Skip to content

Commit

Permalink
Merge pull request #168 from kyutai-labs/internal-imports
Browse files Browse the repository at this point in the history
Merge some internal changes.
  • Loading branch information
LaurentMazare authored Dec 10, 2024
2 parents 295758d + f93b191 commit d4cac37
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 34 deletions.
1 change: 1 addition & 0 deletions rust/moshi-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod seanet;
pub mod streaming;
pub mod transformer;
pub mod tts;
pub mod tts_streaming;
pub mod wav;

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
Expand Down
12 changes: 12 additions & 0 deletions rust/moshi-core/src/lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,18 @@ impl LmModel {
}
}

pub fn forward_ca(
&mut self,
text_ids: Option<Tensor>,
audio_ids: Vec<Option<Tensor>>,
ca_src: &Tensor,
) -> candle::Result<(Tensor, Tensor)> {
match self {
Self::Lm(m) => m.forward_ca(text_ids, audio_ids, ca_src),
Self::QuantizedLm(m) => m.forward_ca(text_ids, audio_ids, ca_src),
}
}

pub fn depformer_sample(
&mut self,
step_idx: usize,
Expand Down
18 changes: 8 additions & 10 deletions rust/moshi-core/src/quantized_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// LICENSE file in the root directory of this source tree.

use crate::streaming::{StreamTensor, StreamingModule};
use crate::transformer::{get_mask, CrossAttention, PositionalEmbedding, RotaryEmbedding};
use crate::transformer::{get_mask, PositionalEmbedding, RotaryEmbedding};

use candle::{DType, IndexOp, Module, Result, Tensor, D};
use candle_transformers::quantized_nn::{layer_norm, linear_b, Linear};
Expand Down Expand Up @@ -169,7 +169,7 @@ pub struct StreamingMultiheadCrossAttention {
}

impl StreamingMultiheadCrossAttention {
pub fn new(_ca: CrossAttention, _cfg: &Config, _vb: VarBuilder) -> Result<Self> {
pub fn new(_cfg: &Config, _vb: VarBuilder) -> Result<Self> {
candle::bail!("cross-attn is not supported at the moment")
}

Expand Down Expand Up @@ -347,14 +347,12 @@ impl StreamingTransformerLayer {
}
};
let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
let cross_attn = match cfg.cross_attention {
Some(ca) => {
let norm_cross = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn =
StreamingMultiheadCrossAttention::new(ca, cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
}
None => None,
let cross_attn = if cfg.cross_attention.is_some() {
let norm_cross = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
} else {
None
};
Ok(Self {
self_attn,
Expand Down
116 changes: 92 additions & 24 deletions rust/moshi-core/src/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,86 @@ pub enum PositionalEmbedding {
None,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum CrossAttention {
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum CrossAttentionGating {
// Configure Type of gating used at the output of vision cross-attention layers
Normal,
ConstantGatedTanh,
ConstantGatedSigmoid,
ConditionalGatedTanh,
ConditionalGatedSigmoid,
}

#[derive(Debug, Clone)]
pub enum XaGate {
// Corresponding module
Normal,
Gated,
ConstantGated {
alpha: Tensor,
},
ConditionalGated {
in_proj: Linear,
out_proj: Linear,
activation: candle_nn::init::NonLinearity,
},
}

impl XaGate {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
match cfg.cross_attention {
// no cross attention - shouldn't occur here
None => candle::bail!("Invalid cross-attention config specified."),
// no gating
Some(CrossAttentionGating::Normal) => Ok(Self::Normal),
// constant (per-layer parameter) passed through tanh or sigmoid
Some(CrossAttentionGating::ConstantGatedTanh) => {
let alpha = vb.get((1, 1, 1), "gate.alpha")?.tanh()?;
Ok(Self::ConstantGated { alpha })
}
Some(CrossAttentionGating::ConstantGatedSigmoid) => {
let alpha = candle_nn::ops::sigmoid(&(vb.get((1, 1, 1), "gate.alpha")? - 4.0)?)?;
Ok(Self::ConstantGated { alpha })
}
// input conditional (small MLP)
Some(CrossAttentionGating::ConditionalGatedTanh)
| Some(CrossAttentionGating::ConditionalGatedSigmoid) => {
let dim = cfg.d_model;
let hidden_dims = (0.125 * dim as f32).floor() as usize;
let in_proj = linear(dim, hidden_dims, false, vb.pp("gate.alpha.0"))?;
let out_proj = linear(hidden_dims, dim, false, vb.pp("gate.alpha.2"))?;
let activation = match cfg.cross_attention {
Some(CrossAttentionGating::ConditionalGatedTanh) => {
candle_nn::init::NonLinearity::Tanh
}
Some(CrossAttentionGating::ConditionalGatedSigmoid) => {
candle_nn::init::NonLinearity::Sigmoid
}
_ => candle::bail!("Invalid cross-attention config specified."),
};
Ok(Self::ConditionalGated { in_proj, out_proj, activation })
}
}
}
}

impl Module for XaGate {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Normal => Ok(xs.clone()),
Self::ConstantGated { alpha } => xs.broadcast_mul(alpha),
Self::ConditionalGated { in_proj, out_proj, activation } => {
let alpha = xs.apply(in_proj)?.relu()?.apply(out_proj)?;
let alpha = match activation {
candle_nn::init::NonLinearity::Tanh => alpha.tanh(),
candle_nn::init::NonLinearity::Sigmoid => {
candle_nn::ops::sigmoid(&(alpha - 4.0)?)
}
_ => candle::bail!("Invalid non-linearity specified in cross-attention gating"),
};
xs * alpha?
}
}
}
}

#[derive(Debug, Clone)]
Expand All @@ -40,7 +116,7 @@ pub struct Config {
pub layer_scale: Option<f64>,
pub positional_embedding: PositionalEmbedding,
pub use_conv_block: bool,
pub cross_attention: Option<CrossAttention>,
pub cross_attention: Option<CrossAttentionGating>,
pub conv_kernel_size: usize,
pub use_conv_bias: bool,
pub gating: Option<candle_nn::Activation>,
Expand Down Expand Up @@ -248,12 +324,12 @@ pub struct StreamingMultiheadCrossAttention {
kv_repeat: usize,
num_heads: usize,
neg_inf: Tensor,
tanh_gate_alpha: Option<Tensor>,
gate: XaGate,
span: tracing::Span,
}

impl StreamingMultiheadCrossAttention {
pub fn new(ca: CrossAttention, cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_kv = cfg.num_heads / cfg.kv_repeat;
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
Expand All @@ -276,10 +352,7 @@ impl StreamingMultiheadCrossAttention {
let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
let tanh_gate_alpha = match ca {
CrossAttention::Gated => Some(vb.get((1, 1, 1), "tanh_gate.alpha")?.tanh()?),
CrossAttention::Normal => None,
};
let gate = XaGate::new(cfg, vb)?;
Ok(Self {
in_proj_q,
in_proj_k,
Expand All @@ -288,7 +361,7 @@ impl StreamingMultiheadCrossAttention {
kv_repeat: cfg.kv_repeat,
num_heads: cfg.num_heads,
neg_inf,
tanh_gate_alpha,
gate,
span: tracing::span!(tracing::Level::TRACE, "mhca"),
})
}
Expand Down Expand Up @@ -331,11 +404,8 @@ impl StreamingMultiheadCrossAttention {
let xs = xs
.transpose(1, 2)? // b,t,h,d
.reshape((b, t, hd))?
.apply(&self.out_proj)?;
let xs = match self.tanh_gate_alpha.as_ref() {
None => xs,
Some(alpha) => xs.broadcast_mul(alpha)?,
};
.apply(&self.out_proj)?
.apply(&self.gate)?;
Ok(xs)
}
}
Expand Down Expand Up @@ -517,14 +587,12 @@ impl StreamingTransformerLayer {
}
};
let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
let cross_attn = match cfg.cross_attention {
Some(ca) => {
let norm_cross = LayerNorm::new_no_bias(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn =
StreamingMultiheadCrossAttention::new(ca, cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
}
None => None,
let cross_attn = if cfg.cross_attention.is_some() {
let norm_cross = LayerNorm::new_no_bias(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
} else {
None
};
Ok(Self {
self_attn,
Expand Down
Loading

0 comments on commit d4cac37

Please sign in to comment.