-
Notifications
You must be signed in to change notification settings - Fork 2
add QTE support for covariate adaptive randomization #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,9 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
| from typing import Tuple, Any | ||
| from typing import Optional, Tuple, Any | ||
| from copy import deepcopy | ||
| from scipy.stats import norm | ||
| from tqdm.auto import tqdm | ||
| from dte_adj.base import DistributionEstimatorBase | ||
| from dte_adj.util import ArrayLike, _convert_to_ndarray | ||
|
|
@@ -153,6 +154,77 @@ def _compute_interval_probability( | |
| conditional_prediction[:, 1:] - conditional_prediction[:, :-1], | ||
| ) | ||
|
|
||
| def predict_qte( | ||
| self, | ||
| target_treatment_arm: int, | ||
| control_treatment_arm: int, | ||
| quantiles: Optional[np.ndarray] = None, | ||
| alpha: float = 0.05, | ||
| n_bootstrap=500, | ||
| display_progress: bool = True, | ||
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
| """ | ||
| Compute Quantile Treatment Effects (QTE) using stratified bootstrap. | ||
|
|
||
| Uses stratified bootstrap (resampling independently within each stratum) to | ||
| correctly estimate variance under covariate adaptive randomization (CAR). | ||
|
|
||
| Args: | ||
| target_treatment_arm (int): The index of the treatment arm of the treatment group. | ||
| control_treatment_arm (int): The index of the treatment arm of the control group. | ||
| quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. | ||
| alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. | ||
| n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. | ||
| display_progress (bool, optional): Whether to display a progress bar. Defaults to True. | ||
|
|
||
| Returns: | ||
| Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: | ||
| - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile | ||
| - Lower bounds (np.ndarray): Lower confidence interval bounds | ||
| - Upper bounds (np.ndarray): Upper confidence interval bounds | ||
| """ | ||
| qte = self._compute_qtes( | ||
| target_treatment_arm, | ||
| control_treatment_arm, | ||
| quantiles, | ||
| self.covariates, | ||
| self.treatment_arms, | ||
| self.outcomes, | ||
| self.strata, | ||
| ) | ||
|
Comment on lines
+161
to
+194
|
||
|
|
||
| # Precompute stratum indices for stratified bootstrap | ||
| unique_strata = np.unique(self.strata) | ||
| strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} | ||
|
|
||
| qtes = np.zeros((n_bootstrap, qte.shape[0])) | ||
| bootstrap_iter = range(n_bootstrap) | ||
| if display_progress: | ||
| bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") | ||
| for b in bootstrap_iter: | ||
| # Stratified bootstrap: resample within each stratum independently | ||
| bootstrap_indexes = np.concatenate([ | ||
| np.random.choice(idx, size=len(idx), replace=True) | ||
| for idx in strata_indices.values() | ||
| ]) | ||
|
Comment on lines
+196
to
+209
|
||
|
|
||
| qtes[b] = self._compute_qtes( | ||
| target_treatment_arm, | ||
| control_treatment_arm, | ||
| quantiles, | ||
| self.covariates[bootstrap_indexes], | ||
| self.treatment_arms[bootstrap_indexes], | ||
| self.outcomes[bootstrap_indexes], | ||
| self.strata[bootstrap_indexes], | ||
| ) | ||
|
Comment on lines
+205
to
+219
|
||
|
|
||
| qte_var = qtes.var(axis=0) | ||
|
|
||
| qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) | ||
| qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) | ||
|
|
||
| return qte, qte_lower, qte_upper | ||
|
|
||
|
|
||
| class AdjustedStratifiedDistributionEstimator(DistributionEstimatorBase): | ||
| """A class is for estimating the adjusted distribution function and computing the Distributional parameters for CAR.""" | ||
|
|
@@ -405,6 +477,77 @@ def _compute_interval_probability( | |
|
|
||
| return prediction.mean(axis=0), prediction, superset_prediction | ||
|
|
||
| def predict_qte( | ||
| self, | ||
| target_treatment_arm: int, | ||
| control_treatment_arm: int, | ||
| quantiles: Optional[np.ndarray] = None, | ||
| alpha: float = 0.05, | ||
| n_bootstrap=500, | ||
| display_progress: bool = True, | ||
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
| """ | ||
| Compute Quantile Treatment Effects (QTE) using stratified bootstrap. | ||
|
|
||
| Uses stratified bootstrap (resampling independently within each stratum) to | ||
| correctly estimate variance under covariate adaptive randomization (CAR). | ||
|
|
||
| Args: | ||
| target_treatment_arm (int): The index of the treatment arm of the treatment group. | ||
| control_treatment_arm (int): The index of the treatment arm of the control group. | ||
| quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. | ||
| alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. | ||
| n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. | ||
| display_progress (bool, optional): Whether to display a progress bar. Defaults to True. | ||
|
|
||
| Returns: | ||
| Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: | ||
| - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile | ||
| - Lower bounds (np.ndarray): Lower confidence interval bounds | ||
| - Upper bounds (np.ndarray): Upper confidence interval bounds | ||
| """ | ||
| qte = self._compute_qtes( | ||
| target_treatment_arm, | ||
| control_treatment_arm, | ||
| quantiles, | ||
| self.covariates, | ||
| self.treatment_arms, | ||
| self.outcomes, | ||
| self.strata, | ||
| ) | ||
|
Comment on lines
+484
to
+517
|
||
|
|
||
| # Precompute stratum indices for stratified bootstrap | ||
| unique_strata = np.unique(self.strata) | ||
| strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} | ||
|
|
||
| qtes = np.zeros((n_bootstrap, qte.shape[0])) | ||
| bootstrap_iter = range(n_bootstrap) | ||
| if display_progress: | ||
| bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") | ||
| for b in bootstrap_iter: | ||
| # Stratified bootstrap: resample within each stratum independently | ||
| bootstrap_indexes = np.concatenate([ | ||
| np.random.choice(idx, size=len(idx), replace=True) | ||
| for idx in strata_indices.values() | ||
| ]) | ||
|
|
||
| qtes[b] = self._compute_qtes( | ||
| target_treatment_arm, | ||
| control_treatment_arm, | ||
| quantiles, | ||
| self.covariates[bootstrap_indexes], | ||
| self.treatment_arms[bootstrap_indexes], | ||
| self.outcomes[bootstrap_indexes], | ||
| self.strata[bootstrap_indexes], | ||
| ) | ||
|
|
||
| qte_var = qtes.var(axis=0) | ||
|
|
||
| qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) | ||
| qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) | ||
|
|
||
| return qte, qte_lower, qte_upper | ||
|
|
||
| def _compute_model_prediction(self, model, covariates: np.ndarray) -> np.ndarray: | ||
| if hasattr(model, "predict_proba"): | ||
| if self.is_multi_task: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the addition! But it seems this is identical to BaseEstimator.predict_qte. Does this mean #64 is already fixed by the previous refactor and DTE for CAR is working fine as it is?