diff --git a/apps/ask-gateway/app/chat_service.py b/apps/ask-gateway/app/chat_service.py index a7aef650..5a96be26 100644 --- a/apps/ask-gateway/app/chat_service.py +++ b/apps/ask-gateway/app/chat_service.py @@ -14,23 +14,36 @@ from .mcp_client import McpClientError, McpHttpClient from .models import AskStreamRequest, ToolCall from .response_synthesizer import synthesize_final_response +from .usage_tracker import ( + resolve_model_for_user_async, + record_usage_async, + get_user_usage_async, +) +from . import supabase_store -MAX_TOOL_ITERATIONS = 30 +MAX_TOOL_ITERATIONS = 20 _SYSTEM_PROMPT = """\ You are an AI course assistant for Princeton University students. You help students \ find courses, understand workload, compare options, and make informed decisions. -You have access to tools that search courses, get course details, evaluations, \ -instructor info, and more. Use them to answer accurately. +The upcoming term is Fall 2026 (term code 1272). Unless the user specifies otherwise, \ +default to searching and discussing courses for Fall 2026. The current term is Spring 2026 (1264). Guidelines: - Always use tools to look up real data. Do not fabricate course information. - After receiving tool results, synthesize a helpful, conversational response. - When comparing courses, highlight key differences (rating, workload, schedule). -- Format course codes as "DEPT NNN" (e.g., COS 226, not COS226). - Keep responses concise but thorough. Use bullet points and bold for readability. -- If a course is not found, say so honestly and suggest alternatives. +- When searching for courses, prefer term 1272 (Fall 2026) unless the user asks about a different term. +""" + +_SCHEDULE_PROMPT_ADDENDUM = """ + +You also have access to the user's TigerJunction (junction.tigerapps.org) schedule. +- Get their schedules with get_user_schedules (no userId needed — you are already authenticated) +When the user asks about "my schedule", "my courses", or wants to add/remove/manage courses, use tools. +When the user wants to find courses that fit their schedule, use search_courses with the scheduleId parameter — this combines all search filters (department, text, days, time, instructor, distribution) with schedule conflict checking. """ @@ -120,15 +133,37 @@ async def _stream_agentic( request_id = request_id or str(uuid.uuid4()) conversation_id = payload.conversationId or str(uuid.uuid4()) prompt = payload.messages[-1].content - mcp_client = McpHttpClient(self._settings) + mcp_url = self._settings.junction_mcp_url if payload.netid else None + mcp_client = McpHttpClient(self._settings, netid=payload.netid, mcp_url=mcp_url) llm_client = OpenAiLlmClient(self._settings) session_id: str | None = None + # Quota enforcement + quota_before: dict | None = None + effective_model: str | None = payload.model + if payload.netid: + quota_before = await resolve_model_for_user_async(payload.netid) + if quota_before["blocked"]: + yield sse_event( + "quota_exhausted", + { + "percentUsed": 100, + "resetSeconds": quota_before["resetSeconds"], + "requestId": request_id, + }, + ) + return + effective_model = quota_before["model"] + try: yield sse_event("status", {"phase": "starting", "requestId": request_id}) + system_prompt = _SYSTEM_PROMPT + if payload.netid: + system_prompt += _SCHEDULE_PROMPT_ADDENDUM + messages: list[dict[str, Any]] = [ - {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "system", "content": system_prompt}, *[m.model_dump() for m in payload.messages], ] @@ -139,6 +174,12 @@ async def _stream_agentic( # list_tools initializes the session, so capture it session_id = mcp_client._session_id collected_usage: dict[str, Any] | None = None + # Accumulate usage across all LLM iterations (tool-calling loop) + total_cost = 0.0 + total_input_tokens = 0 + total_output_tokens = 0 + # Track tool calls for persistence + persisted_tool_events: list[dict[str, Any]] = [] for iteration in range(MAX_TOOL_ITERATIONS): if is_disconnected(): @@ -150,13 +191,18 @@ async def _stream_agentic( finish_reason: str | None = None async for chunk in llm_client.stream_chat( - messages=messages, tools=llm_tools, model=payload.model + messages=messages, tools=llm_tools, model=effective_model ): if chunk.get("type") == "done": break if chunk.get("usage"): collected_usage = chunk["usage"] + total_cost += collected_usage.get("cost") or 0 + total_input_tokens += collected_usage.get("prompt_tokens") or 0 + total_output_tokens += ( + collected_usage.get("completion_tokens") or 0 + ) choices = chunk.get("choices", []) if not choices: @@ -200,11 +246,47 @@ async def _stream_agentic( # since some models like Gemini use "stop" even with tool calls). if not collected_tool_calls: usage = { - "inputTokens": (collected_usage or {}).get("prompt_tokens", 0), - "outputTokens": (collected_usage or {}).get( - "completion_tokens", 0 - ), + "inputTokens": total_input_tokens, + "outputTokens": total_output_tokens, } + + # Record cost and get updated quota + quota_after: dict | None = None + if payload.netid: + if total_cost > 0: + await record_usage_async(payload.netid, total_cost) + quota_after = await get_user_usage_async(payload.netid) + + # Save conversation to Supabase + conv_title = ( + payload.messages[0].content[:80] + if payload.messages + else "New chat" + ) + await supabase_store.save_message( + conversation_id, payload.netid, conv_title, "user", prompt + ) + # Save tool calls/results + for te in persisted_tool_events: + await supabase_store.save_message( + conversation_id, + payload.netid, + conv_title, + te["type"], + json.dumps(te, default=str), + ) + await supabase_store.save_message( + conversation_id, + payload.netid, + conv_title, + "assistant", + collected_content, + cost=total_cost if total_cost > 0 else None, + input_tokens=total_input_tokens or None, + output_tokens=total_output_tokens or None, + model=effective_model, + ) + yield sse_event( "status", { @@ -213,15 +295,22 @@ async def _stream_agentic( **({"sessionId": session_id} if session_id else {}), }, ) - yield sse_event( - "done", - { - "conversationId": conversation_id, - "requestId": request_id, - **({"sessionId": session_id} if session_id else {}), - "usage": usage, - }, - ) + + done_data: dict[str, Any] = { + "conversationId": conversation_id, + "requestId": request_id, + **({"sessionId": session_id} if session_id else {}), + "usage": usage, + } + if quota_after is not None: + done_data["quota"] = { + "percentUsed": quota_after["percentUsed"], + "tier": quota_after["tier"], + "tierChanged": quota_before is not None + and quota_before["tier"] != quota_after["tier"], + "resetSeconds": quota_after["resetSeconds"], + } + yield sse_event("done", done_data) return yield sse_event( @@ -267,10 +356,25 @@ async def _stream_agentic( "sessionId": session_id, }, ) + persisted_tool_events.append( + { + "type": "tool_call", + "name": tool_name, + "arguments": tool_args, + } + ) result = await asyncio.wait_for( mcp_client.call_tool(tool_name, tool_args), timeout=self._settings.tool_timeout_seconds, ) + persisted_tool_events.append( + { + "type": "tool_result", + "name": tool_name, + "ok": True, + "result": result, + } + ) yield sse_event( "tool_result", { @@ -354,8 +458,24 @@ async def _stream_deterministic( request_id = request_id or str(uuid.uuid4()) conversation_id = payload.conversationId or str(uuid.uuid4()) prompt = payload.messages[-1].content - mcp_client = McpHttpClient(self._settings) + mcp_url = self._settings.junction_mcp_url if payload.netid else None + mcp_client = McpHttpClient(self._settings, netid=payload.netid, mcp_url=mcp_url) session_id: str | None = None + + # Quota enforcement (deterministic doesn't call LLM, but still check) + if payload.netid: + det_quota = await resolve_model_for_user_async(payload.netid) + if det_quota["blocked"]: + yield sse_event( + "quota_exhausted", + { + "percentUsed": 100, + "resetSeconds": det_quota["resetSeconds"], + "requestId": request_id, + }, + ) + return + try: yield sse_event("status", {"phase": "starting", "requestId": request_id}) tool_calls = _plan_tools(prompt, payload.term) @@ -432,18 +552,24 @@ async def _stream_deterministic( **({"sessionId": session_id} if session_id else {}), }, ) - yield sse_event( - "done", - { - "conversationId": conversation_id, - "requestId": request_id, - **({"sessionId": session_id} if session_id else {}), - "usage": { - "inputTokens": 0, - "outputTokens": len(response_text.split()), - }, + det_done_data: dict[str, Any] = { + "conversationId": conversation_id, + "requestId": request_id, + **({"sessionId": session_id} if session_id else {}), + "usage": { + "inputTokens": 0, + "outputTokens": len(response_text.split()), }, - ) + } + if payload.netid: + det_q = await get_user_usage_async(payload.netid) + det_done_data["quota"] = { + "percentUsed": det_q["percentUsed"], + "tier": det_q["tier"], + "tierChanged": False, + "resetSeconds": det_q["resetSeconds"], + } + yield sse_event("done", det_done_data) except asyncio.CancelledError: yield sse_event( "error", @@ -505,5 +631,3 @@ def _extract_reasoning(delta: dict[str, Any]) -> str: if isinstance(reasoning, str): return reasoning return "" - - diff --git a/apps/ask-gateway/app/config.py b/apps/ask-gateway/app/config.py index 23e20733..2897a2e2 100644 --- a/apps/ask-gateway/app/config.py +++ b/apps/ask-gateway/app/config.py @@ -20,6 +20,7 @@ def _env_bool(name: str, default: bool) -> bool: class Settings: gateway_api_token: str = os.getenv("ASK_GATEWAY_API_TOKEN", "") mcp_url: str = os.getenv("JUNCTION_MCP_URL", "http://localhost:3000/mcp") + junction_mcp_url: str = os.getenv("JUNCTION_MCP_URL_SCHEDULE", "http://localhost:3000/junction/mcp") mcp_token: str = os.getenv("JUNCTION_MCP_TOKEN", "") mcp_protocol_version: str = os.getenv("MCP_PROTOCOL_VERSION", "2025-03-26") tool_timeout_seconds: float = float(os.getenv("ASK_TOOL_TIMEOUT_SECONDS", "10")) @@ -33,3 +34,5 @@ class Settings: ask_llm_timeout_seconds: float = float(os.getenv("ASK_LLM_TIMEOUT_SECONDS", "12")) ask_llm_planner_enabled: bool = _env_bool("ASK_LLM_PLANNER_ENABLED", False) ask_llm_synthesis_enabled: bool = _env_bool("ASK_LLM_SYNTHESIS_ENABLED", False) + supabase_url: str = os.getenv("SUPABASE_URL", "") + supabase_service_role_key: str = os.getenv("SUPABASE_SERVICE_ROLE_KEY", "") diff --git a/apps/ask-gateway/app/llm_client.py b/apps/ask-gateway/app/llm_client.py index ec408ec9..46d22bd5 100644 --- a/apps/ask-gateway/app/llm_client.py +++ b/apps/ask-gateway/app/llm_client.py @@ -48,6 +48,7 @@ async def stream_chat( "model": model or self._settings.ask_llm_model, "messages": messages, "stream": True, + "stream_options": {"include_usage": True}, } if tools: request["tools"] = tools diff --git a/apps/ask-gateway/app/main.py b/apps/ask-gateway/app/main.py index 1415cfd4..d22aec8f 100644 --- a/apps/ask-gateway/app/main.py +++ b/apps/ask-gateway/app/main.py @@ -10,6 +10,8 @@ from .chat_service import ChatService from .config import Settings from .models import AskStreamRequest +from .usage_tracker import get_user_usage_async +from . import supabase_store app = FastAPI(title="Ask Gateway", version="1.0.0") logger = logging.getLogger("ask-gateway") @@ -36,6 +38,46 @@ async def health() -> dict[str, str]: return {"status": "ok"} +@app.get("/ask/quota") +async def get_quota( + netid: str, + authorization: str | None = Header(default=None), + settings: Settings = Depends(get_settings), +) -> dict: + _validate_gateway_auth(settings, authorization) + if not netid: + raise HTTPException(status_code=400, detail="netid is required") + return await get_user_usage_async(netid) + + +@app.get("/ask/conversations") +async def list_conversations( + netid: str, + authorization: str | None = Header(default=None), + settings: Settings = Depends(get_settings), +) -> list: + _validate_gateway_auth(settings, authorization) + if not netid: + raise HTTPException(status_code=400, detail="netid is required") + return await supabase_store.list_conversations(netid) + + +@app.get("/ask/conversations/{conv_id}/messages") +async def get_conversation_messages( + conv_id: str, + netid: str, + authorization: str | None = Header(default=None), + settings: Settings = Depends(get_settings), +) -> list: + _validate_gateway_auth(settings, authorization) + if not netid: + raise HTTPException(status_code=400, detail="netid is required") + messages = await supabase_store.get_conversation_messages(conv_id, netid) + if messages is None: + raise HTTPException(status_code=404, detail="Conversation not found") + return messages + + @app.post("/ask/stream") async def ask_stream( payload: AskStreamRequest, diff --git a/apps/ask-gateway/app/mcp_client.py b/apps/ask-gateway/app/mcp_client.py index 4ee57e0f..f816eb3c 100644 --- a/apps/ask-gateway/app/mcp_client.py +++ b/apps/ask-gateway/app/mcp_client.py @@ -22,8 +22,16 @@ class McpClientError(Exception): class McpHttpClient: - def __init__(self, settings: Settings) -> None: + def __init__( + self, + settings: Settings, + *, + netid: str | None = None, + mcp_url: str | None = None, + ) -> None: self._settings = settings + self._netid = netid + self._mcp_url = mcp_url or settings.mcp_url self._session_id: str | None = None self._client = httpx.AsyncClient( timeout=httpx.Timeout(settings.tool_timeout_seconds, connect=settings.connect_timeout_seconds) @@ -58,7 +66,7 @@ async def list_tools(self) -> list[dict[str, Any]]: if self._session_id is None: await self.initialize() - cache_key = self._settings.mcp_url + cache_key = self._mcp_url cached = _tools_cache.get(cache_key) if cached is not None: openai_tools, _, ts = cached @@ -103,7 +111,7 @@ async def close(self) -> None: try: if self._session_id: await self._client.delete( - self._settings.mcp_url, + self._mcp_url, headers=self._headers(include_session=True), ) finally: @@ -116,7 +124,7 @@ def _next(self) -> int: async def _post(self, payload: dict[str, Any]) -> httpx.Response: response = await self._client.post( - self._settings.mcp_url, + self._mcp_url, headers=self._headers(include_session=True), json=payload, ) @@ -131,6 +139,8 @@ def _headers(self, include_session: bool) -> dict[str, str]: } if self._settings.mcp_token: headers["authorization"] = f"Bearer {self._settings.mcp_token}" + if self._netid: + headers["x-user-netid"] = self._netid if include_session and self._session_id: headers["mcp-session-id"] = self._session_id headers["mcp-protocol-version"] = self._settings.mcp_protocol_version diff --git a/apps/ask-gateway/app/models.py b/apps/ask-gateway/app/models.py index 9a0a572b..db1059a2 100644 --- a/apps/ask-gateway/app/models.py +++ b/apps/ask-gateway/app/models.py @@ -15,6 +15,7 @@ class AskStreamRequest(BaseModel): term: int | None = None model: str | None = None messages: list[ChatMessage] = Field(min_length=1) + netid: str | None = None class ToolCall(BaseModel): diff --git a/apps/ask-gateway/app/supabase_store.py b/apps/ask-gateway/app/supabase_store.py new file mode 100644 index 00000000..941d6cd6 --- /dev/null +++ b/apps/ask-gateway/app/supabase_store.py @@ -0,0 +1,185 @@ +"""Supabase persistence for quotas, conversations, and messages. + +Uses httpx to call Supabase PostgREST API directly — no extra dependencies. +""" + +from __future__ import annotations + +import logging +import os +from typing import Any + +import httpx + +logger = logging.getLogger("ask-gateway.store") + +_SUPABASE_URL = os.getenv("SUPABASE_URL", "").rstrip("/") +_SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY", "") + + +def _headers() -> dict[str, str]: + return { + "apikey": _SUPABASE_KEY, + "Authorization": f"Bearer {_SUPABASE_KEY}", + "Content-Type": "application/json", + "Prefer": "return=representation", + } + + +def _rest_url(table: str) -> str: + return f"{_SUPABASE_URL}/rest/v1/{table}" + + +def _enabled() -> bool: + return bool(_SUPABASE_URL and _SUPABASE_KEY) + + +# ── Quotas ───────────────────────────────────────────────────────────── + +async def get_quota_spent(netid: str, time_window: str) -> float: + if not _enabled(): + return 0.0 + async with httpx.AsyncClient() as client: + r = await client.get( + _rest_url("ask_quotas"), + headers={**_headers(), "Accept": "application/json"}, + params={"netid": f"eq.{netid}", "time_window": f"eq.{time_window}", "select": "spent"}, + ) + if r.status_code == 200: + rows = r.json() + if rows: + return float(rows[0].get("spent", 0)) + return 0.0 + + +async def upsert_quota(netid: str, time_window: str, spent: float) -> None: + if not _enabled(): + return + async with httpx.AsyncClient() as client: + await client.post( + _rest_url("ask_quotas"), + headers={**_headers(), "Prefer": "resolution=merge-duplicates"}, + json={"netid": netid, "time_window": time_window, "spent": round(spent, 6)}, + ) + + +# ── Conversations ────────────────────────────────────────────────────── + +async def upsert_conversation(conv_id: str, netid: str, title: str) -> None: + if not _enabled(): + return + async with httpx.AsyncClient() as client: + await client.post( + _rest_url("ask_conversations"), + headers={**_headers(), "Prefer": "resolution=merge-duplicates"}, + json={"id": conv_id, "netid": netid, "title": title[:100], "updated_at": "now()"}, + ) + + +async def update_conversation_timestamp(conv_id: str) -> None: + if not _enabled(): + return + async with httpx.AsyncClient() as client: + await client.patch( + _rest_url("ask_conversations"), + headers=_headers(), + params={"id": f"eq.{conv_id}"}, + json={"updated_at": "now()"}, + ) + + +async def save_message( + conv_id: str, + netid: str, + title: str, + role: str, + content: str, + cost: float | None = None, + input_tokens: int | None = None, + output_tokens: int | None = None, + model: str | None = None, +) -> None: + """Save a single message and ensure the conversation exists.""" + if not _enabled(): + return + try: + # Ensure conversation exists + await upsert_conversation(conv_id, netid, title) + + # Insert message + msg: dict[str, Any] = { + "conversation_id": conv_id, + "role": role, + "content": content, + } + if cost is not None: + msg["cost"] = round(cost, 6) + if input_tokens is not None: + msg["input_tokens"] = input_tokens + if output_tokens is not None: + msg["output_tokens"] = output_tokens + if model is not None: + msg["model"] = model + + async with httpx.AsyncClient() as client: + await client.post( + _rest_url("ask_messages"), + headers=_headers(), + json=msg, + ) + + # Update conversation timestamp + await update_conversation_timestamp(conv_id) + except Exception as e: + logger.error("Failed to save message: %s", e) + + +async def list_conversations(netid: str, limit: int = 20) -> list[dict[str, Any]]: + if not _enabled(): + return [] + async with httpx.AsyncClient() as client: + r = await client.get( + _rest_url("ask_conversations"), + headers={**_headers(), "Accept": "application/json"}, + params={ + "netid": f"eq.{netid}", + "select": "id,title,created_at,updated_at", + "order": "updated_at.desc", + "limit": str(limit), + }, + ) + if r.status_code == 200: + return r.json() + return [] + + +async def get_conversation_messages( + conv_id: str, netid: str +) -> list[dict[str, Any]] | None: + """Get messages for a conversation, verifying ownership.""" + if not _enabled(): + return None + + async with httpx.AsyncClient() as client: + # Verify ownership + r = await client.get( + _rest_url("ask_conversations"), + headers={**_headers(), "Accept": "application/json"}, + params={"id": f"eq.{conv_id}", "netid": f"eq.{netid}", "select": "id"}, + ) + if r.status_code != 200 or not r.json(): + return None + + # Fetch messages + r = await client.get( + _rest_url("ask_messages"), + headers={**_headers(), "Accept": "application/json"}, + params={ + "conversation_id": f"eq.{conv_id}", + "select": "role,content,cost,input_tokens,output_tokens,model,created_at", + "order": "created_at.asc", + }, + ) + if r.status_code == 200: + return r.json() + return None diff --git a/apps/ask-gateway/app/usage_tracker.py b/apps/ask-gateway/app/usage_tracker.py new file mode 100644 index 00000000..8fcdf136 --- /dev/null +++ b/apps/ask-gateway/app/usage_tracker.py @@ -0,0 +1,125 @@ +"""Per-user usage tracking with tiered quota system, persisted to Supabase. + +Tier 1: Claude Sonnet 4.6 — $0.50 budget +Tier 2: Claude Haiku 4.5 — $0.25 budget (auto-downgrade when Tier 1 exhausted) +Exhausted: No more requests until next 8-hour window + +Resets at 12:00 AM, 8:00 AM, 4:00 PM US Eastern. +""" + +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, timezone, timedelta +from typing import Any + +from . import supabase_store + +logger = logging.getLogger("ask-gateway.usage") + +TIER1_MODEL = "anthropic/claude-sonnet-4.6" +TIER2_MODEL = "anthropic/claude-haiku-4.5" +TIER1_BUDGET = 0.50 +TIER2_BUDGET = 0.25 +TOTAL_BUDGET = TIER1_BUDGET + TIER2_BUDGET # $0.75 +WINDOW_HOURS = 8 + +_ET_OFFSET = timezone(timedelta(hours=-5)) + + +def _now_et() -> datetime: + return datetime.now(_ET_OFFSET) + + +def _get_window_id() -> str: + now = _now_et() + window_index = now.hour // WINDOW_HOURS + return f"{now.strftime('%Y-%m-%d')}-{window_index}" + + +def _seconds_until_next_window() -> int: + now = _now_et() + window_index = now.hour // WINDOW_HOURS + next_hour = (window_index + 1) * WINDOW_HOURS + if next_hour >= 24: + tomorrow = now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1) + return max(1, int((tomorrow - now).total_seconds())) + next_boundary = now.replace(hour=next_hour, minute=0, second=0, microsecond=0) + return max(1, int((next_boundary - now).total_seconds())) + + +def _build_status(spent: float) -> dict[str, Any]: + if spent < TIER1_BUDGET: + tier = 1 + model = TIER1_MODEL + blocked = False + elif spent < TOTAL_BUDGET: + tier = 2 + model = TIER2_MODEL + blocked = False + else: + tier = "exhausted" + model = TIER2_MODEL + blocked = True + + percent_used = min(100, round((spent / TOTAL_BUDGET) * 100, 1)) + reset_seconds = _seconds_until_next_window() + + return { + "spent": round(spent, 4), + "tier": tier, + "model": model, + "blocked": blocked, + "percentUsed": percent_used, + "resetSeconds": reset_seconds, + "window": _get_window_id(), + } + + +def get_user_usage(netid: str) -> dict[str, Any]: + """Synchronous wrapper — reads from Supabase.""" + window = _get_window_id() + try: + spent = asyncio.get_event_loop().run_until_complete( + supabase_store.get_quota_spent(netid, window) + ) + except RuntimeError: + # If no event loop, create one (shouldn't happen in FastAPI) + spent = asyncio.run(supabase_store.get_quota_spent(netid, window)) + return _build_status(spent) + + +async def get_user_usage_async(netid: str) -> dict[str, Any]: + """Async version for use in async handlers.""" + window = _get_window_id() + spent = await supabase_store.get_quota_spent(netid, window) + return _build_status(spent) + + +def resolve_model_for_user(netid: str) -> dict[str, Any]: + return get_user_usage(netid) + + +async def resolve_model_for_user_async(netid: str) -> dict[str, Any]: + return await get_user_usage_async(netid) + + +async def record_usage_async(netid: str, cost: float) -> None: + if cost <= 0: + return + window = _get_window_id() + current = await supabase_store.get_quota_spent(netid, window) + new_spent = current + cost + await supabase_store.upsert_quota(netid, window, new_spent) + logger.info("usage.record netid=%s cost=%.4f total=%.4f window=%s", netid, cost, new_spent, window) + + +def record_usage(netid: str, cost: float) -> None: + """Synchronous wrapper.""" + if cost <= 0: + return + try: + asyncio.get_event_loop().run_until_complete(record_usage_async(netid, cost)) + except RuntimeError: + asyncio.run(record_usage_async(netid, cost)) diff --git a/apps/engine/.mcp.json b/apps/engine/.mcp.json new file mode 100644 index 00000000..42694d08 --- /dev/null +++ b/apps/engine/.mcp.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "supabase": { + "type": "http", + "url": "https://mcp.supabase.com/mcp" + } + } +} \ No newline at end of file diff --git a/apps/engine/bun.lockb b/apps/engine/bun.lockb index 0f6cf4e5..85625db7 100755 Binary files a/apps/engine/bun.lockb and b/apps/engine/bun.lockb differ diff --git a/apps/engine/package.json b/apps/engine/package.json index cd79c634..a82a43e1 100644 --- a/apps/engine/package.json +++ b/apps/engine/package.json @@ -27,11 +27,13 @@ "@fastify/swagger-ui": "^5.2.3", "@fastify/websocket": "^11.2.0", "@modelcontextprotocol/sdk": "^1.26.0", + "@supabase/supabase-js": "^2.100.1", "dotenv": "^17.2.3", "drizzle-orm": "^0.44.6", "fastify": "^5.5.0", "fastify-plugin": "^5.0.1", "jsdom": "^26.1.0", + "mongodb": "^7.1.1", "pg": "^8.16.3", "redis": "^4.6.0", "zod": "^4.3.6" diff --git a/apps/engine/src/app.ts b/apps/engine/src/app.ts index b9f14dca..7c54e0be 100644 --- a/apps/engine/src/app.ts +++ b/apps/engine/src/app.ts @@ -19,6 +19,8 @@ import instructorsRoutes from "./routes/api/instructors.ts"; import evaluationsRoutes from "./routes/api/evaluations.ts"; import redisPlugin from "./plugins/redis.ts"; import dbPlugin from "./plugins/db.ts"; +import supabasePlugin from "./plugins/supabase.ts"; +import snatchDbPlugin from "./plugins/snatch-db.ts"; import snatchRoutes from "./routes/snatch.ts"; import mcpRoutes from "./routes/mcp.ts"; @@ -75,6 +77,8 @@ export async function build(opts?: { logger?: boolean }): Promise { + async ({ term, department, query, dist, days, daysMatch, startAfter, startBefore, instructor, scheduleId, limit: maxResults, offset }) => { const resultLimit = Math.min(maxResults ?? 50, 200); const resultOffset = offset ?? 0; const conditions = []; @@ -73,6 +81,9 @@ export function registerCourseTools(server: McpServer, db: NodePgDatabase) { ); } + // Fetch more results when schedule filtering is active (some will be filtered out) + const fetchLimit = scheduleId ? resultLimit * 3 : resultLimit; + const courses = await db .select({ id: schema.courses.id, @@ -89,7 +100,115 @@ export function registerCourseTools(server: McpServer, db: NodePgDatabase) { .where(conditions.length > 0 ? and(...conditions) : undefined) .orderBy(asc(schema.courses.code)) .offset(resultOffset) - .limit(resultLimit); + .limit(fetchLimit); + + // If scheduleId is provided and junction context available, post-filter for conflicts + if (scheduleId != null && junctionCtx) { + const { supabase, authContext } = junctionCtx; + + // Resolve user + if (!authContext?.netid) { + return { + content: [{ type: "text" as const, text: "scheduleId filter requires authenticated user (x-user-netid header)." }], + isError: true, + }; + } + + const { data: userId } = await supabase.rpc("get_user_id_by_netid", { netid: authContext.netid }); + if (!userId) { + return { + content: [{ type: "text" as const, text: `No TigerJunction account found for NetID '${authContext.netid}'.` }], + isError: true, + }; + } + + // Verify schedule ownership + const { data: sched } = await supabase + .from("schedules") + .select("id, user_id") + .eq("id", scheduleId) + .single(); + + if (!sched || sched.user_id !== userId) { + return { + content: [{ type: "text" as const, text: "Schedule not found or does not belong to authenticated user." }], + isError: true, + }; + } + + // Get existing courses in schedule + const { data: existingAssocs } = await supabase + .from("course_schedule_associations") + .select("course_id") + .eq("schedule_id", scheduleId); + + const existingCourseIds = new Set((existingAssocs ?? []).map((a: { course_id: number }) => a.course_id)); + + // Get occupied time slots from schedule's sections (via engine DB for consistency) + const occupiedSlots: { days: number; startTime: number; endTime: number }[] = []; + if (existingCourseIds.size > 0) { + // Map Supabase course IDs to engine course IDs via listing_id + term + const { data: supabaseCourses } = await supabase + .from("courses") + .select("listing_id, term") + .in("id", [...existingCourseIds]); + + for (const sc of supabaseCourses ?? []) { + const engineCourseId = `${sc.listing_id}-${sc.term}`; + const sections = await db + .select({ days: schema.sections.days, startTime: schema.sections.startTime, endTime: schema.sections.endTime }) + .from(schema.sections) + .where(eq(schema.sections.courseId, engineCourseId)); + occupiedSlots.push(...sections); + } + } + + // Filter out courses already in schedule and courses that conflict + const filtered = []; + for (const course of courses) { + // Skip courses already in schedule (match by listing_id + term) + const sections = await db + .select({ + title: schema.sections.title, + days: schema.sections.days, + startTime: schema.sections.startTime, + endTime: schema.sections.endTime, + }) + .from(schema.sections) + .where(eq(schema.sections.courseId, course.id)); + + if (sections.length === 0) { filtered.push(course); continue; } + + // Group by section type + const byType = new Map(); + for (const s of sections) { + const type = s.title.match(/^([A-Z]+)/)?.[1] ?? s.title; + if (!byType.has(type)) byType.set(type, []); + byType.get(type)!.push(s); + } + + // Course fits if each section type has at least one non-conflicting option + const fits = [...byType.values()].every((group) => + group.some((s) => + !occupiedSlots.some((o) => + (s.days & o.days) !== 0 && s.startTime < o.endTime && o.startTime < s.endTime + ) + ) + ); + + if (fits) filtered.push(course); + if (filtered.length >= resultLimit) break; + } + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify({ count: filtered.length, scheduleFiltered: true, scheduleId, courses: filtered }, null, 2), + }, + ], + }; + } return { content: [ diff --git a/apps/engine/src/mcp/tools/junction-schedules.ts b/apps/engine/src/mcp/tools/junction-schedules.ts new file mode 100644 index 00000000..975c26dd --- /dev/null +++ b/apps/engine/src/mcp/tools/junction-schedules.ts @@ -0,0 +1,784 @@ +import type { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { SupabaseClient } from "@supabase/supabase-js"; +import { z } from "zod"; +import { termCodeToName, valueToDays, valueToTime } from "../helpers.js"; +import type { AuthContext } from "../context.js"; + +// Supabase status encoding: 0=open, 1=closed, 2=canceled +const STATUS_MAP: Record = { 0: "open", 1: "closed", 2: "canceled" }; +function statusName(code: number | null): string { + return code != null ? (STATUS_MAP[code] ?? "unknown") : "unknown"; +} + +interface TimeSlot { + days: number; + startTime: number; + endTime: number; +} + +function timeSlotsOverlap(a: TimeSlot, b: TimeSlot): boolean { + if ((a.days & b.days) === 0) return false; + return a.startTime < b.endTime && b.startTime < a.endTime; +} + +function sectionTypePrefix(title: string): string { + const match = title.match(/^([A-Z]+)/); + return match ? match[1] : title; +} + +function formatSectionForDisplay(s: { + title: string; + days: number; + start_time: number; + end_time: number; + room?: string | null; + status: number | null; +}) { + return { + sectionTitle: s.title, + days: valueToDays(s.days), + startTime: valueToTime(s.start_time), + endTime: valueToTime(s.end_time), + room: s.room ?? null, + status: statusName(s.status), + }; +} + +/** + * Resolve the authenticated user's Supabase UUID via: + * NetID → Supabase RPC get_user_id_by_netid → Supabase UUID + * + * Uses a Supabase database function that looks up auth.users by email + * ({netid}@princeton.edu), avoiding dependency on the engine's local users table. + */ +async function resolveSupabaseUserId( + supabase: SupabaseClient, + authContext?: AuthContext +): Promise<{ supabaseUuid?: string; error?: string }> { + if (!authContext?.netid) { + return { error: "Missing user context. Provide x-user-netid header." }; + } + + const { data, error } = await supabase.rpc("get_user_id_by_netid", { + netid: authContext.netid, + }); + + if (error) { + return { error: `Failed to resolve NetID '${authContext.netid}': ${error.message}` }; + } + + if (!data) { + return { + error: `No TigerJunction account found for NetID '${authContext.netid}'. Create a TigerJunction account first to access schedule features.`, + }; + } + + return { supabaseUuid: data as string }; +} + +export function registerJunctionScheduleTools( + server: McpServer, + supabase: SupabaseClient, + authContext?: AuthContext +) { + // ── get_user_schedules ────────────────────────────────────────────── + server.tool( + "get_user_schedules", + "Get all schedules for the authenticated user, optionally filtered by term.", + { + term: z + .number() + .optional() + .describe( + "Term code to filter by. Mapping: 1232=Fall 2022, 1234=Spring 2023, 1242=Fall 2023, 1244=Spring 2024, 1252=Fall 2024, 1254=Spring 2025, 1262=Fall 2025, 1264=Spring 2026 (current). Codes ending in 2=Fall, ending in 4=Spring." + ), + }, + async ({ term }) => { + const auth = await resolveSupabaseUserId(supabase, authContext); + if (!auth.supabaseUuid) { + return { content: [{ type: "text" as const, text: auth.error ?? "Unauthorized." }], isError: true }; + } + + let query = supabase + .from("schedules") + .select("id, title, term, is_public") + .eq("user_id", auth.supabaseUuid) + .order("term", { ascending: true }); + + if (term != null) { + query = query.eq("term", term); + } + + const { data: schedules, error } = await query; + + if (error) { + return { content: [{ type: "text" as const, text: `Failed to fetch schedules: ${error.message}` }], isError: true }; + } + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify( + { + count: schedules?.length ?? 0, + schedules: (schedules ?? []).map((s) => ({ + ...s, + termName: termCodeToName(s.term), + })), + }, + null, + 2 + ), + }, + ], + }; + } + ); + + // ── get_schedule_details ──────────────────────────────────────────── + server.tool( + "get_schedule_details", + "Get full details of a schedule including its courses, sections, meeting times, and any time conflicts.", + { + scheduleId: z.number().describe("Schedule ID"), + }, + async ({ scheduleId }) => { + const auth = await resolveSupabaseUserId(supabase, authContext); + if (!auth.supabaseUuid) { + return { content: [{ type: "text" as const, text: auth.error ?? "Unauthorized." }], isError: true }; + } + + // Fetch the schedule and verify ownership + const { data: schedule, error: schedError } = await supabase + .from("schedules") + .select("id, title, term, is_public, user_id") + .eq("id", scheduleId) + .single(); + + if (schedError || !schedule) { + return { content: [{ type: "text" as const, text: "Schedule not found." }], isError: true }; + } + if (schedule.user_id !== auth.supabaseUuid) { + return { + content: [{ type: "text" as const, text: "Forbidden: schedule does not belong to authenticated user." }], + isError: true, + }; + } + + // Fetch course associations + const { data: associations } = await supabase + .from("course_schedule_associations") + .select("course_id, metadata") + .eq("schedule_id", scheduleId); + + if (!associations || associations.length === 0) { + return { + content: [ + { + type: "text" as const, + text: JSON.stringify( + { + schedule: { id: schedule.id, title: schedule.title, term: schedule.term, termName: termCodeToName(schedule.term) }, + courses: [], + sections: [], + conflicts: "No courses in this schedule", + }, + null, + 2 + ), + }, + ], + }; + } + + const courseIds = associations.map((a) => a.course_id); + + // Fetch courses + const { data: courses } = await supabase + .from("courses") + .select("id, code, title, status") + .in("id", courseIds); + + const courseMap = new Map((courses ?? []).map((c) => [c.id, c])); + + // Fetch sections for all courses + const { data: sections } = await supabase + .from("sections") + .select("id, course_id, title, days, start_time, end_time, room, status, cap, tot") + .in("course_id", courseIds); + + // Build section list with course codes and detect conflicts + const allSections: (TimeSlot & { courseCode: string; sectionTitle: string; room: string | null; status: string })[] = []; + + for (const s of sections ?? []) { + const course = courseMap.get(s.course_id); + if (!course) continue; + allSections.push({ + courseCode: course.code, + sectionTitle: s.title, + days: s.days, + startTime: s.start_time, + endTime: s.end_time, + room: s.room, + status: statusName(s.status), + }); + } + + const conflicts: string[] = []; + for (let i = 0; i < allSections.length; i++) { + for (let j = i + 1; j < allSections.length; j++) { + if (allSections[i].courseCode === allSections[j].courseCode) continue; + if (timeSlotsOverlap(allSections[i], allSections[j])) { + conflicts.push( + `${allSections[i].courseCode} (${allSections[i].sectionTitle}) overlaps with ${allSections[j].courseCode} (${allSections[j].sectionTitle})` + ); + } + } + } + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify( + { + schedule: { id: schedule.id, title: schedule.title, term: schedule.term, termName: termCodeToName(schedule.term) }, + courses: (courses ?? []).map((c) => ({ + id: c.id, + code: c.code, + title: c.title, + status: statusName(c.status), + })), + sections: allSections.map((s) => ({ + courseCode: s.courseCode, + sectionTitle: s.sectionTitle, + days: valueToDays(s.days), + startTime: valueToTime(s.startTime), + endTime: valueToTime(s.endTime), + room: s.room, + status: s.status, + })), + conflicts: conflicts.length > 0 ? conflicts : "No conflicts detected", + }, + null, + 2 + ), + }, + ], + }; + } + ); + + // ── verify_schedule ───────────────────────────────────────────────── + server.tool( + "verify_schedule", + "Validate a proposed schedule of courses. Checks for: time conflicts between sections, mixed terms, missing section types (e.g., no precept selected when required), closed/canceled sections, duplicate courses, and exceeding 7 courses. Returns valid=true or a list of issues.", + { + courseCodes: z + .array(z.string()) + .min(1) + .max(10) + .describe("Array of course codes to validate together (e.g., ['COS 226', 'MAT 202', 'ECO 100'])"), + term: z + .number() + .optional() + .describe("Term code. If omitted, uses the most recent term each course is offered."), + }, + async ({ courseCodes, term }) => { + const issues: string[] = []; + const resolvedCourses: { + code: string; + courseId: number; + term: number; + status: string; + sections: { title: string; type: string; days: number; startTime: number; endTime: number; status: string; cap: number; tot: number }[]; + }[] = []; + + for (const code of courseCodes) { + // Query Supabase courses by code (handle both "COS 330" and "COS330") + const noSpace = code.replace(/\s+/g, ""); + const withSpace = code.replace(/([A-Za-z])(\d)/, "$1 $2"); + let courseQuery = supabase + .from("courses") + .select("id, code, term, status") + .or(`code.ilike.${noSpace},code.ilike.${withSpace}`); + + if (term != null) { + courseQuery = courseQuery.eq("term", term); + } + + const { data: matchedCourses } = await courseQuery.order("term", { ascending: false }).limit(1); + + if (!matchedCourses || matchedCourses.length === 0) { + issues.push(`Course not found: "${code}"`); + continue; + } + + const course = matchedCourses[0]; + const courseStatus = statusName(course.status); + + // Check for duplicates + if (resolvedCourses.some((c) => c.courseId === course.id)) { + issues.push(`Duplicate course: ${course.code}`); + continue; + } + + // Fetch sections + const { data: sections } = await supabase + .from("sections") + .select("title, days, start_time, end_time, status, cap, tot") + .eq("course_id", course.id); + + resolvedCourses.push({ + code: course.code, + courseId: course.id, + term: course.term, + status: courseStatus, + sections: (sections ?? []).map((s) => ({ + title: s.title, + type: sectionTypePrefix(s.title), + days: s.days, + startTime: s.start_time, + endTime: s.end_time, + status: statusName(s.status), + cap: s.cap ?? 0, + tot: s.tot ?? 0, + })), + }); + } + + // Check: too many courses + if (resolvedCourses.length > 7) { + issues.push(`Too many courses: ${resolvedCourses.length} courses selected (max recommended is 7)`); + } + + // Check: mixed terms + const terms = [...new Set(resolvedCourses.map((c) => c.term))]; + if (terms.length > 1) { + const termNames = terms.map((t) => `${termCodeToName(t)} (${t})`).join(", "); + issues.push(`Mixed terms: courses span multiple semesters: ${termNames}`); + } + + // Check: canceled courses + for (const course of resolvedCourses) { + if (course.status === "canceled") { + issues.push(`Canceled course: ${course.code} is canceled`); + } + } + + // Check: section completeness + for (const course of resolvedCourses) { + const byType = new Map(); + for (const s of course.sections) { + if (!byType.has(s.type)) byType.set(s.type, []); + byType.get(s.type)!.push(s); + } + + for (const [type, sections] of byType) { + const allClosed = sections.every((s) => s.status === "closed"); + const allCanceled = sections.every((s) => s.status === "canceled"); + if (allCanceled) { + issues.push(`No available ${type} sections: all ${type} sections for ${course.code} are canceled`); + } else if (allClosed) { + issues.push( + `All ${type} sections full: all ${type} sections for ${course.code} are closed (${sections[0].tot}/${sections[0].cap} enrolled)` + ); + } + } + + if (course.sections.length === 0) { + issues.push(`No sections: ${course.code} has no sections listed`); + } + } + + // Check: time conflicts between courses + for (let i = 0; i < resolvedCourses.length; i++) { + for (let j = i + 1; j < resolvedCourses.length; j++) { + const a = resolvedCourses[i]; + const b = resolvedCourses[j]; + + const aByType = new Map(); + for (const s of a.sections) { + if (!aByType.has(s.type)) aByType.set(s.type, []); + aByType.get(s.type)!.push(s); + } + + const bByType = new Map(); + for (const s of b.sections) { + if (!bByType.has(s.type)) bByType.set(s.type, []); + bByType.get(s.type)!.push(s); + } + + for (const [aType, aSections] of aByType) { + for (const [bType, bSections] of bByType) { + const allConflict = aSections.every((aS) => + bSections.every((bS) => timeSlotsOverlap(aS, bS)) + ); + if (allConflict && aSections[0].days !== 0 && bSections[0].days !== 0) { + issues.push( + `Time conflict: ${a.code} ${aType} sections all conflict with ${b.code} ${bType} sections` + ); + } + } + } + } + } + + const valid = issues.length === 0; + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify( + { + valid, + courseCount: resolvedCourses.length, + term: terms.length === 1 ? { code: terms[0], name: termCodeToName(terms[0]) } : null, + courses: resolvedCourses.map((c) => ({ + code: c.code, + status: c.status, + sectionTypes: [...new Set(c.sections.map((s) => s.type))], + })), + issues: valid ? "Schedule is valid — no issues detected" : issues, + }, + null, + 2 + ), + }, + ], + }; + } + ); + + // ── WRITE TOOLS ───────────────────────────────────────────────────── + + // Helper: verify schedule ownership and return the schedule row + async function verifyScheduleOwnership( + scheduleId: number, + supabaseUuid: string + ): Promise<{ schedule?: { id: number; term: number; title: string; user_id: string }; error?: string }> { + const { data, error } = await supabase + .from("schedules") + .select("id, term, title, user_id") + .eq("id", scheduleId) + .single(); + + if (error || !data) return { error: "Schedule not found." }; + if (data.user_id !== supabaseUuid) return { error: "Forbidden: schedule does not belong to authenticated user." }; + return { schedule: data }; + } + + // Helper: resolve course by code within a term + async function resolveCourseByCode( + code: string, + term: number + ): Promise<{ course?: { id: number; code: string; title: string }; error?: string }> { + // Supabase stores codes without spaces (e.g., "COS330"), but users type "COS 330". + // Try both the original and a no-space version. + const noSpace = code.replace(/\s+/g, ""); + const withSpace = code.replace(/([A-Za-z])(\d)/, "$1 $2"); + + const { data } = await supabase + .from("courses") + .select("id, code, title") + .or(`code.ilike.${noSpace},code.ilike.${withSpace}`) + .eq("term", term) + .limit(1); + + if (!data || data.length === 0) { + return { error: `Course "${code}" not found for term ${termCodeToName(term)} (${term}).` }; + } + return { course: data[0] }; + } + + // Helper: generate metadata for a course being added to a schedule + async function generateCourseMetadata( + scheduleId: number, + courseId: number + ): Promise<{ complete: boolean; color: number; sections: string[]; confirms: Record }> { + // Pick a color: find unused 0-6 among existing courses in this schedule + const { data: existingAssocs } = await supabase + .from("course_schedule_associations") + .select("metadata") + .eq("schedule_id", scheduleId); + + const usedColors = new Map(); + for (const a of existingAssocs ?? []) { + const meta = a.metadata as { color?: number } | null; + if (meta?.color != null) { + usedColors.set(meta.color, (usedColors.get(meta.color) ?? 0) + 1); + } + } + + let color = 0; + for (let c = 0; c <= 6; c++) { + if (!usedColors.has(c)) { color = c; break; } + if (c === 6) { + // All used — pick least-used + let minCount = Infinity; + for (const [col, count] of usedColors) { + if (count < minCount) { minCount = count; color = col; } + } + } + } + + // Fetch sections for the course and extract categories + const { data: sections } = await supabase + .from("sections") + .select("title, category") + .eq("course_id", courseId); + + const categoryMap = new Map(); + for (const s of sections ?? []) { + const cat = s.category; + if (!categoryMap.has(cat)) categoryMap.set(cat, []); + categoryMap.get(cat)!.push(s.title); + } + + const sectionCategories = [...categoryMap.keys()].sort(); + + // Auto-confirm categories with exactly 1 section + const confirms: Record = {}; + for (const [cat, titles] of categoryMap) { + if (titles.length === 1) { + confirms[cat] = titles[0]; + } + } + + const complete = sectionCategories.every((cat) => confirms[cat] != null); + + return { complete, color, sections: sectionCategories, confirms }; + } + + // ── create_schedule ───────────────────────────────────────────────── + server.tool( + "create_schedule", + "Create a new schedule for the authenticated user.", + { + term: z.number().describe("Term code for the schedule (e.g., 1272 for Fall 2026)."), + title: z.string().optional().describe("Schedule title (default: 'My Schedule')."), + }, + async ({ term, title }) => { + const auth = await resolveSupabaseUserId(supabase, authContext); + if (!auth.supabaseUuid) { + return { content: [{ type: "text" as const, text: auth.error ?? "Unauthorized." }], isError: true }; + } + + const { data, error } = await supabase + .from("schedules") + .insert({ user_id: auth.supabaseUuid, term, title: title ?? "My Schedule" }) + .select("id, title, term") + .single(); + + if (error) { + return { content: [{ type: "text" as const, text: `Failed to create schedule: ${error.message}` }], isError: true }; + } + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify({ created: true, schedule: { ...data, termName: termCodeToName(data.term) } }, null, 2), + }, + ], + }; + } + ); + + // ── add_course_to_schedule ────────────────────────────────────────── + server.tool( + "add_course_to_schedule", + "Add a course to an existing schedule. Automatically generates metadata (color, section categories). The course must exist for the schedule's term.", + { + scheduleId: z.number().describe("Schedule ID to add the course to."), + courseCode: z.string().describe("Course code (e.g., 'COS 226')."), + }, + async ({ scheduleId, courseCode }) => { + const auth = await resolveSupabaseUserId(supabase, authContext); + if (!auth.supabaseUuid) { + return { content: [{ type: "text" as const, text: auth.error ?? "Unauthorized." }], isError: true }; + } + + const ownership = await verifyScheduleOwnership(scheduleId, auth.supabaseUuid); + if (!ownership.schedule) { + return { content: [{ type: "text" as const, text: ownership.error ?? "Schedule error." }], isError: true }; + } + + const resolved = await resolveCourseByCode(courseCode, ownership.schedule.term); + if (!resolved.course) { + return { content: [{ type: "text" as const, text: resolved.error ?? "Course not found." }], isError: true }; + } + + // Check if already in schedule + const { data: existing } = await supabase + .from("course_schedule_associations") + .select("course_id") + .eq("schedule_id", scheduleId) + .eq("course_id", resolved.course.id) + .limit(1); + + if (existing && existing.length > 0) { + return { + content: [{ type: "text" as const, text: `${resolved.course.code} is already in this schedule.` }], + isError: true, + }; + } + + const metadata = await generateCourseMetadata(scheduleId, resolved.course.id); + + const { error } = await supabase + .from("course_schedule_associations") + .insert({ course_id: resolved.course.id, schedule_id: scheduleId, metadata }); + + if (error) { + return { content: [{ type: "text" as const, text: `Failed to add course: ${error.message}` }], isError: true }; + } + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify( + { + added: true, + course: resolved.course, + scheduleId, + metadata, + }, + null, + 2 + ), + }, + ], + }; + } + ); + + // ── remove_course_from_schedule ───────────────────────────────────── + server.tool( + "remove_course_from_schedule", + "Remove a course from a schedule.", + { + scheduleId: z.number().describe("Schedule ID."), + courseCode: z.string().describe("Course code to remove (e.g., 'COS 226')."), + }, + async ({ scheduleId, courseCode }) => { + const auth = await resolveSupabaseUserId(supabase, authContext); + if (!auth.supabaseUuid) { + return { content: [{ type: "text" as const, text: auth.error ?? "Unauthorized." }], isError: true }; + } + + const ownership = await verifyScheduleOwnership(scheduleId, auth.supabaseUuid); + if (!ownership.schedule) { + return { content: [{ type: "text" as const, text: ownership.error ?? "Schedule error." }], isError: true }; + } + + const resolved = await resolveCourseByCode(courseCode, ownership.schedule.term); + if (!resolved.course) { + return { content: [{ type: "text" as const, text: resolved.error ?? "Course not found." }], isError: true }; + } + + const { error, count } = await supabase + .from("course_schedule_associations") + .delete() + .eq("schedule_id", scheduleId) + .eq("course_id", resolved.course.id); + + if (error) { + return { content: [{ type: "text" as const, text: `Failed to remove course: ${error.message}` }], isError: true }; + } + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify({ removed: true, course: resolved.course, scheduleId }, null, 2), + }, + ], + }; + } + ); + + // ── rename_schedule ───────────────────────────────────────────────── + server.tool( + "rename_schedule", + "Rename a schedule.", + { + scheduleId: z.number().describe("Schedule ID."), + title: z.string().describe("New title for the schedule."), + }, + async ({ scheduleId, title }) => { + const auth = await resolveSupabaseUserId(supabase, authContext); + if (!auth.supabaseUuid) { + return { content: [{ type: "text" as const, text: auth.error ?? "Unauthorized." }], isError: true }; + } + + const ownership = await verifyScheduleOwnership(scheduleId, auth.supabaseUuid); + if (!ownership.schedule) { + return { content: [{ type: "text" as const, text: ownership.error ?? "Schedule error." }], isError: true }; + } + + const { error } = await supabase + .from("schedules") + .update({ title }) + .eq("id", scheduleId); + + if (error) { + return { content: [{ type: "text" as const, text: `Failed to rename schedule: ${error.message}` }], isError: true }; + } + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify({ renamed: true, scheduleId, newTitle: title }, null, 2), + }, + ], + }; + } + ); + + // ── delete_schedule ───────────────────────────────────────────────── + server.tool( + "delete_schedule", + "Delete a schedule and all its course associations. This cannot be undone.", + { + scheduleId: z.number().describe("Schedule ID to delete."), + }, + async ({ scheduleId }) => { + const auth = await resolveSupabaseUserId(supabase, authContext); + if (!auth.supabaseUuid) { + return { content: [{ type: "text" as const, text: auth.error ?? "Unauthorized." }], isError: true }; + } + + const ownership = await verifyScheduleOwnership(scheduleId, auth.supabaseUuid); + if (!ownership.schedule) { + return { content: [{ type: "text" as const, text: ownership.error ?? "Schedule error." }], isError: true }; + } + + const { error } = await supabase + .from("schedules") + .delete() + .eq("id", scheduleId); + + if (error) { + return { content: [{ type: "text" as const, text: `Failed to delete schedule: ${error.message}` }], isError: true }; + } + + return { + content: [ + { + type: "text" as const, + text: JSON.stringify({ + deleted: true, + scheduleId, + title: ownership.schedule.title, + }, null, 2), + }, + ], + }; + } + ); +} diff --git a/apps/engine/src/mcp/tools/snatch.ts b/apps/engine/src/mcp/tools/snatch.ts new file mode 100644 index 00000000..9260f301 --- /dev/null +++ b/apps/engine/src/mcp/tools/snatch.ts @@ -0,0 +1,516 @@ +import type { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { NodePgDatabase } from "drizzle-orm/node-postgres"; +import type { Db } from "mongodb"; +import { z } from "zod"; +import { eq, ilike, asc, and } from "drizzle-orm"; +import * as schema from "../../db/schema.js"; +import { formatSection, termCodeToName } from "../helpers.js"; +import type { AuthContext } from "../context.js"; + +function getSnatchConfig(): { url: string; token: string } | null { + const url = process.env.SNATCH_URL?.trim(); + const token = process.env.SNATCH_TOKEN?.trim(); + if (!url || !token) return null; + return { url: url.replace(/\/$/, ""), token }; +} + +async function snatchFetch( + path: string, + method: "GET" | "POST" = "POST" +): Promise<{ ok: boolean; data?: any; error?: string }> { + const config = getSnatchConfig(); + if (!config) return { ok: false, error: "TigerSnatch is not configured (missing SNATCH_URL or SNATCH_TOKEN)." }; + + try { + const response = await fetch(`${config.url}${path}`, { + method, + headers: { Authorization: config.token }, + }); + const data = await response.json(); + if (!response.ok) { + return { ok: false, error: data?.message ?? `TigerSnatch returned status ${response.status}` }; + } + return { ok: true, data }; + } catch (err) { + return { ok: false, error: `Failed to reach TigerSnatch: ${err instanceof Error ? err.message : String(err)}` }; + } +} + +async function resolveClassId( + db: NodePgDatabase, + courseCode: string, + term?: number, + sectionTitle?: string +): Promise<{ classId?: number; section?: { title: string; days: number; startTime: number; endTime: number }; courseName?: string; error?: string }> { + const conditions = [ilike(schema.courses.code, courseCode)]; + if (term) conditions.push(eq(schema.courses.term, term)); + + const courses = await db + .select({ id: schema.courses.id, code: schema.courses.code, title: schema.courses.title, term: schema.courses.term }) + .from(schema.courses) + .where(and(...conditions)) + .orderBy(asc(schema.courses.term)) + .limit(1); + + if (courses.length === 0) { + return { error: `Course "${courseCode}" not found${term ? ` for ${termCodeToName(term)}` : ""}.` }; + } + + const course = courses[0]; + const sectionConditions = [eq(schema.sections.courseId, course.id)]; + if (sectionTitle) { + sectionConditions.push(ilike(schema.sections.title, sectionTitle)); + } + + const sections = await db + .select({ + id: schema.sections.id, + title: schema.sections.title, + days: schema.sections.days, + startTime: schema.sections.startTime, + endTime: schema.sections.endTime, + status: schema.sections.status, + cap: schema.sections.cap, + tot: schema.sections.tot, + }) + .from(schema.sections) + .where(and(...sectionConditions)) + .orderBy(asc(schema.sections.id)); + + if (sections.length === 0) { + return { error: `No sections found for ${course.code}${sectionTitle ? ` section ${sectionTitle}` : ""}.` }; + } + + if (sections.length === 1 || sectionTitle) { + const s = sections[0]; + return { + classId: s.id, + section: { title: s.title, days: s.days, startTime: s.startTime, endTime: s.endTime }, + courseName: `${course.code} — ${course.title}`, + }; + } + + const sectionList = sections.map((s) => { + const formatted = formatSection(s); + return `${s.title} (${formatted.days.join("")} ${formatted.startTime}–${formatted.endTime}, ${s.status}, ${s.tot}/${s.cap} enrolled)`; + }); + + return { + error: `${course.code} has ${sections.length} sections. Please specify which one:\n${sectionList.join("\n")}`, + }; +} + +export function registerSnatchTools( + server: McpServer, + db: NodePgDatabase, + authContext?: AuthContext, + snatchDb?: Db | null +) { + // ── get_snatch_subscriptions ──────────────────────────────────────── + server.tool( + "get_snatch_subscriptions", + "Get the user's current TigerSnatch notification subscriptions — classes they'll be notified about when a seat opens.", + {}, + async () => { + if (!authContext?.netid) { + return { content: [{ type: "text" as const, text: "Requires authenticated user (x-user-netid header)." }], isError: true }; + } + + const result = await snatchFetch(`/junction/get_user_data/${authContext.netid}`); + if (!result.ok) { + return { content: [{ type: "text" as const, text: result.error ?? "Failed to get subscriptions." }], isError: true }; + } + + const userData = result.data?.data; + if (userData === "missing") { + return { + content: [{ + type: "text" as const, + text: JSON.stringify({ subscriptions: [], message: "No TigerSnatch account found. Visit tigersnatch.com to get started." }, null, 2), + }], + }; + } + + return { + content: [{ + type: "text" as const, + text: JSON.stringify({ + netid: authContext.netid, + subscriptions: userData?.waitlists ?? [], + autoResubscribe: userData?.auto_resub ?? null, + }, null, 2), + }], + }; + } + ); + + // ── subscribe_to_snatch ───────────────────────────────────────────── + server.tool( + "subscribe_to_snatch", + "Subscribe to TigerSnatch notifications for a class section. You'll get notified when a seat opens. Provide a course code and optionally a specific section (e.g., 'L01'). If the course has multiple sections and none is specified, you'll be shown the available sections to choose from.", + { + courseCode: z.string().describe("Course code (e.g., 'COS 226')."), + section: z.string().optional().describe("Section title (e.g., 'L01', 'P01'). If omitted and the course has multiple sections, available sections will be listed."), + term: z.number().optional().describe("Term code (e.g., 1272 for Fall 2026). Defaults to most recent term."), + }, + async ({ courseCode, section, term }) => { + if (!authContext?.netid) { + return { content: [{ type: "text" as const, text: "Requires authenticated user (x-user-netid header)." }], isError: true }; + } + + const resolved = await resolveClassId(db, courseCode, term, section); + if (!resolved.classId) { + return { content: [{ type: "text" as const, text: resolved.error ?? "Could not resolve class." }], isError: true }; + } + + const result = await snatchFetch(`/junction/add_to_waitlist/${authContext.netid}/${resolved.classId}`); + if (!result.ok) { + return { content: [{ type: "text" as const, text: result.error ?? "Failed to subscribe." }], isError: true }; + } + + return { + content: [{ + type: "text" as const, + text: JSON.stringify({ + subscribed: true, + course: resolved.courseName, + section: resolved.section?.title, + classId: resolved.classId, + message: `You'll be notified when a seat opens in ${resolved.courseName} (${resolved.section?.title}).`, + }, null, 2), + }], + }; + } + ); + + // ── unsubscribe_from_snatch ───────────────────────────────────────── + server.tool( + "unsubscribe_from_snatch", + "Unsubscribe from TigerSnatch notifications for a class section.", + { + courseCode: z.string().describe("Course code (e.g., 'COS 226')."), + section: z.string().optional().describe("Section title (e.g., 'L01', 'P01')."), + term: z.number().optional().describe("Term code (e.g., 1272 for Fall 2026). Defaults to most recent term."), + }, + async ({ courseCode, section, term }) => { + if (!authContext?.netid) { + return { content: [{ type: "text" as const, text: "Requires authenticated user (x-user-netid header)." }], isError: true }; + } + + const resolved = await resolveClassId(db, courseCode, term, section); + if (!resolved.classId) { + return { content: [{ type: "text" as const, text: resolved.error ?? "Could not resolve class." }], isError: true }; + } + + const result = await snatchFetch(`/junction/remove_from_waitlist/${authContext.netid}/${resolved.classId}`); + if (!result.ok) { + return { content: [{ type: "text" as const, text: result.error ?? "Failed to unsubscribe." }], isError: true }; + } + + return { + content: [{ + type: "text" as const, + text: JSON.stringify({ + unsubscribed: true, + course: resolved.courseName, + section: resolved.section?.title, + classId: resolved.classId, + }, null, 2), + }], + }; + } + ); + + // ── DEMAND / ANALYTICS TOOLS (powered by TigerSnatch MongoDB) ────── + + if (!snatchDb) return; + + // ── get_course_demand ─────────────────────────────────────────────── + server.tool( + "get_course_demand", + "Get demand signals for a course from TigerSnatch: real-time enrollment vs capacity for every section, how many students are subscribed for seat-open notifications (during add/drop), and whether the course has reserved seats. Useful for gauging how competitive a class is.", + { + courseCode: z.string().describe("Course code (e.g., 'COS 226')."), + }, + async ({ courseCode }) => { + // Search TigerSnatch mappings by displayname + const normalizedCode = courseCode.replace(/\s+/g, "").toUpperCase(); + const course = await snatchDb.collection("mappings").findOne({ + displayname: { $regex: new RegExp(`^${normalizedCode.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')}`, "i") }, + }); + + if (!course) { + // Try whitespace version + const wsMatch = await snatchDb.collection("mappings").findOne({ + displayname_whitespace: { $regex: new RegExp(courseCode.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'), "i") }, + }); + if (!wsMatch) { + return { content: [{ type: "text" as const, text: `Course "${courseCode}" not found in TigerSnatch.` }], isError: true }; + } + return await buildDemandResponse(snatchDb, wsMatch.courseid, wsMatch.displayname_whitespace, wsMatch.title); + } + + return await buildDemandResponse(snatchDb, course.courseid, course.displayname_whitespace, course.title); + } + ); + + // ── get_trending_courses ──────────────────────────────────────────── + server.tool( + "get_trending_courses", + "Get the most in-demand courses on TigerSnatch. During add/drop, shows courses with the most notification subscribers (students waiting for seats). Between semesters, shows platform-wide stats and historical data.", + { + limit: z.number().optional().describe("Max courses to return (default 15)."), + }, + async ({ limit: maxResults }) => { + const resultLimit = maxResults ?? 15; + + const admin = await snatchDb.collection("admin").findOne({}); + if (!admin) { + return { content: [{ type: "text" as const, text: "TigerSnatch admin data not available." }], isError: true }; + } + + const topSubs = (admin.stats_top_subs ?? []).slice(0, resultLimit); + const hasTrendingData = topSubs.length > 0; + + if (hasTrendingData) { + return { + content: [{ + type: "text" as const, + text: JSON.stringify({ + term: admin.current_term_name, + termCode: admin.current_term_code, + trendingCourses: topSubs, + platformStats: { + totalSubscriptions: admin.stats_total_subs, + subscribedUsers: admin.stats_subbed_users, + subscribedCourses: admin.stats_subbed_courses, + subscribedSections: admin.stats_subbed_sections, + }, + lastUpdated: admin.stats_update_time, + }, null, 2), + }], + }; + } + + // Between semesters — show aggregate stats and most-enrolled courses + const topEnrolled = await snatchDb.collection("enrollments") + .find({ capacity: { $gt: 5 } }) + .sort({ enrollment: -1 }) + .limit(resultLimit) + .toArray(); + + const enriched = []; + for (const e of topEnrolled) { + const course = await snatchDb.collection("mappings").findOne({ courseid: e.courseid }); + enriched.push({ + course: course?.displayname_whitespace ?? e.courseid, + title: course?.title ?? "", + section: e.section, + enrollment: e.enrollment, + capacity: e.capacity, + fillPercent: e.capacity > 0 ? Math.round((e.enrollment / e.capacity) * 100) : 0, + }); + } + + return { + content: [{ + type: "text" as const, + text: JSON.stringify({ + term: admin.current_term_name, + termCode: admin.current_term_code, + status: "Between semesters — no active subscriptions. Showing enrollment data.", + topEnrolled: enriched, + platformStats: { + totalUsersAllTime: admin.stats_total_users, + totalNotificationsAllTime: admin.stats_total_notifs, + }, + lastUpdated: admin.stats_update_time, + }, null, 2), + }], + }; + } + ); + + // ── get_course_historical_demand ────────────────────────────────────── + server.tool( + "get_course_historical_demand", + "Get historical demand trends for a course across past semesters. Shows enrollment fill rates, closed section counts, and capacity changes over time. Answers questions like 'Is this class hard to get into?' and 'Does this course usually fill up?'. Uses engine DB data across all available terms.", + { + courseCode: z.string().describe("Course code (e.g., 'COS 226')."), + }, + async ({ courseCode }) => { + // Resolve course to listing_id via engine DB + const courseRows = await db + .select({ + listingId: schema.courses.listingId, + code: schema.courses.code, + title: schema.courses.title, + }) + .from(schema.courses) + .where(ilike(schema.courses.code, courseCode)) + .orderBy(asc(schema.courses.term)) + .limit(1); + + if (courseRows.length === 0) { + return { content: [{ type: "text" as const, text: `Course "${courseCode}" not found.` }], isError: true }; + } + + const { listingId, code, title } = courseRows[0]; + + // Get all offerings of this course across terms + const offerings = await db + .select({ + term: schema.courses.term, + courseId: schema.courses.id, + status: schema.courses.status, + }) + .from(schema.courses) + .where(eq(schema.courses.listingId, listingId)) + .orderBy(asc(schema.courses.term)); + + const termStats = []; + let totalFillRateSum = 0; + let termsWithEnrollment = 0; + let termsFullyClosed = 0; + let termsWithClosedSections = 0; + + for (const offering of offerings) { + const sections = await db + .select({ + title: schema.sections.title, + status: schema.sections.status, + cap: schema.sections.cap, + tot: schema.sections.tot, + }) + .from(schema.sections) + .where(eq(schema.sections.courseId, offering.courseId)); + + const totalCap = sections.reduce((s, sec) => s + (sec.cap ?? 0), 0); + const totalEnrolled = sections.reduce((s, sec) => s + (sec.tot ?? 0), 0); + const closedCount = sections.filter((s) => s.status === "closed").length; + const canceledCount = sections.filter((s) => s.status === "canceled").length; + const fillRate = totalCap > 0 ? Math.round((totalEnrolled / totalCap) * 100) : 0; + + // Only count terms with actual enrollment data (not future terms) + if (totalEnrolled > 0) { + totalFillRateSum += fillRate; + termsWithEnrollment++; + if (offering.status === "closed") termsFullyClosed++; + if (closedCount > 0) termsWithClosedSections++; + } + + termStats.push({ + term: offering.term, + termName: termCodeToName(offering.term), + courseStatus: offering.status, + totalEnrolled, + totalCapacity: totalCap, + fillRate: `${fillRate}%`, + sections: sections.length, + closedSections: closedCount, + canceledSections: canceledCount, + }); + } + + const avgFillRate = termsWithEnrollment > 0 + ? Math.round(totalFillRateSum / termsWithEnrollment) + : 0; + + // Determine competitiveness + let competitiveness: string; + if (avgFillRate >= 95 || termsFullyClosed >= termsWithEnrollment * 0.5) { + competitiveness = "Very Competitive — this course consistently fills up and often closes. Plan to enroll early or use TigerSnatch."; + } else if (avgFillRate >= 80) { + competitiveness = "Competitive — this course typically fills most of its seats. Early enrollment recommended."; + } else if (avgFillRate >= 50) { + competitiveness = "Moderate — this course usually has available seats but does fill up in popular sections."; + } else { + competitiveness = "Low — this course generally has plenty of available seats."; + } + + // Capacity trend + const capsOverTime = termStats + .filter((t) => t.totalCapacity > 0) + .map((t) => t.totalCapacity); + let capacityTrend = "Stable"; + if (capsOverTime.length >= 2) { + const first = capsOverTime[0]; + const last = capsOverTime[capsOverTime.length - 1]; + const change = ((last - first) / first) * 100; + if (change > 15) capacityTrend = `Growing (+${Math.round(change)}% capacity since ${termStats[0].termName})`; + else if (change < -15) capacityTrend = `Shrinking (${Math.round(change)}% capacity since ${termStats[0].termName})`; + } + + return { + content: [{ + type: "text" as const, + text: JSON.stringify({ + course: code, + title, + termsOffered: offerings.length, + termsWithEnrollmentData: termsWithEnrollment, + averageFillRate: `${avgFillRate}%`, + timesFullyClosed: `${termsFullyClosed}/${termsWithEnrollment} terms`, + timesWithClosedSections: `${termsWithClosedSections}/${termsWithEnrollment} terms`, + capacityTrend, + competitiveness, + history: termStats, + }, null, 2), + }], + }; + } + ); +} + +async function buildDemandResponse(snatchDb: Db, courseid: string, displayName: string, title: string) { + // Get enrollment data for all sections + const enrollments = await snatchDb.collection("enrollments") + .find({ courseid }) + .toArray(); + + // Get course doc for reserved seats and waitlist info + const courseDoc = await snatchDb.collection("courses").findOne({ courseid }); + + // Get waitlist sizes for each section + const sectionDemand = []; + for (const e of enrollments) { + const classKey = `class_${e.classid}`; + const classInfo = courseDoc?.[classKey] as Record | undefined; + + // Check waitlist for this class + const waitlistDoc = await snatchDb.collection("waitlists").findOne({ classid: String(e.classid) }); + const subscriberCount = Array.isArray(waitlistDoc?.netids) ? waitlistDoc.netids.length : 0; + + sectionDemand.push({ + section: e.section, + classId: e.classid, + enrollment: e.enrollment, + capacity: e.capacity, + fillPercent: e.capacity > 0 ? Math.round((e.enrollment / e.capacity) * 100) : 0, + isOpen: classInfo?.status_is_open ?? (e.enrollment < e.capacity), + subscribers: subscriberCount, + days: classInfo?.days ?? null, + startTime: classInfo?.start_time ?? null, + endTime: classInfo?.end_time ?? null, + }); + } + + const totalEnrollment = enrollments.reduce((s, e) => s + e.enrollment, 0); + const totalCapacity = enrollments.reduce((s, e) => s + e.capacity, 0); + const totalSubscribers = sectionDemand.reduce((s, d) => s + d.subscribers, 0); + + return { + content: [{ + type: "text" as const, + text: JSON.stringify({ + course: displayName, + title, + hasReservedSeats: courseDoc?.has_reserved_seats ?? false, + overallFill: totalCapacity > 0 ? `${totalEnrollment}/${totalCapacity} (${Math.round((totalEnrollment / totalCapacity) * 100)}%)` : "N/A", + totalSubscribers, + sections: sectionDemand, + demandLevel: totalSubscribers > 50 ? "Very High" : totalSubscribers > 20 ? "High" : totalSubscribers > 5 ? "Moderate" : totalSubscribers > 0 ? "Low" : "None (no active watchers)", + }, null, 2), + }], + }; +} diff --git a/apps/engine/src/plugins/snatch-db.ts b/apps/engine/src/plugins/snatch-db.ts new file mode 100644 index 00000000..d2d0b606 --- /dev/null +++ b/apps/engine/src/plugins/snatch-db.ts @@ -0,0 +1,34 @@ +import fp from "fastify-plugin"; +import type { FastifyPluginAsync } from "fastify"; +import { MongoClient, type Db } from "mongodb"; + +declare module "fastify" { + interface FastifyInstance { + snatchDb: Db | null; + } +} + +const snatchDbPlugin: FastifyPluginAsync = async (app) => { + const uri = process.env.SNATCH_DB_URI?.trim(); + + if (!uri) { + app.log.warn("SNATCH_DB_URI not set — TigerSnatch demand tools disabled."); + app.decorate("snatchDb", null); + return; + } + + const client = new MongoClient(uri, { serverSelectionTimeoutMS: 5000 }); + await client.connect(); + const db = client.db("tigersnatch"); + + app.decorate("snatchDb", db); + app.log.info("TigerSnatch MongoDB connected"); + + app.addHook("onClose", async () => { + await client.close(); + }); +}; + +export default fp(snatchDbPlugin, { + name: "snatch-db-plugin", +}); diff --git a/apps/engine/src/plugins/supabase.ts b/apps/engine/src/plugins/supabase.ts new file mode 100644 index 00000000..53b0cdf2 --- /dev/null +++ b/apps/engine/src/plugins/supabase.ts @@ -0,0 +1,32 @@ +import fp from "fastify-plugin"; +import type { FastifyPluginAsync } from "fastify"; +import { createClient, type SupabaseClient } from "@supabase/supabase-js"; + +declare module "fastify" { + interface FastifyInstance { + supabase: SupabaseClient; + } +} + +const supabasePlugin: FastifyPluginAsync = async (app) => { + const url = process.env.SUPABASE_URL?.trim(); + const serviceRoleKey = process.env.SUPABASE_SERVICE_ROLE_KEY?.trim(); + + if (!url || !serviceRoleKey) { + app.log.warn("SUPABASE_URL or SUPABASE_SERVICE_ROLE_KEY not set — Supabase plugin disabled. /junction/mcp schedule tools will not work."); + // Decorate with a dummy so Fastify doesn't throw on access + app.decorate("supabase", null as unknown as SupabaseClient); + return; + } + + const supabase = createClient(url, serviceRoleKey, { + auth: { persistSession: false, autoRefreshToken: false }, + }); + + app.decorate("supabase", supabase); + app.log.info("Supabase client initialized"); +}; + +export default fp(supabasePlugin, { + name: "supabase-plugin", +}); diff --git a/apps/engine/src/routes/mcp.ts b/apps/engine/src/routes/mcp.ts index 88c968ac..462579d3 100644 --- a/apps/engine/src/routes/mcp.ts +++ b/apps/engine/src/routes/mcp.ts @@ -135,7 +135,9 @@ const mcpRoutes: FastifyPluginAsync = async (app, opts) => { .send(rpcError(-32009, "Too many active MCP sessions for this client. Close a session and retry.")); } - const mcpServer = createMcpServer(app.db.db, auth.authContext, scope); + const supabase = scope === "junction" ? app.supabase : undefined; + const snatchDb = (scope === "junction" || scope === "snatch") ? app.snatchDb : undefined; + const mcpServer = createMcpServer(app.db.db, auth.authContext, scope, supabase, snatchDb); const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), });