diff --git a/api/app/db/dao/model_dao.py b/api/app/db/dao/model_dao.py index 0ea13f90..7eaa56bc 100644 --- a/api/app/db/dao/model_dao.py +++ b/api/app/db/dao/model_dao.py @@ -1,6 +1,6 @@ import datetime import re -from typing import List, Optional +from typing import Dict, List, Optional from uuid import UUID from fastapi_pagination import Page, Params @@ -32,6 +32,15 @@ def get_by_uuid(self, uuid: UUID) -> Optional[Model]: .one_or_none() ) + def update_features(self, uuid: UUID, model_features: List[Dict]): + with self.db.begin_session() as session: + query = ( + sqlalchemy.update(Model) + .where(Model.uuid == uuid) + .values(features=model_features) + ) + return session.execute(query).rowcount + def delete(self, uuid: UUID) -> int: with self.db.begin_session() as session: deleted_at = datetime.datetime.now(tz=datetime.UTC) diff --git a/api/app/models/model_dto.py b/api/app/models/model_dto.py index 648f848b..6358026e 100644 --- a/api/app/models/model_dto.py +++ b/api/app/models/model_dto.py @@ -77,6 +77,14 @@ def to_dict(self): return self.model_dump() +class ModelFeatures(BaseModel): + features: List[ColumnDefinition] + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + class ModelIn(BaseModel, validate_assignment=True): name: str description: Optional[str] = None diff --git a/api/app/routes/model_route.py b/api/app/routes/model_route.py index d19978c2..26d0c429 100644 --- a/api/app/routes/model_route.py +++ b/api/app/routes/model_route.py @@ -2,12 +2,12 @@ from typing import Annotated, List, Optional from uuid import UUID -from fastapi import APIRouter +from fastapi import APIRouter, Response from fastapi.params import Query from fastapi_pagination import Page, Params from app.core import get_config -from app.models.model_dto import ModelIn, ModelOut +from app.models.model_dto import ModelFeatures, ModelIn, ModelOut from app.models.model_order import OrderType from app.services.model_service import ModelService @@ -51,4 +51,12 @@ def delete_model(model_uuid: UUID): logger.info('Model %s with name %s deleted.', model.uuid, model.name) return model + @router.post('/{model_uuid}', status_code=200) + def update_model_features_by_uuid( + model_uuid: UUID, model_features: ModelFeatures + ): + if model_service.update_model_features_by_uuid(model_uuid, model_features): + return Response(status_code=200) + return Response(status_code=404) + return router diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index 9dd77758..c79b6c91 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -9,8 +9,8 @@ from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.model_table import Model from app.db.tables.reference_dataset_table import ReferenceDataset -from app.models.exceptions import ModelInternalError, ModelNotFoundError -from app.models.model_dto import ModelIn, ModelOut +from app.models.exceptions import ModelError, ModelInternalError, ModelNotFoundError +from app.models.model_dto import ModelFeatures, ModelIn, ModelOut from app.models.model_order import OrderType @@ -46,6 +46,23 @@ def get_model_by_uuid(self, model_uuid: UUID) -> Optional[ModelOut]: latest_current_dataset=latest_current_dataset, ) + def update_model_features_by_uuid( + self, model_uuid: UUID, model_features: ModelFeatures + ) -> bool: + last_reference = self.rd_dao.get_latest_reference_dataset_by_model_uuid( + model_uuid + ) + if last_reference is not None: + raise ModelError( + 'Model already has a reference dataset, could not be updated', 400 + ) from None + return ( + self.model_dao.update_features( + model_uuid, [feature.to_dict() for feature in model_features.features] + ) + > 0 + ) + def delete_model(self, model_uuid: UUID) -> Optional[ModelOut]: model = self.check_and_get_model(model_uuid) self.model_dao.delete(model_uuid) diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index e585db82..39c57da3 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -13,6 +13,7 @@ DataType, FieldType, Granularity, + ModelFeatures, ModelIn, ModelType, OutputType, @@ -72,6 +73,18 @@ def get_sample_model( ) +def get_sample_model_features( + features: List[ColumnDefinition] = [ + ColumnDefinition( + name='feature1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ], +): + return ModelFeatures(features=features) + + def get_sample_model_in( name: str = 'model_name', description: Optional[str] = None, diff --git a/api/tests/dao/model_dao_test.py b/api/tests/dao/model_dao_test.py index 5d7f14d2..f388f052 100644 --- a/api/tests/dao/model_dao_test.py +++ b/api/tests/dao/model_dao_test.py @@ -27,6 +27,18 @@ def test_get_by_uuid_empty(self): retrieved = self.model_dao.get_by_uuid(uuid.uuid4()) assert retrieved is None + def test_update(self): + model = db_mock.get_sample_model() + self.model_dao.insert(model) + new_features = [ + {'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'}, + {'name': 'feature2', 'type': 'int', 'fieldType': 'numerical'}, + ] + updated_rows = self.model_dao.update_features(model.uuid, new_features) + retrieved = self.model_dao.get_by_uuid(model.uuid) + assert updated_rows == 1 + assert retrieved.features == new_features + def test_delete(self): model = db_mock.get_sample_model() self.model_dao.insert(model) diff --git a/api/tests/routes/model_route_test.py b/api/tests/routes/model_route_test.py index 6f5917e9..94bee31f 100644 --- a/api/tests/routes/model_route_test.py +++ b/api/tests/routes/model_route_test.py @@ -50,6 +50,34 @@ def test_create_model(self): assert jsonable_encoder(model_out) == res.json() self.model_service.create_model.assert_called_once_with(model_in) + def test_update_model_ok(self): + model_features = db_mock.get_sample_model_features() + self.model_service.update_model_features_by_uuid = MagicMock(return_value=True) + + res = self.client.post( + f'{self.prefix}/{db_mock.MODEL_UUID}', + json=jsonable_encoder(model_features), + ) + + assert res.status_code == 200 + self.model_service.update_model_features_by_uuid.assert_called_once_with( + db_mock.MODEL_UUID, model_features + ) + + def test_update_model_ko(self): + model_features = db_mock.get_sample_model_features() + self.model_service.update_model_features_by_uuid = MagicMock(return_value=False) + + res = self.client.post( + f'{self.prefix}/{db_mock.MODEL_UUID}', + json=jsonable_encoder(model_features), + ) + + assert res.status_code == 404 + self.model_service.update_model_features_by_uuid.assert_called_once_with( + db_mock.MODEL_UUID, model_features + ) + def test_get_model_by_uuid(self): model = db_mock.get_sample_model() model_out = ModelOut.from_model(model) diff --git a/api/tests/services/model_service_test.py b/api/tests/services/model_service_test.py index 02e6a669..f0be09fe 100644 --- a/api/tests/services/model_service_test.py +++ b/api/tests/services/model_service_test.py @@ -8,7 +8,7 @@ from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.model_dao import ModelDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO -from app.models.exceptions import ModelNotFoundError +from app.models.exceptions import ModelError, ModelNotFoundError from app.models.model_dto import ModelOut from app.models.model_order import OrderType from app.services.model_service import ModelService @@ -66,6 +66,56 @@ def test_get_model_by_uuid_not_found(self): ) self.model_dao.get_by_uuid.assert_called_once() + def test_update_model_ok(self): + model_features = db_mock.get_sample_model_features() + self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( + return_value=None + ) + self.model_dao.update_features = MagicMock(return_value=1) + res = self.model_service.update_model_features_by_uuid( + model_uuid, model_features + ) + feature_dict = [feature.to_dict() for feature in model_features.features] + self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + model_uuid + ) + self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict) + + assert res is True + + def test_update_model_ko(self): + model_features = db_mock.get_sample_model_features() + self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( + return_value=None + ) + self.model_dao.update_features = MagicMock(return_value=0) + res = self.model_service.update_model_features_by_uuid( + model_uuid, model_features + ) + feature_dict = [feature.to_dict() for feature in model_features.features] + self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + model_uuid + ) + self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict) + + assert res is False + + def test_update_model_freezed(self): + model_features = db_mock.get_sample_model_features() + self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( + return_value=db_mock.get_sample_reference_dataset() + ) + self.model_dao.update_features = MagicMock(return_value=0) + with pytest.raises(ModelError): + self.model_service.update_model_features_by_uuid(model_uuid, model_features) + feature_dict = [feature.to_dict() for feature in model_features.features] + self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + model_uuid + ) + self.model_dao.update_features.assert_called_once_with( + model_uuid, feature_dict + ) + def test_delete_model_ok(self): model = db_mock.get_sample_model() self.model_dao.get_by_uuid = MagicMock(return_value=model)