Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "truss"
version = "0.11.0"
version = "0.11.1rc15"
description = "A seamless bridge from model development to model delivery"
authors = [
{ name = "Pankaj Gupta", email = "[email protected]" },
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/public_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ class WebSocketProtocol(Protocol):

async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...

async def receive(self) -> Union[str, bytes]: ...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async def receive_text(self) -> str: ...
async def receive_bytes(self) -> bytes: ...
async def receive_json(self) -> Any: ...
Expand Down
18 changes: 18 additions & 0 deletions truss-chains/truss_chains/remote_chainlet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import threading
import time
import traceback
import typing
from collections.abc import AsyncIterator
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -586,6 +587,23 @@ def __init__(self, websocket: "fastapi.WebSocket") -> None:
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
await self._websocket.close(code=code, reason=reason)

async def receive(self) -> Union[str, bytes]:
try:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pulled from lines above, since we need to expose WebsocketDisconnect to end users.

This wrapper exists to try and hide FastAPI, but per a comment above we already knew we were leaking details via Exceptions. I'd rather not come up with a separate protocol for how users would have to think about catching websocket disconnects (especially since it would be different in Truss), so for now I just pass it through

import fastapi
except ImportError:
raise utils.make_optional_import_error("fastapi")

message = await self._websocket.receive()

if message.get("type") == "websocket.disconnect":
# NB(nikhil): Mimics FastAPI `_raise_on_disconnect`, since otherwise the user has no
# way of detecting that the client disconnected.
raise fastapi.WebSocketDisconnect(message["code"], message.get("reason"))
elif message.get("text"):
return typing.cast(str, message["text"])
else:
return typing.cast(bytes, message["bytes"])

async def receive_text(self) -> str:
return await self._websocket.receive_text()

Expand Down
115 changes: 82 additions & 33 deletions truss/templates/control/control/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import asyncio
import logging
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Protocol

import httpx
from fastapi import APIRouter, WebSocket
from fastapi.responses import JSONResponse, StreamingResponse
from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
from httpx_ws import _exceptions as httpx_ws_exceptions
from httpx_ws import aconnect_ws
from starlette.requests import ClientDisconnect, Request
from starlette.responses import Response
from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisconnect
from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
from wsproto.events import BytesMessage, TextMessage

Expand All @@ -29,6 +30,15 @@

control_app = APIRouter()

WEBSOCKET_NORMAL_CLOSURE_CODE = 1000
WEBSOCKET_SERVER_ERROR_CODE = 1011


class CloseableWebsocket(Protocol):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe overkill, and mentioned through comments below, but it's unfortunate that the incoming websocket (FastAPI) and outbound websockets (httpx) are different types with incompatible APIs.

We get around that separately below, but luckily close is common, so I wrapped in a Protocol here

async def close(
self, code: int = WEBSOCKET_NORMAL_CLOSURE_CODE, reason: Optional[str] = None
) -> None: ...


@control_app.get("/")
def index():
Expand Down Expand Up @@ -118,51 +128,90 @@ def inference_retries(
yield attempt


async def _safe_close_ws(ws: WebSocket, logger: logging.Logger):
async def _safe_close_ws(
ws: CloseableWebsocket,
logger: logging.Logger,
code: int,
reason: Optional[str] = None,
):
try:
await ws.close()
await ws.close(code, reason)
except RuntimeError as close_error:
logger.debug(f"Duplicate close of websocket: `{close_error}`.")


async def forward_to_server(
client_ws: WebSocket, server_ws: AsyncWebSocketSession
) -> None:
while True:
message = await client_ws.receive()
if message.get("type") == "websocket.disconnect":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs seem to suggest that the client closing the connection could either:

  • Trigger an StartletteWebSocketDisconnect
  • Send an explicit websocket.disconnect message, that we have to handle

To be safe, we re-raise a StartletteWebSocketDisconnect here and handle below

raise StartletteWebSocketDisconnect(
message.get("code", 1000), message.get("reason")
)
if "text" in message:
await server_ws.send_text(message["text"])
elif "bytes" in message:
await server_ws.send_bytes(message["bytes"])


async def forward_to_client(client_ws: WebSocket, server_ws: AsyncWebSocketSession):
while True:
message = await server_ws.receive()
if isinstance(message, TextMessage):
await client_ws.send_text(message.data)
elif isinstance(message, BytesMessage):
await client_ws.send_bytes(message.data)


# NB(nikhil): _handle_websocket_forwarding uses some py311 specific syntax, but in newer
# versions of truss we're guaranteed to be running the control server with at least that version.
async def _handle_websocket_forwarding(
client_ws: WebSocket, server_ws: AsyncWebSocketSession
):
logger = client_ws.app.state.logger
try:
async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing from asyncio.gather to TaskGroup means the other task gets automatically cancelled when one raises an exception, compared to running until it attempts to receive on a closed websocket

tg.create_task(forward_to_client(client_ws, server_ws))
tg.create_task(forward_to_server(client_ws, server_ws))
except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821
# NB(nikhil): The first websocket proxy method to raise an error will
# be surfaced here, and that contains the information we want to forward to the
# other websocket. Further errors might raise as a result of cancellation, but we
# can safely ignore those.
exc = eg.exceptions[0]
if isinstance(exc, WebSocketDisconnect):
await _safe_close_ws(client_ws, logger, exc.code, exc.reason)
elif isinstance(exc, StartletteWebSocketDisconnect):
await _safe_close_ws(server_ws, logger, exc.code, exc.reason)
else:
logger.warning(f"Ungraceful websocket close: {exc}")
finally:
# NB(nikhil): In most common cases, both websockets would have been successfully
# closed with applicable codes above, these lines are just a failsafe.
await _safe_close_ws(client_ws, logger, code=WEBSOCKET_SERVER_ERROR_CODE)
await _safe_close_ws(server_ws, logger, code=WEBSOCKET_SERVER_ERROR_CODE)


async def _attempt_websocket_proxy(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor was mainly to improve on the 6 layer indent for proxy_ws before, but there are some material changes above

client_ws: WebSocket, proxy_client: httpx.AsyncClient, logger
):
async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
await client_ws.accept()
await _handle_websocket_forwarding(client_ws, server_ws)


async def proxy_ws(client_ws: WebSocket):
proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client
logger = client_ws.app.state.logger

for attempt in inference_retries():
with attempt:
try:
async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
# Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
# for sending data, so it's not easy to create a unified wrapper.
async def forward_to_server():
while True:
message = await client_ws.receive()
if message.get("type") == "websocket.disconnect":
break
if "text" in message:
await server_ws.send_text(message["text"])
elif "bytes" in message:
await server_ws.send_bytes(message["bytes"])

async def forward_to_client():
while True:
message = await server_ws.receive()
if message is None:
break
if isinstance(message, TextMessage):
await client_ws.send_text(message.data)
elif isinstance(message, BytesMessage):
await client_ws.send_bytes(message.data)

await client_ws.accept()
try:
await asyncio.gather(forward_to_client(), forward_to_server())
finally:
await _safe_close_ws(client_ws, logger)
await _attempt_websocket_proxy(client_ws, proxy_client, logger)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for breaking this code block down into smaller functions!

except httpx_ws_exceptions.HTTPXWSException as e:
logger.warning(f"WebSocket connection rejected: {e}")
await _safe_close_ws(client_ws, logger)
await _safe_close_ws(client_ws, logger, WEBSOCKET_SERVER_ERROR_CODE)
break


Expand Down
2 changes: 1 addition & 1 deletion truss/templates/control/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ loguru>=0.7.2
python-json-logger>=2.0.2
tenacity>=8.1.0
# To avoid divergence, this should follow the latest release.
truss==0.9.100
truss==0.11.1rc15
uvicorn>=0.24.0
uvloop>=0.19.0
websockets>=10.0
1 change: 1 addition & 0 deletions truss/templates/server/common/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_BASETEN_CLIENT_ERROR_CODE = 700

MODEL_ERROR_MESSAGE = "Internal Server Error (in model/chainlet)."
WEBSOCKET_SERVER_ERROR_CODE = 1011


class ModelMissingError(Exception):
Expand Down
8 changes: 5 additions & 3 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def parse_body(request: Request) -> bytes:


async def _safe_close_websocket(
ws: WebSocket, reason: Optional[str], status_code: int = 1000
ws: WebSocket, status_code: int = 1000, reason: Optional[str] = None
) -> None:
try:
await ws.close(code=status_code, reason=reason)
Expand Down Expand Up @@ -257,14 +257,16 @@ async def websocket(self, ws: WebSocket) -> None:
try:
await ws.accept()
await self._model.websocket(ws)
await _safe_close_websocket(ws, None, status_code=1000)
await _safe_close_websocket(ws, status_code=1000, reason=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we always specify arguments, do we still need default values for _safe_close_websocket()? It makes the code DRY-er and more explicit.

Copy link
Contributor Author

@nnarayen nnarayen Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately on L182 I have a failsafe close, but I can make code without a default!

except WebSocketDisconnect as ws_error:
logging.info(
f"Client terminated websocket connection: `{ws_error}`."
)
except Exception:
await _safe_close_websocket(
ws, errors.MODEL_ERROR_MESSAGE, status_code=1011
ws,
status_code=errors.WEBSOCKET_SERVER_ERROR_CODE,
reason=errors.MODEL_ERROR_MESSAGE,
)
raise # Re raise to let `intercept_exceptions` deal with it.

Expand Down
34 changes: 20 additions & 14 deletions truss/tests/templates/control/control/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.mock import AsyncMock, MagicMock, patch
import asyncio
from unittest.mock import AsyncMock, MagicMock, call, patch

import pytest
from fastapi import FastAPI, WebSocket
Expand Down Expand Up @@ -31,33 +32,38 @@ def client_ws(app):

@pytest.mark.asyncio
async def test_proxy_ws_bidirectional_messaging(client_ws):
"""Test that both directions of communication work and clean up properly"""
client_ws.receive.side_effect = [
{"type": "websocket.receive", "text": "msg1"},
{"type": "websocket.receive", "text": "msg2"},
{"type": "websocket.disconnect"},
]
client_queue = asyncio.Queue()
client_ws.receive = client_queue.get

server_queue = asyncio.Queue()
mock_server_ws = AsyncMock(spec=AsyncWebSocketSession)
mock_server_ws.receive.side_effect = [
TextMessage(data="response1"),
TextMessage(data="response2"),
None, # server closing connection
]
mock_server_ws.receive = server_queue.get
mock_server_ws.__aenter__.return_value = mock_server_ws
mock_server_ws.__aexit__.return_value = None

client_queue.put_nowait({"type": "websocket.receive", "text": "msg1"})
client_queue.put_nowait({"type": "websocket.receive", "text": "msg2"})
server_queue.put_nowait(TextMessage(data="response1"))
server_queue.put_nowait(TextMessage(data="response2"))

with patch(
"truss.templates.control.control.endpoints.aconnect_ws",
return_value=mock_server_ws,
):
await proxy_ws(client_ws)
proxy_task = asyncio.create_task(proxy_ws(client_ws))
client_queue.put_nowait(
{"type": "websocket.disconnect", "code": 1002, "reason": "test-closure"}
)

await proxy_task

assert mock_server_ws.send_text.call_count == 2
assert mock_server_ws.send_text.call_args_list == [(("msg1",),), (("msg2",),)]
assert client_ws.send_text.call_count == 2
assert client_ws.send_text.call_args_list == [(("response1",),), (("response2",),)]
client_ws.close.assert_called_once()

assert mock_server_ws.close.call_args_list[0] == call(1002, "test-closure")
client_ws.close.assert_called()


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading