Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
potipot committed Oct 1, 2023
1 parent e9685bb commit d372ddc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 36 deletions.
34 changes: 2 additions & 32 deletions icevision/metrics/confusion_matrix/confusion_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@ class ObjectDetectionItem:
label_id: int
record_id: int
item_id: int
plate_quality: float
plate_exposure: float
plate_state: str
plate_layout: str
plate_special_stacked: str
plate_mirrored: str

# def __eq__(self, other):
# return (
Expand Down Expand Up @@ -155,24 +149,12 @@ def build_target_list(target: BaseRecord) -> List:
bbox=bbox,
label=label,
label_id=label_id,
plate_quality=pq,
plate_exposure=pe,
plate_state=ps,
plate_layout=pl,
plate_special_stacked=pss,
plate_mirrored=pm,
)
for item_id, (bbox, label, label_id, pq, pe, ps, pl, pss, pm) in enumerate(
for item_id, (bbox, label, label_id) in enumerate(
zip(
target.detection.bboxes,
target.detection.labels,
target.detection.label_ids,
target.classification.plate_quality,
target.classification.plate_exposure,
target.classification.plate_state,
target.classification.plate_layout,
target.classification.plate_special_stacked,
target.classification.plate_mirrored,
)
)
]
Expand All @@ -190,25 +172,13 @@ def build_prediction_list(prediction: BaseRecord) -> List:
label=label,
label_id=label_id,
score=score,
plate_quality=pq,
plate_exposure=pe,
plate_state=ps,
plate_layout=pl,
plate_special_stacked=pss,
plate_mirrored=pm,
)
for item_id, (bbox, label, label_id, score, pq, pe, ps, pl, pss, pm) in enumerate(
for item_id, (bbox, label, label_id, score) in enumerate(
zip(
prediction.detection.bboxes,
prediction.detection.labels,
prediction.detection.label_ids,
prediction.detection.scores,
prediction.classification.plate_quality,
prediction.classification.plate_exposure,
prediction.classification.plate_state,
prediction.classification.plate_layout,
prediction.classification.plate_special_stacked,
prediction.classification.plate_mirrored,
)
)
]
Expand Down
9 changes: 5 additions & 4 deletions tests/models/efficient_det/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ def expected_confusion_matrix_output():


@pytest.mark.parametrize(
"metric, expected_output",
"metric, expected_output, detection_threshold",
[
(SimpleConfusionMatrix(print_summary=True), "expected_confusion_matrix_output"),
(COCOMetric(print_summary=True), "expected_coco_metric_output"),
(SimpleConfusionMatrix(print_summary=True), "expected_confusion_matrix_output", 0.5),
(COCOMetric(print_summary=True), "expected_coco_metric_output", 0.0),
],
)
def test_efficientdet_metrics(
fridge_efficientdet_model,
fridge_efficientdet_records,
metric,
expected_output,
detection_threshold,
request,
):
expected_output = request.getfixturevalue(expected_output)
Expand All @@ -57,7 +58,7 @@ def test_efficientdet_metrics(
batch=batch,
raw_preds=raw_preds["detections"],
records=fridge_efficientdet_records,
detection_threshold=0.0,
detection_threshold=detection_threshold,
)

metric.accumulate(preds)
Expand Down

0 comments on commit d372ddc

Please sign in to comment.