
Create articles from any YouTube video or use our API to get YouTube transcriptions
Start for freeIntroduction to Variational Inference
In the realm of artificial intelligence and neuroscience, one of the most fascinating challenges is understanding how systems can reason effectively with incomplete information. Whether it's a human brain identifying a partially obscured object or an AI generating a vivid image from a text prompt, both are tackling the same fundamental problem: reasoning under uncertainty with incomplete data.
This challenge lies at the heart of both natural and artificial intelligence. The key question is: How can we create efficient and accurate models of our complex world using limited clues? In this article, we'll explore variational inference, a powerful mathematical tool that transforms this daunting task into something both brains and computers can master.
The Challenge of High-Dimensional Data
Both biological brains and artificial neural networks face a common challenge: they must process and make sense of complex, high-dimensional inputs. For a brain, this might be the activation patterns of retinal neurons. For AI systems, it could be the pixel values in an image.
What makes this particularly challenging is the sheer volume of data. A single 100x100 pixel image, for example, contains 10,000 individual values. These values are not independent of each other but follow complex patterns and relationships.
The key insight is that intelligent systems can leverage these inherent patterns to build compressed models that capture the essence of what they are observing. For a biological brain, this modeling capacity is crucial for survival. For AI systems, particularly in generative modeling, we explicitly design them to capture and reproduce these patterns.
A Simplified Example: Ocean Research
To make this concept more concrete, let's consider a simplified example. Imagine you're an ocean research vessel trying to understand marine life. You have three sensors measuring water temperature, salinity, and chlorophyll concentration. Each time you dip these sensors into the ocean, you collect a data point - a vector of three numbers.
As you collect samples across the ocean, patterns start to emerge. These measurements are not independent. There might be distinct ocean biomes, for example, that constrain how these variables relate to each other. Certain temperature ranges might occur only within specific salinity levels, while chlorophyll concentrations follow patterns depending on other variables.
This measurement procedure is essentially sampling from an underlying three-dimensional probability distribution - a mathematical function that associates each vector in this measurement space with a positive number (its probability), defining which regions are highly likely to be observed (like common biomes) and which are nearly impossible combinations that don't exist in nature.
The Goal: Building an Efficient Model
Our goal is to build a model that can describe this complicated distribution effectively from a limited number of samples. This is where we introduce a powerful simplifying assumption: we will model the observations to be determined by one hidden or latent random variable, which we'll call Z.
This latent variable Z has its own probability distribution, P(Z). It might represent fundamental properties of ocean biomes that we cannot directly measure, but which explain the patterns we see. The beauty of this approach is that we don't even need to know what this latent factor actually represents - we will let the model discover it on its own.
The Mechanics of Our Model
Our model works like this:
- We randomly sample a value of the latent factor Z from its initial distribution (called the prior).
- Given that specific value of Z, we determine the distribution of possible measurements that are compatible with this latent factor.
We're essentially chaining two sampling procedures together. The distribution of our measurements (X) thus depends on which value of Z we sampled, creating what is known as a conditional probability - the probability of observing X given a specific value of Z, written as P(X|Z).
In probability notation, we can write this as:
P(X,Z) = P(Z) * P(X|Z)
This equation tells us that the joint probability of both X and Z occurring together equals the probability of first stumbling upon the right value of the latent factor, multiplied by the probability of our measurements given that specific latent factor.
Representing Probability Distributions
Before we move further, it's important to clarify what we mean when we say a model "gives us a distribution". While we easily talk about probability distributions as abstract mappings - functions that can be visualized as curves and surfaces - how does a computer practically represent these infinite mathematical objects?
Drawing a smooth curve requires storing coordinates of many points, and the resolution depends on how many points we use. So if our models need to work with probability distributions, they need an efficient way to represent them.
One approach might be to represent the distribution as a long list of probability values for each possible input value, with some finite binning. But this quickly becomes impractical due to the sheer amount of information needed.
Instead, we use what's called a parametric approach - representing entire distributions with just a few key parameters. This is like having a mathematical formula with adjustable knobs, where plugging in different values for a few parameters fully determines different distribution shapes.
The Gaussian Distribution: A Powerful Tool
By far the most common parameterizable distribution is the Gaussian (or normal) distribution - the famous bell curve. Not only is it mathematically convenient, but it is also incredibly common in nature due to the central limit theorem, a famous result that states that when many independent random effects are added together, they tend to create a normal distribution.
What is remarkable about the Gaussian is that despite it being a continuous curve extending infinitely in both directions, we can fully specify any normal distribution using just two parameters: the mean (μ), where the center of the bell curve is, and the variance (σ), which determines how spread out it is.
This efficiency is super powerful - just two numbers completely determine the entire probability distribution. And for multi-dimensional data like our ocean measurements, we can use the extension of Gaussian bell curves to multiple dimensions.
Challenges in High Dimensions
While in our toy example it's totally doable to make a neural network output all the parameters needed to describe the distribution, this approach creates significant problems in high dimensions.
For instance, imagine that instead of three sensor measurements, our observations are now images with 10,000 pixels each. To parameterize the mean of this distribution, we would need 10,000 numbers. There's no getting around that. But if we also want a covariance matrix, that would be a 10,000 x 10,000 matrix, which means 100 million parameters - an astronomical amount that we won't be able to estimate properly because training such a network would require more samples than we can possibly obtain.
This is why we often use simplified covariance structures. A common approach is to assume so-called isotropic Gaussian, which has the same variance of one in all directions without correlations between dimensions. With this simplification, our neural network only needs to output the mean vector, because the covariance is fixed.
Training the Model
Now that we understand how distributions are represented, let's talk about how to train such a model. Remember, the goal of training is to adjust the parameters of our model - in this case, the mean and variance of the prior distribution of latent factors, and the weights of the neural network that transform the latent into the mean of the conditional distribution of observations.
Mathematically, we want to minimize some kind of distance between the true data distribution (which we don't know) and the distribution that our model approximates it with. There is a natural measure for how much two probability distributions differ, called the Kullback-Leibler (KL) divergence.
The KL divergence can be calculated by summing over all possible values that X can take, multiplying the probability of each value in the first distribution by the logarithm of the ratio between the two probabilities. Intuitively, for the perfect model, if the distributions are exactly the same, that ratio will be one everywhere, and since the logarithm of one is zero, the KL divergence becomes zero as well. Any mismatch between the two distributions will result in a value greater than zero.
The Evidence Lower Bound (ELBO)
Using the properties of logarithms, we can split this KL divergence into two terms. The first term is the entropy of the data distribution - a measure of the inherent uncertainty or randomness in our observations. Crucially, this entropy depends only on the true data distribution itself (like how diverse the marine life is) and has nothing to do with our model parameters.
From an optimization perspective, it's a constant factor that we cannot influence. This means that if we want to minimize the KL divergence, we can focus solely on the second term, which is directly affected by our model parameters.
Removing the minus sign, we can rewrite our objective as:
Maximize: E[log P_θ(X)]
This formula tells us to maximize the expected log probability that our model assigns to observations, given higher weight to observations that are more common in the true distribution.
However, there is a practical challenge. As written, this objective requires us to consider every possible observation and weigh the model's output by the true probability of observing that data point. But the true data distribution is exactly what we don't know in the first place!
Fortunately, we have access to a finite number of samples from this distribution - our dataset of measurements that we painstakingly collected. This allows us to approximate the expectation over the true underlying data distribution with a simple average over our training samples.
The Curse of Dimensionality
While the approach we have described is conceptually sound and could work for simple data, learning more complex distributions requires a higher-dimensional latent space, which leads to a significant computational problem.
To calculate the probability of a data point, we need to sum over all possible latent combinations. To properly approximate this sum with sampling, we would need to sample enough values of Z to densely cover all regions in the latent space. The number of required samples grows exponentially with the number of dimensions - a phenomenon known as the curse of dimensionality.
For instance, with 10 latent dimensions, if you need a thousand samples for adequate coverage in one dimension, you would theoretically need 1000^10 samples for the full space - a number that is intractable even with the most advanced computing hardware.
Importance Sampling: A Smarter Approach
But maybe we don't need that much after all. For the vast majority of randomly sampled values from this multi-dimensional prior, the probability that it actually accounts for any specific data point (and thus having a high likelihood P(X|Z)) would be vanishingly small. In other words, only a small fraction of the latent space is relevant for any real data point, and randomly sampling makes it nearly impossible to find these important regions by chance.
But what if we could somehow make smarter choices about how we sample the values of Z? This brings us to the foundational idea of importance sampling.
Imagine we are trying to estimate the average height of trees in a vast forest. It would be impossible to measure every single tree, so we randomly sample some manageable number, say 100, and use that as our estimate. Now suppose the forest has two distinct regions: a dense valley containing 90% of all trees (mostly young saplings just a few feet tall) and a sparse ridge line of very old trees forming the remaining 10% (many reaching over 100 ft in height).
With naive random sampling, we might easily miss the tall trees entirely - they are quite rare, after all. This would significantly underestimate the true average height and create high variance in our estimates, meaning that different sampling rounds would give wildly different results.
But here's an alternative: We can deliberately measure 50% of our samples from the valley and 50% from the ridge line. Then, to mathematically correct for this biased sampling, we can downweight the ridge line measurements by multiplying them by 0.2 (since we are sampling them five times more frequently than their natural occurrence), and we upweigh the valley measurements by multiplying them by 1.8 (as they represent 90% of all trees but only 50% of our samples).
This gives a much better estimate because we are ensuring we capture rare but important samples. The key insight here is that it's better to over-sample important rare cases and mathematically correct for the bias, rather than risk missing them entirely.
The Variational Distribution: Our Guide in the Latent Space
In our forest example, we somehow knew about the ridge line and where to pay more attention. Maybe we had a local guide who could tell us where to find important samples. But how would we know which regions in the latent space are important - that have high likelihood - without actually computing it first?
Wouldn't it be great to have this kind of "forest guide" for our generative model? The key insight of variational inference is precisely this: Instead of blindly sampling from the entire latent space, we train a separate neural network to serve as our guide. This network learns to predict which regions in the latent space are likely to have generated each specific data point.
This guide neural network learns what we call a variational distribution Q(Z|X) for each data point X. It predicts the distribution over the latent space, focusing on regions likely to have generated that specific observation. If you've watched the video on the free energy principle, you may recall this to be the recognition model, which tries to approximate the inversion of the generative model P and predict what latents are compatible with a given observation.
Formalizing the Variational Approach
Let's formalize this idea. Instead of sampling Z blindly from our prior, we will sample from this new Q(Z|X), tailored to each specific data point we are currently looking at. But there's a catch: If we simply replace P(Z) with Q(Z|X) in our calculations, we'll be computing the wrong probability.
The trick is to multiply and divide by the same quantity Q(Z|X) in our formula. Notice that this expression can be now interpreted in the following way: We are sampling Z from the distribution Q, computing the likelihood of the data point X for that value of the latent, and adjusting for the sampling bias (downweighting the effects of Z's that we sample more frequently than they naturally occur).
It is mathematically equivalent to our original formula. But the good thing is now we are ensuring we cover important regions of the latent space with a manageable number of samples.
The Evidence Lower Bound (ELBO) Revisited
Now let's take the logarithm of both sides of this equation. Remember, the end goal is to maximize the average log probability of our observations. We could, in theory, optimize all the parameters (including the weights of the guide neural network Q) to maximize this objective directly. However, there are several practical limitations that make this approach problematic.
The fundamental issue is that the logarithm of the average is not well-behaved computationally. It produces a very noisy estimate of the true log likelihood because outlier samples can disproportionately affect the average, and it's also quite nasty to work with from the optimization standpoint.
Instead, it would be much more convenient to swap the order of operations: First take the logarithm of each sample's likelihood (which would reduce the noise and be more numerically stable) and then calculate the average. This would allow us to process samples in parallel and get more stable gradients.
However, the problem is that we can't simply make this swap. The log of an average and the average of logs are not equal. But the cool thing is, for any concave function (like the logarithm), when you draw a straight line between any two points on the curve, this connecting line always falls below the curve.
This property generalizes to averages of any number of points through what's known as Jensen's inequality. It states that the logarithm of expectation is greater than or equal to the expectation of the logarithm. This means we can write down the inequality that the log likelihood (which is what we aim to maximize) is always greater than or equal to what we would get by swapping the expectation and the log.
The right-hand side of this inequality provides us with a lower bound on the log probability of data, also known as model evidence. So this bound, obtained by swapping the order of operations, is known as the Evidence Lower Bound, or ELBO for short.
Understanding the ELBO
If we unpack this formula using the properties of logarithms and splitting this into two terms, we arrive at a beautiful interpretation of this formula:
-
The first term is accuracy. It measures how well the latents predicted by the recognition model actually explain the observed data point X. This rewards our model for making good predictions.
-
The second term is the negative of the KL divergence between the distribution Q and the prior distribution of latents. This acts as a complexity penalty. It prevents the guide from becoming too specialized and inventing overly complex distributions of latents, ensuring that the distribution of latent factors for any specific data point doesn't stray too far away from our general prior belief about the latent space as a whole.
In other words, it ensures that the emerging latent space is kind of nice and smooth.
Implementation Notes
While in theory we could learn the parameters of the prior distribution as well, what's typically done is to fix the prior to be the standard isotropic Gaussian (with zero mean and identity covariance) and learn only the weights of the recognition model Q and the generative model P.
Then, for each data point:
- We first pass it through the recognition model, which gives us the parameters (the mean and covariance) of the distribution of latents that are likely to account for it.
- Then we sample a bunch of Z's from that distribution and map each Z through the generative model to get the parameters of the distribution of data points given this latent value.
The ELBO objective (which is what we are maximizing) requires us to compute the actual probability of a data point. But remember that our model outputs the parameters of the distribution, not probabilities directly. So to get the probability per sample, we plug both the output of the model (which is the predicted mean vector) and the observation into the formula for the Gaussian probability and take the logarithm.
Because the covariance is fixed and isotropic, the formula simplifies, and taking the logarithm eliminates the exponent. Thus, the log probability of a data point becomes proportional to minus the squared distance between the network's output and the sample. This is why the accuracy is often called the "reconstruction term" and is equated with the squared distance between the model's output (the reconstruction) and the actual training data point.
But this equivalence comes from our assumption that the output distribution is Gaussian with isotropic covariance. If we used a different parameterization, the ELBO equation still holds (because it is about probability distributions as abstract functions), but the accuracy would not equal this simple squared distance. You would have to calculate the probability according to your parameterization.
Another advantage of the Gaussian parameterization for both the prior and the variational distribution Q is that the KL divergence between two Gaussians has a closed-form expression. We can simply plug in the parameters for means and covariances to get the divergence value without having to estimate it from samples.
This way, we can compute the value of ELBO efficiently. And as we do it across all data samples in the training set, we are gradually tweaking the weights of both the Q and P networks to align with each other and identify compressed representations of structure present in the data, effectively learning the underlying distribution.
Conclusion: Tying It All Together
In this article, we've explored variational inference and its key tool, the Evidence Lower Bound (ELBO). We began with the challenge of modeling complex, high-dimensional data and using latent variables as the computational language. We learned how parameterizing distributions makes this practical and how training aligns our model with the true data distribution via minimizing the KL divergence.
When summing over latent possibilities became intractable, importance sampling and Jensen's inequality brought us to ELBO, a practical lower bound that decomposed into accuracy and complexity terms. Notably, the negative of this lower bound is known as variational free energy in neuroscience, framing them as two sides of the same coin, revealing a unifying framework of how intelligent systems - whether brains or machines - can navigate uncertainty in an efficient way.
Variational inference stands as a powerful tool in the arsenal of both artificial intelligence and neuroscience, enabling us to build efficient models from incomplete data and shedding light on how both artificial and biological systems might grapple with the fundamental challenge of reasoning under uncertainty.
Article created from: https://youtu.be/laaBLUxJUMY?si=pKm7UHgMPVhKlpds