diff --git a/pyproject.toml b/pyproject.toml
index 199ba4ded..9f4308e69 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
-version = "5.13.2"
+version = "5.12.0"
description = ""
readme = "README.md"
authors = []
@@ -36,15 +36,6 @@ boto3 = { version="^1.34.0", optional = true}
fastavro = "^1.9.4"
httpx = ">=0.21.2"
httpx-sse = "0.4.0"
-# Without specifying a version, the numpy and pandas will be 1.24.4 and 2.0.3 respectively
-numpy = [
- { version="~1.24.4", python = "<3.12", optional = true },
- { version="~1.26", python = ">=3.12", optional = true }
-]
-pandas = [
- { version="~2.0.3", python = "<3.13", optional = true },
- { version="~2.2.3", python = ">=3.13", optional = true }
-]
parameterized = "^0.9.0"
pydantic = ">= 1.9.2"
pydantic-core = "^2.18.2"
diff --git a/reference.md b/reference.md
index 9cf1c2f5e..c1b7d5a60 100644
--- a/reference.md
+++ b/reference.md
@@ -1793,7 +1793,7 @@ We recommend a maximum of 1,000 documents for optimal endpoint performance.
-
-**model:** `typing.Optional[str]` — The identifier of the model to use, eg `rerank-v3.5`.
+**model:** `typing.Optional[str]` — The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
@@ -2382,6 +2382,7 @@ response = client.v2.chat_stream(
p=1.1,
return_prompt=True,
logprobs=True,
+ stream=True,
)
for chunk in response:
yield chunk
@@ -2654,6 +2655,7 @@ client.v2.chat(
content="messages",
)
],
+ stream=False,
)
```
@@ -2971,7 +2973,7 @@ Available models and corresponding embedding dimensions:
**embedding_types:** `typing.Sequence[EmbeddingType]`
-Specifies the types of embeddings you want to get back. Can be one or more of the following types.
+Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -3084,7 +3086,15 @@ client.v2.rerank(
-
-**model:** `str` — The identifier of the model to use, eg `rerank-v3.5`.
+**model:** `str`
+
+The identifier of the model to use.
+
+Supported models:
+ - `rerank-english-v3.0`
+ - `rerank-multilingual-v3.0`
+ - `rerank-english-v2.0`
+ - `rerank-multilingual-v2.0`
diff --git a/src/cohere/__init__.py b/src/cohere/__init__.py
index 70321646d..3a126e4f7 100644
--- a/src/cohere/__init__.py
+++ b/src/cohere/__init__.py
@@ -245,7 +245,7 @@
)
from . import connectors, datasets, embed_jobs, finetuning, models, v2
from .aws_client import AwsClient
-from .bedrock_client import BedrockClient, BedrockClientV2
+from .bedrock_client import BedrockClient
from .client import AsyncClient, Client
from .client_v2 import AsyncClientV2, ClientV2
from .datasets import (
@@ -257,7 +257,7 @@
)
from .embed_jobs import CreateEmbedJobRequestTruncate
from .environment import ClientEnvironment
-from .sagemaker_client import SagemakerClient, SagemakerClientV2
+from .sagemaker_client import SagemakerClient
from .v2 import (
V2ChatRequestDocumentsItem,
V2ChatRequestSafetyMode,
@@ -287,7 +287,6 @@
"AwsClient",
"BadRequestError",
"BedrockClient",
- "BedrockClientV2",
"ChatCitation",
"ChatCitationGenerationEvent",
"ChatConnector",
@@ -461,7 +460,6 @@
"ResponseFormat",
"ResponseFormatV2",
"SagemakerClient",
- "SagemakerClientV2",
"SearchQueriesGenerationStreamedChatResponse",
"SearchResultsStreamedChatResponse",
"ServiceUnavailableError",
diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py
index 8aea9d15c..79242d0fb 100644
--- a/src/cohere/aws_client.py
+++ b/src/cohere/aws_client.py
@@ -56,7 +56,7 @@ def __init__(
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
- ClientV2.__init__(
+ Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
@@ -196,13 +196,6 @@ def _hook(
return _hook
-def get_boto3_session(
- **kwargs: typing.Any,
-):
- non_none_args = {k: v for k, v in kwargs.items() if v is not None}
- return lazy_boto3().Session(**non_none_args)
-
-
def map_request_to_bedrock(
service: str,
@@ -211,22 +204,19 @@ def map_request_to_bedrock(
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> EventHook:
- session = get_boto3_session(
+ session = lazy_boto3().Session(
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
)
- aws_region = session.region_name
credentials = session.get_credentials()
- signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region)
+ signer = lazy_botocore().auth.SigV4Auth(credentials, service, session.region_name)
def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
del headers["connection"]
-
- api_version = request.url.path.split("/")[-2]
endpoint = request.url.path.split("/")[-1]
body = json.loads(request.read())
model = body["model"]
@@ -240,9 +230,6 @@ def _event_hook(request: httpx.Request) -> None:
request.url = URL(url)
request.headers["host"] = request.url.host
- if endpoint == "rerank":
- body["api_version"] = get_api_version(version=api_version)
-
if "stream" in body:
del body["stream"]
@@ -268,6 +255,20 @@ def _event_hook(request: httpx.Request) -> None:
return _event_hook
+def get_endpoint_from_url(url: str,
+ chat_model: typing.Optional[str] = None,
+ embed_model: typing.Optional[str] = None,
+ generate_model: typing.Optional[str] = None,
+ ) -> str:
+ if chat_model and chat_model in url:
+ return "chat"
+ if embed_model and embed_model in url:
+ return "embed"
+ if generate_model and generate_model in url:
+ return "generate"
+ raise ValueError(f"Unknown endpoint in url: {url}")
+
+
def get_url(
*,
platform: str,
@@ -282,12 +283,3 @@ def get_url(
endpoint = "invocations" if not stream else "invocations-response-stream"
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
return ""
-
-
-def get_api_version(*, version: str):
- int_version = {
- "v1": 1,
- "v2": 2,
- }
-
- return int_version.get(version, 1)
\ No newline at end of file
diff --git a/src/cohere/base_client.py b/src/cohere/base_client.py
index 50a638de6..b18ccb2f3 100644
--- a/src/cohere/base_client.py
+++ b/src/cohere/base_client.py
@@ -2012,7 +2012,7 @@ def rerank(
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
- The identifier of the model to use, eg `rerank-v3.5`.
+ The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
@@ -5047,7 +5047,7 @@ async def rerank(
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
- The identifier of the model to use, eg `rerank-v3.5`.
+ The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
diff --git a/src/cohere/bedrock_client.py b/src/cohere/bedrock_client.py
index 16f47b031..0246a288b 100644
--- a/src/cohere/bedrock_client.py
+++ b/src/cohere/bedrock_client.py
@@ -25,8 +25,6 @@ def __init__(
timeout=timeout,
)
- def rerank(self, *, query, documents, model = ..., top_n = ..., rank_fields = ..., return_documents = ..., max_chunks_per_doc = ..., request_options = None):
- raise NotImplementedError("Please use cohere.BedrockClientV2 instead: Rerank API on Bedrock is not supported with cohere.BedrockClient for this model.")
class BedrockClientV2(AwsClientV2):
def __init__(
diff --git a/src/cohere/core/client_wrapper.py b/src/cohere/core/client_wrapper.py
index 9fb4430e7..8213e6da3 100644
--- a/src/cohere/core/client_wrapper.py
+++ b/src/cohere/core/client_wrapper.py
@@ -24,7 +24,7 @@ def get_headers(self) -> typing.Dict[str, str]:
headers: typing.Dict[str, str] = {
"X-Fern-Language": "Python",
"X-Fern-SDK-Name": "cohere",
- "X-Fern-SDK-Version": "5.13.2",
+ "X-Fern-SDK-Version": "5.12.0",
}
if self._client_name is not None:
headers["X-Client-Name"] = self._client_name
diff --git a/src/cohere/sagemaker_client.py b/src/cohere/sagemaker_client.py
index 77f9f9115..84d20a2ac 100644
--- a/src/cohere/sagemaker_client.py
+++ b/src/cohere/sagemaker_client.py
@@ -26,10 +26,7 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
- try:
- self.sagemaker_finetuning = Client(aws_region=aws_region)
- except Exception:
- pass
+ self.sagemaker_finetuning = Client(aws_region=aws_region)
class SagemakerClientV2(AwsClientV2):
@@ -53,7 +50,4 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
- try:
- self.sagemaker_finetuning = Client(aws_region=aws_region)
- except Exception:
- pass
\ No newline at end of file
+ self.sagemaker_finetuning = Client(aws_region=aws_region)
\ No newline at end of file
diff --git a/src/cohere/v2/client.py b/src/cohere/v2/client.py
index c62edd8e0..a4e465e79 100644
--- a/src/cohere/v2/client.py
+++ b/src/cohere/v2/client.py
@@ -222,6 +222,7 @@ def chat_stream(
p=1.1,
return_prompt=True,
logprobs=True,
+ stream=True,
)
for chunk in response:
yield chunk
@@ -534,6 +535,7 @@ def chat(
content="messages",
)
],
+ stream=False,
)
"""
_response = self._client_wrapper.httpx_client.request(
@@ -568,7 +570,6 @@ def chat(
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
- "stream": False,
},
request_options=request_options,
omit=OMIT,
@@ -736,7 +737,7 @@ def embed(
input_type : EmbedInputType
embedding_types : typing.Sequence[EmbeddingType]
- Specifies the types of embeddings you want to get back. Can be one or more of the following types.
+ Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -936,7 +937,13 @@ def rerank(
Parameters
----------
model : str
- The identifier of the model to use, eg `rerank-v3.5`.
+ The identifier of the model to use.
+
+ Supported models:
+ - `rerank-english-v3.0`
+ - `rerank-multilingual-v3.0`
+ - `rerank-english-v2.0`
+ - `rerank-multilingual-v2.0`
query : str
The search query
@@ -1301,6 +1308,7 @@ async def main() -> None:
p=1.1,
return_prompt=True,
logprobs=True,
+ stream=True,
)
async for chunk in response:
yield chunk
@@ -1340,7 +1348,6 @@ async def main() -> None:
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
- "stream": True,
},
request_options=request_options,
omit=OMIT,
@@ -1621,6 +1628,7 @@ async def main() -> None:
content="messages",
)
],
+ stream=False,
)
@@ -1658,7 +1666,6 @@ async def main() -> None:
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
- "stream": False,
},
request_options=request_options,
omit=OMIT,
@@ -1826,7 +1833,7 @@ async def embed(
input_type : EmbedInputType
embedding_types : typing.Sequence[EmbeddingType]
- Specifies the types of embeddings you want to get back. Can be one or more of the following types.
+ Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -2034,7 +2041,13 @@ async def rerank(
Parameters
----------
model : str
- The identifier of the model to use, eg `rerank-v3.5`.
+ The identifier of the model to use.
+
+ Supported models:
+ - `rerank-english-v3.0`
+ - `rerank-multilingual-v3.0`
+ - `rerank-english-v2.0`
+ - `rerank-multilingual-v2.0`
query : str
The search query
diff --git a/tests/test_client_init.py b/tests/test_client_init.py
deleted file mode 100644
index 9783b2540..000000000
--- a/tests/test_client_init.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import os
-import typing
-import unittest
-
-import cohere
-from cohere import ToolMessage, UserMessage, AssistantMessage
-
-class TestClientInit(unittest.TestCase):
- def test_inits(self) -> None:
- cohere.BedrockClient()
- cohere.BedrockClientV2()
- cohere.SagemakerClient()
- cohere.SagemakerClientV2()
- cohere.Client(api_key="n/a")
- cohere.ClientV2(api_key="n/a")
-