From f3b750e0fa89930f5d58f6b513aedadb6565baa1 Mon Sep 17 00:00:00 2001 From: Daniele Tria <36860433+dtria91@users.noreply.github.com> Date: Mon, 16 Dec 2024 17:33:29 +0100 Subject: [PATCH] feat: define text generation metrics api (#213) * add completion model quality endpoint * add get completion model quality endpoint * fix test --- .../db/dao/completion_dataset_metrics_dao.py | 36 ++++++++++ api/app/main.py | 4 ++ api/app/models/dataset_type.py | 1 + api/app/models/metrics/model_quality_dto.py | 48 +++++++++++++ api/app/routes/metrics_route.py | 12 ++++ api/app/services/metrics_service.py | 47 ++++++++++-- api/tests/commons/db_mock.py | 35 +++++++++ .../completion_dataset_metrics_dao_test.py | 56 +++++++++++++++ api/tests/routes/metrics_route_test.py | 23 ++++++ api/tests/services/metrics_service_test.py | 71 +++++++++++++++++++ 10 files changed, 329 insertions(+), 4 deletions(-) create mode 100644 api/app/db/dao/completion_dataset_metrics_dao.py create mode 100644 api/tests/dao/completion_dataset_metrics_dao_test.py diff --git a/api/app/db/dao/completion_dataset_metrics_dao.py b/api/app/db/dao/completion_dataset_metrics_dao.py new file mode 100644 index 00000000..dd4e8894 --- /dev/null +++ b/api/app/db/dao/completion_dataset_metrics_dao.py @@ -0,0 +1,36 @@ +from typing import Optional +from uuid import UUID + +from app.db.database import Database +from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics +from app.db.tables.completion_dataset_table import CompletionDataset + + +class CompletionDatasetMetricsDAO: + def __init__(self, database: Database) -> None: + self.db = database + + def insert_completion_metrics( + self, completion_metrics: CompletionDatasetMetrics + ) -> CompletionDatasetMetrics: + with self.db.begin_session() as session: + session.add(completion_metrics) + session.flush() + return completion_metrics + + def get_completion_metrics_by_model_uuid( + self, model_uuid: UUID, completion_uuid: UUID + ) -> Optional[CompletionDatasetMetrics]: + with self.db.begin_session() as session: + return ( + session.query(CompletionDatasetMetrics) + .join( + CompletionDataset, + CompletionDatasetMetrics.completion_uuid == CompletionDataset.uuid, + ) + .where( + CompletionDataset.model_uuid == model_uuid, + CompletionDataset.uuid == completion_uuid, + ) + .one_or_none() + ) diff --git a/api/app/main.py b/api/app/main.py index a221e746..130ad772 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -11,6 +11,7 @@ from app.core import get_config from app.db.dao.completion_dataset_dao import CompletionDatasetDAO +from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO from app.db.dao.model_dao import ModelDAO @@ -56,6 +57,7 @@ current_dataset_dao = CurrentDatasetDAO(database) current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database) completion_dataset_dao = CompletionDatasetDAO(database) +completion_dataset_metrics_dao = CompletionDatasetMetricsDAO(database) model_service = ModelService( model_dao=model_dao, @@ -94,6 +96,8 @@ reference_dataset_dao=reference_dataset_dao, current_dataset_metrics_dao=current_dataset_metrics_dao, current_dataset_dao=current_dataset_dao, + completion_dataset_metrics_dao=completion_dataset_metrics_dao, + completion_dataset_dao=completion_dataset_dao, model_service=model_service, ) spark_k8s_service = SparkK8SService(spark_k8s_client) diff --git a/api/app/models/dataset_type.py b/api/app/models/dataset_type.py index 845096eb..1ceb74b6 100644 --- a/api/app/models/dataset_type.py +++ b/api/app/models/dataset_type.py @@ -4,3 +4,4 @@ class DatasetType(str, Enum): REFERENCE = 'REFERENCE' CURRENT = 'CURRENT' + COMPLETION = 'COMPLETION' diff --git a/api/app/models/metrics/model_quality_dto.py b/api/app/models/metrics/model_quality_dto.py index ae43ed54..3478a229 100644 --- a/api/app/models/metrics/model_quality_dto.py +++ b/api/app/models/metrics/model_quality_dto.py @@ -192,6 +192,39 @@ class CurrentRegressionModelQuality(BaseModel): model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) +class TokenProb(BaseModel): + prob: float + token: str + + +class TokenData(BaseModel): + id: str + probs: List[TokenProb] + + +class MeanPerFile(BaseModel): + prob_tot_mean: float + perplex_tot_mean: float + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class MeanPerPhrase(BaseModel): + id: str + prob_per_phrase: float + perplex_per_phrase: float + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class CompletionTextGenerationModelQuality(BaseModel): + tokens: List[TokenData] + mean_per_file: List[MeanPerFile] + mean_per_phrase: List[MeanPerPhrase] + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + class ModelQualityDTO(BaseModel): job_status: JobStatus model_quality: Optional[ @@ -201,6 +234,7 @@ class ModelQualityDTO(BaseModel): | CurrentMultiClassificationModelQuality | RegressionModelQuality | CurrentRegressionModelQuality + | CompletionTextGenerationModelQuality ] model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @@ -251,6 +285,10 @@ def _create_model_quality( return ModelQualityDTO._create_regression_model_quality( dataset_type=dataset_type, model_quality_data=model_quality_data ) + if model_type == ModelType.TEXT_GENERATION: + return ModelQualityDTO._create_text_generation_model_quality( + dataset_type=dataset_type, model_quality_data=model_quality_data + ) raise MetricsInternalError(f'Invalid model type {model_type}') @staticmethod @@ -288,3 +326,13 @@ def _create_regression_model_quality( if dataset_type == DatasetType.CURRENT: return CurrentRegressionModelQuality(**model_quality_data) raise MetricsInternalError(f'Invalid dataset type {dataset_type}') + + @staticmethod + def _create_text_generation_model_quality( + dataset_type: DatasetType, + model_quality_data: Dict, + ) -> CompletionTextGenerationModelQuality: + """Create a text generation model quality instance based on dataset type.""" + if dataset_type == DatasetType.COMPLETION: + return CompletionTextGenerationModelQuality(**model_quality_data) + raise MetricsInternalError(f'Invalid dataset type {dataset_type}') diff --git a/api/app/routes/metrics_route.py b/api/app/routes/metrics_route.py index c906d50c..0df1a6f1 100644 --- a/api/app/routes/metrics_route.py +++ b/api/app/routes/metrics_route.py @@ -151,4 +151,16 @@ def get_current_percentages_by_model_by_uuid( model_uuid, current_uuid ) + @router.get( + '/{model_uuid}/completion/{completion_uuid}/model-quality', + status_code=200, + response_model=ModelQualityDTO, + ) + def get_completion_model_quality_by_model_by_uuid( + model_uuid: UUID, completion_uuid: UUID + ): + return metrics_service.get_completion_model_quality_by_model_by_uuid( + model_uuid, completion_uuid + ) + return router diff --git a/api/app/services/metrics_service.py b/api/app/services/metrics_service.py index 973a4f06..55356660 100644 --- a/api/app/services/metrics_service.py +++ b/api/app/services/metrics_service.py @@ -2,10 +2,14 @@ from typing import Optional from uuid import UUID +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO +from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO from app.db.dao.reference_dataset_metrics_dao import ReferenceDatasetMetricsDAO +from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.reference_dataset_metrics_table import ReferenceDatasetMetrics @@ -29,12 +33,16 @@ def __init__( reference_dataset_dao: ReferenceDatasetDAO, current_dataset_metrics_dao: CurrentDatasetMetricsDAO, current_dataset_dao: CurrentDatasetDAO, + completion_dataset_metrics_dao: CompletionDatasetMetricsDAO, + completion_dataset_dao: CompletionDatasetDAO, model_service: ModelService, ): self.reference_dataset_metrics_dao = reference_dataset_metrics_dao self.reference_dataset_dao = reference_dataset_dao self.current_dataset_metrics_dao = current_dataset_metrics_dao self.current_dataset_dao = current_dataset_dao + self.completion_dataset_metrics_dao = completion_dataset_metrics_dao + self.completion_dataset_dao = completion_dataset_dao self.model_service = model_service def get_reference_statistics_by_model_by_uuid( @@ -83,6 +91,19 @@ def get_current_model_quality_by_model_by_uuid( missing_status=JobStatus.MISSING_CURRENT, ) + def get_completion_model_quality_by_model_by_uuid( + self, model_uuid: UUID, completion_uuid: UUID + ) -> ModelQualityDTO: + """Retrieve completion model quality for a model by its UUID.""" + return self._get_model_quality_by_model_uuid( + model_uuid=model_uuid, + dataset_and_metrics_getter=lambda uuid: self.check_and_get_completion_dataset_and_metrics( + uuid, completion_uuid + ), + dataset_type=DatasetType.COMPLETION, + missing_status=JobStatus.MISSING_COMPLETION, + ) + def get_reference_data_quality_by_model_by_uuid( self, model_uuid: UUID ) -> DataQualityDTO: @@ -162,12 +183,28 @@ def check_and_get_current_dataset_and_metrics( ), ) + def check_and_get_completion_dataset_and_metrics( + self, model_uuid: UUID, completion_uuid: UUID + ) -> tuple[Optional[CompletionDataset], Optional[CompletionDatasetMetrics]]: + """Check and retrieve the completion dataset and its metrics for a model by its UUID.""" + return self._check_and_get_dataset_and_metrics( + model_uuid=model_uuid, + dataset_getter=lambda uuid: self.completion_dataset_dao.get_completion_dataset_by_model_uuid( + uuid, completion_uuid + ), + metrics_getter=lambda uuid: self.completion_dataset_metrics_dao.get_completion_metrics_by_model_uuid( + uuid, completion_uuid + ), + ) + @staticmethod def _check_and_get_dataset_and_metrics( model_uuid: UUID, dataset_getter, metrics_getter ) -> tuple[ - Optional[ReferenceDataset | CurrentDataset], - Optional[ReferenceDatasetMetrics | CurrentDatasetMetrics], + Optional[ReferenceDataset | CurrentDataset | CompletionDataset], + Optional[ + ReferenceDatasetMetrics | CurrentDatasetMetrics | CompletionDatasetMetrics + ], ]: """Check and retrieve the dataset and its metrics using the provided getters.""" dataset = dataset_getter(model_uuid) @@ -294,8 +331,10 @@ def _create_statistics_dto( def _create_model_quality_dto( dataset_type: DatasetType, model_type: ModelType, - dataset: Optional[ReferenceDataset | CurrentDataset], - metrics: Optional[ReferenceDatasetMetrics | CurrentDatasetMetrics], + dataset: Optional[ReferenceDataset | CurrentDataset | CompletionDataset], + metrics: Optional[ + ReferenceDatasetMetrics | CurrentDatasetMetrics | CompletionDatasetMetrics + ], missing_status, ) -> ModelQualityDTO: """Create a ModelQualityDTO from the provided dataset and metrics.""" diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index 13f0bd0a..b1d6188b 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional import uuid +from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics from app.db.tables.current_dataset_table import CurrentDataset @@ -545,6 +546,30 @@ def get_sample_completion_dataset( }, } +model_quality_completion_dict = { + 'tokens': [ + { + 'id': 'chatcmpl', + 'probs': [ + {'prob': 0.27718424797058105, 'token': 'Sky'}, + {'prob': 0.8951022028923035, 'token': ' is'}, + {'prob': 0.7038467526435852, 'token': ' blue'}, + {'prob': 0.9999753832817078, 'token': '.'}, + ], + } + ], + 'mean_per_file': [ + {'prob_tot_mean': 0.7190271615982056, 'perplex_tot_mean': 1.5469378232955933} + ], + 'mean_per_phrase': [ + { + 'id': 'chatcmpl', + 'prob_per_phrase': 0.7190271615982056, + 'perplex_per_phrase': 1.5469378232955933, + } + ], +} + def get_sample_reference_metrics( reference_uuid: uuid.UUID = REFERENCE_UUID, @@ -576,3 +601,13 @@ def get_sample_current_metrics( drift=drift, percentages=percentages, ) + + +def get_sample_completion_metrics( + completion_uuid: uuid.UUID = COMPLETION_UUID, + model_quality: Dict = model_quality_completion_dict, +) -> CompletionDatasetMetrics: + return CompletionDatasetMetrics( + completion_uuid=completion_uuid, + model_quality=model_quality, + ) diff --git a/api/tests/dao/completion_dataset_metrics_dao_test.py b/api/tests/dao/completion_dataset_metrics_dao_test.py new file mode 100644 index 00000000..fb9c82c4 --- /dev/null +++ b/api/tests/dao/completion_dataset_metrics_dao_test.py @@ -0,0 +1,56 @@ +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO +from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO +from app.db.dao.model_dao import ModelDAO +from app.models.model_dto import ModelType +from tests.commons.db_integration import DatabaseIntegration +from tests.commons.db_mock import ( + get_sample_completion_dataset, + get_sample_completion_metrics, + get_sample_model, +) + + +class CompletionDatasetMetricsDAOTest(DatabaseIntegration): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_dao = ModelDAO(cls.db) + cls.metrics_dao = CompletionDatasetMetricsDAO(cls.db) + cls.f_completion_dataset_dao = CompletionDatasetDAO(cls.db) + + def test_insert_completion_dataset_metrics(self): + self.model_dao.insert( + get_sample_model( + model_type=ModelType.TEXT_GENERATION, + features=None, + target=None, + outputs=None, + timestamp=None, + ) + ) + completion_upload = get_sample_completion_dataset() + self.f_completion_dataset_dao.insert_completion_dataset(completion_upload) + completion_metrics = get_sample_completion_metrics( + completion_uuid=completion_upload.uuid + ) + inserted = self.metrics_dao.insert_completion_metrics(completion_metrics) + assert inserted == completion_metrics + + def test_get_completion_metrics_by_model_uuid(self): + self.model_dao.insert( + get_sample_model( + model_type=ModelType.TEXT_GENERATION, + features=None, + target=None, + outputs=None, + timestamp=None, + ) + ) + completion = get_sample_completion_dataset() + self.f_completion_dataset_dao.insert_completion_dataset(completion) + metrics = get_sample_completion_metrics() + self.metrics_dao.insert_completion_metrics(metrics) + retrieved = self.metrics_dao.get_completion_metrics_by_model_uuid( + model_uuid=completion.model_uuid, completion_uuid=completion.uuid + ) + assert retrieved.uuid == metrics.uuid diff --git a/api/tests/routes/metrics_route_test.py b/api/tests/routes/metrics_route_test.py index 83f08a4e..a35cb822 100644 --- a/api/tests/routes/metrics_route_test.py +++ b/api/tests/routes/metrics_route_test.py @@ -200,3 +200,26 @@ def test_get_current_data_quality_by_model_by_uuid(self): self.metrics_service.get_current_data_quality_by_model_by_uuid.assert_called_once_with( model_uuid, current_uuid ) + + def test_get_completion_model_quality_by_model_by_uuid(self): + model_uuid = uuid.uuid4() + completion_uuid = uuid.uuid4() + completion_metrics = db_mock.get_sample_completion_metrics() + model_quality = ModelQualityDTO.from_dict( + dataset_type=DatasetType.COMPLETION, + model_type=ModelType.TEXT_GENERATION, + job_status=JobStatus.SUCCEEDED, + model_quality_data=completion_metrics.model_quality, + ) + self.metrics_service.get_completion_model_quality_by_model_by_uuid = MagicMock( + return_value=model_quality + ) + + res = self.client.get( + f'{self.prefix}/{model_uuid}/completion/{completion_uuid}/model-quality' + ) + assert res.status_code == 200 + assert jsonable_encoder(model_quality) == res.json() + self.metrics_service.get_completion_model_quality_by_model_by_uuid.assert_called_once_with( + model_uuid, completion_uuid + ) diff --git a/api/tests/services/metrics_service_test.py b/api/tests/services/metrics_service_test.py index c7808af8..ccc0c357 100644 --- a/api/tests/services/metrics_service_test.py +++ b/api/tests/services/metrics_service_test.py @@ -4,6 +4,8 @@ import pytest +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO +from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO @@ -36,12 +38,16 @@ def setUpClass(cls): cls.current_dataset_dao: CurrentDatasetDAO = MagicMock( spec_set=CurrentDatasetDAO ) + cls.completion_metrics_dao = MagicMock(spec_set=CompletionDatasetMetricsDAO) + cls.completion_dataset_dao = MagicMock(spec_set=CompletionDatasetDAO) cls.model_service: ModelService = MagicMock(spec_set=ModelService) cls.metrics_service = MetricsService( reference_dataset_metrics_dao=cls.reference_metrics_dao, reference_dataset_dao=cls.reference_dataset_dao, current_dataset_metrics_dao=cls.current_metrics_dao, current_dataset_dao=cls.current_dataset_dao, + completion_dataset_metrics_dao=cls.completion_metrics_dao, + completion_dataset_dao=cls.completion_dataset_dao, model_service=cls.model_service, ) cls.mocks = [ @@ -49,6 +55,8 @@ def setUpClass(cls): cls.reference_dataset_dao, cls.current_metrics_dao, cls.current_dataset_dao, + cls.completion_metrics_dao, + cls.completion_dataset_dao, ] def test_get_reference_statistics_by_model_by_uuid(self): @@ -666,6 +674,69 @@ def test_get_current_regression_model_quality_by_model_by_uuid(self): model_quality_data=current_metrics.model_quality, ) + def test_get_completion_text_generation_model_quality_by_model_by_uuid(self): + status = JobStatus.SUCCEEDED + completion_dataset = db_mock.get_sample_completion_dataset(status=status) + completion_metrics = db_mock.get_sample_completion_metrics() + model = db_mock.get_sample_model( + model_type=ModelType.TEXT_GENERATION, + target=None, + features=None, + outputs=None, + timestamp=None, + ) + self.model_service.get_model_by_uuid = MagicMock(return_value=model) + self.completion_dataset_dao.get_completion_dataset_by_model_uuid = MagicMock( + return_value=completion_dataset + ) + self.completion_metrics_dao.get_completion_metrics_by_model_uuid = MagicMock( + return_value=completion_metrics + ) + res = self.metrics_service.get_completion_model_quality_by_model_by_uuid( + model_uuid, completion_dataset.uuid + ) + self.completion_dataset_dao.get_completion_dataset_by_model_uuid.assert_called_once_with( + model_uuid, completion_dataset.uuid + ) + self.completion_metrics_dao.get_completion_metrics_by_model_uuid.assert_called_once_with( + model_uuid, completion_dataset.uuid + ) + + assert res == ModelQualityDTO.from_dict( + dataset_type=DatasetType.COMPLETION, + model_type=model.model_type, + job_status=completion_dataset.status, + model_quality_data=completion_metrics.model_quality, + ) + + def test_get_empty_completion_model_quality_by_model_by_uuid(self): + status = JobStatus.IMPORTING + completion_dataset = db_mock.get_sample_completion_dataset(status=status.value) + model = db_mock.get_sample_model( + model_type=ModelType.TEXT_GENERATION, + target=None, + features=None, + outputs=None, + timestamp=None, + ) + self.model_service.get_model_by_uuid = MagicMock(return_value=model) + self.completion_dataset_dao.get_completion_dataset_by_model_uuid = MagicMock( + return_value=completion_dataset + ) + res = self.metrics_service.get_completion_model_quality_by_model_by_uuid( + model_uuid, completion_dataset.uuid + ) + self.completion_dataset_dao.get_completion_dataset_by_model_uuid.assert_called_once_with( + model_uuid, completion_dataset.uuid + ) + + assert res == ModelQualityDTO.from_dict( + dataset_type=DatasetType.COMPLETION, + model_type=model.model_type, + job_status=completion_dataset.status, + model_quality_data=None, + ) + model_uuid = db_mock.MODEL_UUID current_uuid = db_mock.CURRENT_UUID