Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): add field type as column definition property #133

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion api/app/models/inferred_schema_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pandas.core.arrays import boolean, floating, integer, string_
from pandas.core.dtypes import dtypes as pd_dtypes
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel

from app.models.exceptions import UnsupportedSchemaException

Expand Down Expand Up @@ -34,11 +35,34 @@ def cast(value) -> 'SupportedTypes':
raise UnsupportedSchemaException(f'Unsupported type: {type(value)}')


class FieldType(str, Enum):
categorical = 'categorical'
numerical = 'numerical'
datetime = 'datetime'

@staticmethod
def from_supported_type(value: SupportedTypes) -> 'FieldType':
match value:
case SupportedTypes.datetime:
return FieldType.datetime
case SupportedTypes.int:
return FieldType.numerical
case SupportedTypes.float:
return FieldType.numerical
case SupportedTypes.bool:
return FieldType.categorical
case SupportedTypes.string:
return FieldType.categorical


class SchemaEntry(BaseModel):
name: str
type: SupportedTypes
field_type: FieldType

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(
arbitrary_types_allowed=True, populate_by_name=True, alias_generator=to_camel
)


class InferredSchemaDTO(BaseModel):
Expand Down
5 changes: 4 additions & 1 deletion api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from app.db.dao.current_dataset_dao import CurrentDataset
from app.db.dao.model_dao import Model
from app.db.dao.reference_dataset_dao import ReferenceDataset
from app.models.inferred_schema_dto import SupportedTypes
from app.models.inferred_schema_dto import FieldType, SupportedTypes
from app.models.job_status import JobStatus
from app.models.utils import is_none, is_number, is_number_or_string, is_optional_float

Expand All @@ -36,6 +36,9 @@ class Granularity(str, Enum):
class ColumnDefinition(BaseModel, validate_assignment=True):
name: str
type: SupportedTypes
field_type: FieldType

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

def to_dict(self):
return self.model_dump()
Expand Down
7 changes: 6 additions & 1 deletion api/app/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ModelNotFoundError,
)
from app.models.inferred_schema_dto import (
FieldType,
InferredSchemaDTO,
SchemaEntry,
SupportedTypes,
Expand Down Expand Up @@ -419,7 +420,11 @@ def schema_from_pandas(df: pd.DataFrame) -> InferredSchemaDTO:
data = data.loc[:, ~data.columns.str.contains('Unnamed')]
return InferredSchemaDTO(
inferred_schema=[
SchemaEntry(name=name.strip(), type=SupportedTypes.cast(type))
SchemaEntry(
name=name.strip(),
type=SupportedTypes.cast(type),
field_type=FieldType.from_supported_type(SupportedTypes.cast(type)),
)
for name, type in data.convert_dtypes(infer_objects=True).dtypes.items()
]
)
Expand Down
57 changes: 47 additions & 10 deletions api/tests/commons/csv_file_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd

from app.models.inferred_schema_dto import (
FieldType,
InferredSchemaDTO,
SchemaEntry,
SupportedTypes,
Expand All @@ -28,16 +29,52 @@ def get_dataframe_with_sep(sep: str) -> pd.DataFrame:

def correct_schema() -> InferredSchemaDTO:
schema = [
{'name': 'Name', 'type': SupportedTypes.string},
{'name': 'Age', 'type': SupportedTypes.int},
{'name': 'City', 'type': SupportedTypes.string},
{'name': 'Salary', 'type': SupportedTypes.float},
{'name': 'String', 'type': SupportedTypes.string},
{'name': 'Float/Int', 'type': SupportedTypes.float},
{'name': 'Boolean', 'type': SupportedTypes.bool},
{'name': 'Datetime', 'type': SupportedTypes.datetime},
{'name': 'Datetime2', 'type': SupportedTypes.datetime},
{'name': 'Datetime3', 'type': SupportedTypes.datetime},
{
'name': 'Name',
'type': SupportedTypes.string,
'fieldType': FieldType.categorical,
},
{'name': 'Age', 'type': SupportedTypes.int, 'fieldType': FieldType.numerical},
{
'name': 'City',
'type': SupportedTypes.string,
'fieldType': FieldType.categorical,
},
{
'name': 'Salary',
'type': SupportedTypes.float,
'fieldType': FieldType.numerical,
},
{
'name': 'String',
'type': SupportedTypes.string,
'fieldType': FieldType.categorical,
},
{
'name': 'Float/Int',
'type': SupportedTypes.float,
'fieldType': FieldType.numerical,
},
{
'name': 'Boolean',
'type': SupportedTypes.bool,
'fieldType': FieldType.categorical,
},
{
'name': 'Datetime',
'type': SupportedTypes.datetime,
'fieldType': FieldType.datetime,
},
{
'name': 'Datetime2',
'type': SupportedTypes.datetime,
'fieldType': FieldType.datetime,
},
{
'name': 'Datetime3',
'type': SupportedTypes.datetime,
'fieldType': FieldType.datetime,
},
]
schema = [SchemaEntry(**entry) for entry in schema]
return InferredSchemaDTO(inferred_schema=schema)
Expand Down
49 changes: 37 additions & 12 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from app.models.model_dto import (
ColumnDefinition,
DataType,
FieldType,
Granularity,
ModelIn,
ModelType,
Expand All @@ -31,14 +32,24 @@ def get_sample_model(
model_type: str = ModelType.BINARY.value,
data_type: str = DataType.TEXT.value,
granularity: str = Granularity.DAY.value,
features: List[Dict] = [{'name': 'feature1', 'type': 'string'}],
features: List[Dict] = [
{'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'}
],
outputs: Dict = {
'prediction': {'name': 'pred1', 'type': 'int'},
'prediction_proba': {'name': 'prob1', 'type': 'float'},
'output': [{'name': 'output1', 'type': 'string'}],
'prediction': {'name': 'pred1', 'type': 'int', 'fieldType': 'numerical'},
'prediction_proba': {
'name': 'prob1',
'type': 'float',
'fieldType': 'numerical',
},
'output': [{'name': 'output1', 'type': 'string', 'fieldType': 'categorical'}],
},
target: Dict = {'name': 'target1', 'type': 'string', 'fieldType': 'categorical'},
timestamp: Dict = {
'name': 'timestamp',
'type': 'datetime',
'fieldType': 'datetime',
},
target: Dict = {'name': 'target1', 'type': 'string'},
timestamp: Dict = {'name': 'timestamp', 'type': 'datetime'},
frameworks: Optional[str] = None,
algorithm: Optional[str] = None,
) -> Model:
Expand Down Expand Up @@ -68,18 +79,32 @@ def get_sample_model_in(
data_type: str = DataType.TEXT.value,
granularity: str = Granularity.DAY.value,
features: List[ColumnDefinition] = [
ColumnDefinition(name='feature1', type=SupportedTypes.string)
ColumnDefinition(
name='feature1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
outputs: OutputType = OutputType(
prediction=ColumnDefinition(name='pred1', type=SupportedTypes.int),
prediction_proba=ColumnDefinition(name='prob1', type=SupportedTypes.float),
output=[ColumnDefinition(name='output1', type=SupportedTypes.string)],
prediction=ColumnDefinition(
name='pred1', type=SupportedTypes.int, field_type=FieldType.numerical
),
prediction_proba=ColumnDefinition(
name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical
),
output=[
ColumnDefinition(
name='output1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
),
target: ColumnDefinition = ColumnDefinition(
name='target1', type=SupportedTypes.int
name='target1', type=SupportedTypes.int, field_type=FieldType.numerical
),
timestamp: ColumnDefinition = ColumnDefinition(
name='timestamp', type=SupportedTypes.datetime
name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime
),
frameworks: Optional[str] = None,
algorithm: Optional[str] = None,
Expand Down
83 changes: 68 additions & 15 deletions api/tests/commons/modelin_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,106 @@
from app.models.model_dto import (
ColumnDefinition,
DataType,
FieldType,
Granularity,
ModelType,
OutputType,
)


def get_model_sample_wrong(fail_fields: List[str], model_type: ModelType):
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.int)
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float)
target = ColumnDefinition(name='target1', type=SupportedTypes.int)
timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.datetime)
prediction = ColumnDefinition(
name='pred1', type=SupportedTypes.int, field_type=FieldType.numerical
)
prediction_proba = ColumnDefinition(
name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical
)
target = ColumnDefinition(
name='target1', type=SupportedTypes.int, field_type=FieldType.numerical
)
timestamp = ColumnDefinition(
name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime
)

if 'outputs.prediction' in fail_fields:
if model_type == ModelType.BINARY:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string)
prediction = ColumnDefinition(
name='pred1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
elif model_type == ModelType.MULTI_CLASS:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.datetime)
prediction = ColumnDefinition(
name='pred1',
type=SupportedTypes.datetime,
field_type=FieldType.datetime,
)
elif model_type == ModelType.REGRESSION:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string)
prediction = ColumnDefinition(
name='pred1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)

if 'outputs.prediction_proba' in fail_fields:
if model_type in (ModelType.BINARY, ModelType.MULTI_CLASS):
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.int)
prediction_proba = ColumnDefinition(
name='prob1', type=SupportedTypes.int, field_type=FieldType.numerical
)
elif model_type == ModelType.REGRESSION:
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float)
prediction_proba = ColumnDefinition(
name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical
)

if 'target' in fail_fields:
if model_type == ModelType.BINARY:
target = ColumnDefinition(name='target1', type=SupportedTypes.string)
target = ColumnDefinition(
name='target1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
elif model_type == ModelType.MULTI_CLASS:
target = ColumnDefinition(name='target1', type=SupportedTypes.datetime)
target = ColumnDefinition(
name='target1',
type=SupportedTypes.datetime,
field_type=FieldType.datetime,
)
elif model_type == ModelType.REGRESSION:
target = ColumnDefinition(name='target1', type=SupportedTypes.string)
target = ColumnDefinition(
name='target1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)

if 'timestamp' in fail_fields:
timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.string)
timestamp = ColumnDefinition(
name='timestamp',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)

return {
'name': 'model_name',
'model_type': model_type,
'data_type': DataType.TEXT,
'granularity': Granularity.DAY,
'features': [ColumnDefinition(name='feature1', type=SupportedTypes.string)],
'features': [
ColumnDefinition(
name='feature1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
'outputs': OutputType(
prediction=prediction,
prediction_proba=prediction_proba,
output=[ColumnDefinition(name='output1', type=SupportedTypes.string)],
output=[
ColumnDefinition(
name='output1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
),
'target': target,
'timestamp': timestamp,
Expand Down
20 changes: 14 additions & 6 deletions api/tests/services/file_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,22 @@ def test_bind_reference_file_already_exists(self):
def test_upload_current_file_ok(self):
file = csv.get_current_sample_csv_file()
model = db_mock.get_sample_model(
features=[{'name': 'num1', 'type': 'int'}],
features=[{'name': 'num1', 'type': 'int', 'fieldType': 'numerical'}],
outputs={
'prediction': {'name': 'prediction', 'type': 'int'},
'prediction_proba': {'name': 'prediction_proba', 'type': 'int'},
'output': [{'name': 'num2', 'type': 'int'}],
'prediction': {
'name': 'prediction',
'type': 'int',
'fieldType': 'numerical',
},
'prediction_proba': {
'name': 'prediction_proba',
'type': 'int',
'fieldType': 'numerical',
},
'output': [{'name': 'num2', 'type': 'int', 'fieldType': 'numerical'}],
},
target={'name': 'target', 'type': 'int'},
timestamp={'name': 'datetime', 'type': 'datetime'},
target={'name': 'target', 'type': 'int', 'fieldType': 'numerical'},
timestamp={'name': 'datetime', 'type': 'datetime', 'fieldType': 'datetime'},
)
object_name = f'{str(model.uuid)}/current/{file.filename}'
path = f's3://bucket/{object_name}'
Expand Down