Skip to content

Commit

Permalink
fixes offline mode and makes some improvements to chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Mar 25, 2023
1 parent 287ebd3 commit 0878586
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 50 deletions.
44 changes: 25 additions & 19 deletions chatai/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,25 @@ def message_handler(self, *args, **kwargs):
response = message["response"]
response = response.replace("<pad>", "")
response = response.replace("<unk>", "")
incomplete = False
if "</s>" not in response:
# remove all tokens after </s> and </s> itself
incomplete = True
formatted_response = ""
special_character = "</s>"
if type == "do_action":
self.ui.generated_text.appendPlainText(response)
else:
response = response[: response.find("</s>")]
response = response.strip()

if not incomplete:
formatted_response = f"{botname} says: \"{response}\"\n"
self.conversation.add_message(botname, response)
self.ui.generated_text.appendPlainText(formatted_response)
else:
self.chatbot_generate()
incomplete = False
if special_character not in response:
# remove all tokens after </s> and </s> itself
incomplete = True
else:
response = response[: response.find(special_character)]
response = response.strip()

if not incomplete:
formatted_response = f"{botname} says: \"{response}\"\n"
self.conversation.add_message(botname, response)
self.ui.generated_text.appendPlainText(formatted_response)
else:
self.chatbot_generate()

self.stop_progress_bar()
self.enable_buttons()
Expand Down Expand Up @@ -144,13 +149,13 @@ def initialize_form(self):

def advanced_settings(self):
HERE = os.path.dirname(os.path.abspath(__file__))
advanced_settings_window = uic.loadUi(os.path.join(HERE, "pyqt/llmrunner/advanced_settings.ui"))
advanced_settings_window = uic.loadUi(os.path.join(HERE, "pyqt/advanced_settings.ui"))
advanced_settings_window.exec()

def about(self):
# display pyqt/about.ui popup window
HERE = os.path.dirname(os.path.abspath(__file__))
about_window = uic.loadUi(os.path.join(HERE, "pyqt/llmrunner/about.ui"))
about_window = uic.loadUi(os.path.join(HERE, "pyqt/about.ui"))
about_window.setWindowTitle(f"About Chat AI")
about_window.title.setText(f"Chat AI v{VERSION}")
about_window.exec()
Expand Down Expand Up @@ -225,15 +230,16 @@ def chatbot_generate(self):
username = self.ui.username.text()
botname = self.ui.botname.text()
user_input = self.ui.prompt.text()
llm_action = self.ui.action.currentText()
self.ui.prompt.setText("")
self.conversation.add_message(username, user_input)
if llm_action == "action":
llm_action = "do_action"
else:
self.conversation.add_message(username, user_input)
self.ui.generated_text.setPlainText(f"{self.conversation.dialogue}")
self.ui.prompt.setText("")
properties = self.prep_properties()
# get current action and set it on properties
llm_action = self.ui.action.currentText()
if llm_action == "action":
llm_action = "do_action"
self.client.message = {
"action": "llm",
"type": llm_action,
Expand Down
48 changes: 41 additions & 7 deletions chatai/conversation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
class Conversation:
summary_length = 20
special_characters = [
"</s>",
"<pad>",
"<unk>"
]
def __init__(self, app):
self.app = app
self.converation_summary = ""
self._dialogue = []
self._summaries = []

@property
def summary_prompt(self):
def strip_special_characters(self, string):
for special_character in self.special_characters:
string = string.replace(special_character, "")
return string

def summary_prompt(self, username, botname):
# print("Summarizing")
# self.app.runner.load_summarizer()
# prompt = self.dialogue
Expand All @@ -22,7 +31,7 @@ def summary_prompt(self):
#
# self.app.runner.load_model(self.app.runner.model_name)
# return summary
return f"{self.dialogue}\n\nSummarize the conversation: </s>"
return f"<extra_id_0>{self.dialogue} <extra_id_1>Summarize:"

@property
def dialogue_length(self):
Expand All @@ -33,6 +42,14 @@ def do_summary(self):
#return self.dialogue_length >= self.summary_length
return False

def add_action(self, username, action, do_summary=False):
self._dialogue.append({
"username": username,
"action": action
})



def add_message(self, username, message):
self._dialogue.append({
"username": username,
Expand All @@ -41,16 +58,29 @@ def add_message(self, username, message):

@property
def dialogue(self):
return "\n".join([f"{message['username']} says: \"{message['message']}\"" for message in self._dialogue])
formatted_messages = []
for message in self._dialogue:
if "action" in message:
formatted_messages.append(f" {message['action']}")
else:
formatted_messages.append(f"{message['username']} says: \"{message['message']}\"")
return "\n".join(formatted_messages)

@dialogue.setter
def dialogue(self, dialogue):
self._dialogue = dialogue

@property
def dialogue_no_action(self):
formatted_messages = []
for message in self._dialogue:
if "action" not in message:
formatted_messages.append(f"{message['username']} says: \"{message['message']}\"")
return "\n".join(formatted_messages)


def update_summary(self, summary):
self.converation_summary = summary
# keep last two items in dialogue
self._dialogue = self._dialogue[-2:]

def format_user_sentiment_prompt(self, botname, username):
prompt = f"Context: {botname} and {username} are having a conversation. {self.converation_summary}\n\n{self.dialogue}\n\n{username}'s sentiment is "
Expand All @@ -60,11 +90,15 @@ def format_prompt(self, botname, username, mood, user_sentiment):
prompt = f"Context: {botname} and {username} are having a conversation.\n{self.converation_summary}\n{botname}'s mood: {mood}\n{username}'s sentiment: {user_sentiment}\n\n{self.dialogue}\n{botname} says: \""
return prompt

def format_random_prompt(self, botname, username, mood, user_sentiment, random_prompt):
prompt = f"Context: {botname} and {username} are having a conversation.\n{self.converation_summary}\n<extra_id_0>Generate a random event that could happen next. <extra_id_1>{random_prompt}: "
return prompt

def format_conversation(self, user_name):
pass

def format_action_prompt(self, botname, username, action):
return f"Context: {botname} and {username} are having a conversation. suddenly {username} {action}. What does {botname} do? "
return f"<extra_id_0>Context: {botname} and {username} are having a conversation. {username} {action}. <extra_id_1>What happens next? "

def get_bot_mood_prompt(self, botname, username):
return f"Context: {botname} and {username} are having a conversation. {self.converation_summary}\n\n{self.dialogue}\n\n{botname}'s mood is "
64 changes: 40 additions & 24 deletions chatai/llmrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def clear_gpu_cache(self):
torch.cuda.empty_cache()
gc.collect()

def load_model(self, model_name, pipeline_type=None, offline=False):
def load_model(self, model_name, pipeline_type=None, offline=True):
if self.model_name == model_name and self.model and self.tokenizer:
return
local_files_only = offline
Expand Down Expand Up @@ -225,21 +225,18 @@ def generate_chat_prompt(self, user_input, **kwargs):

# summarize the conversation if needed
if conversation.do_summary:
results = self.generate(prompt=conversation.summary_prompt, seed=self.seed, **properties, return_result=True)
results = self.generate(prompt=conversation.summary_prompt(username, botname), seed=self.seed, **properties, return_result=True, skip_special_tokens=True)
summary = results
conversation.update_summary(summary)

# let's tokenize the conversation


# get the bot's mood
mood_prompt = conversation.get_bot_mood_prompt(botname, username)
mood_results = self.generate(mood_prompt, seed=self.seed, **properties, return_result=True)
mood_results = self.generate(mood_prompt, seed=self.seed, **properties, return_result=True, skip_special_tokens=True)
mood = mood_results

# get the user's sentiment
sentiment_prompt = conversation.format_user_sentiment_prompt(botname, username)
sentiment_results = self.generate(sentiment_prompt, seed=self.seed, **properties, return_result=True)
sentiment_results = self.generate(sentiment_prompt, seed=self.seed, **properties, return_result=True, skip_special_tokens=True)
user_sentiment = sentiment_results

# setup the prompt for the bot response
Expand All @@ -250,7 +247,33 @@ def generate_action_prompt(self, user_input, **kwargs):
botname = kwargs.get("botname")
username = kwargs.get("username")
conversation = kwargs.get("conversation")
return conversation.format_action_prompt(botname, username, user_input)
action_prompt = conversation.format_action_prompt(botname, username, user_input)
action_results = self.generate(action_prompt, seed=self.seed, **kwargs.get("properties", {}), return_result=True, skip_special_tokens=True)

# format action and reaction
formatted_action_result = f"{username} {user_input}."
action_results = action_results.strip()
# if not action_results.startswith(botname):
# formatted_action_result += f" {botname} {action_results}"
# else:
formatted_action_result += f" {action_results}"

conversation.add_action(username, formatted_action_result)
summary_prompt = conversation.summary_prompt(username, botname)
properties = kwargs.get("properties", {})
properties["temperature"] = 1
properties["top_k"] = 40
properties["repetition_penalty"] = 10.0
properties["top_p"] = 0.9
properties["num_beams"] = 3
summary_results = self.generate(prompt=summary_prompt, seed=self.seed, **properties, return_result=True, skip_special_tokens=True)
conversation.update_summary(summary_results)
# Add summary prompt to conversation
self.set_message({
"type": "do_action",
"botname": botname,
"response": formatted_action_result
})

def generate(
self,
Expand Down Expand Up @@ -315,27 +338,20 @@ def generate(
if not prompt:
if type == "chat":
prompt = self.generate_chat_prompt(user_input, properties=properties, botname=botname,username=username,conversation=conversation)
top_k = 20
top_k = 30
top_p = 0.9
num_beams = 3
repetition_penalty = 1.5
num_beams = 6
repetition_penalty = 20.0
early_stopping = True
max_length = 100
min_length = 0
max_length = 512
min_length = 30
temperature = 1.0
skip_special_tokens = False
elif type == "do_action":
prompt = self.generate_action_prompt(user_input, properties=properties, botname=botname,
username=username, conversation=conversation)
top_k = 20
top_p = 0.9
num_beams = 3
repetition_penalty = 1.5
early_stopping = True
max_length = 100
min_length = 0
temperature = 1.0
skip_special_tokens = False
self.generate_action_prompt(
user_input, properties=properties, botname=botname,
username=username, conversation=conversation)
return
elif type == "generate_characters":
self.generate_character_prompt(**properties)
return
Expand Down

0 comments on commit 0878586

Please sign in to comment.