Skip to content

Commit

Permalink
Revert back to b1dbf95 due to next commit breaking staging
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewbcohere committed Dec 5, 2024
1 parent 61f6f71 commit acddc01
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 78 deletions.
11 changes: 1 addition & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
version = "5.13.2"
version = "5.12.0"
description = ""
readme = "README.md"
authors = []
Expand Down Expand Up @@ -36,15 +36,6 @@ boto3 = { version="^1.34.0", optional = true}
fastavro = "^1.9.4"
httpx = ">=0.21.2"
httpx-sse = "0.4.0"
# Without specifying a version, the numpy and pandas will be 1.24.4 and 2.0.3 respectively
numpy = [
{ version="~1.24.4", python = "<3.12", optional = true },
{ version="~1.26", python = ">=3.12", optional = true }
]
pandas = [
{ version="~2.0.3", python = "<3.13", optional = true },
{ version="~2.2.3", python = ">=3.13", optional = true }
]
parameterized = "^0.9.0"
pydantic = ">= 1.9.2"
pydantic-core = "^2.18.2"
Expand Down
16 changes: 13 additions & 3 deletions reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,7 @@ We recommend a maximum of 1,000 documents for optimal endpoint performance.
<dl>
<dd>

**model:** `typing.Optional[str]` — The identifier of the model to use, eg `rerank-v3.5`.
**model:** `typing.Optional[str]` — The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`

</dd>
</dl>
Expand Down Expand Up @@ -2382,6 +2382,7 @@ response = client.v2.chat_stream(
p=1.1,
return_prompt=True,
logprobs=True,
stream=True,
)
for chunk in response:
yield chunk
Expand Down Expand Up @@ -2654,6 +2655,7 @@ client.v2.chat(
content="messages",
)
],
stream=False,
)

```
Expand Down Expand Up @@ -2971,7 +2973,7 @@ Available models and corresponding embedding dimensions:

**embedding_types:** `typing.Sequence[EmbeddingType]`

Specifies the types of embeddings you want to get back. Can be one or more of the following types.
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.

* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
Expand Down Expand Up @@ -3084,7 +3086,15 @@ client.v2.rerank(
<dl>
<dd>

**model:** `str` — The identifier of the model to use, eg `rerank-v3.5`.
**model:** `str`

The identifier of the model to use.

Supported models:
- `rerank-english-v3.0`
- `rerank-multilingual-v3.0`
- `rerank-english-v2.0`
- `rerank-multilingual-v2.0`

</dd>
</dl>
Expand Down
6 changes: 2 additions & 4 deletions src/cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@
)
from . import connectors, datasets, embed_jobs, finetuning, models, v2
from .aws_client import AwsClient
from .bedrock_client import BedrockClient, BedrockClientV2
from .bedrock_client import BedrockClient
from .client import AsyncClient, Client
from .client_v2 import AsyncClientV2, ClientV2
from .datasets import (
Expand All @@ -257,7 +257,7 @@
)
from .embed_jobs import CreateEmbedJobRequestTruncate
from .environment import ClientEnvironment
from .sagemaker_client import SagemakerClient, SagemakerClientV2
from .sagemaker_client import SagemakerClient
from .v2 import (
V2ChatRequestDocumentsItem,
V2ChatRequestSafetyMode,
Expand Down Expand Up @@ -287,7 +287,6 @@
"AwsClient",
"BadRequestError",
"BedrockClient",
"BedrockClientV2",
"ChatCitation",
"ChatCitationGenerationEvent",
"ChatConnector",
Expand Down Expand Up @@ -461,7 +460,6 @@
"ResponseFormat",
"ResponseFormatV2",
"SagemakerClient",
"SagemakerClientV2",
"SearchQueriesGenerationStreamedChatResponse",
"SearchResultsStreamedChatResponse",
"ServiceUnavailableError",
Expand Down
42 changes: 17 additions & 25 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
ClientV2.__init__(
Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
Expand Down Expand Up @@ -196,13 +196,6 @@ def _hook(

return _hook

def get_boto3_session(
**kwargs: typing.Any,
):
non_none_args = {k: v for k, v in kwargs.items() if v is not None}
return lazy_boto3().Session(**non_none_args)



def map_request_to_bedrock(
service: str,
Expand All @@ -211,22 +204,19 @@ def map_request_to_bedrock(
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> EventHook:
session = get_boto3_session(
session = lazy_boto3().Session(
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
)
aws_region = session.region_name
credentials = session.get_credentials()
signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region)
signer = lazy_botocore().auth.SigV4Auth(credentials, service, session.region_name)

def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
del headers["connection"]


api_version = request.url.path.split("/")[-2]
endpoint = request.url.path.split("/")[-1]
body = json.loads(request.read())
model = body["model"]
Expand All @@ -240,9 +230,6 @@ def _event_hook(request: httpx.Request) -> None:
request.url = URL(url)
request.headers["host"] = request.url.host

if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)

if "stream" in body:
del body["stream"]

Expand All @@ -268,6 +255,20 @@ def _event_hook(request: httpx.Request) -> None:
return _event_hook


def get_endpoint_from_url(url: str,
chat_model: typing.Optional[str] = None,
embed_model: typing.Optional[str] = None,
generate_model: typing.Optional[str] = None,
) -> str:
if chat_model and chat_model in url:
return "chat"
if embed_model and embed_model in url:
return "embed"
if generate_model and generate_model in url:
return "generate"
raise ValueError(f"Unknown endpoint in url: {url}")


def get_url(
*,
platform: str,
Expand All @@ -282,12 +283,3 @@ def get_url(
endpoint = "invocations" if not stream else "invocations-response-stream"
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
return ""


def get_api_version(*, version: str):
int_version = {
"v1": 1,
"v2": 2,
}

return int_version.get(version, 1)
4 changes: 2 additions & 2 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,7 +2012,7 @@ def rerank(
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
The identifier of the model to use, eg `rerank-v3.5`.
The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
Expand Down Expand Up @@ -5047,7 +5047,7 @@ async def rerank(
We recommend a maximum of 1,000 documents for optimal endpoint performance.
model : typing.Optional[str]
The identifier of the model to use, eg `rerank-v3.5`.
The identifier of the model to use, one of : `rerank-english-v3.0`, `rerank-multilingual-v3.0`, `rerank-english-v2.0`, `rerank-multilingual-v2.0`
top_n : typing.Optional[int]
The number of most relevant documents or indices to return, defaults to the length of the documents
Expand Down
2 changes: 0 additions & 2 deletions src/cohere/bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def __init__(
timeout=timeout,
)

def rerank(self, *, query, documents, model = ..., top_n = ..., rank_fields = ..., return_documents = ..., max_chunks_per_doc = ..., request_options = None):
raise NotImplementedError("Please use cohere.BedrockClientV2 instead: Rerank API on Bedrock is not supported with cohere.BedrockClient for this model.")

class BedrockClientV2(AwsClientV2):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion src/cohere/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_headers(self) -> typing.Dict[str, str]:
headers: typing.Dict[str, str] = {
"X-Fern-Language": "Python",
"X-Fern-SDK-Name": "cohere",
"X-Fern-SDK-Version": "5.13.2",
"X-Fern-SDK-Version": "5.12.0",
}
if self._client_name is not None:
headers["X-Client-Name"] = self._client_name
Expand Down
10 changes: 2 additions & 8 deletions src/cohere/sagemaker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
try:
self.sagemaker_finetuning = Client(aws_region=aws_region)
except Exception:
pass
self.sagemaker_finetuning = Client(aws_region=aws_region)


class SagemakerClientV2(AwsClientV2):
Expand All @@ -53,7 +50,4 @@ def __init__(
aws_region=aws_region,
timeout=timeout,
)
try:
self.sagemaker_finetuning = Client(aws_region=aws_region)
except Exception:
pass
self.sagemaker_finetuning = Client(aws_region=aws_region)
27 changes: 20 additions & 7 deletions src/cohere/v2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def chat_stream(
p=1.1,
return_prompt=True,
logprobs=True,
stream=True,
)
for chunk in response:
yield chunk
Expand Down Expand Up @@ -534,6 +535,7 @@ def chat(
content="messages",
)
],
stream=False,
)
"""
_response = self._client_wrapper.httpx_client.request(
Expand Down Expand Up @@ -568,7 +570,6 @@ def chat(
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
"stream": False,
},
request_options=request_options,
omit=OMIT,
Expand Down Expand Up @@ -736,7 +737,7 @@ def embed(
input_type : EmbedInputType
embedding_types : typing.Sequence[EmbeddingType]
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
Expand Down Expand Up @@ -936,7 +937,13 @@ def rerank(
Parameters
----------
model : str
The identifier of the model to use, eg `rerank-v3.5`.
The identifier of the model to use.
Supported models:
- `rerank-english-v3.0`
- `rerank-multilingual-v3.0`
- `rerank-english-v2.0`
- `rerank-multilingual-v2.0`
query : str
The search query
Expand Down Expand Up @@ -1301,6 +1308,7 @@ async def main() -> None:
p=1.1,
return_prompt=True,
logprobs=True,
stream=True,
)
async for chunk in response:
yield chunk
Expand Down Expand Up @@ -1340,7 +1348,6 @@ async def main() -> None:
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
"stream": True,
},
request_options=request_options,
omit=OMIT,
Expand Down Expand Up @@ -1621,6 +1628,7 @@ async def main() -> None:
content="messages",
)
],
stream=False,
)
Expand Down Expand Up @@ -1658,7 +1666,6 @@ async def main() -> None:
"p": p,
"return_prompt": return_prompt,
"logprobs": logprobs,
"stream": False,
},
request_options=request_options,
omit=OMIT,
Expand Down Expand Up @@ -1826,7 +1833,7 @@ async def embed(
input_type : EmbedInputType
embedding_types : typing.Sequence[EmbeddingType]
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types.
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
Expand Down Expand Up @@ -2034,7 +2041,13 @@ async def rerank(
Parameters
----------
model : str
The identifier of the model to use, eg `rerank-v3.5`.
The identifier of the model to use.
Supported models:
- `rerank-english-v3.0`
- `rerank-multilingual-v3.0`
- `rerank-english-v2.0`
- `rerank-multilingual-v2.0`
query : str
The search query
Expand Down
16 changes: 0 additions & 16 deletions tests/test_client_init.py

This file was deleted.

0 comments on commit acddc01

Please sign in to comment.