From 0db695827dedd28505f866342539ca0d32bf7cad Mon Sep 17 00:00:00 2001 From: parkervg Date: Tue, 14 May 2024 14:52:54 -0400 Subject: [PATCH 01/21] Modifying _cfg_grammar.lark to use string.Template format --- blendsql/grammars/_cfg_grammar.lark | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/blendsql/grammars/_cfg_grammar.lark b/blendsql/grammars/_cfg_grammar.lark index 99ee5646..24c34d3f 100644 --- a/blendsql/grammars/_cfg_grammar.lark +++ b/blendsql/grammars/_cfg_grammar.lark @@ -66,22 +66,22 @@ JOIN_TYPE: "INNER"i | "FULL"i ["OUTER"i] | "LEFT"i["OUTER"i] | "RIGHT"i ["OUTER" | "CASE"i (when_then)+ "ELSE"i expression_math "END"i -> case_expression | "CAST"i "(" expression_math "AS"i TYPENAME ")" -> as_type | "CAST"i "(" literal "AS"i TYPENAME ")" -> literal_cast - | AGGREGATION_FUNCTIONS expression_math ")" [window_form] -> sql_aggregation + | AGGREGATE_FUNCTIONS expression_math ")" [window_form] -> sql_aggregation | SCALAR_FUNCTIONS [(expression_math ",")*] expression_math ")" -> sql_scalar | blendsql_aggregation_expr -> blendsql_aggregation | "RANK"i "(" ")" window_form -> rank_expression | "DENSE_RANK"i "(" ")" window_form -> dense_rank_expression | "|" "|" expression_math -BLENDSQL_AGGREGATION: ("LLMQA("i | "LLMVerify("i) -BLENDSQL_JOIN: ("LLMJoin("i) +BLENDSQL_AGGREGATE_FUNCTIONS: $blendsql_aggregate_functions +BLENDSQL_JOIN_FUNCTIONS: $blendsql_join_functions left_on_arg: "left_on" "=" string right_on_arg: "right_on" "=" string blendsql_arg: (name "=" literal | literal | "(" start ")") blendsql_expression_math: blendsql_arg ("," blendsql_arg)* -blendsql_aggregation_expr: "{{" BLENDSQL_AGGREGATION blendsql_expression_math ")" "}}" -blendsql_join_expr: "{{" BLENDSQL_JOIN (left_on_arg "," right_on_arg|right_on_arg "," left_on_arg) ")" "}}" +blendsql_aggregation_expr: "{{" BLENDSQL_AGGREGATE_FUNCTIONS blendsql_expression_math ")" "}}" +blendsql_join_expr: "{{" BLENDSQL_JOIN_FUNCTIONS (left_on_arg "," right_on_arg|right_on_arg "," left_on_arg) ")" "}}" window_form: "OVER"i "(" ["PARTITION"i "BY"i (expression_math ",")* expression_math] ["ORDER"i "BY"i (order ",")* order [ row_range_clause ] ] ")" @@ -125,7 +125,7 @@ TYPENAME: "object"i | "string"i // https://www.sqlite.org/lang_expr.html#*funcinexpr -AGGREGATION_FUNCTIONS: ("sum("i | "avg("i | "min("i | "max("i | "count("i ["distinct"i] ) +AGGREGATE_FUNCTIONS: ("sum("i | "avg("i | "min("i | "max("i | "count("i ["distinct"i] ) SCALAR_FUNCTIONS: ("trim("i | "coalesce("i | "abs("i) alias: string -> alias_string From 04cbac1cfc9b011c8cb2a2f862a027fad24230c4 Mon Sep 17 00:00:00 2001 From: parkervg Date: Tue, 14 May 2024 14:53:29 -0400 Subject: [PATCH 02/21] load_cfg_parser --- blendsql/grammars/utils.py | 63 ++++++++++++++++++++++++++++++++++++++ blendsql/nl_to_blendsql.py | 37 ++++------------------ tests/test_grammar.py | 5 +-- 3 files changed, 72 insertions(+), 33 deletions(-) create mode 100644 blendsql/grammars/utils.py diff --git a/blendsql/grammars/utils.py b/blendsql/grammars/utils.py new file mode 100644 index 00000000..1c541c81 --- /dev/null +++ b/blendsql/grammars/utils.py @@ -0,0 +1,63 @@ +from pathlib import Path +from typing import Optional, Collection, List, Dict +from string import Template + +from ..ingredients import Ingredient +from .._constants import IngredientType +from .minEarley.parser import EarleyParser + + +def format_ingredient_names_to_lark(names: List[str]) -> str: + """Formats list of ingredient names the way our Lark grammar expects. + + Examples: + ```python + format_ingredient_names_to_lark(["LLMQA", "LLMVerify"]) + >>> '("LLMQA("i | "LLMVerify("i)' + ``` + """ + return "(" + " | ".join([f'"{n}("i' for n in names]) + ")" + + +def load_cfg_parser(ingredients: Optional[Collection[Ingredient]]) -> EarleyParser: + """Loads BlendSQL CFG parser. + Dynamically modifies grammar string to include only valid ingredients. + """ + with open(Path(__file__).parent / "./_cfg_grammar.lark", encoding="utf-8") as f: + cfg_grammar = Template(f.read()) + blendsql_join_functions = [] + blendsql_aggregate_functions = [] + ingredient_type_to_function_type: Dict[str, List[str]] = { + IngredientType.JOIN: blendsql_join_functions, + IngredientType.QA: blendsql_aggregate_functions, + } + for ingredient in ingredients: + if ingredient.ingredient_type not in ingredient_type_to_function_type: + print( + f"Not sure what to do with ingredient type '{ingredient.ingredient_type}'" + ) + continue + ingredient_type_to_function_type[ingredient.ingredient_type].append( + ingredient.__name__ + ) + cfg_grammar = cfg_grammar.substitute( + blendsql_join_functions=format_ingredient_names_to_lark( + blendsql_join_functions + ), + blendsql_aggregate_functions=format_ingredient_names_to_lark( + blendsql_aggregate_functions + ), + ) + return EarleyParser( + grammar=cfg_grammar, + start="start", + keep_all_tokens=True, + ) + + +if __name__ == "__main__": + from blendsql import LLMMap, LLMJoin, LLMValidate + + parser = load_cfg_parser({LLMMap, LLMJoin, LLMValidate}) + # parser.parse("{{LLMQA('what is the answer', (select * from w))}}") + print() diff --git a/blendsql/nl_to_blendsql.py b/blendsql/nl_to_blendsql.py index 82510bd4..93d83e2f 100644 --- a/blendsql/nl_to_blendsql.py +++ b/blendsql/nl_to_blendsql.py @@ -2,19 +2,14 @@ from textwrap import dedent from guidance import gen, select from colorama import Fore -from pathlib import Path import logging from .ingredients import Ingredient, IngredientException from .models import Model from ._program import Program from .grammars.minEarley.parser import EarleyParser +from .grammars.utils import load_cfg_parser -CFG_PARSER = EarleyParser.open( - Path(__file__).parent / "./grammars/_cfg_grammar.lark", - start="start", - keep_all_tokens=True, -) PARSER_STOP_TOKENS = ["---", ";", "\n\n"] PARSER_SYSTEM_PROMPT = dedent( """ @@ -145,6 +140,7 @@ def nl_to_blendsql( logging.getLogger().setLevel(logging.DEBUG) else: logging.getLogger().setLevel(logging.ERROR) + parser: EarleyParser = load_cfg_parser(ingredients) system_prompt: str = create_system_prompt( ingredients=ingredients, few_shot_examples=few_shot_examples ) @@ -174,13 +170,13 @@ def nl_to_blendsql( partial_program_prediction + " " + residual_program_prediction ) - if validate_program(program_prediction, CFG_PARSER): + if validate_program(program_prediction, parser): ret_prediction = program_prediction continue # find the max score from a list of score prefix, candidates, pos_in_stream = obtain_correction_pairs( - program_prediction, CFG_PARSER + program_prediction, parser ) candidates = [i for i in candidates if i.strip() != ""] if len(candidates) == 0: @@ -204,13 +200,13 @@ def nl_to_blendsql( inserted_candidate = ( prefix + selected_candidate + program_prediction[pos_in_stream:] ) - if validate_program(inserted_candidate, CFG_PARSER): + if validate_program(inserted_candidate, parser): ret_prediction = inserted_candidate continue # 2) If rest of our query is also broken, we just keep up to the prefix + candidate partial_program_prediction = prefix + selected_candidate for p in {inserted_candidate, partial_program_prediction}: - if validate_program(p, CFG_PARSER): + if validate_program(p, parser): ret_prediction = p num_correction_left -= 1 @@ -224,24 +220,3 @@ def nl_to_blendsql( ret_prediction = initial_prediction logging.debug(Fore.GREEN + ret_prediction + Fore.RESET) return ret_prediction - - -if __name__ == "__main__": - from blendsql import nl_to_blendsql, LLMMap - from blendsql.models import TransformersLLM - from blendsql.db import SQLite - from blendsql.utils import fetch_from_hub - - model = TransformersLLM("Qwen/Qwen1.5-0.5B") - db = SQLite( - fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db") - ) - - query = nl_to_blendsql( - question="Which venues in Sydney saw more than 30 points scored?", - model=model, - ingredients={LLMMap}, - serialized_db=db.to_serialized(num_rows=3, use_tables=["w", "documents"]), - verbose=True, - max_grammar_corrections=5, - ) diff --git a/tests/test_grammar.py b/tests/test_grammar.py index 2914b65a..fb7f2b0e 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -1,11 +1,12 @@ import pytest -from blendsql.grammars.minEarley.parser import EarleyParser +from blendsql import LLMQA, LLMJoin, LLMMap +from blendsql.grammars.utils import load_cfg_parser, EarleyParser from blendsql.grammars.minEarley.earley_exceptions import UnexpectedInput @pytest.fixture(scope="session") def parser() -> EarleyParser: - return EarleyParser.open("./blendsql/grammars/_cfg_grammar.lark", start="start") + return load_cfg_parser({LLMQA, LLMJoin, LLMMap}) accept_queries = [ From 294d86efaef49ef2849e5801fb201016055ca538 Mon Sep 17 00:00:00 2001 From: parkervg Date: Tue, 14 May 2024 14:53:51 -0400 Subject: [PATCH 03/21] IngredientType enum switch to strings Better debugging messages --- blendsql/_constants.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/blendsql/_constants.py b/blendsql/_constants.py index cc46da60..45273bcc 100644 --- a/blendsql/_constants.py +++ b/blendsql/_constants.py @@ -1,4 +1,4 @@ -from enum import Enum, EnumMeta, auto +from enum import Enum, EnumMeta from dataclasses import dataclass HF_REPO_ID = "parkervg/blendsql-test-dbs" @@ -15,10 +15,10 @@ def __contains__(cls, item): class IngredientType(str, Enum, metaclass=StrInMeta): - MAP = auto() - STRING = auto() - QA = auto() - JOIN = auto() + MAP = "MAP" + STRING = "STRING" + QA = "QA" + JOIN = "JOIN" @dataclass From c2496442e2b09d6bbc3fa1401b80479f6b342a83 Mon Sep 17 00:00:00 2001 From: parkervg Date: Tue, 14 May 2024 17:15:36 -0400 Subject: [PATCH 04/21] pandas >= 2.0.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4b956041..43f28ae9 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def find_version(*file_paths): install_requires=[ "guidance>=0.1.0", "pyparsing==3.1.1", - "pandas==1.5.3", + "pandas>=2.0.0", "bottleneck>=1.3.6", "python-dotenv==1.0.1", "sqlglot==18.13.0", From 5c452686bd9438e519c55216ae4838ef536c7ee2 Mon Sep 17 00:00:00 2001 From: parkervg Date: Tue, 14 May 2024 17:16:51 -0400 Subject: [PATCH 05/21] Optimization of JoinIngredient logic by sorting values first On the 1966_NBA_Expansion_Draft_0 test db: Without cached LLM response (10 runs): before: 3.16 after: 1.91 With cached LLM response (100 runs): before: 0.0175 after: 0.0166 --- blendsql/ingredients/ingredient.py | 54 ++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index 1d4db4bc..670b02e3 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -211,10 +211,12 @@ def __call__( get_temp_subquery_table: Callable = kwargs.get("get_temp_subquery_table") get_temp_session_table: Callable = kwargs.get("get_temp_session_table") + # Depending on the size of the underlying data, it may be optimal to swap + # the order of 'left_on' and 'right_on' columns during processing + swapped = False values = [] original_lr_identifiers = [] modified_lr_identifiers = [] - left_values, right_values = [], [] mapping = {} for on_arg in [left_on, right_on]: tablename, colname = utils.get_tablename_colname(on_arg) @@ -225,30 +227,41 @@ def __call__( tablename=tablename, ) values.append( - set( - [ - str(i) - for i in self.db.execute_query( - f'SELECT DISTINCT "{colname}" FROM "{tablename}"' - )[colname].tolist() - ] - ) + [ + str(i) + for i in self.db.execute_query( + f'SELECT DISTINCT "{colname}" FROM "{tablename}"' + )[colname] + ] ) modified_lr_identifiers.append((tablename, colname)) if question is None: # First, check which values we actually need to call Model on # We don't want to join when there's already an intuitive alignment + # First, make sure outer loop is shorter of the two lists + outer, inner = sorted(values, key=len) + _outer = [] + inner = set(inner) mapping = {} - left_values, right_values = values - for l in left_values: - if l in right_values: + for l in outer: + if l in inner: # Define this mapping, and remove from Model inference call mapping[l] = l + inner -= {l} + else: + _outer.append(l) + if len(inner) == 0: + break + to_compare = [inner, _outer] + else: + to_compare = values - processed_values = set(list(mapping.keys())) - left_values = left_values.difference(processed_values) - right_values = right_values.difference(processed_values) + # Finally, order by new (remaining) length and check if we swapped places from original + sorted_values = sorted(to_compare, key=len) + if sorted_values != values: + swapped = True + left_values, right_values = sorted_values kwargs["left_values"] = left_values kwargs["right_values"] = right_values @@ -274,16 +287,21 @@ def __call__( # Using mapped left/right values, create intermediary mapping table temp_join_tablename = get_temp_session_table(str(uuid.uuid4())[:4]) + # Below, we check to see if 'swapped' is True + # If so, we need to inverse what is 'left', and what is 'right' joined_values_df = pd.DataFrame( - data={"left": mapping.keys(), "right": mapping.values()} + data={ + "left" if not swapped else "right": mapping.keys(), + "right" if not swapped else "left": mapping.values(), + } ) self.db.to_temp_table(df=joined_values_df, tablename=temp_join_tablename) return ( left_tablename, right_tablename, f"""JOIN "{temp_join_tablename}" ON "{right_tablename}"."{right_colname}" = "{temp_join_tablename}".right - JOIN "{left_tablename}" ON "{left_tablename}"."{left_colname}" = "{temp_join_tablename}".left - """, + JOIN "{left_tablename}" ON "{left_tablename}"."{left_colname}" = "{temp_join_tablename}".left + """, temp_join_tablename, ) From 175ee2132568df79b7053c29935949184b2892ac Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 16 May 2024 14:51:45 -0400 Subject: [PATCH 06/21] Removing 'align_to_real_columns' --- blendsql/ingredients/ingredient.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index 670b02e3..f1042027 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -33,17 +33,6 @@ def unpack_default_kwargs(**kwargs): ) -def align_to_real_columns(db: Database, colname: str, tablename: str) -> str: - table_columns = db.execute_query(f'SELECT * FROM "{tablename}" LIMIT 1').columns - if colname not in table_columns: - # Try to align with column, according to some normalization rules - cleaned_to_original = { - col.replace("\\n", " ").replace("\xa0", " "): col for col in table_columns - } - colname = cleaned_to_original[colname] - return colname - - @attrs class Ingredient(ABC): name: str = attrib() From 0c692c8b33580e98fc07f4d84889f08c5bc64645 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 16 May 2024 14:52:24 -0400 Subject: [PATCH 07/21] 'execute_to_df' vs 'execute_to_list' distinction --- blendsql/blend.py | 16 +++---- blendsql/db/_database.py | 15 ++++++- blendsql/db/_postgres.py | 2 +- blendsql/db/_sqlite.py | 4 +- blendsql/db/utils.py | 4 ++ blendsql/ingredients/ingredient.py | 46 +++++++++---------- research/utils/bridge_content_encoder.py | 2 +- research/utils/database.py | 2 +- research/utils/ottqa/ottqa.py | 2 +- tests/test_multi_table_blendsql.py | 42 +++++++++--------- tests/test_single_table_blendsql.py | 56 ++++++++++++------------ 11 files changed, 100 insertions(+), 91 deletions(-) diff --git a/blendsql/blend.py b/blendsql/blend.py index 31c756a8..4a0187be 100644 --- a/blendsql/blend.py +++ b/blendsql/blend.py @@ -32,7 +32,7 @@ get_tablename_colname, ) from .db import Database -from .db.utils import double_quote_escape, single_quote_escape +from .db.utils import double_quote_escape, select_all_from_table_query from ._sqlglot import ( MODIFIERS, get_first_child, @@ -442,7 +442,7 @@ def _blend( # If we don't have any ingredient calls, execute as normal SQL if len(ingredients) == 0 or len(ingredient_alias_to_parsed_dict) == 0: return Smoothie( - df=db.execute_query(query), + df=db.execute_to_df(query), meta=SmoothieMeta( num_values_passed=0, num_prompt_tokens=0, @@ -541,7 +541,7 @@ def _blend( ) try: db.to_temp_table( - df=db.execute_query(abstracted_query), + df=db.execute_to_df(abstracted_query), tablename=_get_temp_subquery_table(tablename), ) except OperationalError as e: @@ -741,14 +741,14 @@ def _blend( # On their left join merge command: https://github.com/HKUNLP/Binder/blob/9eede69186ef3f621d2a50572e1696bc418c0e77/nsql/database.py#L196 # We create a new temp table to avoid a potentially self-destructive operation base_tablename = tablename - _base_table: pd.DataFrame = db.execute_query( - f'SELECT * FROM "{double_quote_escape(base_tablename)}";' + _base_table: pd.DataFrame = db.execute_to_df( + select_all_from_table_query(base_tablename) ) base_table = _base_table if db.has_temp_table(_get_temp_session_table(tablename)): base_tablename = _get_temp_session_table(tablename) - base_table: pd.DataFrame = db.execute_query( - f"SELECT * FROM '{single_quote_escape(base_tablename)}';", + base_table: pd.DataFrame = db.execute_to_df( + select_all_from_table_query(base_tablename) ) previously_added_columns = base_table.columns.difference( _base_table.columns @@ -804,7 +804,7 @@ def _blend( ) logging.debug("") - df = db.execute_query(query) + df = db.execute_to_df(query) return Smoothie( df=df, diff --git a/blendsql/db/_database.py b/blendsql/db/_database.py index cbcb9c35..584ff56f 100644 --- a/blendsql/db/_database.py +++ b/blendsql/db/_database.py @@ -1,4 +1,4 @@ -from typing import Generator, List, Dict, Collection +from typing import Generator, List, Dict, Collection, Type, Optional from typing import Iterable import pandas as pd from colorama import Fore @@ -86,7 +86,7 @@ def to_temp_table(self, df: pd.DataFrame, tablename: str): self.con.execute(text(create_table_stmt)) df.to_sql(name=tablename, con=self.con, if_exists="append", index=False) - def execute_query(self, query: str, params: dict = None) -> pd.DataFrame: + def execute_to_df(self, query: str, params: dict = None) -> pd.DataFrame: """ Execute the given query and return results as dataframe. @@ -106,3 +106,14 @@ def execute_query(self, query: str, params: dict = None) -> pd.DataFrame: ``` """ return pd.read_sql(text(query), self.con, params=params) + + def execute_to_list( + self, query: str, to_type: Optional[Type] = lambda x: x + ) -> list: + """A lower-level execute method that doesn't use the pandas processing logic. + Returns results as a tuple. + """ + res = [] + for row in self.con.execute(text(query)).fetchall(): + res.append(to_type(row[0])) + return res diff --git a/blendsql/db/_postgres.py b/blendsql/db/_postgres.py index 323d4131..779eaca8 100644 --- a/blendsql/db/_postgres.py +++ b/blendsql/db/_postgres.py @@ -36,7 +36,7 @@ def __init__(self, db_path: str): def has_temp_table(self, tablename: str) -> bool: return ( tablename - in self.execute_query( + in self.execute_to_df( "SELECT * FROM information_schema.tables WHERE table_schema LIKE 'pg_temp_%'" )["table_name"].unique() ) diff --git a/blendsql/db/_sqlite.py b/blendsql/db/_sqlite.py index 767d9fe3..a1077db3 100644 --- a/blendsql/db/_sqlite.py +++ b/blendsql/db/_sqlite.py @@ -23,7 +23,7 @@ def __init__(self, db_url: str): def has_temp_table(self, tablename: str) -> bool: return ( tablename - in self.execute_query( + in self.execute_to_df( "SELECT name FROM sqlite_temp_master WHERE type='table';" )["name"].unique() ) @@ -39,7 +39,7 @@ def get_sqlglot_schema(self) -> dict: schema = {} for tablename in self.tables(): schema[f'"{double_quote_escape(tablename)}"'] = {} - for _, row in self.execute_query( + for _, row in self.execute_to_df( f""" SELECT name, type FROM pragma_table_info(:t) """, diff --git a/blendsql/db/utils.py b/blendsql/db/utils.py index 2a0591d0..87dccec6 100644 --- a/blendsql/db/utils.py +++ b/blendsql/db/utils.py @@ -14,6 +14,10 @@ def escape(s): return single_quote_escape(double_quote_escape(s)) +def select_all_from_table_query(tablename: str) -> str: + return f'SELECT * FROM "{double_quote_escape(tablename)}";' + + def truncate_df_content(df: pd.DataFrame, truncation_limit: int) -> pd.DataFrame: # Truncate long strings return df.applymap( diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index f1042027..4438d8f1 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -19,7 +19,7 @@ from .. import utils from .._constants import IngredientKwarg, IngredientType from ..db import Database -from ..db.utils import double_quote_escape +from ..db.utils import select_all_from_table_query class IngredientException(ValueError): @@ -107,36 +107,33 @@ def __call__(self, question: str, context: str, *args, **kwargs) -> tuple: ): new_arg_column = "_" + new_arg_column - original_table = self.db.execute_query(f'SELECT * FROM "{tablename}"') + original_table = self.db.execute_to_df(select_all_from_table_query(tablename)) # Get a list of values to map # First, check if we've already dumped some `MapIngredient` output to the main session table if temp_session_table_exists: - temp_session_table = self.db.execute_query( - f'SELECT * FROM "{double_quote_escape(temp_session_tablename)}"' - ) - colname = align_to_real_columns( - db=self.db, colname=colname, tablename=temp_session_tablename + temp_session_table = self.db.execute_to_df( + select_all_from_table_query(temp_session_tablename) ) # We don't need to run this function on everything, # if a previous subquery already got to certain values if new_arg_column in temp_session_table.columns: - values = self.db.execute_query( + values = self.db.execute_to_list( f'SELECT DISTINCT "{colname}" FROM "{temp_session_tablename}" WHERE "{new_arg_column}" IS NULL', - )[colname].tolist() + to_type=str, + ) # Base case: this is the first time we've used this particular ingredient # BUT, temp_session_tablename still exists else: - values = self.db.execute_query( - f'SELECT DISTINCT "{colname}" FROM "{temp_session_tablename}"' - )[colname].tolist() + values = self.db.execute_to_list( + f'SELECT DISTINCT "{colname}" FROM "{temp_session_tablename}"', + to_type=str, + ) else: - colname = align_to_real_columns( - db=self.db, colname=colname, tablename=value_source_tablename + values = self.db.execute_to_list( + f'SELECT DISTINCT "{colname}" FROM "{value_source_tablename}"', + to_type=str, ) - values = self.db.execute_query( - f'SELECT DISTINCT "{colname}" FROM "{value_source_tablename}"' - )[colname].tolist() # No need to run ingredient if we have no values to map onto if len(values) == 0: @@ -216,12 +213,9 @@ def __call__( tablename=tablename, ) values.append( - [ - str(i) - for i in self.db.execute_query( - f'SELECT DISTINCT "{colname}" FROM "{tablename}"' - )[colname] - ] + self.db.execute_to_list( + f'SELECT DISTINCT "{colname}" FROM "{tablename}"', to_type=str + ) ) modified_lr_identifiers.append((tablename, colname)) @@ -320,7 +314,7 @@ def __call__( if context is not None: if isinstance(context, str): tablename, colname = utils.get_tablename_colname(context) - subtable = self.db.execute_query( + subtable = self.db.execute_to_df( f'SELECT "{colname}" FROM "{tablename}"' ) elif not isinstance(context, pd.DataFrame): @@ -335,9 +329,9 @@ def __call__( try: tablename, colname = utils.get_tablename_colname(options) tablename = aliases_to_tablenames.get(tablename, tablename) - unpacked_options = self.db.execute_query( + unpacked_options = self.db.execute_to_list( f'SELECT DISTINCT "{colname}" FROM "{tablename}"' - )[colname].tolist() + ) except ValueError: unpacked_options = options.split(";") unpacked_options = set(unpacked_options) diff --git a/research/utils/bridge_content_encoder.py b/research/utils/bridge_content_encoder.py index c4c68494..3005d379 100644 --- a/research/utils/bridge_content_encoder.py +++ b/research/utils/bridge_content_encoder.py @@ -245,7 +245,7 @@ def get_column_picklist_with_db(table_name: str, column_name: str, db) -> list: fetch_sql = 'SELECT DISTINCT `{}` FROM "{}"'.format( column_name, double_quote_escape(table_name) ) - picklist = set(db.execute_query(fetch_sql).values.flat) + picklist = set(db.execute_to_df(fetch_sql).values.flat) picklist = list(picklist) cache[key] = picklist return picklist diff --git a/research/utils/database.py b/research/utils/database.py index 85cc6fbc..e16d0a3e 100644 --- a/research/utils/database.py +++ b/research/utils/database.py @@ -50,7 +50,7 @@ def to_serialized( else: serialized_db.append(f"{num_rows} example rows:") serialized_db.append(f"{get_rows_query}") - rows = db.execute_query(get_rows_query) + rows = db.execute_to_df(get_rows_query) if truncate_content is not None: # Truncate long strings rows = rows.map( diff --git a/research/utils/ottqa/ottqa.py b/research/utils/ottqa/ottqa.py index e582e999..0bf27154 100644 --- a/research/utils/ottqa/ottqa.py +++ b/research/utils/ottqa/ottqa.py @@ -69,7 +69,7 @@ def ottqa_get_input( model_args: ModelArguments, ) -> Tuple[str, dict]: if "docs_tablesize" not in cache: - cache["docs_tablesize"] = db.execute_query( + cache["docs_tablesize"] = db.execute_to_df( f"SELECT COUNT(*) FROM {DOCS_TABLE_NAME}" ).values[0][0] cache["docs_tablesize"] diff --git a/tests/test_multi_table_blendsql.py b/tests/test_multi_table_blendsql.py index d04f6123..845e0870 100644 --- a/tests/test_multi_table_blendsql.py +++ b/tests/test_multi_table_blendsql.py @@ -62,10 +62,10 @@ def test_simple_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE Symbol in ( @@ -99,10 +99,10 @@ def test_join_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT Name) FROM constituents WHERE Sector = 'Information Technology' """ @@ -136,10 +136,10 @@ def test_join_not_qualified_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT Name) FROM constituents WHERE Sector = 'Information Technology' """ @@ -170,7 +170,7 @@ def test_select_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -198,7 +198,7 @@ def test_complex_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -229,7 +229,7 @@ def test_complex_not_qualified_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -251,7 +251,7 @@ def test_join_ingredient_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -269,7 +269,7 @@ def test_qa_equals_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -289,10 +289,10 @@ def test_table_alias_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 """ @@ -320,10 +320,10 @@ def test_subquery_alias_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 AND Quantity > 200 """ @@ -355,7 +355,7 @@ def test_cte_qa_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient # passed_to_ingredient = db.execute_query( @@ -390,7 +390,7 @@ def test_cte_qa_named_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient # passed_to_ingredient = db.execute_query( @@ -422,10 +422,10 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol @@ -455,10 +455,10 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients) db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol diff --git a/tests/test_single_table_blendsql.py b/tests/test_single_table_blendsql.py index 129db9ab..7ea4f8bb 100644 --- a/tests/test_single_table_blendsql.py +++ b/tests/test_single_table_blendsql.py @@ -68,7 +68,7 @@ def test_simple_ingredient_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) @@ -84,7 +84,7 @@ def test_simple_ingredient_exec_at_end(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) @@ -98,7 +98,7 @@ def test_simple_ingredient_exec_in_select(db, ingredients): ingredients=ingredients, ) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE parent_category = 'Auto & Transport' """ @@ -128,10 +128,10 @@ def test_nested_ingredient_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 100 """ @@ -161,10 +161,10 @@ def test_nonexistent_column_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'this does not exist' """ @@ -194,10 +194,10 @@ def test_nested_and_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["O"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ @@ -229,10 +229,10 @@ def test_multiple_nested_ingredients(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A", "T"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) + COUNT(DISTINCT child_category) FROM transactions WHERE parent_category = 'Food' """ @@ -256,10 +256,10 @@ def test_length_ingredient(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ @@ -283,10 +283,10 @@ def test_max_length(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ @@ -312,10 +312,10 @@ def test_limit(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ @@ -337,7 +337,7 @@ def test_select(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -353,10 +353,10 @@ def test_ingredient_in_select_stmt(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ @@ -377,10 +377,10 @@ def test_ingredient_in_select_stmt_with_filter(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ @@ -400,10 +400,10 @@ def test_nested_duplicate_map_calls(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ @@ -433,10 +433,10 @@ def test_many_duplicate_map_calls(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT (SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 1300) @@ -472,10 +472,10 @@ def test_exists_isolated_qa_call(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_df( """ SELECT (SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 500) + (SELECT COUNT(*) FROM transactions WHERE amount < 500) """ From f8b55500562178c83829667e13f6b249f6b20270 Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 16 May 2024 14:54:12 -0400 Subject: [PATCH 08/21] run-debug join test --- research/run-debug.py | 35 --------------------- run-debug.py | 71 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 35 deletions(-) delete mode 100644 research/run-debug.py create mode 100644 run-debug.py diff --git a/research/run-debug.py b/research/run-debug.py deleted file mode 100644 index 6f90a5aa..00000000 --- a/research/run-debug.py +++ /dev/null @@ -1,35 +0,0 @@ -from blendsql import blend, LLMJoin -from blendsql.db import SQLite, PostgreSQL -from blendsql.models import OpenaiLLM -from blendsql.utils import fetch_from_hub, tabulate - -if __name__ == "__main__": - blendsql = """ - SELECT date, rival, score, documents.content AS "Team Description" FROM w - JOIN {{ - LLMJoin( - left_on='documents::title', - right_on='w::rival' - ) - }} WHERE rival = 'nsw waratahs' - """ - # Make our smoothie - the executed BlendSQL script - smoothie = blend( - query=blendsql, - db=SQLite( - fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db") - ), - blender=OpenaiLLM("gpt-3.5-turbo"), - ingredients={LLMJoin}, - ) - print(tabulate(smoothie.df)) - - smoothie = blend( - query=blendsql, - db=PostgreSQL( - "blendsql@localhost:5432/1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1" - ), - blender=OpenaiLLM("gpt-3.5-turbo"), - ingredients={LLMJoin}, - ) - print(tabulate(smoothie.df)) diff --git a/run-debug.py b/run-debug.py new file mode 100644 index 00000000..633391da --- /dev/null +++ b/run-debug.py @@ -0,0 +1,71 @@ +from blendsql import blend, LLMJoin, LLMMap, LLMQA +from blendsql.db import SQLite +from blendsql.models import TransformersLLM +from blendsql.utils import fetch_from_hub +from tqdm import tqdm + +# TEST_QUERIES = [ +# """ +# SELECT DISTINCT venue FROM w +# WHERE city = 'sydney' AND {{ +# LLMMap( +# 'More than 30 total points?', +# 'w::score' +# ) +# }} = TRUE +# """, +# """ +# SELECT * FROM w +# WHERE city = {{ +# LLMQA( +# 'Which city is located 120 miles west of Sydney?', +# (SELECT * FROM documents WHERE documents MATCH 'sydney OR 120'), +# options='w::city' +# ) +# }} +# """, +# """ +# SELECT date, rival, score, documents.content AS "Team Description" FROM w +# JOIN {{ +# LLMJoin( +# left_on='documents::title', +# right_on='w::rival' +# ) +# }} +# """ +# ] + +TEST_QUERIES = [ + """ + SELECT title, player FROM w JOIN {{ + LLMJoin( + left_on='documents::title', + right_on='w::player' + ) + }} + """ +] +if __name__ == "__main__": + """ + Without cached LLM response (10 runs): + before: 3.16 + after: 1.91 + With cached LLM response (100 runs): + before: 0.0175 + after: 0.0166 + """ + times = [] + db = SQLite(fetch_from_hub("1966_NBA_Expansion_Draft_0.db")) + for _ in tqdm(range(100)): + for q in TEST_QUERIES: + # Make our smoothie - the executed BlendSQL script + smoothie = blend( + query=q, + db=db, + blender=TransformersLLM("Qwen/Qwen1.5-0.5B", caching=True), + verbose=False, + ingredients={LLMJoin, LLMMap, LLMQA}, + ) + times.append(smoothie.meta.process_time_seconds) + + print(f"Average time across {len(times)} runs: {sum(times) / len(times)}") From e4d18f48e5443a664e2317b7053f8304f3dbfbc7 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 11:50:38 -0400 Subject: [PATCH 09/21] Fixed some broken test logic --- tests/test_multi_table_blendsql.py | 110 +++++++++++++++++----------- tests/test_single_table_blendsql.py | 7 ++ 2 files changed, 73 insertions(+), 44 deletions(-) diff --git a/tests/test_multi_table_blendsql.py b/tests/test_multi_table_blendsql.py index 845e0870..4620dcd3 100644 --- a/tests/test_multi_table_blendsql.py +++ b/tests/test_multi_table_blendsql.py @@ -358,12 +358,24 @@ def test_cte_qa_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient - # passed_to_ingredient = db.execute_query( - # """ - # SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 AND Quantity > 200 - # """ - # ) - # assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + passed_to_map_ingredient = db.execute_to_list( + """ + SELECT COUNT(DISTINCT Symbol) FROM portfolio + """ + )[0] + # We also need to factor in what we passed to QA ingredient + passed_to_qa_ingredient = db.execute_to_list( + """ + WITH a AS ( + SELECT * FROM (SELECT DISTINCT * FROM portfolio) as w + WHERE w.Symbol LIKE 'F%' + ) SELECT COUNT(*) FROM a WHERE LENGTH(a.Symbol) > 2 + """ + )[0] + assert ( + smoothie.meta.num_values_passed + == passed_to_qa_ingredient + passed_to_map_ingredient + ) def test_cte_qa_named_multi_exec(db, ingredients): @@ -393,12 +405,22 @@ def test_cte_qa_named_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient - # passed_to_ingredient = db.execute_query( - # """ - # SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 AND Quantity > 200 - # """ - # ) - # assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + passed_to_map_ingredient = db.execute_to_list( + """ + SELECT COUNT(DISTINCT Symbol) FROM portfolio + """ + )[0] + passed_to_qa_ingredient = db.execute_to_list( + """ + WITH a AS ( + SELECT * FROM (SELECT DISTINCT * FROM portfolio) as w + WHERE w.Symbol LIKE 'F%' + ) SELECT COUNT(*) FROM a WHERE LENGTH(a.Symbol) > 2 + """ + )[0] + assert smoothie.meta.num_values_passed == ( + passed_to_map_ingredient + passed_to_qa_ingredient + ) def test_ingredient_in_select_with_join_multi_exec(db, ingredients): @@ -468,35 +490,35 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients) assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() -# def test_subquery_alias_with_join_multi_exec(db, ingredients): -# blendsql = """ -# SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w -# JOIN {{ -# do_join( -# left_on='w::Symbol', -# right_on='geographic::Symbol' -# ) -# }} WHERE {{starts_with('F', 'w::Symbol')}} -# AND w."Percent of Account" < 0.2 -# """ -# -# sql = """ -# SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w -# JOIN geographic ON w.Symbol = geographic.Symbol -# WHERE w.Symbol LIKE 'F%' -# AND w."Percent of Account" < 0.2 -# """ -# smoothie = blend( -# query=blendsql, -# db=db, -# ingredients=ingredients, -# ) -# sql_df = db.execute_query(sql) -# assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) -# # Make sure we only pass what's necessary to our ingredient -# # passed_to_ingredient = db.execute_query( -# # """ -# # SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 -# # """ -# # ) -# # assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() +def test_subquery_alias_with_join_multi_exec(db, ingredients): + blendsql = """ + SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w + JOIN {{ + do_join( + left_on='geographic::Symbol', + right_on='w::Symbol' + ) + }} WHERE {{starts_with('F', 'w::Symbol')}} + AND w."Percent of Account" < 0.2 + """ + + sql = """ + SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w + JOIN geographic ON w.Symbol = geographic.Symbol + WHERE w.Symbol LIKE 'F%' + AND w."Percent of Account" < 0.2 + """ + smoothie = blend( + query=blendsql, + db=db, + ingredients=ingredients, + ) + sql_df = db.execute_to_df(sql) + assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) + # Make sure we only pass what's necessary to our ingredient + passed_to_ingredient = db.execute_to_df( + """ + SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE (Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) AND "Percent of Account" < 0.2 + """ + ) + assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() diff --git a/tests/test_single_table_blendsql.py b/tests/test_single_table_blendsql.py index 7ea4f8bb..0a11f361 100644 --- a/tests/test_single_table_blendsql.py +++ b/tests/test_single_table_blendsql.py @@ -140,6 +140,13 @@ def test_nested_ingredient_exec(db, ingredients): def test_nonexistent_column_exec(db, ingredients): + """ + NOTE: Converting to CNF would break this one + since we would get: + SELECT DISTINCT merchant, child_category FROM transactions WHERE + (child_category = 'Gifts' OR STRUCT(STRUCT(A())) = 1) AND + (child_category = 'Gifts' OR child_category = 'this does not exist') + """ blendsql = """ SELECT DISTINCT merchant, child_category FROM transactions WHERE ( From 7527e194bfceb479f344c058dc69ef0428b7d6fe Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 11:55:30 -0400 Subject: [PATCH 10/21] Fixed some broken test logic --- tests/test_multi_table_blendsql.py | 48 +++++++++--------- tests/test_single_table_blendsql.py | 78 ++++++++++++++--------------- 2 files changed, 63 insertions(+), 63 deletions(-) diff --git a/tests/test_multi_table_blendsql.py b/tests/test_multi_table_blendsql.py index 4620dcd3..280d3285 100644 --- a/tests/test_multi_table_blendsql.py +++ b/tests/test_multi_table_blendsql.py @@ -65,7 +65,7 @@ def test_simple_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE Symbol in ( @@ -73,8 +73,8 @@ def test_simple_multi_exec(db, ingredients): WHERE sector = 'Information Technology' ) """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_join_multi_exec(db, ingredients): @@ -102,12 +102,12 @@ def test_join_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Name) FROM constituents WHERE Sector = 'Information Technology' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_join_not_qualified_multi_exec(db, ingredients): @@ -139,12 +139,12 @@ def test_join_not_qualified_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Name) FROM constituents WHERE Sector = 'Information Technology' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_select_multi_exec(db, ingredients): @@ -292,12 +292,12 @@ def test_table_alias_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_subquery_alias_multi_exec(db, ingredients): @@ -323,12 +323,12 @@ def test_subquery_alias_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 AND Quantity > 200 """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_cte_qa_multi_exec(db, ingredients): @@ -447,14 +447,14 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol WHERE account_history.Action like "%dividend%" """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients): @@ -480,14 +480,14 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients) sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol WHERE account_history.Action like "%dividend%" """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_subquery_alias_with_join_multi_exec(db, ingredients): @@ -516,9 +516,9 @@ def test_subquery_alias_with_join_multi_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE (Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) AND "Percent of Account" < 0.2 """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient diff --git a/tests/test_single_table_blendsql.py b/tests/test_single_table_blendsql.py index 0a11f361..3848328c 100644 --- a/tests/test_single_table_blendsql.py +++ b/tests/test_single_table_blendsql.py @@ -98,12 +98,12 @@ def test_simple_ingredient_exec_in_select(db, ingredients): ingredients=ingredients, ) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE parent_category = 'Auto & Transport' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_nested_ingredient_exec(db, ingredients): @@ -131,12 +131,12 @@ def test_nested_ingredient_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 100 """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_nonexistent_column_exec(db, ingredients): @@ -171,12 +171,12 @@ def test_nonexistent_column_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'this does not exist' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_nested_and_exec(db, ingredients): @@ -204,12 +204,12 @@ def test_nested_and_exec(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["O"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_multiple_nested_ingredients(db, ingredients): @@ -239,12 +239,12 @@ def test_multiple_nested_ingredients(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A", "T"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) + COUNT(DISTINCT child_category) FROM transactions WHERE parent_category = 'Food' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_length_ingredient(db, ingredients): @@ -266,12 +266,12 @@ def test_length_ingredient(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_max_length(db, ingredients): @@ -293,12 +293,12 @@ def test_max_length(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_limit(db, ingredients): @@ -322,12 +322,12 @@ def test_limit(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_select(db, ingredients): @@ -363,12 +363,12 @@ def test_ingredient_in_select_stmt(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_ingredient_in_select_stmt_with_filter(db, ingredients): @@ -387,12 +387,12 @@ def test_ingredient_in_select_stmt_with_filter(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_nested_duplicate_map_calls(db, ingredients): @@ -410,12 +410,12 @@ def test_nested_duplicate_map_calls(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_many_duplicate_map_calls(db, ingredients): @@ -443,7 +443,7 @@ def test_many_duplicate_map_calls(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT (SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 1300) @@ -451,8 +451,8 @@ def test_many_duplicate_map_calls(db, ingredients): + (SELECT COUNT(DISTINCT child_category) FROM transactions WHERE amount > 1300) + (SELECT COUNT(DISTINCT date) FROM transactions WHERE amount > 1300) """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_exists_isolated_qa_call(db, ingredients): @@ -482,12 +482,12 @@ def test_exists_isolated_qa_call(db, ingredients): sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_to_df( + passed_to_ingredient = db.execute_to_list( """ SELECT (SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 500) + (SELECT COUNT(*) FROM transactions WHERE amount < 500) """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_query_options_arg(db, ingredients): From 2fa6f98cab7911cc39634806671e61234f8a3ba9 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:28:47 -0400 Subject: [PATCH 11/21] applymap -> map --- blendsql/db/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blendsql/db/utils.py b/blendsql/db/utils.py index 87dccec6..e1656e6c 100644 --- a/blendsql/db/utils.py +++ b/blendsql/db/utils.py @@ -20,7 +20,7 @@ def select_all_from_table_query(tablename: str) -> str: def truncate_df_content(df: pd.DataFrame, truncation_limit: int) -> pd.DataFrame: # Truncate long strings - return df.applymap( + return df.map( lambda x: f"{str(x)[:truncation_limit]}..." if isinstance(x, str) and len(str(x)) > truncation_limit else x From 385030b5e1e3c0be447f852233dd2d7218805961 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:48:06 -0400 Subject: [PATCH 12/21] Removing unnecessary print --- blendsql/grammars/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/blendsql/grammars/utils.py b/blendsql/grammars/utils.py index 1c541c81..3c6b1b60 100644 --- a/blendsql/grammars/utils.py +++ b/blendsql/grammars/utils.py @@ -33,9 +33,7 @@ def load_cfg_parser(ingredients: Optional[Collection[Ingredient]]) -> EarleyPars } for ingredient in ingredients: if ingredient.ingredient_type not in ingredient_type_to_function_type: - print( - f"Not sure what to do with ingredient type '{ingredient.ingredient_type}'" - ) + # TODO: handle these cases continue ingredient_type_to_function_type[ingredient.ingredient_type].append( ingredient.__name__ From 81392ed88defe01e2785260cf11d78581e7922a9 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:48:15 -0400 Subject: [PATCH 13/21] blend-cli docs --- docs/reference/blend-cli.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 docs/reference/blend-cli.md diff --git a/docs/reference/blend-cli.md b/docs/reference/blend-cli.md new file mode 100644 index 00000000..b1fd4d76 --- /dev/null +++ b/docs/reference/blend-cli.md @@ -0,0 +1,24 @@ +# blend cli + +``` +usage: blendsql [-h] [-v] + [db_url] [{openai,azure_openai,llama_cpp,transformers}] + [model_name_or_path] + +positional arguments: + db_url Database URL, + {openai,azure_openai,llama_cpp,transformers} + Model type, for the Blender to use in executing the BlendSQL + query. + model_name_or_path Model identifier to pass to the selected model_type class. + +optional arguments: + -h, --help show this help message and exit + -v Flag to run in verbose mode. +``` + +Example Usage: + +```bash +blendsql mydb.db openai gpt-3.5-turbo -v +``` \ No newline at end of file From 60012bd22093c3bfdf50773396a0a664435398b6 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:48:20 -0400 Subject: [PATCH 14/21] blend-cli docs --- mkdocs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/mkdocs.yml b/mkdocs.yml index 14820919..fb76022f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -74,6 +74,7 @@ nav: - Teaching BlendSQL via In-Context Learning: reference/examples/teaching-blendsql-via-in-context-learning.ipynb - Documentation: - Execute a BlendSQL Query: reference/execute-blendsql.md + - Blend CLI: reference/blend-cli.md - Ingredients: - reference/ingredients/ingredients.md - Creating Custom Ingredients: reference/ingredients/creating-custom-ingredients.md From f2cc83d4819c0daa57feee410fd1c692d66e30de Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:48:39 -0400 Subject: [PATCH 15/21] include_package_data = True --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 43f28ae9..37e59a46 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ def find_version(*file_paths): long_description_content_type="text/markdown", license="Apache License 2.0", packages=find_packages(exclude=["examples", "research", "img"]), - data_files=["blendsql/grammars/_cfg_grammar.lark"], + include_package_data=True, install_requires=[ "guidance>=0.1.0", "pyparsing==3.1.1", From 1daf05bb5eb868da2981607a8a3693f087e8aa32 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:49:05 -0400 Subject: [PATCH 16/21] Updating blend_cli.py to be more customizable --- blendsql/blend_cli.py | 56 +++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/blendsql/blend_cli.py b/blendsql/blend_cli.py index 51187441..48843d8b 100644 --- a/blendsql/blend_cli.py +++ b/blendsql/blend_cli.py @@ -1,16 +1,23 @@ import os import argparse import importlib -import json from blendsql import blend from blendsql.db import SQLite +from blendsql.db.utils import truncate_df_content from blendsql.utils import tabulate -from blendsql.models import LlamaCppLLM -from blendsql.ingredients.builtin import LLMQA, LLMMap, LLMJoin, DT +from blendsql.models import OpenaiLLM, TransformersLLM, AzureOpenaiLLM, LlamaCppLLM +from blendsql.ingredients.builtin import LLMQA, LLMMap, LLMJoin _has_readline = importlib.util.find_spec("readline") is not None +MODEL_TYPE_TO_CLASS = { + "openai": OpenaiLLM, + "azure_openai": AzureOpenaiLLM, + "llama_cpp": LlamaCppLLM, + "transformers": TransformersLLM, +} + def print_msg_box(msg, indent=1, width=None, title=None): """Print message-box with optional title.""" @@ -37,8 +44,21 @@ def main(): _ = readline parser = argparse.ArgumentParser() - parser.add_argument("db_url", nargs="?") - parser.add_argument("secrets_path", nargs="?", default="./secrets.json") + parser.add_argument("db_url", nargs="?", help="Database URL,") + parser.add_argument( + "model_type", + nargs="?", + default="openai", + choices=list(MODEL_TYPE_TO_CLASS.keys()), + help="Model type, for the Blender to use in executing the BlendSQL query.", + ) + parser.add_argument( + "model_name_or_path", + nargs="?", + default="gpt-3.5-turbo", + help="Model identifier to pass to the selected model_type class.", + ) + parser.add_argument("-v", action="store_true", help="Flag to run in verbose mode.") args = parser.parse_args() db = SQLite(db_url=args.db_url) @@ -61,30 +81,14 @@ def main(): smoothie = blend( query=text, db=db, - ingredients={LLMQA, LLMMap, LLMJoin, DT}, - blender=LlamaCppLLM( - "./lark-constrained-parsing/tinyllama-1.1b-chat-v1.0.Q2_K.gguf" + ingredients={LLMQA, LLMMap, LLMJoin}, + blender=MODEL_TYPE_TO_CLASS.get(args.model_type)( + args.model_name_or_path ), infer_gen_constraints=True, - verbose=True, + verbose=args.v, ) print() - print(tabulate(smoothie.df.iloc[:10])) - print() - print(json.dumps(smoothie.meta.prompts, indent=4)) + print(tabulate(truncate_df_content(smoothie.df, 50))) except Exception as error: print(error) - - -""" -SELECT "common name" AS 'State Flower' FROM w -WHERE state = {{ - LLMQA( - 'Which is the smallest state by area?', - (SELECT title, content FROM documents WHERE documents MATCH 'smallest OR state OR area' LIMIT 3), - options='w::state' - ) -}} - -SELECT Symbol, Description, Quantity FROM portfolio WHERE {{LLMMap('Do they manufacture cell phones?', 'portfolio::Description')}} = TRUE AND portfolio.Symbol in (SELECT Symbol FROM constituents WHERE Sector = 'Information Technology') -""" From 2c71f44bffa3ad5105647a975a6d242ea121df8b Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:49:26 -0400 Subject: [PATCH 17/21] _llama_cpp.py typos --- blendsql/models/local/_llama_cpp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blendsql/models/local/_llama_cpp.py b/blendsql/models/local/_llama_cpp.py index b81a9fb2..ce849a8d 100644 --- a/blendsql/models/local/_llama_cpp.py +++ b/blendsql/models/local/_llama_cpp.py @@ -10,7 +10,7 @@ class LlamaCppLLM(Model): - """Class for Transformers local Model. + """Class for llama-cpp local Model. Args: model_name_or_path: Name of the model on HuggingFace, or the path to a local model @@ -19,7 +19,7 @@ class LlamaCppLLM(Model): def __init__(self, model_name_or_path: str, **kwargs): if not _has_llama_cpp: raise ImportError( - "Please install llama_cpp with `pip install llama_cpp`!" + "Please install llama_cpp with `pip install llama-cpp-python`!" ) from None super().__init__( From ec790e9bc83bb25cec70ab90e675902ba523d024 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:52:42 -0400 Subject: [PATCH 18/21] exrex + lark to main requirements --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 37e59a46..96cb7833 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,8 @@ def find_version(*file_paths): "python-dotenv==1.0.1", "sqlglot==18.13.0", "sqlalchemy>=2.0.0", + "lark", + "exrex", "platformdirs", "attrs", "tqdm", @@ -50,7 +52,6 @@ def find_version(*file_paths): "typeguard", ], extras_require={ - "nl_to_blendsql": ["exrex", "lark"], "research": [ "datasets==2.16.1", "nltk", From 846cb8763c3f741b305efdeb75f0374e822d87a5 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 13:56:53 -0400 Subject: [PATCH 19/21] Prioritize exception from CFG parser, if exists This allows the user to get a more helpful grammar error, in the case of bad syntax. Prior to this, errors were often vague, pointing at the internal workings of the query optimization logic. --- blendsql/blend.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/blendsql/blend.py b/blendsql/blend.py index 4a0187be..3cdf6cd5 100644 --- a/blendsql/blend.py +++ b/blendsql/blend.py @@ -880,6 +880,15 @@ def blend( schema_qualify=schema_qualify, ) except Exception as error: + from .grammars.minEarley.parser import EarleyParser + from .grammars.utils import load_cfg_parser + + # Parse with CFG and try to get helpful recommendations + parser: EarleyParser = load_cfg_parser(ingredients) + try: + parser.parse(query) + except Exception as parser_error: + raise parser_error raise error finally: # In the case of a recursive `_blend()` call, From 4b29daf7d7f38ebe56a2d5f65f951315d0eaf711 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 14:08:50 -0400 Subject: [PATCH 20/21] Creating dedicated _exceptions.py --- blendsql/_exceptions.py | 6 ++++++ blendsql/blend.py | 24 ++++++++++++++---------- blendsql/ingredients/ingredient.py | 5 +---- tests/test_generic_blendsql.py | 6 +++--- 4 files changed, 24 insertions(+), 17 deletions(-) create mode 100644 blendsql/_exceptions.py diff --git a/blendsql/_exceptions.py b/blendsql/_exceptions.py new file mode 100644 index 00000000..4a2a3959 --- /dev/null +++ b/blendsql/_exceptions.py @@ -0,0 +1,6 @@ +class InvalidBlendSQL(ValueError): + pass + + +class IngredientException(ValueError): + pass diff --git a/blendsql/blend.py b/blendsql/blend.py index 3cdf6cd5..e7bf0a59 100644 --- a/blendsql/blend.py +++ b/blendsql/blend.py @@ -15,6 +15,7 @@ Optional, Callable, Collection, + Union, ) from sqlite3 import OperationalError import sqlglot.expressions @@ -31,6 +32,7 @@ recover_blendsql, get_tablename_colname, ) +from ._exceptions import InvalidBlendSQL from .db import Database from .db.utils import double_quote_escape, select_all_from_table_query from ._sqlglot import ( @@ -428,7 +430,7 @@ def _blend( # Preliminary check - we can't have anything that modifies database state if _query.find(MODIFIERS): - raise ValueError("BlendSQL query cannot have `DELETE` clause!") + raise InvalidBlendSQL("BlendSQL query cannot have `DELETE` clause!") # If there's no `SELECT` and just a QAIngredient, wrap it in a `SELECT CASE` query if _query.find(exp.Select) is None: @@ -572,6 +574,7 @@ def _blend( if prev_subquery_has_ingredient: scm.set_node(scm.node.transform(maybe_set_subqueries_to_true)) + lazy_limit: Union[int, None] = scm.get_lazy_limit() # After above processing of AST, sync back to string repr subquery_str = scm.sql() # Now, 1) Find all ingredients to execute (e.g. '{{f(a, b, c)}}') @@ -880,15 +883,16 @@ def blend( schema_qualify=schema_qualify, ) except Exception as error: - from .grammars.minEarley.parser import EarleyParser - from .grammars.utils import load_cfg_parser - - # Parse with CFG and try to get helpful recommendations - parser: EarleyParser = load_cfg_parser(ingredients) - try: - parser.parse(query) - except Exception as parser_error: - raise parser_error + if not isinstance(error, (InvalidBlendSQL, IngredientException)): + from .grammars.minEarley.parser import EarleyParser + from .grammars.utils import load_cfg_parser + + # Parse with CFG and try to get helpful recommendations + parser: EarleyParser = load_cfg_parser(ingredients) + try: + parser.parse(query) + except Exception as parser_error: + raise parser_error raise error finally: # In the case of a recursive `_blend()` call, diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index 4438d8f1..9c29f8ed 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -16,16 +16,13 @@ import uuid from typeguard import check_type +from .._exceptions import IngredientException from .. import utils from .._constants import IngredientKwarg, IngredientType from ..db import Database from ..db.utils import select_all_from_table_query -class IngredientException(ValueError): - pass - - def unpack_default_kwargs(**kwargs): return ( kwargs.get("tablename"), diff --git a/tests/test_generic_blendsql.py b/tests/test_generic_blendsql.py index aedcc0d6..d0ac0375 100644 --- a/tests/test_generic_blendsql.py +++ b/tests/test_generic_blendsql.py @@ -4,7 +4,7 @@ from pathlib import Path from blendsql import blend from blendsql.db import SQLite -from blendsql.ingredients.ingredient import IngredientException +from blendsql._exceptions import IngredientException, InvalidBlendSQL @pytest.fixture(scope="session") @@ -23,7 +23,7 @@ def test_error_on_delete1(db): blendsql = """ DELETE FROM w WHERE TRUE; """ - with pytest.raises(ValueError): + with pytest.raises(InvalidBlendSQL): _ = blend( query=blendsql, db=db, @@ -35,7 +35,7 @@ def test_error_on_delete2(db): blendsql = """ DROP TABLE w; """ - with pytest.raises(ValueError): + with pytest.raises(InvalidBlendSQL): _ = blend( query=blendsql, db=db, From d3a6af62e345bf8319b93756ccbeab516c0a634f Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 18 May 2024 14:10:02 -0400 Subject: [PATCH 21/21] Adding Acknowledgements section --- README.md | 10 ++++++++++ docs/index.md | 12 +++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a3fa0115..d4593c92 100644 --- a/README.md +++ b/README.md @@ -388,6 +388,16 @@ print(smoothie.meta.prompts) } ``` +### Acknowledgements +Special thanks to those below for inspiring this project. Definitely recommend checking out the linked work below, and citing when applicable! + +- The authors of [Binding Language Models in Symbolic Languages](https://arxiv.org/abs/2210.02875) + - This paper was the primary inspiration for BlendSQL. +- The authors of [EHRXQA: A Multi-Modal Question Answering Dataset for Electronic Health Records with Chest X-ray Images](https://arxiv.org/pdf/2310.18652) + - As far as I can tell, the first publication to propose unifying model calls within SQL + - Served as the inspiration for the [vqa-ingredient.ipynb](./examples/vqa-ingredient.ipynb) example +- The authors of [Grammar Prompting for Domain-Specific Language Generation with Large Language Models](https://arxiv.org/abs/2305.19234) + # Documentation diff --git a/docs/index.md b/docs/index.md index a59c283d..134285c2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -121,4 +121,14 @@ For a technical walkthrough of how a BlendSQL query is executed, check out [tech archivePrefix={arXiv}, primaryClass={cs.CL} } -``` \ No newline at end of file +``` + +### Acknowledgements +Special thanks to those below for inspiring this project. Definitely recommend checking out the linked work below, and citing when applicable! + +- The authors of [Binding Language Models in Symbolic Languages](https://arxiv.org/abs/2210.02875) + - This paper was the primary inspiration for BlendSQL. +- The authors of [EHRXQA: A Multi-Modal Question Answering Dataset for Electronic Health Records with Chest X-ray Images](https://arxiv.org/pdf/2310.18652) + - As far as I can tell, the first publication to propose unifying model calls within SQL + - Served as the inspiration for the [vqa-ingredient.ipynb](./examples/vqa-ingredient.ipynb) example +- The authors of [Grammar Prompting for Domain-Specific Language Generation with Large Language Models](https://arxiv.org/abs/2305.19234)