diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index c1f7cb2..2773166 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -32,9 +32,7 @@ def __init__(self, self.model_dir = model_dir os.makedirs(self.model_dir, exist_ok=True) - openai_models = whisper.available_models() - distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] - self.available_models = openai_models + distil_models + self.available_models = self.get_model_paths() def transcribe(self, audio: Union[str, np.ndarray, torch.Tensor], @@ -146,31 +144,26 @@ def update_model(self, model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}, ) - @staticmethod - def format_result( - transcribed_result: dict - ) -> List[dict]: + def get_model_paths(self): """ - Format the transcription result of insanely_fast_whisper as the same with other implementation. - - Parameters - ---------- - transcribed_result: dict - Transcription result of the insanely_fast_whisper + Get available models from models path including fine-tuned model. Returns ---------- - result: List[dict] - Formatted result as the same with other implementation + Name set of models """ - result = transcribed_result["chunks"] - for item in result: - start, end = item["timestamp"][0], item["timestamp"][1] - if end is None: - end = start - item["start"] = start - item["end"] = end - return result + openai_models = whisper.available_models() + distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] + default_models = openai_models + distil_models + + existing_models = os.listdir(self.model_dir) + wrong_dirs = [".locks"] + + available_models = default_models + existing_models + available_models = [model for model in available_models if model not in wrong_dirs] + available_models = sorted(set(available_models), key=available_models.index) + + return available_models @staticmethod def download_model(