Skip to content

Commit f2d6373

Browse files
authored
Merge pull request #4037 from chrishalcrow/sa-by-dict
Allow `create_sorting_analyzer` to accept dicts
2 parents 8557c87 + 97073b6 commit f2d6373

10 files changed

+246
-48
lines changed

doc/how_to/process_by_channel_group.rst

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. _recording-by-channel-group:
1+
.. _process_by_group:
22

33
Process a recording by channel group
44
====================================
@@ -101,8 +101,8 @@ to any preprocessing function.
101101
referenced_recording = spre.common_reference(filtered_recording)
102102
good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording)
103103
104-
We can then aggregate the recordings back together using the ``aggregate_channels`` function.
105-
Note that we do not need to do this to sort the data (see :ref:`sorting-by-channel-group`).
104+
If needed, we could aggregate the recordings back together using the ``aggregate_channels`` function.
105+
Note: you do not need to do this to sort the data (see :ref:`sorting-by-channel-group`).
106106

107107
.. code-block:: python
108108
@@ -145,14 +145,16 @@ Sorting a Recording by Channel Group
145145
We can also sort a recording for each channel group separately. It is not necessary to preprocess
146146
a recording by channel group in order to sort by channel group.
147147

148-
There are two ways to sort a recording by channel group. First, we can simply pass the output from
149-
our preprocessing-by-group method above. Second, for more control, we can loop over the recordings
148+
There are two ways to sort a recording by channel group. First, we can pass a dictionary to the
149+
``run_sorter`` function. Since the preprocessing-by-group method above returns a dict, we can
150+
simply pass this output. Alternatively, for more control, we can loop over the recordings
150151
ourselves.
151152

152-
**Option 1 : Automatic splitting**
153+
**Option 1 : Automatic splitting (Recommended)**
153154

154-
Simply pass the split recording to the `run_sorter` function, as if it was a non-split recording.
155-
This will return a dict of sortings, with the keys corresponding to the groups.
155+
Simply pass the split recording to the ``run_sorter`` function, as if it was a non-split recording.
156+
This will return a dict of sortings, with the same keys as the dict of recordings that were
157+
passed to ``run_sorter``.
156158

157159
.. code-block:: python
158160
@@ -162,10 +164,10 @@ This will return a dict of sortings, with the keys corresponding to the groups.
162164
# do preprocessing if needed
163165
pp_recording = spre.bandpass_filter(split_recording)
164166
165-
dict_of_sortings = run_sorter(
166-
sorter_name='kilosort2',
167+
dict_of_sortings = run_sorter(
168+
sorter_name='kilosort4',
167169
recording=pp_recording,
168-
working_folder='working_path'
170+
folder='my_kilosort4_sorting'
169171
)
170172
171173
@@ -190,3 +192,37 @@ to different groups.
190192
folder=f"folder_KS2_group{group}"
191193
)
192194
sortings[group] = sorting
195+
196+
197+
Creating a SortingAnalyzer by Channel Group
198+
-------------------------------------------
199+
200+
The code above generates a dictionary of recording objects and a dictionary of sorting objects.
201+
When making a :ref:`SortingAnalyzer <modules/core:SortingAnalyzer>`, we can pass these dictionaries and
202+
a single analyzer will be created, with the recordings and sortings appropriately aggregated.
203+
204+
The dictionary of recordings and dictionary of sortings must have the same keys. E.g. if you
205+
use ``split_by("group")``, the keys of your dict of recordings will be the values of the ``group``
206+
property of the recording. Then the dict of sortings should also have these keys.
207+
Note that if you use the internal functions, like we do in the code-block below, you don't need to
208+
keep track of keys yourself. SpikeInterface will do this for you automatically.
209+
210+
The code for create ``SortingAnalyzer`` from dicts of recordings and sortings is very similar to that for
211+
creating a sorting analyzer from a single recording and sorting:
212+
213+
.. code-block:: python
214+
215+
dict_of_recordings = preprocessed_recording.split_by("group")
216+
dict_of_sortings = run_sorter(sorter_name="mountainsort5", recording = dict_of_recordings)
217+
218+
analyzer = create_sorting_analyzer(sorting=dict_of_sortings, recording=dict_of_recordings)
219+
220+
221+
The code above creates a *single* sorting analyzer called :code:`analyzer`. You can select the units
222+
from one of the "group"s as follows:
223+
224+
.. code-block:: python
225+
226+
aggretation_keys = analyzer.get_sorting_property("aggregation_key")
227+
unit_ids_group_0 = analyzer.unit_ids[aggretation_keys == 0]
228+
group_0_analzyer = analyzer.select_units(unit_ids = unit_ids_group_0)

doc/modules/sorters.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,11 @@ Running spike sorting by group is indeed a very common need.
338338

339339
A :py:class:`~spikeinterface.core.BaseRecording` object has the ability to split itself into a dictionary of
340340
sub-recordings given a certain property (see :py:meth:`~spikeinterface.core.BaseRecording.split_by`).
341-
So it is easy to loop over this dictionary and sequentially run spike sorting on these sub-recordings.
342-
The :py:func:`~spikeinterface.sorters.run_sorter` method can also accept the dictionary which is returned
341+
The :py:func:`~spikeinterface.sorters.run_sorter` method can accept the dictionary which is returned
343342
by :py:meth:`~spikeinterface.core.BaseRecording.split_by` and will return a dictionary of sortings.
343+
In turn, these can be fed directly to :py:meth:`~spikeinterface.core.create_sorting_analyzer` to make
344+
a SortingAnalyzer. For more control, you can loop over the dictionary returned by :py:meth:`~spikeinterface.core.BaseRecording.split_by`
345+
and sequentially run spike sorting on these sub-recordings.
344346

345347
In this example, we create a 16-channel recording with 4 tetrodes:
346348

@@ -396,6 +398,10 @@ In this example, we create a 16-channel recording with 4 tetrodes:
396398
sortings[group] = sorting
397399
398400
401+
Note: you can feed the dict of sortings and dict of recordings directly to :code:`create_sorting_analyzer` to make
402+
a SortingAnalyzer from the split data: :ref:`read more <process_by_group>`.
403+
404+
399405
Handling multi-segment recordings
400406
---------------------------------
401407

src/spikeinterface/core/baserecordingsnippets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def split_by(self, property="group", outputs="dict"):
514514
recordings = {}
515515
for value in np.unique(values).tolist():
516516
(inds,) = np.nonzero(values == value)
517-
new_channel_ids = self.get_channel_ids()[inds]
517+
new_channel_ids = self.channel_ids[inds]
518518
subrec = self.select_channels(new_channel_ids)
519519
subrec.set_annotation("split_by_property", value=property)
520520
if outputs == "list":

src/spikeinterface/core/basesorting.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,47 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo
626626

627627
return self.frame_slice(start_frame=start_frame, end_frame=end_frame)
628628

629+
def split_by(self, property="group", outputs="dict"):
630+
"""
631+
Splits object based on a certain property (e.g. "group")
632+
633+
Parameters
634+
----------
635+
property : str, default: "group"
636+
The property to use to split the object, default: "group"
637+
outputs : "dict" | "list", default: "dict"
638+
Whether to return a dict or a list
639+
640+
Returns
641+
-------
642+
dict or list
643+
A dict or list with grouped objects based on property
644+
645+
Raises
646+
------
647+
ValueError
648+
Raised when property is not present
649+
"""
650+
assert outputs in ("list", "dict")
651+
values = self.get_property(property)
652+
if values is None:
653+
raise ValueError(f"property {property} is not set")
654+
655+
if outputs == "list":
656+
sortings = []
657+
elif outputs == "dict":
658+
sortings = {}
659+
for value in np.unique(values).tolist():
660+
(inds,) = np.nonzero(values == value)
661+
new_unit_ids = self.unit_ids[inds]
662+
subsort = self.select_units(new_unit_ids)
663+
subsort.set_annotation("split_by_property", value=property)
664+
if outputs == "list":
665+
sortings.append(subsort)
666+
elif outputs == "dict":
667+
sortings[value] = subsort
668+
return sortings
669+
629670
def time_to_sample_index(self, time, segment_index=0):
630671
"""
631672
Transform time in seconds into sample index

src/spikeinterface/core/channelsaggregationrecording.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,8 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record
3737

3838
self._recordings = recording_list
3939

40-
splitting_known = self._is_splitting_known()
41-
if not splitting_known:
42-
for group_id, recording in zip(recording_ids, recording_list):
43-
recording.set_property("group", [group_id] * recording.get_num_channels())
40+
for group_id, recording in zip(recording_ids, recording_list):
41+
recording.set_property("aggregation_key", [group_id] * recording.get_num_channels())
4442

4543
self._perform_consistency_checks()
4644
sampling_frequency = recording_list[0].get_sampling_frequency()
@@ -140,25 +138,6 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record
140138
def recordings(self):
141139
return self._recordings
142140

143-
def _is_splitting_known(self):
144-
145-
# If we have the `split_by_property` annotation, we know how the recording was split
146-
if self._recordings[0].get_annotation("split_by_property") is not None:
147-
return True
148-
149-
# Check if all 'group' properties are equal to 0
150-
recording_groups = []
151-
for recording in self._recordings:
152-
if (group_labels := recording.get_property("group")) is not None:
153-
recording_groups.extend(group_labels)
154-
else:
155-
recording_groups.extend([0])
156-
# If so, we don't know the splitting
157-
if np.all(np.unique(recording_groups) == np.array([0])):
158-
return False
159-
else:
160-
return True
161-
162141
def _perform_consistency_checks(self):
163142

164143
# Check for consistent sampling frequency across recordings

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222
import spikeinterface
2323

24-
from .baserecording import BaseRecording
25-
from .basesorting import BaseSorting
24+
from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units
2625

2726
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match
2827
from .core_tools import (
@@ -54,6 +53,7 @@ def create_sorting_analyzer(
5453
folder=None,
5554
sparse=True,
5655
sparsity=None,
56+
set_sparsity_by_dict_key=False,
5757
return_scaled=None,
5858
return_in_uV=True,
5959
overwrite=False,
@@ -71,10 +71,10 @@ def create_sorting_analyzer(
7171
7272
Parameters
7373
----------
74-
sorting : Sorting
75-
The sorting object
76-
recording : Recording
77-
The recording object
74+
sorting : Sorting | dict
75+
The sorting object, or a dict of them
76+
recording : Recording | dict
77+
The recording object, or a dict of them
7878
folder : str or Path or None, default: None
7979
The folder where analyzer is cached
8080
format : "memory | "binary_folder" | "zarr", default: "memory"
@@ -88,6 +88,9 @@ def create_sorting_analyzer(
8888
You can control `estimate_sparsity()` : all extra arguments are propagated to it (included job_kwargs)
8989
sparsity : ChannelSparsity or None, default: None
9090
The sparsity used to compute exensions. If this is given, `sparse` is ignored.
91+
set_sparsity_by_dict_key : bool, default: False
92+
If True and passing recording and sorting dicts, will set the sparsity based on the dict keys,
93+
and other `sparsity_kwargs` are overwritten. If False, use other sparsity settings.
9194
return_scaled : bool | None, default: None
9295
DEPRECATED. Use return_in_uV instead.
9396
All extensions that play with traces will use this global return_in_uV : "waveforms", "noise_levels", "templates".
@@ -139,6 +142,34 @@ def create_sorting_analyzer(
139142
In some situation, sparsity is not needed, so to make it fast creation, you need to turn
140143
sparsity off (or give external sparsity) like this.
141144
"""
145+
146+
if isinstance(sorting, dict) and isinstance(recording, dict):
147+
148+
if sorting.keys() != recording.keys():
149+
raise ValueError(
150+
f"Keys of `sorting`, {sorting.keys()}, and `recording`, {recording.keys()}, dicts do not match."
151+
)
152+
153+
aggregated_recording = aggregate_channels(recording)
154+
aggregated_sorting = aggregate_units(sorting)
155+
156+
if set_sparsity_by_dict_key:
157+
sparsity_kwargs = {"method": "by_property", "by_property": "aggregation_key"}
158+
159+
return create_sorting_analyzer(
160+
sorting=aggregated_sorting,
161+
recording=aggregated_recording,
162+
format=format,
163+
folder=folder,
164+
sparse=sparse,
165+
sparsity=sparsity,
166+
return_scaled=return_scaled,
167+
return_in_uV=return_in_uV,
168+
overwrite=overwrite,
169+
backend_options=backend_options,
170+
**sparsity_kwargs,
171+
)
172+
142173
if format != "memory":
143174
if format == "zarr":
144175
if not is_path_remote(folder):

src/spikeinterface/core/tests/test_channelsaggregationrecording.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from spikeinterface.core import aggregate_channels
44
from spikeinterface.core import generate_recording
5+
from spikeinterface.core.testing import check_recordings_equal
56

67

78
def test_channelsaggregationrecording():
@@ -114,6 +115,39 @@ def test_split_then_aggreate_preserve_user_property():
114115
assert np.all(old_properties_ids_dict == new_properties_ids_dict)
115116

116117

118+
def test_aggregation_split_by_and_manual():
119+
"""
120+
We can either split recordings automatically using "split_by" or manually by
121+
constructing dictionaries. This test checks the two are equivalent. We skip
122+
the annoations check since the "split_by" also saves an annotation to save what
123+
property we split by.
124+
"""
125+
126+
rec1 = generate_recording(num_channels=6)
127+
rec1_channel_ids = rec1.get_channel_ids()
128+
rec1.set_property(key="brain_area", values=["a", "a", "b", "a", "b", "a"])
129+
130+
split_recs = rec1.split_by("brain_area")
131+
132+
aggregated_rec = aggregate_channels(split_recs)
133+
134+
rec_a_channel_ids = aggregated_rec.channel_ids[aggregated_rec.get_property("brain_area") == "a"]
135+
rec_b_channel_ids = aggregated_rec.channel_ids[aggregated_rec.get_property("brain_area") == "b"]
136+
137+
assert np.all(rec_a_channel_ids == split_recs["a"].channel_ids)
138+
assert np.all(rec_b_channel_ids == split_recs["b"].channel_ids)
139+
140+
split_recs_manual = {
141+
"a": rec1.select_channels(channel_ids=rec1_channel_ids[rec1.get_property("brain_area") == "a"]),
142+
"b": rec1.select_channels(channel_ids=rec1_channel_ids[rec1.get_property("brain_area") == "b"]),
143+
}
144+
145+
aggregated_rec_manual = aggregate_channels(split_recs_manual)
146+
147+
assert np.all(aggregated_rec_manual.get_property("aggregation_key") == ["a", "a", "a", "a", "b", "b"])
148+
check_recordings_equal(aggregated_rec, aggregated_rec_manual, check_annotations=False, check_properties=True)
149+
150+
117151
def test_channel_aggregation_preserve_ids():
118152

119153
recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
@@ -132,9 +166,9 @@ def test_aggregation_labeling_for_lists():
132166
recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False)
133167
recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False)
134168

135-
# If we don't label at all, aggregation will add a 'group' label
169+
# If we don't label at all, aggregation will add a 'aggregation_key' label
136170
aggregated_recording = aggregate_channels([recording1, recording2])
137-
group_property = aggregated_recording.get_property("group")
171+
group_property = aggregated_recording.get_property("aggregation_key")
138172
assert np.all(group_property == [0, 0, 0, 0, 1, 1])
139173

140174
# If we have different group labels, these should be respected
@@ -161,9 +195,9 @@ def test_aggretion_labelling_for_dicts():
161195
recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False)
162196
recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False)
163197

164-
# If we don't label at all, aggregation will add a 'group' label based on the dict keys
198+
# If we don't label at all, aggregation will add a 'aggregation_key' label based on the dict keys
165199
aggregated_recording = aggregate_channels({0: recording1, "cat": recording2})
166-
group_property = aggregated_recording.get_property("group")
200+
group_property = aggregated_recording.get_property("aggregation_key")
167201
assert np.all(group_property == [0, 0, 0, 0, "cat", "cat"])
168202

169203
# If we have different group labels, these should be respected

0 commit comments

Comments
 (0)