Skip to content

Commit

Permalink
Updated T0 and GPT-J
Browse files Browse the repository at this point in the history
  • Loading branch information
tttyuntian committed Apr 28, 2022
1 parent b62d1be commit 29bff88
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 18 deletions.
35 changes: 25 additions & 10 deletions lm_eval/models/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def __init__(
self,
device="cuda",
batch_size=1,
parallelize=False,
):
super().__init__()

Expand Down Expand Up @@ -35,9 +36,11 @@ def __init__(
self.batch_size_per_gpu = batch_size # todo: adaptive batch size

# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gptj = nn.DataParallel(self.gptj)
if parallelize:
self.gptj.parallelize()
self._device = torch.device('cuda:0')
else:
self.gptj.to(self._device)

@property
def eot_token(self):
Expand Down Expand Up @@ -113,11 +116,23 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
EOSCriteria(self.tokenizer.eos_token)
])

def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)

if num_fewshot == 0:
generations = self.gptj.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)

# Remove the context from the generations
return generations[0, context.shape[1] :]
34 changes: 26 additions & 8 deletions lm_eval/models/t0.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def max_length(self):

@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
return 256

@property
def batch_size(self):
Expand Down Expand Up @@ -94,6 +94,14 @@ def loglikelihood(self, requests):

inputs, targets = zip(*chunk)

# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)

inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
Expand Down Expand Up @@ -172,11 +180,21 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
EOSCriteria(self.tokenizer.eos_token)
])

def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)

if num_fewshot == 0:
generations = self.t0.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations[0]

0 comments on commit 29bff88

Please sign in to comment.