Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLX Whisper Backend for Apple Silicon #147

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ Alternative, less restrictive, but slower backend is [whisper-timestamped](https

Thirdly, it's also possible to run this software from the [OpenAI Whisper API](https://platform.openai.com/docs/api-reference/audio/createTranscription). This solution is fast and requires no GPU, just a small VM will suffice, but you will need to pay OpenAI for api access. Also note that, since each audio fragment is processed multiple times, the [price](https://openai.com/pricing) will be higher than obvious from the pricing page, so keep an eye on costs while using. Setting a higher chunk-size will reduce costs significantly.
Install with: `pip install openai` , [requires Python >=3.8](https://pypi.org/project/openai/).

For running with the openai-api backend, make sure that your [OpenAI api key](https://platform.openai.com/api-keys) is set in the `OPENAI_API_KEY` environment variable. For example, before running, do: `export OPENAI_API_KEY=sk-xxx` with *sk-xxx* replaced with your api key.

Fourthly, another efficient backend is the [Whisper MLX](https://github.com/ml-explore/mlx-examples/tree/main/whisper) library, optimized specifically for Apple Silicon. Whisper MLX leverages the performance capabilities of Apple chips (M1, M2...) to deliver faster transcription without requiring a GPU: `pip install mlx-whisper`. All the main whisper models have been converted to the MLX format, and are listed on [Hugging Face Whisper mlx](https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc).


The backend is loaded only when chosen. The unused one does not have to be installed.

3) For voice activity controller: `pip install torch torchaudio`. Optional, but very recommended.
Expand Down
105 changes: 104 additions & 1 deletion whisper_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,107 @@ def use_vad(self):
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"

class MLXWhisper(ASRBase):
"""
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
Significantly faster than faster-whisper (without CUDA) on Apple M1.
"""

sep = " "

def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
"""
Loads the MLX-compatible Whisper model.

Args:
modelsize (str, optional): The size or name of the Whisper model to load.
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
cache_dir (str, optional): Path to the directory for caching models.
**Note**: This is not supported by MLX Whisper and will be ignored.
model_dir (str, optional): Direct path to a custom model directory.
If specified, it overrides the `modelsize` parameter.
"""
from mlx_whisper import transcribe

if model_dir is not None:
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
model_size_or_path = model_dir
elif modelsize is not None:
model_size_or_path = self.translate_model_name(modelsize)
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")

self.model_size_or_path = model_size_or_path
return transcribe

def translate_model_name(self, model_name):
"""
Translates a given model name to its corresponding MLX-compatible model path.

Args:
model_name (str): The name of the model to translate.

Returns:
str: The MLX-compatible model path.
"""
# Dictionary mapping model names to MLX-compatible paths
model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx"
}

# Retrieve the corresponding MLX model path
mlx_model_path = model_mapping.get(model_name)

if mlx_model_path:
return mlx_model_path
else:
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")

def transcribe(self, audio, init_prompt=""):
segments = self.model(
audio,
language=self.original_language,
initial_prompt=init_prompt,
word_timestamps=True,
condition_on_previous_text=True,
path_or_hf_repo=self.model_size_or_path,
**self.transcribe_kargs
)
return segments.get("segments", [])


def ts_words(self, segments):
"""
Extract timestamped words from transcription segments and skips words with high no-speech probability.
"""
return [
(word["start"], word["end"], word["word"])
for segment in segments
for word in segment.get("words", [])
if segment.get("no_speech_prob", 0) <= 0.9
]

def segments_end_ts(self, res):
return [s['end'] for s in res]

def use_vad(self):
self.transcribe_kargs["vad_filter"] = True

def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"

class OpenaiApiASR(ASRBase):
"""Uses OpenAI's Whisper API for audio transcription."""
Expand Down Expand Up @@ -660,7 +761,7 @@ def add_shared_args(parser):
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.')
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.')
parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.')
parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
Expand All @@ -679,6 +780,8 @@ def asr_factory(args, logfile=sys.stderr):
else:
if backend == "faster-whisper":
asr_cls = FasterWhisperASR
elif backend == "mlx-whisper":
asr_cls = MLXWhisper
else:
asr_cls = WhisperTimestampedASR

Expand Down