From a44773fbb1008cf0b4bd2aeb5b32589a96625f42 Mon Sep 17 00:00:00 2001 From: Julian Wiesler Date: Wed, 21 Feb 2024 10:33:44 +0100 Subject: [PATCH] Tags --- aleph_alpha_client/aleph_alpha_client.py | 14 ++++++ setup.py | 2 +- tests/test_clients.py | 62 ++++++++++++++++++++++-- tests/test_embed.py | 24 +++++++-- tests/test_error_handling.py | 7 ++- 5 files changed, 99 insertions(+), 10 deletions(-) diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 7e15e30..dd0520d 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -146,6 +146,9 @@ class Client: verify_ssl(bool, optional, default True) Setting this to False will disable checking for SSL when doing requests. + tags(Optional[Sequence[str]], optional, default None) + Internal feature. + Example usage: >>> request = CompletionRequest( prompt=Prompt.from_text(f"Request"), maximum_tokens=64 @@ -163,6 +166,7 @@ def __init__( total_retries: int = 8, nice: bool = False, verify_ssl=True, + tags: Optional[Sequence[str]] = None, ) -> None: if host[-1] != "/": host += "/" @@ -171,6 +175,7 @@ def __init__( self.request_timeout_seconds = request_timeout_seconds self.token = token self.nice = nice + self.tags = tags retry_strategy = Retry( total=total_retries, @@ -242,6 +247,8 @@ def _build_json_body( json_body["model"] = model if self.hosting is not None: json_body["hosting"] = self.hosting + if self.tags is not None: + json_body["tags"] = self.tags return json_body def models(self) -> List[Mapping[str, Any]]: @@ -632,6 +639,9 @@ class AsyncClient: verify_ssl(bool, optional, default True) Setting this to False will disable checking for SSL when doing requests. + tags(Optional[Sequence[str]], optional, default None) + Internal feature. + Example usage: >>> request = CompletionRequest(prompt=Prompt.from_text(f"Request"), maximum_tokens=64) >>> async with AsyncClient(token=os.environ["AA_TOKEN"]) as client: @@ -647,6 +657,7 @@ def __init__( total_retries: int = 8, nice: bool = False, verify_ssl=True, + tags: Optional[Sequence[str]] = None, ) -> None: if host[-1] != "/": host += "/" @@ -655,6 +666,7 @@ def __init__( self.request_timeout_seconds = request_timeout_seconds self.token = token self.nice = nice + self.tags = tags retry_options = ExponentialRetry( attempts=total_retries + 1, @@ -762,6 +774,8 @@ def _build_json_body( json_body["model"] = model if self.hosting is not None: json_body["hosting"] = self.hosting + if self.tags is not None: + json_body["tags"] = self.tags return json_body async def models(self) -> List[Mapping[str, Any]]: diff --git a/setup.py b/setup.py index 386c79d..5f1c847 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def version(): "Pillow >= 9.2.0", "tqdm >= v4.62.0", "python-liquid >= 1.9.4", - "packaging >= 23.2" + "packaging >= 23.2", ], tests_require=tests_require, extras_require={ diff --git a/tests/test_clients.py b/tests/test_clients.py index d1b5f2e..c22030b 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -54,7 +54,7 @@ async def test_can_use_async_client_without_context_manager(model_name: str): def test_nice_flag_on_client(httpserver: HTTPServer): - httpserver.expect_request("/version").respond_with_data("OK") + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) httpserver.expect_request( "/complete", query_string={"nice": "true"} @@ -79,7 +79,7 @@ def test_nice_flag_on_client(httpserver: HTTPServer): async def test_nice_flag_on_async_client(httpserver: HTTPServer): - httpserver.expect_request("/version").respond_with_data("OK") + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) httpserver.expect_request( "/complete", @@ -96,7 +96,63 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer): request = CompletionRequest(prompt=Prompt.from_text("Hello world")) async with AsyncClient( - host=httpserver.url_for(""), token="AA_TOKEN", nice=True + host=httpserver.url_for(""), + token="AA_TOKEN", + nice=True, + request_timeout_seconds=1, + ) as client: + await client.complete(request, model="luminous") + + +def test_tags_on_client(httpserver: HTTPServer): + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) + + request = CompletionRequest(prompt=Prompt.from_text("Hello world")) + body = {k: v for k, v in request.to_json().items() if v is not None} + body["tags"] = ["tim-tagger"] + body["model"] = "luminous" + httpserver.expect_request("/complete", json=body).respond_with_json( + CompletionResponse( + "model_version", + [ + CompletionResult( + log_probs=[], + completion="foo", + ) + ], + num_tokens_prompt_total=2, + num_tokens_generated=1, + ).to_json() + ) + + client = Client( + host=httpserver.url_for(""), + request_timeout_seconds=1, + token="AA_TOKEN", + tags=["tim-tagger"], + ) + + client.complete(request, model="luminous") + + +async def test_tags_on_async_client(httpserver: HTTPServer): + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) + + request = CompletionRequest(prompt=Prompt.from_text("Hello world")) + body = {k: v for k, v in request.to_json().items() if v is not None} + body["tags"] = ["tim-tagger"] + body["model"] = "luminous" + httpserver.expect_request("/complete", json=body).respond_with_json( + CompletionResponse( + "model_version", + [CompletionResult(log_probs=[], completion="foo")], + num_tokens_prompt_total=2, + num_tokens_generated=1, + ).to_json() + ) + + async with AsyncClient( + host=httpserver.url_for(""), token="AA_TOKEN", tags=["tim-tagger"] ) as client: await client.complete(request, model="luminous") diff --git a/tests/test_embed.py b/tests/test_embed.py index dc4751c..34b46f5 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -10,7 +10,8 @@ from aleph_alpha_client.embedding import ( BatchSemanticEmbeddingRequest, SemanticEmbeddingRequest, - SemanticRepresentation, BatchSemanticEmbeddingResponse, + SemanticRepresentation, + BatchSemanticEmbeddingResponse, ) from aleph_alpha_client.prompt import Prompt from tests.common import ( @@ -61,7 +62,9 @@ async def test_batch_embed_semantic_with_async_client( ): words = ["car", "elephant", "kitchen sink", "rubber", "sun"] r = random.Random(4082) - prompts = list([Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)]) + prompts = list( + [Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)] + ) request = BatchSemanticEmbeddingRequest( prompts=prompts, @@ -69,7 +72,10 @@ async def test_batch_embed_semantic_with_async_client( compress_to_size=128, ) result = await async_client.batch_semantic_embed( - request=request, num_concurrent_requests=10, batch_size=batch_size, model="luminous-base" + request=request, + num_concurrent_requests=10, + batch_size=batch_size, + model="luminous-base", ) # We have no control over the exact tokenizer used in the backend, so we cannot know the exact @@ -127,7 +133,11 @@ async def test_modelname_gets_passed_along_for_async_client(httpserver: HTTPServ } httpserver.expect_ordered_request( "/batch_semantic_embed", method="POST", data=json.dumps(expected_body) - ).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json()) + ).respond_with_json( + BatchSemanticEmbeddingResponse( + model_version="1", embeddings=[], num_tokens_prompt_total=1 + ).to_json() + ) async_client = AsyncClient(token="", host=httpserver.url_for(""), total_retries=1) await async_client.batch_semantic_embed(request, model=model_name) @@ -226,6 +236,10 @@ def test_modelname_gets_passed_along_for_sync_client(httpserver: HTTPServer): expected_body = {**request.to_json(), "model": model_name} httpserver.expect_ordered_request( "/batch_semantic_embed", method="POST", data=json.dumps(expected_body) - ).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json()) + ).respond_with_json( + BatchSemanticEmbeddingResponse( + model_version="1", embeddings=[], num_tokens_prompt_total=1 + ).to_json() + ) sync_client = Client(token="", host=httpserver.url_for(""), total_retries=1) sync_client.batch_semantic_embed(request, model=model_name) diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index 4c38c61..73a197e 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -111,7 +111,12 @@ def expect_retryable_error( def expect_valid_completion(httpserver: HTTPServer) -> None: httpserver.expect_ordered_request("/complete").respond_with_json( - CompletionResponse(model_version="1", completions=[], num_tokens_prompt_total=0, num_tokens_generated=0).to_json() + CompletionResponse( + model_version="1", + completions=[], + num_tokens_prompt_total=0, + num_tokens_generated=0, + ).to_json() )