diff --git a/src/core/functions/function_data/CMakeLists.txt b/src/core/functions/function_data/CMakeLists.txt index fa64f0f8..0147d5e8 100644 --- a/src/core/functions/function_data/CMakeLists.txt +++ b/src/core/functions/function_data/CMakeLists.txt @@ -3,6 +3,7 @@ set(EXTENSION_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/cheapest_path_length_function_data.cpp ${CMAKE_CURRENT_SOURCE_DIR}/iterative_length_function_data.cpp ${CMAKE_CURRENT_SOURCE_DIR}/local_clustering_coefficient_function_data.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/pagerank_function_data.cpp ${CMAKE_CURRENT_SOURCE_DIR}/weakly_connected_component_function_data.cpp PARENT_SCOPE diff --git a/src/core/functions/function_data/pagerank_function_data.cpp b/src/core/functions/function_data/pagerank_function_data.cpp new file mode 100644 index 00000000..f7562183 --- /dev/null +++ b/src/core/functions/function_data/pagerank_function_data.cpp @@ -0,0 +1,72 @@ +#include "duckpgq/core/functions/function_data/pagerank_function_data.hpp" + +namespace duckpgq { + +namespace core { + +// Constructor +PageRankFunctionData::PageRankFunctionData(ClientContext &ctx, int32_t csr) + : context(ctx), csr_id(csr), damping_factor(0.85), + convergence_threshold(1e-6), iteration_count(0), state_initialized(false), + converged(false) {} + +unique_ptr +PageRankFunctionData::PageRankBind(ClientContext &context, + ScalarFunction &bound_function, + vector> &arguments) { + if (!arguments[0]->IsFoldable()) { + throw InvalidInputException("Id must be constant."); + } + + int32_t csr_id = ExpressionExecutor::EvaluateScalar(context, *arguments[0]) + .GetValue(); + + return make_uniq(context, csr_id); +} + +// Copy method +unique_ptr PageRankFunctionData::Copy() const { + auto result = make_uniq(context, csr_id); + result->rank = rank; // Deep copy of rank vector + result->temp_rank = temp_rank; // Deep copy of temp_rank vector + result->damping_factor = damping_factor; + result->convergence_threshold = convergence_threshold; + result->iteration_count = iteration_count; + result->state_initialized = state_initialized; + result->converged = converged; + // Note: state_lock is not copied as mutexes are not copyable + return result; +} + +// Equals method +bool PageRankFunctionData::Equals(const FunctionData &other_p) const { + auto &other = (const PageRankFunctionData &)other_p; + if (csr_id != other.csr_id) { + return false; + } + if (rank != other.rank) { + return false; + } + if (temp_rank != other.temp_rank) { + return false; + } + if (damping_factor != other.damping_factor) { + return false; + } + if (convergence_threshold != other.convergence_threshold) { + return false; + } + if (iteration_count != other.iteration_count) { + return false; + } + if (state_initialized != other.state_initialized) { + return false; + } + if (converged != other.converged) { + return false; + } + return true; +} +} // namespace core + +} // namespace duckpgq \ No newline at end of file diff --git a/src/core/functions/scalar/CMakeLists.txt b/src/core/functions/scalar/CMakeLists.txt index 269f01b2..19c9d810 100644 --- a/src/core/functions/scalar/CMakeLists.txt +++ b/src/core/functions/scalar/CMakeLists.txt @@ -7,6 +7,7 @@ set(EXTENSION_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/iterativelength.cpp ${CMAKE_CURRENT_SOURCE_DIR}/iterativelength2.cpp ${CMAKE_CURRENT_SOURCE_DIR}/iterativelength_bidirectional.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/pagerank.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reachability.cpp ${CMAKE_CURRENT_SOURCE_DIR}/shortest_path.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csr_creation.cpp diff --git a/src/core/functions/scalar/pagerank.cpp b/src/core/functions/scalar/pagerank.cpp new file mode 100644 index 00000000..889d453e --- /dev/null +++ b/src/core/functions/scalar/pagerank.cpp @@ -0,0 +1,130 @@ +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckpgq/common.hpp" +#include "duckpgq/core/functions/function_data/pagerank_function_data.hpp" +#include +#include +#include +#include + +namespace duckpgq { +namespace core { + +static void PageRankFunction(DataChunk &args, + ExpressionState &state, + Vector &result) { + auto &func_expr = (BoundFunctionExpression &)state.expr; + auto &info = (PageRankFunctionData &)*func_expr.bind_info; + auto duckpgq_state = GetDuckPGQState(info.context); + + // Locate the CSR representation of the graph + auto csr_entry = duckpgq_state->csr_list.find((uint64_t)info.csr_id); + if (csr_entry == duckpgq_state->csr_list.end()) { + throw ConstraintException("CSR not found. Is the graph populated?"); + } + + if (!(csr_entry->second->initialized_v && csr_entry->second->initialized_e)) { + throw ConstraintException("Need to initialize CSR before running PageRank."); + } + + int64_t *v = (int64_t *)duckpgq_state->csr_list[info.csr_id]->v; + vector &e = duckpgq_state->csr_list[info.csr_id]->e; + size_t v_size = duckpgq_state->csr_list[info.csr_id]->vsize; + + // State initialization (only once) + if (!info.state_initialized) { + info.rank.resize(v_size, 1.0 / v_size); // Initial rank for each node + info.temp_rank.resize(v_size, 0.0); // Temporary storage for ranks during iteration + info.damping_factor = 0.85; // Typical damping factor + info.convergence_threshold = 1e-6; // Convergence threshold + info.state_initialized = true; + info.converged = false; + info.iteration_count = 0; + } + + // Check if already converged + if (!info.converged) { + std::lock_guard guard(info.state_lock); // Thread safety + + bool continue_iteration = true; + while (continue_iteration) { + fill(info.temp_rank.begin(), info.temp_rank.end(), 0.0); + + double total_dangling_rank = 0.0; // For dangling nodes + + for (size_t i = 0; i < v_size; i++) { + int64_t start_edge = v[i]; + int64_t end_edge = (i + 1 < v_size) ? v[i + 1] : e.size(); // Adjust end_edge + if (end_edge > start_edge) { + double rank_contrib = info.rank[i] / (end_edge - start_edge); + for (int64_t j = start_edge; j < end_edge; j++) { + int64_t neighbor = e[j]; + info.temp_rank[neighbor] += rank_contrib; + } + } else { + total_dangling_rank += info.rank[i]; + } + } + + // Apply damping factor and handle dangling node ranks + double correction_factor = total_dangling_rank / v_size; + double max_delta = 0.0; + for (size_t i = 0; i < v_size; i++) { + info.temp_rank[i] = (1 - info.damping_factor) / v_size + + info.damping_factor * (info.temp_rank[i] + correction_factor); + max_delta = std::max(max_delta, std::abs(info.temp_rank[i] - info.rank[i])); + } + + info.rank.swap(info.temp_rank); + info.iteration_count++; + if (max_delta < info.convergence_threshold) { + info.converged = true; + continue_iteration = false; + } + } + } + + // Get the source vector for the current DataChunk + auto &src = args.data[1]; + UnifiedVectorFormat vdata_src; + src.ToUnifiedFormat(args.size(), vdata_src); + auto src_data = (int64_t *)vdata_src.data; + + // Create result vector + ValidityMask &result_validity = FlatVector::Validity(result); + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + + // Output the PageRank value corresponding to each source ID in the DataChunk + for (idx_t i = 0; i < args.size(); i++) { + auto id_pos = vdata_src.sel->get_index(i); + if (!vdata_src.validity.RowIsValid(id_pos)) { + result_validity.SetInvalid(i); + continue; // Skip invalid rows + } + auto node_id = src_data[id_pos]; + if (node_id < 0 || node_id >= (int64_t)v_size) { + result_validity.SetInvalid(i); + continue; + } + result_data[i] = info.rank[node_id]; + } + + duckpgq_state->csr_to_delete.insert(info.csr_id); +} + +//------------------------------------------------------------------------------ +// Register functions +//------------------------------------------------------------------------------ +void CoreScalarFunctions::RegisterPageRankScalarFunction( + DatabaseInstance &db) { + ExtensionUtil::RegisterFunction( + db, + ScalarFunction( + "pagerank", + {LogicalType::INTEGER, LogicalType::BIGINT}, + LogicalType::DOUBLE, PageRankFunction, + PageRankFunctionData::PageRankBind)); +} + +} // namespace core +} // namespace duckpgq diff --git a/src/core/functions/table/CMakeLists.txt b/src/core/functions/table/CMakeLists.txt index f89fbb14..647a9d1e 100644 --- a/src/core/functions/table/CMakeLists.txt +++ b/src/core/functions/table/CMakeLists.txt @@ -4,6 +4,7 @@ set(EXTENSION_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/drop_property_graph.cpp ${CMAKE_CURRENT_SOURCE_DIR}/local_clustering_coefficient.cpp ${CMAKE_CURRENT_SOURCE_DIR}/match.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/pagerank.cpp ${CMAKE_CURRENT_SOURCE_DIR}/pgq_scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/weakly_connected_component.cpp ${EXTENSION_SOURCES} diff --git a/src/core/functions/table/pagerank.cpp b/src/core/functions/table/pagerank.cpp new file mode 100644 index 00000000..40fa6d42 --- /dev/null +++ b/src/core/functions/table/pagerank.cpp @@ -0,0 +1,42 @@ +#include "duckpgq/core/functions/table/pagerank.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" + +#include +#include +#include "duckdb/parser/tableref/basetableref.hpp" + +namespace duckpgq { +namespace core { + +// Main binding function +unique_ptr PageRankFunction::PageRankBindReplace(ClientContext &context, TableFunctionBindInput &input) { + auto pg_name = StringUtil::Lower(StringValue::Get(input.inputs[0])); + auto node_table = StringUtil::Lower(StringValue::Get(input.inputs[1])); + auto edge_table = StringUtil::Lower(StringValue::Get(input.inputs[2])); + + auto duckpgq_state = GetDuckPGQState(context); + auto pg_info = GetPropertyGraphInfo(duckpgq_state, pg_name); + auto edge_pg_entry = ValidateSourceNodeAndEdgeTable(pg_info, node_table, edge_table); + + auto select_node = CreateSelectNode(edge_pg_entry, "pagerank", "pagerank"); + + select_node->cte_map.map["csr_cte"] = CreateDirectedCSRCTE(edge_pg_entry, "src", "edge", "dst"); + + auto subquery = make_uniq(); + subquery->node = std::move(select_node); + + auto result = make_uniq(std::move(subquery)); + result->alias = "wcc"; + return std::move(result); +} + +//------------------------------------------------------------------------------ +// Register functions +//------------------------------------------------------------------------------ +void CoreTableFunctions::RegisterPageRankTableFunction(DatabaseInstance &db) { + ExtensionUtil::RegisterFunction(db, PageRankFunction()); +} + +} // namespace core +} // namespace duckpgq \ No newline at end of file diff --git a/src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp b/src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp index 065e0f2f..b9c5e993 100644 --- a/src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp +++ b/src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckPGQ // -// duckpgq/functions/function_data/cheapest_path_length_function_data.hpp +// duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp // // //===----------------------------------------------------------------------===// diff --git a/src/include/duckpgq/core/functions/function_data/iterative_length_function_data.hpp b/src/include/duckpgq/core/functions/function_data/iterative_length_function_data.hpp index 49b2a4b2..a6e4f80a 100644 --- a/src/include/duckpgq/core/functions/function_data/iterative_length_function_data.hpp +++ b/src/include/duckpgq/core/functions/function_data/iterative_length_function_data.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckPGQ // -// duckpgq/functions/function_data/iterative_length_function_data.hpp +// duckpgq/core/functions/function_data/iterative_length_function_data.hpp // // //===----------------------------------------------------------------------===// diff --git a/src/include/duckpgq/core/functions/function_data/pagerank_function_data.hpp b/src/include/duckpgq/core/functions/function_data/pagerank_function_data.hpp new file mode 100644 index 00000000..7e44de75 --- /dev/null +++ b/src/include/duckpgq/core/functions/function_data/pagerank_function_data.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckPGQ +// +// duckpgq/core/functions/function_data/pagerank_function_data.hpp +// +// +//===----------------------------------------------------------------------===// + + +#pragma once +#include "duckdb/main/client_context.hpp" +#include "duckpgq/common.hpp" + +namespace duckpgq { +namespace core { +struct PageRankFunctionData final : FunctionData { + ClientContext &context; + int32_t csr_id; + vector rank; + vector temp_rank; + double_t damping_factor; + double_t convergence_threshold; + int64_t iteration_count; + std::mutex state_lock; // Lock for state + bool state_initialized; + bool converged; + + PageRankFunctionData(ClientContext &context, int32_t csr_id); + PageRankFunctionData(ClientContext &context, int32_t csr_id, const vector &componentId); + static unique_ptr + PageRankBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments); + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other_p) const override; +}; + + +} // namespace core + +} // namespace duckpgq \ No newline at end of file diff --git a/src/include/duckpgq/core/functions/scalar.hpp b/src/include/duckpgq/core/functions/scalar.hpp index 3ac3a679..25d063a9 100644 --- a/src/include/duckpgq/core/functions/scalar.hpp +++ b/src/include/duckpgq/core/functions/scalar.hpp @@ -18,6 +18,7 @@ struct CoreScalarFunctions { RegisterReachabilityScalarFunction(db); RegisterShortestPathScalarFunction(db); RegisterWeaklyConnectedComponentScalarFunction(db); + RegisterPageRankScalarFunction(db); } private: @@ -32,6 +33,8 @@ struct CoreScalarFunctions { static void RegisterReachabilityScalarFunction(DatabaseInstance &db); static void RegisterShortestPathScalarFunction(DatabaseInstance &db); static void RegisterWeaklyConnectedComponentScalarFunction(DatabaseInstance &db); + static void RegisterPageRankScalarFunction(DatabaseInstance &db); + }; diff --git a/src/include/duckpgq/core/functions/table.hpp b/src/include/duckpgq/core/functions/table.hpp index 3a47dd6d..6dff40f5 100644 --- a/src/include/duckpgq/core/functions/table.hpp +++ b/src/include/duckpgq/core/functions/table.hpp @@ -14,6 +14,7 @@ struct CoreTableFunctions { RegisterLocalClusteringCoefficientTableFunction(db); RegisterScanTableFunctions(db); RegisterWeaklyConnectedComponentTableFunction(db); + RegisterPageRankTableFunction(db); } private: @@ -24,6 +25,7 @@ struct CoreTableFunctions { static void RegisterLocalClusteringCoefficientTableFunction(DatabaseInstance &db); static void RegisterScanTableFunctions(DatabaseInstance &db); static void RegisterWeaklyConnectedComponentTableFunction(DatabaseInstance &db); + static void RegisterPageRankTableFunction(DatabaseInstance &db); }; diff --git a/src/include/duckpgq/core/functions/table/pagerank.hpp b/src/include/duckpgq/core/functions/table/pagerank.hpp new file mode 100644 index 00000000..65518d64 --- /dev/null +++ b/src/include/duckpgq/core/functions/table/pagerank.hpp @@ -0,0 +1,59 @@ +//===----------------------------------------------------------------------===// +// DuckPGQ +// +// duckpgq/core/functions/table/pagerank.hpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckpgq/common.hpp" + +namespace duckpgq { +namespace core { + +class PageRankFunction : public TableFunction { +public: + PageRankFunction() { + name = "pagerank"; + arguments = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; + bind_replace = PageRankBindReplace; + } + + static unique_ptr PageRankBindReplace(ClientContext &context, + TableFunctionBindInput &input); + +}; + +struct PageRankData : TableFunctionData { + static unique_ptr + PageRankBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto result = make_uniq(); + result->pg_name = StringValue::Get(input.inputs[0]); + result->node_table = StringValue::Get(input.inputs[1]); + result->edge_table = StringValue::Get(input.inputs[2]); + return_types.emplace_back(LogicalType::BIGINT); + return_types.emplace_back(LogicalType::BIGINT); + names.emplace_back("rowid"); + names.emplace_back("pagerank"); + return std::move(result); + } + + string pg_name; + string node_table; + string edge_table; +}; + + +struct PageRankScanState : GlobalTableFunctionState { + static unique_ptr + Init(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + return std::move(result); + } + + bool finished = false; +}; + +} // namespace core +} // namespace duckpgq \ No newline at end of file diff --git a/test/sql/scalar/pagerank.test b/test/sql/scalar/pagerank.test new file mode 100644 index 00000000..b3376651 --- /dev/null +++ b/test/sql/scalar/pagerank.test @@ -0,0 +1,87 @@ +# name: test/sql/scalar/pagerank.test +# description: Testing the pagerank implementation +# group: [duckpgq_sql_scalar] + +require duckpgq + +statement ok +CREATE TABLE Student(id BIGINT, name VARCHAR);INSERT INTO Student VALUES (0, 'Daniel'), (1, 'Tavneet'), (2, 'Gabor'), (3, 'Peter'), (4, 'David'); + +statement ok +CREATE TABLE know(src BIGINT, dst BIGINT, createDate BIGINT);INSERT INTO know VALUES (0,1, 10), (0,2, 11), (0,3, 12), (3,0, 13), (1,2, 14), (1,3, 15), (2,3, 16), (4,3, 17); + +statement ok +-CREATE PROPERTY GRAPH pg +VERTEX TABLES ( + Student + ) +EDGE TABLES ( + know SOURCE KEY ( src ) REFERENCES Student ( id ) + DESTINATION KEY ( dst ) REFERENCES Student ( id ) + ); + +query II +select id, pagerank from pagerank(pg, student, know); +---- +0 0.30722555839452875 +1 0.11534940106637968 +2 0.16437299553018173 +3 0.32814638463154105 +4 0.028301886792456276 + + +statement ok +CREATE OR REPLACE TABLE Student ( + id BIGINT +); + +statement ok +INSERT INTO Student (id) VALUES +(0), +(1), +(2), +(3), +(4); + +statement ok +CREATE OR REPLACE TABLE know ( + src BIGINT, + dst BIGINT, + edge BIGINT +); + +statement ok +INSERT INTO know (src, dst, edge) VALUES +(2, 1, 4), +(3, 1, 5), +(3, 2, 6), +(1, 2, 4), +(1, 0, 0), +(2, 0, 1), +(3, 0, 2), +(0, 1, 0), +(4, 3, 7), +(0, 3, 3), +(1, 3, 5), +(2, 3, 6), +(3, 4, 7), +(0, 2, 1); + +statement ok +-CREATE OR REPLACE PROPERTY GRAPH pg +VERTEX TABLES ( + Student + ) +EDGE TABLES ( + know SOURCE KEY ( src ) REFERENCES Student ( id ) + DESTINATION KEY ( dst ) REFERENCES Student ( id ) + ); + +query II +select id, pagerank from pagerank(pg, student, know); +---- +0 0.19672392385442233 +1 0.19672392385442233 +2 0.19672392385442233 +3 0.26797750004549203 +4 0.08524695480585476 \ No newline at end of file