DEV Community

freiberg-roman
freiberg-roman

Posted on

Learn the Math of Diffusion Models (Enough to Reason about Ideas from Top-ML Conferences)

Diffusion models took off in recent machine learning conferences.

This year, NeurIPS had around 400 publications that employ diffusion techniques. At first glance the math might seem familiar, but some results might catch you off guard and require more clarification. This article tries to fill the gap without going the painful formal math way.

Here you will get the necessary basics covering the foundational building blocks.

  • Stochastic Processes
  • Brownian motion
  • Stochastic Integral
  • Ito Differential- Stochastic Differential Equations
  • Ornstein-Uhlembeck Process

A quick disclaimer.

The results presented here have been significantly condensed to keep things simple to not go into formal math territory. This comes at a significant cost in correctness. However, if you require a more formally accurate foundation, there will be no other way than literature. Make sure to check out the references below to have some starting points.

That being said, let us begin.

Stochastic Processes

In the domain of stochastic calculus one reasons about probabilities that a certain path will take.

For our sake, stochastic processes are noisy continuous paths in Rd\mathbb{R}^d indexed by time. One could think of a function Xt:R+RdX_t: \mathbb{R}_+ \to \mathbb{R}^d . Since we want to keep track of a lot of them, we need a way to distinguish each by an additional variable ω\omega - hence we denote Xt(ω)X_t(\omega) - defining the total function space as Ω\Omega . In Haskell-like notation it would result in ω(tx)\omega \to (t \to x) . So we end up with a function outputting paths.

This allows us to assign probabilities to paths.

Brownian Motion

Often referred to as the most important object of discussion in probability theory.

In essence, Brownian Motions are continuous paths (Kolmogorov's continuity theorem) starting from zero B0=0B_0 = 0 and having an independent Gaussian increment

For s<s<t BtBs are independent from Bs For\ s' < s < t\ B_t - B_s \text{ are independent from } B_{s'}

BtBs follows a normal distribution,N(0,ts=σ2). B_t - B_s \text{ follows a normal distribution}, \mathcal{N}(0, t-s = \sigma^2).

This means, at each time point, the tracked particle position Bt(ω)B_t(\omega) will completely forget its movement history and choose the position around the current one with a normal distribution. Extending this concept to Rd\mathbb{R}^d involves using multiple independent Brownian Motions in conjunction.

Following, we will use the Brownian Motion to define an integral, using the independent increments as our measurement delta.

Stochastic Integral (Itô)

In this chapter, we will explore the concept of the stochastic integral,

tstarttendXtdBt \int_{t_{start}}^{t_{end}}X_tdB_t

which is essential for defining the stochastic integral

The key idea is similar to the definition of the Riemann Integral through the limit of sums limπ0if(ti)Δt=f(x)dxlim_{\pi \to 0} \sum_i f(t_i) * \Delta_t = \int f(x)dx , which goes over the partitioned space π=tstart=t0<t1<...<tN=tend\pi = t_{start} = t_0 < t_1 <...<t_N = t_{end} with each partition becoming finer. First we require the process to behave without hindsight (adapted process). This simply means that the process, for any given time point, does not look into future values. Here lies the crux of the stochastic integral definition,

limπ0πXti(Bti+1Bti)=tstarttendXtdBt lim_{\pi \to 0} \sum_{\pi}X_{t_i}(B_{t_{i+1}} - B_{t_i}) = \int_{t_{start}}^{t_{end}}X_tdB_t

where the integrand is required to be evaluated at the beginning of the interval to not peek into the future. Hence the resulting integral is an adapted process itself.

Note: There are other definitions of stochastic integrals like Stratonovich Integral, where evaluation happens in the middle of the interval, but also comes also with other properties. For diffusion models the Itô integral remains the goto choice.

As this material can be quite dense, we will go through some key results and properties of this integral.

We have learned previously that the increments, appearing now in the sum, are independent Gaussians. In some cases the distribution remains in the limiting cases. For this we require the integrand to be square integrable i.e. f(s)2ds<\int ||f(s)||^2 ds < \infty then the process

It=0tf(t)dBt I_t = \int_0^tf(t)dB_t

is Gaussian itself.

Some way to think about it, we integrate over paths, which spread unbiased in space, hence the integral is "centered" by design and sums up the "spread" of the integrand.

This neat property allows us to compute in some cases the integral in closed form.

As we will be concerned with the distribution of the integral, we will often encounter its square expectation to compute variances. There is a neat property.

E(XtdBt)2=E(Xt2dt) E(\int X_t dB_t)^2 = E(\int X_t^2 dt)

For this property, the process needs to be in M2M^2 . Think of it as the analogue of L2L^2 space for stochastic processes.

The next theorem is less concerned with computation, but gives us some key insights of the nature of stochastic integrals

First, observe the bare-bone integral, 0tdBt=lim(Bti+1Bti)=Bt\int_0^tdB_t = lim \sum (B_{t_{i+1}} - B_{t_i}) = B_t which is the Brownian motion. Now, consider plugging in a process that 'remaps' time, taking a path from zero to infinity. If this process is well-behaved to some extent i.e. in M2M^2 the result aligns with the Brownian motion. This statement holds if one defines,
τ(t)=min[u;0uXs2ds>t] \tau(t) = min [u; \int_0^u X_s^2ds > t] for the integral,

0τ(t)XtdBt \int_0^{\tau(t)}X_tdB_t

which is our recovered Brownian motion. There are a lot of other worth mentioning properties concerning martingales. However, these few hand picked examples should already convey the type of structure we are dealing with

We will continue now with the core topics.

Itô Differential

The previous topic was concerned about integration, this one will be with differentiation.

Differential equations are typically denoted by dydt=g(y,t)\frac{dy}{dt}=g(y,t) . For stochastic differential equations (SDE) it is more convenient to look at the differential version dy=g(y,t)dtdy=g(y,t)dt as we can use the integral to define these. Thus the Ito stochastic differential is denoted as

Xt2Xt1=t1t2Ftdt+t1t2GtdBt X_{t_2} - X_{t_1} = \int_{t_1}^{t_2}F_tdt + \int_{t_1}^{t_2}G_tdB_t

for each t2>t1t_2 > t_1 , which emits (think of the limit) the differential
dXt=Ftdt+GtdBt dX_t = F_tdt + G_tdB_t

We end up with a deterministic part and a stochastic.

FtF_t is referred to as the drift term, it influences the overall direction of the process. One could think of the expectation of dXtdX_t . The term GtG_t plays the role of a valve controlling the randomness of the process – how much the process oscillates – around the expected value. This is in line with our intuition on how the stochastic integral operates.

So how do we compute the derivative of a stochastic function?

This is where Itô's lemma gives us the solution.

df(Xt,t)=ftdt+fxdXt+122fx2Gt2dt df(X_t, t) = \frac{\partial f}{\partial t}dt + \frac{\partial f}{\partial x}dX_t + \frac{1}{2}\frac{\partial^2 f}{\partial x^2}G_t^2dt

For example, this result in dBt2=BtdBt+12dtdB_t^2 = B_tdB_t + \frac{1}{2}dt , which is contrary to the standard way derivatives work. An intuition on where the last part comes from lies in the Taylor series expansion.
df(Xt)=f(Xt)dXt+12f(Xt)(dXt)2+... df(X_{t})=f^{\prime}(X_{t})dX_{t}+\frac{1}{2}f^{\prime\prime}(X_{t})(dX_{t})^{2}+...

=f(X)(Ftdt+GdBt)+f(X)(Ftdt+GtdBt)2+... =f^{\prime}(X)(F_tdt + GdB_t) + f^{\prime\prime} (X)(F_tdt + G_tdB_t)^2 + ...

=f(Xt)(Ftdt+GtdBt)+12f(Xt)(Ft2dt2+2FtGtdt dBt+Gt2(dBt)2)+... =f^{\prime}(X_{t})(F_{t}dt+G_{t}dB_{t})+\frac{1}{2}f^{\prime\prime}(X_{t})(F_{t}^{2}dt^{2}+2F_{t}G_{t}dt~dB_{t}+G_{t}^{2}(dB_{t})^{2})+...

The standard calculus derivative, is the best linear approximation as the other higher terms tent to zero in the limiting case. However, for the stochastic differential, the term (dBt)2(dB_t)^2 behaves as dtdt . A sort of hand wavy explanation why that is provides the following equation.
E(Bt2Bt1)2=t2t1 \mathbf{E}(B_{t_2} - B_{t_1})^2 = t_2 - t_1

As such, in the limit – in L2L_2 – this lets (dBt)2(dB_t)^2 behave as dtdt .

However, some aspects of standard calculus do translate.

For example the partial differentiation,

0Tf(s)dBs=f(T)BT0TBsf(s)ds \int_0^Tf(s)dB_s = f(T)B_T - \int_0^T B_s f(s)ds

which allows us to compute the stochastic integral for any deterministic function. This is a direct consequence of the Itô lemma applied to the function f(t)Btf(t)B_t .

Note: We keep things simple here and define all terms only in one dimension. The extension to higher dimension can get quite hairy, but fortunately, in most research publications the dimensions are assumed to be independent. Hence, we can compute each dimension separate.

Stochastic Differential Equations

As we might now expect, Stochastic Differential Equations (SDEs) are defined as follows

dXt=b(Xt,t)dt+σ(Xt,t)dBt,t[u,T] dX_t = b(X_t, t)dt + \sigma(X_t, t)dB_t, t \in [u, T]

with an initial condition Xu=xX_u = x . (could also be a random variable). The solution has to satisfy the following:
Xt=x+uTb(Xt,t)dt+uTσ(Xt,t)dBt X_t = x + \int_{u}^Tb(X_t, t)dt + \int_u^T\sigma(X_t, t)dB_t

This forms the basic structure of SDEs.

Next, let us explore one of the most predominant forms of SDEs and its applications.

The Ornstein-Uhlenbeck Process

This process is arguably the most used example that can be encountered in every second publication on diffusion models.

dXt=λXtdt+σdBt,X0=x,where λ>0 dX_t = -\lambda X_tdt + \sigma dB_t,X_0 = x, where\ \lambda > 0

Some prior orientation. The negative sign pushes the process in the opposite direction. We should expect the process to settle in around zero. The constant before the Brownian motion lets us expect the overall stationary (for time going to infinity) to look like a Gaussian distribution.

As a first step we look at the deterministic part of the equation.

dXt=λXtdt dX_t = -\lambda X_tdt

The solution here is of course exp(λt)xexp(-\lambda t)x . Next we try to guess the solution by computing the differential of exp(λt)Ztexp(-\lambda t)Z_t using Itô's lemma.
d(exp(λt)Zt)=λexp(λt)Ztdt+exp(λt)dZt d(exp(-\lambda t)Z_t) = -\lambda exp(-\lambda t)Z_t dt + exp(-\lambda t)dZ_t

and compare it to the original SDE term by term.
exp(λt)dZt=σdBt exp(-\lambda t)dZ_t = \sigma dB_t

Finally solve for ZtZ_t by integrating the differential.
Zt=0texp(λs)σdBt Z_t = \int_{0}^{t}exp(\lambda s)\sigma dB_t

obtaining the solution
Xt=exp(λt)x+exp(λt)0texp(λs)σdBs X_t = exp(-\lambda t)x + exp(-\lambda t)\int_0^{t}exp(\lambda s) \sigma dB_s

Recall the note before, that under certain conditions the stochastic integral results in a Gaussian distribution. In this case, conditions are met and we can even compute the distribution of the process at any time in close form.

Thus, our goal is to find the expectation and variance.

E(Xt)\mathbf{E}(X_t) is simply the drive term exp(λt)xexp(-\lambda t)x and for the variance we recall the property on squared stochastic integrals and apply it to E(Xtexp(λt)x)2E(X_t - exp(-\lambda t)x)^2 .

E(exp(λt)0texp(λs)σdBs)2=exp(λt2)σ20t(exp(λs2))ds=σ22λ(1exp(2λt)) E(exp(-\lambda t)\int_0^t exp(\lambda s)\sigma dB_s)^2 = exp(-\lambda t 2)\sigma^2\int_{0}^{t}(exp(\lambda s 2))ds = \frac{\sigma^2}{2 \lambda}(1 - exp(-2\lambda t))

Hence we end up with the distribution N(exp(λt)x,σ22λ(1exp(2λt)))\mathcal{N}(exp(-\lambda t)x, \frac{\sigma^2}{2\lambda}(1 - exp(-2\lambda t))) .

This is basically it. Let us have a final discussion on the result.

As time progresses the distribution settles in the normal distribution N(0,σ22λ)\mathcal{N}(0, \frac{\sigma^2}{2\lambda}) for the infinite case. This is the reason why this process is of such use for diffusion models. First any data diffused will end up in the stationary distribution. As a byproduct the entropy of the distribution is kept in check and does not diverge to infinity. Second the distribution at any time point is known in closed form and can be step wise computed. This all allows us to learn the step wise diffusion updates, which diffusion models are all about.

Final notes

This article is getting a bit too long.

There are some topics left to discuss such as reverse processes and the SDE-ODE relationships. Let me know if there is interest in a continuation.

References

Øksendal, Bernt. “Stochastic Differential Equations.” In Stochastic Differential Equations, Springer, 2003.

Baldi, Paolo. Stochastic Calculus: An Introduction Through Theory and Exercises. Springer, 2010.

Song, Yang; Sohl-Dickstein, Jascha; Kingma, Diederik P.; Kumar, Abhishek; Ermon, Stefano; Poole, Ben. “Score-Based Generative Modeling Through Stochastic Differential Equations.” arXiv preprint arXiv:2011.13456, 2020.

Top comments (0)