diff --git a/pyproject.toml b/pyproject.toml index 609f5d6..3d9ea29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,9 @@ dependencies = [ "psutil>=5.8.0,<6", "Pillow>=10.0.0,<11", "tvl @ https://github.com/tropicsquare/ts-tvl/releases/download/2.3/tvl-2.3-py3-none-any.whl", + "httpx>=0.28.1,<0.29", + "fastapi>=0.116.1,<0.117", + "uvicorn>=0.35.0,<0.36", ] [dependency-groups] diff --git a/src/bridge_proxy_server.py b/src/bridge_proxy_server.py index 77ddb05..4a688fe 100644 --- a/src/bridge_proxy_server.py +++ b/src/bridge_proxy_server.py @@ -1,105 +1,174 @@ -#!/usr/bin/env python3 - -""" -HTTPServer used as proxy for trezord calls from the outside of docker container -This is workaround for original ip not beeing passed to the container: - https://github.com/docker/for-mac/issues/180 -Listening on port 21326 and routes requests to the trezord with changed Origin header -""" - -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn - -import requests - -import helpers - -TREZORD_HOST = "0.0.0.0:21325" -HEADERS = { - "Host": TREZORD_HOST, - "Origin": "https://user-env.trezor.io", -} -IP = "0.0.0.0" +import asyncio +import logging +import uuid +from contextlib import asynccontextmanager + +import httpx +import uvicorn +from fastapi import FastAPI, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("async_bridge_proxy") +TREZORD_HOST = "http://0.0.0.0:21325" PORT = 21326 -LOG_COLOR = "green" - -def log(text: str, color: str = LOG_COLOR) -> None: - helpers.log(f"BRIDGE PROXY: {text}", color) +app = FastAPI() + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Track in-flight upstream tasks per session id to allow cancelling previous read when a new call arrives +inflight_tasks: dict[str, asyncio.Task] = {} +inflight_lock = asyncio.Lock() + + +def _prepare_response_headers( + resp: httpx.Response, request: Request, remove_content_length: bool = False +) -> dict: + """Return a headers dict for returning to the client. + + - Copies upstream headers + - Removes Transfer-Encoding always + - Optionally removes Content-Length (for chunked/streaming responses) + """ + headers = dict(resp.headers) + headers.pop("transfer-encoding", None) + if remove_content_length: + headers.pop("content-length", None) + return headers -# POST request headers override -# origin is set to the actual machine that made the call not localhost -def merge_headers(original: dict) -> dict: - headers = original.copy() - headers.update(HEADERS) +def _prepare_upstream_headers(request: Request) -> dict: + """Prepare headers for forwarding to upstream, replacing Origin and removing host.""" + # Build headers dict, excluding 'host' and case-insensitive 'origin' + headers = { + k: v for k, v in request.headers.items() if k.lower() not in ("host", "origin") + } + # Set Origin header from request or use default + headers["Origin"] = request.headers.get("origin", "https://user-env.trezor.io") return headers -class Handler(BaseHTTPRequestHandler): - def do_HEAD(self) -> None: - self.do_GET() - - def do_GET(self) -> None: - try: - if self.path == "/status/": - # read trezord status page - url = f"http://{TREZORD_HOST}{self.path}" - resp = requests.get(url) - - self.send_response(resp.status_code) - self.send_resp_headers(resp) - self.wfile.write(resp.content) - except Exception as e: - self.send_error(404, f"Error trying to proxy: {self.path} Error: {repr(e)}") - - def do_POST(self, body: bool = True) -> None: - try: - url = f"http://{TREZORD_HOST}{self.path}" - data_len = int(self.headers.get("content-length", 0)) - data = self.rfile.read(data_len) - headers = merge_headers(dict(self.headers)) - - resp = requests.post(url, data=data, headers=headers) - - self.send_response(resp.status_code) - self.send_resp_headers(resp) - if body: - self.wfile.write(resp.content) - except Exception as e: - self.send_error(404, f"Error trying to proxy: {self.path} Error: {repr(e)}") - - def send_resp_headers(self, resp) -> None: - # response Access-Control header needs to be exact with original - # request from the caller - self.send_header( - "Access-Control-Allow-Origin", - self.headers.get("Access-Control-Allow-Origin", "*"), - ) - - # remove Access-Control and Transfer-Encoding headers - # from the original trezord response - h = dict(resp.headers) - h.pop( - "Transfer-Encoding", "chunked" - ) # this header returns empty response to the caller (trezor-link) - h.pop("Access-Control-Allow-Origin", None) - for key, value in h.items(): - self.send_header(key, value) - self.end_headers() - - def log_message(self, format, *args) -> None: - """Adds color to make the log clearer.""" - log( - "%s - - [%s] %s\n" - % (self.address_string(), self.log_date_time_string(), format % args), - ) - - -class ThreadingServer(ThreadingMixIn, HTTPServer): - pass +def _create_response_from_upstream(resp: httpx.Response, request: Request) -> Response: + """Create a Response object from an upstream httpx.Response with proper logging.""" + logger.info(f"Upstream responded with status: {resp.status_code}") + logger.info(f"Upstream headers: {dict(resp.headers)}") + headers = _prepare_response_headers(resp, request, remove_content_length=False) + return Response(resp.content, status_code=resp.status_code, headers=headers) + + +async def _cancel_previous_session_task(session_id: str | None, req_id: str): + """Cancel any previous in-flight task for the given session.""" + if session_id is not None: + async with inflight_lock: + old = inflight_tasks.get(session_id) + if old is not None: + logger.info( + f"[{req_id}] Cancelling prior inflight task for session {session_id}" + ) + old.cancel() + + +@asynccontextmanager +async def _manage_session_task(session_id: str | None, task: asyncio.Task): + """Context manager to register and cleanup inflight session tasks.""" + if session_id is not None: + async with inflight_lock: + inflight_tasks[session_id] = task + try: + yield + finally: + task.cancel() + if session_id is not None: + async with inflight_lock: + if inflight_tasks.get(session_id) is task: + del inflight_tasks[session_id] + + +async def _proxy_request( + request: Request, path: str, session_id: str | None = None +) -> Response: + """Proxy a request with optional session tracking.""" + url = f"{TREZORD_HOST}/{path}" + headers = _prepare_upstream_headers(request) + req_id = uuid.uuid4().hex[:8] + logger.info(f"[{req_id}] Proxy received {request.method} request for path: /{path}") + + body = await request.body() if request.method == "POST" else None + if body: + logger.info(f"POST body length: {len(body)}") + logger.info(f"Forwarding {request.method} to upstream: {url}") + + try: + # Session-tracked requests (read/call): race against client disconnect + if session_id is not None: + await _cancel_previous_session_task(session_id, req_id) + + # Use short-lived client without keep-alive for cancellable requests + transport = httpx.AsyncHTTPTransport(retries=0) + async with httpx.AsyncClient(timeout=None, transport=transport) as client: + req_task = asyncio.create_task( + client.request(request.method, url, headers=headers, content=body) + ) + + async with _manage_session_task(session_id, req_task): + # Poll with short timeouts and check client disconnect between polls + while True: + try: + resp = await asyncio.wait_for( + asyncio.shield(req_task), timeout=0.25 + ) + break # upstream finished + except asyncio.TimeoutError: + if await request.is_disconnected(): + logger.info( + "Client disconnected before upstream response; cancelling upstream request" + ) + try: + await req_task + except asyncio.CancelledError: + pass + return Response(status_code=499) + + return _create_response_from_upstream(resp, request) + # Simple requests: no session tracking + else: + async with httpx.AsyncClient(timeout=None) as client: + resp = await client.request( + request.method, url, headers=headers, content=body + ) + return _create_response_from_upstream(resp, request) + except httpx.RequestError as e: + logger.error(f"Error proxying request: {e}") + return JSONResponse({"error": str(e)}, status_code=502) + + +@app.api_route("/read/{session_id}", methods=["GET", "POST"]) +async def proxy_read(request: Request, session_id: str): + """Proxy /read requests with session tracking.""" + return await _proxy_request(request, f"read/{session_id}", session_id=session_id) + + +@app.api_route("/call/{session_id}", methods=["GET", "POST"]) +async def proxy_call(request: Request, session_id: str): + """Proxy /call requests with session tracking.""" + return await _proxy_request(request, f"call/{session_id}", session_id=session_id) + + +@app.api_route("/{path:path}", methods=["GET", "POST"]) +async def proxy_all(request: Request, path: str): + """Proxy all other requests.""" + return await _proxy_request(request, path) if __name__ == "__main__": - SERVER = ThreadingServer((IP, PORT), Handler) - SERVER.serve_forever() + # No reload, no workers, single process only + uvicorn.run("bridge_proxy_server:app", host="0.0.0.0", port=PORT) diff --git a/uv.lock b/uv.lock index 9f9d7c1..48f1b3d 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,20 @@ resolution-markers = [ "python_full_version < '3.13'", ] +[[package]] +name = "anyio" +version = "4.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, +] + [[package]] name = "atomicwrites" version = "1.4.1" @@ -249,6 +263,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, ] +[[package]] +name = "fastapi" +version = "0.116.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/01/64/1296f46d6b9e3b23fb22e5d01af3f104ef411425531376212f1eefa2794d/fastapi-0.116.2.tar.gz", hash = "sha256:231a6af2fe21cfa2c32730170ad8514985fc250bec16c9b242d3b94c835ef529", size = 298595, upload-time = "2025-09-16T18:29:23.058Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/e4/c543271a8018874b7f682bf6156863c416e1334b8ed3e51a69495c5d4360/fastapi-0.116.2-py3-none-any.whl", hash = "sha256:c3a7a8fb830b05f7e087d920e0d786ca1fc9892eb4e9a84b227be4c1bc7569db", size = 95670, upload-time = "2025-09-16T18:29:21.329Z" }, +] + [[package]] name = "flake8" version = "7.2.0" @@ -263,6 +291,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/83/5c/0627be4c9976d56b1217cb5187b7504e7fd7d3503f8bfd312a04077bd4f7/flake8-7.2.0-py2.py3-none-any.whl", hash = "sha256:93b92ba5bdb60754a6da14fa3b93a9361fd00a59632ada61fd7b130436c40343", size = 57786, upload-time = "2025-03-29T20:08:37.902Z" }, ] +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -941,6 +1006,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/cb/6547206300861cfbdf8e105ea6e76132550daf48039b7ac9efc9f41f8c27/slip10-1.0.1-py3-none-any.whl", hash = "sha256:4aa764369db0a261e468160ec1afeeb2b22d26392dd118c49b9daa91f642947b", size = 10708, upload-time = "2024-09-03T13:50:38.884Z" }, ] +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, +] + +[[package]] +name = "starlette" +version = "0.48.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/72/2db2f49247d0a18b4f1bb9a5a39a0162869acf235f3a96418363947b3d46/starlette-0.48.0-py3-none-any.whl", hash = "sha256:0764ca97b097582558ecb498132ed0c7d942f233f365b86ba37770e026510659", size = 73736, upload-time = "2025-09-13T08:41:03.869Z" }, +] + [[package]] name = "termcolor" version = "1.1.0" @@ -989,11 +1076,14 @@ name = "trezor-user-env" version = "0.0.0" source = { editable = "." } dependencies = [ + { name = "fastapi" }, + { name = "httpx" }, { name = "pillow" }, { name = "psutil" }, { name = "termcolor" }, { name = "trezor" }, { name = "tvl" }, + { name = "uvicorn" }, { name = "websockets" }, ] @@ -1011,11 +1101,14 @@ dev = [ [package.metadata] requires-dist = [ + { name = "fastapi", specifier = ">=0.116.1,<0.117" }, + { name = "httpx", specifier = ">=0.28.1,<0.29" }, { name = "pillow", specifier = ">=10.0.0,<11" }, { name = "psutil", specifier = ">=5.8.0,<6" }, { name = "termcolor", specifier = ">=1.1.0,<2" }, { name = "trezor", git = "https://github.com/trezor/trezor-firmware.git?subdirectory=python&rev=fc95bb93fcc2e54dce5be63f5b2a94c551c26267" }, { name = "tvl", url = "https://github.com/tropicsquare/ts-tvl/releases/download/2.3/tvl-2.3-py3-none-any.whl" }, + { name = "uvicorn", specifier = ">=0.35.0,<0.36" }, { name = "websockets", specifier = ">=12.0" }, ] @@ -1119,6 +1212,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680, upload-time = "2025-04-10T15:23:37.377Z" }, ] +[[package]] +name = "uvicorn" +version = "0.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/42/e0e305207bb88c6b8d3061399c6a961ffe5fbb7e2aa63c9234df7259e9cd/uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01", size = 78473, upload-time = "2025-06-28T16:15:46.058Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/e2/dc81b1bd1dcfe91735810265e9d26bc8ec5da45b4c0f6237e286819194c3/uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a", size = 66406, upload-time = "2025-06-28T16:15:44.816Z" }, +] + [[package]] name = "websockets" version = "15.0.1"