diff --git a/api/app/models/model_dto.py b/api/app/models/model_dto.py index f76fdbd6..60a9e441 100644 --- a/api/app/models/model_dto.py +++ b/api/app/models/model_dto.py @@ -72,19 +72,19 @@ def validate_target(self) -> Self: case ModelType.BINARY: if not is_number(self.target.type): raise ValueError( - f'target must be a number for a ModelType.BINARY, has been provided [{self.target}]' + f'target must be a number for a {checked_model_type}, has been provided [{self.target}]' ) return self case ModelType.MULTI_CLASS: if not is_number_or_string(self.target.type): raise ValueError( - f'target must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.target}]' + f'target must be a number or string for a {checked_model_type}, has been provided [{self.target}]' ) return self case ModelType.REGRESSION: if not is_number(self.target.type): raise ValueError( - f'target must be a number for a ModelType.REGRESSION, has been provided [{self.target}]' + f'target must be a number for a {checked_model_type}, has been provided [{self.target}]' ) return self case _: @@ -97,31 +97,37 @@ def validate_outputs(self) -> Self: case ModelType.BINARY: if not is_number(self.outputs.prediction.type): raise ValueError( - f'prediction must be a number for a ModelType.BINARY, has been provided [{self.outputs.prediction}]' + f'prediction must be a number for a {checked_model_type}, has been provided [{self.outputs.prediction}]' ) - if not is_optional_float(self.outputs.prediction_proba.type): + if not is_none(self.outputs.prediction_proba) and not is_optional_float( + self.outputs.prediction_proba.type + ): raise ValueError( - f'prediction_proba must be an optional float for a ModelType.BINARY, has been provided [{self.outputs.prediction_proba}]' + f'prediction_proba must be an optional float for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]' ) return self case ModelType.MULTI_CLASS: if not is_number_or_string(self.outputs.prediction.type): raise ValueError( - f'prediction must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction}]' + f'prediction must be a number or string for a {checked_model_type}, has been provided [{self.outputs.prediction}]' ) - if not is_optional_float(self.outputs.prediction_proba.type): + if not is_none(self.outputs.prediction_proba) and not is_optional_float( + self.outputs.prediction_proba.type + ): raise ValueError( - f'prediction_proba must be an optional float for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction_proba}]' + f'prediction_proba must be an optional float for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]' ) return self case ModelType.REGRESSION: if not is_number(self.outputs.prediction.type): raise ValueError( - f'prediction must be a number for a ModelType.REGRESSION, has been provided [{self.outputs.prediction}]' + f'prediction must be a number for a {checked_model_type}, has been provided [{self.outputs.prediction}]' ) - if not is_none(self.outputs.prediction_proba.type): + if not is_none(self.outputs.prediction_proba) and not is_none( + self.outputs.prediction_proba.type + ): raise ValueError( - f'prediction_proba must be None for a ModelType.REGRESSION, has been provided [{self.outputs.prediction_proba}]' + f'prediction_proba must be None for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]' ) return self case _: