@@ -412,6 +412,17 @@ def _load_data(self):
412412 self .n_samples_waveforms = 0
413413 self .n_channels_loc = 0
414414
415+ # Clusters waveforms
416+ if not np .all (self .spike_clusters == self .spike_templates ) and self .sparse_templates .cols is None :
417+ self .merge_map , self .nan_idx = self .get_merge_map ()
418+ self .sparse_clusters = self .cluster_waveforms ()
419+ self .n_clusters = self .spike_clusters .max () + 1
420+ else :
421+ self .merge_map = {}
422+ self .nan_idx = []
423+ self .sparse_clusters = self .sparse_templates
424+ self .n_clusters = self .spike_templates .max () + 1
425+
415426 # Spike waveforms (optional, otherwise fetched from raw data as needed).
416427 self .spike_waveforms = self ._load_spike_waveforms ()
417428
@@ -861,12 +872,12 @@ def _template_n_channels(self, template_id, n_channels):
861872 channel_ids += [- 1 ] * (n_channels - len (channel_ids ))
862873 return channel_ids
863874
864- def _get_template_dense (self , template_id , channel_ids = None , amplitude_threshold = None ):
875+ def _get_template_dense (self , template_id , channel_ids = None , amplitude_threshold = None , unwhiten = True ):
865876 """Return data for one template."""
866877 if not self .sparse_templates :
867878 return
868879 template_w = self .sparse_templates .data [template_id , ...]
869- template = self ._unwhiten (template_w ).astype (np .float32 )
880+ template = self ._unwhiten (template_w ).astype (np .float32 ) if unwhiten else template_w
870881 assert template .ndim == 2
871882 channel_ids_ , amplitude , best_channel = self ._find_best_channels (
872883 template , amplitude_threshold = amplitude_threshold )
@@ -881,7 +892,7 @@ def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold
881892 channel_ids = channel_ids ,
882893 )
883894
884- def _get_template_sparse (self , template_id ):
895+ def _get_template_sparse (self , template_id , unwhiten = True ):
885896 data , cols = self .sparse_templates .data , self .sparse_templates .cols
886897 assert cols is not None
887898 template_w , channel_ids = data [template_id ], cols [template_id ]
@@ -902,7 +913,7 @@ def _get_template_sparse(self, template_id):
902913 channel_ids = channel_ids .astype (np .uint32 )
903914
904915 # Unwhiten.
905- template = self ._unwhiten (template_w , channel_ids = channel_ids )
916+ template = self ._unwhiten (template_w , channel_ids = channel_ids ) if unwhiten else template_w
906917 template = template .astype (np .float32 )
907918 assert template .ndim == 2
908919 assert template .shape [1 ] == len (channel_ids )
@@ -920,17 +931,31 @@ def _get_template_sparse(self, template_id):
920931 )
921932 return out
922933
934+ def get_merge_map (self ):
935+ """"Gets the maps of merges and splits between spikes.clusters and spikes.templates"""
936+ inverse_mapping_dict = {key : [] for key in range (np .max (self .spike_clusters ) + 1 )}
937+ for temp in np .unique (self .spike_templates ):
938+ idx = np .where (self .spike_templates == temp )[0 ]
939+ new_idx = self .spike_clusters [idx ]
940+ mapping = np .unique (new_idx )
941+ for n in mapping :
942+ inverse_mapping_dict [n ].append (temp )
943+
944+ nan_idx = np .array ([idx for idx , val in inverse_mapping_dict .items () if len (val ) == 0 ])
945+
946+ return inverse_mapping_dict , nan_idx
947+
923948 #--------------------------------------------------------------------------
924949 # Data access methods
925950 #--------------------------------------------------------------------------
926951
927- def get_template (self , template_id , channel_ids = None , amplitude_threshold = None ):
952+ def get_template (self , template_id , channel_ids = None , amplitude_threshold = None , unwhiten = True ):
928953 """Get data about a template."""
929954 if self .sparse_templates and self .sparse_templates .cols is not None :
930- return self ._get_template_sparse (template_id )
955+ return self ._get_template_sparse (template_id , unwhiten = unwhiten )
931956 else :
932957 return self ._get_template_dense (
933- template_id , channel_ids = channel_ids , amplitude_threshold = amplitude_threshold )
958+ template_id , channel_ids = channel_ids , amplitude_threshold = amplitude_threshold , unwhiten = unwhiten )
934959
935960 def get_waveforms (self , spike_ids , channel_ids = None ):
936961 """Return spike waveforms on specified channels."""
@@ -1047,7 +1072,7 @@ def get_depths(self):
10471072 # take only first component
10481073 features = self .sparse_features .data [ispi , :, 0 ]
10491074 features = np .maximum (features , 0 ) ** 2 # takes only positive values into account
1050- ichannels = self .sparse_features .cols [self .spike_clusters [ispi ]].astype (np .uint32 )
1075+ ichannels = self .sparse_features .cols [self .spike_templates [ispi ]].astype (np .uint32 )
10511076 # features = np.square(self.sparse_features.data[ispi, :, 0])
10521077 # ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.int64)
10531078 ypos = self .channel_positions [ichannels , 1 ]
@@ -1059,7 +1084,7 @@ def get_depths(self):
10591084 break
10601085 return spikes_depths
10611086
1062- def get_amplitudes_true (self , sample2unit = 1. ):
1087+ def get_amplitudes_true (self , sample2unit = 1. , use = 'templates' ):
10631088 """Convert spike amplitude values to input amplitudes units
10641089 via scaling by unwhitened template waveform.
10651090 :param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.)
@@ -1074,26 +1099,35 @@ def get_amplitudes_true(self, sample2unit=1.):
10741099 # spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps))
10751100 # to rescale the template,
10761101
1102+ if use == 'clusters' :
1103+ sparse = self .sparse_clusters
1104+ spikes = self .spike_clusters
1105+ n_wav = self .n_clusters
1106+ else :
1107+ sparse = self .sparse_templates
1108+ spikes = self .spike_templates
1109+ n_wav = self .n_templates
1110+
10771111 # unwhiten template waveforms on their channels of max amplitude
1078- if self . sparse_templates .cols :
1112+ if sparse .cols :
10791113 raise NotImplementedError
10801114 # apply the inverse whitening matrix to the template
1081- templates_wfs = np .zeros_like (self . sparse_templates .data ) # nt, ns, nc
1082- for n in np .arange (self . n_templates ):
1083- templates_wfs [n , :, :] = np .matmul (self . sparse_templates .data [n , :, :], self .wmi )
1115+ templates_wfs = np .zeros_like (sparse .data ) # nt, ns, nc
1116+ for n in np .arange (n_wav ):
1117+ templates_wfs [n , :, :] = np .matmul (sparse .data [n , :, :], self .wmi )
10841118
10851119 # The amplitude on each channel is the positive peak minus the negative
10861120 templates_ch_amps = np .max (templates_wfs , axis = 1 ) - np .min (templates_wfs , axis = 1 )
10871121
10881122 # The template arbitrary unit amplitude is the amplitude of its largest channel
10891123 # (but see below for true tempAmps)
10901124 templates_amps_au = np .max (templates_ch_amps , axis = 1 )
1091- spike_amps = templates_amps_au [self . spike_templates ] * self .amplitudes
1125+ spike_amps = templates_amps_au [spikes ] * self .amplitudes
10921126
10931127 with np .errstate (divide = 'ignore' , invalid = 'ignore' ):
10941128 # take the average spike amplitude per template
1095- templates_amps_v = (np .bincount (self . spike_templates , weights = spike_amps ) /
1096- np .bincount (self . spike_templates ))
1129+ templates_amps_v = (np .bincount (spikes , weights = spike_amps ) /
1130+ np .bincount (spikes ))
10971131 # scale back the template according to the spikes units
10981132 templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au
10991133 )[:, np .newaxis , np .newaxis ]
@@ -1167,18 +1201,18 @@ def get_template_waveforms(self, template_id):
11671201 template = self .get_template (template_id )
11681202 return template .template if template else None
11691203
1170- def get_cluster_mean_waveforms (self , cluster_id ):
1204+ def get_cluster_mean_waveforms (self , cluster_id , unwhiten = True ):
11711205 """Return the mean template waveforms of a cluster, as a weighted average of the
11721206 template waveforms from which the cluster originates from."""
11731207 count = self .get_template_counts (cluster_id )
11741208 best_template = np .argmax (count )
11751209 template_ids = np .nonzero (count )[0 ]
11761210 count = count [template_ids ]
11771211 # Get local channels of the best template for the given cluster.
1178- template = self .get_template (best_template )
1212+ template = self .get_template (best_template , unwhiten = unwhiten )
11791213 channel_ids = template .channel_ids
11801214 # Get all templates from which this cluster stems from.
1181- templates = [self .get_template (template_id ) for template_id in template_ids ]
1215+ templates = [self .get_template (template_id , unwhiten = unwhiten ) for template_id in template_ids ]
11821216 # Construct the waveforms array.
11831217 ns = self .n_samples_waveforms
11841218 data = np .zeros ((len (template_ids ), ns , self .n_channels ))
@@ -1204,14 +1238,24 @@ def get_cluster_spike_waveforms(self, cluster_id):
12041238
12051239 @property
12061240 def templates_channels (self ):
1207- """Returns a vector of peak channels for all templates"""
1208- tmp = self .sparse_templates .data
1241+ """Returns a vector of peak channels for all templates waveforms"""
1242+ return self ._channels (self .sparse_templates )
1243+
1244+ @property
1245+ def clusters_channels (self ):
1246+ """Returns a vector of peak channels for all clusters waveforms"""
1247+ channels = self ._channels (self .sparse_clusters )
1248+ return channels
1249+
1250+ def _channels (self , sparse ):
1251+ """ Gets peak channels for each waveform"""
1252+ tmp = sparse .data
12091253 n_templates , n_samples , n_channels = tmp .shape
1210- if self . sparse_templates .cols is None :
1254+ if sparse .cols is None :
12111255 template_peak_channels = np .argmax (tmp .max (axis = 1 ) - tmp .min (axis = 1 ), axis = 1 )
12121256 else :
12131257 # when the templates are sparse, the first channel is the highest amplitude channel
1214- template_peak_channels = self . sparse_templates .cols [:, 0 ]
1258+ template_peak_channels = sparse .cols [:, 0 ]
12151259 assert template_peak_channels .shape == (n_templates ,)
12161260 return template_peak_channels
12171261
@@ -1223,16 +1267,33 @@ def templates_probes(self):
12231267 @property
12241268 def templates_amplitudes (self ):
12251269 """Returns the average amplitude per cluster"""
1226- tid = np .unique (self .spike_templates )
1227- n = np .bincount (self .spike_templates )[tid ]
1228- a = np .bincount (self .spike_templates , weights = self .amplitudes )[tid ]
1270+ return self ._amplitudes (self .spike_templates )
1271+
1272+ @property
1273+ def clusters_amplitudes (self ):
1274+ """Returns the average amplitude per cluster"""
1275+ return self ._amplitudes (self .spike_clusters )
1276+
1277+ def _amplitudes (self , tmp ):
1278+ """ Compute average amplitude for spikes"""
1279+ tid = np .unique (tmp )
1280+ n = np .bincount (tmp )[tid ]
1281+ a = np .bincount (tmp , weights = self .amplitudes )[tid ]
12291282 n [np .isnan (n )] = 1
12301283 return a / n
12311284
12321285 @property
12331286 def templates_waveforms_durations (self ):
12341287 """Returns a vector of waveform durations (ms) for all templates"""
1235- tmp = self .sparse_templates .data
1288+ return self ._waveform_durations (self .sparse_templates .data )
1289+
1290+ @property
1291+ def clusters_waveforms_durations (self ):
1292+ """Returns a vector of waveform durations (ms) for all clusters"""
1293+ waveform_duration = self ._waveform_durations (self .sparse_clusters .data )
1294+ return waveform_duration
1295+
1296+ def _waveform_durations (self , tmp ):
12361297 n_templates , n_samples , n_channels = tmp .shape
12371298 # Compute the peak channels for each template.
12381299 template_peak_channels = np .argmax (tmp .max (axis = 1 ) - tmp .min (axis = 1 ), axis = 1 )
@@ -1241,6 +1302,23 @@ def templates_waveforms_durations(self):
12411302 (n_templates , n_channels ), mode = 'raise' , order = 'C' )
12421303 return durations .flatten ()[ind ].astype (np .float64 ) / self .sample_rate * 1e3
12431304
1305+ def cluster_waveforms (self ):
1306+ """
1307+ Computes the cluster waveforms for split and merged clusters
1308+ :return:
1309+ """
1310+ # Only non sparse implementation
1311+ ns = self .n_samples_waveforms
1312+ data = np .zeros ((np .max (self .cluster_ids ) + 1 , ns , self .n_channels ))
1313+ for clust , val in self .merge_map .items ():
1314+ if len (val ) > 1 :
1315+ mean_waveform = self .get_cluster_mean_waveforms (clust , unwhiten = False )
1316+ data [clust , :, mean_waveform .channel_ids ] = np .swapaxes (mean_waveform .mean_waveforms , 0 , 1 )
1317+ elif len (val ) == 1 :
1318+ data [clust , :, :] = self .sparse_templates .data [val [0 ], :, :]
1319+
1320+ return Bunch (data = data , cols = None )
1321+
12441322 #--------------------------------------------------------------------------
12451323 # Saving methods
12461324 #--------------------------------------------------------------------------
0 commit comments