Skip to content

Commit 6edd598

Browse files
committed
fix: bridge-proxy behavior on mac
1 parent 2427767 commit 6edd598

File tree

3 files changed

+375
-97
lines changed

3 files changed

+375
-97
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ dependencies = [
1010
"websockets>=12.0",
1111
"psutil>=5.8.0,<6",
1212
"Pillow>=10.0.0,<11",
13+
"httpx>=0.28.1,<0.29",
14+
"fastapi>=0.116.1,<0.117",
15+
"uvicorn>=0.35.0,<0.36",
1316
]
1417

1518
[dependency-groups]

src/bridge_proxy_server.py

Lines changed: 165 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,173 @@
1-
#!/usr/bin/env python3
2-
3-
"""
4-
HTTPServer used as proxy for trezord calls from the outside of docker container
5-
This is workaround for original ip not beeing passed to the container:
6-
https://github.com/docker/for-mac/issues/180
7-
Listening on port 21326 and routes requests to the trezord with changed Origin header
8-
"""
9-
10-
from http.server import BaseHTTPRequestHandler, HTTPServer
11-
from socketserver import ThreadingMixIn
12-
13-
import requests
14-
15-
import helpers
16-
17-
TREZORD_HOST = "0.0.0.0:21325"
18-
HEADERS = {
19-
"Host": TREZORD_HOST,
20-
"Origin": "https://user-env.trezor.io",
21-
}
22-
IP = "0.0.0.0"
23-
PORT = 21326
24-
LOG_COLOR = "green"
25-
26-
27-
def log(text: str, color: str = LOG_COLOR) -> None:
28-
helpers.log(f"BRIDGE PROXY: {text}", color)
1+
import asyncio
2+
import logging
3+
import uuid
4+
from contextlib import asynccontextmanager
5+
6+
import httpx
7+
import uvicorn
8+
from fastapi import FastAPI, Request, Response
9+
from fastapi.middleware.cors import CORSMiddleware
10+
from fastapi.responses import JSONResponse
11+
12+
logging.basicConfig(level=logging.INFO)
13+
logger = logging.getLogger("async_bridge_proxy")
14+
TREZORD_HOST = "http://0.0.0.0:21325"
15+
16+
app = FastAPI()
17+
18+
# Add CORS middleware
19+
app.add_middleware(
20+
CORSMiddleware,
21+
allow_origins=["*"],
22+
allow_credentials=True,
23+
allow_methods=["*"],
24+
allow_headers=["*"],
25+
)
26+
27+
# Track in-flight upstream tasks per session id to allow cancelling previous read when a new call arrives
28+
inflight_tasks: dict[int, asyncio.Task] = {}
29+
inflight_lock = asyncio.Lock()
30+
31+
32+
def _prepare_response_headers(
33+
resp: httpx.Response, request: Request, remove_content_length: bool = False
34+
) -> dict:
35+
"""Return a headers dict for returning to the client.
36+
37+
- Copies upstream headers
38+
- Removes Transfer-Encoding always
39+
- Optionally removes Content-Length (for chunked/streaming responses)
40+
"""
41+
headers = dict(resp.headers)
42+
headers.pop("transfer-encoding", None)
43+
if remove_content_length:
44+
headers.pop("content-length", None)
45+
return headers
2946

3047

31-
# POST request headers override
32-
# origin is set to the actual machine that made the call not localhost
33-
def merge_headers(original: dict) -> dict:
34-
headers = original.copy()
35-
headers.update(HEADERS)
48+
def _prepare_upstream_headers(request: Request) -> dict:
49+
"""Prepare headers for forwarding to upstream, replacing Origin and removing host."""
50+
# Build headers dict, excluding 'host' and case-insensitive 'origin'
51+
headers = {
52+
k: v for k, v in request.headers.items() if k.lower() not in ("host", "origin")
53+
}
54+
# Set Origin header from request or use default
55+
headers["Origin"] = request.headers.get("origin", "https://user-env.trezor.io")
3656
return headers
3757

3858

39-
class Handler(BaseHTTPRequestHandler):
40-
def do_HEAD(self) -> None:
41-
self.do_GET()
42-
43-
def do_GET(self) -> None:
44-
try:
45-
if self.path == "/status/":
46-
# read trezord status page
47-
url = f"http://{TREZORD_HOST}{self.path}"
48-
resp = requests.get(url)
49-
50-
self.send_response(resp.status_code)
51-
self.send_resp_headers(resp)
52-
self.wfile.write(resp.content)
53-
except Exception as e:
54-
self.send_error(404, f"Error trying to proxy: {self.path} Error: {repr(e)}")
55-
56-
def do_POST(self, body: bool = True) -> None:
57-
try:
58-
url = f"http://{TREZORD_HOST}{self.path}"
59-
data_len = int(self.headers.get("content-length", 0))
60-
data = self.rfile.read(data_len)
61-
headers = merge_headers(dict(self.headers))
62-
63-
resp = requests.post(url, data=data, headers=headers)
64-
65-
self.send_response(resp.status_code)
66-
self.send_resp_headers(resp)
67-
if body:
68-
self.wfile.write(resp.content)
69-
except Exception as e:
70-
self.send_error(404, f"Error trying to proxy: {self.path} Error: {repr(e)}")
71-
72-
def send_resp_headers(self, resp) -> None:
73-
# response Access-Control header needs to be exact with original
74-
# request from the caller
75-
self.send_header(
76-
"Access-Control-Allow-Origin",
77-
self.headers.get("Access-Control-Allow-Origin", "*"),
78-
)
79-
80-
# remove Access-Control and Transfer-Encoding headers
81-
# from the original trezord response
82-
h = dict(resp.headers)
83-
h.pop(
84-
"Transfer-Encoding", "chunked"
85-
) # this header returns empty response to the caller (trezor-link)
86-
h.pop("Access-Control-Allow-Origin", None)
87-
for key, value in h.items():
88-
self.send_header(key, value)
89-
self.end_headers()
90-
91-
def log_message(self, format, *args) -> None:
92-
"""Adds color to make the log clearer."""
93-
log(
94-
"%s - - [%s] %s\n"
95-
% (self.address_string(), self.log_date_time_string(), format % args),
96-
)
97-
98-
99-
class ThreadingServer(ThreadingMixIn, HTTPServer):
100-
pass
59+
def _create_response_from_upstream(resp: httpx.Response, request: Request) -> Response:
60+
"""Create a Response object from an upstream httpx.Response with proper logging."""
61+
logger.info(f"Upstream responded with status: {resp.status_code}")
62+
logger.info(f"Upstream headers: {dict(resp.headers)}")
63+
headers = _prepare_response_headers(resp, request, remove_content_length=False)
64+
return Response(resp.content, status_code=resp.status_code, headers=headers)
65+
66+
67+
async def _cancel_previous_session_task(session_id: int | None, req_id: str):
68+
"""Cancel any previous in-flight task for the given session."""
69+
if session_id is not None:
70+
async with inflight_lock:
71+
old = inflight_tasks.get(session_id)
72+
if old is not None:
73+
logger.info(
74+
f"[{req_id}] Cancelling prior inflight task for session {session_id}"
75+
)
76+
old.cancel()
77+
78+
79+
@asynccontextmanager
80+
async def _manage_session_task(session_id: int | None, task: asyncio.Task):
81+
"""Context manager to register and cleanup inflight session tasks."""
82+
if session_id is not None:
83+
async with inflight_lock:
84+
inflight_tasks[session_id] = task
85+
try:
86+
yield
87+
finally:
88+
task.cancel()
89+
if session_id is not None:
90+
async with inflight_lock:
91+
if inflight_tasks.get(session_id) is task:
92+
del inflight_tasks[session_id]
93+
94+
95+
async def _proxy_request(
96+
request: Request, path: str, session_id: int | None = None
97+
) -> Response:
98+
"""Proxy a request with optional session tracking."""
99+
url = f"{TREZORD_HOST}/{path}"
100+
headers = _prepare_upstream_headers(request)
101+
req_id = uuid.uuid4().hex[:8]
102+
logger.info(f"[{req_id}] Proxy received {request.method} request for path: /{path}")
103+
104+
body = await request.body() if request.method == "POST" else None
105+
if body:
106+
logger.info(f"POST body length: {len(body)}")
107+
logger.info(f"Forwarding {request.method} to upstream: {url}")
108+
109+
try:
110+
# Session-tracked requests (read/call): race against client disconnect
111+
if session_id is not None:
112+
await _cancel_previous_session_task(session_id, req_id)
113+
114+
# Use short-lived client without keep-alive for cancellable requests
115+
transport = httpx.AsyncHTTPTransport(retries=0)
116+
async with httpx.AsyncClient(timeout=None, transport=transport) as client:
117+
req_task = asyncio.create_task(
118+
client.request(request.method, url, headers=headers, content=body)
119+
)
120+
121+
async with _manage_session_task(session_id, req_task):
122+
# Poll with short timeouts and check client disconnect between polls
123+
while True:
124+
try:
125+
resp = await asyncio.wait_for(
126+
asyncio.shield(req_task), timeout=0.25
127+
)
128+
break # upstream finished
129+
except asyncio.TimeoutError:
130+
if await request.is_disconnected():
131+
logger.info(
132+
"Client disconnected before upstream response; cancelling upstream request"
133+
)
134+
try:
135+
await req_task
136+
except asyncio.CancelledError:
137+
pass
138+
return Response(status_code=499)
139+
140+
return _create_response_from_upstream(resp, request)
141+
# Simple requests: no session tracking
142+
else:
143+
async with httpx.AsyncClient(timeout=None) as client:
144+
resp = await client.request(
145+
request.method, url, headers=headers, content=body
146+
)
147+
return _create_response_from_upstream(resp, request)
148+
except httpx.RequestError as e:
149+
logger.error(f"Error proxying request: {e}")
150+
return JSONResponse({"error": str(e)}, status_code=502)
151+
152+
153+
@app.api_route("/read/{session_id}", methods=["GET", "POST"])
154+
async def proxy_read(request: Request, session_id: int):
155+
"""Proxy /read requests with session tracking."""
156+
return await _proxy_request(request, f"read/{session_id}", session_id=session_id)
157+
158+
159+
@app.api_route("/call/{session_id}", methods=["GET", "POST"])
160+
async def proxy_call(request: Request, session_id: int):
161+
"""Proxy /call requests with session tracking."""
162+
return await _proxy_request(request, f"call/{session_id}", session_id=session_id)
163+
164+
165+
@app.api_route("/{path:path}", methods=["GET", "POST"])
166+
async def proxy_all(request: Request, path: str):
167+
"""Proxy all other requests."""
168+
return await _proxy_request(request, path)
101169

102170

103171
if __name__ == "__main__":
104-
SERVER = ThreadingServer((IP, PORT), Handler)
105-
SERVER.serve_forever()
172+
# No reload, no workers, single process only
173+
uvicorn.run("async_bridge_proxy_server:app", host="0.0.0.0", port=21326)

0 commit comments

Comments
 (0)