Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/codex-runtime-versioned-report-cache.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Version economy caches and report output reuse against the full runtime, and strip stale congressional district payloads from legacy US reports so clients refresh district outcomes from live state summaries.
76 changes: 74 additions & 2 deletions policyengine_api/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from importlib.metadata import distributions
from datetime import datetime
import hashlib

REPO = Path(__file__).parents[1]
GET = "GET"
Expand All @@ -17,14 +18,85 @@
"policyengine_ng",
"policyengine_il",
)


def _normalize_distribution_name(name: str | None) -> str:
if name is None:
return ""
return name.replace("_", "-").lower()


def _resolve_distribution_version(
dist_versions: dict[str, str], *package_names: str
) -> str:
for package_name in package_names:
version = dist_versions.get(_normalize_distribution_name(package_name))
if version is not None:
return version
return "0.0.0"


try:
_dist_versions = {d.metadata["Name"]: d.version for d in distributions()}
_dist_versions = {
_normalize_distribution_name(d.metadata["Name"]): d.version
for d in distributions()
}
COUNTRY_PACKAGE_VERSIONS = {
country: _dist_versions.get(package_name.replace("_", "-"), "0.0.0")
country: _resolve_distribution_version(_dist_versions, package_name)
for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES)
}
POLICYENGINE_CORE_VERSION = _resolve_distribution_version(
_dist_versions, "policyengine-core", "policyengine"
)
except Exception:
COUNTRY_PACKAGE_VERSIONS = {country: "0.0.0" for country in COUNTRIES}
POLICYENGINE_CORE_VERSION = "0.0.0"

RUNTIME_CACHE_SCHEMA_VERSIONS = {
"economy_impact": 1,
"report_output": 1,
}


def _build_runtime_cache_version(
scope: str, country_id: str, caller_version: str | None = None
) -> str:
"""
Build a compact version token for cache keys stored in legacy VARCHAR(10)
columns. The token changes whenever the relevant runtime or payload schema
changes, even if the country package version is unchanged.
"""
schema_version = str(RUNTIME_CACHE_SCHEMA_VERSIONS[scope])
prefix = "e" if scope == "economy_impact" else "r"
digest_length = 10 - len(prefix) - len(schema_version)
if digest_length < 4:
raise ValueError(
f"Runtime cache version for {scope} does not fit in VARCHAR(10)"
)

raw = "|".join(
(
scope,
country_id,
caller_version or COUNTRY_PACKAGE_VERSIONS.get(country_id, "0.0.0"),
COUNTRY_PACKAGE_VERSIONS.get(country_id, "0.0.0"),
POLICYENGINE_CORE_VERSION,
schema_version,
)
)
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:digest_length]
return f"{prefix}{schema_version}{digest}"


def get_economy_impact_cache_version(
country_id: str, caller_version: str | None = None
) -> str:
return _build_runtime_cache_version("economy_impact", country_id, caller_version)


def get_report_output_cache_version(country_id: str) -> str:
return _build_runtime_cache_version("report_output", country_id)


# Valid region types for each country
# These define the geographic scope categories for regions
Expand Down
7 changes: 4 additions & 3 deletions policyengine_api/routes/report_output_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def update_report_output(country_id: str) -> Response:

try:
# First check if the report output exists
existing_report = report_output_service.get_report_output(report_id)
existing_report = report_output_service.get_stored_report_output(report_id)
if existing_report is None:
raise NotFound(f"Report #{report_id} not found.")

Expand All @@ -176,8 +176,9 @@ def update_report_output(country_id: str) -> Response:
if not success:
raise BadRequest("No fields to update")

# Get the updated record
updated_report = report_output_service.get_report_output(report_id)
# Get the updated stored record so stale-runtime jobs do not appear to
# complete the current runtime lineage in the PATCH response.
updated_report = report_output_service.get_stored_report_output(report_id)

response_body = dict(
status="ok",
Expand Down
5 changes: 4 additions & 1 deletion policyengine_api/services/economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
EXECUTION_STATUSES_SUCCESS,
EXECUTION_STATUSES_FAILURE,
EXECUTION_STATUSES_PENDING,
get_economy_impact_cache_version,
)
from policyengine_api.gcp_logging import logger
from policyengine_api.libs.simulation_api_modal import simulation_api_modal
Expand Down Expand Up @@ -164,6 +165,8 @@ def get_economic_impact(
if country_id == "uk":
country_package_version = None

cache_version = get_economy_impact_cache_version(country_id, api_version)

economic_impact_setup_options = EconomicImpactSetupOptions.model_validate(
{
"process_id": process_id,
Expand All @@ -174,7 +177,7 @@ def get_economic_impact(
"dataset": dataset,
"time_period": time_period,
"options": options,
"api_version": api_version,
"api_version": cache_version,
"target": target,
"model_version": country_package_version,
"data_version": get_dataset_version(country_id),
Expand Down
92 changes: 70 additions & 22 deletions policyengine_api/services/report_output_service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,52 @@
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
from policyengine_api.constants import get_report_output_cache_version


class ReportOutputService:
def _get_report_output_row(self, report_output_id: int) -> dict | None:
row: Row | None = database.query(
"SELECT * FROM report_outputs WHERE id = ?",
(report_output_id,),
).fetchone()
return dict(row) if row is not None else None

def get_stored_report_output(self, report_output_id: int) -> dict | None:
"""
Get the raw stored report output row by ID without aliasing to the
current runtime lineage. This is useful for mutation paths, which must
update the originally addressed row rather than a resolved alias.
"""
return self._get_report_output_row(report_output_id)

def _is_current_report_output(self, report_output: dict) -> bool:
return report_output.get("api_version") == get_report_output_cache_version(
report_output["country_id"]
)

def _get_or_create_current_report_output(self, report_output: dict) -> dict:
current_report = self.find_existing_report_output(
country_id=report_output["country_id"],
simulation_1_id=report_output["simulation_1_id"],
simulation_2_id=report_output["simulation_2_id"],
year=report_output["year"],
)
if current_report is not None:
return current_report

return self.create_report_output(
country_id=report_output["country_id"],
simulation_1_id=report_output["simulation_1_id"],
simulation_2_id=report_output["simulation_2_id"],
year=report_output["year"],
)

def _alias_report_output(self, report_output_id: int, report_output: dict) -> dict:
aliased_report = dict(report_output)
aliased_report["id"] = report_output_id
return aliased_report

def find_existing_report_output(
self,
country_id: str,
Expand All @@ -25,18 +67,20 @@ def find_existing_report_output(
dict | None: The existing report output data or None if not found.
"""
print("Checking for existing report output")
api_version = get_report_output_cache_version(country_id)

try:
# Check for existing record with the same simulation IDs and year (excluding api_version)
query = "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ?"
params = [country_id, simulation_1_id, year]
query = "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ? AND api_version = ?"
params = [country_id, simulation_1_id, year, api_version]

if simulation_2_id is not None:
query += " AND simulation_2_id = ?"
params.append(simulation_2_id)
else:
query += " AND simulation_2_id IS NULL"

query += " ORDER BY id DESC"

row = database.query(query, tuple(params)).fetchone()

existing_report = None
Expand Down Expand Up @@ -71,9 +115,18 @@ def create_report_output(
dict: The created report output record.
"""
print("Creating new report output")
api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id)
api_version = get_report_output_cache_version(country_id)

try:
existing_report = self.find_existing_report_output(
country_id, simulation_1_id, simulation_2_id, year
)
if existing_report is not None:
print(
f"Reusing existing report output with ID: {existing_report['id']}"
)
return existing_report

# Insert with default status 'pending'
if simulation_2_id is not None:
database.query(
Expand Down Expand Up @@ -132,18 +185,15 @@ def get_report_output(self, report_output_id: int) -> dict | None:
f"Invalid report output ID: {report_output_id}. Must be a positive integer."
)

row: Row | None = database.query(
"SELECT * FROM report_outputs WHERE id = ?",
(report_output_id,),
).fetchone()
report_output = self._get_report_output_row(report_output_id)
if report_output is None:
return None

report_output = None
if row is not None:
report_output = dict(row)
# Keep output as JSON string - frontend expects string format
# Frontend will parse it using JSON.parse()
if self._is_current_report_output(report_output):
return report_output

return report_output
current_report = self._get_or_create_current_report_output(report_output)
return self._alias_report_output(report_output_id, current_report)

except Exception as e:
print(
Expand Down Expand Up @@ -172,10 +222,12 @@ def update_report_output(
bool: True if update was successful.
"""
print(f"Updating report output {report_id}")
# Automatically update api_version on every update to latest
api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id)

try:
requested_report = self._get_report_output_row(report_id)
if requested_report is None:
raise Exception(f"Report output #{report_id} not found")

# Build the update query dynamically based on provided fields
update_fields = []
update_values = []
Expand All @@ -193,16 +245,12 @@ def update_report_output(
update_fields.append("error_message = ?")
update_values.append(error_message)

# Always update API version
update_fields.append("api_version = ?")
update_values.append(api_version)

if not update_fields:
print("No fields to update")
return False

# Add report_id to the end of values for WHERE clause
update_values.append(report_id)
update_values.append(requested_report["id"])

query = f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ?"

Expand Down
4 changes: 3 additions & 1 deletion tests/fixtures/services/report_output_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pytest
import json

from policyengine_api.constants import get_report_output_cache_version

valid_report_data = {
"country_id": "us",
"simulation_1_id": 1,
"simulation_2_id": None,
"api_version": "1.0.0",
"api_version": get_report_output_cache_version("us"),
"status": "pending",
"output": None,
"error_message": None,
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/services/test_economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,40 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params(
)
assert sim_params["_metadata"]["process_id"] == MOCK_PROCESS_ID

def test__given_runtime_cache_version__uses_versioned_economy_cache_key(
self,
economy_service,
base_params,
mock_country_package_versions,
mock_get_dataset_version,
mock_policy_service,
mock_reform_impacts_service,
mock_simulation_api,
mock_logger,
mock_datetime,
mock_numpy_random,
monkeypatch,
):
cache_version = "e1cache01"
monkeypatch.setattr(
"policyengine_api.services.economy_service.get_economy_impact_cache_version",
lambda country_id, api_version=None: cache_version,
)
mock_reform_impacts_service.get_all_reform_impacts.return_value = []

economy_service.get_economic_impact(**base_params)

mock_reform_impacts_service.get_all_reform_impacts.assert_called_once_with(
MOCK_COUNTRY_ID,
MOCK_POLICY_ID,
MOCK_BASELINE_POLICY_ID,
MOCK_REGION,
MOCK_DATASET,
MOCK_TIME_PERIOD,
MOCK_OPTIONS_HASH,
cache_version,
)

def test__given_exception__raises_error(
self,
economy_service,
Expand Down
Loading
Loading