MixtureDiscreteEulerSolver#

class flow_matching.solver.MixtureDiscreteEulerSolver(model: ModelWrapper, path: MixtureDiscreteProbPath, vocabulary_size: int, source_distribution_p: Tensor | None = None)[source]#

Solver that simulates the CTMC process \((X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}\) defined by \(p_t\) the marginal probability path of path. Given \(X_t \sim p_t\), the algorithm of solver step from \(t\) to \(t+h\) for the i-th coordinate is:

\[\begin{align*} & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ & Z^i_{\text{change}} \sim U[0,1]\\ & X_{t+h}^i \sim \begin{cases} \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ \delta_{X_t^i}(\cdot) \text{ else } \end{cases} \end{align*}\]

Where \(p_{1|t}(\cdot|X_t)\) is the output of model, and the conditional probability velocity is of the mixture probability path is:

\[u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right],\]

where

\[\hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], \]

and

\[\check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right].\]

The source distribution \(p(x^i)\) is given by p.

Parameters:
  • model (ModelWrapper) – trained with x-prediction, outputting posterior probabilities (in the range \([0,1]\)), output must be […, vocabulary_size].

  • path (MixtureDiscreteProbPath) – Probability path used for x-prediction training.

  • vocabulary_size (int) – size of the discrete vocabulary.

  • source_distribution_p (Optional[Tensor], optional) – Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None.

sample(x_init: Tensor, step_size: float | None, div_free: float | Callable[[float], float] = 0.0, dtype_categorical: dtype = torch.float32, time_grid: Tensor = tensor([0., 1.]), return_intermediates: bool = False, verbose: bool = False, **model_extras) Tensor[source]#

Sample a sequence of discrete values from the given model.

import torch
from flow_matching.utils import ModelWrapper
from flow_matching.solver import MixtureDiscreteEulerSolver

class DummyModel(ModelWrapper):
    def __init__(self):
        super().__init__(None)
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
        return ...

model = DummyModel()
solver = MixtureDiscreteEulerSolver(model=model)

x_init = torch.LongTensor([122, 725])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
Parameters:
  • x_init (Tensor) – The initial state.

  • step_size (Optional[float]) – If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid.

  • div_free (Union[float, Callable[[float], float]]) – The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0.

  • dtype_categorical (torch.dtype) – Precision to use for categorical sampler. Defaults to torch.float32.

  • time_grid (Tensor) – The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).

  • return_intermediates (bool) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • verbose (bool) – Whether to print progress bars. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

The sampled sequence of discrete values.

Return type:

Tensor

Raises:

ImportError – To run in verbose mode, tqdm must be installed.