diff --git a/api/app/db/dao/model_dao.py b/api/app/db/dao/model_dao.py index a83b362e..7eaa56bc 100644 --- a/api/app/db/dao/model_dao.py +++ b/api/app/db/dao/model_dao.py @@ -1,6 +1,6 @@ import datetime import re -from typing import List, Optional +from typing import Dict, List, Optional from uuid import UUID from fastapi_pagination import Page, Params @@ -32,12 +32,12 @@ def get_by_uuid(self, uuid: UUID) -> Optional[Model]: .one_or_none() ) - def update(self, uuid: UUID, model: Model): + def update_features(self, uuid: UUID, model_features: List[Dict]): with self.db.begin_session() as session: query = ( sqlalchemy.update(Model) .where(Model.uuid == uuid) - .values(**model.attributes()) + .values(features=model_features) ) return session.execute(query).rowcount diff --git a/api/app/models/model_dto.py b/api/app/models/model_dto.py index 648f848b..6358026e 100644 --- a/api/app/models/model_dto.py +++ b/api/app/models/model_dto.py @@ -77,6 +77,14 @@ def to_dict(self): return self.model_dump() +class ModelFeatures(BaseModel): + features: List[ColumnDefinition] + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + class ModelIn(BaseModel, validate_assignment=True): name: str description: Optional[str] = None diff --git a/api/app/routes/model_route.py b/api/app/routes/model_route.py index 56419dd6..26d0c429 100644 --- a/api/app/routes/model_route.py +++ b/api/app/routes/model_route.py @@ -7,7 +7,7 @@ from fastapi_pagination import Page, Params from app.core import get_config -from app.models.model_dto import ModelIn, ModelOut +from app.models.model_dto import ModelFeatures, ModelIn, ModelOut from app.models.model_order import OrderType from app.services.model_service import ModelService @@ -52,8 +52,10 @@ def delete_model(model_uuid: UUID): return model @router.post('/{model_uuid}', status_code=200) - def update_model_by_uuid(model_uuid: UUID, model_in: ModelIn): - if model_service.update_model_by_uuid(model_uuid, model_in): + def update_model_features_by_uuid( + model_uuid: UUID, model_features: ModelFeatures + ): + if model_service.update_model_features_by_uuid(model_uuid, model_features): return Response(status_code=200) return Response(status_code=404) diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index 5af163f8..c79b6c91 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -9,8 +9,8 @@ from app.db.tables.current_dataset_table import CurrentDataset from app.db.tables.model_table import Model from app.db.tables.reference_dataset_table import ReferenceDataset -from app.models.exceptions import ModelInternalError, ModelNotFoundError -from app.models.model_dto import ModelIn, ModelOut +from app.models.exceptions import ModelError, ModelInternalError, ModelNotFoundError +from app.models.model_dto import ModelFeatures, ModelIn, ModelOut from app.models.model_order import OrderType @@ -46,14 +46,22 @@ def get_model_by_uuid(self, model_uuid: UUID) -> Optional[ModelOut]: latest_current_dataset=latest_current_dataset, ) - def update_model_by_uuid(self, model_uuid: UUID, model_in: ModelIn) -> bool: - try: - to_update = model_in.to_model() - return self.model_dao.update(model_uuid, to_update) > 0 - except Exception as e: - raise ModelInternalError( - f'An error occurred while updating the model: {e}' - ) from e + def update_model_features_by_uuid( + self, model_uuid: UUID, model_features: ModelFeatures + ) -> bool: + last_reference = self.rd_dao.get_latest_reference_dataset_by_model_uuid( + model_uuid + ) + if last_reference is not None: + raise ModelError( + 'Model already has a reference dataset, could not be updated', 400 + ) from None + return ( + self.model_dao.update_features( + model_uuid, [feature.to_dict() for feature in model_features.features] + ) + > 0 + ) def delete_model(self, model_uuid: UUID) -> Optional[ModelOut]: model = self.check_and_get_model(model_uuid) diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index e585db82..39c57da3 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -13,6 +13,7 @@ DataType, FieldType, Granularity, + ModelFeatures, ModelIn, ModelType, OutputType, @@ -72,6 +73,18 @@ def get_sample_model( ) +def get_sample_model_features( + features: List[ColumnDefinition] = [ + ColumnDefinition( + name='feature1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ], +): + return ModelFeatures(features=features) + + def get_sample_model_in( name: str = 'model_name', description: Optional[str] = None, diff --git a/api/tests/dao/model_dao_test.py b/api/tests/dao/model_dao_test.py index da776dbe..ec98ac96 100644 --- a/api/tests/dao/model_dao_test.py +++ b/api/tests/dao/model_dao_test.py @@ -1,3 +1,4 @@ +from typing import Dict, List import uuid from app.db.dao.model_dao import ModelDAO @@ -30,12 +31,14 @@ def test_get_by_uuid_empty(self): def test_update(self): model = db_mock.get_sample_model() self.model_dao.insert(model) - model_to_update = model - model_to_update.name = 'updated_name' - rows = self.model_dao.update(model.uuid, model_to_update) + new_features = List[Dict] = [ + {'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'}, + {'name': 'feature2', 'type': 'int', 'fieldType': 'numerical'}, + ] + updated_rows = self.model_dao.update_features(model.uuid, new_features) retrieved = self.model_dao.get_by_uuid(model.uuid) - assert rows == 1 - assert retrieved.name == 'updated_name' + assert updated_rows == 1 + assert retrieved.features == new_features def test_delete(self): model = db_mock.get_sample_model() diff --git a/api/tests/routes/model_route_test.py b/api/tests/routes/model_route_test.py index 726a57fb..94bee31f 100644 --- a/api/tests/routes/model_route_test.py +++ b/api/tests/routes/model_route_test.py @@ -51,31 +51,31 @@ def test_create_model(self): self.model_service.create_model.assert_called_once_with(model_in) def test_update_model_ok(self): - model_in = db_mock.get_sample_model_in() - self.model_service.update_model_by_uuid = MagicMock(return_value=True) + model_features = db_mock.get_sample_model_features() + self.model_service.update_model_features_by_uuid = MagicMock(return_value=True) res = self.client.post( f'{self.prefix}/{db_mock.MODEL_UUID}', - json=jsonable_encoder(model_in), + json=jsonable_encoder(model_features), ) assert res.status_code == 200 - self.model_service.update_model_by_uuid.assert_called_once_with( - db_mock.MODEL_UUID, model_in + self.model_service.update_model_features_by_uuid.assert_called_once_with( + db_mock.MODEL_UUID, model_features ) def test_update_model_ko(self): - model_in = db_mock.get_sample_model_in() - self.model_service.update_model_by_uuid = MagicMock(return_value=False) + model_features = db_mock.get_sample_model_features() + self.model_service.update_model_features_by_uuid = MagicMock(return_value=False) res = self.client.post( f'{self.prefix}/{db_mock.MODEL_UUID}', - json=jsonable_encoder(model_in), + json=jsonable_encoder(model_features), ) assert res.status_code == 404 - self.model_service.update_model_by_uuid.assert_called_once_with( - db_mock.MODEL_UUID, model_in + self.model_service.update_model_features_by_uuid.assert_called_once_with( + db_mock.MODEL_UUID, model_features ) def test_get_model_by_uuid(self): diff --git a/api/tests/services/model_service_test.py b/api/tests/services/model_service_test.py index dfa6274c..f0be09fe 100644 --- a/api/tests/services/model_service_test.py +++ b/api/tests/services/model_service_test.py @@ -8,7 +8,7 @@ from app.db.dao.current_dataset_dao import CurrentDatasetDAO from app.db.dao.model_dao import ModelDAO from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO -from app.models.exceptions import ModelNotFoundError +from app.models.exceptions import ModelError, ModelNotFoundError from app.models.model_dto import ModelOut from app.models.model_order import OrderType from app.services.model_service import ModelService @@ -67,19 +67,55 @@ def test_get_model_by_uuid_not_found(self): self.model_dao.get_by_uuid.assert_called_once() def test_update_model_ok(self): - model_in = db_mock.get_sample_model_in() - self.model_dao.update = MagicMock(return_value=1) - res = self.model_service.update_model_by_uuid(model_uuid, model_in) + model_features = db_mock.get_sample_model_features() + self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( + return_value=None + ) + self.model_dao.update_features = MagicMock(return_value=1) + res = self.model_service.update_model_features_by_uuid( + model_uuid, model_features + ) + feature_dict = [feature.to_dict() for feature in model_features.features] + self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + model_uuid + ) + self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict) assert res is True def test_update_model_ko(self): - model_in = db_mock.get_sample_model_in() - self.model_dao.update = MagicMock(return_value=0) - res = self.model_service.update_model_by_uuid(model_uuid, model_in) + model_features = db_mock.get_sample_model_features() + self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( + return_value=None + ) + self.model_dao.update_features = MagicMock(return_value=0) + res = self.model_service.update_model_features_by_uuid( + model_uuid, model_features + ) + feature_dict = [feature.to_dict() for feature in model_features.features] + self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + model_uuid + ) + self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict) assert res is False + def test_update_model_freezed(self): + model_features = db_mock.get_sample_model_features() + self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock( + return_value=db_mock.get_sample_reference_dataset() + ) + self.model_dao.update_features = MagicMock(return_value=0) + with pytest.raises(ModelError): + self.model_service.update_model_features_by_uuid(model_uuid, model_features) + feature_dict = [feature.to_dict() for feature in model_features.features] + self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with( + model_uuid + ) + self.model_dao.update_features.assert_called_once_with( + model_uuid, feature_dict + ) + def test_delete_model_ok(self): model = db_mock.get_sample_model() self.model_dao.get_by_uuid = MagicMock(return_value=model)