Skip to content

Commit

Permalink
Improve whisper tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed May 31, 2024
1 parent 812185f commit 1cfdf10
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -3488,6 +3488,10 @@ export class M2M100Tokenizer extends PreTrainedTokenizer {
export class WhisperTokenizer extends PreTrainedTokenizer {
_default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`;

get timestamp_begin() {
return this.model.convert_tokens_to_ids(["<|notimestamps|>"])[0] + 1;
}

/**
* Decodes automatic speech recognition (ASR) sequences.
* @param {Array<{tokens: bigint[], token_timestamps?: number[], stride: number[]}>} sequences The sequences to decode.
Expand Down Expand Up @@ -3534,7 +3538,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
const chunks = [];
let chunk = new_chunk();
let time_offset = 0.0;
const timestamp_begin = this.model.convert_tokens_to_ids(["<|notimestamps|>"])[0] + 1;
const timestamp_begin = this.timestamp_begin;

let previous_tokens = [];
let previous_token_timestamps = [];
Expand Down Expand Up @@ -3572,11 +3576,11 @@ export class WhisperTokenizer extends PreTrainedTokenizer {

if (stride_right) {
for (let i = token_ids.length - 1; i >= 0; --i) {
const token = token_ids[i];
const token = Number(token_ids[i]);
if (token >= timestamp_begin) {
// There can be several token in the right stride
// But the last one is ALWAYS going to be skipped
if (last_timestamp !== null && (Number(token) - timestamp_begin) * time_precision < right_stride_start) {
if (last_timestamp !== null && (token - timestamp_begin) * time_precision < right_stride_start) {
break;
}
last_timestamp = token;
Expand Down Expand Up @@ -3924,7 +3928,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
) {
let text;
// @ts-ignore
if (decode_args && decode_args.decode_with_timestamps) {
if (decode_args?.decode_with_timestamps) {
if (token_ids instanceof Tensor) {
token_ids = prepareTensorForDecode(token_ids);
}
Expand All @@ -3950,23 +3954,18 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
const timestamp_begin = Array.from(this.all_special_ids).at(-1) + 1;
/**@type {Array} */
let outputs = [[]];
for (const token of token_ids) {
for (let token of token_ids) {
token = Number(token);
if (token >= timestamp_begin) {
const timestamp = round((Number(token) - timestamp_begin) * time_precision, 2);
const timestamp = ((token - timestamp_begin) * time_precision).toFixed(2);
outputs.push(`<|${timestamp}|>`);
outputs.push([]);
} else {
outputs[outputs.length - 1].push(token);
}
}
outputs = outputs.map(
s => {
if (typeof s === 'string') {
return s;
} else {
return super.decode(s, decode_args);
}
}
s => typeof s === 'string' ? s : super.decode(s, decode_args)
)

return outputs.join('');
Expand Down

0 comments on commit 1cfdf10

Please sign in to comment.