diff --git a/src/maxdiffusion/loaders/lora_base.py b/src/maxdiffusion/loaders/lora_base.py index c4cc432..f22696d 100644 --- a/src/maxdiffusion/loaders/lora_base.py +++ b/src/maxdiffusion/loaders/lora_base.py @@ -44,7 +44,7 @@ def _fetch_state_dict( user_agent, allow_pickle, ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + from .lora_pipeline import LORA_WEIGHT_NAME_SAFE model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py index f5d573d..82d32e8 100644 --- a/src/maxdiffusion/models/lora.py +++ b/src/maxdiffusion/models/lora.py @@ -14,8 +14,6 @@ limitations under the License. """ -import os - from typing import Union, Tuple, Optional import jax import jax.numpy as jnp diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 38e7078..5f02ec8 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -19,7 +19,7 @@ import jax.numpy as jnp from flax.linen import Partitioned from flax.traverse_util import flatten_dict, unflatten_dict -from flax.core.frozen_dict import unfreeze, freeze +from flax.core.frozen_dict import unfreeze from jax.random import PRNGKey from ..utils import logging diff --git a/src/maxdiffusion/models/resnet_flax.py b/src/maxdiffusion/models/resnet_flax.py index 9fa0a92..79ddcb3 100644 --- a/src/maxdiffusion/models/resnet_flax.py +++ b/src/maxdiffusion/models/resnet_flax.py @@ -163,7 +163,6 @@ def setup(self): ) def __call__(self, hidden_states, temb, deterministic=True, cross_attention_kwargs={}): - lora_scale = cross_attention_kwargs.get("scale", 0.0) residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states)