fairseq2.gang¶
What’s Gang?¶
A Gang represents a group of processes (e.g., GPUs) that work together in a distributed setting. Each Gang:
Has a unique rank for each process
Knows its total size (number of processes)
Supports collective operations (e.g.,
all_reduce
,broadcast
)Is associated with a specific device (CPU or CUDA)
The gang module provides distributed computing primitives for fairseq2, enabling efficient coordination and communication between multiple processes in distributed training scenarios.
Core Classes¶
- class fairseq2.gang.Gang(*args, **kwargs)[source]¶
Bases:
Closable
Represents a set of processes that work collectively.
- abstract all_reduce(tensor, op)[source]¶
Reduce
tensor
across all processes.- Parameters:
tensor (Tensor) – The input and output tensor of the operation.
op (ReduceOperation) – The element-wise reduce operation.
- abstract all_gather(output_tensor, input_tensor)[source]¶
Gather tensors from all processes and put them in
output_tensor
.
- abstract all_gather_to_list(output_tensors, input_tensor)[source]¶
Gather tensors from all processes and put them in
output_tensors
.
- abstract broadcast(tensor, source_rank=0)[source]¶
Broadcast
tensor
fromsource_rank
to all processes.
Gang represents a set of processes that work collectively. It provides an abstract interface for collective communication operations like all-reduce, broadcast, and all-gather.
Gang Implementations¶
- final class fairseq2.gang.FakeGang(device, *, rank=0, size=1)[source]¶
Bases:
Gang
Represents a non-distributed gang for local use.
- all_reduce(tensor, op)[source]¶
Reduce
tensor
across all processes.- Parameters:
tensor (Tensor) – The input and output tensor of the operation.
op (ReduceOperation) – The element-wise reduce operation.
- all_gather(output_tensor, input_tensor)[source]¶
Gather tensors from all processes and put them in
output_tensor
.
- all_gather_to_list(output_tensors, input_tensor)[source]¶
Gather tensors from all processes and put them in
output_tensors
.
FakeGang is used for local, non-distributed scenarios. It simulates gang behavior for single-process execution.
- final class fairseq2.gang.ProcessGroupGang(_pg, _device)[source]¶
Bases:
Gang
Represents a gang that wraps a process group.
- classmethod create_default_process_group(device, *, timeout=None, high_priority=False)[source]¶
Initializes the root process group and wraps it as a gang.
- Parameters:
device (device) – The device for which to initialize the gang. For CUDA devices, NCCL; for CPU, Gloo will be used.
timeout (timedelta | None) – The timeout for collective operations. If
None
, the default timeout value (15 minutes) will be used.high_priority (bool) – If
True
, the underlying collective operations will be performed on high priority channels (e.g. CUDA streams).
- Return type:
- create_gang(ranks)[source]¶
Creates a new gang.
- Parameters:
ranks (Sequence[int]) – The ranks of processes that will be part of the new gang.
- Return type:
ProcessGroupGang | None
- all_reduce(tensor, op)[source]¶
Reduce
tensor
across all processes.- Parameters:
tensor (Tensor) – The input and output tensor of the operation.
op (ReduceOperation) – The element-wise reduce operation.
- all_gather(output_tensor, input_tensor)[source]¶
Gather tensors from all processes and put them in
output_tensor
.
- all_gather_to_list(output_tensors, input_tensor)[source]¶
Gather tensors from all processes and put them in
output_tensors
.
ProcessGroupGang wraps PyTorch’s ProcessGroup to provide gang functionality in distributed environments.
Gang Configuration¶
- class fairseq2.gang.Gangs(*, root: 'Gang', dp: 'Gang', rdp: 'Gang', sdp: 'Gang', tp: 'Gang', pp: 'Gang')[source]¶
Bases:
object
Gangs is a dataclass that holds different types of gangs used in parallel training:
root: The root gang containing all processes
dp: Data parallel gang
rdp: Replicated data parallel gang (inter-node)
sdp: Sharded data parallel gang (intra-node)
tp: Tensor parallel gang
pp: Pipeline parallel gang
Enums and Types¶
- class fairseq2.gang.ReduceOperation(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
Bases:
Enum
Specifies a reduce operation.
- SUM = 1¶
- MEAN = 2¶
- PRODUCT = 3¶
- MIN = 4¶
- MAX = 5¶
ReduceOperation specifies the type of reduction to perform in all-reduce operations.
Factory Functions¶
Creates fake gangs for local, non-distributed execution.
- fairseq2.gang.create_parallel_gangs(root_gang, *, tp_size=1)[source]¶
Sets up gangs to be used for data and model parallelism.
For instance; if we have 8 devices denoted by g0 to g7 and 2 devices are used for tensor parallelism, this function will make 4 tensor parallel gangs and 2 data parallel gangs as:
- 4 tensor parallel gangs:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
- 2 data parallel gangs:
[g0, g2, g4, g6], [g1, g3, g5, g7]
For efficiency, the caller should make sure adjacent ranks are on the same host. For example, if there are two hosts with a total of 16 GPUs, ranks 0 to 7 belong to the first host and ranks 8 to 15 belong to the second host.
Note
If
root_gang
is a PyTorchProcessGroup
with NCCL backend, this function uses the experimentalsplit_group
API in PyTorch 2.5 and later. See here for more information.
Sets up gangs for data and tensor parallelism in distributed training.
- fairseq2.gang.create_fsdp_gangs(gangs, intra_node_size=None)[source]¶
Sets up gangs to be used for sharded data parallelism.
For instance; if we have 8 devices denoted by g0 to g7 and
intra_node_size
is 4, this function will make 2 intra-node gangs and 4 inter-node gangs:- 2 intra-node gangs of size 4:
[g0, g1, g2, g3], [g4, g5, g6, g7]
- 4 inter-node gangs of size 2:
[g0, g4], [g1, g5], [g2, g6], [g3, g7]
For efficiency, the caller should make sure adjacent ranks are on the same host.
- Return type:
Sets up gangs for Fully Sharded Data Parallel (FSDP) training.
Utility Functions¶
- fairseq2.gang.broadcast_flag(gang, flag, source_rank=0)[source]¶
Broadcasts
flag
to all processes ingang
fromsource_rank
.- Return type:
Broadcasts a boolean flag to all processes in a gang.
Sums a value over all processes in a gang.
Examples¶
Basic Gang Setup
from fairseq2.gang import create_fake_gangs
from fairseq2.device import Device
# Create fake gangs for local development
device = Device("cpu")
gangs = create_fake_gangs(device)
print(f"Root gang size: {gangs.root.size}")
print(f"Data parallel gang size: {gangs.dp.size}")
Distributed Training Setup
from fairseq2.gang import ProcessGroupGang, create_parallel_gangs
from fairseq2.device import get_default_device
# Initialize distributed process group
device = get_default_device()
root_gang = ProcessGroupGang.create_default_process_group(device)
# Create parallel gangs with tensor parallelism
gangs = create_parallel_gangs(root_gang, tp_size=2)
print(f"Process rank: {gangs.root.rank}")
print(f"World size: {gangs.root.size}")
print(f"TP gang size: {gangs.tp.size}")
Gang Topology¶
fairseq2’s gang system supports complex parallel training topologies:
- Data Parallelism
Multiple processes train on different data shards but maintain identical model copies.
- Tensor Parallelism
The model is split across multiple devices, with each device handling part of each layer.
- Pipeline Parallelism
Different layers of the model run on different devices in a pipeline fashion.
- Fully Sharded Data Parallelism (FSDP)
Combines data parallelism with parameter sharding, reducing memory usage while maintaining training efficiency.
The gang system automatically handles the communication patterns required for each parallelism strategy, making it easy to scale training across many GPUs and nodes.
See Also¶
fairseq2.device - Device management utilities