fairseq2.models.qwen

The Qwen module provides support for Qwen2.5 and Qwen3 language models. It includes model configurations, hub access, tokenizers, and utilities for loading and working with Qwen models.

Quick Start

from fairseq2.models.qwen import get_qwen_model_hub, get_qwen_tokenizer_hub

# Get the model hub
hub = get_qwen_model_hub()

# List available architectures
print("Available Qwen architectures:")
for arch in sorted(hub.get_archs()):
    print(f"  - {arch}")

# Load a model
model = hub.load_model("qwen25_7b")

# Load corresponding tokenizer
tokenizer = get_qwen_tokenizer_hub().load_tokenizer("qwen25_7b")

# Generate some text
text = "The future of AI is"
encoder = tokenizer.create_encoder()
encoded = encoder(text)
# ... model inference code ...

Available Models

The Qwen family includes several model sizes and versions:

Qwen 2.5 Series: - qwen25_1_5b - 1.5B parameters - qwen25_3b - 3B parameters - qwen25_7b - 7B parameters - qwen25_14b - 14B parameters - qwen25_32b - 32B parameters

Qwen 3 Series: - qwen3_0.6b - 0.6B parameters - qwen3_1.7b - 1.7B parameters - qwen3_4b - 4B parameters - qwen3_8b - 8B parameters - qwen3_14b - 14B parameters - qwen3_32b - 32B parameters

Model Hub

get_qwen_model_hub

fairseq2.models.qwen.get_qwen_model_hub()

Returns the model hub for Qwen models, providing access to all model operations.

from fairseq2.models.qwen import get_qwen_model_hub

hub = get_qwen_model_hub()

# List all available Qwen models
for card in hub.iter_cards():
    print(f"Model: {card.name}")

# Get specific architecture config
config = hub.get_arch_config("qwen25_7b")
print(f"Model dimensions: {config.model_dim}")
print(f"Number of layers: {config.num_layers}")
print(f"Attention heads: {config.num_attn_heads}")
Return type:

ModelHub[ModelT, ModelConfigT]

Model Configuration

QwenConfig

class fairseq2.models.qwen.QwenConfig(*, model_dim: 'int' = 3584, max_seq_len: 'int' = 32768, vocab_size: 'int' = 152064, tied_embeddings: 'bool' = False, num_layers: 'int' = 28, num_attn_heads: 'int' = 28, num_key_value_heads: 'int' = 4, head_dim: 'int | None' = None, qkv_proj_bias: 'bool' = True, q_norm: 'bool' = False, k_norm: 'bool' = False, ffn_inner_dim: 'int' = 18944, rope_theta: 'float' = 1000000.0, dropout_p: 'float' = 0.0)[source]

Bases: object

Configuration class for Qwen models. Defines the architecture parameters such as model dimensions, number of layers, attention heads, and other architectural choices.

Key Parameters:

  • model_dim - The dimensionality of the model (default: 3584)

  • num_layers - Number of decoder layers (default: 28)

  • num_attn_heads - Number of attention heads (default: 28)

  • num_key_value_heads - Number of key/value heads for GQA (default: 4)

  • max_seq_len - Maximum sequence length (default: 32,768)

  • vocab_size - Vocabulary size (default: 152,064)

Example:

from fairseq2.models.qwen import QwenConfig

# Create custom configuration
config = QwenConfig()
config.model_dim = 4096
config.num_layers = 32
config.num_attn_heads = 32
config.max_seq_len = 16384

# Or get pre-defined architecture
from fairseq2.models.qwen import get_qwen_model_hub
hub = get_qwen_model_hub()
config = hub.get_arch_config("qwen25_7b")
model_dim: int = 3584

The dimensionality of the model.

max_seq_len: int = 32768

The maximum sequence length.

vocab_size: int = 152064

The size of the vocabulary.

tied_embeddings: bool = False

If True, ties the embedding table and the output projection layer.

num_layers: int = 28

The number of decoder layers.

num_attn_heads: int = 28

The number of attention heads in decoder layers.

num_key_value_heads: int = 4

The number of key/value heads for Grouped Query Attention.

head_dim: int | None = None

The dimensionality of attention heads. If None, uses the standard formula model_dim // num_attn_heads.

qkv_proj_bias: bool = True

If True, query, key, and value projections learn an additive bias.

q_norm: bool = False

If True, applies Layer Normalization to projected attention queries.

k_norm: bool = False

If True, applies Layer Normalization to projected attention keys.

ffn_inner_dim: int = 18944

The dimensionality of inner projection layers in feed-forward networks.

rope_theta: float = 1000000.0

The coefficient of the long-term decay of the Rotary position encoder.

dropout_p: float = 0.0

The dropout probability on outputs of Transformer layers.

Model Factory

QwenFactory

class fairseq2.models.qwen.QwenFactory(config)[source]

Bases: object

Factory class for creating Qwen models. Handles model instantiation and checkpoint loading.

create_qwen_model

fairseq2.models.qwen.create_qwen_model(config)[source]

Creates a Qwen model instance with the specified configuration.

from fairseq2.models.qwen import create_qwen_model, QwenConfig

config = QwenConfig()
config.model_dim = 2048
config.num_layers = 24

model = create_qwen_model(config)
Return type:

TransformerLM

Tokenizer

QwenTokenizer

final class fairseq2.models.qwen.QwenTokenizer(model, eos_token)[source]

Bases: Tokenizer

Tokenizer for Qwen models. Handles text encoding and decoding using the Qwen vocabulary.

create_encoder(*, task=None, lang=None, mode=None, device=None, pin_memory=False)[source]

Constructs a token encoder.

The valid arguments for the task, lang, and mode parameters are implementation specific. Refer to concrete Tokenizer subclasses for more information.

Parameters:
  • task (str | None) – The task for which to generate token indices. Typically, task is used to distinguish between different tasks such as ‘translation’ or ‘transcription’.

  • lang (str | None) – The language of generated token indices. Typically, multilingual translation tasks use lang to distinguish between different languages such as ‘en-US’ or ‘de-DE’.

  • mode (str | None) – The mode in which to generate token indices. Typically, translation tasks use mode to distinguish between different modes such as ‘source’ or ‘target’.

  • device (device | None) – The device on which to construct tensors.

  • pin_memory (bool) – If True, uses pinned memory while constructing tensors.

Return type:

TokenEncoder

create_raw_encoder(*, device=None, pin_memory=False)[source]

Constructs a raw token encoder with no control symbols.

Parameters:
  • device (device | None) – The device on which to construct tensors.

  • pin_memory (bool) – If True, uses pinned memory for tensors.

Return type:

TokenEncoder

create_decoder(*, skip_special_tokens=False)[source]

Constructs a token decoder.

Return type:

TokenDecoder

property vocab_info: VocabularyInfo

The vocabulary information associated with the tokenizer.

QwenTokenizerConfig

class fairseq2.models.qwen.QwenTokenizerConfig(*, use_im_end: 'bool' = False)[source]

Bases: object

Configuration for the Qwen tokenizer.

get_qwen_tokenizer_hub

fairseq2.models.qwen.get_qwen_tokenizer_hub()

Returns the tokenizer hub for Qwen tokenizers.

from fairseq2.models.qwen import get_qwen_tokenizer_hub

tokenizer_hub = get_qwen_tokenizer_hub()

# Load tokenizer through hub
tokenizer = tokenizer_hub.load_tokenizer("qwen25_7b")
Return type:

TokenizerHub[TokenizerT, TokenizerConfigT]

Interoperability

convert_qwen_state_dict

fairseq2.models.qwen.convert_qwen_state_dict(state_dict, config)[source]

Converts Qwen model state dictionaries between different formats (e.g., from Hugging Face format).

from fairseq2.models.qwen import convert_qwen_state_dict
import torch

# Load checkpoint from Hugging Face format
hf_state_dict = torch.load("qwen_hf_checkpoint.pt")

# Convert to fairseq2 format
fs2_state_dict = convert_qwen_state_dict(hf_state_dict)
Return type:

dict[str, object]

export_qwen

fairseq2.models.qwen.export_qwen(state_dict, config)[source]

Exports fairseq2 Qwen models to other formats for interoperability.

Return type:

HuggingFaceExport

Sharding

get_qwen_shard_specs

fairseq2.models.qwen.get_qwen_shard_specs(config)[source]

Returns sharding specifications for distributed training and inference of Qwen models.

from fairseq2.models.qwen import get_qwen_shard_specs, QwenConfig

config = QwenConfig()
shard_specs = get_qwen_shard_specs(config)
Return type:

dict[str, ShardSpec]

Constants

QWEN_FAMILY

fairseq2.models.qwen.QWEN_FAMILY = "qwen"

str(object=’’) -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to ‘strict’.

The family name identifier for Qwen models.

Complete Examples

Basic Model Usage

import torch

from fairseq2.models.qwen import get_qwen_model_hub, get_qwen_tokenizer_hub
from fairseq2.device import get_default_device
from fairseq2.nn import BatchLayout

device = get_default_device()

# Load model and tokenizer
hub = get_qwen_model_hub()
model = hub.load_model("qwen25_7b", device=device)
tokenizer = get_qwen_tokenizer_hub().load_tokenizer("qwen25_7b")

# Prepare input
texts = ["The capital of France is", "The capital of Germany is"]
encoder = tokenizer.create_encoder()
tokens = torch.vstack([encoder(text) for text in texts]).to(device)

# Run inference (simplified)
model.eval()
with torch.inference_mode():
    seqs_layout = BatchLayout.of(tokens)
    output = model(tokens, seqs_layout=seqs_layout)
    # Process output...

Custom Architecture

from fairseq2.models.qwen import get_qwen_model_hub, QwenConfig

hub = get_qwen_model_hub()

# Get base configuration and modify
config = hub.get_arch_config("qwen25_7b")
config.max_seq_len = 16384  # Reduce sequence length
config.dropout_p = 0.1      # Add dropout

# Create model with custom config
model = hub.create_new_model(config)

Loading from Custom Checkpoint

from pathlib import Path
from fairseq2.models.qwen import get_qwen_model_hub

hub = get_qwen_model_hub()
config = hub.get_arch_config("qwen25_7b")

# Load from custom checkpoint
checkpoint_path = Path("/path/to/my/qwen_checkpoint.pt")
model = hub.load_custom_model(checkpoint_path, config)

Architecture Comparison

from fairseq2.models.qwen import get_qwen_model_hub

hub = get_qwen_model_hub()

# Compare different Qwen architectures
architectures = ["qwen25_3b", "qwen25_7b", "qwen25_14b"]

for arch in architectures:
    config = hub.get_arch_config(arch)
    params = config.model_dim * config.num_layers * config.num_attn_heads
    print(f"{arch}:")
    print(f"  Model dim: {config.model_dim}")
    print(f"  Layers: {config.num_layers}")
    print(f"  Attention heads: {config.num_attn_heads}")
    print(f"  Approx parameters: ~{params//1_000_000}M")
    print()

See Also