Module audiocraft.losses.wmloss
Classes
class WMDetectionLoss (p_weight: float = 1.0, n_weight: float = 1.0)-
Expand source code
class WMDetectionLoss(nn.Module): """Compute the detection loss""" def __init__(self, p_weight: float = 1.0, n_weight: float = 1.0) -> None: super().__init__() self.criterion = nn.NLLLoss() self.p_weight = p_weight self.n_weight = n_weight def forward(self, positive, negative, mask, message=None): positive = positive[:, :2, :] # b 2+nbits t -> b 2 t negative = negative[:, :2, :] # b 2+nbits t -> b 2 t # dimensionality of positive [bsz, classes=2, time_steps] # correct classes for pos = [bsz, time_steps] where all values = 1 for positive classes_shape = positive[ :, 0, : ] # same as positive or negative but dropping dim=1 pos_correct_classes = torch.ones_like(classes_shape, dtype=int) neg_correct_classes = torch.zeros_like(classes_shape, dtype=int) # taking log because network outputs softmax # NLLLoss expects a logsoftmax input positive = torch.log(positive) negative = torch.log(negative) if not torch.all(mask == 1): # pos_correct_classes [bsz, timesteps] mask [bsz, 1, timesptes] # mask is applied to the watermark, this basically flips the tgt class from 1 (positive) # to 0 (negative) in the correct places pos_correct_classes = pos_correct_classes * mask[:, 0, :].to(int) loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) # no need for negative class loss here since some of the watermark # is masked to negative return loss_p else: loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) loss_n = self.n_weight * self.criterion(negative, neg_correct_classes) return loss_p + loss_nCompute the detection loss
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torch.nn.modules.module.Module
Class variables
var call_super_init : boolvar dump_patches : boolvar training : bool
Methods
def forward(self, positive, negative, mask, message=None) ‑> Callable[..., Any]-
Expand source code
def forward(self, positive, negative, mask, message=None): positive = positive[:, :2, :] # b 2+nbits t -> b 2 t negative = negative[:, :2, :] # b 2+nbits t -> b 2 t # dimensionality of positive [bsz, classes=2, time_steps] # correct classes for pos = [bsz, time_steps] where all values = 1 for positive classes_shape = positive[ :, 0, : ] # same as positive or negative but dropping dim=1 pos_correct_classes = torch.ones_like(classes_shape, dtype=int) neg_correct_classes = torch.zeros_like(classes_shape, dtype=int) # taking log because network outputs softmax # NLLLoss expects a logsoftmax input positive = torch.log(positive) negative = torch.log(negative) if not torch.all(mask == 1): # pos_correct_classes [bsz, timesteps] mask [bsz, 1, timesptes] # mask is applied to the watermark, this basically flips the tgt class from 1 (positive) # to 0 (negative) in the correct places pos_correct_classes = pos_correct_classes * mask[:, 0, :].to(int) loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) # no need for negative class loss here since some of the watermark # is masked to negative return loss_p else: loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) loss_n = self.n_weight * self.criterion(negative, neg_correct_classes) return loss_p + loss_nDefines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
class WMMbLoss (temperature: float, loss_type: Literal['bce', 'mse'])-
Expand source code
class WMMbLoss(nn.Module): def __init__(self, temperature: float, loss_type: Literal["bce", "mse"]) -> None: """ Compute the masked sample-level detection loss (https://arxiv.org/pdf/2401.17264) Args: temperature: temperature for loss computation loss_type: bce or mse between outputs and original message """ super().__init__() self.bce_with_logits = ( nn.BCEWithLogitsLoss() ) # same as Softmax + NLLLoss, but when only 1 output unit self.mse = nn.MSELoss() self.loss_type = loss_type self.temperature = temperature def forward(self, positive, negative, mask, message): """ Compute decoding loss Args: positive: outputs on watermarked samples [bsz, 2+nbits, time_steps] negative: outputs on not watermarked samples [bsz, 2+nbits, time_steps] mask: watermark mask [bsz, 1, time_steps] message: original message [bsz, nbits] or None """ # # no use of negative at the moment # negative = negative[:, 2:, :] # b 2+nbits t -> b nbits t # negative = torch.masked_select(negative, mask) if message.size(0) == 0: return torch.tensor(0.0) positive = positive[:, 2:, :] # b 2+nbits t -> b nbits t assert ( positive.shape[-2] == message.shape[1] ), "in decoding loss: \ enc and dec don't share nbits, are you using multi-bit?" # cut last dim of positive to keep only where mask is 1 new_shape = [*positive.shape[:-1], -1] # b nbits -1 positive = torch.masked_select(positive, mask == 1).reshape(new_shape) message = message.unsqueeze(-1).repeat(1, 1, positive.shape[2]) # b k -> b k t if self.loss_type == "bce": # in this case similar to temperature in softmax loss = self.bce_with_logits(positive / self.temperature, message.float()) elif self.loss_type == "mse": loss = self.mse(positive / self.temperature, message.float()) return lossBase class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to, etc.Note
As per the example above, an
__init__()call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Compute the masked sample-level detection loss (https://arxiv.org/pdf/2401.17264)
Args
temperature- temperature for loss computation
loss_type- bce or mse between outputs and original message
Ancestors
- torch.nn.modules.module.Module
Class variables
var call_super_init : boolvar dump_patches : boolvar training : bool
Methods
def forward(self, positive, negative, mask, message) ‑> Callable[..., Any]-
Expand source code
def forward(self, positive, negative, mask, message): """ Compute decoding loss Args: positive: outputs on watermarked samples [bsz, 2+nbits, time_steps] negative: outputs on not watermarked samples [bsz, 2+nbits, time_steps] mask: watermark mask [bsz, 1, time_steps] message: original message [bsz, nbits] or None """ # # no use of negative at the moment # negative = negative[:, 2:, :] # b 2+nbits t -> b nbits t # negative = torch.masked_select(negative, mask) if message.size(0) == 0: return torch.tensor(0.0) positive = positive[:, 2:, :] # b 2+nbits t -> b nbits t assert ( positive.shape[-2] == message.shape[1] ), "in decoding loss: \ enc and dec don't share nbits, are you using multi-bit?" # cut last dim of positive to keep only where mask is 1 new_shape = [*positive.shape[:-1], -1] # b nbits -1 positive = torch.masked_select(positive, mask == 1).reshape(new_shape) message = message.unsqueeze(-1).repeat(1, 1, positive.shape[2]) # b k -> b k t if self.loss_type == "bce": # in this case similar to temperature in softmax loss = self.bce_with_logits(positive / self.temperature, message.float()) elif self.loss_type == "mse": loss = self.mse(positive / self.temperature, message.float()) return lossCompute decoding loss
Args
positive- outputs on watermarked samples [bsz, 2+nbits, time_steps]
negative- outputs on not watermarked samples [bsz, 2+nbits, time_steps]
mask- watermark mask [bsz, 1, time_steps]
message- original message [bsz, nbits] or None