From f91f2e9d9ea704178a08a0288bc02b00db483751 Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Fri, 6 Dec 2024 17:01:46 +0100 Subject: [PATCH 1/7] feat: add text-generation as new model type, handle the new model type, set schema fields as optional, edit test --- api/app/models/model_dto.py | 57 +++++++++++++----- api/tests/commons/db_mock.py | 20 ++++--- api/tests/commons/modelin_factory.py | 58 +++++++++++++++++++ api/tests/services/model_service_test.py | 23 +++++++- .../validation/model_type_validator_test.py | 34 ++++++++++- 5 files changed, 168 insertions(+), 24 deletions(-) diff --git a/api/app/models/model_dto.py b/api/app/models/model_dto.py index 0d174552..c56b7a5d 100644 --- a/api/app/models/model_dto.py +++ b/api/app/models/model_dto.py @@ -19,6 +19,7 @@ class ModelType(str, Enum): REGRESSION = 'REGRESSION' BINARY = 'BINARY' MULTI_CLASS = 'MULTI_CLASS' + TEXT_GENERATION = 'TEXT_GENERATION' class DataType(str, Enum): @@ -93,10 +94,10 @@ class ModelIn(BaseModel, validate_assignment=True): model_type: ModelType data_type: DataType granularity: Granularity - features: List[ColumnDefinition] - outputs: OutputType - target: ColumnDefinition - timestamp: ColumnDefinition + features: Optional[List[ColumnDefinition]] = None + outputs: Optional[OutputType] = None + target: Optional[ColumnDefinition] = None + timestamp: Optional[ColumnDefinition] = None frameworks: Optional[str] = None algorithm: Optional[str] = None @@ -104,9 +105,29 @@ class ModelIn(BaseModel, validate_assignment=True): populate_by_name=True, alias_generator=to_camel, protected_namespaces=() ) + @model_validator(mode='after') + def validate_fields(self) -> Self: + checked_model_type = self.model_type + if checked_model_type == ModelType.TEXT_GENERATION: + if any([self.target, self.features, self.outputs, self.timestamp]): + raise ValueError( + f'target, features, outputs and timestamp must not be provided for a {checked_model_type}' + ) + return self + if not self.features: + raise ValueError(f'features must be provided for a {checked_model_type}') + if not self.outputs: + raise ValueError(f'outputs must be provided for a {checked_model_type}') + if not self.target: + raise ValueError(f'target must be provided for a {checked_model_type}') + if not self.timestamp: + raise ValueError(f'timestamp must be provided for a {checked_model_type}') + + return self + @model_validator(mode='after') def validate_target(self) -> Self: - checked_model_type: ModelType = self.model_type + checked_model_type = self.model_type match checked_model_type: case ModelType.BINARY: if not is_number(self.target.type): @@ -126,12 +147,14 @@ def validate_target(self) -> Self: f'target must be a number for a {checked_model_type}, has been provided [{self.target}]' ) return self + case ModelType.TEXT_GENERATION: + return self case _: raise ValueError('not supported type for model_type') @model_validator(mode='after') def validate_outputs(self) -> Self: - checked_model_type: ModelType = self.model_type + checked_model_type = self.model_type match checked_model_type: case ModelType.BINARY: if not is_number(self.outputs.prediction.type): @@ -169,11 +192,15 @@ def validate_outputs(self) -> Self: f'prediction_proba must be None for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]' ) return self + case ModelType.TEXT_GENERATION: + return self case _: raise ValueError('not supported type for model_type') @model_validator(mode='after') def timestamp_must_be_datetime(self) -> Self: + if self.model_type == ModelType.TEXT_GENERATION: + return self if not self.timestamp.type == SupportedTypes.datetime: raise ValueError('timestamp must be a datetime') return self @@ -187,10 +214,12 @@ def to_model(self) -> Model: model_type=self.model_type.value, data_type=self.data_type.value, granularity=self.granularity.value, - features=[feature.to_dict() for feature in self.features], - outputs=self.outputs.to_dict(), - target=self.target.to_dict(), - timestamp=self.timestamp.to_dict(), + features=[feature.to_dict() for feature in self.features] + if self.features + else None, + outputs=self.outputs.to_dict() if self.outputs else None, + target=self.target.to_dict() if self.target else None, + timestamp=self.timestamp.to_dict() if self.timestamp else None, frameworks=self.frameworks, algorithm=self.algorithm, created_at=now, @@ -205,10 +234,10 @@ class ModelOut(BaseModel): model_type: ModelType data_type: DataType granularity: Granularity - features: List[ColumnDefinition] - outputs: OutputType - target: ColumnDefinition - timestamp: ColumnDefinition + features: Optional[List[ColumnDefinition]] + outputs: Optional[OutputType] + target: Optional[ColumnDefinition] + timestamp: Optional[ColumnDefinition] frameworks: Optional[str] algorithm: Optional[str] created_at: str diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index e6b02d30..ba675605 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -33,10 +33,10 @@ def get_sample_model( model_type: str = ModelType.BINARY.value, data_type: str = DataType.TEXT.value, granularity: str = Granularity.DAY.value, - features: List[Dict] = [ + features: Optional[List[Dict]] = [ {'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'} ], - outputs: Dict = { + outputs: Optional[Dict] = { 'prediction': {'name': 'pred1', 'type': 'int', 'fieldType': 'numerical'}, 'prediction_proba': { 'name': 'prob1', @@ -45,8 +45,12 @@ def get_sample_model( }, 'output': [{'name': 'output1', 'type': 'string', 'fieldType': 'categorical'}], }, - target: Dict = {'name': 'target1', 'type': 'string', 'fieldType': 'categorical'}, - timestamp: Dict = { + target: Optional[Dict] = { + 'name': 'target1', + 'type': 'string', + 'fieldType': 'categorical', + }, + timestamp: Optional[Dict] = { 'name': 'timestamp', 'type': 'datetime', 'fieldType': 'datetime', @@ -91,14 +95,14 @@ def get_sample_model_in( model_type: str = ModelType.BINARY.value, data_type: str = DataType.TEXT.value, granularity: str = Granularity.DAY.value, - features: List[ColumnDefinition] = [ + features: Optional[List[ColumnDefinition]] = [ ColumnDefinition( name='feature1', type=SupportedTypes.string, field_type=FieldType.categorical, ) ], - outputs: OutputType = OutputType( + outputs: Optional[OutputType] = OutputType( prediction=ColumnDefinition( name='pred1', type=SupportedTypes.int, field_type=FieldType.numerical ), @@ -113,10 +117,10 @@ def get_sample_model_in( ) ], ), - target: ColumnDefinition = ColumnDefinition( + target: Optional[ColumnDefinition] = ColumnDefinition( name='target1', type=SupportedTypes.int, field_type=FieldType.numerical ), - timestamp: ColumnDefinition = ColumnDefinition( + timestamp: Optional[ColumnDefinition] = ColumnDefinition( name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime ), frameworks: Optional[str] = None, diff --git a/api/tests/commons/modelin_factory.py b/api/tests/commons/modelin_factory.py index f2fc7091..145b5e4d 100644 --- a/api/tests/commons/modelin_factory.py +++ b/api/tests/commons/modelin_factory.py @@ -25,6 +25,64 @@ def get_model_sample_wrong(fail_fields: List[str], model_type: ModelType): name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime ) + if model_type == ModelType.TEXT_GENERATION: + if 'features' in fail_fields: + features = [ + ColumnDefinition( + name='feature1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ] + else: + features = None + + if 'outputs' in fail_fields: + outputs = OutputType( + prediction=prediction, + prediction_proba=prediction_proba, + output=[ + ColumnDefinition( + name='output1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ], + ) + else: + outputs = None + + if 'target' in fail_fields: + target = ColumnDefinition( + name='target1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + else: + target = None + + if 'timestamp' in fail_fields: + timestamp = ColumnDefinition( + name='timestamp', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + else: + timestamp = None + + return { + 'name': 'text_generation_model', + 'model_type': model_type, + 'data_type': DataType.TEXT, + 'granularity': Granularity.DAY, + 'features': features, + 'outputs': outputs, + 'target': target, + 'timestamp': timestamp, + 'frameworks': None, + 'algorithm': None, + } + if 'outputs.prediction' in fail_fields: if model_type == ModelType.BINARY: prediction = ColumnDefinition( diff --git a/api/tests/services/model_service_test.py b/api/tests/services/model_service_test.py index 7463b029..60f196e7 100644 --- a/api/tests/services/model_service_test.py +++ b/api/tests/services/model_service_test.py @@ -11,7 +11,7 @@ from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO from app.models.alert_dto import AnomalyType from app.models.exceptions import ModelError, ModelNotFoundError -from app.models.model_dto import ModelOut +from app.models.model_dto import ModelOut, ModelType from app.models.model_order import OrderType from app.services.model_service import ModelService from tests.commons import db_mock @@ -42,6 +42,27 @@ def test_create_model_ok(self): assert res == ModelOut.from_model(model) + def test_create_text_generation_model_ok(self): + model = db_mock.get_sample_model( + model_type=ModelType.TEXT_GENERATION, + features=None, + target=None, + outputs=None, + timestamp=None, + ) + self.model_dao.insert = MagicMock(return_value=model) + model_in = db_mock.get_sample_model_in( + model_type=ModelType.TEXT_GENERATION, + features=None, + target=None, + outputs=None, + timestamp=None, + ) + res = self.model_service.create_model(model_in) + self.model_dao.insert.assert_called_once() + + assert res == ModelOut.from_model(model) + def test_get_model_by_uuid_ok(self): model = db_mock.get_sample_model() reference_dataset = db_mock.get_sample_reference_dataset(model_uuid=model.uuid) diff --git a/api/tests/validation/model_type_validator_test.py b/api/tests/validation/model_type_validator_test.py index add2bb89..cbec9ad3 100644 --- a/api/tests/validation/model_type_validator_test.py +++ b/api/tests/validation/model_type_validator_test.py @@ -1,7 +1,7 @@ from pydantic import ValidationError import pytest -from app.models.model_dto import ModelIn, ModelType +from app.models.model_dto import ModelIn, ModelType, DataType, Granularity from tests.commons.modelin_factory import get_model_sample_wrong @@ -108,3 +108,35 @@ def test_prediction_proba_for_regression(): assert 'prediction_proba must be None for a ModelType.REGRESSION' in str( excinfo.value ) + + +def test_text_generation_invalid_fields_provided(): + """Tests that TEXT_GENERATION fails if features, outputs, target, or timestamp are provided.""" + with pytest.raises(ValidationError) as excinfo: + model_data = get_model_sample_wrong( + fail_fields=['features', 'outputs', 'target', 'timestamp'], + model_type=ModelType.TEXT_GENERATION, + ) + ModelIn.model_validate(ModelIn(**model_data)) + assert ( + 'target, features, outputs and timestamp must not be provided for a ModelType.TEXT_GENERATION' + in str(excinfo.value) + ) + + +def test_text_generation_valid(): + """Tests that TEXT_GENERATION passes validation with no schema fields.""" + model_data = { + 'name': 'text_generation_model', + 'model_type': ModelType.TEXT_GENERATION, + 'data_type': DataType.TEXT, + 'granularity': Granularity.DAY, + 'frameworks': 'transformer', + 'algorithm': 'gpt-like', + } + model = ModelIn.model_validate(ModelIn(**model_data)) + assert model.model_type == ModelType.TEXT_GENERATION + assert model.features is None + assert model.outputs is None + assert model.target is None + assert model.timestamp is None From cba35a5b3176a57b9d99dd170107ae7652361c60 Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Mon, 9 Dec 2024 09:41:10 +0100 Subject: [PATCH 2/7] fix: ruff check --- api/tests/validation/model_type_validator_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/tests/validation/model_type_validator_test.py b/api/tests/validation/model_type_validator_test.py index cbec9ad3..e2a2651d 100644 --- a/api/tests/validation/model_type_validator_test.py +++ b/api/tests/validation/model_type_validator_test.py @@ -1,7 +1,7 @@ from pydantic import ValidationError import pytest -from app.models.model_dto import ModelIn, ModelType, DataType, Granularity +from app.models.model_dto import DataType, Granularity, ModelIn, ModelType from tests.commons.modelin_factory import get_model_sample_wrong From f2e72a1f909952180c044c8b8ddbbd996252c67d Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Mon, 9 Dec 2024 12:09:43 +0100 Subject: [PATCH 3/7] feat: add optional schema fields to models definition (sdk) --- api/tests/services/model_service_test.py | 2 +- sdk/radicalbit_platform_sdk/apis/model.py | 8 ++-- .../models/model_definition.py | 8 ++-- .../models/model_type.py | 1 + sdk/tests/client_test.py | 44 +++++++++++++++++++ 5 files changed, 54 insertions(+), 9 deletions(-) diff --git a/api/tests/services/model_service_test.py b/api/tests/services/model_service_test.py index 60f196e7..40a4e674 100644 --- a/api/tests/services/model_service_test.py +++ b/api/tests/services/model_service_test.py @@ -42,7 +42,7 @@ def test_create_model_ok(self): assert res == ModelOut.from_model(model) - def test_create_text_generation_model_ok(self): + def test_create_model_with_empty_schema_ok(self): model = db_mock.get_sample_model( model_type=ModelType.TEXT_GENERATION, features=None, diff --git a/sdk/radicalbit_platform_sdk/apis/model.py b/sdk/radicalbit_platform_sdk/apis/model.py index 62088621..16b59a80 100644 --- a/sdk/radicalbit_platform_sdk/apis/model.py +++ b/sdk/radicalbit_platform_sdk/apis/model.py @@ -60,16 +60,16 @@ def data_type(self) -> DataType: def granularity(self) -> Granularity: return self.__granularity - def features(self) -> List[ColumnDefinition]: + def features(self) -> Optional[List[ColumnDefinition]]: return self.__features - def target(self) -> ColumnDefinition: + def target(self) -> Optional[ColumnDefinition]: return self.__target - def timestamp(self) -> ColumnDefinition: + def timestamp(self) -> Optional[ColumnDefinition]: return self.__timestamp - def outputs(self) -> OutputType: + def outputs(self) -> Optional[OutputType]: return self.__outputs def frameworks(self) -> Optional[str]: diff --git a/sdk/radicalbit_platform_sdk/models/model_definition.py b/sdk/radicalbit_platform_sdk/models/model_definition.py index 4a1f6e6a..723fe69e 100644 --- a/sdk/radicalbit_platform_sdk/models/model_definition.py +++ b/sdk/radicalbit_platform_sdk/models/model_definition.py @@ -54,10 +54,10 @@ class BaseModelDefinition(BaseModel): model_type: ModelType data_type: DataType granularity: Granularity - features: List[ColumnDefinition] - outputs: OutputType - target: ColumnDefinition - timestamp: ColumnDefinition + features: Optional[List[ColumnDefinition]] = None + outputs: Optional[OutputType] = None + target: Optional[ColumnDefinition] = None + timestamp: Optional[ColumnDefinition] = None frameworks: Optional[str] = None algorithm: Optional[str] = None diff --git a/sdk/radicalbit_platform_sdk/models/model_type.py b/sdk/radicalbit_platform_sdk/models/model_type.py index 48cdb051..a22d29db 100644 --- a/sdk/radicalbit_platform_sdk/models/model_type.py +++ b/sdk/radicalbit_platform_sdk/models/model_type.py @@ -5,3 +5,4 @@ class ModelType(str, Enum): REGRESSION = 'REGRESSION' BINARY = 'BINARY' MULTI_CLASS = 'MULTI_CLASS' + TEXT_GENERATION = 'TEXT_GENERATION' diff --git a/sdk/tests/client_test.py b/sdk/tests/client_test.py index 366afa2d..6f560b95 100644 --- a/sdk/tests/client_test.py +++ b/sdk/tests/client_test.py @@ -217,6 +217,50 @@ def test_create_model(self): assert model.algorithm() is None assert model.frameworks() is None + @responses.activate + def test_create_model_with_empty_schema(self): + base_url = 'http://api:9000' + model = CreateModel( + name='My Model', + model_type=ModelType.TEXT_GENERATION, + data_type=DataType.TEXT, + granularity=Granularity.DAY, + features=None, + outputs=None, + target=None, + timestamp=None + ) + + model_definition = ModelDefinition( + name=model.name, + model_type=model.model_type, + data_type=model.data_type, + granularity=model.granularity, + created_at=str(time.time()), + updated_at=str(time.time()), + ) + responses.add( + method=responses.POST, + url=f'{base_url}/api/models', + body=model_definition.model_dump_json(), + status=201, + content_type='application/json', + ) + + client = Client(base_url) + model = client.create_model(model) + assert model.name() == model_definition.name + assert model.model_type() == model_definition.model_type + assert model.data_type() == model_definition.data_type + assert model.granularity() == model_definition.granularity + assert model.features() is None + assert model.outputs() is None + assert model.target() is None + assert model.timestamp() is None + assert model.description() is None + assert model.algorithm() is None + assert model.frameworks() is None + @responses.activate def test_search_models(self): base_url = 'http://api:9000' From d7673bce0be7a521b11b2627a8be1cf7730fc7ea Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Mon, 9 Dec 2024 15:09:35 +0100 Subject: [PATCH 4/7] feat: set optional fields model schema (spark side) --- sdk/tests/client_test.py | 2 +- spark/jobs/utils/models.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sdk/tests/client_test.py b/sdk/tests/client_test.py index 6f560b95..161f8b9e 100644 --- a/sdk/tests/client_test.py +++ b/sdk/tests/client_test.py @@ -228,7 +228,7 @@ def test_create_model_with_empty_schema(self): features=None, outputs=None, target=None, - timestamp=None + timestamp=None, ) model_definition = ModelDefinition( diff --git a/spark/jobs/utils/models.py b/spark/jobs/utils/models.py index 6dde137b..4a852565 100644 --- a/spark/jobs/utils/models.py +++ b/spark/jobs/utils/models.py @@ -35,6 +35,7 @@ class ModelType(str, Enum): REGRESSION = "REGRESSION" BINARY = "BINARY" MULTI_CLASS = "MULTI_CLASS" + TEXT_GENERATION = "TEXT_GENERATION" class DataType(str, Enum): @@ -89,10 +90,10 @@ class ModelOut(BaseModel): model_type: ModelType data_type: DataType granularity: Granularity - features: List[ColumnDefinition] - outputs: OutputType - target: ColumnDefinition - timestamp: ColumnDefinition + features: Optional[List[ColumnDefinition]] + outputs: Optional[OutputType] + target: Optional[ColumnDefinition] + timestamp: Optional[ColumnDefinition] frameworks: Optional[str] algorithm: Optional[str] created_at: str From b66a0fae193f5e0998dd2fea1768527c783b9616 Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Wed, 11 Dec 2024 14:34:17 +0100 Subject: [PATCH 5/7] handle json response --- api/alembic/env.py | 2 + ...a4cc_add_dataset_and_metrics_completion.py | 46 +++++++ api/app/core/config/config.py | 1 + api/app/db/dao/completion_dataset_dao.py | 87 ++++++++++++ .../completion_dataset_metrics_table.py | 26 ++++ api/app/db/tables/completion_dataset_table.py | 25 ++++ api/app/main.py | 3 + api/app/models/completion_response.py | 75 +++++++++++ api/app/models/dataset_dto.py | 23 ++++ api/app/routes/upload_dataset_route.py | 38 ++++++ api/app/services/file_service.py | 119 +++++++++++++++++ api/tests/commons/db_mock.py | 17 +++ api/tests/commons/json_file_mock.py | 66 +++++++++ api/tests/dao/completion_dataset_dao_test.py | 126 ++++++++++++++++++ api/tests/routes/upload_dataset_route_test.py | 85 +++++++++++- api/tests/services/file_service_test.py | 121 ++++++++++++++++- 16 files changed, 856 insertions(+), 4 deletions(-) create mode 100644 api/alembic/versions/e72dc7aaa4cc_add_dataset_and_metrics_completion.py create mode 100644 api/app/db/dao/completion_dataset_dao.py create mode 100644 api/app/db/tables/completion_dataset_metrics_table.py create mode 100644 api/app/db/tables/completion_dataset_table.py create mode 100644 api/app/models/completion_response.py create mode 100644 api/tests/commons/json_file_mock.py create mode 100644 api/tests/dao/completion_dataset_dao_test.py diff --git a/api/alembic/env.py b/api/alembic/env.py index 67be389e..3a031fc7 100644 --- a/api/alembic/env.py +++ b/api/alembic/env.py @@ -11,6 +11,8 @@ from app.db.tables.reference_dataset_metrics_table import * from app.db.tables.current_dataset_table import * from app.db.tables.current_dataset_metrics_table import * +from app.db.tables.completion_dataset_table import * +from app.db.tables.completion_dataset_metrics_table import * from app.db.tables.commons.json_encoded_dict import JSONEncodedDict from app.db.database import Database, BaseTable diff --git a/api/alembic/versions/e72dc7aaa4cc_add_dataset_and_metrics_completion.py b/api/alembic/versions/e72dc7aaa4cc_add_dataset_and_metrics_completion.py new file mode 100644 index 00000000..4fc7c8d5 --- /dev/null +++ b/api/alembic/versions/e72dc7aaa4cc_add_dataset_and_metrics_completion.py @@ -0,0 +1,46 @@ +"""add_dataset_and_metrics_completion + +Revision ID: e72dc7aaa4cc +Revises: dccb82489f4d +Create Date: 2024-12-11 13:33:38.759485 + +""" +from typing import Sequence, Union, Text + +from alembic import op +import sqlalchemy as sa +from app.db.tables.commons.json_encoded_dict import JSONEncodedDict + +# revision identifiers, used by Alembic. +revision: str = 'e72dc7aaa4cc' +down_revision: Union[str, None] = 'dccb82489f4d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('completion_dataset', + sa.Column('UUID', sa.UUID(), nullable=False), + sa.Column('MODEL_UUID', sa.UUID(), nullable=False), + sa.Column('PATH', sa.VARCHAR(), nullable=False), + sa.Column('DATE', sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column('STATUS', sa.VARCHAR(), nullable=False), + sa.ForeignKeyConstraint(['MODEL_UUID'], ['model.UUID'], name=op.f('fk_completion_dataset_MODEL_UUID_model')), + sa.PrimaryKeyConstraint('UUID', name=op.f('pk_completion_dataset')) + ) + op.create_table('completion_dataset_metrics', + sa.Column('UUID', sa.UUID(), nullable=False), + sa.Column('COMPLETION_UUID', sa.UUID(), nullable=False), + sa.Column('MODEL_QUALITY', JSONEncodedDict(astext_type=Text()), nullable=True), + sa.ForeignKeyConstraint(['COMPLETION_UUID'], ['completion_dataset.UUID'], name=op.f('fk_completion_dataset_metrics_COMPLETION_UUID_completion_dataset')), + sa.PrimaryKeyConstraint('UUID', name=op.f('pk_completion_dataset_metrics')) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('completion_dataset_metrics') + op.drop_table('completion_dataset') + # ### end Alembic commands ### diff --git a/api/app/core/config/config.py b/api/app/core/config/config.py index 08236d71..e5cf2318 100644 --- a/api/app/core/config/config.py +++ b/api/app/core/config/config.py @@ -56,6 +56,7 @@ class SparkConfig(BaseSettings): spark_image_pull_policy: str = 'IfNotPresent' spark_reference_app_path: str = 'local:///opt/spark/custom_jobs/reference_job.py' spark_current_app_path: str = 'local:///opt/spark/custom_jobs/current_job.py' + spark_completion_app_path: str = 'local:///opt/spark/custom_jobs/completion_job.py' spark_namespace: str = 'spark' spark_service_account: str = 'spark' diff --git a/api/app/db/dao/completion_dataset_dao.py b/api/app/db/dao/completion_dataset_dao.py new file mode 100644 index 00000000..75d4e90b --- /dev/null +++ b/api/app/db/dao/completion_dataset_dao.py @@ -0,0 +1,87 @@ +import re +from typing import List, Optional +from uuid import UUID + +from fastapi_pagination import Page, Params +from fastapi_pagination.ext.sqlalchemy import paginate +from sqlalchemy import asc, desc +from sqlalchemy.future import select as future_select + +from app.db.database import Database +from app.db.tables.completion_dataset_table import CompletionDataset +from app.models.dataset_dto import OrderType + + +class CompletionDatasetDAO: + def __init__(self, database: Database) -> None: + self.db = database + + def insert_completion_dataset( + self, completion_dataset: CompletionDataset + ) -> CompletionDataset: + with self.db.begin_session() as session: + session.add(completion_dataset) + session.flush() + return completion_dataset + + def get_completion_dataset_by_model_uuid( + self, model_uuid: UUID, completion_uuid: UUID + ) -> Optional[CompletionDataset]: + with self.db.begin_session() as session: + return ( + session.query(CompletionDataset) + .where( + CompletionDataset.model_uuid == model_uuid, + CompletionDataset.uuid == completion_uuid, + ) + .one_or_none() + ) + + def get_latest_completion_dataset_by_model_uuid( + self, model_uuid: UUID + ) -> Optional[CompletionDataset]: + with self.db.begin_session() as session: + return ( + session.query(CompletionDataset) + .order_by(desc(CompletionDataset.date)) + .where(CompletionDataset.model_uuid == model_uuid) + .limit(1) + .one_or_none() + ) + + def get_all_completion_datasets_by_model_uuid( + self, + model_uuid: UUID, + ) -> List[CompletionDataset]: + with self.db.begin_session() as session: + return ( + session.query(CompletionDataset) + .order_by(desc(CompletionDataset.date)) + .where(CompletionDataset.model_uuid == model_uuid) + ) + + def get_all_completion_datasets_by_model_uuid_paginated( + self, + model_uuid: UUID, + params: Params = Params(), + order: OrderType = OrderType.ASC, + sort: Optional[str] = None, + ) -> Page[CompletionDataset]: + def order_by_column_name(column_name): + return CompletionDataset.__getattribute__( + CompletionDataset, re.sub('(?=[A-Z])', '_', column_name).lower() + ) + + with self.db.begin_session() as session: + stmt = future_select(CompletionDataset).where( + CompletionDataset.model_uuid == model_uuid + ) + + if sort: + stmt = ( + stmt.order_by(asc(order_by_column_name(sort))) + if order == OrderType.ASC + else stmt.order_by(desc(order_by_column_name(sort))) + ) + + return paginate(session, stmt, params) diff --git a/api/app/db/tables/completion_dataset_metrics_table.py b/api/app/db/tables/completion_dataset_metrics_table.py new file mode 100644 index 00000000..c9f8788e --- /dev/null +++ b/api/app/db/tables/completion_dataset_metrics_table.py @@ -0,0 +1,26 @@ +from uuid import uuid4 + +from sqlalchemy import UUID, Column, ForeignKey + +from app.db.dao.base_dao import BaseDAO +from app.db.database import BaseTable, Reflected +from app.db.tables.commons.json_encoded_dict import JSONEncodedDict + + +class CompletionDatasetMetrics(Reflected, BaseTable, BaseDAO): + __tablename__ = 'completion_dataset_metrics' + + uuid = Column( + 'UUID', + UUID(as_uuid=True), + nullable=False, + default=uuid4, + primary_key=True, + ) + completion_uuid = Column( + 'COMPLETION_UUID', + UUID(as_uuid=True), + ForeignKey('completion_dataset.UUID'), + nullable=False, + ) + model_quality = Column('MODEL_QUALITY', JSONEncodedDict, nullable=True) diff --git a/api/app/db/tables/completion_dataset_table.py b/api/app/db/tables/completion_dataset_table.py new file mode 100644 index 00000000..ddea2e7b --- /dev/null +++ b/api/app/db/tables/completion_dataset_table.py @@ -0,0 +1,25 @@ +from uuid import uuid4 + +from sqlalchemy import TIMESTAMP, UUID, VARCHAR, Column, ForeignKey + +from app.db.dao.base_dao import BaseDAO +from app.db.database import BaseTable, Reflected +from app.models.job_status import JobStatus + + +class CompletionDataset(Reflected, BaseTable, BaseDAO): + __tablename__ = 'completion_dataset' + + uuid = Column( + 'UUID', + UUID(as_uuid=True), + nullable=False, + default=uuid4, + primary_key=True, + ) + model_uuid = Column( + 'MODEL_UUID', UUID(as_uuid=True), ForeignKey('model.UUID'), nullable=False + ) + path = Column('PATH', VARCHAR, nullable=False) + date = Column('DATE', TIMESTAMP(timezone=True), nullable=False) + status = Column('STATUS', VARCHAR, nullable=False, default=JobStatus.IMPORTING) diff --git a/api/app/main.py b/api/app/main.py index b40f4914..0d86e44d 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -10,6 +10,7 @@ from starlette.middleware.cors import CORSMiddleware from app.core import get_config +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO from app.db.dao.model_dao import ModelDAO @@ -54,6 +55,7 @@ reference_dataset_metrics_dao = ReferenceDatasetMetricsDAO(database) current_dataset_dao = CurrentDatasetDAO(database) current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database) +completion_dataset_dao = CompletionDatasetDAO(database) model_service = ModelService( model_dao=model_dao, @@ -81,6 +83,7 @@ file_service = FileService( reference_dataset_dao, current_dataset_dao, + completion_dataset_dao, model_service, s3_client, spark_k8s_client, diff --git a/api/app/models/completion_response.py b/api/app/models/completion_response.py new file mode 100644 index 00000000..f5c59c1e --- /dev/null +++ b/api/app/models/completion_response.py @@ -0,0 +1,75 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, model_validator + + +class TokenLogProbs(BaseModel): + token: str + bytes: List[int] + logprob: float + top_logprobs: List[Dict[str, float]] + + +class LogProbs(BaseModel): + content: List[TokenLogProbs] + refusal: Optional[str] = None + + +class Message(BaseModel): + content: str + refusal: Optional[str] = None + role: str + tool_calls: List = [] + parsed: Optional[dict] = None + + +class Choice(BaseModel): + finish_reason: str + index: int + logprobs: Optional[LogProbs] = None + message: Message + + +class UsageDetails(BaseModel): + accepted_prediction_tokens: int = 0 + reasoning_tokens: int = 0 + rejected_prediction_tokens: int = 0 + audio_tokens: Optional[int] = None + cached_tokens: Optional[int] = None + + +class Usage(BaseModel): + completion_tokens: int + prompt_tokens: int + total_tokens: int + completion_tokens_details: UsageDetails + prompt_tokens_details: Optional[UsageDetails] = None + + +class Response(BaseModel): + id: str + choices: List[Choice] + created: int + model: str + object: str + system_fingerprint: str + usage: Usage + + +class CompletionResponses(BaseModel): + responses: List[Response] + + @model_validator(mode='before') + @classmethod + def handle_single_response(cls, values): + if 'responses' not in values: + return {'responses': [values]} + return values + + @model_validator(mode='after') + def validate_responses_non_empty(self): + if not self.responses or len(self.responses) == 0: + raise ValueError( + "The 'responses' array must contain at least one response." + ) + return self diff --git a/api/app/models/dataset_dto.py b/api/app/models/dataset_dto.py index 8f6feb4e..77c3094f 100644 --- a/api/app/models/dataset_dto.py +++ b/api/app/models/dataset_dto.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.reference_dataset_table import ReferenceDataset @@ -55,6 +56,28 @@ def from_current_dataset(cd: CurrentDataset) -> 'CurrentDatasetDTO': ) +class CompletionDatasetDTO(BaseModel): + uuid: UUID + model_uuid: UUID + path: str + date: str + status: str + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + @staticmethod + def from_completion_dataset(cd: CompletionDataset) -> 'CompletionDatasetDTO': + return CompletionDatasetDTO( + uuid=cd.uuid, + model_uuid=cd.model_uuid, + path=cd.path, + date=cd.date.isoformat(), + status=cd.status, + ) + + class FileReference(BaseModel): file_url: str separator: str = ',' diff --git a/api/app/routes/upload_dataset_route.py b/api/app/routes/upload_dataset_route.py index bd906f48..d7ef7998 100644 --- a/api/app/routes/upload_dataset_route.py +++ b/api/app/routes/upload_dataset_route.py @@ -6,6 +6,7 @@ from fastapi_pagination import Page, Params from app.models.dataset_dto import ( + CompletionDatasetDTO, CurrentDatasetDTO, FileReference, OrderType, @@ -64,6 +65,16 @@ def bind_current_file( ) -> CurrentDatasetDTO: return file_service.bind_current_file(model_uuid, file_ref) + @router.post( + '/{model_uuid}/completion/upload', + status_code=status.HTTP_200_OK, + response_model=CompletionDatasetDTO, + ) + def upload_completion_file( + model_uuid: UUID, json_file: UploadFile = File(...) + ) -> CompletionDatasetDTO: + return file_service.upload_completion_file(model_uuid, json_file) + @router.get( '/{model_uuid}/reference', status_code=200, @@ -118,4 +129,31 @@ def get_all_current_datasets_by_model_uuid( ): return file_service.get_all_current_datasets_by_model_uuid(model_uuid) + @router.get( + '/{model_uuid}/completion', + status_code=200, + response_model=Page[CompletionDatasetDTO], + ) + def get_all_completion_datasets_by_model_uuid_paginated( + model_uuid: UUID, + _page: Annotated[int, Query()] = 1, + _limit: Annotated[int, Query()] = 50, + _order: Annotated[OrderType, Query()] = OrderType.ASC, + _sort: Annotated[Optional[str], Query()] = None, + ): + params = Params(page=_page, size=_limit) + return file_service.get_all_completion_datasets_by_model_uuid_paginated( + model_uuid, params=params, order=_order, sort=_sort + ) + + @router.get( + '/{model_uuid}/completion/all', + status_code=200, + response_model=List[CompletionDatasetDTO], + ) + def get_all_completion_datasets_by_model_uuid( + model_uuid: UUID, + ): + return file_service.get_all_completion_datasets_by_model_uuid(model_uuid) + return router diff --git a/api/app/services/file_service.py b/api/app/services/file_service.py index e2c78707..23e07023 100644 --- a/api/app/services/file_service.py +++ b/api/app/services/file_service.py @@ -1,5 +1,6 @@ from copy import deepcopy import datetime +import json import logging import pathlib from typing import List, Optional @@ -10,17 +11,23 @@ from fastapi import HTTPException, UploadFile from fastapi_pagination import Page, Params import pandas as pd +from pydantic import ValidationError from spark_on_k8s.client import ExecutorInstances, PodResources, SparkOnK8S from spark_on_k8s.utils.configuration import Configuration from app.core.config.config import create_secrets, get_config +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO +from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.reference_dataset_metrics_table import ReferenceDatasetMetrics from app.db.tables.reference_dataset_table import ReferenceDataset +from app.models.completion_response import CompletionResponses from app.models.dataset_dto import ( + CompletionDatasetDTO, CurrentDatasetDTO, FileReference, OrderType, @@ -48,12 +55,14 @@ def __init__( self, reference_dataset_dao: ReferenceDatasetDAO, current_dataset_dao: CurrentDatasetDAO, + completion_dataset_dao: CompletionDatasetDAO, model_service: ModelService, s3_client: boto3.client, spark_k8s_client: SparkOnK8S, ) -> 'FileService': self.rd_dao = reference_dataset_dao self.cd_dao = current_dataset_dao + self.completion_dataset_dao = completion_dataset_dao self.model_svc = model_service self.s3_client = s3_client s3_config = get_config().s3_config @@ -321,6 +330,69 @@ def bind_current_file( except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e + def upload_completion_file( + self, + model_uuid: UUID, + json_file: UploadFile, + ) -> CompletionDatasetDTO: + model_out = self.model_svc.get_model_by_uuid(model_uuid) + if not model_out: + logger.error('Model %s not found', model_uuid) + raise ModelNotFoundError(f'Model {model_uuid} not found') + + self.validate_json_file(json_file) + _f_name = json_file.filename + _f_uuid = uuid4() + try: + object_name = f'{str(model_out.uuid)}/completion/{_f_uuid}/{_f_name}' + self.s3_client.upload_fileobj( + json_file.file, + self.bucket_name, + object_name, + ExtraArgs={ + 'Metadata': { + 'model_uuid': str(model_out.uuid), + 'model_name': model_out.name, + 'file_type': 'completion', + } + }, + ) + + path = f's3://{self.bucket_name}/{object_name}' + + inserted_file = self.completion_dataset_dao.insert_completion_dataset( + CompletionDataset( + uuid=_f_uuid, + model_uuid=model_uuid, + path=path, + date=datetime.datetime.now(tz=datetime.UTC), + status=JobStatus.IMPORTING, + ) + ) + + logger.debug('File %s has been correctly stored in the db', inserted_file) + + spark_config = get_config().spark_config + self.__submit_app( + app_name=str(model_out.uuid), + app_path=spark_config.spark_completion_app_path, + app_arguments=[ + model_out.model_dump_json(), + path.replace('s3://', 's3a://'), + str(inserted_file.uuid), + CompletionDatasetMetrics.__tablename__, + ], + ) + + return CompletionDatasetDTO.from_completion_dataset(inserted_file) + + except NoCredentialsError as nce: + raise HTTPException( + status_code=500, detail='S3 credentials not available' + ) from nce + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + def get_all_reference_datasets_by_model_uuid_paginated( self, model_uuid: UUID, @@ -382,6 +454,40 @@ def get_all_current_datasets_by_model_uuid( for current_dataset in currents ] + def get_all_completion_datasets_by_model_uuid_paginated( + self, + model_uuid: UUID, + params: Params = Params(), + order: OrderType = OrderType.ASC, + sort: Optional[str] = None, + ) -> Page[CompletionDatasetDTO]: + results: Page[CompletionDatasetDTO] = ( + self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid_paginated( + model_uuid, params=params, order=order, sort=sort + ) + ) + + _items = [ + CompletionDatasetDTO.from_completion_dataset(completion_dataset) + for completion_dataset in results.items + ] + + return Page.create(items=_items, params=params, total=results.total) + + def get_all_completion_datasets_by_model_uuid( + self, + model_uuid: UUID, + ) -> List[CompletionDatasetDTO]: + completions = ( + self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid( + model_uuid + ) + ) + return [ + CompletionDatasetDTO.from_completion_dataset(completion_dataset) + for completion_dataset in completions + ] + @staticmethod def infer_schema(csv_file: UploadFile, sep: str = ',') -> InferredSchemaDTO: FileService.validate_file(csv_file, sep) @@ -449,6 +555,19 @@ def validate_file( csv_file.file.flush() csv_file.file.seek(0) + @staticmethod + def validate_json_file(json_file: UploadFile) -> None: + try: + content = json_file.file.read().decode('utf-8') + data = json.loads(content) + CompletionResponses(**data) + except ValidationError as e: + logger.error('Invalid json file: %s', str(e)) + raise InvalidFileException(f'Invalid json file: {str(e)}') from e + except Exception as e: + logger.error('Error while reading the json file: %s', str(e)) + raise InvalidFileException(f'Invalid json file: {str(e)}') from e + def __submit_app( self, app_name: str, app_path: str, app_arguments: List[str] ) -> None: diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index ba675605..13f0bd0a 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional import uuid +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.model_table import Model @@ -23,6 +24,7 @@ MODEL_UUID = uuid.uuid4() REFERENCE_UUID = uuid.uuid4() CURRENT_UUID = uuid.uuid4() +COMPLETION_UUID = uuid.uuid4() def get_sample_model( @@ -172,6 +174,21 @@ def get_sample_current_dataset( ) +def get_sample_completion_dataset( + uuid: uuid.UUID = COMPLETION_UUID, + model_uuid: uuid.UUID = MODEL_UUID, + path: str = 'completion/json_file.json', + status: str = JobStatus.IMPORTING.value, +) -> CompletionDataset: + return CompletionDataset( + uuid=uuid, + model_uuid=model_uuid, + path=path, + date=datetime.datetime.now(tz=datetime.UTC), + status=status, + ) + + statistics_dict = { 'nVariables': 10, 'nObservations': 1000, diff --git a/api/tests/commons/json_file_mock.py b/api/tests/commons/json_file_mock.py new file mode 100644 index 00000000..a9a38747 --- /dev/null +++ b/api/tests/commons/json_file_mock.py @@ -0,0 +1,66 @@ +from io import BytesIO + +from fastapi import UploadFile + + +def get_completion_sample_json_file() -> UploadFile: + json_content = """ + { + "id": "chatcmpl-0120", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": { + "content": [ + { + "token": "Sky", + "bytes": [83, 107, 121], + "logprob": -1.2830728, + "top_logprobs": [] + } + ], + "refusal": null + }, + "message": { + "content": "Sky is blue.", + "refusal": null, + "role": "assistant", + "tool_calls": [], + "parsed": null + } + } + ], + "created": 1733486708, + "model": "gpt-4o-2024-08-06", + "object": "chat.completion", + "system_fingerprint": "fp_c7ca0ebaca", + "usage": { + "completion_tokens": 4, + "prompt_tokens": 25, + "total_tokens": 29, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0 + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0 + } + } + }""" + return UploadFile(filename='test.json', file=BytesIO(json_content.encode())) + + +def get_incorrect_sample_json_file() -> UploadFile: + json_content = """ + { + "id": "chatcmpl-0120", + "created": 1733486708, + "model": "gpt-4o-2024-08-06", + "object": "chat.completion", + "system_fingerprint": "fp_c7ca0ebaca", + }""" + return UploadFile(filename='test.json', file=BytesIO(json_content.encode())) diff --git a/api/tests/dao/completion_dataset_dao_test.py b/api/tests/dao/completion_dataset_dao_test.py new file mode 100644 index 00000000..dd40cce8 --- /dev/null +++ b/api/tests/dao/completion_dataset_dao_test.py @@ -0,0 +1,126 @@ +import datetime +from uuid import uuid4 + +from fastapi_pagination import Params + +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO +from app.db.dao.model_dao import ModelDAO +from app.db.tables.completion_dataset_table import CompletionDataset +from tests.commons import db_mock +from tests.commons.db_integration import DatabaseIntegration + + +class CompletionDatasetDAOTest(DatabaseIntegration): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.completion_dataset_dao = CompletionDatasetDAO(cls.db) + cls.model_dao = ModelDAO(cls.db) + + def test_insert_completion_dataset_upload_result(self): + model = self.model_dao.insert(db_mock.get_sample_model()) + to_insert = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + + inserted = self.completion_dataset_dao.insert_completion_dataset(to_insert) + assert inserted == to_insert + + def test_get_current_dataset_by_model_uuid(self): + model = self.model_dao.insert(db_mock.get_sample_model()) + to_insert = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + + inserted = self.completion_dataset_dao.insert_completion_dataset(to_insert) + retrieved = self.completion_dataset_dao.get_completion_dataset_by_model_uuid( + inserted.model_uuid, inserted.uuid + ) + assert inserted.uuid == retrieved.uuid + assert inserted.model_uuid == retrieved.model_uuid + assert inserted.path == retrieved.path + + def test_get_latest_completion_dataset_by_model_uuid(self): + model = self.model_dao.insert(db_mock.get_sample_model()) + completion_one = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + + self.completion_dataset_dao.insert_completion_dataset(completion_one) + + completion_two = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + + inserted_two = self.completion_dataset_dao.insert_completion_dataset( + completion_two + ) + + retrieved = ( + self.completion_dataset_dao.get_latest_completion_dataset_by_model_uuid( + model.uuid + ) + ) + assert inserted_two.uuid == retrieved.uuid + assert inserted_two.model_uuid == retrieved.model_uuid + assert inserted_two.path == retrieved.path + + def test_get_all_completion_datasets_by_model_uuid_paginated(self): + model = self.model_dao.insert(db_mock.get_sample_model()) + completion_upload_1 = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + completion_upload_2 = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + completion_upload_3 = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + inserted_1 = self.completion_dataset_dao.insert_completion_dataset( + completion_upload_1 + ) + inserted_2 = self.completion_dataset_dao.insert_completion_dataset( + completion_upload_2 + ) + inserted_3 = self.completion_dataset_dao.insert_completion_dataset( + completion_upload_3 + ) + + retrieved = self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid_paginated( + model.uuid, Params(page=1, size=10) + ) + + assert inserted_1.uuid == retrieved.items[0].uuid + assert inserted_1.model_uuid == retrieved.items[0].model_uuid + assert inserted_1.path == retrieved.items[0].path + + assert inserted_2.uuid == retrieved.items[1].uuid + assert inserted_2.model_uuid == retrieved.items[1].model_uuid + assert inserted_2.path == retrieved.items[1].path + + assert inserted_3.uuid == retrieved.items[2].uuid + assert inserted_3.model_uuid == retrieved.items[2].model_uuid + assert inserted_3.path == retrieved.items[2].path + + assert len(retrieved.items) == 3 diff --git a/api/tests/routes/upload_dataset_route_test.py b/api/tests/routes/upload_dataset_route_test.py index 994f8438..27b21cab 100644 --- a/api/tests/routes/upload_dataset_route_test.py +++ b/api/tests/routes/upload_dataset_route_test.py @@ -9,6 +9,7 @@ from starlette.testclient import TestClient from app.models.dataset_dto import ( + CompletionDatasetDTO, CurrentDatasetDTO, FileReference, OrderType, @@ -17,7 +18,7 @@ from app.models.job_status import JobStatus from app.routes.upload_dataset_route import UploadDatasetRoute from app.services.file_service import FileService -from tests.commons import csv_file_mock as csv, db_mock +from tests.commons import csv_file_mock as csv, db_mock, json_file_mock as json class UploadDatasetRouteTest(unittest.TestCase): @@ -113,6 +114,26 @@ def test_bind_current(self): assert res.status_code == 200 assert jsonable_encoder(upload_file_result) == res.json() + def test_upload_completion(self): + file = json.get_completion_sample_json_file() + model_uuid = uuid.uuid4() + upload_file_result = CompletionDatasetDTO( + uuid=uuid.uuid4(), + model_uuid=model_uuid, + path='test', + date=str(datetime.datetime.now(tz=datetime.UTC)), + status=JobStatus.IMPORTING, + ) + self.file_service.upload_completion_file = MagicMock( + return_value=upload_file_result + ) + res = self.client.post( + f'{self.prefix}/{model_uuid}/completion/upload', + files={'json_file': (file.filename, file.file)}, + ) + assert res.status_code == 200 + assert jsonable_encoder(upload_file_result) == res.json() + def test_get_all_reference_datasets_by_model_uuid_paginated(self): test_model_uuid = uuid.uuid4() reference_upload_1 = db_mock.get_sample_reference_dataset( @@ -181,6 +202,40 @@ def test_get_all_current_datasets_by_model_uuid_paginated(self): sort=None, ) + def test_get_all_completion_datasets_by_model_uuid_paginated(self): + test_model_uuid = uuid.uuid4() + completion_upload_1 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_1.json' + ) + completion_upload_2 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_2.json' + ) + completion_upload_3 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_3.json' + ) + + sample_results = [ + CompletionDatasetDTO.from_completion_dataset(completion_upload_1), + CompletionDatasetDTO.from_completion_dataset(completion_upload_2), + CompletionDatasetDTO.from_completion_dataset(completion_upload_3), + ] + page = Page.create( + items=sample_results, total=len(sample_results), params=Params() + ) + self.file_service.get_all_completion_datasets_by_model_uuid_paginated = ( + MagicMock(return_value=page) + ) + + res = self.client.get(f'{self.prefix}/{test_model_uuid}/completion') + assert res.status_code == 200 + assert jsonable_encoder(page) == res.json() + self.file_service.get_all_completion_datasets_by_model_uuid_paginated.assert_called_once_with( + test_model_uuid, + params=Params(page=1, size=50), + order=OrderType.ASC, + sort=None, + ) + def test_get_all_reference_datasets_by_model_uuid(self): test_model_uuid = uuid.uuid4() reference_upload_1 = db_mock.get_sample_reference_dataset( @@ -236,3 +291,31 @@ def test_get_all_current_datasets_by_model_uuid(self): self.file_service.get_all_current_datasets_by_model_uuid.assert_called_once_with( test_model_uuid, ) + + def test_get_all_completion_datasets_by_model_uuid(self): + test_model_uuid = uuid.uuid4() + completion_upload_1 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_1.json' + ) + completion_upload_2 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_2.json' + ) + completion_upload_3 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_3.json' + ) + + sample_results = [ + CompletionDatasetDTO.from_completion_dataset(completion_upload_1), + CompletionDatasetDTO.from_completion_dataset(completion_upload_2), + CompletionDatasetDTO.from_completion_dataset(completion_upload_3), + ] + self.file_service.get_all_completion_datasets_by_model_uuid = MagicMock( + return_value=sample_results + ) + + res = self.client.get(f'{self.prefix}/{test_model_uuid}/completion/all') + assert res.status_code == 200 + assert jsonable_encoder(sample_results) == res.json() + self.file_service.get_all_completion_datasets_by_model_uuid.assert_called_once_with( + test_model_uuid, + ) diff --git a/api/tests/services/file_service_test.py b/api/tests/services/file_service_test.py index 96231650..23447f56 100644 --- a/api/tests/services/file_service_test.py +++ b/api/tests/services/file_service_test.py @@ -8,17 +8,24 @@ from fastapi_pagination import Page, Params import pytest +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.reference_dataset_table import ReferenceDataset -from app.models.dataset_dto import CurrentDatasetDTO, FileReference, ReferenceDatasetDTO +from app.models.dataset_dto import ( + CompletionDatasetDTO, + CurrentDatasetDTO, + FileReference, + ReferenceDatasetDTO, +) from app.models.exceptions import InvalidFileException, ModelNotFoundError from app.models.job_status import JobStatus from app.models.model_dto import ModelOut from app.services.file_service import FileService from app.services.model_service import ModelService -from tests.commons import csv_file_mock as csv, db_mock +from tests.commons import csv_file_mock as csv, db_mock, json_file_mock as json from tests.commons.db_mock import get_sample_reference_dataset @@ -27,15 +34,22 @@ class FileServiceTest(unittest.TestCase): def setUpClass(cls): cls.rd_dao = MagicMock(spec_set=ReferenceDatasetDAO) cls.cd_dao = MagicMock(spec_set=CurrentDatasetDAO) + cls.completion_dataset_dao = MagicMock(spec_set=CompletionDatasetDAO) cls.model_svc = MagicMock(spec_set=ModelService) cls.s3_client = MagicMock() cls.spark_k8s_client = MagicMock() cls.files_service = FileService( - cls.rd_dao, cls.cd_dao, cls.model_svc, cls.s3_client, cls.spark_k8s_client + cls.rd_dao, + cls.cd_dao, + cls.completion_dataset_dao, + cls.model_svc, + cls.s3_client, + cls.spark_k8s_client, ) cls.mocks = [ cls.rd_dao, cls.cd_dao, + cls.completion_dataset_dao, cls.model_svc, cls.s3_client, cls.spark_k8s_client, @@ -67,6 +81,15 @@ def test_infer_schema_separator(self): schema = FileService.infer_schema(file, sep=';') assert schema == csv.correct_schema() + def test_validate_completion_json_file_ok(self): + json_file = json.get_completion_sample_json_file() + self.files_service.validate_json_file(json_file) + + def test_validate_completion_json_file_error(self): + json_file = json.get_incorrect_sample_json_file() + with pytest.raises(InvalidFileException): + self.files_service.validate_json_file(json_file) + def test_upload_reference_file_ok(self): file = csv.get_correct_sample_csv_file() dataset_uuid = uuid4() @@ -253,6 +276,44 @@ def test_upload_current_file_reference_file_not_found(self): correlation_id_column, ) + def test_upload_completion_file_ok(self): + file = json.get_completion_sample_json_file() + model = db_mock.get_sample_model( + features=None, + outputs=None, + target=None, + timestamp=None, + ) + object_name = f'{str(model.uuid)}/completion/{file.filename}' + path = f's3://bucket/{object_name}' + inserted_file = CompletionDataset( + uuid=uuid4(), + model_uuid=model_uuid, + path=path, + date=datetime.datetime.now(tz=datetime.UTC), + status=JobStatus.IMPORTING, + ) + + self.model_svc.get_model_by_uuid = MagicMock( + return_value=ModelOut.from_model(model) + ) + self.s3_client.upload_fileobj = MagicMock() + self.completion_dataset_dao.insert_completion_dataset = MagicMock( + return_value=inserted_file + ) + self.spark_k8s_client.submit_app = MagicMock() + + result = self.files_service.upload_completion_file( + model.uuid, + file, + ) + + self.model_svc.get_model_by_uuid.assert_called_once() + self.completion_dataset_dao.insert_completion_dataset.assert_called_once() + self.s3_client.upload_fileobj.assert_called_once() + self.spark_k8s_client.submit_app.assert_called_once() + assert result == CompletionDatasetDTO.from_completion_dataset(inserted_file) + def test_get_all_reference_datasets_by_model_uuid_paginated(self): reference_upload_1 = db_mock.get_sample_reference_dataset( model_uuid=model_uuid, path='reference/test_1.csv' @@ -311,6 +372,35 @@ def test_get_all_current_datasets_by_model_uuid_paginated(self): assert result.items[1].model_uuid == model_uuid assert result.items[2].model_uuid == model_uuid + def test_get_all_completion_datasets_by_model_uuid_paginated(self): + completion_upload_1 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_1.json' + ) + completion_upload_2 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_2.json' + ) + completion_upload_3 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_3.json' + ) + + sample_results = [completion_upload_1, completion_upload_2, completion_upload_3] + page = Page.create( + sample_results, total=len(sample_results), params=Params(page=1, size=10) + ) + self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid_paginated = MagicMock( + return_value=page + ) + + result = self.files_service.get_all_completion_datasets_by_model_uuid_paginated( + model_uuid, Params(page=1, size=10) + ) + + assert result.total == 3 + assert len(result.items) == 3 + assert result.items[0].model_uuid == model_uuid + assert result.items[1].model_uuid == model_uuid + assert result.items[2].model_uuid == model_uuid + def test_get_all_reference_datasets_by_model_uuid(self): reference_upload_1 = db_mock.get_sample_reference_dataset( model_uuid=model_uuid, path='reference/test_1.csv' @@ -357,5 +447,30 @@ def test_get_all_current_datasets_by_model_uuid(self): assert result[1].model_uuid == model_uuid assert result[2].model_uuid == model_uuid + def test_get_all_completion_datasets_by_model_uuid(self): + completion_upload_1 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_1.json' + ) + completion_upload_2 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_2.json' + ) + completion_upload_3 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_3.json' + ) + + sample_results = [completion_upload_1, completion_upload_2, completion_upload_3] + self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid = ( + MagicMock(return_value=sample_results) + ) + + result = self.files_service.get_all_completion_datasets_by_model_uuid( + model_uuid + ) + + assert len(result) == 3 + assert result[0].model_uuid == model_uuid + assert result[1].model_uuid == model_uuid + assert result[2].model_uuid == model_uuid + model_uuid = db_mock.MODEL_UUID From 14ef8351ba2b79c5be8bccbb2eae04353168da82 Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Thu, 12 Dec 2024 16:06:41 +0100 Subject: [PATCH 6/7] feat: handle empty completion job --- spark/jobs/completion_job.py | 88 ++++++++++++++++++++++++++++++++++++ spark/jobs/current_job.py | 19 ++++---- spark/jobs/reference_job.py | 25 ++++++---- spark/jobs/utils/db.py | 13 ++++-- 4 files changed, 123 insertions(+), 22 deletions(-) create mode 100644 spark/jobs/completion_job.py diff --git a/spark/jobs/completion_job.py b/spark/jobs/completion_job.py new file mode 100644 index 00000000..d601c223 --- /dev/null +++ b/spark/jobs/completion_job.py @@ -0,0 +1,88 @@ +import sys +import os +import uuid + +from pyspark.sql.types import StructField, StructType, StringType +from utils.models import JobStatus +from utils.db import update_job_status, write_to_db + +from pyspark.sql import SparkSession, DataFrame + +import logging + + +def compute_metrics(df: DataFrame) -> dict: + complete_record = {} + # TODO: compute model quality metrics + return complete_record + + +def main( + spark_session: SparkSession, + completion_dataset_path: str, + completion_uuid: str, + metrics_table_name: str, + dataset_table_name: str, +): + spark_context = spark_session.sparkContext + + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID") + ) + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY") + ) + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.endpoint.region", os.getenv("AWS_REGION") + ) + if os.getenv("S3_ENDPOINT_URL"): + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL") + ) + spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true") + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.connection.ssl.enabled", "false" + ) + df = spark_session.read.option("multiline", "true").json(completion_dataset_path) + complete_record = compute_metrics(df) + + complete_record.update( + {"UUID": str(uuid.uuid4()), "COMPLETION_UUID": completion_uuid} + ) + + schema = StructType( + [ + StructField("UUID", StringType(), True), + StructField("COMPLETION_UUID", StringType(), True), + StructField("MODEL_QUALITY", StringType(), True), + ] + ) + + write_to_db(spark_session, complete_record, schema, metrics_table_name) + update_job_status(completion_uuid, JobStatus.SUCCEEDED, dataset_table_name) + + +if __name__ == "__main__": + spark_session = SparkSession.builder.appName( + "radicalbit_completion_metrics" + ).getOrCreate() + + completion_dataset_path = sys.argv[1] + completion_uuid = sys.argv[2] + metrics_table_name = sys.argv[3] + dataset_table_name = sys.argv[4] + + try: + main( + spark_session, + completion_dataset_path, + completion_uuid, + metrics_table_name, + dataset_table_name, + ) + + except Exception as e: + logging.exception(e) + update_job_status(completion_uuid, JobStatus.ERROR, dataset_table_name) + finally: + spark_session.stop() diff --git a/spark/jobs/current_job.py b/spark/jobs/current_job.py index 02b37d73..ff0140da 100644 --- a/spark/jobs/current_job.py +++ b/spark/jobs/current_job.py @@ -124,7 +124,8 @@ def main( current_dataset_path: str, current_uuid: str, reference_dataset_path: str, - table_name: str, + metrics_table_name: str, + dataset_table_name: str, ): spark_context = spark_session.sparkContext @@ -171,9 +172,8 @@ def main( ] ) - write_to_db(spark_session, complete_record, schema, table_name) - # FIXME table name should come from parameters - update_job_status(current_uuid, JobStatus.SUCCEEDED, "current_dataset") + write_to_db(spark_session, complete_record, schema, metrics_table_name) + update_job_status(current_uuid, JobStatus.SUCCEEDED, dataset_table_name) if __name__ == "__main__": @@ -189,8 +189,10 @@ def main( current_uuid = sys.argv[3] # Reference dataset s3 path is fourth param reference_dataset_path = sys.argv[4] - # Table name fifth param - table_name = sys.argv[5] + # Metrics Table name fifth param + metrics_table_name = sys.argv[5] + # Metrics Table name sixth param + dataset_table_name = sys.argv[6] try: main( @@ -199,11 +201,10 @@ def main( current_dataset_path, current_uuid, reference_dataset_path, - table_name, + metrics_table_name, ) except Exception as e: logging.exception(e) - # FIXME table name should come from parameters - update_job_status(current_uuid, JobStatus.ERROR, "current_dataset") + update_job_status(current_uuid, JobStatus.ERROR, dataset_table_name) finally: spark_session.stop() diff --git a/spark/jobs/reference_job.py b/spark/jobs/reference_job.py index 59a5993a..4750ab3b 100644 --- a/spark/jobs/reference_job.py +++ b/spark/jobs/reference_job.py @@ -77,7 +77,8 @@ def main( model: ModelOut, reference_dataset_path: str, reference_uuid: str, - table_name: str, + metrics_table_name: str, + dataset_table_name: str, ): spark_context = spark_session.sparkContext @@ -118,9 +119,8 @@ def main( ] ) - write_to_db(spark_session, complete_record, schema, table_name) - # FIXME table name should come from parameters - update_job_status(reference_uuid, JobStatus.SUCCEEDED, "reference_dataset") + write_to_db(spark_session, complete_record, schema, metrics_table_name) + update_job_status(reference_uuid, JobStatus.SUCCEEDED, dataset_table_name) if __name__ == "__main__": @@ -134,14 +134,21 @@ def main( reference_dataset_path = sys.argv[2] # Reference file uuid third param reference_uuid = sys.argv[3] - # Table name fourth param - table_name = sys.argv[4] + # Metrics table name fourth param + metrics_table_name = sys.argv[4] + # Dataset table name fourth param + dataset_table_name = sys.argv[5] try: - main(spark_session, model, reference_dataset_path, reference_uuid, table_name) + main( + spark_session, + model, + reference_dataset_path, + reference_uuid, + metrics_table_name, + ) except Exception as e: logging.exception(e) - # FIXME table name should come from parameters - update_job_status(reference_uuid, JobStatus.ERROR, "reference_dataset") + update_job_status(reference_uuid, JobStatus.ERROR, dataset_table_name) finally: spark_session.stop() diff --git a/spark/jobs/utils/db.py b/spark/jobs/utils/db.py index 495cd509..ae146cf9 100644 --- a/spark/jobs/utils/db.py +++ b/spark/jobs/utils/db.py @@ -17,7 +17,7 @@ url = f"jdbc:postgresql://{db_host}:{db_port}/{db_name}" -def update_job_status(file_uuid: str, status: str, table_name: str): +def update_job_status(file_uuid: str, status: str, dataset_table_name: str): # Use psycopg2 to update the job status with psycopg2.connect( host=db_host, @@ -30,7 +30,7 @@ def update_job_status(file_uuid: str, status: str, table_name: str): with conn.cursor() as cur: cur.execute( f""" - UPDATE {table_name} + UPDATE {dataset_table_name} SET "STATUS" = %s WHERE "UUID" = %s """, @@ -40,7 +40,10 @@ def update_job_status(file_uuid: str, status: str, table_name: str): def write_to_db( - spark_session: SparkSession, record: Dict, schema: StructType, table_name: str + spark_session: SparkSession, + record: Dict, + schema: StructType, + metrics_table_name: str, ): out_df = spark_session.createDataFrame(data=[record], schema=schema) @@ -49,4 +52,6 @@ def write_to_db( "stringtype", "unspecified" ).option("driver", "org.postgresql.Driver").option("user", user).option( "password", password - ).option("dbtable", f'"{postgres_schema}"."{table_name}"').mode("append").save() + ).option("dbtable", f'"{postgres_schema}"."{metrics_table_name}"').mode( + "append" + ).save() From 75deb9d6340730b43a7b689b20fec1428025b9d0 Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Thu, 12 Dec 2024 16:07:25 +0100 Subject: [PATCH 7/7] feat: refatcoring spark job arguments --- api/app/services/file_service.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/app/services/file_service.py b/api/app/services/file_service.py index 3558baae..e08756e4 100644 --- a/api/app/services/file_service.py +++ b/api/app/services/file_service.py @@ -126,6 +126,7 @@ def upload_reference_file( path.replace('s3://', 's3a://'), str(inserted_file.uuid), ReferenceDatasetMetrics.__tablename__, + ReferenceDataset.__tablename__, ], ) @@ -175,6 +176,7 @@ def bind_reference_file( file_ref.file_url.replace('s3://', 's3a://'), str(inserted_file.uuid), ReferenceDatasetMetrics.__tablename__, + ReferenceDataset.__tablename__, ], ) @@ -259,6 +261,7 @@ def upload_current_file( str(inserted_file.uuid), reference_dataset.path.replace('s3://', 's3a://'), CurrentDatasetMetrics.__tablename__, + CurrentDataset.__tablename__, ], ) @@ -312,6 +315,7 @@ def bind_current_file( str(inserted_file.uuid), reference_dataset.path.replace('s3://', 's3a://'), CurrentDatasetMetrics.__tablename__, + CurrentDataset.__tablename__, ], ) @@ -377,10 +381,10 @@ def upload_completion_file( app_name=str(model_out.uuid), app_path=spark_config.spark_completion_app_path, app_arguments=[ - model_out.model_dump_json(), path.replace('s3://', 's3a://'), str(inserted_file.uuid), CompletionDatasetMetrics.__tablename__, + CompletionDataset.__tablename__, ], )