diff --git a/PolarPointBEV/metric.py b/PolarPointBEV/metric.py index 272654e..1f93f17 100644 --- a/PolarPointBEV/metric.py +++ b/PolarPointBEV/metric.py @@ -51,12 +51,9 @@ def __init__(self, num_classes): def update(self, a, b): n = self.num_classes if self.mat is None: - # 创建混淆矩阵 self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) with torch.no_grad(): - # 寻找GT中为目标的像素索引 k = (a >= 0) & (a < n) - # 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙) inds = n * a[k].to(torch.int64) + b[k] self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) @@ -89,4 +86,4 @@ def __str__(self): acc_global.item() * 100, ['{:.1f}'.format(i) for i in (acc * 100).tolist()], ['{:.1f}'.format(i) for i in (iu * 100).tolist()], - iu.mean().item() * 100) \ No newline at end of file + iu.mean().item() * 100)