diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 01bb9259e49..772dfdccd1f 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -118,7 +118,7 @@ def parallelize_model( if os.path.isfile(index_path): with open(index_path) as f: index_dict = json.load(f) - parallel_ctx.weight_map = index_dict["weight_map"] + parallel_ctx.weight_map = {k : os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()} weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) if not use_safetensors: weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {}