Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/gen_worker/cozy_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,11 @@ def _safe_symlink_dir(target: Path, link: Path) -> None:

@backoff.on_exception(
backoff.expo,
(aiohttp.ClientError, asyncio.TimeoutError, ValueError, OSError, ConnectionError),
max_tries=max(1, int(os.getenv("WORKER_MODEL_DOWNLOAD_MAX_RETRIES", "12") or "12")),
max_time=max(30.0, float(os.getenv("WORKER_MODEL_DOWNLOAD_RETRY_MAX_TIME_S", "900") or "900")),
max_value=max(1.0, float(os.getenv("WORKER_MODEL_DOWNLOAD_BACKOFF_MAX_S", "8") or "8")),
(aiohttp.ClientError, asyncio.TimeoutError, ValueError, OSError),
max_tries=30,
max_time=3600,
factor=1,
max_value=30, # cap backoff at 30s between retries
)
async def _download_one_file(url: str, dst: Path, expected_size: int, expected_blake3: str) -> None:
if dst.exists():
Expand All @@ -319,12 +320,11 @@ async def _download_one_file(url: str, dst: Path, expected_size: int, expected_b
# Fall through to re-download.
pass

timeout = aiohttp.ClientTimeout(
total=None,
connect=float(os.getenv("WORKER_MODEL_DOWNLOAD_CONNECT_TIMEOUT_S", "60") or "60"),
sock_connect=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_CONNECT_TIMEOUT_S", "60") or "60"),
sock_read=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_READ_TIMEOUT_S", "120") or "120"),
)
# Use sock_read instead of total timeout so actively-streaming large files
# are not killed. total=None lets multi-GB downloads run as long as data
# keeps flowing; sock_read=120 catches genuine stalls.
timeout = aiohttp.ClientTimeout(total=None, sock_connect=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_CONNECT_TIMEOUT_S", "60")),
sock_read=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_READ_TIMEOUT_S", "120")))
tmp = dst.with_suffix(dst.suffix + ".part")
# If we have a partial file, try to resume via HTTP Range.
offset = 0
Expand Down
18 changes: 18 additions & 0 deletions src/gen_worker/cozy_snapshot_v2_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,24 @@ def _get_lock(self, mp: Dict[str, asyncio.Lock], key: str) -> asyncio.Lock:
return lock


async def ensure_snapshot_async(
*,
base_dir: Path,
ref: CozyRef,
base_url: str,
token: Optional[str],
resolved: Optional[Any] = None,
) -> Path:
"""Async version of ensure_snapshot_sync for use in async contexts."""
client: Optional[CozyHubV2Client] = None
if resolved is None:
if not (base_url or "").strip():
raise RuntimeError("cozy downloads require TENSORHUB_URL")
client = CozyHubV2Client(base_url=base_url, token=token)
dl = CozySnapshotV2Downloader(client)
return await dl.ensure_snapshot(base_dir, ref, resolved=resolved)


def ensure_snapshot_sync(
*,
base_dir: Path,
Expand Down
10 changes: 9 additions & 1 deletion src/gen_worker/diffusers_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,18 @@ async def load_model_into_vram(self, model_id: str) -> bool:
local_path: Optional[str] = None
if self._downloader is not None:
from .cache_paths import worker_model_cache_dir
from .model_refs import parse_model_ref
from pathlib import Path

cache_dir = str(worker_model_cache_dir())
try:
local_path = self._downloader.download(model_id, cache_dir)
# Use async download path directly to avoid nested event loop issues.
if hasattr(self._downloader, '_download_async'):
parsed = parse_model_ref(model_id)
result = await self._downloader._download_async(parsed, Path(cache_dir))
local_path = result.as_posix()
else:
local_path = self._downloader.download(model_id, cache_dir)
except Exception as e:
logger.warning("DiffusersModelManager: download failed for %s: %s", model_id, e)

Expand Down
17 changes: 9 additions & 8 deletions src/gen_worker/model_ref_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time

from .cozy_cas import CozyHubClient, CozySnapshotDownloader
from .cozy_snapshot_v2_downloader import ensure_snapshot_sync
from .cozy_snapshot_v2_downloader import ensure_snapshot_async, ensure_snapshot_sync
from .downloader import ModelDownloader
from .tensorhub_v2 import (
CozyHubError,
Expand All @@ -19,6 +19,7 @@
)
from .hf_downloader import HuggingFaceHubDownloader
from .model_refs import CozyRef, ParsedModelRef, parse_model_ref
import threading

# Per-task resolved manifests provided by gen-orchestrator (issue #92).
# Shape: {canonical_model_id: ResolvedCozyModel-like object}
Expand Down Expand Up @@ -134,7 +135,7 @@ async def _download_async(self, parsed: ParsedModelRef, dest_dir: Path) -> Path:
canonical = parsed.hf.canonical()
prefs = _get_prefs_for_ref(canonical)
resolved_artifact = await self._request_public_model_with_wait(canonical, prefs=prefs)
return ensure_snapshot_sync(
return await ensure_snapshot_async(
base_dir=dest_dir,
ref=CozyRef(owner="public", repo="public", tag="latest"),
base_url=self._cozy_base_url or "",
Expand All @@ -151,7 +152,7 @@ async def _download_async(self, parsed: ParsedModelRef, dest_dir: Path) -> Path:
resolved_entry = _lookup_resolved_cozy_entry(resolved_mapping, canonical)

if resolved_entry is not None:
return ensure_snapshot_sync(
return await ensure_snapshot_async(
base_dir=dest_dir,
ref=parsed.cozy,
base_url=self._cozy_base_url or "",
Expand All @@ -164,7 +165,7 @@ async def _download_async(self, parsed: ParsedModelRef, dest_dir: Path) -> Path:
if self._cozy_v2 is not None and parsed.cozy.digest is None:
prefs = _get_prefs_for_ref(canonical)
resolved = await self._request_public_model_with_wait(canonical, prefs=prefs)
return ensure_snapshot_sync(
return await ensure_snapshot_async(
base_dir=dest_dir,
ref=parsed.cozy,
base_url=self._cozy_base_url or "",
Expand All @@ -181,7 +182,7 @@ async def _download_async(self, parsed: ParsedModelRef, dest_dir: Path) -> Path:

# Prefer Cozy Hub v2 resolve flow.
try:
return ensure_snapshot_sync(
return await ensure_snapshot_async(
base_dir=dest_dir,
ref=parsed.cozy,
base_url=self._cozy_base_url,
Expand Down Expand Up @@ -291,14 +292,14 @@ def _run_in_thread(coro: Coroutine[Any, Any, Path]) -> str:
out: dict[str, str] = {}
err: dict[str, BaseException] = {}

ctx = contextvars.copy_context()

def runner() -> None:
try:
out["v"] = asyncio.run(coro).as_posix()
out["v"] = ctx.run(asyncio.run, coro).as_posix()
except BaseException as e:
err["e"] = e

import threading

t = threading.Thread(target=runner, daemon=True)
t.start()
t.join()
Expand Down
32 changes: 16 additions & 16 deletions src/gen_worker/pb/frontend_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 14 additions & 14 deletions src/gen_worker/pb/frontend_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __init__(self, channel):
request_serializer=frontend__pb2.ExecuteActionRequest.SerializeToString,
response_deserializer=frontend__pb2.ExecuteActionResponse.FromString,
_registered_method=True)
self.CancelRequest = channel.unary_unary(
'/frontend.v1.FrontendService/CancelRequest',
request_serializer=frontend__pb2.CancelRequestRequest.SerializeToString,
response_deserializer=frontend__pb2.CancelRequestResponse.FromString,
self.CancelRun = channel.unary_unary(
'/frontend.v1.FrontendService/CancelRun',
request_serializer=frontend__pb2.CancelRunRequest.SerializeToString,
response_deserializer=frontend__pb2.CancelRunResponse.FromString,
_registered_method=True)
self.RealtimeSession = channel.stream_stream(
'/frontend.v1.FrontendService/RealtimeSession',
Expand All @@ -63,8 +63,8 @@ def ExecuteAction(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def CancelRequest(self, request, context):
"""2) Cancel an in-flight request.
def CancelRun(self, request, context):
"""2) Cancel an in-flight action/job.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
Expand All @@ -85,10 +85,10 @@ def add_FrontendServiceServicer_to_server(servicer, server):
request_deserializer=frontend__pb2.ExecuteActionRequest.FromString,
response_serializer=frontend__pb2.ExecuteActionResponse.SerializeToString,
),
'CancelRequest': grpc.unary_unary_rpc_method_handler(
servicer.CancelRequest,
request_deserializer=frontend__pb2.CancelRequestRequest.FromString,
response_serializer=frontend__pb2.CancelRequestResponse.SerializeToString,
'CancelRun': grpc.unary_unary_rpc_method_handler(
servicer.CancelRun,
request_deserializer=frontend__pb2.CancelRunRequest.FromString,
response_serializer=frontend__pb2.CancelRunResponse.SerializeToString,
),
'RealtimeSession': grpc.stream_stream_rpc_method_handler(
servicer.RealtimeSession,
Expand Down Expand Up @@ -135,7 +135,7 @@ def ExecuteAction(request,
_registered_method=True)

@staticmethod
def CancelRequest(request,
def CancelRun(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -148,9 +148,9 @@ def CancelRequest(request,
return grpc.experimental.unary_unary(
request,
target,
'/frontend.v1.FrontendService/CancelRequest',
frontend__pb2.CancelRequestRequest.SerializeToString,
frontend__pb2.CancelRequestResponse.FromString,
'/frontend.v1.FrontendService/CancelRun',
frontend__pb2.CancelRunRequest.SerializeToString,
frontend__pb2.CancelRunResponse.FromString,
options,
channel_credentials,
insecure,
Expand Down
Loading