From 4bd2594321da4035856b7736f9483a6e699fb32d Mon Sep 17 00:00:00 2001 From: okiner-3 Date: Wed, 29 Apr 2026 00:50:25 +0900 Subject: [PATCH] add QTE support for covariate adaptive randomization Implement predict_qte in SimpleStratifiedDistributionEstimator and AdjustedStratifiedDistributionEstimator with stratified bootstrap (resampling within each stratum independently) to correctly estimate variance under CAR designs. --- dte_adj/stratified.py | 145 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 144 insertions(+), 1 deletion(-) diff --git a/dte_adj/stratified.py b/dte_adj/stratified.py index a4038e0..88679c6 100644 --- a/dte_adj/stratified.py +++ b/dte_adj/stratified.py @@ -1,8 +1,9 @@ from __future__ import annotations import numpy as np -from typing import Tuple, Any +from typing import Optional, Tuple, Any from copy import deepcopy +from scipy.stats import norm from tqdm.auto import tqdm from dte_adj.base import DistributionEstimatorBase from dte_adj.util import ArrayLike, _convert_to_ndarray @@ -153,6 +154,77 @@ def _compute_interval_probability( conditional_prediction[:, 1:] - conditional_prediction[:, :-1], ) + def predict_qte( + self, + target_treatment_arm: int, + control_treatment_arm: int, + quantiles: Optional[np.ndarray] = None, + alpha: float = 0.05, + n_bootstrap=500, + display_progress: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute Quantile Treatment Effects (QTE) using stratified bootstrap. + + Uses stratified bootstrap (resampling independently within each stratum) to + correctly estimate variance under covariate adaptive randomization (CAR). + + Args: + target_treatment_arm (int): The index of the treatment arm of the treatment group. + control_treatment_arm (int): The index of the treatment arm of the control group. + quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. + alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. + n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. + display_progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: + - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile + - Lower bounds (np.ndarray): Lower confidence interval bounds + - Upper bounds (np.ndarray): Upper confidence interval bounds + """ + qte = self._compute_qtes( + target_treatment_arm, + control_treatment_arm, + quantiles, + self.covariates, + self.treatment_arms, + self.outcomes, + self.strata, + ) + + # Precompute stratum indices for stratified bootstrap + unique_strata = np.unique(self.strata) + strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} + + qtes = np.zeros((n_bootstrap, qte.shape[0])) + bootstrap_iter = range(n_bootstrap) + if display_progress: + bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") + for b in bootstrap_iter: + # Stratified bootstrap: resample within each stratum independently + bootstrap_indexes = np.concatenate([ + np.random.choice(idx, size=len(idx), replace=True) + for idx in strata_indices.values() + ]) + + qtes[b] = self._compute_qtes( + target_treatment_arm, + control_treatment_arm, + quantiles, + self.covariates[bootstrap_indexes], + self.treatment_arms[bootstrap_indexes], + self.outcomes[bootstrap_indexes], + self.strata[bootstrap_indexes], + ) + + qte_var = qtes.var(axis=0) + + qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) + qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) + + return qte, qte_lower, qte_upper + class AdjustedStratifiedDistributionEstimator(DistributionEstimatorBase): """A class is for estimating the adjusted distribution function and computing the Distributional parameters for CAR.""" @@ -405,6 +477,77 @@ def _compute_interval_probability( return prediction.mean(axis=0), prediction, superset_prediction + def predict_qte( + self, + target_treatment_arm: int, + control_treatment_arm: int, + quantiles: Optional[np.ndarray] = None, + alpha: float = 0.05, + n_bootstrap=500, + display_progress: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute Quantile Treatment Effects (QTE) using stratified bootstrap. + + Uses stratified bootstrap (resampling independently within each stratum) to + correctly estimate variance under covariate adaptive randomization (CAR). + + Args: + target_treatment_arm (int): The index of the treatment arm of the treatment group. + control_treatment_arm (int): The index of the treatment arm of the control group. + quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1, 0.2, ..., 0.9]. + alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. + n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500. + display_progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: + - Expected QTEs (np.ndarray): Treatment effect estimates at each quantile + - Lower bounds (np.ndarray): Lower confidence interval bounds + - Upper bounds (np.ndarray): Upper confidence interval bounds + """ + qte = self._compute_qtes( + target_treatment_arm, + control_treatment_arm, + quantiles, + self.covariates, + self.treatment_arms, + self.outcomes, + self.strata, + ) + + # Precompute stratum indices for stratified bootstrap + unique_strata = np.unique(self.strata) + strata_indices = {s: np.where(self.strata == s)[0] for s in unique_strata} + + qtes = np.zeros((n_bootstrap, qte.shape[0])) + bootstrap_iter = range(n_bootstrap) + if display_progress: + bootstrap_iter = tqdm(bootstrap_iter, desc="Bootstrap QTE") + for b in bootstrap_iter: + # Stratified bootstrap: resample within each stratum independently + bootstrap_indexes = np.concatenate([ + np.random.choice(idx, size=len(idx), replace=True) + for idx in strata_indices.values() + ]) + + qtes[b] = self._compute_qtes( + target_treatment_arm, + control_treatment_arm, + quantiles, + self.covariates[bootstrap_indexes], + self.treatment_arms[bootstrap_indexes], + self.outcomes[bootstrap_indexes], + self.strata[bootstrap_indexes], + ) + + qte_var = qtes.var(axis=0) + + qte_lower = qte + norm.ppf(alpha / 2) * np.sqrt(qte_var) + qte_upper = qte + norm.ppf(1 - alpha / 2) * np.sqrt(qte_var) + + return qte, qte_lower, qte_upper + def _compute_model_prediction(self, model, covariates: np.ndarray) -> np.ndarray: if hasattr(model, "predict_proba"): if self.is_multi_task: