From 08785865db53b24254d4c0162e6ff5b1235e3536 Mon Sep 17 00:00:00 2001 From: w4ffl35 Date: Fri, 24 Mar 2023 21:24:53 -0600 Subject: [PATCH] fixes offline mode and makes some improvements to chatbot --- chatai/chatbot.py | 44 ++++++++++++++++------------- chatai/conversation.py | 48 ++++++++++++++++++++++++++----- chatai/llmrunner.py | 64 ++++++++++++++++++++++++++---------------- 3 files changed, 106 insertions(+), 50 deletions(-) diff --git a/chatai/chatbot.py b/chatai/chatbot.py index 352dfb6..53b8329 100644 --- a/chatai/chatbot.py +++ b/chatai/chatbot.py @@ -26,20 +26,25 @@ def message_handler(self, *args, **kwargs): response = message["response"] response = response.replace("", "") response = response.replace("", "") - incomplete = False - if "" not in response: - # remove all tokens after and itself - incomplete = True + formatted_response = "" + special_character = "" + if type == "do_action": + self.ui.generated_text.appendPlainText(response) else: - response = response[: response.find("")] - response = response.strip() - - if not incomplete: - formatted_response = f"{botname} says: \"{response}\"\n" - self.conversation.add_message(botname, response) - self.ui.generated_text.appendPlainText(formatted_response) - else: - self.chatbot_generate() + incomplete = False + if special_character not in response: + # remove all tokens after and itself + incomplete = True + else: + response = response[: response.find(special_character)] + response = response.strip() + + if not incomplete: + formatted_response = f"{botname} says: \"{response}\"\n" + self.conversation.add_message(botname, response) + self.ui.generated_text.appendPlainText(formatted_response) + else: + self.chatbot_generate() self.stop_progress_bar() self.enable_buttons() @@ -144,13 +149,13 @@ def initialize_form(self): def advanced_settings(self): HERE = os.path.dirname(os.path.abspath(__file__)) - advanced_settings_window = uic.loadUi(os.path.join(HERE, "pyqt/llmrunner/advanced_settings.ui")) + advanced_settings_window = uic.loadUi(os.path.join(HERE, "pyqt/advanced_settings.ui")) advanced_settings_window.exec() def about(self): # display pyqt/about.ui popup window HERE = os.path.dirname(os.path.abspath(__file__)) - about_window = uic.loadUi(os.path.join(HERE, "pyqt/llmrunner/about.ui")) + about_window = uic.loadUi(os.path.join(HERE, "pyqt/about.ui")) about_window.setWindowTitle(f"About Chat AI") about_window.title.setText(f"Chat AI v{VERSION}") about_window.exec() @@ -225,15 +230,16 @@ def chatbot_generate(self): username = self.ui.username.text() botname = self.ui.botname.text() user_input = self.ui.prompt.text() + llm_action = self.ui.action.currentText() self.ui.prompt.setText("") - self.conversation.add_message(username, user_input) + if llm_action == "action": + llm_action = "do_action" + else: + self.conversation.add_message(username, user_input) self.ui.generated_text.setPlainText(f"{self.conversation.dialogue}") self.ui.prompt.setText("") properties = self.prep_properties() # get current action and set it on properties - llm_action = self.ui.action.currentText() - if llm_action == "action": - llm_action = "do_action" self.client.message = { "action": "llm", "type": llm_action, diff --git a/chatai/conversation.py b/chatai/conversation.py index a1b2eef..d230d41 100644 --- a/chatai/conversation.py +++ b/chatai/conversation.py @@ -1,13 +1,22 @@ class Conversation: summary_length = 20 + special_characters = [ + "", + "", + "" + ] def __init__(self, app): self.app = app self.converation_summary = "" self._dialogue = [] self._summaries = [] - @property - def summary_prompt(self): + def strip_special_characters(self, string): + for special_character in self.special_characters: + string = string.replace(special_character, "") + return string + + def summary_prompt(self, username, botname): # print("Summarizing") # self.app.runner.load_summarizer() # prompt = self.dialogue @@ -22,7 +31,7 @@ def summary_prompt(self): # # self.app.runner.load_model(self.app.runner.model_name) # return summary - return f"{self.dialogue}\n\nSummarize the conversation: " + return f"{self.dialogue} Summarize:" @property def dialogue_length(self): @@ -33,6 +42,14 @@ def do_summary(self): #return self.dialogue_length >= self.summary_length return False + def add_action(self, username, action, do_summary=False): + self._dialogue.append({ + "username": username, + "action": action + }) + + + def add_message(self, username, message): self._dialogue.append({ "username": username, @@ -41,16 +58,29 @@ def add_message(self, username, message): @property def dialogue(self): - return "\n".join([f"{message['username']} says: \"{message['message']}\"" for message in self._dialogue]) + formatted_messages = [] + for message in self._dialogue: + if "action" in message: + formatted_messages.append(f" {message['action']}") + else: + formatted_messages.append(f"{message['username']} says: \"{message['message']}\"") + return "\n".join(formatted_messages) @dialogue.setter def dialogue(self, dialogue): self._dialogue = dialogue + @property + def dialogue_no_action(self): + formatted_messages = [] + for message in self._dialogue: + if "action" not in message: + formatted_messages.append(f"{message['username']} says: \"{message['message']}\"") + return "\n".join(formatted_messages) + + def update_summary(self, summary): self.converation_summary = summary - # keep last two items in dialogue - self._dialogue = self._dialogue[-2:] def format_user_sentiment_prompt(self, botname, username): prompt = f"Context: {botname} and {username} are having a conversation. {self.converation_summary}\n\n{self.dialogue}\n\n{username}'s sentiment is " @@ -60,11 +90,15 @@ def format_prompt(self, botname, username, mood, user_sentiment): prompt = f"Context: {botname} and {username} are having a conversation.\n{self.converation_summary}\n{botname}'s mood: {mood}\n{username}'s sentiment: {user_sentiment}\n\n{self.dialogue}\n{botname} says: \"" return prompt + def format_random_prompt(self, botname, username, mood, user_sentiment, random_prompt): + prompt = f"Context: {botname} and {username} are having a conversation.\n{self.converation_summary}\nGenerate a random event that could happen next. {random_prompt}: " + return prompt + def format_conversation(self, user_name): pass def format_action_prompt(self, botname, username, action): - return f"Context: {botname} and {username} are having a conversation. suddenly {username} {action}. What does {botname} do? " + return f"Context: {botname} and {username} are having a conversation. {username} {action}. What happens next? " def get_bot_mood_prompt(self, botname, username): return f"Context: {botname} and {username} are having a conversation. {self.converation_summary}\n\n{self.dialogue}\n\n{botname}'s mood is " diff --git a/chatai/llmrunner.py b/chatai/llmrunner.py index e979971..a6811fe 100644 --- a/chatai/llmrunner.py +++ b/chatai/llmrunner.py @@ -115,7 +115,7 @@ def clear_gpu_cache(self): torch.cuda.empty_cache() gc.collect() - def load_model(self, model_name, pipeline_type=None, offline=False): + def load_model(self, model_name, pipeline_type=None, offline=True): if self.model_name == model_name and self.model and self.tokenizer: return local_files_only = offline @@ -225,21 +225,18 @@ def generate_chat_prompt(self, user_input, **kwargs): # summarize the conversation if needed if conversation.do_summary: - results = self.generate(prompt=conversation.summary_prompt, seed=self.seed, **properties, return_result=True) + results = self.generate(prompt=conversation.summary_prompt(username, botname), seed=self.seed, **properties, return_result=True, skip_special_tokens=True) summary = results conversation.update_summary(summary) - # let's tokenize the conversation - - # get the bot's mood mood_prompt = conversation.get_bot_mood_prompt(botname, username) - mood_results = self.generate(mood_prompt, seed=self.seed, **properties, return_result=True) + mood_results = self.generate(mood_prompt, seed=self.seed, **properties, return_result=True, skip_special_tokens=True) mood = mood_results # get the user's sentiment sentiment_prompt = conversation.format_user_sentiment_prompt(botname, username) - sentiment_results = self.generate(sentiment_prompt, seed=self.seed, **properties, return_result=True) + sentiment_results = self.generate(sentiment_prompt, seed=self.seed, **properties, return_result=True, skip_special_tokens=True) user_sentiment = sentiment_results # setup the prompt for the bot response @@ -250,7 +247,33 @@ def generate_action_prompt(self, user_input, **kwargs): botname = kwargs.get("botname") username = kwargs.get("username") conversation = kwargs.get("conversation") - return conversation.format_action_prompt(botname, username, user_input) + action_prompt = conversation.format_action_prompt(botname, username, user_input) + action_results = self.generate(action_prompt, seed=self.seed, **kwargs.get("properties", {}), return_result=True, skip_special_tokens=True) + + # format action and reaction + formatted_action_result = f"{username} {user_input}." + action_results = action_results.strip() + # if not action_results.startswith(botname): + # formatted_action_result += f" {botname} {action_results}" + # else: + formatted_action_result += f" {action_results}" + + conversation.add_action(username, formatted_action_result) + summary_prompt = conversation.summary_prompt(username, botname) + properties = kwargs.get("properties", {}) + properties["temperature"] = 1 + properties["top_k"] = 40 + properties["repetition_penalty"] = 10.0 + properties["top_p"] = 0.9 + properties["num_beams"] = 3 + summary_results = self.generate(prompt=summary_prompt, seed=self.seed, **properties, return_result=True, skip_special_tokens=True) + conversation.update_summary(summary_results) + # Add summary prompt to conversation + self.set_message({ + "type": "do_action", + "botname": botname, + "response": formatted_action_result + }) def generate( self, @@ -315,27 +338,20 @@ def generate( if not prompt: if type == "chat": prompt = self.generate_chat_prompt(user_input, properties=properties, botname=botname,username=username,conversation=conversation) - top_k = 20 + top_k = 30 top_p = 0.9 - num_beams = 3 - repetition_penalty = 1.5 + num_beams = 6 + repetition_penalty = 20.0 early_stopping = True - max_length = 100 - min_length = 0 + max_length = 512 + min_length = 30 temperature = 1.0 skip_special_tokens = False elif type == "do_action": - prompt = self.generate_action_prompt(user_input, properties=properties, botname=botname, - username=username, conversation=conversation) - top_k = 20 - top_p = 0.9 - num_beams = 3 - repetition_penalty = 1.5 - early_stopping = True - max_length = 100 - min_length = 0 - temperature = 1.0 - skip_special_tokens = False + self.generate_action_prompt( + user_input, properties=properties, botname=botname, + username=username, conversation=conversation) + return elif type == "generate_characters": self.generate_character_prompt(**properties) return