AffineProbPath#

class flow_matching.path.AffineProbPath(scheduler: Scheduler)[source]#

The AffineProbPath class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:

\[X_t = \alpha_t X_1 + \sigma_t X_0,\]

where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.

The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.

Using AffineProbPath in the flow matching framework:

# Instantiates a probability path
my_path = AffineProbPath(...)
mse_loss = torch.nn.MSELoss()

for x_1 in dataset:
    # Sets x_0 to random noise
    x_0 = torch.randn()

    # Sets t to a random value in [0,1]
    t = torch.rand()

    # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
    path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)

    # Computes the MSE loss w.r.t. the velocity
    loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
    loss.backward()
Parameters:

scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.

sample(x_0: Tensor, x_1: Tensor, t: Tensor) PathSample[source]#

Sample from the affine probability path:

given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\).
return \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).
Parameters:
  • x_0 (Tensor) – source data point, shape (batch_size, …).

  • x_1 (Tensor) – target data point, shape (batch_size, …).

  • t (Tensor) – times in [0,1], shape (batch_size).

Returns:

a conditional sample at \(X_t \sim p_t\).

Return type:

PathSample

target_to_velocity(x_1: Tensor, x_t: Tensor, t: Tensor) Tensor[source]#

Convert from x_1 representation to velocity.

given \(X_1\).
return \(\dot{X}_t\).
Parameters:
  • x_1 (Tensor) – target data point.

  • x_t (Tensor) – path sample at time t.

  • t (Tensor) – time in [0,1].

Returns:

velocity.

Return type:

Tensor

epsilon_to_velocity(epsilon: Tensor, x_t: Tensor, t: Tensor) Tensor[source]#

Convert from epsilon representation to velocity.

given \(\epsilon\).
return \(\dot{X}_t\).
Parameters:
  • epsilon (Tensor) – noise in the path sample.

  • x_t (Tensor) – path sample at time t.

  • t (Tensor) – time in [0,1].

Returns:

velocity.

Return type:

Tensor

velocity_to_target(velocity: Tensor, x_t: Tensor, t: Tensor) Tensor[source]#

Convert from velocity to x_1 representation.

given \(\dot{X}_t\).
return \(X_1\).
Parameters:
  • velocity (Tensor) – velocity at the path sample.

  • x_t (Tensor) – path sample at time t.

  • t (Tensor) – time in [0,1].

Returns:

target data point.

Return type:

Tensor

epsilon_to_target(epsilon: Tensor, x_t: Tensor, t: Tensor) Tensor[source]#

Convert from epsilon representation to x_1 representation.

given \(\epsilon\).
return \(X_1\).
Parameters:
  • epsilon (Tensor) – noise in the path sample.

  • x_t (Tensor) – path sample at time t.

  • t (Tensor) – time in [0,1].

Returns:

target data point.

Return type:

Tensor

velocity_to_epsilon(velocity: Tensor, x_t: Tensor, t: Tensor) Tensor[source]#

Convert from velocity to noise representation.

given \(\dot{X}_t\).
return \(\epsilon\).
Parameters:
  • velocity (Tensor) – velocity at the path sample.

  • x_t (Tensor) – path sample at time t.

  • t (Tensor) – time in [0,1].

Returns:

noise in the path sample.

Return type:

Tensor

target_to_epsilon(x_1: Tensor, x_t: Tensor, t: Tensor) Tensor[source]#

Convert from x_1 representation to velocity.

given \(X_1\).
return \(\epsilon\).
Parameters:
  • x_1 (Tensor) – target data point.

  • x_t (Tensor) – path sample at time t.

  • t (Tensor) – time in [0,1].

Returns:

noise in the path sample.

Return type:

Tensor