|
1 | 1 | import argparse |
| 2 | +from fireworks._client import Fireworks |
| 3 | +from fireworks.types.reinforcement_fine_tuning_job import ReinforcementFineTuningJob |
2 | 4 | import json |
3 | 5 | import os |
4 | 6 | import sys |
5 | 7 | import time |
6 | 8 | from typing import Any, Dict, Optional |
7 | | - |
| 9 | +import inspect |
8 | 10 | import requests |
9 | 11 | from pydantic import ValidationError |
10 | 12 |
|
|
13 | 15 | from ..fireworks_rft import ( |
14 | 16 | build_default_output_model, |
15 | 17 | create_dataset_from_jsonl, |
16 | | - create_reinforcement_fine_tuning_job, |
17 | 18 | detect_dataset_builder, |
18 | 19 | materialize_dataset_via_builder, |
19 | 20 | ) |
|
33 | 34 | ) |
34 | 35 | from .local_test import run_evaluator_test |
35 | 36 |
|
| 37 | +from fireworks import Fireworks |
| 38 | + |
36 | 39 |
|
37 | 40 | def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]: |
38 | 41 | """Import the test module and extract a JSONL path from data_loaders param if present. |
@@ -619,124 +622,48 @@ def _create_rft_job( |
619 | 622 | args: argparse.Namespace, |
620 | 623 | dry_run: bool, |
621 | 624 | ) -> int: |
622 | | - """Build and submit the RFT job request.""" |
623 | | - # Build training config/body |
624 | | - # Exactly one of base-model or warm-start-from must be provided |
625 | | - base_model_raw = getattr(args, "base_model", None) |
626 | | - warm_start_from_raw = getattr(args, "warm_start_from", None) |
627 | | - # Treat empty/whitespace strings as not provided |
628 | | - base_model = base_model_raw.strip() if isinstance(base_model_raw, str) else base_model_raw |
629 | | - warm_start_from = warm_start_from_raw.strip() if isinstance(warm_start_from_raw, str) else warm_start_from_raw |
630 | | - has_base_model = bool(base_model) |
631 | | - has_warm_start = bool(warm_start_from) |
632 | | - if (not has_base_model and not has_warm_start) or (has_base_model and has_warm_start): |
633 | | - print("Error: exactly one of --base-model or --warm-start-from must be specified.") |
634 | | - return 1 |
| 625 | + """Build and submit the RFT job request (via Fireworks SDK).""" |
635 | 626 |
|
636 | | - training_config: Dict[str, Any] = {} |
637 | | - if has_base_model: |
638 | | - training_config["baseModel"] = base_model |
639 | | - if has_warm_start: |
640 | | - training_config["warmStartFrom"] = warm_start_from |
641 | | - |
642 | | - # Optional hyperparameters |
643 | | - for key, arg_name in [ |
644 | | - ("epochs", "epochs"), |
645 | | - ("batchSize", "batch_size"), |
646 | | - ("learningRate", "learning_rate"), |
647 | | - ("maxContextLength", "max_context_length"), |
648 | | - ("loraRank", "lora_rank"), |
649 | | - ("gradientAccumulationSteps", "gradient_accumulation_steps"), |
650 | | - ("learningRateWarmupSteps", "learning_rate_warmup_steps"), |
651 | | - ("acceleratorCount", "accelerator_count"), |
652 | | - ("region", "region"), |
653 | | - ]: |
654 | | - val = getattr(args, arg_name, None) |
655 | | - if val is not None: |
656 | | - training_config[key] = val |
657 | | - |
658 | | - inference_params: Dict[str, Any] = {} |
659 | | - for key, arg_name in [ |
660 | | - ("temperature", "temperature"), |
661 | | - ("topP", "top_p"), |
662 | | - ("topK", "top_k"), |
663 | | - ("maxOutputTokens", "max_output_tokens"), |
664 | | - ("responseCandidatesCount", "response_candidates_count"), |
665 | | - ]: |
666 | | - val = getattr(args, arg_name, None) |
667 | | - if val is not None: |
668 | | - inference_params[key] = val |
669 | | - if getattr(args, "extra_body", None): |
670 | | - extra = getattr(args, "extra_body") |
671 | | - if isinstance(extra, (dict, list)): |
672 | | - try: |
673 | | - inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False) |
674 | | - except (TypeError, ValueError) as e: |
675 | | - print(f"Error: --extra-body dict/list must be JSON-serializable: {e}") |
676 | | - return 1 |
677 | | - elif isinstance(extra, str): |
678 | | - inference_params["extraBody"] = extra |
679 | | - else: |
680 | | - print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.") |
681 | | - return 1 |
| 627 | + signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create) |
682 | 628 |
|
683 | | - wandb_config: Optional[Dict[str, Any]] = None |
684 | | - if getattr(args, "wandb_enabled", False): |
685 | | - wandb_config = { |
686 | | - "enabled": True, |
687 | | - "apiKey": getattr(args, "wandb_api_key", None), |
688 | | - "project": getattr(args, "wandb_project", None), |
689 | | - "entity": getattr(args, "wandb_entity", None), |
690 | | - "runId": getattr(args, "wandb_run_id", None), |
691 | | - } |
692 | | - |
693 | | - body: Dict[str, Any] = { |
694 | | - "displayName": getattr(args, "display_name", None), |
695 | | - "dataset": dataset_resource, |
| 629 | + # Build top-level SDK kwargs |
| 630 | + sdk_kwargs: Dict[str, Any] = { |
696 | 631 | "evaluator": evaluator_resource_name, |
697 | | - "evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)), |
698 | | - "trainingConfig": training_config, |
699 | | - "inferenceParameters": inference_params or None, |
700 | | - "wandbConfig": wandb_config, |
701 | | - "chunkSize": getattr(args, "chunk_size", None), |
702 | | - "outputStats": None, |
703 | | - "outputMetrics": None, |
704 | | - "mcpServer": getattr(args, "mcp_server", None), |
705 | | - "jobId": getattr(args, "job_id", None), |
| 632 | + "dataset": dataset_resource, |
706 | 633 | } |
707 | | - # Debug: print minimal summary |
708 | | - print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'") |
709 | | - if getattr(args, "evaluation_dataset", None): |
710 | | - body["evaluationDataset"] = args.evaluation_dataset |
711 | 634 |
|
712 | | - output_model_arg = getattr(args, "output_model", None) |
713 | | - if output_model_arg: |
714 | | - if len(output_model_arg) > 63: |
715 | | - print(f"Error: Output model name '{output_model_arg}' exceeds 63 characters.") |
716 | | - return 1 |
717 | | - body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{output_model_arg}" |
718 | | - else: |
719 | | - # Auto-generate output model name if not provided |
720 | | - auto_output_model = build_default_output_model(evaluator_id) |
721 | | - body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{auto_output_model}" |
| 635 | + args_dict = vars(args) |
| 636 | + for name in signature.parameters: |
| 637 | + prefix = name + "_" |
| 638 | + |
| 639 | + # Collect "flattened" argparse fields back into the nested dict expected by the SDK. |
| 640 | + # Example: training_config_epochs=3 becomes sdk_kwargs["training_config"]["epochs"] = 3. |
| 641 | + nested = {} |
| 642 | + for k, v in args_dict.items(): |
| 643 | + if v is None: |
| 644 | + continue |
| 645 | + if not k.startswith(prefix): |
| 646 | + continue |
| 647 | + nested[k[len(prefix) :]] = v |
| 648 | + |
| 649 | + if nested: |
| 650 | + sdk_kwargs[name] = nested |
| 651 | + elif args_dict.get(name) is not None: |
| 652 | + sdk_kwargs[name] = args_dict[name] |
722 | 653 |
|
723 | | - # Clean None fields to avoid noisy payloads |
724 | | - body = {k: v for k, v in body.items() if v is not None} |
| 654 | + print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'") |
725 | 655 |
|
726 | 656 | if dry_run: |
727 | | - print("--dry-run: would create RFT job with body:") |
728 | | - print(json.dumps(body, indent=2)) |
| 657 | + print("--dry-run: would call Fireworks().reinforcement_fine_tuning_jobs.create with kwargs:") |
| 658 | + print(json.dumps(sdk_kwargs, indent=2)) |
729 | 659 | _print_links(evaluator_id, dataset_id, None) |
730 | 660 | return 0 |
731 | 661 |
|
732 | 662 | try: |
733 | | - result = create_reinforcement_fine_tuning_job( |
734 | | - account_id=account_id, api_key=api_key, api_base=api_base, body=body |
735 | | - ) |
736 | | - job_name = result.get("name") if isinstance(result, dict) else None |
737 | | - print("\n✅ Created Reinforcement Fine-tuning Job") |
738 | | - if job_name: |
739 | | - print(f" name: {job_name}") |
| 663 | + fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base) |
| 664 | + job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs) |
| 665 | + job_name = job.name |
| 666 | + print(f"\n✅ Created Reinforcement Fine-tuning Job: {job_name}") |
740 | 667 | _print_links(evaluator_id, dataset_id, job_name) |
741 | 668 | return 0 |
742 | 669 | except Exception as e: |
|
0 commit comments