Skip to content

Commit

Permalink
feat: define text generation metrics api (#213)
Browse files Browse the repository at this point in the history
* add completion model quality endpoint

* add get completion model quality endpoint

* fix test
  • Loading branch information
dtria91 authored Dec 16, 2024
1 parent c4d34ea commit f3b750e
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 4 deletions.
36 changes: 36 additions & 0 deletions api/app/db/dao/completion_dataset_metrics_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional
from uuid import UUID

from app.db.database import Database
from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics
from app.db.tables.completion_dataset_table import CompletionDataset


class CompletionDatasetMetricsDAO:
def __init__(self, database: Database) -> None:
self.db = database

def insert_completion_metrics(
self, completion_metrics: CompletionDatasetMetrics
) -> CompletionDatasetMetrics:
with self.db.begin_session() as session:
session.add(completion_metrics)
session.flush()
return completion_metrics

def get_completion_metrics_by_model_uuid(
self, model_uuid: UUID, completion_uuid: UUID
) -> Optional[CompletionDatasetMetrics]:
with self.db.begin_session() as session:
return (
session.query(CompletionDatasetMetrics)
.join(
CompletionDataset,
CompletionDatasetMetrics.completion_uuid == CompletionDataset.uuid,
)
.where(
CompletionDataset.model_uuid == model_uuid,
CompletionDataset.uuid == completion_uuid,
)
.one_or_none()
)
4 changes: 4 additions & 0 deletions api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from app.core import get_config
from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO
from app.db.dao.model_dao import ModelDAO
Expand Down Expand Up @@ -56,6 +57,7 @@
current_dataset_dao = CurrentDatasetDAO(database)
current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database)
completion_dataset_dao = CompletionDatasetDAO(database)
completion_dataset_metrics_dao = CompletionDatasetMetricsDAO(database)

model_service = ModelService(
model_dao=model_dao,
Expand Down Expand Up @@ -94,6 +96,8 @@
reference_dataset_dao=reference_dataset_dao,
current_dataset_metrics_dao=current_dataset_metrics_dao,
current_dataset_dao=current_dataset_dao,
completion_dataset_metrics_dao=completion_dataset_metrics_dao,
completion_dataset_dao=completion_dataset_dao,
model_service=model_service,
)
spark_k8s_service = SparkK8SService(spark_k8s_client)
Expand Down
1 change: 1 addition & 0 deletions api/app/models/dataset_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
class DatasetType(str, Enum):
REFERENCE = 'REFERENCE'
CURRENT = 'CURRENT'
COMPLETION = 'COMPLETION'
48 changes: 48 additions & 0 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ class CurrentRegressionModelQuality(BaseModel):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class TokenProb(BaseModel):
prob: float
token: str


class TokenData(BaseModel):
id: str
probs: List[TokenProb]


class MeanPerFile(BaseModel):
prob_tot_mean: float
perplex_tot_mean: float

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MeanPerPhrase(BaseModel):
id: str
prob_per_phrase: float
perplex_per_phrase: float

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CompletionTextGenerationModelQuality(BaseModel):
tokens: List[TokenData]
mean_per_file: List[MeanPerFile]
mean_per_phrase: List[MeanPerPhrase]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ModelQualityDTO(BaseModel):
job_status: JobStatus
model_quality: Optional[
Expand All @@ -201,6 +234,7 @@ class ModelQualityDTO(BaseModel):
| CurrentMultiClassificationModelQuality
| RegressionModelQuality
| CurrentRegressionModelQuality
| CompletionTextGenerationModelQuality
]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Expand Down Expand Up @@ -251,6 +285,10 @@ def _create_model_quality(
return ModelQualityDTO._create_regression_model_quality(
dataset_type=dataset_type, model_quality_data=model_quality_data
)
if model_type == ModelType.TEXT_GENERATION:
return ModelQualityDTO._create_text_generation_model_quality(
dataset_type=dataset_type, model_quality_data=model_quality_data
)
raise MetricsInternalError(f'Invalid model type {model_type}')

@staticmethod
Expand Down Expand Up @@ -288,3 +326,13 @@ def _create_regression_model_quality(
if dataset_type == DatasetType.CURRENT:
return CurrentRegressionModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_text_generation_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> CompletionTextGenerationModelQuality:
"""Create a text generation model quality instance based on dataset type."""
if dataset_type == DatasetType.COMPLETION:
return CompletionTextGenerationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')
12 changes: 12 additions & 0 deletions api/app/routes/metrics_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,16 @@ def get_current_percentages_by_model_by_uuid(
model_uuid, current_uuid
)

@router.get(
'/{model_uuid}/completion/{completion_uuid}/model-quality',
status_code=200,
response_model=ModelQualityDTO,
)
def get_completion_model_quality_by_model_by_uuid(
model_uuid: UUID, completion_uuid: UUID
):
return metrics_service.get_completion_model_quality_by_model_by_uuid(
model_uuid, completion_uuid
)

return router
47 changes: 43 additions & 4 deletions api/app/services/metrics_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from typing import Optional
from uuid import UUID

from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.db.dao.reference_dataset_metrics_dao import ReferenceDatasetMetricsDAO
from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics
from app.db.tables.completion_dataset_table import CompletionDataset
from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics
from app.db.tables.current_dataset_table import CurrentDataset
from app.db.tables.reference_dataset_metrics_table import ReferenceDatasetMetrics
Expand All @@ -29,12 +33,16 @@ def __init__(
reference_dataset_dao: ReferenceDatasetDAO,
current_dataset_metrics_dao: CurrentDatasetMetricsDAO,
current_dataset_dao: CurrentDatasetDAO,
completion_dataset_metrics_dao: CompletionDatasetMetricsDAO,
completion_dataset_dao: CompletionDatasetDAO,
model_service: ModelService,
):
self.reference_dataset_metrics_dao = reference_dataset_metrics_dao
self.reference_dataset_dao = reference_dataset_dao
self.current_dataset_metrics_dao = current_dataset_metrics_dao
self.current_dataset_dao = current_dataset_dao
self.completion_dataset_metrics_dao = completion_dataset_metrics_dao
self.completion_dataset_dao = completion_dataset_dao
self.model_service = model_service

def get_reference_statistics_by_model_by_uuid(
Expand Down Expand Up @@ -83,6 +91,19 @@ def get_current_model_quality_by_model_by_uuid(
missing_status=JobStatus.MISSING_CURRENT,
)

def get_completion_model_quality_by_model_by_uuid(
self, model_uuid: UUID, completion_uuid: UUID
) -> ModelQualityDTO:
"""Retrieve completion model quality for a model by its UUID."""
return self._get_model_quality_by_model_uuid(
model_uuid=model_uuid,
dataset_and_metrics_getter=lambda uuid: self.check_and_get_completion_dataset_and_metrics(
uuid, completion_uuid
),
dataset_type=DatasetType.COMPLETION,
missing_status=JobStatus.MISSING_COMPLETION,
)

def get_reference_data_quality_by_model_by_uuid(
self, model_uuid: UUID
) -> DataQualityDTO:
Expand Down Expand Up @@ -162,12 +183,28 @@ def check_and_get_current_dataset_and_metrics(
),
)

def check_and_get_completion_dataset_and_metrics(
self, model_uuid: UUID, completion_uuid: UUID
) -> tuple[Optional[CompletionDataset], Optional[CompletionDatasetMetrics]]:
"""Check and retrieve the completion dataset and its metrics for a model by its UUID."""
return self._check_and_get_dataset_and_metrics(
model_uuid=model_uuid,
dataset_getter=lambda uuid: self.completion_dataset_dao.get_completion_dataset_by_model_uuid(
uuid, completion_uuid
),
metrics_getter=lambda uuid: self.completion_dataset_metrics_dao.get_completion_metrics_by_model_uuid(
uuid, completion_uuid
),
)

@staticmethod
def _check_and_get_dataset_and_metrics(
model_uuid: UUID, dataset_getter, metrics_getter
) -> tuple[
Optional[ReferenceDataset | CurrentDataset],
Optional[ReferenceDatasetMetrics | CurrentDatasetMetrics],
Optional[ReferenceDataset | CurrentDataset | CompletionDataset],
Optional[
ReferenceDatasetMetrics | CurrentDatasetMetrics | CompletionDatasetMetrics
],
]:
"""Check and retrieve the dataset and its metrics using the provided getters."""
dataset = dataset_getter(model_uuid)
Expand Down Expand Up @@ -294,8 +331,10 @@ def _create_statistics_dto(
def _create_model_quality_dto(
dataset_type: DatasetType,
model_type: ModelType,
dataset: Optional[ReferenceDataset | CurrentDataset],
metrics: Optional[ReferenceDatasetMetrics | CurrentDatasetMetrics],
dataset: Optional[ReferenceDataset | CurrentDataset | CompletionDataset],
metrics: Optional[
ReferenceDatasetMetrics | CurrentDatasetMetrics | CompletionDatasetMetrics
],
missing_status,
) -> ModelQualityDTO:
"""Create a ModelQualityDTO from the provided dataset and metrics."""
Expand Down
35 changes: 35 additions & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Optional
import uuid

from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics
from app.db.tables.completion_dataset_table import CompletionDataset
from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics
from app.db.tables.current_dataset_table import CurrentDataset
Expand Down Expand Up @@ -545,6 +546,30 @@ def get_sample_completion_dataset(
},
}

model_quality_completion_dict = {
'tokens': [
{
'id': 'chatcmpl',
'probs': [
{'prob': 0.27718424797058105, 'token': 'Sky'},
{'prob': 0.8951022028923035, 'token': ' is'},
{'prob': 0.7038467526435852, 'token': ' blue'},
{'prob': 0.9999753832817078, 'token': '.'},
],
}
],
'mean_per_file': [
{'prob_tot_mean': 0.7190271615982056, 'perplex_tot_mean': 1.5469378232955933}
],
'mean_per_phrase': [
{
'id': 'chatcmpl',
'prob_per_phrase': 0.7190271615982056,
'perplex_per_phrase': 1.5469378232955933,
}
],
}


def get_sample_reference_metrics(
reference_uuid: uuid.UUID = REFERENCE_UUID,
Expand Down Expand Up @@ -576,3 +601,13 @@ def get_sample_current_metrics(
drift=drift,
percentages=percentages,
)


def get_sample_completion_metrics(
completion_uuid: uuid.UUID = COMPLETION_UUID,
model_quality: Dict = model_quality_completion_dict,
) -> CompletionDatasetMetrics:
return CompletionDatasetMetrics(
completion_uuid=completion_uuid,
model_quality=model_quality,
)
56 changes: 56 additions & 0 deletions api/tests/dao/completion_dataset_metrics_dao_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO
from app.db.dao.model_dao import ModelDAO
from app.models.model_dto import ModelType
from tests.commons.db_integration import DatabaseIntegration
from tests.commons.db_mock import (
get_sample_completion_dataset,
get_sample_completion_metrics,
get_sample_model,
)


class CompletionDatasetMetricsDAOTest(DatabaseIntegration):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_dao = ModelDAO(cls.db)
cls.metrics_dao = CompletionDatasetMetricsDAO(cls.db)
cls.f_completion_dataset_dao = CompletionDatasetDAO(cls.db)

def test_insert_completion_dataset_metrics(self):
self.model_dao.insert(
get_sample_model(
model_type=ModelType.TEXT_GENERATION,
features=None,
target=None,
outputs=None,
timestamp=None,
)
)
completion_upload = get_sample_completion_dataset()
self.f_completion_dataset_dao.insert_completion_dataset(completion_upload)
completion_metrics = get_sample_completion_metrics(
completion_uuid=completion_upload.uuid
)
inserted = self.metrics_dao.insert_completion_metrics(completion_metrics)
assert inserted == completion_metrics

def test_get_completion_metrics_by_model_uuid(self):
self.model_dao.insert(
get_sample_model(
model_type=ModelType.TEXT_GENERATION,
features=None,
target=None,
outputs=None,
timestamp=None,
)
)
completion = get_sample_completion_dataset()
self.f_completion_dataset_dao.insert_completion_dataset(completion)
metrics = get_sample_completion_metrics()
self.metrics_dao.insert_completion_metrics(metrics)
retrieved = self.metrics_dao.get_completion_metrics_by_model_uuid(
model_uuid=completion.model_uuid, completion_uuid=completion.uuid
)
assert retrieved.uuid == metrics.uuid
23 changes: 23 additions & 0 deletions api/tests/routes/metrics_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,26 @@ def test_get_current_data_quality_by_model_by_uuid(self):
self.metrics_service.get_current_data_quality_by_model_by_uuid.assert_called_once_with(
model_uuid, current_uuid
)

def test_get_completion_model_quality_by_model_by_uuid(self):
model_uuid = uuid.uuid4()
completion_uuid = uuid.uuid4()
completion_metrics = db_mock.get_sample_completion_metrics()
model_quality = ModelQualityDTO.from_dict(
dataset_type=DatasetType.COMPLETION,
model_type=ModelType.TEXT_GENERATION,
job_status=JobStatus.SUCCEEDED,
model_quality_data=completion_metrics.model_quality,
)
self.metrics_service.get_completion_model_quality_by_model_by_uuid = MagicMock(
return_value=model_quality
)

res = self.client.get(
f'{self.prefix}/{model_uuid}/completion/{completion_uuid}/model-quality'
)
assert res.status_code == 200
assert jsonable_encoder(model_quality) == res.json()
self.metrics_service.get_completion_model_quality_by_model_by_uuid.assert_called_once_with(
model_uuid, completion_uuid
)
Loading

0 comments on commit f3b750e

Please sign in to comment.