Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory overhead from copy2d #13

Merged
merged 9 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added Dockerfile
Empty file.
105 changes: 49 additions & 56 deletions fish_speech_core/lib/models/text2semantic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ pub mod utils;

use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{
embedding, ops::silu, ops::softmax_last_dim, Embedding, Linear, Module, RmsNorm, VarBuilder,
embedding, kv_cache::KvCache, ops::silu, ops::softmax_last_dim, Embedding, Linear, Module,
RmsNorm, VarBuilder,
};
use candle_transformers::utils::repeat_kv;
use serde::Deserialize;
use serde_json;
use std::fs::File;
Expand Down Expand Up @@ -116,18 +116,6 @@ impl Module for FeedForward {
}
}

pub struct Cache {
/// TODO: Does this require Arc<Mutex>>?
kvs: Option<(Tensor, Tensor)>,
}

impl Cache {
pub fn new() -> Result<Self> {
// Precompute freqs_cis
Ok(Self { kvs: None })
}
}

/// Returns (cos, sin) for the full possible batch size
fn precompute_freqs_cis(
config: &BaseModelArgs,
Expand All @@ -137,7 +125,7 @@ fn precompute_freqs_cis(
let n_elem = config.dim / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / (config.rope_base as f32).powf(i as f32 / n_elem as f32))
.map(|i| 1f32 / config.rope_base.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, config.max_seq_len as u32, device)?
Expand Down Expand Up @@ -173,17 +161,17 @@ pub struct Attention {
dim: usize,
wqkv: Linear,
wo: Linear,
cache: Cache,
cache: KvCache,
}

impl Attention {
pub fn load(vb: &VarBuilder, config: &BaseModelArgs) -> Result<Self> {
pub fn load(vb: &VarBuilder, config: &BaseModelArgs, is_fast: bool) -> Result<Self> {
let total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim;
// KQV for all heads, but in a batch
let wqkv = Linear::new(vb.get((total_head_dim, config.dim), "wqkv.weight")?, None);
let wo = Linear::new(vb.get((config.dim, config.dim), "wo.weight")?, None);

let cache = Cache::new()?;
let cache = KvCache::new(2, if is_fast { config.num_codebooks } else { 1024 });

Ok(Self {
n_head: config.n_head,
Expand All @@ -204,8 +192,8 @@ impl Attention {
cos: &Tensor,
sin: &Tensor,
) -> Result<(Tensor, Tensor)> {
let q_embed = candle_nn::rotary_emb::rope_i(&q, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_i(&k, &cos, &sin)?;
let q_embed = candle_nn::rotary_emb::rope_i(q, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope_i(k, cos, sin)?;
Ok((q_embed, k_embed))
}

Expand Down Expand Up @@ -255,38 +243,38 @@ impl Attention {

let query_states = query_states
.reshape((bsz, seqlen, self.n_head, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let key_states = key_states
.reshape((bsz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let value_states = value_states
.reshape((bsz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;

// Logic copied from phi3.rs
let _seqlen_offset = match &self.cache.kvs {
None => 0,
Some((prev_k, _)) => prev_k.dim(2)?,
};
let (query_states, key_states) =
self.apply_rotary_emb_qkv(&query_states, &key_states, freqs_cis.0, freqs_cis.1)?;
let (key_states, value_states) = match &self.cache.kvs {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.cache.kvs = Some((key_states.clone(), value_states.clone()));
let (query_states, key_states) = self.apply_rotary_emb_qkv(
&query_states.contiguous()?,
&key_states.contiguous()?,
freqs_cis.0,
freqs_cis.1,
)?;

let (key_states, value_states) = self
.cache
.append(&key_states.contiguous()?, &value_states.contiguous()?)?;

// Repeat KV cache
let key_states = repeat_kv(key_states, self.n_head / self.n_local_heads)?.contiguous()?;
let value_states =
repeat_kv(value_states, self.n_head / self.n_local_heads)?.contiguous()?;
// Length changes after pulling
let kv_seqlen = key_states.dim(2)?;
let n_rep = self.n_head / self.n_local_heads;
// TODO: Consider whether there's a better way to do this with 2 copy2ds instead of ucopy
// https://github.com/huggingface/candle/pull/2043 got the duplication wrong but there might still be something there
let key_states = key_states
.unsqueeze(2)?
.expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))?
.reshape((bsz, self.n_local_heads * n_rep, kv_seqlen, self.head_dim))?;
let value_states = value_states
.unsqueeze(2)?
.expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))?
.reshape((bsz, self.n_local_heads * n_rep, kv_seqlen, self.head_dim))?;

// TODO: Add optional flash attention
let y =
Expand All @@ -297,7 +285,7 @@ impl Attention {
}

pub fn clear_cache(&mut self) {
self.cache.kvs = None;
self.cache.reset();
}
}

Expand All @@ -309,8 +297,8 @@ pub struct TransformerBlock {
}

impl TransformerBlock {
pub fn load(vb: &VarBuilder, cfg: &BaseModelArgs) -> Result<Self> {
let attention = Attention::load(&vb.pp("attention"), cfg)?;
pub fn load(vb: &VarBuilder, cfg: &BaseModelArgs, is_fast: bool) -> Result<Self> {
let attention = Attention::load(&vb.pp("attention"), cfg, is_fast)?;
let feed_forward = FeedForward::load(&vb.pp("feed_forward"), cfg)?;
let ffn_norm = RmsNorm::new(vb.get(cfg.dim, "ffn_norm.weight")?, cfg.norm_eps);
let attention_norm = RmsNorm::new(vb.get(cfg.dim, "attention_norm.weight")?, cfg.norm_eps);
Expand All @@ -331,10 +319,9 @@ impl TransformerBlock {
) -> Result<Tensor> {
let residual = x;
let x = self.attention_norm.forward(x)?;
let x = (self.attention.forward(&x, mask, freqs_cis)? + residual)?;
let x = (residual + self.attention.forward(&x, mask, freqs_cis)?)?;
let residual = &x;
let x = residual + self.feed_forward.forward(&self.ffn_norm.forward(&x)?);
x
residual + self.feed_forward.forward(&self.ffn_norm.forward(&x)?)
}
}

Expand Down Expand Up @@ -364,7 +351,7 @@ impl DualARTransformer {
cfg.dim,
);
let layers: Result<Vec<TransformerBlock>> = (0..cfg.n_layer)
.map(|l| TransformerBlock::load(&vb.pp(format!("layers.{}", l)), cfg))
.map(|l| TransformerBlock::load(&vb.pp(format!("layers.{}", l)), cfg, false))
.collect();
let layers = layers?;
let norm = RmsNorm::new(vb.get(cfg.dim, "norm.weight")?, cfg.norm_eps);
Expand All @@ -374,7 +361,7 @@ impl DualARTransformer {
cfg.dim,
);
let fast_layers: Result<Vec<TransformerBlock>> = (0..cfg.n_fast_layer)
.map(|l| TransformerBlock::load(&vb.pp(format!("fast_layers.{}", l)), cfg))
.map(|l| TransformerBlock::load(&vb.pp(format!("fast_layers.{}", l)), cfg, true))
.collect();
let fast_layers = fast_layers?;
let fast_norm = RmsNorm::new(vb.get(cfg.dim, "fast_norm.weight")?, cfg.norm_eps);
Expand Down Expand Up @@ -417,7 +404,7 @@ impl DualARTransformer {
// Offset the ranges for each codebook so they don't overlap
let codebook_tokens_shifted = codebook_tokens.broadcast_add(
&Tensor::arange_step(
0 as u32,
0,
(self.cfg.num_codebooks * self.cfg.codebook_size) as u32,
self.cfg.codebook_size as u32,
x.device(),
Expand All @@ -444,7 +431,7 @@ impl DualARTransformer {
let mask = get_mask(seq_len, x.device())?;

let (cos_full, sin_full) = &self.freqs_cis;
for (_, layer) in self.layers.iter_mut().enumerate() {
for layer in self.layers.iter_mut() {
x = layer.forward(
&x,
&mask,
Expand Down Expand Up @@ -494,4 +481,10 @@ impl DualARTransformer {
layer.attention.clear_cache();
}
}

pub fn clear_slow_layer_caches(&mut self) {
for layer in self.layers.iter_mut() {
layer.attention.clear_cache();
}
}
}
2 changes: 1 addition & 1 deletion fish_speech_core/lib/models/text2semantic/utils/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn encode_tokens(
let zeros = Tensor::zeros((num_codebooks, new_tokens.len()), DType::U32, device)?;
let prompt = Tensor::cat(&[tokens, zeros], 0)?;

if let None = prompt_tokens {
if prompt_tokens.is_none() {
return Ok(prompt);
}
let prompt_tokens = prompt_tokens.unwrap().to_dtype(DType::U32)?;
Expand Down
73 changes: 73 additions & 0 deletions fish_speech_core/lib/models/text2semantic/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,74 @@
use candle_core::{DType, Device, Result, Tensor};
use std::collections::{HashMap, VecDeque};

pub mod encode;

pub struct RepPenProcessor {
penalty_mask: Tensor,
one: Tensor,
penalty_amt: Tensor,
context: VecDeque<usize>,
tokens_seen: HashMap<usize, usize>,
max_ctxt_size: usize,
vocab_size: usize,
}

impl RepPenProcessor {
pub fn new(
vocab_size: usize,
max_ctxt_size: usize,
penalty_amt: f32,
dtype: DType,
device: &Device,
) -> Result<Self> {
let penalty_mask = Tensor::ones(vocab_size, dtype, device)?;
// Yes, this is inelegant, but there's no scalar interface to slice_set
let one = Tensor::ones(1, dtype, device)?;
let penalty_amt = Tensor::from_vec(vec![penalty_amt], 1, device)?.to_dtype(dtype)?;
Ok(Self {
penalty_mask,
one,
penalty_amt,
context: VecDeque::new(),
tokens_seen: HashMap::new(),
max_ctxt_size,
vocab_size,
})
}

pub fn apply(&mut self, logits: &Tensor, last_token: usize) -> Result<Tensor> {
if last_token >= self.vocab_size {
candle_core::bail!("Token must be within vocab size");
}

// Add latest token to penalties if it's not there already
let count = self.tokens_seen.entry(last_token).or_insert(1);
if *count == 1 {
// This is the first time we're penalizing the token in this window, so add to mask
self.penalty_mask
.slice_set(&self.penalty_amt, 0, last_token)?;
}
self.context.push_front(last_token);

if self.context.len() > self.max_ctxt_size {
// If the token falling out of the window is the last of its kind, un-penalize it
if let Some(dropped_token) = self.context.pop_back() {
if let Some(count) = self.tokens_seen.get_mut(&dropped_token) {
*count -= 1;
if *count == 0 {
self.tokens_seen.remove(&dropped_token);
self.penalty_mask.slice_set(&self.one, 0, dropped_token)?;
}
}
}
}

logits.broadcast_div(&self.penalty_mask)
}

pub fn clear_cache(&mut self) -> Result<()> {
self.penalty_mask = self.penalty_mask.ones_like()?;
self.context.clear();
Ok(())
}
}
Loading