diff --git a/.changeset/tender-buses-yell.md b/.changeset/tender-buses-yell.md new file mode 100644 index 00000000..321a6b22 --- /dev/null +++ b/.changeset/tender-buses-yell.md @@ -0,0 +1,5 @@ +"@elevenlabs/client": minor +"@elevenlabs/react": minor +"@elevenlabs/types": minor + +Expose input and output audio streams on conversations and React hooks. diff --git a/packages/client/src/AudioStream.ts b/packages/client/src/AudioStream.ts new file mode 100644 index 00000000..10715f0e --- /dev/null +++ b/packages/client/src/AudioStream.ts @@ -0,0 +1 @@ +export type AudioStreamListener = (stream: MediaStream | null) => void; diff --git a/packages/client/src/BaseConversation.test.ts b/packages/client/src/BaseConversation.test.ts index c8b493a3..11a1f9d1 100644 --- a/packages/client/src/BaseConversation.test.ts +++ b/packages/client/src/BaseConversation.test.ts @@ -42,6 +42,12 @@ class TestConversation extends BaseConversation { public getOutputByteFrequencyData(): Uint8Array { return new Uint8Array(0); } + public getInputAudioStream(): MediaStream | null { + return null; + } + public getOutputAudioStream(): MediaStream | null { + return null; + } public getInputVolume(): number { return 0; } diff --git a/packages/client/src/BaseConversation.ts b/packages/client/src/BaseConversation.ts index 3e696a9a..09119ebd 100644 --- a/packages/client/src/BaseConversation.ts +++ b/packages/client/src/BaseConversation.ts @@ -156,6 +156,10 @@ export abstract class BaseConversation { return this.endSessionWithDetails({ reason: "user" }); } + public abstract getInputAudioStream(): MediaStream | null; + + public abstract getOutputAudioStream(): MediaStream | null; + private endSessionWithDetails = async (details: DisconnectionDetails) => { if (this.status !== "connected" && this.status !== "connecting") return; this.updateStatus("disconnecting"); diff --git a/packages/client/src/InputController.ts b/packages/client/src/InputController.ts index a1855c3d..da7c308e 100644 --- a/packages/client/src/InputController.ts +++ b/packages/client/src/InputController.ts @@ -1,4 +1,5 @@ import type { FormatConfig } from "./utils/BaseConnection.js"; +import type { AudioStreamListener } from "./AudioStream.js"; export type InputDeviceConfig = { inputDeviceId?: string; @@ -10,6 +11,10 @@ export interface InputController { setDevice(config?: Partial & InputDeviceConfig): Promise; setMuted(isMuted: boolean): Promise; isMuted(): boolean; + /** Returns the user input audio stream, if one is available. */ + getAudioStream(): MediaStream | null; + addAudioStreamListener(listener: AudioStreamListener): void; + removeAudioStreamListener(listener: AudioStreamListener): void; /** * @deprecated AnalyserNode is a web-only API and will not work on all diff --git a/packages/client/src/OutputController.ts b/packages/client/src/OutputController.ts index 700f34c7..3755d4fe 100644 --- a/packages/client/src/OutputController.ts +++ b/packages/client/src/OutputController.ts @@ -1,4 +1,5 @@ import type { FormatConfig } from "./utils/BaseConnection.js"; +import type { AudioStreamListener } from "./AudioStream.js"; export type OutputDeviceConfig = { outputDeviceId?: string; @@ -9,6 +10,10 @@ export interface OutputController { setDevice(config?: Partial & OutputDeviceConfig): Promise; setVolume(volume: number): void; interrupt(resetDuration?: number): void; + /** Returns the assistant output audio stream, if one is available. */ + getAudioStream(): MediaStream | null; + addAudioStreamListener(listener: AudioStreamListener): void; + removeAudioStreamListener(listener: AudioStreamListener): void; /** * @deprecated AnalyserNode is a web-only API and will not work on all diff --git a/packages/client/src/TextConversation.ts b/packages/client/src/TextConversation.ts index 84b05441..f420d56e 100644 --- a/packages/client/src/TextConversation.ts +++ b/packages/client/src/TextConversation.ts @@ -24,6 +24,14 @@ export class TextConversation extends BaseConversation { return EMPTY_FREQUENCY_DATA; } + public getInputAudioStream(): MediaStream | null { + return null; + } + + public getOutputAudioStream(): MediaStream | null { + return null; + } + public getInputVolume(): number { return 0; } diff --git a/packages/client/src/VoiceConversation.ts b/packages/client/src/VoiceConversation.ts index b0a379d4..e3247cdc 100644 --- a/packages/client/src/VoiceConversation.ts +++ b/packages/client/src/VoiceConversation.ts @@ -13,6 +13,7 @@ import { type PartialOptions, } from "./BaseConversation.js"; import type { InputController } from "./InputController.js"; +import type { AudioStreamListener } from "./AudioStream.js"; import type { OutputController } from "./OutputController.js"; import { setupStrategy } from "./platform/VoiceSessionSetup.js"; @@ -115,6 +116,14 @@ export class VoiceConversation extends BaseConversation { } }; + private handleInputAudioStream: AudioStreamListener = stream => { + this.options.onInputAudioStream?.(stream); + }; + + private handleOutputAudioStream: AudioStreamListener = stream => { + this.options.onOutputAudioStream?.(stream); + }; + protected constructor( options: Options, connection: BaseConnection, @@ -128,6 +137,9 @@ export class VoiceConversation extends BaseConversation { playbackEventTarget?.addListener(this.handlePlaybackEvent); + input.addAudioStreamListener(this.handleInputAudioStream); + output.addAudioStreamListener(this.handleOutputAudioStream); + if (wakeLock) { // Wake locks are automatically released when a page is hidden like when switching tabs // so attempt to re-acquire lock when page becomes visible again @@ -149,6 +161,8 @@ export class VoiceConversation extends BaseConversation { this.cleanUp(); this.playbackEventTarget?.removeListener(this.handlePlaybackEvent); this.playbackEventTarget = null; + this.input.removeAudioStreamListener(this.handleInputAudioStream); + this.output.removeAudioStreamListener(this.handleOutputAudioStream); await super.handleEndSession(); if (this.visibilityChangeHandler) { @@ -225,6 +239,14 @@ export class VoiceConversation extends BaseConversation { return this.output.getVolume(); } + public getInputAudioStream(): MediaStream | null { + return this.input.getAudioStream(); + } + + public getOutputAudioStream(): MediaStream | null { + return this.output.getAudioStream(); + } + public async changeInputDevice({ sampleRate, format, diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index 8091dc77..f8c9bf2b 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -21,6 +21,7 @@ export type { OutputController, OutputDeviceConfig, } from "./OutputController.js"; +export type { AudioStreamListener } from "./AudioStream.js"; export type { InputConfig } from "./utils/input.js"; export type { OutputConfig } from "./utils/output.js"; export type { diff --git a/packages/client/src/utils/WebRTCConnection.test.ts b/packages/client/src/utils/WebRTCConnection.test.ts index 4bc275a8..4e64d51c 100644 --- a/packages/client/src/utils/WebRTCConnection.test.ts +++ b/packages/client/src/utils/WebRTCConnection.test.ts @@ -174,7 +174,9 @@ describe("WebRTCConnection", () => { vi.stubGlobal("AudioContext", MockAudioContext); vi.stubGlobal( "MediaStream", - vi.fn((tracks: unknown[]) => ({ getTracks: () => tracks })) + vi.fn(function MediaStream(tracks: unknown[]) { + return { getTracks: () => tracks }; + }) ); const connection = await WebRTCConnection.create({ @@ -298,4 +300,74 @@ describe("WebRTCConnection", () => { } } ); + + it("notifies output audio stream listeners when the stream changes", async () => { + const mockRoom = new Room() as any; + (mockRoom.on as ReturnType).mockImplementation( + (event: string, callback: () => void) => { + if (event === "connected") { + queueMicrotask(callback); + } + } + ); + (mockRoom.once as ReturnType).mockImplementation( + (event: string, callback: () => void) => { + if (event === "signalConnected") { + queueMicrotask(callback); + } + } + ); + + const connection = await WebRTCConnection.create({ + conversationToken: "test-token", + connectionType: "webrtc", + }); + const listener = vi.fn(); + + connection.output.addAudioStreamListener(listener); + expect(listener).toHaveBeenCalledWith(null); + + const stream = { getTracks: () => [] } as unknown as MediaStream; + connection["setOutputAudioStream"](stream); + expect(connection.output.getAudioStream()).toBe(stream); + expect(listener).toHaveBeenLastCalledWith(stream); + + connection.close(); + expect(listener).toHaveBeenLastCalledWith(null); + }); + + it("notifies input audio stream listeners when the stream changes", async () => { + const mockRoom = new Room() as any; + (mockRoom.on as ReturnType).mockImplementation( + (event: string, callback: () => void) => { + if (event === "connected") { + queueMicrotask(callback); + } + } + ); + (mockRoom.once as ReturnType).mockImplementation( + (event: string, callback: () => void) => { + if (event === "signalConnected") { + queueMicrotask(callback); + } + } + ); + + const connection = await WebRTCConnection.create({ + conversationToken: "test-token", + connectionType: "webrtc", + }); + const listener = vi.fn(); + + connection.input.addAudioStreamListener(listener); + expect(listener).toHaveBeenCalledWith(null); + + const stream = { getTracks: () => [] } as unknown as MediaStream; + connection["setInputAudioStream"](stream); + expect(connection.input.getAudioStream()).toBe(stream); + expect(listener).toHaveBeenLastCalledWith(stream); + + connection.close(); + expect(listener).toHaveBeenLastCalledWith(null); + }); }); diff --git a/packages/client/src/utils/WebRTCConnection.ts b/packages/client/src/utils/WebRTCConnection.ts index 76f5f419..eb54ca74 100644 --- a/packages/client/src/utils/WebRTCConnection.ts +++ b/packages/client/src/utils/WebRTCConnection.ts @@ -27,6 +27,7 @@ import { import { arrayBufferToBase64 } from "./audio.js"; import { loadRawAudioProcessor } from "./rawAudioProcessor.generated.js"; import type { InputController, InputDeviceConfig } from "../InputController.js"; +import type { AudioStreamListener } from "../AudioStream.js"; import type { OutputController, OutputDeviceConfig, @@ -60,6 +61,10 @@ export class WebRTCConnection extends BaseConnection { private audioCaptureContext: AudioContext | null = null; private audioElements: HTMLAudioElement[] = []; private outputDeviceId: string | null = null; + private inputAudioStream: MediaStream | null = null; + private inputAudioStreamListeners = new Set(); + private outputAudioStream: MediaStream | null = null; + private outputAudioStreamListeners = new Set(); private inputAnalyser: AnalyserNode | null = null; private inputAudioContext: AudioContext | null = null; @@ -155,6 +160,14 @@ export class WebRTCConnection extends BaseConnection { } }, isMuted: () => this._isMuted, + getAudioStream: () => this.inputAudioStream, + addAudioStreamListener: (listener: AudioStreamListener) => { + this.inputAudioStreamListeners.add(listener); + listener(this.inputAudioStream); + }, + removeAudioStreamListener: (listener: AudioStreamListener) => { + this.inputAudioStreamListeners.delete(listener); + }, getAnalyser: () => this.inputAnalyser ?? undefined, getVolume: () => { if (this._isMuted) return 0; @@ -200,6 +213,14 @@ export class WebRTCConnection extends BaseConnection { // Audio interruption is managed by the server/agent }, getAnalyser: () => this.outputAnalyser ?? undefined, + getAudioStream: () => this.outputAudioStream, + addAudioStreamListener: (listener: AudioStreamListener) => { + this.outputAudioStreamListeners.add(listener); + listener(this.outputAudioStream); + }, + removeAudioStreamListener: (listener: AudioStreamListener) => { + this.outputAudioStreamListeners.delete(listener); + }, getVolume: () => this.outputVolumeProvider.getVolume(), getByteFrequencyData: (buffer: Uint8Array) => { this.outputVolumeProvider.getByteFrequencyData(buffer); @@ -330,6 +351,7 @@ export class WebRTCConnection extends BaseConnection { Track.Source.Microphone )?.track; if (micTrack) { + connection.setInputAudioStreamFromTrack(micTrack.mediaStreamTrack); connection.setupInputAnalyser(micTrack.mediaStreamTrack); } @@ -438,6 +460,10 @@ export class WebRTCConnection extends BaseConnection { // Store reference for volume control this.audioElements.push(audioElement); + // Expose the agent's remote track immediately; audio capture below is + // best-effort and may fail in non-browser environments. + this.setOutputAudioStreamFromTrack(remoteAudioTrack.mediaStreamTrack); + // Apply current volume if it exists (for when volume was set before audio track arrived) if (this.audioElements.length === 1) { // First audio element - trigger a callback to sync with current volume @@ -499,6 +525,7 @@ export class WebRTCConnection extends BaseConnection { this.inputAudioContext = null; this.inputAnalyser = null; } + this.setInputAudioStream(null); // Clean up audio capture context (non-blocking) if (this.audioCaptureContext) { @@ -507,6 +534,7 @@ export class WebRTCConnection extends BaseConnection { }); this.audioCaptureContext = null; } + this.setOutputAudioStream(null); // Clean up audio elements this.audioElements.forEach(element => { @@ -601,6 +629,43 @@ export class WebRTCConnection extends BaseConnection { this.outputVolumeProvider = provider; } + private createMediaStream( + mediaStreamTrack: MediaStreamTrack + ): MediaStream | null { + if (typeof MediaStream === "undefined") { + return null; + } + return new MediaStream([mediaStreamTrack]); + } + + private setInputAudioStreamFromTrack( + mediaStreamTrack: MediaStreamTrack + ): void { + this.setInputAudioStream(this.createMediaStream(mediaStreamTrack)); + } + + private setOutputAudioStreamFromTrack( + mediaStreamTrack: MediaStreamTrack + ): void { + this.setOutputAudioStream(this.createMediaStream(mediaStreamTrack)); + } + + private setInputAudioStream(stream: MediaStream | null): void { + if (this.inputAudioStream === stream) { + return; + } + this.inputAudioStream = stream; + this.inputAudioStreamListeners.forEach(listener => listener(stream)); + } + + private setOutputAudioStream(stream: MediaStream | null): void { + if (this.outputAudioStream === stream) { + return; + } + this.outputAudioStream = stream; + this.outputAudioStreamListeners.forEach(listener => listener(stream)); + } + private async setupAudioCapture(track: RemoteAudioTrack) { try { // Create audio context for processing @@ -738,6 +803,7 @@ export class WebRTCConnection extends BaseConnection { }); // Reconnect the input analyser to the new track + this.setInputAudioStreamFromTrack(audioTrack.mediaStreamTrack); this.setupInputAnalyser(audioTrack.mediaStreamTrack); } catch (error) { console.error("Failed to change input device:", error); diff --git a/packages/client/src/utils/input.ts b/packages/client/src/utils/input.ts index 1be80d62..82ece356 100644 --- a/packages/client/src/utils/input.ts +++ b/packages/client/src/utils/input.ts @@ -4,6 +4,7 @@ import { isIosDevice } from "./compatibility.js"; import type { AudioWorkletConfig } from "../BaseConversation.js"; import { addLibsamplerateModule } from "./addLibsamplerateModule.js"; import type { InputController, InputDeviceConfig } from "../InputController.js"; +import type { AudioStreamListener } from "../AudioStream.js"; import { createAnalyserVolumeProvider, type VolumeProvider, @@ -135,6 +136,7 @@ export class MediaDeviceInput implements InputController, InputEventTarget { private muted = false; private readonly volumeProvider: VolumeProvider; + private readonly inputAudioStreamListeners = new Set(); private constructor( private readonly context: AudioContext, @@ -175,6 +177,19 @@ export class MediaDeviceInput implements InputController, InputEventTarget { this.volumeProvider.getByteFrequencyData(buffer); } + public getAudioStream(): MediaStream { + return this.inputStream; + } + + public addAudioStreamListener(listener: AudioStreamListener): void { + this.inputAudioStreamListeners.add(listener); + listener(this.inputStream); + } + + public removeAudioStreamListener(listener: AudioStreamListener): void { + this.inputAudioStreamListeners.delete(listener); + } + public isMuted(): boolean { return this.muted; } @@ -200,6 +215,8 @@ export class MediaDeviceInput implements InputController, InputEventTarget { "change", this.handlePermissionsChange ); + this.inputAudioStreamListeners.forEach(listener => listener(null)); + this.inputAudioStreamListeners.clear(); await this.context.close(); } @@ -250,6 +267,9 @@ export class MediaDeviceInput implements InputController, InputEventTarget { this.inputStream = newInputStream; this.mediaStreamSource = this.context.createMediaStreamSource(newInputStream); + this.inputAudioStreamListeners.forEach(listener => + listener(newInputStream) + ); // Reconnect the audio graph this.mediaStreamSource.connect(this.analyser); diff --git a/packages/client/src/utils/output.ts b/packages/client/src/utils/output.ts index bdf82af1..5ec5469a 100644 --- a/packages/client/src/utils/output.ts +++ b/packages/client/src/utils/output.ts @@ -2,6 +2,7 @@ import { loadAudioConcatProcessor } from "./audioConcatProcessor.generated.js"; import type { FormatConfig } from "./connection.js"; import type { AudioWorkletConfig } from "../BaseConversation.js"; import { addLibsamplerateModule } from "./addLibsamplerateModule.js"; +import type { AudioStreamListener } from "../AudioStream.js"; import type { OutputController, OutputDeviceConfig, @@ -104,7 +105,8 @@ export class MediaDeviceOutput analyser, gain, worklet, - audioElement + audioElement, + destination.stream ); return newOutput; @@ -126,13 +128,15 @@ export class MediaDeviceOutput private interrupted = false; private interruptTimeout: ReturnType | null = null; private readonly volumeProvider: VolumeProvider; + private readonly outputAudioStreamListeners = new Set(); private constructor( private readonly context: AudioContext, private readonly analyser: AnalyserNode, private readonly gain: GainNode, private readonly worklet: AudioWorkletNode, - private readonly audioElement: HTMLAudioElement + private readonly audioElement: HTMLAudioElement, + private readonly audioStream: MediaStream ) { // Start the MessagePort to enable addEventListener to work // (required when using addEventListener instead of onmessage) @@ -155,6 +159,19 @@ export class MediaDeviceOutput this.volumeProvider.getByteFrequencyData(buffer); } + public getAudioStream(): MediaStream { + return this.audioStream; + } + + public addAudioStreamListener(listener: AudioStreamListener): void { + this.outputAudioStreamListeners.add(listener); + listener(this.audioStream); + } + + public removeAudioStreamListener(listener: AudioStreamListener): void { + this.outputAudioStreamListeners.delete(listener); + } + public addListener(listener: PlaybackListener): void { this.worklet.port.addEventListener("message", listener); } @@ -235,6 +252,8 @@ export class MediaDeviceOutput if (this.audioElement.parentNode) { this.audioElement.parentNode.removeChild(this.audioElement); } + this.outputAudioStreamListeners.forEach(listener => listener(null)); + this.outputAudioStreamListeners.clear(); this.audioElement.pause(); await this.context.close(); } diff --git a/packages/react/src/conversation/ConversationAudioStream.test.tsx b/packages/react/src/conversation/ConversationAudioStream.test.tsx new file mode 100644 index 00000000..da443ec1 --- /dev/null +++ b/packages/react/src/conversation/ConversationAudioStream.test.tsx @@ -0,0 +1,152 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import React, { useContext } from "react"; +import { renderHook, act } from "@testing-library/react"; +import { + Conversation, + type Callbacks, + type ConversationLifecycleOptions, +} from "@elevenlabs/client"; +import { ConversationProvider } from "./ConversationProvider.js"; +import { + ConversationContext, + type ConversationContextValue, +} from "./ConversationContext.js"; +import { useConversationAudioStream } from "./ConversationAudioStream.js"; +import { useConversation } from "./useConversation.js"; + +vi.mock("@elevenlabs/client", async importOriginal => { + const actual = await importOriginal(); + return { ...actual, Conversation: { startSession: vi.fn() } }; +}); + +const createMockConversation = ( + inputAudioStream: MediaStream | null = null, + outputAudioStream: MediaStream | null = null +) => + ({ + getId: vi.fn().mockReturnValue("test-id"), + endSession: vi.fn().mockResolvedValue(undefined), + setMicMuted: vi.fn(), + setVolume: vi.fn(), + getInputAudioStream: vi.fn().mockReturnValue(inputAudioStream), + getOutputAudioStream: vi.fn().mockReturnValue(outputAudioStream), + }) as unknown as Conversation; + +function useTestHook() { + const ctx = useContext(ConversationContext) as ConversationContextValue; + const streams = useConversationAudioStream(); + return { startSession: ctx.startSession, streams }; +} + +function createWrapper(props: Record = {}) { + return function Wrapper({ children }: React.PropsWithChildren) { + return {children}; + }; +} + +type MockStartSessionOptions = Partial & + Record; + +describe("useConversationAudioStream", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("throws when used outside a ConversationProvider", () => { + expect(() => renderHook(() => useConversationAudioStream())).toThrow( + "useConversationAudioStream must be used within a ConversationProvider" + ); + }); + + it("returns null initially", () => { + const { result } = renderHook(() => useConversationAudioStream(), { + wrapper: createWrapper(), + }); + + expect(result.current.inputAudioStream).toBeNull(); + expect(result.current.outputAudioStream).toBeNull(); + }); + + it("updates when input and output audio stream callbacks fire", async () => { + const mockConversation = createMockConversation(); + vi.mocked(Conversation.startSession).mockResolvedValue(mockConversation); + + const { result } = renderHook(() => useTestHook(), { + wrapper: createWrapper(), + }); + + await act(async () => { + result.current.startSession(); + }); + + const [[opts]] = vi.mocked(Conversation.startSession).mock + .calls as unknown as [[MockStartSessionOptions]]; + const inputStream = {} as MediaStream; + const outputStream = {} as MediaStream; + + act(() => { + opts.onInputAudioStream?.(inputStream); + opts.onOutputAudioStream?.(outputStream); + }); + + expect(result.current.streams.inputAudioStream).toBe(inputStream); + expect(result.current.streams.outputAudioStream).toBe(outputStream); + }); + + it("clears streams when the session disconnects", async () => { + const inputStream = {} as MediaStream; + const outputStream = {} as MediaStream; + const mockConversation = createMockConversation(inputStream, outputStream); + vi.mocked(Conversation.startSession).mockResolvedValue(mockConversation); + + const { result } = renderHook(() => useTestHook(), { + wrapper: createWrapper(), + }); + + await act(async () => { + result.current.startSession(); + }); + + const [[opts]] = vi.mocked(Conversation.startSession).mock + .calls as unknown as [[MockStartSessionOptions]]; + + act(() => { + opts.onInputAudioStream?.(inputStream); + opts.onOutputAudioStream?.(outputStream); + }); + expect(result.current.streams.inputAudioStream).toBe(inputStream); + expect(result.current.streams.outputAudioStream).toBe(outputStream); + + act(() => { + opts.onDisconnect?.({ reason: "user" }); + }); + expect(result.current.streams.inputAudioStream).toBeNull(); + expect(result.current.streams.outputAudioStream).toBeNull(); + }); + + it("is included in useConversation", async () => { + const mockConversation = createMockConversation(); + vi.mocked(Conversation.startSession).mockResolvedValue(mockConversation); + + const { result } = renderHook(() => useConversation(), { + wrapper: createWrapper(), + }); + + await act(async () => { + result.current.startSession(); + }); + + const [[opts]] = vi.mocked(Conversation.startSession).mock + .calls as unknown as [[MockStartSessionOptions]]; + const inputStream = {} as MediaStream; + const outputStream = {} as MediaStream; + + act(() => { + opts.onInputAudioStream?.(inputStream); + opts.onOutputAudioStream?.(outputStream); + }); + + expect(result.current.inputAudioStream).toBe(inputStream); + expect(result.current.outputAudioStream).toBe(outputStream); + }); +}); diff --git a/packages/react/src/conversation/ConversationAudioStream.tsx b/packages/react/src/conversation/ConversationAudioStream.tsx new file mode 100644 index 00000000..ee7d7349 --- /dev/null +++ b/packages/react/src/conversation/ConversationAudioStream.tsx @@ -0,0 +1,66 @@ +import { createContext, useContext, useMemo, useState } from "react"; +import { useRegisterCallbacks } from "./ConversationContext.js"; + +export type ConversationAudioStreamValue = { + inputAudioStream: MediaStream | null; + outputAudioStream: MediaStream | null; +}; + +const ConversationAudioStreamContext = + createContext(null); + +/** + * Tracks input and output audio streams exposed by the active conversation. + * Must be rendered inside a `ConversationProvider`. + */ +export function ConversationAudioStreamProvider({ + children, +}: React.PropsWithChildren) { + const [inputAudioStream, setInputAudioStream] = + useState(null); + const [outputAudioStream, setOutputAudioStream] = + useState(null); + + useRegisterCallbacks({ + onInputAudioStream(stream) { + setInputAudioStream(stream); + }, + onOutputAudioStream(stream) { + setOutputAudioStream(stream); + }, + onDisconnect() { + setInputAudioStream(null); + setOutputAudioStream(null); + }, + }); + + const value = useMemo( + () => ({ + inputAudioStream, + outputAudioStream, + }), + [inputAudioStream, outputAudioStream] + ); + + return ( + + {children} + + ); +} + +/** + * Returns the user input and assistant output audio streams, or `null` before + * each stream is available. Re-renders when either stream changes. + * + * Must be used within a `ConversationProvider`. + */ +export function useConversationAudioStream(): ConversationAudioStreamValue { + const ctx = useContext(ConversationAudioStreamContext); + if (!ctx) { + throw new Error( + "useConversationAudioStream must be used within a ConversationProvider" + ); + } + return ctx; +} diff --git a/packages/react/src/conversation/ConversationProvider.tsx b/packages/react/src/conversation/ConversationProvider.tsx index ca76890a..d360f279 100644 --- a/packages/react/src/conversation/ConversationProvider.tsx +++ b/packages/react/src/conversation/ConversationProvider.tsx @@ -33,6 +33,7 @@ import { } from "./ConversationInput.js"; import { ConversationModeProvider } from "./ConversationMode.js"; import { ConversationFeedbackProvider } from "./ConversationFeedback.js"; +import { ConversationAudioStreamProvider } from "./ConversationAudioStream.js"; import { ConversationClientToolsProvider, buildClientTools, @@ -50,6 +51,7 @@ const SUB_PROVIDERS_WITHOUT_PROPS: React.ComponentType[ ConversationStatusProvider, ConversationModeProvider, ConversationFeedbackProvider, + ConversationAudioStreamProvider, ConversationClientToolsProvider, ]; diff --git a/packages/react/src/conversation/types.ts b/packages/react/src/conversation/types.ts index d3c3b870..c8b43a19 100644 --- a/packages/react/src/conversation/types.ts +++ b/packages/react/src/conversation/types.ts @@ -26,6 +26,8 @@ export type HookCallbacks = Pick< | "onError" | "onMessage" | "onAudio" + | "onInputAudioStream" + | "onOutputAudioStream" | "onModeChange" | "onStatusChange" | "onCanSendFeedbackChange" diff --git a/packages/react/src/conversation/useConversation.ts b/packages/react/src/conversation/useConversation.ts index 30c1c101..c8008da2 100644 --- a/packages/react/src/conversation/useConversation.ts +++ b/packages/react/src/conversation/useConversation.ts @@ -6,6 +6,7 @@ import { useConversationStatus } from "./ConversationStatus.js"; import { useConversationInput } from "./ConversationInput.js"; import { useConversationMode } from "./ConversationMode.js"; import { useConversationFeedback } from "./ConversationFeedback.js"; +import { useConversationAudioStream } from "./ConversationAudioStream.js"; import { useRawConversation, useRegisterCallbacks, @@ -45,6 +46,7 @@ export function useConversation(props: UseConversationOptions = {}) { const { isMuted, setMuted } = useConversationInput(); const { mode, isSpeaking, isListening } = useConversationMode(); const { canSendFeedback, sendFeedback } = useConversationFeedback(); + const { inputAudioStream, outputAudioStream } = useConversationAudioStream(); const startSession = useCallback( (options?: HookOptions) => { @@ -64,7 +66,7 @@ export function useConversation(props: UseConversationOptions = {}) { ...options, } as HookOptions); }, - [controls, hookOptionsRef] + [controls] ); const conversation = useRawConversation(); @@ -91,6 +93,8 @@ export function useConversation(props: UseConversationOptions = {}) { mode, isSpeaking, isListening, + inputAudioStream, + outputAudioStream, canSendFeedback, sendFeedback, }; diff --git a/packages/react/src/index.ts b/packages/react/src/index.ts index d560b121..cfb177d2 100644 --- a/packages/react/src/index.ts +++ b/packages/react/src/index.ts @@ -27,6 +27,7 @@ export { ConversationProvider } from "./conversation/ConversationProvider.js"; export { useConversationControls } from "./conversation/ConversationControls.js"; export { useConversationStatus } from "./conversation/ConversationStatus.js"; export { useConversationInput } from "./conversation/ConversationInput.js"; +export { useConversationAudioStream } from "./conversation/ConversationAudioStream.js"; export { useConversationMode } from "./conversation/ConversationMode.js"; export { useConversationFeedback } from "./conversation/ConversationFeedback.js"; export { useRawConversation } from "./conversation/ConversationContext.js"; @@ -35,6 +36,7 @@ export { useConversationClientTool } from "./conversation/ConversationClientTool export type { UseConversationOptions } from "./conversation/useConversation.js"; export type { ConversationControlsValue } from "./conversation/ConversationControls.js"; export type { ConversationInputValue } from "./conversation/ConversationInput.js"; +export type { ConversationAudioStreamValue } from "./conversation/ConversationAudioStream.js"; export type { ConversationStatus, ConversationStatusValue, diff --git a/packages/react/src/test-globals.d.ts b/packages/react/src/test-globals.d.ts index a21dd352..fb2bc40d 100644 --- a/packages/react/src/test-globals.d.ts +++ b/packages/react/src/test-globals.d.ts @@ -1 +1,3 @@ +/// + declare const console: Pick; diff --git a/packages/types/src/types.ts b/packages/types/src/types.ts index 22a8a729..d0645b93 100644 --- a/packages/types/src/types.ts +++ b/packages/types/src/types.ts @@ -61,6 +61,8 @@ export type Callbacks = { onError?: (message: string, context?: any) => void; onMessage?: (props: MessagePayload) => void; onAudio?: (base64Audio: string) => void; + onInputAudioStream?: (stream: MediaStream | null) => void; + onOutputAudioStream?: (stream: MediaStream | null) => void; onModeChange?: (prop: { mode: Mode }) => void; onStatusChange?: (prop: { status: Status }) => void; onCanSendFeedbackChange?: (prop: { canSendFeedback: boolean }) => void; @@ -110,6 +112,8 @@ export const CALLBACK_KEYS = [ "onError", "onMessage", "onAudio", + "onInputAudioStream", + "onOutputAudioStream", "onModeChange", "onStatusChange", "onCanSendFeedbackChange",