Source code for are.simulation.validation.event_judge

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.


import logging
from typing import Any

from are.simulation.types import CompletedEvent, EventTimeComparator, EventType
from are.simulation.validation.base import EventJudge
from are.simulation.validation.configs import (
    AgentEventJudgeConfig,
    EnvUserEventJudgeConfig,
    MildToolJudgeConfig,
)
from are.simulation.validation.tool_judge import MildToolJudge
from are.simulation.validation.utils.scenario_utils import CompletedOracleEvent
from are.simulation.validation.utils.trace_utils import injected_traceable

logger: logging.Logger = logging.getLogger(__name__)


[docs] class EnvUserEventJudge(EventJudge): """ A judge that compares a pair of environment/user events from the agent log and the oracle agent log. The two events match if their event ids is the same. """ def __init__(self, event_type: EventType, config: EnvUserEventJudgeConfig) -> None: super().__init__(config, event_type, "env-user")
[docs] @injected_traceable(trace_type="eq_checker", tags=["judge"]) def eq_checker(self, x_agent: Any, x_oracle: Any, **kwargs) -> bool: return x_agent == x_oracle
[docs] def compare( self, agent_event: CompletedEvent, oracle_event: CompletedOracleEvent, **kwargs ) -> bool | None: # Only compare the event id return self.eq_checker( agent_event.event_id, oracle_event.event_id, arg_name="event_id" )
[docs] class AgentEventJudge(EventJudge): """ A judge that compares a pair of agent events from the agent log and the oracle agent log. """ def __init__(self, config: AgentEventJudgeConfig) -> None: super().__init__(config, EventType.AGENT, "agent") self.config = config # Per tool judge self.tool_judges = {} for tool_name in self.config.per_tool_arg_to_checker_type.keys(): # Arg to checker type arg_to_checker_type = self.config.per_tool_arg_to_checker_type[tool_name] # Event id to checker params event_id_to_checker_params = None if self.config.event_id_to_checker_params is not None: event_id_to_checker_params = { event_id: [ checker_param for checker_param in checker_params if checker_param.tool_name == tool_name ] for event_id, checker_params in self.config.event_id_to_checker_params.items() if any( checker_param.tool_name == tool_name for checker_param in checker_params ) } # Soft checker types soft_checker_types = ( self.config.per_tool_soft_checker_types[tool_name] if tool_name in self.config.per_tool_soft_checker_types else [] ) self.tool_judges[tool_name] = MildToolJudge( MildToolJudgeConfig( tool_name=tool_name, arg_to_checker_type=arg_to_checker_type, engine=self.config.engine, event_id_to_checker_params=event_id_to_checker_params, soft_checker_types=soft_checker_types, tracer=self.tracer, ) )
[docs] @injected_traceable(trace_type="event_time_checker", tags=["judge"]) def event_time_checker( self, agent_event_time: float, oracle_event_time: float, pre_event_tolerance_seconds: float = 5.0, post_event_tolerance_seconds: float = 20.0, event_time_comparator: str | None = None, ) -> bool: """ Checks if the agent event time is within the allowed tolerance range compared to the oracle event time. Args: agent_event_time (float): The time of the agent event (relative or absolute) oracle_event_time (float): The time of the oracle event (relative or absolute). pre_event_tolerance_seconds (float): The allowed time in seconds before the oracle event time. post_event_tolerance_seconds (float): The allowed time in seconds after the oracle event time. event_time_comparator (str | None): The type of comparison to perform between the agent and oracle event times. The arg type is str instead of EventTimeComparator for better readability in the tracer. Returns: bool: True if the agent event time is within the allowed tolerance range, False otherwise. """ if ( event_time_comparator is None or event_time_comparator == EventTimeComparator.EQUAL.value ): return ( agent_event_time <= oracle_event_time + post_event_tolerance_seconds and agent_event_time >= oracle_event_time - pre_event_tolerance_seconds ) elif event_time_comparator == EventTimeComparator.LESS_THAN.value: return agent_event_time <= oracle_event_time + post_event_tolerance_seconds elif event_time_comparator == EventTimeComparator.GREATER_THAN.value: return agent_event_time >= oracle_event_time - pre_event_tolerance_seconds else: raise ValueError( f"Event time comparator {event_time_comparator} is not valid" )
[docs] def check_time( self, agent_event: CompletedEvent, oracle_event: CompletedOracleEvent, max_parent_oracle_event_time: float, max_parent_agent_event_time: float, ) -> bool: assert agent_event.event_time is not None, "Agent event time cannot be None" comparator = ( oracle_event.event_time_comparator.value if oracle_event.event_time_comparator else None ) if oracle_event.absolute_event_time is not None: return self.event_time_checker( agent_event_time=agent_event.event_time, oracle_event_time=oracle_event.absolute_event_time, pre_event_tolerance_seconds=self.config.pre_event_tolerance_seconds, post_event_tolerance_seconds=self.config.post_event_tolerance_seconds, event_time_comparator=comparator, ) agent_event_time = agent_event.event_time agent_event_relative_time = agent_event_time - max_parent_agent_event_time oracle_event_time = oracle_event.event_time assert oracle_event_time is not None, "Oracle event time cannot be None" oracle_event_relative_time = oracle_event_time - max_parent_oracle_event_time if ( oracle_event_relative_time > self.config.check_time_threshold_seconds or oracle_event.event_time_comparator is not None ): return self.event_time_checker( agent_event_time=agent_event_relative_time, oracle_event_time=oracle_event_relative_time, pre_event_tolerance_seconds=self.config.pre_event_tolerance_seconds, post_event_tolerance_seconds=self.config.post_event_tolerance_seconds, event_time_comparator=comparator, ) return True
[docs] def compare( self, agent_event: CompletedEvent, oracle_event: CompletedOracleEvent, **kwargs ) -> bool | None: oracle_tool_name = oracle_event.tool_name agent_tool_name = agent_event.tool_name logger.info(f"Comparing {oracle_tool_name} to {agent_tool_name}") # First check time if not self.check_time( agent_event=agent_event, oracle_event=oracle_event, max_parent_oracle_event_time=kwargs.get( "max_parent_oracle_event_time", 0.0 ), max_parent_agent_event_time=kwargs.get("max_parent_agent_event_time", 0.0), ): return False # Check tool call (action) assert oracle_tool_name in self.tool_judges, ( f"Tool {oracle_tool_name} not found in tool judges" ) return self.tool_judges[oracle_tool_name](agent_event, oracle_event, **kwargs)