Skip to content

Commit

Permalink
update only features
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte committed Jul 26, 2024
1 parent f27c4f8 commit e1a5512
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 38 deletions.
6 changes: 3 additions & 3 deletions api/app/db/dao/model_dao.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import re
from typing import List, Optional
from typing import Dict, List, Optional
from uuid import UUID

from fastapi_pagination import Page, Params
Expand Down Expand Up @@ -32,12 +32,12 @@ def get_by_uuid(self, uuid: UUID) -> Optional[Model]:
.one_or_none()
)

def update(self, uuid: UUID, model: Model):
def update_features(self, uuid: UUID, model_features: List[Dict]):
with self.db.begin_session() as session:
query = (
sqlalchemy.update(Model)
.where(Model.uuid == uuid)
.values(**model.attributes())
.values(features=model_features)
)
return session.execute(query).rowcount

Expand Down
8 changes: 8 additions & 0 deletions api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ def to_dict(self):
return self.model_dump()


class ModelFeatures(BaseModel):
features: List[ColumnDefinition]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class ModelIn(BaseModel, validate_assignment=True):
name: str
description: Optional[str] = None
Expand Down
8 changes: 5 additions & 3 deletions api/app/routes/model_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi_pagination import Page, Params

from app.core import get_config
from app.models.model_dto import ModelIn, ModelOut
from app.models.model_dto import ModelFeatures, ModelIn, ModelOut
from app.models.model_order import OrderType
from app.services.model_service import ModelService

Expand Down Expand Up @@ -52,8 +52,10 @@ def delete_model(model_uuid: UUID):
return model

@router.post('/{model_uuid}', status_code=200)
def update_model_by_uuid(model_uuid: UUID, model_in: ModelIn):
if model_service.update_model_by_uuid(model_uuid, model_in):
def update_model_features_by_uuid(
model_uuid: UUID, model_features: ModelFeatures
):
if model_service.update_model_features_by_uuid(model_uuid, model_features):
return Response(status_code=200)
return Response(status_code=404)

Expand Down
28 changes: 18 additions & 10 deletions api/app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from app.db.tables.current_dataset_table import CurrentDataset
from app.db.tables.model_table import Model
from app.db.tables.reference_dataset_table import ReferenceDataset
from app.models.exceptions import ModelInternalError, ModelNotFoundError
from app.models.model_dto import ModelIn, ModelOut
from app.models.exceptions import ModelError, ModelInternalError, ModelNotFoundError
from app.models.model_dto import ModelFeatures, ModelIn, ModelOut
from app.models.model_order import OrderType


Expand Down Expand Up @@ -46,14 +46,22 @@ def get_model_by_uuid(self, model_uuid: UUID) -> Optional[ModelOut]:
latest_current_dataset=latest_current_dataset,
)

def update_model_by_uuid(self, model_uuid: UUID, model_in: ModelIn) -> bool:
try:
to_update = model_in.to_model()
return self.model_dao.update(model_uuid, to_update) > 0
except Exception as e:
raise ModelInternalError(
f'An error occurred while updating the model: {e}'
) from e
def update_model_features_by_uuid(
self, model_uuid: UUID, model_features: ModelFeatures
) -> bool:
last_reference = self.rd_dao.get_latest_reference_dataset_by_model_uuid(
model_uuid
)
if last_reference is not None:
raise ModelError(
'Model already has a reference dataset, could not be updated', 400
) from None
return (
self.model_dao.update_features(
model_uuid, [feature.to_dict() for feature in model_features.features]
)
> 0
)

def delete_model(self, model_uuid: UUID) -> Optional[ModelOut]:
model = self.check_and_get_model(model_uuid)
Expand Down
13 changes: 13 additions & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DataType,
FieldType,
Granularity,
ModelFeatures,
ModelIn,
ModelType,
OutputType,
Expand Down Expand Up @@ -72,6 +73,18 @@ def get_sample_model(
)


def get_sample_model_features(
features: List[ColumnDefinition] = [
ColumnDefinition(
name='feature1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
):
return ModelFeatures(features=features)


def get_sample_model_in(
name: str = 'model_name',
description: Optional[str] = None,
Expand Down
13 changes: 8 additions & 5 deletions api/tests/dao/model_dao_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, List
import uuid

from app.db.dao.model_dao import ModelDAO
Expand Down Expand Up @@ -30,12 +31,14 @@ def test_get_by_uuid_empty(self):
def test_update(self):
model = db_mock.get_sample_model()
self.model_dao.insert(model)
model_to_update = model
model_to_update.name = 'updated_name'
rows = self.model_dao.update(model.uuid, model_to_update)
new_features = List[Dict] = [
{'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'},
{'name': 'feature2', 'type': 'int', 'fieldType': 'numerical'},
]
updated_rows = self.model_dao.update_features(model.uuid, new_features)
retrieved = self.model_dao.get_by_uuid(model.uuid)
assert rows == 1
assert retrieved.name == 'updated_name'
assert updated_rows == 1
assert retrieved.features == new_features

def test_delete(self):
model = db_mock.get_sample_model()
Expand Down
20 changes: 10 additions & 10 deletions api/tests/routes/model_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,31 @@ def test_create_model(self):
self.model_service.create_model.assert_called_once_with(model_in)

def test_update_model_ok(self):
model_in = db_mock.get_sample_model_in()
self.model_service.update_model_by_uuid = MagicMock(return_value=True)
model_features = db_mock.get_sample_model_features()
self.model_service.update_model_features_by_uuid = MagicMock(return_value=True)

res = self.client.post(
f'{self.prefix}/{db_mock.MODEL_UUID}',
json=jsonable_encoder(model_in),
json=jsonable_encoder(model_features),
)

assert res.status_code == 200
self.model_service.update_model_by_uuid.assert_called_once_with(
db_mock.MODEL_UUID, model_in
self.model_service.update_model_features_by_uuid.assert_called_once_with(
db_mock.MODEL_UUID, model_features
)

def test_update_model_ko(self):
model_in = db_mock.get_sample_model_in()
self.model_service.update_model_by_uuid = MagicMock(return_value=False)
model_features = db_mock.get_sample_model_features()
self.model_service.update_model_features_by_uuid = MagicMock(return_value=False)

res = self.client.post(
f'{self.prefix}/{db_mock.MODEL_UUID}',
json=jsonable_encoder(model_in),
json=jsonable_encoder(model_features),
)

assert res.status_code == 404
self.model_service.update_model_by_uuid.assert_called_once_with(
db_mock.MODEL_UUID, model_in
self.model_service.update_model_features_by_uuid.assert_called_once_with(
db_mock.MODEL_UUID, model_features
)

def test_get_model_by_uuid(self):
Expand Down
50 changes: 43 additions & 7 deletions api/tests/services/model_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.model_dao import ModelDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.models.exceptions import ModelNotFoundError
from app.models.exceptions import ModelError, ModelNotFoundError
from app.models.model_dto import ModelOut
from app.models.model_order import OrderType
from app.services.model_service import ModelService
Expand Down Expand Up @@ -67,19 +67,55 @@ def test_get_model_by_uuid_not_found(self):
self.model_dao.get_by_uuid.assert_called_once()

def test_update_model_ok(self):
model_in = db_mock.get_sample_model_in()
self.model_dao.update = MagicMock(return_value=1)
res = self.model_service.update_model_by_uuid(model_uuid, model_in)
model_features = db_mock.get_sample_model_features()
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=None
)
self.model_dao.update_features = MagicMock(return_value=1)
res = self.model_service.update_model_features_by_uuid(
model_uuid, model_features
)
feature_dict = [feature.to_dict() for feature in model_features.features]
self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict)

assert res is True

def test_update_model_ko(self):
model_in = db_mock.get_sample_model_in()
self.model_dao.update = MagicMock(return_value=0)
res = self.model_service.update_model_by_uuid(model_uuid, model_in)
model_features = db_mock.get_sample_model_features()
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=None
)
self.model_dao.update_features = MagicMock(return_value=0)
res = self.model_service.update_model_features_by_uuid(
model_uuid, model_features
)
feature_dict = [feature.to_dict() for feature in model_features.features]
self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict)

assert res is False

def test_update_model_freezed(self):
model_features = db_mock.get_sample_model_features()
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=db_mock.get_sample_reference_dataset()
)
self.model_dao.update_features = MagicMock(return_value=0)
with pytest.raises(ModelError):
self.model_service.update_model_features_by_uuid(model_uuid, model_features)
feature_dict = [feature.to_dict() for feature in model_features.features]
self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.model_dao.update_features.assert_called_once_with(
model_uuid, feature_dict
)

def test_delete_model_ok(self):
model = db_mock.get_sample_model()
self.model_dao.get_by_uuid = MagicMock(return_value=model)
Expand Down

0 comments on commit e1a5512

Please sign in to comment.