Skip to content

Commit 0fe4668

Browse files
authored
Merge pull request #36 from cortex-lab/alf_conversion_merge
Alf conversion merge
2 parents 83f23cc + 6db4487 commit 0fe4668

File tree

5 files changed

+240
-33
lines changed

5 files changed

+240
-33
lines changed

phylib/io/alf.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,20 @@ def make_cluster_objects(self):
174174
"""Create clusters.channels, clusters.waveformsDuration and clusters.amps"""
175175
peak_channel_path = self.dir_path / 'clusters.channels.npy'
176176
if not peak_channel_path.exists():
177-
self._save_npy(peak_channel_path.name, self.model.templates_channels)
177+
# self._save_npy(peak_channel_path.name, self.model.templates_channels)
178+
self._save_npy(peak_channel_path.name, self.model.clusters_channels)
178179

179180
waveform_duration_path = self.dir_path / 'clusters.peakToTrough.npy'
180181
if not waveform_duration_path.exists():
181-
self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations)
182+
# self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations)
183+
waveform_duration = self.model.clusters_waveforms_durations
184+
waveform_duration[self.model.nan_idx] = np.nan
185+
self._save_npy(waveform_duration_path.name, waveform_duration)
182186

183187
# group by average over cluster number
184-
camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan
185-
camps[self.cluster_ids] = self.model.templates_amplitudes
188+
# camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan
189+
camps = np.zeros(self.model.clusters_channels.shape[0], ) * np.nan
190+
camps[self.cluster_ids] = self.model.clusters_amplitudes
186191
amps_path = self.dir_path / 'clusters.amps.npy'
187192
self._save_npy(amps_path.name, camps * self.ampfactor)
188193

@@ -216,6 +221,7 @@ def make_depths(self):
216221
n_clusters = cluster_channels.shape[0]
217222

218223
clusters_depths = channel_positions[cluster_channels, 1]
224+
clusters_depths[self.model.nan_idx] = np.nan
219225
assert clusters_depths.shape == (n_clusters,)
220226

221227
if self.model.sparse_features is None:
@@ -233,7 +239,7 @@ def make_template_and_spikes_objects(self):
233239
# and not seconds
234240
self._save_npy('spikes.times.npy', self.model.spike_times)
235241
self._save_npy('spikes.samples.npy', self.model.spike_samples)
236-
spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor)
242+
spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor, use='templates')
237243
self._save_npy('spikes.amps.npy', spike_amps)
238244
self._save_npy('templates.amps.npy', template_amps)
239245

@@ -257,9 +263,32 @@ def make_template_and_spikes_objects(self):
257263
templates[t, ...] = templates_v[t, :][:, templates_inds[t, :]]
258264
np.save(self.out_path.joinpath('templates.waveforms'), templates)
259265
np.save(self.out_path.joinpath('templates.waveformsChannels'), templates_inds)
266+
267+
_, clusters_v, cluster_amps = self.model.get_amplitudes_true(self.ampfactor, use='clusters')
268+
n_clusters, n_wavsamps, nchall = clusters_v.shape
269+
# for some datasets, 32 may be too much
270+
ncw = min(self.model.n_closest_channels, nchall)
271+
assert(n_clusters == self.model.n_clusters)
272+
templates = np.zeros((n_clusters, n_wavsamps, ncw), dtype=np.float32)
273+
templates_inds = np.zeros((n_clusters, ncw), dtype=np.int32)
274+
# for each template, find the nearest channels to keep (one the same probe...)
275+
for t in np.arange(n_clusters):
276+
channels = self.model.clusters_channels
277+
278+
current_probe = self.model.channel_probes[channels[t]]
279+
channel_distance = np.sum(np.abs(
280+
self.model.channel_positions -
281+
self.model.channel_positions[channels[t]]), axis=1)
282+
channel_distance[self.model.channel_probes != current_probe] += np.inf
283+
templates_inds[t, :] = np.argsort(channel_distance)[:ncw]
284+
templates[t, ...] = clusters_v[t, :][:, templates_inds[t, :]]
260285
np.save(self.out_path.joinpath('clusters.waveforms'), templates)
261286
np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds)
262287

288+
# TODO check if we should save this here, will be inconsistent with what we have at the moment
289+
np.save(self.out_path.joinpath('clusters.amps'), cluster_amps)
290+
291+
263292
def rename_with_label(self):
264293
"""add the label as an ALF part name before the extension if any label provided"""
265294
if not self.label:

phylib/io/model.py

Lines changed: 105 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,17 @@ def _load_data(self):
412412
self.n_samples_waveforms = 0
413413
self.n_channels_loc = 0
414414

415+
# Clusters waveforms
416+
if not np.all(self.spike_clusters == self.spike_templates) and self.sparse_templates.cols is None:
417+
self.merge_map, self.nan_idx = self.get_merge_map()
418+
self.sparse_clusters = self.cluster_waveforms()
419+
self.n_clusters = self.spike_clusters.max() + 1
420+
else:
421+
self.merge_map = {}
422+
self.nan_idx = []
423+
self.sparse_clusters = self.sparse_templates
424+
self.n_clusters = self.spike_templates.max() + 1
425+
415426
# Spike waveforms (optional, otherwise fetched from raw data as needed).
416427
self.spike_waveforms = self._load_spike_waveforms()
417428

@@ -861,12 +872,12 @@ def _template_n_channels(self, template_id, n_channels):
861872
channel_ids += [-1] * (n_channels - len(channel_ids))
862873
return channel_ids
863874

864-
def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold=None):
875+
def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True):
865876
"""Return data for one template."""
866877
if not self.sparse_templates:
867878
return
868879
template_w = self.sparse_templates.data[template_id, ...]
869-
template = self._unwhiten(template_w).astype(np.float32)
880+
template = self._unwhiten(template_w).astype(np.float32) if unwhiten else template_w
870881
assert template.ndim == 2
871882
channel_ids_, amplitude, best_channel = self._find_best_channels(
872883
template, amplitude_threshold=amplitude_threshold)
@@ -881,7 +892,7 @@ def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold
881892
channel_ids=channel_ids,
882893
)
883894

884-
def _get_template_sparse(self, template_id):
895+
def _get_template_sparse(self, template_id, unwhiten=True):
885896
data, cols = self.sparse_templates.data, self.sparse_templates.cols
886897
assert cols is not None
887898
template_w, channel_ids = data[template_id], cols[template_id]
@@ -902,7 +913,7 @@ def _get_template_sparse(self, template_id):
902913
channel_ids = channel_ids.astype(np.uint32)
903914

904915
# Unwhiten.
905-
template = self._unwhiten(template_w, channel_ids=channel_ids)
916+
template = self._unwhiten(template_w, channel_ids=channel_ids) if unwhiten else template_w
906917
template = template.astype(np.float32)
907918
assert template.ndim == 2
908919
assert template.shape[1] == len(channel_ids)
@@ -920,17 +931,31 @@ def _get_template_sparse(self, template_id):
920931
)
921932
return out
922933

934+
def get_merge_map(self):
935+
""""Gets the maps of merges and splits between spikes.clusters and spikes.templates"""
936+
inverse_mapping_dict = {key: [] for key in range(np.max(self.spike_clusters) + 1)}
937+
for temp in np.unique(self.spike_templates):
938+
idx = np.where(self.spike_templates == temp)[0]
939+
new_idx = self.spike_clusters[idx]
940+
mapping = np.unique(new_idx)
941+
for n in mapping:
942+
inverse_mapping_dict[n].append(temp)
943+
944+
nan_idx = np.array([idx for idx, val in inverse_mapping_dict.items() if len(val) == 0])
945+
946+
return inverse_mapping_dict, nan_idx
947+
923948
#--------------------------------------------------------------------------
924949
# Data access methods
925950
#--------------------------------------------------------------------------
926951

927-
def get_template(self, template_id, channel_ids=None, amplitude_threshold=None):
952+
def get_template(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True):
928953
"""Get data about a template."""
929954
if self.sparse_templates and self.sparse_templates.cols is not None:
930-
return self._get_template_sparse(template_id)
955+
return self._get_template_sparse(template_id, unwhiten=unwhiten)
931956
else:
932957
return self._get_template_dense(
933-
template_id, channel_ids=channel_ids, amplitude_threshold=amplitude_threshold)
958+
template_id, channel_ids=channel_ids, amplitude_threshold=amplitude_threshold, unwhiten=unwhiten)
934959

935960
def get_waveforms(self, spike_ids, channel_ids=None):
936961
"""Return spike waveforms on specified channels."""
@@ -1047,7 +1072,7 @@ def get_depths(self):
10471072
# take only first component
10481073
features = self.sparse_features.data[ispi, :, 0]
10491074
features = np.maximum(features, 0) ** 2 # takes only positive values into account
1050-
ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.uint32)
1075+
ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.uint32)
10511076
# features = np.square(self.sparse_features.data[ispi, :, 0])
10521077
# ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.int64)
10531078
ypos = self.channel_positions[ichannels, 1]
@@ -1059,7 +1084,7 @@ def get_depths(self):
10591084
break
10601085
return spikes_depths
10611086

1062-
def get_amplitudes_true(self, sample2unit=1.):
1087+
def get_amplitudes_true(self, sample2unit=1., use='templates'):
10631088
"""Convert spike amplitude values to input amplitudes units
10641089
via scaling by unwhitened template waveform.
10651090
:param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.)
@@ -1074,26 +1099,35 @@ def get_amplitudes_true(self, sample2unit=1.):
10741099
# spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps))
10751100
# to rescale the template,
10761101

1102+
if use == 'clusters':
1103+
sparse = self.sparse_clusters
1104+
spikes = self.spike_clusters
1105+
n_wav = self.n_clusters
1106+
else:
1107+
sparse = self.sparse_templates
1108+
spikes = self.spike_templates
1109+
n_wav = self.n_templates
1110+
10771111
# unwhiten template waveforms on their channels of max amplitude
1078-
if self.sparse_templates.cols:
1112+
if sparse.cols:
10791113
raise NotImplementedError
10801114
# apply the inverse whitening matrix to the template
1081-
templates_wfs = np.zeros_like(self.sparse_templates.data) # nt, ns, nc
1082-
for n in np.arange(self.n_templates):
1083-
templates_wfs[n, :, :] = np.matmul(self.sparse_templates.data[n, :, :], self.wmi)
1115+
templates_wfs = np.zeros_like(sparse.data) # nt, ns, nc
1116+
for n in np.arange(n_wav):
1117+
templates_wfs[n, :, :] = np.matmul(sparse.data[n, :, :], self.wmi)
10841118

10851119
# The amplitude on each channel is the positive peak minus the negative
10861120
templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(templates_wfs, axis=1)
10871121

10881122
# The template arbitrary unit amplitude is the amplitude of its largest channel
10891123
# (but see below for true tempAmps)
10901124
templates_amps_au = np.max(templates_ch_amps, axis=1)
1091-
spike_amps = templates_amps_au[self.spike_templates] * self.amplitudes
1125+
spike_amps = templates_amps_au[spikes] * self.amplitudes
10921126

10931127
with np.errstate(divide='ignore', invalid='ignore'):
10941128
# take the average spike amplitude per template
1095-
templates_amps_v = (np.bincount(self.spike_templates, weights=spike_amps) /
1096-
np.bincount(self.spike_templates))
1129+
templates_amps_v = (np.bincount(spikes, weights=spike_amps) /
1130+
np.bincount(spikes))
10971131
# scale back the template according to the spikes units
10981132
templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au
10991133
)[:, np.newaxis, np.newaxis]
@@ -1167,18 +1201,18 @@ def get_template_waveforms(self, template_id):
11671201
template = self.get_template(template_id)
11681202
return template.template if template else None
11691203

1170-
def get_cluster_mean_waveforms(self, cluster_id):
1204+
def get_cluster_mean_waveforms(self, cluster_id, unwhiten=True):
11711205
"""Return the mean template waveforms of a cluster, as a weighted average of the
11721206
template waveforms from which the cluster originates from."""
11731207
count = self.get_template_counts(cluster_id)
11741208
best_template = np.argmax(count)
11751209
template_ids = np.nonzero(count)[0]
11761210
count = count[template_ids]
11771211
# Get local channels of the best template for the given cluster.
1178-
template = self.get_template(best_template)
1212+
template = self.get_template(best_template, unwhiten=unwhiten)
11791213
channel_ids = template.channel_ids
11801214
# Get all templates from which this cluster stems from.
1181-
templates = [self.get_template(template_id) for template_id in template_ids]
1215+
templates = [self.get_template(template_id, unwhiten=unwhiten) for template_id in template_ids]
11821216
# Construct the waveforms array.
11831217
ns = self.n_samples_waveforms
11841218
data = np.zeros((len(template_ids), ns, self.n_channels))
@@ -1204,14 +1238,24 @@ def get_cluster_spike_waveforms(self, cluster_id):
12041238

12051239
@property
12061240
def templates_channels(self):
1207-
"""Returns a vector of peak channels for all templates"""
1208-
tmp = self.sparse_templates.data
1241+
"""Returns a vector of peak channels for all templates waveforms"""
1242+
return self._channels(self.sparse_templates)
1243+
1244+
@property
1245+
def clusters_channels(self):
1246+
"""Returns a vector of peak channels for all clusters waveforms"""
1247+
channels = self._channels(self.sparse_clusters)
1248+
return channels
1249+
1250+
def _channels(self, sparse):
1251+
""" Gets peak channels for each waveform"""
1252+
tmp = sparse.data
12091253
n_templates, n_samples, n_channels = tmp.shape
1210-
if self.sparse_templates.cols is None:
1254+
if sparse.cols is None:
12111255
template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1)
12121256
else:
12131257
# when the templates are sparse, the first channel is the highest amplitude channel
1214-
template_peak_channels = self.sparse_templates.cols[:, 0]
1258+
template_peak_channels = sparse.cols[:, 0]
12151259
assert template_peak_channels.shape == (n_templates,)
12161260
return template_peak_channels
12171261

@@ -1223,16 +1267,33 @@ def templates_probes(self):
12231267
@property
12241268
def templates_amplitudes(self):
12251269
"""Returns the average amplitude per cluster"""
1226-
tid = np.unique(self.spike_templates)
1227-
n = np.bincount(self.spike_templates)[tid]
1228-
a = np.bincount(self.spike_templates, weights=self.amplitudes)[tid]
1270+
return self._amplitudes(self.spike_templates)
1271+
1272+
@property
1273+
def clusters_amplitudes(self):
1274+
"""Returns the average amplitude per cluster"""
1275+
return self._amplitudes(self.spike_clusters)
1276+
1277+
def _amplitudes(self, tmp):
1278+
""" Compute average amplitude for spikes"""
1279+
tid = np.unique(tmp)
1280+
n = np.bincount(tmp)[tid]
1281+
a = np.bincount(tmp, weights=self.amplitudes)[tid]
12291282
n[np.isnan(n)] = 1
12301283
return a / n
12311284

12321285
@property
12331286
def templates_waveforms_durations(self):
12341287
"""Returns a vector of waveform durations (ms) for all templates"""
1235-
tmp = self.sparse_templates.data
1288+
return self._waveform_durations(self.sparse_templates.data)
1289+
1290+
@property
1291+
def clusters_waveforms_durations(self):
1292+
"""Returns a vector of waveform durations (ms) for all clusters"""
1293+
waveform_duration = self._waveform_durations(self.sparse_clusters.data)
1294+
return waveform_duration
1295+
1296+
def _waveform_durations(self, tmp):
12361297
n_templates, n_samples, n_channels = tmp.shape
12371298
# Compute the peak channels for each template.
12381299
template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1)
@@ -1241,6 +1302,23 @@ def templates_waveforms_durations(self):
12411302
(n_templates, n_channels), mode='raise', order='C')
12421303
return durations.flatten()[ind].astype(np.float64) / self.sample_rate * 1e3
12431304

1305+
def cluster_waveforms(self):
1306+
"""
1307+
Computes the cluster waveforms for split and merged clusters
1308+
:return:
1309+
"""
1310+
# Only non sparse implementation
1311+
ns = self.n_samples_waveforms
1312+
data = np.zeros((np.max(self.cluster_ids) + 1, ns, self.n_channels))
1313+
for clust, val in self.merge_map.items():
1314+
if len(val) > 1:
1315+
mean_waveform = self.get_cluster_mean_waveforms(clust, unwhiten=False)
1316+
data[clust, :, mean_waveform.channel_ids] = np.swapaxes(mean_waveform.mean_waveforms, 0, 1)
1317+
elif len(val) == 1:
1318+
data[clust, :, :] = self.sparse_templates.data[val[0], :, :]
1319+
1320+
return Bunch(data=data, cols=None)
1321+
12441322
#--------------------------------------------------------------------------
12451323
# Saving methods
12461324
#--------------------------------------------------------------------------

phylib/io/tests/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,19 @@ def _make_dataset(tempdir, param='dense', has_spike_attributes=True):
9898
_remove(tempdir / 'whitening_mat_inv.npy')
9999
_remove(tempdir / 'sim_binary.dat')
100100

101+
if param == 'merged':
102+
# remove this file to make templates dense
103+
_remove(tempdir / 'template_ind.npy')
104+
clus = np.load(tempdir / 'spike_clusters.npy')
105+
max_clus = np.max(clus)
106+
# merge cluster 0 and 1
107+
clus[np.bitwise_or(clus == 0, clus == 1)] = max_clus + 1
108+
# split cluster 9 into two clusters
109+
idx = np.where(clus == 9)[0]
110+
clus[idx[0:3]] = max_clus + 2
111+
clus[idx[3:]] = max_clus + 3
112+
np.save(tempdir / 'spike_clusters.npy', clus)
113+
101114
# Spike attributes.
102115
if has_spike_attributes:
103116
write_array(tempdir / 'spike_fail.npy', np.full(10, np.nan)) # wrong number of spikes
@@ -120,7 +133,7 @@ def _make_dataset(tempdir, param='dense', has_spike_attributes=True):
120133
return template_path
121134

122135

123-
@fixture(scope='function', params=('dense', 'sparse', 'misc'))
136+
@fixture(scope='function', params=('dense', 'sparse', 'misc', 'merged'))
124137
def template_path_full(tempdir, request):
125138
return _make_dataset(tempdir, request.param)
126139

0 commit comments

Comments
 (0)