From 1283b17ae5b5ce0cc97777b99062b3e0d60820e0 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 22 Sep 2024 16:14:59 +0200 Subject: [PATCH] Increase max_seq_len for mimi-pyo3 and make it configurable. --- rust/mimi-pyo3/Cargo.toml | 2 +- rust/mimi-pyo3/src/lib.rs | 16 ++++++++-------- rust/moshi-backend/Cargo.toml | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/rust/mimi-pyo3/Cargo.toml b/rust/mimi-pyo3/Cargo.toml index 64c9534..f7c772f 100644 --- a/rust/mimi-pyo3/Cargo.toml +++ b/rust/mimi-pyo3/Cargo.toml @@ -16,4 +16,4 @@ crate-type = ["cdylib"] anyhow = "1" numpy = "0.21.0" pyo3 = "0.21.0" -moshi = { path = "../moshi-core", version = "0.2.1" } +moshi = { path = "../moshi-core", version = "0.2.2" } diff --git a/rust/mimi-pyo3/src/lib.rs b/rust/mimi-pyo3/src/lib.rs index 5f033ea..ee696aa 100644 --- a/rust/mimi-pyo3/src/lib.rs +++ b/rust/mimi-pyo3/src/lib.rs @@ -39,7 +39,7 @@ macro_rules! py_bail { }; } -fn encodec_cfg() -> encodec::Config { +fn encodec_cfg(max_seq_len: Option) -> encodec::Config { let seanet_cfg = seanet::Config { dimension: 512, channels: 1, @@ -82,7 +82,7 @@ fn encodec_cfg() -> encodec::Config { kv_repeat: 1, conv_layout: true, // see builders.py cross_attention: false, - max_seq_len: 4096, + max_seq_len: max_seq_len.unwrap_or(8192), // the transformer works at 25hz so this is ~5 mins. }; encodec::Config { channels: 1, @@ -107,9 +107,9 @@ struct Tokenizer { #[pymethods] impl Tokenizer { - #[pyo3(signature = (path, *, dtype="f32"))] + #[pyo3(signature = (path, *, dtype="f32", max_seq_len=None))] #[new] - fn new(path: std::path::PathBuf, dtype: &str) -> PyResult { + fn new(path: std::path::PathBuf, dtype: &str, max_seq_len: Option) -> PyResult { let device = candle::Device::Cpu; let dtype = match dtype { "f32" => candle::DType::F32, @@ -119,7 +119,7 @@ impl Tokenizer { }; let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? }; - let cfg = encodec_cfg(); + let cfg = encodec_cfg(max_seq_len); let encodec = encodec::Encodec::new(cfg, vb).w()?; Ok(Self { encodec, device, dtype }) } @@ -240,9 +240,9 @@ struct StreamTokenizer { #[pymethods] impl StreamTokenizer { - #[pyo3(signature = (path, *, dtype="f32"))] + #[pyo3(signature = (path, *, dtype="f32", max_seq_len=None))] #[new] - fn new(path: std::path::PathBuf, dtype: &str) -> PyResult { + fn new(path: std::path::PathBuf, dtype: &str, max_seq_len: Option) -> PyResult { let device = candle::Device::Cpu; let dtype = match dtype { "f32" => candle::DType::F32, @@ -252,7 +252,7 @@ impl StreamTokenizer { }; let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? }; - let cfg = encodec_cfg(); + let cfg = encodec_cfg(max_seq_len); let mut e_encodec = encodec::Encodec::new(cfg, vb).w()?; let mut d_encodec = e_encodec.clone(); let (encoder_tx, e_rx) = std::sync::mpsc::channel::>(); diff --git a/rust/moshi-backend/Cargo.toml b/rust/moshi-backend/Cargo.toml index b15ff42..5de5d0b 100644 --- a/rust/moshi-backend/Cargo.toml +++ b/rust/moshi-backend/Cargo.toml @@ -26,7 +26,7 @@ rcgen = "0.13.1" http = "1.1.0" lazy_static = "1.5.0" log = "0.4.20" -moshi = { path = "../moshi-core", version = "0.2.1" } +moshi = { path = "../moshi-core", version = "0.2.2" } ogg = { version = "0.9.1", features = ["async"] } opus = "0.3.0" rand = { version = "0.8.5", features = ["getrandom"] }