ODESolver#
- class flow_matching.solver.ODESolver(velocity_model: ModelWrapper | Callable)[source]#
A class to solve ordinary differential equations (ODEs) using a specified velocity model.
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
- Parameters:
velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)
- sample(x_init: Tensor, step_size: float | None, method: str = 'euler', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Tensor = tensor([0., 1.]), return_intermediates: bool = False, enable_grad: bool = False, **model_extras) Tensor | Sequence[Tensor] [source]#
Solve the ODE with the velocity field.
Example:
import torch from flow_matching.utils import ModelWrapper from flow_matching.solver import ODESolver class DummyModel(ModelWrapper): def __init__(self): super().__init__(None) def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: return torch.ones_like(x) * 3.0 * t**2 velocity_model = DummyModel() solver = ODESolver(velocity_model=velocity_model) x_init = torch.tensor([0.0, 0.0]) step_size = 0.001 time_grid = torch.tensor([0.0, 1.0]) result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
- Parameters:
x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].
step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.
method (str) – A method supported by torchdiffeq. Defaults to “euler”. Other commonly used solvers are “dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.
atol (float) – Absolute tolerance, used for adaptive step solvers.
rtol (float) – Relative tolerance, used for adaptive step solvers.
time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.
enable_grad (bool, optional) – Whether to compute gradients during sampling. Defaults to False.
**model_extras – Additional input for the model.
- Returns:
The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
- Return type:
Union[Tensor, Sequence[Tensor]]
- compute_likelihood(x_1: Tensor, log_p0: Callable[[Tensor], Tensor], step_size: float | None, method: str = 'euler', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Tensor = tensor([1., 0.]), return_intermediates: bool = False, exact_divergence: bool = False, enable_grad: bool = False, **model_extras) Tuple[Tensor, Tensor] | Tuple[Sequence[Tensor], Tensor] [source]#
Solve for log likelihood given a target sample at \(t=0\).
Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x. The function assumes log_p0 is the log probability of the source distribution at \(t=0\).
- Parameters:
x_1 (Tensor) – target sample (e.g., samples \(X_1 \sim p_1\)).
log_p0 (Callable[[Tensor], Tensor]) – Log probability function of the source distribution.
step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.
method (str) – A method supported by torchdiffeq. Defaults to “euler”. Other commonly used solvers are “dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.
atol (float) – Absolute tolerance, used for adaptive step solvers.
rtol (float) – Relative tolerance, used for adaptive step solvers.
time_grid (Tensor) – If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
exact_divergence (bool) – Whether to compute the exact divergence or use the Hutchinson estimator.
enable_grad (bool, optional) – Whether to compute gradients during sampling. Defaults to False.
**model_extras – Additional input for the model.
- Returns:
Samples at time_grid and log likelihood values of given x_1.
- Return type:
Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]