From 5b301b71f09aad4d2103ce1a7a06de22d9198ac0 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Wed, 27 May 2026 23:06:03 +0530 Subject: [PATCH 01/16] feat(assessment): Implement L1 pipeline with topic relevance and duplicate detection - Added L1 pipeline orchestrator to run topic relevance and duplicate detection filters in series. - Introduced duplicate detection logic to filter out vague submissions and check for duplicates against a corpus. - Created topic relevance filter to assess the relevance of submissions based on user-defined criteria. - Integrated L1 results handling in the assessment service, allowing for conditional processing based on L1 outcomes. - Updated export functionality to include L1 results in the output, enhancing the assessment reporting capabilities. - Added Celery task for executing the L1 pipeline and managing assessment run statuses. --- .../064_add_l1_columns_to_assessment_run.py | 61 ++++ backend/app/api/routes/assessment/runs.py | 3 + backend/app/celery/tasks/job_execution.py | 24 ++ backend/app/core/config.py | 7 + backend/app/crud/assessment/__init__.py | 2 + backend/app/crud/assessment/batch.py | 19 +- backend/app/crud/assessment/core.py | 50 ++- backend/app/crud/assessment/cron.py | 2 +- backend/app/models/assessment.py | 42 ++- .../app/services/assessment/l1/__init__.py | 3 + .../assessment/l1/duplicate_detection.py | 214 +++++++++++++ .../app/services/assessment/l1/pipeline.py | 225 +++++++++++++ .../services/assessment/l1/topic_relevance.py | 93 ++++++ backend/app/services/assessment/service.py | 98 +++--- backend/app/services/assessment/tasks.py | 196 ++++++++++++ .../app/services/assessment/utils/export.py | 299 +++++++++++++----- 16 files changed, 1182 insertions(+), 156 deletions(-) create mode 100644 backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py create mode 100644 backend/app/services/assessment/l1/__init__.py create mode 100644 backend/app/services/assessment/l1/duplicate_detection.py create mode 100644 backend/app/services/assessment/l1/pipeline.py create mode 100644 backend/app/services/assessment/l1/topic_relevance.py create mode 100644 backend/app/services/assessment/tasks.py diff --git a/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py b/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py new file mode 100644 index 000000000..bce33e6cd --- /dev/null +++ b/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py @@ -0,0 +1,61 @@ +"""Add L1 pipeline columns to assessment_run + +Revision ID: 064 +Revises: 063 +Create Date: 2026-05-27 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +revision = "064" +down_revision = "063" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "assessment_run", + sa.Column( + "l1_object_store_url", + sa.String(), + nullable=True, + comment="S3 URL of stored L1 filter results JSON", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "l1_total_rows", + sa.Integer(), + nullable=True, + comment="Total rows fed into L1 pipeline", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "l1_total_passed", + sa.Integer(), + nullable=True, + comment="Rows that passed topic relevance and went to L2", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "l1_total_rejected", + sa.Integer(), + nullable=True, + comment="Rows rejected by topic relevance, stopped at L1", + ), + ) + + +def downgrade() -> None: + op.drop_column("assessment_run", "l1_total_rejected") + op.drop_column("assessment_run", "l1_total_passed") + op.drop_column("assessment_run", "l1_total_rows") + op.drop_column("assessment_run", "l1_object_store_url") diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py index 18a9be60e..18398eeb0 100644 --- a/backend/app/api/routes/assessment/runs.py +++ b/backend/app/api/routes/assessment/runs.py @@ -65,6 +65,9 @@ def _build_run_public( total_items=run.total_items, error_message=run.error_message, input=run.input, + l1_total_rows=run.l1_total_rows, + l1_total_passed=run.l1_total_passed, + l1_total_rejected=run.l1_total_rejected, inserted_at=run.inserted_at, updated_at=run.updated_at, ) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index adadf1c9c..ec7ad1bd0 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -232,6 +232,30 @@ def run_tts_batch_submission( ) +@celery_app.task( + bind=True, queue="low_priority", priority=1, soft_time_limit=1800, time_limit=2100 +) +def run_assessment_run( + self, + run_id: int, + organization_id: int, + project_id: int, + trace_id: str, + **kwargs, +): + from app.services.assessment.tasks import execute_assessment_run + + _set_trace(trace_id) + return _run_with_otel_parent( + self, + lambda: execute_assessment_run( + run_id=run_id, + organization_id=organization_id, + project_id=project_id, + ), + ) + + @celery_app.task(bind=True, queue="low_priority", priority=1) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_tts_result_processing") def run_tts_result_processing( diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 720846eb9..60504147b 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -171,6 +171,13 @@ def AWS_S3_BUCKET(self) -> str: DOC_TRANSFORMATION_PENDING_THRESHOLD_MINUTES: int = 30 PENDING_JOB_QUERY_TIMEOUT_MS: int = 1000 + # Assessment + ASSESSMENT_L1_GEMINI_MODEL: str = "gemini-3.1-flash-lite" + ASSESSMENT_L1_CONCURRENT_WORKERS: int = 8 + ASSESSMENT_L1_DUPLICATE_STORE_NAME: str = ( + "fileSearchStores/inquilabcorpus-782mxjcwisaz" + ) + @computed_field # type: ignore[prop-decorator] @property def COMPUTED_CELERY_WORKER_CONCURRENCY(self) -> int: diff --git a/backend/app/crud/assessment/__init__.py b/backend/app/crud/assessment/__init__.py index cd71bff91..8e5c9984d 100644 --- a/backend/app/crud/assessment/__init__.py +++ b/backend/app/crud/assessment/__init__.py @@ -13,6 +13,7 @@ list_assessment_runs, list_assessments, recompute_assessment_status, + update_assessment_run_l1_stats, update_assessment_run_status, ) from app.crud.assessment.dataset import ( @@ -42,5 +43,6 @@ "list_assessment_datasets", "list_assessments", "recompute_assessment_status", + "update_assessment_run_l1_stats", "update_assessment_run_status", ] diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py index b45603853..e5e52daec 100644 --- a/backend/app/crud/assessment/batch.py +++ b/backend/app/crud/assessment/batch.py @@ -161,6 +161,7 @@ def build_openai_jsonl( attachments: list[AssessmentAttachment], prompt_template: str | None, openai_params: dict, + row_indices: list[int] | None = None, ) -> list[dict[str, Any]]: """Build OpenAI batch JSONL data from dataset rows. @@ -174,7 +175,8 @@ def build_openai_jsonl( """ jsonl_data = [] - for idx, row in enumerate(rows): + for i, row in enumerate(rows): + idx = row_indices[i] if row_indices is not None else i # Build input array input_parts: list[dict[str, Any]] = [] @@ -219,6 +221,7 @@ def build_google_jsonl( attachments: list[AssessmentAttachment], prompt_template: str | None, google_params: dict, + row_indices: list[int] | None = None, ) -> list[dict[str, Any]]: """Build Google (Gemini) batch JSONL data from dataset rows. @@ -230,7 +233,8 @@ def build_google_jsonl( """ jsonl_data = [] - for idx, row in enumerate(rows): + for i, row in enumerate(rows): + idx = row_indices[i] if row_indices is not None else i parts: list[dict[str, Any]] = [] # Text prompt @@ -349,6 +353,8 @@ def submit_assessment_batch( assessment_input: dict[str, Any], organization_id: int, project_id: int, + preloaded_rows: list[dict[str, str]] | None = None, + row_indices: list[int] | None = None, ) -> BatchJob: """Build JSONL and submit a batch for one assessment run. @@ -371,8 +377,11 @@ def submit_assessment_batch( output_schema = assessment_input.get("output_schema") attachments = [AssessmentAttachment(**a) for a in attachments_raw] - # Load dataset rows - rows = _load_dataset_rows(session, dataset) + # Use preloaded rows (post-L1 filtered) if provided, else load from dataset. + if preloaded_rows is not None: + rows = preloaded_rows + else: + rows = _load_dataset_rows(session, dataset) if not rows: raise ValueError(f"Dataset {dataset.id} has no rows") @@ -412,6 +421,7 @@ def submit_assessment_batch( attachments=attachments, prompt_template=prompt_template, openai_params=mapped_params, + row_indices=row_indices, ) # Get OpenAI client and submit @@ -452,6 +462,7 @@ def submit_assessment_batch( attachments=attachments, prompt_template=prompt_template, google_params=mapped_params, + row_indices=row_indices, ) # Get Gemini client and submit diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py index c91626660..6e2c8b2f7 100644 --- a/backend/app/crud/assessment/core.py +++ b/backend/app/crud/assessment/core.py @@ -223,16 +223,53 @@ def update_assessment_run_status( return run +def update_assessment_run_l1_stats( + session: Session, + run: AssessmentRun, + l1_object_store_url: str | None = None, + l1_total_rows: int | None = None, + l1_total_passed: int | None = None, + l1_total_rejected: int | None = None, +) -> AssessmentRun: + """Persist L1 result stats (rows/passed/rejected + S3 URL) on a run.""" + run.updated_at = now() + + if l1_object_store_url is not None: + run.l1_object_store_url = l1_object_store_url + if l1_total_rows is not None: + run.l1_total_rows = l1_total_rows + if l1_total_passed is not None: + run.l1_total_passed = l1_total_passed + if l1_total_rejected is not None: + run.l1_total_rejected = l1_total_rejected + + session.add(run) + try: + session.commit() + session.refresh(run) + except Exception as e: + session.rollback() + logger.error(f"[update_assessment_run_l1_stats] Failed: {e}", exc_info=True) + raise + + return run + + +_ACTIVE_RUN_STATUSES = frozenset( + {"l1_processing", "l2_processing", "processing", "in_progress"} +) +_FAILED_RUN_STATUSES = frozenset({"failed", "l1_failed"}) +_COMPLETED_RUN_STATUSES = frozenset({"completed", "completed_with_errors"}) + + def compute_run_counts(runs: list[AssessmentRun]) -> AssessmentRunCounts: """Aggregate child run statuses into counters.""" return AssessmentRunCounts( total=len(runs), pending=sum(1 for run in runs if run.status == "pending"), - processing=sum( - 1 for run in runs if run.status in {"processing", "in_progress"} - ), - completed=sum(1 for run in runs if run.status == "completed"), - failed=sum(1 for run in runs if run.status == "failed"), + processing=sum(1 for run in runs if run.status in _ACTIVE_RUN_STATUSES), + completed=sum(1 for run in runs if run.status in _COMPLETED_RUN_STATUSES), + failed=sum(1 for run in runs if run.status in _FAILED_RUN_STATUSES), ) @@ -267,6 +304,9 @@ def build_run_stats(runs: list[AssessmentRun]) -> list[AssessmentRunStat]: total_items=run.total_items, error_message=run.error_message, updated_at=run.updated_at, + l1_total_rows=run.l1_total_rows, + l1_total_passed=run.l1_total_passed, + l1_total_rejected=run.l1_total_rejected, ) for run in runs ] diff --git a/backend/app/crud/assessment/cron.py b/backend/app/crud/assessment/cron.py index c69b3157e..6cb76b1f5 100644 --- a/backend/app/crud/assessment/cron.py +++ b/backend/app/crud/assessment/cron.py @@ -78,7 +78,7 @@ async def poll_all_pending_assessment_evaluations( runs = get_assessment_runs_for_assessment( session=session, assessment_id=assessment.id ) - active_runs = [run for run in runs if run.status == "processing"] + active_runs = [run for run in runs if run.status == "l2_processing"] if not active_runs: refreshed = recompute_assessment_status( diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index 25ac0f00e..0dd0a96d1 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -108,7 +108,10 @@ class AssessmentRun(SQLModel, table=True): status: str = SQLField( default="pending", sa_column_kwargs={ - "comment": "Run status: pending, processing, completed, failed" + "comment": ( + "Unified pipeline status: pending, l1_processing, l1_failed, " + "l2_processing, completed, completed_with_errors, failed" + ) }, ) batch_job_id: int | None = SQLField( @@ -136,7 +139,27 @@ class AssessmentRun(SQLModel, table=True): object_store_url: str | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "S3 URL of processed batch results"}, + sa_column_kwargs={"comment": "S3 URL of processed L2 batch results"}, + ) + l1_object_store_url: str | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "S3 URL of stored L1 filter results JSON"}, + ) + l1_total_rows: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Total rows fed into L1 pipeline"}, + ) + l1_total_passed: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Rows that passed topic relevance and went to L2"}, + ) + l1_total_rejected: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Rows rejected by topic relevance, stopped at L1"}, ) error_message: str | None = SQLField( default=None, @@ -185,6 +208,9 @@ class AssessmentRunStat(BaseModel): total_items: int error_message: str | None = None updated_at: datetime | None = None + l1_total_rows: int | None = None + l1_total_passed: int | None = None + l1_total_rejected: int | None = None class AssessmentPublic(BaseModel): @@ -224,6 +250,9 @@ class AssessmentRunPublic(BaseModel): "text_columns, attachments, output_schema" ), ) + l1_total_rows: int | None = None + l1_total_passed: int | None = None + l1_total_rejected: int | None = None inserted_at: datetime updated_at: datetime @@ -286,6 +315,13 @@ class AssessmentCreate(BaseModel): configs: list[AssessmentConfigRef] = Field( ..., min_length=1, max_length=4, description="Config versions to run" ) + l1_config: dict[str, Any] | None = Field( + None, + description=( + "L1 pipeline config. Keys: topic_relevance (columns, prompt), " + "duplicate_detection (columns). Omit to skip L1." + ), + ) class AssessmentRunSummary(BaseModel): @@ -324,6 +360,8 @@ class AssessmentExportRow(BaseModel): row_id: str result_status: str input_data: dict[str, str] | None = None + topic_relevance: str | None = None + duplicate_detection: str | None = None output: str | None = None error: str | None = None response_id: str | None = None diff --git a/backend/app/services/assessment/l1/__init__.py b/backend/app/services/assessment/l1/__init__.py new file mode 100644 index 000000000..66e3a0374 --- /dev/null +++ b/backend/app/services/assessment/l1/__init__.py @@ -0,0 +1,3 @@ +from app.services.assessment.l1.pipeline import run_l1_pipeline + +__all__ = ["run_l1_pipeline"] diff --git a/backend/app/services/assessment/l1/duplicate_detection.py b/backend/app/services/assessment/l1/duplicate_detection.py new file mode 100644 index 000000000..608389c1d --- /dev/null +++ b/backend/app/services/assessment/l1/duplicate_detection.py @@ -0,0 +1,214 @@ +"""Duplicate detection filter for L1 pipeline.""" + +import json +import logging +import re +from typing import Any + +from google import genai +from google.genai import types + +logger = logging.getLogger(__name__) + +_VAGUE_SYS = """ +You are a strict VAGUENESS gate for the School Innovation Marathon (SIM) +duplicate-detection pipeline. Submissions come from Indian school students grades 6-12. +You run BEFORE corpus duplicate detection. Decide only if the submission has enough +surface area for corpus matching. NOT a quality gate. + +NOT VAGUE (let through to corpus check): +- Widely-known/textbook ideas (rainwater harvesting, anti-theft alarm) +- Weak novelty / unclear feasibility +- Hindi/Telugu/mixed Indian-language text +- Bad grammar or rambling if content present +- Long essays naming domain + audience + any mechanism + +VAGUE only when ALL: problem names no issue/target/domain, solution names no mechanism, +text is empty / aspirational ("make society better") / gibberish. + +DECISION: 0-1 clear dimensions present -> vague=true. 2+ -> vague=false. Borderline -> false. + +Output ONLY JSON: {"vague": true|false, "reason": "max 15 words"} +""" + +_DUP_SYS = """ +You are a strict duplicate-detection judge for an innovation competition corpus. + +Given a submitted idea, search the corpus and compare precisely. +Focus on MECHANISM of the solution, not category or theme. + +Verdict (exactly one): DUPLICATE / OVERLAP / PARTIAL_MATCH / UNIQUE + + DUPLICATE: Both problem AND solution mechanism substantially match a corpus entry. + OVERLAP: Either problem OR solution mechanism matches, other side clearly different. + PARTIAL_MATCH: Thematic/conceptual similarity only — same domain, different mechanism. + UNIQUE: Neither problem nor solution substantially matches anything in corpus. + +Response format (follow exactly): +Verdict: +Title: +Source: +URL: +Matching sentence: +Reason: + +RULES: +- UNIQUE -> output ONLY Verdict + Reason. +- NOT UNIQUE -> Title, Source, URL, Matching sentence ALL required. +- Source/URL MUST be VERBATIM from "SOURCE_URL:" line in retrieved chunk. +- NEVER write filenames, page numbers, or constructed URLs. +""" + + +def _build_combined(content_parts: dict[str, str]) -> str: + parts = [f"{col}:\n{val}" for col, val in content_parts.items() if val.strip()] + return "\n\n".join(parts) + + +def _check_vague( + text: str, + gemini_client: genai.Client, + model: str, +) -> tuple[bool, str]: + try: + response = gemini_client.models.generate_content( + model=model, + contents=f"Submission:\n\n{text}", + config=types.GenerateContentConfig( + system_instruction=_VAGUE_SYS, + response_mime_type="application/json", + temperature=0.0, + ), + ) + parsed = json.loads((response.text or "").strip()) + return bool(parsed.get("vague", False)), str(parsed.get("reason", "")) + except Exception as exc: + logger.warning("[_check_vague] Parse error — defaulting not vague | %s", exc) + return False, "(vague check error — defaulting to not vague)" + + +def _call_file_search( + text: str, + gemini_client: genai.Client, + model: str, + store_name: str, +) -> str: + response = gemini_client.models.generate_content( + model=model, + contents=f"Submitted idea to check for duplicates:\n\n{text}", + config=types.GenerateContentConfig( + system_instruction=_DUP_SYS, + tools=[ + types.Tool( + file_search=types.FileSearch(file_search_store_names=[store_name]) + ) + ], + temperature=0.0, + ), + ) + return response.text or "" + + +_VERDICT_VALUES = {"DUPLICATE", "OVERLAP", "PARTIAL_MATCH", "UNIQUE"} + + +def _parse_verdict(raw: str) -> dict[str, str | None]: + fields: dict[str, str | None] = { + "verdict": "", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": None, + } + keymap = { + "verdict": "verdict", + "title": "match_title", + "source": "source_url", + "url": "source_url", + "matching sentence": "matching_sentence", + "reason": "reason", + } + for line in (raw or "").splitlines(): + if ":" not in line: + continue + k, _, v = line.partition(":") + norm = re.sub(r"[^a-z\s]", "", k.strip().lower()).strip() + if norm in keymap: + fields[keymap[norm]] = v.strip() or None + + # Fallback: scan entire response for a known verdict token + if not fields["verdict"] or fields["verdict"] not in _VERDICT_VALUES: + m = re.search(r"\b(DUPLICATE|OVERLAP|PARTIAL_MATCH|UNIQUE)\b", raw or "") + if m: + fields["verdict"] = m.group(1) + logger.warning( + "[_parse_verdict] key-based parse missed verdict; regex fallback found: %s", + fields["verdict"], + ) + else: + logger.warning( + "[_parse_verdict] verdict not found in response. raw=%r", + (raw or "")[:500], + ) + + return fields + + +def run_duplicate_detection( + row_idx: int, + row: dict[str, str], + columns: list[str], + gemini_client: genai.Client, + model: str, + store_name: str, +) -> dict[str, Any]: + """Run duplicate detection on a single row. + + Returns a dict with: row_id, verdict, match_title, source_url, + matching_sentence, reason. + Always passthrough — never gates L2. + """ + content_parts = {col: row.get(col, "") for col in columns} + combined = _build_combined(content_parts) or "(empty submission)" + + try: + is_vague, vague_reason = _check_vague(combined, gemini_client, model) + except Exception as exc: + logger.warning( + "[run_duplicate_detection] Vague check failed row_%s | %s", row_idx, exc + ) + is_vague, vague_reason = False, f"(vague check error: {exc})" + + if is_vague: + return { + "row_id": f"row_{row_idx}", + "verdict": "VAGUE", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": vague_reason, + } + + try: + raw = _call_file_search(combined, gemini_client, model, store_name) + parsed = _parse_verdict(raw) + return { + "row_id": f"row_{row_idx}", + "verdict": parsed["verdict"] or "UNKNOWN", + "match_title": parsed["match_title"], + "source_url": parsed["source_url"], + "matching_sentence": parsed["matching_sentence"], + "reason": parsed["reason"], + } + except Exception as exc: + logger.warning( + "[run_duplicate_detection] File search failed row_%s | %s", row_idx, exc + ) + return { + "row_id": f"row_{row_idx}", + "verdict": "ERROR", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": str(exc)[:200], + } diff --git a/backend/app/services/assessment/l1/pipeline.py b/backend/app/services/assessment/l1/pipeline.py new file mode 100644 index 000000000..2a002e5e5 --- /dev/null +++ b/backend/app/services/assessment/l1/pipeline.py @@ -0,0 +1,225 @@ +"""L1 pipeline orchestrator. + +Runs two filters in series for each row: +1. Topic Relevance (go/no-go) — REJECT stops the row. +2. Duplicate Detection (passthrough) — only on ACCEPTED rows. + +""" + +import json +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any + +from sqlmodel import Session + +from app.core.batch.client import GeminiClient +from app.core.config import settings +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import upload_jsonl_to_object_store +from app.models.assessment import AssessmentRun +from app.services.assessment.l1.duplicate_detection import run_duplicate_detection +from app.services.assessment.l1.topic_relevance import run_topic_relevance + +logger = logging.getLogger(__name__) + + +def _build_l1_result( + row_idx: int, + tr_result: dict[str, Any] | None, + dup_result: dict[str, Any] | None, +) -> dict[str, Any]: + return { + "row_id": f"row_{row_idx}", + "l1_passed": tr_result["verdict"] if tr_result else True, + "topic_relevance": { + "decision": tr_result["decision"], + "column_relevance": tr_result.get("column_relevance") or {}, + "reasoning": tr_result["reasoning"], + } + if tr_result + else None, + "duplicate_detection": dup_result, + } + + +def run_l1_pipeline( + run: AssessmentRun, + rows: list[dict[str, str]], + l1_config: dict[str, Any], + session: Session, + organization_id: int, + project_id: int, +) -> tuple[list[dict[str, str]], list[int], list[dict[str, Any]]]: + """Run L1 filters on all rows. + + Args: + run: The AssessmentRun record (used for S3 path and DB update). + rows: Full dataset rows loaded from object store. + l1_config: User-supplied config with topic_relevance and duplicate_detection keys. + session: DB session. + organization_id: For Gemini credential lookup. + project_id: For Gemini credential lookup and S3 storage. + + Returns: + (passed_rows, passed_indices, all_l1_results) + passed_rows: subset of rows where topic_relevance verdict=true. + passed_indices: original dataset indices of passed_rows (used to preserve row IDs in L2). + all_l1_results: one entry per input row (len == len(rows)). + """ + model = settings.ASSESSMENT_L1_GEMINI_MODEL + workers = settings.ASSESSMENT_L1_CONCURRENT_WORKERS + store_name = settings.ASSESSMENT_L1_DUPLICATE_STORE_NAME + + tr_config = l1_config.get("topic_relevance") or {} + dup_config = l1_config.get("duplicate_detection") or {} + + tr_columns: list[str] = tr_config.get("columns") or [] + tr_prompt: str = tr_config.get("prompt") or "" + dup_columns: list[str] = dup_config.get("columns") or [] + + tr_enabled = bool(tr_columns and tr_prompt) + dup_enabled = bool(dup_columns) + + if not tr_enabled and not dup_enabled: + logger.warning( + "[run_l1_pipeline] run_id=%s — no L1 filters configured, skipping L1", + run.id, + ) + return rows, list(range(len(rows))), [] + + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=organization_id, + project_id=project_id, + ).client + + logger.info( + "[run_l1_pipeline] run_id=%s | rows=%s | model=%s | workers=%s | tr=%s | dup=%s", + run.id, + len(rows), + model, + workers, + tr_enabled, + dup_enabled, + ) + + # tr_results[idx] = None when TR disabled → no topic_relevance columns in export + tr_results: dict[int, dict[str, Any] | None] = {} + if tr_enabled: + with ThreadPoolExecutor(max_workers=workers) as executor: + futs = { + executor.submit( + run_topic_relevance, + idx, + row, + tr_columns, + tr_prompt, + gemini_client, + model, + ): idx + for idx, row in enumerate(rows) + } + for fut in as_completed(futs): + idx = futs[fut] + try: + tr_results[idx] = fut.result() + except Exception as exc: + logger.warning( + "[run_l1_pipeline] TR future error row_%s | %s", idx, exc + ) + tr_results[idx] = { + "row_id": f"row_{idx}", + "verdict": True, + "decision": "ACCEPT", + "column_relevance": {}, + "reasoning": f"(future error — defaulting to pass) {exc}", + } + passed_indices = [idx for idx, r in tr_results.items() if r and r["verdict"]] + else: + for idx in range(len(rows)): + tr_results[idx] = None + passed_indices = list(range(len(rows))) + + rejected_count = len(rows) - len(passed_indices) + logger.info( + "[run_l1_pipeline] run_id=%s | TR done | passed=%s | rejected=%s", + run.id, + len(passed_indices), + rejected_count, + ) + + dup_results: dict[int, dict[str, Any]] = {} + if dup_columns and passed_indices: + with ThreadPoolExecutor(max_workers=workers) as executor: + futs = { + executor.submit( + run_duplicate_detection, + idx, + rows[idx], + dup_columns, + gemini_client, + model, + store_name, + ): idx + for idx in passed_indices + } + for fut in as_completed(futs): + idx = futs[fut] + try: + dup_results[idx] = fut.result() + except Exception as exc: + logger.warning( + "[run_l1_pipeline] DUP future error row_%s | %s", idx, exc + ) + dup_results[idx] = { + "row_id": f"row_{idx}", + "verdict": "ERROR", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": str(exc)[:200], + } + + all_l1_results: list[dict[str, Any]] = [ + _build_l1_result(idx, tr_results[idx], dup_results.get(idx)) + for idx in range(len(rows)) + ] + + l1_object_store_url: str | None = None + try: + storage = get_cloud_storage(session=session, project_id=project_id) + l1_object_store_url = upload_jsonl_to_object_store( + storage=storage, + results=all_l1_results, + filename="l1_results.json", + subdirectory=f"assessment/run-{run.id}/l1", + format="json", + ) + logger.info( + "[run_l1_pipeline] run_id=%s | L1 results uploaded to %s", + run.id, + l1_object_store_url, + ) + except Exception as exc: + logger.error( + "[run_l1_pipeline] run_id=%s | S3 upload failed | %s", + run.id, + exc, + exc_info=True, + ) + + from app.crud.assessment.core import update_assessment_run_l1_stats + + update_assessment_run_l1_stats( + session=session, + run=run, + l1_object_store_url=l1_object_store_url, + l1_total_rows=len(rows), + l1_total_passed=len(passed_indices), + l1_total_rejected=rejected_count, + ) + + sorted_passed_indices = sorted(passed_indices) + passed_rows = [rows[idx] for idx in sorted_passed_indices] + return passed_rows, sorted_passed_indices, all_l1_results diff --git a/backend/app/services/assessment/l1/topic_relevance.py b/backend/app/services/assessment/l1/topic_relevance.py new file mode 100644 index 000000000..42516d27b --- /dev/null +++ b/backend/app/services/assessment/l1/topic_relevance.py @@ -0,0 +1,93 @@ +"""Topic relevance filter for L1 pipeline. +""" + +import json +import logging +from typing import Any + +from google import genai +from google.genai import types + +logger = logging.getLogger(__name__) + + +def _build_output_schema(columns: list[str]) -> dict[str, Any]: + """Build output schema: locked decision + per-column relevance booleans + reasoning.""" + props: dict[str, Any] = { + "decision": { + "type": "string", + "enum": ["ACCEPT", "REJECT"], + "description": "Final verdict. ACCEPT to proceed to full evaluation, REJECT to stop here.", + }, + } + required = ["decision"] + + for col in columns: + props[col] = { + "type": "boolean", + "description": f"Whether the '{col}' column content is relevant to the topic.", + } + required.append(col) + + props["reasoning"] = { + "type": "string", + "description": "Explanation of the verdict and per-column relevance assessment.", + } + required.append("reasoning") + + return {"type": "object", "properties": props, "required": required} + + +def run_topic_relevance( + row_idx: int, + row: dict[str, str], + columns: list[str], + user_prompt: str, + gemini_client: genai.Client, + model: str, +) -> dict[str, Any]: + """Run topic relevance check on a single row. + + System instruction = user_prompt (the evaluation rubric/criteria). + User content = dict of {column_name: value} for the selected columns. + Output schema enforced: decision (ACCEPT/REJECT) + reasoning. + On error defaults to verdict=True (fail-open). + """ + user_content = json.dumps({col: row.get(col, "") or "" for col in columns}) + output_schema = _build_output_schema(columns) + + try: + response = gemini_client.models.generate_content( + model=model, + contents=user_content, + config=types.GenerateContentConfig( + system_instruction=user_prompt.strip(), + response_mime_type="application/json", + response_schema=output_schema, + temperature=0.0, + ), + ) + raw = (response.text or "").strip() + parsed = json.loads(raw) + decision = str(parsed.get("decision", "ACCEPT")).upper() + column_relevance = {col: bool(parsed.get(col, True)) for col in columns} + return { + "row_id": f"row_{row_idx}", + "verdict": decision == "ACCEPT", + "decision": decision, + "column_relevance": column_relevance, + "reasoning": str(parsed.get("reasoning", "")), + } + except Exception as exc: + logger.warning( + "[run_topic_relevance] row_%s error — defaulting verdict=True | %s", + row_idx, + exc, + ) + return { + "row_id": f"row_{row_idx}", + "verdict": True, + "decision": "ACCEPT", + "column_relevance": {col: True for col in columns}, + "reasoning": f"(evaluation error — defaulting to pass) {exc}", + } diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py index 45a283ea5..d03eb13a7 100644 --- a/backend/app/services/assessment/service.py +++ b/backend/app/services/assessment/service.py @@ -1,9 +1,9 @@ """Assessment run orchestration service.""" import logging -from typing import Any from uuid import UUID +from asgi_correlation_id import correlation_id from fastapi import HTTPException from sqlmodel import Session @@ -13,9 +13,7 @@ get_assessment_dataset_by_id, get_assessment_runs_for_assessment, recompute_assessment_status, - update_assessment_run_status, ) -from app.crud.assessment.batch import submit_assessment_batch from app.crud.config import ConfigCrud from app.crud.evaluations.core import resolve_evaluation_config from app.models.assessment import ( @@ -81,6 +79,7 @@ def _build_retry_request( attachments=[AssessmentAttachment.model_validate(item) for item in attachments], output_schema=assessment_input.get("output_schema"), configs=configs, + l1_config=assessment_input.get("l1_config"), ) @@ -90,11 +89,13 @@ def start_assessment( organization_id: int, project_id: int, ) -> AssessmentResponse: - """Start an assessment run request. + """Validate, create Assessment + AssessmentRun records, dispatch Celery tasks. - Validates the dataset, resolves each config, creates one AssessmentRun per config, - and kicks off batch processing for each. + Each run is created with status='pending' and handed off to a Celery worker + that runs L1 filtering then submits the L2 batch. """ + from app.celery.tasks.job_execution import run_assessment_run + logger.info( "[start_assessment] Starting | experiment=%s | dataset_id=%s | configs=%s | org_id=%s", request.experiment_name, @@ -110,7 +111,7 @@ def start_assessment( project_id=project_id, ) - assessment_input: dict[str, Any] = { + assessment_input: dict = { "prompt_template": request.prompt_template, "system_instruction": request.system_instruction, "text_columns": request.text_columns, @@ -118,12 +119,13 @@ def start_assessment( } if request.output_schema: assessment_input["output_schema"] = request.output_schema + if request.l1_config: + assessment_input["l1_config"] = request.l1_config config_crud = ConfigCrud(session=session, project_id=project_id) resolved_configs = [] for cfg in request.configs: - # Assessment runs must use configs explicitly tagged for assessment use. parent_config = config_crud.read_one(cfg.config_id) if parent_config is not None and parent_config.tag != ConfigTag.ASSESSMENT: tag_value = ( @@ -165,7 +167,7 @@ def start_assessment( f"Supported providers: {sorted(_SUPPORTED_BATCH_PROVIDERS)}" ), ) - resolved_configs.append((cfg, config_blob)) + resolved_configs.append(cfg) assessment = create_assessment( session=session, @@ -176,54 +178,30 @@ def start_assessment( ) runs: list[AssessmentRun] = [] - try: - for cfg, config_blob in resolved_configs: - run = create_assessment_run( - session=session, - assessment_id=assessment.id, - config_id=cfg.config_id, - config_version=cfg.config_version, - assessment_input=assessment_input, - ) - - try: - batch_job = submit_assessment_batch( - session=session, - run=run, - assessment=assessment, - dataset=dataset, - config_blob=config_blob, - assessment_input=assessment_input, - organization_id=organization_id, - project_id=project_id, - ) + trace_id = correlation_id.get() or "" - run = update_assessment_run_status( - session=session, - run=run, - status="processing", - batch_job_id=batch_job.id, - total_items=batch_job.total_items, - ) + for cfg in resolved_configs: + run = create_assessment_run( + session=session, + assessment_id=assessment.id, + config_id=cfg.config_id, + config_version=cfg.config_version, + assessment_input=assessment_input, + ) + runs.append(run) - except Exception as e: - logger.error( - "[start_assessment] Failed to submit batch for run %s: %s", - run.id, - e, - exc_info=True, - ) - run = update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message="Batch submission failed. Please try again or contact support.", - ) + run_assessment_run.delay( + run_id=run.id, + organization_id=organization_id, + project_id=project_id, + trace_id=trace_id, + ) - runs.append(run) - except Exception: - recompute_assessment_status(session=session, assessment_id=assessment.id) - raise + logger.info( + "[start_assessment] Dispatched Celery task | run_id=%s | config_id=%s", + run.id, + cfg.config_id, + ) recompute_assessment_status(session=session, assessment_id=assessment.id) @@ -242,13 +220,13 @@ def start_assessment( num_configs=len(runs), runs=[ AssessmentRunSummary( - run_id=completed_run.id, - assessment_id=completed_run.assessment_id, - config_id=str(completed_run.config_id), - config_version=completed_run.config_version, - status=completed_run.status, + run_id=run.id, + assessment_id=run.assessment_id, + config_id=str(run.config_id), + config_version=run.config_version, + status=run.status, ) - for completed_run in runs + for run in runs ], ) diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py new file mode 100644 index 000000000..66f644050 --- /dev/null +++ b/backend/app/services/assessment/tasks.py @@ -0,0 +1,196 @@ +"""Celery task logic for running a single assessment run (L1 → L2 batch submit).""" + +import logging + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.assessment import ( + get_assessment_dataset_by_id, + recompute_assessment_status, + update_assessment_run_status, +) +from app.crud.assessment.batch import _load_dataset_rows, submit_assessment_batch +from app.crud.config import ConfigCrud +from app.crud.evaluations.core import resolve_evaluation_config +from app.models.assessment import Assessment, AssessmentRun +from app.models.config.config import ConfigTag +from app.services.assessment.l1 import run_l1_pipeline + +logger = logging.getLogger(__name__) + + +def execute_assessment_run( + run_id: int, + organization_id: int, + project_id: int, +) -> None: + """Run L1 filtering then submit L2 batch for one AssessmentRun. + + Status transitions: + pending → l1_processing → l1_failed (stop) + → l2_processing → (cron handles rest) + pending → l2_processing (when no l1_config) + """ + with Session(engine) as session: + run = session.get(AssessmentRun, run_id) + if run is None: + logger.error("[execute_assessment_run] run_id=%s not found", run_id) + return + + assessment = session.get(Assessment, run.assessment_id) + if assessment is None: + logger.error( + "[execute_assessment_run] parent assessment %s not found for run %s", + run.assessment_id, + run_id, + ) + return + + assessment_input = run.input or {} + dataset_id = assessment.dataset_id + + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + config_crud = ConfigCrud(session=session, project_id=project_id) + parent_config = config_crud.read_one(run.config_id) + if parent_config is not None and parent_config.tag != ConfigTag.ASSESSMENT: + logger.error( + "[execute_assessment_run] config %s has wrong tag for run %s", + run.config_id, + run_id, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Config tag is not ASSESSMENT.", + ) + recompute_assessment_status(session=session, assessment_id=assessment.id) + return + + config_blob, error = resolve_evaluation_config( + session=session, + config_id=run.config_id, + config_version=run.config_version, + project_id=project_id, + tag=ConfigTag.ASSESSMENT, + ) + if error or config_blob is None: + logger.error( + "[execute_assessment_run] config resolution failed run_id=%s: %s", + run_id, + error, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message=f"Config resolution failed: {error}", + ) + recompute_assessment_status(session=session, assessment_id=assessment.id) + return + + all_rows = _load_dataset_rows(session=session, dataset=dataset) + if not all_rows: + logger.error( + "[execute_assessment_run] dataset %s has no rows for run %s", + dataset_id, + run_id, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Dataset has no rows.", + ) + recompute_assessment_status(session=session, assessment_id=assessment.id) + return + + # L1 pipeline + rows_for_l2 = all_rows + row_indices_for_l2: list[int] | None = None + l1_config = assessment_input.get("l1_config") + if l1_config: + update_assessment_run_status( + session=session, run=run, status="l1_processing" + ) + try: + rows_for_l2, row_indices_for_l2, _ = run_l1_pipeline( + run=run, + rows=all_rows, + l1_config=l1_config, + session=session, + organization_id=organization_id, + project_id=project_id, + ) + logger.info( + "[execute_assessment_run] L1 done | run_id=%s | rows_to_l2=%s / %s", + run_id, + len(rows_for_l2), + len(all_rows), + ) + except Exception as l1_exc: + logger.error( + "[execute_assessment_run] L1 failed run_id=%s | %s", + run_id, + l1_exc, + exc_info=True, + ) + update_assessment_run_status( + session=session, + run=run, + status="l1_failed", + error_message=f"L1 pipeline failed: {l1_exc}", + ) + recompute_assessment_status( + session=session, assessment_id=assessment.id + ) + return # L2 does not run when L1 fails + + # L2 batch submit + try: + batch_job = submit_assessment_batch( + session=session, + run=run, + assessment=assessment, + dataset=dataset, + config_blob=config_blob, + assessment_input=assessment_input, + organization_id=organization_id, + project_id=project_id, + preloaded_rows=rows_for_l2, + row_indices=row_indices_for_l2, + ) + update_assessment_run_status( + session=session, + run=run, + status="l2_processing", + batch_job_id=batch_job.id, + total_items=batch_job.total_items, + ) + logger.info( + "[execute_assessment_run] L2 batch submitted | run_id=%s | batch_job_id=%s", + run_id, + batch_job.id, + ) + except Exception as e: + logger.error( + "[execute_assessment_run] L2 batch submit failed run_id=%s: %s", + run_id, + e, + exc_info=True, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Batch submission failed. Please try again or contact support.", + ) + + recompute_assessment_status(session=session, assessment_id=assessment.id) diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index ca273afc6..d244ecd08 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -22,6 +22,8 @@ from app.services.assessment.utils.parsing import parse_stored_results, usage_totals from app.utils import APIResponse +_L1_JSON_COLUMNS = ["topic_relevance", "duplicate_detection"] + logger = logging.getLogger(__name__) @@ -34,6 +36,29 @@ def _load_dataset_rows( return load_dataset_rows(session, dataset) +def _load_l1_results( + session: Session, + run: AssessmentRun, + assessment: Assessment, +) -> dict[str, dict[str, Any]]: + """Load L1 results from object store, keyed by row_id. Returns {} if unavailable.""" + if not run.l1_object_store_url: + return {} + try: + storage = get_cloud_storage(session, project_id=assessment.project_id) + body = storage.stream(run.l1_object_store_url) + raw = body.read().decode("utf-8") + results: list[dict[str, Any]] = json.loads(raw) + return {str(item["row_id"]): item for item in results if "row_id" in item} + except Exception as exc: + logger.warning( + "[_load_l1_results] Failed to load L1 results for run id=%s: %s", + run.id, + exc, + ) + return {} + + def _safe_filename_part(value: str) -> str: """Build a filesystem-safe filename component.""" sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._") @@ -113,86 +138,99 @@ def _drop_empty_columns( return pruned, non_empty_fields +def _parse_json_col(raw: Any) -> dict[str, Any] | None: + if raw is None: + return None + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, dict) else None + except (json.JSONDecodeError, TypeError): + return None + return None + + def _expand_output_columns( row_payload: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[str]]: - """Expand the ``output`` field into separate columns when it contains valid JSON. + """Expand ``output``, ``topic_relevance``, and ``duplicate_detection`` JSON columns + into separate flat columns when they contain valid JSON objects. Returns: (expanded_rows, ordered_fieldnames) """ - # First expand input columns row_payload, input_col_names = _expand_input_columns(row_payload) + json_expand_cols = {"output", "input_data"} | set(_L1_JSON_COLUMNS) base_fields = [ field for field in AssessmentExportRow.model_fields.keys() - if field not in ("output", "input_data") + if field not in json_expand_cols ] - parsed_outputs: list[dict[str, Any] | None] = [] - output_keys: list[str] = [] - seen_keys: dict[str, None] = {} # ordered set + # L1 columns are prefixed with their parent name to avoid key collisions + parsed_cols: dict[str, list[dict[str, Any] | None]] = { + col: [] for col in ["output"] + _L1_JSON_COLUMNS + } + col_keys: dict[str, list[str]] = {col: [] for col in ["output"] + _L1_JSON_COLUMNS} + col_seen: dict[str, dict[str, None]] = { + col: {} for col in ["output"] + _L1_JSON_COLUMNS + } has_unparsed_output = False for row in row_payload: - raw = row.get("output") - if raw is None: - parsed_outputs.append(None) - continue - - if isinstance(raw, str): - try: - parsed = json.loads(raw) - except (json.JSONDecodeError, TypeError): - parsed = None - elif isinstance(raw, dict): - parsed = raw - else: - parsed = None - - if not isinstance(parsed, dict): - has_unparsed_output = True - parsed_outputs.append(None) - continue - - parsed_outputs.append(parsed) - for output_key in parsed: - if output_key not in seen_keys: - seen_keys[output_key] = None - output_keys.append(output_key) - - if not output_keys: - # Keep original layout with output as a single column - fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) - fieldnames = [field for field in fieldnames if field != "input_data"] - return row_payload, fieldnames + for col in ["output"] + _L1_JSON_COLUMNS: + parsed = _parse_json_col(row.get(col)) + if parsed is None and col == "output" and row.get(col) is not None: + has_unparsed_output = True + parsed_cols[col].append(parsed) + if parsed: + for k in parsed: + prefixed = f"{col}_{k}" if col in _L1_JSON_COLUMNS else k + if prefixed not in col_seen[col]: + col_seen[col][prefixed] = None + col_keys[col].append(prefixed) + + def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: + if not parsed: + return {} + if col in _L1_JSON_COLUMNS: + return {f"{col}_{k}": v for k, v in parsed.items()} + return parsed # Build expanded rows expanded: list[dict[str, Any]] = [] - for row, parsed in zip(row_payload, parsed_outputs, strict=True): - new_row = {col: val for col, val in row.items() if col != "output"} - if parsed: - for output_key in output_keys: - new_row[output_key] = parsed.get(output_key) - else: - for output_key in output_keys: - new_row[output_key] = None - if row.get("output") is not None: - new_row["output_raw"] = row.get("output") + for i, row in enumerate(row_payload): + new_row = {k: v for k, v in row.items() if k not in json_expand_cols} + for col in ["output"] + _L1_JSON_COLUMNS: + parsed = parsed_cols[col][i] + keys = col_keys[col] + prefixed_vals = _get_prefixed(parsed, col) + if prefixed_vals: + for k in keys: + new_row[k] = prefixed_vals.get(k) + else: + for k in keys: + new_row[k] = None + if col == "output" and row.get("output") is not None: + new_row["output_raw"] = row.get("output") expanded.append(new_row) - # Build fieldnames: input columns + base fields + output columns - output_idx = base_fields.index("result_status") + 1 # after result_status - fieldnames = ( - input_col_names - + base_fields[:output_idx] - + output_keys - + base_fields[output_idx:] - ) + l1_keys = col_keys["topic_relevance"] + col_keys["duplicate_detection"] + output_keys = col_keys["output"] + + all_output_keys = l1_keys + output_keys + if not all_output_keys: + fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) + fieldnames = [f for f in fieldnames if f != "input_data"] + return row_payload, fieldnames + + fieldnames = input_col_names + l1_keys + output_keys + base_fields if has_unparsed_output: fieldnames.insert( - len(input_col_names) + output_idx + len(output_keys), "output_raw" + len(input_col_names) + len(l1_keys) + len(output_keys), "output_raw" ) return expanded, fieldnames @@ -212,7 +250,6 @@ def serialize_export_rows( "application/json", ) - # For CSV/XLSX, expand output keys into separate columns expanded, fieldnames = _expand_output_columns(row_payload) if export_format == "csv": @@ -230,7 +267,6 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - # XLSX shows input columns + output columns only (no metadata fields). metadata_fields = { field for field in AssessmentExportRow.model_fields.keys() @@ -376,59 +412,154 @@ def _load_dataset_rows_for_run( return [] +def _extract_l1_json_columns( + l1_item: dict[str, Any] | None, +) -> dict[str, Any]: + """Return topic_relevance and duplicate_detection as JSON strings for export expansion.""" + if not l1_item: + return {"topic_relevance": None, "duplicate_detection": None} + + tr = l1_item.get("topic_relevance") + dup = l1_item.get("duplicate_detection") + + tr_flat: dict[str, Any] | None = None + if tr: + tr_flat = {} + for col, val in (tr.get("column_relevance") or {}).items(): + tr_flat[col] = val + tr_flat["decision"] = tr.get("decision") + tr_flat["reasoning"] = tr.get("reasoning") + + dup_flat: dict[str, Any] | None = None + if dup: + dup_flat = {k: v for k, v in dup.items() if k != "row_id"} + + return { + "topic_relevance": json.dumps(tr_flat, ensure_ascii=False) if tr_flat else None, + "duplicate_detection": json.dumps(dup_flat, ensure_ascii=False) + if dup_flat + else None, + } + + def load_export_rows_for_run( session: Session, run: AssessmentRun, assessment: Assessment | None = None, ) -> list[AssessmentExportRow]: - """Load flattened export rows for a single child assessment run.""" - if not run.batch_job_id: + """Load flattened export rows for a single child assessment run. + + When L1 results exist, ALL dataset rows are included in output. + L1-rejected rows have L1 columns filled and L2 columns empty. + L1-passed rows have all columns filled. + Without L1, behaviour is unchanged (only L2 result rows returned). + """ + if assessment is None: + assessment = session.get(Assessment, run.assessment_id) + if assessment is None: logger.warning( - "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id + "[load_export_rows_for_run] Parent assessment missing for run id=%s", + run.id, ) return [] - batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) - if not batch_job: + dataset = session.get(EvaluationDataset, assessment.dataset_id) + dataset_name = dataset.name if dataset else None + dataset_rows = _load_dataset_rows_for_run(session, run, assessment) + + # Load L1 results (empty dict if no L1 was run) + l1_by_row_id = _load_l1_results(session, run, assessment) + + # Load L2 results (may be None if batch not complete) + l2_by_row_id: dict[str, dict[str, Any]] = {} + if run.batch_job_id: + batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) + if batch_job: + parsed_results = _load_parsed_results_for_run( + session=session, run=run, batch_job=batch_job + ) + if parsed_results: + l2_by_row_id = { + str(item["row_id"]): item + for item in parsed_results + if "row_id" in item + } + + has_l1 = bool(l1_by_row_id) + + if has_l1 and dataset_rows: + # All rows in output — build from full dataset + export_rows: list[AssessmentExportRow] = [] + for row_idx, input_data in enumerate(dataset_rows): + row_id_str = f"row_{row_idx}" + l1_item = l1_by_row_id.get(row_id_str) + l1_cols = _extract_l1_json_columns(l1_item) + l2_item = l2_by_row_id.get(row_id_str) + + input_tokens, output_tokens, total_tokens = usage_totals( + l2_item.get("usage") if l2_item else None + ) + l1_passed = (l1_item or {}).get("l1_passed", True) + result_status = ( + "l1_rejected" + if not l1_passed + else ("failed" if l2_item and l2_item.get("error") else "passed") + ) + + export_rows.append( + AssessmentExportRow( + assessment_id=run.assessment_id, + experiment_name=assessment.experiment_name, + dataset_id=assessment.dataset_id, + dataset_name=dataset_name, + run_id=run.id, + run_name=assessment.experiment_name, + run_status=run.status, + config_id=run.config_id, + config_version=run.config_version, + row_id=row_id_str, + result_status=result_status, + input_data=input_data, + topic_relevance=l1_cols.get("topic_relevance"), + duplicate_detection=l1_cols.get("duplicate_detection"), + output=l2_item.get("output") if l2_item else None, + error=l2_item.get("error") if l2_item else None, + response_id=l2_item.get("response_id") if l2_item else None, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + updated_at=run.updated_at, + ) + ) + return export_rows + + # No L1 — original behaviour: only L2 result rows + if not run.batch_job_id: logger.warning( - "[load_export_rows_for_run] Missing batch job for run id=%s", - run.id, + "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id ) return [] - if assessment is None: - assessment = session.get(Assessment, run.assessment_id) - if assessment is None: + batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) + if not batch_job: logger.warning( - "[load_export_rows_for_run] Parent assessment missing for run id=%s", - run.id, + "[load_export_rows_for_run] Missing batch job for run id=%s", run.id ) return [] parsed_results = _load_parsed_results_for_run( - session=session, - run=run, - batch_job=batch_job, + session=session, run=run, batch_job=batch_job ) - if parsed_results is None: - return [] - if not parsed_results: logger.warning( "[load_export_rows_for_run] Parsed results empty for run id=%s", run.id ) return [] - dataset_rows = _load_dataset_rows_for_run(session, run, assessment) - dataset = session.get(EvaluationDataset, assessment.dataset_id) - dataset_name = dataset.name if dataset else None - - export_rows: list[AssessmentExportRow] = [] + export_rows = [] for item in parsed_results: input_tokens, output_tokens, total_tokens = usage_totals(item.get("usage")) - - # Correlate with original input row via row_id (format: "row_{idx}") - input_data: dict[str, str] | None = None + input_data = None row_id_str = str(item.get("row_id", "")) if dataset_rows and row_id_str.startswith("row_"): try: From b4128290c424729997a409ea3977d3c28c71a66d Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 28 May 2026 21:33:42 +0530 Subject: [PATCH 02/16] feat(export): Expand output columns to include topic relevance and duplicate detection --- .../app/services/assessment/utils/export.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index d244ecd08..4b299151f 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -154,7 +154,7 @@ def _parse_json_col(raw: Any) -> dict[str, Any] | None: def _expand_output_columns( row_payload: list[dict[str, Any]], -) -> tuple[list[dict[str, Any]], list[str]]: +) -> tuple[list[dict[str, Any]], list[str], list[str], list[str], list[str]]: """Expand ``output``, ``topic_relevance``, and ``duplicate_detection`` JSON columns into separate flat columns when they contain valid JSON objects. @@ -225,7 +225,7 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: if not all_output_keys: fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) fieldnames = [f for f in fieldnames if f != "input_data"] - return row_payload, fieldnames + return row_payload, fieldnames, input_col_names, [], [] fieldnames = input_col_names + l1_keys + output_keys + base_fields if has_unparsed_output: @@ -233,7 +233,7 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: len(input_col_names) + len(l1_keys) + len(output_keys), "output_raw" ) - return expanded, fieldnames + return expanded, fieldnames, input_col_names, l1_keys, output_keys def serialize_export_rows( @@ -244,13 +244,13 @@ def serialize_export_rows( row_payload = [row.model_dump(mode="json") for row in export_rows] if export_format == "json": - expanded, _ = _expand_output_columns(row_payload) + expanded, *_ = _expand_output_columns(row_payload) return ( json.dumps(expanded, ensure_ascii=False, indent=2).encode("utf-8"), "application/json", ) - expanded, fieldnames = _expand_output_columns(row_payload) + expanded, fieldnames, input_col_names, l1_keys, output_keys = _expand_output_columns(row_payload) if export_format == "csv": output = io.StringIO() @@ -267,14 +267,10 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - metadata_fields = { - field - for field in AssessmentExportRow.model_fields.keys() - if field not in ("output", "input_data") - } - excel_fields = [field for field in fieldnames if field not in metadata_fields] + # Explicit ordering: inputs → L1 topic relevance → L1 duplicate detection → L2 output + excel_fields = input_col_names + l1_keys + output_keys if not excel_fields: - excel_fields = ["output"] + excel_fields = output_keys or ["output"] # Drop columns where every row is null/empty expanded, excel_fields = _drop_empty_columns(expanded, excel_fields) @@ -294,7 +290,7 @@ def build_json_export_rows( ) -> list[dict[str, Any]]: """Return JSON rows with structured output expanded into top-level keys.""" row_payload = [row.model_dump(mode="json") for row in export_rows] - expanded, _ = _expand_output_columns(row_payload) + expanded, *_ = _expand_output_columns(row_payload) return expanded From c12ac18bd185cf0f327011050082d99dbc1bb386 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sun, 31 May 2026 08:00:55 +0530 Subject: [PATCH 03/16] feat(post-processing): Implement post-processing configuration for assessment runs --- .../docs/assessment/update_post_processing.md | 15 ++ backend/app/api/routes/assessment/runs.py | 39 +++- backend/app/crud/assessment/__init__.py | 2 + backend/app/crud/assessment/core.py | 25 +++ backend/app/models/assessment.py | 10 +- backend/app/services/assessment/service.py | 3 + .../app/services/assessment/utils/export.py | 39 +++- .../assessment/utils/post_processing.py | 184 ++++++++++++++++++ 8 files changed, 306 insertions(+), 11 deletions(-) create mode 100644 backend/app/api/docs/assessment/update_post_processing.md create mode 100644 backend/app/services/assessment/utils/post_processing.py diff --git a/backend/app/api/docs/assessment/update_post_processing.md b/backend/app/api/docs/assessment/update_post_processing.md new file mode 100644 index 000000000..0d6f3278a --- /dev/null +++ b/backend/app/api/docs/assessment/update_post_processing.md @@ -0,0 +1,15 @@ +Save post-processing config for a single assessment run. + +Stores the config inside the run's `input` JSON blob (key +`post_processing_config`). It is applied at export/preview time and never +re-runs the LLM, so it can be edited after the run completes. + +The config has three optional sections: + +- `computed_columns`: derived columns from formulas, e.g. + `{"name": "Total_Score", "formula": "@Novelty_score + @Usefulness_score"}`. + Formulas reference columns with `@` and support `+ - * /` and parentheses. +- `filter`: row filters combined with AND logic. +- `sort`: sort rules applied in priority order. + +Pass `null` (or an empty body) to clear post-processing for the run. diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py index 18398eeb0..3c3abd57a 100644 --- a/backend/app/api/routes/assessment/runs.py +++ b/backend/app/api/routes/assessment/runs.py @@ -3,7 +3,7 @@ import logging from typing import Any, Literal -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from app.api.deps import AuthContextDep, SessionDep @@ -12,6 +12,7 @@ get_assessment_by_id, get_assessment_run_by_id as get_run_by_id, list_assessment_runs as list_runs, + update_run_post_processing_config, ) from app.models.assessment import ( Assessment, @@ -33,6 +34,7 @@ load_export_rows_for_run, sort_export_rows, ) +from app.services.assessment.utils.post_processing import apply_post_processing from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -68,6 +70,7 @@ def _build_run_public( l1_total_rows=run.l1_total_rows, l1_total_passed=run.l1_total_passed, l1_total_rejected=run.l1_total_rejected, + post_processing_config=(run.input or {}).get("post_processing_config"), inserted_at=run.inserted_at, updated_at=run.updated_at, ) @@ -215,12 +218,44 @@ def export_assessment_run_results( ) ) + post_processing_config = (run.input or {}).get("post_processing_config") or None base_label = assessment.experiment_name if assessment else f"run_{run.id}" + if export_format != "json": return build_export_response( export_rows=export_rows, export_format=export_format, base_name=f"{base_label}_run_{run.id}_results", + post_processing_config=post_processing_config, ) - return APIResponse.success_response(data=build_json_export_rows(export_rows)) + rows = build_json_export_rows(export_rows) + rows = apply_post_processing(rows, post_processing_config) + return APIResponse.success_response(data=rows) + + +@router.patch( + "/runs/{run_id}/post-processing", + description=load_description("assessment/update_post_processing.md"), + response_model=APIResponse[AssessmentRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def update_post_processing( + run_id: int, + session: SessionDep, + auth_context: AuthContextDep, + config: dict[str, Any] | None = Body(default=None), +) -> APIResponse[AssessmentRunPublic]: + """Save post-processing config (computed columns, sort, filter) for a run.""" + run = get_run_by_id( + session=session, + run_id=run_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + if run is None: + raise HTTPException(status_code=404, detail="Run not found") + + run = update_run_post_processing_config(session=session, run=run, config=config) + + return APIResponse.success_response(data=_build_run_public(session, run)) diff --git a/backend/app/crud/assessment/__init__.py b/backend/app/crud/assessment/__init__.py index 8e5c9984d..8e623e3a7 100644 --- a/backend/app/crud/assessment/__init__.py +++ b/backend/app/crud/assessment/__init__.py @@ -15,6 +15,7 @@ recompute_assessment_status, update_assessment_run_l1_stats, update_assessment_run_status, + update_run_post_processing_config, ) from app.crud.assessment.dataset import ( create_assessment_dataset, @@ -45,4 +46,5 @@ "recompute_assessment_status", "update_assessment_run_l1_stats", "update_assessment_run_status", + "update_run_post_processing_config", ] diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py index 6e2c8b2f7..d5a184d06 100644 --- a/backend/app/crud/assessment/core.py +++ b/backend/app/crud/assessment/core.py @@ -5,6 +5,7 @@ from uuid import UUID from fastapi import HTTPException +from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session, select from app.core.util import now @@ -129,6 +130,30 @@ def create_assessment_run( return run +def update_run_post_processing_config( + session: Session, + run: AssessmentRun, + config: dict[str, Any] | None, +) -> AssessmentRun: + """Set post_processing_config inside the run's input JSON blob and persist.""" + run.input = {**(run.input or {}), "post_processing_config": config} + flag_modified(run, "input") + session.add(run) + try: + session.commit() + session.refresh(run) + except Exception as e: + session.rollback() + logger.error( + f"[update_run_post_processing_config] Failed for run id={run.id}: {e}", + exc_info=True, + ) + raise + + logger.info(f"[update_run_post_processing_config] Updated run id={run.id}") + return run + + def get_assessment_run_by_id( session: Session, run_id: int, diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index 0dd0a96d1..b8620af06 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -5,7 +5,7 @@ from uuid import UUID from pydantic import BaseModel, Field -from sqlalchemy import Column, Index, Text +from sqlalchemy import JSON, Column, Index, Text from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField from sqlmodel import Relationship, SQLModel @@ -253,6 +253,7 @@ class AssessmentRunPublic(BaseModel): l1_total_rows: int | None = None l1_total_passed: int | None = None l1_total_rejected: int | None = None + post_processing_config: dict[str, Any] | None = None inserted_at: datetime updated_at: datetime @@ -322,6 +323,13 @@ class AssessmentCreate(BaseModel): "duplicate_detection (columns). Omit to skip L1." ), ) + post_processing_config: dict[str, Any] | None = Field( + None, + description=( + "Post-processing config applied at export. " + "Keys: computed_columns, sort, filter." + ), + ) class AssessmentRunSummary(BaseModel): diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py index d03eb13a7..cabe2bb4c 100644 --- a/backend/app/services/assessment/service.py +++ b/backend/app/services/assessment/service.py @@ -80,6 +80,7 @@ def _build_retry_request( output_schema=assessment_input.get("output_schema"), configs=configs, l1_config=assessment_input.get("l1_config"), + post_processing_config=assessment_input.get("post_processing_config"), ) @@ -121,6 +122,8 @@ def start_assessment( assessment_input["output_schema"] = request.output_schema if request.l1_config: assessment_input["l1_config"] = request.l1_config + if request.post_processing_config: + assessment_input["post_processing_config"] = request.post_processing_config config_crud = ConfigCrud(session=session, project_id=project_id) diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index 4b299151f..86d9186b0 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -239,22 +239,43 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: def serialize_export_rows( export_rows: list[AssessmentExportRow], export_format: Literal["json", "csv", "xlsx"], + post_processing_config: dict[str, Any] | None = None, ) -> tuple[bytes, str]: """Serialize export rows into the requested file format.""" + from app.services.assessment.utils.post_processing import apply_post_processing + row_payload = [row.model_dump(mode="json") for row in export_rows] if export_format == "json": expanded, *_ = _expand_output_columns(row_payload) + expanded = apply_post_processing(expanded, post_processing_config) return ( json.dumps(expanded, ensure_ascii=False, indent=2).encode("utf-8"), "application/json", ) - expanded, fieldnames, input_col_names, l1_keys, output_keys = _expand_output_columns(row_payload) + ( + expanded, + fieldnames, + input_col_names, + l1_keys, + output_keys, + ) = _expand_output_columns(row_payload) + expanded = apply_post_processing(expanded, post_processing_config) + + # Add any new computed columns to fieldnames so they appear in output + existing = set(fieldnames) + computed_names = [ + c["name"] + for c in (post_processing_config or {}).get("computed_columns") or [] + if c.get("name") and c["name"] not in existing + ] + if computed_names: + fieldnames = fieldnames + computed_names if export_format == "csv": output = io.StringIO() - writer = csv.DictWriter(output, fieldnames=fieldnames) + writer = csv.DictWriter(output, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() writer.writerows(expanded) return output.getvalue().encode("utf-8"), "text/csv" @@ -267,12 +288,11 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - # Explicit ordering: inputs → L1 topic relevance → L1 duplicate detection → L2 output - excel_fields = input_col_names + l1_keys + output_keys + # Explicit ordering: inputs → L1 → L2 → computed columns + excel_fields = input_col_names + l1_keys + output_keys + computed_names if not excel_fields: excel_fields = output_keys or ["output"] - # Drop columns where every row is null/empty expanded, excel_fields = _drop_empty_columns(expanded, excel_fields) buf = io.BytesIO() @@ -290,17 +310,20 @@ def build_json_export_rows( ) -> list[dict[str, Any]]: """Return JSON rows with structured output expanded into top-level keys.""" row_payload = [row.model_dump(mode="json") for row in export_rows] - expanded, *_ = _expand_output_columns(row_payload) - return expanded + expanded, fieldnames, *_ = _expand_output_columns(row_payload) + return [{k: row.get(k) for k in fieldnames if k in row} for row in expanded] def build_export_response( export_rows: list[AssessmentExportRow], export_format: Literal["json", "csv", "xlsx"], base_name: str, + post_processing_config: dict[str, Any] | None = None, ) -> StreamingResponse: """Return a file download response for assessment exports.""" - payload, media_type = serialize_export_rows(export_rows, export_format) + payload, media_type = serialize_export_rows( + export_rows, export_format, post_processing_config + ) filename = generate_timestamped_filename( _safe_filename_part(base_name), extension=export_format, diff --git a/backend/app/services/assessment/utils/post_processing.py b/backend/app/services/assessment/utils/post_processing.py new file mode 100644 index 000000000..9b0d36c45 --- /dev/null +++ b/backend/app/services/assessment/utils/post_processing.py @@ -0,0 +1,184 @@ +"""Post-processing engine for assessment exports. +""" + +import ast +import logging +import operator +import re +from typing import Any + +logger = logging.getLogger(__name__) + +# Safe formula evaluator +_SAFE_OPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.USub: operator.neg, +} + + +def _eval_node(node: ast.AST) -> float: + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return float(node.value) + if isinstance(node, ast.BinOp) and type(node.op) in _SAFE_OPS: + return _SAFE_OPS[type(node.op)](_eval_node(node.left), _eval_node(node.right)) + if isinstance(node, ast.UnaryOp) and type(node.op) in _SAFE_OPS: + return _SAFE_OPS[type(node.op)](_eval_node(node.operand)) + raise ValueError(f"Unsupported operation in formula: {ast.dump(node)}") + + +def evaluate_formula(formula: str, row: dict[str, Any]) -> float | None: + """Evaluate a formula like '@Novelty_score + @Feasibility_score * 0.5'. + + Returns None if the formula fails or references missing columns. + """ + + def resolve(match: re.Match) -> str: + col = match.group(1) + val = row.get(col) + if val is None: + return "0" + try: + return str(float(val)) + except (TypeError, ValueError): + return "0" + + expr = re.sub(r"@([\w]+)", resolve, formula) + + try: + tree = ast.parse(expr, mode="eval") + return _eval_node(tree.body) + except Exception as exc: + logger.warning("[evaluate_formula] Failed to evaluate %r: %s", formula, exc) + return None + + +# Filter + +_FILTER_OPS = { + "eq": lambda a, b: str(a).strip().lower() == str(b).strip().lower(), + "ne": lambda a, b: str(a).strip().lower() != str(b).strip().lower(), + "contains": lambda a, b: str(b).lower() in str(a).lower(), + "not_contains": lambda a, b: str(b).lower() not in str(a).lower(), + "in": lambda a, b: str(a).strip().lower() in {str(v).lower() for v in b}, + "not_in": lambda a, b: str(a).strip().lower() not in {str(v).lower() for v in b}, + "is_empty": lambda a, _: a is None or str(a).strip() == "", + "is_not_empty": lambda a, _: a is not None and str(a).strip() != "", +} + + +def _numeric_filter(op: str, a: Any, b: Any) -> bool: + try: + fa, fb = float(a), float(b) + if op == "gt": + return fa > fb + if op == "lt": + return fa < fb + if op == "gte": + return fa >= fb + if op == "lte": + return fa <= fb + except (TypeError, ValueError): + pass + return False + + +def _row_matches_filter(row: dict[str, Any], rule: dict[str, Any]) -> bool: + col = rule["column"] + op = rule["op"] + value = rule.get("value") + cell = row.get(col) + + if op in ("gt", "lt", "gte", "lte"): + return _numeric_filter(op, cell, value) + if op in _FILTER_OPS: + return _FILTER_OPS[op](cell, value) + return True + + +def apply_computed_columns( + rows: list[dict[str, Any]], + computed_columns: list[dict[str, Any]], +) -> None: + """Add computed columns to each row in-place.""" + for row in rows: + for col_def in computed_columns: + name = col_def.get("name", "").strip() + formula = col_def.get("formula", "").strip() + if not name or not formula: + continue + row[name] = evaluate_formula(formula, row) + + +def apply_filter( + rows: list[dict[str, Any]], + filter_rules: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Return only rows that match ALL filter rules (AND logic).""" + if not filter_rules: + return rows + return [ + row + for row in rows + if all(_row_matches_filter(row, rule) for rule in filter_rules) + ] + + +def apply_sort( + rows: list[dict[str, Any]], + sort_rules: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Sort rows by priority-ordered rules. First rule has highest priority.""" + if not sort_rules: + return rows + + # Build sort key: iterate rules in reverse (lowest priority first) + # so that highest priority rule is the final (dominant) tiebreaker. + result = rows + for rule in reversed(sort_rules): + col = rule.get("column", "") + desc = str(rule.get("direction", "asc")).lower() == "desc" + + def sort_key(row: dict[str, Any], _col: str = col) -> tuple: + val = row.get(_col) + if val is None: + return (1, 0, "") + try: + return (0, -float(val) if desc else float(val), "") + except (TypeError, ValueError): + s = str(val).lower() + return ( + (0, 0, s) + if not desc + else (0, 0, "".join(chr(0x10FFFF - ord(c)) for c in s)) + ) + + result = sorted(result, key=sort_key) + + return result + + +def apply_post_processing( + rows: list[dict[str, Any]], + config: dict[str, Any] | None, +) -> list[dict[str, Any]]: + """Apply full post-processing pipeline: computed columns → filter → sort. + + Safe to call with config=None (no-op). + """ + if not config: + return rows + + computed_columns = config.get("computed_columns") or [] + filter_rules = config.get("filter") or [] + sort_rules = config.get("sort") or [] + + if computed_columns: + apply_computed_columns(rows, computed_columns) + + rows = apply_filter(rows, filter_rules) + rows = apply_sort(rows, sort_rules) + + return rows From c1791d5f44e7f6be10827b6d3dcbdfd3e5b28565 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:30:09 +0530 Subject: [PATCH 04/16] feat(assessment): Enhance attachment handling in L1 pipeline with mixed type detection and improved utility functions --- backend/app/crud/assessment/batch.py | 71 +------ backend/app/models/assessment.py | 9 +- .../app/services/assessment/l1/pipeline.py | 15 +- .../services/assessment/l1/topic_relevance.py | 34 ++- backend/app/services/assessment/tasks.py | 10 +- .../services/assessment/utils/attachments.py | 193 +++++++++++++++++- backend/app/tests/assessment/test_batch.py | 104 ++++++++++ backend/app/tests/assessment/test_export.py | 12 +- .../tests/assessment/test_topic_relevance.py | 123 +++++++++++ 9 files changed, 487 insertions(+), 84 deletions(-) create mode 100644 backend/app/tests/assessment/test_topic_relevance.py diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py index e5e52daec..531dc038d 100644 --- a/backend/app/crud/assessment/batch.py +++ b/backend/app/crud/assessment/batch.py @@ -30,11 +30,8 @@ normalize_llm_text, ) from app.services.assessment.utils.attachments import ( + build_gemini_attachment_parts, resolve_attachment_values, - resolve_image_mime_and_payload, - split_attachment_urls, - split_data_url, - to_direct_attachment_url, ) from app.services.llm.providers.registry import LLMProvider @@ -174,6 +171,8 @@ def build_openai_jsonl( } """ jsonl_data = [] + # Memoize per-item type probes across all rows in this build. + type_cache: dict[str, str] = {} for i, row in enumerate(rows): idx = row_indices[i] if row_indices is not None else i @@ -188,7 +187,7 @@ def build_openai_jsonl( # Attachments for att in attachments: cell_value = row.get(att.column, "") - input_parts.extend(resolve_attachment_values(cell_value, att)) + input_parts.extend(resolve_attachment_values(cell_value, att, type_cache)) if not input_parts: logger.warning("[build_openai_jsonl] Skipping empty row | idx=%s", idx) @@ -232,6 +231,8 @@ def build_google_jsonl( } """ jsonl_data = [] + # Memoize per-item type probes across all rows in this build. + type_cache: dict[str, str] = {} for i, row in enumerate(rows): idx = row_indices[i] if row_indices is not None else i @@ -244,64 +245,8 @@ def build_google_jsonl( # Attachments (Gemini uses file_data for inline content) for att in attachments: - cell_value = row.get(att.column, "").strip() - if not cell_value: - continue - - cell_values = ( - split_attachment_urls(cell_value) - if att.format == "url" - else [cell_value] - ) - - for item_value in cell_values: - normalized_value = ( - to_direct_attachment_url(item_value, att.type) - if att.format == "url" - else item_value - ) - if att.type == "image": - mime_type, payload = resolve_image_mime_and_payload( - normalized_value, - att.format, - ) - if att.format == "url": - parts.append( - { - "fileData": { - "mimeType": mime_type, - "fileUri": normalized_value, - } - } - ) - else: - parts.append( - { - "inlineData": { - "mimeType": mime_type, - "data": payload, - } - } - ) - elif att.type == "pdf": - if att.format == "url": - parts.append( - { - "fileData": { - "mimeType": "application/pdf", - "fileUri": normalized_value, - } - } - ) - else: - parts.append( - { - "inlineData": { - "mimeType": "application/pdf", - "data": split_data_url(normalized_value)[1], - } - } - ) + cell_value = row.get(att.column, "") + parts.extend(build_gemini_attachment_parts(cell_value, att, type_cache)) if not parts: logger.warning("[build_google_jsonl] Skipping empty row | idx=%s", idx) diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index b8620af06..b5a1a31f5 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -275,7 +275,14 @@ class AssessmentAttachment(BaseModel): """Attachment column configuration.""" column: str = Field(..., description="Column name containing the attachment data") - type: Literal["image", "pdf"] = Field(..., description="Attachment type") + type: Literal["image", "pdf", "mixed"] = Field( + ..., + description=( + "Attachment type. 'mixed' detects image vs pdf per item (for columns " + "that contain both); 'image'/'pdf' force a type and act as fallback " + "when per-item detection is inconclusive." + ), + ) format: Literal["url", "base64"] = Field(..., description="Data format") diff --git a/backend/app/services/assessment/l1/pipeline.py b/backend/app/services/assessment/l1/pipeline.py index 2a002e5e5..18df91324 100644 --- a/backend/app/services/assessment/l1/pipeline.py +++ b/backend/app/services/assessment/l1/pipeline.py @@ -17,7 +17,7 @@ from app.core.config import settings from app.core.cloud import get_cloud_storage from app.core.storage_utils import upload_jsonl_to_object_store -from app.models.assessment import AssessmentRun +from app.models.assessment import AssessmentAttachment, AssessmentRun from app.services.assessment.l1.duplicate_detection import run_duplicate_detection from app.services.assessment.l1.topic_relevance import run_topic_relevance @@ -50,6 +50,7 @@ def run_l1_pipeline( session: Session, organization_id: int, project_id: int, + attachments: list[AssessmentAttachment] | None = None, ) -> tuple[list[dict[str, str]], list[int], list[dict[str, Any]]]: """Run L1 filters on all rows. @@ -78,6 +79,13 @@ def run_l1_pipeline( tr_prompt: str = tr_config.get("prompt") or "" dup_columns: list[str] = dup_config.get("columns") or [] + tr_attachment_columns = tr_config.get("attachment_columns") + if tr_attachment_columns is None: + tr_attachments = list(attachments or []) + else: + selected = set(tr_attachment_columns) + tr_attachments = [a for a in (attachments or []) if a.column in selected] + tr_enabled = bool(tr_columns and tr_prompt) dup_enabled = bool(dup_columns) @@ -105,6 +113,9 @@ def run_l1_pipeline( ) # tr_results[idx] = None when TR disabled → no topic_relevance columns in export + # Shared across rows so each unique attachment file is type-probed once. + attachment_type_cache: dict[str, str] = {} + tr_results: dict[int, dict[str, Any] | None] = {} if tr_enabled: with ThreadPoolExecutor(max_workers=workers) as executor: @@ -117,6 +128,8 @@ def run_l1_pipeline( tr_prompt, gemini_client, model, + tr_attachments, + attachment_type_cache, ): idx for idx, row in enumerate(rows) } diff --git a/backend/app/services/assessment/l1/topic_relevance.py b/backend/app/services/assessment/l1/topic_relevance.py index 42516d27b..c1894c04e 100644 --- a/backend/app/services/assessment/l1/topic_relevance.py +++ b/backend/app/services/assessment/l1/topic_relevance.py @@ -8,6 +8,9 @@ from google import genai from google.genai import types +from app.models.assessment import AssessmentAttachment +from app.services.assessment.utils.attachments import build_gemini_attachment_parts + logger = logging.getLogger(__name__) @@ -45,21 +48,42 @@ def run_topic_relevance( user_prompt: str, gemini_client: genai.Client, model: str, + attachments: list[AssessmentAttachment] | None = None, + type_cache: dict[str, str] | None = None, ) -> dict[str, Any]: """Run topic relevance check on a single row. System instruction = user_prompt (the evaluation rubric/criteria). - User content = dict of {column_name: value} for the selected columns. + User content = the selected columns as JSON plus every mapped attachment + (image/pdf) for the row, so relevance is judged on text and documents. + Each attachment column also gets its own relevance boolean in the schema, + so the export carries a ``topic_relevance_`` column. Output schema enforced: decision (ACCEPT/REJECT) + reasoning. On error defaults to verdict=True (fail-open). """ + # Document columns that actually have a value for this row. + doc_columns: list[str] = [] + for att in attachments or []: + if att.column not in doc_columns and (row.get(att.column) or "").strip(): + doc_columns.append(att.column) + + schema_columns = columns + doc_columns user_content = json.dumps({col: row.get(col, "") or "" for col in columns}) - output_schema = _build_output_schema(columns) + output_schema = _build_output_schema(schema_columns) + + parts: list[dict[str, Any]] = [{"text": user_content}] + for att in attachments or []: + attachment_parts = build_gemini_attachment_parts( + row.get(att.column, ""), att, type_cache + ) + if attachment_parts: + parts.append({"text": f"Attached document(s) for column '{att.column}':"}) + parts.extend(attachment_parts) try: response = gemini_client.models.generate_content( model=model, - contents=user_content, + contents=[{"role": "user", "parts": parts}], config=types.GenerateContentConfig( system_instruction=user_prompt.strip(), response_mime_type="application/json", @@ -70,7 +94,7 @@ def run_topic_relevance( raw = (response.text or "").strip() parsed = json.loads(raw) decision = str(parsed.get("decision", "ACCEPT")).upper() - column_relevance = {col: bool(parsed.get(col, True)) for col in columns} + column_relevance = {col: bool(parsed.get(col, True)) for col in schema_columns} return { "row_id": f"row_{row_idx}", "verdict": decision == "ACCEPT", @@ -88,6 +112,6 @@ def run_topic_relevance( "row_id": f"row_{row_idx}", "verdict": True, "decision": "ACCEPT", - "column_relevance": {col: True for col in columns}, + "column_relevance": {col: True for col in schema_columns}, "reasoning": f"(evaluation error — defaulting to pass) {exc}", } diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py index 66f644050..295c55ad2 100644 --- a/backend/app/services/assessment/tasks.py +++ b/backend/app/services/assessment/tasks.py @@ -13,7 +13,11 @@ from app.crud.assessment.batch import _load_dataset_rows, submit_assessment_batch from app.crud.config import ConfigCrud from app.crud.evaluations.core import resolve_evaluation_config -from app.models.assessment import Assessment, AssessmentRun +from app.models.assessment import ( + Assessment, + AssessmentAttachment, + AssessmentRun, +) from app.models.config.config import ConfigTag from app.services.assessment.l1 import run_l1_pipeline @@ -128,6 +132,10 @@ def execute_assessment_run( session=session, organization_id=organization_id, project_id=project_id, + attachments=[ + AssessmentAttachment(**a) + for a in assessment_input.get("attachments") or [] + ], ) logger.info( "[execute_assessment_run] L1 done | run_id=%s | rows_to_l2=%s / %s", diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py index 5a141a757..3622f9bce 100644 --- a/backend/app/services/assessment/utils/attachments.py +++ b/backend/app/services/assessment/utils/attachments.py @@ -6,12 +6,17 @@ import base64 import binascii +import logging import re from typing import Any from urllib.parse import urlparse +import requests + from app.models.assessment import AssessmentAttachment +logger = logging.getLogger(__name__) + _IMAGE_MIME_BY_EXT = { ".png": "image/png", ".jpg": "image/jpeg", @@ -92,10 +97,8 @@ def _decode_base64_prefix(payload: str, max_chars: int = 256) -> bytes | None: return None -def _guess_image_mime_from_base64(payload: str) -> str | None: - blob = _decode_base64_prefix(payload) - if not blob: - return None +def _image_mime_from_magic(blob: bytes) -> str | None: + """Detect image mime type from leading magic bytes.""" if blob.startswith(b"\x89PNG\r\n\x1a\n"): return "image/png" if blob.startswith(b"\xff\xd8\xff"): @@ -111,6 +114,22 @@ def _guess_image_mime_from_base64(payload: str) -> str | None: return None +def _guess_image_mime_from_base64(payload: str) -> str | None: + blob = _decode_base64_prefix(payload) + if not blob: + return None + return _image_mime_from_magic(blob) + + +def _type_from_magic(blob: bytes) -> str | None: + """Detect 'image' or 'pdf' from leading magic bytes; None if neither.""" + if blob.startswith(b"%PDF"): + return "pdf" + if _image_mime_from_magic(blob): + return "image" + return None + + def resolve_image_mime_and_payload( value: str, format_type: str, @@ -126,9 +145,110 @@ def resolve_image_mime_and_payload( return _guess_image_mime_from_base64(payload) or "image/png", payload +def _drive_file_id(url: str) -> str | None: + """Extract a Google Drive file id from common share URL shapes.""" + match = re.match(r"https://drive\.google\.com/file/d/([^/]+)", url) + if match: + return match.group(1) + match = re.search(r"[?&]id=([a-zA-Z0-9_-]+)", url) + if match and ("drive.google.com" in url or "drive.usercontent.google.com" in url): + return match.group(1) + return None + + +def _type_from_url_extension(url: str) -> str | None: + """Detect 'image' or 'pdf' from a URL path extension; None if unknown.""" + path = (urlparse(url).path or "").lower() + if path.endswith(".pdf"): + return "pdf" + if _guess_image_mime_from_url(url): + return "image" + return None + + +def _type_from_content_type(content_type: str | None) -> str | None: + if not content_type: + return None + content_type = content_type.split(";")[0].strip().lower() + if content_type == "application/pdf": + return "pdf" + if content_type.startswith("image/"): + return "image" + return None + + +def _probe_url_type(url: str, num_bytes: int = 16) -> str | None: + """Probe a remote URL's type: ranged byte sniff first, Content-Type fallback. + + Reads only the first few bytes (does not download the whole file). Drive + share URLs are routed through the download endpoint so the real file bytes + are read instead of an HTML share page. + """ + file_id = _drive_file_id(url) + probe_url = ( + f"https://drive.google.com/uc?export=download&id={file_id}" if file_id else url + ) + + try: + with requests.get( + probe_url, + headers={"Range": f"bytes=0-{num_bytes - 1}"}, + timeout=10, + stream=True, + allow_redirects=True, + ) as resp: + resp.raise_for_status() + for chunk in resp.iter_content(chunk_size=num_bytes): + magic_type = _type_from_magic(chunk) + if magic_type: + return magic_type + break + return _type_from_content_type(resp.headers.get("Content-Type")) + except requests.RequestException as e: + logger.warning(f"[_probe_url_type] Probe failed for {url}: {e}") + return None + + +def detect_item_type( + value: str, + format_type: str, + fallback: str, + cache: dict[str, str] | None = None, +) -> str: + """Resolve a single attachment item as 'image' or 'pdf'. + + Order: data-URL/base64 magic (no network) -> URL extension -> remote probe + (ranged byte sniff, then Content-Type) -> declared ``fallback`` type. + ``fallback`` may be 'mixed'; when detection is inconclusive it resolves to + 'image'. Remote probe results are memoized in ``cache`` keyed by item value. + """ + # 'mixed' is not a concrete output type; terminal default is image. + safe_fallback = fallback if fallback in ("image", "pdf") else "image" + + if format_type != "url": + data_url_mime, payload = split_data_url(value) + if data_url_mime == "application/pdf": + return "pdf" + if data_url_mime and data_url_mime.startswith("image/"): + return "image" + blob = _decode_base64_prefix(payload) + return (_type_from_magic(blob) if blob else None) or safe_fallback + + if cache is not None and value in cache: + return cache[value] + + item_type = ( + _type_from_url_extension(value) or _probe_url_type(value) or safe_fallback + ) + if cache is not None: + cache[value] = item_type + return item_type + + def resolve_attachment_values( value: str, att: AssessmentAttachment, + type_cache: dict[str, str] | None = None, ) -> list[dict[str, Any]]: """Convert one dataset cell into one or more OpenAI-style input objects.""" value = value.strip() @@ -142,13 +262,14 @@ def resolve_attachment_values( resolved: list[dict[str, Any]] = [] for item_value in values: + item_type = detect_item_type(item_value, att.format, att.type, type_cache) normalized_value = ( - to_direct_attachment_url(item_value, att.type) + to_direct_attachment_url(item_value, item_type) if att.format == "url" else item_value ) - if att.type == "image": + if item_type == "image": if att.format == "url": resolved.append({"type": "input_image", "image_url": normalized_value}) else: @@ -162,7 +283,7 @@ def resolve_attachment_values( "image_url": f"data:{mime_type};base64,{payload}", } ) - elif att.type == "pdf": + elif item_type == "pdf": if att.format == "url": resolved.append( { @@ -181,3 +302,61 @@ def resolve_attachment_values( ) return resolved + + +def build_gemini_attachment_parts( + value: str, + att: AssessmentAttachment, + type_cache: dict[str, str] | None = None, +) -> list[dict[str, Any]]: + """Convert one dataset cell into one or more Gemini content parts. + + Mirrors the per-item type detection used for the L2 batch so the same + image/pdf routing applies to L1 (topic relevance) calls. + """ + value = value.strip() + if not value: + return [] + + values = split_attachment_urls(value) if att.format == "url" else [value] + + parts: list[dict[str, Any]] = [] + for item_value in values: + item_type = detect_item_type(item_value, att.format, att.type, type_cache) + normalized_value = ( + to_direct_attachment_url(item_value, item_type) + if att.format == "url" + else item_value + ) + + if item_type == "image": + mime_type, payload = resolve_image_mime_and_payload( + normalized_value, att.format + ) + if att.format == "url": + parts.append( + {"fileData": {"mimeType": mime_type, "fileUri": normalized_value}} + ) + else: + parts.append({"inlineData": {"mimeType": mime_type, "data": payload}}) + elif item_type == "pdf": + if att.format == "url": + parts.append( + { + "fileData": { + "mimeType": "application/pdf", + "fileUri": normalized_value, + } + } + ) + else: + parts.append( + { + "inlineData": { + "mimeType": "application/pdf", + "data": split_data_url(normalized_value)[1], + } + } + ) + + return parts diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index 6d524e81f..41d84198d 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -1,5 +1,6 @@ """Tests for assessment/batch.py provider routing in submit_assessment_batch.""" +import base64 import io from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -21,6 +22,7 @@ _decode_base64_prefix, _guess_image_mime_from_base64, _guess_image_mime_from_url, + detect_item_type, resolve_attachment_values, resolve_image_mime_and_payload, split_attachment_urls, @@ -423,3 +425,105 @@ def test_build_openai_and_google_jsonl(self) -> None: assert google_jsonl[0]["request"]["systemInstruction"] == { "parts": [{"text": "system"}] } + + +class TestDetectItemType: + """Per-item image/pdf detection for mixed-content attachment columns.""" + + def test_data_url_pdf(self) -> None: + assert ( + detect_item_type("data:application/pdf;base64,JVBERi0=", "base64", "image") + == "pdf" + ) + + def test_data_url_image(self) -> None: + assert ( + detect_item_type("data:image/png;base64,AAAA", "base64", "pdf") == "image" + ) + + def test_base64_magic_pdf(self) -> None: + payload = base64.b64encode(b"%PDF-1.7 body").decode() + assert detect_item_type(payload, "base64", "image") == "pdf" + + def test_base64_magic_png(self) -> None: + payload = base64.b64encode(b"\x89PNG\r\n\x1a\n" + b"0" * 8).decode() + assert detect_item_type(payload, "base64", "pdf") == "image" + + def test_base64_unknown_falls_back(self) -> None: + payload = base64.b64encode(b"not a known magic").decode() + assert detect_item_type(payload, "base64", "pdf") == "pdf" + + def test_mixed_fallback_resolves_to_image(self) -> None: + """'mixed' is never a returned type; inconclusive detection -> image.""" + payload = base64.b64encode(b"not a known magic").decode() + assert detect_item_type(payload, "base64", "mixed") == "image" + + def test_url_extension_pdf_case_insensitive(self) -> None: + assert detect_item_type("https://x.com/a/scan.PDF", "url", "image", {}) == "pdf" + + def test_url_extension_image(self) -> None: + assert detect_item_type("https://x.com/a/p.jpg", "url", "pdf", {}) == "image" + + def test_url_no_extension_probes_bytes(self) -> None: + """Extensionless URL (Drive-style) is probed; magic bytes win over fallback.""" + url = "https://drive.google.com/file/d/ABC123/view" + resp = MagicMock() + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + resp.raise_for_status = MagicMock() + resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) + with patch( + "app.services.assessment.utils.attachments.requests.get", + return_value=resp, + ) as mock_get: + assert detect_item_type(url, "url", "image", {}) == "pdf" + # Drive share URL is probed through the download endpoint. + assert "uc?export=download&id=ABC123" in mock_get.call_args.args[0] + + def test_url_probe_uses_content_type_when_no_magic(self) -> None: + url = "https://example.com/file" + resp = MagicMock() + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + resp.raise_for_status = MagicMock() + resp.iter_content = MagicMock(return_value=iter([b"\x00\x01\x02\x03"])) + resp.headers = {"Content-Type": "application/pdf; charset=binary"} + with patch( + "app.services.assessment.utils.attachments.requests.get", + return_value=resp, + ): + assert detect_item_type(url, "url", "image", {}) == "pdf" + + def test_url_probe_failure_falls_back(self) -> None: + import requests as _requests + + url = "https://example.com/file" + with patch( + "app.services.assessment.utils.attachments.requests.get", + side_effect=_requests.RequestException("boom"), + ): + assert detect_item_type(url, "url", "image", {}) == "image" + + def test_cache_skips_second_probe(self) -> None: + url = "https://drive.google.com/file/d/XYZ/view" + cache: dict[str, str] = {} + resp = MagicMock() + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + resp.raise_for_status = MagicMock() + resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) + with patch( + "app.services.assessment.utils.attachments.requests.get", + return_value=resp, + ) as mock_get: + assert detect_item_type(url, "url", "image", cache) == "pdf" + assert detect_item_type(url, "url", "image", cache) == "pdf" + assert mock_get.call_count == 1 + + def test_mixed_column_resolves_both_types(self) -> None: + """One column, two URLs with extensions -> one image, one pdf object.""" + att = AssessmentAttachment(column="docs", type="image", format="url") + value = "https://x.com/a/photo.jpg, https://x.com/b/report.pdf" + resolved = resolve_attachment_values(value, att, {}) + types = [obj["type"] for obj in resolved] + assert types == ["input_image", "input_file"] diff --git a/backend/app/tests/assessment/test_export.py b/backend/app/tests/assessment/test_export.py index 3ace89dbd..98eb10683 100644 --- a/backend/app/tests/assessment/test_export.py +++ b/backend/app/tests/assessment/test_export.py @@ -144,14 +144,14 @@ def test_all_empty_drops_all(self) -> None: class TestExpandOutputColumns: def test_plain_string_output_not_expanded(self) -> None: rows = [{"output": "plain text", "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "output" in fieldnames def test_json_dict_output_expanded(self) -> None: rows = [ {"output": json.dumps({"score": 5, "reason": "good"}), "input_data": None} ] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "score" in fieldnames assert "reason" in fieldnames assert expanded[0]["score"] == 5 @@ -161,14 +161,14 @@ def test_mixed_parsed_and_unparsed_adds_output_raw(self) -> None: {"output": json.dumps({"score": 3}), "input_data": None}, {"output": "not json", "input_data": None}, ] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "output_raw" in fieldnames # Second row that didn't parse should get output_raw assert expanded[1].get("output_raw") == "not json" def test_none_output_handled(self) -> None: rows = [{"output": None, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert expanded[0].get("output") is None @@ -253,13 +253,13 @@ class TestExpandOutputColumnsDictOutput: def test_dict_output_expanded_directly(self) -> None: # raw output is already a dict (not a JSON string) rows = [{"output": {"score": 9, "label": "good"}, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "score" in fieldnames assert expanded[0]["score"] == 9 def test_non_dict_non_string_output_treated_as_unparsed(self) -> None: rows = [{"output": 42, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) # 42 is not a dict/string, treated as unparsed → output stays as-is assert "output" in fieldnames diff --git a/backend/app/tests/assessment/test_topic_relevance.py b/backend/app/tests/assessment/test_topic_relevance.py new file mode 100644 index 000000000..ad52c2306 --- /dev/null +++ b/backend/app/tests/assessment/test_topic_relevance.py @@ -0,0 +1,123 @@ +"""Tests for L1 topic relevance attachment handling.""" + +import json +from unittest.mock import MagicMock + +from app.models.assessment import AssessmentAttachment +from app.services.assessment.l1.topic_relevance import run_topic_relevance + + +def _client_returning(decision: str) -> MagicMock: + client = MagicMock() + response = MagicMock() + response.text = json.dumps( + {"decision": decision, "Problem": True, "reasoning": "ok"} + ) + client.models.generate_content.return_value = response + return client + + +class TestTopicRelevanceAttachments: + def test_attachments_added_to_contents(self) -> None: + client = _client_returning("ACCEPT") + att = AssessmentAttachment(column="Documents", type="image", format="url") + row = {"Problem": "p", "Documents": "https://x.com/a/photo.jpg"} + + result = run_topic_relevance( + row_idx=0, + row=row, + columns=["Problem"], + user_prompt="rubric", + gemini_client=client, + model="gemini-2.5-flash", + attachments=[att], + type_cache={}, + ) + + assert result["verdict"] is True + contents = client.models.generate_content.call_args.kwargs["contents"] + parts = contents[0]["parts"] + # First part is the text JSON, then a label, then the attachment file part. + assert parts[0]["text"] + file_parts = [p for p in parts if "fileData" in p] + assert len(file_parts) == 1 + assert file_parts[0]["fileData"]["fileUri"] == "https://x.com/a/photo.jpg" + + def test_document_relevance_in_schema_and_result(self) -> None: + """Selected doc column gets its own relevance boolean in column_relevance.""" + client = MagicMock() + response = MagicMock() + response.text = json.dumps( + { + "decision": "ACCEPT", + "Problem": True, + "Documents": True, + "reasoning": "ok", + } + ) + client.models.generate_content.return_value = response + att = AssessmentAttachment(column="Documents", type="image", format="url") + row = {"Problem": "p", "Documents": "https://x.com/a/photo.jpg"} + + result = run_topic_relevance( + row_idx=3, + row=row, + columns=["Problem"], + user_prompt="rubric", + gemini_client=client, + model="gemini-2.5-flash", + attachments=[att], + type_cache={}, + ) + + # Document column carried into the per-column relevance map -> exports + # as topic_relevance_Documents. + assert "Documents" in result["column_relevance"] + assert result["column_relevance"]["Documents"] is True + schema = client.models.generate_content.call_args.kwargs[ + "config" + ].response_schema + assert "Documents" in schema["properties"] + + def test_no_attachments_text_only(self) -> None: + client = _client_returning("REJECT") + row = {"Problem": "p"} + + result = run_topic_relevance( + row_idx=1, + row=row, + columns=["Problem"], + user_prompt="rubric", + gemini_client=client, + model="gemini-2.5-flash", + ) + + assert result["verdict"] is False + contents = client.models.generate_content.call_args.kwargs["contents"] + parts = contents[0]["parts"] + assert len(parts) == 1 + assert parts[0]["text"] + + def test_mixed_column_pdf_item_detected(self) -> None: + client = _client_returning("ACCEPT") + att = AssessmentAttachment(column="Documents", type="mixed", format="url") + row = {"Problem": "p", "Documents": "https://x.com/a/report.pdf"} + + run_topic_relevance( + row_idx=2, + row=row, + columns=["Problem"], + user_prompt="rubric", + gemini_client=client, + model="gemini-2.5-flash", + attachments=[att], + type_cache={}, + ) + + parts = client.models.generate_content.call_args.kwargs["contents"][0]["parts"] + pdf_parts = [ + p + for p in parts + if p.get("fileData", {}).get("mimeType") == "application/pdf" + ] + assert len(pdf_parts) == 1 From 98acf866904d67fa911bb1b659e04fe3d1e5f678 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 11:28:50 +0530 Subject: [PATCH 05/16] feat(tests): update assessment run status to 'l2_processing' and refactor batch submission to Celery task --- backend/app/tests/assessment/test_cron.py | 6 +-- backend/app/tests/assessment/test_crud.py | 3 ++ backend/app/tests/assessment/test_service.py | 53 +++++++------------- 3 files changed, 23 insertions(+), 39 deletions(-) diff --git a/backend/app/tests/assessment/test_cron.py b/backend/app/tests/assessment/test_cron.py index c9407bd5c..d9e8527eb 100644 --- a/backend/app/tests/assessment/test_cron.py +++ b/backend/app/tests/assessment/test_cron.py @@ -115,7 +115,7 @@ async def test_active_run_processed(self) -> None: session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.status = "l2_processing" session.exec.return_value.all.return_value = [assessment] with patch( @@ -140,7 +140,7 @@ async def test_active_run_failure_and_cleanup_failure(self) -> None: session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.status = "l2_processing" session.exec.return_value.all.return_value = [assessment] with patch( @@ -164,7 +164,7 @@ async def test_active_run_failure_updates_db_with_same_error_message(self) -> No session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.status = "l2_processing" session.exec.return_value.all.return_value = [assessment] with patch( diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py index 2bc076342..1d456329e 100644 --- a/backend/app/tests/assessment/test_crud.py +++ b/backend/app/tests/assessment/test_crud.py @@ -232,6 +232,9 @@ def test_build_run_stats(self) -> None: total_items=2, error_message=None, updated_at=datetime(2024, 1, 1), + l1_total_rows=None, + l1_total_passed=None, + l1_total_rejected=None, ), ] stats = build_run_stats(runs) diff --git a/backend/app/tests/assessment/test_service.py b/backend/app/tests/assessment/test_service.py index b3654fa9b..e3d46ef55 100644 --- a/backend/app/tests/assessment/test_service.py +++ b/backend/app/tests/assessment/test_service.py @@ -142,9 +142,6 @@ def test_google_provider_is_supported(self) -> None: config_blob = SimpleNamespace( completion=SimpleNamespace(provider="google", params={"model": "gemini"}) ) - batch_job = MagicMock() - batch_job.id = 101 - batch_job.total_items = 3 with ( patch( @@ -164,13 +161,8 @@ def test_google_provider_is_supported(self) -> None: return_value=run, ), patch( - "app.services.assessment.service.submit_assessment_batch", - return_value=batch_job, - ) as submit_batch, - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ), + "app.celery.tasks.job_execution.run_assessment_run" + ) as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -181,8 +173,10 @@ def test_google_provider_is_supported(self) -> None: project_id=1, ) + # Google is an accepted provider — no rejection, one Celery task dispatched. assert response.num_configs == 1 - assert submit_batch.call_args.kwargs["config_blob"] is config_blob + dispatch.delay.assert_called_once() + assert dispatch.delay.call_args.kwargs["run_id"] == 11 def test_defaults_missing_provider_to_openai(self) -> None: session = MagicMock() @@ -194,9 +188,6 @@ def test_defaults_missing_provider_to_openai(self) -> None: config_blob = SimpleNamespace( completion=SimpleNamespace(provider=None, params={"model": "gpt-4.1-mini"}) ) - batch_job = MagicMock() - batch_job.id = 101 - batch_job.total_items = 3 with ( patch( @@ -216,13 +207,8 @@ def test_defaults_missing_provider_to_openai(self) -> None: return_value=run, ) as create_run, patch( - "app.services.assessment.service.submit_assessment_batch", - return_value=batch_job, - ) as submit_batch, - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ), + "app.celery.tasks.job_execution.run_assessment_run" + ) as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -238,11 +224,7 @@ def test_defaults_missing_provider_to_openai(self) -> None: assert response.runs[0].run_id == 11 assessment_input = create_run.call_args.kwargs["assessment_input"] assert assessment_input["system_instruction"] == "Assess strictly" - assert ( - submit_batch.call_args.kwargs["assessment_input"]["system_instruction"] - == "Assess strictly" - ) - submit_batch.assert_called_once() + dispatch.delay.assert_called_once() def test_rejects_default_tagged_config(self) -> None: """Configs explicitly tagged 'default' must be rejected for assessment.""" @@ -278,14 +260,15 @@ def test_rejects_default_tagged_config(self) -> None: # Tag check must fire BEFORE config resolution. resolve.assert_not_called() - def test_batch_submission_failure_marks_run_failed(self) -> None: + def test_dispatches_one_celery_task_per_config(self) -> None: + """Batch submission moved to the Celery task; start_assessment only + creates runs and dispatches one task per resolved config.""" session = MagicMock() request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) dataset = _make_dataset() assessment = MagicMock() assessment.id = 21 run = _make_run() - run.status = "failed" config_blob = SimpleNamespace( completion=SimpleNamespace( provider="openai", params={"model": "gpt-4.1-mini"} @@ -310,13 +293,8 @@ def test_batch_submission_failure_marks_run_failed(self) -> None: return_value=run, ), patch( - "app.services.assessment.service.submit_assessment_batch", - side_effect=RuntimeError("submit failed"), - ), - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ) as update_run, + "app.celery.tasks.job_execution.run_assessment_run" + ) as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -327,7 +305,10 @@ def test_batch_submission_failure_marks_run_failed(self) -> None: project_id=1, ) assert response.num_configs == 1 - assert update_run.called + dispatch.delay.assert_called_once() + assert dispatch.delay.call_args.kwargs["run_id"] == 11 + assert dispatch.delay.call_args.kwargs["organization_id"] == 1 + assert dispatch.delay.call_args.kwargs["project_id"] == 1 class TestRetryHelpers: From 0addb717f761bd2aab5895d138db4e259f03a715 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 11:37:11 +0530 Subject: [PATCH 06/16] refactor(tests): streamline patching of run_assessment_run in TestStartAssessment --- backend/app/tests/assessment/test_service.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/backend/app/tests/assessment/test_service.py b/backend/app/tests/assessment/test_service.py index e3d46ef55..b56ec72b7 100644 --- a/backend/app/tests/assessment/test_service.py +++ b/backend/app/tests/assessment/test_service.py @@ -160,9 +160,7 @@ def test_google_provider_is_supported(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ), - patch( - "app.celery.tasks.job_execution.run_assessment_run" - ) as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -206,9 +204,7 @@ def test_defaults_missing_provider_to_openai(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ) as create_run, - patch( - "app.celery.tasks.job_execution.run_assessment_run" - ) as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -292,9 +288,7 @@ def test_dispatches_one_celery_task_per_config(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ), - patch( - "app.celery.tasks.job_execution.run_assessment_run" - ) as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): From ad8e29f7e304d6fcf7f07afdf450115ea0e7c012 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:16:35 +0530 Subject: [PATCH 07/16] feat(tests): add comprehensive tests for L1 duplicate detection and pipeline orchestrator --- backend/app/tests/assessment/test_batch.py | 46 ++++ backend/app/tests/assessment/test_crud.py | 78 ++++++- .../assessment/test_duplicate_detection.py | 132 +++++++++++ backend/app/tests/assessment/test_pipeline.py | 151 +++++++++++++ .../tests/assessment/test_post_processing.py | 212 ++++++++++++++++++ 5 files changed, 618 insertions(+), 1 deletion(-) create mode 100644 backend/app/tests/assessment/test_duplicate_detection.py create mode 100644 backend/app/tests/assessment/test_pipeline.py create mode 100644 backend/app/tests/assessment/test_post_processing.py diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index 41d84198d..aa0fce1a0 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -527,3 +527,49 @@ def test_mixed_column_resolves_both_types(self) -> None: resolved = resolve_attachment_values(value, att, {}) types = [obj["type"] for obj in resolved] assert types == ["input_image", "input_file"] + + +class TestAttachmentMagicAndMime: + def test_image_magic_all_formats(self) -> None: + from app.services.assessment.utils.attachments import _image_mime_from_magic + + assert _image_mime_from_magic(b"\x89PNG\r\n\x1a\n") == "image/png" + assert _image_mime_from_magic(b"\xff\xd8\xff") == "image/jpeg" + assert _image_mime_from_magic(b"GIF89a") == "image/gif" + assert _image_mime_from_magic(b"GIF87a") == "image/gif" + assert _image_mime_from_magic(b"BM....") == "image/bmp" + assert _image_mime_from_magic(b"RIFF\x00\x00\x00\x00WEBP") == "image/webp" + assert _image_mime_from_magic(b"II*\x00") == "image/tiff" + assert _image_mime_from_magic(b"MM\x00*") == "image/tiff" + assert _image_mime_from_magic(b"nope") is None + + def test_type_from_magic_pdf_and_none(self) -> None: + from app.services.assessment.utils.attachments import _type_from_magic + + assert _type_from_magic(b"%PDF-1.7") == "pdf" + assert _type_from_magic(b"\x89PNG\r\n\x1a\n") == "image" + assert _type_from_magic(b"random") is None + + def test_guess_image_mime_from_url_variants(self) -> None: + from app.services.assessment.utils.attachments import _guess_image_mime_from_url + + assert _guess_image_mime_from_url("http://x/a.PNG") == "image/png" + assert _guess_image_mime_from_url("http://x/a.jpeg") == "image/jpeg" + assert _guess_image_mime_from_url("http://x/a.webp") == "image/webp" + assert _guess_image_mime_from_url("http://x/a.txt") is None + + def test_resolve_image_mime_data_url(self) -> None: + from app.services.assessment.utils.attachments import ( + resolve_image_mime_and_payload, + ) + + mime, payload = resolve_image_mime_and_payload( + "data:image/webp;base64,AAAA", "base64" + ) + assert mime == "image/webp" + assert payload == "AAAA" + + def test_decode_base64_prefix_empty(self) -> None: + from app.services.assessment.utils.attachments import _decode_base64_prefix + + assert _decode_base64_prefix(" ") is None diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py index 1d456329e..1cf30249e 100644 --- a/backend/app/tests/assessment/test_crud.py +++ b/backend/app/tests/assessment/test_crud.py @@ -2,7 +2,7 @@ from datetime import datetime from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from uuid import UUID import pytest @@ -25,7 +25,9 @@ list_assessments, recompute_assessment_status, update_assessment_run_status, + update_run_post_processing_config, ) +from app.crud.assessment.core import update_assessment_run_l1_stats from app.models.stt_evaluation import EvaluationType @@ -300,3 +302,77 @@ def test_recompute_commit_failure_rolls_back(self) -> None: with pytest.raises(RuntimeError): recompute_assessment_status(session=session, assessment_id=1) session.rollback.assert_called_once() + + +class TestUpdateRunPostProcessingConfig: + def test_sets_config_in_input_blob(self) -> None: + session = MagicMock() + run = SimpleNamespace(id=5, input={"text_columns": ["q"]}) + cfg = {"computed_columns": [{"name": "T", "formula": "@a"}]} + with patch("app.crud.assessment.core.flag_modified") as flag: + out = update_run_post_processing_config( + session=session, run=run, config=cfg + ) + assert out.input["post_processing_config"] == cfg + assert out.input["text_columns"] == ["q"] + flag.assert_called_once_with(run, "input") + session.commit.assert_called_once() + + def test_none_input_handled(self) -> None: + session = MagicMock() + run = SimpleNamespace(id=6, input=None) + with patch("app.crud.assessment.core.flag_modified"): + out = update_run_post_processing_config( + session=session, run=run, config=None + ) + assert out.input == {"post_processing_config": None} + + def test_commit_failure_rolls_back(self) -> None: + session = MagicMock() + session.commit.side_effect = RuntimeError("db error") + run = SimpleNamespace(id=7, input={}) + with patch("app.crud.assessment.core.flag_modified"): + with pytest.raises(RuntimeError): + update_run_post_processing_config(session=session, run=run, config={}) + session.rollback.assert_called_once() + + +class TestUpdateAssessmentRunL1Stats: + def test_sets_stats_fields(self) -> None: + session = MagicMock() + run = SimpleNamespace( + id=8, + updated_at=None, + l1_object_store_url=None, + l1_total_rows=None, + l1_total_passed=None, + l1_total_rejected=None, + ) + out = update_assessment_run_l1_stats( + session=session, + run=run, + l1_object_store_url="s3://x", + l1_total_rows=10, + l1_total_passed=7, + l1_total_rejected=3, + ) + assert out.l1_object_store_url == "s3://x" + assert out.l1_total_rows == 10 + assert out.l1_total_passed == 7 + assert out.l1_total_rejected == 3 + session.commit.assert_called_once() + + def test_commit_failure_rolls_back(self) -> None: + session = MagicMock() + session.commit.side_effect = RuntimeError("db error") + run = SimpleNamespace( + id=9, + updated_at=None, + l1_object_store_url=None, + l1_total_rows=None, + l1_total_passed=None, + l1_total_rejected=None, + ) + with pytest.raises(RuntimeError): + update_assessment_run_l1_stats(session=session, run=run, l1_total_rows=1) + session.rollback.assert_called_once() diff --git a/backend/app/tests/assessment/test_duplicate_detection.py b/backend/app/tests/assessment/test_duplicate_detection.py new file mode 100644 index 000000000..24d5ac951 --- /dev/null +++ b/backend/app/tests/assessment/test_duplicate_detection.py @@ -0,0 +1,132 @@ +"""Tests for L1 duplicate detection.""" + +import json +from unittest.mock import MagicMock + +from app.services.assessment.l1.duplicate_detection import ( + _build_combined, + _parse_verdict, + run_duplicate_detection, +) + + +def _vague_client(vague: bool, reason: str = "r") -> MagicMock: + client = MagicMock() + resp = MagicMock() + resp.text = json.dumps({"vague": vague, "reason": reason}) + client.models.generate_content.return_value = resp + return client + + +class TestBuildCombined: + def test_joins_non_empty(self) -> None: + out = _build_combined({"Problem": "p", "Solution": "s", "Empty": " "}) + assert "Problem:\np" in out + assert "Solution:\ns" in out + assert "Empty" not in out + + +class TestParseVerdict: + def test_full_fields(self) -> None: + raw = ( + "Verdict: DUPLICATE\n" + "Title: Some Idea\n" + "Source: https://x.com/a\n" + "URL: https://x.com/a\n" + "Matching sentence: a beam alarm\n" + "Reason: same mechanism" + ) + out = _parse_verdict(raw) + assert out["verdict"] == "DUPLICATE" + assert out["match_title"] == "Some Idea" + assert out["source_url"] == "https://x.com/a" + assert out["matching_sentence"] == "a beam alarm" + assert out["reason"] == "same mechanism" + + def test_unique_verdict_only(self) -> None: + out = _parse_verdict("Verdict: UNIQUE\nReason: nothing matches") + assert out["verdict"] == "UNIQUE" + assert out["match_title"] is None + + def test_regex_fallback_when_key_missing(self) -> None: + out = _parse_verdict("The result is clearly OVERLAP here.") + assert out["verdict"] == "OVERLAP" + + def test_no_verdict_stays_empty(self) -> None: + out = _parse_verdict("no decision present") + assert out["verdict"] == "" + + +class TestRunDuplicateDetection: + def test_vague_short_circuits(self) -> None: + client = _vague_client(True, "too vague") + result = run_duplicate_detection( + row_idx=0, + row={"Problem": "x"}, + columns=["Problem"], + gemini_client=client, + model="gemini-2.5-flash", + store_name="store", + ) + assert result["verdict"] == "VAGUE" + assert result["reason"] == "too vague" + # Only the vague check is called; no file-search second call. + assert client.models.generate_content.call_count == 1 + + def test_not_vague_runs_file_search(self) -> None: + client = MagicMock() + vague_resp = MagicMock() + vague_resp.text = json.dumps({"vague": False, "reason": ""}) + search_resp = MagicMock() + search_resp.text = "Verdict: UNIQUE\nReason: novel" + client.models.generate_content.side_effect = [vague_resp, search_resp] + + result = run_duplicate_detection( + row_idx=1, + row={"Problem": "p", "Solution": "s"}, + columns=["Problem", "Solution"], + gemini_client=client, + model="gemini-2.5-flash", + store_name="store", + ) + assert result["verdict"] == "UNIQUE" + assert result["reason"] == "novel" + assert result["row_id"] == "row_1" + + def test_file_search_error_returns_error_verdict(self) -> None: + client = MagicMock() + vague_resp = MagicMock() + vague_resp.text = json.dumps({"vague": False, "reason": ""}) + client.models.generate_content.side_effect = [ + vague_resp, + RuntimeError("search boom"), + ] + + result = run_duplicate_detection( + row_idx=2, + row={"Problem": "p"}, + columns=["Problem"], + gemini_client=client, + model="gemini-2.5-flash", + store_name="store", + ) + assert result["verdict"] == "ERROR" + assert "search boom" in result["reason"] + + def test_vague_check_parse_error_defaults_not_vague(self) -> None: + client = MagicMock() + bad_vague = MagicMock() + bad_vague.text = "not json" + search_resp = MagicMock() + search_resp.text = "Verdict: PARTIAL_MATCH\nTitle: T\nReason: theme" + client.models.generate_content.side_effect = [bad_vague, search_resp] + + result = run_duplicate_detection( + row_idx=3, + row={"Problem": "p"}, + columns=["Problem"], + gemini_client=client, + model="gemini-2.5-flash", + store_name="store", + ) + assert result["verdict"] == "PARTIAL_MATCH" diff --git a/backend/app/tests/assessment/test_pipeline.py b/backend/app/tests/assessment/test_pipeline.py new file mode 100644 index 000000000..faa64693e --- /dev/null +++ b/backend/app/tests/assessment/test_pipeline.py @@ -0,0 +1,151 @@ +"""Tests for the L1 pipeline orchestrator.""" + +from contextlib import ExitStack +from unittest.mock import MagicMock, patch + +from app.services.assessment.l1.pipeline import run_l1_pipeline + + +def _run() -> MagicMock: + run = MagicMock() + run.id = 99 + return run + + +def _tr(verdict: bool, decision: str = "ACCEPT") -> dict: + return { + "row_id": "row", + "verdict": verdict, + "decision": decision, + "column_relevance": {"Problem": verdict}, + "reasoning": "r", + } + + +def _patches(stack: ExitStack, *, tr_side=None, dup_return=None): + """Patch the pipeline's external deps; return the TR mock.""" + client = MagicMock() + stack.enter_context( + patch( + "app.services.assessment.l1.pipeline.GeminiClient.from_credentials", + return_value=MagicMock(client=client), + ) + ) + stack.enter_context( + patch( + "app.services.assessment.l1.pipeline.get_cloud_storage", + return_value=MagicMock(), + ) + ) + stack.enter_context( + patch( + "app.services.assessment.l1.pipeline.upload_jsonl_to_object_store", + return_value="s3://l1.json", + ) + ) + stack.enter_context( + patch("app.crud.assessment.core.update_assessment_run_l1_stats") + ) + tr_mock = stack.enter_context( + patch("app.services.assessment.l1.pipeline.run_topic_relevance") + ) + if tr_side is not None: + tr_mock.side_effect = tr_side + dup_mock = stack.enter_context( + patch("app.services.assessment.l1.pipeline.run_duplicate_detection") + ) + if dup_return is not None: + dup_mock.return_value = dup_return + return tr_mock, dup_mock + + +class TestRunL1Pipeline: + def test_no_filters_configured_passthrough(self) -> None: + rows = [{"Problem": "a"}, {"Problem": "b"}] + passed, indices, results = run_l1_pipeline( + run=_run(), + rows=rows, + l1_config={}, + session=MagicMock(), + organization_id=1, + project_id=1, + ) + assert passed == rows + assert indices == [0, 1] + assert results == [] + + def test_topic_relevance_filters_rejected_rows(self) -> None: + rows = [{"Problem": "keep"}, {"Problem": "drop"}, {"Problem": "keep2"}] + # idx 1 rejected. + side = [_tr(True), _tr(False, "REJECT"), _tr(True)] + with ExitStack() as stack: + _patches(stack, tr_side=side) + passed, indices, results = run_l1_pipeline( + run=_run(), + rows=rows, + l1_config={ + "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"} + }, + session=MagicMock(), + organization_id=1, + project_id=1, + ) + assert indices == [0, 2] + assert [r["Problem"] for r in passed] == ["keep", "keep2"] + assert len(results) == 3 + assert results[1]["l1_passed"] is False + + def test_duplicate_detection_runs_on_passed_rows(self) -> None: + rows = [{"Problem": "a", "Solution": "b"}] + dup = { + "row_id": "row_0", + "verdict": "UNIQUE", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": "novel", + } + with ExitStack() as stack: + tr_mock, dup_mock = _patches(stack, tr_side=[_tr(True)], dup_return=dup) + _, _, results = run_l1_pipeline( + run=_run(), + rows=rows, + l1_config={ + "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"}, + "duplicate_detection": {"columns": ["Problem", "Solution"]}, + }, + session=MagicMock(), + organization_id=1, + project_id=1, + ) + dup_mock.assert_called_once() + assert results[0]["duplicate_detection"]["verdict"] == "UNIQUE" + + def test_attachment_columns_filtered_to_selection(self) -> None: + from app.models.assessment import AssessmentAttachment + + rows = [{"Problem": "a", "Docs": "x", "Other": "y"}] + atts = [ + AssessmentAttachment(column="Docs", type="image", format="url"), + AssessmentAttachment(column="Other", type="image", format="url"), + ] + with ExitStack() as stack: + tr_mock, _ = _patches(stack, tr_side=[_tr(True)]) + run_l1_pipeline( + run=_run(), + rows=rows, + l1_config={ + "topic_relevance": { + "columns": ["Problem"], + "prompt": "rubric", + "attachment_columns": ["Docs"], + } + }, + session=MagicMock(), + organization_id=1, + project_id=1, + attachments=atts, + ) + # run_topic_relevance is called with only the selected attachment ("Docs"). + passed_atts = tr_mock.call_args.args[6] + assert [a.column for a in passed_atts] == ["Docs"] diff --git a/backend/app/tests/assessment/test_post_processing.py b/backend/app/tests/assessment/test_post_processing.py new file mode 100644 index 000000000..0ee7b81cc --- /dev/null +++ b/backend/app/tests/assessment/test_post_processing.py @@ -0,0 +1,212 @@ +"""Tests for the assessment export post-processing engine.""" + +from app.services.assessment.utils.post_processing import ( + apply_computed_columns, + apply_filter, + apply_post_processing, + apply_sort, + evaluate_formula, +) + + +class TestEvaluateFormula: + def test_addition(self) -> None: + assert evaluate_formula("@a + @b", {"a": 2, "b": 3}) == 5.0 + + def test_all_operators(self) -> None: + row = {"a": 10, "b": 4} + assert evaluate_formula("@a - @b", row) == 6.0 + assert evaluate_formula("@a * @b", row) == 40.0 + assert evaluate_formula("@a / @b", row) == 2.5 + assert evaluate_formula("-@a", row) == -10.0 + + def test_precedence_and_constants(self) -> None: + assert evaluate_formula("@a + @b * 0.5", {"a": 1, "b": 4}) == 3.0 + + def test_string_numeric_values_coerced(self) -> None: + assert evaluate_formula("@a + @b", {"a": "2", "b": "3"}) == 5.0 + + def test_missing_column_is_zero(self) -> None: + assert evaluate_formula("@a + @b", {"a": 5}) == 5.0 + + def test_non_numeric_value_is_zero(self) -> None: + assert evaluate_formula("@a + @b", {"a": 5, "b": "abc"}) == 5.0 + + def test_unsupported_operation_returns_none(self) -> None: + # Power operator is not in the safe-ops allowlist. + assert evaluate_formula("@a ** @b", {"a": 2, "b": 3}) is None + + def test_syntax_error_returns_none(self) -> None: + assert evaluate_formula("@a +", {"a": 1}) is None + + +class TestApplyComputedColumns: + def test_adds_column_in_place(self) -> None: + rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + apply_computed_columns(rows, [{"name": "total", "formula": "@a + @b"}]) + assert rows[0]["total"] == 3.0 + assert rows[1]["total"] == 7.0 + + def test_skips_empty_name_or_formula(self) -> None: + rows = [{"a": 1}] + apply_computed_columns( + rows, + [ + {"name": "", "formula": "@a"}, + {"name": "x", "formula": ""}, + ], + ) + assert rows[0] == {"a": 1} + + +class TestApplyFilter: + def test_no_rules_returns_all(self) -> None: + rows = [{"a": 1}, {"a": 2}] + assert apply_filter(rows, []) == rows + + def test_eq_ne(self) -> None: + rows = [{"x": "Yes"}, {"x": "no"}] + assert apply_filter(rows, [{"column": "x", "op": "eq", "value": "yes"}]) == [ + {"x": "Yes"} + ] + assert apply_filter(rows, [{"column": "x", "op": "ne", "value": "yes"}]) == [ + {"x": "no"} + ] + + def test_contains_not_contains(self) -> None: + rows = [{"x": "hello world"}, {"x": "bye"}] + assert apply_filter( + rows, [{"column": "x", "op": "contains", "value": "world"}] + ) == [{"x": "hello world"}] + assert apply_filter( + rows, [{"column": "x", "op": "not_contains", "value": "world"}] + ) == [{"x": "bye"}] + + def test_in_not_in(self) -> None: + rows = [{"x": "a"}, {"x": "b"}] + assert apply_filter( + rows, [{"column": "x", "op": "in", "value": ["a", "c"]}] + ) == [{"x": "a"}] + assert apply_filter( + rows, [{"column": "x", "op": "not_in", "value": ["a", "c"]}] + ) == [{"x": "b"}] + + def test_is_empty_is_not_empty(self) -> None: + rows = [{"x": ""}, {"x": "v"}, {"x": None}] + assert apply_filter(rows, [{"column": "x", "op": "is_empty"}]) == [ + {"x": ""}, + {"x": None}, + ] + assert apply_filter(rows, [{"column": "x", "op": "is_not_empty"}]) == [ + {"x": "v"} + ] + + def test_numeric_comparisons(self) -> None: + rows = [{"n": 1}, {"n": 5}, {"n": 10}] + assert apply_filter(rows, [{"column": "n", "op": "gt", "value": 4}]) == [ + {"n": 5}, + {"n": 10}, + ] + assert apply_filter(rows, [{"column": "n", "op": "lt", "value": 5}]) == [ + {"n": 1} + ] + assert apply_filter(rows, [{"column": "n", "op": "gte", "value": 5}]) == [ + {"n": 5}, + {"n": 10}, + ] + assert apply_filter(rows, [{"column": "n", "op": "lte", "value": 5}]) == [ + {"n": 1}, + {"n": 5}, + ] + + def test_numeric_filter_non_numeric_excluded(self) -> None: + rows = [{"n": "abc"}, {"n": 5}] + assert apply_filter(rows, [{"column": "n", "op": "gt", "value": 1}]) == [ + {"n": 5} + ] + + def test_unknown_op_keeps_row(self) -> None: + rows = [{"x": "a"}] + assert apply_filter(rows, [{"column": "x", "op": "weird", "value": 1}]) == rows + + def test_and_logic_across_rules(self) -> None: + rows = [{"n": 5, "x": "yes"}, {"n": 5, "x": "no"}, {"n": 1, "x": "yes"}] + out = apply_filter( + rows, + [ + {"column": "n", "op": "gte", "value": 5}, + {"column": "x", "op": "eq", "value": "yes"}, + ], + ) + assert out == [{"n": 5, "x": "yes"}] + + +class TestApplySort: + def test_no_rules_returns_input(self) -> None: + rows = [{"n": 2}, {"n": 1}] + assert apply_sort(rows, []) == rows + + def test_numeric_asc_desc(self) -> None: + rows = [{"n": 3}, {"n": 1}, {"n": 2}] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "asc"}]) + ] == [1, 2, 3] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "desc"}]) + ] == [3, 2, 1] + + def test_none_values_sort_last(self) -> None: + rows = [{"n": None}, {"n": 2}, {"n": 1}] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "asc"}]) + ] == [1, 2, None] + + def test_string_asc_desc(self) -> None: + rows = [{"s": "banana"}, {"s": "apple"}, {"s": "cherry"}] + assert [ + r["s"] for r in apply_sort(rows, [{"column": "s", "direction": "asc"}]) + ] == ["apple", "banana", "cherry"] + assert [ + r["s"] for r in apply_sort(rows, [{"column": "s", "direction": "desc"}]) + ] == ["cherry", "banana", "apple"] + + def test_multi_rule_priority(self) -> None: + rows = [ + {"grp": "a", "n": 2}, + {"grp": "b", "n": 1}, + {"grp": "a", "n": 1}, + ] + out = apply_sort( + rows, + [ + {"column": "grp", "direction": "asc"}, + {"column": "n", "direction": "desc"}, + ], + ) + assert out == [ + {"grp": "a", "n": 2}, + {"grp": "a", "n": 1}, + {"grp": "b", "n": 1}, + ] + + +class TestApplyPostProcessing: + def test_none_config_is_noop(self) -> None: + rows = [{"a": 1}] + assert apply_post_processing(rows, None) is rows + + def test_full_pipeline(self) -> None: + rows = [ + {"Novelty": 3, "Feasibility": 4}, + {"Novelty": 9, "Feasibility": 8}, + {"Novelty": 1, "Feasibility": 1}, + ] + config = { + "computed_columns": [ + {"name": "Total", "formula": "@Novelty + @Feasibility"} + ], + "filter": [{"column": "Total", "op": "gt", "value": 5}], + "sort": [{"column": "Total", "direction": "desc"}], + } + out = apply_post_processing(rows, config) + assert [r["Total"] for r in out] == [17.0, 7.0] From e020717949fe8b1a5b8bdf1d09207c1b6ceb7e23 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:28:06 +0530 Subject: [PATCH 08/16] feat: implement prefilter pipeline with topic relevance and duplicate detection - Added a new prefilter pipeline orchestrator that runs topic relevance and duplicate detection filters in series. - Created `run_topic_relevance` and `run_duplicate_detection` functions to handle respective filtering logic. - Updated assessment service to utilize prefilter configuration instead of L1 configuration. - Modified assessment tasks to reflect the new prefilter processing status and error handling. - Adjusted utility functions and export logic to accommodate prefilter results. - Enhanced tests to cover the new prefilter functionality and ensure proper integration. --- ...dd_prefilter_columns_to_assessment_run.py} | 24 ++--- backend/app/api/routes/assessment/runs.py | 6 +- backend/app/celery/tasks/job_execution.py | 5 +- backend/app/core/config.py | 13 ++- backend/app/crud/assessment/__init__.py | 4 +- backend/app/crud/assessment/batch.py | 2 +- backend/app/crud/assessment/core.py | 42 ++++----- backend/app/models/assessment.py | 36 ++++---- .../app/services/assessment/l1/__init__.py | 3 - .../services/assessment/prefilter/__init__.py | 3 + .../{l1 => prefilter}/duplicate_detection.py | 10 ++- .../assessment/{l1 => prefilter}/pipeline.py | 87 +++++++++--------- .../{l1 => prefilter}/topic_relevance.py | 6 +- backend/app/services/assessment/service.py | 8 +- backend/app/services/assessment/tasks.py | 36 ++++---- .../services/assessment/utils/attachments.py | 55 +++++++----- .../app/services/assessment/utils/export.py | 90 ++++++++++--------- backend/app/tests/assessment/test_batch.py | 46 ++++++++++ backend/app/tests/assessment/test_crud.py | 46 +++++----- .../assessment/test_duplicate_detection.py | 4 +- backend/app/tests/assessment/test_pipeline.py | 36 ++++---- .../tests/assessment/test_topic_relevance.py | 4 +- 22 files changed, 328 insertions(+), 238 deletions(-) rename backend/app/alembic/versions/{064_add_l1_columns_to_assessment_run.py => 064_add_prefilter_columns_to_assessment_run.py} (60%) delete mode 100644 backend/app/services/assessment/l1/__init__.py create mode 100644 backend/app/services/assessment/prefilter/__init__.py rename backend/app/services/assessment/{l1 => prefilter}/duplicate_detection.py (95%) rename backend/app/services/assessment/{l1 => prefilter}/pipeline.py (67%) rename backend/app/services/assessment/{l1 => prefilter}/topic_relevance.py (94%) diff --git a/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py b/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py similarity index 60% rename from backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py rename to backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py index bce33e6cd..1720e21b4 100644 --- a/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py +++ b/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py @@ -1,4 +1,4 @@ -"""Add L1 pipeline columns to assessment_run +"""Add prefilter pipeline columns to assessment_run Revision ID: 064 Revises: 063 @@ -19,25 +19,25 @@ def upgrade() -> None: op.add_column( "assessment_run", sa.Column( - "l1_object_store_url", + "prefilter_object_store_url", sa.String(), nullable=True, - comment="S3 URL of stored L1 filter results JSON", + comment="S3 URL of stored prefilter filter results JSON", ), ) op.add_column( "assessment_run", sa.Column( - "l1_total_rows", + "prefilter_total_rows", sa.Integer(), nullable=True, - comment="Total rows fed into L1 pipeline", + comment="Total rows fed into prefilter pipeline", ), ) op.add_column( "assessment_run", sa.Column( - "l1_total_passed", + "prefilter_total_passed", sa.Integer(), nullable=True, comment="Rows that passed topic relevance and went to L2", @@ -46,16 +46,16 @@ def upgrade() -> None: op.add_column( "assessment_run", sa.Column( - "l1_total_rejected", + "prefilter_total_rejected", sa.Integer(), nullable=True, - comment="Rows rejected by topic relevance, stopped at L1", + comment="Rows rejected by topic relevance, stopped at prefilter", ), ) def downgrade() -> None: - op.drop_column("assessment_run", "l1_total_rejected") - op.drop_column("assessment_run", "l1_total_passed") - op.drop_column("assessment_run", "l1_total_rows") - op.drop_column("assessment_run", "l1_object_store_url") + op.drop_column("assessment_run", "prefilter_total_rejected") + op.drop_column("assessment_run", "prefilter_total_passed") + op.drop_column("assessment_run", "prefilter_total_rows") + op.drop_column("assessment_run", "prefilter_object_store_url") diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py index 3c3abd57a..2825e5c86 100644 --- a/backend/app/api/routes/assessment/runs.py +++ b/backend/app/api/routes/assessment/runs.py @@ -67,9 +67,9 @@ def _build_run_public( total_items=run.total_items, error_message=run.error_message, input=run.input, - l1_total_rows=run.l1_total_rows, - l1_total_passed=run.l1_total_passed, - l1_total_rejected=run.l1_total_rejected, + prefilter_total_rows=run.prefilter_total_rows, + prefilter_total_passed=run.prefilter_total_passed, + prefilter_total_rejected=run.prefilter_total_rejected, post_processing_config=(run.input or {}).get("post_processing_config"), inserted_at=run.inserted_at, updated_at=run.updated_at, diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index ec7ad1bd0..6a249a92e 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -232,9 +232,8 @@ def run_tts_batch_submission( ) -@celery_app.task( - bind=True, queue="low_priority", priority=1, soft_time_limit=1800, time_limit=2100 -) +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.ASSESSMENT_RUN_SOFT_TIME_LIMIT, "run_assessment_run") def run_assessment_run( self, run_id: int, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 60504147b..e2e3a5ff1 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -172,11 +172,18 @@ def AWS_S3_BUCKET(self) -> str: PENDING_JOB_QUERY_TIMEOUT_MS: int = 1000 # Assessment - ASSESSMENT_L1_GEMINI_MODEL: str = "gemini-3.1-flash-lite" - ASSESSMENT_L1_CONCURRENT_WORKERS: int = 8 - ASSESSMENT_L1_DUPLICATE_STORE_NAME: str = ( + ASSESSMENT_PREFILTER_GEMINI_MODEL: str = "gemini-3.1-flash-lite" + ASSESSMENT_PREFILTER_CONCURRENT_WORKERS: int = 8 + ASSESSMENT_PREFILTER_DUPLICATE_STORE_NAME: str = ( "fileSearchStores/inquilabcorpus-782mxjcwisaz" ) + # Soft timeout for the full assessment run task (prefilter pipeline + batch + # submission). Larger than the default task limit because prefilter runs many + # concurrent LLM calls over the whole dataset. Seconds. Default 2 hours. + ASSESSMENT_RUN_SOFT_TIME_LIMIT: int = 7200 + # Timeout for prefilter Gemini calls to prevent pipeline stalls from slow/hung requests + # (default: 2 minutes, in ms) + ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS: int = 120000 @computed_field # type: ignore[prop-decorator] @property diff --git a/backend/app/crud/assessment/__init__.py b/backend/app/crud/assessment/__init__.py index 8e623e3a7..2f5f6f217 100644 --- a/backend/app/crud/assessment/__init__.py +++ b/backend/app/crud/assessment/__init__.py @@ -13,7 +13,7 @@ list_assessment_runs, list_assessments, recompute_assessment_status, - update_assessment_run_l1_stats, + update_assessment_run_prefilter_stats, update_assessment_run_status, update_run_post_processing_config, ) @@ -44,7 +44,7 @@ "list_assessment_datasets", "list_assessments", "recompute_assessment_status", - "update_assessment_run_l1_stats", + "update_assessment_run_prefilter_stats", "update_assessment_run_status", "update_run_post_processing_config", ] diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py index 531dc038d..5debc3156 100644 --- a/backend/app/crud/assessment/batch.py +++ b/backend/app/crud/assessment/batch.py @@ -322,7 +322,7 @@ def submit_assessment_batch( output_schema = assessment_input.get("output_schema") attachments = [AssessmentAttachment(**a) for a in attachments_raw] - # Use preloaded rows (post-L1 filtered) if provided, else load from dataset. + # Use preloaded rows (post-prefilter filtered) if provided, else load from dataset. if preloaded_rows is not None: rows = preloaded_rows else: diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py index d5a184d06..547cc2e31 100644 --- a/backend/app/crud/assessment/core.py +++ b/backend/app/crud/assessment/core.py @@ -248,25 +248,25 @@ def update_assessment_run_status( return run -def update_assessment_run_l1_stats( +def update_assessment_run_prefilter_stats( session: Session, run: AssessmentRun, - l1_object_store_url: str | None = None, - l1_total_rows: int | None = None, - l1_total_passed: int | None = None, - l1_total_rejected: int | None = None, + prefilter_object_store_url: str | None = None, + prefilter_total_rows: int | None = None, + prefilter_total_passed: int | None = None, + prefilter_total_rejected: int | None = None, ) -> AssessmentRun: - """Persist L1 result stats (rows/passed/rejected + S3 URL) on a run.""" + """Persist prefilter result stats (rows/passed/rejected + S3 URL) on a run.""" run.updated_at = now() - if l1_object_store_url is not None: - run.l1_object_store_url = l1_object_store_url - if l1_total_rows is not None: - run.l1_total_rows = l1_total_rows - if l1_total_passed is not None: - run.l1_total_passed = l1_total_passed - if l1_total_rejected is not None: - run.l1_total_rejected = l1_total_rejected + if prefilter_object_store_url is not None: + run.prefilter_object_store_url = prefilter_object_store_url + if prefilter_total_rows is not None: + run.prefilter_total_rows = prefilter_total_rows + if prefilter_total_passed is not None: + run.prefilter_total_passed = prefilter_total_passed + if prefilter_total_rejected is not None: + run.prefilter_total_rejected = prefilter_total_rejected session.add(run) try: @@ -274,16 +274,18 @@ def update_assessment_run_l1_stats( session.refresh(run) except Exception as e: session.rollback() - logger.error(f"[update_assessment_run_l1_stats] Failed: {e}", exc_info=True) + logger.error( + f"[update_assessment_run_prefilter_stats] Failed: {e}", exc_info=True + ) raise return run _ACTIVE_RUN_STATUSES = frozenset( - {"l1_processing", "l2_processing", "processing", "in_progress"} + {"prefilter_processing", "l2_processing", "processing", "in_progress"} ) -_FAILED_RUN_STATUSES = frozenset({"failed", "l1_failed"}) +_FAILED_RUN_STATUSES = frozenset({"failed", "prefilter_failed"}) _COMPLETED_RUN_STATUSES = frozenset({"completed", "completed_with_errors"}) @@ -329,9 +331,9 @@ def build_run_stats(runs: list[AssessmentRun]) -> list[AssessmentRunStat]: total_items=run.total_items, error_message=run.error_message, updated_at=run.updated_at, - l1_total_rows=run.l1_total_rows, - l1_total_passed=run.l1_total_passed, - l1_total_rejected=run.l1_total_rejected, + prefilter_total_rows=run.prefilter_total_rows, + prefilter_total_passed=run.prefilter_total_passed, + prefilter_total_rejected=run.prefilter_total_rejected, ) for run in runs ] diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index b5a1a31f5..8ff468db2 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -109,7 +109,7 @@ class AssessmentRun(SQLModel, table=True): default="pending", sa_column_kwargs={ "comment": ( - "Unified pipeline status: pending, l1_processing, l1_failed, " + "Unified pipeline status: pending, prefilter_processing, prefilter_failed, " "l2_processing, completed, completed_with_errors, failed" ) }, @@ -141,25 +141,27 @@ class AssessmentRun(SQLModel, table=True): nullable=True, sa_column_kwargs={"comment": "S3 URL of processed L2 batch results"}, ) - l1_object_store_url: str | None = SQLField( + prefilter_object_store_url: str | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "S3 URL of stored L1 filter results JSON"}, + sa_column_kwargs={"comment": "S3 URL of stored prefilter filter results JSON"}, ) - l1_total_rows: int | None = SQLField( + prefilter_total_rows: int | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "Total rows fed into L1 pipeline"}, + sa_column_kwargs={"comment": "Total rows fed into prefilter pipeline"}, ) - l1_total_passed: int | None = SQLField( + prefilter_total_passed: int | None = SQLField( default=None, nullable=True, sa_column_kwargs={"comment": "Rows that passed topic relevance and went to L2"}, ) - l1_total_rejected: int | None = SQLField( + prefilter_total_rejected: int | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "Rows rejected by topic relevance, stopped at L1"}, + sa_column_kwargs={ + "comment": "Rows rejected by topic relevance, stopped at prefilter" + }, ) error_message: str | None = SQLField( default=None, @@ -208,9 +210,9 @@ class AssessmentRunStat(BaseModel): total_items: int error_message: str | None = None updated_at: datetime | None = None - l1_total_rows: int | None = None - l1_total_passed: int | None = None - l1_total_rejected: int | None = None + prefilter_total_rows: int | None = None + prefilter_total_passed: int | None = None + prefilter_total_rejected: int | None = None class AssessmentPublic(BaseModel): @@ -250,9 +252,9 @@ class AssessmentRunPublic(BaseModel): "text_columns, attachments, output_schema" ), ) - l1_total_rows: int | None = None - l1_total_passed: int | None = None - l1_total_rejected: int | None = None + prefilter_total_rows: int | None = None + prefilter_total_passed: int | None = None + prefilter_total_rejected: int | None = None post_processing_config: dict[str, Any] | None = None inserted_at: datetime updated_at: datetime @@ -323,11 +325,11 @@ class AssessmentCreate(BaseModel): configs: list[AssessmentConfigRef] = Field( ..., min_length=1, max_length=4, description="Config versions to run" ) - l1_config: dict[str, Any] | None = Field( + prefilter_config: dict[str, Any] | None = Field( None, description=( - "L1 pipeline config. Keys: topic_relevance (columns, prompt), " - "duplicate_detection (columns). Omit to skip L1." + "prefilter pipeline config. Keys: topic_relevance (columns, prompt), " + "duplicate_detection (columns). Omit to skip prefilter." ), ) post_processing_config: dict[str, Any] | None = Field( diff --git a/backend/app/services/assessment/l1/__init__.py b/backend/app/services/assessment/l1/__init__.py deleted file mode 100644 index 66e3a0374..000000000 --- a/backend/app/services/assessment/l1/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from app.services.assessment.l1.pipeline import run_l1_pipeline - -__all__ = ["run_l1_pipeline"] diff --git a/backend/app/services/assessment/prefilter/__init__.py b/backend/app/services/assessment/prefilter/__init__.py new file mode 100644 index 000000000..6cd16dce2 --- /dev/null +++ b/backend/app/services/assessment/prefilter/__init__.py @@ -0,0 +1,3 @@ +from app.services.assessment.prefilter.pipeline import run_prefilter_pipeline + +__all__ = ["run_prefilter_pipeline"] diff --git a/backend/app/services/assessment/l1/duplicate_detection.py b/backend/app/services/assessment/prefilter/duplicate_detection.py similarity index 95% rename from backend/app/services/assessment/l1/duplicate_detection.py rename to backend/app/services/assessment/prefilter/duplicate_detection.py index 608389c1d..bba004457 100644 --- a/backend/app/services/assessment/l1/duplicate_detection.py +++ b/backend/app/services/assessment/prefilter/duplicate_detection.py @@ -1,4 +1,4 @@ -"""Duplicate detection filter for L1 pipeline.""" +"""Duplicate detection filter for prefilter pipeline.""" import json import logging @@ -8,6 +8,8 @@ from google import genai from google.genai import types +from app.core.config import settings + logger = logging.getLogger(__name__) _VAGUE_SYS = """ @@ -78,6 +80,9 @@ def _check_vague( system_instruction=_VAGUE_SYS, response_mime_type="application/json", temperature=0.0, + http_options=types.HttpOptions( + timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS + ), ), ) parsed = json.loads((response.text or "").strip()) @@ -104,6 +109,9 @@ def _call_file_search( ) ], temperature=0.0, + http_options=types.HttpOptions( + timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS + ), ), ) return response.text or "" diff --git a/backend/app/services/assessment/l1/pipeline.py b/backend/app/services/assessment/prefilter/pipeline.py similarity index 67% rename from backend/app/services/assessment/l1/pipeline.py rename to backend/app/services/assessment/prefilter/pipeline.py index 18df91324..131fdde8b 100644 --- a/backend/app/services/assessment/l1/pipeline.py +++ b/backend/app/services/assessment/prefilter/pipeline.py @@ -1,4 +1,4 @@ -"""L1 pipeline orchestrator. +"""prefilter pipeline orchestrator. Runs two filters in series for each row: 1. Topic Relevance (go/no-go) — REJECT stops the row. @@ -18,20 +18,22 @@ from app.core.cloud import get_cloud_storage from app.core.storage_utils import upload_jsonl_to_object_store from app.models.assessment import AssessmentAttachment, AssessmentRun -from app.services.assessment.l1.duplicate_detection import run_duplicate_detection -from app.services.assessment.l1.topic_relevance import run_topic_relevance +from app.services.assessment.prefilter.duplicate_detection import ( + run_duplicate_detection, +) +from app.services.assessment.prefilter.topic_relevance import run_topic_relevance logger = logging.getLogger(__name__) -def _build_l1_result( +def _build_prefilter_result( row_idx: int, tr_result: dict[str, Any] | None, dup_result: dict[str, Any] | None, ) -> dict[str, Any]: return { "row_id": f"row_{row_idx}", - "l1_passed": tr_result["verdict"] if tr_result else True, + "prefilter_passed": tr_result["verdict"] if tr_result else True, "topic_relevance": { "decision": tr_result["decision"], "column_relevance": tr_result.get("column_relevance") or {}, @@ -43,37 +45,40 @@ def _build_l1_result( } -def run_l1_pipeline( +def run_prefilter_pipeline( run: AssessmentRun, rows: list[dict[str, str]], - l1_config: dict[str, Any], + prefilter_config: dict[str, Any], session: Session, organization_id: int, project_id: int, attachments: list[AssessmentAttachment] | None = None, ) -> tuple[list[dict[str, str]], list[int], list[dict[str, Any]]]: - """Run L1 filters on all rows. + """Run prefilter filters on all rows. Args: run: The AssessmentRun record (used for S3 path and DB update). rows: Full dataset rows loaded from object store. - l1_config: User-supplied config with topic_relevance and duplicate_detection keys. + prefilter_config: User-supplied config with topic_relevance and duplicate_detection keys. session: DB session. organization_id: For Gemini credential lookup. project_id: For Gemini credential lookup and S3 storage. Returns: - (passed_rows, passed_indices, all_l1_results) + (passed_rows, passed_indices, all_prefilter_results) passed_rows: subset of rows where topic_relevance verdict=true. passed_indices: original dataset indices of passed_rows (used to preserve row IDs in L2). - all_l1_results: one entry per input row (len == len(rows)). + all_prefilter_results: one entry per input row (len == len(rows)). """ - model = settings.ASSESSMENT_L1_GEMINI_MODEL - workers = settings.ASSESSMENT_L1_CONCURRENT_WORKERS - store_name = settings.ASSESSMENT_L1_DUPLICATE_STORE_NAME + model = settings.ASSESSMENT_PREFILTER_GEMINI_MODEL + workers = settings.ASSESSMENT_PREFILTER_CONCURRENT_WORKERS + store_name = settings.ASSESSMENT_PREFILTER_DUPLICATE_STORE_NAME + # Future wait bound: the per-request HTTP timeout plus a small margin so a + # hung Gemini call surfaces as a future error instead of blocking forever. + future_timeout = settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS / 1000 + 30 - tr_config = l1_config.get("topic_relevance") or {} - dup_config = l1_config.get("duplicate_detection") or {} + tr_config = prefilter_config.get("topic_relevance") or {} + dup_config = prefilter_config.get("duplicate_detection") or {} tr_columns: list[str] = tr_config.get("columns") or [] tr_prompt: str = tr_config.get("prompt") or "" @@ -91,7 +96,7 @@ def run_l1_pipeline( if not tr_enabled and not dup_enabled: logger.warning( - "[run_l1_pipeline] run_id=%s — no L1 filters configured, skipping L1", + "[run_prefilter_pipeline] run_id=%s — no prefilter filters configured, skipping prefilter", run.id, ) return rows, list(range(len(rows))), [] @@ -103,7 +108,7 @@ def run_l1_pipeline( ).client logger.info( - "[run_l1_pipeline] run_id=%s | rows=%s | model=%s | workers=%s | tr=%s | dup=%s", + "[run_prefilter_pipeline] run_id=%s | rows=%s | model=%s | workers=%s | tr=%s | dup=%s", run.id, len(rows), model, @@ -136,10 +141,10 @@ def run_l1_pipeline( for fut in as_completed(futs): idx = futs[fut] try: - tr_results[idx] = fut.result() + tr_results[idx] = fut.result(timeout=future_timeout) except Exception as exc: logger.warning( - "[run_l1_pipeline] TR future error row_%s | %s", idx, exc + "[run_prefilter_pipeline] TR future error row_%s | %s", idx, exc ) tr_results[idx] = { "row_id": f"row_{idx}", @@ -156,7 +161,7 @@ def run_l1_pipeline( rejected_count = len(rows) - len(passed_indices) logger.info( - "[run_l1_pipeline] run_id=%s | TR done | passed=%s | rejected=%s", + "[run_prefilter_pipeline] run_id=%s | TR done | passed=%s | rejected=%s", run.id, len(passed_indices), rejected_count, @@ -180,10 +185,12 @@ def run_l1_pipeline( for fut in as_completed(futs): idx = futs[fut] try: - dup_results[idx] = fut.result() + dup_results[idx] = fut.result(timeout=future_timeout) except Exception as exc: logger.warning( - "[run_l1_pipeline] DUP future error row_%s | %s", idx, exc + "[run_prefilter_pipeline] DUP future error row_%s | %s", + idx, + exc, ) dup_results[idx] = { "row_id": f"row_{idx}", @@ -194,45 +201,45 @@ def run_l1_pipeline( "reason": str(exc)[:200], } - all_l1_results: list[dict[str, Any]] = [ - _build_l1_result(idx, tr_results[idx], dup_results.get(idx)) + all_prefilter_results: list[dict[str, Any]] = [ + _build_prefilter_result(idx, tr_results[idx], dup_results.get(idx)) for idx in range(len(rows)) ] - l1_object_store_url: str | None = None + prefilter_object_store_url: str | None = None try: storage = get_cloud_storage(session=session, project_id=project_id) - l1_object_store_url = upload_jsonl_to_object_store( + prefilter_object_store_url = upload_jsonl_to_object_store( storage=storage, - results=all_l1_results, - filename="l1_results.json", - subdirectory=f"assessment/run-{run.id}/l1", + results=all_prefilter_results, + filename="prefilter_results.json", + subdirectory=f"assessment/run-{run.id}/prefilter", format="json", ) logger.info( - "[run_l1_pipeline] run_id=%s | L1 results uploaded to %s", + "[run_prefilter_pipeline] run_id=%s | prefilter results uploaded to %s", run.id, - l1_object_store_url, + prefilter_object_store_url, ) except Exception as exc: logger.error( - "[run_l1_pipeline] run_id=%s | S3 upload failed | %s", + "[run_prefilter_pipeline] run_id=%s | S3 upload failed | %s", run.id, exc, exc_info=True, ) - from app.crud.assessment.core import update_assessment_run_l1_stats + from app.crud.assessment.core import update_assessment_run_prefilter_stats - update_assessment_run_l1_stats( + update_assessment_run_prefilter_stats( session=session, run=run, - l1_object_store_url=l1_object_store_url, - l1_total_rows=len(rows), - l1_total_passed=len(passed_indices), - l1_total_rejected=rejected_count, + prefilter_object_store_url=prefilter_object_store_url, + prefilter_total_rows=len(rows), + prefilter_total_passed=len(passed_indices), + prefilter_total_rejected=rejected_count, ) sorted_passed_indices = sorted(passed_indices) passed_rows = [rows[idx] for idx in sorted_passed_indices] - return passed_rows, sorted_passed_indices, all_l1_results + return passed_rows, sorted_passed_indices, all_prefilter_results diff --git a/backend/app/services/assessment/l1/topic_relevance.py b/backend/app/services/assessment/prefilter/topic_relevance.py similarity index 94% rename from backend/app/services/assessment/l1/topic_relevance.py rename to backend/app/services/assessment/prefilter/topic_relevance.py index c1894c04e..053547ab7 100644 --- a/backend/app/services/assessment/l1/topic_relevance.py +++ b/backend/app/services/assessment/prefilter/topic_relevance.py @@ -1,4 +1,4 @@ -"""Topic relevance filter for L1 pipeline. +"""Topic relevance filter for prefilter pipeline. """ import json @@ -8,6 +8,7 @@ from google import genai from google.genai import types +from app.core.config import settings from app.models.assessment import AssessmentAttachment from app.services.assessment.utils.attachments import build_gemini_attachment_parts @@ -89,6 +90,9 @@ def run_topic_relevance( response_mime_type="application/json", response_schema=output_schema, temperature=0.0, + http_options=types.HttpOptions( + timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS + ), ), ) raw = (response.text or "").strip() diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py index cabe2bb4c..b2e2cea05 100644 --- a/backend/app/services/assessment/service.py +++ b/backend/app/services/assessment/service.py @@ -79,7 +79,7 @@ def _build_retry_request( attachments=[AssessmentAttachment.model_validate(item) for item in attachments], output_schema=assessment_input.get("output_schema"), configs=configs, - l1_config=assessment_input.get("l1_config"), + prefilter_config=assessment_input.get("prefilter_config"), post_processing_config=assessment_input.get("post_processing_config"), ) @@ -93,7 +93,7 @@ def start_assessment( """Validate, create Assessment + AssessmentRun records, dispatch Celery tasks. Each run is created with status='pending' and handed off to a Celery worker - that runs L1 filtering then submits the L2 batch. + that runs prefilter filtering then submits the L2 batch. """ from app.celery.tasks.job_execution import run_assessment_run @@ -120,8 +120,8 @@ def start_assessment( } if request.output_schema: assessment_input["output_schema"] = request.output_schema - if request.l1_config: - assessment_input["l1_config"] = request.l1_config + if request.prefilter_config: + assessment_input["prefilter_config"] = request.prefilter_config if request.post_processing_config: assessment_input["post_processing_config"] = request.post_processing_config diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py index 295c55ad2..909a89a25 100644 --- a/backend/app/services/assessment/tasks.py +++ b/backend/app/services/assessment/tasks.py @@ -1,4 +1,4 @@ -"""Celery task logic for running a single assessment run (L1 → L2 batch submit).""" +"""Celery task logic for running a single assessment run (prefilter → L2 batch submit).""" import logging @@ -19,7 +19,7 @@ AssessmentRun, ) from app.models.config.config import ConfigTag -from app.services.assessment.l1 import run_l1_pipeline +from app.services.assessment.prefilter import run_prefilter_pipeline logger = logging.getLogger(__name__) @@ -29,12 +29,12 @@ def execute_assessment_run( organization_id: int, project_id: int, ) -> None: - """Run L1 filtering then submit L2 batch for one AssessmentRun. + """Run prefilter filtering then submit L2 batch for one AssessmentRun. Status transitions: - pending → l1_processing → l1_failed (stop) + pending → prefilter_processing → prefilter_failed (stop) → l2_processing → (cron handles rest) - pending → l2_processing (when no l1_config) + pending → l2_processing (when no prefilter_config) """ with Session(engine) as session: run = session.get(AssessmentRun, run_id) @@ -116,19 +116,19 @@ def execute_assessment_run( recompute_assessment_status(session=session, assessment_id=assessment.id) return - # L1 pipeline + # prefilter pipeline rows_for_l2 = all_rows row_indices_for_l2: list[int] | None = None - l1_config = assessment_input.get("l1_config") - if l1_config: + prefilter_config = assessment_input.get("prefilter_config") + if prefilter_config: update_assessment_run_status( - session=session, run=run, status="l1_processing" + session=session, run=run, status="prefilter_processing" ) try: - rows_for_l2, row_indices_for_l2, _ = run_l1_pipeline( + rows_for_l2, row_indices_for_l2, _ = run_prefilter_pipeline( run=run, rows=all_rows, - l1_config=l1_config, + prefilter_config=prefilter_config, session=session, organization_id=organization_id, project_id=project_id, @@ -138,28 +138,28 @@ def execute_assessment_run( ], ) logger.info( - "[execute_assessment_run] L1 done | run_id=%s | rows_to_l2=%s / %s", + "[execute_assessment_run] prefilter done | run_id=%s | rows_to_l2=%s / %s", run_id, len(rows_for_l2), len(all_rows), ) - except Exception as l1_exc: + except Exception as prefilter_exc: logger.error( - "[execute_assessment_run] L1 failed run_id=%s | %s", + "[execute_assessment_run] prefilter failed run_id=%s | %s", run_id, - l1_exc, + prefilter_exc, exc_info=True, ) update_assessment_run_status( session=session, run=run, - status="l1_failed", - error_message=f"L1 pipeline failed: {l1_exc}", + status="prefilter_failed", + error_message=f"prefilter pipeline failed: {prefilter_exc}", ) recompute_assessment_status( session=session, assessment_id=assessment.id ) - return # L2 does not run when L1 fails + return # L2 does not run when prefilter fails # L2 batch submit try: diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py index 3622f9bce..87ca3aba7 100644 --- a/backend/app/services/assessment/utils/attachments.py +++ b/backend/app/services/assessment/utils/attachments.py @@ -9,11 +9,12 @@ import logging import re from typing import Any -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse import requests from app.models.assessment import AssessmentAttachment +from app.utils import validate_callback_url logger = logging.getLogger(__name__) @@ -177,33 +178,43 @@ def _type_from_content_type(content_type: str | None) -> str | None: return None +_PROBE_MAX_REDIRECTS = 3 + + def _probe_url_type(url: str, num_bytes: int = 16) -> str | None: """Probe a remote URL's type: ranged byte sniff first, Content-Type fallback. - - Reads only the first few bytes (does not download the whole file). Drive - share URLs are routed through the download endpoint so the real file bytes - are read instead of an HTML share page. - """ + Handles Google Drive URLs with the same logic as to_direct_attachment_url, since""" file_id = _drive_file_id(url) - probe_url = ( + current = ( f"https://drive.google.com/uc?export=download&id={file_id}" if file_id else url ) try: - with requests.get( - probe_url, - headers={"Range": f"bytes=0-{num_bytes - 1}"}, - timeout=10, - stream=True, - allow_redirects=True, - ) as resp: - resp.raise_for_status() - for chunk in resp.iter_content(chunk_size=num_bytes): - magic_type = _type_from_magic(chunk) - if magic_type: - return magic_type - break - return _type_from_content_type(resp.headers.get("Content-Type")) + for _ in range(_PROBE_MAX_REDIRECTS + 1): + validate_callback_url(current) + with requests.get( + current, + headers={"Range": f"bytes=0-{num_bytes - 1}"}, + timeout=10, + stream=True, + allow_redirects=False, + ) as resp: + location = resp.headers.get("Location") + if resp.is_redirect and location: + current = urljoin(current, location) + continue + resp.raise_for_status() + for chunk in resp.iter_content(chunk_size=num_bytes): + magic_type = _type_from_magic(chunk) + if magic_type: + return magic_type + break + return _type_from_content_type(resp.headers.get("Content-Type")) + logger.warning(f"[_probe_url_type] Too many redirects probing {url}") + return None + except ValueError as e: + logger.warning(f"[_probe_url_type] Blocked unsafe probe URL {url}: {e}") + return None except requests.RequestException as e: logger.warning(f"[_probe_url_type] Probe failed for {url}: {e}") return None @@ -312,7 +323,7 @@ def build_gemini_attachment_parts( """Convert one dataset cell into one or more Gemini content parts. Mirrors the per-item type detection used for the L2 batch so the same - image/pdf routing applies to L1 (topic relevance) calls. + image/pdf routing applies to prefilter (topic relevance) calls. """ value = value.strip() if not value: diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index 86d9186b0..39fa7691c 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -22,7 +22,7 @@ from app.services.assessment.utils.parsing import parse_stored_results, usage_totals from app.utils import APIResponse -_L1_JSON_COLUMNS = ["topic_relevance", "duplicate_detection"] +_PREFILTER_JSON_COLUMNS = ["topic_relevance", "duplicate_detection"] logger = logging.getLogger(__name__) @@ -36,23 +36,23 @@ def _load_dataset_rows( return load_dataset_rows(session, dataset) -def _load_l1_results( +def _load_prefilter_results( session: Session, run: AssessmentRun, assessment: Assessment, ) -> dict[str, dict[str, Any]]: - """Load L1 results from object store, keyed by row_id. Returns {} if unavailable.""" - if not run.l1_object_store_url: + """Load prefilter results from object store, keyed by row_id. Returns {} if unavailable.""" + if not run.prefilter_object_store_url: return {} try: storage = get_cloud_storage(session, project_id=assessment.project_id) - body = storage.stream(run.l1_object_store_url) + body = storage.stream(run.prefilter_object_store_url) raw = body.read().decode("utf-8") results: list[dict[str, Any]] = json.loads(raw) return {str(item["row_id"]): item for item in results if "row_id" in item} except Exception as exc: logger.warning( - "[_load_l1_results] Failed to load L1 results for run id=%s: %s", + "[_load_prefilter_results] Failed to load prefilter results for run id=%s: %s", run.id, exc, ) @@ -163,32 +163,34 @@ def _expand_output_columns( """ row_payload, input_col_names = _expand_input_columns(row_payload) - json_expand_cols = {"output", "input_data"} | set(_L1_JSON_COLUMNS) + json_expand_cols = {"output", "input_data"} | set(_PREFILTER_JSON_COLUMNS) base_fields = [ field for field in AssessmentExportRow.model_fields.keys() if field not in json_expand_cols ] - # L1 columns are prefixed with their parent name to avoid key collisions + # prefilter columns are prefixed with their parent name to avoid key collisions parsed_cols: dict[str, list[dict[str, Any] | None]] = { - col: [] for col in ["output"] + _L1_JSON_COLUMNS + col: [] for col in ["output"] + _PREFILTER_JSON_COLUMNS + } + col_keys: dict[str, list[str]] = { + col: [] for col in ["output"] + _PREFILTER_JSON_COLUMNS } - col_keys: dict[str, list[str]] = {col: [] for col in ["output"] + _L1_JSON_COLUMNS} col_seen: dict[str, dict[str, None]] = { - col: {} for col in ["output"] + _L1_JSON_COLUMNS + col: {} for col in ["output"] + _PREFILTER_JSON_COLUMNS } has_unparsed_output = False for row in row_payload: - for col in ["output"] + _L1_JSON_COLUMNS: + for col in ["output"] + _PREFILTER_JSON_COLUMNS: parsed = _parse_json_col(row.get(col)) if parsed is None and col == "output" and row.get(col) is not None: has_unparsed_output = True parsed_cols[col].append(parsed) if parsed: for k in parsed: - prefixed = f"{col}_{k}" if col in _L1_JSON_COLUMNS else k + prefixed = f"{col}_{k}" if col in _PREFILTER_JSON_COLUMNS else k if prefixed not in col_seen[col]: col_seen[col][prefixed] = None col_keys[col].append(prefixed) @@ -196,7 +198,7 @@ def _expand_output_columns( def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: if not parsed: return {} - if col in _L1_JSON_COLUMNS: + if col in _PREFILTER_JSON_COLUMNS: return {f"{col}_{k}": v for k, v in parsed.items()} return parsed @@ -204,7 +206,7 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: expanded: list[dict[str, Any]] = [] for i, row in enumerate(row_payload): new_row = {k: v for k, v in row.items() if k not in json_expand_cols} - for col in ["output"] + _L1_JSON_COLUMNS: + for col in ["output"] + _PREFILTER_JSON_COLUMNS: parsed = parsed_cols[col][i] keys = col_keys[col] prefixed_vals = _get_prefixed(parsed, col) @@ -218,22 +220,22 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: new_row["output_raw"] = row.get("output") expanded.append(new_row) - l1_keys = col_keys["topic_relevance"] + col_keys["duplicate_detection"] + prefilter_keys = col_keys["topic_relevance"] + col_keys["duplicate_detection"] output_keys = col_keys["output"] - all_output_keys = l1_keys + output_keys + all_output_keys = prefilter_keys + output_keys if not all_output_keys: fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) fieldnames = [f for f in fieldnames if f != "input_data"] return row_payload, fieldnames, input_col_names, [], [] - fieldnames = input_col_names + l1_keys + output_keys + base_fields + fieldnames = input_col_names + prefilter_keys + output_keys + base_fields if has_unparsed_output: fieldnames.insert( - len(input_col_names) + len(l1_keys) + len(output_keys), "output_raw" + len(input_col_names) + len(prefilter_keys) + len(output_keys), "output_raw" ) - return expanded, fieldnames, input_col_names, l1_keys, output_keys + return expanded, fieldnames, input_col_names, prefilter_keys, output_keys def serialize_export_rows( @@ -258,7 +260,7 @@ def serialize_export_rows( expanded, fieldnames, input_col_names, - l1_keys, + prefilter_keys, output_keys, ) = _expand_output_columns(row_payload) expanded = apply_post_processing(expanded, post_processing_config) @@ -288,8 +290,8 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - # Explicit ordering: inputs → L1 → L2 → computed columns - excel_fields = input_col_names + l1_keys + output_keys + computed_names + # Explicit ordering: inputs → prefilter → L2 → computed columns + excel_fields = input_col_names + prefilter_keys + output_keys + computed_names if not excel_fields: excel_fields = output_keys or ["output"] @@ -431,15 +433,15 @@ def _load_dataset_rows_for_run( return [] -def _extract_l1_json_columns( - l1_item: dict[str, Any] | None, +def _extract_prefilter_json_columns( + prefilter_item: dict[str, Any] | None, ) -> dict[str, Any]: """Return topic_relevance and duplicate_detection as JSON strings for export expansion.""" - if not l1_item: + if not prefilter_item: return {"topic_relevance": None, "duplicate_detection": None} - tr = l1_item.get("topic_relevance") - dup = l1_item.get("duplicate_detection") + tr = prefilter_item.get("topic_relevance") + dup = prefilter_item.get("duplicate_detection") tr_flat: dict[str, Any] | None = None if tr: @@ -468,10 +470,10 @@ def load_export_rows_for_run( ) -> list[AssessmentExportRow]: """Load flattened export rows for a single child assessment run. - When L1 results exist, ALL dataset rows are included in output. - L1-rejected rows have L1 columns filled and L2 columns empty. - L1-passed rows have all columns filled. - Without L1, behaviour is unchanged (only L2 result rows returned). + When prefilter results exist, ALL dataset rows are included in output. + prefilter-rejected rows have prefilter columns filled and L2 columns empty. + prefilter-passed rows have all columns filled. + Without prefilter, behaviour is unchanged (only L2 result rows returned). """ if assessment is None: assessment = session.get(Assessment, run.assessment_id) @@ -486,8 +488,8 @@ def load_export_rows_for_run( dataset_name = dataset.name if dataset else None dataset_rows = _load_dataset_rows_for_run(session, run, assessment) - # Load L1 results (empty dict if no L1 was run) - l1_by_row_id = _load_l1_results(session, run, assessment) + # Load prefilter results (empty dict if no prefilter was run) + prefilter_by_row_id = _load_prefilter_results(session, run, assessment) # Load L2 results (may be None if batch not complete) l2_by_row_id: dict[str, dict[str, Any]] = {} @@ -504,24 +506,24 @@ def load_export_rows_for_run( if "row_id" in item } - has_l1 = bool(l1_by_row_id) + has_prefilter = bool(prefilter_by_row_id) - if has_l1 and dataset_rows: + if has_prefilter and dataset_rows: # All rows in output — build from full dataset export_rows: list[AssessmentExportRow] = [] for row_idx, input_data in enumerate(dataset_rows): row_id_str = f"row_{row_idx}" - l1_item = l1_by_row_id.get(row_id_str) - l1_cols = _extract_l1_json_columns(l1_item) + prefilter_item = prefilter_by_row_id.get(row_id_str) + prefilter_cols = _extract_prefilter_json_columns(prefilter_item) l2_item = l2_by_row_id.get(row_id_str) input_tokens, output_tokens, total_tokens = usage_totals( l2_item.get("usage") if l2_item else None ) - l1_passed = (l1_item or {}).get("l1_passed", True) + prefilter_passed = (prefilter_item or {}).get("prefilter_passed", True) result_status = ( - "l1_rejected" - if not l1_passed + "prefilter_rejected" + if not prefilter_passed else ("failed" if l2_item and l2_item.get("error") else "passed") ) @@ -539,8 +541,8 @@ def load_export_rows_for_run( row_id=row_id_str, result_status=result_status, input_data=input_data, - topic_relevance=l1_cols.get("topic_relevance"), - duplicate_detection=l1_cols.get("duplicate_detection"), + topic_relevance=prefilter_cols.get("topic_relevance"), + duplicate_detection=prefilter_cols.get("duplicate_detection"), output=l2_item.get("output") if l2_item else None, error=l2_item.get("error") if l2_item else None, response_id=l2_item.get("response_id") if l2_item else None, @@ -552,7 +554,7 @@ def load_export_rows_for_run( ) return export_rows - # No L1 — original behaviour: only L2 result rows + # No prefilter — original behaviour: only L2 result rows if not run.batch_job_id: logger.warning( "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index aa0fce1a0..38373774c 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -470,9 +470,12 @@ def test_url_no_extension_probes_bytes(self) -> None: resp = MagicMock() resp.__enter__ = MagicMock(return_value=resp) resp.__exit__ = MagicMock(return_value=False) + resp.is_redirect = False resp.raise_for_status = MagicMock() resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ), patch( "app.services.assessment.utils.attachments.requests.get", return_value=resp, ) as mock_get: @@ -485,10 +488,13 @@ def test_url_probe_uses_content_type_when_no_magic(self) -> None: resp = MagicMock() resp.__enter__ = MagicMock(return_value=resp) resp.__exit__ = MagicMock(return_value=False) + resp.is_redirect = False resp.raise_for_status = MagicMock() resp.iter_content = MagicMock(return_value=iter([b"\x00\x01\x02\x03"])) resp.headers = {"Content-Type": "application/pdf; charset=binary"} with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ), patch( "app.services.assessment.utils.attachments.requests.get", return_value=resp, ): @@ -499,20 +505,60 @@ def test_url_probe_failure_falls_back(self) -> None: url = "https://example.com/file" with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ), patch( "app.services.assessment.utils.attachments.requests.get", side_effect=_requests.RequestException("boom"), ): assert detect_item_type(url, "url", "image", {}) == "image" + def test_url_probe_follows_validated_redirect(self) -> None: + """A redirect hop is followed and re-validated before the next request.""" + url = "https://drive.google.com/file/d/RID/view" + redirect = MagicMock() + redirect.__enter__ = MagicMock(return_value=redirect) + redirect.__exit__ = MagicMock(return_value=False) + redirect.is_redirect = True + redirect.headers = {"Location": "https://files.example.com/real.pdf"} + final = MagicMock() + final.__enter__ = MagicMock(return_value=final) + final.__exit__ = MagicMock(return_value=False) + final.is_redirect = False + final.raise_for_status = MagicMock() + final.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) + with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ) as validate, patch( + "app.services.assessment.utils.attachments.requests.get", + side_effect=[redirect, final], + ) as mock_get: + assert detect_item_type(url, "url", "image", {}) == "pdf" + # Both the initial and redirected URLs were validated and fetched. + assert validate.call_count == 2 + assert mock_get.call_count == 2 + + def test_url_probe_blocked_by_ssrf_falls_back(self) -> None: + url = "https://internal.host/file" + with patch( + "app.services.assessment.utils.attachments.validate_callback_url", + side_effect=ValueError("private IP"), + ), patch("app.services.assessment.utils.attachments.requests.get") as mock_get: + # SSRF guard blocks the probe -> falls back to declared type. + assert detect_item_type(url, "url", "pdf", {}) == "pdf" + mock_get.assert_not_called() + def test_cache_skips_second_probe(self) -> None: url = "https://drive.google.com/file/d/XYZ/view" cache: dict[str, str] = {} resp = MagicMock() resp.__enter__ = MagicMock(return_value=resp) resp.__exit__ = MagicMock(return_value=False) + resp.is_redirect = False resp.raise_for_status = MagicMock() resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ), patch( "app.services.assessment.utils.attachments.requests.get", return_value=resp, ) as mock_get: diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py index 1cf30249e..e2f44a21a 100644 --- a/backend/app/tests/assessment/test_crud.py +++ b/backend/app/tests/assessment/test_crud.py @@ -27,7 +27,7 @@ update_assessment_run_status, update_run_post_processing_config, ) -from app.crud.assessment.core import update_assessment_run_l1_stats +from app.crud.assessment.core import update_assessment_run_prefilter_stats from app.models.stt_evaluation import EvaluationType @@ -234,9 +234,9 @@ def test_build_run_stats(self) -> None: total_items=2, error_message=None, updated_at=datetime(2024, 1, 1), - l1_total_rows=None, - l1_total_passed=None, - l1_total_rejected=None, + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, ), ] stats = build_run_stats(runs) @@ -343,23 +343,23 @@ def test_sets_stats_fields(self) -> None: run = SimpleNamespace( id=8, updated_at=None, - l1_object_store_url=None, - l1_total_rows=None, - l1_total_passed=None, - l1_total_rejected=None, + prefilter_object_store_url=None, + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, ) - out = update_assessment_run_l1_stats( + out = update_assessment_run_prefilter_stats( session=session, run=run, - l1_object_store_url="s3://x", - l1_total_rows=10, - l1_total_passed=7, - l1_total_rejected=3, + prefilter_object_store_url="s3://x", + prefilter_total_rows=10, + prefilter_total_passed=7, + prefilter_total_rejected=3, ) - assert out.l1_object_store_url == "s3://x" - assert out.l1_total_rows == 10 - assert out.l1_total_passed == 7 - assert out.l1_total_rejected == 3 + assert out.prefilter_object_store_url == "s3://x" + assert out.prefilter_total_rows == 10 + assert out.prefilter_total_passed == 7 + assert out.prefilter_total_rejected == 3 session.commit.assert_called_once() def test_commit_failure_rolls_back(self) -> None: @@ -368,11 +368,13 @@ def test_commit_failure_rolls_back(self) -> None: run = SimpleNamespace( id=9, updated_at=None, - l1_object_store_url=None, - l1_total_rows=None, - l1_total_passed=None, - l1_total_rejected=None, + prefilter_object_store_url=None, + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, ) with pytest.raises(RuntimeError): - update_assessment_run_l1_stats(session=session, run=run, l1_total_rows=1) + update_assessment_run_prefilter_stats( + session=session, run=run, prefilter_total_rows=1 + ) session.rollback.assert_called_once() diff --git a/backend/app/tests/assessment/test_duplicate_detection.py b/backend/app/tests/assessment/test_duplicate_detection.py index 24d5ac951..5d363f896 100644 --- a/backend/app/tests/assessment/test_duplicate_detection.py +++ b/backend/app/tests/assessment/test_duplicate_detection.py @@ -1,9 +1,9 @@ -"""Tests for L1 duplicate detection.""" +"""Tests for prefilter duplicate detection.""" import json from unittest.mock import MagicMock -from app.services.assessment.l1.duplicate_detection import ( +from app.services.assessment.prefilter.duplicate_detection import ( _build_combined, _parse_verdict, run_duplicate_detection, diff --git a/backend/app/tests/assessment/test_pipeline.py b/backend/app/tests/assessment/test_pipeline.py index faa64693e..d74841650 100644 --- a/backend/app/tests/assessment/test_pipeline.py +++ b/backend/app/tests/assessment/test_pipeline.py @@ -1,9 +1,9 @@ -"""Tests for the L1 pipeline orchestrator.""" +"""Tests for the prefilter pipeline orchestrator.""" from contextlib import ExitStack from unittest.mock import MagicMock, patch -from app.services.assessment.l1.pipeline import run_l1_pipeline +from app.services.assessment.prefilter.pipeline import run_prefilter_pipeline def _run() -> MagicMock: @@ -27,32 +27,32 @@ def _patches(stack: ExitStack, *, tr_side=None, dup_return=None): client = MagicMock() stack.enter_context( patch( - "app.services.assessment.l1.pipeline.GeminiClient.from_credentials", + "app.services.assessment.prefilter.pipeline.GeminiClient.from_credentials", return_value=MagicMock(client=client), ) ) stack.enter_context( patch( - "app.services.assessment.l1.pipeline.get_cloud_storage", + "app.services.assessment.prefilter.pipeline.get_cloud_storage", return_value=MagicMock(), ) ) stack.enter_context( patch( - "app.services.assessment.l1.pipeline.upload_jsonl_to_object_store", - return_value="s3://l1.json", + "app.services.assessment.prefilter.pipeline.upload_jsonl_to_object_store", + return_value="s3://prefilter.json", ) ) stack.enter_context( - patch("app.crud.assessment.core.update_assessment_run_l1_stats") + patch("app.crud.assessment.core.update_assessment_run_prefilter_stats") ) tr_mock = stack.enter_context( - patch("app.services.assessment.l1.pipeline.run_topic_relevance") + patch("app.services.assessment.prefilter.pipeline.run_topic_relevance") ) if tr_side is not None: tr_mock.side_effect = tr_side dup_mock = stack.enter_context( - patch("app.services.assessment.l1.pipeline.run_duplicate_detection") + patch("app.services.assessment.prefilter.pipeline.run_duplicate_detection") ) if dup_return is not None: dup_mock.return_value = dup_return @@ -62,10 +62,10 @@ def _patches(stack: ExitStack, *, tr_side=None, dup_return=None): class TestRunL1Pipeline: def test_no_filters_configured_passthrough(self) -> None: rows = [{"Problem": "a"}, {"Problem": "b"}] - passed, indices, results = run_l1_pipeline( + passed, indices, results = run_prefilter_pipeline( run=_run(), rows=rows, - l1_config={}, + prefilter_config={}, session=MagicMock(), organization_id=1, project_id=1, @@ -80,10 +80,10 @@ def test_topic_relevance_filters_rejected_rows(self) -> None: side = [_tr(True), _tr(False, "REJECT"), _tr(True)] with ExitStack() as stack: _patches(stack, tr_side=side) - passed, indices, results = run_l1_pipeline( + passed, indices, results = run_prefilter_pipeline( run=_run(), rows=rows, - l1_config={ + prefilter_config={ "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"} }, session=MagicMock(), @@ -93,7 +93,7 @@ def test_topic_relevance_filters_rejected_rows(self) -> None: assert indices == [0, 2] assert [r["Problem"] for r in passed] == ["keep", "keep2"] assert len(results) == 3 - assert results[1]["l1_passed"] is False + assert results[1]["prefilter_passed"] is False def test_duplicate_detection_runs_on_passed_rows(self) -> None: rows = [{"Problem": "a", "Solution": "b"}] @@ -107,10 +107,10 @@ def test_duplicate_detection_runs_on_passed_rows(self) -> None: } with ExitStack() as stack: tr_mock, dup_mock = _patches(stack, tr_side=[_tr(True)], dup_return=dup) - _, _, results = run_l1_pipeline( + _, _, results = run_prefilter_pipeline( run=_run(), rows=rows, - l1_config={ + prefilter_config={ "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"}, "duplicate_detection": {"columns": ["Problem", "Solution"]}, }, @@ -131,10 +131,10 @@ def test_attachment_columns_filtered_to_selection(self) -> None: ] with ExitStack() as stack: tr_mock, _ = _patches(stack, tr_side=[_tr(True)]) - run_l1_pipeline( + run_prefilter_pipeline( run=_run(), rows=rows, - l1_config={ + prefilter_config={ "topic_relevance": { "columns": ["Problem"], "prompt": "rubric", diff --git a/backend/app/tests/assessment/test_topic_relevance.py b/backend/app/tests/assessment/test_topic_relevance.py index ad52c2306..064d4476b 100644 --- a/backend/app/tests/assessment/test_topic_relevance.py +++ b/backend/app/tests/assessment/test_topic_relevance.py @@ -1,10 +1,10 @@ -"""Tests for L1 topic relevance attachment handling.""" +"""Tests for prefilter topic relevance attachment handling.""" import json from unittest.mock import MagicMock from app.models.assessment import AssessmentAttachment -from app.services.assessment.l1.topic_relevance import run_topic_relevance +from app.services.assessment.prefilter.topic_relevance import run_topic_relevance def _client_returning(decision: str) -> MagicMock: From 4a4e4f87007d7c102fd0363dc6443ec64ab6ec05 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 4 Jun 2026 08:46:08 +0530 Subject: [PATCH 09/16] Refactor assessment tests and add new functionality - Consolidated and refactored tests for the prefilter pipeline and related services. - Introduced new tests for orchestrating assessment runs, ensuring proper handling of pipeline stages and statuses. - Added tests for the topic relevance request builder and result parser, improving coverage for attachment handling. - Implemented failure guards in task execution to prevent dangling runs and ensure proper error handling. - Enhanced the resume functionality for failed assessment runs, allowing for retries from the last failed stage. --- ...add_prefilter_columns_to_assessment_run.py | 54 ++- backend/app/api/docs/assessment/resume_run.md | 5 + backend/app/api/routes/assessment/runs.py | 40 +- backend/app/celery/tasks/job_execution.py | 8 +- backend/app/core/config.py | 17 +- backend/app/crud/assessment/batch.py | 32 +- backend/app/crud/assessment/core.py | 15 +- backend/app/crud/assessment/cron.py | 59 +-- backend/app/crud/assessment/processing.py | 378 ++++++----------- backend/app/models/assessment.py | 82 +++- .../services/assessment/prefilter/__init__.py | 4 +- .../prefilter/duplicate_detection.py | 301 +++++--------- .../services/assessment/prefilter/pipeline.py | 251 +---------- .../assessment/prefilter/request_builder.py | 73 ++++ .../assessment/prefilter/topic_relevance.py | 186 ++++----- backend/app/services/assessment/service.py | 79 +++- backend/app/services/assessment/stages.py | 195 +++++++++ backend/app/services/assessment/tasks.py | 392 +++++++++++------- .../services/assessment/utils/attachments.py | 161 ++----- .../app/services/assessment/utils/export.py | 355 +++++++++------- backend/app/tests/assessment/test_batch.py | 250 +++++------ backend/app/tests/assessment/test_cron.py | 53 +-- backend/app/tests/assessment/test_crud.py | 3 + .../assessment/test_duplicate_detection.py | 178 +++----- backend/app/tests/assessment/test_export.py | 131 +++--- backend/app/tests/assessment/test_pipeline.py | 185 ++------- .../assessment/test_prefilter_batching.py | 115 +++++ .../app/tests/assessment/test_processing.py | 238 +++++------ backend/app/tests/assessment/test_service.py | 72 +++- .../assessment/test_tasks_failure_guard.py | 82 ++++ .../tests/assessment/test_topic_relevance.py | 194 ++++----- 31 files changed, 2085 insertions(+), 2103 deletions(-) create mode 100644 backend/app/api/docs/assessment/resume_run.md create mode 100644 backend/app/services/assessment/prefilter/request_builder.py create mode 100644 backend/app/services/assessment/stages.py create mode 100644 backend/app/tests/assessment/test_prefilter_batching.py create mode 100644 backend/app/tests/assessment/test_tasks_failure_guard.py diff --git a/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py b/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py index 1720e21b4..25ba9a3c0 100644 --- a/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py +++ b/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py @@ -1,4 +1,4 @@ -"""Add prefilter pipeline columns to assessment_run +"""Add prefilter columns and pipeline stage-machine columns to assessment_run Revision ID: 064 Revises: 063 @@ -8,6 +8,7 @@ import sqlalchemy as sa from alembic import op +from sqlalchemy.dialects import postgresql revision = "064" down_revision = "063" @@ -22,7 +23,7 @@ def upgrade() -> None: "prefilter_object_store_url", sa.String(), nullable=True, - comment="S3 URL of stored prefilter filter results JSON", + comment="S3 URL of prefilter results JSON", ), ) op.add_column( @@ -31,7 +32,7 @@ def upgrade() -> None: "prefilter_total_rows", sa.Integer(), nullable=True, - comment="Total rows fed into prefilter pipeline", + comment="Total rows fed into the prefilter stages", ), ) op.add_column( @@ -40,7 +41,7 @@ def upgrade() -> None: "prefilter_total_passed", sa.Integer(), nullable=True, - comment="Rows that passed topic relevance and went to L2", + comment="Rows that passed the go/no-go gates and went to L2", ), ) op.add_column( @@ -49,12 +50,55 @@ def upgrade() -> None: "prefilter_total_rejected", sa.Integer(), nullable=True, - comment="Rows rejected by topic relevance, stopped at prefilter", + comment="Rows rejected by a go/no-go gate", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "stage", + sa.String(), + nullable=True, + comment=( + "Current pipeline stage: PRE_FILTER_TOPIC_RELEVANCE, " + "PRE_FILTER_DUPLICATE_DETECTION, L2_ASSESSMENT, COMPLETED, FAILED" + ), + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "stage_status", + sa.String(), + nullable=True, + comment="Status of stage: PENDING, PROCESSING, COMPLETED, FAILED", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "pipeline", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Ordered stage config driving execution: {'stages': [...]}", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "stage_batches", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Map of stage name -> batch_job id, for per-stage result lookup", ), ) def downgrade() -> None: + op.drop_column("assessment_run", "stage_batches") + op.drop_column("assessment_run", "pipeline") + op.drop_column("assessment_run", "stage_status") + op.drop_column("assessment_run", "stage") op.drop_column("assessment_run", "prefilter_total_rejected") op.drop_column("assessment_run", "prefilter_total_passed") op.drop_column("assessment_run", "prefilter_total_rows") diff --git a/backend/app/api/docs/assessment/resume_run.md b/backend/app/api/docs/assessment/resume_run.md new file mode 100644 index 000000000..a7dc2713f --- /dev/null +++ b/backend/app/api/docs/assessment/resume_run.md @@ -0,0 +1,5 @@ +Resume a failed assessment run from its failed stage. + +Re-runs the same child run in place, starting at the stage that failed. +Stages that already completed are reused (their batch results are not +recomputed). Only valid when the run is in a failed state. diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py index 2825e5c86..3ed8305ef 100644 --- a/backend/app/api/routes/assessment/runs.py +++ b/backend/app/api/routes/assessment/runs.py @@ -10,9 +10,13 @@ from app.api.permissions import Permission, require_permission from app.crud.assessment import ( get_assessment_by_id, + update_run_post_processing_config, +) +from app.crud.assessment import ( get_assessment_run_by_id as get_run_by_id, +) +from app.crud.assessment import ( list_assessment_runs as list_runs, - update_run_post_processing_config, ) from app.models.assessment import ( Assessment, @@ -22,6 +26,9 @@ AssessmentRunPublic, ) from app.models.evaluation import EvaluationDataset +from app.services.assessment.service import ( + resume_assessment_run as resume_run, +) from app.services.assessment.service import ( retry_assessment_run as retry_run, ) @@ -70,6 +77,9 @@ def _build_run_public( prefilter_total_rows=run.prefilter_total_rows, prefilter_total_passed=run.prefilter_total_passed, prefilter_total_rejected=run.prefilter_total_rejected, + stage=run.stage, + stage_status=run.stage_status, + pipeline=run.pipeline, post_processing_config=(run.input or {}).get("post_processing_config"), inserted_at=run.inserted_at, updated_at=run.updated_at, @@ -133,6 +143,34 @@ def retry_assessment_run( return APIResponse.success_response(data=result) +@router.post( + "/runs/{run_id}/resume", + description=load_description("assessment/resume_run.md"), + response_model=APIResponse[AssessmentResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def resume_assessment_run( + run_id: int, + session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[AssessmentResponse]: + """Resume a failed child run from its failed stage, reusing completed stages.""" + run = get_run_by_id( + session=session, + run_id=run_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + result = resume_run( + session=session, + run=run, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + return APIResponse.success_response(data=result) + + @router.get( "/runs", description=load_description("assessment/list_runs.md"), diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 6a249a92e..34a3f5878 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -233,8 +233,8 @@ def run_tts_batch_submission( @celery_app.task(bind=True, queue="low_priority", priority=1) -@gevent_timeout(settings.ASSESSMENT_RUN_SOFT_TIME_LIMIT, "run_assessment_run") -def run_assessment_run( +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_assessment_pipeline") +def run_assessment_pipeline( self, run_id: int, organization_id: int, @@ -242,12 +242,12 @@ def run_assessment_run( trace_id: str, **kwargs, ): - from app.services.assessment.tasks import execute_assessment_run + from app.services.assessment.tasks import execute_assessment_pipeline _set_trace(trace_id) return _run_with_otel_parent( self, - lambda: execute_assessment_run( + lambda: execute_assessment_pipeline( run_id=run_id, organization_id=organization_id, project_id=project_id, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index e2e3a5ff1..a54155922 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -171,19 +171,10 @@ def AWS_S3_BUCKET(self) -> str: DOC_TRANSFORMATION_PENDING_THRESHOLD_MINUTES: int = 30 PENDING_JOB_QUERY_TIMEOUT_MS: int = 1000 - # Assessment - ASSESSMENT_PREFILTER_GEMINI_MODEL: str = "gemini-3.1-flash-lite" - ASSESSMENT_PREFILTER_CONCURRENT_WORKERS: int = 8 - ASSESSMENT_PREFILTER_DUPLICATE_STORE_NAME: str = ( - "fileSearchStores/inquilabcorpus-782mxjcwisaz" - ) - # Soft timeout for the full assessment run task (prefilter pipeline + batch - # submission). Larger than the default task limit because prefilter runs many - # concurrent LLM calls over the whole dataset. Seconds. Default 2 hours. - ASSESSMENT_RUN_SOFT_TIME_LIMIT: int = 7200 - # Timeout for prefilter Gemini calls to prevent pipeline stalls from slow/hung requests - # (default: 2 minutes, in ms) - ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS: int = 120000 + # Assessment prefilter — provider + model for the batch prefilter stages. + ASSESSMENT_PREFILTER_PROVIDER: Literal["openai", "google"] = "openai" + ASSESSMENT_PREFILTER_MODEL: str = "gpt-5-mini" + ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = "vs_6a20339fbc148191867fd06d29133278" @computed_field # type: ignore[prop-decorator] @property diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py index 5debc3156..4918a97d0 100644 --- a/backend/app/crud/assessment/batch.py +++ b/backend/app/crud/assessment/batch.py @@ -13,7 +13,8 @@ from openpyxl.utils.exceptions import InvalidFileException from sqlmodel import Session -from app.core.batch import BATCH_KEY, start_batch_job +from app.core.batch import BATCH_KEY, GeminiBatchProvider, start_batch_job +from app.core.batch.client import GeminiClient from app.core.batch.openai import OpenAIBatchProvider from app.core.cloud import get_cloud_storage from app.models.assessment import ( @@ -30,10 +31,12 @@ normalize_llm_text, ) from app.services.assessment.utils.attachments import ( + attachment_type_for_row, build_gemini_attachment_parts, resolve_attachment_values, ) from app.services.llm.providers.registry import LLMProvider +from app.utils import get_openai_client logger = logging.getLogger(__name__) @@ -171,8 +174,6 @@ def build_openai_jsonl( } """ jsonl_data = [] - # Memoize per-item type probes across all rows in this build. - type_cache: dict[str, str] = {} for i, row in enumerate(rows): idx = row_indices[i] if row_indices is not None else i @@ -187,7 +188,13 @@ def build_openai_jsonl( # Attachments for att in attachments: cell_value = row.get(att.column, "") - input_parts.extend(resolve_attachment_values(cell_value, att, type_cache)) + input_parts.extend( + resolve_attachment_values( + cell_value, + att, + type_override=attachment_type_for_row(att, row), + ) + ) if not input_parts: logger.warning("[build_openai_jsonl] Skipping empty row | idx=%s", idx) @@ -231,8 +238,6 @@ def build_google_jsonl( } """ jsonl_data = [] - # Memoize per-item type probes across all rows in this build. - type_cache: dict[str, str] = {} for i, row in enumerate(rows): idx = row_indices[i] if row_indices is not None else i @@ -246,7 +251,13 @@ def build_google_jsonl( # Attachments (Gemini uses file_data for inline content) for att in attachments: cell_value = row.get(att.column, "") - parts.extend(build_gemini_attachment_parts(cell_value, att, type_cache)) + parts.extend( + build_gemini_attachment_parts( + cell_value, + att, + type_override=attachment_type_for_row(att, row), + ) + ) if not parts: logger.warning("[build_google_jsonl] Skipping empty row | idx=%s", idx) @@ -369,9 +380,6 @@ def submit_assessment_batch( row_indices=row_indices, ) - # Get OpenAI client and submit - from app.utils import get_openai_client - openai_client = get_openai_client( session=session, org_id=organization_id, @@ -410,10 +418,6 @@ def submit_assessment_batch( row_indices=row_indices, ) - # Get Gemini client and submit - from app.core.batch import GeminiBatchProvider - from app.core.batch.client import GeminiClient - gemini_client = GeminiClient.from_credentials( session=session, org_id=organization_id, diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py index 547cc2e31..eb3b529a9 100644 --- a/backend/app/crud/assessment/core.py +++ b/backend/app/crud/assessment/core.py @@ -282,11 +282,14 @@ def update_assessment_run_prefilter_stats( return run -_ACTIVE_RUN_STATUSES = frozenset( - {"prefilter_processing", "l2_processing", "processing", "in_progress"} -) -_FAILED_RUN_STATUSES = frozenset({"failed", "prefilter_failed"}) -_COMPLETED_RUN_STATUSES = frozenset({"completed", "completed_with_errors"}) +_ACTIVE_RUN_STATUSES = { + "prefilter_processing", + "l2_processing", + "processing", + "in_progress", +} +_FAILED_RUN_STATUSES = {"failed", "prefilter_failed"} +_COMPLETED_RUN_STATUSES = {"completed", "completed_with_errors"} def compute_run_counts(runs: list[AssessmentRun]) -> AssessmentRunCounts: @@ -334,6 +337,8 @@ def build_run_stats(runs: list[AssessmentRun]) -> list[AssessmentRunStat]: prefilter_total_rows=run.prefilter_total_rows, prefilter_total_passed=run.prefilter_total_passed, prefilter_total_rejected=run.prefilter_total_rejected, + stage=run.stage, + stage_status=run.stage_status, ) for run in runs ] diff --git a/backend/app/crud/assessment/cron.py b/backend/app/crud/assessment/cron.py index 6cb76b1f5..000d61666 100644 --- a/backend/app/crud/assessment/cron.py +++ b/backend/app/crud/assessment/cron.py @@ -9,13 +9,12 @@ compute_run_counts, get_assessment_runs_for_assessment, recompute_assessment_status, - update_assessment_run_status, ) from app.crud.assessment.processing import ( - check_and_process_assessment, format_assessment_failure_message, + process_run_batches, ) -from app.models.assessment import Assessment, AssessmentRun +from app.models.assessment import Assessment, AssessmentRun, StageStatus logger = logging.getLogger(__name__) @@ -78,7 +77,9 @@ async def poll_all_pending_assessment_evaluations( runs = get_assessment_runs_for_assessment( session=session, assessment_id=assessment.id ) - active_runs = [run for run in runs if run.status == "l2_processing"] + active_runs = [ + run for run in runs if run.stage_status == StageStatus.PROCESSING + ] if not active_runs: refreshed = recompute_assessment_status( @@ -100,7 +101,7 @@ async def poll_all_pending_assessment_evaluations( for run in active_runs: try: - result = await check_and_process_assessment( + result = await process_run_batches( run=run, session=session, ) @@ -115,51 +116,15 @@ async def poll_all_pending_assessment_evaluations( still_processing += 1 except Exception as e: - error_msg = format_assessment_failure_message(e) - logger.error( - "[poll_all_pending_assessment_evaluations] Failed run %s | " - "experiment=%s | assessment_id=%s | config_id=%s | config_version=%s | error=%s", + session.rollback() + logger.warning( + "[poll_all_pending_assessment_evaluations] transient error polling " + "run %s (assessment %s), will retry: %s", run.id, - assessment.experiment_name, run.assessment_id, - run.config_id, - run.config_version, - error_msg, - exc_info=True, + format_assessment_failure_message(e), ) - try: - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message=error_msg, - ) - recompute_assessment_status( - session=session, assessment_id=assessment.id - ) - failure_result = { - "assessment_id": run.assessment_id, - "run_id": run.id, - "experiment_name": assessment.experiment_name, - "config_id": str(run.config_id) if run.config_id else None, - "config_version": run.config_version, - "action": "failed", - "error": error_msg, - "current_status": "failed", - } - all_results.append(failure_result) - failed += 1 - except Exception as cleanup_exc: - logger.error( - "[poll_all_pending_assessment_evaluations] Cleanup failed for run %s | " - "assessment_id=%s | experiment=%s | error=%s", - run.id, - run.assessment_id, - assessment.experiment_name, - cleanup_exc, - exc_info=True, - ) - failed += 1 + still_processing += 1 logger.info( "[poll_all_pending_assessment_evaluations] Summary | processed=%s | failed=%s | still_processing=%s", diff --git a/backend/app/crud/assessment/processing.py b/backend/app/crud/assessment/processing.py index f2a27455c..7b860de5d 100644 --- a/backend/app/crud/assessment/processing.py +++ b/backend/app/crud/assessment/processing.py @@ -10,25 +10,25 @@ from fastapi import HTTPException from sqlmodel import Session -from app.core.batch import ( - BATCH_KEY, - GeminiBatchProvider, - OpenAIBatchProvider, - download_batch_results, - poll_batch_status, - upload_batch_results_to_object_store, -) +from app.celery.tasks.job_execution import run_assessment_pipeline +from app.core.batch import BATCH_KEY, poll_batch_status, process_completed_batch from app.core.batch.base import BatchProvider -from app.core.batch.client import GeminiClient from app.core.batch.gemini import BatchJobState, extract_text_from_response_dict from app.crud.assessment import ( recompute_assessment_status, + update_assessment_run_prefilter_stats, update_assessment_run_status, ) from app.crud.job import get_batch_job -from app.models.assessment import Assessment, AssessmentRun +from app.models.assessment import Assessment, AssessmentRun, StageStatus +from app.services.assessment.stages import ( + GATE_STAGES, + STAGE_PARSERS, + _get_batch_provider, + advance_or_finalize, + load_raw_batch_results, +) from app.services.llm.providers.registry import LLMProvider -from app.utils import get_openai_client logger = logging.getLogger(__name__) @@ -87,32 +87,6 @@ def _sanitize_json_output(raw: str) -> str: return "".join(result) -def _get_batch_provider( - session: Session, - provider_name: str, - organization_id: int, - project_id: int, -) -> BatchProvider: - """Get the appropriate batch provider instance.""" - if provider_name in (LLMProvider.OPENAI, LLMProvider.OPENAI_NATIVE): - openai_client = get_openai_client( - session=session, - org_id=organization_id, - project_id=project_id, - ) - return OpenAIBatchProvider(client=openai_client) - - if provider_name in (LLMProvider.GOOGLE, LLMProvider.GOOGLE_NATIVE): - gemini_client = GeminiClient.from_credentials( - session=session, - org_id=organization_id, - project_id=project_id, - ) - return GeminiBatchProvider(client=gemini_client.client) - - raise ValueError(f"Unsupported provider for assessment polling: {provider_name}") - - def parse_assessment_output( raw_results: list[dict[str, Any]], provider_name: str, @@ -262,234 +236,142 @@ def parse_assessment_output( return results -async def check_and_process_assessment( - run: AssessmentRun, - session: Session, -) -> dict[str, Any]: - """Check assessment batch status and process if completed. - - Args: - run: AssessmentRun to check - session: Database session +_PROVIDER_SUCCESS = {"completed", BatchJobState.SUCCEEDED.value} +_PROVIDER_FAILED = { + "failed", + "expired", + "cancelled", + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} - Returns: - Dict with status information - """ - log_prefix = f"[check_and_process_assessment][assessment_run={run.id}]" - previous_status = run.status - parent_pre = session.get(Assessment, run.assessment_id) - experiment_name_pre = parent_pre.experiment_name if parent_pre else None - try: - if not run.batch_job_id: - raise ValueError(f"Assessment run {run.id} has no batch_job_id") +def _poll_stage_outcome(session: Session, provider: BatchProvider, batch_job) -> str: + """Poll one stage's batch; on success download+persist. Returns the outcome.""" + status_result = poll_batch_status( + session=session, provider=provider, batch_job=batch_job + ) + session.refresh(batch_job) + status = batch_job.provider_status - batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) - if not batch_job: - raise ValueError(f"BatchJob {run.batch_job_id} not found") + if status in _PROVIDER_SUCCESS: + if batch_job.provider_output_file_id: + process_completed_batch( + session=session, provider=provider, batch_job=batch_job + ) + return "completed" + counts = status_result.get("request_counts") or {} + if counts.get("completed", 0) == 0 and ( + counts.get("failed", 0) > 0 or status_result.get("error_file_id") + ): + return "failed" + return "no_change" # output genuinely not ready yet — retry next cycle + if status in _PROVIDER_FAILED: + return "failed" + return "no_change" - parent = parent_pre - if not parent: - raise ValueError(f"Parent assessment {run.assessment_id} not found") - # Get provider and poll status - provider = _get_batch_provider( +def _record_gate_stats( + session: Session, run: AssessmentRun, stage: str, batch_job, project_id: int +) -> None: + """For a go/no-go stage, persist passed/rejected counts from its results.""" + try: + raw = load_raw_batch_results(session, batch_job, project_id) + outputs = parse_assessment_output(raw, batch_job.provider) + parsed = STAGE_PARSERS[stage](outputs) + total = len(parsed) + passed = sum(1 for r in parsed.values() if r.get("verdict")) + update_assessment_run_prefilter_stats( session=session, - provider_name=batch_job.provider, - organization_id=parent.organization_id, - project_id=parent.project_id, + run=run, + prefilter_total_rows=total, + prefilter_total_passed=passed, + prefilter_total_rejected=total - passed, ) - status_result = poll_batch_status( - session=session, - provider=provider, - batch_job=batch_job, + except Exception as exc: + logger.warning( + "[_record_gate_stats] run_id=%s stage=%s — %s", run.id, stage, exc ) - session.refresh(batch_job) - - provider_status = batch_job.provider_status - if ( - provider_status == "completed" - or provider_status == BatchJobState.SUCCEEDED.value - ): - if not batch_job.provider_output_file_id: - request_counts = status_result.get("request_counts") or {} - error_file_id = status_result.get("error_file_id") - failed_count = request_counts.get("failed", 0) - completed_count = request_counts.get("completed", 0) - total_count = request_counts.get("total", 0) - - if error_file_id and failed_count > 0 and completed_count == 0: - error_msg = ( - f"Batch completed with {failed_count} failed request(s)" - f" and no successful outputs" - ) - if total_count: - error_msg += f" out of {total_count}" - error_msg += f" (error_file_id: {error_file_id})" - - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message=error_msg, - ) - recompute_assessment_status( - session=session, assessment_id=run.assessment_id - ) - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": "failed", - "provider_status": provider_status, - "action": "failed", - "error": error_msg, - } +def _fail_run_stage( + session: Session, run: AssessmentRun, message: str +) -> dict[str, Any]: + # Keep run.stage at the failed stage so a resume knows where to restart; + # stage_status == FAILED is the failure marker. + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, run=run, status="failed", error_message=message + ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + return {"run_id": run.id, "current_status": "failed", "action": "failed"} - logger.info( - f"{log_prefix} Batch completed but output file is not ready yet | " - f"batch_job_id={batch_job.id} | provider_status={provider_status}" - ) - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": run.status, - "provider_status": provider_status, - "action": "no_change", - } - # Download and process results - raw_results = download_batch_results(provider=provider, batch_job=batch_job) +async def process_run_batches(run: AssessmentRun, session: Session) -> dict[str, Any]: + """Poll the run's current-stage batch; on completion advance to the next stage.""" + parent = session.get(Assessment, run.assessment_id) + if not parent: + raise ValueError(f"Parent assessment {run.assessment_id} not found") - # Upload raw results to object store - object_store_url = None - try: - object_store_url = upload_batch_results_to_object_store( - session=session, batch_job=batch_job, results=raw_results - ) - except Exception as e: - logger.error( - "%s Object store upload failed — results may be unrecoverable " - "if the provider deletes the output file before next poll: %s", - log_prefix, - e, - exc_info=True, - ) + stage = run.stage + if not stage or run.stage_status != StageStatus.PROCESSING: + return {"run_id": run.id, "current_status": run.status, "action": "no_change"} - # Parse results - parsed = parse_assessment_output(raw_results, batch_job.provider) - error_count = sum(1 for result in parsed if result.get("error")) - success_count = sum(1 for result in parsed if not result.get("error")) - - # Update run status - error_msg = f"{error_count} item(s) failed" if error_count > 0 else None - run_status = ( - "failed" - if parsed and success_count == 0 and error_count > 0 - else "completed" - ) + batch_id = (run.stage_batches or {}).get(stage) + batch_job = ( + get_batch_job(session=session, batch_job_id=batch_id) if batch_id else None + ) + if not batch_job: + return _fail_run_stage(session, run, f"Stage {stage} batch not found") - if not parsed: - run_status = "failed" - error_msg = "Batch completed but no valid results were produced" + # Transient errors here (DNS, network, provider hiccup) must NOT fail the run — + # the batch is still running. Skip this cycle; the cron retries next tick. + try: + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=parent.organization_id, + project_id=parent.project_id, + ) + outcome = _poll_stage_outcome(session, provider, batch_job) + except Exception as exc: + logger.warning( + "[process_run_batches] run_id=%s stage=%s poll error, will retry: %s", + run.id, + stage, + exc, + ) + return {"run_id": run.id, "current_status": run.status, "action": "no_change"} - update_assessment_run_status( - session=session, - run=run, - status=run_status, - error_message=error_msg, - object_store_url=object_store_url, - ) - recompute_assessment_status( - session=session, assessment_id=run.assessment_id - ) + if outcome == "no_change": + return {"run_id": run.id, "current_status": run.status, "action": "no_change"} + if outcome == "failed": + return _fail_run_stage( + session, run, batch_job.error_message or f"Stage {stage} failed" + ) - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": run_status, - "provider_status": provider_status, - "action": "processed" if run_status == "completed" else "failed", - "total_results": len(parsed), - "errors": error_count, - } - - elif provider_status in ( - "failed", - "expired", - "cancelled", - BatchJobState.FAILED.value, - BatchJobState.CANCELLED.value, - BatchJobState.EXPIRED.value, - ): - error_msg = batch_job.error_message or f"Batch {provider_status}" - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message=error_msg, - ) - recompute_assessment_status( - session=session, assessment_id=run.assessment_id - ) + run.stage_status = StageStatus.COMPLETED + if stage in GATE_STAGES: + _record_gate_stats(session, run, stage, batch_job, parent.project_id) - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": "failed", - "provider_status": provider_status, - "action": "failed", - "error": error_msg, - } + nxt = advance_or_finalize(run) + session.add(run) + session.commit() + recompute_assessment_status(session=session, assessment_id=run.assessment_id) - else: - # Still processing - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": run.status, - "provider_status": provider_status, - "action": "no_change", - } - - except Exception as e: - error_msg = format_assessment_failure_message(e) - logger.error( - f"{log_prefix} Error checking assessment: {error_msg}", - exc_info=True, - ) - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message=error_msg, + if nxt: + run_assessment_pipeline.delay( + run_id=run.id, + organization_id=parent.organization_id, + project_id=parent.project_id, + trace_id="", ) - recompute_assessment_status(session=session, assessment_id=run.assessment_id) - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": "failed", - "provider_status": "unknown", - "action": "failed", - "error": error_msg, - } - - -async def poll_all_pending_assessments(session: Session) -> dict[str, Any]: - """Backward-compatible wrapper for parent-first assessment polling.""" - from app.crud.assessment.cron import poll_all_pending_assessment_evaluations - - return await poll_all_pending_assessment_evaluations(session=session) + + return { + "run_id": run.id, + "assessment_id": run.assessment_id, + "experiment_name": parent.experiment_name, + "current_status": run.status, + "action": "processed", + } diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index 8ff468db2..9d1378f83 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -1,11 +1,12 @@ """Assessment models — DB tables, Pydantic schemas, and LLM param wrappers.""" from datetime import datetime +from enum import StrEnum from typing import TYPE_CHECKING, Any, Literal, Optional from uuid import UUID from pydantic import BaseModel, Field -from sqlalchemy import JSON, Column, Index, Text +from sqlalchemy import Column, Index, Text from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField from sqlmodel import Relationship, SQLModel @@ -17,6 +18,25 @@ from app.models.batch_job import BatchJob +class Stage(StrEnum): + """Pipeline stages, in execution order. Business step only (status is separate).""" + + PRE_FILTER_TOPIC_RELEVANCE = "PRE_FILTER_TOPIC_RELEVANCE" + PRE_FILTER_DUPLICATE_DETECTION = "PRE_FILTER_DUPLICATE_DETECTION" + L2_ASSESSMENT = "L2_ASSESSMENT" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +class StageStatus(StrEnum): + """Execution status of the current stage.""" + + PENDING = "PENDING" + PROCESSING = "PROCESSING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + class Assessment(SQLModel, table=True): """Parent assessment — one experiment over a dataset, grouping N config runs.""" @@ -109,11 +129,44 @@ class AssessmentRun(SQLModel, table=True): default="pending", sa_column_kwargs={ "comment": ( - "Unified pipeline status: pending, prefilter_processing, prefilter_failed, " - "l2_processing, completed, completed_with_errors, failed" + "Overall run status: pending, processing, completed, " + "completed_with_errors, failed" + ) + }, + ) + stage: str | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": ( + "Current pipeline stage (Stage enum): PRE_FILTER_TOPIC_RELEVANCE, " + "PRE_FILTER_DUPLICATE_DETECTION, L2_ASSESSMENT, COMPLETED, FAILED" ) }, ) + stage_status: str | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "StageStatus of stage: PENDING, PROCESSING, COMPLETED, FAILED" + }, + ) + pipeline: dict[str, Any] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="Ordered stage config driving execution: {'stages': [...]}", + ), + ) + stage_batches: dict[str, int] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="Map of stage name -> batch_job id, for per-stage result lookup", + ), + ) batch_job_id: int | None = SQLField( default=None, foreign_key="batch_job.id", @@ -213,6 +266,8 @@ class AssessmentRunStat(BaseModel): prefilter_total_rows: int | None = None prefilter_total_passed: int | None = None prefilter_total_rejected: int | None = None + stage: str | None = None + stage_status: str | None = None class AssessmentPublic(BaseModel): @@ -255,6 +310,9 @@ class AssessmentRunPublic(BaseModel): prefilter_total_rows: int | None = None prefilter_total_passed: int | None = None prefilter_total_rejected: int | None = None + stage: str | None = None + stage_status: str | None = None + pipeline: dict[str, Any] | None = None post_processing_config: dict[str, Any] | None = None inserted_at: datetime updated_at: datetime @@ -280,12 +338,24 @@ class AssessmentAttachment(BaseModel): type: Literal["image", "pdf", "mixed"] = Field( ..., description=( - "Attachment type. 'mixed' detects image vs pdf per item (for columns " - "that contain both); 'image'/'pdf' force a type and act as fallback " - "when per-item detection is inconclusive." + "Attachment type. 'image'/'pdf' force the type for every row. 'mixed' " + "resolves the per-row type from type_column via type_value_map." ), ) format: Literal["url", "base64"] = Field(..., description="Data format") + type_column: str | None = Field( + None, + description=( + "For 'mixed': the dataset column whose value decides each row's type." + ), + ) + type_value_map: dict[str, str] | None = Field( + None, + description=( + "For 'mixed': maps a type_column value to 'image' or 'pdf' " + "(e.g. {'Photo': 'image', 'Report': 'pdf'})." + ), + ) class AssessmentConfigRef(BaseModel): diff --git a/backend/app/services/assessment/prefilter/__init__.py b/backend/app/services/assessment/prefilter/__init__.py index 6cd16dce2..2e763bd4f 100644 --- a/backend/app/services/assessment/prefilter/__init__.py +++ b/backend/app/services/assessment/prefilter/__init__.py @@ -1,3 +1,3 @@ -from app.services.assessment.prefilter.pipeline import run_prefilter_pipeline +from app.services.assessment.prefilter.pipeline import resolve_prefilter_settings -__all__ = ["run_prefilter_pipeline"] +__all__ = ["resolve_prefilter_settings"] diff --git a/backend/app/services/assessment/prefilter/duplicate_detection.py b/backend/app/services/assessment/prefilter/duplicate_detection.py index bba004457..6d4358b75 100644 --- a/backend/app/services/assessment/prefilter/duplicate_detection.py +++ b/backend/app/services/assessment/prefilter/duplicate_detection.py @@ -1,222 +1,117 @@ -"""Duplicate detection filter for prefilter pipeline.""" +"""Duplicate detection stage: build per-record file_search batch requests, parse verdicts.""" import json import logging -import re from typing import Any -from google import genai -from google.genai import types - from app.core.config import settings +from app.services.assessment.prefilter.request_builder import build_request_line logger = logging.getLogger(__name__) -_VAGUE_SYS = """ -You are a strict VAGUENESS gate for the School Innovation Marathon (SIM) -duplicate-detection pipeline. Submissions come from Indian school students grades 6-12. -You run BEFORE corpus duplicate detection. Decide only if the submission has enough -surface area for corpus matching. NOT a quality gate. - -NOT VAGUE (let through to corpus check): -- Widely-known/textbook ideas (rainwater harvesting, anti-theft alarm) -- Weak novelty / unclear feasibility -- Hindi/Telugu/mixed Indian-language text -- Bad grammar or rambling if content present -- Long essays naming domain + audience + any mechanism - -VAGUE only when ALL: problem names no issue/target/domain, solution names no mechanism, -text is empty / aspirational ("make society better") / gibberish. - -DECISION: 0-1 clear dimensions present -> vague=true. 2+ -> vague=false. Borderline -> false. - -Output ONLY JSON: {"vague": true|false, "reason": "max 15 words"} -""" - _DUP_SYS = """ You are a strict duplicate-detection judge for an innovation competition corpus. -Given a submitted idea, search the corpus and compare precisely. -Focus on MECHANISM of the solution, not category or theme. - -Verdict (exactly one): DUPLICATE / OVERLAP / PARTIAL_MATCH / UNIQUE - - DUPLICATE: Both problem AND solution mechanism substantially match a corpus entry. - OVERLAP: Either problem OR solution mechanism matches, other side clearly different. - PARTIAL_MATCH: Thematic/conceptual similarity only — same domain, different mechanism. - UNIQUE: Neither problem nor solution substantially matches anything in corpus. - -Response format (follow exactly): -Verdict: -Title: -Source: -URL: -Matching sentence: -Reason: - -RULES: -- UNIQUE -> output ONLY Verdict + Reason. -- NOT UNIQUE -> Title, Source, URL, Matching sentence ALL required. -- Source/URL MUST be VERBATIM from "SOURCE_URL:" line in retrieved chunk. -- NEVER write filenames, page numbers, or constructed URLs. +If the submission is too vague for corpus matching (no problem/target/domain AND no +solution mechanism, or empty/gibberish), use verdict VAGUE. + +Otherwise search the corpus and compare precisely. Focus on the MECHANISM of the +solution, not category or theme: +- DUPLICATE: problem AND solution mechanism substantially match a corpus entry. +- OVERLAP: problem OR solution mechanism matches; the other side clearly differs. +- PARTIAL_MATCH: thematic/conceptual similarity only — same domain, different mechanism. +- UNIQUE: neither problem nor solution substantially matches anything in the corpus. + +Return JSON with keys: verdict, match_title, source_url, matching_sentence, reason. +For UNIQUE or VAGUE, set match_title, source_url and matching_sentence to "" and give a +short reason. Otherwise fill match_title, source_url (the SOURCE_URL verbatim from the +retrieved chunk), matching_sentence (the exact sentence) and a one-sentence reason. +Never invent or construct URLs or filenames. """ - -def _build_combined(content_parts: dict[str, str]) -> str: - parts = [f"{col}:\n{val}" for col, val in content_parts.items() if val.strip()] - return "\n\n".join(parts) - - -def _check_vague( - text: str, - gemini_client: genai.Client, - model: str, -) -> tuple[bool, str]: - try: - response = gemini_client.models.generate_content( - model=model, - contents=f"Submission:\n\n{text}", - config=types.GenerateContentConfig( - system_instruction=_VAGUE_SYS, - response_mime_type="application/json", - temperature=0.0, - http_options=types.HttpOptions( - timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS - ), - ), +_DUP_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "verdict": { + "type": "string", + "enum": ["DUPLICATE", "OVERLAP", "PARTIAL_MATCH", "UNIQUE", "VAGUE"], + }, + "match_title": {"type": "string"}, + "source_url": {"type": "string"}, + "matching_sentence": {"type": "string"}, + "reason": {"type": "string"}, + }, + "required": [ + "verdict", + "match_title", + "source_url", + "matching_sentence", + "reason", + ], +} + + +def _combined_text(row: dict[str, str], columns: list[str]) -> str: + parts = [ + f"{col}:\n{row.get(col, '')}" for col in columns if row.get(col, "").strip() + ] + return "\n\n".join(parts) or "(empty submission)" + + +def build_duplicate_detection_requests( + rows: list[tuple[int, dict[str, str]]], + columns: list[str], +) -> list[dict[str, Any]]: + """Build one batch JSONL line per record, grounded on the provider's corpus store.""" + store = settings.ASSESSMENT_PREFILTER_DUPLICATE_STORE or None + return [ + build_request_line( + key=f"dup_{idx}", + system=_DUP_SYS, + user_text=f"Submitted idea to check:\n\n{_combined_text(row, columns)}", + response_schema=_DUP_SCHEMA, + file_search_store=store, ) - parsed = json.loads((response.text or "").strip()) - return bool(parsed.get("vague", False)), str(parsed.get("reason", "")) - except Exception as exc: - logger.warning("[_check_vague] Parse error — defaulting not vague | %s", exc) - return False, "(vague check error — defaulting to not vague)" - - -def _call_file_search( - text: str, - gemini_client: genai.Client, - model: str, - store_name: str, -) -> str: - response = gemini_client.models.generate_content( - model=model, - contents=f"Submitted idea to check for duplicates:\n\n{text}", - config=types.GenerateContentConfig( - system_instruction=_DUP_SYS, - tools=[ - types.Tool( - file_search=types.FileSearch(file_search_store_names=[store_name]) - ) - ], - temperature=0.0, - http_options=types.HttpOptions( - timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS - ), - ), - ) - return response.text or "" - - -_VERDICT_VALUES = {"DUPLICATE", "OVERLAP", "PARTIAL_MATCH", "UNIQUE"} - - -def _parse_verdict(raw: str) -> dict[str, str | None]: - fields: dict[str, str | None] = { - "verdict": "", + for idx, row in rows + ] + + +def parse_duplicate_detection_results( + outputs: list[dict[str, Any]], +) -> dict[int, dict[str, Any]]: + """Parse extracted batch outputs into {row_id: {verdict, match_title, ...}}.""" + parsed: dict[int, dict[str, Any]] = {} + for out in outputs: + key = str(out.get("row_id", "")) + if not key.startswith("dup_"): + continue + try: + idx = int(key.split("_", 1)[1]) + except (ValueError, IndexError): + continue + if out.get("error") or not out.get("output"): + parsed[idx] = _error_record(out.get("error") or "Empty response") + continue + try: + data = json.loads(out["output"]) + parsed[idx] = { + "verdict": str(data.get("verdict") or "UNKNOWN"), + "match_title": data.get("match_title") or None, + "source_url": data.get("source_url") or None, + "matching_sentence": data.get("matching_sentence") or None, + "reason": data.get("reason") or None, + } + except Exception as exc: + logger.warning("[parse_duplicate_detection_results] %s — %s", key, exc) + parsed[idx] = _error_record(str(exc)[:200]) + return parsed + + +def _error_record(reason: str) -> dict[str, Any]: + return { + "verdict": "ERROR", "match_title": None, "source_url": None, "matching_sentence": None, - "reason": None, - } - keymap = { - "verdict": "verdict", - "title": "match_title", - "source": "source_url", - "url": "source_url", - "matching sentence": "matching_sentence", - "reason": "reason", + "reason": reason, } - for line in (raw or "").splitlines(): - if ":" not in line: - continue - k, _, v = line.partition(":") - norm = re.sub(r"[^a-z\s]", "", k.strip().lower()).strip() - if norm in keymap: - fields[keymap[norm]] = v.strip() or None - - # Fallback: scan entire response for a known verdict token - if not fields["verdict"] or fields["verdict"] not in _VERDICT_VALUES: - m = re.search(r"\b(DUPLICATE|OVERLAP|PARTIAL_MATCH|UNIQUE)\b", raw or "") - if m: - fields["verdict"] = m.group(1) - logger.warning( - "[_parse_verdict] key-based parse missed verdict; regex fallback found: %s", - fields["verdict"], - ) - else: - logger.warning( - "[_parse_verdict] verdict not found in response. raw=%r", - (raw or "")[:500], - ) - - return fields - - -def run_duplicate_detection( - row_idx: int, - row: dict[str, str], - columns: list[str], - gemini_client: genai.Client, - model: str, - store_name: str, -) -> dict[str, Any]: - """Run duplicate detection on a single row. - - Returns a dict with: row_id, verdict, match_title, source_url, - matching_sentence, reason. - Always passthrough — never gates L2. - """ - content_parts = {col: row.get(col, "") for col in columns} - combined = _build_combined(content_parts) or "(empty submission)" - - try: - is_vague, vague_reason = _check_vague(combined, gemini_client, model) - except Exception as exc: - logger.warning( - "[run_duplicate_detection] Vague check failed row_%s | %s", row_idx, exc - ) - is_vague, vague_reason = False, f"(vague check error: {exc})" - - if is_vague: - return { - "row_id": f"row_{row_idx}", - "verdict": "VAGUE", - "match_title": None, - "source_url": None, - "matching_sentence": None, - "reason": vague_reason, - } - - try: - raw = _call_file_search(combined, gemini_client, model, store_name) - parsed = _parse_verdict(raw) - return { - "row_id": f"row_{row_idx}", - "verdict": parsed["verdict"] or "UNKNOWN", - "match_title": parsed["match_title"], - "source_url": parsed["source_url"], - "matching_sentence": parsed["matching_sentence"], - "reason": parsed["reason"], - } - except Exception as exc: - logger.warning( - "[run_duplicate_detection] File search failed row_%s | %s", row_idx, exc - ) - return { - "row_id": f"row_{row_idx}", - "verdict": "ERROR", - "match_title": None, - "source_url": None, - "matching_sentence": None, - "reason": str(exc)[:200], - } diff --git a/backend/app/services/assessment/prefilter/pipeline.py b/backend/app/services/assessment/prefilter/pipeline.py index 131fdde8b..55bfb4f71 100644 --- a/backend/app/services/assessment/prefilter/pipeline.py +++ b/backend/app/services/assessment/prefilter/pipeline.py @@ -1,245 +1,22 @@ -"""prefilter pipeline orchestrator. +"""Prefilter config helpers shared by the batch pipeline stages.""" -Runs two filters in series for each row: -1. Topic Relevance (go/no-go) — REJECT stops the row. -2. Duplicate Detection (passthrough) — only on ACCEPTED rows. - -""" - -import json -import logging -from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any -from sqlmodel import Session - -from app.core.batch.client import GeminiClient -from app.core.config import settings -from app.core.cloud import get_cloud_storage -from app.core.storage_utils import upload_jsonl_to_object_store -from app.models.assessment import AssessmentAttachment, AssessmentRun -from app.services.assessment.prefilter.duplicate_detection import ( - run_duplicate_detection, -) -from app.services.assessment.prefilter.topic_relevance import run_topic_relevance - -logger = logging.getLogger(__name__) - - -def _build_prefilter_result( - row_idx: int, - tr_result: dict[str, Any] | None, - dup_result: dict[str, Any] | None, -) -> dict[str, Any]: - return { - "row_id": f"row_{row_idx}", - "prefilter_passed": tr_result["verdict"] if tr_result else True, - "topic_relevance": { - "decision": tr_result["decision"], - "column_relevance": tr_result.get("column_relevance") or {}, - "reasoning": tr_result["reasoning"], - } - if tr_result - else None, - "duplicate_detection": dup_result, - } - - -def run_prefilter_pipeline( - run: AssessmentRun, - rows: list[dict[str, str]], - prefilter_config: dict[str, Any], - session: Session, - organization_id: int, - project_id: int, - attachments: list[AssessmentAttachment] | None = None, -) -> tuple[list[dict[str, str]], list[int], list[dict[str, Any]]]: - """Run prefilter filters on all rows. - - Args: - run: The AssessmentRun record (used for S3 path and DB update). - rows: Full dataset rows loaded from object store. - prefilter_config: User-supplied config with topic_relevance and duplicate_detection keys. - session: DB session. - organization_id: For Gemini credential lookup. - project_id: For Gemini credential lookup and S3 storage. - - Returns: - (passed_rows, passed_indices, all_prefilter_results) - passed_rows: subset of rows where topic_relevance verdict=true. - passed_indices: original dataset indices of passed_rows (used to preserve row IDs in L2). - all_prefilter_results: one entry per input row (len == len(rows)). - """ - model = settings.ASSESSMENT_PREFILTER_GEMINI_MODEL - workers = settings.ASSESSMENT_PREFILTER_CONCURRENT_WORKERS - store_name = settings.ASSESSMENT_PREFILTER_DUPLICATE_STORE_NAME - # Future wait bound: the per-request HTTP timeout plus a small margin so a - # hung Gemini call surfaces as a future error instead of blocking forever. - future_timeout = settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS / 1000 + 30 +def resolve_prefilter_settings(prefilter_config: dict[str, Any]) -> dict[str, Any]: + """Flatten the prefilter config into the values the stage builders need.""" tr_config = prefilter_config.get("topic_relevance") or {} dup_config = prefilter_config.get("duplicate_detection") or {} - tr_columns: list[str] = tr_config.get("columns") or [] - tr_prompt: str = tr_config.get("prompt") or "" - dup_columns: list[str] = dup_config.get("columns") or [] - - tr_attachment_columns = tr_config.get("attachment_columns") - if tr_attachment_columns is None: - tr_attachments = list(attachments or []) - else: - selected = set(tr_attachment_columns) - tr_attachments = [a for a in (attachments or []) if a.column in selected] - - tr_enabled = bool(tr_columns and tr_prompt) - dup_enabled = bool(dup_columns) + tr_columns = tr_config.get("columns") or [] + tr_prompt = tr_config.get("prompt") or "" + dup_columns = dup_config.get("columns") or [] - if not tr_enabled and not dup_enabled: - logger.warning( - "[run_prefilter_pipeline] run_id=%s — no prefilter filters configured, skipping prefilter", - run.id, - ) - return rows, list(range(len(rows))), [] - - gemini_client = GeminiClient.from_credentials( - session=session, - org_id=organization_id, - project_id=project_id, - ).client - - logger.info( - "[run_prefilter_pipeline] run_id=%s | rows=%s | model=%s | workers=%s | tr=%s | dup=%s", - run.id, - len(rows), - model, - workers, - tr_enabled, - dup_enabled, - ) - - # tr_results[idx] = None when TR disabled → no topic_relevance columns in export - # Shared across rows so each unique attachment file is type-probed once. - attachment_type_cache: dict[str, str] = {} - - tr_results: dict[int, dict[str, Any] | None] = {} - if tr_enabled: - with ThreadPoolExecutor(max_workers=workers) as executor: - futs = { - executor.submit( - run_topic_relevance, - idx, - row, - tr_columns, - tr_prompt, - gemini_client, - model, - tr_attachments, - attachment_type_cache, - ): idx - for idx, row in enumerate(rows) - } - for fut in as_completed(futs): - idx = futs[fut] - try: - tr_results[idx] = fut.result(timeout=future_timeout) - except Exception as exc: - logger.warning( - "[run_prefilter_pipeline] TR future error row_%s | %s", idx, exc - ) - tr_results[idx] = { - "row_id": f"row_{idx}", - "verdict": True, - "decision": "ACCEPT", - "column_relevance": {}, - "reasoning": f"(future error — defaulting to pass) {exc}", - } - passed_indices = [idx for idx, r in tr_results.items() if r and r["verdict"]] - else: - for idx in range(len(rows)): - tr_results[idx] = None - passed_indices = list(range(len(rows))) - - rejected_count = len(rows) - len(passed_indices) - logger.info( - "[run_prefilter_pipeline] run_id=%s | TR done | passed=%s | rejected=%s", - run.id, - len(passed_indices), - rejected_count, - ) - - dup_results: dict[int, dict[str, Any]] = {} - if dup_columns and passed_indices: - with ThreadPoolExecutor(max_workers=workers) as executor: - futs = { - executor.submit( - run_duplicate_detection, - idx, - rows[idx], - dup_columns, - gemini_client, - model, - store_name, - ): idx - for idx in passed_indices - } - for fut in as_completed(futs): - idx = futs[fut] - try: - dup_results[idx] = fut.result(timeout=future_timeout) - except Exception as exc: - logger.warning( - "[run_prefilter_pipeline] DUP future error row_%s | %s", - idx, - exc, - ) - dup_results[idx] = { - "row_id": f"row_{idx}", - "verdict": "ERROR", - "match_title": None, - "source_url": None, - "matching_sentence": None, - "reason": str(exc)[:200], - } - - all_prefilter_results: list[dict[str, Any]] = [ - _build_prefilter_result(idx, tr_results[idx], dup_results.get(idx)) - for idx in range(len(rows)) - ] - - prefilter_object_store_url: str | None = None - try: - storage = get_cloud_storage(session=session, project_id=project_id) - prefilter_object_store_url = upload_jsonl_to_object_store( - storage=storage, - results=all_prefilter_results, - filename="prefilter_results.json", - subdirectory=f"assessment/run-{run.id}/prefilter", - format="json", - ) - logger.info( - "[run_prefilter_pipeline] run_id=%s | prefilter results uploaded to %s", - run.id, - prefilter_object_store_url, - ) - except Exception as exc: - logger.error( - "[run_prefilter_pipeline] run_id=%s | S3 upload failed | %s", - run.id, - exc, - exc_info=True, - ) - - from app.crud.assessment.core import update_assessment_run_prefilter_stats - - update_assessment_run_prefilter_stats( - session=session, - run=run, - prefilter_object_store_url=prefilter_object_store_url, - prefilter_total_rows=len(rows), - prefilter_total_passed=len(passed_indices), - prefilter_total_rejected=rejected_count, - ) - - sorted_passed_indices = sorted(passed_indices) - passed_rows = [rows[idx] for idx in sorted_passed_indices] - return passed_rows, sorted_passed_indices, all_prefilter_results + return { + "tr_columns": tr_columns, + "tr_prompt": tr_prompt, + "tr_attachment_columns": tr_config.get("attachment_columns"), + "dup_columns": dup_columns, + "tr_enabled": bool(tr_columns and tr_prompt), + "dup_enabled": bool(dup_columns), + } diff --git a/backend/app/services/assessment/prefilter/request_builder.py b/backend/app/services/assessment/prefilter/request_builder.py new file mode 100644 index 000000000..c9fa2e3ee --- /dev/null +++ b/backend/app/services/assessment/prefilter/request_builder.py @@ -0,0 +1,73 @@ +"""Provider-aware batch request line builder for prefilter stages.""" + +from typing import Any + +from app.core.config import settings +from app.services.assessment.mappers import _ensure_openai_strict_schema + + +def build_request_line( + key: str, + system: str, + user_text: str, + *, + attachment_parts: list[dict[str, Any]] | None = None, + response_schema: dict[str, Any] | None = None, + file_search_store: str | None = None, +) -> dict[str, Any]: + """Build one batch JSONL line shaped for the configured prefilter provider. + + ``attachment_parts`` are provider-shaped content parts (from the OpenAI/Gemini + attachment resolvers) appended after the text part. + """ + model = settings.ASSESSMENT_PREFILTER_MODEL + + if settings.ASSESSMENT_PREFILTER_PROVIDER == "openai": + content: list[dict[str, Any]] = [{"type": "input_text", "text": user_text}] + content.extend(attachment_parts or []) + body: dict[str, Any] = { + "model": model, + "instructions": system, + "input": [{"role": "user", "content": content}], + } + if response_schema is not None: + body["text"] = { + "format": { + "type": "json_schema", + "name": "result", + "strict": True, + "schema": _ensure_openai_strict_schema(response_schema), + } + } + if file_search_store: + body["tools"] = [ + { + "type": "file_search", + "vector_store_ids": [file_search_store], + "max_num_results": 20, + } + ] + return { + "custom_id": key, + "method": "POST", + "url": "/v1/responses", + "body": body, + } + + parts: list[dict[str, Any]] = [{"text": user_text}] + parts.extend(attachment_parts or []) + request: dict[str, Any] = { + "contents": [{"role": "user", "parts": parts}], + "systemInstruction": {"parts": [{"text": system}]}, + "model": f"models/{model}", + } + if response_schema is not None: + request["generationConfig"] = { + "responseMimeType": "application/json", + "responseSchema": response_schema, + } + if file_search_store: + request["tools"] = [ + {"fileSearch": {"fileSearchStoreNames": [file_search_store]}} + ] + return {"key": key, "request": request} diff --git a/backend/app/services/assessment/prefilter/topic_relevance.py b/backend/app/services/assessment/prefilter/topic_relevance.py index 053547ab7..9b4722827 100644 --- a/backend/app/services/assessment/prefilter/topic_relevance.py +++ b/backend/app/services/assessment/prefilter/topic_relevance.py @@ -1,121 +1,115 @@ -"""Topic relevance filter for prefilter pipeline. +"""Topic relevance go/no-go gate: one batch request per row (text + attachments). + +Each request returns a per-column relevance boolean for every text and attachment +column plus a final ACCEPT/REJECT verdict. """ import json import logging from typing import Any -from google import genai -from google.genai import types - from app.core.config import settings from app.models.assessment import AssessmentAttachment -from app.services.assessment.utils.attachments import build_gemini_attachment_parts +from app.services.assessment.prefilter.request_builder import build_request_line +from app.services.assessment.utils.attachments import ( + attachment_type_for_row, + build_gemini_attachment_parts, + resolve_attachment_values, +) logger = logging.getLogger(__name__) +_INSTRUCTIONS = ( + "\n\nJudge whether this submission is relevant to the topic. For EACH listed " + "column (including any attached document/image columns) set its value to true if " + "that column's content is relevant to the topic, else false. Then give a final " + "decision: ACCEPT if relevant enough to proceed, otherwise REJECT." +) + -def _build_output_schema(columns: list[str]) -> dict[str, Any]: - """Build output schema: locked decision + per-column relevance booleans + reasoning.""" +def _build_schema(columns: list[str]) -> dict[str, Any]: + """Output schema: decision + reasoning + a boolean per column.""" props: dict[str, Any] = { - "decision": { - "type": "string", - "enum": ["ACCEPT", "REJECT"], - "description": "Final verdict. ACCEPT to proceed to full evaluation, REJECT to stop here.", - }, + "decision": {"type": "string", "enum": ["ACCEPT", "REJECT"]}, + "reasoning": {"type": "string"}, } - required = ["decision"] - for col in columns: - props[col] = { - "type": "boolean", - "description": f"Whether the '{col}' column content is relevant to the topic.", - } - required.append(col) - - props["reasoning"] = { - "type": "string", - "description": "Explanation of the verdict and per-column relevance assessment.", + props[col] = {"type": "boolean"} + return { + "type": "object", + "properties": props, + "required": ["decision", "reasoning", *columns], } - required.append("reasoning") - return {"type": "object", "properties": props, "required": required} + +def _record_text(row: dict[str, str], columns: list[str]) -> str: + return "\n\n".join(f"{col}:\n{row.get(col, '') or ''}" for col in columns) -def run_topic_relevance( - row_idx: int, - row: dict[str, str], +def build_topic_relevance_requests( + rows: list[tuple[int, dict[str, str]]], columns: list[str], user_prompt: str, - gemini_client: genai.Client, - model: str, attachments: list[AssessmentAttachment] | None = None, - type_cache: dict[str, str] | None = None, -) -> dict[str, Any]: - """Run topic relevance check on a single row. +) -> list[dict[str, Any]]: + """Build one batch JSONL line per row, with text columns + attachment parts.""" + attachments = attachments or [] + is_openai = settings.ASSESSMENT_PREFILTER_PROVIDER == "openai" + schema = _build_schema(columns + [a.column for a in attachments]) + system = user_prompt.strip() + _INSTRUCTIONS - System instruction = user_prompt (the evaluation rubric/criteria). - User content = the selected columns as JSON plus every mapped attachment - (image/pdf) for the row, so relevance is judged on text and documents. - Each attachment column also gets its own relevance boolean in the schema, - so the export carries a ``topic_relevance_`` column. - Output schema enforced: decision (ACCEPT/REJECT) + reasoning. - On error defaults to verdict=True (fail-open). - """ - # Document columns that actually have a value for this row. - doc_columns: list[str] = [] - for att in attachments or []: - if att.column not in doc_columns and (row.get(att.column) or "").strip(): - doc_columns.append(att.column) - - schema_columns = columns + doc_columns - user_content = json.dumps({col: row.get(col, "") or "" for col in columns}) - output_schema = _build_output_schema(schema_columns) - - parts: list[dict[str, Any]] = [{"text": user_content}] - for att in attachments or []: - attachment_parts = build_gemini_attachment_parts( - row.get(att.column, ""), att, type_cache + lines: list[dict[str, Any]] = [] + for idx, row in rows: + attachment_parts: list[dict[str, Any]] = [] + for att in attachments: + cell = row.get(att.column, "") + if not cell.strip(): + continue + override = attachment_type_for_row(att, row) + attachment_parts.extend( + resolve_attachment_values(cell, att, type_override=override) + if is_openai + else build_gemini_attachment_parts(cell, att, type_override=override) + ) + lines.append( + build_request_line( + key=f"tr_{idx}", + system=system, + user_text=_record_text(row, columns), + attachment_parts=attachment_parts or None, + response_schema=schema, + ) ) - if attachment_parts: - parts.append({"text": f"Attached document(s) for column '{att.column}':"}) - parts.extend(attachment_parts) + return lines - try: - response = gemini_client.models.generate_content( - model=model, - contents=[{"role": "user", "parts": parts}], - config=types.GenerateContentConfig( - system_instruction=user_prompt.strip(), - response_mime_type="application/json", - response_schema=output_schema, - temperature=0.0, - http_options=types.HttpOptions( - timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS - ), - ), - ) - raw = (response.text or "").strip() - parsed = json.loads(raw) - decision = str(parsed.get("decision", "ACCEPT")).upper() - column_relevance = {col: bool(parsed.get(col, True)) for col in schema_columns} - return { - "row_id": f"row_{row_idx}", - "verdict": decision == "ACCEPT", - "decision": decision, - "column_relevance": column_relevance, - "reasoning": str(parsed.get("reasoning", "")), - } - except Exception as exc: - logger.warning( - "[run_topic_relevance] row_%s error — defaulting verdict=True | %s", - row_idx, - exc, - ) - return { - "row_id": f"row_{row_idx}", - "verdict": True, - "decision": "ACCEPT", - "column_relevance": {col: True for col in schema_columns}, - "reasoning": f"(evaluation error — defaulting to pass) {exc}", - } + +def parse_topic_relevance_results( + outputs: list[dict[str, Any]], +) -> dict[int, dict[str, Any]]: + """Parse outputs into {row_id: {verdict, decision, reasoning, column_relevance}}.""" + parsed: dict[int, dict[str, Any]] = {} + for out in outputs: + key = str(out.get("row_id", "")) + if not key.startswith("tr_"): + continue + try: + idx = int(key.split("_", 1)[1]) + except (ValueError, IndexError): + continue + try: + data = json.loads(out.get("output") or "") + decision = str(data.get("decision", "ACCEPT")).upper() + column_relevance = { + k: bool(v) + for k, v in data.items() + if k not in ("decision", "reasoning") + } + parsed[idx] = { + "verdict": decision == "ACCEPT", + "decision": decision, + "reasoning": str(data.get("reasoning", "")), + "column_relevance": column_relevance, + } + except Exception as exc: + logger.warning("[parse_topic_relevance_results] %s — %s", key, exc) + return parsed diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py index b2e2cea05..a0bcf7487 100644 --- a/backend/app/services/assessment/service.py +++ b/backend/app/services/assessment/service.py @@ -24,6 +24,7 @@ AssessmentResponse, AssessmentRun, AssessmentRunSummary, + StageStatus, ) from app.models.config.config import ConfigTag from app.services.llm.providers.registry import LLMProvider @@ -95,7 +96,7 @@ def start_assessment( Each run is created with status='pending' and handed off to a Celery worker that runs prefilter filtering then submits the L2 batch. """ - from app.celery.tasks.job_execution import run_assessment_run + from app.celery.tasks.job_execution import run_assessment_pipeline logger.info( "[start_assessment] Starting | experiment=%s | dataset_id=%s | configs=%s | org_id=%s", @@ -193,7 +194,7 @@ def start_assessment( ) runs.append(run) - run_assessment_run.delay( + run_assessment_pipeline.delay( run_id=run.id, organization_id=organization_id, project_id=project_id, @@ -283,3 +284,77 @@ def retry_assessment_run( organization_id=organization_id, project_id=project_id, ) + + +def resume_assessment_run( + session: Session, + run: AssessmentRun, + organization_id: int, + project_id: int, +) -> AssessmentResponse: + """Re-run a failed run from its failed stage, reusing completed upstream batches.""" + from app.celery.tasks.job_execution import run_assessment_pipeline + from app.services.assessment.stages import ordered_stages + + if run.stage_status != StageStatus.FAILED: + raise HTTPException( + status_code=400, + detail=f"Run {run.id} is not in a failed state and cannot be resumed", + ) + if run.stage not in ordered_stages(run.pipeline): + raise HTTPException( + status_code=400, + detail=f"Run {run.id} has no resumable failed stage", + ) + + parent = getattr(run, "assessment", None) or session.get( + Assessment, run.assessment_id + ) + if not parent: + raise HTTPException( + status_code=404, + detail=f"Parent assessment {run.assessment_id} not found", + ) + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=parent.dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + run.stage_status = StageStatus.PENDING + run.status = "processing" + run.error_message = None + session.add(run) + session.commit() + session.refresh(run) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + + logger.info( + "[resume_assessment_run] Resuming run_id=%s from stage=%s", + run.id, + run.stage, + ) + run_assessment_pipeline.delay( + run_id=run.id, + organization_id=organization_id, + project_id=project_id, + trace_id=correlation_id.get() or "", + ) + + return AssessmentResponse( + assessment_id=parent.id, + experiment_name=parent.experiment_name, + dataset_id=parent.dataset_id, + dataset_name=dataset.name if dataset else None, + num_configs=1, + runs=[ + AssessmentRunSummary( + run_id=run.id, + assessment_id=run.assessment_id, + config_id=str(run.config_id), + config_version=run.config_version, + status=run.status, + ) + ], + ) diff --git a/backend/app/services/assessment/stages.py b/backend/app/services/assessment/stages.py new file mode 100644 index 000000000..ba24e4f67 --- /dev/null +++ b/backend/app/services/assessment/stages.py @@ -0,0 +1,195 @@ +"""Stage registry, pipeline ordering, and Batch API executor.""" + +import logging +from collections.abc import Callable +from typing import Any + +from sqlmodel import Session + +from app.core.batch import ( + GeminiBatchProvider, + OpenAIBatchProvider, + download_batch_results, + start_batch_job, +) +from app.core.batch.base import BatchProvider +from app.core.batch.client import GeminiClient +from app.core.cloud import get_cloud_storage +from app.core.config import settings +from app.models.assessment import AssessmentRun, Stage, StageStatus +from app.models.batch_job import BatchJob, BatchJobType +from app.services.assessment.prefilter import resolve_prefilter_settings +from app.services.assessment.prefilter.duplicate_detection import ( + build_duplicate_detection_requests, + parse_duplicate_detection_results, +) +from app.services.assessment.prefilter.topic_relevance import ( + build_topic_relevance_requests, + parse_topic_relevance_results, +) +from app.services.llm.providers.registry import LLMProvider +from app.utils import get_openai_client + +logger = logging.getLogger(__name__) + +# Stages that gate the pipeline (only ACCEPTed rows continue). Others annotate. +GATE_STAGES = {Stage.PRE_FILTER_TOPIC_RELEVANCE} + +# Result parser per stage: raw batch results -> {row_id: result dict}. +STAGE_PARSERS: dict[str, Callable[[list[dict]], dict[int, dict[str, Any]]]] = { + Stage.PRE_FILTER_TOPIC_RELEVANCE: parse_topic_relevance_results, + Stage.PRE_FILTER_DUPLICATE_DETECTION: parse_duplicate_detection_results, +} + + +def build_pipeline(assessment_input: dict[str, Any]) -> dict[str, Any]: + """Build the ordered stage config; prefilter stages added only when configured.""" + cfg = resolve_prefilter_settings(assessment_input.get("prefilter_config") or {}) + stages: list[dict[str, Any]] = [] + if cfg["tr_enabled"]: + stages.append({"stage": Stage.PRE_FILTER_TOPIC_RELEVANCE, "type": "GO_NO_GO"}) + if cfg["dup_enabled"]: + stages.append( + {"stage": Stage.PRE_FILTER_DUPLICATE_DETECTION, "type": "ANNOTATIVE"} + ) + stages.append({"stage": Stage.L2_ASSESSMENT, "type": "ASSESSMENT"}) + + for order, entry in enumerate(stages, start=1): + entry["order"] = order + return {"stages": stages} + + +def ordered_stages(pipeline: dict[str, Any] | None) -> list[str]: + """The stage names in execution order.""" + return [s["stage"] for s in (pipeline or {}).get("stages", [])] + + +def next_stage( + pipeline: dict[str, Any] | None, current: str | None = None +) -> str | None: + """First stage when ``current`` is None, else the stage after it (None if last).""" + stages = ordered_stages(pipeline) + if current is None: + return stages[0] if stages else None + if current in stages and stages.index(current) + 1 < len(stages): + return stages[stages.index(current) + 1] + return None + + +def submit_prefilter_batch( + session: Session, + organization_id: int, + project_id: int, + jsonl_data: list[dict[str, Any]], + display_name: str, +) -> BatchJob: + """Submit a prefilter batch on the configured provider and return the BatchJob.""" + base = settings.ASSESSMENT_PREFILTER_PROVIDER + provider = _get_batch_provider( + session=session, + provider_name=base, + organization_id=organization_id, + project_id=project_id, + ) + if base == "openai": + config = { + "endpoint": "/v1/responses", + "completion_window": "24h", + "description": display_name, + } + else: + config = { + "display_name": display_name, + "model": f"models/{settings.ASSESSMENT_PREFILTER_MODEL}", + } + return start_batch_job( + session=session, + provider=provider, + provider_name=base, + job_type=BatchJobType.ASSESSMENT, + organization_id=organization_id, + project_id=project_id, + jsonl_data=jsonl_data, + config=config, + ) + + +def build_prefilter_requests( + stage: str, + rows: list[tuple[int, dict[str, str]]], + cfg: dict[str, Any], + attachments: list | None = None, +) -> list[dict[str, Any]]: + """Build the JSONL request lines for a prefilter stage.""" + if stage == Stage.PRE_FILTER_TOPIC_RELEVANCE: + return build_topic_relevance_requests( + rows, cfg["tr_columns"], cfg["tr_prompt"], attachments + ) + if stage == Stage.PRE_FILTER_DUPLICATE_DETECTION: + return build_duplicate_detection_requests(rows, cfg["dup_columns"]) + raise ValueError(f"Unknown prefilter stage: {stage}") + + +def _get_batch_provider( + session: Session, + provider_name: str, + organization_id: int, + project_id: int, +) -> BatchProvider: + """Build the batch provider instance for a given provider name.""" + if provider_name in (LLMProvider.OPENAI, LLMProvider.OPENAI_NATIVE): + return OpenAIBatchProvider( + client=get_openai_client( + session=session, org_id=organization_id, project_id=project_id + ) + ) + if provider_name in (LLMProvider.GOOGLE, LLMProvider.GOOGLE_NATIVE): + gemini_client = GeminiClient.from_credentials( + session=session, org_id=organization_id, project_id=project_id + ) + return GeminiBatchProvider(client=gemini_client.client) + raise ValueError(f"Unsupported batch provider: {provider_name}") + + +def load_raw_batch_results( + session: Session, batch_job: BatchJob, project_id: int +) -> list[dict[str, Any]]: + """Load a completed batch's raw result lines (object store first, else provider).""" + # Lazy import: app.services.assessment.utils.__init__ pulls in export, which + # imports this module's package — a top-level import would be circular. + from app.services.assessment.utils.parsing import parse_stored_results + + if batch_job.raw_output_url: + try: + storage = get_cloud_storage(session, project_id=project_id) + raw = parse_stored_results( + storage.stream(batch_job.raw_output_url).read().decode("utf-8") + ) + if raw: + return raw + except Exception as exc: + logger.warning( + "[load_raw_batch_results] S3 read failed batch %s — %s", + batch_job.id, + exc, + ) + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=batch_job.organization_id, + project_id=project_id, + ) + return download_batch_results(provider=provider, batch_job=batch_job) + + +def advance_or_finalize(run: AssessmentRun) -> str | None: + """Advance the run to the next stage (returned) or finalize it (returns None).""" + nxt = next_stage(run.pipeline, run.stage) + if nxt: + run.stage = nxt + run.stage_status = StageStatus.PENDING + return nxt + run.stage = Stage.COMPLETED + run.stage_status = StageStatus.COMPLETED + run.status = "completed" + return None diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py index 909a89a25..2360583f8 100644 --- a/backend/app/services/assessment/tasks.py +++ b/backend/app/services/assessment/tasks.py @@ -1,9 +1,13 @@ -"""Celery task logic for running a single assessment run (prefilter → L2 batch submit).""" +"""Orchestrator: submit the run's current PENDING stage as a batch, then exit.""" import logging +from asgi_correlation_id import correlation_id +from celery.exceptions import SoftTimeLimitExceeded +from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session +from app.celery.tasks.job_execution import run_assessment_pipeline from app.core.db import engine from app.crud.assessment import ( get_assessment_dataset_by_id, @@ -11,194 +15,260 @@ update_assessment_run_status, ) from app.crud.assessment.batch import _load_dataset_rows, submit_assessment_batch -from app.crud.config import ConfigCrud +from app.crud.assessment.processing import parse_assessment_output from app.crud.evaluations.core import resolve_evaluation_config +from app.crud.job import get_batch_job from app.models.assessment import ( Assessment, AssessmentAttachment, AssessmentRun, + Stage, + StageStatus, ) from app.models.config.config import ConfigTag -from app.services.assessment.prefilter import run_prefilter_pipeline +from app.services.assessment.prefilter import resolve_prefilter_settings +from app.services.assessment.stages import ( + GATE_STAGES, + STAGE_PARSERS, + advance_or_finalize, + build_pipeline, + build_prefilter_requests, + load_raw_batch_results, + next_stage, + ordered_stages, + submit_prefilter_batch, +) logger = logging.getLogger(__name__) +_PREFILTER_STAGES = { + Stage.PRE_FILTER_TOPIC_RELEVANCE, + Stage.PRE_FILTER_DUPLICATE_DETECTION, +} + + +def _mark_run_failed(run_id: int, error_message: str) -> None: + """Fail a run from a fresh session so a killed task leaves no dangling run.""" + try: + with Session(engine) as session: + run = session.get(AssessmentRun, run_id) + if ( + run is None + or run.stage == Stage.COMPLETED + or run.stage_status == StageStatus.FAILED + ): + return + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, run=run, status="failed", error_message=error_message + ) + recompute_assessment_status( + session=session, assessment_id=run.assessment_id + ) + logger.info("[_mark_run_failed] run_id=%s marked failed", run_id) + except Exception: + logger.error( + "[_mark_run_failed] could not mark run_id=%s failed", run_id, exc_info=True + ) + -def execute_assessment_run( - run_id: int, - organization_id: int, - project_id: int, +def execute_assessment_pipeline( + run_id: int, organization_id: int, project_id: int ) -> None: - """Run prefilter filtering then submit L2 batch for one AssessmentRun. + """Guarded entrypoint: submit the run's current stage, never leave it dangling.""" + try: + _orchestrate(run_id, organization_id, project_id) + except SoftTimeLimitExceeded: + logger.error("[execute_assessment_pipeline] soft time limit run_id=%s", run_id) + _mark_run_failed(run_id, "Assessment run exceeded the time limit.") + raise + except Exception: + logger.error( + "[execute_assessment_pipeline] unexpected failure run_id=%s", + run_id, + exc_info=True, + ) + _mark_run_failed(run_id, "Assessment run failed unexpectedly.") + raise + + +def _dispatch(run_id: int, organization_id: int, project_id: int) -> None: + run_assessment_pipeline.delay( + run_id=run_id, + organization_id=organization_id, + project_id=project_id, + trace_id=correlation_id.get() or "", + ) + + +def _resolve_run_context( + session: Session, run: AssessmentRun, organization_id: int, project_id: int +): + """Load the assessment, dataset, and resolved config; ``error`` set on failure.""" + assessment = session.get(Assessment, run.assessment_id) + if assessment is None: + return None, None, None, "Parent assessment not found." + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=assessment.dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + config_blob, error = resolve_evaluation_config( + session=session, + config_id=run.config_id, + config_version=run.config_version, + project_id=project_id, + tag=ConfigTag.ASSESSMENT, + ) + if error or config_blob is None: + return assessment, dataset, None, f"Config resolution failed: {error}" + return assessment, dataset, config_blob, None + - Status transitions: - pending → prefilter_processing → prefilter_failed (stop) - → l2_processing → (cron handles rest) - pending → l2_processing (when no prefilter_config) - """ +def _accepted_indices( + session: Session, run: AssessmentRun, total_rows: int, project_id: int +) -> list[int]: + """Row indices that passed every gate stage before the current one.""" + accepted = set(range(total_rows)) + for stage in ordered_stages(run.pipeline): + if stage == run.stage: + break + if stage not in GATE_STAGES: + continue + batch_id = (run.stage_batches or {}).get(stage) + if batch_id is None: + continue + batch_job = get_batch_job(session=session, batch_job_id=batch_id) + if not batch_job: + continue + raw = load_raw_batch_results(session, batch_job, project_id) + outputs = parse_assessment_output(raw, batch_job.provider) + parsed = STAGE_PARSERS[stage](outputs) + accepted &= {idx for idx, r in parsed.items() if r.get("verdict")} + return sorted(accepted) + + +def _orchestrate(run_id: int, organization_id: int, project_id: int) -> None: with Session(engine) as session: run = session.get(AssessmentRun, run_id) if run is None: - logger.error("[execute_assessment_run] run_id=%s not found", run_id) + logger.error("[execute_assessment_pipeline] run_id=%s not found", run_id) + return + if run.stage == Stage.COMPLETED or run.stage_status == StageStatus.FAILED: return - assessment = session.get(Assessment, run.assessment_id) - if assessment is None: - logger.error( - "[execute_assessment_run] parent assessment %s not found for run %s", - run.assessment_id, - run_id, - ) + if not run.pipeline: + run.pipeline = build_pipeline(run.input or {}) + flag_modified(run, "pipeline") + if run.stage is None: + run.stage = next_stage(run.pipeline) + run.stage_status = StageStatus.PENDING + run.status = "processing" + if run.stage_status != StageStatus.PENDING: + session.add(run) + session.commit() return + session.add(run) + session.commit() + session.refresh(run) - assessment_input = run.input or {} - dataset_id = assessment.dataset_id + _submit_stage(session, run, organization_id, project_id) - dataset = get_assessment_dataset_by_id( + +def _submit_stage( + session: Session, run: AssessmentRun, organization_id: int, project_id: int +) -> None: + assessment, dataset, config_blob, error = _resolve_run_context( + session, run, organization_id, project_id + ) + if error: + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, run=run, status="failed", error_message=error + ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + return + + all_rows = _load_dataset_rows(session, dataset) + if not all_rows: + run.stage_status = StageStatus.FAILED + update_assessment_run_status( session=session, - dataset_id=dataset_id, - organization_id=organization_id, - project_id=project_id, + run=run, + status="failed", + error_message="Dataset has no rows.", ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + return - config_crud = ConfigCrud(session=session, project_id=project_id) - parent_config = config_crud.read_one(run.config_id) - if parent_config is not None and parent_config.tag != ConfigTag.ASSESSMENT: - logger.error( - "[execute_assessment_run] config %s has wrong tag for run %s", - run.config_id, - run_id, - ) - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message="Config tag is not ASSESSMENT.", - ) - recompute_assessment_status(session=session, assessment_id=assessment.id) - return + accepted = _accepted_indices(session, run, len(all_rows), project_id) + rows_with_idx = [(i, all_rows[i]) for i in accepted] + stage = run.stage + + if not rows_with_idx: + # Nothing left for this stage (all rows rejected upstream) — advance. + _persist_advance(session, run, organization_id, project_id) + return - config_blob, error = resolve_evaluation_config( + if stage in _PREFILTER_STAGES: + cfg = resolve_prefilter_settings(run.input.get("prefilter_config") or {}) + attachments = [ + AssessmentAttachment(**a) for a in (run.input.get("attachments") or []) + ] + selected = cfg.get("tr_attachment_columns") + if selected is not None: + attachments = [a for a in attachments if a.column in set(selected)] + jsonl = build_prefilter_requests(stage, rows_with_idx, cfg, attachments) + batch_job = submit_prefilter_batch( session=session, - config_id=run.config_id, - config_version=run.config_version, + organization_id=organization_id, project_id=project_id, - tag=ConfigTag.ASSESSMENT, + jsonl_data=jsonl, + display_name=f"assessment-{run.id}-{stage}", ) - if error or config_blob is None: - logger.error( - "[execute_assessment_run] config resolution failed run_id=%s: %s", - run_id, - error, - ) - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message=f"Config resolution failed: {error}", - ) - recompute_assessment_status(session=session, assessment_id=assessment.id) - return + elif stage == Stage.L2_ASSESSMENT: + batch_job = submit_assessment_batch( + session=session, + run=run, + assessment=assessment, + dataset=dataset, + config_blob=config_blob, + assessment_input=run.input or {}, + organization_id=organization_id, + project_id=project_id, + preloaded_rows=[r for _, r in rows_with_idx], + row_indices=[i for i, _ in rows_with_idx], + ) + run.total_items = batch_job.total_items + else: + raise ValueError(f"Unknown stage: {stage}") - all_rows = _load_dataset_rows(session=session, dataset=dataset) - if not all_rows: - logger.error( - "[execute_assessment_run] dataset %s has no rows for run %s", - dataset_id, - run_id, - ) - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message="Dataset has no rows.", - ) - recompute_assessment_status(session=session, assessment_id=assessment.id) - return + stage_batches = dict(run.stage_batches or {}) + stage_batches[stage] = batch_job.id + run.stage_batches = stage_batches + flag_modified(run, "stage_batches") + run.stage_status = StageStatus.PROCESSING + run.status = "processing" + session.add(run) + session.commit() + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + + logger.info( + "[execute_assessment_pipeline] run_id=%s | stage=%s submitted | batch=%s | rows=%s", + run.id, + stage, + batch_job.id, + len(rows_with_idx), + ) - # prefilter pipeline - rows_for_l2 = all_rows - row_indices_for_l2: list[int] | None = None - prefilter_config = assessment_input.get("prefilter_config") - if prefilter_config: - update_assessment_run_status( - session=session, run=run, status="prefilter_processing" - ) - try: - rows_for_l2, row_indices_for_l2, _ = run_prefilter_pipeline( - run=run, - rows=all_rows, - prefilter_config=prefilter_config, - session=session, - organization_id=organization_id, - project_id=project_id, - attachments=[ - AssessmentAttachment(**a) - for a in assessment_input.get("attachments") or [] - ], - ) - logger.info( - "[execute_assessment_run] prefilter done | run_id=%s | rows_to_l2=%s / %s", - run_id, - len(rows_for_l2), - len(all_rows), - ) - except Exception as prefilter_exc: - logger.error( - "[execute_assessment_run] prefilter failed run_id=%s | %s", - run_id, - prefilter_exc, - exc_info=True, - ) - update_assessment_run_status( - session=session, - run=run, - status="prefilter_failed", - error_message=f"prefilter pipeline failed: {prefilter_exc}", - ) - recompute_assessment_status( - session=session, assessment_id=assessment.id - ) - return # L2 does not run when prefilter fails - - # L2 batch submit - try: - batch_job = submit_assessment_batch( - session=session, - run=run, - assessment=assessment, - dataset=dataset, - config_blob=config_blob, - assessment_input=assessment_input, - organization_id=organization_id, - project_id=project_id, - preloaded_rows=rows_for_l2, - row_indices=row_indices_for_l2, - ) - update_assessment_run_status( - session=session, - run=run, - status="l2_processing", - batch_job_id=batch_job.id, - total_items=batch_job.total_items, - ) - logger.info( - "[execute_assessment_run] L2 batch submitted | run_id=%s | batch_job_id=%s", - run_id, - batch_job.id, - ) - except Exception as e: - logger.error( - "[execute_assessment_run] L2 batch submit failed run_id=%s: %s", - run_id, - e, - exc_info=True, - ) - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message="Batch submission failed. Please try again or contact support.", - ) - recompute_assessment_status(session=session, assessment_id=assessment.id) +def _persist_advance( + session: Session, run: AssessmentRun, organization_id: int, project_id: int +) -> None: + nxt = advance_or_finalize(run) + session.add(run) + session.commit() + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + if nxt: + _dispatch(run.id, organization_id, project_id) diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py index 87ca3aba7..7228ada71 100644 --- a/backend/app/services/assessment/utils/attachments.py +++ b/backend/app/services/assessment/utils/attachments.py @@ -9,12 +9,9 @@ import logging import re from typing import Any -from urllib.parse import urljoin, urlparse - -import requests +from urllib.parse import urlparse from app.models.assessment import AssessmentAttachment -from app.utils import validate_callback_url logger = logging.getLogger(__name__) @@ -122,15 +119,6 @@ def _guess_image_mime_from_base64(payload: str) -> str | None: return _image_mime_from_magic(blob) -def _type_from_magic(blob: bytes) -> str | None: - """Detect 'image' or 'pdf' from leading magic bytes; None if neither.""" - if blob.startswith(b"%PDF"): - return "pdf" - if _image_mime_from_magic(blob): - return "image" - return None - - def resolve_image_mime_and_payload( value: str, format_type: str, @@ -146,120 +134,61 @@ def resolve_image_mime_and_payload( return _guess_image_mime_from_base64(payload) or "image/png", payload -def _drive_file_id(url: str) -> str | None: - """Extract a Google Drive file id from common share URL shapes.""" - match = re.match(r"https://drive\.google\.com/file/d/([^/]+)", url) - if match: - return match.group(1) - match = re.search(r"[?&]id=([a-zA-Z0-9_-]+)", url) - if match and ("drive.google.com" in url or "drive.usercontent.google.com" in url): - return match.group(1) - return None +def resolve_item_type(declared: str, type_override: str | None = None) -> str: + """Resolve an attachment item as 'image' or 'pdf' from the user-declared type. + Trusts the user: a per-row ``type_override`` (for 'mixed' columns) wins, else the + column's declared ``type``. Anything non-concrete falls back to 'image'. + """ + item_type = type_override or declared + return item_type if item_type in ("image", "pdf") else "image" -def _type_from_url_extension(url: str) -> str | None: - """Detect 'image' or 'pdf' from a URL path extension; None if unknown.""" - path = (urlparse(url).path or "").lower() - if path.endswith(".pdf"): - return "pdf" - if _guess_image_mime_from_url(url): - return "image" - return None +def _normalize_type_value(value: str) -> str: + return re.sub(r"\s+", " ", value).strip().casefold() -def _type_from_content_type(content_type: str | None) -> str | None: - if not content_type: - return None - content_type = content_type.split(";")[0].strip().lower() - if content_type == "application/pdf": - return "pdf" - if content_type.startswith("image/"): - return "image" - return None +def _split_type_values(value: str) -> list[str]: + return [ + normalized + for part in re.split(r"[\n,]+", value) + if (normalized := _normalize_type_value(part)) + ] -_PROBE_MAX_REDIRECTS = 3 +def attachment_type_for_row( + att: AssessmentAttachment, row: dict[str, str] +) -> str | None: + """For a 'mixed' column, resolve this row's type from type_column + type_value_map. -def _probe_url_type(url: str, num_bytes: int = 16) -> str | None: - """Probe a remote URL's type: ranged byte sniff first, Content-Type fallback. - Handles Google Drive URLs with the same logic as to_direct_attachment_url, since""" - file_id = _drive_file_id(url) - current = ( - f"https://drive.google.com/uc?export=download&id={file_id}" if file_id else url - ) - - try: - for _ in range(_PROBE_MAX_REDIRECTS + 1): - validate_callback_url(current) - with requests.get( - current, - headers={"Range": f"bytes=0-{num_bytes - 1}"}, - timeout=10, - stream=True, - allow_redirects=False, - ) as resp: - location = resp.headers.get("Location") - if resp.is_redirect and location: - current = urljoin(current, location) - continue - resp.raise_for_status() - for chunk in resp.iter_content(chunk_size=num_bytes): - magic_type = _type_from_magic(chunk) - if magic_type: - return magic_type - break - return _type_from_content_type(resp.headers.get("Content-Type")) - logger.warning(f"[_probe_url_type] Too many redirects probing {url}") - return None - except ValueError as e: - logger.warning(f"[_probe_url_type] Blocked unsafe probe URL {url}: {e}") - return None - except requests.RequestException as e: - logger.warning(f"[_probe_url_type] Probe failed for {url}: {e}") + Returns 'image'/'pdf', or None to let normal detection (extension/declared) decide. + """ + type_column = getattr(att, "type_column", None) + type_value_map = getattr(att, "type_value_map", None) + if att.type != "mixed" or not type_column or not type_value_map: return None + normalized_map: dict[str, str] = {} + for raw_values, mapped_type in type_value_map.items(): + if mapped_type not in ("image", "pdf"): + continue + for value in _split_type_values(raw_values): + normalized_map[value] = mapped_type -def detect_item_type( - value: str, - format_type: str, - fallback: str, - cache: dict[str, str] | None = None, -) -> str: - """Resolve a single attachment item as 'image' or 'pdf'. - - Order: data-URL/base64 magic (no network) -> URL extension -> remote probe - (ranged byte sniff, then Content-Type) -> declared ``fallback`` type. - ``fallback`` may be 'mixed'; when detection is inconclusive it resolves to - 'image'. Remote probe results are memoized in ``cache`` keyed by item value. - """ - # 'mixed' is not a concrete output type; terminal default is image. - safe_fallback = fallback if fallback in ("image", "pdf") else "image" - - if format_type != "url": - data_url_mime, payload = split_data_url(value) - if data_url_mime == "application/pdf": - return "pdf" - if data_url_mime and data_url_mime.startswith("image/"): - return "image" - blob = _decode_base64_prefix(payload) - return (_type_from_magic(blob) if blob else None) or safe_fallback - - if cache is not None and value in cache: - return cache[value] - - item_type = ( - _type_from_url_extension(value) or _probe_url_type(value) or safe_fallback - ) - if cache is not None: - cache[value] = item_type - return item_type + row_values = _split_type_values(row.get(type_column) or "") + if not row_values: + return None + + mapped_values = { + normalized_map[value] for value in row_values if value in normalized_map + } + return mapped_values.pop() if len(mapped_values) == 1 else None def resolve_attachment_values( value: str, att: AssessmentAttachment, - type_cache: dict[str, str] | None = None, + type_override: str | None = None, ) -> list[dict[str, Any]]: """Convert one dataset cell into one or more OpenAI-style input objects.""" value = value.strip() @@ -271,9 +200,9 @@ def resolve_attachment_values( else: values = [value] + item_type = resolve_item_type(att.type, type_override) resolved: list[dict[str, Any]] = [] for item_value in values: - item_type = detect_item_type(item_value, att.format, att.type, type_cache) normalized_value = ( to_direct_attachment_url(item_value, item_type) if att.format == "url" @@ -318,12 +247,12 @@ def resolve_attachment_values( def build_gemini_attachment_parts( value: str, att: AssessmentAttachment, - type_cache: dict[str, str] | None = None, + type_override: str | None = None, ) -> list[dict[str, Any]]: """Convert one dataset cell into one or more Gemini content parts. - Mirrors the per-item type detection used for the L2 batch so the same - image/pdf routing applies to prefilter (topic relevance) calls. + Mirrors the per-item type routing used for the L2 batch so the same + image/pdf handling applies to prefilter (topic relevance) calls. """ value = value.strip() if not value: @@ -331,9 +260,9 @@ def build_gemini_attachment_parts( values = split_attachment_urls(value) if att.format == "url" else [value] + item_type = resolve_item_type(att.type, type_override) parts: list[dict[str, Any]] = [] for item_value in values: - item_type = detect_item_type(item_value, att.format, att.type, type_cache) normalized_value = ( to_direct_attachment_url(item_value, item_type) if att.format == "url" diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index 39fa7691c..c83714d9a 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -12,17 +12,32 @@ from fastapi.responses import StreamingResponse from sqlmodel import Session +from app.core.batch import download_batch_results from app.core.cloud import get_cloud_storage from app.core.storage_utils import generate_timestamped_filename from app.crud.assessment.processing import parse_assessment_output from app.crud.job import get_batch_job -from app.models.assessment import Assessment, AssessmentExportRow, AssessmentRun +from app.models.assessment import ( + Assessment, + AssessmentExportRow, + AssessmentRun, + Stage, +) from app.models.batch_job import BatchJob from app.models.evaluation import EvaluationDataset +from app.services.assessment.prefilter.duplicate_detection import ( + parse_duplicate_detection_results, +) +from app.services.assessment.prefilter.topic_relevance import ( + parse_topic_relevance_results, +) +from app.services.assessment.stages import _get_batch_provider, load_raw_batch_results from app.services.assessment.utils.parsing import parse_stored_results, usage_totals +from app.services.assessment.utils.post_processing import apply_post_processing from app.utils import APIResponse _PREFILTER_JSON_COLUMNS = ["topic_relevance", "duplicate_detection"] +_XLSX_ILLEGAL_RE = re.compile("[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f\ud800-\udfff﷐-﷯￾￿]") logger = logging.getLogger(__name__) @@ -31,32 +46,59 @@ def _load_dataset_rows( session: Session, dataset: EvaluationDataset, ) -> list[dict[str, str]]: + # Imported lazily: app.crud.assessment.batch pulls this module via + # app.services.assessment.utils, so a top-level import would be circular. from app.crud.assessment.batch import _load_dataset_rows as load_dataset_rows return load_dataset_rows(session, dataset) +def _stage_batch_job( + session: Session, run: AssessmentRun, stage: str +) -> BatchJob | None: + """The batch job a run produced for a given stage, via stage_batches.""" + batch_id = (run.stage_batches or {}).get(stage) + return get_batch_job(session=session, batch_job_id=batch_id) if batch_id else None + + def _load_prefilter_results( session: Session, run: AssessmentRun, assessment: Assessment, ) -> dict[str, dict[str, Any]]: - """Load prefilter results from object store, keyed by row_id. Returns {} if unavailable.""" - if not run.prefilter_object_store_url: - return {} - try: - storage = get_cloud_storage(session, project_id=assessment.project_id) - body = storage.stream(run.prefilter_object_store_url) - raw = body.read().decode("utf-8") - results: list[dict[str, Any]] = json.loads(raw) - return {str(item["row_id"]): item for item in results if "row_id" in item} - except Exception as exc: - logger.warning( - "[_load_prefilter_results] Failed to load prefilter results for run id=%s: %s", - run.id, - exc, - ) - return {} + """Build per-row prefilter annotations from the TR + dup stage batches.""" + out: dict[str, dict[str, Any]] = {} + + tr_job = _stage_batch_job(session, run, Stage.PRE_FILTER_TOPIC_RELEVANCE) + if tr_job: + try: + raw = load_raw_batch_results(session, tr_job, assessment.project_id) + outputs = parse_assessment_output(raw, tr_job.provider) + for idx, r in parse_topic_relevance_results(outputs).items(): + out.setdefault(f"row_{idx}", {})["prefilter_passed"] = r["verdict"] + out[f"row_{idx}"]["topic_relevance"] = { + "decision": r["decision"], + "reasoning": r["reasoning"], + "column_relevance": r.get("column_relevance") or {}, + } + except Exception as exc: + logger.warning( + "[_load_prefilter_results] TR load failed run=%s: %s", run.id, exc + ) + + dup_job = _stage_batch_job(session, run, Stage.PRE_FILTER_DUPLICATE_DETECTION) + if dup_job: + try: + raw = load_raw_batch_results(session, dup_job, assessment.project_id) + outputs = parse_assessment_output(raw, dup_job.provider) + for idx, r in parse_duplicate_detection_results(outputs).items(): + out.setdefault(f"row_{idx}", {})["duplicate_detection"] = r + except Exception as exc: + logger.warning( + "[_load_prefilter_results] dup load failed run=%s: %s", run.id, exc + ) + + return out def _safe_filename_part(value: str) -> str: @@ -244,8 +286,6 @@ def serialize_export_rows( post_processing_config: dict[str, Any] | None = None, ) -> tuple[bytes, str]: """Serialize export rows into the requested file format.""" - from app.services.assessment.utils.post_processing import apply_post_processing - row_payload = [row.model_dump(mode="json") for row in export_rows] if export_format == "json": @@ -297,6 +337,11 @@ def serialize_export_rows( expanded, excel_fields = _drop_empty_columns(expanded, excel_fields) + def _clean(value: Any) -> Any: + return _XLSX_ILLEGAL_RE.sub("", value) if isinstance(value, str) else value + + expanded = [{k: _clean(v) for k, v in row.items()} for row in expanded] + buf = io.BytesIO() data_frame = pd.DataFrame(expanded, columns=excel_fields) with pd.ExcelWriter(buf) as writer: @@ -377,9 +422,6 @@ def _load_parsed_results_for_run( # 2. Fallback: download directly from batch provider if batch_job.provider_output_file_id: try: - from app.core.batch import download_batch_results - from app.crud.assessment.processing import _get_batch_provider - provider = _get_batch_provider( session=session, provider_name=batch_job.provider, @@ -463,18 +505,84 @@ def _extract_prefilter_json_columns( } +def _load_parsed_results_for_batch_job( + session: Session, + batch_job: BatchJob, + assessment: Assessment, +) -> list[dict[str, Any]] | None: + """Parse one chunk batch's stored results (object store first, provider fallback).""" + if batch_job.raw_output_url: + try: + storage = get_cloud_storage(session, project_id=assessment.project_id) + raw = parse_stored_results( + storage.stream(batch_job.raw_output_url).read().decode("utf-8") + ) + if raw: + return parse_assessment_output(raw, batch_job.provider) + except Exception as exc: + logger.warning( + "[_load_parsed_results_for_batch_job] S3 read failed for batch %s: %s", + batch_job.id, + exc, + ) + + if batch_job.provider_output_file_id: + try: + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=assessment.organization_id, + project_id=assessment.project_id, + ) + raw = download_batch_results(provider=provider, batch_job=batch_job) + return parse_assessment_output(raw, batch_job.provider) + except Exception as exc: + logger.error( + "[_load_parsed_results_for_batch_job] Provider download failed for " + "batch %s: %s", + batch_job.id, + exc, + exc_info=True, + ) + return None + + +def _load_l2_results_for_run( + session: Session, + run: AssessmentRun, + assessment: Assessment, +) -> dict[str, dict[str, Any]]: + """L2 results keyed by row_id, from the run's L2 stage batch ({} if not done).""" + merged: dict[str, dict[str, Any]] = {} + batch_job = _stage_batch_job(session, run, Stage.L2_ASSESSMENT) + if batch_job: + for item in ( + _load_parsed_results_for_batch_job(session, batch_job, assessment) or [] + ): + if "row_id" in item: + merged[str(item["row_id"])] = item + return merged + + +def _row_result_status( + prefilter_passed: bool, + l2_item: dict[str, Any] | None, + run_status: str, +) -> str: + """Per-row status: rejected, failed, passed, or processing (batch not done).""" + if not prefilter_passed: + return "prefilter_rejected" + if l2_item is None: + return "failed" if run_status == "failed" else "processing" + return "failed" if l2_item.get("error") else "passed" + + def load_export_rows_for_run( session: Session, run: AssessmentRun, assessment: Assessment | None = None, ) -> list[AssessmentExportRow]: - """Load flattened export rows for a single child assessment run. - - When prefilter results exist, ALL dataset rows are included in output. - prefilter-rejected rows have prefilter columns filled and L2 columns empty. - prefilter-passed rows have all columns filled. - Without prefilter, behaviour is unchanged (only L2 result rows returned). - """ + """Flatten one run's rows, merging prefilter annotations + L2 results by row_id.""" if assessment is None: assessment = session.get(Assessment, run.assessment_id) if assessment is None: @@ -488,133 +596,84 @@ def load_export_rows_for_run( dataset_name = dataset.name if dataset else None dataset_rows = _load_dataset_rows_for_run(session, run, assessment) - # Load prefilter results (empty dict if no prefilter was run) prefilter_by_row_id = _load_prefilter_results(session, run, assessment) - - # Load L2 results (may be None if batch not complete) - l2_by_row_id: dict[str, dict[str, Any]] = {} - if run.batch_job_id: - batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) - if batch_job: - parsed_results = _load_parsed_results_for_run( - session=session, run=run, batch_job=batch_job - ) - if parsed_results: - l2_by_row_id = { - str(item["row_id"]): item - for item in parsed_results - if "row_id" in item - } - + l2_by_row_id = _load_l2_results_for_run(session, run, assessment) has_prefilter = bool(prefilter_by_row_id) - if has_prefilter and dataset_rows: - # All rows in output — build from full dataset - export_rows: list[AssessmentExportRow] = [] - for row_idx, input_data in enumerate(dataset_rows): - row_id_str = f"row_{row_idx}" - prefilter_item = prefilter_by_row_id.get(row_id_str) - prefilter_cols = _extract_prefilter_json_columns(prefilter_item) - l2_item = l2_by_row_id.get(row_id_str) - - input_tokens, output_tokens, total_tokens = usage_totals( - l2_item.get("usage") if l2_item else None - ) - prefilter_passed = (prefilter_item or {}).get("prefilter_passed", True) - result_status = ( - "prefilter_rejected" - if not prefilter_passed - else ("failed" if l2_item and l2_item.get("error") else "passed") - ) - - export_rows.append( - AssessmentExportRow( - assessment_id=run.assessment_id, - experiment_name=assessment.experiment_name, - dataset_id=assessment.dataset_id, - dataset_name=dataset_name, - run_id=run.id, - run_name=assessment.experiment_name, - run_status=run.status, - config_id=run.config_id, - config_version=run.config_version, - row_id=row_id_str, - result_status=result_status, - input_data=input_data, - topic_relevance=prefilter_cols.get("topic_relevance"), - duplicate_detection=prefilter_cols.get("duplicate_detection"), - output=l2_item.get("output") if l2_item else None, - error=l2_item.get("error") if l2_item else None, - response_id=l2_item.get("response_id") if l2_item else None, - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - updated_at=run.updated_at, - ) - ) - return export_rows - - # No prefilter — original behaviour: only L2 result rows - if not run.batch_job_id: - logger.warning( - "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id - ) - return [] - - batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) - if not batch_job: - logger.warning( - "[load_export_rows_for_run] Missing batch job for run id=%s", run.id - ) - return [] - - parsed_results = _load_parsed_results_for_run( - session=session, run=run, batch_job=batch_job - ) - if not parsed_results: - logger.warning( - "[load_export_rows_for_run] Parsed results empty for run id=%s", run.id - ) - return [] - - export_rows = [] - for item in parsed_results: - input_tokens, output_tokens, total_tokens = usage_totals(item.get("usage")) - input_data = None - row_id_str = str(item.get("row_id", "")) - if dataset_rows and row_id_str.startswith("row_"): - try: - row_idx = int(row_id_str.split("_", 1)[1]) - if 0 <= row_idx < len(dataset_rows): - input_data = dataset_rows[row_idx] - except (ValueError, IndexError): - pass - - export_rows.append( - AssessmentExportRow( - assessment_id=run.assessment_id, - experiment_name=assessment.experiment_name, - dataset_id=assessment.dataset_id, + if dataset_rows: + rows = [ + _build_export_row( + run=run, + assessment=assessment, dataset_name=dataset_name, - run_id=run.id, - run_name=assessment.experiment_name, - run_status=run.status, - config_id=run.config_id, - config_version=run.config_version, - row_id=row_id_str, - result_status="failed" if item.get("error") else "passed", + row_id=f"row_{row_idx}", input_data=input_data, - output=item.get("output"), - error=item.get("error"), - response_id=item.get("response_id"), - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - updated_at=run.updated_at, + prefilter_item=prefilter_by_row_id.get(f"row_{row_idx}"), + l2_item=l2_by_row_id.get(f"row_{row_idx}"), + has_prefilter=has_prefilter, ) + for row_idx, input_data in enumerate(dataset_rows) + ] + return rows + + # Dataset unavailable — emit whatever results we have, indexed by row_id. + return [ + _build_export_row( + run=run, + assessment=assessment, + dataset_name=dataset_name, + row_id=str(row_id), + input_data=None, + prefilter_item=prefilter_by_row_id.get(str(row_id)), + l2_item=l2_item, + has_prefilter=has_prefilter, ) + for row_id, l2_item in l2_by_row_id.items() + ] - return export_rows + +def _build_export_row( + run: AssessmentRun, + assessment: Assessment, + dataset_name: str | None, + row_id: str, + input_data: dict[str, str] | None, + prefilter_item: dict[str, Any] | None, + l2_item: dict[str, Any] | None, + has_prefilter: bool, +) -> AssessmentExportRow: + prefilter_cols = ( + _extract_prefilter_json_columns(prefilter_item) + if has_prefilter + else {"topic_relevance": None, "duplicate_detection": None} + ) + prefilter_passed = (prefilter_item or {}).get("prefilter_passed", True) + input_tokens, output_tokens, total_tokens = usage_totals( + l2_item.get("usage") if l2_item else None + ) + return AssessmentExportRow( + assessment_id=run.assessment_id, + experiment_name=assessment.experiment_name, + dataset_id=assessment.dataset_id, + dataset_name=dataset_name, + run_id=run.id, + run_name=assessment.experiment_name, + run_status=run.status, + config_id=run.config_id, + config_version=run.config_version, + row_id=row_id, + result_status=_row_result_status(prefilter_passed, l2_item, run.status), + input_data=input_data, + topic_relevance=prefilter_cols.get("topic_relevance"), + duplicate_detection=prefilter_cols.get("duplicate_detection"), + output=l2_item.get("output") if l2_item else None, + error=l2_item.get("error") if l2_item else None, + response_id=l2_item.get("response_id") if l2_item else None, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + updated_at=run.updated_at, + ) def sort_export_rows( diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index 38373774c..472a6f67a 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -1,6 +1,5 @@ """Tests for assessment/batch.py provider routing in submit_assessment_batch.""" -import base64 import io from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -22,9 +21,9 @@ _decode_base64_prefix, _guess_image_mime_from_base64, _guess_image_mime_from_url, - detect_item_type, resolve_attachment_values, resolve_image_mime_and_payload, + resolve_item_type, split_attachment_urls, split_data_url, to_direct_attachment_url, @@ -79,7 +78,7 @@ def test_openai_native_routes_to_openai_batch(self) -> None: return_value=[{"custom_id": "row_0"}], ), patch( - "app.utils.get_openai_client", + "app.crud.assessment.batch.get_openai_client", return_value=MagicMock(), ), patch( @@ -141,7 +140,7 @@ def test_config_instruction_is_not_used_without_request_instruction(self) -> Non return_value=[{"custom_id": "row_0"}], ), patch( - "app.utils.get_openai_client", + "app.crud.assessment.batch.get_openai_client", return_value=MagicMock(), ), patch( @@ -196,9 +195,9 @@ def test_google_native_routes_to_google_batch(self) -> None: "app.crud.assessment.batch.build_google_jsonl", return_value=[{"key": "row_0"}], ), - patch("app.core.batch.client.GeminiClient") as gemini_cls, + patch("app.crud.assessment.batch.GeminiClient") as gemini_cls, patch( - "app.core.batch.GeminiBatchProvider", + "app.crud.assessment.batch.GeminiBatchProvider", return_value=MagicMock(), ), patch( @@ -427,152 +426,32 @@ def test_build_openai_and_google_jsonl(self) -> None: } -class TestDetectItemType: - """Per-item image/pdf detection for mixed-content attachment columns.""" +class TestResolveItemType: + """Image/pdf routing now trusts the user-declared type (no detection).""" - def test_data_url_pdf(self) -> None: - assert ( - detect_item_type("data:application/pdf;base64,JVBERi0=", "base64", "image") - == "pdf" - ) + def test_declared_image(self) -> None: + assert resolve_item_type("image") == "image" - def test_data_url_image(self) -> None: - assert ( - detect_item_type("data:image/png;base64,AAAA", "base64", "pdf") == "image" - ) + def test_declared_pdf(self) -> None: + assert resolve_item_type("pdf") == "pdf" - def test_base64_magic_pdf(self) -> None: - payload = base64.b64encode(b"%PDF-1.7 body").decode() - assert detect_item_type(payload, "base64", "image") == "pdf" - - def test_base64_magic_png(self) -> None: - payload = base64.b64encode(b"\x89PNG\r\n\x1a\n" + b"0" * 8).decode() - assert detect_item_type(payload, "base64", "pdf") == "image" - - def test_base64_unknown_falls_back(self) -> None: - payload = base64.b64encode(b"not a known magic").decode() - assert detect_item_type(payload, "base64", "pdf") == "pdf" - - def test_mixed_fallback_resolves_to_image(self) -> None: - """'mixed' is never a returned type; inconclusive detection -> image.""" - payload = base64.b64encode(b"not a known magic").decode() - assert detect_item_type(payload, "base64", "mixed") == "image" - - def test_url_extension_pdf_case_insensitive(self) -> None: - assert detect_item_type("https://x.com/a/scan.PDF", "url", "image", {}) == "pdf" - - def test_url_extension_image(self) -> None: - assert detect_item_type("https://x.com/a/p.jpg", "url", "pdf", {}) == "image" - - def test_url_no_extension_probes_bytes(self) -> None: - """Extensionless URL (Drive-style) is probed; magic bytes win over fallback.""" - url = "https://drive.google.com/file/d/ABC123/view" - resp = MagicMock() - resp.__enter__ = MagicMock(return_value=resp) - resp.__exit__ = MagicMock(return_value=False) - resp.is_redirect = False - resp.raise_for_status = MagicMock() - resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) - with patch( - "app.services.assessment.utils.attachments.validate_callback_url" - ), patch( - "app.services.assessment.utils.attachments.requests.get", - return_value=resp, - ) as mock_get: - assert detect_item_type(url, "url", "image", {}) == "pdf" - # Drive share URL is probed through the download endpoint. - assert "uc?export=download&id=ABC123" in mock_get.call_args.args[0] - - def test_url_probe_uses_content_type_when_no_magic(self) -> None: - url = "https://example.com/file" - resp = MagicMock() - resp.__enter__ = MagicMock(return_value=resp) - resp.__exit__ = MagicMock(return_value=False) - resp.is_redirect = False - resp.raise_for_status = MagicMock() - resp.iter_content = MagicMock(return_value=iter([b"\x00\x01\x02\x03"])) - resp.headers = {"Content-Type": "application/pdf; charset=binary"} - with patch( - "app.services.assessment.utils.attachments.validate_callback_url" - ), patch( - "app.services.assessment.utils.attachments.requests.get", - return_value=resp, - ): - assert detect_item_type(url, "url", "image", {}) == "pdf" + def test_override_wins(self) -> None: + assert resolve_item_type("image", "pdf") == "pdf" + assert resolve_item_type("pdf", "image") == "image" - def test_url_probe_failure_falls_back(self) -> None: - import requests as _requests + def test_mixed_without_override_defaults_to_image(self) -> None: + assert resolve_item_type("mixed") == "image" - url = "https://example.com/file" - with patch( - "app.services.assessment.utils.attachments.validate_callback_url" - ), patch( - "app.services.assessment.utils.attachments.requests.get", - side_effect=_requests.RequestException("boom"), - ): - assert detect_item_type(url, "url", "image", {}) == "image" - - def test_url_probe_follows_validated_redirect(self) -> None: - """A redirect hop is followed and re-validated before the next request.""" - url = "https://drive.google.com/file/d/RID/view" - redirect = MagicMock() - redirect.__enter__ = MagicMock(return_value=redirect) - redirect.__exit__ = MagicMock(return_value=False) - redirect.is_redirect = True - redirect.headers = {"Location": "https://files.example.com/real.pdf"} - final = MagicMock() - final.__enter__ = MagicMock(return_value=final) - final.__exit__ = MagicMock(return_value=False) - final.is_redirect = False - final.raise_for_status = MagicMock() - final.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) - with patch( - "app.services.assessment.utils.attachments.validate_callback_url" - ) as validate, patch( - "app.services.assessment.utils.attachments.requests.get", - side_effect=[redirect, final], - ) as mock_get: - assert detect_item_type(url, "url", "image", {}) == "pdf" - # Both the initial and redirected URLs were validated and fetched. - assert validate.call_count == 2 - assert mock_get.call_count == 2 - - def test_url_probe_blocked_by_ssrf_falls_back(self) -> None: - url = "https://internal.host/file" - with patch( - "app.services.assessment.utils.attachments.validate_callback_url", - side_effect=ValueError("private IP"), - ), patch("app.services.assessment.utils.attachments.requests.get") as mock_get: - # SSRF guard blocks the probe -> falls back to declared type. - assert detect_item_type(url, "url", "pdf", {}) == "pdf" - mock_get.assert_not_called() - - def test_cache_skips_second_probe(self) -> None: - url = "https://drive.google.com/file/d/XYZ/view" - cache: dict[str, str] = {} - resp = MagicMock() - resp.__enter__ = MagicMock(return_value=resp) - resp.__exit__ = MagicMock(return_value=False) - resp.is_redirect = False - resp.raise_for_status = MagicMock() - resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) - with patch( - "app.services.assessment.utils.attachments.validate_callback_url" - ), patch( - "app.services.assessment.utils.attachments.requests.get", - return_value=resp, - ) as mock_get: - assert detect_item_type(url, "url", "image", cache) == "pdf" - assert detect_item_type(url, "url", "image", cache) == "pdf" - assert mock_get.call_count == 1 - - def test_mixed_column_resolves_both_types(self) -> None: - """One column, two URLs with extensions -> one image, one pdf object.""" - att = AssessmentAttachment(column="docs", type="image", format="url") - value = "https://x.com/a/photo.jpg, https://x.com/b/report.pdf" - resolved = resolve_attachment_values(value, att, {}) + def test_unknown_declared_defaults_to_image(self) -> None: + assert resolve_item_type("whatever") == "image" + + def test_column_uses_single_declared_type(self) -> None: + """One column, many URLs -> all routed by the declared type.""" + att = AssessmentAttachment(column="docs", type="pdf", format="url") + value = "https://x.com/a/photo.jpg, https://x.com/b/report" + resolved = resolve_attachment_values(value, att) types = [obj["type"] for obj in resolved] - assert types == ["input_image", "input_file"] + assert types == ["input_file", "input_file"] class TestAttachmentMagicAndMime: @@ -589,13 +468,6 @@ def test_image_magic_all_formats(self) -> None: assert _image_mime_from_magic(b"MM\x00*") == "image/tiff" assert _image_mime_from_magic(b"nope") is None - def test_type_from_magic_pdf_and_none(self) -> None: - from app.services.assessment.utils.attachments import _type_from_magic - - assert _type_from_magic(b"%PDF-1.7") == "pdf" - assert _type_from_magic(b"\x89PNG\r\n\x1a\n") == "image" - assert _type_from_magic(b"random") is None - def test_guess_image_mime_from_url_variants(self) -> None: from app.services.assessment.utils.attachments import _guess_image_mime_from_url @@ -619,3 +491,75 @@ def test_decode_base64_prefix_empty(self) -> None: from app.services.assessment.utils.attachments import _decode_base64_prefix assert _decode_base64_prefix(" ") is None + + +class TestAttachmentTypeForRow: + def test_mixed_resolves_from_type_column(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Photo": "image", "Report": "pdf"}, + ) + assert attachment_type_for_row(att, {"DOC type": "Photo"}) == "image" + assert attachment_type_for_row(att, {"DOC type": "Report"}) == "pdf" + assert attachment_type_for_row(att, {"DOC type": "Unknown"}) is None + + def test_mixed_resolves_comma_separated_value_lists(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Img-Prototype, Img-Handtext": "image", "Pdf": "pdf"}, + ) + + assert attachment_type_for_row(att, {"DOC type": "Img-Prototype"}) == "image" + assert attachment_type_for_row(att, {"DOC type": "Img-Handtext"}) == "image" + assert attachment_type_for_row(att, {"DOC type": "pdf"}) == "pdf" + + def test_mixed_resolves_row_value_lists_when_same_type(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Img-Prototype, Img-Handtext": "image", "Pdf": "pdf"}, + ) + + assert ( + attachment_type_for_row( + att, + {"DOC type": "Img-Prototype, Img-Handtext"}, + ) + == "image" + ) + assert attachment_type_for_row(att, {"DOC type": "Img-Prototype, Pdf"}) is None + + def test_mixed_missing_type_mapping_fields_returns_none(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = SimpleNamespace(column="Docs", type="mixed", format="url") + + assert attachment_type_for_row(att, {"Docs": "x"}) is None + + def test_non_mixed_returns_none(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = AssessmentAttachment(column="Docs", type="image", format="url") + assert attachment_type_for_row(att, {"Docs": "x"}) is None + + def test_override_forces_part_type(self) -> None: + from app.services.assessment.utils.attachments import resolve_attachment_values + + att = AssessmentAttachment(column="Docs", type="mixed", format="url") + url = "https://drive.google.com/file/d/ID/view" + parts = resolve_attachment_values(url, att, type_override="pdf") + assert parts[0]["type"] == "input_file" diff --git a/backend/app/tests/assessment/test_cron.py b/backend/app/tests/assessment/test_cron.py index d9e8527eb..e2dc54fd8 100644 --- a/backend/app/tests/assessment/test_cron.py +++ b/backend/app/tests/assessment/test_cron.py @@ -103,7 +103,7 @@ async def test_no_active_runs_recompute(self) -> None: "app.crud.assessment.cron.recompute_assessment_status", return_value=refreshed, ), patch( - "app.crud.assessment.cron.check_and_process_assessment", new=AsyncMock() + "app.crud.assessment.cron.process_run_batches", new=AsyncMock() ): result = await poll_all_pending_assessment_evaluations(session=session) @@ -115,14 +115,14 @@ async def test_active_run_processed(self) -> None: session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "l2_processing" + run.stage_status = "PROCESSING" session.exec.return_value.all.return_value = [assessment] with patch( "app.crud.assessment.cron.get_assessment_runs_for_assessment", return_value=[run], ), patch( - "app.crud.assessment.cron.check_and_process_assessment", + "app.crud.assessment.cron.process_run_batches", new=AsyncMock( return_value={ "action": "processed", @@ -136,55 +136,22 @@ async def test_active_run_processed(self) -> None: assert result["processed"] == 1 @pytest.mark.asyncio - async def test_active_run_failure_and_cleanup_failure(self) -> None: + async def test_transient_poll_exception_does_not_fail_run(self) -> None: + """A transient error while polling leaves the run active for retry.""" session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "l2_processing" + run.stage_status = "PROCESSING" session.exec.return_value.all.return_value = [assessment] with patch( "app.crud.assessment.cron.get_assessment_runs_for_assessment", return_value=[run], ), patch( - "app.crud.assessment.cron.check_and_process_assessment", - new=AsyncMock(side_effect=RuntimeError("boom")), - ), patch( - "app.crud.assessment.cron.update_assessment_run_status", - side_effect=RuntimeError("cleanup-failed"), - ), patch( - "app.crud.assessment.cron.recompute_assessment_status", + "app.crud.assessment.cron.process_run_batches", + new=AsyncMock(side_effect=RuntimeError("nodename nor servname provided")), ): result = await poll_all_pending_assessment_evaluations(session=session) - assert result["failed"] == 1 - - @pytest.mark.asyncio - async def test_active_run_failure_updates_db_with_same_error_message(self) -> None: - session = MagicMock() - assessment = _make_assessment(id=1, status="processing") - run = _make_run(id=11) - run.status = "l2_processing" - session.exec.return_value.all.return_value = [assessment] - - with patch( - "app.crud.assessment.cron.get_assessment_runs_for_assessment", - return_value=[run], - ), patch( - "app.crud.assessment.cron.check_and_process_assessment", - new=AsyncMock(side_effect=RuntimeError("gemini quota exceeded")), - ), patch( - "app.crud.assessment.cron.update_assessment_run_status", - ) as update_run, patch( - "app.crud.assessment.cron.recompute_assessment_status", - ): - result = await poll_all_pending_assessment_evaluations(session=session) - - assert result["failed"] == 1 - assert result["details"][0]["error"] == "gemini quota exceeded" - update_run.assert_called_once_with( - session=session, - run=run, - status="failed", - error_message="gemini quota exceeded", - ) + assert result["failed"] == 0 + assert result["still_processing"] == 1 diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py index e2f44a21a..e68feb813 100644 --- a/backend/app/tests/assessment/test_crud.py +++ b/backend/app/tests/assessment/test_crud.py @@ -237,12 +237,15 @@ def test_build_run_stats(self) -> None: prefilter_total_rows=None, prefilter_total_passed=None, prefilter_total_rejected=None, + stage="COMPLETED", + stage_status="COMPLETED", ), ] stats = build_run_stats(runs) assert len(stats) == 1 assert stats[0].run_id == 1 assert stats[0].status == "completed" + assert stats[0].stage == "COMPLETED" def test_derive_aggregate_error(self) -> None: assert derive_aggregate_error(_counts(total=2, completed=2)) is None diff --git a/backend/app/tests/assessment/test_duplicate_detection.py b/backend/app/tests/assessment/test_duplicate_detection.py index 5d363f896..89ff2ddfc 100644 --- a/backend/app/tests/assessment/test_duplicate_detection.py +++ b/backend/app/tests/assessment/test_duplicate_detection.py @@ -1,132 +1,60 @@ -"""Tests for prefilter duplicate detection.""" - -import json -from unittest.mock import MagicMock +"""Tests for the duplicate-detection batch request builder and result parser.""" from app.services.assessment.prefilter.duplicate_detection import ( - _build_combined, - _parse_verdict, - run_duplicate_detection, + build_duplicate_detection_requests, + parse_duplicate_detection_results, ) -def _vague_client(vague: bool, reason: str = "r") -> MagicMock: - client = MagicMock() - resp = MagicMock() - resp.text = json.dumps({"vague": vague, "reason": reason}) - client.models.generate_content.return_value = resp - return client - - -class TestBuildCombined: - def test_joins_non_empty(self) -> None: - out = _build_combined({"Problem": "p", "Solution": "s", "Empty": " "}) - assert "Problem:\np" in out - assert "Solution:\ns" in out - assert "Empty" not in out - - -class TestParseVerdict: - def test_full_fields(self) -> None: - raw = ( - "Verdict: DUPLICATE\n" - "Title: Some Idea\n" - "Source: https://x.com/a\n" - "URL: https://x.com/a\n" - "Matching sentence: a beam alarm\n" - "Reason: same mechanism" - ) - out = _parse_verdict(raw) - assert out["verdict"] == "DUPLICATE" - assert out["match_title"] == "Some Idea" - assert out["source_url"] == "https://x.com/a" - assert out["matching_sentence"] == "a beam alarm" - assert out["reason"] == "same mechanism" - - def test_unique_verdict_only(self) -> None: - out = _parse_verdict("Verdict: UNIQUE\nReason: nothing matches") - assert out["verdict"] == "UNIQUE" - assert out["match_title"] is None - - def test_regex_fallback_when_key_missing(self) -> None: - out = _parse_verdict("The result is clearly OVERLAP here.") - assert out["verdict"] == "OVERLAP" - - def test_no_verdict_stays_empty(self) -> None: - out = _parse_verdict("no decision present") - assert out["verdict"] == "" - - -class TestRunDuplicateDetection: - def test_vague_short_circuits(self) -> None: - client = _vague_client(True, "too vague") - result = run_duplicate_detection( - row_idx=0, - row={"Problem": "x"}, - columns=["Problem"], - gemini_client=client, - model="gemini-2.5-flash", - store_name="store", - ) - assert result["verdict"] == "VAGUE" - assert result["reason"] == "too vague" - # Only the vague check is called; no file-search second call. - assert client.models.generate_content.call_count == 1 - - def test_not_vague_runs_file_search(self) -> None: - client = MagicMock() - vague_resp = MagicMock() - vague_resp.text = json.dumps({"vague": False, "reason": ""}) - search_resp = MagicMock() - search_resp.text = "Verdict: UNIQUE\nReason: novel" - client.models.generate_content.side_effect = [vague_resp, search_resp] - - result = run_duplicate_detection( - row_idx=1, - row={"Problem": "p", "Solution": "s"}, - columns=["Problem", "Solution"], - gemini_client=client, - model="gemini-2.5-flash", - store_name="store", - ) - assert result["verdict"] == "UNIQUE" - assert result["reason"] == "novel" - assert result["row_id"] == "row_1" - - def test_file_search_error_returns_error_verdict(self) -> None: - client = MagicMock() - vague_resp = MagicMock() - vague_resp.text = json.dumps({"vague": False, "reason": ""}) - client.models.generate_content.side_effect = [ - vague_resp, - RuntimeError("search boom"), +class TestBuildRequests: + def test_one_request_per_record(self) -> None: + rows = [(0, {"Problem": "p0", "Solution": "s0"}), (1, {"Problem": "p1"})] + lines = build_duplicate_detection_requests(rows, ["Problem", "Solution"]) + # key (gemini) or custom_id (openai) depending on configured provider. + keys = [ln.get("key") or ln.get("custom_id") for ln in lines] + assert keys == ["dup_0", "dup_1"] + + +class TestParseResults: + def test_parses_structured_verdict_per_row(self) -> None: + import json + + outputs = [ + { + "row_id": "dup_0", + "output": json.dumps( + { + "verdict": "UNIQUE", + "match_title": "", + "source_url": "", + "matching_sentence": "", + "reason": "novel", + } + ), + "error": None, + }, + { + "row_id": "dup_1", + "output": json.dumps( + { + "verdict": "DUPLICATE", + "match_title": "T", + "source_url": "http://x", + "matching_sentence": "s", + "reason": "same mechanism", + } + ), + "error": None, + }, ] - - result = run_duplicate_detection( - row_idx=2, - row={"Problem": "p"}, - columns=["Problem"], - gemini_client=client, - model="gemini-2.5-flash", - store_name="store", - ) - assert result["verdict"] == "ERROR" - assert "search boom" in result["reason"] - - def test_vague_check_parse_error_defaults_not_vague(self) -> None: - client = MagicMock() - bad_vague = MagicMock() - bad_vague.text = "not json" - search_resp = MagicMock() - search_resp.text = "Verdict: PARTIAL_MATCH\nTitle: T\nReason: theme" - client.models.generate_content.side_effect = [bad_vague, search_resp] - - result = run_duplicate_detection( - row_idx=3, - row={"Problem": "p"}, - columns=["Problem"], - gemini_client=client, - model="gemini-2.5-flash", - store_name="store", + parsed = parse_duplicate_detection_results(outputs) + assert parsed[0]["verdict"] == "UNIQUE" + assert parsed[0]["source_url"] is None # "" -> None + assert parsed[1]["verdict"] == "DUPLICATE" + assert parsed[1]["source_url"] == "http://x" + + def test_empty_response_records_error(self) -> None: + parsed = parse_duplicate_detection_results( + [{"row_id": "dup_3", "output": None, "error": None}] ) - assert result["verdict"] == "PARTIAL_MATCH" + assert parsed[3]["verdict"] == "ERROR" diff --git a/backend/app/tests/assessment/test_export.py b/backend/app/tests/assessment/test_export.py index 98eb10683..32a0d8783 100644 --- a/backend/app/tests/assessment/test_export.py +++ b/backend/app/tests/assessment/test_export.py @@ -19,6 +19,12 @@ ) +def _named_dataset() -> MagicMock: + ds = MagicMock() + ds.name = "ds" + return ds + + def _make_row( *, run_id: int = 1, @@ -415,10 +421,11 @@ def test_s3_failure_falls_back_to_provider_download(self) -> None: "app.services.assessment.utils.export.get_cloud_storage", side_effect=Exception("S3 down"), ), patch( - "app.crud.assessment.processing._get_batch_provider", + "app.services.assessment.utils.export._get_batch_provider", return_value=MagicMock(), ), patch( - "app.core.batch.download_batch_results", return_value=raw + "app.services.assessment.utils.export.download_batch_results", + return_value=raw, ): result = _load_parsed_results_for_run( session=session, run=run, batch_job=batch_job @@ -523,64 +530,48 @@ def _make_assessment(self) -> MagicMock: assessment.dataset_id = 2 return assessment - def test_no_batch_job_id_returns_empty(self) -> None: - session = MagicMock() - run = self._make_run() - run.batch_job_id = None - result = load_export_rows_for_run(session=session, run=run) - assert result == [] - - def test_batch_job_not_found_returns_empty(self) -> None: - session = MagicMock() - run = self._make_run() - with patch( - "app.services.assessment.utils.export.get_batch_job", return_value=None - ): - result = load_export_rows_for_run( - session=session, run=run, assessment=self._make_assessment() - ) - assert result == [] + def _patches(self, *, l2, prefilter=None, dataset_rows=None): + return [ + patch( + "app.services.assessment.utils.export._load_l2_results_for_run", + return_value=l2, + ), + patch( + "app.services.assessment.utils.export._load_prefilter_results", + return_value=prefilter or {}, + ), + patch( + "app.services.assessment.utils.export._load_dataset_rows_for_run", + return_value=dataset_rows if dataset_rows is not None else [], + ), + ] - def test_no_parsed_results_returns_empty(self) -> None: + def test_no_results_no_dataset_returns_empty(self) -> None: session = MagicMock() + session.get.return_value = _named_dataset() run = self._make_run() - with patch( - "app.services.assessment.utils.export.get_batch_job", - return_value=MagicMock(), - ), patch( - "app.services.assessment.utils.export._load_parsed_results_for_run", - return_value=None, - ): + p1, p2, p3 = self._patches(l2={}) + with p1, p2, p3: result = load_export_rows_for_run( session=session, run=run, assessment=self._make_assessment() ) assert result == [] - def test_parsed_results_build_export_rows(self) -> None: + def test_merged_results_build_export_rows(self) -> None: session = MagicMock() - dataset = MagicMock() - dataset.name = "ds" - session.get.return_value = dataset + session.get.return_value = _named_dataset() run = self._make_run() - parsed = [ - { + l2 = { + "row_0": { "row_id": "row_0", "output": '{"score": 5}', "error": None, "usage": None, "response_id": "r1", } - ] - with patch( - "app.services.assessment.utils.export.get_batch_job", - return_value=MagicMock(), - ), patch( - "app.services.assessment.utils.export._load_parsed_results_for_run", - return_value=parsed, - ), patch( - "app.services.assessment.utils.export._load_dataset_rows_for_run", - return_value=[], - ): + } + p1, p2, p3 = self._patches(l2=l2) + with p1, p2, p3: result = load_export_rows_for_run( session=session, run=run, assessment=self._make_assessment() ) @@ -590,61 +581,45 @@ def test_parsed_results_build_export_rows(self) -> None: def test_error_result_sets_failed_status(self) -> None: session = MagicMock() - dataset = MagicMock() - dataset.name = "ds" - session.get.return_value = dataset + session.get.return_value = _named_dataset() run = self._make_run() - parsed = [ - { + l2 = { + "row_0": { "row_id": "row_0", "output": None, "error": "timeout", "usage": None, "response_id": None, } - ] - with patch( - "app.services.assessment.utils.export.get_batch_job", - return_value=MagicMock(), - ), patch( - "app.services.assessment.utils.export._load_parsed_results_for_run", - return_value=parsed, - ), patch( - "app.services.assessment.utils.export._load_dataset_rows_for_run", - return_value=[], - ): + } + p1, p2, p3 = self._patches(l2=l2) + with p1, p2, p3: result = load_export_rows_for_run( session=session, run=run, assessment=self._make_assessment() ) assert result[0].result_status == "failed" - def test_input_data_correlated_via_row_id(self) -> None: + def test_dataset_rows_include_pending_and_correlate_input(self) -> None: session = MagicMock() - dataset = MagicMock() - dataset.name = "ds" - session.get.return_value = dataset + session.get.return_value = _named_dataset() run = self._make_run() - parsed = [ - { + run.status = "l2_processing" + l2 = { + "row_1": { "row_id": "row_1", "output": "x", "error": None, "usage": None, "response_id": None, } - ] + } dataset_rows = [{"q": "first"}, {"q": "second"}] - with patch( - "app.services.assessment.utils.export.get_batch_job", - return_value=MagicMock(), - ), patch( - "app.services.assessment.utils.export._load_parsed_results_for_run", - return_value=parsed, - ), patch( - "app.services.assessment.utils.export._load_dataset_rows_for_run", - return_value=dataset_rows, - ): + p1, p2, p3 = self._patches(l2=l2, dataset_rows=dataset_rows) + with p1, p2, p3: result = load_export_rows_for_run( session=session, run=run, assessment=self._make_assessment() ) - assert result[0].input_data == {"q": "second"} + assert len(result) == 2 + assert result[0].result_status == "processing" # row_0 not done yet + assert result[1].input_data == {"q": "second"} + assert result[1].result_status == "passed" diff --git a/backend/app/tests/assessment/test_pipeline.py b/backend/app/tests/assessment/test_pipeline.py index d74841650..c010d997d 100644 --- a/backend/app/tests/assessment/test_pipeline.py +++ b/backend/app/tests/assessment/test_pipeline.py @@ -1,151 +1,50 @@ -"""Tests for the prefilter pipeline orchestrator.""" - -from contextlib import ExitStack -from unittest.mock import MagicMock, patch - -from app.services.assessment.prefilter.pipeline import run_prefilter_pipeline - - -def _run() -> MagicMock: - run = MagicMock() - run.id = 99 - return run - - -def _tr(verdict: bool, decision: str = "ACCEPT") -> dict: - return { - "row_id": "row", - "verdict": verdict, - "decision": decision, - "column_relevance": {"Problem": verdict}, - "reasoning": "r", +"""Tests for prefilter settings + pipeline stage ordering.""" + +from app.models.assessment import Stage +from app.services.assessment.prefilter import resolve_prefilter_settings +from app.services.assessment.stages import ( + build_pipeline, + next_stage, + ordered_stages, +) + +_FULL_INPUT = { + "prefilter_config": { + "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"}, + "duplicate_detection": {"columns": ["Problem"]}, } +} -def _patches(stack: ExitStack, *, tr_side=None, dup_return=None): - """Patch the pipeline's external deps; return the TR mock.""" - client = MagicMock() - stack.enter_context( - patch( - "app.services.assessment.prefilter.pipeline.GeminiClient.from_credentials", - return_value=MagicMock(client=client), - ) - ) - stack.enter_context( - patch( - "app.services.assessment.prefilter.pipeline.get_cloud_storage", - return_value=MagicMock(), - ) - ) - stack.enter_context( - patch( - "app.services.assessment.prefilter.pipeline.upload_jsonl_to_object_store", - return_value="s3://prefilter.json", - ) - ) - stack.enter_context( - patch("app.crud.assessment.core.update_assessment_run_prefilter_stats") - ) - tr_mock = stack.enter_context( - patch("app.services.assessment.prefilter.pipeline.run_topic_relevance") - ) - if tr_side is not None: - tr_mock.side_effect = tr_side - dup_mock = stack.enter_context( - patch("app.services.assessment.prefilter.pipeline.run_duplicate_detection") - ) - if dup_return is not None: - dup_mock.return_value = dup_return - return tr_mock, dup_mock - +class TestResolvePrefilterSettings: + def test_both_enabled(self) -> None: + cfg = resolve_prefilter_settings(_FULL_INPUT["prefilter_config"]) + assert cfg["tr_enabled"] is True + assert cfg["dup_enabled"] is True -class TestRunL1Pipeline: - def test_no_filters_configured_passthrough(self) -> None: - rows = [{"Problem": "a"}, {"Problem": "b"}] - passed, indices, results = run_prefilter_pipeline( - run=_run(), - rows=rows, - prefilter_config={}, - session=MagicMock(), - organization_id=1, - project_id=1, - ) - assert passed == rows - assert indices == [0, 1] - assert results == [] + def test_disabled_when_empty(self) -> None: + cfg = resolve_prefilter_settings({}) + assert cfg["tr_enabled"] is False + assert cfg["dup_enabled"] is False - def test_topic_relevance_filters_rejected_rows(self) -> None: - rows = [{"Problem": "keep"}, {"Problem": "drop"}, {"Problem": "keep2"}] - # idx 1 rejected. - side = [_tr(True), _tr(False, "REJECT"), _tr(True)] - with ExitStack() as stack: - _patches(stack, tr_side=side) - passed, indices, results = run_prefilter_pipeline( - run=_run(), - rows=rows, - prefilter_config={ - "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"} - }, - session=MagicMock(), - organization_id=1, - project_id=1, - ) - assert indices == [0, 2] - assert [r["Problem"] for r in passed] == ["keep", "keep2"] - assert len(results) == 3 - assert results[1]["prefilter_passed"] is False - def test_duplicate_detection_runs_on_passed_rows(self) -> None: - rows = [{"Problem": "a", "Solution": "b"}] - dup = { - "row_id": "row_0", - "verdict": "UNIQUE", - "match_title": None, - "source_url": None, - "matching_sentence": None, - "reason": "novel", - } - with ExitStack() as stack: - tr_mock, dup_mock = _patches(stack, tr_side=[_tr(True)], dup_return=dup) - _, _, results = run_prefilter_pipeline( - run=_run(), - rows=rows, - prefilter_config={ - "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"}, - "duplicate_detection": {"columns": ["Problem", "Solution"]}, - }, - session=MagicMock(), - organization_id=1, - project_id=1, - ) - dup_mock.assert_called_once() - assert results[0]["duplicate_detection"]["verdict"] == "UNIQUE" +class TestPipeline: + def test_full_pipeline_order(self) -> None: + pipeline = build_pipeline(_FULL_INPUT) + assert ordered_stages(pipeline) == [ + Stage.PRE_FILTER_TOPIC_RELEVANCE, + Stage.PRE_FILTER_DUPLICATE_DETECTION, + Stage.L2_ASSESSMENT, + ] - def test_attachment_columns_filtered_to_selection(self) -> None: - from app.models.assessment import AssessmentAttachment + def test_no_prefilter_is_l2_only(self) -> None: + pipeline = build_pipeline({}) + assert ordered_stages(pipeline) == [Stage.L2_ASSESSMENT] + assert next_stage(pipeline) == Stage.L2_ASSESSMENT - rows = [{"Problem": "a", "Docs": "x", "Other": "y"}] - atts = [ - AssessmentAttachment(column="Docs", type="image", format="url"), - AssessmentAttachment(column="Other", type="image", format="url"), - ] - with ExitStack() as stack: - tr_mock, _ = _patches(stack, tr_side=[_tr(True)]) - run_prefilter_pipeline( - run=_run(), - rows=rows, - prefilter_config={ - "topic_relevance": { - "columns": ["Problem"], - "prompt": "rubric", - "attachment_columns": ["Docs"], - } - }, - session=MagicMock(), - organization_id=1, - project_id=1, - attachments=atts, - ) - # run_topic_relevance is called with only the selected attachment ("Docs"). - passed_atts = tr_mock.call_args.args[6] - assert [a.column for a in passed_atts] == ["Docs"] + def test_next_stage(self) -> None: + pipeline = build_pipeline(_FULL_INPUT) + assert next_stage(pipeline, Stage.PRE_FILTER_TOPIC_RELEVANCE) == ( + Stage.PRE_FILTER_DUPLICATE_DETECTION + ) + assert next_stage(pipeline, Stage.L2_ASSESSMENT) is None diff --git a/backend/app/tests/assessment/test_prefilter_batching.py b/backend/app/tests/assessment/test_prefilter_batching.py new file mode 100644 index 000000000..73d103a3c --- /dev/null +++ b/backend/app/tests/assessment/test_prefilter_batching.py @@ -0,0 +1,115 @@ +"""Tests for the single pipeline orchestrator (state-machine submit step).""" + +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from app.models.assessment import Stage, StageStatus +from app.services.assessment import tasks + + +@contextmanager +def _session_cm(session): + yield session + + +def _run(**kw): + base = { + "id": 5, + "assessment_id": 9, + "input": { + "prefilter_config": {"topic_relevance": {"columns": ["a"], "prompt": "p"}} + }, + "config_id": "c", + "config_version": 1, + "pipeline": None, + "stage": None, + "stage_status": None, + "status": "pending", + "stage_batches": None, + "total_items": 0, + } + base.update(kw) + return SimpleNamespace(**base) + + +class TestOrchestrate: + def test_inits_pipeline_and_submits_first_stage(self) -> None: + run = _run() + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "flag_modified"), patch.object( + tasks, "_submit_stage" + ) as submit: + tasks._orchestrate(5, 1, 1) + assert run.stage == Stage.PRE_FILTER_TOPIC_RELEVANCE + assert run.stage_status == StageStatus.PENDING + submit.assert_called_once() + + def test_skips_when_not_pending(self) -> None: + run = _run( + pipeline={"stages": [{"stage": Stage.L2_ASSESSMENT, "order": 1}]}, + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.PROCESSING, + ) + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "_submit_stage") as submit: + tasks._orchestrate(5, 1, 1) + submit.assert_not_called() + + def test_terminal_stage_returns(self) -> None: + run = _run(stage=Stage.COMPLETED) + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "_submit_stage") as submit: + tasks._orchestrate(5, 1, 1) + submit.assert_not_called() + + +class TestSubmitCurrentStage: + def _ctx(self, accepted): + return [ + patch.object( + tasks, + "_resolve_run_context", + return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), + ), + patch.object(tasks, "_load_dataset_rows", return_value=[{"a": "1"}] * 3), + patch.object(tasks, "_accepted_indices", return_value=accepted), + patch.object(tasks, "recompute_assessment_status"), + ] + + def test_submits_prefilter_batch(self) -> None: + run = _run( + stage=Stage.PRE_FILTER_TOPIC_RELEVANCE, + stage_status=StageStatus.PENDING, + stage_batches={}, + ) + session = MagicMock() + batch_job = SimpleNamespace(id=7, total_items=3) + p = self._ctx([0, 1, 2]) + with p[0], p[1], p[2], p[3], patch.object(tasks, "flag_modified"), patch.object( + tasks, "build_prefilter_requests", return_value=[{"key": "tr_0"}] + ), patch.object(tasks, "submit_prefilter_batch", return_value=batch_job): + tasks._submit_stage(session, run, 1, 1) + assert run.stage_batches[Stage.PRE_FILTER_TOPIC_RELEVANCE] == 7 + assert run.stage_status == StageStatus.PROCESSING + + def test_zero_accepted_advances(self) -> None: + run = _run( + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.PENDING, + stage_batches={}, + ) + session = MagicMock() + p = self._ctx([]) + with p[0], p[1], p[2], p[3], patch.object(tasks, "_persist_advance") as advance: + tasks._submit_stage(session, run, 1, 1) + advance.assert_called_once() diff --git a/backend/app/tests/assessment/test_processing.py b/backend/app/tests/assessment/test_processing.py index 958ab3019..da9c930b7 100644 --- a/backend/app/tests/assessment/test_processing.py +++ b/backend/app/tests/assessment/test_processing.py @@ -1,17 +1,18 @@ """Tests for assessment/processing.py pure functions.""" import json -from unittest.mock import AsyncMock, MagicMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, patch import pytest from app.crud.assessment.processing import ( _get_batch_provider, _sanitize_json_output, - check_and_process_assessment, parse_assessment_output, - poll_all_pending_assessments, + process_run_batches, ) +from app.models.assessment import Stage, StageStatus class TestSanitizeJsonOutput: @@ -213,7 +214,7 @@ def test_google_native_provider_accepted(self) -> None: class TestGetBatchProvider: def test_unsupported_provider_raises(self) -> None: session = MagicMock() - with pytest.raises(ValueError, match="Unsupported provider"): + with pytest.raises(ValueError, match="Unsupported batch provider"): _get_batch_provider( session=session, provider_name="anthropic", @@ -225,8 +226,8 @@ def test_openai_provider_returned(self) -> None: session = MagicMock() mock_client = MagicMock() with patch( - "app.crud.assessment.processing.get_openai_client", return_value=mock_client - ), patch("app.crud.assessment.processing.OpenAIBatchProvider") as mock_cls: + "app.services.assessment.stages.get_openai_client", return_value=mock_client + ), patch("app.services.assessment.stages.OpenAIBatchProvider") as mock_cls: _get_batch_provider( session=session, provider_name="openai", @@ -238,8 +239,8 @@ def test_openai_provider_returned(self) -> None: def test_google_provider_returned(self) -> None: session = MagicMock() mock_gemini = MagicMock() - with patch("app.crud.assessment.processing.GeminiClient") as mock_cls, patch( - "app.crud.assessment.processing.GeminiBatchProvider" + with patch("app.services.assessment.stages.GeminiClient") as mock_cls, patch( + "app.services.assessment.stages.GeminiBatchProvider" ) as mock_batch_cls: mock_cls.from_credentials.return_value = mock_gemini _get_batch_provider( @@ -251,186 +252,159 @@ def test_google_provider_returned(self) -> None: mock_batch_cls.assert_called_once_with(client=mock_gemini.client) -class TestPollAllPendingAssessments: - @pytest.mark.asyncio - async def test_delegates_to_cron(self) -> None: - session = MagicMock() - expected = {"processed": 2, "failed": 0} - with patch( - "app.crud.assessment.cron.poll_all_pending_assessment_evaluations", - new=AsyncMock(return_value=expected), - ): - result = await poll_all_pending_assessments(session=session) - assert result == expected - - -class TestCheckAndProcessAssessment: - def _make_run(self) -> MagicMock: - run = MagicMock() - run.id = 1 - run.batch_job_id = 99 - run.status = "processing" - run.assessment_id = 10 - run.organization_id = 1 - run.project_id = 1 - run.run_name = "exp" - return run +class TestProcessRunBatches: + def _parent(self): + return SimpleNamespace(organization_id=1, project_id=1, experiment_name="exp") + + def _run(self): + return SimpleNamespace( + id=1, + assessment_id=10, + status="processing", + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.PROCESSING, + stage_batches={Stage.L2_ASSESSMENT: 5}, + ) @pytest.mark.asyncio - async def test_completed_with_no_output_file_and_failed_counts(self) -> None: + async def test_completes_stage_and_finalizes(self) -> None: session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "completed" - batch_job.provider_output_file_id = None - batch_job.id = 99 + session.get.return_value = self._parent() + run = self._run() with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job + "app.crud.assessment.processing.get_batch_job", return_value=MagicMock() ), patch( "app.crud.assessment.processing._get_batch_provider", return_value=MagicMock(), ), patch( - "app.crud.assessment.processing.poll_batch_status", - return_value={ - "request_counts": {"failed": 3, "completed": 0, "total": 3}, - "error_file_id": "err-1", - }, - ), patch( - "app.crud.assessment.processing.update_assessment_run_status" + "app.crud.assessment.processing._poll_stage_outcome", + return_value="completed", ), patch( + "app.crud.assessment.processing.advance_or_finalize", return_value=None + ) as advance, patch( "app.crud.assessment.processing.recompute_assessment_status" ): - result = await check_and_process_assessment(run=run, session=session) + result = await process_run_batches(run=run, session=session) - assert result["action"] == "failed" - assert result["current_status"] == "failed" + advance.assert_called_once() + assert result["action"] == "processed" + assert run.stage_status == StageStatus.COMPLETED @pytest.mark.asyncio - async def test_completed_with_no_output_file_not_ready(self) -> None: + async def test_advances_and_dispatches_next_stage(self) -> None: session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "completed" - batch_job.provider_output_file_id = None - batch_job.id = 99 + session.get.return_value = self._parent() + run = self._run() + run.stage = Stage.PRE_FILTER_TOPIC_RELEVANCE + run.stage_batches = {Stage.PRE_FILTER_TOPIC_RELEVANCE: 5} with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job + "app.crud.assessment.processing.get_batch_job", return_value=MagicMock() ), patch( "app.crud.assessment.processing._get_batch_provider", return_value=MagicMock(), ), patch( - "app.crud.assessment.processing.poll_batch_status", - return_value={"request_counts": {"failed": 0, "completed": 1, "total": 1}}, - ): - result = await check_and_process_assessment(run=run, session=session) + "app.crud.assessment.processing._poll_stage_outcome", + return_value="completed", + ), patch( + "app.crud.assessment.processing._record_gate_stats" + ) as gate_stats, patch( + "app.crud.assessment.processing.advance_or_finalize", + return_value=Stage.L2_ASSESSMENT, + ), patch( + "app.crud.assessment.processing.recompute_assessment_status" + ), patch( + "app.crud.assessment.processing.run_assessment_pipeline" + ) as dispatch: + result = await process_run_batches(run=run, session=session) - assert result["action"] == "no_change" + gate_stats.assert_called_once() # TR is a gate stage + dispatch.delay.assert_called_once() + assert result["action"] == "processed" @pytest.mark.asyncio - async def test_completed_with_output_file_processes_results(self) -> None: + async def test_no_change_while_in_progress(self) -> None: session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "completed" - batch_job.provider_output_file_id = "file-1" - batch_job.id = 99 + session.get.return_value = self._parent() + run = self._run() with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job + "app.crud.assessment.processing.get_batch_job", return_value=MagicMock() ), patch( "app.crud.assessment.processing._get_batch_provider", return_value=MagicMock(), ), patch( - "app.crud.assessment.processing.poll_batch_status", - return_value={}, - ), patch( - "app.crud.assessment.processing.download_batch_results", - return_value=[{"custom_id": "row_0"}], - ), patch( - "app.crud.assessment.processing.upload_batch_results_to_object_store", - return_value="s3://results", - ), patch( - "app.crud.assessment.processing.parse_assessment_output", - return_value=[{"row_id": "row_0", "error": None}], - ), patch( - "app.crud.assessment.processing.update_assessment_run_status" - ), patch( - "app.crud.assessment.processing.recompute_assessment_status" + "app.crud.assessment.processing._poll_stage_outcome", + return_value="no_change", ): - result = await check_and_process_assessment(run=run, session=session) + result = await process_run_batches(run=run, session=session) - assert result["action"] == "processed" + assert result["action"] == "no_change" @pytest.mark.asyncio - async def test_terminal_provider_status_marks_failed(self) -> None: + async def test_failed_stage_fails_run(self) -> None: session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "failed" - batch_job.error_message = "provider failed" + session.get.return_value = self._parent() + run = self._run() with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job + "app.crud.assessment.processing.get_batch_job", + return_value=MagicMock(error_message="boom"), ), patch( "app.crud.assessment.processing._get_batch_provider", return_value=MagicMock(), ), patch( - "app.crud.assessment.processing.poll_batch_status", return_value={} + "app.crud.assessment.processing._poll_stage_outcome", return_value="failed" ), patch( "app.crud.assessment.processing.update_assessment_run_status" ), patch( "app.crud.assessment.processing.recompute_assessment_status" ): - result = await check_and_process_assessment(run=run, session=session) + result = await process_run_batches(run=run, session=session) assert result["action"] == "failed" - assert result["provider_status"] == "failed" + # Failed stage preserved (so a resume knows where to restart); only status flips. + assert run.stage == Stage.L2_ASSESSMENT + assert run.stage_status == StageStatus.FAILED - @pytest.mark.asyncio - async def test_still_processing_returns_no_change(self) -> None: - session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "in_progress" + +class TestPollStageOutcome: + def _job(self, **kw): + base = dict(provider_status="completed", provider_output_file_id=None) + base.update(kw) + return SimpleNamespace(**base) + + def test_all_failed_no_output_is_failed(self) -> None: + from app.crud.assessment.processing import _poll_stage_outcome with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job - ), patch( - "app.crud.assessment.processing._get_batch_provider", - return_value=MagicMock(), - ), patch( - "app.crud.assessment.processing.poll_batch_status", return_value={} + "app.crud.assessment.processing.poll_batch_status", + return_value={ + "request_counts": {"completed": 0, "failed": 3}, + "error_file_id": "err", + }, ): - result = await check_and_process_assessment(run=run, session=session) + outcome = _poll_stage_outcome(MagicMock(), MagicMock(), self._job()) + assert outcome == "failed" - assert result["action"] == "no_change" - - @pytest.mark.asyncio - async def test_exception_path_marks_failed(self) -> None: - session = MagicMock() - run = self._make_run() - run.batch_job_id = None + def test_no_output_not_ready_is_no_change(self) -> None: + from app.crud.assessment.processing import _poll_stage_outcome with patch( - "app.crud.assessment.processing.update_assessment_run_status" - ) as update_run, patch( - "app.crud.assessment.processing.recompute_assessment_status" + "app.crud.assessment.processing.poll_batch_status", + return_value={"request_counts": {"completed": 0, "failed": 0}}, ): - result = await check_and_process_assessment(run=run, session=session) + outcome = _poll_stage_outcome(MagicMock(), MagicMock(), self._job()) + assert outcome == "no_change" - assert result["action"] == "failed" - assert result["provider_status"] == "unknown" - assert result["error"] == "Assessment run 1 has no batch_job_id" - update_run.assert_called_once_with( - session=session, - run=run, - status="failed", - error_message="Assessment run 1 has no batch_job_id", - ) + def test_output_ready_is_completed(self) -> None: + from app.crud.assessment.processing import _poll_stage_outcome + + with patch( + "app.crud.assessment.processing.poll_batch_status", return_value={} + ), patch("app.crud.assessment.processing.process_completed_batch"): + outcome = _poll_stage_outcome( + MagicMock(), MagicMock(), self._job(provider_output_file_id="file_1") + ) + assert outcome == "completed" diff --git a/backend/app/tests/assessment/test_service.py b/backend/app/tests/assessment/test_service.py index b56ec72b7..e22c50e90 100644 --- a/backend/app/tests/assessment/test_service.py +++ b/backend/app/tests/assessment/test_service.py @@ -7,10 +7,16 @@ import pytest from fastapi import HTTPException -from app.models.assessment import AssessmentConfigRef, AssessmentCreate +from app.models.assessment import ( + AssessmentConfigRef, + AssessmentCreate, + Stage, + StageStatus, +) from app.models.config.config import ConfigTag from app.services.assessment.service import ( _build_retry_request, + resume_assessment_run, retry_assessment, retry_assessment_run, start_assessment, @@ -160,7 +166,7 @@ def test_google_provider_is_supported(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ), - patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_pipeline") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -204,7 +210,7 @@ def test_defaults_missing_provider_to_openai(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ) as create_run, - patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_pipeline") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -288,7 +294,7 @@ def test_dispatches_one_celery_task_per_config(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ), - patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_pipeline") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -378,3 +384,61 @@ def test_retry_assessment_wrappers(self) -> None: ): resp2 = retry_assessment_run(session, run, 1, 1) assert resp2.assessment_id == 1 + + +class TestResumeAssessmentRun: + def _failed_run(self, stage: str) -> MagicMock: + run = MagicMock() + run.id = 11 + run.assessment_id = 21 + run.config_id = UUID("00000000-0000-0000-0000-000000000001") + run.config_version = 1 + run.status = "failed" + run.stage = stage + run.stage_status = StageStatus.FAILED + run.pipeline = { + "stages": [ + {"stage": Stage.PRE_FILTER_TOPIC_RELEVANCE, "order": 1}, + {"stage": Stage.PRE_FILTER_DUPLICATE_DETECTION, "order": 2}, + {"stage": Stage.L2_ASSESSMENT, "order": 3}, + ] + } + run.assessment = SimpleNamespace(id=21, experiment_name="exp", dataset_id=7) + return run + + def test_rejects_non_failed_run(self) -> None: + run = self._failed_run(Stage.L2_ASSESSMENT) + run.stage_status = StageStatus.PROCESSING + with pytest.raises(HTTPException) as exc: + resume_assessment_run(MagicMock(), run, 1, 1) + assert exc.value.status_code == 400 + + def test_rejects_stage_not_in_pipeline(self) -> None: + run = self._failed_run(Stage.FAILED) + with pytest.raises(HTTPException) as exc: + resume_assessment_run(MagicMock(), run, 1, 1) + assert exc.value.status_code == 400 + + def test_resumes_in_place_from_failed_stage(self) -> None: + run = self._failed_run(Stage.L2_ASSESSMENT) + session = MagicMock() + + with ( + patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + return_value=_make_dataset(), + ), + patch("app.services.assessment.service.recompute_assessment_status"), + patch("app.celery.tasks.job_execution.run_assessment_pipeline") as dispatch, + ): + resp = resume_assessment_run(session, run, 1, 1) + + # Same run, reset to PENDING at the same (failed) stage, re-dispatched. + assert run.stage == Stage.L2_ASSESSMENT + assert run.stage_status == StageStatus.PENDING + assert run.status == "processing" + assert run.error_message is None + dispatch.delay.assert_called_once() + assert dispatch.delay.call_args.kwargs["run_id"] == 11 + assert resp.assessment_id == 21 + assert resp.num_configs == 1 diff --git a/backend/app/tests/assessment/test_tasks_failure_guard.py b/backend/app/tests/assessment/test_tasks_failure_guard.py new file mode 100644 index 000000000..9e48486b3 --- /dev/null +++ b/backend/app/tests/assessment/test_tasks_failure_guard.py @@ -0,0 +1,82 @@ +"""Tests for the pipeline orchestrator failure guard (no dangling runs).""" + +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from celery.exceptions import SoftTimeLimitExceeded + +from app.models.assessment import Stage +from app.services.assessment import tasks + + +@contextmanager +def _session_cm(session): + yield session + + +def _patch_session(run): + session = MagicMock() + session.get.return_value = run + cm = patch.object(tasks, "Session", return_value=_session_cm(session)) + return cm, session + + +class TestMarkRunFailed: + def test_marks_non_terminal_run_failed(self) -> None: + run = SimpleNamespace( + stage=Stage.PRE_FILTER_TOPIC_RELEVANCE, + stage_status="PENDING", + assessment_id=7, + ) + cm, session = _patch_session(run) + with cm, patch.object( + tasks, "update_assessment_run_status" + ) as upd, patch.object(tasks, "recompute_assessment_status") as recompute: + tasks._mark_run_failed(11, "boom") + upd.assert_called_once() + assert upd.call_args.kwargs["status"] == "failed" + # Failed stage preserved for resume; only stage_status flips to FAILED. + assert run.stage == Stage.PRE_FILTER_TOPIC_RELEVANCE + assert run.stage_status == "FAILED" + recompute.assert_called_once_with(session=session, assessment_id=7) + + def test_skips_terminal_run(self) -> None: + run = SimpleNamespace(stage=Stage.COMPLETED, assessment_id=7) + cm, _ = _patch_session(run) + with cm, patch.object(tasks, "update_assessment_run_status") as upd: + tasks._mark_run_failed(11, "boom") + upd.assert_not_called() + + def test_missing_run_noop(self) -> None: + cm, _ = _patch_session(None) + with cm, patch.object(tasks, "update_assessment_run_status") as upd: + tasks._mark_run_failed(11, "boom") + upd.assert_not_called() + + +class TestExecutePipelineGuard: + def test_soft_timeout_marks_failed_and_reraises(self) -> None: + with patch.object( + tasks, "_orchestrate", side_effect=SoftTimeLimitExceeded() + ), patch.object(tasks, "_mark_run_failed") as mark: + with pytest.raises(SoftTimeLimitExceeded): + tasks.execute_assessment_pipeline(11, 1, 1) + mark.assert_called_once() + assert mark.call_args.args[0] == 11 + + def test_unexpected_exception_marks_failed_and_reraises(self) -> None: + with patch.object( + tasks, "_orchestrate", side_effect=RuntimeError("kaboom") + ), patch.object(tasks, "_mark_run_failed") as mark: + with pytest.raises(RuntimeError): + tasks.execute_assessment_pipeline(11, 1, 1) + mark.assert_called_once_with(11, "Assessment run failed unexpectedly.") + + def test_success_does_not_mark_failed(self) -> None: + with patch.object(tasks, "_orchestrate", return_value=None), patch.object( + tasks, "_mark_run_failed" + ) as mark: + tasks.execute_assessment_pipeline(11, 1, 1) + mark.assert_not_called() diff --git a/backend/app/tests/assessment/test_topic_relevance.py b/backend/app/tests/assessment/test_topic_relevance.py index 064d4476b..fa623c9c4 100644 --- a/backend/app/tests/assessment/test_topic_relevance.py +++ b/backend/app/tests/assessment/test_topic_relevance.py @@ -1,123 +1,83 @@ -"""Tests for prefilter topic relevance attachment handling.""" +"""Tests for the topic-relevance per-record request builder and result parser.""" import json -from unittest.mock import MagicMock +from unittest.mock import patch +from app.core.config import settings from app.models.assessment import AssessmentAttachment -from app.services.assessment.prefilter.topic_relevance import run_topic_relevance - - -def _client_returning(decision: str) -> MagicMock: - client = MagicMock() - response = MagicMock() - response.text = json.dumps( - {"decision": decision, "Problem": True, "reasoning": "ok"} - ) - client.models.generate_content.return_value = response - return client - - -class TestTopicRelevanceAttachments: - def test_attachments_added_to_contents(self) -> None: - client = _client_returning("ACCEPT") - att = AssessmentAttachment(column="Documents", type="image", format="url") - row = {"Problem": "p", "Documents": "https://x.com/a/photo.jpg"} - - result = run_topic_relevance( - row_idx=0, - row=row, - columns=["Problem"], - user_prompt="rubric", - gemini_client=client, - model="gemini-2.5-flash", - attachments=[att], - type_cache={}, - ) - - assert result["verdict"] is True - contents = client.models.generate_content.call_args.kwargs["contents"] - parts = contents[0]["parts"] - # First part is the text JSON, then a label, then the attachment file part. - assert parts[0]["text"] - file_parts = [p for p in parts if "fileData" in p] - assert len(file_parts) == 1 - assert file_parts[0]["fileData"]["fileUri"] == "https://x.com/a/photo.jpg" - - def test_document_relevance_in_schema_and_result(self) -> None: - """Selected doc column gets its own relevance boolean in column_relevance.""" - client = MagicMock() - response = MagicMock() - response.text = json.dumps( +from app.services.assessment.prefilter.topic_relevance import ( + build_topic_relevance_requests, + parse_topic_relevance_results, +) + + +def _gemini(): + return patch.object(settings, "ASSESSMENT_PREFILTER_PROVIDER", "google") + + +class TestBuildRequests: + def test_one_request_per_row_with_per_column_schema(self) -> None: + rows = [(0, {"Problem": "p0"}), (1, {"Problem": "p1"})] + with _gemini(): + lines = build_topic_relevance_requests(rows, ["Problem"], "rubric") + assert [ln["key"] for ln in lines] == ["tr_0", "tr_1"] + schema = lines[0]["request"]["generationConfig"]["responseSchema"] + # per-column boolean + decision/reasoning + assert schema["properties"]["Problem"]["type"] == "boolean" + assert set(schema["required"]) == {"decision", "reasoning", "Problem"} + assert "p0" in lines[0]["request"]["contents"][0]["parts"][0]["text"] + + def test_attachment_column_adds_part_and_schema_field(self) -> None: + rows = [ + (0, {"Problem": "p0", "Docs": "https://drive.google.com/file/d/A/view"}) + ] + atts = [AssessmentAttachment(column="Docs", type="image", format="url")] + with _gemini(): + lines = build_topic_relevance_requests(rows, ["Problem"], "rubric", atts) + schema = lines[0]["request"]["generationConfig"]["responseSchema"] + assert "Docs" in schema["properties"] # attachment column gets a verdict + parts = lines[0]["request"]["contents"][0]["parts"] + assert len(parts) >= 2 # text + at least one attachment part + + def test_empty_attachments_is_text_only(self) -> None: + with _gemini(): + lines = build_topic_relevance_requests( + [(0, {"Problem": "p"})], ["Problem"], "r" + ) + assert len(lines[0]["request"]["contents"][0]["parts"]) == 1 + + +class TestParseResults: + def test_maps_decision_and_per_column_relevance(self) -> None: + outputs = [ { - "decision": "ACCEPT", - "Problem": True, - "Documents": True, - "reasoning": "ok", - } - ) - client.models.generate_content.return_value = response - att = AssessmentAttachment(column="Documents", type="image", format="url") - row = {"Problem": "p", "Documents": "https://x.com/a/photo.jpg"} - - result = run_topic_relevance( - row_idx=3, - row=row, - columns=["Problem"], - user_prompt="rubric", - gemini_client=client, - model="gemini-2.5-flash", - attachments=[att], - type_cache={}, - ) - - # Document column carried into the per-column relevance map -> exports - # as topic_relevance_Documents. - assert "Documents" in result["column_relevance"] - assert result["column_relevance"]["Documents"] is True - schema = client.models.generate_content.call_args.kwargs[ - "config" - ].response_schema - assert "Documents" in schema["properties"] - - def test_no_attachments_text_only(self) -> None: - client = _client_returning("REJECT") - row = {"Problem": "p"} - - result = run_topic_relevance( - row_idx=1, - row=row, - columns=["Problem"], - user_prompt="rubric", - gemini_client=client, - model="gemini-2.5-flash", - ) - - assert result["verdict"] is False - contents = client.models.generate_content.call_args.kwargs["contents"] - parts = contents[0]["parts"] - assert len(parts) == 1 - assert parts[0]["text"] - - def test_mixed_column_pdf_item_detected(self) -> None: - client = _client_returning("ACCEPT") - att = AssessmentAttachment(column="Documents", type="mixed", format="url") - row = {"Problem": "p", "Documents": "https://x.com/a/report.pdf"} - - run_topic_relevance( - row_idx=2, - row=row, - columns=["Problem"], - user_prompt="rubric", - gemini_client=client, - model="gemini-2.5-flash", - attachments=[att], - type_cache={}, - ) - - parts = client.models.generate_content.call_args.kwargs["contents"][0]["parts"] - pdf_parts = [ - p - for p in parts - if p.get("fileData", {}).get("mimeType") == "application/pdf" + "row_id": "tr_0", + "output": json.dumps( + { + "decision": "ACCEPT", + "reasoning": "ok", + "Problem": True, + "Docs": False, + } + ), + "error": None, + }, + { + "row_id": "tr_1", + "output": json.dumps( + {"decision": "REJECT", "reasoning": "no", "Problem": False} + ), + "error": None, + }, ] - assert len(pdf_parts) == 1 + parsed = parse_topic_relevance_results(outputs) + assert parsed[0]["verdict"] is True + assert parsed[0]["column_relevance"] == {"Problem": True, "Docs": False} + assert parsed[1]["verdict"] is False + assert parsed[1]["column_relevance"] == {"Problem": False} + + def test_bad_output_skipped(self) -> None: + parsed = parse_topic_relevance_results( + [{"row_id": "tr_0", "output": "not json", "error": None}] + ) + assert parsed == {} From bb30f883ddc0fa1132e9ffb6acb0a7c9ecc9a80a Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 4 Jun 2026 11:28:06 +0530 Subject: [PATCH 10/16] feat: refactor assessment prefilter configuration and enhance pipeline acceptance logic --- backend/app/core/config.py | 5 - backend/app/crud/assessment/processing.py | 17 +- .../assessment/prefilter/constants.py | 13 ++ .../prefilter/duplicate_detection.py | 4 +- .../assessment/prefilter/request_builder.py | 6 +- .../assessment/prefilter/topic_relevance.py | 15 +- backend/app/services/assessment/stages.py | 7 +- backend/app/services/assessment/tasks.py | 12 +- .../services/assessment/utils/attachments.py | 167 ++---------------- backend/app/tests/assessment/test_batch.py | 75 +------- .../assessment/test_prefilter_batching.py | 37 ++++ .../app/tests/assessment/test_processing.py | 43 +++++ .../tests/assessment/test_topic_relevance.py | 20 ++- 13 files changed, 178 insertions(+), 243 deletions(-) create mode 100644 backend/app/services/assessment/prefilter/constants.py diff --git a/backend/app/core/config.py b/backend/app/core/config.py index a54155922..720846eb9 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -171,11 +171,6 @@ def AWS_S3_BUCKET(self) -> str: DOC_TRANSFORMATION_PENDING_THRESHOLD_MINUTES: int = 30 PENDING_JOB_QUERY_TIMEOUT_MS: int = 1000 - # Assessment prefilter — provider + model for the batch prefilter stages. - ASSESSMENT_PREFILTER_PROVIDER: Literal["openai", "google"] = "openai" - ASSESSMENT_PREFILTER_MODEL: str = "gpt-5-mini" - ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = "vs_6a20339fbc148191867fd06d29133278" - @computed_field # type: ignore[prop-decorator] @property def COMPUTED_CELERY_WORKER_CONCURRENCY(self) -> int: diff --git a/backend/app/crud/assessment/processing.py b/backend/app/crud/assessment/processing.py index 7b860de5d..2d6b30613 100644 --- a/backend/app/crud/assessment/processing.py +++ b/backend/app/crud/assessment/processing.py @@ -8,6 +8,7 @@ from typing import Any from fastapi import HTTPException +from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session from app.celery.tasks.job_execution import run_assessment_pipeline @@ -275,7 +276,11 @@ def _poll_stage_outcome(session: Session, provider: BatchProvider, batch_job) -> def _record_gate_stats( session: Session, run: AssessmentRun, stage: str, batch_job, project_id: int ) -> None: - """For a go/no-go stage, persist passed/rejected counts from its results.""" + """For a go/no-go stage, persist passed/rejected counts and accepted row indices. + + The accepted indices are stored on ``run.pipeline`` so the next stage's batch + build reads them directly instead of re-downloading and re-parsing this batch. + """ try: raw = load_raw_batch_results(session, batch_job, project_id) outputs = parse_assessment_output(raw, batch_job.provider) @@ -289,6 +294,16 @@ def _record_gate_stats( prefilter_total_passed=passed, prefilter_total_rejected=total - passed, ) + + # Persist the cumulative accepted set (intersect with prior gates). + accepted = {idx for idx, r in parsed.items() if r.get("verdict")} + prev = (run.pipeline or {}).get("accepted_indices") + if prev is not None: + accepted &= set(prev) + pipeline = dict(run.pipeline or {}) + pipeline["accepted_indices"] = sorted(accepted) + run.pipeline = pipeline + flag_modified(run, "pipeline") except Exception as exc: logger.warning( "[_record_gate_stats] run_id=%s stage=%s — %s", run.id, stage, exc diff --git a/backend/app/services/assessment/prefilter/constants.py b/backend/app/services/assessment/prefilter/constants.py new file mode 100644 index 000000000..a18186d06 --- /dev/null +++ b/backend/app/services/assessment/prefilter/constants.py @@ -0,0 +1,13 @@ +"""Static config for the assessment prefilter stages. +""" + +from typing import Literal + +# Provider + model that run the batch prefilter stages (topic relevance, dup check). +ASSESSMENT_PREFILTER_PROVIDER: Literal["openai", "google"] = "google" +ASSESSMENT_PREFILTER_MODEL: str = "gemini-3.1-flash-lite" + +# File-search/vector store holding the corpus for duplicate detection. +ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = ( + "fileSearchStores/inquilabcorpus-782mxjcwisaz" +) diff --git a/backend/app/services/assessment/prefilter/duplicate_detection.py b/backend/app/services/assessment/prefilter/duplicate_detection.py index 6d4358b75..7512616db 100644 --- a/backend/app/services/assessment/prefilter/duplicate_detection.py +++ b/backend/app/services/assessment/prefilter/duplicate_detection.py @@ -4,7 +4,7 @@ import logging from typing import Any -from app.core.config import settings +from app.services.assessment.prefilter import constants from app.services.assessment.prefilter.request_builder import build_request_line logger = logging.getLogger(__name__) @@ -63,7 +63,7 @@ def build_duplicate_detection_requests( columns: list[str], ) -> list[dict[str, Any]]: """Build one batch JSONL line per record, grounded on the provider's corpus store.""" - store = settings.ASSESSMENT_PREFILTER_DUPLICATE_STORE or None + store = constants.ASSESSMENT_PREFILTER_DUPLICATE_STORE or None return [ build_request_line( key=f"dup_{idx}", diff --git a/backend/app/services/assessment/prefilter/request_builder.py b/backend/app/services/assessment/prefilter/request_builder.py index c9fa2e3ee..3e438ffed 100644 --- a/backend/app/services/assessment/prefilter/request_builder.py +++ b/backend/app/services/assessment/prefilter/request_builder.py @@ -2,8 +2,8 @@ from typing import Any -from app.core.config import settings from app.services.assessment.mappers import _ensure_openai_strict_schema +from app.services.assessment.prefilter import constants def build_request_line( @@ -20,9 +20,9 @@ def build_request_line( ``attachment_parts`` are provider-shaped content parts (from the OpenAI/Gemini attachment resolvers) appended after the text part. """ - model = settings.ASSESSMENT_PREFILTER_MODEL + model = constants.ASSESSMENT_PREFILTER_MODEL - if settings.ASSESSMENT_PREFILTER_PROVIDER == "openai": + if constants.ASSESSMENT_PREFILTER_PROVIDER == "openai": content: list[dict[str, Any]] = [{"type": "input_text", "text": user_text}] content.extend(attachment_parts or []) body: dict[str, Any] = { diff --git a/backend/app/services/assessment/prefilter/topic_relevance.py b/backend/app/services/assessment/prefilter/topic_relevance.py index 9b4722827..1fa6acb43 100644 --- a/backend/app/services/assessment/prefilter/topic_relevance.py +++ b/backend/app/services/assessment/prefilter/topic_relevance.py @@ -8,8 +8,8 @@ import logging from typing import Any -from app.core.config import settings from app.models.assessment import AssessmentAttachment +from app.services.assessment.prefilter import constants from app.services.assessment.prefilter.request_builder import build_request_line from app.services.assessment.utils.attachments import ( attachment_type_for_row, @@ -54,7 +54,7 @@ def build_topic_relevance_requests( ) -> list[dict[str, Any]]: """Build one batch JSONL line per row, with text columns + attachment parts.""" attachments = attachments or [] - is_openai = settings.ASSESSMENT_PREFILTER_PROVIDER == "openai" + is_openai = constants.ASSESSMENT_PREFILTER_PROVIDER == "openai" schema = _build_schema(columns + [a.column for a in attachments]) system = user_prompt.strip() + _INSTRUCTIONS @@ -112,4 +112,15 @@ def parse_topic_relevance_results( } except Exception as exc: logger.warning("[parse_topic_relevance_results] %s — %s", key, exc) + parsed[idx] = _accept_on_error() return parsed + + +def _accept_on_error() -> dict[str, Any]: + """Fail-open gate record for a row whose topic-relevance output was unparseable.""" + return { + "verdict": True, + "decision": "", + "reasoning": "", + "column_relevance": {}, + } diff --git a/backend/app/services/assessment/stages.py b/backend/app/services/assessment/stages.py index ba24e4f67..7afc276da 100644 --- a/backend/app/services/assessment/stages.py +++ b/backend/app/services/assessment/stages.py @@ -15,10 +15,9 @@ from app.core.batch.base import BatchProvider from app.core.batch.client import GeminiClient from app.core.cloud import get_cloud_storage -from app.core.config import settings from app.models.assessment import AssessmentRun, Stage, StageStatus from app.models.batch_job import BatchJob, BatchJobType -from app.services.assessment.prefilter import resolve_prefilter_settings +from app.services.assessment.prefilter import constants, resolve_prefilter_settings from app.services.assessment.prefilter.duplicate_detection import ( build_duplicate_detection_requests, parse_duplicate_detection_results, @@ -84,7 +83,7 @@ def submit_prefilter_batch( display_name: str, ) -> BatchJob: """Submit a prefilter batch on the configured provider and return the BatchJob.""" - base = settings.ASSESSMENT_PREFILTER_PROVIDER + base = constants.ASSESSMENT_PREFILTER_PROVIDER provider = _get_batch_provider( session=session, provider_name=base, @@ -100,7 +99,7 @@ def submit_prefilter_batch( else: config = { "display_name": display_name, - "model": f"models/{settings.ASSESSMENT_PREFILTER_MODEL}", + "model": f"models/{constants.ASSESSMENT_PREFILTER_MODEL}", } return start_batch_job( session=session, diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py index 2360583f8..11cdaf0a4 100644 --- a/backend/app/services/assessment/tasks.py +++ b/backend/app/services/assessment/tasks.py @@ -129,7 +129,17 @@ def _resolve_run_context( def _accepted_indices( session: Session, run: AssessmentRun, total_rows: int, project_id: int ) -> list[int]: - """Row indices that passed every gate stage before the current one.""" + """Row indices that passed every gate stage before the current one. + + Prefers the accepted set persisted by the gate stage on ``run.pipeline`` + (set in ``_record_gate_stats``), avoiding a re-download + re-parse of the + gate batch at the memory-heavy prefilter -> assessment transition. Falls back + to recomputing from the gate batches only if nothing was persisted. + """ + stored = (run.pipeline or {}).get("accepted_indices") + if stored is not None: + return [i for i in sorted(stored) if 0 <= i < total_rows] + accepted = set(range(total_rows)) for stage in ordered_stages(run.pipeline): if stage == run.stage: diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py index 7228ada71..341ff3c81 100644 --- a/backend/app/services/assessment/utils/attachments.py +++ b/backend/app/services/assessment/utils/attachments.py @@ -1,11 +1,11 @@ """Attachment resolution utilities for assessment batch builds. -Handles MIME type detection, base64 decoding, Google Drive URL normalization, -data-URL parsing, and conversion of dataset cell values into provider input objects. +URL-only: dataset cells hold attachment URLs. Handles Google Drive URL +normalization and conversion of cell values into provider input objects. +Attachments are passed to providers by reference (URL), never inlined as base64, +to keep the batch build memory-light. """ -import base64 -import binascii import logging import re from typing import Any @@ -63,18 +63,6 @@ def to_direct_attachment_url(url: str, attachment_type: str) -> str: return f"https://drive.google.com/uc?export=download&id={file_id}" -def split_data_url(value: str) -> tuple[str | None, str]: - """Return (mime_type, base64_payload) for a data URL; otherwise (None, value).""" - match = re.match( - r"^data:([^;]+);base64,(.+)$", - value.strip(), - flags=re.IGNORECASE | re.DOTALL, - ) - if not match: - return None, value.strip() - return match.group(1).strip().lower(), match.group(2).strip() - - def _guess_image_mime_from_url(url: str) -> str | None: path = urlparse(url).path or "" for ext, mime in _IMAGE_MIME_BY_EXT.items(): @@ -83,57 +71,6 @@ def _guess_image_mime_from_url(url: str) -> str | None: return None -def _decode_base64_prefix(payload: str, max_chars: int = 256) -> bytes | None: - compact = re.sub(r"\s+", "", payload) - if not compact: - return None - sample = compact[:max_chars] - padding = "=" * (-len(sample) % 4) - try: - return base64.b64decode(sample + padding, validate=False) - except (binascii.Error, ValueError): - return None - - -def _image_mime_from_magic(blob: bytes) -> str | None: - """Detect image mime type from leading magic bytes.""" - if blob.startswith(b"\x89PNG\r\n\x1a\n"): - return "image/png" - if blob.startswith(b"\xff\xd8\xff"): - return "image/jpeg" - if blob.startswith((b"GIF87a", b"GIF89a")): - return "image/gif" - if blob.startswith(b"BM"): - return "image/bmp" - if len(blob) >= 12 and blob[:4] == b"RIFF" and blob[8:12] == b"WEBP": - return "image/webp" - if blob.startswith((b"II*\x00", b"MM\x00*")): - return "image/tiff" - return None - - -def _guess_image_mime_from_base64(payload: str) -> str | None: - blob = _decode_base64_prefix(payload) - if not blob: - return None - return _image_mime_from_magic(blob) - - -def resolve_image_mime_and_payload( - value: str, - format_type: str, -) -> tuple[str, str]: - """Resolve image mime type and raw base64 payload (for base64 format).""" - if format_type == "url": - return _guess_image_mime_from_url(value) or "image/png", value - - data_url_mime, payload = split_data_url(value) - if data_url_mime and data_url_mime.startswith("image/"): - return data_url_mime, payload - - return _guess_image_mime_from_base64(payload) or "image/png", payload - - def resolve_item_type(declared: str, type_override: str | None = None) -> str: """Resolve an attachment item as 'image' or 'pdf' from the user-declared type. @@ -190,57 +127,19 @@ def resolve_attachment_values( att: AssessmentAttachment, type_override: str | None = None, ) -> list[dict[str, Any]]: - """Convert one dataset cell into one or more OpenAI-style input objects.""" + """Convert one dataset cell into one or more OpenAI-style input objects (by URL).""" value = value.strip() if not value: return [] - if att.format == "url": - values = split_attachment_urls(value) - else: - values = [value] - item_type = resolve_item_type(att.type, type_override) resolved: list[dict[str, Any]] = [] - for item_value in values: - normalized_value = ( - to_direct_attachment_url(item_value, item_type) - if att.format == "url" - else item_value - ) - + for item_value in split_attachment_urls(value): + url = to_direct_attachment_url(item_value, item_type) if item_type == "image": - if att.format == "url": - resolved.append({"type": "input_image", "image_url": normalized_value}) - else: - mime_type, payload = resolve_image_mime_and_payload( - normalized_value, - "base64", - ) - resolved.append( - { - "type": "input_image", - "image_url": f"data:{mime_type};base64,{payload}", - } - ) - elif item_type == "pdf": - if att.format == "url": - resolved.append( - { - "type": "input_file", - "file_url": normalized_value, - } - ) - else: - _, payload = split_data_url(normalized_value) - resolved.append( - { - "type": "input_file", - "file_data": f"data:application/pdf;base64,{payload}", - "filename": "document.pdf", - } - ) - + resolved.append({"type": "input_image", "image_url": url}) + else: + resolved.append({"type": "input_file", "file_url": url}) return resolved @@ -249,7 +148,7 @@ def build_gemini_attachment_parts( att: AssessmentAttachment, type_override: str | None = None, ) -> list[dict[str, Any]]: - """Convert one dataset cell into one or more Gemini content parts. + """Convert one dataset cell into one or more Gemini content parts (by URL). Mirrors the per-item type routing used for the L2 batch so the same image/pdf handling applies to prefilter (topic relevance) calls. @@ -258,45 +157,13 @@ def build_gemini_attachment_parts( if not value: return [] - values = split_attachment_urls(value) if att.format == "url" else [value] - item_type = resolve_item_type(att.type, type_override) parts: list[dict[str, Any]] = [] - for item_value in values: - normalized_value = ( - to_direct_attachment_url(item_value, item_type) - if att.format == "url" - else item_value - ) - + for item_value in split_attachment_urls(value): + url = to_direct_attachment_url(item_value, item_type) if item_type == "image": - mime_type, payload = resolve_image_mime_and_payload( - normalized_value, att.format - ) - if att.format == "url": - parts.append( - {"fileData": {"mimeType": mime_type, "fileUri": normalized_value}} - ) - else: - parts.append({"inlineData": {"mimeType": mime_type, "data": payload}}) - elif item_type == "pdf": - if att.format == "url": - parts.append( - { - "fileData": { - "mimeType": "application/pdf", - "fileUri": normalized_value, - } - } - ) - else: - parts.append( - { - "inlineData": { - "mimeType": "application/pdf", - "data": split_data_url(normalized_value)[1], - } - } - ) - + mime_type = _guess_image_mime_from_url(url) or "image/png" + parts.append({"fileData": {"mimeType": mime_type, "fileUri": url}}) + else: + parts.append({"fileData": {"mimeType": "application/pdf", "fileUri": url}}) return parts diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index 472a6f67a..ca435b781 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -18,14 +18,10 @@ ) from app.models.assessment import AssessmentAttachment from app.services.assessment.utils.attachments import ( - _decode_base64_prefix, - _guess_image_mime_from_base64, _guess_image_mime_from_url, resolve_attachment_values, - resolve_image_mime_and_payload, resolve_item_type, split_attachment_urls, - split_data_url, to_direct_attachment_url, ) @@ -347,56 +343,24 @@ def test_split_and_direct_urls(self) -> None: ) assert "drive.google.com/uc" in pdf_url - def test_data_url_and_mime_guessers(self) -> None: - mime, payload = split_data_url("data:image/png;base64,AAAA") - assert mime == "image/png" - assert payload == "AAAA" - none_mime, raw = split_data_url("rawbase64") - assert none_mime is None - assert raw == "rawbase64" + def test_url_mime_guessers(self) -> None: assert _guess_image_mime_from_url("https://x/y/file.jpeg") == "image/jpeg" assert _guess_image_mime_from_url("https://x/y/file.unknown") is None - def test_base64_guess_and_decode(self) -> None: - png_head = "iVBORw0KGgoAAAANSUhEUg==" - assert _guess_image_mime_from_base64(png_head) == "image/png" - assert _decode_base64_prefix("###") == b"" - - def testresolve_image_mime_and_payload(self) -> None: - mime, payload = resolve_image_mime_and_payload("https://x/y/file.webp", "url") - assert mime == "image/webp" - assert payload.endswith("file.webp") - mime2, payload2 = resolve_image_mime_and_payload( - "data:image/jpeg;base64,AAAA", "base64" - ) - assert mime2 == "image/jpeg" - assert payload2 == "AAAA" - def testresolve_attachment_values(self) -> None: image_url_att = AssessmentAttachment(column="img", type="image", format="url") - image_b64_att = AssessmentAttachment( - column="img", type="image", format="base64" - ) pdf_url_att = AssessmentAttachment(column="pdf", type="pdf", format="url") - pdf_b64_att = AssessmentAttachment(column="pdf", type="pdf", format="base64") values = resolve_attachment_values( "https://example.com/a.png,https://example.com/b.png", image_url_att ) assert len(values) == 2 assert values[0]["type"] == "input_image" - - values = resolve_attachment_values("data:image/png;base64,AAAA", image_b64_att) - assert values[0]["image_url"].startswith("data:image/png;base64,") + assert values[0]["image_url"] == "https://example.com/a.png" values = resolve_attachment_values("https://example.com/a.pdf", pdf_url_att) assert values[0]["type"] == "input_file" - assert "file_url" in values[0] - - values = resolve_attachment_values( - "data:application/pdf;base64,AAAA", pdf_b64_att - ) - assert values[0]["file_data"].startswith("data:application/pdf;base64,") + assert values[0]["file_url"] == "https://example.com/a.pdf" def test_build_openai_and_google_jsonl(self) -> None: rows = [{"q": "What is 2+2?", "img": "https://example.com/a.png"}] @@ -454,44 +418,13 @@ def test_column_uses_single_declared_type(self) -> None: assert types == ["input_file", "input_file"] -class TestAttachmentMagicAndMime: - def test_image_magic_all_formats(self) -> None: - from app.services.assessment.utils.attachments import _image_mime_from_magic - - assert _image_mime_from_magic(b"\x89PNG\r\n\x1a\n") == "image/png" - assert _image_mime_from_magic(b"\xff\xd8\xff") == "image/jpeg" - assert _image_mime_from_magic(b"GIF89a") == "image/gif" - assert _image_mime_from_magic(b"GIF87a") == "image/gif" - assert _image_mime_from_magic(b"BM....") == "image/bmp" - assert _image_mime_from_magic(b"RIFF\x00\x00\x00\x00WEBP") == "image/webp" - assert _image_mime_from_magic(b"II*\x00") == "image/tiff" - assert _image_mime_from_magic(b"MM\x00*") == "image/tiff" - assert _image_mime_from_magic(b"nope") is None - +class TestAttachmentMime: def test_guess_image_mime_from_url_variants(self) -> None: - from app.services.assessment.utils.attachments import _guess_image_mime_from_url - assert _guess_image_mime_from_url("http://x/a.PNG") == "image/png" assert _guess_image_mime_from_url("http://x/a.jpeg") == "image/jpeg" assert _guess_image_mime_from_url("http://x/a.webp") == "image/webp" assert _guess_image_mime_from_url("http://x/a.txt") is None - def test_resolve_image_mime_data_url(self) -> None: - from app.services.assessment.utils.attachments import ( - resolve_image_mime_and_payload, - ) - - mime, payload = resolve_image_mime_and_payload( - "data:image/webp;base64,AAAA", "base64" - ) - assert mime == "image/webp" - assert payload == "AAAA" - - def test_decode_base64_prefix_empty(self) -> None: - from app.services.assessment.utils.attachments import _decode_base64_prefix - - assert _decode_base64_prefix(" ") is None - class TestAttachmentTypeForRow: def test_mixed_resolves_from_type_column(self) -> None: diff --git a/backend/app/tests/assessment/test_prefilter_batching.py b/backend/app/tests/assessment/test_prefilter_batching.py index 73d103a3c..4c631daf2 100644 --- a/backend/app/tests/assessment/test_prefilter_batching.py +++ b/backend/app/tests/assessment/test_prefilter_batching.py @@ -113,3 +113,40 @@ def test_zero_accepted_advances(self) -> None: with p[0], p[1], p[2], p[3], patch.object(tasks, "_persist_advance") as advance: tasks._submit_stage(session, run, 1, 1) advance.assert_called_once() + + +class TestAcceptedIndices: + def test_uses_persisted_indices_without_downloading(self) -> None: + """Stored accepted set is read directly — no gate batch re-download.""" + run = _run( + pipeline={ + "stages": [ + {"stage": Stage.PRE_FILTER_TOPIC_RELEVANCE, "order": 1}, + {"stage": Stage.L2_ASSESSMENT, "order": 2}, + ], + "accepted_indices": [0, 2, 5], + }, + stage=Stage.L2_ASSESSMENT, + ) + with patch.object(tasks, "load_raw_batch_results") as load: + result = tasks._accepted_indices( + MagicMock(), run, total_rows=10, project_id=1 + ) + assert result == [0, 2, 5] + load.assert_not_called() + + def test_persisted_indices_clamped_to_total_rows(self) -> None: + run = _run( + pipeline={"stages": [], "accepted_indices": [0, 3, 99]}, + stage=Stage.L2_ASSESSMENT, + ) + result = tasks._accepted_indices(MagicMock(), run, total_rows=4, project_id=1) + assert result == [0, 3] + + def test_falls_back_to_full_range_when_nothing_persisted(self) -> None: + run = _run( + pipeline={"stages": [{"stage": Stage.L2_ASSESSMENT, "order": 1}]}, + stage=Stage.L2_ASSESSMENT, + ) + result = tasks._accepted_indices(MagicMock(), run, total_rows=3, project_id=1) + assert result == [0, 1, 2] diff --git a/backend/app/tests/assessment/test_processing.py b/backend/app/tests/assessment/test_processing.py index da9c930b7..84bace368 100644 --- a/backend/app/tests/assessment/test_processing.py +++ b/backend/app/tests/assessment/test_processing.py @@ -6,8 +6,10 @@ import pytest +from app.crud.assessment import processing as processing_mod from app.crud.assessment.processing import ( _get_batch_provider, + _record_gate_stats, _sanitize_json_output, parse_assessment_output, process_run_batches, @@ -211,6 +213,47 @@ def test_google_native_provider_accepted(self) -> None: assert results[0]["output"] == "out" +class TestRecordGateStats: + def _patches(self, parsed): + return [ + patch.object(processing_mod, "load_raw_batch_results", return_value=[]), + patch.object(processing_mod, "parse_assessment_output", return_value=[]), + patch.dict( + processing_mod.STAGE_PARSERS, + {Stage.PRE_FILTER_TOPIC_RELEVANCE: lambda _outputs: parsed}, + ), + patch.object(processing_mod, "update_assessment_run_prefilter_stats"), + patch.object(processing_mod, "flag_modified"), + ] + + def test_persists_accepted_indices_to_pipeline(self) -> None: + run = SimpleNamespace(id=1, assessment_id=2, pipeline={"stages": []}) + parsed = { + 0: {"verdict": True}, + 1: {"verdict": False}, + 2: {"verdict": True}, + } + p = self._patches(parsed) + with p[0], p[1], p[2], p[3], p[4]: + _record_gate_stats( + MagicMock(), run, Stage.PRE_FILTER_TOPIC_RELEVANCE, MagicMock(), 1 + ) + assert run.pipeline["accepted_indices"] == [0, 2] + + def test_intersects_with_prior_gate(self) -> None: + run = SimpleNamespace( + id=1, assessment_id=2, pipeline={"accepted_indices": [2, 3]} + ) + parsed = {2: {"verdict": True}, 3: {"verdict": False}, 4: {"verdict": True}} + p = self._patches(parsed) + with p[0], p[1], p[2], p[3], p[4]: + _record_gate_stats( + MagicMock(), run, Stage.PRE_FILTER_TOPIC_RELEVANCE, MagicMock(), 1 + ) + # 2 passes this gate and was in the prior accepted set; 4 wasn't; 3 rejected. + assert run.pipeline["accepted_indices"] == [2] + + class TestGetBatchProvider: def test_unsupported_provider_raises(self) -> None: session = MagicMock() diff --git a/backend/app/tests/assessment/test_topic_relevance.py b/backend/app/tests/assessment/test_topic_relevance.py index fa623c9c4..db582584b 100644 --- a/backend/app/tests/assessment/test_topic_relevance.py +++ b/backend/app/tests/assessment/test_topic_relevance.py @@ -3,8 +3,8 @@ import json from unittest.mock import patch -from app.core.config import settings from app.models.assessment import AssessmentAttachment +from app.services.assessment.prefilter import constants from app.services.assessment.prefilter.topic_relevance import ( build_topic_relevance_requests, parse_topic_relevance_results, @@ -12,7 +12,7 @@ def _gemini(): - return patch.object(settings, "ASSESSMENT_PREFILTER_PROVIDER", "google") + return patch.object(constants, "ASSESSMENT_PREFILTER_PROVIDER", "google") class TestBuildRequests: @@ -76,8 +76,20 @@ def test_maps_decision_and_per_column_relevance(self) -> None: assert parsed[1]["verdict"] is False assert parsed[1]["column_relevance"] == {"Problem": False} - def test_bad_output_skipped(self) -> None: + def test_unparseable_output_fails_open_accepted(self) -> None: + # A gate response we cannot parse must NOT silently drop the submission: + # it is accepted (verdict=True) so it still reaches L2 and is counted. parsed = parse_topic_relevance_results( [{"row_id": "tr_0", "output": "not json", "error": None}] ) - assert parsed == {} + assert parsed[0]["verdict"] is True + assert parsed[0]["decision"] == "" + assert parsed[0]["reasoning"] == "" + assert parsed[0]["column_relevance"] == {} + + def test_empty_output_fails_open_accepted(self) -> None: + parsed = parse_topic_relevance_results( + [{"row_id": "tr_0", "output": None, "error": "provider error"}] + ) + assert parsed[0]["verdict"] is True + assert parsed[0]["decision"] == "" From e89f1f2350b406df59645e9dcb6688383ead2e19 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 4 Jun 2026 11:47:40 +0530 Subject: [PATCH 11/16] feat: enhance error handling in assessment pipeline and improve attachment type validation --- backend/app/crud/assessment/processing.py | 26 ++++++++++---- backend/app/models/assessment.py | 16 +++++++-- backend/app/services/assessment/tasks.py | 22 +++++++++++- .../services/assessment/utils/attachments.py | 21 ++++++++--- .../app/services/assessment/utils/export.py | 11 +++--- backend/app/tests/assessment/test_batch.py | 36 ++++++++++++++++--- 6 files changed, 110 insertions(+), 22 deletions(-) diff --git a/backend/app/crud/assessment/processing.py b/backend/app/crud/assessment/processing.py index 2d6b30613..dbf792bae 100644 --- a/backend/app/crud/assessment/processing.py +++ b/backend/app/crud/assessment/processing.py @@ -376,12 +376,26 @@ async def process_run_batches(run: AssessmentRun, session: Session) -> dict[str, recompute_assessment_status(session=session, assessment_id=run.assessment_id) if nxt: - run_assessment_pipeline.delay( - run_id=run.id, - organization_id=parent.organization_id, - project_id=parent.project_id, - trace_id="", - ) + try: + run_assessment_pipeline.delay( + run_id=run.id, + organization_id=parent.organization_id, + project_id=parent.project_id, + trace_id="", + ) + except Exception as exc: + logger.error( + "[process_run_batches] run_id=%s stage=%s enqueue failed — marking failed for resume: %s", + run.id, + run.stage, + exc, + exc_info=True, + ) + return _fail_run_stage( + session, + run, + "Failed to enqueue the next pipeline stage. Resume the run to retry.", + ) return { "run_id": run.id, diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index 9d1378f83..fee792d8a 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlalchemy import Column, Index, Text from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField @@ -349,7 +349,7 @@ class AssessmentAttachment(BaseModel): "For 'mixed': the dataset column whose value decides each row's type." ), ) - type_value_map: dict[str, str] | None = Field( + type_value_map: dict[str, Literal["image", "pdf"]] | None = Field( None, description=( "For 'mixed': maps a type_column value to 'image' or 'pdf' " @@ -357,6 +357,18 @@ class AssessmentAttachment(BaseModel): ), ) + @model_validator(mode="after") + def _validate_mixed_config(self) -> "AssessmentAttachment": + """A 'mixed' column must carry the per-row routing fields; others must not.""" + if self.type == "mixed": + if self.type_column is None or self.type_value_map is None: + raise ValueError( + "type='mixed' requires both 'type_column' and 'type_value_map'." + ) + if not self.type_value_map: + raise ValueError("type_value_map must not be empty for type='mixed'.") + return self + class AssessmentConfigRef(BaseModel): """Reference to a stored config version.""" diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py index 11cdaf0a4..d21d86eeb 100644 --- a/backend/app/services/assessment/tasks.py +++ b/backend/app/services/assessment/tasks.py @@ -280,5 +280,25 @@ def _persist_advance( session.add(run) session.commit() recompute_assessment_status(session=session, assessment_id=run.assessment_id) - if nxt: + if not nxt: + return + # Commit precedes dispatch (the worker only acts on a committed PENDING run). + # If the broker call fails the run would otherwise sit at PENDING forever — the + # cron only re-polls PROCESSING runs — so mark it failed (resumable) instead. + try: _dispatch(run.id, organization_id, project_id) + except Exception: + logger.error( + "[_persist_advance] run_id=%s stage=%s enqueue failed — marking failed for resume", + run.id, + run.stage, + exc_info=True, + ) + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Failed to enqueue the next pipeline stage. Resume the run to retry.", + ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py index 341ff3c81..26f82ad75 100644 --- a/backend/app/services/assessment/utils/attachments.py +++ b/backend/app/services/assessment/utils/attachments.py @@ -71,14 +71,15 @@ def _guess_image_mime_from_url(url: str) -> str | None: return None -def resolve_item_type(declared: str, type_override: str | None = None) -> str: +def resolve_item_type(declared: str, type_override: str | None = None) -> str | None: """Resolve an attachment item as 'image' or 'pdf' from the user-declared type. - Trusts the user: a per-row ``type_override`` (for 'mixed' columns) wins, else the - column's declared ``type``. Anything non-concrete falls back to 'image'. + A per-row ``type_override`` (for 'mixed' columns) wins, else the column's declared + ``type``. Returns None when the type stays unresolved (e.g. a 'mixed' row whose + value didn't map to a concrete type) so callers can skip rather than guess. """ item_type = type_override or declared - return item_type if item_type in ("image", "pdf") else "image" + return item_type if item_type in ("image", "pdf") else None def _normalize_type_value(value: str) -> str: @@ -133,6 +134,12 @@ def resolve_attachment_values( return [] item_type = resolve_item_type(att.type, type_override) + if item_type is None: + logger.warning( + "[resolve_attachment_values] Unresolved type for column=%s — skipping", + att.column, + ) + return [] resolved: list[dict[str, Any]] = [] for item_value in split_attachment_urls(value): url = to_direct_attachment_url(item_value, item_type) @@ -158,6 +165,12 @@ def build_gemini_attachment_parts( return [] item_type = resolve_item_type(att.type, type_override) + if item_type is None: + logger.warning( + "[build_gemini_attachment_parts] Unresolved type for column=%s — skipping", + att.column, + ) + return [] parts: list[dict[str, Any]] = [] for item_value in split_attachment_urls(value): url = to_direct_attachment_url(item_value, item_type) diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index c83714d9a..4dded4bab 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -617,18 +617,21 @@ def load_export_rows_for_run( return rows # Dataset unavailable — emit whatever results we have, indexed by row_id. + all_row_ids = sorted( + {str(rid) for rid in l2_by_row_id} | {str(rid) for rid in prefilter_by_row_id} + ) return [ _build_export_row( run=run, assessment=assessment, dataset_name=dataset_name, - row_id=str(row_id), + row_id=row_id, input_data=None, - prefilter_item=prefilter_by_row_id.get(str(row_id)), - l2_item=l2_item, + prefilter_item=prefilter_by_row_id.get(row_id), + l2_item=l2_by_row_id.get(row_id), has_prefilter=has_prefilter, ) - for row_id, l2_item in l2_by_row_id.items() + for row_id in all_row_ids ] diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index ca435b781..18628c7f6 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -403,11 +403,11 @@ def test_override_wins(self) -> None: assert resolve_item_type("image", "pdf") == "pdf" assert resolve_item_type("pdf", "image") == "image" - def test_mixed_without_override_defaults_to_image(self) -> None: - assert resolve_item_type("mixed") == "image" + def test_mixed_without_override_is_unresolved(self) -> None: + assert resolve_item_type("mixed") is None - def test_unknown_declared_defaults_to_image(self) -> None: - assert resolve_item_type("whatever") == "image" + def test_unknown_declared_is_unresolved(self) -> None: + assert resolve_item_type("whatever") is None def test_column_uses_single_declared_type(self) -> None: """One column, many URLs -> all routed by the declared type.""" @@ -489,10 +489,36 @@ def test_non_mixed_returns_none(self) -> None: att = AssessmentAttachment(column="Docs", type="image", format="url") assert attachment_type_for_row(att, {"Docs": "x"}) is None + def test_mixed_config_missing_routing_fields_is_rejected(self) -> None: + import pytest + from pydantic import ValidationError + + with pytest.raises(ValidationError): + AssessmentAttachment(column="Docs", type="mixed", format="url") + + def test_mixed_config_invalid_map_value_is_rejected(self) -> None: + import pytest + from pydantic import ValidationError + + with pytest.raises(ValidationError): + AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Report": "spreadsheet"}, + ) + def test_override_forces_part_type(self) -> None: from app.services.assessment.utils.attachments import resolve_attachment_values - att = AssessmentAttachment(column="Docs", type="mixed", format="url") + att = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Report": "pdf"}, + ) url = "https://drive.google.com/file/d/ID/view" parts = resolve_attachment_values(url, att, type_override="pdf") assert parts[0]["type"] == "input_file" From 87ee6a54c04cf37837b13a54b711b22afc22f1b9 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 4 Jun 2026 12:27:31 +0530 Subject: [PATCH 12/16] feat: add error handling for deterministic failures in assessment evaluation polling --- backend/app/crud/assessment/cron.py | 29 ++++++++++++++ backend/app/tests/assessment/test_cron.py | 24 +++++++++++ .../assessment/test_duplicate_detection.py | 25 ++++++++++++ .../tests/assessment/test_topic_relevance.py | 40 +++++++++++++++++++ 4 files changed, 118 insertions(+) diff --git a/backend/app/crud/assessment/cron.py b/backend/app/crud/assessment/cron.py index 000d61666..397554607 100644 --- a/backend/app/crud/assessment/cron.py +++ b/backend/app/crud/assessment/cron.py @@ -9,6 +9,7 @@ compute_run_counts, get_assessment_runs_for_assessment, recompute_assessment_status, + update_assessment_run_status, ) from app.crud.assessment.processing import ( format_assessment_failure_message, @@ -115,6 +116,34 @@ async def poll_all_pending_assessment_evaluations( else: still_processing += 1 + except ValueError as e: + session.rollback() + message = format_assessment_failure_message(e) + logger.error( + "[poll_all_pending_assessment_evaluations] deterministic error on " + "run %s (assessment %s), marking failed: %s", + run.id, + run.assessment_id, + message, + ) + try: + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message=message, + ) + failed += 1 + except Exception: + session.rollback() + logger.error( + "[poll_all_pending_assessment_evaluations] could not mark run " + "%s failed", + run.id, + exc_info=True, + ) + still_processing += 1 except Exception as e: session.rollback() logger.warning( diff --git a/backend/app/tests/assessment/test_cron.py b/backend/app/tests/assessment/test_cron.py index e2dc54fd8..77797bb88 100644 --- a/backend/app/tests/assessment/test_cron.py +++ b/backend/app/tests/assessment/test_cron.py @@ -155,3 +155,27 @@ async def test_transient_poll_exception_does_not_fail_run(self) -> None: assert result["failed"] == 0 assert result["still_processing"] == 1 + + @pytest.mark.asyncio + async def test_deterministic_error_marks_run_failed(self) -> None: + """A deterministic ValueError fails the run instead of retrying forever.""" + session = MagicMock() + assessment = _make_assessment(id=1, status="processing") + run = _make_run(id=11) + run.stage_status = "PROCESSING" + session.exec.return_value.all.return_value = [assessment] + + with patch( + "app.crud.assessment.cron.get_assessment_runs_for_assessment", + return_value=[run], + ), patch( + "app.crud.assessment.cron.process_run_batches", + new=AsyncMock(side_effect=ValueError("Parent assessment 1 not found")), + ), patch( + "app.crud.assessment.cron.update_assessment_run_status" + ) as mark_failed: + result = await poll_all_pending_assessment_evaluations(session=session) + + assert result["failed"] == 1 + assert result["still_processing"] == 0 + assert mark_failed.call_args.kwargs["status"] == "failed" diff --git a/backend/app/tests/assessment/test_duplicate_detection.py b/backend/app/tests/assessment/test_duplicate_detection.py index 89ff2ddfc..b9b0c033a 100644 --- a/backend/app/tests/assessment/test_duplicate_detection.py +++ b/backend/app/tests/assessment/test_duplicate_detection.py @@ -1,5 +1,8 @@ """Tests for the duplicate-detection batch request builder and result parser.""" +from unittest.mock import patch + +from app.services.assessment.prefilter import constants from app.services.assessment.prefilter.duplicate_detection import ( build_duplicate_detection_requests, parse_duplicate_detection_results, @@ -14,6 +17,17 @@ def test_one_request_per_record(self) -> None: keys = [ln.get("key") or ln.get("custom_id") for ln in lines] assert keys == ["dup_0", "dup_1"] + def test_openai_request_grounds_on_file_search_store(self) -> None: + with patch.object( + constants, "ASSESSMENT_PREFILTER_PROVIDER", "openai" + ), patch.object(constants, "ASSESSMENT_PREFILTER_DUPLICATE_STORE", "vs_corpus"): + lines = build_duplicate_detection_requests( + [(0, {"Problem": "p"})], ["Problem"] + ) + tool = lines[0]["body"]["tools"][0] + assert tool["type"] == "file_search" + assert tool["vector_store_ids"] == ["vs_corpus"] + class TestParseResults: def test_parses_structured_verdict_per_row(self) -> None: @@ -58,3 +72,14 @@ def test_empty_response_records_error(self) -> None: [{"row_id": "dup_3", "output": None, "error": None}] ) assert parsed[3]["verdict"] == "ERROR" + + def test_bad_json_records_error_and_foreign_keys_skipped(self) -> None: + parsed = parse_duplicate_detection_results( + [ + {"row_id": "tr_0", "output": "{}", "error": None}, # not a dup key + {"row_id": "dup_x", "output": "{}", "error": None}, # bad index + {"row_id": "dup_4", "output": "{not json", "error": None}, + ] + ) + assert set(parsed) == {4} + assert parsed[4]["verdict"] == "ERROR" diff --git a/backend/app/tests/assessment/test_topic_relevance.py b/backend/app/tests/assessment/test_topic_relevance.py index db582584b..b06166982 100644 --- a/backend/app/tests/assessment/test_topic_relevance.py +++ b/backend/app/tests/assessment/test_topic_relevance.py @@ -15,6 +15,28 @@ def _gemini(): return patch.object(constants, "ASSESSMENT_PREFILTER_PROVIDER", "google") +def _openai(): + return patch.object(constants, "ASSESSMENT_PREFILTER_PROVIDER", "openai") + + +class TestBuildRequestsOpenAI: + def test_openai_request_shape(self) -> None: + rows = [(0, {"Problem": "p0", "Docs": "https://x.com/a.png"})] + atts = [AssessmentAttachment(column="Docs", type="image", format="url")] + with _openai(): + lines = build_topic_relevance_requests(rows, ["Problem"], "rubric", atts) + line = lines[0] + assert line["custom_id"] == "tr_0" + assert line["url"] == "/v1/responses" + body = line["body"] + assert body["instructions"].startswith("rubric") + content = body["input"][0]["content"] + assert content[0] == {"type": "input_text", "text": "Problem:\np0"} + assert content[1]["type"] == "input_image" + assert body["text"]["format"]["type"] == "json_schema" + assert body["text"]["format"]["schema"]["additionalProperties"] is False + + class TestBuildRequests: def test_one_request_per_row_with_per_column_schema(self) -> None: rows = [(0, {"Problem": "p0"}), (1, {"Problem": "p1"})] @@ -46,6 +68,15 @@ def test_empty_attachments_is_text_only(self) -> None: ) assert len(lines[0]["request"]["contents"][0]["parts"]) == 1 + def test_blank_attachment_cell_is_skipped(self) -> None: + att = AssessmentAttachment(column="Docs", type="image", format="url") + with _gemini(): + lines = build_topic_relevance_requests( + [(0, {"Problem": "p", "Docs": " "})], ["Problem"], "r", [att] + ) + # Whitespace-only attachment cell -> only the text part survives. + assert len(lines[0]["request"]["contents"][0]["parts"]) == 1 + class TestParseResults: def test_maps_decision_and_per_column_relevance(self) -> None: @@ -93,3 +124,12 @@ def test_empty_output_fails_open_accepted(self) -> None: ) assert parsed[0]["verdict"] is True assert parsed[0]["decision"] == "" + + def test_foreign_and_bad_index_keys_skipped(self) -> None: + parsed = parse_topic_relevance_results( + [ + {"row_id": "dup_0", "output": "{}", "error": None}, # not a tr key + {"row_id": "tr_x", "output": "{}", "error": None}, # bad index + ] + ) + assert parsed == {} From 61798b6357d3f491b2df84deb30a624af9198051 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 4 Jun 2026 12:37:43 +0530 Subject: [PATCH 13/16] feat: update assessment prefilter constants for provider and model configuration --- backend/app/services/assessment/prefilter/constants.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/backend/app/services/assessment/prefilter/constants.py b/backend/app/services/assessment/prefilter/constants.py index a18186d06..429044ade 100644 --- a/backend/app/services/assessment/prefilter/constants.py +++ b/backend/app/services/assessment/prefilter/constants.py @@ -4,10 +4,8 @@ from typing import Literal # Provider + model that run the batch prefilter stages (topic relevance, dup check). -ASSESSMENT_PREFILTER_PROVIDER: Literal["openai", "google"] = "google" -ASSESSMENT_PREFILTER_MODEL: str = "gemini-3.1-flash-lite" +ASSESSMENT_PREFILTER_PROVIDER: Literal["openai", "google"] = "openai" +ASSESSMENT_PREFILTER_MODEL: str = "gpt-5-mini" # File-search/vector store holding the corpus for duplicate detection. -ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = ( - "fileSearchStores/inquilabcorpus-782mxjcwisaz" -) +ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = "vs_6a20339fbc148191867fd06d29133278" From d4d88a2d95c5287e5a5ba794b782d0766e3f4e77 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 4 Jun 2026 14:22:14 +0530 Subject: [PATCH 14/16] feat: enhance assessment tests with additional attachment handling and pipeline logic --- backend/app/tests/assessment/test_batch.py | 43 ++++ backend/app/tests/assessment/test_export.py | 188 +++++++++++++++++ backend/app/tests/assessment/test_pipeline.py | 58 ++++- .../assessment/test_prefilter_batching.py | 198 ++++++++++++++++++ 4 files changed, 486 insertions(+), 1 deletion(-) diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index 18628c7f6..c025ba98c 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -19,6 +19,8 @@ from app.models.assessment import AssessmentAttachment from app.services.assessment.utils.attachments import ( _guess_image_mime_from_url, + attachment_type_for_row, + build_gemini_attachment_parts, resolve_attachment_values, resolve_item_type, split_attachment_urls, @@ -522,3 +524,44 @@ def test_override_forces_part_type(self) -> None: url = "https://drive.google.com/file/d/ID/view" parts = resolve_attachment_values(url, att, type_override="pdf") assert parts[0]["type"] == "input_file" + + +class TestAttachmentResolutionBranches: + _IMG = AssessmentAttachment(column="Docs", type="image", format="url") + _PDF = AssessmentAttachment(column="Docs", type="pdf", format="url") + _MIXED = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Report": "pdf"}, + ) + + def test_blank_value_returns_empty(self) -> None: + assert resolve_attachment_values(" ", self._IMG) == [] + assert build_gemini_attachment_parts(" ", self._IMG) == [] + + def test_unresolved_mixed_is_skipped(self) -> None: + url = "https://x.com/a.jpg" + # No override and declared 'mixed' -> unresolved -> skip rather than guess. + assert resolve_attachment_values(url, self._MIXED) == [] + assert build_gemini_attachment_parts(url, self._MIXED) == [] + + def test_gemini_image_and_pdf_parts(self) -> None: + img = build_gemini_attachment_parts("https://x.com/a.png", self._IMG)[0] + pdf = build_gemini_attachment_parts("https://x.com/a.pdf", self._PDF)[0] + assert img["fileData"]["mimeType"] == "image/png" + assert pdf["fileData"]["mimeType"] == "application/pdf" + + def test_type_for_row_blank_value_returns_none(self) -> None: + assert attachment_type_for_row(self._MIXED, {"DOC type": " "}) is None + + def test_type_for_row_ignores_invalid_map_value(self) -> None: + # SimpleNamespace bypasses the model validator to exercise the guard that + # skips map entries whose target type isn't 'image'/'pdf'. + att = SimpleNamespace( + type="mixed", + type_column="DOC type", + type_value_map={"Report": "spreadsheet"}, + ) + assert attachment_type_for_row(att, {"DOC type": "Report"}) is None diff --git a/backend/app/tests/assessment/test_export.py b/backend/app/tests/assessment/test_export.py index 32a0d8783..9dfd6cbd2 100644 --- a/backend/app/tests/assessment/test_export.py +++ b/backend/app/tests/assessment/test_export.py @@ -2,21 +2,121 @@ import json from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch from app.models.assessment import AssessmentExportRow +from app.services.assessment.utils import export as export_mod from app.services.assessment.utils.export import ( + _build_export_row, _drop_empty_columns, _expand_input_columns, _expand_output_columns, _load_dataset_rows_for_run, + _load_l2_results_for_run, + _load_parsed_results_for_batch_job, _load_parsed_results_for_run, + _load_prefilter_results, _safe_filename_part, + _stage_batch_job, build_json_export_rows, load_export_rows_for_run, serialize_export_rows, sort_export_rows, ) +from app.models.assessment import Stage + + +def _run_ns(status: str = "processing") -> SimpleNamespace: + return SimpleNamespace( + id=5, + assessment_id=9, + status=status, + config_id="00000000-0000-0000-0000-000000000001", + config_version=1, + updated_at=datetime(2026, 1, 1), + ) + + +def _assessment_ns() -> SimpleNamespace: + return SimpleNamespace(experiment_name="exp", dataset_id=3) + + +class TestBuildExportRow: + def test_prefilter_rejected_with_annotations(self) -> None: + prefilter_item = { + "prefilter_passed": False, + "topic_relevance": { + "decision": "REJECT", + "reasoning": "off-topic", + "column_relevance": {"Problem": False}, + }, + "duplicate_detection": {"row_id": "dup_0", "verdict": "UNIQUE"}, + } + row = _build_export_row( + run=_run_ns(), + assessment=_assessment_ns(), + dataset_name="ds", + row_id="row_0", + input_data={"Problem": "p"}, + prefilter_item=prefilter_item, + l2_item=None, + has_prefilter=True, + ) + assert row.result_status == "prefilter_rejected" + assert json.loads(row.topic_relevance)["decision"] == "REJECT" + assert json.loads(row.duplicate_detection)["verdict"] == "UNIQUE" + + def test_passed_with_l2_output(self) -> None: + row = _build_export_row( + run=_run_ns(), + assessment=_assessment_ns(), + dataset_name=None, + row_id="row_1", + input_data=None, + prefilter_item={"prefilter_passed": True}, + l2_item={"output": "{}", "error": None}, + has_prefilter=True, + ) + assert row.result_status == "passed" + + def test_l2_error_is_failed_and_no_prefilter_cols(self) -> None: + row = _build_export_row( + run=_run_ns(), + assessment=_assessment_ns(), + dataset_name=None, + row_id="row_2", + input_data=None, + prefilter_item=None, + l2_item={"output": None, "error": "boom"}, + has_prefilter=False, + ) + assert row.result_status == "failed" + assert row.topic_relevance is None + + def test_no_l2_processing_vs_failed(self) -> None: + processing = _build_export_row( + run=_run_ns(status="processing"), + assessment=_assessment_ns(), + dataset_name=None, + row_id="row_3", + input_data=None, + prefilter_item=None, + l2_item=None, + has_prefilter=False, + ) + failed = _build_export_row( + run=_run_ns(status="failed"), + assessment=_assessment_ns(), + dataset_name=None, + row_id="row_4", + input_data=None, + prefilter_item=None, + l2_item=None, + has_prefilter=False, + ) + assert processing.result_status == "processing" + assert failed.result_status == "failed" def _named_dataset() -> MagicMock: @@ -623,3 +723,91 @@ def test_dataset_rows_include_pending_and_correlate_input(self) -> None: assert result[0].result_status == "processing" # row_0 not done yet assert result[1].input_data == {"q": "second"} assert result[1].result_status == "passed" + + +class TestStageBatchJob: + def test_returns_job_for_stage(self) -> None: + run = SimpleNamespace(stage_batches={Stage.L2_ASSESSMENT: 7}) + with patch.object(export_mod, "get_batch_job", return_value="JOB") as g: + assert _stage_batch_job(MagicMock(), run, Stage.L2_ASSESSMENT) == "JOB" + assert g.call_args.kwargs["batch_job_id"] == 7 + + def test_none_when_no_batch(self) -> None: + run = SimpleNamespace(stage_batches=None) + assert _stage_batch_job(MagicMock(), run, Stage.L2_ASSESSMENT) is None + + +class TestLoadPrefilterResults: + def test_merges_tr_and_dup_annotations(self) -> None: + run = SimpleNamespace(id=5) + assessment = SimpleNamespace(project_id=1) + with patch.object(export_mod, "_stage_batch_job", return_value=SimpleNamespace(provider="openai")), patch.object( + export_mod, "load_raw_batch_results", return_value=[] + ), patch.object(export_mod, "parse_assessment_output", return_value=[]), patch.object( + export_mod, + "parse_topic_relevance_results", + return_value={0: {"verdict": True, "decision": "ACCEPT", "reasoning": "ok", "column_relevance": {"a": True}}}, + ), patch.object( + export_mod, + "parse_duplicate_detection_results", + return_value={0: {"verdict": "UNIQUE"}}, + ): + out = _load_prefilter_results(MagicMock(), run, assessment) + assert out["row_0"]["prefilter_passed"] is True + assert out["row_0"]["topic_relevance"]["decision"] == "ACCEPT" + assert out["row_0"]["duplicate_detection"]["verdict"] == "UNIQUE" + + def test_tr_load_failure_is_swallowed(self) -> None: + run = SimpleNamespace(id=5) + assessment = SimpleNamespace(project_id=1) + with patch.object(export_mod, "_stage_batch_job", return_value=SimpleNamespace(provider="openai")), patch.object( + export_mod, "load_raw_batch_results", side_effect=RuntimeError("s3 down") + ): + out = _load_prefilter_results(MagicMock(), run, assessment) + assert out == {} + + +class TestLoadParsedResultsForBatchJob: + def test_object_store_path(self) -> None: + job = SimpleNamespace(id=1, provider="openai", raw_output_url="s3://x", provider_output_file_id=None) + assessment = SimpleNamespace(project_id=1, organization_id=1) + storage = MagicMock() + storage.stream.return_value.read.return_value.decode.return_value = "raw" + with patch.object(export_mod, "get_cloud_storage", return_value=storage), patch.object( + export_mod, "parse_stored_results", return_value=[{"k": 1}] + ), patch.object(export_mod, "parse_assessment_output", return_value=[{"row_id": "row_0"}]) as parse: + result = _load_parsed_results_for_batch_job(MagicMock(), job, assessment) + assert result == [{"row_id": "row_0"}] + parse.assert_called_once() + + def test_provider_fallback_path(self) -> None: + job = SimpleNamespace(id=1, provider="openai", raw_output_url=None, provider_output_file_id="f1", organization_id=1) + assessment = SimpleNamespace(project_id=1, organization_id=1) + with patch.object(export_mod, "_get_batch_provider", return_value=MagicMock()), patch.object( + export_mod, "download_batch_results", return_value=[{"k": 1}] + ), patch.object(export_mod, "parse_assessment_output", return_value=[{"row_id": "row_1"}]): + result = _load_parsed_results_for_batch_job(MagicMock(), job, assessment) + assert result == [{"row_id": "row_1"}] + + def test_returns_none_without_outputs(self) -> None: + job = SimpleNamespace(id=1, provider="openai", raw_output_url=None, provider_output_file_id=None) + assessment = SimpleNamespace(project_id=1, organization_id=1) + assert _load_parsed_results_for_batch_job(MagicMock(), job, assessment) is None + + +class TestLoadL2ResultsForRun: + def test_keys_by_row_id(self) -> None: + run = SimpleNamespace() + assessment = SimpleNamespace() + with patch.object(export_mod, "_stage_batch_job", return_value=SimpleNamespace()), patch.object( + export_mod, + "_load_parsed_results_for_batch_job", + return_value=[{"row_id": "row_0", "output": "x"}, {"no_row": 1}], + ): + merged = _load_l2_results_for_run(MagicMock(), run, assessment) + assert set(merged) == {"row_0"} + + def test_empty_when_no_batch(self) -> None: + with patch.object(export_mod, "_stage_batch_job", return_value=None): + merged = _load_l2_results_for_run(MagicMock(), SimpleNamespace(), SimpleNamespace()) + assert merged == {} diff --git a/backend/app/tests/assessment/test_pipeline.py b/backend/app/tests/assessment/test_pipeline.py index c010d997d..c807d415d 100644 --- a/backend/app/tests/assessment/test_pipeline.py +++ b/backend/app/tests/assessment/test_pipeline.py @@ -1,9 +1,15 @@ """Tests for prefilter settings + pipeline stage ordering.""" -from app.models.assessment import Stage +from types import SimpleNamespace + +import pytest + +from app.models.assessment import Stage, StageStatus from app.services.assessment.prefilter import resolve_prefilter_settings from app.services.assessment.stages import ( + advance_or_finalize, build_pipeline, + build_prefilter_requests, next_stage, ordered_stages, ) @@ -48,3 +54,53 @@ def test_next_stage(self) -> None: Stage.PRE_FILTER_DUPLICATE_DETECTION ) assert next_stage(pipeline, Stage.L2_ASSESSMENT) is None + + +class TestAdvanceOrFinalize: + def test_advances_to_next_pending_stage(self) -> None: + run = SimpleNamespace( + pipeline=build_pipeline(_FULL_INPUT), + stage=Stage.PRE_FILTER_TOPIC_RELEVANCE, + stage_status=StageStatus.COMPLETED, + status="processing", + ) + nxt = advance_or_finalize(run) + assert nxt == Stage.PRE_FILTER_DUPLICATE_DETECTION + assert run.stage == Stage.PRE_FILTER_DUPLICATE_DETECTION + assert run.stage_status == StageStatus.PENDING + + def test_finalizes_after_last_stage(self) -> None: + run = SimpleNamespace( + pipeline=build_pipeline({}), + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.COMPLETED, + status="processing", + ) + assert advance_or_finalize(run) is None + assert run.stage == Stage.COMPLETED + assert run.stage_status == StageStatus.COMPLETED + assert run.status == "completed" + + +class TestBuildPrefilterRequests: + _CFG = { + "tr_columns": ["Problem"], + "tr_prompt": "rubric", + "dup_columns": ["Problem"], + } + + def test_topic_relevance_stage(self) -> None: + lines = build_prefilter_requests( + Stage.PRE_FILTER_TOPIC_RELEVANCE, [(0, {"Problem": "p"})], self._CFG + ) + assert len(lines) == 1 + + def test_duplicate_detection_stage(self) -> None: + lines = build_prefilter_requests( + Stage.PRE_FILTER_DUPLICATE_DETECTION, [(0, {"Problem": "p"})], self._CFG + ) + assert len(lines) == 1 + + def test_unknown_stage_raises(self) -> None: + with pytest.raises(ValueError): + build_prefilter_requests("BOGUS", [(0, {"Problem": "p"})], self._CFG) diff --git a/backend/app/tests/assessment/test_prefilter_batching.py b/backend/app/tests/assessment/test_prefilter_batching.py index 4c631daf2..65dd9c8d5 100644 --- a/backend/app/tests/assessment/test_prefilter_batching.py +++ b/backend/app/tests/assessment/test_prefilter_batching.py @@ -4,6 +4,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest +from celery.exceptions import SoftTimeLimitExceeded + from app.models.assessment import Stage, StageStatus from app.services.assessment import tasks @@ -150,3 +153,198 @@ def test_falls_back_to_full_range_when_nothing_persisted(self) -> None: ) result = tasks._accepted_indices(MagicMock(), run, total_rows=3, project_id=1) assert result == [0, 1, 2] + + +class TestGuardEntrypoint: + def test_unexpected_exception_marks_failed_and_reraises(self) -> None: + with patch.object( + tasks, "_orchestrate", side_effect=RuntimeError("boom") + ), patch.object(tasks, "_mark_run_failed") as mark: + with pytest.raises(RuntimeError): + tasks.execute_assessment_pipeline(5, 1, 1) + mark.assert_called_once() + + def test_soft_timeout_marks_failed_and_reraises(self) -> None: + with patch.object( + tasks, "_orchestrate", side_effect=SoftTimeLimitExceeded() + ), patch.object(tasks, "_mark_run_failed") as mark: + with pytest.raises(SoftTimeLimitExceeded): + tasks.execute_assessment_pipeline(5, 1, 1) + mark.assert_called_once() + + +class TestMarkRunFailed: + def test_marks_non_terminal_run_failed(self) -> None: + run = _run(stage=Stage.L2_ASSESSMENT, stage_status=StageStatus.PROCESSING) + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "update_assessment_run_status") as upd, patch.object( + tasks, "recompute_assessment_status" + ): + tasks._mark_run_failed(5, "dead") + assert run.stage_status == StageStatus.FAILED + upd.assert_called_once() + + def test_skips_terminal_run(self) -> None: + run = _run(stage=Stage.COMPLETED) + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "update_assessment_run_status") as upd: + tasks._mark_run_failed(5, "dead") + upd.assert_not_called() + + +class TestDispatch: + def test_dispatch_enqueues_task(self) -> None: + with patch.object(tasks, "run_assessment_pipeline") as task: + tasks._dispatch(5, 1, 2) + task.delay.assert_called_once() + assert task.delay.call_args.kwargs["run_id"] == 5 + + +class TestResolveRunContext: + def test_success(self) -> None: + session = MagicMock() + run = _run() + session.get.return_value = SimpleNamespace(dataset_id=3) + with patch.object( + tasks, "get_assessment_dataset_by_id", return_value=MagicMock() + ), patch.object( + tasks, "resolve_evaluation_config", return_value=({"x": 1}, None) + ): + _a, _d, blob, err = tasks._resolve_run_context(session, run, 1, 1) + assert blob == {"x": 1} + assert err is None + + def test_missing_parent(self) -> None: + session = MagicMock() + session.get.return_value = None + _a, _d, blob, err = tasks._resolve_run_context(session, _run(), 1, 1) + assert blob is None + assert "not found" in err + + def test_config_error(self) -> None: + session = MagicMock() + session.get.return_value = SimpleNamespace(dataset_id=3) + with patch.object( + tasks, "get_assessment_dataset_by_id", return_value=MagicMock() + ), patch.object( + tasks, "resolve_evaluation_config", return_value=(None, "bad config") + ): + _a, _d, blob, err = tasks._resolve_run_context(session, _run(), 1, 1) + assert blob is None + assert "bad config" in err + + +class TestAcceptedIndicesFallback: + def test_recomputes_from_gate_batch(self) -> None: + run = _run( + pipeline={ + "stages": [ + {"stage": Stage.PRE_FILTER_TOPIC_RELEVANCE, "order": 1}, + {"stage": Stage.L2_ASSESSMENT, "order": 2}, + ] + }, + stage=Stage.L2_ASSESSMENT, + stage_batches={Stage.PRE_FILTER_TOPIC_RELEVANCE: 1}, + ) + with patch.object(tasks, "get_batch_job", return_value=SimpleNamespace(provider="openai")), patch.object( + tasks, "load_raw_batch_results", return_value=[] + ), patch.object(tasks, "parse_assessment_output", return_value=[]), patch.dict( + tasks.STAGE_PARSERS, + {Stage.PRE_FILTER_TOPIC_RELEVANCE: lambda outs: {0: {"verdict": True}, 1: {"verdict": False}}}, + ): + result = tasks._accepted_indices(MagicMock(), run, total_rows=3, project_id=1) + # Only row 0 passed the gate. + assert result == [0] + + +class TestSubmitStageBranches: + def test_config_error_fails_run(self) -> None: + run = _run(stage=Stage.L2_ASSESSMENT, stage_status=StageStatus.PENDING) + with patch.object( + tasks, "_resolve_run_context", return_value=(None, None, None, "boom") + ), patch.object(tasks, "update_assessment_run_status") as upd, patch.object( + tasks, "recompute_assessment_status" + ): + tasks._submit_stage(MagicMock(), run, 1, 1) + assert run.stage_status == StageStatus.FAILED + upd.assert_called_once() + + def test_empty_dataset_fails_run(self) -> None: + run = _run(stage=Stage.L2_ASSESSMENT, stage_status=StageStatus.PENDING) + with patch.object( + tasks, + "_resolve_run_context", + return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), + ), patch.object(tasks, "_load_dataset_rows", return_value=[]), patch.object( + tasks, "update_assessment_run_status" + ) as upd, patch.object(tasks, "recompute_assessment_status"): + tasks._submit_stage(MagicMock(), run, 1, 1) + assert run.stage_status == StageStatus.FAILED + upd.assert_called_once() + + def test_submits_l2_batch(self) -> None: + run = _run( + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.PENDING, + stage_batches={}, + ) + batch_job = SimpleNamespace(id=8, total_items=2) + with patch.object( + tasks, + "_resolve_run_context", + return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), + ), patch.object(tasks, "_load_dataset_rows", return_value=[{"a": "1"}] * 3), patch.object( + tasks, "_accepted_indices", return_value=[0, 1] + ), patch.object(tasks, "flag_modified"), patch.object( + tasks, "submit_assessment_batch", return_value=batch_job + ), patch.object(tasks, "recompute_assessment_status"): + tasks._submit_stage(MagicMock(), run, 1, 1) + assert run.total_items == 2 + assert run.stage_batches[Stage.L2_ASSESSMENT] == 8 + + def test_unknown_stage_raises(self) -> None: + run = _run(stage="BOGUS", stage_status=StageStatus.PENDING, stage_batches={}) + with patch.object( + tasks, + "_resolve_run_context", + return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), + ), patch.object(tasks, "_load_dataset_rows", return_value=[{"a": "1"}]), patch.object( + tasks, "_accepted_indices", return_value=[0] + ): + with pytest.raises(ValueError): + tasks._submit_stage(MagicMock(), run, 1, 1) + + +class TestPersistAdvance: + def test_dispatches_next_stage(self) -> None: + run = _run() + with patch.object(tasks, "advance_or_finalize", return_value=Stage.L2_ASSESSMENT), patch.object( + tasks, "recompute_assessment_status" + ), patch.object(tasks, "_dispatch") as dispatch: + tasks._persist_advance(MagicMock(), run, 1, 1) + dispatch.assert_called_once() + + def test_finalize_does_not_dispatch(self) -> None: + run = _run() + with patch.object(tasks, "advance_or_finalize", return_value=None), patch.object( + tasks, "recompute_assessment_status" + ), patch.object(tasks, "_dispatch") as dispatch: + tasks._persist_advance(MagicMock(), run, 1, 1) + dispatch.assert_not_called() + + def test_enqueue_failure_marks_failed(self) -> None: + run = _run(stage=Stage.L2_ASSESSMENT) + with patch.object(tasks, "advance_or_finalize", return_value=Stage.L2_ASSESSMENT), patch.object( + tasks, "recompute_assessment_status" + ), patch.object(tasks, "_dispatch", side_effect=RuntimeError("broker down")), patch.object( + tasks, "update_assessment_run_status" + ) as upd: + tasks._persist_advance(MagicMock(), run, 1, 1) + assert run.stage_status == StageStatus.FAILED + upd.assert_called_once() From 827547d9c491ed04397de21358310ac011d1aa25 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 4 Jun 2026 14:22:45 +0530 Subject: [PATCH 15/16] feat: improve test readability by formatting patch calls in assessment tests --- backend/app/tests/assessment/test_export.py | 68 +++++++++++++++---- .../assessment/test_prefilter_batching.py | 63 +++++++++++------ 2 files changed, 99 insertions(+), 32 deletions(-) diff --git a/backend/app/tests/assessment/test_export.py b/backend/app/tests/assessment/test_export.py index 9dfd6cbd2..a13a1e03a 100644 --- a/backend/app/tests/assessment/test_export.py +++ b/backend/app/tests/assessment/test_export.py @@ -741,12 +741,25 @@ class TestLoadPrefilterResults: def test_merges_tr_and_dup_annotations(self) -> None: run = SimpleNamespace(id=5) assessment = SimpleNamespace(project_id=1) - with patch.object(export_mod, "_stage_batch_job", return_value=SimpleNamespace(provider="openai")), patch.object( + with patch.object( + export_mod, + "_stage_batch_job", + return_value=SimpleNamespace(provider="openai"), + ), patch.object( export_mod, "load_raw_batch_results", return_value=[] - ), patch.object(export_mod, "parse_assessment_output", return_value=[]), patch.object( + ), patch.object( + export_mod, "parse_assessment_output", return_value=[] + ), patch.object( export_mod, "parse_topic_relevance_results", - return_value={0: {"verdict": True, "decision": "ACCEPT", "reasoning": "ok", "column_relevance": {"a": True}}}, + return_value={ + 0: { + "verdict": True, + "decision": "ACCEPT", + "reasoning": "ok", + "column_relevance": {"a": True}, + } + }, ), patch.object( export_mod, "parse_duplicate_detection_results", @@ -760,7 +773,11 @@ def test_merges_tr_and_dup_annotations(self) -> None: def test_tr_load_failure_is_swallowed(self) -> None: run = SimpleNamespace(id=5) assessment = SimpleNamespace(project_id=1) - with patch.object(export_mod, "_stage_batch_job", return_value=SimpleNamespace(provider="openai")), patch.object( + with patch.object( + export_mod, + "_stage_batch_job", + return_value=SimpleNamespace(provider="openai"), + ), patch.object( export_mod, "load_raw_batch_results", side_effect=RuntimeError("s3 down") ): out = _load_prefilter_results(MagicMock(), run, assessment) @@ -769,28 +786,49 @@ def test_tr_load_failure_is_swallowed(self) -> None: class TestLoadParsedResultsForBatchJob: def test_object_store_path(self) -> None: - job = SimpleNamespace(id=1, provider="openai", raw_output_url="s3://x", provider_output_file_id=None) + job = SimpleNamespace( + id=1, + provider="openai", + raw_output_url="s3://x", + provider_output_file_id=None, + ) assessment = SimpleNamespace(project_id=1, organization_id=1) storage = MagicMock() storage.stream.return_value.read.return_value.decode.return_value = "raw" - with patch.object(export_mod, "get_cloud_storage", return_value=storage), patch.object( + with patch.object( + export_mod, "get_cloud_storage", return_value=storage + ), patch.object( export_mod, "parse_stored_results", return_value=[{"k": 1}] - ), patch.object(export_mod, "parse_assessment_output", return_value=[{"row_id": "row_0"}]) as parse: + ), patch.object( + export_mod, "parse_assessment_output", return_value=[{"row_id": "row_0"}] + ) as parse: result = _load_parsed_results_for_batch_job(MagicMock(), job, assessment) assert result == [{"row_id": "row_0"}] parse.assert_called_once() def test_provider_fallback_path(self) -> None: - job = SimpleNamespace(id=1, provider="openai", raw_output_url=None, provider_output_file_id="f1", organization_id=1) + job = SimpleNamespace( + id=1, + provider="openai", + raw_output_url=None, + provider_output_file_id="f1", + organization_id=1, + ) assessment = SimpleNamespace(project_id=1, organization_id=1) - with patch.object(export_mod, "_get_batch_provider", return_value=MagicMock()), patch.object( + with patch.object( + export_mod, "_get_batch_provider", return_value=MagicMock() + ), patch.object( export_mod, "download_batch_results", return_value=[{"k": 1}] - ), patch.object(export_mod, "parse_assessment_output", return_value=[{"row_id": "row_1"}]): + ), patch.object( + export_mod, "parse_assessment_output", return_value=[{"row_id": "row_1"}] + ): result = _load_parsed_results_for_batch_job(MagicMock(), job, assessment) assert result == [{"row_id": "row_1"}] def test_returns_none_without_outputs(self) -> None: - job = SimpleNamespace(id=1, provider="openai", raw_output_url=None, provider_output_file_id=None) + job = SimpleNamespace( + id=1, provider="openai", raw_output_url=None, provider_output_file_id=None + ) assessment = SimpleNamespace(project_id=1, organization_id=1) assert _load_parsed_results_for_batch_job(MagicMock(), job, assessment) is None @@ -799,7 +837,9 @@ class TestLoadL2ResultsForRun: def test_keys_by_row_id(self) -> None: run = SimpleNamespace() assessment = SimpleNamespace() - with patch.object(export_mod, "_stage_batch_job", return_value=SimpleNamespace()), patch.object( + with patch.object( + export_mod, "_stage_batch_job", return_value=SimpleNamespace() + ), patch.object( export_mod, "_load_parsed_results_for_batch_job", return_value=[{"row_id": "row_0", "output": "x"}, {"no_row": 1}], @@ -809,5 +849,7 @@ def test_keys_by_row_id(self) -> None: def test_empty_when_no_batch(self) -> None: with patch.object(export_mod, "_stage_batch_job", return_value=None): - merged = _load_l2_results_for_run(MagicMock(), SimpleNamespace(), SimpleNamespace()) + merged = _load_l2_results_for_run( + MagicMock(), SimpleNamespace(), SimpleNamespace() + ) assert merged == {} diff --git a/backend/app/tests/assessment/test_prefilter_batching.py b/backend/app/tests/assessment/test_prefilter_batching.py index 65dd9c8d5..4ecdf52e9 100644 --- a/backend/app/tests/assessment/test_prefilter_batching.py +++ b/backend/app/tests/assessment/test_prefilter_batching.py @@ -252,13 +252,22 @@ def test_recomputes_from_gate_batch(self) -> None: stage=Stage.L2_ASSESSMENT, stage_batches={Stage.PRE_FILTER_TOPIC_RELEVANCE: 1}, ) - with patch.object(tasks, "get_batch_job", return_value=SimpleNamespace(provider="openai")), patch.object( - tasks, "load_raw_batch_results", return_value=[] - ), patch.object(tasks, "parse_assessment_output", return_value=[]), patch.dict( + with patch.object( + tasks, "get_batch_job", return_value=SimpleNamespace(provider="openai") + ), patch.object(tasks, "load_raw_batch_results", return_value=[]), patch.object( + tasks, "parse_assessment_output", return_value=[] + ), patch.dict( tasks.STAGE_PARSERS, - {Stage.PRE_FILTER_TOPIC_RELEVANCE: lambda outs: {0: {"verdict": True}, 1: {"verdict": False}}}, + { + Stage.PRE_FILTER_TOPIC_RELEVANCE: lambda outs: { + 0: {"verdict": True}, + 1: {"verdict": False}, + } + }, ): - result = tasks._accepted_indices(MagicMock(), run, total_rows=3, project_id=1) + result = tasks._accepted_indices( + MagicMock(), run, total_rows=3, project_id=1 + ) # Only row 0 passed the gate. assert result == [0] @@ -283,7 +292,9 @@ def test_empty_dataset_fails_run(self) -> None: return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), ), patch.object(tasks, "_load_dataset_rows", return_value=[]), patch.object( tasks, "update_assessment_run_status" - ) as upd, patch.object(tasks, "recompute_assessment_status"): + ) as upd, patch.object( + tasks, "recompute_assessment_status" + ): tasks._submit_stage(MagicMock(), run, 1, 1) assert run.stage_status == StageStatus.FAILED upd.assert_called_once() @@ -299,11 +310,17 @@ def test_submits_l2_batch(self) -> None: tasks, "_resolve_run_context", return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), - ), patch.object(tasks, "_load_dataset_rows", return_value=[{"a": "1"}] * 3), patch.object( + ), patch.object( + tasks, "_load_dataset_rows", return_value=[{"a": "1"}] * 3 + ), patch.object( tasks, "_accepted_indices", return_value=[0, 1] - ), patch.object(tasks, "flag_modified"), patch.object( + ), patch.object( + tasks, "flag_modified" + ), patch.object( tasks, "submit_assessment_batch", return_value=batch_job - ), patch.object(tasks, "recompute_assessment_status"): + ), patch.object( + tasks, "recompute_assessment_status" + ): tasks._submit_stage(MagicMock(), run, 1, 1) assert run.total_items == 2 assert run.stage_batches[Stage.L2_ASSESSMENT] == 8 @@ -314,7 +331,9 @@ def test_unknown_stage_raises(self) -> None: tasks, "_resolve_run_context", return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), - ), patch.object(tasks, "_load_dataset_rows", return_value=[{"a": "1"}]), patch.object( + ), patch.object( + tasks, "_load_dataset_rows", return_value=[{"a": "1"}] + ), patch.object( tasks, "_accepted_indices", return_value=[0] ): with pytest.raises(ValueError): @@ -324,25 +343,31 @@ def test_unknown_stage_raises(self) -> None: class TestPersistAdvance: def test_dispatches_next_stage(self) -> None: run = _run() - with patch.object(tasks, "advance_or_finalize", return_value=Stage.L2_ASSESSMENT), patch.object( - tasks, "recompute_assessment_status" - ), patch.object(tasks, "_dispatch") as dispatch: + with patch.object( + tasks, "advance_or_finalize", return_value=Stage.L2_ASSESSMENT + ), patch.object(tasks, "recompute_assessment_status"), patch.object( + tasks, "_dispatch" + ) as dispatch: tasks._persist_advance(MagicMock(), run, 1, 1) dispatch.assert_called_once() def test_finalize_does_not_dispatch(self) -> None: run = _run() - with patch.object(tasks, "advance_or_finalize", return_value=None), patch.object( - tasks, "recompute_assessment_status" - ), patch.object(tasks, "_dispatch") as dispatch: + with patch.object( + tasks, "advance_or_finalize", return_value=None + ), patch.object(tasks, "recompute_assessment_status"), patch.object( + tasks, "_dispatch" + ) as dispatch: tasks._persist_advance(MagicMock(), run, 1, 1) dispatch.assert_not_called() def test_enqueue_failure_marks_failed(self) -> None: run = _run(stage=Stage.L2_ASSESSMENT) - with patch.object(tasks, "advance_or_finalize", return_value=Stage.L2_ASSESSMENT), patch.object( - tasks, "recompute_assessment_status" - ), patch.object(tasks, "_dispatch", side_effect=RuntimeError("broker down")), patch.object( + with patch.object( + tasks, "advance_or_finalize", return_value=Stage.L2_ASSESSMENT + ), patch.object(tasks, "recompute_assessment_status"), patch.object( + tasks, "_dispatch", side_effect=RuntimeError("broker down") + ), patch.object( tasks, "update_assessment_run_status" ) as upd: tasks._persist_advance(MagicMock(), run, 1, 1) From 8bd54a7d5e6f193d6395099edd80f5babbd605ef Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 4 Jun 2026 14:26:11 +0530 Subject: [PATCH 16/16] feat: add commented alternative model and duplicate store configurations in assessment prefilter constants --- backend/app/services/assessment/prefilter/constants.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backend/app/services/assessment/prefilter/constants.py b/backend/app/services/assessment/prefilter/constants.py index 429044ade..1fe54ba85 100644 --- a/backend/app/services/assessment/prefilter/constants.py +++ b/backend/app/services/assessment/prefilter/constants.py @@ -6,6 +6,9 @@ # Provider + model that run the batch prefilter stages (topic relevance, dup check). ASSESSMENT_PREFILTER_PROVIDER: Literal["openai", "google"] = "openai" ASSESSMENT_PREFILTER_MODEL: str = "gpt-5-mini" +# ASSESSMENT_PREFILTER_MODEL: str = "gemini-3.1-flash-lite" + # File-search/vector store holding the corpus for duplicate detection. ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = "vs_6a20339fbc148191867fd06d29133278" +# ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = "fileSearchStores/inquilabcorpus-782mxjcwisaz"