Module dora.names

Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict
from pathlib import Path
import typing as tp

from .xp import XP


class NamesMixin:
    """Mixin that handles everything related to the naming of experiments.
    """

    def short_name_part(self, key: str, value: tp.Any) -> str:
        """Shorten the name of an XP.
        """
        key_parts = key.split(".")
        short_key_parts = []
        for part in key_parts[:-1]:
            short_key_parts.append(part[:3])
        short_key_parts.append(key_parts[-1])
        key = ".".join(short_key_parts)

        if isinstance(value, Path):
            value = value.name
        if value is True:
            return key
        return f"{key}={value}"

    def get_name_parts(self, xp: XP) -> OrderedDict:
        """Returns name parts, i.e. an OrderedDict from param name -> param value.
        Name parts that don't impact the signature should be ignored.
        """
        raise NotImplementedError()

    def get_name(self, xp: XP) -> str:
        """Returns the XP name.
        """
        return self.get_names([xp])[-1]

    def _get_short_name(self, parts: OrderedDict, reference: dict = {}):
        out_parts = []
        for key, value in parts.items():
            if key not in reference:
                part = self.short_name_part(key, value)
                out_parts.append(part)
        return " ".join(out_parts)

    def get_names(self, xps: tp.List[XP]) -> tp.Tuple[tp.List[str], str]:
        """Given list of XPs, return individual XP names + base name.
        The common part in all XPs are factored into the base name
        """
        assert len(xps) > 0
        reference = self.get_name_parts(xps[0])
        all_xp_parts = []
        for xp in xps:
            parts = self.get_name_parts(xp)
            for key, val in parts.items():
                if key in reference and reference[key] != val:
                    reference.pop(key)

            missing = set(reference.keys()) - set(parts.keys())
            for key in missing:
                reference.pop(key)
            all_xp_parts.append(parts)

        names = []
        for parts in all_xp_parts:
            names.append(self._get_short_name(parts, reference))

        base_name = self._get_short_name(reference)
        return names, base_name

Classes

class NamesMixin

Mixin that handles everything related to the naming of experiments.

Expand source code
class NamesMixin:
    """Mixin that handles everything related to the naming of experiments.
    """

    def short_name_part(self, key: str, value: tp.Any) -> str:
        """Shorten the name of an XP.
        """
        key_parts = key.split(".")
        short_key_parts = []
        for part in key_parts[:-1]:
            short_key_parts.append(part[:3])
        short_key_parts.append(key_parts[-1])
        key = ".".join(short_key_parts)

        if isinstance(value, Path):
            value = value.name
        if value is True:
            return key
        return f"{key}={value}"

    def get_name_parts(self, xp: XP) -> OrderedDict:
        """Returns name parts, i.e. an OrderedDict from param name -> param value.
        Name parts that don't impact the signature should be ignored.
        """
        raise NotImplementedError()

    def get_name(self, xp: XP) -> str:
        """Returns the XP name.
        """
        return self.get_names([xp])[-1]

    def _get_short_name(self, parts: OrderedDict, reference: dict = {}):
        out_parts = []
        for key, value in parts.items():
            if key not in reference:
                part = self.short_name_part(key, value)
                out_parts.append(part)
        return " ".join(out_parts)

    def get_names(self, xps: tp.List[XP]) -> tp.Tuple[tp.List[str], str]:
        """Given list of XPs, return individual XP names + base name.
        The common part in all XPs are factored into the base name
        """
        assert len(xps) > 0
        reference = self.get_name_parts(xps[0])
        all_xp_parts = []
        for xp in xps:
            parts = self.get_name_parts(xp)
            for key, val in parts.items():
                if key in reference and reference[key] != val:
                    reference.pop(key)

            missing = set(reference.keys()) - set(parts.keys())
            for key in missing:
                reference.pop(key)
            all_xp_parts.append(parts)

        names = []
        for parts in all_xp_parts:
            names.append(self._get_short_name(parts, reference))

        base_name = self._get_short_name(reference)
        return names, base_name

Subclasses

Methods

def get_name(self, xp: XP) ‑> str

Returns the XP name.

Expand source code
def get_name(self, xp: XP) -> str:
    """Returns the XP name.
    """
    return self.get_names([xp])[-1]
def get_name_parts(self, xp: XP) ‑> collections.OrderedDict

Returns name parts, i.e. an OrderedDict from param name -> param value. Name parts that don't impact the signature should be ignored.

Expand source code
def get_name_parts(self, xp: XP) -> OrderedDict:
    """Returns name parts, i.e. an OrderedDict from param name -> param value.
    Name parts that don't impact the signature should be ignored.
    """
    raise NotImplementedError()
def get_names(self, xps: List[XP]) ‑> Tuple[List[str], str]

Given list of XPs, return individual XP names + base name. The common part in all XPs are factored into the base name

Expand source code
def get_names(self, xps: tp.List[XP]) -> tp.Tuple[tp.List[str], str]:
    """Given list of XPs, return individual XP names + base name.
    The common part in all XPs are factored into the base name
    """
    assert len(xps) > 0
    reference = self.get_name_parts(xps[0])
    all_xp_parts = []
    for xp in xps:
        parts = self.get_name_parts(xp)
        for key, val in parts.items():
            if key in reference and reference[key] != val:
                reference.pop(key)

        missing = set(reference.keys()) - set(parts.keys())
        for key in missing:
            reference.pop(key)
        all_xp_parts.append(parts)

    names = []
    for parts in all_xp_parts:
        names.append(self._get_short_name(parts, reference))

    base_name = self._get_short_name(reference)
    return names, base_name
def short_name_part(self, key: str, value: Any) ‑> str

Shorten the name of an XP.

Expand source code
def short_name_part(self, key: str, value: tp.Any) -> str:
    """Shorten the name of an XP.
    """
    key_parts = key.split(".")
    short_key_parts = []
    for part in key_parts[:-1]:
        short_key_parts.append(part[:3])
    short_key_parts.append(key_parts[-1])
    key = ".".join(short_key_parts)

    if isinstance(value, Path):
        value = value.name
    if value is True:
        return key
    return f"{key}={value}"