neuraltrain.losses.losses.MultiLoss¶
- class neuraltrain.losses.losses.MultiLoss(losses: dict[str, Module], weights: dict[str, float] | None = None)[source][source]¶
Weighted combination of multiple loss terms.
Can be parametrized through MultiLossConfig.
- Parameters:
losses – Different loss terms.
weights – Weights associated with each loss term. If None, use weight of 1 for each term.
- forward(x: Tensor | dict[str, Tensor], y: Tensor | dict[str, Tensor]) dict[str, Tensor][source][source]¶
Evaluate the different loss terms.
- Parameters:
input – If provided as a dictionary, the keys must match the loss names provided when the class was instantiated.
target – If provided as a dictionary, the keys must match the loss names provided when the class was instantiated.