From 8ef00e033985819d29ca8470beacd85b1307afec Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 18 Jul 2024 19:44:29 +0000 Subject: [PATCH] fix weight_map --- optimum/fx/parallelization/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 01bb9259e49..772dfdccd1f 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -118,7 +118,7 @@ def parallelize_model( if os.path.isfile(index_path): with open(index_path) as f: index_dict = json.load(f) - parallel_ctx.weight_map = index_dict["weight_map"] + parallel_ctx.weight_map = {k : os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()} weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) if not use_safetensors: weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {}