Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a user profile selector. #6550

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 97 additions & 9 deletions modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_encoded_length,
get_max_prompt_length
)
from modules.utils import delete_file, get_available_characters, save_file
from modules.utils import delete_file, get_available_characters, get_available_users, save_file

# Copied from the Transformers library
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
Expand Down Expand Up @@ -675,6 +675,13 @@ def update_character_menu_after_deletion(idx):
return gr.update(choices=characters, value=characters[idx])


def update_user_menu_after_deletion(idx):
users = utils.get_available_users()
idx = min(int(idx), len(users) - 1)
idx = max(0, idx)
return gr.update(choices=users, value=users[idx])


def load_history(unique_id, character, mode):
p = get_history_file_path(unique_id, character, mode)

Expand Down Expand Up @@ -739,6 +746,8 @@ def load_character(character, name1, name2):
context = greeting = ""
greeting_field = 'greeting'
picture = None
your_picture = None
profile_name = name1

filepath = None
for extension in ["yml", "yaml", "json"]:
Expand All @@ -759,6 +768,7 @@ def load_character(character, name1, name2):
path.unlink()

picture = generate_pfp_cache(character)
your_picture = upload_your_profile_picture(picture, profile_name)

# Finding the bot's name
for k in ['name', 'bot', '<|bot|>', 'char_name']:
Expand Down Expand Up @@ -874,18 +884,27 @@ def check_tavern_character(img):
return _json['name'], _json['description'], _json, gr.update(interactive=True)


def upload_your_profile_picture(img):
def upload_your_profile_picture(img, profile_name):
cache_folder = Path(shared.args.disk_cache_dir)
if not cache_folder.exists():
cache_folder.mkdir()

if img is None:
if Path(f"{cache_folder}/pfp_me.png").exists():
Path(f"{cache_folder}/pfp_me.png").unlink()
if profile_name is None:
if img is None:
if Path(f"{cache_folder}/pfp_me.png").exists():
Path(f"{cache_folder}/pfp_me.png").unlink()
else:
img = make_thumbnail(img)
img.save(Path(f'{cache_folder}/pfp_me.png'))
logger.info(f'Profile picture saved to "{cache_folder}/pfp_me.png"')
else:
img = make_thumbnail(img)
img.save(Path(f'{cache_folder}/pfp_me.png'))
logger.info(f'Profile picture saved to "{cache_folder}/pfp_me.png"')
for img_path in [Path(f"users/{profile_name}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
if img_path.exists():
with Image.open(img_path) as img:
img = make_thumbnail(img)
img.save(Path(f'{cache_folder}/pfp_me.png'), format='PNG')

return None


def generate_character_yaml(name, greeting, context):
Expand All @@ -899,6 +918,16 @@ def generate_character_yaml(name, greeting, context):
return yaml.dump(data, sort_keys=False, width=float("inf"))


def generate_user_yaml(name1, user_bio):
data = {
'name': name1,
'user_bio': user_bio,
}

data = {k: v for k, v in data.items() if v} # Strip falsy
return yaml.dump(data, sort_keys=False, width=float("inf"))


def generate_instruction_template_yaml(instruction_template):
data = {
'instruction_template': instruction_template
Expand All @@ -921,13 +950,34 @@ def save_character(name, greeting, context, picture, filename):
logger.info(f'Saved {path_to_img}.')


def save_user(name1, user_bio, your_picture, filename):
if filename == "":
logger.error("The filename is empty, so the user will not be saved.")
return

data = generate_user_yaml(name1, user_bio)
filepath = Path(f'users/{filename}.yaml')
save_file(filepath, data)
path_to_img = Path(f'users/{filename}.png')
if your_picture is not None:
your_picture.save(path_to_img)
logger.info(f'Saved {path_to_img}.')


def delete_character(name, instruct=False):
for extension in ["yml", "yaml", "json"]:
delete_file(Path(f'characters/{name}.{extension}'))

delete_file(Path(f'characters/{name}.png'))


def delete_user(name, instruct=False):
for extension in ["yml", "yaml", "json"]:
delete_file(Path(f'users/{name}.{extension}'))

delete_file(Path(f'users/{name}.png'))


def jinja_template_from_old_format(params, verbose=False):
MASTER_TEMPLATE = """
{%- set ns = namespace(found=false) -%}
Expand Down Expand Up @@ -1188,6 +1238,37 @@ def handle_character_menu_change(state):
]


def load_user_profile(profile_name):
user_folder = Path('users')
base_file = user_folder / profile_name
yaml_file = next((base_file.with_suffix(ext) for ext in ['.yaml', '.yml'] if base_file.with_suffix(ext).exists()), None)
img_file = base_file.with_suffix('.png')

data = {}
if yaml_file:
with open(yaml_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)

picture = None
if img_file.exists():
from PIL import Image
picture = Image.open(img_file)
upload_your_profile_picture(picture, profile_name)

return {
'name1': data.get('name', ''),
'user_bio': data.get('user_bio', ''),
'your_picture': picture
}


def update_user_fields(profile_name):
if profile_name:
user_data = load_user_profile(profile_name)
return user_data['name1'], user_data['user_bio'], user_data['your_picture']
return '', '', None


def handle_mode_change(state):
history = load_latest_history(state)
histories = find_all_histories_with_first_prompts(state)
Expand Down Expand Up @@ -1216,6 +1297,13 @@ def handle_save_character_click(name2):
]


def handle_save_user_click(name1):
return [
name1,
gr.update(visible=True)
]


def handle_load_template_click(instruction_template):
output = load_instruction_template(instruction_template)
return [
Expand Down Expand Up @@ -1243,7 +1331,7 @@ def handle_delete_template_click(template):


def handle_your_picture_change(picture, state):
upload_your_profile_picture(picture)
upload_your_profile_picture(picture, profile_name=None)
html = redraw_html(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], reset_cache=True)

return html
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
'prompt-default': 'QA',
'prompt-notebook': 'QA',
'character': 'Assistant',
'user': 'You',
'name1': 'You',
'user_bio': '',
'custom_system_message': '',
Expand Down
2 changes: 2 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def list_interface_input_elements():
'chat_style',
'chat-instruct_command',
'character_menu',
'user_menu',
'name2',
'context',
'greeting',
Expand Down Expand Up @@ -295,6 +296,7 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state):
output['prompt-default'] = state['prompt_menu-default']
output['prompt-notebook'] = state['prompt_menu-notebook']
output['character'] = state['character_menu']
output['user'] = state['user_menu']
output['default_extensions'] = extensions_list
output['seed'] = int(output['seed'])
output['show_controls'] = show_controls
Expand Down
17 changes: 15 additions & 2 deletions modules/ui_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from modules.utils import gradio

inputs = ('Chat input', 'interface_state')
reload_arr = ('history', 'name1', 'name2', 'mode', 'chat_style', 'character_menu')
reload_arr = ('history', 'name1', 'name2', 'mode', 'chat_style', 'character_menu', 'user_menu')


def create_ui():
Expand Down Expand Up @@ -114,7 +114,13 @@ def create_chat_settings_ui():
shared.gradio['greeting'] = gr.Textbox(value='', lines=5, label='Greeting', elem_classes=['add_scrollbar'])

with gr.Tab("User"):
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Name')
with gr.Row():
shared.gradio['user_menu'] = gr.Dropdown(value=shared.settings['name1'], choices=utils.get_available_users(), label='User', elem_id='user-menu', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['user_menu'], lambda: None, lambda: {'choices': utils.get_available_users()}, 'refresh-button', interactive=not mu)
shared.gradio['save_user'] = gr.Button('💾', elem_classes='refresh-button', elem_id="save-user", interactive=not mu)
shared.gradio['delete_user'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu)

shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='User\'s name')
shared.gradio['user_bio'] = gr.Textbox(value=shared.settings['user_bio'], lines=10, label='Description', info='Here you can optionally write a description of yourself.', placeholder='{{user}}\'s personality: ...', elem_classes=['add_scrollbar'])

with gr.Tab('Chat history'):
Expand Down Expand Up @@ -281,6 +287,11 @@ def create_event_handlers():
chat.handle_character_menu_change, gradio('interface_state'), gradio('history', 'display', 'name1', 'name2', 'character_picture', 'greeting', 'context', 'unique_id'), show_progress=False).then(
None, None, None, js=f'() => {{{ui.update_big_picture_js}; updateBigPicture()}}')

shared.gradio['user_menu'].change(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.update_user_fields, inputs=shared.gradio['user_menu'], outputs=[shared.gradio['name1'], shared.gradio['user_bio'], shared.gradio['your_picture']]
)

shared.gradio['mode'].change(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_mode_change, gradio('interface_state'), gradio('history', 'display', 'chat_style', 'chat-instruct_command', 'unique_id'), show_progress=False).then(
Expand All @@ -292,6 +303,8 @@ def create_event_handlers():
# Save/delete a character
shared.gradio['save_character'].click(chat.handle_save_character_click, gradio('name2'), gradio('save_character_filename', 'character_saver'), show_progress=False)
shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'), show_progress=False)
shared.gradio['save_user'].click(chat.handle_save_user_click, gradio('name1'), gradio('save_user_filename', 'user_saver'), show_progress=False)
shared.gradio['delete_user'].click(lambda: gr.update(visible=True), None, gradio('user_deleter'), show_progress=False)
shared.gradio['load_template'].click(chat.handle_load_template_click, gradio('instruction_template'), gradio('instruction_template_str', 'instruction_template'), show_progress=False)
shared.gradio['save_template'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
Expand Down
47 changes: 47 additions & 0 deletions modules/ui_file_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ def create_ui():
shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu)

# User saver/deleter
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_saver']:
shared.gradio['save_user_filename'] = gr.Textbox(lines=1, label='File name', info='The user will be saved to your users/ folder with this base filename.')
with gr.Row():
shared.gradio['save_user_cancel'] = gr.Button('Cancel', elem_classes="small-button")
shared.gradio['save_user_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)

with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_deleter']:
gr.Markdown('Confirm the user deletion?')
with gr.Row():
shared.gradio['delete_user_cancel'] = gr.Button('Cancel', elem_classes="small-button")
shared.gradio['delete_user_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu)

# Preset saver
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']:
shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info='The preset will be saved to your presets/ folder with this base filename.')
Expand All @@ -62,12 +75,16 @@ def create_event_handlers():
shared.gradio['delete_confirm'].click(handle_delete_confirm_click, gradio('delete_root', 'delete_filename'), gradio('file_deleter'), show_progress=False)
shared.gradio['save_character_confirm'].click(handle_save_character_confirm_click, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), gradio('character_menu', 'character_saver'), show_progress=False)
shared.gradio['delete_character_confirm'].click(handle_delete_character_confirm_click, gradio('character_menu'), gradio('character_menu', 'character_deleter'), show_progress=False)
shared.gradio['save_user_confirm'].click(handle_save_user_confirm_click, gradio('name1', 'user_bio', 'your_picture', 'save_user_filename'), gradio('user_menu', 'user_saver'), show_progress=False)
shared.gradio['delete_user_confirm'].click(handle_delete_user_confirm_click, gradio('user_menu'), gradio('user_menu', 'user_deleter'), show_progress=False)

shared.gradio['save_preset_cancel'].click(lambda: gr.update(visible=False), None, gradio('preset_saver'), show_progress=False)
shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_saver'))
shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_deleter'))
shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'), show_progress=False)
shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'), show_progress=False)
shared.gradio['save_user_cancel'].click(lambda: gr.update(visible=False), None, gradio('user_saver'), show_progress=False)
shared.gradio['delete_user_cancel'].click(lambda: gr.update(visible=False), None, gradio('user_deleter'), show_progress=False)


def handle_save_preset_confirm_click(filename, contents):
Expand Down Expand Up @@ -118,6 +135,21 @@ def handle_save_character_confirm_click(name2, greeting, context, character_pict
]


def handle_save_user_confirm_click(name1, user_bio, your_picture, filename):
try:
chat.save_user(name1, user_bio, your_picture, filename)
available_users = utils.get_available_users()
output = gr.update(choices=available_users, value=filename)
except Exception:
output = gr.update()
traceback.print_exc()

return [
output,
gr.update(visible=False)
]


def handle_delete_character_confirm_click(character):
try:
index = str(utils.get_available_characters().index(character))
Expand All @@ -133,6 +165,21 @@ def handle_delete_character_confirm_click(character):
]


def handle_delete_user_confirm_click(user):
try:
index = str(utils.get_available_users().index(user))
chat.delete_user(user)
output = chat.update_user_menu_after_deletion(index)
except Exception:
output = gr.update()
traceback.print_exc()

return [
output,
gr.update(visible=False)
]


def handle_save_preset_click(state):
contents = presets.generate_preset_yaml(state)
return [
Expand Down
8 changes: 8 additions & 0 deletions modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def get_available_characters():
return sorted(set((k.stem for k in paths)), key=natural_keys)


def get_available_users():
if not Path('users').exists():
Path('users').mkdir()

paths = (x for x in Path('users').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return sorted(set((k.stem for k in paths)), key=natural_keys)


def get_available_instruction_templates():
path = "instruction-templates"
paths = []
Expand Down