diff --git a/ai_edge_torch/debug/test/test_culprit.py b/ai_edge_torch/debug/test/test_culprit.py index df40c3b7..04e074a8 100644 --- a/ai_edge_torch/debug/test/test_culprit.py +++ b/ai_edge_torch/debug/test/test_culprit.py @@ -15,14 +15,14 @@ import ast -import io -import sys -from ai_edge_torch.debug import find_culprits +import ai_edge_torch.debug import torch from absl.testing import absltest as googletest +find_culprits = ai_edge_torch.debug.find_culprits + _test_culprit_lib = torch.library.Library("test_culprit", "DEF") _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor") @@ -52,6 +52,11 @@ def forward(self, x): class TestCulprit(googletest.TestCase): + def setUp(self): + super().setUp() + torch.manual_seed(0) + torch._dynamo.reset() + def test_find_culprits(self): model = BadModel().eval() args = (torch.rand(10),)