Skip to content

Commit

Permalink
modify automl threshold params and add gpu optimizations in object
Browse files Browse the repository at this point in the history
detection ml-wrappers
  • Loading branch information
imatiach-msft committed Aug 15, 2023
1 parent 8ecf1c4 commit eb43f57
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 65 deletions.
2 changes: 1 addition & 1 deletion python/docs/object_detection_model_wrapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Example:
self._model = model
self._number_of_classes = number_of_classes
def predict(self, x, iou_thresh: float = 0.5, score_thresh: float = 0.5):
def predict(self, x, iou_threshold: float = 0.5, score_threshold: float = 0.5):
"""Create a list of detection records from the image predictions."""
detections = []
for image in x:
Expand Down
156 changes: 101 additions & 55 deletions python/ml_wrappers/model/image_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
BOXES = 'boxes'
LABELS = 'labels'
SCORES = 'scores'
COLON = ":"


def _is_fastai_model(model):
Expand All @@ -70,17 +71,17 @@ def _is_fastai_model(model):
return str(type(model)).endswith(FASTAI_MODEL_SUFFIX)


def _filter_score(orig_prediction: dict, score_thresh: float = 0.5):
"""Filter out predictions with confidence scores < score_thresh.
def _filter_score(orig_prediction: dict, score_threshold: float = 0.5):
"""Filter out predictions with confidence scores < score_threshold.
:param orig_prediction: Original model prediction
:type orig_prediction: dict
:param score_thresh: Score threshold to filter by
:type score_thresh: float
:return: Model predictions filtered out by score_thresh
:param score_threshold: Score threshold to filter by
:type score_threshold: float
:return: Model predictions filtered out by score_threshold
:rtype: dict
"""
keep = orig_prediction[SCORES] > score_thresh
keep = orig_prediction[SCORES] > score_threshold

filter_prediction = orig_prediction
filter_prediction[BOXES] = filter_prediction[BOXES][keep]
Expand All @@ -89,19 +90,19 @@ def _filter_score(orig_prediction: dict, score_thresh: float = 0.5):
return filter_prediction


def _apply_nms(orig_prediction: dict, iou_thresh: float = 0.5):
def _apply_nms(orig_prediction: dict, iou_threshold: float = 0.5):
"""Perform nms on the predictions based on their IoU.
:param orig_prediction: Original model prediction
:type orig_prediction: dict
:param iou_thresh: iou_threshold for nms
:type iou_thresh: float
:param iou_threshold: iou_threshold for nms
:type iou_threshold: float
:return: Model prediction after nms is applied
:rtype: dict
"""
keep = torchvision.ops.nms(orig_prediction[BOXES],
orig_prediction[SCORES],
iou_thresh)
iou_threshold)

nms_prediction = orig_prediction
nms_prediction[BOXES] = nms_prediction[BOXES][keep]
Expand Down Expand Up @@ -314,13 +315,19 @@ def _get_device(device: str) -> str:
"""
if (device in [member.value for member in Device]
or type(device) == int
or device.isdigit()
or device is None):
if device == Device.AUTO.value:
if torch.cuda.is_available():
return Device.CUDA.value
else:
return Device.CPU.value
return device
elif COLON in device:
split_vals = device.split(COLON)
if len(split_vals) == 2:
return COLON.join([_get_device(
split_val) for split_val in split_vals])
raise ValueError("Selected device is invalid")


Expand Down Expand Up @@ -490,7 +497,7 @@ def __init__(self,
self._model = model
self._number_of_classes = number_of_classes

def predict(self, x, iou_thresh: float = 0.5, score_thresh: float = 0.5):
def predict(self, x, iou_threshold: float = 0.5, score_threshold: float = 0.5):
"""Create a list of detection records from the image predictions.
:param x: Tensor of the image
Expand Down Expand Up @@ -588,8 +595,8 @@ def _mlflow_predict(self, dataset: pd.DataFrame) -> pd.DataFrame:
predictions = self._model.predict(dataset)
return predictions

def predict(self, dataset: pd.DataFrame, iou_thresh: float = 0.5,
score_thresh: float = 0.5):
def predict(self, dataset: pd.DataFrame, iou_threshold: float = 0.5,
score_threshold: float = 0.5):
"""Create a list of detection records from the image predictions.
Below is example Label (y) representation for a cohort of 2 images,
Expand All @@ -609,13 +616,13 @@ def predict(self, dataset: pd.DataFrame, iou_thresh: float = 0.5,
:param dataset: The dataset to predict on.
:type dataset: pandas.DataFrame
:param iou_thresh: Intersection-over-Union (IoU) threshold for NMS (or
:param iou_threshold: Intersection-over-Union (IoU) threshold for NMS (or
the amount of acceptable error). Objects with error
cores higher than the threshold will be removed.
:type iou_thresh: float
:param score_thresh: Threshold to filter detections based on
:type iou_threshold: float
:param score_threshold: Threshold to filter detections based on
predicted confidence scores.
:type score_thresh: float
:type score_threshold: float
:return: Final detections from the object detector
:rtype: numpy array of Detection Records
"""
Expand All @@ -636,8 +643,8 @@ def predict(self, dataset: pd.DataFrame, iou_thresh: float = 0.5,
raw_detections = _process_automl_detections_to_raw_detections(
image_detections, self._label_dict, img_size)

raw_detections = _apply_nms(raw_detections, iou_thresh)
raw_detections = _filter_score(raw_detections, score_thresh)
raw_detections = _apply_nms(raw_detections, iou_threshold)
raw_detections = _filter_score(raw_detections, score_threshold)

image_predictions = torch.cat((
raw_detections["labels"].unsqueeze(1),
Expand All @@ -653,19 +660,19 @@ def predict(self, dataset: pd.DataFrame, iou_thresh: float = 0.5,
return np.array(detections)

def predict_proba(self, dataset: pd.DataFrame,
iou_thresh=0.1) -> np.ndarray:
iou_threshold=0.1) -> np.ndarray:
"""Predict the output probability using the MLflow model.
:param dataset: The dataset to predict_proba on.
:type dataset: pandas.DataFrame
:param iou_thresh: amount of acceptable error.
:param iou_threshold: amount of acceptable error.
objects with error scores higher than the threshold will be removed
:type iou_thresh: float
:type iou_threshold: float
:return: The predicted probabilities.
:rtype: numpy.ndarray
"""

predictions = self.predict(dataset, iou_thresh=iou_thresh)
predictions = self.predict(dataset, iou_threshold=iou_threshold)
prob_scores = [[pred[-1] for pred in image_prediction]
for image_prediction in predictions]
return prob_scores
Expand All @@ -682,59 +689,91 @@ class PytorchDRiseWrapper(GeneralObjectDetectionModelWrapper):
"""

def __init__(self, model, number_of_classes: int,
device=Device.AUTO.value):
device=Device.AUTO.value,
transforms=None,
iou_threshold=None,
score_threshold=None):
"""Initialize the PytorchDRiseWrapper.
:param model: Object detection model
:type model: PytorchFasterRCNN model
:param number_of_classes: Number of classes the model is predicting
:type number_of_classes: int
:param device: optional parameter specifying the device to move the
:param device: Optional parameter specifying the device to move the
model to. If not specified, then cpu is the default
:type device: str
:param transforms: Optional parameter specifying the transforms to
apply to the image before passing it to the model
:type transforms: torchvision.transforms
:param iou_threshold: Optional parameter specifying the iou_threshold
for nms. If not specified, the iou_threshold on the predict
method is used.
:type iou_threshold: float
:param score_threshold: Optional parameter specifying the
score_threshold for filtering detections. If not
specified, the score_threshold on the predict method is
used.
:type score_threshold: float
"""
self.transforms = transforms
self._device = torch.device(_get_device(device))
model.to(self._device)
model.eval()

self._model = model
self._number_of_classes = number_of_classes
self._iou_threshold = iou_threshold
self._score_threshold = score_threshold

def predict(self, x: Tensor):
def predict(self, x: Tensor, iou_threshold: float = 0.5,
score_threshold: float = 0.5):
"""Create a list of detection records from the image predictions.
:param x: Tensor of the image
:type x: torch.Tensor
:param iou_threshold: Intersection-over-Union (IoU) threshold for NMS (or
the amount of acceptable error). Objects with error
scores higher than the threshold will be removed.
:type iou_threshold: float
:param score_threshold: Threshold to filter detections based on
predicted confidence scores.
:type score_threshold: float
:return: Baseline detections to get saliency maps for
:rtype: List of Detection Records
"""
raw_detections = self._model(x)

detections = []
for raw_detection in raw_detections:
raw_detection = _apply_nms(raw_detection, 0.5)

# Note that FasterRCNN doesn't return a score for each class, only
# the predicted class. DRISE requires a score for each class.
# We approximate the score for each class
# by dividing (class score) evenly among the other classes.

raw_detection = _filter_score(raw_detection, 0.5)
expanded_class_scores = expand_class_scores(
raw_detection[SCORES],
raw_detection[LABELS],
self._number_of_classes)

detections.append(
od_common.DetectionRecord(
bounding_boxes=raw_detection[BOXES],
class_scores=expanded_class_scores,
objectness_scores=torch.tensor(
[1.0]*raw_detection[BOXES].shape[0]),
if self._iou_threshold is not None:
iou_threshold = self._iou_threshold
if self._score_threshold is not None:
score_threshold = self._score_threshold
with torch.no_grad():
raw_detections = self._model(x)

detections = []
for raw_detection in raw_detections:
raw_detection = _apply_nms(raw_detection, iou_threshold)

# Note that FasterRCNN doesn't return a score for each
# class, only the predicted class. DRISE requires a
# score for each class.
# We approximate the score for each class
# by dividing (class score) evenly among the other classes.

raw_detection = _filter_score(raw_detection, score_threshold)
expanded_class_scores = expand_class_scores(
raw_detection[SCORES],
raw_detection[LABELS],
self._number_of_classes)

detections.append(
od_common.DetectionRecord(
bounding_boxes=raw_detection[BOXES],
class_scores=expanded_class_scores,
objectness_scores=torch.tensor(
[1.0]*raw_detection[BOXES].shape[0]),
)
)
)

return detections
return detections


class MLflowDRiseWrapper():
Expand Down Expand Up @@ -774,12 +813,19 @@ def _mlflow_predict(self, dataset: pd.DataFrame) -> pd.DataFrame:
predictions = self._model.predict(dataset)
return predictions

def predict(self, dataset: pd.DataFrame, iou_thresh: float = 0.25,
score_thresh: float = 0.5):
def predict(self, dataset: pd.DataFrame, iou_threshold: float = 0.25,
score_threshold: float = 0.5):
"""Predict the output value using the wrapped MLflow model.
:param dataset: The dataset to predict on.
:type dataset: pandas.DataFrame
:param iou_threshold: Intersection-over-Union (IoU) threshold for NMS (or
the amount of acceptable error). Objects with error
scores higher than the threshold will be removed.
:type iou_threshold: float
:param score_threshold: Threshold to filter detections based on
predicted confidence scores.
:type score_threshold: float
:return: The predicted values.
:rtype: numpy.ndarray
"""
Expand Down Expand Up @@ -809,8 +855,8 @@ def predict(self, dataset: pd.DataFrame, iou_thresh: float = 0.25,
detections.append(None)
continue

raw_detections = _apply_nms(raw_detections, iou_thresh)
raw_detections = _filter_score(raw_detections, score_thresh)
raw_detections = _apply_nms(raw_detections, iou_threshold)
raw_detections = _filter_score(raw_detections, score_threshold)

# Note that FasterRCNN doesn't return a score for each class, only
# the predicted class. DRISE requires a score for each class.
Expand Down
39 changes: 32 additions & 7 deletions tests/automl/test_automl_image_object_detection_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import azureml.automl.core.shared.constants as shared_constants
import mlflow
import pytest
import torch
from azureml.automl.dnn.vision.common.mlflow.mlflow_model_wrapper import \
MLFlowImagesModelWrapper
from azureml.automl.dnn.vision.common.model_export_utils import (
Expand All @@ -25,11 +24,19 @@
from common_vision_utils import load_base64_images, load_object_fridge_dataset
from ml_wrappers import wrap_model
from ml_wrappers.common.constants import ModelTask
from ml_wrappers.model.image_model_wrapper import MLflowDRiseWrapper
from ml_wrappers.model.image_model_wrapper import (
MLflowDRiseWrapper, PytorchDRiseWrapper)
from wrapper_validator import (
validate_wrapped_object_detection_mlflow_drise_model,
validate_wrapped_object_detection_custom_model,
validate_wrapped_object_detection_model)

try:
import torch
from torchvision import transforms as T
except ImportError:
print('Could not import torch, required if using a PyTorch model')


@pytest.mark.usefixtures('_clean_dir')
class TestImageModelWrapper(object):
Expand Down Expand Up @@ -143,7 +150,9 @@ def test_wrap_automl_object_detection_model(self):
sys.version_info >= (3, 9),
reason=('azureml-automl-dnn-vision not supported ' +
'for newer versions of python'))
def test_wrap_automl_object_detection_model_drise(self):
@pytest.mark.parametrize('extract_raw_model',
[True, False])
def test_wrap_automl_object_detection_model_drise(self, extract_raw_model):
data = load_object_fridge_dataset()[3:4]
model_name = ModelNames.FASTER_RCNN_RESNET50_FPN

Expand Down Expand Up @@ -220,7 +229,23 @@ def test_wrap_automl_object_detection_model_drise(self):
# load the paths as base64 images
data = load_base64_images(data, return_image_size=True)

wrapped_model = MLflowDRiseWrapper(mlflow_model,
classes=class_names)
validate_wrapped_object_detection_mlflow_drise_model(
wrapped_model, data)
if extract_raw_model:
python_model = mlflow_model._model._model_impl.python_model
automl_wrapper = python_model._model
inner_model = automl_wrapper._model
number_of_classes = automl_wrapper._number_of_classes
transforms = automl_wrapper.get_inference_transform()
wrapped_model = PytorchDRiseWrapper(
inner_model, number_of_classes,
transforms=transforms,
iou_threshold=0.25,
score_threshold=0.5)
validate_wrapped_object_detection_custom_model(
wrapped_model,
T.ToTensor()(data[0]).repeat(2, 1, 1, 1),
has_predict_proba=False)
else:
wrapped_model = MLflowDRiseWrapper(mlflow_model,
classes=class_names)
validate_wrapped_object_detection_mlflow_drise_model(
wrapped_model, data)
7 changes: 5 additions & 2 deletions tests/main/test_image_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
from torchvision import transforms as T
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
except ImportError:
print('Could not import torchvision, \
required if using a vision PyTorch model')
print('Could not import torch, required if using a PyTorch model')


@pytest.mark.usefixtures('_clean_dir')
Expand Down Expand Up @@ -195,6 +194,10 @@ def test_get_device(self):
# wrap_model invocation in RAIVisionInsights
device = _get_device("auto")
assert device == "cpu" or device == "cuda"
device = _get_device("cuda:1")
assert device == "cuda:1"
device = _get_device("cpu")
assert device == "cpu"

def _set_up_OD_model(self):
"""Returns generic model and dataset for OD testing (FastRCNN)"""
Expand Down

0 comments on commit eb43f57

Please sign in to comment.