diff --git a/truss/templates/control/control/endpoints.py b/truss/templates/control/control/endpoints.py index 0a5d6abbe..4d3d49a2c 100644 --- a/truss/templates/control/control/endpoints.py +++ b/truss/templates/control/control/endpoints.py @@ -118,9 +118,9 @@ def inference_retries( yield attempt -async def _safe_close_ws(ws: WebSocket, logger: logging.Logger): +async def _safe_close_ws(ws: WebSocket, logger: logging.Logger, code: int = 1000, reason: str | None = None): try: - await ws.close() + await ws.close(code, reason) except RuntimeError as close_error: logger.debug(f"Duplicate close of websocket: `{close_error}`.") @@ -135,35 +135,62 @@ async def proxy_ws(client_ws: WebSocket): async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore # Unfortunate, but FastAPI and httpx-ws have slightly different abstractions # for sending data, so it's not easy to create a unified wrapper. + # + # As a rule: + # - On receive errors, forward a 1006 to the other end. + # - On send errors, do nothing, and assume the other actor will see a corresponding recv error, which will trigger it to send a 1006. async def forward_to_server(): while True: - message = await client_ws.receive() - if message.get("type") == "websocket.disconnect": - break - if "text" in message: - await server_ws.send_text(message["text"]) - elif "bytes" in message: - await server_ws.send_bytes(message["bytes"]) + try: + message = await client_ws.receive() + except Exception as e: + logger.warning(f"failed to receive message from client: {e}") + await server_ws.close(1006, f"failed to receive message from client: {e}") + return + + try: + if message.get("type") == "websocket.disconnect": + await server_ws.close(message["code"], message.get("reason")) + return + if "text" in message: + await server_ws.send_text(message["text"]) + elif "bytes" in message: + await server_ws.send_bytes(message["bytes"]) + except: + return + async def forward_to_client(): while True: - message = await server_ws.receive() + try: + message = await server_ws.receive() + except httpx_ws_exceptions.WebsocketDisconnect as e: + await client_ws.close(e.code, e.reason) + return + except Exception as e: + await client_ws.close(1006, f"failed to receive message from server: {e}") + return + if message is None: break - if isinstance(message, TextMessage): - await client_ws.send_text(message.data) - elif isinstance(message, BytesMessage): - await client_ws.send_bytes(message.data) + + try: + if isinstance(message, TextMessage): + await client_ws.send_text(message.data) + elif isinstance(message, BytesMessage): + await client_ws.send_bytes(message.data) + except: + return await client_ws.accept() - try: - await asyncio.gather(forward_to_client(), forward_to_server()) - finally: - await _safe_close_ws(client_ws, logger) - except httpx_ws_exceptions.HTTPXWSException as e: - logger.warning(f"WebSocket connection rejected: {e}") - await _safe_close_ws(client_ws, logger) - break + await asyncio.gather(forward_to_client(), forward_to_server()) + # The asyncio funcs all catch their own exceptions, so this should only be triggered by the aconnect_ws call. + except Exception as e: + logger.warning(f"error initializing websocket connection to server: {e}") + try: + await client_ws.close(1006, f"error initializing websocket connetion to server: {e}") + except: + pass control_app.add_websocket_route("/v1/websocket", proxy_ws)