Skip to content

Commit

Permalink
Update metric.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fyctime052 authored Jan 25, 2024
1 parent 6efea98 commit 384ff75
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions PolarPointBEV/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@ def __init__(self, num_classes):
def update(self, a, b):
n = self.num_classes
if self.mat is None:
# 创建混淆矩阵
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.no_grad():
# 寻找GT中为目标的像素索引
k = (a >= 0) & (a < n)
# 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

Expand Down Expand Up @@ -89,4 +86,4 @@ def __str__(self):
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
iu.mean().item() * 100)
iu.mean().item() * 100)

0 comments on commit 384ff75

Please sign in to comment.