diff --git a/RealtimeSTT/audio_recorder.py b/RealtimeSTT/audio_recorder.py index 688bfc1..dcbb4e5 100644 --- a/RealtimeSTT/audio_recorder.py +++ b/RealtimeSTT/audio_recorder.py @@ -139,8 +139,13 @@ def run(self): device_index=self.gpu_device_index, download_root=self.download_root, ) + # Create a short dummy audio array, for example 1 second of silence at 16 kHz if self.batch_size > 0: model = BatchedInferencePipeline(model=model) + + # Run a warm-up transcription + dummy_audio = np.zeros(16000, dtype=np.float32) + model.transcribe(dummy_audio, language="en", beam_size=1) except Exception as e: logging.exception(f"Error initializing main faster_whisper transcription model: {e}") raise @@ -281,6 +286,7 @@ def __init__(self, buffer_size: int = BUFFER_SIZE, sample_rate: int = SAMPLE_RATE, initial_prompt: Optional[Union[str, Iterable[int]]] = None, + initial_prompt_realtime: Optional[Union[str, Iterable[int]]] = None, suppress_tokens: Optional[List[int]] = [-1], print_transcription_time: bool = False, early_transcription_on_silence: int = 0, @@ -294,12 +300,12 @@ def __init__(self, Args: - model (str, default="tiny"): Specifies the size of the transcription - model to use or the path to a converted model directory. - Valid options are 'tiny', 'tiny.en', 'base', 'base.en', - 'small', 'small.en', 'medium', 'medium.en', 'large-v1', - 'large-v2'. - If a specific size is provided, the model is downloaded - from the Hugging Face Hub. + model to use or the path to a converted model directory. + Valid options are 'tiny', 'tiny.en', 'base', 'base.en', + 'small', 'small.en', 'medium', 'medium.en', 'large-v1', + 'large-v2'. + If a specific size is provided, the model is downloaded + from the Hugging Face Hub. - download_root (str, default=None): Specifies the root path were the Whisper models are downloaded to. When empty, the default is used. - language (str, default=""): Language code for speech-to-text engine. @@ -472,7 +478,9 @@ def __init__(self, recording. Changing this will very probably functionality (as the WebRTC VAD model is very sensitive towards the sample rate). - initial_prompt (str or iterable of int, default=None): Initial - prompt to be fed to the transcription models. + prompt to be fed to the main transcription model. + - initial_prompt_realtime (str or iterable of int, default=None): + Initial prompt to be fed to the real-time transcription model. - suppress_tokens (list of int, default=[-1]): Tokens to be suppressed from the transcription output. - print_transcription_time (bool, default=False): Logs processing time @@ -533,6 +541,8 @@ def __init__(self, self.enable_realtime_transcription = enable_realtime_transcription self.use_main_model_for_realtime = use_main_model_for_realtime self.main_model_type = model + if not download_root: + download_root = None self.download_root = download_root self.realtime_model_type = realtime_model_type self.realtime_processing_pause = realtime_processing_pause @@ -583,6 +593,7 @@ def __init__(self, self.last_transcription_bytes = None self.last_transcription_bytes_b64 = None self.initial_prompt = initial_prompt + self.initial_prompt_realtime = initial_prompt_realtime self.suppress_tokens = suppress_tokens self.use_wake_words = wake_words or wakeword_backend in {'oww', 'openwakeword', 'openwakewords'} self.detected_language = None @@ -697,7 +708,11 @@ def __init__(self, if self.enable_realtime_transcription and not self.use_main_model_for_realtime: try: logging.info("Initializing faster_whisper realtime " - f"transcription model {self.realtime_model_type}" + f"transcription model {self.realtime_model_type}, " + f"default device: {self.device}, " + f"compute type: {self.compute_type}, " + f"device index: {self.gpu_device_index}, " + f"download root: {self.download_root}" ) self.realtime_model_type = faster_whisper.WhisperModel( model_size_or_path=self.realtime_model_type, @@ -708,7 +723,8 @@ def __init__(self, ) if self.realtime_batch_size > 0: self.realtime_model_type = BatchedInferencePipeline(model=self.realtime_model_type) - + dummy_audio = np.zeros(16000, dtype=np.float32) + self.realtime_model_type.transcribe(dummy_audio, language="en", beam_size=1) except Exception as e: logging.exception("Error initializing faster_whisper " f"realtime transcription model: {e}" @@ -2104,7 +2120,7 @@ def _realtime_worker(self): audio_array, language=self.language if self.language else None, beam_size=self.beam_size_realtime, - initial_prompt=self.initial_prompt, + initial_prompt=self.initial_prompt_realtime, suppress_tokens=self.suppress_tokens, batch_size=self.realtime_batch_size ) @@ -2113,7 +2129,7 @@ def _realtime_worker(self): audio_array, language=self.language if self.language else None, beam_size=self.beam_size_realtime, - initial_prompt=self.initial_prompt, + initial_prompt=self.initial_prompt_realtime, suppress_tokens=self.suppress_tokens ) diff --git a/RealtimeSTT/audio_recorder_client.py b/RealtimeSTT/audio_recorder_client.py index 36451a4..15514a9 100644 --- a/RealtimeSTT/audio_recorder_client.py +++ b/RealtimeSTT/audio_recorder_client.py @@ -28,6 +28,7 @@ INIT_MODEL_TRANSCRIPTION = "tiny" INIT_MODEL_TRANSCRIPTION_REALTIME = "tiny" INIT_REALTIME_PROCESSING_PAUSE = 0.2 +INIT_REALTIME_INITIAL_PAUSE = 0.2 INIT_SILERO_SENSITIVITY = 0.4 INIT_WEBRTC_SENSITIVITY = 3 INIT_POST_SPEECH_SILENCE_DURATION = 0.6 @@ -68,6 +69,7 @@ class AudioToTextRecorderClient: def __init__(self, model: str = INIT_MODEL_TRANSCRIPTION, + download_root: str = None, language: str = "", compute_type: str = "default", input_device_index: int = None, @@ -81,14 +83,17 @@ def __init__(self, use_microphone=True, spinner=True, level=logging.WARNING, + batch_size: int = 16, # Realtime transcription parameters enable_realtime_transcription=False, use_main_model_for_realtime=False, realtime_model_type=INIT_MODEL_TRANSCRIPTION_REALTIME, realtime_processing_pause=INIT_REALTIME_PROCESSING_PAUSE, + init_realtime_after_seconds=INIT_REALTIME_INITIAL_PAUSE, on_realtime_transcription_update=None, on_realtime_transcription_stabilized=None, + realtime_batch_size: int = 16, # Voice activation parameters silero_sensitivity: float = INIT_SILERO_SENSITIVITY, @@ -133,6 +138,7 @@ def __init__(self, buffer_size: int = BUFFER_SIZE, sample_rate: int = SAMPLE_RATE, initial_prompt: Optional[Union[str, Iterable[int]]] = None, + initial_prompt_realtime: Optional[Union[str, Iterable[int]]] = None, suppress_tokens: Optional[List[int]] = [-1], print_transcription_time: bool = False, early_transcription_on_silence: int = 0, @@ -162,10 +168,14 @@ def __init__(self, self.use_microphone = use_microphone self.spinner = spinner self.level = level + self.batch_size = batch_size + self.init_realtime_after_seconds = init_realtime_after_seconds + self.realtime_batch_size = realtime_batch_size # Real-time transcription parameters self.enable_realtime_transcription = enable_realtime_transcription self.use_main_model_for_realtime = use_main_model_for_realtime + self.download_root = download_root self.realtime_model_type = realtime_model_type self.realtime_processing_pause = realtime_processing_pause self.on_realtime_transcription_update = on_realtime_transcription_update @@ -204,6 +214,7 @@ def __init__(self, self.buffer_size = buffer_size self.sample_rate = sample_rate self.initial_prompt = initial_prompt + self.initial_prompt_realtime = initial_prompt_realtime self.suppress_tokens = suppress_tokens self.print_transcription_time = print_transcription_time self.early_transcription_on_silence = early_transcription_on_silence @@ -376,6 +387,43 @@ def start_server(self): args += ['--model', self.model] if self.realtime_model_type: args += ['--realtime_model_type', self.realtime_model_type] + if self.download_root: + args += ['--root', self.download_root] + if self.batch_size is not None: + args += ['--batch', str(self.batch_size)] + if self.realtime_batch_size is not None: + args += ['--realtime_batch_size', str(self.realtime_batch_size)] + if self.init_realtime_after_seconds is not None: + args += ['--init_realtime_after_seconds', str(self.init_realtime_after_seconds)] + if self.initial_prompt_realtime: + sanitized_prompt = self.initial_prompt_realtime.replace("\n", "\\n") + args += ['--initial_prompt_realtime', sanitized_prompt] + + # if self.compute_type: + # args += ['--compute_type', self.compute_type] + # if self.input_device_index is not None: + # args += ['--input_device_index', str(self.input_device_index)] + # if self.gpu_device_index is not None: + # args += ['--gpu_device_index', str(self.gpu_device_index)] + # if self.device: + # args += ['--device', self.device] + # if self.spinner: + # args.append('--spinner') # flag, no need for True/False + # if self.enable_realtime_transcription: + # args.append('--enable_realtime_transcription') # flag, no need for True/False + # if self.handle_buffer_overflow: + # args.append('--handle_buffer_overflow') # flag, no need for True/False + # if self.suppress_tokens: + # args += ['--suppress_tokens', str(self.suppress_tokens)] + # if self.print_transcription_time: + # args.append('--print_transcription_time') # flag, no need for True/False + # if self.allowed_latency_limit is not None: + # args += ['--allowed_latency_limit', str(self.allowed_latency_limit)] + # if self.no_log_file: + # args.append('--no_log_file') # flag, no need for True + # if self.debug_mode: + # args.append('--debug') # flag, no need for True/False + if self.language: args += ['--language', self.language] if self.silero_sensitivity is not None: diff --git a/RealtimeSTT_server/stt_server.py b/RealtimeSTT_server/stt_server.py index 2f0d4f7..1ef3eb4 100644 --- a/RealtimeSTT_server/stt_server.py +++ b/RealtimeSTT_server/stt_server.py @@ -27,6 +27,8 @@ - `-D, --debug`: Enable debug logging. - `-W, --write`: Save audio to WAV file. - `-s, --silence_timing`: Enable dynamic silence duration for sentence detection; default True. + - `-b, --batch, --batch_size`: Batch size for inference; default 16. + - `--root, --download_root`: Specifies the root path were the Whisper models are downloaded to. - `--silero_sensitivity`: Silero VAD sensitivity (0-1); default 0.05. - `--silero_use_onnx`: Use Silero ONNX model; default False. - `--webrtc_sensitivity`: WebRTC VAD sensitivity (0-3); default 3. @@ -38,7 +40,10 @@ - `--early_transcription_on_silence`: Start transcription after silence in seconds; default 0.2. - `--beam_size`: Beam size for main model; default 5. - `--beam_size_realtime`: Beam size for real-time model; default 3. - - `--initial_prompt`: Initial transcription guidance prompt. + - `--init_realtime_after_seconds`: Initial waiting time for realtime transcription; default 0.2. + - `--realtime_batch_size`: Batch size for the real-time transcription model; default 16. + - `--initial_prompt`: Initial main transcription guidance prompt. + - `--initial_prompt_realtime`: Initial realtime transcription guidance prompt. - `--end_of_sentence_detection_pause`: Silence duration for sentence end detection; default 0.45. - `--unknown_sentence_detection_pause`: Pause duration for incomplete sentence detection; default 0.7. - `--mid_sentence_detection_pause`: Pause for mid-sentence break; default 2.0. @@ -52,6 +57,14 @@ - `--use_main_model_for_realtime`: Use main model for real-time transcription. - `--use_extended_logging`: Enable extensive log messages. - `--logchunks`: Log incoming audio chunks. + - `--compute_type`: Type of computation to use. + - `--input_device_index`: Index of the audio input device. + - `--gpu_device_index`: Index of the GPU device. + - `--device`: Device to use for computation. + - `--handle_buffer_overflow`: Handle buffer overflow during transcription. + - `--suppress_tokens`: Suppress tokens during transcription. + - `--allowed_latency_limit`: Allowed latency limit for real-time transcription. + ### WebSocket Interface: The server supports two WebSocket connections: @@ -364,7 +377,7 @@ def parse_arguments(): parser.add_argument('-l', '--lang', '--language', type=str, default='en', help='Language code for the STT model to transcribe in a specific language. Leave this empty for auto-detection based on input audio. Default is en. List of supported language codes: https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L11-L110') - parser.add_argument('-i', '--input-device', '--input_device_index', type=int, default=1, + parser.add_argument('-i', '--input-device', '--input-device-index', type=int, default=1, help='Index of the audio input device to use. Use this option to specify a particular microphone or audio input device based on your system. Default is 1.') parser.add_argument('-c', '--control', '--control_port', type=int, default=8011, @@ -378,12 +391,23 @@ def parse_arguments(): parser.add_argument('-D', '--debug', action='store_true', help='Enable debug logging for detailed server operations') - parser.add_argument("-W", "--write", metavar="FILE", - help="Save received audio to a WAV file") + parser.add_argument('-W', '--write', metavar='FILE', help='Save received audio to a WAV file') + + parser.add_argument('-b', '--batch', '--batch_size', type=int, default=16, help='Batch size for inference. This parameter controls the number of audio chunks processed in parallel during transcription. Default is 16.') + + parser.add_argument('--root', '--download_root', type=str,default=None, help='Specifies the root path where the Whisper models are downloaded to. Default is None.') parser.add_argument('-s', '--silence_timing', action='store_true', default=True, help='Enable dynamic adjustment of silence duration for sentence detection. Adjusts post-speech silence duration based on detected sentence structure and punctuation. Default is False.') + parser.add_argument('--init_realtime_after_seconds', type=float, default=0.2, + help='The initial waiting time in seconds before real-time transcription starts. This delay helps prevent false positives at the beginning of a session. Default is 0.2 seconds.') + + parser.add_argument('--realtime_batch_size', type=int, default=16, + help='Batch size for the real-time transcription model. This parameter controls the number of audio chunks processed in parallel during real-time transcription. Default is 16.') + + parser.add_argument('--initial_prompt_realtime', type=str, default="", help='Initial prompt that guides the real-time transcription model to produce transcriptions in a particular style or format.') + parser.add_argument('--silero_sensitivity', type=float, default=0.05, help='Sensitivity level for Silero Voice Activity Detection (VAD), with a range from 0 to 1. Lower values make the model less sensitive, useful for noisy environments. Default is 0.05.') @@ -457,6 +481,23 @@ def parse_arguments(): parser.add_argument('--use_extended_logging', action='store_true', help='Writes extensive log messages for the recording worker, that processes the audio chunks.') + parser.add_argument('--compute_type', type=str, default='default', + help='Type of computation to use. See https://opennmt.net/CTranslate2/quantization.html') + + parser.add_argument('--gpu_device_index', type=int, default=0, + help='Index of the GPU device to use. Default is None.') + + parser.add_argument('--device', type=str, default='cuda', + help='Device for model to use. Can either be "cuda" or "cpu". Default is cuda.') + + parser.add_argument('--handle_buffer_overflow', action='store_true', + help='Handle buffer overflow during transcription. Default is False.') + + parser.add_argument('--suppress_tokens', type=int, default=[-1], nargs='*', help='Suppress tokens during transcription. Default is [-1].') + + parser.add_argument('--allowed_latency_limit', type=int, default=100, + help='Maximal amount of chunks that can be unprocessed in queue before discarding chunks.. Default is 100.') + parser.add_argument('--logchunks', action='store_true', help='Enable logging of incoming audio chunks (periods)') # Parse arguments @@ -479,6 +520,9 @@ def parse_arguments(): if args.initial_prompt: args.initial_prompt = args.initial_prompt.replace("\\n", "\n") + if args.initial_prompt_realtime: + args.initial_prompt_realtime = args.initial_prompt_realtime.replace("\\n", "\n") + return args def _recorder_thread(loop): @@ -534,7 +578,7 @@ def decode_and_resample( return resampled_audio.astype(np.int16).tobytes() -async def control_handler(websocket, path): +async def control_handler(websocket): debug_print(f"New control connection from {websocket.remote_address}") print(f"{bcolors.OKGREEN}Control client connected{bcolors.ENDC}") global recorder @@ -629,7 +673,7 @@ async def control_handler(websocket, path): finally: control_connections.remove(websocket) -async def data_handler(websocket, path): +async def data_handler(websocket): global writechunks, wav_file print(f"{bcolors.OKGREEN}Data client connected{bcolors.ENDC}") data_connections.add(websocket) @@ -700,8 +744,13 @@ async def main_async(): recorder_config = { 'model': args.model, + 'download_root': args.root, 'realtime_model_type': args.rt_model, 'language': args.lang, + 'batch_size': args.batch, + 'init_realtime_after_seconds': args.init_realtime_after_seconds, + 'realtime_batch_size': args.realtime_batch_size, + 'initial_prompt_realtime': args.initial_prompt_realtime, 'input_device_index': args.input_device, 'silero_sensitivity': args.silero_sensitivity, 'silero_use_onnx': args.silero_use_onnx, @@ -740,6 +789,12 @@ async def main_async(): 'no_log_file': True, # Disable logging to file 'use_extended_logging': args.use_extended_logging, 'level': loglevel, + 'compute_type': args.compute_type, + 'gpu_device_index': args.gpu_device_index, + 'device': args.device, + 'handle_buffer_overflow': args.handle_buffer_overflow, + 'suppress_tokens': args.suppress_tokens, + 'allowed_latency_limit': args.allowed_latency_limit, } try: diff --git a/setup.py b/setup.py index f2cd9a8..85adffa 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setuptools.setup( name="RealtimeSTT", - version="0.3.92", + version="0.3.93", author="Kolja Beigel", author_email="kolja.beigel@web.de", description="A fast Voice Activity Detection and Transcription System", diff --git a/tests/realtimestt_speechendpoint_binary_classified.py b/tests/realtimestt_speechendpoint_binary_classified.py index dc74915..01fac23 100644 --- a/tests/realtimestt_speechendpoint_binary_classified.py +++ b/tests/realtimestt_speechendpoint_binary_classified.py @@ -50,6 +50,8 @@ tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir) classification_model = DistilBertForSequenceClassification.from_pretrained(model_dir) + # tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir, force_download=True) + # classification_model = DistilBertForSequenceClassification.from_pretrained(model_dir, force_download=True) classification_model.to(device) classification_model.eval() @@ -85,7 +87,7 @@ def get_completion_probability(sentence, model, tokenizer, device, max_length): anchor_points = [ (0.0, 1.0), (1.0, 0) - ] + ] # anchor_points = [ # (0.0, 0.4), # (0.5, 0.3), @@ -144,10 +146,9 @@ def is_speech_finished(text): text_time_deque = deque() # Default values - #rapid_sentence_end_detection = 0.2 end_of_sentence_detection_pause = 0.3 unknown_sentence_detection_pause = 0.8 - mid_sentence_detection_pause = 2.0 + mid_sentence_detection_pause = 1.7 hard_break_even_on_background_noise = 3.0 hard_break_even_on_background_noise_min_texts = 3 hard_break_even_on_background_noise_min_chars = 15 @@ -172,12 +173,13 @@ def text_detected(text): def additional_pause_based_on_words(text): word_count = len(text.split()) pauses = { - 1: 0.6, - 2: 0.5, - 3: 0.4, - 4: 0.3, - 5: 0.2, - 6: 0.1, + 0: 0.35, + 1: 0.3, + 2: 0.25, + 3: 0.2, + 4: 0.15, + 5: 0.1, + 6: 0.05, } return pauses.get(word_count, 0.0) @@ -185,13 +187,31 @@ def process_queue(): global recorder, full_sentences, prev_text, displayed_text, rich_text_stored, text_time_deque, abrupt_stop, rapid_sentence_end_detection while True: + text = None # Initialize text to ensure it's defined + try: + # Attempt to retrieve the first item, blocking with timeout text = text_queue.get(timeout=1) except queue.Empty: - continue + continue # No item retrieved, continue the loop + + if text is None: + # Exit signal received + break + + # Drain the queue to get the latest text + try: + while True: + latest_text = text_queue.get_nowait() + if latest_text is None: + text = None + break + text = latest_text + except queue.Empty: + pass # No more items to retrieve if text is None: - # Exit + # Exit signal received after draining break text = preprocess_text(text) @@ -274,7 +294,7 @@ def process_queue(): def process_text(text): global recorder, full_sentences, prev_text, abrupt_stop - #if IS_DEBUG: print(f"SENTENCE: post_speech_silence_duration: {recorder.post_speech_silence_duration}") + if IS_DEBUG: print(f"SENTENCE: post_speech_silence_duration: {recorder.post_speech_silence_duration}") recorder.post_speech_silence_duration = unknown_sentence_detection_pause text = preprocess_text(text) text = text.rstrip() @@ -312,7 +332,7 @@ def process_text(text): 'beam_size': 5, 'beam_size_realtime': 3, 'no_log_file': True, - 'initial_prompt': ( + 'initial_prompt_realtime': ( "End incomplete sentences with ellipses.\n" "Examples:\n" "Complete: The sky is blue.\n"