ModelWrapper#

class flow_matching.utils.model_wrapper.ModelWrapper(model: Module)[source]#

This class is used to wrap around another model, adding custom forward pass logic.

forward(x: Tensor, t: Tensor, **extras) Tensor[source]#

This method defines how inputs should be passed through the wrapped model. Here, we’re assuming that the wrapped model takes both \(x\) and \(t\) as input, along with any additional keyword arguments.

Optional things to do here:
  • check that t is in the dimensions that the model is expecting.

  • add a custom forward pass logic.

  • call the wrapped model.

given x, t
returns the model output for input x at time t, with extra information extra.
Parameters:
  • x (Tensor) – input data to the model (batch_size, …).

  • t (Tensor) – time (batch_size).

  • **extras – additional information forwarded to the model, e.g., text condition.

Returns:

model output.

Return type:

Tensor