Skip to content

Commit

Permalink
Dtype fix for backprop matmul (intel#104)
Browse files Browse the repository at this point in the history
Co-authored-by: SarahByrneIntel <[email protected]>
  • Loading branch information
SarahByrneIntel and SarahByrneIntel authored Jul 23, 2024
1 parent 0562410 commit 86e4e1e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion intel_npu_acceleration_library/nn/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Iterable[Union[torch.Tensor, Non

dl_dx = run_matmul(grad_output, torch.transpose(w, -1, -2))
dl_dw = run_matmul(
torch.transpose(grad_output, -1, -2), torch.transpose(x, -1, -2)
torch.transpose(grad_output, -1, -2),
torch.transpose(x, -1, -2).to(torch.float16),
)
return dl_dx, dl_dw, None

0 comments on commit 86e4e1e

Please sign in to comment.