Embeddings¶
- class fairseq2.nn.Embedding(num_embeddings, embed_dim, pad_idx)[source]¶
-
Stores embeddings of a fixed dictionary and size.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- final class fairseq2.nn.StandardEmbedding(num_embeddings, embed_dim, pad_idx=None, *, init_fn=None, device=None, dtype=None)[source]¶
Bases:
Embedding
Stores embeddings of a fixed dictionary and size in an in-memory table.
- Parameters:
num_embeddings (int) – The size of the embedding table.
embed_dim (int) – The dimensionality of returned embeddings.
pad_idx (int | None) – If not
None
, entries atpad_idx
do not contribute to the gradient; therefore, the embedding atpad_idx
is not updated during training.init_fn (Callable[[StandardEmbedding], None] | None) – The callable to initialize the embedding table.
- final class fairseq2.nn.ShardedEmbedding(gang, num_embeddings, embed_dim, pad_idx=None, *, init_fn=None, device=None, dtype=None)[source]¶
Bases:
Embedding
Represents a
StandardEmbedding
that is sharded across its embedding dimension.- Parameters:
gang (Gang) – The gang over which to shard the embedding table.
num_embeddings (int) – The size of the embedding table.
embed_dim (int) – The dimensionality of returned embeddings.
pad_idx (int | None) – If not
None
, entries atpad_idx
do not contribute to the gradient; therefore, the embedding atpad_idx
is not updated during training.init_fn (Callable[[StandardEmbedding], None] | None) – The callable to initialize the embedding table.
- static from_embedding(embed, gang)[source]¶
Constructs a
ShardedEmbedding
by shardingembed
.- Parameters:
embed (StandardEmbedding) – The embedding to shard.
gang (Gang) – The gang over which to shard
embed
.
- Return type:
- to_embedding(device=None)[source]¶
Converts this instance to a
StandardEmbedding
.- Return type:
- fairseq2.nn.init_scaled_embedding(embed)[source]¶
Initializes
embed
from \(\mathcal{N}(0, \frac{1}{\text{embed_dim}})\).
Example Usage:
from fairseq2.nn import StandardEmbedding, init_scaled_embedding
# Create token embeddings
embed = StandardEmbedding(
num_embeddings=32000, # vocabulary size
embed_dim=512,
pad_idx=0,
init_fn=init_scaled_embedding
)
# Use with BatchLayout
tokens = torch.randint(0, 32000, (4, 6)) # (batch, seq)
batch_layout = BatchLayout.of(tokens, seq_lens=[4, 2, 3, 5])
embeddings = embed(tokens) # (4, 6, 512)