Skip to content

Commit 9522464

Browse files
authored
Merge pull request #3941 from chrishalcrow/add-sa-to-widgets
Allow for `SortingAnalyzer` or `BaseSorter` in `plot_*`
2 parents dc1ea7a + 4454b88 commit 9522464

File tree

3 files changed

+86
-46
lines changed

3 files changed

+86
-46
lines changed

src/spikeinterface/widgets/isi_distribution.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import numpy as np
44
from warnings import warn
55

6+
from spikeinterface.core import SortingAnalyzer, BaseSorting
7+
68
from .base import BaseWidget, to_attr
7-
from .utils import get_unit_colors
89

910

1011
class ISIDistributionWidget(BaseWidget):
@@ -13,18 +14,37 @@ class ISIDistributionWidget(BaseWidget):
1314
1415
Parameters
1516
----------
16-
sorting : SortingExtractor
17-
The sorting extractor object
18-
unit_ids : list
19-
List of unit ids
20-
bins_ms : int
21-
Bin size in ms
22-
window_ms : float
17+
sorting_analyzer_or_sorting : SortingAnalyzer | BaseSorting | None, default: None
18+
The object containing the sorting information for the isi distribution plot
19+
unit_ids : list | None, default: None
20+
List of unit ids. If None, uses all unit ids.
21+
window_ms : float, default: 100.0
2322
Window size in ms
24-
23+
bins_ms : int, default: 1.0
24+
Bin size in ms
25+
sorting : SortingExtractor | None, default: None
26+
A sorting object. Deprecated.
2527
"""
2628

27-
def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs):
29+
def __init__(
30+
self,
31+
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None,
32+
unit_ids: list | None = None,
33+
window_ms: float = 100.0,
34+
bin_ms: float = 1.0,
35+
backend: str | None = None,
36+
sorting: BaseSorting | None = None,
37+
**backend_kwargs,
38+
):
39+
40+
if sorting is not None:
41+
# When removed, make `sorting_analyzer_or_sorting` a required argument rather than None.
42+
deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
43+
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
44+
sorting_analyzer_or_sorting = sorting
45+
46+
sorting = self.ensure_sorting(sorting_analyzer_or_sorting)
47+
2848
if unit_ids is None:
2949
unit_ids = sorting.get_unit_ids()
3050

src/spikeinterface/widgets/rasters.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44
from warnings import warn
55

6-
from .base import BaseWidget, to_attr, default_backend_kwargs
6+
from spikeinterface.core import SortingAnalyzer, BaseSorting
7+
from .base import BaseWidget, to_attr
78
from .utils import get_some_colors
89

910

@@ -278,39 +279,46 @@ class RasterWidget(BaseRasterWidget):
278279
279280
Parameters
280281
----------
281-
sorting : SortingExtractor | None, default: None
282-
A sorting object
283-
sorting_analyzer : SortingAnalyzer | None, default: None
284-
A sorting analyzer object
285-
segment_index : None or int
286-
The segment index.
287-
unit_ids : list
288-
List of unit ids
289-
time_range : list
282+
sorting_analyzer_or_sorting : SortingAnalyzer | BaseSorting | None, default: None
283+
The object containing the sorting information for the raster plot
284+
segment_index : None | int, default: None
285+
The segment index. If None, uses first segment.
286+
unit_ids : list | None, default: None
287+
List of unit ids. If None, uses all unit ids.
288+
time_range : list | None, default: None
290289
List with start time and end time
291-
color : matplotlib color
290+
color : matplotlib color, default: "k"
292291
The color to be used
292+
sorting : SortingExtractor | None, default: None
293+
A sorting object. Deprecated.
294+
sorting_analyzer : SortingAnalyzer | None, default: None
295+
A sorting analyzer object. Deprecated.
293296
"""
294297

295298
def __init__(
296299
self,
297-
sorting=None,
298-
sorting_analyzer=None,
299-
segment_index=None,
300-
unit_ids=None,
301-
time_range=None,
300+
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None,
301+
segment_index: int | None = None,
302+
unit_ids: list | None = None,
303+
time_range: list | None = None,
302304
color="k",
303-
backend=None,
305+
backend: str | None = None,
306+
sorting: BaseSorting | None = None,
307+
sorting_analyzer: SortingAnalyzer | None = None,
304308
**backend_kwargs,
305309
):
306-
if sorting is None and sorting_analyzer is None:
307-
raise Exception("Must supply either a sorting or a sorting_analyzer")
308-
elif sorting is not None and sorting_analyzer is not None:
309-
raise Exception("Should supply either a sorting or a sorting_analyzer, not both")
310-
elif sorting_analyzer is not None:
311-
sorting = sorting_analyzer.sorting
312-
313-
sorting = self.ensure_sorting(sorting)
310+
311+
if sorting is not None:
312+
# When removed, make `sorting_analyzer_or_sorting` a required argument rather than None.
313+
deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
314+
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
315+
sorting_analyzer_or_sorting = sorting
316+
if sorting_analyzer is not None:
317+
deprecation_msg = "`sorting_analyzer` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
318+
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
319+
sorting_analyzer_or_sorting = sorting_analyzer
320+
321+
sorting = self.ensure_sorting(sorting_analyzer_or_sorting)
314322

315323
if sorting.get_num_segments() > 1:
316324
if segment_index is None:

src/spikeinterface/widgets/unit_presence.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

33
import numpy as np
4+
from warnings import warn
45

6+
from spikeinterface.core import SortingAnalyzer, BaseSorting
57
from .base import BaseWidget, to_attr
68

79

@@ -11,29 +13,39 @@ class UnitPresenceWidget(BaseWidget):
1113
1214
Parameters
1315
----------
14-
sorting : SortingExtractor
15-
The sorting extractor object
16-
segment_index : None or int
17-
The segment index.
16+
sorting_analyzer_or_sorting : SortingAnalyzer | BaseSorting | None, default: None
17+
The object containing the sorting information for the raster plot
18+
segment_index : None or int, default: None
19+
The segment index. If None, uses first segment.
1820
time_range : list or None, default: None
1921
List with start time and end time
2022
bin_duration_s : float, default: 0.5
2123
Bin size (in seconds) for the heat map time axis
2224
smooth_sigma : float, default: 4.5
2325
Sigma for the Gaussian kernel (in number of bins)
26+
sorting : SortingExtractor | None, default: None
27+
A sorting object. Deprecated.
2428
"""
2529

2630
def __init__(
2731
self,
28-
sorting,
29-
segment_index=None,
30-
time_range=None,
31-
bin_duration_s=0.05,
32-
smooth_sigma=4.5,
33-
backend=None,
32+
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None,
33+
segment_index: int | None = None,
34+
time_range: list | None = None,
35+
bin_duration_s: float = 0.05,
36+
smooth_sigma: float = 4.5,
37+
backend: str | None = None,
38+
sorting: BaseSorting | None = None,
3439
**backend_kwargs,
3540
):
36-
sorting = self.ensure_sorting(sorting)
41+
42+
if sorting is not None:
43+
# When removed, make `sorting_analyzer_or_sorting` a required argument rather than None.
44+
deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
45+
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
46+
sorting_analyzer_or_sorting = sorting
47+
48+
sorting = self.ensure_sorting(sorting_analyzer_or_sorting)
3749

3850
if segment_index is None:
3951
nseg = sorting.get_num_segments()

0 commit comments

Comments
 (0)