diff --git a/software/control/core/multi_point_utils.py b/software/control/core/multi_point_utils.py index 067ce06a6..589aec854 100644 --- a/software/control/core/multi_point_utils.py +++ b/software/control/core/multi_point_utils.py @@ -8,6 +8,7 @@ from squid.abc import CameraFrame if TYPE_CHECKING: + from control.core.qc import FOVMetrics, PolicyDecision from control.slack_notifier import TimepointStats, AcquisitionStats @@ -125,3 +126,6 @@ class MultiPointControllerFunctions: # Zarr frame written callback - called when subprocess completes writing a frame # Args: (fov, time_point, z_index, channel_name, region_idx) signal_zarr_frame_written: Callable[[int, int, int, str, int], None] = lambda *a, **kw: None + # QC callbacks + signal_qc_metrics_updated: Callable[["FOVMetrics"], None] = lambda *a, **kw: None + signal_qc_policy_decision: Callable[["PolicyDecision"], None] = lambda *a, **kw: None diff --git a/software/control/core/multi_point_worker.py b/software/control/core/multi_point_worker.py index fa17650d7..73fceae27 100644 --- a/software/control/core/multi_point_worker.py +++ b/software/control/core/multi_point_worker.py @@ -55,6 +55,7 @@ ensure_plate_resolution_in_well_resolutions, ) from control.core.backpressure import BackpressureController, BackpressureValues +from control.core.qc import QCConfig, QCJob, QCPolicy, QCPolicyConfig, QCResult, TimepointMetricsStore from squid.config import CameraPixelFormat # Module-level logger for static methods @@ -86,6 +87,8 @@ def __init__( slack_notifier=None, prewarmed_job_runner: Optional[JobRunner] = None, prewarmed_bp_values: Optional["BackpressureValues"] = None, + qc_config: Optional[QCConfig] = None, + qc_policy_config: Optional[QCPolicyConfig] = None, ): self._log = squid.logging.get_logger(__class__.__name__) self._timing = utils.TimingManager("MultiPointWorker Timer Manager") @@ -161,6 +164,14 @@ def __init__( self.num_fovs = 0 self.total_scans = 0 self._last_time_point_z_pos = {} + self._qc_config = qc_config or QCConfig() + self._qc_policy_config = qc_policy_config or QCPolicyConfig() + if self._qc_policy_config.enabled and not self._qc_config.enabled: + self._log.warning("QC policy is enabled but QC metrics collection is disabled — policy checks will not run") + self._qc_policy = QCPolicy(self._qc_policy_config) if self._qc_policy_config.enabled else None + self._metrics_store: Optional[TimepointMetricsStore] = None + self._qc_pause_event = threading.Event() + self._qc_pause_event.set() # starts unpaused self.scan_region_fov_coords_mm = ( acquisition_parameters.scan_position_information.scan_region_fov_coords_mm.copy() ) @@ -264,7 +275,7 @@ def __init__( # For now, use 1 runner per job class. There's no real reason/rationale behind this, though. The runners # can all run any job type. But 1 per is a reasonable arbitrary arrangement while we don't have a lot # of job types. If we have a lot of custom jobs, this could cause problems via resource hogging. - self._job_runners: List[Tuple[Type[Job], JobRunner]] = [] + self._job_runners: List[Tuple[Type[Job], Optional[JobRunner]]] = [] self._log.info(f"Acquisition.USE_MULTIPROCESSING = {Acquisition.USE_MULTIPROCESSING}") # Get the current log file path to share with subprocess workers @@ -370,6 +381,10 @@ def __init__( # Subprocess starts warming up in background - don't block here self._job_runners.append((job_class, job_runner)) + + if self._qc_config.enabled: + self._job_runners.append((QCJob, None)) + self._abort_on_failed_job = abort_on_failed_jobs self._first_job_dispatched = False # Track if we've waited for subprocess warmup @@ -644,6 +659,8 @@ def run_single_time_point(self): self._timepoint_fov_count = 0 self._laser_af_successes = 0 self._laser_af_failures = 0 + if self._qc_config.enabled: + self._metrics_store = TimepointMetricsStore(timepoint_index=self.time_point) self.microcontroller.enable_joystick(False) self._log.debug("multipoint acquisition - time point " + str(self.time_point + 1)) @@ -663,6 +680,26 @@ def run_single_time_point(self): with self._timing.get_timer("run_coordinate_acquisition"): self.run_coordinate_acquisition(current_path) + # QC policy check + if self._qc_policy is not None and self._qc_policy_config.check_after_timepoint: + if self._metrics_store is not None: + try: + decision = self._qc_policy.check_timepoint(self._metrics_store) + self.callbacks.signal_qc_policy_decision(decision) + if decision.should_pause: + self._log.info( + f"QC policy flagged {len(decision.flagged_fovs)} FOVs — " + f"pausing acquisition. Call resume_from_qc_pause() to continue." + ) + self._qc_pause_event.clear() + # Block until resumed or aborted + while not self._qc_pause_event.is_set(): + if self.abort_requested_fn(): + break + self._qc_pause_event.wait(timeout=0.5) + except Exception as e: + self._log.error(f"QC policy evaluation failed for timepoint {self.time_point}: {e}") + # Save plate view for this timepoint if self._generate_downsampled_views and self._downsampled_view_manager is not None: # Wait for pending downsampled view jobs to complete @@ -678,6 +715,14 @@ def run_single_time_point(self): # finished region scan self.coordinates_pd.to_csv(os.path.join(current_path, "coordinates.csv"), index=False, header=True) + # Save QC metrics + if self._qc_config.enabled and self._metrics_store is not None: + qc_csv_path = os.path.join(current_path, "qc_metrics.csv") + try: + self._metrics_store.save(qc_csv_path) + except OSError as e: + self._log.error(f"Failed to save QC metrics to {qc_csv_path}: {e}") + # Send Slack timepoint notification via callback (allows main thread to capture screenshot) if self._slack_notifier is not None: try: @@ -805,6 +850,23 @@ def _summarize_runner_outputs(self, drain_all: bool = False) -> SummarizeResult: return SummarizeResult(none_failed=none_failed, had_results=had_results) + def resume_from_qc_pause(self) -> None: + """Resume acquisition after QC policy pause. Called by UI.""" + self._log.info("Resuming acquisition from QC pause") + self._qc_pause_event.set() + + def _handle_qc_result(self, qc_result: QCResult) -> None: + """Store QC metrics and emit signal.""" + if qc_result.error: + self._log.error( + f"QC metric calculation failed for region={qc_result.metrics.fov_id.region_id} " + f"fov={qc_result.metrics.fov_id.fov_index}: {qc_result.error}" + ) + # Always store metrics (positional data is valid even on partial failure) + if self._metrics_store is not None: + self._metrics_store.add(qc_result.metrics) + self.callbacks.signal_qc_metrics_updated(qc_result.metrics) + def _summarize_job_result(self, job_result: JobResult) -> bool: """ Prints a summary, then returns True if the result was successful or False otherwise. @@ -833,6 +895,9 @@ def _summarize_job_result(self, job_result: JobResult) -> bool: elif isinstance(job_result.result, ZarrWriteResult): r = job_result.result self.callbacks.signal_zarr_frame_written(r.fov, r.time_point, r.z_index, r.channel_name, r.region_idx) + # Handle QCResult - store metrics and emit signal + elif isinstance(job_result.result, QCResult): + self._handle_qc_result(job_result.result) return True def _handle_downsampled_view_result(self, result: DownsampledViewResult) -> None: @@ -888,9 +953,31 @@ def _create_job(self, job_class: Type[Job], info: CaptureInfo, image: np.ndarray """ if job_class == DownsampledViewJob: return self._create_downsampled_view_job(info, image) + elif job_class == QCJob: + return self._create_qc_job(info, image) else: return job_class(capture_info=info, capture_image=JobImage(image_array=image)) + def _create_qc_job(self, info: CaptureInfo, image: np.ndarray) -> Optional[QCJob]: + """Create a QCJob for the given capture. + + Returns None for non-canonical frames to avoid overwriting metrics. + Only the configured channel and z-slice is used for QC. + """ + if info.z_index != self._qc_config.qc_z_index or info.configuration_idx != self._qc_config.qc_channel_index: + return None + previous_z = None + if self._qc_config.calculate_z_diff_from_last_timepoint and self.time_point > 0: + fov_key = (info.region_id, info.fov) + if fov_key in self._last_time_point_z_pos: + previous_z = self._last_time_point_z_pos[fov_key] * 1000 # mm -> um + return QCJob( + capture_info=info, + capture_image=JobImage(image_array=image), + qc_config=self._qc_config, + previous_timepoint_z=previous_z, + ) + def _create_downsampled_view_job(self, info: CaptureInfo, image: np.ndarray) -> Optional[DownsampledViewJob]: """Create a DownsampledViewJob for the given capture. @@ -1424,9 +1511,9 @@ def _image_callback(self, camera_frame: CameraFrame): return else: try: - # NOTE(imo): We don't have any way of people using results, so for now just - # grab and ignore it. result = job.run() + if isinstance(result, QCResult): + self._handle_qc_result(result) except Exception: self._log.exception("Failed to execute job, abandoning acquisition!") self.request_abort_fn() diff --git a/software/control/core/qc.py b/software/control/core/qc.py new file mode 100644 index 000000000..afd1358b2 --- /dev/null +++ b/software/control/core/qc.py @@ -0,0 +1,306 @@ +"""Quality Control system for acquisition. + +Collects per-FOV metrics during acquisition, stores them per-timepoint, +and applies configurable policies to flag FOVs and optionally pause. +""" + +from __future__ import annotations + +import csv +import enum +import threading +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import cv2 +import numpy as np + +from control.core.job_processing import CaptureInfo, Job, JobImage + + +class FocusScoreMethod(str, enum.Enum): + """Available focus score calculation methods.""" + + LAPLACIAN_VARIANCE = "laplacian_variance" + NORMALIZED_VARIANCE = "normalized_variance" + GRADIENT_MAGNITUDE = "gradient_magnitude" + FFT_HIGH_FREQ = "fft_high_freq" + + +class QCMetricField(str, enum.Enum): + """Valid metric field names on FOVMetrics for outlier detection.""" + + FOCUS_SCORE = "focus_score" + LASER_AF_DISPLACEMENT_UM = "laser_af_displacement_um" + Z_DIFF_FROM_LAST_TIMEPOINT_UM = "z_diff_from_last_timepoint_um" + + +def calculate_focus_score(image: np.ndarray, method: FocusScoreMethod = FocusScoreMethod.LAPLACIAN_VARIANCE) -> float: + """Calculate focus score for an image. + + Args: + image: 2D grayscale or multichannel image (first channel used if multichannel). + method: Focus score method to use. + + Returns: + Focus score — higher means more in focus. + """ + if image.ndim == 3: + image = image[:, :, 0] + + method = FocusScoreMethod(method) # accept string or enum + + if method == FocusScoreMethod.LAPLACIAN_VARIANCE: + laplacian = cv2.Laplacian(image, cv2.CV_64F) + return float(laplacian.var()) + + elif method == FocusScoreMethod.NORMALIZED_VARIANCE: + mean = image.mean() + if mean == 0: + return 0.0 + return float(image.var() / mean) + + elif method == FocusScoreMethod.GRADIENT_MAGNITUDE: + img_f = image.astype(np.float64) + gy = np.gradient(img_f, axis=0) + gx = np.gradient(img_f, axis=1) + return float(np.sqrt(gx**2 + gy**2).mean()) + + elif method == FocusScoreMethod.FFT_HIGH_FREQ: + fft = np.fft.fft2(image.astype(np.float64)) + fft_shift = np.fft.fftshift(fft) + h, w = image.shape[:2] + cy, cx = h // 2, w // 2 + mask_size = min(h, w) // 8 + fft_shift[cy - mask_size : cy + mask_size, cx - mask_size : cx + mask_size] = 0 + return float(np.abs(fft_shift).mean()) + + else: + raise ValueError(f"Unknown focus method: {method}") + + +@dataclass(frozen=True) +class FOVIdentifier: + """Identifies a single FOV within an acquisition.""" + + region_id: str + fov_index: int + + +@dataclass +class FOVMetrics: + """QC metrics for a single FOV.""" + + fov_id: FOVIdentifier + timestamp: float + z_position_um: float + + focus_score: Optional[float] = None + laser_af_displacement_um: Optional[float] = None + z_diff_from_last_timepoint_um: Optional[float] = None + + +@dataclass +class QCConfig: + """Configuration for QC metrics collection.""" + + enabled: bool = False + calculate_focus_score: bool = True + record_laser_af_displacement: bool = False + calculate_z_diff_from_last_timepoint: bool = False + focus_score_method: FocusScoreMethod = FocusScoreMethod.LAPLACIAN_VARIANCE + # Which channel and z-slice to run QC on + qc_channel_index: int = 0 + qc_z_index: int = 0 + + +@dataclass +class QCResult: + """Result from QC job.""" + + metrics: FOVMetrics + error: Optional[str] = None + + +@dataclass +class QCJob(Job[QCResult]): + """Quality control job for a single FOV. + + Calculates configured metrics and returns them as QCResult. + Runs in JobRunner subprocess (when multiprocessing enabled) or inline. + """ + + qc_config: QCConfig = field(default_factory=QCConfig) + previous_timepoint_z: Optional[float] = None + + def run(self) -> QCResult: + image = self.image_array() + metrics = FOVMetrics( + fov_id=FOVIdentifier( + region_id=str(self.capture_info.region_id), + fov_index=self.capture_info.fov, + ), + timestamp=self.capture_info.capture_time, + z_position_um=self.capture_info.position.z_mm * 1000, + ) + + try: + if self.qc_config.calculate_focus_score: + metrics.focus_score = calculate_focus_score(image, self.qc_config.focus_score_method) + + if self.qc_config.record_laser_af_displacement: + metrics.laser_af_displacement_um = self.capture_info.z_piezo_um + + if self.previous_timepoint_z is not None: + metrics.z_diff_from_last_timepoint_um = metrics.z_position_um - self.previous_timepoint_z + except Exception as e: + return QCResult(metrics=metrics, error=f"QC metric calculation failed: {e}") + + return QCResult(metrics=metrics) + + +@dataclass +class QCPolicyConfig: + """Configuration for QC policy decisions.""" + + enabled: bool = False + check_after_timepoint: bool = True + focus_score_min: Optional[float] = None + z_drift_max_um: Optional[float] = None + detect_outliers: bool = False + outlier_metric: QCMetricField = QCMetricField.FOCUS_SCORE + outlier_std_threshold: float = 2.0 + pause_if_any_flagged: bool = True + + +class TimepointMetricsStore: + """Stores QC metrics for a single timepoint. Thread-safe.""" + + def __init__(self, timepoint_index: int): + self._timepoint = timepoint_index + self._metrics: Dict[FOVIdentifier, FOVMetrics] = {} + self._lock = threading.Lock() + + def add(self, metrics: FOVMetrics) -> None: + with self._lock: + self._metrics[metrics.fov_id] = metrics + + def get(self, fov_id: FOVIdentifier) -> Optional[FOVMetrics]: + with self._lock: + return self._metrics.get(fov_id) + + def get_all(self) -> List[FOVMetrics]: + with self._lock: + return list(self._metrics.values()) + + def get_metric_values(self, metric: QCMetricField) -> Dict[FOVIdentifier, float]: + metric = QCMetricField(metric) # validate — raises ValueError on bad input + with self._lock: + result = {} + for fov_id, m in self._metrics.items(): + value = getattr(m, metric.value, None) + if value is not None: + result[fov_id] = value + return result + + def save(self, path: str) -> None: + """Save metrics to CSV.""" + with self._lock: + metrics_list = list(self._metrics.values()) + if not metrics_list: + return + # Keep in sync with FOVMetrics fields (flattening fov_id into region_id + fov_index) + fieldnames = [ + "region_id", + "fov_index", + "timestamp", + "z_position_um", + "focus_score", + "laser_af_displacement_um", + "z_diff_from_last_timepoint_um", + ] + with open(path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for m in metrics_list: + writer.writerow( + { + "region_id": m.fov_id.region_id, + "fov_index": m.fov_id.fov_index, + "timestamp": m.timestamp, + "z_position_um": m.z_position_um, + "focus_score": m.focus_score, + "laser_af_displacement_um": m.laser_af_displacement_um, + "z_diff_from_last_timepoint_um": m.z_diff_from_last_timepoint_um, + } + ) + + +@dataclass +class PolicyDecision: + """Result of QC policy evaluation.""" + + flagged_fovs: List[FOVIdentifier] + flag_reasons: Dict[FOVIdentifier, List[str]] + should_pause: bool + + +class QCPolicy: + """Evaluates QC metrics against configured rules.""" + + def __init__(self, config: QCPolicyConfig): + self._config = config + + def check_timepoint(self, metrics_store: TimepointMetricsStore) -> PolicyDecision: + if not self._config.enabled: + return PolicyDecision(flagged_fovs=[], flag_reasons={}, should_pause=False) + + flagged_set: set = set() + reasons: Dict[FOVIdentifier, List[str]] = {} + all_metrics = metrics_store.get_all() + + if self._config.focus_score_min is not None: + for m in all_metrics: + if m.focus_score is not None and m.focus_score < self._config.focus_score_min: + flagged_set.add(m.fov_id) + reasons.setdefault(m.fov_id, []).append( + f"focus_score={m.focus_score:.2f} < {self._config.focus_score_min}" + ) + + if self._config.z_drift_max_um is not None: + for m in all_metrics: + if m.z_diff_from_last_timepoint_um is not None: + if abs(m.z_diff_from_last_timepoint_um) > self._config.z_drift_max_um: + flagged_set.add(m.fov_id) + reasons.setdefault(m.fov_id, []).append( + f"z_drift={m.z_diff_from_last_timepoint_um:.2f}um > {self._config.z_drift_max_um}" + ) + + if self._config.detect_outliers: + for fov_id in self._detect_outliers( + metrics_store, self._config.outlier_metric, self._config.outlier_std_threshold + ): + flagged_set.add(fov_id) + reasons.setdefault(fov_id, []).append(f"outlier in {self._config.outlier_metric}") + + should_pause = self._config.pause_if_any_flagged and len(flagged_set) > 0 + return PolicyDecision(flagged_fovs=list(flagged_set), flag_reasons=reasons, should_pause=should_pause) + + def _detect_outliers( + self, metrics_store: TimepointMetricsStore, metric: QCMetricField, std_threshold: float + ) -> List[FOVIdentifier]: + values = metrics_store.get_metric_values(metric) + if len(values) < 3: + return [] + arr = np.array(list(values.values())) + finite_mask = np.isfinite(arr) + if not finite_mask.all(): + arr = arr[finite_mask] + # Rebuild values dict keeping only finite entries + values = {fov_id: v for fov_id, v in values.items() if np.isfinite(v)} + if len(arr) < 3: + return [] + mean, std = arr.mean(), arr.std() + if std == 0: + return [] + return [fov_id for fov_id, value in values.items() if abs(value - mean) > std_threshold * std] diff --git a/software/tests/control/core/test_qc.py b/software/tests/control/core/test_qc.py new file mode 100644 index 000000000..d67ceb385 --- /dev/null +++ b/software/tests/control/core/test_qc.py @@ -0,0 +1,419 @@ +import time + +import numpy as np +import pytest + +import squid.abc +from control.core.job_processing import CaptureInfo, JobImage +from control.core.multi_point_utils import MultiPointControllerFunctions +from control.core.qc import ( + FOVIdentifier, + FOVMetrics, + FocusScoreMethod, + PolicyDecision, + QCConfig, + QCJob, + QCMetricField, + QCPolicy, + QCPolicyConfig, + QCResult, + TimepointMetricsStore, + calculate_focus_score, +) +from control.models import AcquisitionChannel, CameraSettings, IlluminationSettings + + +def make_test_capture_info(region_id="A1", fov=0, z_mm=1.0, z_piezo_um=None) -> CaptureInfo: + return CaptureInfo( + position=squid.abc.Pos(x_mm=0.0, y_mm=0.0, z_mm=z_mm, theta_rad=None), + z_index=0, + capture_time=time.time(), + configuration=AcquisitionChannel( + name="BF LED matrix full", + display_color="#FFFFFF", + camera=1, + illumination_settings=IlluminationSettings( + illumination_channel="BF LED matrix full", + intensity=50.0, + ), + camera_settings=CameraSettings(exposure_time_ms=10.0, gain_mode=1.0), + z_offset_um=0.0, + ), + save_directory="/tmp/test", + file_id="test_0_0", + region_id=region_id, + fov=fov, + configuration_idx=0, + z_piezo_um=z_piezo_um, + ) + + +class TestFOVIdentifier: + def test_create(self): + fov_id = FOVIdentifier(region_id="A1", fov_index=3) + assert fov_id.region_id == "A1" + assert fov_id.fov_index == 3 + + def test_hashable_as_dict_key(self): + a = FOVIdentifier(region_id="A1", fov_index=0) + b = FOVIdentifier(region_id="A1", fov_index=0) + assert a == b + assert hash(a) == hash(b) + assert {a: "val"}[b] == "val" + + def test_different_fovs_not_equal(self): + assert FOVIdentifier("A1", 0) != FOVIdentifier("A1", 1) + + +class TestFOVMetrics: + def test_required_fields_only(self): + m = FOVMetrics(fov_id=FOVIdentifier("A1", 0), timestamp=1000.0, z_position_um=100.0) + assert m.focus_score is None + assert m.laser_af_displacement_um is None + assert m.z_diff_from_last_timepoint_um is None + + def test_all_fields(self): + m = FOVMetrics( + fov_id=FOVIdentifier("B2", 5), + timestamp=1000.0, + z_position_um=150.0, + focus_score=42.5, + laser_af_displacement_um=0.3, + z_diff_from_last_timepoint_um=-1.2, + ) + assert m.focus_score == 42.5 + assert m.laser_af_displacement_um == 0.3 + assert m.z_diff_from_last_timepoint_um == -1.2 + + +class TestQCConfig: + def test_defaults(self): + c = QCConfig() + assert c.enabled is False + assert c.calculate_focus_score is True + assert c.record_laser_af_displacement is False + assert c.calculate_z_diff_from_last_timepoint is False + assert c.focus_score_method == FocusScoreMethod.LAPLACIAN_VARIANCE + assert c.qc_channel_index == 0 + + +class TestQCPolicyConfig: + def test_defaults(self): + c = QCPolicyConfig() + assert c.enabled is False + assert c.check_after_timepoint is True + assert c.focus_score_min is None + assert c.z_drift_max_um is None + assert c.detect_outliers is False + assert c.outlier_metric == QCMetricField.FOCUS_SCORE + assert c.outlier_std_threshold == 2.0 + assert c.pause_if_any_flagged is True + + +class TestCalculateFocusScore: + def _sharp_image(self): + img = np.zeros((100, 100), dtype=np.uint8) + img[::2, :] = 255 + return img + + def _uniform_image(self): + return np.ones((100, 100), dtype=np.uint8) * 128 + + def test_laplacian_variance_positive_for_sharp(self): + assert calculate_focus_score(self._sharp_image(), method="laplacian_variance") > 0 + + def test_laplacian_variance_near_zero_for_uniform(self): + assert calculate_focus_score(self._uniform_image(), method="laplacian_variance") < 1.0 + + def test_normalized_variance(self): + assert calculate_focus_score(self._sharp_image(), method="normalized_variance") > 0 + + def test_normalized_variance_zero_mean_returns_zero(self): + assert calculate_focus_score(np.zeros((100, 100), dtype=np.uint8), method="normalized_variance") == 0.0 + + def test_gradient_magnitude(self): + assert calculate_focus_score(self._sharp_image(), method="gradient_magnitude") > 0 + + def test_fft_high_freq(self): + assert calculate_focus_score(self._sharp_image(), method="fft_high_freq") > 0 + + def test_unknown_method_raises(self): + with pytest.raises(ValueError): + calculate_focus_score(np.zeros((10, 10), dtype=np.uint8), method="nonexistent") + + def test_sharp_scores_higher_than_uniform(self): + assert calculate_focus_score(self._sharp_image()) > calculate_focus_score(self._uniform_image()) + + def test_multichannel_uses_first_channel(self): + rgb = np.zeros((100, 100, 3), dtype=np.uint8) + rgb[::2, :, 0] = 255 + score = calculate_focus_score(rgb) + assert score > 0 + + +class TestQCJob: + def test_run_calculates_focus_score(self): + image = np.zeros((100, 100), dtype=np.uint8) + image[::2, :] = 255 + job = QCJob( + capture_info=make_test_capture_info(region_id="A1", fov=3, z_mm=1.5), + capture_image=JobImage(image_array=image), + qc_config=QCConfig(enabled=True, calculate_focus_score=True), + ) + result = job.run() + assert isinstance(result, QCResult) + assert result.metrics.fov_id == FOVIdentifier(region_id="A1", fov_index=3) + assert result.metrics.z_position_um == 1500.0 + assert result.metrics.focus_score > 0 + assert result.error is None + + def test_run_without_focus_score(self): + job = QCJob( + capture_info=make_test_capture_info(), + capture_image=JobImage(image_array=np.zeros((10, 10), dtype=np.uint8)), + qc_config=QCConfig(enabled=True, calculate_focus_score=False), + ) + assert job.run().metrics.focus_score is None + + def test_run_records_laser_af_displacement(self): + job = QCJob( + capture_info=make_test_capture_info(z_piezo_um=2.5), + capture_image=JobImage(image_array=np.zeros((10, 10), dtype=np.uint8)), + qc_config=QCConfig(enabled=True, record_laser_af_displacement=True, calculate_focus_score=False), + ) + assert job.run().metrics.laser_af_displacement_um == 2.5 + + def test_run_calculates_z_diff(self): + job = QCJob( + capture_info=make_test_capture_info(z_mm=1.5), + capture_image=JobImage(image_array=np.zeros((10, 10), dtype=np.uint8)), + qc_config=QCConfig(enabled=True, calculate_focus_score=False), + previous_timepoint_z=1490.0, + ) + assert job.run().metrics.z_diff_from_last_timepoint_um == pytest.approx(10.0) + + def test_run_no_z_diff_without_previous(self): + job = QCJob( + capture_info=make_test_capture_info(z_mm=1.5), + capture_image=JobImage(image_array=np.zeros((10, 10), dtype=np.uint8)), + qc_config=QCConfig(enabled=True, calculate_focus_score=False), + ) + assert job.run().metrics.z_diff_from_last_timepoint_um is None + + def test_runs_in_job_runner(self): + """QCJob must work through JobRunner subprocess (picklable).""" + from control.core.job_processing import JobRunner + + image = np.zeros((50, 50), dtype=np.uint8) + image[::2, :] = 255 + job = QCJob( + capture_info=make_test_capture_info(), + capture_image=JobImage(image_array=image), + qc_config=QCConfig(enabled=True), + ) + runner = JobRunner() + runner.daemon = True + runner.start() + assert runner.wait_ready(timeout_s=5.0) + runner.dispatch(job) + result = runner.output_queue().get(timeout=5.0) + runner.shutdown(timeout_s=2.0) + assert result.exception is None + assert result.result.metrics.focus_score > 0 + + +def _make_metrics(region_id="A1", fov_index=0, focus_score=100.0, z_um=1000.0, z_diff=None): + return FOVMetrics( + fov_id=FOVIdentifier(region_id=region_id, fov_index=fov_index), + timestamp=time.time(), + z_position_um=z_um, + focus_score=focus_score, + z_diff_from_last_timepoint_um=z_diff, + ) + + +class TestTimepointMetricsStore: + def test_add_and_get(self): + store = TimepointMetricsStore(timepoint_index=0) + m = _make_metrics("A1", 0) + store.add(m) + assert store.get(FOVIdentifier("A1", 0)) is m + + def test_get_missing_returns_none(self): + store = TimepointMetricsStore(timepoint_index=0) + assert store.get(FOVIdentifier("A1", 99)) is None + + def test_get_all(self): + store = TimepointMetricsStore(timepoint_index=0) + m1 = _make_metrics("A1", 0) + m2 = _make_metrics("A1", 1) + store.add(m1) + store.add(m2) + all_m = store.get_all() + assert len(all_m) == 2 + assert m1 in all_m and m2 in all_m + + def test_get_metric_values_skips_none(self): + store = TimepointMetricsStore(timepoint_index=0) + store.add(_make_metrics("A1", 0, focus_score=100.0)) + store.add(_make_metrics("A1", 1, focus_score=200.0)) + store.add(_make_metrics("A1", 2, focus_score=None)) + values = store.get_metric_values("focus_score") + assert len(values) == 2 + assert values[FOVIdentifier("A1", 0)] == 100.0 + assert values[FOVIdentifier("A1", 1)] == 200.0 + + def test_overwrite_on_duplicate_fov(self): + store = TimepointMetricsStore(timepoint_index=0) + store.add(_make_metrics("A1", 0, focus_score=100.0)) + store.add(_make_metrics("A1", 0, focus_score=200.0)) + assert store.get(FOVIdentifier("A1", 0)).focus_score == 200.0 + assert len(store.get_all()) == 1 + + def test_save_csv(self, tmp_path): + import csv + + store = TimepointMetricsStore(timepoint_index=0) + store.add(_make_metrics("A1", 0, focus_score=100.0, z_um=1500.0)) + store.add(_make_metrics("A1", 1, focus_score=200.0, z_um=1510.0)) + csv_path = str(tmp_path / "qc_metrics.csv") + store.save(csv_path) + + with open(csv_path) as f: + rows = list(csv.DictReader(f)) + assert len(rows) == 2 + assert set(rows[0].keys()) >= {"region_id", "fov_index", "focus_score", "z_position_um"} + + +class TestQCPolicy: + def _store_with(self, metrics_list): + store = TimepointMetricsStore(timepoint_index=0) + for m in metrics_list: + store.add(m) + return store + + def test_no_rules_no_flags(self): + policy = QCPolicy(QCPolicyConfig(enabled=True)) + decision = policy.check_timepoint( + self._store_with( + [ + _make_metrics("A1", 0, focus_score=50.0), + _make_metrics("A1", 1, focus_score=100.0), + ] + ) + ) + assert decision.flagged_fovs == [] + assert decision.should_pause is False + + def test_focus_score_threshold(self): + policy = QCPolicy(QCPolicyConfig(enabled=True, focus_score_min=80.0)) + decision = policy.check_timepoint( + self._store_with( + [ + _make_metrics("A1", 0, focus_score=50.0), + _make_metrics("A1", 1, focus_score=100.0), + _make_metrics("A1", 2, focus_score=79.9), + ] + ) + ) + assert len(decision.flagged_fovs) == 2 + assert FOVIdentifier("A1", 0) in decision.flagged_fovs + assert FOVIdentifier("A1", 2) in decision.flagged_fovs + assert decision.should_pause is True + + def test_z_drift_threshold(self): + policy = QCPolicy(QCPolicyConfig(enabled=True, z_drift_max_um=5.0)) + decision = policy.check_timepoint( + self._store_with( + [ + _make_metrics("A1", 0, z_diff=2.0), + _make_metrics("A1", 1, z_diff=-6.0), + _make_metrics("A1", 2, z_diff=None), + ] + ) + ) + assert decision.flagged_fovs == [FOVIdentifier("A1", 1)] + + def test_outlier_detection(self): + policy = QCPolicy( + QCPolicyConfig( + enabled=True, + detect_outliers=True, + outlier_metric="focus_score", + outlier_std_threshold=2.0, + ) + ) + metrics = [_make_metrics("A1", i, focus_score=100.0) for i in range(9)] + metrics.append(_make_metrics("A1", 9, focus_score=10.0)) + decision = policy.check_timepoint(self._store_with(metrics)) + assert FOVIdentifier("A1", 9) in decision.flagged_fovs + + def test_outlier_needs_minimum_3_fovs(self): + policy = QCPolicy(QCPolicyConfig(enabled=True, detect_outliers=True)) + decision = policy.check_timepoint( + self._store_with( + [ + _make_metrics("A1", 0, focus_score=100.0), + _make_metrics("A1", 1, focus_score=10.0), + ] + ) + ) + assert decision.flagged_fovs == [] + + def test_pause_if_any_flagged_false(self): + policy = QCPolicy(QCPolicyConfig(enabled=True, focus_score_min=80.0, pause_if_any_flagged=False)) + decision = policy.check_timepoint( + self._store_with( + [ + _make_metrics("A1", 0, focus_score=50.0), + ] + ) + ) + assert len(decision.flagged_fovs) == 1 + assert decision.should_pause is False + + def test_flag_reasons_populated(self): + policy = QCPolicy(QCPolicyConfig(enabled=True, focus_score_min=80.0, z_drift_max_um=5.0)) + decision = policy.check_timepoint( + self._store_with( + [ + _make_metrics("A1", 0, focus_score=50.0, z_diff=10.0), + ] + ) + ) + reasons = decision.flag_reasons[FOVIdentifier("A1", 0)] + assert len(reasons) == 2 + assert any("focus_score" in r for r in reasons) + assert any("z_drift" in r for r in reasons) + + def test_fov_not_duplicated_across_rules(self): + """An FOV failing multiple rules should appear once in flagged_fovs.""" + policy = QCPolicy(QCPolicyConfig(enabled=True, focus_score_min=80.0, z_drift_max_um=5.0)) + decision = policy.check_timepoint( + self._store_with( + [ + _make_metrics("A1", 0, focus_score=50.0, z_diff=10.0), + ] + ) + ) + assert decision.flagged_fovs.count(FOVIdentifier("A1", 0)) == 1 + + +class TestQCSignals: + def test_qc_signals_have_noop_defaults(self): + """New QC signals must default to no-ops so existing callers don't break.""" + callbacks = MultiPointControllerFunctions( + signal_acquisition_start=lambda *a, **kw: None, + signal_acquisition_finished=lambda *a, **kw: None, + signal_new_image=lambda *a, **kw: None, + signal_current_configuration=lambda *a, **kw: None, + signal_current_fov=lambda *a, **kw: None, + signal_overall_progress=lambda *a, **kw: None, + signal_region_progress=lambda *a, **kw: None, + ) + # Should be callable without error + m = FOVMetrics(fov_id=FOVIdentifier("A1", 0), timestamp=0.0, z_position_um=0.0) + callbacks.signal_qc_metrics_updated(m) + + d = PolicyDecision(flagged_fovs=[], flag_reasons={}, should_pause=False) + callbacks.signal_qc_policy_decision(d)