Skip to content

Commit

Permalink
Merge branch 'main' into patch-5
Browse files Browse the repository at this point in the history
  • Loading branch information
KoichiYasuoka authored Dec 28, 2024
2 parents e51b48a + 5c75087 commit b247fb7
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_pa
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"](
image_size=224,
in_channels=3,
num_channels=3,
patch_size=4,
embed_dim=96,
depths=[2, 2, 18, 2],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
backbone_config = SwinConfig(
image_size=384,
in_channels=3,
num_channels=3,
patch_size=4,
embed_dim=128,
depths=[2, 2, 18, 2],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(
logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"](
image_size=224,
in_channels=3,
num_channels=3,
patch_size=4,
embed_dim=96,
depths=[2, 2, 6, 2],
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
config_class = TimmWrapperConfig
_no_split_modules = []

# used in Trainer to avoid passing `loss_kwargs` to model forward
accepts_loss_kwargs = False

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision", "timm"])
super().__init__(*args, **kwargs)
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,15 @@ def __init__(
else unwrapped_model.get_base_model().forward
)
forward_params = inspect.signature(model_forward).parameters
self.model_accepts_loss_kwargs = any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())

# Check if the model has explicit setup for loss kwargs,
# if not, check if `**kwargs` are in model.forward
if hasattr(model, "accepts_loss_kwargs"):
self.model_accepts_loss_kwargs = model.accepts_loss_kwargs
else:
self.model_accepts_loss_kwargs = any(
k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
)

self.neftune_noise_alpha = args.neftune_noise_alpha

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,7 +2164,7 @@ def _setup_devices(self) -> "torch.device":
if not is_accelerate_available():
raise ImportError(
f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
# We delay the init of `PartialState` to the end for clarity
accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
Expand Down

0 comments on commit b247fb7

Please sign in to comment.