diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index 0e6ae49d37..9ff8215aba 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -99,7 +99,8 @@ to any preprocessing function. referenced_recording = spre.common_reference(filtered_recording) good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording) -We can then aggregate the recordings back together using the ``aggregate_channels`` function +We can then aggregate the recordings back together using the ``aggregate_channels`` function. +Note that we do not need to do this to sort the data (see :ref:`sorting-by-channel-group`). .. code-block:: python @@ -134,6 +135,7 @@ back together under the hood). In general, it is not recommended to apply :py:func:`~aggregate_channels` more than once. This will slow down :py:func:`~get_traces` calls and may result in unpredictable behaviour. +.. _sorting-by-channel-group: Sorting a Recording by Channel Group ------------------------------------ @@ -141,16 +143,39 @@ Sorting a Recording by Channel Group We can also sort a recording for each channel group separately. It is not necessary to preprocess a recording by channel group in order to sort by channel group. -There are two ways to sort a recording by channel group. First, we can split the preprocessed -recording (or, if it was already split during preprocessing as above, skip the :py:func:`~aggregate_channels` step -directly use the :py:func:`~split_recording_dict`). +There are two ways to sort a recording by channel group. First, we can simply pass the output from +our preprocessing-by-group method above. Second, for more control, we can loop over the recordings +ourselves. -**Option 1: Manual splitting** +**Option 1 : Automatic splitting** -In this example, similar to above we loop over all preprocessed recordings that +Simply pass the split recording to the `run_sorter` function, as if it was a non-split recording. +This will return a dict of sortings, with the keys corresponding to the groups. + +.. code-block:: python + + split_recording = raw_recording.split_by("group") + # is a dict of recordings + + # do preprocessing if needed + pp_recording = spre.bandpass_filter(split_recording) + + dict_of_sortings = run_sorter( + sorter_name='kilosort2', + recording=pp_recording, + working_folder='working_path' + ) + + +**Option 2: Manual splitting** + +In this example, we loop over all preprocessed recordings that are grouped by channel, and apply the sorting separately. We store the sorting objects in a dictionary for later use. +You might do this if you want extra control e.g. to apply bespoke steps +to different groups. + .. code-block:: python split_preprocessed_recording = preprocessed_recording.split_by("group") @@ -163,16 +188,3 @@ sorting objects in a dictionary for later use. folder=f"folder_KS2_group{group}" ) sortings[group] = sorting - -**Option 2 : Automatic splitting** - -Alternatively, SpikeInterface provides a convenience function to sort the recording by property: - -.. code-block:: python - - aggregate_sorting = run_sorter_by_property( - sorter_name='kilosort2', - recording=preprocessed_recording, - grouping_property='group', - working_folder='working_path' - ) diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 6bf3a60e46..393a357cef 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -339,8 +339,8 @@ Running spike sorting by group is indeed a very common need. A :py:class:`~spikeinterface.core.BaseRecording` object has the ability to split itself into a dictionary of sub-recordings given a certain property (see :py:meth:`~spikeinterface.core.BaseRecording.split_by`). So it is easy to loop over this dictionary and sequentially run spike sorting on these sub-recordings. -SpikeInterface also provides a high-level function to automate the process of splitting the -recording and then aggregating the results with the :py:func:`~spikeinterface.sorters.run_sorter_by_property` function. +The :py:func:`~spikeinterface.sorters.run_sorter` method can also accept the dictionary which is returned +by :py:meth:`~spikeinterface.core.BaseRecording.split_by` and will return a dictionary of sortings. In this example, we create a 16-channel recording with 4 tetrodes: @@ -368,7 +368,19 @@ In this example, we create a 16-channel recording with 4 tetrodes: # >>> [0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3] -**Option 1: Manual splitting** +**Option 1 : Automatic splitting** + +.. code-block:: python + + # here the result is a dict of sortings + dict_of_sortings = run_sorter( + sorter_name='kilosort2', + recording=recording_4_tetrodes, + working_folder='working_path' + ) + + +**Option 2: Manual splitting** .. code-block:: python @@ -383,15 +395,6 @@ In this example, we create a 16-channel recording with 4 tetrodes: sorting = run_sorter(sorter_name='kilosort2', recording=recording, folder=f"folder_KS2_group{group}") sortings[group] = sorting -**Option 2 : Automatic splitting** - -.. code-block:: python - - # here the result is one sorting that aggregates all sub sorting objects - aggregate_sorting = run_sorter_by_property(sorter_name='kilosort2', recording=recording_4_tetrodes, - grouping_property='group', - folder='working_path') - Handling multi-segment recordings --------------------------------- diff --git a/src/spikeinterface/core/loading.py b/src/spikeinterface/core/loading.py index d5845f033e..97f104d08f 100644 --- a/src/spikeinterface/core/loading.py +++ b/src/spikeinterface/core/loading.py @@ -196,6 +196,12 @@ def _guess_object_from_local_folder(folder): with open(folder / "spikeinterface_info.json", "r") as f: spikeinterface_info = json.load(f) return _guess_object_from_dict(spikeinterface_info) + elif ( + (folder / "sorter_output").is_dir() + and (folder / "spikeinterface_params.json").is_file() + and (folder / "spikeinterface_log.json").is_file() + ): + return "SorterFolder" elif (folder / "waveforms").is_dir(): # before the SortingAnlazer, it was WaveformExtractor (v<0.101) return "WaveformExtractor" @@ -212,13 +218,20 @@ def _guess_object_from_local_folder(folder): return "Recording|Sorting" -def _load_object_from_folder(folder, object_type, **kwargs): +def _load_object_from_folder(folder, object_type: str, **kwargs): + if object_type == "SortingAnalyzer": from .sortinganalyzer import load_sorting_analyzer analyzer = load_sorting_analyzer(folder, **kwargs) return analyzer + elif object_type == "SorterFolder": + from spikeinterface.sorters import read_sorter_folder + + sorting = read_sorter_folder(folder) + return sorting + elif object_type == "Motion": from spikeinterface.core.motion import Motion @@ -244,6 +257,16 @@ def _load_object_from_folder(folder, object_type, **kwargs): si_file = f return BaseExtractor.load(si_file, base_folder=folder) + elif object_type.startswith("Group"): + + sub_object_type = object_type.split("[")[1].split("]")[0] + with open(folder / "spikeinterface_info.json", "r") as f: + spikeinterface_info = json.load(f) + group_keys = spikeinterface_info.get("dict_keys") + + group_of_objects = {key: _load_object_from_folder(folder / str(key), sub_object_type) for key in group_keys} + return group_of_objects + def _guess_object_from_zarr(zarr_folder): # here it can be a zarr folder for Recording|Sorting|SortingAnalyzer|Template diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 5c44db2d58..408bec65c2 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -70,7 +70,7 @@ ---------- sorter_name : str The sorter name - recording : RecordingExtractor + recording : RecordingExtractor | dict of RecordingExtractor The recording extractor to be spike sorted folder : str or Path Path to output folder @@ -98,16 +98,12 @@ **sorter_params : keyword args Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`) - Returns - ------- - BaseSorting | None - The spike sorted data (it `with_output` is True) or None (if `with_output` is False) """ def run_sorter( sorter_name: str, - recording: BaseRecording, + recording: BaseRecording | dict, folder: Optional[str] = None, remove_existing_folder: bool = False, delete_output_folder: bool = False, @@ -121,8 +117,11 @@ def run_sorter( ): """ Generic function to run a sorter via function approach. - {} + Returns + ------- + BaseSorting | dict of BaseSorting | None + The spike sorted data (it `with_output` is True) or None (if `with_output` is False) Examples -------- @@ -141,6 +140,21 @@ def run_sorter( **sorter_params, ) + if isinstance(recording, dict): + + all_kwargs = common_kwargs + all_kwargs.update( + dict( + docker_image=docker_image, + singularity_image=singularity_image, + delete_container_files=delete_container_files, + ) + ) + all_kwargs.pop("recording") + + dict_of_sorters = _run_sorter_by_dict(dict_of_recordings=recording, **all_kwargs) + return dict_of_sorters + if docker_image or singularity_image: common_kwargs.update(dict(delete_container_files=delete_container_files)) if docker_image: @@ -191,6 +205,46 @@ def run_sorter( run_sorter.__doc__ = run_sorter.__doc__.format(_common_param_doc) +def _run_sorter_by_dict(dict_of_recordings: dict, folder: str | Path | None = None, **run_sorter_params): + """ + Applies `run_sorter` to each recording in a dict of recordings and saves + the results. + {} + Returns + ------- + dict + Dictionary of `BaseSorting`s, with the same keys as the input dict of `BaseRecording`s. + """ + + sorter_name = run_sorter_params.get("sorter_name") + remove_existing_folder = run_sorter_params.get("remove_existing_folder") + + if folder is None: + folder = Path(sorter_name + "_output") + + folder = Path(folder) + folder.mkdir(exist_ok=remove_existing_folder) + + sorter_dict = {} + for group_key, recording in dict_of_recordings.items(): + sorter_dict[group_key] = run_sorter(recording=recording, folder=folder / f"{group_key}", **run_sorter_params) + + info_file = folder / "spikeinterface_info.json" + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + object="Group[SorterFolder]", + dict_keys=list(dict_of_recordings.keys()), + ) + with open(info_file, mode="w") as f: + json.dump(check_json(info), f, indent=4) + + return sorter_dict + + +_run_sorter_by_dict.__doc__ = _run_sorter_by_dict.__doc__.format(_common_param_doc) + + def run_sorter_local( sorter_name, recording, diff --git a/src/spikeinterface/sorters/tests/test_runsorter.py b/src/spikeinterface/sorters/tests/test_runsorter.py index 1f2ec373a9..7b88de0266 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter.py +++ b/src/spikeinterface/sorters/tests/test_runsorter.py @@ -4,8 +4,10 @@ from pathlib import Path import shutil from packaging.version import parse +import json +import numpy as np -from spikeinterface import generate_ground_truth_recording +from spikeinterface import generate_ground_truth_recording, load from spikeinterface.sorters import run_sorter ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -45,6 +47,54 @@ def test_run_sorter_local(generate_recording, create_cache_folder): print(sorting) +def test_run_sorter_dict(generate_recording, create_cache_folder): + recording = generate_recording + cache_folder = create_cache_folder + + recording = recording.time_slice(start_time=0, end_time=3) + + recording.set_property(key="split_property", values=[4, 4, "g", "g", 4, 4, 4, "g"]) + dict_of_recordings = recording.split_by("split_property") + + sorter_params = {"detection": {"detect_threshold": 4.9}} + + folder = cache_folder / "sorting_tdc_local_dict" + + dict_of_sortings = run_sorter( + "simple", + dict_of_recordings, + folder=folder, + remove_existing_folder=True, + delete_output_folder=False, + verbose=True, + raise_error=True, + **sorter_params, + ) + + assert set(list(dict_of_sortings.keys())) == set(["g", "4"]) + assert (folder / "g").is_dir() + assert (folder / "4").is_dir() + + assert dict_of_sortings["g"]._recording.get_num_channels() == 3 + assert dict_of_sortings["4"]._recording.get_num_channels() == 5 + + info_filepath = folder / "spikeinterface_info.json" + assert info_filepath.is_file() + + with open(info_filepath) as f: + spikeinterface_info = json.load(f) + + si_info_keys = spikeinterface_info.keys() + for key in ["version", "dev_mode", "object"]: + assert key in si_info_keys + + loaded_sortings = load(folder) + assert loaded_sortings.keys() == dict_of_sortings.keys() + for key, sorting in loaded_sortings.items(): + assert np.all(sorting.unit_ids == dict_of_sortings[key].unit_ids) + assert np.all(sorting.to_spike_vector() == dict_of_sortings[key].to_spike_vector()) + + @pytest.mark.skipif(ON_GITHUB, reason="Docker tests don't run on github: test locally") def test_run_sorter_docker(generate_recording, create_cache_folder): recording = generate_recording