Skip to content

Commit

Permalink
feat: Add TextControl for Instruct and Chat models (#1104)
Browse files Browse the repository at this point in the history
* feat: Add text control to to_instruct_prompt

TASK: PHS-728

* Refactor: Refactor to_instruct_prompt and add test on non-text prompt with controls
TASK: PHS-728

* Refactor: Assert on length of prompt ranges per control type.
TASK: PHS-728

* refactor: helper functions to only use one for `index shifts

---------

Co-authored-by: Johannes Wesch <[email protected]>
Co-authored-by: Sebastian Niehus <[email protected]>
Co-authored-by: Sebastian Niehus <[email protected]>
  • Loading branch information
4 people authored Oct 24, 2024
1 parent f5ee1bc commit 6e67df4
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

### Features
- Add support for Llama3InstructModel in PromptBasedClassify
- Add TextControl to 'to_instruct_prompt' for instruct models

### Fixes
...
Expand Down
94 changes: 86 additions & 8 deletions src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import replace
from functools import lru_cache
from typing import Any, ClassVar, Literal, Optional

Expand All @@ -13,6 +14,7 @@
ExplanationResponse,
Prompt,
Text,
TextControl,
Tokens,
)
from pydantic import BaseModel, ConfigDict
Expand All @@ -22,7 +24,11 @@
AlephAlphaClientProtocol,
LimitedConcurrencyClient,
)
from intelligence_layer.core.prompt_template import PromptTemplate, RichPrompt
from intelligence_layer.core.prompt_template import (
PromptTemplate,
RichPrompt,
TextCursor,
)
from intelligence_layer.core.task import Task, Token
from intelligence_layer.core.tracer.tracer import TaskSpan, Tracer

Expand Down Expand Up @@ -408,21 +414,76 @@ def to_instruct_prompt(
instruction: str,
input: Optional[str] = None,
response_prefix: Optional[str] = None,
instruction_controls: Optional[Sequence[TextControl]] = None,
input_controls: Optional[Sequence[TextControl]] = None,
) -> RichPrompt:
"""Method to create an instruct-`RichPrompt` object to use with any `ControlModel`.
Allows the implementation of a custom prompt format for the specific model in use.
Args:
instruction: The task the model should fulfill, for example summarization
input: Any context necessary to solve the task, such as the text to be summarize
response_prefix: Optional argument to append a string to the beginning of the
final agent message to steer the generation
input_controls: TextControls for the input part of the prompt. Only for text prompts
instruction_controls: TextControls for the instruction part of the prompt. Only for text prompts
"""
return self.INSTRUCTION_PROMPT_TEMPLATE.to_rich_prompt(
rich_prompt = self.INSTRUCTION_PROMPT_TEMPLATE.to_rich_prompt(
instruction=instruction, input=input, response_prefix=response_prefix
)

prompt = rich_prompt.items[0]
ranges = rich_prompt.ranges

if not isinstance(prompt, Text):
raise ValueError("Text control only valid for text prompts.")

text_controls: list[TextControl] = []

if instruction_controls:
shifted_instruction_controls = self._shift_text_control_ranges(
instruction, instruction_controls, rich_prompt, "instruction"
)
for shifted_input_control_range in shifted_instruction_controls:
text_controls.append(shifted_input_control_range)

if input_controls and input:
shifted_input_control_ranges = self._shift_text_control_ranges(
input, input_controls, rich_prompt, "input"
)
for shifted_input_control_range in shifted_input_control_ranges:
text_controls.append(shifted_input_control_range)

prompt_with_controls = Prompt.from_text(prompt.text, text_controls)
return RichPrompt(items=prompt_with_controls.items, ranges=ranges)

def _shift_text_control_ranges(
self,
input: str,
text_controls: Sequence[TextControl],
rich_prompt: RichPrompt,
control_type: str,
) -> Sequence[TextControl]:
input_start = self._get_text_control_start_index(rich_prompt, control_type)
shifted_controls = []
for control in text_controls:
if control.start + control.length > len(input):
raise ValueError(f"TextControl is out of bounds for input {input}")
shifted_controls.append(replace(control, start=control.start + input_start))
return shifted_controls

def _get_text_control_start_index(
self, rich_prompt: RichPrompt, control_type: str
) -> int:
prompt_ranges = rich_prompt.ranges.get(control_type)
assert prompt_ranges is not None
assert (
len(prompt_ranges) == 1
), "There should always only be one prompt range per control type."

assert isinstance(prompt_ranges[0].start, TextCursor)
cursor_start = prompt_ranges[0].start.position
return cursor_start


class LuminousControlModel(ControlModel):
"""An Aleph Alpha control model of the second generation.
Expand Down Expand Up @@ -551,12 +612,12 @@ class AlephAlphaChatModel(ChatModel, ControlModel):
CHAT_PROMPT_TEMPLATE: PromptTemplate

def to_chat_prompt(
self, messages: list[Message], response_prefix: str | None = None
self,
messages: list[Message],
response_prefix: str | None = None,
) -> RichPrompt:
"""Method to create a chat-`RichPrompt` object to use with any `AlephAlphaModel`.
Allows the implementation of a custom prompt format for the specific model in use.
Args:
messages: A number of messages to use as prompt for the model
response_prefix: Append the given string to the beginning of the final agent message to
Expand Down Expand Up @@ -601,12 +662,29 @@ def to_instruct_prompt(
instruction: str,
input: Optional[str] = None,
response_prefix: Optional[str] = None,
instruction_controls: Optional[Sequence[TextControl]] = None,
input_controls: Optional[Sequence[TextControl]] = None,
) -> RichPrompt:
"""Method to use a chat model like an instruct model`.
Args:
instruction: The task the model should fulfill, for example summarization
input: Any context necessary to solve the task, such as the text to be summarized
response_prefix: Optional argument to append a string to the beginning of the
final agent message to steer the generation
instruction_controls: Instruction controls are not used but needed for the interface.
input_controls: Input controls are not used but needed for the interface
"""
if instruction_controls or input_controls:
warnings.warn(
"'instruction_controls' and 'input_controls' are not supported for 'ChatModel'. Parameter(s) will be ignored."
)

return self.to_chat_prompt(
[
Message(
role="user",
content=f"{instruction}\n\n{input}" if input else instruction,
content=(f"{instruction}\n\n{input}" if input else instruction),
)
],
response_prefix,
Expand Down
151 changes: 141 additions & 10 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import random
from collections.abc import Mapping, Sequence
from typing import Any
from unittest.mock import patch

import pytest
from aleph_alpha_client import Prompt, PromptGranularity, Text
from aleph_alpha_client import (
Prompt,
PromptGranularity,
Text,
TextControl,
)
from pytest import fixture

from intelligence_layer.connectors import AlephAlphaClientProtocol
Expand All @@ -20,6 +26,14 @@
Pharia1ChatModel,
)
from intelligence_layer.core.model import _cached_context_size, _cached_tokenizer
from intelligence_layer.core.prompt_template import PromptRange, PromptTemplate

INSTRUCTION = "Who likes pizza?"
INPUT = "Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much."
INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN = (
59 # Length of "prefix" of Llama3-Instruct prompt template
)
INPUT_PREFIX_LEN = 2 # Additional new lines after instruction


@fixture
Expand Down Expand Up @@ -67,15 +81,32 @@ def test_llama_3_instruct_model_works(no_op_tracer: NoOpTracer) -> None:
llama_3_model = Llama3InstructModel()

prompt = llama_3_model.to_instruct_prompt(
"Who likes pizza?",
"Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.",
INSTRUCTION,
INPUT,
)

complete_input = CompleteInput(prompt=prompt)
output = llama_3_model.complete(complete_input, no_op_tracer)
assert "Jessica" in output.completion


class FakeRichPrompt:
items: Sequence[int] = [1]
ranges: Mapping[str, Sequence[PromptRange]] = {}


def test_text_control_raises_error_for_non_text_prompt(
no_op_tracer: NoOpTracer,
) -> None:
llama_3_model = Llama3InstructModel()

with patch.object(PromptTemplate, "to_rich_prompt", return_value=FakeRichPrompt()): # noqa: SIM117
with pytest.raises(
ValueError, match="Text control only valid for text prompts."
):
llama_3_model.to_instruct_prompt(INSTRUCTION)


def test_pharia_1_chat_model_disables_optimizations(no_op_tracer: NoOpTracer) -> None:
pharia_1_chat_model = Pharia1ChatModel()

Expand Down Expand Up @@ -113,33 +144,133 @@ def test_chat_model_can_do_completion(no_op_tracer: NoOpTracer) -> None:
llama_3_chat_model = Llama3ChatModel()

prompt = llama_3_chat_model.to_instruct_prompt(
"Who likes pizza?",
"Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.",
INSTRUCTION,
INPUT,
)

complete_input = CompleteInput(prompt=prompt)
output = llama_3_chat_model.complete(complete_input, no_op_tracer)
assert "Jessica" in output.completion


def test_chat_model_prompt_equals_instruct_prompt() -> None:
def test_chat_model_prompt_text_equals_instruct_prompt() -> None:
llama_3_model = Llama3InstructModel()
instruct_prompt = llama_3_model.to_instruct_prompt(
"Who likes pizza?",
"Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.",
INSTRUCTION,
INPUT,
).items[0]

llama_3_chat_model = Llama3ChatModel()
chat_prompt = llama_3_chat_model.to_instruct_prompt(
"Who likes pizza?",
"Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.",
INSTRUCTION,
INPUT,
).items[0]

assert isinstance(instruct_prompt, Text)
assert isinstance(chat_prompt, Text)
assert instruct_prompt == chat_prompt


def test_text_control_handles_only_instruction_controls(
no_op_tracer: NoOpTracer,
) -> None:
index_of_focused_word = 5
text_control = TextControl(start=index_of_focused_word, length=5, factor=10)
llama_3_model = Llama3InstructModel()

prompt_with_control = llama_3_model.to_instruct_prompt(
INSTRUCTION,
INPUT,
instruction_controls=[text_control],
input_controls=[],
)
assert (
prompt_with_control.items[0].controls[0].start # type: ignore
== INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN + index_of_focused_word
)

complete_input = CompleteInput(prompt=prompt_with_control)
output = llama_3_model.complete(complete_input, no_op_tracer)
assert "Jessica" in output.completion


def test_text_control_handles_only_input_controls(no_op_tracer: NoOpTracer) -> None:
index_of_focused_word = 0
text_control = TextControl(start=index_of_focused_word, length=5, factor=5)
llama_3_model = Llama3InstructModel()

prompt_with_control = llama_3_model.to_instruct_prompt(
INSTRUCTION,
INPUT,
instruction_controls=[],
input_controls=[text_control],
)
assert (
prompt_with_control.items[0].controls[0].start # type: ignore
== INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN + len(INSTRUCTION) + INPUT_PREFIX_LEN
)


def test_text_control_changes_completion_result(no_op_tracer: NoOpTracer) -> None:
instruction_focus_index = INSTRUCTION.index("likes")
instruction_control = TextControl(
start=instruction_focus_index, length=5, factor=0.10
)
input_focus_index = INPUT.index("hated")
input_control = TextControl(start=input_focus_index, length=5, factor=10)
llama_3_model = Llama3InstructModel()

prompt_with_control = llama_3_model.to_instruct_prompt(
INSTRUCTION,
INPUT,
instruction_controls=[instruction_control],
input_controls=[input_control],
)

assert (
prompt_with_control.items[0].controls[0].start # type: ignore
== INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN + instruction_control.start
)

assert (
prompt_with_control.items[0].controls[1].start # type: ignore
== INSTRUCTION_PROMPT_TEMPLATE_PREFIX_LEN
+ len(INSTRUCTION)
+ INPUT_PREFIX_LEN
+ input_control.start
)

complete_input = CompleteInput(prompt=prompt_with_control)
output = llama_3_model.complete(complete_input, no_op_tracer)
assert "hate" in output.completion


def test_text_control_raises_error_when_out_of_instruction_boundaries() -> None:
text_control = TextControl(start=0, length=len(INSTRUCTION) + 1, factor=3)
llama_3_model = Llama3InstructModel()

with pytest.raises(ValueError):
llama_3_model.to_instruct_prompt(
INSTRUCTION,
INPUT,
instruction_controls=[text_control],
)


def test_text_control_raises_error_when_out_of_input_boundaries() -> None:
text_control = TextControl(
start=len(INSTRUCTION + INPUT) + INPUT_PREFIX_LEN, length=100, factor=3
)
llama_3_model = Llama3InstructModel()

with pytest.raises(ValueError):
llama_3_model.to_instruct_prompt(
INSTRUCTION,
INPUT,
input_controls=[text_control],
)


def test_models_know_their_context_size(client: AlephAlphaClientProtocol) -> None:
assert (
LuminousControlModel(client=client, name="luminous-base-control").context_size
Expand Down

0 comments on commit 6e67df4

Please sign in to comment.