Manifold#

class flow_matching.utils.manifolds.Manifold(*args, **kwargs)[source]#

A manifold class that contains projection operations and logarithm and exponential maps.

abstract expmap(x: Tensor, u: Tensor) Tensor[source]#

Computes exponential map \(\exp_x(u)\).

Parameters:
  • x (Tensor) – point on the manifold

  • u (Tensor) – tangent vector at point \(x\)

Raises:

NotImplementedError – if not implemented

Returns:

transported point

Return type:

Tensor

abstract logmap(x: Tensor, y: Tensor) Tensor[source]#

Computes logarithmic map \(\log_x(y)\).

Parameters:
  • x (Tensor) – point on the manifold

  • y (Tensor) – point on the manifold

Raises:

NotImplementedError – if not implemented

Returns:

tangent vector at point \(x\)

Return type:

Tensor

abstract projx(x: Tensor) Tensor[source]#

Project point \(x\) on the manifold.

Parameters:

x (Tensor) – point to be projected

Raises:

NotImplementedError – if not implemented

Returns:

projected point on the manifold

Return type:

Tensor

abstract proju(x: Tensor, u: Tensor) Tensor[source]#

Project vector \(u\) on a tangent space for \(x\).

Parameters:
  • x (Tensor) – point on the manifold

  • u (Tensor) – vector to be projected

Raises:

NotImplementedError – if not implemented

Returns:

projected tangent vector

Return type:

Tensor