diff --git a/depyf/explain/enable_debugging.py b/depyf/explain/enable_debugging.py index 0118a7c7..b9237c65 100644 --- a/depyf/explain/enable_debugging.py +++ b/depyf/explain/enable_debugging.py @@ -75,9 +75,12 @@ def enable_bytecode_hook(hook): @contextlib.contextmanager -def prepare_debug(func, dump_src_dir, clean_wild_fx_code=True): +def prepare_debug(func, dump_src_dir, clean_wild_fx_code=True, pause=True): """ - clean_wild_fx_code: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally. + Args: + func: the function to debug, can be `None`. If it is `None`, do not dump all the source code in `full_code.py`. + clean_wild_fx_code: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally. + pause: whether to pause the program after the source code is dumped. """ import os import torch @@ -115,18 +118,26 @@ def prepare_debug(func, dump_src_dir, clean_wild_fx_code=True): try: yield finally: - from depyf.explain import dump_src, _extract_artifacts, _collect_compiled_subgraphs - full_src = dump_src(func) - filename = os.path.join(dump_src_dir, f"full_code.py") - with open(filename, "w") as f: - f.write(full_src) if clean_wild_fx_code: for file in os.listdir(dump_src_dir): if file.split( os.path.sep)[-1].startswith("fx_graph_code"): os.remove(os.path.join(dump_src_dir, file)) - input( - f"Please check the full source code in {filename}, and set breakpoints for functions in {dump_src_dir} according to the hash value. Then press enter to continue.") + + if func is None: + if pause: + input( + f"Please set breakpoints in {dump_src_dir}. Then press enter to continue.") + else: + from depyf.explain import dump_src + full_src = dump_src(func) + filename = os.path.join(dump_src_dir, f"full_code.py") + with open(filename, "w") as f: + f.write(full_src) + + if pause: + input( + f"Please check the full source code in {filename}, and set breakpoints for functions in {dump_src_dir} according to the function name. Then press enter to continue.") @contextlib.contextmanager diff --git a/tests/test_pytorch/test_pytorch.py b/tests/test_pytorch/test_pytorch.py index 52f4efb6..c80788ae 100644 --- a/tests/test_pytorch/test_pytorch.py +++ b/tests/test_pytorch/test_pytorch.py @@ -49,7 +49,7 @@ def call(): target(*input1) description = f"{usage_type}_{compile_type}_{backend}" -description += "_dynamic_shape" if dynamic_shape else "_without_dynamic_shape" +description += "_with_dynamic_shape" if dynamic_shape else "_without_dynamic_shape" description += "_with_grad" if requires_grad else "_without_grad" if usage_type == "dump":