# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from are.simulation.apps.contacts import Contact
from are.simulation.scenarios.scenario import Scenario
from are.simulation.scenarios.validation_result import ScenarioValidationResult
from are.simulation.types import (
AbstractEnvironment,
CompletedEvent,
CompletedOracleEvent,
EventType,
)
from are.simulation.validation.configs import (
BaseEventJudgeConfig,
BaseJudgeConfig,
BaseToolJudgeConfig,
)
from are.simulation.validation.judgment import Judgment
from are.simulation.validation.utils.trace_utils import injected_traceable
logger: logging.Logger = logging.getLogger(__name__)
from are.simulation.agents.agent_log import ErrorLog
[docs]
@dataclass
class BaseJudgeState:
# Flag if initialized
initialized: bool = False
# Turn
nb_turns: int = -1
turn_idx: int = -1
last_turn_success: bool = True
last_turn_rationale: str = ""
# Scenario data
scenario_start_time: float = 0.0
scenario_tasks: list[str] = field(default_factory=list)
user_details: Contact | None = None
# Oracle events
turn_to_oracle_events: list[list[CompletedOracleEvent]] = field(
default_factory=list
)
turn_to_oracle_graph: list[dict[str, list[str]]] = field(default_factory=list)
oracle_event_id_to_turn_idx: dict[str, int] = field(default_factory=dict)
# Agent events
turn_to_agent_events: list[list[CompletedEvent]] = field(default_factory=list)
@property
def agent_events(self) -> list[CompletedEvent]:
return [event for events in self.turn_to_agent_events for event in events]
@property
def current_turn_agent_events(self) -> list[CompletedEvent]:
return self.turn_to_agent_events[self.turn_idx]
@property
def current_turn_oracle_events(self) -> list[CompletedOracleEvent]:
return self.turn_to_oracle_events[self.turn_idx]
@property
def current_turn_oracle_graph(self) -> dict[str, list[str]]:
return self.turn_to_oracle_graph[self.turn_idx]
[docs]
class BaseJudge(ABC):
"""
Base class for a judge. A judge compares an agent and oracle event log for a given scenario.
"""
def __init__(self, config: BaseJudgeConfig):
self.tracer = config.tracer
self.state = BaseJudgeState()
self.error_logs: list[ErrorLog] = []
[docs]
@abstractmethod
def initialize_state(self, scenario: Scenario) -> None:
pass
@abstractmethod
def __call__(self, env: AbstractEnvironment) -> Judgment:
pass
[docs]
def validate(self, env: AbstractEnvironment) -> ScenarioValidationResult:
if not self.state.initialized:
raise ValueError("Judge must be initialized before validation")
# Early returns for failure conditions
if not self.state.last_turn_success:
logging.warning("Last turn was already rejected, skipping validation")
return ScenarioValidationResult(
success=False, rationale=self.state.last_turn_rationale
)
is_last_turn = (self.state.turn_idx + 1) == (self.state.nb_turns - 1)
if not is_last_turn:
logging.info(
f"Validation called at turn {self.state.turn_idx} but nb_turns is {self.state.nb_turns}"
)
# Use the injected_traceable decorator for tracing
@injected_traceable(
trace_type="judge", tags=["judge"], log_input_args=False
)
def inner_call(self) -> Judgment:
return Judgment(
success=False,
failure=f"Validation called at turn {self.state.turn_idx} but nb_turns is {self.state.nb_turns}",
)
judgment = inner_call(self)
return ScenarioValidationResult(
success=False, rationale=str(judgment.failure)
)
# Judge the current turn (last turn)
return self.validate_current_turn(env)
[docs]
def validate_current_turn(
self, env: AbstractEnvironment
) -> ScenarioValidationResult:
if not self.state.initialized:
raise ValueError("Judge must be initialized before validation")
try:
judgment = self(env)
except Exception as e:
self.error_logs.append(
ErrorLog(
error=str(type(e).__name__),
exception=str(e),
category=type(e).__name__,
agent=self.__class__.__name__,
timestamp=env.time_manager.time(),
agent_id="unknown",
)
)
logger.error(e)
return ScenarioValidationResult(
success=False, exception=e, rationale="Exception"
)
return ScenarioValidationResult(
success=bool(judgment.success),
rationale=str(judgment.failure),
)
[docs]
def trigger_condition(
self, env: AbstractEnvironment, turn_idx: int
) -> tuple[bool, dict[str, str]]:
judgment = self(env)
return bool(judgment.success), judgment.agent_event_id_to_oracle_event_id
[docs]
class EventJudge(ABC):
"""
Base class for an event judge. An event judge compares a agent and oracle events and decides if the two match.
"""
def __init__(
self, config: BaseEventJudgeConfig, event_type: EventType, judge_type: str = ""
) -> None:
self.judge_type = judge_type
self.event_type = event_type
self.tracer = config.tracer
[docs]
@abstractmethod
def compare(
self, agent_event: CompletedEvent, oracle_event: CompletedOracleEvent, **kwargs
) -> bool | None:
pass
def __call__(
self, agent_event: CompletedEvent, oracle_event: CompletedOracleEvent, **kwargs
) -> bool | None:
oracle_agent_tool_name = oracle_event.tool_name
agent_event_tool_name = agent_event.tool_name
if oracle_agent_tool_name != agent_event_tool_name:
# We reject if the tool are not the same.
# This should be done at the tool judge level but we do it here for cleaner logging.
return False
trace_type = f"{oracle_agent_tool_name}_vs_{agent_event_tool_name}"
@injected_traceable(trace_type=trace_type, tags=["judge"], log_input_args=False)
def inner_call(
self,
agent_event: CompletedEvent,
oracle_event: CompletedOracleEvent,
**kwargs,
) -> bool | None:
# Check event type
if oracle_event.event_type != self.event_type:
raise ValueError(
f"Oracle type {oracle_event.event_type.value} does not match config type {self.event_type.value}"
)
if agent_event.event_type != oracle_event.event_type:
return False
return self.compare(agent_event, oracle_event, **kwargs)
return inner_call(self, agent_event, oracle_event, **kwargs)