Skip to content

Commit

Permalink
Merge pull request #468 from EvolvingLMMs-Lab/pufanyi/internvl2-numframe
Browse files Browse the repository at this point in the history
[Fix] InternVL2 Num Frame
  • Loading branch information
KairuiHu authored Dec 20, 2024
2 parents 32b0753 + c61b430 commit d9b59bb
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions lmms_eval/models/internvl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
return frame_indices


def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
def load_video(video_path, bound=None, input_size=448, max_num=32, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
Expand Down Expand Up @@ -134,13 +134,17 @@ def __init__(
device: str = "cuda:0",
device_map: str = "cuda:0",
batch_size: str = "1",
num_frame: int = 32,
num_segments: int = 32,
**kwargs,
):
super().__init__()

self.path = pretrained
self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval()
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, device_map=device_map)
self.num_frame = num_frame
self.num_segments = num_segments

batch_size = int(batch_size)
assert batch_size == 1, f"Batch size should be 1 for InternVL2, but got {batch_size}."
Expand Down Expand Up @@ -269,7 +273,7 @@ def generate_until(self, requests) -> List[str]:
elif self.modality == "video":
assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos."
video_path = visuals[0]
pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
pixel_values, num_patches_list = load_video(video_path, num_segments=self.num_segments, max_num=self.num_frame)
pixel_values = pixel_values.to(torch.bfloat16).cuda()
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
question = video_prefix + contexts
Expand Down

0 comments on commit d9b59bb

Please sign in to comment.