Skip to content

Commit

Permalink
fix linting errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
jfacevedo-google committed Oct 24, 2024
1 parent 4e40024 commit 1413e86
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/maxdiffusion/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _fetch_state_dict(
user_agent,
allow_pickle,
):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
from .lora_pipeline import LORA_WEIGHT_NAME_SAFE

model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
Expand Down
2 changes: 0 additions & 2 deletions src/maxdiffusion/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
limitations under the License.
"""

import os

from typing import Union, Tuple, Optional
import jax
import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import jax.numpy as jnp
from flax.linen import Partitioned
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.core.frozen_dict import unfreeze, freeze
from flax.core.frozen_dict import unfreeze
from jax.random import PRNGKey

from ..utils import logging
Expand Down
1 change: 0 additions & 1 deletion src/maxdiffusion/models/resnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def setup(self):
)

def __call__(self, hidden_states, temb, deterministic=True, cross_attention_kwargs={}):
lora_scale = cross_attention_kwargs.get("scale", 0.0)
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.swish(hidden_states)
Expand Down

0 comments on commit 1413e86

Please sign in to comment.