From cf8b12ef467d7cc34b363c43d2d1813b62fd785e Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 10 Dec 2024 19:07:55 +0100 Subject: [PATCH] Merge more internal changes. --- rust/moshi-backend/src/stream_both.rs | 1 + .../moshi-core/src/lm_generate_multistream.rs | 41 ++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/rust/moshi-backend/src/stream_both.rs b/rust/moshi-backend/src/stream_both.rs index 9dabcc2..5266f9b 100644 --- a/rust/moshi-backend/src/stream_both.rs +++ b/rust/moshi-backend/src/stream_both.rs @@ -538,6 +538,7 @@ impl StreamingModel { ); let mut state = moshi::lm_generate_multistream::State::new( lm_model, + None, self.session_config.max_steps, audio_lp, text_lp, diff --git a/rust/moshi-core/src/lm_generate_multistream.rs b/rust/moshi-core/src/lm_generate_multistream.rs index 9ac5ad4..15e9f21 100644 --- a/rust/moshi-core/src/lm_generate_multistream.rs +++ b/rust/moshi-core/src/lm_generate_multistream.rs @@ -43,6 +43,18 @@ impl Config { } } + pub fn v0_1_one_way() -> Self { + Self { + generated_audio_codebooks: 8, + input_audio_codebooks: 0, + audio_vocab_size: 2049, + acoustic_delay: 2, + text_eop_token: 0, + text_pad_token: 3, + text_start_token: 32000, + } + } + pub fn audio_pad_token(&self) -> u32 { self.audio_vocab_size as u32 - 1 } @@ -54,6 +66,7 @@ impl Config { pub struct State { model: crate::lm::LmModel, + ca_src: Option, audio_tokens: Vec>, text_tokens: Vec, audio_lp: LogitsProcessor, @@ -66,8 +79,10 @@ pub struct State { } impl State { + #[allow(clippy::too_many_arguments)] pub fn new( model: crate::lm::LmModel, + ca_src: Option, max_step_idx: usize, audio_lp: LogitsProcessor, text_lp: LogitsProcessor, @@ -82,6 +97,7 @@ impl State { let text_tokens = vec![UNGENERATED; max_step_idx + config.acoustic_delay]; Self { model, + ca_src, audio_tokens, text_tokens, audio_lp, @@ -150,16 +166,16 @@ impl State { // The acoustic tokens are written with a delay, so this can create "gaps" of UNGENERATED // tokens in the case where we call `step_audio_prompt` *after* `step`. - pub fn step( + pub fn step_( &mut self, - text_token: u32, + text_token: Option, input_audio_tokens: &[u32], force_text_token: Option, ) -> candle::Result { let mut codes = Vec::with_capacity(self.config.total_audio_codebooks()); let dev = self.model.device(); for (c_idx, &t) in input_audio_tokens.iter().enumerate() { - self.audio_tokens[self.step_idx][c_idx + 8] = t + self.audio_tokens[self.step_idx][c_idx + self.config.generated_audio_codebooks] = t } for codebook in 0..self.config.total_audio_codebooks() { let t = if codebook == 0 || codebook == 8 { @@ -179,8 +195,14 @@ impl State { let t = Tensor::new(&[t], dev)?.unsqueeze(0)?; codes.push(Some(t)) } - let text_token = Some(Tensor::from_vec(vec![text_token], (1, 1), dev)?); - let (text_logits, ys) = self.model.forward(text_token, codes)?; + let text_token = match text_token { + Some(text_token) => Some(Tensor::from_vec(vec![text_token], (1, 1), dev)?), + None => None, + }; + let (text_logits, ys) = match self.ca_src.as_ref() { + None => self.model.forward(text_token, codes)?, + Some(ca_src) => self.model.forward_ca(text_token, codes, ca_src)?, + }; let text_logits = text_logits.i((0, 0))?; let text_logits = self.apply_repetition_penalty(text_logits)?; let text_token = match force_text_token { @@ -222,6 +244,15 @@ impl State { Ok(text_token) } + pub fn step( + &mut self, + text_token: u32, + input_audio_tokens: &[u32], + force_text_token: Option, + ) -> candle::Result { + self.step_(Some(text_token), input_audio_tokens, force_text_token) + } + /// If include_all is set, all the time steps are returned. Otherwise only the timesteps that /// have been generated are handled. pub fn audio_tokens(&self, include_all: bool) -> &[Vec] {