diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 6aed8b89a7..98a78e6ac6 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -110,12 +110,17 @@ def _(func, types, args, kwargs): assert dst.block_size == src.block_size dst.codes.copy_(src.codes) dst.scale.copy_(src.scale) + dst.k.copy_(src.k) elif isinstance(dst, OptimStateFp8): - codes, scale = quantize_fp8(src, dst.block_size) + if src.dynamic_range_expansion: + codes, k = apply_dynamic_range_expansion(src, dst.block_size) + + codes, scale = quantize_fp8(codes, dst.block_size) + dst.codes.copy_(codes) dst.scale.copy_(scale) - + dst.k.copy_(k) else: dst.copy_(src.dequantize())