diff --git a/.cursor/rules/simple.mdc b/.cursor/rules/simple.mdc index 92a5a10d9..509795c67 100644 --- a/.cursor/rules/simple.mdc +++ b/.cursor/rules/simple.mdc @@ -9,6 +9,11 @@ alwaysApply: true # Typescript - Avoid creating a bunch of types/interfaces if they are not shared. Especially for function props. Just inline them. +- After some amount of TypeScript changes, run `pnpm -r typecheck`. + +# Rust + +- After some amount of Rust changes, run `cargo check`. # Mutation - Never do manual state management for form/mutation. Things like setError is anti-pattern. use useForm(from tanstack-form) and useQuery/useMutation(from tanstack-query) for 99% cases. @@ -19,7 +24,6 @@ alwaysApply: true # Misc - Do not create summary docs or example code file if not requested. Plan is ok. -- After a significant amount of TypeScript changes, run `pnpm -r typecheck`. - If there are many classNames and they have conditional logic, use `cn` (import it with `import { cn } from "@hypr/utils"`). It is similar to `clsx`. Always pass an array. Split by logical grouping. - Use `motion/react` instead of `framer-motion`. diff --git a/Cargo.lock b/Cargo.lock index f9709850b..e70ba138a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7492,6 +7492,12 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "if_chain" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd62e6b5e86ea8eeeb8db1de02880a6abc01a397b2ebb64b5d74ac255318f5cb" + [[package]] name = "ignore" version = "0.4.25" @@ -11411,22 +11417,32 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ractor" -version = "0.15.9" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9500e0be6f12a0539cb1154d654ef2e888bf8529164e54aff4a097baad5bb001" +checksum = "1d65972a0286ef14c43c6daafbac6cf15e96496446147683b2905292c35cc178" dependencies = [ + "async-trait", "bon 2.3.0", "dashmap", "futures", - "js-sys", "once_cell", "strum 0.26.3", "tokio", - "tokio_with_wasm", "tracing", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-time", +] + +[[package]] +name = "ractor-supervisor" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d90830688ebfafdc226f3c9567c40fecf4c51a7513171181102ae66e4b57c15f" +dependencies = [ + "futures-util", + "if_chain", + "log", + "ractor", + "thiserror 2.0.17", + "uuid", ] [[package]] @@ -14656,6 +14672,7 @@ dependencies = [ "owhisper-client", "owhisper-interface", "ractor", + "ractor-supervisor", "rodio", "serde", "serde_json", @@ -15858,30 +15875,6 @@ dependencies = [ "webpki-roots 0.26.11", ] -[[package]] -name = "tokio_with_wasm" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dfba9b946459940fb564dcf576631074cdfb0bfe4c962acd4c31f0dca7897e6" -dependencies = [ - "js-sys", - "tokio", - "tokio_with_wasm_proc", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - -[[package]] -name = "tokio_with_wasm_proc" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e04c1865c281139e5ccf633cb9f76ffdaabeebfe53b703984cf82878e2aabb" -dependencies = [ - "quote", - "syn 2.0.108", -] - [[package]] name = "toml" version = "0.8.23" diff --git a/Cargo.toml b/Cargo.toml index 9d083bdfc..13533c6b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,7 +132,8 @@ async-stream = "0.3.6" futures-channel = "0.3.31" futures-core = "0.3.31" futures-util = "0.3.31" -ractor = "0.15" +ractor = { version = "0.14.3" } +ractor-supervisor = "0.1.9" reqwest = "0.12" reqwest-streams = "0.10.0" tokio = "1" diff --git a/Taskfile.yaml b/Taskfile.yaml index b77a6858e..ea3188aaf 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -72,3 +72,12 @@ tasks: cmds: - chmod +x ./apps/desktop/src-tauri/resources/stt-aarch64-apple-darwin - chmod +x ./apps/desktop/src-tauri/resources/passthrough-aarch64-apple-darwin + + db: + env: + DB: /Users/yujonglee/Library/Application Support/com.hyprnote.nightly/db.sqlite + cmds: + - | + sqlite3 -json "$DB" 'SELECT store FROM main LIMIT 1;' | + jq -r '.[0].store' | + jless diff --git a/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/hooks.ts b/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/hooks.ts index 4da042101..30672ac81 100644 --- a/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/hooks.ts +++ b/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/hooks.ts @@ -13,9 +13,11 @@ export function useFinalWords(transcriptId: string): (main.Word & { id: string } return []; } - return Object.entries(resultTable) + const ret = Object.entries(resultTable) .map(([wordId, row]) => ({ ...(row as unknown as main.Word), id: wordId })) .sort((a, b) => a.start_ms - b.start_ms); + + return ret; }, [resultTable]); } diff --git a/apps/desktop/src/hooks/useAutoEnhance.ts b/apps/desktop/src/hooks/useAutoEnhance.ts index 335e16d26..e3138a889 100644 --- a/apps/desktop/src/hooks/useAutoEnhance.ts +++ b/apps/desktop/src/hooks/useAutoEnhance.ts @@ -59,12 +59,7 @@ export function useAutoEnhance(tab: Extract) { if (listenerJustStopped) { startEnhance(); - } - }, [listenerStatus, prevListenerStatus, startEnhance]); - - useEffect(() => { - if (enhanceTask.status === "generating" && tab.state.editor !== "enhanced") { updateSessionTabState(tab, { editor: "enhanced" }); } - }, [enhanceTask.status, tab, updateSessionTabState]); + }, [listenerStatus, prevListenerStatus, startEnhance]); } diff --git a/crates/audio-utils/src/lib.rs b/crates/audio-utils/src/lib.rs index 557c48a4a..eebe6e504 100644 --- a/crates/audio-utils/src/lib.rs +++ b/crates/audio-utils/src/lib.rs @@ -1,3 +1,5 @@ +use std::convert::TryFrom; + use bytes::{BufMut, Bytes, BytesMut}; use futures_util::{Stream, StreamExt}; use kalosm_sound::AsyncSource; @@ -6,11 +8,19 @@ mod error; pub use error::*; mod vorbis; pub use vorbis::*; +mod stream; +pub use stream::*; pub use rodio::Source; const I16_SCALE: f32 = 32768.0; +#[derive(Debug, Clone, Copy)] +pub struct AudioMetadata { + pub sample_rate: u32, + pub channels: u8, +} + impl AudioFormatExt for T {} pub trait AudioFormatExt: AsyncSource { @@ -81,6 +91,40 @@ pub fn source_from_path( Ok(decoder) } +fn metadata_from_source(source: &S) -> Result +where + S: Source, + S::Item: rodio::Sample, +{ + let sample_rate = source.sample_rate(); + if sample_rate == 0 { + return Err(crate::Error::InvalidSampleRate(sample_rate)); + } + + let channels_u16 = source.channels(); + if channels_u16 == 0 { + return Err(crate::Error::UnsupportedChannelCount { + count: channels_u16, + }); + } + let channels = + u8::try_from(channels_u16).map_err(|_| crate::Error::UnsupportedChannelCount { + count: channels_u16, + })?; + + Ok(AudioMetadata { + sample_rate, + channels, + }) +} + +pub fn audio_file_metadata( + path: impl AsRef, +) -> Result { + let source = source_from_path(path)?; + metadata_from_source(&source) +} + pub fn resample_audio(source: S, to_rate: u32) -> Result, crate::Error> where S: rodio::Source + Iterator, @@ -136,32 +180,48 @@ where pub struct ChunkedAudio { pub chunks: Vec, pub sample_count: usize, + pub frame_count: usize, + pub metadata: AudioMetadata, } pub fn chunk_audio_file( path: impl AsRef, - sample_rate: u32, - chunk_size: usize, + chunk_ms: u64, ) -> Result { let source = source_from_path(path)?; - let samples = resample_audio(source, sample_rate)?; + let metadata = metadata_from_source(&source)?; + let samples = resample_audio(source, metadata.sample_rate)?; if samples.is_empty() { return Ok(ChunkedAudio { chunks: Vec::new(), sample_count: 0, + frame_count: 0, + metadata, }); } - let chunk_size = chunk_size.max(1); + let channels = metadata.channels.max(1) as usize; + let frames_per_chunk = { + let frames = ((chunk_ms as u128).saturating_mul(metadata.sample_rate as u128) + 999) / 1000; + frames.max(1).min(usize::MAX as u128) as usize + }; + let samples_per_chunk = frames_per_chunk + .saturating_mul(channels) + .max(1) + .min(usize::MAX); + let sample_count = samples.len(); + let frame_count = sample_count / channels; let chunks = samples - .chunks(chunk_size) + .chunks(samples_per_chunk) .map(|chunk| f32_to_i16_bytes(chunk.iter().copied())) .collect(); Ok(ChunkedAudio { chunks, sample_count, + frame_count, + metadata, }) } diff --git a/crates/audio-utils/src/stream.rs b/crates/audio-utils/src/stream.rs new file mode 100644 index 000000000..70e591fa0 --- /dev/null +++ b/crates/audio-utils/src/stream.rs @@ -0,0 +1,278 @@ +use std::{ + collections::VecDeque, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, +}; + +use futures_util::Stream; +use kalosm_sound::AsyncSource; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct DrainState { + pub processed: bool, + pub terminated: bool, +} + +pub fn drain_stream(mut stream: Pin<&mut S>, mut on_item: F) -> Result +where + S: Stream, + F: FnMut(S::Item) -> Result<(), E>, +{ + let waker = noop_waker(); + let mut ctx = Context::from_waker(&waker); + let mut processed = false; + + loop { + match stream.as_mut().poll_next(&mut ctx) { + Poll::Ready(Some(item)) => { + on_item(item)?; + processed = true; + } + Poll::Ready(None) => { + return Ok(DrainState { + processed, + terminated: true, + }) + } + Poll::Pending => { + return Ok(DrainState { + processed, + terminated: false, + }) + } + } + } +} + +pub fn poll_next_now(mut stream: Pin<&mut S>) -> Poll> +where + S: Stream, +{ + let waker = noop_waker(); + let mut ctx = Context::from_waker(&waker); + stream.as_mut().poll_next(&mut ctx) +} + +/// Push-driven audio source implementing [`AsyncSource`]. +#[derive(Debug)] +pub struct PushSource { + shared: Arc>, +} + +/// Handle for feeding data into [`PushSource`]. +#[derive(Debug, Clone)] +pub struct PushSourceHandle { + shared: Arc>, +} + +#[derive(Debug, Default)] +struct Shared { + queue: VecDeque, + current: Option, + index: usize, + sample_rate: u32, + closed: bool, +} + +#[derive(Debug)] +struct Chunk { + samples: Vec, + sample_rate: u32, +} + +impl PushSource { + /// Create a new push source with an initial sample rate. + pub fn new(initial_sample_rate: u32) -> (Self, PushSourceHandle) { + let shared = Arc::new(Mutex::new(Shared { + sample_rate: initial_sample_rate, + ..Default::default() + })); + + ( + Self { + shared: shared.clone(), + }, + PushSourceHandle { shared }, + ) + } +} + +impl PushSourceHandle { + /// Queue a chunk of samples produced at the provided sample rate. + pub fn push(&self, samples: Vec, sample_rate: u32) { + if samples.is_empty() || sample_rate == 0 { + return; + } + + let mut shared = self.shared.lock().unwrap(); + shared.queue.push_back(Chunk { + samples, + sample_rate, + }); + } + + /// Signal that no additional data will arrive. + pub fn close(&self) { + let mut shared = self.shared.lock().unwrap(); + shared.closed = true; + } +} + +struct PushSourceStream<'a> { + source: &'a mut PushSource, +} + +impl<'a> Stream for PushSourceStream<'a> { + type Item = f32; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let mut shared = self.source.shared.lock().unwrap(); + + loop { + if let Some(len) = shared.current.as_ref().map(|chunk| chunk.samples.len()) { + if shared.index < len { + let sample = shared.current.as_ref().unwrap().samples[shared.index]; + shared.index += 1; + + if shared.index == len { + shared.current = None; + shared.index = 0; + } + + return Poll::Ready(Some(sample)); + } + + shared.current = None; + shared.index = 0; + continue; + } + + if let Some(next) = shared.queue.pop_front() { + shared.sample_rate = next.sample_rate; + + if next.samples.is_empty() { + continue; + } + + shared.current = Some(next); + shared.index = 0; + continue; + } + + return if shared.closed { + Poll::Ready(None) + } else { + Poll::Pending + }; + } + } +} + +impl AsyncSource for PushSource { + fn as_stream(&mut self) -> impl Stream + '_ { + PushSourceStream { source: self } + } + + fn sample_rate(&self) -> u32 { + let shared = self.shared.lock().unwrap(); + + if let Some(current) = shared.current.as_ref() { + current.sample_rate + } else if let Some(next) = shared.queue.front() { + next.sample_rate + } else { + shared.sample_rate + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::pin::Pin; + use std::task::Poll; + + #[test] + fn streams_pushed_samples() { + let (mut source, handle) = PushSource::new(16_000); + handle.push(vec![0.0, 1.0], 16_000); + handle.push(vec![2.0, 3.0], 16_000); + handle.close(); + + let mut stream = source.as_stream(); + let mut collected = Vec::new(); + let state = drain_stream(Pin::new(&mut stream), |sample| { + collected.push(sample); + Ok::<_, ()>(()) + }) + .unwrap(); + + assert!(state.processed); + assert!(state.terminated); + assert_eq!(collected, vec![0.0, 1.0, 2.0, 3.0]); + } + + #[test] + fn updates_sample_rate_from_chunks() { + let (mut source, handle) = PushSource::new(8_000); + assert_eq!(source.sample_rate(), 8_000); + + handle.push(vec![0.0], 12_000); + assert_eq!(source.sample_rate(), 12_000); + + handle.push(vec![1.0, 2.0], 24_000); + handle.close(); + + { + let mut stream = source.as_stream(); + assert_eq!(poll_next_now(Pin::new(&mut stream)), Poll::Ready(Some(0.0))); + } + assert_eq!(source.sample_rate(), 24_000); + { + let mut stream = source.as_stream(); + assert_eq!(poll_next_now(Pin::new(&mut stream)), Poll::Ready(Some(1.0))); + } + { + let mut stream = source.as_stream(); + assert_eq!(poll_next_now(Pin::new(&mut stream)), Poll::Ready(Some(2.0))); + } + { + let mut stream = source.as_stream(); + assert_eq!(poll_next_now(Pin::new(&mut stream)), Poll::Ready(None)); + } + assert_eq!(source.sample_rate(), 24_000); + } + + #[test] + fn pending_until_closed() { + let (mut source, handle) = PushSource::new(16_000); + + { + let mut stream = source.as_stream(); + assert_eq!(poll_next_now(Pin::new(&mut stream)), Poll::Pending); + } + + handle.close(); + + { + let mut stream = source.as_stream(); + assert_eq!(poll_next_now(Pin::new(&mut stream)), Poll::Ready(None)); + } + } +} + +fn noop_waker() -> Waker { + unsafe fn clone(_: *const ()) -> RawWaker { + RawWaker::new(std::ptr::null(), &VTABLE) + } + + unsafe fn wake(_: *const ()) {} + + static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake, wake); + + unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) } +} diff --git a/crates/audio/src/mic.rs b/crates/audio/src/mic.rs index d91233c7c..050f3fbaa 100644 --- a/crates/audio/src/mic.rs +++ b/crates/audio/src/mic.rs @@ -65,6 +65,10 @@ impl MicInput { config, }) } + + pub fn sample_rate(&self) -> u32 { + self.config.sample_rate().0 + } } impl MicInput { diff --git a/crates/audio/src/speaker/linux.rs b/crates/audio/src/speaker/linux.rs index 213b30e72..289b503f7 100644 --- a/crates/audio/src/speaker/linux.rs +++ b/crates/audio/src/speaker/linux.rs @@ -7,6 +7,10 @@ impl SpeakerInput { Self {} } + pub fn sample_rate(&self) -> u32 { + 16000 + } + pub fn stream(self) -> SpeakerStream { SpeakerStream::new() } diff --git a/crates/audio/src/speaker/macos.rs b/crates/audio/src/speaker/macos.rs index 82288a73b..b38cd9b92 100644 --- a/crates/audio/src/speaker/macos.rs +++ b/crates/audio/src/speaker/macos.rs @@ -91,6 +91,10 @@ impl SpeakerInput { Ok(Self { tap, agg_desc }) } + pub fn sample_rate(&self) -> u32 { + self.tap.asbd().unwrap().sample_rate as u32 + } + fn start_device( &self, ctx: &mut Box, diff --git a/crates/audio/src/speaker/mod.rs b/crates/audio/src/speaker/mod.rs index 41905a87c..cf9b1c50f 100644 --- a/crates/audio/src/speaker/mod.rs +++ b/crates/audio/src/speaker/mod.rs @@ -42,6 +42,16 @@ impl SpeakerInput { )) } + #[cfg(any(target_os = "macos", target_os = "windows"))] + pub fn sample_rate(&self) -> u32 { + self.inner.sample_rate() + } + + #[cfg(not(any(target_os = "macos", target_os = "windows")))] + pub fn sample_rate(&self) -> u32 { + 0 + } + #[cfg(any(target_os = "macos", target_os = "windows"))] pub fn stream(self) -> Result { let inner = self.inner.stream(); diff --git a/crates/audio/src/speaker/windows.rs b/crates/audio/src/speaker/windows.rs index 83e9f2d3c..60377c588 100644 --- a/crates/audio/src/speaker/windows.rs +++ b/crates/audio/src/speaker/windows.rs @@ -15,6 +15,10 @@ impl SpeakerInput { Ok(Self {}) } + pub fn sample_rate(&self) -> u32 { + 44100 + } + pub fn stream(self) -> SpeakerStream { let sample_queue = Arc::new(Mutex::new(VecDeque::new())); let waker_state = Arc::new(Mutex::new(WakerState { diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index fc78422b6..c0138e425 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -63,11 +63,11 @@ impl ListenClientBuilder { append_language_query(&mut query_pairs, ¶ms); let model = params.model.as_deref().unwrap_or("hypr-whisper"); - let sample_rate = RESAMPLED_SAMPLE_RATE_HZ.to_string(); + let sample_rate = params.sample_rate.to_string(); query_pairs.append_pair("model", model); query_pairs.append_pair("encoding", "linear16"); - query_pairs.append_pair("sample_rate", &sample_rate); + // query_pairs.append_pair("sample_rate", &sample_rate); query_pairs.append_pair("diarize", "true"); query_pairs.append_pair("multichannel", "false"); query_pairs.append_pair("punctuate", "true"); @@ -104,7 +104,7 @@ impl ListenClientBuilder { let model = params.model.as_deref().unwrap_or("hypr-whisper"); let channel_string = channels.to_string(); - let sample_rate = RESAMPLED_SAMPLE_RATE_HZ.to_string(); + let sample_rate = params.sample_rate.to_string(); query_pairs.append_pair("model", model); query_pairs.append_pair("channels", &channel_string); diff --git a/owhisper/owhisper-interface/src/lib.rs b/owhisper/owhisper-interface/src/lib.rs index 351c955f7..424a3a0d6 100644 --- a/owhisper/owhisper-interface/src/lib.rs +++ b/owhisper/owhisper-interface/src/lib.rs @@ -137,6 +137,7 @@ common_derives! { #[serde(default)] pub model: Option, pub channels: u8, + pub sample_rate: u32, // https://docs.rs/axum-extra/0.10.1/axum_extra/extract/struct.Query.html#example-1 #[serde(default)] pub languages: Vec, @@ -152,6 +153,7 @@ impl Default for ListenParams { ListenParams { model: None, channels: 1, + sample_rate: 16000, languages: vec![], keywords: vec![], redemption_time_ms: None, diff --git a/packages/tiptap/src/editor/index.tsx b/packages/tiptap/src/editor/index.tsx index a1967e547..6662085b0 100644 --- a/packages/tiptap/src/editor/index.tsx +++ b/packages/tiptap/src/editor/index.tsx @@ -82,7 +82,7 @@ const Editor = forwardRef<{ editor: TiptapEditor | null }, EditorProps>( if (event.key === "Tab") { event.preventDefault(); - return true; + return false; } return false; diff --git a/packages/tiptap/src/shared/custom-list-keymap.ts b/packages/tiptap/src/shared/custom-list-keymap.ts index 53f7cfa52..5590930ae 100644 --- a/packages/tiptap/src/shared/custom-list-keymap.ts +++ b/packages/tiptap/src/shared/custom-list-keymap.ts @@ -6,6 +6,34 @@ export const CustomListKeymap = ListKeymap.extend({ const originalShortcuts = this.parent?.() ?? {}; const getListItemType = () => this.editor.schema.nodes.listItem; + const getSupportedListItemNames = () => { + const listTypes = this.options.listTypes ?? []; + + return listTypes + .map(({ itemName }) => this.editor.schema.nodes[itemName]?.name) + .filter((itemName): itemName is string => typeof itemName === "string"); + }; + const runListIndentCommand = (command: "sinkListItem" | "liftListItem") => { + const editor = this.editor; + const { state } = editor; + + for (const itemName of getSupportedListItemNames()) { + if (!isNodeActive(state, itemName)) { + continue; + } + + const chain = editor.chain().focus(undefined, { scrollIntoView: false }); + const executed = command === "sinkListItem" + ? chain.sinkListItem(itemName).run() + : chain.liftListItem(itemName).run(); + + if (executed) { + return true; + } + } + + return false; + }; return { ...originalShortcuts, @@ -47,6 +75,26 @@ export const CustomListKeymap = ListKeymap.extend({ return originalShortcuts.Backspace ? originalShortcuts.Backspace({ editor }) : false; }, + + Tab: () => { + const editor = this.editor; + + if (runListIndentCommand("sinkListItem")) { + return true; + } + + return originalShortcuts.Tab ? originalShortcuts.Tab({ editor }) : false; + }, + + "Shift-Tab": () => { + const editor = this.editor; + + if (runListIndentCommand("liftListItem")) { + return true; + } + + return originalShortcuts["Shift-Tab"] ? originalShortcuts["Shift-Tab"]({ editor }) : false; + }, }; }, }); diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index 5a71ff253..40bf28bda 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -52,8 +52,10 @@ uuid = { workspace = true, features = ["v4"] } hound = { workspace = true } vorbis_rs = { workspace = true } -futures-util = { workspace = true } ractor = { workspace = true } +ractor-supervisor = { workspace = true } + +futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tokio-stream = { workspace = true } tokio-util = { workspace = true } diff --git a/plugins/listener/src/actors/batch.rs b/plugins/listener/src/actors/batch.rs index cd71c3084..235a26763 100644 --- a/plugins/listener/src/actors/batch.rs +++ b/plugins/listener/src/actors/batch.rs @@ -4,13 +4,11 @@ use std::time::Duration; use owhisper_interface::stream::StreamResponse; use owhisper_interface::{ControlMessage, MixedMessage}; -use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SpawnErr}; use tauri_specta::Event; use tokio_stream::{self as tokio_stream, StreamExt as TokioStreamExt}; use crate::SessionEvent; - -const RESAMPLED_SAMPLE_RATE_HZ: u32 = 16_000; const BATCH_STREAM_TIMEOUT_SECS: u64 = 5; const DEFAULT_CHUNK_MS: u64 = 500; const DEFAULT_DELAY_MS: u64 = 20; @@ -91,6 +89,12 @@ impl BatchActor { } } +pub async fn spawn_batch_actor(args: BatchArgs) -> Result, SpawnErr> { + let (batch_ref, _) = Actor::spawn(Some(BatchActor::name()), BatchActor, args).await?; + Ok(batch_ref) +} + +#[ractor::async_trait] impl Actor for BatchActor { type Msg = BatchMsg; type State = BatchState; @@ -188,13 +192,6 @@ impl BatchStreamConfig { } } - fn chunk_samples(&self) -> usize { - let samples = - ((self.chunk_ms as u128).saturating_mul(RESAMPLED_SAMPLE_RATE_HZ as u128) + 999) / 1000; - let samples = samples.max(1); - samples.min(usize::MAX as u128) as usize - } - fn chunk_interval(&self) -> Duration { Duration::from_millis(self.delay_ms) } @@ -225,12 +222,10 @@ async fn spawn_batch_task( let stream_config = BatchStreamConfig::new(DEFAULT_CHUNK_MS, DEFAULT_DELAY_MS); let start_notifier = args.start_notifier.clone(); - let chunk_samples = stream_config.chunk_samples(); let chunk_result = tokio::task::spawn_blocking({ let path = PathBuf::from(&args.file_path); - move || { - hypr_audio_utils::chunk_audio_file(path, RESAMPLED_SAMPLE_RATE_HZ, chunk_samples) - } + let chunk_ms = stream_config.chunk_ms; + move || hypr_audio_utils::chunk_audio_file(path, chunk_ms) }) .await; @@ -258,20 +253,26 @@ async fn spawn_batch_task( } }; - let sample_count = chunked_audio.sample_count; - let audio_duration_secs = if sample_count == 0 { + let frame_count = chunked_audio.frame_count; + let metadata = chunked_audio.metadata; + let audio_duration_secs = if frame_count == 0 || metadata.sample_rate == 0 { 0.0 } else { - sample_count as f64 / RESAMPLED_SAMPLE_RATE_HZ as f64 + frame_count as f64 / metadata.sample_rate as f64 }; let _ = myself.send_message(BatchMsg::StreamAudioDuration(audio_duration_secs)); tracing::debug!("batch task: creating listen client"); - let channel_count = args.listen_params.channels.clamp(1, 2); + let channel_count = metadata.channels.clamp(1, 2); + let listen_params = owhisper_interface::ListenParams { + channels: metadata.channels, + sample_rate: metadata.sample_rate, + ..args.listen_params.clone() + }; let client = owhisper_client::ListenClient::builder() .api_base(args.base_url) .api_key(args.api_key) - .params(args.listen_params.clone()) + .params(listen_params) .build_with_channels(channel_count); let chunk_count = chunked_audio.chunks.len(); diff --git a/plugins/listener/src/actors/context.rs b/plugins/listener/src/actors/context.rs new file mode 100644 index 000000000..8b26592b0 --- /dev/null +++ b/plugins/listener/src/actors/context.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; +use tokio::sync::RwLock; + +use super::ChannelMode; + +pub type LiveContextHandle = Arc; + +pub struct LiveContext { + snapshot: RwLock, +} + +impl LiveContext { + pub fn new() -> Self { + let initial = LiveSnapshot::default(); + Self { + snapshot: RwLock::new(initial), + } + } + + pub async fn read(&self) -> LiveSnapshot { + self.snapshot.read().await.clone() + } + + pub async fn write(&self, f: F) + where + F: FnOnce(&mut LiveSnapshot), + { + let mut guard = self.snapshot.write().await; + f(&mut guard); + } +} + +impl Default for LiveContext { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +pub struct LiveSnapshot { + pub mode: ChannelMode, + pub sample_rate: u32, + pub device_id: Option, +} + +impl Default for LiveSnapshot { + fn default() -> Self { + Self { + mode: ChannelMode::Dual, + sample_rate: 16000, + device_id: None, + } + } +} diff --git a/plugins/listener/src/actors/controller.rs b/plugins/listener/src/actors/controller.rs new file mode 100644 index 000000000..150a2ba02 --- /dev/null +++ b/plugins/listener/src/actors/controller.rs @@ -0,0 +1,190 @@ +use std::sync::Arc; +use std::time::{Instant, SystemTime}; + +use tauri_specta::Event; + +use ractor::{call_t, registry, Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; +use ractor_supervisor::supervisor::SupervisorMsg; + +use crate::{ + actors::{LiveContext, LiveContextHandle, LiveSupervisorArgs, SourceActor, SourceMsg}, + SessionEvent, +}; + +#[derive(Debug)] +pub enum ControllerMsg { + SetMicMute(bool), + GetMicMute(RpcReplyPort), + GetMicDeviceName(RpcReplyPort>), + ChangeMicDevice(Option), +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct SessionParams { + pub session_id: String, + pub languages: Vec, + pub onboarding: bool, + pub record_enabled: bool, + pub model: String, + pub base_url: String, + pub api_key: String, + pub keywords: Vec, +} + +pub struct SessionShared { + app: tauri::AppHandle, + params: SessionParams, + started_at_instant: Instant, + started_at_system: SystemTime, + live_ctx: Arc, +} + +impl SessionShared { + pub fn new(app: tauri::AppHandle, params: SessionParams) -> Arc { + Arc::new(Self { + app, + params, + started_at_instant: Instant::now(), + started_at_system: SystemTime::now(), + live_ctx: Arc::new(LiveContext::new()), + }) + } + + pub fn app(&self) -> &tauri::AppHandle { + &self.app + } + + pub fn live_ctx(&self) -> LiveContextHandle { + self.live_ctx.clone() + } + + pub fn live_supervisor_args(&self) -> LiveSupervisorArgs { + LiveSupervisorArgs { + app: self.app.clone(), + ctx: self.live_ctx(), + languages: self.params.languages.clone(), + onboarding: self.params.onboarding, + model: self.params.model.clone(), + base_url: self.params.base_url.clone(), + api_key: self.params.api_key.clone(), + keywords: self.params.keywords.clone(), + session_started_at: self.started_at_instant, + session_started_at_unix: self.started_at_system, + session_id: self.params.session_id.clone(), + record_enabled: self.params.record_enabled, + } + } +} + +pub struct ControllerActorArgs { + pub shared: Arc, + pub supervisor: ActorRef, +} + +pub struct ControllerState { + shared: Arc, + _supervisor: ActorRef, +} + +pub struct ControllerActor; + +impl ControllerActor { + pub fn name() -> ActorName { + "controller".into() + } +} + +#[ractor::async_trait] +impl Actor for ControllerActor { + type Msg = ControllerMsg; + type State = ControllerState; + type Arguments = ControllerActorArgs; + + async fn pre_start( + &self, + _myself: ActorRef, + args: Self::Arguments, + ) -> Result { + { + use tauri_plugin_tray::TrayPluginExt; + let _ = args.shared.app().set_start_disabled(true); + } + + SessionEvent::RunningActive {} + .emit(args.shared.app()) + .unwrap(); + + Ok(ControllerState { + shared: args.shared, + _supervisor: args.supervisor, + }) + } + + async fn handle( + &self, + _myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + ControllerMsg::SetMicMute(muted) => { + if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + actor.cast(SourceMsg::SetMicMute(muted))?; + } + SessionEvent::MicMuted { value: muted }.emit(state.shared.app())?; + } + + ControllerMsg::GetMicDeviceName(reply) => { + if !reply.is_closed() { + let device_name = if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + call_t!(actor, SourceMsg::GetMicDevice, 100).unwrap_or(None) + } else { + None + }; + + let _ = reply.send(device_name); + } + } + + ControllerMsg::GetMicMute(reply) => { + let muted = if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + call_t!(actor, SourceMsg::GetMicMute, 100)? + } else { + false + }; + + if !reply.is_closed() { + let _ = reply.send(muted); + } + } + + ControllerMsg::ChangeMicDevice(device) => { + if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + actor.cast(SourceMsg::SetMicDevice(device))?; + } + } + } + + Ok(()) + } + + async fn post_stop( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + { + use tauri_plugin_tray::TrayPluginExt; + let _ = state.shared.app().set_start_disabled(false); + } + + SessionEvent::Inactive {}.emit(state.shared.app())?; + tracing::info!("controller_actor_post_stop"); + + Ok(()) + } +} diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 733b0a204..a79f3ea6e 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -4,11 +4,13 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use futures_util::StreamExt; use tokio::time::error::Elapsed; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; +use ractor_supervisor::supervisor::SupervisorMsg; +use tauri_specta::Event; + use owhisper_client::hypr_ws; use owhisper_interface::stream::{Extra, StreamResponse}; use owhisper_interface::{ControlMessage, MixedMessage}; -use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; -use tauri_specta::Event; use crate::SessionEvent; @@ -22,7 +24,6 @@ pub enum ListenerMsg { StreamEnded, StreamTimeout(Elapsed), StreamStartFailed(String), - ChangeMode(crate::actors::ChannelMode), } #[derive(Clone)] @@ -35,6 +36,8 @@ pub struct ListenerArgs { pub api_key: String, pub keywords: Vec, pub mode: crate::actors::ChannelMode, + pub sample_rate: u32, + pub supervisor: ActorRef, pub session_started_at: Instant, pub session_started_at_unix: SystemTime, } @@ -44,6 +47,7 @@ pub struct ListenerState { tx: tokio::sync::mpsc::Sender>, rx_task: tokio::task::JoinHandle<()>, shutdown_tx: Option>, + supervisor: ActorRef, } pub struct ListenerActor; @@ -54,6 +58,7 @@ impl ListenerActor { } } +#[ractor::async_trait] impl Actor for ListenerActor { type Msg = ListenerMsg; type State = ListenerState; @@ -64,6 +69,12 @@ impl Actor for ListenerActor { myself: ActorRef, args: Self::Arguments, ) -> Result { + tracing::info!( + sample_rate = args.sample_rate, + mode = ?args.mode, + "listener_actor_pre_start" + ); + let supervisor = args.supervisor.clone(); let (tx, rx_task, shutdown_tx) = spawn_rx_task(args.clone(), myself).await?; let state = ListenerState { @@ -71,6 +82,7 @@ impl Actor for ListenerActor { tx, rx_task, shutdown_tx: Some(shutdown_tx), + supervisor, }; Ok(state) @@ -81,6 +93,7 @@ impl Actor for ListenerActor { _myself: ActorRef, state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { + tracing::info!("listener_actor_post_stop"); if let Some(shutdown_tx) = state.shutdown_tx.take() { let _ = shutdown_tx.send(()); let _ = (&mut state.rx_task).await; @@ -105,45 +118,42 @@ impl Actor for ListenerActor { response.remap_channel_index(0, 2); } + if let StreamResponse::TranscriptResponse { is_final, .. } = &response { + if *is_final { + tracing::info!(response = ?response, "final_response"); + } + } + SessionEvent::StreamResponse { response }.emit(&state.args.app)?; } ListenerMsg::StreamStartFailed(error) => { tracing::error!("listen_ws_connect_failed: {}", error); + request_rest_for_one( + &state.supervisor, + ListenerActor::name(), + "stream_start_failed", + ); myself.stop(Some(format!("listen_ws_connect_failed: {}", error))); } ListenerMsg::StreamError(error) => { tracing::info!("listen_stream_error: {}", error); + request_rest_for_one(&state.supervisor, ListenerActor::name(), "stream_error"); myself.stop(None); } ListenerMsg::StreamEnded => { tracing::info!("listen_stream_ended"); + request_rest_for_one(&state.supervisor, ListenerActor::name(), "stream_ended"); myself.stop(None); } ListenerMsg::StreamTimeout(elapsed) => { tracing::info!("listen_stream_timeout: {}", elapsed); + request_rest_for_one(&state.supervisor, ListenerActor::name(), "stream_timeout"); myself.stop(None); } - - ListenerMsg::ChangeMode(new_mode) => { - tracing::info!(?new_mode, "listener_mode_change"); - - if let Some(shutdown_tx) = state.shutdown_tx.take() { - let _ = shutdown_tx.send(()); - let _ = (&mut state.rx_task).await; - } - - state.args.mode = new_mode; - - let (tx, rx_task, shutdown_tx) = - spawn_rx_task(state.args.clone(), myself.clone()).await?; - state.tx = tx; - state.rx_task = rx_task; - state.shutdown_tx = Some(shutdown_tx); - } } Ok(()) } @@ -154,12 +164,13 @@ impl Actor for ListenerActor { message: SupervisionEvent, _state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { - tracing::info!("supervisor_event: {:?}", message); + tracing::info!("listener_actor_supervision_event: {:?}", message); match message { SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} SupervisionEvent::ActorTerminated(_, _, _) => {} SupervisionEvent::ActorFailed(_cell, _) => { + tracing::error!("listener_actor_failed_event"); myself.stop(None); } } @@ -167,6 +178,34 @@ impl Actor for ListenerActor { } } +fn request_rest_for_one( + supervisor: &ActorRef, + child_id: ActorName, + reason: &'static str, +) { + let child_id_string = child_id.to_string(); + tracing::info!( + child = child_id_string, + reason, + "requesting_rest_for_one_spawn_from_listener" + ); + match supervisor.cast(SupervisorMsg::RestForOneSpawn { + child_id: child_id_string.clone(), + }) { + Ok(_) => tracing::info!( + child = child_id_string, + reason, + "requested_rest_for_one_spawn_from_listener" + ), + Err(error) => tracing::warn!( + ?error, + child = child_id_string, + reason, + "failed_to_request_rest_for_one_from_listener" + ), + } +} + async fn spawn_rx_task( args: ListenerArgs, myself: ActorRef, @@ -196,12 +235,15 @@ async fn spawn_rx_task( let rx_task = tokio::spawn(async move { use crate::actors::ChannelMode; + let app_handle = args.app.clone(); + if args.mode == ChannelMode::Single { let client = owhisper_client::ListenClient::builder() .api_base(args.base_url.clone()) .api_key(args.api_key.clone()) .params(owhisper_interface::ListenParams { model: Some(args.model.clone()), + sample_rate: args.sample_rate, languages: args.languages.clone(), redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), keywords: args.keywords.clone(), @@ -231,6 +273,7 @@ async fn spawn_rx_task( handle, myself, shutdown_rx, + app_handle.clone(), session_offset_secs, extra.clone(), ) @@ -241,6 +284,7 @@ async fn spawn_rx_task( .api_key(args.api_key) .params(owhisper_interface::ListenParams { model: Some(args.model), + sample_rate: args.sample_rate, languages: args.languages, redemption_time_ms: Some(if args.onboarding { 60 } else { 400 }), keywords: args.keywords, @@ -248,7 +292,15 @@ async fn spawn_rx_task( }) .build_dual(); - let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); + let outbound = tokio_stream::StreamExt::map( + tokio_stream::wrappers::ReceiverStream::new(rx), + |msg| match msg { + MixedMessage::Audio((mic, spk)) => { + MixedMessage::Audio((spk, bytes::Bytes::from(vec![0; mic.len()]))) + } + MixedMessage::Control(c) => MixedMessage::Control(c), + }, + ); let (listen_stream, handle) = match client.from_realtime_audio(outbound).await { Ok(res) => res, @@ -264,6 +316,7 @@ async fn spawn_rx_task( handle, myself, shutdown_rx, + app_handle.clone(), session_offset_secs, extra.clone(), ) @@ -279,6 +332,7 @@ async fn process_stream( handle: hypr_ws::client::WebSocketHandle, myself: ActorRef, mut shutdown_rx: tokio::sync::oneshot::Receiver<()>, + app: tauri::AppHandle, offset_secs: f64, extra: Extra, ) where @@ -290,6 +344,10 @@ async fn process_stream( _ = &mut shutdown_rx => { handle.finalize_with_text(serde_json::json!({"type": "Finalize"}).to_string().into()).await; + if let Err(err) = (SessionEvent::Finalizing {}).emit(&app) { + tracing::warn!(?err, "failed_to_emit_finalizing"); + } + let finalize_timeout = tokio::time::sleep(Duration::from_secs(5)); tokio::pin!(finalize_timeout); diff --git a/plugins/listener/src/actors/live_supervisor.rs b/plugins/listener/src/actors/live_supervisor.rs new file mode 100644 index 000000000..d525fa1c9 --- /dev/null +++ b/plugins/listener/src/actors/live_supervisor.rs @@ -0,0 +1,214 @@ +use std::sync::Arc; +use std::time::{Instant, SystemTime}; + +use ractor::{Actor, ActorRef}; +use ractor_supervisor::{ + supervisor::{ + Supervisor, SupervisorArguments, SupervisorMsg, SupervisorOptions, SupervisorStrategy, + }, + ChildSpec, Restart, SpawnFn, +}; +use tokio::time::Duration; + +use super::{ + ListenerActor, ListenerArgs, LiveContextHandle, RecArgs, RecorderActor, SessionShared, + SourceActor, SourceArgs, +}; + +pub struct LiveSupervisorArgs { + pub app: tauri::AppHandle, + pub ctx: LiveContextHandle, + pub languages: Vec, + pub onboarding: bool, + pub model: String, + pub base_url: String, + pub api_key: String, + pub keywords: Vec, + pub session_started_at: Instant, + pub session_started_at_unix: SystemTime, + pub session_id: String, + pub record_enabled: bool, +} + +pub fn live_supervisor_spec(shared: Arc) -> ChildSpec { + let supervisor_options = SupervisorOptions { + strategy: SupervisorStrategy::RestForOne, + max_restarts: 5, + max_window: Duration::from_secs(10), + reset_after: Some(Duration::from_secs(30)), + }; + + ChildSpec { + id: "live".to_string(), + restart: Restart::Transient, + spawn_fn: SpawnFn::new(move |supervisor_cell, child_id| { + let shared = shared.clone(); + let options = supervisor_options.clone(); + + async move { + let args = shared.live_supervisor_args(); + let child_specs = build_child_specs(&args); + + let (live_ref, _) = Supervisor::spawn_linked( + child_id.into(), + Supervisor, + SupervisorArguments { + child_specs, + options, + }, + supervisor_cell, + ) + .await?; + + Ok(live_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: Some(Duration::from_secs(30)), + } +} + +fn build_child_specs(args: &LiveSupervisorArgs) -> Vec { + let mut child_specs = Vec::with_capacity(3); + child_specs.push(build_source_spec(args)); + if let Some(recorder_spec) = build_recorder_spec(args) { + child_specs.push(recorder_spec); + } + child_specs.push(build_listener_spec(args)); + child_specs +} + +fn build_source_spec(args: &LiveSupervisorArgs) -> ChildSpec { + let ctx = args.ctx.clone(); + let app = args.app.clone(); + let onboarding = args.onboarding; + + ChildSpec { + id: SourceActor::name().to_string(), + restart: Restart::Transient, + spawn_fn: SpawnFn::new(move |supervisor_cell, child_id| { + let ctx = ctx.clone(); + let app = app.clone(); + + async move { + let snapshot = ctx.read().await; + let supervisor_ref: ActorRef = supervisor_cell.clone().into(); + let source_args = SourceArgs { + mic_device: snapshot.device_id.clone(), + onboarding, + app: app.clone(), + ctx: ctx.clone(), + supervisor: supervisor_ref, + }; + + let (source_ref, _) = Actor::spawn_linked( + Some(child_id.into()), + SourceActor, + source_args, + supervisor_cell, + ) + .await?; + + Ok(source_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: Some(Duration::from_secs(60)), + } +} + +fn build_recorder_spec(args: &LiveSupervisorArgs) -> Option { + if !args.record_enabled { + return None; + } + + let app = args.app.clone(); + let session_id = args.session_id.clone(); + + Some(ChildSpec { + id: RecorderActor::name().to_string(), + restart: Restart::Transient, + spawn_fn: SpawnFn::new(move |supervisor_cell, child_id| { + let app = app.clone(); + let session_id = session_id.clone(); + + async move { + let rec_args = RecArgs { + app_dir: tauri::Manager::path(&app).app_data_dir().unwrap(), + session_id: session_id.clone(), + }; + + let (rec_ref, _) = Actor::spawn_linked( + Some(child_id.into()), + RecorderActor, + rec_args, + supervisor_cell, + ) + .await?; + + Ok(rec_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: None, + }) +} + +fn build_listener_spec(args: &LiveSupervisorArgs) -> ChildSpec { + let ctx = args.ctx.clone(); + let app = args.app.clone(); + let languages = args.languages.clone(); + let onboarding = args.onboarding; + let model = args.model.clone(); + let base_url = args.base_url.clone(); + let api_key = args.api_key.clone(); + let keywords = args.keywords.clone(); + let session_started_at = args.session_started_at; + let session_started_at_unix = args.session_started_at_unix; + + ChildSpec { + id: ListenerActor::name().to_string(), + restart: Restart::Transient, + spawn_fn: SpawnFn::new(move |supervisor_cell, child_id| { + let ctx = ctx.clone(); + let app = app.clone(); + let languages = languages.clone(); + let model = model.clone(); + let base_url = base_url.clone(); + let api_key = api_key.clone(); + let keywords = keywords.clone(); + + let supervisor_ref: ActorRef = supervisor_cell.clone().into(); + + async move { + let snapshot = ctx.read().await; + let listener_args = ListenerArgs { + app: app.clone(), + languages: languages.clone(), + onboarding, + model: model.clone(), + base_url: base_url.clone(), + api_key: api_key.clone(), + keywords: keywords.clone(), + mode: snapshot.mode, + sample_rate: snapshot.sample_rate, + supervisor: supervisor_ref.clone(), + session_started_at, + session_started_at_unix, + }; + + let (listener_ref, _) = Actor::spawn_linked( + Some(child_id.into()), + ListenerActor, + listener_args, + supervisor_cell, + ) + .await?; + + Ok(listener_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: Some(Duration::from_secs(60)), + } +} diff --git a/plugins/listener/src/actors/mod.rs b/plugins/listener/src/actors/mod.rs index 86ea98a66..b1c8bf9ce 100644 --- a/plugins/listener/src/actors/mod.rs +++ b/plugins/listener/src/actors/mod.rs @@ -1,15 +1,19 @@ mod batch; +mod context; +mod controller; mod listener; -mod processor; +mod live_supervisor; mod recorder; -mod session; +mod session_supervisor; mod source; pub use batch::*; +pub use context::*; +pub use controller::*; pub use listener::*; -pub use processor::*; +pub use live_supervisor::*; pub use recorder::*; -pub use session::*; +pub use session_supervisor::*; pub use source::*; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -17,8 +21,3 @@ pub enum ChannelMode { Single, Dual, } - -#[derive(Clone)] -pub struct AudioChunk { - data: Vec, -} diff --git a/plugins/listener/src/actors/processor.rs b/plugins/listener/src/actors/processor.rs deleted file mode 100644 index fec275c9e..000000000 --- a/plugins/listener/src/actors/processor.rs +++ /dev/null @@ -1,243 +0,0 @@ -use std::{ - collections::VecDeque, - sync::Arc, - time::{Duration, Instant}, -}; - -use ractor::{registry, Actor, ActorName, ActorProcessingErr, ActorRef}; -use tauri_specta::Event; - -use crate::{ - actors::{AudioChunk, ChannelMode, ListenerActor, ListenerMsg, RecMsg, RecorderActor}, - SessionEvent, -}; - -const AUDIO_AMPLITUDE_THROTTLE: Duration = Duration::from_millis(100); - -pub enum ProcMsg { - Mic(AudioChunk), - Speaker(AudioChunk), - SetMode(ChannelMode), - Reset, -} - -pub struct ProcArgs { - pub app: tauri::AppHandle, -} - -pub struct ProcState { - app: tauri::AppHandle, - agc_m: hypr_agc::Agc, - agc_s: hypr_agc::Agc, - joiner: Joiner, - last_sent_mic: Option>, - last_sent_spk: Option>, - last_amp_emit: Instant, - mode: ChannelMode, -} - -impl ProcState { - fn reset_pipeline(&mut self) { - self.joiner.reset(); - self.last_sent_mic = None; - self.last_sent_spk = None; - self.agc_m = hypr_agc::Agc::default(); - self.agc_s = hypr_agc::Agc::default(); - self.last_amp_emit = Instant::now(); - } -} - -pub struct ProcessorActor {} - -impl ProcessorActor { - pub fn name() -> ActorName { - "processor_actor".into() - } -} - -impl Actor for ProcessorActor { - type Msg = ProcMsg; - type State = ProcState; - type Arguments = ProcArgs; - - async fn pre_start( - &self, - _myself: ActorRef, - args: Self::Arguments, - ) -> Result { - Ok(ProcState { - app: args.app.clone(), - joiner: Joiner::new(), - agc_m: hypr_agc::Agc::default(), - agc_s: hypr_agc::Agc::default(), - last_sent_mic: None, - last_sent_spk: None, - last_amp_emit: Instant::now(), - mode: ChannelMode::Dual, - }) - } - - async fn handle( - &self, - _myself: ActorRef, - msg: Self::Msg, - st: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - match msg { - ProcMsg::Mic(mut c) => { - st.agc_m.process(&mut c.data); - let arc = Arc::<[f32]>::from(c.data); - st.joiner.push_mic(arc); - process_ready(st).await; - } - ProcMsg::Speaker(mut c) => { - st.agc_s.process(&mut c.data); - let arc = Arc::<[f32]>::from(c.data); - st.joiner.push_spk(arc); - process_ready(st).await; - } - ProcMsg::SetMode(mode) => { - if st.mode != mode { - st.mode = mode; - st.reset_pipeline(); - } - } - ProcMsg::Reset => { - st.reset_pipeline(); - } - } - Ok(()) - } -} - -async fn process_ready(st: &mut ProcState) { - while let Some((mic, spk)) = st.joiner.pop_pair(st.mode) { - let mut audio_sent_successfully = false; - - if let Some(cell) = registry::where_is(RecorderActor::name()) { - let mixed: Vec = mic - .iter() - .zip(spk.iter()) - .map(|(m, s)| (m + s).clamp(-1.0, 1.0)) - .collect(); - - let actor: ActorRef = cell.into(); - if let Err(e) = actor.cast(RecMsg::Audio(mixed)) { - tracing::error!(error = ?e, "failed_to_send_audio_to_recorder"); - } - } - - if let Some(cell) = registry::where_is(ListenerActor::name()) { - let (mic_bytes, spk_bytes) = if st.mode == ChannelMode::Single { - let mixed: Vec = mic - .iter() - .zip(spk.iter()) - .map(|(m, s)| (m + s).clamp(-1.0, 1.0)) - .collect(); - let mixed_bytes = hypr_audio_utils::f32_to_i16_bytes(mixed.iter().copied()); - ( - hypr_audio_utils::f32_to_i16_bytes(mic.iter().copied()), - mixed_bytes, - ) - } else { - ( - hypr_audio_utils::f32_to_i16_bytes(mic.iter().copied()), - hypr_audio_utils::f32_to_i16_bytes(spk.iter().copied()), - ) - }; - - let actor: ActorRef = cell.into(); - if actor - .cast(ListenerMsg::Audio(mic_bytes.into(), spk_bytes.into())) - .is_ok() - { - audio_sent_successfully = true; - st.last_sent_mic = Some(mic.clone()); - st.last_sent_spk = Some(spk.clone()); - } else { - tracing::warn!(actor = ListenerActor::name(), "cast_failed"); - } - } else { - tracing::debug!(actor = ListenerActor::name(), "unavailable"); - } - - if audio_sent_successfully && st.last_amp_emit.elapsed() >= AUDIO_AMPLITUDE_THROTTLE { - if let (Some(mic_data), Some(spk_data)) = (&st.last_sent_mic, &st.last_sent_spk) { - if let Err(e) = - SessionEvent::from((mic_data.as_ref(), spk_data.as_ref())).emit(&st.app) - { - tracing::error!("{:?}", e); - } - st.last_amp_emit = Instant::now(); - } - } - } -} - -struct Joiner { - mic: VecDeque>, - spk: VecDeque>, - silence_cache: std::collections::HashMap>, -} - -impl Joiner { - const MAX_LAG: usize = 4; - const MAX_QUEUE_SIZE: usize = 30; - - fn new() -> Self { - Self { - mic: VecDeque::new(), - spk: VecDeque::new(), - silence_cache: std::collections::HashMap::new(), - } - } - - fn reset(&mut self) { - self.mic.clear(); - self.spk.clear(); - } - - fn get_silence(&mut self, len: usize) -> Arc<[f32]> { - self.silence_cache - .entry(len) - .or_insert_with(|| Arc::from(vec![0.0; len])) - .clone() - } - - fn push_mic(&mut self, data: Arc<[f32]>) { - self.mic.push_back(data); - if self.mic.len() > Self::MAX_QUEUE_SIZE { - tracing::warn!("mic_queue_overflow"); - self.mic.pop_front(); - } - } - - fn push_spk(&mut self, data: Arc<[f32]>) { - self.spk.push_back(data); - if self.spk.len() > Self::MAX_QUEUE_SIZE { - tracing::warn!("spk_queue_overflow"); - self.spk.pop_front(); - } - } - - fn pop_pair(&mut self, mode: ChannelMode) -> Option<(Arc<[f32]>, Arc<[f32]>)> { - match (self.mic.front(), self.spk.front()) { - (Some(_), Some(_)) => { - let mic = self.mic.pop_front()?; - let spk = self.spk.pop_front()?; - Some((mic, spk)) - } - (Some(_), None) if mode == ChannelMode::Single || self.mic.len() > Self::MAX_LAG => { - let mic = self.mic.pop_front()?; - let spk = self.get_silence(mic.len()); - Some((mic, spk)) - } - (None, Some(_)) if self.spk.len() > Self::MAX_LAG => { - let spk = self.spk.pop_front()?; - let mic = self.get_silence(spk.len()); - Some((mic, spk)) - } - _ => None, - } - } -} diff --git a/plugins/listener/src/actors/recorder.rs b/plugins/listener/src/actors/recorder.rs index 241d86bb7..7750b7dde 100644 --- a/plugins/listener/src/actors/recorder.rs +++ b/plugins/listener/src/actors/recorder.rs @@ -1,17 +1,21 @@ use std::fs::File; use std::io::BufWriter; use std::path::PathBuf; +use std::pin::Pin; use std::time::Instant; +use hypr_audio::ResampledAsyncSource; use hypr_audio_utils::{ - decode_vorbis_to_wav_file, encode_wav_to_vorbis_file, VorbisEncodeSettings, + decode_vorbis_to_wav_file, drain_stream, encode_wav_to_vorbis_file, PushSource, + PushSourceHandle, VorbisEncodeSettings, }; use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef}; const FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_millis(1000); +const TARGET_SAMPLE_RATE_HZ: u32 = 16_000; pub enum RecMsg { - Audio(Vec), + Audio { samples: Vec, sample_rate: u32 }, } pub struct RecArgs { @@ -24,6 +28,8 @@ pub struct RecState { wav_path: PathBuf, ogg_path: PathBuf, last_flush: Instant, + resampler: ResampledAsyncSource, + input_handle: PushSourceHandle, } pub struct RecorderActor; @@ -34,6 +40,7 @@ impl RecorderActor { } } +#[ractor::async_trait] impl Actor for RecorderActor { type Msg = RecMsg; type State = RecState; @@ -58,7 +65,7 @@ impl Actor for RecorderActor { let spec = hound::WavSpec { channels: 1, - sample_rate: 16000, + sample_rate: TARGET_SAMPLE_RATE_HZ, bits_per_sample: 32, sample_format: hound::SampleFormat::Float, }; @@ -69,11 +76,16 @@ impl Actor for RecorderActor { hound::WavWriter::create(&wav_path, spec)? }; + let (source, input_handle) = PushSource::new(TARGET_SAMPLE_RATE_HZ); + let resampler = ResampledAsyncSource::new(source, TARGET_SAMPLE_RATE_HZ); + Ok(RecState { writer: Some(writer), wav_path, ogg_path, last_flush: Instant::now(), + resampler, + input_handle, }) } @@ -84,17 +96,11 @@ impl Actor for RecorderActor { st: &mut Self::State, ) -> Result<(), ActorProcessingErr> { match msg { - RecMsg::Audio(v) => { - if let Some(ref mut writer) = st.writer { - for s in v { - writer.write_sample(s)?; - } - - if st.last_flush.elapsed() >= FLUSH_INTERVAL { - writer.flush()?; - st.last_flush = Instant::now(); - } - } + RecMsg::Audio { + samples, + sample_rate, + } => { + st.push_samples(samples, sample_rate)?; } } @@ -106,6 +112,9 @@ impl Actor for RecorderActor { _myself: ActorRef, st: &mut Self::State, ) -> Result<(), ActorProcessingErr> { + st.input_handle.close(); + st.drain_resampler()?; + if let Some(mut writer) = st.writer.take() { writer.flush()?; writer.finalize()?; @@ -140,3 +149,38 @@ impl Actor for RecorderActor { fn into_actor_err(err: hypr_audio_utils::Error) -> ActorProcessingErr { Box::new(err) } + +impl RecState { + fn push_samples( + &mut self, + samples: Vec, + sample_rate: u32, + ) -> Result<(), ActorProcessingErr> { + if samples.is_empty() || sample_rate == 0 { + return Ok(()); + } + + self.input_handle.push(samples, sample_rate); + self.drain_resampler() + } + + fn drain_resampler(&mut self) -> Result<(), ActorProcessingErr> { + let writer = match self.writer.as_mut() { + Some(writer) => writer, + None => return Ok(()), + }; + + let progress = drain_stream(Pin::new(&mut self.resampler), |sample| { + writer + .write_sample(sample) + .map_err(|err| Box::new(err) as ActorProcessingErr) + })?; + + if progress.processed && self.last_flush.elapsed() >= FLUSH_INTERVAL { + writer.flush()?; + self.last_flush = Instant::now(); + } + + Ok(()) + } +} diff --git a/plugins/listener/src/actors/session.rs b/plugins/listener/src/actors/session.rs deleted file mode 100644 index 5d9d3c5f2..000000000 --- a/plugins/listener/src/actors/session.rs +++ /dev/null @@ -1,365 +0,0 @@ -use std::time::{Instant, SystemTime}; - -use tauri::Manager; -use tauri_specta::Event; - -use ractor::{ - call_t, concurrency, registry, Actor, ActorCell, ActorName, ActorProcessingErr, ActorRef, - RpcReplyPort, SupervisionEvent, -}; -use tokio_util::sync::CancellationToken; - -use crate::{ - actors::{ - ListenerActor, ListenerArgs, ListenerMsg, ProcArgs, ProcMsg, ProcessorActor, RecArgs, - RecMsg, RecorderActor, SourceActor, SourceArgs, SourceMsg, - }, - SessionEvent, -}; - -#[derive(Debug)] -pub enum SessionMsg { - SetMicMute(bool), - GetMicMute(RpcReplyPort), - GetMicDeviceName(RpcReplyPort>), - ChangeMicDevice(Option), -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] -pub struct SessionParams { - pub session_id: String, - pub languages: Vec, - pub onboarding: bool, - pub record_enabled: bool, - pub model: String, - pub base_url: String, - pub api_key: String, - pub keywords: Vec, -} - -pub struct SessionArgs { - pub app: tauri::AppHandle, - pub params: SessionParams, -} - -pub struct SessionState { - app: tauri::AppHandle, - token: CancellationToken, - params: SessionParams, - started_at_instant: Instant, - started_at_system: SystemTime, -} - -pub struct SessionActor; - -impl SessionActor { - pub fn name() -> ActorName { - "session".into() - } -} - -impl Actor for SessionActor { - type Msg = SessionMsg; - type State = SessionState; - type Arguments = SessionArgs; - - async fn pre_start( - &self, - myself: ActorRef, - args: Self::Arguments, - ) -> Result { - let cancellation_token = CancellationToken::new(); - let started_at_instant = Instant::now(); - let started_at_system = SystemTime::now(); - - { - use tauri_plugin_tray::TrayPluginExt; - let _ = args.app.set_start_disabled(true); - } - - let state = SessionState { - app: args.app, - token: cancellation_token, - params: args.params, - started_at_instant, - started_at_system, - }; - - { - let c = myself.get_cell(); - Self::start_all_actors(c, &state).await?; - } - - SessionEvent::RunningActive {}.emit(&state.app).unwrap(); - Ok(state) - } - - async fn handle( - &self, - _myself: ActorRef, - message: Self::Msg, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - match message { - SessionMsg::SetMicMute(muted) => { - if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - actor.cast(SourceMsg::SetMicMute(muted))?; - } - SessionEvent::MicMuted { value: muted }.emit(&state.app)?; - } - - SessionMsg::GetMicDeviceName(reply) => { - if !reply.is_closed() { - let device_name = if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - call_t!(actor, SourceMsg::GetMicDevice, 100).unwrap_or(None) - } else { - None - }; - - let _ = reply.send(device_name); - } - } - - SessionMsg::GetMicMute(reply) => { - let muted = if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - call_t!(actor, SourceMsg::GetMicMute, 100)? - } else { - false - }; - - if !reply.is_closed() { - let _ = reply.send(muted); - } - } - - SessionMsg::ChangeMicDevice(device) => { - if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - actor.cast(SourceMsg::SetMicDevice(device))?; - } - } - } - - Ok(()) - } - - async fn handle_supervisor_evt( - &self, - myself: ActorRef, - event: SupervisionEvent, - _state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - match event { - SupervisionEvent::ActorStarted(actor) => { - tracing::info!("{:?}_actor_started", actor.get_name()); - } - SupervisionEvent::ActorTerminated(actor, _maybe_state, exit_reason) => { - let actor_name = actor - .get_name() - .map(|n| n.to_string()) - .unwrap_or_else(|| "unknown".to_string()); - - tracing::error!( - actor = %actor_name, - reason = ?exit_reason, - "child_actor_terminated_stopping_session" - ); - - myself.stop(None); - } - SupervisionEvent::ActorFailed(_, _) => {} - _ => {} - } - - Ok(()) - } - - async fn post_stop( - &self, - _myself: ActorRef, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - state.token.cancel(); - - { - Self::stop_all_actors().await; - } - - { - use tauri_plugin_tray::TrayPluginExt; - let _ = state.app.set_start_disabled(false); - } - - SessionEvent::Inactive {}.emit(&state.app)?; - - Ok(()) - } -} - -impl SessionActor { - async fn start_all_actors( - supervisor: ActorCell, - state: &SessionState, - ) -> Result<(), ActorProcessingErr> { - Self::start_processor(supervisor.clone(), state).await?; - Self::start_source(supervisor.clone(), state).await?; - Self::start_listener(supervisor.clone(), state, None).await?; - - if state.params.record_enabled { - Self::start_recorder(supervisor, state).await?; - } - - Ok(()) - } - - async fn stop_all_actors() { - Self::stop_processor().await; - Self::stop_source().await; - Self::stop_listener().await; - Self::stop_recorder().await; - } - - async fn start_source( - supervisor: ActorCell, - state: &SessionState, - ) -> Result, ActorProcessingErr> { - let (ar, _) = Actor::spawn_linked( - Some(SourceActor::name()), - SourceActor, - SourceArgs { - token: state.token.clone(), - mic_device: None, - onboarding: state.params.onboarding, - }, - supervisor, - ) - .await?; - Ok(ar) - } - - async fn stop_source() { - if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor - .stop_and_wait( - Some("restart".to_string()), - Some(concurrency::Duration::from_secs(3)), - ) - .await; - } - } - - async fn start_processor( - supervisor: ActorCell, - state: &SessionState, - ) -> Result, ActorProcessingErr> { - let (ar, _) = Actor::spawn_linked( - Some(ProcessorActor::name()), - ProcessorActor {}, - ProcArgs { - app: state.app.clone(), - }, - supervisor, - ) - .await?; - Ok(ar) - } - - async fn stop_processor() { - if let Some(cell) = registry::where_is(ProcessorActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor - .stop_and_wait( - Some("restart".to_string()), - Some(concurrency::Duration::from_secs(3)), - ) - .await; - } - } - - async fn start_recorder( - supervisor: ActorCell, - state: &SessionState, - ) -> Result, ActorProcessingErr> { - let (rec_ref, _) = Actor::spawn_linked( - Some(RecorderActor::name()), - RecorderActor, - RecArgs { - app_dir: state.app.path().app_data_dir().unwrap(), - session_id: state.params.session_id.clone(), - }, - supervisor, - ) - .await?; - Ok(rec_ref) - } - - async fn stop_recorder() { - if let Some(cell) = registry::where_is(RecorderActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor - .stop_and_wait( - Some("restart".to_string()), - Some(concurrency::Duration::from_secs(6)), - ) - .await; - } - } - - async fn start_listener( - supervisor: ActorCell, - session_state: &SessionState, - listener_args: Option, - ) -> Result, ActorProcessingErr> { - use crate::actors::ChannelMode; - - let mode = if listener_args.is_none() { - if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - match call_t!(actor, SourceMsg::GetMode, 500) { - Ok(m) => m, - Err(_) => ChannelMode::Dual, - } - } else { - ChannelMode::Dual - } - } else { - ChannelMode::Dual - }; - - let (listen_ref, _) = Actor::spawn_linked( - Some(ListenerActor::name()), - ListenerActor, - listener_args.unwrap_or(ListenerArgs { - app: session_state.app.clone(), - languages: session_state.params.languages.clone(), - onboarding: session_state.params.onboarding, - model: session_state.params.model.clone(), - base_url: session_state.params.base_url.clone(), - api_key: session_state.params.api_key.clone(), - keywords: session_state.params.keywords.clone(), - mode, - session_started_at: session_state.started_at_instant, - session_started_at_unix: session_state.started_at_system, - }), - supervisor, - ) - .await?; - Ok(listen_ref) - } - - async fn stop_listener() { - if let Some(cell) = registry::where_is(ListenerActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor - .stop_and_wait( - Some("restart".to_string()), - Some(concurrency::Duration::from_secs(3)), - ) - .await; - } - } -} diff --git a/plugins/listener/src/actors/session_supervisor.rs b/plugins/listener/src/actors/session_supervisor.rs new file mode 100644 index 000000000..51651ba38 --- /dev/null +++ b/plugins/listener/src/actors/session_supervisor.rs @@ -0,0 +1,72 @@ +use std::sync::Arc; +use tokio::time::Duration; + +use ractor::{Actor, ActorRef, SpawnErr}; +use ractor_supervisor::{ + ChildSpec, Restart, SpawnFn, Supervisor, SupervisorArguments, SupervisorMsg, SupervisorOptions, + SupervisorStrategy, +}; + +use super::{ + live_supervisor_spec, ControllerActor, ControllerActorArgs, SessionParams, SessionShared, +}; + +pub const SESSION_SUPERVISOR_NAME: &str = "session_supervisor"; + +pub async fn start_session_supervisor( + app: tauri::AppHandle, + params: SessionParams, +) -> Result<(ActorRef, ractor::concurrency::JoinHandle<()>), SpawnErr> { + let shared = SessionShared::new(app, params); + + let child_specs = vec![ + controller_actor_spec(shared.clone()), + live_supervisor_spec(shared.clone()), + ]; + + let supervisor_options = SupervisorOptions { + strategy: SupervisorStrategy::RestForOne, + max_restarts: 5, + max_window: Duration::from_secs(10), + reset_after: Some(Duration::from_secs(30)), + }; + + Supervisor::spawn( + SESSION_SUPERVISOR_NAME.into(), + SupervisorArguments { + child_specs, + options: supervisor_options, + }, + ) + .await +} + +fn controller_actor_spec(shared: Arc) -> ChildSpec { + ChildSpec { + id: ControllerActor::name().to_string(), + restart: Restart::Transient, + spawn_fn: SpawnFn::new(move |supervisor_cell, child_id| { + let shared = shared.clone(); + + async move { + let supervisor_ref: ActorRef = supervisor_cell.clone().into(); + let args = ControllerActorArgs { + shared, + supervisor: supervisor_ref, + }; + + let (controller_ref, _) = Actor::spawn_linked( + Some(child_id.into()), + ControllerActor, + args, + supervisor_cell, + ) + .await?; + + Ok(controller_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: Some(Duration::from_secs(60)), + } +} diff --git a/plugins/listener/src/actors/source.rs b/plugins/listener/src/actors/source.rs index 1f8378d5c..faf3bd8b0 100644 --- a/plugins/listener/src/actors/source.rs +++ b/plugins/listener/src/actors/source.rs @@ -1,45 +1,70 @@ +use std::collections::VecDeque; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::time::{Duration, Instant}; use futures_util::StreamExt; use ractor::{registry, Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; -use tokio_util::sync::CancellationToken; +use ractor_supervisor::supervisor::SupervisorMsg; +use tauri_specta::Event; -use crate::actors::{AudioChunk, ChannelMode, ListenerActor, ListenerMsg, ProcMsg, ProcessorActor}; +use crate::actors::{ + ChannelMode, ListenerActor, ListenerMsg, LiveContextHandle, LiveSnapshot, RecMsg, RecorderActor, +}; +use crate::SessionEvent; use hypr_audio::{ - is_using_headphone, AudioInput, DeviceEvent, DeviceMonitor, DeviceMonitorHandle, - ResampledAsyncSource, + is_using_headphone, AudioInput, DeviceEvent, DeviceMonitor, DeviceMonitorHandle, MicInput, + SpeakerInput, }; -// We previously used AEC; it has been removed. Keep this constant to preserve chunking size. const AEC_BLOCK_SIZE: usize = 512; -const SAMPLE_RATE: u32 = 16000; +const AUDIO_AMPLITUDE_THROTTLE: Duration = Duration::from_millis(100); pub enum SourceMsg { SetMicMute(bool), GetMicMute(RpcReplyPort), SetMicDevice(Option), GetMicDevice(RpcReplyPort>), - GetMode(RpcReplyPort), } pub struct SourceArgs { pub mic_device: Option, - pub token: CancellationToken, pub onboarding: bool, + pub app: tauri::AppHandle, + pub ctx: LiveContextHandle, + pub supervisor: ActorRef, } pub struct SourceState { mic_device: Option, - token: CancellationToken, onboarding: bool, + app: tauri::AppHandle, + ctx: LiveContextHandle, + supervisor: ActorRef, mic_muted: Arc, run_task: Option>, - stream_cancel_token: Option, _device_monitor_handle: Option, _silence_stream_tx: Option>, _device_event_thread: Option>, current_mode: ChannelMode, + sample_rate: u32, + agc_m: hypr_agc::Agc, + agc_s: hypr_agc::Agc, + joiner: Joiner, + last_sent_mic: Option>, + last_sent_spk: Option>, + last_amp_emit: Instant, +} + +impl SourceState { + fn reset_pipeline(&mut self) { + self.joiner.reset(); + self.last_sent_mic = None; + self.last_sent_spk = None; + self.agc_m = hypr_agc::Agc::default(); + self.agc_s = hypr_agc::Agc::default(); + self.last_amp_emit = Instant::now(); + } } pub struct SourceActor; @@ -50,6 +75,7 @@ impl SourceActor { } } +#[ractor::async_trait] impl Actor for SourceActor { type Msg = SourceMsg; type State = SourceState; @@ -109,17 +135,47 @@ impl Actor for SourceActor { .or_else(|| Some(AudioInput::get_default_device_name())); tracing::info!(mic_device = ?mic_device); + let sample_rate = MicInput::new(mic_device.clone()) + .map_err(|err| -> ActorProcessingErr { Box::new(err) })? + .sample_rate(); + tracing::info!(sample_rate, "mic_sample_rate_resolved"); + + #[cfg(any(target_os = "macos", target_os = "windows"))] + match SpeakerInput::new() { + Ok(input) => { + let speaker_rate_hz = input.sample_rate(); + if speaker_rate_hz != sample_rate { + tracing::warn!( + mic_sample_rate = sample_rate, + speaker_sample_rate = speaker_rate_hz, + "sample_rate_mismatch" + ); + } + } + Err(err) => { + tracing::warn!(error = ?err, "speaker_sample_rate_unavailable"); + } + } + let mut st = SourceState { mic_device, - token: args.token, onboarding: args.onboarding, + app: args.app, + ctx: args.ctx, + supervisor: args.supervisor, mic_muted: Arc::new(AtomicBool::new(false)), run_task: None, - stream_cancel_token: None, _device_monitor_handle: Some(device_monitor_handle), _silence_stream_tx: silence_stream_tx, _device_event_thread: Some(device_event_thread), current_mode: ChannelMode::Dual, + sample_rate, + agc_m: hypr_agc::Agc::default(), + agc_s: hypr_agc::Agc::default(), + joiner: Joiner::new(), + last_sent_mic: None, + last_sent_spk: None, + last_amp_emit: Instant::now(), }; start_source_loop(&myself, &mut st).await?; @@ -147,21 +203,24 @@ impl Actor for SourceActor { } } SourceMsg::SetMicDevice(dev) => { + tracing::info!(device = ?dev, "source_actor_set_mic_device"); st.mic_device = dev; + let device = st.mic_device.clone(); - if let Some(cancel_token) = st.stream_cancel_token.take() { - cancel_token.cancel(); - } + st.ctx + .write(|snap| { + snap.device_id = device.clone(); + }) + .await; if let Some(t) = st.run_task.take() { t.abort(); } - start_source_loop(&myself, st).await?; - } - SourceMsg::GetMode(reply) => { - if !reply.is_closed() { - let _ = reply.send(st.current_mode); - } + + request_rest_for_one(&st.supervisor, SourceActor::name()); + tracing::info!("source_actor_stopping_for_device_change"); + myself.stop(Some("device_change".into())); + return Ok(()); } } @@ -173,9 +232,7 @@ impl Actor for SourceActor { _myself: ActorRef, st: &mut Self::State, ) -> Result<(), ActorProcessingErr> { - if let Some(cancel_token) = st.stream_cancel_token.take() { - cancel_token.cancel(); - } + tracing::info!("source_actor_post_stop"); if let Some(task) = st.run_task.take() { task.abort(); } @@ -185,16 +242,18 @@ impl Actor for SourceActor { } async fn start_source_loop( - myself: &ActorRef, + _myself: &ActorRef, st: &mut SourceState, ) -> Result<(), ActorProcessingErr> { - let myself2 = myself.clone(); - let token = st.token.clone(); let mic_muted = st.mic_muted.clone(); let mic_device = st.mic_device.clone(); + let app = st.app.clone(); - let stream_cancel_token = CancellationToken::new(); - st.stream_cancel_token = Some(stream_cancel_token.clone()); + let previous_snapshot = st.ctx.read().await; + let default_snapshot = LiveSnapshot::default(); + let was_initialized = previous_snapshot.device_id.is_some() + || previous_snapshot.mode != default_snapshot.mode + || previous_snapshot.sample_rate != default_snapshot.sample_rate; #[cfg(target_os = "macos")] let new_mode = if !st.onboarding && !is_using_headphone() { @@ -206,157 +265,274 @@ async fn start_source_loop( #[cfg(not(target_os = "macos"))] let new_mode = ChannelMode::Dual; - let mode_changed = st.current_mode != new_mode; + let new_sample_rate = st.sample_rate; + + let mode_changed = previous_snapshot.mode != new_mode; + let rate_changed = previous_snapshot.sample_rate != new_sample_rate; st.current_mode = new_mode; - tracing::info!(?new_mode, mode_changed, "start_source_loop"); + tracing::info!( + ?new_mode, + mode_changed, + rate_changed, + sample_rate = new_sample_rate, + "start_source_loop" + ); + + let device_id = st.mic_device.clone(); + + st.ctx + .write(|snap| { + snap.mode = new_mode; + snap.sample_rate = new_sample_rate; + snap.device_id = device_id.clone(); + }) + .await; + + st.reset_pipeline(); - if let Some(cell) = registry::where_is(ProcessorActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor.cast(ProcMsg::Reset); + if was_initialized && (mode_changed || rate_changed) { + request_rest_for_one(&st.supervisor, ListenerActor::name()); } - if mode_changed { - if let Some(cell) = registry::where_is(ProcessorActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor.cast(ProcMsg::SetMode(new_mode)); - } + let handle = spawn_capture_task(mic_device, mic_muted, new_mode, new_sample_rate, app); - if let Some(cell) = registry::where_is(ListenerActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor.cast(ListenerMsg::ChangeMode(new_mode)); - } - } + st.run_task = Some(handle); + Ok(()) +} - let use_mixed = new_mode == ChannelMode::Single; - - let handle = if use_mixed { - #[cfg(target_os = "macos")] - { - tokio::spawn(async move { - let mic_stream = { - let mut mic_input = AudioInput::from_mic(mic_device).unwrap(); - ResampledAsyncSource::new(mic_input.stream(), SAMPLE_RATE) - .chunks(AEC_BLOCK_SIZE) - }; - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - let spk_stream = { - let mut spk_input = hypr_audio::AudioInput::from_speaker(); - ResampledAsyncSource::new(spk_input.stream(), SAMPLE_RATE) - .chunks(AEC_BLOCK_SIZE) - }; - - tokio::pin!(mic_stream); - tokio::pin!(spk_stream); - - loop { - let Some(cell) = registry::where_is(ProcessorActor::name()) else { - tracing::warn!("processor_actor_not_found"); - continue; - }; - let proc: ActorRef = cell.into(); - - tokio::select! { - _ = token.cancelled() => { - drop(mic_stream); - drop(spk_stream); - myself2.stop(None); - return; - } - _ = stream_cancel_token.cancelled() => { - drop(mic_stream); - drop(spk_stream); - return; - } - mic_next = mic_stream.next() => { - if let Some(data) = mic_next { - let output_data = if mic_muted.load(Ordering::Relaxed) { - vec![0.0; data.len()] - } else { - data - }; - let msg = ProcMsg::Mic(AudioChunk { data: output_data }); - let _ = proc.cast(msg); - } else { - break; - } - } - spk_next = spk_stream.next() => { - if let Some(data) = spk_next { - let msg = ProcMsg::Speaker(AudioChunk{ data }); - let _ = proc.cast(msg); - } else { - break; - } - } +fn spawn_capture_task( + mic_device: Option, + mic_muted: Arc, + mode: ChannelMode, + sample_rate: u32, + app: tauri::AppHandle, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mic_device_name = mic_device.clone(); + let mic_stream = match MicInput::new(mic_device) { + Ok(mic_input) => mic_input.stream().chunks(AEC_BLOCK_SIZE), + Err(err) => { + tracing::error!( + error = ?err, + mic_device = ?mic_device_name, + "mic_stream_init_failed" + ); + return; + } + }; + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + let spk_stream = match SpeakerInput::new().and_then(|input| input.stream()) { + Ok(stream) => stream.chunks(AEC_BLOCK_SIZE), + Err(err) => { + tracing::error!(error = ?err, "speaker_stream_init_failed"); + return; + } + }; + + tokio::pin!(mic_stream); + tokio::pin!(spk_stream); + + let mut agc_m = hypr_agc::Agc::default(); + let mut agc_s = hypr_agc::Agc::default(); + let mut joiner = Joiner::new(); + let mut last_sent_mic: Option> = None; + let mut last_sent_spk: Option> = None; + let mut last_amp_emit = Instant::now(); + + loop { + tokio::select! { + mic_next = mic_stream.next() => { + if let Some(mut data) = mic_next { + let output_data = if mic_muted.load(Ordering::Relaxed) { + vec![0.0; data.len()] + } else { + agc_m.process(&mut data); + data + }; + let arc = Arc::<[f32]>::from(output_data); + joiner.push_mic(arc); + process_ready_inline(&mut joiner, mode, sample_rate, &mut last_sent_mic, &mut last_sent_spk, &mut last_amp_emit, &app).await; + } else { + break; + } + } + spk_next = spk_stream.next() => { + if let Some(mut data) = spk_next { + agc_s.process(&mut data); + let arc = Arc::<[f32]>::from(data); + joiner.push_spk(arc); + process_ready_inline(&mut joiner, mode, sample_rate, &mut last_sent_mic, &mut last_sent_spk, &mut last_amp_emit, &app).await; + } else { + break; } } - }) + } } - #[cfg(not(target_os = "macos"))] - { - tokio::spawn(async move {}) + }) +} + +async fn process_ready_inline( + joiner: &mut Joiner, + mode: ChannelMode, + sample_rate: u32, + last_sent_mic: &mut Option>, + last_sent_spk: &mut Option>, + last_amp_emit: &mut Instant, + app: &tauri::AppHandle, +) { + while let Some((mic, spk)) = joiner.pop_pair(mode) { + let mut audio_sent_successfully = false; + + if let Some(cell) = registry::where_is(RecorderActor::name()) { + let mixed: Vec = mic + .iter() + .zip(spk.iter()) + .map(|(m, s)| (m + s).clamp(-1.0, 1.0)) + .collect(); + + let actor: ActorRef = cell.into(); + if let Err(e) = actor.cast(RecMsg::Audio { + samples: mixed, + sample_rate, + }) { + tracing::error!(error = ?e, "failed_to_send_audio_to_recorder"); + } } - } else { - tokio::spawn(async move { - let mic_stream = { - let mut mic_input = hypr_audio::AudioInput::from_mic(mic_device).unwrap(); - ResampledAsyncSource::new(mic_input.stream(), SAMPLE_RATE).chunks(AEC_BLOCK_SIZE) - }; - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - let spk_stream = { - let mut spk_input = hypr_audio::AudioInput::from_speaker(); - ResampledAsyncSource::new(spk_input.stream(), SAMPLE_RATE).chunks(AEC_BLOCK_SIZE) + + if let Some(cell) = registry::where_is(ListenerActor::name()) { + let (mic_bytes, spk_bytes) = if mode == ChannelMode::Single { + let mixed: Vec = mic + .iter() + .zip(spk.iter()) + .map(|(m, s)| (m + s).clamp(-1.0, 1.0)) + .collect(); + let mixed_bytes = hypr_audio_utils::f32_to_i16_bytes(mixed.iter().copied()); + ( + hypr_audio_utils::f32_to_i16_bytes(mic.iter().copied()), + mixed_bytes, + ) + } else { + ( + hypr_audio_utils::f32_to_i16_bytes(mic.iter().copied()), + hypr_audio_utils::f32_to_i16_bytes(spk.iter().copied()), + ) }; - tokio::pin!(mic_stream); - tokio::pin!(spk_stream); - loop { - let Some(cell) = registry::where_is(ProcessorActor::name()) else { - tracing::warn!("processor_actor_not_found"); - continue; - }; - let proc: ActorRef = cell.into(); - - tokio::select! { - _ = token.cancelled() => { - drop(mic_stream); - drop(spk_stream); - myself2.stop(None); - return; - } - _ = stream_cancel_token.cancelled() => { - drop(mic_stream); - drop(spk_stream); - return; - } - mic_next = mic_stream.next() => { - if let Some(data) = mic_next { - let output_data = if mic_muted.load(Ordering::Relaxed) { - vec![0.0; data.len()] - } else { - data - }; - - let msg = ProcMsg::Mic(AudioChunk{ data: output_data }); - let _ = proc.cast(msg); - } else { - break; - } - } - spk_next = spk_stream.next() => { - if let Some(data) = spk_next { - let msg = ProcMsg::Speaker(AudioChunk{ data }); - let _ = proc.cast(msg); - } else { - break; - } - } + let actor: ActorRef = cell.into(); + if actor + .cast(ListenerMsg::Audio(mic_bytes.into(), spk_bytes.into())) + .is_ok() + { + audio_sent_successfully = true; + *last_sent_mic = Some(mic.clone()); + *last_sent_spk = Some(spk.clone()); + } else { + tracing::warn!(actor = ListenerActor::name(), "cast_failed"); + } + } else { + tracing::debug!(actor = ListenerActor::name(), "unavailable"); + } + + if audio_sent_successfully && last_amp_emit.elapsed() >= AUDIO_AMPLITUDE_THROTTLE { + if let (Some(mic_data), Some(spk_data)) = + (last_sent_mic.as_ref(), last_sent_spk.as_ref()) + { + if let Err(e) = SessionEvent::from((mic_data.as_ref(), spk_data.as_ref())).emit(app) + { + tracing::error!("{:?}", e); } + *last_amp_emit = Instant::now(); } - }) - }; + } + } +} - st.run_task = Some(handle); - Ok(()) +fn request_rest_for_one(supervisor: &ActorRef, child_id: ActorName) { + let child_id_string = child_id.to_string(); + tracing::info!(child = child_id_string, "requesting_rest_for_one_spawn"); + match supervisor.cast(SupervisorMsg::RestForOneSpawn { + child_id: child_id_string.clone(), + }) { + Ok(_) => { + tracing::info!(child = child_id_string, "requested_rest_for_one_spawn"); + } + Err(error) => { + tracing::warn!( + ?error, + child = child_id_string, + "failed_to_request_rest_for_one" + ); + } + } +} + +struct Joiner { + mic: VecDeque>, + spk: VecDeque>, + silence_cache: std::collections::HashMap>, +} + +impl Joiner { + const MAX_LAG: usize = 4; + const MAX_QUEUE_SIZE: usize = 30; + + fn new() -> Self { + Self { + mic: VecDeque::new(), + spk: VecDeque::new(), + silence_cache: std::collections::HashMap::new(), + } + } + + fn reset(&mut self) { + self.mic.clear(); + self.spk.clear(); + } + + fn get_silence(&mut self, len: usize) -> Arc<[f32]> { + self.silence_cache + .entry(len) + .or_insert_with(|| Arc::from(vec![0.0; len])) + .clone() + } + + fn push_mic(&mut self, data: Arc<[f32]>) { + self.mic.push_back(data); + if self.mic.len() > Self::MAX_QUEUE_SIZE { + tracing::warn!("mic_queue_overflow"); + self.mic.pop_front(); + } + } + + fn push_spk(&mut self, data: Arc<[f32]>) { + self.spk.push_back(data); + if self.spk.len() > Self::MAX_QUEUE_SIZE { + tracing::warn!("spk_queue_overflow"); + self.spk.pop_front(); + } + } + + fn pop_pair(&mut self, mode: ChannelMode) -> Option<(Arc<[f32]>, Arc<[f32]>)> { + match (self.mic.front(), self.spk.front()) { + (Some(_), Some(_)) => { + let mic = self.mic.pop_front()?; + let spk = self.spk.pop_front()?; + Some((mic, spk)) + } + (Some(_), None) if mode == ChannelMode::Single || self.mic.len() > Self::MAX_LAG => { + let mic = self.mic.pop_front()?; + let spk = self.get_silence(mic.len()); + Some((mic, spk)) + } + (None, Some(_)) if self.spk.len() > Self::MAX_LAG => { + let spk = self.spk.pop_front()?; + let mic = self.get_silence(spk.len()); + Some((mic, spk)) + } + _ => None, + } + } } diff --git a/plugins/listener/src/events.rs b/plugins/listener/src/events.rs index 99b628a22..90dd2657c 100644 --- a/plugins/listener/src/events.rs +++ b/plugins/listener/src/events.rs @@ -40,6 +40,7 @@ impl From<(&[f32], &[f32])> for SessionEvent { let mic = (mic_chunk .iter() .map(|&x| x.abs()) + .filter(|x| x.is_finite()) .max_by(|a, b| a.partial_cmp(b).unwrap()) .unwrap_or(0.0) * 100.0) as u16; @@ -47,6 +48,7 @@ impl From<(&[f32], &[f32])> for SessionEvent { let speaker = (speaker_chunk .iter() .map(|&x| x.abs()) + .filter(|x| x.is_finite()) .max_by(|a, b| a.partial_cmp(b).unwrap()) .unwrap_or(0.0) * 100.0) as u16; diff --git a/plugins/listener/src/ext.rs b/plugins/listener/src/ext.rs index 8cfba9478..cecfbf636 100644 --- a/plugins/listener/src/ext.rs +++ b/plugins/listener/src/ext.rs @@ -1,11 +1,14 @@ use std::future::Future; use std::sync::{Arc, Mutex}; -use ractor::{call_t, concurrency, registry, Actor, ActorRef}; +use ractor::{call_t, concurrency, registry, ActorRef}; use tauri_specta::Event; use crate::{ - actors::{BatchActor, BatchArgs, SessionActor, SessionArgs, SessionMsg, SessionParams}, + actors::{ + spawn_batch_actor, start_session_supervisor, BatchArgs, ControllerActor, ControllerMsg, + SessionParams, SESSION_SUPERVISOR_NAME, + }, SessionEvent, }; @@ -60,15 +63,15 @@ impl> ListenerPluginExt for T { #[tracing::instrument(skip_all)] async fn get_current_microphone_device(&self) -> Result, crate::Error> { - if let Some(cell) = registry::where_is(SessionActor::name()) { - let actor: ActorRef = cell.into(); + if let Some(cell) = registry::where_is(ControllerActor::name()) { + let actor: ActorRef = cell.into(); - match call_t!(actor, SessionMsg::GetMicDeviceName, 500) { + match call_t!(actor, ControllerMsg::GetMicDeviceName, 500) { Ok(device_name) => Ok(device_name), Err(_) => Ok(None), } } else { - Err(crate::Error::ActorNotFound(SessionActor::name())) + Err(crate::Error::ActorNotFound(ControllerActor::name())) } } @@ -77,9 +80,9 @@ impl> ListenerPluginExt for T { &self, device_name: impl Into, ) -> Result<(), crate::Error> { - if let Some(cell) = registry::where_is(SessionActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor.cast(SessionMsg::ChangeMicDevice(Some(device_name.into()))); + if let Some(cell) = registry::where_is(ControllerActor::name()) { + let actor: ActorRef = cell.into(); + let _ = actor.cast(ControllerMsg::ChangeMicDevice(Some(device_name.into()))); } Ok(()) @@ -87,7 +90,7 @@ impl> ListenerPluginExt for T { #[tracing::instrument(skip_all)] async fn get_state(&self) -> crate::fsm::State { - if let Some(_) = registry::where_is(SessionActor::name()) { + if let Some(_) = registry::where_is(SESSION_SUPERVISOR_NAME.to_string()) { crate::fsm::State::RunningActive } else { crate::fsm::State::Inactive @@ -96,10 +99,10 @@ impl> ListenerPluginExt for T { #[tracing::instrument(skip_all)] async fn get_mic_muted(&self) -> bool { - if let Some(cell) = registry::where_is(SessionActor::name()) { - let actor: ActorRef = cell.into(); + if let Some(cell) = registry::where_is(ControllerActor::name()) { + let actor: ActorRef = cell.into(); - match call_t!(actor, SessionMsg::GetMicMute, 100) { + match call_t!(actor, ControllerMsg::GetMicMute, 100) { Ok(muted) => muted, Err(_) => false, } @@ -110,51 +113,80 @@ impl> ListenerPluginExt for T { #[tracing::instrument(skip_all)] async fn set_mic_muted(&self, muted: bool) { - if let Some(cell) = registry::where_is(SessionActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor.cast(SessionMsg::SetMicMute(muted)); + if let Some(cell) = registry::where_is(ControllerActor::name()) { + let actor: ActorRef = cell.into(); + let _ = actor.cast(ControllerMsg::SetMicMute(muted)); } } #[tracing::instrument(skip_all)] async fn start_session(&self, params: SessionParams) { + if registry::where_is(SESSION_SUPERVISOR_NAME.to_string()).is_some() { + return; + } + let state = self.state::(); let guard = state.lock().await; + let app = guard.app.clone(); + drop(guard); - let _ = Actor::spawn( - Some(SessionActor::name()), - SessionActor, - SessionArgs { - app: guard.app.clone(), - params, - }, - ) - .await; + if let Err(err) = start_session_supervisor(app, params).await { + tracing::error!(error = ?err, "failed_to_spawn_session_supervisor"); + } } #[tracing::instrument(skip_all)] async fn stop_session(&self) { - if let Some(cell) = registry::where_is(SessionActor::name()) { - { - let state = self.state::(); + if let Some(cell) = registry::where_is(SESSION_SUPERVISOR_NAME.to_string()) { + let state = self.state::(); + let app_handle = { let guard = state.lock().await; - SessionEvent::Finalizing {}.emit(&guard.app).unwrap(); - } + guard.app.clone() + }; - let actor: ActorRef = cell.into(); - let _ = actor + let actor: ActorRef = cell.into(); + tracing::info!("stop_session: requesting supervisor shutdown"); + let stop_result = actor .stop_and_wait(None, Some(concurrency::Duration::from_secs(10))) .await; + tracing::info!(?stop_result, "stop_session: supervisor shutdown complete"); + + if let Err(err) = (SessionEvent::Inactive {}).emit(&app_handle) { + tracing::warn!(?err, "failed_to_emit_inactive_fallback"); + } else { + tracing::info!("stop_session: emitted_inactive_fallback"); + } } } #[tracing::instrument(skip_all)] async fn run_batch(&self, params: BatchParams) -> Result<(), crate::Error> { - let channels = params.channels.unwrap_or(1); + let metadata = tokio::task::spawn_blocking({ + let path = params.file_path.clone(); + move || hypr_audio_utils::audio_file_metadata(path) + }) + .await + .map_err(|err| { + crate::Error::BatchStartFailed(format!("failed to join audio metadata task: {err:?}")) + })? + .map_err(|err| { + crate::Error::BatchStartFailed(format!("failed to read audio metadata: {err}")) + })?; + + if let Some(requested) = params.channels { + if requested != metadata.channels { + tracing::warn!( + requested, + actual = metadata.channels, + "batch params channel override ignored in favor of file metadata" + ); + } + } let listen_params = owhisper_interface::ListenParams { model: params.model.clone(), - channels, + channels: metadata.channels, + sample_rate: metadata.sample_rate, languages: params.languages.clone(), keywords: params.keywords.clone(), redemption_time_ms: None, @@ -171,20 +203,27 @@ impl> ListenerPluginExt for T { let app = guard.app.clone(); drop(guard); - match Actor::spawn( - Some(BatchActor::name()), - BatchActor, - BatchArgs { - app, - file_path: params.file_path.clone(), - base_url: params.base_url.clone(), - api_key: params.api_key.clone(), - listen_params, - start_notifier: start_notifier.clone(), - }, - ) - .await - { + if registry::where_is(SESSION_SUPERVISOR_NAME.to_string()).is_some() { + let error = "live session must be stopped before running batch".to_string(); + tracing::error!("{}", error); + if let Ok(mut notifier) = start_notifier.lock() { + if let Some(tx) = notifier.take() { + let _ = tx.send(Err(error.clone())); + } + } + return Err(crate::Error::BatchStartFailed(error)); + } + + let args = BatchArgs { + app, + file_path: params.file_path.clone(), + base_url: params.base_url.clone(), + api_key: params.api_key.clone(), + listen_params: listen_params.clone(), + start_notifier: start_notifier.clone(), + }; + + match spawn_batch_actor(args).await { Ok(_) => { tracing::info!("batch actor spawned successfully"); let state = self.state::(); @@ -196,11 +235,11 @@ impl> ListenerPluginExt for T { .unwrap(); } Err(e) => { - tracing::error!("batch actor spawn failed: {:?}", e); + tracing::error!("batch supervisor spawn failed: {:?}", e); if let Ok(mut notifier) = start_notifier.lock() { if let Some(tx) = notifier.take() { - let _ = - tx.send(Err(format!("failed to spawn batch actor: {:?}", e))); + let _ = tx + .send(Err(format!("failed to spawn batch supervisor: {e:?}"))); } } return Err(e.into()); diff --git a/plugins/listener/src/lib.rs b/plugins/listener/src/lib.rs index 0b9df4b1a..dbec71ad8 100644 --- a/plugins/listener/src/lib.rs +++ b/plugins/listener/src/lib.rs @@ -22,7 +22,7 @@ pub struct State { impl State { pub async fn get_state(&self) -> fsm::State { - if let Some(_) = ractor::registry::where_is(actors::SessionActor::name()) { + if let Some(_) = ractor::registry::where_is(actors::SESSION_SUPERVISOR_NAME.to_string()) { crate::fsm::State::RunningActive } else { crate::fsm::State::Inactive diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs index d91e9e1b8..4f8d87ac3 100644 --- a/plugins/local-stt/src/server/external.rs +++ b/plugins/local-stt/src/server/external.rs @@ -37,6 +37,7 @@ impl ExternalSTTActor { } } +#[ractor::async_trait] impl Actor for ExternalSTTActor { type Msg = ExternalSTTMessage; type State = ExternalSTTState; @@ -61,6 +62,8 @@ impl Actor for ExternalSTTActor { let text = text.trim(); if !text.is_empty() && !text.contains("[WebSocket]") + && !text.contains("Sent interim text:") + && !text.contains("[TranscriptionHandler]") && !text.contains("/v1/status") { tracing::info!("{}", text); diff --git a/plugins/local-stt/src/server/internal.rs b/plugins/local-stt/src/server/internal.rs index 01d715101..90344ce7f 100644 --- a/plugins/local-stt/src/server/internal.rs +++ b/plugins/local-stt/src/server/internal.rs @@ -36,6 +36,7 @@ impl InternalSTTActor { } } +#[ractor::async_trait] impl Actor for InternalSTTActor { type Msg = InternalSTTMessage; type State = InternalSTTState;