diff --git a/moshi/README.md b/moshi/README.md new file mode 100644 index 0000000..022ce8d --- /dev/null +++ b/moshi/README.md @@ -0,0 +1 @@ +# moshi - pytorch diff --git a/moshi/moshi/client.py b/moshi/moshi/client.py index 2872738..cba9048 100644 --- a/moshi/moshi/client.py +++ b/moshi/moshi/client.py @@ -1,16 +1,17 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Client for the Moshi server.""" import argparse import asyncio import queue import sys +import aiohttp import numpy as np import sphn import sounddevice as sd -import aiohttp from .client_utils import AnyPrinter, Printer, RawPrinter @@ -141,7 +142,27 @@ async def run(self) -> None: async def run(printer: AnyPrinter, args): - uri = f"ws://{args.host}:{args.port}/api/chat" + if args.url is None: + proto = "ws" + if args.https: + proto += "s" + uri = f"{proto}://{args.host}:{args.port}/api/chat" + else: + proto = "wss" + if '://' in args.url: + proto, without_proto = args.url.split('://', 1) + if proto in ['ws', 'http']: + proto = "ws" + elif proto in ['wss', 'https']: + proto = "wss" + else: + printer.log("error", "The provided URL {args.url} seems to contain a protocol but it is unknown.") + sys.exit(1) + else: + without_proto = args.url + uri = f"{proto}://{without_proto}/api/chat" + + printer.log("info", "Connecting to {uri}.") async with aiohttp.ClientSession() as session: async with session.ws_connect(uri) as ws: printer.log("info", "connected!") @@ -152,8 +173,11 @@ async def run(printer: AnyPrinter, args): def main(): parser = argparse.ArgumentParser("client_opus") - parser.add_argument("--host", default="localhost", type=str) - parser.add_argument("--port", default=8998, type=int) + parser.add_argument("--host", default="localhost", type=str, help="Hostname to connect to.") + parser.add_argument("--port", default=8998, type=int, help="Port to connect to.") + parser.add_argument("--https", action='store_true', + help="Set this flag for using a https connection.") + parser.add_argument("--url", type=str, help='Provides directly a URL, e.g. to a gradio tunnel.') args = parser.parse_args() printer: AnyPrinter diff --git a/moshi/moshi/client_utils.py b/moshi/moshi/client_utils.py index c7aaa92..0bc37f5 100644 --- a/moshi/moshi/client_utils.py +++ b/moshi/moshi/client_utils.py @@ -1,6 +1,8 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Utilities for the command line client, in particular for handling interactions with the terminal. +""" from dataclasses import dataclass import sys @@ -14,11 +16,11 @@ def colorize(text, color): def make_log(level: str, msg: str) -> str: if level == "warning": - prefix = colorize("Warning:", "1;31") + prefix = colorize("[Warn]", "1;31") elif level == "info": - prefix = colorize("Info:", "1;34") + prefix = colorize("[Info]", "1;34") elif level == "error": - prefix = colorize("Error:", "1;31") + prefix = colorize("[Err ]", "1;31") else: raise ValueError(f"Unknown level {level}") return prefix + " " + msg diff --git a/moshi/moshi/models/__init__.py b/moshi/moshi/models/__init__.py index 1fcf526..5501848 100644 --- a/moshi/moshi/models/__init__.py +++ b/moshi/moshi/models/__init__.py @@ -1,20 +1,14 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. """ -Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. +Models for the compression model Moshi, """ # flake8: noqa -from .encodec import ( +from .compression import ( CompressionModel, - EncodecModel, + MimiModel, ) from .lm import LMModel, LMGen -from .moshi_ import get_encodec, get_lm +from .loaders import get_mimi, get_moshi_lm diff --git a/moshi/moshi/models/encodec.py b/moshi/moshi/models/compression.py similarity index 94% rename from moshi/moshi/models/encodec.py rename to moshi/moshi/models/compression.py index 02d1723..7a790d5 100644 --- a/moshi/moshi/models/encodec.py +++ b/moshi/moshi/models/compression.py @@ -2,13 +2,15 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# Part of this file is adapted from encodec.py in https://github.com/facebookresearch/audiocraft +# released under the following license. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -"""Compression models or wrapper around existing models. -Also defines the main interface that a model must follow to be usable as an audio tokenizer. +"""Compression models or wrapper around existing models. In particular, provides the implementation +for Mimi. Also defines the main interface that a model must follow to be usable as an audio tokenizer. """ from abc import abstractmethod @@ -19,7 +21,6 @@ import torch from torch import nn -from torch.nn import functional as F from ..quantization import ( @@ -46,12 +47,12 @@ def forward(self, x: torch.Tensor) -> QuantizedResult: ... @abstractmethod def encode(self, x: torch.Tensor) -> torch.Tensor: - """See `EncodecModel.encode`.""" + """See `MimiModel.encode`.""" ... @abstractmethod def decode(self, codes: torch.Tensor) -> torch.Tensor: - """See `EncodecModel.decode`.""" + """See `MimiModel.decode`.""" ... @abstractmethod @@ -90,7 +91,7 @@ def set_num_codebooks(self, n: int): @dataclass -class _EncodecState: +class _MimiState: graphed_tr_enc: CUDAGraphed | None graphed_tr_dec: CUDAGraphed | None @@ -98,8 +99,8 @@ def reset(self): pass -class EncodecModel(CompressionModel[_EncodecState]): - """Encodec model operating on the raw waveform. +class MimiModel(CompressionModel[_MimiState]): + """Mimi model operating on the raw waveform. Args: encoder (nn.Module): Encoder network. @@ -122,6 +123,7 @@ class EncodecModel(CompressionModel[_EncodecState]): torch_compile_encoder_decoder (bool): if True, uses torch.compile on the encoder / decoder. Deactivated by default for training as this is incompatible at the moment with weight norm. See https://github.com/pytorch/pytorch/issues/121902 + Also this seems to work well with 2.2.0, but completely fail with 2.4.0. """ def __init__( @@ -217,14 +219,16 @@ def __init__( channel_wise=upsample_channel_wise_bug, ) - def _init_streaming_state(self, batch_size: int) -> _EncodecState: + def _init_streaming_state(self, batch_size: int) -> _MimiState: + device = next(self.parameters()).device + disable = device.type != 'cuda' graphed_tr_dec = None graphed_tr_enc = None if self.encoder_transformer is not None: - graphed_tr_enc = CUDAGraphed(self.encoder_transformer) + graphed_tr_enc = CUDAGraphed(self.encoder_transformer, disable=disable) if self.decoder_transformer is not None: - graphed_tr_dec = CUDAGraphed(self.decoder_transformer) - return _EncodecState(graphed_tr_enc, graphed_tr_dec) + graphed_tr_dec = CUDAGraphed(self.decoder_transformer, disable=disable) + return _MimiState(graphed_tr_enc, graphed_tr_dec) @property def channels(self) -> int: @@ -368,7 +372,8 @@ def encode(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): Float tensor of shape [B, C, T] Returns: - codes (torch.Tensor): an int tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. + codes (torch.Tensor): an int tensor of shape [B, K, T] + with K the number of codebooks used and T the timestep. """ emb = self._encode_to_unquantized_latent(x) codes = self.quantizer.encode(emb) diff --git a/moshi/moshi/models/lm.py b/moshi/moshi/models/lm.py index 58ad922..8cca181 100644 --- a/moshi/moshi/models/lm.py +++ b/moshi/moshi/models/lm.py @@ -372,8 +372,9 @@ def _init_streaming_state(self, batch_size: int) -> _LMGenState: dtype=torch.long, ) - graphed_main = CUDAGraphed(lm_model.forward_text) - graphed_depth = CUDAGraphed(self.depformer_step) + disable = lm_model.device.type != 'cuda' + graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable) + graphed_depth = CUDAGraphed(self.depformer_step, disable=disable) return _LMGenState(cache, initial, graphed_main, graphed_depth) diff --git a/moshi/moshi/models/moshi_.py b/moshi/moshi/models/loaders.py similarity index 61% rename from moshi/moshi/models/moshi_.py rename to moshi/moshi/models/loaders.py index d136200..1917694 100644 --- a/moshi/moshi/models/moshi_.py +++ b/moshi/moshi/models/loaders.py @@ -1,21 +1,29 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Retrieves the pretrained models for Moshi and Mimi.""" +from pathlib import Path -from ..modules import SEANetEncoder, SEANetDecoder, transformer -from .encodec import EncodecModel +from huggingface_hub import hf_hub_download +from safetensors.torch import load_model +import sentencepiece +import torch + +from .compression import MimiModel from .lm import LMModel +from ..modules import SEANetEncoder, SEANetDecoder, transformer from ..quantization import SplitResidualVectorQuantizer -import torch -from safetensors.torch import load_model -from pathlib import Path -import typing as tp SAMPLE_RATE = 24000 FRAME_RATE = 12.5 +HF_REPO = 'kmhf/msh-v0.1' +MIMI_V0_1 = 'tokenizer-e351c8d8-checkpoint125.safetensors' +MOSHIKO_V0_1 = 'moshiko_pt_301e30bf@120.safetensors' +MOSHIKA_V0_1 = 'moshika_pt_3d736a96@120.safetensors' +TEXT_TOKENIZER_V0_1 = 'tokenizer_spm_32k_3.model' -seanet_kwargs = { +_seanet_kwargs = { "channels": 1, "dimension": 512, "causal": True, @@ -35,15 +43,15 @@ "ratios": [8, 6, 5, 4], "true_skip": True, } -quantizer_kwargs = { +_quantizer_kwargs = { "dimension": 256, "n_q": 32, "bins": 2048, - "input_dimension": seanet_kwargs["dimension"], - "output_dimension": seanet_kwargs["dimension"], + "input_dimension": _seanet_kwargs["dimension"], + "output_dimension": _seanet_kwargs["dimension"], } -transformer_kwargs = { - "d_model": seanet_kwargs["dimension"], +_transformer_kwargs = { + "d_model": _seanet_kwargs["dimension"], "num_heads": 8, "num_layers": 8, "causal": True, @@ -55,17 +63,17 @@ "norm": "layer_norm", "positional_embedding": "rope", "dim_feedforward": 2048, - "input_dimension": seanet_kwargs["dimension"], - "output_dimensions": [seanet_kwargs["dimension"]], + "input_dimension": _seanet_kwargs["dimension"], + "output_dimensions": [_seanet_kwargs["dimension"]], } -lm_kwargs = { +_lm_kwargs = { "dim": 4096, "text_card": 32000, "existing_text_padding_id": 3, "n_q": 16, "dep_q": 8, - "card": quantizer_kwargs["bins"], + "card": _quantizer_kwargs["bins"], "num_heads": 32, "num_layers": 32, "hidden_scale": 4.125, @@ -92,24 +100,40 @@ } -def _is_safetensors(filename: tp.Union[str, Path]) -> bool: - filename = Path(filename) - return filename.suffix in (".safetensors", ".sft", ".sfts") +def _is_safetensors(path: Path | str) -> bool: + return Path(path).suffix in (".safetensors", ".sft", ".sfts") -def get_encodec(filename: tp.Union[str, Path], device): - encoder = SEANetEncoder(**seanet_kwargs) - decoder = SEANetDecoder(**seanet_kwargs) +def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = True) -> Path: + """Load a model checkpoint from HF. + If `allow_local_file` is True, then if a file `name` exists, it will be used instead. + """ + if allow_local_file and Path(name).exists(): + return Path(name) + else: + filename = name + return Path(hf_hub_download(hf_repo, filename)) + + +def get_text_tokenizer(filename: str | Path) -> sentencepiece.SentencePieceProcessor: + return sentencepiece.SentencePieceProcessor(str(filename)) # type: ignore + + +def get_mimi(filename: str | Path, + device: torch.device | str = 'cpu') -> MimiModel: + """Return a pretrained Mimi model.""" + encoder = SEANetEncoder(**_seanet_kwargs) + decoder = SEANetDecoder(**_seanet_kwargs) encoder_transformer = transformer.ProjectedTransformer( - device=device, **transformer_kwargs + device=device, **_transformer_kwargs ) decoder_transformer = transformer.ProjectedTransformer( - device=device, **transformer_kwargs + device=device, **_transformer_kwargs ) quantizer = SplitResidualVectorQuantizer( - **quantizer_kwargs, + **_quantizer_kwargs, ) - model = EncodecModel( + model = MimiModel( encoder, decoder, quantizer, @@ -126,21 +150,19 @@ def get_encodec(filename: tp.Union[str, Path], device): if _is_safetensors(filename): load_model(model, filename) else: - pkg = torch.load( - filename, - "cpu", - ) + pkg = torch.load(filename, "cpu") model.load_state_dict(pkg["model"]) model.set_num_codebooks(8) return model -def get_lm(filename: tp.Union[str, Path], device): +def get_moshi_lm(filename: str | Path, + device: torch.device | str = 'cpu') -> LMModel: dtype = torch.bfloat16 model = LMModel( device=device, dtype=dtype, - **lm_kwargs, + **_lm_kwargs, ).to(device=device, dtype=dtype) model.eval() if _is_safetensors(filename): diff --git a/moshi/moshi/modules/seanet.py b/moshi/moshi/modules/seanet.py index 0fe706b..1d8ff28 100644 --- a/moshi/moshi/modules/seanet.py +++ b/moshi/moshi/modules/seanet.py @@ -159,8 +159,7 @@ def __init__( self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks assert ( - self.disable_norm_outer_blocks >= 0 - and self.disable_norm_outer_blocks <= self.n_blocks + self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks ), ( "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." @@ -307,8 +306,7 @@ def __init__( self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks assert ( - self.disable_norm_outer_blocks >= 0 - and self.disable_norm_outer_blocks <= self.n_blocks + self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks ), ( "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." diff --git a/moshi/moshi/modules/transformer.py b/moshi/moshi/modules/transformer.py index d3d0ede..84d1952 100644 --- a/moshi/moshi/modules/transformer.py +++ b/moshi/moshi/modules/transformer.py @@ -9,6 +9,7 @@ See `StreamingTransformer` for more information. """ +from contextlib import ExitStack from dataclasses import dataclass import typing as tp @@ -17,6 +18,7 @@ import torch.nn as nn from torch.nn import functional as F +from ..utils.compile import no_compile from .gating import make_gating from .rope import RotaryEmbedding from .streaming import StreamingModule, StreamingContainer @@ -240,10 +242,7 @@ def reset(self): def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape) B, H, T, D = k.shape - indexes = ( - torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) - + self.end_offset - ) + indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset indexes = indexes % self.capacity self.cache[0].index_copy_(2, indexes, k) self.cache[1].index_copy_(2, indexes, v) @@ -485,8 +484,8 @@ def __init__( context=context, rope=rope, weights_per_step=weights_per_step, - **attn_kwargs, - **factory_kwargs, + **attn_kwargs, # type: ignore + **factory_kwargs, # type: ignore ) # type: ignore self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) @@ -542,8 +541,8 @@ def __init__( self.layer_scale_1 = nn.Identity() self.layer_scale_2 = nn.Identity() else: - self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) - self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) + self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore + self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore def _init_streaming_state(self, batch_size: int) -> _LayerState: return _LayerState(offset_cpu=0) @@ -582,12 +581,15 @@ def _sa_block(self, x: torch.Tensor): return x_orig + self.layer_scale_1(update) def forward(self, x: torch.Tensor): - x = self._sa_block(x) - x = self._ff_block(x) - state = self._streaming_state - if state: - state.offset_cpu += x.shape[1] - return x + with ExitStack() as stack: + if x.device.type != 'cuda': + stack.enter_context(no_compile()) + x = self._sa_block(x) + x = self._ff_block(x) + state = self._streaming_state + if state: + state.offset_cpu += x.shape[1] + return x @dataclass diff --git a/moshi/moshi/quantization/base.py b/moshi/moshi/quantization/base.py index e8f0ad4..02228a9 100644 --- a/moshi/moshi/quantization/base.py +++ b/moshi/moshi/quantization/base.py @@ -68,7 +68,7 @@ def num_codebooks(self) -> int: raise NotImplementedError() @property - def semantic_quantizer(self): + def semantic_quantizer(self) -> 'BaseQuantizer': """This returns the quantizer that models the first level of the hierarchy (typically semantic). In this case, it's the quantizer itself. @@ -76,7 +76,7 @@ def semantic_quantizer(self): return self @property - def acoustic_quantizer(self): + def acoustic_quantizer(self) -> 'BaseQuantizer': """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic). In this case, it's the quantizer itself. diff --git a/moshi/moshi/quantization/core_vq.py b/moshi/moshi/quantization/core_vq.py index 670b3a9..54abb5b 100644 --- a/moshi/moshi/quantization/core_vq.py +++ b/moshi/moshi/quantization/core_vq.py @@ -8,10 +8,9 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import math import typing as tp -from einops import rearrange, repeat +from einops import rearrange import torch from torch import nn from torch import distributed @@ -339,7 +338,7 @@ def forward( n_q = n_q or len(self.layers) previous_layer_is_initialized = True - for i, layer in enumerate(self.layers[:n_q]): + for i, layer in enumerate(self.layers[:n_q]): # type: ignore quantized, codes, loss, metrics = layer( residual, initialize=previous_layer_is_initialized ) @@ -366,7 +365,7 @@ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: residual = x all_indices = [] n_q = n_q or len(self.layers) - for layer in self.layers[:n_q]: + for layer in self.layers[:n_q]: # type: ignore indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized diff --git a/moshi/moshi/quantization/vq.py b/moshi/moshi/quantization/vq.py index 0e436c1..4fa5b0a 100644 --- a/moshi/moshi/quantization/vq.py +++ b/moshi/moshi/quantization/vq.py @@ -321,12 +321,12 @@ def dimension(self): return self.rvq_first.dimension @property - def semantic_quantizer(self): + def semantic_quantizer(self) -> ResidualVectorQuantizer: """This returns the quantizer that models the first level of the hierarchy (typically semantic).""" return self.rvq_first @property - def acoustic_quantizer(self): + def acoustic_quantizer(self) -> ResidualVectorQuantizer: """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic).""" return self.rvq_rest diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 70dfbff..a67a332 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -5,67 +5,29 @@ import argparse import asyncio from dataclasses import dataclass -from pathlib import Path import random +import os +from pathlib import Path import tarfile import time +import secrets +import sys -import os +import aiohttp +from aiohttp import web +from huggingface_hub import hf_hub_download import numpy as np import sentencepiece import sphn import torch -import aiohttp -from aiohttp import web - -from huggingface_hub import hf_hub_download - -from .models import moshi_, EncodecModel, LMGen -SAMPLE_RATE = moshi_.SAMPLE_RATE -DEVICE = "cuda:0" -ENABLE_PROFILING = False - -def colorize(text, color): - code = f"\033[{color}m" - restore = "\033[0m" - return "".join([code, text, restore]) +from .client_utils import make_log +from .models import loaders, MimiModel, LMModel, LMGen def log(level: str, msg: str): - if level == "warning": - prefix = colorize("[Warn]", "1;31") - elif level == "info": - prefix = colorize("[Info]", "1;34") - elif level == "error": - prefix = colorize("[Err ]", "1;31") - else: - raise ValueError(f"Unknown level {level}") - print(prefix + " " + msg) - - -parser = argparse.ArgumentParser() -parser.add_argument("--host", default="localhost", type=str) -parser.add_argument("--port", default=8998, type=int) -parser.add_argument("--static", type=str) -parser.add_argument("--tokenizer", type=str) -parser.add_argument("--moshi-weights", type=str) -parser.add_argument("--mimi-weights", type=str) -parser.add_argument("--hf-repo", type=str, default="kmhf/msh-v0.1") - -args = parser.parse_args() - -if args.tokenizer is None: - args.tokenizer = hf_hub_download(args.hf_repo, "tokenizer_spm_32k_3.model") -if args.moshi_weights is None: - args.moshi_weights = hf_hub_download( - args.hf_repo, "moshiko_pt_301e30bf@120.safetensors" - ) -if args.mimi_weights is None: - args.mimi_weights = hf_hub_download( - args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors" - ) + print(make_log(level, msg)) def seed_all(seed): @@ -79,43 +41,35 @@ def seed_all(seed): torch.backends.cudnn.benchmark = False -seed_all(42424242) - - @dataclass class ServerState: - ec: EncodecModel + mimi: MimiModel text_tokenizer: sentencepiece.SentencePieceProcessor lm_gen: LMGen lock: asyncio.Lock - def __init__(self): - log("info", "loading mimi") - self.ec = moshi_.get_encodec(args.mimi_weights, DEVICE) - log("info", "mimi loaded") - self.text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) - log("info", "loading moshi") - lm = moshi_.get_lm(args.moshi_weights, DEVICE) + def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, + lm: LMModel, device: str | torch.device): + self.mimi = mimi + self.text_tokenizer = text_tokenizer self.lm_gen = LMGen(lm) - self.frame_size = int(self.ec.sample_rate / self.ec.frame_rate) + self.device = device + self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) self.lock = asyncio.Lock() - self.ec.streaming_forever(1) + self.mimi.streaming_forever(1) self.lm_gen.streaming_forever(1) - log("info", "lm loaded") def warmup(self): for chunk in range(4): - chunk = torch.zeros( - 1, 1, self.frame_size, dtype=torch.float32, device=DEVICE - ) - codes = self.ec.encode(chunk) + chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device) + codes = self.mimi.encode(chunk) for c in range(codes.shape[-1]): - tokens = self.lm_gen.step(codes[:, :, c : c + 1]) + tokens = self.lm_gen.step(codes[:, :, c: c + 1]) if tokens is None: continue - _ = self.ec.decode(tokens[:, 1:]) + _ = self.mimi.decode(tokens[:, 1:]) torch.cuda.synchronize() async def handle_chat(self, request): @@ -168,21 +122,21 @@ async def opus_loop(): while all_pcm_data.shape[-1] >= self.frame_size: be = time.time() chunk = all_pcm_data[: self.frame_size] - all_pcm_data = all_pcm_data[self.frame_size :] + all_pcm_data = all_pcm_data[self.frame_size:] chunk = torch.from_numpy(chunk) - chunk = chunk.to(device=DEVICE)[None, None] - codes = self.ec.encode(chunk) + chunk = chunk.to(device=self.device)[None, None] + codes = self.mimi.encode(chunk) for c in range(codes.shape[-1]): - tokens = self.lm_gen.step(codes[:, :, c : c + 1]) + tokens = self.lm_gen.step(codes[:, :, c: c + 1]) if tokens is None: continue assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 - main_pcm = self.ec.decode(tokens[:, 1:]) + main_pcm = self.mimi.decode(tokens[:, 1:]) main_pcm = main_pcm.cpu() opus_writer.append_pcm(main_pcm[0, 0].numpy()) text_token = tokens[0, 0, 0].item() if text_token not in (0, 3): - _text = self.text_tokenizer.id_to_piece(text_token) + _text = self.text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") msg = b"\x02" + bytes(_text, encoding="utf8") log("info", f"text token '{_text}'") @@ -201,9 +155,9 @@ async def send_loop(): log("info", "accepted connection") close = False async with self.lock: - opus_writer = sphn.OpusStreamWriter(self.ec.sample_rate) - opus_reader = sphn.OpusStreamReader(self.ec.sample_rate) - self.ec.reset_streaming() + opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) + opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) + self.mimi.reset_streaming() self.lm_gen.reset_streaming() # Send the handshake. await ws.send_bytes(b"\x00") @@ -213,14 +167,63 @@ async def send_loop(): def main(): - state = ServerState() + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="localhost", type=str) + parser.add_argument("--port", default=8998, type=int) + parser.add_argument("--static", type=str) + parser.add_argument("--gradio_tunnel", action='store_true', help='Activate a gradio tunnel.') + parser.add_argument("--gradio_tunnel_token", + help='Provide a custom (secret) token here to keep getting the same URL.') + + parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1, + help="Name of the text tokenizer file in the given HF repo, or path to a local file.") + parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1, + help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.") + parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1, + help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.") + parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO, + help="HF repo to look into, defaults to Kyutai official one.") + parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.") + + args = parser.parse_args() + seed_all(42424242) + + setup_tunnel = None + tunnel_token = '' + if args.gradio_tunnel: + try: + from gradio import networking # type: ignore + except ImportError: + log("error", "Cannot find gradio which is required to activate a tunnel. " + "Please install with `pip install gradio`.") + sys.exit(1) + setup_tunnel = networking.setup_tunnel + if args.gradio_tunnel_token is None: + tunnel_token = secrets.token_urlsafe(32) + else: + tunnel_token = args.gradio_tunnel_token + + log("info", "loading mimi") + mimi_path = loaders.resolve_model_checkpoint(args.mimi_weight, args.hf_repo) + mimi = loaders.get_mimi(mimi_path, args.device) + log("info", "mimi loaded") + + tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer, args.hf_repo) + text_tokenizer = loaders.get_text_tokenizer(tokenizer_path) + + log("info", "loading moshi") + moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo) + lm = loaders.get_moshi_lm(moshi_path, args.device) + log("info", "moshi loaded") + + state = ServerState(mimi, text_tokenizer, lm, args.device) log("info", "warming up the model") state.warmup() app = web.Application() app.router.add_get("/api/chat", state.handle_chat) static_path: None | str = None if args.static is None: - log("info", f"retrieving the static content") + log("info", "retrieving the static content") dist_tgz = hf_hub_download(args.hf_repo, "dist.tgz") dist_tgz = Path(dist_tgz) dist = dist_tgz.parent / "dist" @@ -232,7 +235,6 @@ def main(): # When set to the "none" string, we don't serve any static content. static_path = args.static if static_path is not None: - async def handle_root(_): return web.FileResponse(os.path.join(static_path, "index.html")) @@ -241,7 +243,11 @@ async def handle_root(_): app.router.add_static( "/", path=static_path, follow_symlinks=True, name="static" ) - log("info", f"listening to ws://{args.host}:{args.port}") + log("info", f"Access the Web UI directly at http://{args.host}:{args.port}") + if setup_tunnel is not None: + tunnel = setup_tunnel('localhost', args.port, tunnel_token, None) + log("info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.") + log("info", f"Note that this tunnel goes through the US and you might experience high latency in Europe.") web.run_app(app, port=args.port) diff --git a/moshi/moshi/utils/compile.py b/moshi/moshi/utils/compile.py index b47e987..780513b 100644 --- a/moshi/moshi/utils/compile.py +++ b/moshi/moshi/utils/compile.py @@ -23,7 +23,7 @@ @contextmanager def no_compile(): - """Disable torch.compile locally.""" + """Disable torch.compile locally. Now Pytorch 2.4 provides a function to do that.""" global _compile_disabled prev_disabled = _compile_disabled @@ -194,11 +194,14 @@ class CUDAGraphed: be top level args, not nested in structures (tuples, dicts, etc). Keyword arguments are NOT supported for simplicity. warmup_steps: how many call to make normally before CUDA Graphing. In particular, this - allows torch.compiled functions to get properly compiled.""" + allows torch.compiled functions to get properly compiled. + disabled: if True, just call the func directly, useful to quickly deactivate on CPU. + """ - def __init__(self, func: tp.Callable, warmup_steps: int = 1): + def __init__(self, func: tp.Callable, warmup_steps: int = 1, disable: bool = False): self.func = func self.warmup_steps = warmup_steps + self.disable = disable self._graph: cuda.CUDAGraph | None = None self._output: tuple | None = None self._args: tuple | None = None @@ -214,7 +217,7 @@ def reset(self, warmup_steps: int = 0) -> None: def __call__(self, *args, **kwargs) -> tp.Any: if kwargs: raise RuntimeError("Named arguments not supported for now.") - if not _is_cuda_graph_enabled() or in_cuda_graph(): + if self.disable or not _is_cuda_graph_enabled() or in_cuda_graph(): return self.func(*args, **kwargs) def _clone_tensors(args: tuple) -> tuple: diff --git a/moshi/pyproject.toml b/moshi/pyproject.toml index 61d4c54..d5b7578 100644 --- a/moshi/pyproject.toml +++ b/moshi/pyproject.toml @@ -1,10 +1,9 @@ [project] name = "moshi" -version = "0.0.1" requires-python = ">= 3.10" description = "Moshi is moshi" dependencies = [ - "numpy >= 2.1.0, < 2.2", + "numpy >= 1.26, < 2.2", "safetensors >= 0.4.0, < 0.5", "huggingface-hub >= 0.24, < 0.25", "einops == 0.7", @@ -17,11 +16,18 @@ dependencies = [ authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] license = {text = "MIT"} +dynamic = ["version"] +[tool.setuptools.dynamic] +version = {attr = "moshi.__version__"} [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" -[tool.setuptools] -packages = ["moshi", "moshi.utils", "moshi.modules", "moshi.models", "moshi.quantization"] +[project.optional-dependencies] +dev = [ + "pyright", + "flake8", + "pre-commit", +] diff --git a/moshi/requirements.txt b/moshi/requirements.txt index 9a93905..876de9d 100644 --- a/moshi/requirements.txt +++ b/moshi/requirements.txt @@ -5,5 +5,6 @@ sounddevice==0.5.0 soundfile==0.12.1 sphn==0.1.4 torch==2.2.0 +numpy==1.26.4 aiohttp>=3.10.5, <3.11 huggingface-hub==0.24.6 diff --git a/moshi/setup.cfg b/moshi/setup.cfg index dc7aa4b..5bccac4 100644 --- a/moshi/setup.cfg +++ b/moshi/setup.cfg @@ -3,3 +3,4 @@ max-line-length = 120 [flake8] max-line-length = 120 +ignore = E203,E704 diff --git a/scripts/mimi_streaming_test.py b/scripts/mimi_streaming_test.py index d71f07c..54865a3 100644 --- a/scripts/mimi_streaming_test.py +++ b/scripts/mimi_streaming_test.py @@ -3,20 +3,23 @@ # LICENSE file in the root directory of this source tree. import argparse -import moshi +import random import time -import torch + +import numpy as np import sphn +import torch from torch.profiler import profile, ProfilerActivity -import numpy as np -import random -SAMPLE_RATE = moshi.models.moshi.SAMPLE_RATE -DEVICE = "cuda:0" -ENABLE_PROFILING = False +from moshi.models import loaders + parser = argparse.ArgumentParser() -parser.add_argument("--weights", type=str) +parser.add_argument("--weights", type=str, default=loaders.MIMI_V0_1) +parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO) +parser.add_argument("--device", type=str, + default='cuda' if torch.cuda.device_count() else 'cpu') +parser.add_argument("--profile", action='store_true') args = parser.parse_args() @@ -35,23 +38,27 @@ def seed_all(seed): print("loading mimi") -ec = moshi.models.moshi.get_encodec(args.weights, DEVICE) +mimi = loaders.get_mimi( + loaders.resolve_model_checkpoint(args.weights, args.hf_repo), + args.device) print("mimi loaded") -def encodec_streaming_test(ec, pcm_chunk_size=1920, max_duration_sec=10.0): +def mimi_streaming_test(mimi, max_duration_sec=10.0): + pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 sample_pcm, sample_sr = sphn.read("bria.mp3") + sample_rate = mimi.sample_rate print("loaded pcm", sample_pcm.shape, sample_sr) sample_pcm = sphn.resample( - sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=SAMPLE_RATE + sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate ) - sample_pcm = torch.tensor(sample_pcm, device=DEVICE) - max_duration_len = int(SAMPLE_RATE * max_duration_sec) + sample_pcm = torch.tensor(sample_pcm, device=args.device) + max_duration_len = int(sample_rate * max_duration_sec) if sample_pcm.shape[-1] > max_duration_len: sample_pcm = sample_pcm[..., :max_duration_len] print("resampled pcm", sample_pcm.shape, sample_sr) - sample_pcm = sample_pcm[None].to(device=DEVICE) + sample_pcm = sample_pcm[None].to(device=args.device) print("streaming encoding...") start_time = time.time() @@ -61,34 +68,34 @@ def run_loop(): for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size): end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size) chunk = sample_pcm[..., start_idx:end_idx] - codes, _scale = ec.encode(chunk) + codes = mimi.encode(chunk) if codes.shape[-1]: print(start_idx, codes.shape, end="\r") all_codes.append(codes) - if ENABLE_PROFILING: + if args.profile: with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: run_loop() prof.export_chrome_trace("trace.json") else: run_loop() - all_codes = torch.cat(all_codes, dim=-1) - print(f"codes {all_codes.shape} generated in {time.time() - start_time:.2f}s") + all_codes_th = torch.cat(all_codes, dim=-1) + print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s") print("streaming decoding...") all_pcms = [] - with ec.streaming(): - for i in range(all_codes.shape[-1]): - codes = all_codes[..., i : i + 1] - pcm = ec.decode(codes, scale=None) + with mimi.streaming(1): + for i in range(all_codes_th.shape[-1]): + codes = all_codes_th[..., i : i + 1] + pcm = mimi.decode(codes) print(i, pcm.shape, end="\r") all_pcms.append(pcm) all_pcms = torch.cat(all_pcms, dim=-1) print("pcm", all_pcms.shape, all_pcms.dtype) - sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), SAMPLE_RATE) - pcm = ec.decode(all_codes, scale=None) + sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate) + pcm = mimi.decode(all_codes_th) print("pcm", pcm.shape, pcm.dtype) - sphn.write_wav.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), SAMPLE_RATE) + sphn.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), sample_rate) with torch.no_grad(): - encodec_streaming_test(ec) + mimi_streaming_test(mimi) diff --git a/scripts/moshi_benchmark.py b/scripts/moshi_benchmark.py index 056542f..0bd015b 100644 --- a/scripts/moshi_benchmark.py +++ b/scripts/moshi_benchmark.py @@ -3,26 +3,30 @@ # LICENSE file in the root directory of this source tree. import argparse -import moshi -import sentencepiece -import torch -import sphn -import numpy as np import random import time +import numpy as np +import sentencepiece +import sphn +import torch from torch.profiler import profile, ProfilerActivity -SAMPLE_RATE = moshi.models.moshi.SAMPLE_RATE -DEVICE = "cuda:0" -ENABLE_PROFILING = False +from moshi.models import loaders, LMGen + parser = argparse.ArgumentParser() -parser.add_argument("--tokenizer", type=str) -parser.add_argument("--moshi-weights", type=str) -parser.add_argument("--mimi-weights", type=str) +parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1, + help="Name of the text tokenizer file in the given HF repo, or path to a local file.") +parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1, + help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.") +parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1, + help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.") +parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO, + help="HF repo to look into, defaults to Kyutai official one.") parser.add_argument("--steps", default=100, type=int) parser.add_argument("--profile", action="store_true") +parser.add_argument("--device", type=str, default='cuda') args = parser.parse_args() @@ -39,52 +43,53 @@ def seed_all(seed): seed_all(42424242) +tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer, args.hf_repo) +text_tokenizer = loaders.get_text_tokenizer(tokenizer_path) print("loading mimi") -ec = moshi.models.moshi.get_encodec(args.mimi_weights, DEVICE) +mimi_path = loaders.resolve_model_checkpoint(args.mimi_weight, args.hf_repo) +mimi = loaders.get_mimi(mimi_path, args.device) print("mimi loaded") -text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) print("loading moshi") -lm = moshi.models.moshi.get_lm(args.moshi_weights, DEVICE) -lm.to(torch.bfloat16) +moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo) +lm = loaders.get_moshi_lm(moshi_path, args.device) +lm_gen = LMGen(lm) print("lm loaded") -lm_gen = moshi.models.LMGen(lm) - def cb(step, total): print(f"{step:06d} / {total:06d}", end="\r") def streaming_test(bs): - main_audio = [] main_text = [] + frame_size = int(mimi.sample_rate / mimi.frame_rate) + def run_step(): start_time = time.time() # Chunk should contain the pcm data from the user, single channel with a sample rate of 24000. - chunk = torch.zeros((bs, 1, 1920), dtype=torch.float, device=DEVICE) - codes = ec.encode(chunk) + chunk = torch.zeros((bs, 1, frame_size), dtype=torch.float, device=args.device) + codes = mimi.encode(chunk) assert codes.shape[-1] == 1 - for c in range(codes.shape[-1]): - be = time.time() - ev = torch.cuda.Event(enable_timing=True) - ev.record() - tokens = lm_gen.step(codes[:, :, c : c + 1]) - if tokens is None: - print("Skipping") - return - evb = torch.cuda.Event(enable_timing=True) - evb.record() - dt_step = time.time() - be - text_tokens = tokens[:, 0, 0] - audio_tokens = tokens[:, 1:, :] - main_pcm = ec.decode(audio_tokens) - # main_pcm is the audio to be played back to the user, here we just append it and store it in - # a file once the loop is finished. - main_audio.append(main_pcm[0]) + be = time.time() + ev = torch.cuda.Event(enable_timing=True) + ev.record() + tokens = lm_gen.step(codes[:, :, :1]) + if tokens is None: + print("Skipping") + return + evb = torch.cuda.Event(enable_timing=True) + evb.record() + dt_step = time.time() - be + text_tokens = tokens[:, 0, 0] + audio_tokens = tokens[:, 1:, :] + main_pcm = mimi.decode(audio_tokens) + # main_pcm is the audio to be played back to the user, here we just append it and store it in + # a file once the loop is finished. + main_audio.append(main_pcm[0]) evb.synchronize() dg = ev.elapsed_time(evb) torch.cuda.synchronize() @@ -109,17 +114,17 @@ def run_step(): run_step() print() prof.export_chrome_trace("trace.json") - main_audio = torch.cat(main_audio, dim=-1) - print(main_audio.shape) + main_audio_th = torch.cat(main_audio, dim=-1) + print(main_audio_th.shape) print("generated text:") print("".join(main_text)) sphn.write_wav( - "gen_main.wav", main_audio[0].cpu().numpy().astype(np.float32), SAMPLE_RATE + "gen_main.wav", main_audio_th[0].cpu().numpy().astype(np.float32), mimi.sample_rate ) print("streaming test") bs = 1 with torch.no_grad(): - with ec.streaming(bs), lm_gen.streaming(bs): + with mimi.streaming(bs), lm_gen.streaming(bs): streaming_test(bs) diff --git a/scripts/setup.cfg b/scripts/setup.cfg index dc7aa4b..5bccac4 100755 --- a/scripts/setup.cfg +++ b/scripts/setup.cfg @@ -3,3 +3,4 @@ max-line-length = 120 [flake8] max-line-length = 120 +ignore = E203,E704