Source code for are.simulation.validation.tool_judge
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import logging
import os
import re
from datetime import datetime, timezone
from typing import Any
from are.simulation.types import CompletedEvent
from are.simulation.validation.base import ToolJudge
from are.simulation.validation.configs import (
CheckerType,
HardToolJudgeConfig,
MildToolJudgeConfig,
SoftCheckerType,
SoftToolJudgeConfig,
)
from are.simulation.validation.utils.llm_utils import (
build_llm_checkers,
build_subtask_extractor,
)
from are.simulation.validation.utils.misc import (
extract_text_between_tags,
normalize_arg,
normalize_str,
)
from are.simulation.validation.utils.scenario_utils import CompletedOracleEvent
from are.simulation.validation.utils.trace_utils import injected_traceable
logger: logging.Logger = logging.getLogger(__name__)
[docs]
class HardToolJudge(ToolJudge):
"""
A judge that performs a scripted check on some action args to compare an agent and oracle event representing a tool call.
"""
def __init__(self, config: HardToolJudgeConfig):
super().__init__(config, "hard")
self.config = config
self.checkers = {
CheckerType.eq_checker.value: self.eq_checker,
CheckerType.unordered_list_checker.value: self.unordered_list_checker,
CheckerType.datetime_checker.value: self.datetime_checker,
CheckerType.list_attendees_checker.value: self.list_attendees_checker,
CheckerType.phone_number_checker.value: self.phone_number_checker,
CheckerType.eq_str_strip_checker.value: self.eq_str_strip_checker,
CheckerType.contain_any_checker.value: self.contain_any_checker,
CheckerType.contain_all_checker.value: self.contain_all_checker,
CheckerType.path_checker.value: self.path_checker,
CheckerType.unordered_path_list_checker.value: self.unordered_path_list_checker,
}
# Collect checker args names
self.checker_to_args_names = {
checker_name: list(inspect.signature(checker_fn).parameters.keys())
for checker_name, checker_fn in self.checkers.items()
}
[docs]
@injected_traceable(trace_type="eq_checker", tags=["judge"])
def eq_checker(self, x_agent: Any, x_oracle: Any, **kwargs) -> bool:
return x_agent == x_oracle
[docs]
@injected_traceable(trace_type="unordered_list_checker", tags=["judge"])
def unordered_list_checker(
self, x_agent: list[Any] | None, x_oracle: list[Any] | None, **kwargs
) -> bool:
if x_agent is None:
return x_oracle is None or len(x_oracle) == 0
if x_oracle is None:
return x_agent is None or len(x_agent) == 0
return set(x_agent) == set(x_oracle)
[docs]
@injected_traceable(trace_type="path_checker", tags=["judge"])
def path_checker(self, x_agent: str | None, x_oracle: str | None, **kwargs) -> bool:
if x_agent is None or x_oracle is None:
return x_agent == x_oracle
normalized_agent = os.path.normpath(x_agent).lstrip("/")
normalized_oracle = os.path.normpath(x_oracle).lstrip("/")
return normalized_agent == normalized_oracle
[docs]
@injected_traceable(trace_type="unordered_path_list_checker", tags=["judge"])
def unordered_path_list_checker(
self, x_agent: list[str] | None, x_oracle: list[str] | None, **kwargs
) -> bool:
if x_agent is None:
return x_oracle is None or len(x_oracle) == 0
if x_oracle is None:
return x_agent is None or len(x_agent) == 0
normalized_agent_paths = {
os.path.normpath(path).lstrip("/") for path in x_agent
}
normalized_oracle_paths = {
os.path.normpath(path).lstrip("/") for path in x_oracle
}
return normalized_agent_paths == normalized_oracle_paths
[docs]
def list_attendees_checker(
self,
x_agent: list[str] | None,
x_oracle: list[str] | None,
tolerance_list_str: list[str] | None = None,
**kwargs,
) -> bool:
if tolerance_list_str is None:
tolerance_list_str = []
_tolerance_list = [normalize_str(x) for x in tolerance_list_str]
# If the oracle list is empty or contains only tolerance strings, return True
if x_oracle is None or len(x_oracle) == 0:
return True
if all(
normalize_str(x_oracle[i]) in _tolerance_list for i in range(len(x_oracle))
):
return True
# Otherwise call unordered list checker with tolerance list
return self.unordered_str_list_with_tolerance_checker(
x_agent, x_oracle, _tolerance_list
)
[docs]
@injected_traceable(
trace_type="unordered_str_list_with_tolerance_checker", tags=["judge"]
)
def unordered_str_list_with_tolerance_checker(
self,
x_agent: list[str] | None,
x_oracle: list[str] | None,
tolerance_list_str: list[str] | None = None,
**kwargs,
) -> bool:
if tolerance_list_str is None:
tolerance_list_str = []
_x_agent = [normalize_str(x) for x in x_agent] if x_agent is not None else []
_x_oracle = [normalize_str(x) for x in x_oracle] if x_oracle is not None else []
# Remove elements from tolerance list
_x_agent = [x for x in _x_agent if x not in tolerance_list_str]
_x_oracle = [x for x in _x_oracle if x not in tolerance_list_str]
# Compare sets
return set(_x_agent) == set(_x_oracle)
[docs]
@injected_traceable(trace_type="datetime_checker", tags=["judge"])
def datetime_checker(
self, x_agent: str | None, x_oracle: str | None, **kwargs
) -> bool:
if x_agent is None or x_oracle is None:
return x_agent == x_oracle
_x_agent = datetime.strptime(x_agent, "%Y-%m-%d %H:%M:%S")
_x_oracle = datetime.strptime(x_oracle, "%Y-%m-%d %H:%M:%S")
return _x_agent == _x_oracle
[docs]
@injected_traceable(trace_type="eq_str_strip_checker", tags=["judge"])
def eq_str_strip_checker(
self, x_agent: str | None, x_oracle: str | None, **kwargs
) -> bool:
_x_agent = (
x_agent.strip() if bool(x_agent) else ""
) # Are we sure we want strip here?
_x_oracle = x_oracle.strip() if bool(x_oracle) else ""
return _x_agent == _x_oracle
[docs]
@injected_traceable(trace_type="phone_number_checker", tags=["judge"])
def phone_number_checker(
self, x_agent: str | None, x_oracle: str | None, **kwargs
) -> bool:
if x_agent is None or x_oracle is None:
return x_agent is None and x_oracle is None
# Remove any non-digit characters from the input strings
_x_agent = "".join(char for char in x_agent if char.isdigit())
_x_oracle = "".join(char for char in x_oracle if char.isdigit())
# Compare the cleaned phone numbers
return _x_agent == _x_oracle
[docs]
@injected_traceable(trace_type="contain_any_checker", tags=["judge"])
def contain_any_checker(self, x_agent: str, targets: list[str], **kwargs) -> bool:
return any(x_oracle.lower() in x_agent.lower() for x_oracle in targets)
[docs]
@injected_traceable(trace_type="contain_any_checker", tags=["judge"])
def contain_all_checker(self, x_agent: str, targets: list[str], **kwargs) -> bool:
return all(x_oracle.lower() in x_agent.lower() for x_oracle in targets)
[docs]
def compare(
self, agent_event: CompletedEvent, oracle_event: CompletedOracleEvent, **kwargs
) -> bool:
# Get args
agent_args = agent_event.get_args()
oracle_args = oracle_event.get_args()
# Hard checks
for arg_name, check_type in self.config.arg_to_checker_type.items():
if check_type.is_hard():
checker_fn = self.checkers[check_type.value] # type: ignore
args_names = self.checker_to_args_names[check_type.value] # type: ignore
# This is only for logging purposes
checker_args = {k: v for k, v in kwargs.items() if k in args_names}
checker_args["arg_name"] = arg_name
# Call the checker
if not checker_fn(
x_agent=agent_args[arg_name],
x_oracle=oracle_args[arg_name],
**checker_args,
):
return False
# Scripted checks TODO: unify hard and scripted checks
if (
self.config.event_id_to_checker_params
and oracle_event.event_id in self.config.event_id_to_checker_params
):
for params in self.config.event_id_to_checker_params[oracle_event.event_id]:
if params.checker_type.is_hard():
if not self.checkers[params.checker_type.value]( # type: ignore
agent_args[params.arg_name],
oracle_args[params.arg_name],
**params.checker_args,
):
return False
elif params.checker_type.is_scripted():
if not self.checkers[params.checker_type.value]( # type: ignore
agent_args[params.arg_name],
**params.checker_args,
):
return False
return True
[docs]
class SoftToolJudge(ToolJudge):
"""
A soft judge that compares some action args of an agent and oracle event representing a tool call with an llm.
"""
def __init__(self, config: SoftToolJudgeConfig):
super().__init__(config, "soft")
self.config = config
# Engine
self.engine = self.config.engine
# No arg to check
self.selected_action_args = [
name
for name, check_type in self.config.arg_to_checker_type.items()
if check_type == CheckerType.llm_checker
]
self.no_arg_to_check = len(self.selected_action_args) == 0
# Subtask extractor
self.subtask_extractor = build_subtask_extractor(
engine=self.config.engine, tool_name=self.tool_name
)
# LLM checkers
self.llm_checkers = build_llm_checkers(engine=self.config.engine)
# Soft checkers
self.soft_checkers = {
SoftCheckerType.content_checker.value: self.content_checker,
SoftCheckerType.sanity_checker.value: self.sanity_checker,
SoftCheckerType.signature_checker.value: self.signature_checker,
SoftCheckerType.placeholder_checker.value: self.placeholder_checker,
SoftCheckerType.cab_checker.value: self.cab_checker,
SoftCheckerType.email_checker.value: self.email_checker,
SoftCheckerType.message_checker.value: self.message_checker,
SoftCheckerType.user_message_checker.value: self.user_message_checker,
SoftCheckerType.event_checker.value: self.event_checker,
SoftCheckerType.tone_checker.value: self.tone_checker,
}
# Checker to args names
self.checker_to_args_names = {
checker_name: list(inspect.signature(checker_fn).parameters.keys())
for checker_name, checker_fn in self.soft_checkers.items()
}
# Flag if we need to extract subtask
self.need_subtask = any(c.need_subtask for c in self.config.soft_checker_types)
[docs]
def describe_action_args(self, args: dict[str, Any]) -> str:
return_string = ""
for k, v in args.items():
return_string += f"{k}: {v} \n"
return return_string.strip()
[docs]
@injected_traceable(trace_type="equality_checker", tags=["judge"])
def equality_checker(
self,
agent_args: dict[str, Any],
oracle_args: dict[str, Any],
**kwargs,
) -> bool:
# Check args
for arg_name in oracle_args.keys():
if normalize_arg(agent_args[arg_name]) != normalize_arg(
oracle_args[arg_name]
):
return False
return True
[docs]
@injected_traceable(trace_type="placeholder_checker", tags=["judge"])
def placeholder_checker(
self,
agent_args: dict[str, Any],
**kwargs,
) -> bool:
agent_args_str = " ".join([v for v in agent_args.values()])
placeholders = [
"[User's Name]",
"[User Name]",
"[User]",
"[Your Name]",
"[My Name]",
"Best regards,\nYour Name",
"Best,\nYour Name",
]
if any(p.lower() in agent_args_str.lower() for p in placeholders):
return False
return True
[docs]
@injected_traceable(trace_type="extract_subtask", tags=["judge"])
def extract_subtask(self, oracle_action_call: str, task: str) -> str:
task = task.strip()
if len(task) == 0:
return ""
# Extract subtask
subtask = self.subtask_extractor(
user_prompt_args={
"tool_name": self.tool_name,
"oracle_action_call": oracle_action_call,
"task": task,
}
)
# Extract subtask
if subtask is None:
return ""
subtask = extract_text_between_tags(subtask, "subtask")
return subtask[0].strip() if len(subtask) > 0 else ""
[docs]
@injected_traceable(trace_type="content_checker", tags=["judge"])
def content_checker(
self,
agent_args: dict[str, Any],
oracle_args: dict[str, Any],
today_date: str,
user_address: str,
subtask: str,
**kwargs,
) -> bool | None:
# Call the soft checker
return self.llm_checkers[SoftCheckerType.content_checker.value](
user_prompt_args={
"agent_action_call": self.describe_action_args(agent_args),
"oracle_action_call": self.describe_action_args(oracle_args),
"task": subtask,
"tool_name": self.tool_name,
"today_date": today_date,
"user_address": user_address,
}
)
[docs]
@injected_traceable(trace_type="signature_checker", tags=["judge"])
def signature_checker(
self,
agent_args: dict[str, Any],
user_name: str,
**kwargs,
) -> bool | None:
# Call the soft checker
return self.llm_checkers[SoftCheckerType.signature_checker.value](
user_prompt_args={
"agent_action_call": self.describe_action_args(agent_args),
"user_name": user_name,
}
)
[docs]
@injected_traceable(trace_type="sanity_checker", tags=["judge"])
def sanity_checker(
self,
agent_args: dict[str, Any],
task: str = "",
previous_task: str = "",
**kwargs,
) -> bool | None:
# Check for numerical values first
agent_action_call = self.describe_action_args(agent_args)
def is_valid_numerical_values_format(s):
# Define the regular expression pattern
pattern = r"^content: \d+(\.\d+)?$"
# Use re.match to check if the string matches the pattern
return re.match(pattern, s) is not None
if is_valid_numerical_values_format(agent_action_call):
return True
# Call the soft checker
return self.llm_checkers[SoftCheckerType.sanity_checker.value](
user_prompt_args={
"agent_action_call": agent_action_call,
"task": "\n".join([previous_task, task]),
}
)
[docs]
@injected_traceable(trace_type="cab_checker", tags=["judge"])
def cab_checker(
self,
agent_args: dict[str, Any],
oracle_args: dict[str, Any],
user_address: str,
**kwargs,
) -> bool | None:
# Call the soft checker
return self.llm_checkers[SoftCheckerType.cab_checker.value](
user_prompt_args={
"agent_action_call": self.describe_action_args(agent_args),
"oracle_action_call": self.describe_action_args(oracle_args),
"user_address": user_address,
}
)
[docs]
@injected_traceable(trace_type="email_checker", tags=["judge"])
def email_checker(
self,
agent_args: dict[str, Any],
oracle_args: dict[str, Any],
today_date: str,
**kwargs,
) -> bool | None:
# Call the soft checker
return self.llm_checkers[SoftCheckerType.email_checker.value](
user_prompt_args={
"agent_action_call": self.describe_action_args(agent_args),
"oracle_action_call": self.describe_action_args(oracle_args),
"today_date": today_date,
}
)
[docs]
@injected_traceable(trace_type="message_checker", tags=["judge"])
def message_checker(
self,
agent_args: dict[str, Any],
oracle_args: dict[str, Any],
today_date: str,
**kwargs,
) -> bool | None:
# Call the soft checker
return self.llm_checkers[SoftCheckerType.message_checker.value](
user_prompt_args={
"agent_action_call": self.describe_action_args(agent_args),
"oracle_action_call": self.describe_action_args(oracle_args),
"today_date": today_date,
}
)
[docs]
@injected_traceable(trace_type="event_checker", tags=["judge"])
def event_checker(
self,
agent_args: dict[str, Any],
oracle_args: dict[str, Any],
user_address: str,
subtask: str,
**kwargs,
) -> bool | None:
# Call the soft checker
return self.llm_checkers[SoftCheckerType.event_checker.value](
user_prompt_args={
"agent_action_call": self.describe_action_args(agent_args),
"oracle_action_call": self.describe_action_args(oracle_args),
"user_address": user_address,
"task": subtask,
}
)
[docs]
@injected_traceable(trace_type="user_message_checker", tags=["judge"])
def user_message_checker(
self,
agent_args: dict[str, Any],
oracle_args: dict[str, Any],
subtask: str,
**kwargs,
) -> bool | None:
# Call the soft checker
return self.llm_checkers[SoftCheckerType.user_message_checker.value](
user_prompt_args={
"agent_action_call": self.describe_action_args(agent_args),
"oracle_action_call": self.describe_action_args(oracle_args),
"task": subtask,
}
)
[docs]
@injected_traceable(trace_type="tone_checker", tags=["judge"])
def tone_checker(
self,
agent_args: dict[str, Any],
**kwargs,
) -> bool | None:
# Call the soft checker
return self.llm_checkers[SoftCheckerType.tone_checker.value](
user_prompt_args={
"agent_action_call": self.describe_action_args(agent_args),
}
)
[docs]
def get_checker_kwargs(
self,
kwargs: dict[str, Any],
oracle_event: CompletedOracleEvent,
oracle_args: dict[str, Any],
) -> dict[str, Any]:
# Get the task
tasks = kwargs.get("tasks", [""])
# Soft checker kwargs
today_date = ""
if oracle_event.event_time is not None:
today_date = datetime.fromtimestamp(
oracle_event.event_time, tz=timezone.utc
).strftime("%Y-%m-%d %A")
# Subtask
subtask = (
self.extract_subtask(
oracle_action_call=self.describe_action_args(oracle_args),
task="\n".join(tasks),
)
if self.need_subtask
else ""
)
# User name
user_details = kwargs.get("user_details")
user_name = (
f"{user_details.first_name} {user_details.last_name}"
if user_details
else ""
)
# User address
user_address = f"{user_details.address}" if user_details else ""
# Tasks
return {
"task": tasks[-1],
"previous_task": "/n".join(tasks[:-1]) if len(tasks) > 1 else "",
"user_name": user_name,
"user_address": user_address,
"today_date": today_date,
"oracle_args": oracle_args,
"subtask": subtask,
}
[docs]
def compare(
self,
agent_event: CompletedEvent,
oracle_event: CompletedOracleEvent,
**kwargs,
) -> bool | None:
# If no args to check, then return True
if self.no_arg_to_check:
return True
# Get the args
oracle_args = oracle_event.get_args()
agent_args = agent_event.get_args()
selected_action_args = [
arg for arg in self.selected_action_args if bool(oracle_args[arg])
]
oracle_args = {
k: v for k, v in oracle_args.items() if k in selected_action_args
}
agent_args = {k: v for k, v in agent_args.items() if k in selected_action_args}
# Check equality
if self.equality_checker(
agent_args=agent_args,
oracle_args=oracle_args,
):
return True
# Get checker kwargs
checker_kwargs = self.get_checker_kwargs(
kwargs=kwargs,
oracle_event=oracle_event,
oracle_args=oracle_args,
)
# Apply soft checkers
for checker in self.config.soft_checker_types:
checker_fn = self.soft_checkers[checker.value]
# This is only for logging purposes
_checker_kwargs = {
k: v
for k, v in checker_kwargs.items()
if k in self.checker_to_args_names[checker.value]
}
# Call the checker
if not checker_fn(
agent_args=agent_args,
**_checker_kwargs,
):
return False
return True
[docs]
class MildToolJudge(ToolJudge):
"""
A mild judge that combines a hard and soft judge to compare an agent and oracle event representing a tool call.
If first call the hard judge and if it passes, then call the soft judge.
"""
def __init__(self, config: MildToolJudgeConfig):
super().__init__(config, "mild")
# Config
self.config = config
# Hard judge
self.hard_judge = HardToolJudge(
HardToolJudgeConfig(
tool_name=config.tool_name,
arg_to_checker_type=config.arg_to_checker_type,
event_id_to_checker_params=config.event_id_to_checker_params,
tracer=self.tracer,
)
)
# Soft judge
self.soft_judge = SoftToolJudge(
SoftToolJudgeConfig(
tool_name=config.tool_name,
arg_to_checker_type=config.arg_to_checker_type,
soft_checker_types=config.soft_checker_types,
engine=config.engine,
tracer=self.tracer,
)
)
# If scripted checkers are provided we do not use the soft judge
if self.config.event_id_to_checker_params is not None:
self.soft_judge.no_arg_to_check = True
[docs]
def compare(
self, agent_event: CompletedEvent, oracle_event: CompletedOracleEvent, **kwargs
) -> bool | None:
# TODO: change system prompt when hard comparison fails to call soft judge
hard_comparison = self.hard_judge(agent_event, oracle_event, **kwargs)
if not hard_comparison or self.soft_judge.no_arg_to_check:
return hard_comparison
return self.soft_judge(agent_event, oracle_event, **kwargs)