Skip to content

Commit

Permalink
chore: add model_type validations (#52)
Browse files Browse the repository at this point in the history
* chore: add model_type validations

* chore: add model_type validations

* doc: postgres installation required for local test

* chore: add authors.md and its related automation script

* chore: add a Github Actions in order to automate update-authors.sh executions

* chore: refactoring tests
  • Loading branch information
bigmoby authored Jul 1, 2024
1 parent 8b484f6 commit 922ea7d
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 17 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ to check.

## Test

Please install a PostgreSQL database locally. For example, on a macOS platform, execute:

```bash
brew install postgresql
```

Note: If any errors occur during pytest runs, please stop the local database service by executing:

```bash
brew services stop postgresql
```

Tests are done with `pytest`

Run
Expand Down
83 changes: 76 additions & 7 deletions app/models/model_dto.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import datetime
from enum import Enum
from typing import List, Optional
from typing import List, Optional, Self
from uuid import UUID

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic.alias_generators import to_camel

from app.db.dao.model_dao import Model
from app.models.inferred_schema_dto import SupportedTypes
from app.models.utils import is_none, is_number, is_number_or_string, is_optional_float


class ModelType(str, Enum):
Expand All @@ -28,26 +30,25 @@ class Granularity(str, Enum):
MONTH = 'MONTH'


class ColumnDefinition(BaseModel):
class ColumnDefinition(BaseModel, validate_assignment=True):
name: str
type: str
type: SupportedTypes

def to_dict(self):
return self.model_dump()


class OutputType(BaseModel):
class OutputType(BaseModel, validate_assignment=True):
prediction: ColumnDefinition
prediction_proba: Optional[ColumnDefinition] = None
output: List[ColumnDefinition]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

def to_dict(self):
return self.model_dump()


class ModelIn(BaseModel):
class ModelIn(BaseModel, validate_assignment=True):
name: str
description: Optional[str] = None
model_type: ModelType
Expand All @@ -64,6 +65,74 @@ class ModelIn(BaseModel):
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)

@model_validator(mode='after')
def validate_target(self) -> Self:
checked_model_type: ModelType = self.model_type
match checked_model_type:
case ModelType.BINARY:
if not is_number(self.target.type):
raise ValueError(
f'target must be a number for a ModelType.BINARY, has been provided [{self.target}]'
)
return self
case ModelType.MULTI_CLASS:
if not is_number_or_string(self.target.type):
raise ValueError(
f'target must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.target}]'
)
return self
case ModelType.REGRESSION:
if not is_number(self.target.type):
raise ValueError(
f'target must be a number for a ModelType.REGRESSION, has been provided [{self.target}]'
)
return self
case _:
raise ValueError('not supported type for model_type')

@model_validator(mode='after')
def validate_outputs(self) -> Self:
checked_model_type: ModelType = self.model_type
match checked_model_type:
case ModelType.BINARY:
if not is_number(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number for a ModelType.BINARY, has been provided [{self.outputs.prediction}]'
)
if not is_optional_float(self.outputs.prediction_proba.type):
raise ValueError(
f'prediction_proba must be an optional float for a ModelType.BINARY, has been provided [{self.outputs.prediction_proba}]'
)
return self
case ModelType.MULTI_CLASS:
if not is_number_or_string(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction}]'
)
if not is_optional_float(self.outputs.prediction_proba.type):
raise ValueError(
f'prediction_proba must be an optional float for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction_proba}]'
)
return self
case ModelType.REGRESSION:
if not is_number(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number for a ModelType.REGRESSION, has been provided [{self.outputs.prediction}]'
)
if not is_none(self.outputs.prediction_proba.type):
raise ValueError(
f'prediction_proba must be None for a ModelType.REGRESSION, has been provided [{self.outputs.prediction_proba}]'
)
return self
case _:
raise ValueError('not supported type for model_type')

@model_validator(mode='after')
def timestamp_must_be_datetime(self) -> Self:
if not self.timestamp.type == SupportedTypes.datetime:
raise ValueError('timestamp must be a datetime')
return self

def to_model(self) -> Model:
now = datetime.datetime.now(tz=datetime.UTC)
return Model(
Expand Down
19 changes: 19 additions & 0 deletions app/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any, Optional

from app.models.inferred_schema_dto import SupportedTypes


def is_number(value: SupportedTypes):
return value in (SupportedTypes.int, SupportedTypes.float)


def is_number_or_string(value: SupportedTypes):
return value in (SupportedTypes.int, SupportedTypes.float, SupportedTypes.string)


def is_optional_float(value: Optional[SupportedTypes] = None) -> bool:
return value in (None, SupportedTypes.float)


def is_none(value: Any) -> bool:
return value is None
34 changes: 24 additions & 10 deletions tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from app.db.tables.reference_dataset_metrics_table import ReferenceDatasetMetrics
from app.db.tables.reference_dataset_table import ReferenceDataset
from app.models.job_status import JobStatus
from app.models.model_dto import DataType, Granularity, ModelIn, ModelType
from app.models.model_dto import (
ColumnDefinition,
DataType,
Granularity,
ModelIn,
ModelType,
OutputType,
SupportedTypes,
)

MODEL_UUID = uuid.uuid4()
REFERENCE_UUID = uuid.uuid4()
Expand All @@ -26,7 +34,7 @@ def get_sample_model(
features: List[Dict] = [{'name': 'feature1', 'type': 'string'}],
outputs: Dict = {
'prediction': {'name': 'pred1', 'type': 'int'},
'prediction_proba': {'name': 'prob1', 'type': 'double'},
'prediction_proba': {'name': 'prob1', 'type': 'float'},
'output': [{'name': 'output1', 'type': 'string'}],
},
target: Dict = {'name': 'target1', 'type': 'string'},
Expand Down Expand Up @@ -59,14 +67,20 @@ def get_sample_model_in(
model_type: str = ModelType.BINARY.value,
data_type: str = DataType.TEXT.value,
granularity: str = Granularity.DAY.value,
features: List[Dict] = [{'name': 'feature1', 'type': 'string'}],
outputs: Dict = {
'prediction': {'name': 'pred1', 'type': 'int'},
'prediction_proba': {'name': 'prob1', 'type': 'double'},
'output': [{'name': 'output1', 'type': 'string'}],
},
target: Dict = {'name': 'target1', 'type': 'string'},
timestamp: Dict = {'name': 'timestamp', 'type': 'datetime'},
features: List[ColumnDefinition] = [
ColumnDefinition(name='feature1', type=SupportedTypes.string)
],
outputs: OutputType = OutputType(
prediction=ColumnDefinition(name='pred1', type=SupportedTypes.int),
prediction_proba=ColumnDefinition(name='prob1', type=SupportedTypes.float),
output=[ColumnDefinition(name='output1', type=SupportedTypes.string)],
),
target: ColumnDefinition = ColumnDefinition(
name='target1', type=SupportedTypes.int
),
timestamp: ColumnDefinition = ColumnDefinition(
name='timestamp', type=SupportedTypes.datetime
),
frameworks: Optional[str] = None,
algorithm: Optional[str] = None,
):
Expand Down
68 changes: 68 additions & 0 deletions tests/commons/modelin_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from app.models.inferred_schema_dto import SupportedTypes
from app.models.model_dto import (
ColumnDefinition,
DataType,
Granularity,
ModelType,
OutputType,
)


def get_model_sample_wrong(fail_field: str, model_type: ModelType):
prediction = None
prediction_proba = None
if fail_field == 'outputs.prediction' and model_type == ModelType.BINARY:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string)
elif fail_field == 'outputs.prediction' and model_type == ModelType.MULTI_CLASS:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.datetime)
elif fail_field == 'outputs.prediction' and model_type == ModelType.REGRESSION:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string)
else:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.int)

if (
fail_field == 'outputs.prediction_proba'
and model_type == ModelType.BINARY
or fail_field == 'outputs.prediction_proba'
and model_type == ModelType.MULTI_CLASS
):
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.int)
elif (
fail_field == 'outputs.prediction_proba' and model_type == ModelType.REGRESSION
):
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float)
else:
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float)

target: ColumnDefinition = None
if fail_field == 'target' and model_type == ModelType.BINARY:
target = ColumnDefinition(name='target1', type=SupportedTypes.string)
elif fail_field == 'target' and model_type == ModelType.MULTI_CLASS:
target = ColumnDefinition(name='target1', type=SupportedTypes.datetime)
elif fail_field == 'target' and model_type == ModelType.REGRESSION:
target = ColumnDefinition(name='target1', type=SupportedTypes.string)
else:
target = ColumnDefinition(name='target1', type=SupportedTypes.int)

timestamp: ColumnDefinition = None
if fail_field == 'timestamp':
timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.string)
else:
timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.datetime)

return {
'name': 'model_name',
'model_type': model_type,
'data_type': DataType.TEXT,
'granularity': Granularity.DAY,
'features': [ColumnDefinition(name='feature1', type=SupportedTypes.string)],
'outputs': OutputType(
prediction=prediction,
prediction_proba=prediction_proba,
output=[ColumnDefinition(name='output1', type=SupportedTypes.string)],
),
'target': target,
'timestamp': timestamp,
'frameworks': None,
'algorithm': None,
}
1 change: 1 addition & 0 deletions tests/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

106 changes: 106 additions & 0 deletions tests/validation/model_type_validator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from pydantic import ValidationError
import pytest

from app.models.model_dto import ModelIn, ModelType
from tests.commons.modelin_factory import get_model_sample_wrong


def test_timestamp_not_datetime():
"""Tests that timestamp validator fails when timestamp is not valid."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong(
fail_field='timestamp', model_type=ModelType.BINARY
)
ModelIn.model_validate(ModelIn(**model_data))
assert 'timestamp must be a datetime' in str(excinfo.value)


def test_target_for_binary():
"""Tests that for ModelType.BINARY: target must be a number."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong('target', ModelType.BINARY)
ModelIn.model_validate(ModelIn(**model_data))
assert 'target must be a number for a ModelType.BINARY' in str(excinfo.value)


def test_target_for_multiclass():
"""Tests that for ModelType.MULTI_CLASS: target must be a number or string."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong('target', ModelType.MULTI_CLASS)
ModelIn.model_validate(ModelIn(**model_data))
assert 'target must be a number or string for a ModelType.MULTI_CLASS' in str(
excinfo.value
)


def test_target_for_regression():
"""Tests that for ModelType.REGRESSION: target must be a number."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong('target', ModelType.REGRESSION)
ModelIn.model_validate(ModelIn(**model_data))
assert 'target must be a number for a ModelType.REGRESSION' in str(excinfo.value)


def test_prediction_for_binary():
"""Tests that for ModelType.BINARY: prediction must be a number."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong('outputs.prediction', ModelType.BINARY)
ModelIn.model_validate(ModelIn(**model_data))
assert 'prediction must be a number for a ModelType.BINARY' in str(excinfo.value)


def test_prediction_for_multiclass():
"""Tests that for ModelType.MULTI_CLASS: prediction must be a number or string."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong('outputs.prediction', ModelType.MULTI_CLASS)
ModelIn.model_validate(ModelIn(**model_data))
assert 'prediction must be a number or string for a ModelType.MULTI_CLASS' in str(
excinfo.value
)


def test_prediction_for_regression():
"""Tests that for ModelType.REGRESSION: prediction must be a number."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong('outputs.prediction', ModelType.REGRESSION)
ModelIn.model_validate(ModelIn(**model_data))
assert 'prediction must be a number for a ModelType.REGRESSION' in str(
excinfo.value
)


def test_prediction_proba_for_binary():
"""Tests that for ModelType.BINARY: prediction_proba must be a number."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong(
'outputs.prediction_proba', ModelType.BINARY
)
ModelIn.model_validate(ModelIn(**model_data))
assert 'prediction_proba must be an optional float for a ModelType.BINARY' in str(
excinfo.value
)


def test_prediction_proba_for_multiclass():
"""Tests that for ModelType.MULTI_CLASS: prediction_proba must be a number."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong(
'outputs.prediction_proba', ModelType.MULTI_CLASS
)
ModelIn.model_validate(ModelIn(**model_data))
assert (
'prediction_proba must be an optional float for a ModelType.MULTI_CLASS'
in str(excinfo.value)
)


def test_prediction_proba_for_regression():
"""Tests that for ModelType.REGRESSION: prediction_proba must be None."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong(
'outputs.prediction_proba', ModelType.REGRESSION
)
ModelIn.model_validate(ModelIn(**model_data))
assert 'prediction_proba must be None for a ModelType.REGRESSION' in str(
excinfo.value
)

0 comments on commit 922ea7d

Please sign in to comment.