Source code for are.simulation.scenarios.scenario_imported_from_json.benchmark_scenario

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.


import logging
from collections import defaultdict, deque
from typing import Callable

from are.simulation.data_handler.models import ExportedHuggingFaceMetadata
from are.simulation.scenarios.scenario import Scenario, ScenarioValidationResult
from are.simulation.scenarios.utils.turn_conditions import (
    is_send_message_to_user,
    turn_condition_wrapper,
)
from are.simulation.types import (
    AbstractEnvironment,
    AbstractEvent,
    CompletedEvent,
    ConditionCheckEvent,
    OracleEvent,
)

from .scenario import ScenarioImportedFromJson

logger = logging.getLogger(__name__)


def build_event_id_to_turn_idx(
    scenario: Scenario,
    is_end_of_turn_event: Callable[[AbstractEvent], bool] = is_send_message_to_user,
):
    """
    Build a dictionary to store the turn of each event
    The turn of an event is the number of event send_message_to_user among its ancestors
    The dictionary and the number of turns are stored to the scenario
    Args:
        scenario (Scenario): The scenario object.
    """
    if scenario.events is None:
        raise ValueError("Events not found")
    visited_events = set()
    # Initialize a dictionary to store the turn of each event
    event_id_to_turn_idx = defaultdict(lambda: 0)
    # Initialize a queue with event without dependencies
    queue = deque([e for e in scenario.events if not e.dependencies])
    # Perform BFS on events
    while queue:
        event = queue.popleft()
        # If the event has not been visited yet
        if event.event_id not in visited_events:
            event_id_to_turn_idx[event.event_id] = 0
            # Mark it as visited
            visited_events.add(event.event_id)
            if event.dependencies:
                event_id_to_turn_idx[event.event_id] = max(
                    [event_id_to_turn_idx[e.event_id] for e in event.dependencies],
                )
                # Check if there is a end of turn event among the dependencies
                if any(
                    is_end_of_turn_event(e)
                    for e in event.dependencies  # type: ignore
                ):
                    event_id_to_turn_idx[event.event_id] += 1
            # Add the successors to the queue
            queue.extend(
                e for e in event.successors if e.event_id not in visited_events
            )
    # Convert the defaultdict back to a regular dictionary
    event_id_to_turn_idx = dict(event_id_to_turn_idx)
    # Set the number of turns
    nb_turns = (max(event_id_to_turn_idx.values()) if event_id_to_turn_idx else 0) + 1
    # Append event_id_to_turn_idx and number of turn to scenario
    scenario.nb_turns = nb_turns  # type: ignore
    scenario.event_id_to_turn_idx = event_id_to_turn_idx  # type: ignore


[docs] class BenchmarkScenarioImportedFromJson(ScenarioImportedFromJson): """ This is a special class used for importing scenarios from JSON files. """ # Multi turn parameters nb_turns: int | None = None # Number of turns in the scenario event_id_to_turn_idx: dict[str, int] | None = ( None # Dictionary to store the turn of each event ) oracle_run_event_log: list[CompletedEvent] | None = ( None # Event log of a run in oracle mode ) _turns_initialized: bool = False hf_metadata: ExportedHuggingFaceMetadata | None = None
[docs] def build_turn_trigger( self, trigger_condition: Callable[ [AbstractEnvironment, int], tuple[bool, dict[str, str]] ], is_end_of_turn_event: Callable[[AbstractEvent], bool] = is_send_message_to_user, ): """ Modify the events to trigger the turns with trigger condition """ assert self.event_id_to_turn_idx is not None, "Turn index must be set" # Get all end of turn events end_of_turn_events = { self.event_id_to_turn_idx[event.event_id]: event for event in self.events if is_end_of_turn_event(event) } d_events = dict() # Build the turn trigger events assert self.nb_turns is not None, "Number of turns must be set" for turn_idx in range(1, self.nb_turns): # Get dependencies dependencies = ( d_events[f"condition_turn_{turn_idx - 1}"] if turn_idx > 1 else None ) # Get successors successors = end_of_turn_events[turn_idx - 1].successors[:] if any(isinstance(e, OracleEvent) for e in successors): raise ValueError( f"Scenario {self.scenario_id} has a end of turn event with oracle successors" ) # Remove end of turn events from successors for successor_event in successors: successor_event.dependencies = [ e for e in successor_event.dependencies if not is_end_of_turn_event(e) ] # Build a condition for the turn d_events[f"condition_turn_{turn_idx}"] = ( ConditionCheckEvent.from_condition( condition=turn_condition_wrapper( trigger_condition=trigger_condition, turn_idx=turn_idx, scenario_id=self.scenario_id, ), # type: ignore every_tick=1, timeout=None, ) .with_id(f"condition_turn_{turn_idx}") .depends_on(dependencies) .followed_by( events=successors, delay_seconds=[ ( e.event_relative_time if e.event_relative_time is not None else 0.0 ) for e in successors ], ) ) # Check that all non oracle events do not depend on any oracle event for event in self.events: if type(event) is not OracleEvent and event.dependencies: # Non oracle events must not depend on any oracle event now if any(type(e) is OracleEvent for e in event.dependencies): raise ValueError( f"Event {event.event_id} depends on an oracle event. This is not expected." ) # Update the events for key, event in d_events.items(): self.events.append(event.with_id(key))
[docs] def build_validation_fn( self, validation_fn: Callable[[AbstractEnvironment], ScenarioValidationResult], offline_validation: bool = False, ): """ Build a validation function for the scenario """ def online_wrapped_validation_fn( env: AbstractEnvironment, ) -> ScenarioValidationResult: """ Perform the validation only for the last turn of the scenario since in the online mode, a validation is performed at the end of each turn in the trigger condition. """ return validation_fn(env) def offline_wrapped_validation_fn( env: AbstractEnvironment, ) -> ScenarioValidationResult: """ Perform the validation for all the turns of the scenario since in the offline mode, no intermediate validation is performed between the turns. """ assert self.nb_turns is not None, "Number of turns must be set" # Evaluate all the turns result = ScenarioValidationResult(success=False) for turn_idx in range(self.nb_turns): logger.info(f"Validating turn {turn_idx + 1} / {self.nb_turns}") # We expect that validation_fn at each call validates the current turn # and updates its internal state to evaluate the next turn at the next call result = validation_fn(env) if not result.success: return result return result self.validate = ( offline_wrapped_validation_fn if offline_validation else online_wrapped_validation_fn )
[docs] def initialize_turns( self, trigger_condition: ( Callable[[AbstractEnvironment, int], tuple[bool, dict[str, str]]] | None ) = None, validation_fn: ( Callable[[AbstractEnvironment], ScenarioValidationResult] | None ) = None, is_end_of_turn_event: Callable[[AbstractEvent], bool] = is_send_message_to_user, offline_validation: bool = False, ): """ Initialize the turns. """ if not self._initialized: raise ValueError("Scenario must be initialized before initializing turns") if self._turns_initialized: return # Build the event id to turn index dictionary build_event_id_to_turn_idx( scenario=self, is_end_of_turn_event=is_end_of_turn_event, ) # Modify the events graph to trigger turns if trigger_condition is not None: self.build_turn_trigger( trigger_condition=trigger_condition, is_end_of_turn_event=is_end_of_turn_event, ) else: logger.warning( "Trigger condition is not provided. Building turn triggers is skipped" ) # Build the validation function if validation_fn is not None: self.build_validation_fn( validation_fn, offline_validation=offline_validation ) else: logger.warning( "Validation function is not provided. Building validation function is skipped" ) # Add dummy validation function self.validate = lambda env: ScenarioValidationResult(success=None) # Set the flag self._turns_initialized = True