fairseq2.models.olmo

The OLMo module provides support for OLMo2 and OLMo3 language models from the Allen Institute for AI. It includes model configurations, hub access, tokenizers, and utilities for loading and working with OLMo models.

Quick Start

from fairseq2.models.olmo import get_olmo_model_hub, load_olmo_tokenizer

# Get the model hub
hub = get_olmo_model_hub()

# List available architectures
for arch in sorted(hub.get_archs()):
    print(f"  - {arch}")

# Load a model
model = hub.load_model("olmo2_7b")

# Load corresponding tokenizer
tokenizer = load_olmo_tokenizer("olmo2_7b")

Available Models

OLMo2 Series — standard causal attention, 4K context:

  • olmo2_1b - 1B parameters

  • olmo2_7b - 7B parameters

  • olmo2_13b - 13B parameters

  • olmo2_32b - 32B parameters (GQA)

OLMo3 Series — hybrid sliding window + full attention, 8K–65K context with YaRN:

  • olmo3_7b - 7B parameters

  • olmo3_32b - 32B parameters (GQA)

Model Configuration

OLMOConfig

class fairseq2.models.olmo.OLMOConfig(*, model_dim: int = 2048, max_seq_len: int = 4096, vocab_size: int = 100352, pad_idx: int = 100277, bos_token_id: int | None = None, eos_token_id: int = 100257, tied_embeddings: bool = False, num_layers: int = 16, num_attn_heads: int = 16, num_key_value_heads: int = 16, ffn_inner_dim: int = 8192, rms_norm_eps: float = 1e-06, rope_theta: float = 500000.0, dropout_p: float = 0.0, init_std: float | None = None, init_std_scale: Literal['none', 'layer', 'stack'] = 'layer', shard_embed_dim: bool = True, sliding_window: int | None = None, layer_types: list[Literal['sliding_attention', 'full_attention']] | None = None, yarn_scale_config: YaRNScaleConfig | None = None)[source]

Bases: object

Holds the configuration of an OLMO model (OLMO2 and OLMO3).

This configuration supports both OLMO2 and OLMO3 architectures. The default values correspond to the allenai/OLMo-2-0425-1B model base architecture.

OLMO2: Standard causal attention, 4K context OLMO3: Hybrid sliding window + full attention, 8K-65K context

References: - OLMO2: https://arxiv.org/abs/2501.00656 - HuggingFace: https://huggingface.co/allenai/OLMo-2-0425-1B

Configuration class for OLMo models. Extends LLaMAConfig with OLMo-specific architecture choices such as post-norm residual connections, Q/K normalization, and optional hybrid sliding window attention (OLMo3).

The default values correspond to the OLMo2 1B architecture.

Key Parameters:

  • model_dim - Model dimensionality (default: 2048)

  • num_layers - Number of decoder layers (default: 16)

  • num_attn_heads - Number of attention heads (default: 16)

  • num_key_value_heads - Key/value heads for GQA; equals num_attn_heads for MHA (default: 16)

  • max_seq_len - Maximum sequence length (default: 4096)

  • vocab_size - Vocabulary size (default: 100,352)

  • sliding_window - Sliding window size for OLMo3 hybrid attention; None for OLMo2 (default: None)

  • yarn_scale_config - YaRN scaling for OLMo3 long-context models (default: None)

model_dim: int = 2048

The dimensionality of the model.

max_seq_len: int = 4096

The maximum sequence length.

vocab_size: int = 100352

The size of the vocabulary.

pad_idx: int = 100277

The index of the PAD token in the vocabulary.

bos_token_id: int | None = None

The index of the BOS token in the vocabulary.

eos_token_id: int = 100257

The index of the EOS token in the vocabulary.

tied_embeddings: bool = False

If True, ties the embedding table and the output projection layer.

num_layers: int = 16

The number of decoder layers.

num_attn_heads: int = 16

The number of attention heads in decoder layers.

num_key_value_heads: int = 16

The number of key/value heads for Grouped Query Attention.

OLMO2 models use MHA, but the 32B variant uses GQA. OLMO3 7B uses MHA, OLMO3 32B uses GQA.

If num_key_value_heads == num_attn_heads, MHA is used. If num_key_value_heads == 1, MQA is used. Otherwise GQA is used.

ffn_inner_dim: int = 8192

The inner dimensionality of feed-forward networks.

Unlike LLaMA which derives the FFN dimension from base dim × scale × multiplier, OLMO directly specifies the final FFN inner dimension (matching HuggingFace intermediate_size). No additional scaling or rounding is applied.

rms_norm_eps: float = 1e-06

The epsilon value for RMSNorm layers.

rope_theta: float = 500000.0

The coefficient of the long-term decay of the Rotary position encoder.

dropout_p: float = 0.0

The dropout probability on outputs of Transformer layers.

init_std: float | None = None

If not None, the standard deviation to initialize input embeddings and projection weights; otherwise, model_dim ** -0.5 will be used instead.

init_std_scale: Literal['none', 'layer', 'stack'] = 'layer'

The method to use to scale init_std per layer. If ‘none’, no scaling will be applied. If ‘layer’, init_std will be scaled by the depth of the layer. If ‘stack’, init_std will be scaled by the total depth of the decoder.

shard_embed_dim: bool = True

If True, the embedding dimension is sharded across devices.

sliding_window: int | None = None

Sliding window size for local attention (OLMO3 only).

If set, enables hybrid attention pattern where most layers use sliding window attention with this window size. Every 4th layer uses full global attention. The final layer always uses full global attention.

OLMO3 uses sliding_window=4096 for efficient long-context processing. If None, all layers use full causal attention (OLMO2 behavior).

layer_types: list[Literal['sliding_attention', 'full_attention']] | None = None

Per-layer attention type configuration (OLMO3 only).

Explicitly specifies whether each layer uses ‘sliding_attention’ or ‘full_attention’. If None and sliding_window is set, automatically generates the pattern: 3 sliding window layers, 1 full attention layer, with the final layer always using full attention.

Length must match num_layers if specified.

yarn_scale_config: YaRNScaleConfig | None = None

YaRN scaling configuration for long-context models (OLMO3 only).

Enables YaRN (Yet another RoPE extensioN) scaling to extend context length from 8K to 65K. When set, ALL layers (both sliding window and full attention) share the same YaRN-scaled RoPE encoder, matching the HuggingFace behavior where a single RotaryEmbedding is shared.

If None, uses standard RoPE without scaling (default for OLMO2/3 base models).

YaRNScaleConfig

class fairseq2.models.olmo.YaRNScaleConfig(*, scale_factor: float = 8.0, original_max_seq_len: int = 8192, beta_fast: float = 32.0, beta_slow: float = 1.0, mscale: float = 1.0, mscale_all_dim: float = 0.0, truncate: bool = True)[source]

Bases: object

YaRN (Yet another RoPE extensioN) scaling configuration for long-context models.

YaRN is applied to extend the context length of OLMO3 models from 8K to 65K.

Reference: https://arxiv.org/abs/2309.00071

Configuration for YaRN (Yet another RoPE extensioN) scaling, used by OLMo3 to extend context length from 8K to 65K tokens. YaRN scaling is applied selectively to full-attention layers; sliding window layers use standard RoPE.

Reference: https://arxiv.org/abs/2309.00071

scale_factor: float = 8.0

Ratio between extended and original max sequence length (65536/8192).

original_max_seq_len: int = 8192

Original sequence length before YaRN extension.

beta_fast: float = 32.0

Parameter to set the boundary for extrapolation (high frequency).

beta_slow: float = 1.0

Parameter to set the boundary for interpolation (low frequency).

mscale: float = 1.0

Multiplier for attention scaling to maintain training stability.

mscale_all_dim: float = 0.0

Dimension-wise scaling parameter for YaRN.

truncate: bool = True

If True, truncate correction range bounds to integers. Default: True.

Tokenizer

OLMOTokenizer

final class fairseq2.models.olmo.OLMOTokenizer(model: HuggingFaceTokenModel, eos_token: str)[source]

Bases: Tokenizer

create_encoder(*, task: str | None = None, lang: str | None = None, mode: str | None = None, device: device | None = None, pin_memory: bool = False) TokenEncoder[source]
create_raw_encoder(*, device: device | None = None, pin_memory: bool = False) TokenEncoder[source]
create_decoder(*, skip_special_tokens: bool = False) TokenDecoder[source]
property vocab_info: VocabularyInfo

OLMOTokenizerConfig

class fairseq2.models.olmo.OLMOTokenizerConfig(*, use_im_end: bool = False)[source]

Bases: object

Configuration for OLMO tokenizer.

use_im_end: bool = False

If True, use <|im_end|> as the EOS token (for chat/instruct models). If False, use <|endoftext|> (default, for base models).

load_olmo_tokenizer

fairseq2.models.olmo.load_olmo_tokenizer(path: Path, config: OLMOTokenizerConfig) Tokenizer[source]

Hub

get_olmo_model_hub

fairseq2.models.olmo.get_olmo_model_hub = <fairseq2.models.hub.ModelHubAccessor object>

Creates a ModelHub instance when called.

This class provides a strongly-typed way to access model hubs. Its direct use is meant for model authors rather than library users.

See src/fairseq2/models/llama/hub.py as an example.

The use of ModelHubAccessor for model authors
from fairseq2.models import ModelHubAccessor

# Defined in the Python module where the model is implemented.
get_my_model_hub = ModelHubAccessor(
    family_name="my_model_family", kls=MyModel, config_kls=MyModelConfig
)

# `get_my_model_hub()` is treated as a standalone function by the model
# users in other parts of the code like below:
model_config = MyModelConfig()

model = get_my_model_hub().create_new_model(model_config)

Returns the model hub accessor for OLMo models.

from fairseq2.models.olmo import get_olmo_model_hub

hub = get_olmo_model_hub()
model = hub.load_model("olmo2_7b", device=device)

Constants

OLMO_FAMILY

fairseq2.models.olmo.OLMO_FAMILY = "olmo"

str(object=’’) -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to ‘strict’.

The family name identifier for OLMo models.

Complete Examples

Basic Model Usage

import torch

from fairseq2.device import get_default_device
from fairseq2.models.olmo import get_olmo_model_hub, load_olmo_tokenizer
from fairseq2.nn import BatchLayout

device = get_default_device()

hub = get_olmo_model_hub()
model = hub.load_model("olmo2_7b", device=device)
tokenizer = load_olmo_tokenizer("olmo2_7b")

texts = ["The capital of France is", "The capital of Germany is"]
encoder = tokenizer.create_encoder()
tokens = torch.vstack([encoder(text) for text in texts]).to(device)

model.eval()
with torch.inference_mode():
    seqs_layout = BatchLayout.of(tokens)
    output = model(tokens, seqs_layout=seqs_layout)

Custom Architecture

from fairseq2.models.olmo import get_olmo_model_hub

hub = get_olmo_model_hub()

config = hub.get_arch_config("olmo2_7b")
config.max_seq_len = 2048
config.dropout_p = 0.1

model = hub.create_new_model(config)

See Also