Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: define text generation metrics api #213

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions api/app/db/dao/completion_dataset_metrics_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional
from uuid import UUID

from app.db.database import Database
from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics
from app.db.tables.completion_dataset_table import CompletionDataset


class CompletionDatasetMetricsDAO:
def __init__(self, database: Database) -> None:
self.db = database

def insert_completion_metrics(
self, completion_metrics: CompletionDatasetMetrics
) -> CompletionDatasetMetrics:
with self.db.begin_session() as session:
session.add(completion_metrics)
session.flush()
return completion_metrics

def get_completion_metrics_by_model_uuid(
self, model_uuid: UUID, completion_uuid: UUID
) -> Optional[CompletionDatasetMetrics]:
with self.db.begin_session() as session:
return (
session.query(CompletionDatasetMetrics)
.join(
CompletionDataset,
CompletionDatasetMetrics.completion_uuid == CompletionDataset.uuid,
)
.where(
CompletionDataset.model_uuid == model_uuid,
CompletionDataset.uuid == completion_uuid,
)
.one_or_none()
)
4 changes: 4 additions & 0 deletions api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from app.core import get_config
from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO
from app.db.dao.model_dao import ModelDAO
Expand Down Expand Up @@ -56,6 +57,7 @@
current_dataset_dao = CurrentDatasetDAO(database)
current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database)
completion_dataset_dao = CompletionDatasetDAO(database)
completion_dataset_metrics_dao = CompletionDatasetMetricsDAO(database)

model_service = ModelService(
model_dao=model_dao,
Expand Down Expand Up @@ -94,6 +96,8 @@
reference_dataset_dao=reference_dataset_dao,
current_dataset_metrics_dao=current_dataset_metrics_dao,
current_dataset_dao=current_dataset_dao,
completion_dataset_metrics_dao=completion_dataset_metrics_dao,
completion_dataset_dao=completion_dataset_dao,
model_service=model_service,
)
spark_k8s_service = SparkK8SService(spark_k8s_client)
Expand Down
1 change: 1 addition & 0 deletions api/app/models/dataset_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
class DatasetType(str, Enum):
REFERENCE = 'REFERENCE'
CURRENT = 'CURRENT'
COMPLETION = 'COMPLETION'
48 changes: 48 additions & 0 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ class CurrentRegressionModelQuality(BaseModel):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class TokenProb(BaseModel):
prob: float
token: str


class TokenData(BaseModel):
id: str
probs: List[TokenProb]


class MeanPerFile(BaseModel):
prob_tot_mean: float
perplex_tot_mean: float

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MeanPerPhrase(BaseModel):
id: str
prob_per_phrase: float
perplex_per_phrase: float

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CompletionTextGenerationModelQuality(BaseModel):
tokens: List[TokenData]
mean_per_file: List[MeanPerFile]
mean_per_phrase: List[MeanPerPhrase]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ModelQualityDTO(BaseModel):
job_status: JobStatus
model_quality: Optional[
Expand All @@ -201,6 +234,7 @@ class ModelQualityDTO(BaseModel):
| CurrentMultiClassificationModelQuality
| RegressionModelQuality
| CurrentRegressionModelQuality
| CompletionTextGenerationModelQuality
]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Expand Down Expand Up @@ -251,6 +285,10 @@ def _create_model_quality(
return ModelQualityDTO._create_regression_model_quality(
dataset_type=dataset_type, model_quality_data=model_quality_data
)
if model_type == ModelType.TEXT_GENERATION:
return ModelQualityDTO._create_text_generation_model_quality(
dataset_type=dataset_type, model_quality_data=model_quality_data
)
raise MetricsInternalError(f'Invalid model type {model_type}')

@staticmethod
Expand Down Expand Up @@ -288,3 +326,13 @@ def _create_regression_model_quality(
if dataset_type == DatasetType.CURRENT:
return CurrentRegressionModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_text_generation_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> CompletionTextGenerationModelQuality:
"""Create a text generation model quality instance based on dataset type."""
if dataset_type == DatasetType.COMPLETION:
return CompletionTextGenerationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')
12 changes: 12 additions & 0 deletions api/app/routes/metrics_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,16 @@ def get_current_percentages_by_model_by_uuid(
model_uuid, current_uuid
)

@router.get(
'/{model_uuid}/completion/{completion_uuid}/model-quality',
status_code=200,
response_model=ModelQualityDTO,
)
def get_completion_model_quality_by_model_by_uuid(
model_uuid: UUID, completion_uuid: UUID
):
return metrics_service.get_completion_model_quality_by_model_by_uuid(
model_uuid, completion_uuid
)

return router
47 changes: 43 additions & 4 deletions api/app/services/metrics_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from typing import Optional
from uuid import UUID

from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.db.dao.reference_dataset_metrics_dao import ReferenceDatasetMetricsDAO
from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics
from app.db.tables.completion_dataset_table import CompletionDataset
from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics
from app.db.tables.current_dataset_table import CurrentDataset
from app.db.tables.reference_dataset_metrics_table import ReferenceDatasetMetrics
Expand All @@ -29,12 +33,16 @@ def __init__(
reference_dataset_dao: ReferenceDatasetDAO,
current_dataset_metrics_dao: CurrentDatasetMetricsDAO,
current_dataset_dao: CurrentDatasetDAO,
completion_dataset_metrics_dao: CompletionDatasetMetricsDAO,
completion_dataset_dao: CompletionDatasetDAO,
model_service: ModelService,
):
self.reference_dataset_metrics_dao = reference_dataset_metrics_dao
self.reference_dataset_dao = reference_dataset_dao
self.current_dataset_metrics_dao = current_dataset_metrics_dao
self.current_dataset_dao = current_dataset_dao
self.completion_dataset_metrics_dao = completion_dataset_metrics_dao
self.completion_dataset_dao = completion_dataset_dao
self.model_service = model_service

def get_reference_statistics_by_model_by_uuid(
Expand Down Expand Up @@ -83,6 +91,19 @@ def get_current_model_quality_by_model_by_uuid(
missing_status=JobStatus.MISSING_CURRENT,
)

def get_completion_model_quality_by_model_by_uuid(
self, model_uuid: UUID, completion_uuid: UUID
) -> ModelQualityDTO:
"""Retrieve completion model quality for a model by its UUID."""
return self._get_model_quality_by_model_uuid(
model_uuid=model_uuid,
dataset_and_metrics_getter=lambda uuid: self.check_and_get_completion_dataset_and_metrics(
uuid, completion_uuid
),
dataset_type=DatasetType.COMPLETION,
missing_status=JobStatus.MISSING_COMPLETION,
)

def get_reference_data_quality_by_model_by_uuid(
self, model_uuid: UUID
) -> DataQualityDTO:
Expand Down Expand Up @@ -162,12 +183,28 @@ def check_and_get_current_dataset_and_metrics(
),
)

def check_and_get_completion_dataset_and_metrics(
self, model_uuid: UUID, completion_uuid: UUID
) -> tuple[Optional[CompletionDataset], Optional[CompletionDatasetMetrics]]:
"""Check and retrieve the completion dataset and its metrics for a model by its UUID."""
return self._check_and_get_dataset_and_metrics(
model_uuid=model_uuid,
dataset_getter=lambda uuid: self.completion_dataset_dao.get_completion_dataset_by_model_uuid(
uuid, completion_uuid
),
metrics_getter=lambda uuid: self.completion_dataset_metrics_dao.get_completion_metrics_by_model_uuid(
uuid, completion_uuid
),
)

@staticmethod
def _check_and_get_dataset_and_metrics(
model_uuid: UUID, dataset_getter, metrics_getter
) -> tuple[
Optional[ReferenceDataset | CurrentDataset],
Optional[ReferenceDatasetMetrics | CurrentDatasetMetrics],
Optional[ReferenceDataset | CurrentDataset | CompletionDataset],
Optional[
ReferenceDatasetMetrics | CurrentDatasetMetrics | CompletionDatasetMetrics
],
]:
"""Check and retrieve the dataset and its metrics using the provided getters."""
dataset = dataset_getter(model_uuid)
Expand Down Expand Up @@ -294,8 +331,10 @@ def _create_statistics_dto(
def _create_model_quality_dto(
dataset_type: DatasetType,
model_type: ModelType,
dataset: Optional[ReferenceDataset | CurrentDataset],
metrics: Optional[ReferenceDatasetMetrics | CurrentDatasetMetrics],
dataset: Optional[ReferenceDataset | CurrentDataset | CompletionDataset],
metrics: Optional[
ReferenceDatasetMetrics | CurrentDatasetMetrics | CompletionDatasetMetrics
],
missing_status,
) -> ModelQualityDTO:
"""Create a ModelQualityDTO from the provided dataset and metrics."""
Expand Down
35 changes: 35 additions & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Optional
import uuid

from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics
from app.db.tables.completion_dataset_table import CompletionDataset
from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics
from app.db.tables.current_dataset_table import CurrentDataset
Expand Down Expand Up @@ -545,6 +546,30 @@ def get_sample_completion_dataset(
},
}

model_quality_completion_dict = {
'tokens': [
{
'id': 'chatcmpl',
'probs': [
{'prob': 0.27718424797058105, 'token': 'Sky'},
{'prob': 0.8951022028923035, 'token': ' is'},
{'prob': 0.7038467526435852, 'token': ' blue'},
{'prob': 0.9999753832817078, 'token': '.'},
],
}
],
'mean_per_file': [
{'prob_tot_mean': 0.7190271615982056, 'perplex_tot_mean': 1.5469378232955933}
],
'mean_per_phrase': [
{
'id': 'chatcmpl',
'prob_per_phrase': 0.7190271615982056,
'perplex_per_phrase': 1.5469378232955933,
}
],
}


def get_sample_reference_metrics(
reference_uuid: uuid.UUID = REFERENCE_UUID,
Expand Down Expand Up @@ -576,3 +601,13 @@ def get_sample_current_metrics(
drift=drift,
percentages=percentages,
)


def get_sample_completion_metrics(
completion_uuid: uuid.UUID = COMPLETION_UUID,
model_quality: Dict = model_quality_completion_dict,
) -> CompletionDatasetMetrics:
return CompletionDatasetMetrics(
completion_uuid=completion_uuid,
model_quality=model_quality,
)
56 changes: 56 additions & 0 deletions api/tests/dao/completion_dataset_metrics_dao_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO
from app.db.dao.model_dao import ModelDAO
from app.models.model_dto import ModelType
from tests.commons.db_integration import DatabaseIntegration
from tests.commons.db_mock import (
get_sample_completion_dataset,
get_sample_completion_metrics,
get_sample_model,
)


class CompletionDatasetMetricsDAOTest(DatabaseIntegration):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_dao = ModelDAO(cls.db)
cls.metrics_dao = CompletionDatasetMetricsDAO(cls.db)
cls.f_completion_dataset_dao = CompletionDatasetDAO(cls.db)

def test_insert_completion_dataset_metrics(self):
self.model_dao.insert(
get_sample_model(
model_type=ModelType.TEXT_GENERATION,
features=None,
target=None,
outputs=None,
timestamp=None,
)
)
completion_upload = get_sample_completion_dataset()
self.f_completion_dataset_dao.insert_completion_dataset(completion_upload)
completion_metrics = get_sample_completion_metrics(
completion_uuid=completion_upload.uuid
)
inserted = self.metrics_dao.insert_completion_metrics(completion_metrics)
assert inserted == completion_metrics

def test_get_completion_metrics_by_model_uuid(self):
self.model_dao.insert(
get_sample_model(
model_type=ModelType.TEXT_GENERATION,
features=None,
target=None,
outputs=None,
timestamp=None,
)
)
completion = get_sample_completion_dataset()
self.f_completion_dataset_dao.insert_completion_dataset(completion)
metrics = get_sample_completion_metrics()
self.metrics_dao.insert_completion_metrics(metrics)
retrieved = self.metrics_dao.get_completion_metrics_by_model_uuid(
model_uuid=completion.model_uuid, completion_uuid=completion.uuid
)
assert retrieved.uuid == metrics.uuid
23 changes: 23 additions & 0 deletions api/tests/routes/metrics_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,26 @@ def test_get_current_data_quality_by_model_by_uuid(self):
self.metrics_service.get_current_data_quality_by_model_by_uuid.assert_called_once_with(
model_uuid, current_uuid
)

def test_get_completion_model_quality_by_model_by_uuid(self):
model_uuid = uuid.uuid4()
completion_uuid = uuid.uuid4()
completion_metrics = db_mock.get_sample_completion_metrics()
model_quality = ModelQualityDTO.from_dict(
dataset_type=DatasetType.COMPLETION,
model_type=ModelType.TEXT_GENERATION,
job_status=JobStatus.SUCCEEDED,
model_quality_data=completion_metrics.model_quality,
)
self.metrics_service.get_completion_model_quality_by_model_by_uuid = MagicMock(
return_value=model_quality
)

res = self.client.get(
f'{self.prefix}/{model_uuid}/completion/{completion_uuid}/model-quality'
)
assert res.status_code == 200
assert jsonable_encoder(model_quality) == res.json()
self.metrics_service.get_completion_model_quality_by_model_by_uuid.assert_called_once_with(
model_uuid, completion_uuid
)
Loading