diff --git a/README.md b/README.md index 8c821433..463b1153 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,18 @@ to check. ## Test +Please install a PostgreSQL database locally. For example, on a macOS platform, execute: + +```bash +brew install postgresql +``` + +Note: If any errors occur during pytest runs, please stop the local database service by executing: + +```bash +brew services stop postgresql +``` + Tests are done with `pytest` Run diff --git a/app/models/model_dto.py b/app/models/model_dto.py index 85cd3738..f76fdbd6 100644 --- a/app/models/model_dto.py +++ b/app/models/model_dto.py @@ -1,12 +1,14 @@ import datetime from enum import Enum -from typing import List, Optional +from typing import List, Optional, Self from uuid import UUID -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from pydantic.alias_generators import to_camel from app.db.dao.model_dao import Model +from app.models.inferred_schema_dto import SupportedTypes +from app.models.utils import is_none, is_number, is_number_or_string, is_optional_float class ModelType(str, Enum): @@ -28,26 +30,25 @@ class Granularity(str, Enum): MONTH = 'MONTH' -class ColumnDefinition(BaseModel): +class ColumnDefinition(BaseModel, validate_assignment=True): name: str - type: str + type: SupportedTypes def to_dict(self): return self.model_dump() -class OutputType(BaseModel): +class OutputType(BaseModel, validate_assignment=True): prediction: ColumnDefinition prediction_proba: Optional[ColumnDefinition] = None output: List[ColumnDefinition] - model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) def to_dict(self): return self.model_dump() -class ModelIn(BaseModel): +class ModelIn(BaseModel, validate_assignment=True): name: str description: Optional[str] = None model_type: ModelType @@ -64,6 +65,74 @@ class ModelIn(BaseModel): populate_by_name=True, alias_generator=to_camel, protected_namespaces=() ) + @model_validator(mode='after') + def validate_target(self) -> Self: + checked_model_type: ModelType = self.model_type + match checked_model_type: + case ModelType.BINARY: + if not is_number(self.target.type): + raise ValueError( + f'target must be a number for a ModelType.BINARY, has been provided [{self.target}]' + ) + return self + case ModelType.MULTI_CLASS: + if not is_number_or_string(self.target.type): + raise ValueError( + f'target must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.target}]' + ) + return self + case ModelType.REGRESSION: + if not is_number(self.target.type): + raise ValueError( + f'target must be a number for a ModelType.REGRESSION, has been provided [{self.target}]' + ) + return self + case _: + raise ValueError('not supported type for model_type') + + @model_validator(mode='after') + def validate_outputs(self) -> Self: + checked_model_type: ModelType = self.model_type + match checked_model_type: + case ModelType.BINARY: + if not is_number(self.outputs.prediction.type): + raise ValueError( + f'prediction must be a number for a ModelType.BINARY, has been provided [{self.outputs.prediction}]' + ) + if not is_optional_float(self.outputs.prediction_proba.type): + raise ValueError( + f'prediction_proba must be an optional float for a ModelType.BINARY, has been provided [{self.outputs.prediction_proba}]' + ) + return self + case ModelType.MULTI_CLASS: + if not is_number_or_string(self.outputs.prediction.type): + raise ValueError( + f'prediction must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction}]' + ) + if not is_optional_float(self.outputs.prediction_proba.type): + raise ValueError( + f'prediction_proba must be an optional float for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction_proba}]' + ) + return self + case ModelType.REGRESSION: + if not is_number(self.outputs.prediction.type): + raise ValueError( + f'prediction must be a number for a ModelType.REGRESSION, has been provided [{self.outputs.prediction}]' + ) + if not is_none(self.outputs.prediction_proba.type): + raise ValueError( + f'prediction_proba must be None for a ModelType.REGRESSION, has been provided [{self.outputs.prediction_proba}]' + ) + return self + case _: + raise ValueError('not supported type for model_type') + + @model_validator(mode='after') + def timestamp_must_be_datetime(self) -> Self: + if not self.timestamp.type == SupportedTypes.datetime: + raise ValueError('timestamp must be a datetime') + return self + def to_model(self) -> Model: now = datetime.datetime.now(tz=datetime.UTC) return Model( diff --git a/app/models/utils.py b/app/models/utils.py new file mode 100644 index 00000000..62d19ee9 --- /dev/null +++ b/app/models/utils.py @@ -0,0 +1,19 @@ +from typing import Any, Optional + +from app.models.inferred_schema_dto import SupportedTypes + + +def is_number(value: SupportedTypes): + return value in (SupportedTypes.int, SupportedTypes.float) + + +def is_number_or_string(value: SupportedTypes): + return value in (SupportedTypes.int, SupportedTypes.float, SupportedTypes.string) + + +def is_optional_float(value: Optional[SupportedTypes] = None) -> bool: + return value in (None, SupportedTypes.float) + + +def is_none(value: Any) -> bool: + return value is None diff --git a/tests/commons/db_mock.py b/tests/commons/db_mock.py index bd27b34f..b6adc8d0 100644 --- a/tests/commons/db_mock.py +++ b/tests/commons/db_mock.py @@ -8,7 +8,15 @@ from app.db.tables.reference_dataset_metrics_table import ReferenceDatasetMetrics from app.db.tables.reference_dataset_table import ReferenceDataset from app.models.job_status import JobStatus -from app.models.model_dto import DataType, Granularity, ModelIn, ModelType +from app.models.model_dto import ( + ColumnDefinition, + DataType, + Granularity, + ModelIn, + ModelType, + OutputType, + SupportedTypes, +) MODEL_UUID = uuid.uuid4() REFERENCE_UUID = uuid.uuid4() @@ -26,7 +34,7 @@ def get_sample_model( features: List[Dict] = [{'name': 'feature1', 'type': 'string'}], outputs: Dict = { 'prediction': {'name': 'pred1', 'type': 'int'}, - 'prediction_proba': {'name': 'prob1', 'type': 'double'}, + 'prediction_proba': {'name': 'prob1', 'type': 'float'}, 'output': [{'name': 'output1', 'type': 'string'}], }, target: Dict = {'name': 'target1', 'type': 'string'}, @@ -59,14 +67,20 @@ def get_sample_model_in( model_type: str = ModelType.BINARY.value, data_type: str = DataType.TEXT.value, granularity: str = Granularity.DAY.value, - features: List[Dict] = [{'name': 'feature1', 'type': 'string'}], - outputs: Dict = { - 'prediction': {'name': 'pred1', 'type': 'int'}, - 'prediction_proba': {'name': 'prob1', 'type': 'double'}, - 'output': [{'name': 'output1', 'type': 'string'}], - }, - target: Dict = {'name': 'target1', 'type': 'string'}, - timestamp: Dict = {'name': 'timestamp', 'type': 'datetime'}, + features: List[ColumnDefinition] = [ + ColumnDefinition(name='feature1', type=SupportedTypes.string) + ], + outputs: OutputType = OutputType( + prediction=ColumnDefinition(name='pred1', type=SupportedTypes.int), + prediction_proba=ColumnDefinition(name='prob1', type=SupportedTypes.float), + output=[ColumnDefinition(name='output1', type=SupportedTypes.string)], + ), + target: ColumnDefinition = ColumnDefinition( + name='target1', type=SupportedTypes.int + ), + timestamp: ColumnDefinition = ColumnDefinition( + name='timestamp', type=SupportedTypes.datetime + ), frameworks: Optional[str] = None, algorithm: Optional[str] = None, ): diff --git a/tests/commons/modelin_factory.py b/tests/commons/modelin_factory.py new file mode 100644 index 00000000..631b8dc5 --- /dev/null +++ b/tests/commons/modelin_factory.py @@ -0,0 +1,68 @@ +from app.models.inferred_schema_dto import SupportedTypes +from app.models.model_dto import ( + ColumnDefinition, + DataType, + Granularity, + ModelType, + OutputType, +) + + +def get_model_sample_wrong(fail_field: str, model_type: ModelType): + prediction = None + prediction_proba = None + if fail_field == 'outputs.prediction' and model_type == ModelType.BINARY: + prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string) + elif fail_field == 'outputs.prediction' and model_type == ModelType.MULTI_CLASS: + prediction = ColumnDefinition(name='pred1', type=SupportedTypes.datetime) + elif fail_field == 'outputs.prediction' and model_type == ModelType.REGRESSION: + prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string) + else: + prediction = ColumnDefinition(name='pred1', type=SupportedTypes.int) + + if ( + fail_field == 'outputs.prediction_proba' + and model_type == ModelType.BINARY + or fail_field == 'outputs.prediction_proba' + and model_type == ModelType.MULTI_CLASS + ): + prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.int) + elif ( + fail_field == 'outputs.prediction_proba' and model_type == ModelType.REGRESSION + ): + prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float) + else: + prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float) + + target: ColumnDefinition = None + if fail_field == 'target' and model_type == ModelType.BINARY: + target = ColumnDefinition(name='target1', type=SupportedTypes.string) + elif fail_field == 'target' and model_type == ModelType.MULTI_CLASS: + target = ColumnDefinition(name='target1', type=SupportedTypes.datetime) + elif fail_field == 'target' and model_type == ModelType.REGRESSION: + target = ColumnDefinition(name='target1', type=SupportedTypes.string) + else: + target = ColumnDefinition(name='target1', type=SupportedTypes.int) + + timestamp: ColumnDefinition = None + if fail_field == 'timestamp': + timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.string) + else: + timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.datetime) + + return { + 'name': 'model_name', + 'model_type': model_type, + 'data_type': DataType.TEXT, + 'granularity': Granularity.DAY, + 'features': [ColumnDefinition(name='feature1', type=SupportedTypes.string)], + 'outputs': OutputType( + prediction=prediction, + prediction_proba=prediction_proba, + output=[ColumnDefinition(name='output1', type=SupportedTypes.string)], + ), + 'target': target, + 'timestamp': timestamp, + 'frameworks': None, + 'algorithm': None, + } diff --git a/tests/validation/__init__.py b/tests/validation/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/validation/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/validation/model_type_validator_test.py b/tests/validation/model_type_validator_test.py new file mode 100644 index 00000000..2a743d69 --- /dev/null +++ b/tests/validation/model_type_validator_test.py @@ -0,0 +1,106 @@ +from pydantic import ValidationError +import pytest + +from app.models.model_dto import ModelIn, ModelType +from tests.commons.modelin_factory import get_model_sample_wrong + + +def test_timestamp_not_datetime(): + """Tests that timestamp validator fails when timestamp is not valid.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong( + fail_field='timestamp', model_type=ModelType.BINARY + ) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'timestamp must be a datetime' in str(excinfo.value) + + +def test_target_for_binary(): + """Tests that for ModelType.BINARY: target must be a number.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong('target', ModelType.BINARY) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'target must be a number for a ModelType.BINARY' in str(excinfo.value) + + +def test_target_for_multiclass(): + """Tests that for ModelType.MULTI_CLASS: target must be a number or string.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong('target', ModelType.MULTI_CLASS) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'target must be a number or string for a ModelType.MULTI_CLASS' in str( + excinfo.value + ) + + +def test_target_for_regression(): + """Tests that for ModelType.REGRESSION: target must be a number.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong('target', ModelType.REGRESSION) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'target must be a number for a ModelType.REGRESSION' in str(excinfo.value) + + +def test_prediction_for_binary(): + """Tests that for ModelType.BINARY: prediction must be a number.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong('outputs.prediction', ModelType.BINARY) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'prediction must be a number for a ModelType.BINARY' in str(excinfo.value) + + +def test_prediction_for_multiclass(): + """Tests that for ModelType.MULTI_CLASS: prediction must be a number or string.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong('outputs.prediction', ModelType.MULTI_CLASS) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'prediction must be a number or string for a ModelType.MULTI_CLASS' in str( + excinfo.value + ) + + +def test_prediction_for_regression(): + """Tests that for ModelType.REGRESSION: prediction must be a number.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong('outputs.prediction', ModelType.REGRESSION) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'prediction must be a number for a ModelType.REGRESSION' in str( + excinfo.value + ) + + +def test_prediction_proba_for_binary(): + """Tests that for ModelType.BINARY: prediction_proba must be a number.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong( + 'outputs.prediction_proba', ModelType.BINARY + ) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'prediction_proba must be an optional float for a ModelType.BINARY' in str( + excinfo.value + ) + + +def test_prediction_proba_for_multiclass(): + """Tests that for ModelType.MULTI_CLASS: prediction_proba must be a number.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong( + 'outputs.prediction_proba', ModelType.MULTI_CLASS + ) + ModelIn.model_validate(ModelIn(**model_data)) + assert ( + 'prediction_proba must be an optional float for a ModelType.MULTI_CLASS' + in str(excinfo.value) + ) + + +def test_prediction_proba_for_regression(): + """Tests that for ModelType.REGRESSION: prediction_proba must be None.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong( + 'outputs.prediction_proba', ModelType.REGRESSION + ) + ModelIn.model_validate(ModelIn(**model_data)) + assert 'prediction_proba must be None for a ModelType.REGRESSION' in str( + excinfo.value + )