diff --git a/src/tokenizers.js b/src/tokenizers.js index 57ab21fa8..1d92bc074 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -3488,6 +3488,10 @@ export class M2M100Tokenizer extends PreTrainedTokenizer { export class WhisperTokenizer extends PreTrainedTokenizer { _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`; + get timestamp_begin() { + return this.model.convert_tokens_to_ids(["<|notimestamps|>"])[0] + 1; + } + /** * Decodes automatic speech recognition (ASR) sequences. * @param {Array<{tokens: bigint[], token_timestamps?: number[], stride: number[]}>} sequences The sequences to decode. @@ -3534,7 +3538,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { const chunks = []; let chunk = new_chunk(); let time_offset = 0.0; - const timestamp_begin = this.model.convert_tokens_to_ids(["<|notimestamps|>"])[0] + 1; + const timestamp_begin = this.timestamp_begin; let previous_tokens = []; let previous_token_timestamps = []; @@ -3572,11 +3576,11 @@ export class WhisperTokenizer extends PreTrainedTokenizer { if (stride_right) { for (let i = token_ids.length - 1; i >= 0; --i) { - const token = token_ids[i]; + const token = Number(token_ids[i]); if (token >= timestamp_begin) { // There can be several token in the right stride // But the last one is ALWAYS going to be skipped - if (last_timestamp !== null && (Number(token) - timestamp_begin) * time_precision < right_stride_start) { + if (last_timestamp !== null && (token - timestamp_begin) * time_precision < right_stride_start) { break; } last_timestamp = token; @@ -3924,7 +3928,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { ) { let text; // @ts-ignore - if (decode_args && decode_args.decode_with_timestamps) { + if (decode_args?.decode_with_timestamps) { if (token_ids instanceof Tensor) { token_ids = prepareTensorForDecode(token_ids); } @@ -3950,9 +3954,10 @@ export class WhisperTokenizer extends PreTrainedTokenizer { const timestamp_begin = Array.from(this.all_special_ids).at(-1) + 1; /**@type {Array} */ let outputs = [[]]; - for (const token of token_ids) { + for (let token of token_ids) { + token = Number(token); if (token >= timestamp_begin) { - const timestamp = round((Number(token) - timestamp_begin) * time_precision, 2); + const timestamp = ((token - timestamp_begin) * time_precision).toFixed(2); outputs.push(`<|${timestamp}|>`); outputs.push([]); } else { @@ -3960,13 +3965,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { } } outputs = outputs.map( - s => { - if (typeof s === 'string') { - return s; - } else { - return super.decode(s, decode_args); - } - } + s => typeof s === 'string' ? s : super.decode(s, decode_args) ) return outputs.join('');