neuraltrain.losses.losses.MultiLoss

class neuraltrain.losses.losses.MultiLoss(losses: dict[str, Module], weights: dict[str, float] | None = None)[source][source]

Weighted combination of multiple loss terms.

Can be parametrized through MultiLossConfig.

Parameters:
  • losses – Different loss terms.

  • weights – Weights associated with each loss term. If None, use weight of 1 for each term.

forward(x: Tensor | dict[str, Tensor], y: Tensor | dict[str, Tensor]) dict[str, Tensor][source][source]

Evaluate the different loss terms.

Parameters:
  • input – If provided as a dictionary, the keys must match the loss names provided when the class was instantiated.

  • target – If provided as a dictionary, the keys must match the loss names provided when the class was instantiated.