Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 66 additions & 81 deletions eval_protocol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
"""

import argparse
import inspect
import json
import logging
import os
import sys
from pathlib import Path
from typing import Any, cast
from .cli_commands.utils import add_args_from_callable_signature

from fireworks import Fireworks

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -374,87 +379,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
"rft",
help="Create a Reinforcement Fine-tuning Job on Fireworks",
)
rft_parser.add_argument(
"--evaluator",
help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests",
)
# Dataset options
rft_parser.add_argument(
"--dataset",
help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization",
)
rft_parser.add_argument(
"--dataset-jsonl",
help="Path to JSONL to upload as a new Fireworks dataset",
)
rft_parser.add_argument(
"--dataset-builder",
help="Explicit dataset builder spec (module::function or path::function)",
)
rft_parser.add_argument(
"--dataset-display-name",
help="Display name for dataset on Fireworks (defaults to dataset id)",
)
# Training config and evaluator/job settings
rft_parser.add_argument("--base-model", help="Base model resource id")
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
rft_parser.add_argument("--region", help="Fireworks region for training")
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
rft_parser.add_argument(
"--eval-auto-carveout",
dest="eval_auto_carveout",
action="store_true",
default=True,
help="Automatically carve out evaluation data from training set",
)
rft_parser.add_argument(
"--no-eval-auto-carveout",
dest="eval_auto_carveout",
action="store_false",
help="Disable automatic evaluation data carveout",
)
# Rollout chunking
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
# Inference params
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
rft_parser.add_argument(
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
)
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
# MCP server (optional)
rft_parser.add_argument(
"--mcp-server",
help="MCP server resource name for agentic rollouts",
)
# Wandb
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
# Misc
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")

rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending")
rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID")
rft_parser.add_argument(
"--skip-validation",
action="store_true",
help="Skip local dataset and evaluator validation before creating the RFT job",
)
rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation")
rft_parser.add_argument(
"--ignore-docker",
action="store_true",
Expand All @@ -463,14 +392,67 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
rft_parser.add_argument(
"--docker-build-extra",
default="",
metavar="",
help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")",
)
rft_parser.add_argument(
"--docker-run-extra",
default="",
metavar="",
help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")",
)

# The flags below are Eval Protocol CLI workflow controls (not part of the Fireworks SDK `create()` signature),
# so they can’t be auto-generated via signature introspection and must be maintained here.
rft_parser.add_argument(
"--source-job",
metavar="",
help="The source reinforcement fine-tuning job to copy configuration from. If other flags are set, they will override the source job's configuration.",
)
rft_parser.add_argument(
"--quiet",
action="store_true",
help="If set, only errors will be printed.",
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CLI arguments for dataset creation workflow

The --dataset-jsonl, --dataset-builder, and --dataset-display-name arguments were removed from the CLI but the code in create_rft.py still expects them via getattr(args, "dataset_jsonl", None) and getattr(args, "dataset_display_name", None). These are workflow-specific arguments (not SDK parameters) for creating datasets from local JSONL files. Users attempting to use --dataset-jsonl will get "unrecognized arguments" errors, and the dataset creation workflow from local files is broken. The comment at line 405-406 acknowledges that workflow controls must be maintained manually, but these arguments were not added.

Additional Locations (1)

Fix in Cursor Fix in Web

skip_fields = {
"__top_level__": {
"extra_headers",
"extra_query",
"extra_body",
"timeout",
"display_name",
"account_id",
},
"training_config": {"region", "jinja_template"},
"wandb_config": {"run_id"},
}
aliases = {
"wandb_config.api_key": ["--wandb-api-key"],
"wandb_config.project": ["--wandb-project"],
"wandb_config.entity": ["--wandb-entity"],
"wandb_config.enabled": ["--wandb"],
"reinforcement_fine_tuning_job_id": ["--job-id"],
"loss_config.kl_beta": ["--rl-kl-beta"],
"loss_config.method": ["--rl-loss-method"],
"node_count": ["--nodes"],
}
help_overrides = {
"training_config.gradient_accumulation_steps": "The number of batches to accumulate gradients before updating the model parameters. The effective batch size will be batch-size multiplied by this value.",
"training_config.learning_rate_warmup_steps": "The number of learning rate warmup steps for the reinforcement fine-tuning job.",
"mcp_server": "The MCP server resource name to use for the reinforcement fine-tuning job. (Optional)",
"loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.",
}

create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create

add_args_from_callable_signature(
rft_parser,
create_rft_job_fn,
skip_fields=skip_fields,
aliases=aliases,
help_overrides=help_overrides,
)

# Local test command
local_test_parser = subparsers.add_parser(
"local-test",
Expand Down Expand Up @@ -542,8 +524,11 @@ def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None:
def parse_args(args=None):
"""Parse command line arguments."""
parser = build_parser()
# Use parse_known_args to allow Hydra to handle its own arguments
return parser.parse_known_args(args)
# Fail fast on unknown flags so typos don't silently get ignored.
parsed, remaining = parser.parse_known_args(args)
if remaining:
parser.error(f"unrecognized arguments: {' '.join(remaining)}")
return parsed, remaining


def main():
Expand Down
143 changes: 35 additions & 108 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
from fireworks._client import Fireworks
from fireworks.types.reinforcement_fine_tuning_job import ReinforcementFineTuningJob
import json
import os
import sys
import time
from typing import Any, Dict, Optional

import inspect
import requests
from pydantic import ValidationError

Expand All @@ -13,7 +15,6 @@
from ..fireworks_rft import (
build_default_output_model,
create_dataset_from_jsonl,
create_reinforcement_fine_tuning_job,
detect_dataset_builder,
materialize_dataset_via_builder,
)
Expand All @@ -33,6 +34,8 @@
)
from .local_test import run_evaluator_test

from fireworks import Fireworks


def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]:
"""Import the test module and extract a JSONL path from data_loaders param if present.
Expand Down Expand Up @@ -619,124 +622,48 @@ def _create_rft_job(
args: argparse.Namespace,
dry_run: bool,
) -> int:
"""Build and submit the RFT job request."""
# Build training config/body
# Exactly one of base-model or warm-start-from must be provided
base_model_raw = getattr(args, "base_model", None)
warm_start_from_raw = getattr(args, "warm_start_from", None)
# Treat empty/whitespace strings as not provided
base_model = base_model_raw.strip() if isinstance(base_model_raw, str) else base_model_raw
warm_start_from = warm_start_from_raw.strip() if isinstance(warm_start_from_raw, str) else warm_start_from_raw
has_base_model = bool(base_model)
has_warm_start = bool(warm_start_from)
if (not has_base_model and not has_warm_start) or (has_base_model and has_warm_start):
print("Error: exactly one of --base-model or --warm-start-from must be specified.")
return 1
"""Build and submit the RFT job request (via Fireworks SDK)."""

training_config: Dict[str, Any] = {}
if has_base_model:
training_config["baseModel"] = base_model
if has_warm_start:
training_config["warmStartFrom"] = warm_start_from

# Optional hyperparameters
for key, arg_name in [
("epochs", "epochs"),
("batchSize", "batch_size"),
("learningRate", "learning_rate"),
("maxContextLength", "max_context_length"),
("loraRank", "lora_rank"),
("gradientAccumulationSteps", "gradient_accumulation_steps"),
("learningRateWarmupSteps", "learning_rate_warmup_steps"),
("acceleratorCount", "accelerator_count"),
("region", "region"),
]:
val = getattr(args, arg_name, None)
if val is not None:
training_config[key] = val

inference_params: Dict[str, Any] = {}
for key, arg_name in [
("temperature", "temperature"),
("topP", "top_p"),
("topK", "top_k"),
("maxOutputTokens", "max_output_tokens"),
("responseCandidatesCount", "response_candidates_count"),
]:
val = getattr(args, arg_name, None)
if val is not None:
inference_params[key] = val
if getattr(args, "extra_body", None):
extra = getattr(args, "extra_body")
if isinstance(extra, (dict, list)):
try:
inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(f"Error: --extra-body dict/list must be JSON-serializable: {e}")
return 1
elif isinstance(extra, str):
inference_params["extraBody"] = extra
else:
print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.")
return 1
signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)

wandb_config: Optional[Dict[str, Any]] = None
if getattr(args, "wandb_enabled", False):
wandb_config = {
"enabled": True,
"apiKey": getattr(args, "wandb_api_key", None),
"project": getattr(args, "wandb_project", None),
"entity": getattr(args, "wandb_entity", None),
"runId": getattr(args, "wandb_run_id", None),
}

body: Dict[str, Any] = {
"displayName": getattr(args, "display_name", None),
"dataset": dataset_resource,
# Build top-level SDK kwargs
sdk_kwargs: Dict[str, Any] = {
"evaluator": evaluator_resource_name,
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
"trainingConfig": training_config,
"inferenceParameters": inference_params or None,
"wandbConfig": wandb_config,
"chunkSize": getattr(args, "chunk_size", None),
"outputStats": None,
"outputMetrics": None,
"mcpServer": getattr(args, "mcp_server", None),
"jobId": getattr(args, "job_id", None),
"dataset": dataset_resource,
}
# Debug: print minimal summary
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
if getattr(args, "evaluation_dataset", None):
body["evaluationDataset"] = args.evaluation_dataset

output_model_arg = getattr(args, "output_model", None)
if output_model_arg:
if len(output_model_arg) > 63:
print(f"Error: Output model name '{output_model_arg}' exceeds 63 characters.")
return 1
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{output_model_arg}"
else:
# Auto-generate output model name if not provided
auto_output_model = build_default_output_model(evaluator_id)
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{auto_output_model}"
args_dict = vars(args)
for name in signature.parameters:
prefix = name + "_"

# Collect "flattened" argparse fields back into the nested dict expected by the SDK.
# Example: training_config_epochs=3 becomes sdk_kwargs["training_config"]["epochs"] = 3.
nested = {}
for k, v in args_dict.items():
if v is None:
continue
if not k.startswith(prefix):
continue
nested[k[len(prefix) :]] = v

if nested:
sdk_kwargs[name] = nested
elif args_dict.get(name) is not None:
sdk_kwargs[name] = args_dict[name]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDK kwargs overwrite normalized values with raw args

The _create_rft_job function correctly initializes sdk_kwargs with normalized evaluator_resource_name and dataset_resource values from function parameters. However, the loop that builds nested SDK kwargs (lines 636-652) also checks for top-level parameters like evaluator and dataset in args_dict. Since these are parameters in the SDK signature and are added as CLI flags via add_args_from_callable_signature, user-provided raw values (e.g., --evaluator my-eval) will overwrite the correctly-normalized full resource names (e.g., accounts/{id}/evaluators/my-eval). This causes the SDK call to receive short IDs instead of full resource names, likely causing API failures.

Additional Locations (1)

Fix in Cursor Fix in Web


# Clean None fields to avoid noisy payloads
body = {k: v for k, v in body.items() if v is not None}
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")

if dry_run:
print("--dry-run: would create RFT job with body:")
print(json.dumps(body, indent=2))
print("--dry-run: would call Fireworks().reinforcement_fine_tuning_jobs.create with kwargs:")
print(json.dumps(sdk_kwargs, indent=2))
_print_links(evaluator_id, dataset_id, None)
return 0

try:
result = create_reinforcement_fine_tuning_job(
account_id=account_id, api_key=api_key, api_base=api_base, body=body
)
job_name = result.get("name") if isinstance(result, dict) else None
print("\n✅ Created Reinforcement Fine-tuning Job")
if job_name:
print(f" name: {job_name}")
fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base)
job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs)
job_name = job.name
print(f"\n✅ Created Reinforcement Fine-tuning Job: {job_name}")
_print_links(evaluator_id, dataset_id, job_name)
return 0
except Exception as e:
Expand Down
Loading
Loading