Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
a978211
remoed scgpt from code base, implementation exists in state-reproduce
abhinadduri Feb 21, 2026
351edbb
removed cpa and scvi models from this repo as they are available in s…
abhinadduri Feb 21, 2026
7dcb6b7
removes old neuralot model
abhinadduri Feb 21, 2026
95a73dc
removed stale vci fine tune decoder
abhinadduri Feb 21, 2026
68bc593
removed batch token ablation as it does not work too well
abhinadduri Feb 21, 2026
6a97538
removed confidence token which was also unused
abhinadduri Feb 21, 2026
e22636d
removed combined loss as well as l1 regularization
abhinadduri Feb 21, 2026
32e4e9a
removed the mmd chunks implementation
abhinadduri Feb 21, 2026
b23d4fd
cleaned base perturbation model and latent to gene decoder to remove …
abhinadduri Feb 22, 2026
df2c83c
refactored wandb metrics names
abhinadduri Feb 22, 2026
906611b
updated to specify st training contract with cell load in readme
abhinadduri Feb 22, 2026
cbb5e06
adds learning rate scheduling and choice of adam vs adamw in hydra co…
abhinadduri Feb 22, 2026
5202767
updated residual to match main
abhinadduri Feb 22, 2026
761acda
correctly implemented nb nll loss to natively predict out counts
abhinadduri Feb 23, 2026
9d18311
adds nb loss to config
abhinadduri Feb 23, 2026
c4841b5
chore: ruff
abhinadduri Feb 23, 2026
c16276b
embedding output is incompatible with nb loss, also nb loss forces ex…
abhinadduri Feb 23, 2026
7011a63
bump semvar and add is log1p logic for predict for metrics evaluator
abhinadduri Feb 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ vci_job_*
lightning_logs/
outputs/
log/
logs/
debugging/
uv.lock
tmp/
notebooks/
Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ Optional downsampling overrides:
- `data.kwargs.downsample` downsample counts during loading (`<=1` keeps that fraction by binomial sampling, `>1` targets that read depth per cell for `output_space=all`).
- `data.kwargs.downsample_cells` caps the number of cells loaded per `(cell_type, perturbation[, batch])` group while leaving smaller groups unchanged.

### ST training contract: `embed_key`, `output_space`, decoder, and metrics

For ST (`state tx train`), two settings jointly define target tensors, decoder usage, and logged losses:
- `data.kwargs.embed_key`
- `data.kwargs.output_space`

The current contract is:

| `embed_key` | `output_space` | Main target/loss space | Decoder target (if present) | Logged metrics |
|---|---|---|---|---|
| `X_hvg` or `null` | `gene` | Expression (HVG/count space) | None expected | `train/expression_loss`, `val/expression_loss` |
| non-`X_hvg` embedding key (for example `X_state`, `X_uce`) | `gene` | Embedding space | Expression (HVG/count space) | `train/embedding_loss`, `val/embedding_loss`, plus `train/expression_loss`, `val/expression_loss` |
| non-`X_hvg` embedding key | `all` | Embedding space | Expression (full transcriptome) | `train/embedding_loss`, `val/embedding_loss`, plus `train/expression_loss`, `val/expression_loss` |
| any | `embedding` | Embedding space | Disabled | `train/embedding_loss`, `val/embedding_loss` |

Notes:
- Main loss is always computed from the model output against `batch["pert_cell_emb"]`.
- Decoder loss (expression loss) is only logged when decoder targets are available in `batch["pert_cell_counts"]`.
- With `output_space=embedding`, decoder is disabled.
- Best-checkpoint monitoring is:
- `val/expression_loss` when `output_space` is `gene` or `all`
- `val/embedding_loss` when `output_space` is `embedding`

### predict

Evaluates a trained run with `cell-eval` metrics (or just runs prediction with `--predict-only`).
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "arc-state"
version = "0.10.4"
version = "1.0.0"
description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts."
readme = "README.md"
authors = [
Expand Down
41 changes: 27 additions & 14 deletions src/state/_cli/_tx/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,12 +443,21 @@ def pad_adata_with_tsv(
cell_set_len = args.max_set_len if args.max_set_len is not None else getattr(model, "cell_sentence_len", 256)
uses_batch_encoder = getattr(model, "batch_encoder", None) is not None
output_space = getattr(model, "output_space", cfg.get("data", {}).get("kwargs", {}).get("output_space", "gene"))
nb_loss_enabled = bool(getattr(model, "nb_loss", cfg.get("model", {}).get("kwargs", {}).get("nb_loss", False)))
if nb_loss_enabled and output_space == "embedding":
raise ValueError(
"model.kwargs.nb_loss=True is incompatible with data.kwargs.output_space='embedding'. "
"Use output_space='gene' or output_space='all'."
)
if nb_loss_enabled and output_space not in {"gene", "all"}:
raise ValueError(f"nb_loss=True requires output_space in {{'gene', 'all'}}; got {output_space!r}.")

if not args.quiet:
print(f"Model device: {device}")
print(f"Model cell_set_len (max sequence length): {cell_set_len}")
print(f"Model uses batch encoder: {bool(uses_batch_encoder)}")
print(f"Model output space: {output_space}")
print(f"Model nb_loss enabled: {nb_loss_enabled}")

# -----------------------
# 3) Load AnnData
Expand Down Expand Up @@ -720,7 +729,7 @@ def pad_adata_with_tsv(
store_raw_expression = (args.embed_key is not None and args.embed_key != "X_hvg" and output_space == "gene") or (
args.embed_key is not None and output_space == "all"
)
counts_expected = store_raw_expression
counts_expected = store_raw_expression or nb_loss_enabled
counts_out_target: Optional[str] = None
counts_obsm_key: Optional[str] = None
sim_counts: Optional[np.ndarray] = None
Expand Down Expand Up @@ -916,24 +925,28 @@ def group_control_indices(group_name: str) -> np.ndarray:

start = end # next window

# Clip gene-space predictions to keep downstream eval consistent.
# Clip legacy decoder outputs only; NB count outputs remain unclipped.
if output_space in {"gene", "all"}:
if out_target == "X":
clip_array(sim_X)
elif out_target.startswith("obsm['") and out_target.endswith("']"):
pred_key = out_target[6:-2]
if writes_to[0] == ".obsm" and pred_key == writes_to[1]:
clip_array(sim_obsm)
elif pred_key in adata.obsm:
clip_array(adata.obsm[pred_key])
if nb_loss_enabled:
if not args.quiet:
print("nb_loss=True: skipping clipping of simulated outputs.")
else:
if writes_to[0] == ".X":
if out_target == "X":
clip_array(sim_X)
elif out_target.startswith("obsm['") and out_target.endswith("']"):
pred_key = out_target[6:-2]
if writes_to[0] == ".obsm" and pred_key == writes_to[1]:
clip_array(sim_obsm)
elif pred_key in adata.obsm:
clip_array(adata.obsm[pred_key])
else:
clip_array(sim_obsm)
if writes_to[0] == ".X":
clip_array(sim_X)
else:
clip_array(sim_obsm)

if counts_written and sim_counts is not None:
clip_array(sim_counts)
if counts_written and sim_counts is not None:
clip_array(sim_counts)

# -----------------------
# 5) Persist the updated AnnData
Expand Down
118 changes: 83 additions & 35 deletions src/state/_cli/_tx/_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,46 @@ def ensure_list(values, batch_size: int):
data_module.test_datasets = []
data_module._setup_global_maps()
data_module.setup(stage="test")
nb_loss_enabled = bool(cfg.get("model", {}).get("kwargs", {}).get("nb_loss", False))
output_space = cfg.get("data", {}).get("kwargs", {}).get("output_space", "gene")
if nb_loss_enabled and output_space == "embedding":
raise ValueError(
"model.kwargs.nb_loss=True is incompatible with data.kwargs.output_space='embedding'. "
"Use output_space='gene' or output_space='all'."
)
if nb_loss_enabled and output_space not in {"gene", "all"}:
raise ValueError(
f"model.kwargs.nb_loss=True requires data.kwargs.output_space in {{'gene', 'all'}}; got {output_space!r}."
)
if nb_loss_enabled:
resolved_is_log1p = bool(getattr(data_module, "is_log1p", cfg["data"]["kwargs"].get("is_log1p", False)))
expected_exp_counts = resolved_is_log1p
current_exp_counts = bool(getattr(data_module, "exp_counts", False))
if current_exp_counts != expected_exp_counts:
logger.warning(
"nb_loss=True requires exp_counts to follow is_log1p. "
"Resolved is_log1p=%s, overriding exp_counts %s -> %s.",
resolved_is_log1p,
current_exp_counts,
expected_exp_counts,
)
data_module.exp_counts = expected_exp_counts
if data_module.embed_key not in {None, "X_hvg"} and not bool(getattr(data_module, "store_raw_basal", False)):
logger.warning(
"nb_loss=True with embed_key=%r and store_raw_basal=False. "
"NB library-size estimation will fall back to ctrl_cell_emb.",
data_module.embed_key,
)
cfg["data"]["kwargs"]["is_log1p"] = resolved_is_log1p
cfg["data"]["kwargs"]["exp_counts"] = expected_exp_counts
resolved_exp_counts = bool(getattr(data_module, "exp_counts", cfg["data"]["kwargs"].get("exp_counts", False)))
metrics_is_log1p = not (nb_loss_enabled or resolved_exp_counts)
logger.info(
"Metrics config: setting pdex is_log1p=%s (nb_loss=%s, exp_counts=%s)",
metrics_is_log1p,
nb_loss_enabled,
resolved_exp_counts,
)
logger.info("Loaded data module from %s", data_module_path)

# Seed everything
Expand All @@ -250,10 +290,6 @@ def ensure_list(values, batch_size: int):
from ...tx.models.embed_sum import EmbedSumPerturbationModel

ModelClass = EmbedSumPerturbationModel
elif model_class_name.lower() == "old_neuralot":
from ...tx.models.old_neural_ot import OldNeuralOTPerturbationModel

ModelClass = OldNeuralOTPerturbationModel
elif model_class_name.lower() in ["neuralot", "pertsets", "state"]:
from ...tx.models.state_transition import StateTransitionPerturbationModel

Expand Down Expand Up @@ -337,6 +373,12 @@ def ensure_list(values, batch_size: int):
and data_module.embed_key != "X_hvg"
and cfg["data"]["kwargs"]["output_space"] == "gene"
) or (data_module.embed_key is not None and cfg["data"]["kwargs"]["output_space"] == "all")
use_count_outputs = store_raw_expression or nb_loss_enabled
if nb_loss_enabled and not store_raw_expression:
logger.info(
"nb_loss=True: forcing prediction artifacts to use NB count outputs even though "
"store_raw_expression would otherwise be disabled."
)

if args.pseudobulk:
logger.info("Pseudobulk enabled; aggregating running means by (context, perturbation).")
Expand All @@ -348,13 +390,13 @@ def ensure_list(values, batch_size: int):
os.makedirs(results_dir, exist_ok=True)

pseudo_x_dim = None
if store_raw_expression:
if cfg["data"]["kwargs"]["output_space"] == "gene":
if use_count_outputs:
if output_space == "gene":
pseudo_x_dim = hvg_dim
elif cfg["data"]["kwargs"]["output_space"] == "all":
elif output_space == "all":
pseudo_x_dim = gene_dim
else:
raise ValueError(f"Unsupported output_space for pseudobulk: {cfg['data']['kwargs']['output_space']}")
raise ValueError(f"Unsupported output_space for pseudobulk: {output_space}")

pb_groups: dict[tuple[str, str], dict] = {}
context_mode = None
Expand Down Expand Up @@ -402,7 +444,7 @@ def ensure_list(values, batch_size: int):

batch_real_gene_np = None
batch_gene_pred_np = None
if store_raw_expression:
if use_count_outputs:
batch_real_gene_np = batch_preds["pert_cell_counts"].cpu().numpy().astype(np.float32)
batch_gene_pred_np = batch_preds["pert_cell_counts_preds"].cpu().numpy().astype(np.float32)

Expand All @@ -427,10 +469,8 @@ def ensure_list(values, batch_size: int):
"count": 0,
"pred_sum": np.zeros(output_dim, dtype=np.float64),
"real_sum": np.zeros(output_dim, dtype=np.float64),
"x_hvg_sum": np.zeros(pseudo_x_dim, dtype=np.float64) if store_raw_expression else None,
"counts_pred_sum": (
np.zeros(pseudo_x_dim, dtype=np.float64) if store_raw_expression else None
),
"x_hvg_sum": np.zeros(pseudo_x_dim, dtype=np.float64) if use_count_outputs else None,
"counts_pred_sum": np.zeros(pseudo_x_dim, dtype=np.float64) if use_count_outputs else None,
}
pb_groups[(context_label, pert_name)] = entry
elif entry["celltype_name"] != current_celltype:
Expand All @@ -442,7 +482,7 @@ def ensure_list(values, batch_size: int):
entry["count"] += int(idx_arr.size)
entry["pred_sum"] += batch_pred_np[idx_arr].sum(axis=0, dtype=np.float64)
entry["real_sum"] += batch_real_np[idx_arr].sum(axis=0, dtype=np.float64)
if store_raw_expression:
if use_count_outputs:
entry["x_hvg_sum"] += batch_real_gene_np[idx_arr].sum(axis=0, dtype=np.float64)
entry["counts_pred_sum"] += batch_gene_pred_np[idx_arr].sum(axis=0, dtype=np.float64)

Expand All @@ -462,8 +502,8 @@ def ensure_list(values, batch_size: int):

pred_bulk = np.empty((n_groups, output_dim), dtype=np.float32)
real_bulk = np.empty((n_groups, output_dim), dtype=np.float32)
pred_x = np.empty((n_groups, pseudo_x_dim), dtype=np.float32) if store_raw_expression else None
real_x = np.empty((n_groups, pseudo_x_dim), dtype=np.float32) if store_raw_expression else None
pred_x = np.empty((n_groups, pseudo_x_dim), dtype=np.float32) if use_count_outputs else None
real_x = np.empty((n_groups, pseudo_x_dim), dtype=np.float32) if use_count_outputs else None

reserved_obs_keys = {
str(data_module.pert_col),
Expand Down Expand Up @@ -495,7 +535,7 @@ def ensure_list(values, batch_size: int):
denom = float(count)
pred_bulk[idx, :] = (entry["pred_sum"] / denom).astype(np.float32, copy=False)
real_bulk[idx, :] = (entry["real_sum"] / denom).astype(np.float32, copy=False)
if store_raw_expression:
if use_count_outputs:
pred_x[idx, :] = (entry["counts_pred_sum"] / denom).astype(np.float32, copy=False)
real_x[idx, :] = (entry["x_hvg_sum"] / denom).astype(np.float32, copy=False)

Expand All @@ -508,18 +548,22 @@ def ensure_list(values, batch_size: int):
obs_dict[data_module.batch_col].append(entry["batch_name"])

obs = pd.DataFrame(obs_dict)
if store_raw_expression:
if use_count_outputs:
adata_pred = anndata.AnnData(X=pred_x, obs=obs)
adata_real = anndata.AnnData(X=real_x, obs=obs)
adata_pred.obsm[data_module.embed_key] = pred_bulk
adata_real.obsm[data_module.embed_key] = real_bulk
if data_module.embed_key is not None:
adata_pred.obsm[data_module.embed_key] = pred_bulk
adata_real.obsm[data_module.embed_key] = real_bulk
else:
adata_pred = anndata.AnnData(X=pred_bulk, obs=obs)
adata_real = anndata.AnnData(X=real_bulk, obs=obs)

clip_anndata_values(adata_pred, max_value=14.0)
clip_anndata_values(adata_real, max_value=14.0)
logger.info("Clipped pseudobulk adata_pred and adata_real X values to [0.0, 14.0].")
if nb_loss_enabled:
logger.info("nb_loss=True in run config; skipping pseudobulk clipping of adata_pred/adata_real X values.")
else:
clip_anndata_values(adata_pred, max_value=14.0)
clip_anndata_values(adata_real, max_value=14.0)
logger.info("Clipped pseudobulk adata_pred and adata_real X values to [0.0, 14.0].")

if args.shared_only:
try:
Expand Down Expand Up @@ -566,7 +610,7 @@ def ensure_list(values, batch_size: int):
f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}"
)

pdex_kwargs = dict(exp_post_agg=True, is_log1p=True)
pdex_kwargs = dict(exp_post_agg=True, is_log1p=metrics_is_log1p)
for ct in ct_split_real.keys():
real_ct = ct_split_real[ct]
pred_ct = ct_split_pred[ct]
Expand Down Expand Up @@ -608,12 +652,12 @@ def ensure_list(values, batch_size: int):

final_X_hvg = None
final_pert_cell_counts_preds = None
if store_raw_expression:
if use_count_outputs:
# Preallocate matrices of shape (num_cells, gene_dim) for decoded predictions.
if cfg["data"]["kwargs"]["output_space"] == "gene":
if output_space == "gene":
final_X_hvg = np.empty((num_cells, hvg_dim), dtype=np.float32)
final_pert_cell_counts_preds = np.empty((num_cells, hvg_dim), dtype=np.float32)
if cfg["data"]["kwargs"]["output_space"] == "all":
if output_space == "all":
final_X_hvg = np.empty((num_cells, gene_dim), dtype=np.float32)
final_pert_cell_counts_preds = np.empty((num_cells, gene_dim), dtype=np.float32)

Expand Down Expand Up @@ -719,9 +763,10 @@ def ensure_list(values, batch_size: int):
adata_real = anndata.AnnData(X=final_X_hvg, obs=obs)

# add the embedding predictions
adata_pred.obsm[data_module.embed_key] = final_preds
adata_real.obsm[data_module.embed_key] = final_reals
logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']")
if data_module.embed_key is not None:
adata_pred.obsm[data_module.embed_key] = final_preds
adata_real.obsm[data_module.embed_key] = final_reals
logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']")
else:
# if len(gene_names) != final_preds.shape[1]:
# gene_names = np.load(
Expand All @@ -736,10 +781,13 @@ def ensure_list(values, batch_size: int):
# adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var)
adata_real = anndata.AnnData(X=final_reals, obs=obs)

# Clip extreme values to keep cell-eval log1p checks happy.
clip_anndata_values(adata_pred, max_value=14.0)
clip_anndata_values(adata_real, max_value=14.0)
logger.info("Clipped adata_pred and adata_real X values to [0.0, 14.0] before evaluation.")
# Clip extreme values to keep cell-eval log1p checks happy unless NB loss is active.
if nb_loss_enabled:
logger.info("nb_loss=True in run config; skipping clipping of adata_pred/adata_real X values.")
else:
clip_anndata_values(adata_pred, max_value=14.0)
clip_anndata_values(adata_real, max_value=14.0)
logger.info("Clipped adata_pred and adata_real X values to [0.0, 14.0] before evaluation.")

# Optionally filter to perturbations seen in at least one training context
if args.shared_only:
Expand Down Expand Up @@ -798,7 +846,7 @@ def ensure_list(values, batch_size: int):
f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}"
)

pdex_kwargs = dict(exp_post_agg=True, is_log1p=True)
pdex_kwargs = dict(exp_post_agg=True, is_log1p=metrics_is_log1p)

for ct in ct_split_real.keys():
real_ct = ct_split_real[ct]
Expand Down
Loading