@@ -413,7 +413,8 @@ def _load_data(self):
413413 self .n_channels_loc = 0
414414
415415 # Clusters waveforms
416- if not np .all (self .spike_clusters == self .spike_templates ) and self .sparse_templates .cols is None :
416+ if not np .all (self .spike_clusters == self .spike_templates ) and \
417+ self .sparse_templates .cols is None :
417418 self .merge_map , self .nan_idx = self .get_merge_map ()
418419 self .sparse_clusters = self .cluster_waveforms ()
419420 self .n_clusters = self .spike_clusters .max () + 1
@@ -872,7 +873,8 @@ def _template_n_channels(self, template_id, n_channels):
872873 channel_ids += [- 1 ] * (n_channels - len (channel_ids ))
873874 return channel_ids
874875
875- def _get_template_dense (self , template_id , channel_ids = None , amplitude_threshold = None , unwhiten = True ):
876+ def _get_template_dense (self , template_id , channel_ids = None , amplitude_threshold = None ,
877+ unwhiten = True ):
876878 """Return data for one template."""
877879 if not self .sparse_templates :
878880 return
@@ -955,7 +957,8 @@ def get_template(self, template_id, channel_ids=None, amplitude_threshold=None,
955957 return self ._get_template_sparse (template_id , unwhiten = unwhiten )
956958 else :
957959 return self ._get_template_dense (
958- template_id , channel_ids = channel_ids , amplitude_threshold = amplitude_threshold , unwhiten = unwhiten )
960+ template_id , channel_ids = channel_ids , amplitude_threshold = amplitude_threshold ,
961+ unwhiten = unwhiten )
959962
960963 def get_waveforms (self , spike_ids , channel_ids = None ):
961964 """Return spike waveforms on specified channels."""
@@ -1212,7 +1215,8 @@ def get_cluster_mean_waveforms(self, cluster_id, unwhiten=True):
12121215 template = self .get_template (best_template , unwhiten = unwhiten )
12131216 channel_ids = template .channel_ids
12141217 # Get all templates from which this cluster stems from.
1215- templates = [self .get_template (template_id , unwhiten = unwhiten ) for template_id in template_ids ]
1218+ templates = [self .get_template (template_id , unwhiten = unwhiten )
1219+ for template_id in template_ids ]
12161220 # Construct the waveforms array.
12171221 ns = self .n_samples_waveforms
12181222 data = np .zeros ((len (template_ids ), ns , self .n_channels ))
@@ -1313,7 +1317,8 @@ def cluster_waveforms(self):
13131317 for clust , val in self .merge_map .items ():
13141318 if len (val ) > 1 :
13151319 mean_waveform = self .get_cluster_mean_waveforms (clust , unwhiten = False )
1316- data [clust , :, mean_waveform .channel_ids ] = np .swapaxes (mean_waveform .mean_waveforms , 0 , 1 )
1320+ data [clust , :, mean_waveform .channel_ids ] = \
1321+ np .swapaxes (mean_waveform .mean_waveforms , 0 , 1 )
13171322 elif len (val ) == 1 :
13181323 data [clust , :, :] = self .sparse_templates .data [val [0 ], :, :]
13191324
0 commit comments