From ef42551e0db962d8ed43f9891104255d7d79a9b7 Mon Sep 17 00:00:00 2001 From: w4ffl35 <25737761+w4ffl35@users.noreply.github.com> Date: Mon, 1 May 2023 07:56:47 -0600 Subject: [PATCH] allows setting of lora path and loads them in real time --- main.py | 100 +++++++++++++++++++++++++++++++++++++------------------ setup.py | 2 +- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/main.py b/main.py index 33d9f45..38cc358 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ import os -from PyQt6.QtWidgets import QWidget, QVBoxLayout +from PyQt6.QtWidgets import QWidget, QVBoxLayout, QScrollArea from PyQt6.QtCore import pyqtSignal from airunner.extensions import BaseExtension from aihandler.qtvar import Var, StringVar, FloatVar, BooleanVar @@ -8,6 +8,7 @@ from diffusers.loaders import LoraLoaderMixin from safetensors.torch import load_file + class LoraVar(Var): my_signal = pyqtSignal(str, float, bool) @@ -42,40 +43,72 @@ class Extension(BaseExtension): extension_directory = "airunner-lora" lora_loaded = False - def __init__(self, settings_manager=None): - super().__init__(settings_manager) + def __init__(self, app, settings_manager=None): + super().__init__(app, settings_manager) self._available_loras = {} + # check if self.settings_manager.settings.lora_path is str + if "lora_path" not in self.settings_manager.settings.__dict__: + self.settings_manager.settings.lora_path = StringVar(app) + lora_path = self.settings_manager.settings.lora_path + if isinstance(lora_path, str): + self.settings_manager.settings.lora_path = StringVar(app) + self.settings_manager.settings.lora_path.set(lora_path) if "available_loras" not in self.settings_manager.settings.__dict__: self.settings_manager.settings.available_loras = {} + self.settings_manager.settings.lora_path.my_signal.connect(self.refresh_lora) + + def refresh_lora(self): + self._available_loras = {} + for tab_name in self.app.tabs.keys(): + tab = self.app.tabs[tab_name] + self.settings_manager.settings.available_loras[tab_name] = [] + + # find tab with name LoRA in tab.PromptTabsSection + for i in range(tab.PromptTabsSection.count()): + if tab.PromptTabsSection.tabText(i) == "LoRA": + tab.PromptTabsSection.removeTab(i) + break + self.generator_tab_injection(tab, tab_name) + + def get_list_of_available_loras(self, tab_name, lora_path, lora_names=None): + if lora_names is None: + lora_names = [] + if not os.path.exists(lora_path): + return lora_names + possible_line_endings = ["ckpt", "safetensors", "bin"] + for f in os.listdir(lora_path): + if os.path.isdir(os.path.join(lora_path, f)): + lora_names = self.get_list_of_available_loras(tab_name, os.path.join(lora_path, f), lora_names) + if f.split(".")[-1] in possible_line_endings: + name = f.split(".")[0] + scale = 100.0 + enabled = True + # check if we have scale in self.settings_manager.settings.available_loras[tab_name] + if tab_name in self.settings_manager.settings.available_loras: + loras = self.settings_manager.settings.available_loras[tab_name] or [] + for lora in loras: + if lora["name"] == name: + scale = lora["scale"] + enabled = lora["enabled"] + break + lora_names.append({ + "name": name, + "scale": scale, + "enabled": enabled + }) + return lora_names def available_loras(self, tab_name): + base_path = self.settings_manager.settings.model_base_path.get() + lora_path = self.settings_manager.settings.lora_path.get() or "lora" + if lora_path == "lora": + lora_path = os.path.join(base_path, lora_path) + if not os.path.exists(lora_path): + return [] if tab_name not in self._available_loras: self._available_loras[tab_name] = [] - loras_path = os.path.join(self.model_base_path, "lora") - possible_line_endings = ["ckpt", "safetensors", "bin"] self.settings_manager.enable_save() - - if not os.path.exists(loras_path): - os.makedirs(loras_path) - - for f in os.listdir(loras_path): - if f.split(".")[-1] in possible_line_endings: - name = f.split(".")[0] - scale = 100.0 - enabled = True - # check if we have scale in self.settings_manager.settings.available_loras[tab_name] - if tab_name in self.settings_manager.settings.available_loras: - loras = self.settings_manager.settings.available_loras[tab_name] - for lora in loras: - if lora["name"] == name: - scale = lora["scale"] - enabled = lora["enabled"] - break - self._available_loras[tab_name].append({ - "name": name, - "scale": scale, - "enabled": enabled - }) + self._available_loras[tab_name] = self.get_list_of_available_loras(tab_name, lora_path) self.settings_manager.settings.available_loras[tab_name] = self._available_loras[tab_name] self.settings_manager.save_settings() return self.settings_manager.settings.available_loras[tab_name] @@ -94,16 +127,20 @@ def generator_tab_injection(self, tab, tab_name=None): container.layout().addWidget(lora_widget) lora_widget.scaleSlider.valueChanged.connect( lambda value, _lora_widget=lora_widget, _lora=lora, _tab_name=tab_name: - self.handle_lora_slider(_lora, _lora_widget, value, _tab_name)) + self.handle_lora_slider(_lora, _lora_widget, value, _tab_name)) lora_widget.scaleSpinBox.valueChanged.connect( lambda value, _lora_widget=lora_widget, _lora=lora, _tab_name=tab_name: - self.handle_lora_spinbox(_lora, _lora_widget, value, _tab_name)) + self.handle_lora_spinbox(_lora, _lora_widget, value, _tab_name)) lora_widget.enabledCheckbox.stateChanged.connect( lambda value, _lora=lora, _tab_name=tab_name: - self.toggle_lora(lora, value, _tab_name)) + self.toggle_lora(lora, value, _tab_name)) # add a vertical spacer to the end of the container container.layout().addStretch() - tab.PromptTabsSection.addTab(container, "LoRA") + # make the container scrollable + scroll = QScrollArea() + scroll.setWidget(container) + scroll.setWidgetResizable(True) + tab.PromptTabsSection.addTab(scroll, "LoRA") def toggle_lora(self, lora, value, tab_name): for n in range(len(self.available_loras(tab_name))): @@ -129,7 +166,6 @@ def generate_data_injection(self, data): data["options"]["lora"] = [] for lora in self.available_loras(data["action"]): if lora["enabled"]: - print(lora) data["options"]["lora"].append((lora["name"], lora["scale"])) return data diff --git a/setup.py b/setup.py index 37aecce..90b3105 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='airunner-lora', - version="1.0.3", + version="1.0.4", author="Capsize LLC", description="LoRA extension for AI Runner", long_description=open("README.md", "r", encoding="utf-8").read(),