MixtureDiscreteEulerSolver#
- class flow_matching.solver.MixtureDiscreteEulerSolver(model: ModelWrapper, path: MixtureDiscreteProbPath, vocabulary_size: int, source_distribution_p: Tensor | None = None)[source]#
Solver that simulates the CTMC process \((X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}\) defined by \(p_t\) the marginal probability path of
path
. Given \(X_t \sim p_t\), the algorithm of solver step from \(t\) to \(t+h\) for the i-th coordinate is:\[\begin{align*} & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ & Z^i_{\text{change}} \sim U[0,1]\\ & X_{t+h}^i \sim \begin{cases} \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ \delta_{X_t^i}(\cdot) \text{ else } \end{cases} \end{align*}\]Where \(p_{1|t}(\cdot|X_t)\) is the output of
model
, and the conditional probability velocity is of the mixture probability path is:\[u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right],\]where
\[\hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], \]and
\[\check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right].\]The source distribution \(p(x^i)\) is given by
p
.- Parameters:
model (ModelWrapper) – trained with x-prediction, outputting posterior probabilities (in the range \([0,1]\)), output must be […, vocabulary_size].
path (MixtureDiscreteProbPath) – Probability path used for x-prediction training.
vocabulary_size (int) – size of the discrete vocabulary.
source_distribution_p (Optional[Tensor], optional) – Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None.
- sample(x_init: Tensor, step_size: float | None, div_free: float | Callable[[float], float] = 0.0, dtype_categorical: dtype = torch.float32, time_grid: Tensor = tensor([0., 1.]), return_intermediates: bool = False, verbose: bool = False, **model_extras) Tensor [source]#
Sample a sequence of discrete values from the given model.
import torch from flow_matching.utils import ModelWrapper from flow_matching.solver import MixtureDiscreteEulerSolver class DummyModel(ModelWrapper): def __init__(self): super().__init__(None) def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: return ... model = DummyModel() solver = MixtureDiscreteEulerSolver(model=model) x_init = torch.LongTensor([122, 725]) step_size = 0.001 time_grid = torch.tensor([0.0, 1.0]) result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
- Parameters:
x_init (Tensor) – The initial state.
step_size (Optional[float]) – If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid.
div_free (Union[float, Callable[[float], float]]) – The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0.
dtype_categorical (torch.dtype) – Precision to use for categorical sampler. Defaults to torch.float32.
time_grid (Tensor) – The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
return_intermediates (bool) – If True then return intermediate time steps according to time_grid. Defaults to False.
verbose (bool) – Whether to print progress bars. Defaults to False.
**model_extras – Additional input for the model.
- Returns:
The sampled sequence of discrete values.
- Return type:
Tensor
- Raises:
ImportError – To run in verbose mode, tqdm must be installed.