Skip to content

Commit

Permalink
fix: bump clients test base url to llama (huggingface#1751)
Browse files Browse the repository at this point in the history
This PR bumps the client tests from `google/flan-t5-xxl` to
`meta-llama/Llama-2-7b-chat-hf` to resolve issues when calling the
endpoint and `google/flan-t5-xxl` is not available

run with
```bash
make python-client-tests

clients/python/tests/test_client.py ..............     [ 43%]
clients/python/tests/test_errors.py ..........         [ 75%]
clients/python/tests/test_inference_api.py ......      [ 93%]
clients/python/tests/test_types.py ..                  [100%]
```

**note `google/flan-t5-xxl` function is currently unused but still
included in the `conftest.py`
  • Loading branch information
drbh authored and kdamaszk committed Jun 3, 2024
1 parent 65748c7 commit fea0f2f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 32 deletions.
10 changes: 10 additions & 0 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ def flan_t5_xxl():
return "google/flan-t5-xxl"


@pytest.fixture
def llama_7b():
return "meta-llama/Llama-2-7b-chat-hf"


@pytest.fixture
def fake_model():
return "fake/model"
Expand All @@ -34,6 +39,11 @@ def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"


@pytest.fixture
def llama_7b_url(base_url, llama_7b):
return f"{base_url}/{llama_7b}"


@pytest.fixture
def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}"
Expand Down
67 changes: 35 additions & 32 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
from text_generation.types import FinishReason, InputToken


def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
def test_generate(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
response = client.generate("test", max_new_tokens=1, decoder_input_details=True)

assert response.generated_text == ""
assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.prefill) == 2
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " "
assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == "_"
assert not response.details.tokens[0].special


def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
def test_generate_best_of(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
response = client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
)
Expand All @@ -39,22 +39,22 @@ def test_generate_not_found(fake_url, hf_headers):
client.generate("test")


def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
def test_generate_validation_error(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000)


def test_generate_stream(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
def test_generate_stream(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
responses = [
response for response in client.generate_stream("test", max_new_tokens=1)
]

assert len(responses) == 1
response = responses[0]

assert response.generated_text == ""
assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
Expand All @@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers):
list(client.generate_stream("test"))


def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
def test_generate_stream_validation_error(llama_7b_url, hf_headers):
client = Client(llama_7b_url, hf_headers)
with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000))


@pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
async def test_generate_async(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
response = await client.generate(
"test", max_new_tokens=1, decoder_input_details=True
)

assert response.generated_text == ""
assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.prefill) == 2
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
assert response.details.prefill[1] == InputToken(
id=1243, text="test", logprob=-10.96875
)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " "
assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == "_"
assert not response.details.tokens[0].special


@pytest.mark.asyncio
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
async def test_generate_async_best_of(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
response = await client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
)
Expand All @@ -112,23 +115,23 @@ async def test_generate_async_not_found(fake_url, hf_headers):


@pytest.mark.asyncio
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
async def test_generate_async_validation_error(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000)


@pytest.mark.asyncio
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
async def test_generate_stream_async(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
responses = [
response async for response in client.generate_stream("test", max_new_tokens=1)
]

assert len(responses) == 1
response = responses[0]

assert response.generated_text == ""
assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
Expand All @@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):


@pytest.mark.asyncio
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers):
client = AsyncClient(llama_7b_url, hf_headers)
with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass

0 comments on commit fea0f2f

Please sign in to comment.