Skip to content

Commit

Permalink
feat: handle latest completion dataset (#212)
Browse files Browse the repository at this point in the history
* handle latest completion dataset

* add completion dataset dao to model service

* fix: edi alert test

* fix: handle dataset table name argument

* fix: edit validation json file

* fix: ruff check
  • Loading branch information
dtria91 authored Dec 13, 2024
1 parent ff2f911 commit c4d34ea
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 90 deletions.
1 change: 1 addition & 0 deletions api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
model_dao=model_dao,
reference_dataset_dao=reference_dataset_dao,
current_dataset_dao=current_dataset_dao,
completion_dataset_dao=completion_dataset_dao,
)
s3_config = get_config().s3_config

Expand Down
1 change: 1 addition & 0 deletions api/app/models/alert_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AlertDTO(BaseModel):
model_uuid: UUID
reference_uuid: Optional[UUID]
current_uuid: Optional[UUID]
completion_uuid: Optional[UUID]
anomaly_type: AnomalyType
anomaly_features: List[str]

Expand Down
1 change: 1 addition & 0 deletions api/app/models/job_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ class JobStatus(str, Enum):
ERROR = 'ERROR'
MISSING_REFERENCE = 'MISSING_REFERENCE'
MISSING_CURRENT = 'MISSING_CURRENT'
MISSING_COMPLETION = 'MISSING_COMPLETION'
52 changes: 33 additions & 19 deletions api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from app.db.dao.current_dataset_dao import CurrentDataset
from app.db.dao.model_dao import Model
from app.db.dao.reference_dataset_dao import ReferenceDataset
from app.db.tables.completion_dataset_table import CompletionDataset
from app.models.inferred_schema_dto import FieldType, SupportedTypes
from app.models.job_status import JobStatus
from app.models.metrics.percentages_dto import Percentages
Expand Down Expand Up @@ -244,8 +245,10 @@ class ModelOut(BaseModel):
updated_at: str
latest_reference_uuid: Optional[UUID]
latest_current_uuid: Optional[UUID]
latest_reference_job_status: JobStatus
latest_current_job_status: JobStatus
latest_completion_uuid: Optional[UUID]
latest_reference_job_status: Optional[JobStatus]
latest_current_job_status: Optional[JobStatus]
latest_completion_job_status: Optional[JobStatus]
percentages: Optional[Percentages]

model_config = ConfigDict(
Expand All @@ -257,25 +260,34 @@ def from_model(
model: Model,
latest_reference_dataset: Optional[ReferenceDataset] = None,
latest_current_dataset: Optional[CurrentDataset] = None,
latest_completion_dataset: Optional[CompletionDataset] = None,
percentages: Optional[Percentages] = None,
):
latest_reference_uuid = (
latest_reference_dataset.uuid if latest_reference_dataset else None
)
latest_current_uuid = (
latest_current_dataset.uuid if latest_current_dataset else None
)

latest_reference_job_status = (
latest_reference_dataset.status
if latest_reference_dataset
else JobStatus.MISSING_REFERENCE
)
latest_current_job_status = (
latest_current_dataset.status
if latest_current_dataset
else JobStatus.MISSING_CURRENT
)
latest_reference_uuid = None
latest_current_uuid = None
latest_completion_uuid = None
latest_reference_job_status = None
latest_current_job_status = None
latest_completion_job_status = None

if model.model_type == ModelType.TEXT_GENERATION:
if latest_completion_dataset:
latest_completion_uuid = latest_completion_dataset.uuid
latest_completion_job_status = latest_completion_dataset.status
else:
latest_completion_job_status = JobStatus.MISSING_COMPLETION
else:
if latest_reference_dataset:
latest_reference_uuid = latest_reference_dataset.uuid
latest_reference_job_status = latest_reference_dataset.status
else:
latest_reference_job_status = JobStatus.MISSING_REFERENCE

if latest_current_dataset:
latest_current_uuid = latest_current_dataset.uuid
latest_current_job_status = latest_current_dataset.status
else:
latest_current_job_status = JobStatus.MISSING_CURRENT

return ModelOut(
uuid=model.uuid,
Expand All @@ -294,7 +306,9 @@ def from_model(
updated_at=str(model.updated_at),
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
latest_completion_uuid=latest_completion_uuid,
latest_reference_job_status=latest_reference_job_status,
latest_current_job_status=latest_current_job_status,
latest_completion_job_status=latest_completion_job_status,
percentages=percentages,
)
15 changes: 10 additions & 5 deletions api/app/services/file_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
import datetime
from io import BytesIO
import json
import logging
import pathlib
Expand Down Expand Up @@ -344,13 +345,13 @@ def upload_completion_file(
logger.error('Model %s not found', model_uuid)
raise ModelNotFoundError(f'Model {model_uuid} not found')

self.validate_json_file(json_file)
_f_name = json_file.filename
validated_json_file = self.validate_json_file(json_file)
_f_name = validated_json_file.filename
_f_uuid = uuid4()
try:
object_name = f'{str(model_out.uuid)}/completion/{_f_uuid}/{_f_name}'
self.s3_client.upload_fileobj(
json_file.file,
validated_json_file.file,
self.bucket_name,
object_name,
ExtraArgs={
Expand Down Expand Up @@ -560,11 +561,15 @@ def validate_file(
csv_file.file.seek(0)

@staticmethod
def validate_json_file(json_file: UploadFile) -> None:
def validate_json_file(json_file: UploadFile):
try:
content = json_file.file.read().decode('utf-8')
json_data = json.loads(content)
CompletionResponses.model_validate(json_data)
validated_data = CompletionResponses.model_validate(json_data)
return UploadFile(
filename=json_file.filename,
file=BytesIO(validated_data.model_dump_json().encode()),
)
except ValidationError as e:
logger.error('Invalid json file: %s', str(e))
raise InvalidFileException(f'Invalid json file: {str(e)}') from e
Expand Down
104 changes: 74 additions & 30 deletions api/app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

from fastapi_pagination import Page, Params

from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.model_dao import ModelDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.db.tables.completion_dataset_table import CompletionDataset
from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics
from app.db.tables.current_dataset_table import CurrentDataset
from app.db.tables.model_table import Model
from app.db.tables.reference_dataset_table import ReferenceDataset
from app.models.alert_dto import AlertDTO, AnomalyType
from app.models.exceptions import ModelError, ModelInternalError, ModelNotFoundError
from app.models.metrics.tot_percentages_dto import TotPercentagesDTO
from app.models.model_dto import ModelFeatures, ModelIn, ModelOut
from app.models.model_dto import ModelFeatures, ModelIn, ModelOut, ModelType
from app.models.model_order import OrderType


Expand All @@ -23,10 +25,12 @@ def __init__(
model_dao: ModelDAO,
reference_dataset_dao: ReferenceDatasetDAO,
current_dataset_dao: CurrentDatasetDAO,
completion_dataset_dao: CompletionDatasetDAO,
):
self.model_dao = model_dao
self.rd_dao = reference_dataset_dao
self.cd_dao = current_dataset_dao
self.reference_dataset_dao = reference_dataset_dao
self.current_dataset_dao = current_dataset_dao
self.completion_dataset_dao = completion_dataset_dao

def create_model(self, model_in: ModelIn) -> ModelOut:
try:
Expand All @@ -40,22 +44,25 @@ def create_model(self, model_in: ModelIn) -> ModelOut:

def get_model_by_uuid(self, model_uuid: UUID) -> Optional[ModelOut]:
model = self.check_and_get_model(model_uuid)
latest_reference_dataset, latest_current_dataset = self.get_latest_datasets(
model_uuid
latest_reference_dataset, latest_current_dataset, latest_completion_dataset = (
self._get_latest_datasets(model_uuid, model.model_type)
)
return ModelOut.from_model(
model=model,
latest_reference_dataset=latest_reference_dataset,
latest_current_dataset=latest_current_dataset,
latest_completion_dataset=latest_completion_dataset,
)

def update_model_features_by_uuid(
self, model_uuid: UUID, model_features: ModelFeatures
) -> bool:
last_reference = self.rd_dao.get_latest_reference_dataset_by_model_uuid(
model_uuid
latest_reference_dataset = (
self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid(
model_uuid
)
)
if last_reference is not None:
if latest_reference_dataset is not None:
raise ModelError(
'Model already has a reference dataset, could not be updated', 400
) from None
Expand All @@ -77,13 +84,16 @@ def get_all_models(
models = self.model_dao.get_all()
model_out_list = []
for model in models:
latest_reference_dataset, latest_current_dataset = self.get_latest_datasets(
model.uuid
)
(
latest_reference_dataset,
latest_current_dataset,
latest_completion_dataset,
) = self._get_latest_datasets(model.uuid, model.model_type)
model_out = ModelOut.from_model(
model=model,
latest_reference_dataset=latest_reference_dataset,
latest_current_dataset=latest_current_dataset,
latest_completion_dataset=latest_completion_dataset,
)
model_out_list.append(model_out)
return model_out_list
Expand Down Expand Up @@ -134,9 +144,11 @@ def get_last_n_alerts(self, n_alerts) -> List[AlertDTO]:
res = []
count_alerts = 0
for model, metrics in models:
latest_reference_dataset, latest_current_dataset = self.get_latest_datasets(
model.uuid
)
(
latest_reference_dataset,
latest_current_dataset,
latest_completion_dataset,
) = self._get_latest_datasets(model.uuid, model.model_type)
if metrics and metrics.percentages:
for perc in ['data_quality', 'model_quality', 'drift']:
if count_alerts == n_alerts:
Expand All @@ -152,6 +164,9 @@ def get_last_n_alerts(self, n_alerts) -> List[AlertDTO]:
current_uuid=latest_current_dataset.uuid
if latest_current_dataset
else None,
completion_uuid=latest_completion_dataset.uuid
if latest_completion_dataset
else None,
anomaly_type=AnomalyType[perc.upper()],
anomaly_features=[
x['feature_name']
Expand All @@ -172,13 +187,16 @@ def get_last_n_models_percentages(self, n_models) -> List[ModelOut]:
models = self.model_dao.get_last_n_percentages(n_models)
model_out_list_tmp = []
for model, metrics in models:
latest_reference_dataset, latest_current_dataset = self.get_latest_datasets(
model.uuid
)
(
latest_reference_dataset,
latest_current_dataset,
latest_completion_dataset,
) = self._get_latest_datasets(model.uuid, model.model_type)
model_out = ModelOut.from_model(
model=model,
latest_reference_dataset=latest_reference_dataset,
latest_current_dataset=latest_current_dataset,
latest_completion_dataset=latest_completion_dataset,
percentages=metrics.percentages if metrics else None,
)
model_out_list_tmp.append(model_out)
Expand All @@ -195,13 +213,16 @@ def get_all_models_paginated(
)
_items = []
for model, metrics in models.items:
latest_reference_dataset, latest_current_dataset = self.get_latest_datasets(
model.uuid
)
(
latest_reference_dataset,
latest_current_dataset,
latest_completion_dataset,
) = self._get_latest_datasets(model.uuid, model.model_type)
model_out = ModelOut.from_model(
model=model,
latest_reference_dataset=latest_reference_dataset,
latest_current_dataset=latest_current_dataset,
latest_completion_dataset=latest_completion_dataset,
percentages=metrics.percentages if metrics else None,
)
_items.append(model_out)
Expand All @@ -214,14 +235,37 @@ def check_and_get_model(self, model_uuid: UUID) -> Model:
raise ModelNotFoundError(f'Model {model_uuid} not found')
return model

def get_latest_datasets(
self, model_uuid: UUID
) -> (Optional[ReferenceDataset], Optional[CurrentDataset]):
latest_reference_dataset = (
self.rd_dao.get_latest_reference_dataset_by_model_uuid(model_uuid)
)
latest_current_dataset = self.cd_dao.get_latest_current_dataset_by_model_uuid(
model_uuid
)
def _get_latest_datasets(
self, model_uuid: UUID, model_type: ModelType
) -> (
Optional[ReferenceDataset],
Optional[CurrentDataset],
Optional[CompletionDataset],
):
latest_reference_dataset = None
latest_current_dataset = None
latest_completion_dataset = None

if model_type == ModelType.TEXT_GENERATION:
latest_completion_dataset = (
self.completion_dataset_dao.get_latest_completion_dataset_by_model_uuid(
model_uuid
)
)
else:
latest_reference_dataset = (
self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid(
model_uuid
)
)
latest_current_dataset = (
self.current_dataset_dao.get_latest_current_dataset_by_model_uuid(
model_uuid
)
)

return latest_reference_dataset, latest_current_dataset
return (
latest_reference_dataset,
latest_current_dataset,
latest_completion_dataset,
)
2 changes: 2 additions & 0 deletions api/tests/routes/model_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def test_get_last_n_alerts(self):
model_uuid=model1.uuid,
reference_uuid=None,
current_uuid=current1.uuid,
completion_uuid=None,
anomaly_type=AnomalyType.DRIFT,
anomaly_features=['num1', 'num2'],
),
Expand All @@ -183,6 +184,7 @@ def test_get_last_n_alerts(self):
model_uuid=model0.uuid,
reference_uuid=None,
current_uuid=current0.uuid,
completion_uuid=None,
anomaly_type=AnomalyType.DRIFT,
anomaly_features=['num1', 'num2'],
),
Expand Down
2 changes: 1 addition & 1 deletion api/tests/services/file_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_infer_schema_separator(self):

def test_validate_completion_json_file_ok(self):
json_file = json.get_completion_sample_json_file()
self.files_service.validate_json_file(json_file)
assert self.files_service.validate_json_file(json_file) is not None

def test_validate_completion_json_file_without_logprobs_field(self):
json_file = json.get_completion_sample_json_file_without_logprobs_field()
Expand Down
Loading

0 comments on commit c4d34ea

Please sign in to comment.