From c039e1d8a093405929be22e29416a10a12a037be Mon Sep 17 00:00:00 2001 From: pabloiyu Date: Thu, 19 Dec 2024 13:19:41 +0000 Subject: [PATCH] Add test to check steering functionality --- aleph_alpha_client/completion.py | 4 ++-- tests/test_complete.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index a8e31ca..701d088 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -177,7 +177,7 @@ class CompletionRequest: return the optimized completion in the completion field of the CompletionResponse. The raw completion, if returned, will contain the un-optimized completion. - steering_concepts_to_apply (Optional[list[str]], default None) + steering_concepts (Optional[list[str]], default None) Names of the steering vectors to apply on this task. This steers the output in the direction given by positive examples, and away from negative examples if provided. @@ -219,7 +219,7 @@ class CompletionRequest: control_log_additive: Optional[bool] = True repetition_penalties_include_completion: bool = True raw_completion: bool = False - steering_concepts_to_apply: Optional[List[str]] = None + steering_concepts: Optional[List[str]] = None def to_json(self) -> Mapping[str, Any]: payload = {k: v for k, v in self._asdict().items() if v is not None} diff --git a/tests/test_complete.py b/tests/test_complete.py index cc5eca3..7fc75f8 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -188,3 +188,16 @@ def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str) number_tokens_completion = len(completion_result.completion_tokens) assert response.num_tokens_generated == best_of * number_tokens_completion + + +@pytest.mark.system_test +def test_steering_completion(sync_client: Client, chat_model_name: str): + request = CompletionRequest( + prompt=Prompt.from_text("<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"), + steering_concepts=["shakespeare"], + maximum_tokens=16, + ) + + response = sync_client.complete(request, model=chat_model_name) + completion_result = response.completions[0] + assert ("art" in completion_result.completion), "Steered completion should contain Shakespearean language like 'art' for this particular phrase."