diff --git a/README.md b/README.md index a3fa0115..d4593c92 100644 --- a/README.md +++ b/README.md @@ -388,6 +388,16 @@ print(smoothie.meta.prompts) } ``` +### Acknowledgements +Special thanks to those below for inspiring this project. Definitely recommend checking out the linked work below, and citing when applicable! + +- The authors of [Binding Language Models in Symbolic Languages](https://arxiv.org/abs/2210.02875) + - This paper was the primary inspiration for BlendSQL. +- The authors of [EHRXQA: A Multi-Modal Question Answering Dataset for Electronic Health Records with Chest X-ray Images](https://arxiv.org/pdf/2310.18652) + - As far as I can tell, the first publication to propose unifying model calls within SQL + - Served as the inspiration for the [vqa-ingredient.ipynb](./examples/vqa-ingredient.ipynb) example +- The authors of [Grammar Prompting for Domain-Specific Language Generation with Large Language Models](https://arxiv.org/abs/2305.19234) + # Documentation diff --git a/blendsql/_constants.py b/blendsql/_constants.py index cc46da60..45273bcc 100644 --- a/blendsql/_constants.py +++ b/blendsql/_constants.py @@ -1,4 +1,4 @@ -from enum import Enum, EnumMeta, auto +from enum import Enum, EnumMeta from dataclasses import dataclass HF_REPO_ID = "parkervg/blendsql-test-dbs" @@ -15,10 +15,10 @@ def __contains__(cls, item): class IngredientType(str, Enum, metaclass=StrInMeta): - MAP = auto() - STRING = auto() - QA = auto() - JOIN = auto() + MAP = "MAP" + STRING = "STRING" + QA = "QA" + JOIN = "JOIN" @dataclass diff --git a/blendsql/_exceptions.py b/blendsql/_exceptions.py new file mode 100644 index 00000000..4a2a3959 --- /dev/null +++ b/blendsql/_exceptions.py @@ -0,0 +1,6 @@ +class InvalidBlendSQL(ValueError): + pass + + +class IngredientException(ValueError): + pass diff --git a/blendsql/blend.py b/blendsql/blend.py index 31c756a8..e7bf0a59 100644 --- a/blendsql/blend.py +++ b/blendsql/blend.py @@ -15,6 +15,7 @@ Optional, Callable, Collection, + Union, ) from sqlite3 import OperationalError import sqlglot.expressions @@ -31,8 +32,9 @@ recover_blendsql, get_tablename_colname, ) +from ._exceptions import InvalidBlendSQL from .db import Database -from .db.utils import double_quote_escape, single_quote_escape +from .db.utils import double_quote_escape, select_all_from_table_query from ._sqlglot import ( MODIFIERS, get_first_child, @@ -428,7 +430,7 @@ def _blend( # Preliminary check - we can't have anything that modifies database state if _query.find(MODIFIERS): - raise ValueError("BlendSQL query cannot have `DELETE` clause!") + raise InvalidBlendSQL("BlendSQL query cannot have `DELETE` clause!") # If there's no `SELECT` and just a QAIngredient, wrap it in a `SELECT CASE` query if _query.find(exp.Select) is None: @@ -442,7 +444,7 @@ def _blend( # If we don't have any ingredient calls, execute as normal SQL if len(ingredients) == 0 or len(ingredient_alias_to_parsed_dict) == 0: return Smoothie( - df=db.execute_query(query), + df=db.execute_to_df(query), meta=SmoothieMeta( num_values_passed=0, num_prompt_tokens=0, @@ -541,7 +543,7 @@ def _blend( ) try: db.to_temp_table( - df=db.execute_query(abstracted_query), + df=db.execute_to_df(abstracted_query), tablename=_get_temp_subquery_table(tablename), ) except OperationalError as e: @@ -572,6 +574,7 @@ def _blend( if prev_subquery_has_ingredient: scm.set_node(scm.node.transform(maybe_set_subqueries_to_true)) + lazy_limit: Union[int, None] = scm.get_lazy_limit() # After above processing of AST, sync back to string repr subquery_str = scm.sql() # Now, 1) Find all ingredients to execute (e.g. '{{f(a, b, c)}}') @@ -741,14 +744,14 @@ def _blend( # On their left join merge command: https://github.com/HKUNLP/Binder/blob/9eede69186ef3f621d2a50572e1696bc418c0e77/nsql/database.py#L196 # We create a new temp table to avoid a potentially self-destructive operation base_tablename = tablename - _base_table: pd.DataFrame = db.execute_query( - f'SELECT * FROM "{double_quote_escape(base_tablename)}";' + _base_table: pd.DataFrame = db.execute_to_df( + select_all_from_table_query(base_tablename) ) base_table = _base_table if db.has_temp_table(_get_temp_session_table(tablename)): base_tablename = _get_temp_session_table(tablename) - base_table: pd.DataFrame = db.execute_query( - f"SELECT * FROM '{single_quote_escape(base_tablename)}';", + base_table: pd.DataFrame = db.execute_to_df( + select_all_from_table_query(base_tablename) ) previously_added_columns = base_table.columns.difference( _base_table.columns @@ -804,7 +807,7 @@ def _blend( ) logging.debug("") - df = db.execute_query(query) + df = db.execute_to_df(query) return Smoothie( df=df, @@ -880,6 +883,16 @@ def blend( schema_qualify=schema_qualify, ) except Exception as error: + if not isinstance(error, (InvalidBlendSQL, IngredientException)): + from .grammars.minEarley.parser import EarleyParser + from .grammars.utils import load_cfg_parser + + # Parse with CFG and try to get helpful recommendations + parser: EarleyParser = load_cfg_parser(ingredients) + try: + parser.parse(query) + except Exception as parser_error: + raise parser_error raise error finally: # In the case of a recursive `_blend()` call, diff --git a/blendsql/blend_cli.py b/blendsql/blend_cli.py index 51187441..48843d8b 100644 --- a/blendsql/blend_cli.py +++ b/blendsql/blend_cli.py @@ -1,16 +1,23 @@ import os import argparse import importlib -import json from blendsql import blend from blendsql.db import SQLite +from blendsql.db.utils import truncate_df_content from blendsql.utils import tabulate -from blendsql.models import LlamaCppLLM -from blendsql.ingredients.builtin import LLMQA, LLMMap, LLMJoin, DT +from blendsql.models import OpenaiLLM, TransformersLLM, AzureOpenaiLLM, LlamaCppLLM +from blendsql.ingredients.builtin import LLMQA, LLMMap, LLMJoin _has_readline = importlib.util.find_spec("readline") is not None +MODEL_TYPE_TO_CLASS = { + "openai": OpenaiLLM, + "azure_openai": AzureOpenaiLLM, + "llama_cpp": LlamaCppLLM, + "transformers": TransformersLLM, +} + def print_msg_box(msg, indent=1, width=None, title=None): """Print message-box with optional title.""" @@ -37,8 +44,21 @@ def main(): _ = readline parser = argparse.ArgumentParser() - parser.add_argument("db_url", nargs="?") - parser.add_argument("secrets_path", nargs="?", default="./secrets.json") + parser.add_argument("db_url", nargs="?", help="Database URL,") + parser.add_argument( + "model_type", + nargs="?", + default="openai", + choices=list(MODEL_TYPE_TO_CLASS.keys()), + help="Model type, for the Blender to use in executing the BlendSQL query.", + ) + parser.add_argument( + "model_name_or_path", + nargs="?", + default="gpt-3.5-turbo", + help="Model identifier to pass to the selected model_type class.", + ) + parser.add_argument("-v", action="store_true", help="Flag to run in verbose mode.") args = parser.parse_args() db = SQLite(db_url=args.db_url) @@ -61,30 +81,14 @@ def main(): smoothie = blend( query=text, db=db, - ingredients={LLMQA, LLMMap, LLMJoin, DT}, - blender=LlamaCppLLM( - "./lark-constrained-parsing/tinyllama-1.1b-chat-v1.0.Q2_K.gguf" + ingredients={LLMQA, LLMMap, LLMJoin}, + blender=MODEL_TYPE_TO_CLASS.get(args.model_type)( + args.model_name_or_path ), infer_gen_constraints=True, - verbose=True, + verbose=args.v, ) print() - print(tabulate(smoothie.df.iloc[:10])) - print() - print(json.dumps(smoothie.meta.prompts, indent=4)) + print(tabulate(truncate_df_content(smoothie.df, 50))) except Exception as error: print(error) - - -""" -SELECT "common name" AS 'State Flower' FROM w -WHERE state = {{ - LLMQA( - 'Which is the smallest state by area?', - (SELECT title, content FROM documents WHERE documents MATCH 'smallest OR state OR area' LIMIT 3), - options='w::state' - ) -}} - -SELECT Symbol, Description, Quantity FROM portfolio WHERE {{LLMMap('Do they manufacture cell phones?', 'portfolio::Description')}} = TRUE AND portfolio.Symbol in (SELECT Symbol FROM constituents WHERE Sector = 'Information Technology') -""" diff --git a/blendsql/db/_database.py b/blendsql/db/_database.py index cbcb9c35..584ff56f 100644 --- a/blendsql/db/_database.py +++ b/blendsql/db/_database.py @@ -1,4 +1,4 @@ -from typing import Generator, List, Dict, Collection +from typing import Generator, List, Dict, Collection, Type, Optional from typing import Iterable import pandas as pd from colorama import Fore @@ -86,7 +86,7 @@ def to_temp_table(self, df: pd.DataFrame, tablename: str): self.con.execute(text(create_table_stmt)) df.to_sql(name=tablename, con=self.con, if_exists="append", index=False) - def execute_query(self, query: str, params: dict = None) -> pd.DataFrame: + def execute_to_df(self, query: str, params: dict = None) -> pd.DataFrame: """ Execute the given query and return results as dataframe. @@ -106,3 +106,14 @@ def execute_query(self, query: str, params: dict = None) -> pd.DataFrame: ``` """ return pd.read_sql(text(query), self.con, params=params) + + def execute_to_list( + self, query: str, to_type: Optional[Type] = lambda x: x + ) -> list: + """A lower-level execute method that doesn't use the pandas processing logic. + Returns results as a tuple. + """ + res = [] + for row in self.con.execute(text(query)).fetchall(): + res.append(to_type(row[0])) + return res diff --git a/blendsql/db/_postgres.py b/blendsql/db/_postgres.py index 323d4131..779eaca8 100644 --- a/blendsql/db/_postgres.py +++ b/blendsql/db/_postgres.py @@ -36,7 +36,7 @@ def __init__(self, db_path: str): def has_temp_table(self, tablename: str) -> bool: return ( tablename - in self.execute_query( + in self.execute_to_df( "SELECT * FROM information_schema.tables WHERE table_schema LIKE 'pg_temp_%'" )["table_name"].unique() ) diff --git a/blendsql/db/_sqlite.py b/blendsql/db/_sqlite.py index 767d9fe3..a1077db3 100644 --- a/blendsql/db/_sqlite.py +++ b/blendsql/db/_sqlite.py @@ -23,7 +23,7 @@ def __init__(self, db_url: str): def has_temp_table(self, tablename: str) -> bool: return ( tablename - in self.execute_query( + in self.execute_to_df( "SELECT name FROM sqlite_temp_master WHERE type='table';" )["name"].unique() ) @@ -39,7 +39,7 @@ def get_sqlglot_schema(self) -> dict: schema = {} for tablename in self.tables(): schema[f'"{double_quote_escape(tablename)}"'] = {} - for _, row in self.execute_query( + for _, row in self.execute_to_df( f""" SELECT name, type FROM pragma_table_info(:t) """, diff --git a/blendsql/db/utils.py b/blendsql/db/utils.py index 2a0591d0..e1656e6c 100644 --- a/blendsql/db/utils.py +++ b/blendsql/db/utils.py @@ -14,9 +14,13 @@ def escape(s): return single_quote_escape(double_quote_escape(s)) +def select_all_from_table_query(tablename: str) -> str: + return f'SELECT * FROM "{double_quote_escape(tablename)}";' + + def truncate_df_content(df: pd.DataFrame, truncation_limit: int) -> pd.DataFrame: # Truncate long strings - return df.applymap( + return df.map( lambda x: f"{str(x)[:truncation_limit]}..." if isinstance(x, str) and len(str(x)) > truncation_limit else x diff --git a/blendsql/grammars/_cfg_grammar.lark b/blendsql/grammars/_cfg_grammar.lark index 99ee5646..24c34d3f 100644 --- a/blendsql/grammars/_cfg_grammar.lark +++ b/blendsql/grammars/_cfg_grammar.lark @@ -66,22 +66,22 @@ JOIN_TYPE: "INNER"i | "FULL"i ["OUTER"i] | "LEFT"i["OUTER"i] | "RIGHT"i ["OUTER" | "CASE"i (when_then)+ "ELSE"i expression_math "END"i -> case_expression | "CAST"i "(" expression_math "AS"i TYPENAME ")" -> as_type | "CAST"i "(" literal "AS"i TYPENAME ")" -> literal_cast - | AGGREGATION_FUNCTIONS expression_math ")" [window_form] -> sql_aggregation + | AGGREGATE_FUNCTIONS expression_math ")" [window_form] -> sql_aggregation | SCALAR_FUNCTIONS [(expression_math ",")*] expression_math ")" -> sql_scalar | blendsql_aggregation_expr -> blendsql_aggregation | "RANK"i "(" ")" window_form -> rank_expression | "DENSE_RANK"i "(" ")" window_form -> dense_rank_expression | "|" "|" expression_math -BLENDSQL_AGGREGATION: ("LLMQA("i | "LLMVerify("i) -BLENDSQL_JOIN: ("LLMJoin("i) +BLENDSQL_AGGREGATE_FUNCTIONS: $blendsql_aggregate_functions +BLENDSQL_JOIN_FUNCTIONS: $blendsql_join_functions left_on_arg: "left_on" "=" string right_on_arg: "right_on" "=" string blendsql_arg: (name "=" literal | literal | "(" start ")") blendsql_expression_math: blendsql_arg ("," blendsql_arg)* -blendsql_aggregation_expr: "{{" BLENDSQL_AGGREGATION blendsql_expression_math ")" "}}" -blendsql_join_expr: "{{" BLENDSQL_JOIN (left_on_arg "," right_on_arg|right_on_arg "," left_on_arg) ")" "}}" +blendsql_aggregation_expr: "{{" BLENDSQL_AGGREGATE_FUNCTIONS blendsql_expression_math ")" "}}" +blendsql_join_expr: "{{" BLENDSQL_JOIN_FUNCTIONS (left_on_arg "," right_on_arg|right_on_arg "," left_on_arg) ")" "}}" window_form: "OVER"i "(" ["PARTITION"i "BY"i (expression_math ",")* expression_math] ["ORDER"i "BY"i (order ",")* order [ row_range_clause ] ] ")" @@ -125,7 +125,7 @@ TYPENAME: "object"i | "string"i // https://www.sqlite.org/lang_expr.html#*funcinexpr -AGGREGATION_FUNCTIONS: ("sum("i | "avg("i | "min("i | "max("i | "count("i ["distinct"i] ) +AGGREGATE_FUNCTIONS: ("sum("i | "avg("i | "min("i | "max("i | "count("i ["distinct"i] ) SCALAR_FUNCTIONS: ("trim("i | "coalesce("i | "abs("i) alias: string -> alias_string diff --git a/blendsql/grammars/utils.py b/blendsql/grammars/utils.py new file mode 100644 index 00000000..3c6b1b60 --- /dev/null +++ b/blendsql/grammars/utils.py @@ -0,0 +1,61 @@ +from pathlib import Path +from typing import Optional, Collection, List, Dict +from string import Template + +from ..ingredients import Ingredient +from .._constants import IngredientType +from .minEarley.parser import EarleyParser + + +def format_ingredient_names_to_lark(names: List[str]) -> str: + """Formats list of ingredient names the way our Lark grammar expects. + + Examples: + ```python + format_ingredient_names_to_lark(["LLMQA", "LLMVerify"]) + >>> '("LLMQA("i | "LLMVerify("i)' + ``` + """ + return "(" + " | ".join([f'"{n}("i' for n in names]) + ")" + + +def load_cfg_parser(ingredients: Optional[Collection[Ingredient]]) -> EarleyParser: + """Loads BlendSQL CFG parser. + Dynamically modifies grammar string to include only valid ingredients. + """ + with open(Path(__file__).parent / "./_cfg_grammar.lark", encoding="utf-8") as f: + cfg_grammar = Template(f.read()) + blendsql_join_functions = [] + blendsql_aggregate_functions = [] + ingredient_type_to_function_type: Dict[str, List[str]] = { + IngredientType.JOIN: blendsql_join_functions, + IngredientType.QA: blendsql_aggregate_functions, + } + for ingredient in ingredients: + if ingredient.ingredient_type not in ingredient_type_to_function_type: + # TODO: handle these cases + continue + ingredient_type_to_function_type[ingredient.ingredient_type].append( + ingredient.__name__ + ) + cfg_grammar = cfg_grammar.substitute( + blendsql_join_functions=format_ingredient_names_to_lark( + blendsql_join_functions + ), + blendsql_aggregate_functions=format_ingredient_names_to_lark( + blendsql_aggregate_functions + ), + ) + return EarleyParser( + grammar=cfg_grammar, + start="start", + keep_all_tokens=True, + ) + + +if __name__ == "__main__": + from blendsql import LLMMap, LLMJoin, LLMValidate + + parser = load_cfg_parser({LLMMap, LLMJoin, LLMValidate}) + # parser.parse("{{LLMQA('what is the answer', (select * from w))}}") + print() diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index 1d4db4bc..9c29f8ed 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -16,14 +16,11 @@ import uuid from typeguard import check_type +from .._exceptions import IngredientException from .. import utils from .._constants import IngredientKwarg, IngredientType from ..db import Database -from ..db.utils import double_quote_escape - - -class IngredientException(ValueError): - pass +from ..db.utils import select_all_from_table_query def unpack_default_kwargs(**kwargs): @@ -33,17 +30,6 @@ def unpack_default_kwargs(**kwargs): ) -def align_to_real_columns(db: Database, colname: str, tablename: str) -> str: - table_columns = db.execute_query(f'SELECT * FROM "{tablename}" LIMIT 1').columns - if colname not in table_columns: - # Try to align with column, according to some normalization rules - cleaned_to_original = { - col.replace("\\n", " ").replace("\xa0", " "): col for col in table_columns - } - colname = cleaned_to_original[colname] - return colname - - @attrs class Ingredient(ABC): name: str = attrib() @@ -118,36 +104,33 @@ def __call__(self, question: str, context: str, *args, **kwargs) -> tuple: ): new_arg_column = "_" + new_arg_column - original_table = self.db.execute_query(f'SELECT * FROM "{tablename}"') + original_table = self.db.execute_to_df(select_all_from_table_query(tablename)) # Get a list of values to map # First, check if we've already dumped some `MapIngredient` output to the main session table if temp_session_table_exists: - temp_session_table = self.db.execute_query( - f'SELECT * FROM "{double_quote_escape(temp_session_tablename)}"' - ) - colname = align_to_real_columns( - db=self.db, colname=colname, tablename=temp_session_tablename + temp_session_table = self.db.execute_to_df( + select_all_from_table_query(temp_session_tablename) ) # We don't need to run this function on everything, # if a previous subquery already got to certain values if new_arg_column in temp_session_table.columns: - values = self.db.execute_query( + values = self.db.execute_to_list( f'SELECT DISTINCT "{colname}" FROM "{temp_session_tablename}" WHERE "{new_arg_column}" IS NULL', - )[colname].tolist() + to_type=str, + ) # Base case: this is the first time we've used this particular ingredient # BUT, temp_session_tablename still exists else: - values = self.db.execute_query( - f'SELECT DISTINCT "{colname}" FROM "{temp_session_tablename}"' - )[colname].tolist() + values = self.db.execute_to_list( + f'SELECT DISTINCT "{colname}" FROM "{temp_session_tablename}"', + to_type=str, + ) else: - colname = align_to_real_columns( - db=self.db, colname=colname, tablename=value_source_tablename + values = self.db.execute_to_list( + f'SELECT DISTINCT "{colname}" FROM "{value_source_tablename}"', + to_type=str, ) - values = self.db.execute_query( - f'SELECT DISTINCT "{colname}" FROM "{value_source_tablename}"' - )[colname].tolist() # No need to run ingredient if we have no values to map onto if len(values) == 0: @@ -211,10 +194,12 @@ def __call__( get_temp_subquery_table: Callable = kwargs.get("get_temp_subquery_table") get_temp_session_table: Callable = kwargs.get("get_temp_session_table") + # Depending on the size of the underlying data, it may be optimal to swap + # the order of 'left_on' and 'right_on' columns during processing + swapped = False values = [] original_lr_identifiers = [] modified_lr_identifiers = [] - left_values, right_values = [], [] mapping = {} for on_arg in [left_on, right_on]: tablename, colname = utils.get_tablename_colname(on_arg) @@ -225,13 +210,8 @@ def __call__( tablename=tablename, ) values.append( - set( - [ - str(i) - for i in self.db.execute_query( - f'SELECT DISTINCT "{colname}" FROM "{tablename}"' - )[colname].tolist() - ] + self.db.execute_to_list( + f'SELECT DISTINCT "{colname}" FROM "{tablename}"', to_type=str ) ) modified_lr_identifiers.append((tablename, colname)) @@ -239,16 +219,29 @@ def __call__( if question is None: # First, check which values we actually need to call Model on # We don't want to join when there's already an intuitive alignment + # First, make sure outer loop is shorter of the two lists + outer, inner = sorted(values, key=len) + _outer = [] + inner = set(inner) mapping = {} - left_values, right_values = values - for l in left_values: - if l in right_values: + for l in outer: + if l in inner: # Define this mapping, and remove from Model inference call mapping[l] = l + inner -= {l} + else: + _outer.append(l) + if len(inner) == 0: + break + to_compare = [inner, _outer] + else: + to_compare = values - processed_values = set(list(mapping.keys())) - left_values = left_values.difference(processed_values) - right_values = right_values.difference(processed_values) + # Finally, order by new (remaining) length and check if we swapped places from original + sorted_values = sorted(to_compare, key=len) + if sorted_values != values: + swapped = True + left_values, right_values = sorted_values kwargs["left_values"] = left_values kwargs["right_values"] = right_values @@ -274,16 +267,21 @@ def __call__( # Using mapped left/right values, create intermediary mapping table temp_join_tablename = get_temp_session_table(str(uuid.uuid4())[:4]) + # Below, we check to see if 'swapped' is True + # If so, we need to inverse what is 'left', and what is 'right' joined_values_df = pd.DataFrame( - data={"left": mapping.keys(), "right": mapping.values()} + data={ + "left" if not swapped else "right": mapping.keys(), + "right" if not swapped else "left": mapping.values(), + } ) self.db.to_temp_table(df=joined_values_df, tablename=temp_join_tablename) return ( left_tablename, right_tablename, f"""JOIN "{temp_join_tablename}" ON "{right_tablename}"."{right_colname}" = "{temp_join_tablename}".right - JOIN "{left_tablename}" ON "{left_tablename}"."{left_colname}" = "{temp_join_tablename}".left - """, + JOIN "{left_tablename}" ON "{left_tablename}"."{left_colname}" = "{temp_join_tablename}".left + """, temp_join_tablename, ) @@ -313,7 +311,7 @@ def __call__( if context is not None: if isinstance(context, str): tablename, colname = utils.get_tablename_colname(context) - subtable = self.db.execute_query( + subtable = self.db.execute_to_df( f'SELECT "{colname}" FROM "{tablename}"' ) elif not isinstance(context, pd.DataFrame): @@ -328,9 +326,9 @@ def __call__( try: tablename, colname = utils.get_tablename_colname(options) tablename = aliases_to_tablenames.get(tablename, tablename) - unpacked_options = self.db.execute_query( + unpacked_options = self.db.execute_to_list( f'SELECT DISTINCT "{colname}" FROM "{tablename}"' - )[colname].tolist() + ) except ValueError: unpacked_options = options.split(";") unpacked_options = set(unpacked_options) diff --git a/blendsql/models/local/_llama_cpp.py b/blendsql/models/local/_llama_cpp.py index b81a9fb2..ce849a8d 100644 --- a/blendsql/models/local/_llama_cpp.py +++ b/blendsql/models/local/_llama_cpp.py @@ -10,7 +10,7 @@ class LlamaCppLLM(Model): - """Class for Transformers local Model. + """Class for llama-cpp local Model. Args: model_name_or_path: Name of the model on HuggingFace, or the path to a local model @@ -19,7 +19,7 @@ class LlamaCppLLM(Model): def __init__(self, model_name_or_path: str, **kwargs): if not _has_llama_cpp: raise ImportError( - "Please install llama_cpp with `pip install llama_cpp`!" + "Please install llama_cpp with `pip install llama-cpp-python`!" ) from None super().__init__( diff --git a/blendsql/nl_to_blendsql.py b/blendsql/nl_to_blendsql.py index 82510bd4..93d83e2f 100644 --- a/blendsql/nl_to_blendsql.py +++ b/blendsql/nl_to_blendsql.py @@ -2,19 +2,14 @@ from textwrap import dedent from guidance import gen, select from colorama import Fore -from pathlib import Path import logging from .ingredients import Ingredient, IngredientException from .models import Model from ._program import Program from .grammars.minEarley.parser import EarleyParser +from .grammars.utils import load_cfg_parser -CFG_PARSER = EarleyParser.open( - Path(__file__).parent / "./grammars/_cfg_grammar.lark", - start="start", - keep_all_tokens=True, -) PARSER_STOP_TOKENS = ["---", ";", "\n\n"] PARSER_SYSTEM_PROMPT = dedent( """ @@ -145,6 +140,7 @@ def nl_to_blendsql( logging.getLogger().setLevel(logging.DEBUG) else: logging.getLogger().setLevel(logging.ERROR) + parser: EarleyParser = load_cfg_parser(ingredients) system_prompt: str = create_system_prompt( ingredients=ingredients, few_shot_examples=few_shot_examples ) @@ -174,13 +170,13 @@ def nl_to_blendsql( partial_program_prediction + " " + residual_program_prediction ) - if validate_program(program_prediction, CFG_PARSER): + if validate_program(program_prediction, parser): ret_prediction = program_prediction continue # find the max score from a list of score prefix, candidates, pos_in_stream = obtain_correction_pairs( - program_prediction, CFG_PARSER + program_prediction, parser ) candidates = [i for i in candidates if i.strip() != ""] if len(candidates) == 0: @@ -204,13 +200,13 @@ def nl_to_blendsql( inserted_candidate = ( prefix + selected_candidate + program_prediction[pos_in_stream:] ) - if validate_program(inserted_candidate, CFG_PARSER): + if validate_program(inserted_candidate, parser): ret_prediction = inserted_candidate continue # 2) If rest of our query is also broken, we just keep up to the prefix + candidate partial_program_prediction = prefix + selected_candidate for p in {inserted_candidate, partial_program_prediction}: - if validate_program(p, CFG_PARSER): + if validate_program(p, parser): ret_prediction = p num_correction_left -= 1 @@ -224,24 +220,3 @@ def nl_to_blendsql( ret_prediction = initial_prediction logging.debug(Fore.GREEN + ret_prediction + Fore.RESET) return ret_prediction - - -if __name__ == "__main__": - from blendsql import nl_to_blendsql, LLMMap - from blendsql.models import TransformersLLM - from blendsql.db import SQLite - from blendsql.utils import fetch_from_hub - - model = TransformersLLM("Qwen/Qwen1.5-0.5B") - db = SQLite( - fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db") - ) - - query = nl_to_blendsql( - question="Which venues in Sydney saw more than 30 points scored?", - model=model, - ingredients={LLMMap}, - serialized_db=db.to_serialized(num_rows=3, use_tables=["w", "documents"]), - verbose=True, - max_grammar_corrections=5, - ) diff --git a/docs/index.md b/docs/index.md index a59c283d..134285c2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -121,4 +121,14 @@ For a technical walkthrough of how a BlendSQL query is executed, check out [tech archivePrefix={arXiv}, primaryClass={cs.CL} } -``` \ No newline at end of file +``` + +### Acknowledgements +Special thanks to those below for inspiring this project. Definitely recommend checking out the linked work below, and citing when applicable! + +- The authors of [Binding Language Models in Symbolic Languages](https://arxiv.org/abs/2210.02875) + - This paper was the primary inspiration for BlendSQL. +- The authors of [EHRXQA: A Multi-Modal Question Answering Dataset for Electronic Health Records with Chest X-ray Images](https://arxiv.org/pdf/2310.18652) + - As far as I can tell, the first publication to propose unifying model calls within SQL + - Served as the inspiration for the [vqa-ingredient.ipynb](./examples/vqa-ingredient.ipynb) example +- The authors of [Grammar Prompting for Domain-Specific Language Generation with Large Language Models](https://arxiv.org/abs/2305.19234) diff --git a/docs/reference/blend-cli.md b/docs/reference/blend-cli.md new file mode 100644 index 00000000..b1fd4d76 --- /dev/null +++ b/docs/reference/blend-cli.md @@ -0,0 +1,24 @@ +# blend cli + +``` +usage: blendsql [-h] [-v] + [db_url] [{openai,azure_openai,llama_cpp,transformers}] + [model_name_or_path] + +positional arguments: + db_url Database URL, + {openai,azure_openai,llama_cpp,transformers} + Model type, for the Blender to use in executing the BlendSQL + query. + model_name_or_path Model identifier to pass to the selected model_type class. + +optional arguments: + -h, --help show this help message and exit + -v Flag to run in verbose mode. +``` + +Example Usage: + +```bash +blendsql mydb.db openai gpt-3.5-turbo -v +``` \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 14820919..fb76022f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -74,6 +74,7 @@ nav: - Teaching BlendSQL via In-Context Learning: reference/examples/teaching-blendsql-via-in-context-learning.ipynb - Documentation: - Execute a BlendSQL Query: reference/execute-blendsql.md + - Blend CLI: reference/blend-cli.md - Ingredients: - reference/ingredients/ingredients.md - Creating Custom Ingredients: reference/ingredients/creating-custom-ingredients.md diff --git a/research/run-debug.py b/research/run-debug.py deleted file mode 100644 index 6f90a5aa..00000000 --- a/research/run-debug.py +++ /dev/null @@ -1,35 +0,0 @@ -from blendsql import blend, LLMJoin -from blendsql.db import SQLite, PostgreSQL -from blendsql.models import OpenaiLLM -from blendsql.utils import fetch_from_hub, tabulate - -if __name__ == "__main__": - blendsql = """ - SELECT date, rival, score, documents.content AS "Team Description" FROM w - JOIN {{ - LLMJoin( - left_on='documents::title', - right_on='w::rival' - ) - }} WHERE rival = 'nsw waratahs' - """ - # Make our smoothie - the executed BlendSQL script - smoothie = blend( - query=blendsql, - db=SQLite( - fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db") - ), - blender=OpenaiLLM("gpt-3.5-turbo"), - ingredients={LLMJoin}, - ) - print(tabulate(smoothie.df)) - - smoothie = blend( - query=blendsql, - db=PostgreSQL( - "blendsql@localhost:5432/1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1" - ), - blender=OpenaiLLM("gpt-3.5-turbo"), - ingredients={LLMJoin}, - ) - print(tabulate(smoothie.df)) diff --git a/research/utils/bridge_content_encoder.py b/research/utils/bridge_content_encoder.py index c4c68494..3005d379 100644 --- a/research/utils/bridge_content_encoder.py +++ b/research/utils/bridge_content_encoder.py @@ -245,7 +245,7 @@ def get_column_picklist_with_db(table_name: str, column_name: str, db) -> list: fetch_sql = 'SELECT DISTINCT `{}` FROM "{}"'.format( column_name, double_quote_escape(table_name) ) - picklist = set(db.execute_query(fetch_sql).values.flat) + picklist = set(db.execute_to_df(fetch_sql).values.flat) picklist = list(picklist) cache[key] = picklist return picklist diff --git a/research/utils/database.py b/research/utils/database.py index 85cc6fbc..e16d0a3e 100644 --- a/research/utils/database.py +++ b/research/utils/database.py @@ -50,7 +50,7 @@ def to_serialized( else: serialized_db.append(f"{num_rows} example rows:") serialized_db.append(f"{get_rows_query}") - rows = db.execute_query(get_rows_query) + rows = db.execute_to_df(get_rows_query) if truncate_content is not None: # Truncate long strings rows = rows.map( diff --git a/research/utils/ottqa/ottqa.py b/research/utils/ottqa/ottqa.py index e582e999..0bf27154 100644 --- a/research/utils/ottqa/ottqa.py +++ b/research/utils/ottqa/ottqa.py @@ -69,7 +69,7 @@ def ottqa_get_input( model_args: ModelArguments, ) -> Tuple[str, dict]: if "docs_tablesize" not in cache: - cache["docs_tablesize"] = db.execute_query( + cache["docs_tablesize"] = db.execute_to_df( f"SELECT COUNT(*) FROM {DOCS_TABLE_NAME}" ).values[0][0] cache["docs_tablesize"] diff --git a/run-debug.py b/run-debug.py new file mode 100644 index 00000000..633391da --- /dev/null +++ b/run-debug.py @@ -0,0 +1,71 @@ +from blendsql import blend, LLMJoin, LLMMap, LLMQA +from blendsql.db import SQLite +from blendsql.models import TransformersLLM +from blendsql.utils import fetch_from_hub +from tqdm import tqdm + +# TEST_QUERIES = [ +# """ +# SELECT DISTINCT venue FROM w +# WHERE city = 'sydney' AND {{ +# LLMMap( +# 'More than 30 total points?', +# 'w::score' +# ) +# }} = TRUE +# """, +# """ +# SELECT * FROM w +# WHERE city = {{ +# LLMQA( +# 'Which city is located 120 miles west of Sydney?', +# (SELECT * FROM documents WHERE documents MATCH 'sydney OR 120'), +# options='w::city' +# ) +# }} +# """, +# """ +# SELECT date, rival, score, documents.content AS "Team Description" FROM w +# JOIN {{ +# LLMJoin( +# left_on='documents::title', +# right_on='w::rival' +# ) +# }} +# """ +# ] + +TEST_QUERIES = [ + """ + SELECT title, player FROM w JOIN {{ + LLMJoin( + left_on='documents::title', + right_on='w::player' + ) + }} + """ +] +if __name__ == "__main__": + """ + Without cached LLM response (10 runs): + before: 3.16 + after: 1.91 + With cached LLM response (100 runs): + before: 0.0175 + after: 0.0166 + """ + times = [] + db = SQLite(fetch_from_hub("1966_NBA_Expansion_Draft_0.db")) + for _ in tqdm(range(100)): + for q in TEST_QUERIES: + # Make our smoothie - the executed BlendSQL script + smoothie = blend( + query=q, + db=db, + blender=TransformersLLM("Qwen/Qwen1.5-0.5B", caching=True), + verbose=False, + ingredients={LLMJoin, LLMMap, LLMQA}, + ) + times.append(smoothie.meta.process_time_seconds) + + print(f"Average time across {len(times)} runs: {sum(times) / len(times)}") diff --git a/setup.py b/setup.py index 4b956041..96cb7833 100644 --- a/setup.py +++ b/setup.py @@ -33,15 +33,17 @@ def find_version(*file_paths): long_description_content_type="text/markdown", license="Apache License 2.0", packages=find_packages(exclude=["examples", "research", "img"]), - data_files=["blendsql/grammars/_cfg_grammar.lark"], + include_package_data=True, install_requires=[ "guidance>=0.1.0", "pyparsing==3.1.1", - "pandas==1.5.3", + "pandas>=2.0.0", "bottleneck>=1.3.6", "python-dotenv==1.0.1", "sqlglot==18.13.0", "sqlalchemy>=2.0.0", + "lark", + "exrex", "platformdirs", "attrs", "tqdm", @@ -50,7 +52,6 @@ def find_version(*file_paths): "typeguard", ], extras_require={ - "nl_to_blendsql": ["exrex", "lark"], "research": [ "datasets==2.16.1", "nltk", diff --git a/tests/test_generic_blendsql.py b/tests/test_generic_blendsql.py index aedcc0d6..d0ac0375 100644 --- a/tests/test_generic_blendsql.py +++ b/tests/test_generic_blendsql.py @@ -4,7 +4,7 @@ from pathlib import Path from blendsql import blend from blendsql.db import SQLite -from blendsql.ingredients.ingredient import IngredientException +from blendsql._exceptions import IngredientException, InvalidBlendSQL @pytest.fixture(scope="session") @@ -23,7 +23,7 @@ def test_error_on_delete1(db): blendsql = """ DELETE FROM w WHERE TRUE; """ - with pytest.raises(ValueError): + with pytest.raises(InvalidBlendSQL): _ = blend( query=blendsql, db=db, @@ -35,7 +35,7 @@ def test_error_on_delete2(db): blendsql = """ DROP TABLE w; """ - with pytest.raises(ValueError): + with pytest.raises(InvalidBlendSQL): _ = blend( query=blendsql, db=db, diff --git a/tests/test_grammar.py b/tests/test_grammar.py index 2914b65a..fb7f2b0e 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -1,11 +1,12 @@ import pytest -from blendsql.grammars.minEarley.parser import EarleyParser +from blendsql import LLMQA, LLMJoin, LLMMap +from blendsql.grammars.utils import load_cfg_parser, EarleyParser from blendsql.grammars.minEarley.earley_exceptions import UnexpectedInput @pytest.fixture(scope="session") def parser() -> EarleyParser: - return EarleyParser.open("./blendsql/grammars/_cfg_grammar.lark", start="start") + return load_cfg_parser({LLMQA, LLMJoin, LLMMap}) accept_queries = [ diff --git a/tests/test_multi_table_blendsql.py b/tests/test_multi_table_blendsql.py index d04f6123..280d3285 100644 --- a/tests/test_multi_table_blendsql.py +++ b/tests/test_multi_table_blendsql.py @@ -62,10 +62,10 @@ def test_simple_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE Symbol in ( @@ -73,8 +73,8 @@ def test_simple_multi_exec(db, ingredients): WHERE sector = 'Information Technology' ) """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_join_multi_exec(db, ingredients): @@ -99,15 +99,15 @@ def test_join_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Name) FROM constituents WHERE Sector = 'Information Technology' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_join_not_qualified_multi_exec(db, ingredients): @@ -136,15 +136,15 @@ def test_join_not_qualified_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Name) FROM constituents WHERE Sector = 'Information Technology' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_select_multi_exec(db, ingredients): @@ -170,7 +170,7 @@ def test_select_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -198,7 +198,7 @@ def test_complex_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -229,7 +229,7 @@ def test_complex_not_qualified_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -251,7 +251,7 @@ def test_join_ingredient_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -269,7 +269,7 @@ def test_qa_equals_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -289,15 +289,15 @@ def test_table_alias_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_subquery_alias_multi_exec(db, ingredients): @@ -320,15 +320,15 @@ def test_subquery_alias_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 AND Quantity > 200 """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_cte_qa_multi_exec(db, ingredients): @@ -355,15 +355,27 @@ def test_cte_qa_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient - # passed_to_ingredient = db.execute_query( - # """ - # SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 AND Quantity > 200 - # """ - # ) - # assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + passed_to_map_ingredient = db.execute_to_list( + """ + SELECT COUNT(DISTINCT Symbol) FROM portfolio + """ + )[0] + # We also need to factor in what we passed to QA ingredient + passed_to_qa_ingredient = db.execute_to_list( + """ + WITH a AS ( + SELECT * FROM (SELECT DISTINCT * FROM portfolio) as w + WHERE w.Symbol LIKE 'F%' + ) SELECT COUNT(*) FROM a WHERE LENGTH(a.Symbol) > 2 + """ + )[0] + assert ( + smoothie.meta.num_values_passed + == passed_to_qa_ingredient + passed_to_map_ingredient + ) def test_cte_qa_named_multi_exec(db, ingredients): @@ -390,15 +402,25 @@ def test_cte_qa_named_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) # Make sure we only pass what's necessary to our ingredient - # passed_to_ingredient = db.execute_query( - # """ - # SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 AND Quantity > 200 - # """ - # ) - # assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + passed_to_map_ingredient = db.execute_to_list( + """ + SELECT COUNT(DISTINCT Symbol) FROM portfolio + """ + )[0] + passed_to_qa_ingredient = db.execute_to_list( + """ + WITH a AS ( + SELECT * FROM (SELECT DISTINCT * FROM portfolio) as w + WHERE w.Symbol LIKE 'F%' + ) SELECT COUNT(*) FROM a WHERE LENGTH(a.Symbol) > 2 + """ + )[0] + assert smoothie.meta.num_values_passed == ( + passed_to_map_ingredient + passed_to_qa_ingredient + ) def test_ingredient_in_select_with_join_multi_exec(db, ingredients): @@ -422,17 +444,17 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol WHERE account_history.Action like "%dividend%" """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients): @@ -455,48 +477,48 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients) db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol WHERE account_history.Action like "%dividend%" """ + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient + + +def test_subquery_alias_with_join_multi_exec(db, ingredients): + blendsql = """ + SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w + JOIN {{ + do_join( + left_on='geographic::Symbol', + right_on='w::Symbol' + ) + }} WHERE {{starts_with('F', 'w::Symbol')}} + AND w."Percent of Account" < 0.2 + """ + + sql = """ + SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w + JOIN geographic ON w.Symbol = geographic.Symbol + WHERE w.Symbol LIKE 'F%' + AND w."Percent of Account" < 0.2 + """ + smoothie = blend( + query=blendsql, + db=db, + ingredients=ingredients, ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() - - -# def test_subquery_alias_with_join_multi_exec(db, ingredients): -# blendsql = """ -# SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w -# JOIN {{ -# do_join( -# left_on='w::Symbol', -# right_on='geographic::Symbol' -# ) -# }} WHERE {{starts_with('F', 'w::Symbol')}} -# AND w."Percent of Account" < 0.2 -# """ -# -# sql = """ -# SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w -# JOIN geographic ON w.Symbol = geographic.Symbol -# WHERE w.Symbol LIKE 'F%' -# AND w."Percent of Account" < 0.2 -# """ -# smoothie = blend( -# query=blendsql, -# db=db, -# ingredients=ingredients, -# ) -# sql_df = db.execute_query(sql) -# assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) -# # Make sure we only pass what's necessary to our ingredient -# # passed_to_ingredient = db.execute_query( -# # """ -# # SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 -# # """ -# # ) -# # assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + sql_df = db.execute_to_df(sql) + assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"]) + # Make sure we only pass what's necessary to our ingredient + passed_to_ingredient = db.execute_to_list( + """ + SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE (Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) AND "Percent of Account" < 0.2 + """ + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient diff --git a/tests/test_single_table_blendsql.py b/tests/test_single_table_blendsql.py index 129db9ab..3848328c 100644 --- a/tests/test_single_table_blendsql.py +++ b/tests/test_single_table_blendsql.py @@ -68,7 +68,7 @@ def test_simple_ingredient_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) @@ -84,7 +84,7 @@ def test_simple_ingredient_exec_at_end(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) @@ -98,12 +98,12 @@ def test_simple_ingredient_exec_in_select(db, ingredients): ingredients=ingredients, ) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE parent_category = 'Auto & Transport' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_nested_ingredient_exec(db, ingredients): @@ -128,18 +128,25 @@ def test_nested_ingredient_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 100 """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_nonexistent_column_exec(db, ingredients): + """ + NOTE: Converting to CNF would break this one + since we would get: + SELECT DISTINCT merchant, child_category FROM transactions WHERE + (child_category = 'Gifts' OR STRUCT(STRUCT(A())) = 1) AND + (child_category = 'Gifts' OR child_category = 'this does not exist') + """ blendsql = """ SELECT DISTINCT merchant, child_category FROM transactions WHERE ( @@ -161,15 +168,15 @@ def test_nonexistent_column_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'this does not exist' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_nested_and_exec(db, ingredients): @@ -194,15 +201,15 @@ def test_nested_and_exec(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["O"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_multiple_nested_ingredients(db, ingredients): @@ -229,15 +236,15 @@ def test_multiple_nested_ingredients(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A", "T"]) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) + COUNT(DISTINCT child_category) FROM transactions WHERE parent_category = 'Food' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_length_ingredient(db, ingredients): @@ -256,15 +263,15 @@ def test_length_ingredient(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_max_length(db, ingredients): @@ -283,15 +290,15 @@ def test_max_length(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_limit(db, ingredients): @@ -312,15 +319,15 @@ def test_limit(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_select(db, ingredients): @@ -337,7 +344,7 @@ def test_select(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) @@ -353,15 +360,15 @@ def test_ingredient_in_select_stmt(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_ingredient_in_select_stmt_with_filter(db, ingredients): @@ -377,15 +384,15 @@ def test_ingredient_in_select_stmt_with_filter(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions WHERE child_category = 'Restaurants & Dining' """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_nested_duplicate_map_calls(db, ingredients): @@ -400,15 +407,15 @@ def test_nested_duplicate_map_calls(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT COUNT(DISTINCT merchant) FROM transactions """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_many_duplicate_map_calls(db, ingredients): @@ -433,10 +440,10 @@ def test_many_duplicate_map_calls(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT (SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 1300) @@ -444,8 +451,8 @@ def test_many_duplicate_map_calls(db, ingredients): + (SELECT COUNT(DISTINCT child_category) FROM transactions WHERE amount > 1300) + (SELECT COUNT(DISTINCT date) FROM transactions WHERE amount > 1300) """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_exists_isolated_qa_call(db, ingredients): @@ -472,15 +479,15 @@ def test_exists_isolated_qa_call(db, ingredients): db=db, ingredients=ingredients, ) - sql_df = db.execute_query(sql) + sql_df = db.execute_to_df(sql) assert_equality(smoothie=smoothie, sql_df=sql_df) # Make sure we only pass what's necessary to our ingredient - passed_to_ingredient = db.execute_query( + passed_to_ingredient = db.execute_to_list( """ SELECT (SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 500) + (SELECT COUNT(*) FROM transactions WHERE amount < 500) """ - ) - assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item() + )[0] + assert smoothie.meta.num_values_passed == passed_to_ingredient def test_query_options_arg(db, ingredients):