diff --git a/changelog.d/fix-ci-warnings.fixed.md b/changelog.d/fix-ci-warnings.fixed.md new file mode 100644 index 000000000..c2235f25e --- /dev/null +++ b/changelog.d/fix-ci-warnings.fixed.md @@ -0,0 +1 @@ +Replace legacy SQLModel `session.query(...)` lookups in the SOI ETL loaders and their focused tests with `session.exec(select(...))` to remove deprecation warnings in CI. diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index 0b474c10e..91edf0b4e 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd -from sqlmodel import Session, create_engine +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.db.create_database_tables import ( @@ -313,16 +313,14 @@ def _upsert_target( source: str, notes: Optional[str] = None, ) -> None: - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == stratum_id, Target.variable == variable, Target.period == period, Target.reform_id == 0, ) - .first() - ) + ).first() if existing_target: existing_target.value = value existing_target.source = source @@ -347,14 +345,12 @@ def _get_or_create_national_domain_stratum( session: Session, national_filer_stratum_id: int, variable: str ) -> Stratum: note = f"National filers with {variable} > 0" - stratum = ( - session.query(Stratum) - .filter( + stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == national_filer_stratum_id, Stratum.notes == note, ) - .first() - ) + ).first() if stratum: return stratum @@ -751,14 +747,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): filer_strata = {"national": None, "state": {}, "district": {}} # National filer stratum - check if it exists first - national_filer_stratum = ( - session.query(Stratum) - .filter( + national_filer_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == geo_strata["national"], Stratum.notes == "United States - Tax Filers", ) - .first() - ) + ).first() if not national_filer_stratum: national_filer_stratum = Stratum( @@ -780,14 +774,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): # State filer strata for state_fips, state_geo_stratum_id in geo_strata["state"].items(): # Check if state filer stratum exists - state_filer_stratum = ( - session.query(Stratum) - .filter( + state_filer_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == state_geo_stratum_id, Stratum.notes == f"State FIPS {state_fips} - Tax Filers", ) - .first() - ) + ).first() if not state_filer_stratum: state_filer_stratum = Stratum( @@ -814,15 +806,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): # District filer strata for district_geoid, district_geo_stratum_id in geo_strata["district"].items(): # Check if district filer stratum exists - district_filer_stratum = ( - session.query(Stratum) - .filter( + district_filer_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == district_geo_stratum_id, Stratum.notes == f"Congressional District {district_geoid} - Tax Filers", ) - .first() - ) + ).first() if not district_filer_stratum: district_filer_stratum = Stratum( @@ -917,14 +907,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): ] # Check if stratum already exists - existing_stratum = ( - session.query(Stratum) - .filter( + existing_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == parent_stratum_id, Stratum.notes == note, ) - .first() - ) + ).first() if existing_stratum: new_stratum = existing_stratum @@ -964,15 +952,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): ("tax_unit_count", count_value), ("eitc", amount_value), ]: - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == new_stratum.stratum_id, Target.variable == variable, Target.period == year, ) - .first() - ) + ).first() if existing_target: existing_target.value = value @@ -1047,14 +1033,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): note = f"{geo_description} filers with {amount_variable_name} > 0" # Check if child stratum already exists - existing_stratum = ( - session.query(Stratum) - .filter( + existing_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == parent_stratum_id, Stratum.notes == note, ) - .first() - ) + ).first() if existing_stratum: child_stratum = existing_stratum @@ -1119,15 +1103,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): (count_variable_name, count_value), (amount_variable_name, amount_value), ]: - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == child_stratum.stratum_id, Target.variable == variable, Target.period == year, ) - .first() - ) + ).first() if existing_target: existing_target.value = value @@ -1170,15 +1152,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): ) # Check if target already exists - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == stratum.stratum_id, Target.variable == "adjusted_gross_income", Target.period == year, ) - .first() - ) + ).first() if existing_target: existing_target.value = agi_values.iloc[i][["target_value"]].values[0] @@ -1211,14 +1191,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): note = f"National filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" # Check if national AGI stratum already exists - nat_stratum = ( - session.query(Stratum) - .filter( + nat_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == filer_strata["national"], Stratum.notes == note, ) - .first() - ) + ).first() if not nat_stratum: nat_stratum = Stratum( @@ -1296,14 +1274,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): continue # Skip if not state or district (shouldn't happen, but defensive) # Check if stratum already exists - existing_stratum = ( - session.query(Stratum) - .filter( + existing_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == parent_stratum_id, Stratum.notes == note, ) - .first() - ) + ).first() if existing_stratum: new_stratum = existing_stratum @@ -1331,15 +1307,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): session.flush() # Check if target already exists and update or create it - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == new_stratum.stratum_id, Target.variable == "person_count", Target.period == year, ) - .first() - ) + ).first() if existing_target: existing_target.value = person_count diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py index 7f98dcd86..b6aaa5c2c 100644 --- a/policyengine_us_data/db/etl_national_targets.py +++ b/policyengine_us_data/db/etl_national_targets.py @@ -1,6 +1,6 @@ import warnings -from sqlmodel import Session, create_engine +from sqlmodel import Session, create_engine, select import pandas as pd from policyengine_us_data.storage import STORAGE_FOLDER @@ -527,9 +527,9 @@ def load_national_targets( with Session(engine) as session: # Get the national stratum - us_stratum = ( - session.query(Stratum).filter(Stratum.parent_stratum_id.is_(None)).first() - ) + us_stratum = session.exec( + select(Stratum).where(Stratum.parent_stratum_id.is_(None)) + ).first() if not us_stratum: raise ValueError( @@ -540,15 +540,13 @@ def load_national_targets( for _, target_data in direct_targets_df.iterrows(): target_year = target_data["year"] # Check if target already exists - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == us_stratum.stratum_id, Target.variable == target_data["variable"], Target.period == target_year, ) - .first() - ) + ).first() # Combine source info into notes notes_parts = [] @@ -580,14 +578,12 @@ def load_national_targets( # Process tax-related targets that need filer constraint if not tax_filer_df.empty: # Get or create the national filer stratum - national_filer_stratum = ( - session.query(Stratum) - .filter( + national_filer_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == us_stratum.stratum_id, Stratum.notes == "United States - Tax Filers", ) - .first() - ) + ).first() if not national_filer_stratum: # Create national filer stratum @@ -610,15 +606,13 @@ def load_national_targets( for _, target_data in tax_filer_df.iterrows(): target_year = target_data["year"] # Check if target already exists - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == national_filer_stratum.stratum_id, Target.variable == target_data["variable"], Target.period == target_year, ) - .first() - ) + ).first() # Combine source info into notes notes_parts = [] @@ -649,9 +643,8 @@ def load_national_targets( # Process reform-based tax expenditure targets. if not tax_expenditure_df.empty: - migrated_strata = ( - session.query(Stratum) - .filter( + migrated_stratum_ids = session.exec( + select(Stratum.stratum_id).where( Stratum.parent_stratum_id == us_stratum.stratum_id, Stratum.notes.in_( [ @@ -660,9 +653,7 @@ def load_national_targets( ] ), ) - .all() - ) - migrated_stratum_ids = [s.stratum_id for s in migrated_strata] + ).all() for _, target_data in tax_expenditure_df.iterrows(): target_year = target_data["year"] @@ -670,30 +661,26 @@ def load_national_targets( # Clean up incorrectly scoped baseline rows from older DBs. if migrated_stratum_ids: - stale_targets = ( - session.query(Target) - .filter( + stale_targets = session.exec( + select(Target).where( Target.stratum_id.in_(migrated_stratum_ids), Target.variable == target_data["variable"], Target.period == target_year, Target.reform_id == 0, Target.active, ) - .all() - ) + ).all() for stale_target in stale_targets: stale_target.active = False - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == us_stratum.stratum_id, Target.variable == target_data["variable"], Target.period == target_year, Target.reform_id == target_reform_id, ) - .first() - ) + ).first() notes_parts = [] if pd.notna(target_data.get("notes")): @@ -724,11 +711,9 @@ def load_national_targets( session.add(target) session.flush() - persisted = ( - session.query(Target) - .filter(Target.target_id == target.target_id) - .first() - ) + persisted = session.exec( + select(Target).where(Target.target_id == target.target_id) + ).first() if persisted.reform_id != target_reform_id: print( f" WARNING: {target_data['variable']} persisted " @@ -770,26 +755,22 @@ def load_national_targets( constraint_value = "0" # Check if this stratum already exists - existing_stratum = ( - session.query(Stratum) - .filter( + existing_stratum = session.exec( + select(Stratum).where( Stratum.parent_stratum_id == us_stratum.stratum_id, Stratum.notes == stratum_notes, ) - .first() - ) + ).first() if existing_stratum: # Update the existing target in this stratum - existing_target = ( - session.query(Target) - .filter( + existing_target = session.exec( + select(Target).where( Target.stratum_id == existing_stratum.stratum_id, Target.variable == target_variable, Target.period == target_year, ) - .first() - ) + ).first() if existing_target: existing_target.value = target_value @@ -848,17 +829,16 @@ def load_national_targets( "medical_expense_deduction", "qualified_business_income_deduction", ] - bad_targets = ( - session.query(Target) + bad_targets = session.exec( + select(Target) .join(Stratum, Target.stratum_id == Stratum.stratum_id) - .filter( + .where( Target.variable.in_(tax_exp_vars), - Target.active == True, - Stratum.parent_stratum_id == None, + Target.active, + Stratum.parent_stratum_id.is_(None), Target.reform_id == 0, ) - .all() - ) + ).all() if bad_targets: bad_names = [t.variable for t in bad_targets] raise ValueError( diff --git a/tests/unit/test_etl_national_targets.py b/tests/unit/test_etl_national_targets.py index 6e9448376..84d8c748b 100644 --- a/tests/unit/test_etl_national_targets.py +++ b/tests/unit/test_etl_national_targets.py @@ -1,5 +1,5 @@ import pandas as pd -from sqlmodel import Session +from sqlmodel import Session, select from policyengine_us_data.db.create_database_tables import ( Stratum, @@ -121,11 +121,11 @@ def test_load_national_targets_deactivates_stale_baseline_rows(tmp_path, monkeyp ) with Session(engine) as session: - stale_rows = session.query(Target).filter(Target.reform_id == 0).all() + stale_rows = session.exec(select(Target).where(Target.reform_id == 0)).all() assert stale_rows assert all(not target.active for target in stale_rows) - reform_rows = session.query(Target).filter(Target.reform_id > 0).all() + reform_rows = session.exec(select(Target).where(Target.reform_id > 0)).all() assert len(reform_rows) == 2 assert all(target.active for target in reform_rows) assert {(target.variable, target.reform_id) for target in reform_rows} == { @@ -173,11 +173,11 @@ def test_load_national_targets_supports_liheap_household_counts(tmp_path, monkey ) with Session(engine) as session: - liheap_stratum = ( - session.query(Stratum) - .filter(Stratum.notes == "National LIHEAP Recipient Households") - .first() - ) + liheap_stratum = session.exec( + select(Stratum).where( + Stratum.notes == "National LIHEAP Recipient Households" + ) + ).first() assert liheap_stratum is not None constraints = { @@ -190,14 +190,12 @@ def test_load_national_targets_supports_liheap_household_counts(tmp_path, monkey } assert ("spm_unit_energy_subsidy_reported", ">", "0") in constraints - liheap_target = ( - session.query(Target) - .filter( + liheap_target = session.exec( + select(Target).where( Target.stratum_id == liheap_stratum.stratum_id, Target.variable == "household_count", Target.period == 2024, ) - .first() - ) + ).first() assert liheap_target is not None assert liheap_target.value == 5_876_646