Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 132 additions & 65 deletions src/gen_worker/cozy_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,16 @@ def _safe_symlink_dir(target: Path, link: Path) -> None:
max_value=30, # cap backoff at 30s between retries
)
async def _download_one_file(url: str, dst: Path, expected_size: int, expected_blake3: str) -> None:
import fcntl
import logging
log = logging.getLogger("gen_worker.download")

log.info("download_start path=%s expected_size=%s expected_blake3=%s", dst.name, expected_size, (expected_blake3 or "")[:16])
print(f"DEBUG download_start path={dst.name} expected_size={expected_size} expected_blake3={(expected_blake3 or '')[:16]}")

if dst.exists():
log.info("dst_exists path=%s size=%s", dst, dst.stat().st_size)
print(f"DEBUG dst_exists path={dst} size={dst.stat().st_size}")
try:
if expected_size and dst.stat().st_size != expected_size:
raise ValueError("size mismatch")
Expand All @@ -326,73 +335,131 @@ async def _download_one_file(url: str, dst: Path, expected_size: int, expected_b
timeout = aiohttp.ClientTimeout(total=None, sock_connect=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_CONNECT_TIMEOUT_S", "60")),
sock_read=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_READ_TIMEOUT_S", "180")))
tmp = dst.with_suffix(dst.suffix + ".part")
# If we have a partial file, try to resume via HTTP Range.
offset = 0
if tmp.exists():
try:
offset = tmp.stat().st_size
except Exception:
offset = 0
if expected_size and offset > expected_size:
tmp.unlink(missing_ok=True)
offset = 0

# If the partial file is already complete, validate + finalize.
if offset and expected_size and offset == expected_size:
got = _blake3_file(tmp)
if expected_blake3 and got.lower() != expected_blake3.lower():
tmp.unlink(missing_ok=True)
else:
tmp.rename(dst)
return
lock_path = dst.with_suffix(dst.suffix + ".lock")

headers: Dict[str, str] = {}
mode = "wb"
if offset and expected_size:
headers["Range"] = f"bytes={offset}-"
mode = "ab"

async def _stream_to_file(resp: aiohttp.ClientResponse, *, mode: str, start: int) -> None:
nonlocal expected_size
size = start
with open(tmp, mode) as f:
async for chunk in resp.content.iter_chunked(1 << 20):
if not chunk:
continue
f.write(chunk)
size += len(chunk)
if expected_size and size > expected_size:
raise ValueError("download exceeded expected size")

async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as resp:
# If the server ignored our Range request, restart from scratch to avoid
# duplicating bytes by appending a full response.
# Some gateways can return 206 with an unexpected range start.
# Treat that the same as a 200-on-resume and restart from byte 0.
if offset and (
resp.status == 200
or (
resp.status == 206
and not str(resp.headers.get("Content-Range") or "").strip().startswith(f"bytes {offset}-")
)
):
resp.release()
async with session.get(url) as resp2:
resp2.raise_for_status()
await _stream_to_file(resp2, mode="wb", start=0)
# File-level exclusive lock: prevents concurrent writes to the same .part
# file even from different async tasks or downloader instances.
lock_path.parent.mkdir(parents=True, exist_ok=True)
lock_fd = open(lock_path, "w")
try:
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
print(f"DEBUG file_lock_acquired path={dst.name}")

# Re-check dst after acquiring the lock — another holder might have
# already completed the download while we waited.
if dst.exists():
try:
if expected_size and dst.stat().st_size != expected_size:
raise ValueError("size mismatch after lock")
if expected_blake3:
got = _blake3_file(dst)
if got.lower() != expected_blake3.lower():
raise ValueError("blake3 mismatch after lock")
print(f"DEBUG file_lock_dst_completed path={dst.name} (another writer finished)")
return
except Exception:
pass

# If we have a partial file, try to resume via HTTP Range.
offset = 0
if tmp.exists():
try:
offset = tmp.stat().st_size
except OSError:
# Another coroutine may have renamed tmp→dst between the exists() check and stat().
offset = 0
if offset:
log.info("resume_attempt path=%s offset=%s expected_size=%s", dst.name, offset, expected_size)
print(f"DEBUG resume_attempt path={dst.name} offset={offset} expected_size={expected_size}")
if expected_size and offset > expected_size:
tmp.unlink(missing_ok=True)
offset = 0

# If the partial file is already complete, validate + finalize.
if offset and expected_size and offset == expected_size:
got = _blake3_file(tmp)
if expected_blake3 and got.lower() != expected_blake3.lower():
tmp.unlink(missing_ok=True)
else:
resp.raise_for_status()
await _stream_to_file(resp, mode=mode, start=offset)

# Validate final file.
if expected_size and tmp.stat().st_size != expected_size:
raise ValueError(f"size mismatch (expected {expected_size}, got {tmp.stat().st_size})")
if expected_blake3:
got = _blake3_file(tmp)
if got.lower() != expected_blake3.lower():
raise ValueError("blake3 mismatch")
tmp.rename(dst)
tmp.rename(dst)
return

headers: Dict[str, str] = {}
mode = "wb"
if offset and expected_size:
headers["Range"] = f"bytes={offset}-"
mode = "ab"
print(f"DEBUG range_header path={dst.name} Range=bytes={offset}- mode={mode}")

async def _stream_to_file(resp: aiohttp.ClientResponse, *, mode: str, start: int) -> None:
nonlocal expected_size
size = start
with open(tmp, mode) as f:
async for chunk in resp.content.iter_chunked(1 << 20):
if not chunk:
continue
f.write(chunk)
size += len(chunk)
if expected_size and size > expected_size:
raise ValueError("download exceeded expected size")

async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as resp:
content_range = str(resp.headers.get("Content-Range") or "").strip()
print(f"DEBUG http_response path={dst.name} status={resp.status} content_range={content_range!r} content_length={resp.headers.get('Content-Length', 'unknown')} offset={offset}")
# If the server ignored our Range request, restart from scratch to avoid
# duplicating bytes by appending a full response.
# Some gateways can return 206 with an unexpected range start.
# Treat that the same as a 200-on-resume and restart from byte 0.
if offset and (
resp.status == 200
or (
resp.status == 206
and not content_range.startswith(f"bytes {offset}-")
)
):
print(f"DEBUG range_ignored path={dst.name} status={resp.status} content_range={content_range!r} restarting_from_zero=True")
resp.release()
async with session.get(url) as resp2:
resp2.raise_for_status()
print(f"DEBUG range_restart path={dst.name} status={resp2.status} content_length={resp2.headers.get('Content-Length', 'unknown')}")
await _stream_to_file(resp2, mode="wb", start=0)
else:
resp.raise_for_status()
await _stream_to_file(resp, mode=mode, start=offset)

# Validate final file.
actual_size = tmp.stat().st_size
log.info("download_complete path=%s actual_size=%s expected_size=%s", dst.name, actual_size, expected_size)
print(f"DEBUG download_complete path={dst.name} actual_size={actual_size} expected_size={expected_size}")
if expected_size and actual_size != expected_size:
log.error("size_mismatch path=%s expected=%s got=%s url=%s", dst.name, expected_size, actual_size, url[:80])
print(f"DEBUG size_mismatch path={dst.name} expected={expected_size} got={actual_size} url={url[:80]}")
tmp.unlink(missing_ok=True)
raise ValueError(f"size mismatch (expected {expected_size}, got {actual_size})")
if expected_blake3:
got = _blake3_file(tmp)
log.info("blake3_check path=%s expected=%s got=%s", dst.name, (expected_blake3 or "")[:16], got[:16])
print(f"DEBUG blake3_check path={dst.name} expected={(expected_blake3 or '')[:16]} got={got[:16]}")
if got.lower() != expected_blake3.lower():
log.error("blake3_mismatch path=%s", dst.name)
print(f"DEBUG blake3_mismatch path={dst.name}")
tmp.unlink(missing_ok=True)
raise ValueError("blake3 mismatch")
# A concurrent coroutine may have already renamed tmp→dst (won the race).
# Use an atomic replace so we don't fail if dst now exists.
try:
tmp.replace(dst)
except OSError:
# dst was created by another coroutine; .part is stale, just remove it.
tmp.unlink(missing_ok=True)
finally:
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_UN)
lock_fd.close()
try:
lock_path.unlink(missing_ok=True)
except OSError:
pass


def _blake3_file(path: Path, chunk_size: int = 1 << 20) -> str:
Expand Down
57 changes: 43 additions & 14 deletions src/gen_worker/cozy_pipeline_spec.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from __future__ import annotations

import json
import logging
import os
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import yaml

logger = logging.getLogger(__name__)

COZY_PIPELINE_LOCK_FILENAME = "cozy.pipeline.lock.yaml"
COZY_PIPELINE_FILENAME = "cozy.pipeline.yaml"
PIPELINE_LOCK_TOML_FILENAME = "pipeline.lock"
PIPELINE_TOML_FILENAME = "pipeline.toml"
DIFFUSERS_MODEL_INDEX_FILENAME = "model_index.json"


Expand All @@ -33,6 +38,13 @@ def custom_pipeline_path(self) -> Optional[str]:
s = str(v).strip()
return s or None

@property
def variant(self) -> Optional[str]:
"""Diffusers variant (e.g. 'fp16', 'fp8') from the pipeline spec."""
pipe = self.raw.get("pipe") or {}
v = str(pipe.get("variant") or "").strip()
return v or None


def _safe_child_path(root: Path, rel: str) -> Path:
# Ensure rel doesn't escape root (best-effort).
Expand All @@ -50,24 +62,41 @@ def load_cozy_pipeline_spec(model_root: Path) -> Optional[CozyPipelineSpec]:
This is a worker-side helper used during pipeline loading to implement:
- prefer `cozy.pipeline.lock.yaml` when present
- fall back to `cozy.pipeline.yaml` otherwise
- fall back to `pipeline.lock` / `pipeline.toml` (TOML) if no YAML found
"""
root = Path(model_root)
lock_path = root / COZY_PIPELINE_LOCK_FILENAME
spec_path = lock_path if lock_path.exists() else (root / COZY_PIPELINE_FILENAME)
if not spec_path.exists():
return None

raw = yaml.safe_load(spec_path.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("invalid cozy pipeline spec (expected mapping)")
api = str(raw.get("apiVersion") or "").strip()
kind = str(raw.get("kind") or "").strip()
if api and api != "v1":
raise ValueError(f"unsupported cozy pipeline apiVersion: {api!r}")
if kind and kind != "DiffusersPipeline":
raise ValueError(f"unsupported cozy pipeline kind: {kind!r}")

return CozyPipelineSpec(source_path=spec_path, raw=raw)
if spec_path.exists():
raw = yaml.safe_load(spec_path.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("invalid cozy pipeline spec (expected mapping)")
api = str(raw.get("apiVersion") or "").strip()
kind = str(raw.get("kind") or "").strip()
if api and api != "v1":
raise ValueError(f"unsupported cozy pipeline apiVersion: {api!r}")
if kind and kind != "DiffusersPipeline":
raise ValueError(f"unsupported cozy pipeline kind: {kind!r}")
logger.info("DEBUG loaded cozy pipeline spec from %s", spec_path.name)
return CozyPipelineSpec(source_path=spec_path, raw=raw)

# Fallback: read pipeline.lock / pipeline.toml (TOML format, stored by tensorhub ingest).
toml_lock = root / PIPELINE_LOCK_TOML_FILENAME
toml_spec = toml_lock if toml_lock.exists() else (root / PIPELINE_TOML_FILENAME)
if toml_spec.exists():
raw = tomllib.loads(toml_spec.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("invalid pipeline toml (expected mapping)")
api = str(raw.get("apiVersion") or "").strip()
kind = str(raw.get("kind") or "").strip()
if api and api != "v1":
raise ValueError(f"unsupported pipeline toml apiVersion: {api!r}")
if kind and kind != "DiffusersPipeline":
raise ValueError(f"unsupported pipeline toml kind: {kind!r}")
logger.info("DEBUG loaded cozy pipeline spec from %s (toml fallback)", toml_spec.name)
return CozyPipelineSpec(source_path=toml_spec, raw=raw)

return None


def cozy_custom_pipeline_arg(model_root: Path, spec: CozyPipelineSpec) -> Optional[str]:
Expand Down
Loading
Loading