@@ -6,11 +6,13 @@ import {
66 AudioByteStream ,
77 AudioEnergyFilter ,
88 Future ,
9+ Task ,
910 log ,
1011 stt ,
12+ waitForAbort ,
1113} from '@livekit/agents' ;
1214import type { AudioFrame } from '@livekit/rtc-node' ;
13- import { type RawData , WebSocket } from 'ws' ;
15+ import { WebSocket } from 'ws' ;
1416import { PeriodicCollector } from './_utils.js' ;
1517import type { STTLanguages , STTModels } from './models.js' ;
1618
@@ -62,6 +64,7 @@ export class STT extends stt.STT {
6264 #opts: STTOptions ;
6365 #logger = log ( ) ;
6466 label = 'deepgram.STT' ;
67+ private abortController = new AbortController ( ) ;
6568
6669 constructor ( opts : Partial < STTOptions > = defaultSTTOptions ) {
6770 super ( {
@@ -111,7 +114,11 @@ export class STT extends stt.STT {
111114 }
112115
113116 stream ( ) : SpeechStream {
114- return new SpeechStream ( this , this . #opts) ;
117+ return new SpeechStream ( this , this . #opts, this . abortController ) ;
118+ }
119+
120+ async close ( ) {
121+ this . abortController . abort ( ) ;
115122 }
116123}
117124
@@ -125,7 +132,11 @@ export class SpeechStream extends stt.SpeechStream {
125132 #audioDurationCollector: PeriodicCollector < number > ;
126133 label = 'deepgram.SpeechStream' ;
127134
128- constructor ( stt : STT , opts : STTOptions ) {
135+ constructor (
136+ stt : STT ,
137+ opts : STTOptions ,
138+ private abortController : AbortController ,
139+ ) {
129140 super ( stt , opts . sampleRate ) ;
130141 this . #opts = opts ;
131142 this . closed = false ;
@@ -140,7 +151,8 @@ export class SpeechStream extends stt.SpeechStream {
140151 const maxRetry = 32 ;
141152 let retries = 0 ;
142153 let ws : WebSocket ;
143- while ( ! this . input . closed ) {
154+
155+ while ( ! this . input . closed && ! this . closed ) {
144156 const streamURL = new URL ( API_BASE_URL_V1 ) ;
145157 const params = {
146158 model : this . #opts. model ,
@@ -185,17 +197,23 @@ export class SpeechStream extends stt.SpeechStream {
185197
186198 await this . #runWS( ws ) ;
187199 } catch ( e ) {
188- if ( retries >= maxRetry ) {
189- throw new Error ( `failed to connect to Deepgram after ${ retries } attempts: ${ e } ` ) ;
190- }
200+ if ( ! this . closed && ! this . input . closed ) {
201+ if ( retries >= maxRetry ) {
202+ throw new Error ( `failed to connect to Deepgram after ${ retries } attempts: ${ e } ` ) ;
203+ }
191204
192- const delay = Math . min ( retries * 5 , 10 ) ;
193- retries ++ ;
205+ const delay = Math . min ( retries * 5 , 10 ) ;
206+ retries ++ ;
194207
195- this . #logger. warn (
196- `failed to connect to Deepgram, retrying in ${ delay } seconds: ${ e } (${ retries } /${ maxRetry } )` ,
197- ) ;
198- await new Promise ( ( resolve ) => setTimeout ( resolve , delay * 1000 ) ) ;
208+ this . #logger. warn (
209+ `failed to connect to Deepgram, retrying in ${ delay } seconds: ${ e } (${ retries } /${ maxRetry } )` ,
210+ ) ;
211+ await new Promise ( ( resolve ) => setTimeout ( resolve , delay * 1000 ) ) ;
212+ } else {
213+ this . #logger. warn (
214+ `Deepgram disconnected, connection is closed: ${ e } (inputClosed: ${ this . input . closed } , isClosed: ${ this . closed } )` ,
215+ ) ;
216+ }
199217 }
200218 }
201219
@@ -220,6 +238,20 @@ export class SpeechStream extends stt.SpeechStream {
220238 }
221239 } , 5000 ) ;
222240
241+ // gets cancelled also when sendTask is complete
242+ const wsMonitor = Task . from ( async ( controller ) => {
243+ const closed = new Promise < void > ( async ( _ , reject ) => {
244+ ws . once ( 'close' , ( code , reason ) => {
245+ if ( ! closing ) {
246+ this . #logger. error ( `WebSocket closed with code ${ code } : ${ reason } ` ) ;
247+ reject ( new Error ( 'WebSocket closed' ) ) ;
248+ }
249+ } ) ;
250+ } ) ;
251+
252+ await Promise . race ( [ closed , waitForAbort ( controller . signal ) ] ) ;
253+ } ) ;
254+
223255 const sendTask = async ( ) => {
224256 const samples100Ms = Math . floor ( this . #opts. sampleRate / 10 ) ;
225257 const stream = new AudioByteStream (
@@ -228,48 +260,52 @@ export class SpeechStream extends stt.SpeechStream {
228260 samples100Ms ,
229261 ) ;
230262
231- for await ( const data of this . input ) {
232- let frames : AudioFrame [ ] ;
233- if ( data === SpeechStream . FLUSH_SENTINEL ) {
234- frames = stream . flush ( ) ;
235- this . #audioDurationCollector. flush ( ) ;
236- } else if (
237- data . sampleRate === this . #opts. sampleRate ||
238- data . channels === this . #opts. numChannels
239- ) {
240- frames = stream . write ( data . data . buffer ) ;
241- } else {
242- throw new Error ( `sample rate or channel count of frame does not match` ) ;
243- }
263+ try {
264+ while ( ! this . closed ) {
265+ const result = await Promise . race ( [
266+ this . input . next ( ) ,
267+ waitForAbort ( this . abortController . signal ) ,
268+ ] ) ;
269+
270+ if ( result === undefined ) return ; // aborted
271+ if ( result . done ) {
272+ break ;
273+ }
274+
275+ const data = result . value ;
276+
277+ let frames : AudioFrame [ ] ;
278+ if ( data === SpeechStream . FLUSH_SENTINEL ) {
279+ frames = stream . flush ( ) ;
280+ this . #audioDurationCollector. flush ( ) ;
281+ } else if (
282+ data . sampleRate === this . #opts. sampleRate ||
283+ data . channels === this . #opts. numChannels
284+ ) {
285+ frames = stream . write ( data . data . buffer as ArrayBuffer ) ;
286+ } else {
287+ throw new Error ( `sample rate or channel count of frame does not match` ) ;
288+ }
244289
245- for await ( const frame of frames ) {
246- if ( this . #audioEnergyFilter. pushFrame ( frame ) ) {
247- const frameDuration = frame . samplesPerChannel / frame . sampleRate ;
248- this . #audioDurationCollector. push ( frameDuration ) ;
249- ws . send ( frame . data . buffer ) ;
290+ for await ( const frame of frames ) {
291+ if ( this . #audioEnergyFilter. pushFrame ( frame ) ) {
292+ const frameDuration = frame . samplesPerChannel / frame . sampleRate ;
293+ this . #audioDurationCollector. push ( frameDuration ) ;
294+ ws . send ( frame . data . buffer ) ;
295+ }
250296 }
251297 }
298+ } finally {
299+ closing = true ;
300+ ws . send ( JSON . stringify ( { type : 'CloseStream' } ) ) ;
301+ wsMonitor . cancel ( ) ;
252302 }
253-
254- closing = true ;
255- ws . send ( JSON . stringify ( { type : 'CloseStream' } ) ) ;
256303 } ;
257304
258- const wsMonitor = new Promise < void > ( ( _ , reject ) =>
259- ws . once ( 'close' , ( code , reason ) => {
260- if ( ! closing ) {
261- this . #logger. error ( `WebSocket closed with code ${ code } : ${ reason } ` ) ;
262- reject ( new Error ( 'WebSocket closed' ) ) ;
263- }
264- } ) ,
265- ) ;
266-
267- const listenTask = async ( ) => {
268- while ( ! this . closed && ! closing ) {
269- try {
270- await new Promise < RawData > ( ( resolve ) => {
271- ws . once ( 'message' , ( data ) => resolve ( data ) ) ;
272- } ) . then ( ( msg ) => {
305+ const listenTask = Task . from ( async ( controller ) => {
306+ const listenMessage = new Promise < void > ( ( resolve , reject ) => {
307+ ws . on ( 'message' , ( msg ) => {
308+ try {
273309 const json = JSON . parse ( msg . toString ( ) ) ;
274310 switch ( json [ 'type' ] ) {
275311 case 'SpeechStarted' : {
@@ -300,7 +336,9 @@ export class SpeechStream extends stt.SpeechStream {
300336 if ( alternatives [ 0 ] && alternatives [ 0 ] . text ) {
301337 if ( ! this . #speaking) {
302338 this . #speaking = true ;
303- this . queue . put ( { type : stt . SpeechEventType . START_OF_SPEECH } ) ;
339+ this . queue . put ( {
340+ type : stt . SpeechEventType . START_OF_SPEECH ,
341+ } ) ;
304342 }
305343
306344 if ( isFinal ) {
@@ -334,15 +372,24 @@ export class SpeechStream extends stt.SpeechStream {
334372 break ;
335373 }
336374 }
337- } ) ;
338- } catch ( error ) {
339- this . #logger. child ( { error } ) . warn ( 'unrecoverable error, exiting' ) ;
340- break ;
341- }
342- }
343- } ;
344375
345- await Promise . race ( [ this . #resetWS. await , Promise . all ( [ sendTask ( ) , listenTask ( ) , wsMonitor ] ) ] ) ;
376+ if ( this . closed || closing ) {
377+ resolve ( ) ;
378+ }
379+ } catch ( err ) {
380+ this . #logger. error ( `STT: Error processing message: ${ msg } ` ) ;
381+ reject ( err ) ;
382+ }
383+ } ) ;
384+ } ) ;
385+
386+ await Promise . race ( [ listenMessage , waitForAbort ( controller . signal ) ] ) ;
387+ } , this . abortController ) ;
388+
389+ await Promise . race ( [
390+ this . #resetWS. await ,
391+ Promise . all ( [ sendTask ( ) , listenTask . result , wsMonitor ] ) ,
392+ ] ) ;
346393 closing = true ;
347394 ws . close ( ) ;
348395 clearInterval ( keepalive ) ;
0 commit comments