diff --git a/src/cohere/client.py b/src/cohere/client.py index d503c5610..e7603a293 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -6,6 +6,23 @@ from .environment import CohereEnvironment +def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None: + method = getattr(obj, method_name) + + def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + check_fn(*args, **kwargs) + return method(*args, **kwargs) + + setattr(obj, method_name, wrapped) + + +def throw_if_stream_is_true(*args, **kwargs) -> None: + if kwargs.get("stream") is True: + raise ValueError( + "Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)" + ) + + class Client(BaseCohere): def __init__( self, @@ -27,6 +44,8 @@ def __init__( httpx_client=httpx_client, ) + validate_args(self, "chat", throw_if_stream_is_true) + class AsyncClient(AsyncBaseCohere): def __init__( @@ -48,3 +67,5 @@ def __init__( timeout=timeout, httpx_client=httpx_client, ) + + validate_args(self, "chat", throw_if_stream_is_true) \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py index 58b3d5b66..8baff6af2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -28,6 +28,13 @@ def test_chat(self) -> None: print(chat) + def test_stream_equals_true(self) -> None: + with self.assertRaises(ValueError): + co.chat( + stream=True, # type: ignore + message="What year was he born?", + ) + def test_generate(self) -> None: response = co.generate( prompt='Please explain to me how LLMs work',