diff --git a/.gitignore b/.gitignore index 7e434174..918f7fbf 100644 --- a/.gitignore +++ b/.gitignore @@ -243,3 +243,5 @@ package.json tau2-bench *.err eval-protocol + +.vscode/launch.json diff --git a/.vscode/.gitignore b/.vscode/.gitignore new file mode 100644 index 00000000..c2dd2a37 --- /dev/null +++ b/.vscode/.gitignore @@ -0,0 +1 @@ +!launch.json.backup diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 38fff2f8..00000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "Python: Debug Tests", - "type": "python", - "request": "launch", - "module": "pytest", - "args": ["-s", "--tb=short", "${file}"], - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - }, - { - "name": "Python: Debug Current File", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - }, - { - "name": "Python: Debug Logs Server", - "type": "python", - "request": "launch", - "module": "eval_protocol.utils.logs_server", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - } - ] -} diff --git a/.vscode/launch.json.example b/.vscode/launch.json.example new file mode 100644 index 00000000..7b70e735 --- /dev/null +++ b/.vscode/launch.json.example @@ -0,0 +1,60 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "EP: Upload", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": ["upload"], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + }, + { + "name": "EP: Local Test", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": ["local-test", "--ignore-docker"], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + }, + { + "name": "EP: Create RFT", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": [ + "create", + "rft", + "--base-model", + "accounts/fireworks/models/qwen3-0p6b", + "--chunk-size", + "10" + ], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + } + ] +} diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index 68ce134c..40e3c777 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -1,12 +1,75 @@ import logging import os -from typing import Optional +from typing import Dict, Optional import requests +from dotenv import dotenv_values, find_dotenv, load_dotenv logger = logging.getLogger(__name__) +def find_dotenv_path(search_path: Optional[str] = None) -> Optional[str]: + """ + Find the .env file path, searching .env.dev first, then .env. + + Args: + search_path: Directory to search from. If None, uses current working directory. + + Returns: + Path to the .env file if found, otherwise None. + """ + # If a specific search path is provided, look there first + if search_path: + env_dev_path = os.path.join(search_path, ".env.dev") + if os.path.isfile(env_dev_path): + return env_dev_path + env_path = os.path.join(search_path, ".env") + if os.path.isfile(env_path): + return env_path + return None + + # Otherwise use find_dotenv to search up the directory tree + env_dev_path = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) + if env_dev_path: + return env_dev_path + env_path = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) + if env_path: + return env_path + return None + + +def get_dotenv_values(search_path: Optional[str] = None) -> Dict[str, Optional[str]]: + """ + Get all key-value pairs from the .env file. + + Args: + search_path: Directory to search from. If None, uses current working directory. + + Returns: + Dictionary of environment variable names to values. + """ + dotenv_path = find_dotenv_path(search_path) + if dotenv_path: + return dotenv_values(dotenv_path) + return {} + + +# --- Load .env files --- +# Attempt to load .env.dev first, then .env as a fallback. +# This happens when the module is imported. +# We use override=False (default) so that existing environment variables +# (e.g., set in the shell) are NOT overridden by .env files. +_DOTENV_PATH = find_dotenv_path() +if _DOTENV_PATH: + load_dotenv(dotenv_path=_DOTENV_PATH, override=False) + logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_DOTENV_PATH}") +else: + logger.debug( + "eval_protocol.auth: No .env.dev or .env file found. Relying on shell/existing environment variables." + ) +# --- End .env loading --- + + def get_fireworks_api_key() -> Optional[str]: """ Retrieves the Fireworks API key. @@ -73,6 +136,8 @@ def verify_api_key_and_get_account_id( Args: api_key: Optional explicit API key. When None, resolves via get_fireworks_api_key(). api_base: Optional explicit API base. When None, resolves via get_fireworks_api_base(). + If api_base is api.fireworks.ai, it is used directly. Otherwise, defaults to + dev.api.fireworks.ai for the verification call. Returns: The resolved account id if verification succeeds and the header is present; otherwise None. @@ -81,7 +146,12 @@ def verify_api_key_and_get_account_id( resolved_key = api_key or get_fireworks_api_key() if not resolved_key: return None - resolved_base = api_base or get_fireworks_api_base() + provided_base = api_base or get_fireworks_api_base() + # Use api.fireworks.ai if explicitly provided, otherwise fall back to dev + if "api.fireworks.ai" in provided_base: + resolved_base = provided_base + else: + resolved_base = "https://dev.api.fireworks.ai" from .common_utils import get_user_agent diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index ac8a8d9d..9b3bb320 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -81,13 +81,12 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse "--env-file", help="Path to .env file containing secrets to upload (default: .env in current directory)", ) - upload_parser.add_argument( - "--force", - action="store_true", - help="Overwrite existing evaluator with the same ID", - ) # Auto-generate flags from SDK Fireworks().evaluators.create() signature + # Note: We use Fireworks() directly here instead of create_fireworks_client() + # because we only need the method signature for introspection, not a fully + # authenticated client. create_fireworks_client() would trigger an HTTP request + # to verify the API key, causing delays even for --help invocations. create_evaluator_fn = Fireworks().evaluators.create upload_skip_fields = { @@ -137,7 +136,6 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending") - rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID") rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation") rft_parser.add_argument( "--ignore-docker", @@ -198,6 +196,10 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse "loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.", } + # Note: We use Fireworks() directly here instead of create_fireworks_client() + # because we only need the method signature for introspection, not a fully + # authenticated client. create_fireworks_client() would trigger an HTTP request + # to verify the API key, causing delays even for --help invocations. create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create add_args_from_callable_signature( @@ -284,8 +286,10 @@ def main(): from dotenv import load_dotenv # .env.dev for development-specific overrides, .env for general + # Use explicit paths to avoid find_dotenv() searching up the directory tree + # and potentially finding a different .env file (e.g., in some other repo) load_dotenv(dotenv_path=Path(".") / ".env.dev", override=True) - load_dotenv(override=True) + load_dotenv(dotenv_path=Path(".") / ".env", override=True) except ImportError: pass diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 702eb2fe..6a6123f1 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -7,19 +7,18 @@ import time from typing import Any, Callable, Dict, Optional import inspect -import requests import tempfile from pydantic import ValidationError from ..auth import get_fireworks_api_base, get_fireworks_api_key -from ..common_utils import get_user_agent, load_jsonl +from ..fireworks_client import create_fireworks_client +from ..common_utils import load_jsonl from ..fireworks_rft import ( create_dataset_from_jsonl, detect_dataset_builder, materialize_dataset_via_builder, ) from ..models import EvaluationRow -from .upload import upload_command from .utils import ( _build_entry_point, _build_trimmed_dataset_id, @@ -35,8 +34,6 @@ ) from .local_test import run_evaluator_test -from fireworks import Fireworks - def _extract_dataset_adapter( test_file_path: str, test_func_name: str @@ -223,64 +220,68 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) return None -def _poll_evaluator_status( - evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10 +def _poll_evaluator_version_status( + evaluator_id: str, + version_id: str, + api_key: str, + api_base: str, + timeout_minutes: int = 10, ) -> bool: """ - Poll evaluator status until it becomes ACTIVE or times out. + Poll a specific evaluator version status until it becomes ACTIVE or times out. + + Uses the Fireworks SDK to get the specified version of the evaluator and checks + its build state. Args: - evaluator_resource_name: Full evaluator resource name (e.g., accounts/xxx/evaluators/yyy) + evaluator_id: The evaluator ID (not full resource name) + version_id: The specific version ID to poll api_key: Fireworks API key api_base: Fireworks API base URL timeout_minutes: Maximum time to wait in minutes Returns: - True if evaluator becomes ACTIVE, False if timeout or BUILD_FAILED + True if evaluator version becomes ACTIVE, False if timeout or BUILD_FAILED """ - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - - check_url = f"{api_base}/v1/{evaluator_resource_name}" timeout_seconds = timeout_minutes * 60 poll_interval = 10 # seconds start_time = time.time() - print(f"Polling evaluator status (timeout: {timeout_minutes}m, interval: {poll_interval}s)...") + print( + f"Polling evaluator version '{version_id}' status (timeout: {timeout_minutes}m, interval: {poll_interval}s)..." + ) + + client = create_fireworks_client(api_key=api_key, base_url=api_base) while time.time() - start_time < timeout_seconds: try: - response = requests.get(check_url, headers=headers, timeout=30) - response.raise_for_status() - - evaluator_data = response.json() - state = evaluator_data.get("state", "STATE_UNSPECIFIED") - status = evaluator_data.get("status", "") + version = client.evaluator_versions.get(version_id, evaluator_id=evaluator_id) + state = version.state or "STATE_UNSPECIFIED" + status_msg = "" + if version.status and version.status.message: + status_msg = version.status.message if state == "ACTIVE": - print("✅ Evaluator is ACTIVE and ready!") + print("✅ Evaluator version is ACTIVE and ready!") return True elif state == "BUILD_FAILED": - print(f"❌ Evaluator build failed. Status: {status}") + print(f"❌ Evaluator version build failed. Status: {status_msg}") return False elif state == "BUILDING": elapsed_minutes = (time.time() - start_time) / 60 - print(f"⏳ Evaluator is still building... ({elapsed_minutes:.1f}m elapsed)") + print(f"⏳ Evaluator version is still building... ({elapsed_minutes:.1f}m elapsed)") else: - print(f"⏳ Evaluator state: {state}, status: {status}") + print(f"⏳ Evaluator version state: {state}, status: {status_msg}") - except requests.exceptions.RequestException as e: - print(f"Warning: Failed to check evaluator status: {e}") + except Exception as e: + print(f"Warning: Failed to check evaluator version status: {e}") # Wait before next poll time.sleep(poll_interval) # Timeout reached elapsed_minutes = (time.time() - start_time) / 60 - print(f"⏰ Timeout after {elapsed_minutes:.1f}m - evaluator is not yet ACTIVE") + print(f"⏰ Timeout after {elapsed_minutes:.1f}m - evaluator version is not yet ACTIVE") return False @@ -565,42 +566,16 @@ def _upload_dataset( def _upload_and_ensure_evaluator( project_root: str, evaluator_id: str, - evaluator_resource_name: str, api_key: str, api_base: str, - force: bool, ) -> bool: - """Ensure the evaluator exists and is ACTIVE, uploading it if needed.""" - # Optional short-circuit: if evaluator already exists and not forcing, skip upload path - if not force: - try: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10) - if resp.ok: - state = resp.json().get("state", "STATE_UNSPECIFIED") - print(f"✓ Evaluator exists (state: {state}). Skipping upload (use --force to overwrite).") - # Poll for ACTIVE before proceeding - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - if not _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ): - dashboard_url = _build_evaluator_dashboard_url(evaluator_id) - print("\n❌ Evaluator is not ready within the timeout period.") - print(f"📊 Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return False - return True - except requests.exceptions.RequestException: - pass + """Upload evaluator and ensure its version becomes ACTIVE. + + Creates/updates the evaluator and uploads the code, then polls the specific + version until it becomes ACTIVE. + """ + from eval_protocol.evaluation import create_evaluation - # Ensure evaluator exists by invoking the upload flow programmatically try: tests = _discover_tests(project_root) selected_entry: Optional[str] = None @@ -617,43 +592,37 @@ def _upload_and_ensure_evaluator( ) return False - upload_args = argparse.Namespace( - path=project_root, - entry=selected_entry, - id=evaluator_id, - display_name=None, - description=None, - force=force, # Pass through the --force flag - yes=True, - env_file=None, # Add the new env_file parameter + print(f"\nUploading evaluator '{evaluator_id}'...") + result, version_id = create_evaluation( + evaluator_id=evaluator_id, + display_name=evaluator_id, + description=f"Evaluator for {evaluator_id}", + entry_point=selected_entry, ) - if force: - print(f"🔄 Force flag enabled - will overwrite existing evaluator '{evaluator_id}'") + if not version_id: + print("Warning: Evaluator created but version upload failed.") + return False - rc = upload_command(upload_args) - if rc == 0: - print(f"✓ Uploaded/ensured evaluator: {evaluator_id}") + print(f"✓ Uploaded evaluator: {evaluator_id} (version: {version_id})") - # Poll for evaluator status - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - is_active = _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ) + # Poll for the specific evaluator version status + print(f"Waiting for evaluator '{evaluator_id}' version '{version_id}' to become ACTIVE...") + is_active = _poll_evaluator_version_status( + evaluator_id=evaluator_id, + version_id=version_id, + api_key=api_key, + api_base=api_base, + timeout_minutes=10, + ) - if not is_active: - dashboard_url = _build_evaluator_dashboard_url(evaluator_id) - print("\n❌ Evaluator is not ready within the timeout period.") - print(f"📊 Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return False - return True - else: - print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") + if not is_active: + dashboard_url = _build_evaluator_dashboard_url(evaluator_id) + print("\n❌ Evaluator version is not ready within the timeout period.") + print(f"📊 Please check the evaluator status at: {dashboard_url}") + print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") return False + return True except Exception as e: print(f"Warning: Failed to upload evaluator automatically: {e}") return False @@ -672,7 +641,7 @@ def _create_rft_job( ) -> int: """Build and submit the RFT job request (via Fireworks SDK).""" - signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create) + signature = inspect.signature(create_fireworks_client().reinforcement_fine_tuning_jobs.create) # Build top-level SDK kwargs sdk_kwargs: Dict[str, Any] = { @@ -711,7 +680,7 @@ def _create_rft_job( return 0 try: - fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base) + fw: Fireworks = create_fireworks_client(api_key=api_key, base_url=api_base) job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs) job_name = job.name print(f"\n✅ Created Reinforcement Fine-tuning Job: {job_name}") @@ -739,7 +708,6 @@ def create_rft_command(args) -> int: evaluator_arg: Optional[str] = getattr(args, "evaluator", None) non_interactive: bool = bool(getattr(args, "yes", False)) dry_run: bool = bool(getattr(args, "dry_run", False)) - force: bool = bool(getattr(args, "force", False)) skip_validation: bool = bool(getattr(args, "skip_validation", False)) ignore_docker: bool = bool(getattr(args, "ignore_docker", False)) docker_build_extra: str = getattr(args, "docker_build_extra", "") or "" @@ -810,14 +778,12 @@ def create_rft_command(args) -> int: if not dataset_id or not dataset_resource: return 1 - # 5) Ensure evaluator exists and is ACTIVE (upload + poll if needed) + # 5) Ensure evaluator exists and its latest version is ACTIVE (upload + poll if needed) if not _upload_and_ensure_evaluator( project_root=project_root, evaluator_id=evaluator_id, - evaluator_resource_name=evaluator_resource_name, api_key=api_key, api_base=api_base, - force=force, ): return 1 diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index a8a132d6..5abe49e8 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -289,7 +289,6 @@ def upload_command(args: argparse.Namespace) -> int: base_id = getattr(args, "id", None) display_name = getattr(args, "display_name", None) description = getattr(args, "description", None) - force = bool(getattr(args, "force", False)) env_file = getattr(args, "env_file", None) # Load secrets from .env file and ensure they're available on Fireworks @@ -378,17 +377,18 @@ def upload_command(args: argparse.Namespace) -> int: print(f"\nUploading evaluator '{evaluator_id}' for {qualname.split('.')[-1]}...") try: - result = create_evaluation( + result, version_id = create_evaluation( evaluator_id=evaluator_id, display_name=display_name or evaluator_id, description=description or f"Evaluator for {qualname}", - force=force, entry_point=entry_point, ) name = result.get("name", evaluator_id) if isinstance(result, dict) else evaluator_id # Print success message with Fireworks dashboard link print(f"\n✅ Successfully uploaded evaluator: {evaluator_id}") + if version_id: + print(f" Version: {version_id}") print("📊 View in Fireworks Dashboard:") dashboard_url = _build_evaluator_dashboard_url(evaluator_id) print(f" {dashboard_url}\n") diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 128038bf..31298992 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -4,14 +4,15 @@ from typing import List, Optional import fireworks +from fireworks.types import EvaluatorVersionParam import requests -from fireworks import Fireworks from eval_protocol.auth import ( get_fireworks_account_id, get_fireworks_api_key, verify_api_key_and_get_account_id, ) +from eval_protocol.fireworks_client import create_fireworks_client from eval_protocol.get_pep440_version import get_pep440_version logger = logging.getLogger(__name__) @@ -153,7 +154,7 @@ def _create_tar_gz_with_ignores(output_path: str, source_dir: str) -> int: logger.info(f"Created {output_path} ({size_bytes:,} bytes)") return size_bytes - def create(self, evaluator_id, display_name=None, description=None, force=False): + def create(self, evaluator_id, display_name=None, description=None): auth_token = self.api_key or get_fireworks_api_key() account_id = self.account_id or get_fireworks_account_id() if not account_id and auth_token: @@ -163,7 +164,11 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) logger.error("Authentication error: API credentials appear to be invalid or incomplete.") raise ValueError("Invalid or missing API credentials.") - client = Fireworks(api_key=auth_token, base_url=self.api_base, account_id=account_id) + client = create_fireworks_client( + api_key=auth_token, + base_url=self.api_base, + account_id=account_id, + ) self.display_name = display_name or evaluator_id self.description = description or f"Evaluator created from {evaluator_id}" @@ -197,28 +202,20 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) logger.info(f"Creating evaluator '{evaluator_id}' for account '{account_id}'...") try: - if force: - try: - logger.info("Checking if evaluator exists") - existing_evaluator = client.evaluators.get(evaluator_id=evaluator_id) - if existing_evaluator: - logger.info(f"Evaluator '{evaluator_id}' already exists, deleting and recreating...") - try: - client.evaluators.delete(evaluator_id=evaluator_id) - logger.info(f"Successfully deleted evaluator '{evaluator_id}'") - except fireworks.NotFoundError: - logger.info(f"Evaluator '{evaluator_id}' not found, creating...") - except fireworks.APIError as e: - logger.warning(f"Error deleting evaluator: {str(e)}") - except fireworks.NotFoundError: - logger.info(f"Evaluator '{evaluator_id}' does not exist, creating...") - - # Create evaluator using SDK - result = client.evaluators.create( - evaluator_id=evaluator_id, - evaluator=evaluator_params, - ) - logger.info(f"Successfully created evaluator '{evaluator_id}'") + # Try to create evaluator using SDK + try: + result = client.evaluators.create( + evaluator_id=evaluator_id, + evaluator=evaluator_params, + ) + logger.info(f"Successfully created evaluator '{evaluator_id}'") + except fireworks.APIStatusError as create_error: + if create_error.status_code == 409: + # Evaluator already exists, get the existing one and proceed to create a new version + logger.info(f"Evaluator '{evaluator_id}' already exists, creating new version...") + result = client.evaluators.get(evaluator_id=evaluator_id) + else: + raise # Upload code as tar.gz to GCS evaluator_name = result.name # e.g., "accounts/pyroworks/evaluators/test-123" @@ -229,6 +226,25 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) f"Cannot proceed with code upload. Response: {result}" ) + evaluator_version_param: EvaluatorVersionParam = {} + if "commit_hash" in evaluator_params: + evaluator_version_param["commit_hash"] = evaluator_params["commit_hash"] + if "entry_point" in evaluator_params: + evaluator_version_param["entry_point"] = evaluator_params["entry_point"] + if "requirements" in evaluator_params: + evaluator_version_param["requirements"] = evaluator_params["requirements"] + + evaluator_version = client.evaluator_versions.create( + evaluator_id=evaluator_id, + evaluator_version=evaluator_version_param, + ) + evaluator_version_id = evaluator_version.name.split("/")[-1] if evaluator_version.name else None + if not evaluator_version_id: + raise ValueError( + "Create evaluator version response missing 'name' field. " + f"Cannot proceed with code upload. Response: {evaluator_version}" + ) + try: # Create tar.gz of current directory cwd = os.getcwd() @@ -240,7 +256,8 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) # Call GetEvaluatorUploadEndpoint using SDK logger.info(f"Requesting upload endpoint for {tar_filename}") - upload_response = client.evaluators.get_upload_endpoint( + upload_response = client.evaluator_versions.get_upload_endpoint( + version_id=evaluator_version_id, evaluator_id=evaluator_id, filename_to_size={tar_filename: str(tar_size)}, ) @@ -321,9 +338,9 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) raise # Step 3: Validate upload using SDK - client.evaluators.validate_upload( + client.evaluator_versions.validate_upload( + version_id=evaluator_version_id, evaluator_id=evaluator_id, - body={}, ) logger.info("Upload validated successfully") @@ -334,8 +351,10 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) except Exception as upload_error: logger.warning(f"Code upload failed (evaluator created but code not uploaded): {upload_error}") # Don't fail - evaluator is created, just code upload failed + # Return None for version_id since upload failed + return result, None - return result # Return after attempting upload + return result, evaluator_version_id # Return evaluator result and version ID except fireworks.APIStatusError as e: logger.error(f"Error creating evaluator: {str(e)}") logger.error(f"Status code: {e.status_code}, Response: {e.response.text}") @@ -361,7 +380,6 @@ def create_evaluation( evaluator_id: str, display_name: Optional[str] = None, description: Optional[str] = None, - force: bool = False, account_id: Optional[str] = None, api_key: Optional[str] = None, entry_point: Optional[str] = None, @@ -373,10 +391,13 @@ def create_evaluation( evaluator_id: Unique identifier for the evaluator display_name: Display name for the evaluator description: Description for the evaluator - force: If True, delete and recreate if evaluator exists account_id: Optional Fireworks account ID api_key: Optional Fireworks API key entry_point: Optional entry point (module::function or path::function) + + Returns: + A tuple of (evaluator_result, version_id) where version_id is the ID of the + created evaluator version, or None if upload failed. """ evaluator = Evaluator( account_id=account_id, @@ -384,4 +405,4 @@ def create_evaluation( entry_point=entry_point, ) - return evaluator.create(evaluator_id, display_name, description, force) + return evaluator.create(evaluator_id, display_name, description) diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index 5086d6e3..59a026ed 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -11,8 +11,8 @@ # Retry configuration for database operations -SQLITE_RETRY_MAX_TRIES = 5 -SQLITE_RETRY_MAX_TIME = 30 # seconds +SQLITE_RETRY_MAX_TRIES = 10 +SQLITE_RETRY_MAX_TIME = 60 # seconds def _is_database_locked_error(e: Exception) -> bool: diff --git a/eval_protocol/fireworks_client.py b/eval_protocol/fireworks_client.py new file mode 100644 index 00000000..d92d8bec --- /dev/null +++ b/eval_protocol/fireworks_client.py @@ -0,0 +1,132 @@ +""" +Consolidated Fireworks client factory. + +This module provides a single point of instantiation for the Fireworks SDK client, +ensuring consistent handling of environment variables and configuration across the +eval_protocol codebase. + +Environment variables: + FIREWORKS_API_KEY: API key for authentication (required) + FIREWORKS_ACCOUNT_ID: Account ID (optional, can be derived from API key) + FIREWORKS_API_BASE: Base URL for the API (default: https://api.fireworks.ai) + FIREWORKS_EXTRA_HEADERS: JSON-encoded extra headers to include in requests + Example: '{"X-Custom-Header": "value", "X-Another": "another-value"}' +""" + +import json +import logging +import os +from typing import Mapping, Optional + +from fireworks import Fireworks + +from eval_protocol.auth import ( + get_fireworks_account_id, + get_fireworks_api_base, + get_fireworks_api_key, +) + +logger = logging.getLogger(__name__) + + +def get_fireworks_extra_headers() -> Optional[Mapping[str, str]]: + """ + Retrieves extra headers from the FIREWORKS_EXTRA_HEADERS environment variable. + + The value should be a JSON-encoded object mapping header names to values. + Example: '{"X-Custom-Header": "value"}' + + Returns: + A mapping of header names to values if set and valid, otherwise None. + """ + extra_headers_str = os.environ.get("FIREWORKS_EXTRA_HEADERS") + if not extra_headers_str or not extra_headers_str.strip(): + return None + + try: + headers = json.loads(extra_headers_str) + if not isinstance(headers, dict): + logger.warning( + "FIREWORKS_EXTRA_HEADERS must be a JSON object, got %s. Ignoring.", + type(headers).__name__, + ) + return None + # Validate all keys and values are strings + for k, v in headers.items(): + if not isinstance(k, str) or not isinstance(v, str): + logger.warning( + "FIREWORKS_EXTRA_HEADERS contains non-string key or value: %s=%s. Ignoring all extra headers.", + k, + v, + ) + return None + logger.debug("Using FIREWORKS_EXTRA_HEADERS: %s", list(headers.keys())) + return headers + except json.JSONDecodeError as e: + logger.warning("Failed to parse FIREWORKS_EXTRA_HEADERS as JSON: %s. Ignoring.", e) + return None + + +def create_fireworks_client( + *, + api_key: Optional[str] = None, + account_id: Optional[str] = None, + base_url: Optional[str] = None, + extra_headers: Optional[Mapping[str, str]] = None, +) -> Fireworks: + """ + Create a Fireworks client with consistent configuration. + + This factory function centralizes the logic for creating Fireworks clients, + ensuring that environment variables are handled consistently across the codebase. + + Resolution order for each parameter: + 1. Explicit argument passed to this function + 2. Environment variable (via auth module helpers) + 3. SDK defaults (for base_url only) + + Args: + api_key: Fireworks API key. If not provided, resolves from FIREWORKS_API_KEY. + account_id: Fireworks account ID. If not provided, resolves from FIREWORKS_ACCOUNT_ID + or derives from the API key via the verifyApiKey endpoint. + base_url: Base URL for the Fireworks API. If not provided, resolves from + FIREWORKS_API_BASE or defaults to https://api.fireworks.ai. + extra_headers: Additional headers to include in all requests. If not provided, + resolves from FIREWORKS_EXTRA_HEADERS environment variable (JSON-encoded). + + Returns: + A configured Fireworks client instance. + + Raises: + fireworks.FireworksError: If api_key is not provided and FIREWORKS_API_KEY + environment variable is not set. + """ + # Resolve parameters from environment if not explicitly provided + resolved_api_key = api_key or get_fireworks_api_key() + resolved_account_id = account_id or get_fireworks_account_id() + resolved_base_url = base_url or get_fireworks_api_base() + + # Merge extra headers: env var headers first, then explicit headers override + env_extra_headers = get_fireworks_extra_headers() + merged_headers: Optional[Mapping[str, str]] = None + if env_extra_headers or extra_headers: + merged = {} + if env_extra_headers: + merged.update(env_extra_headers) + if extra_headers: + merged.update(extra_headers) + merged_headers = merged if merged else None + + logger.debug( + "Creating Fireworks client: base_url=%s, account_id=%s, extra_headers=%s", + resolved_base_url, + resolved_account_id, + list(merged_headers.keys()) if merged_headers else None, + ) + + return Fireworks( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_base_url, + default_headers=merged_headers, + ) diff --git a/eval_protocol/mcp/mcp_multi_client.py b/eval_protocol/mcp/mcp_multi_client.py index 4c138796..faa774a9 100644 --- a/eval_protocol/mcp/mcp_multi_client.py +++ b/eval_protocol/mcp/mcp_multi_client.py @@ -13,7 +13,6 @@ class FunctionLike(BaseModel): parameters: Any = None -from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client @@ -26,8 +25,6 @@ class FunctionLike(BaseModel): MCPMultiClientConfiguration, ) -load_dotenv() # load environment variables from .env - class MCPMultiClient: """ diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index 60743ccb..8b07f4d7 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -3,39 +3,17 @@ import sys from typing import Optional -from dotenv import find_dotenv, load_dotenv - from eval_protocol.auth import ( get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key, ) +from eval_protocol.fireworks_client import create_fireworks_client from fireworks.types import Secret -from fireworks import Fireworks, FireworksError, NotFoundError, InternalServerError +from fireworks import FireworksError, NotFoundError, InternalServerError logger = logging.getLogger(__name__) -# --- Load .env files --- -# Attempt to load .env.dev first, then .env as a fallback. -# This happens when the module is imported. -# We use override=False (default) so that existing environment variables -# (e.g., set in the shell) are NOT overridden by .env files. -ENV_DEV_PATH = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) -if ENV_DEV_PATH: - load_dotenv(dotenv_path=ENV_DEV_PATH, override=False) - logger.info(f"eval_protocol.platform_api: Loaded environment variables from: {ENV_DEV_PATH}") -else: - ENV_PATH = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) - if ENV_PATH: - load_dotenv(dotenv_path=ENV_PATH, override=False) - logger.info(f"eval_protocol.platform_api: Loaded environment variables from: {ENV_PATH}") - else: - logger.info( - "eval_protocol.platform_api: No .env.dev or .env file found. " - "Relying on shell/existing environment variables." - ) -# --- End .env loading --- - class PlatformAPIError(Exception): """Custom exception for platform API errors.""" @@ -88,7 +66,11 @@ def create_or_update_fireworks_secret( resolved_api_key = api_key or get_fireworks_api_key() resolved_api_base = api_base or get_fireworks_api_base() resolved_account_id = account_id # Must be provided - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) if not all([resolved_api_key, resolved_api_base, resolved_account_id]): logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.") @@ -173,7 +155,11 @@ def get_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.") return None - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) resource_id = _normalize_secret_resource_id(key_name) try: @@ -215,7 +201,11 @@ def delete_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.") return False - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) resource_id = _normalize_secret_resource_id(key_name) try: diff --git a/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py b/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py index ffd8b9ea..87db9acb 100644 --- a/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py +++ b/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py @@ -13,11 +13,14 @@ from flask import Flask, request, jsonify from openai import OpenAI import openai +from pathlib import Path + from dotenv import load_dotenv from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter -load_dotenv() +# Use explicit path to avoid find_dotenv() searching up the directory tree +load_dotenv(dotenv_path=Path(".") / ".env") # Configure logging so INFO and below go to stdout, WARNING+ to stderr. # This avoids Vercel marking INFO logs as [error] (stderr). diff --git a/pyproject.toml b/pyproject.toml index e5caa497..841f5ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "pytest-asyncio>=0.21.0", "peewee>=3.18.2", "backoff>=2.2.0", - "fireworks-ai==1.0.0a20", + "fireworks-ai==1.0.0a22", "questionary>=2.0.0", # Dependencies for vendored tau2 package "toml>=0.10.0", diff --git a/tests/test_cli_create_rft.py b/tests/test_cli_create_rft.py index 1f1e8395..9832aec2 100644 --- a/tests/test_cli_create_rft.py +++ b/tests/test_cli_create_rft.py @@ -1,7 +1,6 @@ import json import os import argparse -import requests from types import SimpleNamespace from unittest.mock import patch from typing import Any, cast @@ -24,7 +23,7 @@ def _write_json(path: str, data: dict) -> None: def stub_fireworks(monkeypatch) -> dict[str, Any]: """ Stub Fireworks SDK so tests stay offline and so create_rft.py can inspect a stable - create() signature (it uses inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)). + create() signature (it uses inspect.signature(create_fireworks_client().reinforcement_fine_tuning_jobs.create)). Returns: A dict containing the last captured create() kwargs under key "kwargs". @@ -72,12 +71,15 @@ def create( return SimpleNamespace(name=f"accounts/{account_id}/reinforcementFineTuningJobs/xyz") class _FakeFW: - def __init__(self, api_key=None, base_url=None): + def __init__(self, api_key=None, base_url=None, account_id=None, default_headers=None): self.api_key = api_key self.base_url = base_url + self.account_id = account_id + self.default_headers = default_headers self.reinforcement_fine_tuning_jobs = _FakeJobs() - monkeypatch.setattr(cr, "Fireworks", _FakeFW) + # Patch create_fireworks_client to return our fake client + monkeypatch.setattr(cr, "create_fireworks_client", lambda **kwargs: _FakeFW(**kwargs)) return captured @@ -103,7 +105,7 @@ def rft_test_harness(tmp_path, monkeypatch, stub_fireworks): monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) return project @@ -239,7 +241,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator=None, yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -299,7 +300,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator=None, yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -351,7 +351,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator="my-evaluator", yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -401,7 +400,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator="my-evaluator", yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -447,7 +445,7 @@ def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(rft_test_ monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) captured = {"dataset_id": None} @@ -462,7 +460,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d setattr(args, "evaluator", None) setattr(args, "yes", True) setattr(args, "dry_run", False) - setattr(args, "force", False) setattr(args, "env_file", None) setattr(args, "dataset", None) setattr(args, "dataset_jsonl", str(ds_path)) @@ -530,7 +527,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"), yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -600,7 +596,6 @@ def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypat evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -645,17 +640,8 @@ def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch, monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - # Mock evaluator exists and is ACTIVE - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + # Mock evaluator upload and version polling - evaluator becomes ACTIVE + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) # Provide dataset via --dataset-jsonl so no test discovery needed ds_path = project / "dataset.jsonl" @@ -674,7 +660,6 @@ def raise_for_status(self): evaluator="some-eval", yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -708,11 +693,8 @@ def test_create_rft_quiet_new_evaluator_ambiguous_without_entry_errors(tmp_path, monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - # Evaluator does not exist (force path into upload section) - def _raise(*a, **k): - raise requests.exceptions.RequestException("nope") - - monkeypatch.setattr(cr.requests, "get", _raise) + # Mock _upload_and_ensure_evaluator to fail (ambiguous tests) + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: False) # Two discovered tests (ambiguous) f1 = project / "a.py" @@ -727,7 +709,6 @@ def _raise(*a, **k): evaluator="some-eval", yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(project / "dataset.jsonl"), @@ -789,7 +770,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -850,7 +830,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -912,7 +891,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -957,18 +935,8 @@ def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(r d2 = SimpleNamespace(qualname="beta.test_two", file_path=str(f2)) monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) - # Evaluator exists and is ACTIVE (skip upload) - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + # Evaluator upload succeeds and version becomes ACTIVE + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) # We will provide JSONL via input_dataset extractor for matching test (beta.test_two) jsonl_path = project / "data.jsonl" @@ -1007,7 +975,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=eval_id, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -1050,17 +1017,8 @@ def test_cli_full_command_style_evaluator_and_dataset_flags(tmp_path, monkeypatc monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "pyroworks-dev") - # Mock evaluator exists and ACTIVE - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + # Mock evaluator upload succeeds and version becomes ACTIVE + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) captured = stub_fireworks @@ -1143,7 +1101,7 @@ def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(rft_test_h monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) # Prepare two JSONL paths: one explicit via --dataset-jsonl and one inferable via input_dataset explicit_jsonl = project / "metric" / "explicit.jsonl" @@ -1175,7 +1133,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(explicit_jsonl), @@ -1266,7 +1223,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, diff --git a/tests/test_ep_upload_e2e.py b/tests/test_ep_upload_e2e.py index 8a67fd33..e76ac246 100644 --- a/tests/test_ep_upload_e2e.py +++ b/tests/test_ep_upload_e2e.py @@ -80,8 +80,8 @@ def mock_gcs_upload(): @pytest.fixture def mock_fireworks_client(): - """Mock the Fireworks SDK client used in evaluation.py""" - with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class: + """Mock the Fireworks SDK client used in fireworks_client.py""" + with patch("eval_protocol.fireworks_client.Fireworks") as mock_fw_class: mock_client = MagicMock() mock_fw_class.return_value = mock_client @@ -92,8 +92,13 @@ def mock_fireworks_client(): mock_create_response.description = "Test description" mock_client.evaluators.create.return_value = mock_create_response - # Mock evaluators.get_upload_endpoint response - will be set dynamically - def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): + # Mock evaluator_versions.create response + mock_version_response = MagicMock() + mock_version_response.name = "accounts/test_account/evaluators/test-eval/versions/v1" + mock_client.evaluator_versions.create.return_value = mock_version_response + + # Mock evaluator_versions.get_upload_endpoint response - will be set dynamically + def get_upload_endpoint_side_effect(evaluator_id, version_id, filename_to_size): response = MagicMock() signed_urls = {} for filename in filename_to_size.keys(): @@ -101,35 +106,13 @@ def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): response.filename_to_signed_urls = signed_urls return response - mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect + mock_client.evaluator_versions.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect - # Mock evaluators.validate_upload response + # Mock evaluator_versions.validate_upload response mock_validate_response = MagicMock() mock_validate_response.success = True mock_validate_response.valid = True - mock_client.evaluators.validate_upload.return_value = mock_validate_response - - # Mock evaluators.get (for force flow - raises NotFoundError by default) - import fireworks - - mock_client.evaluators.get.side_effect = fireworks.NotFoundError( - "Evaluator not found", - response=MagicMock(status_code=404), - body={"error": "not found"}, - ) - - # Mock evaluators.delete - mock_client.evaluators.delete.return_value = None - - yield mock_client - - -@pytest.fixture -def mock_platform_api_client(): - """Mock the Fireworks SDK client used in platform_api.py for secrets""" - with patch("eval_protocol.platform_api.Fireworks") as mock_fw_class: - mock_client = MagicMock() - mock_fw_class.return_value = mock_client + mock_client.evaluator_versions.validate_upload.return_value = mock_validate_response # Mock secrets.get - raise NotFoundError to simulate secret doesn't exist from fireworks import NotFoundError @@ -141,13 +124,23 @@ def mock_platform_api_client(): ) # Mock secrets.create - successful - mock_create_response = MagicMock() - mock_create_response.name = "accounts/test_account/secrets/test-secret" - mock_client.secrets.create.return_value = mock_create_response + mock_secrets_create_response = MagicMock() + mock_secrets_create_response.name = "accounts/test_account/secrets/test-secret" + mock_client.secrets.create.return_value = mock_secrets_create_response yield mock_client +@pytest.fixture +def mock_platform_api_client(mock_fireworks_client): + """ + Mock the Fireworks SDK client for secrets. + This is now just an alias for mock_fireworks_client since both use the same patched location. + The mock_fireworks_client fixture already includes secrets mocking. + """ + yield mock_fireworks_client + + def test_ep_upload_discovers_and_uploads_evaluation_test( mock_env_variables, mock_fireworks_client, mock_platform_api_client, mock_gcs_upload, monkeypatch ): @@ -214,7 +207,6 @@ async def test_simple_evaluation(row: EvaluationRow) -> EvaluationRow: id="test-simple-eval", # Explicit ID display_name="Simple Word Count Eval", description="E2E test evaluator", - force=False, yes=True, # Non-interactive ) @@ -232,13 +224,18 @@ async def test_simple_evaluation(row: EvaluationRow) -> EvaluationRow: # Step 1: Create evaluator assert mock_fireworks_client.evaluators.create.called, "Should call evaluators.create" - # Step 2: Get upload endpoint - assert mock_fireworks_client.evaluators.get_upload_endpoint.called, ( - "Should call evaluators.get_upload_endpoint" + # Step 1b: Create evaluator version + assert mock_fireworks_client.evaluator_versions.create.called, "Should call evaluator_versions.create" + + # Step 2: Get upload endpoint (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called, ( + "Should call evaluator_versions.get_upload_endpoint" ) - # Step 3: Validate upload - assert mock_fireworks_client.evaluators.validate_upload.called, "Should call evaluators.validate_upload" + # Step 3: Validate upload (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.validate_upload.called, ( + "Should call evaluator_versions.validate_upload" + ) # Step 4: GCS upload assert mock_gcs_upload.send.called, "Should upload tar.gz to GCS" @@ -327,7 +324,6 @@ async def test_multi_model_eval(row: EvaluationRow) -> EvaluationRow: id="test-param-eval", display_name="Parametrized Eval", description="Test parametrized evaluator", - force=False, yes=True, ) @@ -339,8 +335,9 @@ async def test_multi_model_eval(row: EvaluationRow) -> EvaluationRow: # Verify upload flow completed via Fireworks SDK assert mock_fireworks_client.evaluators.create.called - assert mock_fireworks_client.evaluators.get_upload_endpoint.called - assert mock_fireworks_client.evaluators.validate_upload.called + assert mock_fireworks_client.evaluator_versions.create.called + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called + assert mock_fireworks_client.evaluator_versions.validate_upload.called assert mock_gcs_upload.send.called finally: @@ -506,7 +503,6 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: id=None, # Auto-generate from test name display_name=None, # Auto-generate description=None, # Auto-generate - force=False, yes=True, ) @@ -520,8 +516,13 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: # Step 1: Create evaluator assert mock_fireworks_client.evaluators.create.called, "Missing create call" - # Step 2: Get upload endpoint - assert mock_fireworks_client.evaluators.get_upload_endpoint.called, "Missing getUploadEndpoint call" + # Step 1b: Create evaluator version + assert mock_fireworks_client.evaluator_versions.create.called, "Missing evaluator_versions.create call" + + # Step 2: Get upload endpoint (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called, ( + "Missing evaluator_versions.get_upload_endpoint call" + ) # Step 3: Upload to GCS assert mock_gcs_upload.send.called, "Missing GCS upload" @@ -529,8 +530,10 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: assert gcs_request.method == "PUT" assert "storage.googleapis.com" in gcs_request.url - # Step 4: Validate - assert mock_fireworks_client.evaluators.validate_upload.called, "Missing validateUpload call" + # Step 4: Validate (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.validate_upload.called, ( + "Missing evaluator_versions.validate_upload call" + ) # 4. VERIFY PAYLOAD DETAILS create_call = mock_fireworks_client.evaluators.create.call_args @@ -547,8 +550,8 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: assert "test_math_eval.py::test_math_correctness" in entry_point # 5. VERIFY TAR.GZ WAS CREATED AND UPLOADED - # Check getUploadEndpoint call payload - upload_call = mock_fireworks_client.evaluators.get_upload_endpoint.call_args + # Check getUploadEndpoint call payload (via evaluator_versions API) + upload_call = mock_fireworks_client.evaluator_versions.get_upload_endpoint.call_args assert upload_call is not None filename_to_size = upload_call.kwargs.get("filename_to_size", {}) assert filename_to_size, "Should have filename_to_size" @@ -597,95 +600,3 @@ def test_create_tar_includes_dockerignored_files(tmp_path): for expected_path in expected_paths: assert expected_path in names, f"Expected {expected_path} in archive" - - -def test_ep_upload_force_flag_triggers_delete_flow( - mock_env_variables, - mock_gcs_upload, - mock_platform_api_client, -): - """ - Test that --force flag triggers the check/delete/recreate flow - """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests - - test_content = """ -from eval_protocol.pytest import evaluation_test -from eval_protocol.models import EvaluationRow - -@evaluation_test(input_rows=[[EvaluationRow()]]) -async def test_force_eval(row: EvaluationRow) -> EvaluationRow: - return row -""" - - test_project_dir, test_file_path = create_test_project_with_evaluation_test(test_content, "test_force.py") - - original_cwd = os.getcwd() - - try: - os.chdir(test_project_dir) - - # Mock the Fireworks client with evaluator existing (for force flow) - with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class: - mock_client = MagicMock() - mock_fw_class.return_value = mock_client - - # Mock evaluators.get to return an existing evaluator (not raise NotFoundError) - mock_existing_evaluator = MagicMock() - mock_existing_evaluator.name = "accounts/test_account/evaluators/test-force" - mock_client.evaluators.get.return_value = mock_existing_evaluator - - # Mock evaluators.delete - mock_client.evaluators.delete.return_value = None - - # Mock evaluators.create response - mock_create_response = MagicMock() - mock_create_response.name = "accounts/test_account/evaluators/test-force" - mock_client.evaluators.create.return_value = mock_create_response - - # Mock get_upload_endpoint - def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): - response = MagicMock() - signed_urls = {} - for filename in filename_to_size.keys(): - signed_urls[filename] = f"https://storage.googleapis.com/test-bucket/{filename}?signed=true" - response.filename_to_signed_urls = signed_urls - return response - - mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect - - # Mock validate_upload - mock_client.evaluators.validate_upload.return_value = MagicMock() - - discovered_tests = _discover_tests(test_project_dir) - - args = argparse.Namespace( - path=test_project_dir, - entry=None, - id="test-force", - display_name=None, - description=None, - force=True, # Force flag enabled - yes=True, - ) - - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: - mock_select.return_value = discovered_tests - exit_code = upload_command(args) - - assert exit_code == 0 - - # Verify check happened (evaluators.get was called) - assert mock_client.evaluators.get.called, "Should check if evaluator exists" - - # Verify delete happened (since evaluator existed) - assert mock_client.evaluators.delete.called, "Should delete existing evaluator" - - # Verify create happened after delete - assert mock_client.evaluators.create.called, "Should create evaluator after delete" - - finally: - os.chdir(original_cwd) - if test_project_dir in sys.path: - sys.path.remove(test_project_dir) - shutil.rmtree(test_project_dir, ignore_errors=True) diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 942c1962..0d4bb13e 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -41,6 +41,7 @@ def test_create_evaluation_helper(monkeypatch): # Track SDK calls create_called = False + version_create_called = False upload_endpoint_called = False validate_called = False @@ -61,7 +62,16 @@ def mock_create(evaluator_id, evaluator): assert evaluator["description"] == "Test description" return mock_evaluator_result - def mock_get_upload_endpoint(evaluator_id, filename_to_size): + # Mock evaluator_versions.create + mock_version_result = MagicMock() + mock_version_result.name = "accounts/test_account/evaluators/test-eval/versions/v1" + + def mock_version_create(evaluator_id, evaluator_version): + nonlocal version_create_called + version_create_called = True + return mock_version_result + + def mock_get_upload_endpoint(evaluator_id, version_id, filename_to_size): nonlocal upload_endpoint_called upload_endpoint_called = True mock_response = MagicMock() @@ -71,7 +81,7 @@ def mock_get_upload_endpoint(evaluator_id, filename_to_size): mock_response.filename_to_signed_urls = signed_urls return mock_response - def mock_validate_upload(evaluator_id, body): + def mock_validate_upload(evaluator_id, version_id): nonlocal validate_called validate_called = True return MagicMock() @@ -83,20 +93,21 @@ def mock_validate_upload(evaluator_id, body): mock_gcs_response.raise_for_status = MagicMock() mock_session.send.return_value = mock_gcs_response - # Patch the Fireworks client - with patch("eval_protocol.evaluation.Fireworks") as mock_fireworks_class: + # Patch the Fireworks client at the location where it's imported + with patch("eval_protocol.fireworks_client.Fireworks") as mock_fireworks_class: mock_client = MagicMock() mock_fireworks_class.return_value = mock_client mock_client.evaluators.create = mock_create - mock_client.evaluators.get_upload_endpoint = mock_get_upload_endpoint - mock_client.evaluators.validate_upload = mock_validate_upload + mock_client.evaluator_versions.create = mock_version_create + mock_client.evaluator_versions.get_upload_endpoint = mock_get_upload_endpoint + mock_client.evaluator_versions.validate_upload = mock_validate_upload # Patch requests.Session for GCS upload monkeypatch.setattr("requests.Session", lambda: mock_session) try: os.chdir(tmp_dir) - api_response = create_evaluation( + api_response, version_id = create_evaluation( evaluator_id="test-eval", display_name="Test Evaluator", description="Test description", @@ -107,8 +118,12 @@ def mock_validate_upload(evaluator_id, body): assert api_response.display_name == "Test Evaluator" assert api_response.description == "Test description" + # Verify version ID was returned + assert version_id == "v1", "Version ID should be returned" + # Verify full upload flow was executed assert create_called, "Create endpoint should be called" + assert version_create_called, "Version create should be called" assert upload_endpoint_called, "GetUploadEndpoint should be called" assert validate_called, "ValidateUpload should be called" assert mock_session.send.called, "GCS upload should happen" diff --git a/tests/test_fireworks_client.py b/tests/test_fireworks_client.py new file mode 100644 index 00000000..db0b08c6 --- /dev/null +++ b/tests/test_fireworks_client.py @@ -0,0 +1,143 @@ +"""Tests for the consolidated Fireworks client factory.""" + +import os +from unittest.mock import patch + +import pytest + +from eval_protocol.fireworks_client import ( + create_fireworks_client, + get_fireworks_extra_headers, +) + + +class TestGetFireworksExtraHeaders: + """Tests for get_fireworks_extra_headers function.""" + + def test_returns_none_when_env_var_not_set(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is not set.""" + with patch.dict(os.environ, {}, clear=True): + # Remove the env var if it exists + os.environ.pop("FIREWORKS_EXTRA_HEADERS", None) + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_empty_string(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is empty.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": ""}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_whitespace_only(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is whitespace only.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": " "}): + result = get_fireworks_extra_headers() + assert result is None + + def test_parses_valid_json_object(self): + """Should parse valid JSON object into dict.""" + headers = '{"X-Custom": "value", "X-Another": "test"}' + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": headers}): + result = get_fireworks_extra_headers() + assert result == {"X-Custom": "value", "X-Another": "test"} + + def test_returns_none_for_invalid_json(self): + """Should return None and log warning for invalid JSON.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": "not json"}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_json_array(self): + """Should return None when JSON is an array instead of object.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '["item1", "item2"]'}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_json_string(self): + """Should return None when JSON is a string instead of object.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '"just a string"'}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_non_string_values(self): + """Should return None when JSON object has non-string values.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '{"key": 123}'}): + result = get_fireworks_extra_headers() + assert result is None + + +class TestCreateFireworksClient: + """Tests for create_fireworks_client function.""" + + def test_creates_client_with_explicit_api_key(self): + """Should create client with explicitly provided API key.""" + client = create_fireworks_client(api_key="test-api-key") + assert client.api_key == "test-api-key" + + def test_creates_client_with_explicit_base_url(self): + """Should create client with explicitly provided base URL.""" + client = create_fireworks_client( + api_key="test-api-key", + base_url="https://custom.api.example.com", + ) + assert str(client.base_url).rstrip("/") == "https://custom.api.example.com" + + def test_creates_client_with_explicit_account_id(self): + """Should create client with explicitly provided account ID.""" + client = create_fireworks_client( + api_key="test-api-key", + account_id="test-account-123", + ) + assert client.account_id == "test-account-123" + + def test_creates_client_with_explicit_extra_headers(self): + """Should create client with explicitly provided extra headers.""" + extra_headers = {"X-Custom-Header": "test-value"} + client = create_fireworks_client( + api_key="test-api-key", + extra_headers=extra_headers, + ) + assert "X-Custom-Header" in client._custom_headers + assert client._custom_headers["X-Custom-Header"] == "test-value" + + def test_merges_env_and_explicit_extra_headers(self): + """Should merge env var headers with explicit headers, explicit taking precedence.""" + env_headers = '{"X-Env-Header": "env-value", "X-Override": "env"}' + explicit_headers = {"X-Explicit-Header": "explicit-value", "X-Override": "explicit"} + + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": env_headers}): + client = create_fireworks_client( + api_key="test-api-key", + extra_headers=explicit_headers, + ) + # Both headers should be present + assert client._custom_headers["X-Env-Header"] == "env-value" + assert client._custom_headers["X-Explicit-Header"] == "explicit-value" + # Explicit should override env + assert client._custom_headers["X-Override"] == "explicit" + + def test_uses_env_extra_headers_when_no_explicit(self): + """Should use env var extra headers when no explicit headers provided.""" + env_headers = '{"X-Env-Header": "env-value"}' + + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": env_headers}): + client = create_fireworks_client(api_key="test-api-key") + assert client._custom_headers["X-Env-Header"] == "env-value" + + def test_resolves_api_key_from_env(self): + """Should resolve API key from environment when not explicitly provided.""" + with patch.dict(os.environ, {"FIREWORKS_API_KEY": "env-api-key"}): + client = create_fireworks_client() + assert client.api_key == "env-api-key" + + def test_resolves_base_url_from_env(self): + """Should resolve base URL from environment when not explicitly provided.""" + with patch.dict( + os.environ, + { + "FIREWORKS_API_KEY": "test-key", + "FIREWORKS_API_BASE": "https://env.api.example.com", + }, + ): + client = create_fireworks_client() + assert str(client.base_url).rstrip("/") == "https://env.api.example.com" diff --git a/tests/test_no_implicit_dotenv.py b/tests/test_no_implicit_dotenv.py new file mode 100644 index 00000000..04855821 --- /dev/null +++ b/tests/test_no_implicit_dotenv.py @@ -0,0 +1,209 @@ +""" +Test to ensure load_dotenv() is never called without an explicit path. + +When load_dotenv() is called without a dotenv_path argument, it uses find_dotenv() +which searches up the directory tree for a .env file. This can cause unexpected +behavior when running the CLI from a subdirectory, as it may find a .env file +in a parent directory (e.g., the python-sdk repo's .env) instead of the intended +project's .env file. + +This test scans all Python files in the SDK to ensure that every call to +load_dotenv() includes an explicit dotenv_path argument. +""" + +import ast +import os +from pathlib import Path +from typing import List, Set, Tuple + +# Directories to scan for implicit load_dotenv calls +SCAN_DIRECTORIES = [ + "eval_protocol", +] + +# Directories to exclude from scanning (relative to repo root) +EXCLUDE_DIRECTORIES: Set[str] = { + ".venv", + ".git", + "__pycache__", + ".pytest_cache", + ".mypy_cache", + "node_modules", + "build", + "dist", + ".eggs", + "*.egg-info", +} + + +def find_implicit_load_dotenv_calls(file_path: Path) -> List[Tuple[int, str]]: + """ + Parse a Python file and find any load_dotenv() calls without explicit dotenv_path. + + Returns a list of (line_number, code_snippet) tuples for violations. + """ + violations = [] + + try: + with open(file_path, "r", encoding="utf-8") as f: + source = f.read() + except (IOError, UnicodeDecodeError): + return violations + + try: + tree = ast.parse(source, filename=str(file_path)) + except SyntaxError: + return violations + + for node in ast.walk(tree): + if isinstance(node, ast.Call): + # Check if this is a call to load_dotenv + func_name = None + if isinstance(node.func, ast.Name): + func_name = node.func.id + elif isinstance(node.func, ast.Attribute): + func_name = node.func.attr + + if func_name == "load_dotenv": + # Check if dotenv_path is provided as a positional or keyword argument + has_explicit_path = False + + # Check positional arguments (dotenv_path is the first positional arg) + if node.args: + has_explicit_path = True + + # Check keyword arguments + for keyword in node.keywords: + if keyword.arg == "dotenv_path": + has_explicit_path = True + break + + if not has_explicit_path: + # Get the source line for context + try: + lines = source.splitlines() + line = lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + except (IndexError, AttributeError): + line = "" + + violations.append((node.lineno, line)) + + return violations + + +def _should_exclude_dir(dir_name: str) -> bool: + """Check if a directory should be excluded from scanning.""" + return dir_name in EXCLUDE_DIRECTORIES or dir_name.startswith(".") + + +def _scan_directory(directory: Path, repo_root: Path) -> List[Tuple[Path, int, str]]: + """Scan a directory for implicit load_dotenv calls.""" + all_violations: List[Tuple[Path, int, str]] = [] + + for root, dirs, files in os.walk(directory): + # Filter out excluded directories in-place to prevent os.walk from descending into them + dirs[:] = [d for d in dirs if not _should_exclude_dir(d)] + + for filename in files: + if not filename.endswith(".py"): + continue + + file_path = Path(root) / filename + violations = find_implicit_load_dotenv_calls(file_path) + + for line_no, code in violations: + all_violations.append((file_path, line_no, code)) + + return all_violations + + +def test_no_implicit_load_dotenv_calls(): + """ + Ensure no load_dotenv() calls exist without an explicit dotenv_path argument. + + This prevents the CLI from accidentally loading .env files from parent directories + when running from a subdirectory. + """ + repo_root = Path(__file__).parent.parent + + all_violations: List[Tuple[Path, int, str]] = [] + + for scan_dir in SCAN_DIRECTORIES: + directory = repo_root / scan_dir + if directory.exists(): + violations = _scan_directory(directory, repo_root) + all_violations.extend(violations) + + if all_violations: + error_msg = [ + "Found load_dotenv() calls without explicit dotenv_path argument.", + "This can cause the CLI to load .env files from parent directories unexpectedly.", + "", + "Violations:", + ] + for file_path, line_no, code in all_violations: + try: + rel_path = file_path.relative_to(repo_root) + except ValueError: + rel_path = file_path + error_msg.append(f" {rel_path}:{line_no}: {code}") + + error_msg.extend( + [ + "", + "Fix by providing an explicit path:", + " load_dotenv(dotenv_path=Path('.') / '.env', override=True)", + "", + ] + ) + + assert False, "\n".join(error_msg) + + +def test_load_dotenv_ast_detection(): + """Test that our AST detection correctly identifies implicit vs explicit calls.""" + import tempfile + + # Test case: implicit call (should be detected) + implicit_code = """ +from dotenv import load_dotenv +load_dotenv() +load_dotenv(override=True) +load_dotenv(verbose=True, override=True) +""" + + # Test case: explicit call (should NOT be detected) + explicit_code = """ +from dotenv import load_dotenv +load_dotenv(dotenv_path='.env') +load_dotenv('.env') +load_dotenv(Path('.') / '.env') +load_dotenv(dotenv_path=Path('.') / '.env', override=True) +load_dotenv(env_file_path) # positional arg counts as explicit +""" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(implicit_code) + implicit_file = Path(f.name) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(explicit_code) + explicit_file = Path(f.name) + + try: + implicit_violations = find_implicit_load_dotenv_calls(implicit_file) + explicit_violations = find_implicit_load_dotenv_calls(explicit_file) + + # Should find 3 violations in implicit code + assert len(implicit_violations) == 3, ( + f"Expected 3 implicit violations, got {len(implicit_violations)}: {implicit_violations}" + ) + + # Should find 0 violations in explicit code + assert len(explicit_violations) == 0, ( + f"Expected 0 explicit violations, got {len(explicit_violations)}: {explicit_violations}" + ) + + finally: + implicit_file.unlink() + explicit_file.unlink() diff --git a/tests/test_upload_entrypoint.py b/tests/test_upload_entrypoint.py index 2ae23024..076a6f79 100644 --- a/tests/test_upload_entrypoint.py +++ b/tests/test_upload_entrypoint.py @@ -28,8 +28,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - # Simulate API response - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -40,7 +40,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -72,7 +71,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -83,7 +83,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -119,7 +118,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -130,7 +130,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -163,8 +162,8 @@ def test_llm_judge(row=None): monkeypatch.setenv("FIREWORKS_API_BASE", "https://dev.api.fireworks.ai") def fake_create_evaluation(**kwargs): - # Simulate creation result with evaluator name - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate creation result with evaluator name - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -174,7 +173,6 @@ def fake_create_evaluation(**kwargs): id="quickstart-test-llm-judge", display_name=None, description=None, - force=True, yes=True, ) @@ -204,7 +202,8 @@ def test_llm_judge(row=None): monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") def fake_create_evaluation(**kwargs): - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -214,7 +213,6 @@ def fake_create_evaluation(**kwargs): id="quickstart-test-llm-judge", display_name=None, description=None, - force=False, yes=True, ) diff --git a/uv.lock b/uv.lock index c175b81f..188413f6 100644 --- a/uv.lock +++ b/uv.lock @@ -1312,7 +1312,7 @@ requires-dist = [ { name = "dspy", marker = "extra == 'dspy'", specifier = ">=3.0.0" }, { name = "e2b", marker = "extra == 'dev'" }, { name = "fastapi", specifier = ">=0.116.1" }, - { name = "fireworks-ai", specifier = "==1.0.0a20" }, + { name = "fireworks-ai", specifier = "==1.0.0a22" }, { name = "google-auth", marker = "extra == 'bigquery'", specifier = ">=2.0.0" }, { name = "google-cloud-bigquery", marker = "extra == 'bigquery'", specifier = ">=3.0.0" }, { name = "gymnasium", marker = "extra == 'dev'", specifier = ">=1.2.0" }, @@ -1582,7 +1582,7 @@ wheels = [ [[package]] name = "fireworks-ai" -version = "1.0.0a20" +version = "1.0.0a22" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -1594,9 +1594,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1d/c6/cdc6c152876ee1253491e6f72c65c2cdaf7b22b320be0cec7ac5778d3b1c/fireworks_ai-1.0.0a20.tar.gz", hash = "sha256:c84f702445679ea768461dba8fb027175b82255021832a89f9ece65821a2ab25", size = 564097, upload-time = "2025-12-23T19:21:17.891Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ef/16/073cf6855d18e43c14972d4b8f8fe59a43e41b581d430a7fad1dae3b8ddf/fireworks_ai-1.0.0a22.tar.gz", hash = "sha256:ab6fc7ad2beb8d69454b8c8c34ccd5d97ffa8cefa308a5cac7e568676e4b1188", size = 572510, upload-time = "2026-01-13T23:52:12.538Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/a4/e2bc9c4af291786bc7fe364ae63503ba2c8161c2e71223d570a77f0a1415/fireworks_ai-1.0.0a20-py3-none-any.whl", hash = "sha256:b5e199978f71b564b2e19cf55a71c1ac20906d9a7b4ae75135fdccb245227722", size = 304153, upload-time = "2025-12-23T19:21:15.943Z" }, + { url = "https://files.pythonhosted.org/packages/26/ef/a932f1fc357b7847258c212d53c074df3956ffbcbf74b2d5c3fdf14fd805/fireworks_ai-1.0.0a22-py3-none-any.whl", hash = "sha256:4ee18a0cb454585baab4803d82ec647d70fd8078a737a7ca4be7a686bc468ce3", size = 316745, upload-time = "2026-01-13T23:52:11.268Z" }, ] [[package]]