From d126e15d94103de9bdcece79b85e377e71539ac5 Mon Sep 17 00:00:00 2001 From: Mauro Cortellazzi Date: Thu, 25 Jul 2024 13:18:48 +0200 Subject: [PATCH] feat(sdk): improved column definition class (#138) --- .../models/__init__.py | 4 + .../models/column_definition.py | 13 +- .../models/field_type.py | 23 ++ .../models/supported_types.py | 9 + sdk/tests/apis/model_test.py | 202 ++++++++++++++---- sdk/tests/client_test.py | 112 ++++++++-- sdk/tests/models/column_definition_test.py | 16 +- sdk/tests/models/model_definition_test.py | 47 ++-- 8 files changed, 345 insertions(+), 81 deletions(-) create mode 100644 sdk/radicalbit_platform_sdk/models/field_type.py create mode 100644 sdk/radicalbit_platform_sdk/models/supported_types.py diff --git a/sdk/radicalbit_platform_sdk/models/__init__.py b/sdk/radicalbit_platform_sdk/models/__init__.py index 9f9a3964..fc04c03e 100644 --- a/sdk/radicalbit_platform_sdk/models/__init__.py +++ b/sdk/radicalbit_platform_sdk/models/__init__.py @@ -30,6 +30,7 @@ RegressionModelQuality, ) from .dataset_stats import DatasetStats +from .field_type import FieldType from .file_upload_result import CurrentFileUpload, FileReference, ReferenceFileUpload from .job_status import JobStatus from .model_definition import ( @@ -39,6 +40,7 @@ OutputType, ) from .model_type import ModelType +from .supported_types import SupportedTypes __all__ = [ 'OutputType', @@ -76,4 +78,6 @@ 'CurrentFileUpload', 'FileReference', 'AwsCredentials', + 'SupportedTypes', + 'FieldType', ] diff --git a/sdk/radicalbit_platform_sdk/models/column_definition.py b/sdk/radicalbit_platform_sdk/models/column_definition.py index f1b2c556..40d05a81 100644 --- a/sdk/radicalbit_platform_sdk/models/column_definition.py +++ b/sdk/radicalbit_platform_sdk/models/column_definition.py @@ -1,6 +1,13 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field +from pydantic.alias_generators import to_camel +from radicalbit_platform_sdk.models.field_type import FieldType +from radicalbit_platform_sdk.models.supported_types import SupportedTypes -class ColumnDefinition(BaseModel): + +class ColumnDefinition(BaseModel, validate_assignment=True): name: str - type: str + type: SupportedTypes + field_type: FieldType = Field(alias='fieldType') + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) diff --git a/sdk/radicalbit_platform_sdk/models/field_type.py b/sdk/radicalbit_platform_sdk/models/field_type.py new file mode 100644 index 00000000..abbf328e --- /dev/null +++ b/sdk/radicalbit_platform_sdk/models/field_type.py @@ -0,0 +1,23 @@ +from enum import Enum + +from radicalbit_platform_sdk.models.supported_types import SupportedTypes + + +class FieldType(str, Enum): + categorical = 'categorical' + numerical = 'numerical' + datetime = 'datetime' + + @staticmethod + def from_supported_type(value: SupportedTypes) -> 'FieldType': + match value: + case SupportedTypes.datetime: + return FieldType.datetime + case SupportedTypes.int: + return FieldType.numerical + case SupportedTypes.float: + return FieldType.numerical + case SupportedTypes.bool: + return FieldType.categorical + case SupportedTypes.string: + return FieldType.categorical diff --git a/sdk/radicalbit_platform_sdk/models/supported_types.py b/sdk/radicalbit_platform_sdk/models/supported_types.py new file mode 100644 index 00000000..0c475001 --- /dev/null +++ b/sdk/radicalbit_platform_sdk/models/supported_types.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class SupportedTypes(str, Enum): + string = 'string' + int = 'int' + float = 'float' + bool = 'bool' + datetime = 'datetime' diff --git a/sdk/tests/apis/model_test.py b/sdk/tests/apis/model_test.py index fd93c8b2..af9ce160 100644 --- a/sdk/tests/apis/model_test.py +++ b/sdk/tests/apis/model_test.py @@ -13,12 +13,14 @@ ColumnDefinition, CurrentFileUpload, DataType, + FieldType, Granularity, JobStatus, ModelDefinition, ModelType, OutputType, ReferenceFileUpload, + SupportedTypes, ) @@ -27,7 +29,9 @@ class ModelTest(unittest.TestCase): def test_delete_model(self): base_url = 'http://api:9000' model_id = uuid.uuid4() - column_def = ColumnDefinition(name='column', type='my_type') + column_def = ColumnDefinition( + name='column', type=SupportedTypes.string, field_type=FieldType.categorical + ) outputs = OutputType(prediction=column_def, output=[column_def]) model = Model( base_url, @@ -59,7 +63,9 @@ def test_load_reference_dataset_without_object_name(self): model_id = uuid.uuid4() bucket_name = 'test-bucket' file_name = 'test.txt' - column_def = ColumnDefinition(name='prediction', type='float') + column_def = ColumnDefinition( + name='prediction', type=SupportedTypes.float, field_type=FieldType.numerical + ) expected_path = f's3://{bucket_name}/{model_id}/reference/{file_name}' conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) @@ -72,12 +78,28 @@ def test_load_reference_dataset_without_object_name(self): data_type=DataType.TABULAR, granularity=Granularity.HOUR, features=[ - ColumnDefinition(name='first_name', type='str'), - ColumnDefinition(name='age', type='int'), + ColumnDefinition( + name='first_name', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + ColumnDefinition( + name='age', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name='adult', type='bool'), - timestamp=ColumnDefinition(name='created_at', type='str'), + target=ColumnDefinition( + name='adult', + type=SupportedTypes.bool, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='created_at', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ), created_at=str(time.time()), updated_at=str(time.time()), ), @@ -104,7 +126,9 @@ def test_load_reference_dataset_with_different_separator(self): model_id = uuid.uuid4() bucket_name = 'test-bucket' file_name = 'test.txt' - column_def = ColumnDefinition(name='prediction', type='float') + column_def = ColumnDefinition( + name='prediction', type=SupportedTypes.float, field_type=FieldType.numerical + ) expected_path = f's3://{bucket_name}/{model_id}/reference/{file_name}' conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) @@ -117,12 +141,28 @@ def test_load_reference_dataset_with_different_separator(self): data_type=DataType.TABULAR, granularity=Granularity.DAY, features=[ - ColumnDefinition(name='first_name', type='str'), - ColumnDefinition(name='age', type='int'), + ColumnDefinition( + name='first_name', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + ColumnDefinition( + name='age', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name='adult', type='bool'), - timestamp=ColumnDefinition(name='created_at', type='str'), + target=ColumnDefinition( + name='adult', + type=SupportedTypes.bool, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='created_at', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ), created_at=str(time.time()), updated_at=str(time.time()), ), @@ -149,7 +189,9 @@ def test_load_reference_dataset_with_object_name(self): model_id = uuid.uuid4() bucket_name = 'test-bucket' file_name = 'test.txt' - column_def = ColumnDefinition(name='prediction', type='float') + column_def = ColumnDefinition( + name='prediction', type=SupportedTypes.float, field_type=FieldType.numerical + ) expected_path = f's3://{bucket_name}/{file_name}' conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) @@ -162,12 +204,28 @@ def test_load_reference_dataset_with_object_name(self): data_type=DataType.TABULAR, granularity=Granularity.WEEK, features=[ - ColumnDefinition(name='first_name', type='str'), - ColumnDefinition(name='age', type='int'), + ColumnDefinition( + name='first_name', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + ColumnDefinition( + name='age', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name='adult', type='bool'), - timestamp=ColumnDefinition(name='created_at', type='str'), + target=ColumnDefinition( + name='adult', + type=SupportedTypes.bool, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='created_at', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ), created_at=str(time.time()), updated_at=str(time.time()), ), @@ -188,7 +246,9 @@ def test_load_reference_dataset_with_object_name(self): assert response.path() == expected_path def test_load_reference_dataset_wrong_headers(self): - column_def = ColumnDefinition(name='prediction', type='float') + column_def = ColumnDefinition( + name='prediction', type=SupportedTypes.float, field_type=FieldType.numerical + ) model = Model( 'http://api:9000', ModelDefinition( @@ -198,12 +258,28 @@ def test_load_reference_dataset_wrong_headers(self): data_type=DataType.TABULAR, granularity=Granularity.MONTH, features=[ - ColumnDefinition(name='first_name', type='str'), - ColumnDefinition(name='age', type='int'), + ColumnDefinition( + name='first_name', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + ColumnDefinition( + name='age', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name='adult', type='bool'), - timestamp=ColumnDefinition(name='created_at', type='str'), + target=ColumnDefinition( + name='adult', + type=SupportedTypes.bool, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='created_at', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ), created_at=str(time.time()), updated_at=str(time.time()), ), @@ -218,7 +294,9 @@ def test_load_current_dataset_without_object_name(self): model_id = uuid.uuid4() bucket_name = 'test-bucket' file_name = 'test.txt' - column_def = ColumnDefinition(name='prediction', type='float') + column_def = ColumnDefinition( + name='prediction', type=SupportedTypes.float, field_type=FieldType.numerical + ) expected_path = f's3://{bucket_name}/{model_id}/reference/{file_name}' conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) @@ -231,12 +309,28 @@ def test_load_current_dataset_without_object_name(self): data_type=DataType.TABULAR, granularity=Granularity.DAY, features=[ - ColumnDefinition(name='first_name', type='str'), - ColumnDefinition(name='age', type='int'), + ColumnDefinition( + name='first_name', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + ColumnDefinition( + name='age', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name='adult', type='bool'), - timestamp=ColumnDefinition(name='created_at', type='str'), + target=ColumnDefinition( + name='adult', + type=SupportedTypes.bool, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='created_at', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ), created_at=str(time.time()), updated_at=str(time.time()), ), @@ -269,7 +363,9 @@ def test_load_current_dataset_with_object_name(self): model_id = uuid.uuid4() bucket_name = 'test-bucket' file_name = 'test.txt' - column_def = ColumnDefinition(name='prediction', type='float') + column_def = ColumnDefinition( + name='prediction', type=SupportedTypes.float, field_type=FieldType.numerical + ) expected_path = f's3://{bucket_name}/{file_name}' conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) @@ -282,12 +378,28 @@ def test_load_current_dataset_with_object_name(self): data_type=DataType.TABULAR, granularity=Granularity.HOUR, features=[ - ColumnDefinition(name='first_name', type='str'), - ColumnDefinition(name='age', type='int'), + ColumnDefinition( + name='first_name', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + ColumnDefinition( + name='age', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name='adult', type='bool'), - timestamp=ColumnDefinition(name='created_at', type='str'), + target=ColumnDefinition( + name='adult', + type=SupportedTypes.bool, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='created_at', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ), created_at=str(time.time()), updated_at=str(time.time()), ), @@ -314,7 +426,9 @@ def test_load_current_dataset_with_object_name(self): assert response.path() == expected_path def test_load_current_dataset_wrong_headers(self): - column_def = ColumnDefinition(name='prediction', type='float') + column_def = ColumnDefinition( + name='prediction', type=SupportedTypes.float, field_type=FieldType.numerical + ) model = Model( 'http://api:9000', ModelDefinition( @@ -324,12 +438,28 @@ def test_load_current_dataset_wrong_headers(self): data_type=DataType.TABULAR, granularity=Granularity.MONTH, features=[ - ColumnDefinition(name='first_name', type='str'), - ColumnDefinition(name='age', type='int'), + ColumnDefinition( + name='first_name', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + ColumnDefinition( + name='age', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name='adult', type='bool'), - timestamp=ColumnDefinition(name='created_at', type='str'), + target=ColumnDefinition( + name='adult', + type=SupportedTypes.bool, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='created_at', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ), created_at=str(time.time()), updated_at=str(time.time()), ), diff --git a/sdk/tests/client_test.py b/sdk/tests/client_test.py index c07d28be..366afa2d 100644 --- a/sdk/tests/client_test.py +++ b/sdk/tests/client_test.py @@ -11,10 +11,12 @@ ColumnDefinition, CreateModel, DataType, + FieldType, Granularity, ModelDefinition, ModelType, OutputType, + SupportedTypes, ) @@ -32,12 +34,16 @@ def test_get_model(self): frameworks = 'mlflow' feature_name = 'age' feature_type = 'int' + feature_field_type = 'numerical' output_name = 'adult' output_type = 'bool' + output_field_type = 'categorical' target_name = 'adult' target_type = 'bool' + target_field_type = 'categorical' timestamp_name = 'when' - timestamp_type = 'str' + timestamp_type = 'datetime' + timestamp_field_type = 'datetime' ts = str(time.time()) json_string = f"""{{ "uuid": "{str(model_id)}", @@ -47,29 +53,35 @@ def test_get_model(self): "granularity": "{granularity.value}", "features": [{{ "name": "{feature_name}", - "type": "{feature_type}" + "type": "{feature_type}", + "fieldType": "{feature_field_type}" }}], "outputs": {{ "prediction": {{ "name": "{output_name}", - "type": "{output_type}" + "type": "{output_type}", + "fieldType": "{output_field_type}" }}, "predictionProba": {{ "name": "{output_name}", - "type": "{output_type}" + "type": "{output_type}", + "fieldType": "{output_field_type}" }}, "output": [{{ "name": "{output_name}", - "type": "{output_type}" + "type": "{output_type}", + "fieldType": "{output_field_type}" }}] }}, "target": {{ "name": "{target_name}", - "type": "{target_type}" + "type": "{target_type}", + "fieldType": "{target_field_type}" }}, "timestamp": {{ "name": "{timestamp_name}", - "type": "{timestamp_type}" + "type": "{timestamp_type}", + "fieldType": "{timestamp_field_type}" }}, "description": "{description}", "algorithm": "{algorithm}", @@ -95,19 +107,25 @@ def test_get_model(self): assert model.algorithm() == algorithm assert model.frameworks() == frameworks assert model.target().name == target_name - assert model.target().type == target_type + assert model.target().type == SupportedTypes.bool + assert model.target().field_type == FieldType.categorical assert model.timestamp().name == timestamp_name - assert model.timestamp().type == timestamp_type + assert model.timestamp().type == SupportedTypes.datetime + assert model.timestamp().field_type == FieldType.datetime assert len(model.features()) == 1 assert model.features()[0].name == feature_name - assert model.features()[0].type == feature_type + assert model.features()[0].type == SupportedTypes.int + assert model.features()[0].field_type == FieldType.numerical assert model.outputs().prediction.name == output_name - assert model.outputs().prediction.type == output_type + assert model.outputs().prediction.type == SupportedTypes.bool + assert model.outputs().prediction.field_type == FieldType.categorical assert model.outputs().prediction_proba.name == output_name - assert model.outputs().prediction_proba.type == output_type + assert model.outputs().prediction_proba.type == SupportedTypes.bool + assert model.outputs().prediction_proba.field_type == FieldType.categorical assert len(model.outputs().output) == 1 assert model.outputs().output[0].name == output_name - assert model.outputs().output[0].type == output_type + assert model.outputs().output[0].type == SupportedTypes.bool + assert model.outputs().output[0].field_type == FieldType.categorical @responses.activate def test_get_model_client_error(self): @@ -133,13 +151,37 @@ def test_create_model(self): model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.WEEK, - features=[ColumnDefinition(name='feature_column', type='string')], + features=[ + ColumnDefinition( + name='feature_column', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ], outputs=OutputType( - prediction=ColumnDefinition(name='result_column', type='int'), - output=[ColumnDefinition(name='result_column', type='int')], + prediction=ColumnDefinition( + name='result_column', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), + output=[ + ColumnDefinition( + name='result_column', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ) + ], + ), + target=ColumnDefinition( + name='target_column', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='tst_column', + type=SupportedTypes.string, + field_type=FieldType.categorical, ), - target=ColumnDefinition(name='target_column', type='string'), - timestamp=ColumnDefinition(name='tst_column', type='string'), ) model_definition = ModelDefinition( name=model.name, @@ -184,13 +226,37 @@ def test_search_models(self): model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.DAY, - features=[ColumnDefinition(name='feature_column', type='string')], + features=[ + ColumnDefinition( + name='feature_column', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ], outputs=OutputType( - prediction=ColumnDefinition(name='result_column', type='int'), - output=[ColumnDefinition(name='result_column', type='int')], + prediction=ColumnDefinition( + name='result_column', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ), + output=[ + ColumnDefinition( + name='result_column', + type=SupportedTypes.int, + field_type=FieldType.numerical, + ) + ], + ), + target=ColumnDefinition( + name='target_column', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ), + timestamp=ColumnDefinition( + name='tst_column', + type=SupportedTypes.string, + field_type=FieldType.categorical, ), - target=ColumnDefinition(name='target_column', type='string'), - timestamp=ColumnDefinition(name='tst_column', type='string'), created_at=str(time.time()), updated_at=str(time.time()), ) diff --git a/sdk/tests/models/column_definition_test.py b/sdk/tests/models/column_definition_test.py index 9da442cf..040a5a3f 100644 --- a/sdk/tests/models/column_definition_test.py +++ b/sdk/tests/models/column_definition_test.py @@ -1,14 +1,18 @@ import json import unittest -from radicalbit_platform_sdk.models import ColumnDefinition +from radicalbit_platform_sdk.models import ColumnDefinition, FieldType, SupportedTypes class ColumnDefinitionTest(unittest.TestCase): def test_from_json(self): - field_name = 'age' - field_type = 'int' - json_string = f'{{"name": "{field_name}", "type": "{field_type}"}}' + name = 'age' + type = 'int' + field_type = 'numerical' + json_string = ( + f'{{"name": "{name}", "type": "{type}", "fieldType": "{field_type}"}}' + ) column_definition = ColumnDefinition.model_validate(json.loads(json_string)) - assert column_definition.name == field_name - assert column_definition.type == field_type + assert column_definition.name == name + assert column_definition.type == SupportedTypes.int + assert column_definition.field_type == FieldType.numerical diff --git a/sdk/tests/models/model_definition_test.py b/sdk/tests/models/model_definition_test.py index c7a9a48b..4917f152 100644 --- a/sdk/tests/models/model_definition_test.py +++ b/sdk/tests/models/model_definition_test.py @@ -5,9 +5,11 @@ from radicalbit_platform_sdk.models import ( DataType, + FieldType, Granularity, ModelDefinition, ModelType, + SupportedTypes, ) @@ -23,12 +25,16 @@ def test_model_definition_from_json(self): frameworks = 'mlflow' feature_name = 'age' feature_type = 'int' + feature_field_type = 'numerical' output_name = 'adult' output_type = 'bool' + output_field_type = 'categorical' target_name = 'adult' target_type = 'bool' + target_field_type = 'categorical' timestamp_name = 'when' - timestamp_type = 'str' + timestamp_type = 'datetime' + timestamp_field_type = 'datetime' ts = str(time.time()) json_string = f"""{{ "uuid": "{str(id)}", @@ -38,29 +44,35 @@ def test_model_definition_from_json(self): "granularity": "{granularity.value}", "features": [{{ "name": "{feature_name}", - "type": "{feature_type}" + "type": "{feature_type}", + "fieldType": "{feature_field_type}" }}], "outputs": {{ "prediction": {{ "name": "{output_name}", - "type": "{output_type}" + "type": "{output_type}", + "fieldType": "{output_field_type}" }}, "predictionProba": {{ "name": "{output_name}", - "type": "{output_type}" + "type": "{output_type}", + "fieldType": "{output_field_type}" }}, "output": [{{ "name": "{output_name}", - "type": "{output_type}" + "type": "{output_type}", + "fieldType": "{output_field_type}" }}] }}, "target": {{ "name": "{target_name}", - "type": "{target_type}" + "type": "{target_type}", + "fieldType": "{target_field_type}" }}, "timestamp": {{ "name": "{timestamp_name}", - "type": "{timestamp_type}" + "type": "{timestamp_type}", + "fieldType": "{timestamp_field_type}" }}, "description": "{description}", "algorithm": "{algorithm}", @@ -81,15 +93,24 @@ def test_model_definition_from_json(self): assert model_definition.updated_at == ts assert len(model_definition.features) == 1 assert model_definition.features[0].name == feature_name - assert model_definition.features[0].type == feature_type + assert model_definition.features[0].type == SupportedTypes.int + assert model_definition.features[0].field_type == FieldType.numerical assert model_definition.outputs.prediction.name == output_name - assert model_definition.outputs.prediction.type == output_type + assert model_definition.outputs.prediction.type == SupportedTypes.bool + assert model_definition.outputs.prediction.field_type == FieldType.categorical assert model_definition.outputs.prediction_proba.name == output_name - assert model_definition.outputs.prediction_proba.type == output_type + assert model_definition.outputs.prediction_proba.type == SupportedTypes.bool + assert ( + model_definition.outputs.prediction_proba.field_type + == FieldType.categorical + ) assert len(model_definition.outputs.output) == 1 assert model_definition.outputs.output[0].name == output_name - assert model_definition.outputs.output[0].type == output_type + assert model_definition.outputs.output[0].type == SupportedTypes.bool + assert model_definition.outputs.output[0].field_type == FieldType.categorical assert model_definition.target.name == target_name - assert model_definition.target.type == target_type + assert model_definition.target.type == SupportedTypes.bool + assert model_definition.target.field_type == FieldType.categorical assert model_definition.timestamp.name == timestamp_name - assert model_definition.timestamp.type == timestamp_type + assert model_definition.timestamp.type == SupportedTypes.datetime + assert model_definition.timestamp.field_type == FieldType.datetime