Skip to content

Commit

Permalink
fix: check that prediction_proba is not None before checking its type
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzodagostinoradicalbit committed Jul 1, 2024
1 parent 7b33b62 commit 5508730
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def validate_outputs(self) -> Self:
raise ValueError(
f'prediction must be a number for a ModelType.BINARY, has been provided [{self.outputs.prediction}]'
)
if not is_optional_float(self.outputs.prediction_proba.type):
if not is_none(self.outputs.prediction_proba) and not is_optional_float(
self.outputs.prediction_proba.type
):
raise ValueError(
f'prediction_proba must be an optional float for a ModelType.BINARY, has been provided [{self.outputs.prediction_proba}]'
)
Expand All @@ -109,7 +111,9 @@ def validate_outputs(self) -> Self:
raise ValueError(
f'prediction must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction}]'
)
if not is_optional_float(self.outputs.prediction_proba.type):
if not is_none(self.outputs.prediction_proba) and not is_optional_float(
self.outputs.prediction_proba.type
):
raise ValueError(
f'prediction_proba must be an optional float for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction_proba}]'
)
Expand All @@ -119,7 +123,7 @@ def validate_outputs(self) -> Self:
raise ValueError(
f'prediction must be a number for a ModelType.REGRESSION, has been provided [{self.outputs.prediction}]'
)
if not is_none(self.outputs.prediction_proba.type):
if not is_none(self.outputs.prediction_proba) and not is_none(self.outputs.prediction_proba.type):
raise ValueError(
f'prediction_proba must be None for a ModelType.REGRESSION, has been provided [{self.outputs.prediction_proba}]'
)
Expand Down

0 comments on commit 5508730

Please sign in to comment.