Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for a static shape generate #28075

Closed
alessandropalla opened this issue Dec 15, 2023 · 6 comments · Fixed by #27931
Closed

Adding support for a static shape generate #28075

alessandropalla opened this issue Dec 15, 2023 · 6 comments · Fixed by #27931

Comments

@alessandropalla
Copy link
Contributor

Feature request

Many inference AI accelerators (Intel NPU, IPU, TPU, etc...) requires static shapes get maximum performance. Static shapes allows the NN graph compiler to improve memory management, schedule and overall network performance.

However in transformers the generate function uses dynamic shapes and increase the size of the input (and kv-cache) at every successive step. I opened this issue to implement a way to still do LLM generation inference using transformers API while maintaining static shapes:

The trick is to use left padding and shift left the kv-cached values while doing inference. By setting the position_id correctly we can have a correct inference. Attached a GIF that hopefully explains how it works:

static-shapes-kv-cache

Fist inference you pad left and run as usual. It is important to set the attention_mask and position_ids accordingly. In the kv-cached part you only need to pass the new token and the proper prosition_ids and attention_mask while the cache values are shifted left. This works because in the MHA block the cached values and keys are concatenated left with the new ones and having left padding makes the new token key and value tensors adjacent to the cache values

Here a snippet for a function that implements this. This code is not production ready but is a POC to show you how it is supposed to work both with and without KV-caching

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cpu" # use the accelerator that you have or use "cpu"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
# Load model
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# Utility function to compute shift left and insert a value into a tensor
def lshift_insert(tensor, value):
    tensor = torch.roll(tensor, shifts=-1, dims=-1)
    tensor[0, -1] = value
    return tensor

# Generate function 
@torch.no_grad()
def generate_with_static_shape(model, input_ids, attention_mask=None, max_length=None, use_past=True, pad_token_id=None, **kwargs):

    # Get sequence lenght
    batch, seq_lenght = input_ids.shape

    if pad_token_id is None:
        RuntimeError("pad_token_id is not set and needed for static shape generation")

    # Padding attention mask
    if attention_mask is None:
        attention_mask  = torch.ones_like(input_ids, dtype=torch.int32).to(model.device)
    attention_mask_padding  = torch.zeros((batch, max_length - seq_lenght), dtype=input_ids.dtype, device=input_ids.device)
    attention_mask = torch.cat((attention_mask_padding, attention_mask), dim=-1)

    # Padding input_ids with left padding
    padding_input_ids = pad_token_id * torch.ones((batch, max_length - seq_lenght), dtype=input_ids.dtype, device=input_ids.device)
    input_ids = torch.cat((padding_input_ids, input_ids), dim=-1).to(model.device)

    # Set the proper position ids
    position_ids = kwargs.get('position_ids', None)
    if position_ids is None:
        position_ids = torch.tensor([[0] * (max_length - seq_lenght) + list(range(seq_lenght))], dtype=torch.int32).to(model.device)
    else:
        raise RuntimeError("Cannot set position_ids with in static shape generation")

    # past_key_values for KV-cache
    past_key_values = None

    for idx in range(seq_lenght, max_length):

        # Run the inference
        out = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)

        ## Here I do greedy search as an example, but in general is where you want to select the next token with your fancy decoding algorithm
        logits = out.logits
        new_token = torch.argmax(logits[0, -1, :])

        yield new_token

        if not use_past:
            # Shift left input and position ids and set the new token and idx to the proper values
            input_ids = lshift_insert(input_ids, new_token)
            position_ids = lshift_insert(position_ids, idx)
        else:
            # Set input_ids and position_ids to their new value
            input_ids = torch.tensor([[new_token]], dtype=input_ids.dtype).to(model.device)
            position_ids = torch.tensor([[idx]], dtype=input_ids.dtype).to(model.device)

            # Select the proper KV cached keys for next inference
            past_key_values = [[item[:, :, 1:, :] for item in layer_past] for layer_past in out.past_key_values]

        # Shift left attention mask and set the last value to one
        attention_mask = lshift_insert(attention_mask, 1)


prompt = "List all numbers in the Fibonacci sequence: 1, 1, 2, 3, "
max_length = 512

# Tokenize 
input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].to(device)

print(prompt, end="", flush=True)
results = generate_with_static_shape(model, input_ids=input_ids, max_length=max_length, use_past=True, pad_token_id=tokenizer.pad_token_id)
for new_token_id in results:

    token = tokenizer.decode([new_token_id], skip_special_tokens=True)
    # Not very good as depending on the tokenizer it might or might not add spaces
    print(token , end="", flush=True)

Motivation

Enabling AI inference accelerator to be used with the generate API

Your contribution

I'll be happy to help integrating the code into transformers library. Let me know how I can help

@amyeroberts
Copy link
Collaborator

cc @gante

@alessandropalla
Copy link
Contributor Author

Any update on this ticket?

@oobabooga
Copy link
Contributor

There is an open PR: #27931

@ArthurZucker
Copy link
Collaborator

Thanks @oobabooga 🤗 and yes this is my main focus, hoping to ship by end of the week

@alessandropalla
Copy link
Contributor Author

Many thanks! do you need help for the PR? (Development/testing/writing examples on how to run a model with static shape on the NPU?)

@ArthurZucker
Copy link
Collaborator

I don't really have access to a NPU currently so feel free to test it. It's still in draft mode so when it's ready for review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants