diff --git a/tests/test_client_v2.py b/tests/test_client_v2.py index 946d7a958..6630a1b58 100644 --- a/tests/test_client_v2.py +++ b/tests/test_client_v2.py @@ -43,8 +43,8 @@ def test_chat_documents(self) -> None: {"title": "widget sales 2021", "text": "4 million"}, ] response = co.chat( - messages=cohere.UserMessage( - message=cohere.TextContent(text="how many widges were sold in 2020?"), + messages=cohere.UserV2ChatMessage( + content=cohere.TextContent(text="how many widges were sold in 2020?"), documents=documents, ), ) @@ -68,8 +68,8 @@ def test_chat_tools(self) -> None: }, } tools = [cohere.V2Tool(type="function", function=get_weather_tool)] - messages: typing.List[typing.Union[UserMessage, AssistantMessage, None, cohere.V2ToolMessage]] = [ - UserMessage(message="what is the weather in Toronto?") + messages: cohere.ChatMessages = [ + cohere.UserV2ChatMessage(content="what is the weather in Toronto?") ] res = co.chat(model="command-r-plus", tools=tools, messages=messages) @@ -77,7 +77,7 @@ def test_chat_tools(self) -> None: tool_result = {"temperature": "30C"} tool_content = [cohere.Content(output=tool_result, text="The weather in Toronto is 30C")] messages.append(res.message) - messages.append(cohere.V2ToolMessage(tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content)) + messages.append(cohere.ToolV2ChatMessage(tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content)) res = co.chat(tools=tools, messages=messages) print(res.message)