Skip to content

Commit

Permalink
Throw if stream=True (#404)
Browse files Browse the repository at this point in the history
* Throw if stream=True

* Fix types
  • Loading branch information
billytrend-cohere authored Mar 18, 2024
1 parent d5546a7 commit 045ceba
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@
from .environment import CohereEnvironment


def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None:
method = getattr(obj, method_name)

def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
check_fn(*args, **kwargs)
return method(*args, **kwargs)

setattr(obj, method_name, wrapped)


def throw_if_stream_is_true(*args, **kwargs) -> None:
if kwargs.get("stream") is True:
raise ValueError(
"Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)"
)


class Client(BaseCohere):
def __init__(
self,
Expand All @@ -27,6 +44,8 @@ def __init__(
httpx_client=httpx_client,
)

validate_args(self, "chat", throw_if_stream_is_true)


class AsyncClient(AsyncBaseCohere):
def __init__(
Expand All @@ -48,3 +67,5 @@ def __init__(
timeout=timeout,
httpx_client=httpx_client,
)

validate_args(self, "chat", throw_if_stream_is_true)
7 changes: 7 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def test_chat(self) -> None:

print(chat)

def test_stream_equals_true(self) -> None:
with self.assertRaises(ValueError):
co.chat(
stream=True, # type: ignore
message="What year was he born?",
)

def test_generate(self) -> None:
response = co.generate(
prompt='Please explain to me how LLMs work',
Expand Down

0 comments on commit 045ceba

Please sign in to comment.