Skip to content

Commit

Permalink
fix benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent authored and mryab committed Jul 13, 2024
1 parent f0358f7 commit 4ebc04f
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion benchmarks/benchmark_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
from hivemind.optim.optimizer import Optimizer
from hivemind.utils.crypto import RSAPrivateKey

from packaging import version

torch_version = torch.__version__.split("+")[0]

if version.parse(torch_version) >= version.parse("2.3.0"):
from torch.amp import GradScaler
else:
from torch.cuda.amp import GradScaler


@dataclass(frozen=True)
class TrainingArguments:
Expand Down Expand Up @@ -98,7 +107,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose:
grad_scaler = hivemind.GradScaler()
else:
# check that hivemind.Optimizer supports regular PyTorch grad scaler as well
grad_scaler = torch.amp.GradScaler(enabled=args.use_amp)
grad_scaler = GradScaler(enabled=args.use_amp)

prev_time = time.perf_counter()

Expand Down

0 comments on commit 4ebc04f

Please sign in to comment.