diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index c36b42e07d7f7..a98d92c09ab45 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -8,6 +8,7 @@ Feature: llama.cpp server And 42 as server seed And 32 KV cache size And 1 slots + And embeddings extraction And 32 server max tokens to predict Then the server is starting Then the server is healthy diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 3bfbd6824550a..ac1bbd6bebb65 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -4,6 +4,7 @@ import re import socket import subprocess +import time from contextlib import closing from re import RegexFlag @@ -21,13 +22,14 @@ def step_server_config(context, server_fqdn, server_port): context.base_url = f'http://{context.server_fqdn}:{context.server_port}' - context.server_continuous_batching = False context.model_alias = None context.n_ctx = None context.n_predict = None context.n_server_predict = None context.n_slots = None context.server_api_key = None + context.server_continuous_batching = False + context.server_embeddings = False context.server_seed = None context.user_api_key = None @@ -70,15 +72,26 @@ def step_server_n_predict(context, n_predict): def step_server_continuous_batching(context): context.server_continuous_batching = True +@step(u'embeddings extraction') +def step_server_embeddings(context): + context.server_embeddings = True + @step(u"the server is starting") def step_start_server(context): start_server_background(context) + attempts = 0 while True: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: result = sock.connect_ex((context.server_fqdn, context.server_port)) if result == 0: + print("server started!") return + attempts += 1 + if attempts > 20: + assert False, "server not started" + print("waiting for server to start...") + time.sleep(0.1) @step(u"the server is {expecting_status}") @@ -301,6 +314,11 @@ def step_compute_embedding(context): @step(u'embeddings are generated') def step_compute_embeddings(context): assert len(context.embeddings) > 0 + embeddings_computed = False + for emb in context.embeddings: + if emb != 0: + embeddings_computed = True + assert embeddings_computed, f"Embeddings: {context.embeddings}" @step(u'an OAI compatible embeddings computation request for') @@ -436,7 +454,8 @@ async def oai_chat_completions(user_prompt, json=payload, headers=headers) as response: if enable_streaming: - print("payload", payload) + # FIXME: does not work; the server is generating only one token + print("DEBUG payload", payload) assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Content-Type'] == "text/event-stream" @@ -453,7 +472,7 @@ async def oai_chat_completions(user_prompt, if 'content' in delta: completion_response['content'] += delta['content'] completion_response['timings']['predicted_n'] += 1 - print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}") + print(f"DEBUG completion_response: {completion_response}") else: if expect_api_error is None or not expect_api_error: assert response.status == 200 @@ -500,7 +519,7 @@ async def oai_chat_completions(user_prompt, 'predicted_n': chat_completion.usage.completion_tokens } } - print("OAI response formatted to llama.cpp", completion_response) + print("OAI response formatted to llama.cpp:", completion_response) return completion_response @@ -567,7 +586,7 @@ async def wait_for_health_status(context, # Sometimes health requests are triggered after completions are predicted if expected_http_status_code == 503: if len(context.completions) == 0: - print("\x1b[5;37;43mWARNING: forcing concurrents completions tasks," + print("\x1b[33;42mWARNING: forcing concurrents completions tasks," " busy health check missed\x1b[0m") n_completions = await gather_concurrent_completions_tasks(context) if n_completions > 0: @@ -604,6 +623,8 @@ def start_server_background(context): ] if context.server_continuous_batching: server_args.append('--cont-batching') + if context.server_embeddings: + server_args.append('--embedding') if context.model_alias is not None: server_args.extend(['--alias', context.model_alias]) if context.server_seed is not None: @@ -620,3 +641,4 @@ def start_server_background(context): context.server_process = subprocess.Popen( [str(arg) for arg in [context.server_path, *server_args]], close_fds=True) + print(f"server pid={context.server_process.pid}")