Skip to content

Commit

Permalink
[SD3 LoRA] Fix list index out of range (#8584)
Browse files Browse the repository at this point in the history
* fix

* add check

* key present is checked before

* test case draft

* aply suggestions

* changed testing repo, back to old class

* forgot docstring

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Jun 21, 2024
1 parent 8eb1731 commit e7b9a07
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
Expand Down Expand Up @@ -1543,6 +1544,11 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
}

if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)

if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
Expand Down
23 changes: 22 additions & 1 deletion tests/lora/test_lora_layers_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, torch_device
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device


if is_peft_available():
Expand Down Expand Up @@ -287,3 +287,24 @@ def test_simple_inference_with_transformer_fuse_unfuse(self):
self.assertTrue(
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)

@require_torch_gpu
def test_sd3_lora(self):
"""
Test loading the loras that are saved with the diffusers and peft formats.
Related PR: https://github.com/huggingface/diffusers/pull/8584
"""
components = self.get_dummy_components()

pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

lora_model_id = "hf-internal-testing/tiny-sd3-loras"

lora_filename = "lora_diffusers_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.unload_lora_weights()

lora_filename = "lora_peft_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

0 comments on commit e7b9a07

Please sign in to comment.