From 8c5c80d81ee1f0287b44f3fb1ffdd6d9d7fa9b30 Mon Sep 17 00:00:00 2001 From: parkervg Date: Tue, 22 Oct 2024 16:08:35 -0400 Subject: [PATCH] output_type handling on LLMQA --- blendsql/ingredients/builtin/qa/main.py | 69 ++++++++++++++++--------- blendsql/ingredients/ingredient.py | 2 +- blendsql/parse/_parse.py | 4 +- 3 files changed, 49 insertions(+), 26 deletions(-) diff --git a/blendsql/ingredients/builtin/qa/main.py b/blendsql/ingredients/builtin/qa/main.py index c98c70e..aec6eb5 100644 --- a/blendsql/ingredients/builtin/qa/main.py +++ b/blendsql/ingredients/builtin/qa/main.py @@ -33,6 +33,10 @@ ] +def format_tuple(value): + return "(" + ",".join(repr(v) for v in value) + ")" + + class QAProgram(Program): def __call__( self, @@ -89,27 +93,31 @@ def __call__( context_formatter, list_options=list_options_in_prompt ) prompt = lm._current_prompt() - with guidance.assistant(): - if options is not None: - core_f = guidance.select( - options=options_with_aliases, - list_append=bool(modifier is not None), - name="response", - ) - else: - core_f = guidance.gen( - max_tokens=max_tokens or 10, - regex=regex, - list_append=bool(modifier is not None), - name="response", - ) - if modifier == "*": - core_f = guidance.zero_or_more(core_f) - elif modifier == "+": - core_f = guidance.one_or_more(core_f) - response = (lm + guidance.capture(core_f, name="response"))._variables[ - "response" - ] + if options is not None: + gen_f = guidance.select( + options=options_with_aliases, + list_append=bool(modifier is not None), + name="response", + ) + else: + gen_f = guidance.gen( + max_tokens=max_tokens or 200, + regex=regex, + list_append=bool(modifier is not None), + name="response", + ) + if modifier == "*": + gen_f = guidance.zero_or_more(gen_f) + elif modifier == "+": + gen_f = guidance.one_or_more(gen_f) + + @guidance(stateless=True, dedent=False) + def make_predictions(lm, gen_f) -> guidance.models.Model: + with guidance.assistant(): + lm += guidance.capture(gen_f, name="response") + return lm + + response = (lm + make_predictions(gen_f))["response"] else: messages = [] intro_prompt = MAIN_INSTRUCTION @@ -157,7 +165,10 @@ def __call__( response = cast_responses_to_datatypes([response])[0] # Map from modified options to original, as they appear in DB if isinstance(response, str): - response: str = options_alias_to_original.get(response, response) + response = [response] + response: List[str] = [options_alias_to_original.get(r, r) for r in response] + if len(response) == 1 and not modifier: + response = response[0] if options and response not in options: print( Fore.RED @@ -165,6 +176,8 @@ def __call__( + Fore.RESET ) response = f"'{response}'" + else: + response = format_tuple(tuple(response)) return (response, prompt) @@ -291,7 +304,7 @@ def run( value_limit: Optional[int] = None, long_answer: bool = False, **kwargs, - ) -> Union[str, int, float, tuple]: + ) -> Union[str, int, float]: """ Args: question: The question to map onto the values. Will also be the new column name @@ -322,12 +335,20 @@ def run( if context is not None: if value_limit is not None: context = context.iloc[:value_limit] + resolved_output_type = None + if modifier and output_type: + if not output_type.startswith("List"): + resolved_output_type = f"List[{output_type}]" + elif modifier: + resolved_output_type = "list" + else: + resolved_output_type = output_type current_example = QAExample( **{ "question": question, "context": context, "options": options, - "output_type": "list" if modifier and not output_type else output_type, + "output_type": resolved_output_type, } ) few_shot_examples: List[AnnotatedQAExample] = few_shot_retriever( diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index 41553bf..2c5e108 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -543,7 +543,7 @@ def run(self, model: Model, context: pd.DataFrame, **kwargs) -> str: ''' ingredient_type: str = IngredientType.QA.value - allowed_output_types: Tuple[Type] = (Union[str, int, float, tuple],) + allowed_output_types: Tuple[Type] = (Union[str, int, float],) def __call__( self, diff --git a/blendsql/parse/_parse.py b/blendsql/parse/_parse.py index d11f2e3..326a4d3 100644 --- a/blendsql/parse/_parse.py +++ b/blendsql/parse/_parse.py @@ -493,7 +493,9 @@ def create_regex( ): output_type = "float" # Use 'float' as default numeric regex, since it's more expressive than 'integer' if output_type is not None: - added_kwargs["output_type"] = output_type + added_kwargs["output_type"] = ( + output_type if modifier is None else f"List[{output_type}]" + ) added_kwargs[IngredientKwarg.REGEX] = create_regex(output_type) added_kwargs["modifier"] = modifier return added_kwargs