Diffusion model from scratch

Milos Vukadinovic,paper review

Denoising diffusion probabilistic (opens in a new tab) models are behind many popular image/audio generation apps such as: Stable Diffusion (opens in a new tab), Midjourney (opens in a new tab), Dall-E (opens in a new tab), and Imagen (opens in a new tab).

ℹ️

Q: Why can both processes be defined as parametrized Markov chains?

A: Because each circle is a whole distribution and then there is a "line" from each point in the xtx_{t} distribution to each point xt+1x_{t+1} distribution representing the probability of transition.

Most ML folks have a surface understanding of diffusion models and simply think about it as follows: During training there is a forward process qq and a backwards process pp. qq takes an image and gradually adds gaussian noise in nn steps. pp takes the output of qq after nn steps and tries to gradually reconstruct the original image. We aim for distributions pip_i and qiq_i (where ii is the step) to be the same, and that's what we compute the loss with respect to. Finally, once the training is done and loss is minimizes, we need only the process pp for the inference. pp will take any random gaussian noise as input and produce some meaningul image.

That sounds simple right? And this is the correct way to understand the abstract concept of diffusion models. But my curious mind didn't give me rest so questions kept popping up in my head:

  1. How do we measure the distance between two distributions and how do we minimize it?
  2. How complex does the model that predicts denoising need to be ?
  3. How can forward process add meaningful features and produce a real image from Gaussian noise?

You can only fully understand how something works if you build it yourself. So I read up on various blogs and papers: 1 (opens in a new tab), 2 (opens in a new tab),3 (opens in a new tab), 4 (opens in a new tab), 5 (opens in a new tab) and started implementing diffusion model myself.

For a full implementation see my Google Colab notebook https://colab.research.google.com/drive/14UJQ6LTKPo4RgQT-kSjpswkI7Q1oYPW4 (opens in a new tab), here I will explain each code snippet.

def sample_batch(size,noise=1.0):
    x , _=make_swiss_roll(size, noise=noise)
    # we want it in 2d so we take 0th and 2nd coord
    return x[:, [0,2]] / 10.0
data = sample_batch(10**4).T

First, we will focus on the swiss roll dataset. The aim is for the forward process pp to learn to generate points that belong to the swiss roll. Therefore, if we randomly sample many points using pretrained pp we will construct swiss roll distribution. In the code snippet above we use sklearn's function to sample 1000010000 data points, and take coordinates x and z ([x:, [0,2]]), since by default sklearn creates 3D swiss roll. We could've also just defined the dataset ourselves, since the swiss roll in 2D is defined by a map f:R2R2s.t.f(r,θ)=(rcos(θ),rsin(θ))f:\mathbb{R}^2 \to \mathbb{R}^2 \quad s.t. \quad f(r, \theta) = (r cos(\theta), r sin(\theta) )

start=1e-5
end=1e-2
n_steps=100
betas = torch.sigmoid(torch.linspace(-6, 6, n_steps)) * (end-start)+start
alphas = 1 - betas
alphas_bar = torch.cumprod(alphas, 0)   

This code snippet is essentially defining required variables, but there is a lot to unpack here.

Forward process is defined by the formula

q(xtxt1)=N(xt;(1β)xt1,βI)q(x_t | x_{t-1}) = \mathcal{N}(x_t; (1-\beta) x_{t-1}, \beta I )

Here we sample a point that is similar to the point in the previous step but with some noise induced. I think of this as: "remove some small part βfromthedatapoint\beta from the data point and add random noise in it's place. Betas are hyperparameters and they represent the noise schedule, the higher betas are the faster we will induce the noise, note that they have to be between zero and one. We select the sigmoid schedule above because it worked the best in our experiments.

But with the above formula, we cannot directly calculate q(x10x9)q(x_10 | x_9), we would have to calculate 99 points first. It would be neat if we can get it in O(1)O(1). In the denoising diffusion paper, they simpify the formula to:

q(xtxt1)=N(xt;αtˉxt1,(1αtˉ)I)q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}} x_{t-1}, (1- \bar{\alpha_t}) I ) where αt=1βt\alpha_t = 1 - \beta_{t}

αtˉ=s=1tαs\bar{\alpha_t} = \sum_{s=1}^{t} \alpha_s

With this we achieve much faster compute time.

Next, we define a model that will predict denoising. The reverse process pp is defined as:

p(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))p(x_{t-1} | x_{t} ) = \mathcal{N}(x_{t-1} ; \mu_{\theta}(x_t, t), \Sigma_{\theta}(x_t,t))

where μθ(xt,t)\mu_{\theta}(x_t, t) and μθ(xt,t)\mu_{\theta}(x_t, t) are uknown, and is what we need the model to predict.

In the denoising diffusion models paper, they further simplify this process to rely only on the function ϵθ:R2R2\epsilon_\theta : \mathbb{R^2} \to \mathbb{R^2} by defining mean as:

μθ(xt,t)=1αt(xtβt1αˉϵθ(xt,t)) \mu_{\theta} (x_t,t) = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}}} \epsilon_{\theta}(x_t,t) )

and using a fixed variance function instead. This results in the following sampling equation:

xt1p(xt1xt)=1αt(xtβt1αˉϵθ(xt,t))+σtzx_{t-1} \sim p(x_{t-1} | x_{t} ) = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}}} \epsilon_{\theta}(x_t,t) ) + \sigma_t z where zN(0,I)z \sim \mathcal{N}(0,I)

Therefore, our model needs to learn only the function ϵθ\epsilon_\theta

class ConditionalLinear(nn.Module):
    def __init__(self, num_in, num_out, num_classes):
        super().__init__()
        self.num_out = num_out
        self.lin = nn.Linear(num_in, num_out)
        self.embed = nn.Embedding(num_classes, num_out)
        #used to initialize weights of the embedding layer to
        #the values drawn from uniform distribution
        self.embed.weight.data.uniform_()
 
    def forward(self, x, y):
        out = self.lin(x)
        gamma = self.embed(y)
        out = gamma.view(-1, self.num_out) * out
        return out
 
class ConditionalModel(nn.Module):
    def __init__(self, n_steps):
        super(ConditionalModel, self).__init__()
        self.lin1 = ConditionalLinear(2, 128, n_steps)
        self.lin2 = ConditionalLinear(128, 128, n_steps)
        self.lin3 = ConditionalLinear(128, 128, n_steps)
        # only two outs
        self.lin4 = nn.Linear(128, 2)
    
    def forward(self, x, y):
        x = F.softplus(self.lin1(x, y))
        x = F.softplus(self.lin2(x, y))
        x = F.softplus(self.lin3(x, y))
        return self.lin4(x)

For our architecture, we will simply use a stack of 44 linear layers with the softplus activation function. However, since we are going to use the same set of weights to predict each transition, we need to condition the model on the number of steps. So instead of simply using a linear layer we define a conditional linear layer. For each step each layer will learn a vector corresponding to the number of steps (nn.Embedding(num_classes, num_out)) ,and then multiply the output by this vector to condition it: out = gamma.view(-1, self.num_out) * out The output of the last layer has only 2 numbers - namely values of two-dimensional value ϵ\epsilon. Next, it's time to talk about the loss.

def noise_estimation_loss(model, x0):
    batch_size = x0.shape[0]
    t = torch.randint(0, n_steps, size=(batch_size,)).long()
    e = torch.randn_like(x0)
    #broadcast
    x = x0 * torch.sqrt(alphas_bar[t]).unsqueeze(1) + e * torch.sqrt(1-alphas_bar[t]).unsqueeze(1)
    output = model(x, t)
    return (e - output).square().mean()

The above code snippet might be the most complicated bit to understand. To fully understand how we got the noise estimation loss we would need to do a lot of math, but for now I will try to explain the intuition behind it and leave it to the reader to check the math in the original paper. With generative models, we usually optimize the variational bound on negative log likelihood. In other words, we want to find what is the probability of observing the data given the set of weights, and find a set of weights that maximizes this likelihood.

The probability that diffusion model assigns to the data is:

p(x0)=dx1Tp(x0T)p(x^{0}) = \int dx^{1 \dots T} p(x^{0 \dots T}) (joint probability of the entire trajectory)

and then we want to minimize

E[logpθ(x0)]E[-log p_\theta (x_0)]

Now, authors of denoising diffusion paper simplify this objective, and obtain the following loss:

Lsimple(θ)=Et,x0,ϵ[ϵϵθ(αtˉx0+1αtˉϵ,t)2]L_{simple}(\theta) = E_{t, x_0, \epsilon} [|| \epsilon - \epsilon_\theta (\sqrt{\bar{\alpha_t}}x_0 + \sqrt{1- \bar{\alpha_t}}\epsilon, t) ||^2]

One other term that might not be clear in the snippet above is t. We randomly choose a step at which we want to compute the loss for each element in the batch. For that step we compute the difference between p(xt)p(x_t) and q(xt)q(x_t) distribution. Or in other words, we want the distribution obtained after adding noise in tt steps to be the same as the distribution after removing noise in tt steps.

def get_pt(x,t):
    
    t = torch.tensor([t])
    eps_factor = (1 - alphas[t]) / torch.sqrt(1-alphas_bar[t])
    eps_theta = model(x, t)
    mean = 1 / (alphas[t].sqrt()) * (x - (eps_factor * eps_theta))
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    sample = mean + sigma_t * z
    return sample
 
def p_sample_loop(shape):
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = get_pt(cur_x, i)
        x_seq.append(cur_x)
    return x_seq

We defined functions to sample a point from the learned distribution using this formula. xt1p(xt1xt)=1αt(xtβt1αˉϵθ(xt,t))+σtzx_{t-1} \sim p(x_{t-1} | x_{t} ) = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}}} \epsilon_{\theta}(x_t,t) ) + \sigma_t z

model = ConditionalModel(n_steps)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
dataset = torch.tensor(data.T).float()
batch_size = 128
for t in range(1000):
    permutation = torch.randperm(dataset.size()[0])
    for i in range(0, dataset.size()[0], batch_size):
        indices = permutation[i:i+batch_size]
        batch_x = dataset[indices]
        
        loss = noise_estimation_loss(model, batch_x)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        optimizer.step()
        
    if (t % 100 == 0):
        print(loss)

Finally, we train the model. We use Adam optimizer, batch size of 128128 and clip gradients at 1.

x_seq = p_sample_loop(dataset.shape)
plt.scatter(*x_seq[-1].T.detach().numpy(), alpha=0.5, edgecolor='white', s=20)

Visualize the learned distribution.