diff --git a/.gitignore b/.gitignore index 922c724a..ec45ac1c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ vci_job_* lightning_logs/ outputs/ log/ +logs/ +debugging/ uv.lock tmp/ notebooks/ diff --git a/README.md b/README.md index 8c35a7dd..a4a8d926 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,29 @@ Optional downsampling overrides: - `data.kwargs.downsample` downsample counts during loading (`<=1` keeps that fraction by binomial sampling, `>1` targets that read depth per cell for `output_space=all`). - `data.kwargs.downsample_cells` caps the number of cells loaded per `(cell_type, perturbation[, batch])` group while leaving smaller groups unchanged. +### ST training contract: `embed_key`, `output_space`, decoder, and metrics + +For ST (`state tx train`), two settings jointly define target tensors, decoder usage, and logged losses: +- `data.kwargs.embed_key` +- `data.kwargs.output_space` + +The current contract is: + +| `embed_key` | `output_space` | Main target/loss space | Decoder target (if present) | Logged metrics | +|---|---|---|---|---| +| `X_hvg` or `null` | `gene` | Expression (HVG/count space) | None expected | `train/expression_loss`, `val/expression_loss` | +| non-`X_hvg` embedding key (for example `X_state`, `X_uce`) | `gene` | Embedding space | Expression (HVG/count space) | `train/embedding_loss`, `val/embedding_loss`, plus `train/expression_loss`, `val/expression_loss` | +| non-`X_hvg` embedding key | `all` | Embedding space | Expression (full transcriptome) | `train/embedding_loss`, `val/embedding_loss`, plus `train/expression_loss`, `val/expression_loss` | +| any | `embedding` | Embedding space | Disabled | `train/embedding_loss`, `val/embedding_loss` | + +Notes: +- Main loss is always computed from the model output against `batch["pert_cell_emb"]`. +- Decoder loss (expression loss) is only logged when decoder targets are available in `batch["pert_cell_counts"]`. +- With `output_space=embedding`, decoder is disabled. +- Best-checkpoint monitoring is: + - `val/expression_loss` when `output_space` is `gene` or `all` + - `val/embedding_loss` when `output_space` is `embedding` + ### predict Evaluates a trained run with `cell-eval` metrics (or just runs prediction with `--predict-only`). diff --git a/pyproject.toml b/pyproject.toml index 3f53bfb1..d03cdd42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arc-state" -version = "0.10.4" +version = "1.0.0" description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts." readme = "README.md" authors = [ diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 75d81b73..86e1911b 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -443,12 +443,21 @@ def pad_adata_with_tsv( cell_set_len = args.max_set_len if args.max_set_len is not None else getattr(model, "cell_sentence_len", 256) uses_batch_encoder = getattr(model, "batch_encoder", None) is not None output_space = getattr(model, "output_space", cfg.get("data", {}).get("kwargs", {}).get("output_space", "gene")) + nb_loss_enabled = bool(getattr(model, "nb_loss", cfg.get("model", {}).get("kwargs", {}).get("nb_loss", False))) + if nb_loss_enabled and output_space == "embedding": + raise ValueError( + "model.kwargs.nb_loss=True is incompatible with data.kwargs.output_space='embedding'. " + "Use output_space='gene' or output_space='all'." + ) + if nb_loss_enabled and output_space not in {"gene", "all"}: + raise ValueError(f"nb_loss=True requires output_space in {{'gene', 'all'}}; got {output_space!r}.") if not args.quiet: print(f"Model device: {device}") print(f"Model cell_set_len (max sequence length): {cell_set_len}") print(f"Model uses batch encoder: {bool(uses_batch_encoder)}") print(f"Model output space: {output_space}") + print(f"Model nb_loss enabled: {nb_loss_enabled}") # ----------------------- # 3) Load AnnData @@ -720,7 +729,7 @@ def pad_adata_with_tsv( store_raw_expression = (args.embed_key is not None and args.embed_key != "X_hvg" and output_space == "gene") or ( args.embed_key is not None and output_space == "all" ) - counts_expected = store_raw_expression + counts_expected = store_raw_expression or nb_loss_enabled counts_out_target: Optional[str] = None counts_obsm_key: Optional[str] = None sim_counts: Optional[np.ndarray] = None @@ -916,24 +925,28 @@ def group_control_indices(group_name: str) -> np.ndarray: start = end # next window - # Clip gene-space predictions to keep downstream eval consistent. + # Clip legacy decoder outputs only; NB count outputs remain unclipped. if output_space in {"gene", "all"}: - if out_target == "X": - clip_array(sim_X) - elif out_target.startswith("obsm['") and out_target.endswith("']"): - pred_key = out_target[6:-2] - if writes_to[0] == ".obsm" and pred_key == writes_to[1]: - clip_array(sim_obsm) - elif pred_key in adata.obsm: - clip_array(adata.obsm[pred_key]) + if nb_loss_enabled: + if not args.quiet: + print("nb_loss=True: skipping clipping of simulated outputs.") else: - if writes_to[0] == ".X": + if out_target == "X": clip_array(sim_X) + elif out_target.startswith("obsm['") and out_target.endswith("']"): + pred_key = out_target[6:-2] + if writes_to[0] == ".obsm" and pred_key == writes_to[1]: + clip_array(sim_obsm) + elif pred_key in adata.obsm: + clip_array(adata.obsm[pred_key]) else: - clip_array(sim_obsm) + if writes_to[0] == ".X": + clip_array(sim_X) + else: + clip_array(sim_obsm) - if counts_written and sim_counts is not None: - clip_array(sim_counts) + if counts_written and sim_counts is not None: + clip_array(sim_counts) # ----------------------- # 5) Persist the updated AnnData diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 7c88929e..9182c0d8 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -227,6 +227,46 @@ def ensure_list(values, batch_size: int): data_module.test_datasets = [] data_module._setup_global_maps() data_module.setup(stage="test") + nb_loss_enabled = bool(cfg.get("model", {}).get("kwargs", {}).get("nb_loss", False)) + output_space = cfg.get("data", {}).get("kwargs", {}).get("output_space", "gene") + if nb_loss_enabled and output_space == "embedding": + raise ValueError( + "model.kwargs.nb_loss=True is incompatible with data.kwargs.output_space='embedding'. " + "Use output_space='gene' or output_space='all'." + ) + if nb_loss_enabled and output_space not in {"gene", "all"}: + raise ValueError( + f"model.kwargs.nb_loss=True requires data.kwargs.output_space in {{'gene', 'all'}}; got {output_space!r}." + ) + if nb_loss_enabled: + resolved_is_log1p = bool(getattr(data_module, "is_log1p", cfg["data"]["kwargs"].get("is_log1p", False))) + expected_exp_counts = resolved_is_log1p + current_exp_counts = bool(getattr(data_module, "exp_counts", False)) + if current_exp_counts != expected_exp_counts: + logger.warning( + "nb_loss=True requires exp_counts to follow is_log1p. " + "Resolved is_log1p=%s, overriding exp_counts %s -> %s.", + resolved_is_log1p, + current_exp_counts, + expected_exp_counts, + ) + data_module.exp_counts = expected_exp_counts + if data_module.embed_key not in {None, "X_hvg"} and not bool(getattr(data_module, "store_raw_basal", False)): + logger.warning( + "nb_loss=True with embed_key=%r and store_raw_basal=False. " + "NB library-size estimation will fall back to ctrl_cell_emb.", + data_module.embed_key, + ) + cfg["data"]["kwargs"]["is_log1p"] = resolved_is_log1p + cfg["data"]["kwargs"]["exp_counts"] = expected_exp_counts + resolved_exp_counts = bool(getattr(data_module, "exp_counts", cfg["data"]["kwargs"].get("exp_counts", False))) + metrics_is_log1p = not (nb_loss_enabled or resolved_exp_counts) + logger.info( + "Metrics config: setting pdex is_log1p=%s (nb_loss=%s, exp_counts=%s)", + metrics_is_log1p, + nb_loss_enabled, + resolved_exp_counts, + ) logger.info("Loaded data module from %s", data_module_path) # Seed everything @@ -250,10 +290,6 @@ def ensure_list(values, batch_size: int): from ...tx.models.embed_sum import EmbedSumPerturbationModel ModelClass = EmbedSumPerturbationModel - elif model_class_name.lower() == "old_neuralot": - from ...tx.models.old_neural_ot import OldNeuralOTPerturbationModel - - ModelClass = OldNeuralOTPerturbationModel elif model_class_name.lower() in ["neuralot", "pertsets", "state"]: from ...tx.models.state_transition import StateTransitionPerturbationModel @@ -337,6 +373,12 @@ def ensure_list(values, batch_size: int): and data_module.embed_key != "X_hvg" and cfg["data"]["kwargs"]["output_space"] == "gene" ) or (data_module.embed_key is not None and cfg["data"]["kwargs"]["output_space"] == "all") + use_count_outputs = store_raw_expression or nb_loss_enabled + if nb_loss_enabled and not store_raw_expression: + logger.info( + "nb_loss=True: forcing prediction artifacts to use NB count outputs even though " + "store_raw_expression would otherwise be disabled." + ) if args.pseudobulk: logger.info("Pseudobulk enabled; aggregating running means by (context, perturbation).") @@ -348,13 +390,13 @@ def ensure_list(values, batch_size: int): os.makedirs(results_dir, exist_ok=True) pseudo_x_dim = None - if store_raw_expression: - if cfg["data"]["kwargs"]["output_space"] == "gene": + if use_count_outputs: + if output_space == "gene": pseudo_x_dim = hvg_dim - elif cfg["data"]["kwargs"]["output_space"] == "all": + elif output_space == "all": pseudo_x_dim = gene_dim else: - raise ValueError(f"Unsupported output_space for pseudobulk: {cfg['data']['kwargs']['output_space']}") + raise ValueError(f"Unsupported output_space for pseudobulk: {output_space}") pb_groups: dict[tuple[str, str], dict] = {} context_mode = None @@ -402,7 +444,7 @@ def ensure_list(values, batch_size: int): batch_real_gene_np = None batch_gene_pred_np = None - if store_raw_expression: + if use_count_outputs: batch_real_gene_np = batch_preds["pert_cell_counts"].cpu().numpy().astype(np.float32) batch_gene_pred_np = batch_preds["pert_cell_counts_preds"].cpu().numpy().astype(np.float32) @@ -427,10 +469,8 @@ def ensure_list(values, batch_size: int): "count": 0, "pred_sum": np.zeros(output_dim, dtype=np.float64), "real_sum": np.zeros(output_dim, dtype=np.float64), - "x_hvg_sum": np.zeros(pseudo_x_dim, dtype=np.float64) if store_raw_expression else None, - "counts_pred_sum": ( - np.zeros(pseudo_x_dim, dtype=np.float64) if store_raw_expression else None - ), + "x_hvg_sum": np.zeros(pseudo_x_dim, dtype=np.float64) if use_count_outputs else None, + "counts_pred_sum": np.zeros(pseudo_x_dim, dtype=np.float64) if use_count_outputs else None, } pb_groups[(context_label, pert_name)] = entry elif entry["celltype_name"] != current_celltype: @@ -442,7 +482,7 @@ def ensure_list(values, batch_size: int): entry["count"] += int(idx_arr.size) entry["pred_sum"] += batch_pred_np[idx_arr].sum(axis=0, dtype=np.float64) entry["real_sum"] += batch_real_np[idx_arr].sum(axis=0, dtype=np.float64) - if store_raw_expression: + if use_count_outputs: entry["x_hvg_sum"] += batch_real_gene_np[idx_arr].sum(axis=0, dtype=np.float64) entry["counts_pred_sum"] += batch_gene_pred_np[idx_arr].sum(axis=0, dtype=np.float64) @@ -462,8 +502,8 @@ def ensure_list(values, batch_size: int): pred_bulk = np.empty((n_groups, output_dim), dtype=np.float32) real_bulk = np.empty((n_groups, output_dim), dtype=np.float32) - pred_x = np.empty((n_groups, pseudo_x_dim), dtype=np.float32) if store_raw_expression else None - real_x = np.empty((n_groups, pseudo_x_dim), dtype=np.float32) if store_raw_expression else None + pred_x = np.empty((n_groups, pseudo_x_dim), dtype=np.float32) if use_count_outputs else None + real_x = np.empty((n_groups, pseudo_x_dim), dtype=np.float32) if use_count_outputs else None reserved_obs_keys = { str(data_module.pert_col), @@ -495,7 +535,7 @@ def ensure_list(values, batch_size: int): denom = float(count) pred_bulk[idx, :] = (entry["pred_sum"] / denom).astype(np.float32, copy=False) real_bulk[idx, :] = (entry["real_sum"] / denom).astype(np.float32, copy=False) - if store_raw_expression: + if use_count_outputs: pred_x[idx, :] = (entry["counts_pred_sum"] / denom).astype(np.float32, copy=False) real_x[idx, :] = (entry["x_hvg_sum"] / denom).astype(np.float32, copy=False) @@ -508,18 +548,22 @@ def ensure_list(values, batch_size: int): obs_dict[data_module.batch_col].append(entry["batch_name"]) obs = pd.DataFrame(obs_dict) - if store_raw_expression: + if use_count_outputs: adata_pred = anndata.AnnData(X=pred_x, obs=obs) adata_real = anndata.AnnData(X=real_x, obs=obs) - adata_pred.obsm[data_module.embed_key] = pred_bulk - adata_real.obsm[data_module.embed_key] = real_bulk + if data_module.embed_key is not None: + adata_pred.obsm[data_module.embed_key] = pred_bulk + adata_real.obsm[data_module.embed_key] = real_bulk else: adata_pred = anndata.AnnData(X=pred_bulk, obs=obs) adata_real = anndata.AnnData(X=real_bulk, obs=obs) - clip_anndata_values(adata_pred, max_value=14.0) - clip_anndata_values(adata_real, max_value=14.0) - logger.info("Clipped pseudobulk adata_pred and adata_real X values to [0.0, 14.0].") + if nb_loss_enabled: + logger.info("nb_loss=True in run config; skipping pseudobulk clipping of adata_pred/adata_real X values.") + else: + clip_anndata_values(adata_pred, max_value=14.0) + clip_anndata_values(adata_real, max_value=14.0) + logger.info("Clipped pseudobulk adata_pred and adata_real X values to [0.0, 14.0].") if args.shared_only: try: @@ -566,7 +610,7 @@ def ensure_list(values, batch_size: int): f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}" ) - pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) + pdex_kwargs = dict(exp_post_agg=True, is_log1p=metrics_is_log1p) for ct in ct_split_real.keys(): real_ct = ct_split_real[ct] pred_ct = ct_split_pred[ct] @@ -608,12 +652,12 @@ def ensure_list(values, batch_size: int): final_X_hvg = None final_pert_cell_counts_preds = None - if store_raw_expression: + if use_count_outputs: # Preallocate matrices of shape (num_cells, gene_dim) for decoded predictions. - if cfg["data"]["kwargs"]["output_space"] == "gene": + if output_space == "gene": final_X_hvg = np.empty((num_cells, hvg_dim), dtype=np.float32) final_pert_cell_counts_preds = np.empty((num_cells, hvg_dim), dtype=np.float32) - if cfg["data"]["kwargs"]["output_space"] == "all": + if output_space == "all": final_X_hvg = np.empty((num_cells, gene_dim), dtype=np.float32) final_pert_cell_counts_preds = np.empty((num_cells, gene_dim), dtype=np.float32) @@ -719,9 +763,10 @@ def ensure_list(values, batch_size: int): adata_real = anndata.AnnData(X=final_X_hvg, obs=obs) # add the embedding predictions - adata_pred.obsm[data_module.embed_key] = final_preds - adata_real.obsm[data_module.embed_key] = final_reals - logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") + if data_module.embed_key is not None: + adata_pred.obsm[data_module.embed_key] = final_preds + adata_real.obsm[data_module.embed_key] = final_reals + logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") else: # if len(gene_names) != final_preds.shape[1]: # gene_names = np.load( @@ -736,10 +781,13 @@ def ensure_list(values, batch_size: int): # adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) adata_real = anndata.AnnData(X=final_reals, obs=obs) - # Clip extreme values to keep cell-eval log1p checks happy. - clip_anndata_values(adata_pred, max_value=14.0) - clip_anndata_values(adata_real, max_value=14.0) - logger.info("Clipped adata_pred and adata_real X values to [0.0, 14.0] before evaluation.") + # Clip extreme values to keep cell-eval log1p checks happy unless NB loss is active. + if nb_loss_enabled: + logger.info("nb_loss=True in run config; skipping clipping of adata_pred/adata_real X values.") + else: + clip_anndata_values(adata_pred, max_value=14.0) + clip_anndata_values(adata_real, max_value=14.0) + logger.info("Clipped adata_pred and adata_real X values to [0.0, 14.0] before evaluation.") # Optionally filter to perturbations seen in at least one training context if args.shared_only: @@ -798,7 +846,7 @@ def ensure_list(values, batch_size: int): f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}" ) - pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) + pdex_kwargs = dict(exp_post_agg=True, is_log1p=metrics_is_log1p) for ct in ct_split_real.keys(): real_ct = ct_split_real[ct] diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 2a939d6e..ddab0b95 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -11,20 +11,17 @@ def add_arguments_train(parser: ap.ArgumentParser): def run_tx_train(cfg: DictConfig): - import json import logging import os import pickle import shutil from os.path import exists, join - from pathlib import Path import lightning.pytorch as pl import torch from cell_load.data_modules import PerturbationDataModule from cell_load.utils.modules import get_datamodule from lightning.pytorch.loggers import WandbLogger - from lightning.pytorch.plugins.precision import MixedPrecision from ...tx.callbacks import ( BatchSpeedMonitorCallback, @@ -66,52 +63,39 @@ def run_tx_train(cfg: DictConfig): try: sentence_len = cfg["model"]["cell_set_len"] except KeyError: - if cfg["model"]["name"].lower() in ["cpa", "scvi"] or cfg["model"]["name"].lower().startswith("scgpt"): - if "cell_sentence_len" in cfg["model"]["kwargs"] and cfg["model"]["kwargs"]["cell_sentence_len"] > 1: - sentence_len = cfg["model"]["kwargs"]["cell_sentence_len"] - cfg["training"]["batch_size"] = 1 - else: - sentence_len = 1 - else: - try: - sentence_len = cfg["model"]["kwargs"]["transformer_backbone_kwargs"]["n_positions"] - except: - sentence_len = cfg["model"]["kwargs"]["transformer_backbone_kwargs"]["max_position_embeddings"] - - if cfg["model"]["name"].lower().startswith("scgpt"): # scGPT uses log-normalized expression - cfg["data"]["kwargs"]["transform"] = "log-normalize" - cfg["data"]["kwargs"]["hvg_names_uns_key"] = ( - "hvg_names" if cfg["data"]["kwargs"]["train_task"] != "replogle" else None - ) # TODO: better to not hardcode this - - cfg["data"]["kwargs"]["dataset_cls"] = "scGPTPerturbationDataset" - - model_dir = Path(cfg["model"]["kwargs"]["pretrained_path"]) - - vocab_file = model_dir / "vocab.json" - - vocab = json.load(open(vocab_file, "r")) - cfg["model"]["kwargs"]["pad_token_id"] = vocab[""] - for s in cfg["model"]["kwargs"]["special_tokens"]: - if s not in vocab: - vocab[s] = len(vocab) - - cfg["data"]["kwargs"]["vocab"] = vocab - cfg["data"]["kwargs"]["perturbation_type"] = cfg["model"]["kwargs"]["perturbation_type"] - cfg["model"]["kwargs"]["ntoken"] = len(vocab) - cfg["model"]["kwargs"]["d_model"] = cfg["model"]["kwargs"]["embsize"] - - logger.info("Added vocab and hvg_names_uns_key to data kwargs for scGPT") - - elif cfg["model"]["name"].lower() == "cpa" and cfg["model"]["kwargs"]["recon_loss"] == "gauss": - cfg["data"]["kwargs"]["transform"] = "log-normalize" - elif cfg["model"]["name"].lower() == "scvi": - cfg["data"]["kwargs"]["transform"] = None + try: + sentence_len = cfg["model"]["kwargs"]["transformer_backbone_kwargs"]["n_positions"] + except: + sentence_len = cfg["model"]["kwargs"]["transformer_backbone_kwargs"]["max_position_embeddings"] output_space = cfg["data"]["kwargs"].get("output_space", "gene") assert output_space in {"embedding", "gene", "all"}, ( f"data.kwargs.output_space must be one of 'embedding', 'gene', or 'all'; got {output_space!r}" ) + nb_loss_enabled = bool(cfg["model"]["kwargs"].get("nb_loss", False)) + if nb_loss_enabled and output_space == "embedding": + raise ValueError( + "model.kwargs.nb_loss=True is incompatible with data.kwargs.output_space='embedding'. " + "Use output_space='gene' or output_space='all'." + ) + if nb_loss_enabled and output_space not in {"gene", "all"}: + raise ValueError( + f"model.kwargs.nb_loss=True requires data.kwargs.output_space in {{'gene', 'all'}}; got {output_space!r}." + ) + embed_key = cfg["data"]["kwargs"].get("embed_key", None) + if nb_loss_enabled and embed_key not in {None, "X_hvg"}: + if not bool(cfg["data"]["kwargs"].get("store_raw_basal", False)): + logger.warning( + "nb_loss=True with embed_key=%r requires control counts for library-size estimation; " + "setting data.kwargs.store_raw_basal=True.", + embed_key, + ) + cfg["data"]["kwargs"]["store_raw_basal"] = True + + if output_space == "embedding": + checkpoint_monitor_metric = "val/embedding_loss" + else: + checkpoint_monitor_metric = "val/expression_loss" data_module: PerturbationDataModule = get_datamodule( cfg["data"]["name"], @@ -120,11 +104,33 @@ def run_tx_train(cfg: DictConfig): cell_sentence_len=sentence_len, ) + data_module.setup(stage="fit") + if nb_loss_enabled: + resolved_is_log1p = bool(getattr(data_module, "is_log1p", cfg["data"]["kwargs"].get("is_log1p", False))) + expected_exp_counts = resolved_is_log1p + current_exp_counts = bool(getattr(data_module, "exp_counts", False)) + if current_exp_counts != expected_exp_counts: + logger.warning( + "nb_loss=True requires exp_counts to follow is_log1p. " + "Resolved is_log1p=%s, overriding exp_counts %s -> %s.", + resolved_is_log1p, + current_exp_counts, + expected_exp_counts, + ) + data_module.exp_counts = expected_exp_counts + else: + logger.info( + "nb_loss=True with resolved is_log1p=%s and exp_counts=%s.", + resolved_is_log1p, + current_exp_counts, + ) + cfg["data"]["kwargs"]["is_log1p"] = resolved_is_log1p + cfg["data"]["kwargs"]["exp_counts"] = expected_exp_counts + with open(join(run_output_dir, "data_module.torch"), "wb") as f: # TODO-Abhi: only save necessary data data_module.save_state(f) - data_module.setup(stage="fit") dl = data_module.train_dataloader() print("num_workers:", dl.num_workers) print("batch size:", dl.batch_size) @@ -143,7 +149,6 @@ def run_tx_train(cfg: DictConfig): gene_dim=gene_dim, hidden_dims=hidden_dims, dropout=cfg["model"]["kwargs"].get("decoder_dropout", 0.1), - residual_decoder=cfg["model"]["kwargs"].get("residual_decoder", False), ) # tuck it into the kwargs that will reach the LightningModule @@ -164,11 +169,6 @@ def run_tx_train(cfg: DictConfig): with open(var_dims_path, "wb") as f: pickle.dump(var_dims, f) - if cfg["model"]["name"].lower() in ["cpa", "scvi"] or cfg["model"]["name"].lower().startswith("scgpt"): - cfg["model"]["kwargs"]["n_cell_types"] = len(data_module.celltype_onehot_map) - cfg["model"]["kwargs"]["n_perts"] = len(data_module.pert_onehot_map) - cfg["model"]["kwargs"]["n_batches"] = len(data_module.batch_onehot_map) - # Create model model = get_lightning_module( cfg["model"]["name"], @@ -206,6 +206,7 @@ def run_tx_train(cfg: DictConfig): cfg["name"], cfg["training"]["val_freq"], cfg["training"].get("ckpt_every_n_steps", 4000), + monitor_metric=checkpoint_monitor_metric, ) # Add BatchSpeedMonitorCallback to log batches per second to wandb batch_speed_monitor = BatchSpeedMonitorCallback() @@ -239,15 +240,7 @@ def run_tx_train(cfg: DictConfig): logger.info("Loggers and callbacks set up.") - if cfg["model"]["name"].lower().startswith("scgpt"): - plugins = [ - MixedPrecision( - precision="bf16-mixed", - device="cuda", - ) - ] - else: - plugins = [] + plugins = [] if torch.cuda.is_available(): accelerator = "gpu" @@ -272,7 +265,7 @@ def run_tx_train(cfg: DictConfig): logger=loggers, plugins=plugins, callbacks=callbacks, - gradient_clip_val=cfg["training"]["gradient_clip_val"] if cfg["model"]["name"].lower() != "cpa" else None, + gradient_clip_val=cfg["training"]["gradient_clip_val"], accumulate_grad_batches=cfg["training"].get("gradient_accumulation_steps", 1), use_distributed_sampler=False, ) @@ -320,7 +313,9 @@ def run_tx_train(cfg: DictConfig): ) print("Creating new decoder for the specified output space...") - if cfg["model"]["kwargs"].get("gene_decoder_bool", True) == False: + if cfg["model"]["kwargs"].get("gene_decoder_bool", True) == False or cfg["model"]["kwargs"].get( + "nb_loss", False + ): model._decoder_externally_configured = False else: # Override the decoder_cfg to match the new output_space @@ -334,7 +329,6 @@ def run_tx_train(cfg: DictConfig): gene_dim=new_gene_dim, hidden_dims=cfg["model"]["kwargs"].get("decoder_hidden_dims", [1024, 1024, 512]), dropout=cfg["model"]["kwargs"].get("decoder_dropout", 0.1), - residual_decoder=cfg["model"]["kwargs"].get("residual_decoder", False), ) # Update the model's decoder_cfg and rebuild decoder diff --git a/src/state/configs/model/cpa.yaml b/src/state/configs/model/cpa.yaml deleted file mode 100644 index 5f6d285c..00000000 --- a/src/state/configs/model/cpa.yaml +++ /dev/null @@ -1,26 +0,0 @@ -name: CPA -checkpoint: null -device: cuda - -kwargs: - n_latent: 84 - recon_loss: gauss - pert_embeddings: null - hidden_dim: 256 # not used - n_hidden_encoder: 1024 - n_layers_encoder: 5 - n_hidden_decoder: 1024 - n_layers_decoder: 4 - use_batch_norm: decoder - use_layer_norm: encoder - dropout_rate_encoder: 0.2 - dropout_rate_decoder: 0.2 - n_hidden_adv: 128 - n_layers_adv: 3 - use_norm_adv: batch - dropout_rate_adv: 0.25 - variational: False - expr_transform: none - seed: 2025 - cell_sentence_len: 512 - nb_decoder: false \ No newline at end of file diff --git a/src/state/configs/model/old_neuralot.yaml b/src/state/configs/model/old_neuralot.yaml deleted file mode 100644 index f7e1f817..00000000 --- a/src/state/configs/model/old_neuralot.yaml +++ /dev/null @@ -1,28 +0,0 @@ -name: old_neuralot -checkpoint: null -device: cuda - -kwargs: - cell_set_len: 512 - hidden_dim: 328 - loss: energy - n_encoder_layers: 4 - n_decoder_layers: 4 - predict_residual: True - softplus: True - freeze_pert_backbone: False - transformer_decoder: False - finetune_vci_decoder: False - batch_encoder: False - distributional_loss: energy - transformer_backbone_key: GPT2 - transformer_backbone_kwargs: - n_positions: ${model.kwargs.cell_set_len} - n_embd: ${model.kwargs.hidden_dim} - d_inner: 1024 - n_layer: 8 - n_head: 8 - resid_pdrop: 0.0 - embd_pdrop: 0.0 - attn_pdrop: 0.0 - use_cache: false diff --git a/src/state/configs/model/pertsets.yaml b/src/state/configs/model/pertsets.yaml index 31088591..de0b1be2 100644 --- a/src/state/configs/model/pertsets.yaml +++ b/src/state/configs/model/pertsets.yaml @@ -9,15 +9,12 @@ kwargs: blur: 0.05 hidden_dim: 328 # hidden dimension going into the transformer backbone loss: energy - confidence_token: False # if true, model tries to predict its own confidence n_encoder_layers: 4 # number of MLP layers for pert, basal encoders n_decoder_layers: 4 predict_residual: True # if true, predicts the residual in embedding space to the basal cells freeze_pert_backbone: False # if true, the perturbation model is frozen finetune_vci_decoder: False # if true, the pretrained state decoder is used in finetuning - residual_decoder: False # if true, the pretrained state decoder is used in finetuning batch_encoder: False # if true, batch variables are used - use_batch_token: False # if true, batch token is appended to the sequence decoder_loss_weight: 1.0 use_basal_projection: False mask_attn: False # if true, mask the attention diff --git a/src/state/configs/model/pseudobulk.yaml b/src/state/configs/model/pseudobulk.yaml index 9342163d..2aa64843 100644 --- a/src/state/configs/model/pseudobulk.yaml +++ b/src/state/configs/model/pseudobulk.yaml @@ -15,7 +15,6 @@ kwargs: freeze_pert_backbone: False transformer_decoder: False finetune_vci_decoder: False - residual_decoder: False batch_encoder: False mask_attn: False use_effect_gating_token: False diff --git a/src/state/configs/model/scgpt-chemical.yaml b/src/state/configs/model/scgpt-chemical.yaml deleted file mode 100644 index 482c172c..00000000 --- a/src/state/configs/model/scgpt-chemical.yaml +++ /dev/null @@ -1,46 +0,0 @@ -name: scGPT-chemical -checkpoint: null -device: cuda - -kwargs: - hidden_dim: 256 # not used - pad_token: "" - special_tokens: - - "" - - "" - - "" - - pad_value: 0 - pert_pad_id: 2 - - include_zero_gene: "all" # include zero expr genes in training input, "all", "batch-wise", "row-wise", or False - max_seq_len: 1536 - - do_MLM: true # whether to use masked language modeling, currently it is always on. - do_CLS: false # celltype classification objective - do_CCE: false # Contrastive cell embedding objective - do_MVC: false # Masked value prediction for cell embedding - do_ECS: false # Elastic cell similarity objective - cell_emb_style: "cls" - mvc_decoder_style: "inner product, detach" - use_amp: true - pretrained_path: "/large_storage/goodarzilab/userspace/mohsen/scGPT/scGPT_human/" - load_param_prefixes: - - "encoder" - - "value_encoder" - - "transformer_encoder" - - # settings for the model - embsize: 512 # embedding dimension - d_hid: 512 # dimension of the feedforward network model in nn.TransformerEncoder - nlayers: 12 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder - nhead: 8 # number of heads in nn.MultiheadAttention - n_layers_cls: 3 - dropout: 0.2 # dropout probability - use_fast_transformer: true # whether to use fast transformer - - expr_transform: none - perturbation_type: "chemical" - cell_sentence_len: 2048 - seed: 2025 - nb_decoder: false \ No newline at end of file diff --git a/src/state/configs/model/scgpt-genetic.yaml b/src/state/configs/model/scgpt-genetic.yaml deleted file mode 100644 index 85d02f8c..00000000 --- a/src/state/configs/model/scgpt-genetic.yaml +++ /dev/null @@ -1,46 +0,0 @@ -name: scGPT-genetic -checkpoint: null -device: cuda - -kwargs: - hidden_dim: 256 # not used - pad_token: "" - special_tokens: - - "" - - "" - - "" - - pad_value: 0 - pert_pad_id: 2 - - include_zero_gene: "all" # include zero expr genes in training input, "all", "batch-wise", "row-wise", or False - max_seq_len: 1536 - - do_MLM: true # whether to use masked language modeling, currently it is always on. - do_CLS: false # celltype classification objective - do_CCE: false # Contrastive cell embedding objective - do_MVC: false # Masked value prediction for cell embedding - do_ECS: false # Elastic cell similarity objective - cell_emb_style: "cls" - mvc_decoder_style: "inner product, detach" - use_amp: true - pretrained_path: "/large_storage/goodarzilab/userspace/mohsen/scGPT/scGPT_human/" - load_param_prefixes: - - "encoder" - - "value_encoder" - - "transformer_encoder" - - # settings for the model - embsize: 512 # embedding dimension - d_hid: 512 # dimension of the feedforward network model in nn.TransformerEncoder - nlayers: 12 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder - nhead: 8 # number of heads in nn.MultiheadAttention - n_layers_cls: 3 - dropout: 0.2 # dropout probability - use_fast_transformer: true # whether to use fast transformer - - expr_transform: none - perturbation_type: genetic - seed: 2025 - cell_sentence_len: 2048 - nb_decoder: false \ No newline at end of file diff --git a/src/state/configs/model/scvi.yaml b/src/state/configs/model/scvi.yaml deleted file mode 100644 index 85cd3a94..00000000 --- a/src/state/configs/model/scvi.yaml +++ /dev/null @@ -1,21 +0,0 @@ -name: scVI -checkpoint: null -device: cuda - -kwargs: - n_latent: 84 - recon_loss: zinb - pert_embeddings: null - hidden_dim: 256 # not used - n_hidden_encoder: 512 - n_layers_encoder: 2 - n_hidden_decoder: 512 - n_layers_decoder: 2 - use_batch_norm: both - use_layer_norm: none - dropout_rate_encoder: 0.1 - dropout_rate_decoder: 0.1 - expr_transform: none - seed: 2025 - cell_sentence_len: 512 - nb_decoder: false \ No newline at end of file diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index a1bae45f..671f211b 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -7,7 +7,7 @@ kwargs: blur: 0.05 hidden_dim: 768 loss: energy - confidence_token: False + nb_loss: false n_encoder_layers: 1 n_decoder_layers: 1 predict_residual: True @@ -15,16 +15,12 @@ kwargs: freeze_pert_backbone: False transformer_decoder: False finetune_vci_decoder: False - residual_decoder: False batch_encoder: False - use_batch_token: False mask_attn: False use_effect_gating_token: False distributional_loss: energy init_from: null - mmd_num_chunks: 1 - randomize_mmd_chunks: false llm_name: null transformer_backbone_key: llama diff --git a/src/state/configs/model/state_lg.yaml b/src/state/configs/model/state_lg.yaml index a2f2f622..a760bb88 100644 --- a/src/state/configs/model/state_lg.yaml +++ b/src/state/configs/model/state_lg.yaml @@ -15,7 +15,6 @@ kwargs: freeze_pert_backbone: False transformer_decoder: False finetune_vci_decoder: False - residual_decoder: False decoder_loss_weight: 1.0 batch_encoder: False mask_attn: False diff --git a/src/state/configs/model/state_sm.yaml b/src/state/configs/model/state_sm.yaml index 2f6d156b..87cf6aac 100644 --- a/src/state/configs/model/state_sm.yaml +++ b/src/state/configs/model/state_sm.yaml @@ -15,7 +15,6 @@ kwargs: freeze_pert_backbone: False transformer_decoder: False finetune_vci_decoder: False - residual_decoder: False batch_encoder: False mask_attn: False use_effect_gating_token: False @@ -23,8 +22,6 @@ kwargs: distributional_loss: energy gene_decoder_bool: False init_from: null - mmd_num_chunks: 1 - randomize_mmd_chunks: false transformer_backbone_key: llama transformer_backbone_kwargs: bidirectional_attention: true diff --git a/src/state/configs/model/tahoe_best.yaml b/src/state/configs/model/tahoe_best.yaml index 70a9cdf2..50224e42 100644 --- a/src/state/configs/model/tahoe_best.yaml +++ b/src/state/configs/model/tahoe_best.yaml @@ -15,7 +15,6 @@ kwargs: freeze_pert: False transformer_decoder: False finetune_vci_decoder: False - residual_decoder: False batch_encoder: False mask_attn: False use_effect_gating_token: False diff --git a/src/state/configs/model/tahoe_llama_212693232.yaml b/src/state/configs/model/tahoe_llama_212693232.yaml index ae113528..80eb7bdd 100644 --- a/src/state/configs/model/tahoe_llama_212693232.yaml +++ b/src/state/configs/model/tahoe_llama_212693232.yaml @@ -15,7 +15,6 @@ kwargs: freeze_pert: False transformer_decoder: False finetune_vci_decoder: False - residual_decoder: False decoder_loss_weight: 1.0 batch_encoder: False mask_attn: False diff --git a/src/state/configs/model/tahoe_llama_62089464.yaml b/src/state/configs/model/tahoe_llama_62089464.yaml index a985f997..e5f790a7 100644 --- a/src/state/configs/model/tahoe_llama_62089464.yaml +++ b/src/state/configs/model/tahoe_llama_62089464.yaml @@ -15,7 +15,6 @@ kwargs: freeze_pert: False transformer_decoder: False finetune_vci_decoder: False - residual_decoder: False batch_encoder: False mask_attn: False use_effect_gating_token: False diff --git a/src/state/configs/training/cpa.yaml b/src/state/configs/training/cpa.yaml deleted file mode 100644 index 63cc280c..00000000 --- a/src/state/configs/training/cpa.yaml +++ /dev/null @@ -1,20 +0,0 @@ -max_steps: 250000 -train_seed: 42 -val_freq: 5000 -test_freq: 9000 -gradient_clip_val: 10 # 0 means no clipping - -n_epochs_kl_warmup: null -n_steps_adv_warmup: 50000 -n_steps_pretrain_ae: 50000 -adv_steps: null -reg_adv: 15.0 -pen_adv: 20.0 -lr: 5e-4 -wd: 4e-7 -adv_lr: 5e-4 -adv_wd: 4e-7 -step_size_lr: 25 -do_clip_grad: false -adv_loss: "cce" -batch_size: 2048 \ No newline at end of file diff --git a/src/state/configs/training/default.yaml b/src/state/configs/training/default.yaml index e865403c..aae4fa9a 100644 --- a/src/state/configs/training/default.yaml +++ b/src/state/configs/training/default.yaml @@ -2,6 +2,11 @@ wandb_track: true weight_decay: 0.0005 batch_size: 16 lr: 1e-4 +optimizer: adam +use_cosine_decay: false +max_lr: null +lr_decay_steps: null +max_lr_fraction: 0.1 max_steps: 400000 train_seed: 42 val_freq: 2000 diff --git a/src/state/configs/training/scgpt.yaml b/src/state/configs/training/scgpt.yaml deleted file mode 100644 index 3a0457c9..00000000 --- a/src/state/configs/training/scgpt.yaml +++ /dev/null @@ -1,11 +0,0 @@ -max_steps: 250000 -train_seed: 42 -val_freq: 5000 -test_freq: 9000 -gradient_clip_val: 10 # 0 means no clipping - -lr: 5e-5 -wd: 4e-7 -step_size_lr: 25 -do_clip_grad: false -batch_size: 256 diff --git a/src/state/configs/training/scvi.yaml b/src/state/configs/training/scvi.yaml deleted file mode 100644 index 756a436f..00000000 --- a/src/state/configs/training/scvi.yaml +++ /dev/null @@ -1,12 +0,0 @@ -max_steps: 250000 -train_seed: 42 -val_freq: 5000 -test_freq: 9000 -gradient_clip_val: 10 # 0 means no clipping - -n_epochs_kl_warmup: 1e4 -lr: 5e-4 -wd: 4e-7 -step_size_lr: 25 -do_clip_grad: false -batch_size: 2048 \ No newline at end of file diff --git a/src/state/tx/callbacks/batch_speed_monitor.py b/src/state/tx/callbacks/batch_speed_monitor.py index 2f34323f..f001263c 100644 --- a/src/state/tx/callbacks/batch_speed_monitor.py +++ b/src/state/tx/callbacks/batch_speed_monitor.py @@ -45,22 +45,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # Log to wandb pl_module.log("batches_per_second", batches_per_second) - # Also log min, max, and coefficient of variation to help diagnose variability - if len(self.batch_times) > 1: - min_time = min(self.batch_times) - max_time = max(self.batch_times) - std_dev = (sum((t - avg_batch_time) ** 2 for t in self.batch_times) / len(self.batch_times)) ** 0.5 - cv = (std_dev / avg_batch_time) * 100 if avg_batch_time > 0 else 0 - - pl_module.log("batch_time_min", min_time) - pl_module.log("batch_time_max", max_time) - pl_module.log("batch_time_avg", avg_batch_time) - pl_module.log("batch_time_cv_percent", cv) - - # Log max/min ratio to identify extreme outliers - if min_time > 0: - pl_module.log("batch_time_max_min_ratio", max_time / min_time) - # Reset for next interval self.batch_times = [] self.last_logged_batch = batch_idx diff --git a/src/state/tx/callbacks/model_flops_utilization.py b/src/state/tx/callbacks/model_flops_utilization.py index 071e79dc..9f83a085 100644 --- a/src/state/tx/callbacks/model_flops_utilization.py +++ b/src/state/tx/callbacks/model_flops_utilization.py @@ -169,14 +169,3 @@ def on_train_batch_end(self, trainer: Trainer, pl_module: Any, outputs: Any, bat if mfu is not None: mfu = 100 * mfu pl_module.log("mfu (%)", mfu, prog_bar=True, on_step=True, on_epoch=False) - - # Log cell_sets (cell_sentences) per second - cell_sets_per_sec = metrics.get("global/samples_per_sec", metrics.get("device/samples_per_sec", None)) - if cell_sets_per_sec is not None: - pl_module.log( - "cell_sets_per_sec", - cell_sets_per_sec, - prog_bar=False, - on_step=True, - on_epoch=False, - ) diff --git a/src/state/tx/data/dataset/__init__.py b/src/state/tx/data/dataset/__init__.py deleted file mode 100644 index 133972f0..00000000 --- a/src/state/tx/data/dataset/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .scgpt_perturbation_dataset import scGPTPerturbationDataset - -__all__ = ["scGPTPerturbationDataset"] diff --git a/src/state/tx/data/dataset/scgpt_perturbation_dataset.py b/src/state/tx/data/dataset/scgpt_perturbation_dataset.py deleted file mode 100644 index 96a47fd2..00000000 --- a/src/state/tx/data/dataset/scgpt_perturbation_dataset.py +++ /dev/null @@ -1,344 +0,0 @@ -""" -PerturbationDataset is used to load perturbation data from h5 files. -Originally, each file was assumed to contain a single cell type. -Now, we remove that assumption so that each file (a plate) may contain -multiple cell types. -""" - -import logging -from pathlib import Path -from typing import Dict, List, Literal, Optional, Union - -import h5py -import numpy as np -import torch -from cell_load.dataset import PerturbationDataset -from cell_load.mapping_strategies import BaseMappingStrategy -from cell_load.utils.data_utils import GlobalH5MetadataCache - -logger = logging.getLogger(__name__) - - -class scGPTPerturbationDataset(PerturbationDataset): - """ - Dataset class for loading perturbation data from h5 files. - - Each instance handles a single dataset-cell_type combination. Therefore this class is responsible for - serving a single dataset/cell_type pair. Future improvements will also allow for splitting on - the perturbation level. - - Currently there are three strategies for mapping basal cells to perturbed cells: - - "batch": Basal cells are sampled from the same batch as the perturbed cell - - "random": Basal cells are sampled randomly from the same cell type as the perturbed cell - - "nearest": Basal cells are sampled from the nearest neighbors of the perturbed cell - - A control cell is always mapped to a perturbed cell within the same dataset and with the same cell type. - """ - - def __init__( - self, - name: str, - h5_path: Union[str, Path], - mapping_strategy: BaseMappingStrategy, - pert_onehot_map: Optional[Dict[str, int]] = None, - cell_type_onehot_map: Optional[Dict[str, int]] = None, - batch_onehot_map: Optional[Dict[str, int]] = None, - pert_col: str = "gene", - cell_type_key: str = "cell_type", - batch_col: str = "batch", - control_pert: str = "non-targeting", - embed_key: Literal["X_uce", "X_pca"] = "X_uce", - store_raw_expression: bool = False, - random_state: int = 42, - should_yield_control_cells: bool = True, - store_raw_basal: bool = False, - vocab: Optional[Dict[str, int]] = None, - hvg_names_uns_key: Optional[str] = None, - perturbation_type: Literal["chemical", "genetic"] = "chemical", - **kwargs, - ): - """ - Args: - name: Name of the dataset - h5_path: Path to the h5 file containing the dataset - mapping_strategy: Strategy for mapping basal cells to perturbed cells, one of "batch", "random", "nearest" - pert_onehot_map: Global mapping of perturbation names to one-hot encodings or featurizations - cell_type_onehot_map: Global mapping of cell type names to one-hot encodings or featurizations - batch_onehot_map: Global mapping of batch names to one-hot encodings - pert_col: Column in the h5 file containing perturbation information - cell_type_key: Column in the h5 file containing cell type information - batch_col: Column in the h5 file containing batch information - control_pert: Name of the control perturbation - embed_key: Key in the h5 file containing the expression data, one of "pert_cell_emb" or "X_uce" - random_state: Random seed for reproducibility - pert_tracker: PerturbationTracker instance for tracking valid perturbations - should_yield_control_cells: If True, control cells will be included in the dataset - """ - super().__init__( - name=name, - h5_path=h5_path, - mapping_strategy=mapping_strategy, - pert_onehot_map=pert_onehot_map, - cell_type_onehot_map=cell_type_onehot_map, - batch_onehot_map=batch_onehot_map, - pert_col=pert_col, - cell_type_key=cell_type_key, - batch_col=batch_col, - control_pert=control_pert, - embed_key=embed_key, - store_raw_expression=store_raw_expression, - random_state=random_state, - should_yield_control_cells=should_yield_control_cells, - store_raw_basal=store_raw_basal, - **kwargs, - ) - self.vocab = vocab - self.hvg_names_uns_key = hvg_names_uns_key - - assert vocab is not None, "vocab must be provided for scGPTPerturbationDataset" - - self.gene_names = self.get_gene_names() - - self.gene_ids = np.array( - [vocab[gene] if gene in vocab else vocab[""] for gene in self.gene_names], - dtype=int, - ) - - num_invalid_genes = np.sum(self.gene_ids == vocab[""]) - if num_invalid_genes > 0: - logger.warning(f"scGPTPerturbationDataset ([{self.name}]) Number of invalid genes: {num_invalid_genes}") - - self.perturbation_type = perturbation_type.lower() - - if self.perturbation_type == "genetic": - num_genes_X = len(self.gene_names) - self.pert_flags = {} - for pert in self.pert_onehot_map.keys(): - self.pert_flags[pert] = np.zeros(num_genes_X) - if pert in self.gene_names: - self.pert_flags[pert][self.gene_names.index(pert)] = 1 - else: - logger.warning( - f"scGPTPerturbationDataset ([{self.name}]) Perturbation {pert} not found in gene names" - ) - - def __getitem__(self, idx: int): - """ - Returns a dictionary with: - - 'X': the (possibly transformed) expression of the perturbed cell - - 'basal': the control cell’s expression as chosen by the mapping strategy - - 'pert': the one-hot encoding (or other featurization) for the perturbation - - 'pert_name': the perturbation name - - 'cell_type': the cell type (from the full array) - - 'gem_group': the batch (as an int or string) - - The index `idx` here is into the filtered set of cells. - """ - # Map idx to the underlying file index - underlying_idx = int(self.all_indices[idx]) - split = self._find_split_for_idx(underlying_idx) - - # Get expression from the h5 file. - # For now, we assume the data is stored in "pert_cell_emb" (could be counts) and/or in obsm (embed_key) - # (It is up to the downstream code to decide whether to use raw gene expression or a precomputed embedding.) - - pert_expr, ctrl_expr, ctrl_idx = self.mapping_strategy.get_mapped_expressions(self, split, underlying_idx) - - # Get perturbation information using metadata cache - pert_code = self.metadata_cache.pert_codes[underlying_idx] - pert_name = self.metadata_cache.pert_categories[pert_code] - if self.pert_onehot_map is not None: - # map across all files to a consistent one hot encoding or featurization - pert_onehot = self.pert_onehot_map[pert_name] - else: - pert_onehot = None - - # Get cell type using metadata cache - cell_type_code = self.metadata_cache.cell_type_codes[underlying_idx] - cell_type = self.metadata_cache.cell_type_categories[cell_type_code] - - if self.cell_type_onehot_map is not None: - cell_type_onehot = self.cell_type_onehot_map[cell_type] - else: - cell_type_onehot = None - - # Get batch information - batch_code = self.metadata_cache.batch_codes[underlying_idx] - batch_name = self.metadata_cache.batch_categories[batch_code] - - if self.batch_onehot_map is not None: - # map across all files to a consistent one hot encoding or featurization - batch = self.batch_onehot_map[batch_name] - else: - batch = None - - sample = { - "pert_cell_emb": pert_expr, # the perturbed cell’s data - "ctrl_cell_emb": ctrl_expr, # will be filled in by the mapping strategy - "pert_emb": pert_onehot, - "pert_name": pert_name, - "cell_type": cell_type, - "cell_type_onehot": cell_type_onehot, - "batch": batch, - "batch_name": batch_name, - "gene_ids": torch.tensor( - self.gene_ids, dtype=torch.long - ), # TODO: should be a more efficient way to do this as this is repeated for every cell - } - - if "perturbation_type" in self.__dict__ and self.perturbation_type == "genetic": - sample["pert_flags"] = torch.tensor(self.pert_flags[pert_name], dtype=torch.long) - - # Optionally, if raw gene expression is needed: - # backwards compatibility for old cktps - if self.store_raw_expression and self.output_space == "gene": - sample["pert_cell_counts"] = self.fetch_obsm_expression(underlying_idx, "X_hvg") - elif self.store_raw_expression and self.output_space == "all": - sample["pert_cell_counts"] = self.fetch_gene_expression(underlying_idx) - return sample - - def fetch_obsm_expression(self, idx: int, key: str) -> torch.Tensor: - row_data = self.h5_file[f"/obsm/{key}"][idx] - return torch.tensor(row_data, dtype=torch.float32) - - def get_gene_names(self) -> List[str]: - """ - Get the gene names, which are under adata.var.index, using h5. - """ - if self.hvg_names_uns_key is not None: # return hvg names if provided - hvg_names = self.h5_file[f"uns/{self.hvg_names_uns_key}"][:].astype(str).tolist() - return hvg_names - - try: - genes = self.h5_file["var/gene_name"][:].astype(str).tolist() - - # TODO: handle raw exception - except: - try: - categories = self.h5_file["var/gene_name/categories"][:].astype(str) - codes = self.h5_file["var/gene_name/codes"][:] - genes = categories[codes].tolist() - - # TODO: handle raw exception - except: - genes = self.h5_file["var/_index"][:].astype(str).tolist() - - return genes - - ############################## - # Static methods - ############################## - @staticmethod - def collate_fn(batch, transform=None, pert_col="drug", int_counts=False): - """ - Custom collate that reshapes data into sequences. - Safely handles normalization when vectors sum to zero. - """ - # First do normal collation - batch_dict = { - "pert_cell_emb": torch.stack([item["pert_cell_emb"] for item in batch]), - "ctrl_cell_emb": torch.stack([item["ctrl_cell_emb"] for item in batch]), - "pert_emb": torch.stack([item["pert_emb"] for item in batch]), - "pert_name": [item["pert_name"] for item in batch], - "cell_type": [item["cell_type"] for item in batch], - "cell_type_onehot": torch.stack([item["cell_type_onehot"] for item in batch]), - "batch": torch.stack([item["batch"] for item in batch]), - "batch_name": [item["batch_name"] for item in batch], - "gene_ids": torch.stack([item["gene_ids"] for item in batch]), - } - - if "pert_flags" in batch[0]: # only add pert_flags in case of genetic perturbations - batch_dict["pert_flags"] = torch.stack([item["pert_flags"] for item in batch]) - - # If the first sample has "pert_cell_counts", assume the entire batch does - if "pert_cell_counts" in batch[0]: - X_hvg = torch.stack([item["pert_cell_counts"] for item in batch]) - - # Handle Tahoe dataset (needs log transform) - if pert_col == "drug" or pert_col == "drugname_drugconc": - if transform == "log-normalize": - library_sizes = X_hvg.sum( - dim=1, keepdim=True - ) # TODO: Need to replace with library size from all genes - # Replace zeros with ones (will result in no change for zero vectors) - safe_sizes = torch.where(library_sizes > 0, library_sizes, torch.ones_like(library_sizes) * 10000) - X_hvg_norm = X_hvg * 10000 / safe_sizes - batch_dict["pert_cell_counts"] = torch.log1p(X_hvg_norm) - elif transform == "log1p" or transform is True: - batch_dict["pert_cell_counts"] = torch.log1p(X_hvg) - elif int_counts: - # this is for log transformed data. let's make it count data - batch_dict["pert_cell_counts"] = torch.expm1(X_hvg).round().to(torch.int32) - - # If the first sample has "ctrl_cell_counts", assume the entire batch does - if "ctrl_cell_counts" in batch[0]: # either control hvg gene space or 19k gene space - basal_hvg = torch.stack([item["ctrl_cell_counts"] for item in batch]) - - # Handle Tahoe dataset (needs log transform) - if pert_col == "drug" or pert_col == "drugname_drugconc": - if transform == "log-normalize": - library_sizes = basal_hvg.sum( - dim=1, keepdim=True - ) # TODO: Need to replace with library size from all genes - # Replace zeros with ones (will result in no change for zero vectors) - safe_sizes = torch.where(library_sizes > 0, library_sizes, torch.ones_like(library_sizes) * 10000) - basal_hvg_norm = basal_hvg * 10000 / safe_sizes - batch_dict["ctrl_cell_counts"] = torch.log1p(basal_hvg_norm) - elif transform == "log1p" or transform is True: - batch_dict["ctrl_cell_counts"] = torch.log1p(basal_hvg) - elif int_counts: - # this is for log transformed data. let's make it count data - batch_dict["ctrl_cell_counts"] = torch.expm1(basal_hvg).round().to(torch.int32) - else: - batch_dict["ctrl_cell_counts"] = basal_hvg - # Apply transform if provided - if transform == "log-normalize": - X_library_sizes = batch_dict["pert_cell_emb"].sum(dim=1, keepdim=True) - X_safe_sizes = torch.where(X_library_sizes > 0, X_library_sizes, torch.ones_like(X_library_sizes) * 10000) - X_norm = batch_dict["pert_cell_emb"] * 10000 / X_safe_sizes - batch_dict["pert_cell_emb"] = torch.log1p(X_norm) - - # Normalize basal by library size before log transform - basal_library_sizes = batch_dict["ctrl_cell_emb"].sum(dim=1, keepdim=True) - basal_safe_sizes = torch.where( - basal_library_sizes > 0, basal_library_sizes, torch.ones_like(basal_library_sizes) * 10000 - ) - basal_norm = batch_dict["ctrl_cell_emb"] * 10000 / basal_safe_sizes - batch_dict["ctrl_cell_emb"] = torch.log1p(basal_norm) - elif transform == "log1p" or transform is True: # True is for backwards compatibility - # Original behavior: just log transform without normalization - batch_dict["pert_cell_emb"] = torch.log1p(batch_dict["pert_cell_emb"]) - batch_dict["ctrl_cell_emb"] = torch.log1p(batch_dict["ctrl_cell_emb"]) - - return batch_dict - - def __len__(self) -> int: - return self.n_cells - - def __getstate__(self): - """ - Return a dictionary of this dataset's state without the open h5 file object. - """ - # Copy the object's dict - state = self.__dict__.copy() - # Remove the open file object if it exists - if "h5_file" in state: - # We'll also store whether it's currently open, so that we can re-open later if needed - del state["h5_file"] - return state - - def __setstate__(self, state): - """ - Reconstruct the dataset after unpickling. Re-open the HDF5 file by path. - """ - # TODO-Abhi: remove this before release - self.__dict__.update(state) - # This ensures that after we unpickle, we have a valid h5_file handle again - self.h5_file = h5py.File(self.h5_path, "r") - self.metadata_cache = GlobalH5MetadataCache().get_cache( - str(self.h5_path), - self.pert_col, - self.cell_type_key, - self.control_pert, - self.batch_col, - ) diff --git a/src/state/tx/models/__init__.py b/src/state/tx/models/__init__.py index 850f30bb..758a9e35 100644 --- a/src/state/tx/models/__init__.py +++ b/src/state/tx/models/__init__.py @@ -3,7 +3,6 @@ from .decoder_only import DecoderOnlyPerturbationModel from .embed_sum import EmbedSumPerturbationModel from .perturb_mean import PerturbMeanPerturbationModel -from .old_neural_ot import OldNeuralOTPerturbationModel from .state_transition import StateTransitionPerturbationModel from .pseudobulk import PseudobulkPerturbationModel @@ -13,7 +12,6 @@ "ContextMeanPerturbationModel", "EmbedSumPerturbationModel", "StateTransitionPerturbationModel", - "OldNeuralOTPerturbationModel", "DecoderOnlyPerturbationModel", "PseudobulkPerturbationModel", ] diff --git a/src/state/tx/models/base.py b/src/state/tx/models/base.py index e0ba1ca4..06e985bb 100644 --- a/src/state/tx/models/base.py +++ b/src/state/tx/models/base.py @@ -26,7 +26,6 @@ class LatentToGeneDecoder(nn.Module): gene_dim: Dimension of gene space (number of HVGs) hidden_dims: List of hidden layer dimensions dropout: Dropout rate - residual_decoder: If True, adds residual connections between every other layer block """ def __init__( @@ -35,54 +34,32 @@ def __init__( gene_dim: int, hidden_dims: List[int] = [512, 1024], dropout: float = 0.1, - residual_decoder=False, ): super().__init__() - self.residual_decoder = residual_decoder + layers = [] + input_dim = latent_dim - if residual_decoder: - # Build individual blocks for residual connections - self.blocks = nn.ModuleList() - input_dim = latent_dim + for hidden_dim in hidden_dims: + layers.append(nn.Linear(input_dim, hidden_dim)) + layers.append(nn.LayerNorm(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Dropout(dropout)) + input_dim = hidden_dim - for hidden_dim in hidden_dims: - block = nn.Sequential( - nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout) - ) - self.blocks.append(block) - input_dim = hidden_dim + # Final output layer + layers.append(nn.Linear(input_dim, gene_dim)) + # Make sure outputs are non-negative + layers.append(nn.ReLU()) - # Final output layer - self.final_layer = nn.Sequential(nn.Linear(input_dim, gene_dim), nn.ReLU()) - else: - # Original implementation without residual connections - layers = [] - input_dim = latent_dim - - for hidden_dim in hidden_dims: - layers.append(nn.Linear(input_dim, hidden_dim)) - layers.append(nn.LayerNorm(hidden_dim)) - layers.append(nn.GELU()) - layers.append(nn.Dropout(dropout)) - input_dim = hidden_dim - - # Final output layer - layers.append(nn.Linear(input_dim, gene_dim)) - # Make sure outputs are non-negative - layers.append(nn.ReLU()) - - self.decoder = nn.Sequential(*layers) + self.decoder = nn.Sequential(*layers) def gene_dim(self): # return the output dimension of the last layer - if self.residual_decoder: - return self.final_layer[0].out_features - else: - for module in reversed(self.decoder): - if isinstance(module, nn.Linear): - return module.out_features - return None + for module in reversed(self.decoder): + if isinstance(module, nn.Linear): + return module.out_features + return None def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -94,26 +71,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Gene expression predictions of shape [batch_size, gene_dim] """ - if self.residual_decoder: - # Apply blocks with residual connections between every other block - block_outputs = [] - current = x - - for i, block in enumerate(self.blocks): - output = block(current) - - # Add residual connection from every other previous block - # Pattern: blocks 1, 3, 5, ... get residual from blocks 0, 2, 4, ... - if i >= 1 and i % 2 == 1: # Odd-indexed blocks (1, 3, 5, ...) - residual_idx = i - 1 # Previous even-indexed block - output = output + block_outputs[residual_idx] - - block_outputs.append(output) - current = output - - return self.final_layer(current) - else: - return self.decoder(x) + return self.decoder(x) class PerturbationModel(ABC, LightningModule): @@ -131,6 +89,16 @@ class PerturbationModel(ABC, LightningModule): output_space: 'gene', 'all', or 'embedding' """ + @staticmethod + def _sanitize_decoder_cfg(decoder_cfg: dict | None) -> dict | None: + if decoder_cfg is None: + return None + sanitized_cfg = dict(decoder_cfg) + if "residual_decoder" in sanitized_cfg: + sanitized_cfg.pop("residual_decoder") + logger.warning("decoder_cfg.residual_decoder is deprecated and will be ignored.") + return sanitized_cfg + def __init__( self, input_dim: int, @@ -152,7 +120,11 @@ def __init__( **kwargs, ): super().__init__() - self.decoder_cfg = decoder_cfg + if "residual_decoder" in kwargs: + kwargs = dict(kwargs) + kwargs.pop("residual_decoder") + logger.warning("model.kwargs.residual_decoder is deprecated and will be ignored.") + self.decoder_cfg = self._sanitize_decoder_cfg(decoder_cfg) self.save_hyperparameters() self.gene_decoder_bool = kwargs.get("gene_decoder_bool", True) @@ -170,8 +142,6 @@ def __init__( else: self.batch_dim = None - self.residual_decoder = kwargs.get("residual_decoder", False) - self.embed_key = embed_key self.output_space = output_space if self.output_space not in {"embedding", "gene", "all"}: @@ -210,6 +180,7 @@ def _build_networks(self): def _build_decoder(self): """Create self.gene_decoder from self.decoder_cfg (or leave None).""" + self.decoder_cfg = self._sanitize_decoder_cfg(self.decoder_cfg) if self.gene_decoder_bool == False: self.gene_decoder = None return @@ -218,6 +189,28 @@ def _build_decoder(self): return self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) + def _main_loss_is_expression(self) -> bool: + """ + Determine whether the primary train/val loss is in expression/count space. + """ + if self.output_space == "embedding": + return False + return self.embed_key in {"X_hvg", None} + + def _train_main_loss_key(self) -> str: + return "train/expression_loss" if self._main_loss_is_expression() else "train/embedding_loss" + + def _val_main_loss_key(self) -> str: + return "val/expression_loss" if self._main_loss_is_expression() else "val/embedding_loss" + + @staticmethod + def _train_expression_loss_key() -> str: + return "train/expression_loss" + + @staticmethod + def _val_expression_loss_key() -> str: + return "val/expression_loss" + def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: """ Lightning calls this *before* the checkpoint's state_dict is loaded. @@ -233,66 +226,22 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: self.gene_decoder = None return - # When finetuning with the pretrained VCI decoder, keep the existing - # FinetuneVCICountsDecoder instance. Overwriting it with a freshly - # constructed LatentToGeneDecoder would make the checkpoint weights - # incompatible and surface load_state_dict errors. - finetune_decoder_active = False - hparams = getattr(self, "hparams", None) - if hparams is not None: - if hasattr(hparams, "get"): - finetune_decoder_active = bool(hparams.get("finetune_vci_decoder", False)) - else: - finetune_decoder_active = bool(getattr(hparams, "finetune_vci_decoder", False)) - if not finetune_decoder_active: - finetune_decoder_active = bool(getattr(self, "finetune_vci_decoder", False)) - - if finetune_decoder_active: - # Preserve decoder_cfg for completeness but avoid rebuilding the module. - if "decoder_cfg" in checkpoint.get("hyper_parameters", {}): - self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] - logger.info("Finetune VCI decoder active; keeping existing decoder during checkpoint load") + if decoder_already_configured: + logger.info("Decoder was already configured externally, skipping checkpoint decoder configuration") return - if not decoder_already_configured and "decoder_cfg" in checkpoint["hyper_parameters"]: - self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] - self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) - logger.info(f"Loaded decoder from checkpoint decoder_cfg: {self.decoder_cfg}") - elif not decoder_already_configured: - # Only fall back to old logic if no decoder_cfg was saved and not externally configured - self.decoder_cfg = None - self._build_decoder() - logger.info(f"DEBUG: output_space: {self.output_space}") - if self.gene_decoder is None: - gene_dim = self.hvg_dim if self.output_space == "gene" else self.gene_dim - logger.info(f"DEBUG: gene_dim: {gene_dim}") - if (self.embed_key and self.embed_key != "X_hvg" and self.output_space == "gene") or ( - self.embed_key and self.output_space == "all" - ): # we should be able to decode from hvg to all - logger.info("DEBUG: Creating gene_decoder, checking conditions...") - if gene_dim > 10000: - hidden_dims = [1024, 512, 256] - else: - if "DMSO_TF" in self.control_pert: - if self.residual_decoder: - hidden_dims = [2058, 2058, 2058, 2058, 2058] - else: - hidden_dims = [4096, 2048, 2048] - elif "PBS" in self.control_pert: - hidden_dims = [2048, 1024, 1024] - else: - hidden_dims = [1024, 1024, 512] # make this config - - self.gene_decoder = LatentToGeneDecoder( - latent_dim=self.output_dim, - gene_dim=gene_dim, - hidden_dims=hidden_dims, - dropout=self.dropout, - residual_decoder=self.residual_decoder, - ) - logger.info(f"Initialized gene decoder for embedding {self.embed_key} to gene space") - else: - logger.info("Decoder was already configured externally, skipping checkpoint decoder configuration") + checkpoint_hparams = checkpoint.get("hyper_parameters", {}) + if "decoder_cfg" in checkpoint_hparams: + self.decoder_cfg = self._sanitize_decoder_cfg(checkpoint_hparams["decoder_cfg"]) + elif self.decoder_cfg is None: + raise ValueError( + "Checkpoint is missing hyper_parameters.decoder_cfg and no decoder_cfg was provided at init. " + "Decoder configuration is required." + ) + + self.decoder_cfg = self._sanitize_decoder_cfg(self.decoder_cfg) + self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) + logger.info(f"Loaded decoder from decoder_cfg: {self.decoder_cfg}") def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Training step logic for both main model and decoder.""" @@ -301,7 +250,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch # Compute main model loss main_loss = self.loss_fn(pred, batch["pert_cell_emb"]) - self.log("train_loss", main_loss) + self.log(self._train_main_loss_key(), main_loss) # Process decoder if available decoder_loss = None @@ -315,7 +264,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets) # Log decoder loss - self.log("decoder_loss", decoder_loss) + self.log(self._train_expression_loss_key(), decoder_loss) total_loss = main_loss + decoder_loss else: @@ -330,14 +279,12 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non # TODO: remove unused # is_control = self.control_pert in batch["pert_name"] - self.log("val_loss", loss) + self.log(self._val_main_loss_key(), loss) return {"loss": loss, "predictions": pred} def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: latent_output = self(batch) - target = batch[self.embed_key] - loss = self.loss_fn(latent_output, target) output_dict = { "preds": latent_output, # The distribution's sample @@ -352,10 +299,6 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: if self.gene_decoder is not None: pert_cell_counts_preds = self.gene_decoder(latent_output) output_dict["pert_cell_counts_preds"] = pert_cell_counts_preds - decoder_loss = self.loss_fn(pert_cell_counts_preds, batch["pert_cell_counts"]) - self.log("test_decoder_loss", decoder_loss, prog_bar=True) - - self.log("test_loss", loss, prog_bar=True) def predict_step(self, batch, batch_idx, **kwargs): """ diff --git a/src/state/tx/models/context_mean.py b/src/state/tx/models/context_mean.py index 386bf0a3..472f263d 100644 --- a/src/state/tx/models/context_mean.py +++ b/src/state/tx/models/context_mean.py @@ -214,7 +214,7 @@ def training_step(self, batch, batch_idx): ) target = batch[output_key] loss = self.loss_fn(pred, target) - self.log("train_loss", loss, prog_bar=True) + self.log(self._train_main_loss_key(), loss, prog_bar=True) return None def on_save_checkpoint(self, checkpoint): diff --git a/src/state/tx/models/cpa/__init__.py b/src/state/tx/models/cpa/__init__.py deleted file mode 100644 index 1cd283fa..00000000 --- a/src/state/tx/models/cpa/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._model import CPAPerturbationModel - -__all__ = ["CPAPerturbationModel"] diff --git a/src/state/tx/models/cpa/_base_modules.py b/src/state/tx/models/cpa/_base_modules.py deleted file mode 100644 index 9475b8b6..00000000 --- a/src/state/tx/models/cpa/_base_modules.py +++ /dev/null @@ -1,345 +0,0 @@ -from typing import Literal, Optional - -import torch -import torch.nn as nn -from torch.distributions import Normal -from torch.nn import functional as F - - -class FocalLoss(nn.Module): - """Inspired by https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py - - Focal Loss, as described in https://arxiv.org/abs/1708.02002. - It is essentially an enhancement to cross entropy loss and is - useful for classification tasks when there is a large class imbalance. - x is expected to contain raw, unnormalized scores for each class. - y is expected to contain class labels. - Shape: - - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. - - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. - """ - - def __init__( - self, - alpha: Optional[torch.Tensor] = None, - gamma: float = 2.0, - reduction: str = "mean", - ): - """ - Args: - alpha (Tensor, optional): Weights for each class. Defaults to None. - gamma (float, optional): A constant, as described in the paper. - Defaults to 0. - reduction (str, optional): 'mean', 'sum' or 'none'. - Defaults to 'mean'. - """ - if reduction not in ("mean", "sum", "none"): - raise ValueError('Reduction must be one of: "mean", "sum", "none".') - - super().__init__() - self.alpha = alpha - self.gamma = gamma - self.reduction = reduction - - self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none") - - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - if len(y_true) == 0: - return torch.tensor(0.0) - - # compute weighted cross entropy term: -alpha * log(pt) - # (alpha is already part of self.nll_loss) - log_p = F.log_softmax(y_pred, dim=-1) - ce = self.nll_loss(log_p, y_true) - - # get true class column from each row - all_rows = torch.arange(len(y_pred)) - log_pt = log_p[all_rows, y_true] - - # compute focal term: (1 - pt)^gamma - pt = log_pt.exp() - focal_term = (1 - pt) ** self.gamma - - # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) - loss = focal_term * ce - - if self.reduction == "mean": - loss = loss.mean() - elif self.reduction == "sum": - loss = loss.sum() - - return loss - - -class MLP(nn.Module): - def __init__( - self, - n_input, - n_output, - n_hidden, - n_layers, - activation_fn: Optional[nn.Module] = nn.ReLU, - use_norm: str = "batch", - dropout_rate: float = 0.3, - drop_norm_last_layer: bool = True, - ): - super().__init__() - if drop_norm_last_layer: - layers = [n_input] + [n_hidden] * n_layers - else: - layers = [n_input] + [n_hidden] * (n_layers - 1) + [n_output] - - network = [] - for n_in, n_out in zip(layers[:-1], layers[1:]): - network.append(nn.Linear(n_in, n_out)) - if use_norm == "batch": - network.append(nn.BatchNorm1d(n_out)) - elif use_norm == "layer": - network.append(nn.LayerNorm(n_out)) - network.append(activation_fn()) - network.append(nn.Dropout(dropout_rate)) - - if drop_norm_last_layer: - network.append(nn.Linear(n_hidden, n_output)) - - self.network = nn.Sequential(*network) - - def forward(self, x): - """ - x: (batch_size, n_input) - """ - return self.network(x) - - -class Classifier(nn.Module): - def __init__( - self, - n_input, - n_labels, - n_hidden, - n_layers, - activation_fn=nn.ReLU, - use_norm: str = "batch", - dropout_rate: float = 0.3, - ): - super().__init__() - self.n_output = n_labels - - self.network = MLP( - n_input=n_input, - n_output=n_labels, - n_layers=n_layers, - n_hidden=n_hidden, - use_norm=use_norm, - dropout_rate=dropout_rate, - activation_fn=activation_fn, - drop_norm_last_layer=True, - ) - - def forward(self, x): - y = self.network(x) - return y - - -class VariationalEncoder(nn.Module): - def __init__( - self, - n_input: int, - n_output: int, - n_layers: int = 1, - n_hidden: int = 128, - dropout_rate: float = 0.1, - use_norm: str = "batch", - var_eps: float = 1e-4, - var_activation=None, - return_dist: bool = False, - **kwargs, - ): - super().__init__() - - self.var_eps = var_eps - self.encoder = MLP( - n_input=n_input, - n_output=n_hidden, - n_layers=n_layers, - n_hidden=n_hidden, - dropout_rate=dropout_rate, - use_norm=use_norm, - drop_norm_last_layer=False, - ) - self.mean_encoder = nn.Linear(n_hidden, n_output) - self.var_encoder = nn.Linear(n_hidden, n_output) - self.return_dist = return_dist - - self.var_activation = torch.exp if var_activation is None else var_activation - - def forward(self, x: torch.Tensor, *cat_list: int): - """ """ - q = self.encoder(x, *cat_list) - - q_m = self.mean_encoder(q) - q_v = self.var_activation(self.var_encoder(q)) + self.var_eps - - dist = Normal(q_m, q_v.sqrt()) - latent = dist.rsample() - - if self.return_dist: - return dist, latent - - return q_m, q_v, latent - - -# Inspired by scvi-tools source code: https://github.com/scverse/scvi-tools/blob/d094c9b3c14e8cb3ac3a309b9cf0160aff237393/scvi/nn/_base_components.py -class CountDecoder(nn.Module): - """Decodes data from latent space of ``n_input`` dimensions into ``n_output`` dimensions. - - Uses a fully-connected neural network of ``n_hidden`` layers. - - Parameters - ---------- - n_input - The dimensionality of the input (latent space) - n_output - The dimensionality of the output (data space) - n_cat_list - A list containing the number of categories - for each category of interest. Each category will be - included using a one-hot encoding - n_layers - The number of fully-connected hidden layers - n_hidden - The number of nodes per hidden layer - dropout_rate - Dropout rate to apply to each of the hidden layers - inject_covariates - Whether to inject covariates in each layer, or just the first (default). - use_batch_norm - Whether to use batch norm in layers - use_layer_norm - Whether to use layer norm in layers - scale_activation - Activation layer to use for px_scale_decoder - """ - - def __init__( - self, - n_input: int, - n_output: int, - n_layers: int = 1, - n_hidden: int = 128, - use_norm: Literal["batch", "layer"] = "batch", - scale_activation: Literal["softmax", "softplus"] = "softmax", - ): - super().__init__() - self.px_decoder = MLP( - n_input=n_input, - n_output=n_hidden, - n_layers=n_layers, - n_hidden=n_hidden, - dropout_rate=0.0, - use_norm=use_norm, - drop_norm_last_layer=False, - ) - - # mean gamma - if scale_activation == "softmax": - px_scale_activation = nn.Softmax(dim=-1) - elif scale_activation == "softplus": - px_scale_activation = nn.Softplus() - self.px_scale_decoder = nn.Sequential( - nn.Linear(n_hidden, n_output), - px_scale_activation, - ) - - # dispersion: here we only deal with gene-cell dispersion case - self.px_r_decoder = nn.Linear(n_hidden, n_output) - - # dropout - self.px_dropout_decoder = nn.Linear(n_hidden, n_output) - - def forward( - self, - dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"], - z: torch.Tensor, - library: torch.Tensor, - ): - """The forward computation for a single sample. - - #. Decodes the data from the latent space using the decoder network - #. Returns parameters for the ZINB distribution of expression - #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None`` - - Parameters - ---------- - z : - tensor with shape ``(n_input,)`` - library_size - library size - cat_list - list of category membership(s) for this sample - dispersion - One of the following - - * ``'gene'`` - dispersion parameter of NB is constant per gene across cells - * ``'gene-batch'`` - dispersion can differ between different batches - * ``'gene-label'`` - dispersion can differ between different labels - * ``'gene-cell'`` - dispersion can differ for every gene in every cell - - Returns - ------- - 4-tuple of :py:class:`torch.Tensor` - parameters for the ZINB distribution of expression - - """ - # The decoder returns values for the parameters of the ZINB distribution - px = self.px_decoder(z) - px_scale = self.px_scale_decoder(px) - px_dropout = self.px_dropout_decoder(px) - # Clamp to high value: exp(12) ~ 160000 to avoid nans (computational stability) - px_rate = torch.exp(library) * px_scale # torch.clamp( , max=12) - px_r = self.px_r_decoder(px) if dispersion == "gene-cell" else None - return px_scale, px_r, px_rate, px_dropout - - -class GeneralizedSigmoid(nn.Module): - """ - Sigmoid, log-sigmoid or linear functions for encoding dose-response for - drug perurbations. - """ - - def __init__(self, n_drugs, non_linearity="sigmoid"): - """Sigmoid modeling of continuous variable. - Params - ------ - nonlin : str (default: logsigm) - One of logsigm, sigm. - """ - super(GeneralizedSigmoid, self).__init__() - self.non_linearity = non_linearity - self.n_drugs = n_drugs - - self.beta = torch.nn.Parameter(torch.ones(1, n_drugs), requires_grad=True) - self.bias = torch.nn.Parameter(torch.zeros(1, n_drugs), requires_grad=True) - - self.vmap = None - - def forward(self, x, y): - """ - Parameters - ---------- - x: (batch_size, max_comb_len) - y: (batch_size, max_comb_len) - """ - y = y.long() - if self.non_linearity == "logsigm": - bias = self.bias[0][y] - beta = self.beta[0][y] - c0 = bias.sigmoid() - return (torch.log1p(x) * beta + bias).sigmoid() - c0 - elif self.non_linearity == "sigm": - bias = self.bias[0][y] - beta = self.beta[0][y] - c0 = bias.sigmoid() - return (x * beta + bias).sigmoid() - c0 - else: - return x diff --git a/src/state/tx/models/cpa/_callbacks.py b/src/state/tx/models/cpa/_callbacks.py deleted file mode 100644 index befb433e..00000000 --- a/src/state/tx/models/cpa/_callbacks.py +++ /dev/null @@ -1,29 +0,0 @@ -from lightning.pytorch.callbacks import Callback - - -class CPABestModelTracker(Callback): - def __init__(self, monitor: str = "val_loss", mode: str = "min"): - super().__init__() - self.monitor = monitor - self.mode = mode - self.best_model = None - self.best_score = None - - def on_validation_end(self, trainer, pl_module): - if self.best_score is None: - self.best_score = trainer.callback_metrics[self.monitor] - self.best_model = pl_module.state_dict() - else: - if self.mode == "min": - if trainer.callback_metrics[self.monitor] < self.best_score: - self.best_score = trainer.callback_metrics[self.monitor] - self.best_model = pl_module.state_dict() - else: - if trainer.callback_metrics[self.monitor] > self.best_score: - self.best_score = trainer.callback_metrics[self.monitor] - self.best_model = pl_module.state_dict() - - def on_train_end(self, trainer, pl_module): - pl_module.load_state_dict(self.best_model) - print(f"Best model loaded with {self.monitor} = {self.best_score}") - return self.best_model diff --git a/src/state/tx/models/cpa/_dists.py b/src/state/tx/models/cpa/_dists.py deleted file mode 100644 index 18423b16..00000000 --- a/src/state/tx/models/cpa/_dists.py +++ /dev/null @@ -1,387 +0,0 @@ -import warnings - -import torch -import torch.nn.functional as F -from torch.distributions import Distribution, Gamma, constraints -from torch.distributions import Poisson as PoissonTorch -from torch.distributions.constraints import Constraint -from torch.distributions.utils import ( - broadcast_all, - lazy_property, - logits_to_probs, - probs_to_logits, -) - - -class _Optional(Constraint): - def __init__(self, constraint: Constraint): - self.constraint = constraint - - def check(self, value: torch.Tensor) -> torch.Tensor: - if value is None: - return torch.ones(1, dtype=torch.bool) - return self.constraint.check(value) - - def __repr__(self) -> str: - return f"Optional({self.constraint})" - - -def optional_constraint(constraint: Constraint) -> Constraint: - """Returns a wrapped constraint that allows optional values.""" - return _Optional(constraint) - - -def log_zinb_positive( - x: torch.Tensor, - mu: torch.Tensor, - theta: torch.Tensor, - pi: torch.Tensor, - eps: float = 1e-8, -) -> torch.Tensor: - """Log likelihood (scalar) of a minibatch according to a zinb model. - - Parameters - ---------- - x - Data - mu - mean of the negative binomial (has to be positive support) (shape: minibatch x vars) - theta - inverse dispersion parameter (has to be positive support) (shape: minibatch x vars) - pi - logit of the dropout parameter (real support) (shape: minibatch x vars) - eps - numerical stability constant - - Notes - ----- - We parametrize the bernoulli using the logits, hence the softplus functions appearing. - """ - # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless - # of batch or labels) - if theta.ndimension() == 1: - theta = theta.view(1, theta.size(0)) # In this case, we reshape theta for broadcasting - - # Uses log(sigmoid(x)) = -softplus(-x) - softplus_pi = F.softplus(-pi) - log_theta_eps = torch.log(theta + eps) - log_theta_mu_eps = torch.log(theta + mu + eps) - pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps) - - case_zero = F.softplus(pi_theta_log) - softplus_pi - mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero) - - case_non_zero = ( - -softplus_pi - + pi_theta_log - + x * (torch.log(mu + eps) - log_theta_mu_eps) - + torch.lgamma(x + theta) - - torch.lgamma(theta) - - torch.lgamma(x + 1) - ) - mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero) - - res = mul_case_zero + mul_case_non_zero - - return res - - -def log_nb_positive( - x: torch.Tensor, - mu: torch.Tensor, - theta: torch.Tensor, - eps: float = 1e-8, - log_fn: callable = torch.log, - lgamma_fn: callable = torch.lgamma, -) -> torch.Tensor: - """Log likelihood (scalar) of a minibatch according to a nb model. - - Parameters - ---------- - x - data - mu - mean of the negative binomial (has to be positive support) (shape: minibatch x vars) - theta - inverse dispersion parameter (has to be positive support) (shape: minibatch x vars) - eps - numerical stability constant - log_fn - log function - lgamma_fn - log gamma function - """ - log = log_fn - lgamma = lgamma_fn - log_theta_mu_eps = log(theta + mu + eps) - res = ( - theta * (log(theta + eps) - log_theta_mu_eps) - + x * (log(mu + eps) - log_theta_mu_eps) - + lgamma(x + theta) - - lgamma(theta) - - lgamma(x + 1) - ) - - return res - - -def _convert_counts_logits_to_mean_disp( - total_count: torch.Tensor, logits: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """NB parameterizations conversion. - - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - logits - success logits. - - Returns - ------- - type - the mean and inverse overdispersion of the NB distribution. - - """ - theta = total_count - mu = logits.exp() * theta - return mu, theta - - -def _gamma(theta: torch.Tensor, mu: torch.Tensor) -> Gamma: - concentration = theta - rate = theta / mu - # Important remark: Gamma is parametrized by the rate = 1/scale! - gamma_d = Gamma(concentration=concentration, rate=rate) - return gamma_d - - -# Inspired from scvi-tools source code:https://github.com/scverse/scvi-tools/blob/d094c9b3c14e8cb3ac3a309b9cf0160aff237393/scvi/distributions/_negative_binomial.py -class NegativeBinomial(Distribution): - r"""Negative binomial distribution. - - One of the following parameterizations must be provided: - - (1), (`total_count`, `probs`) where `total_count` is the number of failures until - the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`) - parameterization, which is the one used by scvi-tools. These parameters respectively - control the mean and inverse dispersion of the distribution. - - In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as - follows: - - 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, - \underbrace{\theta/\mu}_{\text{rate}})` - 2. :math:`x \sim \textrm{Poisson}(w)` - - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - probs - The success probability. - mu - Mean of the distribution. - theta - Inverse dispersion. - scale - Normalized mean expression of the distribution. - validate_args - Raise ValueError if arguments do not match constraints - """ - - arg_constraints = { - "mu": optional_constraint(constraints.greater_than_eq(0)), - "theta": optional_constraint(constraints.greater_than_eq(0)), - "scale": optional_constraint(constraints.greater_than_eq(0)), - } - support = constraints.nonnegative_integer - - def __init__( - self, - total_count: torch.Tensor | None = None, - probs: torch.Tensor | None = None, - logits: torch.Tensor | None = None, - mu: torch.Tensor | None = None, - theta: torch.Tensor | None = None, - scale: torch.Tensor | None = None, - validate_args: bool = False, - ): - self._eps = 1e-8 - if (mu is None) == (total_count is None): - raise ValueError( - "Please use one of the two possible parameterizations. Refer to the documentation for more information." - ) - - using_param_1 = total_count is not None and (logits is not None or probs is not None) - if using_param_1: - logits = logits if logits is not None else probs_to_logits(probs) - total_count = total_count.type_as(logits) - total_count, logits = broadcast_all(total_count, logits) - mu, theta = _convert_counts_logits_to_mean_disp(total_count, logits) - else: - mu, theta = broadcast_all(mu, theta) - self.mu = mu - self.theta = theta - self.scale = scale - super().__init__(validate_args=validate_args) - - @property - def mean(self) -> torch.Tensor: - return self.mu - - @property - def variance(self) -> torch.Tensor: - return self.mean + (self.mean**2) / self.theta - - @torch.inference_mode() - def sample( - self, - sample_shape: torch.Size | tuple | None = None, - ) -> torch.Tensor: - """Sample from the distribution.""" - sample_shape = sample_shape or torch.Size() - gamma_d = self._gamma() - p_means = gamma_d.sample(sample_shape) - - # Clamping as distributions objects can have buggy behaviors when - # their parameters are too high - l_train = torch.clamp(p_means, max=1e8) - counts = PoissonTorch(l_train).sample() # Shape : (n_samples, n_cells_batch, n_vars) - return counts - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - if self._validate_args: - try: - self._validate_sample(value) - except ValueError: - warnings.warn( - "The value argument must be within the support of the distribution", - UserWarning, - stacklevel=1, - ) - - return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps) - - def _gamma(self) -> Gamma: - return _gamma(self.theta, self.mu) - - def __repr__(self) -> str: - param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] - args_string = ", ".join( - [ - f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" - for p in param_names - if self.__dict__[p] is not None - ] - ) - return self.__class__.__name__ + "(" + args_string + ")" - - -# Inspired from scvi-tools source code:https://github.com/scverse/scvi-tools/blob/d094c9b3c14e8cb3ac3a309b9cf0160aff237393/scvi/distributions/_negative_binomial.py -class ZeroInflatedNegativeBinomial(NegativeBinomial): - r"""Zero-inflated negative binomial distribution. - - One of the following parameterizations must be provided: - - (1), (`total_count`, `probs`) where `total_count` is the number of failures until - the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`) - parameterization, which is the one used by scvi-tools. These parameters respectively - control the mean and inverse dispersion of the distribution. - - In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as - follows: - - 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, - \underbrace{\theta/\mu}_{\text{rate}})` - 2. :math:`x \sim \textrm{Poisson}(w)` - - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - probs - The success probability. - mu - Mean of the distribution. - theta - Inverse dispersion. - zi_logits - Logits scale of zero inflation probability. - scale - Normalized mean expression of the distribution. - validate_args - Raise ValueError if arguments do not match constraints - """ - - arg_constraints = { - "mu": optional_constraint(constraints.greater_than_eq(0)), - "theta": optional_constraint(constraints.greater_than_eq(0)), - "zi_logits": optional_constraint(constraints.real), - "scale": optional_constraint(constraints.greater_than_eq(0)), - } - support = constraints.nonnegative_integer - - def __init__( - self, - total_count: torch.Tensor | None = None, - probs: torch.Tensor | None = None, - logits: torch.Tensor | None = None, - mu: torch.Tensor | None = None, - theta: torch.Tensor | None = None, - zi_logits: torch.Tensor | None = None, - scale: torch.Tensor | None = None, - validate_args: bool = False, - ): - super().__init__( - total_count=total_count, - probs=probs, - logits=logits, - mu=mu, - theta=theta, - scale=scale, - validate_args=validate_args, - ) - self.zi_logits, self.mu, self.theta = broadcast_all(zi_logits, self.mu, self.theta) - - @property - def mean(self) -> torch.Tensor: - pi = self.zi_probs - return (1 - pi) * self.mu - - @property - def variance(self) -> None: - raise NotImplementedError - - @lazy_property - def zi_logits(self) -> torch.Tensor: - """ZI logits.""" - return probs_to_logits(self.zi_probs, is_binary=True) - - @lazy_property - def zi_probs(self) -> torch.Tensor: - return logits_to_probs(self.zi_logits, is_binary=True) - - @torch.inference_mode() - def sample( - self, - sample_shape: torch.Size | tuple | None = None, - ) -> torch.Tensor: - """Sample from the distribution.""" - sample_shape = sample_shape or torch.Size() - samp = super().sample(sample_shape=sample_shape) - is_zero = torch.rand_like(samp) <= self.zi_probs - samp_ = torch.where(is_zero, torch.zeros_like(samp), samp) - return samp_ - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - """Log probability.""" - try: - self._validate_sample(value) - except ValueError: - warnings.warn( - "The value argument must be within the support of the distribution", - UserWarning, - stacklevel=1, - ) - return log_zinb_positive(value, self.mu, self.theta, self.zi_logits, eps=1e-08) diff --git a/src/state/tx/models/cpa/_model.py b/src/state/tx/models/cpa/_model.py deleted file mode 100644 index a10ec8ce..00000000 --- a/src/state/tx/models/cpa/_model.py +++ /dev/null @@ -1,731 +0,0 @@ -import math -from typing import Dict - -import torch -from torch.optim.lr_scheduler import StepLR -from torchmetrics.functional import accuracy - -from ..base import PerturbationModel - -from ._module import CPAModule - - -class CPAPerturbationModel(PerturbationModel): - """ - Implementation of the CPA model. The outputs are always in - gene expression space. - - Args: - input_dim: Dimension of input embeddings (either number of genes or latent dim from obsm key) - hidden_dim: Dimension of hidden layers - output_dim: Number of genes to predict - pert_dim: Dimension of perturbation inputs (usually one-hot size) - decode_intermediate_dim: Optional intermediate dimension for decoder - n_encoder_layers: Number of layers in encoder (default: 2) - n_decoder_layers: Number of layers in encoder (default: 2) - dropout: Dropout rate (default: 0.1) - learning_rate: Learning rate for optimizer (default: 1e-3) - loss_fn: Loss function (default: 'nn.MSELoss()') - """ - - def __init__( - self, - input_dim: int, - hidden_dim: int, - output_dim: int, - pert_dim: int, - n_cell_types: int, - n_perts: int, - n_batches: int, - output_space: str = "gene", - encode_dosage: bool = False, - dosage_non_linearity: str = "linear", - lr=5e-4, - wd=1e-6, - n_steps_pretrain_ae: int = None, - n_epochs_pretrain_ae: int = None, - n_steps_kl_warmup: int = None, - n_epochs_kl_warmup: int = None, - n_steps_adv_warmup: int = None, - n_epochs_adv_warmup: int = None, - adv_steps: int = 3, - reg_adv: float = 1.0, - pen_adv: float = 1.0, - adv_lr=1e-3, - adv_wd=1e-6, - step_size_lr: int = 45, - do_clip_grad: bool = False, - gradient_clip_value: float = 3.0, - adv_loss: str = "cce", - check_val_every_n_epoch: int = 5, - **kwargs, - ): - # Register with parent constructor - super().__init__( - input_dim=input_dim, - hidden_dim=hidden_dim, - output_dim=output_dim, - pert_dim=pert_dim, - output_space=output_space, - **kwargs, - ) - - # Set class specific parameters before registering with parent constructor - self.n_cell_types = n_cell_types - self.n_perts = n_perts - self.n_batches = n_batches - - self.n_layers_encoder = kwargs.get("n_layers_encoder", 2) - self.n_layers_decoder = kwargs.get("n_layers_decoder", 2) - self.n_hidden_encoder = kwargs.get("n_hidden_encoder", 256) - self.n_hidden_decoder = kwargs.get("n_hidden_decoder", 256) - self.n_latent = kwargs.get("n_latent", 64) - self.recon_loss = kwargs.get("recon_loss", "nb") - - self.use_batch_norm = kwargs.get("use_batch_norm", "both") - self.use_layer_norm = kwargs.get("use_layer_norm", "none") - - self.pert_embeddings = None # will be set in _build_networks - self.encode_dosage = encode_dosage - self.dosage_non_linearity = dosage_non_linearity - - self.dropout_rate_encoder = kwargs.get("dropout_rate_encoder", 0.0) - self.dropout_rate_decoder = kwargs.get("dropout_rate_decoder", 0.0) - self.n_hidden_adv = kwargs.get("n_hidden_adv", 128) - self.n_layers_adv = kwargs.get("n_layers_adv", 2) - self.use_norm_adv = kwargs.get("use_norm_adv", "batch") - self.dropout_rate_adv = kwargs.get("dropout_rate_adv", 0.0) - self.seed = kwargs.get("seed", 0) - - # training params - self.lr = lr - self.wd = wd - self.n_steps_pretrain_ae = n_steps_pretrain_ae - self.n_epochs_pretrain_ae = n_epochs_pretrain_ae - self.n_steps_kl_warmup = n_steps_kl_warmup - self.n_epochs_kl_warmup = n_epochs_kl_warmup - self.n_steps_adv_warmup = n_steps_adv_warmup - self.n_epochs_adv_warmup = n_epochs_adv_warmup - self.adv_steps = adv_steps - self.reg_adv = reg_adv - self.pen_adv = pen_adv - self.adv_lr = adv_lr - self.adv_wd = adv_wd - - self.step_size_lr = step_size_lr - self.do_clip_grad = do_clip_grad - self.gradient_clip_value = gradient_clip_value - self.check_val_every_n_epoch = check_val_every_n_epoch - - self.kl_weight = 0.0 # disabled for now - - self.kwargs = kwargs - - self.adv_loss = adv_loss.lower() - self.gamma = kwargs.get("gamma", 2.0) - if self.adv_loss == "focal": - # self.adv_loss_fn = FocalLoss(gamma=self.gamma, reduction="mean") - raise NotImplementedError("Focal loss not implemented for CPA model yet") - else: - self.adv_loss_fn = torch.nn.CrossEntropyLoss() - - assert self.output_space == "gene", "CPA model only supports gene-level output" - - self.automatic_optimization = False - - # Build model components - self._build_networks() - - def _build_networks(self): - """ - Build the core components: - """ - self.module = CPAModule( - n_genes=self.input_dim, - n_perts=self.n_perts, - n_cell_types=self.n_cell_types, - n_batches=self.n_batches, - pert_embeddings=self.pert_embeddings, - n_latent=self.n_latent, - recon_loss=self.recon_loss, - n_hidden_encoder=self.n_hidden_encoder, - n_layers_encoder=self.n_layers_encoder, - n_hidden_decoder=self.n_hidden_decoder, - n_layers_decoder=self.n_layers_decoder, - use_batch_norm=self.use_batch_norm, - use_layer_norm=self.use_layer_norm, - dropout_rate_encoder=self.dropout_rate_encoder, - dropout_rate_decoder=self.dropout_rate_decoder, - n_hidden_adv=self.n_hidden_adv, - n_layers_adv=self.n_layers_adv, - use_norm_adv=self.use_norm_adv, - dropout_rate_adv=self.dropout_rate_adv, - variational=False, - encode_dosage=self.encode_dosage, - dosage_non_linearity=self.dosage_non_linearity, - seed=self.seed, - ) - - def _forward_step(self, batch: Dict[str, torch.Tensor]): - """ - Given - - Args: - batch: Dictionary containing: - - pert: Perturbation one-hot - - basal: Control expression embedding - - cell_type: Cell type one-hot - - batch: Batch one-hot - """ - basal = batch["ctrl_cell_emb"] - pert = batch["pert_emb"] - cell_type = batch["cell_type_onehot"] - batch_ids = batch["batch"] - pert_dosages = batch.get("pert_dosage", None) - - # if pert is one-hot, convert to index - if pert.dim() == 2 and pert.size(1) == self.n_perts: - pert = pert.argmax(1) - - if cell_type.dim() == 2 and cell_type.size(1) == self.n_cell_types: - cell_type = cell_type.argmax(1) - - if batch_ids.dim() == 2 and batch_ids.size(1) == self.n_batches: - batch_ids = batch_ids.argmax(1) - - encoder_outputs, decoder_outputs = self.module.forward(basal, pert, cell_type, batch_ids, pert_dosages) - - return encoder_outputs, decoder_outputs - - def encode_perturbation(self, pert: torch.Tensor) -> torch.Tensor: - """Map perturbation to an effect vector in embedding space.""" - raise NotImplementedError("Perturbation encoding not supported for CPA model") - - def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor: - """Expression is already in embedding space, pass through.""" - raise NotImplementedError("Basal expression encoding not supported for CPA model") - - def perturb(self, pert: torch.Tensor, basal: torch.Tensor) -> torch.Tensor: - """ - Given a perturbation and basal embeddings, compute the perturbed embedding. - """ - # Project perturbation and basal cell state to latent space - raise NotImplementedError("Perturbation not supported for CPA model") - - def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - encoder_outputs, decoder_outputs = self._forward_step(batch) - - if self.recon_loss == "gauss": - output_key = "loc" - else: - output_key = "mu" - - x_pred = getattr(decoder_outputs["px"], output_key) - - return x_pred - - def adversarial_loss(self, perts, cell_types, batch_ids, z_basal, compute_penalty=True): - """Computes adversarial classification losses and regularizations""" - if compute_penalty: - z_basal = z_basal.requires_grad_(True) - - adv_logits = self.module.forward_adv(z_basal) - pert_logits = adv_logits["pert_logits"] - - pert_adv_loss = self.adv_loss_fn(pert_logits, perts.long()) - pert_acc = accuracy( - pert_logits.argmax(1), - perts.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_perts, - ) - - cell_types_logits = adv_logits["cell_type_logits"] - cell_types_adv_loss = self.adv_loss_fn(cell_types_logits, cell_types.long()) - cell_types_acc = accuracy( - cell_types_logits.argmax(1), - cell_types.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_cell_types, - ) - - batch_ids_logits = adv_logits["batch_logits"] - batch_ids_adv_loss = self.adv_loss_fn(batch_ids_logits, batch_ids.long()) - batch_ids_acc = accuracy( - batch_ids_logits.argmax(1), - batch_ids.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_batches, - ) - - adv_loss = pert_adv_loss + cell_types_adv_loss + batch_ids_adv_loss - adv_acc = (pert_acc + cell_types_acc + batch_ids_acc) / 3.0 - - if compute_penalty: - # Penalty losses - cell_type_penalty = ( - torch.autograd.grad( - cell_types_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - batch_penalty = ( - torch.autograd.grad( - batch_ids_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - pert_penalty = ( - torch.autograd.grad( - pert_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - total_penalty = cell_type_penalty + batch_penalty + pert_penalty - - else: - total_penalty = torch.tensor(0.0, device=z_basal.device) - - return adv_loss, adv_acc, total_penalty - - @property - def do_start_adv_training(self): - if self.n_steps_pretrain_ae is not None: - return self.global_step > self.n_steps_pretrain_ae - elif self.n_epochs_pretrain_ae is not None: - return self.current_epoch > self.n_epochs_pretrain_ae - else: - return True - - @property - def adv_lambda(self): - slope = self.reg_adv - if self.n_steps_adv_warmup is not None: - global_step = self.global_step - - if self.n_steps_pretrain_ae: - global_step -= self.n_steps_pretrain_ae - - if global_step <= self.n_steps_adv_warmup: - proportion = global_step / self.n_steps_adv_warmup - return slope * proportion - else: - return slope - elif self.n_epochs_adv_warmup is not None: - current_epoch = self.current_epoch - - if self.n_epochs_pretrain_ae: - current_epoch -= self.n_epochs_pretrain_ae - - if current_epoch <= self.n_epochs_adv_warmup: - proportion = current_epoch / self.n_epochs_adv_warmup - return slope * proportion - else: - return slope - else: - return slope - - def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Training step logic.""" - opt, opt_adv = self.optimizers() - - enc_outputs, dec_outputs = self._forward_step(batch) - - recon_loss, kl_loss = self.module.loss( - x_pert=batch["pert_cell_emb"], - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - if self.do_start_adv_training: - if self.adv_steps is None: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - perts=batch["pert_emb"].argmax(1), - cell_types=batch["cell_type_onehot"].argmax(1), - batch_ids=batch["batch"].argmax(1), - z_basal=z_basal, - compute_penalty=False, - ) - - loss = recon_loss + self.kl_weight * kl_loss - self.adv_lambda * adv_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - - opt_adv.zero_grad() - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - perts=batch["pert_emb"].argmax(1), - cell_types=batch["cell_type_onehot"].argmax(1), - batch_ids=batch["batch"].argmax(1), - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - elif batch_idx % self.adv_steps == 0: - opt_adv.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - perts=batch["pert_emb"].argmax(1), - cell_types=batch["cell_type_onehot"].argmax(1), - batch_ids=batch["batch"].argmax(1), - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - loss = adv_loss - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - # Model update - else: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - perts=batch["pert_emb"].argmax(1), - cell_types=batch["cell_type_onehot"].argmax(1), - batch_ids=batch["batch"].argmax(1), - z_basal=z_basal, - compute_penalty=False, - ) - - loss = recon_loss + self.kl_weight * kl_loss - self.adv_lambda * adv_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - else: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - loss = recon_loss + self.kl_weight * kl_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - - opt_adv.zero_grad() - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - perts=batch["pert_emb"].argmax(1), - cell_types=batch["cell_type_onehot"].argmax(1), - batch_ids=batch["batch"].argmax(1), - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - r2_mean, pearson_lfc = self.module.r2_metric( - x_pert=batch["pert_cell_emb"], - x_basal=batch["ctrl_cell_emb"], - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - disnt_basal, disnt_after = self.module.disentanglement( - perts=batch["pert_emb"].argmax(1), - cell_types=batch["cell_type_onehot"].argmax(1), - batch_ids=batch["batch"].argmax(1), - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - self.log( - "recon_loss", - recon_loss.item(), - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "r2_mean", - r2_mean, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "pearson_lfc", - pearson_lfc, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "adv_loss", - adv_loss.item(), - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "disnt_basal", - disnt_basal, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "disnt_after", - disnt_after, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "adv_acc", - adv_acc, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - if self.global_step % self.step_size_lr * 1000 == 0: - sch, sch_adv = self.lr_schedulers() - sch.step() - sch_adv.step() - - return loss - - def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: - """Validation step logic.""" - enc_outputs, dec_outputs = self._forward_step(batch) - - recon_loss, kl_loss = self.module.loss( - x_pert=batch["pert_cell_emb"], - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - r2_mean, pearson_lfc = self.module.r2_metric( - x_pert=batch["pert_cell_emb"], - x_basal=batch["ctrl_cell_emb"], - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - disnt_basal, disnt_after = self.module.disentanglement( - perts=batch["pert_emb"].argmax(1), - cell_types=batch["cell_type_onehot"].argmax(1), - batch_ids=batch["batch"].argmax(1), - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - self.log("val_loss", recon_loss + self.kl_weight * kl_loss, prog_bar=True) - self.log("val_r2_mean", r2_mean, prog_bar=True) - self.log("val_pearson_lfc", pearson_lfc, prog_bar=True) - self.log("es_metric", r2_mean + math.e ** (disnt_after - disnt_basal), prog_bar=True) - - def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: - enc_outputs, dec_outputs = self._forward_step(batch) - - recon_loss, kl_loss = self.module.loss( - x_pert=batch["pert_cell_emb"], - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - loss = recon_loss + self.kl_weight * kl_loss - - if self.recon_loss == "gauss": - output_key = "loc" - else: - output_key = "mu" - - x_pred = getattr(dec_outputs["px"], output_key) - - self.log("test_loss", loss, prog_bar=True) - - return x_pred - - def predict_step(self, batch, batch_idx, **kwargs): - """ - Typically used for final inference. We'll replicate old logic: - returning 'preds', 'X', 'pert_name', etc. - """ - - enc_outputs, dec_outputs = self._forward_step(batch) - - if self.recon_loss == "gauss": - output_key = "loc" - else: - output_key = "mu" - - x_pred = getattr(dec_outputs["px"], output_key) - - outputs = { - "preds": x_pred, - "pert_cell_emb": batch.get("pert_cell_emb", None), - "X_gene": batch.get("X_gene", None), - "pert_emb": batch.get("pert_emb", None), - "pert_name": batch.get("pert_name", None), - "cell_type": batch.get("cell_type", None), - "batch": batch.get("batch_name", None), - "ctrl_cell_emb": batch.get("ctrl_cell_emb", None), - } - - outputs = {k: v for k, v in outputs.items() if v is not None} - - return outputs - - def configure_optimizers(self): - """Set up optimizer.""" - ae_params = ( - list(filter(lambda p: p.requires_grad, self.module.encoder.parameters())) - + list(filter(lambda p: p.requires_grad, self.module.decoder.parameters())) - + list( - filter( - lambda p: p.requires_grad, - self.module.pert_embeddings.parameters(), - ) - ) - + list( - filter( - lambda p: p.requires_grad, - self.module.cell_type_embeddings.parameters(), - ) - ) - + list(filter(lambda p: p.requires_grad, self.module.batch_embeddings.parameters())) - ) - - if self.module.recon_loss in ["zinb", "nb"]: - ae_params += [self.module.px_r] - - optimizer_autoencoder = torch.optim.Adam(ae_params, lr=self.lr, weight_decay=self.wd) - - scheduler_autoencoder = StepLR(optimizer_autoencoder, step_size=self.step_size_lr, gamma=0.9) - - adv_params = ( - list( - filter( - lambda p: p.requires_grad, - self.module.perturbation_classifier.parameters(), - ) - ) - + list( - filter( - lambda p: p.requires_grad, - self.module.cell_type_classifier.parameters(), - ) - ) - + list(filter(lambda p: p.requires_grad, self.module.batch_classifier.parameters())) - ) - - optimizer_adversaries = torch.optim.Adam(adv_params, lr=self.adv_lr, weight_decay=self.adv_wd) - scheduler_adversaries = StepLR(optimizer_adversaries, step_size=self.step_size_lr, gamma=0.9) - - optimizers = [optimizer_autoencoder, optimizer_adversaries] - schedulers = [scheduler_autoencoder, scheduler_adversaries] - - if self.step_size_lr is not None: - return optimizers, schedulers - else: - return optimizers diff --git a/src/state/tx/models/cpa/_module.py b/src/state/tx/models/cpa/_module.py deleted file mode 100644 index 9027be34..00000000 --- a/src/state/tx/models/cpa/_module.py +++ /dev/null @@ -1,470 +0,0 @@ -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -from torch.distributions import Normal -from torch.distributions.kl import kl_divergence as kl -from torch_scatter import scatter_mean -from torchmetrics.functional import pairwise_euclidean_distance, pearson_corrcoef, r2_score -from torchmetrics.functional.clustering import normalized_mutual_info_score - -from ._base_modules import MLP, Classifier, CountDecoder, GeneralizedSigmoid, VariationalEncoder -from ._dists import NegativeBinomial, ZeroInflatedNegativeBinomial - - -def knn_purity(data, labels, n_neighbors=15): - """Computes KNN Purity for ``data`` given the labels. - Parameters - ---------- - data: - torch tensor of data (n_samples, n_features) - labels - torch tensor of labels (n_samples,) - n_neighbors: int - Number of nearest neighbors. - Returns - ------- - score: float - KNN purity score. A float between 0 and 1. - """ - distances = pairwise_euclidean_distance(data) - # sort each row in distances to get nearest neighbors - - _, indices = torch.topk(distances, k=n_neighbors + 1, dim=1, largest=False, sorted=True) - indices = indices[:, 1:] # remove self - # neighbors_labels = np.vectorize(lambda i: labels[i])(indices) - neighbors_labels = labels[indices] # (n_samples, n_neighbors) - - # pre cell purity scores - scores = ((neighbors_labels - labels.reshape(-1, 1)) == 0).float().mean(axis=1) # (n_samples,) - res = scatter_mean(scores, labels).mean() # per category purity - - return res - - -class CPAModule(nn.Module): - """ - CPA module using Gaussian/NegativeBinomial Likelihood - - Parameters - ---------- - n_genes: int - n_treatments: int - covars_encoder: dict - Dictionary of covariates with keys as each covariate name and values as - number of unique values of the corresponding covariate - n_latent: int - Latent Dimension - loss_ae: str - Autoencoder loss (either "gauss" or "nb") - doser_type: str - # Type of doser network, either `mlp` or `linear`. - autoencoder_width: int - autoencoder_depth: int - use_batch_norm: bool - use_layer_norm: bool - variational: bool - """ - - def __init__( - self, - n_genes: int, - n_perts: int, - n_cell_types: int, - n_batches: int = 1, - pert_embeddings: Optional[np.ndarray] = None, - n_latent: int = 128, - recon_loss: str = "nb", - n_hidden_encoder: int = 256, - n_layers_encoder: int = 3, - n_hidden_decoder: int = 256, - n_layers_decoder: int = 3, - use_batch_norm: str = "both", - use_layer_norm: str = "none", - dropout_rate_encoder: float = 0.0, - dropout_rate_decoder: float = 0.0, - n_hidden_adv: int = 128, - n_layers_adv: int = 2, - use_norm_adv: str = "batch", - dropout_rate_adv: float = 0.0, - variational: bool = False, - encode_dosage: bool = False, - dosage_non_linearity: str = "linear", - seed: int = 0, - **kwargs, - ): - super().__init__() - - torch.manual_seed(seed) - np.random.seed(seed) - - recon_loss = recon_loss.lower() - - assert recon_loss in ["gauss", "nb", "zinb"] - - self.n_genes = n_genes - self.n_latent = n_latent - self.n_perts = n_perts - self.recon_loss = recon_loss - self.variational = variational - self.encode_dosage = encode_dosage - - if variational: - self.encoder = VariationalEncoder( - n_genes, - n_latent, - var_activation=nn.Softplus(), - n_hidden=n_hidden_encoder, - n_layers=n_layers_encoder, - use_batch_norm=use_batch_norm in ["both", "encoder"], - use_layer_norm=use_layer_norm in ["both", "encoder"], - dropout_rate=dropout_rate_encoder, - activation_fn=nn.ReLU, - return_dist=True, - ) - else: - self.encoder = MLP( - n_input=n_genes, - n_output=n_latent, - n_hidden=n_hidden_encoder, - n_layers=n_layers_encoder, - use_norm=( - "batch" - if use_batch_norm in ["both", "encoder"] - else "layer" - if use_layer_norm in ["both", "encoder"] - else "none" - ), - dropout_rate=dropout_rate_encoder, - activation_fn=nn.ReLU, - drop_norm_last_layer=False, - ) - - # Decoder components - if self.recon_loss in ["zinb", "nb"]: - # setup the parameters of your generative model, as well as your inference model - self.px_r = torch.nn.Parameter(torch.randn(self.n_genes)) - - # decoder goes from n_latent-dimensional space to n_input-d data - self.decoder = CountDecoder( - n_input=n_latent, - n_output=n_genes, - n_layers=n_layers_decoder, - n_hidden=n_hidden_decoder, - use_norm=( - "batch" - if use_batch_norm in ["both", "decoder"] - else "layer" - if use_layer_norm in ["both", "decoder"] - else "none" - ), - ) - - elif recon_loss == "gauss": - self.decoder = VariationalEncoder( - n_input=n_latent, - n_output=n_genes, - n_layers=n_layers_decoder, - n_hidden=n_hidden_decoder, - dropout_rate=dropout_rate_decoder, - use_norm=( - "batch" - if use_batch_norm in ["both", "decoder"] - else "layer" - if use_layer_norm in ["both", "decoder"] - else "none" - ), - var_activation=nn.Softplus(), - ) - - else: - raise Exception("Invalid Loss function for Autoencoder") - - # Embeddings - if pert_embeddings is not None: - self.pert_embeddings = nn.Embedding.from_pretrained(torch.FloatTensor(pert_embeddings), freeze=True) - else: - self.pert_embeddings = nn.Embedding(n_perts, n_latent) - - if n_batches > 1: - self.batch_embeddings = nn.Embedding(n_batches, n_latent) - - self.cell_type_embeddings = nn.Embedding(n_cell_types, n_latent) - - if self.encode_dosage: - self.dosage_encoder = GeneralizedSigmoid(n_perts, non_linearity=dosage_non_linearity) - else: - self.dosage_encoder = None - - # Adversarial Components - self.perturbation_classifier = Classifier( - n_input=n_latent, - n_labels=n_perts, - n_hidden=n_hidden_adv, - n_layers=n_layers_adv, - use_norm=use_norm_adv, - dropout_rate=dropout_rate_adv, - activation_fn=nn.ReLU, - ) - - self.cell_type_classifier = Classifier( - n_input=n_latent, - n_labels=n_cell_types, - n_hidden=n_hidden_adv, - n_layers=n_layers_adv, - use_norm=use_norm_adv, - dropout_rate=dropout_rate_adv, - activation_fn=nn.ReLU, - ) - - self.batch_classifier = Classifier( - n_input=n_latent, - n_labels=n_batches, - n_hidden=n_hidden_adv, - n_layers=n_layers_adv, - use_norm=use_norm_adv, - dropout_rate=dropout_rate_adv, - activation_fn=nn.ReLU, - ) - - self.metrics = { - "pearson_r": pearson_corrcoef, - "r2_score": r2_score, - "nmi": normalized_mutual_info_score, - } - - def forward(self, x_basal, perts, cell_types, batch_ids, pert_dosages=None, n_samples: int = 1): - enc_outputs = self.forward_encoder( - x=x_basal, - perts=perts, - cell_types=cell_types, - batch_ids=batch_ids, - pert_dosages=pert_dosages, - n_samples=n_samples, - ) - - dec_outputs = self.forward_decoder( - z=enc_outputs["z"], - library=enc_outputs["library"], - ) - - return enc_outputs, dec_outputs - - def forward_encoder( - self, - x, - perts, - cell_types, - batch_ids: Optional[torch.Tensor] = None, - pert_dosages: Optional[torch.Tensor] = None, - n_samples: int = 1, - ): - # TODO: remove unused - # batch_size = x.shape[0] - - if self.recon_loss in ["nb", "zinb"]: - # log the input to the variational distribution for numerical stability - x_ = torch.log(1 + x) - library = torch.log(x.sum(1)).unsqueeze(1) - else: - x_ = x - library = None, None - - if self.variational: - qz, z_basal = self.encoder(x_) - else: - qz, z_basal = None, self.encoder(x_) - - if self.variational and n_samples > 1: - sampled_z = qz.sample((n_samples,)) - z_basal = self.encoder.z_transformation(sampled_z) - if self.recon_loss in ["nb", "zinb"]: - library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) - - z_covs = self.cell_type_embeddings(cell_types.long()) - z_batch = self.batch_embeddings(batch_ids.long()) - z_pert = self.pert_embeddings(perts.long()) - - if self.encode_dosage and pert_dosages is not None: - pert_dosages = torch.tensor(pert_dosages, device=x.device, dtype=torch.float32) - scaled_dosages = self.dosage_encoder(pert_dosages, perts.long()).squeeze(-1) # (batch_size,) - - z_pert = torch.einsum("b, b d -> b d", scaled_dosages, z_pert) - - z = z_basal + z_pert + z_covs + z_batch - z_corrected = z_basal + z_pert + z_covs - - return dict( - z=z, - z_basal=z_basal, - z_corrected=z_corrected, - library=library, - qz=qz, - ) - - def forward_decoder( - self, - z, - library, - ): - if self.recon_loss == "nb": - px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) - px_r = torch.exp(self.px_r) - - px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) - - elif self.recon_loss == "zinb": - px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) - px_r = torch.exp(self.px_r) - - px = ZeroInflatedNegativeBinomial( - mu=px_rate, - theta=px_r, - zi_logits=px_dropout, - scale=px_scale, - ) - - else: - px_mean, px_var, x_pred = self.decoder(z) - - px = Normal(loc=px_mean, scale=px_var.sqrt()) - - pz = Normal(torch.zeros_like(z), torch.ones_like(z)) - return dict(px=px, pz=pz) - - def forward_adv(self, z_basal): - pert_logits = self.perturbation_classifier(z_basal) - cell_type_logits = self.cell_type_classifier(z_basal) - batch_logits = self.batch_classifier(z_basal) - - return dict( - pert_logits=pert_logits, - cell_type_logits=cell_type_logits, - batch_logits=batch_logits, - ) - - def loss(self, x_pert, encoder_outputs, decoder_outputs): - """Computes the reconstruction loss (AE) or the ELBO (VAE)""" - px = decoder_outputs["px"] - recon_loss = -px.log_prob(x_pert).sum(dim=-1).mean() - - if self.variational: - qz = encoder_outputs["qz"] - pz = decoder_outputs["pz"] - - kl_divergence_z = kl(qz, pz).sum(dim=1) - kl_loss = kl_divergence_z.mean() - else: - kl_loss = torch.zeros_like(recon_loss) - - return recon_loss, kl_loss - - def r2_metric(self, x_pert, x_basal, encoder_outputs, decoder_outputs): - px = decoder_outputs["px"] - if self.recon_loss == "gauss": - x_pred_mean = px.loc - - x_pred_mean = torch.nan_to_num(x_pred_mean, nan=0, posinf=1e3, neginf=-1e3) - - r2_mean = torch.nan_to_num(self.metrics["r2_score"](x_pred_mean.mean(0), x_pert.mean(0)), nan=0.0).item() - - lfc_true = (x_pert - x_basal).mean(0) - lfc_pred = (x_pred_mean - x_basal).mean(0) - - pearson_lfc = torch.nan_to_num(self.metrics["r2_score"](lfc_pred, lfc_true), nan=0.0).item() - - elif self.recon_loss in ["nb", "zinb"]: - x_pert = torch.log(1 + x_pert) - x_pred = px.mu - x_pred = torch.log(1 + x_pred) - x_basal = torch.log(1 + x_basal) - - x_pred = torch.nan_to_num(x_pred, nan=0, posinf=1e3, neginf=-1e3) - - r2_mean = torch.nan_to_num(self.metrics["r2_score"](x_pred.mean(0), x_pert.mean(0)), nan=0.0).item() - - lfc_true = (x_pert - x_basal).mean(0) - lfc_pred = (x_pred - x_basal).mean(0) - - pearson_lfc = torch.nan_to_num(self.metrics["pearson_r"](lfc_pred, lfc_true), nan=0.0).item() - - return r2_mean, pearson_lfc - - def disentanglement( - self, - perts, - cell_types, - batch_ids, - encoder_outputs, - decoder_outputs, - ): - z_basal = encoder_outputs["z_basal"] - z = encoder_outputs["z"] - - knn_basal = knn_purity( - z_basal, - perts.ravel(), - n_neighbors=min(perts.shape[0] - 1, 30), - ) - knn_after = knn_purity( - z, - perts.ravel(), - n_neighbors=min(perts.shape[0] - 1, 30), - ) - - knn_basal += knn_purity( - z_basal, - cell_types.ravel(), - n_neighbors=min(cell_types.shape[0] - 1, 30), - ) - - knn_after += knn_purity( - z, - cell_types.ravel(), - n_neighbors=min(cell_types.shape[0] - 1, 30), - ) - - if batch_ids is not None: - knn_basal += knn_purity( - z_basal, - batch_ids.ravel(), - n_neighbors=min(batch_ids.shape[0] - 1, 30), - ) - - knn_after += knn_purity( - z, - batch_ids.ravel(), - n_neighbors=min(batch_ids.shape[0] - 1, 30), - ) - - return knn_basal.item(), knn_after.item() - - def get_expression( - self, - ): - """Computes knockout gene expression - - Parameters - ---------- - tensors : dict - dictionary of input tensors - - """ - ## TODO: remove broken code - # _, decoder_outputs = self.forward( - # batch, - # n_samples=n_samples, - # ) - - # px = decoder_outputs["px"] - - # if self.recon_loss == "gauss": - # output_key = "loc" - # else: - # output_key = "mu" - - # output = getattr(px, output_key) - - # return output - raise NotImplementedError("Expression is not implemented") diff --git a/src/state/tx/models/cpa/_task.py b/src/state/tx/models/cpa/_task.py deleted file mode 100644 index a2070701..00000000 --- a/src/state/tx/models/cpa/_task.py +++ /dev/null @@ -1,552 +0,0 @@ -import math -from collections import defaultdict - -import lightning as L -import numpy as np -import torch -from torch import nn -from torch.optim.lr_scheduler import StepLR -from torchmetrics.functional import accuracy - -from ._base_modules import FocalLoss -from ._module import CPAModule - - -class CPATrainer(L.LightningModule): - def __init__( - self, - module: CPAModule, - lr=5e-4, - wd=1e-6, - n_steps_pretrain_ae: int = None, - n_epochs_pretrain_ae: int = None, - n_steps_kl_warmup: int = None, - n_epochs_kl_warmup: int = None, - n_steps_adv_warmup: int = None, - n_epochs_adv_warmup: int = None, - adv_steps: int = 3, - reg_adv: float = 1.0, - pen_adv: float = 1.0, - adv_lr=1e-3, - adv_wd=1e-6, - step_size_lr: int = 45, - do_clip_grad: bool = False, - gradient_clip_value: float = 3.0, - adv_loss: str = "cce", - check_val_every_n_epoch: int = 5, - **kwargs, - ): - """Training plan for the CPA model""" - super().__init__() - - self.module = module - - self.lr = float(lr) - self.n_steps_kl_warmup = n_steps_kl_warmup - self.n_epochs_kl_warmup = n_epochs_kl_warmup - - self.automatic_optimization = False - - self.wd = float(wd) - - self.n_perts = module.n_perts - - self.n_steps_pretrain_ae = n_steps_pretrain_ae - self.n_epochs_pretrain_ae = n_epochs_pretrain_ae - - self.n_steps_adv_warmup = n_steps_adv_warmup - self.n_epochs_adv_warmup = n_epochs_adv_warmup - self.adv_steps = adv_steps - - self.reg_adv = reg_adv - self.pen_adv = pen_adv - - self.adv_lr = float(adv_lr) - self.adv_wd = float(adv_wd) - - self.step_size_lr = step_size_lr - - self.do_clip_grad = do_clip_grad - self.gradient_clip_value = gradient_clip_value - self.check_val_every_n_epoch = check_val_every_n_epoch - - self.metrics = [ - "recon_loss", - "KL", - "disnt_basal", - "disnt_after", - "r2_mean", - "r2_var", - "adv_loss", - "penalty_adv", - "adv_perts", - "acc_perts", - "penalty_perts", - ] - - self.epoch_history = defaultdict(list) - - # TODO: remove unused - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.adv_loss = adv_loss.lower() - self.gamma = kwargs.get("gamma", 2.0) - if self.adv_loss == "focal": - self.adv_loss_fn = FocalLoss(gamma=self.gamma, reduction="mean") - else: - self.adv_loss_fn = nn.CrossEntropyLoss() - - @property - def kl_weight(self): - return 0.0 - - @property - def adv_lambda(self): - slope = self.reg_adv - if self.n_steps_adv_warmup: - global_step = self.global_step - - if self.n_steps_pretrain_ae: - global_step -= self.n_steps_pretrain_ae - - if global_step <= self.n_steps_adv_warmup: - proportion = global_step / self.n_steps_adv_warmup - return slope * proportion - else: - return slope - elif self.n_epochs_adv_warmup is not None: - current_epoch = self.current_epoch - - if self.n_epochs_pretrain_ae: - current_epoch -= self.n_epochs_pretrain_ae - - if current_epoch <= self.n_epochs_adv_warmup: - proportion = current_epoch / self.n_epochs_adv_warmup - return slope * proportion - else: - return slope - else: - return slope - - @property - def do_start_adv_training(self): - if self.n_steps_pretrain_ae: - return self.global_step > self.n_steps_pretrain_ae - elif self.n_epochs_pretrain_ae: - return self.current_epoch > self.n_epochs_pretrain_ae - else: - return True - - def adversarial_loss(self, batch, z_basal, compute_penalty=True): - """Computes adversarial classification losses and regularizations""" - if compute_penalty: - z_basal = z_basal.requires_grad_(True) - - adv_logits = self.module.forward_adv(z_basal) - perts = batch["pert_emb"].argmax(1) - pert_logits = adv_logits["pert_logits"] - - pert_adv_loss = self.adv_loss_fn(pert_logits, perts.long()) - pert_acc = accuracy( - pert_logits.argmax(1), - perts.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_perts, - ) - - cell_types = batch["cell_type_onehot"].argmax(1) - cell_types_logits = adv_logits["cell_type_logits"] - cell_types_adv_loss = self.adv_loss_fn(cell_types_logits, cell_types.long()) - cell_types_acc = accuracy( - cell_types_logits.argmax(1), - cell_types.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_perts, - ) - - batch_ids = batch["batch"].argmax(1) - batch_ids_logits = adv_logits["batch_logits"] - batch_ids_adv_loss = self.adv_loss_fn(batch_ids_logits, batch_ids.long()) - batch_ids_acc = accuracy( - batch_ids_logits.argmax(1), - batch_ids.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_perts, - ) - - adv_loss = pert_adv_loss + cell_types_adv_loss + batch_ids_adv_loss - adv_acc = (pert_acc + cell_types_acc + batch_ids_acc) / 3.0 - - if compute_penalty: - # Penalty losses - cell_type_penalty = ( - torch.autograd.grad( - cell_types_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - batch_penalty = ( - torch.autograd.grad( - batch_ids_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - pert_penalty = ( - torch.autograd.grad( - pert_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - total_penalty = cell_type_penalty + batch_penalty + pert_penalty - - else: - total_penalty = torch.tensor(0.0, device=z_basal.device) - - return adv_loss, adv_acc, total_penalty - - def configure_optimizers(self): - ae_params = ( - list(filter(lambda p: p.requires_grad, self.module.encoder.parameters())) - + list(filter(lambda p: p.requires_grad, self.module.decoder.parameters())) - + list( - filter( - lambda p: p.requires_grad, - self.module.pert_embeddings.parameters(), - ) - ) - + list( - filter( - lambda p: p.requires_grad, - self.module.cell_type_embeddings.parameters(), - ) - ) - + list(filter(lambda p: p.requires_grad, self.module.batch_embeddings.parameters())) - ) - - if self.module.recon_loss in ["zinb", "nb"]: - ae_params += [self.module.px_r] - - optimizer_autoencoder = torch.optim.Adam(ae_params, lr=self.lr, weight_decay=self.wd) - - scheduler_autoencoder = StepLR(optimizer_autoencoder, step_size=self.step_size_lr, gamma=0.9) - - adv_params = ( - list( - filter( - lambda p: p.requires_grad, - self.module.perturbation_classifier.parameters(), - ) - ) - + list( - filter( - lambda p: p.requires_grad, - self.module.cell_type_classifier.parameters(), - ) - ) - + list(filter(lambda p: p.requires_grad, self.module.batch_classifier.parameters())) - ) - - optimizer_adversaries = torch.optim.Adam(adv_params, lr=self.adv_lr, weight_decay=self.adv_wd) - scheduler_adversaries = StepLR(optimizer_adversaries, step_size=self.step_size_lr, gamma=0.9) - - optimizers = [optimizer_autoencoder, optimizer_adversaries] - schedulers = [scheduler_autoencoder, scheduler_adversaries] - - if self.step_size_lr is not None: - return optimizers, schedulers - else: - return optimizers - - def training_step(self, batch, batch_idx): - opt, opt_adv = self.optimizers() - - enc_outputs, dec_outputs = self.module.forward(batch) - - recon_loss, kl_loss = self.module.loss( - batch=batch, - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - if self.do_start_adv_training: - if self.adv_steps is None: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal, - compute_penalty=False, - ) - - loss = recon_loss + self.kl_weight * kl_loss - self.adv_lambda * adv_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - - opt_adv.zero_grad() - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - elif batch_idx % self.adv_steps == 0: - opt_adv.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - # Model update - else: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal, - compute_penalty=False, - ) - - loss = recon_loss + self.kl_weight * kl_loss - self.adv_lambda * adv_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - else: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - loss = recon_loss + self.kl_weight * kl_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - - opt_adv.zero_grad() - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - r2_mean, r2_lfc = self.module.r2_metric(batch, enc_outputs, dec_outputs) - - disnt_basal, disnt_after = self.module.disentanglement(batch, enc_outputs, dec_outputs) - - results = { - "recon_loss": recon_loss.item(), - "KL": kl_loss.item(), - "r2_mean": r2_mean, - "r2_lfc": r2_lfc, - "adv_loss": adv_loss.item(), - "adv_acc": adv_acc.item(), - "penalty_adv": adv_penalty.item(), - "es_metric": r2_mean + np.e ** (disnt_after - disnt_basal), - "disnt_basal": disnt_basal, - "disnt_after": disnt_after, - } - - self.log( - "recon_loss", - recon_loss.item(), - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "r2_mean", - r2_mean, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "r2_lfc", - r2_lfc, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "adv_loss", - adv_loss.item(), - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "disnt_basal", - disnt_basal, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "disnt_after", - disnt_after, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "adv_acc", - adv_acc, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - if self.global_step % self.step_size_lr * 1000 == 0: - sch, sch_adv = self.lr_schedulers() - sch.step() - sch_adv.step() - - return results - - def validation_step(self, batch, batch_idx): - enc_outputs, dec_outputs = self.module.forward(batch) - - recon_loss, kl_loss = self.module.loss( - batch=batch, - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - r2_mean, r2_lfc = self.module.r2_metric(batch, enc_outputs, dec_outputs) - - disnt_basal, disnt_after = self.module.disentanglement(batch, enc_outputs, dec_outputs) - - self.log("val_r2_mean", r2_mean, prog_bar=True) - self.log("val_r2_lfc", r2_lfc, prog_bar=True) - self.log("es_metric", r2_mean + math.e ** (disnt_after - disnt_basal), prog_bar=True) - - def test_step(self, batch, batch_idx): - enc_outputs, dec_outputs = self.module.forward(batch) - - recon_loss, kl_loss = self.module.loss( - batch=batch, - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - r2_mean, r2_lfc = self.module.r2_metric(batch, enc_outputs, dec_outputs) - - self.log("test_recon", recon_loss.item(), prog_bar=True) - self.log("test_r2_mean", r2_mean, prog_bar=True) - self.log("test_r2_lfc", r2_lfc, prog_bar=True) - - x_pred = self.module.get_expression(batch, n_samples=1) - - return x_pred.detach().cpu().numpy() diff --git a/src/state/tx/models/decoder_only.py b/src/state/tx/models/decoder_only.py index ad0b2ce2..47682a96 100644 --- a/src/state/tx/models/decoder_only.py +++ b/src/state/tx/models/decoder_only.py @@ -16,8 +16,8 @@ class DecoderOnlyPerturbationModel(PerturbationModel): this model simply feeds the latent representation through a decoder network. The loss is computed between the decoder output and the target HVG expression. - It keeps the overall architectural style (and uses the SamplesLoss loss function from geomloss) - as in the OldNeuralOT model. + It keeps the overall architectural style and uses the SamplesLoss loss + function from geomloss. """ def __init__( @@ -49,7 +49,7 @@ def __init__( self.activation_class = get_activation_class(kwargs.get("activation", "gelu")) self.gene_dim = gene_dim - # Use the same loss function as OldNeuralOT (e.g. using the MMD loss via geomloss) + # Use an MMD-style distributional loss via geomloss. self.loss_fn = SamplesLoss(loss=self.distributional_loss) def _build_networks(self): @@ -71,7 +71,7 @@ def training_step(self, batch, batch_idx): """ pred = self(batch) # log a zero tensor - self.log("train_loss", 0.0) + self.log(self._train_main_loss_key(), 0.0) if self.gene_decoder is not None and "pert_cell_counts" in batch: pert_cell_counts_preds = self.gene_decoder(pred) @@ -79,15 +79,15 @@ def training_step(self, batch, batch_idx): gene_targets = batch["pert_cell_counts"] gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() - self.log("decoder_loss", decoder_loss) + self.log(self._train_expression_loss_key(), decoder_loss) else: - self.log("decoder_loss", 0.0) + self.log(self._train_expression_loss_key(), 0.0) decoder_loss = None return decoder_loss def validation_step(self, batch, batch_idx): pred = self(batch) - self.log("val_loss", 0.0) + self.log(self._val_main_loss_key(), 0.0) return {"loss": None, "predictions": pred} @@ -100,7 +100,7 @@ def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): pert_cell_counts_preds = pert_cell_counts_preds.reshape(-1, self.cell_sentence_len, self.gene_dim) gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() - self.log("decoder_val_loss", decoder_loss) + self.log(self._val_expression_loss_key(), decoder_loss) def test_step(self, batch, batch_idx): pred = self(batch) @@ -109,8 +109,7 @@ def test_step(self, batch, batch_idx): pert_cell_counts_preds = self.gene_decoder(pred) gene_targets = batch["pert_cell_counts"] gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) - decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() - self.log("decoder_test_loss", decoder_loss) + _ = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() return {"loss": None, "predictions": pred} def predict_step(self, batch, batch_idx, padded=True, **kwargs): diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py deleted file mode 100644 index 613397f7..00000000 --- a/src/state/tx/models/decoders.py +++ /dev/null @@ -1,195 +0,0 @@ -import logging -import os -from typing import Optional - -import torch -import torch.nn as nn - -from omegaconf import OmegaConf - -from ...emb.finetune_decoder import Finetune - -logger = logging.getLogger(__name__) - - -class FinetuneVCICountsDecoder(nn.Module): - def __init__( - self, - genes=None, - adata=None, - # checkpoint: Optional[str] = "/large_storage/ctc/userspace/aadduri/SE-600M/se600m_epoch15.ckpt", - # config: Optional[str] = "/large_storage/ctc/userspace/aadduri/SE-600M/config.yaml", - checkpoint: Optional[str] = "/home/aadduri/vci_pretrain/vci_1.4.4/vci_1.4.4_v7.ckpt", - config: Optional[str] = "/home/aadduri/vci_pretrain/vci_1.4.4/config.yaml", - latent_dim: int = 1034, # total input dim (cell emb + optional ds emb) - read_depth: float = 4.0, - ds_emb_dim: int = 10, # dataset embedding dim at the tail of input - hidden_dim: int = 512, - dropout: float = 0.1, - basal_residual: bool = False, - train_binary_decoder: bool = True, - ): - super().__init__() - # Initialize finetune helper and model from a single checkpoint - if config is None: - raise ValueError( - "FinetuneVCICountsDecoder requires a VCI/SE config. Set kwargs.vci_config or env STATE_VCI_CONFIG." - ) - self.finetune = Finetune(cfg=OmegaConf.load(config), train_binary_decoder=train_binary_decoder) - self.finetune.load_model(checkpoint) - # Resolve genes: prefer explicit list; else infer from anndata if provided - if genes is None and adata is not None: - try: - genes = self.finetune.genes_from_adata(adata) - except Exception as e: - raise ValueError(f"Failed to infer genes from AnnData: {e}") - if genes is None: - raise ValueError("FinetuneVCICountsDecoder requires 'genes' or 'adata' to derive gene names") - self.genes = genes - # Keep read_depth as a learnable parameter so decoded counts can adapt - self.read_depth = nn.Parameter(torch.tensor(read_depth, dtype=torch.float), requires_grad=True) - self.basal_residual = basal_residual - self.ds_emb_dim = int(ds_emb_dim) if ds_emb_dim is not None else 0 - self.input_total_dim = int(latent_dim) - - self.latent_decoder = nn.Sequential( - nn.Linear(latent_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim, len(self.genes)), - ) - - self.gene_decoder_proj = nn.Sequential( - nn.Linear(len(self.genes), 128), - nn.LayerNorm(128), - nn.GELU(), - nn.Linear(128, 128), - nn.LayerNorm(128), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(128, len(self.genes)), - ) - - self.binary_decoder = self.finetune.model.binary_decoder # type: ignore - - # Validate that all requested genes exist in the pretrained checkpoint's embeddings - pe = getattr(self.finetune, "protein_embeds", {}) - self.present_mask = [g in pe for g in self.genes] - self.missing_positions = [i for i, g in enumerate(self.genes) if g not in pe] - self.missing_genes = [self.genes[i] for i in self.missing_positions] - total_req = len(self.genes) - found = total_req - len(self.missing_positions) - total_pe = len(pe) if hasattr(pe, "__len__") else -1 - miss_pct = (len(self.missing_positions) / total_req) if total_req > 0 else 0.0 - logger.info( - f"FinetuneVCICountsDecoder gene check: requested={total_req}, found={found}, missing={len(self.missing_positions)} ({miss_pct:.1%}), all_embeddings_size={total_pe}" - ) - - # Create learnable embeddings for missing genes in the post-ESM gene embedding space - if len(self.missing_positions) > 0: - # Infer gene embedding output dimension by a dry-run through gene_embedding_layer - try: - sample_vec = next(iter(pe.values())).to(self.finetune.model.device) - if sample_vec.dim() == 1: - sample_vec = sample_vec.unsqueeze(0) - gene_embed_dim = self.finetune.model.gene_embedding_layer(sample_vec).shape[-1] - except Exception: - # Conservative fallback - gene_embed_dim = 1024 - - self.missing_table = nn.Embedding(len(self.missing_positions), gene_embed_dim) - nn.init.normal_(self.missing_table.weight, mean=0.0, std=0.02) - # For user visibility - try: - self.finetune.missing_genes = self.missing_genes - except Exception: - pass - else: - # Register a dummy buffer so attributes exist - self.missing_table = None - - # Ensure the wrapped Finetune helper creates its own missing-table parameters - # prior to Lightning's checkpoint load. Otherwise the checkpoint will contain - # weights like `gene_decoder.finetune.missing_table.weight` that are absent - # from a freshly constructed module, triggering "unexpected key" errors. - try: - with torch.no_grad(): - self.finetune.get_gene_embedding(self.genes) - except Exception as exc: - logger.debug(f"Deferred Finetune missing-table initialization failed: {exc}") - - def gene_dim(self): - return len(self.genes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x is [B, S, total_dim] - if x.dim() != 3: - x = x.unsqueeze(0) - batch_size, seq_len, total_dim = x.shape - x_flat = x.reshape(batch_size * seq_len, total_dim) - - # Split cell and dataset embeddings - if self.ds_emb_dim > 0: - cell_embeds = x_flat[:, : total_dim - self.ds_emb_dim] - ds_emb = x_flat[:, total_dim - self.ds_emb_dim : total_dim] - else: - cell_embeds = x_flat - ds_emb = None - - # Prepare gene embeddings (replace any missing with learned vectors) - gene_embeds = self.finetune.get_gene_embedding(self.genes) - if self.missing_table is not None and len(self.missing_positions) > 0: - device = gene_embeds.device - learned = self.missing_table.weight.to(device) - idx = torch.tensor(self.missing_positions, device=device, dtype=torch.long) - gene_embeds = gene_embeds.clone() - gene_embeds.index_copy_(0, idx, learned) - # Ensure embeddings live on the same device as cell_embeds - if gene_embeds.device != cell_embeds.device: - gene_embeds = gene_embeds.to(cell_embeds.device) - - # RDA read depth vector (if enabled in SE model) - use_rda = getattr(self.finetune.model.cfg.model, "rda", False) - task_counts = None - if use_rda: - task_counts = self.read_depth.expand(cell_embeds.shape[0]) - if task_counts.device != cell_embeds.device: - task_counts = task_counts.to(cell_embeds.device) - - # Binary decoder forward with safe dtype handling. - # - On CUDA: enable bf16 autocast for speed. - # - On CPU: ensure inputs match decoder weight dtype to avoid BF16/FP32 mismatch. - device_type = "cuda" if cell_embeds.is_cuda else "cpu" - with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=(device_type == "cuda")): - merged = self.finetune.model.resize_batch( - cell_embeds=cell_embeds, task_embeds=gene_embeds, task_counts=task_counts, ds_emb=ds_emb - ) - - # Align input dtype with decoder weights when autocast is not active (e.g., CPU path) - dec_param_dtype = next(self.binary_decoder.parameters()).dtype - if device_type != "cuda" and merged.dtype != dec_param_dtype: - merged = merged.to(dec_param_dtype) - - logprobs = self.binary_decoder(merged) - if logprobs.dim() == 3 and logprobs.size(-1) == 1: - logprobs = logprobs.squeeze(-1) - - # Reshape back to [B, S, gene_dim] - decoded_gene = logprobs.view(batch_size, seq_len, len(self.genes)) - - # Match dtype for post-decoder projection to avoid mixed-dtype matmul - proj_param_dtype = next(self.gene_decoder_proj.parameters()).dtype - if decoded_gene.dtype != proj_param_dtype: - decoded_gene = decoded_gene.to(proj_param_dtype) - decoded_gene = decoded_gene + self.gene_decoder_proj(decoded_gene) - - # Optional residual from latent decoder (operates on full input features) - ld_param_dtype = next(self.latent_decoder.parameters()).dtype - x_flat_for_ld = x_flat if x_flat.dtype == ld_param_dtype else x_flat.to(ld_param_dtype) - decoded_x = self.latent_decoder(x_flat_for_ld).view(batch_size, seq_len, len(self.genes)) - return torch.nn.functional.relu(decoded_gene + decoded_x) diff --git a/src/state/tx/models/embed_sum.py b/src/state/tx/models/embed_sum.py index 8aef97b4..655ef06f 100644 --- a/src/state/tx/models/embed_sum.py +++ b/src/state/tx/models/embed_sum.py @@ -10,7 +10,7 @@ class EmbedSumPerturbationModel(PerturbationModel): """ Implementation of the EmbedSum model which treats perturbations as learned embeddings that are added to control cell representations, which are input as gene expression counts - or as embeddings from a foundation model (UCE, scGPT, etc). The outputs are always in + or as embeddings from a foundation model (for example, UCE). The outputs are always in gene expression space. This model: diff --git a/src/state/tx/models/old_neural_ot.py b/src/state/tx/models/old_neural_ot.py deleted file mode 100644 index 88c2090f..00000000 --- a/src/state/tx/models/old_neural_ot.py +++ /dev/null @@ -1,241 +0,0 @@ -from typing import Dict, Optional - -import torch -from geomloss import SamplesLoss - -from .base import PerturbationModel -from .utils import build_mlp, get_activation_class, get_transformer_backbone - - -class OldNeuralOTPerturbationModel(PerturbationModel): - """ - This model: - 1) Projects basal expression and perturbation encodings into a shared latent space. - 2) Uses an OT-based distributional loss (energy, sinkhorn, etc.) from geomloss. - 3) Enables cells to attend to one another, learning a set-to-set function rather than - a sample-to-sample single-cell map. - """ - - def __init__( - self, - input_dim: int, - hidden_dim: int, - output_dim: int, - pert_dim: int, - predict_residual: bool = True, - distributional_loss: str = "energy", - transformer_backbone_key: str = "GPT2", - transformer_backbone_kwargs: dict = None, - output_space: str = "gene", - gene_dim: Optional[int] = None, - **kwargs, - ): - """ - Args: - input_dim: dimension of the input expression (e.g. number of genes or embedding dimension). - hidden_dim: not necessarily used, but required by PerturbationModel signature. - output_dim: dimension of the output space (genes or latent). - pert_dim: dimension of perturbation embedding. - gpt: e.g. "TranslationTransformerSamplesModel". - model_kwargs: dictionary passed to that model's constructor. - loss: choice of distributional metric ("sinkhorn", "energy", etc.). - **kwargs: anything else to pass up to PerturbationModel or not used. - """ - # Call the parent PerturbationModel constructor - super().__init__( - input_dim=input_dim, - hidden_dim=hidden_dim, - gene_dim=gene_dim, - output_dim=output_dim, - pert_dim=pert_dim, - output_space=output_space, - **kwargs, - ) - - # Save or store relevant hyperparams - self.predict_residual = predict_residual - self.n_encoder_layers = kwargs.get("n_encoder_layers", 2) - self.n_decoder_layers = kwargs.get("n_decoder_layers", 2) - self.activation_class = get_activation_class(kwargs.get("activation", "gelu")) - self.transformer_backbone_key = transformer_backbone_key - self.transformer_backbone_kwargs = transformer_backbone_kwargs - self.distributional_loss = distributional_loss - self.cell_sentence_len = self.transformer_backbone_kwargs["n_positions"] - self.gene_dim = gene_dim - - # Build the distributional loss from geomloss - self.loss_fn = SamplesLoss(loss=self.distributional_loss) - # self.loss_fn = LearnableAlignmentLoss() - - # Build the underlying neural OT network - self._build_networks() - - def _build_networks(self): - """ - Here we instantiate the actual GPT2-based model or any neuralOT translator - via your old get_model(model_key, model_kwargs) approach. - """ - self.pert_encoder = build_mlp( - in_dim=self.pert_dim, - out_dim=self.hidden_dim, - hidden_dim=self.hidden_dim, - n_layers=self.n_encoder_layers, - dropout=self.dropout, - activation=self.activation_class, - ) - - # Map the input embedding to the hidden space - self.basal_encoder = build_mlp( - in_dim=self.input_dim, - out_dim=self.hidden_dim, - hidden_dim=self.hidden_dim, - n_layers=self.n_encoder_layers, - dropout=self.dropout, - activation=self.activation_class, - ) - - self.transformer_backbone, self.transformer_model_dim = get_transformer_backbone( - self.transformer_backbone_key, - self.transformer_backbone_kwargs, - ) - - self.project_out = build_mlp( - in_dim=self.hidden_dim, - out_dim=self.output_dim, - hidden_dim=self.hidden_dim, - n_layers=self.n_decoder_layers, - dropout=self.dropout, - activation=self.activation_class, - ) - - print(self) - - def encode_perturbation(self, pert: torch.Tensor) -> torch.Tensor: - """If needed, define how we embed the raw perturbation input.""" - return self.pert_encoder(pert) - - def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor: - """Define how we embed basal state input, if needed.""" - return self.basal_encoder(expr) - - def perturb(self, pert: torch.Tensor, basal: torch.Tensor) -> torch.Tensor: - """ - Return the latent perturbed state given the perturbation and basal state. - """ - pert_embedding = self.encode_perturbation(pert).unsqueeze(1) # shape: [batch_size, 1, hidden_dim] - control_cells = self.encode_basal_expression(basal).unsqueeze(1) # shape: [batch_size, 1, hidden_dim] - cls_input = torch.zeros_like(pert_embedding) # shape: [batch_size, 1, hidden_dim] - seq_input = torch.cat([pert_embedding, control_cells, cls_input], dim=1) # shape: [batch_size, 3, hidden_dim] - - # forward pass + extract CLS last hidden state - prediction = self.transformer_backbone(inputs_embeds=seq_input).last_hidden_state[:, -1] - - # add to basal if predicting residual - if self.predict_residual: - # treat the actual prediction as a residual sum to basal - return prediction + control_cells.squeeze(1) - else: - return prediction - - def forward(self, batch: dict) -> torch.Tensor: - """ - The main forward call. - """ - prediction = self.perturb(batch["pert_emb"], batch["ctrl_cell_emb"]) - output = self.project_out(prediction) - - return output - - def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Training step logic for both main model and decoder.""" - # Get model predictions (in latent space) - pred = self(batch) - pred = pred.reshape(-1, self.cell_sentence_len, self.output_dim) - # TODO: please improve this, do not assume self.cell_sentence_len for this model - target = batch["pert_cell_emb"] - target = target.reshape(-1, self.cell_sentence_len, self.output_dim) - main_loss = self.loss_fn(pred, target).mean() - self.log("train_loss", main_loss) - - # Process decoder if available - decoder_loss = None - if self.gene_decoder is not None and "pert_cell_counts" in batch: - # Train decoder to map latent predictions to gene space - with torch.no_grad(): - latent_preds = pred.detach() # Detach to prevent gradient flow back to main model - - pert_cell_counts_preds = self.gene_decoder(latent_preds) - gene_targets = batch["pert_cell_counts"] - gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) - decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() - - # Log decoder loss - self.log("decoder_loss", decoder_loss) - - total_loss = main_loss + decoder_loss - else: - total_loss = main_loss - - return total_loss - - def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: - """Validation step logic.""" - pred = self(batch) - pred = pred.reshape(-1, self.cell_sentence_len, self.output_dim) - target = batch["pert_cell_emb"] - target = target.reshape(-1, self.cell_sentence_len, self.output_dim) - loss = self.loss_fn(pred, target).mean() - self.log("val_loss", loss) - - return {"loss": loss, "predictions": pred} - - def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0) -> None: - """Track decoder performance during validation without training it.""" - if self.gene_decoder is not None and "pert_cell_counts" in batch: - # Get model predictions from validation step - latent_preds = outputs["predictions"] - - # Train decoder to map latent predictions to gene space - pert_cell_counts_preds = self.gene_decoder(latent_preds) # verify this is automatically detached - gene_targets = batch["pert_cell_counts"] - - # Get decoder predictions - pert_cell_counts_preds = pert_cell_counts_preds.reshape(-1, self.cell_sentence_len, self.gene_dim) - gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) - decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() - - # Log the validation metric - self.log("decoder_val_loss", decoder_loss) - - def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: - pred = self.forward(batch, padded=False) - target = batch["pert_cell_emb"] - pred = pred.reshape(1, -1, self.output_dim) - target = target.reshape(1, -1, self.output_dim) - loss = self.loss_fn(pred, target).mean() - self.log("test_loss", loss) - pred = pred.reshape(-1, self.output_dim) - target = target.reshape(-1, self.output_dim) - - def predict_step(self, batch, batch_idx, padded=True, **kwargs): - """ - Typically used for final inference. We'll replicate old logic: - returning 'preds', 'X', 'pert_name', etc. - """ - latent_output = self.forward(batch) # shape [B, ...] - output_dict = { - "preds": latent_output, - "pert_cell_emb": batch.get("pert_cell_emb", None), - "pert_cell_counts": batch.get("pert_cell_counts", None), - "pert_name": batch.get("pert_name", None), - "celltype_name": batch.get("cell_type", None), - "batch": batch.get("batch", None), - "ctrl_cell_emb": batch.get("ctrl_cell_emb", None), - "pert_cell_barcode": batch.get("pert_cell_barcode", None), - } - - if self.gene_decoder is not None: - pert_cell_counts_preds = self.gene_decoder(latent_output) - output_dict["pert_cell_counts_preds"] = pert_cell_counts_preds - - return output_dict diff --git a/src/state/tx/models/perturb_mean.py b/src/state/tx/models/perturb_mean.py index 618236d9..ba821218 100644 --- a/src/state/tx/models/perturb_mean.py +++ b/src/state/tx/models/perturb_mean.py @@ -216,7 +216,7 @@ def training_step(self, batch, batch_idx): else: target = batch["pert_cell_emb"] loss = self.loss_fn(pred, target) - self.log("train_loss", loss, prog_bar=True) + self.log(self._train_main_loss_key(), loss, prog_bar=True) return None def on_save_checkpoint(self, checkpoint): diff --git a/src/state/tx/models/pseudobulk.py b/src/state/tx/models/pseudobulk.py index 8f7b8577..fa0051dc 100644 --- a/src/state/tx/models/pseudobulk.py +++ b/src/state/tx/models/pseudobulk.py @@ -8,7 +8,6 @@ from geomloss import SamplesLoss from .base import PerturbationModel -from .decoders import FinetuneVCICountsDecoder from .utils import build_mlp, get_activation_class, get_transformer_backbone logger = logging.getLogger(__name__) @@ -104,25 +103,10 @@ def __init__( # actually just set this to a relu for now self.relu = torch.nn.ReLU() - control_pert = kwargs.get("control_pert", "non-targeting") if kwargs.get("finetune_vci_decoder", False): - # Prefer the gene names supplied by the data module (aligned to training output) - gene_names = self.gene_names - if gene_names is None: - raise ValueError( - "finetune_vci_decoder=True but model.gene_names is None. " - "Please provide gene_names via data module var_dims." - ) - - n_genes = len(gene_names) - logger.info( - f"Initializing FinetuneVCICountsDecoder with {n_genes} genes (output_space={output_space}; " - + ("HVG subset" if output_space == "gene" else "all genes") - + ")" - ) - self.gene_decoder = FinetuneVCICountsDecoder( - genes=gene_names, - checkpoint=kwargs.get("vci_checkpoint", None), + logger.warning( + "model.kwargs.finetune_vci_decoder is no longer supported. " + "Ignoring it and using the standard latent-to-gene decoder path." ) print(self) @@ -177,7 +161,7 @@ def _maybe_concat_batch(self, latent: torch.Tensor, batch: torch.Tensor, padded: # Decide whether to concatenate based on the decoder's input expectation if expected_in is None: - # Fallback to previous behavior: concatenate for non-VCI decoders + # Fallback to previous behavior: concatenate batch covariates. return torch.cat([latent, batch_var], dim=-1) if expected_in == last_dim: @@ -318,7 +302,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T target = target.reshape(1, -1, self.output_dim) main_loss = self.loss_fn(pred, target).nanmean() - self.log("train_loss", main_loss) + self.log(self._train_main_loss_key(), main_loss) # Process decoder if available decoder_loss = None @@ -331,8 +315,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T # with torch.no_grad(): # latent_preds = pred.detach() # Detach to prevent gradient flow back to main model - if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): - latent_preds = self._maybe_concat_batch(latent_preds, batch["batch"], padded=True) + latent_preds = self._maybe_concat_batch(latent_preds, batch["batch"], padded=True) pert_cell_counts_preds = self.gene_decoder(latent_preds) if padded: @@ -343,7 +326,7 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() # Log decoder loss - self.log("decoder_loss", decoder_loss) + self.log(self._train_expression_loss_key(), decoder_loss) total_loss = total_loss + 0.1 * decoder_loss @@ -358,7 +341,7 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non target = target.reshape(-1, self.cell_sentence_len, self.output_dim) loss = torch.nanmean(self.loss_fn(pred, target)) - self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log(self._val_main_loss_key(), loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) if self.gene_decoder is not None and "pert_cell_counts" in batch: gene_targets = batch["pert_cell_counts"] @@ -367,8 +350,7 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non latent_preds = pred # Match decoder input dims - if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): - latent_preds = self._maybe_concat_batch(latent_preds, batch["batch"], padded=True) + latent_preds = self._maybe_concat_batch(latent_preds, batch["batch"], padded=True) pert_cell_counts_preds = self.gene_decoder(latent_preds) # Get decoder predictions @@ -377,17 +359,12 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() # Log the validation metric - self.log("decoder_val_loss", decoder_loss) + self.log(self._val_expression_loss_key(), decoder_loss) return {"loss": loss, "predictions": pred} def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: - pred = self.forward(batch, padded=False) - target = batch["pert_cell_emb"] - pred = pred.reshape(1, -1, self.output_dim) - target = target.reshape(1, -1, self.output_dim) - loss = self.loss_fn(pred, target).mean() - self.log("test_loss", loss) + _ = self.forward(batch, padded=False) def predict_step(self, batch, batch_idx, padded=True, **kwargs): """ @@ -410,8 +387,7 @@ def predict_step(self, batch, batch_idx, padded=True, **kwargs): if self.gene_decoder is not None: # Only concat batch covariates if decoder expects them - if not isinstance(self.gene_decoder, FinetuneVCICountsDecoder): - latent_output = self._maybe_concat_batch(latent_output, batch["batch"], padded=padded) + latent_output = self._maybe_concat_batch(latent_output, batch["batch"], padded=padded) pert_cell_counts_preds = self.gene_decoder(latent_output) output_dict["pert_cell_counts_preds"] = pert_cell_counts_preds diff --git a/src/state/tx/models/scgpt/__init__.py b/src/state/tx/models/scgpt/__init__.py deleted file mode 100644 index 852431e8..00000000 --- a/src/state/tx/models/scgpt/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .generation_model import TransformerGenerator -from .lightning_model import scGPTForPerturbation -from .loss import criterion_neg_log_bernoulli, masked_mse_loss, masked_relative_error -from .utils import map_raw_id_to_vocab_id - -__all__ = [ - "scGPTForPerturbation", - "TransformerGenerator", - "masked_mse_loss", - "criterion_neg_log_bernoulli", - "masked_relative_error", - "map_raw_id_to_vocab_id", -] diff --git a/src/state/tx/models/scgpt/dsbn.py b/src/state/tx/models/scgpt/dsbn.py deleted file mode 100644 index 50719df4..00000000 --- a/src/state/tx/models/scgpt/dsbn.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Optional - -import torch -from torch import nn - - -# The code is modified from https://github.com/wgchang/DSBN/blob/master/model/dsbn.py -class _DomainSpecificBatchNorm(nn.Module): - _version = 2 - - def __init__( - self, - num_features: int, - num_domains: int, - eps: float = 1e-5, - momentum: float = 0.1, - affine: bool = True, - track_running_stats: bool = True, - ): - super(_DomainSpecificBatchNorm, self).__init__() - self._cur_domain = None - self.num_domains = num_domains - self.bns = nn.ModuleList( - [self.bn_handle(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_domains)] - ) - - @property - def bn_handle(self) -> nn.Module: - raise NotImplementedError - - @property - def cur_domain(self) -> Optional[int]: - return self._cur_domain - - @cur_domain.setter - def cur_domain(self, domain_label: int): - self._cur_domain = domain_label - - def reset_running_stats(self): - for bn in self.bns: - bn.reset_running_stats() - - def reset_parameters(self): - for bn in self.bns: - bn.reset_parameters() - - def _check_input_dim(self, input: torch.Tensor): - raise NotImplementedError - - def forward(self, x: torch.Tensor, domain_label: int) -> torch.Tensor: - self._check_input_dim(x) - if domain_label >= self.num_domains: - raise ValueError(f"Domain label {domain_label} exceeds the number of domains {self.num_domains}") - bn = self.bns[domain_label] - self.cur_domain = domain_label - return bn(x) - - -class DomainSpecificBatchNorm1d(_DomainSpecificBatchNorm): - @property - def bn_handle(self) -> nn.Module: - return nn.BatchNorm1d - - def _check_input_dim(self, input: torch.Tensor): - if input.dim() > 3: - raise ValueError("expected at most 3D input (got {}D input)".format(input.dim())) - - -class DomainSpecificBatchNorm2d(_DomainSpecificBatchNorm): - @property - def bn_handle(self) -> nn.Module: - return nn.BatchNorm2d - - def _check_input_dim(self, input: torch.Tensor): - if input.dim() != 4: - raise ValueError("expected 4D input (got {}D input)".format(input.dim())) diff --git a/src/state/tx/models/scgpt/gene_tokenizer.py b/src/state/tx/models/scgpt/gene_tokenizer.py deleted file mode 100644 index d70ef996..00000000 --- a/src/state/tx/models/scgpt/gene_tokenizer.py +++ /dev/null @@ -1,398 +0,0 @@ -import json -import pickle -from collections import Counter, OrderedDict -from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -import torch - -# import torchtext.vocab as torch_vocab -from torchtext.vocab import Vocab -from typing_extensions import Self - -from .. import logger - - -class GeneVocab(Vocab): - """ - Vocabulary for genes. - """ - - def __init__( - self, - gene_list_or_vocab: Union[List[str], Vocab], - specials: Optional[List[str]] = None, - special_first: bool = True, - default_token: Optional[str] = "", - ) -> None: - """ - Initialize the vocabulary. - Note: add specials only works when init from a gene list. - - Args: - gene_list_or_vocab (List[str] or Vocab): List of gene names or a - Vocab object. - specials (List[str]): List of special tokens. - special_first (bool): Whether to add special tokens to the beginning - of the vocabulary. - default_token (str): Default token, by default will set to "", - if "" is in the vocabulary. - """ - if isinstance(gene_list_or_vocab, Vocab): - _vocab = gene_list_or_vocab - if specials is not None: - raise ValueError("receive non-empty specials when init from a Vocab object.") - elif isinstance(gene_list_or_vocab, list): - _vocab = self._build_vocab_from_iterator( - gene_list_or_vocab, - specials=specials, - special_first=special_first, - ) - else: - raise ValueError("gene_list_or_vocab must be a list of gene names or a Vocab object.") - super().__init__(_vocab.vocab) - if default_token is not None and default_token in self: - self.set_default_token(default_token) - - @classmethod - def from_file(cls, file_path: Union[Path, str]) -> Self: - """ - Load the vocabulary from a file. The file should be either a pickle or a - json file of token to index mapping. - """ - if isinstance(file_path, str): - file_path = Path(file_path) - if file_path.suffix == ".pkl": - with file_path.open("rb") as f: - vocab = pickle.load(f) - return cls(vocab) - elif file_path.suffix == ".json": - with file_path.open("r") as f: - token2idx = json.load(f) - return cls.from_dict(token2idx) - else: - raise ValueError(f"{file_path} is not a valid file type. Only .pkl and .json are supported.") - - @classmethod - def from_dict( - cls, - token2idx: Dict[str, int], - default_token: Optional[str] = "", - ) -> Self: - """ - Load the vocabulary from a dictionary. - - Args: - token2idx (Dict[str, int]): Dictionary mapping tokens to indices. - """ - # initiate an empty vocabulary first - _vocab = cls([]) - - # add the tokens to the vocabulary, GeneVocab requires consecutive indices - for t, i in sorted(token2idx.items(), key=lambda x: x[1]): - _vocab.insert_token(t, i) - - if default_token is not None and default_token in _vocab: - _vocab.set_default_token(default_token) - - return _vocab - - def _build_vocab_from_iterator( - self, - iterator: Iterable, - min_freq: int = 1, - specials: Optional[List[str]] = None, - special_first: bool = True, - ) -> Vocab: - """ - Build a Vocab from an iterator. This function is modified from - torchtext.vocab.build_vocab_from_iterator. The original function always - splits tokens into characters, which is not what we want. - - Args: - iterator (Iterable): Iterator used to build Vocab. Must yield list - or iterator of tokens. - min_freq (int): The minimum frequency needed to include a token in - the vocabulary. - specials (List[str]): Special symbols to add. The order of supplied - tokens will be preserved. - special_first (bool): Whether to add special tokens to the beginning - - Returns: - torchtext.vocab.Vocab: A `Vocab` object - """ - - counter = Counter() - counter.update(iterator) - - if specials is not None: - for tok in specials: - del counter[tok] - - sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) - sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) - ordered_dict = OrderedDict(sorted_by_freq_tuples) - - if specials is not None: - if special_first: - specials = specials[::-1] - for symbol in specials: - ordered_dict.update({symbol: min_freq}) - ordered_dict.move_to_end(symbol, last=not special_first) - - ## TODO: fix broken usage - # word_vocab = torch_vocab.vocab(ordered_dict, min_freq=min_freq) - # return word_vocab - - @property - def pad_token(self) -> Optional[str]: - """ - Get the pad token. - """ - if getattr(self, "_pad_token", None) is None: - self._pad_token = None - return self._pad_token - - @pad_token.setter - def pad_token(self, pad_token: str) -> None: - """ - Set the pad token. Will not add the pad token to the vocabulary. - - Args: - pad_token (str): Pad token, should be in the vocabulary. - """ - if pad_token not in self: - raise ValueError(f"{pad_token} is not in the vocabulary.") - self._pad_token = pad_token - - def save_json(self, file_path: Union[Path, str]) -> None: - """ - Save the vocabulary to a json file. - """ - if isinstance(file_path, str): - file_path = Path(file_path) - with file_path.open("w") as f: - json.dump(self.get_stoi(), f, indent=2) - - def set_default_token(self, default_token: str) -> None: - """ - Set the default token. - - Args: - default_token (str): Default token. - """ - if default_token not in self: - raise ValueError(f"{default_token} is not in the vocabulary.") - self.set_default_index(self[default_token]) - - -def get_default_gene_vocab() -> GeneVocab: - """ - Get the default gene vocabulary, consisting of gene symbols and ids. - """ - vocab_file = Path(__file__).parent / "default_gene_vocab.json" - if not vocab_file.exists(): - logger.info(f"No existing default vocab, will build one and save to {vocab_file}") - return _build_default_gene_vocab(save_vocab_to=vocab_file) - logger.info(f"Loading gene vocabulary from {vocab_file}") - return GeneVocab.from_file(vocab_file) - - -def _build_default_gene_vocab( - download_source_to: str = "/tmp", - save_vocab_to: Union[Path, str, None] = None, -) -> GeneVocab: - """ - Build the default gene vocabulary from HGNC gene symbols. - - Args: - download_source_to (str): Directory to download the source data. - save_vocab_to (Path or str): Path to save the vocabulary. If None, - the vocabulary will not be saved. Default to None. - """ - gene_collection_file = Path(download_source_to) / "human.gene_name_symbol.from_genenames.org.tsv" - if not gene_collection_file.exists(): - # download and save file from url - url = ( - "https://www.genenames.org/cgi-bin/download/custom?col=gd_app_sym&" - "col=md_ensembl_id&status=Approved&status=Entry%20Withdrawn&hgnc_dbtag" - "=on&order_by=gd_app_sym_sort&format=text&submit=submit" - ) - import requests - - r = requests.get(url) - gene_collection_file.write_text(r.text) - - logger.info(f"Building gene vocabulary from {gene_collection_file}") - df = pd.read_csv(gene_collection_file, sep="\t") - gene_list = df["Approved symbol"].dropna().unique().tolist() - gene_vocab = GeneVocab(gene_list) # no special tokens set in default vocab - if save_vocab_to is not None: - gene_vocab.save_json(Path(save_vocab_to)) - return gene_vocab - - -def tokenize_batch( - data: np.ndarray, - gene_ids: np.ndarray, - return_pt: bool = True, - append_cls: bool = True, - include_zero_gene: bool = False, - cls_id: int = "", -) -> List[Tuple[Union[torch.Tensor, np.ndarray]]]: - """ - Tokenize a batch of data. Returns a list of tuple (gene_id, count). - - Args: - data (array-like): A batch of data, with shape (batch_size, n_features). - n_features equals the number of all genes. - gene_ids (array-like): A batch of gene ids, with shape (n_features,). - return_pt (bool): Whether to return torch tensors of gene_ids and counts, - default to True. - - Returns: - list: A list of tuple (gene_id, count) of non zero gene expressions. - """ - if data.shape[1] != len(gene_ids): - raise ValueError( - f"Number of features in data ({data.shape[1]}) does not match number of gene_ids ({len(gene_ids)})." - ) - tokenized_data = [] - for i in range(len(data)): - row = data[i] - if include_zero_gene: - values = row - genes = gene_ids - else: - idx = np.nonzero(row)[0] - values = row[idx] - genes = gene_ids[idx] - if append_cls: - genes = np.insert(genes, 0, cls_id) - values = np.insert(values, 0, 0) - if return_pt: - genes = torch.from_numpy(genes).long() - values = torch.from_numpy(values) - tokenized_data.append((genes, values)) - return tokenized_data - - -def pad_batch( - batch: List[Tuple], - max_len: int, - vocab: Vocab, - pad_token: str = "", - pad_value: int = 0, - cls_appended: bool = True, -) -> Dict[str, torch.Tensor]: - """ - Pad a batch of data. Returns a list of Dict[gene_id, count]. - - Args: - batch (list): A list of tuple (gene_id, count). - max_len (int): The maximum length of the batch. - vocab (Vocab): The vocabulary containing the pad token. - pad_token (str): The token to pad with. - - Returns: - Dict[str, torch.Tensor]: A dictionary of gene_id and count. - """ - pad_id = vocab[pad_token] - gene_ids_list = [] - values_list = [] - for i in range(len(batch)): - gene_ids, values = batch[i] - if len(gene_ids) > max_len: - # sample max_len genes - if not cls_appended: - idx = np.random.choice(len(gene_ids), max_len, replace=False) - else: - idx = np.random.choice(len(gene_ids) - 1, max_len - 1, replace=False) - idx = idx + 1 - idx = np.insert(idx, 0, 0) - gene_ids = gene_ids[idx] - values = values[idx] - if len(gene_ids) < max_len: - gene_ids = torch.cat( - [ - gene_ids, - torch.full((max_len - len(gene_ids),), pad_id, dtype=gene_ids.dtype), - ] - ) - values = torch.cat( - [ - values, - torch.full((max_len - len(values),), pad_value, dtype=values.dtype), - ] - ) - gene_ids_list.append(gene_ids) - values_list.append(values) - batch_padded = { - "genes": torch.stack(gene_ids_list, dim=0), - "values": torch.stack(values_list, dim=0), - } - return batch_padded - - -def tokenize_and_pad_batch( - data: np.ndarray, - gene_ids: np.ndarray, - max_len: int, - vocab: Vocab, - pad_token: str, - pad_value: int, - append_cls: bool = True, - include_zero_gene: bool = False, - cls_token: str = "", - return_pt: bool = True, -) -> Dict[str, torch.Tensor]: - """ - Tokenize and pad a batch of data. Returns a list of tuple (gene_id, count). - """ - cls_id = vocab[cls_token] - tokenized_data = tokenize_batch( - data, - gene_ids, - return_pt=return_pt, - append_cls=append_cls, - include_zero_gene=include_zero_gene, - cls_id=cls_id, - ) - batch_padded = pad_batch(tokenized_data, max_len, vocab, pad_token, pad_value, cls_appended=append_cls) - return batch_padded - - -def random_mask_value( - values: Union[torch.Tensor, np.ndarray], - mask_ratio: float = 0.15, - mask_value: int = -1, - pad_value: int = 0, -) -> torch.Tensor: - """ - Randomly mask a batch of data. - - Args: - values (array-like): - A batch of tokenized data, with shape (batch_size, n_features). - mask_ratio (float): The ratio of genes to mask, default to 0.15. - mask_value (int): The value to mask with, default to -1. - pad_value (int): The value of padding in the values, will be kept unchanged. - - Returns: - torch.Tensor: A tensor of masked data. - """ - if isinstance(values, torch.Tensor): - # it is crutial to clone the tensor, otherwise it changes the original tensor - values = values.clone().detach().numpy() - else: - values = values.copy() - - for i in range(len(values)): - row = values[i] - non_padding_idx = np.nonzero(row - pad_value)[0] - n_mask = int(len(non_padding_idx) * mask_ratio) - mask_idx = np.random.choice(non_padding_idx, n_mask, replace=False) - row[mask_idx] = mask_value - return torch.from_numpy(values).float() diff --git a/src/state/tx/models/scgpt/generation_model.py b/src/state/tx/models/scgpt/generation_model.py deleted file mode 100644 index 50467692..00000000 --- a/src/state/tx/models/scgpt/generation_model.py +++ /dev/null @@ -1,701 +0,0 @@ -import math -from typing import Mapping, Optional, Union - -import torch -import torch.nn.functional as F -from torch import Tensor, nn -from torch.distributions import Bernoulli -from torch.nn import TransformerEncoder, TransformerEncoderLayer -from tqdm import trange - -from .model import ( - ContinuousValueEncoder, - ExprDecoder, - FastTransformerEncoderWrapper, - FlashTransformerEncoderLayer, - MVCDecoder, -) -from .utils import map_raw_id_to_vocab_id - - -def generate_square_subsequent_mask(sz: int) -> Tensor: - """Generates an upper-triangular matrix of -inf, with zeros on diag.""" - return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) - - -class GeneEncoder(nn.Module): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - ): - super().__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - self.enc_norm = nn.LayerNorm(embedding_dim) - - def forward(self, x: Tensor) -> Tensor: - x = self.embedding(x) # (batch, seq_len, embsize) - x = self.enc_norm(x) - return x - - -class PositionalEncoding(nn.Module): - def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - pe = torch.zeros(max_len, 1, d_model) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer("pe", pe) - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: Tensor, shape [seq_len, batch_size, embedding_dim] - """ - x = x + self.pe[: x.size(0)] - return self.dropout(x) - - -class Similarity(nn.Module): - """ - Dot product or cosine similarity - """ - - def __init__(self, temp): - super().__init__() - self.temp = temp - self.cos = nn.CosineSimilarity(dim=-1) - - def forward(self, x, y): - return self.cos(x, y) / self.temp - - -class ClsDecoder(nn.Module): - """ - Decoder for classification task. - """ - - def __init__( - self, - d_model: int, - n_cls: int, - nlayers: int = 3, - activation: callable = nn.ReLU, - ): - super().__init__() - # module list - self._decoder = nn.ModuleList() - for i in range(nlayers - 1): - self._decoder.append(nn.Linear(d_model, d_model)) - self._decoder.append(activation()) - self._decoder.append(nn.LayerNorm(d_model)) - self.out_layer = nn.Linear(d_model, n_cls) - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: Tensor, shape [batch_size, embsize] - """ - for layer in self._decoder: - x = layer(x) - return self.out_layer(x) - - -class DrugEncoder(nn.Module): - def __init__(self, d_model: int, n_drug_tokens: int): - super().__init__() - self.embedding = nn.Embedding(n_drug_tokens, d_model) - self.enc_norm = nn.LayerNorm(d_model) - - def forward(self, x: Tensor) -> Tensor: - x = self.embedding(x) - x = self.enc_norm(x) - return x - - -class TransformerGenerator(nn.Module): - def __init__( - self, - ntoken: int, - d_model: int, - nhead: int, - d_hid: int, - nlayers: int, - nlayers_cls: int, - n_cls: int, - pad_token_id: int, - dropout: float = 0.5, - pad_value: int = 0, - pert_pad_id: int = 2, - do_mvc: bool = False, - domain_spec_batchnorm: Union[bool, str] = False, - cell_emb_style: str = "cls", - mvc_decoder_style: str = "inner product", - ecs_threshold: float = 0.3, - explicit_zero_prob: bool = False, - use_fast_transformer: bool = False, - fast_transformer_backend: str = "flash", - pre_norm: bool = False, - ): - super().__init__() - self.model_type = "Transformer" - self.d_model = d_model - self.pad_token_id = pad_token_id - self.pad_value = pad_value - self.pert_pad_id = pert_pad_id - self.ecs_threshold = ecs_threshold - self.domain_spec_batchnorm = domain_spec_batchnorm - self.cell_emb_style = cell_emb_style - self.explicit_zero_prob = explicit_zero_prob - self.norm_scheme = "pre" if pre_norm else "post" - if cell_emb_style not in ["cls", "avg-pool", "w-pool"]: - raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}") - - self.encoder = GeneEncoder(ntoken, d_model, padding_idx=pad_token_id) - self.value_encoder = ContinuousValueEncoder(d_model, dropout) - self.pert_encoder = nn.Embedding(3, d_model, padding_idx=pert_pad_id) - - print("Using simple batchnorm instead of domain specific batchnorm") - self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) - - if use_fast_transformer: - if fast_transformer_backend == "linear": - self.transformer_encoder = FastTransformerEncoderWrapper(d_model, nhead, d_hid, nlayers, dropout) - elif fast_transformer_backend == "flash": - encoder_layers = FlashTransformerEncoderLayer( - d_model, - nhead, - d_hid, - dropout, - batch_first=True, - norm_scheme=self.norm_scheme, - ) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - else: - encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - - # self.decoder = nn.Linear(d_model, 1) - self.decoder = ExprDecoder( - d_model, - explicit_zero_prob=explicit_zero_prob, - ) - self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) - if do_mvc: - self.mvc_decoder = MVCDecoder( - d_model, - arch_style=mvc_decoder_style, - explicit_zero_prob=explicit_zero_prob, - ) - - self.sim = Similarity(temp=0.5) - self.creterion_cce = nn.CrossEntropyLoss() - - self.init_weights() - - def init_weights(self) -> None: - initrange = 0.1 - self.encoder.embedding.weight.data.uniform_(-initrange, initrange) - - def _encode( - self, - src: Tensor, - values: Tensor, - input_pert_flags, - src_key_padding_mask: Tensor, - ) -> Tensor: - src = self.encoder(src) # (batch, seq_len, embsize) - self.cur_gene_token_embs = src - values = self.value_encoder(values) # (batch, seq_len, embsize) - perts = self.pert_encoder(input_pert_flags) # (batch, seq_len, embsize) - total_embs = src + values + perts - - total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) - output = self.transformer_encoder(total_embs, src_key_padding_mask=src_key_padding_mask) - return output # (batch, seq_len, embsize) - - def _get_cell_emb_from_layer(self, layer_output: Tensor, weights: Tensor = None) -> Tensor: - """ - Args: - layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) - weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used - when :attr:`self.cell_emb_style` is "w-pool". - - Returns: - :obj:`Tensor`: shape (batch, embsize) - """ - if self.cell_emb_style == "cls": - cell_emb = layer_output[:, 0, :] # (batch, embsize) - elif self.cell_emb_style == "avg-pool": - cell_emb = torch.mean(layer_output, dim=1) - elif self.cell_emb_style == "w-pool": - if weights is None: - raise ValueError("weights is required when cell_emb_style is w-pool") - if weights.dim() != 2: - raise ValueError("weights should be 2D") - cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) - cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) - - return cell_emb - - def forward( - self, - src: Tensor, - pert_ids: Tensor, # unused for genetic perturbations but added for compatibility with chemical generator - values: Tensor, - input_pert_flags: Tensor, - src_key_padding_mask: Tensor, - CLS: bool = False, - CCE: bool = False, - MVC: bool = False, - ECS: bool = False, - do_sample: bool = False, - ) -> Mapping[str, Tensor]: - """ - Args: - src (:obj:`Tensor`): token ids, shape [batch_size, seq_len] - values (:obj:`Tensor`): token values, shape [batch_size, seq_len] - src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size, - seq_len] - CLS (:obj:`bool`): if True, return the celltype classification objective - (CLS) output - CCE (:obj:`bool`): if True, return the contrastive cell embedding objective - (CCE) output - MVC (:obj:`bool`): if True, return the masked value prediction for cell - embedding MVC output - ECS (:obj:`bool`): if True, return the elastic cell similarity objective - (ECS) output. - - Returns: - dict of output Tensors. - """ - if self.explicit_zero_prob and not do_sample and not self.training: - do_sample = True - # logger.warning("Auto set do_sample to True when model is in eval mode.") - - transformer_output = self._encode(src, values, input_pert_flags, src_key_padding_mask) - output = {} - mlm_output = self.decoder(transformer_output) - if self.explicit_zero_prob and do_sample: - bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) - output["mlm_output"] = bernoulli.sample() * mlm_output["pred"] - else: - output["mlm_output"] = mlm_output["pred"] # (batch, seq_len) - if self.explicit_zero_prob: - output["mlm_zero_probs"] = mlm_output["zero_probs"] - - cell_emb = self._get_cell_emb_from_layer(transformer_output, values) - if CLS: - output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) - if MVC: - mvc_output = self.mvc_decoder( - cell_emb, - self.cur_gene_token_embs, - ) # (batch, seq_len) - if self.explicit_zero_prob and do_sample: - bernoulli = Bernoulli(probs=mvc_output["zero_probs"]) - output["mvc_output"] = bernoulli.sample() * mvc_output["pred"] - else: - output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) - if self.explicit_zero_prob: - output["mvc_zero_probs"] = mvc_output["zero_probs"] - if ECS: - # Here using customized cosine similarity instead of F.cosine_similarity - # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 - # normalize the embedding - cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) - cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) - - # mask out diagnal elements - mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) - cos_sim = cos_sim.masked_fill(mask, 0.0) - # only optimize positive similarities - cos_sim = F.relu(cos_sim) - - output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) - - return output - - def encode_batch( - self, - src: Tensor, - values: Tensor, - src_key_padding_mask: Tensor, - batch_size: int, - output_to_cpu: bool = True, - ) -> Tensor: - """ - Args: - src: Tensor, shape [N, seq_len] - values: Tensor, shape [N, seq_len] - src_key_padding_mask: Tensor, shape [N, seq_len] - - Returns: - output Tensor of shape [N, seq_len, embsize] - """ - outputs = [] - N = src.size(0) - device = next(self.parameters()).device - for i in trange(0, N, batch_size): - output = self._encode( - src[i : i + batch_size].to(device), - values[i : i + batch_size].to(device), - src_key_padding_mask[i : i + batch_size].to(device), - ) - if output_to_cpu: - output = output.cpu() - outputs.append(output) - return torch.cat(outputs, dim=0) - - def pred_perturb( - self, - batch_data, - include_zero_gene="batch-wise", - gene_ids=None, - amp=True, - ) -> Tensor: - """ - Args: - batch_data: a dictionary of input data with keys. - - Returns: - output Tensor of shape [N, seq_len] - """ - self.eval() - device = next(self.parameters()).device - batch_data.to(device) - batch_size = len(batch_data.pert) - x: torch.Tensor = batch_data.x.reshape(batch_size, -1) - ori_gene_values = x[:, 0].view(batch_size, -1) # (batch_size, n_genes) - pert_flags = x[:, 1].long().view(batch_size, -1) - - if include_zero_gene in ["all", "batch-wise"]: - assert gene_ids is not None - if include_zero_gene == "all": - input_gene_ids = torch.arange(ori_gene_values.size(1), device=device) - else: # batch-wise - input_gene_ids = ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0] - input_values = ori_gene_values[:, input_gene_ids] - input_pert_flags = pert_flags[:, input_gene_ids] - - print(input_gene_ids) - - mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids) - mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1) - - src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device) - with torch.cuda.amp.autocast(enabled=amp): - output_dict = self( - mapped_input_gene_ids, - input_values, - input_pert_flags, - src_key_padding_mask=src_key_padding_mask, - CLS=False, - CCE=False, - MVC=False, - ECS=False, - do_sample=True, - ) - output_values = output_dict["mlm_output"].float() - pred_gene_values = torch.zeros_like(ori_gene_values) - pred_gene_values[:, input_gene_ids] = output_values - return pred_gene_values - - -class ChemicalTransformerGenerator(nn.Module): - def __init__( - self, - ntoken: int, - n_drug_tokens: int, - d_model: int, - nhead: int, - d_hid: int, - nlayers: int, - nlayers_cls: int, - n_cls: int, - pad_token_id: int, - dropout: float = 0.5, - pad_value: int = 0, - pert_pad_id: int = 2, - do_mvc: bool = False, - domain_spec_batchnorm: Union[bool, str] = False, - cell_emb_style: str = "cls", - mvc_decoder_style: str = "inner product", - ecs_threshold: float = 0.3, - explicit_zero_prob: bool = False, - use_fast_transformer: bool = False, - fast_transformer_backend: str = "flash", - pre_norm: bool = False, - ): - super().__init__() - self.model_type = "Transformer" - self.d_model = d_model - self.n_drug_tokens = n_drug_tokens - self.pad_token_id = pad_token_id - self.pad_value = pad_value - self.pert_pad_id = pert_pad_id - self.ecs_threshold = ecs_threshold - self.domain_spec_batchnorm = domain_spec_batchnorm - self.cell_emb_style = cell_emb_style - self.explicit_zero_prob = explicit_zero_prob - self.norm_scheme = "pre" if pre_norm else "post" - if cell_emb_style not in ["cls", "avg-pool", "w-pool"]: - raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}") - - self.encoder = GeneEncoder(ntoken, d_model, padding_idx=pad_token_id) - self.drug_encoder = DrugEncoder(d_model, n_drug_tokens) - self.value_encoder = ContinuousValueEncoder(d_model, dropout) - self.pert_encoder = nn.Embedding(3, d_model, padding_idx=pert_pad_id) - - print("Using simple batchnorm instead of domain specific batchnorm") - self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) - - if use_fast_transformer: - if fast_transformer_backend == "linear": - self.transformer_encoder = FastTransformerEncoderWrapper(d_model, nhead, d_hid, nlayers, dropout) - elif fast_transformer_backend == "flash": - encoder_layers = FlashTransformerEncoderLayer( - d_model, - nhead, - d_hid, - dropout, - batch_first=True, - norm_scheme=self.norm_scheme, - ) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - else: - encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - - # self.decoder = nn.Linear(d_model, 1) - self.decoder = ExprDecoder( - d_model, - explicit_zero_prob=explicit_zero_prob, - ) - self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) - if do_mvc: - self.mvc_decoder = MVCDecoder( - d_model, - arch_style=mvc_decoder_style, - explicit_zero_prob=explicit_zero_prob, - ) - - self.sim = Similarity(temp=0.5) - self.creterion_cce = nn.CrossEntropyLoss() - - self.init_weights() - - def init_weights(self) -> None: - initrange = 0.1 - self.encoder.embedding.weight.data.uniform_(-initrange, initrange) - - def _encode( - self, - src: Tensor, - drug_ids: Tensor, - values: Tensor, - input_pert_flags, - src_key_padding_mask: Tensor, - ) -> Tensor: - src = self.encoder(src) # (batch, seq_len, embsize) - drug_embs = self.drug_encoder(drug_ids).unsqueeze(1) # (batch, 1, embsize) - self.cur_gene_token_embs = src - values = self.value_encoder(values) # (batch, seq_len, embsize) - perts = self.pert_encoder(input_pert_flags) # (batch, seq_len, embsize) - total_embs = src + drug_embs + values + perts - - total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) - output = self.transformer_encoder(total_embs, src_key_padding_mask=src_key_padding_mask) - return output # (batch, seq_len, embsize) - - def _get_cell_emb_from_layer(self, layer_output: Tensor, weights: Tensor = None) -> Tensor: - """ - Args: - layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) - weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used - when :attr:`self.cell_emb_style` is "w-pool". - - Returns: - :obj:`Tensor`: shape (batch, embsize) - """ - if self.cell_emb_style == "cls": - cell_emb = layer_output[:, 0, :] # (batch, embsize) - elif self.cell_emb_style == "avg-pool": - cell_emb = torch.mean(layer_output, dim=1) - elif self.cell_emb_style == "w-pool": - if weights is None: - raise ValueError("weights is required when cell_emb_style is w-pool") - if weights.dim() != 2: - raise ValueError("weights should be 2D") - cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) - cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) - - return cell_emb - - def forward( - self, - src: Tensor, - drug_ids: Tensor, - values: Tensor, - input_pert_flags: Tensor, - src_key_padding_mask: Tensor, - CLS: bool = False, - CCE: bool = False, - MVC: bool = False, - ECS: bool = False, - do_sample: bool = False, - ) -> Mapping[str, Tensor]: - """ - Args: - src (:obj:`Tensor`): token ids, shape [batch_size, seq_len] - values (:obj:`Tensor`): token values, shape [batch_size, seq_len] - src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size, - seq_len] - CLS (:obj:`bool`): if True, return the celltype classification objective - (CLS) output - CCE (:obj:`bool`): if True, return the contrastive cell embedding objective - (CCE) output - MVC (:obj:`bool`): if True, return the masked value prediction for cell - embedding MVC output - ECS (:obj:`bool`): if True, return the elastic cell similarity objective - (ECS) output. - - Returns: - dict of output Tensors. - """ - if self.explicit_zero_prob and not do_sample and not self.training: - do_sample = True - # logger.warning("Auto set do_sample to True when model is in eval mode.") - - transformer_output = self._encode(src, drug_ids, values, input_pert_flags, src_key_padding_mask) - output = {} - mlm_output = self.decoder(transformer_output) - if self.explicit_zero_prob and do_sample: - bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) - output["mlm_output"] = bernoulli.sample() * mlm_output["pred"] - else: - output["mlm_output"] = mlm_output["pred"] # (batch, seq_len) - if self.explicit_zero_prob: - output["mlm_zero_probs"] = mlm_output["zero_probs"] - - cell_emb = self._get_cell_emb_from_layer(transformer_output, values) - if CLS: - output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) - if MVC: - mvc_output = self.mvc_decoder( - cell_emb, - self.cur_gene_token_embs, - ) # (batch, seq_len) - if self.explicit_zero_prob and do_sample: - bernoulli = Bernoulli(probs=mvc_output["zero_probs"]) - output["mvc_output"] = bernoulli.sample() * mvc_output["pred"] - else: - output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) - if self.explicit_zero_prob: - output["mvc_zero_probs"] = mvc_output["zero_probs"] - if ECS: - # Here using customized cosine similarity instead of F.cosine_similarity - # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 - # normalize the embedding - cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) - cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) - - # mask out diagnal elements - mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) - cos_sim = cos_sim.masked_fill(mask, 0.0) - # only optimize positive similarities - cos_sim = F.relu(cos_sim) - - output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) - - return output - - def encode_batch( - self, - src: Tensor, - values: Tensor, - src_key_padding_mask: Tensor, - batch_size: int, - output_to_cpu: bool = True, - ) -> Tensor: - """ - Args: - src: Tensor, shape [N, seq_len] - values: Tensor, shape [N, seq_len] - src_key_padding_mask: Tensor, shape [N, seq_len] - - Returns: - output Tensor of shape [N, seq_len, embsize] - """ - outputs = [] - N = src.size(0) - device = next(self.parameters()).device - for i in trange(0, N, batch_size): - output = self._encode( - src[i : i + batch_size].to(device), - values[i : i + batch_size].to(device), - src_key_padding_mask[i : i + batch_size].to(device), - ) - if output_to_cpu: - output = output.cpu() - outputs.append(output) - return torch.cat(outputs, dim=0) - - def pred_perturb( - self, - batch_data, - include_zero_gene="batch-wise", - gene_ids=None, - amp=True, - ) -> Tensor: - """ - Args: - batch_data: a dictionary of input data with keys. - - Returns: - output Tensor of shape [N, seq_len] - """ - self.eval() - device = next(self.parameters()).device - batch_data.to(device) - batch_size = len(batch_data.pert) - x: torch.Tensor = batch_data.x.reshape(batch_size, -1) - ori_gene_values = x[:, 0].view(batch_size, -1) # (batch_size, n_genes) - pert_flags = x[:, 1].long().view(batch_size, -1) - - if include_zero_gene in ["all", "batch-wise"]: - assert gene_ids is not None - if include_zero_gene == "all": - input_gene_ids = torch.arange(ori_gene_values.size(1), device=device) - else: # batch-wise - input_gene_ids = ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0] - input_values = ori_gene_values[:, input_gene_ids] - input_pert_flags = pert_flags[:, input_gene_ids] - - print(input_gene_ids) - - mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids) - mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1) - - src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device) - with torch.cuda.amp.autocast(enabled=amp): - output_dict = self( - mapped_input_gene_ids, - input_values, - input_pert_flags, - src_key_padding_mask=src_key_padding_mask, - CLS=False, - CCE=False, - MVC=False, - ECS=False, - do_sample=True, - ) - output_values = output_dict["mlm_output"].float() - pred_gene_values = torch.zeros_like(ori_gene_values) - pred_gene_values[:, input_gene_ids] = output_values - return pred_gene_values diff --git a/src/state/tx/models/scgpt/grad_reverse.py b/src/state/tx/models/scgpt/grad_reverse.py deleted file mode 100644 index 7dc7e612..00000000 --- a/src/state/tx/models/scgpt/grad_reverse.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch -from torch.autograd import Function - - -class GradReverse(Function): - @staticmethod - def forward(ctx, x: torch.Tensor, lambd: float) -> torch.Tensor: - ctx.lambd = lambd - return x.view_as(x) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - return grad_output.neg() * ctx.lambd, None - - -def grad_reverse(x: torch.Tensor, lambd: float = 1.0) -> torch.Tensor: - return GradReverse.apply(x, lambd) diff --git a/src/state/tx/models/scgpt/lightning_model.py b/src/state/tx/models/scgpt/lightning_model.py deleted file mode 100644 index c18264c8..00000000 --- a/src/state/tx/models/scgpt/lightning_model.py +++ /dev/null @@ -1,321 +0,0 @@ -from typing import List, Literal, Optional, Union - -import torch -from torchmetrics.functional import pearson_corrcoef, r2_score - -from ..base import PerturbationModel -from .generation_model import ChemicalTransformerGenerator, TransformerGenerator -from .loss import masked_mse_loss -from .utils import map_raw_id_to_vocab_id - - -class scGPTForPerturbation(PerturbationModel): - def __init__( - self, - ntoken: int, - d_model: int, - nhead: int, - d_hid: int, - nlayers: int, - nlayers_cls: int, - n_cls: int, - n_drug_tokens: int = 0, - dropout: float = 0.5, - pad_token_id: int = 0, - pad_value: int = 0, - pert_pad_id: int = 2, - do_mvc: bool = False, - domain_spec_batchnorm: Union[bool, str] = False, - cell_emb_style: str = "cls", - mvc_decoder_style: str = "inner product", - ecs_threshold: float = 0.3, - explicit_zero_prob: bool = False, - use_fast_transformer: bool = False, - fast_transformer_backend: str = "flash", - pre_norm: bool = False, - lr: float = 1e-4, - step_size_lr: int = 1, - include_zero_gene: Optional[Literal["all", "batch-wise"]] = "all", - max_seq_len: int = 1536, - do_CLS: bool = True, - do_CCE: bool = False, - do_MVC: bool = False, - do_ECS: bool = False, - gene_names: Optional[List[str]] = None, - embed_key: Optional[str] = None, - perturbation_type: Literal["chemical", "genetic"] = "chemical", - **kwargs, - ): - super().__init__( - input_dim=None, - hidden_dim=None, - output_dim=None, - pert_dim=None, - dropout=dropout, - lr=lr, - loss_fn="mse", - embed_key=embed_key, - output_space="gene", - gene_names=gene_names, - batch_size=64, - ) - self.ntoken = ntoken - self.d_model = d_model - self.nhead = nhead - self.d_hid = d_hid - self.nlayers = nlayers - self.nlayers_cls = nlayers_cls - self.n_cls = n_cls - self.n_drug_tokens = n_drug_tokens - self.dropout = dropout - self.pad_token_id = pad_token_id - self.pad_value = pad_value - self.pert_pad_id = pert_pad_id - self.do_mvc = do_mvc - self.domain_spec_batchnorm = domain_spec_batchnorm - self.cell_emb_style = cell_emb_style - self.mvc_decoder_style = mvc_decoder_style - self.ecs_threshold = ecs_threshold - self.explicit_zero_prob = explicit_zero_prob - self.use_fast_transformer = use_fast_transformer - self.fast_transformer_backend = fast_transformer_backend - self.pre_norm = pre_norm - self.perturbation_type = perturbation_type.lower() - - for k, v in kwargs.items(): - print(f"WARNING: scGPTForPerturbation Model unused kwarg: {k}") - - self.lr = lr - self.step_size_lr = step_size_lr - self.include_zero_gene = include_zero_gene - self.max_seq_len = max_seq_len - self.do_CLS = do_CLS - self.do_CCE = do_CCE - self.do_MVC = do_MVC - self.do_ECS = do_ECS - - assert self.perturbation_type in ["chemical", "genetic"], ( - "perturbation_type must be either 'chemical' or 'genetic'" - ) - - if self.perturbation_type == "chemical": - assert self.n_drug_tokens > 0, "n_drug_tokens must be greater than 0 for chemical perturbation" - - self.save_hyperparameters() - - self.validation_outputs = [] - - self._build_networks() - - def _build_networks(self): - generator_params = dict( - ntoken=self.ntoken, - d_model=self.d_model, - nhead=self.nhead, - d_hid=self.d_hid, - nlayers=self.nlayers, - nlayers_cls=self.nlayers_cls, - n_cls=self.n_cls, - dropout=self.dropout, - pad_token_id=self.pad_token_id, - pad_value=self.pad_value, - pert_pad_id=self.pert_pad_id, - do_mvc=self.do_mvc, - domain_spec_batchnorm=self.domain_spec_batchnorm, - cell_emb_style=self.cell_emb_style, - mvc_decoder_style=self.mvc_decoder_style, - ecs_threshold=self.ecs_threshold, - explicit_zero_prob=self.explicit_zero_prob, - use_fast_transformer=self.use_fast_transformer, - fast_transformer_backend=self.fast_transformer_backend, - pre_norm=self.pre_norm, - ) - if self.perturbation_type == "chemical": - self.model = ChemicalTransformerGenerator( - n_drug_tokens=self.n_drug_tokens, - **generator_params, - ) - elif self.perturbation_type == "genetic": - self.model = TransformerGenerator(**generator_params) - - def encode_perturbation(self, pert: torch.Tensor) -> torch.Tensor: - """Map perturbation to an effect vector in embedding space.""" - raise NotImplementedError("Perturbation encoding not supported for scGPT model") - - def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor: - """Expression is already in embedding space, pass through.""" - raise NotImplementedError("Basal expression encoding not supported for scGPT model") - - def perturb(self, pert: torch.Tensor, basal: torch.Tensor) -> torch.Tensor: - """ - Given a perturbation and basal embeddings, compute the perturbed embedding. - """ - # Project perturbation and basal cell state to latent space - raise NotImplementedError("Perturb function not supported for scGPT model") - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.step_size_lr, gamma=0.9) - return [optimizer], [scheduler] - - def shared_step( - self, - batch, - truncate=True, - ): - x_basal = batch["ctrl_cell_emb"] # (batch_size, n_genes) - x_pert = batch["pert_cell_emb"] # (batch_size, n_genes) - pert_ids = batch["pert_emb"].argmax(dim=1) - if self.perturbation_type == "chemical": - pert_flags = torch.zeros_like(x_pert, dtype=torch.long) # no genes are perturbed - else: - pert_flags = batch["pert_flags"] # (batch_size, n_genes) - - gene_ids = batch["gene_ids"][0] # (n_genes,) - - nonpad_genes_mask = gene_ids != self.pad_token_id - - x_basal = x_basal[:, nonpad_genes_mask] - x_pert = x_pert[:, nonpad_genes_mask] - gene_ids = gene_ids[nonpad_genes_mask] - pert_flags = pert_flags[:, nonpad_genes_mask] - - batch_size, n_genes = x_basal.size() - - if self.include_zero_gene == "all": - input_gene_ids = torch.arange(n_genes, device=x_basal.device, dtype=torch.long) - else: - input_gene_ids = x_basal.nonzero()[:, 1].flatten().unique().sort()[0] # TODO: need double-check - - # sample input_gene_id - if truncate: - if len(input_gene_ids) > self.max_seq_len: - input_gene_ids = torch.randperm(len(input_gene_ids), device=x_basal.device)[: self.max_seq_len] - - x_basal = x_basal[:, input_gene_ids] - x_pert = x_pert[:, input_gene_ids] - pert_flags = pert_flags[:, input_gene_ids] - - mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids) - mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1) # (batch_size, max_seq_len) - - src_key_padding_mask = mapped_input_gene_ids.eq(self.pad_token_id) - # src_key_padding_mask = torch.zeros_like( - # x_basal, dtype=torch.bool, device=x_basal.device - # ) - - if self.perturbation_type == "genetic": - pert_ids = None - - output_dict = self.model( - mapped_input_gene_ids.long(), - pert_ids, - x_basal, - pert_flags.long(), - src_key_padding_mask=src_key_padding_mask, - CLS=self.do_CLS and self.training, - CCE=self.do_CCE and self.training, - MVC=self.do_MVC and self.training, - ECS=self.do_ECS and self.training, - ) - output_values = output_dict["mlm_output"] # batch_size, max_seq_len - - masked_positions = torch.ones_like(x_basal, dtype=torch.bool) # Use all genes - loss = masked_mse_loss(output_values, x_pert, masked_positions) - - output_dict["x_pred"] = output_dict["mlm_output"].float() - output_dict["x_true"] = x_pert - output_dict["x_basal"] = x_basal - - return loss, output_dict - - def compute_metrics(self, x_basal, x_pred, x_true): - """ - Computes a sets of evaluation metrics for assessing perturbation response prediction's quality. - - Parameters - ---------- - x_pred : torch.Tensor of shape (batch_size, n_genes) - The predicted perturbation response. - x_true : torch.Tensor of shape (batch_size, n_genes) - The true perturbation response. - """ - r2_mean = r2_score(x_pred.mean(0), x_true.mean(0)) # R2 score for mean gene expression - - change_pred = x_pred - x_basal - change_true = x_true - x_basal - - pearson_mean_lfc = pearson_corrcoef(change_pred.mean(0), change_true.mean(0)) - - return { - "r2_mean": r2_mean, - "pearson_mean_lfc": pearson_mean_lfc, - } - - def training_step(self, batch, batch_idx): - loss, batch_outputs = self.shared_step(batch) - - metrics = self.compute_metrics( - x_basal=batch_outputs["x_basal"], - x_pred=batch_outputs["x_pred"], - x_true=batch_outputs["x_true"], - ) - - self.log("loss", loss, prog_bar=True) - - for k, v in metrics.items(): - self.log(k, v, prog_bar=True) - - return loss - - def validation_step(self, batch, batch_idx): - loss, batch_outputs = self.shared_step(batch) - - metrics = self.compute_metrics( - x_basal=batch_outputs["x_basal"], - x_pred=batch_outputs["x_pred"], - x_true=batch_outputs["x_true"], - ) - - # self.validation_outputs.append( - # ( - # batch_outputs["x_basal"].detach().cpu(), - # batch_outputs["x_pred"].detach().cpu(), - # batch_outputs["x_true"].detach().cpu(), - # ) - # ) - - self.log("val_loss", loss, prog_bar=True) - - for k, v in metrics.items(): - self.log(f"val_{k}", v, prog_bar=True) - - return loss - - # def validation_epoch_end(self, outputs): - # pass - - @torch.no_grad() - def predict(self, batch): - loss, batch_outputs = self.shared_step(batch, truncate=False) - - x_pred = batch_outputs["x_pred"] - return x_pred - - def predict_step(self, batch, batch_idx, **kwargs): - """ - Typically used for final inference. We'll replicate old logic: - returning 'pred', 'X', 'pert_name', etc. - """ - loss, batch_outputs = self.shared_step(batch, truncate=False) - - return { - "preds": batch_outputs["x_pred"].float(), - "pert_cell_emb": batch_outputs["x_true"].float(), - "X_gene": batch_outputs["x_true"].float(), - "pert_emb": batch.get("pert_emb", None), - "pert_name": batch.get("pert_name", None), - "cell_type": batch.get("cell_type", None), - "batch": batch.get("batch_name", None), - "ctrl_cell_emb": batch_outputs["x_basal"].float(), - } diff --git a/src/state/tx/models/scgpt/loss.py b/src/state/tx/models/scgpt/loss.py deleted file mode 100644 index 6d734124..00000000 --- a/src/state/tx/models/scgpt/loss.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -import torch.nn.functional as F - - -def masked_mse_loss(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - """ - Compute the masked MSE loss between input and target. - """ - mask = mask.float() - loss = F.mse_loss(input * mask, target * mask, reduction="sum") - return loss / mask.sum() - - -def criterion_neg_log_bernoulli(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - """ - Compute the negative log-likelihood of Bernoulli distribution - """ - mask = mask.float() - bernoulli = torch.distributions.Bernoulli(probs=input) - masked_log_probs = bernoulli.log_prob((target > 0).float()) * mask - return -masked_log_probs.sum() / mask.sum() - - -def masked_relative_error(input: torch.Tensor, target: torch.Tensor, mask: torch.LongTensor) -> torch.Tensor: - """ - Compute the masked relative error between input and target. - """ - assert mask.any() - loss = torch.abs(input[mask] - target[mask]) / (target[mask] + 1e-6) - return loss.mean() diff --git a/src/state/tx/models/scgpt/model.py b/src/state/tx/models/scgpt/model.py deleted file mode 100644 index 21c96ae0..00000000 --- a/src/state/tx/models/scgpt/model.py +++ /dev/null @@ -1,976 +0,0 @@ -import math -from typing import Any, Dict, Mapping, Optional, Union - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn.functional as F -from fast_transformers.masking import LengthMask -from torch import Tensor, nn -from torch.distributions import Bernoulli -from torch.nn import TransformerEncoder, TransformerEncoderLayer -from tqdm import trange - -try: - from flash_attn.flash_attention import FlashMHA -except ImportError: - import warnings - - warnings.warn("flash_attn is not installed") - -from .dsbn import DomainSpecificBatchNorm1d -from .grad_reverse import grad_reverse - - -class TransformerModel(nn.Module): - def __init__( - self, - ntoken: int, - d_model: int, - nhead: int, - d_hid: int, - nlayers: int, - nlayers_cls: int = 3, - n_cls: int = 1, - vocab: Any = None, - dropout: float = 0.5, - pad_token: str = "", - pad_value: int = 0, - do_mvc: bool = False, - do_dab: bool = False, - use_batch_labels: bool = False, - num_batch_labels: Optional[int] = None, - domain_spec_batchnorm: Union[bool, str] = False, - input_emb_style: str = "continuous", - n_input_bins: Optional[int] = None, - cell_emb_style: str = "cls", - mvc_decoder_style: str = "inner product", - ecs_threshold: float = 0.3, - explicit_zero_prob: bool = False, - use_fast_transformer: bool = False, - fast_transformer_backend: str = "flash", - pre_norm: bool = False, - ): - super().__init__() - self.model_type = "Transformer" - self.d_model = d_model - self.do_dab = do_dab - self.ecs_threshold = ecs_threshold - self.use_batch_labels = use_batch_labels - self.domain_spec_batchnorm = domain_spec_batchnorm - self.input_emb_style = input_emb_style - self.cell_emb_style = cell_emb_style - self.explicit_zero_prob = explicit_zero_prob - self.norm_scheme = "pre" if pre_norm else "post" - if self.input_emb_style not in ["category", "continuous", "scaling"]: - raise ValueError(f"input_emb_style should be one of category, continuous, scaling, got {input_emb_style}") - if cell_emb_style not in ["cls", "avg-pool", "w-pool"]: - raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}") - - # TODO: add dropout in the GeneEncoder - self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token]) - - # Value Encoder, NOTE: the scaling style is also handled in _encode method - if input_emb_style == "continuous": - self.value_encoder = ContinuousValueEncoder(d_model, dropout) - elif input_emb_style == "category": - assert n_input_bins > 0 - self.value_encoder = CategoryValueEncoder(n_input_bins, d_model, padding_idx=pad_value) - else: - self.value_encoder = nn.Identity() # nn.Softmax(dim=1) - # TODO: consider row-wise normalization or softmax - # TODO: Correct handle the mask_value when using scaling - - # Batch Encoder - if use_batch_labels: - self.batch_encoder = BatchLabelEncoder(num_batch_labels, d_model) - - if domain_spec_batchnorm: - use_affine = True if domain_spec_batchnorm == "do_affine" else False - print(f"Use domain specific batchnorm with affine={use_affine}") - self.dsbn = DomainSpecificBatchNorm1d(d_model, num_batch_labels, eps=6.1e-5, affine=use_affine) - else: - print("Using simple batchnorm instead of domain specific batchnorm") - self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) - - if use_fast_transformer: - if fast_transformer_backend == "linear": - self.transformer_encoder = FastTransformerEncoderWrapper(d_model, nhead, d_hid, nlayers, dropout) - elif fast_transformer_backend == "flash": - encoder_layers = FlashTransformerEncoderLayer( - d_model, - nhead, - d_hid, - dropout, - batch_first=True, - norm_scheme=self.norm_scheme, - ) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - else: - encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - - self.decoder = ExprDecoder( - d_model, - explicit_zero_prob=explicit_zero_prob, - use_batch_labels=use_batch_labels, - ) - self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) - if do_mvc: - self.mvc_decoder = MVCDecoder( - d_model, - arch_style=mvc_decoder_style, - explicit_zero_prob=explicit_zero_prob, - use_batch_labels=use_batch_labels, - ) - - if do_dab: - self.grad_reverse_discriminator = AdversarialDiscriminator( - d_model, - n_cls=num_batch_labels, - reverse_grad=True, - ) - - self.sim = Similarity(temp=0.5) # TODO: auto set temp - self.creterion_cce = nn.CrossEntropyLoss() - - self.init_weights() - - def init_weights(self) -> None: - initrange = 0.1 - # TODO: check if this initialization is helpful and shall we apply to all? - self.encoder.embedding.weight.data.uniform_(-initrange, initrange) - - def _encode( - self, - src: Tensor, - values: Tensor, - src_key_padding_mask: Tensor, - batch_labels: Optional[Tensor] = None, # (batch,) - ) -> Tensor: - self._check_batch_labels(batch_labels) - - src = self.encoder(src) # (batch, seq_len, embsize) - self.cur_gene_token_embs = src - - values = self.value_encoder(values) # (batch, seq_len, embsize) - if self.input_emb_style == "scaling": - values = values.unsqueeze(2) - total_embs = src * values - else: - total_embs = src + values - - if self.domain_spec_batchnorm: - batch_label = int(batch_labels[0].item()) - total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute( - 0, 2, 1 - ) # the batch norm always works on dim 1 - else: - total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) - - output = self.transformer_encoder(total_embs, src_key_padding_mask=src_key_padding_mask) - return output # (batch, seq_len, embsize) - - def _get_cell_emb_from_layer(self, layer_output: Tensor, weights: Tensor = None) -> Tensor: - """ - Args: - layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) - weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used - when :attr:`self.cell_emb_style` is "w-pool". - - Returns: - :obj:`Tensor`: shape (batch, embsize) - """ - if self.cell_emb_style == "cls": - cell_emb = layer_output[:, 0, :] # (batch, embsize) - elif self.cell_emb_style == "avg-pool": - cell_emb = torch.mean(layer_output, dim=1) - elif self.cell_emb_style == "w-pool": - if weights is None: - raise ValueError("weights is required when cell_emb_style is w-pool") - if weights.dim() != 2: - raise ValueError("weights should be 2D") - cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) - cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) - - return cell_emb - - def _check_batch_labels(self, batch_labels: Tensor) -> None: - if self.use_batch_labels or self.domain_spec_batchnorm: - assert batch_labels is not None - elif batch_labels is not None: - raise ValueError( - "batch_labels should only be provided when `self.use_batch_labels`" - " or `self.domain_spec_batchnorm` is True" - ) - - def generate( - self, - cell_emb: Tensor, - src: Tensor, - values: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - gen_iters: int = 1, - batch_labels: Optional[Tensor] = None, # (batch,) - ) -> Tensor: - """ - Args: - cell_emb(:obj:`Tensor`): shape (batch, embsize) - src(:obj:`Tensor`): shape (batch, seq_len) - values(:obj:`Tensor`): shape (batch, seq_len), optional - src_key_padding_mask(:obj:`Tensor`): shape (batch, seq_len), optional - gen_iters(:obj:`int`): number of generation iterations - batch_labels(:obj:`Tensor`): shape (batch,), optional - """ - # TODO: should have a tag indicate the generation mode - # TODO: if gen_iters > 1, should have a tag indicate the current iteration - try: - self._check_batch_labels(batch_labels) - # TODO: handle raw exception - except: - import warnings - - warnings.warn("batch_labels is required but not provided, using zeros instead") - batch_labels = torch.zeros(cell_emb.shape[0], dtype=torch.long, device=cell_emb.device) - - src = self.encoder(src) # (batch, seq_len, embsize) - - if values is not None: - values = self.value_encoder(values) # (batch, seq_len, embsize) - if self.input_emb_style == "scaling": - values = values.unsqueeze(2) - total_embs = src * values - else: - total_embs = src + values - else: - total_embs = src - - if self.domain_spec_batchnorm: - batch_label = int(batch_labels[0].item()) - total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute( - 0, 2, 1 - ) # the batch norm always works on dim 1 - else: - total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) - - total_embs[:, 0, :] = cell_emb - - if src_key_padding_mask is None: - src_key_padding_mask = torch.zeros(total_embs.shape[:2], dtype=torch.bool, device=total_embs.device) - transformer_output = self.transformer_encoder(total_embs, src_key_padding_mask=src_key_padding_mask) - - if self.use_batch_labels: - batch_emb = self.batch_encoder(batch_labels) # (batch, embsize) - mlm_output = self.decoder( - ( - transformer_output - if not self.use_batch_labels - else torch.cat( - [ - transformer_output, - batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1), - ], - dim=2, - ) - ), - # else transformer_output + batch_emb.unsqueeze(1), - ) - output = mlm_output["pred"] # (batch, seq_len) - - return output # (batch, seq_len) - - def forward( - self, - src: Tensor, - values: Tensor, - src_key_padding_mask: Tensor, - batch_labels: Optional[Tensor] = None, - CLS: bool = False, - CCE: bool = False, - MVC: bool = False, - ECS: bool = False, - do_sample: bool = False, - ) -> Mapping[str, Tensor]: - """ - Args: - src (:obj:`Tensor`): token ids, shape [batch_size, seq_len] - values (:obj:`Tensor`): token values, shape [batch_size, seq_len] - src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size, - seq_len] - batch_labels (:obj:`Tensor`): batch labels, shape [batch_size] - CLS (:obj:`bool`): if True, return the celltype classification objective - (CLS) output - CCE (:obj:`bool`): if True, return the contrastive cell embedding objective - (CCE) output - MVC (:obj:`bool`): if True, return the masked value prediction for cell - embedding MVC output - ECS (:obj:`bool`): if True, return the elastic cell similarity objective - (ECS) output. - - Returns: - dict of output Tensors. - """ - transformer_output = self._encode(src, values, src_key_padding_mask, batch_labels) - if self.use_batch_labels: - batch_emb = self.batch_encoder(batch_labels) # (batch, embsize) - - output = {} - mlm_output = self.decoder( - ( - transformer_output - if not self.use_batch_labels - else torch.cat( - [ - transformer_output, - batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1), - ], - dim=2, - ) - ), - # else transformer_output + batch_emb.unsqueeze(1), - ) - if self.explicit_zero_prob and do_sample: - bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) - output["mlm_output"] = bernoulli.sample() * mlm_output["pred"] - else: - output["mlm_output"] = mlm_output["pred"] # (batch, seq_len) - if self.explicit_zero_prob: - output["mlm_zero_probs"] = mlm_output["zero_probs"] - - cell_emb = self._get_cell_emb_from_layer(transformer_output, values) - output["cell_emb"] = cell_emb - - if CLS: - output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) - if CCE: - cell1 = cell_emb - transformer_output2 = self._encode(src, values, src_key_padding_mask, batch_labels) - cell2 = self._get_cell_emb_from_layer(transformer_output2) - - # Gather embeddings from all devices if distributed training - if dist.is_initialized() and self.training: - cls1_list = [torch.zeros_like(cell1) for _ in range(dist.get_world_size())] - cls2_list = [torch.zeros_like(cell2) for _ in range(dist.get_world_size())] - dist.all_gather(tensor_list=cls1_list, tensor=cell1.contiguous()) - dist.all_gather(tensor_list=cls2_list, tensor=cell2.contiguous()) - - # NOTE: all_gather results have no gradients, so replace the item - # of the current rank with the original tensor to keep gradients. - # See https://github.com/princeton-nlp/SimCSE/blob/main/simcse/models.py#L186 - cls1_list[dist.get_rank()] = cell1 - cls2_list[dist.get_rank()] = cell2 - - cell1 = torch.cat(cls1_list, dim=0) - cell2 = torch.cat(cls2_list, dim=0) - # TODO: should detach the second run cls2? Can have a try - cos_sim = self.sim(cell1.unsqueeze(1), cell2.unsqueeze(0)) # (batch, batch) - labels = torch.arange(cos_sim.size(0)).long().to(cell1.device) - output["loss_cce"] = self.creterion_cce(cos_sim, labels) - if MVC: - mvc_output = self.mvc_decoder( - (cell_emb if not self.use_batch_labels else torch.cat([cell_emb, batch_emb], dim=1)), - # else cell_emb + batch_emb, - self.cur_gene_token_embs, - ) - if self.explicit_zero_prob and do_sample: - bernoulli = Bernoulli(probs=mvc_output["zero_probs"]) - output["mvc_output"] = bernoulli.sample() * mvc_output["pred"] - else: - output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) - if self.explicit_zero_prob: - output["mvc_zero_probs"] = mvc_output["zero_probs"] - if ECS: - # Here using customized cosine similarity instead of F.cosine_similarity - # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 - # normalize the embedding - cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) - cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) - - # mask out diagnal elements - mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) - cos_sim = cos_sim.masked_fill(mask, 0.0) - # only optimize positive similarities - cos_sim = F.relu(cos_sim) - - output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) - - if self.do_dab: - output["dab_output"] = self.grad_reverse_discriminator(cell_emb) - - return output - - def encode_batch( - self, - src: Tensor, - values: Tensor, - src_key_padding_mask: Tensor, - batch_size: int, - batch_labels: Optional[Tensor] = None, - output_to_cpu: bool = True, - time_step: Optional[int] = None, - return_np: bool = False, - ) -> Tensor: - """ - Args: - src (Tensor): shape [N, seq_len] - values (Tensor): shape [N, seq_len] - src_key_padding_mask (Tensor): shape [N, seq_len] - batch_size (int): batch size for encoding - batch_labels (Tensor): shape [N, n_batch_labels] - output_to_cpu (bool): whether to move the output to cpu - time_step (int): the time step index in the transformer output to return. - The time step is along the second dimenstion. If None, return all. - return_np (bool): whether to return numpy array - - Returns: - output Tensor of shape [N, seq_len, embsize] - """ - N = src.size(0) - device = next(self.parameters()).device - - # initialize the output tensor - array_func = np.zeros if return_np else torch.zeros - float32_ = np.float32 if return_np else torch.float32 - shape = (N, self.d_model) if time_step is not None else (N, src.size(1), self.d_model) - outputs = array_func(shape, dtype=float32_) - - for i in trange(0, N, batch_size): - raw_output = self._encode( - src[i : i + batch_size].to(device), - values[i : i + batch_size].to(device), - src_key_padding_mask[i : i + batch_size].to(device), - (batch_labels[i : i + batch_size].to(device) if batch_labels is not None else None), - ) - output = raw_output.detach() - if output_to_cpu: - output = output.cpu() - if return_np: - output = output.numpy() - if time_step is not None: - output = output[:, time_step, :] - outputs[i : i + batch_size] = output - - return outputs - - -def generate_square_subsequent_mask(sz: int) -> Tensor: - """Generates an upper-triangular matrix of -inf, with zeros on diag.""" - return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) - - -class FastTransformerEncoderWrapper(nn.Module): - def __init__( - self, - d_model: int, - nhead: int, - d_hid: int, - nlayers: int, - dropout: float = 0.5, - ): - super().__init__() - self.fast_transformer_encoder = self.build_fast_transformer_encoder(d_model, nhead, d_hid, nlayers, dropout) - - @staticmethod - def build_fast_transformer_encoder(d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float) -> nn.Module: - from fast_transformers.builders import TransformerEncoderBuilder - - if d_model % nhead != 0: - raise ValueError(f"d_model must be divisible by nhead, got d_model={d_model} and nhead={nhead}") - builder = TransformerEncoderBuilder.from_kwargs( - n_layers=nlayers, - n_heads=nhead, - query_dimensions=d_model // nhead, - value_dimensions=d_model // nhead, - feed_forward_dimensions=d_hid, - attention_type="linear", - attention_dropout=dropout, - dropout=dropout, - activation="gelu", - ) - assert builder.attention_type == "linear" - return builder.get() - - @staticmethod - def build_length_mask( - src: Tensor, - src_key_padding_mask: torch.BoolTensor, - ) -> LengthMask: - from fast_transformers.masking import LengthMask - - seq_len = src.shape[1] - num_paddings = src_key_padding_mask.sum(dim=1) - actual_seq_len = seq_len - num_paddings # (N,) - length_mask = LengthMask(actual_seq_len, max_len=seq_len, device=src.device) - - if src_key_padding_mask[length_mask.bool_matrix].sum() != 0: - raise ValueError( - "Found padding tokens in the middle of the sequence. " - "src_key_padding_mask and length_mask are not compatible." - ) - return length_mask - - def forward( - self, - src: Tensor, - src_key_padding_mask: torch.BoolTensor, - ) -> Tensor: - """ - Args: - src: Tensor, shape [N, seq_len, embsize] - src_key_padding_mask: Tensor, shape [N, seq_len] - - Returns: - output Tensor of shape [N, seq_len, embsize] - """ - if src_key_padding_mask.shape != src.shape[:2]: - raise ValueError( - f"src_key_padding_mask shape {src_key_padding_mask.shape} " - f"does not match first two dims of src shape {src.shape[:2]}" - ) - - if src_key_padding_mask.dtype != torch.bool: - raise ValueError(f"src_key_padding_mask needs to be of type torch.bool, got {src_key_padding_mask.dtype}") - - length_mask = self.build_length_mask(src, src_key_padding_mask) - output = self.fast_transformer_encoder(src, length_mask=length_mask) - return output - - -class FlashTransformerEncoderLayer(nn.Module): - r"""TransformerEncoderLayer is made up of self-attn and feedforward network. - The class is modified from torch.nn.TransformerEncoderLayer to support the - FlashAttention. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). - layer_norm_eps: the eps value in layer normalization components (default=1e-5). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False``. - - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - - Alternatively, when ``batch_first`` is ``True``: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) - >>> src = torch.rand(32, 10, 512) - >>> out = encoder_layer(src) - """ - - __constants__ = ["batch_first"] - - def __init__( - self, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - layer_norm_eps=1e-5, - batch_first=True, - device=None, - dtype=None, - norm_scheme="post", # "pre" or "post" - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.self_attn = FlashMHA( - embed_dim=d_model, - num_heads=nhead, - batch_first=batch_first, - attention_dropout=dropout, - **factory_kwargs, - ) - self.self_attn.batch_first = batch_first - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) - - self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = self._get_activation_fn(activation) - self.norm_scheme = norm_scheme - if self.norm_scheme not in ["pre", "post"]: - raise ValueError(f"norm_scheme should be pre or post, not {norm_scheme}") - - @staticmethod - def _get_activation_fn(activation): - if activation == "relu": - return F.relu - elif activation == "gelu": - return F.gelu - - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = F.relu - super().__setstate__(state) - - def forward( - self, - src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - **kwargs, - ) -> Tensor: - r"""Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - see the docs in Transformer class. - """ - if src_mask is not None: - raise ValueError("FlashTransformerEncoderLayer does not support src_mask") - - if not src_key_padding_mask.any().item(): - # no padding tokens in src - src_key_padding_mask_ = None - else: - if src_key_padding_mask.dtype != torch.bool: - src_key_padding_mask = src_key_padding_mask.bool() - # NOTE: the FlashMHA uses mask 0 for padding tokens, which is the opposite - src_key_padding_mask_ = ~src_key_padding_mask - - if self.norm_scheme == "pre": - src = self.norm1(src) - src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0] - src = src + self.dropout1(src2) - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = src + self.dropout2(src2) - else: - src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0] - src = src + self.dropout1(src2) - src = self.norm1(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = src + self.dropout2(src2) - src = self.norm2(src) - - return src - - -class GeneEncoder(nn.Module): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - ): - super().__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - self.enc_norm = nn.LayerNorm(embedding_dim) - - def forward(self, x: Tensor) -> Tensor: - x = self.embedding(x) # (batch, seq_len, embsize) - x = self.enc_norm(x) - return x - - -class PositionalEncoding(nn.Module): - def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - pe = torch.zeros(max_len, 1, d_model) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer("pe", pe) - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: Tensor, shape [seq_len, batch_size, embedding_dim] - """ - x = x + self.pe[: x.size(0)] - return self.dropout(x) - - -class ContinuousValueEncoder(nn.Module): - """ - Encode real number values to a vector using neural nets projection. - """ - - def __init__(self, d_model: int, dropout: float = 0.1, max_value: int = 512): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - self.linear1 = nn.Linear(1, d_model) - self.activation = nn.ReLU() - self.linear2 = nn.Linear(d_model, d_model) - self.norm = nn.LayerNorm(d_model) - self.max_value = max_value - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: Tensor, shape [batch_size, seq_len] - """ - # TODO: test using actual embedding layer if input is categorical - # expand last dimension - x = x.unsqueeze(-1) - # clip x to [-inf, max_value] - x = torch.clamp(x, max=self.max_value) - x = self.activation(self.linear1(x)) - x = self.linear2(x) - x = self.norm(x) - return self.dropout(x) - - -class CategoryValueEncoder(nn.Module): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - ): - super().__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - self.enc_norm = nn.LayerNorm(embedding_dim) - - def forward(self, x: Tensor) -> Tensor: - x = x.long() - x = self.embedding(x) # (batch, seq_len, embsize) - x = self.enc_norm(x) - return x - - -class BatchLabelEncoder(nn.Module): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - ): - super().__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - self.enc_norm = nn.LayerNorm(embedding_dim) - - def forward(self, x: Tensor) -> Tensor: - x = self.embedding(x) # (batch, embsize) - x = self.enc_norm(x) - return x - - -class Similarity(nn.Module): - """ - Dot product or cosine similarity - """ - - def __init__(self, temp): - super().__init__() - self.temp = temp - self.cos = nn.CosineSimilarity(dim=-1) - - def forward(self, x, y): - return self.cos(x, y) / self.temp - - -class ExprDecoder(nn.Module): - def __init__( - self, - d_model: int, - explicit_zero_prob: bool = False, - use_batch_labels: bool = False, - ): - super().__init__() - d_in = d_model * 2 if use_batch_labels else d_model - self.fc = nn.Sequential( - nn.Linear(d_in, d_model), - nn.LeakyReLU(), - nn.Linear(d_model, d_model), - nn.LeakyReLU(), - nn.Linear(d_model, 1), - ) - self.explicit_zero_prob = explicit_zero_prob - if explicit_zero_prob: - self.zero_logit = nn.Sequential( - nn.Linear(d_in, d_model), - nn.LeakyReLU(), - nn.Linear(d_model, d_model), - nn.LeakyReLU(), - nn.Linear(d_model, 1), - ) - - def forward(self, x: Tensor) -> Dict[str, Tensor]: - """x is the output of the transformer, (batch, seq_len, d_model)""" - pred_value = self.fc(x).squeeze(-1) # (batch, seq_len) - - if not self.explicit_zero_prob: - return dict(pred=pred_value) - zero_logits = self.zero_logit(x).squeeze(-1) # (batch, seq_len) - zero_probs = torch.sigmoid(zero_logits) - return dict(pred=pred_value, zero_probs=zero_probs) - # TODO: note that the return currently is only for training. Since decoder - # is not used in the test setting for the integration task, the eval/inference - # logic is not implemented yet. However, remember to implement it when - # the decoder is used in any test setting. The inference logic will need - # to sample from the bernoulli distribution with the zero_probs. - - -class ClsDecoder(nn.Module): - """ - Decoder for classification task. - """ - - def __init__( - self, - d_model: int, - n_cls: int, - nlayers: int = 3, - activation: callable = nn.ReLU, - ): - super().__init__() - # module list - self._decoder = nn.ModuleList() - for i in range(nlayers - 1): - self._decoder.append(nn.Linear(d_model, d_model)) - self._decoder.append(activation()) - self._decoder.append(nn.LayerNorm(d_model)) - self.out_layer = nn.Linear(d_model, n_cls) - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: Tensor, shape [batch_size, embsize] - """ - for layer in self._decoder: - x = layer(x) - return self.out_layer(x) - - -class MVCDecoder(nn.Module): - """ - Decoder for the masked value prediction for cell embeddings. - """ - - def __init__( - self, - d_model: int, - arch_style: str = "inner product", - query_activation: nn.Module = nn.Sigmoid, - hidden_activation: nn.Module = nn.PReLU, - explicit_zero_prob: bool = False, - use_batch_labels: bool = False, - ) -> None: - """ - Args: - d_model (:obj:`int`): dimension of the gene embedding. - arch_style (:obj:`str`): architecture style of the decoder, choice from - 1. "inner product" or 2. "concat query" or 3. "sum query". - query_activation (:obj:`nn.Module`): activation function for the query - vectors. - hidden_activation (:obj:`nn.Module`): activation function for the hidden - layers. - """ - super().__init__() - d_in = d_model * 2 if use_batch_labels else d_model - if arch_style in ["inner product", "inner product, detach"]: - self.gene2query = nn.Linear(d_model, d_model) - self.query_activation = query_activation() - self.W = nn.Linear(d_model, d_in, bias=False) - if explicit_zero_prob: # by default, gene-wise prob rate - self.W_zero_logit = nn.Linear(d_model, d_in) - elif arch_style == "concat query": - self.gene2query = nn.Linear(d_model, 64) - self.query_activation = query_activation() - self.fc1 = nn.Linear(d_model + 64, 64) - self.hidden_activation = hidden_activation() - self.fc2 = nn.Linear(64, 1) - elif arch_style == "sum query": - self.gene2query = nn.Linear(d_model, d_model) - self.query_activation = query_activation() - self.fc1 = nn.Linear(d_model, 64) - self.hidden_activation = hidden_activation() - self.fc2 = nn.Linear(64, 1) - else: - raise ValueError(f"Unknown arch_style: {arch_style}") - - self.arch_style = arch_style - self.do_detach = arch_style.endswith("detach") - self.explicit_zero_prob = explicit_zero_prob - - def forward(self, cell_emb: Tensor, gene_embs: Tensor) -> Union[Tensor, Dict[str, Tensor]]: - """ - Args: - cell_emb: Tensor, shape (batch, embsize=d_model) - gene_embs: Tensor, shape (batch, seq_len, embsize=d_model) - """ - gene_embs = gene_embs.detach() if self.do_detach else gene_embs - if self.arch_style in ["inner product", "inner product, detach"]: - query_vecs = self.query_activation(self.gene2query(gene_embs)) - cell_emb = cell_emb.unsqueeze(2) # (batch, embsize, 1) - # the pred gene expr values, # (batch, seq_len) - pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) - if not self.explicit_zero_prob: - return dict(pred=pred_value) - # zero logits need to based on the cell_emb, because of input exprs - zero_logits = torch.bmm(self.W_zero_logit(query_vecs), cell_emb).squeeze(2) - zero_probs = torch.sigmoid(zero_logits) - return dict(pred=pred_value, zero_probs=zero_probs) - elif self.arch_style == "concat query": - query_vecs = self.query_activation(self.gene2query(gene_embs)) - # expand cell_emb to (batch, seq_len, embsize) - cell_emb = cell_emb.unsqueeze(1).expand(-1, gene_embs.shape[1], -1) - - h = self.hidden_activation(self.fc1(torch.cat([cell_emb, query_vecs], dim=2))) - if self.explicit_zero_prob: - raise NotImplementedError - return self.fc2(h).squeeze(2) # (batch, seq_len) - elif self.arch_style == "sum query": - query_vecs = self.query_activation(self.gene2query(gene_embs)) - cell_emb = cell_emb.unsqueeze(1) - - h = self.hidden_activation(self.fc1(cell_emb + query_vecs)) - if self.explicit_zero_prob: - raise NotImplementedError - return self.fc2(h).squeeze(2) # (batch, seq_len) - - -class AdversarialDiscriminator(nn.Module): - """ - Discriminator for the adversarial training for batch correction. - """ - - def __init__( - self, - d_model: int, - n_cls: int, - nlayers: int = 3, - activation: callable = nn.LeakyReLU, - reverse_grad: bool = False, - ): - super().__init__() - # module list - self._decoder = nn.ModuleList() - for i in range(nlayers - 1): - self._decoder.append(nn.Linear(d_model, d_model)) - self._decoder.append(activation()) - self._decoder.append(nn.LayerNorm(d_model)) - self.out_layer = nn.Linear(d_model, n_cls) - self.reverse_grad = reverse_grad - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: Tensor, shape [batch_size, embsize] - """ - if self.reverse_grad: - x = grad_reverse(x, lambd=1.0) - for layer in self._decoder: - x = layer(x) - return self.out_layer(x) diff --git a/src/state/tx/models/scgpt/utils.py b/src/state/tx/models/scgpt/utils.py deleted file mode 100644 index c186f58d..00000000 --- a/src/state/tx/models/scgpt/utils.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Union - -import numpy as np -import torch - - -def map_raw_id_to_vocab_id( - raw_ids: Union[np.ndarray, torch.Tensor], - gene_ids: np.ndarray, -) -> Union[np.ndarray, torch.Tensor]: - """ - Map some raw ids which are indices of the raw gene names to the indices of the - - Args: - raw_ids: the raw ids to map - gene_ids: the gene ids to map to - """ - if isinstance(raw_ids, torch.Tensor): - device = raw_ids.device - dtype = raw_ids.dtype - return_pt = True - raw_ids = raw_ids.cpu().numpy() - - elif isinstance(raw_ids, np.ndarray): - return_pt = False - dtype = raw_ids.dtype - - else: - raise ValueError("raw_ids must be either torch.Tensor or np.ndarray.") - - if raw_ids.ndim != 1: - raise ValueError(f"raw_ids must be 1d, got {raw_ids.ndim}d.") - - if gene_ids.ndim != 1: - raise ValueError(f"gene_ids must be 1d, got {gene_ids.ndim}d.") - - mapped_ids: np.ndarray = gene_ids[raw_ids] - assert mapped_ids.shape == raw_ids.shape - if return_pt: - if isinstance(mapped_ids, np.ndarray): - return torch.from_numpy(mapped_ids).type(dtype).to(device) - return mapped_ids.to(dtype) - return mapped_ids.astype(dtype) diff --git a/src/state/tx/models/scvi/__init__.py b/src/state/tx/models/scvi/__init__.py deleted file mode 100644 index 46582dde..00000000 --- a/src/state/tx/models/scvi/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._model import SCVIPerturbationModel - -__all__ = ["SCVIPerturbationModel"] diff --git a/src/state/tx/models/scvi/_base_modules.py b/src/state/tx/models/scvi/_base_modules.py deleted file mode 100644 index 2d4605af..00000000 --- a/src/state/tx/models/scvi/_base_modules.py +++ /dev/null @@ -1,301 +0,0 @@ -from typing import Literal, Optional - -import torch -import torch.nn as nn -from torch.distributions import Normal -from torch.nn import functional as F - - -class FocalLoss(nn.Module): - """Inspired by https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py - - Focal Loss, as described in https://arxiv.org/abs/1708.02002. - It is essentially an enhancement to cross entropy loss and is - useful for classification tasks when there is a large class imbalance. - x is expected to contain raw, unnormalized scores for each class. - y is expected to contain class labels. - Shape: - - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. - - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. - """ - - def __init__( - self, - alpha: Optional[torch.Tensor] = None, - gamma: float = 2.0, - reduction: str = "mean", - ): - """ - Args: - alpha (Tensor, optional): Weights for each class. Defaults to None. - gamma (float, optional): A constant, as described in the paper. - Defaults to 0. - reduction (str, optional): 'mean', 'sum' or 'none'. - Defaults to 'mean'. - """ - if reduction not in ("mean", "sum", "none"): - raise ValueError('Reduction must be one of: "mean", "sum", "none".') - - super().__init__() - self.alpha = alpha - self.gamma = gamma - self.reduction = reduction - - self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none") - - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - if len(y_true) == 0: - return torch.tensor(0.0) - - # compute weighted cross entropy term: -alpha * log(pt) - # (alpha is already part of self.nll_loss) - log_p = F.log_softmax(y_pred, dim=-1) - ce = self.nll_loss(log_p, y_true) - - # get true class column from each row - all_rows = torch.arange(len(y_pred)) - log_pt = log_p[all_rows, y_true] - - # compute focal term: (1 - pt)^gamma - pt = log_pt.exp() - focal_term = (1 - pt) ** self.gamma - - # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) - loss = focal_term * ce - - if self.reduction == "mean": - loss = loss.mean() - elif self.reduction == "sum": - loss = loss.sum() - - return loss - - -class MLP(nn.Module): - def __init__( - self, - n_input, - n_output, - n_hidden, - n_layers, - activation_fn: Optional[nn.Module] = nn.ReLU, - use_norm: str = "batch", - dropout_rate: float = 0.3, - drop_norm_last_layer: bool = True, - ): - super().__init__() - if drop_norm_last_layer: - layers = [n_input] + [n_hidden] * n_layers - else: - layers = [n_input] + [n_hidden] * (n_layers - 1) + [n_output] - - network = [] - for n_in, n_out in zip(layers[:-1], layers[1:]): - network.append(nn.Linear(n_in, n_out)) - if use_norm == "batch": - network.append(nn.BatchNorm1d(n_out)) - elif use_norm == "layer": - network.append(nn.LayerNorm(n_out)) - network.append(activation_fn()) - network.append(nn.Dropout(dropout_rate)) - - if drop_norm_last_layer: - network.append(nn.Linear(n_hidden, n_output)) - - self.network = nn.Sequential(*network) - - def forward(self, x): - """ - x: (batch_size, n_input) - """ - return self.network(x) - - -class Classifier(nn.Module): - def __init__( - self, - n_input, - n_labels, - n_hidden, - n_layers, - activation_fn=nn.ReLU, - use_norm: str = "batch", - dropout_rate: float = 0.3, - ): - super().__init__() - self.n_output = n_labels - - self.network = MLP( - n_input=n_input, - n_output=n_labels, - n_layers=n_layers, - n_hidden=n_hidden, - use_norm=use_norm, - dropout_rate=dropout_rate, - activation_fn=activation_fn, - drop_norm_last_layer=True, - ) - - def forward(self, x): - y = self.network(x) - return y - - -class VariationalEncoder(nn.Module): - def __init__( - self, - n_input: int, - n_output: int, - n_layers: int = 1, - n_hidden: int = 128, - dropout_rate: float = 0.1, - use_norm: str = "batch", - var_eps: float = 1e-4, - var_activation=None, - return_dist: bool = False, - **kwargs, - ): - super().__init__() - - self.var_eps = var_eps - self.encoder = MLP( - n_input=n_input, - n_output=n_hidden, - n_layers=n_layers, - n_hidden=n_hidden, - dropout_rate=dropout_rate, - use_norm=use_norm, - drop_norm_last_layer=False, - ) - self.mean_encoder = nn.Linear(n_hidden, n_output) - self.var_encoder = nn.Linear(n_hidden, n_output) - self.return_dist = return_dist - - self.var_activation = torch.exp if var_activation is None else var_activation - - def forward(self, x: torch.Tensor, *cat_list: int): - """ """ - q = self.encoder(x, *cat_list) - - q_m = self.mean_encoder(q) - q_v = self.var_activation(self.var_encoder(q)) + self.var_eps - - dist = Normal(q_m, q_v.sqrt()) - latent = dist.rsample() - - if self.return_dist: - return dist, latent - - return q_m, q_v, latent - - -# Inspired by scvi-tools source code: https://github.com/scverse/scvi-tools/blob/d094c9b3c14e8cb3ac3a309b9cf0160aff237393/scvi/nn/_base_components.py -class CountDecoder(nn.Module): - """Decodes data from latent space of ``n_input`` dimensions into ``n_output`` dimensions. - - Uses a fully-connected neural network of ``n_hidden`` layers. - - Parameters - ---------- - n_input - The dimensionality of the input (latent space) - n_output - The dimensionality of the output (data space) - n_cat_list - A list containing the number of categories - for each category of interest. Each category will be - included using a one-hot encoding - n_layers - The number of fully-connected hidden layers - n_hidden - The number of nodes per hidden layer - dropout_rate - Dropout rate to apply to each of the hidden layers - inject_covariates - Whether to inject covariates in each layer, or just the first (default). - use_batch_norm - Whether to use batch norm in layers - use_layer_norm - Whether to use layer norm in layers - scale_activation - Activation layer to use for px_scale_decoder - """ - - def __init__( - self, - n_input: int, - n_output: int, - n_layers: int = 1, - n_hidden: int = 128, - use_norm: Literal["batch", "layer"] = "batch", - scale_activation: Literal["softmax", "softplus"] = "softmax", - ): - super().__init__() - self.px_decoder = MLP( - n_input=n_input, - n_output=n_hidden, - n_layers=n_layers, - n_hidden=n_hidden, - dropout_rate=0.0, - use_norm=use_norm, - drop_norm_last_layer=False, - ) - - # mean gamma - if scale_activation == "softmax": - px_scale_activation = nn.Softmax(dim=-1) - elif scale_activation == "softplus": - px_scale_activation = nn.Softplus() - self.px_scale_decoder = nn.Sequential( - nn.Linear(n_hidden, n_output), - px_scale_activation, - ) - - # dispersion: here we only deal with gene-cell dispersion case - self.px_r_decoder = nn.Linear(n_hidden, n_output) - - # dropout - self.px_dropout_decoder = nn.Linear(n_hidden, n_output) - - def forward( - self, - dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"], - z: torch.Tensor, - library: torch.Tensor, - ): - """The forward computation for a single sample. - - #. Decodes the data from the latent space using the decoder network - #. Returns parameters for the ZINB distribution of expression - #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None`` - - Parameters - ---------- - z : - tensor with shape ``(n_input,)`` - library_size - library size - cat_list - list of category membership(s) for this sample - dispersion - One of the following - - * ``'gene'`` - dispersion parameter of NB is constant per gene across cells - * ``'gene-batch'`` - dispersion can differ between different batches - * ``'gene-label'`` - dispersion can differ between different labels - * ``'gene-cell'`` - dispersion can differ for every gene in every cell - - Returns - ------- - 4-tuple of :py:class:`torch.Tensor` - parameters for the ZINB distribution of expression - - """ - # The decoder returns values for the parameters of the ZINB distribution - px = self.px_decoder(z) - px_scale = self.px_scale_decoder(px) - px_dropout = self.px_dropout_decoder(px) - # Clamp to high value: exp(12) ~ 160000 to avoid nans (computational stability) - px_rate = torch.exp(library) * px_scale # torch.clamp( , max=12) - px_r = self.px_r_decoder(px) if dispersion == "gene-cell" else None - return px_scale, px_r, px_rate, px_dropout diff --git a/src/state/tx/models/scvi/_callbacks.py b/src/state/tx/models/scvi/_callbacks.py deleted file mode 100644 index befb433e..00000000 --- a/src/state/tx/models/scvi/_callbacks.py +++ /dev/null @@ -1,29 +0,0 @@ -from lightning.pytorch.callbacks import Callback - - -class CPABestModelTracker(Callback): - def __init__(self, monitor: str = "val_loss", mode: str = "min"): - super().__init__() - self.monitor = monitor - self.mode = mode - self.best_model = None - self.best_score = None - - def on_validation_end(self, trainer, pl_module): - if self.best_score is None: - self.best_score = trainer.callback_metrics[self.monitor] - self.best_model = pl_module.state_dict() - else: - if self.mode == "min": - if trainer.callback_metrics[self.monitor] < self.best_score: - self.best_score = trainer.callback_metrics[self.monitor] - self.best_model = pl_module.state_dict() - else: - if trainer.callback_metrics[self.monitor] > self.best_score: - self.best_score = trainer.callback_metrics[self.monitor] - self.best_model = pl_module.state_dict() - - def on_train_end(self, trainer, pl_module): - pl_module.load_state_dict(self.best_model) - print(f"Best model loaded with {self.monitor} = {self.best_score}") - return self.best_model diff --git a/src/state/tx/models/scvi/_dists.py b/src/state/tx/models/scvi/_dists.py deleted file mode 100644 index 18423b16..00000000 --- a/src/state/tx/models/scvi/_dists.py +++ /dev/null @@ -1,387 +0,0 @@ -import warnings - -import torch -import torch.nn.functional as F -from torch.distributions import Distribution, Gamma, constraints -from torch.distributions import Poisson as PoissonTorch -from torch.distributions.constraints import Constraint -from torch.distributions.utils import ( - broadcast_all, - lazy_property, - logits_to_probs, - probs_to_logits, -) - - -class _Optional(Constraint): - def __init__(self, constraint: Constraint): - self.constraint = constraint - - def check(self, value: torch.Tensor) -> torch.Tensor: - if value is None: - return torch.ones(1, dtype=torch.bool) - return self.constraint.check(value) - - def __repr__(self) -> str: - return f"Optional({self.constraint})" - - -def optional_constraint(constraint: Constraint) -> Constraint: - """Returns a wrapped constraint that allows optional values.""" - return _Optional(constraint) - - -def log_zinb_positive( - x: torch.Tensor, - mu: torch.Tensor, - theta: torch.Tensor, - pi: torch.Tensor, - eps: float = 1e-8, -) -> torch.Tensor: - """Log likelihood (scalar) of a minibatch according to a zinb model. - - Parameters - ---------- - x - Data - mu - mean of the negative binomial (has to be positive support) (shape: minibatch x vars) - theta - inverse dispersion parameter (has to be positive support) (shape: minibatch x vars) - pi - logit of the dropout parameter (real support) (shape: minibatch x vars) - eps - numerical stability constant - - Notes - ----- - We parametrize the bernoulli using the logits, hence the softplus functions appearing. - """ - # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless - # of batch or labels) - if theta.ndimension() == 1: - theta = theta.view(1, theta.size(0)) # In this case, we reshape theta for broadcasting - - # Uses log(sigmoid(x)) = -softplus(-x) - softplus_pi = F.softplus(-pi) - log_theta_eps = torch.log(theta + eps) - log_theta_mu_eps = torch.log(theta + mu + eps) - pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps) - - case_zero = F.softplus(pi_theta_log) - softplus_pi - mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero) - - case_non_zero = ( - -softplus_pi - + pi_theta_log - + x * (torch.log(mu + eps) - log_theta_mu_eps) - + torch.lgamma(x + theta) - - torch.lgamma(theta) - - torch.lgamma(x + 1) - ) - mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero) - - res = mul_case_zero + mul_case_non_zero - - return res - - -def log_nb_positive( - x: torch.Tensor, - mu: torch.Tensor, - theta: torch.Tensor, - eps: float = 1e-8, - log_fn: callable = torch.log, - lgamma_fn: callable = torch.lgamma, -) -> torch.Tensor: - """Log likelihood (scalar) of a minibatch according to a nb model. - - Parameters - ---------- - x - data - mu - mean of the negative binomial (has to be positive support) (shape: minibatch x vars) - theta - inverse dispersion parameter (has to be positive support) (shape: minibatch x vars) - eps - numerical stability constant - log_fn - log function - lgamma_fn - log gamma function - """ - log = log_fn - lgamma = lgamma_fn - log_theta_mu_eps = log(theta + mu + eps) - res = ( - theta * (log(theta + eps) - log_theta_mu_eps) - + x * (log(mu + eps) - log_theta_mu_eps) - + lgamma(x + theta) - - lgamma(theta) - - lgamma(x + 1) - ) - - return res - - -def _convert_counts_logits_to_mean_disp( - total_count: torch.Tensor, logits: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """NB parameterizations conversion. - - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - logits - success logits. - - Returns - ------- - type - the mean and inverse overdispersion of the NB distribution. - - """ - theta = total_count - mu = logits.exp() * theta - return mu, theta - - -def _gamma(theta: torch.Tensor, mu: torch.Tensor) -> Gamma: - concentration = theta - rate = theta / mu - # Important remark: Gamma is parametrized by the rate = 1/scale! - gamma_d = Gamma(concentration=concentration, rate=rate) - return gamma_d - - -# Inspired from scvi-tools source code:https://github.com/scverse/scvi-tools/blob/d094c9b3c14e8cb3ac3a309b9cf0160aff237393/scvi/distributions/_negative_binomial.py -class NegativeBinomial(Distribution): - r"""Negative binomial distribution. - - One of the following parameterizations must be provided: - - (1), (`total_count`, `probs`) where `total_count` is the number of failures until - the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`) - parameterization, which is the one used by scvi-tools. These parameters respectively - control the mean and inverse dispersion of the distribution. - - In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as - follows: - - 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, - \underbrace{\theta/\mu}_{\text{rate}})` - 2. :math:`x \sim \textrm{Poisson}(w)` - - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - probs - The success probability. - mu - Mean of the distribution. - theta - Inverse dispersion. - scale - Normalized mean expression of the distribution. - validate_args - Raise ValueError if arguments do not match constraints - """ - - arg_constraints = { - "mu": optional_constraint(constraints.greater_than_eq(0)), - "theta": optional_constraint(constraints.greater_than_eq(0)), - "scale": optional_constraint(constraints.greater_than_eq(0)), - } - support = constraints.nonnegative_integer - - def __init__( - self, - total_count: torch.Tensor | None = None, - probs: torch.Tensor | None = None, - logits: torch.Tensor | None = None, - mu: torch.Tensor | None = None, - theta: torch.Tensor | None = None, - scale: torch.Tensor | None = None, - validate_args: bool = False, - ): - self._eps = 1e-8 - if (mu is None) == (total_count is None): - raise ValueError( - "Please use one of the two possible parameterizations. Refer to the documentation for more information." - ) - - using_param_1 = total_count is not None and (logits is not None or probs is not None) - if using_param_1: - logits = logits if logits is not None else probs_to_logits(probs) - total_count = total_count.type_as(logits) - total_count, logits = broadcast_all(total_count, logits) - mu, theta = _convert_counts_logits_to_mean_disp(total_count, logits) - else: - mu, theta = broadcast_all(mu, theta) - self.mu = mu - self.theta = theta - self.scale = scale - super().__init__(validate_args=validate_args) - - @property - def mean(self) -> torch.Tensor: - return self.mu - - @property - def variance(self) -> torch.Tensor: - return self.mean + (self.mean**2) / self.theta - - @torch.inference_mode() - def sample( - self, - sample_shape: torch.Size | tuple | None = None, - ) -> torch.Tensor: - """Sample from the distribution.""" - sample_shape = sample_shape or torch.Size() - gamma_d = self._gamma() - p_means = gamma_d.sample(sample_shape) - - # Clamping as distributions objects can have buggy behaviors when - # their parameters are too high - l_train = torch.clamp(p_means, max=1e8) - counts = PoissonTorch(l_train).sample() # Shape : (n_samples, n_cells_batch, n_vars) - return counts - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - if self._validate_args: - try: - self._validate_sample(value) - except ValueError: - warnings.warn( - "The value argument must be within the support of the distribution", - UserWarning, - stacklevel=1, - ) - - return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps) - - def _gamma(self) -> Gamma: - return _gamma(self.theta, self.mu) - - def __repr__(self) -> str: - param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] - args_string = ", ".join( - [ - f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" - for p in param_names - if self.__dict__[p] is not None - ] - ) - return self.__class__.__name__ + "(" + args_string + ")" - - -# Inspired from scvi-tools source code:https://github.com/scverse/scvi-tools/blob/d094c9b3c14e8cb3ac3a309b9cf0160aff237393/scvi/distributions/_negative_binomial.py -class ZeroInflatedNegativeBinomial(NegativeBinomial): - r"""Zero-inflated negative binomial distribution. - - One of the following parameterizations must be provided: - - (1), (`total_count`, `probs`) where `total_count` is the number of failures until - the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`) - parameterization, which is the one used by scvi-tools. These parameters respectively - control the mean and inverse dispersion of the distribution. - - In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as - follows: - - 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, - \underbrace{\theta/\mu}_{\text{rate}})` - 2. :math:`x \sim \textrm{Poisson}(w)` - - Parameters - ---------- - total_count - Number of failures until the experiment is stopped. - probs - The success probability. - mu - Mean of the distribution. - theta - Inverse dispersion. - zi_logits - Logits scale of zero inflation probability. - scale - Normalized mean expression of the distribution. - validate_args - Raise ValueError if arguments do not match constraints - """ - - arg_constraints = { - "mu": optional_constraint(constraints.greater_than_eq(0)), - "theta": optional_constraint(constraints.greater_than_eq(0)), - "zi_logits": optional_constraint(constraints.real), - "scale": optional_constraint(constraints.greater_than_eq(0)), - } - support = constraints.nonnegative_integer - - def __init__( - self, - total_count: torch.Tensor | None = None, - probs: torch.Tensor | None = None, - logits: torch.Tensor | None = None, - mu: torch.Tensor | None = None, - theta: torch.Tensor | None = None, - zi_logits: torch.Tensor | None = None, - scale: torch.Tensor | None = None, - validate_args: bool = False, - ): - super().__init__( - total_count=total_count, - probs=probs, - logits=logits, - mu=mu, - theta=theta, - scale=scale, - validate_args=validate_args, - ) - self.zi_logits, self.mu, self.theta = broadcast_all(zi_logits, self.mu, self.theta) - - @property - def mean(self) -> torch.Tensor: - pi = self.zi_probs - return (1 - pi) * self.mu - - @property - def variance(self) -> None: - raise NotImplementedError - - @lazy_property - def zi_logits(self) -> torch.Tensor: - """ZI logits.""" - return probs_to_logits(self.zi_probs, is_binary=True) - - @lazy_property - def zi_probs(self) -> torch.Tensor: - return logits_to_probs(self.zi_logits, is_binary=True) - - @torch.inference_mode() - def sample( - self, - sample_shape: torch.Size | tuple | None = None, - ) -> torch.Tensor: - """Sample from the distribution.""" - sample_shape = sample_shape or torch.Size() - samp = super().sample(sample_shape=sample_shape) - is_zero = torch.rand_like(samp) <= self.zi_probs - samp_ = torch.where(is_zero, torch.zeros_like(samp), samp) - return samp_ - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - """Log probability.""" - try: - self._validate_sample(value) - except ValueError: - warnings.warn( - "The value argument must be within the support of the distribution", - UserWarning, - stacklevel=1, - ) - return log_zinb_positive(value, self.mu, self.theta, self.zi_logits, eps=1e-08) diff --git a/src/state/tx/models/scvi/_model.py b/src/state/tx/models/scvi/_model.py deleted file mode 100644 index 2760d74c..00000000 --- a/src/state/tx/models/scvi/_model.py +++ /dev/null @@ -1,380 +0,0 @@ -from typing import Dict, Tuple - -import torch -from torch.optim.lr_scheduler import StepLR - -from ..base import PerturbationModel -from ._module import scVIModule - - -class SCVIPerturbationModel(PerturbationModel): - """ - Implementation of the scVI model. The outputs are always in - gene expression space. - - Args: - input_dim: Dimension of input embeddings (either number of genes or latent dim from obsm key) - hidden_dim: Dimension of hidden layers - output_dim: Number of genes to predict - pert_dim: Dimension of perturbation inputs (usually one-hot size) - decode_intermediate_dim: Optional intermediate dimension for decoder - n_encoder_layers: Number of layers in encoder (default: 2) - n_decoder_layers: Number of layers in encoder (default: 2) - dropout: Dropout rate (default: 0.1) - learning_rate: Learning rate for optimizer (default: 1e-3) - loss_fn: Loss function (default: 'nn.MSELoss()') - """ - - def __init__( - self, - input_dim: int, - hidden_dim: int, - output_dim: int, - pert_dim: int, - n_cell_types: int, - n_perts: int, - n_batches: int, - output_space: str = "gene", - lr=5e-4, - wd=1e-6, - n_steps_kl_warmup: int = None, - n_epochs_kl_warmup: int = None, - step_size_lr: int = 45, - do_clip_grad: bool = False, - gradient_clip_value: float = 3.0, - check_val_every_n_epoch: int = 5, - **kwargs, - ): - # Register with parent constructor - super().__init__( - input_dim=input_dim, - hidden_dim=hidden_dim, - output_dim=output_dim, - pert_dim=pert_dim, - output_space=output_space, - **kwargs, - ) - - # Set class specific parameters before registering with parent constructor - self.n_cell_types = n_cell_types - self.n_perts = n_perts - self.n_batches = n_batches - - self.n_layers_encoder = kwargs.get("n_layers_encoder", 2) - self.n_layers_decoder = kwargs.get("n_layers_decoder", 2) - self.n_hidden_encoder = kwargs.get("n_hidden_encoder", 256) - self.n_hidden_decoder = kwargs.get("n_hidden_decoder", 256) - self.n_latent = kwargs.get("n_latent", 64) - self.recon_loss = kwargs.get("recon_loss", "nb") - - self.use_batch_norm = kwargs.get("use_batch_norm", "both") - self.use_layer_norm = kwargs.get("use_layer_norm", "none") - - self.pert_embeddings = None # will be set in _build_networks - - self.dropout_rate_encoder = kwargs.get("dropout_rate_encoder", 0.0) - self.dropout_rate_decoder = kwargs.get("dropout_rate_decoder", 0.0) - self.seed = kwargs.get("seed", 0) - - # training params - self.lr = lr - self.wd = wd - self.n_steps_kl_warmup = n_steps_kl_warmup - self.n_epochs_kl_warmup = n_epochs_kl_warmup - - self.step_size_lr = step_size_lr - self.do_clip_grad = do_clip_grad - self.gradient_clip_value = gradient_clip_value - self.check_val_every_n_epoch = check_val_every_n_epoch - - self.kwargs = kwargs - - assert self.output_space in ["gene", "all"], "scVI model only supports gene-level or all-level output" - - # Build model components - self._build_networks() - - def _build_networks(self): - """ - Build the core components: - """ - self.module = scVIModule( - n_genes=self.input_dim, - n_perts=self.n_perts, - n_cell_types=self.n_cell_types, - n_batches=self.n_batches, - pert_embeddings=self.pert_embeddings, - n_latent=self.n_latent, - recon_loss=self.recon_loss, - n_hidden_encoder=self.n_hidden_encoder, - n_layers_encoder=self.n_layers_encoder, - n_hidden_decoder=self.n_hidden_decoder, - n_layers_decoder=self.n_layers_decoder, - use_batch_norm=self.use_batch_norm, - use_layer_norm=self.use_layer_norm, - dropout_rate_encoder=self.dropout_rate_encoder, - dropout_rate_decoder=self.dropout_rate_decoder, - seed=self.seed, - ) - - def encode_perturbation(self, pert: torch.Tensor) -> torch.Tensor: - """Map perturbation to an effect vector in embedding space.""" - raise NotImplementedError("Perturbation encoding not supported for scVI model") - - def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor: - """Expression is already in embedding space, pass through.""" - raise NotImplementedError("Basal expression encoding not supported for scVI model") - - def perturb(self, pert: torch.Tensor, basal: torch.Tensor) -> torch.Tensor: - """ - Given a perturbation and basal embeddings, compute the perturbed embedding. - """ - # Project perturbation and basal cell state to latent space - raise NotImplementedError("Perturbation not supported for scVI model") - - @property - def kl_weight(self): - slope = 1.0 - if self.n_steps_kl_warmup is not None: - global_step = self.global_step - - if global_step <= self.n_steps_kl_warmup: - proportion = global_step / self.n_steps_kl_warmup - return slope * proportion - else: - return slope - elif self.n_epochs_kl_warmup is not None: - current_epoch = self.current_epoch - - if current_epoch <= self.n_epochs_kl_warmup: - proportion = current_epoch / self.n_epochs_kl_warmup - return slope * proportion - else: - return slope - else: - return slope - - def extract_batch_tensors( - self, batch: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x_pert = batch["pert_cell_emb"] - x_basal = batch["ctrl_cell_emb"] - pert = batch["pert_emb"] - cell_type = batch["cell_type_onehot"] - batch_ids = batch["batch"] - - # if pert is one-hot, convert to index - if pert.dim() == 2 and pert.size(1) == self.n_perts: - pert = pert.argmax(1) - - if cell_type.dim() == 2 and cell_type.size(1) == self.n_cell_types: - cell_type = cell_type.argmax(1) - - if batch_ids.dim() == 2 and batch_ids.size(1) == self.n_batches: - batch_ids = batch_ids.argmax(1) - - return x_pert, x_basal, pert, cell_type, batch_ids - - def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - """ - Given - - Args: - batch: Dictionary containing: - - pert: Perturbation one-hot - - basal: Control expression embedding - - cell_type: Cell type one-hot - - batch: Batch one-hot - """ - x_pert, x_basal, pert, cell_type, batch_ids = self.extract_batch_tensors(batch) - - encoder_outputs, decoder_outputs = self.module.forward(x_basal, pert, cell_type, batch_ids) - - if self.recon_loss == "gauss": - output_key = "loc" - else: - output_key = "mu" - - output = getattr(decoder_outputs["px"], output_key) - - return output - - def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Training step logic.""" - x_pert, x_basal, pert, cell_type, batch_ids = self.extract_batch_tensors(batch) - - encoder_outputs, decoder_outputs = self.module.forward(x_basal, pert, cell_type, batch_ids) - - recon_loss, kl_loss = self.module.loss( - x_pert=x_pert, - encoder_outputs=encoder_outputs, - decoder_outputs=decoder_outputs, - ) - - loss = recon_loss + self.kl_weight * kl_loss - - r2_mean, r2_lfc = self.module.r2_metric( - x_pert=x_pert, - x_basal=x_basal, - encoder_outputs=encoder_outputs, - decoder_outputs=decoder_outputs, - ) - - self.log("KL_weight", self.kl_weight, on_epoch=True, on_step=False, prog_bar=True, logger=True) - - self.log( - "recon_loss", - recon_loss.item(), - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "r2_mean", - r2_mean, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "r2_lfc", - r2_lfc, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - if self.global_step % self.step_size_lr * 1000 == 0: - sch = self.lr_schedulers() - sch.step() - - return loss - - def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: - """Validation step logic.""" - x_pert, x_basal, pert, cell_type, batch_ids = self.extract_batch_tensors(batch) - - encoder_outputs, decoder_outputs = self.module.forward(x_basal, pert, cell_type, batch_ids) - - recon_loss, kl_loss = self.module.loss( - x_pert=x_pert, - encoder_outputs=encoder_outputs, - decoder_outputs=decoder_outputs, - ) - - r2_mean, r2_lfc = self.module.r2_metric( - x_pert=x_pert, - x_basal=x_basal, - encoder_outputs=encoder_outputs, - decoder_outputs=decoder_outputs, - ) - - self.log("val_loss", recon_loss + self.kl_weight * kl_loss, prog_bar=True) - self.log("val_r2_mean", r2_mean, prog_bar=True) - self.log("val_r2_lfc", r2_lfc, prog_bar=True) - - # is_control = "DMSO_TF" == batch["pert_name"][0] or "non-targeting" == batch["pert_name"][0] - # if np.random.rand() < 0.1 or is_control: - # if self.recon_loss == "gauss": - # output_key = "loc" - # else: - # output_key = "mu" - - # x_pred = getattr(decoder_outputs["px"], output_key) - - # self._update_val_cache(batch, x_pred) - - def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: - x_pert, x_basal, pert, cell_type, batch_ids = self.extract_batch_tensors(batch) - - encoder_outputs, decoder_outputs = self.module.forward(x_basal, pert, cell_type, batch_ids) - - recon_loss, kl_loss = self.module.loss( - x_pert=x_pert, - encoder_outputs=encoder_outputs, - decoder_outputs=decoder_outputs, - ) - - loss = recon_loss + self.kl_weight * kl_loss - - if self.recon_loss == "gauss": - output_key = "loc" - else: - output_key = "mu" - - x_pred = getattr(decoder_outputs["px"], output_key) - - self.log("test_loss", loss, prog_bar=True) - # self._update_test_cache(batch, x_pred) - - return x_pred - - def predict_step(self, batch, batch_idx, **kwargs): - """ - Typically used for final inference. We'll replicate old logic: - returning 'preds', 'X', 'pert_name', etc. - """ - x_pert, x_basal, pert, cell_type, batch_ids = self.extract_batch_tensors(batch) - - encoder_outputs, decoder_outputs = self.module.forward(x_basal, pert, cell_type, batch_ids) - - if self.recon_loss == "gauss": - output_key = "loc" - else: - output_key = "mu" - - x_pred = getattr(decoder_outputs["px"], output_key) - - return { - "preds": torch.nan_to_num(torch.log(x_pred + 1), nan=0.0, posinf=1e4, neginf=0.0), - "pert_cell_counts_preds": torch.nan_to_num(torch.log(x_pred + 1), nan=0.0, posinf=1e4, neginf=0.0), - "pert_cell_emb": batch.get("pert_cell_emb", None), - "pert_cell_counts": batch.get("pert_cell_counts", None), - "pert_emb": batch.get("pert_emb", None), - "pert_name": batch.get("pert_name", None), - "cell_type": batch.get("cell_type", None), - "cell_type_name": batch.get("cell_type_name", None), - "batch": batch.get("batch", None), - "batch_name": batch.get("batch_name", None), - "ctrl_cell_emb": batch.get("ctrl_cell_emb", None), - } - - def configure_optimizers(self): - """Set up optimizer.""" - ae_params = ( - list(filter(lambda p: p.requires_grad, self.module.encoder.parameters())) - + list(filter(lambda p: p.requires_grad, self.module.decoder.parameters())) - + list( - filter( - lambda p: p.requires_grad, - self.module.pert_embeddings.parameters(), - ) - ) - + list( - filter( - lambda p: p.requires_grad, - self.module.cell_type_embeddings.parameters(), - ) - ) - + list(filter(lambda p: p.requires_grad, self.module.batch_embeddings.parameters())) - ) - - if self.module.recon_loss in ["zinb", "nb"]: - ae_params += [self.module.px_r] - - optimizer_autoencoder = torch.optim.Adam(ae_params, lr=self.lr, weight_decay=self.wd) - - scheduler_autoencoder = StepLR(optimizer_autoencoder, step_size=self.step_size_lr, gamma=0.9) - - optimizers = [optimizer_autoencoder] - schedulers = [scheduler_autoencoder] - - if self.step_size_lr is not None: - return optimizers, schedulers - else: - return optimizers diff --git a/src/state/tx/models/scvi/_module.py b/src/state/tx/models/scvi/_module.py deleted file mode 100644 index f82b8827..00000000 --- a/src/state/tx/models/scvi/_module.py +++ /dev/null @@ -1,342 +0,0 @@ -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -from torch.distributions import Normal -from torch.distributions.kl import kl_divergence as kl -from torch_scatter import scatter_mean -from torchmetrics.functional import pairwise_euclidean_distance, pearson_corrcoef, r2_score -from torchmetrics.functional.clustering import normalized_mutual_info_score - -from ._base_modules import CountDecoder, VariationalEncoder -from ._dists import NegativeBinomial, ZeroInflatedNegativeBinomial - - -def knn_purity(data, labels, n_neighbors=15): - """Computes KNN Purity for ``data`` given the labels. - Parameters - ---------- - data: - torch tensor of data (n_samples, n_features) - labels - torch tensor of labels (n_samples,) - n_neighbors: int - Number of nearest neighbors. - Returns - ------- - score: float - KNN purity score. A float between 0 and 1. - """ - distances = pairwise_euclidean_distance(data) - # sort each row in distances to get nearest neighbors - - _, indices = torch.topk(distances, k=n_neighbors + 1, dim=1, largest=False, sorted=True) - indices = indices[:, 1:] # remove self - # neighbors_labels = np.vectorize(lambda i: labels[i])(indices) - neighbors_labels = labels[indices] # (n_samples, n_neighbors) - - # pre cell purity scores - scores = ((neighbors_labels - labels.reshape(-1, 1)) == 0).float().mean(axis=1) # (n_samples,) - res = scatter_mean(scores, labels).mean() # per category purity - - return res - - -class scVIModule(nn.Module): - """ - scVI module using NegativeBinomial/Zero-Inflated NegativeBinomial Likelihood objectives - - Parameters - ---------- - n_genes: int - n_treatments: int - covars_encoder: dict - Dictionary of covariates with keys as each covariate name and values as - number of unique values of the corresponding covariate - n_latent: int - Latent Dimension - loss_ae: str - Autoencoder loss (either "gauss" or "nb") - doser_type: str - # Type of doser network, either `mlp` or `linear`. - autoencoder_width: int - autoencoder_depth: int - use_batch_norm: bool - use_layer_norm: bool - variational: bool - """ - - def __init__( - self, - n_genes: int, - n_perts: int, - n_cell_types: int, - n_batches: int = 1, - pert_embeddings: Optional[np.ndarray] = None, - n_latent: int = 128, - n_pert_latent: int = 64, - n_cell_type_latent: int = 32, - n_batch_latent: int = 16, - recon_loss: str = "nb", - n_hidden_encoder: int = 256, - n_layers_encoder: int = 3, - n_hidden_decoder: int = 256, - n_layers_decoder: int = 3, - use_batch_norm: str = "both", - use_layer_norm: str = "none", - dropout_rate_encoder: float = 0.0, - dropout_rate_decoder: float = 0.0, - seed: int = 0, - ): - super().__init__() - - torch.manual_seed(seed) - np.random.seed(seed) - - recon_loss = recon_loss.lower() - - assert recon_loss in ["nb", "zinb"] - - self.n_genes = n_genes - self.n_latent = n_latent - self.n_cell_type_latent = n_cell_type_latent - self.n_batch_latent = n_batch_latent - self.n_pert_latent = n_pert_latent - self.n_perts = n_perts - self.recon_loss = recon_loss - - self.encoder = VariationalEncoder( - n_genes, - n_latent, - var_activation=nn.Softplus(), - n_hidden=n_hidden_encoder, - n_layers=n_layers_encoder, - use_batch_norm=use_batch_norm in ["both", "encoder"], - use_layer_norm=use_layer_norm in ["both", "encoder"], - dropout_rate=dropout_rate_encoder, - activation_fn=nn.ReLU, - return_dist=True, - ) - - n_input_decoder = n_latent + n_cell_type_latent + n_batch_latent + n_pert_latent - - # Decoder components - if self.recon_loss in ["zinb", "nb"]: - # setup the parameters of your generative model, as well as your inference model - self.px_r = torch.nn.Parameter(torch.randn(self.n_genes)) - - # decoder goes from n_latent-dimensional space to n_input-d data - self.decoder = CountDecoder( - n_input=n_input_decoder, - n_output=n_genes, - n_layers=n_layers_decoder, - n_hidden=n_hidden_decoder, - use_norm=( - "batch" - if use_batch_norm in ["both", "decoder"] - else "layer" - if use_layer_norm in ["both", "decoder"] - else "none" - ), - ) - - elif recon_loss == "gauss": - self.decoder = VariationalEncoder( - n_input=n_input_decoder, - n_output=n_genes, - n_layers=n_layers_decoder, - n_hidden=n_hidden_decoder, - dropout_rate=dropout_rate_decoder, - use_norm=( - "batch" - if use_batch_norm in ["both", "decoder"] - else "layer" - if use_layer_norm in ["both", "decoder"] - else "none" - ), - var_activation=nn.Softplus(), - ) - - else: - raise Exception("Invalid Loss function for Autoencoder") - - # Embeddings - if pert_embeddings is not None: - self.pert_embeddings = nn.Embedding.from_pretrained(torch.FloatTensor(pert_embeddings), freeze=True) - self.n_pert_latent = pert_embeddings.shape[1] - else: - self.pert_embeddings = nn.Embedding(n_perts, n_pert_latent) - - if n_batches > 1: - self.batch_embeddings = nn.Embedding(n_batches, n_batch_latent) - - self.cell_type_embeddings = nn.Embedding(n_cell_types, n_cell_type_latent) - - self.metrics = { - "pearson_r": pearson_corrcoef, - "r2_score": r2_score, - "nmi": normalized_mutual_info_score, - } - - def forward(self, x_basal, perts, cell_types, batch_ids, n_samples: int = 1): - enc_outputs = self.forward_encoder( - x=x_basal, - perts=perts, - cell_types=cell_types, - batch_ids=batch_ids, - n_samples=n_samples, - ) - - dec_outputs = self.forward_decoder( - z_basal=enc_outputs["z_basal"], - z=enc_outputs["z"], - library=enc_outputs["library"], - ) - - return enc_outputs, dec_outputs - - def forward_encoder( - self, - x, - perts, - cell_types, - batch_ids: Optional[torch.Tensor] = None, - n_samples: int = 1, - ): - ## TODO: remove unused - # batch_size = x.shape[0] - - if self.recon_loss in ["nb", "zinb"]: - # log the input to the variational distribution for numerical stability - x_ = torch.log(1 + x) - library = torch.log(x.sum(1)).unsqueeze(1) - else: - x_ = x - library = None, None - - qz, z_basal = self.encoder(x_) - - if n_samples > 1: - sampled_z = qz.sample((n_samples,)) - z_basal = self.encoder.z_transformation(sampled_z) - if self.recon_loss in ["nb", "zinb"]: - library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) - - z_covs = self.cell_type_embeddings(cell_types.long()) - z_batch = self.batch_embeddings(batch_ids.long()) - z_pert = self.pert_embeddings(perts.long()) - - z = torch.cat([z_basal, z_covs, z_batch, z_pert], dim=-1) - z_corrected = torch.cat([z_basal, z_covs, z_pert], dim=-1) - - return dict( - z=z, - z_basal=z_basal, - z_corrected=z_corrected, - library=library, - qz=qz, - ) - - def forward_decoder( - self, - z_basal, - z, - library, - ): - if self.recon_loss == "nb": - px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) - px_r = torch.exp(self.px_r) - - px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) - - elif self.recon_loss == "zinb": - px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) - px_r = torch.exp(self.px_r) - - px = ZeroInflatedNegativeBinomial( - mu=px_rate, - theta=px_r, - zi_logits=px_dropout, - scale=px_scale, - ) - - else: - px_mean, px_var, x_pred = self.decoder(z) - - px = Normal(loc=px_mean, scale=px_var.sqrt()) - - pz = Normal(torch.zeros_like(z_basal), torch.ones_like(z_basal)) - return dict(px=px, pz=pz) - - def loss(self, x_pert, encoder_outputs, decoder_outputs): - """Computes the reconstruction loss (AE) or the ELBO (VAE)""" - px = decoder_outputs["px"] - recon_loss = -px.log_prob(x_pert).sum(dim=-1).mean() - - qz = encoder_outputs["qz"] - pz = decoder_outputs["pz"] - - kl_divergence_z = kl(qz, pz).sum(dim=1) - kl_loss = kl_divergence_z.mean() - - return recon_loss, kl_loss - - def r2_metric(self, x_pert, x_basal, encoder_outputs, decoder_outputs): - px = decoder_outputs["px"] - if self.recon_loss == "gauss": - x_pred_mean = px.loc - - x_pred_mean = torch.nan_to_num(x_pred_mean, nan=0, posinf=1e3, neginf=-1e3) - - r2_mean = torch.nan_to_num(self.metrics["r2_score"](x_pred_mean.mean(0), x_pert.mean(0)), nan=0.0).item() - - lfc_true = (x_pert - x_basal).mean(0) - lfc_pred = (x_pred_mean - x_basal).mean(0) - - r2_lfc = torch.nan_to_num(self.metrics["pearson_r"](lfc_pred, lfc_true), nan=0.0).item() - - elif self.recon_loss in ["nb", "zinb"]: - x_pert = torch.log(1 + x_pert) - x_pred = px.mu - x_pred = torch.log(1 + x_pred) - x_basal = torch.log(1 + x_basal) - - x_pred = torch.nan_to_num(x_pred, nan=0, posinf=1e3, neginf=-1e3) - - r2_mean = torch.nan_to_num(self.metrics["r2_score"](x_pred.mean(0), x_pert.mean(0)), nan=0.0).item() - - lfc_true = (x_pert - x_basal).mean(0) - lfc_pred = (x_pred - x_basal).mean(0) - - r2_lfc = torch.nan_to_num(self.metrics["pearson_r"](lfc_pred, lfc_true), nan=0.0).item() - - return r2_mean, r2_lfc - - def sample(self, x_basal, perts, cell_types, batch_ids, n_samples=1): - """Computes knockout gene expression - - Parameters - ---------- - tensors : dict - dictionary of input tensors - - """ - _, decoder_outputs = self.forward( - x_basal=x_basal, - perts=perts, - cell_types=cell_types, - batch_ids=batch_ids, - n_samples=n_samples, - ) - - px = decoder_outputs["px"] - - if self.recon_loss == "gauss": - output_key = "loc" - else: - output_key = "mu" - - output = getattr(px, output_key) - - return output diff --git a/src/state/tx/models/scvi/_task.py b/src/state/tx/models/scvi/_task.py deleted file mode 100644 index 71a59671..00000000 --- a/src/state/tx/models/scvi/_task.py +++ /dev/null @@ -1,552 +0,0 @@ -import math -from collections import defaultdict - -import lightning as L -import numpy as np -import torch -from torch import nn -from torch.optim.lr_scheduler import StepLR -from torchmetrics.functional import accuracy - -from .._base_modules import FocalLoss -from ._module import CPAModule - - -class CPATrainer(L.LightningModule): - def __init__( - self, - module: CPAModule, - lr=5e-4, - wd=1e-6, - n_steps_pretrain_ae: int = None, - n_epochs_pretrain_ae: int = None, - n_steps_kl_warmup: int = None, - n_epochs_kl_warmup: int = None, - n_steps_adv_warmup: int = None, - n_epochs_adv_warmup: int = None, - adv_steps: int = 3, - reg_adv: float = 1.0, - pen_adv: float = 1.0, - adv_lr=1e-3, - adv_wd=1e-6, - step_size_lr: int = 45, - do_clip_grad: bool = False, - gradient_clip_value: float = 3.0, - adv_loss: str = "cce", - check_val_every_n_epoch: int = 5, - **kwargs, - ): - """Training plan for the CPA model""" - super().__init__() - - self.module = module - - self.lr = float(lr) - self.n_steps_kl_warmup = n_steps_kl_warmup - self.n_epochs_kl_warmup = n_epochs_kl_warmup - - self.automatic_optimization = False - - self.wd = float(wd) - - self.n_perts = module.n_perts - - self.n_steps_pretrain_ae = n_steps_pretrain_ae - self.n_epochs_pretrain_ae = n_epochs_pretrain_ae - - self.n_steps_adv_warmup = n_steps_adv_warmup - self.n_epochs_adv_warmup = n_epochs_adv_warmup - self.adv_steps = adv_steps - - self.reg_adv = reg_adv - self.pen_adv = pen_adv - - self.adv_lr = float(adv_lr) - self.adv_wd = float(adv_wd) - - self.step_size_lr = step_size_lr - - self.do_clip_grad = do_clip_grad - self.gradient_clip_value = gradient_clip_value - self.check_val_every_n_epoch = check_val_every_n_epoch - - self.metrics = [ - "recon_loss", - "KL", - "disnt_basal", - "disnt_after", - "r2_mean", - "r2_var", - "adv_loss", - "penalty_adv", - "adv_perts", - "acc_perts", - "penalty_perts", - ] - - self.epoch_history = defaultdict(list) - - ## TODO: remove unused - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.adv_loss = adv_loss.lower() - self.gamma = kwargs.get("gamma", 2.0) - if self.adv_loss == "focal": - self.adv_loss_fn = FocalLoss(gamma=self.gamma, reduction="mean") - else: - self.adv_loss_fn = nn.CrossEntropyLoss() - - @property - def kl_weight(self): - return 0.0 - - @property - def adv_lambda(self): - slope = self.reg_adv - if self.n_steps_adv_warmup: - global_step = self.global_step - - if self.n_steps_pretrain_ae: - global_step -= self.n_steps_pretrain_ae - - if global_step <= self.n_steps_adv_warmup: - proportion = global_step / self.n_steps_adv_warmup - return slope * proportion - else: - return slope - elif self.n_epochs_adv_warmup is not None: - current_epoch = self.current_epoch - - if self.n_epochs_pretrain_ae: - current_epoch -= self.n_epochs_pretrain_ae - - if current_epoch <= self.n_epochs_adv_warmup: - proportion = current_epoch / self.n_epochs_adv_warmup - return slope * proportion - else: - return slope - else: - return slope - - @property - def do_start_adv_training(self): - if self.n_steps_pretrain_ae: - return self.global_step > self.n_steps_pretrain_ae - elif self.n_epochs_pretrain_ae: - return self.current_epoch > self.n_epochs_pretrain_ae - else: - return True - - def adversarial_loss(self, batch, z_basal, compute_penalty=True): - """Computes adversarial classification losses and regularizations""" - if compute_penalty: - z_basal = z_basal.requires_grad_(True) - - adv_logits = self.module.forward_adv(z_basal) - perts = batch["pert_emb"].argmax(1) - pert_logits = adv_logits["pert_logits"] - - pert_adv_loss = self.adv_loss_fn(pert_logits, perts.long()) - pert_acc = accuracy( - pert_logits.argmax(1), - perts.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_perts, - ) - - cell_types = batch["cell_type_onehot"].argmax(1) - cell_types_logits = adv_logits["cell_type_logits"] - cell_types_adv_loss = self.adv_loss_fn(cell_types_logits, cell_types.long()) - cell_types_acc = accuracy( - cell_types_logits.argmax(1), - cell_types.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_perts, - ) - - batch_ids = batch["batch"].argmax(1) - batch_ids_logits = adv_logits["batch_logits"] - batch_ids_adv_loss = self.adv_loss_fn(batch_ids_logits, batch_ids.long()) - batch_ids_acc = accuracy( - batch_ids_logits.argmax(1), - batch_ids.long().view( - -1, - ), - average="macro", - task="multiclass", - num_classes=self.n_perts, - ) - - adv_loss = pert_adv_loss + cell_types_adv_loss + batch_ids_adv_loss - adv_acc = (pert_acc + cell_types_acc + batch_ids_acc) / 3.0 - - if compute_penalty: - # Penalty losses - cell_type_penalty = ( - torch.autograd.grad( - cell_types_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - batch_penalty = ( - torch.autograd.grad( - batch_ids_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - pert_penalty = ( - torch.autograd.grad( - pert_logits.sum(), - z_basal, - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - .pow(2) - .mean() - ) - - total_penalty = cell_type_penalty + batch_penalty + pert_penalty - - else: - total_penalty = torch.tensor(0.0, device=z_basal.device) - - return adv_loss, adv_acc, total_penalty - - def configure_optimizers(self): - ae_params = ( - list(filter(lambda p: p.requires_grad, self.module.encoder.parameters())) - + list(filter(lambda p: p.requires_grad, self.module.decoder.parameters())) - + list( - filter( - lambda p: p.requires_grad, - self.module.pert_embeddings.parameters(), - ) - ) - + list( - filter( - lambda p: p.requires_grad, - self.module.cell_type_embeddings.parameters(), - ) - ) - + list(filter(lambda p: p.requires_grad, self.module.batch_embeddings.parameters())) - ) - - if self.module.recon_loss in ["zinb", "nb"]: - ae_params += [self.module.px_r] - - optimizer_autoencoder = torch.optim.Adam(ae_params, lr=self.lr, weight_decay=self.wd) - - scheduler_autoencoder = StepLR(optimizer_autoencoder, step_size=self.step_size_lr, gamma=0.9) - - adv_params = ( - list( - filter( - lambda p: p.requires_grad, - self.module.perturbation_classifier.parameters(), - ) - ) - + list( - filter( - lambda p: p.requires_grad, - self.module.cell_type_classifier.parameters(), - ) - ) - + list(filter(lambda p: p.requires_grad, self.module.batch_classifier.parameters())) - ) - - optimizer_adversaries = torch.optim.Adam(adv_params, lr=self.adv_lr, weight_decay=self.adv_wd) - scheduler_adversaries = StepLR(optimizer_adversaries, step_size=self.step_size_lr, gamma=0.9) - - optimizers = [optimizer_autoencoder, optimizer_adversaries] - schedulers = [scheduler_autoencoder, scheduler_adversaries] - - if self.step_size_lr is not None: - return optimizers, schedulers - else: - return optimizers - - def training_step(self, batch, batch_idx): - opt, opt_adv = self.optimizers() - - enc_outputs, dec_outputs = self.module.forward(batch) - - recon_loss, kl_loss = self.module.loss( - batch=batch, - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - if self.do_start_adv_training: - if self.adv_steps is None: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal, - compute_penalty=False, - ) - - loss = recon_loss + self.kl_weight * kl_loss - self.adv_lambda * adv_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - - opt_adv.zero_grad() - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - elif batch_idx % self.adv_steps == 0: - opt_adv.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - # Model update - else: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal, - compute_penalty=False, - ) - - loss = recon_loss + self.kl_weight * kl_loss - self.adv_lambda * adv_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - else: - opt.zero_grad() - - z_basal = enc_outputs["z_basal"] - - loss = recon_loss + self.kl_weight * kl_loss - - self.manual_backward(loss) - - if self.do_clip_grad: - self.clip_gradients( - opt, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt.step() - - opt_adv.zero_grad() - - adv_loss, adv_acc, adv_penalty = self.adversarial_loss( - batch=batch, - z_basal=z_basal.detach(), - compute_penalty=True, - ) - - adv_loss = adv_loss + self.pen_adv * adv_penalty - - self.manual_backward(adv_loss) - - if self.do_clip_grad: - self.clip_gradients( - opt_adv, - gradient_clip_val=self.gradient_clip_value, - gradient_clip_algorithm="norm", - ) - - opt_adv.step() - - r2_mean, r2_lfc = self.module.r2_metric(batch, enc_outputs, dec_outputs) - - disnt_basal, disnt_after = self.module.disentanglement(batch, enc_outputs, dec_outputs) - - results = { - "recon_loss": recon_loss.item(), - "KL": kl_loss.item(), - "r2_mean": r2_mean, - "r2_lfc": r2_lfc, - "adv_loss": adv_loss.item(), - "adv_acc": adv_acc.item(), - "penalty_adv": adv_penalty.item(), - "es_metric": r2_mean + np.e ** (disnt_after - disnt_basal), - "disnt_basal": disnt_basal, - "disnt_after": disnt_after, - } - - self.log( - "recon_loss", - recon_loss.item(), - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "r2_mean", - r2_mean, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "r2_lfc", - r2_lfc, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "adv_loss", - adv_loss.item(), - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "disnt_basal", - disnt_basal, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "disnt_after", - disnt_after, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - self.log( - "adv_acc", - adv_acc, - on_epoch=True, - on_step=False, - prog_bar=True, - logger=True, - ) - - if self.global_step % self.step_size_lr * 1000 == 0: - sch, sch_adv = self.lr_schedulers() - sch.step() - sch_adv.step() - - return results - - def validation_step(self, batch, batch_idx): - enc_outputs, dec_outputs = self.module.forward(batch) - - recon_loss, kl_loss = self.module.loss( - batch=batch, - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - r2_mean, r2_lfc = self.module.r2_metric(batch, enc_outputs, dec_outputs) - - disnt_basal, disnt_after = self.module.disentanglement(batch, enc_outputs, dec_outputs) - - self.log("val_r2_mean", r2_mean, prog_bar=True) - self.log("val_r2_lfc", r2_lfc, prog_bar=True) - self.log("es_metric", r2_mean + math.e ** (disnt_after - disnt_basal), prog_bar=True) - - def test_step(self, batch, batch_idx): - enc_outputs, dec_outputs = self.module.forward(batch) - - recon_loss, kl_loss = self.module.loss( - batch=batch, - encoder_outputs=enc_outputs, - decoder_outputs=dec_outputs, - ) - - r2_mean, r2_lfc = self.module.r2_metric(batch, enc_outputs, dec_outputs) - - self.log("test_recon", recon_loss.item(), prog_bar=True) - self.log("test_r2_mean", r2_mean, prog_bar=True) - self.log("test_r2_lfc", r2_lfc, prog_bar=True) - - x_pred = self.module.get_expression(batch, n_samples=1) - - return x_pred.detach().cpu().numpy() diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index b9ead679..979d5496 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -1,7 +1,6 @@ import logging import math -import anndata as ad import numpy as np import torch import torch.nn as nn @@ -11,91 +10,12 @@ from typing import Dict, Optional, Tuple from .base import PerturbationModel -from .decoders import FinetuneVCICountsDecoder from .utils import build_mlp, get_activation_class, get_transformer_backbone, apply_lora logger = logging.getLogger(__name__) -class CombinedLoss(nn.Module): - """Combined Sinkhorn + Energy loss.""" - - def __init__(self, sinkhorn_weight=0.001, energy_weight=1.0, blur=0.05): - super().__init__() - self.sinkhorn_weight = sinkhorn_weight - self.energy_weight = energy_weight - self.sinkhorn_loss = SamplesLoss(loss="sinkhorn", blur=blur) - self.energy_loss = SamplesLoss(loss="energy", blur=blur) - - def forward(self, pred, target): - sinkhorn_val = self.sinkhorn_loss(pred, target) - energy_val = self.energy_loss(pred, target) - return self.sinkhorn_weight * sinkhorn_val + self.energy_weight * energy_val - - -class ConfidenceToken(nn.Module): - """ - Learnable confidence token that gets appended to the input sequence - and learns to predict the expected loss value. - """ - - def __init__(self, hidden_dim: int, dropout: float = 0.1): - super().__init__() - # Learnable confidence token embedding - self.confidence_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) - - # Projection head to map confidence token output to scalar loss prediction - self.confidence_projection = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim // 2), - nn.LayerNorm(hidden_dim // 2), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim // 2, hidden_dim // 4), - nn.LayerNorm(hidden_dim // 4), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim // 4, 1), - nn.ReLU(), # Ensure positive loss prediction - ) - - def append_confidence_token(self, seq_input: torch.Tensor) -> torch.Tensor: - """ - Append confidence token to the sequence input. - - Args: - seq_input: Input tensor of shape [B, S, E] - - Returns: - Extended tensor of shape [B, S+1, E] - """ - batch_size = seq_input.size(0) - # Expand confidence token to batch size - confidence_tokens = self.confidence_token.expand(batch_size, -1, -1) - # Concatenate along sequence dimension - return torch.cat([seq_input, confidence_tokens], dim=1) - - def extract_confidence_prediction(self, transformer_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Extract main output and confidence prediction from transformer output. - - Args: - transformer_output: Output tensor of shape [B, S+1, E] - - Returns: - main_output: Tensor of shape [B, S, E] - confidence_pred: Tensor of shape [B, 1] - """ - # Split the output - main_output = transformer_output[:, :-1, :] # [B, S, E] - confidence_output = transformer_output[:, -1:, :] # [B, 1, E] - - # Project confidence token output to scalar - confidence_pred = self.confidence_projection(confidence_output).squeeze(-1) # [B, 1] - - return main_output, confidence_pred - - class StateTransitionPerturbationModel(PerturbationModel): """ This model: @@ -147,12 +67,11 @@ def __init__( # Save or store relevant hyperparams self.predict_residual = predict_residual self.output_space = output_space - self.n_encoder_layers = kwargs.get("n_encoder_layers", 2) - self.n_decoder_layers = kwargs.get("n_decoder_layers", 2) + self.n_encoder_layers = kwargs.get("n_encoder_layers", 1) + self.n_decoder_layers = kwargs.get("n_decoder_layers", 1) self.activation_class = get_activation_class(kwargs.get("activation", "gelu")) self.cell_sentence_len = kwargs.get("cell_set_len", 256) self.decoder_loss_weight = kwargs.get("decoder_weight", 1.0) - self.regularization = kwargs.get("regularization", 0.0) self.detach_decoder = kwargs.get("detach_decoder", False) self.transformer_backbone_key = transformer_backbone_key @@ -161,8 +80,28 @@ def __init__( self.distributional_loss = distributional_loss self.gene_dim = gene_dim - self.mmd_num_chunks = max(int(kwargs.get("mmd_num_chunks", 1)), 1) - self.randomize_mmd_chunks = bool(kwargs.get("randomize_mmd_chunks", False)) + self.nb_loss = bool(kwargs.get("nb_loss", False)) + self.nb_eps = float(kwargs.get("nb_eps", 1e-8)) + default_nb_embed_loss_weight = 1.0 if self.embed_key is not None else 0.0 + self.nb_embed_loss_weight = float(kwargs.get("nb_embed_loss_weight", default_nb_embed_loss_weight)) + if self.nb_loss and self.output_space == "embedding": + raise ValueError( + "nb_loss=True is incompatible with output_space='embedding'. " + "Use output_space='gene' or output_space='all'." + ) + if self.nb_loss and self.output_space not in {"gene", "all"}: + raise ValueError(f"nb_loss=True requires output_space in {{'gene', 'all'}}; got {self.output_space!r}.") + if self.nb_loss: + if self.gene_decoder is not None: + logger.info("nb_loss=True: disabling gene_decoder and decoder loss branches.") + self.gene_decoder = None + self.gene_decoder_bool = False + self.decoder_cfg = None + try: + self.hparams["gene_decoder_bool"] = False # type: ignore[index] + self.hparams["decoder_cfg"] = None # type: ignore[index] + except Exception: + pass # Build the distributional loss from geomloss blur = kwargs.get("blur", 0.05) @@ -172,9 +111,10 @@ def __init__( elif loss_name == "mse": self.loss_fn = nn.MSELoss() elif loss_name == "se": - sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01) - energy_weight = kwargs.get("energy_weight", 1.0) - self.loss_fn = CombinedLoss(sinkhorn_weight=sinkhorn_weight, energy_weight=energy_weight, blur=blur) + raise ValueError( + "loss='se' (combined sinkhorn+energy) has been removed. " + "Use loss='energy', loss='sinkhorn', or loss='mse'." + ) elif loss_name == "sinkhorn": self.loss_fn = SamplesLoss(loss="sinkhorn", blur=blur) else: @@ -184,6 +124,27 @@ def __init__( # Build the underlying neural OT network self._build_networks(lora_cfg=kwargs.get("lora", None)) + self.nb_parameter_head: Optional[nn.Module] = None + self.nb_target_dim = int(self.output_dim) + if self.nb_loss: + if self.output_space == "all": + self.nb_target_dim = int(self.gene_dim if self.gene_dim is not None else self.output_dim) + elif self.embed_key is not None and self.embed_key != "X_hvg": + self.nb_target_dim = int(self.hvg_dim if self.hvg_dim is not None else self.output_dim) + + self.nb_parameter_head = build_mlp( + in_dim=self.output_dim, + out_dim=self.nb_target_dim * 2, + hidden_dim=self.hidden_dim, + n_layers=self.n_decoder_layers, + dropout=self.dropout, + activation=self.activation_class, + ) + logger.info( + "NB loss enabled for state transition model (nb_target_dim=%d, nb_embed_loss_weight=%.3f).", + self.nb_target_dim, + self.nb_embed_loss_weight, + ) # Add an optional encoder that introduces a batch variable self.batch_encoder = None @@ -196,102 +157,23 @@ def __init__( ) self.batch_dim = batch_dim - # Optional batch predictor ablation: learns a single batch token added to every position, - # and adds an auxiliary per-token batch classification head + CE loss. - self.batch_predictor = bool(kwargs.get("batch_predictor", False)) - # If batch_encoder is enabled, disable batch_predictor per request - if self.batch_encoder is not None and self.batch_predictor: - logger.warning( - "Both model.kwargs.batch_encoder and model.kwargs.batch_predictor are True. " - "Disabling batch_predictor and proceeding with batch_encoder." - ) - self.batch_predictor = False - try: - # Keep hparams in sync if available - self.hparams["batch_predictor"] = False # type: ignore[index] - except Exception: - pass - - self.batch_predictor_weight = float(kwargs.get("batch_predictor_weight", 0.1)) - self.batch_predictor_num_classes: Optional[int] = batch_dim if self.batch_predictor else None - if self.batch_predictor: - if self.batch_predictor_num_classes is None: - raise ValueError("batch_predictor=True requires a valid `batch_dim` (number of batch classes).") - # A single learnable batch token that is added to each position - self.batch_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim)) - # Simple per-token classifier from transformer hidden to batch classes - self.batch_classifier = build_mlp( - in_dim=self.hidden_dim, - out_dim=self.batch_predictor_num_classes, - hidden_dim=self.hidden_dim, - n_layers=4, - dropout=self.dropout, - activation=self.activation_class, - ) - else: - self.batch_token = None - self.batch_classifier = None - # Internal cache for last token features (B, S, H) from transformer for aux loss - self._token_features: Optional[torch.Tensor] = None - # if the model is outputting to counts space, apply relu # otherwise its in embedding space and we don't want to is_gene_space = kwargs["embed_key"] == "X_hvg" or kwargs["embed_key"] is None if is_gene_space or self.gene_decoder is None: self.relu = torch.nn.ReLU() - self.use_batch_token = kwargs.get("use_batch_token", False) - self.basal_mapping_strategy = basal_mapping_strategy - # Disable batch token only for truly incompatible cases - disable_reasons = [] - if self.batch_encoder and self.use_batch_token: - disable_reasons.append("batch encoder is used") - if basal_mapping_strategy == "random" and self.use_batch_token: - disable_reasons.append("basal mapping strategy is random") - - if disable_reasons: - self.use_batch_token = False + if kwargs.get("use_batch_token", False) or kwargs.get("batch_predictor", False): logger.warning( - f"Batch token is not supported when {' or '.join(disable_reasons)}, setting use_batch_token to False" + "Batch-token logic has been removed from StateTransitionPerturbationModel. " + "Ignoring model.kwargs.use_batch_token and model.kwargs.batch_predictor." ) - try: - self.hparams["use_batch_token"] = False - except Exception: - pass - self.batch_token_weight = kwargs.get("batch_token_weight", 0.1) - self.batch_token_num_classes: Optional[int] = batch_dim if self.use_batch_token else None - - if self.use_batch_token: - if self.batch_token_num_classes is None: - raise ValueError("batch_token_num_classes must be set when use_batch_token is True") - self.batch_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim)) - self.batch_classifier = build_mlp( - in_dim=self.hidden_dim, - out_dim=self.batch_token_num_classes, - hidden_dim=self.hidden_dim, - n_layers=1, - dropout=self.dropout, - activation=self.activation_class, - ) - else: - self.batch_token = None - self.batch_classifier = None - - # Internal cache for last token features (B, S, H) from transformer for aux loss - self._batch_token_cache: Optional[torch.Tensor] = None - - # initialize a confidence token - self.confidence_token = None - self.confidence_loss_fn = None if kwargs.get("confidence_token", False): - self.confidence_token = ConfidenceToken(hidden_dim=self.hidden_dim, dropout=self.dropout) - self.confidence_loss_fn = nn.MSELoss() - self.confidence_target_scale = float(kwargs.get("confidence_target_scale", 10.0)) - self.confidence_weight = float(kwargs.get("confidence_weight", 0.01)) - else: - self.confidence_target_scale = None - self.confidence_weight = 0.0 + logger.warning( + "Confidence-token logic has been removed from StateTransitionPerturbationModel. " + "Ignoring model.kwargs.confidence_token." + ) # Backward-compat: accept legacy key `freeze_pert` self.freeze_pert_backbone = kwargs.get("freeze_pert_backbone", kwargs.get("freeze_pert", False)) @@ -306,24 +188,10 @@ def __init__( for param in self.project_out.parameters(): param.requires_grad = False - control_pert = kwargs.get("control_pert", "non-targeting") - if kwargs.get("finetune_vci_decoder", False): # TODO: This will go very soon - # Prefer the gene names supplied by the data module (aligned to training output) - gene_names = self.gene_names - if gene_names is None: - raise ValueError( - "finetune_vci_decoder=True but model.gene_names is None. " - "Please provide gene_names via data module var_dims." - ) - - n_genes = len(gene_names) - logger.info( - f"Initializing FinetuneVCICountsDecoder with {n_genes} genes (output_space={output_space}; " - + ("HVG subset" if output_space == "gene" else "all genes") - + ")" - ) - self.gene_decoder = FinetuneVCICountsDecoder( - genes=gene_names, + if kwargs.get("finetune_vci_decoder", False): + logger.warning( + "model.kwargs.finetune_vci_decoder is no longer supported. " + "Ignoring it and using the standard latent-to-gene decoder path." ) print(self) @@ -366,9 +234,6 @@ def _build_networks(self, lora_cfg=None): lora_cfg, ) - # Project from input_dim to hidden_dim for transformer input - # self.project_to_hidden = nn.Linear(self.input_dim, self.hidden_dim) - self.project_out = build_mlp( in_dim=self.hidden_dim, out_dim=self.output_dim, @@ -393,7 +258,107 @@ def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor: """Define how we embed basal state input, if needed.""" return self.basal_encoder(expr) - def forward(self, batch: dict, padded=True) -> torch.Tensor: + def _main_loss_is_expression(self) -> bool: + if self.nb_loss: + return True + return super()._main_loss_is_expression() + + @staticmethod + def _suspected_discrete_torch(x: torch.Tensor, n_cells: int = 100) -> bool: + if x.numel() == 0: + return False + flat = x.reshape(-1, x.shape[-1]) + top_n = min(flat.shape[0], n_cells) + rowsum = flat[:top_n].sum(dim=1) + frac_part = rowsum - rowsum.floor() + return bool(torch.all(torch.abs(frac_part) < 1e-7)) + + @staticmethod + def _suspected_log_torch(x: torch.Tensor) -> bool: + if x.numel() == 0: + return False + return bool(x.max().item() < 15.0) + + def _to_count_space(self, x: torch.Tensor) -> torch.Tensor: + x_float = x.float() + is_discrete = self._suspected_discrete_torch(x_float) + is_log = self._suspected_log_torch(x_float) + + if (not is_discrete) and is_log: + counts = torch.expm1(x_float) + else: + counts = x_float + + counts = torch.nan_to_num(counts, nan=0.0, posinf=0.0, neginf=0.0) + return counts.clamp_min(0.0).round() + + def _compute_set_library_sizes_from_control(self, ctrl_cells: torch.Tensor) -> torch.Tensor: + ctrl_counts = self._to_count_space(ctrl_cells) + per_cell_library_sizes = ctrl_counts.sum(dim=-1) + per_cell_library_sizes = torch.nan_to_num(per_cell_library_sizes, nan=0.0, posinf=0.0, neginf=0.0) + per_set_library_sizes = per_cell_library_sizes.median(dim=1, keepdim=True).values.clamp_min(1.0) + return per_set_library_sizes.unsqueeze(-1) + + def _reshape_sequence_tensor(self, x: torch.Tensor, padded: bool) -> torch.Tensor: + if padded: + return x.reshape(-1, self.cell_sentence_len, x.shape[-1]) + return x.reshape(1, -1, x.shape[-1]) + + def _get_nb_target_tensor( + self, batch: Dict[str, torch.Tensor], fallback_target: torch.Tensor, padded: bool + ) -> torch.Tensor: + pert_counts = batch.get("pert_cell_counts", None) + if pert_counts is not None: + return self._reshape_sequence_tensor(pert_counts.to(fallback_target.device), padded) + return fallback_target + + def _get_nb_control_tensor_for_library( + self, + batch: Dict[str, torch.Tensor], + fallback_ctrl: torch.Tensor, + padded: bool, + ) -> torch.Tensor: + ctrl_counts = batch.get("ctrl_cell_counts", None) + if ctrl_counts is not None: + return self._reshape_sequence_tensor(ctrl_counts.to(fallback_ctrl.device), padded) + return fallback_ctrl + + def _compute_nb_nll_loss( + self, + nb_mean: torch.Tensor, + nb_dispersion: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + target_counts = self._to_count_space(target).to(nb_mean.dtype) + if nb_mean.shape != nb_dispersion.shape: + raise RuntimeError( + f"NB parameter shape mismatch: mean={tuple(nb_mean.shape)} dispersion={tuple(nb_dispersion.shape)}" + ) + if target_counts.shape[-1] != nb_mean.shape[-1]: + raise RuntimeError( + "NB target dimension mismatch: " + f"target={target_counts.shape[-1]} vs nb_params={nb_mean.shape[-1]}. " + "Ensure pert_cell_counts has the expected gene dimension for NB training." + ) + mu = nb_mean.clamp_min(self.nb_eps) + theta = nb_dispersion.clamp_min(self.nb_eps) + log_theta_mu_eps = torch.log(theta + mu + self.nb_eps) + log_nb = ( + theta * (torch.log(theta + self.nb_eps) - log_theta_mu_eps) + + target_counts * (torch.log(mu + self.nb_eps) - log_theta_mu_eps) + + torch.lgamma(target_counts + theta) + - torch.lgamma(theta) + - torch.lgamma(target_counts + 1) + ) + recon_loss_all = -log_nb + return torch.nanmean(recon_loss_all.reshape(recon_loss_all.shape[0], -1), dim=1) + + def forward( + self, + batch: dict, + padded=True, + return_nb_params: bool = False, + ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ The main forward call. Batch is a flattened sequence of cell sentences, which we reshape into sequences of length cell_sentence_len. @@ -440,24 +405,13 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: batch_embeddings = self.batch_encoder(batch_indices.long()) # Shape: [B, S, hidden_dim] seq_input = seq_input + batch_embeddings - if self.use_batch_token and self.batch_token is not None: - batch_size, _, _ = seq_input.shape - # Prepend the batch token to the sequence along the sequence dimension - # [B, S, H] -> [B, S+1, H], batch token at position 0 - seq_input = torch.cat([self.batch_token.expand(batch_size, -1, -1), seq_input], dim=1) - - confidence_pred = None - if self.confidence_token is not None: - # Append confidence token: [B, S, E] -> [B, S+1, E] (might be one more if we have the batch token) - seq_input = self.confidence_token.append_confidence_token(seq_input) - # forward pass + extract CLS last hidden state if self.hparams.get("mask_attn", False): batch_size, seq_length, _ = seq_input.shape device = seq_input.device self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] - # create a [1,1,S,S] mask (now S+1 if confidence token is used) + # create a [1,1,S,S] mask base = torch.eye(seq_length, device=device, dtype=torch.bool).view(1, 1, seq_length, seq_length) # Get number of attention heads from model config @@ -472,36 +426,11 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: outputs = self.transformer_backbone(inputs_embeds=seq_input) transformer_output = outputs.last_hidden_state - # Extract outputs accounting for optional prepended batch token and optional confidence token at the end - if self.confidence_token is not None and self.use_batch_token and self.batch_token is not None: - # transformer_output: [B, 1 + S + 1, H] -> batch token at 0, cells 1..S, confidence at -1 - batch_token_pred = transformer_output[:, :1, :] # [B, 1, H] - res_pred, confidence_pred = self.confidence_token.extract_confidence_prediction( - transformer_output[:, 1:, :] - ) - # res_pred currently excludes the confidence token and starts from former index 1 - self._batch_token_cache = batch_token_pred - elif self.confidence_token is not None: - # Only confidence token appended at the end - res_pred, confidence_pred = self.confidence_token.extract_confidence_prediction(transformer_output) - self._batch_token_cache = None - elif self.use_batch_token and self.batch_token is not None: - # Only batch token prepended at the beginning - batch_token_pred = transformer_output[:, :1, :] # [B, 1, H] - res_pred = transformer_output[:, 1:, :] # [B, S, H] - self._batch_token_cache = batch_token_pred - else: - # Neither special token used - res_pred = transformer_output - self._batch_token_cache = None - - # Cache token features for auxiliary batch prediction loss (B, S, H) - self._token_features = res_pred + res_pred = transformer_output # add to basal if predicting residual if self.predict_residual and self.output_space == "all": # Project control_cells to hidden_dim space to match res_pred - # control_cells_hidden = self.project_to_hidden(control_cells) # treat the actual prediction as a residual sum to basal out_pred = self.project_out(res_pred) + basal out_pred = self.final_down_then_up(out_pred) @@ -512,40 +441,39 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: # apply relu if specified and we output to HVG space is_gene_space = self.hparams["embed_key"] == "X_hvg" or self.hparams["embed_key"] is None - if is_gene_space or self.gene_decoder is None: + if is_gene_space or (self.gene_decoder is None and not self.nb_loss): out_pred = self.relu(out_pred) + nb_mean = None + nb_dispersion = None + if self.nb_loss: + if self.nb_parameter_head is None: + raise RuntimeError("nb_loss=True but nb_parameter_head was not initialized.") + nb_params = self.nb_parameter_head(out_pred) + px_scale_logits, nb_dispersion_logits = torch.chunk(nb_params, chunks=2, dim=-1) + px_scale = F.softmax(px_scale_logits, dim=-1) + ctrl_for_library = self._get_nb_control_tensor_for_library(batch, basal, padded) + set_library_sizes = self._compute_set_library_sizes_from_control(ctrl_for_library) + nb_mean = px_scale * set_library_sizes + nb_dispersion = F.softplus(nb_dispersion_logits) + self.nb_eps + output = out_pred.reshape(-1, self.output_dim) - if confidence_pred is not None: - return output, confidence_pred - else: + if not self.nb_loss or not return_nb_params: return output + if nb_mean is None or nb_dispersion is None: + raise RuntimeError("nb_loss=True but NB parameters were not produced in forward().") + return output, nb_mean.reshape(-1, self.nb_target_dim), nb_dispersion.reshape(-1, self.nb_target_dim) def _compute_distribution_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """Apply the primary distributional loss, optionally chunking feature dimensions for SamplesLoss.""" - - if isinstance(self.loss_fn, SamplesLoss) and self.mmd_num_chunks > 1: - feature_dim = pred.shape[-1] - num_chunks = min(self.mmd_num_chunks, feature_dim) - if num_chunks > 1 and feature_dim > 0: - if self.randomize_mmd_chunks and self.training: - perm = torch.randperm(feature_dim, device=pred.device) - pred = pred.index_select(-1, perm) - target = target.index_select(-1, perm) - pred_chunks = torch.chunk(pred, num_chunks, dim=-1) - target_chunks = torch.chunk(target, num_chunks, dim=-1) - chunk_losses = [self.loss_fn(p_chunk, t_chunk) for p_chunk, t_chunk in zip(pred_chunks, target_chunks)] - return torch.stack(chunk_losses, dim=0).nanmean(dim=0) - + """Apply the primary distributional loss.""" return self.loss_fn(pred, target) def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=True) -> torch.Tensor: """Training step logic for both main model and decoder.""" # Get model predictions (in latent space) - confidence_pred = None - if self.confidence_token is not None: - pred, confidence_pred = self.forward(batch, padded=padded) + if self.nb_loss: + pred, nb_mean_flat, nb_dispersion_flat = self.forward(batch, padded=padded, return_nb_params=True) else: pred = self.forward(batch, padded=padded) @@ -558,73 +486,34 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T pred = pred.reshape(1, -1, self.output_dim) target = target.reshape(1, -1, self.output_dim) - per_set_main_losses = self._compute_distribution_loss(pred, target) + embedding_aux_loss = None + if self.nb_loss: + if padded: + nb_mean = nb_mean_flat.reshape(-1, self.cell_sentence_len, self.nb_target_dim) + nb_dispersion = nb_dispersion_flat.reshape(-1, self.cell_sentence_len, self.nb_target_dim) + else: + nb_mean = nb_mean_flat.reshape(1, -1, self.nb_target_dim) + nb_dispersion = nb_dispersion_flat.reshape(1, -1, self.nb_target_dim) + + nb_target = self._get_nb_target_tensor(batch, target, padded) + per_set_main_losses = self._compute_nb_nll_loss(nb_mean, nb_dispersion, nb_target) + if self.nb_embed_loss_weight > 0.0: + embedding_aux_losses = self._compute_distribution_loss(pred, target) + embedding_aux_loss = torch.nanmean(embedding_aux_losses) + self.log("train/embedding_loss", embedding_aux_loss) + else: + per_set_main_losses = self._compute_distribution_loss(pred, target) main_loss = torch.nanmean(per_set_main_losses) - self.log("train_loss", main_loss) - - # Log individual loss components if using combined loss - if hasattr(self.loss_fn, "sinkhorn_loss") and hasattr(self.loss_fn, "energy_loss"): - sinkhorn_component = self.loss_fn.sinkhorn_loss(pred, target).nanmean() - energy_component = self.loss_fn.energy_loss(pred, target).nanmean() - self.log("train/sinkhorn_loss", sinkhorn_component) - self.log("train/energy_loss", energy_component) + self.log(self._train_main_loss_key(), main_loss) # Process decoder if available decoder_loss = None total_loss = main_loss + if embedding_aux_loss is not None: + total_loss = total_loss + self.nb_embed_loss_weight * embedding_aux_loss - if self.use_batch_token and self.batch_classifier is not None and self._batch_token_cache is not None: - logits = self.batch_classifier(self._batch_token_cache) # [B, 1, C] - batch_token_targets = batch["batch"] - - B = logits.shape[0] - C = logits.size(-1) - - # Prepare one label per sequence (all S cells share the same batch) - if batch_token_targets.dim() > 1 and batch_token_targets.size(-1) == C: - # One-hot labels; reshape to [B, S, C] - if padded: - target_oh = batch_token_targets.reshape(-1, self.cell_sentence_len, C) - else: - target_oh = batch_token_targets.reshape(1, -1, C) - sentence_batch_labels = target_oh.argmax(-1) - else: - # Integer labels; reshape to [B, S] - if padded: - sentence_batch_labels = batch_token_targets.reshape(-1, self.cell_sentence_len) - else: - sentence_batch_labels = batch_token_targets.reshape(1, -1) - - if sentence_batch_labels.shape[0] != B: - sentence_batch_labels = sentence_batch_labels.reshape(B, -1) - - if self.basal_mapping_strategy == "batch": - uniform_mask = sentence_batch_labels.eq(sentence_batch_labels[:, :1]).all(dim=1) - if not torch.all(uniform_mask): - bad_indices = torch.where(~uniform_mask)[0] - label_strings = [] - for idx in bad_indices: - labels = sentence_batch_labels[idx].detach().cpu().tolist() - logger.error("Batch labels for sentence %d: %s", idx.item(), labels) - label_strings.append(f"sentence {idx.item()}: {labels}") - raise ValueError( - "Expected all cells in a sentence to share the same batch when " - "basal_mapping_strategy is 'batch'. " - f"Found mixed batch labels: {', '.join(label_strings)}" - ) - - target_idx = sentence_batch_labels[:, 0] - - # Safety: ensure exactly one target per sequence - if target_idx.numel() != B: - target_idx = target_idx.reshape(-1)[:B] - - ce_loss = F.cross_entropy(logits.reshape(B, -1, C).squeeze(1), target_idx.long()) - self.log("train/batch_token_loss", ce_loss) - total_loss = total_loss + self.batch_token_weight * ce_loss - - # Auxiliary batch prediction loss (per token), if enabled - if self.gene_decoder is not None and "pert_cell_counts" in batch: + # Decoder loss in gene space, if a decoder is configured. + if (not self.nb_loss) and self.gene_decoder is not None and "pert_cell_counts" in batch: gene_targets = batch["pert_cell_counts"] # Train decoder to map latent predictions to gene space @@ -647,63 +536,41 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T decoder_loss = decoder_per_set.mean() # Log decoder loss - self.log("decoder_loss", decoder_loss) + self.log(self._train_expression_loss_key(), decoder_loss) total_loss = total_loss + self.decoder_loss_weight * decoder_loss - if confidence_pred is not None: - confidence_pred_vals = confidence_pred - if confidence_pred_vals.dim() > 1: - confidence_pred_vals = confidence_pred_vals.squeeze(-1) - confidence_targets = per_set_main_losses.detach() - if self.confidence_target_scale is not None: - confidence_targets = confidence_targets * self.confidence_target_scale - confidence_targets = confidence_targets.to(confidence_pred_vals.device) - - confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) - self.log("train/confidence_loss", confidence_loss) - self.log("train/actual_loss", confidence_targets.mean()) - - total_loss = total_loss + confidence_loss - - if self.regularization > 0.0: - ctrl_cell_emb = batch["ctrl_cell_emb"].reshape_as(pred) - delta = pred - ctrl_cell_emb - - # compute l1 loss - l1_loss = torch.abs(delta).mean() - - # Log the regularization loss - self.log("train/l1_regularization", l1_loss) - - # Add regularization to total loss - total_loss = total_loss + self.regularization * l1_loss - return total_loss def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: """Validation step logic.""" - if self.confidence_token is None: - pred, confidence_pred = self.forward(batch), None + if self.nb_loss: + pred, nb_mean_flat, nb_dispersion_flat = self.forward(batch, return_nb_params=True) else: - pred, confidence_pred = self.forward(batch) + pred = self.forward(batch) pred = pred.reshape(-1, self.cell_sentence_len, self.output_dim) target = batch["pert_cell_emb"] target = target.reshape(-1, self.cell_sentence_len, self.output_dim) - per_set_main_losses = self._compute_distribution_loss(pred, target) + embedding_aux_loss = None + if self.nb_loss: + nb_mean = nb_mean_flat.reshape(-1, self.cell_sentence_len, self.nb_target_dim) + nb_dispersion = nb_dispersion_flat.reshape(-1, self.cell_sentence_len, self.nb_target_dim) + nb_target = self._get_nb_target_tensor(batch, target, padded=True) + per_set_main_losses = self._compute_nb_nll_loss(nb_mean, nb_dispersion, nb_target) + if self.nb_embed_loss_weight > 0.0: + embedding_aux_losses = self._compute_distribution_loss(pred, target) + embedding_aux_loss = torch.nanmean(embedding_aux_losses) + self.log("val/embedding_loss", embedding_aux_loss) + else: + per_set_main_losses = self._compute_distribution_loss(pred, target) loss = torch.nanmean(per_set_main_losses) - self.log("val_loss", loss) + if embedding_aux_loss is not None: + loss = loss + self.nb_embed_loss_weight * embedding_aux_loss + self.log(self._val_main_loss_key(), loss) - # Log individual loss components if using combined loss - if hasattr(self.loss_fn, "sinkhorn_loss") and hasattr(self.loss_fn, "energy_loss"): - sinkhorn_component = self.loss_fn.sinkhorn_loss(pred, target).mean() - energy_component = self.loss_fn.energy_loss(pred, target).mean() - self.log("val/sinkhorn_loss", sinkhorn_component) - self.log("val/energy_loss", energy_component) - - if self.gene_decoder is not None and "pert_cell_counts" in batch: + if (not self.nb_loss) and self.gene_decoder is not None and "pert_cell_counts" in batch: gene_targets = batch["pert_cell_counts"] # Get model predictions from validation step @@ -718,59 +585,42 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non decoder_loss = decoder_per_set.mean() # Log the validation metric - self.log("val/decoder_loss", decoder_loss) + self.log(self._val_expression_loss_key(), decoder_loss) loss = loss + self.decoder_loss_weight * decoder_loss - if confidence_pred is not None: - confidence_pred_vals = confidence_pred - if confidence_pred_vals.dim() > 1: - confidence_pred_vals = confidence_pred_vals.squeeze(-1) - confidence_targets = per_set_main_losses.detach() - if self.confidence_target_scale is not None: - confidence_targets = confidence_targets * self.confidence_target_scale - confidence_targets = confidence_targets.to(confidence_pred_vals.device) - - confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) - self.log("val/confidence_loss", confidence_loss) - self.log("val/actual_loss", confidence_targets.mean()) - return {"loss": loss, "predictions": pred} def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: - if self.confidence_token is None: - pred, confidence_pred = self.forward(batch, padded=False), None - else: - pred, confidence_pred = self.forward(batch, padded=False) - - target = batch["pert_cell_emb"] - pred = pred.reshape(1, -1, self.output_dim) - target = target.reshape(1, -1, self.output_dim) - per_set_main_losses = self._compute_distribution_loss(pred, target) - loss = torch.nanmean(per_set_main_losses) - self.log("test_loss", loss) - - if confidence_pred is not None: - confidence_pred_vals = confidence_pred - if confidence_pred_vals.dim() > 1: - confidence_pred_vals = confidence_pred_vals.squeeze(-1) - confidence_targets = per_set_main_losses.detach() - if self.confidence_target_scale is not None: - confidence_targets = confidence_targets * self.confidence_target_scale - confidence_targets = confidence_targets.to(confidence_pred_vals.device) - - confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets) - self.log("test/confidence_loss", confidence_loss) + if self.nb_loss: + pred, nb_mean_flat, nb_dispersion_flat = self.forward(batch, padded=False, return_nb_params=True) + target = batch["pert_cell_emb"] + pred = pred.reshape(1, -1, self.output_dim) + target = target.reshape(1, -1, self.output_dim) + nb_mean = nb_mean_flat.reshape(1, -1, self.nb_target_dim) + nb_dispersion = nb_dispersion_flat.reshape(1, -1, self.nb_target_dim) + nb_target = self._get_nb_target_tensor(batch, target, padded=False) + per_set_main_losses = self._compute_nb_nll_loss(nb_mean, nb_dispersion, nb_target) + loss = torch.nanmean(per_set_main_losses) + if self.nb_embed_loss_weight > 0.0: + embedding_aux_losses = self._compute_distribution_loss(pred, target) + loss = loss + self.nb_embed_loss_weight * torch.nanmean(embedding_aux_losses) + self.log("test_loss", loss) + return + _ = self.forward(batch, padded=False) def predict_step(self, batch, batch_idx, padded=True, **kwargs): """ Typically used for final inference. We'll replicate old logic:s returning 'preds', 'X', 'pert_name', etc. """ - if self.confidence_token is None: - latent_output = self.forward(batch, padded=padded) # shape [B, ...] - confidence_pred = None + if self.nb_loss: + latent_output, nb_mean_flat, nb_dispersion_flat = self.forward( + batch, + padded=padded, + return_nb_params=True, + ) else: - latent_output, confidence_pred = self.forward(batch, padded=padded) + latent_output = self.forward(batch, padded=padded) # shape [B, ...] output_dict = { "preds": latent_output, @@ -784,13 +634,69 @@ def predict_step(self, batch, batch_idx, padded=True, **kwargs): "ctrl_cell_barcode": batch.get("ctrl_cell_barcode", None), } - # Add confidence prediction to output if available - if confidence_pred is not None: - output_dict["confidence_pred"] = confidence_pred - - if self.gene_decoder is not None: + if self.nb_loss: + output_dict["pert_cell_counts_preds"] = nb_mean_flat.reshape(-1, self.nb_target_dim) + output_dict["pert_cell_counts_dispersion"] = nb_dispersion_flat.reshape(-1, self.nb_target_dim) + elif self.gene_decoder is not None: pert_cell_counts_preds = self.gene_decoder(latent_output) - output_dict["pert_cell_counts_preds"] = pert_cell_counts_preds return output_dict + + def configure_optimizers(self): + """ + Configure optimizer and optional cosine LR decay. + + This is intentionally scoped to StateTransitionPerturbationModel only. + """ + optimizer_name = str(self.hparams.get("optimizer", "adam")).lower() + base_lr = float(self.hparams.get("lr", self.lr)) + weight_decay = float(self.hparams.get("weight_decay", 0.0)) + + if optimizer_name == "adamw": + optimizer = torch.optim.AdamW(self.parameters(), lr=base_lr, weight_decay=weight_decay) + elif optimizer_name == "adam": + optimizer = torch.optim.Adam(self.parameters(), lr=base_lr, weight_decay=weight_decay) + else: + raise ValueError(f"Unsupported optimizer '{optimizer_name}'. Expected one of: adam, adamw.") + + if not bool(self.hparams.get("use_cosine_decay", False)): + return optimizer + + max_lr_cfg = self.hparams.get("max_lr", None) + max_lr = float(max_lr_cfg) if max_lr_cfg is not None else base_lr + if max_lr <= 0: + raise ValueError(f"max_lr must be > 0 when cosine decay is enabled. Received: {max_lr}") + + decay_steps_cfg = self.hparams.get("lr_decay_steps", None) + if decay_steps_cfg is None: + decay_steps = int(self.hparams.get("max_steps", 0)) + else: + decay_steps = int(decay_steps_cfg) + if decay_steps <= 0: + raise ValueError( + "lr_decay_steps must be a positive integer when cosine decay is enabled " + "(or training.max_steps must be set > 0)." + ) + + max_lr_fraction = float(self.hparams.get("max_lr_fraction", 0.1)) + if not (0 < max_lr_fraction <= 1.0): + raise ValueError(f"max_lr_fraction must be in (0, 1]. Received: {max_lr_fraction}") + + min_lr = max_lr * max_lr_fraction + for param_group in optimizer.param_groups: + param_group["lr"] = max_lr + + def _lr_lambda(step: int) -> float: + if step >= decay_steps: + return max_lr_fraction + decay_ratio = step / decay_steps + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + lr = min_lr + coeff * (max_lr - min_lr) + return lr / max_lr + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=_lr_lambda) + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": scheduler, "interval": "step", "frequency": 1}, + } diff --git a/src/state/tx/utils/__init__.py b/src/state/tx/utils/__init__.py index e4fcf9ac..6ad5600c 100644 --- a/src/state/tx/utils/__init__.py +++ b/src/state/tx/utils/__init__.py @@ -127,7 +127,13 @@ def get_loggers( return loggers -def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, _ckpt_every_n_steps: int): +def get_checkpoint_callbacks( + output_dir: str, + name: str, + val_freq: int, + _ckpt_every_n_steps: int, + monitor_metric: str = "val/expression_loss", +): """ Create checkpoint callbacks based on validation frequency. @@ -136,12 +142,12 @@ def get_checkpoint_callbacks(output_dir: str, name: str, val_freq: int, _ckpt_ev checkpoint_dir = join(output_dir, name, "checkpoints") callbacks = [] - # Save only the best checkpoint (by val_loss) plus the latest checkpoint + # Save only the best checkpoint (by monitor_metric) plus the latest checkpoint best_ckpt = ModelCheckpoint( dirpath=checkpoint_dir, filename="best", save_last=True, - monitor="val_loss", + monitor=monitor_metric, mode="min", save_top_k=1, every_n_train_steps=val_freq, @@ -178,18 +184,6 @@ def get_lightning_module(model_type: str, data_config: dict, model_config: dict, batch_dim=var_dims["batch_dim"], **module_config, ) - elif model_type.lower() == "old_neuralot": - from ...tx.models.old_neural_ot import OldNeuralOTPerturbationModel - - return OldNeuralOTPerturbationModel( - input_dim=var_dims["input_dim"], - gene_dim=gene_dim, - hvg_dim=var_dims["hvg_dim"], - output_dim=var_dims["output_dim"], - pert_dim=var_dims["pert_dim"], - batch_dim=var_dims["batch_dim"], - **module_config, - ) elif model_type.lower() == "neuralot" or model_type.lower() == "pertsets" or model_type.lower() == "state": from ...tx.models.state_transition import StateTransitionPerturbationModel @@ -251,93 +245,5 @@ def get_lightning_module(model_type: str, data_config: dict, model_config: dict, batch_dim=var_dims["batch_dim"], **module_config, ) - elif model_type.lower() == "cpa": - from ...tx.models.cpa import CPAPerturbationModel - - return CPAPerturbationModel( - input_dim=var_dims["input_dim"], - output_dim=var_dims["output_dim"], - pert_dim=var_dims["pert_dim"], - gene_dim=gene_dim, - **module_config, - ) - elif model_type.lower() == "scvi": - from ...tx.models.scvi import SCVIPerturbationModel - - return SCVIPerturbationModel( - input_dim=var_dims["input_dim"], - gene_dim=gene_dim, - hvg_dim=var_dims["hvg_dim"], - output_dim=var_dims["output_dim"], - pert_dim=var_dims["pert_dim"], - batch_dim=var_dims["batch_dim"], - **module_config, - ) - elif model_type.lower() == "scgpt-chemical" or model_type.lower() == "scgpt-genetic": - from ...tx.models.scgpt import scGPTForPerturbation - - pretrained_path = module_config["pretrained_path"] - assert pretrained_path is not None, "pretrained_path must be provided for scGPT" - - model_dir = Path(pretrained_path) - model_file = model_dir / "best_model.pt" - - model = scGPTForPerturbation( - ntoken=module_config["ntoken"], - n_drug_tokens=module_config["n_perts"], # only used for chemical perturbations - d_model=module_config["d_model"], - nhead=module_config["nhead"], - d_hid=module_config["d_hid"], - nlayers=module_config["nlayers"], - nlayers_cls=module_config["n_layers_cls"], - n_cls=1, - dropout=module_config["dropout"], - pad_token_id=module_config["pad_token_id"], - pad_value=module_config["pad_value"], - pert_pad_id=module_config["pert_pad_id"], - do_mvc=module_config["do_MVC"], - cell_emb_style=module_config["cell_emb_style"], - mvc_decoder_style=module_config["mvc_decoder_style"], - use_fast_transformer=module_config["use_fast_transformer"], - lr=module_config["lr"], - step_size_lr=module_config["step_size_lr"], - include_zero_gene=module_config["include_zero_gene"], - embed_key=module_config["embed_key"], - perturbation_type=module_config["perturbation_type"], - ) - - load_param_prefixes = module_config["load_param_prefixes"] - - if load_param_prefixes is not None: - model_dict = model.model.state_dict() - pretrained_dict = torch.load(model_file) - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if any([k.startswith(prefix) for prefix in module_config["load_param_prefixes"]]) - } - for k, v in pretrained_dict.items(): - print(f"Loading params {k} with shape {v.shape}") - - model_dict.update(pretrained_dict) - model.model.load_state_dict(model_dict) - else: - try: - model.model.load_state_dict(torch.load(model_file)) - print(f"Loading all model params from {model_file}") - except: - # only load params that are in the model and match the size - model_dict = model.model.state_dict() - pretrained_dict = torch.load(model_file) - pretrained_dict = { - k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape - } - for k, v in pretrained_dict.items(): - print(f"Loading params {k} with shape {v.shape}") - - model_dict.update(pretrained_dict) - model.model.load_state_dict(model_dict) - - return model else: raise ValueError(f"Unknown model type: {model_type}")