Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions doc/how_to/process_by_channel_group.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ to any preprocessing function.
referenced_recording = spre.common_reference(filtered_recording)
good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording)

We can then aggregate the recordings back together using the ``aggregate_channels`` function
We can then aggregate the recordings back together using the ``aggregate_channels`` function.
Note that we do not need to do this to sort the data (see :ref:`sorting a recording by channel group`)

.. code-block:: python

Expand Down Expand Up @@ -141,16 +142,38 @@ Sorting a Recording by Channel Group
We can also sort a recording for each channel group separately. It is not necessary to preprocess
a recording by channel group in order to sort by channel group.

There are two ways to sort a recording by channel group. First, we can split the preprocessed
recording (or, if it was already split during preprocessing as above, skip the :py:func:`~aggregate_channels` step
directly use the :py:func:`~split_recording_dict`).
There are two ways to sort a recording by channel group. First, we can simply pass the output from
our preprocessing-by-group method above. Second, for more control, we can loop over the recordings
ourselves.

**Option 1: Manual splitting**
**Option 1 : Automatic splitting**

In this example, similar to above we loop over all preprocessed recordings that
Simply pass the split recording to the `run_sorter` function, as if it was a non-split recording.
This will return a dict of sortings, with the keys corresponding to the groups.

.. code-block:: python

split_recording = raw_recording.split_by("group")

# do preprocessing if needed
pp_recording = spre.bandpass_filter(split_recording)

dict_of_sortings = run_sorter(
sorter_name='kilosort2',
recording=pp_recording,
working_folder='working_path'
)


**Option 2: Manual splitting**

In this example, we loop over all preprocessed recordings that
are grouped by channel, and apply the sorting separately. We store the
sorting objects in a dictionary for later use.

You might do this if you want extra control e.g. to apply bespoke steps
to different groups.

.. code-block:: python

split_preprocessed_recording = preprocessed_recording.split_by("group")
Expand All @@ -163,16 +186,3 @@ sorting objects in a dictionary for later use.
output_folder=f"folder_KS2_group{group}"
)
sortings[group] = sorting

**Option 2 : Automatic splitting**

Alternatively, SpikeInterface provides a convenience function to sort the recording by property:

.. code-block:: python

aggregate_sorting = run_sorter_by_property(
sorter_name='kilosort2',
recording=preprocessed_recording,
grouping_property='group',
working_folder='working_path'
)
27 changes: 15 additions & 12 deletions doc/modules/sorters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ Running spike sorting by group is indeed a very common need.
A :py:class:`~spikeinterface.core.BaseRecording` object has the ability to split itself into a dictionary of
sub-recordings given a certain property (see :py:meth:`~spikeinterface.core.BaseRecording.split_by`).
So it is easy to loop over this dictionary and sequentially run spike sorting on these sub-recordings.
SpikeInterface also provides a high-level function to automate the process of splitting the
recording and then aggregating the results with the :py:func:`~spikeinterface.sorters.run_sorter_by_property` function.
The :py:func:`~spikeinterface.sorters.run_sorter` method can also accept the dictionary which is returned
by :py:meth:`~spikeinterface.core.BaseRecording.split_by` and will return a dictionary of sortings.

In this example, we create a 16-channel recording with 4 tetrodes:

Expand Down Expand Up @@ -368,7 +368,19 @@ In this example, we create a 16-channel recording with 4 tetrodes:
# >>> [0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]


**Option 1: Manual splitting**
**Option 1 : Automatic splitting**

.. code-block:: python

# here the result is a dict of sortings
dict_of_sortings = run_sorter(
sorter_name='kilosort2',
recording=recording_4_tetrodes,
working_folder='working_path'
)


**Option 2: Manual splitting**

.. code-block:: python

Expand All @@ -383,15 +395,6 @@ In this example, we create a 16-channel recording with 4 tetrodes:
sorting = run_sorter(sorter_name='kilosort2', recording=recording, output_folder=f"folder_KS2_group{group}")
sortings[group] = sorting

**Option 2 : Automatic splitting**

.. code-block:: python

# here the result is one sorting that aggregates all sub sorting objects
aggregate_sorting = run_sorter_by_property(sorter_name='kilosort2', recording=recording_4_tetrodes,
grouping_property='group',
working_folder='working_path')


Handling multi-segment recordings
---------------------------------
Expand Down
25 changes: 24 additions & 1 deletion src/spikeinterface/core/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ def _guess_object_from_local_folder(folder):
with open(folder / "spikeinterface_info.json", "r") as f:
spikeinterface_info = json.load(f)
return _guess_object_from_dict(spikeinterface_info)
elif (
(folder / "sorter_output").is_dir()
and (folder / "spikeinterface_params.json").is_file()
and (folder / "spikeinterface_log.json").is_file()
):
return "SorterOutput"
elif (folder / "waveforms").is_dir():
# before the SortingAnlazer, it was WaveformExtractor (v<0.101)
return "WaveformExtractor"
Expand All @@ -212,13 +218,20 @@ def _guess_object_from_local_folder(folder):
return "Recording|Sorting"


def _load_object_from_folder(folder, object_type, **kwargs):
def _load_object_from_folder(folder, object_type: str, **kwargs):

if object_type == "SortingAnalyzer":
from .sortinganalyzer import load_sorting_analyzer

analyzer = load_sorting_analyzer(folder, **kwargs)
return analyzer

elif object_type == "SorterOutput":
from spikeinterface.sorters import read_sorter_folder

sorting = read_sorter_folder(folder)
return sorting

elif object_type == "Motion":
from spikeinterface.core.motion import Motion

Expand All @@ -244,6 +257,16 @@ def _load_object_from_folder(folder, object_type, **kwargs):
si_file = f
return BaseExtractor.load(si_file, base_folder=folder)

elif object_type.startswith("Group"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feedback on these lines of code very welcome! Wasn't sure best way to proceed.


sub_object_type = object_type.split("[")[1].split("]")[0]
with open(folder / "spikeinterface_info.json", "r") as f:
spikeinterface_info = json.load(f)
group_keys = spikeinterface_info.get("dict_keys")

group_of_objects = {key: _load_object_from_folder(folder / str(key), sub_object_type) for key in group_keys}
return group_of_objects


def _guess_object_from_zarr(zarr_folder):
# here it can be a zarr folder for Recording|Sorting|SortingAnalyzer|Template
Expand Down
77 changes: 70 additions & 7 deletions src/spikeinterface/sorters/runsorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
----------
sorter_name : str
The sorter name
recording : RecordingExtractor
recording : RecordingExtractor | dict of RecordingExtractor
The recording extractor to be spike sorted
folder : str or Path
Path to output folder
Expand Down Expand Up @@ -100,16 +100,12 @@
**sorter_params : keyword args
Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`)

Returns
-------
BaseSorting | None
The spike sorted data (it `with_output` is True) or None (if `with_output` is False)
"""


def run_sorter(
sorter_name: str,
recording: BaseRecording,
recording: BaseRecording | dict,
folder: Optional[str] = None,
remove_existing_folder: bool = False,
delete_output_folder: bool = False,
Expand All @@ -124,8 +120,11 @@ def run_sorter(
):
"""
Generic function to run a sorter via function approach.

{}
Returns
-------
BaseSorting | dict of BaseSorting | None
The spike sorted data (it `with_output` is True) or None (if `with_output` is False)

Examples
--------
Expand All @@ -151,6 +150,20 @@ def run_sorter(
**sorter_params,
)

if isinstance(recording, dict):

all_kwargs = common_kwargs
all_kwargs.update(
dict(
docker_image=docker_image,
singularity_image=singularity_image,
delete_container_files=delete_container_files,
)
)

dict_of_sorters = _run_sorter_by_dict(recording, **all_kwargs)
return dict_of_sorters

if docker_image or singularity_image:
common_kwargs.update(dict(delete_container_files=delete_container_files))
if docker_image:
Expand Down Expand Up @@ -201,6 +214,56 @@ def run_sorter(
run_sorter.__doc__ = run_sorter.__doc__.format(_common_param_doc)


def _run_sorter_by_dict(dict_of_recordings: dict, folder: str | Path | None = None, **run_sorter_params):
"""
Applies `run_sorter` to each recording in a dict of recordings and saves
the results.
{}
Returns
-------
dict
Dictionary of `BaseSorting`s, with the same keys as the input dict of `BaseRecording`s.
"""

sorter_name = run_sorter_params["sorter_name"]
remove_existing_folder = run_sorter_params["remove_existing_folder"]

if folder is None:
folder = Path(sorter_name + "_output")

folder = Path(folder)
folder.mkdir(exist_ok=remove_existing_folder)

# If we know how the recording was split, save this in the info file
first_recording = next(iter(dict_of_recordings.values()))
split_by_property = first_recording.get_annotation("split_by_property")
if split_by_property is None:
split_by_property = "Unknown"

info_file = folder / "spikeinterface_info.json"
info = dict(
version=spikeinterface.__version__,
dev_mode=spikeinterface.DEV_MODE,
object="Group[SorterOutput]",
dict_keys=list(dict_of_recordings.keys()),
)
with open(info_file, mode="w") as f:
json.dump(check_json(info), f, indent=4)

sorter_dict = {}
for group_key, recording in dict_of_recordings.items():

if "recording" in run_sorter_params:
run_sorter_params.pop("recording")

sorter_dict[group_key] = run_sorter(recording=recording, folder=folder / f"{group_key}", **run_sorter_params)

return sorter_dict


_run_sorter_by_dict.__doc__ = _run_sorter_by_dict.__doc__.format(_common_param_doc)


def run_sorter_local(
sorter_name,
recording,
Expand Down
52 changes: 51 additions & 1 deletion src/spikeinterface/sorters/tests/test_runsorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from pathlib import Path
import shutil
from packaging.version import parse
import json
import numpy as np

from spikeinterface import generate_ground_truth_recording
from spikeinterface import generate_ground_truth_recording, load
from spikeinterface.sorters import run_sorter

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
Expand Down Expand Up @@ -45,6 +47,54 @@ def test_run_sorter_local(generate_recording, create_cache_folder):
print(sorting)


def test_run_sorter_dict(generate_recording, create_cache_folder):
recording = generate_recording
cache_folder = create_cache_folder

recording = recording.time_slice(start_time=0, end_time=3)

recording.set_property(key="split_property", values=[4, 4, "g", "g", 4, 4, 4, "g"])
dict_of_recordings = recording.split_by("split_property")

sorter_params = {"detection": {"detect_threshold": 4.9}}

output_folder = cache_folder / "sorting_tdc_local_dict"

dict_of_sortings = run_sorter(
"simple",
dict_of_recordings,
output_folder=output_folder,
remove_existing_folder=True,
delete_output_folder=False,
verbose=True,
raise_error=True,
**sorter_params,
)

assert set(list(dict_of_sortings.keys())) == set(["g", "4"])
assert (output_folder / "g").is_dir()
assert (output_folder / "4").is_dir()

assert dict_of_sortings["g"]._recording.get_num_channels() == 3
assert dict_of_sortings["4"]._recording.get_num_channels() == 5

info_filepath = output_folder / "spikeinterface_info.json"
assert info_filepath.is_file()

with open(info_filepath) as f:
spikeinterface_info = json.load(f)

si_info_keys = spikeinterface_info.keys()
for key in ["version", "dev_mode", "object"]:
assert key in si_info_keys

loaded_sortings = load(output_folder)
assert loaded_sortings.keys() == dict_of_sortings.keys()
for key, sorting in loaded_sortings.items():
assert np.all(sorting.unit_ids == dict_of_sortings[key].unit_ids)
assert np.all(sorting.to_spike_vector() == dict_of_sortings[key].to_spike_vector())


@pytest.mark.skipif(ON_GITHUB, reason="Docker tests don't run on github: test locally")
def test_run_sorter_docker(generate_recording, create_cache_folder):
recording = generate_recording
Expand Down
Loading