# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import abc
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum, IntEnum
from typing import Any
from are.simulation.agents.multimodal import Attachment
from are.simulation.apps.agent_user_interface import AUIMessage, Sender
from are.simulation.apps.reminder import Reminder, ReminderApp
from are.simulation.apps.system import SystemApp, WaitForNotificationTimeout
from are.simulation.priority_queue import PriorityQueue
from are.simulation.time_manager import TimeManager
from are.simulation.types import AbstractEvent, Action, CompletedEvent, EventType
from are.simulation.validation.constants import APP_ALIAS
[docs]
class MessageType(Enum):
USER_MESSAGE = "USER_MESSAGE"
ENVIRONMENT_NOTIFICATION = "ENVIRONMENT_NOTIFICATION"
ENVIRONMENT_STOP = "ENVIRONMENT_STOP"
[docs]
class VerbosityLevel(IntEnum):
"""Defines the verbosity levels for the notification system.
LOW: Only user messages and system notifications (due reminders, wait for notification timeouts) are notified.
MEDIUM: ENV events that are possible consequences of Agent actions are notified.
HIGH: Most ENV events are notified, even independent ENV events that are not caused by Agent actions.
"""
LOW = 1
MEDIUM = 2
HIGH = 3
[docs]
@dataclass
class Message:
message_type: MessageType
message: str
timestamp: datetime
attachments: list[Attachment] = field(default_factory=list)
[docs]
class AbstractMessageQueue:
[docs]
@abc.abstractmethod
def put(self, message: Message) -> None:
pass
[docs]
@abc.abstractmethod
def get_by_timestamp(self, timestamp: datetime) -> list[Message]:
pass
[docs]
class MessageQueue(AbstractMessageQueue):
[docs]
def __init__(self):
self.messages = PriorityQueue[Message](fields=["timestamp"])
[docs]
def put(self, message: Message) -> None:
self.messages.put(message)
[docs]
def get_by_timestamp(self, timestamp: datetime) -> list[Message]:
extracted_messages = []
remaining_messages = []
while not self.messages.empty():
event = self.messages.get()
if event.timestamp <= timestamp:
extracted_messages.append(event)
else:
remaining_messages.append(event)
break # since events are ordered, no need to check further
# Reinsert the remaining items back into the queue
for item in remaining_messages:
self.messages.put(item)
return extracted_messages
[docs]
def has_environment_stop_message(self) -> bool:
return any(
message.message_type == MessageType.ENVIRONMENT_STOP
for message in self.list_view()
)
[docs]
def list_view(self) -> list[Message]:
return list(self.messages)
[docs]
def has_new_messages(self, timestamp: datetime) -> bool:
return any(message.timestamp <= timestamp for message in self.list_view())
[docs]
@dataclass
class NotificationSystemConfig:
notified_tools: dict[str, list[str]] = field(default_factory=dict)
def _add_app_aliases(notified_tools: dict[str, list[str]]) -> dict[str, list[str]]:
"""Helper function to add app aliases to the notified tools dictionary."""
aliases_to_add = {}
for app_name in notified_tools:
if app_name in APP_ALIAS:
if isinstance(APP_ALIAS[app_name], list):
for alias in APP_ALIAS[app_name]:
aliases_to_add[alias] = notified_tools[app_name]
elif isinstance(APP_ALIAS[app_name], str):
aliases_to_add[APP_ALIAS[app_name]] = notified_tools[app_name]
notified_tools.update(aliases_to_add)
return notified_tools
[docs]
class BaseNotificationSystem:
"""Base class for all notification systems."""
[docs]
def __init__(
self, config: NotificationSystemConfig = NotificationSystemConfig()
) -> None:
self.message_queue = MessageQueue()
self.time_manager = None
self.reminder_app = None
self.system_app = None
self._initialized = False
self.config = config
[docs]
def initialize(self, time_manager: TimeManager) -> None:
self.time_manager = time_manager
self._initialized = True
[docs]
def setup_reminder_app(self, reminder_app: ReminderApp) -> None:
self.reminder_app = reminder_app
[docs]
def setup_system_app(self, system_app: SystemApp) -> None:
self.system_app = system_app
[docs]
def clear(self) -> None:
self.message_queue = MessageQueue()
[docs]
def get_current_time(self) -> float:
if not self.time_manager:
raise ValueError("Time manager is not set.")
return self.time_manager.time()
[docs]
def get_next_notification_time(self) -> float | None:
next_message = self.message_queue.messages.peek()
if next_message is None:
return None
return datetime.fromtimestamp(
next_message.timestamp.timestamp(), tz=timezone.utc
).timestamp()
[docs]
def handle_time_based_notifications(self) -> None:
if self.reminder_app:
due_reminders = self.reminder_app.get_due_reminders()
new_due_reminders = [
reminder for reminder in due_reminders if not reminder.already_notified
]
if new_due_reminders:
message = self.convert_reminders_to_message(new_due_reminders)
if message is not None:
self.message_queue.put(message)
[docs]
def handle_timeout_after_events(self) -> None:
"""
Handle timeout notifications after all events in the current tick have been processed.
This ensures that timeout notifications are only triggered if no other notifications
were generated during the same tick.
"""
current_timestamp = datetime.fromtimestamp(
self.get_current_time(), tz=timezone.utc
)
if self.system_app:
if not self.message_queue.has_new_messages(timestamp=current_timestamp):
wait_for_notification_timeout = (
self.system_app.get_wait_for_notification_timeout()
)
if wait_for_notification_timeout is not None:
# Insert a wait for notification timeout message
message = self.convert_wait_for_notification_timeout_to_message(
wait_for_notification_timeout
)
if message is not None:
self.message_queue.put(message)
self.system_app.reset_wait_for_notification_timeout()
else:
# If there are new messages, reset the wait for notification timeout
self.system_app.reset_wait_for_notification_timeout()
[docs]
def convert_reminders_to_message(
self, due_reminders: list[Reminder]
) -> Message | None:
"""Convert due reminders to a notification message."""
if self.reminder_app:
timestamp = self.get_current_time()
for reminder in due_reminders:
reminder.time_notified = timestamp
message = f"{self.reminder_app.name}: New due reminders:\n"
message += "\n".join(str(reminder) for reminder in due_reminders)
return Message(
message_type=MessageType.ENVIRONMENT_NOTIFICATION,
message=message,
timestamp=datetime.fromtimestamp(timestamp, tz=timezone.utc),
)
return None
[docs]
def convert_wait_for_notification_timeout_to_message(
self, wait_for_notification_timeout: WaitForNotificationTimeout
) -> Message | None:
if self.system_app:
timestamp = self.get_current_time()
message = f"{self.system_app.name}: Wait for notification timeout reached after {wait_for_notification_timeout.timeout} seconds"
return Message(
message_type=MessageType.ENVIRONMENT_NOTIFICATION,
message=message,
timestamp=datetime.fromtimestamp(timestamp, tz=timezone.utc),
)
return None
[docs]
def handle_event(self, event: AbstractEvent) -> None:
if not self._initialized:
raise ValueError("Notification system is not initialized.")
message = self.convert_to_message(event)
if message is not None:
self.message_queue.put(message)
[docs]
def convert_to_message(self, event: AbstractEvent) -> Message | None:
if not isinstance(event, CompletedEvent) or not isinstance(
event.action, Action
):
return None
timestamp = self.get_current_time()
# Handle AgentUserInterface send_message_to_agent events
if (
hasattr(event.action, "app")
and event.app_class_name() == "AgentUserInterface"
and event.function_name() == "send_message_to_agent"
):
args: dict[str, Any] = get_args(event)
message = str(
AUIMessage(
sender=Sender.USER,
content=args["content"],
timestamp=timestamp,
time_read=timestamp,
)
)
dumped_attachments: list[dict[str, Any]] = (
args.get("base64_utf8_encoded_attachment_contents") or []
)
attachments: list[Attachment] = [
Attachment.model_validate(attachment)
for attachment in dumped_attachments
]
return Message(
message_type=MessageType.USER_MESSAGE,
message=message,
timestamp=datetime.fromtimestamp(timestamp, tz=timezone.utc),
attachments=attachments,
)
# Handle environment notification events
message = get_content_for_message(event)
app_class_name = event.app_class_name()
if (
app_class_name is not None
and app_class_name in self.config.notified_tools
and event.function_name() in self.config.notified_tools[app_class_name]
and message is not None
):
return Message(
message_type=MessageType.ENVIRONMENT_NOTIFICATION,
message=message,
timestamp=datetime.fromtimestamp(timestamp, tz=timezone.utc),
)
return None
[docs]
def get_args(event: CompletedEvent) -> dict[str, Any]:
assert type(event.action) is Action
if event.action.resolved_args:
return event.action.resolved_args
return event.action.args
[docs]
class VerboseNotificationSystem(BaseNotificationSystem):
[docs]
def __init__(self, verbosity_level: VerbosityLevel = VerbosityLevel.MEDIUM):
self.verbosity_level = verbosity_level
config = NotificationSystemConfig(
notified_tools=get_notification_tools(self.verbosity_level)
)
super().__init__(config)
def __call__(self) -> BaseNotificationSystem:
config = NotificationSystemConfig(
notified_tools=get_notification_tools(self.verbosity_level)
)
return BaseNotificationSystem(config)
NotificationSystem = VerboseNotificationSystem(verbosity_level=VerbosityLevel.LOW)
[docs]
def get_content_for_message(event: AbstractEvent) -> str | None:
if type(event) is CompletedEvent and type(event.action) is Action:
if (
hasattr(event.action, "app")
and (
event.app_class_name() == "EmailClientApp"
or event.app_class_name() in APP_ALIAS["EmailClientApp"]
)
and event.function_name()
in ["create_and_add_email", "send_email_to_user_only"]
):
return f"{event.app_class_name()}: New email received from {get_args(event)['sender']}"
elif (
hasattr(event.action, "app")
and (
event.app_class_name() == "EmailClientApp"
or event.app_class_name() in APP_ALIAS["EmailClientApp"]
)
and event.function_name() == "reply_to_email_from_user"
):
return f"{event.app_class_name()}: New email received"
elif (
hasattr(event.action, "app")
and (
event.app_class_name() == "MessagingApp"
or event.app_class_name() in APP_ALIAS["MessagingApp"]
or event.app_class_name() == "MessagingAppV2"
or event.app_class_name() in APP_ALIAS["MessagingAppV2"]
)
and event.function_name() == "create_and_add_message"
):
return f"{event.app_class_name()}: New message received in conversation ID: {get_args(event)['conversation_id']}"
elif (
hasattr(event.action, "app")
and (
event.app_class_name() == "Shopping"
or event.app_class_name() == "ShoppingApp"
)
and event.function_name() == "add_product"
):
return f"{event.app_class_name()}: New product added: {get_args(event)['name']}"
elif (
hasattr(event.action, "app")
and (
event.app_class_name() == "Shopping"
or event.app_class_name() == "ShoppingApp"
)
and event.function_name() == "add_item_to_product"
):
return f"{event.app_class_name()}: New item added to product with Id {get_args(event)['product_id']}"
elif (
hasattr(event.action, "app")
and (
event.app_class_name() == "Shopping"
or event.app_class_name() == "ShoppingApp"
)
and event.function_name() == "cancel_order"
and event.event_type in [EventType.USER, EventType.ENV]
):
return f"{event.app_class_name()}: Order cancelled with Id {get_args(event)['order_id']}"
elif (
hasattr(event.action, "app")
and (
event.app_class_name() == "Shopping"
or event.app_class_name() == "ShoppingApp"
)
and event.function_name() == "update_order_status"
):
return f"{event.app_class_name()}: Updated order with Id {get_args(event)['order_id']}"
elif (
hasattr(event.action, "app")
and (
event.app_class_name() == "Shopping"
or event.app_class_name() == "ShoppingApp"
)
and event.function_name() == "add_discount_code"
):
return f"{event.app_class_name()}: New discount code for item with Id {get_args(event)['item_id']}"
elif (
hasattr(event.action, "app")
and (
event.app_class_name() == "RentAFlat"
or event.app_class_name() == "ApartmentListingApp"
)
and event.function_name() == "add_new_apartment"
):
return f"{event.app_class_name()}: New apartment available with name {get_args(event)['name']}"
elif (
hasattr(event.action, "app")
and event.app_class_name() == "CabApp"
and event.function_name() in ["cancel_ride", "user_cancel_ride"]
and event.event_type in [EventType.USER, EventType.ENV]
):
return f"{event.app_class_name()}: ride cancelled"
elif (
hasattr(event.action, "app")
and event.app_class_name() == "CabApp"
and event.function_name() == "update_ride_status"
):
return f"{event.app_class_name()}: ride status updated"
elif (
hasattr(event.action, "app")
and event.app_class_name() == "CabApp"
and event.function_name() == "end_ride"
):
return f"{event.app_class_name()}: ride completed"
elif (
hasattr(event.action, "app")
and event.app_class_name() in ["Calendar", "CalendarApp"]
and event.function_name() == "add_calendar_event_by_attendee"
):
return f"{event.app_class_name()}: New calendar event added by {get_args(event)['who_add']}"
elif (
hasattr(event.action, "app")
and event.app_class_name() in ["Calendar", "CalendarApp"]
and event.function_name() == "delete_calendar_event_by_attendee"
):
return f"{event.app_class_name()}: Calendar event deleted by {get_args(event)['who_delete']}"
else:
return None
else:
return None