Skip to content

Commit

Permalink
Add v2 clients
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Nov 29, 2024
1 parent 061f005 commit 245949d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 3 deletions.
33 changes: 32 additions & 1 deletion src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .client import Client, ClientEnvironment
from .core import construct_type
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore

from .client_v2 import ClientV2

class AwsClient(Client):
def __init__(
Expand Down Expand Up @@ -45,6 +45,37 @@ def __init__(
)


class AwsClientV2(ClientV2):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
timeout=timeout,
),
)


EventHook = typing.Callable[..., typing.Any]


Expand Down
23 changes: 22 additions & 1 deletion src/cohere/bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from tokenizers import Tokenizer # type: ignore

from .aws_client import AwsClient
from .aws_client import AwsClient, AwsClientV2


class BedrockClient(AwsClient):
Expand All @@ -24,3 +24,24 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)


class BedrockClientV2(AwsClientV2):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClientV2.__init__(
self,
service="bedrock",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
26 changes: 25 additions & 1 deletion src/cohere/sagemaker_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing

from .aws_client import AwsClient
from .aws_client import AwsClient, AwsClientV2
from .manually_maintained.cohere_aws.client import Client
from .manually_maintained.cohere_aws.mode import Mode

Expand All @@ -26,4 +26,28 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
self.sagemaker_finetuning = Client(aws_region=aws_region)


class SagemakerClientV2(AwsClientV2):
sagemaker_finetuning: Client

def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClientV2.__init__(
self,
service="sagemaker",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
self.sagemaker_finetuning = Client(aws_region=aws_region)

0 comments on commit 245949d

Please sign in to comment.