diff --git a/src/benchmarks/fsdp/test.py b/src/benchmarks/fsdp/test.py index ef2e87dd..80a781fb 100644 --- a/src/benchmarks/fsdp/test.py +++ b/src/benchmarks/fsdp/test.py @@ -60,7 +60,7 @@ def main( print_rank0(f"Loading OLMo-core FSDP checkpoint from {checkpoint_dir}...") load_model_and_optim_state(checkpoint_dir, olmo_model, olmo_optim) - print_rank0(f"Checking state dict...") + print_rank0("Checking state dict...") with TorchFSDP.summon_full_params(torch_model), olmo_model.summon_full_params(): torch_state_dict = {k.replace("_fsdp_wrapped_module.", ""): v for k, v in torch_model.state_dict().items()} olmo_state_dict = olmo_model.state_dict()