diff --git a/lmms_eval/filters/extraction.py b/lmms_eval/filters/extraction.py index 392e21ad..9dbc212d 100755 --- a/lmms_eval/filters/extraction.py +++ b/lmms_eval/filters/extraction.py @@ -1,7 +1,10 @@ +import os import re import sys import unicodedata +import openai + from lmms_eval.api.filter import Filter diff --git a/lmms_eval/models/internvl2.py b/lmms_eval/models/internvl2.py index 5de4ce1f..ae4cc0c8 100644 --- a/lmms_eval/models/internvl2.py +++ b/lmms_eval/models/internvl2.py @@ -139,8 +139,8 @@ def __init__( super().__init__() self.path = pretrained - self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True).eval().cuda() - self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True) + self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval() + self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, device_map=device_map) batch_size = int(batch_size) assert batch_size == 1, f"Batch size should be 1 for InternVL2, but got {batch_size}." diff --git a/lmms_eval/models/llama_vision.py b/lmms_eval/models/llama_vision.py index 2051dd2c..c67f7965 100644 --- a/lmms_eval/models/llama_vision.py +++ b/lmms_eval/models/llama_vision.py @@ -15,6 +15,7 @@ from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model +from lmms_eval.models.model_utils.load_video import read_video_pyav_pil warnings.filterwarnings("ignore") @@ -25,22 +26,6 @@ @register_model("llama_vision") class LlamaVision(lmms): - """ - Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava - - Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py - - Example usage: - - accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \ - --model llava_hf \ - --model_args pretrained=llava-hf/llava-1.5-7b-hf \ - --tasks seedbench \ - --batch_size 1 \ - --output_path ./logs/ \ - --log_samples - """ - def __init__( self, pretrained: str = "meta-llama/Llama-3.2-11B-Vision", @@ -48,10 +33,12 @@ def __init__( device: str = "cuda", dtype: Optional[Union[str, torch.dtype]] = "auto", batch_size: int = 1, - trust_remote_code: Optional[bool] = False, + trust_remote_code: Optional[bool] = True, attn_implementation: Optional[str] = None, device_map: str = "", max_frames_num: Optional[int] = 32, + fps: Optional[int] = None, + max_image_size: Optional[int] = None, **kwargs, ) -> None: super().__init__() @@ -68,7 +55,9 @@ def __init__( if isinstance(dtype, str) and dtype != "auto": dtype = getattr(torch, dtype) + self.fps = fps self.max_frames_num = max_frames_num + self.max_image_size = max_image_size self._model = MllamaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation) self.model.eval() self.processor = AutoProcessor.from_pretrained(pretrained) @@ -193,9 +182,11 @@ def generate_until(self, requests: List[Instance]) -> List[str]: for visual in visuals: if isinstance(visual, str): - frames = self.load_video(visual, self.max_frames_num) - frames = torch.from_numpy(frames).permute(0, 3, 1, 2) - images.extend([to_pil_image(frame) for frame in frames]) + frames = read_video_pyav_pil(visual, num_frm=self.max_frames_num, fps=self.fps, max_image_size=self.max_image_size) + images.extend(frames) + # frames = self.load_video(visual, self.max_frames_num) + # frames = torch.from_numpy(frames).permute(0, 3, 1, 2) + # images.extend([to_pil_image(frame) for frame in frames]) elif isinstance(visual, PIL.Image.Image): images.append(visual) diff --git a/lmms_eval/models/llava_vid.py b/lmms_eval/models/llava_vid.py index fd3e9ae1..6c430325 100755 --- a/lmms_eval/models/llava_vid.py +++ b/lmms_eval/models/llava_vid.py @@ -90,7 +90,7 @@ def __init__( conv_template="vicuna_v1", use_cache=True, truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6 - max_frames_num: int = 3, + max_frames_num: int = 20, video_fps: int = 1, mm_resampler_type: str = "spatial_pool", mm_spatial_pool_stride: int = 2, diff --git a/lmms_eval/tasks/__init__.py b/lmms_eval/tasks/__init__.py index 3749248e..85537d0c 100755 --- a/lmms_eval/tasks/__init__.py +++ b/lmms_eval/tasks/__init__.py @@ -417,6 +417,8 @@ def _get_task_and_group(self, task_dir: str): "yaml_path": yaml_path, } elif self._config_is_group(config): + if f.endswith("mix_evals_image2text.yaml"): + print(config) # This is a group config tasks_and_groups[config["group"]] = { "type": "group", @@ -477,6 +479,7 @@ def _get_task_and_group(self, task_dir: str): else: self.logger.debug(f"File {f} in {root} could not be loaded as a task or group") + print(tasks_and_groups["mix_evals_image2text"]) return tasks_and_groups diff --git a/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2text.yaml b/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2text.yaml new file mode 100644 index 00000000..85b23377 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/audio2text/mix_evals_audio2text.yaml @@ -0,0 +1,3 @@ +group: mix_evals_audio2text +task: +- mix_evals_audio2_text_freeform diff --git a/lmms_eval/tasks/mix_evals/image2text/_default_template_yaml b/lmms_eval/tasks/mix_evals/image2text/_default_template_yaml new file mode 100644 index 00000000..ee3858f9 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/image2text/_default_template_yaml @@ -0,0 +1,13 @@ +dataset_path: MixEval/MixEval-X +dataset_kwargs: + video: true # a bit confusing, but this is because the official uses path to store image data, so we need to load it as a video dataset + cache_dir: mix_evals_image2text +lmms_eval_specific_kwargs: + default: + post_prompt: "" + pre_prompt: "" + gpt4v: + post_prompt: "" + pre_prompt: "" +metadata: + version: 0 diff --git a/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text.yaml b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text.yaml new file mode 100644 index 00000000..141c8c56 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text.yaml @@ -0,0 +1,4 @@ +group: mix_evals_image2text +task: +- mix_evals_image2text_mc +- mix_evals_image2text_freeform diff --git a/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_freeform.yaml b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_freeform.yaml new file mode 100644 index 00000000..e1e7cded --- /dev/null +++ b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_freeform.yaml @@ -0,0 +1,17 @@ +task: "mix_evals_image2text_freeform" +dataset_name: "image2text" +test_split: free_form +output_type: generate_until +doc_to_visual: !function utils.mix_evals_image2text_doc_to_visual +doc_to_text: !function utils.mix_evals_image2text_doc_to_text +doc_to_target: "{{reference_answer}}" +process_results: !function utils.mix_evals_image2text_process_results_freeform +metric_list: + - metric: gpt_eval + aggregation: !function utils.mix_evals_image2text_gpt_eval + higher_is_better: true + +generation_kwargs: + max_new_tokens: 1024 + +include: _default_template_yaml diff --git a/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_freeform_hard.yaml b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_freeform_hard.yaml new file mode 100644 index 00000000..24874364 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_freeform_hard.yaml @@ -0,0 +1,25 @@ +task: "mix_evals_image2text_freeform_hard" +dataset_name: "image2text" +test_split: free_form_hard +output_type: generate_until +doc_to_visual: !function utils.mix_evals_image2text_doc_to_visual +doc_to_text: !function utils.mix_evals_image2text_doc_to_text +doc_to_target: "{{reference_answer}}" +process_results: !function utils.mix_evals_image2text_process_results_freeform +metric_list: + - metric: gpt_eval + aggregation: !function utils.mix_evals_image2text_gpt_eval + higher_is_better: true + +generation_kwargs: + max_new_tokens: 1024 + +include: _default_template_yaml + +lmms_eval_specific_kwargs: + default: + pre_prompt: "Please answer the following questions about the image." + post_prompt: "" + gpt4v: + pre_prompt: "Please answer the following questions about the image." + post_prompt: "" diff --git a/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_hard.yaml b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_hard.yaml new file mode 100644 index 00000000..77d7f845 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_hard.yaml @@ -0,0 +1,5 @@ +group: mix_evals_image2text_hard +task: +- mix_evals_image2text_mc_hard +- mix_evals_image2text_freeform_hard +# - mix_evals_image2text_openended \ No newline at end of file diff --git a/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_mc.yaml b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_mc.yaml new file mode 100644 index 00000000..1100b539 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_mc.yaml @@ -0,0 +1,23 @@ +include: _default_template_yaml +task: "mix_evals_image2text_mc" +dataset_name: "image2text" +test_split: multiple_choice +output_type: generate_until +doc_to_visual: !function utils.mix_evals_image2text_doc_to_visual +doc_to_text: !function utils.mix_evals_image2text_doc_to_text +doc_to_target: "{{reference_answer}}" + +generation_kwargs: + max_new_tokens: 1024 + +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true + +filter_list: + - name: "flexible-extract" + filter: + - function: !function utils.GPTMultiChoiceFilter diff --git a/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_mc_hard.yaml b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_mc_hard.yaml new file mode 100644 index 00000000..8fd90184 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/image2text/mix_evals_image2text_mc_hard.yaml @@ -0,0 +1,23 @@ +include: _default_template_yaml +task: "mix_evals_image2text_mc_hard" +dataset_name: "image2text" +test_split: multiple_choice_hard +output_type: generate_until +doc_to_visual: !function utils.mix_evals_image2text_doc_to_visual +doc_to_text: !function utils.mix_evals_image2text_doc_to_text +doc_to_target: "{{reference_answer}}" + +generation_kwargs: + max_new_tokens: 1024 + +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true + +filter_list: + - name: "flexible-extract" + filter: + - function: !function utils.GPTMultiChoiceFilter diff --git a/lmms_eval/tasks/mix_evals/image2text/utils.py b/lmms_eval/tasks/mix_evals/image2text/utils.py new file mode 100644 index 00000000..50b181ae --- /dev/null +++ b/lmms_eval/tasks/mix_evals/image2text/utils.py @@ -0,0 +1,401 @@ +import ast +import datetime +import json +import os +import random +import re +import sys +import time +from pathlib import Path + +import openai +import yaml +from loguru import logger as eval_logger +from PIL import Image + +import lmms_eval.tasks._task_utils.file_utils as file_utils +from lmms_eval.filters import Filter + +with open(Path(__file__).parent / "_default_template_yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + config = yaml.safe_load("".join(safe_data)) + +NUM_SECONDS_TO_SLEEP = 5 +API_TYPE = os.getenv("API_TYPE", "openai") +MODEL_VERSION = "gpt-3.5-turbo-0125" +MAX_NEW_TOKENS = 999 + +if API_TYPE == "openai": + client = openai.OpenAI() +elif API_TYPE == "azure": + if "AZURE_ENDPOINT" in os.environ: + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + else: + API_URL = os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + if "AZURE_OPENAI_API_KEY" in os.environ: + API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "YOUR_API_KEY") + else: + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + client = openai.AzureOpenAI(api_key=API_KEY, azure_endpoint=API_URL) + + +image2text_gpt_judge_for_closeended_freeform = lambda prompt, gold_ans, response: [ + {"role": "system", "content": f"In this task, I want you to act as a judge."}, + { + "role": "user", + "content": f"""You will be provided with a question, its golden answer(s), and the model's answer, while the context of the question, which is one or more images, is not given here. Your task is to judge how correct the model's answer is based on the golden answer(s), without seeing the input images of the question, and then give a correctness score. The correctness score should be one of the below numbers: 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right). Your should first briefly give your reasoning process regarding how the model's answer conforms to or contradicts the golden answer(s), and then give the correctness score. The correctness score must strictly follow this format: \"[[score]]\", e.g., \"The correctness score: [[0.5]]\". Below are some examples. + +Example 1: +Question: what is this advertising? +Golden Answer(s): garden annual; seeds; seeds; seeds; seeds; seeds; seeds; seeds; seeds; cole's garden annual +Model's Answer: Seed +Your Judgment: The golden answers consistently mention "seeds" suggesting an advertisement for a seed catalog. The model's answer, "Seed", aligns exactly with this description. The Correctness Score: [[1.0]] + +Example 2: +Question: Who is making a face? +Golden Answer: child +Model's Answer: A man. +Your Judgment: The golden answer specifies a "child" making a face, but the model answered "A man", which is incorrect as it refers to a different age group. The Correctness Score: [[0.0]] + +Example 3: +Question: what road is to the right? +Golden Answer: troublesome valley rd; troublesome valley rd.; troublesome valley; troublesome valley road; valley road; troublesome valley; troublesome valley road; troublesome valley ; troublesome valley rd; troublesome valley rd. +Model's Answer: troublesome road +Your Judgment: The golden answers all specify the name of the road as "troublesome valley rd" or variations of this phrase with consistent reference to "troublesome valley." The model's answer, "troublesome road," captures the "troublesome" aspect but omits the critical "valley" part of the name, which is crucial for full accuracy. Thus, the model's answer partially matches the golden answer but lacks complete specificity. The Correctness Score: [[0.6]] + +Note that each one of the golden answers is considered correct. Thus if the model's answer matches any one of the golden answers, it should be considered correct. Judge the below case, give the brief reasoning process and the correctness score. + +Question: {prompt} +Golden Answer(s): {gold_ans} +Model's Answer: {response} +Your Judgment: +""", + }, +] + +image2text_gpt_judge_for_closeended_multiplechoice = lambda prompt, options, response: [ + {"role": "system", "content": f"In this task, I want you to act as an option extractor."}, + { + "role": "user", + "content": f"""You will be provided with a multiple-choice question, its options, and the model's answer, while the context of the question, which is one or more images, is not given here. Your task is to extract or judge which option is chosen by the model based on its response, without seeing the context of the question. The extracted option should be one of the provided option letters. Your should first briefly give your reasoning process, and then give the extracted option letter. The extracted option must strictly follow this format: \"[[option letter]]\", e.g., \"The option chosen by the model: [[A]]\". +Below are some examples. + +Example 1: +Question: Where are the cast of the television show located in the image? +Options: +A. In the foreground +B. In the background +C. In the center +D. At the edges +Model's Answer: C. In the center +Your Judgment: The model's answer clearly states "C. In the center", indicating that the correct option, according to the model, is in the center. The option chosen by the model: [[C]]. + +Example 2: +Question: on the left was painted during the +Options: +A. first or second century C. E. +B. sixth or seventh century C. E. +C. tenth or eleventh century C.E. +D. fourteenth or fifteenth century C. E. +Model's Answer: The correct answer is option D, the fourteenth or fifteenth century C.E. +Your Judgment: The model's response specifies "option D, the fourteenth or fifteenth century C.E." directly as the correct answer. The option chosen by the model: [[D]]. + +Example 3: +Question: what does the diagram show's you information about +Options: +A. Photosynthesis +B. The plant getting fed +C. A picture of the plant +D. What happens to a plant daily +Model's Answer: The diagram shows the process of photosynthesis, which is the process by which plants convert sunlight, carbon dioxide, and water into oxygen and glucose. +Your Judgment: The model's answer mentions "the process of photosynthesis," which directly corresponds to option A, "Photosynthesis". Therefore, the correct option according to the model is photosynthesis. The option chosen by the model: [[A]]. + +Give the brief reasoning process and the extracted option for the below case: + +Question: {prompt} +Options: +{options} +Model's Answer: {response} +Your Judgment: +""", + }, +] + + +def get_score_from_judge(judge_response): + """ + Get the score from the judge response. + """ + one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") + one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") + + match = re.search(one_score_pattern, judge_response) + if not match: + match = re.search(one_score_pattern_backup, judge_response) + + if match: + rating = ast.literal_eval(match.groups()[0]) + else: + rating = round(random.random(), 1) + + return float(rating) + + +def get_eval(question, model_response: str, ground_truth: str, max_tokens: int, retries: int = 5): + global client + messages = image2text_gpt_judge_for_closeended_freeform(prompt=question, gold_ans=ground_truth, response=model_response) + + payload = { + "model": MODEL_VERSION, + "messages": messages, + # "temperature": 0.2, + "max_tokens": max_tokens, + } + + for attempt in range(retries): + try: + # response = requests.post(API_URL, headers=headers, json=payload, timeout=60) + response = client.chat.completions.create(**payload) + # response.raise_for_status() + response_data = response.json() + + # content = response_data["choices"][0]["message"]["content"].strip() + content = response.choices[0].message.content.strip() + if content != "": + return content + break # If successful, break out of the loop + + except Exception as e: + eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}") + if attempt < retries: # If we have retries left, sleep and then continue to next attempt + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty + eval_logger.error(f"All {retries} attempts failed. Last error message: {e}") + return "[[0.0]]" + return "[[0.0]]" + + +# A bit ugly here +# But the idea is that we will unzip all the zip files +# To HF HOME cache dir +# And load it here +HF_HOME = os.environ["HF_HOME"] +cache_dir = config["dataset_kwargs"]["cache_dir"] +cache_dir = os.path.join(HF_HOME, cache_dir) +cache_dir = os.path.join(cache_dir) + + +def mix_evals_image2text_doc_to_visual(doc): + visual = [] + for image_path in doc["input_file"]: + image_path = os.path.join(cache_dir, image_path) + if os.path.exists(image_path): + image_path = image_path + + visual.append(Image.open(image_path).convert("RGB")) + + return visual + + +# This is the place where you format your question +def mix_evals_image2text_doc_to_text(doc, lmms_eval_specific_kwargs=None): + if lmms_eval_specific_kwargs is None: + lmms_eval_specific_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + if "post_prompt" in lmms_eval_specific_kwargs: + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + + user_prompt = doc["query"] + + if "options" in doc and len(doc["options"]) > 1: + option_prompt = "Here are the options:\n" + for idx, option in enumerate(doc["options"]): + char_idx = chr(ord("A") + idx) + option = option.strip() + option_prompt += f"{char_idx}. {option}\n" + + option_prompt = option_prompt.rstrip("\n") + user_prompt = f"{user_prompt}\n{option_prompt}" + + if pre_prompt: + user_prompt = f"{pre_prompt}\n{user_prompt}" + + if post_prompt: + user_prompt = f"{user_prompt}\n{post_prompt}" + return user_prompt + + +OPEN_CONVS_PROMPT = """{PRE} +{FIRST} +{POST} +""" + + +def mix_evals_image2text_doc_to_text_open_convs(doc, lmms_eval_specific_kwargs=None): + if lmms_eval_specific_kwargs is None: + lmms_eval_specific_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + if "post_prompt" in lmms_eval_specific_kwargs: + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + + filtered_first_turn = re.sub(r"", "", doc["first_turn_user_prompt"]) + return OPEN_CONVS_PROMPT.format( + PRE=pre_prompt, + POST=post_prompt, + FIRST=filtered_first_turn, + ) + + +MODEL_CONVS_PROMPT = """{FIRST} +{MODEL_RESPONSE} +{PRE} +{SECOND} +{POST} +""" + + +def mix_evals_image2text_doc_to_text_open_2nd_convs(doc, lmms_eval_specific_kwargs=None): + if lmms_eval_specific_kwargs is None: + lmms_eval_specific_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + if "post_prompt" in lmms_eval_specific_kwargs: + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + + return MODEL_CONVS_PROMPT.format( + PRE=pre_prompt, + POST=post_prompt, + FIRST=doc["first_turn_user_prompt"], + SECOND=doc["second_turn_user_prompt"], + MODEL_RESPONSE=doc["model_response"], + ) + + +def mix_evals_image2text_process_results_open_convs(doc, result): + pred = result[0] + return {"submission": {"pred": pred, "question_idx": doc["question_index"], "first_turn_video_caption": doc["first_turn_video_caption"], "target": ""}} + + +def mix_evals_image2text_process_results_freeform(doc, result): + pred = result[0] + ground_truth_str = ", ".join([f'"{gt}"' for gt in doc["reference_answer"]]) + ground_truth_str = f"[{ground_truth_str}]" + content = image2text_gpt_judge_for_closeended_freeform(response=pred, gold_ans=ground_truth_str, prompt=doc["query"]) + eval_answer = get_eval(model_response=pred, ground_truth=ground_truth_str, max_tokens=MAX_NEW_TOKENS, question=doc["query"]) + return { + "submission": {"pred": pred, "question_idx": doc["id"], "target": doc["reference_answer"], "eval_answer": eval_answer, "gpt_prompt": content}, + "gpt_eval": {"pred": pred, "question_idx": doc["id"], "target": doc["reference_answer"], "eval_answer": eval_answer, "gpt_prompt": content}, + } + + +def mix_evals_image2text_aggregate_submissions(results, args, task): + now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + submission_file_name = f"mix_evals_image2text_{task}-{now_date_time}.json" + path = file_utils.generate_submission_file(submission_file_name, args) + with open(path, "w") as f: + json.dump(results, f) + eval_logger.info(f"Submission file saved to {path}") + + +def mix_evals_image2text_gpt_eval(results, args): + score = 0 + for result in results: + eval_answer = result["eval_answer"] + eval_score = get_score_from_judge(eval_answer) + score += eval_score + + return score / len(results) + + +# Factory into different aggregate +def mix_evals_image2text_aggregate_gen(results, args): + mix_evals_image2text_aggregate_submissions(results, args, "OpenConvs") + + +class GPTMultiChoiceFilter(Filter): + def __init__(self, gpt_version: str = "gpt-3.5-turbo-0125", retries: int = 5): + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + self.gpt_version = gpt_version + + if API_TYPE == "openai": + self.client = openai.OpenAI(api_key=API_KEY) + elif API_TYPE == "azure": + self.client = openai.AzureOpenAI(api_key=API_KEY, azure_endpoint=API_URL) + + self.retries = retries + + def apply(self, resps, docs): + """ + Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. + Should return the list of (filtered) response lists *in the same order as they were input*, e.g. + if pass in [, ] should return + [, ] + """ + results = [] + for response, doc in zip(resps, docs): + query = doc["query"] + options = "\n".join([f"{chr(ord('A') + idx)}. {option}" for idx, option in enumerate(doc["options"])]) + message = image2text_gpt_judge_for_closeended_multiplechoice(prompt=query, options=options, response=response) + payload = { + "model": self.gpt_version, + "messages": message, + "max_tokens": MAX_NEW_TOKENS, + } + result = 0 + for attempt in range(self.retries): + try: + # response = requests.post(API_URL, headers=headers, json=payload, timeout=60) + # print(payload) + response = client.chat.completions.create(**payload) + # print(response) + # response.raise_for_status() + + # content =["choices"][0]["message"]["content"].strip() + content = response.choices[0].message.content + # print("content:", content) + if content: + match = re.search(r"\[\[([A-Z])\]\]", content) + # print("match:", match) + if not match: + match = re.search(r"r'\b([A-Z])\.?\b'", content) + # print("match:", match) + if match: + # print("=====") + # print(match.group(1)) + result = ord(match.group(1)) - ord("A") + # print("result:", result) + # print("=====") + # print(content, result) + else: + result = 0 + break # If successful, break out of the loop + + except Exception as e: + eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}") + import traceback + + print(traceback.format_exc()) + if attempt < self.retries: # If we have retries left, sleep and then continue to next attempt + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty + eval_logger.error(f"All {self.retries} attempts failed. Last error message: {e}") + result = 0 + break + results.append(str(result)) + return results diff --git a/lmms_eval/tasks/mix_evals/mix_evals_video2text.yaml b/lmms_eval/tasks/mix_evals/mix_evals_video2text.yaml deleted file mode 100644 index e49612a8..00000000 --- a/lmms_eval/tasks/mix_evals/mix_evals_video2text.yaml +++ /dev/null @@ -1,5 +0,0 @@ -group: mix_evals_video2text -task: -# - mix_evals_video2text_openconv -- mix_evals_video2text_mc -- mix_evals_video2text_freeform \ No newline at end of file diff --git a/lmms_eval/tasks/mix_evals/utils.py b/lmms_eval/tasks/mix_evals/utils.py deleted file mode 100644 index cd1d8e5b..00000000 --- a/lmms_eval/tasks/mix_evals/utils.py +++ /dev/null @@ -1,286 +0,0 @@ -import datetime -import json -import os -import re -import sys -import time -from pathlib import Path - -import requests -import yaml -from loguru import logger as eval_logger - -import lmms_eval.tasks._task_utils.file_utils as file_utils -from lmms_eval.filters.extraction import ExtendedRegexFilter - -with open(Path(__file__).parent / "_default_template_yaml", "r") as f: - raw_data = f.readlines() - safe_data = [] - for i, line in enumerate(raw_data): - # remove function definition since yaml load cannot handle it - if "!function" not in line: - safe_data.append(line) - - config = yaml.safe_load("".join(safe_data)) - -NUM_SECONDS_TO_SLEEP = 5 -GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"] -API_TYPE = os.getenv("API_TYPE", "openai") - -if API_TYPE == "openai": - API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") - API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") - headers = { - "Authorization": f"Bearer {API_KEY}", - "Content-Type": "application/json", - } -elif API_TYPE == "azure": - API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") - API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") - headers = { - "api-key": API_KEY, - "Content-Type": "application/json", - } - -eval_prompt = """You are an AI assistant who will help me to evaluate the quality of a model response to a few candidate ground truth answers. - -Some criterion -- Response that perfectly reflect the meaning of the ground truth: 1 point -- Response that reflect none of the key points in the ground truth: 0 point -- Some part in the response are correct but some parts in the ground truth are not mentioned in the response: 0.5 point -- Some part in the response are correct but other parts in the response are not mentioned in the ground truth: 0.5 point - -Here're some examples about the scoring criterion and format: -model response: Steam Cleaning Services -ground truth: ["steam clean", "steam clean", "cleaning", "car", "steam clean"], -Point: 1 - -model response: A cowboy action shooter. -ground truth: ["man"] -Point: 1 - -model response: I'm sorry, but I can't assist with that request. -ground truth: ["quality"] -Point: 0 - -Let's begin this task: -model response: {model_response} -ground truth: {ground_truth} -Point:""" - - -def get_eval(model_response: str, ground_truth: str, max_tokens: int, retries: int = 5): - global headers - content = eval_prompt.format(model_response=model_response, ground_truth=ground_truth) - - messages = [ - {"role": "user", "content": content}, - ] - - payload = { - "model": GPT_EVAL_MODEL_NAME, - "messages": messages, - "temperature": 0.2, - "max_tokens": max_tokens, - } - - for attempt in range(retries): - try: - response = requests.post(API_URL, headers=headers, json=payload, timeout=60) - response.raise_for_status() - response_data = response.json() - - content = response_data["choices"][0]["message"]["content"].strip() - if content != "": - return content, response_data["model"] - break # If successful, break out of the loop - - except Exception as e: - eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}") - if attempt < retries: # If we have retries left, sleep and then continue to next attempt - time.sleep(NUM_SECONDS_TO_SLEEP) - else: # If this was the last attempt, log and return empty - eval_logger.error(f"All {retries} attempts failed. Last error message: {e}") - return "", "" - return "", "" - - -# A bit ugly here -# But the idea is that we will unzip all the zip files -# To HF HOME cache dir -# And load it here -HF_HOME = os.environ["HF_HOME"] -cache_dir = config["dataset_kwargs"]["cache_dir"] -cache_dir = os.path.join(HF_HOME, cache_dir) -cache_dir = os.path.join(cache_dir) - - -# Pass in video path here -# Can only work correctly with video llm -def mix_evals_video2text_doc_to_visual(doc): - video_path = doc["video_path"] - video_path = os.path.join(cache_dir, video_path) - if os.path.exists(video_path): - video_path = video_path - elif os.path.exists(video_path.replace("mp4", "MP4")): - video_path = video_path.replace("mp4", "MP4") - else: - sys.exit(f"video path:{video_path} does not exist, please check") - return [video_path] - - -# This is the place where you format your question -def mix_evals_video2text_doc_to_text(doc, lmms_eval_specific_kwargs=None): - if lmms_eval_specific_kwargs is None: - lmms_eval_specific_kwargs = {} - pre_prompt = "" - post_prompt = "" - if "pre_prompt" in lmms_eval_specific_kwargs: - pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] - if "post_prompt" in lmms_eval_specific_kwargs: - post_prompt = lmms_eval_specific_kwargs["post_prompt"] - - user_prompt = doc["prompt"] - - if "options" in doc: - option_prompt = "Here are the options:\n" - for idx, option in enumerate(doc["options"]): - char_idx = chr(ord("A") + idx) - option = option.strip() - option_prompt += f"{char_idx}. {option}\n" - - option_prompt = option_prompt.rstrip("\n") - user_prompt = f"{user_prompt}\n{option_prompt}" - - if pre_prompt: - user_prompt = f"{pre_prompt}\n{user_prompt}" - - if post_prompt: - user_prompt = f"{user_prompt}\n{post_prompt}" - return user_prompt - - -OPEN_CONVS_PROMPT = """{PRE} -{FIRST} -{POST} -""" - - -def mix_evals_video2text_doc_to_text_open_convs(doc, lmms_eval_specific_kwargs=None): - if lmms_eval_specific_kwargs is None: - lmms_eval_specific_kwargs = {} - pre_prompt = "" - post_prompt = "" - if "pre_prompt" in lmms_eval_specific_kwargs: - pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] - if "post_prompt" in lmms_eval_specific_kwargs: - post_prompt = lmms_eval_specific_kwargs["post_prompt"] - - filtered_first_turn = re.sub(r"", "", doc["first_turn_user_prompt"]) - return OPEN_CONVS_PROMPT.format( - PRE=pre_prompt, - POST=post_prompt, - FIRST=filtered_first_turn, - ) - - -MODEL_CONVS_PROMPT = """{FIRST} -{MODEL_RESPONSE} -{PRE} -{SECOND} -{POST} -""" - - -def mix_evals_video2text_doc_to_text_open_2nd_convs(doc, lmms_eval_specific_kwargs=None): - if lmms_eval_specific_kwargs is None: - lmms_eval_specific_kwargs = {} - pre_prompt = "" - post_prompt = "" - if "pre_prompt" in lmms_eval_specific_kwargs: - pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] - if "post_prompt" in lmms_eval_specific_kwargs: - post_prompt = lmms_eval_specific_kwargs["post_prompt"] - - return MODEL_CONVS_PROMPT.format( - PRE=pre_prompt, - POST=post_prompt, - FIRST=doc["first_turn_user_prompt"], - SECOND=doc["second_turn_user_prompt"], - MODEL_RESPONSE=doc["model_response"], - ) - - -def mix_evals_video2text_process_results_open_convs(doc, result): - pred = result[0] - return {"submission": {"pred": pred, "question_idx": doc["question_index"], "first_turn_video_caption": doc["first_turn_video_caption"], "target": ""}} - - -def mix_evals_video2text_process_results_freeform(doc, result): - pred = result[0] - ground_truth_str = ", ".join([f'"{gt}"' for gt in doc["target"]]) - ground_truth_str = f"[{ground_truth_str}]" - content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str) - eval_answer, model_name = get_eval(model_response=pred, ground_truth=ground_truth_str, max_tokens=1024) - return { - "submission": {"pred": pred, "question_idx": doc["question_index"], "target": doc["target"], "eval_answer": eval_answer, "gpt_prompt": content}, - "gpt_eval": {"pred": pred, "question_idx": doc["question_index"], "target": doc["target"], "eval_answer": eval_answer, "gpt_prompt": content}, - } - - -def mix_evals_video2text_aggregate_submissions(results, args, task): - now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") - submission_file_name = f"mix_evals_video2text_{task}-{now_date_time}.json" - path = file_utils.generate_submission_file(submission_file_name, args) - with open(path, "w") as f: - json.dump(results, f) - eval_logger.info(f"Submission file saved to {path}") - - -def mix_evals_video2text_gpt_eval(results, args): - score = 0 - for result in results: - eval_answer = result["eval_answer"] - eval_score = re.search(r"([0-9.]+)", eval_answer).group(1) - try: - eval_score = float(eval_score) - except Exception as e: - eval_logger.error(f"Error parsing eval_score: {e}") - eval_score = 0.0 - score += eval_score - - return score / len(results) - - -# Factory into different aggregate -def mix_evals_video2text_aggregate_gen(results, args): - mix_evals_video2text_aggregate_submissions(results, args, "OpenConvs") - - -class MultiChoiceRegexFilter(ExtendedRegexFilter): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def apply(self, resps, docs): - filtered_resps = [] - - for r, doc in zip(resps, docs): - # Regex to directly extract the option letter from the model response - option_letter_regex = re.compile(r"\b([A-Z])\.\s+([^\n]*)") - - # Process each response - filtered = [] - for resp in r: - # Try to match the option letter at the start of the response - match = option_letter_regex.match(resp) - if match: - # If a match is found, append the matched letter - filtered.append(match.group(1)) - else: - # If no match, return the original response - filtered.append(resp) - - # Assuming we need the first response that matches or the original response - filtered_resps.append(filtered[0]) - - return filtered_resps diff --git a/lmms_eval/tasks/mix_evals/_default_template_yaml b/lmms_eval/tasks/mix_evals/video2text/_default_template_yaml similarity index 84% rename from lmms_eval/tasks/mix_evals/_default_template_yaml rename to lmms_eval/tasks/mix_evals/video2text/_default_template_yaml index bda3f8e8..4c84f1a2 100644 --- a/lmms_eval/tasks/mix_evals/_default_template_yaml +++ b/lmms_eval/tasks/mix_evals/video2text/_default_template_yaml @@ -2,7 +2,7 @@ dataset_kwargs: cache_dir: mix_evals_video2text token: true video: true -dataset_path: lmms-lab/MixEvals_Video2Text +dataset_path: MixEval/MixEval-X lmms_eval_specific_kwargs: default: post_prompt: "" @@ -12,5 +12,4 @@ lmms_eval_specific_kwargs: pre_prompt: These are frames from a video. Please answer the following questions about the video. metadata: gpt_eval_model_name: gpt-4o-mini - modality: video version: 0 diff --git a/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text.yaml b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text.yaml new file mode 100644 index 00000000..4f3e2c8a --- /dev/null +++ b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text.yaml @@ -0,0 +1,5 @@ +group: mix_evals_video2text +task: +- mix_evals_video2text_mc +- mix_evals_video2text_freeform +# - mix_evals_video2text_openended \ No newline at end of file diff --git a/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_freeform.yaml b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_freeform.yaml new file mode 100644 index 00000000..35531b0f --- /dev/null +++ b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_freeform.yaml @@ -0,0 +1,25 @@ +task: "mix_evals_video2text_freeform" +dataset_name: "video2text" +test_split: free_form +output_type: generate_until +doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual +doc_to_text: !function utils.mix_evals_video2text_doc_to_text +doc_to_target: "{{reference_answer}}" +process_results: !function utils.mix_evals_video2text_process_results_freeform +metric_list: + - metric: gpt_eval + aggregation: !function utils.mix_evals_video2text_gpt_eval + higher_is_better: true + +generation_kwargs: + max_new_tokens: 1024 + +include: _default_template_yaml + +lmms_eval_specific_kwargs: + default: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "" + gpt4v: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "" diff --git a/lmms_eval/tasks/mix_evals/mix_evals_video2text_freeform.yaml b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_freeform_hard.yaml similarity index 83% rename from lmms_eval/tasks/mix_evals/mix_evals_video2text_freeform.yaml rename to lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_freeform_hard.yaml index e4495b50..059d2b28 100644 --- a/lmms_eval/tasks/mix_evals/mix_evals_video2text_freeform.yaml +++ b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_freeform_hard.yaml @@ -1,10 +1,10 @@ -dataset_name: "video2text_closeended_free-form" -task: "mix_evals_video2text_freeform" -test_split: test +task: "mix_evals_video2text_freeform_hard" +dataset_name: "video2text" +test_split: free_form_hard output_type: generate_until doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual doc_to_text: !function utils.mix_evals_video2text_doc_to_text -doc_to_target: "{{target}}" +doc_to_target: "{{reference_answer}}" process_results: !function utils.mix_evals_video2text_process_results_freeform metric_list: - metric: gpt_eval @@ -12,7 +12,7 @@ metric_list: higher_is_better: true generation_kwargs: - max_new_tokens: 16 + max_new_tokens: 1024 include: _default_template_yaml diff --git a/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_hard.yaml b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_hard.yaml new file mode 100644 index 00000000..2817b420 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_hard.yaml @@ -0,0 +1,5 @@ +group: mix_evals_video2text_hard +task: +- mix_evals_video2text_mc_hard +- mix_evals_video2text_freeform_hard +# - mix_evals_video2text_openended \ No newline at end of file diff --git a/lmms_eval/tasks/mix_evals/mix_evals_video2text_mc.yaml b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_mc.yaml similarity index 76% rename from lmms_eval/tasks/mix_evals/mix_evals_video2text_mc.yaml rename to lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_mc.yaml index fcca0731..c94a0c5a 100644 --- a/lmms_eval/tasks/mix_evals/mix_evals_video2text_mc.yaml +++ b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_mc.yaml @@ -1,14 +1,14 @@ include: _default_template_yaml -dataset_name: "video2text_closeended_multiple-choice" task: "mix_evals_video2text_mc" -test_split: test +dataset_name: "video2text" +test_split: multiple_choice output_type: generate_until doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual doc_to_text: !function utils.mix_evals_video2text_doc_to_text -doc_to_target: "{{target}}" +doc_to_target: "{{reference_answer}}" generation_kwargs: - max_new_tokens: 5 + max_new_tokens: 1024 metric_list: - metric: exact_match @@ -20,10 +20,7 @@ metric_list: filter_list: - name: "flexible-extract" filter: - - function: !function utils.MultiChoiceRegexFilter - group_select: 0 - ignore_case: true - ignore_punctuation: true + - function: !function utils.GPTMultiChoiceFilter lmms_eval_specific_kwargs: default: diff --git a/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_mc_hard.yaml b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_mc_hard.yaml new file mode 100644 index 00000000..9cc3f2da --- /dev/null +++ b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_mc_hard.yaml @@ -0,0 +1,31 @@ +include: _default_template_yaml +task: "mix_evals_video2text_mc_hard" +dataset_name: "video2text" +test_split: multiple_choice_hard +output_type: generate_until +doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual +doc_to_text: !function utils.mix_evals_video2text_doc_to_text +doc_to_target: "{{reference_answer}}" + +generation_kwargs: + max_new_tokens: 1024 + +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true + +filter_list: + - name: "flexible-extract" + filter: + - function: !function utils.GPTMultiChoiceFilter + +lmms_eval_specific_kwargs: + default: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "Answer with the option's letter from the given choices directly." + gpt4v: + pre_prompt: "These are frames from a video. Please answer the following questions about the video." + post_prompt: "Answer with the option's letter from the given choices directly." diff --git a/lmms_eval/tasks/mix_evals/mix_evals_video2text_openended.yaml b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_openended.yaml similarity index 88% rename from lmms_eval/tasks/mix_evals/mix_evals_video2text_openended.yaml rename to lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_openended.yaml index a62b2818..7d0baea8 100644 --- a/lmms_eval/tasks/mix_evals/mix_evals_video2text_openended.yaml +++ b/lmms_eval/tasks/mix_evals/video2text/mix_evals_video2text_openended.yaml @@ -1,7 +1,7 @@ include: _default_template_yaml -dataset_name: "video2text_openended" -task: "mix_evals_video2text_openconv" -test_split: test +dataset_name: "open_ended" +task: "mix_evals_video2text_openended" +test_split: video2text output_type: generate_until doc_to_visual: !function utils.mix_evals_video2text_doc_to_visual doc_to_text: !function utils.mix_evals_video2text_doc_to_text_open_convs diff --git a/lmms_eval/tasks/mix_evals/video2text/utils.py b/lmms_eval/tasks/mix_evals/video2text/utils.py new file mode 100644 index 00000000..d98f9a81 --- /dev/null +++ b/lmms_eval/tasks/mix_evals/video2text/utils.py @@ -0,0 +1,422 @@ +import ast +import datetime +import json +import os +import random +import re +import sys +import time +from pathlib import Path + +import openai +import requests +import yaml +from loguru import logger as eval_logger +from PIL import Image + +import lmms_eval.tasks._task_utils.file_utils as file_utils +from lmms_eval.filters import Filter + +with open(Path(__file__).parent / "_default_template_yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + config = yaml.safe_load("".join(safe_data)) + +NUM_SECONDS_TO_SLEEP = 5 +API_TYPE = os.getenv("API_TYPE", "openai") +MODEL_VERSION = "gpt-3.5-turbo-0125" +MAX_NEW_TOKENS = 999 + +if API_TYPE == "openai": + client = openai.OpenAI() +elif API_TYPE == "azure": + if "AZURE_ENDPOINT" in os.environ: + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + else: + API_URL = os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + if "AZURE_OPENAI_API_KEY" in os.environ: + API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "YOUR_API_KEY") + else: + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + client = openai.AzureOpenAI(api_key=API_KEY, azure_endpoint=API_URL) + + +video2text_gpt_judge_for_closeended_freeform = lambda prompt, gold_ans, response: [ + {"role": "system", "content": f"In this task, I want you to act as a judge."}, + { + "role": "user", + "content": f"""You will be provided with a question, its golden answer(s), and the model's answer, while the context of the question, which is one or more videos, is not given here. Your task is to judge how correct the model's answer is based on the golden answer(s), without seeing the input videos of the question, and then give a correctness score. The correctness score should be one of the below numbers: 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right). Your should first briefly give your reasoning process regarding how the model's answer conforms to or contradicts the golden answer(s), and then give the correctness score. The correctness score must strictly follow this format: \"[[score]]\", e.g., \"The correctness score: [[0.5]]\". Below are some examples. + +Example 1: +Question: what does this video want to express +Golden Answer(s): introduce method of playing +Model's Answer: Volleyball serve \n +Your Judgment: The model's answer "Volleyball serve" suggests a specific action, which may be part of what the video demonstrates. However, it misses the broader educational intent implied by the golden answer "introduce method of playing". Therefore, the answer is partially correct. The Correctness Score: [[0.5]] + +Example 2: +Question: who do two other boys with surprised looks assist up? +Golden Answer(s): boy +Model's Answer: Boy. +Your Judgment: The model's answer "Boy." precisely matches the golden answer which states the two other boys assist a "boy". The Correctness Score: [[1.0]] + +Example 3: +Question: what did the lady do at the end of the video after their performance +Golden Answer(s): picks up her phone +Model's Answer: Nothing. +Your Judgment: The model's answer "Nothing." directly contradicts the golden answer which states that the lady "picks up her phone" at the end of the video after their performance. Since the model's response completely misses the specific action described in the golden answer, it is incorrect. The Correctness Score: [[0.0]] + +Note that each one of the golden answers is considered correct. Thus if the model's answer matches any one of the golden answers, it should be considered correct. Judge the below case, give the brief reasoning process and the correctness score. + +Question: {prompt} +Golden Answer(s): {gold_ans} +Model's Answer: {response} +Your Judgment: +""", + }, +] + + +def get_score_from_judge(judge_response): + """ + Get the score from the judge response. + """ + one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") + one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") + + match = re.search(one_score_pattern, judge_response) + if not match: + match = re.search(one_score_pattern_backup, judge_response) + + if match: + rating = ast.literal_eval(match.groups()[0]) + else: + rating = round(random.random(), 1) + + return float(rating) + + +def get_eval(question, model_response: str, ground_truth: str, max_tokens: int, retries: int = 5): + global client + messages = video2text_gpt_judge_for_closeended_freeform(prompt=question, gold_ans=ground_truth, response=model_response) + + payload = { + "model": MODEL_VERSION, + "messages": messages, + # "temperature": 0.2, + "max_tokens": max_tokens, + } + + for attempt in range(retries): + try: + # response = requests.post(API_URL, headers=headers, json=payload, timeout=60) + response = client.chat.completions.create(**payload) + # response.raise_for_status() + response_data = response.json() + + # content = response_data["choices"][0]["message"]["content"].strip() + content = response.choices[0].message.content.strip() + if content != "": + return content + break # If successful, break out of the loop + + except Exception as e: + eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}") + if attempt < retries: # If we have retries left, sleep and then continue to next attempt + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty + eval_logger.error(f"All {retries} attempts failed. Last error message: {e}") + return "[[0.0]]" + return "[[0.0]]" + + +# A bit ugly here +# But the idea is that we will unzip all the zip files +# To HF HOME cache dir +# And load it here +HF_HOME = os.environ["HF_HOME"] +cache_dir = config["dataset_kwargs"]["cache_dir"] +cache_dir = os.path.join(HF_HOME, cache_dir) +cache_dir = os.path.join(cache_dir) + + +def mix_evals_doc_to_visual(doc, modality): + visual = [] + for video_path in doc["input_file"]: + video_path = os.path.join(cache_dir, video_path) + if os.path.exists(video_path): + video_path = video_path + elif os.path.exists(video_path.replace("mp4", "MP4")): + video_path = video_path.replace("mp4", "MP4") + else: + sys.exit(f"video path:{video_path} does not exist, please check") + + if modality == "video": + visual.append(video_path) + elif modality == "image": + visual.append(Image.open(video_path).convert("RGB")) + else: + sys.exit(f"modality:{modality} is not supported, please check") + return visual + + +def mix_evals_video2text_doc_to_visual(doc): + return mix_evals_doc_to_visual(doc, "video") + + +def mix_evals_image2text_doc_to_visual(doc): + return mix_evals_doc_to_visual(doc, "image") + + +# This is the place where you format your question +def mix_evals_video2text_doc_to_text(doc, lmms_eval_specific_kwargs=None): + if lmms_eval_specific_kwargs is None: + lmms_eval_specific_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + if "post_prompt" in lmms_eval_specific_kwargs: + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + + user_prompt = doc["query"] + + if "options" in doc and len(doc["options"]) > 1: + option_prompt = "Here are the options:\n" + for idx, option in enumerate(doc["options"]): + char_idx = chr(ord("A") + idx) + option = option.strip() + option_prompt += f"{char_idx}. {option}\n" + + option_prompt = option_prompt.rstrip("\n") + user_prompt = f"{user_prompt}\n{option_prompt}" + + if pre_prompt: + user_prompt = f"{pre_prompt}\n{user_prompt}" + + if post_prompt: + user_prompt = f"{user_prompt}\n{post_prompt}" + return user_prompt + + +OPEN_CONVS_PROMPT = """{PRE} +{FIRST} +{POST} +""" + + +def mix_evals_video2text_doc_to_text_open_convs(doc, lmms_eval_specific_kwargs=None): + if lmms_eval_specific_kwargs is None: + lmms_eval_specific_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + if "post_prompt" in lmms_eval_specific_kwargs: + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + + filtered_first_turn = re.sub(r"", "", doc["first_turn_user_prompt"]) + return OPEN_CONVS_PROMPT.format( + PRE=pre_prompt, + POST=post_prompt, + FIRST=filtered_first_turn, + ) + + +MODEL_CONVS_PROMPT = """{FIRST} +{MODEL_RESPONSE} +{PRE} +{SECOND} +{POST} +""" + + +def mix_evals_video2text_doc_to_text_open_2nd_convs(doc, lmms_eval_specific_kwargs=None): + if lmms_eval_specific_kwargs is None: + lmms_eval_specific_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + if "post_prompt" in lmms_eval_specific_kwargs: + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + + return MODEL_CONVS_PROMPT.format( + PRE=pre_prompt, + POST=post_prompt, + FIRST=doc["first_turn_user_prompt"], + SECOND=doc["second_turn_user_prompt"], + MODEL_RESPONSE=doc["model_response"], + ) + + +def mix_evals_video2text_process_results_open_convs(doc, result): + pred = result[0] + return {"submission": {"pred": pred, "question_idx": doc["question_index"], "first_turn_video_caption": doc["first_turn_video_caption"], "target": ""}} + + +def mix_evals_video2text_process_results_freeform(doc, result): + pred = result[0] + ground_truth_str = ", ".join([f'"{gt}"' for gt in doc["reference_answer"]]) + ground_truth_str = f"[{ground_truth_str}]" + content = video2text_gpt_judge_for_closeended_freeform(response=pred, gold_ans=ground_truth_str, prompt=doc["query"]) + eval_answer = get_eval(model_response=pred, ground_truth=ground_truth_str, max_tokens=MAX_NEW_TOKENS, question=doc["query"]) + return { + "submission": {"pred": pred, "question_idx": doc["id"], "target": doc["reference_answer"], "eval_answer": eval_answer, "gpt_prompt": content}, + "gpt_eval": {"pred": pred, "question_idx": doc["id"], "target": doc["reference_answer"], "eval_answer": eval_answer, "gpt_prompt": content}, + } + + +def mix_evals_video2text_aggregate_submissions(results, args, task): + now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + submission_file_name = f"mix_evals_video2text_{task}-{now_date_time}.json" + path = file_utils.generate_submission_file(submission_file_name, args) + with open(path, "w") as f: + json.dump(results, f) + eval_logger.info(f"Submission file saved to {path}") + + +def mix_evals_video2text_gpt_eval(results, args): + score = 0 + for result in results: + eval_answer = result["eval_answer"] + eval_score = get_score_from_judge(eval_answer) + score += eval_score + + return score / len(results) + + +# Factory into different aggregate +def mix_evals_video2text_aggregate_gen(results, args): + mix_evals_video2text_aggregate_submissions(results, args, "OpenConvs") + + +video2text_gpt_judge_for_closeended_multiplechoice = lambda prompt, options, response: [ + {"role": "system", "content": f"In this task, I want you to act as an option extractor."}, + { + "role": "user", + "content": f"""You will be provided with a multiple-choice question, its options, and the model's answer, while the context of the question, which is one or more videos, is not given here. Your task is to extract or judge which option is chosen by the model based on its response, without seeing the context of the question. The extracted option should be one of the provided option letters. Your should first briefly give your reasoning process, and then give the extracted option letter. The extracted option must strictly follow this format: \"[[option letter]]\", e.g., \"The option chosen by the model: [[A]]\". +Below are some examples. + +Example 1: +Question: What did he do to the car? +Options: +A. Paint the car +B. Put plastic over the car +C. Put metal over the car +D. Cut the car +Model's Answer: put plastic over the car. +Your Judgment: The model's response directly aligns with option B, which is "Put plastic over the car." The response given is a paraphrase of this option without deviating in meaning. The option chosen by the model: [[B]] + +Example 2: +Question: How did Eddie know Pam and Justin before Justin was killed? +Options: +A. They were part of the theater company +B. They were high school friends +C. They went to college together +D. They were cousins +E. They were siblings +Model's Answer: A. +Your Judgment: The model's answer directly provides the option letter "A." The option chosen by the model: [[A]] + +Example 3: +Question: why do the people move in the same manner +Options: +A. uniform +B. dancing with the baby +C. exercising together +D. stay together +E. singing and dancing +Model's Answer: sing and dance +Your Judgment: The model's response "sing and dance" closely aligns with option E, which is "singing and dancing." The response provided is a direct paraphrase of this option, modifying only slightly the form of the words (from gerund to infinitive) but maintaining the same core activities described in the option. The option chosen by the model: [[E]] + +When you think that the model's answer does not match any of the given options, please choose the option that is the closest to the model's answer. +Give the brief reasoning process and the extracted option for the below case. + +Question: {prompt} +Options: +{options} +Model's Answer: {response} +Your Judgment: +""", + }, +] + + +class GPTMultiChoiceFilter(Filter): + def __init__(self, gpt_version: str = "gpt-3.5-turbo-0125", retries: int = 5): + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + self.gpt_version = gpt_version + + if API_TYPE == "openai": + self.client = openai.OpenAI(api_key=API_KEY) + elif API_TYPE == "azure": + self.client = openai.AzureOpenAI(api_key=API_KEY, azure_endpoint=API_URL) + + self.retries = retries + + def apply(self, resps, docs): + """ + Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. + Should return the list of (filtered) response lists *in the same order as they were input*, e.g. + if pass in [, ] should return + [, ] + """ + results = [] + for response, doc in zip(resps, docs): + query = doc["query"] + options = "\n".join([f"{chr(ord('A') + idx)}. {option}" for idx, option in enumerate(doc["options"])]) + message = video2text_gpt_judge_for_closeended_multiplechoice(prompt=query, options=options, response=response) + payload = { + "model": self.gpt_version, + "messages": message, + "max_tokens": MAX_NEW_TOKENS, + } + result = 0 + for attempt in range(self.retries): + try: + # response = requests.post(API_URL, headers=headers, json=payload, timeout=60) + # print(payload) + response = client.chat.completions.create(**payload) + # print(response) + # response.raise_for_status() + + # content =["choices"][0]["message"]["content"].strip() + content = response.choices[0].message.content + print("content:", content) + if content: + match = re.search(r"\[\[([A-Z])\]\]", content) + # print("match:", match) + if not match: + match = re.search(r"r'\b([A-Z])\.?\b'", content) + # print("match:", match) + if match: + # print("=====") + # print(match.group(1)) + result = ord(match.group(1)) - ord("A") + # print("result:", result) + # print("=====") + # print(content, result) + else: + result = 0 + break # If successful, break out of the loop + + except Exception as e: + eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}") + import traceback + + print(traceback.format_exc()) + if attempt < self.retries: # If we have retries left, sleep and then continue to next attempt + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty + eval_logger.error(f"All {self.retries} attempts failed. Last error message: {e}") + result = 0 + break + results.append(str(result)) + return results