fairseq2.gang¶
A Gang
represents a set of processes that can perform collective
operations such as all-reduce, broadcast, and other distributed primitives.
This module provides Gang
implementations that supports both real
distributed environments (using PyTorch’s distributed backend) and simulated
environments for testing and single-process scenarios.
See What is a Gang? for more information.
Classes¶
- class fairseq2.gang.Gang(*args, **kwargs)[source]¶
Bases:
Closable
Represents a set of processes that work collectively.
- abstract create_gang(ranks: Sequence[int]) Gang | None [source]¶
Creates a new sub-gang with the specified process ranks.
The ranks must be unique and within the range [0, gang.size).
Returns
None
if the current process is not included inranks
.- Raises:
ValueError – If
ranks
contains duplicates, or has one or more out of range values.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Creating a sub-gang with specific processes¶# Create a gang with ranks 0, 2, 4 from an 8-process gang sub_gang = root_gang.create_gang([0, 2, 4]) if sub_gang is not None: # Current process is part of the new gang print(f"New gang rank: {sub_gang.rank}, size: {sub_gang.size}")
- abstract as_process_group() ProcessGroup [source]¶
Returns this gang as a PyTorch ProcessGroup that can be used with PyTorch’s distributed operations and collective communication functions.
- Raises:
NotSupportedError – If the gang implementation does not support conversion to a ProcessGroup (e.g.
FakeGang
).
- abstract barrier() None [source]¶
Synchronizes all processes in the gang.
This is a collective operation that blocks until all processes in the gang reach this synchronization point. Used for ensuring a consistent state across all processes before proceeding.
- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- abstract all_reduce(tensor: Tensor, op: ReduceOperation) None [source]¶
Reduces
tensor
across all processes using the specified operation.All-reduce combines tensors from all processes using the specified operation and distributes the result to all processes. The input tensor is modified in-place to contain the reduction result.
- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Computing sum across all processes¶import torch from fairseq2.gang import ReduceOperation # Each process has a different tensor tensor = torch.tensor([gang.rank], dtype=torch.float32) # Sum across all processes gang.all_reduce(tensor, ReduceOperation.SUM) # Now tensor contains the sum of all ranks print(f"Sum of all ranks: {tensor.item()}")
- abstract all_gather(output_tensor: Tensor, input_tensor: Tensor) None [source]¶
Gathers tensors from all processes and puts them in
output_tensor
.All-gather collects input tensors from all processes, concatenates them along a new first dimension in rank order and writes to the output tensor. The output tensor must have shape
[gang.size, *input_tensor.shape]
and be contiguous in memory.- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Gathering tensors from all processes¶import torch # Each process contributes a tensor with its rank input_tensor = torch.tensor([gang.rank * 10], dtype=torch.float32) # Prepare output tensor for all gathered tensors output_tensor = torch.empty([gang.size, 1], dtype=torch.float32) # Gather from all processes gang.all_gather(output_tensor, input_tensor) # output_tensor now contains [0, 10, 20, ...] for ranks 0,1,2,...
- abstract all_gather_to_list(output_tensors: list[Tensor], input_tensor: Tensor) None [source]¶
Gathers tensors from all processes and puts them in
output_tensors
.Similar to
all_gather()
, but stores the gathered tensors in a list instead of concatenating them into a single tensor.output_tensors
must be a pre-allocated list with length equal togang.size
.- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- abstract broadcast(tensor: Tensor, source_rank: int = 0) None [source]¶
Broadcasts
tensor
from the specified rank to all processes.Broadcast copies the tensor from the source process to all other processes. The tensor is modified in-place on non-source processes to contain the broadcasted data.
source_rank
must be in range [0, gang.size).- Raises:
ValueError – If
source_rank
is out of valid range.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- abstract broadcast_objects(objects: list[object], source_rank: int = 0) None [source]¶
Broadcasts picklable
objects
from the specified rank to all processes.Similar to
broadcast()
, but copies arbitrary Python objects that can be pickled. The objects are modified in-place on non-source processes. Each process must provide lists of equal sizes.source_rank
must be in range [0, gang.size).- Raises:
ValueError – If
source_rank
is out of valid range.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- final class fairseq2.gang.ProcessGroupGang(_pg: ProcessGroup, _device: device)[source]¶
Bases:
Gang
Represents a gang that wraps a PyTorch ProcessGroup.
This is a distributed gang implementation that uses PyTorch’s distributed backend for actual inter-process communication.
- classmethod create_default_process_group(device: device, *, timeout: timedelta | None = None, high_priority: bool = False) ProcessGroupGang [source]¶
Initializes the default process group and wraps it as a gang.
For CUDA devices, NCCL; for CPU devices, Gloo backend will be used.
timeout
specifies the timeout for collective operations. IfNone
, the default timeout (15 minutes) will be used.If
high_priority
isTrue
, the underlying collective operations will be performed on high-priority channels (e.g. CUDA streams) if supported by the underlying backend.- Raises:
ValueError – If
device
is not of typecpu
orcuda
.NotSupportedError – If
torch.distributed
is not available.InvalidOperationError – If the root process group is already initialized.
GangError – If the underlying process group fails to initialize due to an unexpected error such as a network communication failure.
- create_gang(ranks: Sequence[int]) ProcessGroupGang | None [source]¶
Creates a new sub-gang with the specified process ranks.
The ranks must be unique and within the range [0, gang.size).
Returns
None
if the current process is not included inranks
.- Raises:
ValueError – If
ranks
contains duplicates, or has one or more out of range values.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Creating a sub-gang with specific processes¶# Create a gang with ranks 0, 2, 4 from an 8-process gang sub_gang = root_gang.create_gang([0, 2, 4]) if sub_gang is not None: # Current process is part of the new gang print(f"New gang rank: {sub_gang.rank}, size: {sub_gang.size}")
- as_process_group() ProcessGroup [source]¶
Returns this gang as a PyTorch ProcessGroup that can be used with PyTorch’s distributed operations and collective communication functions.
- Raises:
NotSupportedError – If the gang implementation does not support conversion to a ProcessGroup (e.g.
FakeGang
).
- barrier() None [source]¶
Synchronizes all processes in the gang.
This is a collective operation that blocks until all processes in the gang reach this synchronization point. Used for ensuring a consistent state across all processes before proceeding.
- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- all_reduce(tensor: Tensor, op: ReduceOperation) None [source]¶
Reduces
tensor
across all processes using the specified operation.All-reduce combines tensors from all processes using the specified operation and distributes the result to all processes. The input tensor is modified in-place to contain the reduction result.
- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Computing sum across all processes¶import torch from fairseq2.gang import ReduceOperation # Each process has a different tensor tensor = torch.tensor([gang.rank], dtype=torch.float32) # Sum across all processes gang.all_reduce(tensor, ReduceOperation.SUM) # Now tensor contains the sum of all ranks print(f"Sum of all ranks: {tensor.item()}")
- all_gather(output_tensor: Tensor, input_tensor: Tensor) None [source]¶
Gathers tensors from all processes and puts them in
output_tensor
.All-gather collects input tensors from all processes, concatenates them along a new first dimension in rank order and writes to the output tensor. The output tensor must have shape
[gang.size, *input_tensor.shape]
and be contiguous in memory.- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Gathering tensors from all processes¶import torch # Each process contributes a tensor with its rank input_tensor = torch.tensor([gang.rank * 10], dtype=torch.float32) # Prepare output tensor for all gathered tensors output_tensor = torch.empty([gang.size, 1], dtype=torch.float32) # Gather from all processes gang.all_gather(output_tensor, input_tensor) # output_tensor now contains [0, 10, 20, ...] for ranks 0,1,2,...
- all_gather_to_list(output_tensors: list[Tensor], input_tensor: Tensor) None [source]¶
Gathers tensors from all processes and puts them in
output_tensors
.Similar to
all_gather()
, but stores the gathered tensors in a list instead of concatenating them into a single tensor.output_tensors
must be a pre-allocated list with length equal togang.size
.- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- broadcast(tensor: Tensor, source_rank: int = 0) None [source]¶
Broadcasts
tensor
from the specified rank to all processes.Broadcast copies the tensor from the source process to all other processes. The tensor is modified in-place on non-source processes to contain the broadcasted data.
source_rank
must be in range [0, gang.size).- Raises:
ValueError – If
source_rank
is out of valid range.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- broadcast_objects(objects: list[object], source_rank: int = 0) None [source]¶
Broadcasts picklable
objects
from the specified rank to all processes.Similar to
broadcast()
, but copies arbitrary Python objects that can be pickled. The objects are modified in-place on non-source processes. Each process must provide lists of equal sizes.source_rank
must be in range [0, gang.size).- Raises:
ValueError – If
source_rank
is out of valid range.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- final class fairseq2.gang.FakeGang(device: device, *, rank: int = 0, size: int = 1)[source]¶
Bases:
Gang
Represents a non-distributed gang for local use.
This implementation simulates gang operations without actual distributed communication, making it useful for testing, debugging, and single-process execution. All collective operations are no-ops.
Simulating a collective operation¶import torch from fairseq2.gang import FakeGang device = torch.device("cpu") gang = FakeGang(device, rank=0, size=8) tensor = torch.tensor([gang.rank], dtype=torch.float32) # Simulates as if a real all-reduce operation is performed on the gang. gang.all_reduce(tensor, ReduceOperation.SUM)
- create_gang(ranks: Sequence[int]) FakeGang | None [source]¶
Creates a new sub-gang with the specified process ranks.
The ranks must be unique and within the range [0, gang.size).
Returns
None
if the current process is not included inranks
.- Raises:
ValueError – If
ranks
contains duplicates, or has one or more out of range values.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Creating a sub-gang with specific processes¶# Create a gang with ranks 0, 2, 4 from an 8-process gang sub_gang = root_gang.create_gang([0, 2, 4]) if sub_gang is not None: # Current process is part of the new gang print(f"New gang rank: {sub_gang.rank}, size: {sub_gang.size}")
- as_process_group() ProcessGroup [source]¶
Returns this gang as a PyTorch ProcessGroup that can be used with PyTorch’s distributed operations and collective communication functions.
- Raises:
NotSupportedError – If the gang implementation does not support conversion to a ProcessGroup (e.g.
FakeGang
).
- barrier() None [source]¶
Synchronizes all processes in the gang.
This is a collective operation that blocks until all processes in the gang reach this synchronization point. Used for ensuring a consistent state across all processes before proceeding.
- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- all_reduce(tensor: Tensor, op: ReduceOperation) None [source]¶
Reduces
tensor
across all processes using the specified operation.All-reduce combines tensors from all processes using the specified operation and distributes the result to all processes. The input tensor is modified in-place to contain the reduction result.
- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Computing sum across all processes¶import torch from fairseq2.gang import ReduceOperation # Each process has a different tensor tensor = torch.tensor([gang.rank], dtype=torch.float32) # Sum across all processes gang.all_reduce(tensor, ReduceOperation.SUM) # Now tensor contains the sum of all ranks print(f"Sum of all ranks: {tensor.item()}")
- all_gather(output_tensor: Tensor, input_tensor: Tensor) None [source]¶
Gathers tensors from all processes and puts them in
output_tensor
.All-gather collects input tensors from all processes, concatenates them along a new first dimension in rank order and writes to the output tensor. The output tensor must have shape
[gang.size, *input_tensor.shape]
and be contiguous in memory.- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
Gathering tensors from all processes¶import torch # Each process contributes a tensor with its rank input_tensor = torch.tensor([gang.rank * 10], dtype=torch.float32) # Prepare output tensor for all gathered tensors output_tensor = torch.empty([gang.size, 1], dtype=torch.float32) # Gather from all processes gang.all_gather(output_tensor, input_tensor) # output_tensor now contains [0, 10, 20, ...] for ranks 0,1,2,...
- all_gather_to_list(output_tensors: list[Tensor], input_tensor: Tensor) None [source]¶
Gathers tensors from all processes and puts them in
output_tensors
.Similar to
all_gather()
, but stores the gathered tensors in a list instead of concatenating them into a single tensor.output_tensors
must be a pre-allocated list with length equal togang.size
.- Raises:
GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- broadcast(tensor: Tensor, source_rank: int = 0) None [source]¶
Broadcasts
tensor
from the specified rank to all processes.Broadcast copies the tensor from the source process to all other processes. The tensor is modified in-place on non-source processes to contain the broadcasted data.
source_rank
must be in range [0, gang.size).- Raises:
ValueError – If
source_rank
is out of valid range.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- broadcast_objects(objects: list[object], source_rank: int = 0) None [source]¶
Broadcasts picklable
objects
from the specified rank to all processes.Similar to
broadcast()
, but copies arbitrary Python objects that can be pickled. The objects are modified in-place on non-source processes. Each process must provide lists of equal sizes.source_rank
must be in range [0, gang.size).- Raises:
ValueError – If
source_rank
is out of valid range.GangError – If the collective operation fails due to an unexpected error such as a network communication failure.
- class fairseq2.gang.Gangs(*, root: Gang, dp: Gang, rdp: Gang, sdp: Gang, tp: Gang, pp: Gang)[source]¶
Bases:
Closable
Holds parallel gangs used in distributed configurations.
Each gang is used for a different parallelism strategy such as data, tensor, or pipeline parallelism.
Check out
create_parallel_gangs()
andcreate_fsdp_gangs()
to see how to initialize aGangs
instance.- rdp: Gang¶
The replicated data parallel gang (i.e. inter-node for HSDP).
This is a sub-gang of
dp
used for replicated data parallelism. In PyTorch, this gang is used by DDP as well as by FSDP for inter-node communication when hybrid sharding is enabled.
- sdp: Gang¶
The sharded data parallel gang (i.e. intra-node for HSDP).
This is a sub-gang of
dp
used for sharded data parallelism. In PyTorch, this gang is used by FSDP. If hybrid sharding is enabled, it will be used only for intra-node communication, while inter-node communication will be handled byrdp
.
Enums¶
- class fairseq2.gang.ReduceOperation(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
Bases:
Enum
Defines the standard reduction operations that can be performed across processes during collective communication operations like all-reduce.
- SUM = 1¶
- MEAN = 2¶
- PRODUCT = 3¶
- MIN = 4¶
- MAX = 5¶
Factory Functions¶
See the ProcessGroupGang.create_default_process_group()
method for
creating the default PyTorch ProcessGroup. The rest of the factory functions
listed below are used to create sub-gangs for different parallelism strategies.
- fairseq2.gang.create_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) Gangs [source]¶
Creates gangs to be used for data and model parallelism.
For instance, if there are 8 devices denoted by d0 to d7 and 2 devices are used for tensor parallelism (i.e.
tp_size
is 2), this function will create 4 tensor parallel gangs and 2 data parallel gangs by splittingroot_gang
as:- 4 tensor parallel gangs:
[d0, d1], [d2, d3], [d4, d5], [d6, d7]
- 2 data parallel gangs:
[d0, d2, d4, d6], [d1, d3, d5, d7]
For efficiency, the caller should make sure adjacent ranks are on the same host. For example, if there are two hosts with a total of 16 GPUs, ranks 0 to 7 belong to the first host and ranks 8 to 15 belong to the second host.
Note
If
root_gang
is a PyTorchProcessGroup
with NCCL backend, this function uses the experimentalsplit_group
API in PyTorch 2.5 and later. See here for more information.
- fairseq2.gang.create_fsdp_gangs(gangs: Gangs, intra_node_size: int | None = None) Gangs [source]¶
Creates gangs to be used for hybrid or fully sharded data parallelism.
If
intra_node_size
isNone
,Gangs.sdp
(sharded data gang) will be set to the same gang asGangs.dp
andGangs.rdp
(replicated data gang) will be set to a fake gang of size 1. This topology represents a fully sharded data parallel strategy.An integer
intra_node_size
indicates hybrid sharded data parallelism. For instance, if there are 8 devices denoted by d0 to d7 and 4 devices are used for intra-node parallelism (i.e.intra_node_size
is 4), this function will create 2 intra-node gangs and 4 inter-node gangs by splittinggangs.dp
as:- 2 intra-node gangs of size 4:
[d0, d1, d2, d3], [d4, d5, d6, d7]
- 4 inter-node gangs of size 2:
[d0, d4], [d1, d5], [d2, d6], [d3, d7]
For efficiency, the caller should make sure adjacent ranks are on the same host.
At the end of the call,
gangs.rdp
(replicated data gang) will point to the inter-node gang andgangs.sdp
(sharded data gang) will point to the intra-node gang.Returns the same
Gangs
instance passed togangs
with itsGangs.rdp
andGangs.sdp
attributes set accordingly.Note
If
root_gang
is a PyTorchProcessGroup
with NCCL backend, this function uses the experimentalsplit_group
API in PyTorch 2.5 and later. See here for more information.
- fairseq2.gang.create_fake_gangs(device: device) Gangs [source]¶
Creates a set of fake gangs for single-process scenarios.
This is a helper function where every
FakeGang
is initialized with rank 0 and world size 1. For more complex simulated/testing environments,FakeGang
instances can be individually constructed per parallelism strategy and passed to aGangs
object.Creating fake gangs for testing¶import torch from fairseq2.gang import create_fake_gangs device = torch.device("cpu") gangs = create_fake_gangs(device) tensor = torch.tensor([gang.rank], dtype=torch.float32) # Simulates as if a real all-reduce operation is performed on the data # parallel gang. gangs.dp.all_reduce(tensor, ReduceOperation.SUM)
Utilities¶
- fairseq2.gang.maybe_get_current_gangs() Gangs | None [source]¶
Returns the current gangs to use for collective operations.
By default, this function returns
None
. The current gangs of the calling thread can be set by usingGangs
as a context manager:from fairseq2.gang import Gangs gangs = Gangs(...) with gangs: current_gangs = maybe_get_current_gangs() assert current_gangs is gangs current_gangs = maybe_get_current_gangs() assert current_gangs is None
Within fairseq2, this function is used by model factories to retrieve the current gangs and shard the constructed models accordingly. The current gangs are set internally by fairseq2 before calling the factories.
Note that the return value of this function is thread specific. Individual threads may have their own set of current gangs.
- fairseq2.gang.broadcast_flag(gang: Gang, flag: bool, source_rank: int = 0) bool [source]¶
Broadcasts a boolean flag to all processes in
gang
from the specified rank.Returns the broadcasted boolean value on all processes.
Broadcasting a flag across processes¶# Only rank 0 sets the flag should_continue = gang.rank == 0 and some_condition() # Broadcast the decision to all processes should_continue = broadcast_flag(gang, should_continue, source_rank=0) if should_continue: # All processes execute this together continue_processing()
- fairseq2.gang.all_sum(gang: Gang, value: float | int) Tensor [source]¶
Sums a scalar value over all processes in
gang
.Returns a tensor containing the sum of all process values.
Computing total loss across processes¶# Each process computes its local loss local_loss = compute_loss(batch) # Sum losses across all processes total_loss = all_sum(gang, local_loss) # Now `total_loss` contains the sum from all processes average_loss = total_loss / gang.size