Skip to content

Commit

Permalink
Fix media_id generation in generate_frames function and update corres…
Browse files Browse the repository at this point in the history
…ponding tests
  • Loading branch information
Aleksandr Movchan committed Dec 6, 2024
1 parent 696ae56 commit 3f0d594
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aana/integrations/external/decord.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def generate_frames(
batch_frames_array = video_reader.get_batch(batch).asnumpy()
batch_frames = []
for frame_id, frame in enumerate(batch_frames_array):
img = Image(numpy=frame, media_id=f"{video.media_id}_frame_{frame_id}")
img = Image(numpy=frame, media_id=f"{video.media_id}_frame_{i+frame_id}")
batch_frames.append(img)

batch_timestamps = timestamps[i : i + batch_size]
Expand Down
5 changes: 5 additions & 0 deletions aana/tests/units/test_frame_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def test_generate_frames_success(
gen_frame = generate_frames(video=video, params=params, batch_size=1)
total_frames = 0
frame_ids = []
frames_media_ids = []
for result in gen_frame:
assert "frames" in result
assert "timestamps" in result
Expand All @@ -110,8 +111,12 @@ def test_generate_frames_success(
total_frames += 1
assert result["duration"] == expected_duration
frame_ids.extend(result["frame_ids"])
frames_media_ids.extend([frame.media_id for frame in result["frames"]])

assert frame_ids == list(range(expected_num_frames))
assert frames_media_ids == [
f"{video.media_id}_frame_{frame_id}" for frame_id in range(expected_num_frames)
]
assert total_frames == expected_num_frames


Expand Down

0 comments on commit 3f0d594

Please sign in to comment.