The first neural world model: sequential latent variables
Introduce latents because pixels are too big to plan in.
In the last blog, we summarized the classical world models and introduced the need for a learnable latent variable in the presence of unknown dynamics.
Till now, our model was:
This is rigid.
Here, we accept that we don’t know the state equations. We must learn a latent dynamics model. The high-dimensional observations (pixels) force two changes:
the observation model p(o_t|x_t) must be learned (decoder).
inference p(x_t|o_<t,a_<t) must be amortized (encoder).
The core question becomes:
How do we learn a latent dynamical system whose latent state is (1) predictice, (2) trainable from pixels, and (3) usable for control?
We have sequences from an environment or simulator:
Think of a replay buffer: chunks of length T, e.g., T = 50. Our job is to learn a model that can answer counterfactual questions: If I’m in “the situation implied by history”, and I take actions a_t, what distribution over next observations/rewards should I expect?
But we don’t want to store the whole history h_t, we want a compact “state”.
Why not just predict pixels directly?
A naive world model would do:
If we do this with a transformer, it can work for short horizons, but it has three deep problems.
Problem A: multi-modality becomes blurred
If from the same current frame, two futures are possible (e.g., an occluded car might appear left or right), the MSE-optional prediction is the average, which is not a valid future. So “pixel regression” learns means of the futures.
Problem B: partial observability requires “belief”, not a point
In partially observable Markov decision processes, the correct internal state is a distribution over latent states, not a deterministic vector. Direct pixel predictors usually hide ambiguity poorly.
Problem C: planning needs a state that evolves cleanly under actions
Planning requires rolling forward under imagined actions for many steps. Pixels’ space is huge - compounding errors explode. So we want a latent state that is small, predictive, action-conditioned, and (ideally) carries undertainty. This is exactly what the classical state-space model did, except now o_t is pixels and the observation model is unknown.
The minimal fix: introduce a latent state
We posit a latent variable z_t at every time step. You can interpret z_t as the compact internal state/belief that explains the observation and predicts the future under actions — just defined a generative model (the “neural world model”).
Building the generative story…
First of all, sample the initial latent:
For each time step t,
This is just a latent state-space model. We still need to figure out: how do we learn it from data?
Before we do that, let’s have some concepts in order.
The obstacle: we never observe the latent state (z_t)
We only see o_t. The training process aims to maximize:
This integral is the “intractable marginalization over latent trajectories”. So we do the standard pick: approximate the posterior with an inference model.
The single most important concept: two distributions
At each step, we will maintain two distributions over the latent state:
Prior (prediction-only): What the dynamics model predicts before seeing the next observation:
This is what our model “imagines”.
Posterior (inference, corrected by observation): What we believe after incorporating the actual observation:
This is our learned “filter update”.
World models are trained by forcing the prior (imagination) to match the posterior (reality-corrected belief), while also reconstructing observations.
That’s it, everything else is details!
Key equations: ELBO for sequential latents (deriving the training objective)
Please read the equations carefully, they are simple and understandable.
How do we train this? We cannot maximize p(x) over all time steps directly because it requires integrating over all possible latent trajectories (an intractable problem!). Instead, we maximize the Evidence Lower Bound (ELBO) OR Variational Lower Bound (ELBO): a lower bound on the log-likelihood of some observed data.
Why are we using ELBO? Read the first two paragraphs
For a single time step t, the loss function L has three parts:
Let’s decode the KL term. This is the most critical equation in World Models.
q(.) (Posterior): the distribution given the current observation
p(.) (Prior): the distribution given only the past
Essentially, we want the “prior” to be close to the “posterior”. Why? Because at runtime (deployment), we don’t have the observation (image) x_t yet. We only have the Prior. If the Prior accurately matches the Posterior, our “dreams” (p) match “reality” (q).
A simple derivation
We will start with what we want to maximize (check above):
Multiply and divide by:
Apply Jensen’s inequality:
Now, let’s expand the first term using the model factorization:
So, the ELBO becomes:
Let’s understand these two terms intuitively
Reconstruction Loss: Forces the latent state (z_t) to carry enough information to render o_t.
KL to the prior: Forces the inferred latent trajectory to be explainable by the dynamics model. Essentially, this is the “physics constraint” in the latent space.
To deeply understand the KL term, let’s decompose it per time step.
Pick a filtering factorization:
and the Markov prior:
The, the trajectory KL decomposes as:
This is exactly why our code can do per-step KL. The training objective in operational form:
This looks abstract until we realize:
q(.) is the encoder distribution (posterior)
p(.) is the dynamics distribution (prior)
So we are literally penalizing the mismatch between what our model predicts and what reality forces us to infer.
To summarize, we discussed:
prior rollout (imagination) -----> p(z_{t+1}|z_t,a_t)
| |
| (observe o_{t+1}) | (no observation)
v v
posterior correction -----> q(z_{t+1}|...,o_{t+1})Following this, a paper came out on this: Variational Recurrent Neural Networks (VRNN).
Briefly discussing this, VRNN uses an RNN hidden state h_t as memory of the past, and samples z_t each step.
The RNN gives a flexible memory backbone; pure Markov can be too weak. The issue is that the RNN can learn to reconstruct o_t without needing z_t, causing posterior collapse (when the model can minimize loss while ignoring the latent state).
The field has converged on a specific architecture known as the Recurrent State Space Model (RSSM), which we are going to cover in the next blog!
~Ashutosh

