diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index a8e31ca..95d5b40 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -177,8 +177,8 @@ class CompletionRequest: return the optimized completion in the completion field of the CompletionResponse. The raw completion, if returned, will contain the un-optimized completion. - steering_concepts_to_apply (Optional[list[str]], default None) - Names of the steering vectors to apply on this task. This steers the output in the + steering_concepts (Optional[list[str]], default None) + Names of the steering vectors to apply on this task. This steers the output in the direction given by positive examples, and away from negative examples if provided. Examples: @@ -219,7 +219,7 @@ class CompletionRequest: control_log_additive: Optional[bool] = True repetition_penalties_include_completion: bool = True raw_completion: bool = False - steering_concepts_to_apply: Optional[List[str]] = None + steering_concepts: Optional[List[str]] = None def to_json(self) -> Mapping[str, Any]: payload = {k: v for k, v in self._asdict().items() if v is not None} diff --git a/tests/test_complete.py b/tests/test_complete.py index cc5eca3..41db92a 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -188,3 +188,21 @@ def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str) number_tokens_completion = len(completion_result.completion_tokens) assert response.num_tokens_generated == best_of * number_tokens_completion + + +@pytest.mark.system_test +def test_steering_completion(sync_client: Client, chat_model_name: str): + request = CompletionRequest( + prompt=Prompt.from_text( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ), + steering_concepts=["shakespeare"], + maximum_tokens=16, + ) + + response = sync_client.complete(request, model=chat_model_name) + completion_result = response.completions[0] + assert completion_result.completion is not None + assert ( + "art" in completion_result.completion + ), "Steered completion should contain Shakespearean language like 'art' for this particular phrase."