Gang is fairseq2’s abstraction for distributed training that provides a clean interface for collective operations (e.g., all_reduce, all_gather, and broadcast) across processes in a distributed environment.
It simplifies PyTorch’s distributed training while supporting both data parallelism and tensor parallelism.
This design encapsulates the complexity of PyTorch’s torch.distributed while supporting:
Data Parallelism: Distributing batches of data across multiple GPUs.
Tensor Parallelism: Partitioning model tensors for efficient computation.
Flexible Process Grouping: Organizing processes into groups dynamically.
The Gang interface supports the following methods:
# Reduce tensor across processesgang.all_reduce(tensor,ReduceOperation.SUM)# Gather tensors from all processesgang.all_gather(output_tensor,input_tensor)# Gather tensors from all processes into a listgang.all_gather_to_list(output_tensors,input_tensor)# Broadcast tensor from source rank to all othersgang.broadcast(tensor,source_rank=0)# Synchronize all processesgang.barrier()# Broadcast Python objectsgang.broadcast_objects(objects,source_rank=0)
In fairseq2, parallel training is organized around Data Parallel (DP) Gangs and Tensor Parallel (TP) Gangs, which together enable scalable training of large models.
For example, the setup_parallel_gangs(root_gang,tp_size=2) function creates a root gang (e.g., 8 processes) and then creates 2 DP gangs and 4 TP gangs.
fromfairseq2.gangimportsetup_default_gang# Initialize the default ganggang=setup_default_gang()print(f"Process rank: {gang.rank}, World size: {gang.size}")
Note
If running locally (no torch.distributed backend), a FakeGang is created.
This is useful for local testing and debugging.
If running in a distributed environment, a ProcessGroupGang is created.
A minimal example of distributed training with gangs:
# script.pyimporttorchfromfairseq2.gangimportsetup_default_gang,ReduceOperation# Initialize ganggang=setup_default_gang()# Dummy tensortensor=torch.tensor(gang.rank+1.0,device=gang.device)# Sum tensor across all processesgang.all_reduce(tensor,ReduceOperation.SUM)print(f"Rank {gang.rank}: Tensor after all_reduce = {tensor.item()}")# Synchronizegang.barrier()