diff --git a/README.md b/README.md index 5f1291bab..d7c573232 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Meet Marvin: a batteries-included library for building AI-powered software. Marvin's job is to integrate AI directly into your codebase by making it look and feel like any other function. -Marvin introduces a new concept called [**AI Functions**](https://www.askmarvin.ai/guide/concepts/ai_functions.md). These functions differ from conventional ones in that they don’t rely on source code, but instead generate their outputs on-demand through AI. With AI functions, you don't have to write complex code for tasks like extracting entities from web pages, scoring sentiment, or categorizing items in your database. Just describe your needs, call the function, and you're done! +Marvin introduces a new concept called [**AI Functions**](https://askmarvin.ai/guide/concepts/ai_functions). These functions differ from conventional ones in that they don’t rely on source code, but instead generate their outputs on-demand through AI. With AI functions, you don't have to write complex code for tasks like extracting entities from web pages, scoring sentiment, or categorizing items in your database. Just describe your needs, call the function, and you're done! AI functions work with native data types, so you can seamlessly integrate them into any codebase and chain them into sophisticated pipelines. Technically speaking, Marvin transforms the signature of using AI from `(str) -> str` to `(**kwargs) -> Any`. We call this **"functional prompt engineering."** diff --git a/docs/img/heroes/dont_panic_center.png b/docs/img/heroes/dont_panic_center.png new file mode 100644 index 000000000..e95eb7e8e Binary files /dev/null and b/docs/img/heroes/dont_panic_center.png differ diff --git a/src/marvin/bots/response_formatters.py b/src/marvin/bots/response_formatters.py index fb0819971..ff5318716 100644 --- a/src/marvin/bots/response_formatters.py +++ b/src/marvin/bots/response_formatters.py @@ -1,5 +1,6 @@ import json import re +import warnings from types import GenericAlias from typing import Any, Literal @@ -9,6 +10,7 @@ import marvin from marvin.utilities.types import ( DiscriminatedUnionType, + LoggerMixin, format_type_str, genericalias_contains, safe_issubclass, @@ -17,7 +19,7 @@ SENTINEL = "__SENTINEL__" -class ResponseFormatter(DiscriminatedUnionType): +class ResponseFormatter(DiscriminatedUnionType, LoggerMixin): format: str = Field(None, description="The format of the response") on_error: Literal["reformat", "raise", "ignore"] = "reformat" @@ -67,17 +69,31 @@ def __init__(self, type_: type = SENTINEL, **kwargs): if not isinstance(type_, (type, GenericAlias)): raise ValueError(f"Expected a type or GenericAlias, got {type_}") + # warn if the type is a set or tuple with GPT 3.5 + if marvin.settings.openai_model_name.startswith("gpt-3.5"): + if safe_issubclass(type_, (set, tuple)) or genericalias_contains( + type_, (set, tuple) + ): + warnings.warn( + ( + "GPT-3.5 often fails with `set` or `tuple` types. Consider" + " using `list` instead." + ), + UserWarning, + ) + schema = marvin.utilities.types.type_to_schema(type_) kwargs.update( type_schema=schema, format=( - "A valid JSON object that matches this simple type" - f" signature: ```{format_type_str(type_)}``` and equivalent OpenAI" - f" schema: ```{json.dumps(schema)}```. Make sure your response is" - " valid JSON, so use lists instead of sets or tuples; literal" - " `true` and `false` instead of `True` and `False`; literal `null`" - " instead of `None`; and double quotes instead of single quotes." + "A valid JSON object that satisfies this OpenAPI schema:" + f" ```{json.dumps(schema)}```. The JSON object will be coerced to" + f" the following type signature: ```{format_type_str(type_)}```." + " Make sure your response is valid JSON, which means you must use" + " lists instead of tuples or sets; literal `true` and `false`" + " instead of `True` and `False`; literal `null` instead of `None`;" + " and double quotes instead of single quotes." ), ) super().__init__(**kwargs) @@ -97,8 +113,10 @@ def get_type(self) -> type | GenericAlias: def parse_response(self, response): type_ = self.get_type() - # handle GenericAlias and containers - if isinstance(type_, GenericAlias): + # handle GenericAlias and containers like dicts + if isinstance(type_, GenericAlias) or safe_issubclass( + type_, (list, dict, set, tuple) + ): return pydantic.parse_raw_as(type_, response) # handle basic types diff --git a/src/marvin/utilities/tests.py b/src/marvin/utilities/tests.py index ce9590bee..e8b7a8900 100644 --- a/src/marvin/utilities/tests.py +++ b/src/marvin/utilities/tests.py @@ -26,18 +26,33 @@ def assert_approx_equal(statement_1: str, statement_2: str): @ai_fn() -def assert_llm(output: Any, expectation: Any) -> bool: +def _assert_llm(output: Any, expectation: Any) -> bool: """ - Given the `output` of an LLM and an expectation, determines whether the - output satisfies the expectation. + This function is used to unit test LLM outputs. The LLM `output` is compared + to an `expectation` of what the output is, contains, or represents. The + function returns `true` if the output satisfies the expectation and `false` + otherwise. The expectation does not need to be matched exactly. If the + expectation and output are semantically the same, the function should return + true. + For example: - `assert_llm(5, "output == 5")` will return `True` - `assert_llm(["red", "orange"], "a list of colors")` will return `True` - `assert_llm(["red", "house"], "a list of colors")` will return `False` + assert_llm(5, "5") -> True + assert_llm("Greetings, friend!", "Hello, how are you?") -> True + assert_llm("Hello, friend!", "a greeting") -> True + assert_llm("I'm good, thanks!", "Hello, how are you?") -> False + assert_llm(["red", "orange"], "a list of colors") -> True + assert_llm(["red", "house"], "a list of colors") -> False """ +def assert_llm(output: str, expectation: Any): + if not _assert_llm(output, expectation): + raise AssertionError( + f"Output {output} does not satisfy expectation {expectation}" + ) + + @asynccontextmanager async def timer(): start_time = asyncio.get_running_loop().time() diff --git a/src/marvin/utilities/types.py b/src/marvin/utilities/types.py index d0db45cb9..51fb31ab4 100644 --- a/src/marvin/utilities/types.py +++ b/src/marvin/utilities/types.py @@ -388,11 +388,17 @@ def replace_class(generic_alias, old_class, new_class): def genericalias_contains(genericalias, target_type): """ - Explore whether a type or generic alias contains a target type. + Explore whether a type or generic alias contains a target type. The target + types can be a single type or a tuple of types. Useful for seeing if a type contains a pydantic model, for example. """ + if isinstance(target_type, tuple): + return any(genericalias_contains(genericalias, t) for t in target_type) + if isinstance(genericalias, GenericAlias): + if safe_issubclass(genericalias.__origin__, target_type): + return True for arg in genericalias.__args__: if genericalias_contains(arg, target_type): return True diff --git a/tests/llm_tests/bots/test_ai_functions.py b/tests/llm_tests/bots/test_ai_functions.py index 998a29e02..3bb0809f3 100644 --- a/tests/llm_tests/bots/test_ai_functions.py +++ b/tests/llm_tests/bots/test_ai_functions.py @@ -1,6 +1,8 @@ from typing import Optional +import marvin import pydantic +import pytest from marvin import ai_fn from marvin.utilities.tests import assert_llm @@ -59,7 +61,7 @@ def fake_people(n: int) -> list[dict]: assert all(isinstance(person, dict) for person in x) assert all("name" in person for person in x) assert all("age" in person for person in x) - assert_llm(x, "a list of fake people") + assert_llm(x, "a list of people data including name and age") def test_generate_rhyming_words(self): @ai_fn @@ -69,7 +71,7 @@ def rhymes(word: str) -> str: x = rhymes("blue") assert isinstance(x, str) assert x != "blue" - assert_llm(x, "a word that rhymes with blue") + assert_llm(x, "the output is any word that rhymes with blue") def test_generate_rhyming_words_with_n(self): @ai_fn @@ -81,7 +83,13 @@ def rhymes(word: str, n: int) -> list[str]: assert len(x) == 3 assert all(isinstance(word, str) for word in x) assert all(word != "blue" for word in x) - assert_llm(x, "a list of words that rhyme with blue") + assert_llm( + x, + ( + "the output is a list of words, each one rhyming with 'blue'. For" + " example ['clue', 'dew', 'flew']" + ), + ) class TestBool: @@ -124,6 +132,94 @@ def list_questions(email_body: str) -> list[str]: assert x == ["What is your favorite color?"] +class TestContainers: + """tests untyped containers""" + + def test_dict(self): + @ai_fn + def dict_response() -> dict: + """ + Returns a dictionary that contains + - name: str + - age: int + """ + + response = dict_response() + assert isinstance(response, dict) + assert isinstance(response["name"], str) + assert isinstance(response["age"], int) + + def test_list(self): + @ai_fn + def list_response() -> list: + """ + Returns a list that contains two numbers + """ + + response = list_response() + assert isinstance(response, list) + assert len(response) == 2 + assert isinstance(response[0], (int, float)) + assert isinstance(response[1], (int, float)) + + def test_set(self): + @ai_fn + def set_response() -> set[int]: + """ + Returns a set that contains two numbers, such as {3, 5} + """ + + if marvin.settings.openai_model_name.startswith("gpt-3.5"): + with pytest.warns(UserWarning): + response = set_response() + assert isinstance(response, set) + # its unclear what will be in the set + + else: + response = set_response() + assert isinstance(response, set) + assert len(response) == 2 + assert isinstance(response.pop(), (int, float)) + assert isinstance(response.pop(), (int, float)) + + def test_tuple(self): + @ai_fn + def tuple_response() -> tuple: + """ + Returns a tuple that contains two numbers + """ + + if marvin.settings.openai_model_name.startswith("gpt-3.5"): + with pytest.warns(UserWarning): + response = tuple_response() + assert isinstance(response, tuple) + # its unclear what will be in the tuple + + else: + response = tuple_response() + assert isinstance(response, tuple) + assert len(response) == 2 + assert isinstance(response[0], (int, float)) + assert isinstance(response[1], (int, float)) + + def test_list_of_dicts(self): + @ai_fn + def list_of_dicts_response() -> list[dict]: + """ + Returns a list of 2 dictionaries that each contain + - name: str + - age: int + """ + + response = list_of_dicts_response() + assert isinstance(response, list) + assert len(response) == 2 + for i in [0, 1]: + assert isinstance(response[i], dict) + assert isinstance(response[i]["name"], str) + assert isinstance(response[i]["age"], int) + + class TestSet: def test_set_response(self): # https://github.com/PrefectHQ/marvin/issues/54 diff --git a/tests/llm_tests/bots/test_bots.py b/tests/llm_tests/bots/test_bots.py index 47a22d888..221535ee7 100644 --- a/tests/llm_tests/bots/test_bots.py +++ b/tests/llm_tests/bots/test_bots.py @@ -6,7 +6,7 @@ class TestBotResponse: @pytest.mark.parametrize( "message,expected_response", - [("hello", "Greetings. How may I assist you today?")], + [("Say only the word 'red'", "Red")], ) async def test_simple_response(self, message, expected_response): bot = Bot() @@ -15,11 +15,14 @@ async def test_simple_response(self, message, expected_response): async def test_memory(self): bot = Bot() - response = await bot.say("My favorite color is blue") + response = await bot.say("Hello, favorite color is blue") response = await bot.say("What is my favorite color?") assert_llm( response.content, - "You told me that your favorite color is blue", + ( + "Based on your previous message, you mentioned that your favorite color" + " is blue. Is that still correct?" + ), )