Source code for are.simulation.validation.judgment
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections import Counter
from dataclasses import dataclass, field
from enum import Enum
[docs]
@dataclass
class Failure:
def __str__(self) -> str:
return "Failure occurred."
[docs]
class EventComparisonFailureType(Enum):
CAUSALITY = "causality"
ALREADY_MATCHED = "already matched"
TOOL_JUDGE_REJECT = "tool judge reject"
[docs]
@dataclass
class EventComparisonFailure(Failure):
agent_tool_name: str
agent_event_id: str
oracle_tool_name: str
oracle_event_id: str
failure_type: EventComparisonFailureType
def __str__(self) -> str:
reason = f"{self.failure_type.value}"
return f"Failure matching agent event (ID: {self.agent_event_id}) with oracle event (ID: {self.oracle_event_id}), reason: {reason}"
[docs]
@dataclass
class OracleEventMatchingFailure(Failure):
oracle_tool_name: str
oracle_tool_args: dict[str, str]
comparison_failures: list[EventComparisonFailure]
def __str__(self) -> str:
tool_arg_str = [f"-{k}: {v}" for k, v in self.oracle_tool_args.items()]
tool_arg_str = [
(x[:200] + ("..." if len(x) > 200 else "")) for x in tool_arg_str
]
tool_arg_str = "\n".join(tool_arg_str)
message = "Failure: Agent did not perform the following oracle tool call:"
message += f"\ntool name: {self.oracle_tool_name}\ntool args:\n{tool_arg_str}\n"
message += "\nList of matching attempts:\n"
message += "\n".join(["-" + str(c) for c in self.comparison_failures])
return message
[docs]
@dataclass
class EnvOracleMatchingFailure(Failure):
oracle_event_id: str
def __str__(self) -> str:
return f"Failure: Oracle env/user event {self.oracle_event_id} could not be matched. This is likely a bug !"
[docs]
@dataclass
class Judgment:
success: bool | None = False
failure: str | Failure | None = None
agent_event_id_to_oracle_event_id: dict[str, str] = field(default_factory=dict)