Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions doc/how_to/process_by_channel_group.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. _recording-by-channel-group:
.. _process_by_group:

Process a recording by channel group
====================================
Expand Down Expand Up @@ -101,8 +101,8 @@ to any preprocessing function.
referenced_recording = spre.common_reference(filtered_recording)
good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording)

We can then aggregate the recordings back together using the ``aggregate_channels`` function.
Note that we do not need to do this to sort the data (see :ref:`sorting-by-channel-group`).
If needed, we could aggregate the recordings back together using the ``aggregate_channels`` function.
Note: you do not need to do this to sort the data (see :ref:`sorting-by-channel-group`).

.. code-block:: python

Expand Down Expand Up @@ -145,14 +145,16 @@ Sorting a Recording by Channel Group
We can also sort a recording for each channel group separately. It is not necessary to preprocess
a recording by channel group in order to sort by channel group.

There are two ways to sort a recording by channel group. First, we can simply pass the output from
our preprocessing-by-group method above. Second, for more control, we can loop over the recordings
There are two ways to sort a recording by channel group. First, we can pass a dictionary to the
``run_sorter`` function. Since the preprocessing-by-group method above returns a dict, we can
simply pass this output. Alternatively, for more control, we can loop over the recordings
ourselves.

**Option 1 : Automatic splitting**
**Option 1 : Automatic splitting (Recommended)**

Simply pass the split recording to the `run_sorter` function, as if it was a non-split recording.
This will return a dict of sortings, with the keys corresponding to the groups.
Simply pass the split recording to the ``run_sorter`` function, as if it was a non-split recording.
This will return a dict of sortings, with the same keys as the dict of recordings that were
passed to ``run_sorter``.

.. code-block:: python

Expand All @@ -162,10 +164,10 @@ This will return a dict of sortings, with the keys corresponding to the groups.
# do preprocessing if needed
pp_recording = spre.bandpass_filter(split_recording)

dict_of_sortings = run_sorter(
sorter_name='kilosort2',
dict_of_sortings = run_sorter(
sorter_name='kilosort4',
recording=pp_recording,
working_folder='working_path'
folder='my_kilosort4_sorting'
)


Expand All @@ -190,3 +192,37 @@ to different groups.
folder=f"folder_KS2_group{group}"
)
sortings[group] = sorting


Creating a SortingAnalyzer by Channel Group
-------------------------------------------

The code above generates a dictionary of recording objects and a dictionary of sorting objects.
When making a :ref:`SortingAnalyzer <modules/core:SortingAnalyzer>`, we can pass these dictionaries and
a single analyzer will be created, with the recordings and sortings appropriately aggregated.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? Do keys of the dict need to match up? Will the user know how the sorting dict looks vs the recording dict?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could check that keys are the same in the same order no ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we either do that for the users internally or need to explain it here. So I think doing it ourselves is fine. Chris is doing a dict key comparison below but I haven't tested it to see if it checks the order or just the presence of the same keys.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to hopefully make things clearer.

The order of keys doesn't matter: when you aggregate, the link between the recording channels and unit ids is independent of this dict stuff I'm adding. Added a test to check this is true.


The dictionary of recordings and dictionary of sortings must have the same keys. E.g. if you
use ``split_by("group")``, the keys of your dict of recordings will be the values of the ``group``
property of the recording. Then the dict of sortings should also have these keys.
Note that if you use the internal functions, like we do in the code-block below, you don't need to
keep track of keys yourself. SpikeInterface will do this for you automatically.

The code for create ``SortingAnalyzer`` from dicts of recordings and sortings is very similar to that for
creating a sorting analyzer from a single recording and sorting:

.. code-block:: python

dict_of_recordings = preprocessed_recording.split_by("group")
dict_of_sortings = run_sorter(sorter_name="mountainsort5", recording = dict_of_recordings)

analyzer = create_sorting_analyzer(sorting=dict_of_sortings, recording=dict_of_recordings)


The code above creates a *single* sorting analyzer called :code:`analyzer`. You can select the units
from one of the "group"s as follows:

.. code-block:: python

aggretation_keys = analyzer.get_sorting_property("aggregation_key")
unit_ids_group_0 = analyzer.unit_ids[aggretation_keys == 0]
group_0_analzyer = analyzer.select_units(unit_ids = unit_ids_group_0)
10 changes: 8 additions & 2 deletions doc/modules/sorters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,11 @@ Running spike sorting by group is indeed a very common need.

A :py:class:`~spikeinterface.core.BaseRecording` object has the ability to split itself into a dictionary of
sub-recordings given a certain property (see :py:meth:`~spikeinterface.core.BaseRecording.split_by`).
So it is easy to loop over this dictionary and sequentially run spike sorting on these sub-recordings.
The :py:func:`~spikeinterface.sorters.run_sorter` method can also accept the dictionary which is returned
The :py:func:`~spikeinterface.sorters.run_sorter` method can accept the dictionary which is returned
by :py:meth:`~spikeinterface.core.BaseRecording.split_by` and will return a dictionary of sortings.
In turn, these can be fed directly to :py:meth:`~spikeinterface.core.create_sorting_analyzer` to make
a SortingAnalyzer. For more control, you can loop over the dictionary returned by :py:meth:`~spikeinterface.core.BaseRecording.split_by`
and sequentially run spike sorting on these sub-recordings.

In this example, we create a 16-channel recording with 4 tetrodes:

Expand Down Expand Up @@ -396,6 +398,10 @@ In this example, we create a 16-channel recording with 4 tetrodes:
sortings[group] = sorting


Note: you can feed the dict of sortings and dict of recordings directly to :code:`create_sorting_analyzer` to make
a SortingAnalyzer from the split data: :ref:`read more <process_by_group>`.


Handling multi-segment recordings
---------------------------------

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def split_by(self, property="group", outputs="dict"):
recordings = {}
for value in np.unique(values).tolist():
(inds,) = np.nonzero(values == value)
new_channel_ids = self.get_channel_ids()[inds]
new_channel_ids = self.channel_ids[inds]
subrec = self.select_channels(new_channel_ids)
subrec.set_annotation("split_by_property", value=property)
if outputs == "list":
Expand Down
41 changes: 41 additions & 0 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,47 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo

return self.frame_slice(start_frame=start_frame, end_frame=end_frame)

def split_by(self, property="group", outputs="dict"):
"""
Splits object based on a certain property (e.g. "group")

Parameters
----------
property : str, default: "group"
The property to use to split the object, default: "group"
outputs : "dict" | "list", default: "dict"
Whether to return a dict or a list

Returns
-------
dict or list
A dict or list with grouped objects based on property

Raises
------
ValueError
Raised when property is not present
"""
assert outputs in ("list", "dict")
values = self.get_property(property)
if values is None:
raise ValueError(f"property {property} is not set")

if outputs == "list":
sortings = []
elif outputs == "dict":
sortings = {}
for value in np.unique(values).tolist():
(inds,) = np.nonzero(values == value)
new_unit_ids = self.unit_ids[inds]
subsort = self.select_units(new_unit_ids)
subsort.set_annotation("split_by_property", value=property)
if outputs == "list":
sortings.append(subsort)
elif outputs == "dict":
sortings[value] = subsort
return sortings

def time_to_sample_index(self, time, segment_index=0):
"""
Transform time in seconds into sample index
Expand Down
25 changes: 2 additions & 23 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record

self._recordings = recording_list

splitting_known = self._is_splitting_known()
if not splitting_known:
for group_id, recording in zip(recording_ids, recording_list):
recording.set_property("group", [group_id] * recording.get_num_channels())
for group_id, recording in zip(recording_ids, recording_list):
recording.set_property("aggregation_key", [group_id] * recording.get_num_channels())

self._perform_consistency_checks()
sampling_frequency = recording_list[0].get_sampling_frequency()
Expand Down Expand Up @@ -140,25 +138,6 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record
def recordings(self):
return self._recordings

def _is_splitting_known(self):

# If we have the `split_by_property` annotation, we know how the recording was split
if self._recordings[0].get_annotation("split_by_property") is not None:
return True

# Check if all 'group' properties are equal to 0
recording_groups = []
for recording in self._recordings:
if (group_labels := recording.get_property("group")) is not None:
recording_groups.extend(group_labels)
else:
recording_groups.extend([0])
# If so, we don't know the splitting
if np.all(np.unique(recording_groups) == np.array([0])):
return False
else:
return True

def _perform_consistency_checks(self):

# Check for consistent sampling frequency across recordings
Expand Down
43 changes: 37 additions & 6 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

import spikeinterface

from .baserecording import BaseRecording
from .basesorting import BaseSorting
from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units

from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match
from .core_tools import (
Expand Down Expand Up @@ -54,6 +53,7 @@ def create_sorting_analyzer(
folder=None,
sparse=True,
sparsity=None,
set_sparsity_by_dict_key=False,
return_scaled=None,
return_in_uV=True,
overwrite=False,
Expand All @@ -71,10 +71,10 @@ def create_sorting_analyzer(

Parameters
----------
sorting : Sorting
The sorting object
recording : Recording
The recording object
sorting : Sorting | dict
The sorting object, or a dict of them
recording : Recording | dict
The recording object, or a dict of them
folder : str or Path or None, default: None
The folder where analyzer is cached
format : "memory | "binary_folder" | "zarr", default: "memory"
Expand All @@ -88,6 +88,9 @@ def create_sorting_analyzer(
You can control `estimate_sparsity()` : all extra arguments are propagated to it (included job_kwargs)
sparsity : ChannelSparsity or None, default: None
The sparsity used to compute exensions. If this is given, `sparse` is ignored.
set_sparsity_by_dict_key : bool, default: False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I support this.
What about setting it to True ?

Copy link
Member Author

@chrishalcrow chrishalcrow Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's False by default, some tetrode people will have un-sparsified tetrode bundles (not great, but ok)
If True by default, some silicon people will have badly-sparsified probes (worse!)

So I vote False!

If True and passing recording and sorting dicts, will set the sparsity based on the dict keys,
and other `sparsity_kwargs` are overwritten. If False, use other sparsity settings.
return_scaled : bool | None, default: None
DEPRECATED. Use return_in_uV instead.
All extensions that play with traces will use this global return_in_uV : "waveforms", "noise_levels", "templates".
Expand Down Expand Up @@ -139,6 +142,34 @@ def create_sorting_analyzer(
In some situation, sparsity is not needed, so to make it fast creation, you need to turn
sparsity off (or give external sparsity) like this.
"""

if isinstance(sorting, dict) and isinstance(recording, dict):

if sorting.keys() != recording.keys():
raise ValueError(
f"Keys of `sorting`, {sorting.keys()}, and `recording`, {recording.keys()}, dicts do not match."
)

aggregated_recording = aggregate_channels(recording)
aggregated_sorting = aggregate_units(sorting)

if set_sparsity_by_dict_key:
sparsity_kwargs = {"method": "by_property", "by_property": "aggregation_key"}

return create_sorting_analyzer(
sorting=aggregated_sorting,
recording=aggregated_recording,
format=format,
folder=folder,
sparse=sparse,
sparsity=sparsity,
return_scaled=return_scaled,
return_in_uV=return_in_uV,
overwrite=overwrite,
backend_options=backend_options,
**sparsity_kwargs,
)

if format != "memory":
if format == "zarr":
if not is_path_remote(folder):
Expand Down
42 changes: 38 additions & 4 deletions src/spikeinterface/core/tests/test_channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from spikeinterface.core import aggregate_channels
from spikeinterface.core import generate_recording
from spikeinterface.core.testing import check_recordings_equal


def test_channelsaggregationrecording():
Expand Down Expand Up @@ -114,6 +115,39 @@ def test_split_then_aggreate_preserve_user_property():
assert np.all(old_properties_ids_dict == new_properties_ids_dict)


def test_aggregation_split_by_and_manual():
"""
We can either split recordings automatically using "split_by" or manually by
constructing dictionaries. This test checks the two are equivalent. We skip
the annoations check since the "split_by" also saves an annotation to save what
property we split by.
"""

rec1 = generate_recording(num_channels=6)
rec1_channel_ids = rec1.get_channel_ids()
rec1.set_property(key="brain_area", values=["a", "a", "b", "a", "b", "a"])

split_recs = rec1.split_by("brain_area")

aggregated_rec = aggregate_channels(split_recs)

rec_a_channel_ids = aggregated_rec.channel_ids[aggregated_rec.get_property("brain_area") == "a"]
rec_b_channel_ids = aggregated_rec.channel_ids[aggregated_rec.get_property("brain_area") == "b"]

assert np.all(rec_a_channel_ids == split_recs["a"].channel_ids)
assert np.all(rec_b_channel_ids == split_recs["b"].channel_ids)

split_recs_manual = {
"a": rec1.select_channels(channel_ids=rec1_channel_ids[rec1.get_property("brain_area") == "a"]),
"b": rec1.select_channels(channel_ids=rec1_channel_ids[rec1.get_property("brain_area") == "b"]),
}

aggregated_rec_manual = aggregate_channels(split_recs_manual)

assert np.all(aggregated_rec_manual.get_property("aggregation_key") == ["a", "a", "a", "a", "b", "b"])
check_recordings_equal(aggregated_rec, aggregated_rec_manual, check_annotations=False, check_properties=True)


def test_channel_aggregation_preserve_ids():

recording1 = generate_recording(num_channels=3, durations=[10], set_probe=False) # To avoid location check
Expand All @@ -132,9 +166,9 @@ def test_aggregation_labeling_for_lists():
recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False)
recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False)

# If we don't label at all, aggregation will add a 'group' label
# If we don't label at all, aggregation will add a 'aggregation_key' label
aggregated_recording = aggregate_channels([recording1, recording2])
group_property = aggregated_recording.get_property("group")
group_property = aggregated_recording.get_property("aggregation_key")
assert np.all(group_property == [0, 0, 0, 0, 1, 1])

# If we have different group labels, these should be respected
Expand All @@ -161,9 +195,9 @@ def test_aggretion_labelling_for_dicts():
recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False)
recording2 = generate_recording(num_channels=2, durations=[20], set_probe=False)

# If we don't label at all, aggregation will add a 'group' label based on the dict keys
# If we don't label at all, aggregation will add a 'aggregation_key' label based on the dict keys
aggregated_recording = aggregate_channels({0: recording1, "cat": recording2})
group_property = aggregated_recording.get_property("group")
group_property = aggregated_recording.get_property("aggregation_key")
assert np.all(group_property == [0, 0, 0, 0, "cat", "cat"])

# If we have different group labels, these should be respected
Expand Down
Loading
Loading