Source code for neuraltrain.models.conformer
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Conformer configuration wrapper for torchaudio.models.Conformer.
"""
import torch
from torch import nn
from .base import BaseModelConfig
[docs]
class Conformer(BaseModelConfig):
"""
Reference: Gulati et al., "Conformer: Convolution-augmented Transformer
for Speech Recognition", Interspeech 2020.
See https://arxiv.org/abs/2005.08100.
The Conformer combines self-attention (Transformer) and local convolution
blocks to capture both global and short-range temporal dependencies.
Parameters
----------
num_heads : int, optional
Number of attention heads used in the multi-head self-attention layers.
Each head learns a different temporal relationship.
ffn_dim : int, optional
Dimension of the feed-forward layer inside each Conformer block.
Acts as the hidden expansion size for each token.
num_layers : int, optional
Number of Conformer layers to stack.
Controls the model's depth and temporal abstraction capacity.
depthwise_conv_kernel_size : int, optional
Kernel size of the depthwise convolution in each convolution module.
Controls the temporal receptive field of local processing.
dropout : float, optional
Dropout probability applied within the Conformer layers.
Helps regularize the model and prevent overfitting.
use_group_norm : bool, optional
Whether to use GroupNorm instead of BatchNorm inside the convolutional
modules. GroupNorm can be more stable for small batch sizes.
convolution_first : bool, optional
If True, applies the convolutional module before the self-attention
module inside each block. In practice this may slightly alter inductive
bias but rarely changes performance significantly.
The values shown below are the default settings used in the original paper.
"""
num_heads: int = 4
ffn_dim: int = 144
num_layers: int = 2
depthwise_conv_kernel_size: int = 31
dropout: float = 0.1
use_group_norm: bool = False
convolution_first: bool = False
def build(self, dim: int) -> nn.Module:
from torchaudio.models import Conformer
# Subclass with forward method that infers `lengths` and returns only the output tensor
class ConformerSimpleOutput(Conformer):
def forward(self, x: torch.Tensor) -> torch.Tensor:
lengths = torch.full(
(x.shape[0],),
fill_value=x.shape[1],
dtype=torch.long,
device=x.device,
)
out, _ = super().forward(input=x, lengths=lengths)
return out
kwargs = self.model_dump()
del kwargs["name"]
return ConformerSimpleOutput(input_dim=dim, **kwargs)