Skip to content

Commit

Permalink
Update neuron backend (EleutherAI#2314)
Browse files Browse the repository at this point in the history
* feat(neuron): align with latest optimum-neuron

* feat(neuron): support pre-exported neuron models

* fix(neuron): correctly use max_length

* fix(neuron): adapt loglikelihood

The evaluation of log likelihood was not working for neuron models
using continuous batching, such as all cached neuron LLama models.

* refactor(neuron): remove dead code
  • Loading branch information
dacorvo authored Sep 18, 2024
1 parent 88ea85b commit 9a092f3
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 179 deletions.
241 changes: 88 additions & 153 deletions lm_eval/models/neuron_optimum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import copy
import json
import logging
import subprocess
from collections import defaultdict
from typing import List, Optional, Union

Expand Down Expand Up @@ -33,54 +31,6 @@
logger = logging.getLogger(__name__)


def get_nc_count() -> Union[int, None]:
"""Returns the number of neuron cores on the current instance."""
try:
cmd = "neuron-ls --json-output"
result = subprocess.run(cmd, shell=True, capture_output=True)
print(f"inferring nc_count from `neuron-ls` {result.stdout}")
json_output = json.loads(result.stdout)
count = sum([x["nc_count"] for x in json_output])
print(f"nc_count={count}")
return count
except Exception:
return None


def wrap_constant_batch_size(func):
def _decorator(self, input_ids):
"""input_ids a 2D array with batch_size on dim=0
makes sure the func runs with self.batch_size
"""
# access a from TestSample
batch_size = input_ids.shape[0]

if batch_size < self.batch_size:
# handle the event of input_ids.shape[0] != batch_size
# Neuron cores expect constant batch_size
input_ids = torch.concat(
(
input_ids,
# add missing_batch_size dummy
torch.zeros(
[self.batch_size - batch_size, *input_ids.size()[1:]],
dtype=input_ids.dtype,
device=input_ids.device,
),
),
dim=0,
)
elif batch_size > self.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
)
# return the forward pass that requires constant batch size
return func(self, input_ids)[:batch_size]

return _decorator


class CustomNeuronModelForCausalLM(NeuronModelForCausalLM):
"""NeuronModelForCausalLM with `stopping_criteria` in `generate`"""

Expand Down Expand Up @@ -146,7 +96,7 @@ def generate(
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
)
elif batch_size < self.batch_size:
elif batch_size < self.batch_size and not self.continuous_batching:
logger.warning(
"Inputs will be padded to match the model static batch size. This will increase latency."
)
Expand All @@ -158,8 +108,6 @@ def generate(
if attention_mask is not None:
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([attention_mask, padding])
# Drop the current generation context and clear the Key/Value cache
self.reset_generation()

output_ids = self.generate_tokens(
padded_input_ids,
Expand All @@ -179,8 +127,6 @@ class NEURON_HF(TemplateLM):
Tested with neuron 2.17.0
"""

_DEFAULT_MAX_LENGTH = 2048

def __init__(
self,
pretrained: Optional[str] = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Expand All @@ -203,7 +149,7 @@ def __init__(
"please install neuron via pip install transformers-neuron ",
"also make sure you are running on an AWS inf2 instance",
)
if version.parse(optimum_neuron_version) != version.parse("0.0.17"):
if version.parse(optimum_neuron_version) != version.parse("0.0.24"):
logger.warning(
'`optimum-neuron` model requires `pip install "optimum[neuronx]>=0.0.17" '
"preferably using the Hugging Face Neuron Deep Learning AMI (Ubuntu 22.04) "
Expand All @@ -217,35 +163,16 @@ def __init__(

self.batch_size_per_gpu = int(batch_size)
batch_size = int(batch_size)
if tp_degree is None:
# execute `neuron-ls --json-output | jq '.[0].nc_count'``
# to get the number of neuron cores on your instance
tp_degree = get_nc_count()

assert isinstance(tp_degree, int), (
f"model_args must include tp_degree. tp_degree must be set to an integer,"
f" but is tp_degree=`{tp_degree}` with type=`{type(tp_degree)}`."
"Set it to number of neuron cores on your instance."
" For inf2.xlarge and inf2.8xlarge, set it to `2`."
" For inf2.24xlarge, set it to `12`."
" For inf2.48xlarge, set it to `24`."
)

revision = str(revision) # cast to string if not already one
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")

self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
torch_dtype = lm_eval.models.utils.get_dtype(dtype)

assert torch_dtype in [
torch.float16,
torch.bfloat16,
], "Only float16 and bfloat16 are supported"
revision = str(revision) # cast to string if not already one
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")

self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
Expand All @@ -254,45 +181,65 @@ def __init__(
use_fast=use_fast_tokenizer,
)

# Neuron specific code
if torch_dtype == torch.float16:
self.amp_dtype = "f16"
elif torch_dtype == torch.bfloat16:
self.amp_dtype = "bf16"
elif torch_dtype == torch.float32:
self.amp_dtype = "f32"
else:
raise NotImplementedError("Only float16 and bfloat16 are implemented.")

compiler_args = {"num_cores": tp_degree, "auto_cast_type": self.amp_dtype}
input_shapes = {
"batch_size": batch_size,
"sequence_length": self._DEFAULT_MAX_LENGTH,
}
neuron_config = getattr(self._config, "neuron", None)
if neuron_config is None:
# Check export parameters
if tp_degree is not None:
assert isinstance(tp_degree, int), (
f"tp_degree must be set to an integer,"
f" but is tp_degree=`{tp_degree}` with type=`{type(tp_degree)}`."
"Set it to a number lower than the number of neuron cores on your instance."
" For inf2.xlarge and inf2.8xlarge, set it to `2`."
" For inf2.24xlarge, set it <= `12`."
" For inf2.48xlarge, set it <= `24`."
)
torch_dtype = lm_eval.models.utils.get_dtype(dtype)

if torch_dtype == torch.float16:
self.amp_dtype = "f16"
elif torch_dtype == torch.bfloat16:
self.amp_dtype = "bf16"
elif torch_dtype == torch.float32:
self.amp_dtype = "f32"
else:
raise NotImplementedError(
"Only float16/bfloat16/float32 are supported."
)

print(
f"{'='*20} \n loading model to neuron with"
f" {compiler_args}, {input_shapes}..."
)
self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage,
export=True,
**compiler_args,
**input_shapes,
)
print(f"SUCCESS: neuron model compiled. \n {'='*20}")
print(f"{'='*20} \n exporting model to neuron")
self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage,
export=True,
batch_size=batch_size,
num_cores=tp_degree,
auto_cast_type=self.amp_dtype,
sequence_length=max_length,
)
neuron_config = self.model.config.neuron
print(
f"SUCCESS: neuron model exported with config {neuron_config}. \n {'='*20}"
)
else:
print(
f"{'='*20} \n loading neuron model with config" f" {neuron_config}..."
)
self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage,
)
print(f"SUCCESS: neuron model loaded. \n {'='*20}")

self.truncation = truncation

self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.add_bos_token = add_bos_token

self._max_length = max_length

self.batch_schedule = 1
self.batch_sizes = {}

Expand All @@ -313,17 +260,7 @@ def prefix_token_id(self):

@property
def max_length(self):
if self._max_length: # if max length manually set, return it
return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self.model.config, attr):
return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
return self.model.max_length

@property
def max_gen_toks(self) -> int:
Expand Down Expand Up @@ -391,34 +328,6 @@ def tok_batch_encode(
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)

@wrap_constant_batch_size
def _model_call(self, input_ids: torch.Tensor):
"""
get logits for the entire sequence
:param input_ids: torch.Tensor
A torch tensor of shape [batch, sequence_cont]
the size of sequence may vary from call to call
:return
A torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model's decoder-lm head
"""
_, sequence_length = input_ids.shape

with torch.inference_mode():
cache_ids = torch.arange(0, sequence_length, dtype=torch.int32).split(1)
input_ids_split = input_ids.split(1, dim=1)

return torch.concat(
[
self.model.forward(
input_ids=input_id, cache_ids=cache_id, return_dict=False
)[0]
for input_id, cache_id in zip(input_ids_split, cache_ids)
],
dim=1,
)

def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
Expand Down Expand Up @@ -580,15 +489,41 @@ def _collate(x):
cont_toks_list.append(continuation_enc)
inplens.append(inplen)

# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
# Add dummy inputs up to the model static batch size
if len(inps) < self.batch_size:
inps = inps + [
torch.zeros_like(inps[0]),
] * (self.batch_size - len(inps))

masks = [torch.ones_like(inp) for inp in inps]
batched_inps = lm_eval.models.utils.pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]

multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1
) # [batch, padding_length (inp or cont), vocab]
batched_masks = lm_eval.models.utils.pad_and_concat(
padding_len_inp, masks, padding_side="right"
)
if self.model.model.neuron_config.output_all_logits:
inputs = self.model.prepare_inputs_for_prefill(
batched_inps, batched_masks
)
multi_logits = F.log_softmax(
self.model.forward(**inputs).logits, dim=-1
) # [batch, padding_length (inp or cont), vocab]
else:
# The model will only return the logits for the last input token, so we need
# to iterate over inputs to accumulate logits.
# To speed things up we use the KV cache as we would do when generating.
inputs = self.model.prepare_inputs_for_prefill(
batched_inps[:, :1], batched_masks[:, :1]
)
outputs = [self.model.forward(**inputs).logits]
for i in range(1, padding_len_inp):
inputs = self.model.prepare_inputs_for_decode(
batched_inps[:, : i + 1], batched_masks[:, : i + 1]
)
outputs.append(self.model.forward(**inputs).logits)
multi_logits = F.log_softmax(torch.concat(outputs, dim=1), dim=-1)

for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
Expand Down
26 changes: 0 additions & 26 deletions tests/models/test_neuron_optimum.py

This file was deleted.

0 comments on commit 9a092f3

Please sign in to comment.