From 03c992b9fc2f1e7170d1ebf60f18f29f65097ecd Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:20:42 -0500 Subject: [PATCH] Pin checkpoint state dict flattening patch (#3700) --- composer/utils/checkpoint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index d0a8cd8115..baade3dbea 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -595,7 +595,10 @@ def dist_cp_load( storage_reader: StorageReader, load_planner: Optional[LoadPlanner] = None, ): - if version.parse(torch.__version__) >= version.parse('2.4.0'): + if ( + version.parse(torch.__version__) >= version.parse('2.4.0') and + version.parse(torch.__version__) < version.parse('2.5.0') + ): from torch.distributed.checkpoint.utils import CheckpointException try: dist_cp.load(