Skip to content

Commit

Permalink
Lint fixes test folders#1
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Dec 10, 2024
1 parent a6f8676 commit ae6b157
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 82 deletions.
7 changes: 5 additions & 2 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ include = [
"torchao/quantization/**/*.py",
"torchao/dtypes/**/*.py",
"torchao/sparsity/**/*.py",
"torchao/profiler/**/*.py",
"torchao/testing/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"torchao/utils.py",
"torchao/ops.py",
"torchao/_executorch_ops.py",
# Test folders
"test/dora/**/*.py",
"test/dtypes/**/*.py",
"test/float8/**/*.py",
"test/galore/**/*.py",
"test/hqq/**/*.py",
"test/quantization/**/*.py",
"test/dtypes/**/*.py",
"test/sparsity/**/*.py",
"test/prototype/low_bit_optim/**.py",
]
Expand Down
20 changes: 6 additions & 14 deletions test/dora/test_dora_fusion.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import itertools
import sys

import pytest
import torch

from torchao.prototype.dora.kernels.matmul import triton_mm
from torchao.prototype.dora.kernels.smallk import triton_mm_small_k

if sys.version_info < (3, 11):
pytest.skip("requires Python >= 3.11", allow_module_level=True)

triton = pytest.importorskip("triton", reason="requires triton")

import itertools

import torch

from torchao.prototype.dora.kernels.matmul import triton_mm
from torchao.prototype.dora.kernels.smallk import triton_mm_small_k

torch.manual_seed(0)

# Test configs
Expand Down Expand Up @@ -48,13 +46,7 @@ def _arg_to_id(arg):


def check(expected, actual, dtype):
if dtype == torch.float32:
atol = 1e-4
elif dtype == torch.float16:
atol = 1e-3
elif dtype == torch.bfloat16:
atol = 1e-2
else:
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise ValueError(f"Unsupported dtype: {dtype}")
diff = (expected - actual).abs().max()
print(f"diff: {diff}")
Expand Down
16 changes: 5 additions & 11 deletions test/dora/test_dora_layer.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
import itertools
import sys

import pytest
import torch

from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear

if sys.version_info < (3, 11):
pytest.skip("requires Python >= 3.11", allow_module_level=True)

bnbnn = pytest.importorskip("bitsandbytes.nn", reason="requires bitsandbytes")
hqq_core = pytest.importorskip("hqq.core.quantize", reason="requires hqq")

import itertools

import torch

# Import modules as opposed to classes directly, otherwise pytest.importorskip always skips
Linear4bit = bnbnn.Linear4bit
BaseQuantizeConfig = hqq_core.BaseQuantizeConfig
HQQLinear = hqq_core.HQQLinear
from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear


def check(expected, actual, dtype):
if dtype == torch.float32:
atol = 1e-4
elif dtype == torch.float16:
atol = 1e-3
elif dtype == torch.bfloat16:
atol = 1e-2
else:
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise ValueError(f"Unsupported dtype: {dtype}")
diff = (expected - actual).abs().max()
print(f"diff: {diff}")
Expand Down
2 changes: 1 addition & 1 deletion test/galore/memory_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _convert_to_units(df, col):
convert_cols_to_MB = {col: partial(_convert_to_units, col=col) for col in COL_NAMES}

df = pd.DataFrame(
[l[1:] for l in df.iloc[:, 1].to_list()], columns=COL_NAMES
[row[1:] for row in df.iloc[:, 1].to_list()], columns=COL_NAMES
).assign(**convert_cols_to_MB)
df["Total"] = df.sum(axis=1)
return df
Expand Down
4 changes: 2 additions & 2 deletions test/galore/profile_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run(args, file_prefix):
model_config = LlamaConfig()
try:
model_config_dict = getattr(model_configs, args.model_config.upper())
except:
except Exception:
raise ValueError(f"Model config {args.model_config} not found")
model_config.update(model_config_dict)
model = LlamaForCausalLM(model_config).to("cuda")
Expand Down Expand Up @@ -163,7 +163,7 @@ def run(args, file_prefix):
if args.torch_profiler:
print(f"Finished profiling, outputs saved to {args.output_dir}/{file_prefix}*")
else:
print(f"Finished profiling")
print("Finished profiling")


if __name__ == "__main__":
Expand Down
2 changes: 0 additions & 2 deletions test/galore/profiling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def get_cuda_memory_usage(units="MB", show=True):


def export_memory_snapshot(prefix) -> None:

# Prefix for file names.
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{prefix}_{timestamp}"
Expand Down Expand Up @@ -115,7 +114,6 @@ def trace_handler(
export_memory_timeline=True,
print_table=True,
):

timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = os.path.join(output_dir, f"{prefix}_{timestamp}")

Expand Down
113 changes: 79 additions & 34 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,136 @@
import unittest

import torch

from torchao.quantization import (
ZeroPointDomain,
MappingType,
ZeroPointDomain,
int4_weight_only,
uintx_weight_only,
)

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
)
from torchao.quantization import (
uintx_weight_only,
int4_weight_only,
)

cuda_available = torch.cuda.is_available()

#Parameters
device = 'cuda:0'
compute_dtype = torch.bfloat16
group_size = 64
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size) #axis=1
preserve_zero = False
# Parameters
device = "cuda:0"
compute_dtype = torch.bfloat16
group_size = 64
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size) # axis=1
preserve_zero = False
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_dtype = compute_dtype
inner_k_tiles = 8
in_features = 4096
out_features = 11800
torch_seed = 100
zero_point_dtype = compute_dtype
inner_k_tiles = 8
in_features = 4096
out_features = 11800
torch_seed = 100


def _init_data(in_features, out_features, compute_dtype, device, torch_seed):
torch.random.manual_seed(torch_seed)
linear_layer = torch.nn.Linear(in_features, out_features, bias=False).to(device)
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20.
x = (
torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)
/ 20.0
)
y_ref = linear_layer(x)
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)
return W, x, y_ref


def _eval_hqq(dtype):
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed)
W, x, y_ref = _init_data(
in_features, out_features, compute_dtype, device, torch_seed
)

dummy_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)
dummy_linear = torch.nn.Linear(
in_features=in_features, out_features=out_features, bias=False
)
dummy_linear.weight.data = W
if dtype == torch.uint4:
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(dummy_linear).weight
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(
dummy_linear
).weight
else:
q_tensor_hqq = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)(dummy_linear).weight
q_tensor_hqq = uintx_weight_only(
dtype, group_size=max(block_size), use_hqq=True
)(dummy_linear).weight

quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device)
quant_linear_layer = torch.nn.Linear(
W.shape[1], W.shape[0], bias=False, device=W.device
)
del quant_linear_layer.weight
quant_linear_layer.weight = q_tensor_hqq
dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item()
dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()
dot_product_error = (
(y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()
)

return dequantize_error, dot_product_error


@unittest.skipIf(not cuda_available, "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+")
class TestHQQ(unittest.TestCase):
def _test_hqq(self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None):
if(dtype is None): return
def _test_hqq(
self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None
):
if dtype is None:
return
dequantize_error, dot_product_error = _eval_hqq(dtype)
self.assertTrue(dequantize_error < ref_dequantize_error)
self.assertTrue(dot_product_error < ref_dot_product_error)

def test_hqq_plain_8bit(self):
self._test_hqq(dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013)
self._test_hqq(
dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013
)

def test_hqq_plain_7bit(self):
self._test_hqq(dtype=torch.uint7, ref_dequantize_error=6e-05, ref_dot_product_error=0.000193)
self._test_hqq(
dtype=torch.uint7,
ref_dequantize_error=6e-05,
ref_dot_product_error=0.000193,
)

def test_hqq_plain_6bit(self):
self._test_hqq(dtype=torch.uint6, ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353)
self._test_hqq(
dtype=torch.uint6,
ref_dequantize_error=0.0001131,
ref_dot_product_error=0.000353,
)

def test_hqq_plain_5bit(self):
self._test_hqq(dtype=torch.uint5, ref_dequantize_error=0.00023, ref_dot_product_error=0.000704)
self._test_hqq(
dtype=torch.uint5,
ref_dequantize_error=0.00023,
ref_dot_product_error=0.000704,
)

def test_hqq_plain_4bit(self):
self._test_hqq(dtype=torch.uint4, ref_dequantize_error=0.000487, ref_dot_product_error=0.001472)
self._test_hqq(
dtype=torch.uint4,
ref_dequantize_error=0.000487,
ref_dot_product_error=0.001472,
)

def test_hqq_plain_3bit(self):
self._test_hqq(dtype=torch.uint3, ref_dequantize_error=0.00101, ref_dot_product_error=0.003047)
self._test_hqq(
dtype=torch.uint3,
ref_dequantize_error=0.00101,
ref_dot_product_error=0.003047,
)

def test_hqq_plain_2bit(self):
self._test_hqq(dtype=torch.uint2, ref_dequantize_error=0.002366, ref_dot_product_error=0.007255)
self._test_hqq(
dtype=torch.uint2,
ref_dequantize_error=0.002366,
ref_dot_product_error=0.007255,
)


if __name__ == "__main__":
unittest.main()
17 changes: 9 additions & 8 deletions test/hqq/test_triton_mm.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Skip entire test if following module not available, otherwise CI failure
import itertools

import pytest
import torch

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

triton = pytest.importorskip(
"triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test"
)
hqq = pytest.importorskip("hqq", reason="hqq required to run this test")
hqq_quantize = pytest.importorskip("hqq.core.quantize", reason="hqq required to run this test")
hqq_quantize = pytest.importorskip(
"hqq.core.quantize", reason="hqq required to run this test"
)
HQQLinear = hqq_quantize.HQQLinear
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig

import itertools

import torch

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

# Test configs
SHAPES = [
[16, 128, 128],
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_mixed_mm(
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
if not quant_config["weight_quant_params"]["bitpack"]
else W_q
)
W_dq = hqq_linear.dequantize()
Expand Down
14 changes: 6 additions & 8 deletions test/hqq/test_triton_qkv_fused.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import itertools

import pytest
import torch

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

triton = pytest.importorskip(
"triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test"
Expand All @@ -10,13 +15,6 @@
HQQLinear = hqq_quantize.HQQLinear
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig

import itertools

import torch
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, Quantizer

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

torch.manual_seed(0)
# N, K = shape
Q_SHAPES = [[4096, 4096]]
Expand Down Expand Up @@ -60,7 +58,7 @@ def quantize_helper(
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
if not quant_config["weight_quant_params"]["bitpack"]
else W_q
)

Expand Down

0 comments on commit ae6b157

Please sign in to comment.