Skip to content

Commit

Permalink
pin cpu data, cuda streams ver, non-blocking safe
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Mar 18, 2024
1 parent 1e8048e commit 9e7249e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 55 deletions.
4 changes: 2 additions & 2 deletions hqq/core/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ def __init__(self, linear_layer: nn.Module, peft_config: dict):
+ ")"
)
self.lora_A.data = peft_config["lora_init"]["lora_A"].to(
device=self.device, dtype=self.train_dtype, non_blocking=True
device=self.device, dtype=self.train_dtype
)
self.lora_B.data = peft_config["lora_init"]["lora_B"].to(
device=self.device, dtype=self.train_dtype, non_blocking=True
device=self.device, dtype=self.train_dtype
)
else:
# Init weights, as as the original LoRA implementation
Expand Down
171 changes: 122 additions & 49 deletions hqq/core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,10 @@ def __init__(

self.linear_layer = linear_layer

# Create streams
self.stream_zero = torch.cuda.Stream()
self.stream_scale = torch.cuda.Stream()

if initialize:
self.initialize()

Expand Down Expand Up @@ -433,6 +437,11 @@ def cuda(self, device):
else:
self.meta["scale"] = self.meta["scale"].to(device)

# #Use zero/scale with streams for dequantization is faster than packing in "zero_scale"
# for key in ["zero", "zero_q", "scale", "scale_q"]:
# if((key in self.meta) and self.offload_meta):
# self.meta[key] = self.meta[key].contiguous().cpu().pin_memory()

if self.offload_meta:
if "zero_scale" not in self.meta:
if self.meta["quant_scale"] and self.meta["quant_zero"]:
Expand All @@ -446,10 +455,12 @@ def cuda(self, device):
).to(self.compute_dtype)
del self.meta["scale"], self.meta["zero"]

self.meta["zero_scale"] = self.meta["zero_scale"].contiguous().cpu()
self.meta["zero_scale"] = (
self.meta["zero_scale"].contiguous().cpu().pin_memory()
)

if self.bias is not None:
self.bias = self.bias.to(self.compute_dtype).cuda(device)
self.bias = self.bias.to(device=device, dtype=self.compute_dtype)

self.W_q = nn.Parameter(self.W_q, requires_grad=False)
self.device = device
Expand Down Expand Up @@ -496,11 +507,13 @@ def load_state_dict(self, state_dict, strict=True, assign=False):
self.meta = state_dict["meta"]
self.bias = state_dict["bias"] if ("bias" in state_dict) else None

# Meta-data offloading
self.offload_meta = False
if "zero_scale" in self.meta:
self.offload_meta = (
True if (self.meta["zero_scale"].device.type == "cpu") else False
)
for key in ["zero", "zero_q", "scale", "scale_q", "zero_scale"]:
if key in self.meta:
if self.meta[key].device.type == "cpu":
self.offload_meta = True
self.meta[key] = self.meta[key].contiguous().pin_memory()

# Float view settings
if "unpack_view_dtype" not in self.meta:
Expand Down Expand Up @@ -575,30 +588,31 @@ def quantize(
def dequantize(self):
assert self.ready, "model was not quantized"
W_q, meta = self.W_q, self.meta
device = W_q.device
del_keys = set()

del_keys = []

# Zero/Scale packed together
if "zero_scale" in meta:
zero_scale = meta["zero_scale"].to(self.W_q.device)
zero_scale = meta["zero_scale"].to(device=device)

if zero_scale.dtype == uint8:
meta["zero_q"] = zero_scale[0]
del_keys.append("zero_q")
meta["scale_q"] = zero_scale[1]
del_keys.append("scale_q")
meta["zero_q"], meta["scale_q"] = zero_scale[0], zero_scale[1]
del_keys.update({"zero_q", "scale_q"})
else:
meta["zero"] = zero_scale[0]
del_keys.append("zero")
meta["scale"] = zero_scale[1]
del_keys.append("scale")
meta["zero"], meta["scale"] = zero_scale[0], zero_scale[1]
del_keys.update({"zero", "scale"})

if meta["quant_zero"]:
meta["zero"] = Quantizer.dequantize(meta["zero_q"], meta["meta_zero"])
del_keys.append("zero")
meta["zero"] = Quantizer.dequantize(
meta["zero_q"].to(device=device), meta["meta_zero"]
)
del_keys.add("zero")

if meta["quant_scale"]:
meta["scale"] = Quantizer.dequantize(meta["scale_q"], meta["meta_scale"])
del_keys.append("scale")
meta["scale"] = Quantizer.dequantize(
meta["scale_q"].to(device=device), meta["meta_scale"]
)
del_keys.add("scale")

W_est = Quantizer.dequantize(W_q, meta)

Expand Down Expand Up @@ -631,14 +645,15 @@ def forward_pytorch(self, x: Tensor) -> Tensor:
def forward_pytorch_compile(self, x: Tensor) -> Tensor:
return self.forward_pytorch(x)

##############################################
# Experimental
#############################################
############################################################################################
# ATen C++ / CUDA Bacekdn
##########################################################################################
# Requires building the aten backend
@torch.jit.ignore
def dequantize_Wq_aten(self, W_q: Tensor, meta: dict):
if meta["view_as_float"]:
W_q = W_q.view(meta["unpack_view_dtype"])

return hqq_aten.dequantize(
W_q,
meta["scale"],
Expand All @@ -654,45 +669,103 @@ def dequantize_aten(self):
# Dequantize
assert self.ready, "model was not quantized"
W_q, meta = self.W_q, self.meta
device = W_q.device
del_keys = set()

del_keys = []

# Zero/Scale packed together
if "zero_scale" in meta:
zero_scale = meta["zero_scale"].to(self.W_q.device)

zero_scale = meta["zero_scale"].to(device=device, non_blocking=True)
if zero_scale.dtype == uint8:
meta["zero_q"] = zero_scale[0]
del_keys.append("zero_q")
meta["scale_q"] = zero_scale[1]
del_keys.append("scale_q")
meta["zero_q"], meta["scale_q"] = zero_scale[0], zero_scale[1]
del_keys.update({"zero_q", "scale_q"})
else:
meta["zero"] = zero_scale[0]
del_keys.append("zero")
meta["scale"] = zero_scale[1]
del_keys.append("scale")
meta["zero"], meta["scale"] = zero_scale[0], zero_scale[1]
del_keys.update({"zero", "scale"})

# Dequantize zero_q / scale_q with device loading
if meta["quant_zero"]:
if meta["meta_zero"]["group_size"]:
meta["zero"] = self.dequantize_Wq_aten(
meta["zero_q"].to(device=device), meta["meta_zero"]
)
else:
meta["zero"] = Quantizer.dequantize(
meta["zero_q"].to(device=device), meta["meta_zero"]
)
del_keys.add("zero")

if meta["quant_scale"]:
if meta["meta_scale"]["group_size"]:
meta["scale"] = self.dequantize_Wq_aten(
meta["scale_q"], meta["meta_scale"]
meta["scale_q"].to(device=device), meta["meta_scale"]
)
del_keys.append("scale")
else:
meta["scale"] = Quantizer.dequantize(
meta["scale_q"], meta["meta_scale"]
meta["scale_q"].to(device=device), meta["meta_scale"]
)
del_keys.append("scale")
del_keys.add("scale")

if meta["quant_zero"]:
if meta["meta_zero"]["group_size"]:
meta["zero"] = self.dequantize_Wq_aten(
meta["zero_q"], meta["meta_zero"]
)
del_keys.append("zero")
# Reconstruct the weights
W_est = self.dequantize_Wq_aten(W_q, meta)

# Cleanup
for key in del_keys:
del meta[key]

return W_est

# Much faster with data-offloading zero_q/scale_q but takes more VRAM
def dequantize_aten_with_streams(self):
# Dequantize
assert self.ready, "model was not quantized"
W_q, meta = self.W_q, self.meta
device = W_q.device
del_keys = set()

# Zero/Scale packed together
if "zero_scale" in meta:
zero_scale = meta["zero_scale"].to(device=device, non_blocking=True)
if zero_scale.dtype == uint8:
meta["zero_q"], meta["scale_q"] = zero_scale[0], zero_scale[1]
del_keys.update({"zero_q", "scale_q"})
else:
meta["zero"] = Quantizer.dequantize(meta["zero_q"], meta["meta_zero"])
del_keys.append("zero")
meta["zero"], meta["scale"] = zero_scale[0], zero_scale[1]
del_keys.update({"zero", "scale"})

# Using non_blocking=False for the moment, otherwise it can result in strange behaviors
non_blocking = False
with torch.cuda.stream(self.stream_zero):
if meta["quant_zero"]:
if meta["meta_zero"]["group_size"]:
meta["zero"] = self.dequantize_Wq_aten(
meta["zero_q"].to(device=device, non_blocking=non_blocking),
meta["meta_zero"],
)
else:
meta["zero"] = Quantizer.dequantize(
meta["zero_q"].to(device=device, non_blocking=non_blocking),
meta["meta_zero"],
)
del_keys.add("zero")

with torch.cuda.stream(self.stream_scale):
if meta["quant_scale"]:
if meta["meta_scale"]["group_size"]:
meta["scale"] = self.dequantize_Wq_aten(
meta["scale_q"].to(device=device, non_blocking=non_blocking),
meta["meta_scale"],
)
else:
meta["scale"] = Quantizer.dequantize(
meta["scale_q"].to(device=device, non_blocking=non_blocking),
meta["meta_scale"],
)
del_keys.add("scale")

# Wait for streams to finish
torch.cuda.synchronize()

# Reconstruct the weights
W_est = self.dequantize_Wq_aten(W_q, meta)

# Cleanup
Expand All @@ -717,7 +790,7 @@ def hqq_base_quant_config(
group_size: int = 64,
quant_zero: bool = True,
quant_scale: bool = False,
offload_meta: bool = False,
offload_meta: bool = False, # meta-data should be quantized with the same settings to use offload_meta
view_as_float: bool = False,
):
assert (
Expand Down
14 changes: 10 additions & 4 deletions hqq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..core.quantize import HQQLinear

# Defined what is qualified as "linear layer"
_LINEAR_LAYERS = [nn.Linear]
_QUANT_LAYERS = [nn.Linear]
_IGNORE_LINEAR = ["lm_head"]


Expand Down Expand Up @@ -53,7 +53,7 @@ def name_to_linear_tag(name: str) -> str:
def get_linear_tags_from_model(model, ignore: list) -> list:
linear_tags = set()
for name, module in model.named_modules():
if (type(module) in _LINEAR_LAYERS) and (name.split(".")[-1] not in ignore):
if (type(module) in _QUANT_LAYERS) and (name.split(".")[-1] not in ignore):
linear_tags.add(name_to_linear_tag(name))
return list(linear_tags)

Expand All @@ -71,7 +71,7 @@ def patch_nonlinearlayers(

tmp_mapping = {}
for name, module in model.named_modules():
if (type(module) not in _LINEAR_LAYERS) and (name not in ignore_tags):
if (type(module) not in _QUANT_LAYERS) and (name not in ignore_tags):
tmp_mapping[name] = module

for name in tqdm(tmp_mapping, disable=not verbose):
Expand All @@ -90,7 +90,7 @@ def patch_linearlayers(

tmp_mapping = {}
for name, module in model.named_modules():
if (type(module) in _LINEAR_LAYERS) and (name not in ignore_tags):
if (type(module) in _QUANT_LAYERS) and (name not in ignore_tags):
tmp_mapping[name] = module

for name in tqdm(tmp_mapping, disable=not verbose):
Expand Down Expand Up @@ -241,6 +241,9 @@ def _patch_linear(linear_layer, quant_config):
# Set base class
model.base_class = cls

# Sync
torch.cuda.synchronize()

return model

# Prepares model weights by iterating through modules. It might some parameters that are NOT modules like model.param1
Expand Down Expand Up @@ -363,4 +366,7 @@ def _load_module(module, params=None):
# Set base class
model.base_class = cls

# Sync
torch.cuda.synchronize()

return model

0 comments on commit 9e7249e

Please sign in to comment.