From 8fdf8d9a8d35c8781171cc1e20f1666f0c87cde4 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Fri, 18 Oct 2024 17:11:42 +0100 Subject: [PATCH] Gate deps --- src/cohere/aws_client.py | 13 ++++++++--- src/cohere/bedrock_client.py | 3 --- .../manually_maintained/cohere_aws/client.py | 23 +++++++++++-------- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py index 08a52d927..95099682f 100644 --- a/src/cohere/aws_client.py +++ b/src/cohere/aws_client.py @@ -3,10 +3,7 @@ import re import typing -import boto3 # type: ignore import httpx -from botocore.auth import SigV4Auth # type: ignore -from botocore.awsrequest import AWSRequest # type: ignore from httpx import URL, SyncByteStream, ByteStream from tokenizers import Tokenizer # type: ignore @@ -17,6 +14,14 @@ from .core import construct_type +try: + import boto3 # type: ignore + from botocore.auth import SigV4Auth # type: ignore + from botocore.awsrequest import AWSRequest # type: ignore + AWS_DEPS_AVAILABLE = True + except ImportError: + AWS_DEPS_AVAILABLE = False + class AwsClient(Client): def __init__( self, @@ -28,6 +33,8 @@ def __init__( timeout: typing.Optional[float] = None, service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]], ): + if not AWS_DEPS_AVAILABLE: + raise ImportError("AWS dependencies not available. Please install boto3 and botocore.") Client.__init__( self, base_url="https://api.cohere.com", # this url is unused for BedrockClient diff --git a/src/cohere/bedrock_client.py b/src/cohere/bedrock_client.py index 038941147..bcc24786a 100644 --- a/src/cohere/bedrock_client.py +++ b/src/cohere/bedrock_client.py @@ -1,8 +1,5 @@ import typing -import boto3 # type: ignore -from botocore.auth import SigV4Auth # type: ignore -from botocore.awsrequest import AWSRequest # type: ignore from tokenizers import Tokenizer # type: ignore from .aws_client import AwsClient diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index bacb3a97e..5d2a852c4 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -5,12 +5,6 @@ import time from typing import Any, Dict, List, Optional, Tuple, Union -import boto3 -import sagemaker as sage -from botocore.exceptions import (ClientError, EndpointConnectionError, - ParamValidationError) -from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url - from .classification import Classification, Classifications from .embeddings import Embeddings from .error import CohereError @@ -23,7 +17,17 @@ from .mode import Mode import typing -class Client: +# Try to import sagemaker and related modules +try: + import sagemaker as sage + from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url + import boto3 + from botocore.exceptions import ( + ClientError, EndpointConnectionError, ParamValidationError) + AWS_DEPS_AVAILABLE = True + except ImportError: + AWS_DEPS_AVAILABLE = False + def __init__( self, aws_region: typing.Optional[str] = None, @@ -32,8 +36,9 @@ def __init__( By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with `aws configure set region us-west-2` or override it with `region_name` parameter. """ - self._client = boto3.client("sagemaker-runtime", region_name=aws_region) - self._service_client = boto3.client("sagemaker", region_name=aws_region) + if not AWS_DEPS_AVAILABLE: + raise CohereError("AWS dependencies not available. Please install boto3 and sagemaker.") + self._client = boto3.client( if os.environ.get('AWS_DEFAULT_REGION') is None: os.environ['AWS_DEFAULT_REGION'] = aws_region self._sess = sage.Session(sagemaker_client=self._service_client)