fairseq2.recipes.wav2vec2.train

        classDiagram
  ABC <|-- CliCommandHandler
  CliCommandHandler <|-- RecipeCommandHandler
  Generic <|-- RecipeCommandHandler
    

Classes

class fairseq2.recipes.wav2vec2.train.Wav2Vec2TrainConfig(*, dataset='librispeech_960h', train_split='train', valid_split='valid', min_audio_len=32000, max_audio_len=250000, max_num_elements=1500000, normalize_audio=False, batch_shuffle_window=0, num_prefetch=4, model_family='wav2vec2', model_arch='base', model_config=None, dtype=torch.float16, data_parallelism='ddp', fsdp_wrap_granularity='stack', torch_compile=False, optimizer='adamw', optimizer_config=<factory>, lr_scheduler='polynomial-decay', lr_scheduler_config=<factory>, max_gradient_norm=None, fp16_loss_scale=(128.0, 0.0001), gradient_accumulation=1, diversity_loss_weight=0.1, feature_penalty_weight=10.0, max_num_steps=400000, max_num_data_epochs=None, validate_every_n_steps=5000, checkpoint_every_n_steps=25000, keep_best_n_checkpoints=1, publish_metrics_every_n_steps=200, resume_checkpoint_dir=None, seed=2, profile=None, monitored_gang=False, anomaly_detection=False)[source]

Bases: object

Holds the configuration of a wav2vec 2.0 model training task.

The default values correspond to the base ls960h training setup as described in Baevski et al. [BZMA20].

dataset: str | AssetCard | Path = 'librispeech_960h'

The name, path or path to the asset card of the speech dataset.

train_split: str = 'train'

The name of the train data split.

valid_split: str = 'valid'

The name of the valid data split.

min_audio_len: int = 32000

The minimum audio sequence length.

max_audio_len: int = 250000

The maximum audio sequence length.

max_num_elements: int = 1500000

The maximum number of elements per batch.

normalize_audio: bool = False

If True, normalizes audio to have zero mean and unit variance.

batch_shuffle_window: int = 0

The size of the sliding window for shuffling batches.

num_prefetch: int = 4

The number of batches to prefetch in background.

model_family: str = 'wav2vec2'

The family of the model.

model_arch: str | None = 'base'

The architecture of the wav2vec2 model.

model_config: Any = None

The configuration of the model.

dtype: dtype = torch.float16

The data type of the model.

data_parallelism: Literal['ddp', 'fsdp'] = 'ddp'

The data parallelism API to use.

fsdp_wrap_granularity: Literal['layer', 'stack', 'model'] = 'stack'

The granularity at which to wrap the model.

torch_compile: bool = False

If True, applies torch.compile() to the encoder. (experimental)

optimizer: str = 'adamw'

The optimizer.

optimizer_config: Any

The configuration of the optimizer.

lr_scheduler: str = 'polynomial-decay'

The learning rate scheduler.

lr_scheduler_config: Any

The configuration of the learning rate scheduler.

max_gradient_norm: float | None = None

The maximum gradient norm. If None, no clipping will be applied.

fp16_loss_scale: tuple[float, float] = (128.0, 0.0001)

The initial and minimum loss scale for fp16 training.

gradient_accumulation: int = 1

The number of steps to accumulate gradients before an optimizer update.

diversity_loss_weight: float = 0.1

The weight of the diversity loss.

feature_penalty_weight: float = 10.0

The weight of the regularization penalty applied to the extracted features.

max_num_steps: int = 400000

The maximum number of steps to train for.

max_num_data_epochs: int | None = None

The maximum number of data epochs to train for.

validate_every_n_steps: int = 5000

The step interval at which to validate the model.

checkpoint_every_n_steps: int = 25000

The step interval at which to checkpoint.

keep_best_n_checkpoints: int | None = 1

The number of checkpoints to keep based on their validation score. If None, none will be deleted.

publish_metrics_every_n_steps: int = 200

The step interval at which to publish metrics.

resume_checkpoint_dir: Path | None = None

If not None, adds the specified path to the default asset store.

seed: int = 2

The random number generator seed to use.

profile: tuple[int, int] | None = None

The number of steps that the PyTorch profiler should skip and then record.

monitored_gang: bool = False

If True, puts a monitored barrier before every collective call.

anomaly_detection: bool = False

If True, enables the anomaly detection feature of torch.autograd.

final class fairseq2.recipes.wav2vec2.train.Wav2Vec2TrainUnit(criterion, gang)[source]

Bases: AbstractTrainUnit[SequenceBatch]

property metric_bag: Wav2Vec2MetricBag

The training-related metrics.

Functions

fairseq2.recipes.wav2vec2.train.load_wav2vec2_trainer(config, output_dir)[source]

Load a Trainer for wav2vec 2.0 model training.

Return type:

Trainer[SequenceBatch]