RiemannianODESolver#

class flow_matching.solver.RiemannianODESolver(manifold: Manifold, velocity_model: ModelWrapper)[source]#

Riemannian ODE solver Initialize the RiemannianODESolver.

Parameters:
  • manifold (Manifold) – the manifold to solve on.

  • velocity_model (ModelWrapper) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\) which is assumed to lie on the tangent plane at x.

sample(x_init: Tensor, step_size: float, projx: bool = True, proju: bool = True, method: str = 'euler', time_grid: Tensor = tensor([0., 1.]), return_intermediates: bool = False, verbose: bool = False, enable_grad: bool = False, **model_extras) Tensor[source]#

Solve the ODE with the velocity_field on the manifold.

Parameters:
  • x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)).

  • step_size (float) – The step size.

  • projx (bool) – Whether to project the point onto the manifold at each step. Defaults to True.

  • proju (bool) – Whether to project the vector field onto the tangent plane at each step. Defaults to True.

  • method (str) – One of [“euler”, “midpoint”, “rk4”]. Defaults to “euler”.

  • time_grid (Tensor, optional) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • verbose (bool, optional) – Whether to print progress bars. Defaults to False.

  • enable_grad (bool, optional) – Whether to compute gradients during sampling. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

The sampled sequence. Defaults to returning samples at \(t=1\).

Return type:

Tensor

Raises:

ImportError – To run in verbose mode, tqdm must be installed.