Skip to content

Commit

Permalink
add GenAI_HFLM class to support microservice.
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk12014402 committed May 22, 2024
1 parent 924d47d commit 6131a3c
Showing 1 changed file with 312 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
from lm_eval.api.registry import register_model
from lm_eval.api.model import CacheHook
import requests as requests_obj
from requests.exceptions import RequestException
import json

eval_logger = utils.eval_logger

Expand Down Expand Up @@ -1217,3 +1222,310 @@ def _model_call(self, inps):
logits = logits[:, :-padding_length, :]
logits = logits.to(torch.float32)
return logits


@register_model("genai-hf")
class GenAI_HFLM(HFLM):
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

def __init__(
self,
base_url=None,
logits_cache: bool = True,
tokenizer: Optional[str] = None,
revision: Optional[str] = "main",
batch_size: int = 1,
max_length: Optional[int] = None,
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
**kwargs):
self.base_url = base_url
assert self.base_url, "must pass `base_url` to use GenaAI service!"
self._rank = 0
self._world_size = 1

self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)

self.logits_cache = logits_cache
# select (or create) a pad token to use
if self.tokenizer.pad_token:
pass
elif self.tokenizer.unk_token:
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else:
if getattr(self.config, "model_type", None) == "qwen":
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>"
elif (
self.tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
or self.tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
):
# The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
# The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
# ---
# Note that the world tokenizer class name, might change in the future for the final huggingface merge
# https://github.com/huggingface/transformers/pull/26963
assert self.tokenizer.pad_token_id == 0
else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

# TODO: override this for Gemma
self.add_bos_token = add_bos_token
if "GemmaTokenizer" in self.tokenizer.__class__.__name__:
self.add_bos_token = True
eval_logger.info(
f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it."
)

self._batch_size = int(batch_size)
self._max_length = max_length
self.custom_prefix_token_id = prefix_token_id
if prefix_token_id is not None:
eval_logger.info(
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
)
self.cache_hook = CacheHook(None)
self.headers = {"Content-Type": "application/json"}

@property
def max_length(self) -> int:
if self._max_length:
return self._max_length
else:
return self._DEFAULT_MAX_LENGTH

@property
def batch_size(self) -> int:
return self._batch_size

def _loglikelihood_tokens(
self,
task_requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []

def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end

toks = req[1] + req[2]
return -len(toks), tuple(toks)

def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]

re_ord = Collator(
task_requests,
sort_fn=_collate,
group_by=None,
group_fn=_lookup_one_token_cont,
)

# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
n_reordered_requests = len(re_ord)
batch_size = (
self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None
)

chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(
total=len(task_requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
for chunk in chunks:
inps = []
cont_toks_list = []
inplens = []

conts = []
encoder_attns = []

padding_len_inp = None
padding_len_cont = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying

for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length

# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice

# when too long to fit in context, truncate from the left
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :],
dtype=torch.long,
)
(inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
inp = torch.tensor(
(context_enc)[-self.max_length :],
dtype=torch.long,
)
(inplen,) = inp.shape

# build encoder attn masks
encoder_attns.append(torch.ones_like(inp))

cont = torch.tensor(
(continuation_enc)[-self.max_length :],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long,
)
(contlen,) = cont.shape

conts.append(cont)

padding_len_cont = (
max(padding_len_cont, contlen)
if padding_len_cont is not None
else contlen
)

padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)

inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)

# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask?
batched_inps = pad_and_concat(
padding_len_inp, inps
) # [batch, padding_len_inp]
batched_conts = pad_and_concat(
padding_len_cont, conts
) # [batch, padding_len_cont]
batched_encoder_mask = pad_and_concat(
padding_len_inp, encoder_attns
) # [batch, padding_len_inp]
call_kwargs = {
"attn_mask": batched_encoder_mask,
"labels": batched_conts,
}

data = {
"batched_inputs": batched_inps.tolist(),
}
try:
response = requests_obj.post(
f"{self.base_url}/v1/completions",
headers=self.headers,
data=json.dumps(data),
)
response.raise_for_status()
response = response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")

for (request_str, ctx_tokens, _), greedy_tokens, logprobs, inplen, cont_toks in zip(
chunk, response["greedy_tokens"], response["logprobs"],inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# also discards + checks for "virtual tokens" in the causal LM's input window
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (len(logprobs) - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None
)
cont_toks = torch.tensor(
cont_toks, dtype=torch.long
).unsqueeze(0) # [1, seq]
greedy_tokens = torch.tensor(
self._select_cont_toks(greedy_tokens, contlen=contlen, inplen=ctx_len),
dtype=torch.long
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
cont_logprobs = self._select_cont_toks(logprobs, contlen=contlen, inplen=ctx_len)

# Answer: (log prob, is-exact-match)
answer = (sum(cont_logprobs), bool(max_equal))

res.append(answer)

self.cache_hook.add_partial("loglikelihood", request_str, answer)
pbar.update(1)

pbar.close()

return re_ord.get_original(res)

def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()

def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override generate_until
raise NotImplementedError()

@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()

def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError(
"loglikelihood_rolling not yet supported for GenAI service"
)

def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
raise NotImplementedError("Not supported yet.")

0 comments on commit 6131a3c

Please sign in to comment.