ScheduleTransformedModel#
- class flow_matching.path.scheduler.ScheduleTransformedModel(velocity_model: ModelWrapper, original_scheduler: Scheduler, new_scheduler: Scheduler)[source]#
Change of scheduler for a velocity model.
This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.
Example:
import torch from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel from flow_matching.solver import ODESolver # Initialize the model and schedulers model = ... original_scheduler = CondOTScheduler() new_scheduler = CosineScheduler() # Create the transformed model transformed_model = ScheduleTransformedModel( velocity_model=model, original_scheduler=original_scheduler, new_scheduler=new_scheduler ) # Set up the solver solver = ODESolver(velocity_model=transformed_model) x_0 = torch.randn([10, 2]) # Example initial condition x_1 = solver.sample( time_steps=torch.tensor([0.0, 1.0]), x_init=x_0, step_size=1/1000 )[1]
- Parameters:
velocity_model (ModelWrapper) – The original velocity model to be transformed.
original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.
new_scheduler (Scheduler) – The new scheduler to be applied to the model.
- forward(x: Tensor, t: Tensor, **extras) Tensor [source]#
Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows. It transforms a generating marginal velocity field \(u_t(x)\) based on an original scheduler to a new marginal velocity field \(\bar{u}_r(x)\) based on a different scheduler, while maintaining the same data coupling. The transformation is based on the scale-time (ST) transformation between the two conditional flows, defined as:
\[\bar{X}_r = s_r X_{t_r},\]where \(X_t\) and \(\bar{X}_r\) are defined by their respective schedulers. The ST transformation is computed as:
\[t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}.\]Here, \(\rho(t)\) is the signal-to-noise ratio (SNR) defined as:
\[\rho(t) = \frac{\alpha_t}{\sigma_t}.\]\(\bar{\rho}(r)\) is similarly defined for the new scheduler. The marginal velocity for the new scheduler is then given by:
\[\bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right).\]- Parameters:
x (Tensor) – \(x_t\), the input tensor.
t (Tensor) – The time tensor (denoted as \(r\) above).
**extras – Additional arguments for the model.
- Returns:
The transformed velocity.
- Return type:
Tensor