Skip to content

Commit

Permalink
output_type handling on LLMQA
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Oct 22, 2024
1 parent 6569a19 commit 8c5c80d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 26 deletions.
69 changes: 45 additions & 24 deletions blendsql/ingredients/builtin/qa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
]


def format_tuple(value):
return "(" + ",".join(repr(v) for v in value) + ")"


class QAProgram(Program):
def __call__(
self,
Expand Down Expand Up @@ -89,27 +93,31 @@ def __call__(
context_formatter, list_options=list_options_in_prompt
)
prompt = lm._current_prompt()
with guidance.assistant():
if options is not None:
core_f = guidance.select(
options=options_with_aliases,
list_append=bool(modifier is not None),
name="response",
)
else:
core_f = guidance.gen(
max_tokens=max_tokens or 10,
regex=regex,
list_append=bool(modifier is not None),
name="response",
)
if modifier == "*":
core_f = guidance.zero_or_more(core_f)
elif modifier == "+":
core_f = guidance.one_or_more(core_f)
response = (lm + guidance.capture(core_f, name="response"))._variables[
"response"
]
if options is not None:
gen_f = guidance.select(
options=options_with_aliases,
list_append=bool(modifier is not None),
name="response",
)
else:
gen_f = guidance.gen(
max_tokens=max_tokens or 200,
regex=regex,
list_append=bool(modifier is not None),
name="response",
)
if modifier == "*":
gen_f = guidance.zero_or_more(gen_f)
elif modifier == "+":
gen_f = guidance.one_or_more(gen_f)

@guidance(stateless=True, dedent=False)
def make_predictions(lm, gen_f) -> guidance.models.Model:
with guidance.assistant():
lm += guidance.capture(gen_f, name="response")
return lm

response = (lm + make_predictions(gen_f))["response"]
else:
messages = []
intro_prompt = MAIN_INSTRUCTION
Expand Down Expand Up @@ -157,14 +165,19 @@ def __call__(
response = cast_responses_to_datatypes([response])[0]
# Map from modified options to original, as they appear in DB
if isinstance(response, str):
response: str = options_alias_to_original.get(response, response)
response = [response]
response: List[str] = [options_alias_to_original.get(r, r) for r in response]
if len(response) == 1 and not modifier:
response = response[0]
if options and response not in options:
print(
Fore.RED
+ f"Model did not select from a valid option!\nExpected one of {options}, got '{response}'"
+ Fore.RESET
)
response = f"'{response}'"
else:
response = format_tuple(tuple(response))
return (response, prompt)


Expand Down Expand Up @@ -291,7 +304,7 @@ def run(
value_limit: Optional[int] = None,
long_answer: bool = False,
**kwargs,
) -> Union[str, int, float, tuple]:
) -> Union[str, int, float]:
"""
Args:
question: The question to map onto the values. Will also be the new column name
Expand Down Expand Up @@ -322,12 +335,20 @@ def run(
if context is not None:
if value_limit is not None:
context = context.iloc[:value_limit]
resolved_output_type = None
if modifier and output_type:
if not output_type.startswith("List"):
resolved_output_type = f"List[{output_type}]"
elif modifier:
resolved_output_type = "list"
else:
resolved_output_type = output_type
current_example = QAExample(
**{
"question": question,
"context": context,
"options": options,
"output_type": "list" if modifier and not output_type else output_type,
"output_type": resolved_output_type,
}
)
few_shot_examples: List[AnnotatedQAExample] = few_shot_retriever(
Expand Down
2 changes: 1 addition & 1 deletion blendsql/ingredients/ingredient.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def run(self, model: Model, context: pd.DataFrame, **kwargs) -> str:
'''

ingredient_type: str = IngredientType.QA.value
allowed_output_types: Tuple[Type] = (Union[str, int, float, tuple],)
allowed_output_types: Tuple[Type] = (Union[str, int, float],)

def __call__(
self,
Expand Down
4 changes: 3 additions & 1 deletion blendsql/parse/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ def create_regex(
):
output_type = "float" # Use 'float' as default numeric regex, since it's more expressive than 'integer'
if output_type is not None:
added_kwargs["output_type"] = output_type
added_kwargs["output_type"] = (
output_type if modifier is None else f"List[{output_type}]"
)
added_kwargs[IngredientKwarg.REGEX] = create_regex(output_type)
added_kwargs["modifier"] = modifier
return added_kwargs
Expand Down

0 comments on commit 8c5c80d

Please sign in to comment.