Skip to content

Commit

Permalink
Lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Jan 8, 2025
2 parents e51a1d7 + cc8e80b commit 58299ce
Show file tree
Hide file tree
Showing 14 changed files with 790 additions and 155 deletions.
3 changes: 3 additions & 0 deletions examples/sam2_amg_server/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ def load_exported_model(
def set_fast(
mask_generator, task_type, loaded_exported_model=False, allow_recompiles=True
):
if task_type == "":
task_type = "amg"

assert task_type in TASK_TYPES, f"Expected {task_type} to be one of {TASK_TYPES}"
if not loaded_exported_model:
# TODO: Using CUDA graphs can cause numerical differences?
Expand Down
43 changes: 31 additions & 12 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def max_memory_allocated():
mib = stats["bytes"] >> 20
print(f"max_memory_allocated_bytes: {mib}MiB")
print(f"max_memory_allocated_percentage: {stats['percentage']}%")
return mib, stats["percentage"]


def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
Expand Down Expand Up @@ -385,16 +386,30 @@ def model_type_to_paths(checkpoint_path, model_type):
return sam2_checkpoint, model_cfg


def set_autoquant(mask_generator):
def set_autoquant(mask_generator, autoquant_type, min_sqnr):
import torchao
from torchao import autoquant

# NOTE: Not baseline feature
mask_generator.predictor.model.image_encoder = autoquant(
mask_generator.predictor.model.image_encoder,
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
min_sqnr=40,
)
if autoquant_type == "autoquant":
mask_generator.predictor.model.image_encoder = autoquant(
mask_generator.predictor.model.image_encoder, min_sqnr=min_sqnr
)
elif autoquant_type == "autoquant-fp":
mask_generator.predictor.model.image_encoder = autoquant(
mask_generator.predictor.model.image_encoder,
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
min_sqnr=min_sqnr,
)
elif autoquant_type == "autoquant-all":
mask_generator.predictor.model.image_encoder = autoquant(
mask_generator.predictor.model.image_encoder,
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
min_sqnr=min_sqnr,
)
else:
raise ValueError(f"Unexpected autoquant type: {autoquant_type}")

mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision("high")
# NOTE: this fails when we run
Expand All @@ -409,7 +424,8 @@ def main(
baseline=False,
fast=False,
furious=False,
use_autoquant=False,
autoquant_type=None,
min_sqnr=None,
unittest=False,
benchmark=False,
profile=None,
Expand Down Expand Up @@ -491,9 +507,9 @@ def main(
set_fast(mask_generator, load_fast)

# since autoquant is replicating what furious mode is doing, don't use these two together
if use_autoquant:
if autoquant_type is not None:
assert not furious, "use autoquant can't be used together with furious"
set_autoquant(mask_generator)
set_autoquant(mask_generator, autoquant_type, min_sqnr)

with open("dog.jpg", "rb") as f:
output_format = "numpy" if baseline else "torch"
Expand Down Expand Up @@ -555,7 +571,7 @@ def main(
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = "sam2-" + model_type
arch = get_arch_name()
dtype = "autoquant" if use_autoquant else "noquant"
dtype = autoquant_type or "noquant"
(
avg_time_per_run,
max_memory_allocated_bytes,
Expand Down Expand Up @@ -657,18 +673,21 @@ async def upload_image(image: UploadFile = File(...)):
await request_queue.put((image_tensor, response_future))
masks = await response_future

# Save an example
plt.figure(
# Create figure and ensure it's closed after generating response
fig = plt.figure(
figsize=(image_tensor.shape[1] / 100.0, image_tensor.shape[0] / 100.0),
dpi=100,
)
plt.imshow(image_tensor)
show_anns(masks, rle_to_mask)
plt.axis("off")
plt.tight_layout()

buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
plt.close(fig) # Close figure after we're done with it

return StreamingResponse(buf, media_type="image/png")

# uvicorn.run(app, host=host, port=port, log_level="info")
Expand Down
44 changes: 40 additions & 4 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def main(
"checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"
),
quantization: Optional[str] = None,
min_sqnr: Optional[float] = None,
sparsity: Optional[str] = None,
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
Expand Down Expand Up @@ -706,27 +707,31 @@ def ffn_or_attn_only(mod, fqn):
manual=True,
qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
example_input=inputs,
min_sqnr=min_sqnr,
)
elif "autoquant-float8" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST,
example_input=inputs,
min_sqnr=min_sqnr,
)
elif "autoquant-fp" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
example_input=inputs,
min_sqnr=min_sqnr,
)
elif "autoquant-sparse" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
example_input=inputs,
min_sqnr=min_sqnr,
)
elif "autoquant-gemlite-int4" == quantization:
import os
Expand All @@ -742,6 +747,7 @@ def ffn_or_attn_only(mod, fqn):
manual=True,
qtensor_class_list=torchao.quantization.GEMLITE_INT4_AUTOQUANT_CLASS_LIST,
example_input=inputs,
min_sqnr=min_sqnr,
)
elif "autoquant-all" == quantization:
try:
Expand All @@ -761,9 +767,12 @@ def ffn_or_attn_only(mod, fqn):
manual=True,
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
example_input=inputs,
min_sqnr=min_sqnr,
)
else:
model = autoquant(model, manual=True, example_input=inputs)
model = autoquant(
model, manual=True, example_input=inputs, min_sqnr=min_sqnr
)

generate(
model,
Expand Down Expand Up @@ -1015,12 +1024,30 @@ def callback(x):
f.close()

if output_json_path:
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
headers = [
"name",
"dtype",
"min_sqnr",
"device",
"arch",
"metric",
"actual",
"target",
]
name = checkpoint_path.parent.name
arch = get_arch_name()
dtype = quantization or "noquant"
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
memory_result = [name, dtype, min_sqnr, device, arch, "mem/s", bandwidth, None]
performance_result = [
name,
dtype,
min_sqnr,
device,
arch,
"tok/s",
tokpersec,
None,
]
write_json_result = (
write_json_result_local if output_json_local else write_json_result_ossci
)
Expand Down Expand Up @@ -1073,6 +1100,14 @@ def callback(x):
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, int8adq-int4w-symm"
),
)
parser.add_argument(
"--min_sqnr",
type=float,
default=None,
help=(
"min sqnr for quantizing v.s. not quantizing a layer, used in autoquant options",
),
)
parser.add_argument(
"-s",
"--sparsity",
Expand Down Expand Up @@ -1148,6 +1183,7 @@ def callback(x):
args.temperature,
args.checkpoint_path,
args.quantization,
args.min_sqnr,
args.sparsity,
args.kv_cache_quantization,
args.cache_size,
Expand Down
33 changes: 30 additions & 3 deletions torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def run(
use_compile="False",
use_compile_decoder=False,
compress=None,
min_sqnr=None,
num_workers=0,
use_rel_pos=True,
pad_input_image_batch=True,
Expand Down Expand Up @@ -457,31 +458,38 @@ def mlp_only(mod, name):
example_input=example_input,
manual=True,
qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
min_sqnr=min_sqnr,
)
elif "autoquant-float8" == compress:
autoquant(
predictor.model.image_encoder,
example_input=example_input,
manual=True,
qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST,
min_sqnr=min_sqnr,
)
elif "autoquant-sparse" == compress:
autoquant(
predictor.model.image_encoder,
example_input=example_input,
manual=True,
qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
min_sqnr=min_sqnr,
)
elif "autoquant-all" == compress:
autoquant(
predictor.model.image_encoder,
example_input=example_input,
manual=True,
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
min_sqnr=min_sqnr,
)
else:
autoquant(
predictor.model.image_encoder, example_input=example_input, manual=True
predictor.model.image_encoder,
example_input=example_input,
manual=True,
min_sqnr=min_sqnr,
)
predictor.model.image_encoder(example_input)
predictor.model.image_encoder.finalize_autoquant()
Expand Down Expand Up @@ -630,20 +638,39 @@ def mlp_only(mod, name):
f.write(vals + "\n")

if output_json_path:
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
headers = [
"name",
"dtype",
"min_sqnr",
"device",
"arch",
"metric",
"actual",
"target",
]
name = sam_model_type
arch = get_arch_name()
dtype = compress or "noquant"
memory_result = [
name,
dtype,
min_sqnr,
device,
arch,
"memory(MiB)",
max_memory_allocated_bytes,
None,
]
performance_result = [name, dtype, device, arch, "img_s(avg)", img_s, None]
performance_result = [
name,
dtype,
min_sqnr,
device,
arch,
"img_s(avg)",
img_s,
None,
]
write_json_result = (
write_json_result_local if output_json_local else write_json_result_ossci
)
Expand Down
6 changes: 4 additions & 2 deletions torchao/_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def write_json_result_ossci(output_json_path, headers, row):
"name": "TorchAO benchmark",
"mode": "inference",
"dtype": mapping_headers["dtype"],
"min_sqnr": mapping_headers["min_sqnr"],
"extra_info": {
"device": mapping_headers["device"],
"arch": mapping_headers["arch"],
Expand All @@ -38,7 +39,7 @@ def write_json_result_ossci(output_json_path, headers, row):
"model": {
"name": mapping_headers["name"],
"type": "model",
"origins": ["torchao/_models"],
"origins": ["torchao"],
},
"metric": {
"name": mapping_headers["metric"],
Expand Down Expand Up @@ -79,6 +80,7 @@ def write_json_result_local(output_json_path, headers, row):
"name": "TorchAO benchmark",
"mode": "inference",
"dtype": mapping_headers["dtype"],
"min_sqnr": mapping_headers["min_sqnr"],
"extra_info": {
"device": mapping_headers["device"],
"arch": mapping_headers["arch"],
Expand All @@ -87,7 +89,7 @@ def write_json_result_local(output_json_path, headers, row):
"model": {
"name": mapping_headers["name"],
"type": "model",
"origins": ["torchao/_models"],
"origins": ["torchao"],
},
"metric": {
"name": mapping_headers["metric"],
Expand Down
13 changes: 8 additions & 5 deletions torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,13 @@ __global__ void Marlin_24(
meta_ptr[i] += m_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
if constexpr (group_blocks != -1) {
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
Expand All @@ -429,7 +432,7 @@ __global__ void Marlin_24(
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) {
if constexpr (group_blocks != -1) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
Expand Down
Empty file.
Loading

0 comments on commit 58299ce

Please sign in to comment.