diff --git a/.env.example b/.env.example index 41810d09..1fbc44a6 100644 --- a/.env.example +++ b/.env.example @@ -7,6 +7,9 @@ FIREWORKS_API_KEY="your_fireworks_api_key_here" FIREWORKS_ACCOUNT_ID="your_fireworks_account_id_here" # e.g., "fireworks" or your specific account +# OpenAI Credentials (for using OpenAI models as judge) +OPENAI_API_KEY="your_openai_api_key_here" + # Optional: If targeting a non-production Fireworks API endpoint # FIREWORKS_API_BASE="https://dev.api.fireworks.ai" diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index f1d6c922..7cbfc6e5 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -287,6 +287,79 @@ def get_fireworks_api_base() -> str: return api_base +def get_extra_headers() -> Dict[str, str]: + """ + Retrieves extra headers from the FIREWORKS_EXTRA_HEADERS environment variable. + + The value should be a JSON object mapping header names to values. + Example: FIREWORKS_EXTRA_HEADERS='{"x-custom-header": "value", "x-another": "value2"}' + + Returns: + Dictionary of extra headers, or empty dict if not set or invalid. + """ + import json + + extra_headers_str = os.environ.get("FIREWORKS_EXTRA_HEADERS") + if not extra_headers_str: + return {} + + try: + extra_headers = json.loads(extra_headers_str) + if isinstance(extra_headers, dict): + # Ensure all values are strings + return {str(k): str(v) for k, v in extra_headers.items()} + else: + logger.warning("FIREWORKS_EXTRA_HEADERS must be a JSON object, got %s", type(extra_headers).__name__) + return {} + except json.JSONDecodeError as e: + logger.warning("Failed to parse FIREWORKS_EXTRA_HEADERS as JSON: %s", e) + return {} + + +def get_platform_headers( + api_key: Optional[str] = None, + content_type: Optional[str] = "application/json", + include_extra_headers: bool = True, +) -> Dict[str, str]: + """ + Builds standard headers for Fireworks platform API requests. + + This centralizes header construction including: + - Authorization bearer token + - Content-Type + - User-Agent + - Extra headers from FIREWORKS_EXTRA_HEADERS env var (JSON format) + + Args: + api_key: The API key for authorization. If None, resolves via get_fireworks_api_key(). + content_type: The Content-Type header value. Set to None to omit. + include_extra_headers: Whether to include extra headers from FIREWORKS_EXTRA_HEADERS env var. + + Returns: + Dictionary of headers for platform API requests. + """ + from .common_utils import get_user_agent + + resolved_api_key = api_key or get_fireworks_api_key() + + headers: Dict[str, str] = { + "User-Agent": get_user_agent(), + } + + if resolved_api_key: + headers["Authorization"] = f"Bearer {resolved_api_key}" + + if content_type: + headers["Content-Type"] = content_type + + # Include extra headers if set in environment + if include_extra_headers: + extra = get_extra_headers() + headers.update(extra) + + return headers + + def verify_api_key_and_get_account_id( api_key: Optional[str] = None, api_base: Optional[str] = None, diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 4f566338..044b90a9 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -8,7 +8,7 @@ import requests from pydantic import ValidationError -from ..auth import get_fireworks_api_base, get_fireworks_api_key +from ..auth import get_fireworks_api_base, get_fireworks_api_key, get_platform_headers from ..common_utils import get_user_agent from ..fireworks_rft import ( build_default_output_model, @@ -175,11 +175,7 @@ def _poll_evaluator_status( Returns: True if evaluator becomes ACTIVE, False if timeout or BUILD_FAILED """ - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=api_key, content_type="application/json") check_url = f"{api_base}/v1/{evaluator_resource_name}" timeout_seconds = timeout_minutes * 60 @@ -517,11 +513,7 @@ def _upload_and_ensure_evaluator( # Optional short-circuit: if evaluator already exists and not forcing, skip upload path if not force: try: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=api_key, content_type="application/json") resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10) if resp.ok: state = resp.json().get("state", "STATE_UNSPECIFIED") @@ -702,7 +694,7 @@ def _create_rft_job( print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'") if getattr(args, "evaluation_dataset", None): body["evaluationDataset"] = args.evaluation_dataset - + output_model_arg = getattr(args, "output_model", None) if output_model_arg: if len(output_model_arg) > 63: diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 6123f15d..90da3646 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -18,6 +18,7 @@ from eval_protocol.auth import ( get_fireworks_account_id, get_fireworks_api_key, + get_platform_headers, verify_api_key_and_get_account_id, ) from eval_protocol.common_utils import get_user_agent @@ -403,11 +404,7 @@ def preview(self, sample_file, max_samples=5): account_id = "pyroworks-dev" url = f"{api_base}/v1/accounts/{account_id}/evaluators:previewEvaluator" - headers = { - "Authorization": f"Bearer {auth_token}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=auth_token, content_type="application/json") logger.info(f"Previewing evaluator using API endpoint: {url} with account: {account_id}") logger.debug(f"Preview API Request URL: {url}") logger.debug(f"Preview API Request Headers: {json.dumps(headers, indent=2)}") @@ -749,11 +746,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) account_id = "pyroworks-dev" base_url = f"{self.api_base}/v1/{parent}/evaluatorsV2" - headers = { - "Authorization": f"Bearer {auth_token}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=auth_token, content_type="application/json") self._ensure_requirements_present(os.getcwd()) diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 777547fe..1e2ed32b 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -13,8 +13,7 @@ import requests -from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key -from .common_utils import get_user_agent +from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key, get_platform_headers def _map_api_host_to_app_host(api_base: str) -> str: @@ -142,11 +141,17 @@ def create_dataset_from_jsonl( display_name: Optional[str], jsonl_path: str, ) -> Tuple[str, Dict[str, Any]]: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + import os + + # DEBUG: Check environment variable + extra_headers_env = os.environ.get("FIREWORKS_EXTRA_HEADERS", "") + print(f"[DEBUG] FIREWORKS_EXTRA_HEADERS env: {extra_headers_env}") + + headers = get_platform_headers(api_key=api_key, content_type="application/json") + + # DEBUG: Print headers (mask auth token) + debug_headers = {k: (v[:20] + "..." if k == "Authorization" else v) for k, v in headers.items()} + print(f"[DEBUG] Headers being sent: {debug_headers}") # Count examples quickly example_count = 0 with open(jsonl_path, "r", encoding="utf-8") as f: @@ -171,10 +176,8 @@ def create_dataset_from_jsonl( upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload" with open(jsonl_path, "rb") as f: files = {"file": f} - up_headers = { - "Authorization": f"Bearer {api_key}", - "User-Agent": get_user_agent(), - } + # For file uploads, omit Content-Type (let requests set multipart boundary) + up_headers = get_platform_headers(api_key=api_key, content_type=None) up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600) if up_resp.status_code not in (200, 201): raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}") @@ -196,12 +199,8 @@ def create_reinforcement_fine_tuning_job( # Remove from body and append as query param body.pop("jobId", None) url = f"{url}?{urlencode({'reinforcementFineTuningJobId': job_id})}" - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "Accept": "application/json", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=api_key, content_type="application/json") + headers["Accept"] = "application/json" resp = requests.post(url, json=body, headers=headers, timeout=60) if resp.status_code not in (200, 201): raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}") @@ -217,11 +216,11 @@ def build_default_dataset_id(evaluator_id: str) -> str: def build_default_output_model(evaluator_id: str) -> str: base = evaluator_id.lower().replace("_", "-") uuid_suffix = str(uuid.uuid4())[:4] - + # suffix is "-rft-{4chars}" -> 9 chars suffix_len = 9 max_len = 63 - + # Check if we need to truncate if len(base) + suffix_len > max_len: # Calculate hash of the full base to preserve uniqueness @@ -229,10 +228,10 @@ def build_default_output_model(evaluator_id: str) -> str: # New structure: {truncated_base}-{hash}-{uuid_suffix} # Space needed for "-{hash}" is 1 + 6 = 7 hash_part_len = 7 - + allowed_base_len = max_len - suffix_len - hash_part_len truncated_base = base[:allowed_base_len].strip("-") - + return f"{truncated_base}-{hash_digest}-rft-{uuid_suffix}" return f"{base}-rft-{uuid_suffix}" diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index 81754e13..8781ff8a 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -10,8 +10,8 @@ get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key, + get_platform_headers, ) -from eval_protocol.common_utils import get_user_agent logger = logging.getLogger(__name__) @@ -93,11 +93,7 @@ def create_or_update_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.") return False - headers = { - "Authorization": f"Bearer {resolved_api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=resolved_api_key, content_type="application/json") # The secret_id for GET/PATCH/DELETE operations is the key_name. # The 'name' field in the gatewaySecret model for POST/PATCH is a bit ambiguous. @@ -219,10 +215,7 @@ def get_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.") return None - headers = { - "Authorization": f"Bearer {resolved_api_key}", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=resolved_api_key, content_type=None) resource_id = _normalize_secret_resource_id(key_name) try: @@ -259,10 +252,7 @@ def delete_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.") return False - headers = { - "Authorization": f"Bearer {resolved_api_key}", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=resolved_api_key, content_type=None) resource_id = _normalize_secret_resource_id(key_name) try: diff --git a/eval_protocol/pytest/handle_persist_flow.py b/eval_protocol/pytest/handle_persist_flow.py index 66538903..779b265d 100644 --- a/eval_protocol/pytest/handle_persist_flow.py +++ b/eval_protocol/pytest/handle_persist_flow.py @@ -7,7 +7,6 @@ import re from typing import Any -from eval_protocol.common_utils import get_user_agent from eval_protocol.directory_utils import find_eval_protocol_dir from eval_protocol.models import EvaluationRow from eval_protocol.pytest.store_experiment_link import store_experiment_link @@ -16,6 +15,7 @@ get_fireworks_account_id, verify_api_key_and_get_account_id, get_fireworks_api_base, + get_platform_headers, ) import requests @@ -130,11 +130,7 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name: continue api_base = get_fireworks_api_base() - headers = { - "Authorization": f"Bearer {fireworks_api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + headers = get_platform_headers(api_key=fireworks_api_key, content_type="application/json") # Make dataset first @@ -167,10 +163,8 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name: upload_url = f"{api_base}/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload" with open(exp_file, "rb") as f: files = {"file": f} - upload_headers = { - "Authorization": f"Bearer {fireworks_api_key}", - "User-Agent": get_user_agent(), - } + # For file uploads, omit Content-Type (let requests set multipart boundary) + upload_headers = get_platform_headers(api_key=fireworks_api_key, content_type=None) upload_response = requests.post(upload_url, files=files, headers=upload_headers) # Skip if upload failed