Skip to content

Commit

Permalink
Merge pull request #380 from Capsize-Games/develop
Browse files Browse the repository at this point in the history
better handling of content generation
  • Loading branch information
w4ffl35 authored Jan 17, 2024
2 parents c0135be + 191ea7b commit e9e36a6
Show file tree
Hide file tree
Showing 15 changed files with 389 additions and 425 deletions.
158 changes: 61 additions & 97 deletions src/airunner/aihandler/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from airunner.workers.worker import Worker
from airunner.aihandler.enums import EngineRequestCode, EngineResponseCode
from airunner.aihandler.image_processor import ImageProcessor
from airunner.aihandler.llm import LLMEngine
from airunner.aihandler.llm import LLMController
from airunner.aihandler.logger import Logger
from airunner.aihandler.runner import SDRunner
from airunner.aihandler.runner import SDController
from airunner.aihandler.speech_to_text import SpeechToText
from airunner.aihandler.tts import TTS

Expand All @@ -34,17 +34,19 @@ class Engine(QObject):
# Signals
request_signal_status = pyqtSignal(str)
hear_signal = pyqtSignal(str)
text_generated_signal = pyqtSignal(dict)
image_generated_signal = pyqtSignal(dict)

# Loaded flags
llm_loaded: bool = False
sd_loaded: bool = False

# Model controllers
llm = None
sd = None
tts = None
stt = None
ocr = None
llm_controller = None
sd_controller = None
tts_controller = None
stt_controller = None
ocr_controller = None

# Message properties for EngineResponseCode.TEXT_STREAMED
message = ""
Expand Down Expand Up @@ -80,7 +82,7 @@ def do_response(self, response):

def cancel(self):
self.logger.info("Canceling")
self.sd.cancel()
self.sd_controller.cancel()
self.request_worker.cancel()

# END OFFLINE CLIENT
Expand Down Expand Up @@ -135,19 +137,20 @@ def response_worker_response_signal_slot(self, message):
EngineResponseCode.TEXT_STREAMED: self.handle_text_streamed,
EngineResponseCode.IMAGE_GENERATED: self.handle_image_generated,
EngineResponseCode.CAPTION_GENERATED: self.handle_caption_generated,
EngineResponseCode.CLEAR_MEMORY: self.clear_memory
}.get(message["code"], self.handle_default_response)(message["message"], message["code"])

def handle_generate_text(self, message):
self.move_sd_to_cpu()
self.llm.do_request(message["message"])
self.llm_controller.do_request(message["message"])

def handle_generate_image(self, message):
self.unload_llm(
message,
self.app.settings["memory_settings"]["unload_unused_model"],
self.app.settings["memory_settings"]["unload_unused_models"],
self.app.settings["memory_settings"]["move_unused_model_to_cpu"]
)
self.sd.generator_sample(message)
self.sd_controller.do_request(message["message"])

def handle_generate_caption(self, message):
pass
Expand All @@ -162,29 +165,26 @@ def handle_text_streamed(self, message, code):
self.current_message = self.current_message.replace("</s>", "")
# check if sentence enders are in self.current_message
is_end_of_message = "</s>" in message
self.tts.add_text(message.replace("</s>", ""), is_end_of_message=is_end_of_message)
self.app.message_handler_signal.emit(dict(
code=EngineResponseCode.ADD_TO_CONVERSATION,
message=dict(
name=self.app.settings["llm_generator_settings"]["botname"],
text=message.replace("</s>", ""),
is_bot=True,
first_message=self.first_message,
last_message=is_end_of_message
)
self.tts_controller.add_text(message.replace("</s>", ""), is_end_of_message=is_end_of_message)
self.text_generated_signal.emit(dict(
name=self.app.settings["llm_generator_settings"]["botname"],
text=message.replace("</s>", ""),
is_bot=True,
first_message=self.first_message,
last_message=is_end_of_message
))
self.first_message = False
if is_end_of_message:
self.first_message = True
self.message = ""
self.current_message = ""

# self.stt.do_listen()
# self.stt_controller.do_listen()

def handle_image_generated(self, message):
self.send_message(message, code)
def handle_image_generated(self, message, code):
self.image_generated_signal.emit(message)

def handle_caption_generated(self, message):
def handle_caption_generated(self, message, code):
self.send_message(message, code)

def __init__(self, **kwargs):
Expand All @@ -193,13 +193,23 @@ def __init__(self, **kwargs):
self.app = kwargs.get("app", None)
self.message_handler = kwargs.get("message_handler", None)
self.clear_memory()
self.initialize_llm() # Large language model
self.initialize_sd() # Art model
self.initialize_tts() # Text to speech model (voice)
self.initialize_stt() # Speech to text model (ears)
# self.initialize_ocr() # Vision to text model (eyes)

self.llm.response_signal.connect(self.do_response)
# Initialize Controllers
self.llm_controller = LLMController(engine=self)
self.sd_controller = SDController(engine=self)
#self.stt_controller = SpeechToText(engine=self, hear_signal=self.hear_signal, duration=10.0, fs=16000)
#self.hear_signal.connect(self.hear)
# self.listen_thread = threading.Thread(target=self.stt_controller.listen)
# self.listen_thread.start()

self.tts_controller = TTS(engine=self)
#self.tts_thread = threading.Thread(target=self.tts_controller.run)
#self.tts_thread.start()

# self.ocr_controller = ImageProcessor(engine=self)

self.llm_controller.response_signal.connect(self.do_response)
self.sd_controller.response_signal.connect(self.do_response)

# Request worker and thread
self.request_worker = Worker(prefix="RequestWorker")
Expand All @@ -222,8 +232,8 @@ def __init__(self, **kwargs):
def handle_default(self, message):
self.logger.error(f"Unknown code: {message['code']}")

def handle_default_response(self, message):
self.logger.error(f"handle_default_response Unknown code: {message['code']}")
def handle_default_response(self, message, code):
self.app.send_message(code, message)

def request_queue_size(self):
return self.request_worker.queue.qsize()
Expand All @@ -237,51 +247,6 @@ def send_message(self, message, code=None):
code=code,
message=message
))

def initialize_llm(self):
"""
Initialize the LLM.
"""
self.llm = LLMEngine(app=self.app, engine=self)

def initialize_sd(self):
"""
Initialize Stable Diffusion.
"""
self.sd = SDRunner(
app=self.app,
message_handler=self.message_handler,
engine=self
)

def initialize_stt(self):
"""
Initialize speech to text.
"""
self.stt = SpeechToText(
hear_signal=self.hear_signal,
engine=self,
duration=10.0,
fs=16000
)
self.hear_signal.connect(self.hear)
# self.listen_thread = threading.Thread(target=self.stt.listen)
# self.listen_thread.start()

def initialize_tts(self):
"""
Initialize text to speech.
"""
tts_settings = self.app.settings["tts_settings"]
self.tts = TTS(engine=self)
self.tts_thread = threading.Thread(target=self.tts.run)
self.tts_thread.start()

def initialize_ocr(self):
"""
Initialize vision to text.
"""
self.ocr = ImageProcessor(engine=self)

# def generator_sample(self, data: dict):
# """
Expand All @@ -300,22 +265,22 @@ def initialize_ocr(self):
# if not self.llm_loaded:
# self.logger.info("Preparing LLM")
# # if self.tts:
# # self.tts.move_model(to_cpu=False)
# # self.tts_controller.move_model(to_cpu=False)
# self.llm_loaded = True
# do_unload_model = data["request_data"].get("unload_unused_model", False)
# do_move_to_cpu = not do_unload_model and data["request_data"].get("move_unused_model_to_cpu", False)
# if do_move_to_cpu:
# self.move_sd_to_cpu()
# elif do_unload_model:
# self.sd.unload()
# self.sd_controller.unload()
# self.logger.info("Engine calling llm.do_generate")
# self.llm.do_generate(data)
# self.llm_controller.do_generate(data)

# def tts_generator_sample(self, data: dict):
# if "tts_request" not in data or not self.tts:
# return
# self.logger.info("Preparing TTS model...")
# # self.tts.move_model(to_cpu=False)
# # self.tts_controller.move_model(to_cpu=False)
# signal = data["request_data"].get("signal", None)
# message_object = data["request_data"].get("message_object", None)
# is_bot = data["request_data"].get("is_bot", False)
Expand All @@ -326,7 +291,7 @@ def initialize_ocr(self):
# # check if ends with a proper sentence ender, if not, add a period
# if not text.endswith((".", "?", "!", "...", "-", "—", )):
# text += "."
# generator = self.tts.add_text(text, "a", data["request_data"]["tts_settings"])
# generator = self.tts_controller.add_text(text, "a", data["request_data"]["tts_settings"])
# for success in generator:
# if signal and success:
# signal.emit(message_object, is_bot, first_message, last_message)
Expand All @@ -339,23 +304,23 @@ def initialize_ocr(self):
# self.sd_loaded = True
# self.do_unload_llm()
# self.logger.info("Engine calling sd.generator_sample")
# self.sd.generator_sample(data)
# self.sd_controller.generator_sample(data)

def do_listen(self):
# self.stt.do_listen()
# self.stt_controller.do_listen()
pass

def cancel(self):
"""
Cancel Stable Diffusion request.
"""
self.sd.cancel()
self.sd_controller.cancel()

def unload_stablediffusion(self):
"""
Unload the Stable Diffusion model from memory.
"""
self.sd.unload()
self.sd_controller.unload()

def parse_message(self, message):
if message:
Expand Down Expand Up @@ -385,9 +350,9 @@ def handle_tts(self, message: str):
# tts_settings=self.app.settings["tts_settings"]
# )
# )
self.tts.add_text(message)
self.tts_controller.add_text(message)

def clear_memory(self):
def clear_memory(self, *args, **kwargs):
"""
Clear the GPU ram.
"""
Expand All @@ -398,13 +363,13 @@ def clear_memory(self):

def clear_llm_history(self):
if self.llm:
self.llm.clear_history()
self.llm_controller.clear_history()

def stop(self):
self.logger.info("Stopping")
self.request_worker.stop()
self.response_worker.stop()
self.stt.stop()
#self.stt_controller.stop()

def unload_llm(self, request_data: dict, do_unload_model: bool, move_unused_model_to_cpu: bool):
"""
Expand All @@ -427,19 +392,18 @@ def unload_llm(self, request_data: dict, do_unload_model: bool, move_unused_mode

if do_move_to_cpu:
self.logger.info("Moving LLM to CPU")
self.llm.move_to_cpu()
self.llm_controller.move_to_cpu()
self.clear_memory()
elif do_unload_model:
self.do_unload_llm()

def do_unload_llm(self):
self.logger.info("Unloading LLM")
self.llm.unload_model()
self.llm.unload_tokenizer()
self.llm_controller.do_unload_llm()
self.clear_memory()

def move_sd_to_cpu(self):
if self.sd.is_pipe_on_cpu or not self.sd.has_pipe:
if self.sd_controller.is_pipe_on_cpu or not self.sd_controller.has_pipe:
return
self.sd.move_pipe_to_cpu()
self.sd_controller.move_pipe_to_cpu()
self.clear_memory()
2 changes: 2 additions & 0 deletions src/airunner/aihandler/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class EngineResponseCode(Enum):
TEXT_STREAMED = 701
CAPTION_GENERATED = 800
ADD_TO_CONVERSATION = 900
CLEAR_MEMORY = 1000
NSFW_CONTENT_DETECTED = 1100


class EngineRequestCode(Enum):
Expand Down
Loading

0 comments on commit e9e36a6

Please sign in to comment.