Embeddings¶
- class fairseq2.nn.Embedding(num_embeddings: int, embed_dim: int, pad_idx: int | None)[source]¶
-
Stores embeddings of a fixed dictionary and size.
If
pad_idx
is provided, the embedding at the specified index won’t contribute to the gradient and therefore won’t be updated during training.
- final class fairseq2.nn.StandardEmbedding(num_embeddings: int, embed_dim: int, pad_idx: int | None = None, *, init_fn: Callable[[StandardEmbedding], None] | None = None, device: device | None = None, dtype: dtype | None = None)[source]¶
Bases:
Embedding
Represents the standard implementation of
Embedding
.If
init_fn
is provided, it will be used to initialize the embedding table inreset_parameters()
.
- final class fairseq2.nn.ShardedEmbedding(gang: Gang, num_embeddings: int, embed_dim: int, pad_idx: int | None = None, *, init_fn: Callable[[StandardEmbedding], None] | None = None, device: device | None = None, dtype: dtype | None = None)[source]¶
Bases:
Embedding
,Sharded
Represents an
Embedding
sharded across its embedding dimension.If
pad_idx
is provided, the embedding at the specified index won’t contribute to the gradient and therefore won’t be updated during training.- static from_embedding(embed: StandardEmbedding, gang: Gang) ShardedEmbedding [source]¶
Creates a
ShardedEmbedding
by shardingembed
over its embedding dimension usinggang
.
- forward(x: Tensor) Tensor [source]¶
Returns the embeddings corresponding to the specified indices.
x
can have any shape.The return value will be of shape \((*,E)\), where \(*\) is the input shape and \(E\) is the dimensionality of the embeddings.
- to_embedding(device: device | None = None) StandardEmbedding [source]¶
Unshards this instance to a
StandardEmbedding
.
- fairseq2.nn.init_scaled_embedding(embed: StandardEmbedding) None [source]¶
Initializes
embed
from \(\mathcal{N}(0, \frac{1}{\text{embed_dim}})\).
Example Usage:
from fairseq2.nn import StandardEmbedding, init_scaled_embedding
# Create token embeddings
embed = StandardEmbedding(
num_embeddings=32000, # vocabulary size
embed_dim=512,
pad_idx=0,
init_fn=init_scaled_embedding
)
# Use with BatchLayout
tokens = torch.randint(0, 32000, (4, 6)) # (batch, seq)
batch_layout = BatchLayout.of(tokens, seq_lens=[4, 2, 3, 5])
embeddings = embed(tokens) # (4, 6, 512)