Skip to content

Commit

Permalink
three input data streams
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 13, 2024
1 parent e21a404 commit 1a15ea1
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 43 deletions.
12 changes: 8 additions & 4 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ label_studio:
api_key: "${oc.env:LABEL_STUDIO_API_KEY}"
folder_name: "/pgsql/retrieverdash/everglades-label-studio/everglades-data"

pipeline:
confidence_threshold: 0.5
predict:
patch_size: 450
patch_overlap: 0
min_score: 0.5

pipeline:
confidence_threshold: 0.5
limit_empty_frac: 0.1

train:
Expand All @@ -33,8 +36,9 @@ train:
- "Bird"

pipeline_evaluation:
detection_annotations_dir:
classification_annotations_dir:
detect_ground_truth_dir:
classify_confident_ground_truth_dir:
classify_uncertain_ground_truth_dir:
detection_true_positive_threshold: 0.8
detection_false_positive_threshold: 0.5
classification_avg_score: 0.5
Expand Down
8 changes: 4 additions & 4 deletions src/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,18 @@ def move_images(annotations, src_dir, dst_dir):
except FileNotFoundError:
continue

def gather_data(train_dir, labels=None):
def gather_data(annotation_dir, labels=[None]):
"""Gather data from a directory of CSV files.
Args:
train_dir (str): The directory containing the CSV files.
annotation_dir (str): The directory containing the CSV files.
labels (list): A list of labels to filter by.
Returns:
pd.DataFrame: A DataFrame containing the data.
"""
train_csvs = glob.glob(os.path.join(train_dir,"*.csv"))
csvs = glob.glob(os.path.join(annotation_dir,"*.csv"))
df = []
for x in train_csvs:
for x in csvs:
df.append(pd.read_csv(x))
df = pd.concat(df)
df.drop_duplicates(inplace=True)
Expand Down
77 changes: 45 additions & 32 deletions src/pipeline_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,35 @@
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.classification import Accuracy
from torchmetrics.functional import confusion_matrix
import pandas as pd

class PipelineEvaluation:
def __init__(self, model, detection_annotations_dir=None, classification_annotations_dir=None, detection_true_positive_threshold=0.8, detection_false_positive_threshold=0.5, classification_avg_score=0.5, target_labels=None):
"""Initialize pipeline evaluation"""
def __init__(self, model, detect_ground_truth_dir=None, classify_confident_ground_truth_dir=None, classify_uncertain_ground_truth_dir=None, detection_true_positive_threshold=0.8, detection_false_positive_threshold=0.5, classification_avg_score=0.5, target_labels=None, patch_size=450, patch_overlap=0, min_score=0.5):
"""Initialize pipeline evaluation.
Args:
model: Trained model for making predictions
detect_ground_truth_dir (str): Directory containing detection ground truth annotation CSV files
classify_confident_ground_truth_dir (str): Directory containing confident classification ground truth annotation CSV files
classify_uncertain_ground_truth_dir (str): Directory containing uncertain classification ground truth annotation CSV files
detection_true_positive_threshold (float): IoU threshold for considering a detection a true positive
detection_false_positive_threshold (float): IoU threshold for considering a detection a false positive
classification_avg_score (float): Threshold for classification confidence score
target_labels (list): List of target class labels to evaluate
patch_size (int): Size of image patches for prediction
patch_overlap (int): Overlap between patches
min_score (float): Minimum confidence score threshold for predictions
"""
self.detection_true_positive_threshold = detection_true_positive_threshold
self.detection_false_positive_threshold = detection_false_positive_threshold
self.classification_avg_score = classification_avg_score
self.patch_size = patch_size
self.patch_overlap = patch_overlap
self.min_score = min_score

self.detection_annotations_df = gather_data(detection_annotations_dir)
self.classification_annotations_df = gather_data(classification_annotations_dir)
self.detection_annotations_df = gather_data(detect_ground_truth_dir)
self.classification_confident_annotations_df = gather_data(classify_confident_ground_truth_dir)
self.classification_uncertain_annotations_df = gather_data(classify_uncertain_ground_truth_dir)

self.model = model

Expand All @@ -28,54 +47,48 @@ def _format_targets(self, annotations_df):
return targets

def evaluate_detection(self):
preds = self.model.predict(self.detection_annotations_df)
preds = self.model.predict(
self.detection_annotations_df.image_path.tolist(),
patch_size=self.patch_size,
patch_overlap=self.patch_overlap,
min_score=self.min_score
)
targets = self._format_targets(self.detection_annotations_df)

self.mAP.update(preds=preds, target=targets)

return self.mAP.compute()

def classification_accuracy(self):
self.classification_accuracy.update(self.classification_annotations_df)
def confident_classification_accuracy(self):
self.classification_accuracy.update(self.classification_confident_annotations_df)
return self.classification_accuracy.compute()

def confusion_matrix(self):
return confusion_matrix(self.classification_annotations_df)
def uncertain_classification_accuracy(self):
self.classification_accuracy.update(self.classification_uncertain_annotations_df)
return self.classification_accuracy.compute()

def target_classification_accuracy(self):
# Combine confident and uncertain classifications
combined_annotations_df = pd.concat([self.classification_confident_annotations_df, self.classification_uncertain_annotations_df])
if self.target_classes is not None:
self.classification_accuracy.update(self.classification_annotations_df, self.target_classes)
self.classification_accuracy.update(combined_annotations_df, self.target_classes)
return self.classification_accuracy.compute()
else:
return None

def evaluate_pipeline(self, predictions, ground_truth):
def evaluate_pipeline(self):
"""
Evaluate pipeline performance for both detection and classification
Args:
predictions: List of dictionaries containing predicted boxes and classes
Each dict should have 'bbox' (x,y,w,h) and 'class_label'
ground_truth: List of dictionaries containing ground truth annotations
Each dict should have 'bbox' (x,y,w,h) and 'class_label'
"""
detection_results = self.evaluate_detection()
classification_results = self.classification_accuracy()

confident_classification_results = self.confident_classification_accuracy()
uncertain_classification_results = self.uncertain_classification_accuracy()

results = {
'detection': {
'precision': detection_results["precision"],
'recall': detection_results["recall"],
'f1_score': detection_results["f1_score"],
'total_predictions': detection_results["total_predictions"],
'total_ground_truth': detection_results["total_ground_truth"],
'true_positives': detection_results["true_positives"]
},
'classification': {
'accuracy': classification_results["accuracy"],
'correct_classifications': classification_results["correct_classifications"],
'total_correct_detections': classification_results["total_correct_detections"]
}
'detection': detection_results,
'confident_classification': confident_classification_results,
'uncertain_classification': uncertain_classification_results
}

return results
13 changes: 10 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,19 @@ def config(tmpdir_factory):
cfg.train.checkpoint_dir = tmpdir_factory.mktemp("checkpoints").strpath

# Create detection annotations
cfg.pipeline_evaluation.detection_annotations_dir = tmpdir_factory.mktemp("detection_annotations").strpath
csv_path = os.path.join(cfg.pipeline_evaluation.detection_annotations_dir, 'detection_annotations.csv')
cfg.pipeline_evaluation.detect_ground_truth_dir = tmpdir_factory.mktemp("detection_annotations").strpath
csv_path = os.path.join(cfg.pipeline_evaluation.detect_ground_truth_dir, 'detection_annotations.csv')
df.to_csv(csv_path, index=False)

# Create classification annotations
cfg.pipeline_evaluation.classification_annotations_dir = tmpdir_factory.mktemp("classification_annotations").strpath
cfg.pipeline_evaluation.classify_confident_ground_truth_dir = tmpdir_factory.mktemp("confident_classification_annotations").strpath
csv_path = os.path.join(cfg.pipeline_evaluation.classify_confident_ground_truth_dir, 'confident_classification_annotations.csv')
df.to_csv(csv_path, index=False)

cfg.pipeline_evaluation.classify_uncertain_ground_truth_dir = tmpdir_factory.mktemp("uncertain_classification_annotations").strpath
csv_path = os.path.join(cfg.pipeline_evaluation.classify_uncertain_ground_truth_dir, 'uncertain_classification_annotations.csv')
df.to_csv(csv_path, index=False)

return cfg

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
Expand Down

0 comments on commit 1a15ea1

Please sign in to comment.