diff --git a/aana/integrations/external/decord.py b/aana/integrations/external/decord.py index 991db2cb..424b1dc6 100644 --- a/aana/integrations/external/decord.py +++ b/aana/integrations/external/decord.py @@ -176,13 +176,13 @@ def generate_frames( batch_frames_array = video_reader.get_batch(batch).asnumpy() batch_frames = [] for frame_id, frame in enumerate(batch_frames_array): - img = Image(numpy=frame, media_id=f"{video.media_id}_frame_{frame_id}") + img = Image(numpy=frame, media_id=f"{video.media_id}_frame_{i+frame_id}") batch_frames.append(img) batch_timestamps = timestamps[i : i + batch_size] yield FramesDict( frames=batch_frames, - frame_ids=list(range(len(batch_frames))), + frame_ids=list(range(i, i + len(batch_frames))), timestamps=batch_timestamps, duration=duration, ) diff --git a/aana/tests/units/test_frame_extraction.py b/aana/tests/units/test_frame_extraction.py index 8d3835f5..bc78355b 100644 --- a/aana/tests/units/test_frame_extraction.py +++ b/aana/tests/units/test_frame_extraction.py @@ -41,6 +41,7 @@ def test_extract_frames_success( assert isinstance(result["frames"][0], Image) assert result["duration"] == expected_duration assert len(result["frames"]) == expected_num_frames + assert result["frame_ids"] == list(range(expected_num_frames)) assert len(result["timestamps"]) == expected_num_frames @@ -93,6 +94,8 @@ def test_generate_frames_success( params = VideoParams(extract_fps=extract_fps, fast_mode_enabled=fast_mode_enabled) gen_frame = generate_frames(video=video, params=params, batch_size=1) total_frames = 0 + frame_ids = [] + frames_media_ids = [] for result in gen_frame: assert "frames" in result assert "timestamps" in result @@ -107,7 +110,13 @@ def test_generate_frames_success( assert len(result["timestamps"]) == 1 # batch_size = 1 total_frames += 1 assert result["duration"] == expected_duration + frame_ids.extend(result["frame_ids"]) + frames_media_ids.extend([frame.media_id for frame in result["frames"]]) + assert frame_ids == list(range(expected_num_frames)) + assert frames_media_ids == [ + f"{video.media_id}_frame_{frame_id}" for frame_id in range(expected_num_frames) + ] assert total_frames == expected_num_frames