A mathematical framework for improved weight initialization of neural networks using Lagrange multipliers

Neural Netw. 2023 Sep:166:579-594. doi: 10.1016/j.neunet.2023.07.035. Epub 2023 Aug 3.

Abstract

A good weight initialization is crucial to accelerate the convergence of the weights in a neural network. However, training a neural network is still time-consuming, despite recent advances in weight initialization approaches. In this paper, we propose a mathematical framework for the weight initialization in the last layer of a neural network. We first derive analytically a tight constraint on the weights that accelerates the convergence of the weights during the back-propagation algorithm. We then use linear regression and Lagrange multipliers to analytically derive the optimal initial weights and initial bias of the last layer, that minimize the initial training loss given the derived tight constraint. We also show that the restrictive assumption of traditional weight initialization algorithms that the expected value of the weights is zero is redundant for our approach. We first apply our proposed weight initialization approach to a Convolutional Neural Network that predicts the Remaining Useful Life of aircraft engines. The initial training and validation loss are relatively small, the weights do not get stuck in a local optimum, and the convergence of the weights is accelerated. We compare our approach with several benchmark strategies. Compared to the best performing state-of-the-art initialization strategy (Kaiming initialization), our approach needs 34% less epochs to reach the same validation loss. We also apply our approach to ResNets for the CIFAR-100 dataset, combined with transfer learning. Here, the initial accuracy is already at least 53%. This gives a faster weight convergence and a higher test accuracy than the benchmark strategies.

Keywords: Lagrange function; Linear regression; Neural network training; Remaining useful life; Weight initialization.

MeSH terms

  • Algorithms*
  • Language
  • Learning
  • Linear Models
  • Neural Networks, Computer*