-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <[email protected]> Signed-off-by: Max de Bayser <[email protected]> Signed-off-by: Flavia Beo <[email protected]> Co-authored-by: Flavia Beo <[email protected]>
- Loading branch information
1 parent
49628fe
commit 214efc2
Showing
28 changed files
with
1,370 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
"""Examples Python client Score for Cross Encoder Models | ||
""" | ||
|
||
import argparse | ||
import json | ||
import pprint | ||
|
||
import requests | ||
|
||
|
||
def post_http_request(prompt: json, api_url: str) -> requests.Response: | ||
headers = {"User-Agent": "Test Client"} | ||
response = requests.post(api_url, headers=headers, json=prompt) | ||
return response | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--host", type=str, default="localhost") | ||
parser.add_argument("--port", type=int, default=8000) | ||
parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3") | ||
args = parser.parse_args() | ||
api_url = f"http://{args.host}:{args.port}/v1/score" | ||
|
||
model_name = args.model | ||
|
||
text_1 = "What is the capital of France?" | ||
text_2 = [ | ||
"The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
] | ||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} | ||
score_response = post_http_request(prompt=prompt, api_url=api_url) | ||
print("Prompt for text_1 is string and text_2 is a list:") | ||
pprint.pprint(prompt) | ||
print("Score Response:") | ||
pprint.pprint(score_response.data) | ||
|
||
text_1 = [ | ||
"What is the capital of Brazil?", "What is the capital of France?" | ||
] | ||
text_2 = [ | ||
"The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
] | ||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} | ||
score_response = post_http_request(prompt=prompt, api_url=api_url) | ||
print("Prompt for text_1 and text_2 are lists:") | ||
pprint.pprint(prompt) | ||
print("Score Response:") | ||
pprint.pprint(score_response.data) | ||
|
||
text_1 = "What is the capital of Brazil?" | ||
text_2 = "The capital of Brazil is Brasilia." | ||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} | ||
score_response = post_http_request(prompt=prompt, api_url=api_url) | ||
print("Prompt for text_1 and text_2 are strings:") | ||
pprint.pprint(prompt) | ||
print("Score Response:") | ||
pprint.pprint(score_response.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import pytest | ||
import requests | ||
|
||
from vllm.entrypoints.openai.protocol import ScoreResponse | ||
|
||
from ...utils import RemoteOpenAIServer | ||
|
||
MODEL_NAME = "BAAI/bge-reranker-v2-m3" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def server(): | ||
args = [ | ||
"--enforce-eager", | ||
] | ||
|
||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: | ||
yield remote_server | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer, | ||
model_name: str): | ||
text_1 = "What is the capital of France?" | ||
text_2 = [ | ||
"The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
] | ||
|
||
score_response = requests.post(server.url_for("v1/score"), | ||
json={ | ||
"model": model_name, | ||
"text_1": text_1, | ||
"text_2": text_2, | ||
}) | ||
score_response.raise_for_status() | ||
score = ScoreResponse.model_validate(score_response.json()) | ||
|
||
assert score.id is not None | ||
assert score.data is not None | ||
assert len(score.data) == 2 | ||
assert score.data[0].score[0] <= 0.01 | ||
assert score.data[1].score[0] >= 0.9 | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer, | ||
model_name: str): | ||
text_1 = [ | ||
"What is the capital of the United States?", | ||
"What is the capital of France?" | ||
] | ||
text_2 = [ | ||
"The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
] | ||
|
||
score_response = requests.post(server.url_for("v1/score"), | ||
json={ | ||
"model": model_name, | ||
"text_1": text_1, | ||
"text_2": text_2, | ||
}) | ||
score_response.raise_for_status() | ||
score = ScoreResponse.model_validate(score_response.json()) | ||
|
||
assert score.id is not None | ||
assert score.data is not None | ||
assert len(score.data) == 2 | ||
assert score.data[0].score[0] <= 0.01 | ||
assert score.data[1].score[0] >= 0.9 | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer, | ||
model_name: str): | ||
text_1 = "What is the capital of France?" | ||
text_2 = "The capital of France is Paris." | ||
|
||
score_response = requests.post(server.url_for("v1/score"), | ||
json={ | ||
"model": model_name, | ||
"text_1": text_1, | ||
"text_2": text_2, | ||
}) | ||
score_response.raise_for_status() | ||
score = ScoreResponse.model_validate(score_response.json()) | ||
|
||
assert score.id is not None | ||
assert score.data is not None | ||
assert len(score.data) == 1 | ||
assert score.data[0].score[0] >= 0.9 |
Oops, something went wrong.