geodesic#

class flow_matching.utils.manifolds.geodesic(manifold: Manifold, start_point: Tensor, end_point: Tensor)[source]#

Generate parameterized function for geodesic curve.

Parameters:
  • manifold (Manifold) – the manifold to compute geodesic on.

  • start_point (Tensor) – point on the manifold at \(t=0\).

  • end_point (Tensor) – point on the manifold at \(t=1\).

Returns:

a function that takes in \(t\) and outputs the geodesic at time \(t\).

Return type:

Callable[[Tensor], Tensor]