Source code for are.simulation.validation.factory
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Type
from are.simulation.validation.base import BaseJudge
from are.simulation.validation.configs import (
BaseJudgeConfig,
GraphPerEventJudgeConfig,
InContextJudgeConfig,
ScriptedGraphPerEventJudgeConfig,
)
from are.simulation.validation.judge import GraphPerEventJudge, InContextJudge
[docs]
class JudgeFactory:
def __init__(self) -> None:
self.judge_classes: dict[Type[BaseJudgeConfig], Type[BaseJudge]] = {
ScriptedGraphPerEventJudgeConfig: GraphPerEventJudge,
GraphPerEventJudgeConfig: GraphPerEventJudge,
InContextJudgeConfig: InContextJudge,
}
def __call__(self, config: BaseJudgeConfig) -> BaseJudge:
judge_class = self.judge_classes.get(type(config), None)
if judge_class is None:
raise ValueError(f"No judge class found for config {type(config)}")
return judge_class(config)