Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable optional dynamic prompting for the FasterWhisperPipeline #913

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,14 @@ def decode_batch(tokens: List[List[int]]) -> str:
res = []
for tk in tokens:
res.append([token for token in tk if token < tokenizer.eot])
# text_tokens = [token for token in tokens if token < self.eot]
return tokenizer.tokenizer.decode_batch(res)

text = decode_batch(tokens_batch)

return text

def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
Expand Down Expand Up @@ -161,7 +157,6 @@ def get_iterator(
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# TODO hack by collating feature_extractor and image_processor

def stack(items):
return {'inputs': torch.stack([x['inputs'] for x in items])}
Expand All @@ -171,16 +166,18 @@ def stack(items):
return final_iterator

def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False, prompt=None
) -> TranscriptionResult:
if prompt:
self.options = self.options._replace(initial_prompt=prompt)

if isinstance(audio, str):
audio = load_audio(audio)

def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}

vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
Expand Down Expand Up @@ -231,17 +228,14 @@ def data(audio, segments):
}
)

# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None

# revert suppressed tokens if suppress_numerals is enabled
if self.suppress_numerals:
self.options = self.options._replace(suppress_tokens=previous_suppress_tokens)

return {"segments": segments, "language": language}


def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
Expand Down Expand Up @@ -354,4 +348,4 @@ def load_model(whisper_arch,
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)
)