ProbPath#

class flow_matching.path.ProbPath[source]#

Abstract class, representing a probability path.

A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).

The ProbPath class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example

# Instantiate a probability path
my_path = ProbPath(...)

for x_0, x_1 in dataset:
    # Sets t to a random value in [0,1]
    t = torch.rand()

    # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
    path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)

    # Optimizes the model. The loss function varies, depending on model and path.
    loss(path_sample, my_model(x_t, t)).backward()
abstract sample(x_0: Tensor, x_1: Tensor, t: Tensor) PathSample[source]#

Sample from an abstract probability path:

given \((X_0,X_1) \sim \pi(X_0,X_1)\).
returns \(X_0, X_1, X_t \sim p_t(X_t)\), and a conditional target \(Y\), all objects are under PathSample.
Parameters:
  • x_0 (Tensor) – source data point, shape (batch_size, …).

  • x_1 (Tensor) – target data point, shape (batch_size, …).

  • t (Tensor) – times in [0,1], shape (batch_size).

Returns:

a conditional sample.

Return type:

PathSample