Skip to content

Commit

Permalink
copy k values as well for copy method
Browse files Browse the repository at this point in the history
  • Loading branch information
MirMustafaAli committed Nov 8, 2024
1 parent 9d1c00c commit 7bc6ea4
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,17 @@ def _(func, types, args, kwargs):
assert dst.block_size == src.block_size
dst.codes.copy_(src.codes)
dst.scale.copy_(src.scale)
dst.k.copy_(src.k)

elif isinstance(dst, OptimStateFp8):
codes, scale = quantize_fp8(src, dst.block_size)
if src.dynamic_range_expansion:
codes, k = apply_dynamic_range_expansion(src, dst.block_size)

codes, scale = quantize_fp8(codes, dst.block_size)

dst.codes.copy_(codes)
dst.scale.copy_(scale)

dst.k.copy_(k)
else:
dst.copy_(src.dequantize())

Expand Down

0 comments on commit 7bc6ea4

Please sign in to comment.