Module dora.utils

Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez 2020

from contextlib import contextmanager
import importlib
import logging
import os
from pathlib import Path
import pickle
from shutil import rmtree
import tempfile
import typing as tp

from omegaconf.basecontainer import BaseContainer
from omegaconf import OmegaConf

from .log import fatal

logger = logging.getLogger(__name__)


def jsonable(value):
    import torch

    if isinstance(value, dict):
        return {k: jsonable(v) for k, v in value.items()}
    elif isinstance(value, (list, tuple)):
        return [jsonable(v) for v in value]
    elif isinstance(value, torch.Tensor):
        return value.detach().cpu().tolist()
    elif isinstance(value, Path):
        return str(value)
    elif value is None or isinstance(value, (int, float, str, bool)):
        return value
    elif isinstance(value, BaseContainer):
        return OmegaConf.to_container(value)
    else:
        raise ValueError(f"{repr(value)} is not jsonable.")


@contextmanager
def write_and_rename(path: Path, mode: str = "wb", suffix: str = ".tmp"):
    """
    Write to a temporary file with the given suffix, then rename it
    to the right filename. As renaming a file is usually much faster
    than writing it, this removes (or highly limits as far as I understand NFS)
    the likelihood of leaving a half-written checkpoint behind, if killed
    at the wrong time.
    """
    tmp_path = str(path) + suffix
    with open(tmp_path, mode) as f:
        yield f
    os.rename(tmp_path, path)


def try_load(path: Path, load=pickle.load, mode: str = "rb"):
    """
    Try to load from a path using torch.load, and handles various failure cases.
    Return None upon failure.
    """
    try:
        return load(open(path, mode))
    except (OSError, pickle.UnpicklingError, RuntimeError, EOFError) as exc:
        # Trying to list everything that can go wrong.
        logger.warning(
            "An error happened when trying to load from %s, this file will be ignored: %r",
            path, exc)
        return None


def import_or_fatal(module_name: str) -> tp.Any:
    try:
        return importlib.import_module(module_name)
    except ImportError:
        logger.info("Could not import module %s", module_name, exc_info=True)
        fatal(f"Failed to import module {module_name}.")


def reliable_rmtree(path: Path):
    """Reliably delete the given folder, trying to remove while ignoring errors,
    and if any files remain, renaming to some trash folder."""
    error = False

    def _on_error(func, error_path, exc_info):
        nonlocal error
        error = True
        logger.warning(f"Error deleting file {error_path}")

    rmtree(path, onerror=_on_error)
    if error:
        assert path.exists()
        target_name = tempfile.mkdtemp(dir=path.parent, prefix=path.name + "_", suffix="_trash")
        logger.warning(f"Deletion of {path} failed, moving to {target_name}")
        path.rename(target_name)
    else:
        assert not path.exists()

Functions

def import_or_fatal(module_name: str) ‑> Any
Expand source code
def import_or_fatal(module_name: str) -> tp.Any:
    try:
        return importlib.import_module(module_name)
    except ImportError:
        logger.info("Could not import module %s", module_name, exc_info=True)
        fatal(f"Failed to import module {module_name}.")
def jsonable(value)
Expand source code
def jsonable(value):
    import torch

    if isinstance(value, dict):
        return {k: jsonable(v) for k, v in value.items()}
    elif isinstance(value, (list, tuple)):
        return [jsonable(v) for v in value]
    elif isinstance(value, torch.Tensor):
        return value.detach().cpu().tolist()
    elif isinstance(value, Path):
        return str(value)
    elif value is None or isinstance(value, (int, float, str, bool)):
        return value
    elif isinstance(value, BaseContainer):
        return OmegaConf.to_container(value)
    else:
        raise ValueError(f"{repr(value)} is not jsonable.")
def reliable_rmtree(path: pathlib.Path)

Reliably delete the given folder, trying to remove while ignoring errors, and if any files remain, renaming to some trash folder.

Expand source code
def reliable_rmtree(path: Path):
    """Reliably delete the given folder, trying to remove while ignoring errors,
    and if any files remain, renaming to some trash folder."""
    error = False

    def _on_error(func, error_path, exc_info):
        nonlocal error
        error = True
        logger.warning(f"Error deleting file {error_path}")

    rmtree(path, onerror=_on_error)
    if error:
        assert path.exists()
        target_name = tempfile.mkdtemp(dir=path.parent, prefix=path.name + "_", suffix="_trash")
        logger.warning(f"Deletion of {path} failed, moving to {target_name}")
        path.rename(target_name)
    else:
        assert not path.exists()
def try_load(path: pathlib.Path, load=<built-in function load>, mode: str = 'rb')

Try to load from a path using torch.load, and handles various failure cases. Return None upon failure.

Expand source code
def try_load(path: Path, load=pickle.load, mode: str = "rb"):
    """
    Try to load from a path using torch.load, and handles various failure cases.
    Return None upon failure.
    """
    try:
        return load(open(path, mode))
    except (OSError, pickle.UnpicklingError, RuntimeError, EOFError) as exc:
        # Trying to list everything that can go wrong.
        logger.warning(
            "An error happened when trying to load from %s, this file will be ignored: %r",
            path, exc)
        return None
def write_and_rename(path: pathlib.Path, mode: str = 'wb', suffix: str = '.tmp')

Write to a temporary file with the given suffix, then rename it to the right filename. As renaming a file is usually much faster than writing it, this removes (or highly limits as far as I understand NFS) the likelihood of leaving a half-written checkpoint behind, if killed at the wrong time.

Expand source code
@contextmanager
def write_and_rename(path: Path, mode: str = "wb", suffix: str = ".tmp"):
    """
    Write to a temporary file with the given suffix, then rename it
    to the right filename. As renaming a file is usually much faster
    than writing it, this removes (or highly limits as far as I understand NFS)
    the likelihood of leaving a half-written checkpoint behind, if killed
    at the wrong time.
    """
    tmp_path = str(path) + suffix
    with open(tmp_path, mode) as f:
        yield f
    os.rename(tmp_path, path)