From c4d34ea37742b4e0fb5e2a4dc7dc27e15803596c Mon Sep 17 00:00:00 2001 From: Daniele Tria <36860433+dtria91@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:15:58 +0100 Subject: [PATCH] feat: handle latest completion dataset (#212) * handle latest completion dataset * add completion dataset dao to model service * fix: edi alert test * fix: handle dataset table name argument * fix: edit validation json file * fix: ruff check --- api/app/main.py | 1 + api/app/models/alert_dto.py | 1 + api/app/models/job_status.py | 1 + api/app/models/model_dto.py | 52 +++++++----- api/app/services/file_service.py | 15 ++-- api/app/services/model_service.py | 104 ++++++++++++++++------- api/tests/routes/model_route_test.py | 2 + api/tests/services/file_service_test.py | 2 +- api/tests/services/model_service_test.py | 94 ++++++++++++++------ spark/jobs/completion_job.py | 6 +- spark/jobs/current_job.py | 1 + spark/jobs/reference_job.py | 1 + spark/tests/completion_metrics_test.py | 8 +- 13 files changed, 198 insertions(+), 90 deletions(-) diff --git a/api/app/main.py b/api/app/main.py index 0d86e44d..a221e746 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -61,6 +61,7 @@ model_dao=model_dao, reference_dataset_dao=reference_dataset_dao, current_dataset_dao=current_dataset_dao, + completion_dataset_dao=completion_dataset_dao, ) s3_config = get_config().s3_config diff --git a/api/app/models/alert_dto.py b/api/app/models/alert_dto.py index 237f6d3c..dd38d8c5 100644 --- a/api/app/models/alert_dto.py +++ b/api/app/models/alert_dto.py @@ -17,6 +17,7 @@ class AlertDTO(BaseModel): model_uuid: UUID reference_uuid: Optional[UUID] current_uuid: Optional[UUID] + completion_uuid: Optional[UUID] anomaly_type: AnomalyType anomaly_features: List[str] diff --git a/api/app/models/job_status.py b/api/app/models/job_status.py index ec7b0fdd..3ca364fe 100644 --- a/api/app/models/job_status.py +++ b/api/app/models/job_status.py @@ -7,3 +7,4 @@ class JobStatus(str, Enum): ERROR = 'ERROR' MISSING_REFERENCE = 'MISSING_REFERENCE' MISSING_CURRENT = 'MISSING_CURRENT' + MISSING_COMPLETION = 'MISSING_COMPLETION' diff --git a/api/app/models/model_dto.py b/api/app/models/model_dto.py index c56b7a5d..ee976993 100644 --- a/api/app/models/model_dto.py +++ b/api/app/models/model_dto.py @@ -9,6 +9,7 @@ from app.db.dao.current_dataset_dao import CurrentDataset from app.db.dao.model_dao import Model from app.db.dao.reference_dataset_dao import ReferenceDataset +from app.db.tables.completion_dataset_table import CompletionDataset from app.models.inferred_schema_dto import FieldType, SupportedTypes from app.models.job_status import JobStatus from app.models.metrics.percentages_dto import Percentages @@ -244,8 +245,10 @@ class ModelOut(BaseModel): updated_at: str latest_reference_uuid: Optional[UUID] latest_current_uuid: Optional[UUID] - latest_reference_job_status: JobStatus - latest_current_job_status: JobStatus + latest_completion_uuid: Optional[UUID] + latest_reference_job_status: Optional[JobStatus] + latest_current_job_status: Optional[JobStatus] + latest_completion_job_status: Optional[JobStatus] percentages: Optional[Percentages] model_config = ConfigDict( @@ -257,25 +260,34 @@ def from_model( model: Model, latest_reference_dataset: Optional[ReferenceDataset] = None, latest_current_dataset: Optional[CurrentDataset] = None, + latest_completion_dataset: Optional[CompletionDataset] = None, percentages: Optional[Percentages] = None, ): - latest_reference_uuid = ( - latest_reference_dataset.uuid if latest_reference_dataset else None - ) - latest_current_uuid = ( - latest_current_dataset.uuid if latest_current_dataset else None - ) - - latest_reference_job_status = ( - latest_reference_dataset.status - if latest_reference_dataset - else JobStatus.MISSING_REFERENCE - ) - latest_current_job_status = ( - latest_current_dataset.status - if latest_current_dataset - else JobStatus.MISSING_CURRENT - ) + latest_reference_uuid = None + latest_current_uuid = None + latest_completion_uuid = None + latest_reference_job_status = None + latest_current_job_status = None + latest_completion_job_status = None + + if model.model_type == ModelType.TEXT_GENERATION: + if latest_completion_dataset: + latest_completion_uuid = latest_completion_dataset.uuid + latest_completion_job_status = latest_completion_dataset.status + else: + latest_completion_job_status = JobStatus.MISSING_COMPLETION + else: + if latest_reference_dataset: + latest_reference_uuid = latest_reference_dataset.uuid + latest_reference_job_status = latest_reference_dataset.status + else: + latest_reference_job_status = JobStatus.MISSING_REFERENCE + + if latest_current_dataset: + latest_current_uuid = latest_current_dataset.uuid + latest_current_job_status = latest_current_dataset.status + else: + latest_current_job_status = JobStatus.MISSING_CURRENT return ModelOut( uuid=model.uuid, @@ -294,7 +306,9 @@ def from_model( updated_at=str(model.updated_at), latest_reference_uuid=latest_reference_uuid, latest_current_uuid=latest_current_uuid, + latest_completion_uuid=latest_completion_uuid, latest_reference_job_status=latest_reference_job_status, latest_current_job_status=latest_current_job_status, + latest_completion_job_status=latest_completion_job_status, percentages=percentages, ) diff --git a/api/app/services/file_service.py b/api/app/services/file_service.py index e08756e4..601db2a7 100644 --- a/api/app/services/file_service.py +++ b/api/app/services/file_service.py @@ -1,5 +1,6 @@ from copy import deepcopy import datetime +from io import BytesIO import json import logging import pathlib @@ -344,13 +345,13 @@ def upload_completion_file( logger.error('Model %s not found', model_uuid) raise ModelNotFoundError(f'Model {model_uuid} not found') - self.validate_json_file(json_file) - _f_name = json_file.filename + validated_json_file = self.validate_json_file(json_file) + _f_name = validated_json_file.filename _f_uuid = uuid4() try: object_name = f'{str(model_out.uuid)}/completion/{_f_uuid}/{_f_name}' self.s3_client.upload_fileobj( - json_file.file, + validated_json_file.file, self.bucket_name, object_name, ExtraArgs={ @@ -560,11 +561,15 @@ def validate_file( csv_file.file.seek(0) @staticmethod - def validate_json_file(json_file: UploadFile) -> None: + def validate_json_file(json_file: UploadFile): try: content = json_file.file.read().decode('utf-8') json_data = json.loads(content) - CompletionResponses.model_validate(json_data) + validated_data = CompletionResponses.model_validate(json_data) + return UploadFile( + filename=json_file.filename, + file=BytesIO(validated_data.model_dump_json().encode()), + ) except ValidationError as e: logger.error('Invalid json file: %s', str(e)) raise InvalidFileException(f'Invalid json file: {str(e)}') from e diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index ed4531f9..5d135081 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -3,9 +3,11 @@ from fastapi_pagination import Page, Params +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.model_dao import ModelDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO +from app.db.tables.completion_dataset_table import CompletionDataset from app.db.tables.current_dataset_metrics_table import CurrentDatasetMetrics from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.model_table import Model @@ -13,7 +15,7 @@ from app.models.alert_dto import AlertDTO, AnomalyType from app.models.exceptions import ModelError, ModelInternalError, ModelNotFoundError from app.models.metrics.tot_percentages_dto import TotPercentagesDTO -from app.models.model_dto import ModelFeatures, ModelIn, ModelOut +from app.models.model_dto import ModelFeatures, ModelIn, ModelOut, ModelType from app.models.model_order import OrderType @@ -23,10 +25,12 @@ def __init__( model_dao: ModelDAO, reference_dataset_dao: ReferenceDatasetDAO, current_dataset_dao: CurrentDatasetDAO, + completion_dataset_dao: CompletionDatasetDAO, ): self.model_dao = model_dao - self.rd_dao = reference_dataset_dao - self.cd_dao = current_dataset_dao + self.reference_dataset_dao = reference_dataset_dao + self.current_dataset_dao = current_dataset_dao + self.completion_dataset_dao = completion_dataset_dao def create_model(self, model_in: ModelIn) -> ModelOut: try: @@ -40,22 +44,25 @@ def create_model(self, model_in: ModelIn) -> ModelOut: def get_model_by_uuid(self, model_uuid: UUID) -> Optional[ModelOut]: model = self.check_and_get_model(model_uuid) - latest_reference_dataset, latest_current_dataset = self.get_latest_datasets( - model_uuid + latest_reference_dataset, latest_current_dataset, latest_completion_dataset = ( + self._get_latest_datasets(model_uuid, model.model_type) ) return ModelOut.from_model( model=model, latest_reference_dataset=latest_reference_dataset, latest_current_dataset=latest_current_dataset, + latest_completion_dataset=latest_completion_dataset, ) def update_model_features_by_uuid( self, model_uuid: UUID, model_features: ModelFeatures ) -> bool: - last_reference = self.rd_dao.get_latest_reference_dataset_by_model_uuid( - model_uuid + latest_reference_dataset = ( + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid( + model_uuid + ) ) - if last_reference is not None: + if latest_reference_dataset is not None: raise ModelError( 'Model already has a reference dataset, could not be updated', 400 ) from None @@ -77,13 +84,16 @@ def get_all_models( models = self.model_dao.get_all() model_out_list = [] for model in models: - latest_reference_dataset, latest_current_dataset = self.get_latest_datasets( - model.uuid - ) + ( + latest_reference_dataset, + latest_current_dataset, + latest_completion_dataset, + ) = self._get_latest_datasets(model.uuid, model.model_type) model_out = ModelOut.from_model( model=model, latest_reference_dataset=latest_reference_dataset, latest_current_dataset=latest_current_dataset, + latest_completion_dataset=latest_completion_dataset, ) model_out_list.append(model_out) return model_out_list @@ -134,9 +144,11 @@ def get_last_n_alerts(self, n_alerts) -> List[AlertDTO]: res = [] count_alerts = 0 for model, metrics in models: - latest_reference_dataset, latest_current_dataset = self.get_latest_datasets( - model.uuid - ) + ( + latest_reference_dataset, + latest_current_dataset, + latest_completion_dataset, + ) = self._get_latest_datasets(model.uuid, model.model_type) if metrics and metrics.percentages: for perc in ['data_quality', 'model_quality', 'drift']: if count_alerts == n_alerts: @@ -152,6 +164,9 @@ def get_last_n_alerts(self, n_alerts) -> List[AlertDTO]: current_uuid=latest_current_dataset.uuid if latest_current_dataset else None, + completion_uuid=latest_completion_dataset.uuid + if latest_completion_dataset + else None, anomaly_type=AnomalyType[perc.upper()], anomaly_features=[ x['feature_name'] @@ -172,13 +187,16 @@ def get_last_n_models_percentages(self, n_models) -> List[ModelOut]: models = self.model_dao.get_last_n_percentages(n_models) model_out_list_tmp = [] for model, metrics in models: - latest_reference_dataset, latest_current_dataset = self.get_latest_datasets( - model.uuid - ) + ( + latest_reference_dataset, + latest_current_dataset, + latest_completion_dataset, + ) = self._get_latest_datasets(model.uuid, model.model_type) model_out = ModelOut.from_model( model=model, latest_reference_dataset=latest_reference_dataset, latest_current_dataset=latest_current_dataset, + latest_completion_dataset=latest_completion_dataset, percentages=metrics.percentages if metrics else None, ) model_out_list_tmp.append(model_out) @@ -195,13 +213,16 @@ def get_all_models_paginated( ) _items = [] for model, metrics in models.items: - latest_reference_dataset, latest_current_dataset = self.get_latest_datasets( - model.uuid - ) + ( + latest_reference_dataset, + latest_current_dataset, + latest_completion_dataset, + ) = self._get_latest_datasets(model.uuid, model.model_type) model_out = ModelOut.from_model( model=model, latest_reference_dataset=latest_reference_dataset, latest_current_dataset=latest_current_dataset, + latest_completion_dataset=latest_completion_dataset, percentages=metrics.percentages if metrics else None, ) _items.append(model_out) @@ -214,14 +235,37 @@ def check_and_get_model(self, model_uuid: UUID) -> Model: raise ModelNotFoundError(f'Model {model_uuid} not found') return model - def get_latest_datasets( - self, model_uuid: UUID - ) -> (Optional[ReferenceDataset], Optional[CurrentDataset]): - latest_reference_dataset = ( - self.rd_dao.get_latest_reference_dataset_by_model_uuid(model_uuid) - ) - latest_current_dataset = self.cd_dao.get_latest_current_dataset_by_model_uuid( - model_uuid - ) + def _get_latest_datasets( + self, model_uuid: UUID, model_type: ModelType + ) -> ( + Optional[ReferenceDataset], + Optional[CurrentDataset], + Optional[CompletionDataset], + ): + latest_reference_dataset = None + latest_current_dataset = None + latest_completion_dataset = None + + if model_type == ModelType.TEXT_GENERATION: + latest_completion_dataset = ( + self.completion_dataset_dao.get_latest_completion_dataset_by_model_uuid( + model_uuid + ) + ) + else: + latest_reference_dataset = ( + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid( + model_uuid + ) + ) + latest_current_dataset = ( + self.current_dataset_dao.get_latest_current_dataset_by_model_uuid( + model_uuid + ) + ) - return latest_reference_dataset, latest_current_dataset + return ( + latest_reference_dataset, + latest_current_dataset, + latest_completion_dataset, + ) diff --git a/api/tests/routes/model_route_test.py b/api/tests/routes/model_route_test.py index ae8df3ed..b7288bb0 100644 --- a/api/tests/routes/model_route_test.py +++ b/api/tests/routes/model_route_test.py @@ -175,6 +175,7 @@ def test_get_last_n_alerts(self): model_uuid=model1.uuid, reference_uuid=None, current_uuid=current1.uuid, + completion_uuid=None, anomaly_type=AnomalyType.DRIFT, anomaly_features=['num1', 'num2'], ), @@ -183,6 +184,7 @@ def test_get_last_n_alerts(self): model_uuid=model0.uuid, reference_uuid=None, current_uuid=current0.uuid, + completion_uuid=None, anomaly_type=AnomalyType.DRIFT, anomaly_features=['num1', 'num2'], ), diff --git a/api/tests/services/file_service_test.py b/api/tests/services/file_service_test.py index e326041a..2cdddcd6 100644 --- a/api/tests/services/file_service_test.py +++ b/api/tests/services/file_service_test.py @@ -83,7 +83,7 @@ def test_infer_schema_separator(self): def test_validate_completion_json_file_ok(self): json_file = json.get_completion_sample_json_file() - self.files_service.validate_json_file(json_file) + assert self.files_service.validate_json_file(json_file) is not None def test_validate_completion_json_file_without_logprobs_field(self): json_file = json.get_completion_sample_json_file_without_logprobs_field() diff --git a/api/tests/services/model_service_test.py b/api/tests/services/model_service_test.py index 40a4e674..4186e476 100644 --- a/api/tests/services/model_service_test.py +++ b/api/tests/services/model_service_test.py @@ -5,6 +5,7 @@ from fastapi_pagination import Page, Params import pytest +from app.db.dao.completion_dataset_dao import CompletionDatasetDAO from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO from app.db.dao.model_dao import ModelDAO @@ -21,17 +22,30 @@ class ModelServiceTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.model_dao: ModelDAO = MagicMock(spec_set=ModelDAO) - cls.rd_dao: ReferenceDatasetDAO = MagicMock(spec_set=ReferenceDatasetDAO) - cls.cd_dao: CurrentDatasetDAO = MagicMock(spec_set=CurrentDatasetDAO) + cls.reference_dataset_dao: ReferenceDatasetDAO = MagicMock( + spec_set=ReferenceDatasetDAO + ) + cls.current_dataset_dao: CurrentDatasetDAO = MagicMock( + spec_set=CurrentDatasetDAO + ) + cls.completion_dataset_dao: CompletionDatasetDAO = MagicMock( + spec_set=CompletionDatasetDAO + ) cls.model_service = ModelService( model_dao=cls.model_dao, - reference_dataset_dao=cls.rd_dao, - current_dataset_dao=cls.cd_dao, + reference_dataset_dao=cls.reference_dataset_dao, + current_dataset_dao=cls.current_dataset_dao, + completion_dataset_dao=cls.completion_dataset_dao, ) cls.current_metrics_dao: CurrentDatasetMetricsDAO = MagicMock( spec_set=CurrentDatasetMetricsDAO ) - cls.mocks = [cls.model_dao, cls.rd_dao, cls.cd_dao] + cls.mocks = [ + cls.model_dao, + cls.reference_dataset_dao, + cls.current_dataset_dao, + cls.completion_dataset_dao, + ] def test_create_model_ok(self): model = db_mock.get_sample_model() @@ -68,16 +82,16 @@ def test_get_model_by_uuid_ok(self): reference_dataset = db_mock.get_sample_reference_dataset(model_uuid=model.uuid) current_dataset = db_mock.get_sample_current_dataset(model_uuid=model.uuid) self.model_dao.get_by_uuid = MagicMock(return_value=model) - self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( - return_value=reference_dataset + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid = ( + MagicMock(return_value=reference_dataset) ) - self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock( + self.current_dataset_dao.get_latest_current_dataset_by_model_uuid = MagicMock( return_value=current_dataset ) res = self.model_service.get_model_by_uuid(model_uuid) self.model_dao.get_by_uuid.assert_called_once() - self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once() - self.cd_dao.get_latest_current_dataset_by_model_uuid.assert_called_once() + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once() + self.current_dataset_dao.get_latest_current_dataset_by_model_uuid.assert_called_once() assert res == ModelOut.from_model( model=model, @@ -85,6 +99,30 @@ def test_get_model_by_uuid_ok(self): latest_current_dataset=current_dataset, ) + def test_get_model_by_uuid_without_schema_ok(self): + model = db_mock.get_sample_model( + model_type=ModelType.TEXT_GENERATION, + features=None, + target=None, + outputs=None, + timestamp=None, + ) + completion_dataset = db_mock.get_sample_completion_dataset( + model_uuid=model.uuid + ) + self.model_dao.get_by_uuid = MagicMock(return_value=model) + self.completion_dataset_dao.get_latest_completion_dataset_by_model_uuid = ( + MagicMock(return_value=completion_dataset) + ) + res = self.model_service.get_model_by_uuid(model_uuid) + self.model_dao.get_by_uuid.assert_called_once() + self.completion_dataset_dao.get_latest_completion_dataset_by_model_uuid.assert_called_once() + + assert res == ModelOut.from_model( + model=model, + latest_completion_dataset=completion_dataset, + ) + def test_get_model_by_uuid_not_found(self): self.model_dao.get_by_uuid = MagicMock(return_value=None) pytest.raises( @@ -94,15 +132,15 @@ def test_get_model_by_uuid_not_found(self): def test_update_model_ok(self): model_features = db_mock.get_sample_model_features() - self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( - return_value=None + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid = ( + MagicMock(return_value=None) ) self.model_dao.update_features = MagicMock(return_value=1) res = self.model_service.update_model_features_by_uuid( model_uuid, model_features ) feature_dict = [feature.to_dict() for feature in model_features.features] - self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( model_uuid ) self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict) @@ -111,15 +149,15 @@ def test_update_model_ok(self): def test_update_model_ko(self): model_features = db_mock.get_sample_model_features() - self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( - return_value=None + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid = ( + MagicMock(return_value=None) ) self.model_dao.update_features = MagicMock(return_value=0) res = self.model_service.update_model_features_by_uuid( model_uuid, model_features ) feature_dict = [feature.to_dict() for feature in model_features.features] - self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( model_uuid ) self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict) @@ -128,14 +166,14 @@ def test_update_model_ko(self): def test_update_model_freezed(self): model_features = db_mock.get_sample_model_features() - self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( - return_value=db_mock.get_sample_reference_dataset() + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid = ( + MagicMock(return_value=db_mock.get_sample_reference_dataset()) ) self.model_dao.update_features = MagicMock(return_value=0) with pytest.raises(ModelError): self.model_service.update_model_features_by_uuid(model_uuid, model_features) feature_dict = [feature.to_dict() for feature in model_features.features] - self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( model_uuid ) self.model_dao.update_features.assert_called_once_with( @@ -172,10 +210,10 @@ def test_get_all_models_paginated_ok(self): sort=None, ) self.model_dao.get_all_paginated = MagicMock(return_value=page) - self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( - return_value=None + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid = ( + MagicMock(return_value=None) ) - self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock( + self.current_dataset_dao.get_latest_current_dataset_by_model_uuid = MagicMock( return_value=None ) @@ -202,10 +240,10 @@ def test_get_all_models_ok(self): model3 = db_mock.get_sample_model(id=3, uuid=uuid.uuid4(), name='model3') sample_models = [model1, model2, model3] self.model_dao.get_all = MagicMock(return_value=sample_models) - self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( - return_value=None + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid = ( + MagicMock(return_value=None) ) - self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock( + self.current_dataset_dao.get_latest_current_dataset_by_model_uuid = MagicMock( return_value=None ) @@ -233,10 +271,10 @@ def test_get_last_n_models_percentages(self): sample = [(model1, metrics1), (model0, metrics0)] self.model_dao.get_last_n_percentages = MagicMock(return_value=sample) - self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( - return_value=None + self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid = ( + MagicMock(return_value=None) ) - self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock( + self.current_dataset_dao.get_latest_current_dataset_by_model_uuid = MagicMock( return_value=None ) diff --git a/spark/jobs/completion_job.py b/spark/jobs/completion_job.py index 93ab9a91..36850e94 100644 --- a/spark/jobs/completion_job.py +++ b/spark/jobs/completion_job.py @@ -17,9 +17,9 @@ def compute_metrics(df: DataFrame) -> dict: complete_record = {} completion_service = CompletionMetrics() model_quality = completion_service.extract_metrics(df) - complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump(serialize_as_any=True)).decode( - "utf-8" - ) + complete_record["MODEL_QUALITY"] = orjson.dumps( + model_quality.model_dump(serialize_as_any=True) + ).decode("utf-8") return complete_record diff --git a/spark/jobs/current_job.py b/spark/jobs/current_job.py index ff0140da..4e7eb1ba 100644 --- a/spark/jobs/current_job.py +++ b/spark/jobs/current_job.py @@ -202,6 +202,7 @@ def main( current_uuid, reference_dataset_path, metrics_table_name, + dataset_table_name, ) except Exception as e: logging.exception(e) diff --git a/spark/jobs/reference_job.py b/spark/jobs/reference_job.py index 4750ab3b..8e7017eb 100644 --- a/spark/jobs/reference_job.py +++ b/spark/jobs/reference_job.py @@ -146,6 +146,7 @@ def main( reference_dataset_path, reference_uuid, metrics_table_name, + dataset_table_name, ) except Exception as e: logging.exception(e) diff --git a/spark/tests/completion_metrics_test.py b/spark/tests/completion_metrics_test.py index d4c81c43..19dcb446 100644 --- a/spark/tests/completion_metrics_test.py +++ b/spark/tests/completion_metrics_test.py @@ -31,7 +31,9 @@ def test_compute_prob(spark_fixture, input_file): def test_extract_metrics(spark_fixture, input_file): completion_metrics_service = CompletionMetrics() - completion_metrics_model: CompletionMetricsModel = completion_metrics_service.extract_metrics(input_file) + completion_metrics_model: CompletionMetricsModel = ( + completion_metrics_service.extract_metrics(input_file) + ) assert len(completion_metrics_model.tokens) > 0 assert len(completion_metrics_model.mean_per_phrase) > 0 assert len(completion_metrics_model.mean_per_file) > 0 @@ -40,6 +42,4 @@ def test_extract_metrics(spark_fixture, input_file): def test_compute_metrics(spark_fixture, input_file): complete_record = compute_metrics(input_file) model_quality = complete_record.get("MODEL_QUALITY") - assert model_quality == orjson.dumps(completion_metric_results).decode( - "utf-8" - ) + assert model_quality == orjson.dumps(completion_metric_results).decode("utf-8")