Skip to content

Commit

Permalink
Add epilogue lambda guards (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
imShZh authored Nov 22, 2024
1 parent 2ba8467 commit cbfbc00
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions depyf/explain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self, original_code, module, cache):
if not cpp_guard:
# for old version of pytorch,
# `guard_manager` is a plain python function
guard = guard_manager.code_parts
guard_codes = guard_manager.code_parts
freevar_names = guard_manager.__code__.co_freevars
freevar_values = [x.cell_contents for x in guard_manager.__closure__]
else:
Expand All @@ -118,28 +118,33 @@ def __init__(self, original_code, module, cache):
tensor_aliasing_guard_seen = False
def visit(root, ans):
nonlocal tensor_aliasing_guard_seen
for x in root.get_leaf_guards():
if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING):
for leaf_guard in root.get_leaf_guards():
if isinstance(guard_codes, torch._C._dynamo.guards.NO_TENSOR_ALIASING):
if not tensor_aliasing_guard_seen:
tensor_aliasing_guard_seen = True
else:
continue
for verbose_str in x.verbose_code_parts():
verbose_str = verbose_str.strip()
ans.append(verbose_str)
append_guard_code(leaf_guard, ans)
for child in root.get_child_managers():
visit(child, ans)
guard = []
guard_codes = []
root = guard_manager.root
visit(root, guard)

# Add guards in RootGuardManager
visit(root, guard_codes)
# Add guards in epilogue lambda guards
if hasattr(root, "get_epilogue_lambda_guards"):
for lambda_guard in root.get_epilogue_lambda_guards():
append_guard_code(lambda_guard, guard_codes)

if guard_manager.closure_vars is None:
freevar_names = tuple()
freevar_values = []
else:
freevar_names = tuple(guard_manager.closure_vars.keys())
freevar_values = list(guard_manager.closure_vars.values())

self.guard = guard
self.guard = guard_codes
self.freevars = {name: value for name, value in zip(freevar_names, freevar_values)}
code = cache.code

Expand Down Expand Up @@ -285,6 +290,11 @@ def remove_indentation(code: str):
indent = len(lines[0]) - len(lines[0].lstrip())
return "".join([line[indent:] + "\n" for line in lines])

def append_guard_code(guard, ans):
for verbose_str in guard.verbose_code_parts():
verbose_str = verbose_str.strip()
ans.append(verbose_str)

from contextlib import contextmanager

@contextmanager
Expand Down

0 comments on commit cbfbc00

Please sign in to comment.