Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 49 additions & 22 deletions truss/templates/control/control/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def inference_retries(
yield attempt


async def _safe_close_ws(ws: WebSocket, logger: logging.Logger):
async def _safe_close_ws(ws: WebSocket, logger: logging.Logger, code: int = 1000, reason: str | None = None):
try:
await ws.close()
await ws.close(code, reason)
except RuntimeError as close_error:
logger.debug(f"Duplicate close of websocket: `{close_error}`.")

Expand All @@ -135,35 +135,62 @@ async def proxy_ws(client_ws: WebSocket):
async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
# Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
# for sending data, so it's not easy to create a unified wrapper.
#
# As a rule:
# - On receive errors, forward a 1006 to the other end.
# - On send errors, do nothing, and assume the other actor will see a corresponding recv error, which will trigger it to send a 1006.
async def forward_to_server():
while True:
message = await client_ws.receive()
if message.get("type") == "websocket.disconnect":
break
if "text" in message:
await server_ws.send_text(message["text"])
elif "bytes" in message:
await server_ws.send_bytes(message["bytes"])
try:
message = await client_ws.receive()
except Exception as e:
logger.warning(f"failed to receive message from client: {e}")
await server_ws.close(1006, f"failed to receive message from client: {e}")
return

try:
if message.get("type") == "websocket.disconnect":
await server_ws.close(message["code"], message.get("reason"))
return
if "text" in message:
await server_ws.send_text(message["text"])
elif "bytes" in message:
await server_ws.send_bytes(message["bytes"])
except:
return


async def forward_to_client():
while True:
message = await server_ws.receive()
try:
message = await server_ws.receive()
except httpx_ws_exceptions.WebsocketDisconnect as e:
await client_ws.close(e.code, e.reason)
return
except Exception as e:
await client_ws.close(1006, f"failed to receive message from server: {e}")
return

if message is None:
break
if isinstance(message, TextMessage):
await client_ws.send_text(message.data)
elif isinstance(message, BytesMessage):
await client_ws.send_bytes(message.data)

try:
if isinstance(message, TextMessage):
await client_ws.send_text(message.data)
elif isinstance(message, BytesMessage):
await client_ws.send_bytes(message.data)
except:
return

await client_ws.accept()
try:
await asyncio.gather(forward_to_client(), forward_to_server())
finally:
await _safe_close_ws(client_ws, logger)
except httpx_ws_exceptions.HTTPXWSException as e:
logger.warning(f"WebSocket connection rejected: {e}")
await _safe_close_ws(client_ws, logger)
break
await asyncio.gather(forward_to_client(), forward_to_server())
# The asyncio funcs all catch their own exceptions, so this should only be triggered by the aconnect_ws call.
except Exception as e:
logger.warning(f"error initializing websocket connection to server: {e}")
try:
await client_ws.close(1006, f"error initializing websocket connetion to server: {e}")
except:
pass


control_app.add_websocket_route("/v1/websocket", proxy_ws)
Expand Down
Loading