Embeddings

class fairseq2.nn.Embedding(num_embeddings: int, embed_dim: int, pad_idx: int | None)[source]

Bases: Module, ABC

Stores embeddings of a fixed dictionary and size.

If pad_idx is provided, the embedding at the specified index won’t contribute to the gradient and therefore won’t be updated during training.

abstract forward(x: Tensor) Tensor[source]

Returns the embeddings corresponding to the specified indices.

x can have any shape.

The return value will be of shape \((*,E)\), where \(*\) is the input shape and \(E\) is the dimensionality of the embeddings.

final class fairseq2.nn.StandardEmbedding(num_embeddings: int, embed_dim: int, pad_idx: int | None = None, *, init_fn: Callable[[StandardEmbedding], None] | None = None, device: device | None = None, dtype: dtype | None = None)[source]

Bases: Embedding

Represents the standard implementation of Embedding.

If init_fn is provided, it will be used to initialize the embedding table in reset_parameters().

reset_parameters() None[source]
forward(x: Tensor) Tensor[source]

Returns the embeddings corresponding to the specified indices.

x can have any shape.

The return value will be of shape \((*,E)\), where \(*\) is the input shape and \(E\) is the dimensionality of the embeddings.

final class fairseq2.nn.ShardedEmbedding(gang: Gang, num_embeddings: int, embed_dim: int, pad_idx: int | None = None, *, init_fn: Callable[[StandardEmbedding], None] | None = None, device: device | None = None, dtype: dtype | None = None)[source]

Bases: Embedding, Sharded

Represents an Embedding sharded across its embedding dimension.

If pad_idx is provided, the embedding at the specified index won’t contribute to the gradient and therefore won’t be updated during training.

static from_embedding(embed: StandardEmbedding, gang: Gang) ShardedEmbedding[source]

Creates a ShardedEmbedding by sharding embed over its embedding dimension using gang.

reset_parameters() None[source]
forward(x: Tensor) Tensor[source]

Returns the embeddings corresponding to the specified indices.

x can have any shape.

The return value will be of shape \((*,E)\), where \(*\) is the input shape and \(E\) is the dimensionality of the embeddings.

to_embedding(device: device | None = None) StandardEmbedding[source]

Unshards this instance to a StandardEmbedding.

get_shard_dims() list[tuple[Parameter, int]][source]

Returns the sharding information for this module’s parameters.

This function returns a list of tuples where each tuple contains the sharded parameter within this module and the tensor dimension along which the parameter is sharded.

fairseq2.nn.init_scaled_embedding(embed: StandardEmbedding) None[source]

Initializes embed from \(\mathcal{N}(0, \frac{1}{\text{embed_dim}})\).

Example Usage:

from fairseq2.nn import StandardEmbedding, init_scaled_embedding

# Create token embeddings
embed = StandardEmbedding(
    num_embeddings=32000,  # vocabulary size
    embed_dim=512,
    pad_idx=0,
    init_fn=init_scaled_embedding
)

# Use with BatchLayout
tokens = torch.randint(0, 32000, (4, 6))  # (batch, seq)
batch_layout = BatchLayout.of(tokens, seq_lens=[4, 2, 3, 5])

embeddings = embed(tokens)  # (4, 6, 512)