From 3b81ea0058cf76c824300e4aa36721bb7bf86b2f Mon Sep 17 00:00:00 2001 From: Linchenn <40653845+Linchenn@users.noreply.github.com> Date: Thu, 11 Jul 2024 09:37:26 -0700 Subject: [PATCH] fix SDPA fx pass (#85) --- .../generative/fx_passes/remove_sdpa_zero_mask_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py b/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py index 5e9a2f42..86bbc69b 100644 --- a/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +++ b/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py @@ -40,7 +40,7 @@ def call(self, exported_program: torch.export.ExportedProgram): if self.is_zero_tensor_node(source): # Remove the mark_tensor call on the mask input by # replacing the target with an identity function. - node.target = lambda *args, **kwargs: args[0] + node.target = lambda *args, **kwargs: torch.zeros_like(args[0]) exported_program.graph_module.graph.lint() exported_program.graph_module.recompile()