Source code for fairseq2.utils.validation

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
This module provides a set of helper types for object state validation. These
types are primarly used by recipe and asset configuration dataclasses to ensure
that all options are set correctly.

A class -*typically a configuration dataclass*- that wants to support validation
should expose a ``validate(self) -> ValidationResult`` method. Optionally,
the class can derive from the runtime-checkable :class:`Validatable` protocol
to make its intent more clear.

A typical implementation of a ``validate()`` method looks like the following:

.. code-block:: python

    from dataclasses import dataclass

    from fairseq2.utils.validation import Validatable, ValidationResult

    @dataclass
    class FooConfig(Validatable):
        field1: str
        field2: int

        def validate(self) -> ValidationResult:
            result = ValidationResult()

            if not self.field1:
                result.add_error("`field1` must be a non-empty string.")

            if self.field2 < 1:
                result.add_error("`field2` must be a positive integer.")

            return result


Note that ``FooConfig`` must NOT call ``validate()`` on its sub-fields that are
validatable. :class:`ObjectValidator` is responsible for traversing the object
graph and call each ``validate()`` method it finds in dataclasses, as well as in
composite objects of types ``list``, ``Mapping``, ``Set``, and ``tuple``.

Whenever ``FooConfig`` is used in a recipe configuration, fairseq2 will ensure
that it is validated before setting :attr:`RecipeContext.config`. To manually
validate an object outside of recipes, :class:`StandardObjectValidator` can
be used:

.. code-block:: python

    from dataclasses import dataclass

    from fairseq2.utils.validation import (
        ObjectValidator,
        StandardObjectValidator,
        Validatable,
        ValidationError,
        ValidationResult,
    )

    @dataclass
    class FooConfig(Validatable):
        field1: str
        field2: int

        def validate(self) -> ValidationResult:
            result = ValidationResult()

            if not self.field1:
                result.add_error("`field1` must be a non-empty string.")

            if self.field2 < 1:
                result.add_error("`field2` must be a positive integer.")

            return result

    config = FooConfig(field1="foo", field2=0)

    validator: ObjectValidator = StandardObjectValidator()

    try:
        validator.validate(config)
    except ValidationError as ex:
        # Prints an error message indicating that `field2` must be a
        # positive integer.
        print(ex.result)
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence, Set
from dataclasses import fields
from typing import Protocol, final, runtime_checkable

from typing_extensions import override

from fairseq2.typing import is_dataclass_instance


[docs] class ObjectValidator(ABC): """ Validates an object along with its sub-objects if it is a composite object (i.e. a ``dataclass``, ``list``, ``Mapping``, ``Set``, or ``tuple``) and raises a :class:`ValidationError` if any of them returns an error. """
[docs] @abstractmethod def validate(self, obj: object) -> None: """ Validates ``obj``. :param obj: The object to validate. :raises ValidationError: ``obj`` or one of its sub-objects has a validation error. """
[docs] @final class StandardObjectValidator(ObjectValidator): """Represents the standard implementation of :class:`ObjectValidator`."""
[docs] @override def validate(self, obj: object) -> None: result = self._do_validate(obj) if result.has_error(): raise ValidationError(result)
def _do_validate(self, obj: object) -> ValidationResult: if isinstance(obj, Validatable): result = obj.validate() else: result = ValidationResult() if is_dataclass_instance(obj): for field in fields(obj): value = getattr(obj, field.name) sub_result = self._do_validate(value) if sub_result.has_error(): result.add_sub_result(field.name, sub_result) elif isinstance(obj, Mapping): for k, v in obj.items(): sub_result = self._do_validate(v) if sub_result.has_error(): result.add_sub_result(f"[{repr(k)}]", sub_result) elif isinstance(obj, (list, tuple, Set)): for i, v in enumerate(obj): sub_result = self._do_validate(v) if sub_result.has_error(): result.add_sub_result(f"[{i}]", sub_result) return result
[docs] @runtime_checkable class Validatable(Protocol): """Represents the protocol for validatable objects."""
[docs] def validate(self) -> ValidationResult: """ Validates the state of the object. :returns: The result of the validation. """
[docs] @final class ValidationResult: """Holds the result of a :meth:`~Validatable.validate` call.""" def __init__(self) -> None: self._errors: list[str] = [] self._sub_results: dict[str, ValidationResult] = {}
[docs] def add_error(self, message: str) -> None: """Adds an error message to the result.""" self._errors.append(message)
[docs] def add_sub_result(self, field: str, result: ValidationResult) -> None: """ Adds the validation result of a sub-object as a sub-result. :param field: The name of the sub-object. For a dataclass, it is the name of the field, for a ``Mapping`` it is the name of the key formatted as ``f"[{repr(key)}]"``, for ``list``, ``Set``, and ``tuple`` it is the index of the value formatted as ``f"[{index}]"``. :param result: The validation result of the sub-object. """ self._sub_results[field] = result
[docs] def has_error(self) -> bool: """ Returns ``True`` if the object or any of its sub-objects have a validation error. """ if self._errors: return True return any(r.has_error() for r in self._sub_results.values())
@property def errors(self) -> Sequence[str]: """ Returns the validation errors of the object, excluding errors of its sub-objects. """ return self._errors @property def sub_results(self) -> Mapping[str, ValidationResult]: """Returns the validation results of the sub-objects.""" return self._sub_results def __str__(self) -> str: output: list[str] = [] self._create_error_string(output, field_path=[]) return " ".join(output) def _create_error_string(self, output: list[str], field_path: list[str]) -> None: s = " ".join(self._errors) if s: if field_path: pathname = self._build_pathname(field_path) output.append(f"`{pathname}` is not valid: {s}") else: output.append(s) for field, result in self._sub_results.items(): field_path.append(field) result._create_error_string(output, field_path) field_path.pop() def _build_pathname(self, field_path: list[str]) -> str: segments = [field_path[0]] for p in field_path[1:]: if not p.startswith("[") or not p.endswith("]"): segments.append(".") segments.append(p) return "".join(segments)
[docs] class ValidationError(Exception): """Raised when a validation error occurs.""" result: ValidationResult """The result containing validation errors.""" def __init__( self, result: ValidationResult | str, *, field: str | None = None ) -> None: """ :param result: The validation result. If ``str``, will be converted to a result with ``ValidationResult(result)``. :param field: If not ``None``, ``result`` will be treated as the sub-result of ``field``. """ if isinstance(result, str): tmp = ValidationResult() tmp.add_error(result) result = tmp if field is not None: tmp = ValidationResult() tmp.add_sub_result(field, result) result = tmp self.result = result def __str__(self) -> str: return str(self.result)