Skip to content

Commit

Permalink
allows setting of lora path and loads them in real time
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed May 1, 2023
1 parent e2ff27f commit ef42551
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 33 deletions.
100 changes: 68 additions & 32 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from PyQt6.QtWidgets import QWidget, QVBoxLayout
from PyQt6.QtWidgets import QWidget, QVBoxLayout, QScrollArea
from PyQt6.QtCore import pyqtSignal
from airunner.extensions import BaseExtension
from aihandler.qtvar import Var, StringVar, FloatVar, BooleanVar
Expand All @@ -8,6 +8,7 @@
from diffusers.loaders import LoraLoaderMixin
from safetensors.torch import load_file


class LoraVar(Var):
my_signal = pyqtSignal(str, float, bool)

Expand Down Expand Up @@ -42,40 +43,72 @@ class Extension(BaseExtension):
extension_directory = "airunner-lora"
lora_loaded = False

def __init__(self, settings_manager=None):
super().__init__(settings_manager)
def __init__(self, app, settings_manager=None):
super().__init__(app, settings_manager)
self._available_loras = {}
# check if self.settings_manager.settings.lora_path is str
if "lora_path" not in self.settings_manager.settings.__dict__:
self.settings_manager.settings.lora_path = StringVar(app)
lora_path = self.settings_manager.settings.lora_path
if isinstance(lora_path, str):
self.settings_manager.settings.lora_path = StringVar(app)
self.settings_manager.settings.lora_path.set(lora_path)
if "available_loras" not in self.settings_manager.settings.__dict__:
self.settings_manager.settings.available_loras = {}
self.settings_manager.settings.lora_path.my_signal.connect(self.refresh_lora)

def refresh_lora(self):
self._available_loras = {}
for tab_name in self.app.tabs.keys():
tab = self.app.tabs[tab_name]
self.settings_manager.settings.available_loras[tab_name] = []

# find tab with name LoRA in tab.PromptTabsSection
for i in range(tab.PromptTabsSection.count()):
if tab.PromptTabsSection.tabText(i) == "LoRA":
tab.PromptTabsSection.removeTab(i)
break
self.generator_tab_injection(tab, tab_name)

def get_list_of_available_loras(self, tab_name, lora_path, lora_names=None):
if lora_names is None:
lora_names = []
if not os.path.exists(lora_path):
return lora_names
possible_line_endings = ["ckpt", "safetensors", "bin"]
for f in os.listdir(lora_path):
if os.path.isdir(os.path.join(lora_path, f)):
lora_names = self.get_list_of_available_loras(tab_name, os.path.join(lora_path, f), lora_names)
if f.split(".")[-1] in possible_line_endings:
name = f.split(".")[0]
scale = 100.0
enabled = True
# check if we have scale in self.settings_manager.settings.available_loras[tab_name]
if tab_name in self.settings_manager.settings.available_loras:
loras = self.settings_manager.settings.available_loras[tab_name] or []
for lora in loras:
if lora["name"] == name:
scale = lora["scale"]
enabled = lora["enabled"]
break
lora_names.append({
"name": name,
"scale": scale,
"enabled": enabled
})
return lora_names

def available_loras(self, tab_name):
base_path = self.settings_manager.settings.model_base_path.get()
lora_path = self.settings_manager.settings.lora_path.get() or "lora"
if lora_path == "lora":
lora_path = os.path.join(base_path, lora_path)
if not os.path.exists(lora_path):
return []
if tab_name not in self._available_loras:
self._available_loras[tab_name] = []
loras_path = os.path.join(self.model_base_path, "lora")
possible_line_endings = ["ckpt", "safetensors", "bin"]
self.settings_manager.enable_save()

if not os.path.exists(loras_path):
os.makedirs(loras_path)

for f in os.listdir(loras_path):
if f.split(".")[-1] in possible_line_endings:
name = f.split(".")[0]
scale = 100.0
enabled = True
# check if we have scale in self.settings_manager.settings.available_loras[tab_name]
if tab_name in self.settings_manager.settings.available_loras:
loras = self.settings_manager.settings.available_loras[tab_name]
for lora in loras:
if lora["name"] == name:
scale = lora["scale"]
enabled = lora["enabled"]
break
self._available_loras[tab_name].append({
"name": name,
"scale": scale,
"enabled": enabled
})
self._available_loras[tab_name] = self.get_list_of_available_loras(tab_name, lora_path)
self.settings_manager.settings.available_loras[tab_name] = self._available_loras[tab_name]
self.settings_manager.save_settings()
return self.settings_manager.settings.available_loras[tab_name]
Expand All @@ -94,16 +127,20 @@ def generator_tab_injection(self, tab, tab_name=None):
container.layout().addWidget(lora_widget)
lora_widget.scaleSlider.valueChanged.connect(
lambda value, _lora_widget=lora_widget, _lora=lora, _tab_name=tab_name:
self.handle_lora_slider(_lora, _lora_widget, value, _tab_name))
self.handle_lora_slider(_lora, _lora_widget, value, _tab_name))
lora_widget.scaleSpinBox.valueChanged.connect(
lambda value, _lora_widget=lora_widget, _lora=lora, _tab_name=tab_name:
self.handle_lora_spinbox(_lora, _lora_widget, value, _tab_name))
self.handle_lora_spinbox(_lora, _lora_widget, value, _tab_name))
lora_widget.enabledCheckbox.stateChanged.connect(
lambda value, _lora=lora, _tab_name=tab_name:
self.toggle_lora(lora, value, _tab_name))
self.toggle_lora(lora, value, _tab_name))
# add a vertical spacer to the end of the container
container.layout().addStretch()
tab.PromptTabsSection.addTab(container, "LoRA")
# make the container scrollable
scroll = QScrollArea()
scroll.setWidget(container)
scroll.setWidgetResizable(True)
tab.PromptTabsSection.addTab(scroll, "LoRA")

def toggle_lora(self, lora, value, tab_name):
for n in range(len(self.available_loras(tab_name))):
Expand All @@ -129,7 +166,6 @@ def generate_data_injection(self, data):
data["options"]["lora"] = []
for lora in self.available_loras(data["action"]):
if lora["enabled"]:
print(lora)
data["options"]["lora"].append((lora["name"], lora["scale"]))
return data

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='airunner-lora',
version="1.0.3",
version="1.0.4",
author="Capsize LLC",
description="LoRA extension for AI Runner",
long_description=open("README.md", "r", encoding="utf-8").read(),
Expand Down

0 comments on commit ef42551

Please sign in to comment.