diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 7e15e30..43afeb6 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -163,6 +163,7 @@ def __init__( total_retries: int = 8, nice: bool = False, verify_ssl=True, + tags: Optional[Sequence[str]] = None ) -> None: if host[-1] != "/": host += "/" @@ -171,6 +172,7 @@ def __init__( self.request_timeout_seconds = request_timeout_seconds self.token = token self.nice = nice + self.tags = tags retry_strategy = Retry( total=total_retries, @@ -242,6 +244,8 @@ def _build_json_body( json_body["model"] = model if self.hosting is not None: json_body["hosting"] = self.hosting + if self.tags is not None: + json_body["tags"] = self.tags return json_body def models(self) -> List[Mapping[str, Any]]: @@ -647,6 +651,7 @@ def __init__( total_retries: int = 8, nice: bool = False, verify_ssl=True, + tags: Optional[Sequence[str]] = None ) -> None: if host[-1] != "/": host += "/" @@ -655,6 +660,7 @@ def __init__( self.request_timeout_seconds = request_timeout_seconds self.token = token self.nice = nice + self.tags = tags retry_options = ExponentialRetry( attempts=total_retries + 1, @@ -762,6 +768,8 @@ def _build_json_body( json_body["model"] = model if self.hosting is not None: json_body["hosting"] = self.hosting + if self.tags is not None: + json_body["tags"] = self.tags return json_body async def models(self) -> List[Mapping[str, Any]]: diff --git a/tests/test_clients.py b/tests/test_clients.py index d1b5f2e..1e920af 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -54,7 +54,7 @@ async def test_can_use_async_client_without_context_manager(model_name: str): def test_nice_flag_on_client(httpserver: HTTPServer): - httpserver.expect_request("/version").respond_with_data("OK") + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) httpserver.expect_request( "/complete", query_string={"nice": "true"} @@ -79,7 +79,7 @@ def test_nice_flag_on_client(httpserver: HTTPServer): async def test_nice_flag_on_async_client(httpserver: HTTPServer): - httpserver.expect_request("/version").respond_with_data("OK") + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) httpserver.expect_request( "/complete", @@ -96,10 +96,60 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer): request = CompletionRequest(prompt=Prompt.from_text("Hello world")) async with AsyncClient( - host=httpserver.url_for(""), token="AA_TOKEN", nice=True + host=httpserver.url_for(""), token="AA_TOKEN", nice=True, request_timeout_seconds=1 ) as client: await client.complete(request, model="luminous") +def test_tags_on_client(httpserver: HTTPServer): + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) + + request = CompletionRequest(prompt=Prompt.from_text("Hello world")) + body = request.to_json() + body["tags"] = ["tim-tagger"] + body["model"] = "luminous" + httpserver.expect_request( + "/complete", json=body + ).respond_with_json( + CompletionResponse( + "model_version", + [ + CompletionResult( + log_probs=[], + completion="foo", + ) + ], + num_tokens_prompt_total=2, + num_tokens_generated=1, + ).to_json() + ) + + client = Client(host=httpserver.url_for(""), request_timeout_seconds=1, token="AA_TOKEN", tags=["tim-tagger"]) + + client.complete(request, model="luminous") + + +async def test_tags_on_async_client(httpserver: HTTPServer): + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) + + request = CompletionRequest(prompt=Prompt.from_text("Hello world")) + body = request.to_json() + body["tags"] = ["tim-tagger"] + body["model"] = "luminous" + httpserver.expect_request( + "/complete", json=body + ).respond_with_json( + CompletionResponse( + "model_version", + [CompletionResult(log_probs=[], completion="foo")], + num_tokens_prompt_total=2, + num_tokens_generated=1, + ).to_json() + ) + + async with AsyncClient( + host=httpserver.url_for(""), token="AA_TOKEN", tags=["tim-tagger"] + ) as client: + await client.complete(request, model="luminous") @pytest.mark.system_test def test_available_models_sync_client(sync_client: Client, model_name: str):