Sk refactoring#394
Open
SvenKlaassen wants to merge 40 commits into
Open
Conversation
- Introduced learner management in DoubleMLScalar with properties for learner names and instances. - Added abstract method `set_learners` to enforce learner setting in subclasses. - Updated PLR to utilize the new learner management system, including validation checks for learner instances. - Refactored tests to align with the new learner management approach, ensuring proper exception handling and validation.
…nd utility functions
- Implemented the IRM class for double machine learning with interactive regression models in irm_scalar.py. - Added core estimation tests for IRM scalar in test_irm_scalar.py. - Created exception handling tests for IRM scalar in test_irm_scalar_exceptions.py. - Developed tests for handling external predictions in test_irm_scalar_external_predictions.py. - Added return type validation tests for IRM scalar in test_irm_scalar_return_types.py. - Compared the new IRM scalar implementation against the existing DoubleMLIRM in test_irm_scalar_vs_irm.py.
…standards, error handling, performance guidelines, and testing conventions.
…d tests for cluster-based sample splitting and external prediction validation.
…rs; enhance tests for return types and reset behavior.
…, testing, and scalar model test structure
…g; update tests for consistency
- Added `_sensitivity_element_est` method to `DoubleMLScalar`, `IRM`, and `PLR` classes to compute sensitivity elements including sigma2, nu2, and their influence functions. - Introduced `sensitivity_elements` property to retrieve computed sensitivity elements after model fitting. - Implemented validation checks for sensitivity elements in `DoubleMLScalar`. - Added exception handling for sensitivity analysis methods in `IRM` and `PLR` classes to ensure proper input types and values. - Created unit tests for sensitivity analysis, including checks for element shapes, bounds, and exception handling in both `IRM` and `PLR` models. - Ensured compatibility of sensitivity elements between scalar and legacy models in comparison tests.
…ndling in DoubleMLScalar
- Implemented `cate()` and `gate()` methods in `IRM` and `PLR` classes for estimating conditional average treatment effects. - Enhanced `DoubleMLBLP` to support per-rep basis for multi-rep scenarios. - Updated tests for `IRM` and `PLR` to validate new functionality, including checks for correct handling of multi-rep bases and group effects. - Improved validation of basis inputs in `DoubleMLBLP` to accept both single DataFrame and list of DataFrames. - Added new test cases to ensure robustness of the new features and backward compatibility with legacy models.
…sion and add comprehensive tests
… and enhance error handling in PLR and LearnerSpec validation
… checks into dedicated functions
Apply ruff D200/D213/D413 auto-fixes and add __init__ docstrings to DoubleMLVector and PLRVector.
…reamline sample comparison logic in tests
…bleMLScalar class
… DoubleMLScalar class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Introduce a new
DoubleMLScalar/DoubleMLVectorclass hierarchy alongside the existingDoubleMLAPI. The refactor delivers a cleaner, more testable design with explicit tuning, nuisance evaluation, and sensitivity analysis as first-class features. Two concrete scalar models (PLR,IRM) and one vector model (PLRVector) are ported, each backed by a comprehensive test suite that proves exact numerical equivalence with the legacy classes.Motivation
The legacy
DoubleMLbase class conflates single-parameter estimation, multi-treatment orchestration, and inference into one large class. This makes it hard to:The new hierarchy separates these concerns via a layered design with explicit hooks.
New Class Hierarchy
Plus a parallel multi-treatment track:
See doc/diagrams/architecture.md for the full UML and method-resolution diagrams.
Key Design Decisions
__init__accepts learners as optional kwargs (e.g.ml_l,ml_m,ml_g) for one-line construction, but they can also be configured (or replaced) later viaset_learners(...). Decoupling the two paths makes it possible to swap learners, re-tune, or re-fit without rebuilding the model._learner_namesas single source of truth — drives prediction-dict initialization and learner-availability checks; subclasses just declare the list.draw_sample_splitting()is its own step and can be called independently or re-drawn.fit()— orchestratesdraw_sample_splitting()→fit_nuisance_models()→estimate_causal_parameters(). Subclasses implement_nuisance_est()and_get_score_elements(); the mixin provides_est_causal_pars_and_se().fit()/fit_nuisance_models(), validated against_learner_names, and pre-filled before the cross-fitting loop.What's Included
Core infrastructure
LinearScoreMixinScalar models
DoubleMLPLRScalarwithcate(),gate(),_partial_out()DoubleMLIRMScalarwithcate(),gate(), weighted scores (array + dict-with-weights_bar)Vector models
DoubleMLPLRVector, validated against legacyDoubleMLPLRfor multi-treatmentCross-cutting features
tune_ml_models()onDoubleMLScalarwith pruning support,_LEARNER_PARAM_ALIASES(e.g. IRMml_g→[ml_g0, ml_g1]), and a_get_tuning_data()hook for subclass-specific tuning targetsnuisance_targets,nuisance_loss, andevaluate_learners(metric=...)with auto-defaulted RMSE / log-loss and NaN-aware masking_sensitivity_element_est()hook running over all reps post-fit, with framework-ready shapes; supports the fullsensitivity_analysis()pipelineDoubleMLBLPper-rep basis —basismay be a singlepd.DataFrame(shared) orlist[pd.DataFrame]of lengthn_rep. Also fixes a multi-rep / multi-column bug in legacyDoubleMLPLR.cate()(doubleml/utils/blp.py)Test suites
Every scalar model ships with the mandatory 5-file structure plus dedicated files for tuning, evaluation, and sensitivity:
test_<model>_scalar.py— 3-sigma estimation accuracytest_<model>_scalar_return_types.py— property types/shapestest_<model>_scalar_exceptions.py— input validationtest_<model>_scalar_vs_<model>.py— exact match with legacy (rtol=1e-9)test_<model>_scalar_external_predictions.py— external-prediction equivalencetest_<model>_scalar_tune_ml_models.py— Optuna tuningtest_<model>_scalar_evaluate_learners.py— nuisance loss / metricstest_<model>_scalar_sensitivity.py— sensitivity bounds & monotonicitytest_<model>_scalar_cate_gate.py— CATE/GATE (PLR & IRM)PLR vector ships with 5 corresponding test files. Plus shared scalar-level tests in doubleml/tests/ (cluster splits, fit, set-sample-splitting, tune-pruning, tune-exceptions, ext-predictions).
Tooling & docs
Feature Parity with Legacy Classes
cate()/gate()(PLR + IRM)_partial_out()(PLR)weights_bar(IRM)_check_smpls_dependent_inputs()hookpolicy_tree()(IRM)trimming_rule/trimming_thresholddeprecated propsps_processor_configBackwards Compatibility
DoubleMLPLR,DoubleMLIRM, …) remain unchanged and pass their existing test suites.DoubleMLPLR.cate()(basis * D_tildemis-broadcast forn_rep > 1andd_basis > 1) is fixed via the new BLP per-rep API.Test Plan
pytest -m cipasses locallypytest doubleml/plm/tests/ doubleml/irm/tests/ -v— full module suites for both refactored and legacy classespytest doubleml/tests/test_scalar_*.py -v— shared scalar infrastructure testsblack .,ruff check .,mypy doublemlclean (pre-existing mypy errors not introduced by this branch)summary,confint,bootstrap, andsensitivity_analysis()on a fittedDoubleMLPLRScalarandDoubleMLIRMScalar*_scalar_vs_*.py) atrtol=1e-9Follow-ups (out of scope)
DoubleMLIRMVectorDoubleMLPLIVScalar,DoubleMLPLPRScalarDID,DIDCSBinary,DIDMulti)DoubleMLVectorbase-class testspolicy_tree()port