Skip to content

Commit

Permalink
Tags
Browse files Browse the repository at this point in the history
  • Loading branch information
WieslerTNG committed Feb 21, 2024
1 parent 94aae05 commit 465a539
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
8 changes: 8 additions & 0 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
total_retries: int = 8,
nice: bool = False,
verify_ssl=True,
tags: Optional[Sequence[str]] = None
) -> None:
if host[-1] != "/":
host += "/"
Expand All @@ -171,6 +172,7 @@ def __init__(
self.request_timeout_seconds = request_timeout_seconds
self.token = token
self.nice = nice
self.tags = tags

retry_strategy = Retry(
total=total_retries,
Expand Down Expand Up @@ -242,6 +244,8 @@ def _build_json_body(
json_body["model"] = model
if self.hosting is not None:
json_body["hosting"] = self.hosting
if self.tags is not None:
json_body["tags"] = self.tags
return json_body

def models(self) -> List[Mapping[str, Any]]:
Expand Down Expand Up @@ -647,6 +651,7 @@ def __init__(
total_retries: int = 8,
nice: bool = False,
verify_ssl=True,
tags: Optional[Sequence[str]] = None
) -> None:
if host[-1] != "/":
host += "/"
Expand All @@ -655,6 +660,7 @@ def __init__(
self.request_timeout_seconds = request_timeout_seconds
self.token = token
self.nice = nice
self.tags = tags

retry_options = ExponentialRetry(
attempts=total_retries + 1,
Expand Down Expand Up @@ -762,6 +768,8 @@ def _build_json_body(
json_body["model"] = model
if self.hosting is not None:
json_body["hosting"] = self.hosting
if self.tags is not None:
json_body["tags"] = self.tags
return json_body

async def models(self) -> List[Mapping[str, Any]]:
Expand Down
56 changes: 53 additions & 3 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def test_can_use_async_client_without_context_manager(model_name: str):


def test_nice_flag_on_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("OK")
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)

httpserver.expect_request(
"/complete", query_string={"nice": "true"}
Expand All @@ -79,7 +79,7 @@ def test_nice_flag_on_client(httpserver: HTTPServer):


async def test_nice_flag_on_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("OK")
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)

httpserver.expect_request(
"/complete",
Expand All @@ -96,10 +96,60 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer):
request = CompletionRequest(prompt=Prompt.from_text("Hello world"))

async with AsyncClient(
host=httpserver.url_for(""), token="AA_TOKEN", nice=True
host=httpserver.url_for(""), token="AA_TOKEN", nice=True, request_timeout_seconds=1
) as client:
await client.complete(request, model="luminous")

def test_tags_on_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)

request = CompletionRequest(prompt=Prompt.from_text("Hello world"))
body = request.to_json()
body["tags"] = ["tim-tagger"]
body["model"] = "luminous"
httpserver.expect_request(
"/complete", json=body
).respond_with_json(
CompletionResponse(
"model_version",
[
CompletionResult(
log_probs=[],
completion="foo",
)
],
num_tokens_prompt_total=2,
num_tokens_generated=1,
).to_json()
)

client = Client(host=httpserver.url_for(""), request_timeout_seconds=1, token="AA_TOKEN", tags=["tim-tagger"])

client.complete(request, model="luminous")


async def test_tags_on_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)

request = CompletionRequest(prompt=Prompt.from_text("Hello world"))
body = request.to_json()
body["tags"] = ["tim-tagger"]
body["model"] = "luminous"
httpserver.expect_request(
"/complete", json=body
).respond_with_json(
CompletionResponse(
"model_version",
[CompletionResult(log_probs=[], completion="foo")],
num_tokens_prompt_total=2,
num_tokens_generated=1,
).to_json()
)

async with AsyncClient(
host=httpserver.url_for(""), token="AA_TOKEN", tags=["tim-tagger"]
) as client:
await client.complete(request, model="luminous")

@pytest.mark.system_test
def test_available_models_sync_client(sync_client: Client, model_name: str):
Expand Down

0 comments on commit 465a539

Please sign in to comment.