diff --git a/tests/test_complete.py b/tests/test_complete.py index acc2acd..863a9fe 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -82,23 +82,31 @@ def test_complete_with_token_ids(sync_client: Client, model_name: str): assert response.model_version is not None -# Re-add as system test and adjust once new behavior is in production -# @pytest.mark.system_test +@pytest.mark.system_test def test_complete_with_optimized_prompt( sync_client: Client, model_name: str, prompt_image: Image ): - prompt_text = " Hello World! " prompt_tokens = Tokens.from_token_ids([1, 2]) request = CompletionRequest( - prompt=Prompt([Text.from_text(prompt_text), prompt_image, prompt_tokens]), + prompt=Prompt( + [ + Text.from_text(" Hello "), + Text.from_text(" world! "), + prompt_image, + prompt_tokens, + Text.from_text(" The "), + Text.from_text(" end "), + ] + ), maximum_tokens=5, ) response = sync_client.complete(request, model=model_name) assert response.optimized_prompt is not None - assert response.optimized_prompt.items[0] == Text.from_text(prompt_text.strip()) + assert response.optimized_prompt.items[0] == Text.from_text("Hello world! ") assert response.optimized_prompt.items[2] == prompt_tokens + assert response.optimized_prompt.items[3] == Text.from_text(" The end") assert isinstance(response.optimized_prompt.items[1], Image)