Skip to content

Commit

Permalink
allow func is None, and allow pause
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Nov 28, 2023
1 parent b6ec652 commit 62e8da1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
29 changes: 20 additions & 9 deletions depyf/explain/enable_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ def enable_bytecode_hook(hook):


@contextlib.contextmanager
def prepare_debug(func, dump_src_dir, clean_wild_fx_code=True):
def prepare_debug(func, dump_src_dir, clean_wild_fx_code=True, pause=True):
"""
clean_wild_fx_code: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally.
Args:
func: the function to debug, can be `None`. If it is `None`, do not dump all the source code in `full_code.py`.
clean_wild_fx_code: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally.
pause: whether to pause the program after the source code is dumped.
"""
import os
import torch
Expand Down Expand Up @@ -115,18 +118,26 @@ def prepare_debug(func, dump_src_dir, clean_wild_fx_code=True):
try:
yield
finally:
from depyf.explain import dump_src, _extract_artifacts, _collect_compiled_subgraphs
full_src = dump_src(func)
filename = os.path.join(dump_src_dir, f"full_code.py")
with open(filename, "w") as f:
f.write(full_src)
if clean_wild_fx_code:
for file in os.listdir(dump_src_dir):
if file.split(
os.path.sep)[-1].startswith("fx_graph_code"):
os.remove(os.path.join(dump_src_dir, file))
input(
f"Please check the full source code in {filename}, and set breakpoints for functions in {dump_src_dir} according to the hash value. Then press enter to continue.")

if func is None:
if pause:
input(
f"Please set breakpoints in {dump_src_dir}. Then press enter to continue.")
else:
from depyf.explain import dump_src
full_src = dump_src(func)
filename = os.path.join(dump_src_dir, f"full_code.py")
with open(filename, "w") as f:
f.write(full_src)

if pause:
input(
f"Please check the full source code in {filename}, and set breakpoints for functions in {dump_src_dir} according to the function name. Then press enter to continue.")


@contextlib.contextmanager
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pytorch/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def call():
target(*input1)

description = f"{usage_type}_{compile_type}_{backend}"
description += "_dynamic_shape" if dynamic_shape else "_without_dynamic_shape"
description += "_with_dynamic_shape" if dynamic_shape else "_without_dynamic_shape"
description += "_with_grad" if requires_grad else "_without_grad"

if usage_type == "dump":
Expand Down

0 comments on commit 62e8da1

Please sign in to comment.