Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 8, 2025
1 parent 0818011 commit f11f214
Showing 1 changed file with 58 additions and 57 deletions.
115 changes: 58 additions & 57 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,63 +105,64 @@ def main():
print("Validation successful!\n")

print("Generating candidate tuning specs...")
test_tuner = TestTuner()
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling dispatch candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled dispatch candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.simple_num_dispatch_candidates,
)
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.simple_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.simple_model_file,
)
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.simple_num_model_candidates,
)
with TunerContext() as tuner_context:
test_tuner = TestTuner(tuner_context)
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers, test_tuner
)
print(f"Stored candidate tuning specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling dispatch candidates...")
compiled_candidates = libtuner.compile(
args, path_config, candidates, candidate_trackers, test_tuner
)
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled dispatch candidates...")
top_candidates = libtuner.benchmark(
args,
path_config,
compiled_candidates,
candidate_trackers,
test_tuner,
args.simple_num_dispatch_candidates,
)
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={args.simple_hip_target}",
]
compiled_model_candidates = libtuner.compile(
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.simple_model_file,
)
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
compiled_model_candidates,
candidate_trackers,
test_tuner,
args.simple_num_model_candidates,
)

print(f"Top model candidates: {top_model_candidates}")

Expand Down

0 comments on commit f11f214

Please sign in to comment.