-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable more models to inference based on LoRA (#3382)
Co-authored-by: Antoni Baum <[email protected]>
- Loading branch information
Showing
10 changed files
with
402 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import pytest | ||
|
||
import vllm | ||
from vllm.lora.request import LoRARequest | ||
|
||
from .conftest import cleanup | ||
|
||
MODEL_PATH = "baichuan-inc/Baichuan-7B" | ||
|
||
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 | ||
|
||
|
||
def do_sample(llm, lora_path: str, lora_id: int) -> str: | ||
prompts = [ | ||
PROMPT_TEMPLATE.format(query="How many singers do we have?"), | ||
PROMPT_TEMPLATE.format( | ||
query= | ||
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 | ||
), | ||
PROMPT_TEMPLATE.format( | ||
query= | ||
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 | ||
), | ||
] | ||
print(prompts) | ||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) | ||
outputs = llm.generate( | ||
prompts, | ||
sampling_params, | ||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) | ||
if lora_id else None) | ||
# Print the outputs. | ||
generated_texts = [] | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text.strip() | ||
generated_texts.append(generated_text) | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
return generated_texts | ||
|
||
|
||
def test_baichuan_lora(baichuan_lora_files): | ||
llm = vllm.LLM(MODEL_PATH, | ||
max_model_len=1024, | ||
enable_lora=True, | ||
max_loras=4, | ||
max_lora_rank=64, | ||
trust_remote_code=True) | ||
|
||
expected_lora_output = [ | ||
"SELECT count(*) FROM singer", | ||
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501 | ||
"SELECT name , country , age FROM singer ORDER BY age ASC", | ||
] | ||
|
||
output1 = do_sample(llm, baichuan_lora_files, lora_id=1) | ||
for i in range(len(expected_lora_output)): | ||
assert output1[i] == expected_lora_output[i] | ||
output2 = do_sample(llm, baichuan_lora_files, lora_id=2) | ||
for i in range(len(expected_lora_output)): | ||
assert output2[i] == expected_lora_output[i] | ||
|
||
|
||
@pytest.mark.skip("Requires multiple GPUs") | ||
def test_llama_tensor_parallel_equality(baichuan_lora_files): | ||
# Cannot use as it will initialize torch.cuda too early... | ||
# if torch.cuda.device_count() < 4: | ||
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}") | ||
|
||
llm_tp1 = vllm.LLM(MODEL_PATH, | ||
enable_lora=True, | ||
max_num_seqs=16, | ||
max_loras=4, | ||
max_lora_rank=64, | ||
tensor_parallel_size=1, | ||
trust_remote_code=True) | ||
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) | ||
|
||
del llm_tp1 | ||
cleanup() | ||
|
||
llm_tp2 = vllm.LLM(MODEL_PATH, | ||
enable_lora=True, | ||
max_num_seqs=16, | ||
max_loras=4, | ||
max_lora_rank=64, | ||
tensor_parallel_size=2, | ||
trust_remote_code=True) | ||
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) | ||
|
||
del llm_tp2 | ||
cleanup() | ||
|
||
assert output_tp1 == output_tp2 | ||
|
||
llm_tp4 = vllm.LLM(MODEL_PATH, | ||
enable_lora=True, | ||
max_num_seqs=16, | ||
max_loras=4, | ||
max_lora_rank=64, | ||
tensor_parallel_size=4, | ||
trust_remote_code=True) | ||
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) | ||
|
||
del llm_tp4 | ||
cleanup() | ||
|
||
assert output_tp1 == output_tp4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import vllm | ||
from vllm.lora.request import LoRARequest | ||
|
||
MODEL_PATH = "THUDM/chatglm3-6b" | ||
|
||
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 | ||
|
||
|
||
def do_sample(llm, lora_path: str, lora_id: int) -> str: | ||
prompts = [ | ||
PROMPT_TEMPLATE.format(query="How many singers do we have?"), | ||
PROMPT_TEMPLATE.format( | ||
query= | ||
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 | ||
), | ||
PROMPT_TEMPLATE.format( | ||
query= | ||
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 | ||
), | ||
] | ||
print(prompts) | ||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) | ||
outputs = llm.generate( | ||
prompts, | ||
sampling_params, | ||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) | ||
if lora_id else None) | ||
# Print the outputs. | ||
generated_texts = [] | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text.strip() | ||
generated_texts.append(generated_text) | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
return generated_texts | ||
|
||
|
||
def test_chatglm3_lora(chatglm3_lora_files): | ||
llm = vllm.LLM(MODEL_PATH, | ||
max_model_len=1024, | ||
enable_lora=True, | ||
max_loras=4, | ||
max_lora_rank=64, | ||
trust_remote_code=True) | ||
|
||
expected_lora_output = [ | ||
"SELECT count(*) FROM singer", | ||
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 | ||
"SELECT name , country , age FROM singer ORDER BY age", | ||
] | ||
|
||
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) | ||
for i in range(len(expected_lora_output)): | ||
assert output1[i] == expected_lora_output[i] | ||
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) | ||
for i in range(len(expected_lora_output)): | ||
assert output2[i] == expected_lora_output[i] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.