From cc4181c40752da145779fca90c688fbf0f858fcc Mon Sep 17 00:00:00 2001 From: fdrgsp Date: Fri, 29 May 2026 13:27:57 -0400 Subject: [PATCH] fix: refactor Cellpose processing logic and update documentation for batch size --- src/cali/detection/_detection_runner.py | 89 +++++++------------------ src/cali/gui/_detection_gui.py | 5 +- src/cali/sqlmodel/_model.py | 3 +- 3 files changed, 28 insertions(+), 69 deletions(-) diff --git a/src/cali/detection/_detection_runner.py b/src/cali/detection/_detection_runner.py index 8b6c9056..6d393318 100644 --- a/src/cali/detection/_detection_runner.py +++ b/src/cali/detection/_detection_runner.py @@ -221,57 +221,34 @@ def _run_cellpose_detection( "TiffCollectionReader instance." ) - # Process images in batches n_positions = len(position_indices) - n_batches = (n_positions + batch_size - 1) // batch_size - - cali_logger.info( - f"Processing {n_positions} positions in {n_batches} batches of {batch_size}" - ) + cali_logger.info(f"Processing {n_positions} positions") msg = ( f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]} - " "cali_logger - INFO - 🔍 Running Cellpose" ) - for batch_idx in tqdm(range(n_batches), desc=msg): + for pos_idx in tqdm(position_indices, desc=msg): if self._check_for_abort_requested(): return - # Load one batch of images - start_idx = batch_idx * batch_size - end_idx = min(start_idx + batch_size, n_positions) - batch_positions = position_indices[start_idx:end_idx] - - batch_images = [] - batch_metadata = [] - batch_pos_indices = [] - - for pos_idx in batch_positions: - if self._check_for_abort_requested(): - return + data, meta = dataset.isel(p=pos_idx, metadata=True) - data, meta = dataset.isel(p=pos_idx, metadata=True) - - # Preprocess data: max projection from half to end of stack - if data.ndim == 3: # (t, y, x) - data_half_to_end = data[data.shape[0] // 2 :, :, :] - image = data_half_to_end.max(axis=0) - else: # already 2D - image = data - - batch_images.append(image) - batch_metadata.append(meta) - batch_pos_indices.append(pos_idx) + # Preprocess data: max projection from half to end of stack + if data.ndim == 3: # (t, y, x) + image = data[data.shape[0] // 2 :, :, :].max(axis=0) + else: # already 2D + image = data if self._check_for_abort_requested(): return - # Process this batch - batch_masks = self._process_single_batch( + mask = self._run_cellpose_on_image( model=model, - images=batch_images, + image=image, diameter=diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, + batch_size=batch_size, min_size=min_size, normalize=normalize, ) @@ -279,17 +256,9 @@ def _run_cellpose_detection( if self._check_for_abort_requested(): return - # Yield FOV objects for this batch - for pos_idx, meta, masks_2d in zip( - batch_pos_indices, batch_metadata, batch_masks - ): - if self._check_for_abort_requested(): - return - - fov_result = self._create_fov_with_rois(pos_idx, meta, masks_2d) - - if fov_result: - yield fov_result + fov_result = self._create_fov_with_rois(pos_idx, meta, mask) + if fov_result: + yield fov_result def _check_for_abort_requested(self) -> bool: """Check if cancellation has been requested.""" @@ -298,42 +267,30 @@ def _check_for_abort_requested(self) -> bool: return True return False - def _process_single_batch( + def _run_cellpose_on_image( self, model: CellposeModel, - images: list[np.ndarray], + image: np.ndarray, diameter: float | None, cellprob_threshold: float, flow_threshold: float, + batch_size: int, min_size: int, normalize: bool, - ) -> list[np.ndarray]: - """Process a single batch of images using Cellpose. - - Returns - ------- - list[np.ndarray] - List of 2D label masks, one per image - """ + ) -> np.ndarray: + """Run Cellpose on a single image and return the processed mask.""" from cellpose.utils import fill_holes_and_remove_small_masks - # Run Cellpose on batch - masks, _, _ = model.eval( - images, + mask, _, _ = model.eval( + image, diameter=diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, normalize=normalize, - batch_size=len(images), + batch_size=batch_size, ) - # Post-process masks - processed_masks = [] - for mask in masks: - mask = fill_holes_and_remove_small_masks(mask, min_size=min_size) - processed_masks.append(mask) - - return processed_masks + return fill_holes_and_remove_small_masks(mask, min_size=min_size) def _create_fov_with_rois( self, diff --git a/src/cali/gui/_detection_gui.py b/src/cali/gui/_detection_gui.py index e872cb1d..b90a2777 100644 --- a/src/cali/gui/_detection_gui.py +++ b/src/cali/gui/_detection_gui.py @@ -402,8 +402,9 @@ def __init__(self, parent: QWidget | None = None) -> None: # BATCH SIZE WIDGET ----------------------------------------------------------- self._batch_wdg = QWidget(self) self._batch_wdg.setToolTip( - "Number of images to process per batch. Higher values are faster " - "but use more memory." + "Number of 256x256 image tiles processed simultaneously on the GPU. " + "Higher values are faster on large GPUs but use more GPU memory. " + "Reduce if you run out of GPU memory." ) batch_layout = QHBoxLayout(self._batch_wdg) batch_layout.setContentsMargins(0, 0, 0, 0) diff --git a/src/cali/sqlmodel/_model.py b/src/cali/sqlmodel/_model.py index 8d451a60..54e97c3e 100644 --- a/src/cali/sqlmodel/_model.py +++ b/src/cali/sqlmodel/_model.py @@ -776,7 +776,8 @@ class DetectionSettings(SQLModel, table=True): normalize : bool Whether to normalize images before detection batch_size : int - Number of images to process per batch. By default, 8. + Number of 256x256 GPU tiles processed simultaneously by Cellpose. + Higher values are faster on large GPUs but use more GPU memory. By default, 8. use_gpu : bool Whether to use GPU acceleration if available. By default, True. """