Learning a World Model With Multitimescale Memory Augmentation

IEEE Trans Neural Netw Learn Syst. 2023 Nov;34(11):8493-8502. doi: 10.1109/TNNLS.2022.3151412. Epub 2023 Oct 27.

Abstract

Model-based reinforcement learning (RL) is regarded as a promising approach to tackle the challenges that hinder model-free RL. The success of model-based RL hinges critically on the quality of the predicted dynamic models. However, for many real-world tasks involving high-dimensional state spaces, current dynamics prediction models show poor performance in long-term prediction. To that end, we propose a novel two-branch neural network architecture with multi-timescale memory augmentation to handle long-term and short-term memory differently. Specifically, we follow previous works to introduce a recurrent neural network architecture to encode history observation sequences into latent space, characterizing the long-term memory of agents. Different from previous works, we view the most recent observations as the short-term memory of agents and employ them to directly reconstruct the next frame to avoid compounding error. This is achieved by introducing a self-supervised optical flow prediction structure to model the action-conditional feature transformation at pixel level. The reconstructed observation is finally augmented by the long-term memory to ensure semantic consistency. Experimental results show that our approach is able to generate visually-realistic long-term predictions in DeepMind maze navigation games, and outperforms the prevalent state-of-the-art methods in prediction accuracy by a large margin. Furthermore, we also evaluate the usefulness of our world model by using the predicted frames to drive an imagination-augmented exploration strategy to improve the model-free RL controller.