Skip to content

Commit

Permalink
condensce code
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 22, 2024
1 parent 3f43d49 commit c05ee2c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 22 deletions.
20 changes: 20 additions & 0 deletions src/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,23 @@ def update_sys_path():
raise ValueError("Invalid strategy. Must be one of 'random', 'most-detections', or 'target-labels'.")

return chosen_images

def predict_and_divide(trained_detection_model, trained_classification_model, image_paths, patch_size, patch_overlap, confident_threshold):
predictions = detection.predict(
model=trained_detection_model,
crop_model=trained_classification_model,
image_paths=image_paths,
patch_size=patch_size,
patch_overlap=patch_overlap,
)
combined_predictions = pd.concat(predictions)

# Split predictions into confident and uncertain
uncertain_predictions = combined_predictions[
combined_predictions["score"] <= confident_threshold]

confident_predictions = combined_predictions[
~combined_predictions["image_path"].isin(
uncertain_predictions["image_path"])]

return confident_predictions, uncertain_predictions
45 changes: 24 additions & 21 deletions src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import pandas as pd
from omegaconf import DictConfig

from src.active_learning import choose_train_images, choose_test_images
from src.active_learning import choose_train_images, choose_test_images, predict_and_divide
from src import propagate
from src import label_studio
from src.classification import preprocess_and_train_classification
from src.data_processing import density_cropping
from src.detection import preprocess_and_train, predict
from src.detection import preprocess_and_train
from src.pipeline_evaluation import PipelineEvaluation
from src.reporting import Reporting

Expand Down Expand Up @@ -77,32 +77,35 @@ def run(self):
return None
else:
train_images_to_annotate = choose_train_images(
performance, trained_detection_model,
**self.config.active_learning)
test_images_to_annotate = choose_test_images(
performance, **self.config.active_testing)

predictions = predict(
m=trained_detection_model,
crop_model=trained_classification_model,
image_paths=train_images_to_annotate,
evaluation=performance,
image_dir=self.config.active_learning.image_dir,
model=trained_detection_model,
strategy=self.config.active_learning.strategy,
n=self.config.active_learning.n_images,
patch_size=self.config.active_learning.patch_size,
patch_overlap=self.config.active_learning.patch_overlap,
min_score=self.config.active_learning.min_score
)
combined_predictions = pd.concat(predictions)

# Split predictions into confident and uncertain
uncertain_predictions = combined_predictions[
combined_predictions["score"] <=
self.config.active_learning.confident_threshold]

confident_predictions = combined_predictions[
~combined_predictions["image_path"].isin(
uncertain_predictions["image_path"])]
test_images_to_annotate = choose_test_images(
image_dir=self.config.active_testing.image_dir,
model=trained_detection_model,
strategy=self.config.active_testing.strategy,
n=self.config.active_testing.n_images,
patch_size=self.config.active_testing.patch_size,
patch_overlap=self.config.active_testing.patch_overlap,
min_score=self.config.active_testing.min_score)


confident_predictions, uncertain_predictions = predict_and_divide(
trained_detection_model, trained_classification_model,
train_images_to_annotate, self.config.active_learning.patch_size,
self.config.active_learning.patch_overlap,
self.config.active_learning.confident_threshold)

reporter.confident_predictions = confident_predictions
reporter.uncertain_predictions = uncertain_predictions

print(f"Images requiring human review: {len(confident_predictions)}")
print(f"Images auto-annotated: {len(uncertain_predictions)}")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def test_choose_test_images(detection_model, config):
n=config.active_testing.n_images,
patch_size=config.active_testing.patch_size,
patch_overlap=config.active_testing.patch_overlap,
min_score=config.active_testing.min_score )
min_score=config.active_testing.min_score)
assert len(test_images_to_annotate) > 0

0 comments on commit c05ee2c

Please sign in to comment.