From 99fd51661c81e60c3e98b3bb2693f78f4fb7365e Mon Sep 17 00:00:00 2001 From: Merlin Kallenborn Date: Fri, 28 Jun 2024 09:30:22 +0200 Subject: [PATCH] feat: Client exposes HTTPAdapter pool size TASK: PHS-622 (Studio Team) --- aleph_alpha_client/aleph_alpha_client.py | 7 +++++- aleph_alpha_client/completion.py | 8 ++++--- aleph_alpha_client/explanation.py | 30 ++++++++++++++---------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index dd0520d..ce50cb5 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -167,6 +167,7 @@ def __init__( nice: bool = False, verify_ssl=True, tags: Optional[Sequence[str]] = None, + pool_size: int = 10, ) -> None: if host[-1] != "/": host += "/" @@ -184,7 +185,11 @@ def __init__( allowed_methods=["POST", "GET"], raise_on_status=False, ) - adapter = HTTPAdapter(max_retries=retry_strategy) + adapter = HTTPAdapter( + max_retries=retry_strategy, + pool_connections=pool_size, + pool_maxsize=pool_size, + ) self.session = requests.Session() self.session.verify = verify_ssl self.session.headers = CaseInsensitiveDict( diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index 128d003..e373fc1 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -284,9 +284,11 @@ def from_json(json: Dict[str, Any]) -> "CompletionResponse": ], num_tokens_prompt_total=json["num_tokens_prompt_total"], num_tokens_generated=json["num_tokens_generated"], - optimized_prompt=Prompt.from_json(optimized_prompt_json) - if optimized_prompt_json - else None, + optimized_prompt=( + Prompt.from_json(optimized_prompt_json) + if optimized_prompt_json + else None + ), ) def to_json(self) -> Mapping[str, Any]: diff --git a/aleph_alpha_client/explanation.py b/aleph_alpha_client/explanation.py index 9c962c9..f7ae023 100644 --- a/aleph_alpha_client/explanation.py +++ b/aleph_alpha_client/explanation.py @@ -172,9 +172,9 @@ class ExplanationRequest: control_factor: Optional[float] = None control_token_overlap: Optional[ControlTokenOverlap] = None control_log_additive: Optional[bool] = None - prompt_granularity: Optional[ - Union[PromptGranularity, str, CustomGranularity] - ] = None + prompt_granularity: Optional[Union[PromptGranularity, str, CustomGranularity]] = ( + None + ) target_granularity: Optional[TargetGranularity] = None postprocessing: Optional[ExplanationPostprocessing] = None normalize: Optional[bool] = None @@ -357,9 +357,11 @@ def from_json(item: Dict[str, Any]) -> "TextPromptItemExplanation": def with_text(self, prompt: Text) -> "TextPromptItemExplanation": return TextPromptItemExplanation( scores=[ - TextScoreWithRaw.from_text_score(score, prompt) - if isinstance(score, TextScore) - else score + ( + TextScoreWithRaw.from_text_score(score, prompt) + if isinstance(score, TextScore) + else score + ) for score in self.scores ] ) @@ -386,9 +388,11 @@ def from_json(item: Dict[str, Any]) -> "TargetPromptItemExplanation": def with_text(self, prompt: str) -> "TargetPromptItemExplanation": return TargetPromptItemExplanation( scores=[ - TargetScoreWithRaw.from_target_score(score, prompt) - if isinstance(score, TargetScore) - else score + ( + TargetScoreWithRaw.from_target_score(score, prompt) + if isinstance(score, TargetScore) + else score + ) for score in self.scores ] ) @@ -461,9 +465,11 @@ def with_image_prompt_items_in_pixels(self, prompt: Prompt) -> "Explanation": return Explanation( target=self.target, items=[ - item.in_pixels(prompt.items[item_index]) - if isinstance(item, ImagePromptItemExplanation) - else item + ( + item.in_pixels(prompt.items[item_index]) + if isinstance(item, ImagePromptItemExplanation) + else item + ) for item_index, item in enumerate(self.items) ], )