# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Loss functions.
Ensure each new loss function has its own Pydantic configuration object under `losses/config.py` so
that a loss function can be defined in an experiment configuration.
"""
import torch
import torch.nn.functional as F
from torch import nn
[docs]
class ClipLoss(nn.Module):
"""CLIP constrastive loss.
Contrastive Language-Image Pretraining (CLIP) loss from [1]_. Default values reflect the
configuration of the CLIP loss used in [2]_.
Parameters
----------
norm_kind : {"x", "y", "xy"} or None
How to normalize the estimates and/or candidates before computing their dot products.
``'x'``: normalize estimates only.
``'y'``: normalize candidates only (approach originally used in brainmagick).
``'xy'``: normalize both estimates and candidates.
``None``: do not normalize.
temperature : bool
If True, use a learnable temperature parameter.
symmetric : bool
If True, compute loss in both retrieval directions, i.e. retrieve candidates given
estimates and retrieve estimates given candidates (requires estimates and candidates to be
of the same shape). If False, only do the former.
reduction : str
Reduction applied to the per-example cross-entropy loss (forwarded to
``F.cross_entropy``).
References
----------
.. [1] Radford, Alec, et al. "Learning transferable visual models from natural language
supervision." International conference on machine learning. PMLR, 2021.
.. [2] Défossez, Alexandre, et al. "Decoding speech perception from non-invasive brain
recordings." Nature Machine Intelligence (2023): 1-11.
"""
def __init__(
self,
norm_kind: str | None = "y",
temperature: bool = True,
symmetric: bool = True,
reduction: str = "mean",
):
super().__init__()
self.norm_kind = norm_kind
self.temperature = (
nn.Parameter(torch.tensor(1 / 0.07).log())
if temperature
else nn.Parameter(torch.tensor(0.0), requires_grad=False)
)
self.symmetric = symmetric
self.reduction = reduction
@staticmethod
def _compute_similarity(
x: torch.Tensor, y: torch.Tensor, norm: str | None = None, eps=1e-15
) -> torch.Tensor:
if norm is None:
eq, inv_norms = "b", torch.ones(x.shape[0])
elif norm == "x":
eq, inv_norms = "b", 1 / (eps + x.norm(dim=(1), p=2))
elif norm == "y":
eq, inv_norms = "o", 1 / (eps + y.norm(dim=(1), p=2))
elif norm == "xy":
eq = "bo"
inv_norms = 1 / (
eps + torch.outer(x.norm(dim=(1), p=2), y.norm(dim=(1), p=2))
)
else:
raise ValueError(f"norm must be None, x, y or xy, got {norm}.")
# Normalize inside einsum to avoid creating a copy of candidates which can be pretty big
return torch.einsum(f"bc,oc,{eq}->bo", x, y, inv_norms)
[docs]
def get_scores(self, estimate: torch.Tensor, candidate: torch.Tensor) -> torch.Tensor:
"""Given estimates of shape [B, F] and candidates of shape [B', F], return a [B, B'] matrix
of similarity scores.
"""
scores = self._compute_similarity(estimate, candidate, norm=self.norm_kind)
scores = self.temperature.exp() * scores
return scores
[docs]
def get_probabilities(
self, estimate: torch.Tensor, candidate: torch.Tensor
) -> torch.Tensor:
"""Given estimates of shape [B, F] and candidates of shape [B', F], return a [B, B'] matrix
of matching probability.
"""
scores = self.get_scores(estimate, candidate)
return F.softmax(scores, dim=1)
[docs]
def forward(self, estimate: torch.Tensor, candidate: torch.Tensor) -> torch.Tensor:
"""Warning: estimate and candidate are not necessarily symmetrical.
If estimate of shape [B, C] and candidate of shape [B', C] with B'>=B, the first B samples
of candidate are targets, while the remaining B'-B samples of candidate are only used as
negatives.
"""
n_est = estimate.size(0)
n_cand = candidate.size(0)
if n_est > n_cand:
raise ValueError(f"need candidates >= estimates, got {n_cand} vs {n_est}")
scores = self.get_scores(estimate, candidate)
target = torch.arange(len(scores), device=estimate.device)
loss_e = F.cross_entropy(scores, target, reduction=self.reduction)
if self.symmetric:
if scores.shape[0] != scores.shape[1]:
raise ValueError(f"need square scores, got {scores.shape}")
loss_c = F.cross_entropy(
scores.transpose(1, 0), target, reduction=self.reduction
)
loss = (loss_e + loss_c) / 2
else:
loss = loss_e
return loss
[docs]
class SigLipLoss(ClipLoss):
"""SigLIP contrastive loss.
Sigmoid loss for Language-Image Pretraining (SigLIP) from [1]_.
Parameters
----------
norm_kind : {"x", "y", "xy"} or None
How to normalize the estimates and/or candidates before computing their dot products.
``'x'``: normalize estimates only.
``'y'``: normalize candidates only (approach originally used in brainmagick).
``'xy'``: normalize both estimates and candidates.
``None``: do not normalize.
temperature : bool
If True, use a learnable temperature parameter initialized to ``ln(10)``.
bias : bool
If True, use a learnable bias parameter initialized to ``-10`` (since
most pairs are negative).
identical_candidates_threshold : float or None
If given, estimates are matched not only to their candidate, but all candidates
that have a large cosine similarity to their candidate (larger or equal this threshold).
Assumes such other candidates with high cosine similarity are duplicates.
Intended to use only if candidate generator is frozen.
reduction : str
Reduction applied to the binary cross-entropy loss (forwarded to
``F.binary_cross_entropy_with_logits``).
reweigh_positives : bool
If True and *identical_candidates_threshold* is set, down-weight
duplicate positive pairs so only one copy contributes to the loss.
References
----------
.. [1] Zhai, Xiaohua, et al. "Sigmoid loss for language image pre-training." arXiv preprint
arXiv:2303.15343 (2023).
Note
----
Official jax implementation: https://github.com/google-research/big_vision/blob/474dd2ebde37268db4ea44decef14c7c1f6a0258/big_vision/trainers/proj/image_text/siglip.py
"""
def __init__(
self,
norm_kind: str | None = "y",
temperature: bool = True,
bias: bool = True,
identical_candidates_threshold: float | None = 0.999,
reduction: str = "sum",
reweigh_positives: bool = False,
):
super().__init__(
norm_kind=norm_kind, temperature=False, symmetric=True, reduction=reduction
)
self.temperature = (
nn.Parameter(torch.tensor(10).log())
if temperature
else nn.Parameter(torch.tensor(0.0), requires_grad=False)
)
self.bias = (
nn.Parameter(torch.tensor(-10.0))
if bias
else nn.Parameter(torch.tensor(0.0), requires_grad=False)
)
self.identical_candidates_threshold = identical_candidates_threshold
self.reweigh_positives = reweigh_positives
[docs]
def get_scores(self, estimate: torch.Tensor, candidate: torch.Tensor) -> torch.Tensor:
"""Given estimates of shape [B, F] and candidates of shape [B', F], return a [B, B'] matrix
of similarity scores.
"""
return super().get_scores(estimate, candidate) + self.bias
[docs]
def forward(self, estimate: torch.Tensor, candidate: torch.Tensor) -> torch.Tensor:
n_est = estimate.size(0)
n_cand = candidate.size(0)
if n_est > n_cand:
raise ValueError(f"need candidates >= estimates, got {n_cand} vs {n_est}")
scores = self.get_scores(estimate, candidate)
if self.identical_candidates_threshold is not None:
candidate_sim = self._compute_similarity(
candidate, candidate, "xy", eps=1e-15
)
targets = 1.0 * (candidate_sim >= self.identical_candidates_threshold)
targets = targets[: len(estimate)]
if self.reweigh_positives:
weights = 1.0 * (candidate_sim >= self.identical_candidates_threshold)
weights = 1 - weights # remove all duplicates
weights += torch.eye(
*weights.shape, device=weights.device
) # keep only one
else:
weights = None
else:
weights = None
targets = torch.eye(*scores.shape, device=scores.device)
loss = F.binary_cross_entropy_with_logits(
scores, targets, weights, reduction=self.reduction
)
return loss
[docs]
class MultiLoss(nn.Module):
"""Weighted combination of multiple loss terms.
Can be parametrized through MultiLossConfig.
Parameters
----------
losses :
Different loss terms.
weights :
Weights associated with each loss term. If None, use weight of 1 for each term.
"""
def __init__(
self,
losses: dict[str, nn.Module],
weights: dict[str, float] | None = None,
):
super().__init__()
self.losses = nn.ModuleDict(losses)
if weights is None:
weights = {name: 1.0 for name in losses}
self.weights = weights
assert len(losses) == len(weights)
[docs]
def forward(
self,
x: torch.Tensor | dict[str, torch.Tensor],
y: torch.Tensor | dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""Evaluate the different loss terms.
Parameters
----------
input, target :
If provided as a dictionary, the keys must match the loss names provided when
the class was instantiated.
"""
loss_values = {"total": torch.tensor(0.0)}
for name in self.losses:
_x = x[name] if isinstance(x, dict) else x
_y = y[name] if isinstance(y, dict) else y
loss_values[name] = self.losses[name](_x, _y)
loss_values["total"] += self.weights[name] * loss_values[name]
return loss_values