Scalable Moment Propagation and Analysis of Variational Distributions for Practical Bayesian Deep Learning

IEEE Trans Neural Netw Learn Syst. 2024 Feb 27:PP. doi: 10.1109/TNNLS.2024.3367363. Online ahead of print.

Abstract

Bayesian deep learning is one of the key frameworks employed in handling predictive uncertainty. Variational inference (VI), an extensively used inference method, derives the predictive distributions by Monte Carlo (MC) sampling. The drawback of MC sampling is its extremely high computational cost compared to that of ordinary deep learning. In contrast, the moment propagation (MP)-based approach propagates the output moments of each layer to derive predictive distributions instead of MC sampling. Because of this computational property, it is expected to realize faster inference than MC-based approaches. However, the applicability of the MP-based method in deep models has not been explored sufficiently, even though some studies have demonstrated the effectiveness of MP only in small toy models. One of the reasons is that it is difficult to train deep models by MP because of the large variance in activations. To realize MP in deep models, some normalization layers are required but have not yet been studied. In addition, it is still difficult to design well-calibrated MP-based models, because the effectiveness of MP-based methods under various variational distributions has also not been investigated. In this study, we propose a fast and reliable MP-based Bayesian deep-learning method. First, to train deep-learning models using MP, we introduce a batch normalization layer extended to random variables to prevent increases in the variance of activations. Second, to identify the appropriate variational distribution in MP, we investigate the treatment of moments of several variational distributions and evaluate their uncertainty quality of predictions. Experiments with regression tasks demonstrate that the MP-based method provides qualitatively and quantitatively equivalent predictive performance to MC-based methods regardless of variational distributions. In the classification tasks, we show that we can train MP-based deep models by extended batch normalization. We also show that the MP-based approach realizes 2.0-2.8 times faster inference than the MC-based approach while maintaining the predictive performance. The results of this study can help realize a fast and well-calibrated uncertainty estimation method that can be deployed in a wider range of reliability-aware applications.