Skip to content

Commit

Permalink
Fix for weights-only load (#1228)
Browse files Browse the repository at this point in the history
stack-info: PR: #1228, branch: drisspg/stack/19
  • Loading branch information
drisspg authored Nov 6, 2024
1 parent 5632535 commit 71a442a
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 67 deletions.
3 changes: 3 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ include = [
"test/dtypes/test_affine_quantized_float.py",
"torchao/quantization/weight_tensor_linear_activation_quantization.py",
"torchao/dtypes/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"test/prototype/low_bit_optim/**.py",

]
100 changes: 76 additions & 24 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
quantize_4bit_with_qmap,
_fp32_to_bf16_sr,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_6,
)

try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile):
x_rep = x.view(-1, 1).repeat(1, 100_000)

if compile:
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep)
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(
x_rep
)
else:
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)

Expand All @@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile):


class TestOptim(TestCase):
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize(
"optim_name",
["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"],
)
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
def test_optim_smoke(self, optim_name, dtype, device):
Expand Down Expand Up @@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="bitsandbytes 8-bit Adam only works for CUDA",
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1)

# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name):

# this will not run in CI because we can't install lpmm
@pytest.mark.skipif(lpmm is None, reason="lpmm is not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA"
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1)

# lpmm doesn't have Adam. use AdamW with no weight decay instead.
Expand Down Expand Up @@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
)
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1[0].requires_grad_(
False
) # make sure it can work in the presence of non-trainable params
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad,
model2.parameters(),
torch.optim.AdamW,
offload_gradients=offload_grad,
)

for _ in range(2):
Expand All @@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
)
def test_optim_cpu_offload_save_load(self):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
optim1 = low_bit_optim.CPUOffloadOptimizer(
model1.parameters(), torch.optim.AdamW
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -253,7 +293,9 @@ def test_optim_cpu_offload_save_load(self):

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand All @@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self):
def test_optim_bf16_stochastic_round_correctness(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(2024)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1).bfloat16()

# small LR so that weight update is small
# when bf16_stochastic_round=False, the test will fail after 1 iteration
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True)
optim2 = low_bit_optim._AdamW(
model2.parameters(), lr=1e-5, bf16_stochastic_round=True
)

# overfit on this sample
x = torch.randn(4, 32, device=device)
Expand All @@ -299,15 +345,19 @@ def test_optim_bf16_stochastic_round_correctness(self):
optim2.step()
optim2.zero_grad()

torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}")
torch.testing.assert_close(
loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}"
)


class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
)
@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
Expand Down Expand Up @@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls):
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG)
torch.distributed.all_reduce(
param.grad, op=torch.distributed.ReduceOp.AVG
)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

Expand Down
11 changes: 11 additions & 0 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW
from .cpu_offload import CPUOffloadOptimizer

__all__ = [
"Adam4bit",
"Adam8bit",
"AdamFp8",
"AdamW4bit",
"AdamW8bit",
"AdamWFp8",
"_AdamW",
"CPUOffloadOptimizer",
]
30 changes: 24 additions & 6 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,28 @@

import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.distributed._tensor import DTensor
from torch.optim import Optimizer

from .subclass_8bit import OptimState8bit
from .quant_utils import _fp32_to_bf16_sr
from .subclass_4bit import OptimState4bit
from .subclass_8bit import OptimState8bit
from .subclass_fp8 import OptimStateFp8
from .quant_utils import _fp32_to_bf16_sr


class _AdamBase(Optimizer):
def __init__(
self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, bf16_stochastic_round, is_adamw
self,
params,
lr,
betas,
eps,
weight_decay,
amsgrad,
*,
block_size,
bf16_stochastic_round,
is_adamw,
) -> None:
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
Expand All @@ -23,7 +33,13 @@ def __init__(
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
defaults = dict(
lr=torch.tensor(lr),
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
super().__init__(params, defaults)
self.block_size = block_size
self.bf16_stochastic_round = bf16_stochastic_round
Expand All @@ -45,7 +61,9 @@ def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
local_tensor=self._subclass_zeros(
p.to_local(), signed, self.block_size
),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
Expand Down
10 changes: 8 additions & 2 deletions torchao/prototype/low_bit_optim/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def __init__(
kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
"""
# default to fused CPU AdamW
if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and "fused" not in kwargs:
if (
optimizer_class is torch.optim.AdamW
and TORCH_VERSION_AT_LEAST_2_4
and "fused" not in kwargs
):
kwargs.update(fused=True)

param_groups = list(params)
Expand Down Expand Up @@ -77,7 +81,9 @@ def backward_hook(p_cuda):
self.param_cuda2cpu_map[p_cuda] = p_cpu

p_cuda.register_post_accumulate_grad_hook(backward_hook)
self.optim_dict[p_cuda] = optimizer_class([{"params": p_cpu, **param_group}], **kwargs)
self.optim_dict[p_cuda] = optimizer_class(
[{"params": p_cpu, **param_group}], **kwargs
)

@torch.no_grad()
def step(self, closure=None):
Expand Down
13 changes: 8 additions & 5 deletions torchao/prototype/low_bit_optim/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,17 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
#
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
rand_16bit = torch.randint(0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32)
rand_16bit = torch.randint(
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
)
x_f32_bits = x_f32.view(torch.int32)
x_fraction = x_f32_bits & 0xFFFF # lower 16 bits
x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits
x_fraction = x_f32_bits & 0xFFFF # lower 16 bits
x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits

x_f32_bits = torch.where(
rand_16bit < x_fraction, # this is True with the probability of p_fraction
x_bf16_towards_zero + 0x10000, # this might overflow, which will result in UB due to signed integer
rand_16bit < x_fraction, # this is True with the probability of p_fraction
x_bf16_towards_zero
+ 0x10000, # this might overflow, which will result in UB due to signed integer
x_bf16_towards_zero,
)
# alternative, slightly faster
Expand Down
Loading

0 comments on commit 71a442a

Please sign in to comment.