Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multiple LoRa adapters #2010

Merged
merged 41 commits into from
Jun 25, 2024
Merged

Enable multiple LoRa adapters #2010

merged 41 commits into from
Jun 25, 2024

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Jun 4, 2024

This PR is a work in progress to add support for mutliple loras to be loaded at startup and then use 0 or 1 adapters in a request by specifying the adapter id.

Example usage

download adapter without auto merging

text-generation-server download-weights predibase/dbpedia
text-generation-server download-weights predibase/customer_support

start server with multiple LoRa adapters

LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia \
text-generation-launcher --model-id mistralai/Mistral-7B-v0.1

sending request without adapter id

curl 127.0.0.1:3000/generate \
    -X POST \
    -H 'Content-Type: application/json' \
    -d '{
    "inputs": "What are 3 unique words that describe you?",
    "parameters": {
    "max_new_tokens": 40
  }
}'
{
  "generated_text": "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact"
}

with first LoRa adapter specified

curl 127.0.0.1:3000/generate \
    -X POST \
    -H 'Content-Type: application/json' \
    -d '{
    "inputs": "What are 3 unique words that describe you?",
    "parameters": {
    "max_new_tokens": 40,
    "adapter_id": "predibase/customer_support"
  }
}'
{
  "generated_text": "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3."
}

with second LoRa adapter specified

curl 127.0.0.1:3000/generate \
    -X POST \
    -H 'Content-Type: application/json' \
    -d '{
    "inputs": "You are given the title and the body of an article below. Please determine the type of the article.### Title: Great White Whale\n\n### Body: Great White Whale is the debut album by the Canadian rock band Secret and Whisper. The album was in the works for about a year and was released on February 12 2008.",                                                                                             
    "parameters": {
    "max_new_tokens": 40,
    "adapter_id": "predibase/dbpedia"
  }
}'
{
  "generated_text": "8"
}

@flozi00
Copy link
Contributor

flozi00 commented Jun 4, 2024

Hey guys, as I am switched to lorax and started contributing there a lot after the first license change I am happy to see the PR got opened

I would be happy if you are open for some questions and discussion about this.
What do you think about following the lorax style for the api so it could be a drop in replacement, especially for the openai endpoint?
Furthermore are you open for some kernel optimizations like punica, used in lorax and vllm, to minimize the overhead and enable efficient batching of different adapters ?

I'd would be happy to contribute here too.

@drbh
Copy link
Collaborator Author

drbh commented Jun 5, 2024

hi @flozi00 thanks for the feedback! can you share more about the lorax style api? I see that in lorax you can specify the adapter via the model field in the chat endpoint, is that the feature you're referring to? regarding kernel optimizations, yes we are very interested and plan on using optimized kernels/punica (planning on diving into this specific thing tomorrow/this week). Any suggestions/contributions/patches related to this are always appreciated!

@flozi00
Copy link
Contributor

flozi00 commented Jun 5, 2024

Yes, i mean the "adapter_id" inside "parameters" for the tgi api (as you did it, i see now), and the "model" in the openai api :)

@drbh
Copy link
Collaborator Author

drbh commented Jun 6, 2024

update:

This PR's implementation has been updated to align with the great work done by the lorax team. This implementation tries to use the same layers when possible and only diverges to work with TGI's recent updates/improvements and limits lora to loading at startup. Current changes allow weights to be loaded similar to Lorax, however there are still issues with generation to be resolved, and other refactors

@flozi00
Copy link
Contributor

flozi00 commented Jun 6, 2024

Looks like you are successfully adopting the lorax code
Please tell me if you need any help with this feature

@drbh
Copy link
Collaborator Author

drbh commented Jun 6, 2024

@flozi00 generation with loras is mostly stable, just focusing on the rebase then refactors now. And thank you 🙂 a review once the PR is ready would be super helpful!

@drbh drbh force-pushed the lora-internal branch from 091f2dc to d103264 Compare June 6, 2024 19:52
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tgaddair
Copy link
Contributor

tgaddair commented Jun 7, 2024

Thanks for the shoutout in the docs! It's quite interesting to see things come full circle, maybe we should chat about merging our projects.

@drbh drbh marked this pull request as ready for review June 7, 2024 12:24
@drbh
Copy link
Collaborator Author

drbh commented Jun 7, 2024

of course @tgaddair thank you for the awesome work! thats an interesting idea and we are always aiming to improve TGI. We appreciate any contributions/discussions about features that may be helpful to our users

@drbh drbh changed the title Lora internal Enable multiple LoRa adapters Jun 7, 2024
@flozi00
Copy link
Contributor

flozi00 commented Jun 7, 2024

Thanks for the shoutout in the docs! It's quite interesting to see things come full circle, maybe we should chat about merging our projects.

I'd love to migrate to tgi again 👍 And of course trying to contribute here too @tgaddair

@NielsRogge NielsRogge mentioned this pull request Jun 10, 2024
2 tasks
@drbh
Copy link
Collaborator Author

drbh commented Jun 18, 2024

hi @xiadingZ in this PR lora adapters are loaded from the HUGGINGFACE_HUB_CACHE directory. Once you've downloaded the lora locally, you can specify the id like LORA_ADAPTERS=predibase/customer_support which will use the local lora model.

once this initial lora work is merged we'll follow up with other improvement such as easier ways to specify lora path, and etc

@xiadingZ
Copy link

hi @xiadingZ in this PR lora adapters are loaded from the HUGGINGFACE_HUB_CACHE directory. Once you've downloaded the lora locally, you can specify the id like LORA_ADAPTERS=predibase/customer_support which will use the local lora model.

once this initial lora work is merged we'll follow up with other improvement such as easier ways to specify lora path, and etc

Hi, @drbh I can try your methods with downloaded lora. But I have a lora adapter trained locally. It doesn't have a directory structure such as blobs, refs, snapshots from huggingface. how can I place it in HUGGINGFACE_HUB_CACHE and load it?

I set HUGGINGFACE_HUB_CACHE as /root/lora_adapters and my directory structure is:
20240619-181354

server/Makefile-lorax-punica Show resolved Hide resolved
server/text_generation_server/adapters/__init__.py Outdated Show resolved Hide resolved
server/text_generation_server/adapters/__init__.py Outdated Show resolved Hide resolved
server/text_generation_server/adapters/config.py Outdated Show resolved Hide resolved
server/text_generation_server/adapters/config.py Outdated Show resolved Hide resolved
server/text_generation_server/utils/adapter.py Outdated Show resolved Hide resolved
server/text_generation_server/utils/peft.py Outdated Show resolved Hide resolved
@danieldk
Copy link
Member

Forgot to add: we probably want an integration test as well.

danieldk
danieldk previously approved these changes Jun 25, 2024
Copy link
Member

@danieldk danieldk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the changes! Looks ready to merge to me after the small nit that breaks CI is fixed.

@drbh
Copy link
Collaborator Author

drbh commented Jun 25, 2024

@danieldk thanks for the review! I've fixed the nits and CI passes. Going to go ahead and merge based on your last approval

@drbh drbh merged commit 04e1af9 into main Jun 25, 2024
9 checks passed
@drbh drbh deleted the lora-internal branch June 25, 2024 18:46
@mhou7712
Copy link

mhou7712 commented Jul 1, 2024

Hi @flozi00, is it possible to look at my issue #2143 and let me know any suggestion? Thanks.

glegendre01 pushed a commit that referenced this pull request Jul 2, 2024
* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------

Co-authored-by: Derek <[email protected]>
yuanwu2017 pushed a commit to yuanwu2017/tgi-gaudi that referenced this pull request Sep 26, 2024
* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------

Co-authored-by: Derek <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants