diff --git a/gui/app.py b/gui/app.py index c2ad469..494b20c 100644 --- a/gui/app.py +++ b/gui/app.py @@ -36,7 +36,6 @@ from PyQt5.QtCore import Qt, QThread, pyqtSignal from PyQt5.QtGui import QDragEnterEvent, QDropEvent - STYLE_SHEET = """ QGroupBox { font-weight: bold; @@ -325,6 +324,127 @@ def get_color(row, col): self.error.emit(f"Error: {str(e)}\n{traceback.format_exc()}") +def _run_fusion_pipeline( + tiff_path, + do_registration, + blend_pixels, + downsample_factor, + fusion_mode, + flatfield=None, + darkfield=None, + registration_z=None, + registration_t=0, + registration_channel=0, + log_fn=None, +): + """Shared stitching pipeline used by both single and batch workers. + + Returns the output path string. Raises on failure. + """ + import gc + import json + import shutil + import time + + import numpy as np + from tilefusion import TileFusion + + def log(msg): + if log_fn: + log_fn(msg) + + p = Path(tiff_path) + output_path = p.parent / f"{p.stem}_fused.ome.zarr" + output_folder = p.parent / f"{p.stem}_fused" + + if output_path.exists(): + shutil.rmtree(output_path) + if output_folder.exists(): + shutil.rmtree(output_folder) + + metrics_path = p.parent / "metrics.json" + if metrics_path.exists(): + metrics_path.unlink() + for m in p.parent.glob("metrics_*.json"): + m.unlink() + + step_start = time.time() + tf = TileFusion( + tiff_path, + output_path=output_path, + blend_pixels=blend_pixels, + downsample_factors=(downsample_factor, downsample_factor), + flatfield=flatfield, + darkfield=darkfield, + registration_z=registration_z, + registration_t=registration_t, + channel_to_use=registration_channel, + ) + load_time = time.time() - step_start + log(f"Loaded {tf.n_tiles} tiles ({tf.Y}x{tf.X}) [{load_time:.1f}s]") + + if len(tf._unique_regions) > 1: + log(f"Multi-region dataset: {tf._unique_regions}") + tf.stitch_all_regions() + return str(output_folder) + + step_start = time.time() + if do_registration: + log("Computing registration...") + tf.refine_tile_positions_with_cross_correlation() + tf.save_pairwise_metrics(metrics_path) + reg_time = time.time() - step_start + log(f"Registration complete: {len(tf.pairwise_metrics)} pairs [{reg_time:.1f}s]") + else: + tf.threshold = 1.0 + log("Using stage positions (no registration)") + + step_start = time.time() + log("Optimizing positions...") + tf.optimize_shifts(method="TWO_ROUND_ITERATIVE", rel_thresh=0.5, abs_thresh=2.0, iterative=True) + gc.collect() + + tf._tile_positions = [ + tuple(np.array(pos) + off * np.array(tf.pixel_size)) + for pos, off in zip(tf._tile_positions, tf.global_offsets) + ] + opt_time = time.time() - step_start + log(f"Positions optimized [{opt_time:.1f}s]") + + step_start = time.time() + log("Computing fused image space...") + tf._compute_fused_image_space() + tf._pad_to_chunk_multiple() + log(f"Output size: {tf.padded_shape[0]} x {tf.padded_shape[1]}") + + scale0 = output_path / "scale0" / "image" + scale0.parent.mkdir(parents=True, exist_ok=True) + tf._create_fused_tensorstore(output_path=scale0) + + mode_label = "direct placement" if fusion_mode == "direct" else "blended" + log(f"Fusing tiles ({mode_label})...") + tf._fuse_tiles(mode=fusion_mode) + fuse_time = time.time() - step_start + log(f"Tiles fused [{fuse_time:.1f}s]") + + ngff = { + "attributes": {"_ARRAY_DIMENSIONS": ["t", "c", "y", "x"]}, + "zarr_format": 3, + "node_type": "group", + } + with open(output_path / "scale0" / "zarr.json", "w") as f: + json.dump(ngff, f, indent=2) + + step_start = time.time() + log("Building multiscale pyramid...") + tf._create_multiscales(output_path, factors=tf.multiscale_factors) + tf._generate_ngff_zarr3_json(output_path, resolution_multiples=tf.resolution_multiples) + pyramid_time = time.time() - step_start + log(f"Pyramid built [{pyramid_time:.1f}s]") + + return str(output_path) + + class FusionWorker(QThread): """Worker thread for running tile fusion.""" @@ -360,146 +480,129 @@ def __init__( def run(self): try: - from tilefusion import TileFusion - import shutil import time - import json - import gc start_time = time.time() - self.progress.emit(f"Loading {self.tiff_path}...") - output_path = ( - Path(self.tiff_path).parent / f"{Path(self.tiff_path).stem}_fused.ome.zarr" - ) - # Multi-region output folder - output_folder = Path(self.tiff_path).parent / f"{Path(self.tiff_path).stem}_fused" - - # Remove existing outputs if present - if output_path.exists(): - shutil.rmtree(output_path) - if output_folder.exists(): - shutil.rmtree(output_folder) - - # Also remove metrics if not doing registration - metrics_path = Path(self.tiff_path).parent / "metrics.json" - if metrics_path.exists(): - metrics_path.unlink() - # Remove multi-region metrics - for m in Path(self.tiff_path).parent.glob("metrics_*.json"): - m.unlink() - - step_start = time.time() - tf = TileFusion( + self.output_path = _run_fusion_pipeline( self.tiff_path, - output_path=output_path, - blend_pixels=self.blend_pixels, - downsample_factors=(self.downsample_factor, self.downsample_factor), + self.do_registration, + self.blend_pixels, + self.downsample_factor, + self.fusion_mode, flatfield=self.flatfield, darkfield=self.darkfield, registration_z=self.registration_z, registration_t=self.registration_t, - channel_to_use=self.registration_channel, + registration_channel=self.registration_channel, + log_fn=self.progress.emit, ) - load_time = time.time() - step_start - self.progress.emit(f"Loaded {tf.n_tiles} tiles ({tf.Y}x{tf.X} each) [{load_time:.1f}s]") - - # Check for multi-region dataset - if len(tf._unique_regions) > 1: - self.progress.emit(f"Multi-region dataset: {tf._unique_regions}") - tf.stitch_all_regions() - # Output folder for multi-region - output_folder = Path(self.tiff_path).parent / f"{Path(self.tiff_path).stem}_fused" - elapsed_time = time.time() - start_time - self.output_path = str(output_folder) - self.finished.emit(str(output_folder), elapsed_time) - return - # Registration step - step_start = time.time() - if self.do_registration: - self.progress.emit("Computing registration...") - tf.refine_tile_positions_with_cross_correlation() - tf.save_pairwise_metrics(metrics_path) - reg_time = time.time() - step_start - self.progress.emit( - f"Registration complete: {len(tf.pairwise_metrics)} pairs [{reg_time:.1f}s]" - ) - else: - tf.threshold = 1.0 # Skip registration - self.progress.emit("Using stage positions (no registration)") + elapsed_time = time.time() - start_time + self.finished.emit(self.output_path, elapsed_time) - # Optimize shifts - step_start = time.time() - self.progress.emit("Optimizing positions...") - tf.optimize_shifts( - method="TWO_ROUND_ITERATIVE", rel_thresh=0.5, abs_thresh=2.0, iterative=True - ) - gc.collect() + except Exception as e: + import traceback - import numpy as np + self.error.emit(f"Error: {str(e)}\n{traceback.format_exc()}") - tf._tile_positions = [ - tuple(np.array(pos) + off * np.array(tf.pixel_size)) - for pos, off in zip(tf._tile_positions, tf.global_offsets) - ] - opt_time = time.time() - step_start - self.progress.emit(f"Positions optimized [{opt_time:.1f}s]") - - # Compute fused space - step_start = time.time() - self.progress.emit("Computing fused image space...") - tf._compute_fused_image_space() - tf._pad_to_chunk_multiple() - self.progress.emit(f"Output size: {tf.padded_shape[0]} x {tf.padded_shape[1]}") - - # Create output store - scale0 = output_path / "scale0" / "image" - scale0.parent.mkdir(parents=True, exist_ok=True) - tf._create_fused_tensorstore(output_path=scale0) - - # Fuse tiles - mode_label = "direct placement" if self.fusion_mode == "direct" else "blended" - self.progress.emit(f"Fusing tiles ({mode_label})...") - tf._fuse_tiles(mode=self.fusion_mode) - fuse_time = time.time() - step_start - self.progress.emit(f"Tiles fused [{fuse_time:.1f}s]") - - # Write metadata - ngff = { - "attributes": {"_ARRAY_DIMENSIONS": ["t", "c", "y", "x"]}, - "zarr_format": 3, - "node_type": "group", - } - with open(output_path / "scale0" / "zarr.json", "w") as f: - json.dump(ngff, f, indent=2) - - # Build multiscales - step_start = time.time() - self.progress.emit("Building multiscale pyramid...") - tf._create_multiscales(output_path, factors=tf.multiscale_factors) - tf._generate_ngff_zarr3_json(output_path, resolution_multiples=tf.resolution_multiples) - pyramid_time = time.time() - step_start - self.progress.emit(f"Pyramid built [{pyramid_time:.1f}s]") - elapsed_time = time.time() - start_time - self.output_path = str(output_path) - self.finished.emit(str(output_path), elapsed_time) +class BatchFusionWorker(QThread): + """Worker thread for batch processing multiple folders/files.""" + + progress = pyqtSignal(str) + item_started = pyqtSignal(int, int, str) # (current_index, total, item_name) + item_finished = pyqtSignal(int, int) # (current_index, total) for progress bar + finished = pyqtSignal(int, int, float) # (succeeded, failed, total_time) + error = pyqtSignal(str) + + def __init__( + self, + paths, + do_registration, + blend_pixels, + downsample_factor, + fusion_mode="blended", + flatfield=None, + darkfield=None, + ): + super().__init__() + self.paths = paths + self.do_registration = do_registration + self.blend_pixels = blend_pixels + self.downsample_factor = downsample_factor + self.fusion_mode = fusion_mode + self.flatfield = flatfield + self.darkfield = darkfield + + def _log(self, index, total, name, message): + self.progress.emit(f"[{index + 1}/{total} {name}] {message}") + def run(self): + try: + self._run_batch() except Exception as e: import traceback - self.error.emit(f"Error: {str(e)}\n{traceback.format_exc()}") + self.error.emit(f"Batch processing failed: {e}\n{traceback.format_exc()}") + self.finished.emit(0, len(self.paths), 0.0) + + def _run_batch(self): + import time + + total = len(self.paths) + succeeded = 0 + failed = 0 + batch_start = time.time() + + for idx, tiff_path in enumerate(self.paths): + name = Path(tiff_path).name + self.item_started.emit(idx, total, name) + + try: + + def log_fn(msg, _idx=idx, _total=total, _name=name): + self._log(_idx, _total, _name, msg) + + _run_fusion_pipeline( + tiff_path, + self.do_registration, + self.blend_pixels, + self.downsample_factor, + self.fusion_mode, + flatfield=self.flatfield, + darkfield=self.darkfield, + log_fn=log_fn, + ) + succeeded += 1 + except MemoryError: + failed += 1 + self._log(idx, total, name, "FAILED: Out of memory. Stopping batch.") + self.item_finished.emit(idx, total) + break + except Exception as e: + import traceback + + failed += 1 + self._log(idx, total, name, f"FAILED: {e}") + self._log(idx, total, name, traceback.format_exc()) + + self.item_finished.emit(idx, total) + + total_time = time.time() - batch_start + self.finished.emit(succeeded, failed, total_time) class DropArea(QFrame): - """Drag and drop area for files or folders.""" + """Drag and drop area for files or folders. Supports single and multi-drop.""" fileDropped = pyqtSignal(str) + filesDropped = pyqtSignal(list) # list of path strings (directories or .tif/.tiff files) _default_style = "border: 2px dashed #888; border-radius: 8px; background: #fafafa;" _hover_style = "border: 2px dashed #0071e3; border-radius: 8px; background: #e8f4ff;" _active_style = "border: 2px solid #34c759; border-radius: 8px; background: #f0fff4;" + _warn_style = "border: 2px solid #ff9500; border-radius: 8px; background: #fff8f0;" def __init__(self): super().__init__() @@ -522,7 +625,11 @@ def __init__(self): self.label.setStyleSheet("border: none; background: transparent;") layout.addWidget(self.label) - self.file_path = None + self.file_paths = [] + + @property + def file_path(self): + return self.file_paths[0] if self.file_paths else None def dragEnterEvent(self, event: QDragEnterEvent): if event.mimeData().hasUrls(): @@ -535,18 +642,36 @@ def dragLeaveEvent(self, event): else: self.setStyleSheet(self._default_style) + def _is_valid_path(self, file_path): + """Check if a path is a valid folder or TIFF file.""" + path = Path(file_path) + return path.is_dir() or file_path.endswith((".tif", ".tiff")) + def dropEvent(self, event: QDropEvent): urls = event.mimeData().urls() - if urls: - file_path = urls[0].toLocalFile() - path = Path(file_path) - if path.is_dir() or file_path.endswith((".tif", ".tiff")): - self.setFile(file_path) - self.fileDropped.emit(file_path) + if not urls: + self.setStyleSheet(self._default_style) + return + + valid_paths = [] + invalid_names = [] + for url in urls: + file_path = url.toLocalFile() + if self._is_valid_path(file_path): + valid_paths.append(file_path) else: - self.setStyleSheet(self._default_style) - else: + invalid_names.append(Path(file_path).name) + + if not valid_paths: self.setStyleSheet(self._default_style) + return + + if len(valid_paths) == 1: + self.setFile(valid_paths[0]) + self.fileDropped.emit(valid_paths[0]) + else: + self.setFiles(valid_paths, invalid_names) + self.filesDropped.emit(valid_paths) def mousePressEvent(self, event): from PyQt5.QtWidgets import QMenu @@ -571,7 +696,7 @@ def mousePressEvent(self, event): self.fileDropped.emit(folder_path) def setFile(self, file_path): - self.file_path = file_path + self.file_paths = [file_path] path = Path(file_path) self.setStyleSheet(self._active_style) self.icon_label.setText("āœ…") @@ -580,6 +705,19 @@ def setFile(self, file_path): else: self.label.setText(path.name) + def setFiles(self, paths, invalid_names=None): + """Set multiple paths and update the display for batch mode.""" + self.file_paths = list(paths) + names = [Path(p).name for p in paths] + label_lines = f"šŸ“¦ {len(paths)} items selected:\n" + "\n".join(f" {n}" for n in names) + if invalid_names: + label_lines += f"\n⚠ Skipped: {', '.join(invalid_names)}" + self.setStyleSheet(self._warn_style) + else: + self.setStyleSheet(self._active_style) + self.icon_label.setText("āœ…") + self.label.setText(label_lines) + class FlatfieldDropArea(QFrame): """Small drag and drop area for flatfield .npy files.""" @@ -731,6 +869,9 @@ def __init__(self): self.regions = [] # List of region names for multi-region outputs self.is_multi_region = False + # Batch processing state + self.batch_paths = [] + # Flatfield correction state self.flatfield = None # Shape (C, Y, X) or None self.darkfield = None # Shape (C, Y, X) or None @@ -754,6 +895,7 @@ def setup_ui(self): # Input drop area (no wrapper group to avoid double border) self.drop_area = DropArea() self.drop_area.fileDropped.connect(self.on_file_dropped) + self.drop_area.filesDropped.connect(self.on_files_dropped) layout.addWidget(self.drop_area) # Preview section @@ -1017,7 +1159,33 @@ def setup_ui(self): layout.addStretch() + @property + def is_batch_mode(self): + return len(self.batch_paths) > 1 + + def _update_batch_mode_ui(self): + """Update UI to reflect batch vs single mode.""" + batch = self.is_batch_mode + self.preview_button.setEnabled(not batch) + self.calc_flatfield_button.setEnabled(not batch and self.drop_area.file_path is not None) + self.reg_zt_widget.setEnabled(not batch) + if batch: + self.preview_button.setToolTip("Preview is not available in batch mode") + self.calc_flatfield_button.setToolTip( + "Calculate flatfield from a single dataset first, then load it for batch" + ) + self.reg_zt_widget.setToolTip("Registration z/t/channel uses defaults in batch mode") + else: + self.preview_button.setToolTip("") + self.calc_flatfield_button.setToolTip("") + self.reg_zt_widget.setToolTip("") + self.napari_button.setToolTip("") + def on_file_dropped(self, file_path): + """Handle single file/folder drop — exits batch mode.""" + self.batch_paths = [] + self._update_batch_mode_ui() + path = Path(file_path) if path.is_dir(): self.log(f"Selected SQUID folder: {file_path}") @@ -1078,6 +1246,58 @@ def on_file_dropped(self, file_path): else: self.flatfield_checkbox.setChecked(False) + def on_files_dropped(self, paths): + """Handle multi-drop — validate each path and enter batch mode.""" + from tilefusion import TileFusion + + self.log_text.clear() + self.log(f"Validating {len(paths)} dropped items...") + + valid_paths = [] + invalid_names = [] + for p in paths: + name = Path(p).name + try: + with TileFusion(p): + pass + valid_paths.append(p) + self.log(f" āœ“ {name}") + except Exception as e: + invalid_names.append(name) + self.log(f" āœ— {name}: {e}") + + if not valid_paths: + self.log("No valid datasets found.") + self.run_button.setEnabled(False) + return + + if invalid_names: + self.log( + f"\n{len(valid_paths)} of {len(paths)} valid. " + f"Skipped: {', '.join(invalid_names)}" + ) + + # Single valid item — fall back to normal single-item flow + if len(valid_paths) == 1: + self.log(f"\nOnly 1 valid item — using single mode.") + self.drop_area.setFile(valid_paths[0]) + self.on_file_dropped(valid_paths[0]) + return + + # Multiple valid items — enter batch mode + self.drop_area.setFiles(valid_paths, invalid_names) + self.batch_paths = valid_paths + self._update_batch_mode_ui() + self.run_button.setEnabled(True) + + self.dataset_n_z = 1 + self.dataset_n_t = 1 + self.dataset_n_channels = 1 + self.dataset_channel_names = [] + + if not invalid_names: + self.log(f"\nAll {len(valid_paths)} items valid. Ready to run batch.") + def on_registration_toggled(self, checked): self.downsample_widget.setVisible(checked) self._update_reg_zt_controls() @@ -1346,6 +1566,12 @@ def run_stitching(self): flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None + if self.is_batch_mode: + self._run_batch(blend_pixels, fusion_mode, flatfield, darkfield) + else: + self._run_single(blend_pixels, fusion_mode, flatfield, darkfield) + + def _run_single(self, blend_pixels, fusion_mode, flatfield, darkfield): # Get registration z/t values (None means use default middle z) registration_z = self.reg_z_spin.value() if self.dataset_n_z > 1 else None registration_t = self.reg_t_spin.value() if self.dataset_n_t > 1 else 0 @@ -1370,6 +1596,52 @@ def run_stitching(self): self.worker.error.connect(self.on_fusion_error) self.worker.start() + def _run_batch(self, blend_pixels, fusion_mode, flatfield, darkfield): + total = len(self.batch_paths) + self.progress_bar.setRange(0, total) + self.progress_bar.setValue(0) + self.log(f"Starting batch processing: {total} items\n") + + self.worker = BatchFusionWorker( + self.batch_paths, + self.registration_checkbox.isChecked(), + blend_pixels, + self.downsample_spin.value(), + fusion_mode, + flatfield=flatfield, + darkfield=darkfield, + ) + self.worker.progress.connect(self.log) + self.worker.error.connect(self.on_fusion_error) + self.worker.item_started.connect(self._on_batch_item_started) + self.worker.item_finished.connect(self._on_batch_item_finished) + self.worker.finished.connect(self._on_batch_finished) + self.worker.start() + + def _on_batch_item_started(self, index, total, name): + self.log(f"\n{'='*40}") + self.log(f"Processing {index + 1}/{total}: {name}") + self.log(f"{'='*40}") + + def _on_batch_item_finished(self, index, total): + self.progress_bar.setValue(index + 1) + + def _on_batch_finished(self, succeeded, failed, total_time): + self.progress_bar.setVisible(False) + self.progress_bar.setRange(0, 0) # Reset to indeterminate for next run + self.batch_paths = [] + self.run_button.setEnabled(True) + self.napari_button.setEnabled(True) + self._update_batch_mode_ui() + + minutes = int(total_time // 60) + seconds = total_time % 60 + time_str = f"{minutes}m {seconds:.1f}s" if minutes > 0 else f"{seconds:.1f}s" + + self.log(f"\n{'='*40}") + self.log(f"Batch complete! {succeeded} succeeded, {failed} failed. Total time: {time_str}") + self.log(f"{'='*40}") + def on_fusion_finished(self, output_path, elapsed_time): self.output_path = output_path self.progress_bar.setVisible(False) @@ -1486,6 +1758,13 @@ def _on_region_slider_changed(self, value): def open_in_napari(self): if not self.output_path: + try: + import napari + + napari.Viewer() + napari.run() + except Exception as e: + self.log(f"Error opening Napari: {e}") return # Determine the actual zarr path to open diff --git a/scripts/view_in_napari.py b/scripts/view_in_napari.py index c38628c..a4697d8 100644 --- a/scripts/view_in_napari.py +++ b/scripts/view_in_napari.py @@ -3,6 +3,7 @@ Simple script to view fused OME-Zarr in napari. Works around napari-ome-zarr plugin issues with Zarr v3. """ + import sys from pathlib import Path