Memory Model: MDN-RNN
Overview
The Memory Model in World Models uses a Mixture Density Network - Recurrent Neural Network (MDN-RNN) to predict future latent states.
Why MDN-RNN?
The environment's dynamics are often:
- Stochastic: The same action can lead to different outcomes
- Multimodal: Multiple valid future states may exist
- Temporal: Current state depends on history
An MDN-RNN addresses these challenges by:
- Using an RNN to capture temporal dependencies
- Using an MDN to model multimodal distributions
Mixture Density Networks
An MDN models the output distribution as a mixture of Gaussians, allowing the model to:
- Capture uncertainty in predictions
- Model multimodal distributions
- Handle stochastic environments
The Role of Hidden State
The LSTM hidden state h serves multiple purposes:
- Memory: Stores information about past observations
- Context: Provides temporal context for predictions
- Belief State: Represents the agent's belief about the world
Dream Training
A key application of the MDN-RNN is dream training:
- Initialize with a real observation
- Sample actions from the controller
- Use MDN-RNN to predict next states
- Train the controller entirely in this "dream"