From 5fef4b99fff719433c69a85a16e6ee0851e76d43 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 23 Oct 2024 19:56:01 -0700 Subject: [PATCH] Add simple cache unit test and move server test to proper location (#317) Also reorganizes server test to test multiple batch sizes. --- .../ci_linux_x64_nogil-libshortfin.yml | 2 +- .../tests/apps/llm/components/cache_test.py | 94 +++++++++++++++++++ .../apps}/llm/test_llm_server.py | 51 +++++++++- 3 files changed, 143 insertions(+), 4 deletions(-) create mode 100644 shortfin/tests/apps/llm/components/cache_test.py rename shortfin/{python/shortfin_apps => tests/apps}/llm/test_llm_server.py (73%) diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 84b266fc7..ee9470fc0 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -98,6 +98,6 @@ jobs: - name: Run shortfin Python tests (full) working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - pytest -s --ignore=tests/examples/fastapi_test.py + pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py # TODO: Enable further tests and switch to # pytest -s diff --git a/shortfin/tests/apps/llm/components/cache_test.py b/shortfin/tests/apps/llm/components/cache_test.py new file mode 100644 index 000000000..169d082b1 --- /dev/null +++ b/shortfin/tests/apps/llm/components/cache_test.py @@ -0,0 +1,94 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Tests for llm kvcache component. +""" + +import pytest +import time +import tempfile +import shortfin as sf +from _shortfin import lib as sfl +from shortfin_apps.llm.components import cache +from shortfin_apps.llm.components import config_struct +import json +from pathlib import Path + + +@pytest.fixture +def lsys(): + sc = sfl.local.host.CPUSystemBuilder() + ls = sc.create_system() + yield ls + ls.shutdown() + + +@pytest.fixture +def fiber(lsys): + # TODO: Should adopt the main thread. + worker = lsys.create_worker("main") + return lsys.create_fiber(worker) + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +@pytest.fixture +def model_params(): + model_params = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 2048, + "attn_head_count": 32, + "attn_head_dim": 100, + "prefill_batch_sizes": [4], + "decode_batch_sizes": [4], + "transformer_block_count": 26, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + + # Create a temporary file to store the JSON + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as tmp_file: + json.dump(model_params, tmp_file, indent=4) + tmp_path = Path(tmp_file.name) + + try: + # Load the JSON using config_struct + model_params = config_struct.ModelParams.load_json(tmp_path) + yield model_params + finally: + tmp_path.unlink + + +@pytest.fixture +def cache_fixture(fiber, model_params) -> cache.AttnPageCache: + # Create and return the cache object + return cache.AttnPageCache( + devices=fiber.devices_dict.values(), model_params=model_params + ) + + +@pytest.mark.parametrize("n_allocated", [1, 16, 255]) +def test_alloc( + cache_fixture: cache.AttnPageCache, + n_allocated, + model_params: config_struct.ModelParams, +): + alloc_page_count = cache_fixture.page_tables[0].shape[0] + + assert alloc_page_count == model_params.paged_kv_cache.device_block_count + + pages = cache_fixture.acquire_free_pages(n_allocated) + last_page = alloc_page_count - 1 + expected_indices = range(last_page, last_page - n_allocated, -1) + for p, expected_ix in zip(pages, expected_indices): + assert p.index == expected_ix + assert p.index > 0 diff --git a/shortfin/python/shortfin_apps/llm/test_llm_server.py b/shortfin/tests/apps/llm/test_llm_server.py similarity index 73% rename from shortfin/python/shortfin_apps/llm/test_llm_server.py rename to shortfin/tests/apps/llm/test_llm_server.py index 51474333f..e87b2797f 100644 --- a/shortfin/python/shortfin_apps/llm/test_llm_server.py +++ b/shortfin/tests/apps/llm/test_llm_server.py @@ -4,6 +4,9 @@ import requests import os import json +import uuid + +BATCH_SIZES = [1, 4] @pytest.fixture(scope="module") @@ -34,6 +37,7 @@ def setup_environment(): mlir_path = "/tmp/sharktank/llama/model.mlir" config_path = "/tmp/sharktank/llama/config.json" if not os.path.exists(mlir_path) or not os.path.exists(config_path): + bs_string = ",".join(map(str, BATCH_SIZES)) subprocess.run( [ "python", @@ -42,6 +46,7 @@ def setup_environment(): f"--gguf-file={model_path}", f"--output-mlir={mlir_path}", f"--output-config={config_path}", + f"--bs={bs_string}", ], check=True, ) @@ -70,8 +75,8 @@ def setup_environment(): "max_seq_len": 2048, "attn_head_count": 32, "attn_head_dim": 100, - "prefill_batch_sizes": [4], - "decode_batch_sizes": [4], + "prefill_batch_sizes": BATCH_SIZES, + "decode_batch_sizes": BATCH_SIZES, "transformer_block_count": 26, "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, } @@ -96,7 +101,7 @@ def llm_server(setup_environment): ) # Wait for server to start - time.sleep(5) + time.sleep(2) yield server_process @@ -105,6 +110,43 @@ def llm_server(setup_environment): server_process.wait() +def do_generate(prompt): + + headers = {"Content-Type": "application/json"} + # Create a GenerateReqInput-like structure + data = { + "text": prompt, + "sampling_params": {"max_tokens": 50, "temperature": 0.7}, + "rid": uuid.uuid4().hex, + "return_logprob": False, + "logprob_start_len": -1, + "top_logprobs_num": 0, + "return_text_in_logprobs": False, + "stream": False, + } + + print("Prompt text:") + print(data["text"]) + + BASE_URL = "http://localhost:8000" + + response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data) + print(f"Generate endpoint status code: {response.status_code}") + + if response.status_code == 200: + print("Generated text:") + data = response.text + assert data.startswith("data: ") + data = data[6:] + assert data.endswith("\n\n") + data = data[:-2] + + return data + else: + response.raise_for_status() + + +@pytest.mark.system("amdgpu") def test_llm_server(llm_server): # Here you would typically make requests to your server # and assert on the responses @@ -115,3 +157,6 @@ def test_llm_server(llm_server): # For now, we'll just check if the server process is running assert llm_server.poll() is None + output = do_generate("1 2 3 4 5 ") + print(output) + assert output.startswith("6 7 8")