diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c2037e5..2987fa35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ### Features - Add support for Llama3InstructModel in PromptBasedClassify +- Add TextControl to 'to_instruct_prompt' for instruct models ### Fixes ... diff --git a/src/intelligence_layer/core/model.py b/src/intelligence_layer/core/model.py index a4e095e9..63d6626b 100644 --- a/src/intelligence_layer/core/model.py +++ b/src/intelligence_layer/core/model.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from copy import deepcopy +from dataclasses import replace from functools import lru_cache from typing import Any, ClassVar, Literal, Optional @@ -13,6 +14,7 @@ ExplanationResponse, Prompt, Text, + TextControl, Tokens, ) from pydantic import BaseModel, ConfigDict @@ -22,7 +24,11 @@ AlephAlphaClientProtocol, LimitedConcurrencyClient, ) -from intelligence_layer.core.prompt_template import PromptTemplate, RichPrompt +from intelligence_layer.core.prompt_template import ( + PromptTemplate, + RichPrompt, + TextCursor, +) from intelligence_layer.core.task import Task, Token from intelligence_layer.core.tracer.tracer import TaskSpan, Tracer @@ -408,21 +414,76 @@ def to_instruct_prompt( instruction: str, input: Optional[str] = None, response_prefix: Optional[str] = None, + instruction_controls: Optional[Sequence[TextControl]] = None, + input_controls: Optional[Sequence[TextControl]] = None, ) -> RichPrompt: """Method to create an instruct-`RichPrompt` object to use with any `ControlModel`. - Allows the implementation of a custom prompt format for the specific model in use. - Args: instruction: The task the model should fulfill, for example summarization input: Any context necessary to solve the task, such as the text to be summarize response_prefix: Optional argument to append a string to the beginning of the final agent message to steer the generation + input_controls: TextControls for the input part of the prompt. Only for text prompts + instruction_controls: TextControls for the instruction part of the prompt. Only for text prompts """ - return self.INSTRUCTION_PROMPT_TEMPLATE.to_rich_prompt( + rich_prompt = self.INSTRUCTION_PROMPT_TEMPLATE.to_rich_prompt( instruction=instruction, input=input, response_prefix=response_prefix ) + prompt = rich_prompt.items[0] + ranges = rich_prompt.ranges + + if not isinstance(prompt, Text): + raise ValueError("Text control only valid for text prompts.") + + text_controls: list[TextControl] = [] + + if instruction_controls: + shifted_instruction_controls = self._shift_text_control_ranges( + instruction, instruction_controls, rich_prompt, "instruction" + ) + for shifted_input_control_range in shifted_instruction_controls: + text_controls.append(shifted_input_control_range) + + if input_controls and input: + shifted_input_control_ranges = self._shift_text_control_ranges( + input, input_controls, rich_prompt, "input" + ) + for shifted_input_control_range in shifted_input_control_ranges: + text_controls.append(shifted_input_control_range) + + prompt_with_controls = Prompt.from_text(prompt.text, text_controls) + return RichPrompt(items=prompt_with_controls.items, ranges=ranges) + + def _shift_text_control_ranges( + self, + input: str, + text_controls: Sequence[TextControl], + rich_prompt: RichPrompt, + control_type: str, + ) -> Sequence[TextControl]: + input_start = self._get_text_control_start_index(rich_prompt, control_type) + shifted_controls = [] + for control in text_controls: + if control.start + control.length > len(input): + raise ValueError(f"TextControl is out of bounds for input {input}") + shifted_controls.append(replace(control, start=control.start + input_start)) + return shifted_controls + + def _get_text_control_start_index( + self, rich_prompt: RichPrompt, control_type: str + ) -> int: + prompt_ranges = rich_prompt.ranges.get(control_type) + assert prompt_ranges is not None + assert ( + len(prompt_ranges) == 1 + ), "There should always only be one prompt range per control type." + + assert isinstance(prompt_ranges[0].start, TextCursor) + cursor_start = prompt_ranges[0].start.position + return cursor_start + class LuminousControlModel(ControlModel): """An Aleph Alpha control model of the second generation. @@ -551,12 +612,12 @@ class AlephAlphaChatModel(ChatModel, ControlModel): CHAT_PROMPT_TEMPLATE: PromptTemplate def to_chat_prompt( - self, messages: list[Message], response_prefix: str | None = None + self, + messages: list[Message], + response_prefix: str | None = None, ) -> RichPrompt: """Method to create a chat-`RichPrompt` object to use with any `AlephAlphaModel`. - Allows the implementation of a custom prompt format for the specific model in use. - Args: messages: A number of messages to use as prompt for the model response_prefix: Append the given string to the beginning of the final agent message to @@ -601,12 +662,29 @@ def to_instruct_prompt( instruction: str, input: Optional[str] = None, response_prefix: Optional[str] = None, + instruction_controls: Optional[Sequence[TextControl]] = None, + input_controls: Optional[Sequence[TextControl]] = None, ) -> RichPrompt: + """Method to use a chat model like an instruct model`. + + Args: + instruction: The task the model should fulfill, for example summarization + input: Any context necessary to solve the task, such as the text to be summarized + response_prefix: Optional argument to append a string to the beginning of the + final agent message to steer the generation + instruction_controls: Instruction controls are not used but needed for the interface. + input_controls: Input controls are not used but needed for the interface + """ + if instruction_controls or input_controls: + warnings.warn( + "'instruction_controls' and 'input_controls' are not supported for 'ChatModel'. Parameter(s) will be ignored." + ) + return self.to_chat_prompt( [ Message( role="user", - content=f"{instruction}\n\n{input}" if input else instruction, + content=(f"{instruction}\n\n{input}" if input else instruction), ) ], response_prefix, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index ffee438b..b4d77023 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1,9 +1,15 @@ import random from collections.abc import Mapping, Sequence from typing import Any +from unittest.mock import patch import pytest -from aleph_alpha_client import Prompt, PromptGranularity, Text +from aleph_alpha_client import ( + Prompt, + PromptGranularity, + Text, + TextControl, +) from pytest import fixture from intelligence_layer.connectors import AlephAlphaClientProtocol @@ -20,6 +26,14 @@ Pharia1ChatModel, ) from intelligence_layer.core.model import _cached_context_size, _cached_tokenizer +from intelligence_layer.core.prompt_template import PromptRange, PromptTemplate + +INSTRUCTION = "Who likes pizza?" +INPUT = "Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much." +INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN = ( + 59 # Length of "prefix" of Llama3-Instruct prompt template +) +INPUT_PREFIX_LEN = 2 # Additional new lines after instruction @fixture @@ -67,8 +81,8 @@ def test_llama_3_instruct_model_works(no_op_tracer: NoOpTracer) -> None: llama_3_model = Llama3InstructModel() prompt = llama_3_model.to_instruct_prompt( - "Who likes pizza?", - "Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.", + INSTRUCTION, + INPUT, ) complete_input = CompleteInput(prompt=prompt) @@ -76,6 +90,23 @@ def test_llama_3_instruct_model_works(no_op_tracer: NoOpTracer) -> None: assert "Jessica" in output.completion +class FakeRichPrompt: + items: Sequence[int] = [1] + ranges: Mapping[str, Sequence[PromptRange]] = {} + + +def test_text_control_raises_error_for_non_text_prompt( + no_op_tracer: NoOpTracer, +) -> None: + llama_3_model = Llama3InstructModel() + + with patch.object(PromptTemplate, "to_rich_prompt", return_value=FakeRichPrompt()): # noqa: SIM117 + with pytest.raises( + ValueError, match="Text control only valid for text prompts." + ): + llama_3_model.to_instruct_prompt(INSTRUCTION) + + def test_pharia_1_chat_model_disables_optimizations(no_op_tracer: NoOpTracer) -> None: pharia_1_chat_model = Pharia1ChatModel() @@ -113,8 +144,8 @@ def test_chat_model_can_do_completion(no_op_tracer: NoOpTracer) -> None: llama_3_chat_model = Llama3ChatModel() prompt = llama_3_chat_model.to_instruct_prompt( - "Who likes pizza?", - "Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.", + INSTRUCTION, + INPUT, ) complete_input = CompleteInput(prompt=prompt) @@ -122,17 +153,17 @@ def test_chat_model_can_do_completion(no_op_tracer: NoOpTracer) -> None: assert "Jessica" in output.completion -def test_chat_model_prompt_equals_instruct_prompt() -> None: +def test_chat_model_prompt_text_equals_instruct_prompt() -> None: llama_3_model = Llama3InstructModel() instruct_prompt = llama_3_model.to_instruct_prompt( - "Who likes pizza?", - "Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.", + INSTRUCTION, + INPUT, ).items[0] llama_3_chat_model = Llama3ChatModel() chat_prompt = llama_3_chat_model.to_instruct_prompt( - "Who likes pizza?", - "Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.", + INSTRUCTION, + INPUT, ).items[0] assert isinstance(instruct_prompt, Text) @@ -140,6 +171,106 @@ def test_chat_model_prompt_equals_instruct_prompt() -> None: assert instruct_prompt == chat_prompt +def test_text_control_handles_only_instruction_controls( + no_op_tracer: NoOpTracer, +) -> None: + index_of_focused_word = 5 + text_control = TextControl(start=index_of_focused_word, length=5, factor=10) + llama_3_model = Llama3InstructModel() + + prompt_with_control = llama_3_model.to_instruct_prompt( + INSTRUCTION, + INPUT, + instruction_controls=[text_control], + input_controls=[], + ) + assert ( + prompt_with_control.items[0].controls[0].start # type: ignore + == INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN + index_of_focused_word + ) + + complete_input = CompleteInput(prompt=prompt_with_control) + output = llama_3_model.complete(complete_input, no_op_tracer) + assert "Jessica" in output.completion + + +def test_text_control_handles_only_input_controls(no_op_tracer: NoOpTracer) -> None: + index_of_focused_word = 0 + text_control = TextControl(start=index_of_focused_word, length=5, factor=5) + llama_3_model = Llama3InstructModel() + + prompt_with_control = llama_3_model.to_instruct_prompt( + INSTRUCTION, + INPUT, + instruction_controls=[], + input_controls=[text_control], + ) + assert ( + prompt_with_control.items[0].controls[0].start # type: ignore + == INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN + len(INSTRUCTION) + INPUT_PREFIX_LEN + ) + + +def test_text_control_changes_completion_result(no_op_tracer: NoOpTracer) -> None: + instruction_focus_index = INSTRUCTION.index("likes") + instruction_control = TextControl( + start=instruction_focus_index, length=5, factor=0.10 + ) + input_focus_index = INPUT.index("hated") + input_control = TextControl(start=input_focus_index, length=5, factor=10) + llama_3_model = Llama3InstructModel() + + prompt_with_control = llama_3_model.to_instruct_prompt( + INSTRUCTION, + INPUT, + instruction_controls=[instruction_control], + input_controls=[input_control], + ) + + assert ( + prompt_with_control.items[0].controls[0].start # type: ignore + == INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN + instruction_control.start + ) + + assert ( + prompt_with_control.items[0].controls[1].start # type: ignore + == INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN + + len(INSTRUCTION) + + INPUT_PREFIX_LEN + + input_control.start + ) + + complete_input = CompleteInput(prompt=prompt_with_control) + output = llama_3_model.complete(complete_input, no_op_tracer) + assert "hate" in output.completion + + +def test_text_control_raises_error_when_out_of_instruction_boundaries() -> None: + text_control = TextControl(start=0, length=len(INSTRUCTION) + 1, factor=3) + llama_3_model = Llama3InstructModel() + + with pytest.raises(ValueError): + llama_3_model.to_instruct_prompt( + INSTRUCTION, + INPUT, + instruction_controls=[text_control], + ) + + +def test_text_control_raises_error_when_out_of_input_boundaries() -> None: + text_control = TextControl( + start=len(INSTRUCTION + INPUT) + INPUT_PREFIX_LEN, length=100, factor=3 + ) + llama_3_model = Llama3InstructModel() + + with pytest.raises(ValueError): + llama_3_model.to_instruct_prompt( + INSTRUCTION, + INPUT, + input_controls=[text_control], + ) + + def test_models_know_their_context_size(client: AlephAlphaClientProtocol) -> None: assert ( LuminousControlModel(client=client, name="luminous-base-control").context_size