neuraltrain.models.diffusion_prior.DiffusionPrior

class neuraltrain.models.diffusion_prior.DiffusionPrior(*, depth: int = 6, dim_head: int = 64, prior_learned_query_mode: Literal['token', 'pos_emb', 'all_pos_emb'] = 'pos_emb', timesteps: int = 100, cond_drop_prob: float = 0.2, predict: Literal['x_start', 'v'] = 'x_start')[source][source]

Diffusion prior module adapted from MindEye [1].

Although the parameters text_embed and image_embed appear to refer specifically to text and image data, they can represent any embedding: text_embed is the input (x) to the diffusion prior, and image_embed is the target (y) that the prior aims to denoise.

Parameters:
  • depth (int) – Number of Transformer layers in the prior network.

  • dim_head (int) – Dimension per attention head.

  • prior_learned_query_mode ({"token", "pos_emb", "all_pos_emb"}) – How to handle learned queries for image tokens.

  • timesteps (int) – Number of diffusion denoising steps.

  • cond_drop_prob (float) – Dropout probability applied to the conditioning input for classifier-free guidance.

  • predict ({"x_start", "v"}) – Prediction target: "x_start" predicts the clean embedding directly; "v" uses the velocity parameterisation from Imagen.

References