Skip to content

Commit fcf2284

Browse files
authored
Merge pull request #3982 from alejoe91/remove-cast-unsigned
Remove auto_cast_uint, cast_unsigned, and modify fix_dtype
2 parents 9522464 + 3d0a98b commit fcf2284

File tree

6 files changed

+37
-127
lines changed

6 files changed

+37
-127
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ def get_traces(
297297
order: "C" | "F" | None = None,
298298
return_scaled: bool | None = None,
299299
return_in_uV: bool = False,
300-
cast_unsigned: bool = False,
301300
) -> np.ndarray:
302301
"""Returns traces from recording.
303302
@@ -320,9 +319,6 @@ def get_traces(
320319
return_in_uV : bool, default: False
321320
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
322321
traces are scaled to uV
323-
cast_unsigned : bool, default: False
324-
If True and the traces are unsigned, they are cast to integer and centered
325-
(an offset of (2**nbits) is subtracted)
326322
327323
Returns
328324
-------
@@ -345,17 +341,6 @@ def get_traces(
345341
assert order in ["C", "F"]
346342
traces = np.asanyarray(traces, order=order)
347343

348-
if cast_unsigned:
349-
dtype = traces.dtype
350-
# if dtype is unsigned, return centered signed signal
351-
if dtype.kind == "u":
352-
itemsize = dtype.itemsize
353-
assert itemsize < 8, "Cannot upcast uint64!"
354-
nbits = dtype.itemsize * 8
355-
# upcast to int with double itemsize
356-
traces = traces.astype(f"int{2 * (dtype.itemsize) * 8}") - 2 ** (nbits - 1)
357-
traces = traces.astype(f"int{dtype.itemsize * 8}")
358-
359344
# Handle deprecated return_scaled parameter
360345
if return_scaled is not None:
361346
warnings.warn(

src/spikeinterface/core/recording_tools.py

Lines changed: 9 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,12 @@ def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0):
5454

5555

5656
# used by write_binary_recording + ChunkRecordingExecutor
57-
def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsigned):
57+
def _init_binary_worker(recording, file_path_dict, dtype, byte_offest):
5858
# create a local dict per worker
5959
worker_ctx = {}
6060
worker_ctx["recording"] = recording
6161
worker_ctx["byte_offset"] = byte_offest
6262
worker_ctx["dtype"] = np.dtype(dtype)
63-
worker_ctx["cast_unsigned"] = cast_unsigned
6463

6564
file_dict = {segment_index: open(file_path, "r+") for segment_index, file_path in file_path_dict.items()}
6665
worker_ctx["file_dict"] = file_dict
@@ -74,7 +73,6 @@ def write_binary_recording(
7473
dtype: np.typing.DTypeLike = None,
7574
add_file_extension: bool = True,
7675
byte_offset: int = 0,
77-
auto_cast_uint: bool = True,
7876
verbose: bool = False,
7977
**job_kwargs,
8078
):
@@ -98,9 +96,6 @@ def write_binary_recording(
9896
byte_offset : int, default: 0
9997
Offset in bytes for the binary file (e.g. to write a header). This is useful in case you want to append data
10098
to an existing file where you wrote a header or other data before.
101-
auto_cast_uint : bool, default: True
102-
If True, unsigned integers are automatically cast to int if the specified dtype is signed
103-
.. deprecated:: 0.103, use the `unsigned_to_signed` function instead.
10499
verbose : bool
105100
This is the verbosity of the ChunkRecordingExecutor
106101
{}
@@ -117,12 +112,6 @@ def write_binary_recording(
117112
file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list]
118113

119114
dtype = dtype if dtype is not None else recording.get_dtype()
120-
if auto_cast_uint:
121-
cast_unsigned = determine_cast_unsigned(recording, dtype)
122-
warning_message = (
123-
"auto_cast_uint is deprecated and will be removed in 0.103. Use the `unsigned_to_signed` function instead."
124-
)
125-
warnings.warn(warning_message, DeprecationWarning, stacklevel=2)
126115

127116
dtype_size_bytes = np.dtype(dtype).itemsize
128117
num_channels = recording.get_num_channels()
@@ -144,7 +133,7 @@ def write_binary_recording(
144133
# use executor (loop or workers)
145134
func = _write_binary_chunk
146135
init_func = _init_binary_worker
147-
init_args = (recording, file_path_dict, dtype, byte_offset, cast_unsigned)
136+
init_args = (recording, file_path_dict, dtype, byte_offset)
148137
executor = ChunkRecordingExecutor(
149138
recording, func, init_func, init_args, job_name="write_binary_recording", verbose=verbose, **job_kwargs
150139
)
@@ -157,7 +146,6 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx):
157146
recording = worker_ctx["recording"]
158147
dtype = worker_ctx["dtype"]
159148
byte_offset = worker_ctx["byte_offset"]
160-
cast_unsigned = worker_ctx["cast_unsigned"]
161149
file = worker_ctx["file_dict"][segment_index]
162150

163151
num_channels = recording.get_num_channels()
@@ -181,9 +169,7 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx):
181169
memmap_array = np.ndarray(shape=shape, dtype=dtype, buffer=memmap_obj, offset=start_offset)
182170

183171
# Extract the traces and store them in the memmap array
184-
traces = recording.get_traces(
185-
start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned
186-
)
172+
traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index)
187173

188174
if traces.dtype != dtype:
189175
traces = traces.astype(dtype, copy=False)
@@ -243,7 +229,7 @@ def write_binary_recording_file_handle(
243229

244230

245231
# used by write_memory_recording
246-
def _init_memory_worker(recording, arrays, shm_names, shapes, dtype, cast_unsigned):
232+
def _init_memory_worker(recording, arrays, shm_names, shapes, dtype):
247233
# create a local dict per worker
248234
worker_ctx = {}
249235
if isinstance(recording, dict):
@@ -269,7 +255,6 @@ def _init_memory_worker(recording, arrays, shm_names, shapes, dtype, cast_unsign
269255
arrays.append(arr)
270256

271257
worker_ctx["arrays"] = arrays
272-
worker_ctx["cast_unsigned"] = cast_unsigned
273258

274259
return worker_ctx
275260

@@ -280,17 +265,14 @@ def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx):
280265
recording = worker_ctx["recording"]
281266
dtype = worker_ctx["dtype"]
282267
arr = worker_ctx["arrays"][segment_index]
283-
cast_unsigned = worker_ctx["cast_unsigned"]
284268

285269
# apply function
286-
traces = recording.get_traces(
287-
start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned
288-
)
270+
traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index)
289271
traces = traces.astype(dtype, copy=False)
290272
arr[start_frame:end_frame, :] = traces
291273

292274

293-
def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, buffer_type="auto", **job_kwargs):
275+
def write_memory_recording(recording, dtype=None, verbose=False, buffer_type="auto", **job_kwargs):
294276
"""
295277
Save the traces into numpy arrays (memory).
296278
try to use the SharedMemory introduce in py3.8 if n_jobs > 1
@@ -303,8 +285,6 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=
303285
Type of the saved data
304286
verbose : bool, default: False
305287
If True, output is verbose (when chunks are used)
306-
auto_cast_uint : bool, default: True
307-
If True, unsigned integers are automatically cast to int if the specified dtype is signed
308288
buffer_type : "auto" | "numpy" | "sharedmem"
309289
{}
310290
@@ -316,10 +296,6 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=
316296

317297
if dtype is None:
318298
dtype = recording.get_dtype()
319-
if auto_cast_uint:
320-
cast_unsigned = determine_cast_unsigned(recording, dtype)
321-
else:
322-
cast_unsigned = False
323299

324300
# create sharedmmep
325301
arrays = []
@@ -352,9 +328,9 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=
352328
func = _write_memory_chunk
353329
init_func = _init_memory_worker
354330
if n_jobs > 1:
355-
init_args = (recording, None, shm_names, shapes, dtype, cast_unsigned)
331+
init_args = (recording, None, shm_names, shapes, dtype)
356332
else:
357-
init_args = (recording, arrays, None, None, dtype, cast_unsigned)
333+
init_args = (recording, arrays, None, None, dtype)
358334

359335
executor = ChunkRecordingExecutor(
360336
recording, func, init_func, init_args, verbose=verbose, job_name="write_memory_recording", **job_kwargs
@@ -379,7 +355,6 @@ def write_to_h5_dataset_format(
379355
chunk_size=None,
380356
chunk_memory="500M",
381357
verbose=False,
382-
auto_cast_uint=True,
383358
return_scaled=None,
384359
return_in_uV=False,
385360
):
@@ -413,8 +388,6 @@ def write_to_h5_dataset_format(
413388
Chunk size in bytes must end with "k", "M" or "G"
414389
verbose : bool, default: False
415390
If True, output is verbose (when chunks are used)
416-
auto_cast_uint : bool, default: True
417-
If True, unsigned integers are automatically cast to int if the specified dtype is signed
418391
return_scaled : bool | None, default: None
419392
DEPRECATED. Use return_in_uV instead.
420393
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
@@ -446,10 +419,6 @@ def write_to_h5_dataset_format(
446419
dtype_file = recording.get_dtype()
447420
else:
448421
dtype_file = dtype
449-
if auto_cast_uint:
450-
cast_unsigned = determine_cast_unsigned(recording, dtype)
451-
else:
452-
cast_unsigned = False
453422

454423
if single_axis:
455424
shape = (num_frames,)
@@ -472,7 +441,7 @@ def write_to_h5_dataset_format(
472441
)
473442
return_in_uV = return_scaled
474443

475-
traces = recording.get_traces(cast_unsigned=cast_unsigned, return_scaled=return_in_uV)
444+
traces = recording.get_traces(return_scaled=return_in_uV)
476445
if dtype is not None:
477446
traces = traces.astype(dtype_file, copy=False)
478447
if time_axis == 1:
@@ -496,7 +465,6 @@ def write_to_h5_dataset_format(
496465
segment_index=segment_index,
497466
start_frame=i * chunk_size,
498467
end_frame=min((i + 1) * chunk_size, num_frames),
499-
cast_unsigned=cast_unsigned,
500468
return_scaled=return_in_uV if return_scaled is None else return_scaled,
501469
)
502470
chunk_frames = traces.shape[0]
@@ -517,16 +485,6 @@ def write_to_h5_dataset_format(
517485
return save_path
518486

519487

520-
def determine_cast_unsigned(recording, dtype):
521-
recording_dtype = np.dtype(recording.get_dtype())
522-
523-
if np.dtype(dtype) != recording_dtype and recording_dtype.kind == "u" and np.dtype(dtype).kind == "i":
524-
cast_unsigned = True
525-
else:
526-
cast_unsigned = False
527-
return cast_unsigned
528-
529-
530488
def get_random_recording_slices(
531489
recording,
532490
method="full_random",

src/spikeinterface/core/tests/test_baserecording.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -253,22 +253,6 @@ def test_BaseRecording(create_cache_folder):
253253
# Verify both parameters produce the same result
254254
assert np.array_equal(traces_float32_old, traces_float32_new)
255255

256-
# test cast unsigned
257-
tr_u = rec_uint16.get_traces(cast_unsigned=False)
258-
assert tr_u.dtype.kind == "u"
259-
tr_i = rec_uint16.get_traces(cast_unsigned=True)
260-
assert tr_i.dtype.kind == "i"
261-
folder = cache_folder / "recording_unsigned"
262-
rec_u = rec_uint16.save(folder=folder)
263-
rec_u.get_dtype() == "uint16"
264-
folder = cache_folder / "recording_signed"
265-
rec_i = rec_uint16.save(folder=folder, dtype="int16")
266-
rec_i.get_dtype() == "int16"
267-
assert np.allclose(
268-
rec_u.get_traces(cast_unsigned=False).astype("float") - (2**15), rec_i.get_traces().astype("float")
269-
)
270-
assert np.allclose(rec_u.get_traces(cast_unsigned=True), rec_i.get_traces().astype("float"))
271-
272256
# test cast with dtype
273257
rec_float32 = rec_int16.astype("float32")
274258
assert rec_float32.get_dtype() == "float32"
@@ -361,16 +345,6 @@ def test_BaseRecording(create_cache_folder):
361345
assert rec2.get_annotation(annotation_name) == rec_zarr2.get_annotation(annotation_name)
362346
assert rec2.get_annotation(annotation_name) == rec_zarr2_loaded.get_annotation(annotation_name)
363347

364-
# test cast unsigned
365-
rec_u = rec_uint16.save(format="zarr", folder=cache_folder / "rec_u")
366-
rec_u.get_dtype() == "uint16"
367-
rec_i = rec_uint16.save(format="zarr", folder=cache_folder / "rec_i", dtype="int16")
368-
rec_i.get_dtype() == "int16"
369-
assert np.allclose(
370-
rec_u.get_traces(cast_unsigned=False).astype("float") - (2**15), rec_i.get_traces().astype("float")
371-
)
372-
assert np.allclose(rec_u.get_traces(cast_unsigned=True), rec_i.get_traces().astype("float"))
373-
374348

375349
def test_interleaved_probegroups():
376350
recording = generate_recording(durations=[1.0], num_channels=16)

src/spikeinterface/core/zarrextractors.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .basesorting import BaseSorting, SpikeVectorSortingSegment, minimum_spike_dtype
1212
from .core_tools import define_function_from_class, check_json
1313
from .job_tools import split_job_kwargs
14-
from .recording_tools import determine_cast_unsigned
1514
from .core_tools import is_path_remote
1615

1716

@@ -446,7 +445,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G
446445

447446
# Recording
448447
def add_recording_to_zarr_group(
449-
recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, dtype=None, **kwargs
448+
recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, dtype=None, **kwargs
450449
):
451450
zarr_kwargs, job_kwargs = split_job_kwargs(kwargs)
452451

@@ -478,7 +477,6 @@ def add_recording_to_zarr_group(
478477
filters=filters_traces,
479478
dtype=dtype,
480479
channel_chunk_size=channel_chunk_size,
481-
auto_cast_uint=auto_cast_uint,
482480
verbose=verbose,
483481
**job_kwargs,
484482
)
@@ -522,7 +520,6 @@ def add_traces_to_zarr(
522520
compressor=None,
523521
filters=None,
524522
verbose=False,
525-
auto_cast_uint=True,
526523
**job_kwargs,
527524
):
528525
"""
@@ -546,8 +543,6 @@ def add_traces_to_zarr(
546543
List of zarr filters
547544
verbose : bool, default: False
548545
If True, output is verbose (when chunks are used)
549-
auto_cast_uint : bool, default: True
550-
If True, unsigned integers are automatically cast to int if the specified dtype is signed
551546
{}
552547
"""
553548
from .job_tools import (
@@ -564,10 +559,6 @@ def add_traces_to_zarr(
564559

565560
if dtype is None:
566561
dtype = recording.get_dtype()
567-
if auto_cast_uint:
568-
cast_unsigned = determine_cast_unsigned(recording, dtype)
569-
else:
570-
cast_unsigned = False
571562

572563
job_kwargs = fix_job_kwargs(job_kwargs)
573564
chunk_size = ensure_chunk_size(recording, **job_kwargs)
@@ -593,23 +584,22 @@ def add_traces_to_zarr(
593584
# use executor (loop or workers)
594585
func = _write_zarr_chunk
595586
init_func = _init_zarr_worker
596-
init_args = (recording, zarr_datasets, dtype, cast_unsigned)
587+
init_args = (recording, zarr_datasets, dtype)
597588
executor = ChunkRecordingExecutor(
598589
recording, func, init_func, init_args, verbose=verbose, job_name="write_zarr_recording", **job_kwargs
599590
)
600591
executor.run()
601592

602593

603594
# used by write_zarr_recording + ChunkRecordingExecutor
604-
def _init_zarr_worker(recording, zarr_datasets, dtype, cast_unsigned):
595+
def _init_zarr_worker(recording, zarr_datasets, dtype):
605596
import zarr
606597

607598
# create a local dict per worker
608599
worker_ctx = {}
609600
worker_ctx["recording"] = recording
610601
worker_ctx["zarr_datasets"] = zarr_datasets
611602
worker_ctx["dtype"] = np.dtype(dtype)
612-
worker_ctx["cast_unsigned"] = cast_unsigned
613603

614604
return worker_ctx
615605

@@ -622,11 +612,12 @@ def _write_zarr_chunk(segment_index, start_frame, end_frame, worker_ctx):
622612
recording = worker_ctx["recording"]
623613
dtype = worker_ctx["dtype"]
624614
zarr_dataset = worker_ctx["zarr_datasets"][segment_index]
625-
cast_unsigned = worker_ctx["cast_unsigned"]
626615

627616
# apply function
628617
traces = recording.get_traces(
629-
start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned
618+
start_frame=start_frame,
619+
end_frame=end_frame,
620+
segment_index=segment_index,
630621
)
631622
traces = traces.astype(dtype)
632623
zarr_dataset[start_frame:end_frame, :] = traces

0 commit comments

Comments
 (0)