From 93bdbf895de1bed283ba55d9d508f4406d41e03f Mon Sep 17 00:00:00 2001 From: Advitya Gemawat Date: Mon, 21 Aug 2023 10:56:07 -0400 Subject: [PATCH] Updated return type logic for OD detections (#147) * Added logic to return list * lint fixes --- python/ml_wrappers/model/image_model_wrapper.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/ml_wrappers/model/image_model_wrapper.py b/python/ml_wrappers/model/image_model_wrapper.py index 75f6e43..de31767 100644 --- a/python/ml_wrappers/model/image_model_wrapper.py +++ b/python/ml_wrappers/model/image_model_wrapper.py @@ -314,7 +314,7 @@ def _get_device(device: str) -> str: :rtype: str """ if (device in [member.value for member in Device] - or type(device) == int + or type(device) is int or device.isdigit() or device is None): if device == Device.AUTO.value: @@ -525,7 +525,7 @@ def predict(self, x, iou_threshold: float = 0.5, score_threshold: float = 0.5): """ detections = [] for image in x: - if type(image) == Tensor: + if type(image) is Tensor: raw_detections = self._model( image.to(self._device).unsqueeze(0)) else: @@ -543,7 +543,10 @@ def predict(self, x, iou_threshold: float = 0.5, score_threshold: float = 0.5): detections.append(image_predictions.detach().cpu().numpy() .tolist()) - return np.array(detections) + try: + return np.array(detections) + except ValueError: + return detections def predict_proba(self, dataset, iou_threshold=0.1): """Predict the output probability using the wrapped model. @@ -657,7 +660,10 @@ def predict(self, dataset: pd.DataFrame, iou_threshold: float = 0.5, detections.append( image_predictions.detach().cpu().numpy().tolist()) - return np.array(detections) + try: + return np.array(detections) + except ValueError: + return detections def predict_proba(self, dataset: pd.DataFrame, iou_threshold=0.1) -> np.ndarray: