generated from duckdb/extension-template
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #144 from cwida/pagerank
PageRank table function
- Loading branch information
Showing
13 changed files
with
441 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
src/core/functions/function_data/pagerank_function_data.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#include "duckpgq/core/functions/function_data/pagerank_function_data.hpp" | ||
|
||
namespace duckpgq { | ||
|
||
namespace core { | ||
|
||
// Constructor | ||
PageRankFunctionData::PageRankFunctionData(ClientContext &ctx, int32_t csr) | ||
: context(ctx), csr_id(csr), damping_factor(0.85), | ||
convergence_threshold(1e-6), iteration_count(0), state_initialized(false), | ||
converged(false) {} | ||
|
||
unique_ptr<FunctionData> | ||
PageRankFunctionData::PageRankBind(ClientContext &context, | ||
ScalarFunction &bound_function, | ||
vector<unique_ptr<Expression>> &arguments) { | ||
if (!arguments[0]->IsFoldable()) { | ||
throw InvalidInputException("Id must be constant."); | ||
} | ||
|
||
int32_t csr_id = ExpressionExecutor::EvaluateScalar(context, *arguments[0]) | ||
.GetValue<int32_t>(); | ||
|
||
return make_uniq<PageRankFunctionData>(context, csr_id); | ||
} | ||
|
||
// Copy method | ||
unique_ptr<FunctionData> PageRankFunctionData::Copy() const { | ||
auto result = make_uniq<PageRankFunctionData>(context, csr_id); | ||
result->rank = rank; // Deep copy of rank vector | ||
result->temp_rank = temp_rank; // Deep copy of temp_rank vector | ||
result->damping_factor = damping_factor; | ||
result->convergence_threshold = convergence_threshold; | ||
result->iteration_count = iteration_count; | ||
result->state_initialized = state_initialized; | ||
result->converged = converged; | ||
// Note: state_lock is not copied as mutexes are not copyable | ||
return result; | ||
} | ||
|
||
// Equals method | ||
bool PageRankFunctionData::Equals(const FunctionData &other_p) const { | ||
auto &other = (const PageRankFunctionData &)other_p; | ||
if (csr_id != other.csr_id) { | ||
return false; | ||
} | ||
if (rank != other.rank) { | ||
return false; | ||
} | ||
if (temp_rank != other.temp_rank) { | ||
return false; | ||
} | ||
if (damping_factor != other.damping_factor) { | ||
return false; | ||
} | ||
if (convergence_threshold != other.convergence_threshold) { | ||
return false; | ||
} | ||
if (iteration_count != other.iteration_count) { | ||
return false; | ||
} | ||
if (state_initialized != other.state_initialized) { | ||
return false; | ||
} | ||
if (converged != other.converged) { | ||
return false; | ||
} | ||
return true; | ||
} | ||
} // namespace core | ||
|
||
} // namespace duckpgq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#include "duckdb/planner/expression/bound_function_expression.hpp" | ||
#include "duckpgq/common.hpp" | ||
#include "duckpgq/core/functions/function_data/pagerank_function_data.hpp" | ||
#include <duckpgq/core/functions/scalar.hpp> | ||
#include <duckpgq/core/functions/table/pagerank.hpp> | ||
#include <duckpgq/core/utils/duckpgq_bitmap.hpp> | ||
#include <duckpgq/core/utils/duckpgq_utils.hpp> | ||
|
||
namespace duckpgq { | ||
namespace core { | ||
|
||
static void PageRankFunction(DataChunk &args, | ||
ExpressionState &state, | ||
Vector &result) { | ||
auto &func_expr = (BoundFunctionExpression &)state.expr; | ||
auto &info = (PageRankFunctionData &)*func_expr.bind_info; | ||
auto duckpgq_state = GetDuckPGQState(info.context); | ||
|
||
// Locate the CSR representation of the graph | ||
auto csr_entry = duckpgq_state->csr_list.find((uint64_t)info.csr_id); | ||
if (csr_entry == duckpgq_state->csr_list.end()) { | ||
throw ConstraintException("CSR not found. Is the graph populated?"); | ||
} | ||
|
||
if (!(csr_entry->second->initialized_v && csr_entry->second->initialized_e)) { | ||
throw ConstraintException("Need to initialize CSR before running PageRank."); | ||
} | ||
|
||
int64_t *v = (int64_t *)duckpgq_state->csr_list[info.csr_id]->v; | ||
vector<int64_t> &e = duckpgq_state->csr_list[info.csr_id]->e; | ||
size_t v_size = duckpgq_state->csr_list[info.csr_id]->vsize; | ||
|
||
// State initialization (only once) | ||
if (!info.state_initialized) { | ||
info.rank.resize(v_size, 1.0 / v_size); // Initial rank for each node | ||
info.temp_rank.resize(v_size, 0.0); // Temporary storage for ranks during iteration | ||
info.damping_factor = 0.85; // Typical damping factor | ||
info.convergence_threshold = 1e-6; // Convergence threshold | ||
info.state_initialized = true; | ||
info.converged = false; | ||
info.iteration_count = 0; | ||
} | ||
|
||
// Check if already converged | ||
if (!info.converged) { | ||
std::lock_guard<std::mutex> guard(info.state_lock); // Thread safety | ||
|
||
bool continue_iteration = true; | ||
while (continue_iteration) { | ||
fill(info.temp_rank.begin(), info.temp_rank.end(), 0.0); | ||
|
||
double total_dangling_rank = 0.0; // For dangling nodes | ||
|
||
for (size_t i = 0; i < v_size; i++) { | ||
int64_t start_edge = v[i]; | ||
int64_t end_edge = (i + 1 < v_size) ? v[i + 1] : e.size(); // Adjust end_edge | ||
if (end_edge > start_edge) { | ||
double rank_contrib = info.rank[i] / (end_edge - start_edge); | ||
for (int64_t j = start_edge; j < end_edge; j++) { | ||
int64_t neighbor = e[j]; | ||
info.temp_rank[neighbor] += rank_contrib; | ||
} | ||
} else { | ||
total_dangling_rank += info.rank[i]; | ||
} | ||
} | ||
|
||
// Apply damping factor and handle dangling node ranks | ||
double correction_factor = total_dangling_rank / v_size; | ||
double max_delta = 0.0; | ||
for (size_t i = 0; i < v_size; i++) { | ||
info.temp_rank[i] = (1 - info.damping_factor) / v_size + | ||
info.damping_factor * (info.temp_rank[i] + correction_factor); | ||
max_delta = std::max(max_delta, std::abs(info.temp_rank[i] - info.rank[i])); | ||
} | ||
|
||
info.rank.swap(info.temp_rank); | ||
info.iteration_count++; | ||
if (max_delta < info.convergence_threshold) { | ||
info.converged = true; | ||
continue_iteration = false; | ||
} | ||
} | ||
} | ||
|
||
// Get the source vector for the current DataChunk | ||
auto &src = args.data[1]; | ||
UnifiedVectorFormat vdata_src; | ||
src.ToUnifiedFormat(args.size(), vdata_src); | ||
auto src_data = (int64_t *)vdata_src.data; | ||
|
||
// Create result vector | ||
ValidityMask &result_validity = FlatVector::Validity(result); | ||
result.SetVectorType(VectorType::FLAT_VECTOR); | ||
auto result_data = FlatVector::GetData<double_t>(result); | ||
|
||
// Output the PageRank value corresponding to each source ID in the DataChunk | ||
for (idx_t i = 0; i < args.size(); i++) { | ||
auto id_pos = vdata_src.sel->get_index(i); | ||
if (!vdata_src.validity.RowIsValid(id_pos)) { | ||
result_validity.SetInvalid(i); | ||
continue; // Skip invalid rows | ||
} | ||
auto node_id = src_data[id_pos]; | ||
if (node_id < 0 || node_id >= (int64_t)v_size) { | ||
result_validity.SetInvalid(i); | ||
continue; | ||
} | ||
result_data[i] = info.rank[node_id]; | ||
} | ||
|
||
duckpgq_state->csr_to_delete.insert(info.csr_id); | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// Register functions | ||
//------------------------------------------------------------------------------ | ||
void CoreScalarFunctions::RegisterPageRankScalarFunction( | ||
DatabaseInstance &db) { | ||
ExtensionUtil::RegisterFunction( | ||
db, | ||
ScalarFunction( | ||
"pagerank", | ||
{LogicalType::INTEGER, LogicalType::BIGINT}, | ||
LogicalType::DOUBLE, PageRankFunction, | ||
PageRankFunctionData::PageRankBind)); | ||
} | ||
|
||
} // namespace core | ||
} // namespace duckpgq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include "duckpgq/core/functions/table/pagerank.hpp" | ||
#include "duckdb/function/table_function.hpp" | ||
#include "duckdb/parser/tableref/subqueryref.hpp" | ||
|
||
#include <duckpgq/core/functions/table.hpp> | ||
#include <duckpgq/core/utils/duckpgq_utils.hpp> | ||
#include "duckdb/parser/tableref/basetableref.hpp" | ||
|
||
namespace duckpgq { | ||
namespace core { | ||
|
||
// Main binding function | ||
unique_ptr<TableRef> PageRankFunction::PageRankBindReplace(ClientContext &context, TableFunctionBindInput &input) { | ||
auto pg_name = StringUtil::Lower(StringValue::Get(input.inputs[0])); | ||
auto node_table = StringUtil::Lower(StringValue::Get(input.inputs[1])); | ||
auto edge_table = StringUtil::Lower(StringValue::Get(input.inputs[2])); | ||
|
||
auto duckpgq_state = GetDuckPGQState(context); | ||
auto pg_info = GetPropertyGraphInfo(duckpgq_state, pg_name); | ||
auto edge_pg_entry = ValidateSourceNodeAndEdgeTable(pg_info, node_table, edge_table); | ||
|
||
auto select_node = CreateSelectNode(edge_pg_entry, "pagerank", "pagerank"); | ||
|
||
select_node->cte_map.map["csr_cte"] = CreateDirectedCSRCTE(edge_pg_entry, "src", "edge", "dst"); | ||
|
||
auto subquery = make_uniq<SelectStatement>(); | ||
subquery->node = std::move(select_node); | ||
|
||
auto result = make_uniq<SubqueryRef>(std::move(subquery)); | ||
result->alias = "wcc"; | ||
return std::move(result); | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// Register functions | ||
//------------------------------------------------------------------------------ | ||
void CoreTableFunctions::RegisterPageRankTableFunction(DatabaseInstance &db) { | ||
ExtensionUtil::RegisterFunction(db, PageRankFunction()); | ||
} | ||
|
||
} // namespace core | ||
} // namespace duckpgq |
2 changes: 1 addition & 1 deletion
2
src/include/duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
src/include/duckpgq/core/functions/function_data/iterative_length_function_data.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
src/include/duckpgq/core/functions/function_data/pagerank_function_data.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
//===----------------------------------------------------------------------===// | ||
// DuckPGQ | ||
// | ||
// duckpgq/core/functions/function_data/pagerank_function_data.hpp | ||
// | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
|
||
#pragma once | ||
#include "duckdb/main/client_context.hpp" | ||
#include "duckpgq/common.hpp" | ||
|
||
namespace duckpgq { | ||
namespace core { | ||
struct PageRankFunctionData final : FunctionData { | ||
ClientContext &context; | ||
int32_t csr_id; | ||
vector<double_t> rank; | ||
vector<double_t> temp_rank; | ||
double_t damping_factor; | ||
double_t convergence_threshold; | ||
int64_t iteration_count; | ||
std::mutex state_lock; // Lock for state | ||
bool state_initialized; | ||
bool converged; | ||
|
||
PageRankFunctionData(ClientContext &context, int32_t csr_id); | ||
PageRankFunctionData(ClientContext &context, int32_t csr_id, const vector<int64_t> &componentId); | ||
static unique_ptr<FunctionData> | ||
PageRankBind(ClientContext &context, ScalarFunction &bound_function, | ||
vector<unique_ptr<Expression>> &arguments); | ||
|
||
unique_ptr<FunctionData> Copy() const override; | ||
bool Equals(const FunctionData &other_p) const override; | ||
}; | ||
|
||
|
||
} // namespace core | ||
|
||
} // namespace duckpgq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.