diff --git a/.changeset/ninety-geckos-jam.md b/.changeset/ninety-geckos-jam.md new file mode 100644 index 000000000..3aa339f0b --- /dev/null +++ b/.changeset/ninety-geckos-jam.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Fix voice interruption transcript spill, add ConnectionPool for inference websockets, and log TTS websocket pool misses. diff --git a/agents/src/connection_pool.test.ts b/agents/src/connection_pool.test.ts new file mode 100644 index 000000000..b6002815a --- /dev/null +++ b/agents/src/connection_pool.test.ts @@ -0,0 +1,346 @@ +// SPDX-FileCopyrightText: 2025 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it, vi } from 'vitest'; +import { ConnectionPool } from './connection_pool.js'; + +describe('ConnectionPool', () => { + const makeConnectCb = () => { + let n = 0; + return vi.fn(async (_timeout: number): Promise => `conn_${++n}`); + }; + + describe('basic operations', () => { + it('should create and return a connection', async () => { + const connections: string[] = []; + const connectCb = vi.fn(async (_timeout: number): Promise => { + const conn = `conn_${connections.length}`; + connections.push(conn); + return conn; + }); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + const conn = await pool.get(); + expect(conn).toBe('conn_0'); + expect(connectCb).toHaveBeenCalledTimes(1); + + pool.put(conn); + const conn2 = await pool.get(); + expect(conn2).toBe('conn_0'); // Should reuse + expect(connectCb).toHaveBeenCalledTimes(1); + }); + + it('should create new connection when none available', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + const conn1 = await pool.get(); + pool.put(conn1); + const conn2 = await pool.get(); + expect(conn1).toBe(conn2); // Should reuse + expect(connectCb).toHaveBeenCalledTimes(1); + }); + + it('should remove connection from pool', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + const conn = await pool.get(); + pool.put(conn); + pool.remove(conn); + + const conn2 = await pool.get(); + expect(conn2).not.toBe(conn); // Should create new connection + expect(connectCb).toHaveBeenCalledTimes(2); + expect(closeCb).toHaveBeenCalledTimes(1); + }); + }); + + describe('maxSessionDuration', () => { + it('should expire connections after maxSessionDuration', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + maxSessionDuration: 100, // 100ms + }); + + const conn1 = await pool.get(); + pool.put(conn1); + + // Wait for expiration + await new Promise((resolve) => setTimeout(resolve, 150)); + + const conn2 = await pool.get(); + expect(conn2).not.toBe(conn1); // Should create new connection + expect(connectCb).toHaveBeenCalledTimes(2); + expect(closeCb).toHaveBeenCalledTimes(1); + }); + + it('should refresh connection timestamp when markRefreshedOnGet is true', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + maxSessionDuration: 200, // 200ms + markRefreshedOnGet: true, + }); + + const conn1 = await pool.get(); + pool.put(conn1); + + // Wait 100ms (less than expiration) + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Get again - should refresh timestamp + const conn2 = await pool.get(); + expect(conn2).toBe(conn1); // Should reuse + pool.put(conn2); + + // Wait another 100ms (total 200ms, but refreshed at 100ms) + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Should still be valid + const conn3 = await pool.get(); + expect(conn3).toBe(conn1); // Should still reuse + expect(connectCb).toHaveBeenCalledTimes(1); + }); + }); + + describe('withConnection', () => { + it('should return connection to pool on success', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + let capturedConn: string | undefined; + await pool.withConnection(async (conn) => { + capturedConn = conn; + return 'result'; + }); + + // Connection should be returned to pool + const conn2 = await pool.get(); + expect(conn2).toBe(capturedConn); // Should reuse + expect(connectCb).toHaveBeenCalledTimes(1); + }); + + it('should remove connection from pool on error', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + let capturedConn: string | undefined; + try { + await pool.withConnection(async (conn) => { + capturedConn = conn; + throw new Error('test error'); + }); + } catch (e) { + // Expected + } + + // Connection should be removed from pool + const conn2 = await pool.get(); + expect(conn2).not.toBe(capturedConn); // Should create new connection + expect(connectCb).toHaveBeenCalledTimes(2); + expect(closeCb).toHaveBeenCalledTimes(1); + }); + + it('should handle abort signal', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + const abortController = new AbortController(); + let capturedConn: string | undefined; + + const promise = pool.withConnection( + async (conn) => { + capturedConn = conn; + await new Promise((resolve) => setTimeout(resolve, 1000)); + return 'result'; + }, + { signal: abortController.signal }, + ); + + // Abort after a short delay + setTimeout(() => abortController.abort(), 10); + + await expect(promise).rejects.toThrow(); + + // Connection should be removed from pool + const conn2 = await pool.get(); + expect(conn2).not.toBe(capturedConn); // Should create new connection + expect(closeCb).toHaveBeenCalledTimes(1); + }); + }); + + describe('prewarm', () => { + it('should create connection in background', async () => { + let n = 0; + const connectCb = vi.fn(async (_timeout: number): Promise => { + await new Promise((resolve) => setTimeout(resolve, 50)); + return `conn_${++n}`; + }); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + pool.prewarm(); + + // Wait for prewarm to complete + await new Promise((resolve) => setTimeout(resolve, 100)); + + const conn = await pool.get(); + expect(conn).toBeDefined(); + expect(connectCb).toHaveBeenCalledTimes(1); + }); + + it('should not prewarm if connections already exist', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + // Create a connection first + const conn1 = await pool.get(); + pool.put(conn1); + + pool.prewarm(); // Should not create new connection + + const conn2 = await pool.get(); + expect(conn2).toBe(conn1); // Should reuse existing + expect(connectCb).toHaveBeenCalledTimes(1); + }); + }); + + describe('close', () => { + it('should close all connections', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + // Create two distinct connections by checking out both before returning either. + const conn1 = await pool.get(); + const conn2 = await pool.get(); + pool.put(conn1); + pool.put(conn2); + + await pool.close(); + + expect(closeCb).toHaveBeenCalledTimes(2); + }); + + it('should invalidate all connections', async () => { + const connectCb = makeConnectCb(); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + // Create two distinct connections by checking out both before returning either. + const conn1 = await pool.get(); + const conn2 = await pool.get(); + pool.put(conn1); + pool.put(conn2); + + pool.invalidate(); + await pool.close(); // Drain to close + + expect(closeCb).toHaveBeenCalledTimes(2); + }); + }); + + describe('concurrent access', () => { + it('should handle concurrent get requests', async () => { + const connectCb = vi.fn(async (_timeout: number): Promise => { + await new Promise((resolve) => setTimeout(resolve, 10)); + return `conn_${Date.now()}_${Math.random()}`; + }); + const closeCb = vi.fn(async (_conn: string) => { + // Mock close + }); + + const pool = new ConnectionPool({ + connectCb, + closeCb, + }); + + const promises = Array.from({ length: 5 }, () => pool.get()); + const connections = await Promise.all(promises); + + // All should be different connections + const uniqueConnections = new Set(connections); + expect(uniqueConnections.size).toBe(5); + expect(connectCb).toHaveBeenCalledTimes(5); + }); + }); +}); diff --git a/agents/src/connection_pool.ts b/agents/src/connection_pool.ts new file mode 100644 index 000000000..b2f13c798 --- /dev/null +++ b/agents/src/connection_pool.ts @@ -0,0 +1,307 @@ +// SPDX-FileCopyrightText: 2025 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { Mutex } from '@livekit/mutex'; +import { waitForAbort } from './utils.js'; + +/** + * Helper class to manage persistent connections like websockets. + */ +export interface ConnectionPoolOptions { + /** + * Maximum duration in milliseconds before forcing reconnection. + * If not set, connections will never expire based on duration. + */ + maxSessionDuration?: number; + + /** + * If true, the session will be marked as fresh when get() is called. + * Only used when maxSessionDuration is set. + */ + markRefreshedOnGet?: boolean; + + /** + * Async callback to create new connections. + * @param timeout - Connection timeout in milliseconds + * @returns A new connection object + */ + connectCb: (timeout: number) => Promise; + + /** + * Optional async callback to close connections. + * @param conn - The connection to close + */ + closeCb?: (conn: T) => Promise; + + /** + * Default connection timeout in milliseconds. + * Defaults to 10000 (10 seconds). + */ + connectTimeout?: number; +} + +/** + * Connection pool for managing persistent WebSocket connections. + * + * Reuses connections efficiently and automatically refreshes them after max duration. + * Prevents creating too many connections in a single conversation. + */ +export class ConnectionPool { + private readonly maxSessionDuration?: number; + private readonly markRefreshedOnGet: boolean; + private readonly connectCb: (timeout: number) => Promise; + private readonly closeCb?: (conn: T) => Promise; + private readonly connectTimeout: number; + + // Track connections and their creation timestamps + private readonly connections: Map = new Map(); + // Available connections ready for reuse + private readonly available: Set = new Set(); + // Connections queued for closing + private readonly toClose: Set = new Set(); + // Mutex for connection operations + private readonly connectLock = new Mutex(); + // Prewarm task reference + private prewarmController?: AbortController; + + constructor(options: ConnectionPoolOptions) { + this.maxSessionDuration = options.maxSessionDuration; + this.markRefreshedOnGet = options.markRefreshedOnGet ?? false; + this.connectCb = options.connectCb; + this.closeCb = options.closeCb; + this.connectTimeout = options.connectTimeout ?? 10_000; + } + + /** + * Create a new connection. + * + * @param timeout - Connection timeout in milliseconds + * @returns The new connection object + * @throws If connectCb is not provided or connection fails + */ + private async _connect(timeout: number): Promise { + const connection = await this.connectCb(timeout); + this.connections.set(connection, Date.now()); + return connection; + } + + /** + * Drain and close all connections queued for closing. + */ + private async _drainToClose(): Promise { + const connectionsToClose = Array.from(this.toClose); + this.toClose.clear(); + + for (const conn of connectionsToClose) { + await this._maybeCloseConnection(conn); + } + } + + /** + * Close a connection if closeCb is provided. + * + * @param conn - The connection to close + */ + private async _maybeCloseConnection(conn: T): Promise { + if (this.closeCb) { + await this.closeCb(conn); + } + } + + private _abortError(): Error { + const error = new Error('The operation was aborted.'); + error.name = 'AbortError'; + return error; + } + + /** + * Get an available connection or create a new one if needed. + * + * @param timeout - Connection timeout in milliseconds + * @returns An active connection object + */ + async get(timeout?: number): Promise { + const unlock = await this.connectLock.lock(); + try { + await this._drainToClose(); + const now = Date.now(); + + // Try to reuse an available connection that hasn't expired + while (this.available.size > 0) { + const conn = this.available.values().next().value as T; + this.available.delete(conn); + + if ( + this.maxSessionDuration === undefined || + now - (this.connections.get(conn) ?? 0) <= this.maxSessionDuration + ) { + if (this.markRefreshedOnGet) { + this.connections.set(conn, now); + } + return conn; + } + + // Connection expired; close it now so callers observing get() see it closed promptly. + // (Also makes tests deterministic: closeCb should have been called by the time get() resolves.) + if (this.connections.has(conn)) { + this.connections.delete(conn); + } + this.toClose.delete(conn); + await this._maybeCloseConnection(conn); + } + + return await this._connect(timeout ?? this.connectTimeout); + } finally { + unlock(); + } + } + + /** + * Mark a connection as available for reuse. + * + * If connection has been removed, it will not be added to the pool. + * + * @param conn - The connection to make available + */ + put(conn: T): void { + if (this.connections.has(conn)) { + this.available.add(conn); + return; + } + } + + /** + * Remove a specific connection from the pool. + * + * Marks the connection to be closed during the next drain cycle. + * + * @param conn - The connection to remove + */ + remove(conn: T): void { + this.available.delete(conn); + if (this.connections.has(conn)) { + this.toClose.add(conn); + this.connections.delete(conn); + // Important for Node websockets: if we just "mark to close later" but remove listeners, + // the ws library can buffer incoming frames in memory. Close ASAP in background. + void (async () => { + const unlock = await this.connectLock.lock(); + try { + if (!this.toClose.has(conn)) return; + await this._maybeCloseConnection(conn); + this.toClose.delete(conn); + } finally { + unlock(); + } + })(); + } + } + + /** + * Clear all existing connections. + * + * Marks all current connections to be closed during the next drain cycle. + */ + invalidate(): void { + for (const conn of this.connections.keys()) { + this.toClose.add(conn); + } + this.connections.clear(); + this.available.clear(); + } + + /** + * Initiate prewarming of the connection pool without blocking. + * + * This method starts a background task that creates a new connection if none exist. + * The task automatically cleans itself up when the connection pool is closed. + */ + prewarm(): void { + if (this.prewarmController || this.connections.size > 0) { + return; + } + + const controller = new AbortController(); + this.prewarmController = controller; + + // Start prewarm in background + this._prewarmImpl(controller.signal).catch(() => { + // Ignore errors during prewarm + }); + } + + private async _prewarmImpl(signal: AbortSignal): Promise { + const unlock = await this.connectLock.lock(); + try { + if (signal.aborted) { + return; + } + + if (this.connections.size === 0) { + const conn = await this._connect(this.connectTimeout); + this.available.add(conn); + } + } finally { + unlock(); + } + } + + /** + * Get a connection from the pool and automatically return it when done. + * Handles abort signals and ensures proper cleanup. + * + * @param fn - Function to execute with the connection + * @param options - Options including timeout and abort signal + * @returns The result of the function + */ + async withConnection( + fn: (conn: T) => Promise, + options?: { + timeout?: number; + signal?: AbortSignal; + }, + ): Promise { + // Check if already aborted before getting connection + if (options?.signal?.aborted) { + throw this._abortError(); + } + + const conn = await this.get(options?.timeout); + + const signal = options?.signal; + + try { + const fnPromise = fn(conn); + const result = signal + ? await Promise.race([ + fnPromise.then((value) => ({ type: 'result' as const, value })), + waitForAbort(signal).then(() => ({ type: 'abort' as const })), + ]).then((r) => { + if (r.type === 'abort') throw this._abortError(); + return r.value; + }) + : await fnPromise; + // Return connection to pool on success + this.put(conn); + return result; + } catch (error) { + // Remove connection from pool on error (don't return it) + this.remove(conn); + throw error; + } + } + + /** + * Close all connections, draining any pending connection closures. + */ + async close(): Promise { + // Cancel prewarm task if running + if (this.prewarmController) { + this.prewarmController.abort(); + this.prewarmController = undefined; + } + + this.invalidate(); + await this._drainToClose(); + } +} diff --git a/agents/src/index.ts b/agents/src/index.ts index a92d4bf3c..57ace0c7a 100644 --- a/agents/src/index.ts +++ b/agents/src/index.ts @@ -23,6 +23,7 @@ import * as voice from './voice/index.js'; export * from './_exceptions.js'; export * from './audio.js'; +export * from './connection_pool.js'; export * from './generator.js'; export * from './inference_runner.js'; export * from './job.js'; diff --git a/agents/src/inference/tts.ts b/agents/src/inference/tts.ts index 3c6541764..327eae1ed 100644 --- a/agents/src/inference/tts.ts +++ b/agents/src/inference/tts.ts @@ -5,13 +5,14 @@ import type { AudioFrame } from '@livekit/rtc-node'; import { WebSocket } from 'ws'; import { APIError, APIStatusError } from '../_exceptions.js'; import { AudioByteStream } from '../audio.js'; +import { ConnectionPool } from '../connection_pool.js'; import { log } from '../log.js'; import { createStreamChannel } from '../stream/stream_channel.js'; import { basic as tokenizeBasic } from '../tokenize/index.js'; import type { ChunkedStream } from '../tts/index.js'; import { SynthesizeStream as BaseSynthesizeStream, TTS as BaseTTS } from '../tts/index.js'; import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../types.js'; -import { shortuuid } from '../utils.js'; +import { Event, Future, Task, cancelAndWait, shortuuid } from '../utils.js'; import { type TtsClientEvent, type TtsServerEvent, @@ -95,6 +96,7 @@ export interface InferenceTTSOptions { export class TTS extends BaseTTS { private opts: InferenceTTSOptions; private streams: Set> = new Set(); + pool: ConnectionPool; #logger = log(); @@ -165,6 +167,15 @@ export class TTS extends BaseTTS { apiSecret: lkApiSecret, modelOptions, }; + + // Initialize connection pool + this.pool = new ConnectionPool({ + connectCb: (timeout) => this.connectWs(timeout), + closeCb: (ws) => this.closeWs(ws), + maxSessionDuration: 300_000, + markRefreshedOnGet: true, + connectTimeout: 10_000, // 10 seconds default + }); } get label() { @@ -218,6 +229,7 @@ export class TTS extends BaseTTS { if (this.opts.model) params.model = this.opts.model; if (this.opts.language) params.language = this.opts.language; + this.#logger.debug({ url }, 'inference.TTS creating new websocket connection (pool miss)'); const socket = await connectWs(url, headers, timeout); socket.send(JSON.stringify(params)); return socket; @@ -227,11 +239,16 @@ export class TTS extends BaseTTS { await ws.close(); } + prewarm(): void { + this.pool.prewarm(); + } + async close() { for (const stream of this.streams) { await stream.close(); } this.streams.clear(); + await this.pool.close(); } } @@ -256,30 +273,31 @@ export class SynthesizeStream extends BaseSynthesizeSt } protected async run(): Promise { - let ws: WebSocket | null = null; let closing = false; - let finalReceived = false; let lastFrame: AudioFrame | undefined; const sendTokenizerStream = new tokenizeBasic.SentenceTokenizer().stream(); const eventChannel = createStreamChannel(); const requestId = shortuuid('tts_request_'); + const inputSentEvent = new Event(); + + // Signal for protocol-driven completion (when 'done' message is received) + const completionFuture = new Future(); - const resourceCleanup = () => { + const resourceCleanup = async () => { if (closing) return; closing = true; sendTokenizerStream.close(); - eventChannel.close(); - ws?.removeAllListeners(); - ws?.close(); + // close() returns a promise; don't leak it + await eventChannel.close(); }; - const sendClientEvent = async (event: TtsClientEvent) => { + const sendClientEvent = async (event: TtsClientEvent, ws: WebSocket, signal: AbortSignal) => { // Don't send events to a closed WebSocket or aborted controller - if (this.abortController.signal.aborted || closing) return; + if (signal.aborted || closing) return; const validatedEvent = await ttsClientEventSchema.parseAsync(event); - if (!ws || ws.readyState !== WebSocket.OPEN) { + if (ws.readyState !== WebSocket.OPEN) { this.#logger.warn('Trying to send client TTS event to a closed WebSocket'); return; } @@ -293,9 +311,9 @@ export class SynthesizeStream extends BaseSynthesizeSt } }; - const createInputTask = async () => { + const createInputTask = async (signal: AbortSignal) => { for await (const data of this.input) { - if (this.abortController.signal.aborted || closing) break; + if (signal.aborted || closing) break; if (data === SynthesizeStream.FLUSH_SENTINEL) { sendTokenizerStream.flush(); continue; @@ -308,55 +326,108 @@ export class SynthesizeStream extends BaseSynthesizeSt } }; - const createSentenceStreamTask = async () => { + const createSentenceStreamTask = async (ws: WebSocket, signal: AbortSignal) => { for await (const ev of sendTokenizerStream) { - if (this.abortController.signal.aborted) break; - - sendClientEvent({ - type: 'input_transcript', - transcript: ev.token + ' ', - }); + if (signal.aborted || closing) break; + + await sendClientEvent( + { + type: 'input_transcript', + transcript: ev.token + ' ', + }, + ws, + signal, + ); + inputSentEvent.set(); } - sendClientEvent({ type: 'session.flush' }); + await sendClientEvent({ type: 'session.flush' }, ws, signal); + // needed in case empty input is sent + inputSentEvent.set(); }; - const createWsListenerTask = async (ws: WebSocket) => { - return new Promise((resolve, reject) => { - this.abortController.signal.addEventListener('abort', () => { - resourceCleanup(); - resolve(); // Abort is triggered by close(), which is a normal shutdown, not an error - }); - - ws.on('message', async (data) => { + // Handles WebSocket message routing and error handling + // Completes based on protocol messages, NOT on ws.close() + const createWsListenerTask = async (ws: WebSocket, signal: AbortSignal) => { + const onMessage = (data: Buffer) => { + try { const eventJson = JSON.parse(data.toString()) as Record; const validatedEvent = ttsServerEventSchema.parse(eventJson); - eventChannel.write(validatedEvent); - }); - - ws.on('error', (e) => { - this.#logger.error({ error: e }, 'WebSocket error'); - resourceCleanup(); - reject(e); - }); - - ws.on('close', () => { - resourceCleanup(); - - if (!closing) return this.#logger.error('WebSocket closed unexpectedly'); - if (finalReceived) return resolve(); + // writer.write returns a promise; avoid unhandled rejections if stream is closed + void eventChannel.write(validatedEvent).catch((error) => { + this.#logger.debug( + { error }, + 'Failed writing TTS event to stream channel (likely closed)', + ); + }); + } catch (e) { + this.#logger.error({ error: e }, 'Error parsing WebSocket message'); + } + }; - reject( + const onError = (e: Error) => { + this.#logger.error({ error: e }, 'WebSocket error'); + void resourceCleanup(); + try { + // If the ws is misbehaving, hard-stop it immediately to avoid buffering. + ws.terminate?.(); + } catch { + // ignore + } + // Ensure this ws is not reused + this.tts.pool.remove(ws); + completionFuture.reject(e); + }; + + const onClose = () => { + // WebSocket closed unexpectedly (not by us) + if (!closing) { + this.#logger.error('WebSocket closed unexpectedly'); + void resourceCleanup(); + // Ensure this ws is not reused + this.tts.pool.remove(ws); + completionFuture.reject( new APIStatusError({ message: 'Gateway connection closed unexpectedly', options: { requestId }, }), ); - }); - }); + } + }; + + const onAbort = () => { + void resourceCleanup(); + try { + // On interruption/abort, close the websocket immediately so the server stops streaming + // and the ws library doesn't buffer unread frames in memory. + ws.terminate?.(); + } catch { + // ignore + } + this.tts.pool.remove(ws); + inputSentEvent.set(); + completionFuture.resolve(); + }; + + // Attach listeners + ws.on('message', onMessage); + ws.on('error', onError); + ws.on('close', onClose); + signal.addEventListener('abort', onAbort); + + try { + // Wait for protocol-driven completion or error + await completionFuture.await; + } finally { + // IMPORTANT: Remove listeners so connection can be reused + ws.off('message', onMessage); + ws.off('error', onError); + ws.off('close', onClose); + signal.removeEventListener('abort', onAbort); + } }; - const createRecvTask = async () => { + const createRecvTask = async (signal: AbortSignal) => { let currentSessionId: string | null = null; const bstream = new AudioByteStream(this.opts.sampleRate, NUM_CHANNELS); @@ -364,9 +435,11 @@ export class SynthesizeStream extends BaseSynthesizeSt const reader = serverEventStream.getReader(); try { - while (!this.closed && !this.abortController.signal.aborted) { + await inputSentEvent.wait(); + + while (!this.closed && !signal.aborted) { const result = await reader.read(); - if (this.abortController.signal.aborted) return; + if (signal.aborted) return; if (result.done) return; const serverEvent = result.value; @@ -382,24 +455,29 @@ export class SynthesizeStream extends BaseSynthesizeSt } break; case 'done': - finalReceived = true; for (const frame of bstream.flush()) { sendLastFrame(currentSessionId!, false); lastFrame = frame; } sendLastFrame(currentSessionId!, true); this.queue.put(SynthesizeStream.END_OF_STREAM); - break; + await resourceCleanup(); + completionFuture.resolve(); + return; case 'session.closed': - resourceCleanup(); - break; + await resourceCleanup(); + completionFuture.resolve(); + return; case 'error': this.#logger.error( { serverEvent }, 'Received error message from LiveKit TTS WebSocket', ); - resourceCleanup(); - throw new APIError(`LiveKit TTS returned error: ${serverEvent.message}`); + await resourceCleanup(); + completionFuture.reject( + new APIError(`LiveKit TTS returned error: ${serverEvent.message}`), + ); + return; default: this.#logger.warn('Unexpected message %s', serverEvent); break; @@ -416,16 +494,100 @@ export class SynthesizeStream extends BaseSynthesizeSt }; try { - ws = await this.tts.connectWs(this.connOptions.timeoutMs); - - await Promise.all([ - createInputTask(), - createSentenceStreamTask(), - createWsListenerTask(ws), - createRecvTask(), - ]); + await this.tts.pool.withConnection( + async (ws: WebSocket) => { + try { + // IMPORTANT: don't cancel the stream's controller on normal completion, + // otherwise the pool will remove+close the ws and every run becomes a pool miss. + const runController = new AbortController(); + const onStreamAbort = () => runController.abort(this.abortController.signal.reason); + this.abortController.signal.addEventListener('abort', onStreamAbort, { once: true }); + + const combineSignals = (a: AbortSignal, b: AbortSignal): AbortSignal => { + const c = new AbortController(); + const abortFrom = (s: AbortSignal) => { + if (c.signal.aborted) return; + c.abort(s.reason); + }; + if (a.aborted) { + abortFrom(a); + } else { + a.addEventListener('abort', () => abortFrom(a), { once: true }); + } + if (b.aborted) { + abortFrom(b); + } else { + b.addEventListener('abort', () => abortFrom(b), { once: true }); + } + return c.signal; + }; + + const tasks = [ + Task.from( + async (controller) => { + const combined = combineSignals(runController.signal, controller.signal); + await createInputTask(combined); + }, + undefined, + 'inference-tts-input', + ), + Task.from( + async (controller) => { + const combined = combineSignals(runController.signal, controller.signal); + await createSentenceStreamTask(ws, combined); + }, + undefined, + 'inference-tts-sentence', + ), + Task.from( + async (controller) => { + const combined = combineSignals(runController.signal, controller.signal); + await createWsListenerTask(ws, combined); + }, + undefined, + 'inference-tts-ws-listener', + ), + Task.from( + async (controller) => { + const combined = combineSignals(runController.signal, controller.signal); + await createRecvTask(combined); + }, + undefined, + 'inference-tts-recv', + ), + ]; + + try { + await Promise.all(tasks.map((t) => t.result)); + } finally { + // Mirror python finally: unblock recv and cancel all tasks. + inputSentEvent.set(); + await resourceCleanup(); + await cancelAndWait(tasks, 5000); + this.abortController.signal.removeEventListener('abort', onStreamAbort); + } + } catch (e) { + // If aborted, don't throw - let cleanup handle it + if (e instanceof Error && e.name === 'AbortError') { + return; + } + throw e; + } + }, + { + timeout: this.connOptions.timeoutMs, + }, + ); + } catch (e) { + // Handle connection errors + if (e instanceof Error && e.name === 'AbortError') { + // Abort is expected during normal shutdown + return; + } + throw e; } finally { - resourceCleanup(); + // Ensure cleanup always runs (and don't leak the promise) + await resourceCleanup(); } } } diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index a1e3bf1d6..3ac2cc649 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -1449,6 +1449,13 @@ export class AgentActivity implements RecognitionHooks { { speech_id: speechHandle.id }, 'Aborting all pipeline reply tasks due to interruption', ); + + // Stop playout ASAP (don't wait for cancellations), otherwise the segment may finish and we + // will correctly (but undesirably) commit a long transcript even though the user said "stop". + if (audioOutput) { + audioOutput.clearBuffer(); + } + replyAbortController.abort(); await Promise.allSettled( tasks.map((task) => task.cancelAndWait(AgentActivity.REPLY_TASK_CANCEL_TIMEOUT)), @@ -1457,7 +1464,6 @@ export class AgentActivity implements RecognitionHooks { let forwardedText = textOut?.text || ''; if (audioOutput) { - audioOutput.clearBuffer(); const playbackEv = await audioOutput.waitForPlayout(); if (audioOut?.firstFrameFut.done) { // playback EV is valid only if the first frame was already played diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index 21850bb70..ad349a122 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -527,7 +527,10 @@ export class AgentSession< newAgentId: agent.id, }), ); - this.logger.debug({ previousActivity, agent }, 'Agent handoff inserted into chat context'); + this.logger.debug( + { previousAgentId: previousActivity?.agent.id, newAgentId: agent.id }, + 'Agent handoff inserted into chat context', + ); await this.activity.start(); diff --git a/agents/src/voice/transcription/synchronizer.ts b/agents/src/voice/transcription/synchronizer.ts index 9eb2ff469..c86426a13 100644 --- a/agents/src/voice/transcription/synchronizer.ts +++ b/agents/src/voice/transcription/synchronizer.ts @@ -151,7 +151,14 @@ class SegmentSynchronizerImpl { return; } - if (!interrupted) { + const playbackPosition = _playbackPosition; + const epsilonSeconds = 0.05; + const nearEnd = playbackPosition >= Math.max(0, this.audioData.pushedDuration - epsilonSeconds); + + // Only mark as fully completed if playback reached (roughly) the end. + // This prevents returning the full transcript in cases where the sink reports interrupted=false + // but playbackPosition indicates partial playout. + if (!interrupted && nearEnd) { this.playbackCompleted = true; } }