diff --git a/duckpgq/src/duckpgq/functions/scalar/shortest_path_lowerbound.cpp b/duckpgq/src/duckpgq/functions/scalar/shortest_path_lowerbound.cpp index 91f852ea..97086ba1 100644 --- a/duckpgq/src/duckpgq/functions/scalar/shortest_path_lowerbound.cpp +++ b/duckpgq/src/duckpgq/functions/scalar/shortest_path_lowerbound.cpp @@ -11,7 +11,8 @@ namespace duckdb { -static bool IterativeLengthLowerBound(int64_t v_size, int64_t iter, bool seen_check, +template +static bool IterativeLengthLowerBound(int64_t v_size, int64_t iter, int64_t *V, vector &E, int16_t lane_to_num[LANE_LIMIT], vector &edge_ids, @@ -42,7 +43,7 @@ static bool IterativeLengthLowerBound(int64_t v_size, int64_t iter, bool seen_ch } for (auto v = 0; v < v_size; v++) { - if (seen_check) { + if (SEEN_CHECK) { next[v] = next[v] & ~seen[v]; seen[v] = seen[v] | next[v]; } @@ -98,8 +99,6 @@ static void ShortestPathLowerBoundFunction(DataChunk &args, auto result_data = FlatVector::GetData(result); ValidityMask &result_validity = FlatVector::Validity(result); - bool seen_check = false; - // create temp SIMD arrays vector> seen(v_size); vector> visit1(v_size); @@ -119,8 +118,6 @@ static void ShortestPathLowerBoundFunction(DataChunk &args, idx_t started_searches = 0; while (started_searches < args.size()) { - seen_check = false; - // empty visit vectors for (auto i = 0; i < v_size; i++) { seen[i] = 0; @@ -151,17 +148,19 @@ static void ShortestPathLowerBoundFunction(DataChunk &args, //! make passes while a lane is still active for (int64_t iter = 1; active && iter <= upper_bound; iter++) { - if (iter >= lower_bound) { - seen_check = true; - } - //! Perform one step of bfs exploration - if (!IterativeLengthLowerBound( - v_size, iter, seen_check, v, e, lane_to_num, edge_ids, paths_v, paths_e, seen, - (iter & 1) ? visit1 : visit2, (iter & 1) ? visit2 : visit1)) { - break; - } if (iter < lower_bound) { + if (!IterativeLengthLowerBound( + v_size, iter, v, e, lane_to_num, edge_ids, paths_v, paths_e, seen, + (iter & 1) ? visit1 : visit2, (iter & 1) ? visit2 : visit1)) { + break; + } continue; + } else { + if (!IterativeLengthLowerBound( + v_size, iter, v, e, lane_to_num, edge_ids, paths_v, paths_e, seen, + (iter & 1) ? visit1 : visit2, (iter & 1) ? visit2 : visit1)) { + break; + } } // detect lanes that finished for (int64_t lane = 0; lane < LANE_LIMIT; lane++) {