diff --git a/evalscope/benchmarks/arc/arc_adapter.py b/evalscope/benchmarks/arc/arc_adapter.py index e00cf78..eb470c9 100644 --- a/evalscope/benchmarks/arc/arc_adapter.py +++ b/evalscope/benchmarks/arc/arc_adapter.py @@ -4,6 +4,7 @@ import os from evalscope.benchmarks import Benchmark, DataAdapter +from evalscope.constants import EvalType from evalscope.metrics import WeightedAverageAccuracy, exact_match from evalscope.models import MultiChoiceModelAdapter from evalscope.utils import ResponseParser @@ -119,7 +120,7 @@ def get_gold_answer(self, input_d: dict) -> str: # Get the gold choice return input_d.get('answerKey', '') - def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = 'checkpoint') -> str: + def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> str: """ Parse the model output to get the answer. Could be the best choice index. @@ -131,12 +132,12 @@ def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: st Returns: The parsed answer. Depending on the dataset. Usually a string for chat. """ - if eval_type == 'checkpoint': + if eval_type == EvalType.CHECKPOINT: return result - elif eval_type == 'service': + elif eval_type == EvalType.SERVICE: return ResponseParser.parse_first_option_with_choices( text=result, options=self.choices) # TODO: to be checked ! - elif eval_type == 'custom': + elif eval_type == EvalType.CUSTOM: return ResponseParser.parse_first_option_with_choices( text=result, options=self.choices) # TODO: to be checked ! else: diff --git a/evalscope/benchmarks/bbh/__init__.py b/evalscope/benchmarks/bbh/__init__.py index 7387c94..b937315 100644 --- a/evalscope/benchmarks/bbh/__init__.py +++ b/evalscope/benchmarks/bbh/__init__.py @@ -1,5 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.bbh.bbh_adapter import DATASET_ID, SUBSET_LIST -from evalscope.benchmarks.bbh.bbh_adapter import BBHAdapter as DataAdapterClass -from evalscope.models.model_adapter import ChatGenerationModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/benchmarks/ceval/__init__.py b/evalscope/benchmarks/ceval/__init__.py index b7532a3..b937315 100644 --- a/evalscope/benchmarks/ceval/__init__.py +++ b/evalscope/benchmarks/ceval/__init__.py @@ -1,6 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.ceval.ceval_adapter import DATASET_ID, SUBJECT_MAPPING, SUBSET_LIST -from evalscope.benchmarks.ceval.ceval_adapter import CEVALAdapter -from evalscope.benchmarks.ceval.ceval_adapter import CEVALAdapter as DataAdapterClass -from evalscope.models.model_adapter import MultiChoiceModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/benchmarks/cmmlu/__init__.py b/evalscope/benchmarks/cmmlu/__init__.py index 864f846..b937315 100644 --- a/evalscope/benchmarks/cmmlu/__init__.py +++ b/evalscope/benchmarks/cmmlu/__init__.py @@ -1,6 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.cmmlu.cmmlu_adapter import DATASET_ID, SUBJECT_MAPPING, SUBSET_LIST -from evalscope.benchmarks.cmmlu.cmmlu_adapter import CMMLUAdapter -from evalscope.benchmarks.cmmlu.cmmlu_adapter import CMMLUAdapter as DataAdapterClass -from evalscope.models.model_adapter import MultiChoiceModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/benchmarks/competition_math/__init__.py b/evalscope/benchmarks/competition_math/__init__.py index 85efbf4..b937315 100644 --- a/evalscope/benchmarks/competition_math/__init__.py +++ b/evalscope/benchmarks/competition_math/__init__.py @@ -1,6 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.competition_math.competition_math_adapter import DATASET_ID, SUBSET_LIST -from evalscope.benchmarks.competition_math.competition_math_adapter import CompetitionMathAdapter -from evalscope.benchmarks.competition_math.competition_math_adapter import CompetitionMathAdapter as DataAdapterClass -from evalscope.models.model_adapter import ChatGenerationModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/benchmarks/data_adapter.py b/evalscope/benchmarks/data_adapter.py index 18f823e..468b47a 100644 --- a/evalscope/benchmarks/data_adapter.py +++ b/evalscope/benchmarks/data_adapter.py @@ -5,7 +5,7 @@ from modelscope.msdatasets import MsDataset from typing import Any, Optional -from evalscope.constants import DEFAULT_DATASET_CACHE_DIR, AnswerKeys, HubType +from evalscope.constants import DEFAULT_DATASET_CACHE_DIR, AnswerKeys, EvalType, HubType from evalscope.utils import normalize_score from evalscope.utils.logger import get_logger @@ -265,7 +265,7 @@ def get_gold_answer(self, input_d: Any) -> Any: raise NotImplementedError @abstractmethod - def parse_pred_result(self, result: Any, raw_input_d: dict = None, eval_type: str = 'checkpoint') -> Any: + def parse_pred_result(self, result: Any, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> Any: """ Parse the predicted result and extract proper answer. @@ -286,9 +286,9 @@ def match(self, gold: Any, pred: Any) -> Any: Args: gold (Any): The golden answer. Usually a string for chat/multiple-choice-questions. - e.g. 'A' + e.g. 'A', extracted from get_gold_answer method. pred (Any): The predicted answer. Usually a string for chat/multiple-choice-questions. - e.g. 'B' + e.g. 'B', extracted from parse_pred_result method. Returns: The match result. Usually a score (float) for chat/multiple-choice-questions. diff --git a/evalscope/benchmarks/general_qa/__init__.py b/evalscope/benchmarks/general_qa/__init__.py index 2e73200..b937315 100644 --- a/evalscope/benchmarks/general_qa/__init__.py +++ b/evalscope/benchmarks/general_qa/__init__.py @@ -1,6 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.general_qa.general_qa_adapter import DATASET_ID, SUBSET_LIST -from evalscope.benchmarks.general_qa.general_qa_adapter import GeneralQAAdapter -from evalscope.benchmarks.general_qa.general_qa_adapter import GeneralQAAdapter as DataAdapterClass -from evalscope.models.model_adapter import ChatGenerationModelAdapter as ModelAdapterClass diff --git a/evalscope/benchmarks/hellaswag/hellaswag_adapter.py b/evalscope/benchmarks/hellaswag/hellaswag_adapter.py index afae557..faafc96 100644 --- a/evalscope/benchmarks/hellaswag/hellaswag_adapter.py +++ b/evalscope/benchmarks/hellaswag/hellaswag_adapter.py @@ -4,6 +4,7 @@ import re from evalscope.benchmarks import Benchmark, DataAdapter +from evalscope.constants import EvalType from evalscope.metrics import WeightedAverageAccuracy, exact_match from evalscope.models import ContinuationLogitsModelAdapter from evalscope.utils.io_utils import jsonl_to_list @@ -92,7 +93,7 @@ def get_gold_answer(self, input_d: dict) -> str: # Get the gold choice return input_d['label'] - def parse_pred_result(self, result: list, raw_input_d: dict = None, eval_type: str = 'checkpoint') -> str: + def parse_pred_result(self, result: list, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> str: """ Parse the model output to get the answer. Could be the best choice index. @@ -104,7 +105,7 @@ def parse_pred_result(self, result: list, raw_input_d: dict = None, eval_type: s Returns: The parsed answer. Depending on the dataset. Usually a string for chat. """ - if eval_type == 'checkpoint': + if eval_type == EvalType.CHECKPOINT: # answer: in the form of [-2.3, -4.5, ...], len of self.choices result = np.array(result) endings: list = [self._preprocess(ending) for ending in raw_input_d['endings']] @@ -112,9 +113,9 @@ def parse_pred_result(self, result: list, raw_input_d: dict = None, eval_type: s best_choice_idx = np.argmax(result / completion_len) return str(best_choice_idx) - elif eval_type == 'service': + elif eval_type == EvalType.SERVICE: return result # TODO: to be supported ! - elif eval_type == 'custom': + elif eval_type == EvalType.CUSTOM: return result # TODO: to be supported ! else: raise ValueError(f'Invalid eval_type: {eval_type}') diff --git a/evalscope/benchmarks/humaneval/__init__.py b/evalscope/benchmarks/humaneval/__init__.py index 176dd8f..b937315 100644 --- a/evalscope/benchmarks/humaneval/__init__.py +++ b/evalscope/benchmarks/humaneval/__init__.py @@ -1,5 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.humaneval.humaneval_adapter import DATASET_ID, SUBSET_LIST -from evalscope.benchmarks.humaneval.humaneval_adapter import HumanevalAdapter as DataAdapterClass -from evalscope.models.model_adapter import ChatGenerationModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/benchmarks/mmlu/__init__.py b/evalscope/benchmarks/mmlu/__init__.py index c112533..b937315 100644 --- a/evalscope/benchmarks/mmlu/__init__.py +++ b/evalscope/benchmarks/mmlu/__init__.py @@ -1,6 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.mmlu.mmlu_adapter import DATASET_ID, SUBJECT_MAPPING, SUBSET_LIST -from evalscope.benchmarks.mmlu.mmlu_adapter import MMLUAdapter -from evalscope.benchmarks.mmlu.mmlu_adapter import MMLUAdapter as DataAdapterClass -from evalscope.models.model_adapter import MultiChoiceModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/benchmarks/race/__init__.py b/evalscope/benchmarks/race/__init__.py index f4290c4..b937315 100644 --- a/evalscope/benchmarks/race/__init__.py +++ b/evalscope/benchmarks/race/__init__.py @@ -1,6 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.race.race_adapter import DATASET_ID, SUBJECT_MAPPING, SUBSET_LIST -from evalscope.benchmarks.race.race_adapter import RACEAdapter -from evalscope.benchmarks.race.race_adapter import RACEAdapter as DataAdapterClass -from evalscope.models.model_adapter import MultiChoiceModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/benchmarks/trivia_qa/__init__.py b/evalscope/benchmarks/trivia_qa/__init__.py index 5087549..b937315 100644 --- a/evalscope/benchmarks/trivia_qa/__init__.py +++ b/evalscope/benchmarks/trivia_qa/__init__.py @@ -1,6 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.trivia_qa.trivia_qa_adapter import DATASET_ID, SUBSET_LIST -from evalscope.benchmarks.trivia_qa.trivia_qa_adapter import TriviaQaAdapter -from evalscope.benchmarks.trivia_qa.trivia_qa_adapter import TriviaQaAdapter as DataAdapterClass -from evalscope.models.model_adapter import ChatGenerationModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/benchmarks/truthful_qa/__init__.py b/evalscope/benchmarks/truthful_qa/__init__.py index 1fbe887..b937315 100644 --- a/evalscope/benchmarks/truthful_qa/__init__.py +++ b/evalscope/benchmarks/truthful_qa/__init__.py @@ -1,6 +1 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from evalscope.benchmarks.truthful_qa.truthful_qa_adapter import DATASET_ID, SUBSET_LIST -from evalscope.benchmarks.truthful_qa.truthful_qa_adapter import TruthfulQaAdapter -from evalscope.benchmarks.truthful_qa.truthful_qa_adapter import TruthfulQaAdapter as DataAdapterClass -from evalscope.models.model_adapter import ContinuationLogitsModelAdapter as ModelAdapterClass # noqa diff --git a/evalscope/config.py b/evalscope/config.py index 3e8652c..a768bc2 100644 --- a/evalscope/config.py +++ b/evalscope/config.py @@ -72,9 +72,10 @@ def __post_init__(self): self.model_id = type(self.model).__name__ else: self.model_id = os.path.basename(self.model).rstrip(os.sep) + # Convert Enum to string + self.eval_backend = str(self.eval_backend) def to_dict(self): - # Note: to avoid serialization error for some model instance return self.__dict__ def __str__(self): @@ -129,6 +130,7 @@ def load(custom_model: CustomModel, tasks: List[str]) -> List['TaskConfig']: continue task.model = custom_model + task.model_args = custom_model.config task.model_id = type(custom_model).__name__ res_list.append(task) diff --git a/evalscope/constants.py b/evalscope/constants.py index be8d00e..f6152ac 100644 --- a/evalscope/constants.py +++ b/evalscope/constants.py @@ -135,7 +135,8 @@ class EvalStage: class EvalType: CUSTOM = 'custom' - CHECKPOINT = 'checkpoint' + CHECKPOINT = 'checkpoint' # native model checkpoint + SERVICE = 'service' # model service class EvalBackend: diff --git a/evalscope/evaluator/evaluator.py b/evalscope/evaluator/evaluator.py index f894411..88c1894 100644 --- a/evalscope/evaluator/evaluator.py +++ b/evalscope/evaluator/evaluator.py @@ -10,9 +10,8 @@ from evalscope.benchmarks import DataAdapter from evalscope.config import TaskConfig -from evalscope.constants import (DEFAULT_DATASET_CACHE_DIR, AnswerKeys, DumpMode, EvalStage, EvalType, HubType, - ReviewKeys) -from evalscope.models.model_adapter import BaseModelAdapter, CustomModelAdapter +from evalscope.constants import AnswerKeys, DumpMode, EvalStage, ReviewKeys +from evalscope.models import BaseModelAdapter, CustomModelAdapter from evalscope.tools.combine_reports import gen_table from evalscope.utils import dict_torch_dtype_to_str, gen_hash from evalscope.utils.io_utils import OutputsStructure, dump_jsonl_data, jsonl_to_list @@ -45,44 +44,36 @@ class Evaluator(object): def __init__(self, dataset_name_or_path: str, data_adapter: DataAdapter, - subset_list: Optional[list] = None, - model_adapter: Optional[BaseModelAdapter] = None, - use_cache: Optional[str] = None, - outputs: Optional[OutputsStructure] = None, - datasets_dir: Optional[str] = DEFAULT_DATASET_CACHE_DIR, - datasets_hub: Optional[str] = HubType.MODELSCOPE, - stage: Optional[str] = EvalStage.ALL, - eval_type: Optional[str] = EvalType.CHECKPOINT, - overall_task_cfg: Optional[TaskConfig] = None, + model_adapter: BaseModelAdapter, + subset_list: list = None, + outputs: OutputsStructure = None, + task_cfg: TaskConfig = None, **kwargs): self.dataset_name_or_path = os.path.expanduser(dataset_name_or_path) self.dataset_name = os.path.basename(self.dataset_name_or_path.rstrip(os.sep)).split('.')[0] - self.model_name = overall_task_cfg.model_id + self.model_name = task_cfg.model_id self.custom_task_name = f'{self.model_name}_{self.dataset_name}' - self.datasets_dir = os.path.expanduser(datasets_dir) + self.datasets_dir = os.path.expanduser(task_cfg.dataset_dir) self.kwargs = kwargs self.data_adapter = data_adapter self.model_adapter = model_adapter - self.eval_type = eval_type - self.stage = stage - self.use_cache = use_cache - self.overall_task_cfg = overall_task_cfg - if isinstance(self.model_adapter, CustomModelAdapter): - self.overall_task_cfg.model_args = self.model_adapter.custom_model.config - - self.model_cfg = self.model_adapter.model_cfg + self.eval_type = task_cfg.eval_type + self.stage = task_cfg.stage + self.use_cache = task_cfg.use_cache + self.task_cfg = task_cfg + self.model_cfg = model_adapter.model_cfg # Deal with the output paths self.outputs_structure = outputs # Load dataset self.dataset = self.data_adapter.load( - dataset_name_or_path=dataset_name_or_path, + dataset_name_or_path=self.dataset_name_or_path, subset_list=subset_list, work_dir=self.datasets_dir, - datasets_hub=datasets_hub, + datasets_hub=task_cfg.dataset_hub, **kwargs) # Get prompts from dataset diff --git a/evalscope/evaluator/reviewer/auto_reviewer.py b/evalscope/evaluator/reviewer/auto_reviewer.py index 4144f11..bd0e387 100644 --- a/evalscope/evaluator/reviewer/auto_reviewer.py +++ b/evalscope/evaluator/reviewer/auto_reviewer.py @@ -8,7 +8,7 @@ import time from abc import ABC, abstractmethod from functools import partial -from typing import Any, List +from typing import Any, List, Tuple from evalscope.constants import ArenaMode, EvalConfigKeys, FnCompletionParser, PositionBiasMitigation from evalscope.models.model import OpenAIModel @@ -240,7 +240,15 @@ def get_review_single(self, row: List[dict], dry_run: bool = False, **kwargs): review_text=review_text) return review_result - def _get_review_pair(self, model_a, model_b, question, category, ans1, ans2, dry_run=False, **kwargs) -> (str, Any): + def _get_review_pair(self, + model_a, + model_b, + question, + category, + ans1, + ans2, + dry_run=False, + **kwargs) -> Tuple[str, Any]: input_msg = dict(ques=question, category=category, ans1=ans1, ans2=ans2) if self.reference_list: @@ -263,7 +271,7 @@ def _get_review_pair(self, model_a, model_b, question, category, ans1, ans2, dry result = (result, None) return review_text, *result - def _get_review_single(self, model, question, category, answer, dry_run=False, **kwargs) -> (str, Any): + def _get_review_single(self, model, question, category, answer, dry_run=False, **kwargs) -> Tuple[str, Any]: input_msg = dict(ques=question, category=category, ans1=answer) if self.reference_list: diff --git a/evalscope/models/__init__.py b/evalscope/models/__init__.py index 8fc22eb..90f126e 100644 --- a/evalscope/models/__init__.py +++ b/evalscope/models/__init__.py @@ -1,5 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from evalscope.models.custom import * -from evalscope.models.model import BaseModel, ChatBaseModel -from evalscope.models.model_adapter import * +from evalscope.models.base_adapter import BaseModelAdapter +from evalscope.models.chat_adapter import ChatGenerationModelAdapter +from evalscope.models.choice_adapter import ContinuationLogitsModelAdapter, MultiChoiceModelAdapter +from evalscope.models.custom import CustomModel +from evalscope.models.custom_adapter import CustomModelAdapter +from evalscope.models.local_model import LocalModel +from evalscope.models.model import BaseModel, ChatBaseModel, OpenAIModel +from evalscope.models.server_adapter import ServerModelAdapter + +__all__ = [ + 'CustomModel', 'BaseModel', 'ChatBaseModel', 'OpenAIModel', 'BaseModelAdapter', 'ChatGenerationModelAdapter', + 'MultiChoiceModelAdapter', 'ContinuationLogitsModelAdapter', 'CustomModelAdapter', 'ServerModelAdapter', + 'LocalModel' +] diff --git a/evalscope/models/base_adapter.py b/evalscope/models/base_adapter.py new file mode 100644 index 0000000..ea00c7c --- /dev/null +++ b/evalscope/models/base_adapter.py @@ -0,0 +1,27 @@ +import torch +from abc import ABC, abstractmethod +from typing import Any, Union + +from evalscope.models.custom import CustomModel +from evalscope.models.local_model import LocalModel + + +class BaseModelAdapter(ABC): + + def __init__(self, model: Union[LocalModel, CustomModel], **kwargs): + if isinstance(model, LocalModel): + self.model = model.model + self.model_id = model.model_id + self.model_revision = model.model_revision + self.device = model.device + self.tokenizer = model.tokenizer + self.model_cfg = model.model_cfg + elif isinstance(model, CustomModel): + self.model_cfg = model.config + else: + raise ValueError(f'Unsupported model type: {type(model)}') + + @abstractmethod + @torch.no_grad() + def predict(self, *args, **kwargs) -> Any: + raise NotImplementedError diff --git a/evalscope/models/chat_adapter.py b/evalscope/models/chat_adapter.py new file mode 100644 index 0000000..033ee7f --- /dev/null +++ b/evalscope/models/chat_adapter.py @@ -0,0 +1,108 @@ +import os +import time +import torch +from modelscope import GenerationConfig +from typing import Union + +from evalscope.models.base_adapter import BaseModelAdapter +from evalscope.models.local_model import LocalModel +from evalscope.utils.chat_service import ChatCompletionResponse, ChatMessage +from evalscope.utils.logger import get_logger +from evalscope.utils.model_utils import fix_do_sample_warning + +logger = get_logger() + + +class ChatGenerationModelAdapter(BaseModelAdapter): + """ + Chat generation model adapter. + """ + + def __init__(self, model: LocalModel, **kwargs): + super().__init__(model) + + self.generation_config = self._parse_generation_config(self.tokenizer, self.model) + + custom_generation_config = kwargs.pop('generation_config', None) + custom_chat_template = kwargs.pop('chat_template', None) + + if custom_generation_config: + logger.info('Updating generation config ...') + self.generation_config.update(**custom_generation_config) + + if custom_chat_template: + self.tokenizer.chat_template = custom_chat_template + logger.info(f'Using custom chat template: {custom_chat_template}') + + def _parse_generation_config(self, tokenizer, model): + generation_config = getattr(model, 'generation_config', GenerationConfig(do_sample=False)) + + try: + remote_config = GenerationConfig.from_pretrained( + self.model_id, revision=self.model_revision, trust_remote_code=True) + generation_config.update(**remote_config.to_dict()) + except Exception: + logger.warning(f'Failed to get generation config of {self.model_id} from model hub, use default.') + + if isinstance(self.model_id, str) and os.path.exists(self.model_id): + logger.warning(f'Got local model dir: {self.model_id}') + + if tokenizer.eos_token_id is not None: + generation_config.eos_token_id = tokenizer.eos_token_id + if tokenizer.pad_token_id is not None: + generation_config.pad_token_id = tokenizer.pad_token_id + if generation_config.max_new_tokens is None: + generation_config.max_new_tokens = 2048 + + return generation_config + + def _model_generate(self, query: str, infer_cfg: dict) -> str: + messages = [ChatMessage(role='user', content=query)] + formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = self.tokenizer(formatted_prompt, return_tensors='pt', padding=True).to(self.device) + input_ids = inputs['input_ids'] + + # Process infer_cfg + if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1: + infer_cfg['do_sample'] = True + + # stop settings + stop = infer_cfg.get('stop', None) + eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \ + if stop else self.tokenizer.eos_token_id + + if eos_token_id is not None: + infer_cfg['eos_token_id'] = eos_token_id + infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token + + self.generation_config.update(**infer_cfg) + fix_do_sample_warning(self.generation_config) + + # Run inference + output_ids = self.model.generate(**inputs, generation_config=self.generation_config) + + response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], skip_special_tokens=True) + return response + + @torch.no_grad() + def predict(self, inputs: Union[str, dict, list], infer_cfg: dict = {}) -> dict: + + # Process inputs + if isinstance(inputs, str): + query = inputs + elif isinstance(inputs, dict): + query = inputs['data'][0] + elif isinstance(inputs, list): + query = '\n'.join(inputs) + else: + raise TypeError(f'Unsupported inputs type: {type(inputs)}') + + response = self._model_generate(query, infer_cfg) + + choices_list = [{'index': 0, 'message': {'content': response, 'role': 'assistant'}}] + + res_d = ChatCompletionResponse( + model=self.model_id, choices=choices_list, object='chat.completion', created=int(time.time()), + usage=None).model_dump(exclude_unset=True) + + return res_d diff --git a/evalscope/models/choice_adapter.py b/evalscope/models/choice_adapter.py new file mode 100644 index 0000000..b2d403e --- /dev/null +++ b/evalscope/models/choice_adapter.py @@ -0,0 +1,214 @@ +import numpy as np +import time +import torch +from typing import List + +from evalscope.models.base_adapter import BaseModelAdapter +from evalscope.models.local_model import LocalModel +from evalscope.utils.chat_service import ChatCompletionResponse + + +class MultiChoiceModelAdapter(BaseModelAdapter): + """ The multi-choice model adapter. """ + + _DEFAULT_MAX_LENGTH = 2048 + + def __init__(self, model: LocalModel, **kwargs): + super().__init__(model) + + self._max_length = kwargs.get('max_length') + + @property + def max_length(self): + if self._max_length: + return self._max_length + seqlen_config_attrs = ('n_positions', 'max_position_embeddings', 'n_ctx') + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, 'model_max_length'): + if self.tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @torch.no_grad() + def predict(self, inputs: dict, infer_cfg: dict = None) -> dict: + """ + Multi-choice model prediction func. + + Args: + inputs (dict): The inputs for a doc. Format: + {'data': [full_prompt], 'multi_choices': ['A', 'B', 'C', 'D']} + + infer_cfg (dict): inference configuration. + + Returns: + res (dict): The model prediction results. Format: + { + 'choices': [ + { + 'index': 0, + 'message': { + 'content': [-14.9609, -13.6015, ...], # loglikelihood values for inputs context-continuation pairs. + 'role': 'assistant' + } + } + ], + 'created': 1677664795, + # For models on the ModelScope or HuggingFace, concat model_id and revision with "-". + 'model': 'gpt-3.5-turbo-0613', + 'object': 'chat.completion', + 'usage': { + 'completion_tokens': 17, + 'prompt_tokens': 57, + 'total_tokens': 74 + } + } + """ + infer_cfg = infer_cfg or {} + self.model.generation_config.update(**infer_cfg) + + input_data = inputs['data'] + multi_choices = inputs['multi_choices'] + + output, input_info = self._get_logits(self.tokenizer, self.model, input_data) + assert output.shape[0] == 1 + logits = output.flatten() + + choice_logits = [logits[self.tokenizer(ch)['input_ids'][-1:]] for ch in multi_choices] + softval = torch.nn.functional.softmax(torch.tensor(choice_logits).float(), dim=0) + + if softval.dtype in {torch.bfloat16, torch.float16}: + softval = softval.to(dtype=torch.float32) + probs = softval.detach().cpu().numpy() + pred: str = multi_choices[int(np.argmax(probs))] # Format: A or B or C or D + + res_d = ChatCompletionResponse( + model=self.model_id, + choices=[{ + 'index': 0, + 'message': { + 'content': pred, + 'role': 'assistant' + } + }], + object='chat.completion', + created=int(time.time()), + usage=None).model_dump(exclude_unset=True) + + return res_d + + @staticmethod + def _get_logits(tokenizer, model, inputs: List[str]): + input_ids = tokenizer(inputs, padding=False)['input_ids'] + input_ids = torch.tensor(input_ids, device=model.device) + tokens = {'input_ids': input_ids} + + outputs = model(input_ids)['logits'] + logits = outputs[:, -1, :] + log_probs = torch.nn.functional.softmax(logits, dim=-1) + return log_probs, {'tokens': tokens} + + +class ContinuationLogitsModelAdapter(MultiChoiceModelAdapter): + """ + Continuation-logits model adapter. + """ + + def __init__(self, model: LocalModel, **kwargs): + super().__init__(model, **kwargs) + + @torch.no_grad() + def predict(self, inputs: dict, infer_cfg: dict = None) -> dict: + """ + Multi-choice model prediction func. + Args: + inputs (dict): The inputs for a doc. Format: + {'data': [(context, continuation), ...]} + infer_cfg (dict): inference configuration. + Returns: + res (dict): The model prediction results. Format: + { + 'choices': [ + { + 'index': 0, + 'message': { + 'content': [-14.9609, -13.6015, ...], # loglikelihood values for inputs context-continuation pairs. + 'role': 'assistant' + } + } + ], + 'created': 1677664795, + # For models on the ModelScope or HuggingFace, concat model_id and revision with "-". + 'model': 'gpt-3.5-turbo-0613', + 'object': 'chat.completion', + 'usage': { + 'completion_tokens': 17, + 'prompt_tokens': 57, + 'total_tokens': 74 + } + } + """ + infer_cfg = infer_cfg or {} + + pred_list: list = self.loglikelihood(inputs=inputs['data'], infer_cfg=infer_cfg) + + res_d = ChatCompletionResponse( + model=self.model_id, + choices=[{ + 'index': 0, + 'message': { + 'content': pred_list, + 'role': 'assistant' + } + }], + object='chat.completion', + created=int(time.time()), + usage=None).model_dump(exclude_unset=True) + + return res_d + + def loglikelihood(self, inputs: list, infer_cfg: dict = None) -> list: + self.model.generation_config.update(**infer_cfg) + # To predict one doc + doc_ele_pred = [] + for ctx, continuation in inputs: + + # ctx_enc shape: [context_tok_len] cont_enc shape: [continuation_tok_len] + ctx_enc, cont_enc = self._encode_pair(ctx, continuation) + + inputs_tokens = torch.tensor( + (ctx_enc.tolist() + cont_enc.tolist())[-(self.max_length + 1):][:-1], + dtype=torch.long, + device=self.model.device).unsqueeze(0) + + logits = self.model(inputs_tokens)[0] + logits = torch.nn.functional.log_softmax(logits.float(), dim=-1) + + logits = logits[:, -len(cont_enc):, :] + cont_enc = cont_enc.unsqueeze(0).unsqueeze(-1) + logits = torch.gather(logits.cpu(), 2, cont_enc.cpu()).squeeze(-1) + + choice_score = float(logits.sum()) + doc_ele_pred.append(choice_score) + + # e.g. [-2.3, -9.2, -12.9, 1.1], length=len(choices) + return doc_ele_pred + + def _encode_pair(self, context, continuation): + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + whole_enc = self.tokenizer(context + continuation, padding=False)['input_ids'] + whole_enc = torch.tensor(whole_enc, device=self.device) + + context_enc = self.tokenizer(context, padding=False)['input_ids'] + context_enc = torch.tensor(context_enc, device=self.device) + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc diff --git a/evalscope/models/custom_adapter.py b/evalscope/models/custom_adapter.py new file mode 100644 index 0000000..fb279fe --- /dev/null +++ b/evalscope/models/custom_adapter.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List, Union + +from evalscope.models.base_adapter import BaseModelAdapter +from evalscope.models.custom import CustomModel + + +class CustomModelAdapter(BaseModelAdapter): + + def __init__(self, custom_model: CustomModel, **kwargs): + """ + Custom model adapter. + + Args: + custom_model: The custom model instance. + **kwargs: Other args. + """ + self.custom_model = custom_model + super(CustomModelAdapter, self).__init__(model=custom_model) + + def predict(self, inputs: Union[str, dict, list], **kwargs) -> List[Dict[str, Any]]: + """ + Model prediction func. + + Args: + inputs (Union[str, dict, list]): The input data. Depending on the specific model. + str: 'xxx' + dict: {'data': [full_prompt]} + list: ['xxx', 'yyy', 'zzz'] + **kwargs: kwargs + + Returns: + res (dict): The model prediction results. Format: + { + 'choices': [ + { + 'index': 0, + 'message': { + 'content': 'xxx', + 'role': 'assistant' + } + } + ], + 'created': 1677664795, + 'model': 'gpt-3.5-turbo-0613', # should be model_id + 'object': 'chat.completion', + 'usage': { + 'completion_tokens': 17, + 'prompt_tokens': 57, + 'total_tokens': 74 + } + } + """ + in_prompts = [] + + # Note: here we assume the inputs are all prompts for the benchmark. + for input_prompt in inputs: + if isinstance(input_prompt, str): + in_prompts.append(input_prompt) + elif isinstance(input_prompt, dict): + # TODO: to be supported for continuation list like truthful_qa + in_prompts.append(input_prompt['data'][0]) + elif isinstance(input_prompt, list): + in_prompts.append('\n'.join(input_prompt)) + else: + raise TypeError(f'Unsupported inputs type: {type(input_prompt)}') + + return self.custom_model.predict(prompts=in_prompts, **kwargs) diff --git a/evalscope/models/local_model.py b/evalscope/models/local_model.py new file mode 100644 index 0000000..3702781 --- /dev/null +++ b/evalscope/models/local_model.py @@ -0,0 +1,47 @@ +import torch +from modelscope import AutoModelForCausalLM, AutoTokenizer +from torch import dtype + +from evalscope.constants import DEFAULT_MODEL_CACHE_DIR +from evalscope.utils.logger import get_logger + +logger = get_logger() + + +class LocalModel: + + def __init__(self, + model_id: str, + model_revision: str = 'master', + device_map: str = 'auto', + torch_dtype: dtype = torch.bfloat16, + cache_dir: str = None, + **kwargs): + model_cache_dir = cache_dir or DEFAULT_MODEL_CACHE_DIR + + self.model_id = model_id + self.model_revision = model_revision + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logger.info(f'Device: {self.device}') + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id, + revision=model_revision, + trust_remote_code=True, + cache_dir=model_cache_dir, + ) + + self.model = AutoModelForCausalLM.from_pretrained( + self.model_id, + revision=model_revision, + device_map=device_map, + trust_remote_code=True, + torch_dtype=torch_dtype, + cache_dir=model_cache_dir, + ) + + self.model_cfg = { + 'model_id': model_id, + 'device_map': device_map, + 'torch_dtype': str(torch_dtype), + } diff --git a/evalscope/models/model_adapter.py b/evalscope/models/model_adapter.py deleted file mode 100644 index a3d5650..0000000 --- a/evalscope/models/model_adapter.py +++ /dev/null @@ -1,467 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# Copyright (c) EleutherAI, Inc. and its affiliates. -# flake8: noqa -import numpy as np -import os -import time -import torch -from abc import ABC, abstractmethod -from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -from torch import dtype -from typing import Any, Dict, List, Union - -from evalscope.constants import DEFAULT_MODEL_CACHE_DIR -from evalscope.models.custom import CustomModel -from evalscope.utils.chat_service import ChatMessage -from evalscope.utils.logger import get_logger -from evalscope.utils.model_utils import fix_do_sample_warning - -logger = get_logger() - - -class LocalModel: - - def __init__(self, - model_id: str, - model_revision: str = 'master', - device_map: str = 'auto', - torch_dtype: dtype = torch.bfloat16, - cache_dir: str = None, - **kwargs): - """ - Args: - model_id: The model id on ModelScope, or local model_dir. - model_revision: The model revision on ModelScope. - device_map: The device map for model inference. - torch_dtype: The torch dtype for model inference. - cache_dir: Directory to cache the models. - """ - model_cache_dir = cache_dir or DEFAULT_MODEL_CACHE_DIR - - self.model_id = model_id - self.model_revision = model_revision - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - logger.info(f'Device: {self.device}') - - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_id, - revision=model_revision, - trust_remote_code=True, - cache_dir=model_cache_dir, - ) - - self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, - revision=model_revision, - device_map=device_map, - trust_remote_code=True, - torch_dtype=torch_dtype, - cache_dir=model_cache_dir, - ) - - self.model_cfg = { - 'model_id': model_id, - 'device_map': device_map, - 'torch_dtype': str(torch_dtype), - } - - -class BaseModelAdapter(ABC): - """ - Base class for model adapter. - """ - - def __init__(self, model: Union[LocalModel, CustomModel], **kwargs): - """ - Args: - model: The model instance which is compatible with - AutoModel/AutoModelForCausalLM/AutoModelForSeq2SeqLM of transformers. - """ - if isinstance(model, LocalModel): - self.model = model.model - self.model_id = model.model_id - self.model_revision = model.model_revision - self.device = model.device - self.tokenizer = model.tokenizer - self.model_cfg = model.model_cfg - elif isinstance(model, CustomModel): - pass - else: - raise ValueError(f'Unsupported model type: {type(model)}') - - @abstractmethod - @torch.no_grad() - def predict(self, *args, **kwargs) -> Any: - """ - Model prediction func. - """ - raise NotImplementedError - - -class MultiChoiceModelAdapter(BaseModelAdapter): - """ The multi-choice model adapter. """ - - _DEFAULT_MAX_LENGTH = 2048 - - def __init__(self, model: LocalModel, **kwargs): - super().__init__(model) - - self._max_length = kwargs.get('max_length') - - @property - def max_length(self): - if self._max_length: - return self._max_length - seqlen_config_attrs = ('n_positions', 'max_position_embeddings', 'n_ctx') - for attr in seqlen_config_attrs: - if hasattr(self.model.config, attr): - return getattr(self.model.config, attr) - if hasattr(self.tokenizer, 'model_max_length'): - if self.tokenizer.model_max_length == 1000000000000000019884624838656: - return self._DEFAULT_MAX_LENGTH - return self.tokenizer.model_max_length - return self._DEFAULT_MAX_LENGTH - - @torch.no_grad() - def predict(self, inputs: dict, infer_cfg: dict = None) -> dict: - """ - Multi-choice model prediction func. - - Args: - inputs (dict): The inputs for a doc. Format: - {'data': [full_prompt], 'multi_choices': ['A', 'B', 'C', 'D']} - - infer_cfg (dict): inference configuration. - - Returns: - res (dict): The model prediction results. Format: - { - 'choices': [ - { - 'index': 0, - 'message': { - 'content': [-14.9609, -13.6015, ...], # loglikelihood values for inputs context-continuation pairs. - 'role': 'assistant' - } - } - ], - 'created': 1677664795, - # For models on the ModelScope or HuggingFace, concat model_id and revision with "-". - 'model': 'gpt-3.5-turbo-0613', - 'object': 'chat.completion', - 'usage': { - 'completion_tokens': 17, - 'prompt_tokens': 57, - 'total_tokens': 74 - } - } - """ - infer_cfg = infer_cfg or {} - self.model.generation_config.update(**infer_cfg) - - input_data = inputs['data'] - multi_choices = inputs['multi_choices'] - - output, input_info = self._get_logits(self.tokenizer, self.model, input_data) - assert output.shape[0] == 1 - logits = output.flatten() - - choice_logits = [logits[self.tokenizer(ch)['input_ids'][-1:]] for ch in multi_choices] - softval = torch.nn.functional.softmax(torch.tensor(choice_logits).float(), dim=0) - - if softval.dtype in {torch.bfloat16, torch.float16}: - softval = softval.to(dtype=torch.float32) - probs = softval.detach().cpu().numpy() - pred: str = multi_choices[int(np.argmax(probs))] # Format: A or B or C or D - - res_d = { - 'choices': [{ - 'index': 0, - 'message': { - 'content': pred, - 'role': 'assistant' - } - }], - 'created': time.time(), - 'model': self.model_id, - 'object': 'chat.completion', - 'usage': {} - } - - return res_d - - @staticmethod - def _get_logits(tokenizer, model, inputs: List[str]): - input_ids = tokenizer(inputs, padding=False)['input_ids'] - input_ids = torch.tensor(input_ids, device=model.device) - tokens = {'input_ids': input_ids} - - outputs = model(input_ids)['logits'] - logits = outputs[:, -1, :] - log_probs = torch.nn.functional.softmax(logits, dim=-1) - return log_probs, {'tokens': tokens} - - -class ContinuationLogitsModelAdapter(MultiChoiceModelAdapter): - """ - Continuation-logits model adapter. - """ - - def __init__(self, model: LocalModel, **kwargs): - super().__init__(model, **kwargs) - - @torch.no_grad() - def predict(self, inputs: dict, infer_cfg: dict = None) -> dict: - """ - Multi-choice model prediction func. - Args: - inputs (dict): The inputs for a doc. Format: - {'data': [(context, continuation), ...]} - infer_cfg (dict): inference configuration. - Returns: - res (dict): The model prediction results. Format: - { - 'choices': [ - { - 'index': 0, - 'message': { - 'content': [-14.9609, -13.6015, ...], # loglikelihood values for inputs context-continuation pairs. - 'role': 'assistant' - } - } - ], - 'created': 1677664795, - # For models on the ModelScope or HuggingFace, concat model_id and revision with "-". - 'model': 'gpt-3.5-turbo-0613', - 'object': 'chat.completion', - 'usage': { - 'completion_tokens': 17, - 'prompt_tokens': 57, - 'total_tokens': 74 - } - } - """ - infer_cfg = infer_cfg or {} - - pred_list: list = self.loglikelihood(inputs=inputs['data'], infer_cfg=infer_cfg) - - res_d = { - 'choices': [{ - 'index': 0, - 'message': { - 'content': pred_list, - 'role': 'assistant' - } - }], - 'created': time.time(), - 'model': self.model_id, - 'object': 'chat.completion', - 'usage': {} - } - return res_d - - def loglikelihood(self, inputs: list, infer_cfg: dict = None) -> list: - self.model.generation_config.update(**infer_cfg) - # To predict one doc - doc_ele_pred = [] - for ctx, continuation in inputs: - - # ctx_enc shape: [context_tok_len] cont_enc shape: [continuation_tok_len] - ctx_enc, cont_enc = self._encode_pair(ctx, continuation) - - inputs_tokens = torch.tensor( - (ctx_enc.tolist() + cont_enc.tolist())[-(self.max_length + 1):][:-1], - dtype=torch.long, - device=self.model.device).unsqueeze(0) - - logits = self.model(inputs_tokens)[0] - logits = torch.nn.functional.log_softmax(logits.float(), dim=-1) - - logits = logits[:, -len(cont_enc):, :] - cont_enc = cont_enc.unsqueeze(0).unsqueeze(-1) - logits = torch.gather(logits.cpu(), 2, cont_enc.cpu()).squeeze(-1) - - choice_score = float(logits.sum()) - doc_ele_pred.append(choice_score) - - # e.g. [-2.3, -9.2, -12.9, 1.1], length=len(choices) - return doc_ele_pred - - def _encode_pair(self, context, continuation): - n_spaces = len(context) - len(context.rstrip()) - if n_spaces > 0: - continuation = context[-n_spaces:] + continuation - context = context[:-n_spaces] - - whole_enc = self.tokenizer(context + continuation, padding=False)['input_ids'] - whole_enc = torch.tensor(whole_enc, device=self.device) - - context_enc = self.tokenizer(context, padding=False)['input_ids'] - context_enc = torch.tensor(context_enc, device=self.device) - - context_enc_len = len(context_enc) - continuation_enc = whole_enc[context_enc_len:] - - return context_enc, continuation_enc - - -class ChatGenerationModelAdapter(BaseModelAdapter): - """ - Chat generation model adapter. - """ - - def __init__(self, model: LocalModel, **kwargs): - super().__init__(model) - - self.generation_config = self._parse_generation_config(self.tokenizer, self.model) - - custom_generation_config = kwargs.pop('generation_config', None) - custom_chat_template = kwargs.pop('chat_template', None) - - if custom_generation_config: - logger.info('Updating generation config ...') - self.generation_config.update(**custom_generation_config) - - if custom_chat_template: - self.tokenizer.chat_template = custom_chat_template - logger.info(f'Using custom chat template: {custom_chat_template}') - - def _parse_generation_config(self, tokenizer, model): - generation_config = getattr(model, 'generation_config', GenerationConfig(do_sample=False)) - - try: - remote_config = GenerationConfig.from_pretrained( - self.model_id, revision=self.model_revision, trust_remote_code=True) - generation_config.update(**remote_config.to_dict()) - except: - logger.warning(f'Failed to get generation config of {self.model_id} from model hub, use default.') - - if isinstance(self.model_id, str) and os.path.exists(self.model_id): - logger.warning(f'Got local model dir: {self.model_id}') - - if tokenizer.eos_token_id is not None: - generation_config.eos_token_id = tokenizer.eos_token_id - if tokenizer.pad_token_id is not None: - generation_config.pad_token_id = tokenizer.pad_token_id - if generation_config.max_new_tokens is None: - generation_config.max_new_tokens = 2048 - - return generation_config - - def _model_generate(self, query: str, infer_cfg: dict) -> str: - messages = [ChatMessage(role='user', content=query)] - formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - inputs = self.tokenizer(formatted_prompt, return_tensors='pt', padding=True).to(self.device) - input_ids = inputs['input_ids'] - - # Process infer_cfg - if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1: - infer_cfg['do_sample'] = True - - # stop settings - stop = infer_cfg.get('stop', None) - eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \ - if stop else self.tokenizer.eos_token_id - - if eos_token_id is not None: - infer_cfg['eos_token_id'] = eos_token_id - infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token - - self.generation_config.update(**infer_cfg) - fix_do_sample_warning(self.generation_config) - - # Run inference - output_ids = self.model.generate(**inputs, generation_config=self.generation_config) - - response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], skip_special_tokens=True) - return response - - @torch.no_grad() - def predict(self, inputs: Union[str, dict, list], infer_cfg: dict = {}) -> dict: - - # Process inputs - if isinstance(inputs, str): - query = inputs - elif isinstance(inputs, dict): - query = inputs['data'][0] - elif isinstance(inputs, list): - query = '\n'.join(inputs) - else: - raise TypeError(f'Unsupported inputs type: {type(inputs)}') - - response = self._model_generate(query, infer_cfg) - - choices_list = [{'index': 0, 'message': {'content': response, 'role': 'assistant'}}] - - res_d = { - 'choices': choices_list, - 'created': time.time(), - 'model': self.model_id, - 'object': 'chat.completion', - 'usage': {} - } - - return res_d - - -class CustomModelAdapter(BaseModelAdapter): - - def __init__(self, custom_model: CustomModel, **kwargs): - """ - Custom model adapter. - - Args: - custom_model: The custom model instance. - **kwargs: Other args. - """ - self.custom_model = custom_model - super(CustomModelAdapter, self).__init__(model=custom_model) - - def predict(self, inputs: Union[str, dict, list], **kwargs) -> List[Dict[str, Any]]: - """ - Model prediction func. - - Args: - inputs (Union[str, dict, list]): The input data. Depending on the specific model. - str: 'xxx' - dict: {'data': [full_prompt]} - list: ['xxx', 'yyy', 'zzz'] - **kwargs: kwargs - - Returns: - res (dict): The model prediction results. Format: - { - 'choices': [ - { - 'index': 0, - 'message': { - 'content': 'xxx', - 'role': 'assistant' - } - } - ], - 'created': 1677664795, - 'model': 'gpt-3.5-turbo-0613', # should be model_id - 'object': 'chat.completion', - 'usage': { - 'completion_tokens': 17, - 'prompt_tokens': 57, - 'total_tokens': 74 - } - } - """ - in_prompts = [] - - # Note: here we assume the inputs are all prompts for the benchmark. - for input_prompt in inputs: - if isinstance(input_prompt, str): - in_prompts.append(input_prompt) - elif isinstance(input_prompt, dict): - # TODO: to be supported for continuation list like truthful_qa - in_prompts.append(input_prompt['data'][0]) - elif isinstance(input_prompt, list): - in_prompts.append('\n'.join(input_prompt)) - else: - raise TypeError(f'Unsupported inputs type: {type(input_prompt)}') - - return self.custom_model.predict(prompts=in_prompts, **kwargs) diff --git a/evalscope/models/server_adapter.py b/evalscope/models/server_adapter.py new file mode 100644 index 0000000..4c75393 --- /dev/null +++ b/evalscope/models/server_adapter.py @@ -0,0 +1,80 @@ +import requests +import time +from typing import Union + +from evalscope.models.base_adapter import BaseModelAdapter +from evalscope.models.custom import CustomModel +from evalscope.models.local_model import LocalModel +from evalscope.utils.chat_service import ChatCompletionResponse + + +class ServerModelAdapter(BaseModelAdapter): + """ + Server model adapter to request remote API model and generate results. + """ + + def __init__(self, model: Union[LocalModel, CustomModel], api_url: str, **kwargs): + """ + Args: + model: The model instance. + api_url: The URL of the remote API model. + **kwargs: Other args. + """ + super().__init__(model, **kwargs) + self.api_url = api_url + + def predict(self, inputs: Union[str, dict, list], infer_cfg: dict = None) -> dict: + """ + Model prediction func. + + Args: + inputs (Union[str, dict, list]): The input data. + infer_cfg (dict): Inference configuration. + + Returns: + res (dict): The model prediction results. + """ + infer_cfg = infer_cfg or {} + + # Process inputs + if isinstance(inputs, str): + query = inputs + elif isinstance(inputs, dict): + # TODO: to be supported for continuation list like truthful_qa + query = inputs['data'][0] + elif isinstance(inputs, list): + query = '\n'.join(inputs) + else: + raise TypeError(f'Unsupported inputs type: {type(inputs)}') + + # Format request JSON according to OpenAI API format + request_json = { + 'model': self.model_id, + 'prompt': query, + 'max_tokens': infer_cfg.get('max_tokens', 2048), + 'temperature': infer_cfg.get('temperature', 1.0), + 'top_p': infer_cfg.get('top_p', 1.0), + 'n': infer_cfg.get('num_return_sequences', 1), + 'stop': infer_cfg.get('stop', None) + } + + # Request to remote API + response = requests.post(self.api_url, json=request_json) + response_data = response.json() + + choices_list = [{ + 'index': i, + 'message': { + 'content': choice['text'], + 'role': 'assistant' + } + } for i, choice in enumerate(response_data['choices'])] + + res_d = ChatCompletionResponse( + model=self.model_id, + choices=choices_list, + object='chat.completion', + created=int(time.time()), + usage=response_data.get('usage', None)).model_dump(exclude_unset=True) + + return res_d diff --git a/evalscope/run.py b/evalscope/run.py index e747329..069795d 100644 --- a/evalscope/run.py +++ b/evalscope/run.py @@ -120,18 +120,15 @@ def create_evaluator(task_cfg: TaskConfig, dataset_name: str, outputs: OutputsSt data_adapter=data_adapter, subset_list=benchmark.subset_list, model_adapter=model_adapter, - use_cache=task_cfg.use_cache, outputs=outputs, - datasets_dir=task_cfg.dataset_dir, - datasets_hub=task_cfg.dataset_hub, - stage=task_cfg.stage, - eval_type=task_cfg.eval_type, - overall_task_cfg=task_cfg, + task_cfg=task_cfg, ) def get_base_model(task_cfg: TaskConfig) -> Optional[LocalModel]: - """Get the base model for the task.""" + """Get the base local model for the task. If the task is not checkpoint-based, return None. + Avoids loading model multiple times for different datasets. + """ if task_cfg.eval_type != EvalType.CHECKPOINT: return None else: @@ -159,8 +156,11 @@ def initialize_model_adapter(task_cfg: TaskConfig, model_adapter_cls, base_model elif task_cfg.eval_type == EvalType.CUSTOM: if not isinstance(task_cfg.model, CustomModel): raise ValueError(f'Expected evalscope.models.custom.CustomModel, but got {type(task_cfg.model)}.') - from evalscope.models.model_adapter import CustomModelAdapter + from evalscope.models import CustomModelAdapter return CustomModelAdapter(custom_model=task_cfg.model) + elif task_cfg.eval_type == EvalType.SERVICE: + from evalscope.models import ServerModelAdapter + return ServerModelAdapter(url=task_cfg.model, model_id=task_cfg.model_id) else: return model_adapter_cls( model=base_model or get_base_model(task_cfg), diff --git a/evalscope/utils/chat_service.py b/evalscope/utils/chat_service.py index 6e4a4a7..6df4fd9 100644 --- a/evalscope/utils/chat_service.py +++ b/evalscope/utils/chat_service.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from threading import Thread from transformers import TextIteratorStreamer -from typing import List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union class Usage(BaseModel): @@ -66,7 +66,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponse(BaseModel): model: str object: Literal['chat.completion', 'chat.completion.chunk'] - choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, Any]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) usage: Optional[Usage]