|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from datetime import datetime |
| 4 | +from typing import Annotated |
| 5 | + |
| 6 | +from agents import Agent, RunContextWrapper, StopAtTools, function_tool |
| 7 | +from chatkit.agents import AgentContext, ClientToolCall |
| 8 | +from chatkit.types import ( |
| 9 | + AssistantMessageContent, |
| 10 | + AssistantMessageItem, |
| 11 | + ProgressUpdateEvent, |
| 12 | + ThreadItemDoneEvent, |
| 13 | +) |
| 14 | +from pydantic import BaseModel, ConfigDict, Field |
| 15 | + |
| 16 | +from ..data.metro_map_store import Line, MetroMap, MetroMapStore, Station |
| 17 | +from ..memory_store import MemoryStore |
| 18 | +from ..request_context import RequestContext |
| 19 | +from ..widgets.line_select_widget import build_line_select_widget |
| 20 | + |
| 21 | +INSTRUCTIONS = """ |
| 22 | + You are a concise metro planner helping city planners update the Orbital Transit map. |
| 23 | + Give short answers, list 2–3 options, and highlight the lines or interchanges involved. |
| 24 | +
|
| 25 | + Before recommending a route, sync the latest map with the provided tools. Cite line |
| 26 | + colors when helpful (e.g., "take Red then Blue at Central Exchange"). |
| 27 | +
|
| 28 | + When the user asks what to do next, reply with 2 concise follow-up ideas and pick one to lead with. |
| 29 | + Default to actionable options like adding another station on the same line or explaining how to travel |
| 30 | + from the newly added station to a nearby destination. |
| 31 | +
|
| 32 | + When the user mentions a station, always call the `get_map` tool to sync the latest map before responding. |
| 33 | +
|
| 34 | + When a user wants to add a station (e.g. "I would like to add a new metro station." or "Add another station"): |
| 35 | + - If the user did not specify a line, you MUST call `show_line_selector` with a message prompting them to choose one |
| 36 | + from the list of lines. You must NEVER ask the user to choose a line without calling `show_line_selector` first. |
| 37 | + This applies even if you just added a station—treat each new "add a station" turn as needing a fresh line selection |
| 38 | + unless the user explicitly included the line in that same turn or in the latest message via <LINE_SELECTED>. |
| 39 | + - If the user replies with a number to pick one of your follow-up options AND that option involves adding a station, |
| 40 | + treat this as a fresh station-add request and immediately call `show_line_selector` before asking anything else. |
| 41 | + - If the user did not specify a station name, ask them to enter a name. |
| 42 | + - If the user did not specify whether to add the station to the end of the line or the beginning, ask them to choose one. |
| 43 | + - When you have all the information you need, call the `add_station` tool with the station name, line id, and append flag. |
| 44 | +
|
| 45 | + Describing: |
| 46 | + - After a new station has been added, describe it to the user in a whimsical and poetic sentence. |
| 47 | + - When describing a station to the user, omit the station id and coordinates. |
| 48 | + - When describing a line to the user, omit the line id and color. |
| 49 | +
|
| 50 | + When a user wants to plan a route: |
| 51 | + - If the user did not specify a starting or detination station, ask them to choose them from the list of stations. |
| 52 | + - Provide a one-sentence route, the estimated travel time, and points of interest along the way. |
| 53 | + - Avoid over-explaining and stay within the given station list. |
| 54 | +
|
| 55 | + Custom tags: |
| 56 | + - <LINE_SELECTED>{line_id}</LINE_SELECTED> - when the user has selected a line, you can use this tag to reference the line id. |
| 57 | + When this is the latest message, acknowledge the selection. |
| 58 | + - <STATION_TAG>...</STATION_TAG> - contains full station details (id, name, description, coordinates, and served lines with ids/colors/orientations). |
| 59 | + Use the data inside the tag directly; do not call `get_station` just to resolve a tagged station. |
| 60 | +""" |
| 61 | + |
| 62 | + |
| 63 | +class MetroAgentContext(AgentContext): |
| 64 | + model_config = ConfigDict(arbitrary_types_allowed=True) |
| 65 | + store: Annotated[MemoryStore, Field(exclude=True)] |
| 66 | + metro: Annotated[MetroMapStore, Field(exclude=True)] |
| 67 | + request_context: Annotated[RequestContext, Field(exclude=True)] |
| 68 | + |
| 69 | + |
| 70 | +class MapResult(BaseModel): |
| 71 | + map: MetroMap |
| 72 | + |
| 73 | + |
| 74 | +class LineListResult(BaseModel): |
| 75 | + lines: list[Line] |
| 76 | + |
| 77 | + |
| 78 | +class StationListResult(BaseModel): |
| 79 | + stations: list[Station] |
| 80 | + |
| 81 | + |
| 82 | +class LineDetailResult(BaseModel): |
| 83 | + line: Line |
| 84 | + stations: list[Station] |
| 85 | + |
| 86 | + |
| 87 | +class StationDetailResult(BaseModel): |
| 88 | + station: Station |
| 89 | + lines: list[Line] |
| 90 | + |
| 91 | + |
| 92 | +@function_tool(description_override="Show a clickable widget listing metro lines.") |
| 93 | +async def show_line_selector(ctx: RunContextWrapper[MetroAgentContext], message: str): |
| 94 | + widget = build_line_select_widget(ctx.context.metro.list_lines()) |
| 95 | + await ctx.context.stream( |
| 96 | + ThreadItemDoneEvent( |
| 97 | + item=AssistantMessageItem( |
| 98 | + thread_id=ctx.context.thread.id, |
| 99 | + id=ctx.context.generate_id("message"), |
| 100 | + created_at=datetime.now(), |
| 101 | + content=[AssistantMessageContent(text=message)], |
| 102 | + ), |
| 103 | + ) |
| 104 | + ) |
| 105 | + await ctx.context.stream_widget(widget) |
| 106 | + |
| 107 | + |
| 108 | +@function_tool(description_override="Load the latest metro map with lines and stations.") |
| 109 | +async def get_map(ctx: RunContextWrapper[MetroAgentContext]) -> MapResult: |
| 110 | + print("[TOOL CALL] get_map") |
| 111 | + metro_map = ctx.context.metro.get_map() |
| 112 | + await ctx.context.stream(ProgressUpdateEvent(text="Retrieving the latest metro map...")) |
| 113 | + return MapResult(map=metro_map) |
| 114 | + |
| 115 | + |
| 116 | +@function_tool(description_override="List all metro lines with their colors and endpoints.") |
| 117 | +async def list_lines(ctx: RunContextWrapper[MetroAgentContext]) -> LineListResult: |
| 118 | + print("[TOOL CALL] list_lines") |
| 119 | + return LineListResult(lines=ctx.context.metro.list_lines()) |
| 120 | + |
| 121 | + |
| 122 | +@function_tool(description_override="List all stations and which lines serve them.") |
| 123 | +async def list_stations(ctx: RunContextWrapper[MetroAgentContext]) -> StationListResult: |
| 124 | + print("[TOOL CALL] list_stations") |
| 125 | + return StationListResult(stations=ctx.context.metro.list_stations()) |
| 126 | + |
| 127 | + |
| 128 | +@function_tool(description_override="Get the ordered stations for a specific line.") |
| 129 | +async def get_line_route( |
| 130 | + ctx: RunContextWrapper[MetroAgentContext], |
| 131 | + line_id: str, |
| 132 | +) -> LineDetailResult: |
| 133 | + print("[TOOL CALL] get_line_route", line_id) |
| 134 | + line = ctx.context.metro.find_line(line_id) |
| 135 | + if not line: |
| 136 | + raise ValueError(f"Line '{line_id}' was not found.") |
| 137 | + stations = ctx.context.metro.stations_for_line(line_id) |
| 138 | + return LineDetailResult(line=line, stations=stations) |
| 139 | + |
| 140 | + |
| 141 | +@function_tool(description_override="Look up a single station and the lines serving it.") |
| 142 | +async def get_station( |
| 143 | + ctx: RunContextWrapper[MetroAgentContext], |
| 144 | + station_id: str, |
| 145 | +) -> StationDetailResult: |
| 146 | + print("[TOOL CALL] get_station", station_id) |
| 147 | + station = ctx.context.metro.find_station(station_id) |
| 148 | + if not station: |
| 149 | + raise ValueError(f"Station '{station_id}' was not found.") |
| 150 | + lines = [ctx.context.metro.find_line(line_id) for line_id in station.lines] |
| 151 | + return StationDetailResult( |
| 152 | + station=station, |
| 153 | + lines=[line for line in lines if line], |
| 154 | + ) |
| 155 | + |
| 156 | + |
| 157 | +@function_tool( |
| 158 | + description_override=( |
| 159 | + """Add a new station to the metro map. |
| 160 | + - `station_name`: The name of the station to add. |
| 161 | + - `line_id`: The id of the line to add the station to. Should be one of the ids returned by list_lines. |
| 162 | + - `append`: Whether to add the station to the end of the line or the beginning. Defaults to True. |
| 163 | + """ |
| 164 | + ) |
| 165 | +) |
| 166 | +async def add_station( |
| 167 | + ctx: RunContextWrapper[MetroAgentContext], |
| 168 | + station_name: str, |
| 169 | + line_id: str, |
| 170 | + append: bool = True, |
| 171 | +) -> MapResult: |
| 172 | + station_name = station_name.strip().title() |
| 173 | + print(f"[TOOL CALL] add_station: {station_name} to {line_id}") |
| 174 | + await ctx.context.stream(ProgressUpdateEvent(text="Adding station...")) |
| 175 | + try: |
| 176 | + updated_map, new_station = ctx.context.metro.add_station(station_name, line_id, append) |
| 177 | + ctx.context.client_tool_call = ClientToolCall( |
| 178 | + name="add_station", |
| 179 | + arguments={ |
| 180 | + "stationId": new_station.id, |
| 181 | + "map": updated_map.model_dump(mode="json"), |
| 182 | + }, |
| 183 | + ) |
| 184 | + return MapResult(map=updated_map) |
| 185 | + except Exception as e: |
| 186 | + print(f"[ERROR] add_station: {e}") |
| 187 | + await ctx.context.stream( |
| 188 | + ThreadItemDoneEvent( |
| 189 | + item=AssistantMessageItem( |
| 190 | + thread_id=ctx.context.thread.id, |
| 191 | + id=ctx.context.generate_id("message"), |
| 192 | + created_at=datetime.now(), |
| 193 | + content=[ |
| 194 | + AssistantMessageContent( |
| 195 | + text=f"There was an error adding **{station_name}**" |
| 196 | + ) |
| 197 | + ], |
| 198 | + ), |
| 199 | + ) |
| 200 | + ) |
| 201 | + raise |
| 202 | + |
| 203 | + |
| 204 | +metro_map_agent = Agent[MetroAgentContext]( |
| 205 | + name="metro_map", |
| 206 | + instructions=INSTRUCTIONS, |
| 207 | + model="gpt-4o-mini", |
| 208 | + tools=[ |
| 209 | + # Retrieve map data |
| 210 | + get_map, |
| 211 | + list_lines, |
| 212 | + list_stations, |
| 213 | + get_line_route, |
| 214 | + get_station, |
| 215 | + # Respond with a widget |
| 216 | + show_line_selector, |
| 217 | + # Update the metro map |
| 218 | + add_station, |
| 219 | + ], |
| 220 | + # Stop inference after client tool call or widget output |
| 221 | + tool_use_behavior=StopAtTools( |
| 222 | + stop_at_tool_names=[ |
| 223 | + add_station.name, |
| 224 | + show_line_selector.name, |
| 225 | + ] |
| 226 | + ), |
| 227 | +) |
0 commit comments