From 88cbbb47ea0e18a91752170778438b3cda0e8086 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:22:25 +0900 Subject: [PATCH 1/4] Add `offload()` to transcription pipeline --- modules/whisper/base_transcription_pipeline.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 2791dc6..64bf67f 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -8,6 +8,7 @@ import numpy as np from datetime import datetime from faster_whisper.vad import VadOptions +import gc from modules.uvr.music_separator import MusicSeparator from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, @@ -414,6 +415,15 @@ def get_available_compute_type(self): else: return list(ctranslate2.get_supported_compute_types("cpu")) + def offload(self): + """Offload the model and free up the memory""" + if self.model is not None: + del self.model + self.model = None + if self.device == "cuda": + torch.cuda.empty_cache() + gc.collect() + @staticmethod def format_time(elapsed_time: float) -> str: """ From 85a217c9b3dd6cf9a9d57b9a14c2dc8aa86e5702 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:25:12 +0900 Subject: [PATCH 2/4] Update cuda release logic --- modules/whisper/base_transcription_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 64bf67f..e016f5c 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -421,7 +421,7 @@ def offload(self): del self.model self.model = None if self.device == "cuda": - torch.cuda.empty_cache() + self.release_cuda_memory() gc.collect() @staticmethod From 9524fb03faebc41983c54abd4ef02041a8dccc75 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:25:36 +0900 Subject: [PATCH 3/4] Add `offload()` to translation model --- modules/translation/translation_base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/modules/translation/translation_base.py b/modules/translation/translation_base.py index 6087767..08297a4 100644 --- a/modules/translation/translation_base.py +++ b/modules/translation/translation_base.py @@ -2,6 +2,7 @@ import torch import gradio as gr from abc import ABC, abstractmethod +import gc from typing import List from datetime import datetime @@ -128,6 +129,15 @@ def translate_file(self, finally: self.release_cuda_memory() + def offload(self): + """Offload the model and free up the memory""" + if self.model is not None: + del self.model + self.model = None + if self.device == "cuda": + self.release_cuda_memory() + gc.collect() + @staticmethod def get_device(): if torch.cuda.is_available(): From be456590e46c82d361b1c34257b61fdfa69e7216 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:27:49 +0900 Subject: [PATCH 4/4] Add `offload()` to diarizer --- modules/diarize/diarizer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index 38e150a..a16a888 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -4,6 +4,7 @@ import numpy as np import time import logging +import gc from modules.utils.paths import DIARIZATION_MODELS_DIR from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers @@ -121,6 +122,16 @@ def update_pipe(self, ) logger.disabled = False + def offload(self): + """Offload the model and free up the memory""" + if self.pipe is not None: + del self.pipe + self.pipe = None + if self.device == "cuda": + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + @staticmethod def get_device(): if torch.cuda.is_available():