A scalable second order optimizer with an adaptive trust region for neural networks

Neural Netw. 2023 Oct:167:692-705. doi: 10.1016/j.neunet.2023.09.010. Epub 2023 Sep 11.

Abstract

We introduce Tadam (Trust region ADAptive Moment estimation), a new optimizer based on the trust region of the second-order approximation of the loss using the Fisher information matrix. Despite the enhanced gradient estimations offered by second-order approximations, their practical implementation requires sizable batch sizes to estimate the second-order approximation matrices and perform matrix inversions. Consequently, integrating second-order approximations entails additional memory consumption and imposes substantial computational demands due to the inversion of large matrices. In light of these challenges, we have devised a second-order approximation algorithm that mitigates these issues by judiciously approximating the pertinent large matrix, requiring only a marginal increase in memory usage while minimizing the computational burden. Tadam approximates the loss up to the second order using the Fisher information matrix. Since estimating the Fisher information matrix is expensive in both memory and time, Tadam approximates the Fisher information matrix and reduces the computational burdens to the O(N) level. Furthermore, Tadam employs an adaptive trust region scheme to reduce approximate errors and guarantee stability. Tadam evaluates how well it minimizes the loss function and uses this information to adjust the trust region dynamically. In addition, Tadam adjusts the learning rate internally, even if we provide the learning rate as a fixed constant. We run several experiments to measure Tadam's performance against Adam, AMSGrad, Radam, and Nadam, which have the same space and time complexity as Tadam. The test results show that Tadam outperforms the benchmarks and finds reasonable solutions fast and stably.

Keywords: Fisher information matrix; Gradient descent; Neural network; Second order optimizer; Trust region.

MeSH terms

  • Algorithms
  • Learning
  • Neural Networks, Computer*
  • Trust*