Module audiocraft.models.watermark

Classes

class AudioSeal (generator: torch.nn.modules.module.Module,
detector: torch.nn.modules.module.Module,
nbits: int = 0)
Expand source code
class AudioSeal(WMModel):
    """Wrap Audioseal (https://github.com/facebookresearch/audioseal) for the
    training and evaluation. The generator and detector are jointly trained
    """

    def __init__(
        self,
        generator: nn.Module,
        detector: nn.Module,
        nbits: int = 0,
    ):
        super().__init__()
        self.generator = generator  # type: ignore
        self.detector = detector  # type: ignore

        # Allow to re-train an n-bit model with new 0-bit message
        self.nbits = nbits if nbits else self.generator.msg_processor.nbits

    def get_watermark(
        self,
        x: torch.Tensor,
        message: tp.Optional[torch.Tensor] = None,
        sample_rate: int = 16_000,
    ) -> torch.Tensor:
        return self.generator.get_watermark(x, message=message, sample_rate=sample_rate)

    def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
        """
        Detect the watermarks from the audio signal.  The first two units of the output
        are used for detection, the rest is used to decode the message. If the audio is
        not watermarked, the message will be random.

        Args:
            x: Audio signal, size batch x frames
        Returns
            torch.Tensor: Detection + decoding results of shape (B, 2+nbits, T).
        """

        # Getting the direct decoded message from the detector
        result = self.detector.detector(x)  # b x 2+nbits
        # hardcode softmax on 2 first units used for detection
        result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
        return result

    def forward(  # generator
        self,
        x: torch.Tensor,
        message: tp.Optional[torch.Tensor] = None,
        sample_rate: int = 16_000,
        alpha: float = 1.0,
    ) -> torch.Tensor:
        """Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)"""
        wm = self.get_watermark(x, message)
        return x + alpha * wm

    @staticmethod
    def get_pretrained(name="base", device=None) -> WMModel:
        if device is None:
            if torch.cuda.device_count():
                device = "cuda"
            else:
                device = "cpu"
        return load_audioseal_models("facebook/audioseal", filename=name, device=device)

Wrap Audioseal (https://github.com/facebookresearch/audioseal) for the training and evaluation. The generator and detector are jointly trained

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • WMModel
  • abc.ABC
  • torch.nn.modules.module.Module

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Static methods

def get_pretrained(name='base', device=None) ‑> WMModel
Expand source code
@staticmethod
def get_pretrained(name="base", device=None) -> WMModel:
    if device is None:
        if torch.cuda.device_count():
            device = "cuda"
        else:
            device = "cpu"
    return load_audioseal_models("facebook/audioseal", filename=name, device=device)

Methods

def detect_watermark(self, x: torch.Tensor) ‑> torch.Tensor
Expand source code
def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
    """
    Detect the watermarks from the audio signal.  The first two units of the output
    are used for detection, the rest is used to decode the message. If the audio is
    not watermarked, the message will be random.

    Args:
        x: Audio signal, size batch x frames
    Returns
        torch.Tensor: Detection + decoding results of shape (B, 2+nbits, T).
    """

    # Getting the direct decoded message from the detector
    result = self.detector.detector(x)  # b x 2+nbits
    # hardcode softmax on 2 first units used for detection
    result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
    return result

Detect the watermarks from the audio signal. The first two units of the output are used for detection, the rest is used to decode the message. If the audio is not watermarked, the message will be random.

Args

x
Audio signal, size batch x frames

Returns torch.Tensor: Detection + decoding results of shape (B, 2+nbits, T).

def forward(self,
x: torch.Tensor,
message: torch.Tensor | None = None,
sample_rate: int = 16000,
alpha: float = 1.0) ‑> torch.Tensor
Expand source code
def forward(  # generator
    self,
    x: torch.Tensor,
    message: tp.Optional[torch.Tensor] = None,
    sample_rate: int = 16_000,
    alpha: float = 1.0,
) -> torch.Tensor:
    """Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)"""
    wm = self.get_watermark(x, message)
    return x + alpha * wm

Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)

Inherited members

class WMModel (*args, **kwargs)
Expand source code
class WMModel(ABC, nn.Module):
    """
    A wrapper interface to different watermarking models for
    training or evaluation purporses
    """

    @abstractmethod
    def get_watermark(
        self,
        x: torch.Tensor,
        message: tp.Optional[torch.Tensor] = None,
        sample_rate: int = 16_000,
    ) -> torch.Tensor:
        """Get the watermark from an audio tensor and a message.
        If the input message is None, a random message of
        n bits {0,1} will be generated
        """

    @abstractmethod
    def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
        """Detect the watermarks from the audio signal

        Args:
            x: Audio signal, size batch x frames

        Returns:
            tensor of size (B, 2+n, frames) where:
            Detection results of shape (B, 2, frames)
            Message decoding results of shape (B, n, frames)
        """

A wrapper interface to different watermarking models for training or evaluation purporses

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • abc.ABC
  • torch.nn.modules.module.Module

Subclasses

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def detect_watermark(self, x: torch.Tensor) ‑> torch.Tensor
Expand source code
@abstractmethod
def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
    """Detect the watermarks from the audio signal

    Args:
        x: Audio signal, size batch x frames

    Returns:
        tensor of size (B, 2+n, frames) where:
        Detection results of shape (B, 2, frames)
        Message decoding results of shape (B, n, frames)
    """

Detect the watermarks from the audio signal

Args

x
Audio signal, size batch x frames

Returns

tensor of size (B, 2+n, frames) where: Detection results of shape (B, 2, frames) Message decoding results of shape (B, n, frames)

def forward(self, *input: Any) ‑> None
Expand source code
def _forward_unimplemented(self, *input: Any) -> None:
    r"""Defines the computation performed at every call.

    Should be overridden by all subclasses.

    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
        registered hooks while the latter silently ignores them.
    """
    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_watermark(self,
x: torch.Tensor,
message: torch.Tensor | None = None,
sample_rate: int = 16000) ‑> torch.Tensor
Expand source code
@abstractmethod
def get_watermark(
    self,
    x: torch.Tensor,
    message: tp.Optional[torch.Tensor] = None,
    sample_rate: int = 16_000,
) -> torch.Tensor:
    """Get the watermark from an audio tensor and a message.
    If the input message is None, a random message of
    n bits {0,1} will be generated
    """

Get the watermark from an audio tensor and a message. If the input message is None, a random message of n bits {0,1} will be generated