RiemannianODESolver#
- class flow_matching.solver.RiemannianODESolver(manifold: Manifold, velocity_model: ModelWrapper)[source]#
Riemannian ODE solver Initialize the
RiemannianODESolver
.- Parameters:
manifold (Manifold) – the manifold to solve on.
velocity_model (ModelWrapper) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\) which is assumed to lie on the tangent plane at x.
- sample(x_init: Tensor, step_size: float, projx: bool = True, proju: bool = True, method: str = 'euler', time_grid: Tensor = tensor([0., 1.]), return_intermediates: bool = False, verbose: bool = False, enable_grad: bool = False, **model_extras) Tensor [source]#
Solve the ODE with the velocity_field on the manifold.
- Parameters:
x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)).
step_size (float) – The step size.
projx (bool) – Whether to project the point onto the manifold at each step. Defaults to True.
proju (bool) – Whether to project the vector field onto the tangent plane at each step. Defaults to True.
method (str) – One of [“euler”, “midpoint”, “rk4”]. Defaults to “euler”.
time_grid (Tensor, optional) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.
verbose (bool, optional) – Whether to print progress bars. Defaults to False.
enable_grad (bool, optional) – Whether to compute gradients during sampling. Defaults to False.
**model_extras – Additional input for the model.
- Returns:
The sampled sequence. Defaults to returning samples at \(t=1\).
- Return type:
Tensor
- Raises:
ImportError – To run in verbose mode, tqdm must be installed.