diff --git a/object_storage_api/core/exceptions.py b/object_storage_api/core/exceptions.py index 095c051..9afca44 100644 --- a/object_storage_api/core/exceptions.py +++ b/object_storage_api/core/exceptions.py @@ -6,10 +6,11 @@ # TODO: Some of this file is identical to the one in inventory-management-system-api - Use common repo? +from typing import Optional + + class BaseAPIException(Exception): - """ - Base exception for API errors. - """ + """Base exception for API errors.""" # Status code to return if this exception is raised status_code: int @@ -19,37 +20,56 @@ class BaseAPIException(Exception): detail: str - def __init__(self, detail: str): + def __init__(self, detail: str, response_detail: Optional[str] = None): """ Initialise the exception. :param detail: Specific detail of the exception (just like Exception would take - this will only be logged and not returned in a response). + :param response_detail: Generic detail of the exception that will be returned in a response. """ super().__init__(detail) self.detail = detail + if response_detail is not None: + self.response_detail = response_detail + class DatabaseError(BaseAPIException): - """ - Database related error. - """ + """Database related error.""" class InvalidObjectIdError(DatabaseError): - """ - The provided value is not a valid ObjectId. - """ + """The provided value is not a valid ObjectId.""" status_code = 422 response_detail = "Invalid ID given" class InvalidImageFileError(BaseAPIException): - """ - The provided image file is not valid. - """ + """The provided image file is not valid.""" status_code = 422 response_detail = "File given is not a valid image" + + +class MissingRecordError(DatabaseError): + """A specific database record was requested but could not be found.""" + + status_code = 404 + response_detail = "Requested record was not found" + + def __init__(self, detail: str, response_detail: Optional[str] = None, entity_name: Optional[str] = None): + """ + Initialise the exception. + + :param detail: Specific detail of the exception (just like Exception would take - this will only be logged + and not returned in a response). + :param response_detail: Generic detail of the exception to be returned in the response. + :param entity_name: Name of the entity to include in the response detail. + """ + super().__init__(detail, response_detail) + + if entity_name is not None: + self.response_detail = f"{entity_name.capitalize()} not found" diff --git a/object_storage_api/repositories/image.py b/object_storage_api/repositories/image.py index 09d4191..c01c3b6 100644 --- a/object_storage_api/repositories/image.py +++ b/object_storage_api/repositories/image.py @@ -10,6 +10,7 @@ from object_storage_api.core.custom_object_id import CustomObjectId from object_storage_api.core.database import DatabaseDep +from object_storage_api.core.exceptions import InvalidObjectIdError, MissingRecordError from object_storage_api.models.image import ImageIn, ImageOut logger = logging.getLogger() @@ -42,20 +43,27 @@ def create(self, image: ImageIn, session: ClientSession = None) -> ImageOut: result = self._images_collection.insert_one(image.model_dump(by_alias=True), session=session) return self.get(str(result.inserted_id), session=session) - def get(self, image_id: str, session: ClientSession = None) -> Optional[ImageOut]: + def get(self, image_id: str, session: ClientSession = None) -> ImageOut: """ Retrieve an image by its ID from a MongoDB database. :param image_id: ID of the image to retrieve. :param session: PyMongo ClientSession to use for database operations. - :return: Retrieved image or `None` if not found. + :return: Retrieved image if found. + :raises MissingRecordError: If the supplied `image_id` is non-existent. + :raises InvalidObjectIdError: If the supplied `image_id` is invalid. """ - image_id = CustomObjectId(image_id) logger.info("Retrieving image with ID: %s from the database", image_id) - image = self._images_collection.find_one({"_id": image_id}, session=session) + try: + image_id = CustomObjectId(image_id) + image = self._images_collection.find_one({"_id": image_id}, session=session) + except InvalidObjectIdError as exc: + exc.status_code = 404 + exc.response_detail = "Image not found" + raise exc if image: return ImageOut(**image) - return None + raise MissingRecordError(detail=f"No image found with ID: {image_id}", entity_name="image") def list(self, entity_id: Optional[str], primary: Optional[bool], session: ClientSession = None) -> list[ImageOut]: """ diff --git a/object_storage_api/routers/image.py b/object_storage_api/routers/image.py index 4c42415..73dcca9 100644 --- a/object_storage_api/routers/image.py +++ b/object_storage_api/routers/image.py @@ -6,9 +6,9 @@ import logging from typing import Annotated, Optional -from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status +from fastapi import APIRouter, Depends, File, Form, Path, Query, UploadFile, status -from object_storage_api.schemas.image import ImagePostMetadataSchema, ImageSchema +from object_storage_api.schemas.image import ImageMetadataSchema, ImagePostMetadataSchema, ImageSchema from object_storage_api.services.image import ImageService logger = logging.getLogger() @@ -36,7 +36,7 @@ def create_image( upload_file: Annotated[UploadFile, File(description="Image file")], title: Annotated[Optional[str], Form(description="Title of the image")] = None, description: Annotated[Optional[str], Form(description="Description of the image")] = None, -) -> ImageSchema: +) -> ImageMetadataSchema: # pylint: disable=missing-function-docstring logger.info("Creating a new image") @@ -57,7 +57,7 @@ def get_images( image_service: ImageServiceDep, entity_id: Annotated[Optional[str], Query(description="Filter images by entity ID")] = None, primary: Annotated[Optional[bool], Query(description="Filter images by primary")] = None, -) -> list[ImageSchema]: +) -> list[ImageMetadataSchema]: # pylint: disable=missing-function-docstring logger.info("Getting images") @@ -67,3 +67,14 @@ def get_images( logger.debug("Primary filter: '%s'", primary) return image_service.list(entity_id, primary) + + +@router.get(path="/{image_id}", summary="Get an image by ID", response_description="Single image") +def get_image( + image_id: Annotated[str, Path(description="ID of the image to get")], + image_service: ImageServiceDep, +) -> ImageSchema: + # pylint: disable=missing-function-docstring + logger.info("Getting image with ID: %s", image_id) + + return image_service.get(image_id) diff --git a/object_storage_api/schemas/image.py b/object_storage_api/schemas/image.py index e422100..2eea290 100644 --- a/object_storage_api/schemas/image.py +++ b/object_storage_api/schemas/image.py @@ -4,27 +4,29 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, HttpUrl from object_storage_api.schemas.mixins import CreatedModifiedSchemaMixin class ImagePostMetadataSchema(BaseModel): - """ - Base schema model for an image. - """ + """Base schema model for an image.""" entity_id: str = Field(description="ID of the entity the image relates to") title: Optional[str] = Field(default=None, description="Title of the image") description: Optional[str] = Field(default=None, description="Description of the image") -class ImageSchema(CreatedModifiedSchemaMixin, ImagePostMetadataSchema): - """ - Schema model for an image get request response. - """ +class ImageMetadataSchema(CreatedModifiedSchemaMixin, ImagePostMetadataSchema): + """Schema model for an image's metadata.""" id: str = Field(description="ID of the image") file_name: str = Field(description="File name of the image") primary: bool = Field(description="Whether the image is the primary for its related entity") thumbnail_base64: str = Field(description="Thumbnail of the image as a base64 encoded byte string") + + +class ImageSchema(ImageMetadataSchema): + """Schema model for an image get request response.""" + + url: HttpUrl = Field(description="Presigned get URL to get the image file") diff --git a/object_storage_api/services/image.py b/object_storage_api/services/image.py index 24420f3..1157510 100644 --- a/object_storage_api/services/image.py +++ b/object_storage_api/services/image.py @@ -13,7 +13,7 @@ from object_storage_api.core.image import generate_thumbnail_base64_str from object_storage_api.models.image import ImageIn from object_storage_api.repositories.image import ImageRepo -from object_storage_api.schemas.image import ImagePostMetadataSchema, ImageSchema +from object_storage_api.schemas.image import ImageMetadataSchema, ImagePostMetadataSchema, ImageSchema from object_storage_api.stores.image import ImageStore logger = logging.getLogger() @@ -38,7 +38,7 @@ def __init__( self._image_repository = image_repository self._image_store = image_store - def create(self, image_metadata: ImagePostMetadataSchema, upload_file: UploadFile) -> ImageSchema: + def create(self, image_metadata: ImagePostMetadataSchema, upload_file: UploadFile) -> ImageMetadataSchema: """ Create a new image. @@ -73,9 +73,20 @@ def create(self, image_metadata: ImagePostMetadataSchema, upload_file: UploadFil image_out = self._image_repository.create(image_in) - return ImageSchema(**image_out.model_dump()) + return ImageMetadataSchema(**image_out.model_dump()) - def list(self, entity_id: Optional[str] = None, primary: Optional[bool] = None) -> list[ImageSchema]: + def get(self, image_id: str) -> ImageSchema: + """ + Retrieve an image's metadata with its presigned get url by its ID. + + :param image_id: ID of the image to retrieve. + :return: An image's metadata with a presigned get url. + """ + image = self._image_repository.get(image_id=image_id) + presigned_url = self._image_store.create_presigned_get(image) + return ImageSchema(**image.model_dump(), url=presigned_url) + + def list(self, entity_id: Optional[str] = None, primary: Optional[bool] = None) -> list[ImageMetadataSchema]: """ Retrieve a list of images based on the provided filters. @@ -84,4 +95,4 @@ def list(self, entity_id: Optional[str] = None, primary: Optional[bool] = None) :return: List of images or an empty list if no images are retrieved. """ images = self._image_repository.list(entity_id, primary) - return [ImageSchema(**image.model_dump()) for image in images] + return [ImageMetadataSchema(**image.model_dump()) for image in images] diff --git a/object_storage_api/stores/image.py b/object_storage_api/stores/image.py index 8d01c31..26083c9 100644 --- a/object_storage_api/stores/image.py +++ b/object_storage_api/stores/image.py @@ -7,6 +7,7 @@ from fastapi import UploadFile from object_storage_api.core.object_store import object_storage_config, s3_client +from object_storage_api.models.image import ImageOut from object_storage_api.schemas.image import ImagePostMetadataSchema logger = logging.getLogger() @@ -37,3 +38,23 @@ def upload(self, image_id: str, image_metadata: ImagePostMetadataSchema, upload_ ) return object_key + + def create_presigned_get(self, image: ImageOut) -> str: + """ + Generate a presigned URL to share an S3 object. + + :param image: `ImageOut` model of the image. + :return: Presigned url to get the image. + """ + logger.info("Generating presigned url to get image from object storage") + response = s3_client.generate_presigned_url( + "get_object", + Params={ + "Bucket": object_storage_config.bucket_name.get_secret_value(), + "Key": image.object_key, + "ResponseContentDisposition": f'inline; filename="{image.file_name}"', + }, + ExpiresIn=object_storage_config.presigned_url_expiry_seconds, + ) + + return response diff --git a/test/e2e/test_image.py b/test/e2e/test_image.py index f070388..0c4bd84 100644 --- a/test/e2e/test_image.py +++ b/test/e2e/test_image.py @@ -4,7 +4,8 @@ from test.mock_data import ( IMAGE_GET_DATA_ALL_VALUES, - IMAGE_GET_DATA_REQUIRED_VALUES_ONLY, + IMAGE_GET_METADATA_ALL_VALUES, + IMAGE_GET_METADATA_REQUIRED_VALUES_ONLY, IMAGE_POST_METADATA_DATA_ALL_VALUES, IMAGE_POST_METADATA_DATA_REQUIRED_VALUES_ONLY, ) @@ -52,7 +53,7 @@ def check_post_image_success(self, expected_image_get_data: dict) -> None: Checks that a prior call to `post_image` gave a successful response with the expected data returned. :param expected_image_get_data: Dictionary containing the expected image data returned as would be - required for an `ImageSchema`. + required for an `ImageMetadataSchema`. """ assert self._post_response_image.status_code == 201 @@ -77,13 +78,13 @@ def test_create_with_only_required_values_provided(self): """Test creating an image with only required values provided.""" self.post_image(IMAGE_POST_METADATA_DATA_REQUIRED_VALUES_ONLY, "image.jpg") - self.check_post_image_success(IMAGE_GET_DATA_REQUIRED_VALUES_ONLY) + self.check_post_image_success(IMAGE_GET_METADATA_REQUIRED_VALUES_ONLY) def test_create_with_all_values_provided(self): """Test creating an image with all values provided.""" self.post_image(IMAGE_POST_METADATA_DATA_ALL_VALUES, "image.jpg") - self.check_post_image_success(IMAGE_GET_DATA_ALL_VALUES) + self.check_post_image_success(IMAGE_GET_METADATA_ALL_VALUES) def test_create_with_invalid_entity_id(self): """Test creating an image with an invalid `entity_id`.""" @@ -98,13 +99,60 @@ def test_create_with_invalid_image_file(self): self.check_post_image_failed_with_detail(422, "File given is not a valid image") -# pylint:disable=fixme -# TODO: Inherit from GetDSL when added -class ListDSL(CreateDSL): - """Base class for list tests.""" +class GetDSL(CreateDSL): + """Base class for get tests.""" _get_response_image: Response + def get_image(self, image_id: str) -> None: + """ + Gets an image with the given ID. + + :param image_id: The ID of the image to be obtained. + """ + self._get_response_image = self.test_client.get(f"/images/{image_id}") + + def check_get_image_success(self, expected_image_data: dict) -> None: + """ + Checks that a prior call to `get_image` gave a successful response with the expected data returned. + + :param expected_image_data: Dictionary containing the expected image data as would be required + for an `ImageMetadataSchema`. + """ + assert self._get_response_image.status_code == 200 + assert self._get_response_image.json() == expected_image_data + + def check_get_image_failed(self) -> None: + """Checks that prior call to `get_image` gave a failed response.""" + + assert self._get_response_image.status_code == 404 + assert self._get_response_image.json()["detail"] == "Image not found" + + +class TestGet(GetDSL): + """Tests for getting an image.""" + + def test_get_with_valid_image_id(self): + """Test getting an image with a valid image ID.""" + image_id = self.post_image(IMAGE_POST_METADATA_DATA_ALL_VALUES, "image.jpg") + self.get_image(image_id) + self.check_get_image_success(IMAGE_GET_DATA_ALL_VALUES) + + def test_get_with_invalid_image_id(self): + """Test getting an image with an invalid image ID.""" + self.get_image("sdfgfsdg") + self.check_get_image_failed() + + def test_get_with_non_existent_image_id(self): + """Test getting an image with a non-existent image ID.""" + image_id = str(ObjectId()) + self.get_image(image_id) + self.check_get_image_failed() + + +class ListDSL(GetDSL): + """Base class for list tests.""" + def get_images(self, filters: Optional[dict] = None) -> None: """ Gets a list of images with the given filters. @@ -118,7 +166,7 @@ def post_test_images(self) -> list[dict]: Posts three images. The first two images have the same entity ID, the last image has a different one. :return: List of dictionaries containing the expected item data returned from a get endpoint in - the form of an `ImageSchema`. + the form of an `ImageMetadataSchema`. """ entity_id_a, entity_id_b = (str(ObjectId()) for _ in range(2)) @@ -145,17 +193,17 @@ def post_test_images(self) -> list[dict]: return [ { - **IMAGE_GET_DATA_ALL_VALUES, + **IMAGE_GET_METADATA_ALL_VALUES, "entity_id": entity_id_a, "id": image_a_id, }, { - **IMAGE_GET_DATA_ALL_VALUES, + **IMAGE_GET_METADATA_ALL_VALUES, "entity_id": entity_id_a, "id": image_b_id, }, { - **IMAGE_GET_DATA_ALL_VALUES, + **IMAGE_GET_METADATA_ALL_VALUES, "entity_id": entity_id_b, "id": image_c_id, }, @@ -166,7 +214,7 @@ def check_get_images_success(self, expected_images_get_data: list[dict]) -> None Checks that a prior call to `get_images` gave a successful response with the expected data returned. :param expected_images_get_data: List of dictionaries containing the expected image data as would - be required for an `ImageSchema`. + be required for an `ImageMetadataSchema`. """ assert self._get_response_image.status_code == 200 assert self._get_response_image.json() == expected_images_get_data diff --git a/test/mock_data.py b/test/mock_data.py index a05e979..4b8ef76 100644 --- a/test/mock_data.py +++ b/test/mock_data.py @@ -111,7 +111,7 @@ "entity_id": str(ObjectId()), } -IMAGE_GET_DATA_REQUIRED_VALUES_ONLY = { +IMAGE_GET_METADATA_REQUIRED_VALUES_ONLY = { **IMAGE_POST_METADATA_DATA_REQUIRED_VALUES_ONLY, **CREATED_MODIFIED_GET_DATA_EXPECTED, "id": ANY, @@ -122,6 +122,12 @@ "description": None, } +IMAGE_GET_DATA_REQUIRED_VALUES_ONLY = { + **IMAGE_GET_METADATA_REQUIRED_VALUES_ONLY, + "url": ANY, +} + + IMAGE_POST_METADATA_DATA_ALL_VALUES = { **IMAGE_POST_METADATA_DATA_REQUIRED_VALUES_ONLY, "title": "Report Title", @@ -136,7 +142,7 @@ "thumbnail_base64": "UklGRjQAAABXRUJQVlA4ICgAAADQAQCdASoCAAEAAUAmJYwCdAEO/gOOAAD+qlQWHDxhNJOjVlqIb8AA", } -IMAGE_GET_DATA_ALL_VALUES = { +IMAGE_GET_METADATA_ALL_VALUES = { **IMAGE_POST_METADATA_DATA_ALL_VALUES, **CREATED_MODIFIED_GET_DATA_EXPECTED, "id": ANY, @@ -144,3 +150,8 @@ "primary": False, "thumbnail_base64": "UklGRjQAAABXRUJQVlA4ICgAAADQAQCdASoCAAEAAUAmJYwCdAEO/gOOAAD+qlQWHDxhNJOjVlqIb8AA", } + +IMAGE_GET_DATA_ALL_VALUES = { + **IMAGE_GET_METADATA_ALL_VALUES, + "url": ANY, +} diff --git a/test/unit/repositories/test_image.py b/test/unit/repositories/test_image.py index c50fdcc..94e1684 100644 --- a/test/unit/repositories/test_image.py +++ b/test/unit/repositories/test_image.py @@ -10,6 +10,7 @@ import pytest from bson import ObjectId +from object_storage_api.core.exceptions import InvalidObjectIdError, MissingRecordError from object_storage_api.models.image import ImageIn, ImageOut from object_storage_api.repositories.image import ImageRepo @@ -88,6 +89,109 @@ def test_create(self): self.check_create_success() +class GetDSL(ImageRepoDSL): + """Base class for `get` tests.""" + + _obtained_image_id: str + _expected_image_out: ImageOut + _obtained_image_out: ImageOut + _get_exception: pytest.ExceptionInfo + + def mock_get(self, image_id: str, image_in_data: dict) -> None: + """ + Mocks database methods appropriately to test the `get` repo method. + + :param image_id: ID of the image to obtain. + :param image_in_data: Dictionary containing the image data as would be required for an + `ImageIn` database model (i.e. no created and modified times required). + """ + if image_in_data: + image_in_data["id"] = image_id + self._expected_image_out = ImageOut(**ImageIn(**image_in_data).model_dump()) if image_in_data else None + + RepositoryTestHelpers.mock_find_one( + self.images_collection, self._expected_image_out.model_dump() if self._expected_image_out else None + ) + + def call_get(self, image_id: str) -> None: + """ + Calls the `ImageRepo` `get` method. + + :param image_id: The ID of the image to obtain. + """ + self._obtained_image_id = image_id + self._obtained_image_out = self.image_repository.get(image_id=image_id, session=self.mock_session) + + def call_get_expecting_error(self, image_id: str, error_type: type[BaseException]) -> None: + """ + Calls the `ImageRepo` `get` method with the appropriate data from a prior call to `mock_get` + while expecting an error to be raised. + + :param image_id: ID of the image to be obtained. + :param error_type: Expected exception to be raised. + """ + self._obtained_image_id = image_id + with pytest.raises(error_type) as exc: + self.image_repository.get(image_id, session=self.mock_session) + self._get_exception = exc + + def check_get_success(self) -> None: + """Checks that a prior call to `call_get` worked as expected.""" + + self.images_collection.find_one.assert_called_once_with( + {"_id": ObjectId(self._obtained_image_id)}, session=self.mock_session + ) + assert self._obtained_image_out == self._expected_image_out + + def check_get_failed_with_exception(self, message: str, assert_find: bool = False) -> None: + """ + Checks that a prior call to `call_get_expecting_error` worked as expected, raising an exception + with the correct message. + + :param image_id: ID of the expected image to appear in the exception detail. + :param assert_find: If `True` it asserts whether a `find_one` call was made, + else it asserts that no call was made. + """ + if assert_find: + self.images_collection.find_one.assert_called_once_with( + {"_id": ObjectId(self._obtained_image_id)}, session=self.mock_session + ) + else: + self.images_collection.find_one.assert_not_called() + + assert str(self._get_exception.value) == message + + +class TestGet(GetDSL): + """Tests for getting images.""" + + def test_get(self): + """Test getting an image.""" + + image_id = str(ObjectId()) + + self.mock_get(image_id, IMAGE_IN_DATA_ALL_VALUES) + self.call_get(image_id) + self.check_get_success() + + def test_get_with_non_existent_id(self): + """Test getting an image with a non-existent image ID.""" + + image_id = str(ObjectId()) + + self.mock_get(image_id, None) + self.call_get_expecting_error(image_id, MissingRecordError) + self.check_get_failed_with_exception(f"No image found with ID: {image_id}", True) + + def test_get_with_invalid_id(self): + """Test getting an image with an invalid image ID.""" + image_id = "invalid-id" + + self.mock_get(image_id, None) + self.call_get_expecting_error(image_id, InvalidObjectIdError) + self.check_get_failed_with_exception(f"Invalid ObjectId value '{image_id}'") + + class ListDSL(ImageRepoDSL): """Base class for `list` tests.""" diff --git a/test/unit/services/test_image.py b/test/unit/services/test_image.py index 107dd67..07bb88f 100644 --- a/test/unit/services/test_image.py +++ b/test/unit/services/test_image.py @@ -12,7 +12,7 @@ from object_storage_api.core.exceptions import InvalidObjectIdError from object_storage_api.models.image import ImageIn, ImageOut -from object_storage_api.schemas.image import ImagePostMetadataSchema, ImageSchema +from object_storage_api.schemas.image import ImageMetadataSchema, ImagePostMetadataSchema, ImageSchema from object_storage_api.services.image import ImageService @@ -58,8 +58,8 @@ class CreateDSL(ImageServiceDSL): _upload_file: UploadFile _expected_image_id: ObjectId _expected_image_in: ImageIn - _expected_image: ImageSchema - _created_image: ImageSchema + _expected_image: ImageMetadataSchema + _created_image: ImageMetadataSchema _create_exception: pytest.ExceptionInfo def mock_create(self, image_post_metadata_data: dict) -> None: @@ -98,7 +98,7 @@ def mock_create(self, image_post_metadata_data: dict) -> None: expected_image_out = ImageOut(**self._expected_image_in.model_dump(by_alias=True)) self.mock_image_repository.create.return_value = expected_image_out - self._expected_image = ImageSchema(**expected_image_out.model_dump()) + self._expected_image = ImageMetadataSchema(**expected_image_out.model_dump()) def call_create(self) -> None: """Calls the `ImageService` `create` method with the appropriate data from a prior call to @@ -163,13 +163,57 @@ def test_create_with_invalid_entity_id(self): self.check_create_failed_with_exception("Invalid ObjectId value 'invalid-id'") +class GetDSL(ImageServiceDSL): + """Base class for `get` tests.""" + + _obtained_image_id: str + _expected_image_out: ImageOut + _expected_image: ImageSchema + _obtained_image: ImageSchema + + def mock_get(self) -> None: + """Mocks repo methods appropriately to test the `get` service method.""" + + self._expected_image_out = ImageOut(**ImageIn(**IMAGE_IN_DATA_ALL_VALUES).model_dump()) + self.mock_image_repository.get.return_value = self._expected_image_out + self.mock_image_store.create_presigned_get.return_value = "https://fakepresignedurl.co.uk" + self._expected_image = ImageSchema( + **self._expected_image_out.model_dump(), url="https://fakepresignedurl.co.uk" + ) + + def call_get(self, image_id: str) -> None: + """ + Calls the `ImageService` `get` method. + + :param image_id: The ID of the image to obtain. + """ + self._obtained_image_id = image_id + self._obtained_image = self.image_service.get(image_id=image_id) + + def check_get_success(self) -> None: + """Checks that a prior call to `call_get` worked as expected.""" + self.mock_image_repository.get.assert_called_once_with(image_id=self._obtained_image_id) + self.mock_image_store.create_presigned_get.assert_called_once_with(self._expected_image_out) + assert self._obtained_image == self._expected_image + + +class TestGet(GetDSL): + """Tests for getting images.""" + + def test_get(self): + """Test getting images.""" + self.mock_get() + self.call_get(str(ObjectId())) + self.check_get_success() + + class ListDSL(ImageServiceDSL): """Base class for `list` tests.""" _entity_id_filter: Optional[str] _primary_filter: Optional[str] - _expected_images: List[ImageSchema] - _obtained_images: List[ImageSchema] + _expected_images: List[ImageMetadataSchema] + _obtained_images: List[ImageMetadataSchema] def mock_list(self) -> None: """Mocks repo methods appropriately to test the `list` service method.""" @@ -177,7 +221,7 @@ def mock_list(self) -> None: # Just returns the result after converting it to the schemas currently, so actual value doesn't matter here images_out = [ImageOut(**ImageIn(**IMAGE_IN_DATA_ALL_VALUES).model_dump())] self.mock_image_repository.list.return_value = images_out - self._expected_images = [ImageSchema(**image_out.model_dump()) for image_out in images_out] + self._expected_images = [ImageMetadataSchema(**image_out.model_dump()) for image_out in images_out] def call_list(self, entity_id: Optional[str] = None, primary: Optional[bool] = None) -> None: """Calls the `ImageService` `list` method. diff --git a/test/unit/stores/test_image.py b/test/unit/stores/test_image.py index 8a47124..55282b9 100644 --- a/test/unit/stores/test_image.py +++ b/test/unit/stores/test_image.py @@ -2,7 +2,7 @@ Unit tests for the `ImageStore` store. """ -from test.mock_data import IMAGE_POST_METADATA_DATA_ALL_VALUES +from test.mock_data import IMAGE_IN_DATA_ALL_VALUES, IMAGE_POST_METADATA_DATA_ALL_VALUES from unittest.mock import MagicMock, patch import pytest @@ -10,6 +10,7 @@ from fastapi import UploadFile from object_storage_api.core.object_store import object_storage_config +from object_storage_api.models.image import ImageIn, ImageOut from object_storage_api.schemas.image import ImagePostMetadataSchema from object_storage_api.stores.image import ImageStore @@ -82,3 +83,58 @@ def test_upload(self): self.mock_upload(IMAGE_POST_METADATA_DATA_ALL_VALUES) self.call_upload() self.check_upload_success() + + +class CreatePresignedURLDSL(ImageStoreDSL): + """Base class for `create` tests.""" + + _image_out: ImageOut + _expected_presigned_url: str + _obtained_presigned_url: str + + def mock_create_presigned_get(self, image_in_data: dict) -> None: + """ + Mocks object store methods appropriately to test the `create_presigned_get` store method. + + :param image_in_data: Dictionary containing the image data as would be required for an + `ImageIn`. + """ + self._image_out = ImageOut(**ImageIn(**image_in_data).model_dump()) + + # Mock presigned url generation + self._expected_presigned_url = "example_presigned_url" + self.mock_s3_client.generate_presigned_url.return_value = self._expected_presigned_url + + def call_create_presigned_get(self) -> None: + """ + Calls the `ImageStore` `create_presigned_get` method with the appropriate data from a prior call to + `mock_create_presigned_get`. + """ + + self._obtained_presigned_url = self.image_store.create_presigned_get(self._image_out) + + def check_create_presigned_get_success(self) -> None: + """Checks that a prior call to `call_create_presigned_get` worked as expected.""" + + self.mock_s3_client.generate_presigned_url.assert_called_once_with( + "get_object", + Params={ + "Bucket": object_storage_config.bucket_name.get_secret_value(), + "Key": self._image_out.object_key, + "ResponseContentDisposition": f'inline; filename="{self._image_out.file_name}"', + }, + ExpiresIn=object_storage_config.presigned_url_expiry_seconds, + ) + + assert self._obtained_presigned_url == self._expected_presigned_url + + +class TestCreatePresignedURL(CreatePresignedURLDSL): + """Tests for creating a presigned url for an image.""" + + def test_create_presigned_get(self): + """Test creating a presigned url for an image.""" + + self.mock_create_presigned_get(IMAGE_IN_DATA_ALL_VALUES) + self.call_create_presigned_get() + self.check_create_presigned_get_success()