From 8126e73b425bfd332f984f0244ff552b535bbf28 Mon Sep 17 00:00:00 2001 From: Abel Essiane <167091830+coessiane@users.noreply.github.com> Date: Tue, 10 Dec 2024 12:23:29 -0500 Subject: [PATCH] fix: Allows clients to pass in an HTTP Session (#59) This pull request introduces changes to the `__init__` method of the `compass.py` module, which is responsible for initializing a compass client to interact with the Compass API. The modifications primarily involve updating the handling of the `default_timeout` and introducing a new `http_session` parameter. ## Changes: - The `default_timeout` parameter's type hint has been updated from `int` to `int | None`, indicating that it can now accept `None` as a valid value. - A new parameter, `http_session`, has been added to the `__init__` method, allowing for the specification of an HTTP session with a custom timeout. - The instantiation of the `SessionWithDefaultTimeout` class has been replaced with a conditional assignment, setting `self.session` to either the provided `http_session` or a new instance of `requests.Session()`. - A warning message is logged if `default_timeout` is not `None`, indicating that the variable is deprecated and will not have any effect. This message also provides guidance on using the `http_session` parameter for specifying HTTP request timeouts. --- cohere/compass/clients/compass.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/cohere/compass/clients/compass.py b/cohere/compass/clients/compass.py index 3331b4c..e5cf54a 100644 --- a/cohere/compass/clients/compass.py +++ b/cohere/compass/clients/compass.py @@ -65,22 +65,9 @@ class RetryResult: error: Optional[str] = None -_DEFAULT_TIMEOUT = 30 - - logger = logging.getLogger(__name__) -class SessionWithDefaultTimeout(requests.Session): - def __init__(self, timeout: int): - self._timeout = timeout - super().__init__() - - def request(self, *args: Any, **kwargs: Any): - kwargs.setdefault("timeout", self._timeout) - return super().request(*args, **kwargs) - - class CompassClient: def __init__( self, @@ -89,7 +76,7 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, bearer_token: Optional[str] = None, - default_timeout: int = _DEFAULT_TIMEOUT, + http_session: Optional[requests.Session] = None, ): """ A compass client to interact with the Compass API @@ -100,7 +87,7 @@ def __init__( self.index_url = index_url self.username = username or os.getenv("COHERE_COMPASS_USERNAME") self.password = password or os.getenv("COHERE_COMPASS_PASSWORD") - self.session = SessionWithDefaultTimeout(default_timeout) + self.session = http_session or requests.Session() self.bearer_token = bearer_token self.api_method = {