# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the terms described in the LICENSE file in# the root directory of this source tree.fromdataclassesimportdataclass,fieldfromtypingimportAny,Callablefromare.simulation.agents.are_simulation_agent_configimportLLMEngineConfigfromare.simulation.validation.constantsimport(TOOL_ARG_CHECKER_TYPE_REGISTRY,TOOL_EVALUATION_CRITERIA_REGISTRY,TOOL_SOFT_CHECKER_TYPE_REGISTRY,CheckerType,SoftCheckerType,ToolArgCheckerTypeRegistry,ToolCriteriaRegistry,ToolSoftCheckerTypeRegistry,)fromare.simulation.validation.promptsimport(IN_CONTEXT_JUDGE_SYSTEM_PROMPT_TEMPLATE,TIME_SYSTEM_PROMPT_TEMPLATE,)# Default judge configurationDEFAULT_JUDGE_MODEL="meta-llama/Meta-Llama-3.3-70B-Instruct"DEFAULT_JUDGE_PROVIDER="huggingface"defcreate_judge_engine(judge_engine_config:LLMEngineConfig|None=None,):"""Create a judge engine with the specified configuration."""ifjudge_engine_configisNone:judge_engine_config=LLMEngineConfig(model_name=DEFAULT_JUDGE_MODEL,provider=None,endpoint=None,)# Use LiteLLM for all other cases (external or internal with overrides)fromare.simulation.agents.llm.litellm.litellm_engineimport(LiteLLMEngine,LiteLLMModelConfig,)final_provider=(judge_engine_config.providerorDEFAULT_JUDGE_PROVIDERor"huggingface")judge_config=LiteLLMModelConfig(model_name=judge_engine_config.model_name,provider=final_provider,endpoint=judge_engine_config.endpoint,)returnLiteLLMEngine(model_config=judge_config)@dataclassclassToolCheckerParam:# Parameter for the tool checker of the hard judgearg_name:strchecker_type:CheckerTypetool_name:strchecker_args:dict[str,Any]=field(default_factory=dict)@dataclassclassBaseToolJudgeConfig:tool_name:strarg_to_checker_type:dict[str,CheckerType]# the list of args to check and the type of checker to use for each argtracer:Callable|None=None@dataclassclassHardToolJudgeConfig(BaseToolJudgeConfig):event_id_to_checker_params:dict[str,list[ToolCheckerParam]]|None=None@dataclassclassSoftToolJudgeConfig(BaseToolJudgeConfig):engine:Callable=field(default_factory=create_judge_engine)# Soft checkersoft_checker_types:list[SoftCheckerType]=field(default_factory=lambda:[SoftCheckerType.content_checker])def__post_init__(self):iflen(self.soft_checker_types)==0:self.soft_checker_types=[SoftCheckerType.content_checker]@dataclassclassMildToolJudgeConfig(BaseToolJudgeConfig):engine:Callable=field(default_factory=create_judge_engine)soft_checker_types:list[SoftCheckerType]=field(default_factory=list)# Scripted checkers related configevent_id_to_checker_params:dict[str,list[ToolCheckerParam]]|None=None@dataclassclassBaseEventJudgeConfig:tracer:Callable|None=None@dataclassclassEnvUserEventJudgeConfig(BaseEventJudgeConfig):pass@dataclassclassAgentEventJudgeConfig(BaseEventJudgeConfig):# Time related configcheck_time_threshold_seconds:float=1.0pre_event_tolerance_seconds:float=10.0post_event_tolerance_seconds:float=25.0# Tool related configper_tool_arg_to_checker_type:ToolArgCheckerTypeRegistry=field(default_factory=lambda:TOOL_ARG_CHECKER_TYPE_REGISTRY)per_tool_soft_checker_types:ToolSoftCheckerTypeRegistry=field(default_factory=lambda:TOOL_SOFT_CHECKER_TYPE_REGISTRY)engine:Callable=field(default_factory=create_judge_engine)# Scripted checkers related configevent_id_to_checker_params:dict[str,list[ToolCheckerParam]]|None=None
[docs]@dataclassclassGraphPerEventJudgeConfig(BaseJudgeConfig):# Time related configcheck_time_threshold_seconds:float=1.0pre_event_tolerance_seconds:float=10.0post_event_tolerance_seconds:float=25.0# Tool related configper_tool_arg_to_checker_type:ToolArgCheckerTypeRegistry=field(default_factory=lambda:TOOL_ARG_CHECKER_TYPE_REGISTRY)engine:Callable=field(default_factory=create_judge_engine)per_tool_soft_checker_types:ToolSoftCheckerTypeRegistry=field(default_factory=lambda:TOOL_SOFT_CHECKER_TYPE_REGISTRY)# Scripted checkers related config# If this field is not `None`, the soft judge will not be used.event_id_to_checker_params:dict[str,list[ToolCheckerParam]]|None=None# Preliminary checkextra_send_message_to_user_allowed:int=1
[docs]@dataclassclassScriptedGraphPerEventJudgeConfig(GraphPerEventJudgeConfig):""" Config for the scripted graph per event judge. Scripted judge is a judge where the soft judge is deactivated and instead scripted checks will be performed by the hard judge. The `event_id_to_checker_params` field is used to specify the scripted checks to perform. """# Change default such that soft judge is not used.event_id_to_checker_params:dict[str,list[ToolCheckerParam]]|None=field(default_factory=dict)def__post_init__(self):ifself.event_id_to_checker_paramsisNone:raiseValueError("event_id_to_checker_params must be specified for ScriptedGraphPerEventJudgeConfig")
[docs]@dataclassclassInContextJudgeConfig(BaseJudgeConfig):# Time related configcheck_time_threshold_seconds:float=1.0pre_event_tolerance_seconds:float=10.0post_event_tolerance_seconds:float=25.0time_system_prompt_template:str=TIME_SYSTEM_PROMPT_TEMPLATE# Tool related configper_tool_evaluation_criteria:ToolCriteriaRegistry=field(default_factory=lambda:TOOL_EVALUATION_CRITERIA_REGISTRY)tool_to_selected_args:ToolArgCheckerTypeRegistry=(field(# Will not use the checker type but only arg namesdefault_factory=lambda:TOOL_ARG_CHECKER_TYPE_REGISTRY))engine:Callable=field(default_factory=create_judge_engine)system_prompt_template:str=IN_CONTEXT_JUDGE_SYSTEM_PROMPT_TEMPLATE