diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index aeb42cd513..02f563dadc 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -490,6 +490,16 @@ def get_lm_trainer(hf_model, should_save_peft_only=should_save_peft_only, ) + # On torch 2.0, fsdp wrapped modules can not have both frozen and unfrozen params. + # On 2.1+, if you have use_orig_params=True, they can. So we need a special case for the tests here. + if version.parse(torch.__version__) < version.parse('2.1.0') and peft_config is not None: + for name, module in model.named_modules(): + if 'lora' in name.lower() and 'default' in name.lower(): + has_parameters = any(True for _ in module.parameters()) + has_buffers = any(True for _ in module.buffers()) + if has_parameters or has_buffers: + module._fsdp_wrap = True # type: ignore + vocab_size = hf_model.config.vocab_size sequence_length = 4 size = 4