Skip to content

Commit

Permalink
Cleaning up guidance functions for list generation
Browse files Browse the repository at this point in the history
Unifying the pattern for generating n items with a given regex/option list
  • Loading branch information
parkervg committed Oct 23, 2024
1 parent 3a9fc90 commit 809484b
Showing 1 changed file with 94 additions and 68 deletions.
162 changes: 94 additions & 68 deletions blendsql/ingredients/builtin/qa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
cast_responses_to_datatypes,
partialclass,
)
from blendsql._constants import ModifierType
from blendsql._constants import ModifierType, RegexPatterns
from .examples import QAExample, AnnotatedQAExample

MAIN_INSTRUCTION = "Answer the question given the table context.\n"
Expand All @@ -35,6 +35,68 @@
]


def get_modifier_wrapper(
modifier: ModifierType,
) -> Callable[[guidance.models.Model], guidance.models.Model]:
modifier_wrapper = lambda x: x
if modifier is not None:
if modifier == "*":
modifier_wrapper = guidance.zero_or_more
elif modifier == "+":
modifier_wrapper = guidance.one_or_more
elif re.match("{\d+}", modifier):
repeats = [
int(i) for i in modifier.replace("}", "").replace("{", "").split(",")
]
if len(repeats) == 1:
repeats = repeats * 2
min_length, max_length = repeats
modifier_wrapper = lambda f: guidance.sequence(
f, min_length=min_length, max_length=max_length
)
return modifier_wrapper


@guidance(stateless=True)
def gen_list(
lm, add_quotes: bool, modifier=None, options: List[str] = None, regex=None
):
if options:
single_item = guidance.select(options, list_append=True, name="response")
else:
single_item = guidance.gen(
max_tokens=20, regex=regex or "[^],]", list_append=True, name="response"
)
if add_quotes:
single_item = "'" + single_item + "'"
single_item += guidance.optional(", ")
return lm + "[" + get_modifier_wrapper(modifier)(single_item) + "]"


def get_option_aliases(options: Optional[List[str]], is_list_output: bool):
options_alias_to_original = {}
options_with_aliases = None
if options is not None:
# Since 'options' is a mutable list, create a copy to retain the originals
options_with_aliases = copy.deepcopy(options)
# Below we check to see if our options have a unique first word
# sometimes, the model will generate 'Frank' instead of 'Frank Smith'
# We still want to align that, in this case
add_first_word = False
if len(set([i.split(" ")[0] for i in options])) == len(options):
add_first_word = True
for option in options:
option = str(option)
for option_alias in [option.title(), option.lower(), option.upper()]:
options_with_aliases.add(option_alias)
options_alias_to_original[option_alias] = option
if add_first_word:
option_alias = option.split(" ")[0]
options_alias_to_original[option_alias] = option
options_with_aliases.add(option_alias)
return options_with_aliases or options, options_alias_to_original


class QAProgram(Program):
def __call__(
self,
Expand All @@ -51,44 +113,23 @@ def __call__(
) -> Tuple[str, str]:
if isinstance(model, LocalModel):
lm: guidance.models.Model = model.model_obj
context_formatter(
current_example.context
) if current_example.context is not None else ""
is_list_output = (
modifier is not None
or current_example.output_type is not None
and "list" in current_example.output_type.lower()
)
options_alias_to_original = {}
# Resolve regex, if we haven't been passed one
if regex is None and current_example.output_type:
if "integer" in current_example.output_type:
regex = RegexPatterns.INTEGER
elif "boolean" in current_example.output_type:
regex = RegexPatterns.BOOLEAN
elif "float" in current_example.output_type:
regex = RegexPatterns.FLOAT
options = current_example.options
if options is not None:
# Since 'options' is a mutable list, create a copy to retain the originals
options_with_aliases = copy.deepcopy(options)
# Below we check to see if our options have a unique first word
# sometimes, the model will generate 'Frank' instead of 'Frank Smith'
# We still want to align that, in this case
add_first_word = False
if len(set([i.split(" ")[0] for i in options])) == len(options):
add_first_word = True
for option in options:
option = str(option)
for option_alias in [option.title(), option.lower(), option.upper()]:
options_with_aliases.add(option_alias)
options_alias_to_original[option_alias] = option
if is_list_output:
for option_alias in [
f"'{option}'",
f"'{option}', ",
f"{option}, ",
f"'{option}']",
f"{option}]",
]:
options_with_aliases.add(option_alias)
options_alias_to_original[option_alias] = option
if add_first_word:
option_alias = option.split(" ")[0]
options_alias_to_original[option_alias] = option
options_with_aliases.add(option_alias)
options_with_aliases, options_alias_to_original = get_option_aliases(
options, is_list_output=is_list_output
)
if isinstance(model, LocalModel):
with guidance.user():
lm += MAIN_INSTRUCTION
Expand All @@ -106,44 +147,27 @@ def __call__(
context_formatter, list_options=list_options_in_prompt
)
prompt = lm._current_prompt()
if options is not None:
gen_f = guidance.select(
if is_list_output:
lm += gen_list(
add_quotes=bool(
current_example.output_type
and "str" in current_example.output_type
),
options=options_with_aliases,
list_append=bool(modifier is not None),
name="response",
modifier=modifier,
)
else:
gen_f = guidance.gen(
max_tokens=max_tokens or 200,
regex=regex,
list_append=bool(modifier is not None),
name="response",
stop=["]", "\n"] if is_list_output else ["\n"],
)
# Parse the modifier arg
if modifier is not None:
if modifier == "*":
gen_f = guidance.zero_or_more(gen_f)
elif modifier == "+":
gen_f = guidance.one_or_more(gen_f)
elif re.match("{\d+}", modifier):
repeats = [
int(i)
for i in modifier.replace("}", "").replace("{", "").split(",")
]
if len(repeats) == 1:
repeats = repeats * 2
min_length, max_length = repeats
gen_f = guidance.sequence(
gen_f, min_length=min_length, max_length=max_length
)
if options:
lm += guidance.select(options=options, name="response")
else:
raise IngredientException(
f"Invalid modifier arg {modifier}\dValid values are '+', '*', or any string matching the '{{\d+}}' pattern"
lm += guidance.gen(
max_tokens=max_tokens or 200,
regex=regex,
list_append=bool(modifier is not None),
name="response",
stop=["\n"],
)
with guidance.assistant():
answer_prefix = "[" if is_list_output else ""
response = (lm + answer_prefix + gen_f)["response"]
response = lm["response"]
else:
messages = []
intro_prompt = MAIN_INSTRUCTION
Expand Down Expand Up @@ -190,9 +214,11 @@ def __call__(
else:
response = cast_responses_to_datatypes([response])[0]
# Map from modified options to original, as they appear in DB
if isinstance(response, str):
if not isinstance(response, (list, tuple, set)):
response = [response]
response: List[str] = [options_alias_to_original.get(r, r) for r in response]
response: List[str] = [
options_alias_to_original.get(str(r), r) for r in response
]
if len(response) == 1 and not is_list_output:
response = response[0]
if options and response not in options:
Expand Down

0 comments on commit 809484b

Please sign in to comment.