MixtureDiscreteProbPath#
- class flow_matching.path.MixtureDiscreteProbPath(scheduler: ConvexScheduler)[source]#
The
MixtureDiscreteProbPath
class defines a factorized discrete probability path.This path remains constant at the source data point \(X_0\) until a random time, determined by the scheduler, when it flips to the target data point \(X_1\). The scheduler determines the flip probability using the parameter \(\sigma_t\), which is a function of time t. Specifically, \(\sigma_t\) represents the probability of remaining at \(X_0\), while \(1 - \sigma_t\) is the probability of flipping to \(X_1\):
\[P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t,\]where \(\sigma_t\) is provided by the scheduler.
Example:
>>> x_0 = torch.zeros((1, 3, 3)) >>> x_1 = torch.ones((1, 3, 3)) >>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0)) >>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t >>> result tensor([[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]]) >>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t >>> result tensor([[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]]) >>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t >>> result tensor([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]])
- Parameters:
scheduler (ConvexScheduler) – The scheduler that provides \(\sigma_t\).
- sample(x_0: Tensor, x_1: Tensor, t: Tensor) DiscretePathSample [source]#
- Sample from the affine probability path:
- given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\).return \(X_0, X_1, t\), and \(X_t \sim p_t\).
- Parameters:
x_0 (Tensor) – source data point, shape (batch_size, …).
x_1 (Tensor) – target data point, shape (batch_size, …).
t (Tensor) – times in [0,1], shape (batch_size).
- Returns:
a conditional sample at \(X_t ~ p_t\).
- Return type:
- posterior_to_velocity(posterior_logits: Tensor, x_t: Tensor, t: Tensor) Tensor [source]#
Convert the factorized posterior to velocity.
given \(p(X_1|X_t)\). In the factorized case: \(\prod_i p(X_1^i | X_t)\).return \(u_t\).- Parameters:
posterior_logits (Tensor) – logits of the x_1 posterior conditional on x_t, shape (…, vocab size).
x_t (Tensor) – path sample at time t, shape (…).
t (Tensor) – time in [0,1].
- Returns:
velocity.
- Return type:
Tensor