ScheduleTransformedModel#

class flow_matching.path.scheduler.ScheduleTransformedModel(velocity_model: ModelWrapper, original_scheduler: Scheduler, new_scheduler: Scheduler)[source]#

Change of scheduler for a velocity model.

This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.

Example:

import torch
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
from flow_matching.solver import ODESolver

# Initialize the model and schedulers
model = ...

original_scheduler = CondOTScheduler()
new_scheduler = CosineScheduler()

# Create the transformed model
transformed_model = ScheduleTransformedModel(
    velocity_model=model,
    original_scheduler=original_scheduler,
    new_scheduler=new_scheduler
)

# Set up the solver
solver = ODESolver(velocity_model=transformed_model)

x_0 = torch.randn([10, 2])  # Example initial condition

x_1 = solver.sample(
    time_steps=torch.tensor([0.0, 1.0]),
    x_init=x_0,
    step_size=1/1000
    )[1]
Parameters:
  • velocity_model (ModelWrapper) – The original velocity model to be transformed.

  • original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.

  • new_scheduler (Scheduler) – The new scheduler to be applied to the model.

forward(x: Tensor, t: Tensor, **extras) Tensor[source]#

Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows. It transforms a generating marginal velocity field \(u_t(x)\) based on an original scheduler to a new marginal velocity field \(\bar{u}_r(x)\) based on a different scheduler, while maintaining the same data coupling. The transformation is based on the scale-time (ST) transformation between the two conditional flows, defined as:

\[\bar{X}_r = s_r X_{t_r},\]

where \(X_t\) and \(\bar{X}_r\) are defined by their respective schedulers. The ST transformation is computed as:

\[t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}.\]

Here, \(\rho(t)\) is the signal-to-noise ratio (SNR) defined as:

\[\rho(t) = \frac{\alpha_t}{\sigma_t}.\]

\(\bar{\rho}(r)\) is similarly defined for the new scheduler. The marginal velocity for the new scheduler is then given by:

\[\bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right).\]
Parameters:
  • x (Tensor) – \(x_t\), the input tensor.

  • t (Tensor) – The time tensor (denoted as \(r\) above).

  • **extras – Additional arguments for the model.

Returns:

The transformed velocity.

Return type:

Tensor