Source code for fairseq2.gang

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import os
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import Any, final

import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import Backend, ProcessGroup, ReduceOp
from typing_extensions import override

from fairseq2.error import InternalError, InvalidOperationError, NotSupportedError
from fairseq2.logging import log
from fairseq2.typing import CPU, Device
from fairseq2.utils.env import (
    InvalidEnvironmentVariableError,
    get_local_world_size,
    get_world_size,
)
from fairseq2.utils.version import torch_greater_or_equal


[docs] class ReduceOperation(Enum): """Specifies a reduce operation.""" SUM = 1 MEAN = 2 PRODUCT = 3 MIN = 4 MAX = 5
[docs] class Gang(ABC): """Represents a set of processes that work collectively."""
[docs] @abstractmethod def close(self) -> None: """Close and destroy the gang."""
[docs] @abstractmethod def create_gang(self, ranks: Sequence[int]) -> Gang | None: """Make a new gang. :param ranks: The ranks of processes that will be part of the new gang. """
[docs] @abstractmethod def as_process_group(self) -> ProcessGroup: """Return this gang as a process group."""
[docs] @abstractmethod def barrier(self) -> None: """Synchronize all processes."""
[docs] @abstractmethod def all_reduce(self, tensor: Tensor, op: ReduceOperation) -> None: """Reduce ``tensor`` across all processes. :param tensor: The input and output tensor of the operation. :param op: The element-wise reduce operation. """
[docs] @abstractmethod def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: """Gather tensors from all processes and put them in ``output_tensor``. :param output_tensor: The output tensor to accomodate tensors from all processes. :param input_tensor: The tensor to be gathered from this process. """
[docs] @abstractmethod def all_gather_to_list( self, output_tensors: list[Tensor], input_tensor: Tensor ) -> None: """Gather tensors from all processes and put them in ``output_tensors``. :param output_tensors: The tensor list to accomodate tensors from all processes. :param input_tensor: The tensor to be gathered from this process. """
[docs] @abstractmethod def broadcast(self, tensor: Tensor, source_rank: int = 0) -> None: """Broadcast ``tensor`` from ``source_rank`` to all processes. :param tensor: The tensor to be sent from ``source_rank``. :param source_rank: The rank of the process from which to broadcast ``tensor``. """
[docs] @abstractmethod def broadcast_objects(self, objects: list[object], source_rank: int = 0) -> None: """Broadcast picklable ``objects`` from ``source_rank`` to all processes. :param objects: The list of picklable objects to broadcast. Each process must provide lists of equal sizes. :param source_rank: The rank of the process from which to broadcast ``objects``. """
@property @abstractmethod def rank(self) -> int: """The rank of this process in the gang.""" @property @abstractmethod def size(self) -> int: """The number of processes that are part of the gang.""" @property @abstractmethod def device(self) -> Device: """The associated device."""
[docs] class GangError(Exception): pass
[docs] class AbstractGang(Gang): """Provides a skeletal implementation of :class:`Gang`.""" _rank: int _size: int _device: Device def __init__(self, rank: int, size: int, device: Device) -> None: """ :param rank: The rank of this process in the gang. :param size: The number of processes that are part of the gang. :param device: The associated device. """ if size == 0: raise ValueError("`size` must be greater than zero.") if rank >= size: raise ValueError( f"`rank` must be less than `size` ({size}), but is {rank} instead." ) if device.type == "meta": raise ValueError("`device` must be a real device.") self._rank = rank self._size = size self._device = device
[docs] @final @override def create_gang(self, ranks: Sequence[int]) -> Gang | None: if len(set(ranks)) != len(ranks): raise ValueError("The ranks in ``ranks`` must be all unique.") for idx, rank in enumerate(ranks): if rank < 0 or rank > self._size: raise ValueError( f"The rank at index {idx} in ``ranks`` must be greater than or equal to 0 and less than the size of the gang ({self._size}), but is {rank} instead." ) return self._do_create_gang(ranks)
@abstractmethod def _do_create_gang(self, ranks: Sequence[int]) -> Gang | None: """Make a new gang. :param ranks: The ranks of processes that will be part of the new gang. """ @final @property @override def rank(self) -> int: return self._rank @final @property @override def size(self) -> int: return self._size @final @property @override def device(self) -> Device: return self._device
[docs] @final class FakeGang(AbstractGang): """Represents a non-distributed gang for local use.""" def __init__( self, device: Device | None = None, *, rank: int = 0, size: int = 1 ) -> None: """ :param device: If ``None``, CPU will be used. :param rank: The emulated rank of this process in the gang. :param size: The emulated number of processes that are part of the gang. """ super().__init__(rank=rank, size=size, device=device or CPU)
[docs] @override def close(self) -> None: pass
@override def _do_create_gang(self, ranks: Sequence[int]) -> FakeGang | None: try: idx = ranks.index(self._rank) except ValueError: return None return FakeGang(rank=idx, size=len(ranks), device=self._device)
[docs] @override def as_process_group(self) -> ProcessGroup: raise NotSupportedError( "`FakeGang` does not support conversion to a process group." )
[docs] @override def barrier(self) -> None: pass
[docs] @override def all_reduce(self, tensor: Tensor, op: ReduceOperation) -> None: match op: case ReduceOperation.SUM: tensor *= self._size case ReduceOperation.PRODUCT: tensor.pow_(self._size) case _: raise NotSupportedError( "`FakeGang` supports only `SUM` and `PRODUCT` reduce operations." )
[docs] @override def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: if not output_tensor.is_contiguous(): raise ValueError("`output_tensor` must be contiguous.") if output_tensor.dim() != input_tensor.dim() + 1: raise ValueError( "`output_tensor` must have a shape that is compatible with all-gather." ) if output_tensor.size(0) != self._size: raise ValueError( f"The size of the first dimension of `output_tensor` must match the number of processes in the gang ({self._size}), but is {output_tensor.size(0)} instead." ) for i in range(self._size): output_tensor[i].copy_(input_tensor)
[docs] @override def all_gather_to_list( self, output_tensors: list[Tensor], input_tensor: Tensor ) -> None: if len(output_tensors) != self._size: raise ValueError( f"The length of `output_tensors` must match the number of processes in the gang ({self._size}), but is {len(output_tensors)} instead." ) for i in range(self._size): output_tensors[i].copy_(input_tensor)
[docs] @override def broadcast(self, tensor: Tensor, source_rank: int = 0) -> None: if source_rank != self._rank: raise ValueError( f"`source_rank` must be {self._rank}, but is {source_rank} instead." )
[docs] @override def broadcast_objects(self, objects: list[object], source_rank: int = 0) -> None: if source_rank != self._rank: raise ValueError( f"`source_rank` must be {self._rank}, but is {source_rank} instead." )
[docs] @final class ProcessGroupGang(AbstractGang): """Represents a gang that wraps a process group.""" _pg: ProcessGroup _monitor_pg: ProcessGroup | None def __init__( self, pg: ProcessGroup, device: Device, *, monitor_pg: ProcessGroup | None = None, ) -> None: super().__init__(dist.get_rank(pg), dist.get_world_size(pg), device) self._pg = pg self._monitor_pg = monitor_pg
[docs] @classmethod def init_root_process_group( cls, device: Device, *, timeout: timedelta | None = None, high_priority: bool = False, num_threads: int | None = None, monitored: bool = False, ) -> ProcessGroupGang: """Initialize the root process group and wrap it as a gang. :param device: The device for which to initialize the gang. For CUDA devices, NCCL; for CPU, Gloo will be used. :param timeout: The timeout for collective operations. If ``None``, the default timeout value (15 minutes) will be used. :param num_threads: The number of threads to use for interaop parallelism. :param high_priority: If ``True``, the underlying collective operations will be performed on high priority channels (e.g. CUDA streams). :param monitored: If ``True``, puts a monitored barrier before every collective call for troubleshooting purposes. """ if log.is_enabled_for_debug(): os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" dist.set_debug_level_from_env() if not dist.is_available(): raise GangError("`torch.distributed` is not available.") if dist.is_initialized(): raise GangError("The default process group is already initialized.") backend: str | None if device.type == "cpu": backend = Backend.GLOO elif device.type == "cuda": backend = Backend.NCCL else: raise NotSupportedError( f"`device` must be of type `cpu` and `cuda`, but is of type `{device.type}` instead." ) if num_threads is None: try: num_procs = get_local_world_size(os.environ) except InvalidEnvironmentVariableError as ex: raise GangError( "The local world size cannot be determined from the environment variables. See the nested exception for details." ) from ex if num_procs > 1 and "OMP_NUM_THREADS" not in os.environ: # To prevent thread oversubscription, we distribute cores evenly # across the workers. num_threads = _get_num_cpus(num_procs) if num_threads is not None: torch.set_num_threads(num_threads) log.info("Setting the number of threads used for intra-op parallelism to {}.", num_threads) # fmt: skip if device.type == "cuda": # See https://github.com/pytorch/pytorch/issues/46874. os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" if timeout is None: timeout = timedelta(minutes=15) kwargs: dict[str, Any] = {} if torch_greater_or_equal(2, 3): # Forces NCCL to initialize immediately which enables deterministic # behavior. kwargs = {"device_id": device} # If enabled, uses high priority CUDA streams for NCCL. if device.type == "cuda" and high_priority: # Not available unless PyTorch is built with NCCL. from torch.distributed import ProcessGroupNCCL pg_options = ProcessGroupNCCL.Options(is_high_priority_stream=True) else: pg_options = None try: with warnings.catch_warnings(): warnings.filterwarnings( action="ignore", message=".*options._timeout was specified.*" ) dist.init_process_group( backend, timeout=timeout, pg_options=pg_options, **kwargs ) except (RuntimeError, ValueError) as ex: raise GangError( "The underlying process group has failed to initialize. See the nested exception for details." ) from ex pg = dist.group.WORLD if pg is None: raise InternalError("`dist.group.WORLD` is `None`.") if monitored: if backend == Backend.GLOO: monitor_pg = pg else: # Gloo is needed for monitored barrier support. try: monitor_pg = dist.new_group(backend=Backend.GLOO, timeout=timeout) except RuntimeError as ex: raise GangError( "The underlying process group used for monitoring has failed to initialize. See the nested exception for details." ) from ex else: monitor_pg = None return ProcessGroupGang(pg, device, monitor_pg=monitor_pg)
[docs] @override def close(self) -> None: dist.destroy_process_group(self._pg)
@override def _do_create_gang(self, ranks: Sequence[int]) -> ProcessGroupGang | None: if self._pg is not dist.group.WORLD: raise InvalidOperationError( "`create_gang()` can only be called on the gang associated with the default (i.e. main) process group." ) try: backend = dist.get_backend() except RuntimeError as ex: raise GangError( "The default process group backend cannot be determined. See the nested exception for details." ) from ex try: pg = dist.new_group(ranks, backend=backend) except RuntimeError as ex: s = ", ".join(sorted(str(r) for r in ranks)) raise GangError( f"The creation of a new child process group has failed for ranks {s}. See the nested exception for details." ) from ex if self._rank not in ranks: return None if self._monitor_pg is not None: if backend == Backend.GLOO: monitor_pg = pg else: try: monitor_pg = dist.new_group(ranks, backend=Backend.GLOO) except RuntimeError as ex: s = ", ".join(sorted(str(r) for r in ranks)) raise GangError( f"The creation of a new monitoring child process group has failed for ranks {s}. See the nested exception for details." ) from ex else: monitor_pg = None return ProcessGroupGang(pg, self._device, monitor_pg=monitor_pg)
[docs] @override def as_process_group(self) -> ProcessGroup: return self._pg
[docs] @override def barrier(self) -> None: if self._monitor_pg is None: try: dist.barrier(group=self._pg, device_ids=[self._device.index]) except RuntimeError as ex: raise GangError( "The `barrier` collective operation has failed. See the nested exception for details." ) from ex else: torch.cuda.synchronize() try: dist.monitored_barrier(group=self._monitor_pg, wait_all_ranks=True) except RuntimeError as ex: raise GangError( "The `monitored_barrier` collective operation has failed. See the nested exception for details." ) from ex
[docs] @override def all_reduce(self, tensor: Tensor, op: ReduceOperation) -> None: self._maybe_monitored_barrier() try: dist.all_reduce(tensor, self._get_reduce_op(op), group=self._pg) except RuntimeError as ex: raise GangError( "The `all_reduce` collective operation has failed. See the nested exception for details." ) from ex
[docs] @override def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: self._maybe_monitored_barrier() try: dist.all_gather_into_tensor(output_tensor, input_tensor, group=self._pg) except RuntimeError as ex: raise GangError( "The `all_gather` collective operation has failed. See the nested exception for details." ) from ex
[docs] @override def all_gather_to_list( self, output_tensors: list[Tensor], input_tensor: Tensor ) -> None: self._maybe_monitored_barrier() try: dist.all_gather(output_tensors, input_tensor, group=self._pg) except RuntimeError as ex: raise GangError( "The `all_gather_to_list` collective operation has failed. See the nested exception for details." ) from ex
[docs] @override def broadcast(self, tensor: Tensor, source_rank: int = 0) -> None: self._maybe_monitored_barrier() try: dist.broadcast(tensor, source_rank, group=self._pg) except RuntimeError as ex: raise GangError( "The `broadcast` collective operation has failed. See the nested exception for details." ) from ex
[docs] @override def broadcast_objects(self, objects: list[object], source_rank: int = 0) -> None: self._maybe_monitored_barrier() try: dist.broadcast_object_list(objects, source_rank, group=self._pg) except RuntimeError as ex: raise GangError( "The `broadcast_object_list` collective operation has failed. See the nested exception for details." ) from ex
def _maybe_monitored_barrier(self) -> None: if self._monitor_pg is None: return torch.cuda.synchronize() try: dist.monitored_barrier(group=self._monitor_pg, wait_all_ranks=True) except RuntimeError as ex: raise GangError( "The `monitored_barrier` collective operation has failed. See the nested exception for details." ) from ex @staticmethod def _get_reduce_op(op: ReduceOperation): # type: ignore[no-untyped-def] if op == ReduceOperation.SUM: return ReduceOp.SUM if op == ReduceOperation.MEAN: return ReduceOp.AVG # type: ignore[attr-defined] if op == ReduceOperation.PRODUCT: return ReduceOp.PRODUCT if op == ReduceOperation.MIN: return ReduceOp.MIN if op == ReduceOperation.MAX: return ReduceOp.MAX raise NotSupportedError( f"`{op}` operation is not supported by the underlying process group." )
def _get_num_cpus(num_procs: int) -> int: num_cpus = os.cpu_count() affinity_mask = os.sched_getaffinity(0) if num_cpus is None or affinity_mask is None: log.warning("The number of CPU cores cannot be determined.") return 1 # We should not exceed the number of cores available in the affinity mask. return min(max(num_cpus // num_procs, 1), len(affinity_mask))
[docs] def setup_root_gang( device: Device, *, timeout: timedelta | None = None, high_priority: bool = False, monitored: bool = False, ) -> Gang: """Create the root gang of this process. :param device: The device for which to initialize the gang. For CUDA devices, NCCL; for CPU, Gloo will be used. :param timeout: The timeout for collective operations. If ``None``, the default timeout value (15 minutes) will be used. :param high_priority: If ``True``, the underlying collective operations will be performed on high priority channels (e.g. CUDA streams). :param monitored: If ``True``, puts a monitored barrier before every collective call for troubleshooting purposes. """ try: world_size = get_world_size(os.environ) except InvalidEnvironmentVariableError as ex: raise GangError( "The world size cannot be determined. See the nested exception for details." ) from ex if world_size == 1: return FakeGang(device) return ProcessGroupGang.init_root_process_group( device, timeout=timeout, high_priority=high_priority, monitored=monitored )
@dataclass class Gangs: root: Gang dp: Gang tp: Gang def close(self) -> None: self.root.close()
[docs] def fake_gangs(device: Device) -> Gangs: gang = FakeGang(device=device) return Gangs(gang, gang, gang)
def to_gangs(gang: Gang) -> Gangs: fake_gang = FakeGang(device=gang.device) return Gangs(gang, gang, fake_gang)
[docs] def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs: """Sets up gangs to be used for data and model parallelism. For instance; if we have 8 devices denoted by g0 to g7 and 2 devices are used for tensor parallelism, this function will make 4 tensor parallel gangs and 2 data parallel gangs as: 4 tensor parallel gangs: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 data parallel gangs: [g0, g2, g4, g6], [g1, g3, g5, g7] For efficiency, the caller should make sure adjacent ranks are on the same host. For example, if there are two hosts with a total of 16 GPUs, ranks 0 to 7 belong to the first host and ranks 8 to 15 belong to the second host. :param root_gang: The gang whose topology will be used to make the new gangs. :param tp_size: The size of tensor parallel gangs. """ if tp_size < 1: raise ValueError(f"`tp_size` must be greater than 0, but is {tp_size} instead.") if tp_size > root_gang.size: raise ValueError( f"`tp_size` must be less than or equal to the number of processes in the root gang ({root_gang.size}), but is {tp_size} instead." ) if root_gang.size % tp_size != 0: raise ValueError( f"`root_gang.size` is expected to be a multiple of `tp_size` ({tp_size}), but is {root_gang.size} instead." ) dp_size = root_gang.size // tp_size mesh = torch.arange(root_gang.size).view(dp_size, tp_size) # Get the coordinate of this process in the mesh. rank_coords = [x.item() for x in torch.where(mesh == root_gang.rank)] dp_gang: Gang | None = None log.info("Initializing data parallel gang with a size of {}.", dp_size) # Build the gangs for data parallelism. match dp_size: case 1: dp_gang = FakeGang(device=root_gang.device) case root_gang.size: dp_gang = root_gang case _: for i in range(tp_size): sub_gang = root_gang.create_gang(mesh[:, i].tolist()) if i == rank_coords[1]: dp_gang = sub_gang if dp_gang is None: raise InternalError("`dp_gang` is `None`.") tp_gang: Gang | None = None log.info("Initializing tensor parallel gang with a size of {}.", tp_size) # Build the gangs for tensor parallelism. match tp_size: case 1: tp_gang = FakeGang(device=root_gang.device) case root_gang.size: tp_gang = root_gang case _: for i in range(dp_size): sub_gang = root_gang.create_gang(mesh[i, :].tolist()) if i == rank_coords[0]: tp_gang = sub_gang if tp_gang is None: raise InternalError("`tp_gang` is `None`.") return Gangs(root_gang, dp_gang, tp_gang)
def setup_hsdp_gangs(dp_gang: Gang, local_world_size: int) -> tuple[Gang, Gang]: """ Sets up gangs to be used for hybrid sharded data parallelism. For instance; if we have 8 devices denoted by g0 to g7 and ``local_world_size`` is 4, this function will make 2 intra-node gangs and 4 inter-node gangs: 2 intra-node gangs of size 4: [g0, g1, g2, g3], [g4, g5, g6, g7] 4 inter-node gangs of size 2: [g0, g4], [g1, g5], [g2, g6], [g3, g7] For efficiency, the caller should make sure adjacent ranks are on the same host. :returs: A tuple of intra-node gang for sharding and inter-node gang for replication. """ if local_world_size <= 1: raise ValueError( f"`local_world_size` must be greater than 1, but is {local_world_size} instead." ) if dp_gang.size % local_world_size != 0: raise ValueError( f"`dp_gang.size` is expected to be a multiple of `local_world_size` ({local_world_size}), but is {dp_gang.size} instead." ) intra_node_size = local_world_size inter_node_size = dp_gang.size // local_world_size mesh = torch.arange(dp_gang.size).view(inter_node_size, intra_node_size) # Get the coordinate of this process in the mesh. rank_coords = [x.item() for x in torch.where(mesh == dp_gang.rank)] inter_node_gang: Gang | None = None log.info("Initializing inter-node data parallel gang with a size of {}.", inter_node_size) # fmt: skip # Build the gangs for inter-node data parallelism. match inter_node_size: case 1: inter_node_gang = FakeGang(device=dp_gang.device) case dp_gang.size: inter_node_gang = dp_gang case _: for i in range(intra_node_size): sub_gang = dp_gang.create_gang(mesh[:, i].tolist()) if i == rank_coords[1]: inter_node_gang = sub_gang if inter_node_gang is None: raise InternalError("`inter_node_gang` is `None`.") intra_node_gang: Gang | None = None log.info("Initializing intra-node data parallel gang with a size of {}.", intra_node_size) # fmt: skip # Build the gangs for intra-node data parallelism. match intra_node_size: case 1: intra_node_gang = FakeGang(device=dp_gang.device) case dp_gang.size: intra_node_gang = dp_gang case _: for i in range(inter_node_size): sub_gang = dp_gang.create_gang(mesh[i, :].tolist()) if i == rank_coords[0]: intra_node_gang = sub_gang if intra_node_gang is None: raise InternalError("`intra_node_gang` is `None`.") return intra_node_gang, inter_node_gang
[docs] def broadcast_flag(gang: Gang, flag: bool, source_rank: int = 0) -> bool: """Broadcast ``flag`` to all processes in ``gang`` from ``source_rank``.""" tmp = torch.tensor(flag, device=gang.device) gang.broadcast(tmp, source_rank) return bool(tmp)
[docs] def all_sum(gang: Gang, value: float | int | Tensor) -> Tensor: """Sum ``value`` over all processes in ``gang``.""" if isinstance(value, Tensor): output = value else: output = torch.tensor(value, device=gang.device) gang.all_reduce(output, ReduceOperation.SUM) return output