fairseq2.models.qwen¶
The Qwen module provides support for Qwen2.5 and Qwen3 language models. It includes model configurations, hub access, tokenizers, and utilities for loading and working with Qwen models.
Quick Start¶
from fairseq2.models.qwen import get_qwen_model_hub, get_qwen_tokenizer_hub
# Get the model hub
hub = get_qwen_model_hub()
# List available architectures
print("Available Qwen architectures:")
for arch in sorted(hub.get_archs()):
print(f" - {arch}")
# Load a model
model = hub.load_model("qwen25_7b")
# Load corresponding tokenizer
tokenizer = get_qwen_tokenizer_hub().load_tokenizer("qwen25_7b")
# Generate some text
text = "The future of AI is"
encoder = tokenizer.create_encoder()
encoded = encoder(text)
# ... model inference code ...
Available Models¶
The Qwen family includes several model sizes and versions:
Qwen 2.5 Series:
- qwen25_1_5b
- 1.5B parameters
- qwen25_3b
- 3B parameters
- qwen25_7b
- 7B parameters
- qwen25_14b
- 14B parameters
- qwen25_32b
- 32B parameters
Qwen 3 Series:
- qwen3_0.6b
- 0.6B parameters
- qwen3_1.7b
- 1.7B parameters
- qwen3_4b
- 4B parameters
- qwen3_8b
- 8B parameters
- qwen3_14b
- 14B parameters
- qwen3_32b
- 32B parameters
Model Hub¶
get_qwen_model_hub¶
- fairseq2.models.qwen.get_qwen_model_hub()¶
Returns the model hub for Qwen models, providing access to all model operations.
from fairseq2.models.qwen import get_qwen_model_hub hub = get_qwen_model_hub() # List all available Qwen models for card in hub.iter_cards(): print(f"Model: {card.name}") # Get specific architecture config config = hub.get_arch_config("qwen25_7b") print(f"Model dimensions: {config.model_dim}") print(f"Number of layers: {config.num_layers}") print(f"Attention heads: {config.num_attn_heads}")
- Return type:
ModelHub[ModelT, ModelConfigT]
Model Configuration¶
QwenConfig¶
- class fairseq2.models.qwen.QwenConfig(*, model_dim: 'int' = 3584, max_seq_len: 'int' = 32768, vocab_size: 'int' = 152064, tied_embeddings: 'bool' = False, num_layers: 'int' = 28, num_attn_heads: 'int' = 28, num_key_value_heads: 'int' = 4, head_dim: 'int | None' = None, qkv_proj_bias: 'bool' = True, q_norm: 'bool' = False, k_norm: 'bool' = False, ffn_inner_dim: 'int' = 18944, rope_theta: 'float' = 1000000.0, dropout_p: 'float' = 0.0)[source]¶
Bases:
object
Configuration class for Qwen models. Defines the architecture parameters such as model dimensions, number of layers, attention heads, and other architectural choices.
Key Parameters:
model_dim
- The dimensionality of the model (default: 3584)num_layers
- Number of decoder layers (default: 28)num_attn_heads
- Number of attention heads (default: 28)num_key_value_heads
- Number of key/value heads for GQA (default: 4)max_seq_len
- Maximum sequence length (default: 32,768)vocab_size
- Vocabulary size (default: 152,064)
Example:
from fairseq2.models.qwen import QwenConfig # Create custom configuration config = QwenConfig() config.model_dim = 4096 config.num_layers = 32 config.num_attn_heads = 32 config.max_seq_len = 16384 # Or get pre-defined architecture from fairseq2.models.qwen import get_qwen_model_hub hub = get_qwen_model_hub() config = hub.get_arch_config("qwen25_7b")
- head_dim: int | None = None¶
The dimensionality of attention heads. If
None
, uses the standard formulamodel_dim // num_attn_heads
.
Model Factory¶
QwenFactory¶
create_qwen_model¶
- fairseq2.models.qwen.create_qwen_model(config)[source]¶
Creates a Qwen model instance with the specified configuration.
from fairseq2.models.qwen import create_qwen_model, QwenConfig config = QwenConfig() config.model_dim = 2048 config.num_layers = 24 model = create_qwen_model(config)
- Return type:
TransformerLM
Tokenizer¶
QwenTokenizer¶
- final class fairseq2.models.qwen.QwenTokenizer(model, eos_token)[source]¶
Bases:
Tokenizer
Tokenizer for Qwen models. Handles text encoding and decoding using the Qwen vocabulary.
- create_encoder(*, task=None, lang=None, mode=None, device=None, pin_memory=False)[source]¶
Constructs a token encoder.
The valid arguments for the
task
,lang
, andmode
parameters are implementation specific. Refer to concreteTokenizer
subclasses for more information.- Parameters:
task (str | None) – The task for which to generate token indices. Typically,
task
is used to distinguish between different tasks such as ‘translation’ or ‘transcription’.lang (str | None) – The language of generated token indices. Typically, multilingual translation tasks use
lang
to distinguish between different languages such as ‘en-US’ or ‘de-DE’.mode (str | None) – The mode in which to generate token indices. Typically, translation tasks use
mode
to distinguish between different modes such as ‘source’ or ‘target’.device (device | None) – The device on which to construct tensors.
pin_memory (bool) – If
True
, uses pinned memory while constructing tensors.
- Return type:
TokenEncoder
- create_raw_encoder(*, device=None, pin_memory=False)[source]¶
Constructs a raw token encoder with no control symbols.
- create_decoder(*, skip_special_tokens=False)[source]¶
Constructs a token decoder.
- Return type:
TokenDecoder
- property vocab_info: VocabularyInfo¶
The vocabulary information associated with the tokenizer.
QwenTokenizerConfig¶
get_qwen_tokenizer_hub¶
- fairseq2.models.qwen.get_qwen_tokenizer_hub()¶
Returns the tokenizer hub for Qwen tokenizers.
from fairseq2.models.qwen import get_qwen_tokenizer_hub tokenizer_hub = get_qwen_tokenizer_hub() # Load tokenizer through hub tokenizer = tokenizer_hub.load_tokenizer("qwen25_7b")
- Return type:
TokenizerHub[TokenizerT, TokenizerConfigT]
Interoperability¶
convert_qwen_state_dict¶
- fairseq2.models.qwen.convert_qwen_state_dict(state_dict, config)[source]¶
Converts Qwen model state dictionaries between different formats (e.g., from Hugging Face format).
from fairseq2.models.qwen import convert_qwen_state_dict import torch # Load checkpoint from Hugging Face format hf_state_dict = torch.load("qwen_hf_checkpoint.pt") # Convert to fairseq2 format fs2_state_dict = convert_qwen_state_dict(hf_state_dict)
export_qwen¶
Constants¶
QWEN_FAMILY¶
- fairseq2.models.qwen.QWEN_FAMILY = "qwen"¶
str(object=’’) -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to ‘strict’.
The family name identifier for Qwen models.
Complete Examples¶
Basic Model Usage¶
import torch
from fairseq2.models.qwen import get_qwen_model_hub, get_qwen_tokenizer_hub
from fairseq2.device import get_default_device
from fairseq2.nn import BatchLayout
device = get_default_device()
# Load model and tokenizer
hub = get_qwen_model_hub()
model = hub.load_model("qwen25_7b", device=device)
tokenizer = get_qwen_tokenizer_hub().load_tokenizer("qwen25_7b")
# Prepare input
texts = ["The capital of France is", "The capital of Germany is"]
encoder = tokenizer.create_encoder()
tokens = torch.vstack([encoder(text) for text in texts]).to(device)
# Run inference (simplified)
model.eval()
with torch.inference_mode():
seqs_layout = BatchLayout.of(tokens)
output = model(tokens, seqs_layout=seqs_layout)
# Process output...
Custom Architecture¶
from fairseq2.models.qwen import get_qwen_model_hub, QwenConfig
hub = get_qwen_model_hub()
# Get base configuration and modify
config = hub.get_arch_config("qwen25_7b")
config.max_seq_len = 16384 # Reduce sequence length
config.dropout_p = 0.1 # Add dropout
# Create model with custom config
model = hub.create_new_model(config)
Loading from Custom Checkpoint¶
from pathlib import Path
from fairseq2.models.qwen import get_qwen_model_hub
hub = get_qwen_model_hub()
config = hub.get_arch_config("qwen25_7b")
# Load from custom checkpoint
checkpoint_path = Path("/path/to/my/qwen_checkpoint.pt")
model = hub.load_custom_model(checkpoint_path, config)
Architecture Comparison¶
from fairseq2.models.qwen import get_qwen_model_hub
hub = get_qwen_model_hub()
# Compare different Qwen architectures
architectures = ["qwen25_3b", "qwen25_7b", "qwen25_14b"]
for arch in architectures:
config = hub.get_arch_config(arch)
params = config.model_dim * config.num_layers * config.num_attn_heads
print(f"{arch}:")
print(f" Model dim: {config.model_dim}")
print(f" Layers: {config.num_layers}")
print(f" Attention heads: {config.num_attn_heads}")
print(f" Approx parameters: ~{params//1_000_000}M")
print()
See Also¶
fairseq2.models.hub - Model hub API reference
Add Your Own Model - Tutorial on adding new models
Assets - Understanding the asset system