diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 31997177a262..9a368b7a0a25 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -116,7 +116,7 @@ def show_transformer_file_map(self): self._dump_mapping(self.transformer_file_map, 'rank_to_transformer_files') def _build_global_state(self): - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) @@ -137,14 +137,17 @@ def get_final_norm_layer_id(self): def get_iteration(self): if not ITERATION_KEY in self.global_state: - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) return self.global_state[ITERATION_KEY] def get_embedding_state(self, tp_index: int) -> Dict: assert tp_index in self.tp_to_embedding_map.keys() - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] + sd_list = [ + torch.load(fname, map_location=torch.device('cpu'), weights_only=False) + for fname in self.tp_to_embedding_map[tp_index] + ] sd = self._merge_state_dicts(sd_list) return sd @@ -154,7 +157,7 @@ def get_embedding_files(self, tp_index: int) -> list: def _get_checkpoint_value(self, key): if not key in self.global_state: - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) self.global_state[key] = sd.get(key, None) return self.global_state[key] @@ -169,7 +172,7 @@ def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict: assert tp_index < self.tp_degree assert pp_index < self.pp_degree fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index) - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] + sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list] merged_sd = None for sd in sd_list: @@ -185,7 +188,7 @@ def get_transformer_state(self, tp_index: int, pp_index: int) -> list: assert pp_index < self.pp_degree t_list = [] for fname_list in self.transformer_file_map[(tp_index, pp_index)]: - sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] + sd_list = [torch.load(fname, map_location=torch.device('cpu'), weights_only=False) for fname in fname_list] sd = self._merge_state_dicts(sd_list) t_list.append(sd) return t_list @@ -196,7 +199,7 @@ def get_pp_transformer_map(self, pp_index: int) -> list: def get_final_norm_state(self, tp_index: int) -> Dict: assert tp_index in self.tp_to_final_norm_map.keys() - sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) + sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'), weights_only=False) return sd def get_final_norm_files(self, tp_index: int) -> list: diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index e5974a30df22..f7b75eee66d0 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -150,7 +150,7 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): - state_dict = torch.load(optim_files[dp_index], map_location='cpu') + state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False) flat_state = dict( exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], @@ -214,7 +214,7 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None): raise ValueError(f"Cannot parse dp_rank from {p}") paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] - shards = [torch.load(p) for p in paths] + shards = [torch.load(p, weights_only=False) for p in paths] if state == "step": assert all(v == shards[0] for v in shards), "All shards must have the same step value" @@ -404,7 +404,7 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size): def _parse_model_states_stage3(files): - return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES] + return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES] def _save_optimizer_state(args, ds_checkpoint): @@ -420,7 +420,7 @@ def _save_optimizer_state(args, ds_checkpoint): def _save_optimizer_state_stage3(args, optim_files): - sd = torch.load(optim_files[0], map_location=torch.device('cpu')) + sd = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False) output_sd = sd[OPTIMIZER_STATE_DICT] output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS] zero_output_folder = os.path.join(args.output_folder, "zero") @@ -446,7 +446,7 @@ def _get_checkpoint_files(checkpoint_dir, glob_pattern): def _get_zero_stage(optim_files): - state_dict = torch.load(optim_files[0], map_location=torch.device('cpu')) + state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'), weights_only=False) optimizer_state = state_dict[OPTIMIZER_STATE_DICT] zero_stage = optimizer_state.get(ZERO_STAGE, 1) return zero_stage @@ -454,7 +454,7 @@ def _get_zero_stage(optim_files): def _inject_missing_state(ds_checkpoint): if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state: - sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu')) + sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False) if UNIVERSAL_CHECKPOINT_INFO not in sd: ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {} ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][ @@ -488,7 +488,7 @@ def main(args): slice_shapes = [] for mp_rank_file in ds_checkpoint.mp_rank_files: - mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu')) + mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'), weights_only=False) slice_shapes += mp_sd[PARAM_SHAPES] # fix back to normal flat dict, merge duplicates for tp>1 diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index 064891a8bb54..266d5a063595 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -34,7 +34,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): step = None for key in hp_keys: ckpt_file = os.path.join(folder, f"{key}.pt") - ckpt_dict = torch.load(ckpt_file) + ckpt_dict = torch.load(ckpt_file, weights_only=False) if key == "step": step = ckpt_dict diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index 6730b93dfd4f..c85f0241005d 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -54,7 +54,7 @@ def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], st state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) merged_sd = None for state_file in state_file_list: - sd = torch.load(state_file, map_location=torch.device('cpu')) + sd = torch.load(state_file, map_location=torch.device('cpu'), weights_only=False) for key in keys_to_ignore: sd.pop(key, None) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 6574d49fb132..cfca1ff4fe4c 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -452,7 +452,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): checkpoint = sd_loader['checkpoints'] if type(checkpoint) is list: - self.sd = torch.load(checkpoint[0], map_location='cpu') + self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False) self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) @@ -460,7 +460,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): for i in range(1, len(checkpoint)): if not dist.is_initialized() or dist.get_rank() == 0: print(f"loading checkpoint ({i})") - self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name()) + self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False) self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) else: diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index d88d99ebebfd..b17bb886838f 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -80,7 +80,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool: else: model_param_json_fname = "pytorch_model.bin.index.json" model_file_fname = "pytorch_model.bin" - self._checkpoint_load_fn = partial(torch.load, map_location="cpu") + self._checkpoint_load_fn = partial(torch.load, map_location="cpu", weights_only=False) model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname) diff --git a/deepspeed/inference/v2/model_implementations/inference_policy_base.py b/deepspeed/inference/v2/model_implementations/inference_policy_base.py index d5a326c03599..2f4266a8cb88 100644 --- a/deepspeed/inference/v2/model_implementations/inference_policy_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_policy_base.py @@ -205,7 +205,7 @@ def populate_model_parameters(self) -> None: buffer_path = make_param_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size) metadata_path = make_metadata_filename(self._inf_checkpoint_path, self.model.tp_rank, self.model.tp_size) - buffer = torch.load(buffer_path) + buffer = torch.load(buffer_path, weights_only=False) metadata = json.load(open(metadata_path, "r")) metadata = ModelMetadata.parse_raw(metadata) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 1c5745dcf168..7afe6ca903fb 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -415,7 +415,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") for i in range(len(checkpoint)): - sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')] + sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)] load_model_with_checkpoint(replaced_module, sd, mp_replace, @@ -437,7 +437,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j] for j in range(sd_count) ] - sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files] + sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files] load_model_with_checkpoint(replaced_module, sds, mp_replace, @@ -457,7 +457,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): pbar.update(1) ckpt_file = os.path.join(base_dir1, checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i] - sds = [torch.load(ckpt_file, map_location='cpu')] + sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)] load_model_with_checkpoint(replaced_module, sds, mp_replace, @@ -624,7 +624,7 @@ def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=No from safetensors.torch import load_file sd = load_file(checkpoint) else: - sd = torch.load(checkpoint, map_location='cpu') + sd = torch.load(checkpoint, map_location='cpu', weights_only=False) policy = {} if orig_class is not None: diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 6cfd66f1cc38..b8df7499450d 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -22,7 +22,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") assert os.path.isfile( optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - optim_sd = torch.load(optim_state_path) + optim_sd = torch.load(optim_state_path, weights_only=False) self._load_global_state(optim_sd) diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index e26e3243c4b5..e834bf0d22d7 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -58,7 +58,7 @@ def load(self, path: str, map_location=None): if not self.enable_nebula_load and first_load_flag: self.tag_flag = tag logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...") - partition = torch.load(path, map_location=map_location) + partition = torch.load(path, map_location=map_location, weights_only=False) logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .") return partition diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py index 5cd44864bb2e..076c638532ad 100644 --- a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -25,7 +25,7 @@ def save(self, state_dict, path: str): def load(self, path: str, map_location=None): logger.info(f"[Torch] Loading checkpoint from {path}...") - partition = torch.load(path, map_location=map_location) + partition = torch.load(path, map_location=map_location, weights_only=False) logger.info(f"[Torch] Loaded checkpoint from {path}.") return partition diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 04d52319ae8c..99a5ecf41a2f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2741,7 +2741,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa assert os.path.isfile( optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - optim_sd = torch.load(optim_state_path) + optim_sd = torch.load(optim_state_path, weights_only=False) self._load_global_state_stage3(optim_sd) key_list = ["fp32", "exp_avg", "exp_avg_sq"] @@ -2799,7 +2799,7 @@ def load_hp_checkpoint_state(self, folder, key): local_rank = dist.get_local_rank() # Load tensors from files and reshape them to flat vectors - loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt")).view(-1) + loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1) # Partition the loaded data according to the local rank world_size = dist.get_world_size(group=self.dp_process_group) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index c0768deae62b..e93cb1c95f15 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -102,7 +102,7 @@ def get_model_state_files(checkpoint_dir): def parse_model_states(files): zero_model_states = [] for file in files: - state_dict = torch.load(file, map_location=device) + state_dict = torch.load(file, map_location=device, weights_only=False) if BUFFER_NAMES not in state_dict: raise ValueError(f"{file} is not a model state checkpoint") @@ -149,7 +149,7 @@ def parse_optim_states(files, ds_checkpoint_dir): total_files = len(files) state_dicts = [] for f in tqdm(files, desc='Loading checkpoint shards'): - state_dict = torch.load(f, map_location=device, mmap=True) + state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False) # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights # and also handle the case where it was already removed by another helper script state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py index 3fb13b214ea0..001c08f1a99f 100644 --- a/tests/unit/checkpoint/common.py +++ b/tests/unit/checkpoint/common.py @@ -218,7 +218,7 @@ def checkpoint_correctness_verification(config_dict, for root, _, files in os.walk(save_folder): for f in files: if "_expert_" in f and "_model_states" in f: - expert = torch.load(os.path.join(root, f)) + expert = torch.load(os.path.join(root, f), weights_only=False) needed, storages = 0, {} for name, tensor in expert.items(): needed += tensor.size().numel() diff --git a/tests/unit/checkpoint/test_universal_checkpoint.py b/tests/unit/checkpoint/test_universal_checkpoint.py index 27ddf0cdef39..46d4294bdd0d 100644 --- a/tests/unit/checkpoint/test_universal_checkpoint.py +++ b/tests/unit/checkpoint/test_universal_checkpoint.py @@ -181,7 +181,7 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam): ) hidden_dim = 10 - loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt") + loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False) ds_config["checkpoint"] = {"load_universal": True} univ_model = SimpleModel(hidden_dim) diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index 84b4eca6e2ca..44966b331d0f 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -264,7 +264,7 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) if load_optim: - saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file)) + saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file), weights_only=False) curr_sd = model.optimizer.optimizer.state_dict() compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys) @@ -523,7 +523,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage): all_ckpt_folder = os.path.join(tmpdir, 'all_params') ds_engine.save_checkpoint(all_ckpt_folder) all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00') - loaded_all_param_model = torch.load(all_params_ckpt_file)['module'] + loaded_all_param_model = torch.load(all_params_ckpt_file, weights_only=False)['module'] all_param_names = set([n for n, p in model.named_parameters()]) assert set(loaded_all_param_model.keys()) == all_param_names @@ -536,7 +536,7 @@ def test_save_exclude_frozen_weights(self, tmpdir, zero_stage): # Excluding frozen parameters should reduce checkpoint size assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file) - loaded_trainable_param_model = torch.load(trainable_ckpt_file)['module'] + loaded_trainable_param_model = torch.load(trainable_ckpt_file, weights_only=False)['module'] frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad]) loaded_trainable_param_names = set(loaded_trainable_param_model.keys()) overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names) @@ -575,7 +575,7 @@ def test_save_exclude_custom_frozen_weights(self, tmpdir, zero_stage): custom_state_dict_ckpt_file = get_model_ckpt_name_for_rank( os.path.join(custom_state_dict_ckpt_folder, 'global_step0'), '00') - loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file)['module'] + loaded_custom_state_dict_param_model = torch.load(custom_state_dict_ckpt_file, weights_only=False)['module'] loaded_custom_state_dict_param_names = set(loaded_custom_state_dict_param_model.keys()) custom_state_dict_param_names = set([k for k, v in model.state_dict().items()]) @@ -618,7 +618,8 @@ def test_save_tensor_clone(self, tmpdir, zero_stage, use_cpu_device): clone_ckpt_file = os.path.join(tmpdir, 'clone_ckpt.pt') torch.save(clone_state_dict, clone_ckpt_file) - compare_state_dicts(torch.load(ref_ckpt_file), torch.load(clone_ckpt_file)) + compare_state_dicts(torch.load(ref_ckpt_file, weights_only=False), + torch.load(clone_ckpt_file, weights_only=False)) class TestZeRONonDistributed(DistributedTest): diff --git a/tests/unit/model_parallelism/test_configurable_parallel_mp.py b/tests/unit/model_parallelism/test_configurable_parallel_mp.py index cca1ef3584ad..a7b0d3431ee9 100644 --- a/tests/unit/model_parallelism/test_configurable_parallel_mp.py +++ b/tests/unit/model_parallelism/test_configurable_parallel_mp.py @@ -170,7 +170,7 @@ def test(self, baseline_mp2, inputs, class_tmpdir): test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name)) if dist.get_rank() == 0: load_path = os.path.join(class_tmpdir, "output.pt") - baseline = torch.load(load_path) + baseline = torch.load(load_path, weights_only=False) test = test.cpu() assert torch.allclose( baseline, test, diff --git a/tests/unit/model_parallelism/test_configurable_parallel_pp.py b/tests/unit/model_parallelism/test_configurable_parallel_pp.py index e50fd18577b1..df469044e186 100644 --- a/tests/unit/model_parallelism/test_configurable_parallel_pp.py +++ b/tests/unit/model_parallelism/test_configurable_parallel_pp.py @@ -225,7 +225,7 @@ def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resiz assert torch.is_tensor(test[0][0]) test = test[0][0].cpu() load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt") - baseline = torch.load(load_path) + baseline = torch.load(load_path, weights_only=False) assert torch.allclose( baseline, test, atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"