Skip to content

Commit b0f5cce

Browse files
authored
fix resource cleanup (#849)
1 parent 01f9ad3 commit b0f5cce

File tree

6 files changed

+126
-58
lines changed

6 files changed

+126
-58
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"@livekit/agents": patch
3+
"@livekit/agents-plugin-deepgram": patch
4+
---
5+
6+
fix resource cleanup

agents/src/stt/stt.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter<STTCal
135135
* transcriptions
136136
*/
137137
abstract stream(): SpeechStream;
138+
139+
async close(): Promise<void> {
140+
return;
141+
}
138142
}
139143

140144
/**

agents/src/tts/tts.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ export abstract class TTS extends (EventEmitter as new () => TypedEmitter<TTSCal
9494
* Returns a {@link SynthesizeStream} that can be used to push text and receive audio data
9595
*/
9696
abstract stream(): SynthesizeStream;
97+
98+
async close(): Promise<void> {
99+
return;
100+
}
97101
}
98102

99103
/**

agents/src/vad.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ export abstract class VAD extends (EventEmitter as new () => TypedEmitter<VADCal
8080
* Returns a {@link VADStream} that can be used to push audio frames and receive VAD events.
8181
*/
8282
abstract stream(): VADStream;
83+
84+
async close(): Promise<void> {
85+
return;
86+
}
8387
}
8488

8589
export abstract class VADStream implements AsyncIterableIterator<VADEvent> {

agents/src/voice/agent_activity.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2141,12 +2141,15 @@ export class AgentActivity implements RecognitionHooks {
21412141
}
21422142
if (this.stt instanceof STT) {
21432143
this.stt.off('metrics_collected', this.onMetricsCollected);
2144+
await this.stt.close();
21442145
}
21452146
if (this.tts instanceof TTS) {
21462147
this.tts.off('metrics_collected', this.onMetricsCollected);
2148+
await this.tts.close();
21472149
}
21482150
if (this.vad instanceof VAD) {
21492151
this.vad.off('metrics_collected', this.onMetricsCollected);
2152+
await this.vad.close();
21502153
}
21512154

21522155
this.detachAudioInput();

plugins/deepgram/src/stt.ts

Lines changed: 105 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ import {
66
AudioByteStream,
77
AudioEnergyFilter,
88
Future,
9+
Task,
910
log,
1011
stt,
12+
waitForAbort,
1113
} from '@livekit/agents';
1214
import type { AudioFrame } from '@livekit/rtc-node';
13-
import { type RawData, WebSocket } from 'ws';
15+
import { WebSocket } from 'ws';
1416
import { PeriodicCollector } from './_utils.js';
1517
import type { STTLanguages, STTModels } from './models.js';
1618

@@ -62,6 +64,7 @@ export class STT extends stt.STT {
6264
#opts: STTOptions;
6365
#logger = log();
6466
label = 'deepgram.STT';
67+
private abortController = new AbortController();
6568

6669
constructor(opts: Partial<STTOptions> = defaultSTTOptions) {
6770
super({
@@ -111,7 +114,11 @@ export class STT extends stt.STT {
111114
}
112115

113116
stream(): SpeechStream {
114-
return new SpeechStream(this, this.#opts);
117+
return new SpeechStream(this, this.#opts, this.abortController);
118+
}
119+
120+
async close() {
121+
this.abortController.abort();
115122
}
116123
}
117124

@@ -125,7 +132,11 @@ export class SpeechStream extends stt.SpeechStream {
125132
#audioDurationCollector: PeriodicCollector<number>;
126133
label = 'deepgram.SpeechStream';
127134

128-
constructor(stt: STT, opts: STTOptions) {
135+
constructor(
136+
stt: STT,
137+
opts: STTOptions,
138+
private abortController: AbortController,
139+
) {
129140
super(stt, opts.sampleRate);
130141
this.#opts = opts;
131142
this.closed = false;
@@ -140,7 +151,8 @@ export class SpeechStream extends stt.SpeechStream {
140151
const maxRetry = 32;
141152
let retries = 0;
142153
let ws: WebSocket;
143-
while (!this.input.closed) {
154+
155+
while (!this.input.closed && !this.closed) {
144156
const streamURL = new URL(API_BASE_URL_V1);
145157
const params = {
146158
model: this.#opts.model,
@@ -185,17 +197,23 @@ export class SpeechStream extends stt.SpeechStream {
185197

186198
await this.#runWS(ws);
187199
} catch (e) {
188-
if (retries >= maxRetry) {
189-
throw new Error(`failed to connect to Deepgram after ${retries} attempts: ${e}`);
190-
}
200+
if (!this.closed && !this.input.closed) {
201+
if (retries >= maxRetry) {
202+
throw new Error(`failed to connect to Deepgram after ${retries} attempts: ${e}`);
203+
}
191204

192-
const delay = Math.min(retries * 5, 10);
193-
retries++;
205+
const delay = Math.min(retries * 5, 10);
206+
retries++;
194207

195-
this.#logger.warn(
196-
`failed to connect to Deepgram, retrying in ${delay} seconds: ${e} (${retries}/${maxRetry})`,
197-
);
198-
await new Promise((resolve) => setTimeout(resolve, delay * 1000));
208+
this.#logger.warn(
209+
`failed to connect to Deepgram, retrying in ${delay} seconds: ${e} (${retries}/${maxRetry})`,
210+
);
211+
await new Promise((resolve) => setTimeout(resolve, delay * 1000));
212+
} else {
213+
this.#logger.warn(
214+
`Deepgram disconnected, connection is closed: ${e} (inputClosed: ${this.input.closed}, isClosed: ${this.closed})`,
215+
);
216+
}
199217
}
200218
}
201219

@@ -220,6 +238,20 @@ export class SpeechStream extends stt.SpeechStream {
220238
}
221239
}, 5000);
222240

241+
// gets cancelled also when sendTask is complete
242+
const wsMonitor = Task.from(async (controller) => {
243+
const closed = new Promise<void>(async (_, reject) => {
244+
ws.once('close', (code, reason) => {
245+
if (!closing) {
246+
this.#logger.error(`WebSocket closed with code ${code}: ${reason}`);
247+
reject(new Error('WebSocket closed'));
248+
}
249+
});
250+
});
251+
252+
await Promise.race([closed, waitForAbort(controller.signal)]);
253+
});
254+
223255
const sendTask = async () => {
224256
const samples100Ms = Math.floor(this.#opts.sampleRate / 10);
225257
const stream = new AudioByteStream(
@@ -228,48 +260,52 @@ export class SpeechStream extends stt.SpeechStream {
228260
samples100Ms,
229261
);
230262

231-
for await (const data of this.input) {
232-
let frames: AudioFrame[];
233-
if (data === SpeechStream.FLUSH_SENTINEL) {
234-
frames = stream.flush();
235-
this.#audioDurationCollector.flush();
236-
} else if (
237-
data.sampleRate === this.#opts.sampleRate ||
238-
data.channels === this.#opts.numChannels
239-
) {
240-
frames = stream.write(data.data.buffer);
241-
} else {
242-
throw new Error(`sample rate or channel count of frame does not match`);
243-
}
263+
try {
264+
while (!this.closed) {
265+
const result = await Promise.race([
266+
this.input.next(),
267+
waitForAbort(this.abortController.signal),
268+
]);
269+
270+
if (result === undefined) return; // aborted
271+
if (result.done) {
272+
break;
273+
}
274+
275+
const data = result.value;
276+
277+
let frames: AudioFrame[];
278+
if (data === SpeechStream.FLUSH_SENTINEL) {
279+
frames = stream.flush();
280+
this.#audioDurationCollector.flush();
281+
} else if (
282+
data.sampleRate === this.#opts.sampleRate ||
283+
data.channels === this.#opts.numChannels
284+
) {
285+
frames = stream.write(data.data.buffer as ArrayBuffer);
286+
} else {
287+
throw new Error(`sample rate or channel count of frame does not match`);
288+
}
244289

245-
for await (const frame of frames) {
246-
if (this.#audioEnergyFilter.pushFrame(frame)) {
247-
const frameDuration = frame.samplesPerChannel / frame.sampleRate;
248-
this.#audioDurationCollector.push(frameDuration);
249-
ws.send(frame.data.buffer);
290+
for await (const frame of frames) {
291+
if (this.#audioEnergyFilter.pushFrame(frame)) {
292+
const frameDuration = frame.samplesPerChannel / frame.sampleRate;
293+
this.#audioDurationCollector.push(frameDuration);
294+
ws.send(frame.data.buffer);
295+
}
250296
}
251297
}
298+
} finally {
299+
closing = true;
300+
ws.send(JSON.stringify({ type: 'CloseStream' }));
301+
wsMonitor.cancel();
252302
}
253-
254-
closing = true;
255-
ws.send(JSON.stringify({ type: 'CloseStream' }));
256303
};
257304

258-
const wsMonitor = new Promise<void>((_, reject) =>
259-
ws.once('close', (code, reason) => {
260-
if (!closing) {
261-
this.#logger.error(`WebSocket closed with code ${code}: ${reason}`);
262-
reject(new Error('WebSocket closed'));
263-
}
264-
}),
265-
);
266-
267-
const listenTask = async () => {
268-
while (!this.closed && !closing) {
269-
try {
270-
await new Promise<RawData>((resolve) => {
271-
ws.once('message', (data) => resolve(data));
272-
}).then((msg) => {
305+
const listenTask = Task.from(async (controller) => {
306+
const listenMessage = new Promise<void>((resolve, reject) => {
307+
ws.on('message', (msg) => {
308+
try {
273309
const json = JSON.parse(msg.toString());
274310
switch (json['type']) {
275311
case 'SpeechStarted': {
@@ -300,7 +336,9 @@ export class SpeechStream extends stt.SpeechStream {
300336
if (alternatives[0] && alternatives[0].text) {
301337
if (!this.#speaking) {
302338
this.#speaking = true;
303-
this.queue.put({ type: stt.SpeechEventType.START_OF_SPEECH });
339+
this.queue.put({
340+
type: stt.SpeechEventType.START_OF_SPEECH,
341+
});
304342
}
305343

306344
if (isFinal) {
@@ -334,15 +372,24 @@ export class SpeechStream extends stt.SpeechStream {
334372
break;
335373
}
336374
}
337-
});
338-
} catch (error) {
339-
this.#logger.child({ error }).warn('unrecoverable error, exiting');
340-
break;
341-
}
342-
}
343-
};
344375

345-
await Promise.race([this.#resetWS.await, Promise.all([sendTask(), listenTask(), wsMonitor])]);
376+
if (this.closed || closing) {
377+
resolve();
378+
}
379+
} catch (err) {
380+
this.#logger.error(`STT: Error processing message: ${msg}`);
381+
reject(err);
382+
}
383+
});
384+
});
385+
386+
await Promise.race([listenMessage, waitForAbort(controller.signal)]);
387+
}, this.abortController);
388+
389+
await Promise.race([
390+
this.#resetWS.await,
391+
Promise.all([sendTask(), listenTask.result, wsMonitor]),
392+
]);
346393
closing = true;
347394
ws.close();
348395
clearInterval(keepalive);

0 commit comments

Comments
 (0)