diff --git a/cohere/compass/clients/compass.py b/cohere/compass/clients/compass.py index 3331b4c..e5cf54a 100644 --- a/cohere/compass/clients/compass.py +++ b/cohere/compass/clients/compass.py @@ -65,22 +65,9 @@ class RetryResult: error: Optional[str] = None -_DEFAULT_TIMEOUT = 30 - - logger = logging.getLogger(__name__) -class SessionWithDefaultTimeout(requests.Session): - def __init__(self, timeout: int): - self._timeout = timeout - super().__init__() - - def request(self, *args: Any, **kwargs: Any): - kwargs.setdefault("timeout", self._timeout) - return super().request(*args, **kwargs) - - class CompassClient: def __init__( self, @@ -89,7 +76,7 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, bearer_token: Optional[str] = None, - default_timeout: int = _DEFAULT_TIMEOUT, + http_session: Optional[requests.Session] = None, ): """ A compass client to interact with the Compass API @@ -100,7 +87,7 @@ def __init__( self.index_url = index_url self.username = username or os.getenv("COHERE_COMPASS_USERNAME") self.password = password or os.getenv("COHERE_COMPASS_PASSWORD") - self.session = SessionWithDefaultTimeout(default_timeout) + self.session = http_session or requests.Session() self.bearer_token = bearer_token self.api_method = {