Skip to content

Commit

Permalink
Add simple cache unit test and move server test to proper location (#317
Browse files Browse the repository at this point in the history
)

Also reorganizes server test to test multiple batch sizes.
  • Loading branch information
renxida authored Oct 24, 2024
1 parent 3b5dc8a commit 5fef4b9
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64_nogil-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ jobs:
- name: Run shortfin Python tests (full)
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
pytest -s --ignore=tests/examples/fastapi_test.py
pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py
# TODO: Enable further tests and switch to
# pytest -s
94 changes: 94 additions & 0 deletions shortfin/tests/apps/llm/components/cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Tests for llm kvcache component.
"""

import pytest
import time
import tempfile
import shortfin as sf
from _shortfin import lib as sfl
from shortfin_apps.llm.components import cache
from shortfin_apps.llm.components import config_struct
import json
from pathlib import Path


@pytest.fixture
def lsys():
sc = sfl.local.host.CPUSystemBuilder()
ls = sc.create_system()
yield ls
ls.shutdown()


@pytest.fixture
def fiber(lsys):
# TODO: Should adopt the main thread.
worker = lsys.create_worker("main")
return lsys.create_fiber(worker)


@pytest.fixture
def device(fiber):
return fiber.device(0)


@pytest.fixture
def model_params():
model_params = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 2048,
"attn_head_count": 32,
"attn_head_dim": 100,
"prefill_batch_sizes": [4],
"decode_batch_sizes": [4],
"transformer_block_count": 26,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}

# Create a temporary file to store the JSON
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
json.dump(model_params, tmp_file, indent=4)
tmp_path = Path(tmp_file.name)

try:
# Load the JSON using config_struct
model_params = config_struct.ModelParams.load_json(tmp_path)
yield model_params
finally:
tmp_path.unlink


@pytest.fixture
def cache_fixture(fiber, model_params) -> cache.AttnPageCache:
# Create and return the cache object
return cache.AttnPageCache(
devices=fiber.devices_dict.values(), model_params=model_params
)


@pytest.mark.parametrize("n_allocated", [1, 16, 255])
def test_alloc(
cache_fixture: cache.AttnPageCache,
n_allocated,
model_params: config_struct.ModelParams,
):
alloc_page_count = cache_fixture.page_tables[0].shape[0]

assert alloc_page_count == model_params.paged_kv_cache.device_block_count

pages = cache_fixture.acquire_free_pages(n_allocated)
last_page = alloc_page_count - 1
expected_indices = range(last_page, last_page - n_allocated, -1)
for p, expected_ix in zip(pages, expected_indices):
assert p.index == expected_ix
assert p.index > 0
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import requests
import os
import json
import uuid

BATCH_SIZES = [1, 4]


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -34,6 +37,7 @@ def setup_environment():
mlir_path = "/tmp/sharktank/llama/model.mlir"
config_path = "/tmp/sharktank/llama/config.json"
if not os.path.exists(mlir_path) or not os.path.exists(config_path):
bs_string = ",".join(map(str, BATCH_SIZES))
subprocess.run(
[
"python",
Expand All @@ -42,6 +46,7 @@ def setup_environment():
f"--gguf-file={model_path}",
f"--output-mlir={mlir_path}",
f"--output-config={config_path}",
f"--bs={bs_string}",
],
check=True,
)
Expand Down Expand Up @@ -70,8 +75,8 @@ def setup_environment():
"max_seq_len": 2048,
"attn_head_count": 32,
"attn_head_dim": 100,
"prefill_batch_sizes": [4],
"decode_batch_sizes": [4],
"prefill_batch_sizes": BATCH_SIZES,
"decode_batch_sizes": BATCH_SIZES,
"transformer_block_count": 26,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}
Expand All @@ -96,7 +101,7 @@ def llm_server(setup_environment):
)

# Wait for server to start
time.sleep(5)
time.sleep(2)

yield server_process

Expand All @@ -105,6 +110,43 @@ def llm_server(setup_environment):
server_process.wait()


def do_generate(prompt):

headers = {"Content-Type": "application/json"}
# Create a GenerateReqInput-like structure
data = {
"text": prompt,
"sampling_params": {"max_tokens": 50, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"return_logprob": False,
"logprob_start_len": -1,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"stream": False,
}

print("Prompt text:")
print(data["text"])

BASE_URL = "http://localhost:8000"

response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data)
print(f"Generate endpoint status code: {response.status_code}")

if response.status_code == 200:
print("Generated text:")
data = response.text
assert data.startswith("data: ")
data = data[6:]
assert data.endswith("\n\n")
data = data[:-2]

return data
else:
response.raise_for_status()


@pytest.mark.system("amdgpu")
def test_llm_server(llm_server):
# Here you would typically make requests to your server
# and assert on the responses
Expand All @@ -115,3 +157,6 @@ def test_llm_server(llm_server):

# For now, we'll just check if the server process is running
assert llm_server.poll() is None
output = do_generate("1 2 3 4 5 ")
print(output)
assert output.startswith("6 7 8")

0 comments on commit 5fef4b9

Please sign in to comment.