From c665b76a123a05a7c65c807c7f95bb7f1d45189c Mon Sep 17 00:00:00 2001 From: movchan74 Date: Mon, 13 Nov 2023 16:20:23 +0000 Subject: [PATCH] Implemented PR suggestions --- aana/models/core/media.py | 2 +- aana/models/pydantic/image_input.py | 18 ++++++++++++++++++ aana/models/pydantic/video_input.py | 26 ++++++++++++++++++++++---- aana/tests/test_image_input.py | 8 ++++++++ aana/tests/test_video_input.py | 18 +++++++++++++----- aana/utils/video.py | 18 ++++++++++-------- 6 files changed, 72 insertions(+), 18 deletions(-) diff --git a/aana/models/core/media.py b/aana/models/core/media.py index 9e83ceda..caf7c6d3 100644 --- a/aana/models/core/media.py +++ b/aana/models/core/media.py @@ -11,7 +11,7 @@ @dataclass class Media: """ - A base class representing a media. + A base class representing a media file. It is used to represent images, medias, and audio files. diff --git a/aana/models/pydantic/image_input.py b/aana/models/pydantic/image_input.py index ca8a43b5..58d2e28a 100644 --- a/aana/models/pydantic/image_input.py +++ b/aana/models/pydantic/image_input.py @@ -50,6 +50,24 @@ class ImageInput(BaseModel): description="The ID of the image. If not provided, it will be generated automatically.", ) + @validator("media_id") + def media_id_must_not_be_empty(cls, media_id): + """ + Validates that the media_id is not an empty string. + + Args: + media_id (str): The value of the media_id field. + + Raises: + ValueError: If the media_id is an empty string. + + Returns: + str: The non-empty media_id value. + """ + if media_id == "": + raise ValueError("media_id cannot be an empty string") + return media_id + def set_file(self, file: bytes): """ If 'content' or 'numpy' is set to 'file', diff --git a/aana/models/pydantic/video_input.py b/aana/models/pydantic/video_input.py index 744173a4..e4a780c7 100644 --- a/aana/models/pydantic/video_input.py +++ b/aana/models/pydantic/video_input.py @@ -38,6 +38,24 @@ class VideoInput(BaseModel): description="The ID of the video. If not provided, it will be generated automatically.", ) + @validator("media_id") + def media_id_must_not_be_empty(cls, media_id): + """ + Validates that the media_id is not an empty string. + + Args: + media_id (str): The value of the media_id field. + + Raises: + ValueError: If the media_id is an empty string. + + Returns: + str: The non-empty media_id value. + """ + if media_id == "": + raise ValueError("media_id cannot be an empty string") + return media_id + @root_validator def check_only_one_field(cls, values): """ @@ -145,12 +163,12 @@ class VideoInputList(BaseListModel): __root__: List[VideoInput] @validator("__root__", pre=True) - def check_non_empty(cls, v: List[VideoInput]) -> List[VideoInput]: + def check_non_empty(cls, videos: List[VideoInput]) -> List[VideoInput]: """ Check that the list of videos isn't empty. Args: - v (List[VideoInput]): the list of videos + videos (List[VideoInput]): the list of videos Returns: List[VideoInput]: the list of videos @@ -158,9 +176,9 @@ def check_non_empty(cls, v: List[VideoInput]) -> List[VideoInput]: Raises: ValueError: if the list of videos is empty """ - if len(v) == 0: + if len(videos) == 0: raise ValueError("The list of videos must not be empty.") - return v + return videos def set_files(self, files: List[bytes]): """ diff --git a/aana/tests/test_image_input.py b/aana/tests/test_image_input.py index 83b3427a..a5a8b438 100644 --- a/aana/tests/test_image_input.py +++ b/aana/tests/test_image_input.py @@ -36,6 +36,14 @@ def test_new_imageinput_success(): assert image_input.numpy == b"file" +def test_imageinput_invalid_media_id(): + """ + Test that ImageInput can't be created if media_id is invalid. + """ + with pytest.raises(ValidationError): + ImageInput(path="image.png", media_id="") + + def test_imageinput_check_only_one_field(): """ Test that exactly one of 'path', 'url', 'content', or 'numpy' is provided. diff --git a/aana/tests/test_video_input.py b/aana/tests/test_video_input.py index f20a9d70..fba2df33 100644 --- a/aana/tests/test_video_input.py +++ b/aana/tests/test_video_input.py @@ -31,6 +31,14 @@ def test_new_videoinput_success(): assert video_input.content == b"file" +def test_videoinput_invalid_media_id(): + """ + Test that VideoInput can't be created if media_id is invalid. + """ + with pytest.raises(ValidationError): + VideoInput(path="video.mp4", media_id="") + + def test_videoinput_check_only_one_field(): """ Test that exactly one of 'path', 'url', or 'content' is provided. @@ -131,9 +139,9 @@ def test_videoinput_convert_input_to_object(mock_download_file): video_object.cleanup() -def test_videolistinput(): +def test_videoinputlist(): """ - Test that VideoListInput can be created successfully. + Test that VideoInputList can be created successfully. """ videos = [ VideoInput(path="video.mp4"), @@ -149,7 +157,7 @@ def test_videolistinput(): assert video_list_input[2] == videos[2] -def test_videolistinput_set_files(): +def test_videoinputlist_set_files(): """ Test that the files can be set for the video list. """ @@ -178,9 +186,9 @@ def test_videolistinput_set_files(): video_list_input.set_files(files) -def test_videolistinput_non_empty(): +def test_videoinputlist_non_empty(): """ - Test that VideoListInput must not be empty. + Test that videoinputlist must not be empty. """ with pytest.raises(ValidationError): VideoInputList(__root__=[]) diff --git a/aana/utils/video.py b/aana/utils/video.py index 8d7c948b..07c0a9e6 100644 --- a/aana/utils/video.py +++ b/aana/utils/video.py @@ -1,13 +1,19 @@ -from typing import Any, Dict import decord import numpy as np from aana.exceptions.general import VideoReadingException from aana.models.core.image import Image from aana.models.core.video import Video from aana.models.pydantic.video_params import VideoParams +from typing import List, TypedDict -def extract_frames_decord(video: Video, params: VideoParams) -> Dict[str, Any]: +class FramesDict(TypedDict): + frames: List[Image] + timestamps: List[float] + duration: float + + +def extract_frames_decord(video: Video, params: VideoParams) -> FramesDict: """ Extract frames from a video using decord. @@ -16,7 +22,7 @@ def extract_frames_decord(video: Video, params: VideoParams) -> Dict[str, Any]: params (VideoParams): the parameters of the video extraction Returns: - Dict[str, Any]: a dictionary containing the extracted frames, timestamps, and duration + FramesDict: a dictionary containing the extracted frames, timestamps, and duration """ device = decord.cpu(0) num_threads = 1 # TODO: see if we can use more threads @@ -47,8 +53,4 @@ def extract_frames_decord(video: Video, params: VideoParams) -> Dict[str, Any]: img = Image(numpy=frame) frames.append(img) - return { - "frames": frames, - "timestamps": timestamps, - "duration": duration, - } + return FramesDict(frames=frames, timestamps=timestamps, duration=duration)