Skip to content

Commit

Permalink
Merge pull request #378 from jhj0517/feature/enable-finetune
Browse files Browse the repository at this point in the history
Enable fintuned models for `insanely_fast_whisper`
  • Loading branch information
jhj0517 authored Nov 4, 2024
2 parents be000f4 + 9e5ed74 commit 02f3c63
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions modules/whisper/insanely_fast_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def __init__(self,
self.model_dir = model_dir
os.makedirs(self.model_dir, exist_ok=True)

openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
self.available_models = openai_models + distil_models
self.available_models = self.get_model_paths()

def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
Expand Down Expand Up @@ -146,31 +144,26 @@ def update_model(self,
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)

@staticmethod
def format_result(
transcribed_result: dict
) -> List[dict]:
def get_model_paths(self):
"""
Format the transcription result of insanely_fast_whisper as the same with other implementation.
Parameters
----------
transcribed_result: dict
Transcription result of the insanely_fast_whisper
Get available models from models path including fine-tuned model.
Returns
----------
result: List[dict]
Formatted result as the same with other implementation
Name set of models
"""
result = transcribed_result["chunks"]
for item in result:
start, end = item["timestamp"][0], item["timestamp"][1]
if end is None:
end = start
item["start"] = start
item["end"] = end
return result
openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
default_models = openai_models + distil_models

existing_models = os.listdir(self.model_dir)
wrong_dirs = [".locks"]

available_models = default_models + existing_models
available_models = [model for model in available_models if model not in wrong_dirs]
available_models = sorted(set(available_models), key=available_models.index)

return available_models

@staticmethod
def download_model(
Expand Down

0 comments on commit 02f3c63

Please sign in to comment.