Skip to content
13 changes: 11 additions & 2 deletions miles/rollout/rm_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from miles.utils.misc import load_function
from miles.utils.types import Sample

from .ocr import ocr_rm


def _resolve_rm_type(args, sample: Sample) -> str:
metadata = sample.metadata if isinstance(sample.metadata, dict) else {}
Expand Down Expand Up @@ -38,7 +36,13 @@ async def async_rm(args, sample: Sample, **kwargs):
elif rm_type == "random":
return random.randint(0, 1)
elif rm_type == "ocr":
from .ocr import ocr_rm

return await ocr_rm(args, sample)
elif rm_type == "hps":
from .hps import hps_rm

return (await hps_rm(args, [sample]))[0]
elif rm_type:
raise NotImplementedError(f"Rule-based RM for {rm_type} is not implemented.")
else:
Expand All @@ -55,6 +59,11 @@ async def batched_async_rm(
rm_function = load_function(args.custom_rm_path)
return await rm_function(args, samples, **kwargs)

if samples and all(_resolve_rm_type(args, sample) == "hps" for sample in samples):
from .hps import hps_rm

return await hps_rm(args, samples)

tasks = [async_rm(args, sample, **kwargs) for sample in samples]
rewards = await asyncio.gather(*tasks)
return rewards
186 changes: 186 additions & 0 deletions miles/rollout/rm_hub/hps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from __future__ import annotations

import asyncio
import logging
from collections.abc import Sequence

import numpy as np
import ray
import torch
from PIL import Image

from miles.utils.misc import SingletonMeta
from miles.utils.types import Sample

logger = logging.getLogger(__name__)


def _sample_to_rgb_hwc_uint8(sample: Sample) -> np.ndarray:
t = sample.generated_output
if t is None:
raise ValueError("generated_output is None")
if t.ndim != 4:
raise ValueError(f"generated_output must be 4D [C, F, H, W], got {tuple(t.shape)}")

frame_chw = t.detach().cpu()[:, 0, :, :]
hwc = frame_chw.float().numpy().transpose(1, 2, 0)
if float(hwc.max()) <= 1.0 + 1e-3:
hwc = np.round(hwc * 255.0)
return np.ascontiguousarray(hwc.clip(0, 255).astype(np.uint8))


class HPSScorer(torch.nn.Module):
"""HPS / HPSv2.1 reward scorer.

Loads the ViT-H-14 backbone via hpsv2's own ``create_model_and_transforms``
(so the preprocessing pipeline matches the published checkpoint exactly),
then patches in the HPS preference-tuned weights. Scoring returns the raw
image/text logit diagonal — the same scalar hpsv2.score() returns.
"""

def __init__(
self,
*,
device: str = "cuda",
hps_version: str = "v2.1",
checkpoint_path: str | None = None,
) -> None:
super().__init__()
import huggingface_hub
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
from hpsv2.utils import hps_version_map

self.device = torch.device(device)
self.hps_version = hps_version

model, _, preprocess_val = create_model_and_transforms(
"ViT-H-14",
"laion2B-s32B-b79K",
precision="amp",
device=str(self.device),
jit=False,
force_quick_gelu=False,
force_custom_text=False,
force_patch_dropout=False,
force_image_size=None,
pretrained_image=False,
image_mean=None,
image_std=None,
light_augmentation=True,
aug_cfg={},
output_dict=True,
with_score_predictor=False,
with_region_predictor=False,
)

if checkpoint_path is None:
checkpoint_path = huggingface_hub.hf_hub_download(
"xswu/HPSv2", hps_version_map[hps_version]
)
checkpoint = torch.load(checkpoint_path, map_location=str(self.device))
model.load_state_dict(checkpoint["state_dict"])
model.to(self.device).eval()

self.model = model
self.preprocess = preprocess_val
self.tokenizer = get_tokenizer("ViT-H-14")

@torch.no_grad()
def forward(self, prompts: Sequence[str], images: Sequence[Image.Image]) -> list[float]:
if not prompts:
return []

image_batch = torch.stack([self.preprocess(img) for img in images]).to(
self.device, non_blocking=True
)
text_batch = self.tokenizer(list(prompts)).to(self.device, non_blocking=True)

with torch.amp.autocast(self.device.type, enabled=self.device.type == "cuda"):
outputs = self.model(image_batch, text_batch)
image_features = outputs["image_features"]
text_features = outputs["text_features"]
logits = image_features @ text_features.T
scores = torch.diagonal(logits)

return [float(score) for score in scores.detach().float().cpu()]


@ray.remote
class HPSRewardActor:
def __init__(
self,
*,
hps_version: str,
checkpoint_path: str | None = None,
) -> None:
use_cuda = bool(ray.get_gpu_ids()) and torch.cuda.is_available()
if use_cuda:
torch.cuda.set_device(0)
device = "cuda" if use_cuda else "cpu"
self.scorer = HPSScorer(
device=device,
hps_version=hps_version,
checkpoint_path=checkpoint_path,
)

def score_batch(self, images: list[np.ndarray], prompts: list[str]) -> list[float]:
pil_images = [Image.fromarray(image) for image in images]
return self.scorer(prompts, pil_images)


class AsyncHPSPool(metaclass=SingletonMeta):
"""Ray actor pool for GPU HPS reward inference."""

def __init__(self, args) -> None:
num_workers = args.hps_num_workers
num_gpus_per_worker = args.hps_num_gpus_per_worker
if num_workers <= 0:
raise ValueError("--hps-num-workers must be positive")
if args.hps_batch_size <= 0:
raise ValueError("--hps-batch-size must be positive")

self._batch_size = args.hps_batch_size
self._actors = [
HPSRewardActor.options(
num_cpus=1,
num_gpus=num_gpus_per_worker,
scheduling_strategy="DEFAULT",
).remote(
hps_version=args.hps_version,
checkpoint_path=args.hps_checkpoint_path,
)
for _ in range(num_workers)
]
self._round_robin_index = 0
logger.info(
"Initialized HPS actor pool with %d workers, %.3f GPUs/worker, batch_size=%d, version=%s.",
num_workers,
num_gpus_per_worker,
self._batch_size,
args.hps_version,
)

def _next_actor(self):
i = self._round_robin_index % len(self._actors)
self._round_robin_index += 1
return self._actors[i]

async def score(self, images: list[np.ndarray], prompts: list[str]) -> list[float]:
if not images:
return []

refs = []
for start in range(0, len(images), self._batch_size):
end = start + self._batch_size
refs.append(self._next_actor().score_batch.remote(images[start:end], prompts[start:end]))

loop = asyncio.get_running_loop()
chunked_scores = await loop.run_in_executor(None, ray.get, refs)
return [float(score) for chunk in chunked_scores for score in chunk]


async def hps_rm(args, samples: Sequence[Sample]) -> list[float]:
pool = AsyncHPSPool(args)
images = [_sample_to_rgb_hwc_uint8(sample) for sample in samples]
prompts = [sample.prompt for sample in samples]
return await pool.score(images, prompts)
31 changes: 31 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,37 @@ def add_reward_model_arguments(parser):
default=4,
help="Number of Ray OCR actors used when --rm-type ocr.",
)
parser.add_argument(
"--hps-num-workers",
type=int,
default=1,
help="Number of Ray HPS actors used when --rm-type hps.",
)
parser.add_argument(
"--hps-num-gpus-per-worker",
type=float,
default=1.0,
help="GPU resources per HPS actor. Use 1.0 for a dedicated GPU smoke test.",
)
parser.add_argument(
"--hps-batch-size",
type=int,
default=8,
help="Batch size per HPS actor call.",
)
parser.add_argument(
"--hps-version",
type=str,
default="v2.1",
choices=["v2.0", "v2.1"],
help="HPS checkpoint version to use when --rm-type hps.",
)
parser.add_argument(
"--hps-checkpoint-path",
type=str,
default=None,
help="Optional local HPS checkpoint path. If unset, the checkpoint is downloaded from Hugging Face.",
)
parser.add_argument(
"--custom-rm-path",
type=str,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
accelerate
blobfile
datasets
hpsv2
httpx[http2]
mcp[cli]
memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml
Expand Down
97 changes: 97 additions & 0 deletions scripts/run-diffusion-grpo-hps-smoke.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env bash

set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
PYTHON_BIN="${PYTHON_BIN:-/usr/bin/python3}"

SMOKE_COLOCATE="${SMOKE_COLOCATE:-1}"
SMOKE_ACTOR_GPUS_PER_NODE="${SMOKE_ACTOR_GPUS_PER_NODE:-2}"
SMOKE_ROLLOUT_GPUS="${SMOKE_ROLLOUT_GPUS:-2}"
SMOKE_ROLLOUT_GPUS_PER_ENGINE="${SMOKE_ROLLOUT_GPUS_PER_ENGINE:-1}"
SMOKE_HPS_VERSION="${SMOKE_HPS_VERSION:-v2.1}"
SMOKE_HPS_NUM_WORKERS="${SMOKE_HPS_NUM_WORKERS:-1}"
SMOKE_HPS_BATCH_SIZE="${SMOKE_HPS_BATCH_SIZE:-8}"

COLOCATE_ARGS=()
if [[ "${SMOKE_COLOCATE}" == "1" || "${SMOKE_COLOCATE}" == "true" || "${SMOKE_COLOCATE}" == "yes" ]]; then
# Use two colocated train/rollout GPUs plus one dedicated HPS reward GPU.
DEFAULT_CUDA_VISIBLE_DEVICES="4,5,6"
DEFAULT_NUM_GPUS_PER_NODE="3"
COLOCATE_ARGS+=(--colocate)
else
# Use two train GPUs, two rollout GPUs, and one dedicated HPS reward GPU.
DEFAULT_CUDA_VISIBLE_DEVICES="1,2,3,4,5"
DEFAULT_NUM_GPUS_PER_NODE="5"
fi

SMOKE_NUM_GPUS_PER_NODE="${SMOKE_NUM_GPUS_PER_NODE:-${DEFAULT_NUM_GPUS_PER_NODE}}"

export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-${DEFAULT_CUDA_VISIBLE_DEVICES}}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"

RUN_NAME="diffusion_grpo_hps_smoke_$(date +%Y%m%d_%H%M%S)"

WANDB_ARGS=()
if [[ -n "${WANDB_API_KEY:-}" ]]; then
WANDB_ARGS+=(
--use-wandb
--wandb-project miles-diffusion-grpo
--wandb-group "${RUN_NAME}"
--wandb-key "${WANDB_API_KEY}"
--diffusion-log-images 4
--diffusion-log-image-interval 1
--disable-wandb-random-suffix
)
fi

"${PYTHON_BIN}" "${ROOT_DIR}/tools/prepare_ocr_jsonl.py"

"${PYTHON_BIN}" -u "${ROOT_DIR}/train_diffusion.py" \
--train-backend fsdp \
--diffusion-train \
--rollout-function-path miles.rollout.sglang_diffusion_rollout.generate_rollout \
--hf-checkpoint gpt2 \
--prompt-data "${ROOT_DIR}/data/ocr/train.jsonl" \
--input-key input \
--rollout-batch-size 1 \
--n-samples-per-prompt 2 \
--num-rollout 1 \
--diffusion-timestep-batch 10 \
--gradient-checkpointing \
--actor-num-gpus-per-node "${SMOKE_ACTOR_GPUS_PER_NODE}" \
--rollout-num-gpus "${SMOKE_ROLLOUT_GPUS}" \
--rollout-num-gpus-per-engine "${SMOKE_ROLLOUT_GPUS_PER_ENGINE}" \
--num-gpus-per-node "${SMOKE_NUM_GPUS_PER_NODE}" \
"${COLOCATE_ARGS[@]}" \
--no-offload-rollout \
--use-lora \
--lora-rank 64 \
--lora-alpha 128 \
--diffusion-init-lora-weight gaussian \
--use-miles-router \
--sglang-server-concurrency 2 \
--diffusion-model Qwen/Qwen-Image \
--diffusion-reward hps:1.0 \
--advantage-estimator grpo \
--globalize-reward-std \
--rm-type hps \
--hps-version "${SMOKE_HPS_VERSION}" \
--hps-num-workers "${SMOKE_HPS_NUM_WORKERS}" \
--hps-num-gpus-per-worker 1 \
--hps-batch-size "${SMOKE_HPS_BATCH_SIZE}" \
--diffusion-dtype bf16 \
--diffusion-num-steps 10 \
--diffusion-guidance-scale 4.0 \
--diffusion-true-cfg-scale 4.0 \
--diffusion-noise-level 1.2 \
--diffusion-step-strategy-path miles.rollout.step_strategy_hub.sde_window \
--diffusion-sde-window-size 2 \
--diffusion-sde-window-range 0,5 \
--diffusion-height 256 \
--diffusion-width 256 \
--global-batch-size 2 \
--diffusion-ignore-last 1 \
--diffusion-debug-mode \
--debug-skip-optimizer-step \
"${WANDB_ARGS[@]}"