Skip to content

Commit b3ebb8a

Browse files
committed
fix: bridge-proxy behavior on mac
1 parent d699116 commit b3ebb8a

File tree

3 files changed

+273
-95
lines changed

3 files changed

+273
-95
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ dependencies = [
1111
"psutil>=5.8.0,<6",
1212
"Pillow>=10.0.0,<11",
1313
"tvl @ https://github.com/tropicsquare/ts-tvl/releases/download/2.3/tvl-2.3-py3-none-any.whl",
14+
"httpx>=0.28.1,<0.29",
15+
"fastapi>=0.116.1,<0.117",
16+
"uvicorn>=0.35.0,<0.36",
1417
]
1518

1619
[dependency-groups]

src/bridge_proxy_server.py

Lines changed: 164 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,174 @@
1-
#!/usr/bin/env python3
2-
3-
"""
4-
HTTPServer used as proxy for trezord calls from the outside of docker container
5-
This is workaround for original ip not beeing passed to the container:
6-
https://github.com/docker/for-mac/issues/180
7-
Listening on port 21326 and routes requests to the trezord with changed Origin header
8-
"""
9-
10-
from http.server import BaseHTTPRequestHandler, HTTPServer
11-
from socketserver import ThreadingMixIn
12-
13-
import requests
14-
15-
import helpers
16-
17-
TREZORD_HOST = "0.0.0.0:21325"
18-
HEADERS = {
19-
"Host": TREZORD_HOST,
20-
"Origin": "https://user-env.trezor.io",
21-
}
22-
IP = "0.0.0.0"
1+
import asyncio
2+
import logging
3+
import uuid
4+
from contextlib import asynccontextmanager
5+
6+
import httpx
7+
import uvicorn
8+
from fastapi import FastAPI, Request, Response
9+
from fastapi.middleware.cors import CORSMiddleware
10+
from fastapi.responses import JSONResponse
11+
12+
logging.basicConfig(level=logging.INFO)
13+
logger = logging.getLogger("async_bridge_proxy")
14+
TREZORD_HOST = "http://0.0.0.0:21325"
2315
PORT = 21326
24-
LOG_COLOR = "green"
25-
2616

27-
def log(text: str, color: str = LOG_COLOR) -> None:
28-
helpers.log(f"BRIDGE PROXY: {text}", color)
17+
app = FastAPI()
18+
19+
# Add CORS middleware
20+
app.add_middleware(
21+
CORSMiddleware,
22+
allow_origins=["*"],
23+
allow_credentials=True,
24+
allow_methods=["*"],
25+
allow_headers=["*"],
26+
)
27+
28+
# Track in-flight upstream tasks per session id to allow cancelling previous read when a new call arrives
29+
inflight_tasks: dict[str, asyncio.Task] = {}
30+
inflight_lock = asyncio.Lock()
31+
32+
33+
def _prepare_response_headers(
34+
resp: httpx.Response, request: Request, remove_content_length: bool = False
35+
) -> dict:
36+
"""Return a headers dict for returning to the client.
37+
38+
- Copies upstream headers
39+
- Removes Transfer-Encoding always
40+
- Optionally removes Content-Length (for chunked/streaming responses)
41+
"""
42+
headers = dict(resp.headers)
43+
headers.pop("transfer-encoding", None)
44+
if remove_content_length:
45+
headers.pop("content-length", None)
46+
return headers
2947

3048

31-
# POST request headers override
32-
# origin is set to the actual machine that made the call not localhost
33-
def merge_headers(original: dict) -> dict:
34-
headers = original.copy()
35-
headers.update(HEADERS)
49+
def _prepare_upstream_headers(request: Request) -> dict:
50+
"""Prepare headers for forwarding to upstream, replacing Origin and removing host."""
51+
# Build headers dict, excluding 'host' and case-insensitive 'origin'
52+
headers = {
53+
k: v for k, v in request.headers.items() if k.lower() not in ("host", "origin")
54+
}
55+
# Set Origin header from request or use default
56+
headers["Origin"] = request.headers.get("origin", "https://user-env.trezor.io")
3657
return headers
3758

3859

39-
class Handler(BaseHTTPRequestHandler):
40-
def do_HEAD(self) -> None:
41-
self.do_GET()
42-
43-
def do_GET(self) -> None:
44-
try:
45-
if self.path == "/status/":
46-
# read trezord status page
47-
url = f"http://{TREZORD_HOST}{self.path}"
48-
resp = requests.get(url)
49-
50-
self.send_response(resp.status_code)
51-
self.send_resp_headers(resp)
52-
self.wfile.write(resp.content)
53-
except Exception as e:
54-
self.send_error(404, f"Error trying to proxy: {self.path} Error: {repr(e)}")
55-
56-
def do_POST(self, body: bool = True) -> None:
57-
try:
58-
url = f"http://{TREZORD_HOST}{self.path}"
59-
data_len = int(self.headers.get("content-length", 0))
60-
data = self.rfile.read(data_len)
61-
headers = merge_headers(dict(self.headers))
62-
63-
resp = requests.post(url, data=data, headers=headers)
64-
65-
self.send_response(resp.status_code)
66-
self.send_resp_headers(resp)
67-
if body:
68-
self.wfile.write(resp.content)
69-
except Exception as e:
70-
self.send_error(404, f"Error trying to proxy: {self.path} Error: {repr(e)}")
71-
72-
def send_resp_headers(self, resp) -> None:
73-
# response Access-Control header needs to be exact with original
74-
# request from the caller
75-
self.send_header(
76-
"Access-Control-Allow-Origin",
77-
self.headers.get("Access-Control-Allow-Origin", "*"),
78-
)
79-
80-
# remove Access-Control and Transfer-Encoding headers
81-
# from the original trezord response
82-
h = dict(resp.headers)
83-
h.pop(
84-
"Transfer-Encoding", "chunked"
85-
) # this header returns empty response to the caller (trezor-link)
86-
h.pop("Access-Control-Allow-Origin", None)
87-
for key, value in h.items():
88-
self.send_header(key, value)
89-
self.end_headers()
90-
91-
def log_message(self, format, *args) -> None:
92-
"""Adds color to make the log clearer."""
93-
log(
94-
"%s - - [%s] %s\n"
95-
% (self.address_string(), self.log_date_time_string(), format % args),
96-
)
97-
98-
99-
class ThreadingServer(ThreadingMixIn, HTTPServer):
100-
pass
60+
def _create_response_from_upstream(resp: httpx.Response, request: Request) -> Response:
61+
"""Create a Response object from an upstream httpx.Response with proper logging."""
62+
logger.info(f"Upstream responded with status: {resp.status_code}")
63+
logger.info(f"Upstream headers: {dict(resp.headers)}")
64+
headers = _prepare_response_headers(resp, request, remove_content_length=False)
65+
return Response(resp.content, status_code=resp.status_code, headers=headers)
66+
67+
68+
async def _cancel_previous_session_task(session_id: str | None, req_id: str):
69+
"""Cancel any previous in-flight task for the given session."""
70+
if session_id is not None:
71+
async with inflight_lock:
72+
old = inflight_tasks.get(session_id)
73+
if old is not None:
74+
logger.info(
75+
f"[{req_id}] Cancelling prior inflight task for session {session_id}"
76+
)
77+
old.cancel()
78+
79+
80+
@asynccontextmanager
81+
async def _manage_session_task(session_id: str | None, task: asyncio.Task):
82+
"""Context manager to register and cleanup inflight session tasks."""
83+
if session_id is not None:
84+
async with inflight_lock:
85+
inflight_tasks[session_id] = task
86+
try:
87+
yield
88+
finally:
89+
task.cancel()
90+
if session_id is not None:
91+
async with inflight_lock:
92+
if inflight_tasks.get(session_id) is task:
93+
del inflight_tasks[session_id]
94+
95+
96+
async def _proxy_request(
97+
request: Request, path: str, session_id: str | None = None
98+
) -> Response:
99+
"""Proxy a request with optional session tracking."""
100+
url = f"{TREZORD_HOST}/{path}"
101+
headers = _prepare_upstream_headers(request)
102+
req_id = uuid.uuid4().hex[:8]
103+
logger.info(f"[{req_id}] Proxy received {request.method} request for path: /{path}")
104+
105+
body = await request.body() if request.method == "POST" else None
106+
if body:
107+
logger.info(f"POST body length: {len(body)}")
108+
logger.info(f"Forwarding {request.method} to upstream: {url}")
109+
110+
try:
111+
# Session-tracked requests (read/call): race against client disconnect
112+
if session_id is not None:
113+
await _cancel_previous_session_task(session_id, req_id)
114+
115+
# Use short-lived client without keep-alive for cancellable requests
116+
transport = httpx.AsyncHTTPTransport(retries=0)
117+
async with httpx.AsyncClient(timeout=None, transport=transport) as client:
118+
req_task = asyncio.create_task(
119+
client.request(request.method, url, headers=headers, content=body)
120+
)
121+
122+
async with _manage_session_task(session_id, req_task):
123+
# Poll with short timeouts and check client disconnect between polls
124+
while True:
125+
try:
126+
resp = await asyncio.wait_for(
127+
asyncio.shield(req_task), timeout=0.25
128+
)
129+
break # upstream finished
130+
except asyncio.TimeoutError:
131+
if await request.is_disconnected():
132+
logger.info(
133+
"Client disconnected before upstream response; cancelling upstream request"
134+
)
135+
try:
136+
await req_task
137+
except asyncio.CancelledError:
138+
pass
139+
return Response(status_code=499)
140+
141+
return _create_response_from_upstream(resp, request)
142+
# Simple requests: no session tracking
143+
else:
144+
async with httpx.AsyncClient(timeout=None) as client:
145+
resp = await client.request(
146+
request.method, url, headers=headers, content=body
147+
)
148+
return _create_response_from_upstream(resp, request)
149+
except httpx.RequestError as e:
150+
logger.error(f"Error proxying request: {e}")
151+
return JSONResponse({"error": str(e)}, status_code=502)
152+
153+
154+
@app.api_route("/read/{session_id}", methods=["GET", "POST"])
155+
async def proxy_read(request: Request, session_id: str):
156+
"""Proxy /read requests with session tracking."""
157+
return await _proxy_request(request, f"read/{session_id}", session_id=session_id)
158+
159+
160+
@app.api_route("/call/{session_id}", methods=["GET", "POST"])
161+
async def proxy_call(request: Request, session_id: str):
162+
"""Proxy /call requests with session tracking."""
163+
return await _proxy_request(request, f"call/{session_id}", session_id=session_id)
164+
165+
166+
@app.api_route("/{path:path}", methods=["GET", "POST"])
167+
async def proxy_all(request: Request, path: str):
168+
"""Proxy all other requests."""
169+
return await _proxy_request(request, path)
101170

102171

103172
if __name__ == "__main__":
104-
SERVER = ThreadingServer((IP, PORT), Handler)
105-
SERVER.serve_forever()
173+
# No reload, no workers, single process only
174+
uvicorn.run("bridge_proxy_server:app", host="0.0.0.0", port=PORT)

0 commit comments

Comments
 (0)