Module dora.executor

Start multiple process locally for DDP.

Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Start multiple process locally for DDP.
"""

from functools import partial
import os
import subprocess as sp
import sys
import typing as tp

from .log import simple_log, fatal


log = partial(simple_log, "Executor:")


class ChildrenManager:
    def __init__(self):
        self.children = []
        self.failed = False

    def add(self, child):
        child.rank = len(self.children)
        self.children.append(child)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_value is not None:
            log(f"An exception happened while starting workers {exc_value}")
            self.failed = True
        try:
            while self.children and not self.failed:
                for child in list(self.children):
                    try:
                        exitcode = child.wait(0.05)
                    except sp.TimeoutExpired:
                        continue
                    else:
                        self.children.remove(child)
                        if exitcode:
                            log(f"Worker {child.rank} died, killing all workers")
                            self.failed = True
        except KeyboardInterrupt:
            log("Received keyboard interrupt, trying to kill all workers.")
            self.failed = True
        for child in self.children:
            child.terminate()
        if not self.failed:
            log("All workers completed successfully")


def start_ddp_workers(main, argv, num_workers: tp.Optional[int] = None):
    import torch as th

    world_size = num_workers or th.cuda.device_count()
    if not world_size:
        fatal(
            "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
        sys.exit(1)

    xp = main.get_xp(argv)
    xp.folder.mkdir(exist_ok=True, parents=True)
    if xp.rendezvous_file.exists():
        xp.rendezvous_file.unlink()
    log(f"Starting {world_size} worker processes for DDP.")
    with ChildrenManager() as manager:
        for rank in range(world_size):
            kwargs: tp.Dict[str, tp.Any] = {}
            env = dict(os.environ)
            env['RANK'] = str(rank)
            env['LOCAL_RANK'] = str(rank)
            env['WORLD_SIZE'] = str(world_size)
            env['MASTER_ADDR'] = '127.0.0.1'
            args = ["-m", "dora", "-P", main.package, "--main_module", main.main_module,
                    "run", "--"]
            args += argv
            if rank > 0:
                kwargs['stdin'] = sp.DEVNULL
                kwargs['stdout'] = open(xp.folder / f'worker_{rank}.log', 'w')
                kwargs['stderr'] = sp.STDOUT
            manager.add(
                sp.Popen([sys.executable] + args, env=env, **kwargs))
    sys.exit(int(manager.failed))

Functions

def start_ddp_workers(main, argv, num_workers: Optional[int] = None)
Expand source code
def start_ddp_workers(main, argv, num_workers: tp.Optional[int] = None):
    import torch as th

    world_size = num_workers or th.cuda.device_count()
    if not world_size:
        fatal(
            "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
        sys.exit(1)

    xp = main.get_xp(argv)
    xp.folder.mkdir(exist_ok=True, parents=True)
    if xp.rendezvous_file.exists():
        xp.rendezvous_file.unlink()
    log(f"Starting {world_size} worker processes for DDP.")
    with ChildrenManager() as manager:
        for rank in range(world_size):
            kwargs: tp.Dict[str, tp.Any] = {}
            env = dict(os.environ)
            env['RANK'] = str(rank)
            env['LOCAL_RANK'] = str(rank)
            env['WORLD_SIZE'] = str(world_size)
            env['MASTER_ADDR'] = '127.0.0.1'
            args = ["-m", "dora", "-P", main.package, "--main_module", main.main_module,
                    "run", "--"]
            args += argv
            if rank > 0:
                kwargs['stdin'] = sp.DEVNULL
                kwargs['stdout'] = open(xp.folder / f'worker_{rank}.log', 'w')
                kwargs['stderr'] = sp.STDOUT
            manager.add(
                sp.Popen([sys.executable] + args, env=env, **kwargs))
    sys.exit(int(manager.failed))

Classes

class ChildrenManager
Expand source code
class ChildrenManager:
    def __init__(self):
        self.children = []
        self.failed = False

    def add(self, child):
        child.rank = len(self.children)
        self.children.append(child)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_value is not None:
            log(f"An exception happened while starting workers {exc_value}")
            self.failed = True
        try:
            while self.children and not self.failed:
                for child in list(self.children):
                    try:
                        exitcode = child.wait(0.05)
                    except sp.TimeoutExpired:
                        continue
                    else:
                        self.children.remove(child)
                        if exitcode:
                            log(f"Worker {child.rank} died, killing all workers")
                            self.failed = True
        except KeyboardInterrupt:
            log("Received keyboard interrupt, trying to kill all workers.")
            self.failed = True
        for child in self.children:
            child.terminate()
        if not self.failed:
            log("All workers completed successfully")

Methods

def add(self, child)
Expand source code
def add(self, child):
    child.rank = len(self.children)
    self.children.append(child)