Skip to content

Commit 705d194

Browse files
authored
Merge pull request #4036 from samuelgarcia/templates_scaled
Templates.is_scaled > Templates.is_in_uV
2 parents 076d781 + 3ca989d commit 705d194

File tree

18 files changed

+59
-55
lines changed

18 files changed

+59
-55
lines changed

src/spikeinterface/benchmark/tests/common_benchmark_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret
7474
channel_ids=recording.channel_ids,
7575
unit_ids=gt_sorting.unit_ids,
7676
probe=recording.get_probe(),
77-
is_scaled=return_scaled,
77+
is_in_uV=return_scaled,
7878
)
7979
return gt_templates
8080

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
698698
channel_ids=self.sorting_analyzer.channel_ids,
699699
unit_ids=unit_ids,
700700
probe=self.sorting_analyzer.get_probe(),
701-
is_scaled=self.sorting_analyzer.return_in_uV,
701+
is_in_uV=self.sorting_analyzer.return_in_uV,
702702
)
703703
else:
704704
raise ValueError("`outputs` must be 'numpy' or 'Templates'")

src/spikeinterface/core/sparsity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def from_snr(
439439
return_scaled = templates_or_sorting_analyzer.return_scaled
440440
elif isinstance(templates_or_sorting_analyzer, Templates):
441441
assert noise_levels is not None, "To compute sparsity from snr you need to provide noise_levels"
442-
return_scaled = templates_or_sorting_analyzer.is_scaled
442+
return_scaled = templates_or_sorting_analyzer.is_in_uV
443443

444444
mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")
445445

@@ -491,9 +491,9 @@ def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode
491491
"You can set `return_scaled=True` when computing the templates."
492492
)
493493
elif isinstance(templates_or_sorting_analyzer, Templates):
494-
assert templates_or_sorting_analyzer.is_scaled, (
494+
assert templates_or_sorting_analyzer.is_in_uV, (
495495
"To compute sparsity from amplitude you need to have scaled templates. "
496-
"You can set `is_scaled=True` when creating the Templates object."
496+
"You can set `is_in_uV=True` when creating the Templates object."
497497
)
498498

499499
mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")

src/spikeinterface/core/template.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class Templates:
3131
Array of unit IDs. If `None`, defaults to an array of increasing integers.
3232
probe: Probe, default: None
3333
A `probeinterface.Probe` object
34-
is_scaled : bool, optional default: True
34+
is_in_uV : bool, optional default: True
3535
If True, it means that the templates are in uV, otherwise they are in raw ADC values.
3636
check_for_consistent_sparsity : bool, optional default: None
3737
When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the
@@ -61,7 +61,7 @@ class Templates:
6161
templates_array: np.ndarray
6262
sampling_frequency: float
6363
nbefore: int
64-
is_scaled: bool = True
64+
is_in_uV: bool = True
6565

6666
sparsity_mask: np.ndarray = None
6767
channel_ids: np.ndarray = None
@@ -206,7 +206,7 @@ def to_sparse(self, sparsity):
206206
unit_ids=self.unit_ids,
207207
probe=self.probe,
208208
check_for_consistent_sparsity=self.check_for_consistent_sparsity,
209-
is_scaled=self.is_scaled,
209+
is_in_uV=self.is_in_uV,
210210
)
211211

212212
def get_one_template_dense(self, unit_index):
@@ -257,7 +257,7 @@ def to_dict(self):
257257
"unit_ids": self.unit_ids,
258258
"sampling_frequency": self.sampling_frequency,
259259
"nbefore": self.nbefore,
260-
"is_scaled": self.is_scaled,
260+
"is_in_uV": self.is_in_uV,
261261
"probe": self.probe.to_dict() if self.probe is not None else None,
262262
}
263263

@@ -270,7 +270,7 @@ def from_dict(cls, data):
270270
unit_ids=np.asarray(data["unit_ids"]),
271271
sampling_frequency=data["sampling_frequency"],
272272
nbefore=data["nbefore"],
273-
is_scaled=data["is_scaled"],
273+
is_in_uV=data["is_in_uV"],
274274
probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]),
275275
)
276276

@@ -304,7 +304,7 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None:
304304

305305
zarr_group.attrs["sampling_frequency"] = self.sampling_frequency
306306
zarr_group.attrs["nbefore"] = self.nbefore
307-
zarr_group.attrs["is_scaled"] = self.is_scaled
307+
zarr_group.attrs["is_in_uV"] = self.is_in_uV
308308

309309
if self.sparsity_mask is not None:
310310
zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask)
@@ -361,8 +361,12 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
361361
sampling_frequency = zarr_group.attrs["sampling_frequency"]
362362
nbefore = zarr_group.attrs["nbefore"]
363363

364-
# TODO: Consider eliminating the True and make it required
365-
is_scaled = zarr_group.attrs.get("is_scaled", True)
364+
if "is_scaled" in zarr_group.attrs:
365+
# prior to 0.103.0 "is_in_uV" was named "is_scaled", so for backward compatibility:
366+
is_in_uV = zarr_group.attrs["is_scaled"]
367+
else:
368+
# TODO: Consider eliminating the True and make it required
369+
is_in_uV = zarr_group.attrs.get("is_in_uV", True)
366370

367371
sparsity_mask = None
368372
if "sparsity_mask" in zarr_group:
@@ -380,7 +384,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
380384
channel_ids=channel_ids,
381385
unit_ids=unit_ids,
382386
probe=probe,
383-
is_scaled=is_scaled,
387+
is_in_uV=is_in_uV,
384388
)
385389

386390
@staticmethod

src/spikeinterface/core/template_tools.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc
2222
The dense templates (num_units, num_samples, num_channels)
2323
"""
2424
if isinstance(one_object, Templates):
25-
if return_scaled != one_object.is_scaled:
25+
if return_scaled != one_object.is_in_uV:
2626
raise ValueError(
2727
f"get_dense_templates_array: return_scaled={return_scaled} is not possible Templates has the reverse"
2828
)
@@ -165,15 +165,15 @@ def get_template_extremum_channel(
165165
channel_ids = templates_or_sorting_analyzer.channel_ids
166166

167167
# if SortingAnalyzer need to use global SortingAnalyzer return_scaled otherwise
168-
# we use the Templates is_scaled
168+
# we use the Templates is_in_uV
169169
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
170170
# For backward compatibility
171171
if hasattr(templates_or_sorting_analyzer, "return_scaled"):
172172
return_scaled = templates_or_sorting_analyzer.return_scaled
173173
else:
174174
return_scaled = templates_or_sorting_analyzer.return_in_uV
175175
else:
176-
return_scaled = templates_or_sorting_analyzer.is_scaled
176+
return_scaled = templates_or_sorting_analyzer.is_in_uV
177177

178178
peak_values = get_template_amplitudes(
179179
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled
@@ -218,15 +218,15 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak
218218
shifts = {}
219219

220220
# We need to use the SortingAnalyzer return_scaled
221-
# We need to use the Templates is_scaled
221+
# We need to use the Templates is_in_uV
222222
if isinstance(templates_or_sorting_analyzer, SortingAnalyzer):
223223
# For backward compatibility
224224
if hasattr(templates_or_sorting_analyzer, "return_scaled"):
225225
return_scaled = templates_or_sorting_analyzer.return_scaled
226226
else:
227227
return_scaled = templates_or_sorting_analyzer.return_in_uV
228228
else:
229-
return_scaled = templates_or_sorting_analyzer.is_scaled
229+
return_scaled = templates_or_sorting_analyzer.is_in_uV
230230

231231
templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled)
232232

@@ -291,7 +291,7 @@ def get_template_extremum_amplitude(
291291
else:
292292
return_scaled = templates_or_sorting_analyzer.return_in_uV
293293
else:
294-
return_scaled = templates_or_sorting_analyzer.is_scaled
294+
return_scaled = templates_or_sorting_analyzer.is_in_uV
295295

296296
extremum_amplitudes = get_template_amplitudes(
297297
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_scaled=return_scaled, abs_value=abs_value

src/spikeinterface/core/tests/test_template_class.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from probeinterface import generate_multi_columns_probe
88

99

10-
def generate_test_template(template_type, is_scaled=True) -> Templates:
10+
def generate_test_template(template_type, is_in_uV=True) -> Templates:
1111
num_units = 3
1212
num_samples = 5
1313
num_channels = 4
@@ -28,7 +28,7 @@ def generate_test_template(template_type, is_scaled=True) -> Templates:
2828
probe=probe,
2929
unit_ids=unit_ids,
3030
channel_ids=channel_ids,
31-
is_scaled=is_scaled,
31+
is_in_uV=is_in_uV,
3232
)
3333
elif template_type == "sparse": # sparse with sparse templates
3434
sparsity_mask = np.array(
@@ -53,7 +53,7 @@ def generate_test_template(template_type, is_scaled=True) -> Templates:
5353
sampling_frequency=sampling_frequency,
5454
nbefore=nbefore,
5555
probe=probe,
56-
is_scaled=is_scaled,
56+
is_in_uV=is_in_uV,
5757
unit_ids=unit_ids,
5858
channel_ids=channel_ids,
5959
)
@@ -68,16 +68,16 @@ def generate_test_template(template_type, is_scaled=True) -> Templates:
6868
sampling_frequency=sampling_frequency,
6969
nbefore=nbefore,
7070
probe=probe,
71-
is_scaled=is_scaled,
71+
is_in_uV=is_in_uV,
7272
unit_ids=unit_ids,
7373
channel_ids=channel_ids,
7474
)
7575

7676

77-
@pytest.mark.parametrize("is_scaled", [True, False])
77+
@pytest.mark.parametrize("is_in_uV", [True, False])
7878
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
79-
def test_pickle_serialization(template_type, is_scaled, tmp_path):
80-
template = generate_test_template(template_type, is_scaled)
79+
def test_pickle_serialization(template_type, is_in_uV, tmp_path):
80+
template = generate_test_template(template_type, is_in_uV)
8181

8282
# Dump to pickle
8383
pkl_path = tmp_path / "templates.pkl"
@@ -91,21 +91,21 @@ def test_pickle_serialization(template_type, is_scaled, tmp_path):
9191
assert template == template_reloaded
9292

9393

94-
@pytest.mark.parametrize("is_scaled", [True, False])
94+
@pytest.mark.parametrize("is_in_uV", [True, False])
9595
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
96-
def test_json_serialization(template_type, is_scaled):
97-
template = generate_test_template(template_type, is_scaled)
96+
def test_json_serialization(template_type, is_in_uV):
97+
template = generate_test_template(template_type, is_in_uV)
9898

9999
json_str = template.to_json()
100100
template_reloaded_from_json = Templates.from_json(json_str)
101101

102102
assert template == template_reloaded_from_json
103103

104104

105-
@pytest.mark.parametrize("is_scaled", [True, False])
105+
@pytest.mark.parametrize("is_in_uV", [True, False])
106106
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
107-
def test_get_dense_templates(template_type, is_scaled):
108-
template = generate_test_template(template_type, is_scaled)
107+
def test_get_dense_templates(template_type, is_in_uV):
108+
template = generate_test_template(template_type, is_in_uV)
109109
dense_templates = template.get_dense_templates()
110110
assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels)
111111

@@ -115,10 +115,10 @@ def test_initialization_fail_with_dense_templates():
115115
template = generate_test_template(template_type="sparse_with_dense_templates")
116116

117117

118-
@pytest.mark.parametrize("is_scaled", [True, False])
118+
@pytest.mark.parametrize("is_in_uV", [True, False])
119119
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
120-
def test_save_and_load_zarr(template_type, is_scaled, tmp_path):
121-
original_template = generate_test_template(template_type, is_scaled)
120+
def test_save_and_load_zarr(template_type, is_in_uV, tmp_path):
121+
original_template = generate_test_template(template_type, is_in_uV)
122122

123123
zarr_path = tmp_path / "templates.zarr"
124124
original_template.to_zarr(str(zarr_path))
@@ -129,10 +129,10 @@ def test_save_and_load_zarr(template_type, is_scaled, tmp_path):
129129
assert original_template == loaded_template
130130

131131

132-
@pytest.mark.parametrize("is_scaled", [True, False])
132+
@pytest.mark.parametrize("is_in_uV", [True, False])
133133
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
134-
def test_select_units(template_type, is_scaled):
135-
template = generate_test_template(template_type, is_scaled)
134+
def test_select_units(template_type, is_in_uV):
135+
template = generate_test_template(template_type, is_in_uV)
136136
selected_unit_ids = ["unit_a", "unit_c"]
137137
selected_unit_ids_indices = [0, 2]
138138

@@ -149,10 +149,10 @@ def test_select_units(template_type, is_scaled):
149149
assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[selected_unit_ids_indices])
150150

151151

152-
@pytest.mark.parametrize("is_scaled", [True, False])
152+
@pytest.mark.parametrize("is_in_uV", [True, False])
153153
@pytest.mark.parametrize("template_type", ["dense"])
154-
def test_select_channels(template_type, is_scaled):
155-
template = generate_test_template(template_type, is_scaled)
154+
def test_select_channels(template_type, is_in_uV):
155+
template = generate_test_template(template_type, is_in_uV)
156156
selected_channel_ids = ["channel1", "channel3"]
157157
selected_channel_ids_indices = [0, 2]
158158

src/spikeinterface/core/tests/test_template_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer):
4848
sparsity_mask=None,
4949
channel_ids=sorting_analyzer.channel_ids,
5050
unit_ids=sorting_analyzer.unit_ids,
51-
is_scaled=sorting_analyzer.return_in_uV,
51+
is_in_uV=sorting_analyzer.return_in_uV,
5252
)
5353
return templates
5454

src/spikeinterface/generation/drift_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def from_static_templates(cls, templates: Templates):
166166
nbefore=templates.nbefore,
167167
probe=templates.probe,
168168
sparsity_mask=templates.sparsity_mask,
169-
is_scaled=templates.is_scaled,
169+
is_in_uV=templates.is_in_uV,
170170
unit_ids=templates.unit_ids,
171171
channel_ids=templates.channel_ids,
172172
)

src/spikeinterface/generation/drifting_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def generate_drifting_recording(
459459
sampling_frequency=sampling_frequency,
460460
nbefore=nbefore,
461461
probe=probe,
462-
is_scaled=True,
462+
is_in_uV=True,
463463
)
464464

465465
drifting_templates = DriftingTemplates.from_static_templates(templates)

src/spikeinterface/generation/tests/test_drift_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def make_some_templates():
6565
sampling_frequency=sampling_frequency,
6666
nbefore=nbefore,
6767
probe=probe,
68-
is_scaled=True,
68+
is_in_uV=True,
6969
)
7070

7171
return templates

0 commit comments

Comments
 (0)