diff --git a/README.md b/README.md index f1d2911..01c0f60 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ import os from aleph_alpha_client import Client, CompletionRequest, Prompt client = Client( - token=os.getenv("AA_TOKEN"), - host="https://inference-api.your-domain.com", + token=os.environ["TEST_TOKEN"], + host=os.environ["TEST_API_URL"], ) request = CompletionRequest( prompt=Prompt.from_text("Provide a short description of AI:"), @@ -39,8 +39,8 @@ from aleph_alpha_client import AsyncClient, CompletionRequest, Prompt # Can enter context manager within an async function async with AsyncClient( - token=os.environ["AA_TOKEN"] - host="https://inference-api.your-domain.com", + token=os.environ["TEST_TOKEN"], + host=os.environ["TEST_API_URL"], ) as client: request = CompletionRequest( prompt=Prompt.from_text("Provide a short description of AI:"), diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 20b0d78..b375d05 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -165,8 +165,8 @@ class Client: Example usage: >>> request = CompletionRequest(prompt=Prompt.from_text(f"Request"), maximum_tokens=64) >>> client = Client( - token=os.environ["AA_TOKEN"], - host="https://inference-api.your-domain.com", + token=os.environ["TEST_TOKEN"], + host=os.environ["TEST_API_URL"], ) >>> response: CompletionResponse = client.complete(request, "pharia-1-llm-7b-control") """ @@ -743,8 +743,8 @@ class AsyncClient: Example usage: >>> request = CompletionRequest(prompt=Prompt.from_text(f"Request"), maximum_tokens=64) >>> async with AsyncClient( - token=os.environ["AA_TOKEN"], - host="https://inference-api.your-domain.com" + token=os.environ["TEST_TOKEN"], + host=os.environ["TEST_API_URL"], ) as client: response: CompletionResponse = await client.complete(request, "pharia-1-llm-7b-control") """ diff --git a/docs/source/index.rst b/docs/source/index.rst index 8339ef5..16a6d57 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -18,7 +18,7 @@ Synchronous client. from aleph_alpha_client import Client, CompletionRequest, Prompt import os - client = Client(token=os.getenv("AA_TOKEN"), host="https://inference-api.your-domain.com") + client = Client(token=os.environ["TEST_TOKEN"], host=os.environ["TEST_API_URL"]) prompt = Prompt.from_text("Provide a short description of AI:") request = CompletionRequest(prompt=prompt, maximum_tokens=20) result = client.complete(request, model="luminous-extended") @@ -32,7 +32,7 @@ Synchronous client with prompt containing an image. from aleph_alpha_client import Client, CompletionRequest, PromptTemplate, Image import os - client = Client(token=os.getenv("AA_TOKEN"), host="https://inference-api.your-domain.com") + client = Client(token=os.environ["TEST_TOKEN"], host=os.environ["TEST_API_URL"]) image = Image.from_file("path-to-an-image") prompt_template = PromptTemplate("{{image}}This picture shows ") prompt = prompt_template.to_prompt(image=prompt_template.placeholder(image)) @@ -50,7 +50,7 @@ Asynchronous client. from aleph_alpha_client import AsyncClient, CompletionRequest, Prompt # Can enter context manager within an async function - async with AsyncClient(token=os.environ["AA_TOKEN"], host="https://inference-api.your-domain.com") as client: + async with AsyncClient(token=os.environ["TEST_TOKEN"], host=os.environ["TEST_API_URL"]) as client: request = CompletionRequest( prompt=Prompt.from_text("Request"), maximum_tokens=64, diff --git a/tests/test_clients.py b/tests/test_clients.py index 32dd135..3ea7704 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -16,25 +16,27 @@ def test_api_version_mismatch_client(httpserver: HTTPServer): httpserver.expect_request("/version").respond_with_data("0.0.0") with pytest.raises(RuntimeError): - Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version() + Client(host=httpserver.url_for(""), token="TEST_TOKEN").validate_version() async def test_api_version_mismatch_async_client(httpserver: HTTPServer): httpserver.expect_request("/version").respond_with_data("0.0.0") with pytest.raises(RuntimeError): - async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client: + async with AsyncClient( + host=httpserver.url_for(""), token="TEST_TOKEN" + ) as client: await client.validate_version() def test_api_version_correct_client(httpserver: HTTPServer): httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) - Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version() + Client(host=httpserver.url_for(""), token="TEST_TOKEN").validate_version() async def test_api_version_correct_async_client(httpserver: HTTPServer): httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) - async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client: + async with AsyncClient(host=httpserver.url_for(""), token="TEST_TOKEN") as client: await client.validate_version() @@ -71,7 +73,7 @@ def test_nice_flag_on_client(httpserver: HTTPServer): ).to_json() ) - client = Client(host=httpserver.url_for(""), token="AA_TOKEN", nice=True) + client = Client(host=httpserver.url_for(""), token="TEST_TOKEN", nice=True) request = CompletionRequest(prompt=Prompt.from_text("Hello world")) client.complete(request, model="luminous") @@ -96,7 +98,7 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer): async with AsyncClient( host=httpserver.url_for(""), - token="AA_TOKEN", + token="TEST_TOKEN", nice=True, request_timeout_seconds=1, ) as client: @@ -127,7 +129,7 @@ def test_tags_on_client(httpserver: HTTPServer): client = Client( host=httpserver.url_for(""), request_timeout_seconds=1, - token="AA_TOKEN", + token="TEST_TOKEN", tags=["tim-tagger"], ) @@ -151,7 +153,7 @@ async def test_tags_on_async_client(httpserver: HTTPServer): ) async with AsyncClient( - host=httpserver.url_for(""), token="AA_TOKEN", tags=["tim-tagger"] + host=httpserver.url_for(""), token="TEST_TOKEN", tags=["tim-tagger"] ) as client: await client.complete(request, model="luminous") diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index 73a197e..949263c 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -29,7 +29,7 @@ def test_translate_errors(): def test_retry_sync(httpserver: HTTPServer): num_retries = 2 client = Client( - token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries + token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries ) expect_retryable_error(httpserver, num_calls_expected=num_retries) expect_valid_version(httpserver) @@ -40,7 +40,7 @@ def test_retry_sync(httpserver: HTTPServer): def test_retry_sync_post(httpserver: HTTPServer): num_retries = 2 client = Client( - host=httpserver.url_for(""), token="AA_TOKEN", total_retries=num_retries + host=httpserver.url_for(""), token="TEST_TOKEN", total_retries=num_retries ) expect_retryable_error(httpserver, num_calls_expected=num_retries) expect_valid_completion(httpserver) @@ -52,7 +52,7 @@ def test_retry_sync_post(httpserver: HTTPServer): def test_exhaust_retries_sync(httpserver: HTTPServer): num_retries = 1 client = Client( - token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries + token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries ) expect_retryable_error( httpserver, @@ -69,7 +69,7 @@ async def test_retry_async(httpserver: HTTPServer): expect_valid_version(httpserver) async with AsyncClient( - token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries + token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries ) as client: await client.get_version() @@ -80,7 +80,7 @@ async def test_retry_async_post(httpserver: HTTPServer): expect_valid_completion(httpserver) async with AsyncClient( - token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries + token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries ) as client: request = CompletionRequest(prompt=Prompt.from_text(""), maximum_tokens=7) await client.complete(request, model="FOO") @@ -95,7 +95,7 @@ async def test_exhaust_retries_async(httpserver: HTTPServer): ) with pytest.raises(BusyError): async with AsyncClient( - token="AA_TOKEN", host=httpserver.url_for(""), total_retries=num_retries + token="TEST_TOKEN", host=httpserver.url_for(""), total_retries=num_retries ) as client: await client.get_version()