neuraltrain.models.diffusion_prior.DiffusionPrior¶
- class neuraltrain.models.diffusion_prior.DiffusionPrior(*, depth: int = 6, dim_head: int = 64, prior_learned_query_mode: Literal['token', 'pos_emb', 'all_pos_emb'] = 'pos_emb', timesteps: int = 100, cond_drop_prob: float = 0.2, predict: Literal['x_start', 'v'] = 'x_start')[source][source]¶
Diffusion prior module adapted from MindEye [1].
Although the parameters text_embed and image_embed appear to refer specifically to text and image data, they can represent any embedding: text_embed is the input (x) to the diffusion prior, and image_embed is the target (y) that the prior aims to denoise.
- Parameters:
depth (int) – Number of Transformer layers in the prior network.
dim_head (int) – Dimension per attention head.
prior_learned_query_mode ({"token", "pos_emb", "all_pos_emb"}) – How to handle learned queries for image tokens.
timesteps (int) – Number of diffusion denoising steps.
cond_drop_prob (float) – Dropout probability applied to the conditioning input for classifier-free guidance.
predict ({"x_start", "v"}) – Prediction target:
"x_start"predicts the clean embedding directly;"v"uses the velocity parameterisation from Imagen.
References