From 384ff7569482411c55fa69ef71f0a5bdc6f7b9e6 Mon Sep 17 00:00:00 2001 From: Yuchao Feng <99387180+fyctime052@users.noreply.github.com> Date: Thu, 25 Jan 2024 19:58:38 +0800 Subject: [PATCH] Update metric.py --- PolarPointBEV/metric.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/PolarPointBEV/metric.py b/PolarPointBEV/metric.py index 272654e..1f93f17 100644 --- a/PolarPointBEV/metric.py +++ b/PolarPointBEV/metric.py @@ -51,12 +51,9 @@ def __init__(self, num_classes): def update(self, a, b): n = self.num_classes if self.mat is None: - # 创建混淆矩阵 self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) with torch.no_grad(): - # 寻找GT中为目标的像素索引 k = (a >= 0) & (a < n) - # 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙) inds = n * a[k].to(torch.int64) + b[k] self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) @@ -89,4 +86,4 @@ def __str__(self): acc_global.item() * 100, ['{:.1f}'.format(i) for i in (acc * 100).tolist()], ['{:.1f}'.format(i) for i in (iu * 100).tolist()], - iu.mean().item() * 100) \ No newline at end of file + iu.mean().item() * 100)