diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index ca25b3d1f..cc8a964a5 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -35,7 +35,7 @@ jobs:
version: 1.5.1
virtualenvs-in-project: false
- name: Install dependencies
- run: poetry install --extras aws
+ run: poetry install
- name: Test
run: poetry run pytest .
env:
diff --git a/pyproject.toml b/pyproject.toml
index f2b80c00e..9f4308e69 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
-version = "5.13.2"
+version = "5.12.0"
description = ""
readme = "README.md"
authors = []
diff --git a/reference.md b/reference.md
index 9cf1c2f5e..c1b7d5a60 100644
--- a/reference.md
+++ b/reference.md
@@ -1793,7 +1793,7 @@ We recommend a maximum of 1,000 documents for optimal endpoint performance.
-
-**model:** `typing.Optional[str]` — The identifier of the model to use, eg `rerank-v3.5`.
+**model:** `typing.Optional[str]` — The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
@@ -2382,6 +2382,7 @@ response = client.v2.chat_stream(
p=1.1,
return_prompt=True,
logprobs=True,
+ stream=True,
)
for chunk in response:
yield chunk
@@ -2654,6 +2655,7 @@ client.v2.chat(
content="messages",
)
],
+ stream=False,
)
```
@@ -2971,7 +2973,7 @@ Available models and corresponding embedding dimensions:
**embedding_types:** `typing.Sequence[EmbeddingType]`
-Specifies the types of embeddings you want to get back. Can be one or more of the following types.
+Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -3084,7 +3086,15 @@ client.v2.rerank(
-
-**model:** `str` — The identifier of the model to use, eg `rerank-v3.5`.
+**model:** `str`
+
+The identifier of the model to use.
+
+Supported models:
+ - `rerank-english-v3.0`
+ - `rerank-multilingual-v3.0`
+ - `rerank-english-v2.0`
+ - `rerank-multilingual-v2.0`
diff --git a/src/cohere/__init__.py b/src/cohere/__init__.py
index 70321646d..3a126e4f7 100644
--- a/src/cohere/__init__.py
+++ b/src/cohere/__init__.py
@@ -245,7 +245,7 @@
)
from . import connectors, datasets, embed_jobs, finetuning, models, v2
from .aws_client import AwsClient
-from .bedrock_client import BedrockClient, BedrockClientV2
+from .bedrock_client import BedrockClient
from .client import AsyncClient, Client
from .client_v2 import AsyncClientV2, ClientV2
from .datasets import (
@@ -257,7 +257,7 @@
)
from .embed_jobs import CreateEmbedJobRequestTruncate
from .environment import ClientEnvironment
-from .sagemaker_client import SagemakerClient, SagemakerClientV2
+from .sagemaker_client import SagemakerClient
from .v2 import (
V2ChatRequestDocumentsItem,
V2ChatRequestSafetyMode,
@@ -287,7 +287,6 @@
"AwsClient",
"BadRequestError",
"BedrockClient",
- "BedrockClientV2",
"ChatCitation",
"ChatCitationGenerationEvent",
"ChatConnector",
@@ -461,7 +460,6 @@
"ResponseFormat",
"ResponseFormatV2",
"SagemakerClient",
- "SagemakerClientV2",
"SearchQueriesGenerationStreamedChatResponse",
"SearchResultsStreamedChatResponse",
"ServiceUnavailableError",
diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py
index 8aea9d15c..cdbeeedbe 100644
--- a/src/cohere/aws_client.py
+++ b/src/cohere/aws_client.py
@@ -12,7 +12,7 @@
from .client import Client, ClientEnvironment
from .core import construct_type
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
-from .client_v2 import ClientV2
+
class AwsClient(Client):
def __init__(
@@ -45,37 +45,6 @@ def __init__(
)
-class AwsClientV2(ClientV2):
- def __init__(
- self,
- *,
- aws_access_key: typing.Optional[str] = None,
- aws_secret_key: typing.Optional[str] = None,
- aws_session_token: typing.Optional[str] = None,
- aws_region: typing.Optional[str] = None,
- timeout: typing.Optional[float] = None,
- service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
- ):
- ClientV2.__init__(
- self,
- base_url="https://api.cohere.com", # this url is unused for BedrockClient
- environment=ClientEnvironment.PRODUCTION,
- client_name="n/a",
- timeout=timeout,
- api_key="n/a",
- httpx_client=httpx.Client(
- event_hooks=get_event_hooks(
- service=service,
- aws_access_key=aws_access_key,
- aws_secret_key=aws_secret_key,
- aws_session_token=aws_session_token,
- aws_region=aws_region,
- ),
- timeout=timeout,
- ),
- )
-
-
EventHook = typing.Callable[..., typing.Any]
@@ -196,13 +165,6 @@ def _hook(
return _hook
-def get_boto3_session(
- **kwargs: typing.Any,
-):
- non_none_args = {k: v for k, v in kwargs.items() if v is not None}
- return lazy_boto3().Session(**non_none_args)
-
-
def map_request_to_bedrock(
service: str,
@@ -211,22 +173,19 @@ def map_request_to_bedrock(
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> EventHook:
- session = get_boto3_session(
+ session = lazy_boto3().Session(
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
)
- aws_region = session.region_name
credentials = session.get_credentials()
- signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region)
+ signer = lazy_botocore().auth.SigV4Auth(credentials, service, session.region_name)
def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
del headers["connection"]
-
- api_version = request.url.path.split("/")[-2]
endpoint = request.url.path.split("/")[-1]
body = json.loads(request.read())
model = body["model"]
@@ -240,9 +199,6 @@ def _event_hook(request: httpx.Request) -> None:
request.url = URL(url)
request.headers["host"] = request.url.host
- if endpoint == "rerank":
- body["api_version"] = get_api_version(version=api_version)
-
if "stream" in body:
del body["stream"]
@@ -268,6 +224,20 @@ def _event_hook(request: httpx.Request) -> None:
return _event_hook
+def get_endpoint_from_url(url: str,
+ chat_model: typing.Optional[str] = None,
+ embed_model: typing.Optional[str] = None,
+ generate_model: typing.Optional[str] = None,
+ ) -> str:
+ if chat_model and chat_model in url:
+ return "chat"
+ if embed_model and embed_model in url:
+ return "embed"
+ if generate_model and generate_model in url:
+ return "generate"
+ raise ValueError(f"Unknown endpoint in url: {url}")
+
+
def get_url(
*,
platform: str,
@@ -282,12 +252,3 @@ def get_url(
endpoint = "invocations" if not stream else "invocations-response-stream"
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
return ""
-
-
-def get_api_version(*, version: str):
- int_version = {
- "v1": 1,
- "v2": 2,
- }
-
- return int_version.get(version, 1)
\ No newline at end of file
diff --git a/src/cohere/base_client.py b/src/cohere/base_client.py
index 50a638de6..b18ccb2f3 100644
--- a/src/cohere/base_client.py
+++ b/src/cohere/base_client.py
@@ -2012,7 +2012,7 @@ def rerank(
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
- The identifier of the model to use, eg `rerank-v3.5`.
+ The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
@@ -5047,7 +5047,7 @@ async def rerank(
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
- The identifier of the model to use, eg `rerank-v3.5`.
+ The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
diff --git a/src/cohere/bedrock_client.py b/src/cohere/bedrock_client.py
index 16f47b031..bcc24786a 100644
--- a/src/cohere/bedrock_client.py
+++ b/src/cohere/bedrock_client.py
@@ -2,7 +2,7 @@
from tokenizers import Tokenizer # type: ignore
-from .aws_client import AwsClient, AwsClientV2
+from .aws_client import AwsClient
class BedrockClient(AwsClient):
@@ -24,26 +24,3 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
-
- def rerank(self, *, query, documents, model = ..., top_n = ..., rank_fields = ..., return_documents = ..., max_chunks_per_doc = ..., request_options = None):
- raise NotImplementedError("Please use cohere.BedrockClientV2 instead: Rerank API on Bedrock is not supported with cohere.BedrockClient for this model.")
-
-class BedrockClientV2(AwsClientV2):
- def __init__(
- self,
- *,
- aws_access_key: typing.Optional[str] = None,
- aws_secret_key: typing.Optional[str] = None,
- aws_session_token: typing.Optional[str] = None,
- aws_region: typing.Optional[str] = None,
- timeout: typing.Optional[float] = None,
- ):
- AwsClientV2.__init__(
- self,
- service="bedrock",
- aws_access_key=aws_access_key,
- aws_secret_key=aws_secret_key,
- aws_session_token=aws_session_token,
- aws_region=aws_region,
- timeout=timeout,
- )
diff --git a/src/cohere/core/client_wrapper.py b/src/cohere/core/client_wrapper.py
index 9fb4430e7..8213e6da3 100644
--- a/src/cohere/core/client_wrapper.py
+++ b/src/cohere/core/client_wrapper.py
@@ -24,7 +24,7 @@ def get_headers(self) -> typing.Dict[str, str]:
headers: typing.Dict[str, str] = {
"X-Fern-Language": "Python",
"X-Fern-SDK-Name": "cohere",
- "X-Fern-SDK-Version": "5.13.2",
+ "X-Fern-SDK-Version": "5.12.0",
}
if self._client_name is not None:
headers["X-Client-Name"] = self._client_name
diff --git a/src/cohere/sagemaker_client.py b/src/cohere/sagemaker_client.py
index 77f9f9115..6d4236d53 100644
--- a/src/cohere/sagemaker_client.py
+++ b/src/cohere/sagemaker_client.py
@@ -1,6 +1,6 @@
import typing
-from .aws_client import AwsClient, AwsClientV2
+from .aws_client import AwsClient
from .manually_maintained.cohere_aws.client import Client
from .manually_maintained.cohere_aws.mode import Mode
@@ -26,34 +26,4 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
- try:
- self.sagemaker_finetuning = Client(aws_region=aws_region)
- except Exception:
- pass
-
-
-class SagemakerClientV2(AwsClientV2):
- sagemaker_finetuning: Client
-
- def __init__(
- self,
- *,
- aws_access_key: typing.Optional[str] = None,
- aws_secret_key: typing.Optional[str] = None,
- aws_session_token: typing.Optional[str] = None,
- aws_region: typing.Optional[str] = None,
- timeout: typing.Optional[float] = None,
- ):
- AwsClientV2.__init__(
- self,
- service="sagemaker",
- aws_access_key=aws_access_key,
- aws_secret_key=aws_secret_key,
- aws_session_token=aws_session_token,
- aws_region=aws_region,
- timeout=timeout,
- )
- try:
- self.sagemaker_finetuning = Client(aws_region=aws_region)
- except Exception:
- pass
\ No newline at end of file
+ self.sagemaker_finetuning = Client(aws_region=aws_region)
\ No newline at end of file
diff --git a/src/cohere/v2/client.py b/src/cohere/v2/client.py
index c62edd8e0..a4e465e79 100644
--- a/src/cohere/v2/client.py
+++ b/src/cohere/v2/client.py
@@ -222,6 +222,7 @@ def chat_stream(
p=1.1,
return_prompt=True,
logprobs=True,
+ stream=True,
)
for chunk in response:
yield chunk
@@ -534,6 +535,7 @@ def chat(
content="messages",
)
],
+ stream=False,
)
"""
_response = self._client_wrapper.httpx_client.request(
@@ -568,7 +570,6 @@ def chat(
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
- "stream": False,
},
request_options=request_options,
omit=OMIT,
@@ -736,7 +737,7 @@ def embed(
input_type : EmbedInputType
embedding_types : typing.Sequence[EmbeddingType]
- Specifies the types of embeddings you want to get back. Can be one or more of the following types.
+ Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -936,7 +937,13 @@ def rerank(
Parameters
----------
model : str
- The identifier of the model to use, eg `rerank-v3.5`.
+ The identifier of the model to use.
+
+ Supported models:
+ - `rerank-english-v3.0`
+ - `rerank-multilingual-v3.0`
+ - `rerank-english-v2.0`
+ - `rerank-multilingual-v2.0`
query : str
The search query
@@ -1301,6 +1308,7 @@ async def main() -> None:
p=1.1,
return_prompt=True,
logprobs=True,
+ stream=True,
)
async for chunk in response:
yield chunk
@@ -1340,7 +1348,6 @@ async def main() -> None:
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
- "stream": True,
},
request_options=request_options,
omit=OMIT,
@@ -1621,6 +1628,7 @@ async def main() -> None:
content="messages",
)
],
+ stream=False,
)
@@ -1658,7 +1666,6 @@ async def main() -> None:
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
- "stream": False,
},
request_options=request_options,
omit=OMIT,
@@ -1826,7 +1833,7 @@ async def embed(
input_type : EmbedInputType
embedding_types : typing.Sequence[EmbeddingType]
- Specifies the types of embeddings you want to get back. Can be one or more of the following types.
+ Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
@@ -2034,7 +2041,13 @@ async def rerank(
Parameters
----------
model : str
- The identifier of the model to use, eg `rerank-v3.5`.
+ The identifier of the model to use.
+
+ Supported models:
+ - `rerank-english-v3.0`
+ - `rerank-multilingual-v3.0`
+ - `rerank-english-v2.0`
+ - `rerank-multilingual-v2.0`
query : str
The search query
diff --git a/tests/test_aws_client.py b/tests/test_aws_client.py
new file mode 100644
index 000000000..d361bc184
--- /dev/null
+++ b/tests/test_aws_client.py
@@ -0,0 +1,140 @@
+# import os
+# import unittest
+
+# import typing
+# import cohere
+# from parameterized import parameterized_class # type: ignore
+
+# package_dir = os.path.dirname(os.path.abspath(__file__))
+# embed_job = os.path.join(package_dir, 'embed_job.jsonl')
+
+
+# model_mapping = {
+# "bedrock": {
+# "chat_model": "cohere.command-r-plus-v1:0",
+# "embed_model": "cohere.embed-multilingual-v3",
+# "generate_model": "cohere.command-text-v14",
+# },
+# "sagemaker": {
+# "chat_model": "cohere.command-r-plus-v1:0",
+# "embed_model": "cohere.embed-multilingual-v3",
+# "generate_model": "cohere-command-light",
+# "rerank_model": "rerank",
+# },
+# }
+
+
+# @parameterized_class([
+# {
+# "platform": "bedrock",
+# "client": cohere.BedrockClient(
+# timeout=10000,
+# aws_region="us-east-1",
+# aws_access_key="...",
+# aws_secret_key="...",
+# aws_session_token="...",
+# ),
+# "models": model_mapping["bedrock"],
+# },
+# {
+# "platform": "sagemaker",
+# "client": cohere.SagemakerClient(
+# timeout=10000,
+# aws_region="us-east-2",
+# aws_access_key="...",
+# aws_secret_key="...",
+# aws_session_token="...",
+# ),
+# "models": model_mapping["sagemaker"],
+# }
+# ])
+# @unittest.skip("skip tests until they work in CI")
+# class TestClient(unittest.TestCase):
+# platform: str
+# client: cohere.AwsClient
+# models: typing.Dict[str, str]
+
+# def test_rerank(self) -> None:
+# if self.platform != "sagemaker":
+# self.skipTest("Only sagemaker supports rerank")
+
+# docs = [
+# 'Carson City is the capital city of the American state of Nevada.',
+# 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
+# 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
+# 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']
+
+# response = self.client.rerank(
+# model=self.models["rerank_model"],
+# query='What is the capital of the United States?',
+# documents=docs,
+# top_n=3,
+# )
+
+# self.assertEqual(len(response.results), 3)
+
+# def test_embed(self) -> None:
+# response = self.client.embed(
+# model=self.models["embed_model"],
+# texts=["I love Cohere!"],
+# input_type="search_document",
+# )
+# print(response)
+
+# def test_generate(self) -> None:
+# response = self.client.generate(
+# model=self.models["generate_model"],
+# prompt='Please explain to me how LLMs work',
+# )
+# print(response)
+
+# def test_generate_stream(self) -> None:
+# response = self.client.generate_stream(
+# model=self.models["generate_model"],
+# prompt='Please explain to me how LLMs work',
+# )
+# for event in response:
+# print(event)
+# if event.event_type == "text-generation":
+# print(event.text, end='')
+
+# def test_chat(self) -> None:
+# response = self.client.chat(
+# model=self.models["chat_model"],
+# message='Please explain to me how LLMs work',
+# )
+# print(response)
+
+# self.assertIsNotNone(response.text)
+# self.assertIsNotNone(response.generation_id)
+# self.assertIsNotNone(response.finish_reason)
+
+# self.assertIsNotNone(response.meta)
+# if response.meta is not None:
+# self.assertIsNotNone(response.meta.tokens)
+# if response.meta.tokens is not None:
+# self.assertIsNotNone(response.meta.tokens.input_tokens)
+# self.assertIsNotNone(response.meta.tokens.output_tokens)
+
+# self.assertIsNotNone(response.meta.billed_units)
+# if response.meta.billed_units is not None:
+# self.assertIsNotNone(response.meta.billed_units.input_tokens)
+# self.assertIsNotNone(response.meta.billed_units.input_tokens)
+
+# def test_chat_stream(self) -> None:
+# response_types = set()
+# response = self.client.chat_stream(
+# model=self.models["chat_model"],
+# message='Please explain to me how LLMs work',
+# )
+# for event in response:
+# response_types.add(event.event_type)
+# if event.event_type == "text-generation":
+# print(event.text, end='')
+# self.assertIsNotNone(event.text)
+# if event.event_type == "stream-end":
+# self.assertIsNotNone(event.finish_reason)
+# self.assertIsNotNone(event.response)
+# self.assertIsNotNone(event.response.text)
+
+# self.assertSetEqual(response_types, {"text-generation", "stream-end"})
diff --git a/tests/test_bedrock_client.py b/tests/test_bedrock_client.py
deleted file mode 100644
index d588ca38c..000000000
--- a/tests/test_bedrock_client.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import os
-import unittest
-
-import typing
-import cohere
-
-aws_access_key = os.getenv("AWS_ACCESS_KEY")
-aws_secret_key = os.getenv("AWS_SECRET_KEY")
-aws_session_token = os.getenv("AWS_SESSION_TOKEN")
-aws_region = os.getenv("AWS_REGION")
-endpoint_type = os.getenv("ENDPOINT_TYPE")
-
-@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set")
-class TestClient(unittest.TestCase):
- platform: str = "bedrock"
- client: cohere.AwsClient = cohere.BedrockClient(
- aws_access_key=aws_access_key,
- aws_secret_key=aws_secret_key,
- aws_session_token=aws_session_token,
- aws_region=aws_region,
- )
- models: typing.Dict[str, str] = {
- "chat_model": "cohere.command-r-plus-v1:0",
- "embed_model": "cohere.embed-multilingual-v3",
- "generate_model": "cohere.command-text-v14",
- }
-
- def test_rerank(self) -> None:
- if self.platform != "sagemaker":
- self.skipTest("Only sagemaker supports rerank")
-
- docs = [
- 'Carson City is the capital city of the American state of Nevada.',
- 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
- 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
- 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']
-
- response = self.client.rerank(
- model=self.models["rerank_model"],
- query='What is the capital of the United States?',
- documents=docs,
- top_n=3,
- )
-
- self.assertEqual(len(response.results), 3)
-
- def test_embed(self) -> None:
- response = self.client.embed(
- model=self.models["embed_model"],
- texts=["I love Cohere!"],
- input_type="search_document",
- )
- print(response)
-
- def test_generate(self) -> None:
- response = self.client.generate(
- model=self.models["generate_model"],
- prompt='Please explain to me how LLMs work',
- )
- print(response)
-
- def test_generate_stream(self) -> None:
- response = self.client.generate_stream(
- model=self.models["generate_model"],
- prompt='Please explain to me how LLMs work',
- )
- for event in response:
- print(event)
- if event.event_type == "text-generation":
- print(event.text, end='')
-
- def test_chat(self) -> None:
- response = self.client.chat(
- model=self.models["chat_model"],
- message='Please explain to me how LLMs work',
- )
- print(response)
-
- self.assertIsNotNone(response.text)
- self.assertIsNotNone(response.generation_id)
- self.assertIsNotNone(response.finish_reason)
-
- self.assertIsNotNone(response.meta)
- if response.meta is not None:
- self.assertIsNotNone(response.meta.tokens)
- if response.meta.tokens is not None:
- self.assertIsNotNone(response.meta.tokens.input_tokens)
- self.assertIsNotNone(response.meta.tokens.output_tokens)
-
- self.assertIsNotNone(response.meta.billed_units)
- if response.meta.billed_units is not None:
- self.assertIsNotNone(response.meta.billed_units.input_tokens)
- self.assertIsNotNone(response.meta.billed_units.input_tokens)
-
- def test_chat_stream(self) -> None:
- response_types = set()
- response = self.client.chat_stream(
- model=self.models["chat_model"],
- message='Please explain to me how LLMs work',
- )
- for event in response:
- response_types.add(event.event_type)
- if event.event_type == "text-generation":
- print(event.text, end='')
- self.assertIsNotNone(event.text)
- if event.event_type == "stream-end":
- self.assertIsNotNone(event.finish_reason)
- self.assertIsNotNone(event.response)
- self.assertIsNotNone(event.response.text)
-
- self.assertSetEqual(response_types, {"text-generation", "stream-end"})
diff --git a/tests/test_client_init.py b/tests/test_client_init.py
deleted file mode 100644
index 9783b2540..000000000
--- a/tests/test_client_init.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import os
-import typing
-import unittest
-
-import cohere
-from cohere import ToolMessage, UserMessage, AssistantMessage
-
-class TestClientInit(unittest.TestCase):
- def test_inits(self) -> None:
- cohere.BedrockClient()
- cohere.BedrockClientV2()
- cohere.SagemakerClient()
- cohere.SagemakerClientV2()
- cohere.Client(api_key="n/a")
- cohere.ClientV2(api_key="n/a")
-