Skip to content

Commit 5fc4592

Browse files
authored
auto generated cli (#384)
* auto generated cli * update * fix rft command * update test * also fail fast on unknown flags. * update test * rft create uses sdk now * clean up tests * add a smoke test
1 parent 37d2e02 commit 5fc4592

File tree

8 files changed

+584
-776
lines changed

8 files changed

+584
-776
lines changed

eval_protocol/cli.py

Lines changed: 66 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
"""
44

55
import argparse
6+
import inspect
7+
import json
68
import logging
79
import os
810
import sys
911
from pathlib import Path
1012
from typing import Any, cast
13+
from .cli_commands.utils import add_args_from_callable_signature
14+
15+
from fireworks import Fireworks
1116

1217
logger = logging.getLogger(__name__)
1318

@@ -374,87 +379,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
374379
"rft",
375380
help="Create a Reinforcement Fine-tuning Job on Fireworks",
376381
)
377-
rft_parser.add_argument(
378-
"--evaluator",
379-
help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests",
380-
)
381-
# Dataset options
382-
rft_parser.add_argument(
383-
"--dataset",
384-
help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization",
385-
)
386-
rft_parser.add_argument(
387-
"--dataset-jsonl",
388-
help="Path to JSONL to upload as a new Fireworks dataset",
389-
)
390-
rft_parser.add_argument(
391-
"--dataset-builder",
392-
help="Explicit dataset builder spec (module::function or path::function)",
393-
)
394-
rft_parser.add_argument(
395-
"--dataset-display-name",
396-
help="Display name for dataset on Fireworks (defaults to dataset id)",
397-
)
398-
# Training config and evaluator/job settings
399-
rft_parser.add_argument("--base-model", help="Base model resource id")
400-
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
401-
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
402-
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
403-
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
404-
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
405-
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
406-
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
407-
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
408-
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
409-
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
410-
rft_parser.add_argument("--region", help="Fireworks region for training")
411-
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
412-
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
413-
rft_parser.add_argument(
414-
"--eval-auto-carveout",
415-
dest="eval_auto_carveout",
416-
action="store_true",
417-
default=True,
418-
help="Automatically carve out evaluation data from training set",
419-
)
420-
rft_parser.add_argument(
421-
"--no-eval-auto-carveout",
422-
dest="eval_auto_carveout",
423-
action="store_false",
424-
help="Disable automatic evaluation data carveout",
425-
)
426-
# Rollout chunking
427-
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
428-
# Inference params
429-
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
430-
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
431-
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
432-
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
433-
rft_parser.add_argument(
434-
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
435-
)
436-
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
437-
# MCP server (optional)
438-
rft_parser.add_argument(
439-
"--mcp-server",
440-
help="MCP server resource name for agentic rollouts",
441-
)
442-
# Wandb
443-
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
444-
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
445-
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
446-
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
447-
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
448-
# Misc
449-
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
382+
450383
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
451-
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
384+
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending")
452385
rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID")
453-
rft_parser.add_argument(
454-
"--skip-validation",
455-
action="store_true",
456-
help="Skip local dataset and evaluator validation before creating the RFT job",
457-
)
386+
rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation")
458387
rft_parser.add_argument(
459388
"--ignore-docker",
460389
action="store_true",
@@ -463,14 +392,67 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
463392
rft_parser.add_argument(
464393
"--docker-build-extra",
465394
default="",
395+
metavar="",
466396
help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")",
467397
)
468398
rft_parser.add_argument(
469399
"--docker-run-extra",
470400
default="",
401+
metavar="",
471402
help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")",
472403
)
473404

405+
# The flags below are Eval Protocol CLI workflow controls (not part of the Fireworks SDK `create()` signature),
406+
# so they can’t be auto-generated via signature introspection and must be maintained here.
407+
rft_parser.add_argument(
408+
"--source-job",
409+
metavar="",
410+
help="The source reinforcement fine-tuning job to copy configuration from. If other flags are set, they will override the source job's configuration.",
411+
)
412+
rft_parser.add_argument(
413+
"--quiet",
414+
action="store_true",
415+
help="If set, only errors will be printed.",
416+
)
417+
skip_fields = {
418+
"__top_level__": {
419+
"extra_headers",
420+
"extra_query",
421+
"extra_body",
422+
"timeout",
423+
"display_name",
424+
"account_id",
425+
},
426+
"training_config": {"region", "jinja_template"},
427+
"wandb_config": {"run_id"},
428+
}
429+
aliases = {
430+
"wandb_config.api_key": ["--wandb-api-key"],
431+
"wandb_config.project": ["--wandb-project"],
432+
"wandb_config.entity": ["--wandb-entity"],
433+
"wandb_config.enabled": ["--wandb"],
434+
"reinforcement_fine_tuning_job_id": ["--job-id"],
435+
"loss_config.kl_beta": ["--rl-kl-beta"],
436+
"loss_config.method": ["--rl-loss-method"],
437+
"node_count": ["--nodes"],
438+
}
439+
help_overrides = {
440+
"training_config.gradient_accumulation_steps": "The number of batches to accumulate gradients before updating the model parameters. The effective batch size will be batch-size multiplied by this value.",
441+
"training_config.learning_rate_warmup_steps": "The number of learning rate warmup steps for the reinforcement fine-tuning job.",
442+
"mcp_server": "The MCP server resource name to use for the reinforcement fine-tuning job. (Optional)",
443+
"loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.",
444+
}
445+
446+
create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create
447+
448+
add_args_from_callable_signature(
449+
rft_parser,
450+
create_rft_job_fn,
451+
skip_fields=skip_fields,
452+
aliases=aliases,
453+
help_overrides=help_overrides,
454+
)
455+
474456
# Local test command
475457
local_test_parser = subparsers.add_parser(
476458
"local-test",
@@ -542,8 +524,11 @@ def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None:
542524
def parse_args(args=None):
543525
"""Parse command line arguments."""
544526
parser = build_parser()
545-
# Use parse_known_args to allow Hydra to handle its own arguments
546-
return parser.parse_known_args(args)
527+
# Fail fast on unknown flags so typos don't silently get ignored.
528+
parsed, remaining = parser.parse_known_args(args)
529+
if remaining:
530+
parser.error(f"unrecognized arguments: {' '.join(remaining)}")
531+
return parsed, remaining
547532

548533

549534
def main():

eval_protocol/cli_commands/create_rft.py

Lines changed: 35 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import argparse
2+
from fireworks._client import Fireworks
3+
from fireworks.types.reinforcement_fine_tuning_job import ReinforcementFineTuningJob
24
import json
35
import os
46
import sys
57
import time
68
from typing import Any, Dict, Optional
7-
9+
import inspect
810
import requests
911
from pydantic import ValidationError
1012

@@ -13,7 +15,6 @@
1315
from ..fireworks_rft import (
1416
build_default_output_model,
1517
create_dataset_from_jsonl,
16-
create_reinforcement_fine_tuning_job,
1718
detect_dataset_builder,
1819
materialize_dataset_via_builder,
1920
)
@@ -33,6 +34,8 @@
3334
)
3435
from .local_test import run_evaluator_test
3536

37+
from fireworks import Fireworks
38+
3639

3740
def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]:
3841
"""Import the test module and extract a JSONL path from data_loaders param if present.
@@ -619,124 +622,48 @@ def _create_rft_job(
619622
args: argparse.Namespace,
620623
dry_run: bool,
621624
) -> int:
622-
"""Build and submit the RFT job request."""
623-
# Build training config/body
624-
# Exactly one of base-model or warm-start-from must be provided
625-
base_model_raw = getattr(args, "base_model", None)
626-
warm_start_from_raw = getattr(args, "warm_start_from", None)
627-
# Treat empty/whitespace strings as not provided
628-
base_model = base_model_raw.strip() if isinstance(base_model_raw, str) else base_model_raw
629-
warm_start_from = warm_start_from_raw.strip() if isinstance(warm_start_from_raw, str) else warm_start_from_raw
630-
has_base_model = bool(base_model)
631-
has_warm_start = bool(warm_start_from)
632-
if (not has_base_model and not has_warm_start) or (has_base_model and has_warm_start):
633-
print("Error: exactly one of --base-model or --warm-start-from must be specified.")
634-
return 1
625+
"""Build and submit the RFT job request (via Fireworks SDK)."""
635626

636-
training_config: Dict[str, Any] = {}
637-
if has_base_model:
638-
training_config["baseModel"] = base_model
639-
if has_warm_start:
640-
training_config["warmStartFrom"] = warm_start_from
641-
642-
# Optional hyperparameters
643-
for key, arg_name in [
644-
("epochs", "epochs"),
645-
("batchSize", "batch_size"),
646-
("learningRate", "learning_rate"),
647-
("maxContextLength", "max_context_length"),
648-
("loraRank", "lora_rank"),
649-
("gradientAccumulationSteps", "gradient_accumulation_steps"),
650-
("learningRateWarmupSteps", "learning_rate_warmup_steps"),
651-
("acceleratorCount", "accelerator_count"),
652-
("region", "region"),
653-
]:
654-
val = getattr(args, arg_name, None)
655-
if val is not None:
656-
training_config[key] = val
657-
658-
inference_params: Dict[str, Any] = {}
659-
for key, arg_name in [
660-
("temperature", "temperature"),
661-
("topP", "top_p"),
662-
("topK", "top_k"),
663-
("maxOutputTokens", "max_output_tokens"),
664-
("responseCandidatesCount", "response_candidates_count"),
665-
]:
666-
val = getattr(args, arg_name, None)
667-
if val is not None:
668-
inference_params[key] = val
669-
if getattr(args, "extra_body", None):
670-
extra = getattr(args, "extra_body")
671-
if isinstance(extra, (dict, list)):
672-
try:
673-
inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False)
674-
except (TypeError, ValueError) as e:
675-
print(f"Error: --extra-body dict/list must be JSON-serializable: {e}")
676-
return 1
677-
elif isinstance(extra, str):
678-
inference_params["extraBody"] = extra
679-
else:
680-
print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.")
681-
return 1
627+
signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)
682628

683-
wandb_config: Optional[Dict[str, Any]] = None
684-
if getattr(args, "wandb_enabled", False):
685-
wandb_config = {
686-
"enabled": True,
687-
"apiKey": getattr(args, "wandb_api_key", None),
688-
"project": getattr(args, "wandb_project", None),
689-
"entity": getattr(args, "wandb_entity", None),
690-
"runId": getattr(args, "wandb_run_id", None),
691-
}
692-
693-
body: Dict[str, Any] = {
694-
"displayName": getattr(args, "display_name", None),
695-
"dataset": dataset_resource,
629+
# Build top-level SDK kwargs
630+
sdk_kwargs: Dict[str, Any] = {
696631
"evaluator": evaluator_resource_name,
697-
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
698-
"trainingConfig": training_config,
699-
"inferenceParameters": inference_params or None,
700-
"wandbConfig": wandb_config,
701-
"chunkSize": getattr(args, "chunk_size", None),
702-
"outputStats": None,
703-
"outputMetrics": None,
704-
"mcpServer": getattr(args, "mcp_server", None),
705-
"jobId": getattr(args, "job_id", None),
632+
"dataset": dataset_resource,
706633
}
707-
# Debug: print minimal summary
708-
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
709-
if getattr(args, "evaluation_dataset", None):
710-
body["evaluationDataset"] = args.evaluation_dataset
711634

712-
output_model_arg = getattr(args, "output_model", None)
713-
if output_model_arg:
714-
if len(output_model_arg) > 63:
715-
print(f"Error: Output model name '{output_model_arg}' exceeds 63 characters.")
716-
return 1
717-
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{output_model_arg}"
718-
else:
719-
# Auto-generate output model name if not provided
720-
auto_output_model = build_default_output_model(evaluator_id)
721-
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{auto_output_model}"
635+
args_dict = vars(args)
636+
for name in signature.parameters:
637+
prefix = name + "_"
638+
639+
# Collect "flattened" argparse fields back into the nested dict expected by the SDK.
640+
# Example: training_config_epochs=3 becomes sdk_kwargs["training_config"]["epochs"] = 3.
641+
nested = {}
642+
for k, v in args_dict.items():
643+
if v is None:
644+
continue
645+
if not k.startswith(prefix):
646+
continue
647+
nested[k[len(prefix) :]] = v
648+
649+
if nested:
650+
sdk_kwargs[name] = nested
651+
elif args_dict.get(name) is not None:
652+
sdk_kwargs[name] = args_dict[name]
722653

723-
# Clean None fields to avoid noisy payloads
724-
body = {k: v for k, v in body.items() if v is not None}
654+
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
725655

726656
if dry_run:
727-
print("--dry-run: would create RFT job with body:")
728-
print(json.dumps(body, indent=2))
657+
print("--dry-run: would call Fireworks().reinforcement_fine_tuning_jobs.create with kwargs:")
658+
print(json.dumps(sdk_kwargs, indent=2))
729659
_print_links(evaluator_id, dataset_id, None)
730660
return 0
731661

732662
try:
733-
result = create_reinforcement_fine_tuning_job(
734-
account_id=account_id, api_key=api_key, api_base=api_base, body=body
735-
)
736-
job_name = result.get("name") if isinstance(result, dict) else None
737-
print("\n✅ Created Reinforcement Fine-tuning Job")
738-
if job_name:
739-
print(f" name: {job_name}")
663+
fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base)
664+
job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs)
665+
job_name = job.name
666+
print(f"\n✅ Created Reinforcement Fine-tuning Job: {job_name}")
740667
_print_links(evaluator_id, dataset_id, job_name)
741668
return 0
742669
except Exception as e:

0 commit comments

Comments
 (0)