Batch Layout

final class fairseq2.nn.BatchLayout(shape, seq_lens, *, packed=False, device=None)[source]

Bases: object

static of(batch, seq_lens=None, *, packed=False)[source]
Return type:

BatchLayout

property width: int
property seq_begin_indices: Sequence[int]
property seq_begin_indices_pt: Tensor
property seq_lens: Sequence[int]
property seq_lens_pt: Tensor
property min_seq_len: int
compiled_max_seq_len: ClassVar[int | None] = None
property max_seq_len: int
property position_indices: Tensor
property padded: bool
property packed: bool

BatchLayout is the cornerstone of fairseq2 v0.5’s unified batching system. It consolidates both padded and packed batching strategies under a single, consistent API, making it easier to work with variable-length sequences efficiently.

Key Features:

  • Unified API: Single interface for both padded and packed batch modes

  • Memory Efficient: Optimized handling of variable-length sequences

  • Torch.compile Compatible: Full integration with PyTorch’s compilation system

  • Dynamic Sequences: Support for dynamic sequence lengths during training

Creating BatchLayout

For Padded Batches:

from fairseq2.nn import BatchLayout
from fairseq2.device import get_default_device
import torch

device = get_default_device()  # "cuda" or "cpu"
# Create a padded batch layout
# Shape: (batch_size=4, max_seq_len=6)
# Individual sequence lengths: [4, 2, 3, 5]
batch_layout = BatchLayout(
    shape=(4, 6),
    seq_lens=[4, 2, 3, 5],
    packed=False,
    device=device
)

print(f"Width: {batch_layout.width}")  # 6
print(f"Sequence lengths: {list(batch_layout.seq_lens)}")  # [4, 2, 3, 5]
print(f"Is padded: {batch_layout.padded}")  # True
print(f"Is packed: {batch_layout.packed}")  # False

For Packed Batches:

# Create a packed batch layout
# Total elements: 14 (4+2+3+5), sequences: [4, 2, 3, 5]
packed_layout = BatchLayout(
    shape=(14,),  # 1D shape for packed mode
    seq_lens=[4, 2, 3, 5],
    packed=True,
    device=torch.device("cpu")
)

print(f"Sequence begin indices: {list(packed_layout.seq_begin_indices)}")
# [0, 4, 6, 9, 14]
print(f"Is packed: {packed_layout.packed}")  # True

# For packed batches, get sequence boundaries
if packed_layout.packed:
    seq_boundaries = packed_layout.seq_begin_indices_pt
    print(f"Sequence boundaries: {seq_boundaries}")

From Existing Tensors:

# Create from existing batch tensor
batch_tensor = torch.randn(4, 6, 512)  # (batch, seq, features)
batch_layout = BatchLayout.of(
    batch_tensor,
    seq_lens=[4, 2, 3, 5],
    packed=False
)

Working with Position Indices and Masks

BatchLayout automatically computes position indices and masking information:

batch_layout = BatchLayout((4, 6), seq_lens=[4, 2, 3, 5])

# Position indices for each element (-1 indicates padding)
pos_indices = batch_layout.position_indices
# Shape: (4, 6)
# [[0, 1, 2, 3, -1, -1],
#  [0, 1, -1, -1, -1, -1],
#  [0, 1, 2, -1, -1, -1],
#  [0, 1, 2, 3, 4, -1]]

# Create padding mask (True for valid positions)
padding_mask = pos_indices >= 0

# Apply mask to hide padding positions
from fairseq2.nn.utils.mask import apply_mask
masked_batch = apply_mask(batch_tensor, padding_mask, fill_value=0.0)

Sequence Information

BatchLayout provides comprehensive sequence metadata:

batch_layout = BatchLayout((4, 6), seq_lens=[4, 2, 3, 5])

# Sequence properties
print(f"Min sequence length: {batch_layout.min_seq_len}")  # 2
print(f"Max sequence length: {batch_layout.max_seq_len}")  # 5
print(f"Batch width: {batch_layout.width}")  # 6

# Sequence lengths as tensors (for GPU operations)
seq_lens_tensor = batch_layout.seq_lens_pt  # torch.Tensor([4, 2, 3, 5])

Integration with Neural Network Layers

BatchLayout is designed to work seamlessly with all fairseq2 neural network layers:

import torch.nn as nn
from fairseq2.nn import BatchLayout

# Example attention layer that works with BatchLayout
class AttentionLayer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)

    def forward(self, x, batch_layout: BatchLayout):
        # Create attention mask from batch layout
        seq_len = x.size(1)
        attn_mask = batch_layout.position_indices < 0  # True for padding

        # Expand mask for attention heads if needed
        if batch_layout.packed:
            # Handle packed sequences differently
            # Use sequence boundaries for efficient attention
            pass
        else:
            # Standard padded attention with mask
            output, _ = self.attention(x, x, x, key_padding_mask=attn_mask)

        return output

Torch.compile Integration

BatchLayout is fully compatible with PyTorch’s compilation system:

import torch
from fairseq2.nn import BatchLayout

@torch.compile
def process_batch(batch_tensor, batch_layout: BatchLayout):
    # Position indices are automatically marked as dynamic
    pos_indices = batch_layout.position_indices

    # Use compiled operations with dynamic sequences
    mask = pos_indices >= 0
    return batch_tensor.masked_fill(~mask.unsqueeze(-1), 0.0)

# The compiled function handles dynamic sequence lengths efficiently
batch_layout = BatchLayout((4, 6), seq_lens=[4, 2, 3, 5])
batch_tensor = torch.randn(4, 6, 512)
result = process_batch(batch_tensor, batch_layout)

Performance Considerations

Packed vs Padded Trade-offs:

  • Packed: More memory efficient, better for variable lengths, requires careful indexing

  • Padded: Simpler operations, better for uniform attention, may waste memory on padding

# Memory comparison
seq_lens = [100, 50, 75, 25]  # Variable length sequences

# Padded: allocates max_len for all sequences
padded_layout = BatchLayout((4, 100), seq_lens=seq_lens, packed=False)
padded_memory = 4 * 100  # 400 positions

# Packed: only allocates needed positions
packed_layout = BatchLayout((250,), seq_lens=seq_lens, packed=True)
packed_memory = sum(seq_lens)  # 250 positions (37.5% savings)