Skip to content

Commit

Permalink
keypoint metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 11, 2024
1 parent c1960c3 commit 12a754d
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 18 deletions.
25 changes: 11 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@

# Overview

The MillionTrees benchmark is designed to provie *open*, *reproducible* and *rigorous* evaluation of tree detection algorithms. The dataset currently holds X images, X annotations and X train/test evaluation splits. This repo is the python package for rapid data sharing and evaluation.
The MillionTrees benchmark is designed to provie *open*, *reproducible* and *rigorous* evaluation of tree detection algorithms. The dataset currently holds X images, X annotations and X train/test evaluation splits. This repo is the python package for rapid data sharing and evaluation. We anticipate the dataset will be ready by spring 2025.

## Current Status

There are 3 datasets available for the MillionTrees benchmark:

* TreeBoxes: A dataset of 282,288 tree crowns from 9 sources.

* TreePolygons: A dataset of 362,751 tree crowns from 8 sources.

* TreePoints: A dataset of 191,614 tree stems from 2 sources.

## Why MillionTrees?

Expand Down Expand Up @@ -54,19 +64,6 @@ The MillionTrees package has ingested many contributed datasets and formatted th
| Polygon | | | | |
| Box | | | | |

## Underlying data contributions

Many datasets have been cleaned or altered to fit the desired data format. Here is an incomplete list of the current contributions.

# Using the MillionTrees package

## Downloading and training on the MillionTrees datasets

## Algorithms

## Evaluation

### Reproducibility

# Citing MillionTrees

Expand Down
64 changes: 64 additions & 0 deletions milliontrees/common/metrics/all_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,67 @@ def _accuracy(self, src_boxes, pred_boxes, iou_threshold):

def worst(self, metrics):
return minimum(metrics)

class KeypointAccuracy(ElementwiseMetric):
"""Given a specific Intersection over union threshold, determine the
accuracy achieved for a one-class detector."""

def __init__(self, iou_threshold=0.5, score_threshold=0.5, name=None):
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
if name is None:
name = "keypoint_acc"
super().__init__(name=name)

def _compute_element_wise(self, y_pred, y_true):
batch_results = []
for target, batch_keypoints_predictions in zip(y_true, y_pred):
# concat all boxes and scores
pred_keypoints = torch.cat([image_results["y"] for image_results in batch_keypoints_predictions], dim=0)
pred_scores = torch.cat([image_results["score"] for image_results in batch_keypoints_predictions], dim=0)
pred_keypoints = pred_keypoints[pred_scores > self.score_threshold]
src_keypoints = torch.cat([image_results["y"] for image_results in target], dim=0)

det_accuracy = torch.mean(
torch.stack([
self._accuracy(src_keypoints, pred_keypoints, iou_thr)
for iou_thr in np.arange(0.5, 0.51, 0.05)
]))
batch_results.append(det_accuracy)

return torch.tensor(batch_results)

def _point_iou(self, src_keypoints, pred_keypoints):
return torch.cdist(src_keypoints, pred_keypoints, p=2)

def _accuracy(self, src_keypoints, pred_keypoints, iou_threshold):
total_gt = len(src_keypoints)
total_pred = len(pred_keypoints)
if total_gt > 0 and total_pred > 0:
# Define the matcher and distance matrix based on iou
matcher = Matcher(iou_threshold,
iou_threshold,
allow_low_quality_matches=False)
match_quality_matrix = self._point_iou(src_keypoints, pred_keypoints)
results = matcher(match_quality_matrix)
true_positive = torch.count_nonzero(results.unique() != -1)
matched_elements = results[results > -1]
# in Matcher, a pred element can be matched only twice
false_positive = (
torch.count_nonzero(results == -1) +
(len(matched_elements) - len(matched_elements.unique())))
false_negative = total_gt - true_positive
acc = true_positive / (true_positive + false_positive +
false_negative)
return true_positive / (true_positive + false_positive +
false_negative)
elif total_gt == 0:
if total_pred > 0:
return torch.tensor(0.)
else:
return torch.tensor(1.)
elif total_gt > 0 and total_pred == 0:
return torch.tensor(0.)

def worst(self, metrics):
return minimum(metrics)
4 changes: 2 additions & 2 deletions milliontrees/datasets/TreePoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from milliontrees.datasets.milliontrees_dataset import MillionTreesDataset
from milliontrees.common.grouper import CombinatorialGrouper
from milliontrees.common.metrics.all_metrics import DetectionAccuracy
from milliontrees.common.metrics.all_metrics import KeypointAccuracy
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(self,
self._metadata_array = unique_sources.values
self._metadata_fields = ['filename','source_id']

self._metric = DetectionAccuracy()
self._metric = KeypointAccuracy()
self._collate = TreePointsDataset._collate_fn

# eval grouper
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ furo
sphinx_markdown_tables
myst_parser
albumentations
torchmetrics
4 changes: 2 additions & 2 deletions tests/test_TreePoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_TreePoints_eval(dataset):
all_metadata = []
# Get predictions for the full test set
for metadata, x, y_true in test_loader:
y_pred = [{'y': torch.tensor([[30, 70]]), 'label': torch.tensor([0]), 'score': torch.tensor([0.54])} for _ in range(x.shape[0])]
y_pred = [{'y': torch.tensor([[30.0, 70.0]]), 'label': torch.tensor([0]), 'score': torch.tensor([0.54])} for _ in range(x.shape[0])]
# Accumulate y_true, y_pred, metadata
all_y_pred.append(y_pred)
all_y_true.append(y_true)
Expand All @@ -104,7 +104,7 @@ def test_TreePoints_eval(dataset):
eval_results, eval_string = dataset.eval(all_y_pred, all_y_true, all_metadata)

assert len(eval_results)
assert "detection_acc_avg" in eval_results.keys()
assert "keypoint_acc_avg" in eval_results.keys()

# Test structure with real annotation data to ensure format is correct
# Do not run on github actions, long running.
Expand Down

0 comments on commit 12a754d

Please sign in to comment.