diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index 07d0162a29..f1dbb80a42 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -1,4 +1,4 @@ -.. _recording-by-channel-group: +.. _process_by_group: Process a recording by channel group ==================================== @@ -101,8 +101,8 @@ to any preprocessing function. referenced_recording = spre.common_reference(filtered_recording) good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording) -We can then aggregate the recordings back together using the ``aggregate_channels`` function. -Note that we do not need to do this to sort the data (see :ref:`sorting-by-channel-group`). +If needed, we could aggregate the recordings back together using the ``aggregate_channels`` function. +Note: you do not need to do this to sort the data (see :ref:`sorting-by-channel-group`). .. code-block:: python @@ -145,14 +145,16 @@ Sorting a Recording by Channel Group We can also sort a recording for each channel group separately. It is not necessary to preprocess a recording by channel group in order to sort by channel group. -There are two ways to sort a recording by channel group. First, we can simply pass the output from -our preprocessing-by-group method above. Second, for more control, we can loop over the recordings +There are two ways to sort a recording by channel group. First, we can pass a dictionary to the +``run_sorter`` function. Since the preprocessing-by-group method above returns a dict, we can +simply pass this output. Alternatively, for more control, we can loop over the recordings ourselves. -**Option 1 : Automatic splitting** +**Option 1 : Automatic splitting (Recommended)** -Simply pass the split recording to the `run_sorter` function, as if it was a non-split recording. -This will return a dict of sortings, with the keys corresponding to the groups. +Simply pass the split recording to the ``run_sorter`` function, as if it was a non-split recording. +This will return a dict of sortings, with the same keys as the dict of recordings that were +passed to ``run_sorter``. .. code-block:: python @@ -162,10 +164,10 @@ This will return a dict of sortings, with the keys corresponding to the groups. # do preprocessing if needed pp_recording = spre.bandpass_filter(split_recording) - dict_of_sortings = run_sorter( - sorter_name='kilosort2', + dict_of_sortings = run_sorter( + sorter_name='kilosort4', recording=pp_recording, - working_folder='working_path' + folder='my_kilosort4_sorting' ) @@ -190,3 +192,37 @@ to different groups. folder=f"folder_KS2_group{group}" ) sortings[group] = sorting + + +Creating a SortingAnalyzer by Channel Group +------------------------------------------- + +The code above generates a dictionary of recording objects and a dictionary of sorting objects. +When making a :ref:`SortingAnalyzer `, we can pass these dictionaries and +a single analyzer will be created, with the recordings and sortings appropriately aggregated. + +The dictionary of recordings and dictionary of sortings must have the same keys. E.g. if you +use ``split_by("group")``, the keys of your dict of recordings will be the values of the ``group`` +property of the recording. Then the dict of sortings should also have these keys. +Note that if you use the internal functions, like we do in the code-block below, you don't need to +keep track of keys yourself. SpikeInterface will do this for you automatically. + +The code for create ``SortingAnalyzer`` from dicts of recordings and sortings is very similar to that for +creating a sorting analyzer from a single recording and sorting: + +.. code-block:: python + + dict_of_recordings = preprocessed_recording.split_by("group") + dict_of_sortings = run_sorter(sorter_name="mountainsort5", recording = dict_of_recordings) + + analyzer = create_sorting_analyzer(sorting=dict_of_sortings, recording=dict_of_recordings) + + +The code above creates a *single* sorting analyzer called :code:`analyzer`. You can select the units +from one of the "group"s as follows: + +.. code-block:: python + + aggretation_keys = analyzer.get_sorting_property("aggregation_key") + unit_ids_group_0 = analyzer.unit_ids[aggretation_keys == 0] + group_0_analzyer = analyzer.select_units(unit_ids = unit_ids_group_0) diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index fd2423a26e..4045dd4674 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -338,9 +338,11 @@ Running spike sorting by group is indeed a very common need. A :py:class:`~spikeinterface.core.BaseRecording` object has the ability to split itself into a dictionary of sub-recordings given a certain property (see :py:meth:`~spikeinterface.core.BaseRecording.split_by`). -So it is easy to loop over this dictionary and sequentially run spike sorting on these sub-recordings. -The :py:func:`~spikeinterface.sorters.run_sorter` method can also accept the dictionary which is returned +The :py:func:`~spikeinterface.sorters.run_sorter` method can accept the dictionary which is returned by :py:meth:`~spikeinterface.core.BaseRecording.split_by` and will return a dictionary of sortings. +In turn, these can be fed directly to :py:meth:`~spikeinterface.core.create_sorting_analyzer` to make +a SortingAnalyzer. For more control, you can loop over the dictionary returned by :py:meth:`~spikeinterface.core.BaseRecording.split_by` +and sequentially run spike sorting on these sub-recordings. In this example, we create a 16-channel recording with 4 tetrodes: @@ -396,6 +398,10 @@ In this example, we create a 16-channel recording with 4 tetrodes: sortings[group] = sorting +Note: you can feed the dict of sortings and dict of recordings directly to :code:`create_sorting_analyzer` to make +a SortingAnalyzer from the split data: :ref:`read more `. + + Handling multi-segment recordings --------------------------------- diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 6be1766dbc..56f558bce7 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -514,7 +514,7 @@ def split_by(self, property="group", outputs="dict"): recordings = {} for value in np.unique(values).tolist(): (inds,) = np.nonzero(values == value) - new_channel_ids = self.get_channel_ids()[inds] + new_channel_ids = self.channel_ids[inds] subrec = self.select_channels(new_channel_ids) subrec.set_annotation("split_by_property", value=property) if outputs == "list": diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 8a1fa9cf1b..98159fb646 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -626,6 +626,47 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo return self.frame_slice(start_frame=start_frame, end_frame=end_frame) + def split_by(self, property="group", outputs="dict"): + """ + Splits object based on a certain property (e.g. "group") + + Parameters + ---------- + property : str, default: "group" + The property to use to split the object, default: "group" + outputs : "dict" | "list", default: "dict" + Whether to return a dict or a list + + Returns + ------- + dict or list + A dict or list with grouped objects based on property + + Raises + ------ + ValueError + Raised when property is not present + """ + assert outputs in ("list", "dict") + values = self.get_property(property) + if values is None: + raise ValueError(f"property {property} is not set") + + if outputs == "list": + sortings = [] + elif outputs == "dict": + sortings = {} + for value in np.unique(values).tolist(): + (inds,) = np.nonzero(values == value) + new_unit_ids = self.unit_ids[inds] + subsort = self.select_units(new_unit_ids) + subsort.set_annotation("split_by_property", value=property) + if outputs == "list": + sortings.append(subsort) + elif outputs == "dict": + sortings[value] = subsort + return sortings + def time_to_sample_index(self, time, segment_index=0): """ Transform time in seconds into sample index diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 9116437775..996772fd33 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -37,10 +37,8 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record self._recordings = recording_list - splitting_known = self._is_splitting_known() - if not splitting_known: - for group_id, recording in zip(recording_ids, recording_list): - recording.set_property("group", [group_id] * recording.get_num_channels()) + for group_id, recording in zip(recording_ids, recording_list): + recording.set_property("aggregation_key", [group_id] * recording.get_num_channels()) self._perform_consistency_checks() sampling_frequency = recording_list[0].get_sampling_frequency() @@ -140,25 +138,6 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record def recordings(self): return self._recordings - def _is_splitting_known(self): - - # If we have the `split_by_property` annotation, we know how the recording was split - if self._recordings[0].get_annotation("split_by_property") is not None: - return True - - # Check if all 'group' properties are equal to 0 - recording_groups = [] - for recording in self._recordings: - if (group_labels := recording.get_property("group")) is not None: - recording_groups.extend(group_labels) - else: - recording_groups.extend([0]) - # If so, we don't know the splitting - if np.all(np.unique(recording_groups) == np.array([0])): - return False - else: - return True - def _perform_consistency_checks(self): # Check for consistent sampling frequency across recordings diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index f8d11dd157..8f82754ac7 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -21,8 +21,7 @@ import spikeinterface -from .baserecording import BaseRecording -from .basesorting import BaseSorting +from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match from .core_tools import ( @@ -54,6 +53,7 @@ def create_sorting_analyzer( folder=None, sparse=True, sparsity=None, + set_sparsity_by_dict_key=False, return_scaled=None, return_in_uV=True, overwrite=False, @@ -71,10 +71,10 @@ def create_sorting_analyzer( Parameters ---------- - sorting : Sorting - The sorting object - recording : Recording - The recording object + sorting : Sorting | dict + The sorting object, or a dict of them + recording : Recording | dict + The recording object, or a dict of them folder : str or Path or None, default: None The folder where analyzer is cached format : "memory | "binary_folder" | "zarr", default: "memory" @@ -88,6 +88,9 @@ def create_sorting_analyzer( You can control `estimate_sparsity()` : all extra arguments are propagated to it (included job_kwargs) sparsity : ChannelSparsity or None, default: None The sparsity used to compute exensions. If this is given, `sparse` is ignored. + set_sparsity_by_dict_key : bool, default: False + If True and passing recording and sorting dicts, will set the sparsity based on the dict keys, + and other `sparsity_kwargs` are overwritten. If False, use other sparsity settings. return_scaled : bool | None, default: None DEPRECATED. Use return_in_uV instead. All extensions that play with traces will use this global return_in_uV : "waveforms", "noise_levels", "templates". @@ -139,6 +142,34 @@ def create_sorting_analyzer( In some situation, sparsity is not needed, so to make it fast creation, you need to turn sparsity off (or give external sparsity) like this. """ + + if isinstance(sorting, dict) and isinstance(recording, dict): + + if sorting.keys() != recording.keys(): + raise ValueError( + f"Keys of `sorting`, {sorting.keys()}, and `recording`, {recording.keys()}, dicts do not match." + ) + + aggregated_recording = aggregate_channels(recording) + aggregated_sorting = aggregate_units(sorting) + + if set_sparsity_by_dict_key: + sparsity_kwargs = {"method": "by_property", "by_property": "aggregation_key"} + + return create_sorting_analyzer( + sorting=aggregated_sorting, + recording=aggregated_recording, + format=format, + folder=folder, + sparse=sparse, + sparsity=sparsity, + return_scaled=return_scaled, + return_in_uV=return_in_uV, + overwrite=overwrite, + backend_options=backend_options, + **sparsity_kwargs, + ) + if format != "memory": if format == "zarr": if not is_path_remote(folder): diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index a9bb51dfed..119ab1d598 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -2,6 +2,7 @@ from spikeinterface.core import aggregate_channels from spikeinterface.core import generate_recording +from spikeinterface.core.testing import check_recordings_equal def test_channelsaggregationrecording(): @@ -114,6 +115,39 @@ def test_split_then_aggreate_preserve_user_property(): assert np.all(old_properties_ids_dict == new_properties_ids_dict) +def test_aggregation_split_by_and_manual(): + """ + We can either split recordings automatically using "split_by" or manually by + constructing dictionaries. This test checks the two are equivalent. We skip + the annoations check since the "split_by" also saves an annotation to save what + property we split by. + """ + + rec1 = generate_recording(num_channels=6) + rec1_channel_ids = rec1.get_channel_ids() + rec1.set_property(key="brain_area", values=["a", "a", "b", "a", "b", "a"]) + + split_recs = rec1.split_by("brain_area") + + aggregated_rec = aggregate_channels(split_recs) + + rec_a_channel_ids = aggregated_rec.channel_ids[aggregated_rec.get_property("brain_area") == "a"] + rec_b_channel_ids = aggregated_rec.channel_ids[aggregated_rec.get_property("brain_area") == "b"] + + assert np.all(rec_a_channel_ids == split_recs["a"].channel_ids) + assert np.all(rec_b_channel_ids == split_recs["b"].channel_ids) + + split_recs_manual = { + "a": rec1.select_channels(channel_ids=rec1_channel_ids[rec1.get_property("brain_area") == "a"]), + "b": rec1.select_channels(channel_ids=rec1_channel_ids[rec1.get_property("brain_area") == "b"]), + } + + aggregated_rec_manual = aggregate_channels(split_recs_manual) + + assert np.all(aggregated_rec_manual.get_property("aggregation_key") == ["a", "a", "a", "a", "b", "b"]) + check_recordings_equal(aggregated_rec, aggregated_rec_manual, check_annotations=False, check_properties=True) + + def test_channel_aggregation_preserve_ids(): recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check @@ -132,9 +166,9 @@ def test_aggregation_labeling_for_lists(): recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) - # If we don't label at all, aggregation will add a 'group' label + # If we don't label at all, aggregation will add a 'aggregation_key' label aggregated_recording = aggregate_channels([recording1, recording2]) - group_property = aggregated_recording.get_property("group") + group_property = aggregated_recording.get_property("aggregation_key") assert np.all(group_property == [0, 0, 0, 0, 1, 1]) # If we have different group labels, these should be respected @@ -161,9 +195,9 @@ def test_aggretion_labelling_for_dicts(): recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False) - # If we don't label at all, aggregation will add a 'group' label based on the dict keys + # If we don't label at all, aggregation will add a 'aggregation_key' label based on the dict keys aggregated_recording = aggregate_channels({0: recording1, "cat": recording2}) - group_property = aggregated_recording.get_property("group") + group_property = aggregated_recording.get_property("aggregation_key") assert np.all(group_property == [0, 0, 0, 0, "cat", "cat"]) # If we have different group labels, these should be respected diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index ce248f00f6..f645f3416f 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -201,6 +201,50 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): assert "number" in sorting_analyzer.sorting.get_property_keys() +def test_create_by_dict(): + """ + Generates a recording and sorting which are split into dicts and fed to create_sorting_analyzer. + Interally, this aggregates the dicts of recordings and sortings. This test checks that the + unit structure is maintained from the dicts to the analyzer. Then checks that the function + fails if the dict keys are different for the recordings and the sortings. + """ + + rec, sort = generate_ground_truth_recording(num_channels=6) + + rec.set_property(key="group", values=[1, 2, 1, 1, 2, 2]) + sort.set_property(key="group", values=[2, 2, 2, 1, 2, 2, 2, 1, 2, 1]) + + unit_ids = sort.unit_ids + split_sort = sort.split_by("group") + split_rec = rec.split_by("group") + analyzer = create_sorting_analyzer(split_sort, split_rec) + analyzer_unit_ids = analyzer.unit_ids + + assert set(analyzer.unit_ids) == set(sort.unit_ids) + assert np.all(analyzer_unit_ids[analyzer.get_sorting_property("group") == 1] == split_sort[1].unit_ids) + assert np.all(analyzer_unit_ids[analyzer.get_sorting_property("group") == 2] == split_sort[2].unit_ids) + + assert np.all(sort.get_unit_spike_train(unit_id="5") == analyzer.sorting.get_unit_spike_train(unit_id="5")) + + # make a dict of sortings with keys which don't match the recordings keys + split_sort_bad_keys = { + bad_key: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == key]) + for bad_key, key in zip([3, 4], [1, 2]) + } + + with pytest.raises(ValueError): + analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group")) + + # make a dict of sortings, in a different order than the recording. This should + # still work + split_sort_different_order = { + 2: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 2]), + 1: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 1]), + } + combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group")) + assert np.all(sort.get_unit_spike_train(unit_id="5") == combined_analyzer.sorting.get_unit_spike_train(unit_id="5")) + + def test_load_without_runtime_info(tmp_path, dataset): import zarr diff --git a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py index fadac094aa..c04c66d60c 100644 --- a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py +++ b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py @@ -149,6 +149,21 @@ def test_unit_aggregation_does_not_preserve_ids_not_the_same_type(): assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"] +def test_aggregation_of_dicts(): + """ + Tests `aggregate_units` when the input is a dict of sortings. Checks that + the unit structure is maintained by the aggregation. + """ + + sorting1 = generate_sorting(num_units=4) + sorting2 = generate_sorting(num_units=2) + + aggregated_sorting = aggregate_units({"a": sorting1, "b": sorting2}) + + assert aggregated_sorting.get_num_units() == 6 + assert np.all(aggregated_sorting.get_property(key="aggregation_key") == np.array(["a", "a", "a", "a", "b", "b"])) + + def test_sampling_frequency_max_diff(): """Test that the sampling frequency max diff is respected.""" sorting1 = generate_sorting(sampling_frequency=30000, num_units=3) diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 838660df46..404bae5924 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -16,7 +16,7 @@ class UnitsAggregationSorting(BaseSorting): Parameters ---------- - sorting_list: list + sorting_list: list | dict List of BaseSorting objects to aggregate renamed_unit_ids: array-like If given, unit ids are renamed as provided. If None, unit ids are sequential integers. @@ -32,6 +32,11 @@ class UnitsAggregationSorting(BaseSorting): def __init__(self, sorting_list, renamed_unit_ids=None, sampling_frequency_max_diff=0): unit_map = {} + sorting_keys = [] + if isinstance(sorting_list, dict): + sorting_keys = list(sorting_list.keys()) + sorting_list = list(sorting_list.values()) + num_all_units = sum([sort.get_num_units() for sort in sorting_list]) if renamed_unit_ids is not None: assert len(np.unique(renamed_unit_ids)) == num_all_units, ( @@ -122,6 +127,13 @@ def __init__(self, sorting_list, renamed_unit_ids=None, sampling_frequency_max_d except Exception as ext: warnings.warn(f"Skipping property '{prop_name}' as numpy cannot concatente. Numpy error: {ext}") + # add a label to each unit, with which sorting it came from + if len(sorting_keys) > 0: + aggregation_keys = [] + for sort_key, sort in zip(sorting_keys, sorting_list): + aggregation_keys += [sort_key] * sort.get_num_units() + self.set_property(key="aggregation_key", values=aggregation_keys) + # add segments for i_seg in range(num_segments): parent_segments = [sort._sorting_segments[i_seg] for sort in sorting_list]