diff --git a/README.md b/README.md
index 97c307199..d8dfdf207 100644
--- a/README.md
+++ b/README.md
@@ -12,7 +12,7 @@
---
-### Stable Diffusion and Kandinsky on your own hardware
+### Stable Diffusion on your own hardware
No web server to run, additional requirements to install or technical knowledge required.
diff --git a/src/airunner/aihandler/llm.py b/src/airunner/aihandler/llm.py
index 158ec85da..6319d7822 100644
--- a/src/airunner/aihandler/llm.py
+++ b/src/airunner/aihandler/llm.py
@@ -1,5 +1,13 @@
+import torch
+import traceback
+
from airunner.aihandler.transformer_runner import TransformerRunner
from airunner.aihandler.logger import Logger as logger
+from airunner.aihandler.enums import MessageCode
+import os
+from jinja2 import Environment, FileSystemLoader
+
+from transformers.pipelines.conversational import Conversation
class LLM(TransformerRunner):
@@ -8,34 +16,122 @@ def clear_conversation(self):
self.chain.clear()
def do_generate(self, data):
+ Logger.info("Generating with LLM")
self.process_data(data)
self.handle_request()
self.requested_generator_name = data["request_data"]["generator_name"]
prompt = data["request_data"]["prompt"]
model_path = data["request_data"]["model_path"]
- self.generate(
- app=self.app,
- endpoint=data["request_data"]["generator_name"],
- prompt=prompt,
- model=model_path,
- stream=data["request_data"]["stream"],
- images=[data["request_data"]["image"]],
+ return self.generate(
+ # app=self.app,
+ # endpoint=data["request_data"]["generator_name"],
+ # prompt=prompt,
+ # model=model_path,
+ # stream=data["request_data"]["stream"],
+ # images=[data["request_data"]["image"]],
)
+
+ history = []
+
+ def generate(self):
+ # Create a FileSystemLoader object with the directory of the template
+ HERE = os.path.dirname(os.path.abspath(__file__))
+ file_loader = FileSystemLoader(os.path.join(HERE, "chat_templates"))
+
+ # Create an Environment object with the FileSystemLoader object
+ env = Environment(loader=file_loader)
- def generate(self, **kwargs):
+ # Load the template
+ # Load the template
+ chat_template = env.get_template('chat.j2')
+
+ prompt = self.prompt
+ if prompt is None or prompt == "":
+ traceback.print_stack()
+ return
+
if self.generator.name == "casuallm":
- prompt = kwargs.get("prompt", "")
- logger.info(f"LLM requested with prompt {prompt}")
- return self.chain.run(prompt)
+ history = []
+ for message in self.history:
+ if message["role"] == "user":
+ history.append("[INST]" + self.username + ': "'+ message["content"] +'"[/INST]')
+ else:
+ history.append(self.botname + ': "'+ message["content"] +'"')
+ history = "\n".join(history)
+ if history == "":
+ history = None
+
+ # Create a dictionary with the variables
+ variables = {
+ "username": self.username,
+ "botname": self.botname,
+ "history": history,
+ "input": prompt,
+ "bos_token": self.tokenizer.bos_token,
+ "botmood": "angry. He hates " + self.username
+ }
+
+ self.history.append({
+ "role": "user",
+ "content": prompt
+ })
+
+ # Render the template with the variables
+ rendered_template = chat_template.render(variables)
+ #print(rendered_template)
+
+ # Encode the rendered template
+ encoded = self.tokenizer.encode(rendered_template, return_tensors="pt")
+
+ model_inputs = encoded.to("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Generate the response
+ generated_ids = self.model.generate(
+ model_inputs,
+ min_length=0,
+ max_length=1000,
+ num_beams=1,
+ do_sample=True,
+ top_k=20,
+ eta_cutoff=10,
+ top_p=1.0,
+ num_return_sequences=self.sequences,
+ eos_token_id=self.tokenizer.eos_token_id,
+ early_stopping=True,
+ repetition_penalty=1.15,
+ temperature=0.7,
+ )
+
+ # Decode the new tokens
+ decoded = self.tokenizer.batch_decode(generated_ids)[0]
+ decoded = decoded.replace(self.tokenizer.batch_decode(model_inputs)[0], "")
+ decoded = decoded.replace("", "")
+
+ # Extract the actual message content
+ start_index = decoded.find('"') + 1
+ end_index = decoded.rfind('"')
+ decoded = decoded[start_index:end_index]
+
+ self.history.append({
+ "role": "assistant",
+ "content": decoded
+ })
+
+ # print(self.history)
+
+ # print("*"*80)
+ # print(decoded)
+
+ #return decoded
+ self.engine.send_message(decoded, code=MessageCode.TEXT_GENERATED)
elif self.generator.name == "visualqa":
inputs = self.processor(
self.image,
- self.prompt,
+ prompt,
return_tensors="pt"
).to("cuda")
out = self.model.generate(
- **inputs,
- **kwargs,
+ **inputs
)
answers = []
diff --git a/src/airunner/aihandler/logger.py b/src/airunner/aihandler/logger.py
index b9ecd5723..a22fe7dcf 100644
--- a/src/airunner/aihandler/logger.py
+++ b/src/airunner/aihandler/logger.py
@@ -76,4 +76,3 @@ def error(cls, msg):
Logger.logger.addHandler(Logger.stream_handler)
logging.getLogger("lightning").setLevel(logging.WARNING)
logging.getLogger("lightning_fabric.utilities.seed").setLevel(logging.WARNING)
-Logger.set_level(LOG_LEVEL)
diff --git a/src/airunner/aihandler/runner.py b/src/airunner/aihandler/runner.py
index c4a24b9a3..b5f16601c 100644
--- a/src/airunner/aihandler/runner.py
+++ b/src/airunner/aihandler/runner.py
@@ -1,6 +1,4 @@
import base64
-import gc
-import os
import re
import traceback
from io import BytesIO
@@ -17,12 +15,11 @@
from diffusers.utils.torch_utils import randn_tensor
#from diffusers import ConsistencyDecoderVAE
from torchvision import transforms
-from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DiffusionPipeline
+from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from airunner.aihandler.auto_pipeline import AutoImport
from airunner.aihandler.enums import FilterType
from airunner.aihandler.enums import MessageCode
-from airunner.aihandler.logger import Logger as logger
from airunner.aihandler.mixins.compel_mixin import CompelMixin
from airunner.aihandler.mixins.embedding_mixin import EmbeddingMixin
from airunner.aihandler.mixins.lora_mixin import LoraMixin
@@ -34,7 +31,7 @@
from airunner.aihandler.settings_manager import SettingsManager
from airunner.prompt_builder.prompt_data import PromptData
from airunner.scripts.realesrgan.main import RealESRGAN
-
+from airunner.aihandler.logger import Logger
torch.backends.cuda.matmul.allow_tf32 = True
@@ -139,7 +136,7 @@ def local_files_only(self):
@local_files_only.setter
def local_files_only(self, value):
- logger.info("Setting local_files_only to %s" % value)
+ Logger.info("Setting local_files_only to %s" % value)
self._local_files_only = value
@property
@@ -526,7 +523,7 @@ def pipe(self):
elif self.is_vid_action:
return self.txt2vid
else:
- logger.warning(f"Invalid action {self.action} unable to get pipe")
+ Logger.warning(f"Invalid action {self.action} unable to get pipe")
@pipe.setter
def pipe(self, value):
@@ -543,7 +540,7 @@ def pipe(self, value):
elif self.is_vid_action:
self.txt2vid = value
else:
- logger.warning(f"Invalid action {self.action} unable to set pipe")
+ Logger.warning(f"Invalid action {self.action} unable to set pipe")
@property
def cuda_is_available(self):
@@ -695,7 +692,7 @@ def original_model_data(self):
return self.options.get("original_model_data", {})
def __init__(self, **kwargs):
- logger.set_level(LOG_LEVEL)
+ Logger.set_level(LOG_LEVEL)
self.settings_manager = SettingsManager()
self.safety_checker_model = self.settings_manager.models_by_pipeline_action("safety_checker")
self.text_encoder_model = self.settings_manager.models_by_pipeline_action("text_encoder")
@@ -770,9 +767,9 @@ def is_safetensor_file(model):
return model.endswith(".safetensors")
def initialize(self):
- logger.info("trying to initialize")
+ Logger.info("trying to initialize")
if not self.initialized or self.reload_model or self.pipe is None:
- logger.info("Initializing")
+ Logger.info("Initializing")
self.compel_proc = None
self.prompt_embeds = None
self.negative_prompt_embeds = None
@@ -790,7 +787,7 @@ def generator(self, device=None, seed=None):
return torch.Generator(device=device).manual_seed(seed)
def prepare_options(self, data):
- logger.info(f"Preparing options")
+ Logger.info(f"Preparing options")
action = data["action"]
options = data["options"]
requested_model = options.get(f"model", None)
@@ -825,11 +822,11 @@ def send_message(self, message, code=None):
if code == MessageCode.ERROR:
traceback.print_stack()
- logger.error(message)
+ Logger.error(message)
elif code == MessageCode.WARNING:
- logger.warning(message)
+ Logger.warning(message)
elif code == MessageCode.STATUS:
- logger.info(message)
+ Logger.info(message)
formatted_message = {
"code": code,
@@ -842,11 +839,10 @@ def send_message(self, message, code=None):
def error_handler(self, error):
message = str(error)
- if "got an unexpected keyword argument 'image'" in message and self.action in ["outpaint", "pix2pix",
- "depth2img"]:
+ if "got an unexpected keyword argument 'image'" in message and self.action in ["outpaint", "pix2pix", "depth2img"]:
message = f"This model does not support {self.action}"
traceback.print_exc()
- logger.error(error)
+ Logger.error(error)
self.send_message(message, MessageCode.ERROR)
def initialize_safety_checker(self, local_files_only=None):
@@ -875,16 +871,16 @@ def load_safety_checker(self):
if not self.pipe:
return
if not self.do_nsfw_filter:
- logger.info("Disabling safety checker")
+ Logger.info("Disabling safety checker")
self.pipe.safety_checker = None
elif self.pipe.safety_checker is None:
- logger.info("Loading safety checker")
+ Logger.info("Loading safety checker")
self.pipe.safety_checker = self.safety_checker
if self.pipe.safety_checker:
self.pipe.safety_checker.to(self.device)
def do_sample(self, **kwargs):
- logger.info(f"Sampling {self.action}")
+ Logger.info(f"Sampling {self.action}")
message = f"Generating {'video' if self.is_vid_action else 'image'}"
@@ -977,7 +973,7 @@ def call_pipe(self, **kwargs):
"negative_prompt_embeds": self.negative_prompt_embeds,
})
except Exception as _e:
- logger.warning("Compel failed: " + str(_e))
+ Logger.warning("Compel failed: " + str(_e))
args.update({
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
@@ -1004,7 +1000,7 @@ def call_pipe(self, **kwargs):
args["generator"] = generator
if self.enable_controlnet:
- logger.info(f"Setting up controlnet")
+ Logger.info(f"Setting up controlnet")
args = self.load_controlnet_arguments(**args)
self.load_safety_checker()
@@ -1043,7 +1039,7 @@ def call_pipe_txt2vid(self, **kwargs):
ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
frame_ids = list(range(ch_start, ch_end))
try:
- logger.info(f"Generating video with {len(frame_ids)} frames")
+ Logger.info(f"Generating video with {len(frame_ids)} frames")
self.send_message(f"Generating video, frames {cur_frame} to {cur_frame + len(frame_ids)-1} of {self.n_samples}")
cur_frame += len(frame_ids)
kwargs = {
@@ -1122,7 +1118,7 @@ def prepare_extra_args(self, _data, image, mask):
return extra_args
def sample_diffusers_model(self, data: dict):
- logger.info("sample_diffusers_model")
+ Logger.info("sample_diffusers_model")
from pytorch_lightning import seed_everything
image = self.image
mask = self.mask
@@ -1155,7 +1151,7 @@ def process_prompts(self, data, seed):
data["options"][f"negative_prompt"] = [negative_prompt for _ in range(self.batch_size)]
return data
prompt_data = self.prompt_data
- logger.info(f"Process prompt")
+ Logger.info(f"Process prompt")
if self.deterministic_seed:
prompt = data["options"][f"prompt"]
if ".blend(" in prompt:
@@ -1199,7 +1195,7 @@ def process_prompts(self, data, seed):
def process_data(self, data: dict):
import traceback
- logger.info("Runner: process_data called")
+ Logger.info("Runner: process_data called")
self.requested_data = data
self.prepare_options(data)
#self.prepare_scheduler()
@@ -1211,7 +1207,7 @@ def process_data(self, data: dict):
def generate(self, data: dict):
if not self.pipe:
return
- logger.info("generate called")
+ Logger.info("generate called")
self.do_cancel = False
self.process_data(data)
@@ -1314,7 +1310,7 @@ def unload_tokenizer(self):
self.tokenizer = None
def process_upscale(self, data: dict):
- logger.info("Processing upscale")
+ Logger.info("Processing upscale")
image = self.input_image
results = []
if image:
@@ -1345,7 +1341,7 @@ def generator_sample(self, data: dict):
return
if not self.pipe:
- logger.info("pipe is None")
+ Logger.info("pipe is None")
return
self.send_message(f"Generating {'video' if self.is_vid_action else 'image'}")
@@ -1355,7 +1351,7 @@ def generator_sample(self, data: dict):
try:
self.initialized = self.__dict__[action] is not None
except KeyError:
- logger.info(f"{action} model has not been initialized yet")
+ Logger.info(f"{action} model has not been initialized yet")
self.initialized = False
error = None
@@ -1369,7 +1365,7 @@ def generator_sample(self, data: dict):
error = e
if "PYTORCH_CUDA_ALLOC_CONF" in str(e):
error = self.cuda_error_message
- self.clear_memory()
+ self.engine.clear_memory()
self.reset_applied_memory_settings()
else:
error_message = f"Error during generation"
@@ -1397,7 +1393,7 @@ def log_error(self, error, message=None):
self.error_handler(message)
def load_controlnet_from_ckpt(self, pipeline):
- logger.info("Loading controlnet from ckpt")
+ Logger.info("Loading controlnet from ckpt")
pipeline = self.controlnet_action_diffuser(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
@@ -1412,7 +1408,7 @@ def load_controlnet_from_ckpt(self, pipeline):
return pipeline
def load_controlnet(self):
- logger.info(f"Loading controlnet {self.controlnet_type} self.controlnet_model {self.controlnet_model}")
+ Logger.info(f"Loading controlnet {self.controlnet_type} self.controlnet_model {self.controlnet_model}")
self._controlnet = None
self.current_controlnet_type = self.controlnet_type
controlnet = self.from_pretrained(
@@ -1424,17 +1420,17 @@ def load_controlnet(self):
def preprocess_for_controlnet(self, image):
if self.current_controlnet_type != self.controlnet_type or not self.processor:
- logger.info("Loading controlnet processor " + self.controlnet_type)
+ Logger.info("Loading controlnet processor " + self.controlnet_type)
self.current_controlnet_type = self.controlnet_type
- logger.info("Controlnet: Processing image")
+ Logger.info("Controlnet: Processing image")
self.processor = Processor(self.controlnet_type)
if self.processor:
- logger.info("Controlnet: Processing image")
+ Logger.info("Controlnet: Processing image")
image = self.processor(image)
# resize image to width and height
image = image.resize((self.width, self.height))
return image
- logger.error("No controlnet processor found")
+ Logger.error("No controlnet processor found")
def load_controlnet_arguments(self, **kwargs):
if not self.is_vid_action:
@@ -1452,7 +1448,7 @@ def load_controlnet_arguments(self, **kwargs):
return kwargs
def unload_unused_models(self):
- logger.info("Unloading unused models")
+ Logger.info("Unloading unused models")
for action in [
"txt2img",
"img2img",
@@ -1472,7 +1468,7 @@ def unload_unused_models(self):
self.reset_applied_memory_settings()
def load_model(self):
- logger.info("Loading model")
+ Logger.info("Loading model")
self.torch_compile_applied = False
self.lora_loaded = False
self.embeds_loaded = False
@@ -1497,7 +1493,7 @@ def load_model(self):
if self.pipe is None or self.reload_model:
kwargs["from_safetensors"] = self.is_safetensors
- logger.info(f"Loading model from scratch {self.reload_model}")
+ Logger.info(f"Loading model from scratch {self.reload_model}")
self.reset_applied_memory_settings()
self.send_model_loading_message(self.model_path)
@@ -1527,7 +1523,7 @@ def load_model(self):
except OSError as e:
return self.handle_missing_files(self.action)
else:
- logger.info(f"Loading model {self.model_path} from PRETRAINED")
+ Logger.info(f"Loading model {self.model_path} from PRETRAINED")
scheduler = self.load_scheduler()
if scheduler:
kwargs["scheduler"] = scheduler
@@ -1546,7 +1542,7 @@ def load_model(self):
)
if self.pipe is None:
- logger.error("Failed to load pipeline")
+ Logger.error("Failed to load pipeline")
self.send_message("Failed to load model", MessageCode.ERROR)
return
@@ -1554,12 +1550,12 @@ def load_model(self):
Initialize pipe for video to video zero
"""
if self.pipe and self.is_vid2vid:
- logger.info("Initializing pipe for vid2vid")
+ Logger.info("Initializing pipe for vid2vid")
self.pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
self.pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
if self.is_outpaint:
- logger.info("Initializing vae for inpaint / outpaint")
+ Logger.info("Initializing vae for inpaint / outpaint")
self.pipe.vae = self.from_pretrained(
pipeline_action="inpaint_vae",
model=self.inpaint_vae_model
@@ -1584,7 +1580,7 @@ def load_model(self):
#self.load_learned_embed_in_clip()
def load_ckpt_model(self):
- logger.info(f"Loading ckpt file {self.model_path}")
+ Logger.info(f"Loading ckpt file {self.model_path}")
pipeline = self.download_from_original_stable_diffusion_ckpt(path=self.model_path)
return pipeline
@@ -1634,11 +1630,11 @@ def download_from_original_stable_diffusion_ckpt(self, path, local_files_only=No
local_files_only=False
)
except Exception as e:
- logger.error(f"Failed to load model from ckpt: {e}")
+ Logger.error(f"Failed to load model from ckpt: {e}")
return pipe
def clear_controlnet(self):
- logger.info("Clearing controlnet")
+ Logger.info("Clearing controlnet")
self._controlnet = None
self.engine.clear_memory()
self.reset_applied_memory_settings()
@@ -1651,14 +1647,14 @@ def load_vae(self):
)
def reuse_pipeline(self, do_load_controlnet):
- logger.info("Reusing pipeline")
+ Logger.info("Reusing pipeline")
pipe = None
if self.is_txt2img:
pipe = self.img2img if self.txt2img is None else self.txt2img
elif self.is_img2img:
pipe = self.txt2img if self.img2img is None else self.img2img
if pipe is None:
- logger.warning("Failed to reuse pipeline")
+ Logger.warning("Failed to reuse pipeline")
self.clear_controlnet()
return
kwargs = pipe.components
@@ -1726,7 +1722,7 @@ def send_model_loading_message(self, model_name):
self.send_message(message)
def prepare_model(self):
- logger.info("Prepare model")
+ Logger.info("Prepare model")
# get model and switch to it
# get models from database
@@ -1746,7 +1742,7 @@ def prepare_model(self):
def unload_controlnet(self):
if self.pipe:
- logger.info("Unloading controlnet")
+ Logger.info("Unloading controlnet")
self.pipe.controlnet = None
self.controlnet_loaded = False
@@ -1778,17 +1774,17 @@ def from_pretrained(self, **kwargs):
**kwargs
)
except OSError as e:
- logger.error(f"failed to load {model} from pretrained")
+ Logger.error(f"failed to load {model} from pretrained")
return self.handle_missing_files(pipeline_action)
def handle_missing_files(self, action):
if not self.attempt_download:
if self.is_ckpt_model or self.is_safetensors:
- logger.info("Required files not found, attempting download")
+ Logger.info("Required files not found, attempting download")
else:
import traceback
traceback.print_exc()
- logger.info("Model not found, attempting download")
+ Logger.info("Model not found, attempting download")
# check if we have an internet connection
if self.allow_online_when_missing_files:
self.send_message("Downloading model files")
diff --git a/src/airunner/aihandler/settings_manager.py b/src/airunner/aihandler/settings_manager.py
index 83fbe1140..6add3bd4d 100644
--- a/src/airunner/aihandler/settings_manager.py
+++ b/src/airunner/aihandler/settings_manager.py
@@ -211,21 +211,21 @@ def find_generator(self, generator_section, generator_name):
if self.generator_settings_override_id:
generator_settings = session.query(GeneratorSetting).filter_by(
id=self.generator_settings_override_id
- ).join(Settings).first()
+ ).first()
else:
generator_settings = session.query(GeneratorSetting).filter_by(
- section=generator_section,
- generator_name=generator_name
- ).join(Settings).first()
+ is_preset=0
+ ).first()
if generator_settings is None:
if not generator_section or generator_section == "" or not generator_name or generator_name == "":
return None
- generator_settings = GeneratorSetting(
- section=generator_section,
- generator_name=generator_name
- )
- session.add(generator_settings)
- session.commit()
+ # generator_settings = GeneratorSetting(
+ # section=generator_section,
+ # generator_name=generator_name,
+ # is_preset=False
+ # )
+ # session.add(generator_settings)
+ # session.commit()
return generator_settings
def __init__(self, app=None, *args, **kwargs):
diff --git a/src/airunner/aihandler/transformer_runner.py b/src/airunner/aihandler/transformer_runner.py
index 4ed433913..4c668c4ce 100644
--- a/src/airunner/aihandler/transformer_runner.py
+++ b/src/airunner/aihandler/transformer_runner.py
@@ -4,20 +4,12 @@
from PyQt6.QtCore import QObject
# import AutoTokenizer
-from transformers import AutoTokenizer
# import BitsAndBytesConfig
from transformers import BitsAndBytesConfig
-import transformers
-from transformers import AutoModelForCausalLM
+from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import InstructBlipForConditionalGeneration
from transformers import InstructBlipProcessor
-from langchain.llms.huggingface_pipeline import HuggingFacePipeline
-from langchain.memory import ConversationBufferWindowMemory
-from langchain.prompts import PromptTemplate
-from langchain.chains import ConversationChain
-from airunner.aihandler.enums import MessageCode
-from airunner.aihandler.llm_api import LLMAPI
from airunner.aihandler.settings_manager import SettingsManager
from airunner.data.models import LLMGenerator
from airunner.data.db import session
@@ -67,13 +59,19 @@ class TransformerRunner(QObject):
current_generator_name = ""
do_quantize_model = True
callback = None
+ template = None
+ #system_instructions = ""
+ system_instructions = ""
@property
def generator(self):
- if not self._generator or self.current_generator_name != self.requested_generator_name:
- self.current_generator_name = self.requested_generator_name
- self._generator = session.query(LLMGenerator).filter_by(name=self.current_generator_name).first()
- return self._generator
+ try:
+ if not self._generator or self.current_generator_name != self.requested_generator_name:
+ self.current_generator_name = self.requested_generator_name
+ self._generator = session.query(LLMGenerator).filter_by(name=self.current_generator_name).first()
+ return self._generator
+ except Exception as e:
+ Logger.error(e)
@property
def do_load_model(self):
@@ -202,7 +200,7 @@ def load_model(self, local_files_only = None):
auto_class_ = None
if self.generator.name == "seq2seq":
- auto_class_ = AutoModelForSeq2SeqLM
+ auto_class_ = AutoModelForSeq2SeqLM
elif self.generator.name == "casuallm":
auto_class_ = AutoModelForCausalLM
elif self.generator.name == "visualqa":
@@ -219,38 +217,14 @@ def load_model(self, local_files_only = None):
else:
Logger.error(e)
- if self.generator.name == "casuallm":
- self.pipeline = transformers.pipeline(
- "text-generation",
- model=self.model,
- tokenizer=self.tokenizer,
- torch_dtype=torch.float16 if self.dtype != "32bit" else torch.float32,
- trust_remote_code=False,
- device_map="auto",
- min_length=0,
- max_length=1000,
- num_beams=1,
- do_sample=True,
- top_k=20,
- eta_cutoff=10,
- top_p=1.0,
- num_return_sequences=self.sequences,
- eos_token_id=self.tokenizer.eos_token_id,
- early_stopping=True,
- repetition_penalty=1.15,
- temperature=0.7,
- )
- Logger.info(f"Loading prompt template {self.prompt_template}")
- self.llm=HuggingFacePipeline(pipeline=self.pipeline)
- self.memory = ConversationBufferWindowMemory(k=5)
- self.prompt = PromptTemplate.from_template(
- self.prompt_template,
- template_format="jinja2",
- partial_variables={
- "username": self.username,
- "botname": self.botname,
- })
- self.chain = ConversationChain(llm=self.llm, prompt=self.prompt, memory=self.memory)
+ # if self.generator.name == "casuallm":
+ # self.pipeline = AutoModelForCausalLM.from_pretrained(
+ # self.model,
+ # torch_dtype=torch.float16 if self.dtype != "32bit" else torch.float32,
+ # trust_remote_code=False,
+ # device_map="auto"
+ # )
+
def process_data(self, data):
self.request_data = data.get("request_data", {})
@@ -266,6 +240,7 @@ def process_data(self, data):
self.prompt = self.request_data.get("prompt", "")
self.prompt_template = self.request_data.get("prompt_template", "")
self.image = self.request_data.get("image", None)
+ self.system_instructions = f"Your name is {self.botname}. You are talking to {self.username}. You will always respond in character. You will always respond to the user. You will not give ethical or moral advice to the user unless it is something appropriate for {self.botname} to say."
if self.image:
self.image = self.image.convert("RGB")
@@ -284,6 +259,7 @@ def clear_conversation(self):
pass
def prepare_input_args(self):
+ self.system_instructions = self.request_data.get("system_instructions", "")
top_k = self.parameters.get("top_k", self.top_k)
eta_cutoff = self.parameters.get("eta_cutoff", self.eta_cutoff)
top_p = self.parameters.get("top_p", self.top_p)
@@ -388,18 +364,20 @@ def handle_generate_request(self):
if self.generator.name == "visualqa":
self.load_processor()
- self.engine.send_message("Generating output")
- with torch.backends.cuda.sdp_kernel(
- enable_flash=True,
- enable_math=False,
- enable_mem_efficient=False
- ):
- with torch.no_grad():
- value = self.generate(**kwargs)
- if self.callback:
- self.callback(value)
- else:
- self.engine.send_message(value, code=MessageCode.TEXT_GENERATED)
+ # self.engine.send_message("Generating output")
+ # with torch.backends.cuda.sdp_kernel(
+ # enable_flash=True,
+ # enable_math=False,
+ # enable_mem_efficient=False
+ # ):
+ # with torch.no_grad():
+ # print("************** CALLING GENERATE")
+ # value = self.generate()
+ # print("VALUE", value)
+ # # if self.callback:
+ # # self.callback(value)
+ # # else:
+ # # self.engine.send_message(value, code=MessageCode.TEXT_GENERATED)
self.enable_request_processing()
def disable_request_processing(self):
@@ -408,7 +386,7 @@ def disable_request_processing(self):
def enable_request_processing(self):
self._processing_request = True
- def generate(self, **kwargs):
+ def generate(self):
pass
diff --git a/src/airunner/alembic/env.py b/src/airunner/alembic/env.py
index 1e8633c2a..01ef487e6 100644
--- a/src/airunner/alembic/env.py
+++ b/src/airunner/alembic/env.py
@@ -11,8 +11,8 @@
# Interpret the config file for Python logging.
# This line sets up loggers basically.
-if config.config_file_name is not None:
- fileConfig(config.config_file_name)
+# if config.config_file_name is not None:
+# fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
diff --git a/src/airunner/alembic/versions/08728ca819b8_adds_prompt_template_to_llmgenerator.py b/src/airunner/alembic/versions/08728ca819b8_adds_prompt_template_to_llmgenerator.py
new file mode 100644
index 000000000..eae30e2bf
--- /dev/null
+++ b/src/airunner/alembic/versions/08728ca819b8_adds_prompt_template_to_llmgenerator.py
@@ -0,0 +1,30 @@
+"""Adds prompt_template to LLMGenerator
+
+Revision ID: 08728ca819b8
+Revises: 7c1b15cade9b
+Create Date: 2024-01-08 13:46:24.790611
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '08728ca819b8'
+down_revision: Union[str, None] = '7c1b15cade9b'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('llm_generator', sa.Column('prompt_template', sa.String(), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('llm_generator', 'prompt_template')
+ # ### end Alembic commands ###
diff --git a/src/airunner/alembic/versions/435acbb59892_added_version_to_generator_settings.py b/src/airunner/alembic/versions/435acbb59892_added_version_to_generator_settings.py
new file mode 100644
index 000000000..4dbc8a1a5
--- /dev/null
+++ b/src/airunner/alembic/versions/435acbb59892_added_version_to_generator_settings.py
@@ -0,0 +1,30 @@
+"""Added version to generator settings
+
+Revision ID: 435acbb59892
+Revises: 6f4472b84071
+Create Date: 2024-01-02 06:15:11.656030
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '435acbb59892'
+down_revision: Union[str, None] = '6f4472b84071'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('generator_settings', sa.Column('version', sa.String(), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('generator_settings', 'version')
+ # ### end Alembic commands ###
diff --git a/src/airunner/alembic/versions/7c1b15cade9b_adds_is_preset_flag_to_generator_.py b/src/airunner/alembic/versions/7c1b15cade9b_adds_is_preset_flag_to_generator_.py
new file mode 100644
index 000000000..0f561eefb
--- /dev/null
+++ b/src/airunner/alembic/versions/7c1b15cade9b_adds_is_preset_flag_to_generator_.py
@@ -0,0 +1,30 @@
+"""Adds is_preset flag to generator settings table
+
+Revision ID: 7c1b15cade9b
+Revises: 435acbb59892
+Create Date: 2024-01-02 06:52:02.924308
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '7c1b15cade9b'
+down_revision: Union[str, None] = '435acbb59892'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('generator_settings', sa.Column('is_preset', sa.Boolean(), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('generator_settings', 'is_preset')
+ # ### end Alembic commands ###
diff --git a/src/airunner/data/db.py b/src/airunner/data/db.py
index cba29c1bd..e210a868d 100644
--- a/src/airunner/data/db.py
+++ b/src/airunner/data/db.py
@@ -14,6 +14,7 @@
LLMGeneratorSetting, LLMGenerator, LLMModelVersion
from airunner.utils import get_session
from alembic.config import Config
+from airunner.aihandler.logger import Logger
from alembic import command
import os
import configparser
@@ -495,7 +496,7 @@ def insert_variables(variables, prev_object=None):
},
]
-HERE = os.path.abspath(os.path.dirname(__file__))
+HERE = os.path.abspath(os.path.dirname(__file__))
alembic_ini_path = os.path.join(HERE, "../alembic.ini")
config = configparser.ConfigParser()
@@ -505,9 +506,7 @@ def insert_variables(variables, prev_object=None):
db_path = f'sqlite:///{home_dir}/.airunner/airunner.db'
config.set('alembic', 'sqlalchemy.url', db_path)
-
with open(alembic_ini_path, 'w') as configfile:
config.write(configfile)
-
alembic_cfg = Config(alembic_ini_path)
-command.upgrade(alembic_cfg, "head")
\ No newline at end of file
+command.upgrade(alembic_cfg, "head")
diff --git a/src/airunner/data/models.py b/src/airunner/data/models.py
index edac49b49..1a31fc370 100644
--- a/src/airunner/data/models.py
+++ b/src/airunner/data/models.py
@@ -356,6 +356,8 @@ class GeneratorSetting(BaseModel):
active_grid_border_color = Column(String, default="#00FF00")
active_grid_fill_color = Column(String, default="#FF0000")
brushes = relationship("Brush", back_populates='generator_setting') # modified line
+ version = Column(String, default="SD 1.5")
+ is_preset = Column(Boolean, default=False)
class Brush(BaseModel):
@@ -767,6 +769,7 @@ class LLMGenerator(BaseModel):
message_type = Column(String, default="chat")
bot_personality = Column(String, default="Nice")
override_parameters = Column(Boolean, default=False)
+ prompt_template = Column(String, default="")
class LLMGeneratorSetting(BaseModel):
diff --git a/src/airunner/filters/color_balance.py b/src/airunner/filters/color_balance.py
index 706e00276..3a30bf560 100644
--- a/src/airunner/filters/color_balance.py
+++ b/src/airunner/filters/color_balance.py
@@ -5,6 +5,7 @@
class ColorBalanceFilter(BaseFilter):
def apply_filter(self, image, _do_reset):
+ image = image.convert("RGBA")
red, green, blue, alpha = image.split()
red = red.point(lambda i: i + (i * self.cyan_red))
green = green.point(lambda i: i + (i * self.magenta_green))
diff --git a/src/airunner/filters/registration_error.py b/src/airunner/filters/registration_error.py
index 38c61c7df..0a461f850 100644
--- a/src/airunner/filters/registration_error.py
+++ b/src/airunner/filters/registration_error.py
@@ -5,6 +5,7 @@
class RegistrationErrorFilter(BaseFilter):
def apply_filter(self, image, do_reset):
+ image = image.convert("RGBA")
# first, split the image into its R G B channels
r, g, b, a = image.split()
diff --git a/src/airunner/filters/unsharp_mask.py b/src/airunner/filters/unsharp_mask.py
index 0de8945c5..e3ae94983 100644
--- a/src/airunner/filters/unsharp_mask.py
+++ b/src/airunner/filters/unsharp_mask.py
@@ -4,5 +4,8 @@
class UnsharpMask(BaseFilter):
+ def __init__(self, radius, percent, threshold):
+ self.unsharp_mask = UnsharpMask(radius=radius, percent=percent, threshold=threshold)
+
def apply_filter(self, image, do_reset):
- return image.filter(UnsharpMask(radius=self.radius))
+ return image.filter(self.unsharp_mask)
diff --git a/src/airunner/filters/windows/filter_base.py b/src/airunner/filters/windows/filter_base.py
index 49db007d0..3f6aaab44 100644
--- a/src/airunner/filters/windows/filter_base.py
+++ b/src/airunner/filters/windows/filter_base.py
@@ -5,6 +5,7 @@
from PyQt6 import uic
from airunner.widgets.slider.slider_widget import SliderWidget
+from airunner.aihandler.settings_manager import SettingsManager
class FilterBase:
@@ -36,7 +37,7 @@ def __getattr__(self, item):
def __setattr__(self, key, value):
if key in self._filter_values:
self._filter_values[key].value = str(value)
- self.parent.settings_manager.save()
+ self.settings_manager.save()
else:
super().__setattr__(key, value)
@@ -57,6 +58,8 @@ def __init__(self, parent, model_name):
# filter_values are the names of the ImageFilterValue objects in the database.
# when the filter is shown, the values are loaded from the database
# and stored in this dictionary.
+ self.settings_manager = SettingsManager(app=parent)
+
self._filter_values = {}
self.filter_window = None
@@ -65,16 +68,14 @@ def __init__(self, parent, model_name):
self.load_image_filter_data()
def update_value(self, name, value):
- print(name, value)
self._filter_values[name].value = str(value)
- self.parent.settings_manager.save()
+ self.settings_manager.save()
def update_canvas(self):
- if self.parent.canvas_is_active:
- self.parent.current_canvas.update()
+ pass
def load_image_filter_data(self):
- self.image_filter_data = self.parent.settings_manager.get_image_filter(self.image_filter_model_name)
+ self.image_filter_data = self.settings_manager.get_image_filter(self.image_filter_model_name)
for filter_value in self.image_filter_data.image_filter_values:
self._filter_values[filter_value.name] = filter_value
@@ -134,7 +135,7 @@ def show(self):
def handle_auto_apply_toggle(self):
self.image_filter_data.auto_apply = self.filter_window.auto_apply.isChecked()
- self.parent.settings_manager.save()
+ self.settings_manager.save()
def handle_slider_change(self, settings_property, val):
self.update_value(settings_property, val)
@@ -143,15 +144,15 @@ def handle_slider_change(self, settings_property, val):
def cancel_filter(self):
self.reject()
- self.parent.current_canvas.cancel_filter()
+ self.parent.canvas_widget.cancel_filter()
self.update_canvas()
def apply_filter(self):
self.accept()
- self.parent.current_canvas.apply_filter(self.filter)
+ self.parent.canvas_widget.apply_filter(self.filter)
self.filter_window.close()
self.update_canvas()
def preview_filter(self):
- self.parent.current_canvas.preview_filter(self.filter)
+ self.parent.canvas_widget.preview_filter(self.filter)
self.update_canvas()
diff --git a/src/airunner/styles/dark_theme/styles.qss b/src/airunner/styles/dark_theme/styles.qss
index 36e8a954b..cea971461 100644
--- a/src/airunner/styles/dark_theme/styles.qss
+++ b/src/airunner/styles/dark_theme/styles.qss
@@ -644,4 +644,11 @@ QTableWidget::item:selected {
#delete_confirmation QPushButton {
border: 1px solid rgba(0, 132, 185, 50);
+}
+
+QWidget#message QPlainTextEdit {
+ border-radius: 5px;
+ border: 5px solid #1f1f1f;
+ background-color: #1f1f1f;
+ color: #ffffff;
}
\ No newline at end of file
diff --git a/src/airunner/utils.py b/src/airunner/utils.py
index 3e2b5262b..561bdf193 100644
--- a/src/airunner/utils.py
+++ b/src/airunner/utils.py
@@ -221,24 +221,6 @@ def get_version():
return ""
-def get_latest_version():
- return
- # get latest release from https://github.com/Capsize-Games/airunner/releases/latest
- # follow the redirect to get the version number
- import requests
- import re
- url = "https://github.com/Capsize-Games/airunner/releases/latest"
- try:
- r = requests.get(url)
- if r.status_code == 200:
- m = re.search(r"\/Capsize-Games\/airunner\/releases\/tag\/v([0-9\.]+)", r.text)
- if m:
- return m.group(1)
- except ConnectionError:
- return None
- return None
-
-
# def load_default_models(tab_section, section_name):
# if section_name == "txt2img":
# section_name = "generate"
diff --git a/src/airunner/widgets/base_widget.py b/src/airunner/widgets/base_widget.py
index 268443c17..cd314ef85 100644
--- a/src/airunner/widgets/base_widget.py
+++ b/src/airunner/widgets/base_widget.py
@@ -8,7 +8,7 @@
class BaseWidget(QWidget):
widget_class_ = None
- icons = {}
+ icons = ()
ui = None
qss_filename = None
@@ -26,25 +26,42 @@ def add_to_grid(self, widget, row, column, row_span=1, column_span=1):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.app = get_main_window()
+ self.app.loaded.connect(self.initialize)
self.settings_manager = SettingsManager()
if self.widget_class_:
self.ui = self.widget_class_()
if self.ui:
self.ui.setupUi(self)
- if self.qss_filename:
- theme_name = "dark_theme"
- here = os.path.dirname(os.path.realpath(__file__))
- with open(os.path.join(here, "..", "styles", theme_name, self.qss_filename), "r") as f:
- stylesheet = f.read()
- self.setStyleSheet(stylesheet)
-
- def set_stylesheet(self, is_dark=None, button_name=None, icon=None):
- is_dark = self.is_dark if is_dark is None else is_dark
- if button_name is None or icon is None:
- for button_name, icon in self.icons.items():
- self.set_button_icon(is_dark, button_name, icon)
- else:
- self.set_button_icon(is_dark, button_name, icon)
+ # if self.qss_filename:
+ # theme_name = "dark_theme"
+ # here = os.path.dirname(os.path.realpath(__file__))
+ # with open(os.path.join(here, "..", "styles", theme_name, self.qss_filename), "r") as f:
+ # stylesheet = f.read()
+ # self.setStyleSheet(stylesheet)
+ self.set_icons()
+
+ def initialize(self):
+ """
+ Triggered when the app is loaded.
+ Override this function in order to initialize the widget rather than
+ using __init__.
+ """
+ pass
+
+ def set_icons(self):
+ theme = "dark" if self.is_dark else "light"
+ for icon_data in self.icons:
+ icon_name = icon_data[0]
+ widget_name = icon_data[1]
+ print(icon_name, widget_name)
+ print(self.icons)
+ icon = QtGui.QIcon()
+ icon.addPixmap(
+ QtGui.QPixmap(f":/icons/{theme}/{icon_name}.svg"),
+ QtGui.QIcon.Mode.Normal,
+ QtGui.QIcon.State.Off)
+ getattr(self.ui, widget_name).setIcon(icon)
+ self.update()
def set_button_icon(self, is_dark, button_name, icon):
try:
@@ -130,8 +147,12 @@ def set_form_value(self, element, settings_key_name):
if not self.set_is_checked(element, target_val):
raise Exception(f"Could not set value for {element} to {target_val}")
- def set_form_property(self, element, property_name, settings_key_name):
+ def set_form_property(self, element, property_name, settings_key_name=None, settings=None):
val = self.get_form_element(element).property(property_name)
- target_val = self.settings_manager.get_value(settings_key_name)
+ if settings_key_name:
+ target_val = self.settings_manager.get_value(settings_key_name)
+ elif settings:
+ target_val = getattr(settings, property_name)
+
if val != target_val:
self.get_form_element(element).setProperty(property_name, target_val)
diff --git a/src/airunner/widgets/brushes/brushes_container.py b/src/airunner/widgets/brushes/brushes_container.py
index edd6a585e..bc3b24e0d 100644
--- a/src/airunner/widgets/brushes/brushes_container.py
+++ b/src/airunner/widgets/brushes/brushes_container.py
@@ -133,6 +133,7 @@ def activate_brush(self, clicked_widget, brush, multiple):
else:
self.selected_brushes.remove(widget)
widget.setStyleSheet("")
+ self.settings_manager.set_value("generator_settings_override_id", None)
return
for widget in self.selected_brushes:
@@ -153,7 +154,7 @@ def activate_brush(self, clicked_widget, brush, multiple):
widget.setStyleSheet(f"""
border: 2px solid #ff0000;
""")
- self.app.ui.generator_widget.enable_preset(widget.brush.generator_setting_id)
+ self.settings_manager.set_value("generator_settings_override_id", widget.brush.generator_setting_id)
def display_brush_menu(self, event, widget, brush):
context_menu = QMenu(self)
diff --git a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py
index d5097f515..96d82312e 100644
--- a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py
+++ b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py
@@ -2,7 +2,6 @@
import math
import subprocess
from functools import partial
-import pdb
from PIL import Image, ImageGrab
from PIL.ImageQt import ImageQt, QImage
@@ -15,7 +14,7 @@
from airunner.cursors.circle_brush import CircleCursor
from airunner.data.db import session
from airunner.data.models import Layer, CanvasSettings, ActiveGridSettings
-from airunner.utils import get_session, save_session
+from airunner.utils import save_session
from airunner.widgets.canvas_plus.canvas_base_widget import CanvasBaseWidget
from airunner.widgets.canvas_plus.templates.canvas_plus_ui import Ui_canvas
from airunner.utils import apply_opacity_to_image
@@ -212,10 +211,6 @@ class CanvasPlusWidget(CanvasBaseWidget):
initialized = False
drawing = False
- @property
- def current_active_image_data(self):
- return self.current_layer.image_data
-
def current_pixmap(self):
draggable_pixmap = self.current_draggable_pixmap()
if draggable_pixmap:
@@ -227,10 +222,6 @@ def current_image(self):
return None
return Image.fromqpixmap(pixmap)
- @current_active_image_data.setter
- def current_active_image_data(self, value):
- self.current_layer.image_data = value
-
@property
def image_pivot_point(self):
try:
@@ -402,6 +393,16 @@ def draw_layers(self):
layer.opacity / 100.0
)
+ if layer.id in self.layers:
+ if not layer.visible:
+ if self.layers[layer.id] in self.scene.items():
+ self.scene.removeItem(self.layers[layer.id])
+ elif layer.visible:
+ if not self.layers[layer.id] in self.scene.items():
+ self.scene.addItem(self.layers[layer.id])
+ self.layers[layer.id].pixmap.convertFromImage(ImageQt(image))
+ continue
+
draggable_pixmap = None
if layer.id in self.layers:
self.layers[layer.id].pixmap.convertFromImage(ImageQt(image))
@@ -453,6 +454,7 @@ def draw_active_grid_area_container(self):
"""
Draw a rectangle around the active grid area of
"""
+ print(self.active_grid_area_rect)
if not self.active_grid_area:
self.active_grid_area = ActiveGridArea(
parent=self,
@@ -583,7 +585,7 @@ def load_image_from_object(self, image, is_outpaint=False, image_root_point=None
def load_image(self, image_path):
image = Image.open(image_path)
- if self.app.settings_manager.resize_on_paste:
+ if self.settings_manager.resize_on_paste:
image.thumbnail((self.settings_manager.working_width,
self.settings_manager.working_height), Image.ANTIALIAS)
self.add_image_to_scene(image)
@@ -603,6 +605,8 @@ def cut_image(self):
if not draggable_pixmap:
return
self.scene.removeItem(draggable_pixmap)
+ if self.current_layer.id in self.layers:
+ del self.layers[self.current_layer.id]
self.update()
def delete_image(self):
@@ -614,17 +618,13 @@ def delete_image(self):
self.update()
def paste_image_from_clipboard(self):
+ Logger.info("paste image from clipboard")
image = self.get_image_from_clipboard()
if not image:
Logger.info("No image in clipboard")
return
- if self.app.settings_manager.resize_on_paste:
- Logger.info("Resizing image")
- image.thumbnail(
- (self.settings_manager.working_width, self.settings_manager.working_height),
- Image.ANTIALIAS)
self.create_image(image)
def get_image_from_clipboard(self):
@@ -653,19 +653,23 @@ def image_to_system_clipboard_linux(self, pixmap):
subprocess.Popen(["xclip", "-selection", "clipboard", "-t", "image/png"],
stdin=subprocess.PIPE).communicate(data)
except FileNotFoundError:
- pass
+ Logger.error("xclip not found. Please install xclip to copy image to clipboard.")
def create_image(self, image):
- if self.app.settings_manager.resize_on_paste:
- image.thumbnail(
- (
- self.settings_manager.working_width,
- self.settings_manager.working_height
- ),
- Image.ANTIALIAS
- )
+ if self.settings_manager.resize_on_paste:
+ image = self.resize_image(image)
self.add_image_to_scene(image)
+ def resize_image(self, image):
+ image.thumbnail(
+ (
+ self.settings_manager.working_width,
+ self.settings_manager.working_height
+ ),
+ Image.ANTIALIAS
+ )
+ return image
+
def remove_current_draggable_pixmap_from_scene(self):
current_draggable_pixmap = self.current_draggable_pixmap()
if current_draggable_pixmap:
@@ -678,11 +682,6 @@ def switch_to_layer(self, layer_index):
self.current_layer_index = layer_index
def add_image_to_scene(self, image, is_outpaint=False, image_root_point=None):
- # change size of self.current_active_image to match size of image
- if self.current_active_image is None:
- self.current_active_image = image
- else:
- self.current_active_image = self.current_active_image.resize(image.size)
self.current_active_image = image
if image_root_point is not None:
self.current_layer.pos_x = image_root_point.x()
@@ -751,3 +750,13 @@ def save_image(self, image_path, image=None):
def update_image_canvas(self):
print("TODO")
+
+ def rotate_90_clockwise(self):
+ if self.current_active_image:
+ self.current_active_image = self.current_active_image.transpose(Image.ROTATE_270)
+ self.do_draw()
+
+ def rotate_90_counterclockwise(self):
+ if self.current_active_image:
+ self.current_active_image = self.current_active_image.transpose(Image.ROTATE_90)
+ self.do_draw()
\ No newline at end of file
diff --git a/src/airunner/widgets/canvas_plus/standard_image_widget.py b/src/airunner/widgets/canvas_plus/standard_image_widget.py
index 0a2465006..556ee284a 100644
--- a/src/airunner/widgets/canvas_plus/standard_image_widget.py
+++ b/src/airunner/widgets/canvas_plus/standard_image_widget.py
@@ -18,7 +18,7 @@
from airunner.utils import delete_image, load_metadata_from_image, prepare_metadata
from airunner.settings import CONTROLNET_OPTIONS
from airunner.widgets.slider.slider_widget import SliderWidget
-from airunner.data.models import ActionScheduler, Pipeline
+from airunner.data.models import ActionScheduler, Pipeline, GeneratorSetting
class StandardImageWidget(StandardBaseWidget):
widget_class_ = Ui_standard_image_widget
@@ -29,6 +29,21 @@ class StandardImageWidget(StandardBaseWidget):
image_label = None
image_batch = None
meta_data = None
+ _image = None
+
+ @property
+ def image(self):
+ if self._image is None:
+ self.image = self.app.canvas_widget.current_layer.image
+ return self._image
+
+ @image.setter
+ def image(self, image):
+ self._image = image
+
+ @property
+ def canvas_widget(self):
+ return self.ui.canvas_widget
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -41,20 +56,11 @@ def __init__(self, *args, **kwargs):
self.initialize()
def set_controlnet_settings_properties(self):
- self.ui.controlnet_settings.initialize(
- self.settings_manager.generator_name,
- self.settings_manager.generator_section
- )
+ self.ui.controlnet_settings.initialize()
def set_input_image_widget_properties(self):
- self.ui.input_image_widget.initialize(
- self.settings_manager.generator_name,
- self.settings_manager.generator_section
- )
- self.ui.controlnet_settings.initialize(
- self.settings_manager.generator_name,
- self.settings_manager.generator_section
- )
+ self.ui.input_image_widget.initialize()
+ self.ui.controlnet_settings.initialize()
def update_image_input_thumbnail(self):
self.ui.input_image_widget.set_thumbnail()
@@ -113,7 +119,6 @@ def set_pixmap(self, image_path=None, image=None):
self.image_path = image_path
self.image = image
meta_data = image.info
- print("META DATA", meta_data)
self.meta_data = meta_data if meta_data is not None else load_metadata_from_image(image)
return
#size = self.ui.image_frame.width() - 20
@@ -204,18 +209,21 @@ def similar_image_with_prompt(self):
"""
Using the LLM, generate a description of the image
"""
- #self.app.describe_image(image=self.image, callback=self.handle_prompt_generated)
- prompt = self.app.generator_tab_widget.ui.generator_form_stablediffusion.ui.prompt.toPlainText()
- negative_prompt = self.app.generator_tab_widget.ui.generator_form_stablediffusion.ui.negative_prompt.toPlainText()
- self.handle_prompt_generated([prompt], [negative_prompt])
+ self.app.describe_image(image=self.image, callback=self.handle_prompt_generated)
+ # prompt = self.app.generator_tab_widget.ui.prompt.toPlainText()
+ # negative_prompt = self.app.generator_tab_widget.ui.negative_prompt.toPlainText()
+ # self.handle_prompt_generated([prompt], [negative_prompt])
def handle_prompt_generated(self, prompt, negative_prompt):
meta_data = load_metadata_from_image(self.image)
meta_data["prompt"] = prompt[0]
meta_data["negative_prompt"] = negative_prompt[0]
meta_data = prepare_metadata({ "options": meta_data })
- image = Image.open(self.image_path)
- image.save(self.image_path, pnginfo=meta_data)
+ if self.image_path:
+ image = Image.open(self.image_path)
+ image.save(self.image_path, pnginfo=meta_data)
+ else:
+ image = self.app.canvas_widget.current_layer.image
self.image = image
self.meta_data = load_metadata_from_image(self.image)
self.generate_similar_image()
@@ -224,7 +232,7 @@ def similar_image(self):
self.generate_similar_image()
def generate_similar_image(self, batch_size=1):
- meta_data = self.meta_data
+ meta_data = self.meta_data or {}
prompt = meta_data.get("prompt", None)
negative_prompt = meta_data.get("negative_prompt", None)
@@ -232,7 +240,8 @@ def generate_similar_image(self, batch_size=1):
negative_prompt = None if negative_prompt == "" else negative_prompt
if prompt is None:
- return self.similar_image_with_prompt()
+ #return self.similar_image_with_prompt()
+ prompt = ""
if negative_prompt is None:
meta_data["negative_prompt"] = "verybadimagenegative_v1.3, EasyNegative"
@@ -251,7 +260,7 @@ def generate_similar_image(self, batch_size=1):
meta_data["use_cropped_image"] = False
meta_data["batch_size"] = batch_size
- self.app.generator_tab_widget.current_generator_widget.call_generate(
+ self.app.generator_tab_widget.call_generate(
image=self.image,
override_data=meta_data
)
@@ -263,7 +272,7 @@ def similar_batch(self):
self.generate_similar_image(batch_size=4)
def upscale_2x_clicked(self):
- meta_data = self.meta_data
+ meta_data = self.meta_data or {}
prompt = meta_data.get("prompt", None)
negative_prompt = meta_data.get("negative_prompt", None)
@@ -271,7 +280,7 @@ def upscale_2x_clicked(self):
negative_prompt = None if negative_prompt == "" else negative_prompt
if prompt is None:
- return self.similar_image_with_prompt()
+ prompt = ""
if negative_prompt is None:
meta_data["negative_prompt"] = "verybadimagenegative_v1.3, EasyNegative"
@@ -281,14 +290,13 @@ def upscale_2x_clicked(self):
meta_data["face_enhance"] = self.settings_manager.standard_image_widget_settings.face_enhance
meta_data["denoise_strength"] = 0.5
meta_data["action"] = "upscale"
- meta_data["width"] = self.image.width
- meta_data["height"] = self.image.height
+ meta_data["width"] = self.ui.canvas_widget.current_layer.image.width
+ meta_data["height"] = self.ui.canvas_widget.current_layer.image.height
meta_data["enable_input_image"] = True
meta_data["use_cropped_image"] = False
-
- self.app.generator_tab_widget.current_generator_widget.call_generate(
- image=self.image,
+ self.app.generator_tab_widget.call_generate(
+ image=self.ui.canvas_widget.current_layer.image,
override_data=meta_data
)
@@ -322,20 +330,6 @@ def load_versions(self):
self.ui.version.setCurrentText(current_version)
self.ui.version.blockSignals(False)
- def handle_pipeline_changed(self, val):
- if val == "txt2img / img2img":
- val = "txt2img"
- elif val == "inpaint / outpaint":
- val = "outpaint"
- self.settings_manager.set_value(f"current_section_{self.settings_manager.current_image_generator}", val)
- self.load_versions()
- self.load_models()
-
- def handle_version_changed(self, val):
- print("VERSION CHANGED", val)
- self.settings_manager.set_value(f"current_version_{self.settings_manager.current_image_generator}", val)
- self.load_models()
-
def load_models(self):
session = get_session()
self.ui.model.blockSignals(True)
@@ -381,15 +375,37 @@ def load_schedulers(self):
def clear_models(self):
self.ui.model.clear()
- def handle_settings_manager_changed(self, key, val, settings_manager):
- print("handle_settings_manager_changed", key, val)
- if settings_manager.generator_section == self.settings_manager.generator_section and settings_manager.generator_name == self.settings_manager.generator_name:
+ def initialize_generator_form(self, override_id=None):
+ if override_id:
+ self.ui.steps_widget.set_slider_and_spinbox_values(self.settings_manager.generator.steps)
+ self.ui.scale_widget.set_slider_and_spinbox_values(self.settings_manager.generator.scale * 100)
+ self.ui.clip_skip_slider_widget.set_slider_and_spinbox_values(self.settings_manager.generator.clip_skip)
+
+ self.ui.pipeline.blockSignals(True)
+ self.ui.version.blockSignals(True)
+ self.ui.model.blockSignals(True)
+ self.ui.scheduler.blockSignals(True)
+
+ self.ui.pipeline.setCurrentText(self.settings_manager.generator.section)
+ self.ui.version.setCurrentText(self.settings_manager.generator.version)
+ self.ui.model.setCurrentText(self.settings_manager.generator.model)
+ self.ui.scheduler.setCurrentText(self.settings_manager.generator.scheduler)
+
+ self.ui.pipeline.blockSignals(False)
+ self.ui.version.blockSignals(False)
+ self.ui.model.blockSignals(False)
+ self.ui.scheduler.blockSignals(False)
+ else:
self.set_form_values()
self.load_pipelines()
self.load_versions()
self.load_models()
self.load_schedulers()
-
+
+ def handle_settings_manager_changed(self, key, val, settings_manager):
+ if key == "generator_settings_override_id":
+ self.initialize_generator_form(val)
+
def initialize(self):
self.set_form_values()
self.load_pipelines()
@@ -426,4 +442,28 @@ def initialize(self):
# self.generator_section,
# self.generator_name
# )
- self.initialized = True
\ No newline at end of file
+ self.initialized = True
+
+ def handle_model_changed(self, name):
+ if not self.initialized:
+ return
+ self.settings_manager.set_value("generator.model", name)
+
+ def handle_scheduler_changed(self, name):
+ if not self.initialized:
+ return
+ self.settings_manager.set_value("generator.scheduler", name)
+
+ def handle_pipeline_changed(self, val):
+ if val == "txt2img / img2img":
+ val = "txt2img"
+ elif val == "inpaint / outpaint":
+ val = "outpaint"
+ self.settings_manager.set_value(f"current_section_{self.settings_manager.current_image_generator}", val)
+ self.load_versions()
+ self.load_models()
+
+ def handle_version_changed(self, val):
+ print("VERSION CHANGED", val)
+ self.settings_manager.set_value(f"current_version_{self.settings_manager.current_image_generator}", val)
+ self.load_models()
\ No newline at end of file
diff --git a/src/airunner/widgets/canvas_plus/templates/standard_image_widget.ui b/src/airunner/widgets/canvas_plus/templates/standard_image_widget.ui
index 77553b23c..7dae3c629 100644
--- a/src/airunner/widgets/canvas_plus/templates/standard_image_widget.ui
+++ b/src/airunner/widgets/canvas_plus/templates/standard_image_widget.ui
@@ -39,14 +39,14 @@
-
+
0
0
- 450
+ 0
0
@@ -80,14 +80,15 @@
Qt::LeftToRight
- 0
+ 4
-
+
+ ..
Settings
@@ -118,40 +119,23 @@
0
0
- 444
+ 438
1211
- -
-
-
-
-
-
-
-
-
-
-
-
- generator.random_seed
-
-
-
- -
-
-
-
-
-
-
-
-
- generator.random_latents_seed
-
-
-
-
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
-
@@ -225,24 +209,24 @@
- -
-
+
-
+
-
-
+
- DDIM ETA
+ Samples
handle_value_change
- 1
+ 0
- 10
+ 500
- 10
+ 500.000000000000000
false
@@ -260,26 +244,23 @@
1
- generator.ddim_eta
+ generator.n_samples
-
-
+
- Frames
-
-
- handle_value_change
+ Clip Skip
0
- 200
+ 11
- 200.000000000000000
+ 12.000000000000000
false
@@ -291,36 +272,39 @@
1
- 1
+ 0
- 1
+ 0
+
+
+ handle_value_change
- generator.n_samples
+ generator.clip_skip
- -
-
+
-
+
-
-
+
- Samples
+ DDIM ETA
handle_value_change
- 0
+ 1
- 500
+ 10
- 500.000000000000000
+ 10
false
@@ -338,23 +322,26 @@
1
- generator.n_samples
+ generator.ddim_eta
-
-
+
- Clip Skip
+ Frames
+
+
+ handle_value_change
0
- 11
+ 200
- 12.000000000000000
+ 200.000000000000000
false
@@ -366,16 +353,13 @@
1
- 0
+ 1
- 0
-
-
- handle_value_change
+ 1
- generator.clip_skip
+ generator.n_samples
@@ -469,33 +453,36 @@
- -
-
+
-
+
+
-
+
+
+
+
+
+
+
+
+ generator.random_seed
+
+
+
-
-
-
- PointingHandCursor
+
+
+
-
- Variation
+
+
+
+
+ generator.random_latents_seed
- -
-
-
- Qt::Vertical
-
-
-
- 20
- 40
-
-
-
-
@@ -773,8 +760,8 @@
- 532
- 532
+ 0
+ 0
@@ -784,7 +771,7 @@
- PointingHandCursor
+ ArrowCursor
true
@@ -856,8 +843,8 @@
upscale_model_changed(QString)
- 144
- 67
+ 111
+ 59
79
@@ -872,8 +859,8 @@
face_enhance_toggled(bool)
- 306
- 66
+ 327
+ 58
286
@@ -888,8 +875,8 @@
upscale_number_changed(int)
- 86
- 92
+ 88
+ 72
0
@@ -904,8 +891,8 @@
upscale_2x_clicked()
- 249
- 100
+ 327
+ 89
228
@@ -920,8 +907,8 @@
handle_advanced_settings_checkbox(bool)
- 58
- 58
+ 60
+ 37
24
@@ -936,8 +923,8 @@
similar_image()
- 206
- 142
+ 255
+ 135
-3
@@ -952,8 +939,8 @@
similar_batch()
- 350
- 143
+ 399
+ 135
259
@@ -961,6 +948,70 @@
+
+ pipeline
+ currentTextChanged(QString)
+ standard_image_widget
+ handle_pipeline_changed(QString)
+
+
+ 352
+ 74
+
+
+ 464
+ -11
+
+
+
+
+ version
+ currentTextChanged(QString)
+ standard_image_widget
+ handle_version_changed(QString)
+
+
+ 166
+ 116
+
+
+ 458
+ -17
+
+
+
+
+ model
+ currentTextChanged(QString)
+ standard_image_widget
+ handle_model_changed(QString)
+
+
+ 121
+ 170
+
+
+ 418
+ -7
+
+
+
+
+ scheduler
+ currentTextChanged(QString)
+ standard_image_widget
+ handle_scheduler_changed(QString)
+
+
+ 299
+ 221
+
+
+ 342
+ -12
+
+
+
image_to_canvas()
@@ -981,5 +1032,9 @@
face_enhance_toggled(bool)
handle_advanced_settings_checkbox(bool)
upscale_number_changed(int)
+ handle_pipeline_changed(QString)
+ handle_version_changed(QString)
+ handle_model_changed(QString)
+ handle_scheduler_changed(QString)
diff --git a/src/airunner/widgets/canvas_plus/templates/standard_image_widget_ui.py b/src/airunner/widgets/canvas_plus/templates/standard_image_widget_ui.py
index 58a8df40b..4e7e3e8d0 100644
--- a/src/airunner/widgets/canvas_plus/templates/standard_image_widget_ui.py
+++ b/src/airunner/widgets/canvas_plus/templates/standard_image_widget_ui.py
@@ -21,12 +21,12 @@ def setupUi(self, standard_image_widget):
self.splitter.setOrientation(QtCore.Qt.Orientation.Horizontal)
self.splitter.setObjectName("splitter")
self.sidebar = QtWidgets.QWidget(parent=self.splitter)
- sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Maximum, QtWidgets.QSizePolicy.Policy.Preferred)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.MinimumExpanding, QtWidgets.QSizePolicy.Policy.Preferred)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(self.sidebar.sizePolicy().hasHeightForWidth())
self.sidebar.setSizePolicy(sizePolicy)
- self.sidebar.setMinimumSize(QtCore.QSize(450, 0))
+ self.sidebar.setMinimumSize(QtCore.QSize(0, 0))
self.sidebar.setMaximumSize(QtCore.QSize(450, 16777215))
self.sidebar.setObjectName("sidebar")
self.gridLayout_7 = QtWidgets.QGridLayout(self.sidebar)
@@ -48,23 +48,12 @@ def setupUi(self, standard_image_widget):
self.scrollArea_2.setWidgetResizable(True)
self.scrollArea_2.setObjectName("scrollArea_2")
self.scrollAreaWidgetContents_2 = QtWidgets.QWidget()
- self.scrollAreaWidgetContents_2.setGeometry(QtCore.QRect(0, 0, 444, 1211))
+ self.scrollAreaWidgetContents_2.setGeometry(QtCore.QRect(0, 0, 438, 1211))
self.scrollAreaWidgetContents_2.setObjectName("scrollAreaWidgetContents_2")
self.gridLayout_10 = QtWidgets.QGridLayout(self.scrollAreaWidgetContents_2)
self.gridLayout_10.setObjectName("gridLayout_10")
- self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_3.setObjectName("horizontalLayout_3")
- self.seed_widget = SeedWidget(parent=self.scrollAreaWidgetContents_2)
- self.seed_widget.setProperty("generator_section", "")
- self.seed_widget.setProperty("generator_name", "")
- self.seed_widget.setObjectName("seed_widget")
- self.horizontalLayout_3.addWidget(self.seed_widget)
- self.seed_widget_latents = SeedWidget(parent=self.scrollAreaWidgetContents_2)
- self.seed_widget_latents.setProperty("generator_section", "")
- self.seed_widget_latents.setProperty("generator_name", "")
- self.seed_widget_latents.setObjectName("seed_widget_latents")
- self.horizontalLayout_3.addWidget(self.seed_widget_latents)
- self.gridLayout_10.addLayout(self.horizontalLayout_3, 1, 0, 1, 1)
+ spacerItem = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding)
+ self.gridLayout_10.addItem(spacerItem, 5, 0, 1, 1)
self.horizontalLayout_5 = QtWidgets.QHBoxLayout()
self.horizontalLayout_5.setObjectName("horizontalLayout_5")
self.steps_widget = SliderWidget(parent=self.scrollAreaWidgetContents_2)
@@ -92,6 +81,34 @@ def setupUi(self, standard_image_widget):
self.scale_widget.setObjectName("scale_widget")
self.horizontalLayout_5.addWidget(self.scale_widget)
self.gridLayout_10.addLayout(self.horizontalLayout_5, 2, 0, 1, 1)
+ self.horizontalLayout_6 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_6.setObjectName("horizontalLayout_6")
+ self.samples_widget_2 = SliderWidget(parent=self.scrollAreaWidgetContents_2)
+ self.samples_widget_2.setProperty("slider_callback", "handle_value_change")
+ self.samples_widget_2.setProperty("current_value", 0)
+ self.samples_widget_2.setProperty("slider_maximum", 500)
+ self.samples_widget_2.setProperty("spinbox_maximum", 500.0)
+ self.samples_widget_2.setProperty("display_as_float", False)
+ self.samples_widget_2.setProperty("spinbox_single_step", 1)
+ self.samples_widget_2.setProperty("spinbox_page_step", 1)
+ self.samples_widget_2.setProperty("spinbox_minimum", 1)
+ self.samples_widget_2.setProperty("slider_minimum", 1)
+ self.samples_widget_2.setObjectName("samples_widget_2")
+ self.horizontalLayout_6.addWidget(self.samples_widget_2)
+ self.clip_skip_slider_widget = SliderWidget(parent=self.scrollAreaWidgetContents_2)
+ self.clip_skip_slider_widget.setProperty("current_value", 0)
+ self.clip_skip_slider_widget.setProperty("slider_maximum", 11)
+ self.clip_skip_slider_widget.setProperty("spinbox_maximum", 12.0)
+ self.clip_skip_slider_widget.setProperty("display_as_float", False)
+ self.clip_skip_slider_widget.setProperty("spinbox_single_step", 1)
+ self.clip_skip_slider_widget.setProperty("spinbox_page_step", 1)
+ self.clip_skip_slider_widget.setProperty("spinbox_minimum", 0)
+ self.clip_skip_slider_widget.setProperty("slider_minimum", 0)
+ self.clip_skip_slider_widget.setProperty("slider_callback", "handle_value_change")
+ self.clip_skip_slider_widget.setProperty("settings_property", "generator.clip_skip")
+ self.clip_skip_slider_widget.setObjectName("clip_skip_slider_widget")
+ self.horizontalLayout_6.addWidget(self.clip_skip_slider_widget)
+ self.gridLayout_10.addLayout(self.horizontalLayout_6, 3, 0, 1, 1)
self.ddim_frames = QtWidgets.QHBoxLayout()
self.ddim_frames.setObjectName("ddim_frames")
self.ddim_eta_slider_widget = SliderWidget(parent=self.scrollAreaWidgetContents_2)
@@ -121,34 +138,6 @@ def setupUi(self, standard_image_widget):
self.frames_slider_widget.setObjectName("frames_slider_widget")
self.ddim_frames.addWidget(self.frames_slider_widget)
self.gridLayout_10.addLayout(self.ddim_frames, 4, 0, 1, 1)
- self.horizontalLayout_6 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_6.setObjectName("horizontalLayout_6")
- self.samples_widget_2 = SliderWidget(parent=self.scrollAreaWidgetContents_2)
- self.samples_widget_2.setProperty("slider_callback", "handle_value_change")
- self.samples_widget_2.setProperty("current_value", 0)
- self.samples_widget_2.setProperty("slider_maximum", 500)
- self.samples_widget_2.setProperty("spinbox_maximum", 500.0)
- self.samples_widget_2.setProperty("display_as_float", False)
- self.samples_widget_2.setProperty("spinbox_single_step", 1)
- self.samples_widget_2.setProperty("spinbox_page_step", 1)
- self.samples_widget_2.setProperty("spinbox_minimum", 1)
- self.samples_widget_2.setProperty("slider_minimum", 1)
- self.samples_widget_2.setObjectName("samples_widget_2")
- self.horizontalLayout_6.addWidget(self.samples_widget_2)
- self.clip_skip_slider_widget = SliderWidget(parent=self.scrollAreaWidgetContents_2)
- self.clip_skip_slider_widget.setProperty("current_value", 0)
- self.clip_skip_slider_widget.setProperty("slider_maximum", 11)
- self.clip_skip_slider_widget.setProperty("spinbox_maximum", 12.0)
- self.clip_skip_slider_widget.setProperty("display_as_float", False)
- self.clip_skip_slider_widget.setProperty("spinbox_single_step", 1)
- self.clip_skip_slider_widget.setProperty("spinbox_page_step", 1)
- self.clip_skip_slider_widget.setProperty("spinbox_minimum", 0)
- self.clip_skip_slider_widget.setProperty("slider_minimum", 0)
- self.clip_skip_slider_widget.setProperty("slider_callback", "handle_value_change")
- self.clip_skip_slider_widget.setProperty("settings_property", "generator.clip_skip")
- self.clip_skip_slider_widget.setObjectName("clip_skip_slider_widget")
- self.horizontalLayout_6.addWidget(self.clip_skip_slider_widget)
- self.gridLayout_10.addLayout(self.horizontalLayout_6, 3, 0, 1, 1)
self.verticalLayout_3 = QtWidgets.QVBoxLayout()
self.verticalLayout_3.setObjectName("verticalLayout_3")
self.verticalLayout_4 = QtWidgets.QVBoxLayout()
@@ -205,15 +194,19 @@ def setupUi(self, standard_image_widget):
self.verticalLayout_2.addWidget(self.scheduler)
self.verticalLayout_3.addLayout(self.verticalLayout_2)
self.gridLayout_10.addLayout(self.verticalLayout_3, 0, 0, 1, 1)
- self.verticalLayout_5 = QtWidgets.QVBoxLayout()
- self.verticalLayout_5.setObjectName("verticalLayout_5")
- self.variation_checkbox = QtWidgets.QCheckBox(parent=self.scrollAreaWidgetContents_2)
- self.variation_checkbox.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor))
- self.variation_checkbox.setObjectName("variation_checkbox")
- self.verticalLayout_5.addWidget(self.variation_checkbox)
- self.gridLayout_10.addLayout(self.verticalLayout_5, 5, 0, 1, 1)
- spacerItem = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding)
- self.gridLayout_10.addItem(spacerItem, 6, 0, 1, 1)
+ self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_3.setObjectName("horizontalLayout_3")
+ self.seed_widget = SeedWidget(parent=self.scrollAreaWidgetContents_2)
+ self.seed_widget.setProperty("generator_section", "")
+ self.seed_widget.setProperty("generator_name", "")
+ self.seed_widget.setObjectName("seed_widget")
+ self.horizontalLayout_3.addWidget(self.seed_widget)
+ self.seed_widget_latents = SeedWidget(parent=self.scrollAreaWidgetContents_2)
+ self.seed_widget_latents.setProperty("generator_section", "")
+ self.seed_widget_latents.setProperty("generator_name", "")
+ self.seed_widget_latents.setObjectName("seed_widget_latents")
+ self.horizontalLayout_3.addWidget(self.seed_widget_latents)
+ self.gridLayout_10.addLayout(self.horizontalLayout_3, 1, 0, 1, 1)
self.scrollArea_2.setWidget(self.scrollAreaWidgetContents_2)
self.gridLayout_11.addWidget(self.scrollArea_2, 0, 0, 1, 1)
icon = QtGui.QIcon.fromTheme("document-properties")
@@ -336,16 +329,16 @@ def setupUi(self, standard_image_widget):
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(self.canvas_widget.sizePolicy().hasHeightForWidth())
self.canvas_widget.setSizePolicy(sizePolicy)
- self.canvas_widget.setMinimumSize(QtCore.QSize(532, 532))
+ self.canvas_widget.setMinimumSize(QtCore.QSize(0, 0))
self.canvas_widget.setMaximumSize(QtCore.QSize(16777215, 16777215))
- self.canvas_widget.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor))
+ self.canvas_widget.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.ArrowCursor))
self.canvas_widget.setAcceptDrops(True)
self.canvas_widget.setObjectName("canvas_widget")
self.gridLayout.addWidget(self.canvas_widget, 0, 0, 1, 1)
self.gridLayout_2.addWidget(self.splitter, 0, 0, 1, 1)
self.retranslateUi(standard_image_widget)
- self.tabWidget.setCurrentIndex(0)
+ self.tabWidget.setCurrentIndex(4)
self.upscale_model.currentTextChanged['QString'].connect(standard_image_widget.upscale_model_changed) # type: ignore
self.face_enhance.clicked['bool'].connect(standard_image_widget.face_enhance_toggled) # type: ignore
self.comboBox.currentIndexChanged['int'].connect(standard_image_widget.upscale_number_changed) # type: ignore
@@ -353,25 +346,28 @@ def setupUi(self, standard_image_widget):
self.advanced_settings_checkbox.clicked['bool'].connect(standard_image_widget.handle_advanced_settings_checkbox) # type: ignore
self.generate_single_simillar_button.clicked.connect(standard_image_widget.similar_image) # type: ignore
self.generate_batch_similar_button.clicked.connect(standard_image_widget.similar_batch) # type: ignore
+ self.pipeline.currentTextChanged['QString'].connect(standard_image_widget.handle_pipeline_changed) # type: ignore
+ self.version.currentTextChanged['QString'].connect(standard_image_widget.handle_version_changed) # type: ignore
+ self.model.currentTextChanged['QString'].connect(standard_image_widget.handle_model_changed) # type: ignore
+ self.scheduler.currentTextChanged['QString'].connect(standard_image_widget.handle_scheduler_changed) # type: ignore
QtCore.QMetaObject.connectSlotsByName(standard_image_widget)
def retranslateUi(self, standard_image_widget):
_translate = QtCore.QCoreApplication.translate
standard_image_widget.setWindowTitle(_translate("standard_image_widget", "Form"))
- self.seed_widget.setProperty("property_name", _translate("standard_image_widget", "generator.random_seed"))
- self.seed_widget_latents.setProperty("property_name", _translate("standard_image_widget", "generator.random_latents_seed"))
self.steps_widget.setProperty("label_text", _translate("standard_image_widget", "Steps"))
self.scale_widget.setProperty("label_text", _translate("standard_image_widget", "Scale"))
- self.ddim_eta_slider_widget.setProperty("label_text", _translate("standard_image_widget", "DDIM ETA"))
- self.frames_slider_widget.setProperty("label_text", _translate("standard_image_widget", "Frames"))
self.samples_widget_2.setProperty("label_text", _translate("standard_image_widget", "Samples"))
self.samples_widget_2.setProperty("settings_property", _translate("standard_image_widget", "generator.n_samples"))
self.clip_skip_slider_widget.setProperty("label_text", _translate("standard_image_widget", "Clip Skip"))
+ self.ddim_eta_slider_widget.setProperty("label_text", _translate("standard_image_widget", "DDIM ETA"))
+ self.frames_slider_widget.setProperty("label_text", _translate("standard_image_widget", "Frames"))
self.label_5.setText(_translate("standard_image_widget", "Pipeline"))
self.label_6.setText(_translate("standard_image_widget", "Version"))
self.label_3.setText(_translate("standard_image_widget", "Model"))
self.label_4.setText(_translate("standard_image_widget", "Scheduler"))
- self.variation_checkbox.setText(_translate("standard_image_widget", "Variation"))
+ self.seed_widget.setProperty("property_name", _translate("standard_image_widget", "generator.random_seed"))
+ self.seed_widget_latents.setProperty("property_name", _translate("standard_image_widget", "generator.random_latents_seed"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_6), _translate("standard_image_widget", "Settings"))
self.tabWidget.setTabToolTip(self.tabWidget.indexOf(self.tab_6), _translate("standard_image_widget", "Stable Diffusion settings"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_5), _translate("standard_image_widget", "Presets"))
diff --git a/src/airunner/widgets/controlnet_settings/controlnet_settings_widget.py b/src/airunner/widgets/controlnet_settings/controlnet_settings_widget.py
index 86921d587..0222dad8f 100644
--- a/src/airunner/widgets/controlnet_settings/controlnet_settings_widget.py
+++ b/src/airunner/widgets/controlnet_settings/controlnet_settings_widget.py
@@ -245,15 +245,8 @@ def set_mask_thumbnail(self):
# clear the image
self.ui.mask_thumbnail.clear()
- def initialize(
- self,
- generator_name,
- generator_section
- ):
- super().initialize(
- generator_name,
- generator_section
- )
+ def initialize(self):
+ super().initialize()
self.initialize_combobox()
def initialize_combobox(self):
diff --git a/src/airunner/widgets/deterministic/__init__.py b/src/airunner/widgets/deterministic/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/airunner/widgets/deterministic/deterministic_widget.py b/src/airunner/widgets/deterministic/deterministic_widget.py
deleted file mode 100644
index 25e7191c0..000000000
--- a/src/airunner/widgets/deterministic/deterministic_widget.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from airunner.widgets.base_widget import BaseWidget
-from airunner.widgets.deterministic.templates.deterministic_widget_ui import Ui_deterministic_widget
-
-
-class DeterministicWidget(BaseWidget):
- widget_class_ = Ui_deterministic_widget
-
- @property
- def batch_size(self):
- return self.ui.images_per_batch.value()
-
- @property
- def category(self):
- return self.ui.category.text()
-
- def action_value_changed_images_per_batch(self, val):
- self.settings_manager.set_value("determinisitic_settings.images_per_batch", val)
-
- def action_text_changed_category(self, val):
- self.settings_manager.set_value("determinisitic_settings.category", val)
-
- def action_clicked_button_generate_batch(self):
- print("generate batch clicked")
diff --git a/src/airunner/widgets/deterministic/templates/__init__.py b/src/airunner/widgets/deterministic/templates/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/airunner/widgets/deterministic/templates/deterministic_widget.ui b/src/airunner/widgets/deterministic/templates/deterministic_widget.ui
deleted file mode 100644
index 9f64b91fd..000000000
--- a/src/airunner/widgets/deterministic/templates/deterministic_widget.ui
+++ /dev/null
@@ -1,220 +0,0 @@
-
-
- deterministic_widget
-
-
-
- 0
- 0
- 322
- 256
-
-
-
-
- 16777215
- 16777215
-
-
-
-
- 8
-
-
-
- Form
-
-
- -
-
-
- true
-
-
-
-
- 0
- 0
- 302
- 236
-
-
-
-
-
-
-
-
- 8
- true
-
-
-
- Cateogry
-
-
-
-
-
-
-
- 8
- true
-
-
-
-
-
-
-
- -
-
-
- Qt::Horizontal
-
-
-
- -
-
-
- -
-
-
-
- 10
- true
-
-
-
- Deterministic generation
-
-
-
- -
-
-
-
- 8
- true
-
-
-
- Images per-batch
-
-
-
-
-
-
-
- 8
- false
-
-
-
- 1
-
-
- 16
-
-
- 4
-
-
-
-
-
-
- -
-
-
-
- 8
-
-
-
- Generate Deterministic Batch
-
-
-
- -
-
-
- Qt::Vertical
-
-
-
- 20
- 7
-
-
-
-
-
-
-
-
-
-
-
-
- SeedWidget
- QWidget
- airunner/widgets/seed/seed_widget
- 1
-
-
-
-
-
- category
- currentTextChanged(QString)
- deterministic_widget
- action_text_changed_category(QString)
-
-
- 59
- 92
-
-
- 15
- -12
-
-
-
-
- images_per_batch
- valueChanged(int)
- deterministic_widget
- action_value_changed_images_per_batch(int)
-
-
- 155
- 166
-
-
- 93
- -13
-
-
-
-
- generate_batches_button
- clicked()
- deterministic_widget
- action_clicked_button_generate_batch()
-
-
- 96
- 239
-
-
- 115
- -5
-
-
-
-
-
- action_text_changed_category(QString)
- action_value_changed_images_per_batch(int)
- action_clicked_button_generate_batch()
-
-
diff --git a/src/airunner/widgets/deterministic/templates/deterministic_widget_ui.py b/src/airunner/widgets/deterministic/templates/deterministic_widget_ui.py
deleted file mode 100644
index 50c2e3b01..000000000
--- a/src/airunner/widgets/deterministic/templates/deterministic_widget_ui.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/deterministic/templates/deterministic_widget.ui'
-#
-# Created by: PyQt6 UI code generator 6.4.2
-#
-# WARNING: Any manual changes made to this file will be lost when pyuic6 is
-# run again. Do not edit this file unless you know what you are doing.
-
-
-from PyQt6 import QtCore, QtGui, QtWidgets
-
-
-class Ui_deterministic_widget(object):
- def setupUi(self, deterministic_widget):
- deterministic_widget.setObjectName("deterministic_widget")
- deterministic_widget.resize(322, 256)
- deterministic_widget.setMaximumSize(QtCore.QSize(16777215, 16777215))
- font = QtGui.QFont()
- font.setPointSize(8)
- deterministic_widget.setFont(font)
- self.gridLayout_4 = QtWidgets.QGridLayout(deterministic_widget)
- self.gridLayout_4.setObjectName("gridLayout_4")
- self.scrollArea = QtWidgets.QScrollArea(parent=deterministic_widget)
- self.scrollArea.setWidgetResizable(True)
- self.scrollArea.setObjectName("scrollArea")
- self.scrollAreaWidgetContents = QtWidgets.QWidget()
- self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 302, 236))
- self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents")
- self.gridLayout_2 = QtWidgets.QGridLayout(self.scrollAreaWidgetContents)
- self.gridLayout_2.setObjectName("gridLayout_2")
- self.groupBox_2 = QtWidgets.QGroupBox(parent=self.scrollAreaWidgetContents)
- font = QtGui.QFont()
- font.setPointSize(8)
- font.setBold(True)
- self.groupBox_2.setFont(font)
- self.groupBox_2.setObjectName("groupBox_2")
- self.gridLayout_3 = QtWidgets.QGridLayout(self.groupBox_2)
- self.gridLayout_3.setObjectName("gridLayout_3")
- self.category = QtWidgets.QComboBox(parent=self.groupBox_2)
- font = QtGui.QFont()
- font.setPointSize(8)
- font.setBold(True)
- self.category.setFont(font)
- self.category.setObjectName("category")
- self.gridLayout_3.addWidget(self.category, 0, 0, 1, 1)
- self.gridLayout_2.addWidget(self.groupBox_2, 2, 0, 1, 1)
- self.line = QtWidgets.QFrame(parent=self.scrollAreaWidgetContents)
- self.line.setFrameShape(QtWidgets.QFrame.Shape.HLine)
- self.line.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken)
- self.line.setObjectName("line")
- self.gridLayout_2.addWidget(self.line, 1, 0, 1, 1)
- self.deterministic_seed = SeedWidget(parent=self.scrollAreaWidgetContents)
- self.deterministic_seed.setObjectName("deterministic_seed")
- self.gridLayout_2.addWidget(self.deterministic_seed, 4, 0, 1, 1)
- self.label = QtWidgets.QLabel(parent=self.scrollAreaWidgetContents)
- font = QtGui.QFont()
- font.setPointSize(10)
- font.setBold(True)
- self.label.setFont(font)
- self.label.setObjectName("label")
- self.gridLayout_2.addWidget(self.label, 0, 0, 1, 1)
- self.groupBox = QtWidgets.QGroupBox(parent=self.scrollAreaWidgetContents)
- font = QtGui.QFont()
- font.setPointSize(8)
- font.setBold(True)
- self.groupBox.setFont(font)
- self.groupBox.setObjectName("groupBox")
- self.gridLayout = QtWidgets.QGridLayout(self.groupBox)
- self.gridLayout.setObjectName("gridLayout")
- self.images_per_batch = QtWidgets.QSpinBox(parent=self.groupBox)
- font = QtGui.QFont()
- font.setPointSize(8)
- font.setBold(False)
- self.images_per_batch.setFont(font)
- self.images_per_batch.setMinimum(1)
- self.images_per_batch.setMaximum(16)
- self.images_per_batch.setProperty("value", 4)
- self.images_per_batch.setObjectName("images_per_batch")
- self.gridLayout.addWidget(self.images_per_batch, 0, 0, 1, 1)
- self.gridLayout_2.addWidget(self.groupBox, 3, 0, 1, 1)
- self.generate_batches_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents)
- font = QtGui.QFont()
- font.setPointSize(8)
- self.generate_batches_button.setFont(font)
- self.generate_batches_button.setObjectName("generate_batches_button")
- self.gridLayout_2.addWidget(self.generate_batches_button, 5, 0, 1, 1)
- spacerItem = QtWidgets.QSpacerItem(20, 7, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding)
- self.gridLayout_2.addItem(spacerItem, 6, 0, 1, 1)
- self.scrollArea.setWidget(self.scrollAreaWidgetContents)
- self.gridLayout_4.addWidget(self.scrollArea, 0, 0, 1, 1)
-
- self.retranslateUi(deterministic_widget)
- self.category.currentTextChanged['QString'].connect(deterministic_widget.action_text_changed_category) # type: ignore
- self.images_per_batch.valueChanged['int'].connect(deterministic_widget.action_value_changed_images_per_batch) # type: ignore
- self.generate_batches_button.clicked.connect(deterministic_widget.action_clicked_button_generate_batch) # type: ignore
- QtCore.QMetaObject.connectSlotsByName(deterministic_widget)
-
- def retranslateUi(self, deterministic_widget):
- _translate = QtCore.QCoreApplication.translate
- deterministic_widget.setWindowTitle(_translate("deterministic_widget", "Form"))
- self.groupBox_2.setTitle(_translate("deterministic_widget", "Cateogry"))
- self.label.setText(_translate("deterministic_widget", "Deterministic generation"))
- self.groupBox.setTitle(_translate("deterministic_widget", "Images per-batch"))
- self.generate_batches_button.setText(_translate("deterministic_widget", "Generate Deterministic Batch"))
-from airunner.widgets.seed.seed_widget import SeedWidget
diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py
index 7c9a5924c..5cf58d522 100644
--- a/src/airunner/widgets/generator_form/generator_form_widget.py
+++ b/src/airunner/widgets/generator_form/generator_form_widget.py
@@ -1,13 +1,11 @@
import random
from PIL import Image
-from PyQt6.QtCore import pyqtSignal, QRect, QTimer
-from PyQt6.QtWidgets import QWidget
+from PyQt6.QtCore import pyqtSignal, QRect
from airunner.aihandler.settings import MAX_SEED
from airunner.data.db import session
-from airunner.data.models import AIModel, ActiveGridSettings, CanvasSettings, Pipeline, GeneratorSetting
-from airunner.utils import get_session
+from airunner.data.models import ActiveGridSettings, CanvasSettings
from airunner.widgets.base_widget import BaseWidget
from airunner.widgets.generator_form.templates.generatorform_ui import Ui_generator_form
@@ -25,6 +23,9 @@ class GeneratorForm(BaseWidget):
initialized = False
parent = None
generate_signal = pyqtSignal(dict)
+ icons = (
+ ("artificial-intelligence-ai-chip-icon", "ai_button"),
+ )
@property
def is_txt2img(self):
@@ -56,19 +57,11 @@ def is_txt2vid(self):
@property
def generator_section(self):
- try:
- return self.property("generator_section")
- except Exception as e:
- print(e)
- return None
+ return getattr(self.settings_manager, f"current_section_{self.settings_manager.current_image_generator}")
@property
def generator_name(self):
- try:
- return self.property("generator_name")
- except Exception as e:
- print(e)
- return None
+ return self.settings_manager.current_image_generator
@property
def generator_settings(self):
@@ -92,7 +85,7 @@ def latents_seed(self):
@latents_seed.setter
def latents_seed(self, val):
self.settings_manager.set_value("generator.latents_seed", val)
- self.app.ui.standard_image_widget.ui.seed_widget_latents.ui.lineEdit.setText(str(val))
+ self.app.standard_image_panel.ui.seed_widget_latents.ui.lineEdit.setText(str(val))
@property
def seed(self):
@@ -101,7 +94,7 @@ def seed(self):
@seed.setter
def seed(self, val):
self.settings_manager.set_value("generator.seed", val)
- self.app.ui.standard_image_widget.ui.seed_widget.ui.lineEdit.setText(str(val))
+ self.app.standard_image_panel.ui.seed_widget.ui.lineEdit.setText(str(val))
@property
def image_scale(self):
@@ -125,7 +118,7 @@ def enable_controlnet(self):
@property
def controlnet_image(self):
- return self.app.ui.standard_image_widget.ui.controlnet_settings.current_controlnet_image
+ return self.app.standard_image_panel.ui.controlnet_settings.current_controlnet_image
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -133,11 +126,6 @@ def __init__(self, *args, **kwargs):
self.canvas_settings = session.query(CanvasSettings).first()
self.settings_manager.changed_signal.connect(self.handle_changed_signal)
- self.initialize()
-
- def enable_preset(self, id):
- self.settings_manager.set_value("generator_settings_override_id", id)
- self.initialize()
def toggle_advanced_generation(self):
advanced_mode = self.settings_manager.enable_advanced_mode
@@ -184,21 +172,6 @@ def handle_negative_prompt_changed(self):
def toggle_prompt_builder_checkbox(self, toggled):
pass
- def handle_model_changed(self, name):
- if not self.initialized:
- return
- self.settings_manager.set_value("generator.model", name)
- self.changed_signal.emit("generator.model", name)
-
- def handle_scheduler_changed(self, name):
- if not self.initialized:
- return
- self.settings_manager.set_value("generator.scheduler", name)
- self.changed_signal.emit("generator.scheduler", name)
-
- def toggle_variation(self, toggled):
- pass
-
def handle_generate_button_clicked(self):
self.start_progress_bar()
self.generate(image=self.app.current_active_image())
@@ -211,10 +184,10 @@ def handle_interrupt_button_clicked(self):
def generate(self, image=None, seed=None):
if seed is None:
- seed = self.app.ui.standard_image_widget.ui.seed_widget.seed
- if self.app.ui.standard_image_widget.ui.samples_widget.current_value > 1:
+ seed = self.app.standard_image_panel.ui.seed_widget.seed
+ if self.app.standard_image_panel.ui.samples_widget.current_value > 1:
self.app.client.do_process_queue = False
- total_samples = self.app.ui.standard_image_widget.ui.samples_widget.current_value if not self.is_txt2vid else 1
+ total_samples = self.settings_manager.generator.n_samples
for n in range(total_samples):
if self.settings_manager.generator.use_prompt_builder and n > 0:
seed = int(seed) + n
@@ -242,7 +215,7 @@ def call_generate(self, image=None, seed=None, override_data=None):
self.settings_manager.generator.enable_input_image
)
if enable_input_image:
- input_image = self.app.ui.standard_image_widget.ui.input_image_widget.current_input_image
+ input_image = self.app.standard_image_panel.ui.input_image_widget.current_input_image
elif self.generator_section == "txt2img":
input_image = override_data.get("input_image", None)
image = input_image
@@ -258,7 +231,7 @@ def call_generate(self, image=None, seed=None, override_data=None):
if image is None:
if self.is_txt2img:
return self.do_generate(seed=seed, override_data=override_data)
- # Create a transparent image the size of self.canvas.active_grid_area_rect
+ # Create a transparent image the size of self.app.canvas_widget.active_grid_area_rect
width = self.settings_manager.working_width
height = self.settings_manager.working_height
image = Image.new("RGBA", (int(width), int(height)), (0, 0, 0, 0))
@@ -284,14 +257,14 @@ def call_generate(self, image=None, seed=None, override_data=None):
# Get the cropped image
cropped_outpaint_box_rect = self.active_rect
# crop_location = (
- # cropped_outpaint_box_rect.x() - self.canvas.image_pivot_point.x(),
- # cropped_outpaint_box_rect.y() - self.canvas.image_pivot_point.y(),
- # cropped_outpaint_box_rect.width() - self.canvas.image_pivot_point.x(),
- # cropped_outpaint_box_rect.height() - self.canvas.image_pivot_point.y()
+ # cropped_outpaint_box_rect.x() - self.app.canvas_widget.image_pivot_point.x(),
+ # cropped_outpaint_box_rect.y() - self.app.canvas_widget.image_pivot_point.y(),
+ # cropped_outpaint_box_rect.width() - self.app.canvas_widget.image_pivot_point.x(),
+ # cropped_outpaint_box_rect.height() - self.app.canvas_widget.image_pivot_point.y()
# )
crop_location = (
- cropped_outpaint_box_rect.x() - self.canvas.current_layer.pos_x,
- cropped_outpaint_box_rect.y() - self.canvas.current_layer.pos_y,
+ cropped_outpaint_box_rect.x() - self.app.canvas_widget.current_layer.pos_x,
+ cropped_outpaint_box_rect.y() - self.app.canvas_widget.current_layer.pos_y,
cropped_outpaint_box_rect.width(),
cropped_outpaint_box_rect.height()
)
@@ -379,7 +352,6 @@ def do_generate(self, extra_options=None, seed=None, latents_seed=None, do_deter
# get the model from the database
- print(model_data)
model = self.settings_manager.models.filter_by(
name=model_data["name"] if "name" in model_data \
else self.settings_manager.generator.model
@@ -474,8 +446,8 @@ def do_generate(self, extra_options=None, seed=None, latents_seed=None, do_deter
options["controlnet_image"] = self.controlnet_image
if action == "superresolution":
- options["original_image_width"] = self.canvas.current_active_image_data.image.width
- options["original_image_height"] = self.canvas.current_active_image_data.image.height
+ options["original_image_width"] = self.app.canvas_widget.current_active_image_data.image.width
+ options["original_image_height"] = self.app.canvas_widget.current_active_image_data.image.height
if action in ["txt2img", "img2img", "outpaint", "depth2img"]:
options[f"strength"] = strength
@@ -500,25 +472,6 @@ def do_generate(self, extra_options=None, seed=None, latents_seed=None, do_deter
}
self.app.client.message = data
- def do_deterministic_generation(self, extra_options):
- action = self.deterministic_data["action"]
- options = self.deterministic_data["options"]
- options[f"prompt"] = self.deterministic_data[f"prompt"][self.deterministic_index]
- memory_options = self.get_memory_options()
- data = {
- "action": action,
- "options": {
- **options,
- **extra_options,
- **memory_options,
- "batch_size": self.settings_manager.deterministic_settings.batch_size,
- "deterministic_generation": True,
- "deterministic_seed": self.settings_manager.deterministic_settings.seed,
- "deterministic_style": self.settings_manager.deterministic_settings.style,
- }
- }
- self.app.client.message = data
-
def get_memory_options(self):
return {
"use_last_channels": self.settings_manager.memory_settings.use_last_channels,
@@ -545,8 +498,8 @@ def set_seed(self, seed=None, latents_seed=None):
self.update_seed()
def update_seed(self):
- self.app.ui.standard_image_widget.ui.seed_widget.update_seed()
- self.app.ui.standard_image_widget.ui.seed_widget_latents.update_seed()
+ self.app.standard_image_panel.ui.seed_widget.update_seed()
+ self.app.standard_image_panel.ui.seed_widget_latents.update_seed()
def set_primary_seed(self, seed=None):
if self.deterministic_data:
@@ -629,3 +582,24 @@ def new_batch(self, index, image, data):
self.deterministic = False
self.deterministic_data = None
self.deterministic_images = None
+
+ def set_progress_bar_value(self, tab_section, section, value):
+ progressbar = self.ui.progress_bar
+ if not progressbar:
+ return
+ if progressbar.maximum() == 0:
+ progressbar.setRange(0, 100)
+ progressbar.setValue(value)
+
+ def stop_progress_bar(self, tab_section, section):
+ progressbar = self.ui.progress_bar
+ if not progressbar:
+ return
+ progressbar.setRange(0, 100)
+ progressbar.setValue(100)
+
+ def update_prompt(self, prompt):
+ self.ui.prompt.setPlainText(prompt)
+
+ def update_negative_prompt(self, prompt):
+ self.ui.negative_prompt.setPlainText(prompt)
\ No newline at end of file
diff --git a/src/airunner/widgets/generator_form/templates/generatorform.ui b/src/airunner/widgets/generator_form/templates/generatorform.ui
index 4498a0c34..2ad953d8e 100644
--- a/src/airunner/widgets/generator_form/templates/generatorform.ui
+++ b/src/airunner/widgets/generator_form/templates/generatorform.ui
@@ -21,7 +21,7 @@
Form
-
+
0
@@ -34,152 +34,216 @@
0
- -
-
-
- QFrame::NoFrame
+
-
+
+
+ 0
-
- QFrame::Plain
-
-
- true
-
-
-
-
- 0
- 0
- 361
- 1064
-
-
-
+
+
+ Tab 1
+
+
+
+ 0
+
- 2
+ 0
+
+
+ 0
- 9
+ 0
-
-
-
-
-
-
-
- Generate
-
-
-
- -
-
-
+
-
+
+
+ QFrame::NoFrame
+
+
+ QFrame::Plain
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 357
+ 1038
+
+
+
+
0
-
-
- -
-
-
- PointingHandCursor
+
+ 0
-
- Interrupt
+
+ 0
-
-
-
-
- -
-
-
- Qt::Vertical
-
-
-
- 9
+ 0
-
-
+
+
+ Qt::Vertical
+
+
+
+
+ 9
+
+
-
+
+
-
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+ Save Prompts
+
+
+
+ -
+
+
+
+
+
+
+ :/icons/light/artificial-intelligence-ai-chip-icon.svg:/icons/light/artificial-intelligence-ai-chip-icon.svg
+
+
+ true
+
+
+
+
+
+ -
+
+
+
+ 8
+ true
+
+
+
+ Prompt
+
+
+
+ -
+
+
+ Enter a prompt...
+
+
+
+
+
+
+
+
+ 9
+
+ -
+
+
+
+ 8
+ true
+
+
+
+ Negative Prompt
+
+
+
+ -
+
+
+ Enter a negative prompt...
+
+
+
+
+
+
+
+ -
+
-
-
-
- Qt::Horizontal
+
+
+ Generate
-
-
- 40
- 20
-
+
+
+ -
+
+
+ 0
-
+
-
-
+
+
+ PointingHandCursor
+
- Save Prompts
+ Interrupt
- -
-
-
-
- 8
- true
-
-
-
- Prompt
-
-
-
- -
-
-
- Enter a prompt...
-
-
-
-
-
-
-
-
- 9
-
- -
-
-
-
- 8
- true
-
-
-
- Negative Prompt
-
-
-
- -
-
-
- Enter a negative prompt...
-
-
-
+
+
+ Tab 2
+
+
+ -
+
+
+
+
+
+
+ ChatPromptWidget
+ QWidget
+ airunner/widgets/llm/chat_prompt_widget
+ 1
+
+
@@ -192,8 +256,8 @@
action_clicked_button_save_prompts()
- 349
- 23
+ 322
+ 47
502
diff --git a/src/airunner/widgets/generator_form/templates/generatorform_ui.py b/src/airunner/widgets/generator_form/templates/generatorform_ui.py
index 180de7232..a2236295d 100644
--- a/src/airunner/widgets/generator_form/templates/generatorform_ui.py
+++ b/src/airunner/widgets/generator_form/templates/generatorform_ui.py
@@ -17,34 +17,27 @@ def setupUi(self, generator_form):
font.setPointSize(8)
generator_form.setFont(font)
generator_form.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor))
- self.gridLayout_7 = QtWidgets.QGridLayout(generator_form)
- self.gridLayout_7.setContentsMargins(0, 0, 0, 0)
- self.gridLayout_7.setObjectName("gridLayout_7")
- self.scrollArea = QtWidgets.QScrollArea(parent=generator_form)
+ self.gridLayout_4 = QtWidgets.QGridLayout(generator_form)
+ self.gridLayout_4.setContentsMargins(0, 0, 0, 0)
+ self.gridLayout_4.setObjectName("gridLayout_4")
+ self.tabWidget = QtWidgets.QTabWidget(parent=generator_form)
+ self.tabWidget.setObjectName("tabWidget")
+ self.tab = QtWidgets.QWidget()
+ self.tab.setObjectName("tab")
+ self.gridLayout_3 = QtWidgets.QGridLayout(self.tab)
+ self.gridLayout_3.setContentsMargins(0, 0, 0, 0)
+ self.gridLayout_3.setObjectName("gridLayout_3")
+ self.scrollArea = QtWidgets.QScrollArea(parent=self.tab)
self.scrollArea.setFrameShape(QtWidgets.QFrame.Shape.NoFrame)
self.scrollArea.setFrameShadow(QtWidgets.QFrame.Shadow.Plain)
self.scrollArea.setWidgetResizable(True)
self.scrollArea.setObjectName("scrollArea")
self.scrollAreaWidgetContents = QtWidgets.QWidget()
- self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 361, 1064))
+ self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 357, 1038))
self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents")
self.gridLayout = QtWidgets.QGridLayout(self.scrollAreaWidgetContents)
- self.gridLayout.setContentsMargins(-1, 2, -1, 9)
+ self.gridLayout.setContentsMargins(0, 0, 0, 0)
self.gridLayout.setObjectName("gridLayout")
- self.horizontalLayout_9 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_9.setObjectName("horizontalLayout_9")
- self.generate_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents)
- self.generate_button.setObjectName("generate_button")
- self.horizontalLayout_9.addWidget(self.generate_button)
- self.progress_bar = QtWidgets.QProgressBar(parent=self.scrollAreaWidgetContents)
- self.progress_bar.setProperty("value", 0)
- self.progress_bar.setObjectName("progress_bar")
- self.horizontalLayout_9.addWidget(self.progress_bar)
- self.interrupt_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents)
- self.interrupt_button.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor))
- self.interrupt_button.setObjectName("interrupt_button")
- self.horizontalLayout_9.addWidget(self.interrupt_button)
- self.gridLayout.addLayout(self.horizontalLayout_9, 1, 0, 1, 1)
self.splitter = QtWidgets.QSplitter(parent=self.scrollAreaWidgetContents)
self.splitter.setOrientation(QtCore.Qt.Orientation.Vertical)
self.splitter.setObjectName("splitter")
@@ -60,6 +53,14 @@ def setupUi(self, generator_form):
self.pushButton = QtWidgets.QPushButton(parent=self.layoutWidget)
self.pushButton.setObjectName("pushButton")
self.horizontalLayout_6.addWidget(self.pushButton)
+ self.ai_button = QtWidgets.QPushButton(parent=self.layoutWidget)
+ self.ai_button.setText("")
+ icon = QtGui.QIcon()
+ icon.addPixmap(QtGui.QPixmap(":/icons/light/artificial-intelligence-ai-chip-icon.svg"), QtGui.QIcon.Mode.Normal, QtGui.QIcon.State.Off)
+ self.ai_button.setIcon(icon)
+ self.ai_button.setCheckable(True)
+ self.ai_button.setObjectName("ai_button")
+ self.horizontalLayout_6.addWidget(self.ai_button)
self.gridLayout_2.addLayout(self.horizontalLayout_6, 0, 1, 1, 1)
self.label = QtWidgets.QLabel(parent=self.layoutWidget)
font = QtGui.QFont()
@@ -86,11 +87,36 @@ def setupUi(self, generator_form):
self.negative_prompt = QtWidgets.QPlainTextEdit(parent=self.layoutWidget1)
self.negative_prompt.setObjectName("negative_prompt")
self.verticalLayout_6.addWidget(self.negative_prompt)
- self.gridLayout.addWidget(self.splitter, 0, 0, 1, 1)
+ self.gridLayout.addWidget(self.splitter, 0, 1, 1, 1)
+ self.horizontalLayout_9 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_9.setObjectName("horizontalLayout_9")
+ self.generate_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents)
+ self.generate_button.setObjectName("generate_button")
+ self.horizontalLayout_9.addWidget(self.generate_button)
+ self.progress_bar = QtWidgets.QProgressBar(parent=self.scrollAreaWidgetContents)
+ self.progress_bar.setProperty("value", 0)
+ self.progress_bar.setObjectName("progress_bar")
+ self.horizontalLayout_9.addWidget(self.progress_bar)
+ self.interrupt_button = QtWidgets.QPushButton(parent=self.scrollAreaWidgetContents)
+ self.interrupt_button.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor))
+ self.interrupt_button.setObjectName("interrupt_button")
+ self.horizontalLayout_9.addWidget(self.interrupt_button)
+ self.gridLayout.addLayout(self.horizontalLayout_9, 1, 1, 1, 1)
self.scrollArea.setWidget(self.scrollAreaWidgetContents)
- self.gridLayout_7.addWidget(self.scrollArea, 1, 0, 1, 1)
+ self.gridLayout_3.addWidget(self.scrollArea, 0, 0, 1, 1)
+ self.tabWidget.addTab(self.tab, "")
+ self.tab_2 = QtWidgets.QWidget()
+ self.tab_2.setObjectName("tab_2")
+ self.gridLayout_5 = QtWidgets.QGridLayout(self.tab_2)
+ self.gridLayout_5.setObjectName("gridLayout_5")
+ self.widget = ChatPromptWidget(parent=self.tab_2)
+ self.widget.setObjectName("widget")
+ self.gridLayout_5.addWidget(self.widget, 0, 0, 1, 1)
+ self.tabWidget.addTab(self.tab_2, "")
+ self.gridLayout_4.addWidget(self.tabWidget, 0, 0, 1, 1)
self.retranslateUi(generator_form)
+ self.tabWidget.setCurrentIndex(0)
self.pushButton.clicked.connect(generator_form.action_clicked_button_save_prompts) # type: ignore
self.interrupt_button.clicked.connect(generator_form.handle_interrupt_button_clicked) # type: ignore
self.generate_button.clicked.connect(generator_form.handle_generate_button_clicked) # type: ignore
@@ -101,10 +127,13 @@ def setupUi(self, generator_form):
def retranslateUi(self, generator_form):
_translate = QtCore.QCoreApplication.translate
generator_form.setWindowTitle(_translate("generator_form", "Form"))
- self.generate_button.setText(_translate("generator_form", "Generate"))
- self.interrupt_button.setText(_translate("generator_form", "Interrupt"))
self.pushButton.setText(_translate("generator_form", "Save Prompts"))
self.label.setText(_translate("generator_form", "Prompt"))
self.prompt.setPlaceholderText(_translate("generator_form", "Enter a prompt..."))
self.label_2.setText(_translate("generator_form", "Negative Prompt"))
self.negative_prompt.setPlaceholderText(_translate("generator_form", "Enter a negative prompt..."))
+ self.generate_button.setText(_translate("generator_form", "Generate"))
+ self.interrupt_button.setText(_translate("generator_form", "Interrupt"))
+ self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab), _translate("generator_form", "Tab 1"))
+ self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_2), _translate("generator_form", "Tab 2"))
+from airunner.widgets.llm.chat_prompt_widget import ChatPromptWidget
diff --git a/src/airunner/widgets/image/image_widget.py b/src/airunner/widgets/image/image_widget.py
index d4b027f6f..2e7a0c91b 100644
--- a/src/airunner/widgets/image/image_widget.py
+++ b/src/airunner/widgets/image/image_widget.py
@@ -244,7 +244,6 @@ def __init__(self, *args, **kwargs):
def handle_label_clicked(self, event):
# get the clicked object
if event.button() == Qt.MouseButton.LeftButton:
- # check if shift is down
shift_pressed = event.modifiers() == Qt.KeyboardModifier.ShiftModifier
self.container.activate_brush(self, self.brush, shift_pressed)
elif event.button() == Qt.MouseButton.RightButton:
diff --git a/src/airunner/widgets/input_image/input_image_settings_widget.py b/src/airunner/widgets/input_image/input_image_settings_widget.py
index c58a28b5b..0558e51c9 100644
--- a/src/airunner/widgets/input_image/input_image_settings_widget.py
+++ b/src/airunner/widgets/input_image/input_image_settings_widget.py
@@ -43,15 +43,7 @@ def current_input_image(self):
except AttributeError:
return None
- def initialize(
- self,
- generator_name,
- generator_section
- ):
- self.setProperty("generator_name", generator_name)
- self.setProperty("generator_section", generator_section)
- self.settings_manager.generator_name = generator_name
- self.settings_manager.generator_section = generator_section
+ def initialize(self):
self.update_buttons()
self.ui.groupBox.setTitle(self.property("checkbox_label"))
self.ui.scale_slider_widget.initialize()
diff --git a/src/airunner/widgets/layers/layer_container_widget.py b/src/airunner/widgets/layers/layer_container_widget.py
index 4be82f88c..39e5f154c 100644
--- a/src/airunner/widgets/layers/layer_container_widget.py
+++ b/src/airunner/widgets/layers/layer_container_widget.py
@@ -26,10 +26,6 @@ def current_layer(self):
except IndexError:
Logger.error(f"No current layer for index {self.current_layer_index}")
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.app.loaded.connect(self.initialize)
-
def initialize(self):
self.ui.scrollAreaWidgetContents.layout().addSpacerItem(
QSpacerItem(0, 0, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding))
@@ -195,7 +191,7 @@ def delete_selected_layers(self):
self.delete_layer(index=index, layer=layer)
self.selected_layers = {}
self.show_layers()
- self.app.canvas.update()
+ self.app.standard_image_panel.canvas_widget.do_draw()
def delete_layer(self, _value=False, index=None, layer=None):
Logger.info(f"delete_layer requested index {index}")
@@ -208,7 +204,7 @@ def delete_layer(self, _value=False, index=None, layer=None):
if current_index is None:
current_index = self.current_layer_index
Logger.info(f"Deleting layer {current_index}")
- self.app.canvas.delete_image()
+ self.app.standard_image_panel.canvas_widget.delete_image()
self.app.history.add_event({
"event": "delete_layer",
"layers": self.get_layers_copy(),
@@ -302,6 +298,7 @@ def toggle_layer_visibility(self, layer, layer_obj):
layer.visible = not layer.visible
self.update()
layer_obj.set_icon()
+ self.app.canvas_widget.do_draw()
def handle_move_layer(self, event):
point = QPoint(
diff --git a/src/airunner/widgets/layers/layer_widget.py b/src/airunner/widgets/layers/layer_widget.py
index 9dc4500aa..5a5067329 100644
--- a/src/airunner/widgets/layers/layer_widget.py
+++ b/src/airunner/widgets/layers/layer_widget.py
@@ -64,7 +64,7 @@ def action_clicked_button_toggle_layer_visibility(self, val):
self.layer_data.visible = val
session = get_session()
session.commit()
- self.app.canvas.do_draw()
+ self.app.canvas_widget.do_draw()
def set_thumbnail(self):
image = self.layer_data.image
diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py
new file mode 100644
index 000000000..cdb60c044
--- /dev/null
+++ b/src/airunner/widgets/llm/chat_prompt_widget.py
@@ -0,0 +1,307 @@
+from PyQt6.QtCore import pyqtSlot
+from PyQt6.QtWidgets import QSpacerItem, QSizePolicy
+
+from airunner.aihandler.enums import MessageCode
+from airunner.data.db import session
+from airunner.data.models import Conversation, LLMPromptTemplate, Message
+from airunner.widgets.base_widget import BaseWidget
+from airunner.widgets.llm.templates.chat_prompt_ui import Ui_chat_prompt
+from airunner.widgets.llm.message_widget import MessageWidget
+from airunner.utils import save_session
+from airunner.aihandler.logger import Logger
+
+
+class ChatPromptWidget(BaseWidget):
+ widget_class_ = Ui_chat_prompt
+ conversation = None
+ is_modal = True
+ generating = False
+ prefix = ""
+ prompt = ""
+ suffix = ""
+ conversation_history = []
+ spacer = None
+
+ @property
+ def generator(self):
+ try:
+ return self.app.ui.llm_widget.generator
+ except Exception as e:
+ Logger.error(e)
+ import traceback
+ traceback.print_exc()
+
+ @property
+ def generator_settings(self):
+ try:
+ return self.app.ui.llm_widget.generator_settings
+ except Exception as e:
+ Logger.error(e)
+
+ @property
+ def instructions(self):
+ return f"{self.generator.botname} loves {self.generator.username}. {self.generator.botname} is very nice. {self.generator.botname} uses compliments, kind responses, and nice words. Everything {self.generator.botname} says is nice. {self.generator.botname} is kind."
+
+ @property
+ def current_generator(self):
+ return self.settings_manager.current_llm_generator
+
+ @property
+ def instructions(self):
+ return f"{self.generator.botname} loves {self.generator.username}. {self.generator.botname} is very nice. {self.generator.botname} uses compliments, kind responses, and nice words. Everything {self.generator.botname} says is nice. {self.generator.botname} is kind."
+
+ def load_data(self):
+ self.conversation = session.query(Conversation).first()
+ if self.conversation is None:
+ self.conversation = Conversation()
+ session.add(self.conversation)
+ session.commit()
+
+ def initialize(self):
+ self.load_data()
+
+ self.app.token_signal.connect(self.handle_token_signal)
+ self.app.message_var.my_signal.connect(self.message_handler)
+ self.ui.prompt.returnPressed.connect(self.action_button_clicked_send)
+ self.ui.prompt.textChanged.connect(self.prompt_text_changed)
+ self.ui.conversation.hide()
+ self.ui.chat_container.show()
+
+ def handle_token_signal(self, val):
+ if val != "[END]":
+ text = self.ui.conversation.toPlainText()
+ text += val
+ self.ui.conversation.setText(text)
+ else:
+ self.stop_progress_bar()
+ self.generating = False
+ self.enable_send_button()
+
+ @pyqtSlot(dict)
+ def message_handler(self, response: dict):
+ try:
+ code = response["code"]
+ except TypeError:
+ return
+ message = response["message"]
+
+ if code == MessageCode.TEXT_GENERATED:
+ self.handle_text_generated(message)
+
+ def handle_text_generated(self, message):
+ self.stop_progress_bar()
+
+ # # check if messages is string or list
+ # if isinstance(messages, str):
+ # messages = [messages]
+
+ # #print("MESSAGES", messages)
+
+ # if messages is None:
+ # return
+
+ # # get last message
+ # message = messages[-1]["content"]
+
+ # strip quotes from start and end of message
+ if not message:
+ return
+ if message.startswith("\""):
+ message = message[1:]
+ if message.endswith("\""):
+ message = message[:-1]
+ message_object = Message(
+ name=self.generator.botname,
+ message=message,
+ conversation=self.conversation
+ )
+ session.add(message_object)
+ session.commit()
+
+ if self.settings_manager.enable_tts:
+ # split on sentence enders
+ sentence_enders = [".", "?", "!", "\n"]
+ text = message_object.message
+ sentences = []
+ # split text into sentences
+ current_sentence = ""
+ for char in text:
+ current_sentence += char
+ if char in sentence_enders:
+ sentences.append(current_sentence)
+ current_sentence = ""
+ if current_sentence != "":
+ sentences.append(current_sentence)
+
+ for index, sentence in enumerate(sentences):
+ sentence = sentence.strip()
+ self.app.client.message = dict(
+ tts_request=True,
+ request_data=dict(
+ text=sentence,
+ message_object=Message(
+ name=message_object.name,
+ message=sentence,
+ ),
+ is_bot=True,
+ callback=self.add_message_to_conversation,
+ first_message=index == 0,
+ last_message=index == len(sentences) - 1,
+ )
+ )
+
+ self.add_message_to_conversation(message_object, is_bot=True)
+
+ self.generating = False
+ self.enable_send_button()
+
+ def prompt_text_changed(self, val):
+ self.prompt = val
+
+ def clear_prompt(self):
+ self.ui.prompt.setText("")
+
+ def start_progress_bar(self):
+ self.ui.progressBar.setRange(0, 0)
+ self.ui.progressBar.setValue(0)
+
+ def stop_progress_bar(self):
+ self.ui.progressBar.setRange(0, 1)
+ self.ui.progressBar.setValue(1)
+ self.ui.progressBar.reset()
+
+ def disable_send_button(self):
+ self.ui.send_button.setEnabled(False)
+
+ def enable_send_button(self):
+ self.ui.send_button.setEnabled(True)
+
+ def response_text_changed(self):
+ pass
+
+ def parent(self):
+ return self.app.ui.llm_widget
+
+ def action_button_clicked_send(self, image_override=None, prompt_override=None, callback=None, generator_name="casuallm"):
+ if self.generating:
+ Logger.warning("Already generating")
+ return
+
+ self.generating = True
+ self.disable_send_button()
+ #user_input = f"{self.generator.username} Says: \"{self.prompt}\""
+ # conversation = "\n".join(self.conversation_history)
+ # suffix = "\n".join([self.suffix, f'{self.generator.botname} Says: '])
+ # prompt = "\n".join([self.instructions, self.prefix, conversation, input, suffix])
+
+ image = self.app.current_active_image() if (image_override is None or image_override is False) else image_override
+
+ prompt = self.prompt if (prompt_override is None or prompt_override == "") else prompt_override
+ if prompt is None or prompt == "":
+ Logger.warning("Prompt is empty")
+ return
+
+ print(self.generator.prompt_template)
+ prompt_template = session.query(LLMPromptTemplate).filter(
+ LLMPromptTemplate.name == self.generator.prompt_template
+ ).first()
+
+ data = {
+ "llm_request": True,
+ "request_data": {
+ "generator_name": generator_name,
+ "model_path": self.generator_settings.model_version,
+ "stream": True,
+ "prompt": prompt,
+ "do_summary": False,
+ "is_bot_alive": True,
+ "conversation_history": self.conversation_history,
+ "generator": self.generator,
+ "prefix": self.prefix,
+ "suffix": self.suffix,
+ "dtype": self.generator_settings.dtype,
+ "use_gpu": self.generator_settings.use_gpu,
+ "request_type": "image_caption_generator",
+ "username": self.generator.username,
+ "botname": self.generator.botname,
+ "prompt_template": prompt_template.template,
+ "parameters": {
+ "override_parameters": self.generator.override_parameters,
+ "top_p": self.generator_settings.top_p / 100.0,
+ "max_length": self.generator_settings.max_length,
+ "repetition_penalty": self.generator_settings.repetition_penalty / 100.0,
+ "min_length": self.generator_settings.min_length,
+ "length_penalty": self.generator_settings.length_penalty / 100,
+ "num_beams": self.generator_settings.num_beams,
+ "ngram_size": self.generator_settings.ngram_size,
+ "temperature": self.generator_settings.temperature / 10000.0,
+ "sequences": self.generator_settings.sequences,
+ "top_k": self.generator_settings.top_k,
+ "eta_cutoff": self.generator_settings.eta_cutoff / 100.0,
+ "seed": self.generator_settings.do_sample,
+ "early_stopping": self.generator_settings.early_stopping,
+ },
+ "image": image,
+ "callback": callback
+ }
+ }
+ message_object = Message(
+ name=self.generator.username,
+ message=self.prompt,
+ conversation=self.conversation
+ )
+ session.add(message_object)
+ session.commit()
+ self.app.client.message = data
+ self.add_message_to_conversation(message_object=message_object, is_bot=False)
+ self.clear_prompt()
+ self.start_progress_bar()
+
+ def describe_image(self, image, callback):
+ self.action_button_clicked_send(
+ image_override=image,
+ prompt_override="What is in this picture?",
+ callback=callback,
+ generator_name="visualqa"
+ )
+
+ def add_message_to_conversation(self, message_object, is_bot, first_message=True, last_message=True):
+ # remove spacer from self.ui.chat_container
+ widget = MessageWidget(message=message_object, is_bot=is_bot)
+ self.ui.scrollAreaWidgetContents.layout().addWidget(widget)
+
+ if self.spacer is not None:
+ self.ui.scrollAreaWidgetContents.layout().removeItem(self.spacer)
+
+ message = ""
+ if first_message:
+ message = f"{message_object.name} Says: \""
+ message += message_object.message
+ if last_message:
+ message += "\""
+
+ if first_message:
+ self.conversation_history.append(message)
+ if not first_message:
+ self.conversation_history[-1] += message
+ self.ui.conversation.undo()
+ self.ui.conversation.append(self.conversation_history[-1])
+
+ # add a vertical spacer to self.ui.chat_container
+ if self.spacer is None:
+ self.spacer = QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding)
+ self.ui.scrollAreaWidgetContents.layout().addItem(self.spacer)
+
+ def action_button_clicked_clear_conversation(self):
+ self.conversation_history = []
+ self.ui.conversation.setText("")
+ self.app.client.message = {
+ "llm_request": True,
+ "request_data": {
+ "request_type": "clear_conversation",
+ }
+ }
+
+ def message_type_text_changed(self, val):
+ self.generator.message_type = val
+ save_session()
diff --git a/src/airunner/widgets/llm/llm_preferences_widget.py b/src/airunner/widgets/llm/llm_preferences_widget.py
new file mode 100644
index 000000000..91fed15a1
--- /dev/null
+++ b/src/airunner/widgets/llm/llm_preferences_widget.py
@@ -0,0 +1,62 @@
+from airunner.widgets.base_widget import BaseWidget
+from airunner.widgets.llm.templates.llm_preferences_ui import Ui_llm_preferences_widget
+from airunner.utils import get_session, save_session
+from airunner.data.models import LLMPromptTemplate
+from airunner.widgets.base_widget import BaseWidget
+from airunner.aihandler.logger import Logger
+
+
+
+class LLMPreferencesWidget(BaseWidget):
+ widget_class_ = Ui_llm_preferences_widget
+
+ @property
+ def generator(self):
+ try:
+ return self.app.ui.llm_widget.generator
+ except Exception as e:
+ Logger.error(e)
+ import traceback
+ traceback.print_exc()
+
+ @property
+ def generator_settings(self):
+ try:
+ return self.app.ui.llm_widget.generator_settings
+ except Exception as e:
+ Logger.error(e)
+
+ def initialize(self):
+ self.ui.prefix.blockSignals(True)
+ self.ui.suffix.blockSignals(True)
+ self.ui.personality_type.blockSignals(True)
+ if self.generator:
+ self.ui.prefix.setPlainText(self.generator.prefix)
+ self.ui.suffix.setPlainText(self.generator.suffix)
+ self.ui.personality_type.setCurrentText(self.generator.bot_personality)
+ self.ui.prefix.blockSignals(False)
+ self.ui.suffix.blockSignals(False)
+ self.ui.personality_type.blockSignals(False)
+
+ def action_button_clicked_generate_characters(self):
+ pass
+
+ def personality_type_changed(self, val):
+ self.generator.bot_personality = val
+ save_session()
+
+ def prefix_text_changed(self):
+ self.generator.prefix = self.ui.prefix.toPlainText()
+ save_session()
+
+ def suffix_text_changed(self):
+ self.generator.suffix = self.ui.suffix.toPlainText()
+ save_session()
+
+ def username_text_changed(self, val):
+ self.generator.username = val
+ save_session()
+
+ def botname_text_changed(self, val):
+ self.generator.botname = val
+ save_session()
\ No newline at end of file
diff --git a/src/airunner/widgets/llm/llm_settings_widget.py b/src/airunner/widgets/llm/llm_settings_widget.py
new file mode 100644
index 000000000..6af9d0187
--- /dev/null
+++ b/src/airunner/widgets/llm/llm_settings_widget.py
@@ -0,0 +1,263 @@
+"""
+This class should be used to create a window widget for the LLM.
+"""
+from PyQt6.QtWidgets import QWidget
+
+from airunner.widgets.base_widget import BaseWidget
+from airunner.widgets.llm.templates.llm_settings_ui import Ui_llm_settings_widget
+from airunner.utils import save_session, get_session
+from airunner.data.models import LLMGeneratorSetting, LLMGenerator, AIModel, LLMPromptTemplate
+from airunner.aihandler.logger import Logger
+
+
+class LLMSettingsWidget(BaseWidget):
+ widget_class_ = Ui_llm_settings_widget
+ current_generator = None
+ dtype_descriptions = {
+ "2bit": "Fastest, least amount of VRAM, GPU only, least accurate results.",
+ "4bit": "Faster, much less VRAM, GPU only, much less accurate results.",
+ "8bit": "Fast, less VRAM, GPU only, less accurate results.",
+ "16bit": "Normal speed, some VRAM, uses GPU, slightly less accurate results.",
+ "32bit": "Slow, no VRAM, uses CPU, most accurate results.",
+ }
+
+ @property
+ def generator(self):
+ try:
+ return self.app.ui.llm_widget.generator
+ except Exception as e:
+ Logger.error(e)
+ import traceback
+ traceback.print_exc()
+
+ @property
+ def generator_settings(self):
+ try:
+ return self.app.ui.llm_widget.generator_settings
+ except Exception as e:
+ Logger.error(e)
+
+ @property
+ def current_generator(self):
+ return self.settings_manager.get_value("current_llm_generator")
+
+ def initialize(self):
+ self.initialize_form()
+
+ def early_stopping_toggled(self, val):
+ self.generator.generator_settings[0].early_stopping = val
+ save_session()
+
+ def do_sample_toggled(self, val):
+ self.generator.generator_settings[0].do_sample = val
+ save_session()
+
+ def toggle_leave_model_in_vram(self, val):
+ if val:
+ self.settings_manager.set_value("unload_unused_model", False)
+ self.settings_manager.set_value("move_unused_model_to_cpu", False)
+
+ def initialize_form(self):
+ session = get_session()
+ self.ui.prompt_template.blockSignals(True)
+ self.ui.model.blockSignals(True)
+ self.ui.model_version.blockSignals(True)
+ self.ui.radio_button_2bit.blockSignals(True)
+ self.ui.radio_button_4bit.blockSignals(True)
+ self.ui.radio_button_8bit.blockSignals(True)
+ self.ui.radio_button_16bit.blockSignals(True)
+ self.ui.radio_button_32bit.blockSignals(True)
+ self.ui.random_seed.blockSignals(True)
+ self.ui.do_sample.blockSignals(True)
+ self.ui.early_stopping.blockSignals(True)
+ self.ui.use_gpu_checkbox.blockSignals(True)
+ self.ui.override_parameters.blockSignals(True)
+ self.ui.top_p.initialize_properties()
+ self.ui.max_length.initialize_properties()
+ self.ui.max_length.initialize_properties()
+ self.ui.repetition_penalty.initialize_properties()
+ self.ui.min_length.initialize_properties()
+ self.ui.length_penalty.initialize_properties()
+ self.ui.num_beams.initialize_properties()
+ self.ui.ngram_size.initialize_properties()
+ self.ui.temperature.initialize_properties()
+ self.ui.sequences.initialize_properties()
+ self.ui.top_k.initialize_properties()
+
+ prompt_templates = [template.name for template in session.query(LLMPromptTemplate).all()]
+ self.ui.prompt_template.clear()
+ self.ui.prompt_template.addItems(prompt_templates)
+
+ if self.generator:
+ self.ui.radio_button_2bit.setChecked(self.generator.generator_settings[0].dtype == "2bit")
+ self.ui.radio_button_4bit.setChecked(self.generator.generator_settings[0].dtype == "4bit")
+ self.ui.radio_button_8bit.setChecked(self.generator.generator_settings[0].dtype == "8bit")
+ self.ui.radio_button_16bit.setChecked(self.generator.generator_settings[0].dtype == "16bit")
+ self.ui.radio_button_32bit.setChecked(self.generator.generator_settings[0].dtype == "32bit")
+ self.set_dtype_by_gpu( self.generator.generator_settings[0].use_gpu)
+ self.set_dtype(self.generator.generator_settings[0].dtype)
+
+ # get unique model names
+ model_names = [model.name for model in session.query(LLMGenerator).all()]
+ model_names = list(set(model_names))
+ self.ui.model.clear()
+ self.ui.model.addItems(model_names)
+ self.ui.model.setCurrentText(self.current_generator)
+ self.update_model_version_combobox()
+ if self.generator:
+ self.ui.model_version.setCurrentText(self.generator.generator_settings[0].model_version)
+ self.ui.random_seed.setChecked(self.generator.generator_settings[0].random_seed)
+ self.ui.do_sample.setChecked(self.generator.generator_settings[0].do_sample)
+ self.ui.early_stopping.setChecked(self.generator.generator_settings[0].early_stopping)
+ self.ui.use_gpu_checkbox.setChecked(self.generator.generator_settings[0].use_gpu)
+ self.ui.override_parameters.setChecked(self.generator.override_parameters)
+
+ self.ui.model.blockSignals(False)
+ self.ui.model_version.blockSignals(False)
+ self.ui.radio_button_2bit.blockSignals(False)
+ self.ui.radio_button_4bit.blockSignals(False)
+ self.ui.radio_button_8bit.blockSignals(False)
+ self.ui.radio_button_16bit.blockSignals(False)
+ self.ui.radio_button_32bit.blockSignals(False)
+ self.ui.random_seed.blockSignals(False)
+ self.ui.do_sample.blockSignals(False)
+ self.ui.early_stopping.blockSignals(False)
+ self.ui.use_gpu_checkbox.blockSignals(False)
+ self.ui.override_parameters.blockSignals(False)
+ self.ui.prompt_template.blockSignals(False)
+
+ def model_text_changed(self, val):
+ print("model_text_changed", val)
+ self.settings_manager.set_value("current_llm_generator", val)
+ self.update_model_version_combobox()
+ self.model_version_changed(self.ui.model_version.currentText())
+ self.initialize_form()
+
+ def model_version_changed(self, val):
+ self.generator.generator_settings[0].model_version = val
+ save_session()
+
+ def toggle_move_model_to_cpu(self, val):
+ self.settings_manager.set_value("move_unused_model_to_cpu", val)
+ if val:
+ self.settings_manager.set_value("unload_unused_model", False)
+
+ def override_parameters_toggled(self, val):
+ self.generator.override_parameters = val
+ save_session()
+
+ def prompt_template_text_changed(self, value):
+ self.generator.prompt_template = value
+ save_session()
+
+ def toggled_2bit(self, val):
+ if val:
+ self.set_dtype("2bit")
+
+ def toggled_4bit(self, val):
+ if val:
+ self.set_dtype("4bit")
+
+ def toggled_8bit(self, val):
+ if val:
+ self.set_dtype("8bit")
+
+ def toggled_16bit(self, val):
+ if val:
+ self.set_dtype("16bit")
+
+ def toggled_32bit(self, val):
+ if val:
+ self.set_dtype("32bit")
+
+ def random_seed_toggled(self, val):
+ self.generator.generator_settings[0].random_seed = val
+ save_session()
+
+ def seed_changed(self, val):
+ self.generator.generator_settings[0].seed = val
+ save_session()
+
+ def toggle_unload_model(self, val):
+ self.settings_manager.set_value("unload_unused_model", val)
+ if val:
+ self.settings_manager.set_value("move_unused_model_to_cpu", False)
+
+ def use_gpu_toggled(self, val):
+ self.generator.generator_settings[0].use_gpu = val
+ # toggle the 16bit radio button and disable 4bit and 8bit radio buttons
+ self.set_dtype_by_gpu(val)
+ save_session()
+
+ def set_dtype_by_gpu(self, use_gpu):
+ if not use_gpu:
+ self.ui.radio_button_2bit.setEnabled(False)
+ self.ui.radio_button_4bit.setEnabled(False)
+ self.ui.radio_button_8bit.setEnabled(False)
+ self.ui.radio_button_32bit.setEnabled(True)
+
+ if self.generator.generator_settings[0].dtype in ["4bit", "8bit"]:
+ self.ui.radio_button_16bit.setChecked(True)
+ else:
+ self.ui.radio_button_2bit.setEnabled(True)
+ self.ui.radio_button_4bit.setEnabled(True)
+ self.ui.radio_button_8bit.setEnabled(True)
+ self.ui.radio_button_32bit.setEnabled(False)
+ if self.generator.generator_settings[0].dtype == "32bit":
+ self.ui.radio_button_16bit.setChecked(True)
+
+ def reset_settings_to_default_clicked(self):
+ self.generator.generator_settings[0].top_p = LLMGeneratorSetting.top_p.default.arg
+ self.generator.generator_settings[0].max_length = LLMGeneratorSetting.max_length.default.arg
+ self.generator.generator_settings[0].repetition_penalty = LLMGeneratorSetting.repetition_penalty.default.arg
+ self.generator.generator_settings[0].min_length = LLMGeneratorSetting.min_length.default.arg
+ self.generator.generator_settings[0].length_penalty = LLMGeneratorSetting.length_penalty.default.arg
+ self.generator.generator_settings[0].num_beams = LLMGeneratorSetting.num_beams.default.arg
+ self.generator.generator_settings[0].ngram_size = LLMGeneratorSetting.ngram_size.default.arg
+ self.generator.generator_settings[0].temperature = LLMGeneratorSetting.temperature.default.arg
+ self.generator.generator_settings[0].sequences = LLMGeneratorSetting.sequences.default.arg
+ self.generator.generator_settings[0].top_k = LLMGeneratorSetting.top_k.default.arg
+ self.generator.generator_settings[0].eta_cutoff = LLMGeneratorSetting.eta_cutoff.default.arg
+ self.generator.generator_settings[0].seed = LLMGeneratorSetting.seed.default.arg
+ self.generator.generator_settings[0].do_sample = LLMGeneratorSetting.do_sample.default.arg
+ self.generator.generator_settings[0].early_stopping = LLMGeneratorSetting.early_stopping.default.arg
+ self.generator.generator_settings[0].random_seed = LLMGeneratorSetting.random_seed.default.arg
+ self.generator.generator_settings[0].model_version = LLMGeneratorSetting.model_version.default.arg
+ self.generator.generator_settings[0].dtype = LLMGeneratorSetting.dtype.default.arg
+ self.generator.generator_settings[0].use_gpu = LLMGeneratorSetting.use_gpu.default.arg
+ save_session()
+ self.initialize_form()
+ self.ui.top_p.set_slider_and_spinbox_values(self.generator.generator_settings[0].top_p)
+ self.ui.max_length.set_slider_and_spinbox_values(self.generator.generator_settings[0].max_length)
+ self.ui.repetition_penalty.set_slider_and_spinbox_values(self.generator.generator_settings[0].repetition_penalty)
+ self.ui.min_length.set_slider_and_spinbox_values(self.generator.generator_settings[0].min_length)
+ self.ui.length_penalty.set_slider_and_spinbox_values(self.generator.generator_settings[0].length_penalty)
+ self.ui.num_beams.set_slider_and_spinbox_values(self.generator.generator_settings[0].num_beams)
+ self.ui.ngram_size.set_slider_and_spinbox_values(self.generator.generator_settings[0].ngram_size)
+ self.ui.temperature.set_slider_and_spinbox_values(self.generator.generator_settings[0].temperature)
+ self.ui.sequences.set_slider_and_spinbox_values(self.generator.generator_settings[0].sequences)
+ self.ui.top_k.set_slider_and_spinbox_values(self.generator.generator_settings[0].top_k)
+ self.ui.eta_cutoff.set_slider_and_spinbox_values(self.generator.generator_settings[0].eta_cutoff)
+ self.ui.random_seed.setChecked(self.generator.generator_settings[0].random_seed)
+
+ def set_dtype(self, dtype):
+ self.generator.generator_settings[0].dtype = dtype
+ save_session()
+ self.set_dtype_description(dtype)
+
+ def set_dtype_description(self, dtype):
+ self.ui.dtype_description.setText(self.dtype_descriptions[dtype])
+
+ def update_model_version_combobox(self):
+ session = get_session()
+ self.ui.model_version.blockSignals(True)
+ self.ui.model_version.clear()
+ ai_model_paths = [model.path for model in session.query(AIModel).filter(
+ AIModel.pipeline_action == self.current_generator
+ )]
+ self.ui.model_version.addItems(ai_model_paths)
+ self.ui.model_version.blockSignals(False)
+
+ def set_tab(self, tab_name):
+ index = self.ui.tabWidget.indexOf(self.ui.tabWidget.findChild(QWidget, tab_name))
+ self.ui.tabWidget.setCurrentIndex(index)
\ No newline at end of file
diff --git a/src/airunner/widgets/llm/llm_widget.py b/src/airunner/widgets/llm/llm_widget.py
index 1d6f7b862..997241db0 100644
--- a/src/airunner/widgets/llm/llm_widget.py
+++ b/src/airunner/widgets/llm/llm_widget.py
@@ -1,558 +1,43 @@
"""
This class should be used to create a window widget for the LLM.
"""
-from PyQt6.QtCore import pyqtSlot
-from PyQt6.QtWidgets import QSpacerItem, QSizePolicy
-from PyQt6.QtWidgets import QWidget
-from sqlalchemy import inspect
-from functools import partial
-
-from airunner.aihandler.enums import MessageCode
-from airunner.data.db import session
-from airunner.data.models import AIModel, LLMGenerator, Conversation, LLMGeneratorSetting, LLMPromptTemplate, Message
-from airunner.utils import save_session
+from airunner.utils import get_session
+from airunner.data.models import LLMGenerator, LLMGeneratorSetting
from airunner.widgets.base_widget import BaseWidget
from airunner.widgets.llm.templates.llm_widget_ui import Ui_llm_widget
+from airunner.aihandler.logger import Logger
class LLMWidget(BaseWidget):
widget_class_ = Ui_llm_widget
generator = None
- conversation = None
- is_modal = True
- generating = False
- prefix = ""
- prompt = ""
- suffix = ""
- conversation_history = []
-
- @property
- def current_generator(self):
- return self.settings_manager.current_llm_generator
+ _generator = None
+ _generator_settings = None
@property
- def instructions(self):
- return f"{self.generator.botname} loves {self.generator.username}. {self.generator.botname} is very nice. {self.generator.botname} uses compliments, kind responses, and nice words. Everything {self.generator.botname} says is nice. {self.generator.botname} is kind."
-
- def load_data(self):
- self.load_generator()
- self.conversation = session.query(Conversation).first()
- if self.conversation is None:
- self.conversation = Conversation()
- session.add(self.conversation)
- session.commit()
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # hide tab bar
- #self.ui.tabWidget.tabBar().hide()
- self.load_data()
-
- self.ui.prefix.blockSignals(True)
- self.ui.suffix.blockSignals(True)
- if self.generator:
- self.ui.prefix.setPlainText(self.generator.prefix)
- self.ui.suffix.setPlainText(self.generator.suffix)
- self.initialize_form()
- self.ui.prefix.blockSignals(False)
- self.ui.suffix.blockSignals(False)
-
- self.app.token_signal.connect(self.handle_token_signal)
- self.app.message_var.my_signal.connect(self.message_handler)
- self.ui.prompt.returnPressed.connect(self.action_button_clicked_send)
- self.ui.prompt.textChanged.connect(self.prompt_text_changed)
- leave_in_vram = not self.settings_manager.move_unused_model_to_cpu and not self.settings_manager.unload_unused_model
- self.ui.leave_in_vram.setChecked(leave_in_vram)
- self.ui.move_to_cpu.setChecked(self.settings_manager.move_unused_model_to_cpu)
- self.ui.unload_model.setChecked(self.settings_manager.unload_unused_model)
-
- prompt_templates = session.query(LLMPromptTemplate).all()
- self.ui.prompt_template.blockSignals(True)
- for prompt_template in prompt_templates:
- self.ui.prompt_template.addItem(prompt_template.name)
- self.ui.prompt_template.blockSignals(False)
+ def generator(self):
+ if self._generator is None:
+ session = get_session()
+ try:
+ self._generator = session.query(LLMGenerator).filter(
+ LLMGenerator.name == self.ui.llm_settings_widget.current_generator
+ ).first()
+ if self._generator is None:
+ Logger.error("Unable to locate generator by name " + self.ui.llm_settings_widget.current_generator if self.ui.llm_settings_widget.current_generator else "None")
+ except Exception as e:
+ Logger.error(e)
+ return self._generator
- def prompt_template_text_changed(self, value):
- print(value)
-
- def handle_token_signal(self, val):
- if val != "[END]":
- text = self.ui.conversation.toPlainText()
- text += val
- self.ui.conversation.setText(text)
- else:
- self.stop_progress_bar()
- self.generating = False
- self.enable_send_button()
-
- @pyqtSlot(dict)
- def message_handler(self, response: dict):
+ @property
+ def generator_settings(self):
try:
- code = response["code"]
- except TypeError:
- return
- message = response["message"]
-
- if code == MessageCode.TEXT_GENERATED:
- print("MESSAGE HANDLER", response)
- self.handle_text_generated(message)
-
- def initialize_form(self):
- self.ui.model.blockSignals(True)
- self.ui.model_version.blockSignals(True)
- self.ui.prompt.blockSignals(True)
- self.ui.botname.blockSignals(True)
- self.ui.username.blockSignals(True)
- self.ui.prefix.blockSignals(True)
- self.ui.suffix.blockSignals(True)
- self.ui.personality_type.blockSignals(True)
- self.ui.radio_button_2bit.blockSignals(True)
- self.ui.radio_button_4bit.blockSignals(True)
- self.ui.radio_button_8bit.blockSignals(True)
- self.ui.radio_button_16bit.blockSignals(True)
- self.ui.radio_button_32bit.blockSignals(True)
- self.ui.random_seed.blockSignals(True)
- self.ui.do_sample.blockSignals(True)
- self.ui.early_stopping.blockSignals(True)
- self.ui.use_gpu_checkbox.blockSignals(True)
- self.ui.override_parameters.blockSignals(True)
-
- self.ui.top_p.initialize_properties()
- self.ui.max_length.initialize_properties()
- self.ui.max_length.initialize_properties()
- self.ui.repetition_penalty.initialize_properties()
- self.ui.min_length.initialize_properties()
- self.ui.length_penalty.initialize_properties()
- self.ui.num_beams.initialize_properties()
- self.ui.ngram_size.initialize_properties()
- self.ui.temperature.initialize_properties()
- self.ui.sequences.initialize_properties()
- self.ui.top_k.initialize_properties()
-
- if self.generator:
- self.ui.radio_button_2bit.setChecked(self.generator.generator_settings[0].dtype == "2bit")
- self.ui.radio_button_4bit.setChecked(self.generator.generator_settings[0].dtype == "4bit")
- self.ui.radio_button_8bit.setChecked(self.generator.generator_settings[0].dtype == "8bit")
- self.ui.radio_button_16bit.setChecked(self.generator.generator_settings[0].dtype == "16bit")
- self.ui.radio_button_32bit.setChecked(self.generator.generator_settings[0].dtype == "32bit")
- self.set_dtype_by_gpu( self.generator.generator_settings[0].use_gpu)
- self.set_dtype(self.generator.generator_settings[0].dtype)
-
- # get unique model names
- model_names = [model.name for model in session.query(LLMGenerator).all()]
- model_names = list(set(model_names))
- self.ui.model.clear()
- self.ui.model.addItems(model_names)
- self.ui.model.setCurrentText(self.current_generator)
- if self.generator:
- self.ui.username.setText(self.generator.username)
- self.ui.botname.setText(self.generator.botname)
- self.update_model_version_combobox()
- if self.generator:
- self.ui.model_version.setCurrentText(self.generator.generator_settings[0].model_version)
- self.ui.personality_type.setCurrentText(self.generator.bot_personality)
- self.ui.random_seed.setChecked(self.generator.generator_settings[0].random_seed)
- self.ui.do_sample.setChecked(self.generator.generator_settings[0].do_sample)
- self.ui.early_stopping.setChecked(self.generator.generator_settings[0].early_stopping)
- self.ui.use_gpu_checkbox.setChecked(self.generator.generator_settings[0].use_gpu)
- self.ui.override_parameters.setChecked(self.generator.override_parameters)
-
- self.ui.model.blockSignals(False)
- self.ui.model_version.blockSignals(False)
- self.ui.prompt.blockSignals(False)
- self.ui.botname.blockSignals(False)
- self.ui.username.blockSignals(False)
- self.ui.prefix.blockSignals(False)
- self.ui.suffix.blockSignals(False)
- self.ui.personality_type.blockSignals(False)
- self.ui.radio_button_2bit.blockSignals(False)
- self.ui.radio_button_4bit.blockSignals(False)
- self.ui.radio_button_8bit.blockSignals(False)
- self.ui.radio_button_16bit.blockSignals(False)
- self.ui.radio_button_32bit.blockSignals(False)
- self.ui.random_seed.blockSignals(False)
- self.ui.do_sample.blockSignals(False)
- self.ui.early_stopping.blockSignals(False)
- self.ui.use_gpu_checkbox.blockSignals(False)
- self.ui.override_parameters.blockSignals(False)
-
- def handle_text_generated(self, messages):
- self.stop_progress_bar()
-
- # check if messages is string or list
- if isinstance(messages, str):
- messages = [messages]
-
- print("MESSAGES", messages)
-
- for message in messages:
-
- # strip quotes from start and end of message
- if not message:
- return
- if message.startswith("\""):
- message = message[1:]
- if message.endswith("\""):
- message = message[:-1]
- message_object = Message(
- name=self.generator.botname,
- message=message,
- conversation=self.conversation
- )
- session.add(message_object)
- session.commit()
-
- if self.settings_manager.enable_tts:
- # split on sentence enders
- sentence_enders = [".", "?", "!", "\n"]
- text = message_object.message
- sentences = []
- # split text into sentences
- current_sentence = ""
- for char in text:
- current_sentence += char
- if char in sentence_enders:
- sentences.append(current_sentence)
- current_sentence = ""
- if current_sentence != "":
- sentences.append(current_sentence)
-
- for index, sentence in enumerate(sentences):
- sentence = sentence.strip()
- self.app.client.message = dict(
- tts_request=True,
- request_data=dict(
- text=sentence,
- message_object=Message(
- name=message_object.name,
- message=sentence,
- ),
- is_bot=True,
- callback=self.add_message_to_conversation,
- first_message=index == 0,
- last_message=index == len(sentences) - 1,
- )
- )
-
- if not self.settings_manager.enable_tts:
- self.add_message_to_conversation(message_object, is_bot=True)
-
- self.generating = False
- self.enable_send_button()
-
- def personality_type_changed(self, val):
- self.generator.bot_personality = val
- save_session()
-
- def prefix_text_changed(self):
- self.generator.prefix = self.ui.prefix.toPlainText()
- save_session()
-
- def prompt_text_changed(self, val):
- self.prompt = val
+ return self.generator.generator_settings[0]
+ except Exception as e:
+ Logger.error(e)
+ return None
- def suffix_text_changed(self):
- self.generator.suffix = self.ui.suffix.toPlainText()
- save_session()
-
- def clear_prompt(self):
- self.ui.prompt.setText("")
-
- def start_progress_bar(self):
- self.ui.progressBar.setRange(0, 0)
- self.ui.progressBar.setValue(0)
-
- def stop_progress_bar(self):
- self.ui.progressBar.setRange(0, 1)
- self.ui.progressBar.setValue(1)
- self.ui.progressBar.reset()
-
- def disable_send_button(self):
- self.ui.send_button.setEnabled(False)
-
- def enable_send_button(self):
- self.ui.send_button.setEnabled(True)
-
- def seed_changed(self, val):
- self.generator.generator_settings[0].seed = val
- save_session()
-
- def response_text_changed(self):
- pass
-
- def username_text_changed(self, val):
- self.generator.username = val
- save_session()
-
- def random_seed_toggled(self, val):
- self.generator.generator_settings[0].random_seed = val
- save_session()
-
- def model_version_changed(self, val):
- self.generator.generator_settings[0].model_version = val
- save_session()
-
- def early_stopping_toggled(self, val):
- self.generator.generator_settings[0].early_stopping = val
- save_session()
-
- def do_sample_toggled(self, val):
- self.generator.generator_settings[0].do_sample = val
- save_session()
-
- def botname_text_changed(self, val):
- self.generator.botname = val
- save_session()
-
- def action_button_clicked_send(self, image_override=None, prompt_override=None, callback=None, generator_name="casuallm"):
- if self.generating:
- return
-
- self.load_generator()
- self.generating = True
- self.disable_send_button()
- #user_input = f"{self.generator.username} Says: \"{self.prompt}\""
- # conversation = "\n".join(self.conversation_history)
- # suffix = "\n".join([self.suffix, f'{self.generator.botname} Says: '])
- # prompt = "\n".join([self.instructions, self.prefix, conversation, input, suffix])
- prompt_template = session.query(LLMPromptTemplate).filter(
- LLMPromptTemplate.name == self.ui.prompt_template.currentText()
- ).first()
-
- image = self.app.current_active_image() if (image_override is None or image_override is False) else image_override
-
- prompt = self.prompt if prompt_override is None else prompt_override
-
- settings = self.generator.generator_settings[0]
- data = {
- "llm_request": True,
- "request_data": {
- "generator_name": generator_name,
- "model_path": settings.model_version,
- "stream": True,
- "prompt": prompt,
- "do_summary": False,
- "is_bot_alive": True,
- "conversation_history": self.conversation_history,
- "generator": self.generator,
- "prefix": self.prefix,
- "suffix": self.suffix,
- "dtype": settings.dtype,
- "use_gpu": settings.use_gpu,
- "request_type": "image_caption_generator",
- "username": self.generator.username,
- "botname": self.generator.botname,
- "prompt_template": prompt_template.template,
- "parameters": {
- "override_parameters": self.generator.override_parameters,
- "top_p": settings.top_p / 100.0,
- "max_length": settings.max_length,
- "repetition_penalty": settings.repetition_penalty / 100.0,
- "min_length": settings.min_length,
- "length_penalty": settings.length_penalty / 100,
- "num_beams": settings.num_beams,
- "ngram_size": settings.ngram_size,
- "temperature": settings.temperature / 10000.0,
- "sequences": settings.sequences,
- "top_k": settings.top_k,
- "eta_cutoff": settings.eta_cutoff / 100.0,
- "seed": settings.do_sample,
- "early_stopping": settings.early_stopping,
- },
- "image": image,
- "callback": callback
- }
- }
- message_object = Message(
- name=self.generator.username,
- message=self.prompt,
- conversation=self.conversation
- )
- session.add(message_object)
- session.commit()
- self.app.client.message = data
- self.add_message_to_conversation(message_object=message_object, is_bot=False)
- self.clear_prompt()
- self.start_progress_bar()
-
- def describe_image(self, image, callback):
- print("DESCRIBE IMAGE", callback)
- self.action_button_clicked_send(
- image_override=image,
- prompt_override="What is in this picture?",
- callback=callback,
- generator_name="visualqa"
- )
-
- def add_message_to_conversation(self, message_object, is_bot, first_message=True, last_message=True):
- message = ""
- if first_message:
- message = f"{message_object.name} Says: \""
- message += message_object.message
- if last_message:
- message += "\""
-
- if first_message:
- self.conversation_history.append(message)
- if not first_message:
- self.conversation_history[-1] += message
- self.ui.conversation.undo()
- self.ui.conversation.append(self.conversation_history[-1])
-
-
- def action_button_clicked_generate_characters(self):
- pass
-
- def action_button_clicked_clear_conversation(self):
- self.conversation_history = []
- self.ui.conversation.setText("")
- self.app.client.message = {
- "llm_request": True,
- "request_data": {
- "request_type": "clear_conversation",
- }
- }
-
- def message_type_text_changed(self, val):
- self.generator.message_type = val
- save_session()
-
- dtype_descriptions = {
- "2bit": "Fastest, least amount of VRAM, GPU only, least accurate results.",
- "4bit": "Faster, much less VRAM, GPU only, much less accurate results.",
- "8bit": "Fast, less VRAM, GPU only, less accurate results.",
- "16bit": "Normal speed, some VRAM, uses GPU, slightly less accurate results.",
- "32bit": "Slow, no VRAM, uses CPU, most accurate results.",
- }
-
- def toggled_2bit(self, val):
- if val:
- self.set_dtype("2bit")
-
- def toggled_4bit(self, val):
- if val:
- self.set_dtype("4bit")
-
- def toggled_8bit(self, val):
- if val:
- self.set_dtype("8bit")
-
- def toggled_16bit(self, val):
- if val:
- self.set_dtype("16bit")
-
- def toggled_32bit(self, val):
- if val:
- self.set_dtype("32bit")
-
- def set_dtype(self, dtype):
- self.generator.generator_settings[0].dtype = dtype
- save_session()
- self.set_dtype_description(dtype)
-
- def set_dtype_description(self, dtype):
- self.ui.dtype_description.setText(self.dtype_descriptions[dtype])
-
- def model_text_changed(self, val):
- self.settings_manager.set_value("current_llm_generator", val)
- self.load_generator()
- self.generator.generator_settings[0].model = val
- self.update_model_version_combobox()
- self.model_version_changed(self.ui.model_version.currentText())
- print("MODEL TEXT CHANGED")
- self.initialize_form()
-
- def update_model_version_combobox(self):
- self.ui.model_version.blockSignals(True)
- self.ui.model_version.clear()
- ai_model_paths = [model.path for model in session.query(AIModel).filter(
- AIModel.pipeline_action == self.current_generator
- )]
- self.ui.model_version.addItems(ai_model_paths)
- self.ui.model_version.blockSignals(False)
-
- def load_generator(self):
- self.generator = session.query(LLMGenerator).filter(
- LLMGenerator.name == self.current_generator
- ).first()
-
- def reset_settings_to_default_clicked(self):
- self.generator.generator_settings[0].top_p = LLMGeneratorSetting.top_p.default.arg
- self.generator.generator_settings[0].max_length = LLMGeneratorSetting.max_length.default.arg
- self.generator.generator_settings[0].repetition_penalty = LLMGeneratorSetting.repetition_penalty.default.arg
- self.generator.generator_settings[0].min_length = LLMGeneratorSetting.min_length.default.arg
- self.generator.generator_settings[0].length_penalty = LLMGeneratorSetting.length_penalty.default.arg
- self.generator.generator_settings[0].num_beams = LLMGeneratorSetting.num_beams.default.arg
- self.generator.generator_settings[0].ngram_size = LLMGeneratorSetting.ngram_size.default.arg
- self.generator.generator_settings[0].temperature = LLMGeneratorSetting.temperature.default.arg
- self.generator.generator_settings[0].sequences = LLMGeneratorSetting.sequences.default.arg
- self.generator.generator_settings[0].top_k = LLMGeneratorSetting.top_k.default.arg
- self.generator.generator_settings[0].eta_cutoff = LLMGeneratorSetting.eta_cutoff.default.arg
- self.generator.generator_settings[0].seed = LLMGeneratorSetting.seed.default.arg
- self.generator.generator_settings[0].do_sample = LLMGeneratorSetting.do_sample.default.arg
- self.generator.generator_settings[0].early_stopping = LLMGeneratorSetting.early_stopping.default.arg
- self.generator.generator_settings[0].random_seed = LLMGeneratorSetting.random_seed.default.arg
- self.generator.generator_settings[0].model_version = LLMGeneratorSetting.model_version.default.arg
- self.generator.generator_settings[0].dtype = LLMGeneratorSetting.dtype.default.arg
- self.generator.generator_settings[0].use_gpu = LLMGeneratorSetting.use_gpu.default.arg
- save_session()
- self.initialize_form()
- self.ui.top_p.set_slider_and_spinbox_values(self.generator.generator_settings[0].top_p)
- self.ui.max_length.set_slider_and_spinbox_values(self.generator.generator_settings[0].max_length)
- self.ui.repetition_penalty.set_slider_and_spinbox_values(self.generator.generator_settings[0].repetition_penalty)
- self.ui.min_length.set_slider_and_spinbox_values(self.generator.generator_settings[0].min_length)
- self.ui.length_penalty.set_slider_and_spinbox_values(self.generator.generator_settings[0].length_penalty)
- self.ui.num_beams.set_slider_and_spinbox_values(self.generator.generator_settings[0].num_beams)
- self.ui.ngram_size.set_slider_and_spinbox_values(self.generator.generator_settings[0].ngram_size)
- self.ui.temperature.set_slider_and_spinbox_values(self.generator.generator_settings[0].temperature)
- self.ui.sequences.set_slider_and_spinbox_values(self.generator.generator_settings[0].sequences)
- self.ui.top_k.set_slider_and_spinbox_values(self.generator.generator_settings[0].top_k)
- self.ui.eta_cutoff.set_slider_and_spinbox_values(self.generator.generator_settings[0].eta_cutoff)
- self.ui.random_seed.setChecked(self.generator.generator_settings[0].random_seed)
-
- def use_gpu_toggled(self, val):
- self.generator.generator_settings[0].use_gpu = val
- # toggle the 16bit radio button and disable 4bit and 8bit radio buttons
- self.set_dtype_by_gpu(val)
- save_session()
-
- def set_dtype_by_gpu(self, use_gpu):
- if not use_gpu:
- self.ui.radio_button_2bit.setEnabled(False)
- self.ui.radio_button_4bit.setEnabled(False)
- self.ui.radio_button_8bit.setEnabled(False)
- self.ui.radio_button_32bit.setEnabled(True)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
- if self.generator.generator_settings[0].dtype in ["4bit", "8bit"]:
- self.ui.radio_button_16bit.setChecked(True)
- else:
- self.ui.radio_button_2bit.setEnabled(True)
- self.ui.radio_button_4bit.setEnabled(True)
- self.ui.radio_button_8bit.setEnabled(True)
- self.ui.radio_button_32bit.setEnabled(False)
- if self.generator.generator_settings[0].dtype == "32bit":
- self.ui.radio_button_16bit.setChecked(True)
-
- def override_parameters_toggled(self, val):
- self.generator.override_parameters = val
- save_session()
-
- def toggle_leave_model_in_vram(self, val):
- print(val)
- if val:
- self.settings_manager.set_value("unload_unused_model", False)
- self.settings_manager.set_value("move_unused_model_to_cpu", False)
-
- def toggle_move_model_to_cpu(self, val):
- self.settings_manager.set_value("move_unused_model_to_cpu", val)
- if val:
- self.settings_manager.set_value("unload_unused_model", False)
-
- def toggle_unload_model(self, val):
- self.settings_manager.set_value("unload_unused_model", val)
- if val:
- self.settings_manager.set_value("move_unused_model_to_cpu", False)
-
- def set_tab(self, tab_name):
- index = self.ui.tabWidget.indexOf(self.ui.tabWidget.findChild(QWidget, tab_name))
- self.ui.tabWidget.setCurrentIndex(index)
+ # After the app is loaded, initialize other widgets
+ self.app.loaded.connect(self.initialize)
\ No newline at end of file
diff --git a/src/airunner/widgets/llm/message_widget.py b/src/airunner/widgets/llm/message_widget.py
index b12075a0e..7a918b4eb 100644
--- a/src/airunner/widgets/llm/message_widget.py
+++ b/src/airunner/widgets/llm/message_widget.py
@@ -1,19 +1,25 @@
from airunner.widgets.base_widget import BaseWidget
-from airunner.widgets.llm.templates.message_ui import Ui_message_widget
+from airunner.widgets.llm.templates.message_ui import Ui_message
class MessageWidget(BaseWidget):
- widget_class_ = Ui_message_widget
+ widget_class_ = Ui_message
def __init__(self, *args, **kwargs):
self.is_bot = kwargs.pop("is_bot")
self.message = kwargs.pop("message")
super().__init__(*args, **kwargs)
- self.ui.name.setText(f"{self.message.name}:")
- self.ui.message.setPlainText(self.message.message)
+ self.ui.content.setPlainText(self.message.message)
+ name = self.message.name
if self.is_bot:
- self.ui.name.setStyleSheet("font-weight: normal;")
+ self.ui.bot_name.show()
+ self.ui.bot_name.setText(f"{name}")
+ self.ui.bot_name.setStyleSheet("font-weight: normal;")
+ self.ui.user_name.hide()
else:
- self.ui.name.setStyleSheet("font-weight: bold;")
+ self.ui.user_name.show()
+ self.ui.user_name.setText(f"{name}")
+ self.ui.user_name.setStyleSheet("font-weight: normal;")
+ self.ui.bot_name.hide()
- self.ui.message.setStyleSheet("color: #f2f2f2;")
+ self.ui.content.setStyleSheet("color: #f2f2f2;")
diff --git a/src/airunner/widgets/llm/templates/chat_prompt.ui b/src/airunner/widgets/llm/templates/chat_prompt.ui
new file mode 100644
index 000000000..f9017ca49
--- /dev/null
+++ b/src/airunner/widgets/llm/templates/chat_prompt.ui
@@ -0,0 +1,179 @@
+
+
+ chat_prompt
+
+
+
+ 0
+ 0
+ 734
+ 1077
+
+
+
+ Form
+
+
+ -
+
+
+ true
+
+
+
+ -
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 714
+ 492
+
+
+
+
+
+
+ -
+
+
+
+ 0
+
+
+ 0
+
+
+ 0
+
+
+ 0
+
+
-
+
+
+ Chat input
+
+
+
+ -
+
+
-
+
+ Chat
+
+
+ -
+
+ Narrate
+
+
+ -
+
+ Generate Image
+
+
+ -
+
+ Summarize
+
+
+ -
+
+ Translate
+
+
+
+
+
+
+
+ -
+
+
-
+
+
+ Send
+
+
+
+ -
+
+
+ 0
+
+
+
+ -
+
+
+ New Conversation
+
+
+
+
+
+
+
+
+
+
+ send_button
+ clicked()
+ chat_prompt
+ action_button_clicked_send()
+
+
+ 75
+ 1057
+
+
+ 211
+ -12
+
+
+
+
+ clear_conversatiion_button
+ clicked()
+ chat_prompt
+ action_button_clicked_clear_conversation()
+
+
+ 657
+ 1052
+
+
+ 565
+ -13
+
+
+
+
+ comboBox
+ currentTextChanged(QString)
+ chat_prompt
+ message_type_text_changed(QString)
+
+
+ 634
+ 1018
+
+
+ 465
+ -3
+
+
+
+
+
+ action_button_clicked_send()
+ action_button_clicked_clear_conversation()
+ message_type_text_changed(QString)
+
+
diff --git a/src/airunner/widgets/llm/templates/chat_prompt_ui.py b/src/airunner/widgets/llm/templates/chat_prompt_ui.py
new file mode 100644
index 000000000..cba47dece
--- /dev/null
+++ b/src/airunner/widgets/llm/templates/chat_prompt_ui.py
@@ -0,0 +1,79 @@
+# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/llm/templates/chat_prompt.ui'
+#
+# Created by: PyQt6 UI code generator 6.4.2
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic6 is
+# run again. Do not edit this file unless you know what you are doing.
+
+
+from PyQt6 import QtCore, QtGui, QtWidgets
+
+
+class Ui_chat_prompt(object):
+ def setupUi(self, chat_prompt):
+ chat_prompt.setObjectName("chat_prompt")
+ chat_prompt.resize(734, 1077)
+ self.gridLayout = QtWidgets.QGridLayout(chat_prompt)
+ self.gridLayout.setObjectName("gridLayout")
+ self.conversation = QtWidgets.QTextEdit(parent=chat_prompt)
+ self.conversation.setReadOnly(True)
+ self.conversation.setObjectName("conversation")
+ self.gridLayout.addWidget(self.conversation, 0, 0, 1, 1)
+ self.chat_container = QtWidgets.QScrollArea(parent=chat_prompt)
+ self.chat_container.setWidgetResizable(True)
+ self.chat_container.setObjectName("chat_container")
+ self.scrollAreaWidgetContents = QtWidgets.QWidget()
+ self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 714, 492))
+ self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents")
+ self.verticalLayout = QtWidgets.QVBoxLayout(self.scrollAreaWidgetContents)
+ self.verticalLayout.setObjectName("verticalLayout")
+ self.chat_container.setWidget(self.scrollAreaWidgetContents)
+ self.gridLayout.addWidget(self.chat_container, 1, 0, 1, 1)
+ self.widget_5 = QtWidgets.QWidget(parent=chat_prompt)
+ self.widget_5.setObjectName("widget_5")
+ self.horizontalLayout = QtWidgets.QHBoxLayout(self.widget_5)
+ self.horizontalLayout.setContentsMargins(0, 0, 0, 0)
+ self.horizontalLayout.setObjectName("horizontalLayout")
+ self.prompt = QtWidgets.QLineEdit(parent=self.widget_5)
+ self.prompt.setObjectName("prompt")
+ self.horizontalLayout.addWidget(self.prompt)
+ self.comboBox = QtWidgets.QComboBox(parent=self.widget_5)
+ self.comboBox.setObjectName("comboBox")
+ self.comboBox.addItem("")
+ self.comboBox.addItem("")
+ self.comboBox.addItem("")
+ self.comboBox.addItem("")
+ self.comboBox.addItem("")
+ self.horizontalLayout.addWidget(self.comboBox)
+ self.gridLayout.addWidget(self.widget_5, 2, 0, 1, 1)
+ self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_2.setObjectName("horizontalLayout_2")
+ self.send_button = QtWidgets.QPushButton(parent=chat_prompt)
+ self.send_button.setObjectName("send_button")
+ self.horizontalLayout_2.addWidget(self.send_button)
+ self.progressBar = QtWidgets.QProgressBar(parent=chat_prompt)
+ self.progressBar.setProperty("value", 0)
+ self.progressBar.setObjectName("progressBar")
+ self.horizontalLayout_2.addWidget(self.progressBar)
+ self.clear_conversatiion_button = QtWidgets.QPushButton(parent=chat_prompt)
+ self.clear_conversatiion_button.setObjectName("clear_conversatiion_button")
+ self.horizontalLayout_2.addWidget(self.clear_conversatiion_button)
+ self.gridLayout.addLayout(self.horizontalLayout_2, 3, 0, 1, 1)
+
+ self.retranslateUi(chat_prompt)
+ self.send_button.clicked.connect(chat_prompt.action_button_clicked_send) # type: ignore
+ self.clear_conversatiion_button.clicked.connect(chat_prompt.action_button_clicked_clear_conversation) # type: ignore
+ self.comboBox.currentTextChanged['QString'].connect(chat_prompt.message_type_text_changed) # type: ignore
+ QtCore.QMetaObject.connectSlotsByName(chat_prompt)
+
+ def retranslateUi(self, chat_prompt):
+ _translate = QtCore.QCoreApplication.translate
+ chat_prompt.setWindowTitle(_translate("chat_prompt", "Form"))
+ self.prompt.setPlaceholderText(_translate("chat_prompt", "Chat input"))
+ self.comboBox.setItemText(0, _translate("chat_prompt", "Chat"))
+ self.comboBox.setItemText(1, _translate("chat_prompt", "Narrate"))
+ self.comboBox.setItemText(2, _translate("chat_prompt", "Generate Image"))
+ self.comboBox.setItemText(3, _translate("chat_prompt", "Summarize"))
+ self.comboBox.setItemText(4, _translate("chat_prompt", "Translate"))
+ self.send_button.setText(_translate("chat_prompt", "Send"))
+ self.clear_conversatiion_button.setText(_translate("chat_prompt", "New Conversation"))
diff --git a/src/airunner/widgets/llm/templates/llm_preferences.ui b/src/airunner/widgets/llm/templates/llm_preferences.ui
new file mode 100644
index 000000000..a1f768413
--- /dev/null
+++ b/src/airunner/widgets/llm/templates/llm_preferences.ui
@@ -0,0 +1,275 @@
+
+
+ llm_preferences_widget
+
+
+
+ 0
+ 0
+ 1371
+ 957
+
+
+
+ Form
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ Prefix
+
+
+
-
+
+
+
+
+
+
+ Suffix
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Bot details
+
+
+
-
+
+
-
+
+
+ Name
+
+
+
+ -
+
+
+ ChatAI
+
+
+ 6
+
+
+
+
+
+ -
+
+
-
+
+
+ Personality
+
+
+
+ -
+
+
-
+
+ Nice
+
+
+ -
+
+ Mean
+
+
+ -
+
+ Weird
+
+
+ -
+
+ Insane
+
+
+ -
+
+ Random
+
+
+
+
+
+
+
+
+
+ -
+
+
+ User name
+
+
+
-
+
+
+ User
+
+
+
+
+
+
+ -
+
+
+ Generate Characters
+
+
+
+
+
+
+ prefix
+ suffix
+ botname
+ username
+
+
+
+
+ generate_characters_button
+ clicked()
+ llm_preferences_widget
+ action_button_clicked_generate_characters()
+
+
+ 362
+ 962
+
+
+ 477
+ 0
+
+
+
+
+ username
+ textEdited(QString)
+ llm_preferences_widget
+ username_text_changed(QString)
+
+
+ 226
+ 919
+
+
+ 118
+ 0
+
+
+
+
+ botname
+ textEdited(QString)
+ llm_preferences_widget
+ botname_text_changed(QString)
+
+
+ 298
+ 843
+
+
+ 80
+ 0
+
+
+
+
+ personality_type
+ textHighlighted(QString)
+ llm_preferences_widget
+ personality_type_changed(QString)
+
+
+ 597
+ 843
+
+
+ 412
+ 0
+
+
+
+
+ suffix
+ textChanged()
+ llm_preferences_widget
+ suffix_text_changed()
+
+
+ 182
+ 669
+
+
+ 206
+ 5
+
+
+
+
+ prefix
+ textChanged()
+ llm_preferences_widget
+ prefix_text_changed()
+
+
+ 131
+ 108
+
+
+ 48
+ 6
+
+
+
+
+
+ prompt_text_changed(QString)
+ botname_text_changed(QString)
+ username_text_changed(QString)
+ model_version_changed(QString)
+ random_seed_toggled(bool)
+ response_text_changed(QString)
+ seed_changed(QString)
+ suffix_text_changed()
+ prefix_text_changed()
+ action_button_clicked_send()
+ action_button_clicked_generate_characters()
+ action_button_clicked_clear_conversation()
+ message_type_text_changed(QString)
+ personality_type_changed(QString)
+ toggled_4bit(bool)
+ toggled_8bit(bool)
+ toggled_16bit(bool)
+ toggled_32bit(bool)
+ do_sample_toggled(bool)
+ early_stopping_toggled(bool)
+ model_text_changed(QString)
+ reset_settings_to_default_clicked()
+ use_gpu_toggled(bool)
+ override_parameters_toggled(bool)
+ toggled_2bit(bool)
+ toggle_leave_model_in_vram(bool)
+ toggle_unload_model(bool)
+ toggle_move_model_to_cpu(bool)
+ prompt_template_text_changed(QString)
+
+
diff --git a/src/airunner/widgets/llm/templates/llm_preferences_ui.py b/src/airunner/widgets/llm/templates/llm_preferences_ui.py
new file mode 100644
index 000000000..7b1ca09a6
--- /dev/null
+++ b/src/airunner/widgets/llm/templates/llm_preferences_ui.py
@@ -0,0 +1,105 @@
+# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/llm/templates/llm_preferences.ui'
+#
+# Created by: PyQt6 UI code generator 6.4.2
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic6 is
+# run again. Do not edit this file unless you know what you are doing.
+
+
+from PyQt6 import QtCore, QtGui, QtWidgets
+
+
+class Ui_llm_preferences_widget(object):
+ def setupUi(self, llm_preferences_widget):
+ llm_preferences_widget.setObjectName("llm_preferences_widget")
+ llm_preferences_widget.resize(1371, 957)
+ self.gridLayout = QtWidgets.QGridLayout(llm_preferences_widget)
+ self.gridLayout.setObjectName("gridLayout")
+ self.splitter = QtWidgets.QSplitter(parent=llm_preferences_widget)
+ self.splitter.setOrientation(QtCore.Qt.Orientation.Vertical)
+ self.splitter.setObjectName("splitter")
+ self.groupBox = QtWidgets.QGroupBox(parent=self.splitter)
+ self.groupBox.setObjectName("groupBox")
+ self.gridLayout_3 = QtWidgets.QGridLayout(self.groupBox)
+ self.gridLayout_3.setObjectName("gridLayout_3")
+ self.prefix = QtWidgets.QPlainTextEdit(parent=self.groupBox)
+ self.prefix.setObjectName("prefix")
+ self.gridLayout_3.addWidget(self.prefix, 0, 0, 1, 1)
+ self.groupBox_2 = QtWidgets.QGroupBox(parent=self.splitter)
+ self.groupBox_2.setObjectName("groupBox_2")
+ self.gridLayout_4 = QtWidgets.QGridLayout(self.groupBox_2)
+ self.gridLayout_4.setObjectName("gridLayout_4")
+ self.suffix = QtWidgets.QPlainTextEdit(parent=self.groupBox_2)
+ self.suffix.setObjectName("suffix")
+ self.gridLayout_4.addWidget(self.suffix, 0, 0, 1, 1)
+ self.gridLayout.addWidget(self.splitter, 0, 0, 1, 1)
+ self.groupBox_3 = QtWidgets.QGroupBox(parent=llm_preferences_widget)
+ self.groupBox_3.setObjectName("groupBox_3")
+ self.horizontalLayout_5 = QtWidgets.QHBoxLayout(self.groupBox_3)
+ self.horizontalLayout_5.setObjectName("horizontalLayout_5")
+ self.verticalLayout_2 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_2.setObjectName("verticalLayout_2")
+ self.label = QtWidgets.QLabel(parent=self.groupBox_3)
+ self.label.setObjectName("label")
+ self.verticalLayout_2.addWidget(self.label)
+ self.botname = QtWidgets.QLineEdit(parent=self.groupBox_3)
+ self.botname.setCursorPosition(6)
+ self.botname.setObjectName("botname")
+ self.verticalLayout_2.addWidget(self.botname)
+ self.horizontalLayout_5.addLayout(self.verticalLayout_2)
+ self.verticalLayout_3 = QtWidgets.QVBoxLayout()
+ self.verticalLayout_3.setObjectName("verticalLayout_3")
+ self.label_2 = QtWidgets.QLabel(parent=self.groupBox_3)
+ self.label_2.setObjectName("label_2")
+ self.verticalLayout_3.addWidget(self.label_2)
+ self.personality_type = QtWidgets.QComboBox(parent=self.groupBox_3)
+ self.personality_type.setObjectName("personality_type")
+ self.personality_type.addItem("")
+ self.personality_type.addItem("")
+ self.personality_type.addItem("")
+ self.personality_type.addItem("")
+ self.personality_type.addItem("")
+ self.verticalLayout_3.addWidget(self.personality_type)
+ self.horizontalLayout_5.addLayout(self.verticalLayout_3)
+ self.gridLayout.addWidget(self.groupBox_3, 1, 0, 1, 1)
+ self.groupBox_4 = QtWidgets.QGroupBox(parent=llm_preferences_widget)
+ self.groupBox_4.setObjectName("groupBox_4")
+ self.gridLayout_13 = QtWidgets.QGridLayout(self.groupBox_4)
+ self.gridLayout_13.setObjectName("gridLayout_13")
+ self.username = QtWidgets.QLineEdit(parent=self.groupBox_4)
+ self.username.setObjectName("username")
+ self.gridLayout_13.addWidget(self.username, 0, 0, 1, 1)
+ self.gridLayout.addWidget(self.groupBox_4, 2, 0, 1, 1)
+ self.generate_characters_button = QtWidgets.QPushButton(parent=llm_preferences_widget)
+ self.generate_characters_button.setObjectName("generate_characters_button")
+ self.gridLayout.addWidget(self.generate_characters_button, 3, 0, 1, 1)
+
+ self.retranslateUi(llm_preferences_widget)
+ self.generate_characters_button.clicked.connect(llm_preferences_widget.action_button_clicked_generate_characters) # type: ignore
+ self.username.textEdited['QString'].connect(llm_preferences_widget.username_text_changed) # type: ignore
+ self.botname.textEdited['QString'].connect(llm_preferences_widget.botname_text_changed) # type: ignore
+ self.personality_type.textHighlighted['QString'].connect(llm_preferences_widget.personality_type_changed) # type: ignore
+ self.suffix.textChanged.connect(llm_preferences_widget.suffix_text_changed) # type: ignore
+ self.prefix.textChanged.connect(llm_preferences_widget.prefix_text_changed) # type: ignore
+ QtCore.QMetaObject.connectSlotsByName(llm_preferences_widget)
+ llm_preferences_widget.setTabOrder(self.prefix, self.suffix)
+ llm_preferences_widget.setTabOrder(self.suffix, self.botname)
+ llm_preferences_widget.setTabOrder(self.botname, self.username)
+
+ def retranslateUi(self, llm_preferences_widget):
+ _translate = QtCore.QCoreApplication.translate
+ llm_preferences_widget.setWindowTitle(_translate("llm_preferences_widget", "Form"))
+ self.groupBox.setTitle(_translate("llm_preferences_widget", "Prefix"))
+ self.groupBox_2.setTitle(_translate("llm_preferences_widget", "Suffix"))
+ self.groupBox_3.setTitle(_translate("llm_preferences_widget", "Bot details"))
+ self.label.setText(_translate("llm_preferences_widget", "Name"))
+ self.botname.setText(_translate("llm_preferences_widget", "ChatAI"))
+ self.label_2.setText(_translate("llm_preferences_widget", "Personality"))
+ self.personality_type.setItemText(0, _translate("llm_preferences_widget", "Nice"))
+ self.personality_type.setItemText(1, _translate("llm_preferences_widget", "Mean"))
+ self.personality_type.setItemText(2, _translate("llm_preferences_widget", "Weird"))
+ self.personality_type.setItemText(3, _translate("llm_preferences_widget", "Insane"))
+ self.personality_type.setItemText(4, _translate("llm_preferences_widget", "Random"))
+ self.groupBox_4.setTitle(_translate("llm_preferences_widget", "User name"))
+ self.username.setText(_translate("llm_preferences_widget", "User"))
+ self.generate_characters_button.setText(_translate("llm_preferences_widget", "Generate Characters"))
diff --git a/src/airunner/widgets/llm/templates/llm_settings.ui b/src/airunner/widgets/llm/templates/llm_settings.ui
new file mode 100644
index 000000000..e7129ede0
--- /dev/null
+++ b/src/airunner/widgets/llm/templates/llm_settings.ui
@@ -0,0 +1,1088 @@
+
+
+ llm_settings_widget
+
+
+
+ 0
+ 0
+ 1262
+ 1059
+
+
+
+ Form
+
+
+ -
+
+
+ Model Type
+
+
+
-
+
+
+
+
+
+ -
+
+
+ Model Version
+
+
+
-
+
+
+ -1
+
+
+
+
+
+
+ -
+
+
+ Prompt Template
+
+
+
-
+
+
+
+
+
+ -
+
+
+ DType
+
+
+
-
+
+
-
+
+
+ 2-bit
+
+
+
+ -
+
+
+ 4-bit
+
+
+
+ -
+
+
+ 8-bit
+
+
+
+ -
+
+
+ 16-bit
+
+
+
+ -
+
+
+ 32-bit
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+
+
+ -
+
+
+
+ 9
+
+
+
+ Description
+
+
+
+ -
+
+
+ Use GPU
+
+
+
+
+
+
+ -
+
+
+ Override Prameters
+
+
+ true
+
+
+ true
+
+
+
-
+
+
+ Reset Settings to Default
+
+
+
+ -
+
+
+ Seed
+
+
+
-
+
+
+ -
+
+
+ Random seed
+
+
+
+
+
+
+ -
+
+
-
+
+
+ Early stopping
+
+
+
+ -
+
+
+ Do sample
+
+
+
+
+
+ -
+
+
-
+
+
+ Repetition penalty
+
+
+
-
+
+
+ 1
+
+
+ 10000
+
+
+ 0.010000000000000
+
+
+ 100.000000000000000
+
+
+ true
+
+
+ llm_generator_setting.repetition_penalty
+
+
+ 0
+
+
+ 1
+
+
+ 1.000000000000000
+
+
+ 10.000000000000000
+
+
+ handle_value_change
+
+
+
+
+
+
+ -
+
+
+ Min length
+
+
+
-
+
+
+ 1
+
+
+ 2556
+
+
+ 1.000000000000000
+
+
+ 2556.000000000000000
+
+
+ false
+
+
+ llm_generator_setting.min_length
+
+
+ 1
+
+
+ 2556
+
+
+ 1
+
+
+ 2556
+
+
+ handle_value_change
+
+
+
+
+
+
+
+
+ -
+
+
-
+
+
+ Top P
+
+
+
-
+
+
+ 1
+
+
+ 100
+
+
+ 0.000000000000000
+
+
+ 1.000000000000000
+
+
+ true
+
+
+ llm_generator_setting.top_p
+
+
+ 1
+
+
+ 10
+
+
+ 0.010000000000000
+
+
+ 0.100000000000000
+
+
+ handle_value_change
+
+
+
+
+
+
+ -
+
+
+ Max length
+
+
+
-
+
+
+ 1
+
+
+ 2556
+
+
+ 1.000000000000000
+
+
+ 2556.000000000000000
+
+
+ false
+
+
+ llm_generator_setting.max_length
+
+
+ 1
+
+
+ 2556
+
+
+ 1
+
+
+ 2556
+
+
+ handle_value_change
+
+
+
+
+
+
+
+
+ -
+
+
-
+
+
+ Leave in VRAM
+
+
+
+ -
+
+
+ Move to CPU
+
+
+
+ -
+
+
+ Unload model
+
+
+
+
+
+ -
+
+
+
+ true
+
+
+
+ Model management
+
+
+
+ -
+
+
-
+
+
+ No repeat ngram size
+
+
+
-
+
+
+ 0
+
+
+ 20
+
+
+ 0.000000000000000
+
+
+ 20.000000000000000
+
+
+ false
+
+
+ llm_generator_setting.ngram_size
+
+
+ 1
+
+
+ 1
+
+
+ 1.000000000000000
+
+
+ 1.000000000000000
+
+
+ handle_value_change
+
+
+
+
+
+
+ -
+
+
+ Temperature
+
+
+
-
+
+
+ 1
+
+
+ 20000
+
+
+ 0.000100000000000
+
+
+ 2.000000000000000
+
+
+ true
+
+
+ llm_generator_setting.temperature
+
+
+ 1
+
+
+ 10
+
+
+ 0.010000000000000
+
+
+ 0.100000000000000
+
+
+ handle_value_change
+
+
+
+
+
+
+
+
+ -
+
+
-
+
+
+ Length penalty
+
+
+
-
+
+
+ -100
+
+
+ 100
+
+
+ 0.000000000000000
+
+
+ 1.000000000000000
+
+
+ true
+
+
+ llm_generator_setting.length_penalty
+
+
+ 1
+
+
+ 10
+
+
+ 0.010000000000000
+
+
+ 0.100000000000000
+
+
+ handle_value_change
+
+
+
+
+
+
+ -
+
+
+ Num beams
+
+
+
-
+
+
+ 1
+
+
+ 100
+
+
+ 0.000000000000000
+
+
+ 100.000000000000000
+
+
+ false
+
+
+ llm_generator_setting.num_beams
+
+
+ 1
+
+
+ 10
+
+
+ 0.010000000000000
+
+
+ 0.100000000000000
+
+
+ handle_value_change
+
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ -
+
+
-
+
+
+ Sequences to generate
+
+
+
-
+
+
+ 1
+
+
+ 100
+
+
+ 0.000000000000000
+
+
+ 100.000000000000000
+
+
+ false
+
+
+ llm_generator_setting.sequences
+
+
+ 1
+
+
+ 10
+
+
+ 0.010000000000000
+
+
+ 0.100000000000000
+
+
+ handle_value_change
+
+
+
+
+
+
+ -
+
+
+ Top k
+
+
+
-
+
+
+ 0
+
+
+ 256
+
+
+ 0.000000000000000
+
+
+ 256.000000000000000
+
+
+ false
+
+
+ llm_generator_setting.top_k
+
+
+ 1
+
+
+ 10
+
+
+ 1
+
+
+ 10
+
+
+ handle_value_change
+
+
+
+
+
+
+
+
+ -
+
+
+
+ 9
+
+
+
+ How to treat model when not in use
+
+
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+ SliderWidget
+ QWidget
+ airunner/widgets/slider/slider_widget
+ 1
+
+
+
+ model
+ model_version
+ use_gpu_checkbox
+ radio_button_4bit
+ radio_button_8bit
+ radio_button_16bit
+ radio_button_32bit
+ seed
+ random_seed
+
+
+
+
+ radio_button_2bit
+ toggled(bool)
+ llm_settings_widget
+ toggled_2bit(bool)
+
+
+ 78
+ 318
+
+
+ 0
+ 324
+
+
+
+
+ radio_button_16bit
+ toggled(bool)
+ llm_settings_widget
+ toggled_16bit(bool)
+
+
+ 276
+ 318
+
+
+ 3
+ 136
+
+
+
+
+ use_gpu_checkbox
+ toggled(bool)
+ llm_settings_widget
+ use_gpu_toggled(bool)
+
+
+ 462
+ 288
+
+
+ 34
+ 0
+
+
+
+
+ radio_button_4bit
+ toggled(bool)
+ llm_settings_widget
+ toggled_4bit(bool)
+
+
+ 141
+ 318
+
+
+ 3
+ 43
+
+
+
+
+ radio_button_8bit
+ toggled(bool)
+ llm_settings_widget
+ toggled_8bit(bool)
+
+
+ 204
+ 318
+
+
+ 2
+ 88
+
+
+
+
+ radio_button_32bit
+ toggled(bool)
+ llm_settings_widget
+ toggled_32bit(bool)
+
+
+ 348
+ 318
+
+
+ 0
+ 179
+
+
+
+
+ override_parameters
+ toggled(bool)
+ llm_settings_widget
+ override_parameters_toggled(bool)
+
+
+ 128
+ 398
+
+
+ 120
+ 0
+
+
+
+
+ seed
+ textEdited(QString)
+ llm_settings_widget
+ seed_changed(QString)
+
+
+ 301
+ 756
+
+
+ 2
+ 382
+
+
+
+
+ early_stopping
+ toggled(bool)
+ llm_settings_widget
+ early_stopping_toggled(bool)
+
+
+ 125
+ 798
+
+
+ 1
+ 501
+
+
+
+
+ random_seed
+ toggled(bool)
+ llm_settings_widget
+ random_seed_toggled(bool)
+
+
+ 1228
+ 755
+
+
+ 3
+ 470
+
+
+
+
+ pushButton
+ clicked()
+ llm_settings_widget
+ reset_settings_to_default_clicked()
+
+
+ 311
+ 913
+
+
+ 437
+ 0
+
+
+
+
+ move_to_cpu
+ toggled(bool)
+ llm_settings_widget
+ toggle_move_model_to_cpu(bool)
+
+
+ 831
+ 881
+
+
+ 0
+ 704
+
+
+
+
+ do_sample
+ toggled(bool)
+ llm_settings_widget
+ do_sample_toggled(bool)
+
+
+ 1231
+ 798
+
+
+ 1
+ 427
+
+
+
+
+ leave_in_vram
+ toggled(bool)
+ llm_settings_widget
+ toggle_leave_model_in_vram(bool)
+
+
+ 142
+ 881
+
+
+ 0
+ 716
+
+
+
+
+ unload_model
+ toggled(bool)
+ llm_settings_widget
+ toggle_unload_model(bool)
+
+
+ 1239
+ 881
+
+
+ 0
+ 685
+
+
+
+
+ prompt_template
+ currentTextChanged(QString)
+ llm_settings_widget
+ prompt_template_text_changed(QString)
+
+
+ 90
+ 215
+
+
+ 79
+ -16
+
+
+
+
+ model
+ currentTextChanged(QString)
+ llm_settings_widget
+ model_text_changed(QString)
+
+
+ 184
+ 65
+
+
+ 117
+ 0
+
+
+
+
+ model_version
+ currentTextChanged(QString)
+ llm_settings_widget
+ model_version_changed(QString)
+
+
+ 205
+ 140
+
+
+ 272
+ 0
+
+
+
+
+
+ prompt_text_changed(QString)
+ botname_text_changed(QString)
+ username_text_changed(QString)
+ model_version_changed(QString)
+ random_seed_toggled(bool)
+ response_text_changed(QString)
+ seed_changed(QString)
+ suffix_text_changed()
+ prefix_text_changed()
+ action_button_clicked_send()
+ action_button_clicked_generate_characters()
+ action_button_clicked_clear_conversation()
+ message_type_text_changed(QString)
+ personality_type_changed(QString)
+ toggled_4bit(bool)
+ toggled_8bit(bool)
+ toggled_16bit(bool)
+ toggled_32bit(bool)
+ do_sample_toggled(bool)
+ early_stopping_toggled(bool)
+ model_text_changed(QString)
+ reset_settings_to_default_clicked()
+ use_gpu_toggled(bool)
+ override_parameters_toggled(bool)
+ toggled_2bit(bool)
+ toggle_leave_model_in_vram(bool)
+ toggle_unload_model(bool)
+ toggle_move_model_to_cpu(bool)
+ prompt_template_text_changed(QString)
+
+
diff --git a/src/airunner/widgets/llm/templates/llm_settings_ui.py b/src/airunner/widgets/llm/templates/llm_settings_ui.py
new file mode 100644
index 000000000..bfb93997b
--- /dev/null
+++ b/src/airunner/widgets/llm/templates/llm_settings_ui.py
@@ -0,0 +1,408 @@
+# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/llm/templates/llm_settings.ui'
+#
+# Created by: PyQt6 UI code generator 6.4.2
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic6 is
+# run again. Do not edit this file unless you know what you are doing.
+
+
+from PyQt6 import QtCore, QtGui, QtWidgets
+
+
+class Ui_llm_settings_widget(object):
+ def setupUi(self, llm_settings_widget):
+ llm_settings_widget.setObjectName("llm_settings_widget")
+ llm_settings_widget.resize(1262, 1059)
+ self.gridLayout = QtWidgets.QGridLayout(llm_settings_widget)
+ self.gridLayout.setObjectName("gridLayout")
+ self.groupBox_7 = QtWidgets.QGroupBox(parent=llm_settings_widget)
+ self.groupBox_7.setObjectName("groupBox_7")
+ self.gridLayout_10 = QtWidgets.QGridLayout(self.groupBox_7)
+ self.gridLayout_10.setObjectName("gridLayout_10")
+ self.model = QtWidgets.QComboBox(parent=self.groupBox_7)
+ self.model.setObjectName("model")
+ self.gridLayout_10.addWidget(self.model, 0, 0, 1, 1)
+ self.gridLayout.addWidget(self.groupBox_7, 0, 0, 1, 1)
+ self.groupBox_8 = QtWidgets.QGroupBox(parent=llm_settings_widget)
+ self.groupBox_8.setObjectName("groupBox_8")
+ self.gridLayout_11 = QtWidgets.QGridLayout(self.groupBox_8)
+ self.gridLayout_11.setObjectName("gridLayout_11")
+ self.model_version = QtWidgets.QComboBox(parent=self.groupBox_8)
+ self.model_version.setObjectName("model_version")
+ self.gridLayout_11.addWidget(self.model_version, 0, 0, 1, 1)
+ self.gridLayout.addWidget(self.groupBox_8, 1, 0, 1, 1)
+ self.groupBox_14 = QtWidgets.QGroupBox(parent=llm_settings_widget)
+ self.groupBox_14.setObjectName("groupBox_14")
+ self.gridLayout_17 = QtWidgets.QGridLayout(self.groupBox_14)
+ self.gridLayout_17.setObjectName("gridLayout_17")
+ self.prompt_template = QtWidgets.QComboBox(parent=self.groupBox_14)
+ self.prompt_template.setObjectName("prompt_template")
+ self.gridLayout_17.addWidget(self.prompt_template, 0, 0, 1, 1)
+ self.gridLayout.addWidget(self.groupBox_14, 2, 0, 1, 1)
+ self.groupBox_6 = QtWidgets.QGroupBox(parent=llm_settings_widget)
+ self.groupBox_6.setObjectName("groupBox_6")
+ self.gridLayout_9 = QtWidgets.QGridLayout(self.groupBox_6)
+ self.gridLayout_9.setObjectName("gridLayout_9")
+ self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_3.setObjectName("horizontalLayout_3")
+ self.radio_button_2bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
+ self.radio_button_2bit.setObjectName("radio_button_2bit")
+ self.horizontalLayout_3.addWidget(self.radio_button_2bit)
+ self.radio_button_4bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
+ self.radio_button_4bit.setObjectName("radio_button_4bit")
+ self.horizontalLayout_3.addWidget(self.radio_button_4bit)
+ self.radio_button_8bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
+ self.radio_button_8bit.setObjectName("radio_button_8bit")
+ self.horizontalLayout_3.addWidget(self.radio_button_8bit)
+ self.radio_button_16bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
+ self.radio_button_16bit.setObjectName("radio_button_16bit")
+ self.horizontalLayout_3.addWidget(self.radio_button_16bit)
+ self.radio_button_32bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
+ self.radio_button_32bit.setObjectName("radio_button_32bit")
+ self.horizontalLayout_3.addWidget(self.radio_button_32bit)
+ spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Minimum)
+ self.horizontalLayout_3.addItem(spacerItem)
+ self.gridLayout_9.addLayout(self.horizontalLayout_3, 1, 0, 1, 1)
+ self.dtype_description = QtWidgets.QLabel(parent=self.groupBox_6)
+ font = QtGui.QFont()
+ font.setPointSize(9)
+ self.dtype_description.setFont(font)
+ self.dtype_description.setObjectName("dtype_description")
+ self.gridLayout_9.addWidget(self.dtype_description, 2, 0, 1, 1)
+ self.use_gpu_checkbox = QtWidgets.QCheckBox(parent=self.groupBox_6)
+ self.use_gpu_checkbox.setObjectName("use_gpu_checkbox")
+ self.gridLayout_9.addWidget(self.use_gpu_checkbox, 0, 0, 1, 1)
+ self.gridLayout.addWidget(self.groupBox_6, 3, 0, 1, 1)
+ self.override_parameters = QtWidgets.QGroupBox(parent=llm_settings_widget)
+ self.override_parameters.setCheckable(True)
+ self.override_parameters.setChecked(True)
+ self.override_parameters.setObjectName("override_parameters")
+ self.gridLayout_12 = QtWidgets.QGridLayout(self.override_parameters)
+ self.gridLayout_12.setObjectName("gridLayout_12")
+ self.pushButton = QtWidgets.QPushButton(parent=self.override_parameters)
+ self.pushButton.setObjectName("pushButton")
+ self.gridLayout_12.addWidget(self.pushButton, 14, 0, 1, 1)
+ self.groupBox_19 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_19.setObjectName("groupBox_19")
+ self.horizontalLayout_4 = QtWidgets.QHBoxLayout(self.groupBox_19)
+ self.horizontalLayout_4.setObjectName("horizontalLayout_4")
+ self.seed = QtWidgets.QLineEdit(parent=self.groupBox_19)
+ self.seed.setObjectName("seed")
+ self.horizontalLayout_4.addWidget(self.seed)
+ self.random_seed = QtWidgets.QCheckBox(parent=self.groupBox_19)
+ self.random_seed.setObjectName("random_seed")
+ self.horizontalLayout_4.addWidget(self.random_seed)
+ self.gridLayout_12.addWidget(self.groupBox_19, 5, 0, 1, 1)
+ self.horizontalLayout_6 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_6.setObjectName("horizontalLayout_6")
+ self.early_stopping = QtWidgets.QCheckBox(parent=self.override_parameters)
+ self.early_stopping.setObjectName("early_stopping")
+ self.horizontalLayout_6.addWidget(self.early_stopping)
+ self.do_sample = QtWidgets.QCheckBox(parent=self.override_parameters)
+ self.do_sample.setObjectName("do_sample")
+ self.horizontalLayout_6.addWidget(self.do_sample)
+ self.gridLayout_12.addLayout(self.horizontalLayout_6, 8, 0, 1, 1)
+ self.horizontalLayout_7 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_7.setObjectName("horizontalLayout_7")
+ self.groupBox_11 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_11.setObjectName("groupBox_11")
+ self.gridLayout_19 = QtWidgets.QGridLayout(self.groupBox_11)
+ self.gridLayout_19.setObjectName("gridLayout_19")
+ self.repetition_penalty = SliderWidget(parent=self.groupBox_11)
+ self.repetition_penalty.setProperty("slider_minimum", 1)
+ self.repetition_penalty.setProperty("slider_maximum", 10000)
+ self.repetition_penalty.setProperty("spinbox_minimum", 0.01)
+ self.repetition_penalty.setProperty("spinbox_maximum", 100.0)
+ self.repetition_penalty.setProperty("display_as_float", True)
+ self.repetition_penalty.setProperty("slider_single_step", 0)
+ self.repetition_penalty.setProperty("slider_page_step", 1)
+ self.repetition_penalty.setProperty("spinbox_single_step", 1.0)
+ self.repetition_penalty.setProperty("spinbox_page_step", 10.0)
+ self.repetition_penalty.setObjectName("repetition_penalty")
+ self.gridLayout_19.addWidget(self.repetition_penalty, 0, 0, 1, 1)
+ self.horizontalLayout_7.addWidget(self.groupBox_11)
+ self.groupBox_16 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_16.setObjectName("groupBox_16")
+ self.gridLayout_20 = QtWidgets.QGridLayout(self.groupBox_16)
+ self.gridLayout_20.setObjectName("gridLayout_20")
+ self.min_length = SliderWidget(parent=self.groupBox_16)
+ self.min_length.setProperty("slider_minimum", 1)
+ self.min_length.setProperty("slider_maximum", 2556)
+ self.min_length.setProperty("spinbox_minimum", 1.0)
+ self.min_length.setProperty("spinbox_maximum", 2556.0)
+ self.min_length.setProperty("display_as_float", False)
+ self.min_length.setProperty("slider_single_step", 1)
+ self.min_length.setProperty("slider_page_step", 2556)
+ self.min_length.setProperty("spinbox_single_step", 1)
+ self.min_length.setProperty("spinbox_page_step", 2556)
+ self.min_length.setObjectName("min_length")
+ self.gridLayout_20.addWidget(self.min_length, 0, 0, 1, 1)
+ self.horizontalLayout_7.addWidget(self.groupBox_16)
+ self.gridLayout_12.addLayout(self.horizontalLayout_7, 1, 0, 1, 1)
+ self.horizontalLayout_9 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_9.setObjectName("horizontalLayout_9")
+ self.groupBox_20 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_20.setObjectName("groupBox_20")
+ self.gridLayout_24 = QtWidgets.QGridLayout(self.groupBox_20)
+ self.gridLayout_24.setObjectName("gridLayout_24")
+ self.top_p = SliderWidget(parent=self.groupBox_20)
+ self.top_p.setProperty("slider_minimum", 1)
+ self.top_p.setProperty("slider_maximum", 100)
+ self.top_p.setProperty("spinbox_minimum", 0.0)
+ self.top_p.setProperty("spinbox_maximum", 1.0)
+ self.top_p.setProperty("display_as_float", True)
+ self.top_p.setProperty("slider_single_step", 1)
+ self.top_p.setProperty("slider_page_step", 10)
+ self.top_p.setProperty("spinbox_single_step", 0.01)
+ self.top_p.setProperty("spinbox_page_step", 0.1)
+ self.top_p.setObjectName("top_p")
+ self.gridLayout_24.addWidget(self.top_p, 0, 0, 1, 1)
+ self.horizontalLayout_9.addWidget(self.groupBox_20)
+ self.groupBox_21 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_21.setObjectName("groupBox_21")
+ self.gridLayout_25 = QtWidgets.QGridLayout(self.groupBox_21)
+ self.gridLayout_25.setObjectName("gridLayout_25")
+ self.max_length = SliderWidget(parent=self.groupBox_21)
+ self.max_length.setProperty("slider_minimum", 1)
+ self.max_length.setProperty("slider_maximum", 2556)
+ self.max_length.setProperty("spinbox_minimum", 1.0)
+ self.max_length.setProperty("spinbox_maximum", 2556.0)
+ self.max_length.setProperty("display_as_float", False)
+ self.max_length.setProperty("slider_single_step", 1)
+ self.max_length.setProperty("slider_page_step", 2556)
+ self.max_length.setProperty("spinbox_single_step", 1)
+ self.max_length.setProperty("spinbox_page_step", 2556)
+ self.max_length.setObjectName("max_length")
+ self.gridLayout_25.addWidget(self.max_length, 0, 0, 1, 1)
+ self.horizontalLayout_9.addWidget(self.groupBox_21)
+ self.gridLayout_12.addLayout(self.horizontalLayout_9, 0, 0, 1, 1)
+ self.horizontalLayout_12 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_12.setObjectName("horizontalLayout_12")
+ self.leave_in_vram = QtWidgets.QRadioButton(parent=self.override_parameters)
+ self.leave_in_vram.setObjectName("leave_in_vram")
+ self.horizontalLayout_12.addWidget(self.leave_in_vram)
+ self.move_to_cpu = QtWidgets.QRadioButton(parent=self.override_parameters)
+ self.move_to_cpu.setObjectName("move_to_cpu")
+ self.horizontalLayout_12.addWidget(self.move_to_cpu)
+ self.unload_model = QtWidgets.QRadioButton(parent=self.override_parameters)
+ self.unload_model.setObjectName("unload_model")
+ self.horizontalLayout_12.addWidget(self.unload_model)
+ self.gridLayout_12.addLayout(self.horizontalLayout_12, 12, 0, 1, 1)
+ self.label_3 = QtWidgets.QLabel(parent=self.override_parameters)
+ font = QtGui.QFont()
+ font.setBold(True)
+ self.label_3.setFont(font)
+ self.label_3.setObjectName("label_3")
+ self.gridLayout_12.addWidget(self.label_3, 10, 0, 1, 1)
+ self.horizontalLayout_8 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_8.setObjectName("horizontalLayout_8")
+ self.groupBox_17 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_17.setObjectName("groupBox_17")
+ self.gridLayout_21 = QtWidgets.QGridLayout(self.groupBox_17)
+ self.gridLayout_21.setObjectName("gridLayout_21")
+ self.ngram_size = SliderWidget(parent=self.groupBox_17)
+ self.ngram_size.setProperty("slider_minimum", 0)
+ self.ngram_size.setProperty("slider_maximum", 20)
+ self.ngram_size.setProperty("spinbox_minimum", 0.0)
+ self.ngram_size.setProperty("spinbox_maximum", 20.0)
+ self.ngram_size.setProperty("display_as_float", False)
+ self.ngram_size.setProperty("slider_single_step", 1)
+ self.ngram_size.setProperty("slider_page_step", 1)
+ self.ngram_size.setProperty("spinbox_single_step", 1.0)
+ self.ngram_size.setProperty("spinbox_page_step", 1.0)
+ self.ngram_size.setObjectName("ngram_size")
+ self.gridLayout_21.addWidget(self.ngram_size, 0, 0, 1, 1)
+ self.horizontalLayout_8.addWidget(self.groupBox_17)
+ self.groupBox_18 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_18.setObjectName("groupBox_18")
+ self.gridLayout_22 = QtWidgets.QGridLayout(self.groupBox_18)
+ self.gridLayout_22.setObjectName("gridLayout_22")
+ self.temperature = SliderWidget(parent=self.groupBox_18)
+ self.temperature.setProperty("slider_minimum", 1)
+ self.temperature.setProperty("slider_maximum", 20000)
+ self.temperature.setProperty("spinbox_minimum", 0.0001)
+ self.temperature.setProperty("spinbox_maximum", 2.0)
+ self.temperature.setProperty("display_as_float", True)
+ self.temperature.setProperty("slider_single_step", 1)
+ self.temperature.setProperty("slider_page_step", 10)
+ self.temperature.setProperty("spinbox_single_step", 0.01)
+ self.temperature.setProperty("spinbox_page_step", 0.1)
+ self.temperature.setObjectName("temperature")
+ self.gridLayout_22.addWidget(self.temperature, 0, 0, 1, 1)
+ self.horizontalLayout_8.addWidget(self.groupBox_18)
+ self.gridLayout_12.addLayout(self.horizontalLayout_8, 3, 0, 1, 1)
+ self.horizontalLayout_11 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_11.setObjectName("horizontalLayout_11")
+ self.groupBox_24 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_24.setObjectName("groupBox_24")
+ self.gridLayout_28 = QtWidgets.QGridLayout(self.groupBox_24)
+ self.gridLayout_28.setObjectName("gridLayout_28")
+ self.length_penalty = SliderWidget(parent=self.groupBox_24)
+ self.length_penalty.setProperty("slider_minimum", -100)
+ self.length_penalty.setProperty("slider_maximum", 100)
+ self.length_penalty.setProperty("spinbox_minimum", 0.0)
+ self.length_penalty.setProperty("spinbox_maximum", 1.0)
+ self.length_penalty.setProperty("display_as_float", True)
+ self.length_penalty.setProperty("slider_single_step", 1)
+ self.length_penalty.setProperty("slider_page_step", 10)
+ self.length_penalty.setProperty("spinbox_single_step", 0.01)
+ self.length_penalty.setProperty("spinbox_page_step", 0.1)
+ self.length_penalty.setObjectName("length_penalty")
+ self.gridLayout_28.addWidget(self.length_penalty, 0, 0, 1, 1)
+ self.horizontalLayout_11.addWidget(self.groupBox_24)
+ self.groupBox_25 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_25.setObjectName("groupBox_25")
+ self.gridLayout_29 = QtWidgets.QGridLayout(self.groupBox_25)
+ self.gridLayout_29.setObjectName("gridLayout_29")
+ self.num_beams = SliderWidget(parent=self.groupBox_25)
+ self.num_beams.setProperty("slider_minimum", 1)
+ self.num_beams.setProperty("slider_maximum", 100)
+ self.num_beams.setProperty("spinbox_minimum", 0.0)
+ self.num_beams.setProperty("spinbox_maximum", 100.0)
+ self.num_beams.setProperty("display_as_float", False)
+ self.num_beams.setProperty("slider_single_step", 1)
+ self.num_beams.setProperty("slider_page_step", 10)
+ self.num_beams.setProperty("spinbox_single_step", 0.01)
+ self.num_beams.setProperty("spinbox_page_step", 0.1)
+ self.num_beams.setObjectName("num_beams")
+ self.gridLayout_29.addWidget(self.num_beams, 0, 0, 1, 1)
+ self.horizontalLayout_11.addWidget(self.groupBox_25)
+ self.gridLayout_12.addLayout(self.horizontalLayout_11, 2, 0, 1, 1)
+ self.line = QtWidgets.QFrame(parent=self.override_parameters)
+ self.line.setFrameShape(QtWidgets.QFrame.Shape.HLine)
+ self.line.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken)
+ self.line.setObjectName("line")
+ self.gridLayout_12.addWidget(self.line, 9, 0, 1, 1)
+ self.horizontalLayout_10 = QtWidgets.QHBoxLayout()
+ self.horizontalLayout_10.setObjectName("horizontalLayout_10")
+ self.groupBox_22 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_22.setObjectName("groupBox_22")
+ self.gridLayout_26 = QtWidgets.QGridLayout(self.groupBox_22)
+ self.gridLayout_26.setObjectName("gridLayout_26")
+ self.sequences = SliderWidget(parent=self.groupBox_22)
+ self.sequences.setProperty("slider_minimum", 1)
+ self.sequences.setProperty("slider_maximum", 100)
+ self.sequences.setProperty("spinbox_minimum", 0.0)
+ self.sequences.setProperty("spinbox_maximum", 100.0)
+ self.sequences.setProperty("display_as_float", False)
+ self.sequences.setProperty("slider_single_step", 1)
+ self.sequences.setProperty("slider_page_step", 10)
+ self.sequences.setProperty("spinbox_single_step", 0.01)
+ self.sequences.setProperty("spinbox_page_step", 0.1)
+ self.sequences.setObjectName("sequences")
+ self.gridLayout_26.addWidget(self.sequences, 0, 0, 1, 1)
+ self.horizontalLayout_10.addWidget(self.groupBox_22)
+ self.groupBox_23 = QtWidgets.QGroupBox(parent=self.override_parameters)
+ self.groupBox_23.setObjectName("groupBox_23")
+ self.gridLayout_27 = QtWidgets.QGridLayout(self.groupBox_23)
+ self.gridLayout_27.setObjectName("gridLayout_27")
+ self.top_k = SliderWidget(parent=self.groupBox_23)
+ self.top_k.setProperty("slider_minimum", 0)
+ self.top_k.setProperty("slider_maximum", 256)
+ self.top_k.setProperty("spinbox_minimum", 0.0)
+ self.top_k.setProperty("spinbox_maximum", 256.0)
+ self.top_k.setProperty("display_as_float", False)
+ self.top_k.setProperty("slider_single_step", 1)
+ self.top_k.setProperty("slider_page_step", 10)
+ self.top_k.setProperty("spinbox_single_step", 1)
+ self.top_k.setProperty("spinbox_page_step", 10)
+ self.top_k.setObjectName("top_k")
+ self.gridLayout_27.addWidget(self.top_k, 0, 0, 1, 1)
+ self.horizontalLayout_10.addWidget(self.groupBox_23)
+ self.gridLayout_12.addLayout(self.horizontalLayout_10, 4, 0, 1, 1)
+ self.label_4 = QtWidgets.QLabel(parent=self.override_parameters)
+ font = QtGui.QFont()
+ font.setPointSize(9)
+ self.label_4.setFont(font)
+ self.label_4.setObjectName("label_4")
+ self.gridLayout_12.addWidget(self.label_4, 11, 0, 1, 1)
+ self.gridLayout.addWidget(self.override_parameters, 4, 0, 1, 1)
+ spacerItem1 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding)
+ self.gridLayout.addItem(spacerItem1, 5, 0, 1, 1)
+
+ self.retranslateUi(llm_settings_widget)
+ self.model_version.setCurrentIndex(-1)
+ self.radio_button_2bit.toggled['bool'].connect(llm_settings_widget.toggled_2bit) # type: ignore
+ self.radio_button_16bit.toggled['bool'].connect(llm_settings_widget.toggled_16bit) # type: ignore
+ self.use_gpu_checkbox.toggled['bool'].connect(llm_settings_widget.use_gpu_toggled) # type: ignore
+ self.radio_button_4bit.toggled['bool'].connect(llm_settings_widget.toggled_4bit) # type: ignore
+ self.radio_button_8bit.toggled['bool'].connect(llm_settings_widget.toggled_8bit) # type: ignore
+ self.radio_button_32bit.toggled['bool'].connect(llm_settings_widget.toggled_32bit) # type: ignore
+ self.override_parameters.toggled['bool'].connect(llm_settings_widget.override_parameters_toggled) # type: ignore
+ self.seed.textEdited['QString'].connect(llm_settings_widget.seed_changed) # type: ignore
+ self.early_stopping.toggled['bool'].connect(llm_settings_widget.early_stopping_toggled) # type: ignore
+ self.random_seed.toggled['bool'].connect(llm_settings_widget.random_seed_toggled) # type: ignore
+ self.pushButton.clicked.connect(llm_settings_widget.reset_settings_to_default_clicked) # type: ignore
+ self.move_to_cpu.toggled['bool'].connect(llm_settings_widget.toggle_move_model_to_cpu) # type: ignore
+ self.do_sample.toggled['bool'].connect(llm_settings_widget.do_sample_toggled) # type: ignore
+ self.leave_in_vram.toggled['bool'].connect(llm_settings_widget.toggle_leave_model_in_vram) # type: ignore
+ self.unload_model.toggled['bool'].connect(llm_settings_widget.toggle_unload_model) # type: ignore
+ self.prompt_template.currentTextChanged['QString'].connect(llm_settings_widget.prompt_template_text_changed) # type: ignore
+ self.model.currentTextChanged['QString'].connect(llm_settings_widget.model_text_changed) # type: ignore
+ self.model_version.currentTextChanged['QString'].connect(llm_settings_widget.model_version_changed) # type: ignore
+ QtCore.QMetaObject.connectSlotsByName(llm_settings_widget)
+ llm_settings_widget.setTabOrder(self.model, self.model_version)
+ llm_settings_widget.setTabOrder(self.model_version, self.use_gpu_checkbox)
+ llm_settings_widget.setTabOrder(self.use_gpu_checkbox, self.radio_button_4bit)
+ llm_settings_widget.setTabOrder(self.radio_button_4bit, self.radio_button_8bit)
+ llm_settings_widget.setTabOrder(self.radio_button_8bit, self.radio_button_16bit)
+ llm_settings_widget.setTabOrder(self.radio_button_16bit, self.radio_button_32bit)
+ llm_settings_widget.setTabOrder(self.radio_button_32bit, self.seed)
+ llm_settings_widget.setTabOrder(self.seed, self.random_seed)
+
+ def retranslateUi(self, llm_settings_widget):
+ _translate = QtCore.QCoreApplication.translate
+ llm_settings_widget.setWindowTitle(_translate("llm_settings_widget", "Form"))
+ self.groupBox_7.setTitle(_translate("llm_settings_widget", "Model Type"))
+ self.groupBox_8.setTitle(_translate("llm_settings_widget", "Model Version"))
+ self.groupBox_14.setTitle(_translate("llm_settings_widget", "Prompt Template"))
+ self.groupBox_6.setTitle(_translate("llm_settings_widget", "DType"))
+ self.radio_button_2bit.setText(_translate("llm_settings_widget", "2-bit"))
+ self.radio_button_4bit.setText(_translate("llm_settings_widget", "4-bit"))
+ self.radio_button_8bit.setText(_translate("llm_settings_widget", "8-bit"))
+ self.radio_button_16bit.setText(_translate("llm_settings_widget", "16-bit"))
+ self.radio_button_32bit.setText(_translate("llm_settings_widget", "32-bit"))
+ self.dtype_description.setText(_translate("llm_settings_widget", "Description"))
+ self.use_gpu_checkbox.setText(_translate("llm_settings_widget", "Use GPU"))
+ self.override_parameters.setTitle(_translate("llm_settings_widget", "Override Prameters"))
+ self.pushButton.setText(_translate("llm_settings_widget", "Reset Settings to Default"))
+ self.groupBox_19.setTitle(_translate("llm_settings_widget", "Seed"))
+ self.random_seed.setText(_translate("llm_settings_widget", "Random seed"))
+ self.early_stopping.setText(_translate("llm_settings_widget", "Early stopping"))
+ self.do_sample.setText(_translate("llm_settings_widget", "Do sample"))
+ self.groupBox_11.setTitle(_translate("llm_settings_widget", "Repetition penalty"))
+ self.repetition_penalty.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.repetition_penalty"))
+ self.repetition_penalty.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.groupBox_16.setTitle(_translate("llm_settings_widget", "Min length"))
+ self.min_length.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.min_length"))
+ self.min_length.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.groupBox_20.setTitle(_translate("llm_settings_widget", "Top P"))
+ self.top_p.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.top_p"))
+ self.top_p.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.groupBox_21.setTitle(_translate("llm_settings_widget", "Max length"))
+ self.max_length.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.max_length"))
+ self.max_length.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.leave_in_vram.setText(_translate("llm_settings_widget", "Leave in VRAM"))
+ self.move_to_cpu.setText(_translate("llm_settings_widget", "Move to CPU"))
+ self.unload_model.setText(_translate("llm_settings_widget", "Unload model"))
+ self.label_3.setText(_translate("llm_settings_widget", "Model management"))
+ self.groupBox_17.setTitle(_translate("llm_settings_widget", "No repeat ngram size"))
+ self.ngram_size.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.ngram_size"))
+ self.ngram_size.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.groupBox_18.setTitle(_translate("llm_settings_widget", "Temperature"))
+ self.temperature.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.temperature"))
+ self.temperature.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.groupBox_24.setTitle(_translate("llm_settings_widget", "Length penalty"))
+ self.length_penalty.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.length_penalty"))
+ self.length_penalty.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.groupBox_25.setTitle(_translate("llm_settings_widget", "Num beams"))
+ self.num_beams.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.num_beams"))
+ self.num_beams.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.groupBox_22.setTitle(_translate("llm_settings_widget", "Sequences to generate"))
+ self.sequences.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.sequences"))
+ self.sequences.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.groupBox_23.setTitle(_translate("llm_settings_widget", "Top k"))
+ self.top_k.setProperty("settings_property", _translate("llm_settings_widget", "llm_generator_setting.top_k"))
+ self.top_k.setProperty("slider_callback", _translate("llm_settings_widget", "handle_value_change"))
+ self.label_4.setText(_translate("llm_settings_widget", "How to treat model when not in use"))
+from airunner.widgets.slider.slider_widget import SliderWidget
diff --git a/src/airunner/widgets/llm/templates/llm_widget.ui b/src/airunner/widgets/llm/templates/llm_widget.ui
index acc764482..5fefdb191 100644
--- a/src/airunner/widgets/llm/templates/llm_widget.ui
+++ b/src/airunner/widgets/llm/templates/llm_widget.ui
@@ -53,90 +53,7 @@
0
-
-
-
- true
-
-
-
- -
-
-
-
-
-
- Send
-
-
-
- -
-
-
- 0
-
-
-
- -
-
-
- New Conversation
-
-
-
-
-
- -
-
-
-
- 0
-
-
- 0
-
-
- 0
-
-
- 0
-
-
-
-
-
- Chat input
-
-
-
- -
-
-
-
-
- Chat
-
-
- -
-
- Narrate
-
-
- -
-
- Generate Image
-
-
- -
-
- Summarize
-
-
- -
-
- Translate
-
-
-
-
-
-
+
@@ -149,125 +66,20 @@
Preferences
+
+ 0
+
+
+ 0
+
+
+ 0
+
+
+ 0
+
-
-
-
- Qt::Vertical
-
-
-
- Prefix
-
-
-
-
-
-
-
-
-
-
- Suffix
-
-
- -
-
-
-
-
-
-
- -
-
-
- Bot details
-
-
-
-
-
-
-
-
-
- Name
-
-
-
- -
-
-
- ChatAI
-
-
- 6
-
-
-
-
-
- -
-
-
-
-
-
- Personality
-
-
-
- -
-
-
-
-
- Nice
-
-
- -
-
- Mean
-
-
- -
-
- Weird
-
-
- -
-
- Insane
-
-
- -
-
- Random
-
-
-
-
-
-
-
-
-
- -
-
-
- User name
-
-
-
-
-
-
- User
-
-
-
-
-
-
- -
-
-
- Generate Characters
-
-
+
@@ -280,733 +92,20 @@
Settings
+
+ 0
+
+
+ 0
+
+
+ 0
+
+
+ 0
+
-
-
-
- Model Type
-
-
-
-
-
-
-
-
-
- -
-
-
- DType
-
-
-
-
-
-
-
-
-
- 2-bit
-
-
-
- -
-
-
- 4-bit
-
-
-
- -
-
-
- 8-bit
-
-
-
- -
-
-
- 16-bit
-
-
-
- -
-
-
- 32-bit
-
-
-
- -
-
-
- Qt::Horizontal
-
-
-
- 40
- 20
-
-
-
-
-
-
- -
-
-
-
- 9
-
-
-
- Description
-
-
-
- -
-
-
- Use GPU
-
-
-
-
-
-
- -
-
-
- Override Prameters
-
-
- true
-
-
- true
-
-
-
-
-
-
- Reset Settings to Default
-
-
-
- -
-
-
- Seed
-
-
-
-
-
-
- -
-
-
- Random seed
-
-
-
-
-
-
- -
-
-
-
-
-
- Early stopping
-
-
-
- -
-
-
- Do sample
-
-
-
-
-
- -
-
-
-
-
-
- Repetition penalty
-
-
-
-
-
-
- 1
-
-
- 10000
-
-
- 0.010000000000000
-
-
- 100.000000000000000
-
-
- true
-
-
- llm_generator_setting.repetition_penalty
-
-
- 0
-
-
- 1
-
-
- 1.000000000000000
-
-
- 10.000000000000000
-
-
- handle_value_change
-
-
-
-
-
-
- -
-
-
- Min length
-
-
-
-
-
-
- 1
-
-
- 2556
-
-
- 1.000000000000000
-
-
- 2556.000000000000000
-
-
- false
-
-
- llm_generator_setting.min_length
-
-
- 1
-
-
- 2556
-
-
- 1
-
-
- 2556
-
-
- handle_value_change
-
-
-
-
-
-
-
-
- -
-
-
-
-
-
- Top P
-
-
-
-
-
-
- 1
-
-
- 100
-
-
- 0.000000000000000
-
-
- 1.000000000000000
-
-
- true
-
-
- llm_generator_setting.top_p
-
-
- 1
-
-
- 10
-
-
- 0.010000000000000
-
-
- 0.100000000000000
-
-
- handle_value_change
-
-
-
-
-
-
- -
-
-
- Max length
-
-
-
-
-
-
- 1
-
-
- 2556
-
-
- 1.000000000000000
-
-
- 2556.000000000000000
-
-
- false
-
-
- llm_generator_setting.max_length
-
-
- 1
-
-
- 2556
-
-
- 1
-
-
- 2556
-
-
- handle_value_change
-
-
-
-
-
-
-
-
- -
-
-
-
-
-
- Leave in VRAM
-
-
-
- -
-
-
- Move to CPU
-
-
-
- -
-
-
- Unload model
-
-
-
-
-
- -
-
-
-
- true
-
-
-
- Model management
-
-
-
- -
-
-
-
-
-
- No repeat ngram size
-
-
-
-
-
-
- 0
-
-
- 20
-
-
- 0.000000000000000
-
-
- 20.000000000000000
-
-
- false
-
-
- llm_generator_setting.ngram_size
-
-
- 1
-
-
- 1
-
-
- 1.000000000000000
-
-
- 1.000000000000000
-
-
- handle_value_change
-
-
-
-
-
-
- -
-
-
- Temperature
-
-
-
-
-
-
- 1
-
-
- 20000
-
-
- 0.000100000000000
-
-
- 2.000000000000000
-
-
- true
-
-
- llm_generator_setting.temperature
-
-
- 1
-
-
- 10
-
-
- 0.010000000000000
-
-
- 0.100000000000000
-
-
- handle_value_change
-
-
-
-
-
-
-
-
- -
-
-
-
-
-
- Length penalty
-
-
-
-
-
-
- -100
-
-
- 100
-
-
- 0.000000000000000
-
-
- 1.000000000000000
-
-
- true
-
-
- llm_generator_setting.length_penalty
-
-
- 1
-
-
- 10
-
-
- 0.010000000000000
-
-
- 0.100000000000000
-
-
- handle_value_change
-
-
-
-
-
-
- -
-
-
- Num beams
-
-
-
-
-
-
- 1
-
-
- 100
-
-
- 0.000000000000000
-
-
- 100.000000000000000
-
-
- false
-
-
- llm_generator_setting.num_beams
-
-
- 1
-
-
- 10
-
-
- 0.010000000000000
-
-
- 0.100000000000000
-
-
- handle_value_change
-
-
-
-
-
-
-
-
- -
-
-
- Qt::Horizontal
-
-
-
- -
-
-
-
-
-
- Sequences to generate
-
-
-
-
-
-
- 1
-
-
- 100
-
-
- 0.000000000000000
-
-
- 100.000000000000000
-
-
- false
-
-
- llm_generator_setting.sequences
-
-
- 1
-
-
- 10
-
-
- 0.010000000000000
-
-
- 0.100000000000000
-
-
- handle_value_change
-
-
-
-
-
-
- -
-
-
- Top k
-
-
-
-
-
-
- 0
-
-
- 256
-
-
- 0.000000000000000
-
-
- 256.000000000000000
-
-
- false
-
-
- llm_generator_setting.top_k
-
-
- 1
-
-
- 10
-
-
- 1
-
-
- 10
-
-
- handle_value_change
-
-
-
-
-
-
-
-
- -
-
-
-
- 9
-
-
-
- How to treat model when not in use
-
-
-
-
-
-
- -
-
-
- Model Version
-
-
-
-
-
-
- -1
-
-
-
-
-
-
- -
-
-
- Qt::Vertical
-
-
-
- 20
- 40
-
-
-
-
- -
-
-
- Prompt Template
-
-
-
-
-
-
-
-
+
@@ -1016,455 +115,29 @@
- SliderWidget
+ ChatPromptWidget
+ QWidget
+ airunner/widgets/llm/chat_prompt_widget
+ 1
+
+
+ LLMSettingsWidget
+ QWidget
+ airunner/widgets/llm/llm_settings_widget
+ 1
+
+
+ LLMPreferencesWidget
QWidget
- airunner/widgets/slider/slider_widget
+ airunner/widgets/llm/llm_preferences_widget
1
- prompt
- comboBox
- send_button
- clear_conversatiion_button
- prefix
- suffix
- botname
- username
- generate_characters_button
- model
- model_version
- use_gpu_checkbox
- radio_button_4bit
- radio_button_8bit
- radio_button_16bit
- radio_button_32bit
- override_parameters
- seed
- random_seed
tabWidget
- comboBox
- prompt
-
-
- prefix
- textChanged()
- llm_widget
- prefix_text_changed()
-
-
- 131
- 108
-
-
- 48
- 6
-
-
-
-
- suffix
- textChanged()
- llm_widget
- suffix_text_changed()
-
-
- 182
- 669
-
-
- 206
- 5
-
-
-
-
- send_button
- clicked()
- llm_widget
- action_button_clicked_send()
-
-
- 91
- 961
-
-
- 257
- 0
-
-
-
-
- clear_conversatiion_button
- clicked()
- llm_widget
- action_button_clicked_clear_conversation()
-
-
- 609
- 961
-
-
- 498
- 4
-
-
-
-
- radio_button_4bit
- toggled(bool)
- llm_widget
- toggled_4bit(bool)
-
-
- 126
- 347
-
-
- 3
- 43
-
-
-
-
- radio_button_8bit
- toggled(bool)
- llm_widget
- toggled_8bit(bool)
-
-
- 206
- 347
-
-
- 2
- 88
-
-
-
-
- radio_button_16bit
- toggled(bool)
- llm_widget
- toggled_16bit(bool)
-
-
- 215
- 347
-
-
- 3
- 136
-
-
-
-
- radio_button_32bit
- toggled(bool)
- llm_widget
- toggled_32bit(bool)
-
-
- 287
- 347
-
-
- 0
- 179
-
-
-
-
- model_version
- currentTextChanged(QString)
- llm_widget
- model_version_changed(QString)
-
-
- 184
- 169
-
-
- 272
- 0
-
-
-
-
- seed
- textEdited(QString)
- llm_widget
- seed_changed(QString)
-
-
- 268
- 785
-
-
- 2
- 382
-
-
-
-
- random_seed
- toggled(bool)
- llm_widget
- random_seed_toggled(bool)
-
-
- 586
- 784
-
-
- 3
- 470
-
-
-
-
- do_sample
- toggled(bool)
- llm_widget
- do_sample_toggled(bool)
-
-
- 597
- 827
-
-
- 1
- 427
-
-
-
-
- early_stopping
- toggled(bool)
- llm_widget
- early_stopping_toggled(bool)
-
-
- 103
- 827
-
-
- 1
- 501
-
-
-
-
- model
- currentTextChanged(QString)
- llm_widget
- model_text_changed(QString)
-
-
- 163
- 92
-
-
- 117
- 0
-
-
-
-
- pushButton
- clicked()
- llm_widget
- reset_settings_to_default_clicked()
-
-
- 290
- 942
-
-
- 437
- 0
-
-
-
-
- use_gpu_checkbox
- toggled(bool)
- llm_widget
- use_gpu_toggled(bool)
-
-
- 441
- 315
-
-
- 34
- 0
-
-
-
-
- override_parameters
- toggled(bool)
- llm_widget
- override_parameters_toggled(bool)
-
-
- 128
- 398
-
-
- 120
- 0
-
-
-
-
- botname
- textEdited(QString)
- llm_widget
- botname_text_changed(QString)
-
-
- 298
- 843
-
-
- 80
- 0
-
-
-
-
- username
- textEdited(QString)
- llm_widget
- username_text_changed(QString)
-
-
- 226
- 919
-
-
- 118
- 0
-
-
-
-
- personality_type
- textHighlighted(QString)
- llm_widget
- personality_type_changed(QString)
-
-
- 597
- 843
-
-
- 412
- 0
-
-
-
-
- generate_characters_button
- clicked()
- llm_widget
- action_button_clicked_generate_characters()
-
-
- 362
- 962
-
-
- 477
- 0
-
-
-
-
- radio_button_2bit
- toggled(bool)
- llm_widget
- toggled_2bit(bool)
-
-
- 63
- 333
-
-
- 0
- 324
-
-
-
-
- leave_in_vram
- toggled(bool)
- llm_widget
- toggle_leave_model_in_vram(bool)
-
-
- 120
- 910
-
-
- 0
- 716
-
-
-
-
- move_to_cpu
- toggled(bool)
- llm_widget
- toggle_move_model_to_cpu(bool)
-
-
- 404
- 910
-
-
- 0
- 704
-
-
-
-
- unload_model
- toggled(bool)
- llm_widget
- toggle_unload_model(bool)
-
-
- 597
- 910
-
-
- 0
- 685
-
-
-
-
- prompt_template
- currentTextChanged(QString)
- llm_widget
- prompt_template_text_changed(QString)
-
-
- 69
- 226
-
-
- 79
- -16
-
-
-
-
+
prompt_text_changed(QString)
botname_text_changed(QString)
diff --git a/src/airunner/widgets/llm/templates/llm_widget_ui.py b/src/airunner/widgets/llm/templates/llm_widget_ui.py
index efb6aee76..1ff1c26bd 100644
--- a/src/airunner/widgets/llm/templates/llm_widget_ui.py
+++ b/src/airunner/widgets/llm/templates/llm_widget_ui.py
@@ -23,548 +23,43 @@ def setupUi(self, llm_widget):
self.gridLayout_8 = QtWidgets.QGridLayout(self.chat)
self.gridLayout_8.setContentsMargins(0, 0, 0, 0)
self.gridLayout_8.setObjectName("gridLayout_8")
- self.conversation = QtWidgets.QTextEdit(parent=self.chat)
- self.conversation.setReadOnly(True)
- self.conversation.setObjectName("conversation")
- self.gridLayout_8.addWidget(self.conversation, 0, 0, 1, 1)
- self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_2.setObjectName("horizontalLayout_2")
- self.send_button = QtWidgets.QPushButton(parent=self.chat)
- self.send_button.setObjectName("send_button")
- self.horizontalLayout_2.addWidget(self.send_button)
- self.progressBar = QtWidgets.QProgressBar(parent=self.chat)
- self.progressBar.setProperty("value", 0)
- self.progressBar.setObjectName("progressBar")
- self.horizontalLayout_2.addWidget(self.progressBar)
- self.clear_conversatiion_button = QtWidgets.QPushButton(parent=self.chat)
- self.clear_conversatiion_button.setObjectName("clear_conversatiion_button")
- self.horizontalLayout_2.addWidget(self.clear_conversatiion_button)
- self.gridLayout_8.addLayout(self.horizontalLayout_2, 2, 0, 1, 1)
- self.widget_5 = QtWidgets.QWidget(parent=self.chat)
- self.widget_5.setObjectName("widget_5")
- self.horizontalLayout = QtWidgets.QHBoxLayout(self.widget_5)
- self.horizontalLayout.setContentsMargins(0, 0, 0, 0)
- self.horizontalLayout.setObjectName("horizontalLayout")
- self.prompt = QtWidgets.QLineEdit(parent=self.widget_5)
- self.prompt.setObjectName("prompt")
- self.horizontalLayout.addWidget(self.prompt)
- self.comboBox = QtWidgets.QComboBox(parent=self.widget_5)
- self.comboBox.setObjectName("comboBox")
- self.comboBox.addItem("")
- self.comboBox.addItem("")
- self.comboBox.addItem("")
- self.comboBox.addItem("")
- self.comboBox.addItem("")
- self.horizontalLayout.addWidget(self.comboBox)
- self.gridLayout_8.addWidget(self.widget_5, 1, 0, 1, 1)
+ self.chat_prompt_widget = ChatPromptWidget(parent=self.chat)
+ self.chat_prompt_widget.setObjectName("chat_prompt_widget")
+ self.gridLayout_8.addWidget(self.chat_prompt_widget, 0, 0, 1, 1)
icon = QtGui.QIcon.fromTheme("user-available")
self.tabWidget.addTab(self.chat, icon, "")
self.preferences = QtWidgets.QWidget()
self.preferences.setObjectName("preferences")
self.gridLayout_2 = QtWidgets.QGridLayout(self.preferences)
+ self.gridLayout_2.setContentsMargins(0, 0, 0, 0)
self.gridLayout_2.setObjectName("gridLayout_2")
- self.splitter = QtWidgets.QSplitter(parent=self.preferences)
- self.splitter.setOrientation(QtCore.Qt.Orientation.Vertical)
- self.splitter.setObjectName("splitter")
- self.groupBox = QtWidgets.QGroupBox(parent=self.splitter)
- self.groupBox.setObjectName("groupBox")
- self.gridLayout_3 = QtWidgets.QGridLayout(self.groupBox)
- self.gridLayout_3.setObjectName("gridLayout_3")
- self.prefix = QtWidgets.QPlainTextEdit(parent=self.groupBox)
- self.prefix.setObjectName("prefix")
- self.gridLayout_3.addWidget(self.prefix, 0, 0, 1, 1)
- self.groupBox_2 = QtWidgets.QGroupBox(parent=self.splitter)
- self.groupBox_2.setObjectName("groupBox_2")
- self.gridLayout_4 = QtWidgets.QGridLayout(self.groupBox_2)
- self.gridLayout_4.setObjectName("gridLayout_4")
- self.suffix = QtWidgets.QPlainTextEdit(parent=self.groupBox_2)
- self.suffix.setObjectName("suffix")
- self.gridLayout_4.addWidget(self.suffix, 0, 0, 1, 1)
- self.gridLayout_2.addWidget(self.splitter, 0, 0, 1, 1)
- self.groupBox_3 = QtWidgets.QGroupBox(parent=self.preferences)
- self.groupBox_3.setObjectName("groupBox_3")
- self.horizontalLayout_5 = QtWidgets.QHBoxLayout(self.groupBox_3)
- self.horizontalLayout_5.setObjectName("horizontalLayout_5")
- self.verticalLayout_2 = QtWidgets.QVBoxLayout()
- self.verticalLayout_2.setObjectName("verticalLayout_2")
- self.label = QtWidgets.QLabel(parent=self.groupBox_3)
- self.label.setObjectName("label")
- self.verticalLayout_2.addWidget(self.label)
- self.botname = QtWidgets.QLineEdit(parent=self.groupBox_3)
- self.botname.setCursorPosition(6)
- self.botname.setObjectName("botname")
- self.verticalLayout_2.addWidget(self.botname)
- self.horizontalLayout_5.addLayout(self.verticalLayout_2)
- self.verticalLayout_3 = QtWidgets.QVBoxLayout()
- self.verticalLayout_3.setObjectName("verticalLayout_3")
- self.label_2 = QtWidgets.QLabel(parent=self.groupBox_3)
- self.label_2.setObjectName("label_2")
- self.verticalLayout_3.addWidget(self.label_2)
- self.personality_type = QtWidgets.QComboBox(parent=self.groupBox_3)
- self.personality_type.setObjectName("personality_type")
- self.personality_type.addItem("")
- self.personality_type.addItem("")
- self.personality_type.addItem("")
- self.personality_type.addItem("")
- self.personality_type.addItem("")
- self.verticalLayout_3.addWidget(self.personality_type)
- self.horizontalLayout_5.addLayout(self.verticalLayout_3)
- self.gridLayout_2.addWidget(self.groupBox_3, 1, 0, 1, 1)
- self.groupBox_4 = QtWidgets.QGroupBox(parent=self.preferences)
- self.groupBox_4.setObjectName("groupBox_4")
- self.gridLayout_13 = QtWidgets.QGridLayout(self.groupBox_4)
- self.gridLayout_13.setObjectName("gridLayout_13")
- self.username = QtWidgets.QLineEdit(parent=self.groupBox_4)
- self.username.setObjectName("username")
- self.gridLayout_13.addWidget(self.username, 0, 0, 1, 1)
- self.gridLayout_2.addWidget(self.groupBox_4, 2, 0, 1, 1)
- self.generate_characters_button = QtWidgets.QPushButton(parent=self.preferences)
- self.generate_characters_button.setObjectName("generate_characters_button")
- self.gridLayout_2.addWidget(self.generate_characters_button, 3, 0, 1, 1)
+ self.llm_preferences_widget = LLMPreferencesWidget(parent=self.preferences)
+ self.llm_preferences_widget.setObjectName("llm_preferences_widget")
+ self.gridLayout_2.addWidget(self.llm_preferences_widget, 0, 0, 1, 1)
icon = QtGui.QIcon.fromTheme("preferences-desktop")
self.tabWidget.addTab(self.preferences, icon, "")
self.settings = QtWidgets.QWidget()
self.settings.setObjectName("settings")
self.gridLayout_5 = QtWidgets.QGridLayout(self.settings)
+ self.gridLayout_5.setContentsMargins(0, 0, 0, 0)
self.gridLayout_5.setObjectName("gridLayout_5")
- self.groupBox_7 = QtWidgets.QGroupBox(parent=self.settings)
- self.groupBox_7.setObjectName("groupBox_7")
- self.gridLayout_10 = QtWidgets.QGridLayout(self.groupBox_7)
- self.gridLayout_10.setObjectName("gridLayout_10")
- self.model = QtWidgets.QComboBox(parent=self.groupBox_7)
- self.model.setObjectName("model")
- self.gridLayout_10.addWidget(self.model, 0, 0, 1, 1)
- self.gridLayout_5.addWidget(self.groupBox_7, 0, 0, 1, 1)
- self.groupBox_6 = QtWidgets.QGroupBox(parent=self.settings)
- self.groupBox_6.setObjectName("groupBox_6")
- self.gridLayout_9 = QtWidgets.QGridLayout(self.groupBox_6)
- self.gridLayout_9.setObjectName("gridLayout_9")
- self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_3.setObjectName("horizontalLayout_3")
- self.radio_button_2bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
- self.radio_button_2bit.setObjectName("radio_button_2bit")
- self.horizontalLayout_3.addWidget(self.radio_button_2bit)
- self.radio_button_4bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
- self.radio_button_4bit.setObjectName("radio_button_4bit")
- self.horizontalLayout_3.addWidget(self.radio_button_4bit)
- self.radio_button_8bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
- self.radio_button_8bit.setObjectName("radio_button_8bit")
- self.horizontalLayout_3.addWidget(self.radio_button_8bit)
- self.radio_button_16bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
- self.radio_button_16bit.setObjectName("radio_button_16bit")
- self.horizontalLayout_3.addWidget(self.radio_button_16bit)
- self.radio_button_32bit = QtWidgets.QRadioButton(parent=self.groupBox_6)
- self.radio_button_32bit.setObjectName("radio_button_32bit")
- self.horizontalLayout_3.addWidget(self.radio_button_32bit)
- spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Minimum)
- self.horizontalLayout_3.addItem(spacerItem)
- self.gridLayout_9.addLayout(self.horizontalLayout_3, 1, 0, 1, 1)
- self.dtype_description = QtWidgets.QLabel(parent=self.groupBox_6)
- font = QtGui.QFont()
- font.setPointSize(9)
- self.dtype_description.setFont(font)
- self.dtype_description.setObjectName("dtype_description")
- self.gridLayout_9.addWidget(self.dtype_description, 2, 0, 1, 1)
- self.use_gpu_checkbox = QtWidgets.QCheckBox(parent=self.groupBox_6)
- self.use_gpu_checkbox.setObjectName("use_gpu_checkbox")
- self.gridLayout_9.addWidget(self.use_gpu_checkbox, 0, 0, 1, 1)
- self.gridLayout_5.addWidget(self.groupBox_6, 3, 0, 1, 1)
- self.override_parameters = QtWidgets.QGroupBox(parent=self.settings)
- self.override_parameters.setCheckable(True)
- self.override_parameters.setChecked(True)
- self.override_parameters.setObjectName("override_parameters")
- self.gridLayout_12 = QtWidgets.QGridLayout(self.override_parameters)
- self.gridLayout_12.setObjectName("gridLayout_12")
- self.pushButton = QtWidgets.QPushButton(parent=self.override_parameters)
- self.pushButton.setObjectName("pushButton")
- self.gridLayout_12.addWidget(self.pushButton, 14, 0, 1, 1)
- self.groupBox_19 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_19.setObjectName("groupBox_19")
- self.horizontalLayout_4 = QtWidgets.QHBoxLayout(self.groupBox_19)
- self.horizontalLayout_4.setObjectName("horizontalLayout_4")
- self.seed = QtWidgets.QLineEdit(parent=self.groupBox_19)
- self.seed.setObjectName("seed")
- self.horizontalLayout_4.addWidget(self.seed)
- self.random_seed = QtWidgets.QCheckBox(parent=self.groupBox_19)
- self.random_seed.setObjectName("random_seed")
- self.horizontalLayout_4.addWidget(self.random_seed)
- self.gridLayout_12.addWidget(self.groupBox_19, 5, 0, 1, 1)
- self.horizontalLayout_6 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_6.setObjectName("horizontalLayout_6")
- self.early_stopping = QtWidgets.QCheckBox(parent=self.override_parameters)
- self.early_stopping.setObjectName("early_stopping")
- self.horizontalLayout_6.addWidget(self.early_stopping)
- self.do_sample = QtWidgets.QCheckBox(parent=self.override_parameters)
- self.do_sample.setObjectName("do_sample")
- self.horizontalLayout_6.addWidget(self.do_sample)
- self.gridLayout_12.addLayout(self.horizontalLayout_6, 8, 0, 1, 1)
- self.horizontalLayout_7 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_7.setObjectName("horizontalLayout_7")
- self.groupBox_11 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_11.setObjectName("groupBox_11")
- self.gridLayout_19 = QtWidgets.QGridLayout(self.groupBox_11)
- self.gridLayout_19.setObjectName("gridLayout_19")
- self.repetition_penalty = SliderWidget(parent=self.groupBox_11)
- self.repetition_penalty.setProperty("slider_minimum", 1)
- self.repetition_penalty.setProperty("slider_maximum", 10000)
- self.repetition_penalty.setProperty("spinbox_minimum", 0.01)
- self.repetition_penalty.setProperty("spinbox_maximum", 100.0)
- self.repetition_penalty.setProperty("display_as_float", True)
- self.repetition_penalty.setProperty("slider_single_step", 0)
- self.repetition_penalty.setProperty("slider_page_step", 1)
- self.repetition_penalty.setProperty("spinbox_single_step", 1.0)
- self.repetition_penalty.setProperty("spinbox_page_step", 10.0)
- self.repetition_penalty.setObjectName("repetition_penalty")
- self.gridLayout_19.addWidget(self.repetition_penalty, 0, 0, 1, 1)
- self.horizontalLayout_7.addWidget(self.groupBox_11)
- self.groupBox_16 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_16.setObjectName("groupBox_16")
- self.gridLayout_20 = QtWidgets.QGridLayout(self.groupBox_16)
- self.gridLayout_20.setObjectName("gridLayout_20")
- self.min_length = SliderWidget(parent=self.groupBox_16)
- self.min_length.setProperty("slider_minimum", 1)
- self.min_length.setProperty("slider_maximum", 2556)
- self.min_length.setProperty("spinbox_minimum", 1.0)
- self.min_length.setProperty("spinbox_maximum", 2556.0)
- self.min_length.setProperty("display_as_float", False)
- self.min_length.setProperty("slider_single_step", 1)
- self.min_length.setProperty("slider_page_step", 2556)
- self.min_length.setProperty("spinbox_single_step", 1)
- self.min_length.setProperty("spinbox_page_step", 2556)
- self.min_length.setObjectName("min_length")
- self.gridLayout_20.addWidget(self.min_length, 0, 0, 1, 1)
- self.horizontalLayout_7.addWidget(self.groupBox_16)
- self.gridLayout_12.addLayout(self.horizontalLayout_7, 1, 0, 1, 1)
- self.horizontalLayout_9 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_9.setObjectName("horizontalLayout_9")
- self.groupBox_20 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_20.setObjectName("groupBox_20")
- self.gridLayout_24 = QtWidgets.QGridLayout(self.groupBox_20)
- self.gridLayout_24.setObjectName("gridLayout_24")
- self.top_p = SliderWidget(parent=self.groupBox_20)
- self.top_p.setProperty("slider_minimum", 1)
- self.top_p.setProperty("slider_maximum", 100)
- self.top_p.setProperty("spinbox_minimum", 0.0)
- self.top_p.setProperty("spinbox_maximum", 1.0)
- self.top_p.setProperty("display_as_float", True)
- self.top_p.setProperty("slider_single_step", 1)
- self.top_p.setProperty("slider_page_step", 10)
- self.top_p.setProperty("spinbox_single_step", 0.01)
- self.top_p.setProperty("spinbox_page_step", 0.1)
- self.top_p.setObjectName("top_p")
- self.gridLayout_24.addWidget(self.top_p, 0, 0, 1, 1)
- self.horizontalLayout_9.addWidget(self.groupBox_20)
- self.groupBox_21 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_21.setObjectName("groupBox_21")
- self.gridLayout_25 = QtWidgets.QGridLayout(self.groupBox_21)
- self.gridLayout_25.setObjectName("gridLayout_25")
- self.max_length = SliderWidget(parent=self.groupBox_21)
- self.max_length.setProperty("slider_minimum", 1)
- self.max_length.setProperty("slider_maximum", 2556)
- self.max_length.setProperty("spinbox_minimum", 1.0)
- self.max_length.setProperty("spinbox_maximum", 2556.0)
- self.max_length.setProperty("display_as_float", False)
- self.max_length.setProperty("slider_single_step", 1)
- self.max_length.setProperty("slider_page_step", 2556)
- self.max_length.setProperty("spinbox_single_step", 1)
- self.max_length.setProperty("spinbox_page_step", 2556)
- self.max_length.setObjectName("max_length")
- self.gridLayout_25.addWidget(self.max_length, 0, 0, 1, 1)
- self.horizontalLayout_9.addWidget(self.groupBox_21)
- self.gridLayout_12.addLayout(self.horizontalLayout_9, 0, 0, 1, 1)
- self.horizontalLayout_12 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_12.setObjectName("horizontalLayout_12")
- self.leave_in_vram = QtWidgets.QRadioButton(parent=self.override_parameters)
- self.leave_in_vram.setObjectName("leave_in_vram")
- self.horizontalLayout_12.addWidget(self.leave_in_vram)
- self.move_to_cpu = QtWidgets.QRadioButton(parent=self.override_parameters)
- self.move_to_cpu.setObjectName("move_to_cpu")
- self.horizontalLayout_12.addWidget(self.move_to_cpu)
- self.unload_model = QtWidgets.QRadioButton(parent=self.override_parameters)
- self.unload_model.setObjectName("unload_model")
- self.horizontalLayout_12.addWidget(self.unload_model)
- self.gridLayout_12.addLayout(self.horizontalLayout_12, 12, 0, 1, 1)
- self.label_3 = QtWidgets.QLabel(parent=self.override_parameters)
- font = QtGui.QFont()
- font.setBold(True)
- self.label_3.setFont(font)
- self.label_3.setObjectName("label_3")
- self.gridLayout_12.addWidget(self.label_3, 10, 0, 1, 1)
- self.horizontalLayout_8 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_8.setObjectName("horizontalLayout_8")
- self.groupBox_17 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_17.setObjectName("groupBox_17")
- self.gridLayout_21 = QtWidgets.QGridLayout(self.groupBox_17)
- self.gridLayout_21.setObjectName("gridLayout_21")
- self.ngram_size = SliderWidget(parent=self.groupBox_17)
- self.ngram_size.setProperty("slider_minimum", 0)
- self.ngram_size.setProperty("slider_maximum", 20)
- self.ngram_size.setProperty("spinbox_minimum", 0.0)
- self.ngram_size.setProperty("spinbox_maximum", 20.0)
- self.ngram_size.setProperty("display_as_float", False)
- self.ngram_size.setProperty("slider_single_step", 1)
- self.ngram_size.setProperty("slider_page_step", 1)
- self.ngram_size.setProperty("spinbox_single_step", 1.0)
- self.ngram_size.setProperty("spinbox_page_step", 1.0)
- self.ngram_size.setObjectName("ngram_size")
- self.gridLayout_21.addWidget(self.ngram_size, 0, 0, 1, 1)
- self.horizontalLayout_8.addWidget(self.groupBox_17)
- self.groupBox_18 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_18.setObjectName("groupBox_18")
- self.gridLayout_22 = QtWidgets.QGridLayout(self.groupBox_18)
- self.gridLayout_22.setObjectName("gridLayout_22")
- self.temperature = SliderWidget(parent=self.groupBox_18)
- self.temperature.setProperty("slider_minimum", 1)
- self.temperature.setProperty("slider_maximum", 20000)
- self.temperature.setProperty("spinbox_minimum", 0.0001)
- self.temperature.setProperty("spinbox_maximum", 2.0)
- self.temperature.setProperty("display_as_float", True)
- self.temperature.setProperty("slider_single_step", 1)
- self.temperature.setProperty("slider_page_step", 10)
- self.temperature.setProperty("spinbox_single_step", 0.01)
- self.temperature.setProperty("spinbox_page_step", 0.1)
- self.temperature.setObjectName("temperature")
- self.gridLayout_22.addWidget(self.temperature, 0, 0, 1, 1)
- self.horizontalLayout_8.addWidget(self.groupBox_18)
- self.gridLayout_12.addLayout(self.horizontalLayout_8, 3, 0, 1, 1)
- self.horizontalLayout_11 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_11.setObjectName("horizontalLayout_11")
- self.groupBox_24 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_24.setObjectName("groupBox_24")
- self.gridLayout_28 = QtWidgets.QGridLayout(self.groupBox_24)
- self.gridLayout_28.setObjectName("gridLayout_28")
- self.length_penalty = SliderWidget(parent=self.groupBox_24)
- self.length_penalty.setProperty("slider_minimum", -100)
- self.length_penalty.setProperty("slider_maximum", 100)
- self.length_penalty.setProperty("spinbox_minimum", 0.0)
- self.length_penalty.setProperty("spinbox_maximum", 1.0)
- self.length_penalty.setProperty("display_as_float", True)
- self.length_penalty.setProperty("slider_single_step", 1)
- self.length_penalty.setProperty("slider_page_step", 10)
- self.length_penalty.setProperty("spinbox_single_step", 0.01)
- self.length_penalty.setProperty("spinbox_page_step", 0.1)
- self.length_penalty.setObjectName("length_penalty")
- self.gridLayout_28.addWidget(self.length_penalty, 0, 0, 1, 1)
- self.horizontalLayout_11.addWidget(self.groupBox_24)
- self.groupBox_25 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_25.setObjectName("groupBox_25")
- self.gridLayout_29 = QtWidgets.QGridLayout(self.groupBox_25)
- self.gridLayout_29.setObjectName("gridLayout_29")
- self.num_beams = SliderWidget(parent=self.groupBox_25)
- self.num_beams.setProperty("slider_minimum", 1)
- self.num_beams.setProperty("slider_maximum", 100)
- self.num_beams.setProperty("spinbox_minimum", 0.0)
- self.num_beams.setProperty("spinbox_maximum", 100.0)
- self.num_beams.setProperty("display_as_float", False)
- self.num_beams.setProperty("slider_single_step", 1)
- self.num_beams.setProperty("slider_page_step", 10)
- self.num_beams.setProperty("spinbox_single_step", 0.01)
- self.num_beams.setProperty("spinbox_page_step", 0.1)
- self.num_beams.setObjectName("num_beams")
- self.gridLayout_29.addWidget(self.num_beams, 0, 0, 1, 1)
- self.horizontalLayout_11.addWidget(self.groupBox_25)
- self.gridLayout_12.addLayout(self.horizontalLayout_11, 2, 0, 1, 1)
- self.line = QtWidgets.QFrame(parent=self.override_parameters)
- self.line.setFrameShape(QtWidgets.QFrame.Shape.HLine)
- self.line.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken)
- self.line.setObjectName("line")
- self.gridLayout_12.addWidget(self.line, 9, 0, 1, 1)
- self.horizontalLayout_10 = QtWidgets.QHBoxLayout()
- self.horizontalLayout_10.setObjectName("horizontalLayout_10")
- self.groupBox_22 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_22.setObjectName("groupBox_22")
- self.gridLayout_26 = QtWidgets.QGridLayout(self.groupBox_22)
- self.gridLayout_26.setObjectName("gridLayout_26")
- self.sequences = SliderWidget(parent=self.groupBox_22)
- self.sequences.setProperty("slider_minimum", 1)
- self.sequences.setProperty("slider_maximum", 100)
- self.sequences.setProperty("spinbox_minimum", 0.0)
- self.sequences.setProperty("spinbox_maximum", 100.0)
- self.sequences.setProperty("display_as_float", False)
- self.sequences.setProperty("slider_single_step", 1)
- self.sequences.setProperty("slider_page_step", 10)
- self.sequences.setProperty("spinbox_single_step", 0.01)
- self.sequences.setProperty("spinbox_page_step", 0.1)
- self.sequences.setObjectName("sequences")
- self.gridLayout_26.addWidget(self.sequences, 0, 0, 1, 1)
- self.horizontalLayout_10.addWidget(self.groupBox_22)
- self.groupBox_23 = QtWidgets.QGroupBox(parent=self.override_parameters)
- self.groupBox_23.setObjectName("groupBox_23")
- self.gridLayout_27 = QtWidgets.QGridLayout(self.groupBox_23)
- self.gridLayout_27.setObjectName("gridLayout_27")
- self.top_k = SliderWidget(parent=self.groupBox_23)
- self.top_k.setProperty("slider_minimum", 0)
- self.top_k.setProperty("slider_maximum", 256)
- self.top_k.setProperty("spinbox_minimum", 0.0)
- self.top_k.setProperty("spinbox_maximum", 256.0)
- self.top_k.setProperty("display_as_float", False)
- self.top_k.setProperty("slider_single_step", 1)
- self.top_k.setProperty("slider_page_step", 10)
- self.top_k.setProperty("spinbox_single_step", 1)
- self.top_k.setProperty("spinbox_page_step", 10)
- self.top_k.setObjectName("top_k")
- self.gridLayout_27.addWidget(self.top_k, 0, 0, 1, 1)
- self.horizontalLayout_10.addWidget(self.groupBox_23)
- self.gridLayout_12.addLayout(self.horizontalLayout_10, 4, 0, 1, 1)
- self.label_4 = QtWidgets.QLabel(parent=self.override_parameters)
- font = QtGui.QFont()
- font.setPointSize(9)
- self.label_4.setFont(font)
- self.label_4.setObjectName("label_4")
- self.gridLayout_12.addWidget(self.label_4, 11, 0, 1, 1)
- self.gridLayout_5.addWidget(self.override_parameters, 4, 0, 1, 1)
- self.groupBox_8 = QtWidgets.QGroupBox(parent=self.settings)
- self.groupBox_8.setObjectName("groupBox_8")
- self.gridLayout_11 = QtWidgets.QGridLayout(self.groupBox_8)
- self.gridLayout_11.setObjectName("gridLayout_11")
- self.model_version = QtWidgets.QComboBox(parent=self.groupBox_8)
- self.model_version.setObjectName("model_version")
- self.gridLayout_11.addWidget(self.model_version, 0, 0, 1, 1)
- self.gridLayout_5.addWidget(self.groupBox_8, 1, 0, 1, 1)
- spacerItem1 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Policy.Minimum, QtWidgets.QSizePolicy.Policy.Expanding)
- self.gridLayout_5.addItem(spacerItem1, 16, 0, 1, 1)
- self.groupBox_14 = QtWidgets.QGroupBox(parent=self.settings)
- self.groupBox_14.setObjectName("groupBox_14")
- self.gridLayout_17 = QtWidgets.QGridLayout(self.groupBox_14)
- self.gridLayout_17.setObjectName("gridLayout_17")
- self.prompt_template = QtWidgets.QComboBox(parent=self.groupBox_14)
- self.prompt_template.setObjectName("prompt_template")
- self.gridLayout_17.addWidget(self.prompt_template, 0, 0, 1, 1)
- self.gridLayout_5.addWidget(self.groupBox_14, 2, 0, 1, 1)
+ self.llm_settings_widget = LLMSettingsWidget(parent=self.settings)
+ self.llm_settings_widget.setObjectName("llm_settings_widget")
+ self.gridLayout_5.addWidget(self.llm_settings_widget, 0, 0, 1, 1)
icon = QtGui.QIcon.fromTheme("preferences-other")
self.tabWidget.addTab(self.settings, icon, "")
self.gridLayout.addWidget(self.tabWidget, 0, 0, 1, 1)
self.retranslateUi(llm_widget)
self.tabWidget.setCurrentIndex(0)
- self.model_version.setCurrentIndex(-1)
- self.prefix.textChanged.connect(llm_widget.prefix_text_changed) # type: ignore
- self.suffix.textChanged.connect(llm_widget.suffix_text_changed) # type: ignore
- self.send_button.clicked.connect(llm_widget.action_button_clicked_send) # type: ignore
- self.clear_conversatiion_button.clicked.connect(llm_widget.action_button_clicked_clear_conversation) # type: ignore
- self.radio_button_4bit.toggled['bool'].connect(llm_widget.toggled_4bit) # type: ignore
- self.radio_button_8bit.toggled['bool'].connect(llm_widget.toggled_8bit) # type: ignore
- self.radio_button_16bit.toggled['bool'].connect(llm_widget.toggled_16bit) # type: ignore
- self.radio_button_32bit.toggled['bool'].connect(llm_widget.toggled_32bit) # type: ignore
- self.model_version.currentTextChanged['QString'].connect(llm_widget.model_version_changed) # type: ignore
- self.seed.textEdited['QString'].connect(llm_widget.seed_changed) # type: ignore
- self.random_seed.toggled['bool'].connect(llm_widget.random_seed_toggled) # type: ignore
- self.do_sample.toggled['bool'].connect(llm_widget.do_sample_toggled) # type: ignore
- self.early_stopping.toggled['bool'].connect(llm_widget.early_stopping_toggled) # type: ignore
- self.model.currentTextChanged['QString'].connect(llm_widget.model_text_changed) # type: ignore
- self.pushButton.clicked.connect(llm_widget.reset_settings_to_default_clicked) # type: ignore
- self.use_gpu_checkbox.toggled['bool'].connect(llm_widget.use_gpu_toggled) # type: ignore
- self.override_parameters.toggled['bool'].connect(llm_widget.override_parameters_toggled) # type: ignore
- self.botname.textEdited['QString'].connect(llm_widget.botname_text_changed) # type: ignore
- self.username.textEdited['QString'].connect(llm_widget.username_text_changed) # type: ignore
- self.personality_type.textHighlighted['QString'].connect(llm_widget.personality_type_changed) # type: ignore
- self.generate_characters_button.clicked.connect(llm_widget.action_button_clicked_generate_characters) # type: ignore
- self.radio_button_2bit.toggled['bool'].connect(llm_widget.toggled_2bit) # type: ignore
- self.leave_in_vram.toggled['bool'].connect(llm_widget.toggle_leave_model_in_vram) # type: ignore
- self.move_to_cpu.toggled['bool'].connect(llm_widget.toggle_move_model_to_cpu) # type: ignore
- self.unload_model.toggled['bool'].connect(llm_widget.toggle_unload_model) # type: ignore
- self.prompt_template.currentTextChanged['QString'].connect(llm_widget.prompt_template_text_changed) # type: ignore
QtCore.QMetaObject.connectSlotsByName(llm_widget)
- llm_widget.setTabOrder(self.prompt, self.comboBox)
- llm_widget.setTabOrder(self.comboBox, self.send_button)
- llm_widget.setTabOrder(self.send_button, self.clear_conversatiion_button)
- llm_widget.setTabOrder(self.clear_conversatiion_button, self.prefix)
- llm_widget.setTabOrder(self.prefix, self.suffix)
- llm_widget.setTabOrder(self.suffix, self.botname)
- llm_widget.setTabOrder(self.botname, self.username)
- llm_widget.setTabOrder(self.username, self.generate_characters_button)
- llm_widget.setTabOrder(self.generate_characters_button, self.model)
- llm_widget.setTabOrder(self.model, self.model_version)
- llm_widget.setTabOrder(self.model_version, self.use_gpu_checkbox)
- llm_widget.setTabOrder(self.use_gpu_checkbox, self.radio_button_4bit)
- llm_widget.setTabOrder(self.radio_button_4bit, self.radio_button_8bit)
- llm_widget.setTabOrder(self.radio_button_8bit, self.radio_button_16bit)
- llm_widget.setTabOrder(self.radio_button_16bit, self.radio_button_32bit)
- llm_widget.setTabOrder(self.radio_button_32bit, self.override_parameters)
- llm_widget.setTabOrder(self.override_parameters, self.seed)
- llm_widget.setTabOrder(self.seed, self.random_seed)
- llm_widget.setTabOrder(self.random_seed, self.tabWidget)
- llm_widget.setTabOrder(self.tabWidget, self.comboBox)
- llm_widget.setTabOrder(self.comboBox, self.prompt)
def retranslateUi(self, llm_widget):
_translate = QtCore.QCoreApplication.translate
llm_widget.setWindowTitle(_translate("llm_widget", "Form"))
- self.send_button.setText(_translate("llm_widget", "Send"))
- self.clear_conversatiion_button.setText(_translate("llm_widget", "New Conversation"))
- self.prompt.setPlaceholderText(_translate("llm_widget", "Chat input"))
- self.comboBox.setItemText(0, _translate("llm_widget", "Chat"))
- self.comboBox.setItemText(1, _translate("llm_widget", "Narrate"))
- self.comboBox.setItemText(2, _translate("llm_widget", "Generate Image"))
- self.comboBox.setItemText(3, _translate("llm_widget", "Summarize"))
- self.comboBox.setItemText(4, _translate("llm_widget", "Translate"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.chat), _translate("llm_widget", "Chat"))
- self.groupBox.setTitle(_translate("llm_widget", "Prefix"))
- self.groupBox_2.setTitle(_translate("llm_widget", "Suffix"))
- self.groupBox_3.setTitle(_translate("llm_widget", "Bot details"))
- self.label.setText(_translate("llm_widget", "Name"))
- self.botname.setText(_translate("llm_widget", "ChatAI"))
- self.label_2.setText(_translate("llm_widget", "Personality"))
- self.personality_type.setItemText(0, _translate("llm_widget", "Nice"))
- self.personality_type.setItemText(1, _translate("llm_widget", "Mean"))
- self.personality_type.setItemText(2, _translate("llm_widget", "Weird"))
- self.personality_type.setItemText(3, _translate("llm_widget", "Insane"))
- self.personality_type.setItemText(4, _translate("llm_widget", "Random"))
- self.groupBox_4.setTitle(_translate("llm_widget", "User name"))
- self.username.setText(_translate("llm_widget", "User"))
- self.generate_characters_button.setText(_translate("llm_widget", "Generate Characters"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.preferences), _translate("llm_widget", "Preferences"))
- self.groupBox_7.setTitle(_translate("llm_widget", "Model Type"))
- self.groupBox_6.setTitle(_translate("llm_widget", "DType"))
- self.radio_button_2bit.setText(_translate("llm_widget", "2-bit"))
- self.radio_button_4bit.setText(_translate("llm_widget", "4-bit"))
- self.radio_button_8bit.setText(_translate("llm_widget", "8-bit"))
- self.radio_button_16bit.setText(_translate("llm_widget", "16-bit"))
- self.radio_button_32bit.setText(_translate("llm_widget", "32-bit"))
- self.dtype_description.setText(_translate("llm_widget", "Description"))
- self.use_gpu_checkbox.setText(_translate("llm_widget", "Use GPU"))
- self.override_parameters.setTitle(_translate("llm_widget", "Override Prameters"))
- self.pushButton.setText(_translate("llm_widget", "Reset Settings to Default"))
- self.groupBox_19.setTitle(_translate("llm_widget", "Seed"))
- self.random_seed.setText(_translate("llm_widget", "Random seed"))
- self.early_stopping.setText(_translate("llm_widget", "Early stopping"))
- self.do_sample.setText(_translate("llm_widget", "Do sample"))
- self.groupBox_11.setTitle(_translate("llm_widget", "Repetition penalty"))
- self.repetition_penalty.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.repetition_penalty"))
- self.repetition_penalty.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.groupBox_16.setTitle(_translate("llm_widget", "Min length"))
- self.min_length.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.min_length"))
- self.min_length.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.groupBox_20.setTitle(_translate("llm_widget", "Top P"))
- self.top_p.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.top_p"))
- self.top_p.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.groupBox_21.setTitle(_translate("llm_widget", "Max length"))
- self.max_length.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.max_length"))
- self.max_length.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.leave_in_vram.setText(_translate("llm_widget", "Leave in VRAM"))
- self.move_to_cpu.setText(_translate("llm_widget", "Move to CPU"))
- self.unload_model.setText(_translate("llm_widget", "Unload model"))
- self.label_3.setText(_translate("llm_widget", "Model management"))
- self.groupBox_17.setTitle(_translate("llm_widget", "No repeat ngram size"))
- self.ngram_size.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.ngram_size"))
- self.ngram_size.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.groupBox_18.setTitle(_translate("llm_widget", "Temperature"))
- self.temperature.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.temperature"))
- self.temperature.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.groupBox_24.setTitle(_translate("llm_widget", "Length penalty"))
- self.length_penalty.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.length_penalty"))
- self.length_penalty.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.groupBox_25.setTitle(_translate("llm_widget", "Num beams"))
- self.num_beams.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.num_beams"))
- self.num_beams.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.groupBox_22.setTitle(_translate("llm_widget", "Sequences to generate"))
- self.sequences.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.sequences"))
- self.sequences.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.groupBox_23.setTitle(_translate("llm_widget", "Top k"))
- self.top_k.setProperty("settings_property", _translate("llm_widget", "llm_generator_setting.top_k"))
- self.top_k.setProperty("slider_callback", _translate("llm_widget", "handle_value_change"))
- self.label_4.setText(_translate("llm_widget", "How to treat model when not in use"))
- self.groupBox_8.setTitle(_translate("llm_widget", "Model Version"))
- self.groupBox_14.setTitle(_translate("llm_widget", "Prompt Template"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.settings), _translate("llm_widget", "Settings"))
-from airunner.widgets.slider.slider_widget import SliderWidget
+from airunner.widgets.llm.chat_prompt_widget import ChatPromptWidget
+from airunner.widgets.llm.llm_preferences_widget import LLMPreferencesWidget
+from airunner.widgets.llm.llm_settings_widget import LLMSettingsWidget
diff --git a/src/airunner/widgets/llm/templates/message.ui b/src/airunner/widgets/llm/templates/message.ui
new file mode 100644
index 000000000..f1ffdeab5
--- /dev/null
+++ b/src/airunner/widgets/llm/templates/message.ui
@@ -0,0 +1,87 @@
+
+
+ message
+
+
+
+ 0
+ 0
+ 394
+ 92
+
+
+
+ Form
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 0
+ 50
+
+
+
+
-
+
+
+ TextLabel
+
+
+ Qt::AlignBottom|Qt::AlignLeading|Qt::AlignLeft
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 0
+ 40
+
+
+
+ border-radius: 5px; border: 5px solid #1f1f1f; background-color: #1f1f1f; color: #ffffff;
+
+
+ QFrame::NoFrame
+
+
+ QFrame::Plain
+
+
+ true
+
+
+
+ -
+
+
+ TextLabel
+
+
+ Qt::AlignBottom|Qt::AlignLeading|Qt::AlignLeft
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/airunner/widgets/llm/templates/message_ui.py b/src/airunner/widgets/llm/templates/message_ui.py
new file mode 100644
index 000000000..bc0bd80d3
--- /dev/null
+++ b/src/airunner/widgets/llm/templates/message_ui.py
@@ -0,0 +1,58 @@
+# Form implementation generated from reading ui file '/home/joe/Projects/imagetopixel/airunner/src/airunner/../../src/airunner/widgets/llm/templates/message.ui'
+#
+# Created by: PyQt6 UI code generator 6.4.2
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic6 is
+# run again. Do not edit this file unless you know what you are doing.
+
+
+from PyQt6 import QtCore, QtGui, QtWidgets
+
+
+class Ui_message(object):
+ def setupUi(self, message):
+ message.setObjectName("message")
+ message.resize(394, 92)
+ self.gridLayout_2 = QtWidgets.QGridLayout(message)
+ self.gridLayout_2.setObjectName("gridLayout_2")
+ self.widget = QtWidgets.QWidget(parent=message)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Preferred, QtWidgets.QSizePolicy.Policy.MinimumExpanding)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.widget.sizePolicy().hasHeightForWidth())
+ self.widget.setSizePolicy(sizePolicy)
+ self.widget.setMinimumSize(QtCore.QSize(0, 50))
+ self.widget.setObjectName("widget")
+ self.gridLayout = QtWidgets.QGridLayout(self.widget)
+ self.gridLayout.setObjectName("gridLayout")
+ self.user_name = QtWidgets.QLabel(parent=self.widget)
+ self.user_name.setAlignment(QtCore.Qt.AlignmentFlag.AlignBottom|QtCore.Qt.AlignmentFlag.AlignLeading|QtCore.Qt.AlignmentFlag.AlignLeft)
+ self.user_name.setObjectName("user_name")
+ self.gridLayout.addWidget(self.user_name, 0, 0, 1, 1)
+ self.content = QtWidgets.QPlainTextEdit(parent=self.widget)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.MinimumExpanding)
+ sizePolicy.setHorizontalStretch(0)
+ sizePolicy.setVerticalStretch(0)
+ sizePolicy.setHeightForWidth(self.content.sizePolicy().hasHeightForWidth())
+ self.content.setSizePolicy(sizePolicy)
+ self.content.setMinimumSize(QtCore.QSize(0, 40))
+ self.content.setStyleSheet("border-radius: 5px; border: 5px solid #1f1f1f; background-color: #1f1f1f; color: #ffffff;")
+ self.content.setFrameShape(QtWidgets.QFrame.Shape.NoFrame)
+ self.content.setFrameShadow(QtWidgets.QFrame.Shadow.Plain)
+ self.content.setReadOnly(True)
+ self.content.setObjectName("content")
+ self.gridLayout.addWidget(self.content, 0, 1, 1, 1)
+ self.bot_name = QtWidgets.QLabel(parent=self.widget)
+ self.bot_name.setAlignment(QtCore.Qt.AlignmentFlag.AlignBottom|QtCore.Qt.AlignmentFlag.AlignLeading|QtCore.Qt.AlignmentFlag.AlignLeft)
+ self.bot_name.setObjectName("bot_name")
+ self.gridLayout.addWidget(self.bot_name, 0, 2, 1, 1)
+ self.gridLayout_2.addWidget(self.widget, 0, 0, 1, 1)
+
+ self.retranslateUi(message)
+ QtCore.QMetaObject.connectSlotsByName(message)
+
+ def retranslateUi(self, message):
+ _translate = QtCore.QCoreApplication.translate
+ message.setWindowTitle(_translate("message", "Form"))
+ self.user_name.setText(_translate("message", "TextLabel"))
+ self.bot_name.setText(_translate("message", "TextLabel"))
diff --git a/src/airunner/widgets/model_manager/model_manager_widget.py b/src/airunner/widgets/model_manager/model_manager_widget.py
index 7effc5aa7..bedfccf3a 100644
--- a/src/airunner/widgets/model_manager/model_manager_widget.py
+++ b/src/airunner/widgets/model_manager/model_manager_widget.py
@@ -23,21 +23,6 @@ class ModelManagerWidget(BaseWidget):
current_model_data = None
_current_model_object = None
model_form = None
- icons = {
- "toolButton": "010-view",
- "edit_button": "settings",
- "delete_button": "006-trash",
- }
-
- def set_stylesheet(self):
- for key in self.model_widgets.keys():
- for model_widget in self.model_widgets[key]:
- for button, icon in self.icons.items():
- getattr(model_widget, button).setIcon(
- QtGui.QIcon(
- os.path.join(f"src/icons/{icon}{'-light' if self.is_dark else ''}.png")
- )
- )
@property
def current_model_object(self):
diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py
index 25a903653..b01d6c5f4 100644
--- a/src/airunner/windows/main/main_window.py
+++ b/src/airunner/windows/main/main_window.py
@@ -30,7 +30,7 @@
from airunner.input_event_manager import InputEventManager
from airunner.mixins.history_mixin import HistoryMixin
from airunner.settings import BASE_PATH
-from airunner.utils import get_version, get_latest_version, auto_export_image, save_session, \
+from airunner.utils import get_version, auto_export_image, save_session, \
create_airunner_paths, default_hf_cache_dir
from airunner.widgets.status.status_widget import StatusWidget
from airunner.windows.about.about import AboutWindow
@@ -42,9 +42,10 @@
from airunner.windows.video import VideoPopup
from airunner.data.models import TabSection
from airunner.widgets.brushes.brushes_container import BrushesContainer
-
+from airunner.data.models import Document
from airunner.utils import get_session
+
class MainWindow(
QMainWindow,
HistoryMixin
@@ -65,8 +66,6 @@ class MainWindow(
_settings_manager = None
models = None
client = None
- override_current_generator = None
- override_section = None
_version = None
_latest_version = None
add_image_to_canvas_signal = pyqtSignal(dict)
@@ -112,14 +111,6 @@ def settings_manager(self):
self._settings_manager = SettingsManager(app=self)
return self._settings_manager
- # @property
- # def current_prompt_generator_settings(self):
- # """
- # Convenience property to get the current prompt generator settings
- # :return:
- # """
- # return self.settings_manager.prompt_generator_settings
-
@property
def is_dark(self):
return self.settings_manager.dark_mode_enabled
@@ -135,6 +126,10 @@ def standard_image_panel(self):
@property
def generator_tab_widget(self):
return self.ui.generator_widget
+
+ @property
+ def canvas_widget(self):
+ return self.standard_image_panel.canvas_widget
@property
def toolbar_widget(self):
@@ -148,39 +143,6 @@ def prompt_builder(self):
def footer_widget(self):
return self.ui.footer_widget
- @property
- def current_generator(self):
- """
- Returns the current generator (stablediffusion, kandinksy, etc) as
- determined by the selected generator tab in the
- generator_tab_widget. This value can be override by setting
- the override_current_generator property.
- :return: string
- """
- if self.override_current_generator:
- return self.override_current_generator
- return self.generator_tab_widget.current_generator
-
- @property
- def current_section(self):
- """
- Returns the current section (txt2img, outpaint, etc) as
- determined by the selected sub-tab in the generator tab widget.
- This value can be override by setting the override_section property.
- :return: string
- """
- if self.override_section:
- return self.override_section
- return self.generator_tab_widget.current_section
-
- @property
- def tabs(self):
- return self._tabs[self.current_generator]
-
- @tabs.setter
- def tabs(self, val):
- self._tabs[self.current_generator] = val
-
@property
def generator_type(self):
"""
@@ -205,7 +167,7 @@ def latest_version(self, val):
@property
def document_name(self):
- # name = f"{self._document_name}{'*' if self.canvas and self.canvas.is_dirty else ''}"
+ # name = f"{self._document_name}{'*' if self.canvas and self.canvas_widget.is_dirty else ''}"
# return f"{name} - {self.version}"
return "Untitled"
@@ -234,13 +196,13 @@ def current_canvas(self):
return self.standard_image_panel
def describe_image(self, image, callback):
- self.ui.generator_widget.current_generator_widget.ui.ai_tab_widget.describe_image(
+ self.generator_tab_widget.ui.ai_tab_widget.describe_image(
image=image,
callback=callback
)
def current_active_image(self):
- return self.ui.standard_image_widget.image.copy() if self.ui.standard_image_widget.image else None
+ return self.standard_image_panel.image.copy() if self.standard_image_panel.image else None
def send_message(self, code, message):
self.message_var.emit({
@@ -257,6 +219,7 @@ def available_model_names_by_section(self, section):
def __init__(self, *args, **kwargs):
logger.info("Starting AI Runnner")
+ self.ui = Ui_MainWindow()
# qdarktheme.enable_hi_dpi()
# set the api
@@ -266,20 +229,17 @@ def __init__(self, *args, **kwargs):
self.testing = kwargs.pop("testing", False)
# initialize the document
- from airunner.data.db import session
- from airunner.data.models import Document
+ session = get_session()
self.document = session.query(Document).first()
super().__init__(*args, **kwargs)
+ self.ui.setupUi(self)
+
self.initialize()
# on window resize:
- # self.applicationStateChanged.connect(self.on_state_changed)
-
- if self.settings_manager.latest_version_check:
- logger.info("Checking for latest version")
- self.check_for_latest_version()
+ # self.windowStateChanged.connect(self.on_state_changed)
# check for self.current_layer.lines every 100ms
self.timer = self.startTimer(100)
@@ -305,7 +265,6 @@ def __init__(self, *args, **kwargs):
#self.ui.layer_widget.initialize()
self.set_button_checked("toggle_grid", self.settings_manager.grid_settings.show_grid, False)
- self.set_button_checked("safety_checker", self.settings_manager.nsfw_filter, False)
# call a function after the window has finished loading:
QTimer.singleShot(500, self.on_show)
@@ -363,7 +322,6 @@ def mode_tab_index_changed(self, index):
self.settings_manager.set_value("mode", self.ui.mode_tab_widget.tabText(index))
def on_show(self):
- #self.ui.canvas_plus_widget.do_draw()
pass
def action_slider_changed(self, settings_property, value):
@@ -412,23 +370,23 @@ def action_redo_triggered(self):
def action_paste_image_triggered(self):
if self.settings_manager.mode == Mode.IMAGE.value:
- self.canvas.paste_image_from_clipboard()
+ self.canvas_widget.paste_image_from_clipboard()
def action_copy_image_triggered(self):
if self.settings_manager.mode == Mode.IMAGE.value:
- self.canvas.copy_image(self.current_active_image())
+ self.canvas_widget.copy_image(self.current_active_image())
def action_cut_image_triggered(self):
if self.settings_manager.mode == Mode.IMAGE.value:
- self.canvas.cut_image()
+ self.canvas_widget.cut_image()
def action_rotate_90_clockwise_triggered(self):
if self.settings_manager.mode == Mode.IMAGE.value:
- self.canvas.rotate_90_clockwise()
+ self.canvas_widget.rotate_90_clockwise()
def action_rotate_90_counterclockwise_triggered(self):
if self.settings_manager.mode == Mode.IMAGE.value:
- self.canvas.rotate_90_counterclockwise()
+ self.canvas_widget.rotate_90_counterclockwise()
def action_show_prompt_browser_triggered(self):
self.show_prompt_browser()
@@ -549,7 +507,7 @@ def action_toggle_nsfw_filter_triggered(self, bool):
def action_toggle_grid(self, active):
self.settings_manager.set_value("grid_settings.show_grid", active)
- # self.canvas.update()
+ # self.canvas_widget.update()
def action_toggle_darkmode(self):
self.set_stylesheet()
@@ -624,10 +582,10 @@ def set_size_increment_levels(self):
self.ui.height_slider_widget.slider_single_step = size
self.ui.height_slider_widget.slider_tick_interval = size
- self.canvas.update()
+ self.canvas_widget.update()
def toggle_nsfw_filter(self):
- # self.canvas.update()
+ # self.canvas_widget.update()
self.set_nsfw_filter_tooltip()
def set_nsfw_filter_tooltip(self):
@@ -636,43 +594,43 @@ def set_nsfw_filter_tooltip(self):
f"Click to {'enable' if not nsfw_filter else 'disable'} NSFW filter"
)
- def resizeEvent(self, event):
- if not self.is_started:
- return
- state = self.windowState()
- if state == Qt.WindowState.WindowMaximized:
- timer = QTimer(self)
- timer.setSingleShot(True)
- timer.timeout.connect(self.checkWindowState)
- timer.start(100)
- else:
- self.checkWindowState()
-
- def checkWindowState(self):
- state = self.windowState()
- self.is_maximized = state == Qt.WindowState.WindowMaximized
+ # def resizeEvent(self, event):
+ # if not self.is_started:
+ # return
+ # state = self.windowState()
+ # if state == Qt.WindowState.WindowMaximized:
+ # timer = QTimer(self)
+ # timer.setSingleShot(True)
+ # timer.timeout.connect(self.checkWindowState)
+ # timer.start(100)
+ # else:
+ # self.checkWindowState()
+
+ # def checkWindowState(self):
+ # state = self.windowState()
+ # self.is_maximized = state == Qt.WindowState.WindowMaximized
def dragmode_pressed(self):
- # self.canvas.is_canvas_drag_mode = True
+ # self.canvas_widget.is_canvas_drag_mode = True
pass
def dragmode_released(self):
- # self.canvas.is_canvas_drag_mode = False
+ # self.canvas_widget.is_canvas_drag_mode = False
pass
def shift_pressed(self):
- # self.canvas.shift_is_pressed = True
+ # self.canvas_widget.shift_is_pressed = True
pass
def shift_released(self):
- # self.canvas.shift_is_pressed = False
+ # self.canvas_widget.shift_is_pressed = False
pass
def register_keypress(self):
self.input_event_manager.register_keypress("fullscreen", self.toggle_fullscreen)
self.input_event_manager.register_keypress("control_pressed", self.dragmode_pressed, self.dragmode_released)
self.input_event_manager.register_keypress("shift_pressed", self.shift_pressed, self.shift_released)
- #self.input_event_manager.register_keypress("delete_outside_active_grid_area", self.canvas.delete_outside_active_grid_area)
+ #self.input_event_manager.register_keypress("delete_outside_active_grid_area", self.canvas_widget.delete_outside_active_grid_area)
def toggle_fullscreen(self):
if self.isFullScreen():
@@ -688,49 +646,10 @@ def closeEvent(self, event):
QApplication.quit()
def timerEvent(self, event):
- # self.canvas.timerEvent(event)
+ # self.canvas_widget.timerEvent(event)
if self.status_widget:
self.status_widget.update_system_stats(queue_size=self.client.queue.qsize())
- def check_for_latest_version(self):
- self.version_thread = QThread()
- class VersionCheckWorker(QObject):
- version = None
- finished = pyqtSignal()
- def get_latest_version(self):
- self.version = f"v{get_latest_version()}"
- self.finished.emit()
- self.version_worker = VersionCheckWorker()
- self.version_worker.moveToThread(self.version_thread)
- self.version_thread.started.connect(self.version_worker.get_latest_version)
- self.version_worker.finished.connect(self.handle_latest_version)
- self.version_thread.start()
-
- def handle_latest_version(self):
- self.latest_version = self.version_worker.version
- # call get_latest_version() in a separate thread
- # to avoid blocking the UI, show a popup if version doesn't match self.version
- # check if latest_version is greater than version using major, minor, patch
- current_major, current_minor, current_patch = self.version[1:].split(".")
- try:
- latest_major, latest_minor, latest_patch = self.latest_version[1:].split(".")
- except ValueError:
- latest_major, latest_minor, latest_patch = 0, 0, 0
-
- latest_major = int(latest_major)
- latest_minor = int(latest_minor)
- latest_patch = int(latest_patch)
- current_major = int(current_major)
- current_minor = int(current_minor)
- current_patch = int(current_patch)
-
- if current_major == latest_major and current_minor == latest_minor and current_patch < latest_patch:
- self.show_update_message()
- elif current_major == latest_major and current_minor < latest_minor:
- self.show_update_message()
- elif current_major < latest_major:
- self.show_update_message()
-
def show_update_message(self):
self.set_status_label(f"New version available: {self.latest_version}")
@@ -739,13 +658,13 @@ def show_update_popup(self):
def reset_settings(self):
logger.info("Resetting settings")
- self.canvas.reset_settings()
+ self.canvas_widget.reset_settings()
def on_state_changed(self, state):
if state == Qt.ApplicationState.ApplicationActive:
- self.canvas.pos_x = int(self.x() / 4)
- self.canvas.pos_y = int(self.y() / 2)
- self.canvas.update()
+ self.canvas_widget.pos_x = int(self.x() / 4)
+ self.canvas_widget.pos_y = int(self.y() / 2)
+ self.canvas_widget.update()
def refresh_styles(self):
self.set_stylesheet()
@@ -771,7 +690,7 @@ def set_stylesheet(self):
("eraser-icon", "toggle_eraser_button"),
("frame-grid-icon", "toggle_grid_button"),
("circle-center-icon", "focus_button"),
- ("adult-sign-icon", "safety_checker_button"),
+ ("artificial-intelligence-ai-chip-icon", "ai_button"),
("setting-line-icon", "settings_button"),
("chat-box-icon", "chat_button"),
("setting-line-icon", "llm_preferences_button"),
@@ -842,7 +761,7 @@ def initialize_mixins(self):
def connect_signals(self):
logger.info("Connecting signals")
- #self.canvas._is_dirty.connect(self.set_window_title)
+ #self.canvas_widget._is_dirty.connect(self.set_window_title)
for signal, handler in self.registered_settings_handlers:
getattr(self.settings_manager, signal).connect(handler)
@@ -944,16 +863,13 @@ def set_splitter_sizes(self):
def show_section(self, section):
section_lists = {
- "left": [self.ui.generator_widget.ui.generator_tabs.tabText(i) for i in range(self.ui.generator_widget.ui.generator_tabs.count())],
"center": [self.ui.center_tab.tabText(i) for i in range(self.ui.center_tab.count())],
"right": [self.ui.tool_tab_widget.tabText(i) for i in range(self.ui.tool_tab_widget.count())],
"bottom": [self.ui.bottom_panel_tab_widget.tabText(i) for i in range(self.ui.bottom_panel_tab_widget.count())]
}
for k, v in section_lists.items():
if section in v:
- if k == "left":
- self.ui.generator_widget.ui.generator_tabs.setCurrentIndex(v.index(section))
- elif k == "right":
+ if k == "right":
self.ui.tool_tab_widget.setCurrentIndex(v.index(section))
elif k == "bottom":
self.ui.bottom_panel_tab_widget.setCurrentIndex(v.index(section))
@@ -992,7 +908,7 @@ def handle_value_change(self, attr_name, value=None, widget=None):
self.settings_manager.set_value(attr_name, value)
def handle_similar_slider_change(self, attr_name, value=None, widget=None):
- self.ui.standard_image_widget.handle_similar_slider_change(value)
+ self.standard_image_panel.handle_similar_slider_change(value)
def initialize_settings_manager(self):
self.settings_manager.changed_signal.connect(self.handle_changed_signal)
@@ -1003,11 +919,11 @@ def handle_changed_signal(self, key, value):
elif key == "line_width":
self.set_size_form_element_step_values()
elif key == "show_grid":
- self.canvas.update()
+ self.canvas_widget.update()
elif key == "snap_to_grid":
- self.canvas.update()
+ self.canvas_widget.update()
elif key == "line_color":
- self.canvas.update_grid_pen()
+ self.canvas_widget.update_grid_pen()
elif key == "lora_path":
self.refresh_lora()
elif key == "model_base_path":
@@ -1031,9 +947,6 @@ def initialize_handlers(self):
self.message_var.my_signal.connect(self.message_handler)
def initialize_window(self):
- self.window = Ui_MainWindow()
- self.window.setupUi(self)
- self.ui = self.window
self.center()
self.set_window_title()
@@ -1123,7 +1036,7 @@ def new_document(self):
self._document_name = "Untitled"
self.set_window_title()
self.current_filter = None
- #self.canvas.update()
+ #self.canvas_widget.update()
self.ui.layer_widget.show_layers()
def set_status_label(self, txt, error=False):
@@ -1195,7 +1108,7 @@ def handle_image_generated(self, message):
# get max progressbar value
if nsfw_content_detected and self.settings_manager.nsfw_filter:
self.message_handler({
- "message": "NSFW content detected, try again.",
+ "message": "Explicit content detected, try again.",
"code": MessageCode.ERROR
})
@@ -1261,7 +1174,7 @@ def set_size_form_element_step_values(self):
def saveas_document(self):
# get file path
file_path, _ = QFileDialog.getSaveFileName(
- self.window, "Save Document", "", "AI Runner Document (*.airunner)"
+ self.ui, "Save Document", "", "AI Runner Document (*.airunner)"
)
if file_path == "":
return
@@ -1285,8 +1198,8 @@ def do_save(self, document_name):
layers.append(layer)
data = {
"layers": layers,
- "image_pivot_point": self.canvas.image_pivot_point,
- "image_root_point": self.canvas.image_root_point,
+ "image_pivot_point": self.canvas_widget.image_pivot_point,
+ "image_root_point": self.canvas_widget.image_root_point,
}
with open(document_name, "wb") as f:
pickle.dump(data, f)
@@ -1298,10 +1211,10 @@ def do_save(self, document_name):
self._document_name = document_name.split("/")[-1].split(".")[0]
self.set_window_title()
self.is_saved = True
- self.canvas.is_dirty = False
+ self.canvas_widget.is_dirty = False
def update(self):
- self.ui.standard_image_widget.update_thumbnails()
+ self.standard_image_panel.update_thumbnails()
def insert_into_prompt(self, text, negative_prompt=False):
prompt_widget = self.generator_tab_widget.data[self.current_generator][self.current_section]["prompt_widget"]
@@ -1334,28 +1247,10 @@ def change_content_widget(self):
self.ui.center_tab.setCurrentIndex(tab_index)
self.ui.center_tab.blockSignals(False)
- def handle_generator_tab_changed(self):
- self.generator_tab_changed_signal.emit()
- self.change_content_widget()
-
- def handle_tab_section_changed(self):
- self.tab_section_changed_signal.emit()
- self.change_content_widget()
-
- def release_tab_overrides(self):
- self.override_current_generator = None
- self.override_section = None
-
def clear_all_prompts(self):
- for tab_section in self._tabs.keys():
- self.override_current_generator = tab_section
- for tab in self.tabs.keys():
- self.override_section = tab
- self.prompt = ""
- self.negative_prompt = ""
- self.generator_tab_widget.clear_prompts(tab_section, tab)
- self.override_current_generator = None
- self.override_section = None
+ self.prompt = ""
+ self.negative_prompt = ""
+ self.generator_tab_widget.clear_prompts()
def show_prompt_browser(self):
PromptBrowser(settings_manager=self.settings_manager, app=self)
@@ -1410,7 +1305,7 @@ def update_negative_prompt(self, prompt_value):
self.generator_tab_widget.update_negative_prompt(prompt_value)
def new_batch(self, index, image, data):
- self.generator_tab_widget.current_generator.new_batch(index, image, data)
+ self.generator_tab_widget.new_batch(index, image, data)
def image_generation_toggled(self):
self.settings_manager.set_value("mode", Mode.IMAGE.value)
@@ -1485,10 +1380,6 @@ def activate_image_generation_section(self):
def activate_language_processing_section(self):
self.ui.mode_tab_widget.setCurrentIndex(1)
- # try:
- # self.ui.generator_widget.current_generator_widget.ui.generator_form_tab_widget.setCurrentIndex(2)
- # except AttributeError as e:
- # pass
self.toggle_tool_section_buttons_visibility()
def activate_model_manager_section(self):
@@ -1534,12 +1425,7 @@ def set_all_image_generator_buttons(self):
def image_generators_toggled(self):
self.image_generation_toggled()
- current_tab = self.settings_manager.current_tab
- current_tab = self.settings_manager.current_image_generator
- self.settings_manager.set_value("current_tab", current_tab)
self.settings_manager.set_value("mode", Mode.IMAGE.value)
- self.settings_manager.set_value(f"current_section_{current_tab}", GeneratorSection.TXT2IMG.value)
- self.generator_tab_widget.set_current_section_tab()
self.settings_manager.set_value("generator_section", GeneratorSection.TXT2IMG.value)
active_tab_obj = session.query(TabSection).filter(TabSection.panel == "center_tab").first()
active_tab_obj.active_tab = "Canvas"
@@ -1549,11 +1435,7 @@ def image_generators_toggled(self):
def text_to_video_toggled(self):
self.image_generation_toggled()
- current_tab = "stablediffusion"
- self.settings_manager.set_value("current_tab", current_tab)
self.settings_manager.set_value("mode", Mode.IMAGE.value)
- self.settings_manager.set_value(f"current_section_{current_tab}", GeneratorSection.TXT2VID.value)
- self.generator_tab_widget.set_current_section_tab()
self.settings_manager.set_value("generator_section", GeneratorSection.TXT2VID.value)
active_tab_obj = session.query(TabSection).filter(TabSection.panel == "center_tab").first()
active_tab_obj.active_tab = "Video"
@@ -1563,9 +1445,6 @@ def text_to_video_toggled(self):
def prompt_builder_toggled(self):
self.image_generation_toggled()
- current_tab = self.settings_manager.current_tab
- self.settings_manager.set_value(f"current_section_{current_tab}", GeneratorSection.PROMPT_BUILDER.value)
- self.generator_tab_widget.set_current_section_tab()
self.settings_manager.set_value(f"generator_section", GeneratorSection.PROMPT_BUILDER.value)
active_tab_obj = session.query(TabSection).filter(TabSection.panel == "center_tab").first()
active_tab_obj.active_tab = "Prompt Builder"
diff --git a/src/airunner/windows/main/templates/main_window.ui b/src/airunner/windows/main/templates/main_window.ui
index bb5fdf015..514bdf4cd 100644
--- a/src/airunner/windows/main/templates/main_window.ui
+++ b/src/airunner/windows/main/templates/main_window.ui
@@ -7,12 +7,12 @@
0
0
- 1085
- 1141
+ 1144
+ 955
-
+
0
0
@@ -128,7 +128,7 @@
0
0
- 1081
+ 1140
50
@@ -1058,8 +1058,8 @@
0
0
- 365
- 1017
+ 386
+ 831
@@ -1425,20 +1425,21 @@
-
-
+
-
+
0
0
-
-
- 50
- 45
-
+
+ Qt::Horizontal
-
+
+
+ -
+
+
50
45
@@ -1447,38 +1448,26 @@
PointingHandCursor
-
- Toggle NSFW
-
- :/icons/light/adult-sign-icon.svg:/icons/light/adult-sign-icon.svg
+ :/icons/light/artificial-intelligence-ai-chip-icon.svg:/icons/light/artificial-intelligence-ai-chip-icon.svg
- 18
- 18
+ 20
+ 20
-
- true
-
true
-
-
-
-
- 0
- 0
-
-
+
Qt::Horizontal
@@ -1602,7 +1591,7 @@
0
0
- 1085
+ 1144
22
@@ -3039,7 +3028,7 @@
action_new_document_triggered()
- 857
+ 916
65
@@ -3055,7 +3044,7 @@
action_undo_triggered()
- 1034
+ 1093
65
@@ -3071,7 +3060,7 @@
action_redo_triggered()
- 1076
+ 1135
65
@@ -3087,7 +3076,7 @@
action_export_image_triggered()
- 983
+ 1042
65
@@ -3240,22 +3229,6 @@
-
- safety_checker_button
- toggled(bool)
- MainWindow
- action_toggle_nsfw_filter_triggered(bool)
-
-
- 1079
- 422
-
-
- 1002
- 624
-
-
-
settings_button
clicked()
@@ -3263,8 +3236,8 @@
action_show_settings()
- 1079
- 482
+ 1138
+ 426
1002
@@ -3279,8 +3252,8 @@
action_toggle_grid(bool)
- 1079
- 320
+ 1138
+ 315
1002
@@ -3295,8 +3268,8 @@
action_toggle_active_grid_area(bool)
- 1079
- 158
+ 1138
+ 153
1002
@@ -3311,8 +3284,8 @@
action_toggle_eraser(bool)
- 1079
- 260
+ 1138
+ 255
0
@@ -3327,8 +3300,8 @@
action_toggle_brush(bool)
- 1079
- 209
+ 1138
+ 204
0
diff --git a/src/airunner/windows/main/templates/main_window_ui.py b/src/airunner/windows/main/templates/main_window_ui.py
index 793c73ef6..9597e9657 100644
--- a/src/airunner/windows/main/templates/main_window_ui.py
+++ b/src/airunner/windows/main/templates/main_window_ui.py
@@ -12,8 +12,8 @@
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
- MainWindow.resize(1085, 1141)
- sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Maximum, QtWidgets.QSizePolicy.Policy.Preferred)
+ MainWindow.resize(1144, 955)
+ sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.MinimumExpanding, QtWidgets.QSizePolicy.Policy.Preferred)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth())
@@ -58,7 +58,7 @@ def setupUi(self, MainWindow):
self.scrollArea_3.setWidgetResizable(True)
self.scrollArea_3.setObjectName("scrollArea_3")
self.scrollAreaWidgetContents_3 = QtWidgets.QWidget()
- self.scrollAreaWidgetContents_3.setGeometry(QtCore.QRect(0, 0, 1081, 50))
+ self.scrollAreaWidgetContents_3.setGeometry(QtCore.QRect(0, 0, 1140, 50))
self.scrollAreaWidgetContents_3.setObjectName("scrollAreaWidgetContents_3")
self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.scrollAreaWidgetContents_3)
self.horizontalLayout_2.setContentsMargins(0, 0, 0, 0)
@@ -385,7 +385,7 @@ def setupUi(self, MainWindow):
self.scrollArea.setWidgetResizable(True)
self.scrollArea.setObjectName("scrollArea")
self.scrollAreaWidgetContents = QtWidgets.QWidget()
- self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 365, 1017))
+ self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 386, 831))
self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents")
self.gridLayout = QtWidgets.QGridLayout(self.scrollAreaWidgetContents)
self.gridLayout.setObjectName("gridLayout")
@@ -538,24 +538,6 @@ def setupUi(self, MainWindow):
self.focus_button.setFlat(True)
self.focus_button.setObjectName("focus_button")
self.verticalLayout.addWidget(self.focus_button)
- self.safety_checker_button = QtWidgets.QPushButton(parent=self.button_menu)
- sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Maximum, QtWidgets.QSizePolicy.Policy.MinimumExpanding)
- sizePolicy.setHorizontalStretch(0)
- sizePolicy.setVerticalStretch(0)
- sizePolicy.setHeightForWidth(self.safety_checker_button.sizePolicy().hasHeightForWidth())
- self.safety_checker_button.setSizePolicy(sizePolicy)
- self.safety_checker_button.setMinimumSize(QtCore.QSize(50, 45))
- self.safety_checker_button.setMaximumSize(QtCore.QSize(50, 45))
- self.safety_checker_button.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor))
- self.safety_checker_button.setText("")
- icon15 = QtGui.QIcon()
- icon15.addPixmap(QtGui.QPixmap(":/icons/light/adult-sign-icon.svg"), QtGui.QIcon.Mode.Normal, QtGui.QIcon.State.Off)
- self.safety_checker_button.setIcon(icon15)
- self.safety_checker_button.setIconSize(QtCore.QSize(18, 18))
- self.safety_checker_button.setCheckable(True)
- self.safety_checker_button.setFlat(True)
- self.safety_checker_button.setObjectName("safety_checker_button")
- self.verticalLayout.addWidget(self.safety_checker_button)
self.line_4 = QtWidgets.QFrame(parent=self.button_menu)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.MinimumExpanding, QtWidgets.QSizePolicy.Policy.Fixed)
sizePolicy.setHorizontalStretch(0)
@@ -566,6 +548,22 @@ def setupUi(self, MainWindow):
self.line_4.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken)
self.line_4.setObjectName("line_4")
self.verticalLayout.addWidget(self.line_4)
+ self.ai_button = QtWidgets.QPushButton(parent=self.button_menu)
+ self.ai_button.setMinimumSize(QtCore.QSize(50, 45))
+ self.ai_button.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.PointingHandCursor))
+ self.ai_button.setText("")
+ icon15 = QtGui.QIcon()
+ icon15.addPixmap(QtGui.QPixmap(":/icons/light/artificial-intelligence-ai-chip-icon.svg"), QtGui.QIcon.Mode.Normal, QtGui.QIcon.State.Off)
+ self.ai_button.setIcon(icon15)
+ self.ai_button.setIconSize(QtCore.QSize(20, 20))
+ self.ai_button.setFlat(True)
+ self.ai_button.setObjectName("ai_button")
+ self.verticalLayout.addWidget(self.ai_button)
+ self.line_3 = QtWidgets.QFrame(parent=self.button_menu)
+ self.line_3.setFrameShape(QtWidgets.QFrame.Shape.HLine)
+ self.line_3.setFrameShadow(QtWidgets.QFrame.Shadow.Sunken)
+ self.line_3.setObjectName("line_3")
+ self.verticalLayout.addWidget(self.line_3)
self.settings_button = QtWidgets.QPushButton(parent=self.button_menu)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Maximum, QtWidgets.QSizePolicy.Policy.MinimumExpanding)
sizePolicy.setHorizontalStretch(0)
@@ -608,7 +606,7 @@ def setupUi(self, MainWindow):
self.header_widget.raise_()
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(parent=MainWindow)
- self.menubar.setGeometry(QtCore.QRect(0, 0, 1085, 22))
+ self.menubar.setGeometry(QtCore.QRect(0, 0, 1144, 22))
font = QtGui.QFont()
font.setPointSize(11)
self.menubar.setFont(font)
@@ -961,7 +959,6 @@ def setupUi(self, MainWindow):
self.image_generators_button.released.connect(MainWindow.image_generators_toggled) # type: ignore
self.txt2vid_button.released.connect(MainWindow.text_to_video_toggled) # type: ignore
self.prompt_builder_button.released.connect(MainWindow.prompt_builder_toggled) # type: ignore
- self.safety_checker_button.toggled['bool'].connect(MainWindow.action_toggle_nsfw_filter_triggered) # type: ignore
self.settings_button.clicked.connect(MainWindow.action_show_settings) # type: ignore
self.toggle_grid_button.toggled['bool'].connect(MainWindow.action_toggle_grid) # type: ignore
self.toggle_active_grid_area_button.toggled['bool'].connect(MainWindow.action_toggle_active_grid_area) # type: ignore
@@ -1000,7 +997,6 @@ def retranslateUi(self, MainWindow):
self.toggle_eraser_button.setToolTip(_translate("MainWindow", "Eraser tool"))
self.toggle_grid_button.setToolTip(_translate("MainWindow", "Toggle Grid"))
self.focus_button.setToolTip(_translate("MainWindow", "Recenter canvas"))
- self.safety_checker_button.setToolTip(_translate("MainWindow", "Toggle NSFW"))
self.settings_button.setToolTip(_translate("MainWindow", "AI Runner Settings"))
self.mode_tab_widget.setTabText(self.mode_tab_widget.indexOf(self.art), _translate("MainWindow", "Image Generation"))
self.mode_tab_widget.setTabText(self.mode_tab_widget.indexOf(self.chat), _translate("MainWindow", "Text Generation"))