diff --git a/README.md b/README.md index 1670ba9..255b695 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,11 @@ Alternative, less restrictive, but slower backend is [whisper-timestamped](https Thirdly, it's also possible to run this software from the [OpenAI Whisper API](https://platform.openai.com/docs/api-reference/audio/createTranscription). This solution is fast and requires no GPU, just a small VM will suffice, but you will need to pay OpenAI for api access. Also note that, since each audio fragment is processed multiple times, the [price](https://openai.com/pricing) will be higher than obvious from the pricing page, so keep an eye on costs while using. Setting a higher chunk-size will reduce costs significantly. Install with: `pip install openai` , [requires Python >=3.8](https://pypi.org/project/openai/). - For running with the openai-api backend, make sure that your [OpenAI api key](https://platform.openai.com/api-keys) is set in the `OPENAI_API_KEY` environment variable. For example, before running, do: `export OPENAI_API_KEY=sk-xxx` with *sk-xxx* replaced with your api key. +Fourthly, another efficient backend is the [Whisper MLX](https://github.com/ml-explore/mlx-examples/tree/main/whisper) library, optimized specifically for Apple Silicon. Whisper MLX leverages the performance capabilities of Apple chips (M1, M2...) to deliver faster transcription without requiring a GPU: `pip install mlx-whisper`. All the main whisper models have been converted to the MLX format, and are listed on [Hugging Face Whisper mlx](https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc). + + The backend is loaded only when chosen. The unused one does not have to be installed. 3) For voice activity controller: `pip install torch torchaudio`. Optional, but very recommended. diff --git a/whisper_online.py b/whisper_online.py index c11e53c..ce61a52 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -156,6 +156,107 @@ def use_vad(self): def set_translate_task(self): self.transcribe_kargs["task"] = "translate" +class MLXWhisper(ASRBase): + """ + Uses MPX Whisper library as the backend, optimized for Apple Silicon. + Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc + Significantly faster than faster-whisper (without CUDA) on Apple M1. + """ + + sep = " " + + def load_model(self, modelsize=None, cache_dir=None, model_dir=None): + """ + Loads the MLX-compatible Whisper model. + + Args: + modelsize (str, optional): The size or name of the Whisper model to load. + If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method. + Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo". + cache_dir (str, optional): Path to the directory for caching models. + **Note**: This is not supported by MLX Whisper and will be ignored. + model_dir (str, optional): Direct path to a custom model directory. + If specified, it overrides the `modelsize` parameter. + """ + from mlx_whisper import transcribe + + if model_dir is not None: + logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") + model_size_or_path = model_dir + elif modelsize is not None: + model_size_or_path = self.translate_model_name(modelsize) + logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") + + self.model_size_or_path = model_size_or_path + return transcribe + + def translate_model_name(self, model_name): + """ + Translates a given model name to its corresponding MLX-compatible model path. + + Args: + model_name (str): The name of the model to translate. + + Returns: + str: The MLX-compatible model path. + """ + # Dictionary mapping model names to MLX-compatible paths + model_mapping = { + "tiny.en": "mlx-community/whisper-tiny.en-mlx", + "tiny": "mlx-community/whisper-tiny-mlx", + "base.en": "mlx-community/whisper-base.en-mlx", + "base": "mlx-community/whisper-base-mlx", + "small.en": "mlx-community/whisper-small.en-mlx", + "small": "mlx-community/whisper-small-mlx", + "medium.en": "mlx-community/whisper-medium.en-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "large-v1": "mlx-community/whisper-large-v1-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", + "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + "large": "mlx-community/whisper-large-mlx" + } + + # Retrieve the corresponding MLX model path + mlx_model_path = model_mapping.get(model_name) + + if mlx_model_path: + return mlx_model_path + else: + raise ValueError(f"Model name '{model_name}' is not recognized or not supported.") + + def transcribe(self, audio, init_prompt=""): + segments = self.model( + audio, + language=self.original_language, + initial_prompt=init_prompt, + word_timestamps=True, + condition_on_previous_text=True, + path_or_hf_repo=self.model_size_or_path, + **self.transcribe_kargs + ) + return segments.get("segments", []) + + + def ts_words(self, segments): + """ + Extract timestamped words from transcription segments and skips words with high no-speech probability. + """ + return [ + (word["start"], word["end"], word["word"]) + for segment in segments + for word in segment.get("words", []) + if segment.get("no_speech_prob", 0) <= 0.9 + ] + + def segments_end_ts(self, res): + return [s['end'] for s in res] + + def use_vad(self): + self.transcribe_kargs["vad_filter"] = True + + def set_translate_task(self): + self.transcribe_kargs["task"] = "translate" class OpenaiApiASR(ASRBase): """Uses OpenAI's Whisper API for audio transcription.""" @@ -660,7 +761,7 @@ def add_shared_args(parser): parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.") parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.") parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.") - parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.') + parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.') parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.') parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.') parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.') @@ -679,6 +780,8 @@ def asr_factory(args, logfile=sys.stderr): else: if backend == "faster-whisper": asr_cls = FasterWhisperASR + elif backend == "mlx-whisper": + asr_cls = MLXWhisper else: asr_cls = WhisperTimestampedASR