From 0271ac1a922dbe2dbad0328b26b3b33e18140439 Mon Sep 17 00:00:00 2001 From: Roberto Trevisan Date: Fri, 4 Oct 2024 18:14:58 -0300 Subject: [PATCH] [WIP] Implementing No Speech Processor #379 --- lib/bumblebee/audio/speech_to_text_whisper.ex | 77 +++++++++++++++---- lib/bumblebee/audio/whisper.ex | 14 +++- .../text/generation/logits_processing.ex | 53 +++++++++++++ mix.exs | 2 +- 4 files changed, 126 insertions(+), 20 deletions(-) diff --git a/lib/bumblebee/audio/speech_to_text_whisper.ex b/lib/bumblebee/audio/speech_to_text_whisper.ex index e52e8ce2..8fe36773 100644 --- a/lib/bumblebee/audio/speech_to_text_whisper.ex +++ b/lib/bumblebee/audio/speech_to_text_whisper.ex @@ -23,7 +23,9 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do defn_options: [], preallocate_params: false, task: :transcribe, - stream: false + stream: false, + logprob_threshold: 0.6, + no_speech_threshold: -1.0 ]) %{model: model, params: params, spec: spec} = model_info @@ -59,7 +61,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do context_num_seconds: context_num_seconds } - {generate_opts, generation_config} = generate_opts(generation_config, opts) + {generate_opts, generation_config} = generate_opts(model_info, generation_config, opts) generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts) generate_fun = fn params, {inputs, seed} -> @@ -210,27 +212,68 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do end end - defp generate_opts(generation_config, opts) do + defp generate_opts(model_info, generation_config, opts) do forced_token_ids = forced_token_ids(opts, generation_config.extra_config) generation_config = %{generation_config | forced_token_ids: forced_token_ids} logits_processors = - if opts[:timestamps] do - [ - &Bumblebee.Text.Generation.LogitsProcessing.whisper_timestamp_processor(&1, &2, - eos_token_id: generation_config.eos_token_id, - forced_token_ids: generation_config.forced_token_ids, - no_timestamps_token_id: generation_config.extra_config.no_timestamps_token_id, - timestamp_begin_id: generation_config.extra_config.no_timestamps_token_id + 1 - ) - ] - else - [] - end + [] + |> add_no_speech_detection_processor(opts, model_info, generation_config) + |> add_timestamp_processor(opts, generation_config) + |> add_suppress_tokens_at_begin_processor(opts, model_info) - opts = [logits_processors: logits_processors] + {[logits_processors: logits_processors], generation_config} + end + + defp add_timestamp_processor(processors, opts, generation_config) do + timestamp_processor = + &Bumblebee.Text.Generation.LogitsProcessing.whisper_timestamp_processor(&1, &2, + eos_token_id: generation_config.eos_token_id, + forced_token_ids: generation_config.forced_token_ids, + no_timestamps_token_id: generation_config.extra_config.no_timestamps_token_id, + timestamp_begin_id: generation_config.extra_config.no_timestamps_token_id + 1 + ) + + if opts[:timestamps] do + [timestamp_processor | processors] + else + processors + end + end - {opts, generation_config} + defp add_suppress_tokens_at_begin_processor(processors, opts, model_info) do + suppress_tokens_at_begin_processor = + &Bumblebee.Text.Generation.LogitsProcessing.suppress_tokens_at_begin_processor( + &1, + &2, + begin_suppress_tokens: + Nx.tensor(model_info.spec.begin_suppress_tokens) |> Nx.new_axis(-1), + begin_index: opts[:begin_index] || 1 + ) + + [suppress_tokens_at_begin_processor | processors] + end + + defp add_no_speech_detection_processor(processors, opts, model_info, generation_config) do + no_speech_detection_processor = fn logits, context -> + %{logits: processed_logits, skip: should_skip} = + Bumblebee.Text.Generation.LogitsProcessing.whisper_no_speech_detection( + logits, + context, + no_speech_token: List.first(model_info.spec.begin_suppress_tokens), + forced_token_ids: generation_config.forced_token_ids, + no_speech_threshold: opts[:no_speech_threshold], + logprob_threshold: opts[:logprob_threshold] + ) + + Nx.select(should_skip, logits, processed_logits) + end + + if opts[:no_speech_threshold] && opts[:logprob_threshold] do + [no_speech_detection_processor | processors] + else + processors + end end defp forced_token_ids(opts, extra_config) do diff --git a/lib/bumblebee/audio/whisper.ex b/lib/bumblebee/audio/whisper.ex index 9b78c29d..01de12f5 100644 --- a/lib/bumblebee/audio/whisper.ex +++ b/lib/bumblebee/audio/whisper.ex @@ -166,7 +166,13 @@ defmodule Bumblebee.Audio.Whisper do #{Shared.options_doc(options)} """ - defstruct [architecture: :base] ++ Shared.option_defaults(options) + defstruct [ + architecture: :base, + suppress_tokens: [], + begin_suppress_tokens: [], + forced_decoder_ids: [], + no_timestamps_token_id: nil + ] ++ Shared.option_defaults(options) @behaviour Bumblebee.ModelSpec @behaviour Bumblebee.Configurable @@ -520,7 +526,11 @@ defmodule Bumblebee.Audio.Whisper do dropout_rate: {"dropout", number()}, attention_dropout_rate: {"attention_dropout", number()}, activation_dropout_rate: {"activation_dropout", number()}, - initializer_scale: {"init_std", number()} + initializer_scale: {"init_std", number()}, + suppress_tokens: {"suppress_tokens", list(number())}, + begin_suppress_tokens: {"begin_suppress_tokens", list(number())}, + forced_decoder_ids: {"forced_decoder_ids", list(list(number()))}, + no_timestamps_token_id: {"no_timestamps_token", number()} ) ++ Shared.common_options_from_transformers(data, spec) @for.config(spec, opts) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index eff38e52..d3f9ff4f 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -201,6 +201,43 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do end end + defn whisper_no_speech_detection(logits, context, opts \\ []) do + opts = + keyword!(opts, [ + :no_speech_token, + :forced_token_ids, + :no_speech_threshold, + :logprob_threshold + ]) + + no_speech_token = opts[:no_speech_token] + begin_idx = begin_idx(opts[:forced_token_ids]) + no_speech_threshold = opts[:no_speech_threshold] + logprob_threshold = opts[:logprob_threshold] + + if context.length == begin_idx do + scores = Axon.Activations.log_softmax(logits) + no_speech_logprob = scores[no_speech_token] + avg_logprob = compute_avg_logprob(scores, context.sequence) + + should_skip = + Nx.logical_and( + Nx.greater(no_speech_logprob, Nx.log(no_speech_threshold)), + Nx.less(avg_logprob, Nx.log(logprob_threshold)) + ) + + %{logits: logits, skip: should_skip} + else + %{logits: logits, skip: Nx.tensor(false)} + end + end + + defnp compute_avg_logprob(scores, sequence) do + sequence_length = Nx.size(sequence) + total_logprob = Nx.sum(Nx.take(scores, sequence)) + Nx.divide(total_logprob, sequence_length) + end + defnp force_timestamp_pair(logits, context, begin_idx, eos_token_id, timestamp_begin_id) do # Force timestamp tokens to appear in pairs, end followed by # start, except directly before the EOS token @@ -253,6 +290,22 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do Nx.select(ignore_mask, Nx.Constants.neg_infinity(Nx.type(logits)), logits) end + defn suppress_tokens_at_begin_processor(logits, context, opts \\ []) do + opts = keyword!(opts, [:begin_suppress_tokens, :begin_index]) + begin_suppress_tokens = opts[:begin_suppress_tokens] + begin_index = opts[:begin_index] + + if context.length == begin_index do + # Set the logits for the suppressed tokens to negative infinity + values = + Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), {Nx.size(begin_suppress_tokens)}) + + Nx.indexed_put(logits, begin_suppress_tokens, values) + else + logits + end + end + deftransformp begin_idx(forced_token_ids) do case List.last(forced_token_ids) do nil -> 1 diff --git a/mix.exs b/mix.exs index 6c301497..9199445e 100644 --- a/mix.exs +++ b/mix.exs @@ -1,7 +1,7 @@ defmodule Bumblebee.MixProject do use Mix.Project - @version "0.5.3" + @version "0.5.4" @description "Pre-trained and transformer Neural Network models in Axon" def project do