From cf0e73fe6e891047bdf47f359c8053b36ff46e81 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Tue, 10 Dec 2024 12:03:50 -0800 Subject: [PATCH] Fix culprit test PiperOrigin-RevId: 704798083 --- ai_edge_torch/debug/test/test_culprit.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ai_edge_torch/debug/test/test_culprit.py b/ai_edge_torch/debug/test/test_culprit.py index df40c3b7..04e074a8 100644 --- a/ai_edge_torch/debug/test/test_culprit.py +++ b/ai_edge_torch/debug/test/test_culprit.py @@ -15,14 +15,14 @@ import ast -import io -import sys -from ai_edge_torch.debug import find_culprits +import ai_edge_torch.debug import torch from absl.testing import absltest as googletest +find_culprits = ai_edge_torch.debug.find_culprits + _test_culprit_lib = torch.library.Library("test_culprit", "DEF") _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor") @@ -52,6 +52,11 @@ def forward(self, x): class TestCulprit(googletest.TestCase): + def setUp(self): + super().setUp() + torch.manual_seed(0) + torch._dynamo.reset() + def test_find_culprits(self): model = BadModel().eval() args = (torch.rand(10),)