Skip to content

Commit

Permalink
Merge pull request #169 from kyutai-labs/merge-more-changes
Browse files Browse the repository at this point in the history
Merge more internal changes.
  • Loading branch information
LaurentMazare authored Dec 10, 2024
2 parents d4cac37 + cf8b12e commit af1e918
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
1 change: 1 addition & 0 deletions rust/moshi-backend/src/stream_both.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ impl StreamingModel {
);
let mut state = moshi::lm_generate_multistream::State::new(
lm_model,
None,
self.session_config.max_steps,
audio_lp,
text_lp,
Expand Down
41 changes: 36 additions & 5 deletions rust/moshi-core/src/lm_generate_multistream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ impl Config {
}
}

pub fn v0_1_one_way() -> Self {
Self {
generated_audio_codebooks: 8,
input_audio_codebooks: 0,
audio_vocab_size: 2049,
acoustic_delay: 2,
text_eop_token: 0,
text_pad_token: 3,
text_start_token: 32000,
}
}

pub fn audio_pad_token(&self) -> u32 {
self.audio_vocab_size as u32 - 1
}
Expand All @@ -54,6 +66,7 @@ impl Config {

pub struct State {
model: crate::lm::LmModel,
ca_src: Option<Tensor>,
audio_tokens: Vec<Vec<u32>>,
text_tokens: Vec<u32>,
audio_lp: LogitsProcessor,
Expand All @@ -66,8 +79,10 @@ pub struct State {
}

impl State {
#[allow(clippy::too_many_arguments)]
pub fn new(
model: crate::lm::LmModel,
ca_src: Option<Tensor>,
max_step_idx: usize,
audio_lp: LogitsProcessor,
text_lp: LogitsProcessor,
Expand All @@ -82,6 +97,7 @@ impl State {
let text_tokens = vec![UNGENERATED; max_step_idx + config.acoustic_delay];
Self {
model,
ca_src,
audio_tokens,
text_tokens,
audio_lp,
Expand Down Expand Up @@ -150,16 +166,16 @@ impl State {

// The acoustic tokens are written with a delay, so this can create "gaps" of UNGENERATED
// tokens in the case where we call `step_audio_prompt` *after* `step`.
pub fn step(
pub fn step_(
&mut self,
text_token: u32,
text_token: Option<u32>,
input_audio_tokens: &[u32],
force_text_token: Option<u32>,
) -> candle::Result<u32> {
let mut codes = Vec::with_capacity(self.config.total_audio_codebooks());
let dev = self.model.device();
for (c_idx, &t) in input_audio_tokens.iter().enumerate() {
self.audio_tokens[self.step_idx][c_idx + 8] = t
self.audio_tokens[self.step_idx][c_idx + self.config.generated_audio_codebooks] = t
}
for codebook in 0..self.config.total_audio_codebooks() {
let t = if codebook == 0 || codebook == 8 {
Expand All @@ -179,8 +195,14 @@ impl State {
let t = Tensor::new(&[t], dev)?.unsqueeze(0)?;
codes.push(Some(t))
}
let text_token = Some(Tensor::from_vec(vec![text_token], (1, 1), dev)?);
let (text_logits, ys) = self.model.forward(text_token, codes)?;
let text_token = match text_token {
Some(text_token) => Some(Tensor::from_vec(vec![text_token], (1, 1), dev)?),
None => None,
};
let (text_logits, ys) = match self.ca_src.as_ref() {
None => self.model.forward(text_token, codes)?,
Some(ca_src) => self.model.forward_ca(text_token, codes, ca_src)?,
};
let text_logits = text_logits.i((0, 0))?;
let text_logits = self.apply_repetition_penalty(text_logits)?;
let text_token = match force_text_token {
Expand Down Expand Up @@ -222,6 +244,15 @@ impl State {
Ok(text_token)
}

pub fn step(
&mut self,
text_token: u32,
input_audio_tokens: &[u32],
force_text_token: Option<u32>,
) -> candle::Result<u32> {
self.step_(Some(text_token), input_audio_tokens, force_text_token)
}

/// If include_all is set, all the time steps are returned. Otherwise only the timesteps that
/// have been generated are handled.
pub fn audio_tokens(&self, include_all: bool) -> &[Vec<u32>] {
Expand Down

0 comments on commit af1e918

Please sign in to comment.