From a676f6ed6de3468a2c4bd9a9a9057ee6c3e2282c Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 11 Dec 2024 08:55:28 -0500 Subject: [PATCH 1/3] (chore) Remove unused dotfiles (#1445) --- .buckconfig | 0 .style.yapf | 13 ------------- 2 files changed, 13 deletions(-) delete mode 100644 .buckconfig delete mode 100644 .style.yapf diff --git a/.buckconfig b/.buckconfig deleted file mode 100644 index e69de29bb..000000000 diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index e60ac16e5..000000000 --- a/.style.yapf +++ /dev/null @@ -1,13 +0,0 @@ -[style] -ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = True -ALLOW_MULTILINE_LAMBDAS = True -BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = True -COLUMN_LIMIT = 88 -COALESCE_BRACKETS = True -SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True -SPACES_BEFORE_COMMENT = 2 -SPLIT_BEFORE_BITWISE_OPERATOR = True -SPLIT_BEFORE_FIRST_ARGUMENT = True -SPLIT_BEFORE_LOGICAL_OPERATOR = True -SPLIT_BEFORE_NAMED_ASSIGNS = True -SPLIT_COMPLEX_COMPREHENSION = True From 032beb953e7ddbf001d244489850a79c281e8217 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 17 Dec 2024 11:30:27 -0500 Subject: [PATCH 2/3] Remove triton.ops, copy necessary bits here (#1413) * Remove triton.ops, copy necessary bits here Summary: Triton upstream removed `triton.ops` and moved it to a semi-unmaintained `kernels` repo. Since all that's needed here is the perf model, just add those bits here. * Add source reference/license comment --- .../triton/int8_matmul_mixed_dequantize.py | 3 +- .../triton/int8_matmul_rowwise_dequantize.py | 3 +- bitsandbytes/triton/matmul_perf_model.py | 211 ++++++++++++++++++ 3 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 bitsandbytes/triton/matmul_perf_model.py diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index 583371d91..5fcb927d4 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -9,7 +9,8 @@ def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + from .matmul_perf_model import early_config_prune, estimate_matmul_time # This is a matmul kernel based on triton.ops.matmul # It is modified to support rowwise quantized input and global quantized weight diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index e3d192ded..05e30a4c9 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -9,7 +9,8 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + from .matmul_perf_model import early_config_prune, estimate_matmul_time # This is a matmul kernel based on triton.ops.matmul # It is modified to support rowwise quantized input and columnwise quantized weight diff --git a/bitsandbytes/triton/matmul_perf_model.py b/bitsandbytes/triton/matmul_perf_model.py new file mode 100644 index 000000000..199ceb1a3 --- /dev/null +++ b/bitsandbytes/triton/matmul_perf_model.py @@ -0,0 +1,211 @@ +# Adapted from https://github.com/triton-lang/kernels/blob/eeeebdd8be7d13629de22d600621e6234057eed3/kernels/matmul_perf_model.py +# https://github.com/triton-lang/kernels is licensed under the MIT License. + +import functools +import heapq + +import torch + +from triton import cdiv +from triton.runtime import driver +from triton.testing import ( + get_dram_gbps, + get_max_simd_tflops, + get_max_tensorcore_tflops, + nvsmi, +) + + +@functools.lru_cache +def get_clock_rate_in_khz(): + try: + return nvsmi(["clocks.max.sm"])[0] * 1e3 + except FileNotFoundError: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 + + +def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): + """return compute throughput in TOPS""" + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = ( + min(num_subcores, total_warps) + / num_subcores + * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) + ) + return tflops + + +def get_simd_tflops(device, num_ctas, num_warps, dtype): + """return compute throughput in TOPS""" + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = ( + min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) + ) + return tflops + + +def get_tflops(device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) + + +def estimate_matmul_time( + # backend, device, + num_warps, + num_stages, # + A, + B, + C, # + M, + N, + K, # + BLOCK_M, + BLOCK_N, + BLOCK_K, + SPLIT_K, # + debug=False, + **kwargs, # +): + """return estimated running time in ms + = max(compute, loading) + store""" + device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() + + num_cta_m = cdiv(M, BLOCK_M) + num_cta_n = cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print( + f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " + f"loading time: {load_ms}ms, store time: {store_ms}ms, " + f"Activate CTAs: {active_cta_ratio*100}%" + ) + return total_time_ms + + +def early_config_prune(configs, named_args, **kwargs): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args["A"].element_size() + dtype = named_args["A"].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + config.num_stages, + ) + + max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + if dtype not in [torch.float16, torch.float32]: + configs = [config for config in configs if config.kwargs["SPLIT_K"] == 1] + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + kw["SPLIT_K"], + config.num_warps, + config.num_stages, + ) + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, + v, + key=lambda x: ( + 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 + else x[1] - optimal_num_stages + ), + ) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs From 5b015890be1880f910219c9ea966054b8f98affc Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 17 Dec 2024 21:30:02 +0000 Subject: [PATCH 3/3] chore: migrate config files to `pyproject.toml` (#1373) * chore: move configs to pyproject.toml * fix: drop file from CI workflow * feat: reorder pytest markers * chore: retain comments * chore(build): migrate build data to pyproject Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Aarni Koskela * chore: move configs to pyproject.toml * Apply suggestions from code review Co-authored-by: Aarni Koskela * bump ruff --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Co-authored-by: Aarni Koskela --- .github/workflows/python-package.yml | 1 - pyproject.toml | 89 +++++++++++++++++++++++++++- pytest.ini | 14 ----- requirements-ci.txt | 6 -- requirements-dev.txt | 9 --- setup.py | 36 +---------- 6 files changed, 89 insertions(+), 66 deletions(-) delete mode 100644 pytest.ini delete mode 100644 requirements-ci.txt delete mode 100644 requirements-dev.txt diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 9b166794f..5cd956574 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -14,7 +14,6 @@ on: - "requirements*.txt" - "setup.py" - "pyproject.toml" - - "pytest.ini" release: types: [published] workflow_dispatch: {} # Allow manual trigger diff --git a/pyproject.toml b/pyproject.toml index f0ac2e4b7..3573b10f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,94 @@ [build-system] -requires = [ "setuptools", "wheel" ] +requires = ["setuptools >= 63.0.0"] build-backend = "setuptools.build_meta" +[project] +name = "bitsandbytes" +dynamic = ["version"] +description = "k-bit optimizers and matrix multiplication routines." +authors = [{name="Tim Dettmers", email="dettmers@cs.washington.edu"}] +requires-python = ">=3.8" +readme = "README.md" +license = {file="LICENSE"} +keywords = [ + "gpu", + "optimizers", + "optimization", + "8-bit", + "quantization", + "compression" +] +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Environment :: GPU :: NVIDIA CUDA :: 11", + "Environment :: GPU :: NVIDIA CUDA :: 12", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + "Programming Language :: C++", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence" +] +dependencies = [ + "torch>=1.11,!=1.12.0", + "numpy>=1.17" +] + +[project.optional-dependencies] +benchmark = ["pandas", "matplotlib"] +docs = ["hf-doc-builder==0.5.0"] +dev = [ + "bitsandbytes[test]", + "build>=1.0.0,<2", + "ruff==0.6.9", + "pre-commit>=3.5.0,<4", + "wheel>=0.42,<1" +] +test = [ + "einops~=0.6.0", + "lion-pytorch==0.0.6", + "pytest~=7.4", + "scipy>=1.10.1,<2; python_version < '3.9'", + "scipy>=1.11.4,<2; python_version >= '3.9'", + "transformers>=4.30.1,<5" +] +triton = ["triton~=2.0.0; sys_platform=='linux' and platform_machine=='x86_64'"] + +[project.urls] +homepage = "https://github.com/TimDettmers/bitsandbytes" +changelog = "https://github.com/TimDettmers/bitsandbytes/blob/main/CHANGELOG.md" +docs = "https://huggingface.co/docs/bitsandbytes/main" +issues = "https://github.com/TimDettmers/bitsandbytes/issues" + +[tool.setuptools] +package-data = { "*" = ["libbitsandbytes*.*"] } + +[tool.setuptools.dynamic] +version = {attr = "bitsandbytes.__version__"} + +[tool.pytest.ini_options] +addopts = "-rP" +# ; --cov=bitsandbytes +# ; # contexts: record which test ran which line; can be seen in html coverage report +# ; --cov-context=test +# ; --cov-report html +log_cli = true +log_cli_level = "INFO" +log_file = "logs/pytest.log" +markers = [ + "benchmark: mark test as a benchmark", + "deprecated: mark test as covering a deprecated feature", + "slow: mark test as slow", +] + [tool.ruff] src = [ "bitsandbytes", diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 0090e0ca7..000000000 --- a/pytest.ini +++ /dev/null @@ -1,14 +0,0 @@ -[pytest] -addopts = -rP - ; --cov=bitsandbytes - ; # contexts: record which test ran which line; can be seen in html coverage report - ; --cov-context=test - ; --cov-report html - -log_cli = True -log_cli_level = INFO -log_file = logs/pytest.log -markers = - benchmark: mark test as benchmark - slow: mark test as slow - deprecated: mark test as covering a deprecated feature diff --git a/requirements-ci.txt b/requirements-ci.txt deleted file mode 100644 index 25ff67295..000000000 --- a/requirements-ci.txt +++ /dev/null @@ -1,6 +0,0 @@ -# Requirements used for GitHub actions -pytest==8.3.3 -einops==0.8.0 -lion-pytorch==0.2.2 -scipy==1.10.1; python_version < "3.9" -scipy==1.14.1; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index aedd07966..000000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Requirements used for local development -setuptools>=63 -pytest~=8.3.3 -einops~=0.8.0 -wheel~=0.44.0 -lion-pytorch~=0.2.2 -scipy~=1.14.1 -pandas~=2.2.2 -matplotlib~=3.9.2 diff --git a/setup.py b/setup.py index 435641467..7d70bbc17 100644 --- a/setup.py +++ b/setup.py @@ -2,20 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import glob -import os - from setuptools import find_packages, setup from setuptools.dist import Distribution -libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.*")) -libs = [os.path.basename(p) for p in libs] -print("libs:", libs) - - -def read(fname): - return open(os.path.join(os.path.dirname(__file__), fname), encoding="utf8").read() - # Tested with wheel v0.29.0 class BinaryDistribution(Distribution): @@ -23,27 +12,4 @@ def has_ext_modules(self): return True -setup( - name="bitsandbytes", - version="0.45.1.dev0", - author="Tim Dettmers", - author_email="dettmers@cs.washington.edu", - description="k-bit optimizers and matrix multiplication routines.", - license="MIT", - keywords="gpu optimizers optimization 8-bit quantization compression", - url="https://github.com/bitsandbytes-foundation/bitsandbytes", - packages=find_packages(), - package_data={"": libs}, - install_requires=["torch", "numpy", "typing_extensions>=4.8.0"], - extras_require={ - "benchmark": ["pandas", "matplotlib"], - "test": ["scipy", "lion_pytorch"], - }, - long_description=read("README.md"), - long_description_content_type="text/markdown", - classifiers=[ - "Development Status :: 4 - Beta", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - ], - distclass=BinaryDistribution, -) +setup(version="0.45.1.dev0", packages=find_packages(), distclass=BinaryDistribution)