Skip to content

Commit

Permalink
Tags
Browse files Browse the repository at this point in the history
  • Loading branch information
WieslerTNG committed Feb 21, 2024
1 parent 94aae05 commit a44773f
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 10 deletions.
14 changes: 14 additions & 0 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class Client:
verify_ssl(bool, optional, default True)
Setting this to False will disable checking for SSL when doing requests.
tags(Optional[Sequence[str]], optional, default None)
Internal feature.
Example usage:
>>> request = CompletionRequest(
prompt=Prompt.from_text(f"Request"), maximum_tokens=64
Expand All @@ -163,6 +166,7 @@ def __init__(
total_retries: int = 8,
nice: bool = False,
verify_ssl=True,
tags: Optional[Sequence[str]] = None,
) -> None:
if host[-1] != "/":
host += "/"
Expand All @@ -171,6 +175,7 @@ def __init__(
self.request_timeout_seconds = request_timeout_seconds
self.token = token
self.nice = nice
self.tags = tags

retry_strategy = Retry(
total=total_retries,
Expand Down Expand Up @@ -242,6 +247,8 @@ def _build_json_body(
json_body["model"] = model
if self.hosting is not None:
json_body["hosting"] = self.hosting
if self.tags is not None:
json_body["tags"] = self.tags
return json_body

def models(self) -> List[Mapping[str, Any]]:
Expand Down Expand Up @@ -632,6 +639,9 @@ class AsyncClient:
verify_ssl(bool, optional, default True)
Setting this to False will disable checking for SSL when doing requests.
tags(Optional[Sequence[str]], optional, default None)
Internal feature.
Example usage:
>>> request = CompletionRequest(prompt=Prompt.from_text(f"Request"), maximum_tokens=64)
>>> async with AsyncClient(token=os.environ["AA_TOKEN"]) as client:
Expand All @@ -647,6 +657,7 @@ def __init__(
total_retries: int = 8,
nice: bool = False,
verify_ssl=True,
tags: Optional[Sequence[str]] = None,
) -> None:
if host[-1] != "/":
host += "/"
Expand All @@ -655,6 +666,7 @@ def __init__(
self.request_timeout_seconds = request_timeout_seconds
self.token = token
self.nice = nice
self.tags = tags

retry_options = ExponentialRetry(
attempts=total_retries + 1,
Expand Down Expand Up @@ -762,6 +774,8 @@ def _build_json_body(
json_body["model"] = model
if self.hosting is not None:
json_body["hosting"] = self.hosting
if self.tags is not None:
json_body["tags"] = self.tags
return json_body

async def models(self) -> List[Mapping[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def version():
"Pillow >= 9.2.0",
"tqdm >= v4.62.0",
"python-liquid >= 1.9.4",
"packaging >= 23.2"
"packaging >= 23.2",
],
tests_require=tests_require,
extras_require={
Expand Down
62 changes: 59 additions & 3 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def test_can_use_async_client_without_context_manager(model_name: str):


def test_nice_flag_on_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("OK")
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)

httpserver.expect_request(
"/complete", query_string={"nice": "true"}
Expand All @@ -79,7 +79,7 @@ def test_nice_flag_on_client(httpserver: HTTPServer):


async def test_nice_flag_on_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("OK")
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)

httpserver.expect_request(
"/complete",
Expand All @@ -96,7 +96,63 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer):
request = CompletionRequest(prompt=Prompt.from_text("Hello world"))

async with AsyncClient(
host=httpserver.url_for(""), token="AA_TOKEN", nice=True
host=httpserver.url_for(""),
token="AA_TOKEN",
nice=True,
request_timeout_seconds=1,
) as client:
await client.complete(request, model="luminous")


def test_tags_on_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)

request = CompletionRequest(prompt=Prompt.from_text("Hello world"))
body = {k: v for k, v in request.to_json().items() if v is not None}
body["tags"] = ["tim-tagger"]
body["model"] = "luminous"
httpserver.expect_request("/complete", json=body).respond_with_json(
CompletionResponse(
"model_version",
[
CompletionResult(
log_probs=[],
completion="foo",
)
],
num_tokens_prompt_total=2,
num_tokens_generated=1,
).to_json()
)

client = Client(
host=httpserver.url_for(""),
request_timeout_seconds=1,
token="AA_TOKEN",
tags=["tim-tagger"],
)

client.complete(request, model="luminous")


async def test_tags_on_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)

request = CompletionRequest(prompt=Prompt.from_text("Hello world"))
body = {k: v for k, v in request.to_json().items() if v is not None}
body["tags"] = ["tim-tagger"]
body["model"] = "luminous"
httpserver.expect_request("/complete", json=body).respond_with_json(
CompletionResponse(
"model_version",
[CompletionResult(log_probs=[], completion="foo")],
num_tokens_prompt_total=2,
num_tokens_generated=1,
).to_json()
)

async with AsyncClient(
host=httpserver.url_for(""), token="AA_TOKEN", tags=["tim-tagger"]
) as client:
await client.complete(request, model="luminous")

Expand Down
24 changes: 19 additions & 5 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from aleph_alpha_client.embedding import (
BatchSemanticEmbeddingRequest,
SemanticEmbeddingRequest,
SemanticRepresentation, BatchSemanticEmbeddingResponse,
SemanticRepresentation,
BatchSemanticEmbeddingResponse,
)
from aleph_alpha_client.prompt import Prompt
from tests.common import (
Expand Down Expand Up @@ -61,15 +62,20 @@ async def test_batch_embed_semantic_with_async_client(
):
words = ["car", "elephant", "kitchen sink", "rubber", "sun"]
r = random.Random(4082)
prompts = list([Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)])
prompts = list(
[Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)]
)

request = BatchSemanticEmbeddingRequest(
prompts=prompts,
representation=SemanticRepresentation.Symmetric,
compress_to_size=128,
)
result = await async_client.batch_semantic_embed(
request=request, num_concurrent_requests=10, batch_size=batch_size, model="luminous-base"
request=request,
num_concurrent_requests=10,
batch_size=batch_size,
model="luminous-base",
)

# We have no control over the exact tokenizer used in the backend, so we cannot know the exact
Expand Down Expand Up @@ -127,7 +133,11 @@ async def test_modelname_gets_passed_along_for_async_client(httpserver: HTTPServ
}
httpserver.expect_ordered_request(
"/batch_semantic_embed", method="POST", data=json.dumps(expected_body)
).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json())
).respond_with_json(
BatchSemanticEmbeddingResponse(
model_version="1", embeddings=[], num_tokens_prompt_total=1
).to_json()
)
async_client = AsyncClient(token="", host=httpserver.url_for(""), total_retries=1)
await async_client.batch_semantic_embed(request, model=model_name)

Expand Down Expand Up @@ -226,6 +236,10 @@ def test_modelname_gets_passed_along_for_sync_client(httpserver: HTTPServer):
expected_body = {**request.to_json(), "model": model_name}
httpserver.expect_ordered_request(
"/batch_semantic_embed", method="POST", data=json.dumps(expected_body)
).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json())
).respond_with_json(
BatchSemanticEmbeddingResponse(
model_version="1", embeddings=[], num_tokens_prompt_total=1
).to_json()
)
sync_client = Client(token="", host=httpserver.url_for(""), total_retries=1)
sync_client.batch_semantic_embed(request, model=model_name)
7 changes: 6 additions & 1 deletion tests/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ def expect_retryable_error(

def expect_valid_completion(httpserver: HTTPServer) -> None:
httpserver.expect_ordered_request("/complete").respond_with_json(
CompletionResponse(model_version="1", completions=[], num_tokens_prompt_total=0, num_tokens_generated=0).to_json()
CompletionResponse(
model_version="1",
completions=[],
num_tokens_prompt_total=0,
num_tokens_generated=0,
).to_json()
)


Expand Down

0 comments on commit a44773f

Please sign in to comment.