Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions chorus/analysis/_igv_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _ensure_igv_local() -> Path | None:
# 3x stronger than the genome-wide top 1%. Bins above 3.0 clip but
# this is rare for real biology.
_DISPLAY_MAX = 3.0

_HIGH_RES_ORACLES = ["chrombpnet", "legnet"] # for visualization mean vs max pooling

def apply_floor_rescale(
normalizer,
Expand All @@ -162,12 +162,14 @@ def apply_floor_rescale(
(:func:`chorus.analysis.causal._build_causal_igv`) so both reports
render the same "scaled-by-default" IGV tracks.
"""

if normalizer is None or oracle_name is None:
return False, ref_vals, alt_vals
from .normalization import PerTrackNormalizer
if not isinstance(normalizer, PerTrackNormalizer):
return False, ref_vals, alt_vals
floor_p = _LAYER_FLOOR_PCTILE.get(layer, _DEFAULT_FLOOR_PCTILE)

ref_fl = normalizer.perbin_floor_rescale_batch(
oracle_name, assay_id, ref_vals,
floor_pctile=floor_p,
Expand All @@ -186,6 +188,29 @@ def apply_floor_rescale(
return False, ref_vals, alt_vals
return True, ref_fl, alt_fl

def _calculate_track_bin_size(
resolution: int,
window_bp: int,
source_oracle: str,
) -> tuple[int, str]:
"""Calculate appropriate bin size and aggregation method.

Returns:
(bin_size, aggregation_method) where aggregation is "mean" or "max"
"""

# For chrombpnet or legnet models, apply max pooling
# For any other oracle, apply mean pooling
if source_oracle == "chrombpnet":
bin_size = 20
return bin_size, "mean"
elif source_oracle == "legnet":
return resolution, "max"

# Fallback: return 3_000 features per bin
num_features = 3_000
bin_size = window_bp // num_features
return bin_size, "mean"

def build_igv_html(
ref_pred,
Expand Down Expand Up @@ -279,8 +304,9 @@ def build_igv_html(
layer = classify_track_layer(ref_track)
rgb = _LAYER_COLORS.get(layer, "70,130,180")

t_start = ref_track.prediction_interval.reference.start
t_res = ref_track.resolution
actual_bp_in_array = len(ref_track.values) * t_res
t_start = variant_pos - (actual_bp_in_array // 2)

ref_vals = ref_track.values
alt_vals = alt_track.values
Expand All @@ -291,15 +317,21 @@ def build_igv_html(
floor_ok, ref_vals, alt_vals = apply_floor_rescale(
normalizer, oracle_name, assay_id, layer, ref_vals, alt_vals,
)


track_bin_size, agg_method = _calculate_track_bin_size(
t_res, window_bp, first.source_model,
)

ref_features = _downsample_to_features(
ref_vals, variant_chrom, t_start, t_res, bin_size,
ref_vals, variant_chrom, t_start, t_res, track_bin_size,
skip_zeros=not floor_ok,
aggregation_method=agg_method
)
alt_features = _downsample_to_features(
alt_vals, variant_chrom, t_start, t_res, bin_size,
alt_vals, variant_chrom, t_start, t_res, track_bin_size,
skip_zeros=not floor_ok,
)
aggregation_method=agg_method
)

group_id = assay_id.replace(":", "_").replace(" ", "_")
if floor_ok:
Expand All @@ -322,6 +354,7 @@ def build_igv_html(
display_name = f"{ref_track.assay_type}:{ref_track.cell_type}"

# Merged overlay: ref (grey) + alt (coloured) on same panel
source_model = first.source_model
tracks.append({
"name": f"{display_name}{name_suffix}",
"type": "merged",
Expand All @@ -331,13 +364,15 @@ def build_igv_html(
"type": "wig",
"name": f"{display_name} ref",
"color": f"rgb({_REF_COLOR})",
"windowFunction": "max" if source_model in _HIGH_RES_ORACLES else "mean",
**scale_cfg,
"features": ref_features,
},
{
"type": "wig",
"name": f"{display_name} alt",
"color": f"rgb({rgb})",
"windowFunction": "max" if source_model in _HIGH_RES_ORACLES else "mean",
**scale_cfg,
"features": alt_features,
},
Expand Down Expand Up @@ -408,6 +443,7 @@ def _downsample_to_features(
resolution: int,
bin_size: int,
skip_zeros: bool = True,
aggregation_method: str = "mean"
) -> list[dict]:
"""Downsample a signal array into IGV wig features.

Expand All @@ -430,7 +466,11 @@ def _downsample_to_features(

for i in range(0, n, bins_per):
chunk = vals[i:i + bins_per]
v = float(np.mean(chunk))

if aggregation_method == "mean":
v = float(np.mean(chunk))
else:
v = float(np.max(chunk))

# Skip near-zero bins to reduce JSON size (only for raw data)
if skip_zeros and abs(v) < threshold * 0.1:
Expand Down
19 changes: 18 additions & 1 deletion chorus/analysis/multi_oracle_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,14 @@ def build_unified_igv_html(self) -> str:
window_bp = (
ref_t.prediction_interval.reference.end - t_start
)
bin_size = max(1, window_bp // 3000)

from ._igv_report import (
_calculate_track_bin_size,
_HIGH_RES_ORACLES
)
bin_size, agg_method = _calculate_track_bin_size(
t_res, window_bp, ref_t.source_model
)

ref_vals = ref_t.values
alt_vals = alt_t.values
Expand All @@ -375,10 +382,12 @@ def build_unified_igv_html(self) -> str:
ref_features = _downsample_to_features(
ref_vals, pred_chrom, t_start, t_res, bin_size,
skip_zeros=not floor_ok,
aggregation_method=agg_method
)
alt_features = _downsample_to_features(
alt_vals, pred_chrom, t_start, t_res, bin_size,
skip_zeros=not floor_ok,
aggregation_method=agg_method
)
if floor_ok:
scale_cfg = {"min": 0, "max": _DISPLAY_MAX,
Expand All @@ -391,6 +400,12 @@ def build_unified_igv_html(self) -> str:
# Prefix track label with oracle name so stacked panels
# are identifiable at a glance.
panel_label = f"{oracle_name} · {short}"

if oracle_name == "legnet":
# LentiMPRA uses per-track normalization (no per-bin background distribution).
panel_label = f"{panel_label} (per-track norm)"

source_model = ref_t.source_model
tracks.append({
"name": panel_label,
"type": "merged",
Expand All @@ -400,13 +415,15 @@ def build_unified_igv_html(self) -> str:
"type": "wig",
"name": f"{panel_label} ref",
"color": f"rgb({_REF_COLOR})",
"windowFunction": "max" if source_model in _HIGH_RES_ORACLES else "mean",
**scale_cfg,
"features": ref_features,
},
{
"type": "wig",
"name": f"{panel_label} alt",
"color": f"rgb({rgb})",
"windowFunction": "max" if source_model in _HIGH_RES_ORACLES else "mean",
**scale_cfg,
"features": alt_features,
},
Expand Down
67 changes: 62 additions & 5 deletions chorus/analysis/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,54 @@ def perbin_percentile_batch(
) -> np.ndarray | None:
"""Map per-bin values to genome-wide percentiles [0, 1] for visualization."""
return self._lookup_batch(oracle_name, track_id, "perbin_cdfs", raw_values, signed=False)

def _find_matching_cdf(self, entry: dict, idx: int, track_id: str) -> np.ndarray | None:
"""Retrieve the CDF array for track at *idx*, falling back through CDF types.

Tries perbin_cdfs → summary_cdfs → effect_cdfs, returning the first
valid array found.
"""
for cdf_key in ("perbin_cdfs", "summary_cdfs", "effect_cdfs"):
cdf_matrix = entry.get(cdf_key)
if cdf_matrix is None:
continue

try:
cdf = cdf_matrix[idx]
except (IndexError, TypeError):
continue

if cdf is not None and len(cdf) > 0:
logger.debug(f"Using {cdf_key} for '{track_id}'")
return cdf

logger.warning(f"No valid CDF found for '{track_id}' (index {idx})")
return None

def _match_track_id(self, track_id: str, track_index: dict) -> str | None:
"""Find *track_id* in *track_index*, trying common alternative formats.

Returns the matched key, or None if no match is found.
"""
if track_id in track_index:
return track_id

# Build candidate list from track_id components
parts = track_id.split(":")
candidates = [
track_id.replace(":", "_"),
track_id.replace("_", ":"),
]
if len(parts) >= 2:
candidates.append(parts[-1]) # Last component only

for candidate in candidates:
if candidate in track_index:
logger.debug(f"Track ID matched: '{track_id}' → '{candidate}'")
return candidate

logger.warning(f"Track '{track_id}' not found (candidates: {candidates})")
return None

def perbin_floor_rescale_batch(
self,
Expand Down Expand Up @@ -583,17 +631,26 @@ def perbin_floor_rescale_batch(
entry = self._ensure_loaded(oracle_name)
if entry is None:
return None
cdf_matrix = entry.get("perbin_cdfs")
if cdf_matrix is None:

# Match track ID with possible alternative formats
track_index = entry.get("track_index", {})
matched_id = self._match_track_id(track_id, track_index)
if matched_id is None:
return None
idx = entry["track_index"].get(track_id)
if idx is None:

idx = track_index[matched_id]

# Find appropriate CDF (perbin → summary → effect fallback)
cdf = self._find_matching_cdf(entry, idx, matched_id)
if cdf is None:
return None
cdf = cdf_matrix[idx]

# Compute thresholds and rescale
n = len(cdf)
floor = float(cdf[min(int(floor_pctile * n), n - 1)])
peak = float(cdf[min(int(peak_pctile * n), n - 1)])
denom = max(peak - floor, 1e-9)

out = (raw_values.astype(np.float64) - floor) / denom
return np.clip(out, 0.0, max_value)

Expand Down
Loading
Loading