Skip to content

Commit

Permalink
Add rerank support for sagemaker (#526)
Browse files Browse the repository at this point in the history
* Add rerank to sagemaker cli

* Restore

* Restore

* Fix

* Add skip
  • Loading branch information
billytrend-cohere authored Jun 11, 2024
1 parent 01b3c22 commit 8f5ff50
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tokenizers import Tokenizer # type: ignore

from . import GenerateStreamedResponse, Generation, \
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse
from .client import Client, ClientEnvironment
from .core import construct_type

Expand Down Expand Up @@ -97,6 +97,7 @@ def __iter__(self) -> typing.Iterator[bytes]:
"chat": NonStreamedChatResponse,
"embed": EmbedResponse,
"generate": Generation,
"rerank": RerankResponse
}

stream_response_mapping: typing.Dict[str, typing.Any] = {
Expand Down
29 changes: 26 additions & 3 deletions tests/test_aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
embed_job = os.path.join(package_dir, 'embed_job.jsonl')


models = {
model_mapping = {
"bedrock": {
"chat_model": "cohere.command-r-plus-v1:0",
"embed_model": "cohere.embed-multilingual-v3",
Expand All @@ -19,37 +19,60 @@
"chat_model": "cohere.command-r-plus-v1:0",
"embed_model": "cohere.embed-multilingual-v3",
"generate_model": "cohere-command-light",
"rerank_model": "rerank",
},
}


@parameterized_class([
{
"platform": "bedrock",
"client": cohere.BedrockClient(
timeout=10000,
aws_region="us-east-1",
aws_access_key="...",
aws_secret_key="...",
aws_session_token="...",
),
"models": models["bedrock"],
"models": model_mapping["bedrock"],
},
{
"platform": "sagemaker",
"client": cohere.SagemakerClient(
timeout=10000,
aws_region="us-east-1",
aws_access_key="...",
aws_secret_key="...",
aws_session_token="...",
),
"models": models["sagemaker"],
"models": model_mapping["sagemaker"],
}
])
@unittest.skip("skip tests until they work in CI")
class TestClient(unittest.TestCase):
platform: str
client: cohere.AwsClient
models: typing.Dict[str, str]

def test_rerank(self) -> None:
if self.platform != "sagemaker":
self.skipTest("Only sagemaker supports rerank")

docs = [
'Carson City is the capital city of the American state of Nevada.',
'The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.',
'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.',
'Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.']

response = self.client.rerank(
model=self.models["rerank_model"],
query='What is the capital of the United States?',
documents=docs,
top_n=3,
)

self.assertEqual(len(response.results), 3)

def test_embed(self) -> None:
response = self.client.embed(
model=self.models["embed_model"],
Expand Down

0 comments on commit 8f5ff50

Please sign in to comment.