# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import random
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from are.simulation.apps.app import App
from are.simulation.tool_utils import OperationType, app_tool, data_tool, env_tool
from are.simulation.types import EventType, disable_events, event_registered
from are.simulation.utils import get_state_dict, type_check, uuid_hex
DEFAULT_RIDE_CANCELLED_MESSAGE = (
"The ride has been cancelled. Sorry for the inconvenience."
)
[docs]
@dataclass
class Ride:
ride_id: str = field(default_factory=lambda: uuid.uuid4().hex)
status: str | None = None
service_type: str | None = None
start_location: str | None = None
end_location: str | None = None
price: float | None = None
duration: float | None = None
time_stamp: float | None = None
distance_km: float | None = None
delay: float | None = None
delay_history: list[dict] = field(default_factory=list)
def __post_init__(self):
if self.ride_id is None or len(self.ride_id) == 0:
self.ride_id = uuid.uuid4().hex
[docs]
def set_booked(self):
self.status = "BOOKED"
self.delay_history.append({"delay": self.delay, "time_stamp": self.time_stamp})
[docs]
def update_delay(self, current_time_stamp: float, rng: random.Random):
if self.time_stamp is None or self.delay is None:
raise ValueError("time_stamp or delay is not set")
delta_time = current_time_stamp - self.time_stamp
self.delay = min(0, self.delay - delta_time * (1 + rng.uniform(-1.5, 1)))
self.delay_history.append({"delay": self.delay, "time_stamp": self.time_stamp})
def __str__(self):
return f"""
Ride ID: {self.ride_id}
Status: {self.status}
Service Type: {self.service_type}
Start Location: {self.start_location}
End Location: {self.end_location}
Distance: {self.distance_km} kms
Price: ${self.price:.2f}
Duration: {"N/A" if self.duration is None else self.duration / 60} mins
Timestamp: {self.time_stamp}
Delay: {"N/A" if self.delay is None else self.delay / 60} mins
"""
[docs]
@dataclass
class RideHistory:
name: str
rides: list[Ride]
[docs]
@dataclass
class OnGoingRide:
name: str
ride: Ride | None = None
[docs]
@dataclass
class CabApp(App):
"""
A cab service application that manages and facilitates ride requests and bookings. This class provides
functionality for creating, reading, updating, and canceling rides, as well as calculating fares and
handling ride history.
The CabApp maintains rides in a structured format, allowing users to book rides, view current ride status,
and retrieve ride history. Each ride is represented by a unique Ride object, containing relevant details
about the journey.
Key Features:
- Ride Management: Create, book, cancel, and retrieve ride details
- Quotation System: Calculate fare estimates based on distance, service type, and historical data
- Ride History: Track past rides and access details for each ride
- Delay Management: Update and record delays for ongoing rides
- State Persistence: Save and load application state to retain ride and quotation history
Notes:
- All ride attributes are expected to conform to specific data types (e.g., price as float, distance
as float)
- Ride IDs are automatically generated upon creation
- The class allows for the cancellation of rides by both users and drivers
- Fare calculations consider historical pricing trends and maximum service distances
- The distance calculation is currently a mock function; integration with a real mapping API is
recommended for accurate distance measurements
"""
name: str | None = None
quotation_history: list[Ride] = field(default_factory=list)
on_going_ride: Ride | None = None
ride_history: list[Ride] = field(default_factory=list)
MESSAGE_CANCEL: str = DEFAULT_RIDE_CANCELLED_MESSAGE
d_service_config: dict[str, dict[str, float]] = field(
default_factory=lambda: {
"Default": {
"nb_seats": 4,
"price_per_km": 1.0,
"base_delay_min": 5,
"max_distance_km": 25,
},
"Premium": {
"nb_seats": 4,
"price_per_km": 2.0,
"base_delay_min": 3,
"max_distance_km": 25,
},
"Van": {
"nb_seats": 6,
"price_per_km": 1.5,
"base_delay_min": 7,
"max_distance_km": 25,
},
}
)
def __post_init__(self):
super().__init__(self.name)
[docs]
def get_state(self) -> dict[str, Any]:
return get_state_dict(
self,
[
"ride_history",
"quotation_history",
"d_service_config",
],
)
[docs]
def load_state(self, state_dict: dict[str, Any]):
for ride_history in state_dict["ride_history"]:
if "cab_app" in ride_history:
ride_history.pop("cab_app")
if "seed" in ride_history:
ride_history.pop("seed")
self.ride_history.append(Ride(**ride_history))
for ride_history in state_dict["quotation_history"]:
if "cab_app" in ride_history:
ride_history.pop("cab_app")
if "seed" in ride_history:
ride_history.pop("seed")
self.quotation_history.append(Ride(**ride_history))
self.d_service_config = state_dict["d_service_config"]
[docs]
def reset(self):
super().reset()
self.ride_history = []
self.quotation_history = []
self.d_service_config = {}
self.on_going_ride = None
def _parse_ride_time(self, ride_time: str | None = None) -> tuple[str, float]:
"""
Parse and validate ride time, providing default if None.
:param ride_time: Optional ride time string in 'YYYY-MM-DD HH:MM:SS' format
:returns: Tuple of (formatted_time_string, timestamp_float)
:raises ValueError: If ride_time format is invalid
"""
# Generate default time if not provided
if ride_time is None:
ride_time = datetime.fromtimestamp(
self.time_manager.time(), tz=timezone.utc
).strftime("%Y-%m-%d %H:%M:%S")
# Parse and validate the time string
try:
time_stamp = (
datetime.strptime(ride_time, "%Y-%m-%d %H:%M:%S")
.replace(tzinfo=timezone.utc)
.timestamp()
)
except ValueError:
raise ValueError(
"Invalid datetime format for the ride time. Please use YYYY-MM-DD HH:MM:SS"
)
return ride_time, time_stamp
[docs]
def calculate_price(
self,
start_location: str,
end_location: str,
distance_km: float,
service_type: str,
time_stamp,
) -> float:
def get_previous_price():
for ride in self.quotation_history:
if (
ride.start_location == start_location
and ride.end_location == end_location
and ride.service_type == service_type
):
return ride
return None
ex = get_previous_price()
if ex and ex.price:
variance = 0.01 * (time_stamp - ex.time_stamp) / 3600 # 1% per hour
variance = min(max(variance, 0.5), 1.5) # bounded
price = ex.price * (1 + self.rng.uniform(-variance, variance))
else:
price = distance_km * self.d_service_config[service_type]["price_per_km"]
return price
[docs]
@type_check
@data_tool()
@event_registered(operation_type=OperationType.WRITE)
def add_new_ride(
self,
service_type: str,
start_location: str,
end_location: str,
price: float,
duration: float = 0.0,
time_stamp: float = 0.0,
distance_km: float = 0.0,
) -> str:
"""
Add a new ride to the ride history.
:param service_type: type of service (Default, Premium, Van)
:param start_location: starting point of the ride
:param end_location: ending point of the ride
:param price: price of the ride
:param duration: duration in minutes of the ride
:param time_stamp: time stamp of the ride
:param distance_km: distance in kilometers of the ride
:return: ride id of the added ride if successful, otherwise raise Exception.
"""
status = "BOOKED"
ride = Ride(
ride_id=uuid_hex(self.rng),
status=status,
service_type=service_type,
start_location=start_location,
end_location=end_location,
price=price,
duration=duration,
time_stamp=time_stamp,
distance_km=distance_km,
)
self.ride_history.append(ride)
self.quotation_history.append(ride)
return ride.ride_id
[docs]
@type_check
@app_tool()
@event_registered(operation_type=OperationType.READ)
def get_quotation(
self,
start_location: str,
end_location: str,
service_type: str,
ride_time: str | None = None,
) -> Ride:
"""
Calculates the price and estimated delay for a ride.
:param start_location: starting point of the ride
:param end_location: ending point of the ride
:param service_type: type of service (Default, Premium, Van)
:param ride_time: the time of the ride in the format 'YYYY-MM-DD HH:MM:SS'. If None, the current time is used.
:returns: Ride with all the information: start_location, end_location, service_type, price, delay, distance, duration, the time_stamp.
"""
_, time_stamp = self._parse_ride_time(ride_time)
if service_type not in self.d_service_config:
raise ValueError("Invalid service type.")
distance_km = self.calculate_distance(start_location, end_location)
if distance_km > self.d_service_config[service_type]["max_distance_km"]:
raise ValueError("Distance exceeds maximum allowed.")
price = self.calculate_price(
start_location, end_location, distance_km, service_type, time_stamp
)
delay = self.d_service_config[service_type][
"base_delay_min"
] + self.rng.randint(1, 5)
duration = (
distance_km / 50 * 60 + 10 * self.rng.random()
) # Assuming average speed of 50 km/h
ride = Ride(
ride_id=uuid_hex(self.rng),
service_type=service_type,
start_location=start_location,
end_location=end_location,
price=price,
duration=duration,
time_stamp=time_stamp,
distance_km=distance_km,
delay=delay,
)
self.quotation_history.append(ride)
return ride
[docs]
@type_check
@app_tool()
@data_tool()
@event_registered(operation_type=OperationType.READ)
def list_rides(
self,
start_location: str,
end_location: str,
ride_time: str | None = None,
) -> list[Ride]:
"""
Lists all rides available between two locations.
:param start_location: starting point of the ride
:param end_location: ending point of the ride
:param ride_time: the time of the ride. If None, the current time is used.
:returns: list of Ride objects
"""
ride_time_str, _ = self._parse_ride_time(ride_time)
all_rides = []
for service_type in self.d_service_config.keys():
ride = self.get_quotation(
start_location,
end_location,
service_type,
ride_time=ride_time_str,
)
all_rides.append(ride)
return all_rides
[docs]
@type_check
@app_tool()
@data_tool()
@event_registered(operation_type=OperationType.WRITE)
def order_ride(
self,
start_location: str,
end_location: str,
service_type: str,
ride_time: str | None = None,
) -> Ride:
"""
Orders a ride and returns the ride details.
:param start_location: starting point of the ride
:param end_location: ending point of the ride
:param service_type: type of service (Default, Premium, Van)
:param ride_time: the time of the ride
:returns: booked ride, represented by a Ride object
"""
if self.on_going_ride is not None:
raise ValueError("You have an on-going ride.")
ride_time_str, _ = self._parse_ride_time(ride_time)
# Note that here the app looks for a cab but the user is not aware of the delay.
ride = self.get_quotation(
start_location, end_location, service_type, ride_time_str
)
ride.set_booked()
self.ride_history.append(ride)
self.on_going_ride = ride
return ride
[docs]
@data_tool()
@env_tool()
@type_check
@event_registered(operation_type=OperationType.WRITE, event_type=EventType.ENV)
def cancel_ride(
self, who_cancel: str = "driver", message: str | None = None
) -> str:
"""
The current ride is cancelled (by user or by driver).
:param who_cancel: who cancel the ride, either 'driver' or 'user'
:param message: optional message to send to the user
:returns: message
"""
if who_cancel not in ["driver", "user"]:
raise ValueError("who_cancel must be either 'driver' or 'user'.")
if self.on_going_ride is None:
raise ValueError("You have no on-going ride.")
assert self.on_going_ride == self.ride_history[-1]
self.on_going_ride.status = "CANCELLED"
self.on_going_ride = None
message = message if message else self.MESSAGE_CANCEL
return message
[docs]
@type_check
@app_tool()
@event_registered(operation_type=OperationType.WRITE)
def user_cancel_ride(self):
"""
Cancel the current ride.
"""
message = "Ride has been cancelled, sorry to see you go."
with disable_events():
message = self.cancel_ride(who_cancel="user", message=message)
return message
[docs]
@env_tool()
@type_check
@event_registered(operation_type=OperationType.WRITE, event_type=EventType.ENV)
def end_ride(self) -> str:
"""
End the current ride.
:returns: "Ride has been completed."
"""
if self.on_going_ride is None:
raise ValueError("You have no on-going ride.")
assert self.on_going_ride == self.ride_history[-1]
self.on_going_ride.status = "COMPLETED"
self.on_going_ride = None
return "Ride has been completed."
[docs]
@env_tool()
@type_check
@event_registered(operation_type=OperationType.WRITE, event_type=EventType.ENV)
def update_ride_status(self, status: str, message: str | None = None):
"""
Update the status of the current ride.
:param status: new status of the ride. Must be one of "DELAYED", "IN_PROGRESS", "ARRIVED_AT_PICKUP".
:param message: optional message from the driver.
:returns: new status of the ride
"""
if status not in ["DELAYED", "IN_PROGRESS", "ARRIVED_AT_PICKUP"]:
raise ValueError(
"status must be one of 'DELAYED', 'IN_PROGRESS', 'ARRIVED_AT_PICKUP'"
)
if self.on_going_ride is None:
raise ValueError("You have no on-going ride.")
assert self.on_going_ride == self.ride_history[-1]
self.on_going_ride.status = status
if message:
out_message = f"Ride status has been updated to {status}. Message from your driver: {message}"
else:
out_message = f"Ride status has been updated to {status}."
return out_message
[docs]
@type_check
@app_tool()
@event_registered(operation_type=OperationType.READ)
def get_current_ride_status(self) -> Ride:
"""
Check the status for the current ride ordered.
:returns: ride details, represented by a Ride object
"""
if self.on_going_ride:
self.on_going_ride.update_delay(self.time_manager.time(), self.rng)
return self.on_going_ride
else:
raise ValueError("No ride ordered.")
[docs]
@type_check
@app_tool()
@data_tool()
@event_registered(operation_type=OperationType.READ)
def get_ride(self, idx: int):
"""
Gets a specific ride from the ride history.
:param idx: index of the ride to retrieve
:returns: ride details
"""
if 0 <= idx < len(self.ride_history):
return self.ride_history[idx]
else:
raise IndexError("Ride does not exist.")
[docs]
@type_check
@app_tool()
@data_tool()
@event_registered(operation_type=OperationType.READ)
def get_ride_history(self, offset: int = 0, limit: int = 10) -> dict[str, Any]:
"""
Gets a list of rides from the ride history starting from a specified offset.
:param offset: starting point to retrieve rides from, default is 0.
:param limit: maximum number of rides to retrieve, default is 10.
:returns: dictionary of ride details, where the key is the index of the ride in the ride history and the value is the ride details, with additional metadata about the range of rides retrieved and total number of rides.
"""
ride_history_subset = self.ride_history[offset : offset + limit]
return {
"rides": {
offset + idx: ride for idx, ride in enumerate(ride_history_subset)
},
"range": (
offset,
min(offset + limit, len(self.ride_history)),
),
"total": len(self.ride_history),
}
[docs]
@type_check
@app_tool()
@data_tool()
@event_registered(operation_type=OperationType.READ)
def get_ride_history_length(self) -> int:
"""
Gets the length of the ride history.
:returns: length of the ride history
"""
return len(self.ride_history)
[docs]
def get_distance_from_history(self, start_location, end_location):
for ride in self.quotation_history:
if (
ride.start_location == start_location
and ride.end_location == end_location
):
return ride.distance_km
return None
[docs]
def calculate_distance(self, start_location: str, end_location: str) -> float:
"""
Mock function to calculate the distance between two locations.
:param start_location: starting point of the ride
:param end_location: ending point of the ride
:returns: distance in kilometers
"""
distance = self.get_distance_from_history(start_location, end_location)
if not distance:
# In a real-world scenario, you would integrate with a mapping service API to calculate the distance
distance = self.rng.uniform(5, 20) # Mock distance between 5 and 20 km
return distance
[docs]
def delete_future_data(self, timestamp):
"""
Delete all future data from the ride history.
:param timestamp: timestamp to delete data after
"""
self.ride_history = [
ride for ride in self.ride_history if ride.time_stamp <= timestamp
]
self.quotation_history = [
ride for ride in self.quotation_history if ride.time_stamp <= timestamp
]