Skip to content

Commit

Permalink
Merge pull request #110 from PrefectHQ/typeformatter-works-with-conta…
Browse files Browse the repository at this point in the history
…iners
  • Loading branch information
jlowin authored Mar 30, 2023
2 parents b82755a + 433516e commit 77639f1
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 10 deletions.
Binary file added docs/img/heroes/dont_panic_center.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 27 additions & 9 deletions src/marvin/bots/response_formatters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import re
import warnings
from types import GenericAlias
from typing import Any, Literal

Expand All @@ -9,6 +10,7 @@
import marvin
from marvin.utilities.types import (
DiscriminatedUnionType,
LoggerMixin,
format_type_str,
genericalias_contains,
safe_issubclass,
Expand All @@ -17,7 +19,7 @@
SENTINEL = "__SENTINEL__"


class ResponseFormatter(DiscriminatedUnionType):
class ResponseFormatter(DiscriminatedUnionType, LoggerMixin):
format: str = Field(None, description="The format of the response")
on_error: Literal["reformat", "raise", "ignore"] = "reformat"

Expand Down Expand Up @@ -67,17 +69,31 @@ def __init__(self, type_: type = SENTINEL, **kwargs):
if not isinstance(type_, (type, GenericAlias)):
raise ValueError(f"Expected a type or GenericAlias, got {type_}")

# warn if the type is a set or tuple with GPT 3.5
if marvin.settings.openai_model_name.startswith("gpt-3.5"):
if safe_issubclass(type_, (set, tuple)) or genericalias_contains(
type_, (set, tuple)
):
warnings.warn(
(
"GPT-3.5 often fails with `set` or `tuple` types. Consider"
" using `list` instead."
),
UserWarning,
)

schema = marvin.utilities.types.type_to_schema(type_)

kwargs.update(
type_schema=schema,
format=(
"A valid JSON object that matches this simple type"
f" signature: ```{format_type_str(type_)}``` and equivalent OpenAI"
f" schema: ```{json.dumps(schema)}```. Make sure your response is"
" valid JSON, so use lists instead of sets or tuples; literal"
" `true` and `false` instead of `True` and `False`; literal `null`"
" instead of `None`; and double quotes instead of single quotes."
"A valid JSON object that satisfies this OpenAPI schema:"
f" ```{json.dumps(schema)}```. The JSON object will be coerced to"
f" the following type signature: ```{format_type_str(type_)}```."
" Make sure your response is valid JSON, which means you must use"
" lists instead of tuples or sets; literal `true` and `false`"
" instead of `True` and `False`; literal `null` instead of `None`;"
" and double quotes instead of single quotes."
),
)
super().__init__(**kwargs)
Expand All @@ -97,8 +113,10 @@ def get_type(self) -> type | GenericAlias:
def parse_response(self, response):
type_ = self.get_type()

# handle GenericAlias and containers
if isinstance(type_, GenericAlias):
# handle GenericAlias and containers like dicts
if isinstance(type_, GenericAlias) or safe_issubclass(
type_, (list, dict, set, tuple)
):
return pydantic.parse_raw_as(type_, response)

# handle basic types
Expand Down
8 changes: 7 additions & 1 deletion src/marvin/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,17 @@ def replace_class(generic_alias, old_class, new_class):

def genericalias_contains(genericalias, target_type):
"""
Explore whether a type or generic alias contains a target type.
Explore whether a type or generic alias contains a target type. The target
types can be a single type or a tuple of types.
Useful for seeing if a type contains a pydantic model, for example.
"""
if isinstance(target_type, tuple):
return any(genericalias_contains(genericalias, t) for t in target_type)

if isinstance(genericalias, GenericAlias):
if safe_issubclass(genericalias.__origin__, target_type):
return True
for arg in genericalias.__args__:
if genericalias_contains(arg, target_type):
return True
Expand Down
90 changes: 90 additions & 0 deletions tests/llm_tests/bots/test_ai_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

import marvin
import pydantic
import pytest
from marvin import ai_fn
from marvin.utilities.tests import assert_llm

Expand Down Expand Up @@ -130,6 +132,94 @@ def list_questions(email_body: str) -> list[str]:
assert x == ["What is your favorite color?"]


class TestContainers:
"""tests untyped containers"""

def test_dict(self):
@ai_fn
def dict_response() -> dict:
"""
Returns a dictionary that contains
- name: str
- age: int
"""

response = dict_response()
assert isinstance(response, dict)
assert isinstance(response["name"], str)
assert isinstance(response["age"], int)

def test_list(self):
@ai_fn
def list_response() -> list:
"""
Returns a list that contains two numbers
"""

response = list_response()
assert isinstance(response, list)
assert len(response) == 2
assert isinstance(response[0], (int, float))
assert isinstance(response[1], (int, float))

def test_set(self):
@ai_fn
def set_response() -> set[int]:
"""
Returns a set that contains two numbers, such as {3, 5}
"""

if marvin.settings.openai_model_name.startswith("gpt-3.5"):
with pytest.warns(UserWarning):
response = set_response()
assert isinstance(response, set)
# its unclear what will be in the set

else:
response = set_response()
assert isinstance(response, set)
assert len(response) == 2
assert isinstance(response.pop(), (int, float))
assert isinstance(response.pop(), (int, float))

def test_tuple(self):
@ai_fn
def tuple_response() -> tuple:
"""
Returns a tuple that contains two numbers
"""

if marvin.settings.openai_model_name.startswith("gpt-3.5"):
with pytest.warns(UserWarning):
response = tuple_response()
assert isinstance(response, tuple)
# its unclear what will be in the tuple

else:
response = tuple_response()
assert isinstance(response, tuple)
assert len(response) == 2
assert isinstance(response[0], (int, float))
assert isinstance(response[1], (int, float))

def test_list_of_dicts(self):
@ai_fn
def list_of_dicts_response() -> list[dict]:
"""
Returns a list of 2 dictionaries that each contain
- name: str
- age: int
"""

response = list_of_dicts_response()
assert isinstance(response, list)
assert len(response) == 2
for i in [0, 1]:
assert isinstance(response[i], dict)
assert isinstance(response[i]["name"], str)
assert isinstance(response[i]["age"], int)


class TestSet:
def test_set_response(self):
# https://github.com/PrefectHQ/marvin/issues/54
Expand Down

0 comments on commit 77639f1

Please sign in to comment.