Source code for neuraltrain.models.diffusion_prior

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Diffusion prior used in MindEye (https://github.com/MedARC-AI/fMRI-reconstruction-NSD/tree/main).
"""

import logging
import random
import typing as tp

import torch
from torch import nn
from tqdm import tqdm

from .base import BaseModelConfig

logger = logging.getLogger(__name__)

try:
    import dalle2_pytorch.dalle2_pytorch as dalle2_modules
    from dalle2_pytorch import DiffusionPrior as DalleDiffusionPrior

    class PriorNetwork(nn.Module):  # type : ignore
        def __init__(
            self,
            dim: int,
            num_timesteps: int | None = None,  # type: ignore
            num_time_embeds: int = 1,
            num_text_tokens: int = 257,
            num_image_tokens: int = 257,
            causal: bool = True,
            learned_query_mode: str = "none",
            **kwargs,
        ):
            super().__init__()
            self.dim = dim
            self.num_time_embeds = num_time_embeds
            self.continuous_embedded_time = not dalle2_modules.exists(num_timesteps)
            self.learned_query_mode = learned_query_mode

            self.to_time_embeds = nn.Sequential(
                (
                    nn.Embedding(num_timesteps, dim * num_time_embeds)  # type: ignore
                    if dalle2_modules.exists(num_timesteps)
                    else nn.Sequential(
                        dalle2_modules.SinusoidalPosEmb(dim),
                        dalle2_modules.MLP(dim, dim * num_time_embeds),
                    )
                ),  # also offer a continuous version of timestep embeddings, with a 2 layer MLP
                dalle2_modules.Rearrange("b (n d) -> b n d", n=num_time_embeds),
            )

            if self.learned_query_mode == "token":
                self.learned_query = nn.Parameter(torch.randn(num_image_tokens, dim))
            if self.learned_query_mode == "pos_emb":
                scale = dim**-0.5
                self.learned_query = nn.Parameter(
                    torch.randn(num_image_tokens, dim) * scale
                )
            if self.learned_query_mode == "all_pos_emb":
                scale = dim**-0.5
                self.learned_query = nn.Parameter(
                    torch.randn(num_image_tokens + num_text_tokens + 1, dim) * scale
                )
            self.causal_transformer = FlaggedCausalTransformer(
                dim=dim, causal=causal, **kwargs
            )

            self.null_text_embeds = nn.Parameter(torch.randn(num_text_tokens, dim))
            self.null_image_embed = nn.Parameter(torch.randn(num_image_tokens, dim))

            self.num_image_tokens = num_image_tokens
            self.num_text_tokens = num_text_tokens
            self.self_cond = False

        def forward_with_cond_scale(self, *args, cond_scale: float = 1.0, **kwargs):
            logits = self.forward(*args, **kwargs)

            if cond_scale == 1.0:
                return logits

            null_logits = self.forward(
                *args,
                text_cond_drop_prob=1.0,
                image_cond_drop_prob=1,
                **kwargs,  # type: ignore
            )
            return null_logits + (logits - null_logits) * cond_scale

        def forward(
            self,
            image_embed: torch.Tensor,  # image_embed is the target we aim to denoise with diffusion network
            diffusion_timesteps: torch.Tensor,
            text_embed: torch.Tensor,  # text_embed are conditioning inputs of diffusion model
            text_cond_drop_prob: float = 0.0,
            image_cond_drop_prob: float = 0.0,
            **kwargs,
        ):
            # text_embed = text_embed
            # brain_cond_drop_prob = text_cond_drop_prob

            image_embed = image_embed.view(len(image_embed), -1, self.dim)
            text_embed = text_embed.view(len(text_embed), -1, self.dim)

            batch, _, dim, device, dtype = (
                *image_embed.shape,
                image_embed.device,
                image_embed.dtype,
            )

            # classifier free guidance masks
            text_keep_mask = dalle2_modules.prob_mask_like(
                (batch,), 1 - text_cond_drop_prob, device=device
            )
            text_keep_mask = dalle2_modules.rearrange(text_keep_mask, "b -> b 1 1")

            image_keep_mask = dalle2_modules.prob_mask_like(
                (batch,), 1 - image_cond_drop_prob, device=device
            )
            image_keep_mask = dalle2_modules.rearrange(image_keep_mask, "b -> b 1 1")

            # mask out text embeddings with null text embeddings
            null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
            text_embed = torch.where(text_keep_mask, text_embed, null_text_embeds[None])

            # mask out image embeddings with null image embeddings
            null_image_embed = self.null_image_embed.to(image_embed.dtype)
            image_embed = torch.where(
                image_keep_mask, image_embed, null_image_embed[None]
            )

            if self.continuous_embedded_time:
                diffusion_timesteps = diffusion_timesteps.type(dtype)
            time_embed = self.to_time_embeds(diffusion_timesteps)

            if self.learned_query_mode == "token":
                learned_queries = dalle2_modules.repeat(
                    self.learned_query, "n d -> b n d", b=batch
                )
            elif self.learned_query_mode == "pos_emb":
                pos_embs = dalle2_modules.repeat(
                    self.learned_query, "n d -> b n d", b=batch
                )
                image_embed = image_embed + pos_embs
                learned_queries = torch.empty((batch, 0, dim), device=text_embed.device)
            elif self.learned_query_mode == "all_pos_emb":
                pos_embs = dalle2_modules.repeat(
                    self.learned_query, "n d -> b n d", b=batch
                )
                learned_queries = torch.empty((batch, 0, dim), device=text_embed.device)
            else:
                learned_queries = torch.empty((batch, 0, dim), device=text_embed.device)

            tokens = torch.cat(
                (text_embed, time_embed, image_embed, learned_queries), dim=-2
            )

            if self.learned_query_mode == "all_pos_emb":
                tokens = tokens + pos_embs

            # attend
            tokens = self.causal_transformer(tokens)

            # get learned query, which should predict the image embedding (per DDPM timestep)
            pred_image_embed = tokens[..., -self.num_image_tokens :, :]

            return pred_image_embed

    class NewDiffusionPrior(DalleDiffusionPrior):  # type : ignore
        @torch.no_grad()
        def p_sample(
            self,
            x: torch.Tensor,
            t: torch.Tensor,
            text_cond: tp.Dict[str, torch.Tensor] | None = None,
            self_cond: torch.Tensor | None = None,
            clip_denoised: bool = True,
            cond_scale: float = 1.0,
            generator: torch.Generator | None = None,
        ):
            b, *_, device = *x.shape, x.device
            model_mean, _, model_log_variance, x_start = self.p_mean_variance(
                x=x,
                t=t,
                text_cond=text_cond,
                self_cond=self_cond,
                clip_denoised=clip_denoised,
                cond_scale=cond_scale,
            )
            if generator is None:
                noise = torch.randn_like(x)
            else:
                noise = torch.randn(
                    x.size(), device=device, dtype=x.dtype, generator=generator
                )
            nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
            pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
            return pred, x_start

        @torch.no_grad()
        def p_sample_loop_ddpm(
            self,
            shape: tuple[int],
            text_cond: tp.Dict[str, torch.Tensor] | None = None,
            cond_scale: float = 1.0,
            generator: torch.Generator | None = None,
        ):
            batch, device = shape[0], self.device

            if generator is None:
                image_embed = torch.randn(shape, device=device)
            else:
                image_embed = torch.randn(shape, device=device, generator=generator)
            x_start = None  # for self-conditioning

            if self.init_image_embed_l2norm:
                image_embed = dalle2_modules.l2norm(image_embed) * self.image_embed_scale

            for i in tqdm(
                reversed(range(0, self.noise_scheduler.num_timesteps)),
                desc="sampling loop time step",
                total=self.noise_scheduler.num_timesteps,
                disable=True,
            ):
                times = torch.full((batch,), i, device=device, dtype=torch.long)

                self_cond = x_start if self.net.self_cond else None
                image_embed, x_start = self.p_sample(
                    image_embed,
                    times,
                    text_cond=text_cond,
                    self_cond=self_cond,
                    cond_scale=cond_scale,
                    generator=generator,
                )

            if self.sampling_final_clamp_l2norm and self.predict_x_start:
                image_embed = self.l2norm_clamp_embed(image_embed)

            return image_embed

        def p_losses(
            self,
            image_embed: torch.Tensor,
            times: torch.Tensor,
            text_cond: tp.Dict[str, torch.Tensor],
            noise: torch.Tensor | None = None,
        ):
            noise = dalle2_modules.default(noise, lambda: torch.randn_like(image_embed))

            image_embed_noisy = self.noise_scheduler.q_sample(
                x_start=image_embed, t=times, noise=noise
            )

            self_cond = None
            if self.net.self_cond and random.random() < 0.5:
                with torch.no_grad():
                    self_cond = self.net(image_embed_noisy, times, **text_cond).detach()

            pred = self.net(
                image_embed_noisy,
                times,
                self_cond=self_cond,
                text_cond_drop_prob=self.text_cond_drop_prob,
                image_cond_drop_prob=self.image_cond_drop_prob,
                **text_cond,
            )

            if self.predict_x_start and self.training_clamp_l2norm:
                pred = self.l2norm_clamp_embed(pred)

            if self.predict_v:
                target = self.noise_scheduler.calculate_v(image_embed, times, noise)
            elif self.predict_x_start:
                target = image_embed
            else:
                target = noise
            return {"pred": pred, "target": target}

        def forward(
            self,
            text_embed: torch.Tensor,
            image_embed: torch.Tensor | None = None,
            text_encodings: torch.Tensor | None = None,
            *args,
            **kwargs,
        ):
            if self.training:
                if image_embed is None:
                    raise ValueError(
                        "image_embed should be passed to diffusion prior during training"
                    )
                text_cond = dict(text_embed=text_embed)

                if self.condition_on_text_encodings:
                    if not dalle2_modules.exists(text_encodings):
                        raise ValueError(
                            "text encodings must be present"
                            " for diffusion prior if specified"
                        )
                    text_cond = {**text_cond, "text_encodings": text_encodings}  # type: ignore

                # timestep conditioning from ddpm
                batch = image_embed.shape[0]
                times = self.noise_scheduler.sample_random_times(batch)

                # calculate forward loss
                out = self.p_losses(
                    image_embed * self.image_embed_scale,
                    times,
                    text_cond=text_cond,  # type: ignore
                    *args,
                    **kwargs,
                )
                return out
            else:
                if image_embed is not None:
                    raise ValueError(
                        "image_embed should not be passed to diffusion prior during evaluation/sampling mode, as generating conditioned on text_embed only, image embed is target during training"
                    )
                out_shape = (
                    text_embed.shape[0],
                    self.net.num_image_tokens,
                    self.image_embed_dim,
                )
                return self.p_sample_loop(
                    out_shape, text_cond=dict(text_embed=text_embed)
                )

    DiffusionPriorModel = NewDiffusionPrior  # type: ignore

    class FlaggedCausalTransformer(nn.Module):
        def __init__(
            self,
            *,
            dim,
            depth,
            dim_head=64,
            heads=8,
            ff_mult=4,
            norm_in=False,
            norm_out=True,
            attn_dropout=0.0,
            ff_dropout=0.0,
            final_proj=True,
            normformer=False,
            rotary_emb=True,
            causal=True,
        ):
            super().__init__()
            self.init_norm = dalle2_modules.LayerNorm(dim) if norm_in else nn.Identity()

            self.rel_pos_bias = dalle2_modules.RelPosBias(heads=heads)

            rotary_emb = (
                dalle2_modules.RotaryEmbedding(dim=min(32, dim_head))
                if rotary_emb
                else None
            )

            self.layers = nn.ModuleList([])
            for _ in range(depth):
                self.layers.append(
                    nn.ModuleList(
                        [
                            dalle2_modules.Attention(
                                dim=dim,
                                causal=causal,
                                dim_head=dim_head,
                                heads=heads,
                                dropout=attn_dropout,
                                rotary_emb=rotary_emb,
                            ),
                            dalle2_modules.FeedForward(
                                dim=dim,
                                mult=ff_mult,
                                dropout=ff_dropout,
                                post_activation_norm=normformer,
                            ),
                        ]
                    )
                )

            self.norm = (
                dalle2_modules.LayerNorm(dim, stable=True) if norm_out else nn.Identity()
            )
            self.project_out = (
                nn.Linear(dim, dim, bias=False) if final_proj else nn.Identity()
            )

        def forward(self, x: torch.Tensor):
            n, device = x.shape[1], x.device

            x = self.init_norm(x)

            attn_bias = self.rel_pos_bias(n, n + 1, device=device)

            for attn, ff in self.layers:
                x = attn(x, attn_bias=attn_bias) + x
                x = ff(x) + x

            out = self.norm(x)
            return self.project_out(out)

except ImportError:

    class DummyDiffusionPrior(torch.nn.Module):  # type : ignore
        def __init__(self, *args, **kwargs):
            super().__init__()
            raise ImportError(
                "Please install dalle2-pytorch to use DiffusionPrior: pip install dalle2-pytorch"
            )

    DiffusionPriorModel = DummyDiffusionPrior  # type: ignore

    class PriorNetwork(torch.nn.Module):  # type: ignore
        def __init__(self, *args, **kwargs):
            super().__init__()
            raise ImportError(
                "Please install dalle2-pytorch to use DiffusionPrior: pip install dalle2-pytorch"
            )


[docs] class DiffusionPrior(BaseModelConfig): """Diffusion prior module adapted from MindEye [1]_. Although the parameters *text_embed* and *image_embed* appear to refer specifically to text and image data, they can represent any embedding: *text_embed* is the input (x) to the diffusion prior, and *image_embed* is the target (y) that the prior aims to denoise. Parameters ---------- depth : int Number of Transformer layers in the prior network. dim_head : int Dimension per attention head. prior_learned_query_mode : {"token", "pos_emb", "all_pos_emb"} How to handle learned queries for image tokens. timesteps : int Number of diffusion denoising steps. cond_drop_prob : float Dropout probability applied to the conditioning input for classifier-free guidance. predict : {"x_start", "v"} Prediction target: ``"x_start"`` predicts the clean embedding directly; ``"v"`` uses the velocity parameterisation from Imagen. References ---------- .. [1] https://github.com/MedARC-AI/fMRI-reconstruction-NSD/blob/main/src/models.py """ depth: int = 6 dim_head: int = 64 prior_learned_query_mode: tp.Literal["token", "pos_emb", "all_pos_emb"] = "pos_emb" timesteps: int = 100 cond_drop_prob: float = 0.2 # prediction type predict: tp.Literal["x_start", "v"] = "x_start" def build( self, dim: int, num_out_tokens: int, num_in_tokens: int, ) -> DiffusionPriorModel: if dim % self.dim_head != 0: raise ValueError(f"dim {dim} must be divisible by dim_head {self.dim_head}") heads = dim // self.dim_head prior_network = PriorNetwork( dim=dim, depth=self.depth, dim_head=self.dim_head, heads=heads, causal=False, num_image_tokens=num_out_tokens, num_text_tokens=num_in_tokens, learned_query_mode=self.prior_learned_query_mode, ) logger.info("prior_network loaded") diffusion_prior = DiffusionPriorModel( net=prior_network, image_embed_dim=dim, condition_on_text_encodings=False, timesteps=self.timesteps, cond_drop_prob=self.cond_drop_prob, image_embed_scale=None, predict_x_start=True if self.predict == "x_start" else False, predict_v=True if self.predict == "v" else False, ) return diffusion_prior