From 8ec4149af0f3ead8ec6b4409f1e38b10a539744e Mon Sep 17 00:00:00 2001 From: Lorenzo D'Agostino <127778257+lorenzodagostinoradicalbit@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:20:51 +0200 Subject: [PATCH] fix: check that prediction_proba is not None before checking its type (#61) * fix: check that prediction_proba is not None before checking its type * fix: using matched value to populate ValueError text * fix: removed ModelType in ValueError text --------- Co-authored-by: lorenzodagostinoradicalbit --- app/models/model_dto.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/app/models/model_dto.py b/app/models/model_dto.py index f76fdbd6..60a9e441 100644 --- a/app/models/model_dto.py +++ b/app/models/model_dto.py @@ -72,19 +72,19 @@ def validate_target(self) -> Self: case ModelType.BINARY: if not is_number(self.target.type): raise ValueError( - f'target must be a number for a ModelType.BINARY, has been provided [{self.target}]' + f'target must be a number for a {checked_model_type}, has been provided [{self.target}]' ) return self case ModelType.MULTI_CLASS: if not is_number_or_string(self.target.type): raise ValueError( - f'target must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.target}]' + f'target must be a number or string for a {checked_model_type}, has been provided [{self.target}]' ) return self case ModelType.REGRESSION: if not is_number(self.target.type): raise ValueError( - f'target must be a number for a ModelType.REGRESSION, has been provided [{self.target}]' + f'target must be a number for a {checked_model_type}, has been provided [{self.target}]' ) return self case _: @@ -97,31 +97,37 @@ def validate_outputs(self) -> Self: case ModelType.BINARY: if not is_number(self.outputs.prediction.type): raise ValueError( - f'prediction must be a number for a ModelType.BINARY, has been provided [{self.outputs.prediction}]' + f'prediction must be a number for a {checked_model_type}, has been provided [{self.outputs.prediction}]' ) - if not is_optional_float(self.outputs.prediction_proba.type): + if not is_none(self.outputs.prediction_proba) and not is_optional_float( + self.outputs.prediction_proba.type + ): raise ValueError( - f'prediction_proba must be an optional float for a ModelType.BINARY, has been provided [{self.outputs.prediction_proba}]' + f'prediction_proba must be an optional float for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]' ) return self case ModelType.MULTI_CLASS: if not is_number_or_string(self.outputs.prediction.type): raise ValueError( - f'prediction must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction}]' + f'prediction must be a number or string for a {checked_model_type}, has been provided [{self.outputs.prediction}]' ) - if not is_optional_float(self.outputs.prediction_proba.type): + if not is_none(self.outputs.prediction_proba) and not is_optional_float( + self.outputs.prediction_proba.type + ): raise ValueError( - f'prediction_proba must be an optional float for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction_proba}]' + f'prediction_proba must be an optional float for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]' ) return self case ModelType.REGRESSION: if not is_number(self.outputs.prediction.type): raise ValueError( - f'prediction must be a number for a ModelType.REGRESSION, has been provided [{self.outputs.prediction}]' + f'prediction must be a number for a {checked_model_type}, has been provided [{self.outputs.prediction}]' ) - if not is_none(self.outputs.prediction_proba.type): + if not is_none(self.outputs.prediction_proba) and not is_none( + self.outputs.prediction_proba.type + ): raise ValueError( - f'prediction_proba must be None for a ModelType.REGRESSION, has been provided [{self.outputs.prediction_proba}]' + f'prediction_proba must be None for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]' ) return self case _: