-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: add model_type validations (#52)
* chore: add model_type validations * chore: add model_type validations * doc: postgres installation required for local test * chore: add authors.md and its related automation script * chore: add a Github Actions in order to automate update-authors.sh executions * chore: refactoring tests
- Loading branch information
Showing
7 changed files
with
306 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Any, Optional | ||
|
||
from app.models.inferred_schema_dto import SupportedTypes | ||
|
||
|
||
def is_number(value: SupportedTypes): | ||
return value in (SupportedTypes.int, SupportedTypes.float) | ||
|
||
|
||
def is_number_or_string(value: SupportedTypes): | ||
return value in (SupportedTypes.int, SupportedTypes.float, SupportedTypes.string) | ||
|
||
|
||
def is_optional_float(value: Optional[SupportedTypes] = None) -> bool: | ||
return value in (None, SupportedTypes.float) | ||
|
||
|
||
def is_none(value: Any) -> bool: | ||
return value is None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from app.models.inferred_schema_dto import SupportedTypes | ||
from app.models.model_dto import ( | ||
ColumnDefinition, | ||
DataType, | ||
Granularity, | ||
ModelType, | ||
OutputType, | ||
) | ||
|
||
|
||
def get_model_sample_wrong(fail_field: str, model_type: ModelType): | ||
prediction = None | ||
prediction_proba = None | ||
if fail_field == 'outputs.prediction' and model_type == ModelType.BINARY: | ||
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string) | ||
elif fail_field == 'outputs.prediction' and model_type == ModelType.MULTI_CLASS: | ||
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.datetime) | ||
elif fail_field == 'outputs.prediction' and model_type == ModelType.REGRESSION: | ||
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string) | ||
else: | ||
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.int) | ||
|
||
if ( | ||
fail_field == 'outputs.prediction_proba' | ||
and model_type == ModelType.BINARY | ||
or fail_field == 'outputs.prediction_proba' | ||
and model_type == ModelType.MULTI_CLASS | ||
): | ||
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.int) | ||
elif ( | ||
fail_field == 'outputs.prediction_proba' and model_type == ModelType.REGRESSION | ||
): | ||
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float) | ||
else: | ||
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float) | ||
|
||
target: ColumnDefinition = None | ||
if fail_field == 'target' and model_type == ModelType.BINARY: | ||
target = ColumnDefinition(name='target1', type=SupportedTypes.string) | ||
elif fail_field == 'target' and model_type == ModelType.MULTI_CLASS: | ||
target = ColumnDefinition(name='target1', type=SupportedTypes.datetime) | ||
elif fail_field == 'target' and model_type == ModelType.REGRESSION: | ||
target = ColumnDefinition(name='target1', type=SupportedTypes.string) | ||
else: | ||
target = ColumnDefinition(name='target1', type=SupportedTypes.int) | ||
|
||
timestamp: ColumnDefinition = None | ||
if fail_field == 'timestamp': | ||
timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.string) | ||
else: | ||
timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.datetime) | ||
|
||
return { | ||
'name': 'model_name', | ||
'model_type': model_type, | ||
'data_type': DataType.TEXT, | ||
'granularity': Granularity.DAY, | ||
'features': [ColumnDefinition(name='feature1', type=SupportedTypes.string)], | ||
'outputs': OutputType( | ||
prediction=prediction, | ||
prediction_proba=prediction_proba, | ||
output=[ColumnDefinition(name='output1', type=SupportedTypes.string)], | ||
), | ||
'target': target, | ||
'timestamp': timestamp, | ||
'frameworks': None, | ||
'algorithm': None, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from pydantic import ValidationError | ||
import pytest | ||
|
||
from app.models.model_dto import ModelIn, ModelType | ||
from tests.commons.modelin_factory import get_model_sample_wrong | ||
|
||
|
||
def test_timestamp_not_datetime(): | ||
"""Tests that timestamp validator fails when timestamp is not valid.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong( | ||
fail_field='timestamp', model_type=ModelType.BINARY | ||
) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'timestamp must be a datetime' in str(excinfo.value) | ||
|
||
|
||
def test_target_for_binary(): | ||
"""Tests that for ModelType.BINARY: target must be a number.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong('target', ModelType.BINARY) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'target must be a number for a ModelType.BINARY' in str(excinfo.value) | ||
|
||
|
||
def test_target_for_multiclass(): | ||
"""Tests that for ModelType.MULTI_CLASS: target must be a number or string.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong('target', ModelType.MULTI_CLASS) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'target must be a number or string for a ModelType.MULTI_CLASS' in str( | ||
excinfo.value | ||
) | ||
|
||
|
||
def test_target_for_regression(): | ||
"""Tests that for ModelType.REGRESSION: target must be a number.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong('target', ModelType.REGRESSION) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'target must be a number for a ModelType.REGRESSION' in str(excinfo.value) | ||
|
||
|
||
def test_prediction_for_binary(): | ||
"""Tests that for ModelType.BINARY: prediction must be a number.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong('outputs.prediction', ModelType.BINARY) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'prediction must be a number for a ModelType.BINARY' in str(excinfo.value) | ||
|
||
|
||
def test_prediction_for_multiclass(): | ||
"""Tests that for ModelType.MULTI_CLASS: prediction must be a number or string.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong('outputs.prediction', ModelType.MULTI_CLASS) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'prediction must be a number or string for a ModelType.MULTI_CLASS' in str( | ||
excinfo.value | ||
) | ||
|
||
|
||
def test_prediction_for_regression(): | ||
"""Tests that for ModelType.REGRESSION: prediction must be a number.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong('outputs.prediction', ModelType.REGRESSION) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'prediction must be a number for a ModelType.REGRESSION' in str( | ||
excinfo.value | ||
) | ||
|
||
|
||
def test_prediction_proba_for_binary(): | ||
"""Tests that for ModelType.BINARY: prediction_proba must be a number.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong( | ||
'outputs.prediction_proba', ModelType.BINARY | ||
) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'prediction_proba must be an optional float for a ModelType.BINARY' in str( | ||
excinfo.value | ||
) | ||
|
||
|
||
def test_prediction_proba_for_multiclass(): | ||
"""Tests that for ModelType.MULTI_CLASS: prediction_proba must be a number.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong( | ||
'outputs.prediction_proba', ModelType.MULTI_CLASS | ||
) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert ( | ||
'prediction_proba must be an optional float for a ModelType.MULTI_CLASS' | ||
in str(excinfo.value) | ||
) | ||
|
||
|
||
def test_prediction_proba_for_regression(): | ||
"""Tests that for ModelType.REGRESSION: prediction_proba must be None.""" | ||
with pytest.raises(ValidationError) as excinfo: | ||
model_data = get_model_sample_wrong( | ||
'outputs.prediction_proba', ModelType.REGRESSION | ||
) | ||
ModelIn.model_validate(ModelIn(**model_data)) | ||
assert 'prediction_proba must be None for a ModelType.REGRESSION' in str( | ||
excinfo.value | ||
) |