Skip to content

Commit

Permalink
Revert "Add bedrock test and v2 clis (#609)"
Browse files Browse the repository at this point in the history
This reverts commit bd906fa.
  • Loading branch information
andrewbcohere committed Dec 5, 2024
1 parent acddc01 commit 704cfac
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 191 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
version: 1.5.1
virtualenvs-in-project: false
- name: Install dependencies
run: poetry install --extras aws
run: poetry install
- name: Test
run: poetry run pytest .
env:
Expand Down
33 changes: 1 addition & 32 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,9 @@
from .client import Client, ClientEnvironment
from .core import construct_type
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
from .client_v2 import ClientV2

class AwsClient(Client):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
timeout=timeout,
),
)


class AwsClientV2(ClientV2):
class AwsClient(Client):
def __init__(
self,
*,
Expand Down
23 changes: 1 addition & 22 deletions src/cohere/bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from tokenizers import Tokenizer # type: ignore

from .aws_client import AwsClient, AwsClientV2
from .aws_client import AwsClient


class BedrockClient(AwsClient):
Expand All @@ -24,24 +24,3 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)


class BedrockClientV2(AwsClientV2):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClientV2.__init__(
self,
service="bedrock",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
26 changes: 1 addition & 25 deletions src/cohere/sagemaker_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing

from .aws_client import AwsClient, AwsClientV2
from .aws_client import AwsClient
from .manually_maintained.cohere_aws.client import Client
from .manually_maintained.cohere_aws.mode import Mode

Expand All @@ -26,28 +26,4 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
self.sagemaker_finetuning = Client(aws_region=aws_region)


class SagemakerClientV2(AwsClientV2):
sagemaker_finetuning: Client

def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClientV2.__init__(
self,
service="sagemaker",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
self.sagemaker_finetuning = Client(aws_region=aws_region)
140 changes: 140 additions & 0 deletions tests/test_aws_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# import os
# import unittest

# import typing
# import cohere
# from parameterized import parameterized_class # type: ignore

# package_dir = os.path.dirname(os.path.abspath(__file__))
# embed_job = os.path.join(package_dir, 'embed_job.jsonl')


# model_mapping = {
# "bedrock": {
# "chat_model": "cohere.command-r-plus-v1:0",
# "embed_model": "cohere.embed-multilingual-v3",
# "generate_model": "cohere.command-text-v14",
# },
# "sagemaker": {
# "chat_model": "cohere.command-r-plus-v1:0",
# "embed_model": "cohere.embed-multilingual-v3",
# "generate_model": "cohere-command-light",
# "rerank_model": "rerank",
# },
# }


# @parameterized_class([
# {
# "platform": "bedrock",
# "client": cohere.BedrockClient(
# timeout=10000,
# aws_region="us-east-1",
# aws_access_key="...",
# aws_secret_key="...",
# aws_session_token="...",
# ),
# "models": model_mapping["bedrock"],
# },
# {
# "platform": "sagemaker",
# "client": cohere.SagemakerClient(
# timeout=10000,
# aws_region="us-east-2",
# aws_access_key="...",
# aws_secret_key="...",
# aws_session_token="...",
# ),
# "models": model_mapping["sagemaker"],
# }
# ])
# @unittest.skip("skip tests until they work in CI")
# class TestClient(unittest.TestCase):
# platform: str
# client: cohere.AwsClient
# models: typing.Dict[str, str]

# def test_rerank(self) -> None:
# if self.platform != "sagemaker":
# self.skipTest("Only sagemaker supports rerank")

# docs = [
# 'Carson City is the capital city of the American state of Nevada.',
# 'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
# 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
# 'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']

# response = self.client.rerank(
# model=self.models["rerank_model"],
# query='What is the capital of the United States?',
# documents=docs,
# top_n=3,
# )

# self.assertEqual(len(response.results), 3)

# def test_embed(self) -> None:
# response = self.client.embed(
# model=self.models["embed_model"],
# texts=["I love Cohere!"],
# input_type="search_document",
# )
# print(response)

# def test_generate(self) -> None:
# response = self.client.generate(
# model=self.models["generate_model"],
# prompt='Please explain to me how LLMs work',
# )
# print(response)

# def test_generate_stream(self) -> None:
# response = self.client.generate_stream(
# model=self.models["generate_model"],
# prompt='Please explain to me how LLMs work',
# )
# for event in response:
# print(event)
# if event.event_type == "text-generation":
# print(event.text, end='')

# def test_chat(self) -> None:
# response = self.client.chat(
# model=self.models["chat_model"],
# message='Please explain to me how LLMs work',
# )
# print(response)

# self.assertIsNotNone(response.text)
# self.assertIsNotNone(response.generation_id)
# self.assertIsNotNone(response.finish_reason)

# self.assertIsNotNone(response.meta)
# if response.meta is not None:
# self.assertIsNotNone(response.meta.tokens)
# if response.meta.tokens is not None:
# self.assertIsNotNone(response.meta.tokens.input_tokens)
# self.assertIsNotNone(response.meta.tokens.output_tokens)

# self.assertIsNotNone(response.meta.billed_units)
# if response.meta.billed_units is not None:
# self.assertIsNotNone(response.meta.billed_units.input_tokens)
# self.assertIsNotNone(response.meta.billed_units.input_tokens)

# def test_chat_stream(self) -> None:
# response_types = set()
# response = self.client.chat_stream(
# model=self.models["chat_model"],
# message='Please explain to me how LLMs work',
# )
# for event in response:
# response_types.add(event.event_type)
# if event.event_type == "text-generation":
# print(event.text, end='')
# self.assertIsNotNone(event.text)
# if event.event_type == "stream-end":
# self.assertIsNotNone(event.finish_reason)
# self.assertIsNotNone(event.response)
# self.assertIsNotNone(event.response.text)

# self.assertSetEqual(response_types, {"text-generation", "stream-end"})
111 changes: 0 additions & 111 deletions tests/test_bedrock_client.py

This file was deleted.

0 comments on commit 704cfac

Please sign in to comment.