Skip to content

Commit

Permalink
Implemented PR suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
movchan74 committed Nov 13, 2023
1 parent 3337777 commit c665b76
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 18 deletions.
2 changes: 1 addition & 1 deletion aana/models/core/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@dataclass
class Media:
"""
A base class representing a media.
A base class representing a media file.
It is used to represent images, medias, and audio files.
Expand Down
18 changes: 18 additions & 0 deletions aana/models/pydantic/image_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ class ImageInput(BaseModel):
description="The ID of the image. If not provided, it will be generated automatically.",
)

@validator("media_id")
def media_id_must_not_be_empty(cls, media_id):
"""
Validates that the media_id is not an empty string.
Args:
media_id (str): The value of the media_id field.
Raises:
ValueError: If the media_id is an empty string.
Returns:
str: The non-empty media_id value.
"""
if media_id == "":
raise ValueError("media_id cannot be an empty string")
return media_id

def set_file(self, file: bytes):
"""
If 'content' or 'numpy' is set to 'file',
Expand Down
26 changes: 22 additions & 4 deletions aana/models/pydantic/video_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ class VideoInput(BaseModel):
description="The ID of the video. If not provided, it will be generated automatically.",
)

@validator("media_id")
def media_id_must_not_be_empty(cls, media_id):
"""
Validates that the media_id is not an empty string.
Args:
media_id (str): The value of the media_id field.
Raises:
ValueError: If the media_id is an empty string.
Returns:
str: The non-empty media_id value.
"""
if media_id == "":
raise ValueError("media_id cannot be an empty string")
return media_id

@root_validator
def check_only_one_field(cls, values):
"""
Expand Down Expand Up @@ -145,22 +163,22 @@ class VideoInputList(BaseListModel):
__root__: List[VideoInput]

@validator("__root__", pre=True)
def check_non_empty(cls, v: List[VideoInput]) -> List[VideoInput]:
def check_non_empty(cls, videos: List[VideoInput]) -> List[VideoInput]:
"""
Check that the list of videos isn't empty.
Args:
v (List[VideoInput]): the list of videos
videos (List[VideoInput]): the list of videos
Returns:
List[VideoInput]: the list of videos
Raises:
ValueError: if the list of videos is empty
"""
if len(v) == 0:
if len(videos) == 0:
raise ValueError("The list of videos must not be empty.")
return v
return videos

def set_files(self, files: List[bytes]):
"""
Expand Down
8 changes: 8 additions & 0 deletions aana/tests/test_image_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def test_new_imageinput_success():
assert image_input.numpy == b"file"


def test_imageinput_invalid_media_id():
"""
Test that ImageInput can't be created if media_id is invalid.
"""
with pytest.raises(ValidationError):
ImageInput(path="image.png", media_id="")


def test_imageinput_check_only_one_field():
"""
Test that exactly one of 'path', 'url', 'content', or 'numpy' is provided.
Expand Down
18 changes: 13 additions & 5 deletions aana/tests/test_video_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def test_new_videoinput_success():
assert video_input.content == b"file"


def test_videoinput_invalid_media_id():
"""
Test that VideoInput can't be created if media_id is invalid.
"""
with pytest.raises(ValidationError):
VideoInput(path="video.mp4", media_id="")


def test_videoinput_check_only_one_field():
"""
Test that exactly one of 'path', 'url', or 'content' is provided.
Expand Down Expand Up @@ -131,9 +139,9 @@ def test_videoinput_convert_input_to_object(mock_download_file):
video_object.cleanup()


def test_videolistinput():
def test_videoinputlist():
"""
Test that VideoListInput can be created successfully.
Test that VideoInputList can be created successfully.
"""
videos = [
VideoInput(path="video.mp4"),
Expand All @@ -149,7 +157,7 @@ def test_videolistinput():
assert video_list_input[2] == videos[2]


def test_videolistinput_set_files():
def test_videoinputlist_set_files():
"""
Test that the files can be set for the video list.
"""
Expand Down Expand Up @@ -178,9 +186,9 @@ def test_videolistinput_set_files():
video_list_input.set_files(files)


def test_videolistinput_non_empty():
def test_videoinputlist_non_empty():
"""
Test that VideoListInput must not be empty.
Test that videoinputlist must not be empty.
"""
with pytest.raises(ValidationError):
VideoInputList(__root__=[])
18 changes: 10 additions & 8 deletions aana/utils/video.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from typing import Any, Dict
import decord
import numpy as np
from aana.exceptions.general import VideoReadingException
from aana.models.core.image import Image
from aana.models.core.video import Video
from aana.models.pydantic.video_params import VideoParams
from typing import List, TypedDict


def extract_frames_decord(video: Video, params: VideoParams) -> Dict[str, Any]:
class FramesDict(TypedDict):
frames: List[Image]
timestamps: List[float]
duration: float


def extract_frames_decord(video: Video, params: VideoParams) -> FramesDict:
"""
Extract frames from a video using decord.
Expand All @@ -16,7 +22,7 @@ def extract_frames_decord(video: Video, params: VideoParams) -> Dict[str, Any]:
params (VideoParams): the parameters of the video extraction
Returns:
Dict[str, Any]: a dictionary containing the extracted frames, timestamps, and duration
FramesDict: a dictionary containing the extracted frames, timestamps, and duration
"""
device = decord.cpu(0)
num_threads = 1 # TODO: see if we can use more threads
Expand Down Expand Up @@ -47,8 +53,4 @@ def extract_frames_decord(video: Video, params: VideoParams) -> Dict[str, Any]:
img = Image(numpy=frame)
frames.append(img)

return {
"frames": frames,
"timestamps": timestamps,
"duration": duration,
}
return FramesDict(frames=frames, timestamps=timestamps, duration=duration)

0 comments on commit c665b76

Please sign in to comment.