|
37 | 37 | _default_params = dict() |
38 | 38 |
|
39 | 39 |
|
| 40 | +def compute_noise_cutoffs(sorting_analyzer, high_quantile=0.25, low_quantile=0.1, n_bins=100, unit_ids=None): |
| 41 | + """ |
| 42 | + A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. |
| 43 | +
|
| 44 | + Based on the histogram of the (transformed) amplitude: |
| 45 | +
|
| 46 | + 1. This method compares counts in the lower-amplitude bins to counts in the top 'high_quantile' of the amplitude range. |
| 47 | + It computes the mean and std of an upper quantile of the distribution, and calculates how many standard deviations away |
| 48 | + from that mean the lower-quantile bins lie. |
| 49 | +
|
| 50 | + 2. The method also compares the counts in the lower-amplitude bins to the count in the highest bin and return their ratio. |
| 51 | +
|
| 52 | + Parameters |
| 53 | + ---------- |
| 54 | + sorting_analyzer : SortingAnalyzer |
| 55 | + A SortingAnalyzer object. |
| 56 | + high_quantile : float, default: 0.25 |
| 57 | + Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. |
| 58 | + low_quantile : int, default: 0.1 |
| 59 | + Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. |
| 60 | + n_bins: int, default: 100 |
| 61 | + The number of bins to use to compute the amplitude histogram. |
| 62 | + unit_ids : list or None |
| 63 | + List of unit ids to compute the amplitude cutoffs. If None, all units are used. |
| 64 | +
|
| 65 | + Returns |
| 66 | + ------- |
| 67 | + noise_cutoff_dict : dict of floats |
| 68 | + Estimated metrics based on the amplitude distribution, for each unit ID. |
| 69 | +
|
| 70 | + References |
| 71 | + ---------- |
| 72 | + Inspired by metric described in [IBL2024]_ |
| 73 | +
|
| 74 | + """ |
| 75 | + res = namedtuple("cutoff_metrics", ["noise_cutoff", "noise_ratio"]) |
| 76 | + if unit_ids is None: |
| 77 | + unit_ids = sorting_analyzer.unit_ids |
| 78 | + |
| 79 | + noise_cutoff_dict = {} |
| 80 | + noise_ratio_dict = {} |
| 81 | + if not sorting_analyzer.has_extension("spike_amplitudes"): |
| 82 | + warnings.warn( |
| 83 | + "`compute_noise_cutoffs` requires the 'spike_amplitudes` extension. Please run sorting_analyzer.compute('spike_amplitudes') to be able to compute `noise_cutoff`" |
| 84 | + ) |
| 85 | + for unit_id in unit_ids: |
| 86 | + noise_cutoff_dict[unit_id] = np.nan |
| 87 | + noise_ratio_dict[unit_id] = np.nan |
| 88 | + return res(noise_cutoff_dict, noise_ratio_dict) |
| 89 | + |
| 90 | + amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") |
| 91 | + peak_sign = amplitude_extension.params["peak_sign"] |
| 92 | + if peak_sign == "both": |
| 93 | + raise TypeError( |
| 94 | + '`peak_sign` should either be "pos" or "neg". You can set `peak_sign` as an argument when you compute spike_amplitudes.' |
| 95 | + ) |
| 96 | + |
| 97 | + amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) |
| 98 | + |
| 99 | + for unit_id in unit_ids: |
| 100 | + amplitudes = amplitudes_by_units[unit_id] |
| 101 | + |
| 102 | + # We assume the noise (zero values) is on the lower tail of the amplitude distribution. |
| 103 | + # But if peak_sign == 'neg', the noise will be on the higher tail, so we flip the distribution. |
| 104 | + if peak_sign == "neg": |
| 105 | + amplitudes = -amplitudes |
| 106 | + |
| 107 | + cutoff, ratio = _noise_cutoff(amplitudes, high_quantile=high_quantile, low_quantile=low_quantile, n_bins=n_bins) |
| 108 | + noise_cutoff_dict[unit_id] = cutoff |
| 109 | + noise_ratio_dict[unit_id] = ratio |
| 110 | + |
| 111 | + return res(noise_cutoff_dict, noise_ratio_dict) |
| 112 | + |
| 113 | + |
| 114 | +_default_params["noise_cutoff"] = dict(high_quantile=0.25, low_quantile=0.1, n_bins=100) |
| 115 | + |
| 116 | + |
| 117 | +def _noise_cutoff(amps, high_quantile=0.25, low_quantile=0.1, n_bins=100): |
| 118 | + """ |
| 119 | + A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. |
| 120 | +
|
| 121 | + Based on the histogram of the (transformed) amplitude: |
| 122 | +
|
| 123 | + 1. This method compares counts in the lower-amplitude bins to counts in the higher_amplitude bins. |
| 124 | + It computes the mean and std of an upper quantile of the distribution, and calculates how many standard deviations away |
| 125 | + from that mean the lower-quantile bins lie. |
| 126 | +
|
| 127 | + 2. The method also compares the counts in the lower-amplitude bins to the count in the highest bin and return their ratio. |
| 128 | +
|
| 129 | + Parameters |
| 130 | + ---------- |
| 131 | + amps : array-like |
| 132 | + Spike amplitudes. |
| 133 | + high_quantile : float, default: 0.25 |
| 134 | + Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. |
| 135 | + low_quantile : int, default: 0.1 |
| 136 | + Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. |
| 137 | + n_bins: int, default: 100 |
| 138 | + The number of bins to use to compute the amplitude histogram. |
| 139 | +
|
| 140 | + Returns |
| 141 | + ------- |
| 142 | + cutoff : float |
| 143 | + (mean(lower_bins_count) - mean(high_bins_count)) / std(high_bins_count) |
| 144 | + ratio: float |
| 145 | + mean(lower_bins_count) / highest_bin_count |
| 146 | +
|
| 147 | + """ |
| 148 | + n_per_bin, bin_edges = np.histogram(amps, bins=n_bins) |
| 149 | + |
| 150 | + maximum_bin_height = np.max(n_per_bin) |
| 151 | + |
| 152 | + low_quantile_value = np.quantile(amps, q=low_quantile) |
| 153 | + |
| 154 | + # the indices for low-amplitude bins |
| 155 | + low_indices = np.where(bin_edges[1:] <= low_quantile_value)[0] |
| 156 | + |
| 157 | + high_quantile_value = np.quantile(amps, q=1 - high_quantile) |
| 158 | + |
| 159 | + # the indices for high-amplitude bins |
| 160 | + high_indices = np.where(bin_edges[:-1] >= high_quantile_value)[0] |
| 161 | + |
| 162 | + if len(low_indices) == 0: |
| 163 | + warnings.warn( |
| 164 | + "No bin is selected to test cutoff. Please increase low_quantile. Setting noise cutoff and ratio to NaN" |
| 165 | + ) |
| 166 | + return np.nan, np.nan |
| 167 | + |
| 168 | + # compute ratio between low-amplitude bins and the largest bin |
| 169 | + low_counts = n_per_bin[low_indices] |
| 170 | + mean_low_counts = np.mean(low_counts) |
| 171 | + ratio = mean_low_counts / maximum_bin_height |
| 172 | + |
| 173 | + if len(high_indices) == 0: |
| 174 | + warnings.warn( |
| 175 | + "No bin is selected as the reference region. Please increase high_quantile. Setting noise cutoff to NaN" |
| 176 | + ) |
| 177 | + return np.nan, ratio |
| 178 | + |
| 179 | + if len(high_indices) == 1: |
| 180 | + warnings.warn( |
| 181 | + "Only one bin is selected as the reference region, and thus the standard deviation cannot be computed. " |
| 182 | + "Please increase high_quantile. Setting noise cutoff to NaN" |
| 183 | + ) |
| 184 | + return np.nan, ratio |
| 185 | + |
| 186 | + # compute cutoff from low-amplitude and high-amplitude bins |
| 187 | + high_counts = n_per_bin[high_indices] |
| 188 | + mean_high_counts = np.mean(high_counts) |
| 189 | + std_high_counts = np.std(high_counts) |
| 190 | + if std_high_counts == 0: |
| 191 | + warnings.warn( |
| 192 | + "All the high-amplitude bins have the same size. Please consider changing n_bins. " |
| 193 | + "Setting noise cutoff to NaN" |
| 194 | + ) |
| 195 | + return np.nan, ratio |
| 196 | + |
| 197 | + cutoff = (mean_low_counts - mean_high_counts) / std_high_counts |
| 198 | + return cutoff, ratio |
| 199 | + |
| 200 | + |
40 | 201 | def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): |
41 | 202 | """ |
42 | 203 | Compute the number of spike across segments. |
|
0 commit comments