diff --git a/.github/workflows/daft-profiling.yml b/.github/workflows/daft-profiling.yml index 704592348e..e431e347c4 100644 --- a/.github/workflows/daft-profiling.yml +++ b/.github/workflows/daft-profiling.yml @@ -62,13 +62,13 @@ jobs: env: DAFT_DEVELOPER_USE_THREAD_POOL: '0' run: | - py-spy record --native --function -o tpch-${{github.run_id}}.txt -f speedscope -- python benchmarking/tpch/__main__.py --scale_factor=${{ env.TPCH_SCALE_FACTOR }} --num_parts=${{ env.TPCH_NUM_PARTS }} --skip_warmup || true + py-spy record --native --function -o tpch-${{github.run_id}}.txt -f speedscope -- python benchmarking/tpch/__main__.py --scale_factor=${{ env.TPCH_SCALE_FACTOR }} --num_parts=${{ env.TPCH_NUM_PARTS }} --skip_warmup --skip_questions=11,12,13,14,15,16,17,18,19,20,21,22 || true - name: Run GIL Profiling on TPCH Benchmark env: DAFT_DEVELOPER_USE_THREAD_POOL: '0' run: | - py-spy record --native --function --gil -o tpch-gil-${{github.run_id}}.txt -f speedscope -- python benchmarking/tpch/__main__.py --scale_factor=${{ env.TPCH_SCALE_FACTOR }} --num_parts=${{ env.TPCH_NUM_PARTS }} --skip_warmup || true + py-spy record --native --function --gil -o tpch-gil-${{github.run_id}}.txt -f speedscope -- python benchmarking/tpch/__main__.py --scale_factor=${{ env.TPCH_SCALE_FACTOR }} --num_parts=${{ env.TPCH_NUM_PARTS }} --skip_warmup --skip_questions=11,12,13,14,15,16,17,18,19,20,21,22 || true - name: Upload Profile diff --git a/benchmarking/tpch/__main__.py b/benchmarking/tpch/__main__.py index 785501383f..6a7d24b290 100644 --- a/benchmarking/tpch/__main__.py +++ b/benchmarking/tpch/__main__.py @@ -37,7 +37,7 @@ class MetricsBuilder: - NUM_TPCH_QUESTIONS = 10 + NUM_TPCH_QUESTIONS = 22 HEADERS = [ "started_at", @@ -133,7 +133,7 @@ def run_all_benchmarks( daft_context = get_context() metrics_builder = MetricsBuilder(daft_context.runner_config.name) - for i in range(1, 11): + for i in range(1, 23): if i in skip_questions: logger.warning("Skipping TPC-H q%s", i) continue diff --git a/benchmarking/tpch/answers.py b/benchmarking/tpch/answers.py index d3184e0634..90e8b212c7 100644 --- a/benchmarking/tpch/answers.py +++ b/benchmarking/tpch/answers.py @@ -3,7 +3,7 @@ import datetime from typing import Callable -from daft import DataFrame, col +from daft import DataFrame, col, lit GetDFFunc = Callable[[str], DataFrame] @@ -333,3 +333,347 @@ def decrease(x, y): ) return daft_df + + +def q11(get_df: GetDFFunc) -> DataFrame: + partsupp = get_df("partsupp") + supplier = get_df("supplier") + nation = get_df("nation") + + var_1 = "GERMANY" + var_2 = 0.0001 / 1 + + res_1 = ( + partsupp.join(supplier, left_on=col("PS_SUPPKEY"), right_on=col("S_SUPPKEY")) + .join(nation, left_on=col("S_NATIONKEY"), right_on=col("N_NATIONKEY")) + .where(col("N_NAME") == var_1) + ) + + res_2 = res_1.agg((col("PS_SUPPLYCOST") * col("PS_AVAILQTY")).sum().alias("tmp")).select( + col("tmp") * var_2, lit(1).alias("lit") + ) + + daft_df = ( + res_1.groupby("PS_PARTKEY") + .agg( + (col("PS_SUPPLYCOST") * col("PS_AVAILQTY")).sum().alias("value"), + ) + .with_column("lit", lit(1)) + .join(res_2, on="lit") + .where(col("value") > col("tmp")) + .select(col("PS_PARTKEY"), col("value").round(2)) + .sort(col("value"), desc=True) + ) + + return daft_df + + +def q12(get_df: GetDFFunc) -> DataFrame: + orders = get_df("orders") + lineitem = get_df("lineitem") + + daft_df = ( + orders.join(lineitem, left_on=col("O_ORDERKEY"), right_on=col("L_ORDERKEY")) + .where( + col("L_SHIPMODE").is_in(["MAIL", "SHIP"]) + & (col("L_COMMITDATE") < col("L_RECEIPTDATE")) + & (col("L_SHIPDATE") < col("L_COMMITDATE")) + & (col("L_RECEIPTDATE") >= datetime.date(1994, 1, 1)) + & (col("L_RECEIPTDATE") < datetime.date(1995, 1, 1)) + ) + .groupby(col("L_SHIPMODE")) + .agg( + col("O_ORDERPRIORITY").is_in(["1-URGENT", "2-HIGH"]).if_else(1, 0).sum().alias("high_line_count"), + (~col("O_ORDERPRIORITY").is_in(["1-URGENT", "2-HIGH"])).if_else(1, 0).sum().alias("low_line_count"), + ) + .sort(col("L_SHIPMODE")) + ) + + return daft_df + + +def q13(get_df: GetDFFunc) -> DataFrame: + customers = get_df("customer") + orders = get_df("orders") + + daft_df = ( + customers.join( + orders.where(~col("O_COMMENT").str.match(".*special.*requests.*")), + left_on="C_CUSTKEY", + right_on="O_CUSTKEY", + how="left", + ) + .groupby(col("C_CUSTKEY")) + .agg(col("O_ORDERKEY").count().alias("c_count")) + .sort("C_CUSTKEY") + .groupby("c_count") + .agg(col("c_count").count().alias("custdist")) + .sort(["custdist", "c_count"], desc=[True, True]) + ) + + return daft_df + + +def q14(get_df: GetDFFunc) -> DataFrame: + lineitem = get_df("lineitem") + part = get_df("part") + + daft_df = ( + lineitem.join(part, left_on=col("L_PARTKEY"), right_on=col("P_PARTKEY")) + .where((col("L_SHIPDATE") >= datetime.date(1995, 9, 1)) & (col("L_SHIPDATE") < datetime.date(1995, 10, 1))) + .agg( + col("P_TYPE") + .str.startswith("PROMO") + .if_else(col("L_EXTENDEDPRICE") * (1 - col("L_DISCOUNT")), 0) + .sum() + .alias("tmp_1"), + (col("L_EXTENDEDPRICE") * (1 - col("L_DISCOUNT"))).sum().alias("tmp_2"), + ) + .select(100.00 * (col("tmp_1") / col("tmp_2")).alias("promo_revenue")) + ) + + return daft_df + + +def q15(get_df: GetDFFunc) -> DataFrame: + lineitem = get_df("lineitem") + + revenue = ( + lineitem.where( + (col("L_SHIPDATE") >= datetime.date(1996, 1, 1)) & (col("L_SHIPDATE") < datetime.date(1996, 4, 1)) + ) + .groupby(col("L_SUPPKEY")) + .agg((col("L_EXTENDEDPRICE") * (1 - col("L_DISCOUNT"))).sum().alias("total_revenue")) + .select(col("L_SUPPKEY").alias("supplier_no"), "total_revenue") + ) + + revenue = revenue.join(revenue.max("total_revenue"), on="total_revenue") + + supplier = get_df("supplier") + + daft_df = ( + supplier.join(revenue, left_on=col("S_SUPPKEY"), right_on=col("supplier_no")) + .select("S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_PHONE", "total_revenue") + .sort("S_SUPPKEY") + ) + + return daft_df + + +def q16(get_df: GetDFFunc) -> DataFrame: + part = get_df("part") + partsupp = get_df("partsupp") + + supplier = get_df("supplier") + + suppkeys = supplier.where(col("S_COMMENT").str.match(".*Customer.*Complaints.*")).select( + col("S_SUPPKEY"), col("S_SUPPKEY").alias("PS_SUPPKEY_RIGHT") + ) + + daft_df = ( + part.join(partsupp, left_on=col("P_PARTKEY"), right_on=col("PS_PARTKEY")) + .where( + (col("P_BRAND") != "Brand#45") + & ~col("P_TYPE").str.startswith("MEDIUM POLISHED") + & (col("P_SIZE").is_in([49, 14, 23, 45, 19, 3, 36, 9])) + ) + .join(suppkeys, left_on="PS_SUPPKEY", right_on="S_SUPPKEY", how="left") + .where(col("PS_SUPPKEY_RIGHT").is_null()) + .select("P_BRAND", "P_TYPE", "P_SIZE", "PS_SUPPKEY") + .distinct() + .groupby("P_BRAND", "P_TYPE", "P_SIZE") + .agg(col("PS_SUPPKEY").count().alias("supplier_cnt")) + .sort(["supplier_cnt", "P_BRAND", "P_TYPE", "P_SIZE"], desc=[True, False, False, False]) + ) + + return daft_df + + +def q17(get_df: GetDFFunc) -> DataFrame: + lineitem = get_df("lineitem") + part = get_df("part") + + res_1 = part.where((col("P_BRAND") == "Brand#23") & (col("P_CONTAINER") == "MED BOX")).join( + lineitem, left_on="P_PARTKEY", right_on="L_PARTKEY", how="left" + ) + + daft_df = ( + res_1.groupby("P_PARTKEY") + .agg((0.2 * col("L_QUANTITY")).mean().alias("avg_quantity")) + .select(col("P_PARTKEY").alias("key"), col("avg_quantity")) + .join(res_1, left_on="key", right_on="P_PARTKEY") + .where(col("L_QUANTITY") < col("avg_quantity")) + .agg((col("L_EXTENDEDPRICE") / 7.0).sum().alias("avg_yearly")) + ) + + return daft_df + + +def q18(get_df: GetDFFunc) -> DataFrame: + customer = get_df("customer") + orders = get_df("orders") + lineitem = get_df("lineitem") + + res_1 = lineitem.groupby("L_ORDERKEY").agg(col("L_QUANTITY").sum().alias("sum_qty")).where(col("sum_qty") > 300) + + daft_df = ( + orders.join(res_1, left_on=col("O_ORDERKEY"), right_on=col("L_ORDERKEY")) + .join(customer, left_on=col("O_CUSTKEY"), right_on=col("C_CUSTKEY")) + .join(lineitem, left_on=col("O_ORDERKEY"), right_on=col("L_ORDERKEY")) + .groupby("C_NAME", "C_CUSTKEY", "O_ORDERKEY", "O_ORDERDATE", "O_TOTALPRICE") + .agg(col("L_QUANTITY").sum().alias("sum")) + .select("C_NAME", "C_CUSTKEY", "O_ORDERKEY", col("O_ORDERDATE").alias("O_ORDERDAT"), "O_TOTALPRICE", "sum") + .sort(["O_TOTALPRICE", "O_ORDERDAT"], desc=[True, False]) + .limit(100) + ) + + return daft_df + + +def q19(get_df: GetDFFunc) -> DataFrame: + lineitem = get_df("lineitem") + part = get_df("part") + + daft_df = ( + lineitem.join(part, left_on=col("L_PARTKEY"), right_on=col("P_PARTKEY")) + .where( + ( + (col("P_BRAND") == "Brand#12") + & col("P_CONTAINER").is_in(["SM CASE", "SM BOX", "SM PACK", "SM PKG"]) + & (col("L_QUANTITY") >= 1) + & (col("L_QUANTITY") <= 11) + & (col("P_SIZE") >= 1) + & (col("P_SIZE") <= 5) + & col("L_SHIPMODE").is_in(["AIR", "AIR REG"]) + & (col("L_SHIPINSTRUCT") == "DELIVER IN PERSON") + ) + | ( + (col("P_BRAND") == "Brand#23") + & col("P_CONTAINER").is_in(["MED BAG", "MED BOX", "MED PKG", "MED PACK"]) + & (col("L_QUANTITY") >= 10) + & (col("L_QUANTITY") <= 20) + & (col("P_SIZE") >= 1) + & (col("P_SIZE") <= 10) + & col("L_SHIPMODE").is_in(["AIR", "AIR REG"]) + & (col("L_SHIPINSTRUCT") == "DELIVER IN PERSON") + ) + | ( + (col("P_BRAND") == "Brand#34") + & col("P_CONTAINER").is_in(["LG CASE", "LG BOX", "LG PACK", "LG PKG"]) + & (col("L_QUANTITY") >= 20) + & (col("L_QUANTITY") <= 30) + & (col("P_SIZE") >= 1) + & (col("P_SIZE") <= 15) + & col("L_SHIPMODE").is_in(["AIR", "AIR REG"]) + & (col("L_SHIPINSTRUCT") == "DELIVER IN PERSON") + ) + ) + .agg((col("L_EXTENDEDPRICE") * (1 - col("L_DISCOUNT"))).sum().alias("revenue")) + ) + + return daft_df + + +def q20(get_df: GetDFFunc) -> DataFrame: + supplier = get_df("supplier") + nation = get_df("nation") + part = get_df("part") + partsupp = get_df("partsupp") + lineitem = get_df("lineitem") + + res_1 = ( + lineitem.where( + (col("L_SHIPDATE") >= datetime.date(1994, 1, 1)) & (col("L_SHIPDATE") < datetime.date(1995, 1, 1)) + ) + .groupby("L_PARTKEY", "L_SUPPKEY") + .agg(((col("L_QUANTITY") * 0.5).sum()).alias("sum_quantity")) + ) + + res_2 = nation.where(col("N_NAME") == "CANADA") + res_3 = supplier.join(res_2, left_on="S_NATIONKEY", right_on="N_NATIONKEY") + + daft_df = ( + part.where(col("P_NAME").str.startswith("forest")) + .select("P_PARTKEY") + .distinct() + .join(partsupp, left_on="P_PARTKEY", right_on="PS_PARTKEY") + .join( + res_1, + left_on=["PS_SUPPKEY", "P_PARTKEY"], + right_on=["L_SUPPKEY", "L_PARTKEY"], + ) + .where(col("PS_AVAILQTY") > col("sum_quantity")) + .select("PS_SUPPKEY") + .distinct() + .join(res_3, left_on="PS_SUPPKEY", right_on="S_SUPPKEY") + .select("S_NAME", "S_ADDRESS") + .sort("S_NAME") + ) + + return daft_df + + +def q21(get_df: GetDFFunc) -> DataFrame: + supplier = get_df("supplier") + nation = get_df("nation") + lineitem = get_df("lineitem") + orders = get_df("orders") + + res_1 = ( + lineitem.select("L_SUPPKEY", "L_ORDERKEY") + .distinct() + .groupby("L_ORDERKEY") + .agg(col("L_SUPPKEY").count().alias("nunique_col")) + .where(col("nunique_col") > 1) + .join(lineitem.where(col("L_RECEIPTDATE") > col("L_COMMITDATE")), on="L_ORDERKEY") + ) + + daft_df = ( + res_1.select("L_SUPPKEY", "L_ORDERKEY") + .groupby("L_ORDERKEY") + .agg(col("L_SUPPKEY").count().alias("nunique_col")) + .join(res_1, on="L_ORDERKEY") + .join(supplier, left_on="L_SUPPKEY", right_on="S_SUPPKEY") + .join(nation, left_on="S_NATIONKEY", right_on="N_NATIONKEY") + .join(orders, left_on="L_ORDERKEY", right_on="O_ORDERKEY") + .where((col("nunique_col") == 1) & (col("N_NAME") == "SAUDI ARABIA") & (col("O_ORDERSTATUS") == "F")) + .groupby("S_NAME") + .agg(col("O_ORDERKEY").count().alias("numwait")) + .sort(["numwait", "S_NAME"], desc=[True, False]) + .limit(100) + ) + + return daft_df + + +def q22(get_df: GetDFFunc) -> DataFrame: + orders = get_df("orders") + customer = get_df("customer") + + res_1 = ( + customer.with_column("cntrycode", col("C_PHONE").str.left(2)) + .where(col("cntrycode").is_in(["13", "31", "23", "29", "30", "18", "17"])) + .select("C_ACCTBAL", "C_CUSTKEY", "cntrycode") + ) + + res_2 = ( + res_1.where(col("C_ACCTBAL") > 0).agg(col("C_ACCTBAL").mean().alias("avg_acctbal")).with_column("lit", lit(1)) + ) + + res_3 = orders.select("O_CUSTKEY") + + daft_df = ( + res_1.join(res_3, left_on="C_CUSTKEY", right_on="O_CUSTKEY", how="left") + .where(col("O_CUSTKEY").is_null()) + .with_column("lit", lit(1)) + .join(res_2, on="lit") + .where(col("C_ACCTBAL") > col("avg_acctbal")) + .groupby("cntrycode") + .agg( + col("C_ACCTBAL").count().alias("numcust"), + col("C_ACCTBAL").sum().alias("totacctbal"), + ) + .sort("cntrycode") + ) + + return daft_df diff --git a/tests/assets/tpch-sqlite-queries/11.sql b/tests/assets/tpch-sqlite-queries/11.sql new file mode 100644 index 0000000000..c8ab171d22 --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/11.sql @@ -0,0 +1,30 @@ +-- using 1433771997 as a seed to the RNG + + +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc; diff --git a/tests/assets/tpch-sqlite-queries/12.sql b/tests/assets/tpch-sqlite-queries/12.sql new file mode 100644 index 0000000000..cf3511b352 --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/12.sql @@ -0,0 +1,31 @@ +-- using 1433771997 as a seed to the RNG + + +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date('1994-01-01') + and l_receiptdate < date('1994-01-01', '+1 year') +group by + l_shipmode +order by + l_shipmode; diff --git a/tests/assets/tpch-sqlite-queries/13.sql b/tests/assets/tpch-sqlite-queries/13.sql new file mode 100644 index 0000000000..4d5ce720e0 --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/13.sql @@ -0,0 +1,23 @@ +-- using 1433771997 as a seed to the RNG + + +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc; diff --git a/tests/assets/tpch-sqlite-queries/14.sql b/tests/assets/tpch-sqlite-queries/14.sql new file mode 100644 index 0000000000..1fef06fedd --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/14.sql @@ -0,0 +1,16 @@ +-- using 1433771997 as a seed to the RNG + + +select + 100.00 * sum(cast(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end as number)) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date('1995-09-01') + and l_shipdate < date('1995-09-01', '+1 month'); diff --git a/tests/assets/tpch-sqlite-queries/15.sql b/tests/assets/tpch-sqlite-queries/15.sql new file mode 100644 index 0000000000..563609df5a --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/15.sql @@ -0,0 +1,34 @@ +-- using 1433771997 as a seed to the RNG + +create view revenue0 (supplier_no, total_revenue) as + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date('1996-01-01') + and l_shipdate < date('1996-01-01', '+3 months') + group by + l_suppkey; + + +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey; diff --git a/tests/assets/tpch-sqlite-queries/16.sql b/tests/assets/tpch-sqlite-queries/16.sql new file mode 100644 index 0000000000..33e44d51b4 --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/16.sql @@ -0,0 +1,33 @@ +-- using 1433771997 as a seed to the RNG + + +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size; diff --git a/tests/assets/tpch-sqlite-queries/17.sql b/tests/assets/tpch-sqlite-queries/17.sql new file mode 100644 index 0000000000..1e504ad2a0 --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/17.sql @@ -0,0 +1,20 @@ +-- using 1433771997 as a seed to the RNG + + +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); diff --git a/tests/assets/tpch-sqlite-queries/18.sql b/tests/assets/tpch-sqlite-queries/18.sql new file mode 100644 index 0000000000..e6ae06b522 --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/18.sql @@ -0,0 +1,36 @@ +-- using 1433771997 as a seed to the RNG + + +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit 100; diff --git a/tests/assets/tpch-sqlite-queries/19.sql b/tests/assets/tpch-sqlite-queries/19.sql new file mode 100644 index 0000000000..bbe04840be --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/19.sql @@ -0,0 +1,38 @@ +-- using 1433771997 as a seed to the RNG + + +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); diff --git a/tests/assets/tpch-sqlite-queries/20.sql b/tests/assets/tpch-sqlite-queries/20.sql new file mode 100644 index 0000000000..1ad0cdaba8 --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/20.sql @@ -0,0 +1,40 @@ +-- using 1433771997 as a seed to the RNG + + +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date('1994-01-01') + and l_shipdate < date('1994-01-01', '+1 year') + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name; diff --git a/tests/assets/tpch-sqlite-queries/21.sql b/tests/assets/tpch-sqlite-queries/21.sql new file mode 100644 index 0000000000..70cd50557f --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/21.sql @@ -0,0 +1,43 @@ +-- using 1433771997 as a seed to the RNG + + +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100; diff --git a/tests/assets/tpch-sqlite-queries/22.sql b/tests/assets/tpch-sqlite-queries/22.sql new file mode 100644 index 0000000000..5cd6102879 --- /dev/null +++ b/tests/assets/tpch-sqlite-queries/22.sql @@ -0,0 +1,40 @@ +-- using 1433771997 as a seed to the RNG + + +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substr(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + substr(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substr(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode; diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a15fb13364..0508cddbe0 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -37,7 +37,10 @@ def _check_answer(daft_pd_df: pd.DataFrame, tpch_question: int, tmp_path: str): query = pathlib.Path(f"{TPCH_QUERIES}/{tpch_question}.sql").read_text() conn = sqlite3.connect(sqlite_db_file_path, detect_types=sqlite3.PARSE_DECLTYPES) cursor = conn.cursor() - res = cursor.execute(query) + queries = query.split(";") + for q in queries: + if not q.isspace(): + res = cursor.execute(q) sqlite_results = res.fetchall() sqlite_pd_results = pd.DataFrame.from_records(sqlite_results, columns=daft_pd_df.columns) assert_df_equals( diff --git a/tests/integration/test_tpch.py b/tests/integration/test_tpch.py index bc138616e4..1b03141875 100644 --- a/tests/integration/test_tpch.py +++ b/tests/integration/test_tpch.py @@ -102,3 +102,87 @@ def test_tpch_q10(tmp_path, check_answer, get_df): daft_df = answers.q10(get_df) daft_pd_df = daft_df.to_pandas() check_answer(daft_pd_df, 10, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q11(tmp_path, check_answer, get_df): + daft_df = answers.q11(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 11, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q12(tmp_path, check_answer, get_df): + daft_df = answers.q12(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 12, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q13(tmp_path, check_answer, get_df): + daft_df = answers.q13(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 13, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q14(tmp_path, check_answer, get_df): + daft_df = answers.q14(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 14, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q15(tmp_path, check_answer, get_df): + daft_df = answers.q15(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 15, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q16(tmp_path, check_answer, get_df): + daft_df = answers.q16(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 16, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q17(tmp_path, check_answer, get_df): + daft_df = answers.q17(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 17, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q18(tmp_path, check_answer, get_df): + daft_df = answers.q18(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 18, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q19(tmp_path, check_answer, get_df): + daft_df = answers.q19(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 19, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q20(tmp_path, check_answer, get_df): + daft_df = answers.q20(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 20, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q21(tmp_path, check_answer, get_df): + daft_df = answers.q21(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 21, tmp_path) + + +@pytest.mark.skip(reason="Running all TPC-H queries is too slow") +def test_tpch_q22(tmp_path, check_answer, get_df): + daft_df = answers.q22(get_df) + daft_pd_df = daft_df.to_pandas() + check_answer(daft_pd_df, 22, tmp_path)