Statistical mechanics of continual learning: Variational principle and mean-field potential

Phys Rev E. 2023 Jul;108(1-1):014309. doi: 10.1103/PhysRevE.108.014309.

Abstract

An obstacle to artificial general intelligence is set by continual learning of multiple tasks of a different nature. Recently, various heuristic tricks, both from machine learning and from neuroscience angles, were proposed, but they lack a unified theory foundation. Here, we focus on continual learning in single-layered and multilayered neural networks of binary weights. A variational Bayesian learning setting is thus proposed in which the neural networks are trained in a field-space, rather than a gradient-ill-defined discrete-weight space, and furthermore, weight uncertainty is naturally incorporated, and it modulates synaptic resources among tasks. From a physics perspective, we translate variational continual learning into a Franz-Parisi thermodynamic potential framework, where previous task knowledge serves as a prior probability and a reference as well. We thus interpret the continual learning of the binary perceptron in a teacher-student setting as a Franz-Parisi potential computation. The learning performance can then be analytically studied with mean-field order parameters, whose predictions coincide with numerical experiments using stochastic gradient descent methods. Based on the variational principle and Gaussian field approximation of internal preactivations in hidden layers, we also derive the learning algorithm considering weight uncertainty, which solves the continual learning with binary weights using multilayered neural networks, and performs better than the currently available metaplasticity algorithm in which binary synapses bear hidden continuous states and the synaptic plasticity is modulated by a heuristic regularization function. Our proposed principled frameworks also connect to elastic weight consolidation, weight-uncertainty modulated learning, and neuroscience-inspired metaplasticity, providing a theoretically grounded method for real-world multitask learning with deep networks.