Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hyper SDXL Lora support #127

Merged
merged 16 commits into from
Oct 31, 2024
Merged

Hyper SDXL Lora support #127

merged 16 commits into from
Oct 31, 2024

Conversation

entrpn
Copy link
Collaborator

@entrpn entrpn commented Oct 24, 2024

This PR establishes LoRA support and includes Hyper-SD XL LoRA loading for inference.

@entrpn entrpn requested a review from anfals October 24, 2024 22:21
Copy link
Collaborator

@anfals anfals left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review in progress

Still need to look at src/maxdiffusion/models/modeling_flax_pytorch_utils.py and src/maxdiffusion/loaders/lora_conversion_utils.py

src/maxdiffusion/loaders/lora_base.py Show resolved Hide resolved

@classmethod
@validate_hf_hub_args
def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think calling this something like get_lora_state_dict might be clearer about what the method is doing?

src/maxdiffusion/loaders/lora_pipeline.py Show resolved Hide resolved
anfals
anfals previously approved these changes Oct 30, 2024
Copy link
Collaborator

@anfals anfals left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Approving, but I did have a question on the infinite loop avoidance logic.

src/maxdiffusion/loaders/lora_pipeline.py Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mostly just skimming this file since it seems to just be a lot of wrangling of PyTorch state_dict to Flax. Someone who works more on MaxDiffusion might have more thoughts

@@ -26,4 +26,6 @@ git+https://github.com/mlperf/logging.git
opencv-python-headless==4.10.0.84
orbax-checkpoint>=0.5.20
tokenizers==0.20.0
huggingface_hub==0.24.7
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated line?

src/maxdiffusion/loaders/lora_pipeline.py Show resolved Hide resolved
@entrpn entrpn merged commit 1deeca5 into main Oct 31, 2024
3 checks passed
@entrpn entrpn deleted the lora_support branch October 31, 2024 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants