A simple 2D Discrete Flow Matching model#

This notebook trains and evaluates a simple 2D discrete FM model with \(\kappa_t = t^2\) scheduler.

Dataset: 2D discrete checkerboard Model (probability denoiser): MLP

Imports and init device#

[1]:
import time
import torch

from torch import nn, Tensor

# flow_matching
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.solver import MixtureDiscreteEulerSolver
from flow_matching.utils import ModelWrapper
from flow_matching.loss import MixturePathGeneralizedKL

# visualization
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
[2]:
if torch.cuda.is_available():
    device = 'cuda:0'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')
Using gpu
[3]:
torch.manual_seed(42)
[3]:
<torch._C.Generator at 0x7f691c1a3c50>

Dataset#

[4]:
def inf_train_gen(n_grid_points: int = 128, batch_size: int = 200, device: str = "cpu") -> Tensor:
    assert n_grid_points % 4 == 0, "number of grid points has to be divisible by 4"

    n_grid_points = n_grid_points // 4

    x1 = torch.randint(low=0, high=n_grid_points * 4, size=(batch_size,), device=device)
    samples_x2 = torch.randint(low=0, high=n_grid_points, size=(batch_size,), device=device)

    x2 = (
        samples_x2
        + 2 * n_grid_points
        - torch.randint(low=0, high=2, size=(batch_size,), device=device) * 2 * n_grid_points
        + (torch.floor(x1 / n_grid_points) % 2) * n_grid_points
    )

    x_end = 1.0 * torch.cat([x1[:, None], x2[:, None]], dim=1)

    return x_end.long()

Model#

[5]:
# Activation class
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return torch.sigmoid(x) * x

# Model class
class MLP(nn.Module):
    def __init__(
        self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=2):
        super().__init__()
        self.input_dim = input_dim
        self.time_dim = time_dim
        self.hidden_dim = hidden_dim

        self.time_embedding = nn.Linear(1, time_dim)
        self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim)

        self.main = nn.Sequential(
            Swish(),
            nn.Linear(hidden_dim * length + time_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, self.input_dim * length),
        )

    def forward(self, x, t):
        t = self.time_embedding(t.unsqueeze(-1))
        x = self.token_embedding(x)

        B, N, d = x.shape
        x = x.reshape(B, N * d)

        h = torch.cat([x, t], dim=1)
        h = self.main(h)

        h = h.reshape(B, N, self.input_dim)

        return h

Train Discrete Flow Matching model with a uniform source distribution#

[6]:
source_distribution = "uniform"

# training arguments
lr = 0.001
batch_size = 4096
iterations = 30001
print_every = 3000

vocab_size = 128
hidden_dim = 128

epsilon = 1e-3

if source_distribution == "uniform":
    added_token = 0
elif source_distribution == "mask":
    mask_token = vocab_size  # tokens starting from zero
    added_token = 1
else:
    raise NotImplementedError

# additional mask token
vocab_size += added_token

# probability denoiser model init
probability_denoiser = MLP(input_dim=vocab_size, time_dim=1, hidden_dim=hidden_dim).to(device)

# instantiate a convex path object
scheduler = PolynomialConvexScheduler(n=2.0)
path = MixtureDiscreteProbPath(scheduler=scheduler)

# init optimizer
optim = torch.optim.Adam(probability_denoiser.parameters(), lr=lr)

loss_fn = MixturePathGeneralizedKL(path=path)

# train
start_time = time.time()

steps = 0
losses = []
for i in range(iterations):
    optim.zero_grad()

    # sample data (user's responsibility): in this case, (X_0,X_1) ~ pi(X_0,X_1)
    x_1 = inf_train_gen(n_grid_points=vocab_size - added_token, batch_size=batch_size, device=device) # sample data

    if source_distribution == "uniform":
        x_0 = torch.randint_like(x_1, high=vocab_size)
    elif source_distribution == "mask":
        x_0 = torch.zeros_like(x_1) + mask_token
    else:
        raise NotImplementedError

    # sample time (user's responsibility)
    t = torch.rand(x_1.shape[0]).to(device) * (1 - epsilon)

    # sample probability path
    path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)

    # discrete flow matching generalized KL loss
    logits = probability_denoiser(x=path_sample.x_t, t=path_sample.t)
    loss = loss_fn(logits=logits, x_1=x_1, x_t=path_sample.x_t, t=path_sample.t)

    # optimizer step
    loss.backward() # backward
    optim.step() # update

    # log loss
    if (i+1) % print_every == 0:
        elapsed = time.time() - start_time
        print('| iter {:6d} | {:5.2f} ms/step | loss {:8.3f} '
              .format(i+1, elapsed*1000/print_every, loss.item()))
        start_time = time.time()
| iter   3000 |  3.68 ms/step | loss    5.697
| iter   6000 |  3.49 ms/step | loss    5.539
| iter   9000 |  3.31 ms/step | loss    5.296
| iter  12000 |  3.39 ms/step | loss    5.520
| iter  15000 |  3.56 ms/step | loss    5.714
| iter  18000 |  3.49 ms/step | loss    5.556
| iter  21000 |  3.58 ms/step | loss    5.392
| iter  24000 |  3.49 ms/step | loss    5.354
| iter  27000 |  3.30 ms/step | loss    6.423
| iter  30000 |  3.30 ms/step | loss    5.445

Sample from trained model#

[7]:
class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        return torch.softmax(self.model(x, t), dim=-1)

wrapped_probability_denoiser = WrappedModel(probability_denoiser)
solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
[8]:
nfe = 64
step_size = 1 / nfe

safe_sampling = True
n_samples = 1000000
dim = 2

if source_distribution == "uniform":
    x_init = torch.randint(size=(n_samples, dim), high=vocab_size, device=device)
elif source_distribution == "mask":
    x_init = (torch.zeros(size=(n_samples, dim), device=device) + mask_token).long()
else:
    raise NotImplementedError

n_plots = 9
linspace_to_plot = torch.linspace(0,  1 - epsilon, n_plots)

sol = solver.sample(x_init=x_init,
                    step_size=step_size,
                    verbose=True,
                    return_intermediates=True,
                    time_grid=linspace_to_plot)
NFE: 64: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 0.9990000128746033/0.9990000128746033 [00:08<00:00,  8.13s/it]
[9]:
sol = sol.cpu().numpy()

fig, axs = plt.subplots(1, n_plots, figsize = (20, 20))

if source_distribution == "mask":
    mask_tensor = torch.tensor([mask_token, mask_token]).unsqueeze(0)

for idx, step in enumerate(linspace_to_plot):
    step = int(step.item() * nfe)

    if source_distribution == "uniform":
        sol_step = sol[idx, ...]
    elif source_distribution == "mask":
        sol_step = sol[idx, ...]
        sol_step = sol_step[torch.ne(torch.from_numpy(sol_step), mask_tensor).all(dim=1), ...]

        if sol_step.size == 0:
            axs[idx].hist2d([], [], bins=10)
            axs[idx].set_aspect('equal')
            axs[idx].axis('off')
            axs[idx].set_title('t= %.2f' % (step * step_size))

            continue
    else:
        raise NotImplementedError

    H = axs[idx].hist2d(sol_step[:, 0], sol_step[:, 1], bins=vocab_size)

    cmin = 0.0
    cmax = torch.quantile(torch.from_numpy(H[0]), 0.95).item()

    norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)

    _ = axs[idx].hist2d(sol_step[:, 0], sol_step[:, 1], bins=vocab_size, norm=norm)

    axs[idx].set_aspect('equal')
    axs[idx].axis('off')
    axs[idx].set_title(f't= {linspace_to_plot[idx].item():.2f}')

plt.tight_layout()
plt.show()
../_images/notebooks_2d_discrete_flow_matching_15_0.png

Visualize ELBO#

[10]:
n_discretization = 1024  # Time discretization of integration interval
n_samples = 10  # Number of samples to approximate the expectation on X_t ~ p_t(\cdot| x_1)

# Generalized KL function (will use it to compute the elbo)
generalized_kl_fn = MixturePathGeneralizedKL(
    path = path,
    reduction ='none'
)

# Grid of vocab_size X vocab_size
grid = torch.meshgrid(
    torch.arange(0, vocab_size, device=device),
    torch.arange(0, vocab_size, device=device),
    indexing='ij'
)
x_1 = torch.stack(
    [grid[0].reshape(-1), grid[1].reshape(-1)],
    dim=1
)

# Time discretization
discretization = (
    torch.linspace(0, 1, n_discretization + 1, device=device)[:-1]
    .view(-1, 1)
    .repeat(1, x_1.shape[0])
)

elbo = torch.zeros(size=(x_1.shape[0],), device=device)

with torch.no_grad():
    for _ in range(n_samples):
        # Lower variance estimator for time discretization
        discretization = discretization + torch.rand(
            size=(1, x_1.shape[0]), device=device
        )
        discretization = discretization % 1
        discretization = discretization * (1 - epsilon)

        for t in discretization:
            # sample X_t ~ p_t(\cdot| x_1)
            if source_distribution == "uniform":
                x_0 = torch.randint(size=x_1.shape, high=vocab_size, device=device)
            elif source_distribution == "mask":
                x_0 = (torch.zeros(size=x_1.shape, device=device) + mask_token).long()
            else:
                raise NotImplementedError

            x_t = path.sample(t=t, x_0=x_0, x_1=x_1).x_t

            logits = probability_denoiser(x_t, t)

            # compute ELBO
            elbo += -generalized_kl_fn(
                logits=logits, x_1=x_1, x_t=x_t, t=t
            ).sum(dim=1)

    elbo /= n_discretization * n_samples

# Remember that log_q(x_1) >= ELBO(x_1)
probability_lower_bound = torch.exp(elbo)
[11]:
cmin = 0.0
cmax = probability_lower_bound.max().item() / 1.5

norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)

plt.figure(figsize=(5, 5))
plt.imshow(
    probability_lower_bound.reshape(vocab_size, vocab_size).cpu(),
    origin='lower', cmap='viridis', norm=norm
)
plt.gca().axis("off")
plt.colorbar(cm.ScalarMappable(norm=norm, cmap='viridis'), ax=plt.gca(), orientation='horizontal', label='density')
plt.title("ELBO Estimator")
plt.show()
../_images/notebooks_2d_discrete_flow_matching_18_0.png
[ ]: