Skip to content

Commit

Permalink
[WIP] Implementing No Speech Processor elixir-nx#379
Browse files Browse the repository at this point in the history
  • Loading branch information
tubedude committed Oct 6, 2024
1 parent 3b56c7f commit 0271ac1
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 20 deletions.
77 changes: 60 additions & 17 deletions lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
defn_options: [],
preallocate_params: false,
task: :transcribe,
stream: false
stream: false,
logprob_threshold: 0.6,
no_speech_threshold: -1.0
])

%{model: model, params: params, spec: spec} = model_info
Expand Down Expand Up @@ -59,7 +61,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
context_num_seconds: context_num_seconds
}

{generate_opts, generation_config} = generate_opts(generation_config, opts)
{generate_opts, generation_config} = generate_opts(model_info, generation_config, opts)
generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts)

generate_fun = fn params, {inputs, seed} ->
Expand Down Expand Up @@ -210,27 +212,68 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
end
end

defp generate_opts(generation_config, opts) do
defp generate_opts(model_info, generation_config, opts) do
forced_token_ids = forced_token_ids(opts, generation_config.extra_config)
generation_config = %{generation_config | forced_token_ids: forced_token_ids}

logits_processors =
if opts[:timestamps] do
[
&Bumblebee.Text.Generation.LogitsProcessing.whisper_timestamp_processor(&1, &2,
eos_token_id: generation_config.eos_token_id,
forced_token_ids: generation_config.forced_token_ids,
no_timestamps_token_id: generation_config.extra_config.no_timestamps_token_id,
timestamp_begin_id: generation_config.extra_config.no_timestamps_token_id + 1
)
]
else
[]
end
[]
|> add_no_speech_detection_processor(opts, model_info, generation_config)
|> add_timestamp_processor(opts, generation_config)
|> add_suppress_tokens_at_begin_processor(opts, model_info)

opts = [logits_processors: logits_processors]
{[logits_processors: logits_processors], generation_config}
end

defp add_timestamp_processor(processors, opts, generation_config) do
timestamp_processor =
&Bumblebee.Text.Generation.LogitsProcessing.whisper_timestamp_processor(&1, &2,
eos_token_id: generation_config.eos_token_id,
forced_token_ids: generation_config.forced_token_ids,
no_timestamps_token_id: generation_config.extra_config.no_timestamps_token_id,
timestamp_begin_id: generation_config.extra_config.no_timestamps_token_id + 1
)

if opts[:timestamps] do
[timestamp_processor | processors]
else
processors
end
end

{opts, generation_config}
defp add_suppress_tokens_at_begin_processor(processors, opts, model_info) do
suppress_tokens_at_begin_processor =
&Bumblebee.Text.Generation.LogitsProcessing.suppress_tokens_at_begin_processor(
&1,
&2,
begin_suppress_tokens:
Nx.tensor(model_info.spec.begin_suppress_tokens) |> Nx.new_axis(-1),
begin_index: opts[:begin_index] || 1
)

[suppress_tokens_at_begin_processor | processors]
end

defp add_no_speech_detection_processor(processors, opts, model_info, generation_config) do
no_speech_detection_processor = fn logits, context ->
%{logits: processed_logits, skip: should_skip} =
Bumblebee.Text.Generation.LogitsProcessing.whisper_no_speech_detection(
logits,
context,
no_speech_token: List.first(model_info.spec.begin_suppress_tokens),
forced_token_ids: generation_config.forced_token_ids,
no_speech_threshold: opts[:no_speech_threshold],
logprob_threshold: opts[:logprob_threshold]
)

Nx.select(should_skip, logits, processed_logits)
end

if opts[:no_speech_threshold] && opts[:logprob_threshold] do
[no_speech_detection_processor | processors]
else
processors
end
end

defp forced_token_ids(opts, extra_config) do
Expand Down
14 changes: 12 additions & 2 deletions lib/bumblebee/audio/whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ defmodule Bumblebee.Audio.Whisper do
#{Shared.options_doc(options)}
"""

defstruct [architecture: :base] ++ Shared.option_defaults(options)
defstruct [
architecture: :base,
suppress_tokens: [],
begin_suppress_tokens: [],
forced_decoder_ids: [],
no_timestamps_token_id: nil
] ++ Shared.option_defaults(options)

@behaviour Bumblebee.ModelSpec
@behaviour Bumblebee.Configurable
Expand Down Expand Up @@ -520,7 +526,11 @@ defmodule Bumblebee.Audio.Whisper do
dropout_rate: {"dropout", number()},
attention_dropout_rate: {"attention_dropout", number()},
activation_dropout_rate: {"activation_dropout", number()},
initializer_scale: {"init_std", number()}
initializer_scale: {"init_std", number()},
suppress_tokens: {"suppress_tokens", list(number())},
begin_suppress_tokens: {"begin_suppress_tokens", list(number())},
forced_decoder_ids: {"forced_decoder_ids", list(list(number()))},
no_timestamps_token_id: {"no_timestamps_token", number()}
) ++ Shared.common_options_from_transformers(data, spec)

@for.config(spec, opts)
Expand Down
53 changes: 53 additions & 0 deletions lib/bumblebee/text/generation/logits_processing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,43 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
end
end

defn whisper_no_speech_detection(logits, context, opts \\ []) do
opts =
keyword!(opts, [
:no_speech_token,
:forced_token_ids,
:no_speech_threshold,
:logprob_threshold
])

no_speech_token = opts[:no_speech_token]
begin_idx = begin_idx(opts[:forced_token_ids])
no_speech_threshold = opts[:no_speech_threshold]
logprob_threshold = opts[:logprob_threshold]

if context.length == begin_idx do
scores = Axon.Activations.log_softmax(logits)
no_speech_logprob = scores[no_speech_token]
avg_logprob = compute_avg_logprob(scores, context.sequence)

should_skip =
Nx.logical_and(
Nx.greater(no_speech_logprob, Nx.log(no_speech_threshold)),
Nx.less(avg_logprob, Nx.log(logprob_threshold))
)

%{logits: logits, skip: should_skip}
else
%{logits: logits, skip: Nx.tensor(false)}
end
end

defnp compute_avg_logprob(scores, sequence) do
sequence_length = Nx.size(sequence)
total_logprob = Nx.sum(Nx.take(scores, sequence))
Nx.divide(total_logprob, sequence_length)
end

defnp force_timestamp_pair(logits, context, begin_idx, eos_token_id, timestamp_begin_id) do
# Force timestamp tokens to appear in pairs, end followed by
# start, except directly before the EOS token
Expand Down Expand Up @@ -253,6 +290,22 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
Nx.select(ignore_mask, Nx.Constants.neg_infinity(Nx.type(logits)), logits)
end

defn suppress_tokens_at_begin_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:begin_suppress_tokens, :begin_index])
begin_suppress_tokens = opts[:begin_suppress_tokens]
begin_index = opts[:begin_index]

if context.length == begin_index do
# Set the logits for the suppressed tokens to negative infinity
values =
Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), {Nx.size(begin_suppress_tokens)})

Nx.indexed_put(logits, begin_suppress_tokens, values)
else
logits
end
end

deftransformp begin_idx(forced_token_ids) do
case List.last(forced_token_ids) do
nil -> 1
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defmodule Bumblebee.MixProject do
use Mix.Project

@version "0.5.3"
@version "0.5.4"
@description "Pre-trained and transformer Neural Network models in Axon"

def project do
Expand Down

0 comments on commit 0271ac1

Please sign in to comment.