Communication-Efficient stochastic gradient mcmc for neural networks

Chunyuan Li, Changyou Chen, Yunchen Pu, Ricardo Henao, Lawrence Carin

Research output: Chapter in Book/Report/Conference proceedingConference contribution

5 Scopus citations

Abstract

Learning probability distributions on the weights of neural networks has recently proven beneficial in many applications. Bayesian methods such as Stochastic Gradient Markov Chain Monte Carlo (SG-MCMC) offer an elegant framework to reason about model uncertainty in neural networks. However, these advantages usually come with a high computational cost. We propose accelerating SG-MCMC under the master-worker framework: workers asynchronously and in parallel share responsibility for gradient computations, while the master collects the final samples. To reduce communication overhead, two protocols (downpour and elastic) are developed to allow periodic interaction between the master and workers. We provide a theoretical analysis on the finite-time estimation consistency of posterior expectations, and establish connections to sample thinning. Our experiments on various neural networks demonstrate that the proposed algorithms can greatly reduce training time while achieving comparable (or better) test accuracy/log-likelihood levels, relative to traditional SG-MCMC. When applied to reinforcement learning, it naturally provides exploration for asynchronous policy optimization, with encouraging performance improvement.
Original languageEnglish (US)
Title of host publication33rd AAAI Conference on Artificial Intelligence, AAAI 2019, 31st Innovative Applications of Artificial Intelligence Conference, IAAI 2019 and the 9th AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2019
PublisherAAAI press
Pages4173-4180
Number of pages8
ISBN (Print)9781577358091
StatePublished - Jan 1 2019
Externally publishedYes

Fingerprint

Dive into the research topics of 'Communication-Efficient stochastic gradient mcmc for neural networks'. Together they form a unique fingerprint.

Cite this