Skip to content

Commit

Permalink
Allow for calling lift_fresh_copy manually (pytorch#113923)
Browse files Browse the repository at this point in the history
In this case, the input could be fake!  Just treat it normally in that case.

Fixes pytorch#113331

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#113923
Approved by: https://github.com/eellison, https://github.com/bdhirsh, https://github.com/leslie-fang-intel
  • Loading branch information
ezyang authored and pytorchmergebot committed Nov 19, 2023
1 parent 72a8329 commit edc5ae3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
28 changes: 28 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,34 @@ def f():
self.assertEqual(a.size(), torch.Size([3, 4]))
self.assertEqual(b.size(), torch.Size([3, 4]))

def test_export_then_compile_tensor_ctor(self):
class M(torch.nn.Module):
def __init__(self,):
super().__init__()

def forward(self, scores, mask):
scores = scores.masked_fill(
mask, torch.tensor(torch.finfo(scores.dtype).min)
) # (bs, n_heads, q_length, k_length)
return scores

tensor_cpu = torch.randn(2, 4)
mask_cpu = torch.BoolTensor(
[[False, True, False, False],
[False, False, False, False]]
)

m = M().eval()
# res_ref = m(tensor_cpu, mask_cpu)
# print("res_ref is: {}".format(res_ref), flush=True)

exported_model = capture_pre_autograd_graph(
m,
(tensor_cpu, mask_cpu),
)
optimized_model = torch.compile(exported_model)
optimized_model(tensor_cpu, mask_cpu)


if __name__ == '__main__':
run_tests()
16 changes: 9 additions & 7 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,11 +1465,14 @@ def maybe_to_constant(t):
return t

# To constant propagate through these functions:
# 1, If this is a lift, the input tensor is guaranteed to be a
# 1, If this is a lift due to a torch.tensor call,
# the input tensor is guaranteed to be a
# constant, so we keep a copy of the original argument along so
# we can query it if we're asked to item() it at some later point
# we can query it if we're asked to item() it at some later point.
# (Note that you can always call a lift fn manually, so we do
# have to check if there are any fake tensors!)
# 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div
if func in self.lift_fns or (
if (func in self.lift_fns and not flat_arg_fake_tensors) or (
should_allow_numbers_as_tensors(func)
and not has_symbolic_sizes
and not flat_arg_fake_tensors
Expand Down Expand Up @@ -1509,11 +1512,10 @@ def maybe_to_constant(t):
# this is generated from torch.tensor(), which does not use the
# dispatcher, to allow wrapper subclasses to wrap the new tensor
if func in self.lift_fns:
assert (
len(kwargs) == 0 and len(args) == 1 and type(args[0]) is torch.Tensor
), f"{args} {kwargs}"
assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}"

return converter(self, args[0])
if type(args[0]) is torch.Tensor:
return converter(self, args[0])

# Recompute flat_arg_fake_tensors here again in case some of the inputs
# were real tensors and fakified in validate_and_convert_non_fake_tensors
Expand Down

0 comments on commit edc5ae3

Please sign in to comment.