Source code for are.simulation.apps.app
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import logging
import random
from abc import ABC
from enum import Enum, auto
from typing import Any, Callable
from are.simulation.time_manager import TimeManager
from are.simulation.tool_utils import AppTool, ToolAttributeName, build_tool
from are.simulation.utils import SkippableDeepCopy, add_reset
logger = logging.getLogger(__name__)
[docs]
@add_reset
class App(ABC, SkippableDeepCopy):
"""
Base class for all applications in the Meta Agents Research Environments environment.
"""
# We skip add_event_callbacks from copy and pickling because it is not serializable
# That is because after apps are registered, it holds on to a reference to the environment
# Which itself contains non serializable fields like a threading.lock instance.
# add_event seems to be legacy, but leaving it here to not break existing apps.
_skip_deepcopy_fields = ["add_event", "add_event_callbacks"]
_skip_pickle_fields = _skip_deepcopy_fields
[docs]
def __init__(self, name: str | None = None, *args, **kwargs):
super().__init__()
self.name = self.__class__.__name__ if name is None else name
self.is_state_modified = False
self.add_event_callbacks = {}
self._tool_registries: dict[ToolType, list[AppTool] | None] = {
ToolType.APP: None,
ToolType.USER: None,
ToolType.ENV: None,
ToolType.DATA: None,
}
# We can augment App behavior by adding a failure_probability, so that each tool call can fail randomly
self.failure_probability: float | None = None
self.time_manager = TimeManager()
self.set_seed(0)
[docs]
def set_seed(self, seed: int) -> None:
# Derive a new seed from the combination of the input seed and app name
# This ensures each app instance gets a unique but deterministic seed
combined_seed = f"{seed}_{self.name}"
self.seed = hash(combined_seed) % (2**32)
self.rng = random.Random(self.seed)
[docs]
def register_to_env(
self,
key: str,
add_event: Callable[[Any], None],
):
self.add_event_callbacks[key] = add_event
[docs]
def add_event(self, event: Any) -> None:
for callback in self.add_event_callbacks.values():
callback(event)
[docs]
def get_implemented_protocols(self) -> list[Protocol]:
"""
App can provide protocols, e.g. FileSystem which could be used by other apps
Returns a list of protocol names that the app implements.
"""
return []
[docs]
def connect_to_protocols(self, protocols: dict[Protocol, Any]) -> None:
"""
App can connect to other apps via protocols.
"""
pass
def reset(self):
self.rng = random.Random(self.seed)
[docs]
def set_failure_probability(self, failure_probability: float) -> None:
logger.debug(f"Setting failure_probability to {failure_probability}")
self.failure_probability = failure_probability
# When failure probability is set, we need to reset the tool registries
logger.debug("Resetting tool registries")
self._tool_registries: dict[ToolType, list[AppTool] | None] = {
ToolType.APP: None,
ToolType.USER: None,
ToolType.ENV: None,
ToolType.DATA: None,
}
[docs]
def get_tools_with_attribute(
self, attribute: ToolAttributeName, tool_type: ToolType
) -> list[AppTool]:
"""
Retrieves a list of tools that have a specific attribute from the class and its base classes.
:param attribute: The attribute to look for in the tools as a ToolAttributeName enum value.
:type attribute: ToolAttributeName
:param tool_type: The type of tool being registered.
:type tool_type: ToolType
:return: A list of AppTool objects that have the specified attribute.
:rtype: list[AppTool]
"""
# Get the string value from the enum
attr_name = attribute.value
tools = []
processed_attributes = set() # Track processed attributes
cls = self.__class__
# Iterate over the class and its base classes (MRO)
# Child classes are processed first, this ensures that if a function is overridden
# We only include the child class version
for base_cls in inspect.getmro(cls):
for attr_name, attr_value in base_cls.__dict__.items():
if attr_name in processed_attributes:
# Skip if the attribute has been overridden
continue
# Convert enum to string if needed
attr_name_str = (
attribute.value
if isinstance(attribute, ToolAttributeName)
else attribute
)
if hasattr(attr_value, attr_name_str):
logger.debug(
f"[Registering {tool_type} Tool] {attr_name} of class {base_cls.__name__}"
)
if not attr_value.__doc__:
logger.error(
f"\tDid not find doc of {attr_name} of class {base_cls.__name__} - trying base class method"
)
attr_value.__doc__ = get_base_method_doc(base_cls, attr_value)
# We only want to have random failures (when failure_probability is set) for Agent tools
failure_probability = (
self.failure_probability if tool_type == ToolType.APP else None
)
tools.append(build_tool(self, attr_value, failure_probability))
processed_attributes.add(attr_name) # Mark as processed
# Also check instance methods to get dynamically added methods
for attr_name, attr_value in self.__dict__.items():
if attr_name in processed_attributes:
continue
# Convert enum to string if needed
attr_name_str = (
attribute.value
if isinstance(attribute, ToolAttributeName)
else attribute
)
if hasattr(attr_value, attr_name_str):
logger.debug(
f"[Registering {tool_type} Tool] {attr_name} of instance {self.__class__.__name__}"
)
if not attr_value.__doc__:
logger.error(
f"\tDid not find doc of {attr_name} of instance {self.__class__.__name__}"
)
# We only want to have random failures (when failure_probability is set) for Agent tools
failure_probability = (
self.failure_probability if tool_type == ToolType.APP else None
)
tools.append(build_tool(self, attr_value, failure_probability))
processed_attributes.add(attr_name) # Mark as processed
logger.debug(
f"[Registering {tool_type} Tool] Built Tool Registry for class {cls.__name__} with {len(tools)} tools"
)
return tools
def _get_or_initialize_tools(
self, tool_type: ToolType, attribute: ToolAttributeName
) -> list[AppTool]:
"""
Helper method to get or initialize tools for a specific tool type
:param tool_type: The type of tool being registered (e.g., "Agent", "User", "Env")
:param attribute: The attribute to look for in the tools as a ToolAttributeName enum value.
:return: A list of AppTool.
"""
if self._tool_registries[tool_type] is None:
tools = self.get_tools_with_attribute(
attribute=attribute, tool_type=tool_type
)
self._tool_registries[tool_type] = tools
return tools
return self._tool_registries[tool_type] or []
[docs]
def get_tools(self) -> list[AppTool]:
"""
Retrieves the list of agent tools, initializing the tool registry if necessary.
:return: A list of AppTool objects for agents.
"""
return self._get_or_initialize_tools(ToolType.APP, ToolAttributeName.APP)
[docs]
def get_user_tools(self) -> list[AppTool]:
"""
Retrieves the list of user tools, initializing the user tool registry if necessary.
:return: A list of AppTool objects for users.
"""
return self._get_or_initialize_tools(ToolType.USER, ToolAttributeName.USER)
[docs]
def get_env_tools(self) -> list[AppTool]:
"""
Retrieves the list of environment tools, initializing the environment tool registry if necessary.
:return: A list of AppTool objects for the environment.
"""
return self._get_or_initialize_tools(ToolType.ENV, ToolAttributeName.ENV)
[docs]
def get_tool(self, tool_name: str) -> AppTool | None:
"""
Retrieves a specific tool by name, searching through all tool types.
:param tool_name: The name of the tool to retrieve.
:type tool_name: str
:return: The AppTool object with the specified name, or None if not found.
:rtype: AppTool | None
"""
# Try to find the tool in all tool types
for tool_getter in [
self.get_tools,
self.get_user_tools,
self.get_env_tools,
self.get_data_tools,
]:
try:
tools = tool_getter()
# Look for the tool with the matching name
for tool in tools:
if tool.name == f"{self.name}__{tool_name}":
return tool
except Exception:
# Continue to the next tool type if there's an error
continue
return None
[docs]
def get_data_tools(self) -> list[AppTool]:
"""
Retrieves the list of data tools, initializing the data tool registry if necessary.
:return: A list of AppTool objects for data.
"""
return self._get_or_initialize_tools(ToolType.DATA, ToolAttributeName.DATA)
[docs]
def get_base_method_doc(cls: type, method: Callable) -> str | None:
"""
Retrieves the docstring of a method from its base class if it exists.
:param cls: The class containing the method.
:type cls: type
:param method: The method whose docstring is being retrieved.
:type method: Callable
:return: The docstring of the method, or None if it does not exist in the base class.
"""
# we get the first one we find
for base in cls.__bases__:
base_method = getattr(base, method.__name__, None)
if base_method and base_method.__doc__:
return base_method.__doc__
return None