Skip to content

Commit

Permalink
[Feature] Refactor Python code (#8)
Browse files Browse the repository at this point in the history
* Auto-format by https://ultralytics.com/actions

* Refactor Python code

* Auto-format by https://ultralytics.com/actions

---------

Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
glenn-jocher and UltralyticsAssistant authored Jun 9, 2024
1 parent f7774a4 commit 2350a65
Show file tree
Hide file tree
Showing 19 changed files with 850 additions and 1,053 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ OnnxSlim can help you slim your onnx model, with less operators, but same accura
- 🚀 Rank 1st in the [AICAS 2024 LLM inference optimization challenge](https://tianchi.aliyun.com/competition/entrance/532170/customize440) held by Arm and T-head
- 🚀 OnnxSlim is merged into [ultralytics](https://github.com/ultralytics/ultralytics) ❤️❤️❤️


# Installation

## Using Prebuilt
Expand Down
46 changes: 22 additions & 24 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,28 @@ def slim(

from onnxslim.core.slim import (
convert_data_format,
freeze,
input_shape_modification,
optimize,
output_modification,
shape_infer,
optimize,
freeze,
)

from onnxslim.utils import (
check_onnx,
check_point,
check_result,
dump_model_info_to_disk,
init_logging,
model_save_as_external_data,
onnxruntime_inference,
print_model_info_as_table,
model_save_as_external_data,
summarize_model,
init_logging,
check_result,
check_onnx,
check_point,
save,
summarize_model,
)

init_logging(verbose)

MAX_ITER = 10 if not os.getenv("ONNXSLIM_MAX_ITER") else int(os.getenv("ONNXSLIM_MAX_ITER"))
MAX_ITER = int(os.getenv("ONNXSLIM_MAX_ITER")) if os.getenv("ONNXSLIM_MAX_ITER") else 10

if isinstance(model, str):
model_name = Path(model).name
Expand Down Expand Up @@ -146,20 +145,19 @@ def slim(

if not output_model:
return model
else:
slimmed_info = summarize_model(model)
save(model, output_model, model_check)
if slimmed_info["model_size"] >= onnx.checker.MAXIMUM_PROTOBUF:
model_size = model.ByteSize()
slimmed_info["model_size"] = [model_size, slimmed_info["model_size"]]
end_time = time.time()
elapsed_time = end_time - start_time

print_model_info_as_table(
model_name,
[float_info, slimmed_info],
elapsed_time,
)
slimmed_info = summarize_model(model)
save(model, output_model, model_check)
if slimmed_info["model_size"] >= onnx.checker.MAXIMUM_PROTOBUF:
model_size = model.ByteSize()
slimmed_info["model_size"] = [model_size, slimmed_info["model_size"]]
end_time = time.time()
elapsed_time = end_time - start_time

print_model_info_as_table(
model_name,
[float_info, slimmed_info],
elapsed_time,
)


def main():
Expand Down
Loading

0 comments on commit 2350a65

Please sign in to comment.