-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: define text generation metrics api (#213)
* add completion model quality endpoint * add get completion model quality endpoint * fix test
- Loading branch information
Showing
10 changed files
with
329 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Optional | ||
from uuid import UUID | ||
|
||
from app.db.database import Database | ||
from app.db.tables.completion_dataset_metrics_table import CompletionDatasetMetrics | ||
from app.db.tables.completion_dataset_table import CompletionDataset | ||
|
||
|
||
class CompletionDatasetMetricsDAO: | ||
def __init__(self, database: Database) -> None: | ||
self.db = database | ||
|
||
def insert_completion_metrics( | ||
self, completion_metrics: CompletionDatasetMetrics | ||
) -> CompletionDatasetMetrics: | ||
with self.db.begin_session() as session: | ||
session.add(completion_metrics) | ||
session.flush() | ||
return completion_metrics | ||
|
||
def get_completion_metrics_by_model_uuid( | ||
self, model_uuid: UUID, completion_uuid: UUID | ||
) -> Optional[CompletionDatasetMetrics]: | ||
with self.db.begin_session() as session: | ||
return ( | ||
session.query(CompletionDatasetMetrics) | ||
.join( | ||
CompletionDataset, | ||
CompletionDatasetMetrics.completion_uuid == CompletionDataset.uuid, | ||
) | ||
.where( | ||
CompletionDataset.model_uuid == model_uuid, | ||
CompletionDataset.uuid == completion_uuid, | ||
) | ||
.one_or_none() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
class DatasetType(str, Enum): | ||
REFERENCE = 'REFERENCE' | ||
CURRENT = 'CURRENT' | ||
COMPLETION = 'COMPLETION' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from app.db.dao.completion_dataset_dao import CompletionDatasetDAO | ||
from app.db.dao.completion_dataset_metrics_dao import CompletionDatasetMetricsDAO | ||
from app.db.dao.model_dao import ModelDAO | ||
from app.models.model_dto import ModelType | ||
from tests.commons.db_integration import DatabaseIntegration | ||
from tests.commons.db_mock import ( | ||
get_sample_completion_dataset, | ||
get_sample_completion_metrics, | ||
get_sample_model, | ||
) | ||
|
||
|
||
class CompletionDatasetMetricsDAOTest(DatabaseIntegration): | ||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.model_dao = ModelDAO(cls.db) | ||
cls.metrics_dao = CompletionDatasetMetricsDAO(cls.db) | ||
cls.f_completion_dataset_dao = CompletionDatasetDAO(cls.db) | ||
|
||
def test_insert_completion_dataset_metrics(self): | ||
self.model_dao.insert( | ||
get_sample_model( | ||
model_type=ModelType.TEXT_GENERATION, | ||
features=None, | ||
target=None, | ||
outputs=None, | ||
timestamp=None, | ||
) | ||
) | ||
completion_upload = get_sample_completion_dataset() | ||
self.f_completion_dataset_dao.insert_completion_dataset(completion_upload) | ||
completion_metrics = get_sample_completion_metrics( | ||
completion_uuid=completion_upload.uuid | ||
) | ||
inserted = self.metrics_dao.insert_completion_metrics(completion_metrics) | ||
assert inserted == completion_metrics | ||
|
||
def test_get_completion_metrics_by_model_uuid(self): | ||
self.model_dao.insert( | ||
get_sample_model( | ||
model_type=ModelType.TEXT_GENERATION, | ||
features=None, | ||
target=None, | ||
outputs=None, | ||
timestamp=None, | ||
) | ||
) | ||
completion = get_sample_completion_dataset() | ||
self.f_completion_dataset_dao.insert_completion_dataset(completion) | ||
metrics = get_sample_completion_metrics() | ||
self.metrics_dao.insert_completion_metrics(metrics) | ||
retrieved = self.metrics_dao.get_completion_metrics_by_model_uuid( | ||
model_uuid=completion.model_uuid, completion_uuid=completion.uuid | ||
) | ||
assert retrieved.uuid == metrics.uuid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.