Module dora.names
Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from pathlib import Path
import typing as tp
from .xp import XP
class NamesMixin:
"""Mixin that handles everything related to the naming of experiments.
"""
def short_name_part(self, key: str, value: tp.Any) -> str:
"""Shorten the name of an XP.
"""
key_parts = key.split(".")
short_key_parts = []
for part in key_parts[:-1]:
short_key_parts.append(part[:3])
short_key_parts.append(key_parts[-1])
key = ".".join(short_key_parts)
if isinstance(value, Path):
value = value.name
if value is True:
return key
return f"{key}={value}"
def get_name_parts(self, xp: XP) -> OrderedDict:
"""Returns name parts, i.e. an OrderedDict from param name -> param value.
Name parts that don't impact the signature should be ignored.
"""
raise NotImplementedError()
def get_name(self, xp: XP) -> str:
"""Returns the XP name.
"""
return self.get_names([xp])[-1]
def _get_short_name(self, parts: OrderedDict, reference: dict = {}):
out_parts = []
for key, value in parts.items():
if key not in reference:
part = self.short_name_part(key, value)
out_parts.append(part)
return " ".join(out_parts)
def get_names(self, xps: tp.List[XP]) -> tp.Tuple[tp.List[str], str]:
"""Given list of XPs, return individual XP names + base name.
The common part in all XPs are factored into the base name
"""
assert len(xps) > 0
reference = self.get_name_parts(xps[0])
all_xp_parts = []
for xp in xps:
parts = self.get_name_parts(xp)
for key, val in parts.items():
if key in reference and reference[key] != val:
reference.pop(key)
missing = set(reference.keys()) - set(parts.keys())
for key in missing:
reference.pop(key)
all_xp_parts.append(parts)
names = []
for parts in all_xp_parts:
names.append(self._get_short_name(parts, reference))
base_name = self._get_short_name(reference)
return names, base_name
Classes
class NamesMixin
-
Mixin that handles everything related to the naming of experiments.
Expand source code
class NamesMixin: """Mixin that handles everything related to the naming of experiments. """ def short_name_part(self, key: str, value: tp.Any) -> str: """Shorten the name of an XP. """ key_parts = key.split(".") short_key_parts = [] for part in key_parts[:-1]: short_key_parts.append(part[:3]) short_key_parts.append(key_parts[-1]) key = ".".join(short_key_parts) if isinstance(value, Path): value = value.name if value is True: return key return f"{key}={value}" def get_name_parts(self, xp: XP) -> OrderedDict: """Returns name parts, i.e. an OrderedDict from param name -> param value. Name parts that don't impact the signature should be ignored. """ raise NotImplementedError() def get_name(self, xp: XP) -> str: """Returns the XP name. """ return self.get_names([xp])[-1] def _get_short_name(self, parts: OrderedDict, reference: dict = {}): out_parts = [] for key, value in parts.items(): if key not in reference: part = self.short_name_part(key, value) out_parts.append(part) return " ".join(out_parts) def get_names(self, xps: tp.List[XP]) -> tp.Tuple[tp.List[str], str]: """Given list of XPs, return individual XP names + base name. The common part in all XPs are factored into the base name """ assert len(xps) > 0 reference = self.get_name_parts(xps[0]) all_xp_parts = [] for xp in xps: parts = self.get_name_parts(xp) for key, val in parts.items(): if key in reference and reference[key] != val: reference.pop(key) missing = set(reference.keys()) - set(parts.keys()) for key in missing: reference.pop(key) all_xp_parts.append(parts) names = [] for parts in all_xp_parts: names.append(self._get_short_name(parts, reference)) base_name = self._get_short_name(reference) return names, base_name
Subclasses
Methods
def get_name(self, xp: XP) ‑> str
-
Returns the XP name.
Expand source code
def get_name(self, xp: XP) -> str: """Returns the XP name. """ return self.get_names([xp])[-1]
def get_name_parts(self, xp: XP) ‑> collections.OrderedDict
-
Returns name parts, i.e. an OrderedDict from param name -> param value. Name parts that don't impact the signature should be ignored.
Expand source code
def get_name_parts(self, xp: XP) -> OrderedDict: """Returns name parts, i.e. an OrderedDict from param name -> param value. Name parts that don't impact the signature should be ignored. """ raise NotImplementedError()
def get_names(self, xps: List[XP]) ‑> Tuple[List[str], str]
-
Given list of XPs, return individual XP names + base name. The common part in all XPs are factored into the base name
Expand source code
def get_names(self, xps: tp.List[XP]) -> tp.Tuple[tp.List[str], str]: """Given list of XPs, return individual XP names + base name. The common part in all XPs are factored into the base name """ assert len(xps) > 0 reference = self.get_name_parts(xps[0]) all_xp_parts = [] for xp in xps: parts = self.get_name_parts(xp) for key, val in parts.items(): if key in reference and reference[key] != val: reference.pop(key) missing = set(reference.keys()) - set(parts.keys()) for key in missing: reference.pop(key) all_xp_parts.append(parts) names = [] for parts in all_xp_parts: names.append(self._get_short_name(parts, reference)) base_name = self._get_short_name(reference) return names, base_name
def short_name_part(self, key: str, value: Any) ‑> str
-
Shorten the name of an XP.
Expand source code
def short_name_part(self, key: str, value: tp.Any) -> str: """Shorten the name of an XP. """ key_parts = key.split(".") short_key_parts = [] for part in key_parts[:-1]: short_key_parts.append(part[:3]) short_key_parts.append(key_parts[-1]) key = ".".join(short_key_parts) if isinstance(value, Path): value = value.name if value is True: return key return f"{key}={value}"