Residual Connections

class fairseq2.nn.ResidualConnect(*args, **kwargs)[source]

Bases: Module, ABC

Represents a residual connection.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstract forward(seqs, residual)[source]
Parameters:
  • seqs (Tensor) – The sequences output by a module. Shape: \((N,S,M)\), where \(N\) is the batch size, \(S\) is the sequence length, and \(M\) is the dimensionality of the model.

  • residual (Tensor) – The input sequences to the module. Shape: Same as seqs.

Returns:

The output sequences with residuals applied. Shape: Same as seqs.

Return type:

Tensor

final class fairseq2.nn.AdditiveResidualConnect(*args, **kwargs)[source]

Bases: ResidualConnect

Sums inputs and outputs of a module.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(seqs, residual)[source]
Parameters:
  • seqs (Tensor) – The sequences output by a module. Shape: \((N,S,M)\), where \(N\) is the batch size, \(S\) is the sequence length, and \(M\) is the dimensionality of the model.

  • residual (Tensor) – The input sequences to the module. Shape: Same as seqs.

Returns:

The output sequences with residuals applied. Shape: Same as seqs.

Return type:

Tensor

final class fairseq2.nn.ScaledResidualConnect(scale)[source]

Bases: ResidualConnect

Scales residuals by a constant factor before adding them to the output of a Transformer module.

Parameters:

scale (float) – The scale factor.

forward(seqs, residual)[source]
Parameters:
  • seqs (Tensor) – The sequences output by a module. Shape: \((N,S,M)\), where \(N\) is the batch size, \(S\) is the sequence length, and \(M\) is the dimensionality of the model.

  • residual (Tensor) – The input sequences to the module. Shape: Same as seqs.

Returns:

The output sequences with residuals applied. Shape: Same as seqs.

Return type:

Tensor