Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 23 additions & 66 deletions src/cali/detection/_detection_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,75 +221,44 @@ def _run_cellpose_detection(
"TiffCollectionReader instance."
)

# Process images in batches
n_positions = len(position_indices)
n_batches = (n_positions + batch_size - 1) // batch_size

cali_logger.info(
f"Processing {n_positions} positions in {n_batches} batches of {batch_size}"
)
cali_logger.info(f"Processing {n_positions} positions")
msg = (
f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]} - "
"cali_logger - INFO - 🔍 Running Cellpose"
)
for batch_idx in tqdm(range(n_batches), desc=msg):
for pos_idx in tqdm(position_indices, desc=msg):
if self._check_for_abort_requested():
return

# Load one batch of images
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, n_positions)
batch_positions = position_indices[start_idx:end_idx]

batch_images = []
batch_metadata = []
batch_pos_indices = []

for pos_idx in batch_positions:
if self._check_for_abort_requested():
return
data, meta = dataset.isel(p=pos_idx, metadata=True)

data, meta = dataset.isel(p=pos_idx, metadata=True)

# Preprocess data: max projection from half to end of stack
if data.ndim == 3: # (t, y, x)
data_half_to_end = data[data.shape[0] // 2 :, :, :]
image = data_half_to_end.max(axis=0)
else: # already 2D
image = data

batch_images.append(image)
batch_metadata.append(meta)
batch_pos_indices.append(pos_idx)
# Preprocess data: max projection from half to end of stack
if data.ndim == 3: # (t, y, x)
image = data[data.shape[0] // 2 :, :, :].max(axis=0)
else: # already 2D
image = data

if self._check_for_abort_requested():
return

# Process this batch
batch_masks = self._process_single_batch(
mask = self._run_cellpose_on_image(
model=model,
images=batch_images,
image=image,
diameter=diameter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold,
batch_size=batch_size,
min_size=min_size,
normalize=normalize,
)

if self._check_for_abort_requested():
return

# Yield FOV objects for this batch
for pos_idx, meta, masks_2d in zip(
batch_pos_indices, batch_metadata, batch_masks
):
if self._check_for_abort_requested():
return

fov_result = self._create_fov_with_rois(pos_idx, meta, masks_2d)

if fov_result:
yield fov_result
fov_result = self._create_fov_with_rois(pos_idx, meta, mask)
if fov_result:
yield fov_result

def _check_for_abort_requested(self) -> bool:
"""Check if cancellation has been requested."""
Expand All @@ -298,42 +267,30 @@ def _check_for_abort_requested(self) -> bool:
return True
return False

def _process_single_batch(
def _run_cellpose_on_image(
self,
model: CellposeModel,
images: list[np.ndarray],
image: np.ndarray,
diameter: float | None,
cellprob_threshold: float,
flow_threshold: float,
batch_size: int,
min_size: int,
normalize: bool,
) -> list[np.ndarray]:
"""Process a single batch of images using Cellpose.

Returns
-------
list[np.ndarray]
List of 2D label masks, one per image
"""
) -> np.ndarray:
"""Run Cellpose on a single image and return the processed mask."""
from cellpose.utils import fill_holes_and_remove_small_masks

# Run Cellpose on batch
masks, _, _ = model.eval(
images,
mask, _, _ = model.eval(
image,
diameter=diameter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold,
normalize=normalize,
batch_size=len(images),
batch_size=batch_size,
)

# Post-process masks
processed_masks = []
for mask in masks:
mask = fill_holes_and_remove_small_masks(mask, min_size=min_size)
processed_masks.append(mask)

return processed_masks
return fill_holes_and_remove_small_masks(mask, min_size=min_size)

def _create_fov_with_rois(
self,
Expand Down
5 changes: 3 additions & 2 deletions src/cali/gui/_detection_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,9 @@ def __init__(self, parent: QWidget | None = None) -> None:
# BATCH SIZE WIDGET -----------------------------------------------------------
self._batch_wdg = QWidget(self)
self._batch_wdg.setToolTip(
"Number of images to process per batch. Higher values are faster "
"but use more memory."
"Number of 256x256 image tiles processed simultaneously on the GPU. "
"Higher values are faster on large GPUs but use more GPU memory. "
"Reduce if you run out of GPU memory."
)
batch_layout = QHBoxLayout(self._batch_wdg)
batch_layout.setContentsMargins(0, 0, 0, 0)
Expand Down
3 changes: 2 additions & 1 deletion src/cali/sqlmodel/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,8 @@ class DetectionSettings(SQLModel, table=True):
normalize : bool
Whether to normalize images before detection
batch_size : int
Number of images to process per batch. By default, 8.
Number of 256x256 GPU tiles processed simultaneously by Cellpose.
Higher values are faster on large GPUs but use more GPU memory. By default, 8.
use_gpu : bool
Whether to use GPU acceleration if available. By default, True.
"""
Expand Down
Loading