From 704cfacc9450953a230d5175433893e5ef16175e Mon Sep 17 00:00:00 2001 From: Andrew Berneshawi Date: Wed, 4 Dec 2024 20:13:53 -0800 Subject: [PATCH] Revert "Add bedrock test and v2 clis (#609)" This reverts commit bd906fa4828eb365f167b197342c9c49b50392aa. --- .github/workflows/ci.yml | 2 +- src/cohere/aws_client.py | 33 +------- src/cohere/bedrock_client.py | 23 +----- src/cohere/sagemaker_client.py | 26 +----- tests/test_aws_client.py | 140 +++++++++++++++++++++++++++++++++ tests/test_bedrock_client.py | 111 -------------------------- 6 files changed, 144 insertions(+), 191 deletions(-) create mode 100644 tests/test_aws_client.py delete mode 100644 tests/test_bedrock_client.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ca25b3d1f..cc8a964a5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: version: 1.5.1 virtualenvs-in-project: false - name: Install dependencies - run: poetry install --extras aws + run: poetry install - name: Test run: poetry run pytest . env: diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py index 79242d0fb..cdbeeedbe 100644 --- a/src/cohere/aws_client.py +++ b/src/cohere/aws_client.py @@ -12,40 +12,9 @@ from .client import Client, ClientEnvironment from .core import construct_type from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore -from .client_v2 import ClientV2 - -class AwsClient(Client): - def __init__( - self, - *, - aws_access_key: typing.Optional[str] = None, - aws_secret_key: typing.Optional[str] = None, - aws_session_token: typing.Optional[str] = None, - aws_region: typing.Optional[str] = None, - timeout: typing.Optional[float] = None, - service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]], - ): - Client.__init__( - self, - base_url="https://api.cohere.com", # this url is unused for BedrockClient - environment=ClientEnvironment.PRODUCTION, - client_name="n/a", - timeout=timeout, - api_key="n/a", - httpx_client=httpx.Client( - event_hooks=get_event_hooks( - service=service, - aws_access_key=aws_access_key, - aws_secret_key=aws_secret_key, - aws_session_token=aws_session_token, - aws_region=aws_region, - ), - timeout=timeout, - ), - ) -class AwsClientV2(ClientV2): +class AwsClient(Client): def __init__( self, *, diff --git a/src/cohere/bedrock_client.py b/src/cohere/bedrock_client.py index 0246a288b..bcc24786a 100644 --- a/src/cohere/bedrock_client.py +++ b/src/cohere/bedrock_client.py @@ -2,7 +2,7 @@ from tokenizers import Tokenizer # type: ignore -from .aws_client import AwsClient, AwsClientV2 +from .aws_client import AwsClient class BedrockClient(AwsClient): @@ -24,24 +24,3 @@ def __init__( aws_region=aws_region, timeout=timeout, ) - - -class BedrockClientV2(AwsClientV2): - def __init__( - self, - *, - aws_access_key: typing.Optional[str] = None, - aws_secret_key: typing.Optional[str] = None, - aws_session_token: typing.Optional[str] = None, - aws_region: typing.Optional[str] = None, - timeout: typing.Optional[float] = None, - ): - AwsClientV2.__init__( - self, - service="bedrock", - aws_access_key=aws_access_key, - aws_secret_key=aws_secret_key, - aws_session_token=aws_session_token, - aws_region=aws_region, - timeout=timeout, - ) diff --git a/src/cohere/sagemaker_client.py b/src/cohere/sagemaker_client.py index 84d20a2ac..6d4236d53 100644 --- a/src/cohere/sagemaker_client.py +++ b/src/cohere/sagemaker_client.py @@ -1,6 +1,6 @@ import typing -from .aws_client import AwsClient, AwsClientV2 +from .aws_client import AwsClient from .manually_maintained.cohere_aws.client import Client from .manually_maintained.cohere_aws.mode import Mode @@ -26,28 +26,4 @@ def __init__( aws_region=aws_region, timeout=timeout, ) - self.sagemaker_finetuning = Client(aws_region=aws_region) - - -class SagemakerClientV2(AwsClientV2): - sagemaker_finetuning: Client - - def __init__( - self, - *, - aws_access_key: typing.Optional[str] = None, - aws_secret_key: typing.Optional[str] = None, - aws_session_token: typing.Optional[str] = None, - aws_region: typing.Optional[str] = None, - timeout: typing.Optional[float] = None, - ): - AwsClientV2.__init__( - self, - service="sagemaker", - aws_access_key=aws_access_key, - aws_secret_key=aws_secret_key, - aws_session_token=aws_session_token, - aws_region=aws_region, - timeout=timeout, - ) self.sagemaker_finetuning = Client(aws_region=aws_region) \ No newline at end of file diff --git a/tests/test_aws_client.py b/tests/test_aws_client.py new file mode 100644 index 000000000..d361bc184 --- /dev/null +++ b/tests/test_aws_client.py @@ -0,0 +1,140 @@ +# import os +# import unittest + +# import typing +# import cohere +# from parameterized import parameterized_class # type: ignore + +# package_dir = os.path.dirname(os.path.abspath(__file__)) +# embed_job = os.path.join(package_dir, 'embed_job.jsonl') + + +# model_mapping = { +# "bedrock": { +# "chat_model": "cohere.command-r-plus-v1:0", +# "embed_model": "cohere.embed-multilingual-v3", +# "generate_model": "cohere.command-text-v14", +# }, +# "sagemaker": { +# "chat_model": "cohere.command-r-plus-v1:0", +# "embed_model": "cohere.embed-multilingual-v3", +# "generate_model": "cohere-command-light", +# "rerank_model": "rerank", +# }, +# } + + +# @parameterized_class([ +# { +# "platform": "bedrock", +# "client": cohere.BedrockClient( +# timeout=10000, +# aws_region="us-east-1", +# aws_access_key="...", +# aws_secret_key="...", +# aws_session_token="...", +# ), +# "models": model_mapping["bedrock"], +# }, +# { +# "platform": "sagemaker", +# "client": cohere.SagemakerClient( +# timeout=10000, +# aws_region="us-east-2", +# aws_access_key="...", +# aws_secret_key="...", +# aws_session_token="...", +# ), +# "models": model_mapping["sagemaker"], +# } +# ]) +# @unittest.skip("skip tests until they work in CI") +# class TestClient(unittest.TestCase): +# platform: str +# client: cohere.AwsClient +# models: typing.Dict[str, str] + +# def test_rerank(self) -> None: +# if self.platform != "sagemaker": +# self.skipTest("Only sagemaker supports rerank") + +# docs = [ +# 'Carson City is the capital city of the American state of Nevada.', +# 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.', +# 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.', +# 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.'] + +# response = self.client.rerank( +# model=self.models["rerank_model"], +# query='What is the capital of the United States?', +# documents=docs, +# top_n=3, +# ) + +# self.assertEqual(len(response.results), 3) + +# def test_embed(self) -> None: +# response = self.client.embed( +# model=self.models["embed_model"], +# texts=["I love Cohere!"], +# input_type="search_document", +# ) +# print(response) + +# def test_generate(self) -> None: +# response = self.client.generate( +# model=self.models["generate_model"], +# prompt='Please explain to me how LLMs work', +# ) +# print(response) + +# def test_generate_stream(self) -> None: +# response = self.client.generate_stream( +# model=self.models["generate_model"], +# prompt='Please explain to me how LLMs work', +# ) +# for event in response: +# print(event) +# if event.event_type == "text-generation": +# print(event.text, end='') + +# def test_chat(self) -> None: +# response = self.client.chat( +# model=self.models["chat_model"], +# message='Please explain to me how LLMs work', +# ) +# print(response) + +# self.assertIsNotNone(response.text) +# self.assertIsNotNone(response.generation_id) +# self.assertIsNotNone(response.finish_reason) + +# self.assertIsNotNone(response.meta) +# if response.meta is not None: +# self.assertIsNotNone(response.meta.tokens) +# if response.meta.tokens is not None: +# self.assertIsNotNone(response.meta.tokens.input_tokens) +# self.assertIsNotNone(response.meta.tokens.output_tokens) + +# self.assertIsNotNone(response.meta.billed_units) +# if response.meta.billed_units is not None: +# self.assertIsNotNone(response.meta.billed_units.input_tokens) +# self.assertIsNotNone(response.meta.billed_units.input_tokens) + +# def test_chat_stream(self) -> None: +# response_types = set() +# response = self.client.chat_stream( +# model=self.models["chat_model"], +# message='Please explain to me how LLMs work', +# ) +# for event in response: +# response_types.add(event.event_type) +# if event.event_type == "text-generation": +# print(event.text, end='') +# self.assertIsNotNone(event.text) +# if event.event_type == "stream-end": +# self.assertIsNotNone(event.finish_reason) +# self.assertIsNotNone(event.response) +# self.assertIsNotNone(event.response.text) + +# self.assertSetEqual(response_types, {"text-generation", "stream-end"}) diff --git a/tests/test_bedrock_client.py b/tests/test_bedrock_client.py deleted file mode 100644 index d588ca38c..000000000 --- a/tests/test_bedrock_client.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import unittest - -import typing -import cohere - -aws_access_key = os.getenv("AWS_ACCESS_KEY") -aws_secret_key = os.getenv("AWS_SECRET_KEY") -aws_session_token = os.getenv("AWS_SESSION_TOKEN") -aws_region = os.getenv("AWS_REGION") -endpoint_type = os.getenv("ENDPOINT_TYPE") - -@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") -class TestClient(unittest.TestCase): - platform: str = "bedrock" - client: cohere.AwsClient = cohere.BedrockClient( - aws_access_key=aws_access_key, - aws_secret_key=aws_secret_key, - aws_session_token=aws_session_token, - aws_region=aws_region, - ) - models: typing.Dict[str, str] = { - "chat_model": "cohere.command-r-plus-v1:0", - "embed_model": "cohere.embed-multilingual-v3", - "generate_model": "cohere.command-text-v14", - } - - def test_rerank(self) -> None: - if self.platform != "sagemaker": - self.skipTest("Only sagemaker supports rerank") - - docs = [ - 'Carson City is the capital city of the American state of Nevada.', - 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.', - 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.', - 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.'] - - response = self.client.rerank( - model=self.models["rerank_model"], - query='What is the capital of the United States?', - documents=docs, - top_n=3, - ) - - self.assertEqual(len(response.results), 3) - - def test_embed(self) -> None: - response = self.client.embed( - model=self.models["embed_model"], - texts=["I love Cohere!"], - input_type="search_document", - ) - print(response) - - def test_generate(self) -> None: - response = self.client.generate( - model=self.models["generate_model"], - prompt='Please explain to me how LLMs work', - ) - print(response) - - def test_generate_stream(self) -> None: - response = self.client.generate_stream( - model=self.models["generate_model"], - prompt='Please explain to me how LLMs work', - ) - for event in response: - print(event) - if event.event_type == "text-generation": - print(event.text, end='') - - def test_chat(self) -> None: - response = self.client.chat( - model=self.models["chat_model"], - message='Please explain to me how LLMs work', - ) - print(response) - - self.assertIsNotNone(response.text) - self.assertIsNotNone(response.generation_id) - self.assertIsNotNone(response.finish_reason) - - self.assertIsNotNone(response.meta) - if response.meta is not None: - self.assertIsNotNone(response.meta.tokens) - if response.meta.tokens is not None: - self.assertIsNotNone(response.meta.tokens.input_tokens) - self.assertIsNotNone(response.meta.tokens.output_tokens) - - self.assertIsNotNone(response.meta.billed_units) - if response.meta.billed_units is not None: - self.assertIsNotNone(response.meta.billed_units.input_tokens) - self.assertIsNotNone(response.meta.billed_units.input_tokens) - - def test_chat_stream(self) -> None: - response_types = set() - response = self.client.chat_stream( - model=self.models["chat_model"], - message='Please explain to me how LLMs work', - ) - for event in response: - response_types.add(event.event_type) - if event.event_type == "text-generation": - print(event.text, end='') - self.assertIsNotNone(event.text) - if event.event_type == "stream-end": - self.assertIsNotNone(event.finish_reason) - self.assertIsNotNone(event.response) - self.assertIsNotNone(event.response.text) - - self.assertSetEqual(response_types, {"text-generation", "stream-end"})