Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert sdk commits blocking staging #617

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
version: 1.5.1
virtualenvs-in-project: false
- name: Install dependencies
run: poetry install --extras aws
run: poetry install
- name: Test
run: poetry run pytest .
env:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
version = "5.13.2"
version = "5.12.0"
description = ""
readme = "README.md"
authors = []
Expand Down
16 changes: 13 additions & 3 deletions reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,7 @@ We recommend a maximum of 1,000 documents for optimal endpoint performance.
<dl>
<dd>

**model:** `typing.Optional[str]` — The identifier of the model to use, eg `rerank-v3.5`.
**model:** `typing.Optional[str]` — The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`

</dd>
</dl>
Expand Down Expand Up @@ -2382,6 +2382,7 @@ response = client.v2.chat_stream(
p=1.1,
return_prompt=True,
logprobs=True,
stream=True,
)
for chunk in response:
yield chunk
Expand Down Expand Up @@ -2654,6 +2655,7 @@ client.v2.chat(
content="messages",
)
],
stream=False,
)

```
Expand Down Expand Up @@ -2971,7 +2973,7 @@ Available models and corresponding embedding dimensions:

**embedding_types:** `typing.Sequence[EmbeddingType]`

Specifies the types of embeddings you want to get back. Can be one or more of the following types.
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.

* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
Expand Down Expand Up @@ -3084,7 +3086,15 @@ client.v2.rerank(
<dl>
<dd>

**model:** `str` — The identifier of the model to use, eg `rerank-v3.5`.
**model:** `str`

The identifier of the model to use.

Supported models:
- `rerank-english-v3.0`
- `rerank-multilingual-v3.0`
- `rerank-english-v2.0`
- `rerank-multilingual-v2.0`

</dd>
</dl>
Expand Down
6 changes: 2 additions & 4 deletions src/cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@
)
from . import connectors, datasets, embed_jobs, finetuning, models, v2
from .aws_client import AwsClient
from .bedrock_client import BedrockClient, BedrockClientV2
from .bedrock_client import BedrockClient
from .client import AsyncClient, Client
from .client_v2 import AsyncClientV2, ClientV2
from .datasets import (
Expand All @@ -257,7 +257,7 @@
)
from .embed_jobs import CreateEmbedJobRequestTruncate
from .environment import ClientEnvironment
from .sagemaker_client import SagemakerClient, SagemakerClientV2
from .sagemaker_client import SagemakerClient
from .v2 import (
V2ChatRequestDocumentsItem,
V2ChatRequestSafetyMode,
Expand Down Expand Up @@ -287,7 +287,6 @@
"AwsClient",
"BadRequestError",
"BedrockClient",
"BedrockClientV2",
"ChatCitation",
"ChatCitationGenerationEvent",
"ChatConnector",
Expand Down Expand Up @@ -461,7 +460,6 @@
"ResponseFormat",
"ResponseFormatV2",
"SagemakerClient",
"SagemakerClientV2",
"SearchQueriesGenerationStreamedChatResponse",
"SearchResultsStreamedChatResponse",
"ServiceUnavailableError",
Expand Down
73 changes: 17 additions & 56 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .client import Client, ClientEnvironment
from .core import construct_type
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
from .client_v2 import ClientV2


class AwsClient(Client):
def __init__(
Expand Down Expand Up @@ -45,37 +45,6 @@ def __init__(
)


class AwsClientV2(ClientV2):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
ClientV2.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
timeout=timeout,
),
)


EventHook = typing.Callable[..., typing.Any]


Expand Down Expand Up @@ -196,13 +165,6 @@ def _hook(

return _hook

def get_boto3_session(
**kwargs: typing.Any,
):
non_none_args = {k: v for k, v in kwargs.items() if v is not None}
return lazy_boto3().Session(**non_none_args)



def map_request_to_bedrock(
service: str,
Expand All @@ -211,22 +173,19 @@ def map_request_to_bedrock(
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> EventHook:
session = get_boto3_session(
session = lazy_boto3().Session(
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
)
aws_region = session.region_name
credentials = session.get_credentials()
signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region)
signer = lazy_botocore().auth.SigV4Auth(credentials, service, session.region_name)

def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
del headers["connection"]


api_version = request.url.path.split("/")[-2]
endpoint = request.url.path.split("/")[-1]
body = json.loads(request.read())
model = body["model"]
Expand All @@ -240,9 +199,6 @@ def _event_hook(request: httpx.Request) -> None:
request.url = URL(url)
request.headers["host"] = request.url.host

if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)

if "stream" in body:
del body["stream"]

Expand All @@ -268,6 +224,20 @@ def _event_hook(request: httpx.Request) -> None:
return _event_hook


def get_endpoint_from_url(url: str,
chat_model: typing.Optional[str] = None,
embed_model: typing.Optional[str] = None,
generate_model: typing.Optional[str] = None,
) -> str:
if chat_model and chat_model in url:
return "chat"
if embed_model and embed_model in url:
return "embed"
if generate_model and generate_model in url:
return "generate"
raise ValueError(f"Unknown endpoint in url: {url}")


def get_url(
*,
platform: str,
Expand All @@ -282,12 +252,3 @@ def get_url(
endpoint = "invocations" if not stream else "invocations-response-stream"
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
return ""


def get_api_version(*, version: str):
int_version = {
"v1": 1,
"v2": 2,
}

return int_version.get(version, 1)
4 changes: 2 additions & 2 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,7 +2012,7 @@ def rerank(
We recommend a maximum of 1,000 documents for optimal endpoint performance.

model : typing.Optional[str]
The identifier of the model to use, eg `rerank-v3.5`.
The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`

top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
Expand Down Expand Up @@ -5047,7 +5047,7 @@ async def rerank(
We recommend a maximum of 1,000 documents for optimal endpoint performance.

model : typing.Optional[str]
The identifier of the model to use, eg `rerank-v3.5`.
The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`

top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
Expand Down
25 changes: 1 addition & 24 deletions src/cohere/bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from tokenizers import Tokenizer # type: ignore

from .aws_client import AwsClient, AwsClientV2
from .aws_client import AwsClient


class BedrockClient(AwsClient):
Expand All @@ -24,26 +24,3 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)

def rerank(self, *, query, documents, model = ..., top_n = ..., rank_fields = ..., return_documents = ..., max_chunks_per_doc = ..., request_options = None):
raise NotImplementedError("Please use cohere.BedrockClientV2 instead: Rerank API on Bedrock is not supported with cohere.BedrockClient for this model.")

class BedrockClientV2(AwsClientV2):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClientV2.__init__(
self,
service="bedrock",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
2 changes: 1 addition & 1 deletion src/cohere/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_headers(self) -> typing.Dict[str, str]:
headers: typing.Dict[str, str] = {
"X-Fern-Language": "Python",
"X-Fern-SDK-Name": "cohere",
"X-Fern-SDK-Version": "5.13.2",
"X-Fern-SDK-Version": "5.12.0",
}
if self._client_name is not None:
headers["X-Client-Name"] = self._client_name
Expand Down
34 changes: 2 additions & 32 deletions src/cohere/sagemaker_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing

from .aws_client import AwsClient, AwsClientV2
from .aws_client import AwsClient
from .manually_maintained.cohere_aws.client import Client
from .manually_maintained.cohere_aws.mode import Mode

Expand All @@ -26,34 +26,4 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
try:
self.sagemaker_finetuning = Client(aws_region=aws_region)
except Exception:
pass


class SagemakerClientV2(AwsClientV2):
sagemaker_finetuning: Client

def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
):
AwsClientV2.__init__(
self,
service="sagemaker",
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
timeout=timeout,
)
try:
self.sagemaker_finetuning = Client(aws_region=aws_region)
except Exception:
pass
self.sagemaker_finetuning = Client(aws_region=aws_region)
Loading