Skip to content

Commit

Permalink
Merge pull request #98 from OpenGPTX/megatronlm_client_fixes
Browse files Browse the repository at this point in the history
Megatronlm client fixes
  • Loading branch information
KlaudiaTH authored Nov 5, 2023
2 parents c6579ae + 9740758 commit b05c53d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
1 change: 1 addition & 0 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def loglikelihood(self, requests):
context_enc = self.tok_encode(context)

continuation_enc = self.tok_encode(continuation)
# continuation_enc = self.tok_encode(continuation, is_continuation=True)

new_reqs.append(((context, continuation), context_enc, continuation_enc))

Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM,
"anthropic": anthropic_llms.AnthropicLM,
"megatronlm": megatronlm.MegatronServerLM,
"megatronlm": megatronlm.MegatronLMClient,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
}
Expand Down
19 changes: 14 additions & 5 deletions lm_eval/models/megatronlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_result(logprobs, is_max_logprobs, ctxlen):
return sum(continuation_logprobs), is_greedy


class MegatronServerLM(BaseLM):
class MegatronLMClient(BaseLM):
def __init__(self, server_url, model_name, batch_size=20, truncate=False):
"""
Expand All @@ -72,7 +72,7 @@ def __init__(self, server_url, model_name, batch_size=20, truncate=False):
meglm_metadata = self.megatron_metadata()
self.vocab_size = meglm_metadata["vocab_size"]
self._eod_token_id = meglm_metadata["eod_token_id"]
self._max_length = meglm_metadata["max_length"] - self.max_gen_toks
self._max_length = meglm_metadata["max_length"]

@property
def eot_token_id(self):
Expand Down Expand Up @@ -176,13 +176,17 @@ def _collate(x):
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
inp = (context_enc + continuation_enc)[-self.max_length :]

ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)

if len(context_enc) + len(continuation_enc) > len(inp):
print(
f"WARNING: Length of concatenated context ...{repr(cache_key[0][-20:])} and continuation {repr(cache_key[1])} exceeds max length {self.max_length + 1}"
)

inps.append(inp)
ctxlens.append(ctxlen)

Expand Down Expand Up @@ -247,6 +251,11 @@ def sameuntil_chunks(xs, size):
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)

if len(context_enc) > len(inp):
print(
f"WARNING: Length of context ...{repr(context[-20:])} exceeds max length {self.max_length - self.max_gen_toks}"
)

if isinstance(until, str):
until = [until]

Expand Down

0 comments on commit b05c53d

Please sign in to comment.