Skip to content

Commit abd45a7

Browse files
authored
Merge pull request #4005 from chrishalcrow/sort-dict-minimal
Allow `run_sorter` to accept dicts
2 parents 8c2fd1a + f927649 commit abd45a7

File tree

5 files changed

+182
-40
lines changed

5 files changed

+182
-40
lines changed

doc/how_to/process_by_channel_group.rst

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ to any preprocessing function.
9999
referenced_recording = spre.common_reference(filtered_recording)
100100
good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording)
101101
102-
We can then aggregate the recordings back together using the ``aggregate_channels`` function
102+
We can then aggregate the recordings back together using the ``aggregate_channels`` function.
103+
Note that we do not need to do this to sort the data (see :ref:`sorting-by-channel-group`).
103104

104105
.. code-block:: python
105106
@@ -134,23 +135,47 @@ back together under the hood).
134135
In general, it is not recommended to apply :py:func:`~aggregate_channels` more than once.
135136
This will slow down :py:func:`~get_traces` calls and may result in unpredictable behaviour.
136137

138+
.. _sorting-by-channel-group:
137139

138140
Sorting a Recording by Channel Group
139141
------------------------------------
140142

141143
We can also sort a recording for each channel group separately. It is not necessary to preprocess
142144
a recording by channel group in order to sort by channel group.
143145

144-
There are two ways to sort a recording by channel group. First, we can split the preprocessed
145-
recording (or, if it was already split during preprocessing as above, skip the :py:func:`~aggregate_channels` step
146-
directly use the :py:func:`~split_recording_dict`).
146+
There are two ways to sort a recording by channel group. First, we can simply pass the output from
147+
our preprocessing-by-group method above. Second, for more control, we can loop over the recordings
148+
ourselves.
147149

148-
**Option 1: Manual splitting**
150+
**Option 1 : Automatic splitting**
149151

150-
In this example, similar to above we loop over all preprocessed recordings that
152+
Simply pass the split recording to the `run_sorter` function, as if it was a non-split recording.
153+
This will return a dict of sortings, with the keys corresponding to the groups.
154+
155+
.. code-block:: python
156+
157+
split_recording = raw_recording.split_by("group")
158+
# is a dict of recordings
159+
160+
# do preprocessing if needed
161+
pp_recording = spre.bandpass_filter(split_recording)
162+
163+
dict_of_sortings = run_sorter(
164+
sorter_name='kilosort2',
165+
recording=pp_recording,
166+
working_folder='working_path'
167+
)
168+
169+
170+
**Option 2: Manual splitting**
171+
172+
In this example, we loop over all preprocessed recordings that
151173
are grouped by channel, and apply the sorting separately. We store the
152174
sorting objects in a dictionary for later use.
153175

176+
You might do this if you want extra control e.g. to apply bespoke steps
177+
to different groups.
178+
154179
.. code-block:: python
155180
156181
split_preprocessed_recording = preprocessed_recording.split_by("group")
@@ -163,16 +188,3 @@ sorting objects in a dictionary for later use.
163188
folder=f"folder_KS2_group{group}"
164189
)
165190
sortings[group] = sorting
166-
167-
**Option 2 : Automatic splitting**
168-
169-
Alternatively, SpikeInterface provides a convenience function to sort the recording by property:
170-
171-
.. code-block:: python
172-
173-
aggregate_sorting = run_sorter_by_property(
174-
sorter_name='kilosort2',
175-
recording=preprocessed_recording,
176-
grouping_property='group',
177-
working_folder='working_path'
178-
)

doc/modules/sorters.rst

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ Running spike sorting by group is indeed a very common need.
339339
A :py:class:`~spikeinterface.core.BaseRecording` object has the ability to split itself into a dictionary of
340340
sub-recordings given a certain property (see :py:meth:`~spikeinterface.core.BaseRecording.split_by`).
341341
So it is easy to loop over this dictionary and sequentially run spike sorting on these sub-recordings.
342-
SpikeInterface also provides a high-level function to automate the process of splitting the
343-
recording and then aggregating the results with the :py:func:`~spikeinterface.sorters.run_sorter_by_property` function.
342+
The :py:func:`~spikeinterface.sorters.run_sorter` method can also accept the dictionary which is returned
343+
by :py:meth:`~spikeinterface.core.BaseRecording.split_by` and will return a dictionary of sortings.
344344

345345
In this example, we create a 16-channel recording with 4 tetrodes:
346346

@@ -368,7 +368,19 @@ In this example, we create a 16-channel recording with 4 tetrodes:
368368
# >>> [0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]
369369
370370
371-
**Option 1: Manual splitting**
371+
**Option 1 : Automatic splitting**
372+
373+
.. code-block:: python
374+
375+
# here the result is a dict of sortings
376+
dict_of_sortings = run_sorter(
377+
sorter_name='kilosort2',
378+
recording=recording_4_tetrodes,
379+
working_folder='working_path'
380+
)
381+
382+
383+
**Option 2: Manual splitting**
372384

373385
.. code-block:: python
374386
@@ -383,15 +395,6 @@ In this example, we create a 16-channel recording with 4 tetrodes:
383395
sorting = run_sorter(sorter_name='kilosort2', recording=recording, folder=f"folder_KS2_group{group}")
384396
sortings[group] = sorting
385397
386-
**Option 2 : Automatic splitting**
387-
388-
.. code-block:: python
389-
390-
# here the result is one sorting that aggregates all sub sorting objects
391-
aggregate_sorting = run_sorter_by_property(sorter_name='kilosort2', recording=recording_4_tetrodes,
392-
grouping_property='group',
393-
folder='working_path')
394-
395398
396399
Handling multi-segment recordings
397400
---------------------------------

src/spikeinterface/core/loading.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ def _guess_object_from_local_folder(folder):
196196
with open(folder / "spikeinterface_info.json", "r") as f:
197197
spikeinterface_info = json.load(f)
198198
return _guess_object_from_dict(spikeinterface_info)
199+
elif (
200+
(folder / "sorter_output").is_dir()
201+
and (folder / "spikeinterface_params.json").is_file()
202+
and (folder / "spikeinterface_log.json").is_file()
203+
):
204+
return "SorterFolder"
199205
elif (folder / "waveforms").is_dir():
200206
# before the SortingAnlazer, it was WaveformExtractor (v<0.101)
201207
return "WaveformExtractor"
@@ -212,13 +218,20 @@ def _guess_object_from_local_folder(folder):
212218
return "Recording|Sorting"
213219

214220

215-
def _load_object_from_folder(folder, object_type, **kwargs):
221+
def _load_object_from_folder(folder, object_type: str, **kwargs):
222+
216223
if object_type == "SortingAnalyzer":
217224
from .sortinganalyzer import load_sorting_analyzer
218225

219226
analyzer = load_sorting_analyzer(folder, **kwargs)
220227
return analyzer
221228

229+
elif object_type == "SorterFolder":
230+
from spikeinterface.sorters import read_sorter_folder
231+
232+
sorting = read_sorter_folder(folder)
233+
return sorting
234+
222235
elif object_type == "Motion":
223236
from spikeinterface.core.motion import Motion
224237

@@ -244,6 +257,16 @@ def _load_object_from_folder(folder, object_type, **kwargs):
244257
si_file = f
245258
return BaseExtractor.load(si_file, base_folder=folder)
246259

260+
elif object_type.startswith("Group"):
261+
262+
sub_object_type = object_type.split("[")[1].split("]")[0]
263+
with open(folder / "spikeinterface_info.json", "r") as f:
264+
spikeinterface_info = json.load(f)
265+
group_keys = spikeinterface_info.get("dict_keys")
266+
267+
group_of_objects = {key: _load_object_from_folder(folder / str(key), sub_object_type) for key in group_keys}
268+
return group_of_objects
269+
247270

248271
def _guess_object_from_zarr(zarr_folder):
249272
# here it can be a zarr folder for Recording|Sorting|SortingAnalyzer|Template

src/spikeinterface/sorters/runsorter.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
----------
7171
sorter_name : str
7272
The sorter name
73-
recording : RecordingExtractor
73+
recording : RecordingExtractor | dict of RecordingExtractor
7474
The recording extractor to be spike sorted
7575
folder : str or Path
7676
Path to output folder
@@ -98,16 +98,12 @@
9898
**sorter_params : keyword args
9999
Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`)
100100
101-
Returns
102-
-------
103-
BaseSorting | None
104-
The spike sorted data (it `with_output` is True) or None (if `with_output` is False)
105101
"""
106102

107103

108104
def run_sorter(
109105
sorter_name: str,
110-
recording: BaseRecording,
106+
recording: BaseRecording | dict,
111107
folder: Optional[str] = None,
112108
remove_existing_folder: bool = False,
113109
delete_output_folder: bool = False,
@@ -121,8 +117,11 @@ def run_sorter(
121117
):
122118
"""
123119
Generic function to run a sorter via function approach.
124-
125120
{}
121+
Returns
122+
-------
123+
BaseSorting | dict of BaseSorting | None
124+
The spike sorted data (it `with_output` is True) or None (if `with_output` is False)
126125
127126
Examples
128127
--------
@@ -141,6 +140,21 @@ def run_sorter(
141140
**sorter_params,
142141
)
143142

143+
if isinstance(recording, dict):
144+
145+
all_kwargs = common_kwargs
146+
all_kwargs.update(
147+
dict(
148+
docker_image=docker_image,
149+
singularity_image=singularity_image,
150+
delete_container_files=delete_container_files,
151+
)
152+
)
153+
all_kwargs.pop("recording")
154+
155+
dict_of_sorters = _run_sorter_by_dict(dict_of_recordings=recording, **all_kwargs)
156+
return dict_of_sorters
157+
144158
if docker_image or singularity_image:
145159
common_kwargs.update(dict(delete_container_files=delete_container_files))
146160
if docker_image:
@@ -191,6 +205,46 @@ def run_sorter(
191205
run_sorter.__doc__ = run_sorter.__doc__.format(_common_param_doc)
192206

193207

208+
def _run_sorter_by_dict(dict_of_recordings: dict, folder: str | Path | None = None, **run_sorter_params):
209+
"""
210+
Applies `run_sorter` to each recording in a dict of recordings and saves
211+
the results.
212+
{}
213+
Returns
214+
-------
215+
dict
216+
Dictionary of `BaseSorting`s, with the same keys as the input dict of `BaseRecording`s.
217+
"""
218+
219+
sorter_name = run_sorter_params.get("sorter_name")
220+
remove_existing_folder = run_sorter_params.get("remove_existing_folder")
221+
222+
if folder is None:
223+
folder = Path(sorter_name + "_output")
224+
225+
folder = Path(folder)
226+
folder.mkdir(exist_ok=remove_existing_folder)
227+
228+
sorter_dict = {}
229+
for group_key, recording in dict_of_recordings.items():
230+
sorter_dict[group_key] = run_sorter(recording=recording, folder=folder / f"{group_key}", **run_sorter_params)
231+
232+
info_file = folder / "spikeinterface_info.json"
233+
info = dict(
234+
version=spikeinterface.__version__,
235+
dev_mode=spikeinterface.DEV_MODE,
236+
object="Group[SorterFolder]",
237+
dict_keys=list(dict_of_recordings.keys()),
238+
)
239+
with open(info_file, mode="w") as f:
240+
json.dump(check_json(info), f, indent=4)
241+
242+
return sorter_dict
243+
244+
245+
_run_sorter_by_dict.__doc__ = _run_sorter_by_dict.__doc__.format(_common_param_doc)
246+
247+
194248
def run_sorter_local(
195249
sorter_name,
196250
recording,

src/spikeinterface/sorters/tests/test_runsorter.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from pathlib import Path
55
import shutil
66
from packaging.version import parse
7+
import json
8+
import numpy as np
79

8-
from spikeinterface import generate_ground_truth_recording
10+
from spikeinterface import generate_ground_truth_recording, load
911
from spikeinterface.sorters import run_sorter
1012

1113
ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
@@ -45,6 +47,54 @@ def test_run_sorter_local(generate_recording, create_cache_folder):
4547
print(sorting)
4648

4749

50+
def test_run_sorter_dict(generate_recording, create_cache_folder):
51+
recording = generate_recording
52+
cache_folder = create_cache_folder
53+
54+
recording = recording.time_slice(start_time=0, end_time=3)
55+
56+
recording.set_property(key="split_property", values=[4, 4, "g", "g", 4, 4, 4, "g"])
57+
dict_of_recordings = recording.split_by("split_property")
58+
59+
sorter_params = {"detection": {"detect_threshold": 4.9}}
60+
61+
folder = cache_folder / "sorting_tdc_local_dict"
62+
63+
dict_of_sortings = run_sorter(
64+
"simple",
65+
dict_of_recordings,
66+
folder=folder,
67+
remove_existing_folder=True,
68+
delete_output_folder=False,
69+
verbose=True,
70+
raise_error=True,
71+
**sorter_params,
72+
)
73+
74+
assert set(list(dict_of_sortings.keys())) == set(["g", "4"])
75+
assert (folder / "g").is_dir()
76+
assert (folder / "4").is_dir()
77+
78+
assert dict_of_sortings["g"]._recording.get_num_channels() == 3
79+
assert dict_of_sortings["4"]._recording.get_num_channels() == 5
80+
81+
info_filepath = folder / "spikeinterface_info.json"
82+
assert info_filepath.is_file()
83+
84+
with open(info_filepath) as f:
85+
spikeinterface_info = json.load(f)
86+
87+
si_info_keys = spikeinterface_info.keys()
88+
for key in ["version", "dev_mode", "object"]:
89+
assert key in si_info_keys
90+
91+
loaded_sortings = load(folder)
92+
assert loaded_sortings.keys() == dict_of_sortings.keys()
93+
for key, sorting in loaded_sortings.items():
94+
assert np.all(sorting.unit_ids == dict_of_sortings[key].unit_ids)
95+
assert np.all(sorting.to_spike_vector() == dict_of_sortings[key].to_spike_vector())
96+
97+
4898
@pytest.mark.skipif(ON_GITHUB, reason="Docker tests don't run on github: test locally")
4999
def test_run_sorter_docker(generate_recording, create_cache_folder):
50100
recording = generate_recording

0 commit comments

Comments
 (0)