From b953b7ebd91d94ab655a435d192c5bcce3fd6a0c Mon Sep 17 00:00:00 2001 From: bw4sz Date: Thu, 21 Nov 2024 11:30:55 -0800 Subject: [PATCH] classification tests pass --- src/classification.py | 99 +++++----------- src/reporting.py | 10 +- src/visualization.py | 229 ++++++++++++++++++++++++++++++++++++ tests/conftest.py | 7 +- tests/test_visualization.py | 122 +++++++++++++++++++ 5 files changed, 394 insertions(+), 73 deletions(-) create mode 100644 src/visualization.py create mode 100644 tests/test_visualization.py diff --git a/src/classification.py b/src/classification.py index d7bb15f..a6b59a8 100644 --- a/src/classification.py +++ b/src/classification.py @@ -8,11 +8,7 @@ # Local imports from src.label_studio import gather_data - -import tempfile -from omegaconf import OmegaConf from pytorch_lightning.loggers import CometLogger -from deepforest import visualize def create_train_test(annotations): @@ -29,85 +25,43 @@ def get_latest_checkpoint(checkpoint_dir, annotations): m = CropModel.load_from_checkpoint(checkpoint) else: warnings.warn("No checkpoints found in {}".format(checkpoint_dir)) - label_dict = {value: index for index, value in enumerate(annotations.label.unique())} - m = CropModel() + m = CropModel(num_classes=len(annotations["label"].unique())) else: os.makedirs(checkpoint_dir) - m = CropModel() + m = CropModel(num_classes=len(annotations["label"].unique())) return m -def train(model, train_annotations, test_annotations, train_image_dir, comet_project=None, comet_workspace=None, config_args=None): +def train(model, train_dir, val_dir, comet_project=None, comet_workspace=None, fast_dev_run=False): """Train a model on labeled images. Args: - image_paths (list): A list of image paths. - train_annotations (pd.DataFrame): A DataFrame containing annotations. - test_annotations (pd.DataFrame): A DataFrame containing annotations. - train_image_dir (str): The directory containing the training images. + model (CropModel): A CropModel object. + train_dir (str): The directory containing the training images. + val_dir (str): The directory containing the validation images. comet_project (str): The comet project name for logging. Defaults to None. comet_workspace (str): The comet workspace for logging. Defaults to None. - config_args (dict): A dictionary of configuration arguments to update the model.config. Defaults to None. - + Returns: main.deepforest: A trained deepforest model. """ - tmpdir = tempfile.gettempdir() - - train_annotations.to_csv(os.path.join(tmpdir,"train.csv"), index=False) - - # Set config - model.config["train"]["csv_file"] = os.path.join(tmpdir,"train.csv") - model.config["train"]["root_dir"] = train_image_dir - - # Loop through all keys in model.config and set them to the value of the key in model.config - config_args = OmegaConf.to_container(config_args) - if config_args: - for key, value in config_args.items(): - if isinstance(value, dict): - for subkey, subvalue in value.items(): - model.config[key][subkey] = subvalue - else: - model.config[key] = value - + # Update if comet_project: comet_logger = CometLogger(project_name=comet_project, workspace=comet_workspace) - comet_logger.experiment.log_parameters(model.config) - comet_logger.experiment.log_table("train.csv", train_annotations) - comet_logger.experiment.log_table("test.csv", test_annotations) - - model.create_trainer(logger=comet_logger) + model.create_trainer(logger=comet_logger, fast_dev_run=fast_dev_run) else: - model.create_trainer() - - with comet_logger.experiment.context_manager("train_images"): - non_empty_train_annotations = train_annotations[~(train_annotations.xmax==0)] - if non_empty_train_annotations.empty: - pass - else: - sample_train_annotations = non_empty_train_annotations[non_empty_train_annotations.image_path.isin(non_empty_train_annotations.image_path.head(5))] - for filename in sample_train_annotations.image_path: - sample_train_annotations_for_image = sample_train_annotations[sample_train_annotations.image_path == filename] - sample_train_annotations_for_image.root_dir = train_image_dir - visualize.plot_results(sample_train_annotations_for_image, savedir=tmpdir) - comet_logger.experiment.log_image(os.path.join(tmpdir, filename)) + model.create_trainer(fast_dev_run=fast_dev_run) + # Get the data stored from the write_crops step above. + model.load_from_disk(train_dir=train_dir, val_dir=val_dir) model.trainer.fit(model) - with comet_logger.experiment.context_manager("post-training prediction"): - for image_path in test_annotations.image_path.head(5): - prediction = model.predict_image(path = os.path.join(train_image_dir, image_path)) - if prediction is None: - continue - visualize.plot_results(prediction, savedir=tmpdir) - comet_logger.experiment.log_image(os.path.join(tmpdir, image_path)) - return model def preprocess_images(model, annotations, root_dir, save_dir): - for image_path in annotations["image_path"]: - boxes = annotations[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist() - image_path = os.path.join(root_dir, image_path) - model.write_crops(boxes=boxes, labels=annotations.label.values, image_path=image_path, savedir=save_dir) + boxes = annotations[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist() + images = annotations["image_path"].values + labels = annotations["label"].values + model.write_crops(boxes=boxes, root_dir=root_dir, images=images, labels=labels, savedir=save_dir) def preprocess_and_train_classification(config, validation_df=None): """Preprocess data and train a crop model. @@ -131,7 +85,7 @@ def preprocess_and_train_classification(config, validation_df=None): # Load existing model if config.classification_model.checkpoint: - loaded_model = CropModel(config.classification_model.checkpoint) + loaded_model = CropModel(config.classification_model.checkpoint, num_classes=len(train_df["label"].unique())) elif os.path.exists(config.classification_model.checkpoint_dir): loaded_model = get_latest_checkpoint( @@ -141,16 +95,23 @@ def preprocess_and_train_classification(config, validation_df=None): # Preprocess train and validation data - preprocess_images(model=loaded_model, annotations=train_df, root_dir=config.classification_model.train_image_dir, save_dir=config.classification_model.crop_image_dir) - preprocess_images(model=loaded_model, annotations=validation_df, root_dir=config.classification_model.train_image_dir, save_dir=config.classification_model.crop_image_dir) + preprocess_images( + model=loaded_model, + annotations=train_df, + root_dir=config.classification_model.train_image_dir, + save_dir=config.classification_model.crop_image_dir) + preprocess_images( + model=loaded_model, + annotations=validation_df, + root_dir=config.classification_model.train_image_dir, + save_dir=config.classification_model.crop_image_dir) trained_model = train( - train_annotations=train_df, - test_annotations=validation_df, - train_image_dir=config.classification_model.crop_image_dir, + train_dir=config.classification_model.crop_image_dir, + val_dir=config.classification_model.crop_image_dir, model=loaded_model, comet_project=config.comet.project, comet_workspace=config.comet.workspace, - config_args=config.deepforest) + fast_dev_run=config.classification_model.fast_dev_run) return trained_model diff --git a/src/reporting.py b/src/reporting.py index aa46e50..74a432e 100644 --- a/src/reporting.py +++ b/src/reporting.py @@ -1,7 +1,7 @@ import pandas as pd import os from datetime import datetime - +from src.visualization import PredictionVisualizer class Reporting: def __init__(self, report_dir, pipeline_monitor): """Initialize reporting class""" @@ -17,6 +17,14 @@ def get_coco_datasets(self): """Get coco datasets""" self.pipeline_monitor.mAP.get_coco_datasets() + def generate_video(self): + """Generate a video from the predictions""" + visualizer = PredictionVisualizer() + visualizer.create_video( + predictions_list=self.pipeline_monitor.predictions, + output_path=f"{self.report_dir}/predictions.mp4" + ) + def write_metrics(self): """Write metrics to a csv file diff --git a/src/visualization.py b/src/visualization.py new file mode 100644 index 0000000..4e96c73 --- /dev/null +++ b/src/visualization.py @@ -0,0 +1,229 @@ +import cv2 +import numpy as np +import pandas as pd +from pathlib import Path +from typing import List, Optional, Tuple +import os +from deepforest.model import CropModel +from tqdm import tqdm + +class PredictionVisualizer: + def __init__( + self, + model: CropModel, + output_dir: str, + fps: int = 30, + frame_size: Tuple[int, int] = (1920, 1080), + thin_factor: int = 10 + ): + """ + Initialize the prediction visualizer. + + Args: + model: Trained CropModel instance + output_dir: Directory to save visualization outputs + fps: Frames per second for output video + frame_size: Output video frame size (width, height) + thin_factor: Take every nth image from sorted list + """ + self.model = model + self.output_dir = Path(output_dir) + self.fps = fps + self.frame_size = frame_size + self.thin_factor = thin_factor + + # Create output directory + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Define colors for different classes + self.colors = { + 'Bird': (0, 255, 0), # Green + 'Empty': (128, 128, 128) # Gray + } + + def draw_predictions( + self, + image: np.ndarray, + predictions: pd.DataFrame, + confidence_threshold: float = 0.5 + ) -> np.ndarray: + """ + Draw bounding boxes and labels on image. + + Args: + image: Input image as numpy array + predictions: DataFrame with predictions + confidence_threshold: Minimum confidence to show prediction + + Returns: + Image with drawn predictions + """ + img_with_boxes = image.copy() + + # Filter predictions by confidence + confident_preds = predictions[predictions['score'] >= confidence_threshold] + + for _, pred in confident_preds.iterrows(): + # Get coordinates + xmin, ymin = int(pred['xmin']), int(pred['ymin']) + xmax, ymax = int(pred['xmax']), int(pred['ymax']) + + # Get color for class + color = self.colors.get(pred['label'], (255, 255, 255)) + + # Draw box + cv2.rectangle(img_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) + + # Draw label with confidence + label = f"{pred['label']}: {pred['score']:.2f}" + cv2.putText( + img_with_boxes, + label, + (xmin, ymin - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color, + 2 + ) + + return img_with_boxes + + def create_visualization( + self, + image_dir: str, + output_name: str = "predictions.mp4", + confidence_threshold: float = 0.5 + ) -> str: + """ + Create video visualization of predictions on image sequence. + + Args: + image_dir: Directory containing images + output_name: Name of output video file + confidence_threshold: Minimum confidence to show prediction + + Returns: + Path to output video file + """ + # Get sorted list of images + image_paths = sorted([ + f for f in os.listdir(image_dir) + if f.lower().endswith(('.png', '.jpg', '.jpeg')) + ]) + + # Thin the image list + image_paths = image_paths[::self.thin_factor] + + if not image_paths: + raise ValueError(f"No images found in {image_dir}") + + # Create video writer + output_path = str(self.output_dir / output_name) + first_image = cv2.imread(os.path.join(image_dir, image_paths[0])) + if first_image is None: + raise ValueError(f"Could not read first image: {image_paths[0]}") + + height, width = first_image.shape[:2] + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter( + output_path, + fourcc, + self.fps, + (width, height) + ) + + try: + # Process each image + for img_name in tqdm(image_paths[:1000], desc="Creating visualization"): + img_path = os.path.join(image_dir, img_name) + image = cv2.imread(img_path) + + if image is None: + print(f"Warning: Could not read image {img_path}") + continue + + # Get predictions + predictions = self.model.predict_image(img_path) + + # Draw predictions + annotated_image = self.draw_predictions( + image, + predictions, + confidence_threshold + ) + + # Write frame + video_writer.write(annotated_image) + + finally: + video_writer.release() + + return output_path + + def create_summary_image( + self, + predictions_list: List[pd.DataFrame], + image_size: Tuple[int, int] = (800, 600) + ) -> np.ndarray: + """ + Create a summary image showing prediction statistics. + + Args: + predictions_list: List of prediction DataFrames + image_size: Size of output image + + Returns: + Summary image as numpy array + """ + # Create blank image + summary = np.ones((image_size[1], image_size[0], 3), dtype=np.uint8) * 255 + + # Compile statistics + total_predictions = sum(len(preds) for preds in predictions_list) + class_counts = {} + confidence_scores = [] + + for preds in predictions_list: + for _, pred in preds.iterrows(): + class_counts[pred['label']] = class_counts.get(pred['label'], 0) + 1 + confidence_scores.append(pred['score']) + + # Draw statistics + y_pos = 30 + cv2.putText( + summary, + f"Total Predictions: {total_predictions}", + (20, y_pos), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 0), + 2 + ) + + y_pos += 40 + for label, count in class_counts.items(): + cv2.putText( + summary, + f"{label}: {count}", + (20, y_pos), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + self.colors.get(label, (0, 0, 0)), + 2 + ) + y_pos += 40 + + if confidence_scores: + avg_confidence = np.mean(confidence_scores) + cv2.putText( + summary, + f"Average Confidence: {avg_confidence:.2f}", + (20, y_pos), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 0), + 2 + ) + + return summary \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index c78333f..fdd344e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,14 +35,14 @@ def config(tmpdir_factory): for f in os.listdir("tests/data/"): if f != '.DS_Store': shutil.copy("tests/data/" + f, cfg.detection_model.train_image_dir) - + shutil.copy("tests/data/" + f, cfg.classification_model.train_image_dir) # Create sample bounding box annotations train_data = { 'image_path': ['empty.jpg', 'birds.jpg', "birds.jpg"], 'xmin': [0, 200, 150], 'ymin': [0, 300, 250], - 'xmax': [0, 300, 250], - 'ymax': [0, 400, 350], + 'xmax': [20, 300, 250], + 'ymax': [20, 400, 350], 'label': ['Bird', 'Bird1', 'Bird2'], 'annotator': ['test_user', 'test_user', 'test_user'] } @@ -83,6 +83,7 @@ def config(tmpdir_factory): cfg.detection_model.validation_csv_path = val_csv_path cfg.detection_model.fast_dev_run = True + cfg.classification_model.fast_dev_run = True cfg.detection_model.checkpoint = "bird" cfg.detection_model.checkpoint_dir = tmpdir_factory.mktemp( "checkpoints").strpath diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..27e53f2 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,122 @@ +import pytest +import numpy as np +import pandas as pd +import cv2 +from pathlib import Path +from src.visualization import PredictionVisualizer + +@pytest.fixture +def mock_model(): + """Create a mock model that returns predictable predictions.""" + class MockModel: + def predict_image(self, image_path): + return pd.DataFrame({ + 'xmin': [100], + 'ymin': [100], + 'xmax': [200], + 'ymax': [200], + 'label': ['Bird'], + 'score': [0.9] + }) + return MockModel() + +@pytest.fixture +def test_image(): + """Create a test image.""" + return np.ones((600, 800, 3), dtype=np.uint8) * 255 + +@pytest.fixture +def test_predictions(): + """Create test predictions.""" + return pd.DataFrame({ + 'xmin': [100, 200], + 'ymin': [100, 200], + 'xmax': [200, 300], + 'ymax': [200, 300], + 'label': ['Bird', 'Bird'], + 'score': [0.9, 0.8] + }) + +def test_visualizer_initialization(mock_model, tmp_path): + """Test PredictionVisualizer initialization.""" + visualizer = PredictionVisualizer(mock_model, tmp_path) + assert visualizer.model == mock_model + assert visualizer.output_dir == tmp_path + assert visualizer.fps == 30 + +def test_draw_predictions(mock_model, tmp_path, test_image, test_predictions): + """Test drawing predictions on image.""" + visualizer = PredictionVisualizer(mock_model, tmp_path) + result = visualizer.draw_predictions(test_image, test_predictions) + + assert isinstance(result, np.ndarray) + assert result.shape == test_image.shape + # Image should be different from original due to drawn boxes + assert not np.array_equal(result, test_image) + +def test_create_visualization(mock_model, tmp_path): + """Test video creation from image sequence.""" + # Create test images + image_dir = tmp_path / "images" + image_dir.mkdir() + + for i in range(5): + img = np.ones((600, 800, 3), dtype=np.uint8) * 255 + cv2.imwrite(str(image_dir / f"image_{i:03d}.jpg"), img) + + visualizer = PredictionVisualizer(mock_model, tmp_path) + output_path = visualizer.create_visualization(str(image_dir)) + + assert Path(output_path).exists() + assert output_path.endswith('.mp4') + +def test_create_summary_image(mock_model, tmp_path): + """Test creation of summary statistics image.""" + visualizer = PredictionVisualizer(mock_model, tmp_path) + + predictions_list = [ + pd.DataFrame({ + 'label': ['Bird', 'Bird'], + 'score': [0.9, 0.8] + }), + pd.DataFrame({ + 'label': ['Bird'], + 'score': [0.95] + }) + ] + + summary = visualizer.create_summary_image(predictions_list) + assert isinstance(summary, np.ndarray) + assert summary.shape == (600, 800, 3) + +def test_empty_image_dir(mock_model, tmp_path): + """Test handling of empty image directory.""" + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + + visualizer = PredictionVisualizer(mock_model, tmp_path) + with pytest.raises(ValueError, match="No images found"): + visualizer.create_visualization(str(empty_dir)) + +def test_invalid_image(mock_model, tmp_path): + """Test handling of invalid image file.""" + image_dir = tmp_path / "images" + image_dir.mkdir() + + # Create invalid image file + (image_dir / "invalid.jpg").write_text("not an image") + + visualizer = PredictionVisualizer(mock_model, tmp_path) + with pytest.raises(ValueError, match="Could not read first image"): + visualizer.create_visualization(str(image_dir)) + +@pytest.mark.parametrize("confidence_threshold", [0.3, 0.7, 0.9]) +def test_confidence_thresholds(mock_model, tmp_path, test_image, test_predictions, confidence_threshold): + """Test different confidence thresholds.""" + visualizer = PredictionVisualizer(mock_model, tmp_path) + result = visualizer.draw_predictions( + test_image, + test_predictions, + confidence_threshold=confidence_threshold + ) + assert isinstance(result, np.ndarray) \ No newline at end of file