Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #372

Merged
merged 4 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions src/airunner/aihandler/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from airunner.aihandler.speech_to_text import SpeechToText
from airunner.aihandler.tts import TTS

logger = Logger(prefix="Engine")

class Message:
def __init__(self, *args, **kwargs):
self.name = kwargs.get("name")
Expand Down Expand Up @@ -101,7 +103,7 @@ def initialize_ocr(self):
self.ocr = ImageProcessor(engine=self)

def move_pipe_to_cpu(self):
Logger.info("Moving pipe to CPU")
logger.info("Moving pipe to CPU")
self.sd.move_pipe_to_cpu()
self.clear_memory()

Expand All @@ -111,7 +113,7 @@ def generator_sample(self, data: dict):
:param data:
:return:
"""
Logger.info("generator_sample called")
logger.info("generator_sample called")
self.llm_generator_sample(data)
self.tts_generator_sample(data)
self.sd_generator_sample(data)
Expand All @@ -120,7 +122,7 @@ def llm_generator_sample(self, data: dict):
if "llm_request" not in data or not self.llm:
return
if self.model_type != "llm":
Logger.info("Preparing LLM model...")
logger.info("Preparing LLM model...")
# if self.tts:
# self.tts.move_model(to_cpu=False)
self.clear_memory()
Expand All @@ -133,13 +135,13 @@ def llm_generator_sample(self, data: dict):
self.sd.unload_model()
self.sd.unload_tokenizer()
self.clear_memory()
Logger.info("Engine calling llm.do_generate")
logger.info("Engine calling llm.do_generate")
self.llm.do_generate(data)

def tts_generator_sample(self, data: dict):
if "tts_request" not in data or not self.tts:
return
Logger.info("Preparing TTS model...")
logger.info("Preparing TTS model...")
# self.tts.move_model(to_cpu=False)
signal = data["request_data"].get("signal", None)
message_object = data["request_data"].get("message_object", None)
Expand All @@ -157,15 +159,13 @@ def tts_generator_sample(self, data: dict):
signal.emit(message_object, is_bot, first_message, last_message)

def sd_generator_sample(self, data:dict):
if "sd_request" not in data or not self.sd:
if "options" not in data or "sd_request" not in data["options"] or not self.sd:
return
if self.model_type != "art":
Logger.info("Preparing Art model...")
do_unload_model = data["options"].get("unload_unused_model", False)
move_unused_model_to_cpu = data["options"].get("move_unused_model_to_cpu", False)
logger.info("Preparing Art model...")
self.model_type = "art"
self.do_unload_llm(data["request_data"], do_unload_model, move_unused_model_to_cpu)
Logger.info("Engine calling sd.generator_sample")
self.do_unload_llm()
logger.info("Engine calling sd.generator_sample")
self.sd.generator_sample(data)

def do_listen(self):
Expand All @@ -190,14 +190,14 @@ def unload_llm(self, request_data: dict, do_unload_model: bool, move_unused_mode
do_move_to_cpu = False

if do_move_to_cpu:
Logger.info("Moving LLM to CPU")
logger.info("Moving LLM to CPU")
self.llm.move_to_cpu()
self.clear_memory()
elif do_unload_model:
self.do_unload_llm()

def do_unload_llm(self):
Logger.info("Unloading LLM")
logger.info("Unloading LLM")
self.llm.unload_model()
self.llm.unload_tokenizer()
self.clear_memory()
Expand All @@ -218,11 +218,11 @@ def handle_message_code(self, message, code):
code = code or MessageCode.STATUS
if code == MessageCode.ERROR:
traceback.print_stack()
Logger.error(message)
logger.error(message)
elif code == MessageCode.WARNING:
Logger.warning(message)
logger.warning(message)
elif code == MessageCode.STATUS:
Logger.info(message)
logger.info(message)

message = ""
first_message = True
Expand All @@ -243,7 +243,7 @@ def send_message(self, message, code=None):
is_bot=True,
)
})
if code == MessageCode.TEXT_STREAMED:
elif code == MessageCode.TEXT_STREAMED:
self.message += message
self.current_message += message
self.message = self.message.replace("</s>", "")
Expand Down Expand Up @@ -273,6 +273,8 @@ def send_message(self, message, code=None):
self.first_message = True
self.message = ""
self.current_message = ""

# self.stt.do_listen()

# if is_end_of_sentence and not is_end_of_message:
# # split on all sentence enders
Expand All @@ -297,6 +299,11 @@ def send_message(self, message, code=None):
# })
# self.message = ""
# self.current_message = ""
elif code == MessageCode.IMAGE_GENERATED:
self.message_var.emit(dict(
code=code,
message=message
))

current_message = ""

Expand Down Expand Up @@ -334,11 +341,14 @@ def clear_memory(self):
"""
Clear the GPU ram.
"""
Logger.info("Clearing memory")
logger.info("Clearing memory")
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()

def clear_llm_history(self):
if self.llm:
self.llm.clear_history()
self.llm.clear_history()

def stop(self):
self.stt.stop()
30 changes: 3 additions & 27 deletions src/airunner/aihandler/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,28 +348,12 @@ def handle_generate_request(self):
self._processing_request = True
kwargs = self.prepare_input_args()
self.do_set_seed(kwargs.get("seed"))

self.current_model_path = self.model_path
self.load_tokenizer()
self.load_streamer()
self.load_model()
if self.requested_generator_name == "visualqa":
self.load_processor()

# self.engine.send_message("Generating output")
# with torch.backends.cuda.sdp_kernel(
# enable_flash=True,
# enable_math=False,
# enable_mem_efficient=False
# ):
# with torch.no_grad():
# print("************** CALLING GENERATE")
# value = self.generate()
# print("VALUE", value)
# if self.callback:
# self.callback(value)
# else:
# self.engine.send_message(value, code=MessageCode.TEXT_GENERATED)
self._processing_request = True

def clear_conversation(self):
Expand All @@ -381,14 +365,7 @@ def do_generate(self, data):
self.process_data(data)
self.handle_request()
self.requested_generator_name = data["request_data"]["generator_name"]
return self.generate(
# app=self.app,
# endpoint=data["request_data"]["generator_name"],
# prompt=prompt,
# model=model_path,
# stream=data["request_data"]["stream"],
# images=[data["request_data"]["image"]],
)
return self.generate()

def generate(self):
Logger.info("Generating with LLM " + self.requested_generator_name)
Expand All @@ -399,7 +376,6 @@ def generate(self):
# Create an Environment object with the FileSystemLoader object
env = Environment(loader=file_loader)

# Load the template
# Load the template
chat_template = self.prompt_template#env.get_template('chat.j2')

Expand Down Expand Up @@ -481,8 +457,8 @@ def generate(self):
replaced = False
for new_text in self.streamer:
# strip all newlines from new_text
new_text = new_text.replace("\n", " ")
streamed_template += new_text
parsed_new_text = new_text.replace("\n", " ")
streamed_template += parsed_new_text
streamed_template = streamed_template.replace("<s> [INST]", "<s>[INST]")
# iterate over every character in rendered_template and
# check if we have the same character in streamed_template
Expand Down
68 changes: 42 additions & 26 deletions src/airunner/aihandler/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import logging
from airunner.aihandler.settings import LOG_LEVEL
import warnings
import time


class PrefixFilter(logging.Filter):
def __init__(self, prefix=''):
super().__init__()
self.prefix = prefix

def filter(self, record):
record.prefix = self.prefix
return True


class Logger:
Expand All @@ -16,63 +27,68 @@ class Logger:
WARNING = logging.WARNING
ERROR = logging.ERROR
FATAL = logging.FATAL

def __init__(self, *args, **kwargs):
self.prefix = kwargs.pop("prefix", "")
self.name = kwargs.pop("name", "AI Runner")
# Append current time to name to make it unique
self.name += f'_{time.time()}'
super().__init__()
self.logger = logging.getLogger(self.name)
self.formatter = logging.Formatter("%(asctime)s - AI RUNNER - %(levelname)s - %(prefix)s - %(message)s - %(lineno)d")
self.stream_handler = logging.StreamHandler()
self.stream_handler.setFormatter(self.formatter)

logger = logging.getLogger("AI Runner")
stream_handler = logging.StreamHandler()
# Add the prefix filter
self.stream_handler.addFilter(PrefixFilter(self.prefix))

formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(lineno)d")
# Check if StreamHandler is already added
if not any(isinstance(handler, logging.StreamHandler) for handler in self.logger.handlers):
self.logger.addHandler(self.stream_handler)

@classmethod
def set_level(cls, level):
self.set_level(LOG_LEVEL)
logging.getLogger("lightning").setLevel(logging.WARNING)
logging.getLogger("lightning_fabric.utilities.seed").setLevel(logging.WARNING)

def set_level(self, level):
"""
Set the logging level
:param level:
:return: None
"""
if level is None:
level = logging.DEBUG
cls.logger.setLevel(level)
cls.stream_handler.setLevel(level)
self.logger.setLevel(level)
self.stream_handler.setLevel(level)

@classmethod
def debug(cls, msg):
def debug(self, msg):
"""
Log info message
:param msg:
:return: None
"""
cls.logger.debug(msg)
self.logger.debug(msg)

@classmethod
def info(cls, msg):
def info(self, msg):
"""
Log info message
:param msg:
:return: None
"""
cls.logger.info(msg)
self.logger.info(msg)

@classmethod
def warning(cls, msg):
def warning(self, msg):
"""
Log warning message
:param msg:
:return: None
"""
cls.logger.warning(msg)
self.logger.warning(msg)

@classmethod
def error(cls, msg):
def error(self, msg):
"""
Log error message
:param msg:
:return: None
"""
cls.logger.error(msg)


Logger.set_level(LOG_LEVEL)
Logger.stream_handler.setFormatter(Logger.formatter)
Logger.logger.addHandler(Logger.stream_handler)
logging.getLogger("lightning").setLevel(logging.WARNING)
logging.getLogger("lightning_fabric.utilities.seed").setLevel(logging.WARNING)
self.logger.error(msg)
7 changes: 3 additions & 4 deletions src/airunner/aihandler/mixins/lora_mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
from airunner.aihandler.settings import LOG_LEVEL
from airunner.aihandler.logger import Logger as logger
import logging
logging.disable(LOG_LEVEL)
logger.set_level(logger.DEBUG)
from airunner.aihandler.logger import Logger

logger = Logger(prefix="LoraMixin")


class LoraMixin:
Expand Down
8 changes: 2 additions & 6 deletions src/airunner/aihandler/mixins/memory_efficient_mixin.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import functools
import os
import torch
from airunner.aihandler.settings import LOG_LEVEL
from airunner.aihandler.logger import Logger as logger
import logging
from airunner.aihandler.logger import Logger
from dataclasses import dataclass
import tomesd


logging.disable(LOG_LEVEL)
logger.set_level(logger.DEBUG)
logger = Logger(prefix="MemoryEfficientMixin")



Expand Down
8 changes: 3 additions & 5 deletions src/airunner/aihandler/mixins/merge_mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
from airunner.aihandler.settings import LOG_LEVEL
from airunner.aihandler.logger import Logger as logger
import logging
logging.disable(LOG_LEVEL)
logger.set_level(logger.DEBUG)
from airunner.aihandler.logger import Logger

logger = Logger(prefix="MergeMixin")


class MergeMixin:
Expand Down
Loading