Skip to content

Commit

Permalink
Update to all weights only False
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Nov 18, 2024
1 parent a34652e commit 703e38f
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict,
for root, _, files in os.walk(save_folder):
for f in files:
if "_expert_" in f and "_model_states" in f:
expert = torch.load(os.path.join(root, f), weights_only=True)
expert = torch.load(os.path.join(root, f), weights_only=False)
needed, storages = 0, {}
for name, tensor in expert.items():
needed += tensor.size().numel()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/checkpoint/test_universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
)

hidden_dim = 10
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=True)
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)

ds_config["checkpoint"] = {"load_universal": True}
univ_model = SimpleModel(hidden_dim)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
all_ckpt_folder = os.path.join(tmpdir, 'all_params')
ds_engine.save_checkpoint(all_ckpt_folder)
all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00')
loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=True)['module']
loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=False)['module']
all_param_names = set([n for n, p in model.named_parameters()])
assert set(loaded_all_param_model.keys()) == all_param_names

Expand All @@ -536,7 +536,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
# Excluding frozen parameters should reduce checkpoint size
assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file)

loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=True)['module']
loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=False)['module']
frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad])
loaded_trainable_param_names = set(loaded_trainable_param_model.keys())
overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names)
Expand Down Expand Up @@ -575,7 +575,7 @@ def test_save_exclude_custom_frozen_weights(self, tmpdir, zero_stage):

custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank(
os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00')
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=True)['module']
loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=False)['module']
loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys())

custom_state_dict_param_names = set([k for k, v in model.state_dict().items()])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test(self, baseline_mp2, inputs, class_tmpdir):
test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
if dist.get_rank() == 0:
load_path = os.path.join(class_tmpdir, "output.pt")
baseline = torch.load(load_path, weights_only=True)
baseline = torch.load(load_path, weights_only=False)
test = test.cpu()
assert torch.allclose(
baseline, test,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resiz
assert torch.is_tensor(test[0][0])
test = test[0][0].cpu()
load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
baseline = torch.load(load_path, weights_only=True)
baseline = torch.load(load_path, weights_only=False)
assert torch.allclose(
baseline, test,
atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"
Expand Down

0 comments on commit 703e38f

Please sign in to comment.