Skip to content

Commit

Permalink
fixing some corner cases not caught in unit tests
Browse files Browse the repository at this point in the history
Summary:

shapes need to be divisible by 128 or they will not work with gemlite
need fp32 accumulation for groupsize None on int4

Test Plan:

python test_integration.py -k "test_gemlite" (new test for non divisible
shape)a

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None  --write_result benchmark_results.txt
python generate.py --checkpoint_path
$CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16
--quantization gemlite-32-4-None  --write_result benchmark_results.txta

(previously these gave nonsense responses)

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Dec 18, 2024
1 parent 33d57af commit 1797c75
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 25 deletions.
8 changes: 8 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,14 @@ def test_gemlite_layout(self, device, dtype):
test_shape=test_shape,
test_dtype=dtype,
)
# test that shapes with non divisible by 128 shapes aren't causing errors
self._test_lin_weight_subclass_api_impl(
lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)),
device,
15,
test_shape=[1, 1025, 513],
test_dtype=dtype,
)


@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
85 changes: 60 additions & 25 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import platform
import sys
Expand All @@ -18,11 +19,6 @@
import torchao
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5
from torchao._models.utils import (
get_arch_name,
write_json_result_ossci,
write_json_result_local,
)

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

Expand All @@ -41,6 +37,14 @@ def elapsed_time(self, other_event):
return abs(other_event.event_time - self.event_time) * 1000


def get_arch_name() -> str:
if torch.cuda.is_available():
return torch.cuda.get_device_name()
else:
# This returns x86_64 or arm64 (for aarch64)
return platform.machine()


def device_timer(device):
if "cuda" in device:
return torch.cuda.Event(enable_timing=True)
Expand All @@ -61,6 +65,39 @@ def device_sync(device):
print(f"device={device} is not yet suppported")


def write_json_result(output_json_path, headers, row):
"""
Write the result into JSON format, so that it can be uploaded to the benchmark database
to be displayed on OSS dashboard. The JSON format is defined at
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
mapping_headers = {headers[i]: v for i, v in enumerate(row)}
record = {
"benchmark": {
"name": "TorchAO benchmark",
"mode": "inference",
"dtype": mapping_headers["dtype"],
"extra_info": {
"device": mapping_headers["device"],
"arch": mapping_headers["arch"],
},
},
"model": {
"name": mapping_headers["name"],
"type": "model",
"origins": ["pytorch"],
},
"metric": {
"name": mapping_headers["metric"],
"benchmark_values": [mapping_headers["actual"]],
"target_value": mapping_headers["target"],
},
}

with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f:
print(json.dumps(record), file=f)


default_device = (
"cuda"
if torch.cuda.is_available()
Expand Down Expand Up @@ -135,7 +172,7 @@ def decode_n_tokens(
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
# in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step
new_tokens.append(next_token.clone())
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob)
cur_token = next_token
Expand Down Expand Up @@ -279,7 +316,6 @@ def main(
precision=torch.bfloat16,
write_result: Optional[Path] = None,
output_json_path: Optional[Path] = None,
output_json_local: bool = False,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""

Expand Down Expand Up @@ -692,10 +728,20 @@ def ffn_or_attn_only(mod, fqn):
example_input=inputs,
)
if "autoquant-all" == quantization:
all_qtensor_classes = (
torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST
+ torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST
+ torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
)
if torchao.utils.is_sm_89():
# this is fp8 related subclasses, should rename
all_qtensor_classes += (
torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST
)
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
qtensor_class_list=all_qtensor_classes,
example_input=inputs,
)
else:
Expand All @@ -713,10 +759,6 @@ def ffn_or_attn_only(mod, fqn):

# do autoquantization
model.finalize_autoquant()
elif "codebook" in quantization:
from torchao.prototype.quantization.codebook import codebook_weight_only
model.to(device)
quantize_(model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64))

else:
if not TORCH_VERSION_AT_LEAST_2_5:
Expand Down Expand Up @@ -936,14 +978,13 @@ def callback(x):
f.write(result_txt)
f.close()

headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = checkpoint_path.parent.name
arch = get_arch_name()
dtype = quantization or str(precision)
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
if output_json_path:
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = checkpoint_path.parent.name
arch = get_arch_name()
dtype = quantization or "noquant"
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
write_json_result = write_json_result_local if output_json_local else write_json_result_ossci
write_json_result(output_json_path, headers, memory_result)
write_json_result(output_json_path, headers, performance_result)

Expand Down Expand Up @@ -1045,11 +1086,6 @@ def callback(x):
default=None,
help="Path where to write the json result for dashboard",
)
parser.add_argument(
"--output_json_local",
action="store_true",
help="Whether to output json result for local machine or for CI machine, local option will fill in some dummy fields",
)

args = parser.parse_args()
print(args)
Expand Down Expand Up @@ -1077,5 +1113,4 @@ def callback(x):
args.precision,
args.write_result,
args.output_json_path,
args.output_json_local,
)
13 changes: 13 additions & 0 deletions torchao/dtypes/uintx/gemlite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchao.dtypes.utils import Layout, is_device
from torchao.quantization.quant_primitives import quantize_affine
from torchao.utils import fill_defaults
import warnings

aten = torch.ops.aten

Expand Down Expand Up @@ -76,6 +77,14 @@ def apply_gemlite_quant(
out_features, in_features = weight.shape
group_size = in_features if group_size is None else group_size

if in_features % 128 != 0 and out_features % 128 != 0:
warnings.simplefilter("once", UserWarning)
warnings.warn(
"Gemlite only works for layers with in_features or out_features divisible by 128, "
+ "some layers have been skipped", UserWarning
)
return weight

quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size)

layout = GemlitePackedLayout(
Expand Down Expand Up @@ -173,6 +182,10 @@ def from_plain(
exhaustive=False,
use_cuda_graph=False,
)
if _layout.group_size == None and _layout.bit_width == 4:
from gemlite.core import GEMLITE_ACC_DTYPE
from gemlite.dtypes import DType
GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32

out_features, in_features = int_data.shape
input_dtype, output_dtype = DType.FP16, DType.FP16
Expand Down

0 comments on commit 1797c75

Please sign in to comment.