diff --git a/changelog.d/codex-runtime-versioned-report-cache.fixed.md b/changelog.d/codex-runtime-versioned-report-cache.fixed.md new file mode 100644 index 000000000..6f97c23de --- /dev/null +++ b/changelog.d/codex-runtime-versioned-report-cache.fixed.md @@ -0,0 +1 @@ +Version economy caches and report output reuse against the full runtime, and strip stale congressional district payloads from legacy US reports so clients refresh district outcomes from live state summaries. diff --git a/policyengine_api/constants.py b/policyengine_api/constants.py index 1bfe6b53d..fa7b6730b 100644 --- a/policyengine_api/constants.py +++ b/policyengine_api/constants.py @@ -1,6 +1,7 @@ from pathlib import Path from importlib.metadata import distributions from datetime import datetime +import hashlib REPO = Path(__file__).parents[1] GET = "GET" @@ -17,14 +18,85 @@ "policyengine_ng", "policyengine_il", ) + + +def _normalize_distribution_name(name: str | None) -> str: + if name is None: + return "" + return name.replace("_", "-").lower() + + +def _resolve_distribution_version( + dist_versions: dict[str, str], *package_names: str +) -> str: + for package_name in package_names: + version = dist_versions.get(_normalize_distribution_name(package_name)) + if version is not None: + return version + return "0.0.0" + + try: - _dist_versions = {d.metadata["Name"]: d.version for d in distributions()} + _dist_versions = { + _normalize_distribution_name(d.metadata["Name"]): d.version + for d in distributions() + } COUNTRY_PACKAGE_VERSIONS = { - country: _dist_versions.get(package_name.replace("_", "-"), "0.0.0") + country: _resolve_distribution_version(_dist_versions, package_name) for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES) } + POLICYENGINE_CORE_VERSION = _resolve_distribution_version( + _dist_versions, "policyengine-core", "policyengine" + ) except Exception: COUNTRY_PACKAGE_VERSIONS = {country: "0.0.0" for country in COUNTRIES} + POLICYENGINE_CORE_VERSION = "0.0.0" + +RUNTIME_CACHE_SCHEMA_VERSIONS = { + "economy_impact": 1, + "report_output": 1, +} + + +def _build_runtime_cache_version( + scope: str, country_id: str, caller_version: str | None = None +) -> str: + """ + Build a compact version token for cache keys stored in legacy VARCHAR(10) + columns. The token changes whenever the relevant runtime or payload schema + changes, even if the country package version is unchanged. + """ + schema_version = str(RUNTIME_CACHE_SCHEMA_VERSIONS[scope]) + prefix = "e" if scope == "economy_impact" else "r" + digest_length = 10 - len(prefix) - len(schema_version) + if digest_length < 4: + raise ValueError( + f"Runtime cache version for {scope} does not fit in VARCHAR(10)" + ) + + raw = "|".join( + ( + scope, + country_id, + caller_version or COUNTRY_PACKAGE_VERSIONS.get(country_id, "0.0.0"), + COUNTRY_PACKAGE_VERSIONS.get(country_id, "0.0.0"), + POLICYENGINE_CORE_VERSION, + schema_version, + ) + ) + digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:digest_length] + return f"{prefix}{schema_version}{digest}" + + +def get_economy_impact_cache_version( + country_id: str, caller_version: str | None = None +) -> str: + return _build_runtime_cache_version("economy_impact", country_id, caller_version) + + +def get_report_output_cache_version(country_id: str) -> str: + return _build_runtime_cache_version("report_output", country_id) + # Valid region types for each country # These define the geographic scope categories for regions diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 94638c53e..93256d778 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -160,7 +160,7 @@ def update_report_output(country_id: str) -> Response: try: # First check if the report output exists - existing_report = report_output_service.get_report_output(report_id) + existing_report = report_output_service.get_stored_report_output(report_id) if existing_report is None: raise NotFound(f"Report #{report_id} not found.") @@ -176,8 +176,9 @@ def update_report_output(country_id: str) -> Response: if not success: raise BadRequest("No fields to update") - # Get the updated record - updated_report = report_output_service.get_report_output(report_id) + # Get the updated stored record so stale-runtime jobs do not appear to + # complete the current runtime lineage in the PATCH response. + updated_report = report_output_service.get_stored_report_output(report_id) response_body = dict( status="ok", diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 031696286..95eae9838 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -8,6 +8,7 @@ EXECUTION_STATUSES_SUCCESS, EXECUTION_STATUSES_FAILURE, EXECUTION_STATUSES_PENDING, + get_economy_impact_cache_version, ) from policyengine_api.gcp_logging import logger from policyengine_api.libs.simulation_api_modal import simulation_api_modal @@ -164,6 +165,8 @@ def get_economic_impact( if country_id == "uk": country_package_version = None + cache_version = get_economy_impact_cache_version(country_id, api_version) + economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( { "process_id": process_id, @@ -174,7 +177,7 @@ def get_economic_impact( "dataset": dataset, "time_period": time_period, "options": options, - "api_version": api_version, + "api_version": cache_version, "target": target, "model_version": country_package_version, "data_version": get_dataset_version(country_id), diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index c34c62f79..3200ec6e8 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -1,10 +1,52 @@ from sqlalchemy.engine.row import Row from policyengine_api.data import database -from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS +from policyengine_api.constants import get_report_output_cache_version class ReportOutputService: + def _get_report_output_row(self, report_output_id: int) -> dict | None: + row: Row | None = database.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_output_id,), + ).fetchone() + return dict(row) if row is not None else None + + def get_stored_report_output(self, report_output_id: int) -> dict | None: + """ + Get the raw stored report output row by ID without aliasing to the + current runtime lineage. This is useful for mutation paths, which must + update the originally addressed row rather than a resolved alias. + """ + return self._get_report_output_row(report_output_id) + + def _is_current_report_output(self, report_output: dict) -> bool: + return report_output.get("api_version") == get_report_output_cache_version( + report_output["country_id"] + ) + + def _get_or_create_current_report_output(self, report_output: dict) -> dict: + current_report = self.find_existing_report_output( + country_id=report_output["country_id"], + simulation_1_id=report_output["simulation_1_id"], + simulation_2_id=report_output["simulation_2_id"], + year=report_output["year"], + ) + if current_report is not None: + return current_report + + return self.create_report_output( + country_id=report_output["country_id"], + simulation_1_id=report_output["simulation_1_id"], + simulation_2_id=report_output["simulation_2_id"], + year=report_output["year"], + ) + + def _alias_report_output(self, report_output_id: int, report_output: dict) -> dict: + aliased_report = dict(report_output) + aliased_report["id"] = report_output_id + return aliased_report + def find_existing_report_output( self, country_id: str, @@ -25,11 +67,11 @@ def find_existing_report_output( dict | None: The existing report output data or None if not found. """ print("Checking for existing report output") + api_version = get_report_output_cache_version(country_id) try: - # Check for existing record with the same simulation IDs and year (excluding api_version) - query = "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ?" - params = [country_id, simulation_1_id, year] + query = "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ? AND api_version = ?" + params = [country_id, simulation_1_id, year, api_version] if simulation_2_id is not None: query += " AND simulation_2_id = ?" @@ -37,6 +79,8 @@ def find_existing_report_output( else: query += " AND simulation_2_id IS NULL" + query += " ORDER BY id DESC" + row = database.query(query, tuple(params)).fetchone() existing_report = None @@ -71,9 +115,18 @@ def create_report_output( dict: The created report output record. """ print("Creating new report output") - api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) + api_version = get_report_output_cache_version(country_id) try: + existing_report = self.find_existing_report_output( + country_id, simulation_1_id, simulation_2_id, year + ) + if existing_report is not None: + print( + f"Reusing existing report output with ID: {existing_report['id']}" + ) + return existing_report + # Insert with default status 'pending' if simulation_2_id is not None: database.query( @@ -132,18 +185,15 @@ def get_report_output(self, report_output_id: int) -> dict | None: f"Invalid report output ID: {report_output_id}. Must be a positive integer." ) - row: Row | None = database.query( - "SELECT * FROM report_outputs WHERE id = ?", - (report_output_id,), - ).fetchone() + report_output = self._get_report_output_row(report_output_id) + if report_output is None: + return None - report_output = None - if row is not None: - report_output = dict(row) - # Keep output as JSON string - frontend expects string format - # Frontend will parse it using JSON.parse() + if self._is_current_report_output(report_output): + return report_output - return report_output + current_report = self._get_or_create_current_report_output(report_output) + return self._alias_report_output(report_output_id, current_report) except Exception as e: print( @@ -172,10 +222,12 @@ def update_report_output( bool: True if update was successful. """ print(f"Updating report output {report_id}") - # Automatically update api_version on every update to latest - api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) try: + requested_report = self._get_report_output_row(report_id) + if requested_report is None: + raise Exception(f"Report output #{report_id} not found") + # Build the update query dynamically based on provided fields update_fields = [] update_values = [] @@ -193,16 +245,12 @@ def update_report_output( update_fields.append("error_message = ?") update_values.append(error_message) - # Always update API version - update_fields.append("api_version = ?") - update_values.append(api_version) - if not update_fields: print("No fields to update") return False # Add report_id to the end of values for WHERE clause - update_values.append(report_id) + update_values.append(requested_report["id"]) query = f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ?" diff --git a/tests/fixtures/services/report_output_fixtures.py b/tests/fixtures/services/report_output_fixtures.py index 5f8d5bb76..3d4ca5a14 100644 --- a/tests/fixtures/services/report_output_fixtures.py +++ b/tests/fixtures/services/report_output_fixtures.py @@ -1,11 +1,13 @@ import pytest import json +from policyengine_api.constants import get_report_output_cache_version + valid_report_data = { "country_id": "us", "simulation_1_id": 1, "simulation_2_id": None, - "api_version": "1.0.0", + "api_version": get_report_output_cache_version("us"), "status": "pending", "output": None, "error_message": None, diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index c49783bad..162d30c20 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -212,6 +212,40 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params( ) assert sim_params["_metadata"]["process_id"] == MOCK_PROCESS_ID + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + monkeypatch, + ): + cache_version = "e1cache01" + monkeypatch.setattr( + "policyengine_api.services.economy_service.get_economy_impact_cache_version", + lambda country_id, api_version=None: cache_version, + ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + + economy_service.get_economic_impact(**base_params) + + mock_reform_impacts_service.get_all_reform_impacts.assert_called_once_with( + MOCK_COUNTRY_ID, + MOCK_POLICY_ID, + MOCK_BASELINE_POLICY_ID, + MOCK_REGION, + MOCK_DATASET, + MOCK_TIME_PERIOD, + MOCK_OPTIONS_HASH, + cache_version, + ) + def test__given_exception__raises_error( self, economy_service, diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index c1f6b3e55..e3b63cbd3 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -1,6 +1,7 @@ import pytest import json +from policyengine_api.constants import get_report_output_cache_version from policyengine_api.services.report_output_service import ReportOutputService from tests.fixtures.services.report_output_fixtures import ( @@ -47,10 +48,11 @@ def test_find_existing_report_output_not_found(self, test_db): def test_find_existing_report_output_with_null_simulation2(self, test_db): """Test finding reports where simulation_2_id is NULL.""" + api_version = get_report_output_cache_version("us") # GIVEN a report with NULL simulation_2_id test_db.query( "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 100, None, "complete", "1.0.0", "2025"), + ("us", 100, None, "complete", api_version, "2025"), ) # WHEN we search for it @@ -69,14 +71,15 @@ def test_find_existing_report_output_with_null_simulation2(self, test_db): def test_find_existing_report_output_with_year(self, test_db): """Test finding reports with different years.""" + api_version = get_report_output_cache_version("us") # GIVEN reports with different years for the same simulation test_db.query( "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 101, None, "complete", "1.0.0", "2025"), + ("us", 101, None, "complete", api_version, "2025"), ) test_db.query( "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 101, None, "complete", "1.0.0", "2024"), + ("us", 101, None, "complete", api_version, "2024"), ) # WHEN we search for the 2025 report @@ -108,6 +111,25 @@ def test_find_existing_report_output_with_year(self, test_db): # AND the two reports should have different IDs assert result_2025["id"] != result_2024["id"] + def test_find_existing_report_output_ignores_stale_runtime_version(self, test_db): + current_version = get_report_output_cache_version("us") + stale_version = "r0stale1" + assert stale_version != current_version + + test_db.query( + "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", + ("us", 102, None, "complete", stale_version, "2025"), + ) + + result = service.find_existing_report_output( + country_id="us", + simulation_1_id=102, + simulation_2_id=None, + year="2025", + ) + + assert result is None + class TestCreateReportOutput: """Test creating new report outputs in the database.""" @@ -270,7 +292,7 @@ def test_get_report_output_with_json_output(self, test_db): None, "complete", json.dumps(test_output), - "1.0.0", + get_report_output_cache_version("us"), "2025", ), ) @@ -288,6 +310,95 @@ def test_get_report_output_with_json_output(self, test_db): assert result["year"] == "2025" # Frontend will parse this string + def test_get_report_output_resolves_stale_id_to_current_runtime_row(self, test_db): + stale_output = { + "budget": {"budgetary_impact": 1}, + "congressional_district_impact": { + "districts": [ + { + "district": "AL-01", + "average_household_income_change": 120, + "relative_household_income_change": 0.01, + } + ] + }, + } + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + ( + "us", + 2, + None, + "complete", + json.dumps(stale_output), + "r0stale1", + "2025", + ), + ) + + stale_record = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + current_version = get_report_output_cache_version("us") + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + ( + "us", + 2, + None, + "complete", + json.dumps({"budget": {"budgetary_impact": 2}}), + current_version, + "2025", + ), + ) + + current_record = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + result = service.get_report_output(report_output_id=stale_record["id"]) + assert result is not None + assert result["id"] == stale_record["id"] + assert result["api_version"] == current_record["api_version"] + assert result["output"] == current_record["output"] + + def test_get_report_output_creates_current_runtime_row_for_stale_id(self, test_db): + stale_version = "r0stale1" + current_version = get_report_output_cache_version("us") + + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, api_version, year) + VALUES (?, ?, ?, ?, ?, ?)""", + ("us", 3, None, "complete", stale_version, "2025"), + ) + + stale_record = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + result = service.get_report_output(report_output_id=stale_record["id"]) + + assert result is not None + assert result["id"] == stale_record["id"] + assert result["api_version"] == current_version + assert result["status"] == "pending" + assert result["output"] is None + + current_rows = test_db.query( + "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ? ORDER BY id ASC", + ("us", 3, "2025"), + ).fetchall() + assert len(current_rows) == 2 + assert current_rows[0]["api_version"] == stale_version + assert current_rows[1]["api_version"] == current_version + def test_get_report_output_invalid_id(self, test_db): """Test that invalid report IDs are handled properly.""" # GIVEN any database state @@ -406,28 +517,49 @@ def test_update_report_output_partial_update(self, test_db, existing_report_reco assert result["status"] == "complete" assert result["output"] is None # Should remain unchanged - def test_update_report_output_no_fields(self, test_db, existing_report_record): - """Test that update with no optional fields still updates API version.""" - # GIVEN an existing report - - # WHEN we call update with no optional fields + def test_update_report_output_no_fields_returns_false( + self, test_db, existing_report_record + ): success = service.update_report_output( country_id=existing_report_record["country_id"], report_id=existing_report_record["id"], ) - # THEN it should still succeed (API version always gets updated) - assert success is True + assert success is False - # AND the API version should be updated to the latest - result = test_db.query( - "SELECT * FROM report_outputs WHERE id = ?", - (existing_report_record["id"],), + def test_update_report_output_stale_id_keeps_stale_output_quarantined( + self, test_db + ): + stale_version = "r0stale1" + output_json = json.dumps({"result": "fresh"}) + + test_db.query( + """INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, api_version, year) + VALUES (?, ?, ?, ?, ?, ?)""", + ("us", 4, None, "pending", stale_version, "2025"), + ) + + stale_record = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" ).fetchone() - # API version should be updated to current version - from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS - expected_version = COUNTRY_PACKAGE_VERSIONS.get( - existing_report_record["country_id"] + success = service.update_report_output( + country_id="us", + report_id=stale_record["id"], + status="complete", + output=output_json, ) - assert result["api_version"] == expected_version + + assert success is True + + rows = test_db.query( + "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ? ORDER BY id ASC", + ("us", 4, "2025"), + ).fetchall() + + assert len(rows) == 1 + assert rows[0]["id"] == stale_record["id"] + assert rows[0]["api_version"] == stale_version + assert rows[0]["status"] == "complete" + assert rows[0]["output"] == output_json diff --git a/tests/unit/test_constants.py b/tests/unit/test_constants.py index f83d964ca..2cb9503b1 100644 --- a/tests/unit/test_constants.py +++ b/tests/unit/test_constants.py @@ -1,9 +1,9 @@ -import pytest - from policyengine_api.constants import ( UK_REGION_TYPES, US_REGION_TYPES, REGION_PREFIXES, + _normalize_distribution_name, + _resolve_distribution_version, ) @@ -83,3 +83,32 @@ def test__contains_congressional_district_prefix(self): def test__has_exactly_three_prefixes(self): assert len(REGION_PREFIXES["us"]) == 3 + + +class TestDistributionVersionHelpers: + def test__normalize_distribution_name(self): + assert _normalize_distribution_name("policyengine_core") == ( + "policyengine-core" + ) + assert _normalize_distribution_name("PolicyEngine-Core") == ( + "policyengine-core" + ) + + def test__resolve_distribution_version_prefers_first_available_name(self): + dist_versions = { + "policyengine-core": "3.23.6", + "policyengine": "0.12.1", + } + + assert ( + _resolve_distribution_version( + dist_versions, "policyengine-core", "policyengine" + ) + == "3.23.6" + ) + + def test__resolve_distribution_version_falls_back_to_default(self): + assert ( + _resolve_distribution_version({}, "policyengine-core", "policyengine") + == "0.0.0" + )