From 812185f94cec57feb910d7a82f1bd43f01ee5b0b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 31 May 2024 23:38:16 +0200 Subject: [PATCH] Implement `WhisperTextStreamer` helper class --- src/generation/streamers.js | 81 ++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/src/generation/streamers.js b/src/generation/streamers.js index 5dcbd84df..84fdedbfc 100644 --- a/src/generation/streamers.js +++ b/src/generation/streamers.js @@ -39,7 +39,7 @@ export class TextStreamer extends BaseStreamer { constructor(tokenizer, { skip_prompt = false, callback_function = null, - ...decode_kwargs + decode_kwargs = {}, } = {}) { super(); this.tokenizer = tokenizer; @@ -124,3 +124,82 @@ export class TextStreamer extends BaseStreamer { } } } + +/** + * Utility class to handle streaming of tokens generated by whisper speech-to-text models. + * Callback functions are invoked when each of the following events occur: + * - A new chunk starts (on_chunk_start) + * - A new token is generated (callback_function) + * - A chunk ends (on_chunk_end) + * - The stream is finalized (on_finalize) + */ +export class WhisperTextStreamer extends TextStreamer { + /** + * @param {import('../tokenizers.js').WhisperTokenizer} tokenizer + * @param {Object} options + * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens + * @param {function(string): void} [options.callback_function=null] Function to call when a new token is generated + * @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts + * @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends + * @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized + * @param {number} [options.time_precision=0.02] Precision of the timestamps + * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding + * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method + */ + constructor(tokenizer, { + skip_prompt = false, + callback_function = null, + on_chunk_start = null, + on_chunk_end = null, + on_finalize = null, + time_precision = 0.02, + skip_special_tokens = true, + decode_kwargs = {}, + } = {}) { + super(tokenizer, { + skip_prompt, + callback_function, + decode_kwargs: { skip_special_tokens, ...decode_kwargs }, + }); + this.timestamp_begin = tokenizer.timestamp_begin; + + this.on_chunk_start = on_chunk_start; + this.on_chunk_end = on_chunk_end; + this.on_finalize = on_finalize; + + this.time_precision = time_precision; + + this.waiting_for_timestamp = false; + } + + /** + * @param {bigint[][]} value + */ + put(value) { + if (value.length > 1) { + throw Error('WhisperTextStreamer only supports batch size of 1'); + } + const tokens = value[0]; + + // Check if the token is a timestamp + if (tokens.length === 1) { + const offset = Number(tokens[0]) - this.timestamp_begin; + if (offset >= 0) { + const time = offset * this.time_precision; + if (this.waiting_for_timestamp) { + this.on_chunk_end?.(time); + } else { + this.on_chunk_start?.(time); + } + this.waiting_for_timestamp = !this.waiting_for_timestamp; // Toggle + value = [[]]; // Skip timestamp + } + } + return super.put(value); + } + + end() { + super.end(); + this.on_finalize?.(); + } +}