MixturePathGeneralizedKL#
- class flow_matching.loss.MixturePathGeneralizedKL(path: MixtureDiscreteProbPath, reduction: str = 'mean')[source]#
A generalized KL loss for discrete flow matching. A class that measures the generalized KL of a discrete flow model \(p_{1|t}\) w.r.t. a probability path given by
path
. Note: this class is assuming that the model is trained on the same path.For a model trained on a space \(\mathcal{S} = \mathcal{T}^d\), \(\mathcal{T} = [K] = \set{1,2,\ldots,K}\), the loss is given by
\[\ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \biggr[ p_{1|t}(x_t^i|x_t) -\delta_{x^i_1}(x_t^i) + (1-\delta_{x^i_1}(x_t^i))\left(\log p_{1|t}(x_1^i|x_t)\right)\biggr], \]where \(\kappa_t\) is the scheduler associated with
path
.- Parameters:
path (MixtureDiscreteProbPath) – Probability path (x-prediction training).
reduction (str, optional) – Specify the reduction to apply to the output
'none'
|'mean'
|'sum'
.'none'
: no reduction is applied to the output,'mean'
: the output is reduced by mean over sequence elements,'sum'
: the output is reduced by sum over sequence elements. Defaults to ‘mean’.
- forward(logits: Tensor, x_1: Tensor, x_t: Tensor, t: Tensor) Tensor [source]#
Evaluates the generalized KL loss.
- Parameters:
logits (Tensor) – posterior model output (i.e., softmax(
logits
) \(=p_{1|t}(x|x_t)\)), shape (batch, d, K).x_1 (Tensor) – target data point \(x_1 \sim q\), shape (batch, d).
x_t (Tensor) – conditional sample at \(x_t \sim p_t(\cdot|x_1)\), shape (batch, d).
t (Tensor) – times in \([0,1]\), shape (batch).
- Raises:
ValueError – reduction value must be one of
'none'
|'mean'
|'sum'
.- Returns:
Generalized KL loss.
- Return type:
Tensor