-
Notifications
You must be signed in to change notification settings - Fork 10
auto generated cli #384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
auto generated cli #384
Changes from all commits
1df9e72
76e9cec
5885822
ac5be36
35c66c9
2bb176d
c209678
0d0a50c
665ea5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,16 @@ | |
| """ | ||
|
|
||
| import argparse | ||
| import inspect | ||
| import json | ||
| import logging | ||
| import os | ||
| import sys | ||
| from pathlib import Path | ||
| from typing import Any, cast | ||
| from .cli_commands.utils import add_args_from_callable_signature | ||
|
|
||
| from fireworks import Fireworks | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -374,87 +379,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse | |
| "rft", | ||
| help="Create a Reinforcement Fine-tuning Job on Fireworks", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--evaluator", | ||
| help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests", | ||
| ) | ||
| # Dataset options | ||
| rft_parser.add_argument( | ||
| "--dataset", | ||
| help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--dataset-jsonl", | ||
| help="Path to JSONL to upload as a new Fireworks dataset", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--dataset-builder", | ||
| help="Explicit dataset builder spec (module::function or path::function)", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--dataset-display-name", | ||
| help="Display name for dataset on Fireworks (defaults to dataset id)", | ||
| ) | ||
| # Training config and evaluator/job settings | ||
| rft_parser.add_argument("--base-model", help="Base model resource id") | ||
| rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from") | ||
| rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)") | ||
| rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs") | ||
| rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens") | ||
| rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training") | ||
| rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens") | ||
| rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning") | ||
| rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps") | ||
| rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps") | ||
| rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use") | ||
| rft_parser.add_argument("--region", help="Fireworks region for training") | ||
| rft_parser.add_argument("--display-name", help="Display name for the RFT job") | ||
| rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation") | ||
| rft_parser.add_argument( | ||
| "--eval-auto-carveout", | ||
| dest="eval_auto_carveout", | ||
| action="store_true", | ||
| default=True, | ||
| help="Automatically carve out evaluation data from training set", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--no-eval-auto-carveout", | ||
| dest="eval_auto_carveout", | ||
| action="store_false", | ||
| help="Disable automatic evaluation data carveout", | ||
| ) | ||
| # Rollout chunking | ||
| rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching") | ||
| # Inference params | ||
| rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts") | ||
| rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter") | ||
| rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter") | ||
| rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout") | ||
| rft_parser.add_argument( | ||
| "--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt" | ||
| ) | ||
| rft_parser.add_argument("--extra-body", help="JSON string for extra inference params") | ||
| # MCP server (optional) | ||
| rft_parser.add_argument( | ||
| "--mcp-server", | ||
| help="MCP server resource name for agentic rollouts", | ||
| ) | ||
| # Wandb | ||
| rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging") | ||
| rft_parser.add_argument("--wandb-project", help="Weights & Biases project name") | ||
| rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)") | ||
| rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming") | ||
| rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key") | ||
| # Misc | ||
| rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id") | ||
|
|
||
| rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") | ||
| rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending") | ||
| rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending") | ||
| rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID") | ||
| rft_parser.add_argument( | ||
| "--skip-validation", | ||
| action="store_true", | ||
| help="Skip local dataset and evaluator validation before creating the RFT job", | ||
| ) | ||
| rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation") | ||
| rft_parser.add_argument( | ||
| "--ignore-docker", | ||
| action="store_true", | ||
|
|
@@ -463,14 +392,67 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse | |
| rft_parser.add_argument( | ||
| "--docker-build-extra", | ||
| default="", | ||
| metavar="", | ||
| help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--docker-run-extra", | ||
| default="", | ||
| metavar="", | ||
| help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")", | ||
| ) | ||
|
|
||
| # The flags below are Eval Protocol CLI workflow controls (not part of the Fireworks SDK `create()` signature), | ||
| # so they can’t be auto-generated via signature introspection and must be maintained here. | ||
| rft_parser.add_argument( | ||
| "--source-job", | ||
| metavar="", | ||
| help="The source reinforcement fine-tuning job to copy configuration from. If other flags are set, they will override the source job's configuration.", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--quiet", | ||
| action="store_true", | ||
| help="If set, only errors will be printed.", | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing CLI arguments for dataset creation workflowThe Additional Locations (1) |
||
| skip_fields = { | ||
| "__top_level__": { | ||
| "extra_headers", | ||
| "extra_query", | ||
| "extra_body", | ||
| "timeout", | ||
| "display_name", | ||
| "account_id", | ||
| }, | ||
| "training_config": {"region", "jinja_template"}, | ||
xzrderek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "wandb_config": {"run_id"}, | ||
xzrderek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| aliases = { | ||
| "wandb_config.api_key": ["--wandb-api-key"], | ||
| "wandb_config.project": ["--wandb-project"], | ||
| "wandb_config.entity": ["--wandb-entity"], | ||
| "wandb_config.enabled": ["--wandb"], | ||
| "reinforcement_fine_tuning_job_id": ["--job-id"], | ||
| "loss_config.kl_beta": ["--rl-kl-beta"], | ||
| "loss_config.method": ["--rl-loss-method"], | ||
| "node_count": ["--nodes"], | ||
| } | ||
| help_overrides = { | ||
| "training_config.gradient_accumulation_steps": "The number of batches to accumulate gradients before updating the model parameters. The effective batch size will be batch-size multiplied by this value.", | ||
| "training_config.learning_rate_warmup_steps": "The number of learning rate warmup steps for the reinforcement fine-tuning job.", | ||
| "mcp_server": "The MCP server resource name to use for the reinforcement fine-tuning job. (Optional)", | ||
| "loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.", | ||
| } | ||
|
|
||
| create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create | ||
xzrderek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| add_args_from_callable_signature( | ||
| rft_parser, | ||
| create_rft_job_fn, | ||
| skip_fields=skip_fields, | ||
| aliases=aliases, | ||
| help_overrides=help_overrides, | ||
| ) | ||
|
|
||
| # Local test command | ||
| local_test_parser = subparsers.add_parser( | ||
| "local-test", | ||
|
|
@@ -542,8 +524,11 @@ def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None: | |
| def parse_args(args=None): | ||
| """Parse command line arguments.""" | ||
| parser = build_parser() | ||
| # Use parse_known_args to allow Hydra to handle its own arguments | ||
| return parser.parse_known_args(args) | ||
| # Fail fast on unknown flags so typos don't silently get ignored. | ||
| parsed, remaining = parser.parse_known_args(args) | ||
| if remaining: | ||
| parser.error(f"unrecognized arguments: {' '.join(remaining)}") | ||
| return parsed, remaining | ||
xzrderek marked this conversation as resolved.
Show resolved
Hide resolved
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def main(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| import argparse | ||
| from fireworks._client import Fireworks | ||
| from fireworks.types.reinforcement_fine_tuning_job import ReinforcementFineTuningJob | ||
| import json | ||
| import os | ||
| import sys | ||
| import time | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| import inspect | ||
| import requests | ||
| from pydantic import ValidationError | ||
|
|
||
|
|
@@ -13,7 +15,6 @@ | |
| from ..fireworks_rft import ( | ||
| build_default_output_model, | ||
| create_dataset_from_jsonl, | ||
| create_reinforcement_fine_tuning_job, | ||
| detect_dataset_builder, | ||
| materialize_dataset_via_builder, | ||
| ) | ||
|
|
@@ -33,6 +34,8 @@ | |
| ) | ||
| from .local_test import run_evaluator_test | ||
|
|
||
| from fireworks import Fireworks | ||
|
|
||
|
|
||
| def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]: | ||
| """Import the test module and extract a JSONL path from data_loaders param if present. | ||
|
|
@@ -619,124 +622,48 @@ def _create_rft_job( | |
| args: argparse.Namespace, | ||
| dry_run: bool, | ||
| ) -> int: | ||
| """Build and submit the RFT job request.""" | ||
| # Build training config/body | ||
| # Exactly one of base-model or warm-start-from must be provided | ||
| base_model_raw = getattr(args, "base_model", None) | ||
| warm_start_from_raw = getattr(args, "warm_start_from", None) | ||
| # Treat empty/whitespace strings as not provided | ||
| base_model = base_model_raw.strip() if isinstance(base_model_raw, str) else base_model_raw | ||
| warm_start_from = warm_start_from_raw.strip() if isinstance(warm_start_from_raw, str) else warm_start_from_raw | ||
| has_base_model = bool(base_model) | ||
| has_warm_start = bool(warm_start_from) | ||
| if (not has_base_model and not has_warm_start) or (has_base_model and has_warm_start): | ||
| print("Error: exactly one of --base-model or --warm-start-from must be specified.") | ||
| return 1 | ||
| """Build and submit the RFT job request (via Fireworks SDK).""" | ||
|
|
||
| training_config: Dict[str, Any] = {} | ||
| if has_base_model: | ||
| training_config["baseModel"] = base_model | ||
| if has_warm_start: | ||
| training_config["warmStartFrom"] = warm_start_from | ||
|
|
||
| # Optional hyperparameters | ||
| for key, arg_name in [ | ||
| ("epochs", "epochs"), | ||
| ("batchSize", "batch_size"), | ||
| ("learningRate", "learning_rate"), | ||
| ("maxContextLength", "max_context_length"), | ||
| ("loraRank", "lora_rank"), | ||
| ("gradientAccumulationSteps", "gradient_accumulation_steps"), | ||
| ("learningRateWarmupSteps", "learning_rate_warmup_steps"), | ||
| ("acceleratorCount", "accelerator_count"), | ||
| ("region", "region"), | ||
| ]: | ||
| val = getattr(args, arg_name, None) | ||
| if val is not None: | ||
| training_config[key] = val | ||
|
|
||
| inference_params: Dict[str, Any] = {} | ||
| for key, arg_name in [ | ||
| ("temperature", "temperature"), | ||
| ("topP", "top_p"), | ||
| ("topK", "top_k"), | ||
| ("maxOutputTokens", "max_output_tokens"), | ||
| ("responseCandidatesCount", "response_candidates_count"), | ||
| ]: | ||
| val = getattr(args, arg_name, None) | ||
| if val is not None: | ||
| inference_params[key] = val | ||
| if getattr(args, "extra_body", None): | ||
| extra = getattr(args, "extra_body") | ||
| if isinstance(extra, (dict, list)): | ||
| try: | ||
| inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False) | ||
| except (TypeError, ValueError) as e: | ||
| print(f"Error: --extra-body dict/list must be JSON-serializable: {e}") | ||
| return 1 | ||
| elif isinstance(extra, str): | ||
| inference_params["extraBody"] = extra | ||
| else: | ||
| print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.") | ||
| return 1 | ||
| signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create) | ||
|
|
||
| wandb_config: Optional[Dict[str, Any]] = None | ||
| if getattr(args, "wandb_enabled", False): | ||
| wandb_config = { | ||
| "enabled": True, | ||
| "apiKey": getattr(args, "wandb_api_key", None), | ||
| "project": getattr(args, "wandb_project", None), | ||
| "entity": getattr(args, "wandb_entity", None), | ||
| "runId": getattr(args, "wandb_run_id", None), | ||
| } | ||
|
|
||
| body: Dict[str, Any] = { | ||
| "displayName": getattr(args, "display_name", None), | ||
| "dataset": dataset_resource, | ||
| # Build top-level SDK kwargs | ||
| sdk_kwargs: Dict[str, Any] = { | ||
| "evaluator": evaluator_resource_name, | ||
| "evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)), | ||
| "trainingConfig": training_config, | ||
| "inferenceParameters": inference_params or None, | ||
| "wandbConfig": wandb_config, | ||
| "chunkSize": getattr(args, "chunk_size", None), | ||
| "outputStats": None, | ||
| "outputMetrics": None, | ||
| "mcpServer": getattr(args, "mcp_server", None), | ||
| "jobId": getattr(args, "job_id", None), | ||
| "dataset": dataset_resource, | ||
| } | ||
| # Debug: print minimal summary | ||
| print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'") | ||
| if getattr(args, "evaluation_dataset", None): | ||
| body["evaluationDataset"] = args.evaluation_dataset | ||
|
|
||
| output_model_arg = getattr(args, "output_model", None) | ||
| if output_model_arg: | ||
| if len(output_model_arg) > 63: | ||
| print(f"Error: Output model name '{output_model_arg}' exceeds 63 characters.") | ||
| return 1 | ||
| body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{output_model_arg}" | ||
| else: | ||
| # Auto-generate output model name if not provided | ||
| auto_output_model = build_default_output_model(evaluator_id) | ||
| body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{auto_output_model}" | ||
| args_dict = vars(args) | ||
| for name in signature.parameters: | ||
| prefix = name + "_" | ||
|
|
||
| # Collect "flattened" argparse fields back into the nested dict expected by the SDK. | ||
| # Example: training_config_epochs=3 becomes sdk_kwargs["training_config"]["epochs"] = 3. | ||
| nested = {} | ||
| for k, v in args_dict.items(): | ||
| if v is None: | ||
| continue | ||
| if not k.startswith(prefix): | ||
| continue | ||
| nested[k[len(prefix) :]] = v | ||
|
|
||
| if nested: | ||
| sdk_kwargs[name] = nested | ||
| elif args_dict.get(name) is not None: | ||
| sdk_kwargs[name] = args_dict[name] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SDK kwargs overwrite normalized values with raw argsThe Additional Locations (1) |
||
|
|
||
| # Clean None fields to avoid noisy payloads | ||
| body = {k: v for k, v in body.items() if v is not None} | ||
| print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'") | ||
|
|
||
| if dry_run: | ||
| print("--dry-run: would create RFT job with body:") | ||
| print(json.dumps(body, indent=2)) | ||
| print("--dry-run: would call Fireworks().reinforcement_fine_tuning_jobs.create with kwargs:") | ||
| print(json.dumps(sdk_kwargs, indent=2)) | ||
| _print_links(evaluator_id, dataset_id, None) | ||
| return 0 | ||
|
|
||
| try: | ||
| result = create_reinforcement_fine_tuning_job( | ||
| account_id=account_id, api_key=api_key, api_base=api_base, body=body | ||
| ) | ||
| job_name = result.get("name") if isinstance(result, dict) else None | ||
| print("\n✅ Created Reinforcement Fine-tuning Job") | ||
| if job_name: | ||
| print(f" name: {job_name}") | ||
| fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base) | ||
| job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs) | ||
| job_name = job.name | ||
| print(f"\n✅ Created Reinforcement Fine-tuning Job: {job_name}") | ||
| _print_links(evaluator_id, dataset_id, job_name) | ||
| return 0 | ||
| except Exception as e: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.