Skip to content

Commit

Permalink
rerank 3.5 fixes (#611)
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere authored Dec 3, 2024
1 parent 3bc93c0 commit 066bf3d
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
Client.__init__(
ClientV2.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
Expand Down Expand Up @@ -217,6 +217,8 @@ def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
del headers["connection"]


api_version = request.url.path.split("/")[-2]
endpoint = request.url.path.split("/")[-1]
body = json.loads(request.read())
model = body["model"]
Expand All @@ -230,6 +232,9 @@ def _event_hook(request: httpx.Request) -> None:
request.url = URL(url)
request.headers["host"] = request.url.host

if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)

if "stream" in body:
del body["stream"]

Expand All @@ -255,20 +260,6 @@ def _event_hook(request: httpx.Request) -> None:
return _event_hook


def get_endpoint_from_url(url: str,
chat_model: typing.Optional[str] = None,
embed_model: typing.Optional[str] = None,
generate_model: typing.Optional[str] = None,
) -> str:
if chat_model and chat_model in url:
return "chat"
if embed_model and embed_model in url:
return "embed"
if generate_model and generate_model in url:
return "generate"
raise ValueError(f"Unknown endpoint in url: {url}")


def get_url(
*,
platform: str,
Expand All @@ -283,3 +274,12 @@ def get_url(
endpoint = "invocations" if not stream else "invocations-response-stream"
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
return ""


def get_api_version(*, version: str):
int_version = {
"v1": 1,
"v2": 2,
}

return int_version.get(version, 1)

0 comments on commit 066bf3d

Please sign in to comment.