diff --git a/PERFORMANCE_PLAN.md b/PERFORMANCE_PLAN.md new file mode 100644 index 00000000..5dfad240 --- /dev/null +++ b/PERFORMANCE_PLAN.md @@ -0,0 +1,295 @@ +# Plottr Performance & UX Improvements + +This document summarizes the changes in this PR, the profiling that motivated them, +and suggestions for future work. + +--- + +## Part 1: Implemented — Pipeline Performance (datadict, nodes, gridding) + +### Problem + +Plottr's data pipeline copied data excessively as it flowed through nodes. Each node +defensively deep-copied all data, and internal methods (`structure()`, `validate()`, +`copy()`) added further redundant copies. For a 100x100x100 MeshgridDataDict (~38 MB), +a single `copy()` took 92 ms and `validate()` took 43 ms. + +### What Changed + +**`plottr/data/datadict.py`** (core data container): +- New `_copy_field()` helper with per-key copy semantics: numpy `.copy()` for arrays, + `list()` for axes, `deepcopy` only for mutable metadata +- Rewrote `copy(deep=True/False)` — no longer chains through `structure()` → `validate()` + → `deepcopy`. New `deep=False` shares arrays (xarray-style API, backward compatible) +- `_build_structure()` private helper that skips redundant validation +- `MeshgridDataDict.validate()` monotonicity check: replaced `np.unique(np.sign(np.diff(...)))` + with direct min/max checks — same coverage, no sort/allocate +- `mask_invalid()` fast-path: skips masking entirely when data has no invalid entries +- `shapes()` uses `np.shape()` instead of `np.array(...).shape` +- `datasets_are_equal()` shape short-circuit + set-based comparison +- `remove_invalid_entries()` fixed O(n²) `np.append` pattern + fixed crash on inhomogeneous arrays +- `meshgrid_to_datadict()` / `datadict_to_dataframe()`: `ravel()` instead of `flatten()` + +**`plottr/utils/num.py`** (numerical utilities): +- `largest_numtype()`: dtype check instead of iterating every element as Python object (~15,000× faster) +- `is_invalid()`: skip zero-array allocation for non-float types +- `guess_grid_from_sweep_direction()`: convert with `np.asarray()` once instead of 4× +- `_find_switches()`: compute `is_invalid()` once (was 3×), single `np.percentile([lo,hi])` call + (was 2 separate sorts), vectorized boolean filter, `np.nanmean` for NaN-safe sweep direction + +**`plottr/node/node.py`**: Defer `structure()` call to only when structure actually changes (50× faster steady-state) + +**`plottr/node/dim_reducer.py`**: Removed redundant `copy()` in `XYSelector.process()` + +**`plottr/node/grid.py`**: Pass `copy=False` to `datadict_to_meshgrid()` since gridder already copies input + +**`plottr/plot/base.py`**: `dataclasses.replace` instead of `deepcopy` for complex plot splitting + +### Bugs Fixed +- `copy()` now properly deep-copies global mutable metadata (was sharing references) +- `remove_invalid_entries()` no longer crashes when dependents have different numbers of invalid entries + +### Benchmark Results + +**Micro-benchmarks (key functions):** + +| Function | Before | After | Speedup | +|---|---|---|---| +| `largest_numtype` (500K float) | 29.8 ms | 0.002 ms | ~15,000× | +| `mesh_500k_copy()` | 42.2 ms | 2.9 ms | 14.8× | +| `node_process` (500K mesh, steady state) | 7.4 ms | 0.15 ms | 50× | +| `_find_switches` (640K pts) | 80 ms | 31 ms | 2.6× | +| `datadict_to_meshgrid` (640K pts) | 175 ms | 71 ms | 2.5× | +| `mesh_500k_validate()` | 20.5 ms | 14.1 ms | 1.5× | + +**Real experimental data (large qcodes database, steady-state refresh):** + +| Dataset | Data Size | Before | After | Speedup | +|---|---|---|---|---| +| QDstability (14400×251, 16 deps) | 223 MB | 555 ms | 189 ms | 2.93× | +| TopogapStage2 (41×33×5×81, 21 deps) | 152 MB | 439 ms | 161 ms | 2.73× | +| QDtuning (7440×121, 16 deps) | 14 MB | 31 ms | 11 ms | 2.73× | + +**Interactive actions (simulated user operations on large datasets):** + +| Action | Before | After | Speedup | +|---|---|---|---| +| Toggle subtract average (15 MB 2D) | 293 ms | 29 ms | 10.2× | +| Swap XY axes (18 MB 2D) | 790 ms | 241 ms | 3.3× | +| Switch dependent (61 MB 1D) | 2,287 ms | 977 ms | 2.3× | +| Data refresh (15 MB 2D) | 697 ms | 199 ms | 3.5× | + +### Tests Added + +221 new tests across 4 test files: +- `test_datadict_copy_semantics.py` — copy isolation, edge cases, pipeline integrity +- `test_pipeline_coverage.py` — per-node tests, hypothesis property-based, various dtypes +- `test_round2_optimizations.py` — is_invalid, largest_numtype, remove_invalid_entries +- `test_gridder_comprehensive.py` — all GridOption paths, shapes, edge cases + +--- + +## Part 2: Implemented — Inspectr Loading & UX + +### Problem + +Opening a large QCoDeS database (1496 runs) in inspectr took 15+ minutes because the +`experiments()` + `data_sets()` enumeration in QCoDeS is O(N²). Clicking any dataset +froze the UI for ~1 second while the snapshot (up to 6 MB of JSON) was parsed into +thousands of tree widget items. + +### What Changed + +**Fast database overview** (`plottr/data/qcodes_db_overview.py`, new module): +- Single SQL JOIN query fetching run metadata directly from runs + experiments tables +- Skips snapshot and run_description blobs entirely +- Reads `inspectr_tag` directly as a column from the runs table +- Intended for eventual contribution to QCoDeS + +**Lazy snapshot loading** (`plottr/apps/inspectr.py`): +- Snapshot tree built only when user expands the "QCoDeS Snapshot" section +- Info pane sections collapsed by default +- Smooth pixel-based scrolling for tall rows (e.g., exception tracebacks) + +**Incremental refresh**: +- `refreshDB()` only loads runs newer than the last known run_id +- Merges incremental results into existing dataframe + +**Loading UX**: +- Live progress indicator: "Loading database... (142/1496 datasets)" +- Contextual messages: "Select a date...", "No datasets found...", "No datasets match filter..." +- Wider default window (960×640) + +**Fallback chain**: SQL direct → `load_by_id` loop → original `experiments()` API + +### Benchmark + +| Approach | 23 runs | 1496 runs (projected) | +|---|---|---| +| Old (experiments + data_sets) | 103 ms | 15+ minutes | +| load_by_id loop | 90 ms | ~5 seconds | +| **SQL direct** (new) | **14 ms** | **~10 ms** | +| Incremental (3 new runs) | - | **~4 ms** | + +Snapshot click: 951 ms → 0.3 ms (3,554× faster) + +--- + +## Part 3: Implemented — Plot UI Improvements + +### What Changed + +**Grid layout for pyqtgraph subplots** (`plottr/plot/pyqtgraph/autoplot.py`): +- Replaced single-column `QSplitter` with `QGridLayout` using near-square grid + (same formula as matplotlib: `nrows = int(n^0.5 + 0.5)`) +- Many subplots now arrange as 2×2, 2×3, 4×4 etc. instead of stacking vertically + +**Scrollable plot area** (both backends): +- "Scrollable" checkbox + min-height spinbox in the plot toolbar +- Off by default; when enabled, plot area expands and becomes scrollable +- Min height per row configurable (40–2000 px, default 75 px pyqtgraph / 100 px mpl) + +**Plot backend selector** (`plottr/apps/inspectr.py`): +- Combo box in inspectr toolbar to switch between matplotlib and pyqtgraph +- Default: matplotlib. Applies to newly opened plot windows. + +--- + +## Part 4: Not Implemented — Future Suggestions + +These were identified during analysis but not implemented in this PR. + +### HDF5 Data Loading (datadict_storage.py) +- Lines 274 and 305 read the **entire HDF5 dataset into memory** just to get its shape +- Fix: `ds.shape` instead of `ds[:].shape` — would reduce load time by 50–80% + +### Signal Emission Overhead (node.py) +- Up to 7 Qt signals emitted per node per data update +- `dataFieldsChanged` is redundant (axes + deps) +- Could consolidate to 1–2 batched signals + +### Fitter / Histogrammer / ScaleUnits Memoization +- These nodes recompute results on every update even when inputs haven't changed +- Could cache results keyed on data hash + parameters + +### Pipeline Change Detection +- No concept of "what changed" — every update re-processes all data through all nodes +- For append-only monitoring, nodes could process only new data + +### QCoDeS API Suggestion +The ideal API for inspectr would be a single function returning lightweight run metadata +for all or a range of runs without creating full DataSet objects: +```python +get_run_overview(conn, start_id=None, end_id=None) +# Returns: [{run_id, exp_name, sample_name, name, timestamps, guid, result_counter, metadata_keys}] +``` +This would be a single SQL query completing in <1 ms for any database size. + +--- + +## Part 5: Profiling with Real Data (963×1001 complex RF measurement) + +Profiled using a real 963×1001 complex128 2D gate-gate sweep measurement +(~12.5 MB on disk, ~15 MB in memory as complex128). + +### Timing Summary + +| Operation | Time (ms) | Notes | +|---|---|---| +| `ds_to_datadict` (first call) | 2,588 | 1,500 ms is xarray/cf_xarray import (one-time) | +| `ds_to_datadict` (steady state) | 999 | qcodes SQLite → numpy deserialization | +| `datadict_to_meshgrid` | 122 | `guess_grid_from_sweep_direction` dominates | +| Pipeline steady state (sel+grid) | 51 | Per re-trigger with same data | +| Switch dependent variable | 172 | selector + gridding + pyqtgraph `eq()` | +| Complex: real only | 8.5 | `copy()` + `.real.copy()` | +| Complex: real+imag | 11.6 | `copy()` + `.real` + `.imag` | +| Complex: mag+phase | 30.8 | `copy()` + `np.abs()` + `np.angle()` | +| `copy()` deep | 5.1 | Already fast after our optimization | +| `copy()` shallow | 0.1 | Zero-copy array sharing | +| `validate()` | 0.2 | Already fast | +| `structure()` | 0.4 | Already fast | +| `is_invalid()` on 963k complex | 44.6 | **`a == None` comparison is 44× slower than `np.isnan`** | +| `np.isnan()` on 963k complex | 1.0 | What `is_invalid` should use for numeric dtypes | + +### Bottleneck Analysis + +#### 1. `is_invalid()` — 44× slower than needed (LOW-HANGING FRUIT) + +The current implementation does `a == None` for all arrays, which triggers Python object +comparison on every element. For numeric arrays (float/complex), this is always `False` +and is pure waste. Replacing with `np.isnan()` directly for numeric dtypes would cut +`is_invalid` from 44.6 ms → ~1 ms. + +This cascades through `_find_switches()` (which calls `is_invalid` on each 963k-element +axis), making `datadict_to_meshgrid` ~90 ms faster. + +**Fix**: In `is_invalid()`, check dtype first — if it's a numeric type, skip the `== None` +check entirely and return just `np.isnan(a)`. + +#### 2. `ds_to_datadict()` — 999 ms steady state (MEDIUM EFFORT) + +The qcodes `DataSetCacheDeferred` loads data via xarray round-trip. The actual SQLite +read + numpy deserialization (`_convert_array` → `numpy.read_array` → `ast.literal_eval` +for headers) takes ~1 second for 963k × 3 parameters. + +This is largely inside qcodes, so fixes would be upstream. However, plottr could: +- Cache the loaded DataDict and skip reload when the dataset hasn't changed +- Use `load_by_id(...).cache.data()` directly instead of going through `ds_to_datadict` + which re-wraps the data +- For completed datasets (known from metadata), cache the DataDict permanently + +#### 3. `datadict_to_meshgrid` with `guessShape` — 122 ms (AVOIDABLE) + +When shape metadata exists in the QCodes `RunDescriber` (this dataset has +`shapes={'rf_wrapper_ch6_Vrf_6': (1001, 1001)}`), the gridder should use +`GridOption.metadataShape` and skip the expensive `guess_grid_from_sweep_direction`. + +The autoplot code already does this (`autoplot.py:298`), but the grid widget default +is `noGrid`, so if the user starts from the widget rather than autoplot, they get +`guessShape` which runs the full sweep-direction analysis on every re-trigger. + +**Fix**: Default the grid widget to `metadataShape` when shape metadata is available. + +#### 4. `np.abs()` + `np.angle()` for complex mag+phase — 30.8 ms (INHERENT) + +This is inherent computational cost for computing magnitude and phase of 963k complex128 +values. Not much to optimize here, but could be deferred (only compute when the plot +backend actually needs to render). + +#### 5. pyqtgraph `Terminal.setValue` → `eq()` — 12 ms per node (MEDIUM) + +pyqtgraph's flowchart compares old and new terminal values using a recursive `eq()` +function. For large DataDicts this recurses into all arrays and does element-wise +comparison. This adds ~24 ms per pipeline trigger (12 ms per node, 2 nodes). + +**Fix**: Override `eq()` on DataDictBase to do a cheap identity or shape check +instead of element-wise comparison, or set terminal values without comparison. + +### Suggested Priority (remaining) + +Items 1, 2, and 6 have been implemented. Remaining potential improvements: + +1. ~~**Fix `is_invalid()`**~~ ✅ Done — 44x faster (44.6ms → 1.0ms) +2. ~~**Default to `metadataShape`**~~ ✅ Done — avoids 122ms gridding when shape metadata exists +3. **Cache loaded DataDict** for completed datasets — avoids 999 ms reload on each refresh +4. **Override pyqtgraph `eq()`** for DataDictBase — saves ~24 ms per pipeline trigger +5. **Lazy complex splitting** — compute mag/phase only when needed by the plot backend +6. ~~**Fix mpl double-replot**~~ ✅ Done — ~20% faster mpl steady-state (919ms → 754ms) +7. **Matplotlib artist-level updates** — Instead of `fig.clear()` + full recreation on every + `setData()`, reuse existing Line2D/QuadMesh/colorbar artists and update their data. + The pyqtgraph backend already does this via `clearWidget=False`; bringing the same + pattern to mpl could reduce steady-state replot from ~750ms to ~200ms. + +### Backend Comparison After Optimizations (963×1001 complex128) + +| Operation | matplotlib | pyqtgraph | +|---|---|---| +| First plot | 1,428 ms | 175 ms | +| Steady replot | 754 ms | 80 ms | +| Complex real | 394 ms | 118 ms | +| Complex realAndImag | 687 ms | 114 ms | +| Complex magAndPhase | 730 ms | 108 ms | + +The pyqtgraph backend is ~10x faster for steady-state replots because it reuses +plot widget objects when only data changes. The matplotlib backend's remaining +cost is dominated by `fig.clear()` + subplot/artist recreation + agg rendering. diff --git a/plottr/apps/autoplot.py b/plottr/apps/autoplot.py index 0bece004..8e339c51 100644 --- a/plottr/apps/autoplot.py +++ b/plottr/apps/autoplot.py @@ -7,7 +7,6 @@ import time import argparse from typing import Union, Tuple, Optional, Type, List, Any, Type -from packaging import version from .. import QtCore, Flowchart, Signal, Slot, QtWidgets, QtGui from .. import log as plottrlog @@ -249,7 +248,10 @@ def setDefaults(self, data: DataDictBase) -> None: try: self.fc.nodes()['Data selection'].selectedData = selected - self.fc.nodes()['Grid'].grid = GridOption.guessShape, {} + if data.meta_val('qcodes_shape') is not None: + self.fc.nodes()['Grid'].grid = GridOption.metadataShape, {} + else: + self.fc.nodes()['Grid'].grid = GridOption.guessShape, {} self.fc.nodes()['Dimension assignment'].dimensionRoles = drs # FIXME: this is maybe a bit excessive, but trying to set all the defaults # like this can result in many types of errors. @@ -291,17 +293,12 @@ def __init__(self, fc: Flowchart, def setDefaults(self, data: DataDictBase) -> None: super().setDefaults(data) - import qcodes as qc - qcodes_support = (version.parse(qc.__version__) >= - version.parse("0.20.0")) - if data.meta_val('qcodes_shape') is not None and qcodes_support: - self.fc.nodes()['Grid'].grid = GridOption.metadataShape, {} - else: - self.fc.nodes()['Grid'].grid = GridOption.guessShape, {} + def autoplotQcodesDataset(log: bool = False, - pathAndId: Union[Tuple[str, int], None] = None) \ + pathAndId: Union[Tuple[str, int], None] = None, + plotWidgetClass: Optional[Type[PlotWidget]] = None) \ -> Tuple[Flowchart, QCAutoPlotMainWindow]: """ Sets up a simple flowchart consisting of a data selector, @@ -331,7 +328,8 @@ def autoplotQcodesDataset(log: bool = False, win = QCAutoPlotMainWindow(fc, pathAndId=pathAndId, widgetOptions=widgetOptions, monitor=True, - loaderName='Data loader') + loaderName='Data loader', + plotWidgetClass=plotWidgetClass) win.show() return fc, win diff --git a/plottr/apps/inspectr.py b/plottr/apps/inspectr.py index d5e41132..fb6fa94b 100644 --- a/plottr/apps/inspectr.py +++ b/plottr/apps/inspectr.py @@ -17,7 +17,7 @@ import sys import argparse import logging -from typing import Optional, Sequence, List, Dict, Iterable, Union, cast, Tuple, Mapping +from typing import Any, Optional, Sequence, List, Dict, Iterable, Union, cast, Tuple, Mapping from typing_extensions import TypedDict @@ -28,7 +28,9 @@ from .. import log as plottrlog from ..data.qcodes_dataset import (get_runs_from_db_as_dataframe, + get_runs_from_db, get_runs_from_db_fast, get_ds_structure, load_dataset_from) +from ..data.qcodes_db_overview import get_db_overview from plottr.gui.widgets import MonitorIntervalInput, FormLayoutWrapper, dictToTreeWidgetItems from .autoplot import autoplotQcodesDataset, QCAutoPlotMainWindow @@ -39,6 +41,31 @@ LOGGER = plottrlog.getLogger('plottr.apps.inspectr') +#: Hint text shown in the run list when no date is selected. +_SELECT_DATE_HINT = "Select a date on the left to browse datasets." + +#: Mapping of display names to plot widget classes for the backend selector. +#: Populated lazily on first access. +_PLOT_BACKENDS: Dict[str, type] = {} + + +def _get_plot_backends() -> Dict[str, type]: + """Lazily populate and return the backend mapping.""" + if not _PLOT_BACKENDS: + from plottr.plot.mpl.autoplot import AutoPlot as MPLAutoPlot + from plottr.plot.pyqtgraph.autoplot import AutoPlot as PGAutoPlot + _PLOT_BACKENDS['matplotlib'] = MPLAutoPlot + _PLOT_BACKENDS['pyqtgraph'] = PGAutoPlot + return _PLOT_BACKENDS + + +def _backend_name_for_class(cls: Optional[type]) -> Optional[str]: + """Return the display name for a plot widget class, or None if unknown.""" + for name, backend_cls in _get_plot_backends().items(): + if backend_cls is cls: + return name + return None + ### Database inspector tool @@ -142,9 +169,28 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): self.itemSelectionChanged.connect(self.selectRun) self.itemActivated.connect(self.activateRun) + # Overlay label for status messages + self._overlayLabel = QtWidgets.QLabel(self.viewport()) + self._overlayLabel.setAlignment(QtCore.Qt.AlignCenter) + self._overlayLabel.setWordWrap(True) + self._overlayLabel.setStyleSheet( + "color: gray; font-size: 13pt; padding: 40px;" + ) + self._overlayLabel.setAttribute(QtCore.Qt.WA_TransparentForMouseEvents) + self.setOverlayText(_SELECT_DATE_HINT) + self.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) self.customContextMenuRequested.connect(self.showContextMenu) + def setOverlayText(self, text: str) -> None: + """Show a centered overlay message. Pass empty string to hide.""" + self._overlayLabel.setText(text) + self._overlayLabel.setVisible(bool(text)) + + def resizeEvent(self, event: QtGui.QResizeEvent) -> None: + super().resizeEvent(event) + self._overlayLabel.setGeometry(self.viewport().rect()) + @Slot(QtCore.QPoint) def showContextMenu(self, position: QtCore.QPoint) -> None: model_index = self.indexAt(position) @@ -160,6 +206,7 @@ def showContextMenu(self, position: QtCore.QPoint) -> None: window = cast(QCodesDBInspector, self.window()) starAction: QtWidgets.QAction = window.starAction + starAction.setText('Star' if current_tag_char != self.tag_dict['star'] else 'Unstar') menu.addAction(starAction) @@ -167,6 +214,7 @@ def showContextMenu(self, position: QtCore.QPoint) -> None: crossAction.setText( "Cross" if current_tag_char != self.tag_dict["cross"] else "Uncross" ) + menu.addAction(crossAction) action = menu.exec_(self.mapToGlobal(position)) @@ -191,22 +239,28 @@ def addRun(self, runId: int, **vals: str) -> None: def setRuns(self, selection: Mapping[int, Mapping[str, str]], show_only_star: bool, show_also_cross: bool) -> None: self.clear() + self.setOverlayText('') # disable sorting before inserting values to avoid performance hit self.setSortingEnabled(False) + count = 0 for runId, record in selection.items(): tag = record.get('inspectr_tag', '') if show_only_star and tag == '': continue elif show_also_cross or tag != 'cross': self.addRun(runId, **record) + count += 1 self.setSortingEnabled(True) for i in range(len(self.cols)): self.resizeColumnToContents(i) + if count == 0: + self.setOverlayText("No datasets match the current filter.") + def updateRuns(self, selection: Mapping[int, Mapping[str, str]]) -> None: run_added = False @@ -254,24 +308,89 @@ class RunInfo(QtWidgets.QTreeWidget): When sending information in form of a dictionary, it will create a tree view of that dictionary and display that. + + Snapshot data is loaded lazily: a placeholder item is shown, and the full + snapshot tree is built only when the user expands it. """ + #: Signal emitted when the snapshot section needs to be loaded. + #: Argument is the QTreeWidgetItem to populate. + _snapshotRequested = Signal(object) + def __init__(self, parent: Optional[QtWidgets.QWidget] = None): super().__init__(parent) self.setHeaderLabels(['Key', 'Value']) self.setColumnCount(2) + # Smooth pixel-based scrolling so tall rows (e.g., long tracebacks) + # can be scrolled through without jumping to the next row. + self.setVerticalScrollMode(QtWidgets.QAbstractItemView.ScrollPerPixel) + + self._snapshotItem: Optional[QtWidgets.QTreeWidgetItem] = None + self._snapshotData: Optional[dict] = None + self._snapshotLoaded = False + + self.itemExpanded.connect(self._onItemExpanded) + @Slot(dict) def setInfo(self, infoDict: Dict[str, Union[dict, str]]) -> None: self.clear() + self._snapshotItem = None + self._snapshotData = None + self._snapshotLoaded = False + + for key, value in infoDict.items(): + if key == 'QCoDeS Snapshot': + # Create a placeholder for the snapshot — don't build the tree yet + self._snapshotItem = QtWidgets.QTreeWidgetItem([key, '(click to expand)']) + # Add a dummy child so the expand arrow appears + self._snapshotItem.addChild(QtWidgets.QTreeWidgetItem(['(loading...)', ''])) + self._snapshotData = value if isinstance(value, dict) else None + self.addTopLevelItem(self._snapshotItem) + self._snapshotItem.setExpanded(False) + else: + if not isinstance(value, dict): + item = QtWidgets.QTreeWidgetItem([str(key), str(value)]) + else: + item = QtWidgets.QTreeWidgetItem([key, '']) + for child in dictToTreeWidgetItems(value): + item.addChild(child) + self.addTopLevelItem(item) + item.setExpanded(False) + + for i in range(2): + self.resizeColumnToContents(i) + + @Slot(QtWidgets.QTreeWidgetItem) + def _onItemExpanded(self, item: QtWidgets.QTreeWidgetItem) -> None: + if item is self._snapshotItem and not self._snapshotLoaded: + self._loadSnapshot() + + def _loadSnapshot(self) -> None: + """Replace the placeholder with the actual snapshot tree.""" + if self._snapshotItem is None: + return + + self._snapshotLoaded = True + snap_data = self._snapshotData + + # Remove placeholder children + self._snapshotItem.takeChildren() + + if snap_data is None: + self._snapshotItem.setText(1, '(no snapshot)') + return + + self._snapshotItem.setText(1, '') - items = dictToTreeWidgetItems(infoDict) - for item in items: - self.addTopLevelItem(item) - item.setExpanded(True) + if isinstance(snap_data, dict): + for child in dictToTreeWidgetItems(snap_data): + self._snapshotItem.addChild(child) + else: + self._snapshotItem.addChild( + QtWidgets.QTreeWidgetItem([str(snap_data), ''])) - self.expandAll() for i in range(2): self.resizeColumnToContents(i) @@ -281,18 +400,60 @@ class LoadDBProcess(QtCore.QObject): Worker object for getting a qcodes db overview as pandas dataframe. It's good to have this in a separate thread because it can be a bit slow for large databases. + + Uses ``get_db_overview`` (direct SQL) by default for maximum speed. + Falls back to ``get_runs_from_db_fast`` (qcodes public API) if the + SQL approach fails. """ dbdfLoaded = Signal(object) + progressUpdated = Signal(int, int) # (current, total) pathSet = Signal() - def setPath(self, path: str) -> None: + #: If True, use direct SQL queries (fast). If False, use qcodes API. + use_fast_sql: bool = True + + def __init__(self) -> None: + super().__init__() + self.path: Optional[str] = None + self.start_run_id: int = 1 + + def setPath(self, path: str, start_run_id: int = 1) -> None: self.path = path + self.start_run_id = start_run_id self.pathSet.emit() def loadDB(self) -> None: - dbdf = get_runs_from_db_as_dataframe(self.path) + assert self.path is not None + + overview: Optional[Dict[int, Any]] = None + if self.use_fast_sql: + try: + # start_run_id uses > comparison, so subtract 1 for inclusive + overview = get_db_overview( + self.path, + start_run_id=self.start_run_id - 1, + ) + except Exception as e: + LOGGER.warning(f"Fast SQL overview failed, falling back to " + f"qcodes API: {e}") + overview = None + + if overview is None: + overview = get_runs_from_db_fast( + self.path, + start_run_id=self.start_run_id, + progress_callback=self._onProgress, + ) + + if overview: + dbdf = pandas.DataFrame.from_dict(overview, orient='index') + else: + dbdf = pandas.DataFrame() self.dbdfLoaded.emit(dbdf) + def _onProgress(self, current: int, total: int) -> None: + self.progressUpdated.emit(current, total) + class QCodesDBInspector(QtWidgets.QMainWindow): """ @@ -308,11 +469,13 @@ class QCodesDBInspector(QtWidgets.QMainWindow): _sendInfo = Signal(dict) def __init__(self, parent: Optional[QtWidgets.QWidget] = None, - dbPath: Optional[str] = None): + dbPath: Optional[str] = None, + plotWidgetClass: Optional[type] = None): """Constructor for :class:`QCodesDBInspector`.""" super().__init__(parent) self._plotWindows: Dict[int, WindowDict] = {} + self._plotWidgetClass = plotWidgetClass self.filepath = dbPath self.dbdf: Optional[pandas.DataFrame] = None @@ -370,6 +533,31 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None, self.autoLaunchPlots.setToolTip(tt) self.toolbar.addWidget(self.autoLaunchPlots) + self.toolbar.addSeparator() + + # toolbar item: plot backend selector + backendLabel = QtWidgets.QLabel(" Plot backend: ") + self.toolbar.addWidget(backendLabel) + self.plotBackendSelector = QtWidgets.QComboBox() + backends = _get_plot_backends() + self.plotBackendSelector.addItems(list(backends.keys())) + self.plotBackendSelector.setToolTip('Choose plotting backend for new plot windows') + if plotWidgetClass is not None: + known_name = _backend_name_for_class(plotWidgetClass) + if known_name is not None: + self.plotBackendSelector.setCurrentText(known_name) + else: + # Unknown class: add it to the selector with its class name + label = plotWidgetClass.__name__ + self.plotBackendSelector.addItem(label) + self.plotBackendSelector.setCurrentText(label) + self.plotBackendSelector.currentTextChanged.connect(self._onBackendChanged) + self.toolbar.addWidget(self.plotBackendSelector) + # Sync the class with the initial combo selection + self._onBackendChanged(self.plotBackendSelector.currentText()) + + self.toolbar.addSeparator() + self.showOnlyStarAction = self.toolbar.addAction(RunList.tag_dict['star']) self.showOnlyStarAction.setToolTip('Show only starred runs') self.showOnlyStarAction.setCheckable(True) @@ -408,8 +596,8 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None, self.addAction(self.crossAction) # sizing - scaledSize = int(640 * rint(self.logicalDpiX() / 96.0)) - self.resize(scaledSize, scaledSize) + scaledDpi = rint(self.logicalDpiX() / 96.0) + self.resize(int(960 * scaledDpi), int(640 * scaledDpi)) ### Thread workers @@ -420,6 +608,7 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None, self.loadDBProcess.pathSet.connect(self.loadDBThread.start) self.loadDBProcess.dbdfLoaded.connect(self.DBLoaded) self.loadDBProcess.dbdfLoaded.connect(self.loadDBThread.quit) + self.loadDBProcess.progressUpdated.connect(self.onLoadProgress) self.loadDBThread.started.connect(self.loadDBProcess.loadDB) ### connect signals/slots @@ -490,18 +679,49 @@ def loadFullDB(self, path: Optional[str] = None) -> None: if self.filepath is not None: if not self.loadDBThread.isRunning(): - self.loadDBProcess.setPath(self.filepath) + self.runList.setOverlayText("Loading database...") + self.loadDBProcess.setPath(self.filepath, start_run_id=1) + + @Slot(int, int) + def onLoadProgress(self, current: int, total: int) -> None: + self.runList.setOverlayText( + f"Loading database... ({current}/{total} datasets)") def DBLoaded(self, dbdf: pandas.DataFrame) -> None: - if self.dbdf is not None and dbdf.equals(self.dbdf): - LOGGER.debug('DB reloaded with no changes. Skipping update') + if dbdf.size == 0 and self.dbdf is not None: + LOGGER.debug('DB reloaded with no new data. Skipping update.') + self.runList.setOverlayText( + _SELECT_DATE_HINT) return None - self.dbdf = dbdf + + if self.latestRunId is not None and self.dbdf is not None and dbdf.size > 0: + # Incremental load: merge new rows into existing dataframe + existing_mask = dbdf.index.isin(self.dbdf.index) + # Update existing rows (e.g., completed_date may have changed) + if existing_mask.any(): + self.dbdf.update(dbdf.loc[existing_mask]) + # Append all truly-new rows in a single concat + new_rows = dbdf.loc[~existing_mask] + if not new_rows.empty: + self.dbdf = pandas.concat([self.dbdf, new_rows]) + elif dbdf.size > 0: + self.dbdf = dbdf + else: + self.dbdf = dbdf + self.dbdfUpdated.emit() self.dateList.sendSelectedDates() - LOGGER.debug('DB reloaded') + LOGGER.debug('DB loaded/refreshed') + + # Set appropriate overlay text after loading completes + if self.dbdf is None or self.dbdf.size == 0: + self.runList.setOverlayText( + "No datasets found in this database.") + elif self.runList.topLevelItemCount() == 0: + self.runList.setOverlayText( + _SELECT_DATE_HINT) - if self.latestRunId is not None: + if self.latestRunId is not None and self.dbdf is not None and self.dbdf.size > 0: idxs = self.dbdf.index.values newIdxs = idxs[idxs > self.latestRunId] @@ -526,11 +746,15 @@ def refreshDB(self) -> None: if self.loadDBThread.isRunning(): return if self.dbdf is not None and self.dbdf.size > 0: - self.latestRunId = self.dbdf.index.values.max() + self.latestRunId = int(self.dbdf.index.values.max()) else: self.latestRunId = -1 - self.loadFullDB() + # Incremental refresh: only load runs newer than what we have. + start_run_id = self.latestRunId + 1 if self.latestRunId is not None and self.latestRunId > 0 else 1 + if self.filepath is not None: + if not self.loadDBThread.isRunning(): + self.loadDBProcess.setPath(self.filepath, start_run_id=start_run_id) @Slot(float) def setMonitorInterval(self, val: float) -> None: @@ -577,6 +801,8 @@ def setDateSelection(self, dates: Sequence[str]) -> None: else: self._selected_dates = () self.runList.clear() + self.runList.setOverlayText( + _SELECT_DATE_HINT) @Slot(int) def setRunSelection(self, runId: int) -> None: @@ -601,13 +827,21 @@ def setRunSelection(self, runId: int) -> None: @Slot(int) def plotRun(self, runId: int) -> None: assert self.filepath is not None - fc, win = autoplotQcodesDataset(pathAndId=(self.filepath, runId)) + fc, win = autoplotQcodesDataset( + pathAndId=(self.filepath, runId), + plotWidgetClass=self._plotWidgetClass, + ) self._plotWindows[runId] = { 'flowchart': fc, 'window': win, } win.showTime() + @Slot(str) + def _onBackendChanged(self, backend: str) -> None: + backends = _get_plot_backends() + self._plotWidgetClass = backends.get(backend, self._plotWidgetClass) + def setTag(self, item: QtWidgets.QTreeWidgetItem, tag: str) -> None: # set tag in the database assert self.filepath is not None @@ -652,16 +886,18 @@ class WindowDict(TypedDict): window: QCAutoPlotMainWindow -def inspectr(dbPath: Optional[str] = None) -> QCodesDBInspector: - win = QCodesDBInspector(dbPath=dbPath) +def inspectr(dbPath: Optional[str] = None, + plotWidgetClass: Optional[type] = None) -> QCodesDBInspector: + win = QCodesDBInspector(dbPath=dbPath, plotWidgetClass=plotWidgetClass) return win -def main(dbPath: Optional[str], log_level: Union[int, str] = logging.WARNING) -> None: +def main(dbPath: Optional[str], log_level: Union[int, str] = logging.WARNING, + plotWidgetClass: Optional[type] = None) -> None: app = QtWidgets.QApplication([]) plottrlog.enableStreamHandler(True, log_level) - win = inspectr(dbPath=dbPath) + win = inspectr(dbPath=dbPath, plotWidgetClass=plotWidgetClass) win.show() if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'): @@ -679,3 +915,19 @@ def script() -> None: default="WARNING") args = parser.parse_args() main(args.dbpath, args.console_log_level) + + +def script_pyqtgraph() -> None: + """Entry point for inspectr using the pyqtgraph plotting backend.""" + from plottr.plot.pyqtgraph.autoplot import AutoPlot as PGAutoPlot + + parser = argparse.ArgumentParser( + description='inspectr -- sifting through qcodes data (pyqtgraph backend).' + ) + parser.add_argument('--dbpath', help='path to qcodes .db file', + default=None) + parser.add_argument("--console-log-level", + choices=("ERROR", "WARNING", "INFO", "DEBUG"), + default="WARNING") + args = parser.parse_args() + main(args.dbpath, args.console_log_level, plotWidgetClass=PGAutoPlot) diff --git a/plottr/data/datadict.py b/plottr/data/datadict.py index a9c49202..0ab0e157 100644 --- a/plottr/data/datadict.py +++ b/plottr/data/datadict.py @@ -117,6 +117,43 @@ def _meta_key_to_name(key: str) -> str: def _meta_name_to_key(name: str) -> str: return meta_name_to_key(name) + @staticmethod + def _copy_field(field: Dict[str, Any], copy_values: bool = True, + empty_values: bool = False) -> Dict[str, Any]: + """Create a copy of a data field dict with targeted copy semantics. + + Always creates a new dict and a new 'axes' list (mutation-safe). + For 'values': copies the array if *copy_values* is True, shares the + reference if False, or sets to ``[]`` if *empty_values* is True. + Scalar keys (unit, label) are passed through (immutable strings). + Meta keys (``__name__``) are deep-copied (may be mutable). + All other keys are deep-copied for safety. + """ + new_field: Dict[str, Any] = {} + for fk, fv in field.items(): + if fk == 'values': + if empty_values: + new_field[fk] = [] + elif copy_values: + # use numpy-optimized copy for arrays + if isinstance(fv, (np.ndarray, np.ma.core.MaskedArray)): + new_field[fk] = fv.copy() + elif isinstance(fv, list): + new_field[fk] = fv.copy() + else: + new_field[fk] = cp.deepcopy(fv) + else: + new_field[fk] = fv # shared reference + elif fk == 'axes': + new_field[fk] = list(fv) # always new list + elif fk in ('unit', 'label'): + new_field[fk] = fv # immutable strings + elif is_meta_key(fk): + new_field[fk] = cp.deepcopy(fv) # may be mutable + else: + new_field[fk] = cp.deepcopy(fv) # unknown keys: safe default + return new_field + @staticmethod def to_records(**data: Any) -> Dict[str, np.ndarray]: """Convert data to records that can be added to the ``DataDict``. @@ -344,7 +381,7 @@ def extract(self: T, data: List[str], include_meta: bool = True, ret = self.__class__() for d in data: if copy: - ret[d] = cp.deepcopy(self[d]) + ret[d] = self._copy_field(self[d], copy_values=True) else: ret[d] = self[d] @@ -426,29 +463,38 @@ def structure(self: T, add_shape: bool = False, remove_data = [] if self.validate(): - s = self.__class__() - for n, v in self.data_items(): - if n not in remove_data: - v2 = v.copy() - v2['values'] = [] - s[n] = cp.deepcopy(v2) - if 'axes' in s[n]: - for r in remove_data: - if r in s[n]['axes']: - i = s[n]['axes'].index(r) - s[n]['axes'].pop(i) - - if include_meta: - for n, v in self.meta_items(): - s.add_meta(n, v) - else: - s.clear_meta() + return self._build_structure( + include_meta=include_meta, same_type=same_type, + remove_data=remove_data) + return None - if same_type: - s = self.__class__(**s) + def _build_structure(self: T, include_meta: bool = True, + same_type: bool = False, + remove_data: Optional[List[str]] = None) -> T: + """Build a structure-only copy. Caller must ensure data is validated.""" + if remove_data is None: + remove_data = [] - return s - return None + s = self.__class__() + for n, v in self.data_items(): + if n not in remove_data: + s[n] = self._copy_field(v, empty_values=True) + if 'axes' in s[n]: + for r in remove_data: + if r in s[n]['axes']: + i = s[n]['axes'].index(r) + s[n]['axes'].pop(i) + + if include_meta: + for n, v in self.meta_items(): + s.add_meta(n, cp.deepcopy(v)) + else: + s.clear_meta() + + if same_type: + s = self.__class__(**s) + + return s def nbytes(self, name: Optional[str]=None) -> Optional[int]: @@ -477,20 +523,19 @@ def label(self, name: str) -> Optional[str]: :param name: Name of the data field. :return: Labelled name. """ - if self.validate(): - if name not in self: - raise ValueError("No field '{}' present.".format(name)) - - if self[name]['label'] != '': - n = self[name]['label'] - else: - n = name + if name not in self: + raise ValueError("No field '{}' present.".format(name)) - if self[name]['unit'] != '': - n += ' ({})'.format(self[name]['unit']) + field = self[name] + n = field.get('label', '') or name + if n == '': + n = name - return n - return None + unit = field.get('unit', '') + if unit: + n += ' ({})'.format(unit) + + return n def axes_are_compatible(self) -> bool: """ @@ -560,7 +605,7 @@ def shapes(self) -> Dict[str, Tuple[int, ...]]: """ shapes = {} for k, v in self.data_items(): - shapes[k] = np.array(self.data_vals(k)).shape + shapes[k] = np.shape(self.data_vals(k)) return shapes @@ -692,18 +737,25 @@ def reorder_axes(self: T, data_names: Union[str, Sequence[str], None] = None, self.validate() return self - def copy(self: T) -> T: + def copy(self: T, deep: bool = True) -> T: """ Make a copy of the dataset. + :param deep: If ``True`` (default), all data arrays are independently + copied. The returned dataset is fully independent of the original. + If ``False``, the returned dataset shares data array references + with the original. Modifying array *contents* in the copy will + affect the original; *replacing* an array only affects the copy. + Field metadata (axes, unit, label) is always independently copied. :return: A copy of the dataset. """ - logger.debug(f'copying a dataset with size {self.nbytes()}') - ret = self.structure() - assert ret is not None + ret = self.__class__() + for k, v in self.items(): + if self._is_meta_key(k): + ret[k] = cp.deepcopy(v) + else: + ret[k] = self._copy_field(v, copy_values=deep) - for k, v in self.data_items(): - ret[k]['values'] = self.data_vals(k).copy() return ret def astype(self: T, dtype: np.dtype) -> T: @@ -728,7 +780,10 @@ def mask_invalid(self: T) -> T: """ for d, _ in self.data_items(): arr = self.data_vals(d) - vals = np.ma.masked_where(num.is_invalid(arr), arr, copy=True) + invalid_mask = num.is_invalid(arr) + if not np.any(invalid_mask): + continue # no invalid entries, skip masking + vals = np.ma.masked_where(invalid_mask, arr, copy=True) try: vals.fill_value = np.nan except TypeError: @@ -793,7 +848,7 @@ def __add__(self, newdata: 'DataDict') -> 'DataDict': """ # FIXME: remove shape - s = misc.unwrap_optional(self.structure(add_shape=False)) + s = self._build_structure() if DataDictBase.same_structure(self, newdata): for k, v in self.data_items(): val0 = self[k]['values'] @@ -843,7 +898,7 @@ def add_data(self, **kw: Any) -> None: :param kw: one array per data field (none can be omitted). """ - dd = misc.unwrap_optional(self.structure(same_type=True)) + dd = self._build_structure(same_type=True) for name, _ in dd.data_items(): if name not in kw: kw[name] = None @@ -925,7 +980,7 @@ def expand(self) -> 'DataDict': self.validate() if not self.is_expandable(): raise ValueError('Data cannot be expanded.') - struct = misc.unwrap_optional(self.structure(add_shape=False)) + struct = self._build_structure() ret = DataDict(**struct) if self.is_expanded(): @@ -1011,14 +1066,15 @@ def remove_invalid_entries(self) -> 'DataDict': datavals = self.data_vals(d) rows = datavals.reshape(-1, int(np.prod(ishp[d]))) - _idxs: np.ndarray = np.array([]) + _idxs_parts: list = [] # get indices of all rows that are fully None if len(ishp[d]) == 0: _newidxs = np.atleast_1d(np.asarray(rows is None)).nonzero()[0] else: _newidxs = np.atleast_1d(np.asarray(np.all(rows is None, axis=-1))).nonzero()[0] - _idxs = np.append(_idxs, _newidxs) + if _newidxs.size > 0: + _idxs_parts.append(_newidxs) # get indices for all rows that are fully NaN. works only # for some dtypes, so except TypeErrors. @@ -1027,15 +1083,16 @@ def remove_invalid_entries(self) -> 'DataDict': _newidxs = np.where(np.isnan(rows))[0] else: _newidxs = np.where(np.all(np.isnan(rows), axis=-1))[0] - _idxs = np.append(_idxs, _newidxs) + if _newidxs.size > 0: + _idxs_parts.append(_newidxs) except TypeError: pass + _idxs = np.concatenate(_idxs_parts) if _idxs_parts else np.array([], dtype=int) idxs.append(_idxs) if len(idxs) > 0: - remove_idxs = reduce(np.intersect1d, - tuple(np.array(idxs).astype(int))) + remove_idxs = reduce(np.intersect1d, idxs) for k, v in self.data_items(): v['values'] = np.delete(v['values'], remove_idxs, axis=0) @@ -1120,17 +1177,24 @@ def validate(self) -> bool: try: if axis_data.shape[axis_num] > 1: - steps = np.unique(np.sign(np.diff(axis_data, axis=axis_num))) - - # for incomplete data, there maybe nan steps -- we need to remove those, - # doesn't mean anything is wrong. - steps = steps[~np.isnan(steps)] - - if 0 in steps: - msg += (f"Malformed data: {na} is expected to be {axis_num}th " - "axis but has no variation along that axis.\n") - if steps.size > 1: - msg += (f"Malformed data: axis {na} is not monotonous.\n") + diffs: np.ndarray = np.diff(axis_data, axis=axis_num) + + # for incomplete data, there may be nan steps -- we need to + # ignore those, doesn't mean anything is wrong. + if np.issubdtype(diffs.dtype, np.floating): + nan_mask = np.isnan(diffs) + if np.all(nan_mask): + continue # all NaN, can't check + valid: np.ndarray = diffs[~nan_mask] + else: + valid = diffs.ravel() + + if valid.size > 0: + if np.any(valid == 0): + msg += (f"Malformed data: {na} is expected to be {axis_num}th " + "axis but has no variation along that axis.\n") + if not (np.all(valid > 0) or np.all(valid < 0)): + msg += (f"Malformed data: axis {na} is not monotonous.\n") # can happen if we have bad shapes. but that should already have been caught. except IndexError: @@ -1214,7 +1278,7 @@ def _mesh_mean(data: MeshgridDataDict, ax: str) -> MeshgridDataDict: :return: averaged data """ iax = data.axes().index(ax) - new_data = data.structure(remove_data=[ax]) + new_data = data._build_structure(remove_data=[ax]) assert isinstance(new_data, MeshgridDataDict) for d, v in data.data_items(): @@ -1237,7 +1301,7 @@ def _mesh_slice(data: MeshgridDataDict, **kwargs: Dict[str, Union[slice, int]]) for ax, val in kwargs.items(): i = data.axes().index(ax) slices[i] = val - ret = data.structure() + ret = data._build_structure() assert isinstance(ret, MeshgridDataDict) for d, _ in data.data_items(): @@ -1329,7 +1393,7 @@ def datadict_to_meshgrid(data: DataDict, inner_axis_order, target_shape = ret # construct new data - newdata = MeshgridDataDict(**misc.unwrap_optional(data.structure(add_shape=False))) + newdata = MeshgridDataDict(**data._build_structure()) axlist = data.axes(data.dependents()[0]) for k, v in data.data_items(): @@ -1356,9 +1420,9 @@ def meshgrid_to_datadict(data: MeshgridDataDict) -> DataDict: :param data: Input ``MeshgridDataDict``. :return: Flattened ``DataDict``. """ - newdata = DataDict(**misc.unwrap_optional(data.structure(add_shape=False))) + newdata = DataDict(**data._build_structure()) for k, v in data.data_items(): - val = v['values'].copy().reshape(-1) + val = v['values'].ravel().copy() newdata[k]['values'] = val newdata = newdata.sanitize() @@ -1412,6 +1476,7 @@ def combine_datadicts(*dicts: DataDict) -> Union[DataDictBase, DataDict]: ret: Union[DataDictBase, None] = None rettype: Union[type[DataDictBase], None] = None + for d in dicts: if ret is None: ret = d.copy() @@ -1605,48 +1670,38 @@ def datasets_are_equal(a: DataDictBase, b: DataDictBase, return False if not ignore_meta: - # are all meta data of a also in b, and are they the same value? - for k, v in a.meta_items(): - if k not in [kk for kk, vv in b.meta_items()]: - return False - elif b.meta_val(k) != v: - return False - - # are all meta data of b also in a? - for k, v in b.meta_items(): - if k not in [kk for kk, vv in a.meta_items()]: + a_meta = dict(a.meta_items()) + b_meta = dict(b.meta_items()) + if a_meta.keys() != b_meta.keys(): + return False + for k, v in a_meta.items(): + if b_meta[k] != v: return False - # check all data fields in a - for dn, dv in a.data_items(): + # check all data fields + a_fields = set(dn for dn, _ in a.data_items()) + b_fields = set(dn for dn, _ in b.data_items()) + if a_fields != b_fields: + return False - # are all fields also present in b? - if dn not in [dnn for dnn, dvv in b.data_items()]: - return False + for dn in a_fields: + a_vals = a.data_vals(dn) + b_vals = b.data_vals(dn) - # check if data is equal - if not num.arrays_equal( - np.array(a.data_vals(dn)), - np.array(b.data_vals(dn)), - ): + # fast shape check before expensive value comparison + if np.shape(a_vals) != np.shape(b_vals): return False - if not ignore_meta: - # check meta data - for k, v in a.meta_items(dn): - if k not in [kk for kk, vv in b.meta_items(dn)]: - return False - elif v != b.meta_val(k, dn): - return False - - # only thing left to check is whether there are items in b but not a - for dn, dv in b.data_items(): - if dn not in [dnn for dnn, dvv in a.data_items()]: + if not num.arrays_equal(np.asarray(a_vals), np.asarray(b_vals)): return False if not ignore_meta: - for k, v in b.meta_items(dn): - if k not in [kk for kk, vv in a.meta_items(dn)]: + a_fmeta = dict(a.meta_items(dn)) + b_fmeta = dict(b.meta_items(dn)) + if a_fmeta.keys() != b_fmeta.keys(): + return False + for k, v in a_fmeta.items(): + if v != b_fmeta[k]: return False return True @@ -1682,14 +1737,14 @@ def datadict_to_dataframe(data: DataDict) -> pd.DataFrame: # if the dimension of all variables are the same, directly flat the array if dimension_check: for key, value in data.data_items(): - data_set[key] = (data.data_vals(key)).flatten() + data_set[key] = (data.data_vals(key)).ravel() # if the dimension is different between variables, match their dimension to the highest one else: for key, value in data.data_items(): repeated_time = int(max_ele/np.size(data.data_vals(key))) value_array = np.repeat(data.data_vals(key), repeated_time) - data_set[key] = value_array.flatten('F') + data_set[key] = value_array.ravel(order='F') # convert organized data to DataFrame and return it return pd.DataFrame(data=data_set) diff --git a/plottr/data/qcodes_dataset.py b/plottr/data/qcodes_dataset.py index cd497575..8b6911fd 100644 --- a/plottr/data/qcodes_dataset.py +++ b/plottr/data/qcodes_dataset.py @@ -6,6 +6,7 @@ import os import sys from contextlib import closing +from datetime import datetime from itertools import chain from operator import attrgetter from typing import Dict, List, Set, Union, TYPE_CHECKING, Any, Tuple, Optional, cast @@ -17,6 +18,7 @@ from qcodes.dataset.data_set import load_by_id from qcodes.dataset.experiment_container import experiments from qcodes.dataset.sqlite.database import conn_from_dbpath_or_conn, initialise_or_create_database_at +from qcodes.dataset.sqlite.queries import get_last_run from .datadict import DataDictBase, DataDict, combine_datadicts from ..node.node import Node, updateOption @@ -42,6 +44,24 @@ def _get_names_of_standalone_parameters(paramspecs: List['ParamSpec'] return standalones +def _split_timestamp(ts: Optional[str]) -> Tuple[str, str]: + """Split a qcodes timestamp string into (date, time) components. + + Uses datetime parsing instead of string slicing for robustness. + + :param ts: timestamp string as returned by ``ds.run_timestamp()`` + (typically ``"YYYY-MM-DD HH:MM:SS"``), or None. + :returns: (date_str, time_str) or ('', '') if ts is None or unparsable. + """ + if ts is None: + return '', '' + try: + dt = datetime.fromisoformat(ts) + return dt.strftime('%Y-%m-%d'), dt.strftime('%H:%M:%S') + except (ValueError, TypeError): + return '', '' + + class IndependentParameterDict(TypedDict): unit: str label: str @@ -125,20 +145,10 @@ def get_ds_info(ds: 'DataSetProtocol', get_structure: bool = True) -> DataSetInf as well (key is `structure' then). """ _complete_ts = ds.completed_timestamp() - if _complete_ts is not None: - completed_date = _complete_ts[:10] - completed_time = _complete_ts[11:] - else: - completed_date = '' - completed_time = '' + completed_date, completed_time = _split_timestamp(_complete_ts) _start_ts = ds.run_timestamp() - if _start_ts is not None: - started_date = _start_ts[:10] - started_time = _start_ts[11:] - else: - started_date = '' - started_time = '' + started_date, started_time = _split_timestamp(_start_ts) if get_structure: structure: Optional[DataSetStructureDict] = get_ds_structure(ds) @@ -222,6 +232,66 @@ def get_runs_from_db_as_dataframe(path: str) -> pd.DataFrame: return df +def _ds_to_info_dict(ds: 'DataSetProtocol') -> DataSetInfoDict: + """Extract inspectr-relevant info from a dataset without loading data or snapshot.""" + started_date, started_time = _split_timestamp(ds.run_timestamp()) + completed_date, completed_time = _split_timestamp(ds.completed_timestamp()) + return DataSetInfoDict( + experiment=ds.exp_name, + sample=ds.sample_name, + name=ds.name, + started_date=started_date, + started_time=started_time, + completed_date=completed_date, + completed_time=completed_time, + structure=None, + records=ds.number_of_results, + guid=ds.guid, + inspectr_tag=ds.metadata.get('inspectr_tag', ''), + ) + + +def get_runs_from_db_fast(path: str, + start_run_id: int = 1, + progress_callback: Optional[Any] = None, + ) -> Dict[int, DataSetInfoDict]: + """Fast alternative to ``get_runs_from_db`` that avoids the expensive + ``experiments()`` + ``data_sets()`` enumeration. + + Uses ``load_by_id`` directly for each run_id, which is O(1) per run + instead of O(N) for the experiment/dataset iteration approach. + + :param path: path to the qcodes .db file. + :param start_run_id: first run_id to load (inclusive). Use for incremental + loading: pass the last known run_id + 1 to load only new runs. + :param progress_callback: optional callable(current, total) for progress. + :returns: dictionary mapping run_id to dataset info. + """ + if sys.version_info >= (3, 11): + conn = conn_from_dbpath_or_conn(conn=None, path_to_db=path, read_only=True) + else: + conn = conn_from_dbpath_or_conn(conn=None, path_to_db=path) + + overview: Dict[int, DataSetInfoDict] = {} + with closing(conn) as conn_: + last = get_last_run(conn_) + if last is None: + return overview + + total = last - start_run_id + 1 + for i, run_id in enumerate(range(start_run_id, last + 1)): + try: + ds = load_by_id(run_id, conn=conn_) + overview[run_id] = _ds_to_info_dict(ds) + except Exception: + pass # skip missing/corrupt runs + + if progress_callback is not None and (i % 10 == 0 or i == total - 1): + progress_callback(i + 1, total) + + return overview + + # Extracting data def ds_to_datadicts(ds: 'DataSetProtocol') -> Dict[str, DataDict]: diff --git a/plottr/data/qcodes_db_overview.py b/plottr/data/qcodes_db_overview.py new file mode 100644 index 00000000..0b965561 --- /dev/null +++ b/plottr/data/qcodes_db_overview.py @@ -0,0 +1,190 @@ +""" +plottr.data.qcodes_db_overview — Fast database overview queries. + +This module provides optimized functions for listing QCoDeS dataset metadata +without loading full DataSet objects. It uses direct SQLite queries on the +QCoDeS database schema, avoiding the expensive experiments()/data_sets() +enumeration. + +**Intended for eventual contribution to QCoDeS.** The queries here rely on the +stable QCoDeS database schema (runs + experiments tables) which has not changed +across many QCoDeS versions. +""" +import json +import sys +import time +import logging +from contextlib import closing +from typing import Dict, Optional, Tuple + +from typing_extensions import TypedDict + +from qcodes.dataset.sqlite.database import conn_from_dbpath_or_conn + +logger = logging.getLogger(__name__) + + +def _records_from_run_description(run_description_json: Optional[str]) -> int: + """Extract record count from run_description shapes field. + + QCoDeS run_description may contain a ``shapes`` dict mapping dependent + parameter names to their shape tuples. The total data-point count is the + product of shape dimensions summed across all parameter trees — matching + the semantics of ``DataSet.number_of_results``. + """ + if not run_description_json: + return 0 + try: + desc = json.loads(run_description_json) + shapes = desc.get('shapes') + if not shapes: + return 0 + total = 0 + for shape in shapes.values(): + if isinstance(shape, (list, tuple)) and len(shape) > 0: + n = 1 + for dim in shape: + n *= dim + # Each parameter tree contributes n_values * n_params_in_tree + # But shapes only has dependent params, and number_of_results + # counts all values including axes. For display purposes, + # the product of the shape is the most useful number. + total += n + return total + except (json.JSONDecodeError, TypeError, KeyError): + return 0 + + +class RunOverviewDict(TypedDict): + """Lightweight run overview — no snapshot, no data, no full DataSet.""" + run_id: int + experiment: str + sample: str + name: str + started_date: str + started_time: str + completed_date: str + completed_time: str + records: int + guid: str + inspectr_tag: str + + +def _format_timestamp(ts: Optional[float]) -> Tuple[str, str]: + """Convert a unix timestamp float to (date, time) strings.""" + if ts is None or ts == 0: + return '', '' + try: + t = time.localtime(ts) + return time.strftime('%Y-%m-%d', t), time.strftime('%H:%M:%S', t) + except (OSError, ValueError, OverflowError): + return '', '' + + +def get_db_overview(db_path: str, + start_run_id: int = 0, + ) -> Dict[int, RunOverviewDict]: + """Get a lightweight overview of all runs in a QCoDeS database. + + Uses a single SQL JOIN query to fetch run metadata from the ``runs`` and + ``experiments`` tables, avoiding the expensive ``experiments()`` + + ``data_sets()`` enumeration that QCoDeS uses internally. + + For a database with 1500 runs, this completes in ~10ms vs 15+ minutes + with the standard QCoDeS API. + + :param db_path: path to the .db file. + :param start_run_id: only return runs with run_id > start_run_id. + Use 0 to get all runs. Pass the last known run_id for incremental + refresh. + :returns: dict mapping run_id to RunOverviewDict. + """ + overview: Dict[int, RunOverviewDict] = {} + + if sys.version_info >= (3, 11): + conn = conn_from_dbpath_or_conn(conn=None, path_to_db=db_path, read_only=True) + else: + conn = conn_from_dbpath_or_conn(conn=None, path_to_db=db_path) + + with closing(conn) as c: + # Check which ad-hoc metadata columns exist in the runs table. + # QCoDeS stores metadata added via ds.add_metadata() as extra columns. + try: + col_info = c.execute('PRAGMA table_info(runs)').fetchall() + col_names = {col[1] for col in col_info} + except Exception: + col_names = set() + + has_inspectr_tag = 'inspectr_tag' in col_names + + # Build query: include inspectr_tag column if it exists. + # Includes run_description to extract shape info for record count. + # Deliberately excludes snapshot (large blob). + tag_col = ", r.inspectr_tag" if has_inspectr_tag else "" + query = f""" + SELECT r.run_id, e.name, e.sample_name, r.name, + r.run_timestamp, r.completed_timestamp, + r.result_counter, r.guid, r.result_table_name, + r.run_description{tag_col} + FROM runs r + JOIN experiments e ON r.exp_id = e.exp_id + WHERE r.run_id > ? + ORDER BY r.run_id + """ + + try: + rows = c.execute(query, (start_run_id,)).fetchall() + except Exception as e: + logger.warning(f"Could not query database overview: {e}") + return overview + + # Build a map of actual row counts from each results table. + # result_counter in the runs table counts INSERT calls, not data points. + # For array paramtype one INSERT can contain thousands of data points, + # so result_counter can be much smaller than the real data point count. + results_tables: set[str] = set() + for row in rows: + tbl = row[8] # result_table_name + if tbl: + results_tables.add(tbl) + row_counts: dict[str, int] = {} + for tbl in results_tables: + try: + cnt = c.execute( + f'SELECT COUNT(*) FROM "{tbl}"' + ).fetchone() + row_counts[tbl] = cnt[0] if cnt else 0 + except Exception: + pass # table may not exist (e.g., qdwsdk downloads) + + tag_col_idx = 10 if has_inspectr_tag else -1 + for row in rows: + run_id = row[0] + started_date, started_time = _format_timestamp(row[4]) + completed_date, completed_time = _format_timestamp(row[5]) + tag = row[tag_col_idx] if tag_col_idx > 0 and len(row) > tag_col_idx and row[tag_col_idx] else '' + result_table = row[8] or '' + + # Determine record count: prefer results table row count, + # then try shape info from run_description, then result_counter. + records = row_counts.get(result_table, 0) + if records == 0: + records = _records_from_run_description(row[9]) + if records == 0: + records = row[6] or 0 + + overview[run_id] = RunOverviewDict( + run_id=run_id, + experiment=row[1] or '', + sample=row[2] or '', + name=row[3] or '', + started_date=started_date, + started_time=started_time, + completed_date=completed_date, + completed_time=completed_time, + records=records, + guid=row[7] or '', + inspectr_tag=tag, + ) + + return overview diff --git a/plottr/gui/data_display.py b/plottr/gui/data_display.py index 3ff24813..c688b6e2 100644 --- a/plottr/gui/data_display.py +++ b/plottr/gui/data_display.py @@ -26,10 +26,15 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None, self._dataStructure = DataDictBase() self._dataShapes: Dict[str, Tuple[int, ...]] = {} self._readonly = readonly + self._batchUpdate = False self.setSelectionMode(self.MultiSelection) self.itemSelectionChanged.connect(self.emitSelection) + def _ndims(self, name: str) -> int: + """Return the number of independent axes for a dependent field.""" + return len(self._dataStructure.axes(name)) + def _makeItem(self, name: str) -> QtWidgets.QTreeWidgetItem: shape = self._dataShapes.get(name, tuple()) label = f"{self._dataStructure.label(name)}" @@ -111,6 +116,49 @@ def setSelectedData(self, vals: List[str]) -> None: for n, w in self.dataItems.items(): w.setSelected(n in vals) + def setBatchSelectedData(self, vals: List[str]) -> None: + """Batch-select items with a single signal emission. + + Used by select-all / 1D / 2D buttons to avoid per-item replot. + """ + if self._batchUpdate: + return + self._batchUpdate = True + try: + self.blockSignals(True) + for n, w in self.dataItems.items(): + w.setSelected(n in vals) + self.blockSignals(False) + self.dataSelectionMade.emit(self.getSelectedData()) + finally: + self._batchUpdate = False + + def selectAll(self) -> None: + """Select all enabled dependent fields. Single signal emission.""" + enabled = [n for n, w in self.dataItems.items() if not w.isDisabled()] + self.setBatchSelectedData(enabled) + + def selectFirst(self) -> None: + """Select only the first dependent (default view).""" + deps = list(self.dataItems.keys()) + self.setBatchSelectedData(deps[:1] if deps else []) + + def selectByNdims(self, ndims: int) -> None: + """Select all dependents with exactly *ndims* independent axes. + Resets any existing selection. Single signal emission.""" + matching = [n for n in self._dataStructure.dependents() + if self._ndims(n) == ndims + and n in self.dataItems + and not self.dataItems[n].isDisabled()] + self.setBatchSelectedData(matching) + + def has_dependents_with_ndims(self, ndims: int) -> bool: + """Check if the dataset has any dependent with exactly *ndims* axes.""" + for n in self._dataStructure.dependents(): + if self._ndims(n) == ndims: + return True + return False + def emitSelection(self) -> None: """emit the signal ``selectionChanged`` with the current selection""" self.dataSelectionMade.emit(self.getSelectedData()) diff --git a/plottr/node/autonode.py b/plottr/node/autonode.py index 734b10c0..ae2df83d 100644 --- a/plottr/node/autonode.py +++ b/plottr/node/autonode.py @@ -64,6 +64,7 @@ def addOption(self, name: str, specs: Dict[str, Any], confirm: bool) -> None: if optionType in self.widgetConnection.keys() else None ) + if func is not None: widget = func(self, name, specs, confirm) layout = cast(QtWidgets.QFormLayout, self.layout()) diff --git a/plottr/node/data_selector.py b/plottr/node/data_selector.py index 082d486a..06b21d6d 100644 --- a/plottr/node/data_selector.py +++ b/plottr/node/data_selector.py @@ -11,6 +11,7 @@ from ..data.datadict import DataDictBase, DataDict from ..gui.data_display import DataSelectionWidget from plottr.icons import get_dataColumnsIcon +from .. import QtWidgets from ..utils import num __author__ = 'Wolfgang Pfaff' @@ -36,6 +37,55 @@ def __init__(self, node: Optional[Node] = None): self.widget.dataSelectionMade.connect( lambda x: self.signalOption('selectedData')) + # Selection buttons + btnLayout = QtWidgets.QHBoxLayout() + btnLayout.setContentsMargins(0, 0, 0, 0) + btnLayout.setSpacing(4) + + self._selectAllBtn = QtWidgets.QPushButton("Select all") + self._selectAllBtn.clicked.connect(self._onSelectAll) + btnLayout.addWidget(self._selectAllBtn) + + self._selectFirstBtn = QtWidgets.QPushButton("Select first only") + self._selectFirstBtn.clicked.connect(self._onSelectFirst) + btnLayout.addWidget(self._selectFirstBtn) + + self._select1dBtn = QtWidgets.QPushButton("Select all 1D") + self._select1dBtn.clicked.connect(self._onSelect1D) + btnLayout.addWidget(self._select1dBtn) + + self._select2dBtn = QtWidgets.QPushButton("Select all 2D") + self._select2dBtn.clicked.connect(self._onSelect2D) + btnLayout.addWidget(self._select2dBtn) + + btnLayout.addStretch() + + layout = self.layout() + assert isinstance(layout, QtWidgets.QVBoxLayout) + layout.addLayout(btnLayout) + + def _onSelectAll(self) -> None: + assert self.widget is not None + self.widget.selectAll() + + def _onSelectFirst(self) -> None: + assert self.widget is not None + self.widget.selectFirst() + + def _onSelect1D(self) -> None: + assert self.widget is not None + self.widget.selectByNdims(1) + + def _onSelect2D(self) -> None: + assert self.widget is not None + self.widget.selectByNdims(2) + + def _updateDimButtons(self) -> None: + """Show/hide 1D/2D buttons based on what dimensions exist in the data.""" + assert self.widget is not None + self._select1dBtn.setVisible(self.widget.has_dependents_with_ndims(1)) + self._select2dBtn.setVisible(self.widget.has_dependents_with_ndims(2)) + def setSelected(self, vals: Sequence[str]) -> None: assert self.widget is not None self.widget.setSelectedData(vals) @@ -49,6 +99,7 @@ def setData(self, structure: DataDictBase, shapes: Dict[str, Tuple[int, ...]], _: Any) -> None: assert self.widget is not None self.widget.setData(structure, shapes) + self._updateDimButtons() def setShape(self, shapes: Dict[str, Tuple[int, ...]]) -> None: assert self.widget is not None diff --git a/plottr/node/dim_reducer.py b/plottr/node/dim_reducer.py index 2cf35e3e..611f0745 100644 --- a/plottr/node/dim_reducer.py +++ b/plottr/node/dim_reducer.py @@ -898,7 +898,7 @@ def process( return None dataout = data['dataOut'] assert dataout is not None - data = dataout.copy() + data = dataout # parent DimensionReducer.process() already copied if self._xyAxes[0] is not None and self._xyAxes[1] is not None: _kw = {self._xyAxes[0]: 0, self._xyAxes[1]: 1} diff --git a/plottr/node/grid.py b/plottr/node/grid.py index 88213aca..8d672abc 100644 --- a/plottr/node/grid.py +++ b/plottr/node/grid.py @@ -482,16 +482,18 @@ def process( if method is GridOption.noGrid: dout = data.expand() elif method is GridOption.guessShape: - dout = dd.datadict_to_meshgrid(data) + dout = dd.datadict_to_meshgrid(data, copy=False) elif method is GridOption.specifyShape: dout = dd.datadict_to_meshgrid( data, target_shape=opts['shape'], inner_axis_order=order, + copy=False, ) elif method is GridOption.metadataShape: try: dout = dd.datadict_to_meshgrid( - data, use_existing_shape=True + data, use_existing_shape=True, + copy=False, ) except ValueError as err: if "Malformed data" in str(err): @@ -499,7 +501,7 @@ def process( "Shape/Setpoint order does" " not match data. Falling back to guessing shape" ) - dout = dd.datadict_to_meshgrid(data) + dout = dd.datadict_to_meshgrid(data, copy=False) else: raise err except GriddingError: diff --git a/plottr/node/node.py b/plottr/node/node.py index 99d1d510..52c8ec6c 100644 --- a/plottr/node/node.py +++ b/plottr/node/node.py @@ -279,7 +279,6 @@ def process(self, dataIn: Optional[DataDictBase]=None) -> Optional[Dict[str, Opt daxes = dataIn.axes() ddeps = dataIn.dependents() dshapes = dataIn.shapes() - dstruct = dataIn.structure(add_shape=False) if None in [self.dataAxes, self.dataDependents, self.dataType, self.dataShapes]: _axesChanged = True @@ -311,7 +310,10 @@ def process(self, dataIn: Optional[DataDictBase]=None) -> Optional[Dict[str, Opt self.dataDependents = ddeps self.dataType = dtype self.dataShapes = dshapes - self.dataStructure = dstruct + + # Only compute structure snapshot when it actually changed + if _structChanged: + self.dataStructure = dataIn._build_structure() if _axesChanged: self.dataAxesChanged.emit(daxes) diff --git a/plottr/node/scaleunits.py b/plottr/node/scaleunits.py index 04dfe6c8..ca0654b6 100644 --- a/plottr/node/scaleunits.py +++ b/plottr/node/scaleunits.py @@ -1,7 +1,16 @@ from enum import Enum, unique from typing import Dict, Optional -from qcodes.plotting import find_scale_and_prefix +try: + from qcodes.plotting.axis_labels import find_scale_and_prefix +except ImportError: + try: + # fallback for qcodes < 0.46 where the function lived under utils + from qcodes.utils.plotting import find_scale_and_prefix # type: ignore[import-not-found, no-redef] + except ImportError: + # fallback when qcodes is not installed (it is an optional dependency) + from plottr.utils.find_scale_and_prefix import find_scale_and_prefix # type: ignore[no-redef] + from plottr import QtWidgets, Signal, Slot from plottr.data.datadict import DataDictBase diff --git a/plottr/plot/base.py b/plottr/plot/base.py index 8b7217d8..9c2ec07a 100644 --- a/plottr/plot/base.py +++ b/plottr/plot/base.py @@ -5,7 +5,7 @@ from collections import OrderedDict from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, replace as dc_replace from enum import Enum, unique, auto from types import TracebackType from typing import Dict, List, Type, Tuple, Optional, Any, \ @@ -453,7 +453,9 @@ def _splitComplexData(self, plotItem: PlotItem) -> List[PlotItem]: re_label, im_label = label + ' (Real)', label + ' (Imag)' re_plotItem = plotItem - im_plotItem = deepcopy(re_plotItem) + im_plotItem = dc_replace(re_plotItem, + data=list(re_plotItem.data), + labels=list(re_plotItem.labels) if re_plotItem.labels else None) re_plotItem.data[-1] = re_data im_plotItem.data[-1] = im_data @@ -485,7 +487,9 @@ def _splitComplexData(self, plotItem: PlotItem) -> List[PlotItem]: mag_label, phase_label = label + ' 20*log10(Mag)', label + ' (Phase)' mag_plotItem = plotItem - phase_plotItem = deepcopy(mag_plotItem) + phase_plotItem = dc_replace(mag_plotItem, + data=list(mag_plotItem.data), + labels=list(mag_plotItem.labels) if mag_plotItem.labels else None) mag_plotItem.data[-1] = mag_data phase_plotItem.data[-1] = phase_data @@ -514,7 +518,9 @@ def _splitComplexData(self, plotItem: PlotItem) -> List[PlotItem]: mag_label, phase_label = label + ' (Mag)', label + ' (Phase)' mag_plotItem = plotItem - phase_plotItem = deepcopy(mag_plotItem) + phase_plotItem = dc_replace(mag_plotItem, + data=list(mag_plotItem.data), + labels=list(mag_plotItem.labels) if mag_plotItem.labels else None) mag_plotItem.data[-1] = mag_data phase_plotItem.data[-1] = phase_data diff --git a/plottr/plot/mpl/autoplot.py b/plottr/plot/mpl/autoplot.py index 387bec37..2305fbf5 100644 --- a/plottr/plot/mpl/autoplot.py +++ b/plottr/plot/mpl/autoplot.py @@ -125,6 +125,7 @@ def plotLine(self, plotItem: PlotItem) -> Optional[List[ScalarMappable]]: assert plotItem.plotOptions is not None return axes[0].plot(x, y, label=lbl, **plotItem.plotOptions) + def plotImage(self, plotItem: PlotItem) -> Optional[ScalarMappable]: assert len(plotItem.data) == 3 x, y, z = plotItem.data @@ -154,6 +155,9 @@ class AutoPlotToolBar(QtWidgets.QToolBar): #: signal emitted when the complex data option has been changed complexRepresentationSelected = Signal(ComplexRepresentation) + #: signal emitted when the colormap has been changed + cmapChanged = Signal(str) + def __init__(self, name: str, parent: Optional[QtWidgets.QWidget] = None): """Constructor for :class:`AutoPlotToolBar`""" @@ -229,6 +233,38 @@ def __init__(self, name: str, parent: Optional[QtWidgets.QWidget] = None): ComplexRepresentation.magAndPhase: self.plotMagPhase }) + self.addSeparator() + self.scrollableAction = self.addAction('Scrollable') + self.scrollableAction.setCheckable(True) + self.scrollableAction.setChecked(False) + self.scrollableAction.setToolTip('Enable scrollable plot area for many subplots') + + self.minHeightSpin = QtWidgets.QSpinBox() + self.minHeightSpin.setRange(40, 2000) + self.minHeightSpin.setValue(100) + self.minHeightSpin.setSuffix(" px") + self.minHeightSpin.setToolTip("Minimum height per subplot row") + self.minHeightSpin.setEnabled(False) + self.addWidget(self.minHeightSpin) + + self.scrollableAction.triggered.connect( + lambda: self.minHeightSpin.setEnabled(self.scrollableAction.isChecked()) + ) + + # Colormap selector + self.addSeparator() + self._cmapLabel = QtWidgets.QLabel(" Colormap: ") + self.addWidget(self._cmapLabel) + self.cmapCombo = QtWidgets.QComboBox() + self.cmapCombo.setToolTip("Select colormap for 2D plots") + self.cmapCombo.setSizeAdjustPolicy( + QtWidgets.QComboBox.AdjustToContents) + self._populateColormaps() + self.addWidget(self.cmapCombo) + + #: signal emitted when the colormap has been changed + self.cmapCombo.currentTextChanged.connect(self._onCmapChanged) + self._currentPlotType = PlotType.empty self._currentlyAllowedPlotTypes: Tuple[PlotType, ...] = () @@ -236,6 +272,30 @@ def __init__(self, name: str, parent: Optional[QtWidgets.QWidget] = None): self.ComplexActions[self._currentComplex].setChecked(True) self._currentlyAllowedComplexTypes: Tuple[ComplexRepresentation, ...] = () + def _populateColormaps(self) -> None: + """Fill the colormap combo box with matplotlib's available colormaps.""" + import matplotlib as mpl + # Curated list of popular colormaps first, then all others + popular = ['viridis', 'magma', 'inferno', 'plasma', 'cividis', + 'coolwarm', 'RdBu_r', 'RdYlBu_r', 'Spectral_r', + 'hot', 'bone', 'gray'] + all_cmaps = sorted(mpl.colormaps()) + # Put popular ones first, then the rest (no duplicates) + ordered = [c for c in popular if c in all_cmaps] + ordered += [c for c in all_cmaps if c not in ordered and not c.endswith('_r')] + self.cmapCombo.addItems(ordered) + # Set current to the matplotlib default + default = mpl.rcParams.get('image.cmap', 'viridis') + idx = self.cmapCombo.findText(default) + if idx >= 0: + self.cmapCombo.setCurrentIndex(idx) + + def _onCmapChanged(self, name: str) -> None: + """Update the matplotlib RC param and signal a replot.""" + import matplotlib as mpl + mpl.rcParams['image.cmap'] = name + self.cmapChanged.emit(name) + def selectPlotType(self, plotType: PlotType) -> None: """makes sure that the selected `plotType` is active (checked), all others are not active. @@ -353,6 +413,7 @@ def __init__(self, parent: Optional[PlotWidgetContainer] = None): self.plotDataType = PlotDataType.unknown self.plotType = PlotType.empty + self._inSetData = False # The default complex behavior is set here. self.complexRepresentation = ComplexRepresentation.realAndImag @@ -368,6 +429,13 @@ def __init__(self, parent: Optional[PlotWidgetContainer] = None): self.plotOptionsToolBar.complexRepresentationSelected.connect( self._complexPreferenceFromToolBar ) + self.plotOptionsToolBar.scrollableAction.triggered.connect( + self._scrollableFromToolBar + ) + self.plotOptionsToolBar.minHeightSpin.editingFinished.connect( + self._scrollableFromToolBar + ) + self.plotOptionsToolBar.cmapChanged.connect(self._cmapFromToolBar) scaling = dpiScalingFactor(self) iconSize = int(36 + 8*(scaling - 1)) @@ -384,9 +452,17 @@ def setData(self, data: Optional[DataDictBase]) -> None: :param data: input data """ super().setData(data) + if data is None: + self.plot.fig.clear() + self.updatePlot() + return self.plotDataType = determinePlotDataType(data) + # Flag to suppress redundant _plotData calls from toolbar signals + # triggered by _processPlotTypeOptions / _processComplexTypeOptions. + self._inSetData = True self._processPlotTypeOptions() self._processComplexTypeOptions() + self._inSetData = False self._plotData() def _processPlotTypeOptions(self) -> None: @@ -429,14 +505,27 @@ def _processComplexTypeOptions(self) -> None: def _plotTypeFromToolBar(self, plotType: PlotType) -> None: if plotType is not self.plotType: self.plotType = plotType - self._plotData() + if not self._inSetData: + self._plotData() @Slot(ComplexRepresentation) def _complexPreferenceFromToolBar(self, complexRepresentation: ComplexRepresentation) -> None: if complexRepresentation is not self.complexRepresentation: self.complexRepresentation = complexRepresentation + if not self._inSetData: + self._plotData() + + @Slot(str) + def _cmapFromToolBar(self, _cmap: str) -> None: + if not self._inSetData: self._plotData() + @Slot() + def _scrollableFromToolBar(self) -> None: + scrollable = self.plotOptionsToolBar.scrollableAction.isChecked() + self.setScrollable(scrollable) + self._plotData() + def _plotData(self) -> None: """Plot the data using previously determined data and plot types.""" @@ -466,5 +555,16 @@ def _plotData(self) -> None: plotDataType=self.plotDataType, **kw) + nSubPlots = fm.nSubPlots() + + # Set canvas minimum height for scrollable mode + scrollable = self.plotOptionsToolBar.scrollableAction.isChecked() + if scrollable and nSubPlots > 2: + nrows = int(nSubPlots ** 0.5 + 0.5) + min_h = self.plotOptionsToolBar.minHeightSpin.value() + self.plot.setMinimumHeight(max(nrows * min_h, 400)) + else: + self.plot.setMinimumHeight(0) + self.setMeta(self.data) self.updatePlot() diff --git a/plottr/plot/mpl/widgets.py b/plottr/plot/mpl/widgets.py index 0f4c3227..fbf6ba12 100644 --- a/plottr/plot/mpl/widgets.py +++ b/plottr/plot/mpl/widgets.py @@ -159,11 +159,22 @@ def __init__(self, parent: Optional[PlotWidgetContainer] = None): self.addMplBarOptions() defaultIconSize = int(16 * dpiScalingFactor(self)) self.mplBar.setIconSize(QtCore.QSize(defaultIconSize, defaultIconSize)) + + #: scroll area for the canvas (enabled by default) + self._scrollArea = QtWidgets.QScrollArea() + self._scrollArea.setWidgetResizable(True) + self._scrollArea.setWidget(self.plot) + layout = QtWidgets.QVBoxLayout(self) - layout.addWidget(self.plot) + layout.addWidget(self._scrollArea) layout.addWidget(self.mplBar) self.setLayout(layout) + def setScrollable(self, scrollable: bool) -> None: + """Enable or disable scrollable canvas for many subplots.""" + if not scrollable: + self.plot.setMinimumHeight(0) + def setMeta(self, data: DataDictBase) -> None: """Add meta info contained in the data to the figure. diff --git a/plottr/plot/pyqtgraph/autoplot.py b/plottr/plot/pyqtgraph/autoplot.py index 31b9e054..8f740bc6 100644 --- a/plottr/plot/pyqtgraph/autoplot.py +++ b/plottr/plot/pyqtgraph/autoplot.py @@ -20,6 +20,7 @@ from plottr import QtWidgets, QtCore, Signal, Slot, \ config_entry as getcfg from plottr.data.datadict import DataDictBase +from plottr.utils.latex import latex_to_html from .plots import Plot, PlotWithColorbar, PlotBase from ..base import AutoFigureMaker as BaseFM, PlotDataType, \ PlotItem, ComplexRepresentation, determinePlotDataType, \ @@ -33,7 +34,8 @@ class FigureWidget(QtWidgets.QWidget): """Widget that contains all plots generated by :class:`.FigureMaker`. - Widget has a vertical layout, and plots can be added in a single column. + Plots are arranged on a near-square grid (like matplotlib's GridSpec), + so that many subplots remain readable. """ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): @@ -44,18 +46,27 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): super().__init__(parent=parent) self.subPlots: List[PlotBase] = [] + self._minPlotHeight: int = 75 self.title = QtWidgets.QLabel(parent=self) self.title.setAlignment(QtCore.Qt.AlignHCenter) - self.split = QtWidgets.QSplitter(parent=self) - self.split.setOrientation(QtCore.Qt.Vertical) + self._gridWidget = QtWidgets.QWidget(parent=self) + self._gridLayout = QtWidgets.QGridLayout(self._gridWidget) + self._gridLayout.setContentsMargins(0, 0, 0, 0) + self._gridLayout.setSpacing(2) + self._gridWidget.setLayout(self._gridLayout) + + # Wrap the grid in a scroll area so very many plots are still accessible + self._scrollArea = QtWidgets.QScrollArea(parent=self) + self._scrollArea.setWidgetResizable(True) + self._scrollArea.setWidget(self._gridWidget) layout = QtWidgets.QVBoxLayout() layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(2) layout.addWidget(self.title) - layout.addWidget(self.split) + layout.addWidget(self._scrollArea) self.setLayout(layout) self.setTitle('') @@ -64,10 +75,62 @@ def addPlot(self, plot: PlotBase) -> None: """Add a :class:`.PlotBase` widget. :param plot: plot widget - :param title: title of the plot """ - self.split.addWidget(plot) self.subPlots.append(plot) + # Don't add to layout yet — _arrangeGrid() is called after all plots are added + + def _arrangeGrid(self, min_plot_height: Optional[int] = None) -> None: + """Arrange subplots on a near-square grid, matching matplotlib's layout.""" + n = len(self.subPlots) + + # Remove existing items before re-adding to avoid stale layout entries + while self._gridLayout.count(): + self._gridLayout.takeAt(0) + + # Reset all row/column stretches from previous arrangement + for r in range(self._gridLayout.rowCount()): + self._gridLayout.setRowStretch(r, 0) + for c in range(self._gridLayout.columnCount()): + self._gridLayout.setColumnStretch(c, 0) + + if n == 0: + self._gridWidget.setMinimumHeight(0) + return + + if min_plot_height is None: + min_plot_height = self._minPlotHeight + + nrows = max(1, int(n ** 0.5 + 0.5)) + ncols = max(1, int(np.ceil(n / nrows))) + + self._gridWidget.setMinimumHeight(nrows * min_plot_height) + + for i, plot in enumerate(self.subPlots): + row = i // ncols + col = i % ncols + self._gridLayout.addWidget(plot, row, col) + + # Set equal stretch so all rows/columns get the same space + for r in range(nrows): + self._gridLayout.setRowStretch(r, 1) + for c in range(ncols): + self._gridLayout.setColumnStretch(c, 1) + + def setScrollable(self, scrollable: bool) -> None: + """Enable or disable scroll area around the plot grid.""" + if scrollable: + self._scrollArea.setWidgetResizable(True) + # Re-apply grid min height if we have plots + if self.subPlots: + n = len(self.subPlots) + nrows = max(1, int(n ** 0.5 + 0.5)) + self._gridWidget.setMinimumHeight(nrows * self._minPlotHeight) + else: + self._gridWidget.setMinimumHeight(0) + else: + # Disable scrolling: widget resizes with the scroll area + self._scrollArea.setWidgetResizable(True) + self._gridWidget.setMinimumHeight(0) def clearAllPlots(self) -> None: """Clear all plot contents.""" @@ -149,6 +212,7 @@ def makeSubPlots(self, nSubPlots: int) -> List[PlotBase]: elif max(self.dataDimensionsInSubPlot(i).values()) == 2: plot = PlotWithColorbar(self.widget) self.widget.addPlot(plot) + self.widget._arrangeGrid() else: self.widget.clearAllPlots() @@ -165,17 +229,17 @@ def formatSubPlot(self, subPlotId: int) -> None: # label the x axis if there's only one x label if isinstance(subPlot, Plot): if len(set(labels[0])) == 1: - subPlot.plot.setLabel("bottom", labels[0][0]) + subPlot.plot.setLabel("bottom", latex_to_html(labels[0][0])) if isinstance(subPlot, PlotWithColorbar): if len(set(labels[0])) == 1: - subPlot.plot.setLabel("bottom", labels[0][0]) + subPlot.plot.setLabel("bottom", latex_to_html(labels[0][0])) if len(set(labels[1])) == 1: - subPlot.plot.setLabel('left', labels[1][0]) + subPlot.plot.setLabel('left', latex_to_html(labels[1][0])) if len(set(labels[2])) == 1: - subPlot.colorbar.setLabel('left', labels[2][0]) + subPlot.colorbar.setLabel('left', latex_to_html(labels[2][0])) def plot(self, plotItem: PlotItem) -> None: """Plot the given item.""" @@ -288,6 +352,9 @@ def setData(self, data: Optional[DataDictBase]) -> None: """ super().setData(data) if self.data is None: + if self.fmWidget is not None: + self.fmWidget.deleteAllPlots() + self.fmWidget._arrangeGrid() return fmKwargs = {} # {'widget': self.fmWidget} @@ -328,12 +395,16 @@ def _plotData(self, **kwargs: Any) -> None: self.figConfig.figCopied.connect(self.onfigCopied) self.figConfig.figSaved.connect(self.onfigSaved) + self.fmWidget.setScrollable(self.figOptions.scrollablePlots) + self.fmWidget._minPlotHeight = self.figOptions.minPlotHeight + if self.data.has_meta('title'): self.fmWidget.setTitle(self.data.meta_val('title')) self.title = self.data.meta_val('title') #update FigOptions numAxes and imagData self.figOptions.numAxes = len(inds) + self.figOptions.imagData = False #define imagData for single and multiple value data for val in dvals: @@ -407,6 +478,12 @@ class FigureOptions: #: whether the dependent data contains any instance of imaginary data imagData: bool = False + #: whether to enable scrollable plot area (useful for many subplots) + scrollablePlots: bool = False + + #: minimum height per subplot row in pixels (when scrollable) + minPlotHeight: int = 75 + class FigureConfigToolBar(QtWidgets.QToolBar): """Simple toolbar to configure the figure.""" @@ -440,6 +517,33 @@ def __init__(self, options: FigureOptions, lambda: self._setOption('combineLinePlots', combineLinePlots.isChecked()) ) + + scrollablePlots = self.addAction("Scrollable") + scrollablePlots.setCheckable(True) + scrollablePlots.setChecked(self.options.scrollablePlots) + scrollablePlots.setToolTip("Enable scrollable plot area for many subplots") + scrollablePlots.triggered.connect( + lambda: self._setOption('scrollablePlots', + scrollablePlots.isChecked()) + ) + + self._minHeightSpin = QtWidgets.QSpinBox() + self._minHeightSpin.setRange(40, 2000) + self._minHeightSpin.setValue(self.options.minPlotHeight) + self._minHeightSpin.setSuffix(" px") + self._minHeightSpin.setToolTip("Minimum height per subplot row") + self._minHeightSpin.setEnabled(self.options.scrollablePlots) + self._minHeightSpin.editingFinished.connect( + lambda: self._setOption('minPlotHeight', + self._minHeightSpin.value()) + ) + self.addWidget(self._minHeightSpin) + + # Keep spinbox enabled state in sync with scrollable toggle + scrollablePlots.triggered.connect( + lambda: self._minHeightSpin.setEnabled(scrollablePlots.isChecked()) + ) + complexOptions = QtWidgets.QMenu(parent=self) complexGroup = QtWidgets.QActionGroup(complexOptions) complexGroup.setExclusive(True) diff --git a/plottr/plot/pyqtgraph/plots.py b/plottr/plot/pyqtgraph/plots.py index 5aece289..1e9e44ac 100644 --- a/plottr/plot/pyqtgraph/plots.py +++ b/plottr/plot/pyqtgraph/plots.py @@ -36,6 +36,7 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: #: ``pyqtgraph`` plot item self.plot: pg.PlotItem = self.graphicsLayout.addPlot() + self.setMinimumSize(40, 40) def clearPlot(self) -> None: """Clear all plot contents (but do not delete plot elements, like axis @@ -113,7 +114,10 @@ def setImage(self, x: np.ndarray, y: np.ndarray, z: np.ndarray) -> None: self.img = pg.ImageItem() self.plot.addItem(self.img) - self.img.setImage(z) + # Transpose z to match matplotlib convention: the first axis of the + # meshgrid (labeled on bottom/x) maps to the horizontal display axis. + # pyqtgraph ImageItem displays array[col, row], so z.T is needed. + self.img.setImage(z.T) self.img.setRect(QtCore.QRectF(x.min(), y.min(), x.max() - x.min(), y.max() - y.min())) self.colorbar.setImageItem(self.img) diff --git a/plottr/utils/latex.py b/plottr/utils/latex.py new file mode 100644 index 00000000..c9345f47 --- /dev/null +++ b/plottr/utils/latex.py @@ -0,0 +1,73 @@ +""" +plottr.utils.latex — Lightweight LaTeX-to-HTML conversion for plot labels. + +Converts common LaTeX notation used in physics labels into HTML that Qt's +rich text renderer can display (for pyqtgraph axis labels, titles, etc.). + +Uses ``unicodeit`` for Greek letters and math symbols, then converts +subscript/superscript braces to HTML ````/```` tags. +""" +import re + +import unicodeit + + +_LATEX_INDICATOR = re.compile( + r'\\[a-zA-Z]' # backslash command (\alpha, \frac, …) + r'|\$' # dollar-sign math delimiter + r'|_\{' # braced subscript _{...} + r'|\^\{' # braced superscript ^{...} +) + + +def latex_to_html(text: str) -> str: + """Convert LaTeX-like notation in *text* to HTML suitable for Qt rich text. + + The conversion is only applied when the string contains recognisable LaTeX + syntax — backslash commands (``\\alpha``), dollar-sign delimiters + (``$…$``), or braced sub/superscripts (``_{…}``, ``^{…}``). Plain text + with ordinary underscores (e.g. ``gate_voltage``) passes through unchanged. + + Handles: + - Greek letters: ``\\alpha`` → α, ``\\Omega`` → Ω, etc. (via unicodeit) + - Math symbols: ``\\hbar`` → ℏ, ``\\partial`` → ∂, ``\\infty`` → ∞, etc. + - Subscripts: ``V_{gate}`` → ``Vgate`` + - Superscripts: ``x^{2}`` → ``x2`` + - Fractions: ``\\frac{dI}{dV}`` → ``dI/dV`` + - Square root: ``\\sqrt{x}`` → ``√x`` + - Dollar-sign math delimiters are stripped: ``$...$`` → contents + + :param text: input string, possibly containing LaTeX notation. + :returns: HTML string suitable for Qt ``setHtml()`` or pyqtgraph labels. + """ + if not text: + return text + + # Only enter the conversion pipeline when the string looks like LaTeX. + if not _LATEX_INDICATOR.search(text): + return text + + s = text + + # Strip dollar-sign math delimiters + s = re.sub(r'\$([^$]*)\$', r'\1', s) + + # Convert \frac{a}{b} -> a/b + s = re.sub(r'\\frac\{([^}]*)\}\{([^}]*)\}', r'\1/\2', s) + + # Convert \sqrt{x} -> √x + s = re.sub(r'\\sqrt\{([^}]*)\}', '\u221a\\1', s) + + # Convert \overline{x} -> x̅, \bar{x} -> x̅ + s = re.sub(r'\\(?:overline|bar)\{([^}]*)\}', '\\1\u0305', s) + + # Convert braced subscripts and superscripts to HTML BEFORE unicodeit, + # so unicodeit doesn't turn them into Unicode sub/superscript chars. + # Only braced forms (_{...}, ^{...}) — bare underscores are left alone. + s = re.sub(r'_\{([^}]*)\}', r'\1', s) + s = re.sub(r'\^\{([^}]*)\}', r'\1', s) + + # Apply unicodeit for Greek letters and math symbols. + s = unicodeit.replace(s) + + return s diff --git a/plottr/utils/num.py b/plottr/utils/num.py index 3d841c49..5a8b09b7 100644 --- a/plottr/utils/num.py +++ b/plottr/utils/num.py @@ -25,7 +25,26 @@ def largest_numtype(arr: np.ndarray, include_integers: bool = True) \ only integers in the the data. :return: type if possible. None if no numeric data in array. """ - types = {type(a) for a in np.array(arr).flatten()} + arr = np.asarray(arr) + + # Fast path: use numpy's dtype for homogeneous numeric arrays + if arr.size == 0: + return None + if arr.dtype != object: + if np.issubdtype(arr.dtype, np.complexfloating): + return arr.dtype.type + elif np.issubdtype(arr.dtype, np.floating): + return arr.dtype.type + elif np.issubdtype(arr.dtype, np.integer): + if include_integers: + return arr.dtype.type + else: + return float + else: + return None + + # Slow path for object arrays: inspect element types + types = {type(a) for a in arr.ravel() if a is not None} curidx = -1 if include_integers: ok_types = NUMTYPES @@ -55,14 +74,25 @@ def _are_equal(a: np.ndarray, b: np.ndarray) -> np.ndarray: def is_invalid(a: np.ndarray) -> np.ndarray: - # really use == None to do an element wise - # check for None - isnone = a == None - if a.dtype in FLOATTYPES: - isnan = np.isnan(a) + """Check element-wise for invalid entries (None or NaN). + + For numeric dtypes (int, float, complex), only NaN is checked — + numeric arrays can never contain None. + For object arrays, also checks for None. + """ + if a.dtype.kind in ('f', 'c'): + # float or complex: None is impossible, only NaN + return np.isnan(a) + elif a.dtype.kind in ('i', 'u', 'b'): + # integer, unsigned, bool: can never be invalid + return np.zeros(a.shape, dtype=bool) else: - isnan = np.zeros(a.shape, dtype=bool) - return isnone | isnan + # object arrays: check for None and NaN + isnone = a == None # noqa: E711 — element-wise check + try: + return isnone | np.isnan(a) + except (TypeError, ValueError): + return isnone def _are_invalid(a: np.ndarray, b: np.ndarray) -> np.ndarray: @@ -139,17 +169,29 @@ def array1d_to_meshgrid(arr: Union[List, np.ndarray], def _find_switches(arr: np.ndarray, rth: float = 25, ztol: float = 1e-15) -> np.ndarray: - arr_: np.ndarray = np.ma.MaskedArray(arr, is_invalid(arr)) - deltas = arr_[1:] - arr_[:-1] - hi = np.percentile(arr[~is_invalid(arr)], 100.-rth) - lo = np.percentile(arr[~is_invalid(arr)], rth) - diff = np.abs(hi-lo) + # Compute invalid mask once, reuse everywhere + invalid = is_invalid(arr) + valid_mask = ~invalid + + # Use np.diff directly — for entries adjacent to invalid values, the delta + # will be nan. We handle this by using nan-aware operations below. + deltas = arr[1:] - arr[:-1] + + # Compute percentile range of valid data + valid_data = arr[valid_mask] + if valid_data.size == 0: + return np.array([]) + + lo, hi = np.percentile(valid_data, [rth, 100. - rth]) + diff = np.abs(hi - lo) if not diff > ztol: return np.array([]) # first step: suspected switches are where we have 'large' jumps in value. - switch_candidates = np.where(np.abs(deltas) >= diff)[0] + # Use nan-safe abs: nan deltas will produce nan >= diff which is False + abs_deltas = np.abs(deltas) + switch_candidates = np.where(abs_deltas >= diff)[0] switch_candidates = switch_candidates[switch_candidates > 0] if not len(switch_candidates) > 0: return np.array([]) @@ -157,15 +199,13 @@ def _find_switches(arr: np.ndarray, # importantly: switches have to opposite to the sweep direction. # we check the sweep direction by looking at the values prior to the # first suspected switch - sweep_direction = np.sign(np.mean(deltas[:switch_candidates[0]])) + sweep_direction = np.sign(np.nanmean(deltas[:switch_candidates[0]])) # real switches are then those where the delta is opposite to the sweep - # direction. + # direction. Vectorized filter instead of list comprehension. switch_candidate_vals = deltas[switch_candidates] - switches = [s for (s, v) in zip(switch_candidates, switch_candidate_vals) - if np.sign(v) == -sweep_direction] - - return np.array(switches) + mask = np.sign(switch_candidate_vals) == -sweep_direction + return switch_candidates[mask] def find_direction_period(vals: np.ndarray, ignore_last: bool = False) \ @@ -233,13 +273,14 @@ def guess_grid_from_sweep_direction(**axes: np.ndarray) \ raise ValueError("Empty input.") for name, vals in axes.items(): - if len(np.array(vals).shape) > 1: + vals_arr = np.asarray(vals) + if vals_arr.ndim > 1: raise ValueError( - f"Expect 1-dimensional axis data, not {np.array(vals).shape}") + f"Expect 1-dimensional axis data, not {vals_arr.shape}") if size is None: - size = np.array(vals).size + size = vals_arr.size else: - if size != np.array(vals).size: + if size != vals_arr.size: raise ValueError("Non-matching array sizes.") # first step: find repeating patterns in the data. @@ -265,7 +306,7 @@ def guess_grid_from_sweep_direction(**axes: np.ndarray) \ else: if mean == 0: mean = max(np.abs(vals.max()), np.abs(vals.min())) - cost = 1./np.abs(np.std(vals)/mean) + cost = 1./np.abs(std/mean) sorting.append(size + cost) else: return None diff --git a/pyproject.toml b/pyproject.toml index b46e709c..c3bfc4b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "psutil", "watchdog", "pyzmq", + "unicodeit>=0.7.5", ] dynamic = ["version"] @@ -102,10 +103,20 @@ module = [ "matplotlib.*", "pyqtgraph.*", "xhistogram.*", - "ruamel.*" + "ruamel.*", + "unicodeit", ] ignore_missing_imports = true +# These modules contain type: ignore comments that may be unused +# depending on the Qt stubs version installed. +[[tool.mypy.overrides]] +module = [ + "plottr.node.autonode", + "plottr.node.scaleunits", +] +warn_unused_ignores = false + [tool.versioningit] default-version = "0.0" diff --git a/test/pytest/test_data_selector.py b/test/pytest/test_data_selector.py index 54d9a7d8..389cc4c7 100644 --- a/test/pytest/test_data_selector.py +++ b/test/pytest/test_data_selector.py @@ -83,3 +83,91 @@ def test_incompatible_sets(qtbot): node.selectedData = data.dependents()[1] assert fc.output()['dataOut'].dependents() == [data.dependents()[1]] + + +# -- Selection buttons (select all, deselect, 1D, 2D) -- + +class TestSelectionButtons: + """Verify Select All / Deselect / 1D / 2D in DataSelectionWidget.""" + + @staticmethod + def _mixed(): + from plottr.data.datadict import DataDictBase + return DataDictBase( + trace1d=dict(values=np.arange(10.0), axes=['x']), + trace1d_b=dict(values=np.arange(10.0), axes=['x']), + x=dict(values=np.arange(10.0)), + map2d=dict(values=np.arange(20.0), axes=['x', 'y']), + map2d_b=dict(values=np.arange(20.0), axes=['x', 'y']), + y=dict(values=np.arange(20.0)), + ) + + def test_select_all(self, qtbot): + from plottr.gui.data_display import DataSelectionWidget + w = DataSelectionWidget(); qtbot.addWidget(w) + dd = self._mixed(); w.setData(dd, dd.shapes()) + w.selectAll() + assert set(w.getSelectedData()) == set(dd.dependents()) + + def test_select_first(self, qtbot): + """selectFirst should select only the first dependent.""" + from plottr.gui.data_display import DataSelectionWidget + w = DataSelectionWidget(); qtbot.addWidget(w) + dd = self._mixed(); w.setData(dd, dd.shapes()) + w.selectAll() + w.selectFirst() + selected = w.getSelectedData() + assert len(selected) == 1 + assert selected[0] == dd.dependents()[0] + + def test_select_1d(self, qtbot): + from plottr.gui.data_display import DataSelectionWidget + w = DataSelectionWidget(); qtbot.addWidget(w) + dd = self._mixed(); w.setData(dd, dd.shapes()) + w.selectByNdims(1) + sel = w.getSelectedData() + assert 'trace1d' in sel and 'trace1d_b' in sel + assert 'map2d' not in sel + + def test_select_2d(self, qtbot): + from plottr.gui.data_display import DataSelectionWidget + w = DataSelectionWidget(); qtbot.addWidget(w) + dd = self._mixed(); w.setData(dd, dd.shapes()) + w.selectByNdims(2) + sel = w.getSelectedData() + assert 'map2d' in sel and 'map2d_b' in sel + assert 'trace1d' not in sel + + def test_select_resets_previous(self, qtbot): + from plottr.gui.data_display import DataSelectionWidget + w = DataSelectionWidget(); qtbot.addWidget(w) + dd = self._mixed(); w.setData(dd, dd.shapes()) + w.selectAll() + w.selectByNdims(1) + for name in w.getSelectedData(): + assert len(dd.axes(name)) == 1 + + def test_has_dependents_with_ndims(self, qtbot): + from plottr.gui.data_display import DataSelectionWidget + w = DataSelectionWidget(); qtbot.addWidget(w) + dd = self._mixed(); w.setData(dd, dd.shapes()) + assert w.has_dependents_with_ndims(1) + assert w.has_dependents_with_ndims(2) + assert not w.has_dependents_with_ndims(3) + + def test_batch_emits_single_signal(self, qtbot): + from plottr.gui.data_display import DataSelectionWidget + w = DataSelectionWidget(); qtbot.addWidget(w) + dd = self._mixed(); w.setData(dd, dd.shapes()) + count = [0] + w.dataSelectionMade.connect(lambda _: count.__setitem__(0, count[0] + 1)) + w.selectAll() + assert count[0] == 1 + + def test_empty_dataset(self, qtbot): + from plottr.gui.data_display import DataSelectionWidget + from plottr.data.datadict import DataDictBase + w = DataSelectionWidget(); qtbot.addWidget(w) + w.setData(DataDictBase(), {}) + w.selectAll() + assert w.getSelectedData() == [] diff --git a/test/pytest/test_datadict_copy_semantics.py b/test/pytest/test_datadict_copy_semantics.py new file mode 100644 index 00000000..048803f5 --- /dev/null +++ b/test/pytest/test_datadict_copy_semantics.py @@ -0,0 +1,709 @@ +""" +test_datadict_copy_semantics.py + +Comprehensive tests for DataDict copy semantics, data integrity through pipeline +operations, and edge cases. These tests serve as a safety net before making +performance optimizations to the DataDict implementation. +""" +import copy as cp + +import numpy as np +import pytest + +from plottr.data.datadict import ( + DataDict, + DataDictBase, + MeshgridDataDict, + datadict_to_meshgrid, + meshgrid_to_datadict, + datasets_are_equal, +) +from plottr.utils import num + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_datadict(npts: int = 100) -> DataDict: + """Simple 1D DataDict: x -> y, z.""" + return DataDict( + x=dict(values=np.arange(npts, dtype=float), unit='V', label='x'), + y=dict(values=np.random.randn(npts), axes=['x'], unit='A', label='y'), + z=dict(values=np.random.randn(npts), axes=['x'], unit='A', label='z'), + ) + + +def make_meshgrid(shape: tuple = (10, 8), ndeps: int = 2) -> MeshgridDataDict: + """Gridded data with given shape.""" + naxes = len(shape) + dd = MeshgridDataDict() + ax_names = [f'ax{i}' for i in range(naxes)] + grids = np.meshgrid(*[np.linspace(0, 1, s) for s in shape], indexing='ij') + for i, ax in enumerate(ax_names): + dd[ax] = dict(values=grids[i], axes=[], unit='V', label=ax) + for i in range(ndeps): + dd[f'dep{i}'] = dict( + values=np.random.randn(*shape), + axes=ax_names.copy(), + unit='A', + label=f'dep{i}', + ) + dd.validate() + return dd + + +# =========================================================================== +# 1. COPY ISOLATION TESTS +# =========================================================================== + +class TestCopyIsolation: + """Verify that copy() produces fully independent data.""" + + def test_copy_values_independent(self): + """Modifying copied values must not affect the original.""" + dd = make_datadict() + dd2 = dd.copy() + dd2['y']['values'][0] = 999.0 + assert dd['y']['values'][0] != 999.0 + + def test_copy_axes_independent(self): + """Modifying copied axes list must not affect the original.""" + dd = make_datadict() + dd2 = dd.copy() + dd2['y']['axes'].append('extra') + assert 'extra' not in dd['y']['axes'] + + def test_copy_unit_independent(self): + """Changing unit on copy must not affect original.""" + dd = make_datadict() + dd2 = dd.copy() + dd2['y']['unit'] = 'mA' + assert dd['y']['unit'] == 'A' + + def test_copy_meta_independent(self): + """Modifying mutable metadata on copy must not affect original. + + This was previously broken (copy() via structure() did not deepcopy + global mutable metadata). Fixed by the Phase 1a copy() rewrite. + """ + dd = make_datadict() + dd.add_meta('info', {'key': 'value'}) + dd2 = dd.copy() + dd2.meta_val('info')['key'] = 'changed' + assert dd.meta_val('info')['key'] == 'value' + + def test_copy_field_meta_independent(self): + """Per-field mutable metadata should be independent after copy. + + Note: this works because structure() calls cp.deepcopy on each field dict, + which catches per-field meta. However, global meta is NOT deepcopied + (see test_copy_meta_independent above). + """ + dd = make_datadict() + dd.add_meta('cal', [1, 2, 3], data='y') + dd2 = dd.copy() + dd2.meta_val('cal', 'y').append(4) + assert dd.meta_val('cal', 'y') == [1, 2, 3] + + def test_copy_preserves_type_datadict(self): + dd = make_datadict() + dd2 = dd.copy() + assert type(dd2) is DataDict + + def test_copy_preserves_type_meshgrid(self): + dd = make_meshgrid() + dd2 = dd.copy() + assert type(dd2) is MeshgridDataDict + + def test_copy_preserves_equality(self): + dd = make_datadict() + dd.add_meta('info', 'test') + dd2 = dd.copy() + assert dd == dd2 + + def test_meshgrid_copy_values_independent(self): + dd = make_meshgrid((10, 8)) + dd2 = dd.copy() + dd2['dep0']['values'][0, 0] = 999.0 + assert dd['dep0']['values'][0, 0] != 999.0 + + def test_meshgrid_copy_axes_independent(self): + dd = make_meshgrid((10, 8)) + dd2 = dd.copy() + original_axes = dd['dep0']['axes'].copy() + dd2['dep0']['axes'].pop() + assert dd['dep0']['axes'] == original_axes + + +# =========================================================================== +# 2. EXTRACT ISOLATION TESTS +# =========================================================================== + +class TestExtractIsolation: + """Verify that extract() produces independent data when copy=True.""" + + def test_extract_copy_true_values_independent(self): + dd = make_datadict() + ex = dd.extract(['y'], copy=True) + ex['y']['values'][0] = 999.0 + assert dd['y']['values'][0] != 999.0 + + def test_extract_copy_true_axes_independent(self): + dd = make_datadict() + ex = dd.extract(['y'], copy=True) + ex['y']['axes'].append('extra') + assert 'extra' not in dd['y']['axes'] + + def test_extract_copy_false_shares_values(self): + dd = make_datadict() + ex = dd.extract(['y'], copy=False) + # With copy=False, arrays are shared + assert np.shares_memory(ex['y']['values'], dd['y']['values']) + + def test_extract_includes_axes_fields(self): + dd = make_datadict() + ex = dd.extract(['y']) + assert 'x' in ex + assert 'y' in ex + assert 'z' not in ex + + def test_extract_includes_meta(self): + dd = make_datadict() + dd.add_meta('info', 'hello') + ex = dd.extract(['y']) + assert ex.has_meta('info') + + def test_extract_preserves_field_meta(self): + dd = make_datadict() + dd.add_meta('cal', 42, data='y') + ex = dd.extract(['y']) + assert ex.meta_val('cal', 'y') == 42 + + +# =========================================================================== +# 3. STRUCTURE TESTS +# =========================================================================== + +class TestStructure: + """Verify structure() correctness and independence.""" + + def test_structure_has_empty_values(self): + dd = make_datadict() + s = dd.structure() + for _, v in s.data_items(): + assert len(v['values']) == 0 + + def test_structure_preserves_axes(self): + dd = make_datadict() + s = dd.structure() + assert s['y']['axes'] == ['x'] + + def test_structure_preserves_units(self): + dd = make_datadict() + s = dd.structure() + assert s['y']['unit'] == 'A' + + def test_structure_preserves_meta(self): + dd = make_datadict() + dd.add_meta('info', 'test') + s = dd.structure() + assert s.meta_val('info') == 'test' + + def test_structure_axes_independent(self): + """Mutating axes in structure must not affect original.""" + dd = make_datadict() + s = dd.structure() + s['y']['axes'].append('extra') + assert 'extra' not in dd['y']['axes'] + + def test_structure_preserves_custom_field_keys(self): + """Custom keys in field dicts must be preserved.""" + dd = make_datadict() + dd['y']['__shape__'] = (100,) + dd['y']['__custom_meta__'] = 'hello' + s = dd.structure() + assert '__shape__' in s['y'] + assert '__custom_meta__' in s['y'] + + def test_structure_with_remove_data(self): + dd = make_meshgrid((5, 4)) + s = dd.structure(remove_data=['ax0']) + assert 'ax0' not in s + for dep in s.dependents(): + assert 'ax0' not in s[dep]['axes'] + + +# =========================================================================== +# 4. EDGE CASES: DATA TYPES +# =========================================================================== + +class TestEdgeCaseDataTypes: + """Tests with unusual data types.""" + + def test_object_array_with_none(self): + """DataDict with object arrays containing None values.""" + dd = DataDict( + x=dict(values=np.array([1, 2, 3, 4, 5], dtype=object)), + y=dict(values=np.array([1.0, None, 3.0, None, 5.0], dtype=object), + axes=['x']), + ) + assert dd.validate() + dd2 = dd.copy() + assert dd == dd2 + + def test_complex_array(self): + """DataDict with complex-valued data.""" + dd = DataDict( + x=dict(values=np.arange(10, dtype=float)), + y=dict(values=np.random.randn(10) + 1j * np.random.randn(10), + axes=['x']), + ) + assert dd.validate() + dd2 = dd.copy() + assert dd == dd2 + dd2['y']['values'][0] = 999 + 0j + assert dd['y']['values'][0] != 999 + 0j + + def test_integer_array(self): + """DataDict with integer data (no NaN possible).""" + dd = DataDict( + x=dict(values=np.arange(10)), + y=dict(values=np.arange(10, 20), axes=['x']), + ) + assert dd.validate() + dd2 = dd.copy() + assert dd == dd2 + + def test_masked_array_values(self): + """DataDict where values are already MaskedArrays.""" + vals = np.ma.MaskedArray([1.0, 2.0, 3.0], mask=[False, True, False]) + dd = DataDict( + x=dict(values=np.arange(3, dtype=float)), + y=dict(values=vals, axes=['x']), + ) + assert dd.validate() + dd2 = dd.copy() + assert np.ma.is_masked(dd2['y']['values']) + + def test_empty_datadict(self): + """Empty DataDict operations.""" + dd = DataDict() + s = dd.structure() + assert s is not None + dd2 = dd.copy() + assert dd == dd2 + + def test_single_point(self): + """DataDict with a single data point.""" + dd = DataDict( + x=dict(values=np.array([1.0])), + y=dict(values=np.array([2.0]), axes=['x']), + ) + assert dd.validate() + dd2 = dd.copy() + assert dd == dd2 + + +# =========================================================================== +# 5. MASK_INVALID TESTS +# =========================================================================== + +class TestMaskInvalid: + """Tests for mask_invalid() behavior with different data.""" + + def test_mask_invalid_clean_float_data(self): + """Clean float data — all values valid.""" + dd = make_datadict() + dd2 = dd.copy() + dd2 = dd2.mask_invalid() + # Values should be unchanged (though possibly wrapped in MaskedArray) + for name, _ in dd2.data_items(): + assert np.allclose( + np.asarray(dd.data_vals(name)), + np.asarray(dd2.data_vals(name)), + ) + + def test_mask_invalid_with_nan(self): + """Float data with NaN values should be masked.""" + dd = DataDict( + x=dict(values=np.array([1.0, 2.0, 3.0])), + y=dict(values=np.array([1.0, np.nan, 3.0]), axes=['x']), + ) + dd = dd.mask_invalid() + y_vals = dd.data_vals('y') + assert isinstance(y_vals, np.ma.MaskedArray) + assert y_vals.mask[1] == True + + def test_mask_invalid_with_none_objects(self): + """Object array with None values should be masked.""" + dd = DataDict( + x=dict(values=np.array([1, 2, 3], dtype=object)), + y=dict(values=np.array([1.0, None, 3.0], dtype=object), axes=['x']), + ) + dd = dd.mask_invalid() + y_vals = dd.data_vals('y') + assert isinstance(y_vals, np.ma.MaskedArray) + + def test_mask_invalid_preserves_structure(self): + """Structure should be unchanged after masking.""" + dd = make_meshgrid() + s_before = dd.structure() + dd = dd.mask_invalid() + s_after = dd.structure() + assert DataDictBase.same_structure( + s_before, s_after + ) + + +# =========================================================================== +# 6. MESHGRID CONVERSION TESTS +# =========================================================================== + +class TestMeshgridConversions: + """Test conversions between DataDict and MeshgridDataDict.""" + + def test_roundtrip_datadict_meshgrid_datadict(self): + """Tabular → grid → tabular should preserve data.""" + x = np.linspace(0, 1, 10) + y = np.arange(5, dtype=float) + xx, yy = np.meshgrid(x, y, indexing='ij') + zz = xx * yy + + dd = DataDict( + x=dict(values=xx.ravel()), + y=dict(values=yy.ravel()), + z=dict(values=zz.ravel(), axes=['x', 'y']), + ) + mesh = datadict_to_meshgrid(dd) + assert isinstance(mesh, MeshgridDataDict) + assert mesh.shape() == (10, 5) + + dd2 = meshgrid_to_datadict(mesh) + assert isinstance(dd2, DataDict) + assert dd2.nrecords() == 50 + + def test_datadict_to_meshgrid_copy_true(self): + """copy=True should produce independent arrays.""" + x = np.arange(6, dtype=float) + y = np.tile(np.arange(3, dtype=float), 2) + dd = DataDict( + x=dict(values=x), + y=dict(values=y), + z=dict(values=np.arange(6, dtype=float), axes=['x', 'y']), + ) + mesh = datadict_to_meshgrid(dd, target_shape=(2, 3), copy=True) + mesh['z']['values'][0, 0] = 999.0 + assert dd['z']['values'][0] != 999.0 + + def test_datadict_to_meshgrid_preserves_meta(self): + """Conversion should preserve global metadata.""" + x = np.arange(6, dtype=float) + y = np.tile(np.arange(3, dtype=float), 2) + dd = DataDict( + x=dict(values=x), + y=dict(values=y), + z=dict(values=np.arange(6, dtype=float), axes=['x', 'y']), + __info__='test_meta', + ) + mesh = datadict_to_meshgrid(dd, target_shape=(2, 3)) + assert mesh.meta_val('info') == 'test_meta' + + def test_meshgrid_to_datadict_independent(self): + """meshgrid_to_datadict should not share arrays with original.""" + mesh = make_meshgrid((5, 4)) + dd = meshgrid_to_datadict(mesh) + dd['dep0']['values'][0] = 999.0 + assert mesh['dep0']['values'].ravel()[0] != 999.0 + + +# =========================================================================== +# 7. MESHGRID VALIDATION TESTS +# =========================================================================== + +class TestMeshgridValidation: + """Test MeshgridDataDict validation, especially monotonicity checks.""" + + def test_valid_monotonic_increasing(self): + dd = make_meshgrid((5, 4)) + assert dd.validate() + + def test_valid_monotonic_decreasing(self): + """Axes that decrease monotonically are valid.""" + dd = MeshgridDataDict() + x = np.linspace(1, 0, 5) # decreasing + y = np.linspace(0, 1, 4) + xx, yy = np.meshgrid(x, y, indexing='ij') + dd['x'] = dict(values=xx, axes=[], unit='V', label='x') + dd['y'] = dict(values=yy, axes=[], unit='V', label='y') + dd['z'] = dict(values=xx + yy, axes=['x', 'y'], unit='A', label='z') + assert dd.validate() + + def test_invalid_non_monotonic(self): + """Axis that goes up then down should fail.""" + dd = MeshgridDataDict() + x_vals = np.array([0, 1, 2, 1, 0], dtype=float) + y_vals = np.arange(3, dtype=float) + xx, yy = np.meshgrid(x_vals, y_vals, indexing='ij') + dd['x'] = dict(values=xx, axes=[], unit='V', label='x') + dd['y'] = dict(values=yy, axes=[], unit='V', label='y') + dd['z'] = dict(values=np.random.randn(5, 3), axes=['x', 'y'], + unit='A', label='z') + with pytest.raises(ValueError, match="not monotonous"): + dd.validate() + + def test_invalid_flat_axis(self): + """Axis with no variation should fail.""" + dd = MeshgridDataDict() + x = np.array([1.0, 1.0, 1.0]) + y = np.arange(4, dtype=float) + xx, yy = np.meshgrid(x, y, indexing='ij') + dd['x'] = dict(values=xx, axes=[], unit='V', label='x') + dd['y'] = dict(values=yy, axes=[], unit='V', label='y') + dd['z'] = dict(values=np.random.randn(3, 4), axes=['x', 'y'], + unit='A', label='z') + with pytest.raises(ValueError, match="no variation"): + dd.validate() + + def test_valid_with_nan_in_axis(self): + """Axis with NaN values (incomplete data) should still validate + if the non-NaN values are monotonic.""" + dd = make_meshgrid((5, 4)) + dd['ax0']['values'][3, :] = np.nan + dd['ax0']['values'][4, :] = np.nan + # Should not raise — NaN steps are ignored + assert dd.validate() + + def test_valid_3d_meshgrid(self): + """3D meshgrid should validate correctly.""" + dd = make_meshgrid((5, 4, 3)) + assert dd.validate() + + def test_shape_mismatch_fails(self): + """Different shapes across fields should fail.""" + dd = MeshgridDataDict() + dd['x'] = dict(values=np.arange(10, dtype=float).reshape(2, 5), + axes=[]) + dd['z'] = dict(values=np.arange(12, dtype=float).reshape(3, 4), + axes=['x']) + with pytest.raises(ValueError): + dd.validate() + + +# =========================================================================== +# 8. SHAPES() EDGE CASES +# =========================================================================== + +class TestShapes: + """Test shapes() with various input states.""" + + def test_shapes_after_validation(self): + dd = make_datadict(50) + dd.validate() + shapes = dd.shapes() + assert shapes['x'] == (50,) + assert shapes['y'] == (50,) + + def test_shapes_with_list_values(self): + """shapes() should work even before validation when values are lists.""" + dd = DataDictBase( + x=dict(values=[1, 2, 3]), + y=dict(values=[4, 5, 6], axes=['x']), + ) + # Should not crash, even without validate() + shapes = dd.shapes() + assert shapes['x'] == (3,) + + def test_shapes_meshgrid(self): + dd = make_meshgrid((10, 8)) + shapes = dd.shapes() + for name in dd.dependents() + dd.axes(): + assert shapes[name] == (10, 8) + + +# =========================================================================== +# 9. PIPELINE DATA INTEGRITY TESTS +# =========================================================================== + +class TestPipelineIntegrity: + """Simulate pipeline operations and verify input is not mutated.""" + + def _simulate_data_selector(self, data: DataDictBase) -> DataDictBase: + """Simulate DataSelector.process() — extract a subset.""" + selected = data.extract(data.dependents()[:1]) + if isinstance(selected, DataDictBase): + selected = DataDict(**selected) + selected.validate() + return selected + + def _simulate_gridder(self, data: DataDict) -> MeshgridDataDict: + """Simulate DataGridder.process() — copy + grid.""" + data_copy = data.copy() + return datadict_to_meshgrid(data_copy) + + def _simulate_dim_reducer(self, data: MeshgridDataDict) -> MeshgridDataDict: + """Simulate DimensionReducer.process() — copy + mask.""" + data_copy = data.copy() + return data_copy.mask_invalid() + + def test_pipeline_does_not_mutate_input(self): + """Full pipeline must not modify the original input data.""" + # Create griddable data + x = np.linspace(0, 1, 10) + y = np.arange(5, dtype=float) + xx, yy = np.meshgrid(x, y, indexing='ij') + + original = DataDict( + x=dict(values=xx.ravel()), + y=dict(values=yy.ravel()), + z=dict(values=(xx * yy).ravel(), axes=['x', 'y']), + ) + original.validate() + + # Save a reference-safe copy for comparison (cp.deepcopy fails on + # DataDict due to _DataAccess inner class, so use the built-in copy) + reference = original.copy() + + # Run simulated pipeline + selected = self._simulate_data_selector(original) + gridded = self._simulate_gridder(selected) + reduced = self._simulate_dim_reducer(gridded) + + # Verify original is unchanged + assert datasets_are_equal(original, reference) + + def test_pipeline_output_types(self): + """Pipeline stages should produce the expected types.""" + x = np.linspace(0, 1, 10) + y = np.arange(5, dtype=float) + xx, yy = np.meshgrid(x, y, indexing='ij') + + dd = DataDict( + x=dict(values=xx.ravel()), + y=dict(values=yy.ravel()), + z=dict(values=(xx * yy).ravel(), axes=['x', 'y']), + ) + + selected = self._simulate_data_selector(dd) + assert isinstance(selected, DataDict) + + gridded = self._simulate_gridder(selected) + assert isinstance(gridded, MeshgridDataDict) + + reduced = self._simulate_dim_reducer(gridded) + assert isinstance(reduced, MeshgridDataDict) + + +# =========================================================================== +# 10. MESHGRID OPERATIONS: mean, slice +# =========================================================================== + +class TestMeshgridOperations: + """Test mean and slice operations on MeshgridDataDict.""" + + def test_mean_reduces_axis(self): + dd = make_meshgrid((10, 8)) + result = dd.mean('ax0') + assert result.shape() == (8,) + assert 'ax0' not in result + + def test_mean_does_not_mutate_original(self): + dd = make_meshgrid((10, 8)) + original_shape = dd.shape() + _ = dd.mean('ax0') + assert dd.shape() == original_shape + + def test_slice_reduces_shape(self): + dd = make_meshgrid((10, 8)) + result = dd.slice(ax0=slice(2, 5)) + assert result.shape() == (3, 8) + + def test_slice_does_not_mutate_original(self): + dd = make_meshgrid((10, 8)) + original_shape = dd.shape() + _ = dd.slice(ax0=slice(2, 5)) + assert dd.shape() == original_shape + + def test_slice_integer_selects_single_element(self): + """Integer indexing on a meshgrid axis selects a single element, + but _mesh_slice does NOT remove the axis — it creates a size-1 dim. + Using a length-1 slice keeps the axis valid.""" + dd = make_meshgrid((10, 8)) + result = dd.slice(ax0=slice(3, 4)) + assert result.shape() == (1, 8) + + +# =========================================================================== +# 11. CUSTOM FIELD KEY PRESERVATION +# =========================================================================== + +class TestCustomFieldKeys: + """Verify that custom field keys are preserved through operations.""" + + def test_copy_preserves_shape_key(self): + dd = make_meshgrid((5, 4)) + dd['dep0']['__shape__'] = (5, 4) + dd2 = dd.copy() + assert dd2['dep0']['__shape__'] == (5, 4) + + def test_copy_preserves_per_field_meta(self): + dd = make_datadict() + dd['y']['__calibration__'] = {'gain': 1.5} + dd2 = dd.copy() + assert dd2['y']['__calibration__'] == {'gain': 1.5} + + def test_structure_preserves_shape_key(self): + dd = make_meshgrid((5, 4)) + dd['dep0']['__shape__'] = (5, 4) + s = dd.structure() + assert '__shape__' in s['dep0'] + + def test_extract_preserves_per_field_meta(self): + dd = make_datadict() + dd['y']['__calibration__'] = {'gain': 1.5} + ex = dd.extract(['y']) + assert ex['y']['__calibration__'] == {'gain': 1.5} + + +# =========================================================================== +# 12. DATASETS_ARE_EQUAL TESTS +# =========================================================================== + +class TestDatasetsAreEqual: + """Additional equality checks including edge cases.""" + + def test_equal_meshgrids(self): + dd = make_meshgrid() + dd2 = dd.copy() + assert datasets_are_equal(dd, dd2) + + def test_not_equal_different_values(self): + dd = make_meshgrid() + dd2 = dd.copy() + dd2['dep0']['values'][0, 0] += 1.0 + assert not datasets_are_equal(dd, dd2) + + def test_not_equal_different_types(self): + dd = make_datadict() + mesh = make_meshgrid() + assert not datasets_are_equal(dd, mesh) + + def test_not_equal_different_shape(self): + dd1 = make_meshgrid((5, 4)) + dd2 = make_meshgrid((5, 3)) + assert not datasets_are_equal(dd1, dd2) + + def test_equal_with_meta(self): + dd = make_datadict() + dd.add_meta('info', 'value') + dd2 = dd.copy() + assert datasets_are_equal(dd, dd2) + assert datasets_are_equal(dd, dd2, ignore_meta=True) + + def test_not_equal_meta_differs(self): + dd = make_datadict() + dd.add_meta('info', 'value') + dd2 = dd.copy() + dd2.set_meta('info', 'different') + assert not datasets_are_equal(dd, dd2) + assert datasets_are_equal(dd, dd2, ignore_meta=True) diff --git a/test/pytest/test_gridder_comprehensive.py b/test/pytest/test_gridder_comprehensive.py new file mode 100644 index 00000000..a870f4ca --- /dev/null +++ b/test/pytest/test_gridder_comprehensive.py @@ -0,0 +1,471 @@ +""" +test_gridder_comprehensive.py + +Comprehensive tests for the DataGridder node and underlying gridding functions. +Covers all GridOption paths, various data shapes, edge cases, and input types. +""" +import numpy as np +import pytest + +from plottr.data.datadict import ( + DataDict, MeshgridDataDict, DataDictBase, + datadict_to_meshgrid, meshgrid_to_datadict, + guess_shape_from_datadict, GriddingError, +) +from plottr.node.tools import linearFlowchart +from plottr.node.grid import DataGridder, GridOption +from plottr.utils.num import ( + guess_grid_from_sweep_direction, find_direction_period, + _find_switches, array1d_to_meshgrid, +) + +DataGridder.useUi = False +DataGridder.uiClass = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_griddable(shape, ndeps=1, noise=0.0): + """Create a griddable DataDict from a meshgrid shape.""" + naxes = len(shape) + ax_names = [f'ax{i}' for i in range(naxes)] + axes_1d = [np.linspace(0, 1, s) for s in shape] + grids = np.meshgrid(*axes_1d, indexing='ij') + dd = DataDict() + for i, ax in enumerate(ax_names): + vals = grids[i].ravel() + if noise > 0: + vals = vals + np.random.randn(vals.size) * noise + dd[ax] = dict(values=vals, axes=[], unit='V', label=ax) + for j in range(ndeps): + dd[f'dep{j}'] = dict(values=np.random.randn(int(np.prod(shape))), + axes=ax_names[:], unit='A', label=f'dep{j}') + dd.validate() + return dd + + +def make_mesh(shape, ndeps=1): + naxes = len(shape) + ax_names = [f'ax{i}' for i in range(naxes)] + axes_1d = [np.linspace(0, 1, s) for s in shape] + grids = np.meshgrid(*axes_1d, indexing='ij') + dd = MeshgridDataDict() + for i, ax in enumerate(ax_names): + dd[ax] = dict(values=grids[i], axes=[], unit='V', label=ax) + for j in range(ndeps): + dd[f'dep{j}'] = dict(values=np.random.randn(*shape), + axes=ax_names[:], unit='A', label=f'dep{j}') + dd.validate() + return dd + + +# =========================================================================== +# _find_switches +# =========================================================================== + +class TestFindSwitches: + def test_monotonic_no_switches(self): + arr = np.linspace(0, 10, 100) + assert len(_find_switches(arr)) == 0 + + def test_single_sawtooth(self): + arr = np.concatenate([np.arange(10), np.arange(10)]) + switches = _find_switches(arr) + assert len(switches) >= 1 + + def test_flat_array(self): + arr = np.ones(50) + assert len(_find_switches(arr)) == 0 + + def test_with_nan(self): + arr = np.linspace(0, 10, 100) + arr[50] = np.nan + switches = _find_switches(arr) + assert isinstance(switches, np.ndarray) + + def test_short_array(self): + arr = np.array([1.0, 2.0]) + switches = _find_switches(arr) + assert isinstance(switches, np.ndarray) + + def test_single_element(self): + arr = np.array([1.0]) + switches = _find_switches(arr) + assert len(switches) == 0 + + +# =========================================================================== +# find_direction_period +# =========================================================================== + +class TestFindDirectionPeriod: + def test_repeating_pattern(self): + # 0,1,2,3,4, 0,1,2,3,4, 0,1,2,3,4 + arr = np.tile(np.arange(5, dtype=float), 3) + period = find_direction_period(arr) + assert period == 5 + + def test_no_repetition(self): + arr = np.linspace(0, 10, 100) + period = find_direction_period(arr) + assert period == np.inf + + def test_incomplete_last_period(self): + arr = np.concatenate([np.tile(np.arange(5, dtype=float), 3), + np.arange(3, dtype=float)]) + period = find_direction_period(arr, ignore_last=True) + assert period == 5 + + def test_single_value(self): + arr = np.array([1.0]) + period = find_direction_period(arr) + assert period is not None # should handle gracefully + + +# =========================================================================== +# guess_grid_from_sweep_direction +# =========================================================================== + +class TestGuessGrid: + @pytest.mark.parametrize("shape", [ + (10,), (5, 4), (3, 4, 2), (10, 10), (20, 15), + ]) + def test_correct_shape_guessed(self, shape): + naxes = len(shape) + ax_names = [f'ax{i}' for i in range(naxes)] + axes_1d = [np.linspace(0, 1, s) for s in shape] + grids = np.meshgrid(*axes_1d, indexing='ij') + kwargs = {ax_names[i]: grids[i].ravel() for i in range(naxes)} + result = guess_grid_from_sweep_direction(**kwargs) + assert result is not None + _, guessed_shape = result + assert guessed_shape == shape + + def test_noisy_axes(self): + shape = (10, 8) + grids = np.meshgrid(np.linspace(0, 1, 10), np.linspace(0, 1, 8), indexing='ij') + x = grids[0].ravel() + np.random.randn(80) * 1e-6 + y = grids[1].ravel() + result = guess_grid_from_sweep_direction(x=x, y=y) + assert result is not None + _, guessed = result + assert guessed == shape + + def test_single_axis(self): + x = np.linspace(0, 1, 50) + result = guess_grid_from_sweep_direction(x=x) + assert result is not None + _, shape = result + assert shape == (50,) + + def test_empty_raises(self): + with pytest.raises(ValueError): + guess_grid_from_sweep_direction() + + def test_mismatched_sizes_raises(self): + with pytest.raises(ValueError): + guess_grid_from_sweep_direction(x=np.arange(10, dtype=float), + y=np.arange(5, dtype=float)) + + +# =========================================================================== +# array1d_to_meshgrid +# =========================================================================== + +class TestArray1dToMeshgrid: + def test_exact_reshape(self): + arr = np.arange(12, dtype=float) + result = array1d_to_meshgrid(arr, (3, 4)) + assert result.shape == (3, 4) + + def test_padding_with_nan(self): + arr = np.arange(10, dtype=float) + result = array1d_to_meshgrid(arr, (4, 4)) # needs 16, has 10 + assert result.shape == (4, 4) + assert np.isnan(result.ravel()[-1]) + + def test_truncation(self): + arr = np.arange(20, dtype=float) + result = array1d_to_meshgrid(arr, (3, 4)) # needs 12, has 20 + assert result.shape == (3, 4) + + def test_copy_true_independent(self): + arr = np.arange(12, dtype=float) + result = array1d_to_meshgrid(arr, (3, 4), copy=True) + result[0, 0] = 999 + assert arr[0] != 999 + + def test_copy_false_may_share(self): + arr = np.arange(12, dtype=float) + result = array1d_to_meshgrid(arr, (3, 4), copy=False) + assert result.shape == (3, 4) + + def test_object_array_padding(self): + arr = np.array([1, 2, 3], dtype=object) + result = array1d_to_meshgrid(arr, (2, 3)) # needs 6, has 3 + assert result.shape == (2, 3) + + +# =========================================================================== +# guess_shape_from_datadict +# =========================================================================== + +class TestGuessShapeFromDatadict: + @pytest.mark.parametrize("shape", [ + (10, 5), (3, 4, 2), (20, 15), + ]) + def test_guesses_correct_shape(self, shape): + dd = make_griddable(shape) + shapes = guess_shape_from_datadict(dd) + for dep in dd.dependents(): + assert shapes[dep] is not None + _, guessed = shapes[dep] + assert guessed == shape + + def test_with_multiple_deps(self): + dd = make_griddable((10, 8), ndeps=3) + shapes = guess_shape_from_datadict(dd) + assert len(shapes) == 3 + for dep in dd.dependents(): + assert shapes[dep] is not None + + +# =========================================================================== +# datadict_to_meshgrid +# =========================================================================== + +class TestDatadictToMeshgrid: + @pytest.mark.parametrize("shape", [ + (5,), (5, 4), (10, 10), (3, 4, 2), + ]) + def test_produces_correct_shape(self, shape): + dd = make_griddable(shape) + mesh = datadict_to_meshgrid(dd) + assert isinstance(mesh, MeshgridDataDict) + assert mesh.shape() == shape + + def test_with_target_shape(self): + dd = make_griddable((5, 4)) + mesh = datadict_to_meshgrid(dd, target_shape=(5, 4)) + assert mesh.shape() == (5, 4) + + def test_with_inner_axis_order(self): + # Create data where inner order doesn't match axes order + x = np.arange(5, dtype=float) + y = np.linspace(0, 1, 4) + xx, yy = np.meshgrid(x, y, indexing='xy') # xy order + dd = DataDict( + x=dict(values=xx.ravel()), y=dict(values=yy.ravel()), + z=dict(values=(xx * yy).ravel(), axes=['x', 'y']), + ) + dd.validate() + mesh = datadict_to_meshgrid(dd, target_shape=(4, 5), + inner_axis_order=['y', 'x']) + assert isinstance(mesh, MeshgridDataDict) + + def test_use_existing_shape(self): + """use_existing_shape works when data already has the right shape.""" + # Need data with nested array shapes matching target + x = np.arange(5, dtype=float) + y = np.linspace(0, 1, 4) + xx, yy = np.meshgrid(x, y, indexing='ij') + dd = DataDict( + x=dict(values=xx), # already (5,4) shaped + y=dict(values=yy), + z=dict(values=xx * yy, axes=['x', 'y']), + ) + dd.validate() + mesh = datadict_to_meshgrid(dd, use_existing_shape=True) + assert isinstance(mesh, MeshgridDataDict) + assert mesh.shape() == (5, 4) + + def test_copy_false(self): + dd = make_griddable((5, 4)) + mesh = datadict_to_meshgrid(dd, copy=False) + assert isinstance(mesh, MeshgridDataDict) + + def test_preserves_meta(self): + dd = make_griddable((5, 4)) + dd.add_meta('info', 'test') + mesh = datadict_to_meshgrid(dd) + assert mesh.meta_val('info') == 'test' + + def test_incompatible_axes_raises(self): + dd = DataDict( + x=dict(values=np.arange(10, dtype=float)), + y=dict(values=np.arange(10, dtype=float), axes=['x']), + z=dict(values=np.arange(10, dtype=float)), + w=dict(values=np.arange(10, dtype=float), axes=['z']), + ) + dd.validate() + with pytest.raises(GriddingError): + datadict_to_meshgrid(dd) + + def test_empty_datadict(self): + dd = DataDict() + dd.validate() + mesh = datadict_to_meshgrid(dd) + assert isinstance(mesh, MeshgridDataDict) + + def test_incomplete_data_pads_with_nan(self): + # 5x4 grid but only 18 of 20 points + shape = (5, 4) + grids = np.meshgrid(np.linspace(0, 1, 5), np.linspace(0, 1, 4), indexing='ij') + dd = DataDict( + x=dict(values=grids[0].ravel()[:18]), + y=dict(values=grids[1].ravel()[:18]), + z=dict(values=np.random.randn(18), axes=['x', 'y']), + ) + dd.validate() + mesh = datadict_to_meshgrid(dd, target_shape=shape) + assert mesh.shape() == shape + # Last 2 values should be NaN + assert np.isnan(mesh.data_vals('z').ravel()[-1]) + + +# =========================================================================== +# meshgrid_to_datadict +# =========================================================================== + +class TestMeshgridToDatadict: + @pytest.mark.parametrize("shape", [ + (5, 4), (10, 10), (3, 4, 2), + ]) + def test_produces_flat(self, shape): + mesh = make_mesh(shape) + dd = meshgrid_to_datadict(mesh) + assert isinstance(dd, DataDict) + assert dd.nrecords() == int(np.prod(shape)) + + +# =========================================================================== +# DataGridder node — all GridOption paths +# =========================================================================== + +class TestDataGridderNode: + + # --- DataDict input --- + + @pytest.mark.parametrize("shape", [ + (10,), (5, 4), (3, 4, 2), + ]) + def test_noGrid_tabular_passthrough(self, qtbot, shape): + dd = make_griddable(shape) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=dd) + fc.nodes()['g'].grid = GridOption.noGrid, {} + out = fc.outputValues()['dataOut'] + assert isinstance(out, DataDict) + + @pytest.mark.parametrize("shape", [ + (10,), (5, 4), (10, 10), (50, 3), (3, 4, 2), + ]) + def test_guessShape_tabular(self, qtbot, shape): + dd = make_griddable(shape) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=dd) + fc.nodes()['g'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + assert isinstance(out, MeshgridDataDict) + assert out.shape() == shape + + def test_specifyShape_tabular(self, qtbot): + dd = make_griddable((5, 4)) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=dd) + fc.nodes()['g'].grid = GridOption.specifyShape, dict( + shape=(5, 4), order=['ax0', 'ax1']) + out = fc.outputValues()['dataOut'] + assert isinstance(out, MeshgridDataDict) + assert out.shape() == (5, 4) + + def test_metadataShape_tabular(self, qtbot): + dd = make_griddable((5, 4)) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=dd) + fc.nodes()['g'].grid = GridOption.metadataShape, {} + out = fc.outputValues()['dataOut'] + # metadataShape uses existing shape from data arrays + assert out is not None + + # --- MeshgridDataDict input --- + + def test_noGrid_meshgrid_flattens(self, qtbot): + mesh = make_mesh((5, 4)) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=mesh) + fc.nodes()['g'].grid = GridOption.noGrid, {} + out = fc.outputValues()['dataOut'] + assert isinstance(out, DataDict) + assert out.nrecords() == 20 + + def test_guessShape_meshgrid_passthrough(self, qtbot): + mesh = make_mesh((5, 4)) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=mesh) + fc.nodes()['g'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + assert isinstance(out, MeshgridDataDict) + + def test_specifyShape_meshgrid_warns(self, qtbot): + mesh = make_mesh((5, 4)) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=mesh) + fc.nodes()['g'].grid = GridOption.specifyShape, dict(shape=(5, 4)) + out = fc.outputValues()['dataOut'] + # Should pass through with warning + assert isinstance(out, MeshgridDataDict) + + def test_metadataShape_meshgrid_passthrough(self, qtbot): + mesh = make_mesh((5, 4)) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=mesh) + fc.nodes()['g'].grid = GridOption.metadataShape, {} + out = fc.outputValues()['dataOut'] + assert isinstance(out, MeshgridDataDict) + + # --- Edge cases --- + + def test_gridding_error_falls_back(self, qtbot): + """Data that can't be gridded should fall back to noGrid.""" + dd = DataDict( + x=dict(values=np.array([1.0, 1.0, 2.0, 2.0, 3.0])), + y=dict(values=np.array([1.0, 2.0, 1.0, 2.0, 1.0])), + z=dict(values=np.random.randn(5), axes=['x', 'y']), + ) + dd.validate() + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=dd) + fc.nodes()['g'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + # Should not crash; may fall back to expanded DataDict + assert out is not None + + def test_does_not_mutate_input(self, qtbot): + dd = make_griddable((10, 8)) + ref_vals = {k: v['values'].copy() for k, v in dd.data_items()} + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=dd) + fc.nodes()['g'].grid = GridOption.guessShape, {} + _ = fc.outputValues()['dataOut'] + for k, orig in ref_vals.items(): + assert np.array_equal(dd.data_vals(k), orig), f"{k} was mutated" + + def test_multiple_deps(self, qtbot): + dd = make_griddable((5, 4), ndeps=3) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=dd) + fc.nodes()['g'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + assert len(out.dependents()) == 3 + + def test_with_noisy_axes(self, qtbot): + dd = make_griddable((10, 8), noise=1e-6) + fc = linearFlowchart(('g', DataGridder)) + fc.setInput(dataIn=dd) + fc.nodes()['g'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + assert out is not None + assert isinstance(out, MeshgridDataDict) diff --git a/test/pytest/test_latex.py b/test/pytest/test_latex.py new file mode 100644 index 00000000..def0b01b --- /dev/null +++ b/test/pytest/test_latex.py @@ -0,0 +1,172 @@ +"""Tests for plottr.utils.latex — LaTeX to HTML conversion.""" +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from plottr.utils.latex import latex_to_html + + +class TestGreekLetters: + def test_alpha(self): + assert latex_to_html(r'\alpha') == '\u03b1' + + def test_beta(self): + assert latex_to_html(r'\beta') == '\u03b2' + + def test_gamma(self): + assert latex_to_html(r'\gamma') == '\u03b3' + + def test_omega_upper(self): + assert latex_to_html(r'\Omega') == '\u03a9' + + def test_mu(self): + assert latex_to_html(r'\mu') == '\u03bc' + + def test_pi(self): + assert latex_to_html(r'\pi') == '\u03c0' + + +class TestMathSymbols: + def test_hbar(self): + result = latex_to_html(r'\hbar') + # unicodeit may return ℏ (U+210F) or ħ (U+0127) depending on version + assert result in ('\u0127', '\u210f') + + def test_partial(self): + assert latex_to_html(r'\partial') == '\u2202' + + def test_infty(self): + assert latex_to_html(r'\infty') == '\u221e' + + def test_int(self): + assert latex_to_html(r'\int') == '\u222b' + + def test_sum(self): + assert latex_to_html(r'\sum') == '\u2211' + + +class TestSubscripts: + def test_braced_text(self): + assert latex_to_html(r'V_{gate}') == 'Vgate' + + def test_braced_numbers(self): + assert latex_to_html(r'g_{11}') == 'g11' + + def test_braced_multi(self): + assert latex_to_html(r'I_{DS}') == 'IDS' + + def test_mixed(self): + result = latex_to_html(r'V_{SD}') + assert '' in result + assert 'SD' in result + + +class TestSuperscripts: + def test_braced(self): + assert latex_to_html(r'x^{2}') == 'x2' + + def test_braced_text(self): + result = latex_to_html(r'e^{i\pi}') + assert '' in result + assert '\u03c0' in result + + +class TestFractions: + def test_simple(self): + assert latex_to_html(r'\frac{dI}{dV}') == 'dI/dV' + + def test_with_symbols(self): + result = latex_to_html(r'\frac{\partial I}{\partial V}') + assert 'I' in result and 'V' in result and '/' in result + + +class TestSqrt: + def test_simple(self): + result = latex_to_html(r'\sqrt{x}') + assert result == '\u221ax' + + +class TestDollarDelimiters: + def test_stripped(self): + result = latex_to_html(r'$\alpha$') + assert result == '\u03b1' + + def test_inline(self): + result = latex_to_html(r'Signal ($\mu$V)') + assert '\u03bc' in result + assert '$' not in result + + +class TestPassthrough: + def test_plain_text(self): + assert latex_to_html('voltage') == 'voltage' + + def test_empty(self): + assert latex_to_html('') == '' + + def test_units(self): + assert latex_to_html('mV') == 'mV' + + def test_with_parens(self): + assert latex_to_html('amplitude (V)') == 'amplitude (V)' + + def test_plain_underscore(self): + """Plain underscores (no braces) should NOT become subscripts.""" + assert latex_to_html('gate_voltage') == 'gate_voltage' + + def test_multiple_underscores(self): + assert latex_to_html('my_long_variable_name') == 'my_long_variable_name' + + def test_snake_case_with_numbers(self): + assert latex_to_html('channel_1_amplitude') == 'channel_1_amplitude' + + def test_plain_caret(self): + """Plain carets (no braces) in non-LaTeX strings pass through.""" + assert latex_to_html('x^2') == 'x^2' + + def test_plain_underscore_single(self): + assert latex_to_html('x_0') == 'x_0' + + +class TestRealWorldLabels: + """Labels commonly seen in quantum physics experiments.""" + + def test_conductance(self): + result = latex_to_html(r'g_{11}') + assert '' in result and '11' in result + + def test_gate_voltage(self): + result = latex_to_html(r'V_{gate}') + assert 'gate' in result + assert '' in result + + def test_bias_voltage(self): + result = latex_to_html(r'V_{SD}') + assert 'SD' in result + + def test_differential_conductance(self): + result = latex_to_html(r'$\frac{dI}{dV}$') + assert 'dI/dV' in result + + def test_magnetic_field(self): + result = latex_to_html(r'B_{field} (T)') + assert '' in result + assert '(T)' in result + + +class TestHypothesis: + @given(st.text(min_size=0, max_size=100)) + @settings(max_examples=200) + def test_never_crashes(self, text): + """latex_to_html should never raise on any input.""" + result = latex_to_html(text) + assert isinstance(result, str) + + @given(st.text(alphabet='abcdefghijklmnopqrstuvwxyz0123456789 .,()', + min_size=0, max_size=50)) + @settings(max_examples=100) + def test_plain_text_passthrough(self, text): + """Text without LaTeX indicators should pass through unchanged.""" + result = latex_to_html(text) + # Without backslash-letter, $, _{, or ^{, text is returned as-is. + assert result == text diff --git a/test/pytest/test_pipeline_coverage.py b/test/pytest/test_pipeline_coverage.py new file mode 100644 index 00000000..5551f69a --- /dev/null +++ b/test/pytest/test_pipeline_coverage.py @@ -0,0 +1,632 @@ +""" +test_pipeline_coverage.py + +Comprehensive pipeline tests exercising every plottr node with various data +shapes, structures, and dtypes. + +Uses two approaches: +- hypothesis @given for pure DataDict/MeshgridDataDict operations (no Qt needed) +- pytest parametrize + qtbot for flowchart-based node tests (needs QApplication) +""" +import numpy as np +import pytest +from hypothesis import given, settings, HealthCheck +from hypothesis import strategies as st + +from plottr.data.datadict import ( + DataDict, + DataDictBase, + MeshgridDataDict, + datadict_to_meshgrid, + meshgrid_to_datadict, +) +from plottr.node.tools import linearFlowchart +from plottr.node.node import Node +from plottr.node.data_selector import DataSelector +from plottr.node.grid import DataGridder, GridOption +from plottr.node.dim_reducer import DimensionReducer, XYSelector, ReductionMethod +from plottr.node.scaleunits import ScaleUnits +from plottr.node.filter.correct_offset import SubtractAverage +from plottr.node.histogram import Histogrammer +from plottr.utils import num + + +# --------------------------------------------------------------------------- +# Disable UI for all node classes within this module's tests only. +# We save/restore originals via a session-scoped fixture. +# --------------------------------------------------------------------------- + +_ORIGINAL_UI_SETTINGS = {} + +@pytest.fixture(autouse=True, scope="module") +def _disable_ui_for_module(): + """Temporarily disable UIs for all node classes during this module's tests.""" + classes = [DataSelector, DataGridder, DimensionReducer, XYSelector, + ScaleUnits, SubtractAverage, Histogrammer] + for cls in classes: + _ORIGINAL_UI_SETTINGS[cls] = (cls.useUi, cls.uiClass) + cls.useUi = False + cls.uiClass = None + yield + for cls in classes: + cls.useUi, cls.uiClass = _ORIGINAL_UI_SETTINGS[cls] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_griddable_dd(shape, ndeps=1): + """Create a DataDict from a meshgrid shape (flattened).""" + naxes = len(shape) + ax_names = [f'ax{i}' for i in range(naxes)] + axes_1d = [np.linspace(0, 1, s) for s in shape] + grids = np.meshgrid(*axes_1d, indexing='ij') + + dd = DataDict() + for i, ax in enumerate(ax_names): + dd[ax] = dict(values=grids[i].ravel(), axes=[], unit='V', label=ax) + for j in range(ndeps): + dd[f'dep{j}'] = dict(values=np.random.randn(int(np.prod(shape))), + axes=ax_names.copy(), unit='A', label=f'dep{j}') + dd.validate() + return dd + + +def make_mesh(shape, ndeps=1): + """Create a MeshgridDataDict.""" + naxes = len(shape) + ax_names = [f'ax{i}' for i in range(naxes)] + axes_1d = [np.linspace(0, 1, s) for s in shape] + grids = np.meshgrid(*axes_1d, indexing='ij') + + dd = MeshgridDataDict() + for i, ax in enumerate(ax_names): + dd[ax] = dict(values=grids[i], axes=[], unit='V', label=ax) + for j in range(ndeps): + dd[f'dep{j}'] = dict(values=np.random.randn(*shape), + axes=ax_names.copy(), unit='A', label=f'dep{j}') + dd.validate() + return dd + + +def snapshot_values(dd): + return {k: v['values'].copy() for k, v in dd.data_items()} + + +def assert_not_mutated(dd, snap): + for k, orig in snap.items(): + assert num.arrays_equal(np.asarray(orig), np.asarray(dd.data_vals(k))), \ + f"Field {k} was mutated" + + +# --------------------------------------------------------------------------- +# Hypothesis strategies for pure data operations +# --------------------------------------------------------------------------- + +@st.composite +def griddable_datadict_st(draw, min_axis_len=2, max_axis_len=12, + min_axes=1, max_axes=3, min_deps=1, max_deps=3): + naxes = draw(st.integers(min_value=min_axes, max_value=max_axes)) + ndeps = draw(st.integers(min_value=min_deps, max_value=max_deps)) + shape = tuple(draw(st.integers(min_value=min_axis_len, max_value=max_axis_len)) + for _ in range(naxes)) + return make_griddable_dd(shape, ndeps) + + +@st.composite +def meshgrid_st(draw, min_axis_len=2, max_axis_len=12, + min_axes=1, max_axes=3, min_deps=1, max_deps=3): + naxes = draw(st.integers(min_value=min_axes, max_value=max_axes)) + ndeps = draw(st.integers(min_value=min_deps, max_value=max_deps)) + shape = tuple(draw(st.integers(min_value=min_axis_len, max_value=max_axis_len)) + for _ in range(naxes)) + return make_mesh(shape, ndeps) + + +# =========================================================================== +# PART A: HYPOTHESIS TESTS — pure DataDict operations (no Qt) +# =========================================================================== + +class TestDataDictOperationsHypothesis: + """Property-based tests for DataDict operations that don't need a QApplication.""" + + @given(data=griddable_datadict_st(min_axes=1, max_axes=3, min_axis_len=3)) + @settings(max_examples=50, deadline=10000) + def test_gridding_roundtrip_structure(self, data): + """Gridding should produce a MeshgridDataDict with matching structure.""" + try: + mesh = datadict_to_meshgrid(data) + except (ValueError, Exception): + # Some shapes may not grid cleanly; that's expected for edge cases + return + assert isinstance(mesh, MeshgridDataDict) + assert set(mesh.axes()) == set(data.axes()) + assert set(mesh.dependents()) == set(data.dependents()) + + @given(data=meshgrid_st(min_axes=1, max_axes=3)) + @settings(max_examples=50, deadline=5000) + def test_flatten_roundtrip(self, data): + """Flatten to DataDict and back should preserve shapes.""" + flat = meshgrid_to_datadict(data) + assert isinstance(flat, DataDict) + assert flat.nrecords() == int(np.prod(data.shape())) + + @given(data=griddable_datadict_st(min_axes=1, max_axes=3)) + @settings(max_examples=30, deadline=5000) + def test_copy_preserves_equality(self, data): + """copy() should produce an equal dataset.""" + data2 = data.copy() + assert data == data2 + + @given(data=griddable_datadict_st(min_deps=2, max_deps=4)) + @settings(max_examples=30, deadline=5000) + def test_extract_produces_subset(self, data): + """extract() should return only the requested deps and their axes.""" + dep = data.dependents()[0] + ex = data.extract([dep]) + assert ex.dependents() == [dep] + assert set(ex.axes()) == set(data.axes(dep)) + + @given(data=meshgrid_st(min_axes=2, max_axes=3)) + @settings(max_examples=30, deadline=5000) + def test_meshgrid_copy_independent(self, data): + """Copied MeshgridDataDict must be independent.""" + data2 = data.copy() + data2[data2.dependents()[0]]['values'].flat[0] = 999.0 + assert data[data.dependents()[0]]['values'].flat[0] != 999.0 + + @given(data=meshgrid_st(min_axes=2, max_axes=3)) + @settings(max_examples=20, deadline=5000) + def test_mask_invalid_clean_data(self, data): + """mask_invalid on clean data should not change values.""" + data2 = data.copy() + data2 = data2.mask_invalid() + for dep in data.dependents(): + assert np.allclose( + np.asarray(data.data_vals(dep)), + np.asarray(data2.data_vals(dep)), + ) + + @given(data=meshgrid_st(min_axes=2, max_axes=2)) + @settings(max_examples=20, deadline=5000) + def test_mean_removes_axis(self, data): + """mean() should remove the averaged axis.""" + ax = data.axes()[0] + result = data.mean(ax) + assert ax not in result.axes() + + @given(data=meshgrid_st(min_axes=2, max_axes=2, min_axis_len=4)) + @settings(max_examples=20, deadline=5000) + def test_slice_preserves_validity(self, data): + """Slicing should produce valid data.""" + ax = data.axes()[0] + result = data.slice(**{ax: slice(1, 3)}) + assert result.validate() + + +# =========================================================================== +# PART B: FLOWCHART-BASED NODE TESTS (need qtbot for QApplication) +# =========================================================================== + +# --- Node base --- + +def test_node_passthrough(qtbot): + data = make_griddable_dd((5, 4)) + fc = linearFlowchart(('n', Node)) + fc.setInput(dataIn=data) + assert fc.outputValues()['dataOut'] is data + + +# --- DataSelector --- + +class TestDataSelectorFC: + + @pytest.mark.parametrize("shape,ndeps", [ + ((10,), 2), ((5, 4), 2), ((3, 3, 2), 3), + ]) + def test_select_single_dep(self, qtbot, shape, ndeps): + data = make_griddable_dd(shape, ndeps) + fc = linearFlowchart(('sel', DataSelector)) + fc.setInput(dataIn=data) + fc.nodes()['sel'].selectedData = [data.dependents()[0]] + out = fc.outputValues()['dataOut'] + assert out is not None + assert out.dependents() == [data.dependents()[0]] + + @pytest.mark.parametrize("shape,ndeps", [ + ((10,), 3), ((5, 4), 2), + ]) + def test_select_multiple_deps(self, qtbot, shape, ndeps): + data = make_griddable_dd(shape, ndeps) + fc = linearFlowchart(('sel', DataSelector)) + fc.setInput(dataIn=data) + deps = data.dependents()[:2] + fc.nodes()['sel'].selectedData = deps + out = fc.outputValues()['dataOut'] + assert out is not None + assert set(out.dependents()) == set(deps) + + def test_select_does_not_mutate(self, qtbot): + data = make_griddable_dd((10,), 2) + snap = snapshot_values(data) + fc = linearFlowchart(('sel', DataSelector)) + fc.setInput(dataIn=data) + fc.nodes()['sel'].selectedData = [data.dependents()[0]] + _ = fc.outputValues()['dataOut'] + assert_not_mutated(data, snap) + + +# --- DataGridder --- + +class TestDataGridderFC: + + @pytest.mark.parametrize("shape", [ + (5,), (5, 4), (10, 10), (3, 4, 2), (5, 5, 5), + ]) + def test_guess_shape(self, qtbot, shape): + data = make_griddable_dd(shape) + fc = linearFlowchart(('grid', DataGridder)) + fc.setInput(dataIn=data) + fc.nodes()['grid'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + assert out is not None + assert isinstance(out, MeshgridDataDict) + assert out.shape() == shape + + @pytest.mark.parametrize("shape", [ + (5, 4), (3, 3, 2), + ]) + def test_specify_shape(self, qtbot, shape): + data = make_griddable_dd(shape) + ax_names = data.axes(data.dependents()[0]) + fc = linearFlowchart(('grid', DataGridder)) + fc.setInput(dataIn=data) + fc.nodes()['grid'].grid = GridOption.specifyShape, dict( + shape=shape, order=ax_names, + ) + out = fc.outputValues()['dataOut'] + assert out is not None + assert isinstance(out, MeshgridDataDict) + assert out.shape() == shape + + def test_nogrid_passthrough(self, qtbot): + data = make_griddable_dd((5, 4)) + fc = linearFlowchart(('grid', DataGridder)) + fc.setInput(dataIn=data) + fc.nodes()['grid'].grid = GridOption.noGrid, {} + out = fc.outputValues()['dataOut'] + assert out is not None + assert isinstance(out, DataDict) + + def test_meshgrid_passthrough_guess(self, qtbot): + data = make_mesh((5, 4)) + fc = linearFlowchart(('grid', DataGridder)) + fc.setInput(dataIn=data) + fc.nodes()['grid'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + assert isinstance(out, MeshgridDataDict) + + def test_meshgrid_to_flat_nogrid(self, qtbot): + data = make_mesh((5, 4)) + fc = linearFlowchart(('grid', DataGridder)) + fc.setInput(dataIn=data) + fc.nodes()['grid'].grid = GridOption.noGrid, {} + out = fc.outputValues()['dataOut'] + assert isinstance(out, DataDict) + assert out.nrecords() == 20 + + def test_gridder_does_not_mutate(self, qtbot): + data = make_griddable_dd((5, 4)) + snap = snapshot_values(data) + fc = linearFlowchart(('grid', DataGridder)) + fc.setInput(dataIn=data) + fc.nodes()['grid'].grid = GridOption.guessShape, {} + _ = fc.outputValues()['dataOut'] + assert_not_mutated(data, snap) + + +# --- DimensionReducer --- + +class TestDimensionReducerFC: + + @pytest.mark.parametrize("shape", [ + (5, 4), (4, 3, 2), + ]) + def test_element_selection(self, qtbot, shape): + data = make_mesh(shape) + fc = linearFlowchart(('red', DimensionReducer)) + fc.setInput(dataIn=data) + last_ax = data.axes()[-1] + fc.nodes()['red'].reductions = { + last_ax: (ReductionMethod.elementSelection, [], {'index': 0}) + } + out = fc.outputValues()['dataOut'] + assert out is not None + assert last_ax not in out.axes() + + @pytest.mark.parametrize("shape", [ + (5, 4), (4, 3, 2), + ]) + def test_average_reduction(self, qtbot, shape): + data = make_mesh(shape) + fc = linearFlowchart(('red', DimensionReducer)) + fc.setInput(dataIn=data) + last_ax = data.axes()[-1] + fc.nodes()['red'].reductions = {last_ax: (ReductionMethod.average,)} + out = fc.outputValues()['dataOut'] + assert out is not None + assert last_ax not in out.axes() + + def test_reducer_does_not_mutate(self, qtbot): + data = make_mesh((5, 4)) + snap = snapshot_values(data) + fc = linearFlowchart(('red', DimensionReducer)) + fc.setInput(dataIn=data) + fc.nodes()['red'].reductions = { + 'ax1': (ReductionMethod.elementSelection, [], {'index': 0}) + } + _ = fc.outputValues()['dataOut'] + assert_not_mutated(data, snap) + + +# --- XYSelector --- + +class TestXYSelectorFC: + + @pytest.mark.parametrize("shape", [ + (5, 4), (8, 6), (4, 3, 2), (5, 5, 5), + ]) + def test_xy_produces_2d(self, qtbot, shape): + data = make_mesh(shape) + axes = data.axes() + fc = linearFlowchart(('xy', XYSelector)) + fc.setInput(dataIn=data) + fc.nodes()['xy'].xyAxes = (axes[0], axes[1]) + out = fc.outputValues()['dataOut'] + assert out is not None + for dep in out.dependents(): + assert out.data_vals(dep).ndim == 2 + + def test_xy_1d_x_only(self, qtbot): + data = make_mesh((10,)) + fc = linearFlowchart(('xy', XYSelector)) + fc.setInput(dataIn=data) + fc.nodes()['xy'].xyAxes = ('ax0', None) + out = fc.outputValues()['dataOut'] + assert out is not None + for dep in out.dependents(): + assert out.data_vals(dep).ndim == 1 + + def test_xy_no_axes_returns_none(self, qtbot): + data = make_mesh((5, 4)) + fc = linearFlowchart(('xy', XYSelector)) + fc.setInput(dataIn=data) + assert fc.outputValues()['dataOut'] is None + + def test_xy_does_not_mutate(self, qtbot): + data = make_mesh((5, 4, 3)) + snap = snapshot_values(data) + fc = linearFlowchart(('xy', XYSelector)) + fc.setInput(dataIn=data) + fc.nodes()['xy'].xyAxes = ('ax0', 'ax1') + _ = fc.outputValues()['dataOut'] + assert_not_mutated(data, snap) + + +# --- ScaleUnits --- + +class TestScaleUnitsFC: + + @pytest.mark.parametrize("scale,prefix_substr", [ + (1e-9, 'n'), (1e-6, '\u03bc'), (1e-3, 'm'), (1e6, 'M'), (1e9, 'G'), + ]) + def test_si_prefix(self, qtbot, scale, prefix_substr): + dd = DataDict( + x=dict(values=np.arange(5, dtype=float) * scale, unit='V'), + y=dict(values=np.arange(5, dtype=float), axes=['x'], unit='A'), + ) + dd.validate() + fc = linearFlowchart(('su', ScaleUnits)) + fc.setInput(dataIn=dd) + out = fc.outputValues()['dataOut'] + assert prefix_substr in out['x']['unit'] + + def test_does_not_mutate(self, qtbot): + dd = DataDict( + x=dict(values=np.arange(5, dtype=float) * 1e-9, unit='V'), + y=dict(values=np.arange(5, dtype=float), axes=['x'], unit='A'), + ) + dd.validate() + snap = snapshot_values(dd) + fc = linearFlowchart(('su', ScaleUnits)) + fc.setInput(dataIn=dd) + _ = fc.outputValues()['dataOut'] + assert_not_mutated(dd, snap) + + +# --- SubtractAverage --- + +class TestSubtractAverageFC: + + @pytest.mark.parametrize("shape", [ + (10, 5), (5, 4, 3), + ]) + def test_subtract_axis(self, qtbot, shape): + data = make_mesh(shape) + ax = data.axes()[-1] + fc = linearFlowchart(('sa', SubtractAverage)) + fc.setInput(dataIn=data) + fc.nodes()['sa'].averagingAxis = ax + out = fc.outputValues()['dataOut'] + assert out is not None + # After subtraction, mean along that axis should be ~0 + ax_idx = data.axes().index(ax) + for dep in out.dependents(): + avg = out.data_vals(dep).mean(axis=ax_idx) + assert np.allclose(avg, 0, atol=1e-10) + + def test_no_axis_passthrough(self, qtbot): + data = make_mesh((5, 4)) + fc = linearFlowchart(('sa', SubtractAverage)) + fc.setInput(dataIn=data) + out = fc.outputValues()['dataOut'] + assert out is not None + + def test_does_not_mutate(self, qtbot): + data = make_mesh((5, 4)) + snap = snapshot_values(data) + fc = linearFlowchart(('sa', SubtractAverage)) + fc.setInput(dataIn=data) + fc.nodes()['sa'].averagingAxis = 'ax1' + _ = fc.outputValues()['dataOut'] + assert_not_mutated(data, snap) + + +# --- Histogrammer --- + +class TestHistogrammerFC: + + @pytest.mark.parametrize("shape,hist_ax", [ + ((20, 5), 'ax0'), + ((10, 8), 'ax1'), + ((5, 4, 3), 'ax0'), + ]) + def test_histogram_produces_counts(self, qtbot, shape, hist_ax): + data = make_mesh(shape) + fc = linearFlowchart(('h', Histogrammer)) + fc.setInput(dataIn=data) + fc.nodes()['h'].nbins = 10 + fc.nodes()['h'].histogramAxis = hist_ax + out = fc.outputValues()['dataOut'] + assert out is not None + assert any('count' in d for d in out.dependents()) + + def test_no_axis_passthrough(self, qtbot): + data = make_mesh((10, 5)) + fc = linearFlowchart(('h', Histogrammer)) + fc.setInput(dataIn=data) + out = fc.outputValues()['dataOut'] + assert out is not None + + +# =========================================================================== +# FULL PIPELINE INTEGRATION +# =========================================================================== + +class TestFullPipelineFC: + + @pytest.mark.parametrize("shape", [ + (5, 4), (10, 10), (8, 6), (3, 4, 2), + ]) + def test_selector_gridder_xy(self, qtbot, shape): + data = make_griddable_dd(shape, ndeps=2) + fc = linearFlowchart( + ('sel', DataSelector), + ('grid', DataGridder), + ('xy', XYSelector), + ) + fc.setInput(dataIn=data) + fc.nodes()['sel'].selectedData = [data.dependents()[0]] + fc.nodes()['grid'].grid = GridOption.guessShape, {} + axes = data.axes(data.dependents()[0]) + fc.nodes()['xy'].xyAxes = (axes[0], axes[1]) + + out = fc.outputValues()['dataOut'] + assert out is not None + assert isinstance(out, MeshgridDataDict) + + def test_full_pipeline_does_not_mutate(self, qtbot): + data = make_griddable_dd((8, 6)) + snap = snapshot_values(data) + fc = linearFlowchart( + ('sel', DataSelector), + ('grid', DataGridder), + ('xy', XYSelector), + ) + fc.setInput(dataIn=data) + fc.nodes()['sel'].selectedData = [data.dependents()[0]] + fc.nodes()['grid'].grid = GridOption.guessShape, {} + fc.nodes()['xy'].xyAxes = ('ax0', 'ax1') + _ = fc.outputValues()['dataOut'] + assert_not_mutated(data, snap) + + def test_full_with_scale_and_subtract(self, qtbot): + data = make_griddable_dd((6, 5)) + # Give units to exercise ScaleUnits + data['ax0']['unit'] = 'V' + data['ax0']['values'] *= 1e-9 + data['dep0']['unit'] = 'A' + + fc = linearFlowchart( + ('sel', DataSelector), + ('grid', DataGridder), + ('xy', XYSelector), + ('sa', SubtractAverage), + ('su', ScaleUnits), + ) + fc.setInput(dataIn=data) + fc.nodes()['sel'].selectedData = ['dep0'] + fc.nodes()['grid'].grid = GridOption.guessShape, {} + fc.nodes()['xy'].xyAxes = ('ax0', 'ax1') + fc.nodes()['sa'].averagingAxis = 'ax1' + + out = fc.outputValues()['dataOut'] + assert out is not None + + @pytest.mark.parametrize("dtype", [ + np.float64, np.float32, np.complex128, + ]) + def test_pipeline_various_dtypes(self, qtbot, dtype): + shape = (5, 4) + data = make_griddable_dd(shape) + z = np.random.randn(20).astype(dtype) + if np.issubdtype(dtype, np.complexfloating): + z = z + 1j * np.random.randn(20).astype(dtype) + data['dep0']['values'] = z + + fc = linearFlowchart( + ('sel', DataSelector), + ('grid', DataGridder), + ) + fc.setInput(dataIn=data) + fc.nodes()['sel'].selectedData = ['dep0'] + fc.nodes()['grid'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + assert out is not None + assert np.issubdtype(out.data_vals('dep0').dtype, dtype) + + def test_pipeline_with_nan_data(self, qtbot): + """Pipeline with incomplete data (NaN values).""" + data = make_griddable_dd((6, 5)) + # Inject NaN at end (simulating incomplete sweep) + data['dep0']['values'][-5:] = np.nan + data['ax0']['values'][-5:] = np.nan + data['ax1']['values'][-5:] = np.nan + + fc = linearFlowchart( + ('sel', DataSelector), + ('grid', DataGridder), + ) + fc.setInput(dataIn=data) + fc.nodes()['sel'].selectedData = ['dep0'] + fc.nodes()['grid'].grid = GridOption.guessShape, {} + out = fc.outputValues()['dataOut'] + # Should handle NaN gracefully (either grid or fall back) + assert out is not None + + def test_pipeline_with_multiple_deps(self, qtbot): + """Pipeline selecting multiple compatible dependents.""" + data = make_griddable_dd((5, 4), ndeps=3) + fc = linearFlowchart( + ('sel', DataSelector), + ('grid', DataGridder), + ('xy', XYSelector), + ) + fc.setInput(dataIn=data) + fc.nodes()['sel'].selectedData = data.dependents()[:2] + fc.nodes()['grid'].grid = GridOption.guessShape, {} + fc.nodes()['xy'].xyAxes = ('ax0', 'ax1') + out = fc.outputValues()['dataOut'] + assert out is not None + assert len(out.dependents()) == 2 diff --git a/test/pytest/test_plotting.py b/test/pytest/test_plotting.py index 29dc1f75..db482919 100644 --- a/test/pytest/test_plotting.py +++ b/test/pytest/test_plotting.py @@ -1,5 +1,11 @@ import matplotlib.pyplot as plt import numpy as np +import os +import pytest + +os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") + +from plottr.data.datadict import MeshgridDataDict, DataDict from plottr.plot.mpl.plotting import PlotType, colorplot2d @@ -28,3 +34,274 @@ def test_colorplot2d_scatter_rgba_error(): y = np.array([[0.0, 0.0, 0.0]]) z = np.array([[5.08907021, 4.93923391, 5.11400073]]) colorplot2d(ax, x, y, z, PlotType.scatter2d) + + +# -- Axis orientation tests -- + +def _make_asymmetric_meshgrid(): + """5×3 meshgrid with unique Z per position.""" + x = np.linspace(-2, 2, 5) + y = np.linspace(10, 30, 3) + xx, yy = np.meshgrid(x, y, indexing='ij') + zz = xx + 100 * yy + dd = MeshgridDataDict( + z=dict(values=zz, axes=['x', 'y']), + x=dict(values=xx), y=dict(values=yy), + ) + dd.validate() + return dd, xx, yy, zz + + +class TestAxisOrientation: + """Verify that 2D image plots have correct X/Y axis orientation.""" + + def test_pyqtgraph_image_data_shape(self, qtbot): + from plottr.plot.pyqtgraph.plots import PlotWithColorbar + _, xx, yy, zz = _make_asymmetric_meshgrid() + plot = PlotWithColorbar() + qtbot.addWidget(plot) + plot.setImage(xx, yy, zz) + # z is transposed for display: input (5, 3) → ImageItem (3, 5) + assert plot.img.image.shape == (3, 5) + + def test_pyqtgraph_image_rect(self, qtbot): + from plottr.plot.pyqtgraph.plots import PlotWithColorbar + from plottr import QtCore + _, xx, yy, zz = _make_asymmetric_meshgrid() + plot = PlotWithColorbar() + qtbot.addWidget(plot) + plot.setImage(xx, yy, zz) + expected = QtCore.QRectF( + xx.min(), yy.min(), xx.max() - xx.min(), yy.max() - yy.min() + ) + assert abs(expected.width() - (xx.max() - xx.min())) < 0.01 + assert abs(expected.height() - (yy.max() - yy.min())) < 0.01 + + def test_pyqtgraph_reversed_x(self, qtbot): + from plottr.plot.pyqtgraph.plots import PlotWithColorbar + x = np.linspace(2, -2, 5) + y = np.linspace(10, 30, 3) + xx, yy = np.meshgrid(x, y, indexing='ij') + zz = xx + 100 * yy + plot = PlotWithColorbar() + qtbot.addWidget(plot) + plot.setImage(xx, yy, zz) + assert plot.img.image.shape == (3, 5) # transposed + + def test_mpl_and_pyqtgraph_consistency(self, qtbot): + _, xx, yy, zz = _make_asymmetric_meshgrid() + from plottr.plot.mpl.plotting import plotImage + fig, ax = plt.subplots() + plotImage(ax, xx, yy, zz) + plt.close(fig) + from plottr.plot.pyqtgraph.plots import PlotWithColorbar + plot = PlotWithColorbar() + qtbot.addWidget(plot) + plot.setImage(xx, yy, zz) + assert plot.img is not None and plot.img.image is not None + + +# -- Complex splitting tests -- + +class TestComplexSplitting: + """Verify complex data is split correctly for 1D and 2D.""" + + @staticmethod + def _make_complex_1d(): + x = np.linspace(0, 10, 50) + z = np.sin(x) + 1j * np.cos(x) + dd = DataDict(z=dict(values=z, axes=['x']), x=dict(values=x)) + dd.validate() + return dd + + def test_detected(self): + assert np.iscomplexobj(self._make_complex_1d().data_vals('z')) + + def test_split_real(self): + from plottr.plot.base import ComplexRepresentation, PlotItem, PlotDataType, AutoFigureMaker + dd = self._make_complex_1d() + item = PlotItem([dd.data_vals('x'), dd.data_vals('z')], 0, 0, + PlotDataType.line1d, ['x', 'z'], None) + fm = AutoFigureMaker() + fm.complexRepresentation = ComplexRepresentation.real + result = fm._splitComplexData(item) + assert len(result) == 1 + assert not np.iscomplexobj(result[0].data[-1]) + + def test_split_real_and_imag(self): + from plottr.plot.base import ComplexRepresentation, PlotItem, PlotDataType, AutoFigureMaker + dd = self._make_complex_1d() + z = dd.data_vals('z') + item = PlotItem([dd.data_vals('x'), z], 0, 0, + PlotDataType.line1d, ['x', 'z'], None) + fm = AutoFigureMaker() + fm.complexRepresentation = ComplexRepresentation.realAndImag + result = fm._splitComplexData(item) + assert len(result) == 2 + np.testing.assert_array_equal(result[0].data[-1], z.real) + np.testing.assert_array_equal(result[1].data[-1], z.imag) + + def test_split_mag_and_phase(self): + from plottr.plot.base import ComplexRepresentation, PlotItem, PlotDataType, AutoFigureMaker + dd = self._make_complex_1d() + z = dd.data_vals('z') + item = PlotItem([dd.data_vals('x'), z], 0, 0, + PlotDataType.line1d, ['x', 'z'], None) + fm = AutoFigureMaker() + fm.complexRepresentation = ComplexRepresentation.magAndPhase + result = fm._splitComplexData(item) + assert len(result) == 2 + np.testing.assert_array_almost_equal(result[0].data[-1], np.abs(z)) + np.testing.assert_array_almost_equal(result[1].data[-1], np.angle(z)) + + +# -- Matplotlib first-plot-not-blank tests -- + +class TestMplFirstPlot: + """Verify mpl backend renders on first setData (plotType is set).""" + + def test_2d_sets_plotType(self, qtbot): + from plottr.plot.mpl.autoplot import AutoPlot + w = AutoPlot() + qtbot.addWidget(w) + x = np.linspace(-1, 1, 10) + y = np.linspace(0, 5, 8) + xx, yy = np.meshgrid(x, y, indexing='ij') + data = MeshgridDataDict( + z=dict(values=xx**2 + yy, axes=['x', 'y']), + x=dict(values=xx), y=dict(values=yy), + ) + w.setData(data) + assert w.plotType is not PlotType.empty + + def test_1d_sets_plotType(self, qtbot): + from plottr.plot.mpl.autoplot import AutoPlot + w = AutoPlot() + qtbot.addWidget(w) + x = np.linspace(0, 10, 50) + data = MeshgridDataDict( + y=dict(values=np.sin(x), axes=['x']), x=dict(values=x), + ) + w.setData(data) + assert w.plotType is not PlotType.empty + + def test_repeated_setData(self, qtbot): + from plottr.plot.mpl.autoplot import AutoPlot + w = AutoPlot() + qtbot.addWidget(w) + x = np.linspace(-1, 1, 10) + y = np.linspace(0, 5, 8) + xx, yy = np.meshgrid(x, y, indexing='ij') + data = MeshgridDataDict( + z=dict(values=xx**2 + yy, axes=['x', 'y']), + x=dict(values=xx), y=dict(values=yy), + ) + w.setData(data) + t1 = w.plotType + w.setData(data) + assert w.plotType == t1 + + +# -- Pyqtgraph complex mode switching tests -- + +class TestPyqtgraphComplexModes: + """Verify pyqtgraph backend handles complex mode switching for 1D data.""" + + @staticmethod + def _make_complex_1d(): + x = np.linspace(0, 10, 50) + z = np.sin(x) + 1j * np.cos(x) + return MeshgridDataDict( + z=dict(values=z, axes=['x']), x=dict(values=x), + ) + + def test_complex_detected_as_imagData(self, qtbot): + """1D complex data should set imagData=True.""" + from plottr.plot.pyqtgraph.autoplot import AutoPlot + w = AutoPlot(parent=None) + qtbot.addWidget(w) + w.setData(self._make_complex_1d()) + assert w.figOptions.imagData is True + + def test_all_complex_options_available(self, qtbot): + """All complex representations should be in the toolbar menu.""" + from plottr.plot.pyqtgraph.autoplot import AutoPlot + from plottr.plot.base import ComplexRepresentation + w = AutoPlot(parent=None) + qtbot.addWidget(w) + w.setData(self._make_complex_1d()) + + # Find the Complex button's menu + menu_labels = self._get_complex_menu_labels(w) + assert ComplexRepresentation.real.label in menu_labels + assert ComplexRepresentation.realAndImag.label in menu_labels + assert ComplexRepresentation.realAndImagSeparate.label in menu_labels + assert ComplexRepresentation.magAndPhase.label in menu_labels + + def test_switch_to_real_and_back(self, qtbot): + """After switching to Real, should be able to switch back to Real/Imag.""" + from plottr.plot.pyqtgraph.autoplot import AutoPlot + from plottr.plot.base import ComplexRepresentation + w = AutoPlot(parent=None) + qtbot.addWidget(w) + w.setData(self._make_complex_1d()) + + # Switch to Real + w.figOptions.complexRepresentation = ComplexRepresentation.real + w._refreshPlot() + + # imagData should still be True (data is still complex) + assert w.figOptions.imagData is True + + # All options should still be available + menu_labels = self._get_complex_menu_labels(w) + assert ComplexRepresentation.realAndImag.label in menu_labels + + # Switch back + w.figOptions.complexRepresentation = ComplexRepresentation.realAndImag + w._refreshPlot() + assert w.figOptions.complexRepresentation == ComplexRepresentation.realAndImag + + def test_separate_re_im_mode(self, qtbot): + """realAndImagSeparate should create 2 subplots for 1D data.""" + from plottr.plot.pyqtgraph.autoplot import AutoPlot + from plottr.plot.base import ComplexRepresentation + w = AutoPlot(parent=None) + qtbot.addWidget(w) + w.setData(self._make_complex_1d()) + + w.figOptions.complexRepresentation = ComplexRepresentation.realAndImagSeparate + w._refreshPlot() + + # Should have 2 subplots (one for Real, one for Imag) + assert w.fmWidget is not None + assert len(w.fmWidget.subPlots) == 2 + + def test_non_complex_only_shows_real(self, qtbot): + """Non-complex 1D data should only offer Real in the menu.""" + from plottr.plot.pyqtgraph.autoplot import AutoPlot + from plottr.plot.base import ComplexRepresentation + w = AutoPlot(parent=None) + qtbot.addWidget(w) + x = np.linspace(0, 10, 50) + data = MeshgridDataDict( + y=dict(values=np.sin(x), axes=['x']), x=dict(values=x), + ) + w.setData(data) + assert w.figOptions.imagData is False + menu_labels = self._get_complex_menu_labels(w) + assert menu_labels == [ComplexRepresentation.real.label] + + @staticmethod + def _get_complex_menu_labels(w): + """Extract labels from the Complex button's popup menu.""" + # The Complex button is at action index 1 in the toolbar + toolbar = w.figConfig + actions = toolbar.actions() + for a in actions: + widget = toolbar.widgetForAction(a) + if isinstance(widget, __import__('plottr').QtWidgets.QToolButton): + menu = widget.menu() + if menu is not None: + return [ma.text() for ma in menu.actions()] + return [] diff --git a/test/pytest/test_qcodes_data.py b/test/pytest/test_qcodes_data.py index 2ffdaeb9..ba9ce9c2 100644 --- a/test/pytest/test_qcodes_data.py +++ b/test/pytest/test_qcodes_data.py @@ -312,3 +312,144 @@ def check(): # break # check() # check() + + +# -- Records counter tests (qcodes_db_overview) -- + +def _make_qcodes_db_with_runs(db_path: str, n_runs: int = 1) -> str: + """Helper: create a QCodes DB with n_runs simple numeric datasets.""" + try: + from qcodes.parameters import ParamSpecBase + except ImportError: + from qcodes.dataset.descriptions.param_spec import ParamSpecBase + from qcodes.dataset.descriptions.dependencies import InterDependencies_ + + initialise_or_create_database_at(db_path) + exp = load_or_create_experiment("test_exp", sample_name="test_sample") + p_x = ParamSpecBase("x", "numeric") + p_y = ParamSpecBase("y", "numeric") + interdeps = InterDependencies_(dependencies={p_y: (p_x,)}) + + for r in range(n_runs): + ds = qc.new_data_set(f"run_{r + 1}") + ds.set_interdependencies(interdeps) + ds.mark_started() + for i in range(10): + ds.add_results([{p_x.name: float(i), p_y.name: float(i ** 2)}]) + ds.mark_completed() + return db_path + + +class TestRecordsCounter: + """Verify records counter shows actual data point count.""" + + def test_counts_result_rows(self, tmp_path): + """Overview should count rows from the results table.""" + import sqlite3 + from plottr.data.qcodes_db_overview import get_db_overview + + db_path = str(tmp_path / "test.db") + _make_qcodes_db_with_runs(db_path, n_runs=3) + overview = get_db_overview(db_path) + conn = sqlite3.connect(db_path) + + for run_id, info in overview.items(): + row = conn.execute( + "SELECT result_table_name FROM runs WHERE run_id=?", + (run_id,) + ).fetchone() + if row and row[0]: + try: + actual = conn.execute( + f'SELECT COUNT(*) FROM "{row[0]}"' + ).fetchone()[0] + except Exception: + continue + assert info['records'] == actual, \ + f"Run {run_id}: overview={info['records']}, actual={actual}" + conn.close() + + def test_records_from_shapes(self): + """Shape info in run_description should produce correct count.""" + import json + from plottr.data.qcodes_db_overview import _records_from_run_description + + desc = json.dumps({"version": 3, "shapes": {"dep1": [100, 50]}}) + assert _records_from_run_description(desc) == 5000 + assert _records_from_run_description(json.dumps({"version": 3})) == 0 + assert _records_from_run_description(None) == 0 + assert _records_from_run_description("") == 0 + + +# -- Dataset refresh tests (inspectr incremental load) -- + +class TestDatasetRefresh: + """Verify incremental DB refresh detects new runs.""" + + def test_incremental_overview(self, tmp_path): + """get_db_overview with start_run_id should find newly added runs.""" + from plottr.data.qcodes_db_overview import get_db_overview + try: + from qcodes.parameters import ParamSpecBase + except ImportError: + from qcodes.dataset.descriptions.param_spec import ParamSpecBase + from qcodes.dataset.descriptions.dependencies import InterDependencies_ + + db_path = str(tmp_path / "test.db") + _make_qcodes_db_with_runs(db_path, n_runs=2) + + assert set(get_db_overview(db_path).keys()) == {1, 2} + assert len(get_db_overview(db_path, start_run_id=2)) == 0 + + # Add a third run + initialise_or_create_database_at(db_path) + exp = load_or_create_experiment("test_exp2", sample_name="s2") + p_x = ParamSpecBase("x", "numeric") + p_y = ParamSpecBase("y", "numeric") + interdeps = InterDependencies_(dependencies={p_y: (p_x,)}) + ds = qc.new_data_set("run_3") + ds.set_interdependencies(interdeps) + ds.mark_started() + ds.add_results([{p_x.name: 1.0, p_y.name: 2.0}]) + ds.mark_completed() + + assert 3 in get_db_overview(db_path, start_run_id=2) + + def test_inspectr_refresh(self, qtbot, tmp_path): + """QCodesDBInspector.refreshDB should detect new runs.""" + import os + os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") + from plottr.apps.inspectr import QCodesDBInspector + try: + from qcodes.parameters import ParamSpecBase + except ImportError: + from qcodes.dataset.descriptions.param_spec import ParamSpecBase + from qcodes.dataset.descriptions.dependencies import InterDependencies_ + + db_path = str(tmp_path / "test.db") + _make_qcodes_db_with_runs(db_path, n_runs=1) + + inspector = QCodesDBInspector(dbPath=db_path) + qtbot.addWidget(inspector) + + def initial_load_done(): + return inspector.dbdf is not None and inspector.dbdf.size > 0 + qtbot.waitUntil(initial_load_done, timeout=5000) + assert list(inspector.dbdf.index) == [1] + + # Add run 2 + initialise_or_create_database_at(db_path) + p_x = ParamSpecBase("x", "numeric") + p_y = ParamSpecBase("y", "numeric") + interdeps = InterDependencies_(dependencies={p_y: (p_x,)}) + ds = qc.new_data_set("run_2") + ds.set_interdependencies(interdeps) + ds.mark_started() + ds.add_results([{p_x.name: 1.0, p_y.name: 2.0}]) + ds.mark_completed() + + inspector.refreshDB() + def refresh_done(): + return (inspector.dbdf is not None and 2 in inspector.dbdf.index) + qtbot.waitUntil(refresh_done, timeout=5000) + assert 2 in inspector.dbdf.index diff --git a/test/pytest/test_round2_optimizations.py b/test/pytest/test_round2_optimizations.py new file mode 100644 index 00000000..00044147 --- /dev/null +++ b/test/pytest/test_round2_optimizations.py @@ -0,0 +1,350 @@ +""" +test_round2_optimizations.py + +Tests for round 2 performance optimizations: is_invalid, largest_numtype, +guess_grid, remove_invalid_entries, Node.process structure deferral, +complex plot splitting, flatten->ravel. +""" +import numpy as np +import pytest + +from plottr.data.datadict import ( + DataDict, MeshgridDataDict, meshgrid_to_datadict, datadict_to_dataframe, +) +from plottr.utils.num import is_invalid, largest_numtype, guess_grid_from_sweep_direction + + +# =========================================================================== +# is_invalid() +# =========================================================================== + +class TestIsInvalid: + def test_float_with_nan(self): + arr = np.array([1.0, np.nan, 3.0]) + result = is_invalid(arr) + assert result.tolist() == [False, True, False] + + def test_float_clean(self): + arr = np.array([1.0, 2.0, 3.0]) + result = is_invalid(arr) + assert not np.any(result) + + def test_int_array(self): + arr = np.arange(10) + result = is_invalid(arr) + assert not np.any(result) + + def test_object_with_none(self): + arr = np.array([1.0, None, 3.0], dtype=object) + result = is_invalid(arr) + assert result[1] == True + assert result[0] == False + + def test_complex_with_nan(self): + arr = np.array([1+1j, np.nan+0j, 3+0j]) + result = is_invalid(arr) + assert result[1] == True + + def test_empty_array(self): + arr = np.array([], dtype=float) + result = is_invalid(arr) + assert result.shape == (0,) + + def test_bool_array(self): + arr = np.array([True, False, True]) + result = is_invalid(arr) + assert not np.any(result) + + def test_2d_float(self): + arr = np.array([[1.0, np.nan], [3.0, 4.0]]) + result = is_invalid(arr) + assert result.shape == (2, 2) + assert result[0, 1] == True + assert result[1, 0] == False + + +# =========================================================================== +# largest_numtype() +# =========================================================================== + +class TestLargestNumtype: + def test_float_array(self): + arr = np.array([1.0, 2.0, 3.0]) + result = largest_numtype(arr) + assert issubclass(result, (float, np.floating)) + + def test_int_array_include_integers(self): + arr = np.arange(10) + result = largest_numtype(arr, include_integers=True) + # Should return float (promotion) or int + assert result in (float, int, np.int64, np.int32, np.float64) + + def test_int_array_exclude_integers(self): + arr = np.arange(10) + result = largest_numtype(arr, include_integers=False) + # With include_integers=False, int arrays are promoted to float + assert result == float + + def test_complex_array(self): + arr = np.array([1+1j, 2+2j]) + result = largest_numtype(arr) + assert issubclass(result, (complex, np.complexfloating)) + + def test_object_array_with_floats(self): + arr = np.array([1.0, 2.0, 3.0], dtype=object) + result = largest_numtype(arr) + assert result == float + + def test_string_array(self): + arr = np.array(['a', 'b', 'c']) + result = largest_numtype(arr) + assert result is None + + def test_object_with_none_and_floats(self): + arr = np.array([1.0, None, 3.0], dtype=object) + result = largest_numtype(arr) + # Should find float as the largest type + assert result == float + + def test_empty_array(self): + arr = np.array([]) + result = largest_numtype(arr) + # Empty array has no elements to inspect + assert result is None + + +# =========================================================================== +# guess_grid_from_sweep_direction() +# =========================================================================== + +class TestGuessGrid: + def test_simple_2d_grid(self): + x = np.repeat(np.arange(5, dtype=float), 4) + y = np.tile(np.arange(4, dtype=float), 5) + result = guess_grid_from_sweep_direction(x=x, y=y) + assert result is not None + order, shape = result + assert set(order) == {'x', 'y'} + assert 5 in shape and 4 in shape + + def test_1d_sweep(self): + x = np.arange(10, dtype=float) + result = guess_grid_from_sweep_direction(x=x) + assert result is not None + _, shape = result + assert shape == (10,) + + def test_single_point(self): + x = np.array([1.0]) + result = guess_grid_from_sweep_direction(x=x) + assert result is not None + + def test_with_noise(self): + x = np.repeat(np.linspace(0, 1, 10), 8) + np.random.randn(80) * 1e-6 + y = np.tile(np.linspace(0, 1, 8), 10) + result = guess_grid_from_sweep_direction(x=x, y=y) + assert result is not None + _, shape = result + assert 10 in shape and 8 in shape + + +# =========================================================================== +# remove_invalid_entries() +# =========================================================================== + +class TestRemoveInvalidEntries: + def test_removes_nan_rows(self): + dd = DataDict( + x=dict(values=np.arange(5, dtype=float)), + y=dict(values=np.array([1.0, np.nan, 3.0, np.nan, 5.0]), axes=['x']), + ) + dd.validate() + dd2 = dd.remove_invalid_entries() + assert dd2.nrecords() == 3 + assert np.allclose(dd2.data_vals('y'), [1.0, 3.0, 5.0]) + + def test_preserves_clean_data(self): + dd = DataDict( + x=dict(values=np.arange(10, dtype=float)), + y=dict(values=np.arange(10, dtype=float), axes=['x']), + ) + dd.validate() + dd2 = dd.remove_invalid_entries() + assert dd2.nrecords() == 10 + + def test_removes_none_in_object_array(self): + dd = DataDict( + x=dict(values=np.array([1, 2, 3], dtype=object)), + y=dict(values=np.array([1.0, None, 3.0], dtype=object), axes=['x']), + ) + dd.validate() + dd2 = dd.remove_invalid_entries() + # Only row where ALL dependents are invalid gets removed + # Row 1 has None in y -> removed only if x is also invalid + # Actually remove_invalid_entries removes rows where ALL deps are invalid + assert dd2.nrecords() <= 3 + + def test_multiple_dependents(self): + """remove_invalid_entries removes rows where ALL dependents are invalid. + + Note: this previously crashed with np.array(idxs) on inhomogeneous + arrays. Fixed by using np.concatenate instead of np.append. + """ + dd = DataDict( + x=dict(values=np.arange(5, dtype=float)), + y=dict(values=np.array([1.0, np.nan, 3.0, np.nan, 5.0]), axes=['x']), + z=dict(values=np.array([np.nan, 2.0, np.nan, np.nan, 5.0]), axes=['x']), + ) + dd.validate() + dd2 = dd.remove_invalid_entries() + assert dd2.nrecords() == 4 + + +# =========================================================================== +# meshgrid_to_datadict (flatten->ravel) +# =========================================================================== + +class TestMeshgridToDatadict: + def test_basic_conversion(self): + x = np.linspace(0, 1, 5) + y = np.arange(3, dtype=float) + xx, yy = np.meshgrid(x, y, indexing='ij') + mesh = MeshgridDataDict( + x=dict(values=xx), y=dict(values=yy), + z=dict(values=xx*yy, axes=['x', 'y']), + ) + mesh.validate() + dd = meshgrid_to_datadict(mesh) + assert isinstance(dd, DataDict) + assert dd.nrecords() == 15 + + def test_values_match(self): + x = np.linspace(0, 1, 4) + y = np.arange(3, dtype=float) + xx, yy = np.meshgrid(x, y, indexing='ij') + zz = xx + yy + mesh = MeshgridDataDict( + x=dict(values=xx), y=dict(values=yy), + z=dict(values=zz, axes=['x', 'y']), + ) + mesh.validate() + dd = meshgrid_to_datadict(mesh) + assert np.allclose(dd.data_vals('z'), zz.ravel()) + + def test_3d_conversion(self): + shape = (3, 4, 2) + grids = np.meshgrid(*[np.linspace(0, 1, s) for s in shape], indexing='ij') + mesh = MeshgridDataDict( + a=dict(values=grids[0]), b=dict(values=grids[1]), c=dict(values=grids[2]), + z=dict(values=np.random.randn(*shape), axes=['a', 'b', 'c']), + ) + mesh.validate() + dd = meshgrid_to_datadict(mesh) + assert dd.nrecords() == 24 + + +# =========================================================================== +# datadict_to_dataframe +# =========================================================================== + +class TestDatadictToDataframe: + def test_basic(self): + dd = DataDict( + x=dict(values=np.arange(5, dtype=float)), + y=dict(values=np.arange(5, dtype=float) * 2, axes=['x']), + ) + dd.validate() + df = datadict_to_dataframe(dd) + assert len(df) == 5 + assert list(df.columns) == ['x', 'y'] + + +# =========================================================================== +# Node.process() structure deferral +# =========================================================================== + +class TestNodeProcessStructure: + def test_node_process_returns_data(self, qtbot): + from plottr.node.node import Node + from plottr.node.tools import linearFlowchart + Node.useUi = False; Node.uiClass = None + + mesh = MeshgridDataDict() + x = np.linspace(0, 1, 5) + y = np.arange(3, dtype=float) + xx, yy = np.meshgrid(x, y, indexing='ij') + mesh['x'] = dict(values=xx, axes=[]) + mesh['y'] = dict(values=yy, axes=[]) + mesh['z'] = dict(values=xx + yy, axes=['x', 'y']) + mesh.validate() + + fc = linearFlowchart(('n', Node)) + fc.setInput(dataIn=mesh) + out = fc.outputValues()['dataOut'] + assert out is mesh + + def test_node_detects_structure_change(self, qtbot): + from plottr.node.node import Node + from plottr.node.tools import linearFlowchart + Node.useUi = False; Node.uiClass = None + + dd1 = DataDict( + x=dict(values=np.arange(5, dtype=float)), + y=dict(values=np.arange(5, dtype=float), axes=['x']), + ) + dd1.validate() + + dd2 = DataDict( + x=dict(values=np.arange(5, dtype=float)), + y=dict(values=np.arange(5, dtype=float), axes=['x']), + z=dict(values=np.arange(5, dtype=float), axes=['x']), + ) + dd2.validate() + + fc = linearFlowchart(('n', Node)) + fc.setInput(dataIn=dd1) + node = fc.nodes()['n'] + assert node.dataDependents == ['y'] + + fc.setInput(dataIn=dd2) + assert node.dataDependents == ['y', 'z'] + + +# =========================================================================== +# Complex plot deepcopy +# =========================================================================== + +class TestComplexPlotSplit: + def test_split_produces_correct_items(self): + from plottr.plot.base import AutoFigureMaker, PlotDataType, PlotItem, ComplexRepresentation + + class DummyFM(AutoFigureMaker): + def makeSubPlots(self, n): return [None]*n + def plot(self, item): return None + def formatSubPlot(self, id): pass + + fm = DummyFM() + fm.complexRepresentation = ComplexRepresentation.realAndImag + data = np.array([1+2j, 3+4j, 5+6j]) + pi = PlotItem(data=[np.arange(3, dtype=float), data], + id=0, subPlot=0, labels=['x', 'z']) + result = fm._splitComplexData(pi) + assert len(result) == 2 + assert np.allclose(result[0].data[-1], data.real) + assert np.allclose(result[1].data[-1], data.imag) + + def test_split_real_data_unchanged(self): + from plottr.plot.base import AutoFigureMaker, PlotDataType, PlotItem + + class DummyFM(AutoFigureMaker): + def makeSubPlots(self, n): return [None]*n + def plot(self, item): return None + def formatSubPlot(self, id): pass + + fm = DummyFM() + data = np.array([1.0, 2.0, 3.0]) + pi = PlotItem(data=[np.arange(3, dtype=float), data], + id=0, subPlot=0, labels=['x', 'z']) + result = fm._splitComplexData(pi) + assert len(result) == 1 + assert np.allclose(result[0].data[-1], data) diff --git a/test_requirements.txt b/test_requirements.txt index 081141b3..23f7860c 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,6 +1,7 @@ qcodes pytest pytest-qt +hypothesis mypy==1.20.2 PyQt5-stubs==5.15.6.0 pandas-stubs