Skip to content

Commit

Permalink
fixes lora loading
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Apr 28, 2023
1 parent 8f4cdc9 commit d9e5926
Showing 1 changed file with 67 additions and 61 deletions.
128 changes: 67 additions & 61 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,9 @@ def emit(self):
self.my_signal.emit(name, scale, enabled)


class AvailableLorasVar(Var):
my_signal = pyqtSignal(list)

def __init__(self, app=None, loras=None):
super().__init__(app, None)
self.loras = loras


class Settings:
def __init__(self, app):
self.available_loras = AvailableLorasVar(app, [])
self.available_loras = []


class Extension(BaseExtension):
Expand All @@ -52,81 +44,94 @@ class Extension(BaseExtension):

def __init__(self, settings_manager=None):
super().__init__(settings_manager)
# print stack trace
import traceback
traceback.print_stack()
self._available_loras = None
self.settings_manager.settings.available_loras = AvailableLorasVar(self, [])

@property
def available_loras(self):
if self._available_loras is None:
_available_loras = []
self._available_loras = {}
if "available_loras" not in self.settings_manager.settings.__dict__:
self.settings_manager.settings.available_loras = {}

def available_loras(self, tab_name):
if tab_name not in self._available_loras:
self._available_loras[tab_name] = []
loras_path = os.path.join(self.model_base_path, "lora")
possible_line_endings = ["ckpt", "safetensors", "bin"]
self.settings_manager.enable_save()

for f in os.listdir(loras_path):
if f.split(".")[-1] in possible_line_endings:
lora = LoraVar(
name=f.split(".")[0],
scale=1.0,
enabled=True
)
"""
LoraVar ends up having the same name for all of the loras. This is because the name is set
in the constructor and the constructor is called for each lora.
We can fix this by using a lambda function to set the name of the lora when it is created.
"""
_available_loras.append(lora)
self.settings_manager.settings.available_loras.set(_available_loras)
return self.settings_manager.settings.available_loras.get()

def generator_tab_injection(self, tab, name=None):
# use the lora.ui widget which contains
# - a QCheckBox labled enabledCheckbox
# - a QSlider labled scaleSlider (0 - 100)
# - a QDoubleSpinBox labled scaleSpinBox (0.0 - 1.0)
# we will disable the name of the lora and set all of the properties of the lora and store this in settings
# these lora.ui widgets will be added to a QScrollArea widget on the tab
name = f.split(".")[0]
scale = 100.0
enabled = True
# check if we have scale in self.settings_manager.settings.available_loras[tab_name]
if tab_name in self.settings_manager.settings.available_loras:
loras = self.settings_manager.settings.available_loras[tab_name]
for lora in loras:
if lora["name"] == name:
scale = lora["scale"]
enabled = lora["enabled"]
break
self._available_loras[tab_name].append({
"name": name,
"scale": scale,
"enabled": enabled
})
self.settings_manager.settings.available_loras[tab_name] = self._available_loras[tab_name]
self.settings_manager.save_settings()
return self.settings_manager.settings.available_loras[tab_name]

def generator_tab_injection(self, tab, tab_name=None):
container = QWidget()
container.setLayout(QVBoxLayout())
for lora in self.available_loras:
# load the lora.ui widget
for lora in self.available_loras(tab_name):
lora_widget = self.load_template("lora")
lora_widget.enabledCheckbox.setText(lora.name.get())
lora_widget.scaleSlider.setValue(int(lora.scale.get() * 100))
lora_widget.scaleSpinBox.setValue(lora.scale.get())
lora_widget.enabledCheckbox.setChecked(lora.enabled.get())
lora_widget.enabledCheckbox.setText(lora["name"])
scale = lora["scale"]
enabled = lora["enabled"]
lora_widget.scaleSlider.setValue(int(scale))
lora_widget.scaleSpinBox.setValue(scale / 100)
lora_widget.enabledCheckbox.setChecked(enabled)
container.layout().addWidget(lora_widget)

# connect the signals to properties of the lora
lora_widget.scaleSlider.valueChanged.connect(
lambda value, _lora_widget=lora_widget: _lora_widget.scaleSpinBox.setValue(value / 100))
lambda value, _lora_widget=lora_widget, _lora=lora, _tab_name=tab_name:
self.handle_lora_slider(_lora, _lora_widget, value, _tab_name))
lora_widget.scaleSpinBox.valueChanged.connect(
lambda value, _lora_widget=lora_widget: _lora_widget.scaleSlider.setValue(int(value * 100)))
lambda value, _lora_widget=lora_widget, _lora=lora, _tab_name=tab_name:
self.handle_lora_spinbox(_lora, _lora_widget, value, _tab_name))
lora_widget.enabledCheckbox.stateChanged.connect(
lambda value, _lora=lora: setattr(_lora, "enabled", value == 2))
lora_widget.scaleSlider.valueChanged.connect(lambda value, _lora=lora: setattr(_lora, "scale", value / 100))
lora_widget.scaleSpinBox.valueChanged.connect(lambda value, _lora=lora: setattr(_lora, "scale", value))
lambda value, _lora=lora, _tab_name=tab_name:
self.toggle_lora(lora, value, _tab_name))
# add a vertical spacer to the end of the container
container.layout().addStretch()

# create a new tab called "LoRA" on the tab.PromptTabsSection which is a QTabWidget
# add the container to the tab
tab.PromptTabsSection.addTab(container, "LoRA")

def toggle_lora(self, lora, value, tab_name):
for n in range(len(self.available_loras(tab_name))):
if self.settings_manager.settings.available_loras[tab_name][n]["name"] == lora["name"]:
self.settings_manager.settings.available_loras[tab_name][n]["enabled"] = value == 2
self.settings_manager.save_settings()

def handle_lora_slider(self, lora, lora_widget, value, tab_name):
for n in range(len(self.available_loras(tab_name))):
if self.settings_manager.settings.available_loras[tab_name][n]["name"] == lora["name"]:
self.settings_manager.settings.available_loras[tab_name][n]["scale"] = value / 100
lora_widget.scaleSpinBox.setValue(lora["scale"])
self.settings_manager.save_settings()

def handle_lora_spinbox(self, lora, lora_widget, value, tab_name):
for n in range(len(self.available_loras(tab_name))):
if self.settings_manager.settings.available_loras[tab_name][n]["name"] == lora["name"]:
self.settings_manager.settings.available_loras[tab_name][n]["scale"] = value * 100
lora_widget.scaleSlider.setValue(int(lora["scale"]))
self.settings_manager.save_settings()

def generate_data_injection(self, data):
for lora in self.available_loras:
if lora.enabled.get():
for lora in self.available_loras(data["action"]):
if lora["enabled"]:
data["options"]["lora"] = [(lora.name.get(), lora.scale.get())]
return data

def call_pipe(self, options, model_base_path, pipe, **kwargs):
if not self.lora_loaded:
for lora in options["lora"]:
path = os.path.join(model_base_path, "lora")
# find a file with the name of lora[0] in path with an extension of .ckpt or .pt or .bin or .safetensors

# find it first:
filepath = None
for root, dirs, files in os.walk(path):
for file in files:
Expand All @@ -140,6 +145,7 @@ def call_pipe(self, options, model_base_path, pipe, **kwargs):
self.lora_loaded = True
return pipe

# https://github.com/huggingface/diffusers/issues/3064
def load_lora(self, pipeline, checkpoint_path, multiplier=1.0, device="cuda", dtype=torch.float16):
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
Expand Down

0 comments on commit d9e5926

Please sign in to comment.