Source code for are.simulation.validation.configs

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.


from dataclasses import dataclass, field
from typing import Any, Callable

from are.simulation.agents.are_simulation_agent_config import LLMEngineConfig
from are.simulation.validation.constants import (
    TOOL_ARG_CHECKER_TYPE_REGISTRY,
    TOOL_EVALUATION_CRITERIA_REGISTRY,
    TOOL_SOFT_CHECKER_TYPE_REGISTRY,
    CheckerType,
    SoftCheckerType,
    ToolArgCheckerTypeRegistry,
    ToolCriteriaRegistry,
    ToolSoftCheckerTypeRegistry,
)
from are.simulation.validation.prompts import (
    IN_CONTEXT_JUDGE_SYSTEM_PROMPT_TEMPLATE,
    TIME_SYSTEM_PROMPT_TEMPLATE,
)

# Default judge configuration
DEFAULT_JUDGE_MODEL = "meta-llama/Meta-Llama-3.3-70B-Instruct"
DEFAULT_JUDGE_PROVIDER = "huggingface"


def create_judge_engine(
    judge_engine_config: LLMEngineConfig | None = None,
):
    """Create a judge engine with the specified configuration."""
    if judge_engine_config is None:
        judge_engine_config = LLMEngineConfig(
            model_name=DEFAULT_JUDGE_MODEL,
            provider=None,
            endpoint=None,
        )

    # Use LiteLLM for all other cases (external or internal with overrides)
    from are.simulation.agents.llm.litellm.litellm_engine import (
        LiteLLMEngine,
        LiteLLMModelConfig,
    )

    final_provider = (
        judge_engine_config.provider or DEFAULT_JUDGE_PROVIDER or "huggingface"
    )

    judge_config = LiteLLMModelConfig(
        model_name=judge_engine_config.model_name,
        provider=final_provider,
        endpoint=judge_engine_config.endpoint,
    )

    return LiteLLMEngine(model_config=judge_config)


@dataclass
class ToolCheckerParam:
    # Parameter for the tool checker of the hard judge
    arg_name: str
    checker_type: CheckerType
    tool_name: str
    checker_args: dict[str, Any] = field(default_factory=dict)


@dataclass
class BaseToolJudgeConfig:
    tool_name: str
    arg_to_checker_type: dict[
        str, CheckerType
    ]  # the list of args to check and the type of checker to use for each arg
    tracer: Callable | None = None


@dataclass
class HardToolJudgeConfig(BaseToolJudgeConfig):
    event_id_to_checker_params: dict[str, list[ToolCheckerParam]] | None = None


@dataclass
class SoftToolJudgeConfig(BaseToolJudgeConfig):
    engine: Callable = field(default_factory=create_judge_engine)
    # Soft checker
    soft_checker_types: list[SoftCheckerType] = field(
        default_factory=lambda: [SoftCheckerType.content_checker]
    )

    def __post_init__(self):
        if len(self.soft_checker_types) == 0:
            self.soft_checker_types = [SoftCheckerType.content_checker]


@dataclass
class MildToolJudgeConfig(BaseToolJudgeConfig):
    engine: Callable = field(default_factory=create_judge_engine)
    soft_checker_types: list[SoftCheckerType] = field(default_factory=list)
    # Scripted checkers related config
    event_id_to_checker_params: dict[str, list[ToolCheckerParam]] | None = None


@dataclass
class BaseEventJudgeConfig:
    tracer: Callable | None = None


@dataclass
class EnvUserEventJudgeConfig(BaseEventJudgeConfig):
    pass


@dataclass
class AgentEventJudgeConfig(BaseEventJudgeConfig):
    # Time related config
    check_time_threshold_seconds: float = 1.0
    pre_event_tolerance_seconds: float = 10.0
    post_event_tolerance_seconds: float = 25.0
    # Tool related config
    per_tool_arg_to_checker_type: ToolArgCheckerTypeRegistry = field(
        default_factory=lambda: TOOL_ARG_CHECKER_TYPE_REGISTRY
    )
    per_tool_soft_checker_types: ToolSoftCheckerTypeRegistry = field(
        default_factory=lambda: TOOL_SOFT_CHECKER_TYPE_REGISTRY
    )
    engine: Callable = field(default_factory=create_judge_engine)
    # Scripted checkers related config
    event_id_to_checker_params: dict[str, list[ToolCheckerParam]] | None = None


[docs] @dataclass class BaseJudgeConfig: tracer: Callable | None = None
[docs] @dataclass class GraphPerEventJudgeConfig(BaseJudgeConfig): # Time related config check_time_threshold_seconds: float = 1.0 pre_event_tolerance_seconds: float = 10.0 post_event_tolerance_seconds: float = 25.0 # Tool related config per_tool_arg_to_checker_type: ToolArgCheckerTypeRegistry = field( default_factory=lambda: TOOL_ARG_CHECKER_TYPE_REGISTRY ) engine: Callable = field(default_factory=create_judge_engine) per_tool_soft_checker_types: ToolSoftCheckerTypeRegistry = field( default_factory=lambda: TOOL_SOFT_CHECKER_TYPE_REGISTRY ) # Scripted checkers related config # If this field is not `None`, the soft judge will not be used. event_id_to_checker_params: dict[str, list[ToolCheckerParam]] | None = None # Preliminary check extra_send_message_to_user_allowed: int = 1
[docs] @dataclass class ScriptedGraphPerEventJudgeConfig(GraphPerEventJudgeConfig): """ Config for the scripted graph per event judge. Scripted judge is a judge where the soft judge is deactivated and instead scripted checks will be performed by the hard judge. The `event_id_to_checker_params` field is used to specify the scripted checks to perform. """ # Change default such that soft judge is not used. event_id_to_checker_params: dict[str, list[ToolCheckerParam]] | None = field( default_factory=dict ) def __post_init__(self): if self.event_id_to_checker_params is None: raise ValueError( "event_id_to_checker_params must be specified for ScriptedGraphPerEventJudgeConfig" )
[docs] @dataclass class InContextJudgeConfig(BaseJudgeConfig): # Time related config check_time_threshold_seconds: float = 1.0 pre_event_tolerance_seconds: float = 10.0 post_event_tolerance_seconds: float = 25.0 time_system_prompt_template: str = TIME_SYSTEM_PROMPT_TEMPLATE # Tool related config per_tool_evaluation_criteria: ToolCriteriaRegistry = field( default_factory=lambda: TOOL_EVALUATION_CRITERIA_REGISTRY ) tool_to_selected_args: ToolArgCheckerTypeRegistry = ( field( # Will not use the checker type but only arg names default_factory=lambda: TOOL_ARG_CHECKER_TYPE_REGISTRY ) ) engine: Callable = field(default_factory=create_judge_engine) system_prompt_template: str = IN_CONTEXT_JUDGE_SYSTEM_PROMPT_TEMPLATE