diff --git a/ruff.toml b/ruff.toml index b20cab030c..c96d55f9bc 100644 --- a/ruff.toml +++ b/ruff.toml @@ -7,15 +7,18 @@ include = [ "torchao/quantization/**/*.py", "torchao/dtypes/**/*.py", "torchao/sparsity/**/*.py", - "torchao/profiler/**/*.py", "torchao/testing/**/*.py", "torchao/prototype/low_bit_optim/**.py", "torchao/utils.py", "torchao/ops.py", "torchao/_executorch_ops.py", + # Test folders + "test/dora/**/*.py", + "test/dtypes/**/*.py", "test/float8/**/*.py", + "test/galore/**/*.py", + "test/hqq/**/*.py", "test/quantization/**/*.py", - "test/dtypes/**/*.py", "test/sparsity/**/*.py", "test/prototype/low_bit_optim/**.py", ] diff --git a/test/dora/test_dora_fusion.py b/test/dora/test_dora_fusion.py index a7959f85af..562a20f665 100644 --- a/test/dora/test_dora_fusion.py +++ b/test/dora/test_dora_fusion.py @@ -1,19 +1,17 @@ +import itertools import sys import pytest +import torch + +from torchao.prototype.dora.kernels.matmul import triton_mm +from torchao.prototype.dora.kernels.smallk import triton_mm_small_k if sys.version_info < (3, 11): pytest.skip("requires Python >= 3.11", allow_module_level=True) triton = pytest.importorskip("triton", reason="requires triton") -import itertools - -import torch - -from torchao.prototype.dora.kernels.matmul import triton_mm -from torchao.prototype.dora.kernels.smallk import triton_mm_small_k - torch.manual_seed(0) # Test configs @@ -48,13 +46,7 @@ def _arg_to_id(arg): def check(expected, actual, dtype): - if dtype == torch.float32: - atol = 1e-4 - elif dtype == torch.float16: - atol = 1e-3 - elif dtype == torch.bfloat16: - atol = 1e-2 - else: + if dtype not in [torch.float32, torch.float16, torch.bfloat16]: raise ValueError(f"Unsupported dtype: {dtype}") diff = (expected - actual).abs().max() print(f"diff: {diff}") diff --git a/test/dora/test_dora_layer.py b/test/dora/test_dora_layer.py index dd38cc8d6b..dfe89f30b3 100644 --- a/test/dora/test_dora_layer.py +++ b/test/dora/test_dora_layer.py @@ -1,6 +1,10 @@ +import itertools import sys import pytest +import torch + +from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear if sys.version_info < (3, 11): pytest.skip("requires Python >= 3.11", allow_module_level=True) @@ -8,25 +12,15 @@ bnbnn = pytest.importorskip("bitsandbytes.nn", reason="requires bitsandbytes") hqq_core = pytest.importorskip("hqq.core.quantize", reason="requires hqq") -import itertools - -import torch # Import modules as opposed to classes directly, otherwise pytest.importorskip always skips Linear4bit = bnbnn.Linear4bit BaseQuantizeConfig = hqq_core.BaseQuantizeConfig HQQLinear = hqq_core.HQQLinear -from torchao.prototype.dora.dora_layer import BNBDoRALinear, DoRALinear, HQQDoRALinear def check(expected, actual, dtype): - if dtype == torch.float32: - atol = 1e-4 - elif dtype == torch.float16: - atol = 1e-3 - elif dtype == torch.bfloat16: - atol = 1e-2 - else: + if dtype not in [torch.float32, torch.float16, torch.bfloat16]: raise ValueError(f"Unsupported dtype: {dtype}") diff = (expected - actual).abs().max() print(f"diff: {diff}") diff --git a/test/galore/memory_analysis_utils.py b/test/galore/memory_analysis_utils.py index 6e464e8766..68da0012d2 100644 --- a/test/galore/memory_analysis_utils.py +++ b/test/galore/memory_analysis_utils.py @@ -50,7 +50,7 @@ def _convert_to_units(df, col): convert_cols_to_MB = {col: partial(_convert_to_units, col=col) for col in COL_NAMES} df = pd.DataFrame( - [l[1:] for l in df.iloc[:, 1].to_list()], columns=COL_NAMES + [row[1:] for row in df.iloc[:, 1].to_list()], columns=COL_NAMES ).assign(**convert_cols_to_MB) df["Total"] = df.sum(axis=1) return df diff --git a/test/galore/profile_memory_usage.py b/test/galore/profile_memory_usage.py index fce4e18a87..94d70bdd0e 100644 --- a/test/galore/profile_memory_usage.py +++ b/test/galore/profile_memory_usage.py @@ -86,7 +86,7 @@ def run(args, file_prefix): model_config = LlamaConfig() try: model_config_dict = getattr(model_configs, args.model_config.upper()) - except: + except Exception: raise ValueError(f"Model config {args.model_config} not found") model_config.update(model_config_dict) model = LlamaForCausalLM(model_config).to("cuda") @@ -163,7 +163,7 @@ def run(args, file_prefix): if args.torch_profiler: print(f"Finished profiling, outputs saved to {args.output_dir}/{file_prefix}*") else: - print(f"Finished profiling") + print("Finished profiling") if __name__ == "__main__": diff --git a/test/galore/profiling_utils.py b/test/galore/profiling_utils.py index 80d6c03d84..033d33ab5a 100644 --- a/test/galore/profiling_utils.py +++ b/test/galore/profiling_utils.py @@ -75,7 +75,6 @@ def get_cuda_memory_usage(units="MB", show=True): def export_memory_snapshot(prefix) -> None: - # Prefix for file names. timestamp = datetime.now().strftime(TIME_FORMAT_STR) file_prefix = f"{prefix}_{timestamp}" @@ -115,7 +114,6 @@ def trace_handler( export_memory_timeline=True, print_table=True, ): - timestamp = datetime.now().strftime(TIME_FORMAT_STR) file_prefix = os.path.join(output_dir, f"{prefix}_{timestamp}") diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 2f231fbb31..381886d594 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -1,58 +1,73 @@ import unittest + import torch + from torchao.quantization import ( - ZeroPointDomain, MappingType, + ZeroPointDomain, + int4_weight_only, + uintx_weight_only, ) - from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, ) -from torchao.quantization import ( - uintx_weight_only, - int4_weight_only, -) cuda_available = torch.cuda.is_available() -#Parameters -device = 'cuda:0' -compute_dtype = torch.bfloat16 -group_size = 64 -mapping_type = MappingType.ASYMMETRIC -block_size = (1, group_size) #axis=1 -preserve_zero = False +# Parameters +device = "cuda:0" +compute_dtype = torch.bfloat16 +group_size = 64 +mapping_type = MappingType.ASYMMETRIC +block_size = (1, group_size) # axis=1 +preserve_zero = False zero_point_domain = ZeroPointDomain.FLOAT -zero_point_dtype = compute_dtype -inner_k_tiles = 8 -in_features = 4096 -out_features = 11800 -torch_seed = 100 +zero_point_dtype = compute_dtype +inner_k_tiles = 8 +in_features = 4096 +out_features = 11800 +torch_seed = 100 def _init_data(in_features, out_features, compute_dtype, device, torch_seed): torch.random.manual_seed(torch_seed) linear_layer = torch.nn.Linear(in_features, out_features, bias=False).to(device) - x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20. + x = ( + torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device) + / 20.0 + ) y_ref = linear_layer(x) W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype) return W, x, y_ref + def _eval_hqq(dtype): - W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed) + W, x, y_ref = _init_data( + in_features, out_features, compute_dtype, device, torch_seed + ) - dummy_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False) + dummy_linear = torch.nn.Linear( + in_features=in_features, out_features=out_features, bias=False + ) dummy_linear.weight.data = W if dtype == torch.uint4: - q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(dummy_linear).weight + q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)( + dummy_linear + ).weight else: - q_tensor_hqq = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)(dummy_linear).weight + q_tensor_hqq = uintx_weight_only( + dtype, group_size=max(block_size), use_hqq=True + )(dummy_linear).weight - quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device) + quant_linear_layer = torch.nn.Linear( + W.shape[1], W.shape[0], bias=False, device=W.device + ) del quant_linear_layer.weight quant_linear_layer.weight = q_tensor_hqq dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item() - dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item() + dot_product_error = ( + (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item() + ) return dequantize_error, dot_product_error @@ -60,32 +75,62 @@ def _eval_hqq(dtype): @unittest.skipIf(not cuda_available, "Need CUDA available") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+") class TestHQQ(unittest.TestCase): - def _test_hqq(self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None): - if(dtype is None): return + def _test_hqq( + self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None + ): + if dtype is None: + return dequantize_error, dot_product_error = _eval_hqq(dtype) self.assertTrue(dequantize_error < ref_dequantize_error) self.assertTrue(dot_product_error < ref_dot_product_error) def test_hqq_plain_8bit(self): - self._test_hqq(dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013) + self._test_hqq( + dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013 + ) def test_hqq_plain_7bit(self): - self._test_hqq(dtype=torch.uint7, ref_dequantize_error=6e-05, ref_dot_product_error=0.000193) + self._test_hqq( + dtype=torch.uint7, + ref_dequantize_error=6e-05, + ref_dot_product_error=0.000193, + ) def test_hqq_plain_6bit(self): - self._test_hqq(dtype=torch.uint6, ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353) + self._test_hqq( + dtype=torch.uint6, + ref_dequantize_error=0.0001131, + ref_dot_product_error=0.000353, + ) def test_hqq_plain_5bit(self): - self._test_hqq(dtype=torch.uint5, ref_dequantize_error=0.00023, ref_dot_product_error=0.000704) + self._test_hqq( + dtype=torch.uint5, + ref_dequantize_error=0.00023, + ref_dot_product_error=0.000704, + ) def test_hqq_plain_4bit(self): - self._test_hqq(dtype=torch.uint4, ref_dequantize_error=0.000487, ref_dot_product_error=0.001472) + self._test_hqq( + dtype=torch.uint4, + ref_dequantize_error=0.000487, + ref_dot_product_error=0.001472, + ) def test_hqq_plain_3bit(self): - self._test_hqq(dtype=torch.uint3, ref_dequantize_error=0.00101, ref_dot_product_error=0.003047) + self._test_hqq( + dtype=torch.uint3, + ref_dequantize_error=0.00101, + ref_dot_product_error=0.003047, + ) def test_hqq_plain_2bit(self): - self._test_hqq(dtype=torch.uint2, ref_dequantize_error=0.002366, ref_dot_product_error=0.007255) + self._test_hqq( + dtype=torch.uint2, + ref_dequantize_error=0.002366, + ref_dot_product_error=0.007255, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 4684f28221..6b1abbc587 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,20 +1,21 @@ # Skip entire test if following module not available, otherwise CI failure +import itertools + import pytest +import torch + +from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm triton = pytest.importorskip( "triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test" ) hqq = pytest.importorskip("hqq", reason="hqq required to run this test") -hqq_quantize = pytest.importorskip("hqq.core.quantize", reason="hqq required to run this test") +hqq_quantize = pytest.importorskip( + "hqq.core.quantize", reason="hqq required to run this test" +) HQQLinear = hqq_quantize.HQQLinear BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig -import itertools - -import torch - -from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm - # Test configs SHAPES = [ [16, 128, 128], @@ -96,7 +97,7 @@ def test_mixed_mm( W_q = W_q.to(dtype=quant_dtype) W_q = ( W_q.reshape(meta["shape"]) - if quant_config["weight_quant_params"]["bitpack"] == False + if not quant_config["weight_quant_params"]["bitpack"] else W_q ) W_dq = hqq_linear.dequantize() diff --git a/test/hqq/test_triton_qkv_fused.py b/test/hqq/test_triton_qkv_fused.py index eda171a9ca..aa00e17f87 100644 --- a/test/hqq/test_triton_qkv_fused.py +++ b/test/hqq/test_triton_qkv_fused.py @@ -1,4 +1,9 @@ +import itertools + import pytest +import torch + +from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm triton = pytest.importorskip( "triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test" @@ -10,13 +15,6 @@ HQQLinear = hqq_quantize.HQQLinear BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig -import itertools - -import torch -from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, Quantizer - -from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm - torch.manual_seed(0) # N, K = shape Q_SHAPES = [[4096, 4096]] @@ -60,7 +58,7 @@ def quantize_helper( W_q = W_q.to(dtype=quant_dtype) W_q = ( W_q.reshape(meta["shape"]) - if quant_config["weight_quant_params"]["bitpack"] == False + if not quant_config["weight_quant_params"]["bitpack"] else W_q )