From e32224a1e9fbf226e412f5173fcdbfe62b2cac3b Mon Sep 17 00:00:00 2001 From: dtenwolde Date: Tue, 6 Feb 2024 13:42:51 +0100 Subject: [PATCH] WITH statements and subqueries now seem to work --- duckpgq/src/duckpgq_extension.cpp | 41 ++++++++-------- test/sql/with_statement_duckpgq.test | 73 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 20 deletions(-) create mode 100644 test/sql/with_statement_duckpgq.test diff --git a/duckpgq/src/duckpgq_extension.cpp b/duckpgq/src/duckpgq_extension.cpp index 9d419012..5f2a91f2 100644 --- a/duckpgq/src/duckpgq_extension.cpp +++ b/duckpgq/src/duckpgq_extension.cpp @@ -17,6 +17,7 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/statement/copy_statement.hpp" #include "duckdb/parser/parsed_data/create_table_info.hpp" +#include "duckdb/parser/tableref/joinref.hpp" #include "duckdb/parser/statement/extension_statement.hpp" @@ -109,24 +110,32 @@ BoundStatement duckpgq_bind(ClientContext &context, Binder &binder, throw BinderException("Unable to find DuckPGQ Parse Data"); } -ParserExtensionPlanResult duckpgq_handle_statement(SQLStatement *statement, DuckPGQState &duckpgq_state) { - if (statement->type == StatementType::SELECT_STATEMENT) { - auto select_statement = dynamic_cast(statement); - auto select_node = dynamic_cast(select_statement->node.get()); - auto from_table_function = - dynamic_cast(select_node->from_table.get()); - auto function = - dynamic_cast(from_table_function->function.get()); +void duckpgq_find_match_function(TableRef* table_ref, DuckPGQState &duckpgq_state) { + if (auto table_function_ref = dynamic_cast(table_ref)) { + // Handle TableFunctionRef case + auto function = dynamic_cast(table_function_ref->function.get()); if (function->function_name == "duckpgq_match") { duckpgq_state.transform_expression = std::move(std::move(function->children[0])); function->children.pop_back(); } + } else if (auto join_ref = dynamic_cast(table_ref)) { + // Handle JoinRef case + duckpgq_find_match_function(join_ref->left.get(), duckpgq_state); + duckpgq_find_match_function(join_ref->right.get(), duckpgq_state); + } +} + +ParserExtensionPlanResult duckpgq_handle_statement(SQLStatement *statement, DuckPGQState &duckpgq_state) { + if (statement->type == StatementType::SELECT_STATEMENT) { + auto select_statement = dynamic_cast(statement); + auto select_node = dynamic_cast(select_statement->node.get()); + duckpgq_find_match_function(select_node->from_table.get(), duckpgq_state); throw Exception("use duckpgq_bind instead"); } if (statement->type == StatementType::CREATE_STATEMENT) { - auto &create_statement = statement->Cast(); - auto create_property_graph = dynamic_cast(create_statement.info.get()); + const auto &create_statement = statement->Cast(); + const auto create_property_graph = dynamic_cast(create_statement.info.get()); if (create_property_graph) { ParserExtensionPlanResult result; result.function = CreatePropertyGraphFunction(); @@ -134,7 +143,7 @@ ParserExtensionPlanResult duckpgq_handle_statement(SQLStatement *statement, Duck result.return_type = StatementReturnType::QUERY_RESULT; return result; } - auto create_table = reinterpret_cast(create_statement.info.get()); + const auto create_table = reinterpret_cast(create_statement.info.get()); duckpgq_handle_statement(create_table->query.get(), duckpgq_state); } if (statement->type == StatementType::DROP_STATEMENT) { @@ -152,15 +161,7 @@ ParserExtensionPlanResult duckpgq_handle_statement(SQLStatement *statement, Duck if (statement->type == StatementType::COPY_STATEMENT) { auto ©_statement = statement->Cast(); auto select_node = dynamic_cast(copy_statement.select_statement.get()); - auto from_table_function = - dynamic_cast(select_node->from_table.get()); - auto function = - dynamic_cast(from_table_function->function.get()); - if (function->function_name == "duckpgq_match") { - duckpgq_state.transform_expression = - std::move(std::move(function->children[0])); - function->children.pop_back(); - } + duckpgq_find_match_function(select_node->from_table.get(), duckpgq_state); throw Exception("use duckpgq_bind instead"); } if (statement->type == StatementType::INSERT_STATEMENT) { diff --git a/test/sql/with_statement_duckpgq.test b/test/sql/with_statement_duckpgq.test new file mode 100644 index 00000000..e122fa2d --- /dev/null +++ b/test/sql/with_statement_duckpgq.test @@ -0,0 +1,73 @@ +# name: test/sql/sqlpgq/snb.test +# group: [duckpgq] + +require duckpgq + +statement ok +import database 'duckdb-pgq/data/SNB0.003'; + +statement ok +-CREATE PROPERTY GRAPH snb_projected +VERTEX TABLES (Message); + +query IIIIIII +-WITH message_count AS ( + SELECT count(*) as m_count + FROM Message m + WHERE m.creationDate < '2010-05-27 11:16:36.013' +) +SELECT year, isComment, + CASE WHEN m_length < 40 THEN 0 + WHEN m_length < 80 THEN 1 + WHEN m_length < 160 THEN 2 + ELSE 3 END as lengthCategory, + count(*) as messageCount, + avg(m_length) as averageMessageLength, + sum(m_length) as sumMessageLength, + count(*) / mc.m_count as percentageOfMessages +FROM GRAPH_TABLE(snb_projected + MATCH (message:Message where message.creationDate < '2010-05-27 11:16:36.013') + COLUMNS (date_part('year', message.creationDate::TIMESTAMP) as year, message.ImageFile is NULL as isComment, message.length as m_length, message.id) + ) tmp, message_count mc +GROUP BY year, isComment, lengthCategory, m_count +ORDER BY year DESC, isComment ASC, lengthCategory ASC; +---- +2010 false 0 63 0.0 0 0.9692307692307692 +2010 true 2 2 109.0 218 0.03076923076923077 + + +query II +-FROM GRAPH_TABLE (snb_projected + MATCH (m:message) + COLUMNS (m.id) + ) tmp, (SELECT id from message limit 1) +LIMIT 10; +---- +618475290624 618475290624 +343597383683 618475290624 +343597383684 618475290624 +962072674309 618475290624 +962072674310 618475290624 +962072674311 618475290624 +962072674312 618475290624 +962072674313 618475290624 +962072674314 618475290624 +962072674315 618475290624 + +query II +-FROM (SELECT id from message limit 1), GRAPH_TABLE (snb_projected + MATCH (m:message) + COLUMNS (m.id) + ) tmp +LIMIT 10; +---- +618475290624 618475290624 +618475290624 343597383683 +618475290624 343597383684 +618475290624 962072674309 +618475290624 962072674310 +618475290624 962072674311 +618475290624 962072674312 +618475290624 962072674313 +618475290624 962072674314 +618475290624 962072674315 \ No newline at end of file