Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/fix-ci-warnings.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Replace legacy SQLModel `session.query(...)` lookups in the SOI ETL loaders and their focused tests with `session.exec(select(...))` to remove deprecation warnings in CI.
106 changes: 40 additions & 66 deletions policyengine_us_data/db/etl_irs_soi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd

from sqlmodel import Session, create_engine
from sqlmodel import Session, create_engine, select

from policyengine_us_data.storage import STORAGE_FOLDER
from policyengine_us_data.db.create_database_tables import (
Expand Down Expand Up @@ -313,16 +313,14 @@ def _upsert_target(
source: str,
notes: Optional[str] = None,
) -> None:
existing_target = (
session.query(Target)
.filter(
existing_target = session.exec(
select(Target).where(
Target.stratum_id == stratum_id,
Target.variable == variable,
Target.period == period,
Target.reform_id == 0,
)
.first()
)
).first()
if existing_target:
existing_target.value = value
existing_target.source = source
Expand All @@ -347,14 +345,12 @@ def _get_or_create_national_domain_stratum(
session: Session, national_filer_stratum_id: int, variable: str
) -> Stratum:
note = f"National filers with {variable} > 0"
stratum = (
session.query(Stratum)
.filter(
stratum = session.exec(
select(Stratum).where(
Stratum.parent_stratum_id == national_filer_stratum_id,
Stratum.notes == note,
)
.first()
)
).first()
if stratum:
return stratum

Expand Down Expand Up @@ -751,14 +747,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
filer_strata = {"national": None, "state": {}, "district": {}}

# National filer stratum - check if it exists first
national_filer_stratum = (
session.query(Stratum)
.filter(
national_filer_stratum = session.exec(
select(Stratum).where(
Stratum.parent_stratum_id == geo_strata["national"],
Stratum.notes == "United States - Tax Filers",
)
.first()
)
).first()

if not national_filer_stratum:
national_filer_stratum = Stratum(
Expand All @@ -780,14 +774,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
# State filer strata
for state_fips, state_geo_stratum_id in geo_strata["state"].items():
# Check if state filer stratum exists
state_filer_stratum = (
session.query(Stratum)
.filter(
state_filer_stratum = session.exec(
select(Stratum).where(
Stratum.parent_stratum_id == state_geo_stratum_id,
Stratum.notes == f"State FIPS {state_fips} - Tax Filers",
)
.first()
)
).first()

if not state_filer_stratum:
state_filer_stratum = Stratum(
Expand All @@ -814,15 +806,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
# District filer strata
for district_geoid, district_geo_stratum_id in geo_strata["district"].items():
# Check if district filer stratum exists
district_filer_stratum = (
session.query(Stratum)
.filter(
district_filer_stratum = session.exec(
select(Stratum).where(
Stratum.parent_stratum_id == district_geo_stratum_id,
Stratum.notes
== f"Congressional District {district_geoid} - Tax Filers",
)
.first()
)
).first()

if not district_filer_stratum:
district_filer_stratum = Stratum(
Expand Down Expand Up @@ -917,14 +907,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
]

# Check if stratum already exists
existing_stratum = (
session.query(Stratum)
.filter(
existing_stratum = session.exec(
select(Stratum).where(
Stratum.parent_stratum_id == parent_stratum_id,
Stratum.notes == note,
)
.first()
)
).first()

if existing_stratum:
new_stratum = existing_stratum
Expand Down Expand Up @@ -964,15 +952,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
("tax_unit_count", count_value),
("eitc", amount_value),
]:
existing_target = (
session.query(Target)
.filter(
existing_target = session.exec(
select(Target).where(
Target.stratum_id == new_stratum.stratum_id,
Target.variable == variable,
Target.period == year,
)
.first()
)
).first()

if existing_target:
existing_target.value = value
Expand Down Expand Up @@ -1047,14 +1033,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
note = f"{geo_description} filers with {amount_variable_name} > 0"

# Check if child stratum already exists
existing_stratum = (
session.query(Stratum)
.filter(
existing_stratum = session.exec(
select(Stratum).where(
Stratum.parent_stratum_id == parent_stratum_id,
Stratum.notes == note,
)
.first()
)
).first()

if existing_stratum:
child_stratum = existing_stratum
Expand Down Expand Up @@ -1119,15 +1103,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
(count_variable_name, count_value),
(amount_variable_name, amount_value),
]:
existing_target = (
session.query(Target)
.filter(
existing_target = session.exec(
select(Target).where(
Target.stratum_id == child_stratum.stratum_id,
Target.variable == variable,
Target.period == year,
)
.first()
)
).first()

if existing_target:
existing_target.value = value
Expand Down Expand Up @@ -1170,15 +1152,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
)

# Check if target already exists
existing_target = (
session.query(Target)
.filter(
existing_target = session.exec(
select(Target).where(
Target.stratum_id == stratum.stratum_id,
Target.variable == "adjusted_gross_income",
Target.period == year,
)
.first()
)
).first()

if existing_target:
existing_target.value = agi_values.iloc[i][["target_value"]].values[0]
Expand Down Expand Up @@ -1211,14 +1191,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
note = f"National filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}"

# Check if national AGI stratum already exists
nat_stratum = (
session.query(Stratum)
.filter(
nat_stratum = session.exec(
select(Stratum).where(
Stratum.parent_stratum_id == filer_strata["national"],
Stratum.notes == note,
)
.first()
)
).first()

if not nat_stratum:
nat_stratum = Stratum(
Expand Down Expand Up @@ -1296,14 +1274,12 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
continue # Skip if not state or district (shouldn't happen, but defensive)

# Check if stratum already exists
existing_stratum = (
session.query(Stratum)
.filter(
existing_stratum = session.exec(
select(Stratum).where(
Stratum.parent_stratum_id == parent_stratum_id,
Stratum.notes == note,
)
.first()
)
).first()

if existing_stratum:
new_stratum = existing_stratum
Expand Down Expand Up @@ -1331,15 +1307,13 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None):
session.flush()

# Check if target already exists and update or create it
existing_target = (
session.query(Target)
.filter(
existing_target = session.exec(
select(Target).where(
Target.stratum_id == new_stratum.stratum_id,
Target.variable == "person_count",
Target.period == year,
)
.first()
)
).first()

if existing_target:
existing_target.value = person_count
Expand Down
Loading
Loading