Skip to content

Commit

Permalink
classification tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 21, 2024
1 parent b365be9 commit b953b7e
Show file tree
Hide file tree
Showing 5 changed files with 394 additions and 73 deletions.
99 changes: 30 additions & 69 deletions src/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@

# Local imports
from src.label_studio import gather_data

import tempfile
from omegaconf import OmegaConf
from pytorch_lightning.loggers import CometLogger
from deepforest import visualize


def create_train_test(annotations):
Expand All @@ -29,85 +25,43 @@ def get_latest_checkpoint(checkpoint_dir, annotations):
m = CropModel.load_from_checkpoint(checkpoint)
else:
warnings.warn("No checkpoints found in {}".format(checkpoint_dir))
label_dict = {value: index for index, value in enumerate(annotations.label.unique())}
m = CropModel()
m = CropModel(num_classes=len(annotations["label"].unique()))
else:
os.makedirs(checkpoint_dir)
m = CropModel()
m = CropModel(num_classes=len(annotations["label"].unique()))

return m

def train(model, train_annotations, test_annotations, train_image_dir, comet_project=None, comet_workspace=None, config_args=None):
def train(model, train_dir, val_dir, comet_project=None, comet_workspace=None, fast_dev_run=False):
"""Train a model on labeled images.
Args:
image_paths (list): A list of image paths.
train_annotations (pd.DataFrame): A DataFrame containing annotations.
test_annotations (pd.DataFrame): A DataFrame containing annotations.
train_image_dir (str): The directory containing the training images.
model (CropModel): A CropModel object.
train_dir (str): The directory containing the training images.
val_dir (str): The directory containing the validation images.
comet_project (str): The comet project name for logging. Defaults to None.
comet_workspace (str): The comet workspace for logging. Defaults to None.
config_args (dict): A dictionary of configuration arguments to update the model.config. Defaults to None.
Returns:
main.deepforest: A trained deepforest model.
"""
tmpdir = tempfile.gettempdir()

train_annotations.to_csv(os.path.join(tmpdir,"train.csv"), index=False)

# Set config
model.config["train"]["csv_file"] = os.path.join(tmpdir,"train.csv")
model.config["train"]["root_dir"] = train_image_dir

# Loop through all keys in model.config and set them to the value of the key in model.config
config_args = OmegaConf.to_container(config_args)
if config_args:
for key, value in config_args.items():
if isinstance(value, dict):
for subkey, subvalue in value.items():
model.config[key][subkey] = subvalue
else:
model.config[key] = value

# Update
if comet_project:
comet_logger = CometLogger(project_name=comet_project, workspace=comet_workspace)
comet_logger.experiment.log_parameters(model.config)
comet_logger.experiment.log_table("train.csv", train_annotations)
comet_logger.experiment.log_table("test.csv", test_annotations)

model.create_trainer(logger=comet_logger)
model.create_trainer(logger=comet_logger, fast_dev_run=fast_dev_run)
else:
model.create_trainer()

with comet_logger.experiment.context_manager("train_images"):
non_empty_train_annotations = train_annotations[~(train_annotations.xmax==0)]
if non_empty_train_annotations.empty:
pass
else:
sample_train_annotations = non_empty_train_annotations[non_empty_train_annotations.image_path.isin(non_empty_train_annotations.image_path.head(5))]
for filename in sample_train_annotations.image_path:
sample_train_annotations_for_image = sample_train_annotations[sample_train_annotations.image_path == filename]
sample_train_annotations_for_image.root_dir = train_image_dir
visualize.plot_results(sample_train_annotations_for_image, savedir=tmpdir)
comet_logger.experiment.log_image(os.path.join(tmpdir, filename))
model.create_trainer(fast_dev_run=fast_dev_run)

# Get the data stored from the write_crops step above.
model.load_from_disk(train_dir=train_dir, val_dir=val_dir)
model.trainer.fit(model)

with comet_logger.experiment.context_manager("post-training prediction"):
for image_path in test_annotations.image_path.head(5):
prediction = model.predict_image(path = os.path.join(train_image_dir, image_path))
if prediction is None:
continue
visualize.plot_results(prediction, savedir=tmpdir)
comet_logger.experiment.log_image(os.path.join(tmpdir, image_path))

return model

def preprocess_images(model, annotations, root_dir, save_dir):
for image_path in annotations["image_path"]:
boxes = annotations[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist()
image_path = os.path.join(root_dir, image_path)
model.write_crops(boxes=boxes, labels=annotations.label.values, image_path=image_path, savedir=save_dir)
boxes = annotations[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist()
images = annotations["image_path"].values
labels = annotations["label"].values
model.write_crops(boxes=boxes, root_dir=root_dir, images=images, labels=labels, savedir=save_dir)

def preprocess_and_train_classification(config, validation_df=None):
"""Preprocess data and train a crop model.
Expand All @@ -131,7 +85,7 @@ def preprocess_and_train_classification(config, validation_df=None):

# Load existing model
if config.classification_model.checkpoint:
loaded_model = CropModel(config.classification_model.checkpoint)
loaded_model = CropModel(config.classification_model.checkpoint, num_classes=len(train_df["label"].unique()))

elif os.path.exists(config.classification_model.checkpoint_dir):
loaded_model = get_latest_checkpoint(
Expand All @@ -141,16 +95,23 @@ def preprocess_and_train_classification(config, validation_df=None):


# Preprocess train and validation data
preprocess_images(model=loaded_model, annotations=train_df, root_dir=config.classification_model.train_image_dir, save_dir=config.classification_model.crop_image_dir)
preprocess_images(model=loaded_model, annotations=validation_df, root_dir=config.classification_model.train_image_dir, save_dir=config.classification_model.crop_image_dir)
preprocess_images(
model=loaded_model,
annotations=train_df,
root_dir=config.classification_model.train_image_dir,
save_dir=config.classification_model.crop_image_dir)
preprocess_images(
model=loaded_model,
annotations=validation_df,
root_dir=config.classification_model.train_image_dir,
save_dir=config.classification_model.crop_image_dir)

trained_model = train(
train_annotations=train_df,
test_annotations=validation_df,
train_image_dir=config.classification_model.crop_image_dir,
train_dir=config.classification_model.crop_image_dir,
val_dir=config.classification_model.crop_image_dir,
model=loaded_model,
comet_project=config.comet.project,
comet_workspace=config.comet.workspace,
config_args=config.deepforest)
fast_dev_run=config.classification_model.fast_dev_run)

return trained_model
10 changes: 9 additions & 1 deletion src/reporting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import os
from datetime import datetime

from src.visualization import PredictionVisualizer
class Reporting:
def __init__(self, report_dir, pipeline_monitor):
"""Initialize reporting class"""
Expand All @@ -17,6 +17,14 @@ def get_coco_datasets(self):
"""Get coco datasets"""
self.pipeline_monitor.mAP.get_coco_datasets()

def generate_video(self):
"""Generate a video from the predictions"""
visualizer = PredictionVisualizer()
visualizer.create_video(
predictions_list=self.pipeline_monitor.predictions,
output_path=f"{self.report_dir}/predictions.mp4"
)

def write_metrics(self):
"""Write metrics to a csv file
Expand Down
Loading

0 comments on commit b953b7e

Please sign in to comment.