Skip to content

Commit

Permalink
Enhance inference (#219)
Browse files Browse the repository at this point in the history
* fix: device initialization

* fix: v3 model architecture tasks bug

* add install_requires param

* comment out install_requires
  • Loading branch information
KevKibe authored Nov 25, 2024
1 parent 947cde5 commit 7635563
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
8 changes: 5 additions & 3 deletions DOCS/gettingstarted.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,11 @@ model = model_optimizer.load_transcription_model() # For fine-tuning v3 or v3-t
inference = SpeechTranscriptionPipeline(
audio_file_path=audiofile_dir,
task=task,
huggingface_token=huggingface_token
)
huggingface_token=huggingface_token,
chunk_size=10 ) # Duration of each audio chunk; shorter chunks improve accuracy but increase processing time.
# Optional parameter language: The language of the audio for transcription/translation.
# For fine-tuning v3 or v3-turbo models or a fine-tuned version of them specify is_v3_architecture=True


# To get transcriptions
transcription = inference.transcribe_audio(model=model)
Expand All @@ -160,7 +163,6 @@ print(transcription)
# To get transcriptions with speaker labels
alignment_result = inference.align_transcription(transcription) # Optional parameter alignment_model: if the default wav2vec alignment model is not available e.g thinkKenya/wav2vec2-large-xls-r-300m-sw
diarization_result = inference.diarize_audio(alignment_result)
print(diarization_result)

#To generate subtitles(.srt format), will be saved in root directory
inference.generate_subtitles(transcription, alignment_result, diarization_result)
Expand Down
33 changes: 25 additions & 8 deletions src/deployment/speech_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,28 +91,45 @@ class SpeechTranscriptionPipeline:
Class for handling speech transcription, alignment, and diarization tasks.
Attributes:
audio (AudioFile): Loaded audio file for processing.
task (str): Task type (e.g. "transcription").
device (str or int): Device identifier, either 'cpu' or GPU device index.
batch_size (int): Number of audio segments to process per batch.
chunk_size (int): Duration of each audio chunk for processing.
huggingface_token (str): Read token for accessing Huggingface API.
audio (AudioFile): The loaded audio file to be processed.
task (str): The type of task to perform (e.g., "transcription").
device (str): The device used for computation, either 'cpu' or the GPU (e.g., 'cuda:0').
batch_size (int): The number of audio segments to process in a single batch.
chunk_size (int): The duration (in seconds) of each audio chunk for processing.
Shorter chunks improve accuracy but may increase processing time.
huggingface_token (str): Token used to authenticate with the Hugging Face API.
language (str, optional): The language of the audio for transcription, if specified.
is_v3_architecture (bool): Specifies if the model uses the v3 architecture.
"""
def __init__(self,
audio_file_path: str,
task: str,
huggingface_token: str,
language: str = None,
batch_size: int = 32,
chunk_size: int = 30) -> None:
chunk_size: int = 30,
is_v3_architecture = False) -> None:
self.audio = load_audio(audio_file_path)
self.task = task
self.device = 0 if torch.cuda.is_available() else "cpu"
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.batch_size = batch_size
self.chunk_size = chunk_size
self.huggingface_token = huggingface_token
self.is_v3_architecture = is_v3_architecture
self.language = language
self._login_to_huggingface()
self.fix_v3_architecture_tasks()


def fix_v3_architecture_tasks(self):
"""
Switches between 'transcribe' and 'translate' tasks if is_v3_architecture is True.
"""
if self.is_v3_architecture:
if self.task == "transcribe":
self.task = "translate"
elif self.task == "translate":
self.task = "transcribe"

def _login_to_huggingface(self) -> None:
"""
Expand Down

0 comments on commit 7635563

Please sign in to comment.