Skip to content

Sk refactoring#394

Open
SvenKlaassen wants to merge 40 commits into
mainfrom
sk-refactoring
Open

Sk refactoring#394
SvenKlaassen wants to merge 40 commits into
mainfrom
sk-refactoring

Conversation

@SvenKlaassen
Copy link
Copy Markdown
Member

@SvenKlaassen SvenKlaassen commented May 9, 2026

Summary

Introduce a new DoubleMLScalar / DoubleMLVector class hierarchy alongside the existing DoubleML API. The refactor delivers a cleaner, more testable design with explicit tuning, nuisance evaluation, and sensitivity analysis as first-class features. Two concrete scalar models (PLR, IRM) and one vector model (PLRVector) are ported, each backed by a comprehensive test suite that proves exact numerical equivalence with the legacy classes.

Motivation

The legacy DoubleML base class conflates single-parameter estimation, multi-treatment orchestration, and inference into one large class. This makes it hard to:

  • Add new models without inheriting unrelated multi-parameter machinery
  • Swap closed-form vs. numerical score solvers
  • Test nuisance behavior in isolation from causal inference
  • Evolve features (tuning, sensitivity) without touching the core fit loop

The new hierarchy separates these concerns via a layered design with explicit hooks.

New Class Hierarchy

DoubleMLBase (ABC)               # data + framework delegation (coef, se, summary, confint, bootstrap, sensitivity)
└── DoubleMLScalar (ABC)         # single-parameter orchestrator (fit, sample splitting, learners, predictions)
    ├── LinearScoreMixin         # closed-form: theta = -E[psi_b] / E[psi_a]
    │   ├── DoubleMLPLRScalar
    │   └── DoubleMLIRMScalar
    └── NonLinearScoreMixin      # (planned) numerical root-finding

Plus a parallel multi-treatment track:

DoubleMLVector                   # multi-treatment base
└── DoubleMLPLRVector            # exact equivalence with legacy DoubleMLPLR for k>1 treatments

See doc/diagrams/architecture.md for the full UML and method-resolution diagrams.

Key Design Decisions

  • Learners optional in constructor__init__ accepts learners as optional kwargs (e.g. ml_l, ml_m, ml_g) for one-line construction, but they can also be configured (or replaced) later via set_learners(...). Decoupling the two paths makes it possible to swap learners, re-tune, or re-fit without rebuilding the model.
  • _learner_names as single source of truth — drives prediction-dict initialization and learner-availability checks; subclasses just declare the list.
  • Resampling separated from constructordraw_sample_splitting() is its own step and can be called independently or re-drawn.
  • Template method fit() — orchestrates draw_sample_splitting()fit_nuisance_models()estimate_causal_parameters(). Subclasses implement _nuisance_est() and _get_score_elements(); the mixin provides _est_causal_pars_and_se().
  • External predictions — passed to fit() / fit_nuisance_models(), validated against _learner_names, and pre-filled before the cross-fitting loop.

What's Included

Core infrastructure

Scalar models

Vector models

Cross-cutting features

  • Optuna tuningtune_ml_models() on DoubleMLScalar with pruning support, _LEARNER_PARAM_ALIASES (e.g. IRM ml_g[ml_g0, ml_g1]), and a _get_tuning_data() hook for subclass-specific tuning targets
  • Nuisance evaluationnuisance_targets, nuisance_loss, and evaluate_learners(metric=...) with auto-defaulted RMSE / log-loss and NaN-aware masking
  • Sensitivity analysis — vectorized _sensitivity_element_est() hook running over all reps post-fit, with framework-ready shapes; supports the full sensitivity_analysis() pipeline
  • DoubleMLBLP per-rep basisbasis may be a single pd.DataFrame (shared) or list[pd.DataFrame] of length n_rep. Also fixes a multi-rep / multi-column bug in legacy DoubleMLPLR.cate() (doubleml/utils/blp.py)

Test suites

Every scalar model ships with the mandatory 5-file structure plus dedicated files for tuning, evaluation, and sensitivity:

  • test_<model>_scalar.py — 3-sigma estimation accuracy
  • test_<model>_scalar_return_types.py — property types/shapes
  • test_<model>_scalar_exceptions.py — input validation
  • test_<model>_scalar_vs_<model>.py — exact match with legacy (rtol=1e-9)
  • test_<model>_scalar_external_predictions.py — external-prediction equivalence
  • test_<model>_scalar_tune_ml_models.py — Optuna tuning
  • test_<model>_scalar_evaluate_learners.py — nuisance loss / metrics
  • test_<model>_scalar_sensitivity.py — sensitivity bounds & monotonicity
  • test_<model>_scalar_cate_gate.py — CATE/GATE (PLR & IRM)

PLR vector ships with 5 corresponding test files. Plus shared scalar-level tests in doubleml/tests/ (cluster splits, fit, set-sample-splitting, tune-pruning, tune-exceptions, ext-predictions).

Tooling & docs

Feature Parity with Legacy Classes

Feature Status
cate() / gate() (PLR + IRM) ported
_partial_out() (PLR) ported
Array weights (IRM) supported
Dict weights with weights_bar (IRM) supported, validated via _check_smpls_dependent_inputs() hook
policy_tree() (IRM) not yet ported
Callable score intentionally not ported (design decision)
trimming_rule / trimming_threshold deprecated props replaced by ps_processor_config

Backwards Compatibility

  • All legacy classes (DoubleMLPLR, DoubleMLIRM, …) remain unchanged and pass their existing test suites.
  • The new hierarchy is additive — exported alongside the legacy API in doubleml/__init__.py.
  • A latent multi-rep / multi-column bug in legacy DoubleMLPLR.cate() (basis * D_tilde mis-broadcast for n_rep > 1 and d_basis > 1) is fixed via the new BLP per-rep API.

Test Plan

  • pytest -m ci passes locally
  • pytest doubleml/plm/tests/ doubleml/irm/tests/ -v — full module suites for both refactored and legacy classes
  • pytest doubleml/tests/test_scalar_*.py -v — shared scalar infrastructure tests
  • black ., ruff check ., mypy doubleml clean (pre-existing mypy errors not introduced by this branch)
  • Spot-check summary, confint, bootstrap, and sensitivity_analysis() on a fitted DoubleMLPLRScalar and DoubleMLIRMScalar
  • Spot-check exact-match tests (*_scalar_vs_*.py) at rtol=1e-9

Follow-ups (out of scope)

  • DoubleMLIRMVector
  • DoubleMLPLIVScalar, DoubleMLPLPRScalar
  • DID scalar variants (DID, DIDCSBinary, DIDMulti)
  • DoubleMLVector base-class tests
  • IRM policy_tree() port

- Introduced learner management in DoubleMLScalar with properties for learner names and instances.
- Added abstract method `set_learners` to enforce learner setting in subclasses.
- Updated PLR to utilize the new learner management system, including validation checks for learner instances.
- Refactored tests to align with the new learner management approach, ensuring proper exception handling and validation.
- Implemented the IRM class for double machine learning with interactive regression models in irm_scalar.py.
- Added core estimation tests for IRM scalar in test_irm_scalar.py.
- Created exception handling tests for IRM scalar in test_irm_scalar_exceptions.py.
- Developed tests for handling external predictions in test_irm_scalar_external_predictions.py.
- Added return type validation tests for IRM scalar in test_irm_scalar_return_types.py.
- Compared the new IRM scalar implementation against the existing DoubleMLIRM in test_irm_scalar_vs_irm.py.
…standards, error handling, performance guidelines, and testing conventions.
…d tests for cluster-based sample splitting and external prediction validation.
…rs; enhance tests for return types and reset behavior.
- Added `_sensitivity_element_est` method to `DoubleMLScalar`, `IRM`, and `PLR` classes to compute sensitivity elements including sigma2, nu2, and their influence functions.
- Introduced `sensitivity_elements` property to retrieve computed sensitivity elements after model fitting.
- Implemented validation checks for sensitivity elements in `DoubleMLScalar`.
- Added exception handling for sensitivity analysis methods in `IRM` and `PLR` classes to ensure proper input types and values.
- Created unit tests for sensitivity analysis, including checks for element shapes, bounds, and exception handling in both `IRM` and `PLR` models.
- Ensured compatibility of sensitivity elements between scalar and legacy models in comparison tests.
- Implemented `cate()` and `gate()` methods in `IRM` and `PLR` classes for estimating conditional average treatment effects.
- Enhanced `DoubleMLBLP` to support per-rep basis for multi-rep scenarios.
- Updated tests for `IRM` and `PLR` to validate new functionality, including checks for correct handling of multi-rep bases and group effects.
- Improved validation of basis inputs in `DoubleMLBLP` to accept both single DataFrame and list of DataFrames.
- Added new test cases to ensure robustness of the new features and backward compatibility with legacy models.
Comment thread doubleml/irm/tests/test_irm_scalar_exceptions.py Fixed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants