From d10d1fa110ccff238d927d39a1930bfe0f807f57 Mon Sep 17 00:00:00 2001 From: pitneitemeier Date: Tue, 20 Feb 2024 17:27:46 +0100 Subject: [PATCH] added default for ranges and fixed tests --- .../core/prompt_template.py | 4 +- tests/core/test_prompt_template.py | 44 +++++++++++-------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/intelligence_layer/core/prompt_template.py b/src/intelligence_layer/core/prompt_template.py index 6d2970871..df1282ddb 100644 --- a/src/intelligence_layer/core/prompt_template.py +++ b/src/intelligence_layer/core/prompt_template.py @@ -1,5 +1,5 @@ from collections import defaultdict -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from itertools import chain from re import finditer from sys import intern @@ -84,7 +84,7 @@ class RichPrompt(Prompt): ranges: A mapping of range name to a `Sequence` of corresponding `PromptRange` instances. """ - ranges: Mapping[str, Sequence[PromptRange]] + ranges: Mapping[str, Sequence[PromptRange]] = field(default_factory=dict) PROMPT_RANGE_TAG = intern("promptrange") diff --git a/tests/core/test_prompt_template.py b/tests/core/test_prompt_template.py index 74db3f3d3..36aea2efa 100644 --- a/tests/core/test_prompt_template.py +++ b/tests/core/test_prompt_template.py @@ -2,7 +2,8 @@ from textwrap import dedent from typing import List -from aleph_alpha_client.prompt import Image, Prompt, PromptItem, Text, Tokens +from aleph_alpha_client import Prompt +from aleph_alpha_client.prompt import Image, PromptItem, Text, Tokens from liquid.exceptions import LiquidSyntaxError, LiquidTypeError from pytest import raises @@ -10,10 +11,15 @@ PromptItemCursor, PromptRange, PromptTemplate, + RichPrompt, TextCursor, ) +def rich_prompt_from_text(text: str) -> RichPrompt: + return RichPrompt(items=[Text.from_text(text)]) + + def test_to_prompt_with_text_array() -> None: template = PromptTemplate( """ @@ -27,7 +33,7 @@ def test_to_prompt_with_text_array() -> None: prompt = template.to_rich_prompt(names=names) expected = "".join([f"Hello {name}!\n" for name in names]) - assert prompt == Prompt.from_text(expected) + assert prompt == rich_prompt_from_text(expected) def test_to_prompt_with_invalid_input() -> None: @@ -53,12 +59,13 @@ def test_to_prompt_with_single_image(prompt_image: Image) -> None: prompt = template.to_rich_prompt(whatever=template.placeholder(prompt_image)) - expected = Prompt( + expected = RichPrompt( [ Text.from_text("Some Text.\n"), prompt_image, Text.from_text("\nMore Text\n"), - ] + ], + ranges={}, ) assert prompt == expected @@ -76,7 +83,7 @@ def test_to_prompt_with_image_sequence(prompt_image: Image) -> None: images=[template.placeholder(prompt_image), template.placeholder(prompt_image)] ) - expected = Prompt([prompt_image, prompt_image]) + expected = RichPrompt([prompt_image, prompt_image]) assert prompt == expected @@ -87,7 +94,7 @@ def test_to_prompt_with_mixed_modality_variables(prompt_image: Image) -> None: image=template.placeholder(prompt_image), name="whatever" ) - expected = Prompt([prompt_image, Text.from_text("whatever"), prompt_image]) + expected = RichPrompt([prompt_image, Text.from_text("whatever"), prompt_image]) assert prompt == expected @@ -96,7 +103,8 @@ def test_to_prompt_with_unused_image(prompt_image: Image) -> None: prompt = template.to_rich_prompt(images=template.placeholder(prompt_image)) - assert prompt == Prompt.from_text("cool") + assert prompt == rich_prompt_from_text("cool") + assert prompt == RichPrompt(items=[Text("cool", controls=[])]) def test_to_prompt_with_multiple_different_images(prompt_image: Image) -> None: @@ -110,11 +118,11 @@ def test_to_prompt_with_multiple_different_images(prompt_image: Image) -> None: image_2=template.placeholder(second_image), ) - assert prompt == Prompt([prompt_image, second_image]) + assert prompt == RichPrompt([prompt_image, second_image]) def test_to_prompt_with_embedded_prompt(prompt_image: Image) -> None: - user_prompt = Prompt([Text.from_text("Cool"), prompt_image]) + user_prompt = RichPrompt([Text.from_text("Cool"), prompt_image]) template = PromptTemplate("""{{user_prompt}}""") @@ -124,7 +132,7 @@ def test_to_prompt_with_embedded_prompt(prompt_image: Image) -> None: def test_to_prompt_does_not_add_whitespace_after_image(prompt_image: Image) -> None: - user_prompt = Prompt([prompt_image, Text.from_text("Cool"), prompt_image]) + user_prompt = RichPrompt([prompt_image, Text.from_text("Cool"), prompt_image]) template = PromptTemplate("{{user_prompt}}") @@ -134,7 +142,7 @@ def test_to_prompt_does_not_add_whitespace_after_image(prompt_image: Image) -> N def test_to_prompt_skips_empty_strings() -> None: - user_prompt = Prompt( + user_prompt = RichPrompt( [Text.from_text("Cool"), Text.from_text(""), Text.from_text("Also cool")] ) @@ -142,11 +150,11 @@ def test_to_prompt_skips_empty_strings() -> None: prompt = template.to_rich_prompt(user_prompt=template.embed_prompt(user_prompt)) - assert prompt == Prompt([Text.from_text("Cool Also cool")]) + assert prompt == RichPrompt([Text.from_text("Cool Also cool")]) def test_to_prompt_adds_whitespaces() -> None: - user_prompt = Prompt( + user_prompt = RichPrompt( [Text.from_text("start "), Text.from_text("middle"), Text.from_text(" end")] ) @@ -154,11 +162,11 @@ def test_to_prompt_adds_whitespaces() -> None: prompt = template.to_rich_prompt(user_prompt=template.embed_prompt(user_prompt)) - assert prompt == Prompt([Text.from_text("start middle end")]) + assert prompt == RichPrompt([Text.from_text("start middle end")]) def test_to_prompt_works_with_tokens() -> None: - user_prompt = Prompt( + user_prompt = RichPrompt( [ Tokens.from_token_ids([1, 2, 3]), Text.from_text("cool"), @@ -174,7 +182,7 @@ def test_to_prompt_works_with_tokens() -> None: def test_to_prompt_with_empty_template() -> None: - assert PromptTemplate("").to_rich_prompt() == Prompt([]) + assert PromptTemplate("").to_rich_prompt() == RichPrompt([]) def test_to_prompt_resets_template(prompt_image: Image) -> None: @@ -202,11 +210,11 @@ def test_to_prompt_data_returns_ranges(prompt_image: Image) -> None: ) prompt_data = template.to_rich_prompt( - prefix_items=template.embed_prompt(Prompt(prefix_items + [prefix_merged])), + prefix_items=template.embed_prompt(RichPrompt(prefix_items + [prefix_merged])), prefix_text=prefix_text, embedded_text=embedded_text, embedded_items=template.embed_prompt( - Prompt([embedded_merged] + embedded_items) + RichPrompt([embedded_merged] + embedded_items) ), )