From 7bc6ea4e060478e0b7deac641a248f35370dabc1 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Wed, 6 Nov 2024 07:12:02 -0600 Subject: [PATCH] copy k values as well for copy method --- torchao/prototype/low_bit_optim/subclass_fp8.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 6aed8b89a7..98a78e6ac6 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -110,12 +110,17 @@ def _(func, types, args, kwargs): assert dst.block_size == src.block_size dst.codes.copy_(src.codes) dst.scale.copy_(src.scale) + dst.k.copy_(src.k) elif isinstance(dst, OptimStateFp8): - codes, scale = quantize_fp8(src, dst.block_size) + if src.dynamic_range_expansion: + codes, k = apply_dynamic_range_expansion(src, dst.block_size) + + codes, scale = quantize_fp8(codes, dst.block_size) + dst.codes.copy_(codes) dst.scale.copy_(scale) - + dst.k.copy_(k) else: dst.copy_(src.dequantize())