diff --git a/configs/finetune_af3.yaml b/configs/finetune_af3.yaml new file mode 100644 index 0000000..fcce543 --- /dev/null +++ b/configs/finetune_af3.yaml @@ -0,0 +1,79 @@ +data: + batch_size: 32 + max_batch_size: 32 + num_workers: 0 + train_size: 0.8 + mono: false + prefetch_factor: 100 + csv_path: /home/am3826/scratch_pi_sk2433/am3826/iedb_af3/merged_af3.csv + mask: + mask_rate: 0.5 + max_distance: 8 + max_neighbors: 12 + structure: + adj: true + graph_type: knn + k: 15 + +model: + sequence: + model_type: esm + freeze_esm: true + aggregate: false + esm_dim: 1280 + rep_layer: 33 + esm_variant: esm2_t33_650M_UR50D + out_dim: 32 + n_heads: 8 + dim_ffn: 256 + n_layers: 10 + structure: + model_type: transformer + out_dim: 32 + n_heads: 8 + dim_ffn: 256 + n_layers: 4 + bio_chem: + model_type: mlp + n_bio_prop: 93 # DO NOT CHANGE THIS + hidden_dim: 64 + out_dim: 32 + n_layers: 4 + # FinetuneClassifierModule config + classifier: + module: FinetuneClassifierModule + num_classes: 2 + bio_dim: 32 + hidden_dims: [512, 256, 128, 64, 32] + lr: 1e-4 + class_weights: null + +optimizer: + lr: 1e-4 + weight_decay: 1e-6 + +experiment: + debug: False + num_devices: 2 + wandb: + mode: "disabled" + name: "af3_finetune_mlp" + project: immunofoundation + save_code: false + tags: [af3, multimer, finetune, mlp] + trainer: + min_epochs: 1 + max_epochs: 200 + accelerator: gpu + log_every_n_steps: 1 + deterministic: False + strategy: ddp + check_val_every_n_epoch: 2 + accumulate_grad_batches: 4 + checkpointer: + dirpath: ckpt/${experiment.wandb.project}/${experiment.wandb.name} + save_last: True + save_top_k: 3 + monitor: val/loss + mode: min + every_n_epochs: 2 diff --git a/configs/finetune_classifier.yaml b/configs/finetune_classifier.yaml new file mode 100644 index 0000000..104a56a --- /dev/null +++ b/configs/finetune_classifier.yaml @@ -0,0 +1,12 @@ +# Example finetuning config for ImmunoFoundation classifier +# Update the values as needed for your experiment + +csv: /home/am3826/scratch_pi_sk2433/am3826/finetuning/merged.csv +out_dir: /home/am3826/scratch_pi_sk2433/am3826/finetuning/run_af2_full +max_epochs: 50 +batch_size: 8 +num_workers: 8 +matmul_precision: medium +bio_dim: 32 +hidden_dims: [512, 256, 128, 64, 32] +class_weights: [0.61755624, 2.62664165] diff --git a/configs/train_af3.yaml b/configs/train_af3.yaml new file mode 100644 index 0000000..c796226 --- /dev/null +++ b/configs/train_af3.yaml @@ -0,0 +1,77 @@ +data: + batch_size: 32 + max_batch_size: 32 + num_workers: 0 + train_size: 0.8 + mono: false + prefetch_factor: 100 + csv_path: /home/am3826/scratch_pi_sk2433/am3826/iedb_af3/merged_af3.csv + mask: + mask_rate: 0.5 + max_distance: 8 + max_neighbors: 12 + structure: + adj: true + graph_type: knn + k: 15 + +model: + sequence: + model_type: esm + freeze_esm: true + aggregate: false + esm_dim: 1280 + rep_layer: 33 + esm_variant: esm2_t33_650M_UR50D + out_dim: 32 + n_heads: 8 + dim_ffn: 256 + n_layers: 10 + structure: + model_type: transformer + out_dim: 32 + n_heads: 8 + dim_ffn: 256 + n_layers: 4 + bio_chem: + model_type: mlp + n_bio_prop: 93 # DO NOT CHANGE THIS + hidden_dim: 64 + out_dim: 32 + n_layers: 4 + +optimizer: + lr: 1e-4 + weight_decay: 1e-6 + +experiment: + debug: False + num_devices: 2 + + wandb: + name: "af3_phase_1" + project: immunofoundation + save_code: false + tags: [af3, multimer] + mode: "online" + + optimizer: + lr: 0.0001 + + trainer: + min_epochs: 1 + max_epochs: 200 + accelerator: gpu + log_every_n_steps: 1 + deterministic: False + strategy: ddp + check_val_every_n_epoch: 2 + accumulate_grad_batches: 4 + + checkpointer: + dirpath: ckpt/${experiment.wandb.project}/${experiment.wandb.name} + save_last: True + save_top_k: 3 + monitor: val/total_loss + mode: min + every_n_epochs: 2 diff --git a/finetune_classifier.py b/finetune_classifier.py new file mode 100644 index 0000000..3ca1b1f --- /dev/null +++ b/finetune_classifier.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +"""Project-aligned finetuning script. + +This reuses the repository's `ImmunoMonomerDataset` and `ImmunoFoundationMonomerModule`. +It expects the CSV to include a `cif_path` column pointing to mmCIF files and an +`immunogenicity` column (0/1). It will attempt to load a provided checkpoint into +the backbone (best-effort) and finetune a classifier head. +""" +import os +from types import SimpleNamespace + +import torch +from torch.utils.data import DataLoader +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, Callback +import yaml + +from immunofoundation.data.components.ImmunoMonomerDataset import ImmunoMonomerDataset, custom_collate_mono +from immunofoundation.models.ImmunoFoundationMonomerModule import ImmunoFoundationMonomerModule +from immunofoundation.models.FinetuneClassifierModule import FinetuneClassifierModule + + +def make_data_cfg(csv_path, batch_size=16, train_size=0.8, num_workers=4): + # Build a lightweight data_cfg object compatible with ImmunoMonomerDataset + mask = SimpleNamespace(mask_rate=0.5, max_distance=8, max_neighbors=12) + structure = SimpleNamespace(adj=True, k=15) + data_cfg = SimpleNamespace(csv_path=csv_path, train_size=train_size, batch_size=batch_size, num_workers=num_workers, mono=True, mask=mask, structure=structure) + return data_cfg + + +def main(): + # Load config from YAML file + config_path = 'configs/finetune_classifier.yaml' # Update path if needed + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + + # Use config values + csv_path = config['csv'] + out_dir = config['out_dir'] + max_epochs = config.get('max_epochs', 50) + batch_size = config.get('batch_size', 8) + num_workers = config.get('num_workers', 8) + matmul_precision = config.get('matmul_precision', 'medium') + bio_dim = config.get('bio_dim', 32) + hidden_dims = config.get('hidden_dims', [512, 256, 128, 64, 32]) + class_weights = config.get('class_weights', None) + + os.makedirs(out_dir, exist_ok=True) + + # Optionally set float32 matmul precision to leverage Tensor Cores (perf vs numeric tradeoff) + if matmul_precision is not None and matmul_precision != 'none': + try: + torch.set_float32_matmul_precision(matmul_precision) + print(f"Set torch float32 matmul precision -> {matmul_precision}") + except Exception as e: + print(f"Warning: failed to set float32 matmul precision: {e}") + + data_cfg = make_data_cfg(csv_path, batch_size=batch_size, num_workers=num_workers) + train_ds = ImmunoMonomerDataset(data_cfg, is_training=True) + val_ds = ImmunoMonomerDataset(data_cfg, is_training=False) + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_mono, num_workers=num_workers) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_mono, num_workers=num_workers) + + # Build backbone using defaults similar to training config + seq = SimpleNamespace(model_type='esm', freeze_esm=True, aggregate=False, esm_dim=1280, rep_layer=33, esm_variant='esm2_t33_650M_UR50D', out_dim=32, n_heads=8, dim_ffn=256, n_layers=10) + struct = SimpleNamespace(model_type='transformer', out_dim=32, n_heads=8, dim_ffn=256, n_layers=4) + bio = SimpleNamespace(model_type='mlp', n_bio_prop=93, hidden_dim=64, out_dim=32, n_layers=4) + model_cfg = SimpleNamespace(sequence=seq, structure=struct, bio_chem=bio) + + backbone = ImmunoFoundationMonomerModule(model_cfg) + + # Attempt to load checkpoint into backbone (non-fatal) + checkpoint_path = config.get('checkpoint', None) + if checkpoint_path is not None: + try: + ckpt = torch.load(checkpoint_path, map_location='cpu') + state = ckpt.get('state_dict', ckpt) + backbone.load_state_dict(state, strict=False) + print('Loaded checkpoint (partial) into backbone') + except Exception as e: + print('Warning: failed to load checkpoint into backbone:', e) + + # infer num_classes from training dataset labels (if present) + sample = None + try: + sample = train_ds[0] + except Exception: + pass + if sample is not None and 'label' in sample: + # compute unique labels in small pass + labels = [] + for i in range(min(len(train_ds), 1000)): + try: + labels.append(train_ds[i]['label']) + except Exception: + break + num_classes = int(max(labels)) + 1 if len(labels) > 0 else 2 + else: + num_classes = 2 + + finetune = FinetuneClassifierModule( + backbone, + num_classes=num_classes, + bio_dim=model_cfg.bio_chem.out_dim, + hidden_dims=hidden_dims, + class_weights=class_weights + ) + + checkpoint_cb = ModelCheckpoint(dirpath=out_dir, filename='finetune-{epoch:02d}-{val_loss:.4f}', save_top_k=3, monitor='val/loss', mode='min') + # GPU-aware trainer: use a GPU if available + accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' + devices = 1 if torch.cuda.is_available() else None + # callback to print per-epoch metrics (loss/acc) to stdout for easy inspection + class PrintMetricsCallback(Callback): + def __init__(self, out_dir): + super().__init__() + self.out_dir = out_dir + self.csv_path = os.path.join(self.out_dir, 'metrics.csv') + # write header if not exists + if not os.path.exists(self.csv_path): + with open(self.csv_path, 'w') as fh: + fh.write('epoch,train_loss,train_acc,val_loss,val_acc\n') + + def on_validation_epoch_end(self, trainer, pl_module): + metrics = trainer.callback_metrics + # gather epoch and metrics with safe extraction + epoch = int(trainer.current_epoch) if hasattr(trainer, 'current_epoch') else None + def _safe_get(k): + v = metrics.get(k, None) + if v is None: + return None + try: + if hasattr(v, 'item'): + return float(v.item()) + elif isinstance(v, (int, float)): + return float(v) + else: + return float(v) + except Exception: + return None + + train_loss = _safe_get('train/loss') + train_acc = _safe_get('train/acc') + val_loss = _safe_get('val/loss') + val_acc = _safe_get('val/acc') + + out = { + 'epoch': epoch, + 'train/loss': train_loss, + 'train/acc': train_acc, + 'val/loss': val_loss, + 'val/acc': val_acc, + } + print('Epoch metrics:', out) + + # append to CSV (use empty string for missing) + with open(self.csv_path, 'a') as fh: + fh.write(f"{epoch},{'' if train_loss is None else train_loss},{'' if train_acc is None else train_acc},{'' if val_loss is None else val_loss},{'' if val_acc is None else val_acc}\n") + + print_cb = PrintMetricsCallback(out_dir=out_dir) + # Disable the default sanity validation steps (they can produce an extra validation run + # before training starts which confuses single-epoch metrics printing). Set to 0 so + # we only see the real validation at the end of each epoch. + trainer = Trainer(max_epochs=max_epochs, callbacks=[checkpoint_cb, print_cb], accelerator=accelerator, devices=devices, num_sanity_val_steps=0) + + # Run training (Lightning will log metrics per epoch). The FinetuneClassifierModule + # already logs train/loss and train/acc (on_step/on_epoch) and val/loss and val/acc (on_epoch). + trainer.fit(finetune, train_loader, val_loader) + + # Optionally run test evaluation on a held-out test split using the best checkpoint + run_test = config.get('test', False) + test_csv = config.get('test_csv', None) + if run_test: + # determine test CSV path + if test_csv is None: + test_csv = os.path.join(out_dir, 'test.csv') + if not os.path.exists(test_csv): + print(f"Test requested but no test CSV found at {test_csv}. Skipping test run.") + else: + print('Building test dataloader from', test_csv) + test_cfg = make_data_cfg(test_csv, batch_size=batch_size, num_workers=num_workers) + test_ds = ImmunoMonomerDataset(test_cfg, is_training=False) + test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_mono, num_workers=num_workers) + + # find best checkpoint saved by the checkpoint callback + best_ckpt = checkpoint_cb.best_model_path if getattr(checkpoint_cb, 'best_model_path', None) else None + if not best_ckpt or not os.path.exists(best_ckpt): + # fallback: pick the most recent finetune-*.ckpt in out_dir + import glob + ckpts = sorted(glob.glob(os.path.join(out_dir, 'finetune-*.ckpt')), key=os.path.getmtime) + best_ckpt = ckpts[-1] if ckpts else None + + if best_ckpt is None: + print('No checkpoint found to run test with. Skipping test run.') + else: + print('Running test with checkpoint:', best_ckpt) + # load model from checkpoint (ensures weights match expected backbone signature) + test_model = FinetuneClassifierModule.load_from_checkpoint(best_ckpt, backbone=backbone) + test_model.eval() + # run Lightning test and persist results + test_res = trainer.test(test_model, dataloaders=test_loader, ckpt_path=None) + try: + import json + out_path = os.path.join(out_dir, 'test_results.json') + with open(out_path, 'w') as fh: + json.dump(test_res, fh, indent=2) + print('Wrote test results to', out_path) + except Exception as e: + print('Test finished but failed to write results:', e) + + +if __name__ == '__main__': + main() diff --git a/immunofoundation/data/components/ImmunoMonomerDataset.py b/immunofoundation/data/components/ImmunoMonomerDataset.py index 084d666..bebb18b 100644 --- a/immunofoundation/data/components/ImmunoMonomerDataset.py +++ b/immunofoundation/data/components/ImmunoMonomerDataset.py @@ -22,15 +22,20 @@ def __init__(self, data_cfg, is_training): def _init_metadata(self): pdb_csv = pd.read_csv(self.data_cfg.csv_path) self.raw_csv = pdb_csv - train_len = int(pdb_csv.shape[0]*self.data_cfg.train_size)-1 - if self.is_training: - pdb_csv = pdb_csv.iloc[:train_len, :] + n = pdb_csv.shape[0] + train_size = float(getattr(self.data_cfg, 'train_size', 1.0)) + if train_size >= 1.0 or train_size <= 0.0: + # Use all data for both train and val if train_size is 1.0 or 0.0 self.csv = pdb_csv - print (f"Training: {len(self.csv)} samples") + print(f"Using all {len(self.csv)} samples (no split)") else: - pdb_csv = pdb_csv.iloc[train_len:, :] - self.csv = pdb_csv - print (f"Validation: {len(self.csv)} samples") + train_len = int(n * train_size) + if self.is_training: + self.csv = pdb_csv.iloc[:train_len, :] + print(f"Training: {len(self.csv)} samples") + else: + self.csv = pdb_csv.iloc[train_len:, :] + print(f"Validation: {len(self.csv)} samples") def _process_csv_row(self, csv_row): ''' @@ -54,6 +59,17 @@ def _process_csv_row(self, csv_row): final_features['adj'] = kneighbors_graph(coords, n_neighbors = self.data_cfg.structure.k) else: final_features['adj'] = None + # include label if present in CSV (optional for supervised finetuning) + # default label column name expected: 'immunogenicity' + if 'immunogenicity' in csv_row.index: + try: + final_features['label'] = int(csv_row['immunogenicity']) + except Exception: + # try mapping non-integer labels to 0/1 + if str(csv_row['immunogenicity']).lower() in ('true', '1', 'yes'): + final_features['label'] = 1 + else: + final_features['label'] = 0 return final_features def __getitem__(self, idx): @@ -143,14 +159,24 @@ def custom_collate_mono(batch_list): else: adjs = torch.utils.data.default_collate([0]*len(batch_list)) masks = torch.utils.data.default_collate([pad(torch.tensor(rec['mask']).float(), max_len) for rec in batch_list]) - sequences = torch.utils.data.default_collate([rec['sequence'] for rec in batch_list]) + # keep sequences as python list so ESM wrapper can batch-convert them + sequences = [rec['sequence'] for rec in batch_list] - return { + batch = { "coords": padded_coords, "adjs": adjs, "sequence": sequences, "masks": masks, } + # optional fields: label and biochem + if 'label' in batch_list[0]: + labels = torch.tensor([rec['label'] for rec in batch_list], dtype=torch.long) + batch['label'] = labels + if 'biochem' in batch_list[0]: + biochems = torch.utils.data.default_collate([rec['biochem'] for rec in batch_list]) + batch['biochem'] = biochems + + return batch def custom_collate_mono_sparse(batch_list): """ @@ -211,4 +237,4 @@ def custom_collate_mono_sparse(batch_list): "edge_index": edge_index, "batch_vec": batch_vec, "node_counts": node_counts_tensor, - } \ No newline at end of file + } diff --git a/immunofoundation/data/components/MLPDataset.py b/immunofoundation/data/components/MLPDataset.py new file mode 100644 index 0000000..8a7f34d --- /dev/null +++ b/immunofoundation/data/components/MLPDataset.py @@ -0,0 +1,80 @@ +import os +from typing import List, Optional + +import pandas as pd +import numpy as np +import torch +from torch.utils.data import Dataset + + +class MLPDataset(Dataset): + """Simple dataset to load CSV rows and return numeric features + label. + + Behavior: + - Attempts to infer label column (common names). If `label_col` provided, use it. + - Uses `feature_cols` if provided, otherwise selects all numeric columns except label and id-like columns. + - Returns dict with 'features' (float32 tensor) and 'label' (long tensor) + """ + + DEFAULT_LABEL_NAMES = ["label", "labels", "immunogenicity", "is_immunogenic", "y"] + + def __init__(self, csv_path: str, label_col: Optional[str] = None, feature_cols: Optional[List[str]] = None): + if not os.path.exists(csv_path): + raise FileNotFoundError(f"CSV not found: {csv_path}") + self.df = pd.read_csv(csv_path) + if label_col is None: + label_col = self._infer_label_col() + if label_col is None: + raise ValueError("Could not infer a label column. Please pass `label_col` explicitly.") + self.label_col = label_col + + if feature_cols is None: + # choose numeric dtypes except the label + numeric_df = self.df.select_dtypes(include=["number"]).copy() + if label_col in numeric_df.columns: + numeric_df = numeric_df.drop(columns=[label_col]) + feature_cols = list(numeric_df.columns) + if len(feature_cols) == 0: + raise ValueError("No numeric feature columns found. Provide `feature_cols` explicitly.") + + self.feature_cols = feature_cols + # drop rows with NaNs in selected cols + keep_mask = self.df[self.feature_cols + [self.label_col]].notnull().all(axis=1) + self.df = self.df.loc[keep_mask].reset_index(drop=True) + + # prepare numpy arrays for speed + self.features = self.df[self.feature_cols].astype(np.float32).to_numpy() + self.labels = self.df[self.label_col].to_numpy() + + # if labels are not integer, try to map + if not np.issubdtype(self.labels.dtype, np.integer): + unique = sorted(pd.unique(self.labels)) + self.label_map = {v: i for i, v in enumerate(unique)} + self.labels = np.array([self.label_map[x] for x in self.labels], dtype=np.int64) + else: + self.label_map = None + + def _infer_label_col(self) -> Optional[str]: + for name in self.DEFAULT_LABEL_NAMES: + if name in self.df.columns: + return name + # try common suffixes + for col in self.df.columns: + if col.lower().startswith("label") or col.lower().endswith("label"): + return col + return None + + def __len__(self): + return len(self.features) + + def __getitem__(self, idx): + x = torch.from_numpy(self.features[idx]).float() + y = int(self.labels[idx]) + return {"features": x, "label": torch.tensor(y, dtype=torch.long)} + + +def collate_fn(batch): + """Collate function to stack features and labels.""" + features = torch.stack([b["features"] for b in batch], dim=0) + labels = torch.stack([b["label"] for b in batch], dim=0) + return {"features": features, "labels": labels} diff --git a/immunofoundation/data/components/preprocess_pdb.py b/immunofoundation/data/components/preprocess_pdb.py index 91da607..bd893d0 100644 --- a/immunofoundation/data/components/preprocess_pdb.py +++ b/immunofoundation/data/components/preprocess_pdb.py @@ -2,19 +2,57 @@ from Bio.SeqUtils import seq1 import gzip import numpy as np +import io -parser = MMCIFParser(QUIET=True) - -def extract_ca_and_sequence(pdb_file): - parser = MMCIFParser(QUIET=True) - if pdb_file.endswith(".gz"): - with gzip.open(pdb_file, 'rt') as f: - structure = parser.get_structure('protein', f) +def _read_file_text(path): + if path.endswith('.gz'): + with gzip.open(path, 'rt') as f: + return f.read() else: - structure = parser.get_structure('protein', pdb_file) + with open(path, 'r') as f: + return f.read() + + +def extract_ca_and_sequence(pdb_file): + """Robustly extract CA coordinates and sequence from mmCIF or PDB files. + + Tries MMCIFParser first (for mmCIF), and falls back to PDBParser when that + fails or the file appears to be a PDB file. Accepts gzipped files as well. + Returns: (ca_coords_peptide, sequence_peptide, ca_coords_mhc, sequence_mhc) + """ + mmcif_parser = MMCIFParser(QUIET=True) + pdb_parser = PDBParser(QUIET=True) + + text = None + try: + text = _read_file_text(pdb_file) + except Exception as e: + raise ValueError(f"Could not read structure file {pdb_file}: {e}") + + # Heuristic: mmCIF files usually start with 'data_' directive + is_mmcif = text.lstrip().startswith('data_') + + structure = None + if is_mmcif: + try: + # MMCIFParser can accept a file handle + fh = io.StringIO(text) + structure = mmcif_parser.get_structure('protein', fh) + except Exception: + structure = None + + if structure is None: + # try PDB parser + try: + fh = io.StringIO(text) + structure = pdb_parser.get_structure('protein', fh) + except Exception as e: + # give a helpful error + raise ValueError(f"Failed to parse structure file {pdb_file} as mmCIF or PDB: {e}") + model = structure[0] - + ca_coords_peptide = [] sequence_peptide = [] if 'A' in model: @@ -22,30 +60,25 @@ def extract_ca_and_sequence(pdb_file): if 'CA' in residue: ca_atom = residue['CA'] ca_coords_peptide.append(tuple(ca_atom.get_coord())) - try: aa = seq1(residue.get_resname()) sequence_peptide.append(aa) - except KeyError: - # Handle non-standard residues + except Exception: sequence_peptide.append('X') - + ca_coords_mhc = [] sequence_mhc = [] - if 'B' in model: for residue in model['B']: if 'CA' in residue: ca_atom = residue['CA'] ca_coords_mhc.append(tuple(ca_atom.get_coord())) - try: aa = seq1(residue.get_resname()) sequence_mhc.append(aa) - except KeyError: - # Handle non-standard residues + except Exception: sequence_mhc.append('X') - + sequence_peptide = ''.join(sequence_peptide) sequence_mhc = ''.join(sequence_mhc) @@ -55,4 +88,10 @@ def extract_ca_and_sequence(pdb_file): def normalize_coords(coords): """Normalize coordinates to [-1, 1].""" coords = np.array(coords) - return 2 * (coords - coords.min()) / (coords.max() - coords.min()) - 1 + if coords.size == 0: + return coords + minv = coords.min() + maxv = coords.max() + if maxv == minv: + return coords - minv + return 2 * (coords - minv) / (maxv - minv) - 1 diff --git a/immunofoundation/models/FinetuneClassifierModule.py b/immunofoundation/models/FinetuneClassifierModule.py new file mode 100644 index 0000000..0c641fb --- /dev/null +++ b/immunofoundation/models/FinetuneClassifierModule.py @@ -0,0 +1,140 @@ +""" +Finetune classifier wrapper for the ImmunoFoundation backbone. + +""" + +import torch +import torch.nn as nn +from pytorch_lightning import LightningModule + + +class FinetuneClassifierModule(LightningModule): + """LightningModule that reuses ImmunoFoundationMonomerModule as an encoder + and places a small classifier head on top to predict immunogenicity. + """ + + def __init__(self, backbone, num_classes: int = 2, bio_dim: int = 0, hidden_dims=None, lr: float = 1e-3, class_weights=None): + super().__init__() + self.backbone = backbone + self.lr = lr + self.bio_dim = int(bio_dim) + + # infer dims from backbone config if available + seq_dim = getattr(self.backbone.model_cfg.sequence, 'out_dim', None) + struct_dim = getattr(self.backbone.model_cfg.structure, 'out_dim', None) + if seq_dim is None or struct_dim is None: + raise ValueError('Backbone must have sequence.out_dim and structure.out_dim in its model_cfg') + + in_dim = int(seq_dim) + int(struct_dim) + int(bio_dim) + + # default to a simple 5-layer MLP if not provided + if hidden_dims is None: + hidden_dims = [512, 256, 128, 64, 32] + layers = [] + + prev = in_dim + for h in hidden_dims: + layers.append(nn.Linear(prev, h)) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(0.1)) + prev = h + layers.append(nn.Linear(prev, int(num_classes))) + self.classifier = nn.Sequential(*layers) + + # Set up class weights for CrossEntropyLoss if provided + if class_weights is not None: + weights = torch.tensor(class_weights, dtype=torch.float32) + self.criterion = nn.CrossEntropyLoss(weight=weights) + else: + self.criterion = nn.CrossEntropyLoss() + + def forward(self, batch): + # backbone.encode returns seq_embeddings, struct_embeddings + seq_embeddings, struct_embeddings = self.backbone.encode(batch) + + # pooling: masked mean over residues if masks present + if 'masks' in batch: + mask = batch['masks'].unsqueeze(-1) # (B, L, 1) + inv_mask = (1 - mask) + seq_pool = (seq_embeddings * inv_mask).sum(dim=1) / inv_mask.sum(dim=1).clamp(min=1) + struct_pool = (struct_embeddings * inv_mask).sum(dim=1) / inv_mask.sum(dim=1).clamp(min=1) + else: + seq_pool = seq_embeddings.mean(dim=1) + struct_pool = struct_embeddings.mean(dim=1) + + bio = batch.get('biochem') + if bio is None: + # pad with zeros to match expected bio_dim + if self.bio_dim > 0: + print("[WARNING] bio_dim > 0 but 'biochem' features not found in batch. Padding with zeros.") + bio = torch.zeros(seq_pool.size(0), self.bio_dim, dtype=torch.float32, device=seq_pool.device) + else: + bio = None + + # Classifier input: [pooled, bio] if bio_dim > 0 and bio present in batch + if self.bio_dim > 0 and batch.get('bio') is not None: + print("[WARNING] bio_dim > 0 but 'biochem' features not found in batch. Using 'bio' features instead.") + bio = batch['bio'] + x = torch.cat([seq_pool, struct_pool, bio], dim=-1) + else: + x = torch.cat([seq_pool, struct_pool], dim=-1) + + # Save for debug printing in training/validation step + self._last_classifier_input = x.detach().cpu() if not hasattr(self, '_last_classifier_input') else x.detach().cpu() + return self.classifier(x) + + def training_step(self, batch, batch_idx): + logits = self(batch) + labels = batch.get('label') + if labels is None: + labels = batch.get('labels') + if labels is None: + raise ValueError('Batch must contain label for supervised finetuning') + if not torch.is_tensor(labels): + labels = torch.tensor(labels, dtype=torch.long, device=logits.device) + else: + labels = labels.to(logits.device) + loss = self.criterion(logits, labels) + preds = logits.argmax(dim=-1) + acc = (preds == labels).float().mean() + + # # Debug print for first batch of each epoch + # if batch_idx == 0: + # print("[DEBUG][TRAIN] Classifier input (first batch):", self._last_classifier_input[:5]) + # print("[DEBUG][TRAIN] Predictions:", preds[:10].detach().cpu().numpy()) + # print("[DEBUG][TRAIN] Labels:", labels[:10].detach().cpu().numpy()) + + self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=False) + self.log('train/acc', acc, on_step=True, on_epoch=True, prog_bar=False) + return loss + + def validation_step(self, batch, batch_idx): + logits = self(batch) + labels = batch.get('label') + if labels is None: + labels = batch.get('labels') + if labels is None: + raise ValueError( + "Validation batch is missing labels (key 'label' or 'labels'). " + "Ensure your merged CSV contains an 'immunogenicity' column and that the validation split includes labels." + ) + if not torch.is_tensor(labels): + labels = torch.tensor(labels, dtype=torch.long, device=logits.device) + else: + labels = labels.to(logits.device) + loss = self.criterion(logits, labels) + preds = logits.argmax(dim=-1) + acc = (preds == labels).float().mean() + + # # Debug print for first batch of each epoch + # if batch_idx == 0: + # print("[DEBUG][VAL] Classifier input (first batch):", self._last_classifier_input[:5]) + # print("[DEBUG][VAL] Predictions:", preds[:10].detach().cpu().numpy()) + # print("[DEBUG][VAL] Labels:", labels[:10].detach().cpu().numpy()) + + self.log('val/loss', loss, on_epoch=True) + self.log('val/acc', acc, on_epoch=True) + return loss + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.lr) diff --git a/inspect_checkpoint.py b/inspect_checkpoint.py new file mode 100644 index 0000000..e38b9d3 --- /dev/null +++ b/inspect_checkpoint.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +"""Utility: print keys and tensor shapes inside a checkpoint file. + +Helps determine which submodules can be loaded into a new model. +""" +import argparse +import torch + + +def summarize_state_dict(state_dict, top_n=200): + keys = list(state_dict.keys()) + print(f"Total keys: {len(keys)}") + print("First keys (name -> shape):") + for k in keys[:top_n]: + v = state_dict[k] + try: + shape = v.shape + except Exception: + shape = type(v) + print(f" {k} -> {shape}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('ckpt', help='Path to checkpoint (.ckpt or .pth)') + parser.add_argument('--list-only', action='store_true') + args = parser.parse_args() + + ckpt = torch.load(args.ckpt, map_location='cpu') + if isinstance(ckpt, dict) and 'state_dict' in ckpt: + sd = ckpt['state_dict'] + else: + sd = ckpt + + if not isinstance(sd, dict): + print('Checkpoint content is not a mapping; top-level type:', type(sd)) + return + + summarize_state_dict(sd) + + +if __name__ == '__main__': + main() diff --git a/scripts/build_finetune_csvs.py b/scripts/build_finetune_csvs.py new file mode 100644 index 0000000..d565e6e --- /dev/null +++ b/scripts/build_finetune_csvs.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +"""Build merged CSVs for finetuning from a PDB/mmCIF directory and a labels CSV. + +Produces a merged CSV and optional train/val(/test) splits written to an output directory +so downstream training scripts can consume them directly. + +Usage examples: + scripts/build_finetune_csvs.py --pdb-dir /path/to/pdbs --labels-csv labels.csv --out-dir data/finetune + scripts/build_finetune_csvs.py --pdb-dir /path/to/pdbs --out-dir data/finetune +""" +import argparse +import os +import glob +import pandas as pd +from sklearn.model_selection import train_test_split + + +def stem(path): + base = os.path.basename(path) + # remove common extensions + for ext in ('.cif.gz', '.pdb.gz', '.cif', '.pdb'): + if base.endswith(ext): + return base[: -len(ext)] + return os.path.splitext(base)[0] + + +def build_paths_df(pdb_dir): + exts = ('**/*.cif', '**/*.cif.gz', '**/*.pdb', '**/*.pdb.gz') + files = [] + for e in exts: + files.extend(glob.glob(os.path.join(pdb_dir, e), recursive=True)) + files = sorted(files) + if len(files) == 0: + raise ValueError(f'No structure files found under {pdb_dir}') + df = pd.DataFrame({'cif_path': files}) + df['stem'] = df['cif_path'].apply(stem) + return df + + +def try_merge(paths_df, labels_df, merge_on=None): + # If user provided explicit merge key, use it + if merge_on is not None: + if merge_on not in labels_df.columns: + raise ValueError(f"merge_on column '{merge_on}' not found in labels CSV") + # try direct merge + merged = paths_df.merge(labels_df, left_on='stem', right_on=merge_on, how='left') + return merged + + # If labels contain `allele` and `mut_pep`, prefer constructing the filename stem as + # "{allele}_{mut_pep}" which matches the PDB filenames in your dataset (e.g. + # HLA-A*02:01_AAAAQQIQV.pdb). + if 'allele' in labels_df.columns and 'mut_pep' in labels_df.columns: + labels_df['_allele_mut_stem'] = labels_df['allele'].astype(str).str.strip() + '_' + labels_df['mut_pep'].astype(str).str.strip() + merged = paths_df.merge(labels_df, left_on='stem', right_on='_allele_mut_stem', how='left') + return merged + + # try to detect a column in labels that overlaps with stems + best_col = None + best_overlap = 0 + stems = set(paths_df['stem'].astype(str).unique()) + for c in labels_df.columns: + vals = labels_df[c].astype(str).str.split('.').str[0].str.strip().unique() + overlap = len(set(vals).intersection(stems)) + if overlap > best_overlap: + best_overlap = overlap + best_col = c + + if best_col is not None and best_overlap > 0: + labels_df['_stem_candidate'] = labels_df[best_col].astype(str).str.split('.').str[0].str.strip() + merged = paths_df.merge(labels_df, left_on='stem', right_on='_stem_candidate', how='left') + return merged + + # fallback: if lengths match, concat by index + if len(labels_df) == len(paths_df): + merged = pd.concat([paths_df.reset_index(drop=True), labels_df.reset_index(drop=True)], axis=1) + return merged + + raise ValueError('Could not automatically merge labels CSV to paths. Provide --merge-on to specify the key.') + + +def split_and_write(merged_df, out_dir, label_col='immunogenicity', train_size=0.8, test_size=0.1, random_state=42): + os.makedirs(out_dir, exist_ok=True) + merged_path = os.path.join(out_dir, 'merged.csv') + merged_df.to_csv(merged_path, index=False) + print('Wrote merged CSV ->', merged_path) + + if label_col in merged_df.columns and merged_df[label_col].notnull().any(): + labels = merged_df[label_col].astype(int) + # compute splits + if test_size > 0: + train_val_idx, test_idx = train_test_split(merged_df.index.tolist(), test_size=test_size, stratify=labels, random_state=random_state) + rel_val = (1.0 - train_size) / (1.0 - test_size) + train_idx, val_idx = train_test_split(train_val_idx, test_size=rel_val, stratify=labels.iloc[train_val_idx], random_state=random_state) + else: + train_idx, val_idx = train_test_split(merged_df.index.tolist(), test_size=(1.0 - train_size), stratify=labels, random_state=random_state) + test_idx = [] + + train_df = merged_df.loc[train_idx].reset_index(drop=True) + val_df = merged_df.loc[val_idx].reset_index(drop=True) + train_df.to_csv(os.path.join(out_dir, 'train.csv'), index=False) + val_df.to_csv(os.path.join(out_dir, 'val.csv'), index=False) + print(f'Wrote train ({len(train_df)} rows) and val ({len(val_df)} rows) CSVs to', out_dir) + if len(test_idx) > 0: + test_df = merged_df.loc[test_idx].reset_index(drop=True) + test_df.to_csv(os.path.join(out_dir, 'test.csv'), index=False) + print('Wrote test CSV ->', os.path.join(out_dir, 'test.csv')) + else: + print(f"Label column '{label_col}' not present or empty in merged CSV; only wrote merged paths.") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--pdb-dir', required=True) + parser.add_argument('--labels-csv', default=None) + parser.add_argument('--merge-on', default=None, help='Column in labels CSV to match to filename stem') + parser.add_argument('--out-dir', default='data/finetune', help='Directory to write merged and split CSVs') + parser.add_argument('--label-col', default='immunogenicity') + parser.add_argument('--train-size', type=float, default=0.8) + parser.add_argument('--test-size', type=float, default=0.1) + args = parser.parse_args() + + paths_df = build_paths_df(args.pdb_dir) + print(f'Found {len(paths_df)} structure files under {args.pdb_dir}') + + if args.labels_csv is None: + # just write paths CSV + os.makedirs(args.out_dir, exist_ok=True) + p = os.path.join(args.out_dir, 'paths.csv') + paths_df.to_csv(p, index=False) + print('Wrote paths CSV ->', p) + return + + labels_df = pd.read_csv(args.labels_csv) + merged = try_merge(paths_df, labels_df, merge_on=args.merge_on) + split_and_write(merged, args.out_dir, label_col=args.label_col, train_size=args.train_size, test_size=args.test_size) + + +if __name__ == '__main__': + main() diff --git a/scripts/check_af3_data.py b/scripts/check_af3_data.py new file mode 100644 index 0000000..8f95e33 --- /dev/null +++ b/scripts/check_af3_data.py @@ -0,0 +1,9 @@ +import pandas as pd + +af3_csv = "/home/am3826/scratch_pi_sk2433/am3826/iedb_af3/merged_af3.csv" +df = pd.read_csv(af3_csv) +display(df.head()) +print(f"Rows: {len(df)}") +print(f"Columns: {df.columns.tolist()}") +print(df['immunogenicity'].value_counts()) +print(df.isnull().sum()) diff --git a/scripts/finetune_af3.sh b/scripts/finetune_af3.sh new file mode 100644 index 0000000..b0232f8 --- /dev/null +++ b/scripts/finetune_af3.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#SBATCH --job-name=finetune_af3_mlp +#SBATCH --output=outputs/finetune_af3_mlp_%j.out +#SBATCH --error=outputs/finetune_af3_mlp_%j.err +#SBATCH --partition=pi_sk2433_gpu +#SBATCH --gres=gpu:2 +#SBATCH --mem=64G +#SBATCH --cpus-per-task=8 +#SBATCH --time=48:00:00 +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=am3826@columbia.edu + +# Activate environment +source ~/.bashrc +conda activate immuno +cd /home/am3826/workspace/ImmunoFoundation + +# Run finetuning for MLP classifier on AF3 data +python train.py experiment=finetune_af3.yaml diff --git a/scripts/finetune_classifier.sh b/scripts/finetune_classifier.sh new file mode 100644 index 0000000..b735111 --- /dev/null +++ b/scripts/finetune_classifier.sh @@ -0,0 +1,16 @@ +#!/bin/bash +#SBATCH --job-name=finetune_if +#SBATCH --partition=gpu_h200 +#SBATCH --gres=gpu:h200:1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=48G +#SBATCH --time=6:00:00 +#SBATCH --output=finetune_%j.out +#SBATCH --error=finetune_%j.err + +module load Python/3.12.3-GCCcore-13.3.0 +module load PyTorch/2.7.1-foss-2024a-CUDA-12.6.0 +source /home/am3826/workspace/ImmunoFoundation/.venv/bin/activate +cd /home/am3826/workspace/ImmunoFoundation + +python3 finetune_classifier.py diff --git a/train.py b/train.py index 7e25991..3725e50 100644 --- a/train.py +++ b/train.py @@ -32,15 +32,32 @@ def __init__(self, *, cfg: DictConfig): self._cfg = cfg self._data_cfg = cfg.data self._exp_cfg = cfg.experiment + # Build backbone (monomer or multimer) if self._data_cfg.mono: - self._model = ImmunoFoundationMonomerModule(self._cfg.model) + backbone = ImmunoFoundationMonomerModule(self._cfg.model) else: - self._model = ImmunoFoundationMultimerModule(self._cfg.model) + backbone = ImmunoFoundationMultimerModule(self._cfg.model) + + # If classifier module specified in config, wrap backbone with FinetuneClassifierModule + classifier_cfg = getattr(self._cfg.model, 'classifier', None) + if classifier_cfg is not None and getattr(classifier_cfg, 'module', None) == 'FinetuneClassifierModule': + from immunofoundation.models.FinetuneClassifierModule import FinetuneClassifierModule + self._model = FinetuneClassifierModule( + backbone=backbone, + num_classes=getattr(classifier_cfg, 'num_classes', 2), + bio_dim=getattr(classifier_cfg, 'bio_dim', 0), + hidden_dims=getattr(classifier_cfg, 'hidden_dims', [512, 256, 128, 64, 32]), + lr=getattr(classifier_cfg, 'lr', 1e-4), + class_weights=getattr(classifier_cfg, 'class_weights', None) + ) + else: + self._model = backbone + # Load checkpoint if specified init_ckpt = cfg.get("init_checkpoint", None) if init_ckpt: ckpt = torch.load(init_ckpt, map_location="cpu") - # TODO: strict=False is required for cross-module transfer (e.g. monomer → multimer) + # TODO: strict=False is required for cross-module transfer (e.g. monomer 92 multimer) # because the multimer module has keys absent from the monomer checkpoint (bio_model, # biochem_decoder). For same-module continuation strict=True would be safer. missing, unexpected = self._model.load_state_dict(ckpt["state_dict"], strict=False)