diff --git a/CHANGELOG.md b/CHANGELOG.md index 248b3fc..bd517e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [2.7.0] 2025-12-10 + +### Added +- `phylib.alf.io` reverse Alf to Phy conversion to get AU amplitudes ala kilosort from true units amplitudes. + ## [2.6.4] 2025-12-07 ### Fixed diff --git a/phylib/__init__.py b/phylib/__init__.py index cfa18a5..11f90de 100644 --- a/phylib/__init__.py +++ b/phylib/__init__.py @@ -22,7 +22,7 @@ __author__ = 'Cyrille Rossant' __email__ = 'cyrille.rossant at gmail.com' -__version__ = '2.6.4' +__version__ = '2.7.0' __version_git__ = __version__ + _git_version() diff --git a/phylib/io/alf.py b/phylib/io/alf.py index cfa3e33..e429204 100644 --- a/phylib/io/alf.py +++ b/phylib/io/alf.py @@ -307,3 +307,96 @@ def compress_spikes_dtypes(self): for attribute in ['templates', 'clusters']: fn = next(self.out_path.glob(f'spikes.{attribute}.*npy')) np.save(fn, np.load(fn).astype(np.uint16)) + + +def alf2phy(alf_path: Path, target_path=None, use_symlinks: bool = True, s2v=2.34375e-06) -> Path: + """Convert ALF format dataset to phy format for visualization and manual curation. + + This function converts spike sorting data from ALF (ALyx File) format back to the + phy/KiloSort format. It performs inverse transformations including unwhitening + templates, converting amplitudes from voltage to arbitrary units, and creating + the necessary file structure for phy GUI compatibility. + + Args: + alf_path (Path): Path to the directory containing ALF format files. This directory + should contain templates.waveforms.npy, spikes.clusters.npy, and other + standard ALF files. + target_path (Path, optional): Path where the phy format files will be created. + If None, creates a '_phy' subdirectory within alf_path. Defaults to None. + use_symlinks (bool, optional): If True, creates symbolic links for compatible + files instead of copying them. If False, raises NotImplementedError as + copying is not yet implemented. Defaults to True. + s2v (float, optional): Sample-to-voltage conversion factor used to convert + amplitudes from voltage units back to arbitrary units. Defaults to + 2.34375e-06. + + Returns: + Path: The path to the target directory containing the converted phy format files. + This directory will contain files like amplitudes.npy, templates.npy, + channel_map.npy, and cluster_group.tsv that are compatible with phy GUI. + """ + + target_path = target_path if target_path is not None else alf_path.joinpath('_phy') + target_path.mkdir(parents=True, exist_ok=True) + + # those are the easy files that can be directly copied / linked + file_renames = _FILE_RENAMES.copy() + file_renames.append(('spike_times.npy', 'spikes.samples.npy', False)) + for f in file_renames: + source_file = alf_path.joinpath(f[1]) + target_file = target_path.joinpath(f[0]) + if not source_file.exists(): + continue + if use_symlinks: + if target_file.exists(): + continue + target_file.symlink_to(source_file) + else: + raise NotImplementedError + + templates = { + 'waveforms': np.load(alf_path.joinpath('templates.waveforms.npy')), + 'waveformsChannels': np.load(alf_path.joinpath('templates.waveformsChannels.npy')), + } + spikes = { + 'clusters': np.load(alf_path.joinpath('spikes.clusters.npy')), + 'samples': np.load(alf_path.joinpath('spikes.samples.npy')), + 'amps': np.load(alf_path.joinpath('spikes.amps.npy')), + } + + # now we do the inverse processing to get to the AU amplitudes + wm = np.load(alf_path.joinpath('_kilosort_whitening.matrix.npy')) + nclu = templates['waveforms'].shape[0] + nch = wm.shape[0] + + def get_waveforms_amp(templates): + cha = np.max(templates, axis=1) - np.min(templates, axis=1) + return np.max(cha, axis=1) + + # we unwhiten the templates waveforms, this will expand the templates to the original non-sparse size + templates_phy = np.zeros([nclu, templates['waveforms'].shape[1], nch], dtype=np.float32) + for i in np.arange(templates_phy.shape[0]): + templates_phy[i] = np.matmul(templates['waveforms'][i], wm[templates['waveformsChannels'][i], :]) + + # the original templates have a rms of 1.0, so here we just need to normalize by rms + rms_templates = np.sum(np.sum(templates['waveforms'] ** 2, axis=1), axis=1) ** 0.5 + templates_phy = templates_phy / rms_templates[:, np.newaxis, np.newaxis] + + template_amps_au = get_waveforms_amp(templates['waveforms']) / rms_templates + + spikes_amps_au = spikes['amps'] / s2v / template_amps_au[spikes['clusters']] + + np.save(target_path.joinpath('amplitudes.npy'), spikes_amps_au) + np.save(target_path.joinpath('channel_map.npy'), np.arange(nch)) + np.save(target_path.joinpath('templates.npy'), templates_phy) + + np.save(target_path.joinpath('templates_ind.npy'), np.tile(np.arange(nclu)[np.newaxis, :], reps=[nch, 1])) + + # if we have metrics information, output the ks2_label information + if alf_path.joinpath('cluster.metrics.pqt').exists(): + import pandas as pd # optional dependency + df_cluster_metrics = pd.read_parquet(alf_path.joinpath('cluster.metrics.pqt')) + df_cluster_metrics.loc[:, ['cluster_id', 'ks2_label']].to_csv( + target_path.joinpath('cluster_group.tsv'), sep='\t') + + return target_path diff --git a/phylib/io/tests/test_alf.py b/phylib/io/tests/test_alf.py index cbd4428..400cef2 100644 --- a/phylib/io/tests/test_alf.py +++ b/phylib/io/tests/test_alf.py @@ -16,7 +16,7 @@ import numpy.random as nr from phylib.utils._misc import _write_tsv_simple -from ..alf import _FILE_RENAMES, _load, EphysAlfCreator +from ..alf import _FILE_RENAMES, _load, EphysAlfCreator, alf2phy from ..model import TemplateModel @@ -40,7 +40,9 @@ def __init__(self, tempdir): shutil.copy(p / 'spike_clusters.npy', p / 'spike_templates.npy') np.save(p / 'amplitudes.npy', nr.uniform(low=0.5, high=1.5, size=self.ns)) np.save(p / 'channel_positions.npy', np.c_[np.arange(self.nc), np.zeros(self.nc)]) - np.save(p / 'templates.npy', np.random.normal(size=(self.nt, 50, self.nc))) + templates = np.random.normal(size=(self.nt, 50, self.nc)) + templates = templates / (np.sum(np.sum(templates ** 2, axis=1), axis=1) ** .5)[:, np.newaxis, np.newaxis] + np.save(p / 'templates.npy', templates) np.save(p / 'similar_templates.npy', np.tile(np.arange(self.nt), (self.nt, 1))) np.save(p / 'channel_map.npy', np.c_[np.arange(self.nc)]) np.save(p / 'channel_probe.npy', np.zeros(self.nc)) @@ -281,3 +283,20 @@ def test_merger(dataset): clu_new = np.load(next(out_path_merge.glob('clusters.waveformsChannels.npy'))) assert np.array_equal(clu_old[split_clu], clu_new[np.max(clu) + 2]) assert np.array_equal(clu_old[split_clu], clu_new[np.max(clu) + 3]) + + +def test_alf2phy(dataset): + path = Path(dataset.tmp_dir) + # do the Alf2phy forward conversion + alf_path = path / 'alf' + model = TemplateModel( + dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) + c = EphysAlfCreator(model) + c.convert(alf_path) + # do the phy2alf backward conversion + path_phy_reverse = alf2phy(alf_path, s2v=1) + np.testing.assert_allclose( + np.load(path_phy_reverse.joinpath('amplitudes.npy')), + np.load(path.joinpath('amplitudes.npy')), + rtol=1e-4 + )