From 1a42737552d0879b115789e4af577f821a913e0c Mon Sep 17 00:00:00 2001 From: Amir Gholami Date: Mon, 9 Aug 2021 21:35:57 -0700 Subject: [PATCH] fix trace computation for cases with negative curvature --- pyhessian/hessian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhessian/hessian.py b/pyhessian/hessian.py index 442ac4c..70be97f 100644 --- a/pyhessian/hessian.py +++ b/pyhessian/hessian.py @@ -183,7 +183,7 @@ def trace(self, maxIter=100, tol=1e-3): else: Hv = hessian_vector_product(self.gradsH, self.params, v) trace_vhv.append(group_product(Hv, v).cpu().item()) - if abs(np.mean(trace_vhv) - trace) / (trace + 1e-6) < tol: + if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol: return trace_vhv else: trace = np.mean(trace_vhv)