Skip to content

Commit

Permalink
update: label map
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhe committed May 21, 2021
1 parent 6225b49 commit bb3a9a7
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 35 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*.pyc
__pycache__/
__pycache__/*
.idea/
.DS_Store
results/*
24 changes: 13 additions & 11 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def load_all_files(self):
print('Incorrect threshold value! It should be in [0, 1]. Please check and retry ~')
return 0
pre_files = [x for x in os.listdir(prediction_path) if (x.endswith('txt') or x.endswith('xml'))]
print ("Num of prediction files: ", len(pre_files))
print("Num of prediction files: ", len(pre_files))
gt_files = os.listdir(gt_path)
print ("Num of ground truth files: ", len(gt_files))
print("Num of ground truth files: ", len(gt_files))
if len(pre_files) != len(gt_files):
print("groundtruths' size does not match predictions' size please check ~ ")
print("ground truths' size does not match predictions' size, please check ~ ")
return 0
elif len(pre_files) < 1:
print('No files! Please check~')
Expand Down Expand Up @@ -226,19 +226,19 @@ def computeAp(self, label):
plt.ylabel('recall')
plt.draw() # 显示绘图
# plt.pause(5) # 显示5秒
plt.savefig("class_{}_roc.jpg".format(label)) # 保存图象
plt.savefig("class_{}_roc.png".format(label)) # 保存图象
plt.close()

if self.pr:
# 画roc曲线图
# 画pr曲线图
plt.figure('Draw_pr')
plt.plot(rec, prec) # plot绘制折线图
plt.grid(True)
plt.xlabel('recall')
plt.ylabel('precision')
plt.draw() # 显示绘图
# plt.pause(5) # 显示5秒
plt.savefig("class_{}_pr.jpg".format(label)) # 保存图象
plt.savefig("class_{}_pr.png".format(label)) # 保存图象
plt.close()

fppi = 0
Expand Down Expand Up @@ -276,11 +276,12 @@ def run(self):
prediction_path, gt_path, predictions, groundtruths, file_format = self.load_all_files()
aps = 0

# temp
class_map_temp = {1: 'Person', 2: 'Vehicle', 3: 'Dryer'}
# modify as you need
# list your label names as below ['class 1', 'class 2'......]
class_names = ['face']

for label in range(1, self.cls):
semantic_label = class_map_temp[label]
for label in class_names:
semantic_label = label
print('Processing label: {}'.format(semantic_label))
self.get_tp_fp(gt_path, prediction_path, groundtruths, predictions, semantic_label, file_format)
precision, recall, fppi, fppw, ap = self.computeAp(semantic_label)
Expand All @@ -294,8 +295,9 @@ def run(self):
if self.FPPIW:
print('FPPW: ', fppw, 'FPPI', fppi)
aps += ap

mAp = aps / (self.cls - 1)
print ("mAp: ", mAp)
print("mAp: ", mAp)

return 0

18 changes: 13 additions & 5 deletions io_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,24 @@


# modify 'class_map' as you need
class_map = {'Person': 'Person', 'Vehicle': 'Vehicle', 'Dryer': 'Dryer'}
# {id or label name in gt label: class name}
class_map = {'face': 'face'}


# parse pascal voc style label file
def parse_xml(xml_path):
dom = xml.dom.minidom.parse(xml_path)
gts = []
try:
dom = xml.dom.minidom.parse(xml_path)
print('{} parse failed! Use empty label instead \n'.format(xml_path))
except:
return gts
root = dom.documentElement
objects = root.getElementsByTagName('object')
gts = []
for index, obj in enumerate(objects):
name = obj.getElementsByTagName('name')[0].firstChild.data.decode('utf8')
name = obj.getElementsByTagName('name')[0].firstChild.data.strip("\ufeff")
if name not in class_map:
continue
label = class_map[name]
bndbox = obj.getElementsByTagName('bndbox')[0]
x1 = int(bndbox.getElementsByTagName('xmin')[0].firstChild.data)
Expand All @@ -21,4 +29,4 @@ def parse_xml(xml_path):
y2 = int(bndbox.getElementsByTagName('ymax')[0].firstChild.data)
gt_one = [label, x1, y1, x2, y2]
gts.append(gt_one)
return gts
return gts
Binary file removed results/class_1_pr.jpg
Binary file not shown.
Binary file removed results/class_1_roc.jpg
Binary file not shown.
Binary file removed results/class_2_pr.jpg
Binary file not shown.
Binary file removed results/class_2_roc.jpg
Binary file not shown.
Binary file removed results/class_3_pr.jpg
Binary file not shown.
Binary file removed results/class_3_roc.jpg
Binary file not shown.
Binary file removed sample/.DS_Store
Binary file not shown.
40 changes: 21 additions & 19 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@


cfg = {'file_dir': './',
'overlapRatio': 0.5,
'cls': 2,
'presicion': False,
'recall': False,
'threshold': 0.5,
'FPPIW': False,
'roc': False,
'pr': False}
'overlapRatio': 0.5, # iou between predicted bounding box and ground truth bounding box
'cls': 2, # background id included
'precision': False, # calculate precision with 'threshold' or not
'recall': False, # calculate precision with 'threshold' or not
'threshold': 0.5, # confidence threshold used in calculating precision and recall
'FPPIW': False, # FPPI: false positive per image; FPPW: false positive per window(bounding box)
'roc': False, # draw roc curve or not
'pr': False # draw pr curve or not
}


def parse_args():
Expand Down Expand Up @@ -43,19 +44,20 @@ def parse_args():

if __name__ == "__main__":
args = parse_args()
args.dir = ['/Users/wangzhe/data/zhengdanongmu/eval/voc_coco_only/converted_voc_coco_only', '/Users/wangzhe/data/zhengdanongmu/eval/voc_coco_only/gt_voc_coco']
args.dir = ['/data/guanlang/video_clips/metric_used_xml/face_pred_txt', # prediction path
'/data/guanlang/video_clips/metric_used_xml/gt'] # gt path

args.cls = 4
args.overlapRatio = 0.5
args.threshold = 0.5
# args.cls = 2
# args.overlapRatio = 0.3
# args.threshold = 0.94
len(sys.argv)
print ("Your Folder's path: {}".format(args.dir))
print ("Overlap Ratio: {}".format(args.overlapRatio))
print ("Threshold: {}".format(args.threshold))
print ("Num of Categories: {}".format(args.cls))
print ("Precision: {}".format(args.precision))
print ("Recall: {}".format(args.recall))
print ("FPPIW: {}".format(args.FPPIW))
print("Your Folder's path: {}".format(args.dir))
print("Overlap Ratio: {}".format(args.overlapRatio))
print("Threshold: {}".format(args.threshold))
print("Num of Categories: {}".format(args.cls))
print("Precision: {}".format(args.precision))
print("Recall: {}".format(args.recall))
print("FPPIW: {}".format(args.FPPIW))

print("Calculating......")

Expand Down

0 comments on commit bb3a9a7

Please sign in to comment.