Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion MIOFlow/gaga.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy as np
from tqdm import tqdm
import phate
import scipy
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import pairwise_distances
from typing import List, Union, Optional, Tuple


Expand Down Expand Up @@ -331,6 +332,87 @@ def dataloader_from_pc(pointcloud, distances, batch_size=64, shuffle=True):
return dataloader


def fit_gaga(
X_pca: np.ndarray,
X_phate: np.ndarray,
latent_dim: int = 2,
hidden_dims: List[int] = [128, 64],
batch_size: int = 1024,
encoder_epochs: int = 300,
decoder_epochs: int = 300,
learning_rate: float = 1e-3,
dist_weight_phase1: float = 1.0,
recon_weight_phase2: float = 1.0,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> 'Autoencoder':
"""
Convenience wrapper: scale inputs, build Autoencoder, run two-phase GAGA training.

Handles all preprocessing internally so the caller only needs to supply raw
``adata.obsm`` arrays. The fitted PCA scaler is stored on the returned model
as ``model.input_scaler`` so that ``MIOFlow`` can pick it up automatically.

Parameters
----------
X_pca : np.ndarray, shape (n_cells, n_pcs)
Raw PCA coordinates (e.g. ``adata.obsm['X_pca']``).
X_phate : np.ndarray, shape (n_cells, n_phate_dims)
Raw PHATE coordinates (e.g. ``adata.obsm['X_phate']``).
latent_dim : int
Dimensionality of the GAGA latent space (should match PHATE dims).
hidden_dims : list of int
Hidden layer sizes for encoder and decoder.
batch_size : int
Batch size for the DataLoader.
encoder_epochs : int
Phase 1 epochs (distance preservation, decoder frozen).
decoder_epochs : int
Phase 2 epochs (reconstruction, encoder frozen).
learning_rate : float
Adam learning rate.
dist_weight_phase1 : float
Weight for distance loss in phase 1.
recon_weight_phase2 : float
Weight for reconstruction loss in phase 2.
device : str
Torch device string.

Returns
-------
Autoencoder
Trained model with ``model.input_scaler`` (StandardScaler fitted on X_pca).
"""
X_pca = X_pca.astype(np.float32)
X_phate = X_phate.astype(np.float32)

# Scale PCA inputs — scaler stored on the model for downstream use
scaler_pca = StandardScaler().fit(X_pca)
X_pca_scaled = scaler_pca.transform(X_pca)

# Scale PHATE and compute pairwise distances for geometric regularisation
X_phate_scaled = StandardScaler().fit_transform(X_phate)
phate_dist = pairwise_distances(X_phate_scaled, metric='euclidean').astype(np.float32)

input_dim = X_pca_scaled.shape[1]
model = Autoencoder(input_dim, latent_dim, hidden_dims=hidden_dims)
loader = dataloader_from_pc(X_pca_scaled, phate_dist, batch_size=batch_size)
print(f'GAGA architecture: {input_dim} → {latent_dim}')

train_gaga_two_phase(
model,
loader,
encoder_epochs=encoder_epochs,
decoder_epochs=decoder_epochs,
learning_rate=learning_rate,
device=device,
dist_weight_phase1=dist_weight_phase1,
recon_weight_phase2=recon_weight_phase2,
)

model.input_scaler = scaler_pca
return model


def train_valid_loader_from_pc(pointcloud, distances, batch_size=64,
train_valid_split=0.8, shuffle=True, seed=42):
"""Split point cloud data into train/validation sets."""
Expand Down
6 changes: 2 additions & 4 deletions MIOFlow/growth_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@ def __init__(
# Data encoding — same interface as MIOFlow
gaga_model=None,
gaga_input_key: str = 'X_pca',
gaga_input_scaler=None,
obs_time_key: str = 'time_bin',
use_cuda: bool = True,
):
super().__init__()
self.use_time = use_time
self.gaga_autoencoder = gaga_model
self.gaga_input_key = gaga_input_key
self.gaga_input_scaler = gaga_input_scaler
self.obs_time_key = obs_time_key
self.device = 'cuda' if (use_cuda and torch.cuda.is_available()) else 'cpu'

Expand All @@ -56,8 +54,8 @@ def _prepare_data(self, adata) -> TimeSeriesDataset:
X_raw = adata.obsm[self.gaga_input_key].astype(np.float32)

if self.gaga_autoencoder is not None:
X_scaled = (self.gaga_input_scaler.transform(X_raw)
if self.gaga_input_scaler is not None else X_raw)
X_scaled = (self.gaga_autoencoder.input_scaler.transform(X_raw)
if self.gaga_autoencoder.input_scaler is not None else X_raw)
self.gaga_autoencoder.eval()
with torch.no_grad():
embedding = self.gaga_autoencoder.encode(
Expand Down
20 changes: 6 additions & 14 deletions MIOFlow/mioflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ class MIOFlow:
gaga_model = Autoencoder(input_dim, latent_dim)
train_gaga_two_phase(gaga_model, dataloader, ...)

# 2. Pass the trained model + its input scaler to MIOFlow
# 2. Pass the trained model to MIOFlow (scaler read from gaga_model.input_scaler)
mf = MIOFlow(
adata,
gaga_model=gaga_model,
gaga_input_scaler=scaler, # StandardScaler fitted on X_pca
obs_time_key='time_bin',
n_epochs=100,
)
Expand All @@ -57,10 +56,6 @@ class MIOFlow:
training; its decoder is used in ``decode_to_gene_space()``.
gaga_input_key : str
Key in ``adata.obsm`` fed to the GAGA encoder (default: ``'X_pca'``).
gaga_input_scaler : sklearn-compatible scaler, optional
Scaler (e.g. ``StandardScaler``) already fitted on the data in
``adata.obsm[gaga_input_key]``. Used to normalise inputs before
encoding and to inverse-transform decoder outputs.
obs_time_key : str
Column in ``adata.obs`` holding the time/group label
(default: ``'day'``).
Expand Down Expand Up @@ -101,7 +96,6 @@ def __init__(
# GAGA autoencoder (trained externally)
gaga_model=None,
gaga_input_key: str = 'X_pca',
gaga_input_scaler=None,
obs_time_key: str = "time_bin",
debug_level: str = 'info',
hidden_dim: float = 64,
Expand Down Expand Up @@ -144,7 +138,6 @@ def __init__(
self.adata = adata
self.gaga_autoencoder = gaga_model
self.gaga_input_key = gaga_input_key
self.gaga_input_scaler = gaga_input_scaler
self.obs_time_key = obs_time_key

# Model config
Expand Down Expand Up @@ -238,8 +231,8 @@ def _encode(self) -> Tuple[np.ndarray, np.ndarray]:

# If no autoencoder was used reads adata.obsm[gaga_input_key]:
if self.gaga_autoencoder is not None:
X_scaled = (self.gaga_input_scaler.transform(X_raw)
if self.gaga_input_scaler is not None else X_raw)
scaler = getattr(self.gaga_autoencoder, 'input_scaler', None)
X_scaled = scaler.transform(X_raw) if scaler is not None else X_raw
self.gaga_autoencoder.eval()
with torch.no_grad():
embedding = self.gaga_autoencoder.encode(
Expand Down Expand Up @@ -378,10 +371,9 @@ def decode_to_gene_space(self) -> np.ndarray:
torch.tensor(traj_flat, dtype=torch.float32)
).cpu().numpy()

# Inverse-scale back to PCA space (if a scaler was provided)
traj_pca = (self.gaga_input_scaler.inverse_transform(traj_pca_gaga)
if self.gaga_input_scaler is not None
else traj_pca_gaga)
# Inverse-scale back to PCA space
scaler = getattr(self.gaga_autoencoder, 'input_scaler', None)
traj_pca = scaler.inverse_transform(traj_pca_gaga) if scaler is not None else traj_pca_gaga

# PCA → gene space
X_reconstructed = np.array(
Expand Down
158 changes: 69 additions & 89 deletions notebooks/1.01.01-SERGIO-single-trajectory.ipynb

Large diffs are not rendered by default.

197 changes: 50 additions & 147 deletions notebooks/1.01.02-SERGIO-separated-branches.ipynb

Large diffs are not rendered by default.

269 changes: 47 additions & 222 deletions notebooks/1.02.01-SERGIO-SDE.ipynb

Large diffs are not rendered by default.

215 changes: 51 additions & 164 deletions notebooks/1.02.02-SERGIO-SDE-multi.ipynb

Large diffs are not rendered by default.

271 changes: 43 additions & 228 deletions notebooks/1.03.01-SERGIO-growth_rate.ipynb

Large diffs are not rendered by default.

207 changes: 50 additions & 157 deletions notebooks/1.03.02-SERGIO-growth_rate-multi.ipynb

Large diffs are not rendered by default.

415 changes: 134 additions & 281 deletions notebooks/2.01-EBD-vanilla.ipynb

Large diffs are not rendered by default.

Loading