Gang is fairseq2’s abstraction for distributed training that provides a clean interface for collective operations (e.g., all_reduce, all_gather, and broadcast) across processes in a distributed environment.
It simplifies PyTorch’s distributed training while supporting both data parallelism and tensor parallelism.
This design encapsulates the complexity of PyTorch’s torch.distributed while supporting:
Data Parallelism: Distributing batches of data across multiple GPUs.
Tensor Parallelism: Partitioning model tensors for efficient computation.
Flexible Process Grouping: Organizing processes into groups dynamically.
The Gang interface supports the following methods:
# Reduce tensor across processesgang.all_reduce(tensor,ReduceOperation.SUM)# Gather tensors from all processesgang.all_gather(output_tensor,input_tensor)# Gather tensors from all processes into a listgang.all_gather_to_list(output_tensors,input_tensor)# Broadcast tensor from source rank to all othersgang.broadcast(tensor,source_rank=0)# Synchronize all processesgang.barrier()# Broadcast Python objectsgang.broadcast_objects(objects,source_rank=0)
In fairseq2, parallel training is organized around Data Parallel (DP) Gangs and Tensor Parallel (TP) Gangs, which together enable scalable training of large models.
For example, the setup_parallel_gangs(root_gang,tp_size=2) function creates a root gang (e.g., 8 processes) and then creates 2 DP gangs and 4 TP gangs.
fairseq2 also supports hybrid-sharding FSDP configurations through setup_hybrid_fsdp_gangs(), which creates specialized gang arrangements for efficient model sharding and replication across devices.
fromfairseq2.gangimportsetup_root_gangfromdatetimeimporttimedelta# Initialize the default gang with custom settingsgang=setup_root_gang(timeout=timedelta(minutes=30),# Custom timeout for monitored barriersmonitored=True# Enable monitored barriers for deadlock detection)print(f"Process rank: {gang.rank}, World size: {gang.size}")
Note
If running locally (no torch.distributed backend), a FakeGang is created.
This is useful for local testing and debugging.
If running in a distributed environment, a ProcessGroupGang is created.
A minimal example of distributed training with gangs:
# script.pyimporttorchfromfairseq2.gangimportsetup_root_gang,ReduceOperation# Initialize ganggang=setup_root_gang()# Dummy tensortensor=torch.tensor(gang.rank+1.0,device=gang.device)# Sum tensor across all processesgang.all_reduce(tensor,ReduceOperation.SUM)print(f"Rank {gang.rank}: Tensor after all_reduce = {tensor.item()}")# Synchronizegang.barrier()