diff --git a/README.md b/README.md index 97c307199..d8dfdf207 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ --- -### Stable Diffusion and Kandinsky on your own hardware +### Stable Diffusion on your own hardware No web server to run, additional requirements to install or technical knowledge required. diff --git a/src/airunner/aihandler/llm.py b/src/airunner/aihandler/llm.py index 158ec85da..6319d7822 100644 --- a/src/airunner/aihandler/llm.py +++ b/src/airunner/aihandler/llm.py @@ -1,5 +1,13 @@ +import torch +import traceback + from airunner.aihandler.transformer_runner import TransformerRunner from airunner.aihandler.logger import Logger as logger +from airunner.aihandler.enums import MessageCode +import os +from jinja2 import Environment, FileSystemLoader + +from transformers.pipelines.conversational import Conversation class LLM(TransformerRunner): @@ -8,34 +16,122 @@ def clear_conversation(self): self.chain.clear() def do_generate(self, data): + Logger.info("Generating with LLM") self.process_data(data) self.handle_request() self.requested_generator_name = data["request_data"]["generator_name"] prompt = data["request_data"]["prompt"] model_path = data["request_data"]["model_path"] - self.generate( - app=self.app, - endpoint=data["request_data"]["generator_name"], - prompt=prompt, - model=model_path, - stream=data["request_data"]["stream"], - images=[data["request_data"]["image"]], + return self.generate( + # app=self.app, + # endpoint=data["request_data"]["generator_name"], + # prompt=prompt, + # model=model_path, + # stream=data["request_data"]["stream"], + # images=[data["request_data"]["image"]], ) + + history = [] + + def generate(self): + # Create a FileSystemLoader object with the directory of the template + HERE = os.path.dirname(os.path.abspath(__file__)) + file_loader = FileSystemLoader(os.path.join(HERE, "chat_templates")) + + # Create an Environment object with the FileSystemLoader object + env = Environment(loader=file_loader) - def generate(self, **kwargs): + # Load the template + # Load the template + chat_template = env.get_template('chat.j2') + + prompt = self.prompt + if prompt is None or prompt == "": + traceback.print_stack() + return + if self.generator.name == "casuallm": - prompt = kwargs.get("prompt", "") - logger.info(f"LLM requested with prompt {prompt}") - return self.chain.run(prompt) + history = [] + for message in self.history: + if message["role"] == "user": + history.append("[INST]" + self.username + ': "'+ message["content"] +'"[/INST]') + else: + history.append(self.botname + ': "'+ message["content"] +'"') + history = "\n".join(history) + if history == "": + history = None + + # Create a dictionary with the variables + variables = { + "username": self.username, + "botname": self.botname, + "history": history, + "input": prompt, + "bos_token": self.tokenizer.bos_token, + "botmood": "angry. He hates " + self.username + } + + self.history.append({ + "role": "user", + "content": prompt + }) + + # Render the template with the variables + rendered_template = chat_template.render(variables) + #print(rendered_template) + + # Encode the rendered template + encoded = self.tokenizer.encode(rendered_template, return_tensors="pt") + + model_inputs = encoded.to("cuda" if torch.cuda.is_available() else "cpu") + + # Generate the response + generated_ids = self.model.generate( + model_inputs, + min_length=0, + max_length=1000, + num_beams=1, + do_sample=True, + top_k=20, + eta_cutoff=10, + top_p=1.0, + num_return_sequences=self.sequences, + eos_token_id=self.tokenizer.eos_token_id, + early_stopping=True, + repetition_penalty=1.15, + temperature=0.7, + ) + + # Decode the new tokens + decoded = self.tokenizer.batch_decode(generated_ids)[0] + decoded = decoded.replace(self.tokenizer.batch_decode(model_inputs)[0], "") + decoded = decoded.replace("", "") + + # Extract the actual message content + start_index = decoded.find('"') + 1 + end_index = decoded.rfind('"') + decoded = decoded[start_index:end_index] + + self.history.append({ + "role": "assistant", + "content": decoded + }) + + # print(self.history) + + # print("*"*80) + # print(decoded) + + #return decoded + self.engine.send_message(decoded, code=MessageCode.TEXT_GENERATED) elif self.generator.name == "visualqa": inputs = self.processor( self.image, - self.prompt, + prompt, return_tensors="pt" ).to("cuda") out = self.model.generate( - **inputs, - **kwargs, + **inputs ) answers = [] diff --git a/src/airunner/aihandler/logger.py b/src/airunner/aihandler/logger.py index b9ecd5723..a22fe7dcf 100644 --- a/src/airunner/aihandler/logger.py +++ b/src/airunner/aihandler/logger.py @@ -76,4 +76,3 @@ def error(cls, msg): Logger.logger.addHandler(Logger.stream_handler) logging.getLogger("lightning").setLevel(logging.WARNING) logging.getLogger("lightning_fabric.utilities.seed").setLevel(logging.WARNING) -Logger.set_level(LOG_LEVEL) diff --git a/src/airunner/aihandler/runner.py b/src/airunner/aihandler/runner.py index c4a24b9a3..b5f16601c 100644 --- a/src/airunner/aihandler/runner.py +++ b/src/airunner/aihandler/runner.py @@ -1,6 +1,4 @@ import base64 -import gc -import os import re import traceback from io import BytesIO @@ -17,12 +15,11 @@ from diffusers.utils.torch_utils import randn_tensor #from diffusers import ConsistencyDecoderVAE from torchvision import transforms -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DiffusionPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from airunner.aihandler.auto_pipeline import AutoImport from airunner.aihandler.enums import FilterType from airunner.aihandler.enums import MessageCode -from airunner.aihandler.logger import Logger as logger from airunner.aihandler.mixins.compel_mixin import CompelMixin from airunner.aihandler.mixins.embedding_mixin import EmbeddingMixin from airunner.aihandler.mixins.lora_mixin import LoraMixin @@ -34,7 +31,7 @@ from airunner.aihandler.settings_manager import SettingsManager from airunner.prompt_builder.prompt_data import PromptData from airunner.scripts.realesrgan.main import RealESRGAN - +from airunner.aihandler.logger import Logger torch.backends.cuda.matmul.allow_tf32 = True @@ -139,7 +136,7 @@ def local_files_only(self): @local_files_only.setter def local_files_only(self, value): - logger.info("Setting local_files_only to %s" % value) + Logger.info("Setting local_files_only to %s" % value) self._local_files_only = value @property @@ -526,7 +523,7 @@ def pipe(self): elif self.is_vid_action: return self.txt2vid else: - logger.warning(f"Invalid action {self.action} unable to get pipe") + Logger.warning(f"Invalid action {self.action} unable to get pipe") @pipe.setter def pipe(self, value): @@ -543,7 +540,7 @@ def pipe(self, value): elif self.is_vid_action: self.txt2vid = value else: - logger.warning(f"Invalid action {self.action} unable to set pipe") + Logger.warning(f"Invalid action {self.action} unable to set pipe") @property def cuda_is_available(self): @@ -695,7 +692,7 @@ def original_model_data(self): return self.options.get("original_model_data", {}) def __init__(self, **kwargs): - logger.set_level(LOG_LEVEL) + Logger.set_level(LOG_LEVEL) self.settings_manager = SettingsManager() self.safety_checker_model = self.settings_manager.models_by_pipeline_action("safety_checker") self.text_encoder_model = self.settings_manager.models_by_pipeline_action("text_encoder") @@ -770,9 +767,9 @@ def is_safetensor_file(model): return model.endswith(".safetensors") def initialize(self): - logger.info("trying to initialize") + Logger.info("trying to initialize") if not self.initialized or self.reload_model or self.pipe is None: - logger.info("Initializing") + Logger.info("Initializing") self.compel_proc = None self.prompt_embeds = None self.negative_prompt_embeds = None @@ -790,7 +787,7 @@ def generator(self, device=None, seed=None): return torch.Generator(device=device).manual_seed(seed) def prepare_options(self, data): - logger.info(f"Preparing options") + Logger.info(f"Preparing options") action = data["action"] options = data["options"] requested_model = options.get(f"model", None) @@ -825,11 +822,11 @@ def send_message(self, message, code=None): if code == MessageCode.ERROR: traceback.print_stack() - logger.error(message) + Logger.error(message) elif code == MessageCode.WARNING: - logger.warning(message) + Logger.warning(message) elif code == MessageCode.STATUS: - logger.info(message) + Logger.info(message) formatted_message = { "code": code, @@ -842,11 +839,10 @@ def send_message(self, message, code=None): def error_handler(self, error): message = str(error) - if "got an unexpected keyword argument 'image'" in message and self.action in ["outpaint", "pix2pix", - "depth2img"]: + if "got an unexpected keyword argument 'image'" in message and self.action in ["outpaint", "pix2pix", "depth2img"]: message = f"This model does not support {self.action}" traceback.print_exc() - logger.error(error) + Logger.error(error) self.send_message(message, MessageCode.ERROR) def initialize_safety_checker(self, local_files_only=None): @@ -875,16 +871,16 @@ def load_safety_checker(self): if not self.pipe: return if not self.do_nsfw_filter: - logger.info("Disabling safety checker") + Logger.info("Disabling safety checker") self.pipe.safety_checker = None elif self.pipe.safety_checker is None: - logger.info("Loading safety checker") + Logger.info("Loading safety checker") self.pipe.safety_checker = self.safety_checker if self.pipe.safety_checker: self.pipe.safety_checker.to(self.device) def do_sample(self, **kwargs): - logger.info(f"Sampling {self.action}") + Logger.info(f"Sampling {self.action}") message = f"Generating {'video' if self.is_vid_action else 'image'}" @@ -977,7 +973,7 @@ def call_pipe(self, **kwargs): "negative_prompt_embeds": self.negative_prompt_embeds, }) except Exception as _e: - logger.warning("Compel failed: " + str(_e)) + Logger.warning("Compel failed: " + str(_e)) args.update({ "prompt": self.prompt, "negative_prompt": self.negative_prompt, @@ -1004,7 +1000,7 @@ def call_pipe(self, **kwargs): args["generator"] = generator if self.enable_controlnet: - logger.info(f"Setting up controlnet") + Logger.info(f"Setting up controlnet") args = self.load_controlnet_arguments(**args) self.load_safety_checker() @@ -1043,7 +1039,7 @@ def call_pipe_txt2vid(self, **kwargs): ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1] frame_ids = list(range(ch_start, ch_end)) try: - logger.info(f"Generating video with {len(frame_ids)} frames") + Logger.info(f"Generating video with {len(frame_ids)} frames") self.send_message(f"Generating video, frames {cur_frame} to {cur_frame + len(frame_ids)-1} of {self.n_samples}") cur_frame += len(frame_ids) kwargs = { @@ -1122,7 +1118,7 @@ def prepare_extra_args(self, _data, image, mask): return extra_args def sample_diffusers_model(self, data: dict): - logger.info("sample_diffusers_model") + Logger.info("sample_diffusers_model") from pytorch_lightning import seed_everything image = self.image mask = self.mask @@ -1155,7 +1151,7 @@ def process_prompts(self, data, seed): data["options"][f"negative_prompt"] = [negative_prompt for _ in range(self.batch_size)] return data prompt_data = self.prompt_data - logger.info(f"Process prompt") + Logger.info(f"Process prompt") if self.deterministic_seed: prompt = data["options"][f"prompt"] if ".blend(" in prompt: @@ -1199,7 +1195,7 @@ def process_prompts(self, data, seed): def process_data(self, data: dict): import traceback - logger.info("Runner: process_data called") + Logger.info("Runner: process_data called") self.requested_data = data self.prepare_options(data) #self.prepare_scheduler() @@ -1211,7 +1207,7 @@ def process_data(self, data: dict): def generate(self, data: dict): if not self.pipe: return - logger.info("generate called") + Logger.info("generate called") self.do_cancel = False self.process_data(data) @@ -1314,7 +1310,7 @@ def unload_tokenizer(self): self.tokenizer = None def process_upscale(self, data: dict): - logger.info("Processing upscale") + Logger.info("Processing upscale") image = self.input_image results = [] if image: @@ -1345,7 +1341,7 @@ def generator_sample(self, data: dict): return if not self.pipe: - logger.info("pipe is None") + Logger.info("pipe is None") return self.send_message(f"Generating {'video' if self.is_vid_action else 'image'}") @@ -1355,7 +1351,7 @@ def generator_sample(self, data: dict): try: self.initialized = self.__dict__[action] is not None except KeyError: - logger.info(f"{action} model has not been initialized yet") + Logger.info(f"{action} model has not been initialized yet") self.initialized = False error = None @@ -1369,7 +1365,7 @@ def generator_sample(self, data: dict): error = e if "PYTORCH_CUDA_ALLOC_CONF" in str(e): error = self.cuda_error_message - self.clear_memory() + self.engine.clear_memory() self.reset_applied_memory_settings() else: error_message = f"Error during generation" @@ -1397,7 +1393,7 @@ def log_error(self, error, message=None): self.error_handler(message) def load_controlnet_from_ckpt(self, pipeline): - logger.info("Loading controlnet from ckpt") + Logger.info("Loading controlnet from ckpt") pipeline = self.controlnet_action_diffuser( vae=pipeline.vae, text_encoder=pipeline.text_encoder, @@ -1412,7 +1408,7 @@ def load_controlnet_from_ckpt(self, pipeline): return pipeline def load_controlnet(self): - logger.info(f"Loading controlnet {self.controlnet_type} self.controlnet_model {self.controlnet_model}") + Logger.info(f"Loading controlnet {self.controlnet_type} self.controlnet_model {self.controlnet_model}") self._controlnet = None self.current_controlnet_type = self.controlnet_type controlnet = self.from_pretrained( @@ -1424,17 +1420,17 @@ def load_controlnet(self): def preprocess_for_controlnet(self, image): if self.current_controlnet_type != self.controlnet_type or not self.processor: - logger.info("Loading controlnet processor " + self.controlnet_type) + Logger.info("Loading controlnet processor " + self.controlnet_type) self.current_controlnet_type = self.controlnet_type - logger.info("Controlnet: Processing image") + Logger.info("Controlnet: Processing image") self.processor = Processor(self.controlnet_type) if self.processor: - logger.info("Controlnet: Processing image") + Logger.info("Controlnet: Processing image") image = self.processor(image) # resize image to width and height image = image.resize((self.width, self.height)) return image - logger.error("No controlnet processor found") + Logger.error("No controlnet processor found") def load_controlnet_arguments(self, **kwargs): if not self.is_vid_action: @@ -1452,7 +1448,7 @@ def load_controlnet_arguments(self, **kwargs): return kwargs def unload_unused_models(self): - logger.info("Unloading unused models") + Logger.info("Unloading unused models") for action in [ "txt2img", "img2img", @@ -1472,7 +1468,7 @@ def unload_unused_models(self): self.reset_applied_memory_settings() def load_model(self): - logger.info("Loading model") + Logger.info("Loading model") self.torch_compile_applied = False self.lora_loaded = False self.embeds_loaded = False @@ -1497,7 +1493,7 @@ def load_model(self): if self.pipe is None or self.reload_model: kwargs["from_safetensors"] = self.is_safetensors - logger.info(f"Loading model from scratch {self.reload_model}") + Logger.info(f"Loading model from scratch {self.reload_model}") self.reset_applied_memory_settings() self.send_model_loading_message(self.model_path) @@ -1527,7 +1523,7 @@ def load_model(self): except OSError as e: return self.handle_missing_files(self.action) else: - logger.info(f"Loading model {self.model_path} from PRETRAINED") + Logger.info(f"Loading model {self.model_path} from PRETRAINED") scheduler = self.load_scheduler() if scheduler: kwargs["scheduler"] = scheduler @@ -1546,7 +1542,7 @@ def load_model(self): ) if self.pipe is None: - logger.error("Failed to load pipeline") + Logger.error("Failed to load pipeline") self.send_message("Failed to load model", MessageCode.ERROR) return @@ -1554,12 +1550,12 @@ def load_model(self): Initialize pipe for video to video zero """ if self.pipe and self.is_vid2vid: - logger.info("Initializing pipe for vid2vid") + Logger.info("Initializing pipe for vid2vid") self.pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) self.pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) if self.is_outpaint: - logger.info("Initializing vae for inpaint / outpaint") + Logger.info("Initializing vae for inpaint / outpaint") self.pipe.vae = self.from_pretrained( pipeline_action="inpaint_vae", model=self.inpaint_vae_model @@ -1584,7 +1580,7 @@ def load_model(self): #self.load_learned_embed_in_clip() def load_ckpt_model(self): - logger.info(f"Loading ckpt file {self.model_path}") + Logger.info(f"Loading ckpt file {self.model_path}") pipeline = self.download_from_original_stable_diffusion_ckpt(path=self.model_path) return pipeline @@ -1634,11 +1630,11 @@ def download_from_original_stable_diffusion_ckpt(self, path, local_files_only=No local_files_only=False ) except Exception as e: - logger.error(f"Failed to load model from ckpt: {e}") + Logger.error(f"Failed to load model from ckpt: {e}") return pipe def clear_controlnet(self): - logger.info("Clearing controlnet") + Logger.info("Clearing controlnet") self._controlnet = None self.engine.clear_memory() self.reset_applied_memory_settings() @@ -1651,14 +1647,14 @@ def load_vae(self): ) def reuse_pipeline(self, do_load_controlnet): - logger.info("Reusing pipeline") + Logger.info("Reusing pipeline") pipe = None if self.is_txt2img: pipe = self.img2img if self.txt2img is None else self.txt2img elif self.is_img2img: pipe = self.txt2img if self.img2img is None else self.img2img if pipe is None: - logger.warning("Failed to reuse pipeline") + Logger.warning("Failed to reuse pipeline") self.clear_controlnet() return kwargs = pipe.components @@ -1726,7 +1722,7 @@ def send_model_loading_message(self, model_name): self.send_message(message) def prepare_model(self): - logger.info("Prepare model") + Logger.info("Prepare model") # get model and switch to it # get models from database @@ -1746,7 +1742,7 @@ def prepare_model(self): def unload_controlnet(self): if self.pipe: - logger.info("Unloading controlnet") + Logger.info("Unloading controlnet") self.pipe.controlnet = None self.controlnet_loaded = False @@ -1778,17 +1774,17 @@ def from_pretrained(self, **kwargs): **kwargs ) except OSError as e: - logger.error(f"failed to load {model} from pretrained") + Logger.error(f"failed to load {model} from pretrained") return self.handle_missing_files(pipeline_action) def handle_missing_files(self, action): if not self.attempt_download: if self.is_ckpt_model or self.is_safetensors: - logger.info("Required files not found, attempting download") + Logger.info("Required files not found, attempting download") else: import traceback traceback.print_exc() - logger.info("Model not found, attempting download") + Logger.info("Model not found, attempting download") # check if we have an internet connection if self.allow_online_when_missing_files: self.send_message("Downloading model files") diff --git a/src/airunner/aihandler/settings_manager.py b/src/airunner/aihandler/settings_manager.py index 83fbe1140..6add3bd4d 100644 --- a/src/airunner/aihandler/settings_manager.py +++ b/src/airunner/aihandler/settings_manager.py @@ -211,21 +211,21 @@ def find_generator(self, generator_section, generator_name): if self.generator_settings_override_id: generator_settings = session.query(GeneratorSetting).filter_by( id=self.generator_settings_override_id - ).join(Settings).first() + ).first() else: generator_settings = session.query(GeneratorSetting).filter_by( - section=generator_section, - generator_name=generator_name - ).join(Settings).first() + is_preset=0 + ).first() if generator_settings is None: if not generator_section or generator_section == "" or not generator_name or generator_name == "": return None - generator_settings = GeneratorSetting( - section=generator_section, - generator_name=generator_name - ) - session.add(generator_settings) - session.commit() + # generator_settings = GeneratorSetting( + # section=generator_section, + # generator_name=generator_name, + # is_preset=False + # ) + # session.add(generator_settings) + # session.commit() return generator_settings def __init__(self, app=None, *args, **kwargs): diff --git a/src/airunner/aihandler/transformer_runner.py b/src/airunner/aihandler/transformer_runner.py index 4ed433913..4c668c4ce 100644 --- a/src/airunner/aihandler/transformer_runner.py +++ b/src/airunner/aihandler/transformer_runner.py @@ -4,20 +4,12 @@ from PyQt6.QtCore import QObject # import AutoTokenizer -from transformers import AutoTokenizer # import BitsAndBytesConfig from transformers import BitsAndBytesConfig -import transformers -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from transformers import InstructBlipForConditionalGeneration from transformers import InstructBlipProcessor -from langchain.llms.huggingface_pipeline import HuggingFacePipeline -from langchain.memory import ConversationBufferWindowMemory -from langchain.prompts import PromptTemplate -from langchain.chains import ConversationChain -from airunner.aihandler.enums import MessageCode -from airunner.aihandler.llm_api import LLMAPI from airunner.aihandler.settings_manager import SettingsManager from airunner.data.models import LLMGenerator from airunner.data.db import session @@ -67,13 +59,19 @@ class TransformerRunner(QObject): current_generator_name = "" do_quantize_model = True callback = None + template = None + #system_instructions = "" + system_instructions = "" @property def generator(self): - if not self._generator or self.current_generator_name != self.requested_generator_name: - self.current_generator_name = self.requested_generator_name - self._generator = session.query(LLMGenerator).filter_by(name=self.current_generator_name).first() - return self._generator + try: + if not self._generator or self.current_generator_name != self.requested_generator_name: + self.current_generator_name = self.requested_generator_name + self._generator = session.query(LLMGenerator).filter_by(name=self.current_generator_name).first() + return self._generator + except Exception as e: + Logger.error(e) @property def do_load_model(self): @@ -202,7 +200,7 @@ def load_model(self, local_files_only = None): auto_class_ = None if self.generator.name == "seq2seq": - auto_class_ = AutoModelForSeq2SeqLM + auto_class_ = AutoModelForSeq2SeqLM elif self.generator.name == "casuallm": auto_class_ = AutoModelForCausalLM elif self.generator.name == "visualqa": @@ -219,38 +217,14 @@ def load_model(self, local_files_only = None): else: Logger.error(e) - if self.generator.name == "casuallm": - self.pipeline = transformers.pipeline( - "text-generation", - model=self.model, - tokenizer=self.tokenizer, - torch_dtype=torch.float16 if self.dtype != "32bit" else torch.float32, - trust_remote_code=False, - device_map="auto", - min_length=0, - max_length=1000, - num_beams=1, - do_sample=True, - top_k=20, - eta_cutoff=10, - top_p=1.0, - num_return_sequences=self.sequences, - eos_token_id=self.tokenizer.eos_token_id, - early_stopping=True, - repetition_penalty=1.15, - temperature=0.7, - ) - Logger.info(f"Loading prompt template {self.prompt_template}") - self.llm=HuggingFacePipeline(pipeline=self.pipeline) - self.memory = ConversationBufferWindowMemory(k=5) - self.prompt = PromptTemplate.from_template( - self.prompt_template, - template_format="jinja2", - partial_variables={ - "username": self.username, - "botname": self.botname, - }) - self.chain = ConversationChain(llm=self.llm, prompt=self.prompt, memory=self.memory) + # if self.generator.name == "casuallm": + # self.pipeline = AutoModelForCausalLM.from_pretrained( + # self.model, + # torch_dtype=torch.float16 if self.dtype != "32bit" else torch.float32, + # trust_remote_code=False, + # device_map="auto" + # ) + def process_data(self, data): self.request_data = data.get("request_data", {}) @@ -266,6 +240,7 @@ def process_data(self, data): self.prompt = self.request_data.get("prompt", "") self.prompt_template = self.request_data.get("prompt_template", "") self.image = self.request_data.get("image", None) + self.system_instructions = f"Your name is {self.botname}. You are talking to {self.username}. You will always respond in character. You will always respond to the user. You will not give ethical or moral advice to the user unless it is something appropriate for {self.botname} to say." if self.image: self.image = self.image.convert("RGB") @@ -284,6 +259,7 @@ def clear_conversation(self): pass def prepare_input_args(self): + self.system_instructions = self.request_data.get("system_instructions", "") top_k = self.parameters.get("top_k", self.top_k) eta_cutoff = self.parameters.get("eta_cutoff", self.eta_cutoff) top_p = self.parameters.get("top_p", self.top_p) @@ -388,18 +364,20 @@ def handle_generate_request(self): if self.generator.name == "visualqa": self.load_processor() - self.engine.send_message("Generating output") - with torch.backends.cuda.sdp_kernel( - enable_flash=True, - enable_math=False, - enable_mem_efficient=False - ): - with torch.no_grad(): - value = self.generate(**kwargs) - if self.callback: - self.callback(value) - else: - self.engine.send_message(value, code=MessageCode.TEXT_GENERATED) + # self.engine.send_message("Generating output") + # with torch.backends.cuda.sdp_kernel( + # enable_flash=True, + # enable_math=False, + # enable_mem_efficient=False + # ): + # with torch.no_grad(): + # print("************** CALLING GENERATE") + # value = self.generate() + # print("VALUE", value) + # # if self.callback: + # # self.callback(value) + # # else: + # # self.engine.send_message(value, code=MessageCode.TEXT_GENERATED) self.enable_request_processing() def disable_request_processing(self): @@ -408,7 +386,7 @@ def disable_request_processing(self): def enable_request_processing(self): self._processing_request = True - def generate(self, **kwargs): + def generate(self): pass diff --git a/src/airunner/alembic/env.py b/src/airunner/alembic/env.py index 1e8633c2a..01ef487e6 100644 --- a/src/airunner/alembic/env.py +++ b/src/airunner/alembic/env.py @@ -11,8 +11,8 @@ # Interpret the config file for Python logging. # This line sets up loggers basically. -if config.config_file_name is not None: - fileConfig(config.config_file_name) +# if config.config_file_name is not None: +# fileConfig(config.config_file_name) # add your model's MetaData object here # for 'autogenerate' support diff --git a/src/airunner/alembic/versions/08728ca819b8_adds_prompt_template_to_llmgenerator.py b/src/airunner/alembic/versions/08728ca819b8_adds_prompt_template_to_llmgenerator.py new file mode 100644 index 000000000..eae30e2bf --- /dev/null +++ b/src/airunner/alembic/versions/08728ca819b8_adds_prompt_template_to_llmgenerator.py @@ -0,0 +1,30 @@ +"""Adds prompt_template to LLMGenerator + +Revision ID: 08728ca819b8 +Revises: 7c1b15cade9b +Create Date: 2024-01-08 13:46:24.790611 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '08728ca819b8' +down_revision: Union[str, None] = '7c1b15cade9b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('llm_generator', sa.Column('prompt_template', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('llm_generator', 'prompt_template') + # ### end Alembic commands ### diff --git a/src/airunner/alembic/versions/435acbb59892_added_version_to_generator_settings.py b/src/airunner/alembic/versions/435acbb59892_added_version_to_generator_settings.py new file mode 100644 index 000000000..4dbc8a1a5 --- /dev/null +++ b/src/airunner/alembic/versions/435acbb59892_added_version_to_generator_settings.py @@ -0,0 +1,30 @@ +"""Added version to generator settings + +Revision ID: 435acbb59892 +Revises: 6f4472b84071 +Create Date: 2024-01-02 06:15:11.656030 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '435acbb59892' +down_revision: Union[str, None] = '6f4472b84071' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('generator_settings', sa.Column('version', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('generator_settings', 'version') + # ### end Alembic commands ### diff --git a/src/airunner/alembic/versions/7c1b15cade9b_adds_is_preset_flag_to_generator_.py b/src/airunner/alembic/versions/7c1b15cade9b_adds_is_preset_flag_to_generator_.py new file mode 100644 index 000000000..0f561eefb --- /dev/null +++ b/src/airunner/alembic/versions/7c1b15cade9b_adds_is_preset_flag_to_generator_.py @@ -0,0 +1,30 @@ +"""Adds is_preset flag to generator settings table + +Revision ID: 7c1b15cade9b +Revises: 435acbb59892 +Create Date: 2024-01-02 06:52:02.924308 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '7c1b15cade9b' +down_revision: Union[str, None] = '435acbb59892' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('generator_settings', sa.Column('is_preset', sa.Boolean(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('generator_settings', 'is_preset') + # ### end Alembic commands ### diff --git a/src/airunner/data/db.py b/src/airunner/data/db.py index cba29c1bd..e210a868d 100644 --- a/src/airunner/data/db.py +++ b/src/airunner/data/db.py @@ -14,6 +14,7 @@ LLMGeneratorSetting, LLMGenerator, LLMModelVersion from airunner.utils import get_session from alembic.config import Config +from airunner.aihandler.logger import Logger from alembic import command import os import configparser @@ -495,7 +496,7 @@ def insert_variables(variables, prev_object=None): }, ] -HERE = os.path.abspath(os.path.dirname(__file__)) +HERE = os.path.abspath(os.path.dirname(__file__)) alembic_ini_path = os.path.join(HERE, "../alembic.ini") config = configparser.ConfigParser() @@ -505,9 +506,7 @@ def insert_variables(variables, prev_object=None): db_path = f'sqlite:///{home_dir}/.airunner/airunner.db' config.set('alembic', 'sqlalchemy.url', db_path) - with open(alembic_ini_path, 'w') as configfile: config.write(configfile) - alembic_cfg = Config(alembic_ini_path) -command.upgrade(alembic_cfg, "head") \ No newline at end of file +command.upgrade(alembic_cfg, "head") diff --git a/src/airunner/data/models.py b/src/airunner/data/models.py index edac49b49..1a31fc370 100644 --- a/src/airunner/data/models.py +++ b/src/airunner/data/models.py @@ -356,6 +356,8 @@ class GeneratorSetting(BaseModel): active_grid_border_color = Column(String, default="#00FF00") active_grid_fill_color = Column(String, default="#FF0000") brushes = relationship("Brush", back_populates='generator_setting') # modified line + version = Column(String, default="SD 1.5") + is_preset = Column(Boolean, default=False) class Brush(BaseModel): @@ -767,6 +769,7 @@ class LLMGenerator(BaseModel): message_type = Column(String, default="chat") bot_personality = Column(String, default="Nice") override_parameters = Column(Boolean, default=False) + prompt_template = Column(String, default="") class LLMGeneratorSetting(BaseModel): diff --git a/src/airunner/filters/color_balance.py b/src/airunner/filters/color_balance.py index 706e00276..3a30bf560 100644 --- a/src/airunner/filters/color_balance.py +++ b/src/airunner/filters/color_balance.py @@ -5,6 +5,7 @@ class ColorBalanceFilter(BaseFilter): def apply_filter(self, image, _do_reset): + image = image.convert("RGBA") red, green, blue, alpha = image.split() red = red.point(lambda i: i + (i * self.cyan_red)) green = green.point(lambda i: i + (i * self.magenta_green)) diff --git a/src/airunner/filters/registration_error.py b/src/airunner/filters/registration_error.py index 38c61c7df..0a461f850 100644 --- a/src/airunner/filters/registration_error.py +++ b/src/airunner/filters/registration_error.py @@ -5,6 +5,7 @@ class RegistrationErrorFilter(BaseFilter): def apply_filter(self, image, do_reset): + image = image.convert("RGBA") # first, split the image into its R G B channels r, g, b, a = image.split() diff --git a/src/airunner/filters/unsharp_mask.py b/src/airunner/filters/unsharp_mask.py index 0de8945c5..e3ae94983 100644 --- a/src/airunner/filters/unsharp_mask.py +++ b/src/airunner/filters/unsharp_mask.py @@ -4,5 +4,8 @@ class UnsharpMask(BaseFilter): + def __init__(self, radius, percent, threshold): + self.unsharp_mask = UnsharpMask(radius=radius, percent=percent, threshold=threshold) + def apply_filter(self, image, do_reset): - return image.filter(UnsharpMask(radius=self.radius)) + return image.filter(self.unsharp_mask) diff --git a/src/airunner/filters/windows/filter_base.py b/src/airunner/filters/windows/filter_base.py index 49db007d0..3f6aaab44 100644 --- a/src/airunner/filters/windows/filter_base.py +++ b/src/airunner/filters/windows/filter_base.py @@ -5,6 +5,7 @@ from PyQt6 import uic from airunner.widgets.slider.slider_widget import SliderWidget +from airunner.aihandler.settings_manager import SettingsManager class FilterBase: @@ -36,7 +37,7 @@ def __getattr__(self, item): def __setattr__(self, key, value): if key in self._filter_values: self._filter_values[key].value = str(value) - self.parent.settings_manager.save() + self.settings_manager.save() else: super().__setattr__(key, value) @@ -57,6 +58,8 @@ def __init__(self, parent, model_name): # filter_values are the names of the ImageFilterValue objects in the database. # when the filter is shown, the values are loaded from the database # and stored in this dictionary. + self.settings_manager = SettingsManager(app=parent) + self._filter_values = {} self.filter_window = None @@ -65,16 +68,14 @@ def __init__(self, parent, model_name): self.load_image_filter_data() def update_value(self, name, value): - print(name, value) self._filter_values[name].value = str(value) - self.parent.settings_manager.save() + self.settings_manager.save() def update_canvas(self): - if self.parent.canvas_is_active: - self.parent.current_canvas.update() + pass def load_image_filter_data(self): - self.image_filter_data = self.parent.settings_manager.get_image_filter(self.image_filter_model_name) + self.image_filter_data = self.settings_manager.get_image_filter(self.image_filter_model_name) for filter_value in self.image_filter_data.image_filter_values: self._filter_values[filter_value.name] = filter_value @@ -134,7 +135,7 @@ def show(self): def handle_auto_apply_toggle(self): self.image_filter_data.auto_apply = self.filter_window.auto_apply.isChecked() - self.parent.settings_manager.save() + self.settings_manager.save() def handle_slider_change(self, settings_property, val): self.update_value(settings_property, val) @@ -143,15 +144,15 @@ def handle_slider_change(self, settings_property, val): def cancel_filter(self): self.reject() - self.parent.current_canvas.cancel_filter() + self.parent.canvas_widget.cancel_filter() self.update_canvas() def apply_filter(self): self.accept() - self.parent.current_canvas.apply_filter(self.filter) + self.parent.canvas_widget.apply_filter(self.filter) self.filter_window.close() self.update_canvas() def preview_filter(self): - self.parent.current_canvas.preview_filter(self.filter) + self.parent.canvas_widget.preview_filter(self.filter) self.update_canvas() diff --git a/src/airunner/styles/dark_theme/styles.qss b/src/airunner/styles/dark_theme/styles.qss index 36e8a954b..cea971461 100644 --- a/src/airunner/styles/dark_theme/styles.qss +++ b/src/airunner/styles/dark_theme/styles.qss @@ -644,4 +644,11 @@ QTableWidget::item:selected { #delete_confirmation QPushButton { border: 1px solid rgba(0, 132, 185, 50); +} + +QWidget#message QPlainTextEdit { + border-radius: 5px; + border: 5px solid #1f1f1f; + background-color: #1f1f1f; + color: #ffffff; } \ No newline at end of file diff --git a/src/airunner/utils.py b/src/airunner/utils.py index 3e2b5262b..561bdf193 100644 --- a/src/airunner/utils.py +++ b/src/airunner/utils.py @@ -221,24 +221,6 @@ def get_version(): return "" -def get_latest_version(): - return - # get latest release from https://github.com/Capsize-Games/airunner/releases/latest - # follow the redirect to get the version number - import requests - import re - url = "https://github.com/Capsize-Games/airunner/releases/latest" - try: - r = requests.get(url) - if r.status_code == 200: - m = re.search(r"\/Capsize-Games\/airunner\/releases\/tag\/v([0-9\.]+)", r.text) - if m: - return m.group(1) - except ConnectionError: - return None - return None - - # def load_default_models(tab_section, section_name): # if section_name == "txt2img": # section_name = "generate" diff --git a/src/airunner/widgets/base_widget.py b/src/airunner/widgets/base_widget.py index 268443c17..cd314ef85 100644 --- a/src/airunner/widgets/base_widget.py +++ b/src/airunner/widgets/base_widget.py @@ -8,7 +8,7 @@ class BaseWidget(QWidget): widget_class_ = None - icons = {} + icons = () ui = None qss_filename = None @@ -26,25 +26,42 @@ def add_to_grid(self, widget, row, column, row_span=1, column_span=1): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.app = get_main_window() + self.app.loaded.connect(self.initialize) self.settings_manager = SettingsManager() if self.widget_class_: self.ui = self.widget_class_() if self.ui: self.ui.setupUi(self) - if self.qss_filename: - theme_name = "dark_theme" - here = os.path.dirname(os.path.realpath(__file__)) - with open(os.path.join(here, "..", "styles", theme_name, self.qss_filename), "r") as f: - stylesheet = f.read() - self.setStyleSheet(stylesheet) - - def set_stylesheet(self, is_dark=None, button_name=None, icon=None): - is_dark = self.is_dark if is_dark is None else is_dark - if button_name is None or icon is None: - for button_name, icon in self.icons.items(): - self.set_button_icon(is_dark, button_name, icon) - else: - self.set_button_icon(is_dark, button_name, icon) + # if self.qss_filename: + # theme_name = "dark_theme" + # here = os.path.dirname(os.path.realpath(__file__)) + # with open(os.path.join(here, "..", "styles", theme_name, self.qss_filename), "r") as f: + # stylesheet = f.read() + # self.setStyleSheet(stylesheet) + self.set_icons() + + def initialize(self): + """ + Triggered when the app is loaded. + Override this function in order to initialize the widget rather than + using __init__. + """ + pass + + def set_icons(self): + theme = "dark" if self.is_dark else "light" + for icon_data in self.icons: + icon_name = icon_data[0] + widget_name = icon_data[1] + print(icon_name, widget_name) + print(self.icons) + icon = QtGui.QIcon() + icon.addPixmap( + QtGui.QPixmap(f":/icons/{theme}/{icon_name}.svg"), + QtGui.QIcon.Mode.Normal, + QtGui.QIcon.State.Off) + getattr(self.ui, widget_name).setIcon(icon) + self.update() def set_button_icon(self, is_dark, button_name, icon): try: @@ -130,8 +147,12 @@ def set_form_value(self, element, settings_key_name): if not self.set_is_checked(element, target_val): raise Exception(f"Could not set value for {element} to {target_val}") - def set_form_property(self, element, property_name, settings_key_name): + def set_form_property(self, element, property_name, settings_key_name=None, settings=None): val = self.get_form_element(element).property(property_name) - target_val = self.settings_manager.get_value(settings_key_name) + if settings_key_name: + target_val = self.settings_manager.get_value(settings_key_name) + elif settings: + target_val = getattr(settings, property_name) + if val != target_val: self.get_form_element(element).setProperty(property_name, target_val) diff --git a/src/airunner/widgets/brushes/brushes_container.py b/src/airunner/widgets/brushes/brushes_container.py index edd6a585e..bc3b24e0d 100644 --- a/src/airunner/widgets/brushes/brushes_container.py +++ b/src/airunner/widgets/brushes/brushes_container.py @@ -133,6 +133,7 @@ def activate_brush(self, clicked_widget, brush, multiple): else: self.selected_brushes.remove(widget) widget.setStyleSheet("") + self.settings_manager.set_value("generator_settings_override_id", None) return for widget in self.selected_brushes: @@ -153,7 +154,7 @@ def activate_brush(self, clicked_widget, brush, multiple): widget.setStyleSheet(f""" border: 2px solid #ff0000; """) - self.app.ui.generator_widget.enable_preset(widget.brush.generator_setting_id) + self.settings_manager.set_value("generator_settings_override_id", widget.brush.generator_setting_id) def display_brush_menu(self, event, widget, brush): context_menu = QMenu(self) diff --git a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py index d5097f515..96d82312e 100644 --- a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py +++ b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py @@ -2,7 +2,6 @@ import math import subprocess from functools import partial -import pdb from PIL import Image, ImageGrab from PIL.ImageQt import ImageQt, QImage @@ -15,7 +14,7 @@ from airunner.cursors.circle_brush import CircleCursor from airunner.data.db import session from airunner.data.models import Layer, CanvasSettings, ActiveGridSettings -from airunner.utils import get_session, save_session +from airunner.utils import save_session from airunner.widgets.canvas_plus.canvas_base_widget import CanvasBaseWidget from airunner.widgets.canvas_plus.templates.canvas_plus_ui import Ui_canvas from airunner.utils import apply_opacity_to_image @@ -212,10 +211,6 @@ class CanvasPlusWidget(CanvasBaseWidget): initialized = False drawing = False - @property - def current_active_image_data(self): - return self.current_layer.image_data - def current_pixmap(self): draggable_pixmap = self.current_draggable_pixmap() if draggable_pixmap: @@ -227,10 +222,6 @@ def current_image(self): return None return Image.fromqpixmap(pixmap) - @current_active_image_data.setter - def current_active_image_data(self, value): - self.current_layer.image_data = value - @property def image_pivot_point(self): try: @@ -402,6 +393,16 @@ def draw_layers(self): layer.opacity / 100.0 ) + if layer.id in self.layers: + if not layer.visible: + if self.layers[layer.id] in self.scene.items(): + self.scene.removeItem(self.layers[layer.id]) + elif layer.visible: + if not self.layers[layer.id] in self.scene.items(): + self.scene.addItem(self.layers[layer.id]) + self.layers[layer.id].pixmap.convertFromImage(ImageQt(image)) + continue + draggable_pixmap = None if layer.id in self.layers: self.layers[layer.id].pixmap.convertFromImage(ImageQt(image)) @@ -453,6 +454,7 @@ def draw_active_grid_area_container(self): """ Draw a rectangle around the active grid area of """ + print(self.active_grid_area_rect) if not self.active_grid_area: self.active_grid_area = ActiveGridArea( parent=self, @@ -583,7 +585,7 @@ def load_image_from_object(self, image, is_outpaint=False, image_root_point=None def load_image(self, image_path): image = Image.open(image_path) - if self.app.settings_manager.resize_on_paste: + if self.settings_manager.resize_on_paste: image.thumbnail((self.settings_manager.working_width, self.settings_manager.working_height), Image.ANTIALIAS) self.add_image_to_scene(image) @@ -603,6 +605,8 @@ def cut_image(self): if not draggable_pixmap: return self.scene.removeItem(draggable_pixmap) + if self.current_layer.id in self.layers: + del self.layers[self.current_layer.id] self.update() def delete_image(self): @@ -614,17 +618,13 @@ def delete_image(self): self.update() def paste_image_from_clipboard(self): + Logger.info("paste image from clipboard") image = self.get_image_from_clipboard() if not image: Logger.info("No image in clipboard") return - if self.app.settings_manager.resize_on_paste: - Logger.info("Resizing image") - image.thumbnail( - (self.settings_manager.working_width, self.settings_manager.working_height), - Image.ANTIALIAS) self.create_image(image) def get_image_from_clipboard(self): @@ -653,19 +653,23 @@ def image_to_system_clipboard_linux(self, pixmap): subprocess.Popen(["xclip", "-selection", "clipboard", "-t", "image/png"], stdin=subprocess.PIPE).communicate(data) except FileNotFoundError: - pass + Logger.error("xclip not found. Please install xclip to copy image to clipboard.") def create_image(self, image): - if self.app.settings_manager.resize_on_paste: - image.thumbnail( - ( - self.settings_manager.working_width, - self.settings_manager.working_height - ), - Image.ANTIALIAS - ) + if self.settings_manager.resize_on_paste: + image = self.resize_image(image) self.add_image_to_scene(image) + def resize_image(self, image): + image.thumbnail( + ( + self.settings_manager.working_width, + self.settings_manager.working_height + ), + Image.ANTIALIAS + ) + return image + def remove_current_draggable_pixmap_from_scene(self): current_draggable_pixmap = self.current_draggable_pixmap() if current_draggable_pixmap: @@ -678,11 +682,6 @@ def switch_to_layer(self, layer_index): self.current_layer_index = layer_index def add_image_to_scene(self, image, is_outpaint=False, image_root_point=None): - # change size of self.current_active_image to match size of image - if self.current_active_image is None: - self.current_active_image = image - else: - self.current_active_image = self.current_active_image.resize(image.size) self.current_active_image = image if image_root_point is not None: self.current_layer.pos_x = image_root_point.x() @@ -751,3 +750,13 @@ def save_image(self, image_path, image=None): def update_image_canvas(self): print("TODO") + + def rotate_90_clockwise(self): + if self.current_active_image: + self.current_active_image = self.current_active_image.transpose(Image.ROTATE_270) + self.do_draw() + + def rotate_90_counterclockwise(self): + if self.current_active_image: + self.current_active_image = self.current_active_image.transpose(Image.ROTATE_90) + self.do_draw() \ No newline at end of file diff --git a/src/airunner/widgets/canvas_plus/standard_image_widget.py b/src/airunner/widgets/canvas_plus/standard_image_widget.py index 0a2465006..556ee284a 100644 --- a/src/airunner/widgets/canvas_plus/standard_image_widget.py +++ b/src/airunner/widgets/canvas_plus/standard_image_widget.py @@ -18,7 +18,7 @@ from airunner.utils import delete_image, load_metadata_from_image, prepare_metadata from airunner.settings import CONTROLNET_OPTIONS from airunner.widgets.slider.slider_widget import SliderWidget -from airunner.data.models import ActionScheduler, Pipeline +from airunner.data.models import ActionScheduler, Pipeline, GeneratorSetting class StandardImageWidget(StandardBaseWidget): widget_class_ = Ui_standard_image_widget @@ -29,6 +29,21 @@ class StandardImageWidget(StandardBaseWidget): image_label = None image_batch = None meta_data = None + _image = None + + @property + def image(self): + if self._image is None: + self.image = self.app.canvas_widget.current_layer.image + return self._image + + @image.setter + def image(self, image): + self._image = image + + @property + def canvas_widget(self): + return self.ui.canvas_widget def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -41,20 +56,11 @@ def __init__(self, *args, **kwargs): self.initialize() def set_controlnet_settings_properties(self): - self.ui.controlnet_settings.initialize( - self.settings_manager.generator_name, - self.settings_manager.generator_section - ) + self.ui.controlnet_settings.initialize() def set_input_image_widget_properties(self): - self.ui.input_image_widget.initialize( - self.settings_manager.generator_name, - self.settings_manager.generator_section - ) - self.ui.controlnet_settings.initialize( - self.settings_manager.generator_name, - self.settings_manager.generator_section - ) + self.ui.input_image_widget.initialize() + self.ui.controlnet_settings.initialize() def update_image_input_thumbnail(self): self.ui.input_image_widget.set_thumbnail() @@ -113,7 +119,6 @@ def set_pixmap(self, image_path=None, image=None): self.image_path = image_path self.image = image meta_data = image.info - print("META DATA", meta_data) self.meta_data = meta_data if meta_data is not None else load_metadata_from_image(image) return #size = self.ui.image_frame.width() - 20 @@ -204,18 +209,21 @@ def similar_image_with_prompt(self): """ Using the LLM, generate a description of the image """ - #self.app.describe_image(image=self.image, callback=self.handle_prompt_generated) - prompt = self.app.generator_tab_widget.ui.generator_form_stablediffusion.ui.prompt.toPlainText() - negative_prompt = self.app.generator_tab_widget.ui.generator_form_stablediffusion.ui.negative_prompt.toPlainText() - self.handle_prompt_generated([prompt], [negative_prompt]) + self.app.describe_image(image=self.image, callback=self.handle_prompt_generated) + # prompt = self.app.generator_tab_widget.ui.prompt.toPlainText() + # negative_prompt = self.app.generator_tab_widget.ui.negative_prompt.toPlainText() + # self.handle_prompt_generated([prompt], [negative_prompt]) def handle_prompt_generated(self, prompt, negative_prompt): meta_data = load_metadata_from_image(self.image) meta_data["prompt"] = prompt[0] meta_data["negative_prompt"] = negative_prompt[0] meta_data = prepare_metadata({ "options": meta_data }) - image = Image.open(self.image_path) - image.save(self.image_path, pnginfo=meta_data) + if self.image_path: + image = Image.open(self.image_path) + image.save(self.image_path, pnginfo=meta_data) + else: + image = self.app.canvas_widget.current_layer.image self.image = image self.meta_data = load_metadata_from_image(self.image) self.generate_similar_image() @@ -224,7 +232,7 @@ def similar_image(self): self.generate_similar_image() def generate_similar_image(self, batch_size=1): - meta_data = self.meta_data + meta_data = self.meta_data or {} prompt = meta_data.get("prompt", None) negative_prompt = meta_data.get("negative_prompt", None) @@ -232,7 +240,8 @@ def generate_similar_image(self, batch_size=1): negative_prompt = None if negative_prompt == "" else negative_prompt if prompt is None: - return self.similar_image_with_prompt() + #return self.similar_image_with_prompt() + prompt = "" if negative_prompt is None: meta_data["negative_prompt"] = "verybadimagenegative_v1.3, EasyNegative" @@ -251,7 +260,7 @@ def generate_similar_image(self, batch_size=1): meta_data["use_cropped_image"] = False meta_data["batch_size"] = batch_size - self.app.generator_tab_widget.current_generator_widget.call_generate( + self.app.generator_tab_widget.call_generate( image=self.image, override_data=meta_data ) @@ -263,7 +272,7 @@ def similar_batch(self): self.generate_similar_image(batch_size=4) def upscale_2x_clicked(self): - meta_data = self.meta_data + meta_data = self.meta_data or {} prompt = meta_data.get("prompt", None) negative_prompt = meta_data.get("negative_prompt", None) @@ -271,7 +280,7 @@ def upscale_2x_clicked(self): negative_prompt = None if negative_prompt == "" else negative_prompt if prompt is None: - return self.similar_image_with_prompt() + prompt = "" if negative_prompt is None: meta_data["negative_prompt"] = "verybadimagenegative_v1.3, EasyNegative" @@ -281,14 +290,13 @@ def upscale_2x_clicked(self): meta_data["face_enhance"] = self.settings_manager.standard_image_widget_settings.face_enhance meta_data["denoise_strength"] = 0.5 meta_data["action"] = "upscale" - meta_data["width"] = self.image.width - meta_data["height"] = self.image.height + meta_data["width"] = self.ui.canvas_widget.current_layer.image.width + meta_data["height"] = self.ui.canvas_widget.current_layer.image.height meta_data["enable_input_image"] = True meta_data["use_cropped_image"] = False - - self.app.generator_tab_widget.current_generator_widget.call_generate( - image=self.image, + self.app.generator_tab_widget.call_generate( + image=self.ui.canvas_widget.current_layer.image, override_data=meta_data ) @@ -322,20 +330,6 @@ def load_versions(self): self.ui.version.setCurrentText(current_version) self.ui.version.blockSignals(False) - def handle_pipeline_changed(self, val): - if val == "txt2img / img2img": - val = "txt2img" - elif val == "inpaint / outpaint": - val = "outpaint" - self.settings_manager.set_value(f"current_section_{self.settings_manager.current_image_generator}", val) - self.load_versions() - self.load_models() - - def handle_version_changed(self, val): - print("VERSION CHANGED", val) - self.settings_manager.set_value(f"current_version_{self.settings_manager.current_image_generator}", val) - self.load_models() - def load_models(self): session = get_session() self.ui.model.blockSignals(True) @@ -381,15 +375,37 @@ def load_schedulers(self): def clear_models(self): self.ui.model.clear() - def handle_settings_manager_changed(self, key, val, settings_manager): - print("handle_settings_manager_changed", key, val) - if settings_manager.generator_section == self.settings_manager.generator_section and settings_manager.generator_name == self.settings_manager.generator_name: + def initialize_generator_form(self, override_id=None): + if override_id: + self.ui.steps_widget.set_slider_and_spinbox_values(self.settings_manager.generator.steps) + self.ui.scale_widget.set_slider_and_spinbox_values(self.settings_manager.generator.scale * 100) + self.ui.clip_skip_slider_widget.set_slider_and_spinbox_values(self.settings_manager.generator.clip_skip) + + self.ui.pipeline.blockSignals(True) + self.ui.version.blockSignals(True) + self.ui.model.blockSignals(True) + self.ui.scheduler.blockSignals(True) + + self.ui.pipeline.setCurrentText(self.settings_manager.generator.section) + self.ui.version.setCurrentText(self.settings_manager.generator.version) + self.ui.model.setCurrentText(self.settings_manager.generator.model) + self.ui.scheduler.setCurrentText(self.settings_manager.generator.scheduler) + + self.ui.pipeline.blockSignals(False) + self.ui.version.blockSignals(False) + self.ui.model.blockSignals(False) + self.ui.scheduler.blockSignals(False) + else: self.set_form_values() self.load_pipelines() self.load_versions() self.load_models() self.load_schedulers() - + + def handle_settings_manager_changed(self, key, val, settings_manager): + if key == "generator_settings_override_id": + self.initialize_generator_form(val) + def initialize(self): self.set_form_values() self.load_pipelines() @@ -426,4 +442,28 @@ def initialize(self): # self.generator_section, # self.generator_name # ) - self.initialized = True \ No newline at end of file + self.initialized = True + + def handle_model_changed(self, name): + if not self.initialized: + return + self.settings_manager.set_value("generator.model", name) + + def handle_scheduler_changed(self, name): + if not self.initialized: + return + self.settings_manager.set_value("generator.scheduler", name) + + def handle_pipeline_changed(self, val): + if val == "txt2img / img2img": + val = "txt2img" + elif val == "inpaint / outpaint": + val = "outpaint" + self.settings_manager.set_value(f"current_section_{self.settings_manager.current_image_generator}", val) + self.load_versions() + self.load_models() + + def handle_version_changed(self, val): + print("VERSION CHANGED", val) + self.settings_manager.set_value(f"current_version_{self.settings_manager.current_image_generator}", val) + self.load_models() \ No newline at end of file diff --git a/src/airunner/widgets/canvas_plus/templates/standard_image_widget.ui b/src/airunner/widgets/canvas_plus/templates/standard_image_widget.ui index 77553b23c..7dae3c629 100644 --- a/src/airunner/widgets/canvas_plus/templates/standard_image_widget.ui +++ b/src/airunner/widgets/canvas_plus/templates/standard_image_widget.ui @@ -39,14 +39,14 @@ - + 0 0 - 450 + 0 0 @@ -80,14 +80,15 @@ Qt::LeftToRight - 0 + 4 - + + .. Settings @@ -118,40 +119,23 @@ 0 0 - 444 + 438 1211 - - - - - - - - - - - - generator.random_seed - - - - - - - - - - - - - generator.random_latents_seed - - - - + + + + Qt::Vertical + + + + 20 + 40 + + + @@ -225,24 +209,24 @@ - - + + - + - DDIM ETA + Samples handle_value_change - 1 + 0 - 10 + 500 - 10 + 500.000000000000000 false @@ -260,26 +244,23 @@ 1 - generator.ddim_eta + generator.n_samples - + - Frames - - - handle_value_change + Clip Skip 0 - 200 + 11 - 200.000000000000000 + 12.000000000000000 false @@ -291,36 +272,39 @@ 1 - 1 + 0 - 1 + 0 + + + handle_value_change - generator.n_samples + generator.clip_skip - - + + - + - Samples + DDIM ETA handle_value_change - 0 + 1 - 500 + 10 - 500.000000000000000 + 10 false @@ -338,23 +322,26 @@ 1 - generator.n_samples + generator.ddim_eta - + - Clip Skip + Frames + + + handle_value_change 0 - 11 + 200 - 12.000000000000000 + 200.000000000000000 false @@ -366,16 +353,13 @@ 1 - 0 + 1 - 0 - - - handle_value_change + 1 - generator.clip_skip + generator.n_samples @@ -469,33 +453,36 @@ - - + + + + + + + + + + + + generator.random_seed + + + - - - PointingHandCursor + + + - - Variation + + + + + generator.random_latents_seed - - - - Qt::Vertical - - - - 20 - 40 - - - - @@ -773,8 +760,8 @@ - 532 - 532 + 0 + 0 @@ -784,7 +771,7 @@ - PointingHandCursor + ArrowCursor true @@ -856,8 +843,8 @@ upscale_model_changed(QString) - 144 - 67 + 111 + 59 79 @@ -872,8 +859,8 @@ face_enhance_toggled(bool) - 306 - 66 + 327 + 58 286 @@ -888,8 +875,8 @@ upscale_number_changed(int) - 86 - 92 + 88 + 72 0 @@ -904,8 +891,8 @@ upscale_2x_clicked() - 249 - 100 + 327 + 89 228 @@ -920,8 +907,8 @@ handle_advanced_settings_checkbox(bool) - 58 - 58 + 60 + 37 24 @@ -936,8 +923,8 @@ similar_image() - 206 - 142 + 255 + 135 -3 @@ -952,8 +939,8 @@ similar_batch() - 350 - 143 + 399 + 135 259 @@ -961,6 +948,70 @@ + + pipeline + currentTextChanged(QString) + standard_image_widget + handle_pipeline_changed(QString) + + + 352 + 74 + + + 464 + -11 + + + + + version + currentTextChanged(QString) + standard_image_widget + handle_version_changed(QString) + + + 166 + 116 + + + 458 + -17 + + + + + model + currentTextChanged(QString) + standard_image_widget + handle_model_changed(QString) + + + 121 + 170 + + + 418 + -7 + + + + + scheduler + currentTextChanged(QString) + standard_image_widget + handle_scheduler_changed(QString) + + + 299 + 221 + + + 342 + -12 + + + image_to_canvas() @@ -981,5 +1032,9 @@ face_enhance_toggled(bool) handle_advanced_settings_checkbox(bool) upscale_number_changed(int) + handle_pipeline_changed(QString) + handle_version_changed(QString) + handle_model_changed(QString) + handle_scheduler_changed(QString) diff --git a/src/airunner/widgets/canvas_plus/templates/standard_image_widget_ui.py b/src/airunner/widgets/canvas_plus/templates/standard_image_widget_ui.py index 58a8df40b..4e7e3e8d0 100644 --- a/src/airunner/widgets/canvas_plus/templates/standard_image_widget_ui.py +++ b/src/airunner/widgets/canvas_plus/templates/standard_image_widget_ui.py @@ -21,12 +21,12 @@ def setupUi(self, standard_image_widget): self.splitter.setOrientation(QtCore.Qt.Orientation.Horizontal) self.splitter.setObjectName("splitter") self.sidebar = QtWidgets.QWidget(parent=self.splitter) - sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Maximum, QtWidgets.QSizePolicy.Policy.Preferred) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.MinimumExpanding, QtWidgets.QSizePolicy.Policy.Preferred) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) sizePolicy.setHeightForWidth(self.sidebar.sizePolicy().hasHeightForWidth()) self.sidebar.setSizePolicy(sizePolicy) - self.sidebar.setMinimumSize(QtCore.QSize(450, 0)) + self.sidebar.setMinimumSize(QtCore.QSize(0, 0)) self.sidebar.setMaximumSize(QtCore.QSize(450, 16777215)) self.sidebar.setObjectName("sidebar") self.gridLayout_7 = QtWidgets.QGridLayout(self.sidebar) @@ -48,23 +48,12 @@ def setupUi(self, standard_image_widget): self.scrollArea_2.setWidgetResizable(True) self.scrollArea_2.setObjectName("scrollArea_2") self.scrollAreaWidgetContents_2 = QtWidgets.QWidget() - self.scrollAreaWidgetContents_2.setGeometry(QtCore.QRect(0, 0, 444, 1211)) + self.scrollAreaWidgetContents_2.setGeometry(QtCore.QRect(0, 0, 438, 1211)) self.scrollAreaWidgetContents_2.setObjectName("scrollAreaWidgetContents_2") self.gridLayout_10 = QtWidgets.QGridLayout(self.scrollAreaWidgetContents_2) self.gridLayout_10.setObjectName("gridLayout_10") - self.horizontalLayout_3 = QtWidgets.QHBoxLayout() - self.horizontalLayout_3.setObjectName("horizontalLayout_3") - self.seed_widget = SeedWidget(parent=self.scrollAreaWidgetContents_2) - self.seed_widget.setProperty("generator_section", "") - self.seed_widget.setProperty("generator_name", "") - self.seed_widget.setObjectName("seed_widget") - self.horizontalLayout_3.addWidget(self.seed_widget) - self.seed_widget_latents = SeedWidget(parent=self.scrollAreaWidgetContents_2) - self.seed_widget_latents.setProperty("generator_section", "") - self.seed_widget_latents.setProperty("generator_name", "") - self.seed_widget_latents.setObjectName("seed_widget_latents") - self.horizontalLayout_3.addWidget(self.seed_widget_latents) - self.gridLayout_10.addLayout(self.horizontalLayout_3, 1, 0, 1, 1) + spacerItem = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding) + self.gridLayout_10.addItem(spacerItem, 5, 0, 1, 1) self.horizontalLayout_5 = QtWidgets.QHBoxLayout() self.horizontalLayout_5.setObjectName("horizontalLayout_5") self.steps_widget = SliderWidget(parent=self.scrollAreaWidgetContents_2) @@ -92,6 +81,34 @@ def setupUi(self, standard_image_widget): self.scale_widget.setObjectName("scale_widget") self.horizontalLayout_5.addWidget(self.scale_widget) self.gridLayout_10.addLayout(self.horizontalLayout_5, 2, 0, 1, 1) + self.horizontalLayout_6 = QtWidgets.QHBoxLayout() + self.horizontalLayout_6.setObjectName("horizontalLayout_6") + self.samples_widget_2 = SliderWidget(parent=self.scrollAreaWidgetContents_2) + self.samples_widget_2.setProperty("slider_callback", "handle_value_change") + self.samples_widget_2.setProperty("current_value", 0) + self.samples_widget_2.setProperty("slider_maximum", 500) + self.samples_widget_2.setProperty("spinbox_maximum", 500.0) + self.samples_widget_2.setProperty("display_as_float", False) + self.samples_widget_2.setProperty("spinbox_single_step", 1) + self.samples_widget_2.setProperty("spinbox_page_step", 1) + self.samples_widget_2.setProperty("spinbox_minimum", 1) + self.samples_widget_2.setProperty("slider_minimum", 1) + self.samples_widget_2.setObjectName("samples_widget_2") + self.horizontalLayout_6.addWidget(self.samples_widget_2) + self.clip_skip_slider_widget = SliderWidget(parent=self.scrollAreaWidgetContents_2) + self.clip_skip_slider_widget.setProperty("current_value", 0) + self.clip_skip_slider_widget.setProperty("slider_maximum", 11) + self.clip_skip_slider_widget.setProperty("spinbox_maximum", 12.0) + self.clip_skip_slider_widget.setProperty("display_as_float", False) + self.clip_skip_slider_widget.setProperty("spinbox_single_step", 1) + self.clip_skip_slider_widget.setProperty("spinbox_page_step", 1) + self.clip_skip_slider_widget.setProperty("spinbox_minimum", 0) + self.clip_skip_slider_widget.setProperty("slider_minimum", 0) + self.clip_skip_slider_widget.setProperty("slider_callback", "handle_value_change") + self.clip_skip_slider_widget.setProperty("settings_property", "generator.clip_skip") + self.clip_skip_slider_widget.setObjectName("clip_skip_slider_widget") + self.horizontalLayout_6.addWidget(self.clip_skip_slider_widget) + self.gridLayout_10.addLayout(self.horizontalLayout_6, 3, 0, 1, 1) self.ddim_frames = QtWidgets.QHBoxLayout() self.ddim_frames.setObjectName("ddim_frames") self.ddim_eta_slider_widget = SliderWidget(parent=self.scrollAreaWidgetContents_2) @@ -121,34 +138,6 @@ def setupUi(self, standard_image_widget): self.frames_slider_widget.setObjectName("frames_slider_widget") self.ddim_frames.addWidget(self.frames_slider_widget) self.gridLayout_10.addLayout(self.ddim_frames, 4, 0, 1, 1) - self.horizontalLayout_6 = QtWidgets.QHBoxLayout() - self.horizontalLayout_6.setObjectName("horizontalLayout_6") - self.samples_widget_2 = SliderWidget(parent=self.scrollAreaWidgetContents_2) - self.samples_widget_2.setProperty("slider_callback", "handle_value_change") - self.samples_widget_2.setProperty("current_value", 0) - self.samples_widget_2.setProperty("slider_maximum", 500) - self.samples_widget_2.setProperty("spinbox_maximum", 500.0) - self.samples_widget_2.setProperty("display_as_float", False) - self.samples_widget_2.setProperty("spinbox_single_step", 1) - self.samples_widget_2.setProperty("spinbox_page_step", 1) - self.samples_widget_2.setProperty("spinbox_minimum", 1) - self.samples_widget_2.setProperty("slider_minimum", 1) - self.samples_widget_2.setObjectName("samples_widget_2") - self.horizontalLayout_6.addWidget(self.samples_widget_2) - self.clip_skip_slider_widget = SliderWidget(parent=self.scrollAreaWidgetContents_2) - self.clip_skip_slider_widget.setProperty("current_value", 0) - self.clip_skip_slider_widget.setProperty("slider_maximum", 11) - self.clip_skip_slider_widget.setProperty("spinbox_maximum", 12.0) - self.clip_skip_slider_widget.setProperty("display_as_float", False) - self.clip_skip_slider_widget.setProperty("spinbox_single_step", 1) - self.clip_skip_slider_widget.setProperty("spinbox_page_step", 1) - self.clip_skip_slider_widget.setProperty("spinbox_minimum", 0) - self.clip_skip_slider_widget.setProperty("slider_minimum", 0) - self.clip_skip_slider_widget.setProperty("slider_callback", "handle_value_change") - self.clip_skip_slider_widget.setProperty("settings_property", "generator.clip_skip") - self.clip_skip_slider_widget.setObjectName("clip_skip_slider_widget") - self.horizontalLayout_6.addWidget(self.clip_skip_slider_widget) - self.gridLayout_10.addLayout(self.horizontalLayout_6, 3, 0, 1, 1) self.verticalLayout_3 = QtWidgets.QVBoxLayout() self.verticalLayout_3.setObjectName("verticalLayout_3") self.verticalLayout_4 = QtWidgets.QVBoxLayout() @@ -205,15 +194,19 @@ def setupUi(self, standard_image_widget): self.verticalLayout_2.addWidget(self.scheduler) self.verticalLayout_3.addLayout(self.verticalLayout_2) self.gridLayout_10.addLayout(self.verticalLayout_3, 0, 0, 1, 1) - self.verticalLayout_5 = QtWidgets.QVBoxLayout() - self.verticalLayout_5.setObjectName("verticalLayout_5") - self.variation_checkbox = QtWidgets.QCheckBox(parent=self.scrollAreaWidgetContents_2) - self.variation_checkbox.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor)) - self.variation_checkbox.setObjectName("variation_checkbox") - self.verticalLayout_5.addWidget(self.variation_checkbox) - self.gridLayout_10.addLayout(self.verticalLayout_5, 5, 0, 1, 1) - spacerItem = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding) - self.gridLayout_10.addItem(spacerItem, 6, 0, 1, 1) + self.horizontalLayout_3 = QtWidgets.QHBoxLayout() + self.horizontalLayout_3.setObjectName("horizontalLayout_3") + self.seed_widget = SeedWidget(parent=self.scrollAreaWidgetContents_2) + self.seed_widget.setProperty("generator_section", "") + self.seed_widget.setProperty("generator_name", "") + self.seed_widget.setObjectName("seed_widget") + self.horizontalLayout_3.addWidget(self.seed_widget) + self.seed_widget_latents = SeedWidget(parent=self.scrollAreaWidgetContents_2) + self.seed_widget_latents.setProperty("generator_section", "") + self.seed_widget_latents.setProperty("generator_name", "") + self.seed_widget_latents.setObjectName("seed_widget_latents") + self.horizontalLayout_3.addWidget(self.seed_widget_latents) + self.gridLayout_10.addLayout(self.horizontalLayout_3, 1, 0, 1, 1) self.scrollArea_2.setWidget(self.scrollAreaWidgetContents_2) self.gridLayout_11.addWidget(self.scrollArea_2, 0, 0, 1, 1) icon = QtGui.QIcon.fromTheme("document-properties") @@ -336,16 +329,16 @@ def setupUi(self, standard_image_widget): sizePolicy.setVerticalStretch(0) sizePolicy.setHeightForWidth(self.canvas_widget.sizePolicy().hasHeightForWidth()) self.canvas_widget.setSizePolicy(sizePolicy) - self.canvas_widget.setMinimumSize(QtCore.QSize(532, 532)) + self.canvas_widget.setMinimumSize(QtCore.QSize(0, 0)) self.canvas_widget.setMaximumSize(QtCore.QSize(16777215, 16777215)) - self.canvas_widget.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor)) + self.canvas_widget.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.ArrowCursor)) self.canvas_widget.setAcceptDrops(True) self.canvas_widget.setObjectName("canvas_widget") self.gridLayout.addWidget(self.canvas_widget, 0, 0, 1, 1) self.gridLayout_2.addWidget(self.splitter, 0, 0, 1, 1) self.retranslateUi(standard_image_widget) - self.tabWidget.setCurrentIndex(0) + self.tabWidget.setCurrentIndex(4) self.upscale_model.currentTextChanged['QString'].connect(standard_image_widget.upscale_model_changed) # type: ignore self.face_enhance.clicked['bool'].connect(standard_image_widget.face_enhance_toggled) # type: ignore self.comboBox.currentIndexChanged['int'].connect(standard_image_widget.upscale_number_changed) # type: ignore @@ -353,25 +346,28 @@ def setupUi(self, standard_image_widget): self.advanced_settings_checkbox.clicked['bool'].connect(standard_image_widget.handle_advanced_settings_checkbox) # type: ignore self.generate_single_simillar_button.clicked.connect(standard_image_widget.similar_image) # type: ignore self.generate_batch_similar_button.clicked.connect(standard_image_widget.similar_batch) # type: ignore + self.pipeline.currentTextChanged['QString'].connect(standard_image_widget.handle_pipeline_changed) # type: ignore + self.version.currentTextChanged['QString'].connect(standard_image_widget.handle_version_changed) # type: ignore + self.model.currentTextChanged['QString'].connect(standard_image_widget.handle_model_changed) # type: ignore + self.scheduler.currentTextChanged['QString'].connect(standard_image_widget.handle_scheduler_changed) # type: ignore QtCore.QMetaObject.connectSlotsByName(standard_image_widget) def retranslateUi(self, standard_image_widget): _translate = QtCore.QCoreApplication.translate standard_image_widget.setWindowTitle(_translate("standard_image_widget", "Form")) - self.seed_widget.setProperty("property_name", _translate("standard_image_widget", "generator.random_seed")) - self.seed_widget_latents.setProperty("property_name", _translate("standard_image_widget", "generator.random_latents_seed")) self.steps_widget.setProperty("label_text", _translate("standard_image_widget", "Steps")) self.scale_widget.setProperty("label_text", _translate("standard_image_widget", "Scale")) - self.ddim_eta_slider_widget.setProperty("label_text", _translate("standard_image_widget", "DDIM ETA")) - self.frames_slider_widget.setProperty("label_text", _translate("standard_image_widget", "Frames")) self.samples_widget_2.setProperty("label_text", _translate("standard_image_widget", "Samples")) self.samples_widget_2.setProperty("settings_property", _translate("standard_image_widget", "generator.n_samples")) self.clip_skip_slider_widget.setProperty("label_text", _translate("standard_image_widget", "Clip Skip")) + self.ddim_eta_slider_widget.setProperty("label_text", _translate("standard_image_widget", "DDIM ETA")) + self.frames_slider_widget.setProperty("label_text", _translate("standard_image_widget", "Frames")) self.label_5.setText(_translate("standard_image_widget", "Pipeline")) self.label_6.setText(_translate("standard_image_widget", "Version")) self.label_3.setText(_translate("standard_image_widget", "Model")) self.label_4.setText(_translate("standard_image_widget", "Scheduler")) - self.variation_checkbox.setText(_translate("standard_image_widget", "Variation")) + self.seed_widget.setProperty("property_name", _translate("standard_image_widget", "generator.random_seed")) + self.seed_widget_latents.setProperty("property_name", _translate("standard_image_widget", "generator.random_latents_seed")) self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_6), _translate("standard_image_widget", "Settings")) self.tabWidget.setTabToolTip(self.tabWidget.indexOf(self.tab_6), _translate("standard_image_widget", "Stable Diffusion settings")) self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_5), _translate("standard_image_widget", "Presets")) diff --git a/src/airunner/widgets/controlnet_settings/controlnet_settings_widget.py b/src/airunner/widgets/controlnet_settings/controlnet_settings_widget.py index 86921d587..0222dad8f 100644 --- a/src/airunner/widgets/controlnet_settings/controlnet_settings_widget.py +++ b/src/airunner/widgets/controlnet_settings/controlnet_settings_widget.py @@ -245,15 +245,8 @@ def set_mask_thumbnail(self): # clear the image self.ui.mask_thumbnail.clear() - def initialize( - self, - generator_name, - generator_section - ): - super().initialize( - generator_name, - generator_section - ) + def initialize(self): + super().initialize() self.initialize_combobox() def initialize_combobox(self): diff --git a/src/airunner/widgets/deterministic/__init__.py b/src/airunner/widgets/deterministic/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/airunner/widgets/deterministic/deterministic_widget.py b/src/airunner/widgets/deterministic/deterministic_widget.py deleted file mode 100644 index 25e7191c0..000000000 --- a/src/airunner/widgets/deterministic/deterministic_widget.py +++ /dev/null @@ -1,23 +0,0 @@ -from airunner.widgets.base_widget import BaseWidget -from airunner.widgets.deterministic.templates.deterministic_widget_ui import Ui_deterministic_widget - - -class DeterministicWidget(BaseWidget): - widget_class_ = Ui_deterministic_widget - - @property - def batch_size(self): - return self.ui.images_per_batch.value() - - @property - def category(self): - return self.ui.category.text() - - def action_value_changed_images_per_batch(self, val): - self.settings_manager.set_value("determinisitic_settings.images_per_batch", val) - - def action_text_changed_category(self, val): - self.settings_manager.set_value("determinisitic_settings.category", val) - - def action_clicked_button_generate_batch(self): - print("generate batch clicked") diff --git a/src/airunner/widgets/deterministic/templates/__init__.py b/src/airunner/widgets/deterministic/templates/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/airunner/widgets/deterministic/templates/deterministic_widget.ui b/src/airunner/widgets/deterministic/templates/deterministic_widget.ui deleted file mode 100644 index 9f64b91fd..000000000 --- a/src/airunner/widgets/deterministic/templates/deterministic_widget.ui +++ /dev/null @@ -1,220 +0,0 @@ - - - deterministic_widget - - - - 0 - 0 - 322 - 256 - - - - - 16777215 - 16777215 - - - - - 8 - - - - Form - - - - - - true - - - - - 0 - 0 - 302 - 236 - - - - - - - - 8 - true - - - - Cateogry - - - - - - - 8 - true - - - - - - - - - - - Qt::Horizontal - - - - - - - - - - - 10 - true - - - - Deterministic generation - - - - - - - - 8 - true - - - - Images per-batch - - - - - - - 8 - false - - - - 1 - - - 16 - - - 4 - - - - - - - - - - - 8 - - - - Generate Deterministic Batch - - - - - - - Qt::Vertical - - - - 20 - 7 - - - - - - - - - - - - - SeedWidget - QWidget -
airunner/widgets/seed/seed_widget
- 1 -
-
- - - - category - currentTextChanged(QString) - deterministic_widget - action_text_changed_category(QString) - - - 59 - 92 - - - 15 - -12 - - - - - images_per_batch - valueChanged(int) - deterministic_widget - action_value_changed_images_per_batch(int) - - - 155 - 166 - - - 93 - -13 - - - - - generate_batches_button - clicked() - deterministic_widget - action_clicked_button_generate_batch() - - - 96 - 239 - - - 115 - -5 - - - - - - action_text_changed_category(QString) - action_value_changed_images_per_batch(int) - action_clicked_button_generate_batch() - -
diff --git a/src/airunner/widgets/deterministic/templates/deterministic_widget_ui.py b/src/airunner/widgets/deterministic/templates/deterministic_widget_ui.py deleted file mode 100644 index 50c2e3b01..000000000 --- a/src/airunner/widgets/deterministic/templates/deterministic_widget_ui.py +++ /dev/null @@ -1,104 +0,0 @@ -# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/deterministic/templates/deterministic_widget.ui' -# -# Created by: PyQt6 UI code generator 6.4.2 -# -# WARNING: Any manual changes made to this file will be lost when pyuic6 is -# run again. Do not edit this file unless you know what you are doing. - - -from PyQt6 import QtCore, QtGui, QtWidgets - - -class Ui_deterministic_widget(object): - def setupUi(self, deterministic_widget): - deterministic_widget.setObjectName("deterministic_widget") - deterministic_widget.resize(322, 256) - deterministic_widget.setMaximumSize(QtCore.QSize(16777215, 16777215)) - font = QtGui.QFont() - font.setPointSize(8) - deterministic_widget.setFont(font) - self.gridLayout_4 = QtWidgets.QGridLayout(deterministic_widget) - self.gridLayout_4.setObjectName("gridLayout_4") - self.scrollArea = QtWidgets.QScrollArea(parent=deterministic_widget) - self.scrollArea.setWidgetResizable(True) - self.scrollArea.setObjectName("scrollArea") - self.scrollAreaWidgetContents = QtWidgets.QWidget() - self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 302, 236)) - self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents") - self.gridLayout_2 = QtWidgets.QGridLayout(self.scrollAreaWidgetContents) - self.gridLayout_2.setObjectName("gridLayout_2") - self.groupBox_2 = QtWidgets.QGroupBox(parent=self.scrollAreaWidgetContents) - font = QtGui.QFont() - font.setPointSize(8) - font.setBold(True) - self.groupBox_2.setFont(font) - self.groupBox_2.setObjectName("groupBox_2") - self.gridLayout_3 = QtWidgets.QGridLayout(self.groupBox_2) - self.gridLayout_3.setObjectName("gridLayout_3") - self.category = QtWidgets.QComboBox(parent=self.groupBox_2) - font = QtGui.QFont() - font.setPointSize(8) - font.setBold(True) - self.category.setFont(font) - self.category.setObjectName("category") - self.gridLayout_3.addWidget(self.category, 0, 0, 1, 1) - self.gridLayout_2.addWidget(self.groupBox_2, 2, 0, 1, 1) - self.line = QtWidgets.QFrame(parent=self.scrollAreaWidgetContents) - self.line.setFrameShape(QtWidgets.QFrame.Shape.HLine) - self.line.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken) - self.line.setObjectName("line") - self.gridLayout_2.addWidget(self.line, 1, 0, 1, 1) - self.deterministic_seed = SeedWidget(parent=self.scrollAreaWidgetContents) - self.deterministic_seed.setObjectName("deterministic_seed") - self.gridLayout_2.addWidget(self.deterministic_seed, 4, 0, 1, 1) - self.label = QtWidgets.QLabel(parent=self.scrollAreaWidgetContents) - font = QtGui.QFont() - font.setPointSize(10) - font.setBold(True) - self.label.setFont(font) - self.label.setObjectName("label") - self.gridLayout_2.addWidget(self.label, 0, 0, 1, 1) - self.groupBox = QtWidgets.QGroupBox(parent=self.scrollAreaWidgetContents) - font = QtGui.QFont() - font.setPointSize(8) - font.setBold(True) - self.groupBox.setFont(font) - self.groupBox.setObjectName("groupBox") - self.gridLayout = QtWidgets.QGridLayout(self.groupBox) - self.gridLayout.setObjectName("gridLayout") - self.images_per_batch = QtWidgets.QSpinBox(parent=self.groupBox) - font = QtGui.QFont() - font.setPointSize(8) - font.setBold(False) - self.images_per_batch.setFont(font) - self.images_per_batch.setMinimum(1) - self.images_per_batch.setMaximum(16) - self.images_per_batch.setProperty("value", 4) - self.images_per_batch.setObjectName("images_per_batch") - self.gridLayout.addWidget(self.images_per_batch, 0, 0, 1, 1) - self.gridLayout_2.addWidget(self.groupBox, 3, 0, 1, 1) - self.generate_batches_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents) - font = QtGui.QFont() - font.setPointSize(8) - self.generate_batches_button.setFont(font) - self.generate_batches_button.setObjectName("generate_batches_button") - self.gridLayout_2.addWidget(self.generate_batches_button, 5, 0, 1, 1) - spacerItem = QtWidgets.QSpacerItem(20, 7, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding) - self.gridLayout_2.addItem(spacerItem, 6, 0, 1, 1) - self.scrollArea.setWidget(self.scrollAreaWidgetContents) - self.gridLayout_4.addWidget(self.scrollArea, 0, 0, 1, 1) - - self.retranslateUi(deterministic_widget) - self.category.currentTextChanged['QString'].connect(deterministic_widget.action_text_changed_category) # type: ignore - self.images_per_batch.valueChanged['int'].connect(deterministic_widget.action_value_changed_images_per_batch) # type: ignore - self.generate_batches_button.clicked.connect(deterministic_widget.action_clicked_button_generate_batch) # type: ignore - QtCore.QMetaObject.connectSlotsByName(deterministic_widget) - - def retranslateUi(self, deterministic_widget): - _translate = QtCore.QCoreApplication.translate - deterministic_widget.setWindowTitle(_translate("deterministic_widget", "Form")) - self.groupBox_2.setTitle(_translate("deterministic_widget", "Cateogry")) - self.label.setText(_translate("deterministic_widget", "Deterministic generation")) - self.groupBox.setTitle(_translate("deterministic_widget", "Images per-batch")) - self.generate_batches_button.setText(_translate("deterministic_widget", "Generate Deterministic Batch")) -from airunner.widgets.seed.seed_widget import SeedWidget diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py index 7c9a5924c..5cf58d522 100644 --- a/src/airunner/widgets/generator_form/generator_form_widget.py +++ b/src/airunner/widgets/generator_form/generator_form_widget.py @@ -1,13 +1,11 @@ import random from PIL import Image -from PyQt6.QtCore import pyqtSignal, QRect, QTimer -from PyQt6.QtWidgets import QWidget +from PyQt6.QtCore import pyqtSignal, QRect from airunner.aihandler.settings import MAX_SEED from airunner.data.db import session -from airunner.data.models import AIModel, ActiveGridSettings, CanvasSettings, Pipeline, GeneratorSetting -from airunner.utils import get_session +from airunner.data.models import ActiveGridSettings, CanvasSettings from airunner.widgets.base_widget import BaseWidget from airunner.widgets.generator_form.templates.generatorform_ui import Ui_generator_form @@ -25,6 +23,9 @@ class GeneratorForm(BaseWidget): initialized = False parent = None generate_signal = pyqtSignal(dict) + icons = ( + ("artificial-intelligence-ai-chip-icon", "ai_button"), + ) @property def is_txt2img(self): @@ -56,19 +57,11 @@ def is_txt2vid(self): @property def generator_section(self): - try: - return self.property("generator_section") - except Exception as e: - print(e) - return None + return getattr(self.settings_manager, f"current_section_{self.settings_manager.current_image_generator}") @property def generator_name(self): - try: - return self.property("generator_name") - except Exception as e: - print(e) - return None + return self.settings_manager.current_image_generator @property def generator_settings(self): @@ -92,7 +85,7 @@ def latents_seed(self): @latents_seed.setter def latents_seed(self, val): self.settings_manager.set_value("generator.latents_seed", val) - self.app.ui.standard_image_widget.ui.seed_widget_latents.ui.lineEdit.setText(str(val)) + self.app.standard_image_panel.ui.seed_widget_latents.ui.lineEdit.setText(str(val)) @property def seed(self): @@ -101,7 +94,7 @@ def seed(self): @seed.setter def seed(self, val): self.settings_manager.set_value("generator.seed", val) - self.app.ui.standard_image_widget.ui.seed_widget.ui.lineEdit.setText(str(val)) + self.app.standard_image_panel.ui.seed_widget.ui.lineEdit.setText(str(val)) @property def image_scale(self): @@ -125,7 +118,7 @@ def enable_controlnet(self): @property def controlnet_image(self): - return self.app.ui.standard_image_widget.ui.controlnet_settings.current_controlnet_image + return self.app.standard_image_panel.ui.controlnet_settings.current_controlnet_image def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -133,11 +126,6 @@ def __init__(self, *args, **kwargs): self.canvas_settings = session.query(CanvasSettings).first() self.settings_manager.changed_signal.connect(self.handle_changed_signal) - self.initialize() - - def enable_preset(self, id): - self.settings_manager.set_value("generator_settings_override_id", id) - self.initialize() def toggle_advanced_generation(self): advanced_mode = self.settings_manager.enable_advanced_mode @@ -184,21 +172,6 @@ def handle_negative_prompt_changed(self): def toggle_prompt_builder_checkbox(self, toggled): pass - def handle_model_changed(self, name): - if not self.initialized: - return - self.settings_manager.set_value("generator.model", name) - self.changed_signal.emit("generator.model", name) - - def handle_scheduler_changed(self, name): - if not self.initialized: - return - self.settings_manager.set_value("generator.scheduler", name) - self.changed_signal.emit("generator.scheduler", name) - - def toggle_variation(self, toggled): - pass - def handle_generate_button_clicked(self): self.start_progress_bar() self.generate(image=self.app.current_active_image()) @@ -211,10 +184,10 @@ def handle_interrupt_button_clicked(self): def generate(self, image=None, seed=None): if seed is None: - seed = self.app.ui.standard_image_widget.ui.seed_widget.seed - if self.app.ui.standard_image_widget.ui.samples_widget.current_value > 1: + seed = self.app.standard_image_panel.ui.seed_widget.seed + if self.app.standard_image_panel.ui.samples_widget.current_value > 1: self.app.client.do_process_queue = False - total_samples = self.app.ui.standard_image_widget.ui.samples_widget.current_value if not self.is_txt2vid else 1 + total_samples = self.settings_manager.generator.n_samples for n in range(total_samples): if self.settings_manager.generator.use_prompt_builder and n > 0: seed = int(seed) + n @@ -242,7 +215,7 @@ def call_generate(self, image=None, seed=None, override_data=None): self.settings_manager.generator.enable_input_image ) if enable_input_image: - input_image = self.app.ui.standard_image_widget.ui.input_image_widget.current_input_image + input_image = self.app.standard_image_panel.ui.input_image_widget.current_input_image elif self.generator_section == "txt2img": input_image = override_data.get("input_image", None) image = input_image @@ -258,7 +231,7 @@ def call_generate(self, image=None, seed=None, override_data=None): if image is None: if self.is_txt2img: return self.do_generate(seed=seed, override_data=override_data) - # Create a transparent image the size of self.canvas.active_grid_area_rect + # Create a transparent image the size of self.app.canvas_widget.active_grid_area_rect width = self.settings_manager.working_width height = self.settings_manager.working_height image = Image.new("RGBA", (int(width), int(height)), (0, 0, 0, 0)) @@ -284,14 +257,14 @@ def call_generate(self, image=None, seed=None, override_data=None): # Get the cropped image cropped_outpaint_box_rect = self.active_rect # crop_location = ( - # cropped_outpaint_box_rect.x() - self.canvas.image_pivot_point.x(), - # cropped_outpaint_box_rect.y() - self.canvas.image_pivot_point.y(), - # cropped_outpaint_box_rect.width() - self.canvas.image_pivot_point.x(), - # cropped_outpaint_box_rect.height() - self.canvas.image_pivot_point.y() + # cropped_outpaint_box_rect.x() - self.app.canvas_widget.image_pivot_point.x(), + # cropped_outpaint_box_rect.y() - self.app.canvas_widget.image_pivot_point.y(), + # cropped_outpaint_box_rect.width() - self.app.canvas_widget.image_pivot_point.x(), + # cropped_outpaint_box_rect.height() - self.app.canvas_widget.image_pivot_point.y() # ) crop_location = ( - cropped_outpaint_box_rect.x() - self.canvas.current_layer.pos_x, - cropped_outpaint_box_rect.y() - self.canvas.current_layer.pos_y, + cropped_outpaint_box_rect.x() - self.app.canvas_widget.current_layer.pos_x, + cropped_outpaint_box_rect.y() - self.app.canvas_widget.current_layer.pos_y, cropped_outpaint_box_rect.width(), cropped_outpaint_box_rect.height() ) @@ -379,7 +352,6 @@ def do_generate(self, extra_options=None, seed=None, latents_seed=None, do_deter # get the model from the database - print(model_data) model = self.settings_manager.models.filter_by( name=model_data["name"] if "name" in model_data \ else self.settings_manager.generator.model @@ -474,8 +446,8 @@ def do_generate(self, extra_options=None, seed=None, latents_seed=None, do_deter options["controlnet_image"] = self.controlnet_image if action == "superresolution": - options["original_image_width"] = self.canvas.current_active_image_data.image.width - options["original_image_height"] = self.canvas.current_active_image_data.image.height + options["original_image_width"] = self.app.canvas_widget.current_active_image_data.image.width + options["original_image_height"] = self.app.canvas_widget.current_active_image_data.image.height if action in ["txt2img", "img2img", "outpaint", "depth2img"]: options[f"strength"] = strength @@ -500,25 +472,6 @@ def do_generate(self, extra_options=None, seed=None, latents_seed=None, do_deter } self.app.client.message = data - def do_deterministic_generation(self, extra_options): - action = self.deterministic_data["action"] - options = self.deterministic_data["options"] - options[f"prompt"] = self.deterministic_data[f"prompt"][self.deterministic_index] - memory_options = self.get_memory_options() - data = { - "action": action, - "options": { - **options, - **extra_options, - **memory_options, - "batch_size": self.settings_manager.deterministic_settings.batch_size, - "deterministic_generation": True, - "deterministic_seed": self.settings_manager.deterministic_settings.seed, - "deterministic_style": self.settings_manager.deterministic_settings.style, - } - } - self.app.client.message = data - def get_memory_options(self): return { "use_last_channels": self.settings_manager.memory_settings.use_last_channels, @@ -545,8 +498,8 @@ def set_seed(self, seed=None, latents_seed=None): self.update_seed() def update_seed(self): - self.app.ui.standard_image_widget.ui.seed_widget.update_seed() - self.app.ui.standard_image_widget.ui.seed_widget_latents.update_seed() + self.app.standard_image_panel.ui.seed_widget.update_seed() + self.app.standard_image_panel.ui.seed_widget_latents.update_seed() def set_primary_seed(self, seed=None): if self.deterministic_data: @@ -629,3 +582,24 @@ def new_batch(self, index, image, data): self.deterministic = False self.deterministic_data = None self.deterministic_images = None + + def set_progress_bar_value(self, tab_section, section, value): + progressbar = self.ui.progress_bar + if not progressbar: + return + if progressbar.maximum() == 0: + progressbar.setRange(0, 100) + progressbar.setValue(value) + + def stop_progress_bar(self, tab_section, section): + progressbar = self.ui.progress_bar + if not progressbar: + return + progressbar.setRange(0, 100) + progressbar.setValue(100) + + def update_prompt(self, prompt): + self.ui.prompt.setPlainText(prompt) + + def update_negative_prompt(self, prompt): + self.ui.negative_prompt.setPlainText(prompt) \ No newline at end of file diff --git a/src/airunner/widgets/generator_form/templates/generatorform.ui b/src/airunner/widgets/generator_form/templates/generatorform.ui index 4498a0c34..2ad953d8e 100644 --- a/src/airunner/widgets/generator_form/templates/generatorform.ui +++ b/src/airunner/widgets/generator_form/templates/generatorform.ui @@ -21,7 +21,7 @@ Form - + 0 @@ -34,152 +34,216 @@ 0 - - - - QFrame::NoFrame + + + + 0 - - QFrame::Plain - - - true - - - - - 0 - 0 - 361 - 1064 - - - + + + Tab 1 + + + + 0 + - 2 + 0 + + + 0 - 9 + 0 - - - - - - Generate - - - - - - + + + + QFrame::NoFrame + + + QFrame::Plain + + + true + + + + + 0 + 0 + 357 + 1038 + + + + 0 - - - - - - PointingHandCursor + + 0 - - Interrupt + + 0 - - - - - - - - Qt::Vertical - - - - 9 + 0 - + + + Qt::Vertical + + + + + 9 + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + Save Prompts + + + + + + + + + + + :/icons/light/artificial-intelligence-ai-chip-icon.svg:/icons/light/artificial-intelligence-ai-chip-icon.svg + + + true + + + + + + + + + + 8 + true + + + + Prompt + + + + + + + Enter a prompt... + + + + + + + + + 9 + + + + + + 8 + true + + + + Negative Prompt + + + + + + + Enter a negative prompt... + + + + + + + + + - - - Qt::Horizontal + + + Generate - - - 40 - 20 - + + + + + + 0 - + - + + + PointingHandCursor + - Save Prompts + Interrupt - - - - - 8 - true - - - - Prompt - - - - - - - Enter a prompt... - - - - - - - - - 9 - - - - - - 8 - true - - - - Negative Prompt - - - - - - - Enter a negative prompt... - - - + + + Tab 2 + + + + + + + + + + ChatPromptWidget + QWidget +
airunner/widgets/llm/chat_prompt_widget
+ 1 +
+
@@ -192,8 +256,8 @@ action_clicked_button_save_prompts() - 349 - 23 + 322 + 47 502 diff --git a/src/airunner/widgets/generator_form/templates/generatorform_ui.py b/src/airunner/widgets/generator_form/templates/generatorform_ui.py index 180de7232..a2236295d 100644 --- a/src/airunner/widgets/generator_form/templates/generatorform_ui.py +++ b/src/airunner/widgets/generator_form/templates/generatorform_ui.py @@ -17,34 +17,27 @@ def setupUi(self, generator_form): font.setPointSize(8) generator_form.setFont(font) generator_form.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor)) - self.gridLayout_7 = QtWidgets.QGridLayout(generator_form) - self.gridLayout_7.setContentsMargins(0, 0, 0, 0) - self.gridLayout_7.setObjectName("gridLayout_7") - self.scrollArea = QtWidgets.QScrollArea(parent=generator_form) + self.gridLayout_4 = QtWidgets.QGridLayout(generator_form) + self.gridLayout_4.setContentsMargins(0, 0, 0, 0) + self.gridLayout_4.setObjectName("gridLayout_4") + self.tabWidget = QtWidgets.QTabWidget(parent=generator_form) + self.tabWidget.setObjectName("tabWidget") + self.tab = QtWidgets.QWidget() + self.tab.setObjectName("tab") + self.gridLayout_3 = QtWidgets.QGridLayout(self.tab) + self.gridLayout_3.setContentsMargins(0, 0, 0, 0) + self.gridLayout_3.setObjectName("gridLayout_3") + self.scrollArea = QtWidgets.QScrollArea(parent=self.tab) self.scrollArea.setFrameShape(QtWidgets.QFrame.Shape.NoFrame) self.scrollArea.setFrameShadow(QtWidgets.QFrame.Shadow.Plain) self.scrollArea.setWidgetResizable(True) self.scrollArea.setObjectName("scrollArea") self.scrollAreaWidgetContents = QtWidgets.QWidget() - self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 361, 1064)) + self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 357, 1038)) self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents") self.gridLayout = QtWidgets.QGridLayout(self.scrollAreaWidgetContents) - self.gridLayout.setContentsMargins(-1, 2, -1, 9) + self.gridLayout.setContentsMargins(0, 0, 0, 0) self.gridLayout.setObjectName("gridLayout") - self.horizontalLayout_9 = QtWidgets.QHBoxLayout() - self.horizontalLayout_9.setObjectName("horizontalLayout_9") - self.generate_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents) - self.generate_button.setObjectName("generate_button") - self.horizontalLayout_9.addWidget(self.generate_button) - self.progress_bar = QtWidgets.QProgressBar(parent=self.scrollAreaWidgetContents) - self.progress_bar.setProperty("value", 0) - self.progress_bar.setObjectName("progress_bar") - self.horizontalLayout_9.addWidget(self.progress_bar) - self.interrupt_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents) - self.interrupt_button.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor)) - self.interrupt_button.setObjectName("interrupt_button") - self.horizontalLayout_9.addWidget(self.interrupt_button) - self.gridLayout.addLayout(self.horizontalLayout_9, 1, 0, 1, 1) self.splitter = QtWidgets.QSplitter(parent=self.scrollAreaWidgetContents) self.splitter.setOrientation(QtCore.Qt.Orientation.Vertical) self.splitter.setObjectName("splitter") @@ -60,6 +53,14 @@ def setupUi(self, generator_form): self.pushButton = QtWidgets.QPushButton(parent=self.layoutWidget) self.pushButton.setObjectName("pushButton") self.horizontalLayout_6.addWidget(self.pushButton) + self.ai_button = QtWidgets.QPushButton(parent=self.layoutWidget) + self.ai_button.setText("") + icon = QtGui.QIcon() + icon.addPixmap(QtGui.QPixmap(":/icons/light/artificial-intelligence-ai-chip-icon.svg"), QtGui.QIcon.Mode.Normal, QtGui.QIcon.State.Off) + self.ai_button.setIcon(icon) + self.ai_button.setCheckable(True) + self.ai_button.setObjectName("ai_button") + self.horizontalLayout_6.addWidget(self.ai_button) self.gridLayout_2.addLayout(self.horizontalLayout_6, 0, 1, 1, 1) self.label = QtWidgets.QLabel(parent=self.layoutWidget) font = QtGui.QFont() @@ -86,11 +87,36 @@ def setupUi(self, generator_form): self.negative_prompt = QtWidgets.QPlainTextEdit(parent=self.layoutWidget1) self.negative_prompt.setObjectName("negative_prompt") self.verticalLayout_6.addWidget(self.negative_prompt) - self.gridLayout.addWidget(self.splitter, 0, 0, 1, 1) + self.gridLayout.addWidget(self.splitter, 0, 1, 1, 1) + self.horizontalLayout_9 = QtWidgets.QHBoxLayout() + self.horizontalLayout_9.setObjectName("horizontalLayout_9") + self.generate_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents) + self.generate_button.setObjectName("generate_button") + self.horizontalLayout_9.addWidget(self.generate_button) + self.progress_bar = QtWidgets.QProgressBar(parent=self.scrollAreaWidgetContents) + self.progress_bar.setProperty("value", 0) + self.progress_bar.setObjectName("progress_bar") + self.horizontalLayout_9.addWidget(self.progress_bar) + self.interrupt_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents) + self.interrupt_button.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor)) + self.interrupt_button.setObjectName("interrupt_button") + self.horizontalLayout_9.addWidget(self.interrupt_button) + self.gridLayout.addLayout(self.horizontalLayout_9, 1, 1, 1, 1) self.scrollArea.setWidget(self.scrollAreaWidgetContents) - self.gridLayout_7.addWidget(self.scrollArea, 1, 0, 1, 1) + self.gridLayout_3.addWidget(self.scrollArea, 0, 0, 1, 1) + self.tabWidget.addTab(self.tab, "") + self.tab_2 = QtWidgets.QWidget() + self.tab_2.setObjectName("tab_2") + self.gridLayout_5 = QtWidgets.QGridLayout(self.tab_2) + self.gridLayout_5.setObjectName("gridLayout_5") + self.widget = ChatPromptWidget(parent=self.tab_2) + self.widget.setObjectName("widget") + self.gridLayout_5.addWidget(self.widget, 0, 0, 1, 1) + self.tabWidget.addTab(self.tab_2, "") + self.gridLayout_4.addWidget(self.tabWidget, 0, 0, 1, 1) self.retranslateUi(generator_form) + self.tabWidget.setCurrentIndex(0) self.pushButton.clicked.connect(generator_form.action_clicked_button_save_prompts) # type: ignore self.interrupt_button.clicked.connect(generator_form.handle_interrupt_button_clicked) # type: ignore self.generate_button.clicked.connect(generator_form.handle_generate_button_clicked) # type: ignore @@ -101,10 +127,13 @@ def setupUi(self, generator_form): def retranslateUi(self, generator_form): _translate = QtCore.QCoreApplication.translate generator_form.setWindowTitle(_translate("generator_form", "Form")) - self.generate_button.setText(_translate("generator_form", "Generate")) - self.interrupt_button.setText(_translate("generator_form", "Interrupt")) self.pushButton.setText(_translate("generator_form", "Save Prompts")) self.label.setText(_translate("generator_form", "Prompt")) self.prompt.setPlaceholderText(_translate("generator_form", "Enter a prompt...")) self.label_2.setText(_translate("generator_form", "Negative Prompt")) self.negative_prompt.setPlaceholderText(_translate("generator_form", "Enter a negative prompt...")) + self.generate_button.setText(_translate("generator_form", "Generate")) + self.interrupt_button.setText(_translate("generator_form", "Interrupt")) + self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab), _translate("generator_form", "Tab 1")) + self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_2), _translate("generator_form", "Tab 2")) +from airunner.widgets.llm.chat_prompt_widget import ChatPromptWidget diff --git a/src/airunner/widgets/image/image_widget.py b/src/airunner/widgets/image/image_widget.py index d4b027f6f..2e7a0c91b 100644 --- a/src/airunner/widgets/image/image_widget.py +++ b/src/airunner/widgets/image/image_widget.py @@ -244,7 +244,6 @@ def __init__(self, *args, **kwargs): def handle_label_clicked(self, event): # get the clicked object if event.button() == Qt.MouseButton.LeftButton: - # check if shift is down shift_pressed = event.modifiers() == Qt.KeyboardModifier.ShiftModifier self.container.activate_brush(self, self.brush, shift_pressed) elif event.button() == Qt.MouseButton.RightButton: diff --git a/src/airunner/widgets/input_image/input_image_settings_widget.py b/src/airunner/widgets/input_image/input_image_settings_widget.py index c58a28b5b..0558e51c9 100644 --- a/src/airunner/widgets/input_image/input_image_settings_widget.py +++ b/src/airunner/widgets/input_image/input_image_settings_widget.py @@ -43,15 +43,7 @@ def current_input_image(self): except AttributeError: return None - def initialize( - self, - generator_name, - generator_section - ): - self.setProperty("generator_name", generator_name) - self.setProperty("generator_section", generator_section) - self.settings_manager.generator_name = generator_name - self.settings_manager.generator_section = generator_section + def initialize(self): self.update_buttons() self.ui.groupBox.setTitle(self.property("checkbox_label")) self.ui.scale_slider_widget.initialize() diff --git a/src/airunner/widgets/layers/layer_container_widget.py b/src/airunner/widgets/layers/layer_container_widget.py index 4be82f88c..39e5f154c 100644 --- a/src/airunner/widgets/layers/layer_container_widget.py +++ b/src/airunner/widgets/layers/layer_container_widget.py @@ -26,10 +26,6 @@ def current_layer(self): except IndexError: Logger.error(f"No current layer for index {self.current_layer_index}") - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.app.loaded.connect(self.initialize) - def initialize(self): self.ui.scrollAreaWidgetContents.layout().addSpacerItem( QSpacerItem(0, 0, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding)) @@ -195,7 +191,7 @@ def delete_selected_layers(self): self.delete_layer(index=index, layer=layer) self.selected_layers = {} self.show_layers() - self.app.canvas.update() + self.app.standard_image_panel.canvas_widget.do_draw() def delete_layer(self, _value=False, index=None, layer=None): Logger.info(f"delete_layer requested index {index}") @@ -208,7 +204,7 @@ def delete_layer(self, _value=False, index=None, layer=None): if current_index is None: current_index = self.current_layer_index Logger.info(f"Deleting layer {current_index}") - self.app.canvas.delete_image() + self.app.standard_image_panel.canvas_widget.delete_image() self.app.history.add_event({ "event": "delete_layer", "layers": self.get_layers_copy(), @@ -302,6 +298,7 @@ def toggle_layer_visibility(self, layer, layer_obj): layer.visible = not layer.visible self.update() layer_obj.set_icon() + self.app.canvas_widget.do_draw() def handle_move_layer(self, event): point = QPoint( diff --git a/src/airunner/widgets/layers/layer_widget.py b/src/airunner/widgets/layers/layer_widget.py index 9dc4500aa..5a5067329 100644 --- a/src/airunner/widgets/layers/layer_widget.py +++ b/src/airunner/widgets/layers/layer_widget.py @@ -64,7 +64,7 @@ def action_clicked_button_toggle_layer_visibility(self, val): self.layer_data.visible = val session = get_session() session.commit() - self.app.canvas.do_draw() + self.app.canvas_widget.do_draw() def set_thumbnail(self): image = self.layer_data.image diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py new file mode 100644 index 000000000..cdb60c044 --- /dev/null +++ b/src/airunner/widgets/llm/chat_prompt_widget.py @@ -0,0 +1,307 @@ +from PyQt6.QtCore import pyqtSlot +from PyQt6.QtWidgets import QSpacerItem, QSizePolicy + +from airunner.aihandler.enums import MessageCode +from airunner.data.db import session +from airunner.data.models import Conversation, LLMPromptTemplate, Message +from airunner.widgets.base_widget import BaseWidget +from airunner.widgets.llm.templates.chat_prompt_ui import Ui_chat_prompt +from airunner.widgets.llm.message_widget import MessageWidget +from airunner.utils import save_session +from airunner.aihandler.logger import Logger + + +class ChatPromptWidget(BaseWidget): + widget_class_ = Ui_chat_prompt + conversation = None + is_modal = True + generating = False + prefix = "" + prompt = "" + suffix = "" + conversation_history = [] + spacer = None + + @property + def generator(self): + try: + return self.app.ui.llm_widget.generator + except Exception as e: + Logger.error(e) + import traceback + traceback.print_exc() + + @property + def generator_settings(self): + try: + return self.app.ui.llm_widget.generator_settings + except Exception as e: + Logger.error(e) + + @property + def instructions(self): + return f"{self.generator.botname} loves {self.generator.username}. {self.generator.botname} is very nice. {self.generator.botname} uses compliments, kind responses, and nice words. Everything {self.generator.botname} says is nice. {self.generator.botname} is kind." + + @property + def current_generator(self): + return self.settings_manager.current_llm_generator + + @property + def instructions(self): + return f"{self.generator.botname} loves {self.generator.username}. {self.generator.botname} is very nice. {self.generator.botname} uses compliments, kind responses, and nice words. Everything {self.generator.botname} says is nice. {self.generator.botname} is kind." + + def load_data(self): + self.conversation = session.query(Conversation).first() + if self.conversation is None: + self.conversation = Conversation() + session.add(self.conversation) + session.commit() + + def initialize(self): + self.load_data() + + self.app.token_signal.connect(self.handle_token_signal) + self.app.message_var.my_signal.connect(self.message_handler) + self.ui.prompt.returnPressed.connect(self.action_button_clicked_send) + self.ui.prompt.textChanged.connect(self.prompt_text_changed) + self.ui.conversation.hide() + self.ui.chat_container.show() + + def handle_token_signal(self, val): + if val != "[END]": + text = self.ui.conversation.toPlainText() + text += val + self.ui.conversation.setText(text) + else: + self.stop_progress_bar() + self.generating = False + self.enable_send_button() + + @pyqtSlot(dict) + def message_handler(self, response: dict): + try: + code = response["code"] + except TypeError: + return + message = response["message"] + + if code == MessageCode.TEXT_GENERATED: + self.handle_text_generated(message) + + def handle_text_generated(self, message): + self.stop_progress_bar() + + # # check if messages is string or list + # if isinstance(messages, str): + # messages = [messages] + + # #print("MESSAGES", messages) + + # if messages is None: + # return + + # # get last message + # message = messages[-1]["content"] + + # strip quotes from start and end of message + if not message: + return + if message.startswith("\""): + message = message[1:] + if message.endswith("\""): + message = message[:-1] + message_object = Message( + name=self.generator.botname, + message=message, + conversation=self.conversation + ) + session.add(message_object) + session.commit() + + if self.settings_manager.enable_tts: + # split on sentence enders + sentence_enders = [".", "?", "!", "\n"] + text = message_object.message + sentences = [] + # split text into sentences + current_sentence = "" + for char in text: + current_sentence += char + if char in sentence_enders: + sentences.append(current_sentence) + current_sentence = "" + if current_sentence != "": + sentences.append(current_sentence) + + for index, sentence in enumerate(sentences): + sentence = sentence.strip() + self.app.client.message = dict( + tts_request=True, + request_data=dict( + text=sentence, + message_object=Message( + name=message_object.name, + message=sentence, + ), + is_bot=True, + callback=self.add_message_to_conversation, + first_message=index == 0, + last_message=index == len(sentences) - 1, + ) + ) + + self.add_message_to_conversation(message_object, is_bot=True) + + self.generating = False + self.enable_send_button() + + def prompt_text_changed(self, val): + self.prompt = val + + def clear_prompt(self): + self.ui.prompt.setText("") + + def start_progress_bar(self): + self.ui.progressBar.setRange(0, 0) + self.ui.progressBar.setValue(0) + + def stop_progress_bar(self): + self.ui.progressBar.setRange(0, 1) + self.ui.progressBar.setValue(1) + self.ui.progressBar.reset() + + def disable_send_button(self): + self.ui.send_button.setEnabled(False) + + def enable_send_button(self): + self.ui.send_button.setEnabled(True) + + def response_text_changed(self): + pass + + def parent(self): + return self.app.ui.llm_widget + + def action_button_clicked_send(self, image_override=None, prompt_override=None, callback=None, generator_name="casuallm"): + if self.generating: + Logger.warning("Already generating") + return + + self.generating = True + self.disable_send_button() + #user_input = f"{self.generator.username} Says: \"{self.prompt}\"" + # conversation = "\n".join(self.conversation_history) + # suffix = "\n".join([self.suffix, f'{self.generator.botname} Says: ']) + # prompt = "\n".join([self.instructions, self.prefix, conversation, input, suffix]) + + image = self.app.current_active_image() if (image_override is None or image_override is False) else image_override + + prompt = self.prompt if (prompt_override is None or prompt_override == "") else prompt_override + if prompt is None or prompt == "": + Logger.warning("Prompt is empty") + return + + print(self.generator.prompt_template) + prompt_template = session.query(LLMPromptTemplate).filter( + LLMPromptTemplate.name == self.generator.prompt_template + ).first() + + data = { + "llm_request": True, + "request_data": { + "generator_name": generator_name, + "model_path": self.generator_settings.model_version, + "stream": True, + "prompt": prompt, + "do_summary": False, + "is_bot_alive": True, + "conversation_history": self.conversation_history, + "generator": self.generator, + "prefix": self.prefix, + "suffix": self.suffix, + "dtype": self.generator_settings.dtype, + "use_gpu": self.generator_settings.use_gpu, + "request_type": "image_caption_generator", + "username": self.generator.username, + "botname": self.generator.botname, + "prompt_template": prompt_template.template, + "parameters": { + "override_parameters": self.generator.override_parameters, + "top_p": self.generator_settings.top_p / 100.0, + "max_length": self.generator_settings.max_length, + "repetition_penalty": self.generator_settings.repetition_penalty / 100.0, + "min_length": self.generator_settings.min_length, + "length_penalty": self.generator_settings.length_penalty / 100, + "num_beams": self.generator_settings.num_beams, + "ngram_size": self.generator_settings.ngram_size, + "temperature": self.generator_settings.temperature / 10000.0, + "sequences": self.generator_settings.sequences, + "top_k": self.generator_settings.top_k, + "eta_cutoff": self.generator_settings.eta_cutoff / 100.0, + "seed": self.generator_settings.do_sample, + "early_stopping": self.generator_settings.early_stopping, + }, + "image": image, + "callback": callback + } + } + message_object = Message( + name=self.generator.username, + message=self.prompt, + conversation=self.conversation + ) + session.add(message_object) + session.commit() + self.app.client.message = data + self.add_message_to_conversation(message_object=message_object, is_bot=False) + self.clear_prompt() + self.start_progress_bar() + + def describe_image(self, image, callback): + self.action_button_clicked_send( + image_override=image, + prompt_override="What is in this picture?", + callback=callback, + generator_name="visualqa" + ) + + def add_message_to_conversation(self, message_object, is_bot, first_message=True, last_message=True): + # remove spacer from self.ui.chat_container + widget = MessageWidget(message=message_object, is_bot=is_bot) + self.ui.scrollAreaWidgetContents.layout().addWidget(widget) + + if self.spacer is not None: + self.ui.scrollAreaWidgetContents.layout().removeItem(self.spacer) + + message = "" + if first_message: + message = f"{message_object.name} Says: \"" + message += message_object.message + if last_message: + message += "\"" + + if first_message: + self.conversation_history.append(message) + if not first_message: + self.conversation_history[-1] += message + self.ui.conversation.undo() + self.ui.conversation.append(self.conversation_history[-1]) + + # add a vertical spacer to self.ui.chat_container + if self.spacer is None: + self.spacer = QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding) + self.ui.scrollAreaWidgetContents.layout().addItem(self.spacer) + + def action_button_clicked_clear_conversation(self): + self.conversation_history = [] + self.ui.conversation.setText("") + self.app.client.message = { + "llm_request": True, + "request_data": { + "request_type": "clear_conversation", + } + } + + def message_type_text_changed(self, val): + self.generator.message_type = val + save_session() diff --git a/src/airunner/widgets/llm/llm_preferences_widget.py b/src/airunner/widgets/llm/llm_preferences_widget.py new file mode 100644 index 000000000..91fed15a1 --- /dev/null +++ b/src/airunner/widgets/llm/llm_preferences_widget.py @@ -0,0 +1,62 @@ +from airunner.widgets.base_widget import BaseWidget +from airunner.widgets.llm.templates.llm_preferences_ui import Ui_llm_preferences_widget +from airunner.utils import get_session, save_session +from airunner.data.models import LLMPromptTemplate +from airunner.widgets.base_widget import BaseWidget +from airunner.aihandler.logger import Logger + + + +class LLMPreferencesWidget(BaseWidget): + widget_class_ = Ui_llm_preferences_widget + + @property + def generator(self): + try: + return self.app.ui.llm_widget.generator + except Exception as e: + Logger.error(e) + import traceback + traceback.print_exc() + + @property + def generator_settings(self): + try: + return self.app.ui.llm_widget.generator_settings + except Exception as e: + Logger.error(e) + + def initialize(self): + self.ui.prefix.blockSignals(True) + self.ui.suffix.blockSignals(True) + self.ui.personality_type.blockSignals(True) + if self.generator: + self.ui.prefix.setPlainText(self.generator.prefix) + self.ui.suffix.setPlainText(self.generator.suffix) + self.ui.personality_type.setCurrentText(self.generator.bot_personality) + self.ui.prefix.blockSignals(False) + self.ui.suffix.blockSignals(False) + self.ui.personality_type.blockSignals(False) + + def action_button_clicked_generate_characters(self): + pass + + def personality_type_changed(self, val): + self.generator.bot_personality = val + save_session() + + def prefix_text_changed(self): + self.generator.prefix = self.ui.prefix.toPlainText() + save_session() + + def suffix_text_changed(self): + self.generator.suffix = self.ui.suffix.toPlainText() + save_session() + + def username_text_changed(self, val): + self.generator.username = val + save_session() + + def botname_text_changed(self, val): + self.generator.botname = val + save_session() \ No newline at end of file diff --git a/src/airunner/widgets/llm/llm_settings_widget.py b/src/airunner/widgets/llm/llm_settings_widget.py new file mode 100644 index 000000000..6af9d0187 --- /dev/null +++ b/src/airunner/widgets/llm/llm_settings_widget.py @@ -0,0 +1,263 @@ +""" +This class should be used to create a window widget for the LLM. +""" +from PyQt6.QtWidgets import QWidget + +from airunner.widgets.base_widget import BaseWidget +from airunner.widgets.llm.templates.llm_settings_ui import Ui_llm_settings_widget +from airunner.utils import save_session, get_session +from airunner.data.models import LLMGeneratorSetting, LLMGenerator, AIModel, LLMPromptTemplate +from airunner.aihandler.logger import Logger + + +class LLMSettingsWidget(BaseWidget): + widget_class_ = Ui_llm_settings_widget + current_generator = None + dtype_descriptions = { + "2bit": "Fastest, least amount of VRAM, GPU only, least accurate results.", + "4bit": "Faster, much less VRAM, GPU only, much less accurate results.", + "8bit": "Fast, less VRAM, GPU only, less accurate results.", + "16bit": "Normal speed, some VRAM, uses GPU, slightly less accurate results.", + "32bit": "Slow, no VRAM, uses CPU, most accurate results.", + } + + @property + def generator(self): + try: + return self.app.ui.llm_widget.generator + except Exception as e: + Logger.error(e) + import traceback + traceback.print_exc() + + @property + def generator_settings(self): + try: + return self.app.ui.llm_widget.generator_settings + except Exception as e: + Logger.error(e) + + @property + def current_generator(self): + return self.settings_manager.get_value("current_llm_generator") + + def initialize(self): + self.initialize_form() + + def early_stopping_toggled(self, val): + self.generator.generator_settings[0].early_stopping = val + save_session() + + def do_sample_toggled(self, val): + self.generator.generator_settings[0].do_sample = val + save_session() + + def toggle_leave_model_in_vram(self, val): + if val: + self.settings_manager.set_value("unload_unused_model", False) + self.settings_manager.set_value("move_unused_model_to_cpu", False) + + def initialize_form(self): + session = get_session() + self.ui.prompt_template.blockSignals(True) + self.ui.model.blockSignals(True) + self.ui.model_version.blockSignals(True) + self.ui.radio_button_2bit.blockSignals(True) + self.ui.radio_button_4bit.blockSignals(True) + self.ui.radio_button_8bit.blockSignals(True) + self.ui.radio_button_16bit.blockSignals(True) + self.ui.radio_button_32bit.blockSignals(True) + self.ui.random_seed.blockSignals(True) + self.ui.do_sample.blockSignals(True) + self.ui.early_stopping.blockSignals(True) + self.ui.use_gpu_checkbox.blockSignals(True) + self.ui.override_parameters.blockSignals(True) + self.ui.top_p.initialize_properties() + self.ui.max_length.initialize_properties() + self.ui.max_length.initialize_properties() + self.ui.repetition_penalty.initialize_properties() + self.ui.min_length.initialize_properties() + self.ui.length_penalty.initialize_properties() + self.ui.num_beams.initialize_properties() + self.ui.ngram_size.initialize_properties() + self.ui.temperature.initialize_properties() + self.ui.sequences.initialize_properties() + self.ui.top_k.initialize_properties() + + prompt_templates = [template.name for template in session.query(LLMPromptTemplate).all()] + self.ui.prompt_template.clear() + self.ui.prompt_template.addItems(prompt_templates) + + if self.generator: + self.ui.radio_button_2bit.setChecked(self.generator.generator_settings[0].dtype == "2bit") + self.ui.radio_button_4bit.setChecked(self.generator.generator_settings[0].dtype == "4bit") + self.ui.radio_button_8bit.setChecked(self.generator.generator_settings[0].dtype == "8bit") + self.ui.radio_button_16bit.setChecked(self.generator.generator_settings[0].dtype == "16bit") + self.ui.radio_button_32bit.setChecked(self.generator.generator_settings[0].dtype == "32bit") + self.set_dtype_by_gpu( self.generator.generator_settings[0].use_gpu) + self.set_dtype(self.generator.generator_settings[0].dtype) + + # get unique model names + model_names = [model.name for model in session.query(LLMGenerator).all()] + model_names = list(set(model_names)) + self.ui.model.clear() + self.ui.model.addItems(model_names) + self.ui.model.setCurrentText(self.current_generator) + self.update_model_version_combobox() + if self.generator: + self.ui.model_version.setCurrentText(self.generator.generator_settings[0].model_version) + self.ui.random_seed.setChecked(self.generator.generator_settings[0].random_seed) + self.ui.do_sample.setChecked(self.generator.generator_settings[0].do_sample) + self.ui.early_stopping.setChecked(self.generator.generator_settings[0].early_stopping) + self.ui.use_gpu_checkbox.setChecked(self.generator.generator_settings[0].use_gpu) + self.ui.override_parameters.setChecked(self.generator.override_parameters) + + self.ui.model.blockSignals(False) + self.ui.model_version.blockSignals(False) + self.ui.radio_button_2bit.blockSignals(False) + self.ui.radio_button_4bit.blockSignals(False) + self.ui.radio_button_8bit.blockSignals(False) + self.ui.radio_button_16bit.blockSignals(False) + self.ui.radio_button_32bit.blockSignals(False) + self.ui.random_seed.blockSignals(False) + self.ui.do_sample.blockSignals(False) + self.ui.early_stopping.blockSignals(False) + self.ui.use_gpu_checkbox.blockSignals(False) + self.ui.override_parameters.blockSignals(False) + self.ui.prompt_template.blockSignals(False) + + def model_text_changed(self, val): + print("model_text_changed", val) + self.settings_manager.set_value("current_llm_generator", val) + self.update_model_version_combobox() + self.model_version_changed(self.ui.model_version.currentText()) + self.initialize_form() + + def model_version_changed(self, val): + self.generator.generator_settings[0].model_version = val + save_session() + + def toggle_move_model_to_cpu(self, val): + self.settings_manager.set_value("move_unused_model_to_cpu", val) + if val: + self.settings_manager.set_value("unload_unused_model", False) + + def override_parameters_toggled(self, val): + self.generator.override_parameters = val + save_session() + + def prompt_template_text_changed(self, value): + self.generator.prompt_template = value + save_session() + + def toggled_2bit(self, val): + if val: + self.set_dtype("2bit") + + def toggled_4bit(self, val): + if val: + self.set_dtype("4bit") + + def toggled_8bit(self, val): + if val: + self.set_dtype("8bit") + + def toggled_16bit(self, val): + if val: + self.set_dtype("16bit") + + def toggled_32bit(self, val): + if val: + self.set_dtype("32bit") + + def random_seed_toggled(self, val): + self.generator.generator_settings[0].random_seed = val + save_session() + + def seed_changed(self, val): + self.generator.generator_settings[0].seed = val + save_session() + + def toggle_unload_model(self, val): + self.settings_manager.set_value("unload_unused_model", val) + if val: + self.settings_manager.set_value("move_unused_model_to_cpu", False) + + def use_gpu_toggled(self, val): + self.generator.generator_settings[0].use_gpu = val + # toggle the 16bit radio button and disable 4bit and 8bit radio buttons + self.set_dtype_by_gpu(val) + save_session() + + def set_dtype_by_gpu(self, use_gpu): + if not use_gpu: + self.ui.radio_button_2bit.setEnabled(False) + self.ui.radio_button_4bit.setEnabled(False) + self.ui.radio_button_8bit.setEnabled(False) + self.ui.radio_button_32bit.setEnabled(True) + + if self.generator.generator_settings[0].dtype in ["4bit", "8bit"]: + self.ui.radio_button_16bit.setChecked(True) + else: + self.ui.radio_button_2bit.setEnabled(True) + self.ui.radio_button_4bit.setEnabled(True) + self.ui.radio_button_8bit.setEnabled(True) + self.ui.radio_button_32bit.setEnabled(False) + if self.generator.generator_settings[0].dtype == "32bit": + self.ui.radio_button_16bit.setChecked(True) + + def reset_settings_to_default_clicked(self): + self.generator.generator_settings[0].top_p = LLMGeneratorSetting.top_p.default.arg + self.generator.generator_settings[0].max_length = LLMGeneratorSetting.max_length.default.arg + self.generator.generator_settings[0].repetition_penalty = LLMGeneratorSetting.repetition_penalty.default.arg + self.generator.generator_settings[0].min_length = LLMGeneratorSetting.min_length.default.arg + self.generator.generator_settings[0].length_penalty = LLMGeneratorSetting.length_penalty.default.arg + self.generator.generator_settings[0].num_beams = LLMGeneratorSetting.num_beams.default.arg + self.generator.generator_settings[0].ngram_size = LLMGeneratorSetting.ngram_size.default.arg + self.generator.generator_settings[0].temperature = LLMGeneratorSetting.temperature.default.arg + self.generator.generator_settings[0].sequences = LLMGeneratorSetting.sequences.default.arg + self.generator.generator_settings[0].top_k = LLMGeneratorSetting.top_k.default.arg + self.generator.generator_settings[0].eta_cutoff = LLMGeneratorSetting.eta_cutoff.default.arg + self.generator.generator_settings[0].seed = LLMGeneratorSetting.seed.default.arg + self.generator.generator_settings[0].do_sample = LLMGeneratorSetting.do_sample.default.arg + self.generator.generator_settings[0].early_stopping = LLMGeneratorSetting.early_stopping.default.arg + self.generator.generator_settings[0].random_seed = LLMGeneratorSetting.random_seed.default.arg + self.generator.generator_settings[0].model_version = LLMGeneratorSetting.model_version.default.arg + self.generator.generator_settings[0].dtype = LLMGeneratorSetting.dtype.default.arg + self.generator.generator_settings[0].use_gpu = LLMGeneratorSetting.use_gpu.default.arg + save_session() + self.initialize_form() + self.ui.top_p.set_slider_and_spinbox_values(self.generator.generator_settings[0].top_p) + self.ui.max_length.set_slider_and_spinbox_values(self.generator.generator_settings[0].max_length) + self.ui.repetition_penalty.set_slider_and_spinbox_values(self.generator.generator_settings[0].repetition_penalty) + self.ui.min_length.set_slider_and_spinbox_values(self.generator.generator_settings[0].min_length) + self.ui.length_penalty.set_slider_and_spinbox_values(self.generator.generator_settings[0].length_penalty) + self.ui.num_beams.set_slider_and_spinbox_values(self.generator.generator_settings[0].num_beams) + self.ui.ngram_size.set_slider_and_spinbox_values(self.generator.generator_settings[0].ngram_size) + self.ui.temperature.set_slider_and_spinbox_values(self.generator.generator_settings[0].temperature) + self.ui.sequences.set_slider_and_spinbox_values(self.generator.generator_settings[0].sequences) + self.ui.top_k.set_slider_and_spinbox_values(self.generator.generator_settings[0].top_k) + self.ui.eta_cutoff.set_slider_and_spinbox_values(self.generator.generator_settings[0].eta_cutoff) + self.ui.random_seed.setChecked(self.generator.generator_settings[0].random_seed) + + def set_dtype(self, dtype): + self.generator.generator_settings[0].dtype = dtype + save_session() + self.set_dtype_description(dtype) + + def set_dtype_description(self, dtype): + self.ui.dtype_description.setText(self.dtype_descriptions[dtype]) + + def update_model_version_combobox(self): + session = get_session() + self.ui.model_version.blockSignals(True) + self.ui.model_version.clear() + ai_model_paths = [model.path for model in session.query(AIModel).filter( + AIModel.pipeline_action == self.current_generator + )] + self.ui.model_version.addItems(ai_model_paths) + self.ui.model_version.blockSignals(False) + + def set_tab(self, tab_name): + index = self.ui.tabWidget.indexOf(self.ui.tabWidget.findChild(QWidget, tab_name)) + self.ui.tabWidget.setCurrentIndex(index) \ No newline at end of file diff --git a/src/airunner/widgets/llm/llm_widget.py b/src/airunner/widgets/llm/llm_widget.py index 1d6f7b862..997241db0 100644 --- a/src/airunner/widgets/llm/llm_widget.py +++ b/src/airunner/widgets/llm/llm_widget.py @@ -1,558 +1,43 @@ """ This class should be used to create a window widget for the LLM. """ -from PyQt6.QtCore import pyqtSlot -from PyQt6.QtWidgets import QSpacerItem, QSizePolicy -from PyQt6.QtWidgets import QWidget -from sqlalchemy import inspect -from functools import partial - -from airunner.aihandler.enums import MessageCode -from airunner.data.db import session -from airunner.data.models import AIModel, LLMGenerator, Conversation, LLMGeneratorSetting, LLMPromptTemplate, Message -from airunner.utils import save_session +from airunner.utils import get_session +from airunner.data.models import LLMGenerator, LLMGeneratorSetting from airunner.widgets.base_widget import BaseWidget from airunner.widgets.llm.templates.llm_widget_ui import Ui_llm_widget +from airunner.aihandler.logger import Logger class LLMWidget(BaseWidget): widget_class_ = Ui_llm_widget generator = None - conversation = None - is_modal = True - generating = False - prefix = "" - prompt = "" - suffix = "" - conversation_history = [] - - @property - def current_generator(self): - return self.settings_manager.current_llm_generator + _generator = None + _generator_settings = None @property - def instructions(self): - return f"{self.generator.botname} loves {self.generator.username}. {self.generator.botname} is very nice. {self.generator.botname} uses compliments, kind responses, and nice words. Everything {self.generator.botname} says is nice. {self.generator.botname} is kind." - - def load_data(self): - self.load_generator() - self.conversation = session.query(Conversation).first() - if self.conversation is None: - self.conversation = Conversation() - session.add(self.conversation) - session.commit() - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # hide tab bar - #self.ui.tabWidget.tabBar().hide() - self.load_data() - - self.ui.prefix.blockSignals(True) - self.ui.suffix.blockSignals(True) - if self.generator: - self.ui.prefix.setPlainText(self.generator.prefix) - self.ui.suffix.setPlainText(self.generator.suffix) - self.initialize_form() - self.ui.prefix.blockSignals(False) - self.ui.suffix.blockSignals(False) - - self.app.token_signal.connect(self.handle_token_signal) - self.app.message_var.my_signal.connect(self.message_handler) - self.ui.prompt.returnPressed.connect(self.action_button_clicked_send) - self.ui.prompt.textChanged.connect(self.prompt_text_changed) - leave_in_vram = not self.settings_manager.move_unused_model_to_cpu and not self.settings_manager.unload_unused_model - self.ui.leave_in_vram.setChecked(leave_in_vram) - self.ui.move_to_cpu.setChecked(self.settings_manager.move_unused_model_to_cpu) - self.ui.unload_model.setChecked(self.settings_manager.unload_unused_model) - - prompt_templates = session.query(LLMPromptTemplate).all() - self.ui.prompt_template.blockSignals(True) - for prompt_template in prompt_templates: - self.ui.prompt_template.addItem(prompt_template.name) - self.ui.prompt_template.blockSignals(False) + def generator(self): + if self._generator is None: + session = get_session() + try: + self._generator = session.query(LLMGenerator).filter( + LLMGenerator.name == self.ui.llm_settings_widget.current_generator + ).first() + if self._generator is None: + Logger.error("Unable to locate generator by name " + self.ui.llm_settings_widget.current_generator if self.ui.llm_settings_widget.current_generator else "None") + except Exception as e: + Logger.error(e) + return self._generator - def prompt_template_text_changed(self, value): - print(value) - - def handle_token_signal(self, val): - if val != "[END]": - text = self.ui.conversation.toPlainText() - text += val - self.ui.conversation.setText(text) - else: - self.stop_progress_bar() - self.generating = False - self.enable_send_button() - - @pyqtSlot(dict) - def message_handler(self, response: dict): + @property + def generator_settings(self): try: - code = response["code"] - except TypeError: - return - message = response["message"] - - if code == MessageCode.TEXT_GENERATED: - print("MESSAGE HANDLER", response) - self.handle_text_generated(message) - - def initialize_form(self): - self.ui.model.blockSignals(True) - self.ui.model_version.blockSignals(True) - self.ui.prompt.blockSignals(True) - self.ui.botname.blockSignals(True) - self.ui.username.blockSignals(True) - self.ui.prefix.blockSignals(True) - self.ui.suffix.blockSignals(True) - self.ui.personality_type.blockSignals(True) - self.ui.radio_button_2bit.blockSignals(True) - self.ui.radio_button_4bit.blockSignals(True) - self.ui.radio_button_8bit.blockSignals(True) - self.ui.radio_button_16bit.blockSignals(True) - self.ui.radio_button_32bit.blockSignals(True) - self.ui.random_seed.blockSignals(True) - self.ui.do_sample.blockSignals(True) - self.ui.early_stopping.blockSignals(True) - self.ui.use_gpu_checkbox.blockSignals(True) - self.ui.override_parameters.blockSignals(True) - - self.ui.top_p.initialize_properties() - self.ui.max_length.initialize_properties() - self.ui.max_length.initialize_properties() - self.ui.repetition_penalty.initialize_properties() - self.ui.min_length.initialize_properties() - self.ui.length_penalty.initialize_properties() - self.ui.num_beams.initialize_properties() - self.ui.ngram_size.initialize_properties() - self.ui.temperature.initialize_properties() - self.ui.sequences.initialize_properties() - self.ui.top_k.initialize_properties() - - if self.generator: - self.ui.radio_button_2bit.setChecked(self.generator.generator_settings[0].dtype == "2bit") - self.ui.radio_button_4bit.setChecked(self.generator.generator_settings[0].dtype == "4bit") - self.ui.radio_button_8bit.setChecked(self.generator.generator_settings[0].dtype == "8bit") - self.ui.radio_button_16bit.setChecked(self.generator.generator_settings[0].dtype == "16bit") - self.ui.radio_button_32bit.setChecked(self.generator.generator_settings[0].dtype == "32bit") - self.set_dtype_by_gpu( self.generator.generator_settings[0].use_gpu) - self.set_dtype(self.generator.generator_settings[0].dtype) - - # get unique model names - model_names = [model.name for model in session.query(LLMGenerator).all()] - model_names = list(set(model_names)) - self.ui.model.clear() - self.ui.model.addItems(model_names) - self.ui.model.setCurrentText(self.current_generator) - if self.generator: - self.ui.username.setText(self.generator.username) - self.ui.botname.setText(self.generator.botname) - self.update_model_version_combobox() - if self.generator: - self.ui.model_version.setCurrentText(self.generator.generator_settings[0].model_version) - self.ui.personality_type.setCurrentText(self.generator.bot_personality) - self.ui.random_seed.setChecked(self.generator.generator_settings[0].random_seed) - self.ui.do_sample.setChecked(self.generator.generator_settings[0].do_sample) - self.ui.early_stopping.setChecked(self.generator.generator_settings[0].early_stopping) - self.ui.use_gpu_checkbox.setChecked(self.generator.generator_settings[0].use_gpu) - self.ui.override_parameters.setChecked(self.generator.override_parameters) - - self.ui.model.blockSignals(False) - self.ui.model_version.blockSignals(False) - self.ui.prompt.blockSignals(False) - self.ui.botname.blockSignals(False) - self.ui.username.blockSignals(False) - self.ui.prefix.blockSignals(False) - self.ui.suffix.blockSignals(False) - self.ui.personality_type.blockSignals(False) - self.ui.radio_button_2bit.blockSignals(False) - self.ui.radio_button_4bit.blockSignals(False) - self.ui.radio_button_8bit.blockSignals(False) - self.ui.radio_button_16bit.blockSignals(False) - self.ui.radio_button_32bit.blockSignals(False) - self.ui.random_seed.blockSignals(False) - self.ui.do_sample.blockSignals(False) - self.ui.early_stopping.blockSignals(False) - self.ui.use_gpu_checkbox.blockSignals(False) - self.ui.override_parameters.blockSignals(False) - - def handle_text_generated(self, messages): - self.stop_progress_bar() - - # check if messages is string or list - if isinstance(messages, str): - messages = [messages] - - print("MESSAGES", messages) - - for message in messages: - - # strip quotes from start and end of message - if not message: - return - if message.startswith("\""): - message = message[1:] - if message.endswith("\""): - message = message[:-1] - message_object = Message( - name=self.generator.botname, - message=message, - conversation=self.conversation - ) - session.add(message_object) - session.commit() - - if self.settings_manager.enable_tts: - # split on sentence enders - sentence_enders = [".", "?", "!", "\n"] - text = message_object.message - sentences = [] - # split text into sentences - current_sentence = "" - for char in text: - current_sentence += char - if char in sentence_enders: - sentences.append(current_sentence) - current_sentence = "" - if current_sentence != "": - sentences.append(current_sentence) - - for index, sentence in enumerate(sentences): - sentence = sentence.strip() - self.app.client.message = dict( - tts_request=True, - request_data=dict( - text=sentence, - message_object=Message( - name=message_object.name, - message=sentence, - ), - is_bot=True, - callback=self.add_message_to_conversation, - first_message=index == 0, - last_message=index == len(sentences) - 1, - ) - ) - - if not self.settings_manager.enable_tts: - self.add_message_to_conversation(message_object, is_bot=True) - - self.generating = False - self.enable_send_button() - - def personality_type_changed(self, val): - self.generator.bot_personality = val - save_session() - - def prefix_text_changed(self): - self.generator.prefix = self.ui.prefix.toPlainText() - save_session() - - def prompt_text_changed(self, val): - self.prompt = val + return self.generator.generator_settings[0] + except Exception as e: + Logger.error(e) + return None - def suffix_text_changed(self): - self.generator.suffix = self.ui.suffix.toPlainText() - save_session() - - def clear_prompt(self): - self.ui.prompt.setText("") - - def start_progress_bar(self): - self.ui.progressBar.setRange(0, 0) - self.ui.progressBar.setValue(0) - - def stop_progress_bar(self): - self.ui.progressBar.setRange(0, 1) - self.ui.progressBar.setValue(1) - self.ui.progressBar.reset() - - def disable_send_button(self): - self.ui.send_button.setEnabled(False) - - def enable_send_button(self): - self.ui.send_button.setEnabled(True) - - def seed_changed(self, val): - self.generator.generator_settings[0].seed = val - save_session() - - def response_text_changed(self): - pass - - def username_text_changed(self, val): - self.generator.username = val - save_session() - - def random_seed_toggled(self, val): - self.generator.generator_settings[0].random_seed = val - save_session() - - def model_version_changed(self, val): - self.generator.generator_settings[0].model_version = val - save_session() - - def early_stopping_toggled(self, val): - self.generator.generator_settings[0].early_stopping = val - save_session() - - def do_sample_toggled(self, val): - self.generator.generator_settings[0].do_sample = val - save_session() - - def botname_text_changed(self, val): - self.generator.botname = val - save_session() - - def action_button_clicked_send(self, image_override=None, prompt_override=None, callback=None, generator_name="casuallm"): - if self.generating: - return - - self.load_generator() - self.generating = True - self.disable_send_button() - #user_input = f"{self.generator.username} Says: \"{self.prompt}\"" - # conversation = "\n".join(self.conversation_history) - # suffix = "\n".join([self.suffix, f'{self.generator.botname} Says: ']) - # prompt = "\n".join([self.instructions, self.prefix, conversation, input, suffix]) - prompt_template = session.query(LLMPromptTemplate).filter( - LLMPromptTemplate.name == self.ui.prompt_template.currentText() - ).first() - - image = self.app.current_active_image() if (image_override is None or image_override is False) else image_override - - prompt = self.prompt if prompt_override is None else prompt_override - - settings = self.generator.generator_settings[0] - data = { - "llm_request": True, - "request_data": { - "generator_name": generator_name, - "model_path": settings.model_version, - "stream": True, - "prompt": prompt, - "do_summary": False, - "is_bot_alive": True, - "conversation_history": self.conversation_history, - "generator": self.generator, - "prefix": self.prefix, - "suffix": self.suffix, - "dtype": settings.dtype, - "use_gpu": settings.use_gpu, - "request_type": "image_caption_generator", - "username": self.generator.username, - "botname": self.generator.botname, - "prompt_template": prompt_template.template, - "parameters": { - "override_parameters": self.generator.override_parameters, - "top_p": settings.top_p / 100.0, - "max_length": settings.max_length, - "repetition_penalty": settings.repetition_penalty / 100.0, - "min_length": settings.min_length, - "length_penalty": settings.length_penalty / 100, - "num_beams": settings.num_beams, - "ngram_size": settings.ngram_size, - "temperature": settings.temperature / 10000.0, - "sequences": settings.sequences, - "top_k": settings.top_k, - "eta_cutoff": settings.eta_cutoff / 100.0, - "seed": settings.do_sample, - "early_stopping": settings.early_stopping, - }, - "image": image, - "callback": callback - } - } - message_object = Message( - name=self.generator.username, - message=self.prompt, - conversation=self.conversation - ) - session.add(message_object) - session.commit() - self.app.client.message = data - self.add_message_to_conversation(message_object=message_object, is_bot=False) - self.clear_prompt() - self.start_progress_bar() - - def describe_image(self, image, callback): - print("DESCRIBE IMAGE", callback) - self.action_button_clicked_send( - image_override=image, - prompt_override="What is in this picture?", - callback=callback, - generator_name="visualqa" - ) - - def add_message_to_conversation(self, message_object, is_bot, first_message=True, last_message=True): - message = "" - if first_message: - message = f"{message_object.name} Says: \"" - message += message_object.message - if last_message: - message += "\"" - - if first_message: - self.conversation_history.append(message) - if not first_message: - self.conversation_history[-1] += message - self.ui.conversation.undo() - self.ui.conversation.append(self.conversation_history[-1]) - - - def action_button_clicked_generate_characters(self): - pass - - def action_button_clicked_clear_conversation(self): - self.conversation_history = [] - self.ui.conversation.setText("") - self.app.client.message = { - "llm_request": True, - "request_data": { - "request_type": "clear_conversation", - } - } - - def message_type_text_changed(self, val): - self.generator.message_type = val - save_session() - - dtype_descriptions = { - "2bit": "Fastest, least amount of VRAM, GPU only, least accurate results.", - "4bit": "Faster, much less VRAM, GPU only, much less accurate results.", - "8bit": "Fast, less VRAM, GPU only, less accurate results.", - "16bit": "Normal speed, some VRAM, uses GPU, slightly less accurate results.", - "32bit": "Slow, no VRAM, uses CPU, most accurate results.", - } - - def toggled_2bit(self, val): - if val: - self.set_dtype("2bit") - - def toggled_4bit(self, val): - if val: - self.set_dtype("4bit") - - def toggled_8bit(self, val): - if val: - self.set_dtype("8bit") - - def toggled_16bit(self, val): - if val: - self.set_dtype("16bit") - - def toggled_32bit(self, val): - if val: - self.set_dtype("32bit") - - def set_dtype(self, dtype): - self.generator.generator_settings[0].dtype = dtype - save_session() - self.set_dtype_description(dtype) - - def set_dtype_description(self, dtype): - self.ui.dtype_description.setText(self.dtype_descriptions[dtype]) - - def model_text_changed(self, val): - self.settings_manager.set_value("current_llm_generator", val) - self.load_generator() - self.generator.generator_settings[0].model = val - self.update_model_version_combobox() - self.model_version_changed(self.ui.model_version.currentText()) - print("MODEL TEXT CHANGED") - self.initialize_form() - - def update_model_version_combobox(self): - self.ui.model_version.blockSignals(True) - self.ui.model_version.clear() - ai_model_paths = [model.path for model in session.query(AIModel).filter( - AIModel.pipeline_action == self.current_generator - )] - self.ui.model_version.addItems(ai_model_paths) - self.ui.model_version.blockSignals(False) - - def load_generator(self): - self.generator = session.query(LLMGenerator).filter( - LLMGenerator.name == self.current_generator - ).first() - - def reset_settings_to_default_clicked(self): - self.generator.generator_settings[0].top_p = LLMGeneratorSetting.top_p.default.arg - self.generator.generator_settings[0].max_length = LLMGeneratorSetting.max_length.default.arg - self.generator.generator_settings[0].repetition_penalty = LLMGeneratorSetting.repetition_penalty.default.arg - self.generator.generator_settings[0].min_length = LLMGeneratorSetting.min_length.default.arg - self.generator.generator_settings[0].length_penalty = LLMGeneratorSetting.length_penalty.default.arg - self.generator.generator_settings[0].num_beams = LLMGeneratorSetting.num_beams.default.arg - self.generator.generator_settings[0].ngram_size = LLMGeneratorSetting.ngram_size.default.arg - self.generator.generator_settings[0].temperature = LLMGeneratorSetting.temperature.default.arg - self.generator.generator_settings[0].sequences = LLMGeneratorSetting.sequences.default.arg - self.generator.generator_settings[0].top_k = LLMGeneratorSetting.top_k.default.arg - self.generator.generator_settings[0].eta_cutoff = LLMGeneratorSetting.eta_cutoff.default.arg - self.generator.generator_settings[0].seed = LLMGeneratorSetting.seed.default.arg - self.generator.generator_settings[0].do_sample = LLMGeneratorSetting.do_sample.default.arg - self.generator.generator_settings[0].early_stopping = LLMGeneratorSetting.early_stopping.default.arg - self.generator.generator_settings[0].random_seed = LLMGeneratorSetting.random_seed.default.arg - self.generator.generator_settings[0].model_version = LLMGeneratorSetting.model_version.default.arg - self.generator.generator_settings[0].dtype = LLMGeneratorSetting.dtype.default.arg - self.generator.generator_settings[0].use_gpu = LLMGeneratorSetting.use_gpu.default.arg - save_session() - self.initialize_form() - self.ui.top_p.set_slider_and_spinbox_values(self.generator.generator_settings[0].top_p) - self.ui.max_length.set_slider_and_spinbox_values(self.generator.generator_settings[0].max_length) - self.ui.repetition_penalty.set_slider_and_spinbox_values(self.generator.generator_settings[0].repetition_penalty) - self.ui.min_length.set_slider_and_spinbox_values(self.generator.generator_settings[0].min_length) - self.ui.length_penalty.set_slider_and_spinbox_values(self.generator.generator_settings[0].length_penalty) - self.ui.num_beams.set_slider_and_spinbox_values(self.generator.generator_settings[0].num_beams) - self.ui.ngram_size.set_slider_and_spinbox_values(self.generator.generator_settings[0].ngram_size) - self.ui.temperature.set_slider_and_spinbox_values(self.generator.generator_settings[0].temperature) - self.ui.sequences.set_slider_and_spinbox_values(self.generator.generator_settings[0].sequences) - self.ui.top_k.set_slider_and_spinbox_values(self.generator.generator_settings[0].top_k) - self.ui.eta_cutoff.set_slider_and_spinbox_values(self.generator.generator_settings[0].eta_cutoff) - self.ui.random_seed.setChecked(self.generator.generator_settings[0].random_seed) - - def use_gpu_toggled(self, val): - self.generator.generator_settings[0].use_gpu = val - # toggle the 16bit radio button and disable 4bit and 8bit radio buttons - self.set_dtype_by_gpu(val) - save_session() - - def set_dtype_by_gpu(self, use_gpu): - if not use_gpu: - self.ui.radio_button_2bit.setEnabled(False) - self.ui.radio_button_4bit.setEnabled(False) - self.ui.radio_button_8bit.setEnabled(False) - self.ui.radio_button_32bit.setEnabled(True) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - if self.generator.generator_settings[0].dtype in ["4bit", "8bit"]: - self.ui.radio_button_16bit.setChecked(True) - else: - self.ui.radio_button_2bit.setEnabled(True) - self.ui.radio_button_4bit.setEnabled(True) - self.ui.radio_button_8bit.setEnabled(True) - self.ui.radio_button_32bit.setEnabled(False) - if self.generator.generator_settings[0].dtype == "32bit": - self.ui.radio_button_16bit.setChecked(True) - - def override_parameters_toggled(self, val): - self.generator.override_parameters = val - save_session() - - def toggle_leave_model_in_vram(self, val): - print(val) - if val: - self.settings_manager.set_value("unload_unused_model", False) - self.settings_manager.set_value("move_unused_model_to_cpu", False) - - def toggle_move_model_to_cpu(self, val): - self.settings_manager.set_value("move_unused_model_to_cpu", val) - if val: - self.settings_manager.set_value("unload_unused_model", False) - - def toggle_unload_model(self, val): - self.settings_manager.set_value("unload_unused_model", val) - if val: - self.settings_manager.set_value("move_unused_model_to_cpu", False) - - def set_tab(self, tab_name): - index = self.ui.tabWidget.indexOf(self.ui.tabWidget.findChild(QWidget, tab_name)) - self.ui.tabWidget.setCurrentIndex(index) + # After the app is loaded, initialize other widgets + self.app.loaded.connect(self.initialize) \ No newline at end of file diff --git a/src/airunner/widgets/llm/message_widget.py b/src/airunner/widgets/llm/message_widget.py index b12075a0e..7a918b4eb 100644 --- a/src/airunner/widgets/llm/message_widget.py +++ b/src/airunner/widgets/llm/message_widget.py @@ -1,19 +1,25 @@ from airunner.widgets.base_widget import BaseWidget -from airunner.widgets.llm.templates.message_ui import Ui_message_widget +from airunner.widgets.llm.templates.message_ui import Ui_message class MessageWidget(BaseWidget): - widget_class_ = Ui_message_widget + widget_class_ = Ui_message def __init__(self, *args, **kwargs): self.is_bot = kwargs.pop("is_bot") self.message = kwargs.pop("message") super().__init__(*args, **kwargs) - self.ui.name.setText(f"{self.message.name}:") - self.ui.message.setPlainText(self.message.message) + self.ui.content.setPlainText(self.message.message) + name = self.message.name if self.is_bot: - self.ui.name.setStyleSheet("font-weight: normal;") + self.ui.bot_name.show() + self.ui.bot_name.setText(f"{name}") + self.ui.bot_name.setStyleSheet("font-weight: normal;") + self.ui.user_name.hide() else: - self.ui.name.setStyleSheet("font-weight: bold;") + self.ui.user_name.show() + self.ui.user_name.setText(f"{name}") + self.ui.user_name.setStyleSheet("font-weight: normal;") + self.ui.bot_name.hide() - self.ui.message.setStyleSheet("color: #f2f2f2;") + self.ui.content.setStyleSheet("color: #f2f2f2;") diff --git a/src/airunner/widgets/llm/templates/chat_prompt.ui b/src/airunner/widgets/llm/templates/chat_prompt.ui new file mode 100644 index 000000000..f9017ca49 --- /dev/null +++ b/src/airunner/widgets/llm/templates/chat_prompt.ui @@ -0,0 +1,179 @@ + + + chat_prompt + + + + 0 + 0 + 734 + 1077 + + + + Form + + + + + + true + + + + + + + true + + + + + 0 + 0 + 714 + 492 + + + + + + + + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Chat input + + + + + + + + Chat + + + + + Narrate + + + + + Generate Image + + + + + Summarize + + + + + Translate + + + + + + + + + + + + + Send + + + + + + + 0 + + + + + + + New Conversation + + + + + + + + + + + send_button + clicked() + chat_prompt + action_button_clicked_send() + + + 75 + 1057 + + + 211 + -12 + + + + + clear_conversatiion_button + clicked() + chat_prompt + action_button_clicked_clear_conversation() + + + 657 + 1052 + + + 565 + -13 + + + + + comboBox + currentTextChanged(QString) + chat_prompt + message_type_text_changed(QString) + + + 634 + 1018 + + + 465 + -3 + + + + + + action_button_clicked_send() + action_button_clicked_clear_conversation() + message_type_text_changed(QString) + + diff --git a/src/airunner/widgets/llm/templates/chat_prompt_ui.py b/src/airunner/widgets/llm/templates/chat_prompt_ui.py new file mode 100644 index 000000000..cba47dece --- /dev/null +++ b/src/airunner/widgets/llm/templates/chat_prompt_ui.py @@ -0,0 +1,79 @@ +# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/llm/templates/chat_prompt.ui' +# +# Created by: PyQt6 UI code generator 6.4.2 +# +# WARNING: Any manual changes made to this file will be lost when pyuic6 is +# run again. Do not edit this file unless you know what you are doing. + + +from PyQt6 import QtCore, QtGui, QtWidgets + + +class Ui_chat_prompt(object): + def setupUi(self, chat_prompt): + chat_prompt.setObjectName("chat_prompt") + chat_prompt.resize(734, 1077) + self.gridLayout = QtWidgets.QGridLayout(chat_prompt) + self.gridLayout.setObjectName("gridLayout") + self.conversation = QtWidgets.QTextEdit(parent=chat_prompt) + self.conversation.setReadOnly(True) + self.conversation.setObjectName("conversation") + self.gridLayout.addWidget(self.conversation, 0, 0, 1, 1) + self.chat_container = QtWidgets.QScrollArea(parent=chat_prompt) + self.chat_container.setWidgetResizable(True) + self.chat_container.setObjectName("chat_container") + self.scrollAreaWidgetContents = QtWidgets.QWidget() + self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 714, 492)) + self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents") + self.verticalLayout = QtWidgets.QVBoxLayout(self.scrollAreaWidgetContents) + self.verticalLayout.setObjectName("verticalLayout") + self.chat_container.setWidget(self.scrollAreaWidgetContents) + self.gridLayout.addWidget(self.chat_container, 1, 0, 1, 1) + self.widget_5 = QtWidgets.QWidget(parent=chat_prompt) + self.widget_5.setObjectName("widget_5") + self.horizontalLayout = QtWidgets.QHBoxLayout(self.widget_5) + self.horizontalLayout.setContentsMargins(0, 0, 0, 0) + self.horizontalLayout.setObjectName("horizontalLayout") + self.prompt = QtWidgets.QLineEdit(parent=self.widget_5) + self.prompt.setObjectName("prompt") + self.horizontalLayout.addWidget(self.prompt) + self.comboBox = QtWidgets.QComboBox(parent=self.widget_5) + self.comboBox.setObjectName("comboBox") + self.comboBox.addItem("") + self.comboBox.addItem("") + self.comboBox.addItem("") + self.comboBox.addItem("") + self.comboBox.addItem("") + self.horizontalLayout.addWidget(self.comboBox) + self.gridLayout.addWidget(self.widget_5, 2, 0, 1, 1) + self.horizontalLayout_2 = QtWidgets.QHBoxLayout() + self.horizontalLayout_2.setObjectName("horizontalLayout_2") + self.send_button = QtWidgets.QPushButton(parent=chat_prompt) + self.send_button.setObjectName("send_button") + self.horizontalLayout_2.addWidget(self.send_button) + self.progressBar = QtWidgets.QProgressBar(parent=chat_prompt) + self.progressBar.setProperty("value", 0) + self.progressBar.setObjectName("progressBar") + self.horizontalLayout_2.addWidget(self.progressBar) + self.clear_conversatiion_button = QtWidgets.QPushButton(parent=chat_prompt) + self.clear_conversatiion_button.setObjectName("clear_conversatiion_button") + self.horizontalLayout_2.addWidget(self.clear_conversatiion_button) + self.gridLayout.addLayout(self.horizontalLayout_2, 3, 0, 1, 1) + + self.retranslateUi(chat_prompt) + self.send_button.clicked.connect(chat_prompt.action_button_clicked_send) # type: ignore + self.clear_conversatiion_button.clicked.connect(chat_prompt.action_button_clicked_clear_conversation) # type: ignore + self.comboBox.currentTextChanged['QString'].connect(chat_prompt.message_type_text_changed) # type: ignore + QtCore.QMetaObject.connectSlotsByName(chat_prompt) + + def retranslateUi(self, chat_prompt): + _translate = QtCore.QCoreApplication.translate + chat_prompt.setWindowTitle(_translate("chat_prompt", "Form")) + self.prompt.setPlaceholderText(_translate("chat_prompt", "Chat input")) + self.comboBox.setItemText(0, _translate("chat_prompt", "Chat")) + self.comboBox.setItemText(1, _translate("chat_prompt", "Narrate")) + self.comboBox.setItemText(2, _translate("chat_prompt", "Generate Image")) + self.comboBox.setItemText(3, _translate("chat_prompt", "Summarize")) + self.comboBox.setItemText(4, _translate("chat_prompt", "Translate")) + self.send_button.setText(_translate("chat_prompt", "Send")) + self.clear_conversatiion_button.setText(_translate("chat_prompt", "New Conversation")) diff --git a/src/airunner/widgets/llm/templates/llm_preferences.ui b/src/airunner/widgets/llm/templates/llm_preferences.ui new file mode 100644 index 000000000..a1f768413 --- /dev/null +++ b/src/airunner/widgets/llm/templates/llm_preferences.ui @@ -0,0 +1,275 @@ + + + llm_preferences_widget + + + + 0 + 0 + 1371 + 957 + + + + Form + + + + + + Qt::Vertical + + + + Prefix + + + + + + + + + + Suffix + + + + + + + + + + + + + Bot details + + + + + + + + Name + + + + + + + ChatAI + + + 6 + + + + + + + + + + + Personality + + + + + + + + Nice + + + + + Mean + + + + + Weird + + + + + Insane + + + + + Random + + + + + + + + + + + + + User name + + + + + + User + + + + + + + + + + Generate Characters + + + + + + + prefix + suffix + botname + username + + + + + generate_characters_button + clicked() + llm_preferences_widget + action_button_clicked_generate_characters() + + + 362 + 962 + + + 477 + 0 + + + + + username + textEdited(QString) + llm_preferences_widget + username_text_changed(QString) + + + 226 + 919 + + + 118 + 0 + + + + + botname + textEdited(QString) + llm_preferences_widget + botname_text_changed(QString) + + + 298 + 843 + + + 80 + 0 + + + + + personality_type + textHighlighted(QString) + llm_preferences_widget + personality_type_changed(QString) + + + 597 + 843 + + + 412 + 0 + + + + + suffix + textChanged() + llm_preferences_widget + suffix_text_changed() + + + 182 + 669 + + + 206 + 5 + + + + + prefix + textChanged() + llm_preferences_widget + prefix_text_changed() + + + 131 + 108 + + + 48 + 6 + + + + + + prompt_text_changed(QString) + botname_text_changed(QString) + username_text_changed(QString) + model_version_changed(QString) + random_seed_toggled(bool) + response_text_changed(QString) + seed_changed(QString) + suffix_text_changed() + prefix_text_changed() + action_button_clicked_send() + action_button_clicked_generate_characters() + action_button_clicked_clear_conversation() + message_type_text_changed(QString) + personality_type_changed(QString) + toggled_4bit(bool) + toggled_8bit(bool) + toggled_16bit(bool) + toggled_32bit(bool) + do_sample_toggled(bool) + early_stopping_toggled(bool) + model_text_changed(QString) + reset_settings_to_default_clicked() + use_gpu_toggled(bool) + override_parameters_toggled(bool) + toggled_2bit(bool) + toggle_leave_model_in_vram(bool) + toggle_unload_model(bool) + toggle_move_model_to_cpu(bool) + prompt_template_text_changed(QString) + + diff --git a/src/airunner/widgets/llm/templates/llm_preferences_ui.py b/src/airunner/widgets/llm/templates/llm_preferences_ui.py new file mode 100644 index 000000000..7b1ca09a6 --- /dev/null +++ b/src/airunner/widgets/llm/templates/llm_preferences_ui.py @@ -0,0 +1,105 @@ +# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/llm/templates/llm_preferences.ui' +# +# Created by: PyQt6 UI code generator 6.4.2 +# +# WARNING: Any manual changes made to this file will be lost when pyuic6 is +# run again. Do not edit this file unless you know what you are doing. + + +from PyQt6 import QtCore, QtGui, QtWidgets + + +class Ui_llm_preferences_widget(object): + def setupUi(self, llm_preferences_widget): + llm_preferences_widget.setObjectName("llm_preferences_widget") + llm_preferences_widget.resize(1371, 957) + self.gridLayout = QtWidgets.QGridLayout(llm_preferences_widget) + self.gridLayout.setObjectName("gridLayout") + self.splitter = QtWidgets.QSplitter(parent=llm_preferences_widget) + self.splitter.setOrientation(QtCore.Qt.Orientation.Vertical) + self.splitter.setObjectName("splitter") + self.groupBox = QtWidgets.QGroupBox(parent=self.splitter) + self.groupBox.setObjectName("groupBox") + self.gridLayout_3 = QtWidgets.QGridLayout(self.groupBox) + self.gridLayout_3.setObjectName("gridLayout_3") + self.prefix = QtWidgets.QPlainTextEdit(parent=self.groupBox) + self.prefix.setObjectName("prefix") + self.gridLayout_3.addWidget(self.prefix, 0, 0, 1, 1) + self.groupBox_2 = QtWidgets.QGroupBox(parent=self.splitter) + self.groupBox_2.setObjectName("groupBox_2") + self.gridLayout_4 = QtWidgets.QGridLayout(self.groupBox_2) + self.gridLayout_4.setObjectName("gridLayout_4") + self.suffix = QtWidgets.QPlainTextEdit(parent=self.groupBox_2) + self.suffix.setObjectName("suffix") + self.gridLayout_4.addWidget(self.suffix, 0, 0, 1, 1) + self.gridLayout.addWidget(self.splitter, 0, 0, 1, 1) + self.groupBox_3 = QtWidgets.QGroupBox(parent=llm_preferences_widget) + self.groupBox_3.setObjectName("groupBox_3") + self.horizontalLayout_5 = QtWidgets.QHBoxLayout(self.groupBox_3) + self.horizontalLayout_5.setObjectName("horizontalLayout_5") + self.verticalLayout_2 = QtWidgets.QVBoxLayout() + self.verticalLayout_2.setObjectName("verticalLayout_2") + self.label = QtWidgets.QLabel(parent=self.groupBox_3) + self.label.setObjectName("label") + self.verticalLayout_2.addWidget(self.label) + self.botname = QtWidgets.QLineEdit(parent=self.groupBox_3) + self.botname.setCursorPosition(6) + self.botname.setObjectName("botname") + self.verticalLayout_2.addWidget(self.botname) + self.horizontalLayout_5.addLayout(self.verticalLayout_2) + self.verticalLayout_3 = QtWidgets.QVBoxLayout() + self.verticalLayout_3.setObjectName("verticalLayout_3") + self.label_2 = QtWidgets.QLabel(parent=self.groupBox_3) + self.label_2.setObjectName("label_2") + self.verticalLayout_3.addWidget(self.label_2) + self.personality_type = QtWidgets.QComboBox(parent=self.groupBox_3) + self.personality_type.setObjectName("personality_type") + self.personality_type.addItem("") + self.personality_type.addItem("") + self.personality_type.addItem("") + self.personality_type.addItem("") + self.personality_type.addItem("") + self.verticalLayout_3.addWidget(self.personality_type) + self.horizontalLayout_5.addLayout(self.verticalLayout_3) + self.gridLayout.addWidget(self.groupBox_3, 1, 0, 1, 1) + self.groupBox_4 = QtWidgets.QGroupBox(parent=llm_preferences_widget) + self.groupBox_4.setObjectName("groupBox_4") + self.gridLayout_13 = QtWidgets.QGridLayout(self.groupBox_4) + self.gridLayout_13.setObjectName("gridLayout_13") + self.username = QtWidgets.QLineEdit(parent=self.groupBox_4) + self.username.setObjectName("username") + self.gridLayout_13.addWidget(self.username, 0, 0, 1, 1) + self.gridLayout.addWidget(self.groupBox_4, 2, 0, 1, 1) + self.generate_characters_button = QtWidgets.QPushButton(parent=llm_preferences_widget) + self.generate_characters_button.setObjectName("generate_characters_button") + self.gridLayout.addWidget(self.generate_characters_button, 3, 0, 1, 1) + + self.retranslateUi(llm_preferences_widget) + self.generate_characters_button.clicked.connect(llm_preferences_widget.action_button_clicked_generate_characters) # type: ignore + self.username.textEdited['QString'].connect(llm_preferences_widget.username_text_changed) # type: ignore + self.botname.textEdited['QString'].connect(llm_preferences_widget.botname_text_changed) # type: ignore + self.personality_type.textHighlighted['QString'].connect(llm_preferences_widget.personality_type_changed) # type: ignore + self.suffix.textChanged.connect(llm_preferences_widget.suffix_text_changed) # type: ignore + self.prefix.textChanged.connect(llm_preferences_widget.prefix_text_changed) # type: ignore + QtCore.QMetaObject.connectSlotsByName(llm_preferences_widget) + llm_preferences_widget.setTabOrder(self.prefix, self.suffix) + llm_preferences_widget.setTabOrder(self.suffix, self.botname) + llm_preferences_widget.setTabOrder(self.botname, self.username) + + def retranslateUi(self, llm_preferences_widget): + _translate = QtCore.QCoreApplication.translate + llm_preferences_widget.setWindowTitle(_translate("llm_preferences_widget", "Form")) + self.groupBox.setTitle(_translate("llm_preferences_widget", "Prefix")) + self.groupBox_2.setTitle(_translate("llm_preferences_widget", "Suffix")) + self.groupBox_3.setTitle(_translate("llm_preferences_widget", "Bot details")) + self.label.setText(_translate("llm_preferences_widget", "Name")) + self.botname.setText(_translate("llm_preferences_widget", "ChatAI")) + self.label_2.setText(_translate("llm_preferences_widget", "Personality")) + self.personality_type.setItemText(0, _translate("llm_preferences_widget", "Nice")) + self.personality_type.setItemText(1, _translate("llm_preferences_widget", "Mean")) + self.personality_type.setItemText(2, _translate("llm_preferences_widget", "Weird")) + self.personality_type.setItemText(3, _translate("llm_preferences_widget", "Insane")) + self.personality_type.setItemText(4, _translate("llm_preferences_widget", "Random")) + self.groupBox_4.setTitle(_translate("llm_preferences_widget", "User name")) + self.username.setText(_translate("llm_preferences_widget", "User")) + self.generate_characters_button.setText(_translate("llm_preferences_widget", "Generate Characters")) diff --git a/src/airunner/widgets/llm/templates/llm_settings.ui b/src/airunner/widgets/llm/templates/llm_settings.ui new file mode 100644 index 000000000..e7129ede0 --- /dev/null +++ b/src/airunner/widgets/llm/templates/llm_settings.ui @@ -0,0 +1,1088 @@ + + + llm_settings_widget + + + + 0 + 0 + 1262 + 1059 + + + + Form + + + + + + Model Type + + + + + + + + + + + + Model Version + + + + + + -1 + + + + + + + + + + Prompt Template + + + + + + + + + + + + DType + + + + + + + + 2-bit + + + + + + + 4-bit + + + + + + + 8-bit + + + + + + + 16-bit + + + + + + + 32-bit + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + + + 9 + + + + Description + + + + + + + Use GPU + + + + + + + + + + Override Prameters + + + true + + + true + + + + + + Reset Settings to Default + + + + + + + Seed + + + + + + + + + Random seed + + + + + + + + + + + + Early stopping + + + + + + + Do sample + + + + + + + + + + + Repetition penalty + + + + + + 1 + + + 10000 + + + 0.010000000000000 + + + 100.000000000000000 + + + true + + + llm_generator_setting.repetition_penalty + + + 0 + + + 1 + + + 1.000000000000000 + + + 10.000000000000000 + + + handle_value_change + + + + + + + + + + Min length + + + + + + 1 + + + 2556 + + + 1.000000000000000 + + + 2556.000000000000000 + + + false + + + llm_generator_setting.min_length + + + 1 + + + 2556 + + + 1 + + + 2556 + + + handle_value_change + + + + + + + + + + + + + + Top P + + + + + + 1 + + + 100 + + + 0.000000000000000 + + + 1.000000000000000 + + + true + + + llm_generator_setting.top_p + + + 1 + + + 10 + + + 0.010000000000000 + + + 0.100000000000000 + + + handle_value_change + + + + + + + + + + Max length + + + + + + 1 + + + 2556 + + + 1.000000000000000 + + + 2556.000000000000000 + + + false + + + llm_generator_setting.max_length + + + 1 + + + 2556 + + + 1 + + + 2556 + + + handle_value_change + + + + + + + + + + + + + + Leave in VRAM + + + + + + + Move to CPU + + + + + + + Unload model + + + + + + + + + + true + + + + Model management + + + + + + + + + No repeat ngram size + + + + + + 0 + + + 20 + + + 0.000000000000000 + + + 20.000000000000000 + + + false + + + llm_generator_setting.ngram_size + + + 1 + + + 1 + + + 1.000000000000000 + + + 1.000000000000000 + + + handle_value_change + + + + + + + + + + Temperature + + + + + + 1 + + + 20000 + + + 0.000100000000000 + + + 2.000000000000000 + + + true + + + llm_generator_setting.temperature + + + 1 + + + 10 + + + 0.010000000000000 + + + 0.100000000000000 + + + handle_value_change + + + + + + + + + + + + + + Length penalty + + + + + + -100 + + + 100 + + + 0.000000000000000 + + + 1.000000000000000 + + + true + + + llm_generator_setting.length_penalty + + + 1 + + + 10 + + + 0.010000000000000 + + + 0.100000000000000 + + + handle_value_change + + + + + + + + + + Num beams + + + + + + 1 + + + 100 + + + 0.000000000000000 + + + 100.000000000000000 + + + false + + + llm_generator_setting.num_beams + + + 1 + + + 10 + + + 0.010000000000000 + + + 0.100000000000000 + + + handle_value_change + + + + + + + + + + + + Qt::Horizontal + + + + + + + + + Sequences to generate + + + + + + 1 + + + 100 + + + 0.000000000000000 + + + 100.000000000000000 + + + false + + + llm_generator_setting.sequences + + + 1 + + + 10 + + + 0.010000000000000 + + + 0.100000000000000 + + + handle_value_change + + + + + + + + + + Top k + + + + + + 0 + + + 256 + + + 0.000000000000000 + + + 256.000000000000000 + + + false + + + llm_generator_setting.top_k + + + 1 + + + 10 + + + 1 + + + 10 + + + handle_value_change + + + + + + + + + + + + + 9 + + + + How to treat model when not in use + + + + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + + SliderWidget + QWidget +
airunner/widgets/slider/slider_widget
+ 1 +
+
+ + model + model_version + use_gpu_checkbox + radio_button_4bit + radio_button_8bit + radio_button_16bit + radio_button_32bit + seed + random_seed + + + + + radio_button_2bit + toggled(bool) + llm_settings_widget + toggled_2bit(bool) + + + 78 + 318 + + + 0 + 324 + + + + + radio_button_16bit + toggled(bool) + llm_settings_widget + toggled_16bit(bool) + + + 276 + 318 + + + 3 + 136 + + + + + use_gpu_checkbox + toggled(bool) + llm_settings_widget + use_gpu_toggled(bool) + + + 462 + 288 + + + 34 + 0 + + + + + radio_button_4bit + toggled(bool) + llm_settings_widget + toggled_4bit(bool) + + + 141 + 318 + + + 3 + 43 + + + + + radio_button_8bit + toggled(bool) + llm_settings_widget + toggled_8bit(bool) + + + 204 + 318 + + + 2 + 88 + + + + + radio_button_32bit + toggled(bool) + llm_settings_widget + toggled_32bit(bool) + + + 348 + 318 + + + 0 + 179 + + + + + override_parameters + toggled(bool) + llm_settings_widget + override_parameters_toggled(bool) + + + 128 + 398 + + + 120 + 0 + + + + + seed + textEdited(QString) + llm_settings_widget + seed_changed(QString) + + + 301 + 756 + + + 2 + 382 + + + + + early_stopping + toggled(bool) + llm_settings_widget + early_stopping_toggled(bool) + + + 125 + 798 + + + 1 + 501 + + + + + random_seed + toggled(bool) + llm_settings_widget + random_seed_toggled(bool) + + + 1228 + 755 + + + 3 + 470 + + + + + pushButton + clicked() + llm_settings_widget + reset_settings_to_default_clicked() + + + 311 + 913 + + + 437 + 0 + + + + + move_to_cpu + toggled(bool) + llm_settings_widget + toggle_move_model_to_cpu(bool) + + + 831 + 881 + + + 0 + 704 + + + + + do_sample + toggled(bool) + llm_settings_widget + do_sample_toggled(bool) + + + 1231 + 798 + + + 1 + 427 + + + + + leave_in_vram + toggled(bool) + llm_settings_widget + toggle_leave_model_in_vram(bool) + + + 142 + 881 + + + 0 + 716 + + + + + unload_model + toggled(bool) + llm_settings_widget + toggle_unload_model(bool) + + + 1239 + 881 + + + 0 + 685 + + + + + prompt_template + currentTextChanged(QString) + llm_settings_widget + prompt_template_text_changed(QString) + + + 90 + 215 + + + 79 + -16 + + + + + model + currentTextChanged(QString) + llm_settings_widget + model_text_changed(QString) + + + 184 + 65 + + + 117 + 0 + + + + + model_version + currentTextChanged(QString) + llm_settings_widget + model_version_changed(QString) + + + 205 + 140 + + + 272 + 0 + + + + + + prompt_text_changed(QString) + botname_text_changed(QString) + username_text_changed(QString) + model_version_changed(QString) + random_seed_toggled(bool) + response_text_changed(QString) + seed_changed(QString) + suffix_text_changed() + prefix_text_changed() + action_button_clicked_send() + action_button_clicked_generate_characters() + action_button_clicked_clear_conversation() + message_type_text_changed(QString) + personality_type_changed(QString) + toggled_4bit(bool) + toggled_8bit(bool) + toggled_16bit(bool) + toggled_32bit(bool) + do_sample_toggled(bool) + early_stopping_toggled(bool) + model_text_changed(QString) + reset_settings_to_default_clicked() + use_gpu_toggled(bool) + override_parameters_toggled(bool) + toggled_2bit(bool) + toggle_leave_model_in_vram(bool) + toggle_unload_model(bool) + toggle_move_model_to_cpu(bool) + prompt_template_text_changed(QString) + +
diff --git a/src/airunner/widgets/llm/templates/llm_settings_ui.py b/src/airunner/widgets/llm/templates/llm_settings_ui.py new file mode 100644 index 000000000..bfb93997b --- /dev/null +++ b/src/airunner/widgets/llm/templates/llm_settings_ui.py @@ -0,0 +1,408 @@ +# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/llm/templates/llm_settings.ui' +# +# Created by: PyQt6 UI code generator 6.4.2 +# +# WARNING: Any manual changes made to this file will be lost when pyuic6 is +# run again. Do not edit this file unless you know what you are doing. + + +from PyQt6 import QtCore, QtGui, QtWidgets + + +class Ui_llm_settings_widget(object): + def setupUi(self, llm_settings_widget): + llm_settings_widget.setObjectName("llm_settings_widget") + llm_settings_widget.resize(1262, 1059) + self.gridLayout = QtWidgets.QGridLayout(llm_settings_widget) + self.gridLayout.setObjectName("gridLayout") + self.groupBox_7 = QtWidgets.QGroupBox(parent=llm_settings_widget) + self.groupBox_7.setObjectName("groupBox_7") + self.gridLayout_10 = QtWidgets.QGridLayout(self.groupBox_7) + self.gridLayout_10.setObjectName("gridLayout_10") + self.model = QtWidgets.QComboBox(parent=self.groupBox_7) + self.model.setObjectName("model") + self.gridLayout_10.addWidget(self.model, 0, 0, 1, 1) + self.gridLayout.addWidget(self.groupBox_7, 0, 0, 1, 1) + self.groupBox_8 = QtWidgets.QGroupBox(parent=llm_settings_widget) + self.groupBox_8.setObjectName("groupBox_8") + self.gridLayout_11 = QtWidgets.QGridLayout(self.groupBox_8) + self.gridLayout_11.setObjectName("gridLayout_11") + self.model_version = QtWidgets.QComboBox(parent=self.groupBox_8) + self.model_version.setObjectName("model_version") + self.gridLayout_11.addWidget(self.model_version, 0, 0, 1, 1) + self.gridLayout.addWidget(self.groupBox_8, 1, 0, 1, 1) + self.groupBox_14 = QtWidgets.QGroupBox(parent=llm_settings_widget) + self.groupBox_14.setObjectName("groupBox_14") + self.gridLayout_17 = QtWidgets.QGridLayout(self.groupBox_14) + self.gridLayout_17.setObjectName("gridLayout_17") + self.prompt_template = QtWidgets.QComboBox(parent=self.groupBox_14) + self.prompt_template.setObjectName("prompt_template") + self.gridLayout_17.addWidget(self.prompt_template, 0, 0, 1, 1) + self.gridLayout.addWidget(self.groupBox_14, 2, 0, 1, 1) + self.groupBox_6 = QtWidgets.QGroupBox(parent=llm_settings_widget) + self.groupBox_6.setObjectName("groupBox_6") + self.gridLayout_9 = QtWidgets.QGridLayout(self.groupBox_6) + self.gridLayout_9.setObjectName("gridLayout_9") + self.horizontalLayout_3 = QtWidgets.QHBoxLayout() + self.horizontalLayout_3.setObjectName("horizontalLayout_3") + self.radio_button_2bit = QtWidgets.QRadioButton(parent=self.groupBox_6) + self.radio_button_2bit.setObjectName("radio_button_2bit") + self.horizontalLayout_3.addWidget(self.radio_button_2bit) + self.radio_button_4bit = QtWidgets.QRadioButton(parent=self.groupBox_6) + self.radio_button_4bit.setObjectName("radio_button_4bit") + self.horizontalLayout_3.addWidget(self.radio_button_4bit) + self.radio_button_8bit = QtWidgets.QRadioButton(parent=self.groupBox_6) + self.radio_button_8bit.setObjectName("radio_button_8bit") + self.horizontalLayout_3.addWidget(self.radio_button_8bit) + self.radio_button_16bit = QtWidgets.QRadioButton(parent=self.groupBox_6) + self.radio_button_16bit.setObjectName("radio_button_16bit") + self.horizontalLayout_3.addWidget(self.radio_button_16bit) + self.radio_button_32bit = QtWidgets.QRadioButton(parent=self.groupBox_6) + self.radio_button_32bit.setObjectName("radio_button_32bit") + self.horizontalLayout_3.addWidget(self.radio_button_32bit) + spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Minimum) + self.horizontalLayout_3.addItem(spacerItem) + self.gridLayout_9.addLayout(self.horizontalLayout_3, 1, 0, 1, 1) + self.dtype_description = QtWidgets.QLabel(parent=self.groupBox_6) + font = QtGui.QFont() + font.setPointSize(9) + self.dtype_description.setFont(font) + self.dtype_description.setObjectName("dtype_description") + self.gridLayout_9.addWidget(self.dtype_description, 2, 0, 1, 1) + self.use_gpu_checkbox = QtWidgets.QCheckBox(parent=self.groupBox_6) + self.use_gpu_checkbox.setObjectName("use_gpu_checkbox") + self.gridLayout_9.addWidget(self.use_gpu_checkbox, 0, 0, 1, 1) + self.gridLayout.addWidget(self.groupBox_6, 3, 0, 1, 1) + self.override_parameters = QtWidgets.QGroupBox(parent=llm_settings_widget) + self.override_parameters.setCheckable(True) + self.override_parameters.setChecked(True) + self.override_parameters.setObjectName("override_parameters") + self.gridLayout_12 = QtWidgets.QGridLayout(self.override_parameters) + self.gridLayout_12.setObjectName("gridLayout_12") + self.pushButton = QtWidgets.QPushButton(parent=self.override_parameters) + self.pushButton.setObjectName("pushButton") + self.gridLayout_12.addWidget(self.pushButton, 14, 0, 1, 1) + self.groupBox_19 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_19.setObjectName("groupBox_19") + self.horizontalLayout_4 = QtWidgets.QHBoxLayout(self.groupBox_19) + self.horizontalLayout_4.setObjectName("horizontalLayout_4") + self.seed = QtWidgets.QLineEdit(parent=self.groupBox_19) + self.seed.setObjectName("seed") + self.horizontalLayout_4.addWidget(self.seed) + self.random_seed = QtWidgets.QCheckBox(parent=self.groupBox_19) + self.random_seed.setObjectName("random_seed") + self.horizontalLayout_4.addWidget(self.random_seed) + self.gridLayout_12.addWidget(self.groupBox_19, 5, 0, 1, 1) + self.horizontalLayout_6 = QtWidgets.QHBoxLayout() + self.horizontalLayout_6.setObjectName("horizontalLayout_6") + self.early_stopping = QtWidgets.QCheckBox(parent=self.override_parameters) + self.early_stopping.setObjectName("early_stopping") + self.horizontalLayout_6.addWidget(self.early_stopping) + self.do_sample = QtWidgets.QCheckBox(parent=self.override_parameters) + self.do_sample.setObjectName("do_sample") + self.horizontalLayout_6.addWidget(self.do_sample) + self.gridLayout_12.addLayout(self.horizontalLayout_6, 8, 0, 1, 1) + self.horizontalLayout_7 = QtWidgets.QHBoxLayout() + self.horizontalLayout_7.setObjectName("horizontalLayout_7") + self.groupBox_11 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_11.setObjectName("groupBox_11") + self.gridLayout_19 = QtWidgets.QGridLayout(self.groupBox_11) + self.gridLayout_19.setObjectName("gridLayout_19") + self.repetition_penalty = SliderWidget(parent=self.groupBox_11) + self.repetition_penalty.setProperty("slider_minimum", 1) + self.repetition_penalty.setProperty("slider_maximum", 10000) + self.repetition_penalty.setProperty("spinbox_minimum", 0.01) + self.repetition_penalty.setProperty("spinbox_maximum", 100.0) + self.repetition_penalty.setProperty("display_as_float", True) + self.repetition_penalty.setProperty("slider_single_step", 0) + self.repetition_penalty.setProperty("slider_page_step", 1) + self.repetition_penalty.setProperty("spinbox_single_step", 1.0) + self.repetition_penalty.setProperty("spinbox_page_step", 10.0) + self.repetition_penalty.setObjectName("repetition_penalty") + self.gridLayout_19.addWidget(self.repetition_penalty, 0, 0, 1, 1) + self.horizontalLayout_7.addWidget(self.groupBox_11) + self.groupBox_16 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_16.setObjectName("groupBox_16") + self.gridLayout_20 = QtWidgets.QGridLayout(self.groupBox_16) + self.gridLayout_20.setObjectName("gridLayout_20") + self.min_length = SliderWidget(parent=self.groupBox_16) + self.min_length.setProperty("slider_minimum", 1) + self.min_length.setProperty("slider_maximum", 2556) + self.min_length.setProperty("spinbox_minimum", 1.0) + self.min_length.setProperty("spinbox_maximum", 2556.0) + self.min_length.setProperty("display_as_float", False) + self.min_length.setProperty("slider_single_step", 1) + self.min_length.setProperty("slider_page_step", 2556) + self.min_length.setProperty("spinbox_single_step", 1) + self.min_length.setProperty("spinbox_page_step", 2556) + self.min_length.setObjectName("min_length") + self.gridLayout_20.addWidget(self.min_length, 0, 0, 1, 1) + self.horizontalLayout_7.addWidget(self.groupBox_16) + self.gridLayout_12.addLayout(self.horizontalLayout_7, 1, 0, 1, 1) + self.horizontalLayout_9 = QtWidgets.QHBoxLayout() + self.horizontalLayout_9.setObjectName("horizontalLayout_9") + self.groupBox_20 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_20.setObjectName("groupBox_20") + self.gridLayout_24 = QtWidgets.QGridLayout(self.groupBox_20) + self.gridLayout_24.setObjectName("gridLayout_24") + self.top_p = SliderWidget(parent=self.groupBox_20) + self.top_p.setProperty("slider_minimum", 1) + self.top_p.setProperty("slider_maximum", 100) + self.top_p.setProperty("spinbox_minimum", 0.0) + self.top_p.setProperty("spinbox_maximum", 1.0) + self.top_p.setProperty("display_as_float", True) + self.top_p.setProperty("slider_single_step", 1) + self.top_p.setProperty("slider_page_step", 10) + self.top_p.setProperty("spinbox_single_step", 0.01) + self.top_p.setProperty("spinbox_page_step", 0.1) + self.top_p.setObjectName("top_p") + self.gridLayout_24.addWidget(self.top_p, 0, 0, 1, 1) + self.horizontalLayout_9.addWidget(self.groupBox_20) + self.groupBox_21 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_21.setObjectName("groupBox_21") + self.gridLayout_25 = QtWidgets.QGridLayout(self.groupBox_21) + self.gridLayout_25.setObjectName("gridLayout_25") + self.max_length = SliderWidget(parent=self.groupBox_21) + self.max_length.setProperty("slider_minimum", 1) + self.max_length.setProperty("slider_maximum", 2556) + self.max_length.setProperty("spinbox_minimum", 1.0) + self.max_length.setProperty("spinbox_maximum", 2556.0) + self.max_length.setProperty("display_as_float", False) + self.max_length.setProperty("slider_single_step", 1) + self.max_length.setProperty("slider_page_step", 2556) + self.max_length.setProperty("spinbox_single_step", 1) + self.max_length.setProperty("spinbox_page_step", 2556) + self.max_length.setObjectName("max_length") + self.gridLayout_25.addWidget(self.max_length, 0, 0, 1, 1) + self.horizontalLayout_9.addWidget(self.groupBox_21) + self.gridLayout_12.addLayout(self.horizontalLayout_9, 0, 0, 1, 1) + self.horizontalLayout_12 = QtWidgets.QHBoxLayout() + self.horizontalLayout_12.setObjectName("horizontalLayout_12") + self.leave_in_vram = QtWidgets.QRadioButton(parent=self.override_parameters) + self.leave_in_vram.setObjectName("leave_in_vram") + self.horizontalLayout_12.addWidget(self.leave_in_vram) + self.move_to_cpu = QtWidgets.QRadioButton(parent=self.override_parameters) + self.move_to_cpu.setObjectName("move_to_cpu") + self.horizontalLayout_12.addWidget(self.move_to_cpu) + self.unload_model = QtWidgets.QRadioButton(parent=self.override_parameters) + self.unload_model.setObjectName("unload_model") + self.horizontalLayout_12.addWidget(self.unload_model) + self.gridLayout_12.addLayout(self.horizontalLayout_12, 12, 0, 1, 1) + self.label_3 = QtWidgets.QLabel(parent=self.override_parameters) + font = QtGui.QFont() + font.setBold(True) + self.label_3.setFont(font) + self.label_3.setObjectName("label_3") + self.gridLayout_12.addWidget(self.label_3, 10, 0, 1, 1) + self.horizontalLayout_8 = QtWidgets.QHBoxLayout() + self.horizontalLayout_8.setObjectName("horizontalLayout_8") + self.groupBox_17 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_17.setObjectName("groupBox_17") + self.gridLayout_21 = QtWidgets.QGridLayout(self.groupBox_17) + self.gridLayout_21.setObjectName("gridLayout_21") + self.ngram_size = SliderWidget(parent=self.groupBox_17) + self.ngram_size.setProperty("slider_minimum", 0) + self.ngram_size.setProperty("slider_maximum", 20) + self.ngram_size.setProperty("spinbox_minimum", 0.0) + self.ngram_size.setProperty("spinbox_maximum", 20.0) + self.ngram_size.setProperty("display_as_float", False) + self.ngram_size.setProperty("slider_single_step", 1) + self.ngram_size.setProperty("slider_page_step", 1) + self.ngram_size.setProperty("spinbox_single_step", 1.0) + self.ngram_size.setProperty("spinbox_page_step", 1.0) + self.ngram_size.setObjectName("ngram_size") + self.gridLayout_21.addWidget(self.ngram_size, 0, 0, 1, 1) + self.horizontalLayout_8.addWidget(self.groupBox_17) + self.groupBox_18 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_18.setObjectName("groupBox_18") + self.gridLayout_22 = QtWidgets.QGridLayout(self.groupBox_18) + self.gridLayout_22.setObjectName("gridLayout_22") + self.temperature = SliderWidget(parent=self.groupBox_18) + self.temperature.setProperty("slider_minimum", 1) + self.temperature.setProperty("slider_maximum", 20000) + self.temperature.setProperty("spinbox_minimum", 0.0001) + self.temperature.setProperty("spinbox_maximum", 2.0) + self.temperature.setProperty("display_as_float", True) + self.temperature.setProperty("slider_single_step", 1) + self.temperature.setProperty("slider_page_step", 10) + self.temperature.setProperty("spinbox_single_step", 0.01) + self.temperature.setProperty("spinbox_page_step", 0.1) + self.temperature.setObjectName("temperature") + self.gridLayout_22.addWidget(self.temperature, 0, 0, 1, 1) + self.horizontalLayout_8.addWidget(self.groupBox_18) + self.gridLayout_12.addLayout(self.horizontalLayout_8, 3, 0, 1, 1) + self.horizontalLayout_11 = QtWidgets.QHBoxLayout() + self.horizontalLayout_11.setObjectName("horizontalLayout_11") + self.groupBox_24 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_24.setObjectName("groupBox_24") + self.gridLayout_28 = QtWidgets.QGridLayout(self.groupBox_24) + self.gridLayout_28.setObjectName("gridLayout_28") + self.length_penalty = SliderWidget(parent=self.groupBox_24) + self.length_penalty.setProperty("slider_minimum", -100) + self.length_penalty.setProperty("slider_maximum", 100) + self.length_penalty.setProperty("spinbox_minimum", 0.0) + self.length_penalty.setProperty("spinbox_maximum", 1.0) + self.length_penalty.setProperty("display_as_float", True) + self.length_penalty.setProperty("slider_single_step", 1) + self.length_penalty.setProperty("slider_page_step", 10) + self.length_penalty.setProperty("spinbox_single_step", 0.01) + self.length_penalty.setProperty("spinbox_page_step", 0.1) + self.length_penalty.setObjectName("length_penalty") + self.gridLayout_28.addWidget(self.length_penalty, 0, 0, 1, 1) + self.horizontalLayout_11.addWidget(self.groupBox_24) + self.groupBox_25 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_25.setObjectName("groupBox_25") + self.gridLayout_29 = QtWidgets.QGridLayout(self.groupBox_25) + self.gridLayout_29.setObjectName("gridLayout_29") + self.num_beams = SliderWidget(parent=self.groupBox_25) + self.num_beams.setProperty("slider_minimum", 1) + self.num_beams.setProperty("slider_maximum", 100) + self.num_beams.setProperty("spinbox_minimum", 0.0) + self.num_beams.setProperty("spinbox_maximum", 100.0) + self.num_beams.setProperty("display_as_float", False) + self.num_beams.setProperty("slider_single_step", 1) + self.num_beams.setProperty("slider_page_step", 10) + self.num_beams.setProperty("spinbox_single_step", 0.01) + self.num_beams.setProperty("spinbox_page_step", 0.1) + self.num_beams.setObjectName("num_beams") + self.gridLayout_29.addWidget(self.num_beams, 0, 0, 1, 1) + self.horizontalLayout_11.addWidget(self.groupBox_25) + self.gridLayout_12.addLayout(self.horizontalLayout_11, 2, 0, 1, 1) + self.line = QtWidgets.QFrame(parent=self.override_parameters) + self.line.setFrameShape(QtWidgets.QFrame.Shape.HLine) + self.line.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken) + self.line.setObjectName("line") + self.gridLayout_12.addWidget(self.line, 9, 0, 1, 1) + self.horizontalLayout_10 = QtWidgets.QHBoxLayout() + self.horizontalLayout_10.setObjectName("horizontalLayout_10") + self.groupBox_22 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_22.setObjectName("groupBox_22") + self.gridLayout_26 = QtWidgets.QGridLayout(self.groupBox_22) + self.gridLayout_26.setObjectName("gridLayout_26") + self.sequences = SliderWidget(parent=self.groupBox_22) + self.sequences.setProperty("slider_minimum", 1) + self.sequences.setProperty("slider_maximum", 100) + self.sequences.setProperty("spinbox_minimum", 0.0) + self.sequences.setProperty("spinbox_maximum", 100.0) + self.sequences.setProperty("display_as_float", False) + self.sequences.setProperty("slider_single_step", 1) + self.sequences.setProperty("slider_page_step", 10) + self.sequences.setProperty("spinbox_single_step", 0.01) + self.sequences.setProperty("spinbox_page_step", 0.1) + self.sequences.setObjectName("sequences") + self.gridLayout_26.addWidget(self.sequences, 0, 0, 1, 1) + self.horizontalLayout_10.addWidget(self.groupBox_22) + self.groupBox_23 = QtWidgets.QGroupBox(parent=self.override_parameters) + self.groupBox_23.setObjectName("groupBox_23") + self.gridLayout_27 = QtWidgets.QGridLayout(self.groupBox_23) + self.gridLayout_27.setObjectName("gridLayout_27") + self.top_k = SliderWidget(parent=self.groupBox_23) + self.top_k.setProperty("slider_minimum", 0) + self.top_k.setProperty("slider_maximum", 256) + self.top_k.setProperty("spinbox_minimum", 0.0) + self.top_k.setProperty("spinbox_maximum", 256.0) + self.top_k.setProperty("display_as_float", False) + self.top_k.setProperty("slider_single_step", 1) + self.top_k.setProperty("slider_page_step", 10) + self.top_k.setProperty("spinbox_single_step", 1) + self.top_k.setProperty("spinbox_page_step", 10) + self.top_k.setObjectName("top_k") + self.gridLayout_27.addWidget(self.top_k, 0, 0, 1, 1) + self.horizontalLayout_10.addWidget(self.groupBox_23) + self.gridLayout_12.addLayout(self.horizontalLayout_10, 4, 0, 1, 1) + self.label_4 = QtWidgets.QLabel(parent=self.override_parameters) + font = QtGui.QFont() + font.setPointSize(9) + self.label_4.setFont(font) + self.label_4.setObjectName("label_4") + self.gridLayout_12.addWidget(self.label_4, 11, 0, 1, 1) + self.gridLayout.addWidget(self.override_parameters, 4, 0, 1, 1) + spacerItem1 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding) + self.gridLayout.addItem(spacerItem1, 5, 0, 1, 1) + + self.retranslateUi(llm_settings_widget) + self.model_version.setCurrentIndex(-1) + self.radio_button_2bit.toggled['bool'].connect(llm_settings_widget.toggled_2bit) # type: ignore + self.radio_button_16bit.toggled['bool'].connect(llm_settings_widget.toggled_16bit) # type: ignore + self.use_gpu_checkbox.toggled['bool'].connect(llm_settings_widget.use_gpu_toggled) # type: ignore + self.radio_button_4bit.toggled['bool'].connect(llm_settings_widget.toggled_4bit) # type: ignore + self.radio_button_8bit.toggled['bool'].connect(llm_settings_widget.toggled_8bit) # type: ignore + self.radio_button_32bit.toggled['bool'].connect(llm_settings_widget.toggled_32bit) # type: ignore + self.override_parameters.toggled['bool'].connect(llm_settings_widget.override_parameters_toggled) # type: ignore + self.seed.textEdited['QString'].connect(llm_settings_widget.seed_changed) # type: ignore + self.early_stopping.toggled['bool'].connect(llm_settings_widget.early_stopping_toggled) # type: ignore + self.random_seed.toggled['bool'].connect(llm_settings_widget.random_seed_toggled) # type: ignore + self.pushButton.clicked.connect(llm_settings_widget.reset_settings_to_default_clicked) # type: ignore + self.move_to_cpu.toggled['bool'].connect(llm_settings_widget.toggle_move_model_to_cpu) # type: ignore + self.do_sample.toggled['bool'].connect(llm_settings_widget.do_sample_toggled) # type: ignore + self.leave_in_vram.toggled['bool'].connect(llm_settings_widget.toggle_leave_model_in_vram) # type: ignore + self.unload_model.toggled['bool'].connect(llm_settings_widget.toggle_unload_model) # type: ignore + self.prompt_template.currentTextChanged['QString'].connect(llm_settings_widget.prompt_template_text_changed) # type: ignore + self.model.currentTextChanged['QString'].connect(llm_settings_widget.model_text_changed) # type: ignore + self.model_version.currentTextChanged['QString'].connect(llm_settings_widget.model_version_changed) # type: ignore + QtCore.QMetaObject.connectSlotsByName(llm_settings_widget) + llm_settings_widget.setTabOrder(self.model, self.model_version) + llm_settings_widget.setTabOrder(self.model_version, self.use_gpu_checkbox) + llm_settings_widget.setTabOrder(self.use_gpu_checkbox, self.radio_button_4bit) + llm_settings_widget.setTabOrder(self.radio_button_4bit, self.radio_button_8bit) + llm_settings_widget.setTabOrder(self.radio_button_8bit, self.radio_button_16bit) + llm_settings_widget.setTabOrder(self.radio_button_16bit, self.radio_button_32bit) + llm_settings_widget.setTabOrder(self.radio_button_32bit, self.seed) + llm_settings_widget.setTabOrder(self.seed, self.random_seed) + + def retranslateUi(self, llm_settings_widget): + _translate = QtCore.QCoreApplication.translate + llm_settings_widget.setWindowTitle(_translate("llm_settings_widget", "Form")) + self.groupBox_7.setTitle(_translate("llm_settings_widget", "Model Type")) + self.groupBox_8.setTitle(_translate("llm_settings_widget", "Model Version")) + self.groupBox_14.setTitle(_translate("llm_settings_widget", "Prompt Template")) + self.groupBox_6.setTitle(_translate("llm_settings_widget", "DType")) + self.radio_button_2bit.setText(_translate("llm_settings_widget", "2-bit")) + self.radio_button_4bit.setText(_translate("llm_settings_widget", "4-bit")) + self.radio_button_8bit.setText(_translate("llm_settings_widget", "8-bit")) + self.radio_button_16bit.setText(_translate("llm_settings_widget", "16-bit")) + self.radio_button_32bit.setText(_translate("llm_settings_widget", "32-bit")) + self.dtype_description.setText(_translate("llm_settings_widget", "Description")) + self.use_gpu_checkbox.setText(_translate("llm_settings_widget", "Use GPU")) + self.override_parameters.setTitle(_translate("llm_settings_widget", "Override Prameters")) + self.pushButton.setText(_translate("llm_settings_widget", "Reset Settings to Default")) + self.groupBox_19.setTitle(_translate("llm_settings_widget", "Seed")) + self.random_seed.setText(_translate("llm_settings_widget", "Random seed")) + self.early_stopping.setText(_translate("llm_settings_widget", "Early stopping")) + self.do_sample.setText(_translate("llm_settings_widget", "Do sample")) + self.groupBox_11.setTitle(_translate("llm_settings_widget", "Repetition penalty")) + self.repetition_penalty.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.repetition_penalty")) + self.repetition_penalty.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.groupBox_16.setTitle(_translate("llm_settings_widget", "Min length")) + self.min_length.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.min_length")) + self.min_length.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.groupBox_20.setTitle(_translate("llm_settings_widget", "Top P")) + self.top_p.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.top_p")) + self.top_p.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.groupBox_21.setTitle(_translate("llm_settings_widget", "Max length")) + self.max_length.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.max_length")) + self.max_length.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.leave_in_vram.setText(_translate("llm_settings_widget", "Leave in VRAM")) + self.move_to_cpu.setText(_translate("llm_settings_widget", "Move to CPU")) + self.unload_model.setText(_translate("llm_settings_widget", "Unload model")) + self.label_3.setText(_translate("llm_settings_widget", "Model management")) + self.groupBox_17.setTitle(_translate("llm_settings_widget", "No repeat ngram size")) + self.ngram_size.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.ngram_size")) + self.ngram_size.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.groupBox_18.setTitle(_translate("llm_settings_widget", "Temperature")) + self.temperature.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.temperature")) + self.temperature.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.groupBox_24.setTitle(_translate("llm_settings_widget", "Length penalty")) + self.length_penalty.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.length_penalty")) + self.length_penalty.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.groupBox_25.setTitle(_translate("llm_settings_widget", "Num beams")) + self.num_beams.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.num_beams")) + self.num_beams.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.groupBox_22.setTitle(_translate("llm_settings_widget", "Sequences to generate")) + self.sequences.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.sequences")) + self.sequences.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.groupBox_23.setTitle(_translate("llm_settings_widget", "Top k")) + self.top_k.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.top_k")) + self.top_k.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change")) + self.label_4.setText(_translate("llm_settings_widget", "How to treat model when not in use")) +from airunner.widgets.slider.slider_widget import SliderWidget diff --git a/src/airunner/widgets/llm/templates/llm_widget.ui b/src/airunner/widgets/llm/templates/llm_widget.ui index acc764482..5fefdb191 100644 --- a/src/airunner/widgets/llm/templates/llm_widget.ui +++ b/src/airunner/widgets/llm/templates/llm_widget.ui @@ -53,90 +53,7 @@ 0
- - - true - - - - - - - - - Send - - - - - - - 0 - - - - - - - New Conversation - - - - - - - - - - 0 - - - 0 - - - 0 - - - 0 - - - - - Chat input - - - - - - - - Chat - - - - - Narrate - - - - - Generate Image - - - - - Summarize - - - - - Translate - - - - - - + @@ -149,125 +66,20 @@ Preferences + + 0 + + + 0 + + + 0 + + + 0 + - - - Qt::Vertical - - - - Prefix - - - - - - - - - - Suffix - - - - - - - - - - - - - Bot details - - - - - - - - Name - - - - - - - ChatAI - - - 6 - - - - - - - - - - - Personality - - - - - - - - Nice - - - - - Mean - - - - - Weird - - - - - Insane - - - - - Random - - - - - - - - - - - - - User name - - - - - - User - - - - - - - - - - Generate Characters - - + @@ -280,733 +92,20 @@ Settings + + 0 + + + 0 + + + 0 + + + 0 + - - - Model Type - - - - - - - - - - - - DType - - - - - - - - 2-bit - - - - - - - 4-bit - - - - - - - 8-bit - - - - - - - 16-bit - - - - - - - 32-bit - - - - - - - Qt::Horizontal - - - - 40 - 20 - - - - - - - - - - - 9 - - - - Description - - - - - - - Use GPU - - - - - - - - - - Override Prameters - - - true - - - true - - - - - - Reset Settings to Default - - - - - - - Seed - - - - - - - - - Random seed - - - - - - - - - - - - Early stopping - - - - - - - Do sample - - - - - - - - - - - Repetition penalty - - - - - - 1 - - - 10000 - - - 0.010000000000000 - - - 100.000000000000000 - - - true - - - llm_generator_setting.repetition_penalty - - - 0 - - - 1 - - - 1.000000000000000 - - - 10.000000000000000 - - - handle_value_change - - - - - - - - - - Min length - - - - - - 1 - - - 2556 - - - 1.000000000000000 - - - 2556.000000000000000 - - - false - - - llm_generator_setting.min_length - - - 1 - - - 2556 - - - 1 - - - 2556 - - - handle_value_change - - - - - - - - - - - - - - Top P - - - - - - 1 - - - 100 - - - 0.000000000000000 - - - 1.000000000000000 - - - true - - - llm_generator_setting.top_p - - - 1 - - - 10 - - - 0.010000000000000 - - - 0.100000000000000 - - - handle_value_change - - - - - - - - - - Max length - - - - - - 1 - - - 2556 - - - 1.000000000000000 - - - 2556.000000000000000 - - - false - - - llm_generator_setting.max_length - - - 1 - - - 2556 - - - 1 - - - 2556 - - - handle_value_change - - - - - - - - - - - - - - Leave in VRAM - - - - - - - Move to CPU - - - - - - - Unload model - - - - - - - - - - true - - - - Model management - - - - - - - - - No repeat ngram size - - - - - - 0 - - - 20 - - - 0.000000000000000 - - - 20.000000000000000 - - - false - - - llm_generator_setting.ngram_size - - - 1 - - - 1 - - - 1.000000000000000 - - - 1.000000000000000 - - - handle_value_change - - - - - - - - - - Temperature - - - - - - 1 - - - 20000 - - - 0.000100000000000 - - - 2.000000000000000 - - - true - - - llm_generator_setting.temperature - - - 1 - - - 10 - - - 0.010000000000000 - - - 0.100000000000000 - - - handle_value_change - - - - - - - - - - - - - - Length penalty - - - - - - -100 - - - 100 - - - 0.000000000000000 - - - 1.000000000000000 - - - true - - - llm_generator_setting.length_penalty - - - 1 - - - 10 - - - 0.010000000000000 - - - 0.100000000000000 - - - handle_value_change - - - - - - - - - - Num beams - - - - - - 1 - - - 100 - - - 0.000000000000000 - - - 100.000000000000000 - - - false - - - llm_generator_setting.num_beams - - - 1 - - - 10 - - - 0.010000000000000 - - - 0.100000000000000 - - - handle_value_change - - - - - - - - - - - - Qt::Horizontal - - - - - - - - - Sequences to generate - - - - - - 1 - - - 100 - - - 0.000000000000000 - - - 100.000000000000000 - - - false - - - llm_generator_setting.sequences - - - 1 - - - 10 - - - 0.010000000000000 - - - 0.100000000000000 - - - handle_value_change - - - - - - - - - - Top k - - - - - - 0 - - - 256 - - - 0.000000000000000 - - - 256.000000000000000 - - - false - - - llm_generator_setting.top_k - - - 1 - - - 10 - - - 1 - - - 10 - - - handle_value_change - - - - - - - - - - - - - 9 - - - - How to treat model when not in use - - - - - - - - - - Model Version - - - - - - -1 - - - - - - - - - - Qt::Vertical - - - - 20 - 40 - - - - - - - - Prompt Template - - - - - - - + @@ -1016,455 +115,29 @@ - SliderWidget + ChatPromptWidget + QWidget +
airunner/widgets/llm/chat_prompt_widget
+ 1 +
+ + LLMSettingsWidget + QWidget +
airunner/widgets/llm/llm_settings_widget
+ 1 +
+ + LLMPreferencesWidget QWidget -
airunner/widgets/slider/slider_widget
+
airunner/widgets/llm/llm_preferences_widget
1
- prompt - comboBox - send_button - clear_conversatiion_button - prefix - suffix - botname - username - generate_characters_button - model - model_version - use_gpu_checkbox - radio_button_4bit - radio_button_8bit - radio_button_16bit - radio_button_32bit - override_parameters - seed - random_seed tabWidget - comboBox - prompt - - - prefix - textChanged() - llm_widget - prefix_text_changed() - - - 131 - 108 - - - 48 - 6 - - - - - suffix - textChanged() - llm_widget - suffix_text_changed() - - - 182 - 669 - - - 206 - 5 - - - - - send_button - clicked() - llm_widget - action_button_clicked_send() - - - 91 - 961 - - - 257 - 0 - - - - - clear_conversatiion_button - clicked() - llm_widget - action_button_clicked_clear_conversation() - - - 609 - 961 - - - 498 - 4 - - - - - radio_button_4bit - toggled(bool) - llm_widget - toggled_4bit(bool) - - - 126 - 347 - - - 3 - 43 - - - - - radio_button_8bit - toggled(bool) - llm_widget - toggled_8bit(bool) - - - 206 - 347 - - - 2 - 88 - - - - - radio_button_16bit - toggled(bool) - llm_widget - toggled_16bit(bool) - - - 215 - 347 - - - 3 - 136 - - - - - radio_button_32bit - toggled(bool) - llm_widget - toggled_32bit(bool) - - - 287 - 347 - - - 0 - 179 - - - - - model_version - currentTextChanged(QString) - llm_widget - model_version_changed(QString) - - - 184 - 169 - - - 272 - 0 - - - - - seed - textEdited(QString) - llm_widget - seed_changed(QString) - - - 268 - 785 - - - 2 - 382 - - - - - random_seed - toggled(bool) - llm_widget - random_seed_toggled(bool) - - - 586 - 784 - - - 3 - 470 - - - - - do_sample - toggled(bool) - llm_widget - do_sample_toggled(bool) - - - 597 - 827 - - - 1 - 427 - - - - - early_stopping - toggled(bool) - llm_widget - early_stopping_toggled(bool) - - - 103 - 827 - - - 1 - 501 - - - - - model - currentTextChanged(QString) - llm_widget - model_text_changed(QString) - - - 163 - 92 - - - 117 - 0 - - - - - pushButton - clicked() - llm_widget - reset_settings_to_default_clicked() - - - 290 - 942 - - - 437 - 0 - - - - - use_gpu_checkbox - toggled(bool) - llm_widget - use_gpu_toggled(bool) - - - 441 - 315 - - - 34 - 0 - - - - - override_parameters - toggled(bool) - llm_widget - override_parameters_toggled(bool) - - - 128 - 398 - - - 120 - 0 - - - - - botname - textEdited(QString) - llm_widget - botname_text_changed(QString) - - - 298 - 843 - - - 80 - 0 - - - - - username - textEdited(QString) - llm_widget - username_text_changed(QString) - - - 226 - 919 - - - 118 - 0 - - - - - personality_type - textHighlighted(QString) - llm_widget - personality_type_changed(QString) - - - 597 - 843 - - - 412 - 0 - - - - - generate_characters_button - clicked() - llm_widget - action_button_clicked_generate_characters() - - - 362 - 962 - - - 477 - 0 - - - - - radio_button_2bit - toggled(bool) - llm_widget - toggled_2bit(bool) - - - 63 - 333 - - - 0 - 324 - - - - - leave_in_vram - toggled(bool) - llm_widget - toggle_leave_model_in_vram(bool) - - - 120 - 910 - - - 0 - 716 - - - - - move_to_cpu - toggled(bool) - llm_widget - toggle_move_model_to_cpu(bool) - - - 404 - 910 - - - 0 - 704 - - - - - unload_model - toggled(bool) - llm_widget - toggle_unload_model(bool) - - - 597 - 910 - - - 0 - 685 - - - - - prompt_template - currentTextChanged(QString) - llm_widget - prompt_template_text_changed(QString) - - - 69 - 226 - - - 79 - -16 - - - - + prompt_text_changed(QString) botname_text_changed(QString) diff --git a/src/airunner/widgets/llm/templates/llm_widget_ui.py b/src/airunner/widgets/llm/templates/llm_widget_ui.py index efb6aee76..1ff1c26bd 100644 --- a/src/airunner/widgets/llm/templates/llm_widget_ui.py +++ b/src/airunner/widgets/llm/templates/llm_widget_ui.py @@ -23,548 +23,43 @@ def setupUi(self, llm_widget): self.gridLayout_8 = QtWidgets.QGridLayout(self.chat) self.gridLayout_8.setContentsMargins(0, 0, 0, 0) self.gridLayout_8.setObjectName("gridLayout_8") - self.conversation = QtWidgets.QTextEdit(parent=self.chat) - self.conversation.setReadOnly(True) - self.conversation.setObjectName("conversation") - self.gridLayout_8.addWidget(self.conversation, 0, 0, 1, 1) - self.horizontalLayout_2 = QtWidgets.QHBoxLayout() - self.horizontalLayout_2.setObjectName("horizontalLayout_2") - self.send_button = QtWidgets.QPushButton(parent=self.chat) - self.send_button.setObjectName("send_button") - self.horizontalLayout_2.addWidget(self.send_button) - self.progressBar = QtWidgets.QProgressBar(parent=self.chat) - self.progressBar.setProperty("value", 0) - self.progressBar.setObjectName("progressBar") - self.horizontalLayout_2.addWidget(self.progressBar) - self.clear_conversatiion_button = QtWidgets.QPushButton(parent=self.chat) - self.clear_conversatiion_button.setObjectName("clear_conversatiion_button") - self.horizontalLayout_2.addWidget(self.clear_conversatiion_button) - self.gridLayout_8.addLayout(self.horizontalLayout_2, 2, 0, 1, 1) - self.widget_5 = QtWidgets.QWidget(parent=self.chat) - self.widget_5.setObjectName("widget_5") - self.horizontalLayout = QtWidgets.QHBoxLayout(self.widget_5) - self.horizontalLayout.setContentsMargins(0, 0, 0, 0) - self.horizontalLayout.setObjectName("horizontalLayout") - self.prompt = QtWidgets.QLineEdit(parent=self.widget_5) - self.prompt.setObjectName("prompt") - self.horizontalLayout.addWidget(self.prompt) - self.comboBox = QtWidgets.QComboBox(parent=self.widget_5) - self.comboBox.setObjectName("comboBox") - self.comboBox.addItem("") - self.comboBox.addItem("") - self.comboBox.addItem("") - self.comboBox.addItem("") - self.comboBox.addItem("") - self.horizontalLayout.addWidget(self.comboBox) - self.gridLayout_8.addWidget(self.widget_5, 1, 0, 1, 1) + self.chat_prompt_widget = ChatPromptWidget(parent=self.chat) + self.chat_prompt_widget.setObjectName("chat_prompt_widget") + self.gridLayout_8.addWidget(self.chat_prompt_widget, 0, 0, 1, 1) icon = QtGui.QIcon.fromTheme("user-available") self.tabWidget.addTab(self.chat, icon, "") self.preferences = QtWidgets.QWidget() self.preferences.setObjectName("preferences") self.gridLayout_2 = QtWidgets.QGridLayout(self.preferences) + self.gridLayout_2.setContentsMargins(0, 0, 0, 0) self.gridLayout_2.setObjectName("gridLayout_2") - self.splitter = QtWidgets.QSplitter(parent=self.preferences) - self.splitter.setOrientation(QtCore.Qt.Orientation.Vertical) - self.splitter.setObjectName("splitter") - self.groupBox = QtWidgets.QGroupBox(parent=self.splitter) - self.groupBox.setObjectName("groupBox") - self.gridLayout_3 = QtWidgets.QGridLayout(self.groupBox) - self.gridLayout_3.setObjectName("gridLayout_3") - self.prefix = QtWidgets.QPlainTextEdit(parent=self.groupBox) - self.prefix.setObjectName("prefix") - self.gridLayout_3.addWidget(self.prefix, 0, 0, 1, 1) - self.groupBox_2 = QtWidgets.QGroupBox(parent=self.splitter) - self.groupBox_2.setObjectName("groupBox_2") - self.gridLayout_4 = QtWidgets.QGridLayout(self.groupBox_2) - self.gridLayout_4.setObjectName("gridLayout_4") - self.suffix = QtWidgets.QPlainTextEdit(parent=self.groupBox_2) - self.suffix.setObjectName("suffix") - self.gridLayout_4.addWidget(self.suffix, 0, 0, 1, 1) - self.gridLayout_2.addWidget(self.splitter, 0, 0, 1, 1) - self.groupBox_3 = QtWidgets.QGroupBox(parent=self.preferences) - self.groupBox_3.setObjectName("groupBox_3") - self.horizontalLayout_5 = QtWidgets.QHBoxLayout(self.groupBox_3) - self.horizontalLayout_5.setObjectName("horizontalLayout_5") - self.verticalLayout_2 = QtWidgets.QVBoxLayout() - self.verticalLayout_2.setObjectName("verticalLayout_2") - self.label = QtWidgets.QLabel(parent=self.groupBox_3) - self.label.setObjectName("label") - self.verticalLayout_2.addWidget(self.label) - self.botname = QtWidgets.QLineEdit(parent=self.groupBox_3) - self.botname.setCursorPosition(6) - self.botname.setObjectName("botname") - self.verticalLayout_2.addWidget(self.botname) - self.horizontalLayout_5.addLayout(self.verticalLayout_2) - self.verticalLayout_3 = QtWidgets.QVBoxLayout() - self.verticalLayout_3.setObjectName("verticalLayout_3") - self.label_2 = QtWidgets.QLabel(parent=self.groupBox_3) - self.label_2.setObjectName("label_2") - self.verticalLayout_3.addWidget(self.label_2) - self.personality_type = QtWidgets.QComboBox(parent=self.groupBox_3) - self.personality_type.setObjectName("personality_type") - self.personality_type.addItem("") - self.personality_type.addItem("") - self.personality_type.addItem("") - self.personality_type.addItem("") - self.personality_type.addItem("") - self.verticalLayout_3.addWidget(self.personality_type) - self.horizontalLayout_5.addLayout(self.verticalLayout_3) - self.gridLayout_2.addWidget(self.groupBox_3, 1, 0, 1, 1) - self.groupBox_4 = QtWidgets.QGroupBox(parent=self.preferences) - self.groupBox_4.setObjectName("groupBox_4") - self.gridLayout_13 = QtWidgets.QGridLayout(self.groupBox_4) - self.gridLayout_13.setObjectName("gridLayout_13") - self.username = QtWidgets.QLineEdit(parent=self.groupBox_4) - self.username.setObjectName("username") - self.gridLayout_13.addWidget(self.username, 0, 0, 1, 1) - self.gridLayout_2.addWidget(self.groupBox_4, 2, 0, 1, 1) - self.generate_characters_button = QtWidgets.QPushButton(parent=self.preferences) - self.generate_characters_button.setObjectName("generate_characters_button") - self.gridLayout_2.addWidget(self.generate_characters_button, 3, 0, 1, 1) + self.llm_preferences_widget = LLMPreferencesWidget(parent=self.preferences) + self.llm_preferences_widget.setObjectName("llm_preferences_widget") + self.gridLayout_2.addWidget(self.llm_preferences_widget, 0, 0, 1, 1) icon = QtGui.QIcon.fromTheme("preferences-desktop") self.tabWidget.addTab(self.preferences, icon, "") self.settings = QtWidgets.QWidget() self.settings.setObjectName("settings") self.gridLayout_5 = QtWidgets.QGridLayout(self.settings) + self.gridLayout_5.setContentsMargins(0, 0, 0, 0) self.gridLayout_5.setObjectName("gridLayout_5") - self.groupBox_7 = QtWidgets.QGroupBox(parent=self.settings) - self.groupBox_7.setObjectName("groupBox_7") - self.gridLayout_10 = QtWidgets.QGridLayout(self.groupBox_7) - self.gridLayout_10.setObjectName("gridLayout_10") - self.model = QtWidgets.QComboBox(parent=self.groupBox_7) - self.model.setObjectName("model") - self.gridLayout_10.addWidget(self.model, 0, 0, 1, 1) - self.gridLayout_5.addWidget(self.groupBox_7, 0, 0, 1, 1) - self.groupBox_6 = QtWidgets.QGroupBox(parent=self.settings) - self.groupBox_6.setObjectName("groupBox_6") - self.gridLayout_9 = QtWidgets.QGridLayout(self.groupBox_6) - self.gridLayout_9.setObjectName("gridLayout_9") - self.horizontalLayout_3 = QtWidgets.QHBoxLayout() - self.horizontalLayout_3.setObjectName("horizontalLayout_3") - self.radio_button_2bit = QtWidgets.QRadioButton(parent=self.groupBox_6) - self.radio_button_2bit.setObjectName("radio_button_2bit") - self.horizontalLayout_3.addWidget(self.radio_button_2bit) - self.radio_button_4bit = QtWidgets.QRadioButton(parent=self.groupBox_6) - self.radio_button_4bit.setObjectName("radio_button_4bit") - self.horizontalLayout_3.addWidget(self.radio_button_4bit) - self.radio_button_8bit = QtWidgets.QRadioButton(parent=self.groupBox_6) - self.radio_button_8bit.setObjectName("radio_button_8bit") - self.horizontalLayout_3.addWidget(self.radio_button_8bit) - self.radio_button_16bit = QtWidgets.QRadioButton(parent=self.groupBox_6) - self.radio_button_16bit.setObjectName("radio_button_16bit") - self.horizontalLayout_3.addWidget(self.radio_button_16bit) - self.radio_button_32bit = QtWidgets.QRadioButton(parent=self.groupBox_6) - self.radio_button_32bit.setObjectName("radio_button_32bit") - self.horizontalLayout_3.addWidget(self.radio_button_32bit) - spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Minimum) - self.horizontalLayout_3.addItem(spacerItem) - self.gridLayout_9.addLayout(self.horizontalLayout_3, 1, 0, 1, 1) - self.dtype_description = QtWidgets.QLabel(parent=self.groupBox_6) - font = QtGui.QFont() - font.setPointSize(9) - self.dtype_description.setFont(font) - self.dtype_description.setObjectName("dtype_description") - self.gridLayout_9.addWidget(self.dtype_description, 2, 0, 1, 1) - self.use_gpu_checkbox = QtWidgets.QCheckBox(parent=self.groupBox_6) - self.use_gpu_checkbox.setObjectName("use_gpu_checkbox") - self.gridLayout_9.addWidget(self.use_gpu_checkbox, 0, 0, 1, 1) - self.gridLayout_5.addWidget(self.groupBox_6, 3, 0, 1, 1) - self.override_parameters = QtWidgets.QGroupBox(parent=self.settings) - self.override_parameters.setCheckable(True) - self.override_parameters.setChecked(True) - self.override_parameters.setObjectName("override_parameters") - self.gridLayout_12 = QtWidgets.QGridLayout(self.override_parameters) - self.gridLayout_12.setObjectName("gridLayout_12") - self.pushButton = QtWidgets.QPushButton(parent=self.override_parameters) - self.pushButton.setObjectName("pushButton") - self.gridLayout_12.addWidget(self.pushButton, 14, 0, 1, 1) - self.groupBox_19 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_19.setObjectName("groupBox_19") - self.horizontalLayout_4 = QtWidgets.QHBoxLayout(self.groupBox_19) - self.horizontalLayout_4.setObjectName("horizontalLayout_4") - self.seed = QtWidgets.QLineEdit(parent=self.groupBox_19) - self.seed.setObjectName("seed") - self.horizontalLayout_4.addWidget(self.seed) - self.random_seed = QtWidgets.QCheckBox(parent=self.groupBox_19) - self.random_seed.setObjectName("random_seed") - self.horizontalLayout_4.addWidget(self.random_seed) - self.gridLayout_12.addWidget(self.groupBox_19, 5, 0, 1, 1) - self.horizontalLayout_6 = QtWidgets.QHBoxLayout() - self.horizontalLayout_6.setObjectName("horizontalLayout_6") - self.early_stopping = QtWidgets.QCheckBox(parent=self.override_parameters) - self.early_stopping.setObjectName("early_stopping") - self.horizontalLayout_6.addWidget(self.early_stopping) - self.do_sample = QtWidgets.QCheckBox(parent=self.override_parameters) - self.do_sample.setObjectName("do_sample") - self.horizontalLayout_6.addWidget(self.do_sample) - self.gridLayout_12.addLayout(self.horizontalLayout_6, 8, 0, 1, 1) - self.horizontalLayout_7 = QtWidgets.QHBoxLayout() - self.horizontalLayout_7.setObjectName("horizontalLayout_7") - self.groupBox_11 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_11.setObjectName("groupBox_11") - self.gridLayout_19 = QtWidgets.QGridLayout(self.groupBox_11) - self.gridLayout_19.setObjectName("gridLayout_19") - self.repetition_penalty = SliderWidget(parent=self.groupBox_11) - self.repetition_penalty.setProperty("slider_minimum", 1) - self.repetition_penalty.setProperty("slider_maximum", 10000) - self.repetition_penalty.setProperty("spinbox_minimum", 0.01) - self.repetition_penalty.setProperty("spinbox_maximum", 100.0) - self.repetition_penalty.setProperty("display_as_float", True) - self.repetition_penalty.setProperty("slider_single_step", 0) - self.repetition_penalty.setProperty("slider_page_step", 1) - self.repetition_penalty.setProperty("spinbox_single_step", 1.0) - self.repetition_penalty.setProperty("spinbox_page_step", 10.0) - self.repetition_penalty.setObjectName("repetition_penalty") - self.gridLayout_19.addWidget(self.repetition_penalty, 0, 0, 1, 1) - self.horizontalLayout_7.addWidget(self.groupBox_11) - self.groupBox_16 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_16.setObjectName("groupBox_16") - self.gridLayout_20 = QtWidgets.QGridLayout(self.groupBox_16) - self.gridLayout_20.setObjectName("gridLayout_20") - self.min_length = SliderWidget(parent=self.groupBox_16) - self.min_length.setProperty("slider_minimum", 1) - self.min_length.setProperty("slider_maximum", 2556) - self.min_length.setProperty("spinbox_minimum", 1.0) - self.min_length.setProperty("spinbox_maximum", 2556.0) - self.min_length.setProperty("display_as_float", False) - self.min_length.setProperty("slider_single_step", 1) - self.min_length.setProperty("slider_page_step", 2556) - self.min_length.setProperty("spinbox_single_step", 1) - self.min_length.setProperty("spinbox_page_step", 2556) - self.min_length.setObjectName("min_length") - self.gridLayout_20.addWidget(self.min_length, 0, 0, 1, 1) - self.horizontalLayout_7.addWidget(self.groupBox_16) - self.gridLayout_12.addLayout(self.horizontalLayout_7, 1, 0, 1, 1) - self.horizontalLayout_9 = QtWidgets.QHBoxLayout() - self.horizontalLayout_9.setObjectName("horizontalLayout_9") - self.groupBox_20 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_20.setObjectName("groupBox_20") - self.gridLayout_24 = QtWidgets.QGridLayout(self.groupBox_20) - self.gridLayout_24.setObjectName("gridLayout_24") - self.top_p = SliderWidget(parent=self.groupBox_20) - self.top_p.setProperty("slider_minimum", 1) - self.top_p.setProperty("slider_maximum", 100) - self.top_p.setProperty("spinbox_minimum", 0.0) - self.top_p.setProperty("spinbox_maximum", 1.0) - self.top_p.setProperty("display_as_float", True) - self.top_p.setProperty("slider_single_step", 1) - self.top_p.setProperty("slider_page_step", 10) - self.top_p.setProperty("spinbox_single_step", 0.01) - self.top_p.setProperty("spinbox_page_step", 0.1) - self.top_p.setObjectName("top_p") - self.gridLayout_24.addWidget(self.top_p, 0, 0, 1, 1) - self.horizontalLayout_9.addWidget(self.groupBox_20) - self.groupBox_21 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_21.setObjectName("groupBox_21") - self.gridLayout_25 = QtWidgets.QGridLayout(self.groupBox_21) - self.gridLayout_25.setObjectName("gridLayout_25") - self.max_length = SliderWidget(parent=self.groupBox_21) - self.max_length.setProperty("slider_minimum", 1) - self.max_length.setProperty("slider_maximum", 2556) - self.max_length.setProperty("spinbox_minimum", 1.0) - self.max_length.setProperty("spinbox_maximum", 2556.0) - self.max_length.setProperty("display_as_float", False) - self.max_length.setProperty("slider_single_step", 1) - self.max_length.setProperty("slider_page_step", 2556) - self.max_length.setProperty("spinbox_single_step", 1) - self.max_length.setProperty("spinbox_page_step", 2556) - self.max_length.setObjectName("max_length") - self.gridLayout_25.addWidget(self.max_length, 0, 0, 1, 1) - self.horizontalLayout_9.addWidget(self.groupBox_21) - self.gridLayout_12.addLayout(self.horizontalLayout_9, 0, 0, 1, 1) - self.horizontalLayout_12 = QtWidgets.QHBoxLayout() - self.horizontalLayout_12.setObjectName("horizontalLayout_12") - self.leave_in_vram = QtWidgets.QRadioButton(parent=self.override_parameters) - self.leave_in_vram.setObjectName("leave_in_vram") - self.horizontalLayout_12.addWidget(self.leave_in_vram) - self.move_to_cpu = QtWidgets.QRadioButton(parent=self.override_parameters) - self.move_to_cpu.setObjectName("move_to_cpu") - self.horizontalLayout_12.addWidget(self.move_to_cpu) - self.unload_model = QtWidgets.QRadioButton(parent=self.override_parameters) - self.unload_model.setObjectName("unload_model") - self.horizontalLayout_12.addWidget(self.unload_model) - self.gridLayout_12.addLayout(self.horizontalLayout_12, 12, 0, 1, 1) - self.label_3 = QtWidgets.QLabel(parent=self.override_parameters) - font = QtGui.QFont() - font.setBold(True) - self.label_3.setFont(font) - self.label_3.setObjectName("label_3") - self.gridLayout_12.addWidget(self.label_3, 10, 0, 1, 1) - self.horizontalLayout_8 = QtWidgets.QHBoxLayout() - self.horizontalLayout_8.setObjectName("horizontalLayout_8") - self.groupBox_17 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_17.setObjectName("groupBox_17") - self.gridLayout_21 = QtWidgets.QGridLayout(self.groupBox_17) - self.gridLayout_21.setObjectName("gridLayout_21") - self.ngram_size = SliderWidget(parent=self.groupBox_17) - self.ngram_size.setProperty("slider_minimum", 0) - self.ngram_size.setProperty("slider_maximum", 20) - self.ngram_size.setProperty("spinbox_minimum", 0.0) - self.ngram_size.setProperty("spinbox_maximum", 20.0) - self.ngram_size.setProperty("display_as_float", False) - self.ngram_size.setProperty("slider_single_step", 1) - self.ngram_size.setProperty("slider_page_step", 1) - self.ngram_size.setProperty("spinbox_single_step", 1.0) - self.ngram_size.setProperty("spinbox_page_step", 1.0) - self.ngram_size.setObjectName("ngram_size") - self.gridLayout_21.addWidget(self.ngram_size, 0, 0, 1, 1) - self.horizontalLayout_8.addWidget(self.groupBox_17) - self.groupBox_18 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_18.setObjectName("groupBox_18") - self.gridLayout_22 = QtWidgets.QGridLayout(self.groupBox_18) - self.gridLayout_22.setObjectName("gridLayout_22") - self.temperature = SliderWidget(parent=self.groupBox_18) - self.temperature.setProperty("slider_minimum", 1) - self.temperature.setProperty("slider_maximum", 20000) - self.temperature.setProperty("spinbox_minimum", 0.0001) - self.temperature.setProperty("spinbox_maximum", 2.0) - self.temperature.setProperty("display_as_float", True) - self.temperature.setProperty("slider_single_step", 1) - self.temperature.setProperty("slider_page_step", 10) - self.temperature.setProperty("spinbox_single_step", 0.01) - self.temperature.setProperty("spinbox_page_step", 0.1) - self.temperature.setObjectName("temperature") - self.gridLayout_22.addWidget(self.temperature, 0, 0, 1, 1) - self.horizontalLayout_8.addWidget(self.groupBox_18) - self.gridLayout_12.addLayout(self.horizontalLayout_8, 3, 0, 1, 1) - self.horizontalLayout_11 = QtWidgets.QHBoxLayout() - self.horizontalLayout_11.setObjectName("horizontalLayout_11") - self.groupBox_24 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_24.setObjectName("groupBox_24") - self.gridLayout_28 = QtWidgets.QGridLayout(self.groupBox_24) - self.gridLayout_28.setObjectName("gridLayout_28") - self.length_penalty = SliderWidget(parent=self.groupBox_24) - self.length_penalty.setProperty("slider_minimum", -100) - self.length_penalty.setProperty("slider_maximum", 100) - self.length_penalty.setProperty("spinbox_minimum", 0.0) - self.length_penalty.setProperty("spinbox_maximum", 1.0) - self.length_penalty.setProperty("display_as_float", True) - self.length_penalty.setProperty("slider_single_step", 1) - self.length_penalty.setProperty("slider_page_step", 10) - self.length_penalty.setProperty("spinbox_single_step", 0.01) - self.length_penalty.setProperty("spinbox_page_step", 0.1) - self.length_penalty.setObjectName("length_penalty") - self.gridLayout_28.addWidget(self.length_penalty, 0, 0, 1, 1) - self.horizontalLayout_11.addWidget(self.groupBox_24) - self.groupBox_25 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_25.setObjectName("groupBox_25") - self.gridLayout_29 = QtWidgets.QGridLayout(self.groupBox_25) - self.gridLayout_29.setObjectName("gridLayout_29") - self.num_beams = SliderWidget(parent=self.groupBox_25) - self.num_beams.setProperty("slider_minimum", 1) - self.num_beams.setProperty("slider_maximum", 100) - self.num_beams.setProperty("spinbox_minimum", 0.0) - self.num_beams.setProperty("spinbox_maximum", 100.0) - self.num_beams.setProperty("display_as_float", False) - self.num_beams.setProperty("slider_single_step", 1) - self.num_beams.setProperty("slider_page_step", 10) - self.num_beams.setProperty("spinbox_single_step", 0.01) - self.num_beams.setProperty("spinbox_page_step", 0.1) - self.num_beams.setObjectName("num_beams") - self.gridLayout_29.addWidget(self.num_beams, 0, 0, 1, 1) - self.horizontalLayout_11.addWidget(self.groupBox_25) - self.gridLayout_12.addLayout(self.horizontalLayout_11, 2, 0, 1, 1) - self.line = QtWidgets.QFrame(parent=self.override_parameters) - self.line.setFrameShape(QtWidgets.QFrame.Shape.HLine) - self.line.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken) - self.line.setObjectName("line") - self.gridLayout_12.addWidget(self.line, 9, 0, 1, 1) - self.horizontalLayout_10 = QtWidgets.QHBoxLayout() - self.horizontalLayout_10.setObjectName("horizontalLayout_10") - self.groupBox_22 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_22.setObjectName("groupBox_22") - self.gridLayout_26 = QtWidgets.QGridLayout(self.groupBox_22) - self.gridLayout_26.setObjectName("gridLayout_26") - self.sequences = SliderWidget(parent=self.groupBox_22) - self.sequences.setProperty("slider_minimum", 1) - self.sequences.setProperty("slider_maximum", 100) - self.sequences.setProperty("spinbox_minimum", 0.0) - self.sequences.setProperty("spinbox_maximum", 100.0) - self.sequences.setProperty("display_as_float", False) - self.sequences.setProperty("slider_single_step", 1) - self.sequences.setProperty("slider_page_step", 10) - self.sequences.setProperty("spinbox_single_step", 0.01) - self.sequences.setProperty("spinbox_page_step", 0.1) - self.sequences.setObjectName("sequences") - self.gridLayout_26.addWidget(self.sequences, 0, 0, 1, 1) - self.horizontalLayout_10.addWidget(self.groupBox_22) - self.groupBox_23 = QtWidgets.QGroupBox(parent=self.override_parameters) - self.groupBox_23.setObjectName("groupBox_23") - self.gridLayout_27 = QtWidgets.QGridLayout(self.groupBox_23) - self.gridLayout_27.setObjectName("gridLayout_27") - self.top_k = SliderWidget(parent=self.groupBox_23) - self.top_k.setProperty("slider_minimum", 0) - self.top_k.setProperty("slider_maximum", 256) - self.top_k.setProperty("spinbox_minimum", 0.0) - self.top_k.setProperty("spinbox_maximum", 256.0) - self.top_k.setProperty("display_as_float", False) - self.top_k.setProperty("slider_single_step", 1) - self.top_k.setProperty("slider_page_step", 10) - self.top_k.setProperty("spinbox_single_step", 1) - self.top_k.setProperty("spinbox_page_step", 10) - self.top_k.setObjectName("top_k") - self.gridLayout_27.addWidget(self.top_k, 0, 0, 1, 1) - self.horizontalLayout_10.addWidget(self.groupBox_23) - self.gridLayout_12.addLayout(self.horizontalLayout_10, 4, 0, 1, 1) - self.label_4 = QtWidgets.QLabel(parent=self.override_parameters) - font = QtGui.QFont() - font.setPointSize(9) - self.label_4.setFont(font) - self.label_4.setObjectName("label_4") - self.gridLayout_12.addWidget(self.label_4, 11, 0, 1, 1) - self.gridLayout_5.addWidget(self.override_parameters, 4, 0, 1, 1) - self.groupBox_8 = QtWidgets.QGroupBox(parent=self.settings) - self.groupBox_8.setObjectName("groupBox_8") - self.gridLayout_11 = QtWidgets.QGridLayout(self.groupBox_8) - self.gridLayout_11.setObjectName("gridLayout_11") - self.model_version = QtWidgets.QComboBox(parent=self.groupBox_8) - self.model_version.setObjectName("model_version") - self.gridLayout_11.addWidget(self.model_version, 0, 0, 1, 1) - self.gridLayout_5.addWidget(self.groupBox_8, 1, 0, 1, 1) - spacerItem1 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding) - self.gridLayout_5.addItem(spacerItem1, 16, 0, 1, 1) - self.groupBox_14 = QtWidgets.QGroupBox(parent=self.settings) - self.groupBox_14.setObjectName("groupBox_14") - self.gridLayout_17 = QtWidgets.QGridLayout(self.groupBox_14) - self.gridLayout_17.setObjectName("gridLayout_17") - self.prompt_template = QtWidgets.QComboBox(parent=self.groupBox_14) - self.prompt_template.setObjectName("prompt_template") - self.gridLayout_17.addWidget(self.prompt_template, 0, 0, 1, 1) - self.gridLayout_5.addWidget(self.groupBox_14, 2, 0, 1, 1) + self.llm_settings_widget = LLMSettingsWidget(parent=self.settings) + self.llm_settings_widget.setObjectName("llm_settings_widget") + self.gridLayout_5.addWidget(self.llm_settings_widget, 0, 0, 1, 1) icon = QtGui.QIcon.fromTheme("preferences-other") self.tabWidget.addTab(self.settings, icon, "") self.gridLayout.addWidget(self.tabWidget, 0, 0, 1, 1) self.retranslateUi(llm_widget) self.tabWidget.setCurrentIndex(0) - self.model_version.setCurrentIndex(-1) - self.prefix.textChanged.connect(llm_widget.prefix_text_changed) # type: ignore - self.suffix.textChanged.connect(llm_widget.suffix_text_changed) # type: ignore - self.send_button.clicked.connect(llm_widget.action_button_clicked_send) # type: ignore - self.clear_conversatiion_button.clicked.connect(llm_widget.action_button_clicked_clear_conversation) # type: ignore - self.radio_button_4bit.toggled['bool'].connect(llm_widget.toggled_4bit) # type: ignore - self.radio_button_8bit.toggled['bool'].connect(llm_widget.toggled_8bit) # type: ignore - self.radio_button_16bit.toggled['bool'].connect(llm_widget.toggled_16bit) # type: ignore - self.radio_button_32bit.toggled['bool'].connect(llm_widget.toggled_32bit) # type: ignore - self.model_version.currentTextChanged['QString'].connect(llm_widget.model_version_changed) # type: ignore - self.seed.textEdited['QString'].connect(llm_widget.seed_changed) # type: ignore - self.random_seed.toggled['bool'].connect(llm_widget.random_seed_toggled) # type: ignore - self.do_sample.toggled['bool'].connect(llm_widget.do_sample_toggled) # type: ignore - self.early_stopping.toggled['bool'].connect(llm_widget.early_stopping_toggled) # type: ignore - self.model.currentTextChanged['QString'].connect(llm_widget.model_text_changed) # type: ignore - self.pushButton.clicked.connect(llm_widget.reset_settings_to_default_clicked) # type: ignore - self.use_gpu_checkbox.toggled['bool'].connect(llm_widget.use_gpu_toggled) # type: ignore - self.override_parameters.toggled['bool'].connect(llm_widget.override_parameters_toggled) # type: ignore - self.botname.textEdited['QString'].connect(llm_widget.botname_text_changed) # type: ignore - self.username.textEdited['QString'].connect(llm_widget.username_text_changed) # type: ignore - self.personality_type.textHighlighted['QString'].connect(llm_widget.personality_type_changed) # type: ignore - self.generate_characters_button.clicked.connect(llm_widget.action_button_clicked_generate_characters) # type: ignore - self.radio_button_2bit.toggled['bool'].connect(llm_widget.toggled_2bit) # type: ignore - self.leave_in_vram.toggled['bool'].connect(llm_widget.toggle_leave_model_in_vram) # type: ignore - self.move_to_cpu.toggled['bool'].connect(llm_widget.toggle_move_model_to_cpu) # type: ignore - self.unload_model.toggled['bool'].connect(llm_widget.toggle_unload_model) # type: ignore - self.prompt_template.currentTextChanged['QString'].connect(llm_widget.prompt_template_text_changed) # type: ignore QtCore.QMetaObject.connectSlotsByName(llm_widget) - llm_widget.setTabOrder(self.prompt, self.comboBox) - llm_widget.setTabOrder(self.comboBox, self.send_button) - llm_widget.setTabOrder(self.send_button, self.clear_conversatiion_button) - llm_widget.setTabOrder(self.clear_conversatiion_button, self.prefix) - llm_widget.setTabOrder(self.prefix, self.suffix) - llm_widget.setTabOrder(self.suffix, self.botname) - llm_widget.setTabOrder(self.botname, self.username) - llm_widget.setTabOrder(self.username, self.generate_characters_button) - llm_widget.setTabOrder(self.generate_characters_button, self.model) - llm_widget.setTabOrder(self.model, self.model_version) - llm_widget.setTabOrder(self.model_version, self.use_gpu_checkbox) - llm_widget.setTabOrder(self.use_gpu_checkbox, self.radio_button_4bit) - llm_widget.setTabOrder(self.radio_button_4bit, self.radio_button_8bit) - llm_widget.setTabOrder(self.radio_button_8bit, self.radio_button_16bit) - llm_widget.setTabOrder(self.radio_button_16bit, self.radio_button_32bit) - llm_widget.setTabOrder(self.radio_button_32bit, self.override_parameters) - llm_widget.setTabOrder(self.override_parameters, self.seed) - llm_widget.setTabOrder(self.seed, self.random_seed) - llm_widget.setTabOrder(self.random_seed, self.tabWidget) - llm_widget.setTabOrder(self.tabWidget, self.comboBox) - llm_widget.setTabOrder(self.comboBox, self.prompt) def retranslateUi(self, llm_widget): _translate = QtCore.QCoreApplication.translate llm_widget.setWindowTitle(_translate("llm_widget", "Form")) - self.send_button.setText(_translate("llm_widget", "Send")) - self.clear_conversatiion_button.setText(_translate("llm_widget", "New Conversation")) - self.prompt.setPlaceholderText(_translate("llm_widget", "Chat input")) - self.comboBox.setItemText(0, _translate("llm_widget", "Chat")) - self.comboBox.setItemText(1, _translate("llm_widget", "Narrate")) - self.comboBox.setItemText(2, _translate("llm_widget", "Generate Image")) - self.comboBox.setItemText(3, _translate("llm_widget", "Summarize")) - self.comboBox.setItemText(4, _translate("llm_widget", "Translate")) self.tabWidget.setTabText(self.tabWidget.indexOf(self.chat), _translate("llm_widget", "Chat")) - self.groupBox.setTitle(_translate("llm_widget", "Prefix")) - self.groupBox_2.setTitle(_translate("llm_widget", "Suffix")) - self.groupBox_3.setTitle(_translate("llm_widget", "Bot details")) - self.label.setText(_translate("llm_widget", "Name")) - self.botname.setText(_translate("llm_widget", "ChatAI")) - self.label_2.setText(_translate("llm_widget", "Personality")) - self.personality_type.setItemText(0, _translate("llm_widget", "Nice")) - self.personality_type.setItemText(1, _translate("llm_widget", "Mean")) - self.personality_type.setItemText(2, _translate("llm_widget", "Weird")) - self.personality_type.setItemText(3, _translate("llm_widget", "Insane")) - self.personality_type.setItemText(4, _translate("llm_widget", "Random")) - self.groupBox_4.setTitle(_translate("llm_widget", "User name")) - self.username.setText(_translate("llm_widget", "User")) - self.generate_characters_button.setText(_translate("llm_widget", "Generate Characters")) self.tabWidget.setTabText(self.tabWidget.indexOf(self.preferences), _translate("llm_widget", "Preferences")) - self.groupBox_7.setTitle(_translate("llm_widget", "Model Type")) - self.groupBox_6.setTitle(_translate("llm_widget", "DType")) - self.radio_button_2bit.setText(_translate("llm_widget", "2-bit")) - self.radio_button_4bit.setText(_translate("llm_widget", "4-bit")) - self.radio_button_8bit.setText(_translate("llm_widget", "8-bit")) - self.radio_button_16bit.setText(_translate("llm_widget", "16-bit")) - self.radio_button_32bit.setText(_translate("llm_widget", "32-bit")) - self.dtype_description.setText(_translate("llm_widget", "Description")) - self.use_gpu_checkbox.setText(_translate("llm_widget", "Use GPU")) - self.override_parameters.setTitle(_translate("llm_widget", "Override Prameters")) - self.pushButton.setText(_translate("llm_widget", "Reset Settings to Default")) - self.groupBox_19.setTitle(_translate("llm_widget", "Seed")) - self.random_seed.setText(_translate("llm_widget", "Random seed")) - self.early_stopping.setText(_translate("llm_widget", "Early stopping")) - self.do_sample.setText(_translate("llm_widget", "Do sample")) - self.groupBox_11.setTitle(_translate("llm_widget", "Repetition penalty")) - self.repetition_penalty.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.repetition_penalty")) - self.repetition_penalty.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.groupBox_16.setTitle(_translate("llm_widget", "Min length")) - self.min_length.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.min_length")) - self.min_length.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.groupBox_20.setTitle(_translate("llm_widget", "Top P")) - self.top_p.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.top_p")) - self.top_p.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.groupBox_21.setTitle(_translate("llm_widget", "Max length")) - self.max_length.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.max_length")) - self.max_length.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.leave_in_vram.setText(_translate("llm_widget", "Leave in VRAM")) - self.move_to_cpu.setText(_translate("llm_widget", "Move to CPU")) - self.unload_model.setText(_translate("llm_widget", "Unload model")) - self.label_3.setText(_translate("llm_widget", "Model management")) - self.groupBox_17.setTitle(_translate("llm_widget", "No repeat ngram size")) - self.ngram_size.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.ngram_size")) - self.ngram_size.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.groupBox_18.setTitle(_translate("llm_widget", "Temperature")) - self.temperature.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.temperature")) - self.temperature.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.groupBox_24.setTitle(_translate("llm_widget", "Length penalty")) - self.length_penalty.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.length_penalty")) - self.length_penalty.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.groupBox_25.setTitle(_translate("llm_widget", "Num beams")) - self.num_beams.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.num_beams")) - self.num_beams.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.groupBox_22.setTitle(_translate("llm_widget", "Sequences to generate")) - self.sequences.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.sequences")) - self.sequences.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.groupBox_23.setTitle(_translate("llm_widget", "Top k")) - self.top_k.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.top_k")) - self.top_k.setProperty("slider_callback", _translate("llm_widget", "handle_value_change")) - self.label_4.setText(_translate("llm_widget", "How to treat model when not in use")) - self.groupBox_8.setTitle(_translate("llm_widget", "Model Version")) - self.groupBox_14.setTitle(_translate("llm_widget", "Prompt Template")) self.tabWidget.setTabText(self.tabWidget.indexOf(self.settings), _translate("llm_widget", "Settings")) -from airunner.widgets.slider.slider_widget import SliderWidget +from airunner.widgets.llm.chat_prompt_widget import ChatPromptWidget +from airunner.widgets.llm.llm_preferences_widget import LLMPreferencesWidget +from airunner.widgets.llm.llm_settings_widget import LLMSettingsWidget diff --git a/src/airunner/widgets/llm/templates/message.ui b/src/airunner/widgets/llm/templates/message.ui new file mode 100644 index 000000000..f1ffdeab5 --- /dev/null +++ b/src/airunner/widgets/llm/templates/message.ui @@ -0,0 +1,87 @@ + + + message + + + + 0 + 0 + 394 + 92 + + + + Form + + + + + + + 0 + 0 + + + + + 0 + 50 + + + + + + + TextLabel + + + Qt::AlignBottom|Qt::AlignLeading|Qt::AlignLeft + + + + + + + + 0 + 0 + + + + + 0 + 40 + + + + border-radius: 5px; border: 5px solid #1f1f1f; background-color: #1f1f1f; color: #ffffff; + + + QFrame::NoFrame + + + QFrame::Plain + + + true + + + + + + + TextLabel + + + Qt::AlignBottom|Qt::AlignLeading|Qt::AlignLeft + + + + + + + + + + + diff --git a/src/airunner/widgets/llm/templates/message_ui.py b/src/airunner/widgets/llm/templates/message_ui.py new file mode 100644 index 000000000..bc0bd80d3 --- /dev/null +++ b/src/airunner/widgets/llm/templates/message_ui.py @@ -0,0 +1,58 @@ +# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/llm/templates/message.ui' +# +# Created by: PyQt6 UI code generator 6.4.2 +# +# WARNING: Any manual changes made to this file will be lost when pyuic6 is +# run again. Do not edit this file unless you know what you are doing. + + +from PyQt6 import QtCore, QtGui, QtWidgets + + +class Ui_message(object): + def setupUi(self, message): + message.setObjectName("message") + message.resize(394, 92) + self.gridLayout_2 = QtWidgets.QGridLayout(message) + self.gridLayout_2.setObjectName("gridLayout_2") + self.widget = QtWidgets.QWidget(parent=message) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Preferred, QtWidgets.QSizePolicy.Policy.MinimumExpanding) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.widget.sizePolicy().hasHeightForWidth()) + self.widget.setSizePolicy(sizePolicy) + self.widget.setMinimumSize(QtCore.QSize(0, 50)) + self.widget.setObjectName("widget") + self.gridLayout = QtWidgets.QGridLayout(self.widget) + self.gridLayout.setObjectName("gridLayout") + self.user_name = QtWidgets.QLabel(parent=self.widget) + self.user_name.setAlignment(QtCore.Qt.AlignmentFlag.AlignBottom|QtCore.Qt.AlignmentFlag.AlignLeading|QtCore.Qt.AlignmentFlag.AlignLeft) + self.user_name.setObjectName("user_name") + self.gridLayout.addWidget(self.user_name, 0, 0, 1, 1) + self.content = QtWidgets.QPlainTextEdit(parent=self.widget) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.MinimumExpanding) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.content.sizePolicy().hasHeightForWidth()) + self.content.setSizePolicy(sizePolicy) + self.content.setMinimumSize(QtCore.QSize(0, 40)) + self.content.setStyleSheet("border-radius: 5px; border: 5px solid #1f1f1f; background-color: #1f1f1f; color: #ffffff;") + self.content.setFrameShape(QtWidgets.QFrame.Shape.NoFrame) + self.content.setFrameShadow(QtWidgets.QFrame.Shadow.Plain) + self.content.setReadOnly(True) + self.content.setObjectName("content") + self.gridLayout.addWidget(self.content, 0, 1, 1, 1) + self.bot_name = QtWidgets.QLabel(parent=self.widget) + self.bot_name.setAlignment(QtCore.Qt.AlignmentFlag.AlignBottom|QtCore.Qt.AlignmentFlag.AlignLeading|QtCore.Qt.AlignmentFlag.AlignLeft) + self.bot_name.setObjectName("bot_name") + self.gridLayout.addWidget(self.bot_name, 0, 2, 1, 1) + self.gridLayout_2.addWidget(self.widget, 0, 0, 1, 1) + + self.retranslateUi(message) + QtCore.QMetaObject.connectSlotsByName(message) + + def retranslateUi(self, message): + _translate = QtCore.QCoreApplication.translate + message.setWindowTitle(_translate("message", "Form")) + self.user_name.setText(_translate("message", "TextLabel")) + self.bot_name.setText(_translate("message", "TextLabel")) diff --git a/src/airunner/widgets/model_manager/model_manager_widget.py b/src/airunner/widgets/model_manager/model_manager_widget.py index 7effc5aa7..bedfccf3a 100644 --- a/src/airunner/widgets/model_manager/model_manager_widget.py +++ b/src/airunner/widgets/model_manager/model_manager_widget.py @@ -23,21 +23,6 @@ class ModelManagerWidget(BaseWidget): current_model_data = None _current_model_object = None model_form = None - icons = { - "toolButton": "010-view", - "edit_button": "settings", - "delete_button": "006-trash", - } - - def set_stylesheet(self): - for key in self.model_widgets.keys(): - for model_widget in self.model_widgets[key]: - for button, icon in self.icons.items(): - getattr(model_widget, button).setIcon( - QtGui.QIcon( - os.path.join(f"src/icons/{icon}{'-light' if self.is_dark else ''}.png") - ) - ) @property def current_model_object(self): diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py index 25a903653..b01d6c5f4 100644 --- a/src/airunner/windows/main/main_window.py +++ b/src/airunner/windows/main/main_window.py @@ -30,7 +30,7 @@ from airunner.input_event_manager import InputEventManager from airunner.mixins.history_mixin import HistoryMixin from airunner.settings import BASE_PATH -from airunner.utils import get_version, get_latest_version, auto_export_image, save_session, \ +from airunner.utils import get_version, auto_export_image, save_session, \ create_airunner_paths, default_hf_cache_dir from airunner.widgets.status.status_widget import StatusWidget from airunner.windows.about.about import AboutWindow @@ -42,9 +42,10 @@ from airunner.windows.video import VideoPopup from airunner.data.models import TabSection from airunner.widgets.brushes.brushes_container import BrushesContainer - +from airunner.data.models import Document from airunner.utils import get_session + class MainWindow( QMainWindow, HistoryMixin @@ -65,8 +66,6 @@ class MainWindow( _settings_manager = None models = None client = None - override_current_generator = None - override_section = None _version = None _latest_version = None add_image_to_canvas_signal = pyqtSignal(dict) @@ -112,14 +111,6 @@ def settings_manager(self): self._settings_manager = SettingsManager(app=self) return self._settings_manager - # @property - # def current_prompt_generator_settings(self): - # """ - # Convenience property to get the current prompt generator settings - # :return: - # """ - # return self.settings_manager.prompt_generator_settings - @property def is_dark(self): return self.settings_manager.dark_mode_enabled @@ -135,6 +126,10 @@ def standard_image_panel(self): @property def generator_tab_widget(self): return self.ui.generator_widget + + @property + def canvas_widget(self): + return self.standard_image_panel.canvas_widget @property def toolbar_widget(self): @@ -148,39 +143,6 @@ def prompt_builder(self): def footer_widget(self): return self.ui.footer_widget - @property - def current_generator(self): - """ - Returns the current generator (stablediffusion, kandinksy, etc) as - determined by the selected generator tab in the - generator_tab_widget. This value can be override by setting - the override_current_generator property. - :return: string - """ - if self.override_current_generator: - return self.override_current_generator - return self.generator_tab_widget.current_generator - - @property - def current_section(self): - """ - Returns the current section (txt2img, outpaint, etc) as - determined by the selected sub-tab in the generator tab widget. - This value can be override by setting the override_section property. - :return: string - """ - if self.override_section: - return self.override_section - return self.generator_tab_widget.current_section - - @property - def tabs(self): - return self._tabs[self.current_generator] - - @tabs.setter - def tabs(self, val): - self._tabs[self.current_generator] = val - @property def generator_type(self): """ @@ -205,7 +167,7 @@ def latest_version(self, val): @property def document_name(self): - # name = f"{self._document_name}{'*' if self.canvas and self.canvas.is_dirty else ''}" + # name = f"{self._document_name}{'*' if self.canvas and self.canvas_widget.is_dirty else ''}" # return f"{name} - {self.version}" return "Untitled" @@ -234,13 +196,13 @@ def current_canvas(self): return self.standard_image_panel def describe_image(self, image, callback): - self.ui.generator_widget.current_generator_widget.ui.ai_tab_widget.describe_image( + self.generator_tab_widget.ui.ai_tab_widget.describe_image( image=image, callback=callback ) def current_active_image(self): - return self.ui.standard_image_widget.image.copy() if self.ui.standard_image_widget.image else None + return self.standard_image_panel.image.copy() if self.standard_image_panel.image else None def send_message(self, code, message): self.message_var.emit({ @@ -257,6 +219,7 @@ def available_model_names_by_section(self, section): def __init__(self, *args, **kwargs): logger.info("Starting AI Runnner") + self.ui = Ui_MainWindow() # qdarktheme.enable_hi_dpi() # set the api @@ -266,20 +229,17 @@ def __init__(self, *args, **kwargs): self.testing = kwargs.pop("testing", False) # initialize the document - from airunner.data.db import session - from airunner.data.models import Document + session = get_session() self.document = session.query(Document).first() super().__init__(*args, **kwargs) + self.ui.setupUi(self) + self.initialize() # on window resize: - # self.applicationStateChanged.connect(self.on_state_changed) - - if self.settings_manager.latest_version_check: - logger.info("Checking for latest version") - self.check_for_latest_version() + # self.windowStateChanged.connect(self.on_state_changed) # check for self.current_layer.lines every 100ms self.timer = self.startTimer(100) @@ -305,7 +265,6 @@ def __init__(self, *args, **kwargs): #self.ui.layer_widget.initialize() self.set_button_checked("toggle_grid", self.settings_manager.grid_settings.show_grid, False) - self.set_button_checked("safety_checker", self.settings_manager.nsfw_filter, False) # call a function after the window has finished loading: QTimer.singleShot(500, self.on_show) @@ -363,7 +322,6 @@ def mode_tab_index_changed(self, index): self.settings_manager.set_value("mode", self.ui.mode_tab_widget.tabText(index)) def on_show(self): - #self.ui.canvas_plus_widget.do_draw() pass def action_slider_changed(self, settings_property, value): @@ -412,23 +370,23 @@ def action_redo_triggered(self): def action_paste_image_triggered(self): if self.settings_manager.mode == Mode.IMAGE.value: - self.canvas.paste_image_from_clipboard() + self.canvas_widget.paste_image_from_clipboard() def action_copy_image_triggered(self): if self.settings_manager.mode == Mode.IMAGE.value: - self.canvas.copy_image(self.current_active_image()) + self.canvas_widget.copy_image(self.current_active_image()) def action_cut_image_triggered(self): if self.settings_manager.mode == Mode.IMAGE.value: - self.canvas.cut_image() + self.canvas_widget.cut_image() def action_rotate_90_clockwise_triggered(self): if self.settings_manager.mode == Mode.IMAGE.value: - self.canvas.rotate_90_clockwise() + self.canvas_widget.rotate_90_clockwise() def action_rotate_90_counterclockwise_triggered(self): if self.settings_manager.mode == Mode.IMAGE.value: - self.canvas.rotate_90_counterclockwise() + self.canvas_widget.rotate_90_counterclockwise() def action_show_prompt_browser_triggered(self): self.show_prompt_browser() @@ -549,7 +507,7 @@ def action_toggle_nsfw_filter_triggered(self, bool): def action_toggle_grid(self, active): self.settings_manager.set_value("grid_settings.show_grid", active) - # self.canvas.update() + # self.canvas_widget.update() def action_toggle_darkmode(self): self.set_stylesheet() @@ -624,10 +582,10 @@ def set_size_increment_levels(self): self.ui.height_slider_widget.slider_single_step = size self.ui.height_slider_widget.slider_tick_interval = size - self.canvas.update() + self.canvas_widget.update() def toggle_nsfw_filter(self): - # self.canvas.update() + # self.canvas_widget.update() self.set_nsfw_filter_tooltip() def set_nsfw_filter_tooltip(self): @@ -636,43 +594,43 @@ def set_nsfw_filter_tooltip(self): f"Click to {'enable' if not nsfw_filter else 'disable'} NSFW filter" ) - def resizeEvent(self, event): - if not self.is_started: - return - state = self.windowState() - if state == Qt.WindowState.WindowMaximized: - timer = QTimer(self) - timer.setSingleShot(True) - timer.timeout.connect(self.checkWindowState) - timer.start(100) - else: - self.checkWindowState() - - def checkWindowState(self): - state = self.windowState() - self.is_maximized = state == Qt.WindowState.WindowMaximized + # def resizeEvent(self, event): + # if not self.is_started: + # return + # state = self.windowState() + # if state == Qt.WindowState.WindowMaximized: + # timer = QTimer(self) + # timer.setSingleShot(True) + # timer.timeout.connect(self.checkWindowState) + # timer.start(100) + # else: + # self.checkWindowState() + + # def checkWindowState(self): + # state = self.windowState() + # self.is_maximized = state == Qt.WindowState.WindowMaximized def dragmode_pressed(self): - # self.canvas.is_canvas_drag_mode = True + # self.canvas_widget.is_canvas_drag_mode = True pass def dragmode_released(self): - # self.canvas.is_canvas_drag_mode = False + # self.canvas_widget.is_canvas_drag_mode = False pass def shift_pressed(self): - # self.canvas.shift_is_pressed = True + # self.canvas_widget.shift_is_pressed = True pass def shift_released(self): - # self.canvas.shift_is_pressed = False + # self.canvas_widget.shift_is_pressed = False pass def register_keypress(self): self.input_event_manager.register_keypress("fullscreen", self.toggle_fullscreen) self.input_event_manager.register_keypress("control_pressed", self.dragmode_pressed, self.dragmode_released) self.input_event_manager.register_keypress("shift_pressed", self.shift_pressed, self.shift_released) - #self.input_event_manager.register_keypress("delete_outside_active_grid_area", self.canvas.delete_outside_active_grid_area) + #self.input_event_manager.register_keypress("delete_outside_active_grid_area", self.canvas_widget.delete_outside_active_grid_area) def toggle_fullscreen(self): if self.isFullScreen(): @@ -688,49 +646,10 @@ def closeEvent(self, event): QApplication.quit() def timerEvent(self, event): - # self.canvas.timerEvent(event) + # self.canvas_widget.timerEvent(event) if self.status_widget: self.status_widget.update_system_stats(queue_size=self.client.queue.qsize()) - def check_for_latest_version(self): - self.version_thread = QThread() - class VersionCheckWorker(QObject): - version = None - finished = pyqtSignal() - def get_latest_version(self): - self.version = f"v{get_latest_version()}" - self.finished.emit() - self.version_worker = VersionCheckWorker() - self.version_worker.moveToThread(self.version_thread) - self.version_thread.started.connect(self.version_worker.get_latest_version) - self.version_worker.finished.connect(self.handle_latest_version) - self.version_thread.start() - - def handle_latest_version(self): - self.latest_version = self.version_worker.version - # call get_latest_version() in a separate thread - # to avoid blocking the UI, show a popup if version doesn't match self.version - # check if latest_version is greater than version using major, minor, patch - current_major, current_minor, current_patch = self.version[1:].split(".") - try: - latest_major, latest_minor, latest_patch = self.latest_version[1:].split(".") - except ValueError: - latest_major, latest_minor, latest_patch = 0, 0, 0 - - latest_major = int(latest_major) - latest_minor = int(latest_minor) - latest_patch = int(latest_patch) - current_major = int(current_major) - current_minor = int(current_minor) - current_patch = int(current_patch) - - if current_major == latest_major and current_minor == latest_minor and current_patch < latest_patch: - self.show_update_message() - elif current_major == latest_major and current_minor < latest_minor: - self.show_update_message() - elif current_major < latest_major: - self.show_update_message() - def show_update_message(self): self.set_status_label(f"New version available: {self.latest_version}") @@ -739,13 +658,13 @@ def show_update_popup(self): def reset_settings(self): logger.info("Resetting settings") - self.canvas.reset_settings() + self.canvas_widget.reset_settings() def on_state_changed(self, state): if state == Qt.ApplicationState.ApplicationActive: - self.canvas.pos_x = int(self.x() / 4) - self.canvas.pos_y = int(self.y() / 2) - self.canvas.update() + self.canvas_widget.pos_x = int(self.x() / 4) + self.canvas_widget.pos_y = int(self.y() / 2) + self.canvas_widget.update() def refresh_styles(self): self.set_stylesheet() @@ -771,7 +690,7 @@ def set_stylesheet(self): ("eraser-icon", "toggle_eraser_button"), ("frame-grid-icon", "toggle_grid_button"), ("circle-center-icon", "focus_button"), - ("adult-sign-icon", "safety_checker_button"), + ("artificial-intelligence-ai-chip-icon", "ai_button"), ("setting-line-icon", "settings_button"), ("chat-box-icon", "chat_button"), ("setting-line-icon", "llm_preferences_button"), @@ -842,7 +761,7 @@ def initialize_mixins(self): def connect_signals(self): logger.info("Connecting signals") - #self.canvas._is_dirty.connect(self.set_window_title) + #self.canvas_widget._is_dirty.connect(self.set_window_title) for signal, handler in self.registered_settings_handlers: getattr(self.settings_manager, signal).connect(handler) @@ -944,16 +863,13 @@ def set_splitter_sizes(self): def show_section(self, section): section_lists = { - "left": [self.ui.generator_widget.ui.generator_tabs.tabText(i) for i in range(self.ui.generator_widget.ui.generator_tabs.count())], "center": [self.ui.center_tab.tabText(i) for i in range(self.ui.center_tab.count())], "right": [self.ui.tool_tab_widget.tabText(i) for i in range(self.ui.tool_tab_widget.count())], "bottom": [self.ui.bottom_panel_tab_widget.tabText(i) for i in range(self.ui.bottom_panel_tab_widget.count())] } for k, v in section_lists.items(): if section in v: - if k == "left": - self.ui.generator_widget.ui.generator_tabs.setCurrentIndex(v.index(section)) - elif k == "right": + if k == "right": self.ui.tool_tab_widget.setCurrentIndex(v.index(section)) elif k == "bottom": self.ui.bottom_panel_tab_widget.setCurrentIndex(v.index(section)) @@ -992,7 +908,7 @@ def handle_value_change(self, attr_name, value=None, widget=None): self.settings_manager.set_value(attr_name, value) def handle_similar_slider_change(self, attr_name, value=None, widget=None): - self.ui.standard_image_widget.handle_similar_slider_change(value) + self.standard_image_panel.handle_similar_slider_change(value) def initialize_settings_manager(self): self.settings_manager.changed_signal.connect(self.handle_changed_signal) @@ -1003,11 +919,11 @@ def handle_changed_signal(self, key, value): elif key == "line_width": self.set_size_form_element_step_values() elif key == "show_grid": - self.canvas.update() + self.canvas_widget.update() elif key == "snap_to_grid": - self.canvas.update() + self.canvas_widget.update() elif key == "line_color": - self.canvas.update_grid_pen() + self.canvas_widget.update_grid_pen() elif key == "lora_path": self.refresh_lora() elif key == "model_base_path": @@ -1031,9 +947,6 @@ def initialize_handlers(self): self.message_var.my_signal.connect(self.message_handler) def initialize_window(self): - self.window = Ui_MainWindow() - self.window.setupUi(self) - self.ui = self.window self.center() self.set_window_title() @@ -1123,7 +1036,7 @@ def new_document(self): self._document_name = "Untitled" self.set_window_title() self.current_filter = None - #self.canvas.update() + #self.canvas_widget.update() self.ui.layer_widget.show_layers() def set_status_label(self, txt, error=False): @@ -1195,7 +1108,7 @@ def handle_image_generated(self, message): # get max progressbar value if nsfw_content_detected and self.settings_manager.nsfw_filter: self.message_handler({ - "message": "NSFW content detected, try again.", + "message": "Explicit content detected, try again.", "code": MessageCode.ERROR }) @@ -1261,7 +1174,7 @@ def set_size_form_element_step_values(self): def saveas_document(self): # get file path file_path, _ = QFileDialog.getSaveFileName( - self.window, "Save Document", "", "AI Runner Document (*.airunner)" + self.ui, "Save Document", "", "AI Runner Document (*.airunner)" ) if file_path == "": return @@ -1285,8 +1198,8 @@ def do_save(self, document_name): layers.append(layer) data = { "layers": layers, - "image_pivot_point": self.canvas.image_pivot_point, - "image_root_point": self.canvas.image_root_point, + "image_pivot_point": self.canvas_widget.image_pivot_point, + "image_root_point": self.canvas_widget.image_root_point, } with open(document_name, "wb") as f: pickle.dump(data, f) @@ -1298,10 +1211,10 @@ def do_save(self, document_name): self._document_name = document_name.split("/")[-1].split(".")[0] self.set_window_title() self.is_saved = True - self.canvas.is_dirty = False + self.canvas_widget.is_dirty = False def update(self): - self.ui.standard_image_widget.update_thumbnails() + self.standard_image_panel.update_thumbnails() def insert_into_prompt(self, text, negative_prompt=False): prompt_widget = self.generator_tab_widget.data[self.current_generator][self.current_section]["prompt_widget"] @@ -1334,28 +1247,10 @@ def change_content_widget(self): self.ui.center_tab.setCurrentIndex(tab_index) self.ui.center_tab.blockSignals(False) - def handle_generator_tab_changed(self): - self.generator_tab_changed_signal.emit() - self.change_content_widget() - - def handle_tab_section_changed(self): - self.tab_section_changed_signal.emit() - self.change_content_widget() - - def release_tab_overrides(self): - self.override_current_generator = None - self.override_section = None - def clear_all_prompts(self): - for tab_section in self._tabs.keys(): - self.override_current_generator = tab_section - for tab in self.tabs.keys(): - self.override_section = tab - self.prompt = "" - self.negative_prompt = "" - self.generator_tab_widget.clear_prompts(tab_section, tab) - self.override_current_generator = None - self.override_section = None + self.prompt = "" + self.negative_prompt = "" + self.generator_tab_widget.clear_prompts() def show_prompt_browser(self): PromptBrowser(settings_manager=self.settings_manager, app=self) @@ -1410,7 +1305,7 @@ def update_negative_prompt(self, prompt_value): self.generator_tab_widget.update_negative_prompt(prompt_value) def new_batch(self, index, image, data): - self.generator_tab_widget.current_generator.new_batch(index, image, data) + self.generator_tab_widget.new_batch(index, image, data) def image_generation_toggled(self): self.settings_manager.set_value("mode", Mode.IMAGE.value) @@ -1485,10 +1380,6 @@ def activate_image_generation_section(self): def activate_language_processing_section(self): self.ui.mode_tab_widget.setCurrentIndex(1) - # try: - # self.ui.generator_widget.current_generator_widget.ui.generator_form_tab_widget.setCurrentIndex(2) - # except AttributeError as e: - # pass self.toggle_tool_section_buttons_visibility() def activate_model_manager_section(self): @@ -1534,12 +1425,7 @@ def set_all_image_generator_buttons(self): def image_generators_toggled(self): self.image_generation_toggled() - current_tab = self.settings_manager.current_tab - current_tab = self.settings_manager.current_image_generator - self.settings_manager.set_value("current_tab", current_tab) self.settings_manager.set_value("mode", Mode.IMAGE.value) - self.settings_manager.set_value(f"current_section_{current_tab}", GeneratorSection.TXT2IMG.value) - self.generator_tab_widget.set_current_section_tab() self.settings_manager.set_value("generator_section", GeneratorSection.TXT2IMG.value) active_tab_obj = session.query(TabSection).filter(TabSection.panel == "center_tab").first() active_tab_obj.active_tab = "Canvas" @@ -1549,11 +1435,7 @@ def image_generators_toggled(self): def text_to_video_toggled(self): self.image_generation_toggled() - current_tab = "stablediffusion" - self.settings_manager.set_value("current_tab", current_tab) self.settings_manager.set_value("mode", Mode.IMAGE.value) - self.settings_manager.set_value(f"current_section_{current_tab}", GeneratorSection.TXT2VID.value) - self.generator_tab_widget.set_current_section_tab() self.settings_manager.set_value("generator_section", GeneratorSection.TXT2VID.value) active_tab_obj = session.query(TabSection).filter(TabSection.panel == "center_tab").first() active_tab_obj.active_tab = "Video" @@ -1563,9 +1445,6 @@ def text_to_video_toggled(self): def prompt_builder_toggled(self): self.image_generation_toggled() - current_tab = self.settings_manager.current_tab - self.settings_manager.set_value(f"current_section_{current_tab}", GeneratorSection.PROMPT_BUILDER.value) - self.generator_tab_widget.set_current_section_tab() self.settings_manager.set_value(f"generator_section", GeneratorSection.PROMPT_BUILDER.value) active_tab_obj = session.query(TabSection).filter(TabSection.panel == "center_tab").first() active_tab_obj.active_tab = "Prompt Builder" diff --git a/src/airunner/windows/main/templates/main_window.ui b/src/airunner/windows/main/templates/main_window.ui index bb5fdf015..514bdf4cd 100644 --- a/src/airunner/windows/main/templates/main_window.ui +++ b/src/airunner/windows/main/templates/main_window.ui @@ -7,12 +7,12 @@ 0 0 - 1085 - 1141 + 1144 + 955 - + 0 0 @@ -128,7 +128,7 @@ 0 0 - 1081 + 1140 50 @@ -1058,8 +1058,8 @@ 0 0 - 365 - 1017 + 386 + 831 @@ -1425,20 +1425,21 @@ - + - + 0 0 - - - 50 - 45 - + + Qt::Horizontal - + + + + + 50 45 @@ -1447,38 +1448,26 @@ PointingHandCursor - - Toggle NSFW - - :/icons/light/adult-sign-icon.svg:/icons/light/adult-sign-icon.svg + :/icons/light/artificial-intelligence-ai-chip-icon.svg:/icons/light/artificial-intelligence-ai-chip-icon.svg - 18 - 18 + 20 + 20 - - true - true - - - - 0 - 0 - - + Qt::Horizontal @@ -1602,7 +1591,7 @@ 0 0 - 1085 + 1144 22 @@ -3039,7 +3028,7 @@ action_new_document_triggered() - 857 + 916 65 @@ -3055,7 +3044,7 @@ action_undo_triggered() - 1034 + 1093 65 @@ -3071,7 +3060,7 @@ action_redo_triggered() - 1076 + 1135 65 @@ -3087,7 +3076,7 @@ action_export_image_triggered() - 983 + 1042 65 @@ -3240,22 +3229,6 @@ - - safety_checker_button - toggled(bool) - MainWindow - action_toggle_nsfw_filter_triggered(bool) - - - 1079 - 422 - - - 1002 - 624 - - - settings_button clicked() @@ -3263,8 +3236,8 @@ action_show_settings() - 1079 - 482 + 1138 + 426 1002 @@ -3279,8 +3252,8 @@ action_toggle_grid(bool) - 1079 - 320 + 1138 + 315 1002 @@ -3295,8 +3268,8 @@ action_toggle_active_grid_area(bool) - 1079 - 158 + 1138 + 153 1002 @@ -3311,8 +3284,8 @@ action_toggle_eraser(bool) - 1079 - 260 + 1138 + 255 0 @@ -3327,8 +3300,8 @@ action_toggle_brush(bool) - 1079 - 209 + 1138 + 204 0 diff --git a/src/airunner/windows/main/templates/main_window_ui.py b/src/airunner/windows/main/templates/main_window_ui.py index 793c73ef6..9597e9657 100644 --- a/src/airunner/windows/main/templates/main_window_ui.py +++ b/src/airunner/windows/main/templates/main_window_ui.py @@ -12,8 +12,8 @@ class Ui_MainWindow(object): def setupUi(self, MainWindow): MainWindow.setObjectName("MainWindow") - MainWindow.resize(1085, 1141) - sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Maximum, QtWidgets.QSizePolicy.Policy.Preferred) + MainWindow.resize(1144, 955) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.MinimumExpanding, QtWidgets.QSizePolicy.Policy.Preferred) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth()) @@ -58,7 +58,7 @@ def setupUi(self, MainWindow): self.scrollArea_3.setWidgetResizable(True) self.scrollArea_3.setObjectName("scrollArea_3") self.scrollAreaWidgetContents_3 = QtWidgets.QWidget() - self.scrollAreaWidgetContents_3.setGeometry(QtCore.QRect(0, 0, 1081, 50)) + self.scrollAreaWidgetContents_3.setGeometry(QtCore.QRect(0, 0, 1140, 50)) self.scrollAreaWidgetContents_3.setObjectName("scrollAreaWidgetContents_3") self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.scrollAreaWidgetContents_3) self.horizontalLayout_2.setContentsMargins(0, 0, 0, 0) @@ -385,7 +385,7 @@ def setupUi(self, MainWindow): self.scrollArea.setWidgetResizable(True) self.scrollArea.setObjectName("scrollArea") self.scrollAreaWidgetContents = QtWidgets.QWidget() - self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 365, 1017)) + self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 386, 831)) self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents") self.gridLayout = QtWidgets.QGridLayout(self.scrollAreaWidgetContents) self.gridLayout.setObjectName("gridLayout") @@ -538,24 +538,6 @@ def setupUi(self, MainWindow): self.focus_button.setFlat(True) self.focus_button.setObjectName("focus_button") self.verticalLayout.addWidget(self.focus_button) - self.safety_checker_button = QtWidgets.QPushButton(parent=self.button_menu) - sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Maximum, QtWidgets.QSizePolicy.Policy.MinimumExpanding) - sizePolicy.setHorizontalStretch(0) - sizePolicy.setVerticalStretch(0) - sizePolicy.setHeightForWidth(self.safety_checker_button.sizePolicy().hasHeightForWidth()) - self.safety_checker_button.setSizePolicy(sizePolicy) - self.safety_checker_button.setMinimumSize(QtCore.QSize(50, 45)) - self.safety_checker_button.setMaximumSize(QtCore.QSize(50, 45)) - self.safety_checker_button.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor)) - self.safety_checker_button.setText("") - icon15 = QtGui.QIcon() - icon15.addPixmap(QtGui.QPixmap(":/icons/light/adult-sign-icon.svg"), QtGui.QIcon.Mode.Normal, QtGui.QIcon.State.Off) - self.safety_checker_button.setIcon(icon15) - self.safety_checker_button.setIconSize(QtCore.QSize(18, 18)) - self.safety_checker_button.setCheckable(True) - self.safety_checker_button.setFlat(True) - self.safety_checker_button.setObjectName("safety_checker_button") - self.verticalLayout.addWidget(self.safety_checker_button) self.line_4 = QtWidgets.QFrame(parent=self.button_menu) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.MinimumExpanding, QtWidgets.QSizePolicy.Policy.Fixed) sizePolicy.setHorizontalStretch(0) @@ -566,6 +548,22 @@ def setupUi(self, MainWindow): self.line_4.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken) self.line_4.setObjectName("line_4") self.verticalLayout.addWidget(self.line_4) + self.ai_button = QtWidgets.QPushButton(parent=self.button_menu) + self.ai_button.setMinimumSize(QtCore.QSize(50, 45)) + self.ai_button.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor)) + self.ai_button.setText("") + icon15 = QtGui.QIcon() + icon15.addPixmap(QtGui.QPixmap(":/icons/light/artificial-intelligence-ai-chip-icon.svg"), QtGui.QIcon.Mode.Normal, QtGui.QIcon.State.Off) + self.ai_button.setIcon(icon15) + self.ai_button.setIconSize(QtCore.QSize(20, 20)) + self.ai_button.setFlat(True) + self.ai_button.setObjectName("ai_button") + self.verticalLayout.addWidget(self.ai_button) + self.line_3 = QtWidgets.QFrame(parent=self.button_menu) + self.line_3.setFrameShape(QtWidgets.QFrame.Shape.HLine) + self.line_3.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken) + self.line_3.setObjectName("line_3") + self.verticalLayout.addWidget(self.line_3) self.settings_button = QtWidgets.QPushButton(parent=self.button_menu) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Maximum, QtWidgets.QSizePolicy.Policy.MinimumExpanding) sizePolicy.setHorizontalStretch(0) @@ -608,7 +606,7 @@ def setupUi(self, MainWindow): self.header_widget.raise_() MainWindow.setCentralWidget(self.centralwidget) self.menubar = QtWidgets.QMenuBar(parent=MainWindow) - self.menubar.setGeometry(QtCore.QRect(0, 0, 1085, 22)) + self.menubar.setGeometry(QtCore.QRect(0, 0, 1144, 22)) font = QtGui.QFont() font.setPointSize(11) self.menubar.setFont(font) @@ -961,7 +959,6 @@ def setupUi(self, MainWindow): self.image_generators_button.released.connect(MainWindow.image_generators_toggled) # type: ignore self.txt2vid_button.released.connect(MainWindow.text_to_video_toggled) # type: ignore self.prompt_builder_button.released.connect(MainWindow.prompt_builder_toggled) # type: ignore - self.safety_checker_button.toggled['bool'].connect(MainWindow.action_toggle_nsfw_filter_triggered) # type: ignore self.settings_button.clicked.connect(MainWindow.action_show_settings) # type: ignore self.toggle_grid_button.toggled['bool'].connect(MainWindow.action_toggle_grid) # type: ignore self.toggle_active_grid_area_button.toggled['bool'].connect(MainWindow.action_toggle_active_grid_area) # type: ignore @@ -1000,7 +997,6 @@ def retranslateUi(self, MainWindow): self.toggle_eraser_button.setToolTip(_translate("MainWindow", "Eraser tool")) self.toggle_grid_button.setToolTip(_translate("MainWindow", "Toggle Grid")) self.focus_button.setToolTip(_translate("MainWindow", "Recenter canvas")) - self.safety_checker_button.setToolTip(_translate("MainWindow", "Toggle NSFW")) self.settings_button.setToolTip(_translate("MainWindow", "AI Runner Settings")) self.mode_tab_widget.setTabText(self.mode_tab_widget.indexOf(self.art), _translate("MainWindow", "Image Generation")) self.mode_tab_widget.setTabText(self.mode_tab_widget.indexOf(self.chat), _translate("MainWindow", "Text Generation"))