From 5a82c900451c4a38d73635bbf7a080a3ef3e7121 Mon Sep 17 00:00:00 2001 From: Daniele Tria <36860433+dtria91@users.noreply.github.com> Date: Thu, 12 Dec 2024 15:17:10 +0100 Subject: [PATCH] feat: handling json file containing responses from a generative model (#208) * feat: add text-generation as new model type, handle the new model type, set schema fields as optional, edit test * fix: ruff check * feat: add optional schema fields to models definition (sdk) * feat: set optional fields model schema (spark side) * handle json response * feat: editing json handling --- api/alembic/env.py | 2 + ...a4cc_add_dataset_and_metrics_completion.py | 46 +++++++ api/app/core/config/config.py | 1 + api/app/db/dao/completion_dataset_dao.py | 87 ++++++++++++ .../completion_dataset_metrics_table.py | 26 ++++ api/app/db/tables/completion_dataset_table.py | 25 ++++ api/app/main.py | 3 + api/app/models/completion_response.py | 78 +++++++++++ api/app/models/dataset_dto.py | 23 ++++ api/app/routes/upload_dataset_route.py | 38 +++++ api/app/services/file_service.py | 119 ++++++++++++++++ api/tests/commons/db_mock.py | 17 +++ api/tests/commons/json_file_mock.py | 106 ++++++++++++++ api/tests/dao/completion_dataset_dao_test.py | 126 +++++++++++++++++ api/tests/routes/upload_dataset_route_test.py | 85 +++++++++++- api/tests/services/file_service_test.py | 130 +++++++++++++++++- 16 files changed, 908 insertions(+), 4 deletions(-) create mode 100644 api/alembic/versions/e72dc7aaa4cc_add_dataset_and_metrics_completion.py create mode 100644 api/app/db/dao/completion_dataset_dao.py create mode 100644 api/app/db/tables/completion_dataset_metrics_table.py create mode 100644 api/app/db/tables/completion_dataset_table.py create mode 100644 api/app/models/completion_response.py create mode 100644 api/tests/commons/json_file_mock.py create mode 100644 api/tests/dao/completion_dataset_dao_test.py diff --git a/api/alembic/env.py b/api/alembic/env.py index 67be389e..3a031fc7 100644 --- a/api/alembic/env.py +++ b/api/alembic/env.py @@ -11,6 +11,8 @@ from app.db.tables.reference_dataset_metrics_table import * from app.db.tables.current_dataset_table import * from app.db.tables.current_dataset_metrics_table import * +from app.db.tables.completion_dataset_table import * +from app.db.tables.completion_dataset_metrics_table import * from app.db.tables.commons.json_encoded_dict import JSONEncodedDict from app.db.database import Database, BaseTable diff --git a/api/alembic/versions/e72dc7aaa4cc_add_dataset_and_metrics_completion.py b/api/alembic/versions/e72dc7aaa4cc_add_dataset_and_metrics_completion.py new file mode 100644 index 00000000..4fc7c8d5 --- /dev/null +++ b/api/alembic/versions/e72dc7aaa4cc_add_dataset_and_metrics_completion.py @@ -0,0 +1,46 @@ +"""add_dataset_and_metrics_completion + +Revision ID: e72dc7aaa4cc +Revises: dccb82489f4d +Create Date: 2024-12-11 13:33:38.759485 + +""" +from typing import Sequence, Union, Text + +from alembic import op +import sqlalchemy as sa +from app.db.tables.commons.json_encoded_dict import JSONEncodedDict + +# revision identifiers, used by Alembic. +revision: str = 'e72dc7aaa4cc' +down_revision: Union[str, None] = 'dccb82489f4d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('completion_dataset', + sa.Column('UUID', sa.UUID(), nullable=False), + sa.Column('MODEL_UUID', sa.UUID(), nullable=False), + sa.Column('PATH', sa.VARCHAR(), nullable=False), + sa.Column('DATE', sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column('STATUS', sa.VARCHAR(), nullable=False), + sa.ForeignKeyConstraint(['MODEL_UUID'], ['model.UUID'], name=op.f('fk_completion_dataset_MODEL_UUID_model')), + sa.PrimaryKeyConstraint('UUID', name=op.f('pk_completion_dataset')) + ) + op.create_table('completion_dataset_metrics', + sa.Column('UUID', sa.UUID(), nullable=False), + sa.Column('COMPLETION_UUID', sa.UUID(), nullable=False), + sa.Column('MODEL_QUALITY', JSONEncodedDict(astext_type=Text()), nullable=True), + sa.ForeignKeyConstraint(['COMPLETION_UUID'], ['completion_dataset.UUID'], name=op.f('fk_completion_dataset_metrics_COMPLETION_UUID_completion_dataset')), + sa.PrimaryKeyConstraint('UUID', name=op.f('pk_completion_dataset_metrics')) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('completion_dataset_metrics') + op.drop_table('completion_dataset') + # ### end Alembic commands ### diff --git a/api/app/core/config/config.py b/api/app/core/config/config.py index 08236d71..e5cf2318 100644 --- a/api/app/core/config/config.py +++ b/api/app/core/config/config.py @@ -56,6 +56,7 @@ class SparkConfig(BaseSettings): spark_image_pull_policy: str = 'IfNotPresent' spark_reference_app_path: str = 'local:///opt/spark/custom_jobs/reference_job.py' spark_current_app_path: str = 'local:///opt/spark/custom_jobs/current_job.py' + spark_completion_app_path: str = 'local:///opt/spark/custom_jobs/completion_job.py' spark_namespace: str = 'spark' spark_service_account: str = 'spark' diff --git a/api/app/db/dao/completion_dataset_dao.py b/api/app/db/dao/completion_dataset_dao.py new file mode 100644 index 00000000..75d4e90b --- /dev/null +++ b/api/app/db/dao/completion_dataset_dao.py @@ -0,0 +1,87 @@ +import re +from typing import List, Optional +from uuid import UUID + +from fastapi_pagination import Page, Params +from fastapi_pagination.ext.sqlalchemy import paginate +from sqlalchemy import asc, desc +from sqlalchemy.future import select as future_select + +from app.db.database import Database +from app.db.tables.completion_dataset_table import CompletionDataset +from app.models.dataset_dto import OrderType + + +class CompletionDatasetDAO: + def __init__(self, database: Database) -> None: + self.db = database + + def insert_completion_dataset( + self, completion_dataset: CompletionDataset + ) -> CompletionDataset: + with self.db.begin_session() as session: + session.add(completion_dataset) + session.flush() + return completion_dataset + + def get_completion_dataset_by_model_uuid( + self, model_uuid: UUID, completion_uuid: UUID + ) -> Optional[CompletionDataset]: + with self.db.begin_session() as session: + return ( + session.query(CompletionDataset) + .where( + CompletionDataset.model_uuid == model_uuid, + CompletionDataset.uuid == completion_uuid, + ) + .one_or_none() + ) + + def get_latest_completion_dataset_by_model_uuid( + self, model_uuid: UUID + ) -> Optional[CompletionDataset]: + with self.db.begin_session() as session: + return ( + session.query(CompletionDataset) + .order_by(desc(CompletionDataset.date)) + .where(CompletionDataset.model_uuid == model_uuid) + .limit(1) + .one_or_none() + ) + + def get_all_completion_datasets_by_model_uuid( + self, + model_uuid: UUID, + ) -> List[CompletionDataset]: + with self.db.begin_session() as session: + return ( + session.query(CompletionDataset) + .order_by(desc(CompletionDataset.date)) + .where(CompletionDataset.model_uuid == model_uuid) + ) + + def get_all_completion_datasets_by_model_uuid_paginated( + self, + model_uuid: UUID, + params: Params = Params(), + order: OrderType = OrderType.ASC, + sort: Optional[str] = None, + ) -> Page[CompletionDataset]: + def order_by_column_name(column_name): + return CompletionDataset.__getattribute__( + CompletionDataset, re.sub('(?=[A-Z])', '_', column_name).lower() + ) + + with self.db.begin_session() as session: + stmt = future_select(CompletionDataset).where( + CompletionDataset.model_uuid == model_uuid + ) + + if sort: + stmt = ( + stmt.order_by(asc(order_by_column_name(sort))) + if order == OrderType.ASC + else stmt.order_by(desc(order_by_column_name(sort))) + ) + + return paginate(session, stmt, params) diff --git a/api/app/db/tables/completion_dataset_metrics_table.py b/api/app/db/tables/completion_dataset_metrics_table.py new file mode 100644 index 00000000..c9f8788e --- /dev/null +++ b/api/app/db/tables/completion_dataset_metrics_table.py @@ -0,0 +1,26 @@ +from uuid import uuid4 + +from sqlalchemy import UUID, Column, ForeignKey + +from app.db.dao.base_dao import BaseDAO +from app.db.database import BaseTable, Reflected +from app.db.tables.commons.json_encoded_dict import JSONEncodedDict + + +class CompletionDatasetMetrics(Reflected, BaseTable, BaseDAO): + __tablename__ = 'completion_dataset_metrics' + + uuid = Column( + 'UUID', + UUID(as_uuid=True), + nullable=False, + default=uuid4, + primary_key=True, + ) + completion_uuid = Column( + 'COMPLETION_UUID', + UUID(as_uuid=True), + ForeignKey('completion_dataset.UUID'), + nullable=False, + ) + model_quality = Column('MODEL_QUALITY', JSONEncodedDict, nullable=True) diff --git a/api/app/db/tables/completion_dataset_table.py b/api/app/db/tables/completion_dataset_table.py new file mode 100644 index 00000000..ddea2e7b --- /dev/null +++ b/api/app/db/tables/completion_dataset_table.py @@ -0,0 +1,25 @@ +from uuid import uuid4 + +from sqlalchemy import TIMESTAMP, UUID, VARCHAR, Column, ForeignKey + +from app.db.dao.base_dao import BaseDAO +from app.db.database import BaseTable, Reflected +from app.models.job_status import JobStatus + + +class CompletionDataset(Reflected, BaseTable, BaseDAO): + __tablename__ = 'completion_dataset' + + uuid = Column( + 'UUID', + UUID(as_uuid=True), + nullable=False, + default=uuid4, + primary_key=True, + ) + model_uuid = Column( + 'MODEL_UUID', UUID(as_uuid=True), ForeignKey('model.UUID'), nullable=False + ) + path = Column('PATH', VARCHAR, nullable=False) + date = Column('DATE', TIMESTAMP(timezone=True), nullable=False) + status = Column('STATUS', VARCHAR, nullable=False, default=JobStatus.IMPORTING) diff --git a/api/app/main.py b/api/app/main.py index b40f4914..0d86e44d 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -10,6 +10,7 @@ from starlette.middleware.cors import CORSMiddleware from app.core import get_config +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO from app.db.dao.model_dao import ModelDAO @@ -54,6 +55,7 @@ reference_dataset_metrics_dao = ReferenceDatasetMetricsDAO(database) current_dataset_dao = CurrentDatasetDAO(database) current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database) +completion_dataset_dao = CompletionDatasetDAO(database) model_service = ModelService( model_dao=model_dao, @@ -81,6 +83,7 @@ file_service = FileService( reference_dataset_dao, current_dataset_dao, + completion_dataset_dao, model_service, s3_client, spark_k8s_client, diff --git a/api/app/models/completion_response.py b/api/app/models/completion_response.py new file mode 100644 index 00000000..6886cf3b --- /dev/null +++ b/api/app/models/completion_response.py @@ -0,0 +1,78 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, RootModel, model_validator + + +class TokenLogProbs(BaseModel): + token: str + bytes: List[int] + logprob: float + top_logprobs: List[Dict[str, float]] + + +class LogProbs(BaseModel): + content: List[TokenLogProbs] + refusal: Optional[str] = None + + +class Message(BaseModel): + content: str + refusal: Optional[str] = None + role: str + tool_calls: List = [] + parsed: Optional[dict] = None + + +class Choice(BaseModel): + finish_reason: str + index: int + logprobs: Optional[LogProbs] = None + message: Message + + @model_validator(mode='after') + def validate_logprobs(self): + if self.logprobs is None: + raise ValueError( + "the 'logprobs' field cannot be empty, metrics could not be computed." + ) + return self + + +class UsageDetails(BaseModel): + accepted_prediction_tokens: int = 0 + reasoning_tokens: int = 0 + rejected_prediction_tokens: int = 0 + audio_tokens: Optional[int] = None + cached_tokens: Optional[int] = None + + +class Usage(BaseModel): + completion_tokens: int + prompt_tokens: int + total_tokens: int + completion_tokens_details: UsageDetails + prompt_tokens_details: Optional[UsageDetails] = None + + +class Completion(BaseModel): + id: str + choices: List[Choice] + created: int + model: str + object: str + system_fingerprint: str + usage: Usage + + +class CompletionResponses(RootModel[List[Completion]]): + @model_validator(mode='before') + @classmethod + def handle_single_completion(cls, data): + """If a single object is passed instead of a list, wrap it into a list.""" + if isinstance(data, dict): + return [data] + if isinstance(data, list): + return data + raise ValueError( + 'Input file must be a list of completion json or a single completion json' + ) diff --git a/api/app/models/dataset_dto.py b/api/app/models/dataset_dto.py index 8f6feb4e..77c3094f 100644 --- a/api/app/models/dataset_dto.py +++ b/api/app/models/dataset_dto.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.reference_dataset_table import ReferenceDataset @@ -55,6 +56,28 @@ def from_current_dataset(cd: CurrentDataset) -> 'CurrentDatasetDTO': ) +class CompletionDatasetDTO(BaseModel): + uuid: UUID + model_uuid: UUID + path: str + date: str + status: str + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + @staticmethod + def from_completion_dataset(cd: CompletionDataset) -> 'CompletionDatasetDTO': + return CompletionDatasetDTO( + uuid=cd.uuid, + model_uuid=cd.model_uuid, + path=cd.path, + date=cd.date.isoformat(), + status=cd.status, + ) + + class FileReference(BaseModel): file_url: str separator: str = ',' diff --git a/api/app/routes/upload_dataset_route.py b/api/app/routes/upload_dataset_route.py index bd906f48..d7ef7998 100644 --- a/api/app/routes/upload_dataset_route.py +++ b/api/app/routes/upload_dataset_route.py @@ -6,6 +6,7 @@ from fastapi_pagination import Page, Params from app.models.dataset_dto import ( + CompletionDatasetDTO, CurrentDatasetDTO, FileReference, OrderType, @@ -64,6 +65,16 @@ def bind_current_file( ) -> CurrentDatasetDTO: return file_service.bind_current_file(model_uuid, file_ref) + @router.post( + '/{model_uuid}/completion/upload', + status_code=status.HTTP_200_OK, + response_model=CompletionDatasetDTO, + ) + def upload_completion_file( + model_uuid: UUID, json_file: UploadFile = File(...) + ) -> CompletionDatasetDTO: + return file_service.upload_completion_file(model_uuid, json_file) + @router.get( '/{model_uuid}/reference', status_code=200, @@ -118,4 +129,31 @@ def get_all_current_datasets_by_model_uuid( ): return file_service.get_all_current_datasets_by_model_uuid(model_uuid) + @router.get( + '/{model_uuid}/completion', + status_code=200, + response_model=Page[CompletionDatasetDTO], + ) + def get_all_completion_datasets_by_model_uuid_paginated( + model_uuid: UUID, + _page: Annotated[int, Query()] = 1, + _limit: Annotated[int, Query()] = 50, + _order: Annotated[OrderType, Query()] = OrderType.ASC, + _sort: Annotated[Optional[str], Query()] = None, + ): + params = Params(page=_page, size=_limit) + return file_service.get_all_completion_datasets_by_model_uuid_paginated( + model_uuid, params=params, order=_order, sort=_sort + ) + + @router.get( + '/{model_uuid}/completion/all', + status_code=200, + response_model=List[CompletionDatasetDTO], + ) + def get_all_completion_datasets_by_model_uuid( + model_uuid: UUID, + ): + return file_service.get_all_completion_datasets_by_model_uuid(model_uuid) + return router diff --git a/api/app/services/file_service.py b/api/app/services/file_service.py index e2c78707..3558baae 100644 --- a/api/app/services/file_service.py +++ b/api/app/services/file_service.py @@ -1,5 +1,6 @@ from copy import deepcopy import datetime +import json import logging import pathlib from typing import List, Optional @@ -10,17 +11,23 @@ from fastapi import HTTPException, UploadFile from fastapi_pagination import Page, Params import pandas as pd +from pydantic import ValidationError from spark_on_k8s.client import ExecutorInstances, PodResources, SparkOnK8S from spark_on_k8s.utils.configuration import Configuration from app.core.config.config import create_secrets, get_config +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO +from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.reference_dataset_metrics_table import ReferenceDatasetMetrics from app.db.tables.reference_dataset_table import ReferenceDataset +from app.models.completion_response import CompletionResponses from app.models.dataset_dto import ( + CompletionDatasetDTO, CurrentDatasetDTO, FileReference, OrderType, @@ -48,12 +55,14 @@ def __init__( self, reference_dataset_dao: ReferenceDatasetDAO, current_dataset_dao: CurrentDatasetDAO, + completion_dataset_dao: CompletionDatasetDAO, model_service: ModelService, s3_client: boto3.client, spark_k8s_client: SparkOnK8S, ) -> 'FileService': self.rd_dao = reference_dataset_dao self.cd_dao = current_dataset_dao + self.completion_dataset_dao = completion_dataset_dao self.model_svc = model_service self.s3_client = s3_client s3_config = get_config().s3_config @@ -321,6 +330,69 @@ def bind_current_file( except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e + def upload_completion_file( + self, + model_uuid: UUID, + json_file: UploadFile, + ) -> CompletionDatasetDTO: + model_out = self.model_svc.get_model_by_uuid(model_uuid) + if not model_out: + logger.error('Model %s not found', model_uuid) + raise ModelNotFoundError(f'Model {model_uuid} not found') + + self.validate_json_file(json_file) + _f_name = json_file.filename + _f_uuid = uuid4() + try: + object_name = f'{str(model_out.uuid)}/completion/{_f_uuid}/{_f_name}' + self.s3_client.upload_fileobj( + json_file.file, + self.bucket_name, + object_name, + ExtraArgs={ + 'Metadata': { + 'model_uuid': str(model_out.uuid), + 'model_name': model_out.name, + 'file_type': 'completion', + } + }, + ) + + path = f's3://{self.bucket_name}/{object_name}' + + inserted_file = self.completion_dataset_dao.insert_completion_dataset( + CompletionDataset( + uuid=_f_uuid, + model_uuid=model_uuid, + path=path, + date=datetime.datetime.now(tz=datetime.UTC), + status=JobStatus.IMPORTING, + ) + ) + + logger.debug('File %s has been correctly stored in the db', inserted_file) + + spark_config = get_config().spark_config + self.__submit_app( + app_name=str(model_out.uuid), + app_path=spark_config.spark_completion_app_path, + app_arguments=[ + model_out.model_dump_json(), + path.replace('s3://', 's3a://'), + str(inserted_file.uuid), + CompletionDatasetMetrics.__tablename__, + ], + ) + + return CompletionDatasetDTO.from_completion_dataset(inserted_file) + + except NoCredentialsError as nce: + raise HTTPException( + status_code=500, detail='S3 credentials not available' + ) from nce + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + def get_all_reference_datasets_by_model_uuid_paginated( self, model_uuid: UUID, @@ -382,6 +454,40 @@ def get_all_current_datasets_by_model_uuid( for current_dataset in currents ] + def get_all_completion_datasets_by_model_uuid_paginated( + self, + model_uuid: UUID, + params: Params = Params(), + order: OrderType = OrderType.ASC, + sort: Optional[str] = None, + ) -> Page[CompletionDatasetDTO]: + results: Page[CompletionDatasetDTO] = ( + self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid_paginated( + model_uuid, params=params, order=order, sort=sort + ) + ) + + _items = [ + CompletionDatasetDTO.from_completion_dataset(completion_dataset) + for completion_dataset in results.items + ] + + return Page.create(items=_items, params=params, total=results.total) + + def get_all_completion_datasets_by_model_uuid( + self, + model_uuid: UUID, + ) -> List[CompletionDatasetDTO]: + completions = ( + self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid( + model_uuid + ) + ) + return [ + CompletionDatasetDTO.from_completion_dataset(completion_dataset) + for completion_dataset in completions + ] + @staticmethod def infer_schema(csv_file: UploadFile, sep: str = ',') -> InferredSchemaDTO: FileService.validate_file(csv_file, sep) @@ -449,6 +555,19 @@ def validate_file( csv_file.file.flush() csv_file.file.seek(0) + @staticmethod + def validate_json_file(json_file: UploadFile) -> None: + try: + content = json_file.file.read().decode('utf-8') + json_data = json.loads(content) + CompletionResponses.model_validate(json_data) + except ValidationError as e: + logger.error('Invalid json file: %s', str(e)) + raise InvalidFileException(f'Invalid json file: {str(e)}') from e + except Exception as e: + logger.error('Error while reading the json file: %s', str(e)) + raise InvalidFileException(f'Invalid json file: {str(e)}') from e + def __submit_app( self, app_name: str, app_path: str, app_arguments: List[str] ) -> None: diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index ba675605..13f0bd0a 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional import uuid +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.model_table import Model @@ -23,6 +24,7 @@ MODEL_UUID = uuid.uuid4() REFERENCE_UUID = uuid.uuid4() CURRENT_UUID = uuid.uuid4() +COMPLETION_UUID = uuid.uuid4() def get_sample_model( @@ -172,6 +174,21 @@ def get_sample_current_dataset( ) +def get_sample_completion_dataset( + uuid: uuid.UUID = COMPLETION_UUID, + model_uuid: uuid.UUID = MODEL_UUID, + path: str = 'completion/json_file.json', + status: str = JobStatus.IMPORTING.value, +) -> CompletionDataset: + return CompletionDataset( + uuid=uuid, + model_uuid=model_uuid, + path=path, + date=datetime.datetime.now(tz=datetime.UTC), + status=status, + ) + + statistics_dict = { 'nVariables': 10, 'nObservations': 1000, diff --git a/api/tests/commons/json_file_mock.py b/api/tests/commons/json_file_mock.py new file mode 100644 index 00000000..c42ace1f --- /dev/null +++ b/api/tests/commons/json_file_mock.py @@ -0,0 +1,106 @@ +from io import BytesIO + +from fastapi import UploadFile + + +def get_completion_sample_json_file() -> UploadFile: + json_content = """ + { + "id": "chatcmpl-0120", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": { + "content": [ + { + "token": "Sky", + "bytes": [83, 107, 121], + "logprob": -1.2830728, + "top_logprobs": [] + } + ], + "refusal": null + }, + "message": { + "content": "Sky is blue.", + "refusal": null, + "role": "assistant", + "tool_calls": [], + "parsed": null + } + } + ], + "created": 1733486708, + "model": "gpt-4o-2024-08-06", + "object": "chat.completion", + "system_fingerprint": "fp_c7ca0ebaca", + "usage": { + "completion_tokens": 4, + "prompt_tokens": 25, + "total_tokens": 29, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0 + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0 + } + } + }""" + return UploadFile(filename='test.json', file=BytesIO(json_content.encode())) + + +def get_completion_sample_json_file_without_logprobs_field() -> UploadFile: + json_content = """ + { + "id": "chatcmpl-0120", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Sky is blue.", + "refusal": null, + "role": "assistant", + "tool_calls": [], + "parsed": null + } + } + ], + "created": 1733486708, + "model": "gpt-4o-2024-08-06", + "object": "chat.completion", + "system_fingerprint": "fp_c7ca0ebaca", + "usage": { + "completion_tokens": 4, + "prompt_tokens": 25, + "total_tokens": 29, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0 + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0 + } + } + }""" + return UploadFile(filename='test.json', file=BytesIO(json_content.encode())) + + +def get_incorrect_sample_json_file() -> UploadFile: + json_content = """ + { + "id": "chatcmpl-0120", + "created": 1733486708, + "model": "gpt-4o-2024-08-06", + "object": "chat.completion", + "system_fingerprint": "fp_c7ca0ebaca", + }""" + return UploadFile(filename='test.json', file=BytesIO(json_content.encode())) diff --git a/api/tests/dao/completion_dataset_dao_test.py b/api/tests/dao/completion_dataset_dao_test.py new file mode 100644 index 00000000..dd40cce8 --- /dev/null +++ b/api/tests/dao/completion_dataset_dao_test.py @@ -0,0 +1,126 @@ +import datetime +from uuid import uuid4 + +from fastapi_pagination import Params + +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO +from app.db.dao.model_dao import ModelDAO +from app.db.tables.completion_dataset_table import CompletionDataset +from tests.commons import db_mock +from tests.commons.db_integration import DatabaseIntegration + + +class CompletionDatasetDAOTest(DatabaseIntegration): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.completion_dataset_dao = CompletionDatasetDAO(cls.db) + cls.model_dao = ModelDAO(cls.db) + + def test_insert_completion_dataset_upload_result(self): + model = self.model_dao.insert(db_mock.get_sample_model()) + to_insert = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + + inserted = self.completion_dataset_dao.insert_completion_dataset(to_insert) + assert inserted == to_insert + + def test_get_current_dataset_by_model_uuid(self): + model = self.model_dao.insert(db_mock.get_sample_model()) + to_insert = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + + inserted = self.completion_dataset_dao.insert_completion_dataset(to_insert) + retrieved = self.completion_dataset_dao.get_completion_dataset_by_model_uuid( + inserted.model_uuid, inserted.uuid + ) + assert inserted.uuid == retrieved.uuid + assert inserted.model_uuid == retrieved.model_uuid + assert inserted.path == retrieved.path + + def test_get_latest_completion_dataset_by_model_uuid(self): + model = self.model_dao.insert(db_mock.get_sample_model()) + completion_one = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + + self.completion_dataset_dao.insert_completion_dataset(completion_one) + + completion_two = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + + inserted_two = self.completion_dataset_dao.insert_completion_dataset( + completion_two + ) + + retrieved = ( + self.completion_dataset_dao.get_latest_completion_dataset_by_model_uuid( + model.uuid + ) + ) + assert inserted_two.uuid == retrieved.uuid + assert inserted_two.model_uuid == retrieved.model_uuid + assert inserted_two.path == retrieved.path + + def test_get_all_completion_datasets_by_model_uuid_paginated(self): + model = self.model_dao.insert(db_mock.get_sample_model()) + completion_upload_1 = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + completion_upload_2 = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + completion_upload_3 = CompletionDataset( + uuid=uuid4(), + model_uuid=model.uuid, + path='json_file.json', + date=datetime.datetime.now(tz=datetime.UTC), + ) + inserted_1 = self.completion_dataset_dao.insert_completion_dataset( + completion_upload_1 + ) + inserted_2 = self.completion_dataset_dao.insert_completion_dataset( + completion_upload_2 + ) + inserted_3 = self.completion_dataset_dao.insert_completion_dataset( + completion_upload_3 + ) + + retrieved = self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid_paginated( + model.uuid, Params(page=1, size=10) + ) + + assert inserted_1.uuid == retrieved.items[0].uuid + assert inserted_1.model_uuid == retrieved.items[0].model_uuid + assert inserted_1.path == retrieved.items[0].path + + assert inserted_2.uuid == retrieved.items[1].uuid + assert inserted_2.model_uuid == retrieved.items[1].model_uuid + assert inserted_2.path == retrieved.items[1].path + + assert inserted_3.uuid == retrieved.items[2].uuid + assert inserted_3.model_uuid == retrieved.items[2].model_uuid + assert inserted_3.path == retrieved.items[2].path + + assert len(retrieved.items) == 3 diff --git a/api/tests/routes/upload_dataset_route_test.py b/api/tests/routes/upload_dataset_route_test.py index 994f8438..27b21cab 100644 --- a/api/tests/routes/upload_dataset_route_test.py +++ b/api/tests/routes/upload_dataset_route_test.py @@ -9,6 +9,7 @@ from starlette.testclient import TestClient from app.models.dataset_dto import ( + CompletionDatasetDTO, CurrentDatasetDTO, FileReference, OrderType, @@ -17,7 +18,7 @@ from app.models.job_status import JobStatus from app.routes.upload_dataset_route import UploadDatasetRoute from app.services.file_service import FileService -from tests.commons import csv_file_mock as csv, db_mock +from tests.commons import csv_file_mock as csv, db_mock, json_file_mock as json class UploadDatasetRouteTest(unittest.TestCase): @@ -113,6 +114,26 @@ def test_bind_current(self): assert res.status_code == 200 assert jsonable_encoder(upload_file_result) == res.json() + def test_upload_completion(self): + file = json.get_completion_sample_json_file() + model_uuid = uuid.uuid4() + upload_file_result = CompletionDatasetDTO( + uuid=uuid.uuid4(), + model_uuid=model_uuid, + path='test', + date=str(datetime.datetime.now(tz=datetime.UTC)), + status=JobStatus.IMPORTING, + ) + self.file_service.upload_completion_file = MagicMock( + return_value=upload_file_result + ) + res = self.client.post( + f'{self.prefix}/{model_uuid}/completion/upload', + files={'json_file': (file.filename, file.file)}, + ) + assert res.status_code == 200 + assert jsonable_encoder(upload_file_result) == res.json() + def test_get_all_reference_datasets_by_model_uuid_paginated(self): test_model_uuid = uuid.uuid4() reference_upload_1 = db_mock.get_sample_reference_dataset( @@ -181,6 +202,40 @@ def test_get_all_current_datasets_by_model_uuid_paginated(self): sort=None, ) + def test_get_all_completion_datasets_by_model_uuid_paginated(self): + test_model_uuid = uuid.uuid4() + completion_upload_1 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_1.json' + ) + completion_upload_2 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_2.json' + ) + completion_upload_3 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_3.json' + ) + + sample_results = [ + CompletionDatasetDTO.from_completion_dataset(completion_upload_1), + CompletionDatasetDTO.from_completion_dataset(completion_upload_2), + CompletionDatasetDTO.from_completion_dataset(completion_upload_3), + ] + page = Page.create( + items=sample_results, total=len(sample_results), params=Params() + ) + self.file_service.get_all_completion_datasets_by_model_uuid_paginated = ( + MagicMock(return_value=page) + ) + + res = self.client.get(f'{self.prefix}/{test_model_uuid}/completion') + assert res.status_code == 200 + assert jsonable_encoder(page) == res.json() + self.file_service.get_all_completion_datasets_by_model_uuid_paginated.assert_called_once_with( + test_model_uuid, + params=Params(page=1, size=50), + order=OrderType.ASC, + sort=None, + ) + def test_get_all_reference_datasets_by_model_uuid(self): test_model_uuid = uuid.uuid4() reference_upload_1 = db_mock.get_sample_reference_dataset( @@ -236,3 +291,31 @@ def test_get_all_current_datasets_by_model_uuid(self): self.file_service.get_all_current_datasets_by_model_uuid.assert_called_once_with( test_model_uuid, ) + + def test_get_all_completion_datasets_by_model_uuid(self): + test_model_uuid = uuid.uuid4() + completion_upload_1 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_1.json' + ) + completion_upload_2 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_2.json' + ) + completion_upload_3 = db_mock.get_sample_completion_dataset( + model_uuid=test_model_uuid, path='completion/test_3.json' + ) + + sample_results = [ + CompletionDatasetDTO.from_completion_dataset(completion_upload_1), + CompletionDatasetDTO.from_completion_dataset(completion_upload_2), + CompletionDatasetDTO.from_completion_dataset(completion_upload_3), + ] + self.file_service.get_all_completion_datasets_by_model_uuid = MagicMock( + return_value=sample_results + ) + + res = self.client.get(f'{self.prefix}/{test_model_uuid}/completion/all') + assert res.status_code == 200 + assert jsonable_encoder(sample_results) == res.json() + self.file_service.get_all_completion_datasets_by_model_uuid.assert_called_once_with( + test_model_uuid, + ) diff --git a/api/tests/services/file_service_test.py b/api/tests/services/file_service_test.py index 96231650..e326041a 100644 --- a/api/tests/services/file_service_test.py +++ b/api/tests/services/file_service_test.py @@ -8,17 +8,24 @@ from fastapi_pagination import Page, Params import pytest +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.reference_dataset_table import ReferenceDataset -from app.models.dataset_dto import CurrentDatasetDTO, FileReference, ReferenceDatasetDTO +from app.models.dataset_dto import ( + CompletionDatasetDTO, + CurrentDatasetDTO, + FileReference, + ReferenceDatasetDTO, +) from app.models.exceptions import InvalidFileException, ModelNotFoundError from app.models.job_status import JobStatus from app.models.model_dto import ModelOut from app.services.file_service import FileService from app.services.model_service import ModelService -from tests.commons import csv_file_mock as csv, db_mock +from tests.commons import csv_file_mock as csv, db_mock, json_file_mock as json from tests.commons.db_mock import get_sample_reference_dataset @@ -27,15 +34,22 @@ class FileServiceTest(unittest.TestCase): def setUpClass(cls): cls.rd_dao = MagicMock(spec_set=ReferenceDatasetDAO) cls.cd_dao = MagicMock(spec_set=CurrentDatasetDAO) + cls.completion_dataset_dao = MagicMock(spec_set=CompletionDatasetDAO) cls.model_svc = MagicMock(spec_set=ModelService) cls.s3_client = MagicMock() cls.spark_k8s_client = MagicMock() cls.files_service = FileService( - cls.rd_dao, cls.cd_dao, cls.model_svc, cls.s3_client, cls.spark_k8s_client + cls.rd_dao, + cls.cd_dao, + cls.completion_dataset_dao, + cls.model_svc, + cls.s3_client, + cls.spark_k8s_client, ) cls.mocks = [ cls.rd_dao, cls.cd_dao, + cls.completion_dataset_dao, cls.model_svc, cls.s3_client, cls.spark_k8s_client, @@ -67,6 +81,24 @@ def test_infer_schema_separator(self): schema = FileService.infer_schema(file, sep=';') assert schema == csv.correct_schema() + def test_validate_completion_json_file_ok(self): + json_file = json.get_completion_sample_json_file() + self.files_service.validate_json_file(json_file) + + def test_validate_completion_json_file_without_logprobs_field(self): + json_file = json.get_completion_sample_json_file_without_logprobs_field() + with pytest.raises(InvalidFileException) as ex: + self.files_service.validate_json_file(json_file) + assert ( + "the 'logprobs' field cannot be empty, metrics could not be computed." + in str(ex.value) + ) + + def test_validate_completion_json_file_error(self): + json_file = json.get_incorrect_sample_json_file() + with pytest.raises(InvalidFileException): + self.files_service.validate_json_file(json_file) + def test_upload_reference_file_ok(self): file = csv.get_correct_sample_csv_file() dataset_uuid = uuid4() @@ -253,6 +285,44 @@ def test_upload_current_file_reference_file_not_found(self): correlation_id_column, ) + def test_upload_completion_file_ok(self): + file = json.get_completion_sample_json_file() + model = db_mock.get_sample_model( + features=None, + outputs=None, + target=None, + timestamp=None, + ) + object_name = f'{str(model.uuid)}/completion/{file.filename}' + path = f's3://bucket/{object_name}' + inserted_file = CompletionDataset( + uuid=uuid4(), + model_uuid=model_uuid, + path=path, + date=datetime.datetime.now(tz=datetime.UTC), + status=JobStatus.IMPORTING, + ) + + self.model_svc.get_model_by_uuid = MagicMock( + return_value=ModelOut.from_model(model) + ) + self.s3_client.upload_fileobj = MagicMock() + self.completion_dataset_dao.insert_completion_dataset = MagicMock( + return_value=inserted_file + ) + self.spark_k8s_client.submit_app = MagicMock() + + result = self.files_service.upload_completion_file( + model.uuid, + file, + ) + + self.model_svc.get_model_by_uuid.assert_called_once() + self.completion_dataset_dao.insert_completion_dataset.assert_called_once() + self.s3_client.upload_fileobj.assert_called_once() + self.spark_k8s_client.submit_app.assert_called_once() + assert result == CompletionDatasetDTO.from_completion_dataset(inserted_file) + def test_get_all_reference_datasets_by_model_uuid_paginated(self): reference_upload_1 = db_mock.get_sample_reference_dataset( model_uuid=model_uuid, path='reference/test_1.csv' @@ -311,6 +381,35 @@ def test_get_all_current_datasets_by_model_uuid_paginated(self): assert result.items[1].model_uuid == model_uuid assert result.items[2].model_uuid == model_uuid + def test_get_all_completion_datasets_by_model_uuid_paginated(self): + completion_upload_1 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_1.json' + ) + completion_upload_2 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_2.json' + ) + completion_upload_3 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_3.json' + ) + + sample_results = [completion_upload_1, completion_upload_2, completion_upload_3] + page = Page.create( + sample_results, total=len(sample_results), params=Params(page=1, size=10) + ) + self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid_paginated = MagicMock( + return_value=page + ) + + result = self.files_service.get_all_completion_datasets_by_model_uuid_paginated( + model_uuid, Params(page=1, size=10) + ) + + assert result.total == 3 + assert len(result.items) == 3 + assert result.items[0].model_uuid == model_uuid + assert result.items[1].model_uuid == model_uuid + assert result.items[2].model_uuid == model_uuid + def test_get_all_reference_datasets_by_model_uuid(self): reference_upload_1 = db_mock.get_sample_reference_dataset( model_uuid=model_uuid, path='reference/test_1.csv' @@ -357,5 +456,30 @@ def test_get_all_current_datasets_by_model_uuid(self): assert result[1].model_uuid == model_uuid assert result[2].model_uuid == model_uuid + def test_get_all_completion_datasets_by_model_uuid(self): + completion_upload_1 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_1.json' + ) + completion_upload_2 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_2.json' + ) + completion_upload_3 = db_mock.get_sample_completion_dataset( + model_uuid=model_uuid, path='completion/test_3.json' + ) + + sample_results = [completion_upload_1, completion_upload_2, completion_upload_3] + self.completion_dataset_dao.get_all_completion_datasets_by_model_uuid = ( + MagicMock(return_value=sample_results) + ) + + result = self.files_service.get_all_completion_datasets_by_model_uuid( + model_uuid + ) + + assert len(result) == 3 + assert result[0].model_uuid == model_uuid + assert result[1].model_uuid == model_uuid + assert result[2].model_uuid == model_uuid + model_uuid = db_mock.MODEL_UUID