Skip to content

Commit

Permalink
Use autocast depending on the version
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Nov 11, 2024
1 parent f9ae24c commit 58d95fc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions benchmarks/benchmark_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
torch_version = torch.__version__.split("+")[0]

if version.parse(torch_version) >= version.parse("2.3.0"):
from torch.amp import GradScaler
from torch.amp import GradScaler, autocast
else:
from torch.cuda.amp import GradScaler
from torch.cuda.amp import GradScaler, autocast


@dataclass(frozen=True)
Expand Down Expand Up @@ -115,7 +115,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose:

batch = torch.randint(0, len(X_train), (batch_size,))

with torch.amp.autocast() if args.use_amp else nullcontext():
with autocast() if args.use_amp else nullcontext():
loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
grad_scaler.scale(loss).backward()

Expand Down

0 comments on commit 58d95fc

Please sign in to comment.