ODESolver#

class flow_matching.solver.ODESolver(velocity_model: ModelWrapper | Callable)[source]#

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.

Parameters:

velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)

sample(x_init: Tensor, step_size: float | None, method: str = 'euler', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Tensor = tensor([0., 1.]), return_intermediates: bool = False, enable_grad: bool = False, **model_extras) Tensor | Sequence[Tensor][source]#

Solve the ODE with the velocity field.

Example:

import torch
from flow_matching.utils import ModelWrapper
from flow_matching.solver import ODESolver

class DummyModel(ModelWrapper):
    def __init__(self):
        super().__init__(None)

    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
        return torch.ones_like(x) * 3.0 * t**2

velocity_model = DummyModel()
solver = ODESolver(velocity_model=velocity_model)
x_init = torch.tensor([0.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
Parameters:
  • x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].

  • step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.

  • method (str) – A method supported by torchdiffeq. Defaults to “euler”. Other commonly used solvers are “dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.

  • atol (float) – Absolute tolerance, used for adaptive step solvers.

  • rtol (float) – Relative tolerance, used for adaptive step solvers.

  • time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • enable_grad (bool, optional) – Whether to compute gradients during sampling. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

The last timestep when return_intermediates=False, otherwise all values specified in time_grid.

Return type:

Union[Tensor, Sequence[Tensor]]

compute_likelihood(x_1: Tensor, log_p0: Callable[[Tensor], Tensor], step_size: float | None, method: str = 'euler', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Tensor = tensor([1., 0.]), return_intermediates: bool = False, exact_divergence: bool = False, enable_grad: bool = False, **model_extras) Tuple[Tensor, Tensor] | Tuple[Sequence[Tensor], Tensor][source]#

Solve for log likelihood given a target sample at \(t=0\).

Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x. The function assumes log_p0 is the log probability of the source distribution at \(t=0\).

Parameters:
  • x_1 (Tensor) – target sample (e.g., samples \(X_1 \sim p_1\)).

  • log_p0 (Callable[[Tensor], Tensor]) – Log probability function of the source distribution.

  • step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.

  • method (str) – A method supported by torchdiffeq. Defaults to “euler”. Other commonly used solvers are “dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.

  • atol (float) – Absolute tolerance, used for adaptive step solvers.

  • rtol (float) – Relative tolerance, used for adaptive step solvers.

  • time_grid (Tensor) – If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.

  • exact_divergence (bool) – Whether to compute the exact divergence or use the Hutchinson estimator.

  • enable_grad (bool, optional) – Whether to compute gradients during sampling. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

Samples at time_grid and log likelihood values of given x_1.

Return type:

Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]