diff --git a/depyf/explain/utils.py b/depyf/explain/utils.py index 4261678d..8fe67747 100644 --- a/depyf/explain/utils.py +++ b/depyf/explain/utils.py @@ -109,7 +109,7 @@ def __init__(self, original_code, module, cache): if not cpp_guard: # for old version of pytorch, # `guard_manager` is a plain python function - guard = guard_manager.code_parts + guard_codes = guard_manager.code_parts freevar_names = guard_manager.__code__.co_freevars freevar_values = [x.cell_contents for x in guard_manager.__closure__] else: @@ -118,20 +118,25 @@ def __init__(self, original_code, module, cache): tensor_aliasing_guard_seen = False def visit(root, ans): nonlocal tensor_aliasing_guard_seen - for x in root.get_leaf_guards(): - if isinstance(x, torch._C._dynamo.guards.NO_TENSOR_ALIASING): + for leaf_guard in root.get_leaf_guards(): + if isinstance(leaf_guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): if not tensor_aliasing_guard_seen: tensor_aliasing_guard_seen = True else: continue - for verbose_str in x.verbose_code_parts(): - verbose_str = verbose_str.strip() - ans.append(verbose_str) + append_guard_code(leaf_guard, ans) for child in root.get_child_managers(): visit(child, ans) - guard = [] + guard_codes = [] root = guard_manager.root - visit(root, guard) + + # Add guards in RootGuardManager + visit(root, guard_codes) + # Add guards in epilogue lambda guards + if hasattr(root, "get_epilogue_lambda_guards"): + for lambda_guard in root.get_epilogue_lambda_guards(): + append_guard_code(lambda_guard, guard_codes) + if guard_manager.closure_vars is None: freevar_names = tuple() freevar_values = [] @@ -139,7 +144,7 @@ def visit(root, ans): freevar_names = tuple(guard_manager.closure_vars.keys()) freevar_values = list(guard_manager.closure_vars.values()) - self.guard = guard + self.guard = guard_codes self.freevars = {name: value for name, value in zip(freevar_names, freevar_values)} code = cache.code @@ -285,6 +290,11 @@ def remove_indentation(code: str): indent = len(lines[0]) - len(lines[0].lstrip()) return "".join([line[indent:] + "\n" for line in lines]) +def append_guard_code(guard, ans): + for verbose_str in guard.verbose_code_parts(): + verbose_str = verbose_str.strip() + ans.append(verbose_str) + from contextlib import contextmanager @contextmanager