MixturePathGeneralizedKL#

class flow_matching.loss.MixturePathGeneralizedKL(path: MixtureDiscreteProbPath, reduction: str = 'mean')[source]#

A generalized KL loss for discrete flow matching. A class that measures the generalized KL of a discrete flow model \(p_{1|t}\) w.r.t. a probability path given by path. Note: this class is assuming that the model is trained on the same path.

For a model trained on a space \(\mathcal{S} = \mathcal{T}^d\), \(\mathcal{T} = [K] = \set{1,2,\ldots,K}\), the loss is given by

\[\ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \biggr[ p_{1|t}(x_t^i|x_t) -\delta_{x^i_1}(x_t^i) + (1-\delta_{x^i_1}(x_t^i))\left(\log p_{1|t}(x_1^i|x_t)\right)\biggr], \]

where \(\kappa_t\) is the scheduler associated with path.

Parameters:
  • path (MixtureDiscreteProbPath) – Probability path (x-prediction training).

  • reduction (str, optional) – Specify the reduction to apply to the output 'none' | 'mean' | 'sum'. 'none': no reduction is applied to the output, 'mean': the output is reduced by mean over sequence elements, 'sum': the output is reduced by sum over sequence elements. Defaults to ‘mean’.

forward(logits: Tensor, x_1: Tensor, x_t: Tensor, t: Tensor) Tensor[source]#

Evaluates the generalized KL loss.

Parameters:
  • logits (Tensor) – posterior model output (i.e., softmax(logits) \(=p_{1|t}(x|x_t)\)), shape (batch, d, K).

  • x_1 (Tensor) – target data point \(x_1 \sim q\), shape (batch, d).

  • x_t (Tensor) – conditional sample at \(x_t \sim p_t(\cdot|x_1)\), shape (batch, d).

  • t (Tensor) – times in \([0,1]\), shape (batch).

Raises:

ValueError – reduction value must be one of 'none' | 'mean' | 'sum'.

Returns:

Generalized KL loss.

Return type:

Tensor