diff --git a/README.md b/README.md index 8bc0cb2..f862692 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,17 @@ # Overview -The MillionTrees benchmark is designed to provie *open*, *reproducible* and *rigorous* evaluation of tree detection algorithms. The dataset currently holds X images, X annotations and X train/test evaluation splits. This repo is the python package for rapid data sharing and evaluation. +The MillionTrees benchmark is designed to provie *open*, *reproducible* and *rigorous* evaluation of tree detection algorithms. The dataset currently holds X images, X annotations and X train/test evaluation splits. This repo is the python package for rapid data sharing and evaluation. We anticipate the dataset will be ready by spring 2025. + +## Current Status + +There are 3 datasets available for the MillionTrees benchmark: + +* TreeBoxes: A dataset of 282,288 tree crowns from 9 sources. + +* TreePolygons: A dataset of 362,751 tree crowns from 8 sources. + +* TreePoints: A dataset of 191,614 tree stems from 2 sources. ## Why MillionTrees? @@ -54,19 +64,6 @@ The MillionTrees package has ingested many contributed datasets and formatted th | Polygon | | | | | | Box | | | | | -## Underlying data contributions - -Many datasets have been cleaned or altered to fit the desired data format. Here is an incomplete list of the current contributions. - -# Using the MillionTrees package - -## Downloading and training on the MillionTrees datasets - -## Algorithms - -## Evaluation - -### Reproducibility # Citing MillionTrees diff --git a/milliontrees/common/metrics/all_metrics.py b/milliontrees/common/metrics/all_metrics.py index a13be5c..a453262 100644 --- a/milliontrees/common/metrics/all_metrics.py +++ b/milliontrees/common/metrics/all_metrics.py @@ -429,3 +429,67 @@ def _accuracy(self, src_boxes, pred_boxes, iou_threshold): def worst(self, metrics): return minimum(metrics) + +class KeypointAccuracy(ElementwiseMetric): + """Given a specific Intersection over union threshold, determine the + accuracy achieved for a one-class detector.""" + + def __init__(self, iou_threshold=0.5, score_threshold=0.5, name=None): + self.iou_threshold = iou_threshold + self.score_threshold = score_threshold + if name is None: + name = "keypoint_acc" + super().__init__(name=name) + + def _compute_element_wise(self, y_pred, y_true): + batch_results = [] + for target, batch_keypoints_predictions in zip(y_true, y_pred): + # concat all boxes and scores + pred_keypoints = torch.cat([image_results["y"] for image_results in batch_keypoints_predictions], dim=0) + pred_scores = torch.cat([image_results["score"] for image_results in batch_keypoints_predictions], dim=0) + pred_keypoints = pred_keypoints[pred_scores > self.score_threshold] + src_keypoints = torch.cat([image_results["y"] for image_results in target], dim=0) + + det_accuracy = torch.mean( + torch.stack([ + self._accuracy(src_keypoints, pred_keypoints, iou_thr) + for iou_thr in np.arange(0.5, 0.51, 0.05) + ])) + batch_results.append(det_accuracy) + + return torch.tensor(batch_results) + + def _point_iou(self, src_keypoints, pred_keypoints): + return torch.cdist(src_keypoints, pred_keypoints, p=2) + + def _accuracy(self, src_keypoints, pred_keypoints, iou_threshold): + total_gt = len(src_keypoints) + total_pred = len(pred_keypoints) + if total_gt > 0 and total_pred > 0: + # Define the matcher and distance matrix based on iou + matcher = Matcher(iou_threshold, + iou_threshold, + allow_low_quality_matches=False) + match_quality_matrix = self._point_iou(src_keypoints, pred_keypoints) + results = matcher(match_quality_matrix) + true_positive = torch.count_nonzero(results.unique() != -1) + matched_elements = results[results > -1] + # in Matcher, a pred element can be matched only twice + false_positive = ( + torch.count_nonzero(results == -1) + + (len(matched_elements) - len(matched_elements.unique()))) + false_negative = total_gt - true_positive + acc = true_positive / (true_positive + false_positive + + false_negative) + return true_positive / (true_positive + false_positive + + false_negative) + elif total_gt == 0: + if total_pred > 0: + return torch.tensor(0.) + else: + return torch.tensor(1.) + elif total_gt > 0 and total_pred == 0: + return torch.tensor(0.) + + def worst(self, metrics): + return minimum(metrics) diff --git a/milliontrees/datasets/TreePoints.py b/milliontrees/datasets/TreePoints.py index e5dbb8c..d647c62 100644 --- a/milliontrees/datasets/TreePoints.py +++ b/milliontrees/datasets/TreePoints.py @@ -8,7 +8,7 @@ from milliontrees.datasets.milliontrees_dataset import MillionTreesDataset from milliontrees.common.grouper import CombinatorialGrouper -from milliontrees.common.metrics.all_metrics import DetectionAccuracy +from milliontrees.common.metrics.all_metrics import KeypointAccuracy from PIL import Image import albumentations as A from albumentations.pytorch import ToTensorV2 @@ -131,7 +131,7 @@ def __init__(self, self._metadata_array = unique_sources.values self._metadata_fields = ['filename','source_id'] - self._metric = DetectionAccuracy() + self._metric = KeypointAccuracy() self._collate = TreePointsDataset._collate_fn # eval grouper diff --git a/requirements.txt b/requirements.txt index e4476a0..c50dc04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ furo sphinx_markdown_tables myst_parser albumentations +torchmetrics diff --git a/tests/test_TreePoints.py b/tests/test_TreePoints.py index 9313b11..c66181f 100644 --- a/tests/test_TreePoints.py +++ b/tests/test_TreePoints.py @@ -94,7 +94,7 @@ def test_TreePoints_eval(dataset): all_metadata = [] # Get predictions for the full test set for metadata, x, y_true in test_loader: - y_pred = [{'y': torch.tensor([[30, 70]]), 'label': torch.tensor([0]), 'score': torch.tensor([0.54])} for _ in range(x.shape[0])] + y_pred = [{'y': torch.tensor([[30.0, 70.0]]), 'label': torch.tensor([0]), 'score': torch.tensor([0.54])} for _ in range(x.shape[0])] # Accumulate y_true, y_pred, metadata all_y_pred.append(y_pred) all_y_true.append(y_true) @@ -104,7 +104,7 @@ def test_TreePoints_eval(dataset): eval_results, eval_string = dataset.eval(all_y_pred, all_y_true, all_metadata) assert len(eval_results) - assert "detection_acc_avg" in eval_results.keys() + assert "keypoint_acc_avg" in eval_results.keys() # Test structure with real annotation data to ensure format is correct # Do not run on github actions, long running.