Module audiocraft.adversarial.discriminators.base

Classes

class MultiDiscriminator
Expand source code
class MultiDiscriminator(ABC, nn.Module):
    """Base implementation for discriminators composed of sub-discriminators acting at different scales.
    """
    def __init__(self):
        super().__init__()

    @abstractmethod
    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
        ...

    @property
    @abstractmethod
    def num_discriminators(self) -> int:
        """Number of discriminators.
        """
        ...

Base implementation for discriminators composed of sub-discriminators acting at different scales.

Ancestors

  • abc.ABC
  • torch.nn.modules.module.Module

Subclasses

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Instance variables

prop num_discriminators : int
Expand source code
@property
@abstractmethod
def num_discriminators(self) -> int:
    """Number of discriminators.
    """
    ...

Number of discriminators.

Methods

def forward(self, x: torch.Tensor) ‑> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]
Expand source code
@abstractmethod
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
    ...

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.