Skip to content

Commit

Permalink
Merge pull request #212 from mobiusml/generate_frames_ids_order_fix
Browse files Browse the repository at this point in the history
Fix frame_ids Generation in generate_frames Function
  • Loading branch information
movchan74 authored Dec 6, 2024
2 parents c263f06 + 3f0d594 commit d9bc745
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
4 changes: 2 additions & 2 deletions aana/integrations/external/decord.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,13 @@ def generate_frames(
batch_frames_array = video_reader.get_batch(batch).asnumpy()
batch_frames = []
for frame_id, frame in enumerate(batch_frames_array):
img = Image(numpy=frame, media_id=f"{video.media_id}_frame_{frame_id}")
img = Image(numpy=frame, media_id=f"{video.media_id}_frame_{i+frame_id}")
batch_frames.append(img)

batch_timestamps = timestamps[i : i + batch_size]
yield FramesDict(
frames=batch_frames,
frame_ids=list(range(len(batch_frames))),
frame_ids=list(range(i, i + len(batch_frames))),
timestamps=batch_timestamps,
duration=duration,
)
Expand Down
9 changes: 9 additions & 0 deletions aana/tests/units/test_frame_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_extract_frames_success(
assert isinstance(result["frames"][0], Image)
assert result["duration"] == expected_duration
assert len(result["frames"]) == expected_num_frames
assert result["frame_ids"] == list(range(expected_num_frames))
assert len(result["timestamps"]) == expected_num_frames


Expand Down Expand Up @@ -93,6 +94,8 @@ def test_generate_frames_success(
params = VideoParams(extract_fps=extract_fps, fast_mode_enabled=fast_mode_enabled)
gen_frame = generate_frames(video=video, params=params, batch_size=1)
total_frames = 0
frame_ids = []
frames_media_ids = []
for result in gen_frame:
assert "frames" in result
assert "timestamps" in result
Expand All @@ -107,7 +110,13 @@ def test_generate_frames_success(
assert len(result["timestamps"]) == 1 # batch_size = 1
total_frames += 1
assert result["duration"] == expected_duration
frame_ids.extend(result["frame_ids"])
frames_media_ids.extend([frame.media_id for frame in result["frames"]])

assert frame_ids == list(range(expected_num_frames))
assert frames_media_ids == [
f"{video.media_id}_frame_{frame_id}" for frame_id in range(expected_num_frames)
]
assert total_frames == expected_num_frames


Expand Down

0 comments on commit d9bc745

Please sign in to comment.