Skip to content

Commit

Permalink
fix for stt-server (got broken by webserver update)
Browse files Browse the repository at this point in the history
  • Loading branch information
KoljaB committed Dec 18, 2024
1 parent c1a67e6 commit 4e74b5c
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 31 deletions.
38 changes: 27 additions & 11 deletions RealtimeSTT/audio_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,13 @@ def run(self):
device_index=self.gpu_device_index,
download_root=self.download_root,
)
# Create a short dummy audio array, for example 1 second of silence at 16 kHz
if self.batch_size > 0:
model = BatchedInferencePipeline(model=model)

# Run a warm-up transcription
dummy_audio = np.zeros(16000, dtype=np.float32)
model.transcribe(dummy_audio, language="en", beam_size=1)
except Exception as e:
logging.exception(f"Error initializing main faster_whisper transcription model: {e}")
raise
Expand Down Expand Up @@ -281,6 +286,7 @@ def __init__(self,
buffer_size: int = BUFFER_SIZE,
sample_rate: int = SAMPLE_RATE,
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
initial_prompt_realtime: Optional[Union[str, Iterable[int]]] = None,
suppress_tokens: Optional[List[int]] = [-1],
print_transcription_time: bool = False,
early_transcription_on_silence: int = 0,
Expand All @@ -294,12 +300,12 @@ def __init__(self,
Args:
- model (str, default="tiny"): Specifies the size of the transcription
model to use or the path to a converted model directory.
Valid options are 'tiny', 'tiny.en', 'base', 'base.en',
'small', 'small.en', 'medium', 'medium.en', 'large-v1',
'large-v2'.
If a specific size is provided, the model is downloaded
from the Hugging Face Hub.
model to use or the path to a converted model directory.
Valid options are 'tiny', 'tiny.en', 'base', 'base.en',
'small', 'small.en', 'medium', 'medium.en', 'large-v1',
'large-v2'.
If a specific size is provided, the model is downloaded
from the Hugging Face Hub.
- download_root (str, default=None): Specifies the root path were the Whisper models
are downloaded to. When empty, the default is used.
- language (str, default=""): Language code for speech-to-text engine.
Expand Down Expand Up @@ -472,7 +478,9 @@ def __init__(self,
recording. Changing this will very probably functionality (as the
WebRTC VAD model is very sensitive towards the sample rate).
- initial_prompt (str or iterable of int, default=None): Initial
prompt to be fed to the transcription models.
prompt to be fed to the main transcription model.
- initial_prompt_realtime (str or iterable of int, default=None):
Initial prompt to be fed to the real-time transcription model.
- suppress_tokens (list of int, default=[-1]): Tokens to be suppressed
from the transcription output.
- print_transcription_time (bool, default=False): Logs processing time
Expand Down Expand Up @@ -533,6 +541,8 @@ def __init__(self,
self.enable_realtime_transcription = enable_realtime_transcription
self.use_main_model_for_realtime = use_main_model_for_realtime
self.main_model_type = model
if not download_root:
download_root = None
self.download_root = download_root
self.realtime_model_type = realtime_model_type
self.realtime_processing_pause = realtime_processing_pause
Expand Down Expand Up @@ -583,6 +593,7 @@ def __init__(self,
self.last_transcription_bytes = None
self.last_transcription_bytes_b64 = None
self.initial_prompt = initial_prompt
self.initial_prompt_realtime = initial_prompt_realtime
self.suppress_tokens = suppress_tokens
self.use_wake_words = wake_words or wakeword_backend in {'oww', 'openwakeword', 'openwakewords'}
self.detected_language = None
Expand Down Expand Up @@ -697,7 +708,11 @@ def __init__(self,
if self.enable_realtime_transcription and not self.use_main_model_for_realtime:
try:
logging.info("Initializing faster_whisper realtime "
f"transcription model {self.realtime_model_type}"
f"transcription model {self.realtime_model_type}, "
f"default device: {self.device}, "
f"compute type: {self.compute_type}, "
f"device index: {self.gpu_device_index}, "
f"download root: {self.download_root}"
)
self.realtime_model_type = faster_whisper.WhisperModel(
model_size_or_path=self.realtime_model_type,
Expand All @@ -708,7 +723,8 @@ def __init__(self,
)
if self.realtime_batch_size > 0:
self.realtime_model_type = BatchedInferencePipeline(model=self.realtime_model_type)

dummy_audio = np.zeros(16000, dtype=np.float32)
self.realtime_model_type.transcribe(dummy_audio, language="en", beam_size=1)
except Exception as e:
logging.exception("Error initializing faster_whisper "
f"realtime transcription model: {e}"
Expand Down Expand Up @@ -2104,7 +2120,7 @@ def _realtime_worker(self):
audio_array,
language=self.language if self.language else None,
beam_size=self.beam_size_realtime,
initial_prompt=self.initial_prompt,
initial_prompt=self.initial_prompt_realtime,
suppress_tokens=self.suppress_tokens,
batch_size=self.realtime_batch_size
)
Expand All @@ -2113,7 +2129,7 @@ def _realtime_worker(self):
audio_array,
language=self.language if self.language else None,
beam_size=self.beam_size_realtime,
initial_prompt=self.initial_prompt,
initial_prompt=self.initial_prompt_realtime,
suppress_tokens=self.suppress_tokens
)

Expand Down
48 changes: 48 additions & 0 deletions RealtimeSTT/audio_recorder_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
INIT_MODEL_TRANSCRIPTION = "tiny"
INIT_MODEL_TRANSCRIPTION_REALTIME = "tiny"
INIT_REALTIME_PROCESSING_PAUSE = 0.2
INIT_REALTIME_INITIAL_PAUSE = 0.2
INIT_SILERO_SENSITIVITY = 0.4
INIT_WEBRTC_SENSITIVITY = 3
INIT_POST_SPEECH_SILENCE_DURATION = 0.6
Expand Down Expand Up @@ -68,6 +69,7 @@ class AudioToTextRecorderClient:

def __init__(self,
model: str = INIT_MODEL_TRANSCRIPTION,
download_root: str = None,
language: str = "",
compute_type: str = "default",
input_device_index: int = None,
Expand All @@ -81,14 +83,17 @@ def __init__(self,
use_microphone=True,
spinner=True,
level=logging.WARNING,
batch_size: int = 16,

# Realtime transcription parameters
enable_realtime_transcription=False,
use_main_model_for_realtime=False,
realtime_model_type=INIT_MODEL_TRANSCRIPTION_REALTIME,
realtime_processing_pause=INIT_REALTIME_PROCESSING_PAUSE,
init_realtime_after_seconds=INIT_REALTIME_INITIAL_PAUSE,
on_realtime_transcription_update=None,
on_realtime_transcription_stabilized=None,
realtime_batch_size: int = 16,

# Voice activation parameters
silero_sensitivity: float = INIT_SILERO_SENSITIVITY,
Expand Down Expand Up @@ -133,6 +138,7 @@ def __init__(self,
buffer_size: int = BUFFER_SIZE,
sample_rate: int = SAMPLE_RATE,
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
initial_prompt_realtime: Optional[Union[str, Iterable[int]]] = None,
suppress_tokens: Optional[List[int]] = [-1],
print_transcription_time: bool = False,
early_transcription_on_silence: int = 0,
Expand Down Expand Up @@ -162,10 +168,14 @@ def __init__(self,
self.use_microphone = use_microphone
self.spinner = spinner
self.level = level
self.batch_size = batch_size
self.init_realtime_after_seconds = init_realtime_after_seconds
self.realtime_batch_size = realtime_batch_size

# Real-time transcription parameters
self.enable_realtime_transcription = enable_realtime_transcription
self.use_main_model_for_realtime = use_main_model_for_realtime
self.download_root = download_root
self.realtime_model_type = realtime_model_type
self.realtime_processing_pause = realtime_processing_pause
self.on_realtime_transcription_update = on_realtime_transcription_update
Expand Down Expand Up @@ -204,6 +214,7 @@ def __init__(self,
self.buffer_size = buffer_size
self.sample_rate = sample_rate
self.initial_prompt = initial_prompt
self.initial_prompt_realtime = initial_prompt_realtime
self.suppress_tokens = suppress_tokens
self.print_transcription_time = print_transcription_time
self.early_transcription_on_silence = early_transcription_on_silence
Expand Down Expand Up @@ -376,6 +387,43 @@ def start_server(self):
args += ['--model', self.model]
if self.realtime_model_type:
args += ['--realtime_model_type', self.realtime_model_type]
if self.download_root:
args += ['--root', self.download_root]
if self.batch_size is not None:
args += ['--batch', str(self.batch_size)]
if self.realtime_batch_size is not None:
args += ['--realtime_batch_size', str(self.realtime_batch_size)]
if self.init_realtime_after_seconds is not None:
args += ['--init_realtime_after_seconds', str(self.init_realtime_after_seconds)]
if self.initial_prompt_realtime:
sanitized_prompt = self.initial_prompt_realtime.replace("\n", "\\n")
args += ['--initial_prompt_realtime', sanitized_prompt]

# if self.compute_type:
# args += ['--compute_type', self.compute_type]
# if self.input_device_index is not None:
# args += ['--input_device_index', str(self.input_device_index)]
# if self.gpu_device_index is not None:
# args += ['--gpu_device_index', str(self.gpu_device_index)]
# if self.device:
# args += ['--device', self.device]
# if self.spinner:
# args.append('--spinner') # flag, no need for True/False
# if self.enable_realtime_transcription:
# args.append('--enable_realtime_transcription') # flag, no need for True/False
# if self.handle_buffer_overflow:
# args.append('--handle_buffer_overflow') # flag, no need for True/False
# if self.suppress_tokens:
# args += ['--suppress_tokens', str(self.suppress_tokens)]
# if self.print_transcription_time:
# args.append('--print_transcription_time') # flag, no need for True/False
# if self.allowed_latency_limit is not None:
# args += ['--allowed_latency_limit', str(self.allowed_latency_limit)]
# if self.no_log_file:
# args.append('--no_log_file') # flag, no need for True
# if self.debug_mode:
# args.append('--debug') # flag, no need for True/False

if self.language:
args += ['--language', self.language]
if self.silero_sensitivity is not None:
Expand Down
Loading

0 comments on commit 4e74b5c

Please sign in to comment.