Module audiocraft.losses.wmloss

Classes

class WMDetectionLoss (p_weight: float = 1.0, n_weight: float = 1.0)
Expand source code
class WMDetectionLoss(nn.Module):
    """Compute the detection loss"""
    def __init__(self, p_weight: float = 1.0, n_weight: float = 1.0) -> None:
        super().__init__()
        self.criterion = nn.NLLLoss()
        self.p_weight = p_weight
        self.n_weight = n_weight

    def forward(self, positive, negative, mask, message=None):

        positive = positive[:, :2, :]  # b 2+nbits t -> b 2 t
        negative = negative[:, :2, :]  # b 2+nbits t -> b 2 t

        # dimensionality of positive [bsz, classes=2, time_steps]
        # correct classes for pos = [bsz, time_steps] where all values = 1 for positive
        classes_shape = positive[
            :, 0, :
        ]  # same as positive or negative but dropping dim=1
        pos_correct_classes = torch.ones_like(classes_shape, dtype=int)
        neg_correct_classes = torch.zeros_like(classes_shape, dtype=int)

        # taking log because network outputs softmax
        # NLLLoss expects a logsoftmax input
        positive = torch.log(positive)
        negative = torch.log(negative)

        if not torch.all(mask == 1):
            # pos_correct_classes [bsz, timesteps] mask [bsz, 1, timesptes]
            # mask is applied to the watermark, this basically flips the tgt class from 1 (positive)
            # to 0 (negative) in the correct places
            pos_correct_classes = pos_correct_classes * mask[:, 0, :].to(int)
            loss_p = self.p_weight * self.criterion(positive, pos_correct_classes)
            # no need for negative class loss here since some of the watermark
            # is masked to negative
            return loss_p

        else:
            loss_p = self.p_weight * self.criterion(positive, pos_correct_classes)
            loss_n = self.n_weight * self.criterion(negative, neg_correct_classes)
            return loss_p + loss_n

Compute the detection loss

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def forward(self, positive, negative, mask, message=None) ‑> Callable[..., Any]
Expand source code
def forward(self, positive, negative, mask, message=None):

    positive = positive[:, :2, :]  # b 2+nbits t -> b 2 t
    negative = negative[:, :2, :]  # b 2+nbits t -> b 2 t

    # dimensionality of positive [bsz, classes=2, time_steps]
    # correct classes for pos = [bsz, time_steps] where all values = 1 for positive
    classes_shape = positive[
        :, 0, :
    ]  # same as positive or negative but dropping dim=1
    pos_correct_classes = torch.ones_like(classes_shape, dtype=int)
    neg_correct_classes = torch.zeros_like(classes_shape, dtype=int)

    # taking log because network outputs softmax
    # NLLLoss expects a logsoftmax input
    positive = torch.log(positive)
    negative = torch.log(negative)

    if not torch.all(mask == 1):
        # pos_correct_classes [bsz, timesteps] mask [bsz, 1, timesptes]
        # mask is applied to the watermark, this basically flips the tgt class from 1 (positive)
        # to 0 (negative) in the correct places
        pos_correct_classes = pos_correct_classes * mask[:, 0, :].to(int)
        loss_p = self.p_weight * self.criterion(positive, pos_correct_classes)
        # no need for negative class loss here since some of the watermark
        # is masked to negative
        return loss_p

    else:
        loss_p = self.p_weight * self.criterion(positive, pos_correct_classes)
        loss_n = self.n_weight * self.criterion(negative, neg_correct_classes)
        return loss_p + loss_n

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class WMMbLoss (temperature: float, loss_type: Literal['bce', 'mse'])
Expand source code
class WMMbLoss(nn.Module):
    def __init__(self, temperature: float, loss_type: Literal["bce", "mse"]) -> None:
        """
        Compute the masked sample-level detection loss
        (https://arxiv.org/pdf/2401.17264)

        Args:
            temperature: temperature for loss computation
            loss_type: bce or mse between outputs and original message
        """
        super().__init__()
        self.bce_with_logits = (
            nn.BCEWithLogitsLoss()
        )  # same as Softmax + NLLLoss, but when only 1 output unit
        self.mse = nn.MSELoss()
        self.loss_type = loss_type
        self.temperature = temperature

    def forward(self, positive, negative, mask, message):
        """
        Compute decoding loss
        Args:
            positive: outputs on watermarked samples [bsz, 2+nbits, time_steps]
            negative: outputs on not watermarked samples [bsz, 2+nbits, time_steps]
            mask: watermark mask [bsz, 1, time_steps]
            message: original message [bsz, nbits] or None
        """
        # # no use of negative at the moment
        # negative = negative[:, 2:, :]  # b 2+nbits t -> b nbits t
        # negative = torch.masked_select(negative, mask)
        if message.size(0) == 0:
            return torch.tensor(0.0)
        positive = positive[:, 2:, :]  # b 2+nbits t -> b nbits t
        assert (
            positive.shape[-2] == message.shape[1]
        ), "in decoding loss: \
            enc and dec don't share nbits, are you using multi-bit?"

        # cut last dim of positive to keep only where mask is 1
        new_shape = [*positive.shape[:-1], -1]  # b nbits -1
        positive = torch.masked_select(positive, mask == 1).reshape(new_shape)

        message = message.unsqueeze(-1).repeat(1, 1, positive.shape[2])  # b k -> b k t
        if self.loss_type == "bce":
            # in this case similar to temperature in softmax
            loss = self.bce_with_logits(positive / self.temperature, message.float())
        elif self.loss_type == "mse":
            loss = self.mse(positive / self.temperature, message.float())

        return loss

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Compute the masked sample-level detection loss (https://arxiv.org/pdf/2401.17264)

Args

temperature
temperature for loss computation
loss_type
bce or mse between outputs and original message

Ancestors

  • torch.nn.modules.module.Module

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def forward(self, positive, negative, mask, message) ‑> Callable[..., Any]
Expand source code
def forward(self, positive, negative, mask, message):
    """
    Compute decoding loss
    Args:
        positive: outputs on watermarked samples [bsz, 2+nbits, time_steps]
        negative: outputs on not watermarked samples [bsz, 2+nbits, time_steps]
        mask: watermark mask [bsz, 1, time_steps]
        message: original message [bsz, nbits] or None
    """
    # # no use of negative at the moment
    # negative = negative[:, 2:, :]  # b 2+nbits t -> b nbits t
    # negative = torch.masked_select(negative, mask)
    if message.size(0) == 0:
        return torch.tensor(0.0)
    positive = positive[:, 2:, :]  # b 2+nbits t -> b nbits t
    assert (
        positive.shape[-2] == message.shape[1]
    ), "in decoding loss: \
        enc and dec don't share nbits, are you using multi-bit?"

    # cut last dim of positive to keep only where mask is 1
    new_shape = [*positive.shape[:-1], -1]  # b nbits -1
    positive = torch.masked_select(positive, mask == 1).reshape(new_shape)

    message = message.unsqueeze(-1).repeat(1, 1, positive.shape[2])  # b k -> b k t
    if self.loss_type == "bce":
        # in this case similar to temperature in softmax
        loss = self.bce_with_logits(positive / self.temperature, message.float())
    elif self.loss_type == "mse":
        loss = self.mse(positive / self.temperature, message.float())

    return loss

Compute decoding loss

Args

positive
outputs on watermarked samples [bsz, 2+nbits, time_steps]
negative
outputs on not watermarked samples [bsz, 2+nbits, time_steps]
mask
watermark mask [bsz, 1, time_steps]
message
original message [bsz, nbits] or None