..

Score-Based Generative Modeling with SDEs

Creating noise from data is easy; creating data from noise is generative modeling.
~ Song et. al. circa 2020

Banger introduction for a paper. I worked through some of it myself and made a very simple tutorial. You’ll still have to work out some of the math yourself. If you wish, skip to training.

1. Background

1.1 Introduction

Modeling techniques can be likelihood-based or implicit, think VAEs vs GANs. And then there is diffusion, the disruptor of generative modeling. There are multiple mathematical formulations of diffusion-based modeling, two of the most prominent ones are - Denoising Diffusion Probabilistic Models and Score Matching Generative Models. But both of them can be viewed as discretizations to stochastic differential equations determined by score functions.

The key idea is to make a model learn how to nudge random noise into something that looks like our target data. Naturally, by modeling the gradient of the log density function, (which is known as the score function), you can the make the model just do that. To put it formally,

Given a probability density function $p(x)$, the score function is defined as $\nabla_x \log p(x)$ and a model $s_\theta(x)$ for score function is called a score-based model. Score-based generative models are trained to estimate $s_\theta(x) \approx \nabla_x \log p(x)$.

Unlike flow models or auto-regressive models, score-based models do not have to be normalized and are easier to parameterize.

For example, consider a non-normalized statistical model

$$p_{\theta(x)} = \frac{e^{-E_{\theta(x)}}}{Z_{\theta}}$$

where $E_{\theta(x)} \in \mathbb{R}$ is called the energy function and $Z_{\theta}$ is an unknown normalizing constant that makes $p_{\theta(x)}$ a proper probability density function. The energy function is typically parameterized by a flexible neural network. When training it as a likelihood model, we need to know the normalizing constant $Z_\theta$ by computing complex high-dimensional integrals, which is typically intractable. In contrast, when computing its score, we obtain

$$\nabla_x \log p_{\theta(x)} = -\nabla_x E_{\theta(x)} - \underbrace{\nabla_x \log Z_{\theta}}_{=0} = -\nabla_x E_{\theta(x)}$$

which does not require computing the normalizing constant $Z_\theta$.

Any neural network that maps an input vector $x \in \mathbb{R}^d$ to an output vector $y \in \mathbb{R}^d$ can be used as a score-based model, as long as the output and input have the same dimensionality.

Similar to likelihood-based models, we can train scored-based models by minimizing the quantity

$$\mathbb{E}_{p(x)}[||\nabla_x \log p(x) - s_{\theta(x)}||^2]$$

which is the Fisher divergence1 between the model and the data distribution. To minimize the Fischer divergence without the knowledge of the unknown data score is called score-matching.

1.2 The Connection to SDEs

Again, a barebones coverage. Feel free to pick up any resource on Stochastic Processes if you’re interested in any of the gibberish below.

1.2.1 Data Perturbation with a Diffusion Process

To generate samples with score-based models, we have to consider a diffusion process that corrupts data slowly into random noise. Very similar to how DDPNs work.

Let $\{X_t \in \mathbb{R}^d\}_{t=0}^T$ be a diffusion process, indexed by the continuous time variable $t \in [0, T]$. The diffusion process is governed by an Itô stochastic differential equation (SDE), of the form

$$dX_t = b(X_t, t)dt + \sigma(t)dW,$$

where the diffusion coefficient $\sigma$ is independent of $X_t$ (for simplicity). Denote the distribution of $X_t$ by $p(X_t, t)$.

We choose a process such that $X_0 \sim \pi$, and $X_T \sim p$. $\pi$ is the data distribution where we have a dataset of i.i.d. samples, and $p$ is the prior distribution that has a tractable form and is easy to sample from (the gaussian, obviously). The noise perturbation by the diffusion process is large enough to ensure $p$ does not depend on $\pi$.

In practice, only the marginal distribution $p(X_t, t|X_0)$ is assumed to be known, when solving for $X_t$, there is an assumption that $X_t$ is only approximately distributed according to $p$.

1.2.2 Reversing the SDE for Sample Generation

By starting from a sample from the prior distribution $p$ and reversing the diffusion process, we can obtain a sample from the data distribution $\pi$. Crucially, the reverse process is a diffusion process running backwards in time. If you run the given SDE backwards in time, you get a reverse SDE. The proof isn’t obvious, but here’s how the corresponding reverse SDE looks.

$$dX_t = b(X_t, t)dt - \underbrace{\sigma^2(t)\nabla \log p(X_t, t)}_{\text{score function}}dt + \sigma(t)d\bar{W}$$

where, $dt$ is an infinitesimal negative time step and $\bar{W}$ runs reverse in time. In order to compute the reverse SDE, we need to estimate $\nabla \log p(X_t, t)$ which is exactly the score function of $p$.

1.2.3 Denoising Score Matching

A family of methods exist that can be employed for score-matching. We can train a neural network $s_\theta$ to approximate $\nabla_x \log p(X_t, t)$ by replacing $p(X_t, t)$ by a weighted expectation over $p(X_t, t|X_0, 0)$.

This is essentially perturbing the data distribution with specified noise and then employing score matching to estimate the score of the perturbed data distribution, hence, called a denoising score matching.

The objective is then the weighted sum of denoising score matching

$$\min_{\theta} \mathbb{E}_{t\sim U(0,T)}[\lambda(t)\mathbb{E}_{X_0\sim\pi} \mathbb{E}_{X_t\sim p(X_t|X_0)}[||s_{\theta} - \nabla_{X_t} \log p(X_t | X_0)||^2]],$$

where $U(0, T)$ is a uniform distribution over $[0, T]$, $p(X(t) | X(0))$ denotes the transition kernel from the forward diffusion process, and $\lambda(t) \in \mathbb{R}_{>0}$ denotes a positive weighting function.

In the objective, the expectation over $X_0$ can be estimated with empirical means over data samples from $\pi$. The expectation over $X_t$ can be estimated by sampling from $p(X_t | X_0)$, which is efficient when the drift coefficient $b(X_t)$ is affine. The weight function $\lambda(t)$ is typically chosen to be inverse proportional to $\mathbb{E}[||\nabla_x \log p(X_t | X_0)||^2]$.

The cool part of this technique is that the model is not trained on a single noise level, but multiple levels.

1.2.4 Solving the Reverse SDE

We can simulate the reverse process by solving the (estimate) reverse SDE with numerical SDE solvers, the simplest being the Euler-Maruyama Method. It involves discretizing the SDE using finite time steps and small Gaussian noise.

Choose a small (negative) time step, initialize $t \leftarrow T$ and iterate the following until $t \approx 0$

ΔX ← [b(X, t) - σ^2(t)s_θ(X, t)]Δt + σ(t)√|Δt|z_t
X ← X + ΔX
t ← t + Δt

where, $z_t \sim \mathcal{N}(0, I)$.

2 Training a Time-Dependent Score-Based Model

2.1 Set-up

Let’s specify a SDE for demonstration

$$dX_t = \sigma_t dW, t \in [0, 1]$$

In this case,

$$p(X_t | X_0) = \mathcal{N}\left(X_0, \frac{1}{2\log \sigma} (\sigma^{2t} - 1)I\right)$$

When $\sigma$ is large, at $t = 1$, the prior distribution, $p$ is

$$\int \pi(y)\mathcal{N}\left(y, \frac{1}{2\log \sigma} (\sigma^2 - 1)I\right)dy \approx \mathcal{N}\left(0, \frac{1}{2\log \sigma} (\sigma^2 - 1)I\right),$$

which is approximately independent of the data distribution and is easy to sample from.

Intuitively, this SDE captures a continuum of Gaussian perturbations with zero mean and variance function $\frac{1}{2\log \sigma} (\sigma^2 - 1)$. This continuum of perturbations allows us to gradually transfer samples from a data distribution $\pi$ to a simple Gaussian distribution $p$.

The reverse-time SDE for the above toy equation is given by

$$dX_t = -\sigma^{2t}\nabla \log p(X_t, t)dt + \sigma_t d\bar{W}.$$

2.2 Implementation Details

There are no restrictions on the network architecture of time-dependent score-based models, except that their output should have the same dimensionality as the input, and they should be conditioned on time.

def loss_fn(model, x, marginal_prob_std, eps=1e-5):
    """The loss function for training score-based generative models.

    Args:
      model: A PyTorch model instance that represents a
        time-dependent score-based model.
      x: A mini-batch of training data.
      marginal_prob_std: A function that gives the standard deviation of
        the perturbation kernel.
      eps: A tolerance value for numerical stability.
    """
    random_t = torch.rand(x.shape[0], device=x.device) * (1.0 - eps) + eps
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None, None, None]
    score = model(perturbed_x, random_t)
    loss = torch.mean(
        torch.sum((score * std[:, None, None, None] + z) ** 2, dim=(1, 2, 3))
    )
    return loss


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""

    def __init__(self, embed_dim, scale=30.0):
        super().__init__()
        # Randomly sample weights during initialization. These weights are fixed
        # during optimization and are not trainable.
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
    """A fully connected layer that reshapes outputs to feature maps."""

    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.dense(x)[..., None, None]


class ScoreNet(nn.Module):
    """A time-dependent score-based model built upon U-Net architecture."""

    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
        """Initialize a time-dependent score-based network.

        Args:
          marginal_prob_std: A function that takes time t and gives the standard
            deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
          channels: The number of channels for feature maps of each resolution.
          embed_dim: The dimensionality of Gaussian random feature embeddings.
        """
        super().__init__()
        # Gaussian random feature embedding layer for time
        self.embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim),
        )
        # Encoding layers where the resolution decreases
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # Decoding layers where the resolution increases
        self.tconv4 = nn.ConvTranspose2d(
            channels[3], channels[2], 3, stride=2, bias=False
        )
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
        self.tconv3 = nn.ConvTranspose2d(
            channels[2] + channels[2],
            channels[1],
            3,
            stride=2,
            bias=False,
            output_padding=1,
        )
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
        self.tconv2 = nn.ConvTranspose2d(
            channels[1] + channels[1],
            channels[0],
            3,
            stride=2,
            bias=False,
            output_padding=1,
        )
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)

        # The swish activation function
        self.act = lambda x: x * torch.sigmoid(x)
        self.marginal_prob_std = marginal_prob_std

    def forward(self, x, t):
        # Obtain the Gaussian random feature embedding for t
        embed = self.act(self.embed(t))
        # Encoding path
        h1 = self.conv1(x)
        ## Incorporate information from t
        h1 += self.dense1(embed)
        ## Group normalization
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)
        h2 = self.conv2(h1)
        h2 += self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)
        h3 = self.conv3(h2)
        h3 += self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)
        h4 = self.conv4(h3)
        h4 += self.dense4(embed)
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)

        # Decoding path
        h = self.tconv4(h4)
        ## Skip connection from the encoding path
        h += self.dense5(embed)
        h = self.tgnorm4(h)
        h = self.act(h)
        h = self.tconv3(torch.cat([h, h3], dim=1))
        h += self.dense6(embed)
        h = self.tgnorm3(h)
        h = self.act(h)
        h = self.tconv2(torch.cat([h, h2], dim=1))
        h += self.dense7(embed)
        h = self.tgnorm2(h)
        h = self.act(h)
        h = self.tconv1(torch.cat([h, h1], dim=1))

        # Normalize output
        h = h / self.marginal_prob_std(t)[:, None, None, None]
        return h

In summary:

  • The U-net architecture is the most popular choice as the backbone of the score network $s_\theta(x, t)$ and that is what is used for this toy example.

  • Time information is incorporated via Gaussian random features. First, $\omega \sim \mathcal{N}(0, s^2I)$ is sampled which is subsequently fixed for the model (i.e., not learnable). For a time step $t$, the corresponding Gaussian random feature is defined as

    $$[\sin(2\pi\omega t); \cos(2\pi\omega t)],$$

    where $[\vec{a}; \vec{b}]$ denotes the concatenation of vector $\vec{a}$ and $\vec{b}$. This Gaussian random feature is used as an encoding for time step $t$ so that the score network can condition on $t$ by incorporating this encoding.

  • The output of the U-net is scaled by $1/\sqrt{\mathbb{E}[||\nabla_X \log p(X_t | X_0)||^2]}$. This is because the optimal $s_\theta(X_t, t)$ has an $\ell_2$-norm close to $\mathbb{E}[||\nabla_x \log p(X_t | X_0)]||^2$, and the rescaling helps capture the norm of the true score.

2.3 Sampling Method

To sample from the trained time-dependent score-based model $s_\theta(x, t)$, a sample is first drawn from the prior distribution $p \approx \mathcal{N}\left(x; 0, \frac{1}{2\log \sigma} (\sigma^2 - 1)I\right)$. Then, the reverse-time SDE is solved using the Euler-Maruyama approach. When applied to the above reverse-time SDE, the following iteration rule is obtained

$$x_{t-\Delta t} = x_t + \sigma^{2t}s_\theta(x_t, t)\Delta t + \sigma_t\sqrt{\Delta t}z_t,$$

where $z_t \sim \mathcal{N}(0, I)$.

num_steps = 500

def Euler_Maruyama_Sampler(
    score_model,
    marginal_prob_std,
    diffusion_coeff,
    batch_size=64,
    num_steps=num_steps,
    device="cuda",
    eps=1e-3,
):
    """Generate samples from score-based models with the Euler-Maruyama solver.

    Args:
      score_model: A PyTorch model that represents the time-dependent score-based model.
      marginal_prob_std: A function that gives the standard deviation of
        the perturbation kernel.
      diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
      batch_size: The number of samplers to generate by calling this function once.
      num_steps: The number of sampling steps.
        Equivalent to the number of discretized time steps.
      device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
      eps: The smallest time step for numerical stability.

    Returns:
      Samples.
    """
    t = torch.ones(batch_size, device=device)
    init_x = (
        torch.randn(batch_size, 1, 28, 28, device=device)
        * marginal_prob_std(t)[:, None, None, None]
    )
    time_steps = torch.linspace(1.0, eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    x = init_x
    with torch.no_grad():
        for time_step in tqdm.notebook.tqdm(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = (
                x
                + (g**2)[:, None, None, None]
                * score_model(x, batch_time_step)
                * step_size
            )
            x = mean_x + torch.sqrt(step_size) * g[
                :, None, None, None
            ] * torch.randn_like(x)
    # Do not include any noise in the last sampling step.
    return mean_x

Using 500 sampling steps and a decent amount of training, samples such as these can be generated on MNIST data,

mnist_somewhat

These are not perfect but a good baseline on how we can train and experiment with score-based generative modelling using SDEs.

3. References

@misc{song2021scorebasedgenerativemodelingstochastic,
      title={Score-Based Generative Modeling through Stochastic Differential Equations}, 
      author={Yang Song and Jascha Sohl-Dickstein and Diederik P. Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
      year={2021},
      eprint={2011.13456},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2011.13456}, 
}

Yang Song’s Blog “delves” into this in excruciating detail and there are more example notebooks you can try.


  1. this is actually an abuse of notation ↩︎