Module dora.executor
Start multiple process locally for DDP.
Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Start multiple process locally for DDP.
"""
from functools import partial
import os
import subprocess as sp
import sys
import typing as tp
from .log import simple_log, fatal
log = partial(simple_log, "Executor:")
class ChildrenManager:
def __init__(self):
self.children = []
self.failed = False
def add(self, child):
child.rank = len(self.children)
self.children.append(child)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_value is not None:
log(f"An exception happened while starting workers {exc_value}")
self.failed = True
try:
while self.children and not self.failed:
for child in list(self.children):
try:
exitcode = child.wait(0.05)
except sp.TimeoutExpired:
continue
else:
self.children.remove(child)
if exitcode:
log(f"Worker {child.rank} died, killing all workers")
self.failed = True
except KeyboardInterrupt:
log("Received keyboard interrupt, trying to kill all workers.")
self.failed = True
for child in self.children:
child.terminate()
if not self.failed:
log("All workers completed successfully")
def start_ddp_workers(main, argv, num_workers: tp.Optional[int] = None):
import torch as th
world_size = num_workers or th.cuda.device_count()
if not world_size:
fatal(
"DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
sys.exit(1)
xp = main.get_xp(argv)
xp.folder.mkdir(exist_ok=True, parents=True)
if xp.rendezvous_file.exists():
xp.rendezvous_file.unlink()
log(f"Starting {world_size} worker processes for DDP.")
with ChildrenManager() as manager:
for rank in range(world_size):
kwargs: tp.Dict[str, tp.Any] = {}
env = dict(os.environ)
env['RANK'] = str(rank)
env['LOCAL_RANK'] = str(rank)
env['WORLD_SIZE'] = str(world_size)
env['MASTER_ADDR'] = '127.0.0.1'
args = ["-m", "dora", "-P", main.package, "--main_module", main.main_module,
"run", "--"]
args += argv
if rank > 0:
kwargs['stdin'] = sp.DEVNULL
kwargs['stdout'] = open(xp.folder / f'worker_{rank}.log', 'w')
kwargs['stderr'] = sp.STDOUT
manager.add(
sp.Popen([sys.executable] + args, env=env, **kwargs))
sys.exit(int(manager.failed))
Functions
def start_ddp_workers(main, argv, num_workers: Optional[int] = None)
-
Expand source code
def start_ddp_workers(main, argv, num_workers: tp.Optional[int] = None): import torch as th world_size = num_workers or th.cuda.device_count() if not world_size: fatal( "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.") sys.exit(1) xp = main.get_xp(argv) xp.folder.mkdir(exist_ok=True, parents=True) if xp.rendezvous_file.exists(): xp.rendezvous_file.unlink() log(f"Starting {world_size} worker processes for DDP.") with ChildrenManager() as manager: for rank in range(world_size): kwargs: tp.Dict[str, tp.Any] = {} env = dict(os.environ) env['RANK'] = str(rank) env['LOCAL_RANK'] = str(rank) env['WORLD_SIZE'] = str(world_size) env['MASTER_ADDR'] = '127.0.0.1' args = ["-m", "dora", "-P", main.package, "--main_module", main.main_module, "run", "--"] args += argv if rank > 0: kwargs['stdin'] = sp.DEVNULL kwargs['stdout'] = open(xp.folder / f'worker_{rank}.log', 'w') kwargs['stderr'] = sp.STDOUT manager.add( sp.Popen([sys.executable] + args, env=env, **kwargs)) sys.exit(int(manager.failed))
Classes
class ChildrenManager
-
Expand source code
class ChildrenManager: def __init__(self): self.children = [] self.failed = False def add(self, child): child.rank = len(self.children) self.children.append(child) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): if exc_value is not None: log(f"An exception happened while starting workers {exc_value}") self.failed = True try: while self.children and not self.failed: for child in list(self.children): try: exitcode = child.wait(0.05) except sp.TimeoutExpired: continue else: self.children.remove(child) if exitcode: log(f"Worker {child.rank} died, killing all workers") self.failed = True except KeyboardInterrupt: log("Received keyboard interrupt, trying to kill all workers.") self.failed = True for child in self.children: child.terminate() if not self.failed: log("All workers completed successfully")
Methods
def add(self, child)
-
Expand source code
def add(self, child): child.rank = len(self.children) self.children.append(child)