From 0d2c06614b1e539981910825c5f32d1a93e816b4 Mon Sep 17 00:00:00 2001 From: Roberto Trevisan Date: Fri, 4 Oct 2024 18:14:58 -0300 Subject: [PATCH] [WIP] Implementing Silence Processor #379 --- lib/bumblebee/audio/speech_to_text_whisper.ex | 60 ++++++++++++----- .../text/generation/logits_processing.ex | 64 +++++++++++++++++++ mix.exs | 2 +- 3 files changed, 108 insertions(+), 18 deletions(-) diff --git a/lib/bumblebee/audio/speech_to_text_whisper.ex b/lib/bumblebee/audio/speech_to_text_whisper.ex index e52e8ce2..f758a2ad 100644 --- a/lib/bumblebee/audio/speech_to_text_whisper.ex +++ b/lib/bumblebee/audio/speech_to_text_whisper.ex @@ -23,7 +23,9 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do defn_options: [], preallocate_params: false, task: :transcribe, - stream: false + stream: false, + logprob_threshold: 0.6, + no_speech_threshold: -1.0 ]) %{model: model, params: params, spec: spec} = model_info @@ -59,7 +61,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do context_num_seconds: context_num_seconds } - {generate_opts, generation_config} = generate_opts(generation_config, opts) + {generate_opts, generation_config} = generate_opts(model_info, generation_config, opts) generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts) generate_fun = fn params, {inputs, seed} -> @@ -210,27 +212,51 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do end end - defp generate_opts(generation_config, opts) do + defp generate_opts(model_info, generation_config, opts) do forced_token_ids = forced_token_ids(opts, generation_config.extra_config) generation_config = %{generation_config | forced_token_ids: forced_token_ids} logits_processors = - if opts[:timestamps] do - [ - &Bumblebee.Text.Generation.LogitsProcessing.whisper_timestamp_processor(&1, &2, - eos_token_id: generation_config.eos_token_id, - forced_token_ids: generation_config.forced_token_ids, - no_timestamps_token_id: generation_config.extra_config.no_timestamps_token_id, - timestamp_begin_id: generation_config.extra_config.no_timestamps_token_id + 1 - ) - ] - else - [] - end + [] + |> add_timestamp_processor(opts, generation_config) + |> add_silence_processor(opts, model_info, generation_config) - opts = [logits_processors: logits_processors] + {[logits_processors: logits_processors], generation_config} + end - {opts, generation_config} + defp add_timestamp_processor(processors, opts, generation_config) do + if opts[:timestamps] do + [ + (&Bumblebee.Text.Generation.LogitsProcessing.whisper_timestamp_processor(&1, &2, + eos_token_id: generation_config.eos_token_id, + forced_token_ids: generation_config.forced_token_ids, + no_timestamps_token_id: generation_config.extra_config.no_timestamps_token_id, + timestamp_begin_id: generation_config.extra_config.no_timestamps_token_id + 1 + )) + | processors + ] + else + processors + end + end + + defp add_silence_processor(processors, opts, model_info, generation_config) do + no_speech_threshold = Keyword.get(opts, :no_speech_threshold) + logprob_threshold = Keyword.get(opts, :logprob_threshold) + + if no_speech_threshold && logprob_threshold do + [ + (&Bumblebee.Text.Generation.LogitsProcessing.whisper_silence_processor(&1, &2, + no_speech_threshold: no_speech_threshold, + logprob_threshold: logprob_threshold, + vocab_size: model_info.spec.vocab_size, + suppress_tokens: Nx.tensor(generation_config.suppressed_token_ids) + )) + | processors + ] + else + processors + end end defp forced_token_ids(opts, extra_config) do diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index eff38e52..b8c606a5 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -201,6 +201,70 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do end end + defn whisper_silence_processor(logits, context, opts \\ []) do + opts = + keyword!(opts, [:no_speech_threshold, :logprob_threshold, :vocab_size, :suppress_tokens]) + + # Convert to tensor + suppress_tokens = opts[:suppress_tokens] + no_speech_threshold = opts[:no_speech_threshold] + logprob_threshold = opts[:logprob_threshold] + vocab_size = opts[:vocab_size] + + scores = Axon.Activations.log_softmax(logits) + no_speech_prob = compute_no_speech_probability(logits) + avg_logprob = compute_avg_logprob(scores, context.sequence) + + Nx.select( + no_speech_prob > no_speech_threshold and avg_logprob < logprob_threshold, + suppress_logits(logits, vocab_size, suppress_tokens), + logits + ) + end + + defnp compute_no_speech_probability(logits) do + # In Whisper, the no_speech probability is typically the first token's probability + # We apply softmax to get probabilities from logits + probs = Axon.Activations.log_softmax(logits) + probs[0] + end + + defnp compute_avg_logprob(scores, sequence) do + # We need to compute the average log probability of the sequence + # scores should be a list of log probabilities for each token + sequence_length = Nx.size(sequence) + + # Sum the log probabilities of the generated tokens + total_logprob = Nx.sum(Nx.take(scores, sequence)) + + # Compute average log probability + Nx.divide(total_logprob, sequence_length) + end + + defnp suppress_logits(logits, vocab_size, suppress_tokens) do + # Create a mask for tokens to suppress + suppress_mask = Nx.broadcast(Nx.tensor(false, type: {:u, 8}), {vocab_size}) + + # Reshape suppress_tokens to have shape {n, 1} + indices = Nx.new_axis(suppress_tokens, -1) + + # Broadcast updates to match the leading dimensions of indices (shape {n}) + updates = Nx.broadcast(Nx.tensor(true, type: {:u, 8}), Nx.shape(suppress_tokens)) + + # Set mask to true for tokens we want to suppress + suppress_mask = Nx.indexed_put(suppress_mask, indices, updates) + + # Apply the suppression + suppressed_logits = + Nx.select( + suppress_mask, + Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.shape(logits)), + logits + ) + + suppressed_logits + end + defnp force_timestamp_pair(logits, context, begin_idx, eos_token_id, timestamp_begin_id) do # Force timestamp tokens to appear in pairs, end followed by # start, except directly before the EOS token diff --git a/mix.exs b/mix.exs index 6c301497..9199445e 100644 --- a/mix.exs +++ b/mix.exs @@ -1,7 +1,7 @@ defmodule Bumblebee.MixProject do use Mix.Project - @version "0.5.3" + @version "0.5.4" @description "Pre-trained and transformer Neural Network models in Axon" def project do