Skip to content

Commit

Permalink
added default for ranges and fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pitneitemeier committed Feb 20, 2024
1 parent 9c66971 commit d10d1fa
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/intelligence_layer/core/prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from dataclasses import dataclass, replace
from dataclasses import dataclass, field, replace
from itertools import chain
from re import finditer
from sys import intern
Expand Down Expand Up @@ -84,7 +84,7 @@ class RichPrompt(Prompt):
ranges: A mapping of range name to a `Sequence` of corresponding `PromptRange` instances.
"""

ranges: Mapping[str, Sequence[PromptRange]]
ranges: Mapping[str, Sequence[PromptRange]] = field(default_factory=dict)


PROMPT_RANGE_TAG = intern("promptrange")
Expand Down
44 changes: 26 additions & 18 deletions tests/core/test_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
from textwrap import dedent
from typing import List

from aleph_alpha_client.prompt import Image, Prompt, PromptItem, Text, Tokens
from aleph_alpha_client import Prompt
from aleph_alpha_client.prompt import Image, PromptItem, Text, Tokens
from liquid.exceptions import LiquidSyntaxError, LiquidTypeError
from pytest import raises

from intelligence_layer.core.prompt_template import (
PromptItemCursor,
PromptRange,
PromptTemplate,
RichPrompt,
TextCursor,
)


def rich_prompt_from_text(text: str) -> RichPrompt:
return RichPrompt(items=[Text.from_text(text)])


def test_to_prompt_with_text_array() -> None:
template = PromptTemplate(
"""
Expand All @@ -27,7 +33,7 @@ def test_to_prompt_with_text_array() -> None:
prompt = template.to_rich_prompt(names=names)

expected = "".join([f"Hello {name}!\n" for name in names])
assert prompt == Prompt.from_text(expected)
assert prompt == rich_prompt_from_text(expected)


def test_to_prompt_with_invalid_input() -> None:
Expand All @@ -53,12 +59,13 @@ def test_to_prompt_with_single_image(prompt_image: Image) -> None:

prompt = template.to_rich_prompt(whatever=template.placeholder(prompt_image))

expected = Prompt(
expected = RichPrompt(
[
Text.from_text("Some Text.\n"),
prompt_image,
Text.from_text("\nMore Text\n"),
]
],
ranges={},
)
assert prompt == expected

Expand All @@ -76,7 +83,7 @@ def test_to_prompt_with_image_sequence(prompt_image: Image) -> None:
images=[template.placeholder(prompt_image), template.placeholder(prompt_image)]
)

expected = Prompt([prompt_image, prompt_image])
expected = RichPrompt([prompt_image, prompt_image])
assert prompt == expected


Expand All @@ -87,7 +94,7 @@ def test_to_prompt_with_mixed_modality_variables(prompt_image: Image) -> None:
image=template.placeholder(prompt_image), name="whatever"
)

expected = Prompt([prompt_image, Text.from_text("whatever"), prompt_image])
expected = RichPrompt([prompt_image, Text.from_text("whatever"), prompt_image])
assert prompt == expected


Expand All @@ -96,7 +103,8 @@ def test_to_prompt_with_unused_image(prompt_image: Image) -> None:

prompt = template.to_rich_prompt(images=template.placeholder(prompt_image))

assert prompt == Prompt.from_text("cool")
assert prompt == rich_prompt_from_text("cool")
assert prompt == RichPrompt(items=[Text("cool", controls=[])])


def test_to_prompt_with_multiple_different_images(prompt_image: Image) -> None:
Expand All @@ -110,11 +118,11 @@ def test_to_prompt_with_multiple_different_images(prompt_image: Image) -> None:
image_2=template.placeholder(second_image),
)

assert prompt == Prompt([prompt_image, second_image])
assert prompt == RichPrompt([prompt_image, second_image])


def test_to_prompt_with_embedded_prompt(prompt_image: Image) -> None:
user_prompt = Prompt([Text.from_text("Cool"), prompt_image])
user_prompt = RichPrompt([Text.from_text("Cool"), prompt_image])

template = PromptTemplate("""{{user_prompt}}""")

Expand All @@ -124,7 +132,7 @@ def test_to_prompt_with_embedded_prompt(prompt_image: Image) -> None:


def test_to_prompt_does_not_add_whitespace_after_image(prompt_image: Image) -> None:
user_prompt = Prompt([prompt_image, Text.from_text("Cool"), prompt_image])
user_prompt = RichPrompt([prompt_image, Text.from_text("Cool"), prompt_image])

template = PromptTemplate("{{user_prompt}}")

Expand All @@ -134,31 +142,31 @@ def test_to_prompt_does_not_add_whitespace_after_image(prompt_image: Image) -> N


def test_to_prompt_skips_empty_strings() -> None:
user_prompt = Prompt(
user_prompt = RichPrompt(
[Text.from_text("Cool"), Text.from_text(""), Text.from_text("Also cool")]
)

template = PromptTemplate("{{user_prompt}}")

prompt = template.to_rich_prompt(user_prompt=template.embed_prompt(user_prompt))

assert prompt == Prompt([Text.from_text("Cool Also cool")])
assert prompt == RichPrompt([Text.from_text("Cool Also cool")])


def test_to_prompt_adds_whitespaces() -> None:
user_prompt = Prompt(
user_prompt = RichPrompt(
[Text.from_text("start "), Text.from_text("middle"), Text.from_text(" end")]
)

template = PromptTemplate("{{user_prompt}}")

prompt = template.to_rich_prompt(user_prompt=template.embed_prompt(user_prompt))

assert prompt == Prompt([Text.from_text("start middle end")])
assert prompt == RichPrompt([Text.from_text("start middle end")])


def test_to_prompt_works_with_tokens() -> None:
user_prompt = Prompt(
user_prompt = RichPrompt(
[
Tokens.from_token_ids([1, 2, 3]),
Text.from_text("cool"),
Expand All @@ -174,7 +182,7 @@ def test_to_prompt_works_with_tokens() -> None:


def test_to_prompt_with_empty_template() -> None:
assert PromptTemplate("").to_rich_prompt() == Prompt([])
assert PromptTemplate("").to_rich_prompt() == RichPrompt([])


def test_to_prompt_resets_template(prompt_image: Image) -> None:
Expand Down Expand Up @@ -202,11 +210,11 @@ def test_to_prompt_data_returns_ranges(prompt_image: Image) -> None:
)

prompt_data = template.to_rich_prompt(
prefix_items=template.embed_prompt(Prompt(prefix_items + [prefix_merged])),
prefix_items=template.embed_prompt(RichPrompt(prefix_items + [prefix_merged])),
prefix_text=prefix_text,
embedded_text=embedded_text,
embedded_items=template.embed_prompt(
Prompt([embedded_merged] + embedded_items)
RichPrompt([embedded_merged] + embedded_items)
),
)

Expand Down

0 comments on commit d10d1fa

Please sign in to comment.