diff --git a/examples/scripts/ddpo.py b/examples/scripts/ddpo.py index 9a42c28a0f..eb06bf548f 100644 --- a/examples/scripts/ddpo.py +++ b/examples/scripts/ddpo.py @@ -99,7 +99,7 @@ def __init__(self, *, dtype, model_id, model_filename): cached_path = hf_hub_download(model_id, model_filename) except EntryNotFoundError: cached_path = os.path.join(model_id, model_filename) - state_dict = torch.load(cached_path) + state_dict = torch.load(cached_path, map_location=torch.device("cpu")) self.mlp.load_state_dict(state_dict) self.dtype = dtype self.eval()