From 4ebc04f135e0a5f66e23ba27e1c4d6fc5fab767e Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Sat, 13 Jul 2024 06:09:12 -0500 Subject: [PATCH] fix benchmarking --- benchmarks/benchmark_optimizer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_optimizer.py b/benchmarks/benchmark_optimizer.py index b976f3347..dc81b4599 100644 --- a/benchmarks/benchmark_optimizer.py +++ b/benchmarks/benchmark_optimizer.py @@ -16,6 +16,15 @@ from hivemind.optim.optimizer import Optimizer from hivemind.utils.crypto import RSAPrivateKey +from packaging import version + +torch_version = torch.__version__.split("+")[0] + +if version.parse(torch_version) >= version.parse("2.3.0"): + from torch.amp import GradScaler +else: + from torch.cuda.amp import GradScaler + @dataclass(frozen=True) class TrainingArguments: @@ -98,7 +107,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: grad_scaler = hivemind.GradScaler() else: # check that hivemind.Optimizer supports regular PyTorch grad scaler as well - grad_scaler = torch.amp.GradScaler(enabled=args.use_amp) + grad_scaler = GradScaler(enabled=args.use_amp) prev_time = time.perf_counter()