From e0dec2727245136b593d3abb11a4b529491cf7f2 Mon Sep 17 00:00:00 2001 From: zhangsibo1129 <134488188+zhangsibo1129@users.noreply.github.com> Date: Sat, 23 Dec 2023 17:13:38 +0800 Subject: [PATCH] reformatted (#1129) --- examples/scripts/ddpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/ddpo.py b/examples/scripts/ddpo.py index 9a42c28a0f..eb06bf548f 100644 --- a/examples/scripts/ddpo.py +++ b/examples/scripts/ddpo.py @@ -99,7 +99,7 @@ def __init__(self, *, dtype, model_id, model_filename): cached_path = hf_hub_download(model_id, model_filename) except EntryNotFoundError: cached_path = os.path.join(model_id, model_filename) - state_dict = torch.load(cached_path) + state_dict = torch.load(cached_path, map_location=torch.device("cpu")) self.mlp.load_state_dict(state_dict) self.dtype = dtype self.eval()