From 9eca3e8d807b4b6136f14cf15718a53ab358dc1c Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Tue, 3 Dec 2024 14:42:40 +0000 Subject: [PATCH] rerank 3.5 fixes --- src/cohere/aws_client.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py index 79242d0fb..26ce11514 100644 --- a/src/cohere/aws_client.py +++ b/src/cohere/aws_client.py @@ -56,7 +56,7 @@ def __init__( timeout: typing.Optional[float] = None, service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]], ): - Client.__init__( + ClientV2.__init__( self, base_url="https://api.cohere.com", # this url is unused for BedrockClient environment=ClientEnvironment.PRODUCTION, @@ -217,6 +217,8 @@ def _event_hook(request: httpx.Request) -> None: headers = request.headers.copy() del headers["connection"] + + api_version = request.url.path.split("/")[-2] endpoint = request.url.path.split("/")[-1] body = json.loads(request.read()) model = body["model"] @@ -230,6 +232,9 @@ def _event_hook(request: httpx.Request) -> None: request.url = URL(url) request.headers["host"] = request.url.host + if endpoint == "rerank": + body["api_version"] = get_api_version(version=api_version) + if "stream" in body: del body["stream"] @@ -255,20 +260,6 @@ def _event_hook(request: httpx.Request) -> None: return _event_hook -def get_endpoint_from_url(url: str, - chat_model: typing.Optional[str] = None, - embed_model: typing.Optional[str] = None, - generate_model: typing.Optional[str] = None, - ) -> str: - if chat_model and chat_model in url: - return "chat" - if embed_model and embed_model in url: - return "embed" - if generate_model and generate_model in url: - return "generate" - raise ValueError(f"Unknown endpoint in url: {url}") - - def get_url( *, platform: str, @@ -283,3 +274,12 @@ def get_url( endpoint = "invocations" if not stream else "invocations-response-stream" return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}" return "" + + +def get_api_version(*, version: str): + int_version = { + "v1": 1, + "v2": 2, + } + + return int_version.get(version, 1) \ No newline at end of file