diff --git a/api/alembic/versions/3ec04e609ae9_set_correlation_id_optional.py b/api/alembic/versions/3ec04e609ae9_set_correlation_id_optional.py new file mode 100644 index 00000000..feab27f6 --- /dev/null +++ b/api/alembic/versions/3ec04e609ae9_set_correlation_id_optional.py @@ -0,0 +1,42 @@ +"""set_correlation_id_optional + +Revision ID: 3ec04e609ae9 +Revises: 086f26392cc4 +Create Date: 2024-07-08 10:28:35.068312 + +""" +from typing import Sequence, Union, Text + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '3ec04e609ae9' +down_revision: Union[str, None] = '086f26392cc4' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('current_dataset', 'CORRELATION_ID_COLUMN', + existing_type=sa.VARCHAR(), + nullable=True) + op.create_unique_constraint(None, 'current_dataset', ['UUID']) + op.create_unique_constraint(None, 'current_dataset_metrics', ['UUID']) + op.create_unique_constraint(None, 'reference_dataset', ['UUID']) + op.create_unique_constraint(None, 'reference_dataset_metrics', ['UUID']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'reference_dataset_metrics', type_='unique') + op.drop_constraint(None, 'reference_dataset', type_='unique') + op.drop_constraint(None, 'current_dataset_metrics', type_='unique') + op.drop_constraint(None, 'current_dataset', type_='unique') + op.alter_column('current_dataset', 'CORRELATION_ID_COLUMN', + existing_type=sa.VARCHAR(), + nullable=False) + # ### end Alembic commands ### diff --git a/api/app/db/tables/current_dataset_table.py b/api/app/db/tables/current_dataset_table.py index 40f6184b..30cf4acd 100644 --- a/api/app/db/tables/current_dataset_table.py +++ b/api/app/db/tables/current_dataset_table.py @@ -23,5 +23,5 @@ class CurrentDataset(Reflected, BaseTable, BaseDAO): ) path = Column('PATH', VARCHAR, nullable=False) date = Column('DATE', TIMESTAMP(timezone=True), nullable=False) - correlation_id_column = Column('CORRELATION_ID_COLUMN', VARCHAR, nullable=False) + correlation_id_column = Column('CORRELATION_ID_COLUMN', VARCHAR, nullable=True) status = Column('STATUS', VARCHAR, nullable=False, default=JobStatus.IMPORTING) diff --git a/api/app/models/dataset_dto.py b/api/app/models/dataset_dto.py index 3159118f..8f6feb4e 100644 --- a/api/app/models/dataset_dto.py +++ b/api/app/models/dataset_dto.py @@ -36,7 +36,7 @@ class CurrentDatasetDTO(BaseModel): model_uuid: UUID path: str date: str - correlation_id_column: str + correlation_id_column: Optional[str] status: str model_config = ConfigDict( diff --git a/api/app/routes/upload_dataset_route.py b/api/app/routes/upload_dataset_route.py index 90d0a7b1..bd906f48 100644 --- a/api/app/routes/upload_dataset_route.py +++ b/api/app/routes/upload_dataset_route.py @@ -48,7 +48,7 @@ def upload_current_file( model_uuid: UUID, csv_file: UploadFile = File(...), sep: str = Form(','), - correlation_id_column: str = Form(''), + correlation_id_column: Optional[str] = Form(None), ) -> CurrentDatasetDTO: return file_service.upload_current_file( model_uuid, csv_file, correlation_id_column, sep diff --git a/api/app/services/file_service.py b/api/app/services/file_service.py index 537f9f1e..929b0bbb 100644 --- a/api/app/services/file_service.py +++ b/api/app/services/file_service.py @@ -198,7 +198,7 @@ def upload_current_file( self, model_uuid: UUID, csv_file: UploadFile, - correlation_id_column: str, + correlation_id_column: Optional[str] = None, sep: str = ',', columns=None, ) -> CurrentDatasetDTO: diff --git a/api/tests/routes/upload_dataset_route_test.py b/api/tests/routes/upload_dataset_route_test.py index 4b0ab28a..73c16524 100644 --- a/api/tests/routes/upload_dataset_route_test.py +++ b/api/tests/routes/upload_dataset_route_test.py @@ -73,6 +73,48 @@ def test_bind_reference(self): assert res.status_code == 200 assert jsonable_encoder(upload_file_result) == res.json() + def test_upload_current(self): + file = csv.get_correct_sample_csv_file() + model_uuid = uuid.uuid4() + upload_file_result = CurrentDatasetDTO( + uuid=uuid.uuid4(), + model_uuid=model_uuid, + path='test', + date=str(datetime.datetime.now(tz=datetime.UTC)), + status=JobStatus.IMPORTING, + correlation_id_column=None + ) + self.file_service.upload_current_file = MagicMock( + return_value=upload_file_result + ) + res = self.client.post( + f'{self.prefix}/{model_uuid}/current/upload', + files={'csv_file': (file.filename, file.file)}, + ) + assert res.status_code == 200 + assert jsonable_encoder(upload_file_result) == res.json() + + def test_bind_current(self): + file_ref = FileReference(file_url='/file') + model_uuid = uuid.uuid4() + upload_file_result = CurrentDatasetDTO( + uuid=uuid.uuid4(), + model_uuid=model_uuid, + path='test', + date=str(datetime.datetime.now(tz=datetime.UTC)), + status=JobStatus.IMPORTING, + correlation_id_column=None + ) + self.file_service.bind_current_file = MagicMock( + return_value=upload_file_result + ) + res = self.client.post( + f'{self.prefix}/{model_uuid}/current/bind', + json=jsonable_encoder(file_ref), + ) + assert res.status_code == 200 + assert jsonable_encoder(upload_file_result) == res.json() + def test_get_all_reference_datasets_by_model_uuid_paginated(self): test_model_uuid = uuid.uuid4() reference_upload_1 = db_mock.get_sample_reference_dataset( diff --git a/api/tests/services/file_service_test.py b/api/tests/services/file_service_test.py index 2d77a148..65efc577 100644 --- a/api/tests/services/file_service_test.py +++ b/api/tests/services/file_service_test.py @@ -199,13 +199,12 @@ def test_upload_current_file_ok(self): ) object_name = f'{str(model.uuid)}/current/{file.filename}' path = f's3://bucket/{object_name}' - correlation_id_column = 'correlation_id' inserted_file = CurrentDataset( uuid=uuid4(), model_uuid=model_uuid, path=path, date=datetime.datetime.now(tz=datetime.UTC), - correlation_id_column=correlation_id_column, + correlation_id_column=None, status=JobStatus.IMPORTING, ) reference_file = get_sample_reference_dataset(model_uuid=model_uuid) @@ -221,7 +220,7 @@ def test_upload_current_file_ok(self): self.spark_k8s_client.submit_app = MagicMock() result = self.files_service.upload_current_file( - model.uuid, file, correlation_id_column + model.uuid, file, ) self.model_svc.get_model_by_uuid.assert_called_once() diff --git a/sdk/radicalbit_platform_sdk/apis/model.py b/sdk/radicalbit_platform_sdk/apis/model.py index 271d5489..901f7de0 100644 --- a/sdk/radicalbit_platform_sdk/apis/model.py +++ b/sdk/radicalbit_platform_sdk/apis/model.py @@ -284,7 +284,7 @@ def load_current_dataset( self, file_name: str, bucket: str, - correlation_id_column: str, + correlation_id_column: Optional[str] = None, object_name: Optional[str] = None, aws_credentials: Optional[AwsCredentials] = None, separator: str = ',', @@ -307,7 +307,8 @@ def load_current_dataset( ).columns.tolist() required_headers = self.__required_headers() - required_headers.append(correlation_id_column) + if correlation_id_column: + required_headers.append(correlation_id_column) required_headers.append(self.__timestamp.name) if set(required_headers).issubset(file_headers): @@ -465,7 +466,7 @@ def __bind_current_dataset( self, dataset_url: str, separator: str, - correlation_id_column: str, + correlation_id_column: Optional[str] = None, ) -> ModelCurrentDataset: def __callback(response: requests.Response) -> ModelCurrentDataset: try: diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 06430527..3d7998d5 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -49,7 +49,7 @@ def uuid(self) -> UUID: def path(self) -> str: return self.__path - def correlation_id_column(self) -> str: + def correlation_id_column(self) -> Optional[str]: return self.__correlation_id_column def date(self) -> str: diff --git a/sdk/radicalbit_platform_sdk/models/file_upload_result.py b/sdk/radicalbit_platform_sdk/models/file_upload_result.py index 1e4b0545..88eb1af6 100644 --- a/sdk/radicalbit_platform_sdk/models/file_upload_result.py +++ b/sdk/radicalbit_platform_sdk/models/file_upload_result.py @@ -21,7 +21,7 @@ class ReferenceFileUpload(FileUploadResult): class CurrentFileUpload(FileUploadResult): - correlation_id_column: str + correlation_id_column: Optional[str] = None model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)