Skip to content

Commit

Permalink
Add router for bgm separation
Browse files Browse the repository at this point in the history
  • Loading branch information
jhj0517 committed Nov 1, 2024
1 parent 8f3a502 commit 6bfe32b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
61 changes: 61 additions & 0 deletions backend/bgm_separation/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import functools
import numpy as np
from fastapi import (
File,
UploadFile,
)
import gradio as gr
from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
from typing import List, Dict, Tuple

from modules.whisper.data_classes import *
from modules.uvr.music_separator import MusicSeparator
from ..util.audio import read_audio
from ..util.schemas import QueueResponse
from ..util.config_loader import load_server_config

@functools.lru_cache
def init_bgm_separation_inferencer() -> 'MusicSeparator':
config = load_server_config()["bgm_separation"]
inferencer = MusicSeparator()
inferencer.update_model(
model_name=config["model_size"],
device=config["compute_type"]
)
return inferencer

bgm_separation_router = APIRouter()
bgm_separation_inferencer = init_bgm_separation_inferencer()


async def run_bgm_separation(
audio: np.ndarray,
params: BGMSeparationParams
) -> Tuple[np.ndarray, np.ndarray]:
instrumental, vocal, filepaths = bgm_separation_inferencer.separate(
audio=audio,
model_name=params.model_size,
device=params.device,
segment_size=params.segment_size,
save_file=False,
progress=gr.Progress()
)
return instrumental, vocal


@bgm_separation_router.post("/bgm", tags=["bgm-separation"])
async def transcription(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Audio or video file to separate background music."),
params: TranscriptionPipelineParams = Depends()
) -> QueueResponse:
if not isinstance(file, np.ndarray):
audio = await read_audio(file=file)
else:
audio = file

background_tasks.add_task(run_bgm_separation, audio=audio, params=params)

return QueueResponse(message="Transcription task queued")


2 changes: 1 addition & 1 deletion backend/vad/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def run_vad(
@vad_router.post("/vad", tags=["vad"])
async def vad(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Audio or video file to transcribe."),
file: UploadFile = File(..., description="Audio or video file to detect voices."),
params: VadParams = Depends()
) -> QueueResponse:
if not isinstance(file, np.ndarray):
Expand Down

0 comments on commit 6bfe32b

Please sign in to comment.