Skip to content

Commit

Permalink
Merge branch 'refacto' into ci
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Sep 18, 2024
2 parents 95c5bac + 73d7249 commit 85bcac8
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 43 deletions.
6 changes: 4 additions & 2 deletions moshi/moshi/models/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,14 @@ def __init__(
)

def _init_streaming_state(self, batch_size: int) -> _MimiState:
device = next(self.parameters()).device
disable = device.type != 'cuda'
graphed_tr_dec = None
graphed_tr_enc = None
if self.encoder_transformer is not None:
graphed_tr_enc = CUDAGraphed(self.encoder_transformer)
graphed_tr_enc = CUDAGraphed(self.encoder_transformer, disable=disable)
if self.decoder_transformer is not None:
graphed_tr_dec = CUDAGraphed(self.decoder_transformer)
graphed_tr_dec = CUDAGraphed(self.decoder_transformer, disable=disable)
return _MimiState(graphed_tr_enc, graphed_tr_dec)

@property
Expand Down
5 changes: 3 additions & 2 deletions moshi/moshi/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,9 @@ def _init_streaming_state(self, batch_size: int) -> _LMGenState:
dtype=torch.long,
)

graphed_main = CUDAGraphed(lm_model.forward_text)
graphed_depth = CUDAGraphed(self.depformer_step)
disable = lm_model.device.type != 'cuda'
graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable)
graphed_depth = CUDAGraphed(self.depformer_step, disable=disable)

return _LMGenState(cache, initial, graphed_main, graphed_depth)

Expand Down
17 changes: 11 additions & 6 deletions moshi/moshi/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
See `StreamingTransformer` for more information.
"""

from contextlib import ExitStack
from dataclasses import dataclass
import typing as tp

Expand All @@ -17,6 +18,7 @@
import torch.nn as nn
from torch.nn import functional as F

from ..utils.compile import no_compile
from .gating import make_gating
from .rope import RotaryEmbedding
from .streaming import StreamingModule, StreamingContainer
Expand Down Expand Up @@ -579,12 +581,15 @@ def _sa_block(self, x: torch.Tensor):
return x_orig + self.layer_scale_1(update)

def forward(self, x: torch.Tensor):
x = self._sa_block(x)
x = self._ff_block(x)
state = self._streaming_state
if state:
state.offset_cpu += x.shape[1]
return x
with ExitStack() as stack:
if x.device.type != 'cuda':
stack.enter_context(no_compile())
x = self._sa_block(x)
x = self._ff_block(x)
state = self._streaming_state
if state:
state.offset_cpu += x.shape[1]
return x


@dataclass
Expand Down
11 changes: 7 additions & 4 deletions moshi/moshi/utils/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@contextmanager
def no_compile():
"""Disable torch.compile locally."""
"""Disable torch.compile locally. Now Pytorch 2.4 provides a function to do that."""
global _compile_disabled

prev_disabled = _compile_disabled
Expand Down Expand Up @@ -194,11 +194,14 @@ class CUDAGraphed:
be top level args, not nested in structures (tuples, dicts, etc). Keyword
arguments are NOT supported for simplicity.
warmup_steps: how many call to make normally before CUDA Graphing. In particular, this
allows torch.compiled functions to get properly compiled."""
allows torch.compiled functions to get properly compiled.
disabled: if True, just call the func directly, useful to quickly deactivate on CPU.
"""

def __init__(self, func: tp.Callable, warmup_steps: int = 1):
def __init__(self, func: tp.Callable, warmup_steps: int = 1, disable: bool = False):
self.func = func
self.warmup_steps = warmup_steps
self.disable = disable
self._graph: cuda.CUDAGraph | None = None
self._output: tuple | None = None
self._args: tuple | None = None
Expand All @@ -214,7 +217,7 @@ def reset(self, warmup_steps: int = 0) -> None:
def __call__(self, *args, **kwargs) -> tp.Any:
if kwargs:
raise RuntimeError("Named arguments not supported for now.")
if not _is_cuda_graph_enabled() or in_cuda_graph():
if self.disable or not _is_cuda_graph_enabled() or in_cuda_graph():
return self.func(*args, **kwargs)

def _clone_tensors(args: tuple) -> tuple:
Expand Down
5 changes: 2 additions & 3 deletions moshi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ maintainers = [{name="Laurent Mazaré", email="[email protected]"}]
license = {text = "MIT"}
dynamic = ["version"]

[tool.setuptools.dynamic]
version = {attr = "moshi.__version__"}

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = ["moshi", "moshi.utils", "moshi.modules", "moshi.models", "moshi.quantization"]

[tool.setuptools.dynamic]
version = {attr = "moshi.__version__"}

Expand Down
58 changes: 32 additions & 26 deletions scripts/mimi_streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
# LICENSE file in the root directory of this source tree.

import argparse
import moshi
import random
import time
import torch

import numpy as np
import sphn
import torch
from torch.profiler import profile, ProfilerActivity
import numpy as np
import random

SAMPLE_RATE = moshi.models.moshi.SAMPLE_RATE
DEVICE = "cuda:0"
ENABLE_PROFILING = False
from moshi.models import loaders


parser = argparse.ArgumentParser()
parser.add_argument("--weights", type=str)
parser.add_argument("--weights", type=str, default=loaders.MIMI_V0_1)
parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO)
parser.add_argument("--device", type=str,
default='cuda' if torch.cuda.device_count() else 'cpu')
parser.add_argument("--profile", action='store_true')
args = parser.parse_args()


Expand All @@ -35,23 +38,26 @@ def seed_all(seed):


print("loading mimi")
ec = moshi.models.moshi.get_encodec(args.weights, DEVICE)
mimi = loaders.get_mimi(
loaders.resolve_model_checkpoint(args.weights, args.hf_repo),
args.device)
print("mimi loaded")


def encodec_streaming_test(ec, pcm_chunk_size=1920, max_duration_sec=10.0):
def mimi_streaming_test(mimi, pcm_chunk_size=1920, max_duration_sec=10.0):
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
sample_pcm, sample_sr = sphn.read("bria.mp3")
sample_rate = mimi.sample_rate
print("loaded pcm", sample_pcm.shape, sample_sr)
sample_pcm = sphn.resample(
sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=SAMPLE_RATE
sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
)
sample_pcm = torch.tensor(sample_pcm, device=DEVICE)
max_duration_len = int(SAMPLE_RATE * max_duration_sec)
sample_pcm = torch.tensor(sample_pcm, device=args.device)
max_duration_len = int(sample_rate * max_duration_sec)
if sample_pcm.shape[-1] > max_duration_len:
sample_pcm = sample_pcm[..., :max_duration_len]
print("resampled pcm", sample_pcm.shape, sample_sr)
sample_pcm = sample_pcm[None].to(device=DEVICE)
sample_pcm = sample_pcm[None].to(device=args.device)

print("streaming encoding...")
start_time = time.time()
Expand All @@ -61,34 +67,34 @@ def run_loop():
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
chunk = sample_pcm[..., start_idx:end_idx]
codes, _scale = ec.encode(chunk)
codes = mimi.encode(chunk)
if codes.shape[-1]:
print(start_idx, codes.shape, end="\r")
all_codes.append(codes)

if ENABLE_PROFILING:
if args.profile:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
run_loop()
prof.export_chrome_trace("trace.json")
else:
run_loop()
all_codes = torch.cat(all_codes, dim=-1)
print(f"codes {all_codes.shape} generated in {time.time() - start_time:.2f}s")
all_codes_th = torch.cat(all_codes, dim=-1)
print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s")
print("streaming decoding...")
all_pcms = []
with ec.streaming():
for i in range(all_codes.shape[-1]):
codes = all_codes[..., i : i + 1]
pcm = ec.decode(codes, scale=None)
with mimi.streaming(1):
for i in range(all_codes_th.shape[-1]):
codes = all_codes_th[..., i : i + 1]
pcm = mimi.decode(codes)
print(i, pcm.shape, end="\r")
all_pcms.append(pcm)
all_pcms = torch.cat(all_pcms, dim=-1)
print("pcm", all_pcms.shape, all_pcms.dtype)
sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), SAMPLE_RATE)
pcm = ec.decode(all_codes, scale=None)
sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate)
pcm = mimi.decode(all_codes_th)
print("pcm", pcm.shape, pcm.dtype)
sphn.write_wav.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), SAMPLE_RATE)
sphn.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), sample_rate)


with torch.no_grad():
encodec_streaming_test(ec)
mimi_streaming_test(mimi)
1 change: 1 addition & 0 deletions scripts/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ max-line-length = 120

[flake8]
max-line-length = 120
ignore = E203,E704

0 comments on commit 85bcac8

Please sign in to comment.