-
Notifications
You must be signed in to change notification settings - Fork 91
Better code propagation, better Chains websocket API #1918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| [project] | ||
| name = "truss" | ||
| version = "0.11.0" | ||
| version = "0.11.1rc15" | ||
| description = "A seamless bridge from model development to model delivery" | ||
| authors = [ | ||
| { name = "Pankaj Gupta", email = "[email protected]" }, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| import threading | ||
| import time | ||
| import traceback | ||
| import typing | ||
| from collections.abc import AsyncIterator | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
|
|
@@ -586,6 +587,23 @@ def __init__(self, websocket: "fastapi.WebSocket") -> None: | |
| async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: | ||
| await self._websocket.close(code=code, reason=reason) | ||
|
|
||
| async def receive(self) -> Union[str, bytes]: | ||
| try: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is pulled from lines above, since we need to expose WebsocketDisconnect to end users. This wrapper exists to try and hide FastAPI, but per a comment above we already knew we were leaking details via Exceptions. I'd rather not come up with a separate protocol for how users would have to think about catching websocket disconnects (especially since it would be different in Truss), so for now I just pass it through |
||
| import fastapi | ||
| except ImportError: | ||
| raise utils.make_optional_import_error("fastapi") | ||
|
|
||
| message = await self._websocket.receive() | ||
|
|
||
| if message.get("type") == "websocket.disconnect": | ||
| # NB(nikhil): Mimics FastAPI `_raise_on_disconnect`, since otherwise the user has no | ||
| # way of detecting that the client disconnected. | ||
| raise fastapi.WebSocketDisconnect(message["code"], message.get("reason")) | ||
| elif message.get("text"): | ||
| return typing.cast(str, message["text"]) | ||
| else: | ||
| return typing.cast(bytes, message["bytes"]) | ||
|
|
||
| async def receive_text(self) -> str: | ||
| return await self._websocket.receive_text() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,15 @@ | ||
| import asyncio | ||
| import logging | ||
| from typing import Any, Callable, Dict | ||
| from typing import Any, Callable, Dict, Optional, Protocol | ||
|
|
||
| import httpx | ||
| from fastapi import APIRouter, WebSocket | ||
| from fastapi.responses import JSONResponse, StreamingResponse | ||
| from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws | ||
| from httpx_ws import _exceptions as httpx_ws_exceptions | ||
| from httpx_ws import aconnect_ws | ||
| from starlette.requests import ClientDisconnect, Request | ||
| from starlette.responses import Response | ||
| from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisconnect | ||
| from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed | ||
| from wsproto.events import BytesMessage, TextMessage | ||
|
|
||
|
|
@@ -29,6 +30,15 @@ | |
|
|
||
| control_app = APIRouter() | ||
|
|
||
| WEBSOCKET_NORMAL_CLOSURE_CODE = 1000 | ||
| WEBSOCKET_SERVER_ERROR_CODE = 1011 | ||
|
|
||
|
|
||
| class CloseableWebsocket(Protocol): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe overkill, and mentioned through comments below, but it's unfortunate that the incoming websocket (FastAPI) and outbound websockets (httpx) are different types with incompatible APIs. We get around that separately below, but luckily |
||
| async def close( | ||
| self, code: int = WEBSOCKET_NORMAL_CLOSURE_CODE, reason: Optional[str] = None | ||
| ) -> None: ... | ||
|
|
||
|
|
||
| @control_app.get("/") | ||
| def index(): | ||
|
|
@@ -118,51 +128,90 @@ def inference_retries( | |
| yield attempt | ||
|
|
||
|
|
||
| async def _safe_close_ws(ws: WebSocket, logger: logging.Logger): | ||
| async def _safe_close_ws( | ||
| ws: CloseableWebsocket, | ||
| logger: logging.Logger, | ||
| code: int, | ||
| reason: Optional[str] = None, | ||
| ): | ||
| try: | ||
| await ws.close() | ||
| await ws.close(code, reason) | ||
| except RuntimeError as close_error: | ||
| logger.debug(f"Duplicate close of websocket: `{close_error}`.") | ||
|
|
||
|
|
||
| async def forward_to_server( | ||
| client_ws: WebSocket, server_ws: AsyncWebSocketSession | ||
| ) -> None: | ||
| while True: | ||
| message = await client_ws.receive() | ||
| if message.get("type") == "websocket.disconnect": | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docs seem to suggest that the client closing the connection could either:
To be safe, we re-raise a |
||
| raise StartletteWebSocketDisconnect( | ||
| message.get("code", 1000), message.get("reason") | ||
| ) | ||
| if "text" in message: | ||
| await server_ws.send_text(message["text"]) | ||
| elif "bytes" in message: | ||
| await server_ws.send_bytes(message["bytes"]) | ||
|
|
||
|
|
||
| async def forward_to_client(client_ws: WebSocket, server_ws: AsyncWebSocketSession): | ||
| while True: | ||
| message = await server_ws.receive() | ||
| if isinstance(message, TextMessage): | ||
| await client_ws.send_text(message.data) | ||
| elif isinstance(message, BytesMessage): | ||
| await client_ws.send_bytes(message.data) | ||
|
|
||
|
|
||
| # NB(nikhil): _handle_websocket_forwarding uses some py311 specific syntax, but in newer | ||
| # versions of truss we're guaranteed to be running the control server with at least that version. | ||
| async def _handle_websocket_forwarding( | ||
| client_ws: WebSocket, server_ws: AsyncWebSocketSession | ||
| ): | ||
| logger = client_ws.app.state.logger | ||
| try: | ||
| async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing from |
||
| tg.create_task(forward_to_client(client_ws, server_ws)) | ||
| tg.create_task(forward_to_server(client_ws, server_ws)) | ||
| except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821 | ||
| # NB(nikhil): The first websocket proxy method to raise an error will | ||
| # be surfaced here, and that contains the information we want to forward to the | ||
| # other websocket. Further errors might raise as a result of cancellation, but we | ||
| # can safely ignore those. | ||
| exc = eg.exceptions[0] | ||
| if isinstance(exc, WebSocketDisconnect): | ||
| await _safe_close_ws(client_ws, logger, exc.code, exc.reason) | ||
| elif isinstance(exc, StartletteWebSocketDisconnect): | ||
| await _safe_close_ws(server_ws, logger, exc.code, exc.reason) | ||
| else: | ||
| logger.warning(f"Ungraceful websocket close: {exc}") | ||
| finally: | ||
| # NB(nikhil): In most common cases, both websockets would have been successfully | ||
| # closed with applicable codes above, these lines are just a failsafe. | ||
| await _safe_close_ws(client_ws, logger, code=WEBSOCKET_SERVER_ERROR_CODE) | ||
| await _safe_close_ws(server_ws, logger, code=WEBSOCKET_SERVER_ERROR_CODE) | ||
|
|
||
|
|
||
| async def _attempt_websocket_proxy( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This refactor was mainly to improve on the 6 layer indent for |
||
| client_ws: WebSocket, proxy_client: httpx.AsyncClient, logger | ||
| ): | ||
| async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore | ||
| await client_ws.accept() | ||
| await _handle_websocket_forwarding(client_ws, server_ws) | ||
|
|
||
|
|
||
| async def proxy_ws(client_ws: WebSocket): | ||
| proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client | ||
| logger = client_ws.app.state.logger | ||
|
|
||
| for attempt in inference_retries(): | ||
| with attempt: | ||
| try: | ||
| async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore | ||
| # Unfortunate, but FastAPI and httpx-ws have slightly different abstractions | ||
| # for sending data, so it's not easy to create a unified wrapper. | ||
| async def forward_to_server(): | ||
| while True: | ||
| message = await client_ws.receive() | ||
| if message.get("type") == "websocket.disconnect": | ||
| break | ||
| if "text" in message: | ||
| await server_ws.send_text(message["text"]) | ||
| elif "bytes" in message: | ||
| await server_ws.send_bytes(message["bytes"]) | ||
|
|
||
| async def forward_to_client(): | ||
| while True: | ||
| message = await server_ws.receive() | ||
| if message is None: | ||
| break | ||
| if isinstance(message, TextMessage): | ||
| await client_ws.send_text(message.data) | ||
| elif isinstance(message, BytesMessage): | ||
| await client_ws.send_bytes(message.data) | ||
|
|
||
| await client_ws.accept() | ||
| try: | ||
| await asyncio.gather(forward_to_client(), forward_to_server()) | ||
| finally: | ||
| await _safe_close_ws(client_ws, logger) | ||
| await _attempt_websocket_proxy(client_ws, proxy_client, logger) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for breaking this code block down into smaller functions! |
||
| except httpx_ws_exceptions.HTTPXWSException as e: | ||
| logger.warning(f"WebSocket connection rejected: {e}") | ||
| await _safe_close_ws(client_ws, logger) | ||
| await _safe_close_ws(client_ws, logger, WEBSOCKET_SERVER_ERROR_CODE) | ||
| break | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,7 +76,7 @@ async def parse_body(request: Request) -> bytes: | |
|
|
||
|
|
||
| async def _safe_close_websocket( | ||
| ws: WebSocket, reason: Optional[str], status_code: int = 1000 | ||
| ws: WebSocket, status_code: int = 1000, reason: Optional[str] = None | ||
| ) -> None: | ||
| try: | ||
| await ws.close(code=status_code, reason=reason) | ||
|
|
@@ -257,14 +257,16 @@ async def websocket(self, ws: WebSocket) -> None: | |
| try: | ||
| await ws.accept() | ||
| await self._model.websocket(ws) | ||
| await _safe_close_websocket(ws, None, status_code=1000) | ||
| await _safe_close_websocket(ws, status_code=1000, reason=None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we always specify arguments, do we still need default values for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately on L182 I have a failsafe close, but I can make |
||
| except WebSocketDisconnect as ws_error: | ||
| logging.info( | ||
| f"Client terminated websocket connection: `{ws_error}`." | ||
| ) | ||
| except Exception: | ||
| await _safe_close_websocket( | ||
| ws, errors.MODEL_ERROR_MESSAGE, status_code=1011 | ||
| ws, | ||
| status_code=errors.WEBSOCKET_SERVER_ERROR_CODE, | ||
| reason=errors.MODEL_ERROR_MESSAGE, | ||
| ) | ||
| raise # Re raise to let `intercept_exceptions` deal with it. | ||
|
|
||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Context for this change is here: https://basetenlabs.slack.com/archives/C06CZ3RSXRU/p1757360869605229