Skip to content

Commit

Permalink
[MAINTENANCE] Refactor and clean up. (#4008)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsherstinsky authored May 23, 2024
1 parent 7053966 commit 3b9192b
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions ludwig/encoders/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,32 +382,37 @@ def __init__(
raise ValueError("img_height and img_width should be identical.")
self._input_shape = (in_channels, img_height, img_width)

config_dict: dict
if use_pretrained and not saved_weights_in_checkpoint:
transformer = ViTModel.from_pretrained(pretrained_model)
config_dict = {
"pretrained_model_name_or_path": pretrained_model,
}
if output_attentions:
config_dict["attn_implementation"] = "eager"

transformer = ViTModel.from_pretrained(**config_dict)
else:
config = ViTConfig(
image_size=img_height,
num_channels=in_channels,
patch_size=patch_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
gradient_checkpointing=gradient_checkpointing,
)
config_dict = {
"image_size": img_height,
"num_channels": in_channels,
"patch_size": patch_size,
"hidden_size": hidden_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"intermediate_size": intermediate_size,
"hidden_act": hidden_act,
"hidden_dropout_prob": hidden_dropout_prob,
"attention_probs_dropout_prob": attention_probs_dropout_prob,
"initializer_range": initializer_range,
"layer_norm_eps": layer_norm_eps,
"gradient_checkpointing": gradient_checkpointing,
}
if output_attentions:
config_dict["attn_implementation"] = "eager"

config = ViTConfig(**config_dict)
transformer = ViTModel(config)

if output_attentions:
config_dict: dict = transformer.config.to_dict()
updated_config: ViTConfig = ViTConfig(**config_dict)
updated_config._attn_implementation = "eager"
transformer = ViTModel(updated_config)

self.transformer = FreezeModule(transformer, frozen=not trainable)

self._output_shape = (transformer.config.hidden_size,)
Expand Down

0 comments on commit 3b9192b

Please sign in to comment.