Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Sep 19, 2024
1 parent 914e296 commit 4cc09ff
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 45 deletions.
8 changes: 0 additions & 8 deletions src/cohere/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
from .detokenize_response import DetokenizeResponse
from .document import Document
from .document_content import DocumentContent
from .document_source import DocumentSource
from .embed_by_type_response import EmbedByTypeResponse
from .embed_by_type_response_embeddings import EmbedByTypeResponseEmbeddings
from .embed_floats_response import EmbedFloatsResponse
Expand Down Expand Up @@ -206,12 +205,8 @@
from .summarize_request_format import SummarizeRequestFormat
from .summarize_request_length import SummarizeRequestLength
from .summarize_response import SummarizeResponse
from .system_message import SystemMessage
from .system_message_content import SystemMessageContent
from .system_message_content_item import SystemMessageContentItem, TextSystemMessageContentItem
from .text_content import TextContent
from .text_response_format import TextResponseFormat
from .text_response_format_v2 import TextResponseFormatV2
from .texts import Texts
from .texts_truncate import TextsTruncate
from .tokenize_response import TokenizeResponse
Expand All @@ -222,20 +217,17 @@
from .tool_call_v2 import ToolCallV2
from .tool_call_v2function import ToolCallV2Function
from .tool_content import DocumentToolContent, TextToolContent, ToolContent
from .tool_message import ToolMessage
from .tool_message_v2 import ToolMessageV2
from .tool_message_v2tool_content import ToolMessageV2ToolContent
from .tool_parameter_definitions_value import ToolParameterDefinitionsValue
from .tool_result import ToolResult
from .tool_source import ToolSource
from .tool_v2 import ToolV2
from .tool_v2function import ToolV2Function
from .unprocessable_entity_error_body import UnprocessableEntityErrorBody
from .update_connector_response import UpdateConnectorResponse
from .usage import Usage
from .usage_billed_units import UsageBilledUnits
from .usage_tokens import UsageTokens
from .user_message import UserMessage
from .user_message_content import UserMessageContent

__all__ = [
Expand Down
10 changes: 5 additions & 5 deletions src/cohere/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import requests
from fastavro import parse_schema, reader, writer

from . import EmbedResponse, EmbedResponse_EmbeddingsFloats, EmbedResponse_EmbeddingsByType, ApiMeta, \
from . import EmbedResponse, EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse, ApiMeta, \
EmbedByTypeResponseEmbeddings, ApiMetaBilledUnits, EmbedJob, CreateEmbedJobResponse, Dataset
from .datasets import DatasetsCreateResponse, DatasetsGetResponse
from .overrides import get_fields
Expand Down Expand Up @@ -194,23 +194,23 @@ def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedRespons
]

if responses[0].response_type == "embeddings_floats":
embeddings_floats = typing.cast(typing.List[EmbedResponse_EmbeddingsFloats], responses)
embeddings_floats = typing.cast(typing.List[EmbeddingsFloatsEmbedResponse], responses)

embeddings = [
embedding
for embeddings_floats in embeddings_floats
for embedding in embeddings_floats.embeddings
]

return EmbedResponse_EmbeddingsFloats(
return EmbeddingsFloatsEmbedResponse(
response_type="embeddings_floats",
id=response_id,
texts=texts,
embeddings=embeddings,
meta=meta
)
else:
embeddings_type = typing.cast(typing.List[EmbedResponse_EmbeddingsByType], responses)
embeddings_type = typing.cast(typing.List[EmbeddingsByTypeEmbedResponse], responses)

embeddings_by_type = [
response.embeddings
Expand All @@ -231,7 +231,7 @@ def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedRespons

embeddings_by_type_merged = EmbedByTypeResponseEmbeddings.parse_obj(merged_dicts)

return EmbedResponse_EmbeddingsByType(
return EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id=response_id,
embeddings=embeddings_by_type_merged,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import cohere
from cohere import ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \
ToolParameterDefinitionsValue, ToolResult, Message_User, Message_Chatbot
ToolParameterDefinitionsValue, ToolResult, UserMessage, ChatbotMessage

package_dir = os.path.dirname(os.path.abspath(__file__))
embed_job = os.path.join(package_dir, 'embed_job.jsonl')
Expand All @@ -26,9 +26,9 @@ async def test_context_manager(self) -> None:
async def test_chat(self) -> None:
chat = await self.co.chat(
chat_history=[
Message_User(
UserMessage(
message="Who discovered gravity?"),
Message_Chatbot(message="The man who is widely credited with discovering "
ChatbotMessage(message="The man who is widely credited with discovering "
"gravity is Sir Isaac Newton")
],
message="What year was he born?",
Expand All @@ -40,9 +40,9 @@ async def test_chat(self) -> None:
async def test_chat_stream(self) -> None:
stream = self.co.chat_stream(
chat_history=[
Message_User(
UserMessage(
message="Who discovered gravity?"),
Message_Chatbot(message="The man who is widely credited with discovering "
ChatbotMessage(message="The man who is widely credited with discovering "
"gravity is Sir Isaac Newton")
],
message="What year was he born?",
Expand Down
12 changes: 6 additions & 6 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import cohere
from cohere import ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \
ToolParameterDefinitionsValue, ToolResult, Message_Chatbot, Message_User, ResponseFormat_JsonObject
ToolParameterDefinitionsValue, ToolResult, ChatbotMessage, UserMessage, JsonObjectResponseFormat

co = cohere.Client(timeout=10000)

Expand All @@ -25,9 +25,9 @@ def test_context_manager(self) -> None:
def test_chat(self) -> None:
chat = co.chat(
chat_history=[
Message_User(
UserMessage(
message="Who discovered gravity?"),
Message_Chatbot(message="The man who is widely credited with discovering "
ChatbotMessage(message="The man who is widely credited with discovering "
"gravity is Sir Isaac Newton")
],
message="What year was he born?",
Expand All @@ -40,7 +40,7 @@ def test_chat(self) -> None:
def test_response_format(self) -> None:
chat = co.chat(
message="imagine a character from the tv show severance",
response_format=ResponseFormat_JsonObject(
response_format=JsonObjectResponseFormat(
schema={
"type": "object",
"properties": {
Expand All @@ -61,9 +61,9 @@ def test_response_format(self) -> None:
def test_chat_stream(self) -> None:
stream = co.chat_stream(
chat_history=[
Message_User(
UserMessage(
message="Who discovered gravity?"),
Message_Chatbot(message="The man who is widely credited with discovering "
ChatbotMessage(message="The man who is widely credited with discovering "
"gravity is Sir Isaac Newton")
],
message="What year was he born?",
Expand Down
20 changes: 10 additions & 10 deletions tests/test_client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest

import cohere
from cohere import ToolMessage2, UserMessage, AssistantMessage
from cohere import ToolMessage, UserMessage, AssistantMessage

co = cohere.ClientV2(timeout=10000)

Expand All @@ -14,12 +14,12 @@
class TestClientV2(unittest.TestCase):

def test_chat(self) -> None:
response = co.chat(model="command-r-plus", messages=[cohere.v2.ChatMessage2_User(content="hello world!")])
response = co.chat(model="command-r-plus", messages=[cohere.UserMessage(message="hello world!")])

print(response.message)

def test_chat_stream(self) -> None:
stream = co.chat_stream(model="command-r-plus", messages=[cohere.v2.ChatMessage2_User(content="hello world!")])
stream = co.chat_stream(model="command-r-plus", messages=[cohere.UserMessage(message="hello world!")])

events = set()

Expand All @@ -43,8 +43,8 @@ def test_chat_documents(self) -> None:
{"title": "widget sales 2021", "text": "4 million"},
]
response = co.chat(
messages=cohere.v2.UserMessage(
content=cohere.v2.TextContent(text="how many widges were sold in 2020?"),
messages=cohere.UserChatMessageV2(
content=cohere.TextContent(text="how many widges were sold in 2020?"),
documents=documents,
),
)
Expand All @@ -67,17 +67,17 @@ def test_chat_tools(self) -> None:
"required": ["location"],
},
}
tools = [cohere.v2.Tool2(type="function", function=get_weather_tool)]
messages: typing.List[typing.Union[UserMessage, AssistantMessage, None, ToolMessage2]] = [
cohere.v2.UserMessage(content="what is the weather in Toronto?")
tools = [cohere.ToolV2(type="function", function=get_weather_tool)]
messages: cohere.ChatMessages = [
cohere.UserChatMessageV2(content="what is the weather in Toronto?")
]
res = co.chat(model="command-r-plus", tools=tools, messages=messages)

# call the get_weather tool
tool_result = {"temperature": "30C"}
tool_content = [cohere.v2.Content(output=tool_result, text="The weather in Toronto is 30C")]
tool_content = [cohere.Content(output=tool_result, text="The weather in Toronto is 30C")]
messages.append(res.message)
messages.append(cohere.v2.ToolMessage2(tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content))
messages.append(cohere.ToolChatMessageV2(tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content))

res = co.chat(tools=tools, messages=messages)
print(res.message)
22 changes: 11 additions & 11 deletions tests/test_embed_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import unittest

from cohere import EmbedResponse_EmbeddingsByType, EmbedByTypeResponseEmbeddings, ApiMeta, ApiMetaBilledUnits, \
ApiMetaApiVersion, EmbedResponse_EmbeddingsFloats
from cohere import EmbeddingsByTypeEmbedResponse, EmbedByTypeResponseEmbeddings, ApiMeta, ApiMetaBilledUnits, \
ApiMetaApiVersion, EmbeddingsFloatsEmbedResponse
from cohere.utils import merge_embed_responses

ebt_1 = EmbedResponse_EmbeddingsByType(
ebt_1 = EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="1",
embeddings=EmbedByTypeResponseEmbeddings(
Expand All @@ -27,7 +27,7 @@
)
)

ebt_2 = EmbedResponse_EmbeddingsByType(
ebt_2 = EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="2",
embeddings=EmbedByTypeResponseEmbeddings(
Expand All @@ -50,7 +50,7 @@
)
)

ebt_partial_1 = EmbedResponse_EmbeddingsByType(
ebt_partial_1 = EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="1",
embeddings=EmbedByTypeResponseEmbeddings(
Expand All @@ -71,7 +71,7 @@
)
)

ebt_partial_2 = EmbedResponse_EmbeddingsByType(
ebt_partial_2 = EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="2",
embeddings=EmbedByTypeResponseEmbeddings(
Expand All @@ -92,7 +92,7 @@
)
)

ebf_1 = EmbedResponse_EmbeddingsFloats(
ebf_1 = EmbeddingsFloatsEmbedResponse(
response_type="embeddings_floats",
id="1",
texts=["hello", "goodbye"],
Expand All @@ -109,7 +109,7 @@
)
)

ebf_2 = EmbedResponse_EmbeddingsFloats(
ebf_2 = EmbeddingsFloatsEmbedResponse(
response_type="embeddings_floats",
id="2",
texts=["bye", "seeya"],
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_merge_embeddings_by_type(self) -> None:
raise Exception("this is just for mpy")

self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"})
self.assertEqual(resp, EmbedResponse_EmbeddingsByType(
self.assertEqual(resp, EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="1, 2",
embeddings=EmbedByTypeResponseEmbeddings(
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_merge_embeddings_floats(self) -> None:
raise Exception("this is just for mpy")

self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"})
self.assertEqual(resp, EmbedResponse_EmbeddingsFloats(
self.assertEqual(resp, EmbeddingsFloatsEmbedResponse(
response_type="embeddings_floats",
id="1, 2",
texts=["hello", "goodbye", "bye", "seeya"],
Expand All @@ -199,7 +199,7 @@ def test_merge_partial_embeddings_floats(self) -> None:
raise Exception("this is just for mpy")

self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"})
self.assertEqual(resp, EmbedResponse_EmbeddingsByType(
self.assertEqual(resp, EmbeddingsByTypeEmbedResponse(
response_type="embeddings_by_type",
id="1, 2",
embeddings=EmbedByTypeResponseEmbeddings(
Expand Down

0 comments on commit 4cc09ff

Please sign in to comment.