@@ -307,3 +307,96 @@ def compress_spikes_dtypes(self):
307307 for attribute in ['templates' , 'clusters' ]:
308308 fn = next (self .out_path .glob (f'spikes.{ attribute } .*npy' ))
309309 np .save (fn , np .load (fn ).astype (np .uint16 ))
310+
311+
312+ def alf2phy (alf_path : Path , target_path = None , use_symlinks : bool = True , s2v = 2.34375e-06 ) -> Path :
313+ """Convert ALF format dataset to phy format for visualization and manual curation.
314+
315+ This function converts spike sorting data from ALF (ALyx File) format back to the
316+ phy/KiloSort format. It performs inverse transformations including unwhitening
317+ templates, converting amplitudes from voltage to arbitrary units, and creating
318+ the necessary file structure for phy GUI compatibility.
319+
320+ Args:
321+ alf_path (Path): Path to the directory containing ALF format files. This directory
322+ should contain templates.waveforms.npy, spikes.clusters.npy, and other
323+ standard ALF files.
324+ target_path (Path, optional): Path where the phy format files will be created.
325+ If None, creates a '_phy' subdirectory within alf_path. Defaults to None.
326+ use_symlinks (bool, optional): If True, creates symbolic links for compatible
327+ files instead of copying them. If False, raises NotImplementedError as
328+ copying is not yet implemented. Defaults to True.
329+ s2v (float, optional): Sample-to-voltage conversion factor used to convert
330+ amplitudes from voltage units back to arbitrary units. Defaults to
331+ 2.34375e-06.
332+
333+ Returns:
334+ Path: The path to the target directory containing the converted phy format files.
335+ This directory will contain files like amplitudes.npy, templates.npy,
336+ channel_map.npy, and cluster_group.tsv that are compatible with phy GUI.
337+ """
338+
339+ target_path = target_path if target_path is not None else alf_path .joinpath ('_phy' )
340+ target_path .mkdir (parents = True , exist_ok = True )
341+
342+ # those are the easy files that can be directly copied / linked
343+ file_renames = _FILE_RENAMES .copy ()
344+ file_renames .append (('spike_times.npy' , 'spikes.samples.npy' , False ))
345+ for f in file_renames :
346+ source_file = alf_path .joinpath (f [1 ])
347+ target_file = target_path .joinpath (f [0 ])
348+ if not source_file .exists ():
349+ continue
350+ if use_symlinks :
351+ if target_file .exists ():
352+ continue
353+ target_file .symlink_to (source_file )
354+ else :
355+ raise NotImplementedError
356+
357+ templates = {
358+ 'waveforms' : np .load (alf_path .joinpath ('templates.waveforms.npy' )),
359+ 'waveformsChannels' : np .load (alf_path .joinpath ('templates.waveformsChannels.npy' )),
360+ }
361+ spikes = {
362+ 'clusters' : np .load (alf_path .joinpath ('spikes.clusters.npy' )),
363+ 'samples' : np .load (alf_path .joinpath ('spikes.samples.npy' )),
364+ 'amps' : np .load (alf_path .joinpath ('spikes.amps.npy' )),
365+ }
366+
367+ # now we do the inverse processing to get to the AU amplitudes
368+ wm = np .load (alf_path .joinpath ('_kilosort_whitening.matrix.npy' ))
369+ nclu = templates ['waveforms' ].shape [0 ]
370+ nch = wm .shape [0 ]
371+
372+ def get_waveforms_amp (templates ):
373+ cha = np .max (templates , axis = 1 ) - np .min (templates , axis = 1 )
374+ return np .max (cha , axis = 1 )
375+
376+ # we unwhiten the templates waveforms, this will expand the templates to the original non-sparse size
377+ templates_phy = np .zeros ([nclu , templates ['waveforms' ].shape [1 ], nch ], dtype = np .float32 )
378+ for i in np .arange (templates_phy .shape [0 ]):
379+ templates_phy [i ] = np .matmul (templates ['waveforms' ][i ], wm [templates ['waveformsChannels' ][i ], :])
380+
381+ # the original templates have a rms of 1.0, so here we just need to normalize by rms
382+ rms_templates = np .sum (np .sum (templates ['waveforms' ] ** 2 , axis = 1 ), axis = 1 ) ** 0.5
383+ templates_phy = templates_phy / rms_templates [:, np .newaxis , np .newaxis ]
384+
385+ template_amps_au = get_waveforms_amp (templates ['waveforms' ]) / rms_templates
386+
387+ spikes_amps_au = spikes ['amps' ] / s2v / template_amps_au [spikes ['clusters' ]]
388+
389+ np .save (target_path .joinpath ('amplitudes.npy' ), spikes_amps_au )
390+ np .save (target_path .joinpath ('channel_map.npy' ), np .arange (nch ))
391+ np .save (target_path .joinpath ('templates.npy' ), templates_phy )
392+
393+ np .save (target_path .joinpath ('templates_ind.npy' ), np .tile (np .arange (nclu )[np .newaxis , :], reps = [nch , 1 ]))
394+
395+ # if we have metrics information, output the ks2_label information
396+ if alf_path .joinpath ('cluster.metrics.pqt' ).exists ():
397+ import pandas as pd # optional dependency
398+ df_cluster_metrics = pd .read_parquet (alf_path .joinpath ('cluster.metrics.pqt' ))
399+ df_cluster_metrics .loc [:, ['cluster_id' , 'ks2_label' ]].to_csv (
400+ target_path .joinpath ('cluster_group.tsv' ), sep = '\t ' )
401+
402+ return target_path
0 commit comments