Skip to content

Commit

Permalink
Merge pull request #144 from cwida/pagerank
Browse files Browse the repository at this point in the history
PageRank table function
  • Loading branch information
Dtenwolde authored Sep 4, 2024
2 parents 8b439f8 + 9c6b7bc commit 104247e
Show file tree
Hide file tree
Showing 13 changed files with 441 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/core/functions/function_data/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/cheapest_path_length_function_data.cpp
${CMAKE_CURRENT_SOURCE_DIR}/iterative_length_function_data.cpp
${CMAKE_CURRENT_SOURCE_DIR}/local_clustering_coefficient_function_data.cpp
${CMAKE_CURRENT_SOURCE_DIR}/pagerank_function_data.cpp
${CMAKE_CURRENT_SOURCE_DIR}/weakly_connected_component_function_data.cpp

PARENT_SCOPE
Expand Down
72 changes: 72 additions & 0 deletions src/core/functions/function_data/pagerank_function_data.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "duckpgq/core/functions/function_data/pagerank_function_data.hpp"

namespace duckpgq {

namespace core {

// Constructor
PageRankFunctionData::PageRankFunctionData(ClientContext &ctx, int32_t csr)
: context(ctx), csr_id(csr), damping_factor(0.85),
convergence_threshold(1e-6), iteration_count(0), state_initialized(false),
converged(false) {}

unique_ptr<FunctionData>
PageRankFunctionData::PageRankBind(ClientContext &context,
ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments) {
if (!arguments[0]->IsFoldable()) {
throw InvalidInputException("Id must be constant.");
}

int32_t csr_id = ExpressionExecutor::EvaluateScalar(context, *arguments[0])
.GetValue<int32_t>();

return make_uniq<PageRankFunctionData>(context, csr_id);
}

// Copy method
unique_ptr<FunctionData> PageRankFunctionData::Copy() const {
auto result = make_uniq<PageRankFunctionData>(context, csr_id);
result->rank = rank; // Deep copy of rank vector
result->temp_rank = temp_rank; // Deep copy of temp_rank vector
result->damping_factor = damping_factor;
result->convergence_threshold = convergence_threshold;
result->iteration_count = iteration_count;
result->state_initialized = state_initialized;
result->converged = converged;
// Note: state_lock is not copied as mutexes are not copyable
return result;
}

// Equals method
bool PageRankFunctionData::Equals(const FunctionData &other_p) const {
auto &other = (const PageRankFunctionData &)other_p;
if (csr_id != other.csr_id) {
return false;
}
if (rank != other.rank) {
return false;
}
if (temp_rank != other.temp_rank) {
return false;
}
if (damping_factor != other.damping_factor) {
return false;
}
if (convergence_threshold != other.convergence_threshold) {
return false;
}
if (iteration_count != other.iteration_count) {
return false;
}
if (state_initialized != other.state_initialized) {
return false;
}
if (converged != other.converged) {
return false;
}
return true;
}
} // namespace core

} // namespace duckpgq
1 change: 1 addition & 0 deletions src/core/functions/scalar/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/iterativelength.cpp
${CMAKE_CURRENT_SOURCE_DIR}/iterativelength2.cpp
${CMAKE_CURRENT_SOURCE_DIR}/iterativelength_bidirectional.cpp
${CMAKE_CURRENT_SOURCE_DIR}/pagerank.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reachability.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shortest_path.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csr_creation.cpp
Expand Down
130 changes: 130 additions & 0 deletions src/core/functions/scalar/pagerank.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckpgq/common.hpp"
#include "duckpgq/core/functions/function_data/pagerank_function_data.hpp"
#include <duckpgq/core/functions/scalar.hpp>
#include <duckpgq/core/functions/table/pagerank.hpp>
#include <duckpgq/core/utils/duckpgq_bitmap.hpp>
#include <duckpgq/core/utils/duckpgq_utils.hpp>

namespace duckpgq {
namespace core {

static void PageRankFunction(DataChunk &args,
ExpressionState &state,
Vector &result) {
auto &func_expr = (BoundFunctionExpression &)state.expr;
auto &info = (PageRankFunctionData &)*func_expr.bind_info;
auto duckpgq_state = GetDuckPGQState(info.context);

// Locate the CSR representation of the graph
auto csr_entry = duckpgq_state->csr_list.find((uint64_t)info.csr_id);
if (csr_entry == duckpgq_state->csr_list.end()) {
throw ConstraintException("CSR not found. Is the graph populated?");
}

if (!(csr_entry->second->initialized_v && csr_entry->second->initialized_e)) {
throw ConstraintException("Need to initialize CSR before running PageRank.");
}

int64_t *v = (int64_t *)duckpgq_state->csr_list[info.csr_id]->v;
vector<int64_t> &e = duckpgq_state->csr_list[info.csr_id]->e;
size_t v_size = duckpgq_state->csr_list[info.csr_id]->vsize;

// State initialization (only once)
if (!info.state_initialized) {
info.rank.resize(v_size, 1.0 / v_size); // Initial rank for each node
info.temp_rank.resize(v_size, 0.0); // Temporary storage for ranks during iteration
info.damping_factor = 0.85; // Typical damping factor
info.convergence_threshold = 1e-6; // Convergence threshold
info.state_initialized = true;
info.converged = false;
info.iteration_count = 0;
}

// Check if already converged
if (!info.converged) {
std::lock_guard<std::mutex> guard(info.state_lock); // Thread safety

bool continue_iteration = true;
while (continue_iteration) {
fill(info.temp_rank.begin(), info.temp_rank.end(), 0.0);

double total_dangling_rank = 0.0; // For dangling nodes

for (size_t i = 0; i < v_size; i++) {
int64_t start_edge = v[i];
int64_t end_edge = (i + 1 < v_size) ? v[i + 1] : e.size(); // Adjust end_edge
if (end_edge > start_edge) {
double rank_contrib = info.rank[i] / (end_edge - start_edge);
for (int64_t j = start_edge; j < end_edge; j++) {
int64_t neighbor = e[j];
info.temp_rank[neighbor] += rank_contrib;
}
} else {
total_dangling_rank += info.rank[i];
}
}

// Apply damping factor and handle dangling node ranks
double correction_factor = total_dangling_rank / v_size;
double max_delta = 0.0;
for (size_t i = 0; i < v_size; i++) {
info.temp_rank[i] = (1 - info.damping_factor) / v_size +
info.damping_factor * (info.temp_rank[i] + correction_factor);
max_delta = std::max(max_delta, std::abs(info.temp_rank[i] - info.rank[i]));
}

info.rank.swap(info.temp_rank);
info.iteration_count++;
if (max_delta < info.convergence_threshold) {
info.converged = true;
continue_iteration = false;
}
}
}

// Get the source vector for the current DataChunk
auto &src = args.data[1];
UnifiedVectorFormat vdata_src;
src.ToUnifiedFormat(args.size(), vdata_src);
auto src_data = (int64_t *)vdata_src.data;

// Create result vector
ValidityMask &result_validity = FlatVector::Validity(result);
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<double_t>(result);

// Output the PageRank value corresponding to each source ID in the DataChunk
for (idx_t i = 0; i < args.size(); i++) {
auto id_pos = vdata_src.sel->get_index(i);
if (!vdata_src.validity.RowIsValid(id_pos)) {
result_validity.SetInvalid(i);
continue; // Skip invalid rows
}
auto node_id = src_data[id_pos];
if (node_id < 0 || node_id >= (int64_t)v_size) {
result_validity.SetInvalid(i);
continue;
}
result_data[i] = info.rank[node_id];
}

duckpgq_state->csr_to_delete.insert(info.csr_id);
}

//------------------------------------------------------------------------------
// Register functions
//------------------------------------------------------------------------------
void CoreScalarFunctions::RegisterPageRankScalarFunction(
DatabaseInstance &db) {
ExtensionUtil::RegisterFunction(
db,
ScalarFunction(
"pagerank",
{LogicalType::INTEGER, LogicalType::BIGINT},
LogicalType::DOUBLE, PageRankFunction,
PageRankFunctionData::PageRankBind));
}

} // namespace core
} // namespace duckpgq
1 change: 1 addition & 0 deletions src/core/functions/table/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/drop_property_graph.cpp
${CMAKE_CURRENT_SOURCE_DIR}/local_clustering_coefficient.cpp
${CMAKE_CURRENT_SOURCE_DIR}/match.cpp
${CMAKE_CURRENT_SOURCE_DIR}/pagerank.cpp
${CMAKE_CURRENT_SOURCE_DIR}/pgq_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/weakly_connected_component.cpp
${EXTENSION_SOURCES}
Expand Down
42 changes: 42 additions & 0 deletions src/core/functions/table/pagerank.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "duckpgq/core/functions/table/pagerank.hpp"
#include "duckdb/function/table_function.hpp"
#include "duckdb/parser/tableref/subqueryref.hpp"

#include <duckpgq/core/functions/table.hpp>
#include <duckpgq/core/utils/duckpgq_utils.hpp>
#include "duckdb/parser/tableref/basetableref.hpp"

namespace duckpgq {
namespace core {

// Main binding function
unique_ptr<TableRef> PageRankFunction::PageRankBindReplace(ClientContext &context, TableFunctionBindInput &input) {
auto pg_name = StringUtil::Lower(StringValue::Get(input.inputs[0]));
auto node_table = StringUtil::Lower(StringValue::Get(input.inputs[1]));
auto edge_table = StringUtil::Lower(StringValue::Get(input.inputs[2]));

auto duckpgq_state = GetDuckPGQState(context);
auto pg_info = GetPropertyGraphInfo(duckpgq_state, pg_name);
auto edge_pg_entry = ValidateSourceNodeAndEdgeTable(pg_info, node_table, edge_table);

auto select_node = CreateSelectNode(edge_pg_entry, "pagerank", "pagerank");

select_node->cte_map.map["csr_cte"] = CreateDirectedCSRCTE(edge_pg_entry, "src", "edge", "dst");

auto subquery = make_uniq<SelectStatement>();
subquery->node = std::move(select_node);

auto result = make_uniq<SubqueryRef>(std::move(subquery));
result->alias = "wcc";
return std::move(result);
}

//------------------------------------------------------------------------------
// Register functions
//------------------------------------------------------------------------------
void CoreTableFunctions::RegisterPageRankTableFunction(DatabaseInstance &db) {
ExtensionUtil::RegisterFunction(db, PageRankFunction());
}

} // namespace core
} // namespace duckpgq
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//===----------------------------------------------------------------------===//
// DuckPGQ
//
// duckpgq/functions/function_data/cheapest_path_length_function_data.hpp
// duckpgq/core/functions/function_data/cheapest_path_length_function_data.hpp
//
//
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//===----------------------------------------------------------------------===//
// DuckPGQ
//
// duckpgq/functions/function_data/iterative_length_function_data.hpp
// duckpgq/core/functions/function_data/iterative_length_function_data.hpp
//
//
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===----------------------------------------------------------------------===//
// DuckPGQ
//
// duckpgq/core/functions/function_data/pagerank_function_data.hpp
//
//
//===----------------------------------------------------------------------===//


#pragma once
#include "duckdb/main/client_context.hpp"
#include "duckpgq/common.hpp"

namespace duckpgq {
namespace core {
struct PageRankFunctionData final : FunctionData {
ClientContext &context;
int32_t csr_id;
vector<double_t> rank;
vector<double_t> temp_rank;
double_t damping_factor;
double_t convergence_threshold;
int64_t iteration_count;
std::mutex state_lock; // Lock for state
bool state_initialized;
bool converged;

PageRankFunctionData(ClientContext &context, int32_t csr_id);
PageRankFunctionData(ClientContext &context, int32_t csr_id, const vector<int64_t> &componentId);
static unique_ptr<FunctionData>
PageRankBind(ClientContext &context, ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments);

unique_ptr<FunctionData> Copy() const override;
bool Equals(const FunctionData &other_p) const override;
};


} // namespace core

} // namespace duckpgq
3 changes: 3 additions & 0 deletions src/include/duckpgq/core/functions/scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct CoreScalarFunctions {
RegisterReachabilityScalarFunction(db);
RegisterShortestPathScalarFunction(db);
RegisterWeaklyConnectedComponentScalarFunction(db);
RegisterPageRankScalarFunction(db);
}

private:
Expand All @@ -32,6 +33,8 @@ struct CoreScalarFunctions {
static void RegisterReachabilityScalarFunction(DatabaseInstance &db);
static void RegisterShortestPathScalarFunction(DatabaseInstance &db);
static void RegisterWeaklyConnectedComponentScalarFunction(DatabaseInstance &db);
static void RegisterPageRankScalarFunction(DatabaseInstance &db);

};


Expand Down
2 changes: 2 additions & 0 deletions src/include/duckpgq/core/functions/table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct CoreTableFunctions {
RegisterLocalClusteringCoefficientTableFunction(db);
RegisterScanTableFunctions(db);
RegisterWeaklyConnectedComponentTableFunction(db);
RegisterPageRankTableFunction(db);
}

private:
Expand All @@ -24,6 +25,7 @@ struct CoreTableFunctions {
static void RegisterLocalClusteringCoefficientTableFunction(DatabaseInstance &db);
static void RegisterScanTableFunctions(DatabaseInstance &db);
static void RegisterWeaklyConnectedComponentTableFunction(DatabaseInstance &db);
static void RegisterPageRankTableFunction(DatabaseInstance &db);
};


Expand Down
Loading

0 comments on commit 104247e

Please sign in to comment.