diff --git a/GenAIEval/evaluation/lm_evaluation_harness/lm_eval/models/huggingface.py b/GenAIEval/evaluation/lm_evaluation_harness/lm_eval/models/huggingface.py index 38f5d095..10554b23 100644 --- a/GenAIEval/evaluation/lm_evaluation_harness/lm_eval/models/huggingface.py +++ b/GenAIEval/evaluation/lm_evaluation_harness/lm_eval/models/huggingface.py @@ -36,6 +36,11 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, ) +from lm_eval.api.registry import register_model +from lm_eval.api.model import CacheHook +import requests as requests_obj +from requests.exceptions import RequestException +import json eval_logger = utils.eval_logger @@ -1217,3 +1222,310 @@ def _model_call(self, inps): logits = logits[:, :-padding_length, :] logits = logits.to(torch.float32) return logits + + +@register_model("genai-hf") +class GenAI_HFLM(HFLM): + AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM + + def __init__( + self, + base_url=None, + logits_cache: bool = True, + tokenizer: Optional[str] = None, + revision: Optional[str] = "main", + batch_size: int = 1, + max_length: Optional[int] = None, + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + add_bos_token: Optional[bool] = False, + prefix_token_id: Optional[int] = None, + **kwargs): + self.base_url = base_url + assert self.base_url, "must pass `base_url` to use GenaAI service!" + self._rank = 0 + self._world_size = 1 + + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer, + revision=revision, + trust_remote_code=trust_remote_code, + use_fast=use_fast_tokenizer, + ) + + self.logits_cache = logits_cache + # select (or create) a pad token to use + if self.tokenizer.pad_token: + pass + elif self.tokenizer.unk_token: + self.tokenizer.pad_token_id = self.tokenizer.unk_token_id + elif self.tokenizer.eos_token: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + else: + if getattr(self.config, "model_type", None) == "qwen": + # Qwen's trust_remote_code tokenizer does not allow for adding special tokens + self.tokenizer.pad_token = "<|endoftext|>" + elif ( + self.tokenizer.__class__.__name__ == "RWKVWorldTokenizer" + or self.tokenizer.__class__.__name__ == "Rwkv5Tokenizer" + ): + # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0) + # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer + # --- + # Note that the world tokenizer class name, might change in the future for the final huggingface merge + # https://github.com/huggingface/transformers/pull/26963 + assert self.tokenizer.pad_token_id == 0 + else: + self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + + # TODO: override this for Gemma + self.add_bos_token = add_bos_token + if "GemmaTokenizer" in self.tokenizer.__class__.__name__: + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it." + ) + + self._batch_size = int(batch_size) + self._max_length = max_length + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info( + f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" + ) + self.cache_hook = CacheHook(None) + self.headers = {"Content-Type": "application/json"} + + @property + def max_length(self) -> int: + if self._max_length: + return self._max_length + else: + return self._DEFAULT_MAX_LENGTH + + @property + def batch_size(self) -> int: + return self._batch_size + + def _loglikelihood_tokens( + self, + task_requests: List[Tuple[Tuple[str, str], List[int], List[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> List[Tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = req[1] + req[2] + return -len(toks), tuple(toks) + + def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key to group and lookup one-token continuations""" + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + return req[-2] + req[-1][:-1] + + re_ord = Collator( + task_requests, + sort_fn=_collate, + group_by=None, + group_fn=_lookup_one_token_cont, + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else override_bs + if override_bs is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" + and n_reordered_requests > 0 + and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm( + total=len(task_requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :], + dtype=torch.long, + ) + (inplen,) = inp.shape + elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = ( + max(padding_len_cont, contlen) + if padding_len_cont is not None + else contlen + ) + + padding_len_inp = ( + max(padding_len_inp, inplen) + if padding_len_inp is not None + else inplen + ) + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: + batched_inps = pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: + # TODO: left-pad encoder inps and mask? + batched_inps = pad_and_concat( + padding_len_inp, inps + ) # [batch, padding_len_inp] + batched_conts = pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attn_mask": batched_encoder_mask, + "labels": batched_conts, + } + + data = { + "batched_inputs": batched_inps.tolist(), + } + try: + response = requests_obj.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + data=json.dumps(data), + ) + response.raise_for_status() + response = response.json() + except RequestException as e: + logger.error(f"RequestException: {e}") + + for (request_str, ctx_tokens, _), greedy_tokens, logprobs, inplen, cont_toks in zip( + chunk, response["greedy_tokens"], response["logprobs"],inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = ( + inplen + (len(logprobs) - padding_len_inp) + if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM + else None + ) + cont_toks = torch.tensor( + cont_toks, dtype=torch.long + ).unsqueeze(0) # [1, seq] + greedy_tokens = torch.tensor( + self._select_cont_toks(greedy_tokens, contlen=contlen, inplen=ctx_len), + dtype=torch.long + ).unsqueeze(0) # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + cont_logprobs = self._select_cont_toks(logprobs, contlen=contlen, inplen=ctx_len) + + # Answer: (log prob, is-exact-match) + answer = (sum(cont_logprobs), bool(max_equal)) + + res.append(answer) + + self.cache_hook.add_partial("loglikelihood", request_str, answer) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def _model_call(self, inps): + # Isn't used because we override _loglikelihood_tokens + raise NotImplementedError() + + def _model_generate(self, context, max_length, eos_token_id): + # Isn't used because we override generate_until + raise NotImplementedError() + + @property + def device(self): + # Isn't used because we override _loglikelihood_tokens + raise NotImplementedError() + + def loglikelihood_rolling(self, requests, disable_tqdm: bool = False): + raise NotImplementedError( + "loglikelihood_rolling not yet supported for GenAI service" + ) + + def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: + raise NotImplementedError("Not supported yet.")