A multivariate adaptive gradient algorithm with reduced tuning efforts

Neural Netw. 2022 Aug:152:499-509. doi: 10.1016/j.neunet.2022.05.016. Epub 2022 May 21.

Abstract

Large neural networks usually perform well for executing machine learning tasks. However, models that achieve state-of-the-art performance involve arbitrarily large number of parameters and therefore their training is very expensive. It is thus desired to implement methods with small per-iteration costs, fast convergence rates, and reduced tuning. This paper proposes a multivariate adaptive gradient descent method that meets the above attributes. The proposed method updates every element of the model parameters separately in a computationally efficient manner using an adaptive vector-form learning rate, resulting in low per-iteration cost. The adaptive learning rate computes the absolute difference of current and previous model parameters over the difference in subgradients of current and previous state estimates. In the deterministic setting, we show that the cost function value converges at a linear rate for smooth and strongly convex cost functions. Whereas in both the deterministic and stochastic setting, we show that the gradient converges in expectation at the order of O(1/k) for a non-convex cost function with Lipschitz continuous gradient. In addition, we show that after T iterates, the cost function of the last iterate scales as O(log(T)/T) for non-smooth strongly convex cost functions. Effectiveness of the proposed method is validated on convex functions, smooth non-convex function, non-smooth convex function, and four image classification data sets, whilst showing that its execution requires hardly any tuning unlike existing popular optimizers that entail relatively large tuning efforts. Our empirical results show that our proposed algorithm provides the best overall performance when comparing it to tuned state-of-the-art optimizers.

Keywords: Adaptive learning rate; Deep learning; Gradient descent optimization.

MeSH terms

  • Algorithms*
  • Machine Learning
  • Neural Networks, Computer*