From c61b430bd7bdbb95ecbc6f9cfd203ce57674a28e Mon Sep 17 00:00:00 2001 From: Pu Fanyi Date: Fri, 20 Dec 2024 15:08:21 +0800 Subject: [PATCH] update load_video function parameters for improved flexibility in frame and segment handling --- lmms_eval/models/internvl2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lmms_eval/models/internvl2.py b/lmms_eval/models/internvl2.py index ae4cc0c8..b258ad67 100644 --- a/lmms_eval/models/internvl2.py +++ b/lmms_eval/models/internvl2.py @@ -100,7 +100,7 @@ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): return frame_indices -def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): +def load_video(video_path, bound=None, input_size=448, max_num=32, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 fps = float(vr.get_avg_fps()) @@ -134,6 +134,8 @@ def __init__( device: str = "cuda:0", device_map: str = "cuda:0", batch_size: str = "1", + num_frame: int = 32, + num_segments: int = 32, **kwargs, ): super().__init__() @@ -141,6 +143,8 @@ def __init__( self.path = pretrained self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval() self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, device_map=device_map) + self.num_frame = num_frame + self.num_segments = num_segments batch_size = int(batch_size) assert batch_size == 1, f"Batch size should be 1 for InternVL2, but got {batch_size}." @@ -269,7 +273,7 @@ def generate_until(self, requests) -> List[str]: elif self.modality == "video": assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos." video_path = visuals[0] - pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1) + pixel_values, num_patches_list = load_video(video_path, num_segments=self.num_segments, max_num=self.num_frame) pixel_values = pixel_values.to(torch.bfloat16).cuda() video_prefix = "".join([f"Frame{i+1}: \n" for i in range(len(num_patches_list))]) question = video_prefix + contexts