diff --git a/benchmarks/bench_galore_fused_kernels.py b/benchmarks/bench_galore_fused_kernels.py index c05f31e921..261c98acb1 100644 --- a/benchmarks/bench_galore_fused_kernels.py +++ b/benchmarks/bench_galore_fused_kernels.py @@ -8,9 +8,7 @@ def run(args): dtype = getattr(torch, args.dtype) allow_tf32 = args.allow_tf32 - fp8_fast_accum = False torch.backends.cuda.matmul.allow_tf32 = allow_tf32 - kernel = args.kernel M, N = args.M, args.N rank = args.rank diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index ebf9e1e738..e0ab170bd0 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -1,23 +1,26 @@ -"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs -""" +"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs""" + +import copy + import torch + +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, +) from torchao.quantization.subclass import ( - Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, + unwrap_tensor_subclass, ) -from torchao.quantization.quant_api import ( - int4_weight_only, - int8_weight_only, - int8_dynamic_activation_int8_weight, - quantize_, - _replace_with_custom_fn_if_matches_filter, -) -import copy -from torchao.utils import unwrap_tensor_subclass + def _int8wo_api(mod, **kwargs): if TORCH_VERSION_AT_LEAST_2_4: @@ -27,14 +30,20 @@ def _int8wo_api(mod, **kwargs): else: change_linear_weights_to_int8_woqtensors(mod, **kwargs) + def _int8da_int8w_api(mod, **kwargs): if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_dynamic_activation_int8_weight(**kwargs), set_inductor_config=False) + quantize_( + mod, + int8_dynamic_activation_int8_weight(**kwargs), + set_inductor_config=False, + ) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_dqtensors(mod, **kwargs) + def _int4wo_api(mod, **kwargs): if TORCH_VERSION_AT_LEAST_2_4: kwargs_copy = kwargs.copy() @@ -47,31 +56,43 @@ def _int4wo_api(mod, **kwargs): else: change_linear_weights_to_int4_woqtensors(mod, **kwargs) + class ToyLinearModel(torch.nn.Module): - """Single linear for m * k * n problem size - """ - def __init__(self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"): + """Single linear for m * k * n problem size""" + + def __init__( + self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda" + ): super().__init__() self.m = m self.dtype = dtype self.device = device - self.linear = torch.nn.Linear(k, n, bias=has_bias).to(dtype=self.dtype, device=self.device) + self.linear = torch.nn.Linear(k, n, bias=has_bias).to( + dtype=self.dtype, device=self.device + ) def example_inputs(self): - return (torch.randn(self.m, self.linear.in_features, dtype=self.dtype, device=self.device),) + return ( + torch.randn( + self.m, self.linear.in_features, dtype=self.dtype, device=self.device + ), + ) def forward(self, x): x = self.linear(x) return x + def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): """ The deprecated implementation for int8 dynamic quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _in_features_greater_than_16 - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import ( + _get_subclass_inserter, + _in_features_greater_than_16, + _is_linear, + ) from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight if filter_fn is None: @@ -80,40 +101,54 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs ) _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + model, + _get_subclass_inserter( + Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs + ), + filter_fn, ) + def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ The deprecated implementation for weight only quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs), + _get_subclass_inserter( + deprecated_tenosr_subclass, enable_parametrization=True, **kwargs + ), filter_fn, ) return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) + +_ref_change_linear_weights_to_int8_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) +) +_ref_change_linear_weights_to_int4_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) +) torch._dynamo.config.cache_size_limit = 50000 + @torch.no_grad def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): if kwargs is None: kwargs = {} - m = ToyLinearModel(M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda").eval() + m = ToyLinearModel( + M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda" + ).eval() m_bf16 = copy.deepcopy(m) m_ref = copy.deepcopy(m) example_inputs = m.example_inputs() @@ -130,26 +165,30 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): # perf comparison from torchao.utils import benchmark_model + # warmup WARMUP = 20 RUNS = 100 torch._dynamo.reset() - m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True) + m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True) benchmark_model(m_ref, WARMUP, example_inputs) ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs) torch._dynamo.reset() - m = torch.compile(m, mode='max-autotune', fullgraph=True) + m = torch.compile(m, mode="max-autotune", fullgraph=True) benchmark_model(m, WARMUP, example_inputs) elapsed_time = benchmark_model(m, RUNS, example_inputs) torch._dynamo.reset() - m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True) + m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True) benchmark_model(m_bf16, WARMUP, example_inputs) bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) - print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}") + print( + f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}" + ) + if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available(): all_shapes = [ @@ -158,16 +197,25 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): print("_int8da_int8w_api") from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors + for M, N, K in all_shapes: - _bench_quantized_tensor_subclass_perf(_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K) + _bench_quantized_tensor_subclass_perf( + _int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K + ) print("_int8wo_api") from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors + for M, N, K in all_shapes: - _bench_quantized_tensor_subclass_perf(_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K) + _bench_quantized_tensor_subclass_perf( + _int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K + ) print("_int4wo_api") kwargs = {"groupsize": 32} from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors + for M, N, K in all_shapes: - _bench_quantized_tensor_subclass_perf(_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs) + _bench_quantized_tensor_subclass_perf( + _int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs + ) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index 25967baa25..c20599532c 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,17 +1,22 @@ -import torch import pandas as pd +import torch import torch.nn.functional as F +from tqdm import tqdm + from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout +from torchao.dtypes.floatx import FloatxTensorCoreLayout from torchao.utils import benchmark_torch_function_in_microseconds -from tqdm import tqdm def benchmark(m: int, k: int, n: int): float_data_fp16 = torch.randn(n, k, dtype=torch.float16, device="cuda") float_data_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") - fp6_weight_fp16 = to_affine_quantized_fpx(float_data_fp16, FloatxTensorCoreLayout(3, 2)) - fp6_weight_bf16 = to_affine_quantized_fpx(float_data_bf16, FloatxTensorCoreLayout(3, 2)) + fp6_weight_fp16 = to_affine_quantized_fpx( + float_data_fp16, FloatxTensorCoreLayout(3, 2) + ) + fp6_weight_bf16 = to_affine_quantized_fpx( + float_data_bf16, FloatxTensorCoreLayout(3, 2) + ) fp16_weight = fp6_weight_fp16.dequantize(torch.float16) bf16_weight = fp6_weight_bf16.dequantize(torch.bfloat16) @@ -22,15 +27,27 @@ def benchmark(m: int, k: int, n: int): fp16_output = F.linear(fp16_act, fp16_weight) bf16_output = F.linear(bf16_act, bf16_weight) - fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight) - bf16_time = benchmark_torch_function_in_microseconds(F.linear, bf16_act, bf16_weight) - fp6_time_fp16 = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight_fp16) - fp6_time_bf16 = benchmark_torch_function_in_microseconds(F.linear, bf16_act, fp6_weight_bf16) + fp16_time = benchmark_torch_function_in_microseconds( + F.linear, fp16_act, fp16_weight + ) + bf16_time = benchmark_torch_function_in_microseconds( + F.linear, bf16_act, bf16_weight + ) + fp6_time_fp16 = benchmark_torch_function_in_microseconds( + F.linear, fp16_act, fp6_weight_fp16 + ) + fp6_time_bf16 = benchmark_torch_function_in_microseconds( + F.linear, bf16_act, fp6_weight_bf16 + ) # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py # doesn't seem to be the right way to check for correctness - correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 - correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2 + correct_fp16 = ( + fp6_output_fp16 - fp16_output + ).abs().mean() / fp16_output.abs().mean() < 1e-3 + correct_bf16 = ( + fp6_output_bf16 - bf16_output + ).abs().mean() / bf16_output.abs().mean() < 1e-2 return { "m": m, diff --git a/benchmarks/benchmark_gpu_sparsity.py b/benchmarks/benchmark_gpu_sparsity.py index 5a579ed1a6..9e22f6d43a 100644 --- a/benchmarks/benchmark_gpu_sparsity.py +++ b/benchmarks/benchmark_gpu_sparsity.py @@ -1,19 +1,18 @@ import argparse -import random import pandas as pd import torch -import torch.utils.benchmark as benchmark import torch.nn.functional as F -from torch import nn from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured - from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm -from torchao.utils import benchmark_model, profiler_runner -from torchao.sparsity.utils import create_semi_structured_tensor, create_block_sparse_tensor - from tqdm import tqdm +from torchao.sparsity.utils import ( + create_block_sparse_tensor, + create_semi_structured_tensor, +) +from torchao.utils import benchmark_model + torch.set_printoptions( precision=2, threshold=None, @@ -23,10 +22,11 @@ sci_mode=False, ) + def benchmark_model_with_warmup(func, x, N_WARMUP=3): benchmark_model(func, N_WARMUP, device_type="cuda") return benchmark_model(func, 10, device_type="cuda") - + def run_gpu_sparse_benchmark(m, k, n, args): with torch.no_grad(): @@ -40,21 +40,33 @@ def run_gpu_sparse_benchmark(m, k, n, args): A = create_semi_structured_tensor(m, k, dtype) A_sparse = to_sparse_semi_structured(A) elif args.sparsity == "block-sparse": - A = create_block_sparse_tensor(m, k, args.block_size, args.sparsity_level, dtype) + A = create_block_sparse_tensor( + m, k, args.block_size, args.sparsity_level, dtype + ) A_sparse = A.to_sparse_bsr(blocksize=args.block_size) # BSR kernel tuning if args.bsr_autotune: print("Tuning kernel params") - optimize_bsr_dense_addmm(m, k, n, args.block_size, args.block_size, - dtype=dtype, sparsity=args.sparsity_level, verbose=True) + optimize_bsr_dense_addmm( + m, + k, + n, + args.block_size, + args.block_size, + dtype=dtype, + sparsity=args.sparsity_level, + verbose=True, + ) else: raise ValueError(f"Unknown sparsity: {args.sparsity}") if args.eval_fn == "linear": b = torch.randn(m, dtype=dtype).cuda() + # can't use lambda def dense_func(): return F.linear(x, A, b) + def sparse_func(): return F.linear(x, A_sparse, b) @@ -66,26 +78,41 @@ def sparse_func(): scale_b = torch.tensor([1.0], device="cuda") def dense_func(): - return torch._scaled_mm(A, x, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16) + return torch._scaled_mm( + A, x, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16 + ) + def sparse_func(): - return torch._scaled_mm(A_sparse, x, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16) + return torch._scaled_mm( + A_sparse, + x, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + ) else: x = x.t() + def dense_func(): return torch.mm(A, x) + def sparse_func(): return torch.mm(A_sparse, x) else: raise ValueError(f"Unknown eval_fn: {args.eval_fn}") - dense_time = benchmark_model_with_warmup(dense_func, 'dense.json.gz') - sparse_time = benchmark_model_with_warmup(sparse_func, 'sparse.json.gz') + dense_time = benchmark_model_with_warmup(dense_func, "dense.json.gz") + sparse_time = benchmark_model_with_warmup(sparse_func, "sparse.json.gz") dense_func_c = torch.compile(dense_func, mode="max-autotune") - dense_time_c = benchmark_model_with_warmup(dense_func_c, 'dense_compile.json.gz') + dense_time_c = benchmark_model_with_warmup( + dense_func_c, "dense_compile.json.gz" + ) sparse_func_c = torch.compile(sparse_func, mode="max-autotune") - sparse_time_c = benchmark_model_with_warmup(sparse_func_c, 'sparse_compile.json.gz') + sparse_time_c = benchmark_model_with_warmup( + sparse_func_c, "sparse_compile.json.gz" + ) torch._dynamo.reset() @@ -99,7 +126,8 @@ def sparse_func(): "dense": dense_time, "dense_c": dense_time_c, "sparse_c": sparse_time_c, - "speedup (d/s)": min(dense_time, dense_time_c) / min(sparse_time, sparse_time_c), + "speedup (d/s)": min(dense_time, dense_time_c) + / min(sparse_time, sparse_time_c), } @@ -135,27 +163,25 @@ def sparse_func(): 16, 32, 64, - ] + ], ) parser.add_argument( "--dtype", type=str, - choices=[ - "int8", - "float16", - "bfloat16", - "float32", - "float8_e4m3fn" - ], + choices=["int8", "float16", "bfloat16", "float32", "float8_e4m3fn"], default="bfloat16", ) parser.add_argument( "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt" ) - parser.add_argument("--eval-fn", type=str, choices=["linear", "mm"], default="linear") + parser.add_argument( + "--eval-fn", type=str, choices=["linear", "mm"], default="linear" + ) parser.add_argument("-contiguous", action="store_true") parser.add_argument("-save", action="store_true") - parser.add_argument("-bsr-autotune", action="store_true", help="Tune BSR kernel parameters") + parser.add_argument( + "-bsr-autotune", action="store_true", help="Tune BSR kernel parameters" + ) args = parser.parse_args() print(f"Started benchmark: {args}") @@ -170,8 +196,7 @@ def sparse_func(): (4096, 16384, 2816), ] results = ( - run_gpu_sparse_benchmark(m, k, n, args) - for (m, n, k) in tqdm(mm_shapes) + run_gpu_sparse_benchmark(m, k, n, args) for (m, n, k) in tqdm(mm_shapes) ) elif args.mode == "llama3-8b-w": mm_shapes = [ @@ -186,11 +211,10 @@ def sparse_func(): (8192, 11008, 4096), ] results = ( - run_gpu_sparse_benchmark(m, k, n, args) - for (m, k, n) in tqdm(mm_shapes) + run_gpu_sparse_benchmark(m, k, n, args) for (m, k, n) in tqdm(mm_shapes) ) elif args.mode == "vit-mlp": - vit_shapes= [ + vit_shapes = [ # vit-base (768, 3072, 50432), (3072, 3072, 50432), @@ -199,8 +223,7 @@ def sparse_func(): (5120, 1280, 65792), ] results = ( - run_gpu_sparse_benchmark(m, k, n, args) - for (m, k, n) in tqdm(vit_shapes) + run_gpu_sparse_benchmark(m, k, n, args) for (m, k, n) in tqdm(vit_shapes) ) elif args.mode == "nvidia-fixed-k": mn_vals = [ @@ -224,8 +247,7 @@ def sparse_func(): 20480, ] results = ( - run_gpu_sparse_benchmark(mn, 10240, mn, args) - for mn in tqdm(mn_vals) + run_gpu_sparse_benchmark(mn, 10240, mn, args) for mn in tqdm(mn_vals) ) elif args.mode == "nvidia-fixed-mn": k_vals = [ @@ -246,14 +268,12 @@ def sparse_func(): 20480, ] results = ( - run_gpu_sparse_benchmark(10240, k, 10240, args) - for k in tqdm(k_vals) + run_gpu_sparse_benchmark(10240, k, 10240, args) for k in tqdm(k_vals) ) else: raise ValueError(f"Unknown mode: {args.mode}") - df = pd.DataFrame.from_records(results) if args.save: save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" diff --git a/benchmarks/benchmark_hqq.py b/benchmarks/benchmark_hqq.py index 123e2e5f52..d86f16cb57 100644 --- a/benchmarks/benchmark_hqq.py +++ b/benchmarks/benchmark_hqq.py @@ -1,5 +1,5 @@ try: - import hqq + import hqq # noqa: F401 import triton if int(triton.__version__.split(".")[0]) < 3: @@ -54,10 +54,16 @@ def fn(): return t -def bench_hqq(x, hqq_linear: HQQLinear | HQQLinearTorchWeightOnlyInt4, transposed=False, tinygemm=False): +def bench_hqq( + x, + hqq_linear: HQQLinear | HQQLinearTorchWeightOnlyInt4, + transposed=False, + tinygemm=False, +): def reference_fn(): W_dq = hqq_linear.dequantize() _ = x @ W_dq.T if not transposed else x @ W_dq + fn = reference_fn if not tinygemm else lambda: hqq_linear(x) t = do_bench(fn) @@ -138,9 +144,9 @@ def run_benchmark( [1024, 4096, 4096], ] -DTYPES = [torch.bfloat16] #[torch.float16, torch.bfloat16] +DTYPES = [torch.bfloat16] # [torch.float16, torch.bfloat16] GROUP_SIZES = [128] -TRANSPOSED = [True] #[False, True] +TRANSPOSED = [True] # [False, True] HEADERS = [ "M", @@ -171,4 +177,4 @@ def run_benchmark( df = pd.DataFrame(data, columns=HEADERS) df.to_csv(output, index=False) print(output.getvalue()) - # df.to_csv("benchmark_hqq_tinygemm.csv", index=False) \ No newline at end of file + # df.to_csv("benchmark_hqq_tinygemm.csv", index=False) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 986cc58b4f..ba5f109c0d 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -31,11 +31,11 @@ import torch.nn.functional as F import wandb from torch.utils.data import DataLoader -from torchao.utils import get_available_devices from torchvision.transforms import v2 from tqdm import tqdm from torchao.prototype import low_bit_optim +from torchao.utils import get_available_devices _DEVICE = get_available_devices()[-1] assert _DEVICE in ["cuda", "xpu"], "Benchmark currently only supports CUDA & XPU(BF16)" @@ -157,7 +157,7 @@ def evaluate_model(model, args): all_labels = [] all_preds = [] - for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"): + for batch in tqdm(val_dloader, dynamic_ncols=True, desc="Evaluating"): all_labels.append(batch["label"].clone()) if args.full_bf16: batch["image"] = batch["image"].bfloat16() diff --git a/benchmarks/benchmark_marlin_qqq.py b/benchmarks/benchmark_marlin_qqq.py index 295d089680..51df84abe4 100644 --- a/benchmarks/benchmark_marlin_qqq.py +++ b/benchmarks/benchmark_marlin_qqq.py @@ -1,9 +1,10 @@ -import torch import pandas as pd -from torchao.utils import benchmark_torch_function_in_microseconds +import torch +from tqdm import tqdm + from torchao.ops import marlin_qqq_gemm from torchao.quantization.marlin_qqq import marlin_qqq_workspace, pack_to_marlin_qqq -from tqdm import tqdm +from torchao.utils import benchmark_torch_function_in_microseconds def get_problem(m, n, k, groupsize=-1): diff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_s8s4_cutlass.py index 397544b658..fbf07ebb35 100644 --- a/benchmarks/benchmark_s8s4_cutlass.py +++ b/benchmarks/benchmark_s8s4_cutlass.py @@ -1,13 +1,12 @@ -import torch import pandas as pd -from torchao.utils import benchmark_torch_function_in_microseconds -from torchao.ops import s8s4_linear_cutlass +import torch from tqdm import tqdm +from torchao.ops import s8s4_linear_cutlass +from torchao.utils import benchmark_torch_function_in_microseconds -def get_problem(m, n, k): - groupsize = k +def get_problem(m, n, k): dev = torch.device("cuda") A_ref = torch.randn((m, k), dtype=torch.half, device=dev) B_ref = torch.randn((k, n), dtype=torch.half, device=dev) diff --git a/benchmarks/benchmark_semi_sparse_training.py b/benchmarks/benchmark_semi_sparse_training.py index 35796157bd..bf075c0e51 100644 --- a/benchmarks/benchmark_semi_sparse_training.py +++ b/benchmarks/benchmark_semi_sparse_training.py @@ -7,21 +7,21 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import argparse -import itertools import gc +import itertools -from typing import Tuple - +import pandas as pd import torch -import torch.nn.functional as F +from segment_anything_fast import sam_model_registry +from torch.sparse import to_sparse_semi_structured from torch.utils import benchmark -from torch.sparse import to_sparse_semi_structured -from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear +from torchao.sparsity.training import ( + SemiSparseLinear, + swap_linear_with_semi_sparse_linear, +) from torchao.sparsity.training.autograd import semi_structured_sparsify -from segment_anything_fast import sam_model_registry -import pandas as pd def product_dict(**kwargs): keys = kwargs.keys() @@ -29,20 +29,24 @@ def product_dict(**kwargs): for instance in itertools.product(*vals): yield dict(zip(keys, instance)) + def benchmark_helper( - functions, - cases, + functions, + cases, fw: bool = False, bw: bool = False, cuda_graph: bool = False, compile: bool = False, - blocked_autorange = False, + blocked_autorange=False, ): assert fw or bw assert not (cuda_graph and compile) - print(f"Running benchmarks with: fw={fw}, bw={bw}, cuda_graph={cuda_graph}, compile={compile}: ") + print( + f"Running benchmarks with: fw={fw}, bw={bw}, cuda_graph={cuda_graph}, compile={compile}: " + ) results = [] + def handle_case(**case): for sparsity_config, benchmark_cls in functions.items(): result = { @@ -69,9 +73,11 @@ def run_one(): g.replay() if compile: - benchmark_object.model = torch.compile(benchmark_object.model, mode="max-autotune") + benchmark_object.model = torch.compile( + benchmark_object.model, mode="max-autotune" + ) - #benchmark + # benchmark torch.cuda.reset_peak_memory_stats() t0 = benchmark.Timer( stmt="fn()", @@ -83,18 +89,23 @@ def run_one(): if blocked_autorange: res = t0.blocked_autorange() else: - res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20) - result.update({'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}) + res = t0.adaptive_autorange(0.03, min_run_time=0.2, max_run_time=20) + result.update( + { + "time": res.median * 1e3, + "memory": torch.cuda.max_memory_allocated() / 1e9, + } + ) except Exception as e: if "CUDA out of memory" not in str(e): raise else: - result.update({'time': 'OOM', 'memory': 'OOM'}) + result.update({"time": "OOM", "memory": "OOM"}) finally: # clean up - if 'benchmark_object' in locals(): + if "benchmark_object" in locals(): del benchmark_object - if 'g' in locals(): + if "g" in locals(): del g gc.collect() torch.cuda.empty_cache() @@ -104,13 +115,16 @@ def run_one(): handle_case(**case) return pd.DataFrame(results) + # test classes for Linear class LinearTest(torch.nn.Module): def __init__(self, mkn): super().__init__() m, k, n = mkn self.model = torch.nn.Linear(k, n).cuda().half() - self.input = torch.randn([m, k], device='cuda', dtype=torch.half, requires_grad=True) + self.input = torch.randn( + [m, k], device="cuda", dtype=torch.half, requires_grad=True + ) self.grad = torch.randn([m, n], device="cuda", dtype=torch.half) def fw(self): @@ -119,23 +133,30 @@ def fw(self): def bw(self): self.out.backward(self.grad, retain_graph=True) + class SemiSparseLinearOfflineCompressionTest(torch.nn.Module): def __init__(self, mkn): super().__init__() m, k, n = mkn self.model = torch.nn.Linear(k, n).cuda().half() - self.model.weight = torch.nn.Parameter(to_sparse_semi_structured(self.model.weight)) - self.input = torch.randn([m, k], device='cuda', dtype=torch.half, requires_grad=True) + self.model.weight = torch.nn.Parameter( + to_sparse_semi_structured(self.model.weight) + ) + self.input = torch.randn( + [m, k], device="cuda", dtype=torch.half, requires_grad=True + ) self.grad = torch.randn([m, n], device="cuda", dtype=torch.half) def fw(self): self.out = self.model(self.input) + class SemiSparseLinearTest(LinearTest): def __init__(self, mkn): super().__init__(mkn) self.model = SemiSparseLinear.from_dense(self.model) + class SemiSparseKernelTest(LinearTest): def __init__(self, mkn): super().__init__(mkn) @@ -145,15 +166,27 @@ def fw(self): def bw(self): pass - + + # test class for ViT (SAM image encoder) class SAMTest(torch.nn.Module): - def __init__(self, model_type, batch_size): super().__init__() - self.model = sam_model_registry[model_type]().image_encoder.cuda().half().train() - self.input = torch.randn(batch_size, 3, 1024, 1024, device='cuda', dtype=torch.half, requires_grad=True) - self.grad = torch.randn([batch_size, 256, 64, 64], device="cuda", dtype=torch.half) + self.model = ( + sam_model_registry[model_type]().image_encoder.cuda().half().train() + ) + self.input = torch.randn( + batch_size, + 3, + 1024, + 1024, + device="cuda", + dtype=torch.half, + requires_grad=True, + ) + self.grad = torch.randn( + [batch_size, 256, 64, 64], device="cuda", dtype=torch.half + ) def fw(self): self.out = self.model(self.input) @@ -161,16 +194,18 @@ def fw(self): def bw(self): self.out.backward(self.grad, retain_graph=True) + class SAM_W24_MLP_ONLY(SAMTest): def __init__(self, model_type, batch_size): super().__init__(model_type, batch_size) # Apply to just MLP linear layers of SAM image encoder (ViT) sparse_config = {} for name, mod in self.model.named_modules(): - if isinstance(mod, torch.nn.Linear) and 'mlp' in name: + if isinstance(mod, torch.nn.Linear) and "mlp" in name: sparse_config[name] = SemiSparseLinear swap_linear_with_semi_sparse_linear(self.model, sparse_config) + class SAM_W24_ALL(SAMTest): def __init__(self, model_type, batch_size): super().__init__(model_type, batch_size) @@ -181,17 +216,26 @@ def __init__(self, model_type, batch_size): sparse_config[name] = SemiSparseLinear swap_linear_with_semi_sparse_linear(self.model, sparse_config) + if __name__ == "__main__": print("BENCHMARKING") - parser = argparse.ArgumentParser(description='run semi-structured sparse training benchmarks') - parser.add_argument('--mode', type=str, choices=["linear", "llama3-8b", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit") - parser.add_argument('--save', action="store_true", help="save benchmarking results") + parser = argparse.ArgumentParser( + description="run semi-structured sparse training benchmarks" + ) + parser.add_argument( + "--mode", + type=str, + choices=["linear", "llama3-8b", "vit"], + help="nn.Linear/ViT-e2e benchmarking", + default="vit", + ) + parser.add_argument("--save", action="store_true", help="save benchmarking results") args = parser.parse_args() if args.mode == "linear": functions = { "dense_linear": LinearTest, "semi_sparse_linear": SemiSparseLinearTest, - "semi_sparse_prune+compress_time_only": SemiSparseKernelTest, + "semi_sparse_prune+compress_time_only": SemiSparseKernelTest, } cases = list( product_dict( @@ -205,12 +249,8 @@ def __init__(self, model_type, batch_size): ) df = benchmark_helper( - functions, - cases, - fw=True, - bw=True, - cuda_graph=True, - blocked_autorange=True) + functions, cases, fw=True, bw=True, cuda_graph=True, blocked_autorange=True + ) elif args.mode == "llama3-8b": functions = { "dense_linear": LinearTest, @@ -233,12 +273,8 @@ def __init__(self, model_type, batch_size): ) df = benchmark_helper( - functions, - cases, - fw=True, - bw=False, - cuda_graph=True, - blocked_autorange=True) + functions, cases, fw=True, bw=False, cuda_graph=True, blocked_autorange=True + ) elif args.mode == "vit": functions = { @@ -246,19 +282,9 @@ def __init__(self, model_type, batch_size): "ViT MLP weight 2:4 sparse": SAM_W24_MLP_ONLY, # "ViT all(MLP+ATN) Linear weight 2:4 sparse": SAM_W24_ALL } - cases = list( - product_dict( - model_type=['vit_l'], - batch_size=[8] - ) - ) + cases = list(product_dict(model_type=["vit_l"], batch_size=[8])) - df = benchmark_helper( - functions, - cases, - fw=True, - bw=True, - compile=True) + df = benchmark_helper(functions, cases, fw=True, bw=True, compile=True) print(df) if args.save: diff --git a/benchmarks/benchmark_uintx.py b/benchmarks/benchmark_uintx.py index 9887fb8b46..78debbca82 100644 --- a/benchmarks/benchmark_uintx.py +++ b/benchmarks/benchmark_uintx.py @@ -1,24 +1,27 @@ -from math import log from copy import deepcopy import torch -from torchao.utils import unwrap_tensor_subclass -from torchao.prototype.uintx import uintx_affine_weight_only, pack, unpack, pack_cpu, unpack_cpu + +from torchao.prototype.uintx import ( + uintx_affine_weight_only, + unpack_cpu, +) from torchao.quantization.quant_api import quantize_ - + + class Linear16(torch.nn.Module): def __init__(self, scale): super().__init__() self.net = torch.nn.Sequential( - torch.nn.Linear(scale*2, scale, bias=True, dtype=torch.float16).cuda(), + torch.nn.Linear(scale * 2, scale, bias=True, dtype=torch.float16).cuda(), torch.nn.Linear(scale, scale, bias=True, dtype=torch.float16).cuda(), - torch.nn.Linear(scale, scale//2, bias=True, dtype=torch.float16).cuda(), + torch.nn.Linear(scale, scale // 2, bias=True, dtype=torch.float16).cuda(), ) def forward(self, x): return self.net(x) - + def benchmark(function, args, num_runs): # warmup torch._dynamo.reset() @@ -38,25 +41,26 @@ def benchmark(function, args, num_runs): def profile_bitpack(): - from torch.profiler import profile, record_function, ProfilerActivity - fake_tensor = [torch.randint(2**8, (512,512), dtype=torch.uint8).cuda()] + from torch.profiler import ProfilerActivity, profile + + fake_tensor = [torch.randint(2**8, (512, 512), dtype=torch.uint8).cuda()] func = torch.compile(unpack_cpu, fullgraph=True) - with profile(activities=[ - ProfilerActivity.CPU, - ProfilerActivity.CUDA], + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, - with_stack=True - ) as prof: - + with_stack=True, + ) as prof: for _ in range(1000): - unpacked = func(fake_tensor, 4) - + func(fake_tensor, 4) + # Print a summary with open("profile-bitpack.txt", "a") as f: - print(f'{func}',file=f) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f) + print(f"{func}", file=f) + print( + prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f + ) prof.export_chrome_trace("trace.json") - ''' + """ CPU perf: unpack_gpu Self CPU time total: 602.501ms @@ -71,39 +75,37 @@ def profile_bitpack(): unpack_cpu: Self CPU time total: 96.947ms Self CUDA time total: 5.253ms - ''' - -def uintx_vs_fp16(nbits= [1,2,3,4,5,6,7], scales=[256, 512, 1024], repeats=30): - results = [] + """ + + +def uintx_vs_fp16(nbits=[1, 2, 3, 4, 5, 6, 7], scales=[256, 512, 1024], repeats=30): + results = [] nbits.sort() scales.sort() for scale in scales: - test_input = torch.randn(scale*2, dtype=torch.float16).cuda() + test_input = torch.randn(scale * 2, dtype=torch.float16).cuda() forward_args = [test_input] times = [scale] - + fp16 = Linear16(scale) fp16c = torch.compile(fp16, fullgraph=True) fp16_time = benchmark(fp16c.forward, forward_args, repeats) times.append(fp16_time) for bit_size in nbits: m = deepcopy(fp16) - quantize_(m, uintx_affine_weight_only(bit_size)) + quantize_(m, uintx_affine_weight_only(bit_size)) m = torch.compile(m, fullgraph=True) uintx_time = benchmark(m.forward, forward_args, repeats) times.append(uintx_time) - print(f'scale={scale} done') - + print(f"scale={scale} done") + results.append(times) print("----------- benchmark results -----------") for result in results: print(f"scale: {result[0]} fp16 time:{result[1]: .2f}ms speedups:") for i in range(2, len(result)): print(f"int{nbits[i-2]}: {result[1]/result[i]: .2f}x") - - - -if __name__ == "__main__": - uintx_vs_fp16(nbits=[4,7]) - - \ No newline at end of file + + +if __name__ == "__main__": + uintx_vs_fp16(nbits=[4, 7]) diff --git a/benchmarks/dora/bench_utils.py b/benchmarks/dora/bench_utils.py index 2de4fa637a..9e3aed0e5d 100644 --- a/benchmarks/dora/bench_utils.py +++ b/benchmarks/dora/bench_utils.py @@ -1,7 +1,6 @@ import torch from bitsandbytes.nn import Linear4bit from hqq.core.quantize import BaseQuantizeConfig, HQQLinear - from prototypes.dora.dora_layer import BNBDoRALinear, HQQDoRALinear from prototypes.dora.kernels.matmul import triton_mm from prototypes.dora.kernels.smallk import triton_mm_small_k diff --git a/benchmarks/dora/dora_bench.py b/benchmarks/dora/dora_bench.py index 305cfbdb15..217f0b1871 100644 --- a/benchmarks/dora/dora_bench.py +++ b/benchmarks/dora/dora_bench.py @@ -17,9 +17,9 @@ ) from triton.testing import do_bench +from torchao.prototype.common.profiling_tools import pivot_df from torchao.prototype.dora.kernels.matmul import triton_mm from torchao.prototype.dora.kernels.smallk import triton_mm_small_k -from torchao.prototype.common.profiling_tools import pivot_df def run_colnorm_bench(args): diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index f92303c627..d160d7241d 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -11,14 +11,16 @@ from typing import Callable, List, Optional, Tuple import pandas as pd - import torch import torch.utils.benchmark as benchmark +from tqdm import tqdm +from utils import get_name_to_shapes_iter + from torchao.float8.config import ( - CastConfig, - Float8LinearConfig, - ScalingType, + CastConfig, + Float8LinearConfig, ScalingGranularity, + ScalingType, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -26,8 +28,6 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_tensor import ScaledMMConfig -from utils import get_name_to_shapes_iter -from tqdm import tqdm # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N @@ -105,7 +105,7 @@ def main( fast_accum_filter: Optional[bool] = None, shape_name_filter: Optional[str] = None, *, - shape_gen_name: str = 'llama', + shape_gen_name: str = "llama", M: Optional[int] = None, K: Optional[int] = None, N: Optional[int] = None, @@ -123,35 +123,35 @@ def main( scaling_granularity = ScalingGranularity(scaling_granularity) if scaling_type_input is ScalingType.STATIC: - cast_config_input=CastConfig( + cast_config_input = CastConfig( scaling_type=scaling_type_input, static_scale=torch.tensor([1.0], device="cuda"), scaling_granularity=scaling_granularity, ) else: - cast_config_input=CastConfig( + cast_config_input = CastConfig( scaling_type=scaling_type_input, scaling_granularity=scaling_granularity, ) if scaling_type_weight is ScalingType.STATIC: - cast_config_weight=CastConfig( + cast_config_weight = CastConfig( scaling_type=scaling_type_weight, static_scale=torch.tensor([1.0], device="cuda"), scaling_granularity=scaling_granularity, ) else: - cast_config_weight=CastConfig( + cast_config_weight = CastConfig( scaling_type=scaling_type_weight, scaling_granularity=scaling_granularity, ) if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output=CastConfig( + cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, static_scale=torch.tensor([1.0], device="cuda"), scaling_granularity=scaling_granularity, ) else: - cast_config_grad_output=CastConfig( + cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, scaling_granularity=scaling_granularity, ) @@ -169,7 +169,6 @@ def main( else: use_fast_accum = [True, False] if shape_name_filter is not None: - k = shape_name_filter name_to_shapes = ((k, v) for (k, v) in name_to_shapes if k == shape_name_filter) experiment_list: List[Experiment] = [] dtype = torch.bfloat16 @@ -336,11 +335,11 @@ def invoke_main() -> None: if args.shape_gen_name is not None: kwargs["shape_gen_name"] = args.shape_gen_name if args.M is not None: - kwargs["M"] = args.M, + kwargs["M"] = (args.M,) if args.K is not None: - kwargs["K"] = args.K, + kwargs["K"] = (args.K,) if args.N is not None: - kwargs["N"] = args.N, + kwargs["N"] = (args.N,) if args.scaling_type_input is not None: kwargs["scaling_type_input"] = args.scaling_type_input if args.scaling_type_weight is not None: diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index e969846b28..3d48853754 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -8,19 +8,16 @@ import fire import pandas as pd - import torch import torch.nn as nn import torch.utils.benchmark as benchmark - -from torchao.float8.config import ScalingGranularity - from utils import ( - get_name_to_shapes_iter, - profiler_output_to_filtered_time_by_kernel_name, get_gpu_kernel_gemm_time_s, + get_name_to_shapes_iter, ) +from torchao.float8.config import ScalingGranularity + # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N @@ -50,11 +47,11 @@ def benchmark_fn_in_sec(f, *args, **kwargs): def do_benchmarks( - tops, - peak_tops, - use_gpu_kernel_time, - f, - *args, + tops, + peak_tops, + use_gpu_kernel_time, + f, + *args, **kwargs, ): if use_gpu_kernel_time: @@ -71,7 +68,7 @@ def do_benchmarks( @torch.inference_mode() def run( n_limit: Optional[int] = None, - shape_gen_name: str = 'llama', + shape_gen_name: str = "llama", out_filename: Optional[str] = None, M: Optional[int] = None, K: Optional[int] = None, @@ -81,7 +78,16 @@ def run( ): device = "cuda" - headers = ("fast_accum", "name", "M", "K", "N", "ref_time_s", "fp8_time_s", "fp8_speedup") + headers = ( + "fast_accum", + "name", + "M", + "K", + "N", + "ref_time_s", + "fp8_time_s", + "fp8_speedup", + ) results = [] dtype = torch.bfloat16 @@ -89,7 +95,9 @@ def run( fast_accum_vals = [True, False] scaling_granularity = ScalingGranularity(scaling_granularity) - for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)): + for idx, (fast_accum, (name, (M, K, N))) in enumerate( + itertools.product(fast_accum_vals, name_to_shapes) + ): if n_limit is not None and idx >= n_limit: break @@ -156,6 +164,7 @@ def do_matmul(A, B): if out_filename is not None: data_df.to_csv(out_filename) + def main() -> None: fire.Fire(run) diff --git a/benchmarks/float8/bench_multi_gpu.py b/benchmarks/float8/bench_multi_gpu.py index 44c758d1b5..34a690edbe 100644 --- a/benchmarks/float8/bench_multi_gpu.py +++ b/benchmarks/float8/bench_multi_gpu.py @@ -8,19 +8,18 @@ from typing import Callable import fire - import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.utils.benchmark as benchmark +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import ( convert_to_float8_training, sync_float8_amax_and_scale_history, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - torch.manual_seed(0) diff --git a/benchmarks/float8/bench_padding.py b/benchmarks/float8/bench_padding.py index 9777553433..7fc451641e 100644 --- a/benchmarks/float8/bench_padding.py +++ b/benchmarks/float8/bench_padding.py @@ -2,18 +2,18 @@ from typing import Optional import fire - import torch +from tabulate import tabulate +from torch._inductor.utils import do_bench_using_profiling +from tqdm import tqdm + from torchao.float8.float8_tensor import ( GemmInputRole, - hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, + hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import pad_tensor_for_matmul -from tabulate import tabulate -from torch._inductor.utils import do_bench_using_profiling -from tqdm import tqdm # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 19c6cc21bc..2b3f631d8c 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -7,11 +7,11 @@ """ This is a script to estimate the benefit from converting a `torch.nn.Linear` layer to float8, by estimating the difference in e2e GPU kernel time between: -1. bf16 gemms in fwd and bwd, and +1. bf16 gemms in fwd and bwd, and 2. float8 gemms in fwd and bwd, and float8 overhead The gemm times are estimated either from direct measurements via benchmarks, -or with a roofline estimation based on TOPS and peak compute bandwidth of an +or with a roofline estimation based on TOPS and peak compute bandwidth of an NVIDIA H100. The float8 overhead times are estimated by counting memory reads and writes @@ -39,38 +39,35 @@ 5. assume no float8 all-gather (TODO model it) """ -import csv import copy import json import os -import time from typing import Optional import fire import pandas as pd import sympy -import tqdm - import torch import torch.utils.benchmark as benchmark -from torch.profiler import profile, ProfilerActivity, record_function - +import tqdm +from torch.profiler import ProfilerActivity, profile from utils import ( - get_name_to_shapes_iter, - get_gpu_kernel_gemm_time_s, + get_gpu_kernel_gemm_time_s, + get_name_to_shapes_iter, profiler_output_to_filtered_time_by_kernel_name, ) -from torchao.float8.roofline_utils import ( - get_gemm_time_sympy, - get_float8_mem_sympy, -) + from torchao.float8 import ( - convert_to_float8_training, - Float8LinearConfig, - ScalingType, CastConfig, + Float8LinearConfig, + ScalingType, + convert_to_float8_training, +) +from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config +from torchao.float8.roofline_utils import ( + get_float8_mem_sympy, + get_gemm_time_sympy, ) -from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName class LNLinearSigmoid(torch.nn.Module): @@ -85,6 +82,8 @@ def forward(self, x): x = self.fc(x) x = self.sigmoid(x) return x + + # TODO(next): hook this up @@ -103,7 +102,7 @@ def get_gpu_kernel_time(m, x): # warm up for _ in range(2): m(x).sum().backward() - + # capture a profiling run activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] n_iter = 5 @@ -114,18 +113,19 @@ def get_gpu_kernel_time(m, x): # get the gpu kernel time and aggregate it num_leaf_tensors = 1 + len(list(m.parameters())) ref_times = profiler_output_to_filtered_time_by_kernel_name( - prof, n_iter, num_leaf_tensors) + prof, n_iter, num_leaf_tensors + ) total_time_s = sum(v for v in ref_times.values()) / 1e6 / n_iter return total_time_s -def get_gemm_times(M, K, N, fast_accum, cache_filename=None): +def get_gemm_times(M, K, N, fast_accum, cache_filename=None): # Note: this is definitely not the best way to build a cache, # but it will do for now. if cache_filename is not None: if os.path.isfile(cache_filename): # cache already exists, use it - with open(cache_filename, 'r') as f: + with open(cache_filename, "r") as f: cache = json.load(f) else: # cache does not exist yet, create it @@ -136,7 +136,7 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None): if key in cache: return cache[key] - device = torch.device('cuda') + device = torch.device("cuda") # bf16 time x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) @@ -154,6 +154,7 @@ def do_matmul(A, B): return torch._scaled_mm( A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum ) + f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) scale_a = torch.ones(M, 1, device=device) @@ -164,11 +165,12 @@ def do_matmul(A, B): # save to cache if needed if cache_filename is not None: cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s] - with open(cache_filename, 'w') as f: + with open(cache_filename, "w") as f: json.dump(cache, f) return bf16_time_s, f8_time_s, f8_axs_time_s + def run( outfile: str, gemm_time_strategy: str = "benchmarks", @@ -190,37 +192,47 @@ def run( * `n_limit (optional)`: if specified, only runs `n_limit` iterations """ - print(f'gemm_time_strategy: {gemm_time_strategy}') - print(f'shape_gen_name: {shape_gen_name}') + print(f"gemm_time_strategy: {gemm_time_strategy}") + print(f"shape_gen_name: {shape_gen_name}") - assert gemm_time_strategy in ("benchmarks", "roofline"), \ - "`gemm_time_strategy` must be 'benchmarks' or 'roofline'" + assert gemm_time_strategy in ( + "benchmarks", + "roofline", + ), "`gemm_time_strategy` must be 'benchmarks' or 'roofline'" - M, K, N = sympy.symbols('M K N') + M, K, N = sympy.symbols("M K N") fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy( - M, K, N, + M, + K, + N, model_torch_compile_limitations=True, scaling_type_input="dynamic", scaling_type_weight="dynamic", scaling_type_grad_output="dynamic", ) fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy( - M, K, N, + M, + K, + N, model_torch_compile_limitations=False, scaling_type_input="dynamic", scaling_type_weight="dynamic", scaling_type_grad_output="dynamic", ) fp8_mem_time_sympy_del_limit = get_float8_mem_sympy( - M, K, N, + M, + K, + N, model_torch_compile_limitations=True, scaling_type_input="delayed", scaling_type_weight="delayed", scaling_type_grad_output="delayed", ) fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy( - M, K, N, + M, + K, + N, model_torch_compile_limitations=False, scaling_type_input="delayed", scaling_type_weight="delayed", @@ -229,28 +241,39 @@ def run( if gemm_time_strategy == "roofline": bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16) - print('bf16_gemm_time_sympy', bf16_gemm_time_sympy) + print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn) - print('fp8_gemm_time_sympy', fp8_gemm_time_sympy) + print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) print() else: print() headers = [ - 'fwd_M', 'fwd_K', 'fwd_N', + "fwd_M", + "fwd_K", + "fwd_N", # gemm microbenchmarks - 'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s', + "bf16_gemm_s", + "fp8_gemm_s", + "fp8_axs_gemm_time_s", # roofline memory overhead estimates - 'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit', - 'fp8_oh_del_limit', 'fp8_oh_del_nolimit', + "fp8_oh_dyn_limit", + "fp8_oh_dyn_nolimit", + "fp8_oh_del_limit", + "fp8_oh_del_nolimit", # actual e2e measurements - 'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s', + "bf16_s", + "fp8_dyn_s", + "fp8_del_s", + "fp8_dyn_axs_s", # 'fp8_lw_s', - 'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp', + "fp8_dyn_sp", + "fp8_del_sp", + "fp8_dyn_axs_sp", # 'fp8_lw_sp', ] results = [] - + name_to_shapes = get_name_to_shapes_iter(shape_gen_name, None, None, None) for idx, (name, (M_val, K_val, N_val)) in enumerate(tqdm.tqdm(name_to_shapes)): @@ -258,31 +281,47 @@ def run( break if gemm_time_strategy == "benchmarks": - bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename) - bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename) - bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename) + bf16_g1, f8_g1, f8_g1_axs = get_gemm_times( + M_val, K_val, N_val, True, gemm_cache_filename + ) + bf16_g2, f8_g2, f8_g2_axs = get_gemm_times( + M_val, N_val, K_val, False, gemm_cache_filename + ) + bf16_g3, f8_g3, f8_g3_axs = get_gemm_times( + K_val, M_val, N_val, False, gemm_cache_filename + ) bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3 fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs else: assert gemm_time_strategy == "roofline", "unsupported" - bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) - fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + bf16_time_val = ( + bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) + fp8_gemm_time_s = ( + fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) # for now, assume axiswise gemm is similar to tensorwise - fp8_axs_gemm_time_s = fp8_gemm_time_s + fp8_axs_gemm_time_s = fp8_gemm_time_s - fp8_mem_time_dyn_limit_s = \ + fp8_mem_time_dyn_limit_s = ( fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - fp8_mem_time_dyn_nolimit_s = \ + ) + fp8_mem_time_dyn_nolimit_s = ( fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - fp8_mem_time_del_limit_s = \ + ) + fp8_mem_time_del_limit_s = ( fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - fp8_mem_time_del_nolimit_s = \ + ) + fp8_mem_time_del_nolimit_s = ( fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) # create the model m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() - x = torch.randn(M_val, K_val, dtype=torch.bfloat16, device="cuda").requires_grad_() + x = torch.randn( + M_val, K_val, dtype=torch.bfloat16, device="cuda" + ).requires_grad_() # get the bf16 gpu kernel time torch._dynamo.reset() @@ -324,29 +363,38 @@ def run( # m_fp8_lw = torch.compile(m_fp8_lw) # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) - results.append([ - M_val, K_val, N_val, - # gemm microbenchmarks - bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s, - # roofline overhead estimates - fp8_mem_time_dyn_limit_s, - fp8_mem_time_dyn_nolimit_s, - fp8_mem_time_del_limit_s, - fp8_mem_time_del_nolimit_s, - # e2e numbers - bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s, - fp8_dyn_axs_time_actual_s, - # fp8_lw_time_actual_s, - bf16_time_actual_s / fp8_dyn_time_actual_s, - bf16_time_actual_s / fp8_del_time_actual_s, - bf16_time_actual_s / fp8_dyn_axs_time_actual_s, - # bf16_time_actual_s / fp8_lw_time_actual_s, - ]) + results.append( + [ + M_val, + K_val, + N_val, + # gemm microbenchmarks + bf16_time_val, + fp8_gemm_time_s, + fp8_axs_gemm_time_s, + # roofline overhead estimates + fp8_mem_time_dyn_limit_s, + fp8_mem_time_dyn_nolimit_s, + fp8_mem_time_del_limit_s, + fp8_mem_time_del_nolimit_s, + # e2e numbers + bf16_time_actual_s, + fp8_dyn_time_actual_s, + fp8_del_time_actual_s, + fp8_dyn_axs_time_actual_s, + # fp8_lw_time_actual_s, + bf16_time_actual_s / fp8_dyn_time_actual_s, + bf16_time_actual_s / fp8_del_time_actual_s, + bf16_time_actual_s / fp8_dyn_axs_time_actual_s, + # bf16_time_actual_s / fp8_lw_time_actual_s, + ] + ) df = pd.DataFrame(results, columns=headers) print(df) df.to_csv(outfile) - print('done') + print("done") + -if __name__ == '__main__': +if __name__ == "__main__": fire.Fire(run) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index e545ea4665..38a8c5e875 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -5,35 +5,41 @@ # LICENSE file in the root directory of this source tree. import copy -import io import functools +import io import os +import pathlib import random from contextlib import nullcontext, redirect_stdout from dataclasses import dataclass, field -import pathlib from typing import Callable, Optional import fire import pandas as pd # disable inductor FX cache, so we can can always study the inductor output logs -os.environ['TORCHINDUCTOR_FORCE_DISABLE_CACHES'] = '1' +os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" import torch import torch.nn as nn import torch.nn.functional as F +from torch.profiler import ProfilerActivity, profile, record_function from torch.utils.checkpoint import ( - checkpoint, - create_selective_checkpoint_contexts, CheckpointPolicy, + checkpoint, + create_selective_checkpoint_contexts, +) +from utils import ( + kernel_name_to_category, + parse_bw_and_kernel_name, + profiler_output_to_filtered_time_by_kernel_name, + profiler_output_to_gpu_time_for_key, + update_triton_kernels_in_prof_chome_trace_with_torch_logs, ) + from torchao.float8.config import ( - CastConfig, - Float8LinearConfig, - ScalingType, - ScalingGranularity, Float8LinearRecipeName, + ScalingType, recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( @@ -42,14 +48,6 @@ sync_float8_amax_and_scale_history, ) from torchao.testing.float8.test_utils import get_test_float8_linear_config -from torch.profiler import profile, ProfilerActivity, record_function -from utils import ( - kernel_name_to_category, - parse_bw_and_kernel_name, - profiler_output_to_gpu_time_for_key, - profiler_output_to_filtered_time_by_kernel_name, - update_triton_kernels_in_prof_chome_trace_with_torch_logs, -) # don't truncate long kernel names pd.options.display.max_colwidth = 100 @@ -57,7 +55,6 @@ pd.set_option("display.float_format", "{:.3f}".format) - class LNLinear(torch.nn.Module): def __init__(self, fc_dim1, fc_dim2): super().__init__() @@ -184,10 +181,10 @@ class ProfileConfig: def profile_function( - config: ProfileConfig, - func: Callable, + config: ProfileConfig, + func: Callable, add_inductor_metadata_to_trace: bool, - *args, + *args, **kwargs, ) -> torch.profiler.profile: """Profile a torch function and save the result to a file""" @@ -197,8 +194,10 @@ def profile_function( if add_inductor_metadata_to_trace: # ensure we aren't interfering with other torch_log settings - if os.environ.get('TORCH_LOGS', '') != '': - raise AssertionError('using TORCH_LOGS together with add_inductor_metadata_to_trace is not supported yet') + if os.environ.get("TORCH_LOGS", "") != "": + raise AssertionError( + "using TORCH_LOGS together with add_inductor_metadata_to_trace is not supported yet" + ) # save torch.compile logs to a file specific to this benchmark run # TODO(future): can we hack torch.compile to print to file only and not stdout? @@ -209,7 +208,6 @@ def profile_function( if os.path.isfile(config.logs_file_path): pathlib.Path.unlink(config.logs_file_path) torch._logging._init_logs(log_file_name=config.logs_file_path) - activities = [ProfilerActivity.CPU] if config.cuda: @@ -267,12 +265,14 @@ def profile_function( torch.ops.aten.max.default, ] + def policy_fn(ctx, op, *args, **kwargs): if op in ops_to_save: return CheckpointPolicy.MUST_SAVE else: return CheckpointPolicy.PREFER_RECOMPUTE + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) @@ -289,7 +289,12 @@ def main( enable_sync_amax_history: bool = True, enable_activation_checkpointing: bool = False, ): - assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported" + assert model_type in ( + "linear", + "ln_linear", + "norm_ffn_norm", + "norm_ffn_norm_small", + ), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") scaling_type_input = ScalingType(scaling_type_input) @@ -317,7 +322,9 @@ def main( print(f"Compile is set to | {compile}") print(f"model_type is set to | {model_type}") print(f"scaling_repr is set to | {scaling_repr}") - print(f"enable_activation_checkpointing is set to {enable_activation_checkpointing}") + print( + f"enable_activation_checkpointing is set to {enable_activation_checkpointing}" + ) device = "cuda" ref_dtype = torch.bfloat16 @@ -406,7 +413,7 @@ def float8_forw_backward_wrapper(x): # to populate triton kernel bandwidth further down in the script if os.environ.get("TORCHINDUCTOR_PROFILE", "") == "": context = nullcontext() - f = None + f = None else: f = io.StringIO() context = redirect_stdout(f) @@ -425,16 +432,31 @@ def float8_forw_backward_wrapper(x): ref_logs_suffix = f"_{model_type}_ref_compile_{compile}.txt" trace_ref_path = profile_path_prefix + ref_trace_suffix log_ref_path = profile_path_prefix + ref_logs_suffix - trace_ref_modified_path = trace_ref_path.replace(".json", "_modified.json") + trace_ref_modified_path = trace_ref_path.replace( + ".json", "_modified.json" + ) profile_config = ProfileConfig( - trace_ref_path, log_ref_path, trace_ref_modified_path, ref_trace_suffix, iters=profile_iters, warmup_iters=2, sync=True + trace_ref_path, + log_ref_path, + trace_ref_modified_path, + ref_trace_suffix, + iters=profile_iters, + warmup_iters=2, + sync=True, + ) + p = profile_function( + profile_config, + ref_forw_backward, + add_inductor_metadata_to_trace, + input_tensor, ) - p = profile_function(profile_config, ref_forw_backward, add_inductor_metadata_to_trace, input_tensor) print(f"saved profiling trace to {trace_ref_path}") if add_inductor_metadata_to_trace: print(f"saved torch logs to {log_ref_path}") print(f"saved modified trace to {trace_ref_modified_path}") - ref_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors) + ref_times = profiler_output_to_filtered_time_by_kernel_name( + p, profile_iters, num_leaf_tensors + ) total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters for k, v in ref_times.items(): v_ms = v / 1e3 / profile_iters @@ -460,7 +482,9 @@ def float8_forw_backward_wrapper(x): ) trace_float8_path = profile_path_prefix + float8_trace_suffix log_float8_path = profile_path_prefix + float8_log_suffix - trace_float8_modified_path = trace_float8_path.replace(".json", "_modified.json") + trace_float8_modified_path = trace_float8_path.replace( + ".json", "_modified.json" + ) profile_config = ProfileConfig( trace_float8_path, log_float8_path, @@ -471,14 +495,21 @@ def float8_forw_backward_wrapper(x): sync=True, ) p = profile_function( - profile_config, float8_forw_backward_wrapper, add_inductor_metadata_to_trace, input_tensor + profile_config, + float8_forw_backward_wrapper, + add_inductor_metadata_to_trace, + input_tensor, ) print(f"saved profiling trace to {trace_float8_path}") if add_inductor_metadata_to_trace: print(f"saved torch logs to {log_float8_path}") print(f"saved modified trace to {trace_float8_modified_path}") - float8_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors) - total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters + float8_times = profiler_output_to_filtered_time_by_kernel_name( + p, profile_iters, num_leaf_tensors + ) + total_time_ms = ( + sum(v for v in float8_times.values()) / 1e3 / profile_iters + ) for k, v in float8_times.items(): v_ms = v / 1e3 / profile_iters data.append( diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index 6567f69497..60e402e60e 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -9,15 +9,16 @@ import re from typing import Optional -from torch.profiler import profile, ProfilerActivity, record_function +from torch.profiler import ProfilerActivity, profile + def profiler_output_to_filtered_time_by_kernel_name( - prof, + prof, num_iter: int, num_leaf_tensors: int, ): """ - Input: + Input: * `prof`: a profiler with captured events * `num_iter`: number of iterations used to capture `prof` * `num_leaf_tensors`: number of leaf tensors to accumulate gradients to @@ -28,7 +29,7 @@ def profiler_output_to_filtered_time_by_kernel_name( set up as follows: # - # Forward pass + # Forward pass # # Expected GPU kernel overhead: none @@ -59,7 +60,6 @@ def profiler_output_to_filtered_time_by_kernel_name( thresh = 1e-10 kernel_name_to_gpu_time_us = collections.defaultdict(float) for e in key_averages: - # manually filter top-level CPU events with attributed CUDA time # example CPU event row from printing `key_averages`: # aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1 @@ -69,23 +69,25 @@ def profiler_output_to_filtered_time_by_kernel_name( continue # manually filter expected microbenchmarking overhead, in order of execution - if e.key == 'aten::sum': + if e.key == "aten::sum": # forward pass sum - assert e.count == num_iter, f'unexpected number of iter for {e.key}' + assert e.count == num_iter, f"unexpected number of iter for {e.key}" continue - elif e.key == 'aten::fill_': + elif e.key == "aten::fill_": # filling the forward pass sum with 1.0 - assert e.count == num_iter, f'unexpected number of iter for {e.key}' + assert e.count == num_iter, f"unexpected number of iter for {e.key}" continue - elif e.key == 'aten::copy_': + elif e.key == "aten::copy_": # copying 1.0 from grad_out of `sum` to grad_out of next op - assert e.count == num_iter, f'unexpected number of iter for {e.key}' + assert e.count == num_iter, f"unexpected number of iter for {e.key}" continue - elif e.key == 'aten::add_': + elif e.key == "aten::add_": # accumulating gradients into leaf tensors - assert e.count == (num_iter * num_leaf_tensors), f'unexpected number of iter for {e.key}' + assert e.count == ( + num_iter * num_leaf_tensors + ), f"unexpected number of iter for {e.key}" continue - elif e.key == 'cudaDeviceSynchronize': + elif e.key == "cudaDeviceSynchronize": continue kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total @@ -148,9 +150,10 @@ def get_name_to_shapes_iter( K: Optional[int], N: Optional[int], ): - if shape_gen_name == 'llama': - assert M == K == N == None, \ - f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}' + if shape_gen_name == "llama": + assert ( + M == K == N == None + ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" bsz, seq_len = 4, 4096 M = bsz * seq_len # LLaMa 2 70B single-node weight shapes @@ -164,43 +167,47 @@ def get_name_to_shapes_iter( } return name_to_shapes_70b.items() - elif shape_gen_name == 'square': - assert M == K == N == None, \ - f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}' + elif shape_gen_name == "square": + assert ( + M == K == N == None + ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" name_to_shapes = {} min_power_of_2 = 8 # 256 max_power_of_2 = 15 # 32,768 for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): - val = 2 ** power_of_2 + val = 2**power_of_2 name_to_shapes[idx] = val, val, val return name_to_shapes.items() - elif shape_gen_name == 'sweep': - assert M == K == N == None, \ - f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}' + elif shape_gen_name == "sweep": + assert ( + M == K == N == None + ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" name_to_shapes = {} min_p2 = 8 # 256 max_p2 = 15 # 32,768 counter = 0 for M_p2 in range(min_p2, max_p2 + 1): - M = 2 ** M_p2 + M = 2**M_p2 for K_p2 in range(min_p2, max_p2 + 1): - K = 2 ** K_p2 + K = 2**K_p2 for N_p2 in range(min_p2, max_p2 + 1): - N = 2 ** N_p2 + N = 2**N_p2 name_to_shapes[counter] = M, K, N counter += 1 return name_to_shapes.items() - elif shape_gen_name == 'custom': - assert M is not None and K is not None and N is not None, \ - 'M, K, N must be specified for custom shape_gen' + elif shape_gen_name == "custom": + assert ( + M is not None and K is not None and N is not None + ), "M, K, N must be specified for custom shape_gen" name_to_shapes = { 1: (M, K, N), } return name_to_shapes.items() - raise AssertionError(f'unknown shape_gen_name {shape_gen_name}') + raise AssertionError(f"unknown shape_gen_name {shape_gen_name}") + # copy-pasta from https://github.com/vkuzo/pytorch_scripts/blob/main/add_inductor_metadata_to_perf_trace.py def update_triton_kernels_in_prof_chome_trace_with_torch_logs( @@ -209,34 +216,31 @@ def update_triton_kernels_in_prof_chome_trace_with_torch_logs( modified_perf_trace_file: str, ): """ - Input 1: a perf trace generated by using `torch.profiler.profile` inside of + Input 1: a perf trace generated by using `torch.profiler.profile` inside of some_program.py, and containing torch.compile + inductor kernels - Input 2: a text file with the output of + Input 2: a text file with the output of TORCH_LOGS="output_code" python some_program.py Input 3: filename for the modified perf trace This script does the following for each triton kernel in input 1: - navigate to the kernel information in the logs from input 2 - - copy over the kernel metadata (aten graph, triton code, etc) to the JSON + - copy over the kernel metadata (aten graph, triton code, etc) to the JSON in input 1 - The end result is that Input 1 is modified so that the kernel metadata is + The end result is that Input 1 is modified so that the kernel metadata is directly visible in tools like chrome://tracing and perfetto. """ - external_id_to_cpu_ops = dict() - external_id_to_kernels = dict() - # open the torch logs file torch_logs_str = None - with open(torch_logs_file, 'r') as f: + with open(torch_logs_file, "r") as f: torch_logs_str = f.readlines() # strip away the torch_logs prefix torch_logs_only = [] for line in torch_logs_str: - line = line.replace('\n', '') - match = re.match('.* \[__output_code\] (.*)', line) + line = line.replace("\n", "") + match = re.match(".* \[__output_code\] (.*)", line) if match: torch_logs_only.append(match.group(1)) @@ -253,7 +257,7 @@ def update_triton_kernels_in_prof_chome_trace_with_torch_logs( name_to_start_end = {} cur_start, cur_end, cur_name = None, None, None for line_num, line in enumerate(torch_logs_only): - match_start = re.match('\# kernel path: .*', line) + match_start = re.match("\# kernel path: .*", line) if match_start: cur_start = line_num @@ -279,14 +283,14 @@ def update_triton_kernels_in_prof_chome_trace_with_torch_logs( # ... # // CPU ops, with names matchable to triton kernels from inductor output code # { - # # "cat": "cpu_op", + # # "cat": "cpu_op", # # "name": "triton_red_fused_LayerNorm_abs_max_0", # # "args": {"External id": 1030, ...}, # # ... # }, # // Inductor kernels, with wall time # { - # # "cat": "kernel", + # # "cat": "kernel", # # "name": "triton_", // we don't depend on this name, including for context # # "args": {"External id": 1030, ...}, # # "ts": 4275686082015.124, // start time @@ -300,35 +304,37 @@ def update_triton_kernels_in_prof_chome_trace_with_torch_logs( # 2. Using 1, add the metadata to triton kernels # open the perf trace json - with open(perf_trace_file, 'r') as f: + with open(perf_trace_file, "r") as f: perf_trace = json.load(f) # find mapping of cpu_op to external_id external_id_to_cpu_op = dict() - for record in perf_trace['traceEvents']: + for record in perf_trace["traceEvents"]: # print(record) - is_cpu_op = record.get('cat') == 'cpu_op' + is_cpu_op = record.get("cat") == "cpu_op" if is_cpu_op: - external_id_to_cpu_op[record['args']['External id']] = record['name'] + external_id_to_cpu_op[record["args"]["External id"]] = record["name"] # add the metadata to triton kernels - for record in perf_trace['traceEvents']: - is_triton_kernel = record.get('cat') == 'kernel' and 'triton' in record.get('name', '') + for record in perf_trace["traceEvents"]: + is_triton_kernel = record.get("cat") == "kernel" and "triton" in record.get( + "name", "" + ) if not is_triton_kernel: continue - op_name = external_id_to_cpu_op.get(record['args']['External id']) + op_name = external_id_to_cpu_op.get(record["args"]["External id"]) if op_name is None: continue start, end = name_to_start_end[op_name] - triton_code = torch_logs_only[start:end+1] - s = '' + triton_code = torch_logs_only[start : end + 1] + s = "" for line in triton_code: - s += f'{line}\n' - record['args']['triton_code'] = s + s += f"{line}\n" + record["args"]["triton_code"] = s # write the modified file # out_file = perf_trace_file.replace('.json', '') + '_with_metadata.json' - with open(modified_perf_trace_file, 'w') as f: + with open(modified_perf_trace_file, "w") as f: json.dump(perf_trace, f) @@ -338,8 +344,10 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs): n_iter = 5 with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: for idx in range(n_iter): - f(*args, **kwargs) - data = profiler_output_to_filtered_time_by_kernel_name(prof, n_iter, num_leaf_tensors=0) + f(*args, **kwargs) + data = profiler_output_to_filtered_time_by_kernel_name( + prof, n_iter, num_leaf_tensors=0 + ) # there is only 1 key, aten::mm or aten::_scaled_mm, with unit nanoseconds assert len(data) == 1 if "aten::mm" in data: @@ -348,5 +356,3 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs): return data["aten::_scaled_mm"] / 1e6 / n_iter else: raise AssertionError("unexpected format of data") - - diff --git a/benchmarks/fused_benchmark_utils.py b/benchmarks/fused_benchmark_utils.py index 5456154c30..da0cf99ec3 100644 --- a/benchmarks/fused_benchmark_utils.py +++ b/benchmarks/fused_benchmark_utils.py @@ -47,7 +47,6 @@ def _ref_op( step_size=STEP_SIZE, **kwargs, ): - # Step 1: Down proj grad M, N = grad.shape if M >= N: diff --git a/benchmarks/intmm.py b/benchmarks/intmm.py index 5879f14053..ffad3cc27a 100644 --- a/benchmarks/intmm.py +++ b/benchmarks/intmm.py @@ -1,21 +1,16 @@ import argparse import csv import itertools -import math -import sys import pathlib +import sys import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_2 - # Check if CUDA is available, if not, exit the script if not torch.cuda.is_available(): print("CUDA is not available. Exiting the script.") sys.exit(0) -import torch.nn.functional as F -import torch.utils.benchmark as benchmark from torchao.kernel.intmm import int_matmul, int_scaled_matmul torch._dynamo.config.cache_size_limit = 128 @@ -72,7 +67,6 @@ def run_int_scaled_mm_benchmark(x, w, b): def run_benchmarks(shapes): print("fn,m,k,n,fp_time,int_mm_time,ratio") - positives = [] dtype = torch.bfloat16 device = "cuda" for fn, (m, k, n) in itertools.product( @@ -90,7 +84,9 @@ def run_benchmarks(shapes): if __name__ == "__main__": parser = argparse.ArgumentParser(description="integer matmul benchmarks") - parser.add_argument("--file_path", type=str, required=True, help="Path to csv file with shapes") + parser.add_argument( + "--file_path", type=str, required=True, help="Path to csv file with shapes" + ) args = parser.parse_args() # Access the file path provided as an argument file_path = args.file_path diff --git a/benchmarks/print_config_shapes.py b/benchmarks/print_config_shapes.py index 7c27378352..4cf232c862 100644 --- a/benchmarks/print_config_shapes.py +++ b/benchmarks/print_config_shapes.py @@ -1,5 +1,3 @@ -import torchao - from torchao.kernel import autotuner configs = autotuner._load_best_configs() diff --git a/benchmarks/quantized_training/benchmark_int8mm.py b/benchmarks/quantized_training/benchmark_int8mm.py index 0e85cd8313..cc2d5cf0db 100644 --- a/benchmarks/quantized_training/benchmark_int8mm.py +++ b/benchmarks/quantized_training/benchmark_int8mm.py @@ -41,5 +41,7 @@ def bench_f(f, *args): sample = [M, N, K, bf16_time / i8_time, bf16_time / i8_dequant_time] data.append(sample) -df = pd.DataFrame(data, columns=["M", "N", "K", "CuBLAS INT8 speedup", "Triton INT8 dequant speedup"]) +df = pd.DataFrame( + data, columns=["M", "N", "K", "CuBLAS INT8 speedup", "Triton INT8 dequant speedup"] +) print(df.to_markdown()) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 0e6b79f60c..25b37921b6 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -23,7 +23,12 @@ from tqdm import tqdm from torchao import quantize_ -from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs, RMSNorm +from torchao._models.llama.model import ( + ModelArgs, + RMSNorm, + Transformer, + transformer_configs, +) from torchao.prototype import low_bit_optim from torchao.prototype.quantized_training import ( bitnet_training, @@ -75,9 +80,13 @@ def get_tinystories(): tokens_list = [] chunk_size = 10_000 - for i in tqdm(range(0, len(stories), chunk_size), desc="Tokenizing TinyStories"): + for i in tqdm( + range(0, len(stories), chunk_size), desc="Tokenizing TinyStories" + ): chunk = stories[i : min(i + chunk_size, len(stories))] - tokens_list.extend(tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4)) + tokens_list.extend( + tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4) + ) total_size = sum(len(x) for x in tokens_list) mmap_tokens = np.memmap(save_path, dtype=np.uint16, mode="w+", shape=total_size) @@ -155,13 +164,21 @@ def insert_rmsnorm(module: torch.nn.Module): # don't apply int8_mixed_precision to LM head, since it can cause convergence issue. # TODO: might want to do the same for int8_weight_only to standardize. if args.quantize == "int8_weight_only": - quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) + quantize_( + model, int8_weight_only_quantized_training(), set_inductor_config=False + ) elif args.quantize == "int8_mixed_precision": - quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) + quantize_( + model.layers, int8_mixed_precision_training(), set_inductor_config=False + ) elif args.quantize == "int8_mixed_precision_module_swap": - quantize_(model.layers, int8_mixed_precision_training(module_swap=True), set_inductor_config=False) + quantize_( + model.layers, + int8_mixed_precision_training(module_swap=True), + set_inductor_config=False, + ) elif args.quantize == "bitnet": quantize_(model.layers, bitnet_training(), set_inductor_config=False) @@ -195,8 +212,14 @@ def insert_rmsnorm(module: torch.nn.Module): while step < args.n_steps: # randomly select a continuous chunk, then reshape it - idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item() - batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long() + idx = torch.randint( + 0, data.shape[0] - args.batch_size * args.seq_len, (1,) + ).item() + batch = ( + data[idx : idx + args.batch_size * args.seq_len] + .view(args.batch_size, args.seq_len) + .long() + ) with torch.autocast("cuda", torch.bfloat16, enabled=args.bf16_amp): loss = _get_loss(model, batch) @@ -220,7 +243,10 @@ def insert_rmsnorm(module: torch.nn.Module): if step % args.log_interval == 0: time1 = time.time() - log_dict = dict(tokens_per_second=(args.log_interval * args.batch_size * args.seq_len) / (time1 - time0)) + log_dict = dict( + tokens_per_second=(args.log_interval * args.batch_size * args.seq_len) + / (time1 - time0) + ) time0 = time1 run.log(log_dict, step=step) diff --git a/examples/sam2_amg_server/compile_export_utils.py b/examples/sam2_amg_server/compile_export_utils.py index 9237a2dd58..27c4ee7b01 100644 --- a/examples/sam2_amg_server/compile_export_utils.py +++ b/examples/sam2_amg_server/compile_export_utils.py @@ -1,7 +1,9 @@ -import torch import time from pathlib import Path from typing import Optional + +import torch + from torchao._models.sam2.sam2_image_predictor import SAM2ImagePredictor # Tools used to avoid compilation cold start and dynamo cache lookups @@ -13,6 +15,7 @@ TASK_TYPES = ["amg", "sps", "mps"] + # NOTE: We have to declare a separate class, because torch.export demands it. # We build this explicitly for the sole purpose of exporting _predict_masks # We made sure _predict_masks is fullgraph=True compileable so it can be exported @@ -20,12 +23,14 @@ # any expected recompilations. We'll add in guards to prevent unexpectedly # large inputs. class SAM2ImagePredictor_predict_masks(torch.nn.Module): - def __init__(self, - predictor: Optional[SAM2ImagePredictor], - batch_size=1, - points_per_batch=1024, - aoti_compiled_model=None, - furious=False): + def __init__( + self, + predictor: Optional[SAM2ImagePredictor], + batch_size=1, + points_per_batch=1024, + aoti_compiled_model=None, + furious=False, + ): super().__init__() self.predictor = predictor self.batch_size = batch_size @@ -33,16 +38,18 @@ def __init__(self, self.aoti_compiled_model = aoti_compiled_model self.furious = furious - def forward(self, - high_res_feats, - image_embed, - image_pe, - point_coords, - point_labels, - boxes: Optional[torch.Tensor] = None, - mask_input: Optional[torch.Tensor] = None, - multimask_output: bool = True, - img_idx: int = -1): + def forward( + self, + high_res_feats, + image_embed, + image_pe, + point_coords, + point_labels, + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + img_idx: int = -1, + ): assert high_res_feats[0].size() == (self.batch_size, 32, 256, 256) assert high_res_feats[1].size() == (self.batch_size, 64, 128, 128) if self.furious: @@ -69,33 +76,39 @@ def forward(self, assert img_idx == -1 if self.predictor is None: assert self.aoti_compiled_model is not None - return self.aoti_compiled_model(high_res_feats, - image_embed, - image_pe, - point_coords, - point_labels, - boxes=boxes, - mask_input=mask_input, - multimask_output=multimask_output, - img_idx=img_idx) - return self.predictor._predict_masks(high_res_feats, - image_embed, - image_pe, - point_coords, - point_labels, - boxes=boxes, - mask_input=mask_input, - multimask_output=multimask_output, - img_idx=img_idx) - - -def aot_compile(model_directory, - name, - fn, - sample_args, - sample_kwargs=None, - options=None, - overwrite=False): + return self.aoti_compiled_model( + high_res_feats, + image_embed, + image_pe, + point_coords, + point_labels, + boxes=boxes, + mask_input=mask_input, + multimask_output=multimask_output, + img_idx=img_idx, + ) + return self.predictor._predict_masks( + high_res_feats, + image_embed, + image_pe, + point_coords, + point_labels, + boxes=boxes, + mask_input=mask_input, + multimask_output=multimask_output, + img_idx=img_idx, + ) + + +def aot_compile( + model_directory, + name, + fn, + sample_args, + sample_kwargs=None, + options=None, + overwrite=False, +): path = Path(model_directory) / Path(f"{name}.pt2") if path.exists() and not overwrite: raise ValueError(f"{path} already exists and overwrite is {overwrite}") @@ -107,6 +120,7 @@ def aot_compile(model_directory, } from torch.export import export_for_inference + exported = export_for_inference(fn, sample_args, sample_kwargs) output_path = torch._inductor.aoti_compile_and_package( exported, @@ -121,7 +135,6 @@ def aot_load(path): class FunctionModel(torch.nn.Module): - def __init__(self, module, fn_name): super().__init__() self.module = module @@ -131,69 +144,123 @@ def forward(self, *args): return getattr(self.module, self.fn_name)(*args) -def export_model(mask_generator, - model_directory, - task_type, - furious=False, - batch_size=1, - points_per_batch=None, - overwrite=False): +def export_model( + mask_generator, + model_directory, + task_type, + furious=False, + batch_size=1, + points_per_batch=None, + overwrite=False, +): if furious: set_furious(mask_generator) assert task_type in TASK_TYPES, f"Expected {task_type} to be one of {TASK_TYPES}" if task_type in ["sps", "amg"]: - assert points_per_batch is not None, f"Specify points_per_batch for task {task_type}" + assert ( + points_per_batch is not None + ), f"Specify points_per_batch for task {task_type}" if task_type == "sps": - assert points_per_batch == 1, f"Expected points_per_batch set to 1 for {task_type} but got {points_per_batch}" - + assert ( + points_per_batch == 1 + ), f"Expected points_per_batch set to 1 for {task_type} but got {points_per_batch}" example_input = torch.empty(batch_size, 3, 1024, 1024) example_input = example_input.to(mask_generator.predictor._image_dtype) example_input = (example_input.to(mask_generator.predictor.device),) - aot_compile(model_directory, - "sam2_image_encoder", - mask_generator.predictor.model.image_encoder, - example_input, - overwrite=overwrite) + aot_compile( + model_directory, + "sam2_image_encoder", + mask_generator.predictor.model.image_encoder, + example_input, + overwrite=overwrite, + ) print(f"{task_type} cannot export _predict_masks") return if task_type in ["sps"]: - example_input_high_res_feats = [torch.randn(batch_size, 32, 256, 256, dtype=mask_generator.predictor._image_dtype, device=mask_generator.predictor.device), - torch.randn(batch_size, 64, 128, 128, dtype=mask_generator.predictor._image_dtype, device=mask_generator.predictor.device)] - example_input_image_embed = torch.randn(batch_size, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) - example_input_image_pe = torch.randn(batch_size, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) - example_input_point_coords = torch.randn(points_per_batch, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device) - example_input_point_labels = torch.ones(points_per_batch, 1, dtype=torch.int32, device=mask_generator.predictor.device) - example_input_args = (example_input_high_res_feats, - example_input_image_embed, - example_input_image_pe, - example_input_point_coords, - example_input_point_labels) - - example_input_kwargs = {"boxes": None, - "mask_input": None, - "multimask_output": True, - "img_idx": -1, - } - - sam2_image_predict_masks = SAM2ImagePredictor_predict_masks(mask_generator.predictor, - batch_size=batch_size, - points_per_batch=points_per_batch, - furious=furious) - aot_compile(model_directory, - "sam2_image_predict_masks", - sam2_image_predict_masks, - example_input_args, - sample_kwargs=example_input_kwargs, - overwrite=overwrite) + example_input_high_res_feats = [ + torch.randn( + batch_size, + 32, + 256, + 256, + dtype=mask_generator.predictor._image_dtype, + device=mask_generator.predictor.device, + ), + torch.randn( + batch_size, + 64, + 128, + 128, + dtype=mask_generator.predictor._image_dtype, + device=mask_generator.predictor.device, + ), + ] + example_input_image_embed = torch.randn( + batch_size, + 256, + 64, + 64, + dtype=torch.float32, + device=mask_generator.predictor.device, + ) + example_input_image_pe = torch.randn( + batch_size, + 256, + 64, + 64, + dtype=torch.float32, + device=mask_generator.predictor.device, + ) + example_input_point_coords = torch.randn( + points_per_batch, + 1, + 2, + dtype=torch.float32, + device=mask_generator.predictor.device, + ) + example_input_point_labels = torch.ones( + points_per_batch, + 1, + dtype=torch.int32, + device=mask_generator.predictor.device, + ) + example_input_args = ( + example_input_high_res_feats, + example_input_image_embed, + example_input_image_pe, + example_input_point_coords, + example_input_point_labels, + ) + + example_input_kwargs = { + "boxes": None, + "mask_input": None, + "multimask_output": True, + "img_idx": -1, + } + + sam2_image_predict_masks = SAM2ImagePredictor_predict_masks( + mask_generator.predictor, + batch_size=batch_size, + points_per_batch=points_per_batch, + furious=furious, + ) + aot_compile( + model_directory, + "sam2_image_predict_masks", + sam2_image_predict_masks, + example_input_args, + sample_kwargs=example_input_kwargs, + overwrite=overwrite, + ) else: print(f"{task_type} cannot export _predict_masks") class LoadedModel(torch.nn.Module): - def __init__(self, aoti_compiled_model): super().__init__() self.aoti_compiled_model = aoti_compiled_model @@ -203,7 +270,6 @@ def forward(self, *args, **kwargs): class LoadedDecoder(torch.nn.Module): - def __init__(self, aoti_compiled_model, other): super().__init__() self.aoti_compiled_model = aoti_compiled_model @@ -216,12 +282,14 @@ def get_dense_pe(self, *args, **kwargs) -> torch.Tensor: return self.other.get_dense_pe(*args, **kwargs) -def load_exported_model(mask_generator, - model_directory, - task_type, - furious=False, - batch_size=1, - points_per_batch=1024): +def load_exported_model( + mask_generator, + model_directory, + task_type, + furious=False, + batch_size=1, + points_per_batch=1024, +): if furious: set_furious(mask_generator) assert task_type in TASK_TYPES, f"Expected {task_type} to be one of {TASK_TYPES}" @@ -239,7 +307,7 @@ def load_exported_model(mask_generator, if task_type in ["amg", "mps"]: return mask_generator - path = Path(model_directory) / Path(f"sam2_image_predict_masks.pt2") + path = Path(model_directory) / Path("sam2_image_predict_masks.pt2") assert path.exists(), f"Expected {path} to exist" print(f"Start load from {path}") pkg = torch._inductor.aoti_load_package(str(path)) @@ -249,17 +317,24 @@ def load_exported_model(mask_generator, assert points_per_batch == 1 if task_type == "mps": assert points_per_batch is None - pkg_m = SAM2ImagePredictor_predict_masks(None, - batch_size=batch_size, - points_per_batch=points_per_batch, - aoti_compiled_model=pkg, - furious=furious) + pkg_m = SAM2ImagePredictor_predict_masks( + None, + batch_size=batch_size, + points_per_batch=points_per_batch, + aoti_compiled_model=pkg, + furious=furious, + ) mask_generator.predictor._predict_masks = pkg_m.forward print(f"End load image encoder and predict masks. Took {time.time() - t0}s") -def set_fast(mask_generator, task_type, loaded_exported_model=False, allow_recompiles=True): +def set_fast( + mask_generator, task_type, loaded_exported_model=False, allow_recompiles=True +): + if task_type == "": + task_type = "amg" + assert task_type in TASK_TYPES, f"Expected {task_type} to be one of {TASK_TYPES}" if not loaded_exported_model: # TODO: Using CUDA graphs can cause numerical differences? @@ -296,22 +371,35 @@ def set_fast(mask_generator, task_type, loaded_exported_model=False, allow_recom ) import torchao + if allow_recompiles: # A bunch of extra compiles at module level # Note that this can cause recompilations! # We might want to guard on that - torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0 = torch.compile(fullgraph=True, dynamic=True)(torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0) - torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1 = torch.compile(fullgraph=True, dynamic=True)(torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1) - mask_generator.calculate_stability_score = torch.compile(fullgraph=True, dynamic=True)(mask_generator.calculate_stability_score) - mask_generator.batched_mask_to_box = torch.compile(fullgraph=True, dynamic=True)(mask_generator.batched_mask_to_box) + torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0 = torch.compile( + fullgraph=True, dynamic=True + )(torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0) + torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1 = torch.compile( + fullgraph=True, dynamic=True + )(torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1) + mask_generator.calculate_stability_score = torch.compile( + fullgraph=True, dynamic=True + )(mask_generator.calculate_stability_score) + mask_generator.batched_mask_to_box = torch.compile( + fullgraph=True, dynamic=True + )(mask_generator.batched_mask_to_box) def set_furious(mask_generator): - mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16) + mask_generator.predictor.model.image_encoder = ( + mask_generator.predictor.model.image_encoder.to(torch.float16) + ) # NOTE: Not baseline feature mask_generator.predictor._image_dtype = torch.float16 mask_generator.predictor._transforms_device = mask_generator.predictor.device - torch.set_float32_matmul_precision('high') - mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16) + torch.set_float32_matmul_precision("high") + mask_generator.predictor.model.sam_mask_decoder = ( + mask_generator.predictor.model.sam_mask_decoder.to(torch.float16) + ) # NOTE: Not baseline feature mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16 diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index 2fa216176e..8d0f4c8133 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -1,45 +1,37 @@ -import itertools -import requests -import uvicorn -import fire -import tempfile +import asyncio +import json import logging -import sys import time -import json +from contextlib import asynccontextmanager +from io import BytesIO from pathlib import Path -from typing import List, Optional +import cv2 +import fire +import matplotlib.pyplot as plt +import numpy as np +import requests import torch import torch._dynamo.config import torch._inductor.config -from fastapi.responses import Response +import uvicorn +from compile_export_utils import ( + export_model, + load_exported_model, + set_fast, + set_furious, +) from fastapi import FastAPI, File, UploadFile -from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware -from io import BytesIO -import shutil -from pydantic import BaseModel -import cv2 - -import matplotlib.pyplot as plt -import numpy as np +from fastapi.responses import StreamingResponse +from torch._inductor import config as inductorconfig -import asyncio -from contextlib import asynccontextmanager -import contextlib from torchao._models.utils import ( get_arch_name, - write_json_result_ossci, write_json_result_local, + write_json_result_ossci, ) -from compile_export_utils import set_fast -from compile_export_utils import set_furious -from compile_export_utils import load_exported_model -from compile_export_utils import export_model - -from torch._inductor import config as inductorconfig inductorconfig.triton.unique_kernel_names = True inductorconfig.coordinate_descent_tuning = True inductorconfig.coordinate_descent_check_all_directions = True @@ -48,12 +40,13 @@ # torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True + def download_file(url, download_dir): # Create the directory if it doesn't exist download_dir = Path(download_dir) download_dir.mkdir(parents=True, exist_ok=True) # Extract the file name from the URL - file_name = url.split('/')[-1] + file_name = url.split("/")[-1] # Define the full path for the downloaded file file_path = download_dir / file_name # Download the file @@ -61,110 +54,123 @@ def download_file(url, download_dir): response.raise_for_status() # Raise an error for bad responses # Write the file to the specified directory print(f"Downloading '{file_name}' to '{download_dir}'") - with open(file_path, 'wb') as file: + with open(file_path, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"Downloaded '{file_name}' to '{download_dir}'") + def example_shapes(): - return [(848, 480, 3), - (720, 1280, 3), - (848, 480, 3), - (1280, 720, 3), - (480, 848, 3), - (1080, 1920, 3), - (1280, 720, 3), - (1280, 720, 3), - (720, 1280, 3), - (848, 480, 3), - (480, 848, 3), - (864, 480, 3), - (1920, 1080, 3), - (1920, 1080, 3), - (1280, 720, 3), - (1232, 672, 3), - (848, 480, 3), - (848, 480, 3), - (1920, 1080, 3), - (1080, 1920, 3), - (480, 848, 3), - (848, 480, 3), - (480, 848, 3), - (480, 848, 3), - (720, 1280, 3), - (720, 1280, 3), - (900, 720, 3), - (848, 480, 3), - (864, 480, 3), - (360, 640, 3), - (360, 640, 3), - (864, 480, 3)] + return [ + (848, 480, 3), + (720, 1280, 3), + (848, 480, 3), + (1280, 720, 3), + (480, 848, 3), + (1080, 1920, 3), + (1280, 720, 3), + (1280, 720, 3), + (720, 1280, 3), + (848, 480, 3), + (480, 848, 3), + (864, 480, 3), + (1920, 1080, 3), + (1920, 1080, 3), + (1280, 720, 3), + (1232, 672, 3), + (848, 480, 3), + (848, 480, 3), + (1920, 1080, 3), + (1080, 1920, 3), + (480, 848, 3), + (848, 480, 3), + (480, 848, 3), + (480, 848, 3), + (720, 1280, 3), + (720, 1280, 3), + (900, 720, 3), + (848, 480, 3), + (864, 480, 3), + (360, 640, 3), + (360, 640, 3), + (864, 480, 3), + ] def example_shapes_2(): - return [(1080, 1920, 3), - (1920, 1080, 3), - (1920, 1080, 3), - (1080, 1920, 3), - (848, 480, 3), - (864, 480, 3), - (720, 1280, 3), - (864, 480, 3), - (848, 480, 3), - (848, 480, 3), - (848, 480, 3), - (848, 480, 3), - (720, 1280, 3), - (864, 480, 3), - (480, 848, 3), - (1280, 720, 3), - (720, 1280, 3), - (1080, 1920, 3), - (1080, 1920, 3), - (1280, 720, 3), - (1080, 1920, 3), - (1080, 1920, 3), - (720, 1280, 3), - (720, 1280, 3), - (1280, 720, 3), - (360, 640, 3), - (864, 480, 3), - (1920, 1080, 3), - (1080, 1920, 3), - (1920, 1080, 3), - (1920, 1080, 3), - (1080, 1920, 3)] + return [ + (1080, 1920, 3), + (1920, 1080, 3), + (1920, 1080, 3), + (1080, 1920, 3), + (848, 480, 3), + (864, 480, 3), + (720, 1280, 3), + (864, 480, 3), + (848, 480, 3), + (848, 480, 3), + (848, 480, 3), + (848, 480, 3), + (720, 1280, 3), + (864, 480, 3), + (480, 848, 3), + (1280, 720, 3), + (720, 1280, 3), + (1080, 1920, 3), + (1080, 1920, 3), + (1280, 720, 3), + (1080, 1920, 3), + (1080, 1920, 3), + (720, 1280, 3), + (720, 1280, 3), + (1280, 720, 3), + (360, 640, 3), + (864, 480, 3), + (1920, 1080, 3), + (1080, 1920, 3), + (1920, 1080, 3), + (1920, 1080, 3), + (1080, 1920, 3), + ] + # torch.set_float32_matmul_precision('high') + def iou(mask1, mask2): assert mask1.dim() == 2 assert mask2.dim() == 2 intersection = torch.logical_and(mask1, mask2) union = torch.logical_or(mask1, mask2) - return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2))) + return intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)) def show_anns(anns, rle_to_mask, sort_by_area=True, seed=None): if len(anns) == 0: return if sort_by_area: - sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) else: sorted_anns = anns ax = plt.gca() ax.set_autoscale_on(False) for ann in sorted_anns: - ann['segmentation'] = rle_to_mask(ann['segmentation']) - - img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) - img[:,:,3] = 0 + ann["segmentation"] = rle_to_mask(ann["segmentation"]) + + img = np.ones( + ( + sorted_anns[0]["segmentation"].shape[0], + sorted_anns[0]["segmentation"].shape[1], + 4, + ) + ) + img[:, :, 3] = 0 np.random.seed(seed) ms = [] for ann in sorted_anns: - m = ann['segmentation'] + m = ann["segmentation"] ms.append(torch.as_tensor(m)) color_mask = np.concatenate([np.random.random(3), [0.35]]) img[m] = color_mask @@ -174,9 +180,12 @@ def show_anns(anns, rle_to_mask, sort_by_area=True, seed=None): def profiler_runner(path, fn, *args, **kwargs): with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA], - record_shapes=True) as prof: + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + ) as prof: result = fn(*args, **kwargs) prof.export_chrome_trace(path) return result @@ -186,16 +195,15 @@ def memory_runner(path, fn, *args, **kwargs): print("Start memory recording") torch.cuda.synchronize() torch.cuda.memory._record_memory_history( - True, - trace_alloc_max_entries=100000, - trace_alloc_record_context=True + True, trace_alloc_max_entries=100000, trace_alloc_record_context=True ) result = fn(*args, **kwargs) torch.cuda.synchronize() snapshot = torch.cuda.memory._snapshot() print("Finish memory recording") import pickle - with open(path, 'wb') as f: + + with open(path, "wb") as f: pickle.dump(snapshot, f) # Use to convert pickle file into html # python torch/cuda/_memory_viz.py trace_plot .pickle -o .html @@ -218,9 +226,11 @@ def file_bytes_to_image_tensor(file_bytes, output_format="numpy"): if output_format == "numpy": return example_image if output_format not in ["torch"]: - raise ValueError("Expected output_format to be numpy or torch," - f" but got {output_format}") + raise ValueError( + "Expected output_format to be numpy or torch," f" but got {output_format}" + ) from torchvision.transforms import ToTensor + return ToTensor()(example_image) @@ -257,7 +267,6 @@ async def batch_worker(mask_generator, batch_size, *, pad_batch=True, furious=Fa batch.append(await request_queue.get()) if batch: - padded_batch = batch if pad_batch: padded_batch = batch + ([batch[-1]] * (batch_size - len(batch))) @@ -274,7 +283,9 @@ async def lifespan(app: FastAPI): mask_generator = app.state.mask_generator batch_size = app.state.batch_size furious = app.state.furious - task = asyncio.create_task(batch_worker(mask_generator, batch_size, furious=furious)) + task = asyncio.create_task( + batch_worker(mask_generator, batch_size, furious=furious) + ) yield # Shutdown logic (if needed) task.cancel() @@ -290,7 +301,7 @@ def benchmark_fn(func, inp, mask_generator, warmup=3, runs=10): t = time.time() for _ in range(runs): func(inp, mask_generator) - avg_time_per_run = (time.time() - t)/runs + avg_time_per_run = (time.time() - t) / runs print(f"Benchmark took {avg_time_per_run}s per iteration.") max_memory_allocated_bytes, max_memory_allocated_percentage = max_memory_allocated() return avg_time_per_run, max_memory_allocated_bytes, max_memory_allocated_percentage @@ -299,9 +310,13 @@ def benchmark_fn(func, inp, mask_generator, warmup=3, runs=10): def max_memory_allocated_stats(): max_memory_allocated_bytes = torch.cuda.max_memory_allocated() _, total_memory = torch.cuda.mem_get_info() - max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) - return {"bytes": max_memory_allocated_bytes, - "percentage": max_memory_allocated_percentage} + max_memory_allocated_percentage = int( + 100 * (max_memory_allocated_bytes / total_memory) + ) + return { + "bytes": max_memory_allocated_bytes, + "percentage": max_memory_allocated_percentage, + } def max_memory_allocated(): @@ -309,11 +324,15 @@ def max_memory_allocated(): mib = stats["bytes"] >> 20 print(f"max_memory_allocated_bytes: {mib}MiB") print(f"max_memory_allocated_percentage: {stats['percentage']}%") + return mib, stats["percentage"] def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False): from compare_rle_lists import compare_masks - miou, equal_count = compare_masks(masks, ref_masks, order_by_area=order_by_area, verbose=verbose) + + miou, equal_count = compare_masks( + masks, ref_masks, order_by_area=order_by_area, verbose=verbose + ) if equal_count == len(masks): print("Masks exactly match reference.") else: @@ -321,26 +340,26 @@ def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False): MODEL_TYPES_TO_CONFIG = { - "tiny": "sam2.1_hiera_t.yaml", - "small": "sam2.1_hiera_s.yaml", - "plus": "sam2.1_hiera_b+.yaml", - "large": "sam2.1_hiera_l.yaml", - } + "tiny": "sam2.1_hiera_t.yaml", + "small": "sam2.1_hiera_s.yaml", + "plus": "sam2.1_hiera_b+.yaml", + "large": "sam2.1_hiera_l.yaml", +} MODEL_TYPES_TO_MODEL = { - "tiny": "sam2.1_hiera_tiny.pt", - "small": "sam2.1_hiera_small.pt", - "plus": "sam2.1_hiera_base_plus.pt", - "large": "sam2.1_hiera_large.pt", - } + "tiny": "sam2.1_hiera_tiny.pt", + "small": "sam2.1_hiera_small.pt", + "plus": "sam2.1_hiera_base_plus.pt", + "large": "sam2.1_hiera_large.pt", +} MODEL_TYPES_TO_URL = { - "tiny": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", - "small": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", - "plus": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", - "large": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", - } + "tiny": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + "small": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + "plus": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + "large": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", +} def main_docstring(): @@ -353,107 +372,150 @@ def main_docstring(): def model_type_to_paths(checkpoint_path, model_type): if model_type not in MODEL_TYPES_TO_CONFIG.keys(): - raise ValueError(f"Expected model_type to be one of {', '.join(MODEL_TYPES_TO_MODEL.keys())} but got {model_type}") + raise ValueError( + f"Expected model_type to be one of {', '.join(MODEL_TYPES_TO_MODEL.keys())} but got {model_type}" + ) sam2_checkpoint = Path(checkpoint_path) / Path(MODEL_TYPES_TO_MODEL[model_type]) if not sam2_checkpoint.exists(): - print(f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading.") + print( + f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading." + ) download_file(MODEL_TYPES_TO_URL[model_type], checkpoint_path) assert sam2_checkpoint.exists(), "Can't find downloaded file. Please open an issue." model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}" return sam2_checkpoint, model_cfg -def set_autoquant(mask_generator): +def set_autoquant(mask_generator, autoquant_type, min_sqnr): import torchao from torchao import autoquant + # NOTE: Not baseline feature - mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + if autoquant_type == "autoquant": + mask_generator.predictor.model.image_encoder = autoquant( + mask_generator.predictor.model.image_encoder, min_sqnr=min_sqnr + ) + elif autoquant_type == "autoquant-fp": + mask_generator.predictor.model.image_encoder = autoquant( + mask_generator.predictor.model.image_encoder, + qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, + min_sqnr=min_sqnr, + ) + elif autoquant_type == "autoquant-all": + mask_generator.predictor.model.image_encoder = autoquant( + mask_generator.predictor.model.image_encoder, + qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST, + min_sqnr=min_sqnr, + ) + else: + raise ValueError(f"Unexpected autoquant type: {autoquant_type}") + mask_generator.predictor._transforms_device = mask_generator.predictor.device - torch.set_float32_matmul_precision('high') + torch.set_float32_matmul_precision("high") # NOTE: this fails when we run # python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant --unittest # https://gist.github.com/jerryzh168/d337cb5de0a1dec306069fe48ac8225e # mask_generator.predictor.model.sam_mask_decoder = autoquant(mask_generator.predictor.model.sam_mask_decoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) -def main(checkpoint_path, - model_type, - baseline=False, - fast=False, - furious=False, - use_autoquant=False, - unittest=False, - benchmark=False, - profile=None, - memory_profile=None, - verbose=False, - points_per_batch=64, - port=5000, - host="127.0.0.1", - dry=False, - batch_size=1, - load_fast="", - save_fast="", - output_json_path=None, - output_json_local=False): +def main( + checkpoint_path, + model_type, + baseline=False, + fast=False, + furious=False, + autoquant_type=None, + min_sqnr=None, + unittest=False, + benchmark=False, + profile=None, + memory_profile=None, + verbose=False, + points_per_batch=64, + port=5000, + host="127.0.0.1", + dry=False, + batch_size=1, + load_fast="", + save_fast="", + output_json_path=None, + output_json_local=False, +): if verbose: - logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S') + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) logging.info(f"Running with fast set to {fast} and furious set to {furious}") logging.info(f"Running with port {port} and host {host}") logging.info(f"Running with batch size {batch_size}") if baseline: assert batch_size == 1, "baseline only supports batch size 1." - logging.info(f"Importing sam2 from outside of torchao. If this errors, install https://github.com/facebookresearch/sam2") - from sam2.build_sam import build_sam2 + logging.info( + "Importing sam2 from outside of torchao. If this errors, install https://github.com/facebookresearch/sam2" + ) from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + from sam2.build_sam import build_sam2 from sam2.utils.amg import rle_to_mask else: + from torchao._models.sam2.automatic_mask_generator import ( + SAM2AutomaticMaskGenerator, + ) from torchao._models.sam2.build_sam import build_sam2 - from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from torchao._models.sam2.utils.amg import rle_to_mask device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}") - sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) + sam2 = build_sam2( + model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False + ) logging.info(f"Using {points_per_batch} points_per_batch") - mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") + mask_generator = SAM2AutomaticMaskGenerator( + sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle" + ) if load_fast != "": - load_exported_model(mask_generator, load_fast, "amg", furious, batch_size, points_per_batch) + load_exported_model( + mask_generator, load_fast, "amg", furious, batch_size, points_per_batch + ) if furious: set_furious(mask_generator) if save_fast != "": - assert load_fast == "", "Can't save compiled models while loading them with --load-fast." + assert ( + load_fast == "" + ), "Can't save compiled models while loading them with --load-fast." assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." print(f"Saving compiled models under directory {save_fast}") - export_model(mask_generator, - save_fast, - "amg", - furious=furious, - batch_size=batch_size, - points_per_batch=points_per_batch) + export_model( + mask_generator, + save_fast, + "amg", + furious=furious, + batch_size=batch_size, + points_per_batch=points_per_batch, + ) if fast: assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." set_fast(mask_generator, load_fast) # since autoquant is replicating what furious mode is doing, don't use these two together - if use_autoquant: + if autoquant_type is not None: assert not furious, "use autoquant can't be used together with furious" - set_autoquant(mask_generator) + set_autoquant(mask_generator, autoquant_type, min_sqnr) - with open('dog.jpg', 'rb') as f: + with open("dog.jpg", "rb") as f: output_format = "numpy" if baseline else "torch" - image_tensor = file_bytes_to_image_tensor(bytearray(f.read()), - output_format=output_format) + image_tensor = file_bytes_to_image_tensor( + bytearray(f.read()), output_format=output_format + ) # from torchvision import io as tio # img_bytes_tensor = tio.read_file('dog.jpg') @@ -469,7 +531,9 @@ def main(checkpoint_path, else: # TODO: Transpose dog image to create diversity in input image shape logging.info(f"batch size {batch_size} unittest") - all_masks = image_tensors_to_masks([image_tensor] * batch_size, mask_generator) + all_masks = image_tensors_to_masks( + [image_tensor] * batch_size, mask_generator + ) all_masks = [masks_to_rle_dict(masks) for masks in all_masks] ref_masks = json.loads(open("dog_rle.json").read()) for masks in all_masks: @@ -480,17 +544,24 @@ def main(checkpoint_path, if batch_size == 1: result = benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator) else: - result = benchmark_fn(image_tensors_to_masks, [image_tensor] * batch_size, mask_generator) + result = benchmark_fn( + image_tensors_to_masks, [image_tensor] * batch_size, mask_generator + ) for i, shapes in enumerate([example_shapes(), example_shapes_2()]): print(f"batch size {batch_size} example shapes {i} benchmark") - random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in shapes] + random_images = [ + np.random.randint(0, 256, size=size, dtype=np.uint8) for size in shapes + ] if batch_size > len(random_images): num_repeat = (len(random_images) + batch_size) // batch_size random_images = num_repeat * random_images if batch_size == 1: - [benchmark_fn(image_tensor_to_masks, r, mask_generator) for r in random_images] + [ + benchmark_fn(image_tensor_to_masks, r, mask_generator) + for r in random_images + ] else: random_images = random_images[:batch_size] print("len(random_images): ", len(random_images)) @@ -500,12 +571,44 @@ def main(checkpoint_path, headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"] name = "sam2-" + model_type arch = get_arch_name() - dtype = "autoquant" if use_autoquant else "noquant" - avg_time_per_run, max_memory_allocated_bytes, max_memory_allocated_percentage = result - memory_result = [name, dtype, device, arch, "memory(MiB)", max_memory_allocated_bytes, None] - memory_percent_result = [name, dtype, device, arch, "memory(%)", max_memory_allocated_percentage, None] - performance_result = [name, dtype, device, arch, "time_s(avg)", avg_time_per_run, None] - write_json_result = write_json_result_local if output_json_local else write_json_result_ossci + dtype = autoquant_type or "noquant" + ( + avg_time_per_run, + max_memory_allocated_bytes, + max_memory_allocated_percentage, + ) = result + memory_result = [ + name, + dtype, + device, + arch, + "memory(MiB)", + max_memory_allocated_bytes, + None, + ] + memory_percent_result = [ + name, + dtype, + device, + arch, + "memory(%)", + max_memory_allocated_percentage, + None, + ] + performance_result = [ + name, + dtype, + device, + arch, + "time_s(avg)", + avg_time_per_run, + None, + ] + write_json_result = ( + write_json_result_local + if output_json_local + else write_json_result_ossci + ) write_json_result(output_json_path, headers, memory_result) write_json_result(output_json_path, headers, memory_percent_result) write_json_result(output_json_path, headers, performance_result) @@ -513,16 +616,30 @@ def main(checkpoint_path, if profile is not None: print(f"Saving profile under {profile}") if batch_size == 1: - profiler_runner(profile, image_tensor_to_masks, image_tensor, mask_generator) + profiler_runner( + profile, image_tensor_to_masks, image_tensor, mask_generator + ) else: - profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator) + profiler_runner( + profile, + image_tensors_to_masks, + [image_tensor] * batch_size, + mask_generator, + ) if memory_profile is not None: print(f"Saving memory profile under {memory_profile}") if batch_size == 1: - memory_runner(memory_profile, image_tensor_to_masks, image_tensor, mask_generator) + memory_runner( + memory_profile, image_tensor_to_masks, image_tensor, mask_generator + ) else: - memory_runner(memory_profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator) + memory_runner( + memory_profile, + image_tensors_to_masks, + [image_tensor] * batch_size, + mask_generator, + ) if dry: return @@ -555,24 +672,25 @@ async def upload_image(image: UploadFile = File(...)): response_future = asyncio.Future() await request_queue.put((image_tensor, response_future)) masks = await response_future - + # Create figure and ensure it's closed after generating response fig = plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100) plt.imshow(image_tensor) show_anns(masks, rle_to_mask) - plt.axis('off') + plt.axis("off") plt.tight_layout() - + buf = BytesIO() - plt.savefig(buf, format='png') + plt.savefig(buf, format="png") buf.seek(0) plt.close(fig) # Close figure after we're done with it - + return StreamingResponse(buf, media_type="image/png") # uvicorn.run(app, host=host, port=port, log_level="info") uvicorn.run(app, host=host, port=port) + main.__doc__ = main_docstring() if __name__ == "__main__": fire.Fire(main) diff --git a/ruff.toml b/ruff.toml index 3df55e96ea..a044e95b34 100644 --- a/ruff.toml +++ b/ruff.toml @@ -5,6 +5,7 @@ include = [ "torchao/**/*.py", "test/**/*.py", + "benchmarks/**/*.py", ] exclude = [ diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index a111e3e7c8..3e466a5d1c 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -266,6 +266,7 @@ def main( "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ), quantization: Optional[str] = None, + min_sqnr: Optional[float] = None, sparsity: Optional[str] = None, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, @@ -706,6 +707,7 @@ def ffn_or_attn_only(mod, fqn): manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, + min_sqnr=min_sqnr, ) elif "autoquant-float8" == quantization: model = autoquant( @@ -713,6 +715,7 @@ def ffn_or_attn_only(mod, fqn): manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, + min_sqnr=min_sqnr, ) elif "autoquant-fp" == quantization: model = autoquant( @@ -720,6 +723,7 @@ def ffn_or_attn_only(mod, fqn): manual=True, qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, + min_sqnr=min_sqnr, ) elif "autoquant-sparse" == quantization: model = autoquant( @@ -727,6 +731,7 @@ def ffn_or_attn_only(mod, fqn): manual=True, qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, example_input=inputs, + min_sqnr=min_sqnr, ) elif "autoquant-gemlite-int4" == quantization: import os @@ -742,6 +747,7 @@ def ffn_or_attn_only(mod, fqn): manual=True, qtensor_class_list=torchao.quantization.GEMLITE_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, + min_sqnr=min_sqnr, ) elif "autoquant-all" == quantization: try: @@ -761,9 +767,12 @@ def ffn_or_attn_only(mod, fqn): manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST, example_input=inputs, + min_sqnr=min_sqnr, ) else: - model = autoquant(model, manual=True, example_input=inputs) + model = autoquant( + model, manual=True, example_input=inputs, min_sqnr=min_sqnr + ) generate( model, @@ -1015,12 +1024,30 @@ def callback(x): f.close() if output_json_path: - headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"] + headers = [ + "name", + "dtype", + "min_sqnr", + "device", + "arch", + "metric", + "actual", + "target", + ] name = checkpoint_path.parent.name arch = get_arch_name() dtype = quantization or "noquant" - memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None] - performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None] + memory_result = [name, dtype, min_sqnr, device, arch, "mem/s", bandwidth, None] + performance_result = [ + name, + dtype, + min_sqnr, + device, + arch, + "tok/s", + tokpersec, + None, + ] write_json_result = ( write_json_result_local if output_json_local else write_json_result_ossci ) @@ -1073,6 +1100,14 @@ def callback(x): + "embed-int8wo, marlin_qqq, gemlite---, int8adq-int4w-symm" ), ) + parser.add_argument( + "--min_sqnr", + type=float, + default=None, + help=( + "min sqnr for quantizing v.s. not quantizing a layer, used in autoquant options", + ), + ) parser.add_argument( "-s", "--sparsity", @@ -1148,6 +1183,7 @@ def callback(x): args.temperature, args.checkpoint_path, args.quantization, + args.min_sqnr, args.sparsity, args.kv_cache_quantization, args.cache_size, diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index 4dd976bee6..1a082d47b0 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -284,6 +284,7 @@ def run( use_compile="False", use_compile_decoder=False, compress=None, + min_sqnr=None, num_workers=0, use_rel_pos=True, pad_input_image_batch=True, @@ -457,6 +458,7 @@ def mlp_only(mod, name): example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + min_sqnr=min_sqnr, ) elif "autoquant-float8" == compress: autoquant( @@ -464,6 +466,7 @@ def mlp_only(mod, name): example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, + min_sqnr=min_sqnr, ) elif "autoquant-sparse" == compress: autoquant( @@ -471,6 +474,7 @@ def mlp_only(mod, name): example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, + min_sqnr=min_sqnr, ) elif "autoquant-all" == compress: autoquant( @@ -478,10 +482,14 @@ def mlp_only(mod, name): example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST, + min_sqnr=min_sqnr, ) else: autoquant( - predictor.model.image_encoder, example_input=example_input, manual=True + predictor.model.image_encoder, + example_input=example_input, + manual=True, + min_sqnr=min_sqnr, ) predictor.model.image_encoder(example_input) predictor.model.image_encoder.finalize_autoquant() @@ -630,20 +638,39 @@ def mlp_only(mod, name): f.write(vals + "\n") if output_json_path: - headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"] + headers = [ + "name", + "dtype", + "min_sqnr", + "device", + "arch", + "metric", + "actual", + "target", + ] name = sam_model_type arch = get_arch_name() dtype = compress or "noquant" memory_result = [ name, dtype, + min_sqnr, device, arch, "memory(MiB)", max_memory_allocated_bytes, None, ] - performance_result = [name, dtype, device, arch, "img_s(avg)", img_s, None] + performance_result = [ + name, + dtype, + min_sqnr, + device, + arch, + "img_s(avg)", + img_s, + None, + ] write_json_result = ( write_json_result_local if output_json_local else write_json_result_ossci ) diff --git a/torchao/_models/utils.py b/torchao/_models/utils.py index a5fa1576d6..bdbb439037 100644 --- a/torchao/_models/utils.py +++ b/torchao/_models/utils.py @@ -30,6 +30,7 @@ def write_json_result_ossci(output_json_path, headers, row): "name": "TorchAO benchmark", "mode": "inference", "dtype": mapping_headers["dtype"], + "min_sqnr": mapping_headers["min_sqnr"], "extra_info": { "device": mapping_headers["device"], "arch": mapping_headers["arch"], @@ -38,7 +39,7 @@ def write_json_result_ossci(output_json_path, headers, row): "model": { "name": mapping_headers["name"], "type": "model", - "origins": ["torchao/_models"], + "origins": ["torchao"], }, "metric": { "name": mapping_headers["metric"], @@ -79,6 +80,7 @@ def write_json_result_local(output_json_path, headers, row): "name": "TorchAO benchmark", "mode": "inference", "dtype": mapping_headers["dtype"], + "min_sqnr": mapping_headers["min_sqnr"], "extra_info": { "device": mapping_headers["device"], "arch": mapping_headers["arch"], @@ -87,7 +89,7 @@ def write_json_result_local(output_json_path, headers, row): "model": { "name": mapping_headers["name"], "type": "model", - "origins": ["torchao/_models"], + "origins": ["torchao"], }, "metric": { "name": mapping_headers["metric"], diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index e7806f07ad..4b6a1d1d71 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -415,14 +415,14 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT @classmethod def from_float(cls, weight): + if weight.dim() != 2: + return weight + # TODO test if this is valid # in_features = weight.shape[1] # int8 dynamic quantization only has benefit when in_feature > 16 # if in_features <= 16: - # return weight - - if weight.dim() != 2: - return weight + # return weight # avoid circular dep from torchao.dtypes import to_affine_quantized_intx @@ -522,7 +522,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): class AQInt8DynamicallyQuantizedSemiSparseLinearWeight( AQInt8DynamicallyQuantizedLinearWeight ): - layout: Layout = SemiSparseLayout() + aq_layout: Layout = SemiSparseLayout() @classmethod def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): @@ -627,6 +627,15 @@ def from_float(cls, weight): if weight.shape[-1] % group_size != 0: return weight + if ( + isinstance(_layout, TensorCoreTiledLayout) + and weight.dtype != torch.bfloat16 + ): + return weight + + if isinstance(_layout, MarlinSparseLayout) and weight.dtype != torch.float16: + return weight + use_hqq = True mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -690,6 +699,9 @@ class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQM @classmethod def from_float(cls, weight): + if weight.dtype != torch.float16: + return weight + from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs bit_width = 4 @@ -1022,7 +1034,9 @@ def get_weight_block_size(x): DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST = [ AQDefaultLinearWeight, + # TODO: investigate why there are some problems when adding sparse kernels for sam2 AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, + # some errors when calling cusparse kernels when running on sam2 AQInt8DynamicallyQuantizedSemiSparseLinearWeight, ]