diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index fd4d260..7e15e30 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -1,4 +1,6 @@ import warnings + +from packaging import version from tokenizers import Tokenizer # type: ignore from types import TracebackType from typing import ( @@ -48,6 +50,7 @@ SemanticEmbeddingRequest, SemanticEmbeddingResponse, ) +from aleph_alpha_client.version import MIN_API_VERSION POOLING_OPTIONS = ["mean", "max", "last_token", "abs_max"] RETRY_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504}) @@ -80,6 +83,16 @@ def _raise_for_status(status_code: int, text: str): raise RuntimeError(status_code, text) +def _check_api_version(version_str: str): + api_ver = version.parse(MIN_API_VERSION) + ver = version.parse(version_str) + valid = api_ver.major == ver.major and api_ver <= ver + if not valid: + raise RuntimeError( + f"The aleph alpha client requires at least api version {api_ver}, found version {ver}" + ) + + AnyRequest = Union[ CompletionRequest, EmbeddingRequest, @@ -179,6 +192,10 @@ def __init__( self.session.mount("https://", adapter) self.session.mount("http://", adapter) + def validate_version(self) -> None: + """Gets version of the AlephAlpha HTTP API.""" + _check_api_version(self.get_version()) + def get_version(self) -> str: """Gets version of the AlephAlpha HTTP API.""" return self._get_request("version").text @@ -687,6 +704,9 @@ async def __aexit__( ): await self.session.__aexit__(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb) + async def validate_version(self) -> None: + _check_api_version(await self.get_version()) + async def get_version(self) -> str: """Gets version of the AlephAlpha HTTP API.""" return await self._get_request_text("version") diff --git a/aleph_alpha_client/version.py b/aleph_alpha_client/version.py index ba7be38..1e16b39 100644 --- a/aleph_alpha_client/version.py +++ b/aleph_alpha_client/version.py @@ -1 +1,2 @@ __version__ = "5.0.0" +MIN_API_VERSION = "1.15.0" diff --git a/setup.py b/setup.py index 274c374..386c79d 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ def version(): "Pillow >= 9.2.0", "tqdm >= v4.62.0", "python-liquid >= 1.9.4", + "packaging >= 23.2" ], tests_require=tests_require, extras_require={ diff --git a/tests/test_clients.py b/tests/test_clients.py index 6f454ba..d1b5f2e 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -1,6 +1,8 @@ from pytest_httpserver import HTTPServer import os import pytest + +from aleph_alpha_client.version import MIN_API_VERSION from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client from aleph_alpha_client.completion import ( CompletionRequest, @@ -11,6 +13,32 @@ from tests.common import model_name, sync_client, async_client +def test_api_version_mismatch_client(httpserver: HTTPServer): + httpserver.expect_request("/version").respond_with_data("0.0.0") + + with pytest.raises(RuntimeError): + Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version() + + +async def test_api_version_mismatch_async_client(httpserver: HTTPServer): + httpserver.expect_request("/version").respond_with_data("0.0.0") + + with pytest.raises(RuntimeError): + async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client: + await client.validate_version() + + +def test_api_version_correct_client(httpserver: HTTPServer): + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) + Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version() + + +async def test_api_version_correct_async_client(httpserver: HTTPServer): + httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION) + async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client: + await client.validate_version() + + @pytest.mark.system_test async def test_can_use_async_client_without_context_manager(model_name: str): request = CompletionRequest(