diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..2227abb --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.bin binary \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6cd011f..5a99922 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,9 +24,6 @@ jobs: matrix: platform: - os: ubuntu-latest - - os: windows-latest - - os: macos-12 - - os: macos-14 steps: - uses: actions/checkout@v4 - name: Install stable Rust toolchain diff --git a/Cargo.lock b/Cargo.lock index b18dc33..ecf6dd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "aho-corasick" @@ -101,25 +101,22 @@ checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "criterion" -version = "0.5.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" dependencies = [ "anes", "cast", "ciborium", "clap", "criterion-plot", - "is-terminal", "itertools", "num-traits", - "once_cell", "oorandom", "plotters", "rayon", "regex", "serde", - "serde_derive", "serde_json", "tinytemplate", "walkdir", @@ -127,9 +124,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" dependencies = [ "cast", "itertools", @@ -171,6 +168,7 @@ name = "earshot" version = "0.1.0" dependencies = [ "criterion", + "libm", ] [[package]] @@ -189,28 +187,11 @@ dependencies = [ "crunchy", ] -[[package]] -name = "hermit-abi" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" - -[[package]] -name = "is-terminal" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "itertools" -version = "0.10.5" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -231,16 +212,16 @@ dependencies = [ ] [[package]] -name = "libc" -version = "0.2.158" +name = "libm" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "log" -version = "0.4.22" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "memchr" @@ -519,16 +500,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets", + "windows-sys", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 26bb8be..c637fe2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,28 +1,26 @@ [package] name = "earshot" version = "0.1.0" -description = "Ridiculously fast voice activity detection in pure #[no_std] Rust" +description = "Ridiculously fast & accurate voice activity detection in pure Rust" repository = "https://github.com/pykeio/earshot" authors = [ "Carson M " ] -license = "BSD-3-Clause" +license = "MIT" edition = "2021" exclude = ["tests/data", ".github"] [features] -default = [ "std", "alloc" ] +default = [ "std", "embed-weights" ] # Currently just impls `std::error::Error` for the `Error` type. std = [] -# Allocates internal buffers on the heap instead of the stack. -alloc = [] +# Embed the default model weights in the binary. Enables `Default` for `QuantizedPredictor`. +embed-weights = [] [dependencies] +libm = "0.2" [dev-dependencies] -criterion = "0.5" +criterion = "0.7" -[[bench]] -name = "downsample" -harness = false [[bench]] name = "vad" harness = false diff --git a/LICENSE b/LICENSE index 54700b4..91c7f17 100644 --- a/LICENSE +++ b/LICENSE @@ -1,30 +1,21 @@ -Copyright (c) 2011, The WebRTC project authors. All rights reserved. -Copyright (c) 2024 pyke.io +MIT License -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: +Copyright (c) 2025 pyke.io - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in - the documentation and/or other materials provided with the - distribution. +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. - * Neither the name of Google nor the names of its contributors may - be used to endorse or promote products derived from this software - without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 16c85a4..19207a6 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,7 @@ # Earshot -Ridiculously fast, only slightly bad voice activity detection in pure Rust. Port of the famous [WebRTC VAD](https://webrtc.googlesource.com/). +Ridiculously fast & accurate voice activity detection in pure Rust. -## Features -- `#![no_std]`, doesn't even require `alloc` - * Internal buffers can get pretty big when stored on the stack, so the `alloc` feature is enabled by default, which allocates them on the heap instead. -- Stupidly fast; uses only fixed-point arithmetic - * Achieves an RTF of ~3e-4 with 30 ms 48 KHz frames, ~3e-5 with 30 ms 8 KHz frames. - * Comparatively, Silero VAD v4 w/ [`ort`](https://ort.pyke.io/) achieves an RTF of ~3e-3 with 60 ms 16 KHz frames. -- Okay accuracy - * Great at distinguishing between silence and noise, but not between noise and speech. - * Earshot provides alternative models with slight accuracy gains compared to the base WebRTC model. +Achieves an RTF of 0.0014; 10x faster than Silero/TEN VAD. + +## Performance +Compiling with `RUSTFLAGS="-C target-cpu=native"` in release mode is highly recommended as it can cut processing time in half. diff --git a/benches/downsample.rs b/benches/downsample.rs deleted file mode 100644 index 5753dc8..0000000 --- a/benches/downsample.rs +++ /dev/null @@ -1,19 +0,0 @@ -use std::{fs, slice}; - -use criterion::{Criterion, criterion_group, criterion_main}; -use earshot::__internal_downsampling::downsample_2x; - -fn downsample_48khz_24khz(c: &mut Criterion) { - c.bench_function("Downsample 48KHz -> 24KHz", |b| { - let file = fs::read("tests/data/audio_tiny48.raw").unwrap(); - let i16_samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut out = vec![0; i16_samples.len() / 2]; - let mut filter_state = [0, 0]; - b.iter(|| { - downsample_2x(&i16_samples, &mut out, &mut filter_state); - }) - }); -} - -criterion_group!(downsample, downsample_48khz_24khz); -criterion_main!(downsample); diff --git a/benches/vad.rs b/benches/vad.rs index ae44f6d..df396cf 100644 --- a/benches/vad.rs +++ b/benches/vad.rs @@ -1,83 +1,23 @@ -use std::{fs, hint::black_box, slice}; +use std::hint::black_box; use criterion::{Criterion, criterion_group, criterion_main}; -use earshot::{VoiceActivityDetector, VoiceActivityModel, VoiceActivityProfile}; +use earshot::{Detector, QuantizedPredictor}; -fn bench_vad_8khz(c: &mut Criterion) { - let file = fs::read("tests/data/audio_tiny8.raw").unwrap(); - let i16_samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut vad = VoiceActivityDetector::new_with_model(VoiceActivityModel::ES_ALPHA, VoiceActivityProfile::VERY_AGGRESSIVE); - c.bench_function("VAD - 8 KHz (Real world)", |b| { +fn bench_vad(c: &mut Criterion) { + let mut vad = Detector::::default(); + c.bench_function("Single frame - f32", |b| { + let frame = (0..256 as i16).map(|i| i.wrapping_mul(i) as f32).collect::>(); b.iter(|| { - for frame in i16_samples.chunks_exact(240) { - let _ = black_box(vad.predict_8khz(black_box(frame))); - } + let _ = black_box(vad.predict_f32(black_box(&frame))); }) }); - c.bench_function("VAD - 8 KHz (Single frame)", |b| { - let frame = (0..240 as i16).map(|i| i.wrapping_mul(i)).collect::>(); + c.bench_function("Single frame - i16", |b| { + let frame = (0..256 as i16).map(|i| i.wrapping_mul(i)).collect::>(); b.iter(|| { - let _ = black_box(vad.predict_8khz(black_box(&frame))); + let _ = black_box(vad.predict_i16(black_box(&frame))); }) }); } -fn bench_vad_16khz(c: &mut Criterion) { - let file = fs::read("tests/data/audio_tiny16.raw").unwrap(); - let i16_samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut vad = VoiceActivityDetector::new_with_model(VoiceActivityModel::ES_ALPHA, VoiceActivityProfile::VERY_AGGRESSIVE); - c.bench_function("VAD - 16 KHz (Real world)", |b| { - b.iter(|| { - for frame in i16_samples.chunks_exact(240) { - let _ = black_box(vad.predict_16khz(black_box(frame))); - } - }) - }); - c.bench_function("VAD - 16 KHz (Single frame)", |b| { - let frame = (0..480 as i16).map(|i| i.wrapping_mul(i)).collect::>(); - b.iter(|| { - let _ = black_box(vad.predict_16khz(black_box(&frame))); - }) - }); -} - -fn bench_vad_32khz(c: &mut Criterion) { - let file = fs::read("tests/data/audio_tiny32.raw").unwrap(); - let i16_samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut vad = VoiceActivityDetector::new_with_model(VoiceActivityModel::ES_ALPHA, VoiceActivityProfile::VERY_AGGRESSIVE); - c.bench_function("VAD - 32 KHz (Real world)", |b| { - b.iter(|| { - for frame in i16_samples.chunks_exact(240) { - let _ = black_box(vad.predict_32khz(black_box(frame))); - } - }) - }); - c.bench_function("VAD - 32 KHz (Single frame)", |b| { - let frame = (0..960 as i16).map(|i| i.wrapping_mul(i)).collect::>(); - b.iter(|| { - let _ = black_box(vad.predict_32khz(black_box(&frame))); - }) - }); -} - -fn bench_vad_48khz(c: &mut Criterion) { - let file = fs::read("tests/data/audio_tiny48.raw").unwrap(); - let i16_samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut vad = VoiceActivityDetector::new_with_model(VoiceActivityModel::ES_ALPHA, VoiceActivityProfile::VERY_AGGRESSIVE); - c.bench_function("VAD - 48 KHz (Real world)", |b| { - b.iter(|| { - for frame in i16_samples.chunks_exact(240) { - let _ = black_box(vad.predict_48khz(black_box(frame))); - } - }) - }); - c.bench_function("VAD - 48 KHz (Single frame)", |b| { - let frame = (0..1440 as i16).map(|i| i.wrapping_mul(i)).collect::>(); - b.iter(|| { - let _ = black_box(vad.predict_48khz(black_box(&frame))); - }) - }); -} - -criterion_group!(vad, bench_vad_8khz, bench_vad_16khz, bench_vad_32khz, bench_vad_48khz); +criterion_group!(vad, bench_vad); criterion_main!(vad); diff --git a/examples/extract-voice.rs b/examples/extract-voice.rs new file mode 100644 index 0000000..4f5ff03 --- /dev/null +++ b/examples/extract-voice.rs @@ -0,0 +1,42 @@ +use core::{mem, ptr, slice}; +use std::{ + env::args, + fs::{self, File}, + io::Write +}; + +use earshot::{Detector, QuantizedPredictor}; + +fn main() { + let mut args = args().skip(1); + let Some(input) = args.next() else { + eprintln!("cargo run --example extract-voice -- [wav] [out]"); + return; + }; + let Some(output) = args.next() else { + eprintln!("cargo run --example extract-voice -- [wav] [out]"); + return; + }; + + let mut detector = Detector::::default(); + + let mut out = File::create(output).unwrap(); + + let wav = fs::read(input).unwrap(); + for x in wav[44..].chunks_exact(512) { + let mut samples = vec![0; 256]; + for i in 0..256 { + samples[i] = i16::from_le_bytes([x[(i * 2)], x[(i * 2) + 1]]); + } + + let score = detector.predict_i16(&samples); + if score >= 0.5 { + println!("voice"); + out.write_all(&x).unwrap(); + } else { + println!("silence {score}"); + } + } + + out.flush().unwrap(); +} diff --git a/src/energy.rs b/src/energy.rs deleted file mode 100644 index ecea29c..0000000 --- a/src/energy.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::util::{norm_i32, size_in_bits}; - -fn get_scaling_square(inv: &[i16], times: usize) -> u8 { - let n_bits = size_in_bits(times as i32); - let smax = inv.iter().map(|c| c.abs()).max().unwrap() as i32; - let t = norm_i32(smax * smax) as u8; - if smax != 0 { - if t > n_bits { 0 } else { n_bits - t } - } else { - 0 // Since norm(0) returns 0 - } -} - -pub struct EnergyResult { - pub energy: i32, - pub scaling_factor: u8 -} - -pub fn energy(inv: &[i16]) -> EnergyResult { - let scaling_factor = get_scaling_square(inv, inv.len()); - let energy = inv.iter().map(|x| (*x as i32 * *x as i32) >> scaling_factor).sum(); - EnergyResult { energy, scaling_factor } -} - -#[cfg(test)] -mod tests { - use super::EnergyResult; - use crate::energy::energy; - - #[test] - fn test_energy() { - let inv = [1, 2, 33, 100]; - let EnergyResult { energy, scaling_factor } = energy(&inv); - assert_eq!(energy, 11094); - assert_eq!(scaling_factor, 0); - } -} diff --git a/src/fft/mod.rs b/src/fft/mod.rs new file mode 100644 index 0000000..9412274 --- /dev/null +++ b/src/fft/mod.rs @@ -0,0 +1,272 @@ +//! From https://gitlab.com/teskje/microfft-rs +//! Copyright (c) 2020-2024 Jan Teske, MIT license +//! +//! This is actually 2.5x slower than realfft/rustfft, but it's much simpler and doesn't require `std`. +//! Still, it only takes ~1.8us - only slightly slower than the first NN layer. + +use core::{ + ops::{Add, Mul, Sub}, + slice +}; + +mod tables; + +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(C)] +pub struct Complex32 { + pub re: f32, + pub im: f32 +} + +impl Complex32 { + pub const fn new(re: f32, im: f32) -> Self { + Self { re, im } + } + + pub const fn norm_sqr(&self) -> f32 { + self.re * self.re + self.im * self.im + } +} + +impl Add for Complex32 { + type Output = Self; + + #[inline] + fn add(self, rhs: Complex32) -> Self::Output { + Complex32::new(self.re + rhs.re, self.im + rhs.im) + } +} +impl Sub for Complex32 { + type Output = Self; + + #[inline] + fn sub(self, rhs: Complex32) -> Self::Output { + Complex32::new(self.re - rhs.re, self.im - rhs.im) + } +} +impl Mul for Complex32 { + type Output = Self; + + #[inline] + fn mul(self, rhs: Complex32) -> Self::Output { + let re = self.re * rhs.re - self.im * rhs.im; + let im = self.re * rhs.im + self.im * rhs.re; + Complex32::new(re, im) + } +} +impl Mul for Complex32 { + type Output = Self; + + #[inline] + fn mul(self, rhs: f32) -> Self::Output { + Complex32::new(self.re * rhs, self.im * rhs) + } +} + +pub(crate) trait CFft { + type Half: CFft; + + const N: usize; + const LOG2_N: usize = Self::N.ilog2() as usize; + + const BITREV_TABLE: &'static [u16] = tables::BITREV[Self::LOG2_N]; + + #[inline] + fn transform(x: &mut [Complex32]) -> &mut [Complex32] { + debug_assert_eq!(x.len(), Self::N); + + Self::bit_reverse_reorder(x); + Self::compute_butterflies(x); + x + } + + #[inline] + fn bit_reverse_reorder(x: &mut [Complex32]) { + debug_assert_eq!(x.len(), Self::N); + + for i in 0..Self::N { + let j = Self::BITREV_TABLE[i] as usize; + if i != j { + x.swap(i, j); + } + } + } + + #[inline] + fn compute_butterflies(x: &mut [Complex32]) { + debug_assert_eq!(x.len(), Self::N); + + let m = Self::N / 2; + let u = m / 2; + + let table_len = tables::SINE.len(); + let table_stride = (table_len + 1) * 4 / Self::N; + + Self::Half::compute_butterflies(&mut x[..m]); + Self::Half::compute_butterflies(&mut x[m..]); + + // [k = 0] twiddle factor: `1 + 0i` + let (x_0, x_m) = (x[0], x[m]); + x[0] = x_0 + x_m; + x[m] = x_0 - x_m; + + // [k in [1, m/2)] twiddle factor: + // - re from SINE table backwards and negative + // - im from SINE table directly + for k in 1..u { + let s = k * table_stride; + let re = tables::SINE[table_len - s] * -1.; + let im = tables::SINE[s - 1]; + let twiddle = Complex32::new(re, im); + + let (x_k, x_km) = (x[k], x[k + m]); + let y = twiddle * x_km; + x[k] = x_k + y; + x[k + m] = x_k - y; + } + + // [k = m/2] twiddle factor: `0 - 1i` + let (x_u, x_um) = (x[u], x[u + m]); + let y = x_um * Complex32::new(0., -1.); + x[u] = x_u + y; + x[u + m] = x_u - y; + + // [k in (m/2, m)] twiddle factor: + // - re from SINE table directly + // - im from SINE table backwards + for k in (u + 1)..m { + let s = (k - u) * table_stride; + let re = tables::SINE[s - 1]; + let im = tables::SINE[table_len - s]; + let twiddle = Complex32::new(re, im); + + let (x_k, x_km) = (x[k], x[k + m]); + let y = twiddle * x_km; + x[k] = x_k + y; + x[k + m] = x_k - y; + } + } +} + +pub(crate) struct CFftN; + +impl CFft for CFftN<1> { + type Half = Self; + + const N: usize = 1; + + #[inline] + fn bit_reverse_reorder(x: &mut [Complex32]) { + debug_assert_eq!(x.len(), 1); + } + + #[inline] + fn compute_butterflies(x: &mut [Complex32]) { + debug_assert_eq!(x.len(), 1); + } +} + +impl CFft for CFftN<2> { + type Half = CFftN<1>; + + const N: usize = 2; + + #[inline] + fn compute_butterflies(x: &mut [Complex32]) { + debug_assert_eq!(x.len(), 2); + + let (x_0, x_1) = (x[0], x[1]); + x[0] = x_0 + x_1; + x[1] = x_0 - x_1; + } +} + +macro_rules! cfft_impls { + ($($N:expr),*) => { + $( + impl CFft for CFftN<$N> { + type Half = CFftN<{$N / 2}>; + + const N: usize = $N; + } + )* + }; +} + +cfft_impls! { 4, 8, 16, 32, 64, 128, 256, 512, 1024 } + +pub(crate) trait RFft { + type CFft: CFft; + + const N: usize = Self::CFft::N * 2; + + #[inline] + fn transform(x: &mut [f32]) -> &mut [Complex32] { + debug_assert_eq!(x.len(), Self::N); + + let x = Self::pack_complex(x); + + Self::CFft::transform(x); + Self::recombine(x); + x + } + + #[inline] + fn pack_complex(x: &mut [f32]) -> &mut [Complex32] { + assert_eq!(x.len(), Self::N); + + let len = Self::N / 2; + let data = x.as_mut_ptr().cast::(); + unsafe { slice::from_raw_parts_mut(data, len) } + } + + #[inline] + fn recombine(x: &mut [Complex32]) { + let m = Self::CFft::N; + debug_assert_eq!(x.len(), m); + + let table_len = tables::SINE.len(); + let table_stride = (table_len + 1) * 4 / Self::N; + + // The real part of the first element is the DC value. + // Additionally, the real-valued coefficient at the Nyquist frequency + // is stored in the imaginary part. + let x0 = x[0]; + x[0] = Complex32::new(x0.re + x0.im, x0.re - x0.im); + + let u = m / 2; + for k in 1..u { + let s = k * table_stride; + let twiddle_re = tables::SINE[table_len - s] * -1.; + let twiddle_im = tables::SINE[s - 1]; + + let (x_k, x_nk) = (x[k], x[m - k]); + // 20% speed boost just by replacing / 2 with * 0.5 here! + let sum = (x_k + x_nk) * 0.5; + let diff = (x_k - x_nk) * 0.5; + + x[k] = Complex32::new(sum.re + twiddle_re * sum.im + twiddle_im * diff.re, diff.im + twiddle_im * sum.im - twiddle_re * diff.re); + x[m - k] = Complex32::new(sum.re - twiddle_re * sum.im - twiddle_im * diff.re, -diff.im + twiddle_im * sum.im - twiddle_re * diff.re); + } + + let xu = x[u]; + x[u] = Complex32::new(xu.re, -xu.im); + } +} + +struct RFftN; + +impl RFft for RFftN<1024> { + type CFft = CFftN<512>; +} + +pub fn rfft_1024(x: &mut [f32]) -> &mut [Complex32] { + debug_assert_eq!(x.len(), 1026); + let mut comp = RFftN::<1024>::transform(&mut x[..1024]); + // microfft packs Nyquist real into DC bin imaginary so the output can fit in the original 1024-wide buffer. we expect + // 513 values to get the mel spectrogram, so unpack them + comp = unsafe { slice::from_raw_parts_mut(comp.as_mut_ptr(), comp.len() + 1) }; + comp[comp.len() - 1].re = comp[0].im; + comp[0].im = 0.0; + comp +} diff --git a/src/fft/tables.rs b/src/fft/tables.rs new file mode 100644 index 0000000..23a2ff9 --- /dev/null +++ b/src/fft/tables.rs @@ -0,0 +1,348 @@ +#![allow(clippy::excessive_precision)] +#![allow(clippy::unreadable_literal)] + +pub(crate) const SINE: &[f32] = &[ + -0.006135884649154475, + -0.012271538285719925, + -0.01840672990580482, + -0.024541228522912288, + -0.030674803176636626, + -0.03680722294135883, + -0.04293825693494082, + -0.049067674327418015, + -0.055195244349689934, + -0.06132073630220858, + -0.06744391956366405, + -0.07356456359966743, + -0.07968243797143013, + -0.0857973123444399, + -0.09190895649713272, + -0.0980171403295606, + -0.10412163387205459, + -0.11022220729388306, + -0.11631863091190475, + -0.1224106751992162, + -0.12849811079379317, + -0.13458070850712617, + -0.1406582393328492, + -0.14673047445536175, + -0.15279718525844344, + -0.15885814333386145, + -0.1649131204899699, + -0.17096188876030122, + -0.17700422041214875, + -0.18303988795514095, + -0.1890686641498062, + -0.19509032201612825, + -0.2011046348420919, + -0.20711137619221856, + -0.21311031991609136, + -0.2191012401568698, + -0.22508391135979283, + -0.2310581082806711, + -0.2370236059943672, + -0.24298017990326387, + -0.24892760574572015, + -0.25486565960451457, + -0.2607941179152755, + -0.26671275747489837, + -0.272621355449949, + -0.27851968938505306, + -0.2844075372112719, + -0.29028467725446233, + -0.2961508882436238, + -0.3020059493192281, + -0.30784964004153487, + -0.3136817403988915, + -0.3195020308160157, + -0.3253102921622629, + -0.33110630575987643, + -0.33688985339222005, + -0.3426607173119944, + -0.34841868024943456, + -0.35416352542049034, + -0.3598950365349881, + -0.36561299780477385, + -0.37131719395183754, + -0.37700741021641826, + -0.3826834323650898, + -0.38834504669882625, + -0.3939920400610481, + -0.3996241998456468, + -0.40524131400498986, + -0.4108431710579039, + -0.41642956009763715, + -0.4220002707997997, + -0.4275550934302821, + -0.43309381885315196, + -0.43861623853852766, + -0.4441221445704292, + -0.44961132965460654, + -0.45508358712634384, + -0.46053871095824, + -0.4659764957679662, + -0.47139673682599764, + -0.4767992300633221, + -0.4821837720791227, + -0.487550160148436, + -0.49289819222978404, + -0.4982276669727818, + -0.5035383837257176, + -0.508830142543107, + -0.5141027441932217, + -0.5193559901655896, + -0.524589682678469, + -0.5298036246862946, + -0.5349976198870972, + -0.5401714727298929, + -0.5453249884220465, + -0.5504579729366048, + -0.5555702330196022, + -0.560661576197336, + -0.5657318107836131, + -0.5707807458869673, + -0.5758081914178453, + -0.5808139580957645, + -0.5857978574564389, + -0.5907597018588742, + -0.5956993044924334, + -0.600616479383869, + -0.6055110414043255, + -0.6103828062763095, + -0.6152315905806268, + -0.6200572117632891, + -0.6248594881423863, + -0.629638238914927, + -0.6343932841636455, + -0.6391244448637757, + -0.6438315428897914, + -0.6485144010221124, + -0.6531728429537768, + -0.6578066932970786, + -0.6624157775901718, + -0.6669999223036375, + -0.6715589548470183, + -0.6760927035753159, + -0.680600997795453, + -0.6850836677727004, + -0.6895405447370668, + -0.6939714608896539, + -0.6983762494089728, + -0.7027547444572253, + -0.7071067811865476, + -0.7114321957452164, + -0.7157308252838186, + -0.7200025079613817, + -0.7242470829514669, + -0.7284643904482252, + -0.7326542716724128, + -0.7368165688773698, + -0.7409511253549591, + -0.745057785441466, + -0.7491363945234593, + -0.7531867990436125, + -0.7572088465064846, + -0.7612023854842618, + -0.765167265622459, + -0.7691033376455796, + -0.7730104533627369, + -0.7768884656732324, + -0.7807372285720945, + -0.7845565971555752, + -0.7883464276266062, + -0.7921065773002123, + -0.7958369046088835, + -0.799537269107905, + -0.8032075314806448, + -0.8068475535437992, + -0.8104571982525948, + -0.8140363297059483, + -0.8175848131515837, + -0.8211025149911046, + -0.8245893027850253, + -0.8280450452577558, + -0.8314696123025452, + -0.83486287498638, + -0.838224705554838, + -0.8415549774368983, + -0.844853565249707, + -0.8481203448032971, + -0.8513551931052652, + -0.8545579883654005, + -0.8577286100002721, + -0.8608669386377673, + -0.8639728561215867, + -0.8670462455156926, + -0.8700869911087113, + -0.8730949784182901, + -0.8760700941954066, + -0.8790122264286334, + -0.8819212643483549, + -0.8847970984309378, + -0.8876396204028539, + -0.8904487232447579, + -0.8932243011955153, + -0.8959662497561851, + -0.8986744656939538, + -0.901348847046022, + -0.9039892931234433, + -0.9065957045149153, + -0.9091679830905223, + -0.9117060320054299, + -0.9142097557035307, + -0.9166790599210427, + -0.9191138516900578, + -0.9215140393420419, + -0.9238795325112867, + -0.9262102421383113, + -0.9285060804732156, + -0.9307669610789837, + -0.9329927988347388, + -0.9351835099389475, + -0.937339011912575, + -0.9394592236021899, + -0.9415440651830208, + -0.9435934581619604, + -0.9456073253805213, + -0.9475855910177411, + -0.9495281805930367, + -0.9514350209690083, + -0.9533060403541938, + -0.9551411683057707, + -0.9569403357322089, + -0.9587034748958716, + -0.9604305194155658, + -0.9621214042690416, + -0.9637760657954398, + -0.9653944416976894, + -0.9669764710448521, + -0.9685220942744173, + -0.970031253194544, + -0.9715038909862518, + -0.9729399522055601, + -0.9743393827855759, + -0.9757021300385286, + -0.9770281426577544, + -0.9783173707196277, + -0.9795697656854405, + -0.9807852804032304, + -0.9819638691095552, + -0.9831054874312163, + -0.984210092386929, + -0.9852776423889412, + -0.9863080972445987, + -0.9873014181578584, + -0.9882575677307495, + -0.989176509964781, + -0.9900582102622971, + -0.99090263542778, + -0.9917097536690995, + -0.99247953459871, + -0.9932119492347945, + -0.9939069700023561, + -0.9945645707342554, + -0.9951847266721968, + -0.9957674144676598, + -0.996312612182778, + -0.9968202992911657, + -0.9972904566786902, + -0.9977230666441916, + -0.9981181129001492, + -0.9984755805732948, + -0.9987954562051724, + -0.9990777277526454, + -0.9993223845883495, + -0.9995294175010931, + -0.9996988186962042, + -0.9998305817958234, + -0.9999247018391445, + -0.9999811752826011 +]; + +pub(crate) const BITREV: &[&[u16]] = &[ + &[0], + &[0, 1], + &[0, 2, 2, 3], + &[0, 4, 2, 6, 4, 5, 6, 7], + &[0, 8, 4, 12, 4, 10, 6, 14, 8, 9, 10, 13, 12, 13, 14, 15], + &[ + 0, 16, 8, 24, 4, 20, 12, 28, 8, 18, 10, 26, 12, 22, 14, 30, 16, 17, 18, 25, 20, 21, 22, 29, 24, 25, 26, 27, 28, 29, 30, 31 + ], + &[ + 0, 32, 16, 48, 8, 40, 24, 56, 8, 36, 20, 52, 12, 44, 28, 60, 16, 34, 18, 50, 20, 42, 26, 58, 24, 38, 26, 54, 28, 46, 30, 62, 32, 33, 34, 49, 36, 41, + 38, 57, 40, 41, 42, 53, 44, 45, 46, 61, 48, 49, 50, 51, 52, 53, 54, 59, 56, 57, 58, 59, 60, 61, 62, 63 + ], + &[ + 0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120, 16, 68, 36, 100, 20, 84, 52, 116, 24, 76, 44, 108, 28, 92, 60, 124, 32, 66, 34, 98, + 36, 82, 50, 114, 40, 74, 42, 106, 44, 90, 58, 122, 48, 70, 50, 102, 52, 86, 54, 118, 56, 78, 58, 110, 60, 94, 62, 126, 64, 65, 66, 97, 68, 81, 70, 113, + 72, 73, 74, 105, 76, 89, 78, 121, 80, 81, 82, 101, 84, 85, 86, 117, 88, 89, 90, 109, 92, 93, 94, 125, 96, 97, 98, 99, 100, 101, 102, 115, 104, 105, + 106, 107, 108, 109, 110, 123, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127 + ], + &[ + 0, 128, 64, 192, 32, 160, 96, 224, 16, 144, 80, 208, 48, 176, 112, 240, 16, 136, 72, 200, 40, 168, 104, 232, 24, 152, 88, 216, 56, 184, 120, 248, 32, + 132, 68, 196, 36, 164, 100, 228, 40, 148, 84, 212, 52, 180, 116, 244, 48, 140, 76, 204, 52, 172, 108, 236, 56, 156, 92, 220, 60, 188, 124, 252, 64, + 130, 66, 194, 68, 162, 98, 226, 72, 146, 82, 210, 76, 178, 114, 242, 80, 138, 82, 202, 84, 170, 106, 234, 88, 154, 90, 218, 92, 186, 122, 250, 96, 134, + 98, 198, 100, 166, 102, 230, 104, 150, 106, 214, 108, 182, 118, 246, 112, 142, 114, 206, 116, 174, 118, 238, 120, 158, 122, 222, 124, 190, 126, 254, + 128, 129, 130, 193, 132, 161, 134, 225, 136, 145, 138, 209, 140, 177, 142, 241, 144, 145, 146, 201, 148, 169, 150, 233, 152, 153, 154, 217, 156, 185, + 158, 249, 160, 161, 162, 197, 164, 165, 166, 229, 168, 169, 170, 213, 172, 181, 174, 245, 176, 177, 178, 205, 180, 181, 182, 237, 184, 185, 186, 221, + 188, 189, 190, 253, 192, 193, 194, 195, 196, 197, 198, 227, 200, 201, 202, 211, 204, 205, 206, 243, 208, 209, 210, 211, 212, 213, 214, 235, 216, 217, + 218, 219, 220, 221, 222, 251, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 247, 240, 241, 242, 243, 244, 245, 246, 247, + 248, 249, 250, 251, 252, 253, 254, 255 + ], + &[ + 0, 256, 128, 384, 64, 320, 192, 448, 32, 288, 160, 416, 96, 352, 224, 480, 16, 272, 144, 400, 80, 336, 208, 464, 48, 304, 176, 432, 112, 368, 240, 496, + 32, 264, 136, 392, 72, 328, 200, 456, 40, 296, 168, 424, 104, 360, 232, 488, 48, 280, 152, 408, 88, 344, 216, 472, 56, 312, 184, 440, 120, 376, 248, + 504, 64, 260, 132, 388, 68, 324, 196, 452, 72, 292, 164, 420, 100, 356, 228, 484, 80, 276, 148, 404, 84, 340, 212, 468, 88, 308, 180, 436, 116, 372, + 244, 500, 96, 268, 140, 396, 100, 332, 204, 460, 104, 300, 172, 428, 108, 364, 236, 492, 112, 284, 156, 412, 116, 348, 220, 476, 120, 316, 188, 444, + 124, 380, 252, 508, 128, 258, 130, 386, 132, 322, 194, 450, 136, 290, 162, 418, 140, 354, 226, 482, 144, 274, 146, 402, 148, 338, 210, 466, 152, 306, + 178, 434, 156, 370, 242, 498, 160, 266, 162, 394, 164, 330, 202, 458, 168, 298, 170, 426, 172, 362, 234, 490, 176, 282, 178, 410, 180, 346, 218, 474, + 184, 314, 186, 442, 188, 378, 250, 506, 192, 262, 194, 390, 196, 326, 198, 454, 200, 294, 202, 422, 204, 358, 230, 486, 208, 278, 210, 406, 212, 342, + 214, 470, 216, 310, 218, 438, 220, 374, 246, 502, 224, 270, 226, 398, 228, 334, 230, 462, 232, 302, 234, 430, 236, 366, 238, 494, 240, 286, 242, 414, + 244, 350, 246, 478, 248, 318, 250, 446, 252, 382, 254, 510, 256, 257, 258, 385, 260, 321, 262, 449, 264, 289, 266, 417, 268, 353, 270, 481, 272, 273, + 274, 401, 276, 337, 278, 465, 280, 305, 282, 433, 284, 369, 286, 497, 288, 289, 290, 393, 292, 329, 294, 457, 296, 297, 298, 425, 300, 361, 302, 489, + 304, 305, 306, 409, 308, 345, 310, 473, 312, 313, 314, 441, 316, 377, 318, 505, 320, 321, 322, 389, 324, 325, 326, 453, 328, 329, 330, 421, 332, 357, + 334, 485, 336, 337, 338, 405, 340, 341, 342, 469, 344, 345, 346, 437, 348, 373, 350, 501, 352, 353, 354, 397, 356, 357, 358, 461, 360, 361, 362, 429, + 364, 365, 366, 493, 368, 369, 370, 413, 372, 373, 374, 477, 376, 377, 378, 445, 380, 381, 382, 509, 384, 385, 386, 387, 388, 389, 390, 451, 392, 393, + 394, 419, 396, 397, 398, 483, 400, 401, 402, 403, 404, 405, 406, 467, 408, 409, 410, 435, 412, 413, 414, 499, 416, 417, 418, 419, 420, 421, 422, 459, + 424, 425, 426, 427, 428, 429, 430, 491, 432, 433, 434, 435, 436, 437, 438, 475, 440, 441, 442, 443, 444, 445, 446, 507, 448, 449, 450, 451, 452, 453, + 454, 455, 456, 457, 458, 459, 460, 461, 462, 487, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 503, 480, 481, 482, 483, + 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511 + ], + &[ + 0, 512, 256, 768, 128, 640, 384, 896, 64, 576, 320, 832, 192, 704, 448, 960, 32, 544, 288, 800, 160, 672, 416, 928, 96, 608, 352, 864, 224, 736, 480, + 992, 32, 528, 272, 784, 144, 656, 400, 912, 80, 592, 336, 848, 208, 720, 464, 976, 48, 560, 304, 816, 176, 688, 432, 944, 112, 624, 368, 880, 240, 752, + 496, 1008, 64, 520, 264, 776, 136, 648, 392, 904, 72, 584, 328, 840, 200, 712, 456, 968, 80, 552, 296, 808, 168, 680, 424, 936, 104, 616, 360, 872, + 232, 744, 488, 1000, 96, 536, 280, 792, 152, 664, 408, 920, 104, 600, 344, 856, 216, 728, 472, 984, 112, 568, 312, 824, 184, 696, 440, 952, 120, 632, + 376, 888, 248, 760, 504, 1016, 128, 516, 260, 772, 132, 644, 388, 900, 136, 580, 324, 836, 196, 708, 452, 964, 144, 548, 292, 804, 164, 676, 420, 932, + 152, 612, 356, 868, 228, 740, 484, 996, 160, 532, 276, 788, 164, 660, 404, 916, 168, 596, 340, 852, 212, 724, 468, 980, 176, 564, 308, 820, 180, 692, + 436, 948, 184, 628, 372, 884, 244, 756, 500, 1012, 192, 524, 268, 780, 196, 652, 396, 908, 200, 588, 332, 844, 204, 716, 460, 972, 208, 556, 300, 812, + 212, 684, 428, 940, 216, 620, 364, 876, 236, 748, 492, 1004, 224, 540, 284, 796, 228, 668, 412, 924, 232, 604, 348, 860, 236, 732, 476, 988, 240, 572, + 316, 828, 244, 700, 444, 956, 248, 636, 380, 892, 252, 764, 508, 1020, 256, 514, 258, 770, 260, 642, 386, 898, 264, 578, 322, 834, 268, 706, 450, 962, + 272, 546, 290, 802, 276, 674, 418, 930, 280, 610, 354, 866, 284, 738, 482, 994, 288, 530, 290, 786, 292, 658, 402, 914, 296, 594, 338, 850, 300, 722, + 466, 978, 304, 562, 306, 818, 308, 690, 434, 946, 312, 626, 370, 882, 316, 754, 498, 1010, 320, 522, 322, 778, 324, 650, 394, 906, 328, 586, 330, 842, + 332, 714, 458, 970, 336, 554, 338, 810, 340, 682, 426, 938, 344, 618, 362, 874, 348, 746, 490, 1002, 352, 538, 354, 794, 356, 666, 410, 922, 360, 602, + 362, 858, 364, 730, 474, 986, 368, 570, 370, 826, 372, 698, 442, 954, 376, 634, 378, 890, 380, 762, 506, 1018, 384, 518, 386, 774, 388, 646, 390, 902, + 392, 582, 394, 838, 396, 710, 454, 966, 400, 550, 402, 806, 404, 678, 422, 934, 408, 614, 410, 870, 412, 742, 486, 998, 416, 534, 418, 790, 420, 662, + 422, 918, 424, 598, 426, 854, 428, 726, 470, 982, 432, 566, 434, 822, 436, 694, 438, 950, 440, 630, 442, 886, 444, 758, 502, 1014, 448, 526, 450, 782, + 452, 654, 454, 910, 456, 590, 458, 846, 460, 718, 462, 974, 464, 558, 466, 814, 468, 686, 470, 942, 472, 622, 474, 878, 476, 750, 494, 1006, 480, 542, + 482, 798, 484, 670, 486, 926, 488, 606, 490, 862, 492, 734, 494, 990, 496, 574, 498, 830, 500, 702, 502, 958, 504, 638, 506, 894, 508, 766, 510, 1022, + 512, 513, 514, 769, 516, 641, 518, 897, 520, 577, 522, 833, 524, 705, 526, 961, 528, 545, 530, 801, 532, 673, 534, 929, 536, 609, 538, 865, 540, 737, + 542, 993, 544, 545, 546, 785, 548, 657, 550, 913, 552, 593, 554, 849, 556, 721, 558, 977, 560, 561, 562, 817, 564, 689, 566, 945, 568, 625, 570, 881, + 572, 753, 574, 1009, 576, 577, 578, 777, 580, 649, 582, 905, 584, 585, 586, 841, 588, 713, 590, 969, 592, 593, 594, 809, 596, 681, 598, 937, 600, 617, + 602, 873, 604, 745, 606, 1001, 608, 609, 610, 793, 612, 665, 614, 921, 616, 617, 618, 857, 620, 729, 622, 985, 624, 625, 626, 825, 628, 697, 630, 953, + 632, 633, 634, 889, 636, 761, 638, 1017, 640, 641, 642, 773, 644, 645, 646, 901, 648, 649, 650, 837, 652, 709, 654, 965, 656, 657, 658, 805, 660, 677, + 662, 933, 664, 665, 666, 869, 668, 741, 670, 997, 672, 673, 674, 789, 676, 677, 678, 917, 680, 681, 682, 853, 684, 725, 686, 981, 688, 689, 690, 821, + 692, 693, 694, 949, 696, 697, 698, 885, 700, 757, 702, 1013, 704, 705, 706, 781, 708, 709, 710, 909, 712, 713, 714, 845, 716, 717, 718, 973, 720, 721, + 722, 813, 724, 725, 726, 941, 728, 729, 730, 877, 732, 749, 734, 1005, 736, 737, 738, 797, 740, 741, 742, 925, 744, 745, 746, 861, 748, 749, 750, 989, + 752, 753, 754, 829, 756, 757, 758, 957, 760, 761, 762, 893, 764, 765, 766, 1021, 768, 769, 770, 771, 772, 773, 774, 899, 776, 777, 778, 835, 780, 781, + 782, 963, 784, 785, 786, 803, 788, 789, 790, 931, 792, 793, 794, 867, 796, 797, 798, 995, 800, 801, 802, 803, 804, 805, 806, 915, 808, 809, 810, 851, + 812, 813, 814, 979, 816, 817, 818, 819, 820, 821, 822, 947, 824, 825, 826, 883, 828, 829, 830, 1011, 832, 833, 834, 835, 836, 837, 838, 907, 840, 841, + 842, 843, 844, 845, 846, 971, 848, 849, 850, 851, 852, 853, 854, 939, 856, 857, 858, 875, 860, 861, 862, 1003, 864, 865, 866, 867, 868, 869, 870, 923, + 872, 873, 874, 875, 876, 877, 878, 987, 880, 881, 882, 883, 884, 885, 886, 955, 888, 889, 890, 891, 892, 893, 894, 1019, 896, 897, 898, 899, 900, 901, + 902, 903, 904, 905, 906, 907, 908, 909, 910, 967, 912, 913, 914, 915, 916, 917, 918, 935, 920, 921, 922, 923, 924, 925, 926, 999, 928, 929, 930, 931, + 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 983, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 1015, 960, 961, + 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 1007, + 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, + 1018, 1019, 1020, 1021, 1022, 1023 + ] +]; diff --git a/src/filterbank.rs b/src/filterbank.rs deleted file mode 100644 index 79c6bd7..0000000 --- a/src/filterbank.rs +++ /dev/null @@ -1,304 +0,0 @@ -use crate::{ - energy::{EnergyResult, energy}, - util::norm_u32 -}; - -// High pass filtering, with a cut-off frequency at 80 Hz, if the |data_in| is -// sampled at 500 Hz. -// -// - data_in [i] : Input audio data sampled at 500 Hz. -// - data_length [i] : Length of input and output data. -// - filter_state [i/o] : State of the filter. -// - data_out [o] : Output audio data in the frequency interval 80 - 250 Hz. -fn highpass(inv: &[i16], out: &mut [i16], filter_state: &mut [i16]) { - const COEFFS_ZERO: [i32; 3] = [6631, -13262, 6631]; - const COEFFS_POLE: [i32; 2] = [-7756, 5620]; - assert_eq!(inv.len(), out.len()); - // The sum of the absolute values of the impulse response: - // The zero/pole-filter has a max amplification of a single sample of: 1.4546 - // Impulse response: 0.4047 -0.6179 -0.0266 0.1993 0.1035 -0.0194 - // The all-zero section has a max amplification of a single sample of: 1.6189 - // Impulse response: 0.4047 -0.8094 0.4047 0 0 0 - // The all-pole section has a max amplification of a single sample of: 1.9931 - // Impulse response: 1.0000 0.4734 -0.1189 -0.2187 -0.0627 0.04532 - for i in 0..inv.len() { - // All-zero section (filter coefficients in Q14). - let mut t = COEFFS_ZERO[0] * inv[i] as i32; - t += COEFFS_ZERO[1] * filter_state[0] as i32; - t += COEFFS_ZERO[2] * filter_state[1] as i32; - filter_state[1] = filter_state[0]; - filter_state[0] = inv[i]; - - // All-pole section (filter coefficients in Q14). - t -= COEFFS_POLE[0] * filter_state[2] as i32; - t -= COEFFS_POLE[1] * filter_state[3] as i32; - filter_state[3] = filter_state[2]; - filter_state[2] = (t >> 14) as i16; - out[i] = filter_state[2]; - } -} - -// All pass filtering of |data_in|, used before splitting the signal into two -// frequency bands (low pass vs high pass). -// Note that |data_in| and |data_out| can NOT correspond to the same address. -// -// - inv [i] : Input audio signal given in Q0. -// - out [o] : Output audio signal given in Q(-1). -// - filter_state [i/o] : State of the filter given in Q(-1). -// - filter_coefficient [i] : Given in Q15. -fn allpass(inv: &[i16], out: &mut [i16], filter_state: &mut i16, filter_coefficient: i16) { - // The filter can only cause overflow (in the i16 output variable) - // if more than 4 consecutive input numbers are of maximum value and - // has the the same sign as the impulse responses first taps. - // First 6 taps of the impulse response: - // 0.6399 0.5905 -0.3779 0.2418 -0.1547 0.0990 - let mut state32 = (*filter_state as i32) * (1 << 16); - let mut j = 0; - for i in 0..out.len() { - let tmp32 = state32 + filter_coefficient as i32 * inv[j] as i32; - let tmp16 = (tmp32 >> 16) as i16; - out[i] = tmp16; - state32 = (inv[j] as i32 * (1 << 14)) - filter_coefficient as i32 * tmp16 as i32; - state32 *= 2; - j += 2; - } - *filter_state = (state32 >> 16) as i16; -} - -// Splits |data_in| into |hp_data_out| and |lp_data_out| corresponding to -// an upper (high pass) part and a lower (low pass) part respectively. -// -// - inv [i] : Input audio data to be split into two frequency bands. -// - highpass_out [o] : Output audio data of the upper half of the spectrum. The length is |data_length| / 2. -// - upper_state [i/o] : State of the upper filter, given in Q(-1). -// - lowpass_out [o] : Output audio data of the lower half of the spectrum. The length is |data_length| / 2. -// - lower_state [i/o] : State of the lower filter, given in Q(-1). -fn split(inv: &[i16], highpass_out: &mut [i16], upper_state: &mut i16, lowpass_out: &mut [i16], lower_state: &mut i16) { - const HIGHPASS_COEFFICIENT: i16 = 20972; // 0.64 in Q15 - const LOWPASS_COEFFICIENT: i16 = 5571; // 0.17 in Q15 - assert_eq!(inv.len() / 2, highpass_out.len()); - assert_eq!(inv.len() / 2, lowpass_out.len()); - allpass(inv, highpass_out, upper_state, HIGHPASS_COEFFICIENT); - allpass(&inv[1..], lowpass_out, lower_state, LOWPASS_COEFFICIENT); - for i in 0..(inv.len() >> 1) { - let tmp = highpass_out[i]; - // TODO: do we want this to be wrapping? - highpass_out[i] = highpass_out[i].wrapping_sub(lowpass_out[i]); - lowpass_out[i] = lowpass_out[i].wrapping_add(tmp); - } -} - -// Calculates the energy of |data_in| in dB, and also updates an overall -// |total_energy| if necessary. -// -// - data_in [i] : Input audio data for energy calculation. -// - data_length [i] : Length of input data. -// - offset [i] : Offset value added to |log_energy|. -// - total_energy [i/o] : An external energy updated with the energy of |data_in|. NOTE: |total_energy| is only updated -// if |total_energy| <= |kMinEnergy|. -// - log_energy [o] : 10 * log10("energy of |data_in|") given in Q4. -// returns log_energy -fn log_of_energy(inv: &[i16], offset: i16, total_energy: &mut i16) -> i16 { - const LOG_CONST: i16 = 24660; // 160*log10(2) in Q9 - let EnergyResult { energy, scaling_factor } = energy(inv); - let mut energy = energy as u32; - let mut scaling_factor = scaling_factor as i8; - let mut log_energy = 0; - if energy != 0 { - // By construction, normalizing to 15 bits is equivalent with 17 leading - // zeros of an unsigned 32 bit value. - let normalizing_rshifts = 17 - norm_u32(energy) as i8; - // In a 15 bit representation the leading bit is 2^14. log2(2^14) in Q10 is - // (14 << 10), which is what we initialize |log2_energy| with. For a more - // detailed derivations, see below. - let mut log2_energy = 14336; // 14 in Q10 - - scaling_factor += normalizing_rshifts; - // Normalize |energy| to 15 bits. - // |tot_rshifts| is now the total number of right shifts performed on - // |energy| after normalization. This means that |energy| is in - // Q(-tot_rshifts). - if normalizing_rshifts < 0 { - energy <<= -normalizing_rshifts; - } else { - energy >>= normalizing_rshifts; - } - - // Calculate the energy of |data_in| in dB, in Q4. - // - // 10 * log10("true energy") in Q4 = 2^4 * 10 * log10("true energy") = - // 160 * log10(|energy| * 2^|tot_rshifts|) = - // 160 * log10(2) * log2(|energy| * 2^|tot_rshifts|) = - // 160 * log10(2) * (log2(|energy|) + log2(2^|tot_rshifts|)) = - // (160 * log10(2)) * (log2(|energy|) + |tot_rshifts|) = - // |kLogConst| * (|log2_energy| + |tot_rshifts|) - // - // We know by construction that |energy| is normalized to 15 bits. Hence, - // |energy| = 2^14 + frac_Q15, where frac_Q15 is a fractional part in Q15. - // Further, we'd like |log2_energy| in Q10 - // log2(|energy|) in Q10 = 2^10 * log2(2^14 + frac_Q15) = - // 2^10 * log2(2^14 * (1 + frac_Q15 * 2^-14)) = - // 2^10 * (14 + log2(1 + frac_Q15 * 2^-14)) ~= - // (14 << 10) + 2^10 * (frac_Q15 * 2^-14) = - // (14 << 10) + (frac_Q15 * 2^-4) = (14 << 10) + (frac_Q15 >> 4) - // - // Note that frac_Q15 = (|energy| & 0x00003FFF) - - // Calculate and add the fractional part to |log2_energy|. - log2_energy += ((energy & 0x00003FFF) >> 4) as i16; - - // |kLogConst| is in Q9, |log2_energy| in Q10 and |tot_rshifts| in Q0. - // Note that we in our derivation above have accounted for an output in Q4. - log_energy += (((LOG_CONST as i32 * log2_energy as i32) >> 19) + ((scaling_factor as i32 * LOG_CONST as i32) >> 9)) as i16; - if log_energy < 0 { - log_energy = 0; - } - } else { - log_energy = offset; - return log_energy; - } - - log_energy += offset; - - // Update the approximate |total_energy| with the energy of |data_in|, if - // |total_energy| has not exceeded |kMinEnergy|. |total_energy| is used as an - // energy indicator in WebRtcVad_GmmProbability() in vad_core.c. - if *total_energy <= 10 { - if scaling_factor >= 0 { - // We know by construction that the |energy| > |kMinEnergy| in Q0, so add - // an arbitrary value such that |total_energy| exceeds |kMinEnergy|. - *total_energy += 10 + 1; - } else { - // By construction |energy| is represented by 15 bits, hence any number of - // right shifted |energy| will fit in an int16_t. In addition, adding the - // value to |total_energy| is wrap around safe as long as - // |kMinEnergy| < 8192. - *total_energy += (energy >> -scaling_factor) as i16; - } - } - - log_energy -} - -const FEATURES_OFFSET_VECTOR: [i16; 6] = [368, 368, 272, 176, 176, 176]; - -pub fn calculate_features( - inv: &[i16], - features: &mut [i16], - split1_data: &mut [i16], - split2_data: &mut [i16], - upper_state: &mut [i16], - lower_state: &mut [i16], - hp_filter_state: &mut [i16] -) -> i16 { - let split1_size = inv.len() >> 1; - let split2_size = split1_size >> 1; - assert!(split1_data.len() >= split1_size * 2); - assert!(split2_data.len() >= split2_size * 2); - let (hp_120, lp_120) = split1_data[..split1_size << 1].split_at_mut(split1_size); - let (hp_60, lp_60) = split2_data[..split2_size << 1].split_at_mut(split2_size); - let mut total_energy = 0; - - // Split at 2000 Hz and downsample. - split(inv, hp_120, &mut upper_state[0], lp_120, &mut lower_state[0]); - - // For the upper band (2000 Hz - 4000 Hz), split at 3000 Hz and downsample. - split(hp_120, hp_60, &mut upper_state[1], lp_60, &mut lower_state[1]); - - // Energy in 3000 Hz - 4000 Hz. - features[5] = log_of_energy(&hp_60, FEATURES_OFFSET_VECTOR[5], &mut total_energy); - // Energy in 2000 Hz - 3000 Hz. - features[4] = log_of_energy(&lp_60, FEATURES_OFFSET_VECTOR[4], &mut total_energy); - - // For the lower band (0 Hz - 2000 Hz), split at 1000 Hz and downsample. - split(lp_120, hp_60, &mut upper_state[2], lp_60, &mut lower_state[2]); - - // Energy in 1000 Hz - 2000 Hz. - features[3] = log_of_energy(&hp_60, FEATURES_OFFSET_VECTOR[3], &mut total_energy); - - // For the lower band (0 Hz - 1000 Hz), split at 500 Hz and downsample. - let hp_30 = &mut hp_120[..split2_size >> 1]; - let lp_30 = &mut lp_120[..split2_size >> 1]; - split(lp_60, hp_30, &mut upper_state[3], lp_30, &mut lower_state[3]); - - // Energy in 500 Hz - 1000 Hz. - features[2] = log_of_energy(&hp_30, FEATURES_OFFSET_VECTOR[2], &mut total_energy); - - // For the lower band (0 Hz - 500 Hz) split at 250 Hz and downsample. - let hp_15 = &mut hp_60[..split2_size >> 2]; - let lp_15 = &mut lp_60[..split2_size >> 2]; - split(lp_30, hp_15, &mut upper_state[4], lp_15, &mut lower_state[4]); - - // Energy in 250 Hz - 500 Hz. - features[1] = log_of_energy(&hp_15, FEATURES_OFFSET_VECTOR[1], &mut total_energy); - - // Remove 0 Hz - 80 Hz by high pass filtering the lower band. - highpass(&lp_15, hp_15, hp_filter_state); - - // Energy in 80 Hz - 250 Hz. - features[0] = log_of_energy(&hp_15, FEATURES_OFFSET_VECTOR[0], &mut total_energy); - - total_energy -} - -#[cfg(test)] -mod tests { - use super::{FEATURES_OFFSET_VECTOR, calculate_features}; - - #[test] - fn test_calculate_features_reference() { - const ENERGIES: [i16; 3] = [48, 11, 11]; - const REFERENCES: [[i16; 6]; 3] = [[1213, 759, 587, 462, 434, 272], [1479, 1385, 1291, 1200, 1103, 1099], [1732, 1692, 1681, 1629, 1436, 1436]]; - - let mut features = [0; 6]; - let mut upper_state = [0; 5]; - let mut lower_state = [0; 5]; - let mut hp_filter_state = [0; 4]; - for (i, frame_length) in [80, 160, 240].into_iter().enumerate() { - let speech = (0..frame_length).map(|i| (i as i16).wrapping_mul(i as i16)).collect::>(); - let mut split1_data = vec![0; frame_length]; - let mut split2_data = vec![0; frame_length / 2]; - - let total_energy = - calculate_features(&speech, &mut features, &mut split1_data, &mut split2_data, &mut upper_state, &mut lower_state, &mut hp_filter_state); - assert_eq!(total_energy, ENERGIES[i]); - assert_eq!(features, REFERENCES[i]); - } - } - - #[test] - fn test_calculate_features_zeros() { - let mut features = [0; 6]; - let mut upper_state = [0; 5]; - let mut lower_state = [0; 5]; - let mut hp_filter_state = [0; 4]; - for frame_length in [80, 160, 240] { - let speech = vec![0; frame_length]; - let mut split1_data = vec![0; frame_length]; - let mut split2_data = vec![0; frame_length / 2]; - - let total_energy = - calculate_features(&speech, &mut features, &mut split1_data, &mut split2_data, &mut upper_state, &mut lower_state, &mut hp_filter_state); - assert_eq!(total_energy, 0); - assert_eq!(features, FEATURES_OFFSET_VECTOR); - } - } - - #[test] - fn test_calculate_features_ones() { - for frame_length in [80, 160, 240] { - let speech = vec![1; frame_length]; - let mut features = [0; 6]; - let mut upper_state = [0; 5]; - let mut lower_state = [0; 5]; - let mut hp_filter_state = [0; 4]; - let mut split1_data = vec![0; frame_length]; - let mut split2_data = vec![0; frame_length / 2]; - - let total_energy = - calculate_features(&speech, &mut features, &mut split1_data, &mut split2_data, &mut upper_state, &mut lower_state, &mut hp_filter_state); - assert_eq!(total_energy, 0); - assert_eq!(features, FEATURES_OFFSET_VECTOR); - } - } -} diff --git a/src/gmm.rs b/src/gmm.rs deleted file mode 100644 index 520961d..0000000 --- a/src/gmm.rs +++ /dev/null @@ -1,82 +0,0 @@ -use crate::util::div_i32_i16; - -// For a normal distribution, the probability of |input| is calculated and -// returned (in Q20). The formula for normal distributed probability is -// -// 1 / s * exp(-(x - m)^2 / (2 * s^2)) -// -// where the parameters are given in the following Q domains: -// m = |mean| (Q7) -// s = |std| (Q7) -// x = |input| (Q4) -// in addition to the probability we output |delta| (in Q11) used when updating -// the noise/speech model. -// -// Returns (probability, delta) -pub fn gaussian_probability(input: i16, mean: i16, std: i16) -> (i32, i16) { - const COMP_VAR: i32 = 22005; - const LOG2_EXP: i32 = 5909; // log2(exp(1)) in Q12. - - // Calculate |inv_std| = 1 / s, in Q10. - // 131072 = 1 in Q17, and (|std| >> 1) is for rounding instead of truncation. - // Q-domain: Q17 / Q7 = Q10. - let tmp32 = 131072 + (std >> 1) as i32; - let inv_std = div_i32_i16(tmp32, std) as i16; - - // Calculate |inv_std2| = 1 / s^2, in Q14. - let tmp16 = inv_std >> 2; // Q10 -> Q8. - // Q-domain: (Q8 * Q8) >> 2 = Q14. - let inv_std2 = ((tmp16 as i32 * tmp16 as i32) >> 2) as i16; - - let tmp16 = input << 3; // Q4 -> Q7 - let tmp16 = tmp16 - mean; // Q7 - Q7 = Q7 - - // To be used later, when updating noise/speech model. - // |delta| = (x - m) / s^2, in Q11. - // Q-domain: (Q14 * Q7) >> 10 = Q11. - let delta = ((inv_std2 as i32 * tmp16 as i32) >> 10) as i16; - - // Calculate the exponent |tmp32| = (x - m)^2 / (2 * s^2), in Q10. Replacing - // division by two with one shift. - // Q-domain: (Q11 * Q7) >> 8 = Q10. - let tmp32 = (delta as i32 * tmp16 as i32) >> 9; - - // If the exponent is small enough to give a non-zero probability we calculate - // |exp_value| ~= exp(-(x - m)^2 / (2 * s^2)) - // ~= exp2(-log2(exp(1)) * |tmp32|). - let exp_value = if tmp32 < COMP_VAR { - // Calculate |tmp16| = log2(exp(1)) * |tmp32|, in Q10. - // Q-domain: (Q12 * Q10) >> 12 = Q10. - let mut tmp16 = -((LOG2_EXP * tmp32) >> 12) as i16; - let exp_value = 0x0400 | (tmp16 & 0x03FF); - tmp16 ^= 0xFFFFu16 as i16; - tmp16 >>= 10; - tmp16 += 1; - // Get |exp_value| = exp(-|timp32|) in Q10. - exp_value.wrapping_shr(tmp16 as _) - } else { - 0 - }; - - // Calculate and return (1 / s) * exp(-(x - m)^2 / (2 * s^2)), in Q20. - // Q-domain: Q10 * Q10 = Q20. - (inv_std as i32 * exp_value as i32, delta) -} - -#[cfg(test)] -mod tests { - use crate::gmm::gaussian_probability; - - #[test] - fn test_gaussian_probability() { - assert_eq!((1048576, 0), gaussian_probability(0, 0, 128)); - assert_eq!((1048576, 0), gaussian_probability(16, 128, 128)); - assert_eq!((1048576, 0), gaussian_probability(-16, -128, 128)); - - assert_eq!((1024, 7552), gaussian_probability(59, 0, 128)); - assert_eq!((1024, 7552), gaussian_probability(75, 128, 128)); - assert_eq!((1024, -7552), gaussian_probability(-75, -128, 128)); - - assert_eq!((0, 13440), gaussian_probability(105, 0, 128)); - } -} diff --git a/src/lib.rs b/src/lib.rs index 506493a..d1799e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,821 +1,236 @@ #![cfg_attr(all(not(feature = "std"), not(test)), no_std)] -//! Earshot is a fast voice activity detection library. -//! -//! For more details, see [`VoiceActivityDetector`]. -//! -//! ``` -//! use earshot::{VoiceActivityDetector, VoiceActivityProfile}; -//! -//! let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); -//! -//! # let mut stream = std::iter::once(vec![0; 320]); -//! while let Some(frame) = stream.next() { -//! let is_speech_detected = vad.predict_16khz(&frame).unwrap(); -//! # assert_eq!(is_speech_detected, false); -//! } -//! ``` - -use core::{ - fmt, - ops::{Deref, DerefMut} -}; - -#[cfg(feature = "alloc")] extern crate alloc; -extern crate core; - -pub(crate) mod energy; -pub(crate) mod filterbank; -pub(crate) mod gmm; -pub(crate) mod resample; -pub(crate) mod sp; -pub(crate) mod util; - -#[doc(hidden)] -pub mod __internal_downsampling { - pub use crate::{resample::resample_48khz_to_8khz, sp::downsample_2x}; -} - -use self::{ - filterbank::calculate_features, - gmm::gaussian_probability, - resample::resample_48khz_to_8khz, - sp::{downsample_2x, find_minimum}, - util::{div_i32_i16, norm_i32, weighted_average} -}; - -const SPECTRUM_WEIGHT: [i16; 6] = [6, 8, 10, 12, 14, 16]; -const NOISE_UPDATE: i16 = 655; // Q15 -const SPEECH_UPDATE: i16 = 6554; // Q15 -const BACK_ETA: i16 = 154; // Q8 -const MINIMUM_DIFFERENCE: [i16; 6] = [544, 544, 576, 576, 576, 576]; // Q5 -const MAXIMUM_SPEECH: [i16; 6] = [11392, 11392, 11520, 11520, 11520, 11520]; // Q7 -const MAX_SPEECH_FRAMES: u8 = 6; -const MIN_STD: i16 = 384; - -#[doc(hidden)] -pub const NUM_GAUSSIANS: usize = 2; - -#[derive(Debug, Clone)] -pub struct VoiceActivityModel { - pub noise_weights: [[i16; 6]; NUM_GAUSSIANS], - pub speech_weights: [[i16; 6]; NUM_GAUSSIANS], - pub noise_means: [[i16; 6]; NUM_GAUSSIANS], - pub speech_means: [[i16; 6]; NUM_GAUSSIANS], - pub noise_stds: [[i16; 6]; NUM_GAUSSIANS], - pub speech_stds: [[i16; 6]; NUM_GAUSSIANS], - pub minimum_mean: [i16; NUM_GAUSSIANS], - pub maximum_noise: [i16; 6] -} -impl VoiceActivityModel { - /// The default VAD model from the original WebRTC source. - pub const WRTC: VoiceActivityModel = VoiceActivityModel { - noise_weights: [[34, 62, 72, 66, 53, 25], [94, 66, 56, 62, 75, 103]], - speech_weights: [[48, 82, 45, 87, 50, 47], [80, 46, 83, 41, 78, 81]], - noise_means: [[6738, 4892, 7065, 6715, 6771, 3369], [7646, 3863, 7820, 7266, 5020, 4362]], - speech_means: [[8306, 10085, 10078, 11823, 11843, 6309], [9473, 9571, 10879, 7581, 8180, 7483]], - noise_stds: [[378, 1064, 493, 582, 688, 593], [474, 697, 475, 688, 421, 455]], - speech_stds: [[555, 505, 567, 524, 585, 1231], [509, 828, 492, 1540, 1079, 850]], - minimum_mean: [640, 768], - maximum_noise: [9216, 9088, 8960, 8832, 8704, 8576] - }; - - /// A custom model with slightly better accuracy than the default WebRTC model. - pub const ES_ALPHA: VoiceActivityModel = VoiceActivityModel { - noise_weights: [[34, 52, 61, 72, 42, 17], [103, 68, 65, 54, 64, 80]], - speech_weights: [[43, 53, 18, 90, 30, 46], [56, 24, 76, 36, 52, 66]], - noise_means: [[6799, 4771, 7070, 6775, 6843, 3225], [7659, 3939, 7626, 7328, 5091, 4424]], - speech_means: [[8279, 10067, 10053, 11805, 11687, 6224], [9636, 9554, 10973, 7657, 8468, 7466]], - noise_stds: [[361, 1044, 514, 557, 718, 630], [394, 763, 476, 735, 422, 453]], - speech_stds: [[599, 451, 564, 483, 545, 1292], [409, 808, 526, 1447, 1089, 722]], - minimum_mean: [640, 768], - maximum_noise: [9232, 9101, 8952, 8830, 8653, 8555] - }; -} +use alloc::{boxed::Box, vec}; +use core::{f32, ptr}; + +mod fft; +mod quantized_predictor; +mod util; + +#[cfg(feature = "embed-weights")] +pub use self::quantized_predictor::default_weights as default_quantized_weights; +pub use self::quantized_predictor::{PackedWeights, QuantizedPredictor}; +use self::util::OnceLock; + +pub trait Predictor { + fn reset(&mut self); + fn predict(&mut self, features: &[f32], buffer: &mut [f32]) -> f32; +} + +const FFT_SIZE: usize = 1024; +const WINDOW_SIZE: usize = 768; +const N_MELS: usize = 40; +const N_FEATURES: usize = N_MELS + 1; +const N_CONTEXT_FRAMES: usize = 3; +const N_BINS: usize = FFT_SIZE / 2 + 1; +const PRE_EMPHASIS_COEFF: f32 = 0.97; +const POWER_FAC: f32 = 1. / (32768.0f32 * 32768.0); + +#[rustfmt::skip] +const FEATURE_MEANS: [f32; 40] = [ + -8.198236465454, -6.265716552734, -5.483818531036, -4.758691310883, + -4.417088985443, -4.142892837524, -3.912850379944, -3.845927953720, + -3.657090425491, -3.723418712616, -3.876134157181, -3.843890905380, + -3.690405130386, -3.756065845490, -3.698696136475, -3.650463104248, + -3.700468778610, -3.567321300507, -3.498900175095, -3.477807044983, + -3.458816051483, -3.444923877716, -3.401328563690, -3.306261301041, + -3.278556823730, -3.233250856400, -3.198616027832, -3.204526424408, + -3.208798646927, -3.257838010788, -3.381376743317, -3.534021377563, + -3.640867948532, -3.726858854294, -3.773730993271, -3.804667234421, + -3.832901000977, -3.871120452881, -3.990592956543, -4.480289459229 +]; + +#[rustfmt::skip] +const FEATURE_STDS: [f32; 40] = [ + 5.166063785553, 4.977209568024, 4.698895931244, 4.630621433258, + 4.634347915649, 4.641156196594, 4.640676498413, 4.666367053986, + 4.650534629822, 4.640020847321, 4.637400150299, 4.620099067688, + 4.596316337585, 4.562654972076, 4.554360389709, 4.566910743713, + 4.562489986420, 4.562412738800, 4.585299491882, 4.600179672241, + 4.592845916748, 4.585922718048, 4.583496570587, 4.626092910767, + 4.626957893372, 4.626289367676, 4.637005805969, 4.683015823364, + 4.726813793182, 4.734289646149, 4.753227233887, 4.849722862244, + 4.869434833527, 4.884482860565, 4.921327114105, 4.959212303162, + 4.996619224548, 5.044823646545, 5.072216987610, 5.096439361572 +]; + +struct Filters { + mel_coeffs: Box<[f32]>, + window: Box<[f32]> +} + +impl Filters { + pub fn new() -> Self { + let low_mel = 2595. * libm::log10f(1.0f32 + 0.0 / 700.); + let high_mel = 2595. * libm::log10f(1.0f32 + 8000. / 700.); + + let mut bin_points = [0; 42]; + for i in 0..N_MELS + 2 { + let mel = i as f32 * (high_mel - low_mel) / (N_MELS as f32 + 1.0) + low_mel; + let hz = 700.0 * (libm::exp10f(mel / 2595.) - 1.); + bin_points[i] = ((FFT_SIZE as f32 + 1.) * hz / 16000.) as usize; + } -impl Default for VoiceActivityModel { - fn default() -> Self { - Self::WRTC - } -} + let mut mel_coeffs = vec![0.0; N_MELS * N_BINS].into_boxed_slice(); + for i in 0..N_MELS { + for j in bin_points[i]..bin_points[i + 1] { + mel_coeffs[(i * N_BINS) + j] = (j - bin_points[i]) as f32 / (bin_points[i + 1] - bin_points[i]) as f32; + } -#[derive(Debug, Clone)] -pub struct VoiceActivityProfile { - overhang_max_1: [i16; 3], - overhang_max_2: [i16; 3], - local_threshold: [i16; 3], - global_threshold: [i16; 3] -} + for j in bin_points[i + 1]..bin_points[i + 2] { + mel_coeffs[(i * N_BINS) + j] = (bin_points[i + 2] - j) as f32 / (bin_points[i + 2] - bin_points[i + 1]) as f32; + } + } -impl VoiceActivityProfile { - /// The least aggressive profile, tuned to preserve as much probable speech as possible. - pub const QUALITY: VoiceActivityProfile = VoiceActivityProfile::new([8, 4, 3], [14, 7, 5], [24, 21, 24], [57, 48, 57]); - /// Tuned for low bit rate scenarios. - pub const LBR: VoiceActivityProfile = VoiceActivityProfile::new([8, 4, 3], [14, 7, 5], [37, 32, 37], [100, 80, 100]); - /// Aggressive profile, tuned to minimize false positives. - pub const AGGRESSIVE: VoiceActivityProfile = VoiceActivityProfile::new([6, 3, 2], [9, 5, 3], [82, 78, 82], [285, 260, 285]); - /// Even more aggressive profile, tuned to provide the least amount of false positives. - pub const VERY_AGGRESSIVE: VoiceActivityProfile = VoiceActivityProfile::new([6, 3, 2], [9, 5, 3], [94, 94, 94], [1100, 1050, 1100]); - - #[doc(hidden)] - pub const fn new(overhang_max_1: [i16; 3], overhang_max_2: [i16; 3], local_threshold: [i16; 3], global_threshold: [i16; 3]) -> Self { - Self { - overhang_max_1, - overhang_max_2, - local_threshold, - global_threshold + // hann window + let mut window = vec![0.0; WINDOW_SIZE].into_boxed_slice(); + let df = f32::consts::PI / WINDOW_SIZE as f32; + for i in 0..WINDOW_SIZE { + let x = libm::sinf(df * i as f32); + window[i] = x * x; } - } -} -impl Default for VoiceActivityProfile { - fn default() -> Self { - Self::AGGRESSIVE + Self { mel_coeffs, window } } } -#[cfg(feature = "alloc")] -#[repr(transparent)] -#[derive(Debug)] -struct MaybeHeapAllocated(alloc::boxed::Box<[T]>); -#[cfg(not(feature = "alloc"))] -#[repr(transparent)] -struct MaybeHeapAllocated([T; N]); +static FILTERS: OnceLock = OnceLock::new(); -#[cfg(feature = "alloc")] -impl MaybeHeapAllocated { - pub fn new() -> Self { - Self(alloc::vec![T::default(); N].into_boxed_slice()) - } -} -#[cfg(not(feature = "alloc"))] -impl MaybeHeapAllocated { - pub fn new() -> Self { - Self([T::default(); N]) - } +pub struct Detector

{ + predictor: P, + prev_signal: f32, + sample_ring_buffer: Box<[f32]>, + features: Box<[f32]>, + buffer: Box<[f32]> } -impl Deref for MaybeHeapAllocated { - type Target = [T]; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} -impl DerefMut for MaybeHeapAllocated { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 +impl Default for Detector

{ + fn default() -> Self { + Self::new(P::default()) } } -#[derive(Debug, Clone)] -#[non_exhaustive] -pub enum Error { - /// The size of the audio frame passed to [`VoiceActivityDetector::predict_8khz()`] (or similar methods) is not one - /// of the valid frame sizes for this sample rate. - InvalidFrameSize { frame_size: usize, valid_sizes: [u16; 3] } -} - -impl Error { - pub(crate) fn invalid_frame_size_multiplier(self, multiplier: u8) -> Self { - match self { - Error::InvalidFrameSize { frame_size, valid_sizes } => Error::InvalidFrameSize { - frame_size, - valid_sizes: [valid_sizes[0] * multiplier as u16, valid_sizes[1] * multiplier as u16, valid_sizes[2] * multiplier as u16] - }, - #[allow(unreachable_patterns)] - e => e - } - } -} +impl Detector

{ + pub fn new(predictor: P) -> Self { + // create filters now so we don't accidentally make the first `predict` take super long + FILTERS.get_or_init(Filters::new); -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::InvalidFrameSize { frame_size, valid_sizes } => f.write_fmt(format_args!( - "invalid frame size of {frame_size} samples; valid sizes are {} samples (10ms), {} samples (20ms), or {} samples (30ms)", - valid_sizes[0], valid_sizes[1], valid_sizes[2] - )) + Self { + predictor, + prev_signal: 0.0, + sample_ring_buffer: vec![0.0; 768].into_boxed_slice(), + features: vec![0.0; N_FEATURES * N_CONTEXT_FRAMES].into_boxed_slice(), + buffer: vec![0.0; 1026].into_boxed_slice() } } -} - -// TODO: `core::error::Error` was stabilized in 1.81. -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -pub struct VoiceActivityDetector { - downsampling_filter_states: MaybeHeapAllocated, - downsample_48khz_to_8khz_state: MaybeHeapAllocated, - downsample_tmp: MaybeHeapAllocated, - noise_means: MaybeHeapAllocated<[i16; 6], NUM_GAUSSIANS>, - speech_means: MaybeHeapAllocated<[i16; 6], NUM_GAUSSIANS>, - noise_stds: MaybeHeapAllocated<[i16; 6], NUM_GAUSSIANS>, - speech_stds: MaybeHeapAllocated<[i16; 6], NUM_GAUSSIANS>, - frame_counter: usize, - overhang: i16, - num_of_speech: i16, - age_vector: MaybeHeapAllocated, - low_value_vector: MaybeHeapAllocated, - mean_value: MaybeHeapAllocated, - split1_data: MaybeHeapAllocated, - split2_data: MaybeHeapAllocated, - upper_state: MaybeHeapAllocated, - lower_state: MaybeHeapAllocated, - hp_filter_state: MaybeHeapAllocated, - profile: VoiceActivityProfile, - feature_vector: MaybeHeapAllocated, - total_power: i16, - #[doc(hidden)] - pub model: VoiceActivityModel -} - -impl VoiceActivityDetector { - /// Creates a new [`VoiceActivityDetector`] with the default model ([`VoiceActivityModel::WRTC`]) and given profile. - /// - /// ``` - /// # use earshot::{VoiceActivityDetector, VoiceActivityProfile}; - /// let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - /// ``` - pub fn new(profile: VoiceActivityProfile) -> Self { - Self::new_with_model(VoiceActivityModel::default(), profile) - } - /// Creates a new [`VoiceActivityDetector`] with the given model and profile. + /// Resets the internal state of the voice activity detector. /// - /// ``` - /// # use earshot::{VoiceActivityDetector, VoiceActivityModel, VoiceActivityProfile}; - /// let mut vad = - /// VoiceActivityDetector::new_with_model(VoiceActivityModel::ES_ALPHA, VoiceActivityProfile::VERY_AGGRESSIVE); - /// ``` - pub fn new_with_model(model: VoiceActivityModel, profile: VoiceActivityProfile) -> Self { - let mut vad = Self { - downsampling_filter_states: MaybeHeapAllocated::new(), - downsample_48khz_to_8khz_state: MaybeHeapAllocated::new(), - downsample_tmp: MaybeHeapAllocated::new(), - noise_means: MaybeHeapAllocated::new(), - speech_means: MaybeHeapAllocated::new(), - noise_stds: MaybeHeapAllocated::new(), - speech_stds: MaybeHeapAllocated::new(), - frame_counter: 0, - overhang: 0, - num_of_speech: 0, - age_vector: MaybeHeapAllocated::new(), - low_value_vector: MaybeHeapAllocated::new(), - mean_value: MaybeHeapAllocated::new(), - split1_data: MaybeHeapAllocated::new(), - split2_data: MaybeHeapAllocated::new(), - upper_state: MaybeHeapAllocated::new(), - lower_state: MaybeHeapAllocated::new(), - hp_filter_state: MaybeHeapAllocated::new(), - profile, - feature_vector: MaybeHeapAllocated::new(), - total_power: 0, - model - }; - vad.reset(); - vad + /// The detector should be reset whenever: + /// - the recording device changes; or + /// - the detector is being used for a new audio sequence. + pub fn reset(&mut self) { + self.predictor.reset(); + self.prev_signal = 0.0; + self.sample_ring_buffer.fill(0.0); + self.features.fill(0.0); } - /// Resets the internal state of the VAD. - /// - /// Ideally, this should be called whenever a new audio stream begins. - /// - /// ``` - /// # use earshot::{VoiceActivityDetector, VoiceActivityProfile}; - /// let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); + /// Predicts the voice activity score of a single input frame of 16-bit PCM audio. /// - /// # let streams: [&[i16]; 0] = []; - /// for stream in streams { - /// let mut speech_frames = 0; - /// for frame in stream.chunks_exact(240) { - /// if let Ok(true) = vad.predict_8khz(frame) { - /// speech_frames += 1; - /// } - /// } + /// The frame: + /// - should be sampled at 16 KHz; + /// - should be exactly 256 samples (so 16 ms) in length. /// - /// vad.reset(); - /// } - /// ``` - pub fn reset(&mut self) { - self.frame_counter = 0; - self.overhang = 0; - self.num_of_speech = 0; - - self.downsampling_filter_states.fill(0); - self.downsample_48khz_to_8khz_state.fill(0); - - self.noise_means.copy_from_slice(&self.model.noise_means); - self.speech_means.copy_from_slice(&self.model.speech_means); - self.noise_stds.copy_from_slice(&self.model.noise_stds); - self.speech_stds.copy_from_slice(&self.model.speech_stds); - - self.low_value_vector.fill(10000); - self.age_vector.fill(0); - - self.upper_state.fill(0); - self.lower_state.fill(0); - - self.hp_filter_state.fill(0); + /// The output score is between `[0, 1]`. Scores over 0.5 can generally be considered voice, but the exact threshold + /// can be adjusted according to application-specific needs. + pub fn predict_i16(&mut self, frame: &[i16]) -> f32 { + assert_eq!(frame.len(), 256); - self.mean_value.fill(1600); - } - - fn gmm(&mut self, frame_len: usize) -> Result { - const MIN_ENERGY: i16 = 10; - - let (overhead1, overhead2, individual_test, total_test) = match frame_len { - 80 => (self.profile.overhang_max_1[0], self.profile.overhang_max_2[0], self.profile.local_threshold[0], self.profile.global_threshold[0]), - 160 => (self.profile.overhang_max_1[1], self.profile.overhang_max_2[1], self.profile.local_threshold[1], self.profile.global_threshold[1]), - 240 => (self.profile.overhang_max_1[2], self.profile.overhang_max_2[2], self.profile.local_threshold[2], self.profile.global_threshold[2]), - _ => { - return Err(Error::InvalidFrameSize { - frame_size: frame_len, - valid_sizes: [80, 160, 240] - }); - } + unsafe { + ptr::copy(self.sample_ring_buffer.as_ptr().add(256), self.sample_ring_buffer.as_mut_ptr(), 512); }; - - let mut delta_noise = [[0; 6]; NUM_GAUSSIANS]; - let mut delta_speech = [[0; 6]; NUM_GAUSSIANS]; - let mut noise_probability = [0; NUM_GAUSSIANS]; - let mut speech_probability = [0; NUM_GAUSSIANS]; - let mut sum_log_likelihood_ratios = 0; - let mut ngprvec = [[0; 6]; NUM_GAUSSIANS]; - let mut sgprvec = [[0; 6]; NUM_GAUSSIANS]; - let mut vadflag = false; - - if self.total_power > MIN_ENERGY { - // The signal power of current frame is large enough for processing. The - // processing consists of two parts: - // 1) Calculating the likelihood of speech and thereby a VAD decision. - // 2) Updating the underlying model, w.r.t., the decision made. - - // The detection scheme is an LRT with hypothesis - // H0: Noise - // H1: Speech - // - // We combine a global LRT with local tests, for each frequency sub-band, - // here defined as |channel|. - for channel in 0..6 { - // For each channel we model the probability with a GMM consisting of - // |kNumGaussians|, with different means and standard deviations depending - // on H0 or H1. - let mut h0_test = 0; - let mut h1_test = 0; - for gaussian in 0..NUM_GAUSSIANS { - // Probability under H0, that is, probability of frame being noise. - // Value given in Q27 = Q7 * Q20. - let (tmp1_s32, delta) = - gaussian_probability(self.feature_vector[channel], self.noise_means[gaussian][channel], self.noise_stds[gaussian][channel]); - delta_noise[gaussian][channel] = delta; - noise_probability[gaussian] = self.model.noise_weights[gaussian][channel] as i32 * tmp1_s32; - h0_test += noise_probability[gaussian]; // Q27 - - // Probability under H1, that is, probability of frame being speech. - // Value given in Q27 = Q7 * Q20. - let (tmp1_s32, delta) = - gaussian_probability(self.feature_vector[channel], self.speech_means[gaussian][channel], self.speech_stds[gaussian][channel]); - delta_speech[gaussian][channel] = delta; - speech_probability[gaussian] = self.model.speech_weights[gaussian][channel] as i32 * tmp1_s32; - h1_test += speech_probability[gaussian]; // Q27 - } - - // Calculate the log likelihood ratio: log2(Pr{X|H1} / Pr{X|H1}). - // Approximation: - // log2(Pr{X|H1} / Pr{X|H1}) = log2(Pr{X|H1}*2^Q) - log2(Pr{X|H1}*2^Q) - // = log2(h1_test) - log2(h0_test) - // = log2(2^(31-shifts_h1)*(1+b1)) - // - log2(2^(31-shifts_h0)*(1+b0)) - // = shifts_h0 - shifts_h1 - // + log2(1+b1) - log2(1+b0) - // ~= shifts_h0 - shifts_h1 - // - // Note that b0 and b1 are values less than 1, hence, 0 <= log2(1+b0) < 1. - // Further, b0 and b1 are independent and on the average the two terms - // cancel. - let (mut shifts_h0, mut shifts_h1) = (norm_i32(h0_test), norm_i32(h1_test)); - if h0_test == 0 { - shifts_h0 = 31; - } - if h1_test == 0 { - shifts_h1 = 31; - } - let log_likelihood_ratio = shifts_h0 - shifts_h1; - - // Update |sum_log_likelihood_ratios| with spectrum weighting. This is - // used for the global VAD decision. - sum_log_likelihood_ratios += log_likelihood_ratio as i32 * SPECTRUM_WEIGHT[channel] as i32; - - // Local VAD decision. - if (log_likelihood_ratio as i16 * 4) > individual_test { - vadflag = true; - } - - // Calculate local noise probabilities used later when updating the GMM. - let h0 = (h0_test >> 12) as i16; - if h0 > 0 { - // High probability of noise. Assign conditional probabilities for each - // Gaussian in the GMM. - let tmp1_s32 = (noise_probability[0] & 0xFFFFF000u32 as i32) << 2; // Q29 - ngprvec[0][channel] = div_i32_i16(tmp1_s32, h0) as i16; // Q14 - for gaussian in 1..NUM_GAUSSIANS { - ngprvec[gaussian][channel] = 16384 - ngprvec[0][channel]; - } - } else { - // Low noise probability. Assign conditional probability 1 to the first - // Gaussian and 0 to the rest (which is already set at initialization). - ngprvec[0][channel] = 16384; - } - - // Calculate local speech probabilities used later when updating the GMM. - let h1 = (h1_test >> 12) as i16; - if h1 > 0 { - // High probability of speech. Assign conditional probabilities for each - // Gaussian in the GMM. Otherwise use the initialized values, i.e., 0. - let tmp1_s32 = (speech_probability[0] & 0xFFFFF000u32 as i32) << 2; // Q29 - sgprvec[0][channel] = div_i32_i16(tmp1_s32, h1) as i16; // Q14 - for gaussian in 1..NUM_GAUSSIANS { - sgprvec[gaussian][channel] = 16384 - sgprvec[0][channel]; - } - } - } - - // Make a global VAD decision. - vadflag |= sum_log_likelihood_ratios >= total_test as i32; - - // Update the model parameters. - let mut maxspe = 12800; - for channel in 0..6 { - // Get minimum value in past which is used for long term correction in Q4. - let feature_minimum = find_minimum( - &mut self.age_vector, - &mut self.low_value_vector, - self.frame_counter, - &mut self.mean_value, - self.feature_vector[channel], - channel - ); - - // Compute the "global" mean, that is the sum of the two means weighted. - let noise_global_mean = weighted_average(&mut self.noise_means, channel, 0, &self.model.noise_weights); - let tmp1_s16 = (noise_global_mean >> 6) as i16; // Q8 - for gaussian in 0..NUM_GAUSSIANS { - let nmk = self.noise_means[gaussian][channel]; - let smk = self.speech_means[gaussian][channel]; - let nsk = self.noise_stds[gaussian][channel]; - let ssk = self.speech_stds[gaussian][channel]; - - // Update noise mean vector if the frame consists of noise only. - let mut nmk2 = nmk; - if !vadflag { - // (Q14 * Q11 >> 11) = Q14. - let delt = ((ngprvec[gaussian][channel] as i32 * delta_noise[gaussian][channel] as i32) >> 11) as i16; - // Q7 + (Q14 * Q15 >> 22) = Q7. - nmk2 = nmk + ((delt as i32 * NOISE_UPDATE as i32) >> 22) as i16; - } - - // Long term correction of the noise mean. - // Q8 - Q8 = Q8. - let ndelt = (feature_minimum << 4) - tmp1_s16; - // Q7 + (Q8 * Q8) >> 9 = Q7. - let nmk3 = nmk2 + ((ndelt as i32 * BACK_ETA as i32) >> 9) as i16; - - // Control that the noise mean does not drift too much. - self.noise_means[gaussian][channel] = nmk3.clamp((gaussian + 5 << 7) as i16, ((72 + gaussian - channel) << 7) as i16); - - if vadflag { - // Update speech mean vector: - // |deltaS| = (x-mu)/sigma^2 - // sgprvec[k] = |speech_probability[k]| / - // (|speech_probability[0]| + |speech_probability[1]|) - - // (Q14 * Q11) >> 11 = Q14. - let delt = ((sgprvec[gaussian][channel] as i32 * delta_speech[gaussian][channel] as i32) >> 11) as i16; - // Q14 * Q15 >> 21 = Q8 - let tmp_s16 = ((delt as i32 * SPEECH_UPDATE as i32) >> 21) as i16; - // Q7 + (Q8 >> 1) = Q7 - let smk2 = smk + ((tmp_s16 + 1) >> 1); - - // Control that the max speech mean does not drift too much. - self.speech_means[gaussian][channel] = smk2.clamp(self.model.minimum_mean[gaussian], maxspe + 640); - - // (Q7 >> 3) = Q4 - let mut tmp_s16 = (smk + 4) >> 3; - - tmp_s16 = self.feature_vector[channel] - tmp_s16; - // (Q11 * Q4 >> 3) = Q12 - let tmp1_s32 = (delta_speech[gaussian][channel] as i32 * tmp_s16 as i32) >> 3; - let tmp2_s32 = tmp1_s32 - 4096; - tmp_s16 = (sgprvec[gaussian][channel] >> 2) as i16; - // (Q14 >> 2) * Q12 = Q24. - let tmp1_s32 = tmp_s16 as i32 * tmp2_s32; - - let tmp2_s32 = tmp1_s32 >> 4; // Q20 - - // 0.1 * Q20 / Q7 = Q13 - let mut tmp_s16 = if tmp2_s32 > 0 { div_i32_i16(tmp2_s32, ssk * 10) } else { -div_i32_i16(-tmp2_s32, ssk * 10) } as i16; - - // Divide by 4 giving an update factor of 0.025 (= 0.1 / 4). - // Note that division by 4 equals shift by 2, hence, - // (Q13 >> 8) = (Q13 >> 6) / 4 = Q7. - tmp_s16 += 128; - let ssk = (ssk + (tmp_s16 >> 8)).min(MIN_STD); - self.speech_stds[gaussian][channel] = ssk; - } else { - // Update GMM variance vectors. - // deltaN * (features[channel] - nmk) - 1 - // Q4 - (Q7 >> 3) = Q4. - let tmp_s16 = self.feature_vector[channel] - (nmk >> 3); - // (Q11 * Q4 >> 3) = Q12. - let mut tmp1_s32 = (delta_noise[gaussian][channel] as i32 * tmp_s16 as i32) >> 3; - tmp1_s32 -= 4096; - - // (Q14 >> 2) * Q12 = Q24. - let tmp_s16 = (ngprvec[gaussian][channel] + 2) >> 2; - let tmp2_s32 = (tmp_s16 as i32).saturating_mul(tmp1_s32); - // Q20 * approx 0.001 (2^-10=0.0009766), hence, - // (Q24 >> 14) = (Q24 >> 4) / 2^10 = Q20. - tmp1_s32 = tmp2_s32 >> 14; - - // Q20 / Q7 = Q13. - let mut tmp_s16 = if tmp1_s32 > 0 { div_i32_i16(tmp1_s32, nsk) } else { -div_i32_i16(-tmp1_s32, nsk) } as i16; - tmp_s16 += 32; // Rounding - let nsk = (nsk + (tmp_s16 >> 6)).min(MIN_STD); - self.noise_stds[gaussian][channel] = nsk; - } - } - - // Separate models if they are too close. - // |noise_global_mean| in Q14 (= Q7 * Q7). - let mut noise_global_mean = weighted_average(&mut self.noise_means, channel, 0, &self.model.noise_weights); - // |speech_global_mean| in Q14 (= Q7 * Q7). - let mut speech_global_mean = weighted_average(&mut self.speech_means, channel, 0, &self.model.speech_weights); - - // |diff| = "global" speech mean - "global" noise mean. - // (Q14 >> 9) - (Q14 >> 9) = Q5. - let diff = (speech_global_mean >> 9) as i16 - (noise_global_mean >> 9) as i16; - if diff < MINIMUM_DIFFERENCE[channel] { - let tmp_s16 = MINIMUM_DIFFERENCE[channel] - diff; - - // |tmp1_s16| = ~0.8 * (kMinimumDifference - diff) in Q7. - // |tmp2_s16| = ~0.2 * (kMinimumDifference - diff) in Q7. - let tmp1_s16 = (13 * tmp_s16) >> 2; - let tmp2_s16 = (3 * tmp_s16) >> 2; - - // Move Gaussian means for speech model by |tmp1_s16| and update - // |speech_global_mean|. Note that |self->speech_means[channel]| is - // changed after the call. - speech_global_mean = weighted_average(&mut self.speech_means, channel, tmp1_s16, &self.model.speech_weights); - - // Move Gaussian means for noise model by -|tmp2_s16| and update - // |noise_global_mean|. Note that |self->noise_means[channel]| is - // changed after the call. - noise_global_mean = weighted_average(&mut self.noise_means, channel, -tmp2_s16, &self.model.noise_weights); - } - - // Control that the speech & noise means do not drift to much. - maxspe = MAXIMUM_SPEECH[channel]; - let mut tmp2_s16 = (speech_global_mean >> 7) as i16; - if tmp2_s16 > maxspe { - // Upper limit of speech model. - tmp2_s16 -= maxspe; - - for gaussian in 0..NUM_GAUSSIANS { - self.speech_means[gaussian][channel] -= tmp2_s16; - } - } - - tmp2_s16 = (noise_global_mean >> 7) as i16; - if tmp2_s16 > self.model.maximum_noise[channel] { - tmp2_s16 -= self.model.maximum_noise[channel]; - for gaussian in 0..NUM_GAUSSIANS { - self.noise_means[gaussian][channel] -= tmp2_s16; - } - } - } - - self.frame_counter += 1; - } - - if !vadflag { - if self.overhang > 0 { - vadflag = true; - self.overhang -= 1; - } - self.num_of_speech = 0; - } else { - self.num_of_speech += 1; - if self.num_of_speech > MAX_SPEECH_FRAMES as i16 { - self.num_of_speech = MAX_SPEECH_FRAMES as i16; - self.overhang = overhead2; - } else { - self.overhang = overhead1; - } + for (emph, sample) in (&mut self.sample_ring_buffer[512..]).iter_mut().zip(frame.iter()) { + let sample = *sample as f32; + *emph = sample - PRE_EMPHASIS_COEFF * self.prev_signal; + self.prev_signal = sample; } - Ok(vadflag) + self.predict_inner() } - /// Run VAD prediction on a single frame of 48 KHz signed 16-bit mono PCM audio. Returns `Ok(true)` if the model - /// predicts that this frame contains speech. + /// Predicts the voice activity score of a single input frame of 32-bit floating-point PCM audio. /// - /// The frame must be 10ms (480 samples), 20ms (960 samples), or 30ms (1440 samples) in length. An `Err` is returned - /// if the frame size is invalid. + /// The frame: + /// - should be sampled at 16 KHz; + /// - should be exactly 256 samples (so 16 ms) in length; + /// - should consist only of samples in the range [-1, 1]. /// - /// ``` - /// # use earshot::{VoiceActivityDetector, VoiceActivityProfile}; - /// let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - /// - /// # let mut stream = std::iter::once(vec![0; 960]); - /// while let Some(frame) = stream.next() { - /// let is_speech_detected = vad.predict_48khz(&frame).unwrap(); - /// # assert_eq!(is_speech_detected, false); - /// } - /// ``` - pub fn predict_48khz(&mut self, frame: &[i16]) -> Result { - let mut out_frame = [0; 240]; - for (i, subframe) in frame.chunks_exact(480).enumerate() { - let out_chunk_size = 80 * i; - resample_48khz_to_8khz( - subframe, - &mut out_frame[out_chunk_size..out_chunk_size + 80], - &mut self.downsample_48khz_to_8khz_state, - &mut self.downsample_tmp - ); - } - self.predict_8khz(&out_frame).map_err(|e| e.invalid_frame_size_multiplier(6)) - } - - /// Run VAD prediction on a single frame of 32 KHz signed 16-bit mono PCM audio. Returns `Ok(true)` if the model - /// predicts that this frame contains speech. - /// - /// The frame must be 10ms (320 samples), 20ms (640 samples), or 30ms (960 samples) in length. An `Err` is returned - /// if the frame size is invalid. - /// - /// ``` - /// # use earshot::{VoiceActivityDetector, VoiceActivityProfile}; - /// let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - /// - /// # let mut stream = std::iter::once(vec![0; 640]); - /// while let Some(frame) = stream.next() { - /// let is_speech_detected = vad.predict_32khz(&frame).unwrap(); - /// # assert_eq!(is_speech_detected, false); - /// } - /// ``` - pub fn predict_32khz(&mut self, frame: &[i16]) -> Result { - let mut out_frame = [0; 480]; - downsample_2x(frame, &mut out_frame, &mut self.downsampling_filter_states[2..]); - self.predict_16khz(&out_frame).map_err(|e| e.invalid_frame_size_multiplier(2)) - } + /// The output score is between `[0, 1]`. Scores over 0.5 can generally be considered voice, but the exact threshold + /// can be adjusted according to application-specific needs. + pub fn predict_f32(&mut self, frame: &[f32]) -> f32 { + assert_eq!(frame.len(), 256); - /// Run VAD prediction on a single frame of 16 KHz signed 16-bit mono PCM audio. Returns `Ok(true)` if the model - /// predicts that this frame contains speech. - /// - /// The frame must be 10ms (160 samples), 20ms (320 samples), or 30ms (480 samples) in length. An `Err` is returned - /// if the frame size is invalid. - /// - /// ``` - /// # use earshot::{VoiceActivityDetector, VoiceActivityProfile}; - /// let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - /// - /// # let mut stream = std::iter::once(vec![0; 320]); - /// while let Some(frame) = stream.next() { - /// let is_speech_detected = vad.predict_16khz(&frame).unwrap(); - /// # assert_eq!(is_speech_detected, false); - /// } - /// ``` - pub fn predict_16khz(&mut self, frame: &[i16]) -> Result { - let mut out_frame = [0; 240]; - downsample_2x(frame, &mut out_frame, &mut self.downsampling_filter_states[0..2]); - self.predict_8khz(&out_frame).map_err(|e| e.invalid_frame_size_multiplier(2)) - } - - /// Run VAD prediction on a single frame of 8 KHz signed 16-bit mono PCM audio. Returns `Ok(true)` if the model - /// predicts that this frame contains speech. - /// - /// The frame must be 10ms (80 samples), 20ms (160 samples), or 30ms (240 samples) in length. An `Err` is returned - /// if the frame size is invalid. - /// - /// ``` - /// # use earshot::{VoiceActivityDetector, VoiceActivityProfile}; - /// let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - /// - /// # let mut stream = std::iter::once(vec![0; 160]); - /// while let Some(frame) = stream.next() { - /// let is_speech_detected = vad.predict_8khz(&frame).unwrap(); - /// # assert_eq!(is_speech_detected, false); - /// } - /// ``` - pub fn predict_8khz(&mut self, frame: &[i16]) -> Result { - self.total_power = calculate_features( - frame, - &mut self.feature_vector, - &mut self.split1_data, - &mut self.split2_data, - &mut self.upper_state, - &mut self.lower_state, - &mut self.hp_filter_state + debug_assert!( + *frame + .iter() + .max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap_or(core::cmp::Ordering::Equal)) + .unwrap() <= 1.0, + "input frame should be in the range [-1, 1]" ); - self.gmm(frame.len()) - } -} -#[cfg(test)] -mod tests { - use core::slice; - use std::fs; - - use crate::{Error, VoiceActivityDetector, VoiceActivityProfile}; - - #[test] - fn test_vad_synthetic() -> Result<(), Error> { - for profile in [ - VoiceActivityProfile::QUALITY, - VoiceActivityProfile::LBR, - VoiceActivityProfile::AGGRESSIVE, - VoiceActivityProfile::VERY_AGGRESSIVE - ] { - let mut vad = VoiceActivityDetector::new(profile.clone()); - for frame_size in [80, 160, 240] { - let speech = vec![0; frame_size]; - assert_eq!(false, vad.predict_8khz(&speech)?); - } - } - for profile in [ - VoiceActivityProfile::QUALITY, - VoiceActivityProfile::LBR, - VoiceActivityProfile::AGGRESSIVE, - VoiceActivityProfile::VERY_AGGRESSIVE - ] { - let mut vad = VoiceActivityDetector::new(profile.clone()); - for frame_size in [80, 160, 240] { - let speech = (0..frame_size as i16).map(|i| i.wrapping_mul(i)).collect::>(); - assert_eq!(true, vad.predict_8khz(&speech)?); - } - } - Ok(()) - } + /// We perform the FFT at i16 scale and scale down afterwards; doing the FFT at f32 scale ([-1, 1]) loses a lot + /// of precision. + const SCALE: f32 = 32768.0; - #[test] - fn test_real_8khz() -> Result<(), Error> { - let file = fs::read("tests/data/audio_tiny8.raw").unwrap(); - let samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - let mut voice_frames = 0; - for frame in samples.chunks_exact(240) { - if vad.predict_8khz(frame)? { - voice_frames += 1; - } + unsafe { + ptr::copy(self.sample_ring_buffer.as_ptr().add(256), self.sample_ring_buffer.as_mut_ptr(), 512); + }; + for (emph, sample) in (&mut self.sample_ring_buffer[512..]).iter_mut().zip(frame.iter()) { + let sample = *sample * SCALE; + *emph = sample - PRE_EMPHASIS_COEFF * self.prev_signal; + self.prev_signal = sample; } - assert!((162..169).contains(&voice_frames)); - Ok(()) + + self.predict_inner() } - #[test] - fn test_real_16khz() -> Result<(), Error> { - let file = fs::read("tests/data/audio_tiny16.raw").unwrap(); - let samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - let mut voice_frames = 0; - for frame in samples.chunks_exact(480) { - if vad.predict_16khz(frame)? { - voice_frames += 1; - } + fn predict_inner(&mut self) -> f32 { + let filters = FILTERS.get_or_init(Filters::new); + + // windowize for FFT + for i in 0..WINDOW_SIZE { + self.buffer[i] = self.sample_ring_buffer[i] * filters.window[i]; } - assert!((162..169).contains(&voice_frames)); - Ok(()) - } + // FFT size is 1024 but window size is 768, so fill the rest with zeros (+2 to store nyquist frequency) + unsafe { + ptr::write_bytes(self.buffer.as_mut_ptr().add(WINDOW_SIZE), 0, 256 + 2); + }; - #[test] - fn test_real_32khz() -> Result<(), Error> { - let file = fs::read("tests/data/audio_tiny32.raw").unwrap(); - let samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - let mut voice_frames = 0; - for frame in samples.chunks_exact(960) { - if vad.predict_32khz(frame)? { - voice_frames += 1; - } + fft::rfft_1024(&mut self.buffer); + for i in 0..N_BINS { + let j = i * 2; + self.buffer[i] = fft::Complex32::new(self.buffer[j], self.buffer[j + 1]).norm_sqr() + // downscale from i16 scale + * POWER_FAC; } - assert!((162..169).contains(&voice_frames)); - Ok(()) - } - #[test] - fn test_real_48khz() -> Result<(), Error> { - let file = fs::read("tests/data/audio_tiny48.raw").unwrap(); - let samples = unsafe { slice::from_raw_parts(file.as_ptr().cast::(), file.len() / 2) }; - let mut vad = VoiceActivityDetector::new(VoiceActivityProfile::VERY_AGGRESSIVE); - let mut voice_frames = 0; - for frame in samples.chunks_exact(1440) { - if vad.predict_48khz(frame)? { - voice_frames += 1; + unsafe { + ptr::copy(self.features.as_ptr().add(N_FEATURES), self.features.as_mut_ptr(), N_FEATURES * (N_CONTEXT_FRAMES - 1)); + }; + let cur_frame_features = &mut self.features[(N_FEATURES * (N_CONTEXT_FRAMES - 1))..]; + for i in 0..N_MELS { + let mut per_band_value = 0.; + for j in 0..N_BINS { + per_band_value += self.buffer[j] * filters.mel_coeffs[(i * N_BINS) + j]; } + + per_band_value = libm::logf(per_band_value + 1e-20); + cur_frame_features[i] = (per_band_value - FEATURE_MEANS[i]) / FEATURE_STDS[i]; } - assert!((162..169).contains(&voice_frames)); - Ok(()) + + self.predictor.predict(&self.features, &mut self.buffer) } } diff --git a/src/quantized-model.bin b/src/quantized-model.bin new file mode 100644 index 0000000..b6b77fe Binary files /dev/null and b/src/quantized-model.bin differ diff --git a/src/quantized_predictor.rs b/src/quantized_predictor.rs new file mode 100644 index 0000000..010d69b --- /dev/null +++ b/src/quantized_predictor.rs @@ -0,0 +1,489 @@ +use alloc::{boxed::Box, vec}; +use core::{mem, slice}; + +use super::{Predictor, util::OnceLock}; + +struct BitBufferReader<'d> { + pub buf: &'d [u8], + idx: usize, + bit_buffer: u32, + n_bits: u32 +} + +impl<'d> BitBufferReader<'d> { + pub fn new(buffer: &'d [u8]) -> Self { + Self { + buf: buffer, + idx: 0, + bit_buffer: 0, + n_bits: 0 + } + } + + pub fn read(&mut self, len: u32) -> i32 { + while self.n_bits < len { + let byte = self.buf[self.idx]; + self.idx += 1; + + self.bit_buffer |= (byte as u32) << self.n_bits; + self.n_bits += 8; + } + + let bits = self.bit_buffer & ((1 << len) - 1); + self.bit_buffer >>= len; + self.n_bits -= len; + + let sign = (bits & (1 << (len - 1))) != 0; + (if sign { -1 << len } else { 0 }) | bits as i32 + } + + pub fn read_array(&mut self, bit_len: u32, cnt: usize) -> Box<[T]> { + (0..cnt).map(|_| T::from_i32(self.read(bit_len))).collect() + } +} + +trait FromI32 { + fn from_i32(x: i32) -> Self; +} + +impl FromI32 for i8 { + fn from_i32(x: i32) -> Self { + x as i8 + } +} +impl FromI32 for i16 { + fn from_i32(x: i32) -> Self { + x as i16 + } +} +impl FromI32 for i32 { + fn from_i32(x: i32) -> Self { + x + } +} + +pub struct PackedWeights { + layer1_kernel: Box<[i16]>, + layer1_weight: Box<[i16]>, + layer1_bias: Box<[i16]>, + layer2_kernel: Box<[i16]>, + layer2_weight: Box<[i16]>, + layer2_bias: Box<[i16]>, + layer3_kernel: Box<[i16]>, + layer3_weight: Box<[i16]>, + layer3_bias: Box<[i16]>, + lstm1_ih: Box<[i16]>, + lstm1_hh: Box<[i16]>, + lstm1_bias: Box<[i16]>, + lstm2_ih: Box<[i16]>, + lstm2_hh: Box<[i16]>, + lstm2_bias: Box<[i16]>, + out1_weight: Box<[i16]>, + out1_bias: Box<[i16]>, + out2_weight: Box<[i16]>, + out2_bias: i8 +} + +impl PackedWeights { + pub fn new(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), 135783, "invalid length for packed QuantizedPredictor weights"); + let mut reader = BitBufferReader::new(bytes); + Self { + layer1_kernel: reader.read_array(14, 9), + layer1_weight: reader.read_array(14, 16), + layer1_bias: reader.read_array(12, 16), + layer2_kernel: reader.read_array(15, 48), + layer2_weight: reader.read_array(16, 256), + layer2_bias: reader.read_array(14, 16), + layer3_kernel: reader.read_array(14, 48), + layer3_weight: reader.read_array(15, 256), + layer3_bias: reader.read_array(12, 16), + lstm1_ih: reader.read_array(15, 20480), + lstm1_hh: reader.read_array(14, 16384), + lstm1_bias: reader.read_array(12, 256), + lstm2_ih: reader.read_array(15, 16384), + lstm2_hh: reader.read_array(14, 16384), + lstm2_bias: reader.read_array(12, 256), + out1_weight: reader.read_array(14, 4096), + out1_bias: reader.read_array(11, 32), + out2_weight: reader.read_array(13, 32), + out2_bias: reader.read(4) as i8 + } + } +} + +#[cfg(feature = "embed-weights")] +static DEFAULT_WEIGHT_BYTES: &[u8] = include_bytes!("quantized-model.bin"); +#[cfg(feature = "embed-weights")] +static DEFAULT_WEIGHTS: OnceLock = OnceLock::new(); + +#[cfg(feature = "embed-weights")] +pub fn default_weights() -> &'static PackedWeights { + DEFAULT_WEIGHTS.get_or_init(|| PackedWeights::new(DEFAULT_WEIGHT_BYTES)) +} + +pub struct ActivationTables { + sigmoid: Box<[i32]>, + tanh: Box<[i32]> +} + +impl ActivationTables { + const Q11_SCALE: i32 = 2048; // 2 ** 11 + const Q11_SCALE_FLOAT: f32 = 2048.; + const SIGMOID_MAX: i32 = Self::Q11_SCALE * 6; // sigmoid goes asymptotic < -6 or > 6, so limit computation to between these values + const TANH_MAX: i32 = Self::Q11_SCALE * 4; // ^ 4 for tanh + pub const OUT_SCALE: f32 = 65536.; // 2 ** 16, outputs in Q16 + + pub fn new() -> Self { + let sigmoid_len = Self::SIGMOID_MAX * 2 + 1; + let mut sigmoid_table = vec![0; sigmoid_len as usize].into_boxed_slice(); + for i in 0..sigmoid_len { + let v = Self::_real_sigmoid((i - (Self::SIGMOID_MAX)) as f32 / Self::Q11_SCALE_FLOAT); + sigmoid_table[i as usize] = libm::roundevenf(v * Self::OUT_SCALE) as i32; + } + let tanh_len = Self::TANH_MAX * 2 + 1; + let mut tanh_table = vec![0; tanh_len as usize].into_boxed_slice(); + for i in 0..tanh_len { + let v = libm::tanhf((i - (Self::TANH_MAX)) as f32 / Self::Q11_SCALE_FLOAT); + tanh_table[i as usize] = libm::roundevenf(v * Self::OUT_SCALE) as i32; + } + + Self { + sigmoid: sigmoid_table, + tanh: tanh_table + } + } + + #[inline] + fn _real_sigmoid(x: f32) -> f32 { + 1. / (1. + libm::expf(-x)) + } + + #[inline] + pub fn sigmoid(&self, x: i32) -> i32 { + unsafe { + *self + .sigmoid + .get_unchecked((x + Self::SIGMOID_MAX).clamp(0, Self::SIGMOID_MAX * 2) as usize) + } + } + #[inline] + pub fn tanh(&self, x: i32) -> i32 { + unsafe { *self.tanh.get_unchecked((x + Self::TANH_MAX).clamp(0, Self::TANH_MAX * 2) as usize) } + } +} + +static ACTIVATION_TABLES: OnceLock = OnceLock::new(); + +pub struct QuantizedPredictor<'w> { + weights: &'w PackedWeights, + state: Box<[i32]> +} + +impl<'w> QuantizedPredictor<'w> { + pub fn new(weights: &'w PackedWeights) -> Self { + Self { + weights, + state: vec![0; 256].into_boxed_slice() + } + } +} + +#[cfg(feature = "embed-weights")] +impl Default for QuantizedPredictor<'static> { + fn default() -> Self { + Self::new(default_weights()) + } +} + +impl Predictor for QuantizedPredictor<'_> { + fn reset(&mut self) { + self.state.fill(0); + } + + fn predict(&mut self, features: &[f32], buffer: &mut [f32]) -> f32 { + assert_eq!(features.len(), 41 * 3); + assert!(buffer.len() > 464); + + let buffer = unsafe { mem::transmute::<&mut [f32], &mut [i32]>(buffer) }; + + let buffer_ptr = buffer.as_mut_ptr(); + input_layer1(features, &self.weights.layer1_kernel, &self.weights.layer1_weight, &self.weights.layer1_bias, &mut buffer[..304]); + input_layer2(&buffer[..304], &self.weights.layer2_kernel, &self.weights.layer2_weight, &self.weights.layer2_bias, unsafe { + slice::from_raw_parts_mut(buffer_ptr.add(304), 160) + }); + input_layer3(&buffer[304..], &self.weights.layer3_kernel, &self.weights.layer3_weight, &self.weights.layer3_bias, unsafe { + slice::from_raw_parts_mut(buffer_ptr, 80) + }); + lstm::<80, { 80 * 256 }>( + &buffer[..80], + &self.state[..64], + &self.state[64..128], + &self.weights.lstm1_ih, + &self.weights.lstm1_hh, + &self.weights.lstm1_bias, + unsafe { slice::from_raw_parts_mut(buffer_ptr.add(80), 256) } + ); + self.state[..128].copy_from_slice(&buffer[80..208]); + lstm::<64, { 64 * 256 }>( + &self.state[..128], + &self.state[128..192], + &self.state[192..], + &self.weights.lstm2_ih, + &self.weights.lstm2_hh, + &self.weights.lstm2_bias, + &mut buffer[..256] + ); + self.state[128..].copy_from_slice(&buffer[..128]); + output( + &self.state[..128], + &self.state[128..], + &self.weights.out1_weight, + &self.weights.out1_bias, + &self.weights.out2_weight, + self.weights.out2_bias + ) + } +} + +#[inline(never)] +fn input_layer1(features: &[f32], kernel: &[i16], weight: &[i16], bias: &[i16], output: &mut [i32]) { + const NUM_FRAMES: usize = 3; + const NUM_FEATURES: usize = 41; + const FEATURES_INPUT: usize = const { NUM_FRAMES * NUM_FEATURES }; + + const KERNEL_SIZE: usize = 3; + const { + assert!((NUM_FRAMES - KERNEL_SIZE) / 1 + 1 == 1); + }; + const DEPTHWISE_NUM_FEATURES: usize = (NUM_FEATURES - KERNEL_SIZE) / 1 + 1; + const OUT_CHANNELS: usize = 16; + + const POOL_KERNEL_SIZE: usize = 3; + const POOL_STRIDE: usize = 2; + const SCALE_FACTOR: f32 = (1 << 16) as f32; + const POOLED_COLS: usize = (DEPTHWISE_NUM_FEATURES - POOL_KERNEL_SIZE) / POOL_STRIDE + 1; + + output.fill(0); + + assert_eq!(features.len(), FEATURES_INPUT); + + let mut tmp = [0i32; FEATURES_INPUT]; + // doing this conversion in the convolution loop kills performance + for i in 0..FEATURES_INPUT { + unsafe { + // convert to Q16 + *tmp.get_unchecked_mut(i) = libm::floorf(*features.get_unchecked(i) * SCALE_FACTOR) as i32; + }; + } + + let mut row = [0; DEPTHWISE_NUM_FEATURES]; + for c in 0..OUT_CHANNELS { + for ox in 0..DEPTHWISE_NUM_FEATURES { + // depthwise conv + let mut sum = 0; + for kh in 0..KERNEL_SIZE { + for kw in 0..KERNEL_SIZE { + let w = ox + kw; + let input_idx = (kh * NUM_FEATURES) + w; + unsafe { + // Q16 * Q13 = Q29 + sum += *tmp.get_unchecked(input_idx) as i64 * *kernel.get_unchecked((kh * KERNEL_SIZE) + kw) as i64; + } + } + } + + // pointwise conv + unsafe { + // Q29 * Q13 = Q42. bias is Q12 so shift left by 42-12=30 + let x = (sum * *weight.get_unchecked(c) as i64) + ((*bias.get_unchecked(c) as i64) << 30); + // shift down to Q16 + *row.get_unchecked_mut(ox) = (x >> 26) as i32; + } + } + + // max pool over row + let out_row_offs = POOLED_COLS * c; + for q in 0..POOLED_COLS { + for x in 0..POOL_KERNEL_SIZE { + let out_q = unsafe { output.as_mut_ptr().add(out_row_offs + q) }; + // `output` is initially zeroed, so this also acts as ReLU + unsafe { *out_q = (*out_q).max(*row.get_unchecked((q * POOL_STRIDE) + x)) }; + } + } + } +} + +#[inline(never)] +fn input_layer2(features: &[i32], kernel: &[i16], weight: &[i16], bias: &[i16], output: &mut [i32]) { + const HORIZONTAL_KERNEL_SIZE: usize = 3; + const STRIDE: usize = 2; + const CHANNELS: usize = 16; + + const IN_FEATURES: usize = 19; + const OUT_FEATURES: usize = 10; + + output.fill(0); + + for ox in 0..OUT_FEATURES { + let mut row = [0; CHANNELS]; + for c in 0..CHANNELS { + // depthwise conv + let mut sum = 0; + for kw in 0..HORIZONTAL_KERNEL_SIZE { + let ix = (ox * STRIDE + kw) as isize - 1; + if ix < 0 || ix >= IN_FEATURES as isize { + continue; + } + + // Q16 * Q13 = Q29 + unsafe { + sum += *features.get_unchecked((c * IN_FEATURES) + ix as usize) as i64 * *kernel.get_unchecked((c * HORIZONTAL_KERNEL_SIZE) + kw) as i64; + } + } + + // pointwise conv + for oc in 0..CHANNELS { + unsafe { + // Q29 * Q13 = Q42 + let r = sum * *weight.get_unchecked((oc * CHANNELS) + c) as i64; + *row.get_unchecked_mut(oc) += r; + } + } + } + + // apply pointwise conv bias + relu + for oc in 0..CHANNELS { + unsafe { + // bias is Q12 so shift left by 42-12=30 + let br = *row.get_unchecked(oc) + ((*bias.get_unchecked(oc) as i64) << 30); + // shift down to Q16 + *output.get_unchecked_mut((oc * OUT_FEATURES) + ox) = ((br >> 26) as i32).max(0); + } + } + } +} + +#[inline(never)] +fn input_layer3(features: &[i32], kernel: &[i16], weight: &[i16], bias: &[i16], output: &mut [i32]) { + const HORIZONTAL_KERNEL_SIZE: usize = 3; + const STRIDE: usize = 2; + const CHANNELS: usize = 16; + + const IN_FEATURES: usize = 10; + const OUT_FEATURES: usize = 5; + + output.fill(0); + + for ox in 0..OUT_FEATURES { + let mut row = [0; CHANNELS]; + for c in 0..CHANNELS { + // depthwise conv + let mut sum = 0i64; + for kw in 0..HORIZONTAL_KERNEL_SIZE { + let ix = ox * STRIDE + kw; // layer 3 does not use left padding + if ix >= IN_FEATURES { + continue; + } + unsafe { + // Q16 * Q13 = Q29 + sum += *features.get_unchecked((c * IN_FEATURES) + ix as usize) as i64 * *kernel.get_unchecked((c * HORIZONTAL_KERNEL_SIZE) + kw) as i64; + } + } + + // pointwise conv + for oc in 0..CHANNELS { + unsafe { + // Q29 * Q13 = Q42 + *row.get_unchecked_mut(oc) += sum * *weight.get_unchecked((oc * CHANNELS) + c) as i64; + } + } + } + + // apply pointwise conv bias + relu + for oc in 0..CHANNELS { + unsafe { + // bias is Q12 so shift left by 42-12=30 + let r = *row.get_unchecked_mut(oc) + ((*bias.get_unchecked(oc) as i64) << 30); + let ptr = output.get_unchecked_mut((ox * CHANNELS) + oc); + // shift down to Q16 + *ptr = (r >> 26).max(0) as i32; + } + } + } +} + +#[inline(never)] +fn lstm(features: &[i32], h: &[i32], c: &[i32], weight_ih: &[i16], weight_hh: &[i16], bias: &[i16], out: &mut [i32]) { + for d in 0..256 { + // init with Q10 bias, shifted left by 18 to get 10+18=Q28 + let mut o = (unsafe { *bias.get_unchecked(d) } as i64) << 18; + let (ri, rh) = (d * IN_DIM, d * 64); + + for f in 0..IN_DIM { + unsafe { + // Q16 * Q12 = Q28 + o += *features.get_unchecked(f) as i64 * *weight_ih.get_unchecked(ri + f) as i64; + }; + } + + for g in 0..64 { + unsafe { + // Q16 * Q12 = Q28 + o += *h.get_unchecked(g) as i64 * *weight_hh.get_unchecked(rh + g) as i64; + } + } + + unsafe { + // shift down from Q28 to Q11 + *out.get_unchecked_mut(d) = (o >> 17) as i32; + }; + } + + let act = ACTIVATION_TABLES.get_or_init(ActivationTables::new); + for i in 0..64 { + unsafe { + // layout is [input, output, forget, cell] + let ix = act.sigmoid(*out.get_unchecked(i)) as i64; + let fx = act.sigmoid(*out.get_unchecked(128 + i)) as i64; + let cx = act.tanh(*out.get_unchecked(192 + i)) as i64; + // Q16 * Q16 = Q32; Q16 * Q16 = Q32 + let x = (fx * *c.get_unchecked(i) as i64) + (ix * cx); + let xt = act.tanh((x >> 21) as i32) as i64; // Q11 in, Q16 out + // arrange outputs as [hidden, cell] + let o = act.sigmoid(mem::replace(out.get_unchecked_mut(64 + i), (x >> 16) as i32)) as i64; + // Q16 * Q16 = Q32, shift down to Q16 + *out.get_unchecked_mut(i) = ((o * xt) >> 16) as i32; + }; + } +} + +#[inline(never)] +fn output(out_1: &[i32], out_2: &[i32], weight_1: &[i16], bias_1: &[i16], weight_2: &[i16], bias_2: i8) -> f32 { + let mut temp = [0; 32]; + for h in 0..64 { + for f in 0..32 { + unsafe { + // Q16 * Q12 = Q28, shift down to Q19 + let mut o = *out_2.get_unchecked(h) as i64 * *weight_1.get_unchecked(h * 32 + f) as i64; + o += *out_1.get_unchecked(h) as i64 * *weight_1.get_unchecked((h + 64) * 32 + f) as i64; + *temp.get_unchecked_mut(f) += (o >> 9) as i32; + } + } + } + + let mut out = 0; + for f in 0..32 { + unsafe { + // bias is Q10 so shift left by 19-10=9 + let q = *temp.get_unchecked(f) as i64 + ((*bias_1.get_unchecked(f) as i64) << 9); + // Q19 * Q13 = Q32 + out += q.max(0) * *weight_2.get_unchecked(f) as i64; + } + } + // bias is Q9 so shift left by 32-9=23 + out += (bias_2 as i64) << 23; + // shift down to Q11 + out >>= 21; + ACTIVATION_TABLES.get_or_init(ActivationTables::new).sigmoid(out as i32) as f32 / ActivationTables::OUT_SCALE +} diff --git a/src/resample.rs b/src/resample.rs deleted file mode 100644 index a192cce..0000000 --- a/src/resample.rs +++ /dev/null @@ -1,366 +0,0 @@ -use core::{ptr, slice}; - -// Resampling ratio: 2/3 -// input: i32 (normalized, not saturated) :: size (3 * K) + 6 -// output: i32 (shifted 15 positions to the left, + offset 16384) :: size 2 * K -// K: number of blocks -fn resample_48khz_to_32khz(inv: &[i32], out: &mut [i32]) { - const COEFFICIENTS: [i32; 8] = [778, -2050, 1087, 23285, 12903, -3783, 441, 222]; - - let block_size = out.len() / 2; - assert!(inv.len() >= (block_size * 3) + 6); - - let mut ip = 0; - let mut op = 0; - for _ in 0..block_size { - let mut tmp = 1 << 14; - tmp += COEFFICIENTS[0] * inv[ip]; - tmp += COEFFICIENTS[1] * inv[ip + 1]; - tmp += COEFFICIENTS[2] * inv[ip + 2]; - tmp += COEFFICIENTS[3] * inv[ip + 3]; - tmp += COEFFICIENTS[4] * inv[ip + 4]; - tmp += COEFFICIENTS[5] * inv[ip + 5]; - tmp += COEFFICIENTS[6] * inv[ip + 6]; - tmp += COEFFICIENTS[7] * inv[ip + 7]; - out[op] = tmp; - - tmp = 1 << 14; - tmp += COEFFICIENTS[7] * inv[ip + 1]; - tmp += COEFFICIENTS[6] * inv[ip + 2]; - tmp += COEFFICIENTS[5] * inv[ip + 3]; - tmp += COEFFICIENTS[4] * inv[ip + 4]; - tmp += COEFFICIENTS[3] * inv[ip + 5]; - tmp += COEFFICIENTS[2] * inv[ip + 6]; - tmp += COEFFICIENTS[1] * inv[ip + 7]; - tmp += COEFFICIENTS[0] * inv[ip + 8]; - out[op + 1] = tmp; - - ip += 3; - op += 2; - } -} - -const RESAMPLE_ALLPASS_COEFFS: [i32; 6] = [3050, 9368, 15063, 821, 6110, 12382]; - -fn down_x2_i32_i16(inv: &mut [i32], out: &mut [i16], state: &mut [i32]) { - let len = inv.len() >> 1; - // lower allpass filter (operates on even input samples) - for i in 0..len { - let tmp0 = inv[i << 1]; - let mut diff = tmp0 - state[1]; - - // scale down and round - diff = (diff + (1 << 13)) >> 14; - let tmp1 = state[0] + diff * RESAMPLE_ALLPASS_COEFFS[0]; - state[0] = tmp0; - diff = tmp1 - state[2]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - let tmp0 = state[1] + diff * RESAMPLE_ALLPASS_COEFFS[1]; - state[1] = tmp1; - diff = tmp0 - state[3]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - state[3] = state[2] + diff * RESAMPLE_ALLPASS_COEFFS[2]; - state[2] = tmp0; - - // divide by two and store temporarily - inv[i << 1] = state[3] >> 1; - } - - // upper allpass filter (operates on odd input samples) - let inv2 = &mut inv[1..]; - for i in 0..len { - let tmp0 = inv2[i << 1]; - let mut diff = tmp0 - state[5]; - // scale down and round - diff = (diff + (1 << 13)) >> 14; - let tmp1 = state[4] + diff * RESAMPLE_ALLPASS_COEFFS[3]; - state[4] = tmp0; - diff = tmp1 - state[6]; - - // scale down and round - diff >>= 14; - if diff < 0 { - diff += 1; - } - let tmp0 = state[5] + diff * RESAMPLE_ALLPASS_COEFFS[4]; - state[5] = tmp1; - diff = tmp0 - state[7]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - state[7] = state[6] + diff * RESAMPLE_ALLPASS_COEFFS[5]; - state[6] = tmp0; - - // divide by two and store temporarily - inv2[i << 1] = state[7] >> 1; - } - - // combine allpass outputs - for i in (0..len).step_by(2) { - // divide by two, add both allpass outputs and round - let tmp0 = (inv[i << 1] + inv[(i << 1) + 1]) >> 15; - let tmp1 = (inv[(i << 1) + 2] + inv[(i << 1) + 3]) >> 15; - out[i] = tmp0.clamp(i16::MIN as i32, i16::MAX as i32) as i16; - out[i + 1] = tmp1.clamp(i16::MIN as i32, i16::MAX as i32) as i16; - } -} - -fn down_x2_i16_i32(inv: &[i16], out: &mut [i32], state: &mut [i32]) { - let len = inv.len() >> 1; - - // lower allpass filter (operates on even input samples) - for i in 0..len { - let tmp0 = ((inv[1 << 1] as i32) << 15) + (1 << 14); - let mut diff = tmp0 - state[1]; - // scale down and round - diff = (diff + (1 << 13)) >> 14; - let tmp1 = state[0] + diff * RESAMPLE_ALLPASS_COEFFS[0]; - state[0] = tmp0; - diff = tmp1 - state[2]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - let tmp0 = state[1] + diff * RESAMPLE_ALLPASS_COEFFS[1]; - state[1] = tmp1; - diff = tmp0 - state[3]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - state[3] = state[2] + diff * RESAMPLE_ALLPASS_COEFFS[2]; - state[2] = tmp0; - - out[i] = state[3] >> 1; - } - - // upper allpass filter (operates on odd input samples) - let inv2 = &inv[1..]; - for i in 0..len { - let tmp0 = ((inv2[i << 1] as i32) << 15) + (1 << 14); - let mut diff = tmp0 - state[5]; - // scale down and round - diff = (diff + (1 << 13)) >> 14; - let tmp1 = state[4] + diff * RESAMPLE_ALLPASS_COEFFS[3]; - state[4] = tmp0; - diff = tmp1 - state[6]; - - // scale down and round - diff >>= 14; - if diff < 0 { - diff += 1; - } - let tmp0 = state[5] + diff * RESAMPLE_ALLPASS_COEFFS[4]; - state[5] = tmp1; - diff = tmp0 - state[7]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - state[7] = state[6] + diff * RESAMPLE_ALLPASS_COEFFS[5]; - state[6] = tmp0; - - out[i] += state[7] >> 1; - } -} - -fn lowpass_2x_i32_i32(inv: &[i32], out: &mut [i32], state: &mut [i32]) { - let len = inv.len() >> 1; - - let inv1 = &inv[1..]; - let mut v0 = state[12]; - for i in 0..len { - let mut diff = v0 - state[1]; - // scale down and round - diff = (diff + (1 << 13)) >> 14; - let tmp1 = state[0] + diff * RESAMPLE_ALLPASS_COEFFS[0]; - state[0] = v0; - diff = tmp1 - state[2]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - let tmp0 = state[1] + diff * RESAMPLE_ALLPASS_COEFFS[1]; - state[1] = tmp1; - diff = tmp0 - state[3]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - state[3] = state[2] + diff * RESAMPLE_ALLPASS_COEFFS[2]; - state[2] = tmp0; - - // scale down, round and store - out[i << 1] = state[3] >> 1; - v0 = inv1[i << 1]; - } - - // upper allpass filter: even input -> even output samples - for i in 0..len { - let tmp0 = inv[i << 1]; - let mut diff = tmp0 - state[5]; - // scale down and round - diff = (diff + (1 << 13)) >> 14; - let tmp1 = state[4] + diff * RESAMPLE_ALLPASS_COEFFS[3]; - state[4] = tmp0; - diff = tmp1 - state[6]; - - // scale down and round - diff >>= 14; - if diff < 0 { - diff += 1; - } - let tmp0 = state[5] + diff * RESAMPLE_ALLPASS_COEFFS[4]; - state[5] = tmp1; - diff = tmp0 - state[7]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - state[7] = state[6] + diff * RESAMPLE_ALLPASS_COEFFS[5]; - state[6] = tmp0; - - // average the two allpass outputs, scale down and store - out[i << 1] = (out[i << 1] + (state[7] >> 1)) >> 15; - } - - // lower allpass filter: even input -> odd output samples - let out = &mut out[1..]; - for i in 0..len { - let tmp0 = inv[i << 1]; - let mut diff = tmp0 - state[9]; - // scale down and round - diff = (diff + (1 << 13)) >> 14; - let tmp1 = state[8] + diff * RESAMPLE_ALLPASS_COEFFS[0]; - state[8] = tmp0; - diff = tmp1 - state[10]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - let tmp0 = state[9] + diff * RESAMPLE_ALLPASS_COEFFS[1]; - state[9] = tmp1; - diff = tmp0 - state[11]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - state[11] = state[10] + diff * RESAMPLE_ALLPASS_COEFFS[2]; - state[10] = tmp0; - - // scale down, round and store - out[i << 1] = state[11] >> 1; - } - - // upper allpass filter: odd input -> odd output samples - let inv = &inv[1..]; - for i in 0..len { - let tmp0 = inv[i << 1]; - let mut diff = tmp0 - state[13]; - // scale down and round - diff = (diff + (1 << 13)) >> 14; - let tmp1 = state[12] + diff * RESAMPLE_ALLPASS_COEFFS[3]; - state[12] = tmp0; - diff = tmp1 - state[14]; - - // scale down and round - diff >>= 14; - if diff < 0 { - diff += 1; - } - let tmp0 = state[13] + diff * RESAMPLE_ALLPASS_COEFFS[4]; - state[13] = tmp1; - diff = tmp0 - state[15]; - - // scale down and truncate - diff >>= 14; - if diff < 0 { - diff += 1; - } - state[15] = state[14] + diff * RESAMPLE_ALLPASS_COEFFS[5]; - state[14] = tmp0; - - // average the two allpass outputs, scale down and store - out[i << 1] = (out[i << 1] + (state[15] >> 1)) >> 15; - } -} - -pub fn resample_48khz_to_8khz(inv: &[i16], out: &mut [i16], states: &mut [i32], tmp: &mut [i32]) { - assert_eq!(tmp.len(), 480 + 256); - assert_eq!(inv.len(), 480); - - down_x2_i16_i32(inv, &mut tmp[256..], &mut states[0..8]); - lowpass_2x_i32_i32(&tmp[256..256 + 240], unsafe { slice::from_raw_parts_mut(tmp.as_ptr().cast_mut().add(16), 240) }, &mut states[8..24]); - // uses states[24..32] as a buffer to store the last 8 samples to put at the beginning of the next frame, because - // `resample_48khz_to_32khz` wants to read a couple samples beyond the frame - unsafe { - ptr::copy_nonoverlapping(states.as_ptr().add(24), tmp.as_mut_ptr().add(8), 8); - ptr::copy_nonoverlapping(tmp.as_ptr().add(248), states.as_mut_ptr().add(24), 8); - }; - // obviously horrible in many ways - resample_48khz_to_32khz(&tmp[8..256], unsafe { slice::from_raw_parts_mut(tmp.as_ptr().cast_mut(), 160) }); - down_x2_i32_i16(&mut tmp[..160], out, &mut states[32..40]); -} - -#[cfg(test)] -mod tests { - use super::resample_48khz_to_32khz; - - #[test] - fn test_resample48_saturated() { - #[rustfmt::skip] - let vector_saturated = [ - -32768, -32768, -32768, -32768, -32768, -32768, -32768, -32768, - -32768, -32768, -32768, -32768, -32768, -32768, -32768, -32768, - -32768, -32768, -32768, -32768, -32768, -32768, -32768, -32768, - 32767, 32767, 32767, 32767, 32767, 32767, 32767, 32767, - 32767, 32767, 32767, 32767, 32767, 32767, 32767, 32767, - 32767, 32767, 32767, 32767, 32767, 32767, 32767, 32767, - 32767, 32767, 32767, 32767, 32767, 32767 - ]; - - let ref1 = -1077493760; - let ref2 = 1077493645; - - let mut out_vector = [0; 2 * 16]; - - resample_48khz_to_32khz(&vector_saturated, &mut out_vector); - - // values at position 12-15 are skipped to account for the filter lag. - for i in 0..12 { - assert_eq!(out_vector[i], ref1); - } - for i in 16..32 { - assert_eq!(out_vector[i], ref2); - } - } -} diff --git a/src/sp.rs b/src/sp.rs deleted file mode 100644 index a600a81..0000000 --- a/src/sp.rs +++ /dev/null @@ -1,163 +0,0 @@ -// Downsampling filter based on splitting filter and allpass functions. -pub fn downsample_2x(inv: &[i16], out: &mut [i16], filter_state: &mut [i32]) { - const UPPER_COEFFICIENT: i32 = 5243; // 0.64 in Q13 - const LOWER_COEFFICIENT: i32 = 1392; // 0.17 in Q13 - - let half_len = inv.len() >> 1; - let mut tmpd1 = filter_state[0]; - let mut tmpd2 = filter_state[1]; - // Filter coefficients in Q13, filter state in Q0. - for n in 0..half_len { - // All-pass filtering upper branch. - let tmpw1 = ((tmpd1 >> 1) + ((UPPER_COEFFICIENT * inv[n * 2] as i32) >> 14)) as i16; - out[n] = tmpw1; - tmpd1 = (inv[n * 2] as i32) - ((UPPER_COEFFICIENT * tmpw1 as i32) >> 12); - - // All-pass filtering lower branch. - let tmpw2 = ((tmpd2 >> 1) + ((LOWER_COEFFICIENT * inv[n * 2 + 1] as i32) >> 14)) as i16; - out[n] = out[n].saturating_add(tmpw2); // C source originally wrapped but that introduced popping - tmpd2 = (inv[n * 2 + 1] as i32) - ((LOWER_COEFFICIENT * tmpw2 as i32) >> 12); - } - filter_state[0] = tmpd1; - filter_state[1] = tmpd2; -} - -// Inserts |feature_value| into |low_value_vector|, if it is one of the 16 -// smallest values the last 100 frames. Then calculates and returns the median -// of the five smallest values. -pub fn find_minimum(age: &mut [i16], low_value: &mut [i16], frame_counter: usize, mean_values: &mut [i16], feature_value: i16, channel: usize) -> i16 { - let offset = channel << 4; - let age = &mut age[offset..offset + 16]; - let smallest_values = &mut low_value[offset..offset + 16]; - - // Each value in |smallest_values| is getting 1 loop older. Update |age|, and - // remove old values. - for i in 0..16 { - if age[i] != 100 { - age[i] += 1; - } else { - // Too old value. Remove from memory and shift larger values downwards. - for j in i..15 { - smallest_values[j] = smallest_values[j + 1]; - age[j] = age[j + 1]; - } - age[15] = 101; - smallest_values[15] = 10000; - } - } - - // Check if |feature_value| is smaller than any of the values in - // |smallest_values|. If so, find the |position| where to insert the new value - // (|feature_value|). - let position = if feature_value < smallest_values[7] { - if feature_value < smallest_values[3] { - if feature_value < smallest_values[1] { - if feature_value < smallest_values[0] { 0 } else { 1 } - } else if feature_value < smallest_values[2] { - 2 - } else { - 3 - } - } else if feature_value < smallest_values[5] { - if feature_value < smallest_values[4] { 4 } else { 5 } - } else if feature_value < smallest_values[6] { - 6 - } else { - 7 - } - } else if feature_value < smallest_values[15] { - if feature_value < smallest_values[11] { - if feature_value < smallest_values[9] { - if feature_value < smallest_values[8] { 8 } else { 9 } - } else if feature_value < smallest_values[10] { - 10 - } else { - 11 - } - } else if feature_value < smallest_values[13] { - if feature_value < smallest_values[12] { 12 } else { 13 } - } else if feature_value < smallest_values[14] { - 14 - } else { - 15 - } - } else { - -1 - }; - - // If we have detected a new small value, insert it at the correct position - // and shift larger values up. - if position > -1 { - let position = position as usize; - let mut i = 15; - while i > position { - smallest_values[i] = smallest_values[i - 1]; - age[i] = age[i - 1]; - i -= 1; - } - smallest_values[position] = feature_value; - age[position] = 1; - } - - let current_median = if frame_counter > 2 { - smallest_values[2] - } else if frame_counter > 0 { - smallest_values[0] - } else { - 1600 - }; - - // Smooth the median value. - const SMOOTHING_DOWN: i16 = 6553; // 0.2 in Q15 - const SMOOTHING_UP: i16 = 32439; // 0.99 in Q15 - let alpha = if frame_counter > 0 { - if current_median < mean_values[channel] { SMOOTHING_DOWN } else { SMOOTHING_UP } - } else { - 0 - }; - - let smoothed = ((alpha + 1) as i32 * mean_values[channel] as i32) + ((i16::MAX - alpha) as i32 * current_median as i32) + 16384; - mean_values[channel] = (smoothed >> 15) as i16; - mean_values[channel] -} - -#[cfg(test)] -mod tests { - use super::downsample_2x; - use crate::sp::find_minimum; - - #[test] - fn test_downsample() { - let zeros = vec![0; 960]; - let mut out = vec![0; 960 / 2]; - let mut filter_state = [0, 0]; - downsample_2x(&zeros, &mut out, &mut filter_state); - assert_eq!(filter_state[0], 0); - assert_eq!(filter_state[1], 0); - - let inv = (0..960i16).map(|c| c.wrapping_mul(c)).collect::>(); - downsample_2x(&inv, &mut out, &mut filter_state); - assert_eq!(filter_state[0], 207); - assert_eq!(filter_state[1], 2270); - } - - #[test] - fn test_min() { - let reference = [ - 1600, 720, 509, 512, 532, 552, 570, 588, 606, 624, 642, 659, 675, 691, 707, 723, 1600, 544, 502, 522, 542, 561, 579, 597, 615, 633, 651, 667, 683, - 699, 715, 731 - ]; - let mut age = vec![0; 6 * 16].into_boxed_slice(); - let mut low_value = vec![10000; 6 * 16].into_boxed_slice(); - let mut mean_values = vec![1600; 6].into_boxed_slice(); - let mut frame_counter = 0; - for i in 0..16 { - let value = 500 * (i as i16 + 1); - for j in 0..6 { - assert_eq!(reference[i], find_minimum(&mut age, &mut low_value, frame_counter, &mut mean_values, value, j)); - assert_eq!(reference[i + 16], find_minimum(&mut age, &mut low_value, frame_counter, &mut mean_values, 12000, j)); - } - frame_counter += 1; - } - } -} diff --git a/src/util.rs b/src/util.rs index 2acd7f8..2846e93 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,47 +1,89 @@ -#[inline] -pub fn div_i32_i16(a: i32, b: i16) -> i32 { - a.checked_div(b as i32).unwrap_or(i32::MAX) -} +use core::{ + cell::UnsafeCell, + hint::spin_loop, + marker::PhantomData, + mem::{MaybeUninit, forget}, + sync::atomic::{AtomicU8, Ordering} +}; -#[inline] -pub fn size_in_bits(n: i32) -> u8 { - 32 - n.leading_zeros() as u8 +pub(crate) struct OnceLock { + data: UnsafeCell>, + status: AtomicU8, + phantom: PhantomData } -/// Return the number of steps `a` can be left-shifted without overflow, or `0` if `a == 0`. -#[inline] -pub fn norm_i32(a: i32) -> i8 { - if a != 0 { - (if a < 0 { !a } else { a }).leading_zeros() as i8 - 1 // sub 1 for the sign bit - } else { - 0 +unsafe impl Send for OnceLock {} +unsafe impl Sync for OnceLock {} + +const STATUS_UNINITIALIZED: u8 = 0; +const STATUS_RUNNING: u8 = 1; +const STATUS_INITIALIZED: u8 = 2; + +impl OnceLock { + pub const fn new() -> Self { + Self { + data: UnsafeCell::new(MaybeUninit::uninit()), + status: AtomicU8::new(STATUS_UNINITIALIZED), + phantom: PhantomData + } } -} -/// Return the number of steps `a` can be left-shifted without overflow, or `0` if `a == 0`. -#[inline] -pub fn norm_u32(a: u32) -> u8 { - if a != 0 { a.leading_zeros() as u8 } else { 0 } -} + #[inline] + unsafe fn get_unchecked(&self) -> &T { + &*(*self.data.get()).as_ptr() + } -pub fn weighted_average(data: &mut [[i16; 6]], channel: usize, offset: i16, weights: &[[i16; 6]]) -> i32 { - let mut weighted_average = 0; - for gaussian in 0..data.len() { - data[gaussian][channel] += offset; - weighted_average += data[gaussian][channel] as i32 * weights[gaussian][channel] as i32; + #[inline] + pub fn get(&self) -> Option<&T> { + match self.status.load(Ordering::Acquire) { + STATUS_INITIALIZED => Some(unsafe { self.get_unchecked() }), + _ => None + } } - weighted_average -} -#[cfg(test)] -mod tests { - use crate::util::{norm_u32, size_in_bits}; + #[inline] + pub fn get_or_init T>(&self, f: F) -> &T { + if let Some(value) = self.get() { value } else { self.init_inner(f) } + } + + #[cold] + fn init_inner T>(&self, f: F) -> &T { + 'a: loop { + match self + .status + .compare_exchange(STATUS_UNINITIALIZED, STATUS_RUNNING, Ordering::Acquire, Ordering::Acquire) + { + Ok(_) => { + struct SetStatusOnPanic<'a> { + status: &'a AtomicU8 + } + impl Drop for SetStatusOnPanic<'_> { + fn drop(&mut self) { + self.status.store(STATUS_UNINITIALIZED, Ordering::SeqCst); + } + } + + let panic_catcher = SetStatusOnPanic { status: &self.status }; + let val = f(); + unsafe { + (*self.data.get()).as_mut_ptr().write(val); + }; + forget(panic_catcher); + + self.status.store(STATUS_INITIALIZED, Ordering::Release); - #[test] - fn test_norm() { - assert_eq!(17, size_in_bits(111121)); - assert_eq!(0, norm_u32(0)); - assert_eq!(0, norm_u32(u32::MAX)); - assert_eq!(15, norm_u32(111121)); + return unsafe { self.get_unchecked() }; + } + Err(STATUS_INITIALIZED) => return unsafe { self.get_unchecked() }, + Err(STATUS_RUNNING) => loop { + match self.status.load(Ordering::Acquire) { + STATUS_RUNNING => spin_loop(), + STATUS_INITIALIZED => return unsafe { self.get_unchecked() }, + _ => continue 'a + } + }, + _ => continue + } + } } }