Skip to content

Commit

Permalink
[perf] optimize work steal queue
Browse files Browse the repository at this point in the history
  • Loading branch information
ChunelFeng committed Oct 22, 2023
1 parent 0cbe305 commit 6af3722
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
66 changes: 51 additions & 15 deletions src/UtilsCtrl/ThreadPool/Queue/UWorkStealingQueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,20 @@
#include <deque>

#include "UQueueObject.h"
#include "../Task/UTaskInclude.h"
#include "../Lock/ULockInclude.h"

CGRAPH_NAMESPACE_BEGIN

template<typename T>
class UWorkStealingQueue : public UQueueObject {
public:
/**
* 向队列中写入信息
* @param task
*/
CVoid push(UTask&& task) {
CVoid push(T&& task) {
while (true) {
if (lock_.try_lock()) {
deque_.emplace_front(std::move(task));
deque_.emplace_front(std::forward<T>(task));
lock_.unlock();
break;
} else {
Expand All @@ -42,10 +41,47 @@ class UWorkStealingQueue : public UQueueObject {
* @param task
* @return
*/
CBool tryPush(UTask&& task) {
CBool tryPush(T&& task) {
CBool result = false;
if (lock_.try_lock()) {
deque_.emplace_back(std::move(task));
deque_.emplace_back(std::forward<T>(task));
lock_.unlock();
result = true;
}
return result;
}


/**
* 向队列中写入信息
* @param task
*/
CVoid push(std::vector<T>& tasks) {
while (true) {
if (lock_.try_lock()) {
for (const auto& task : tasks) {
deque_.emplace_front(std::forward<T>(task));
}
lock_.unlock();
break;
} else {
std::this_thread::yield();
}
}
}


/**
* 尝试批量写入内容
* @param tasks
* @return
*/
CBool tryPush(std::vector<T>& tasks) {
CBool result = false;
if (lock_.try_lock()) {
for (const auto& task : tasks) {
deque_.emplace_back(std::forward<T>(task));
}
lock_.unlock();
result = true;
}
Expand All @@ -58,12 +94,12 @@ class UWorkStealingQueue : public UQueueObject {
* @param task
* @return
*/
CBool tryPop(UTask& task) {
CBool tryPop(T& task) {
// 这里不使用raii锁,主要是考虑到多线程的情况下,可能会重复进入
bool result = false;
if (!deque_.empty() && lock_.try_lock()) {
if (!deque_.empty()) {
task = std::move(deque_.front()); // 从前方弹出
task = std::forward<T>(deque_.front()); // 从前方弹出
deque_.pop_front();
result = true;
}
Expand All @@ -80,11 +116,11 @@ class UWorkStealingQueue : public UQueueObject {
* @param maxLocalBatchSize
* @return
*/
CBool tryPop(UTaskArrRef taskArr, int maxLocalBatchSize) {
CBool tryPop(std::vector<T>& taskArr, int maxLocalBatchSize) {
bool result = false;
if (!deque_.empty() && lock_.try_lock()) {
while (!deque_.empty() && maxLocalBatchSize--) {
taskArr.emplace_back(std::move(deque_.front()));
taskArr.emplace_back(std::forward<T>(deque_.front()));
deque_.pop_front();
result = true;
}
Expand All @@ -100,11 +136,11 @@ class UWorkStealingQueue : public UQueueObject {
* @param task
* @return
*/
CBool trySteal(UTask& task) {
CBool trySteal(T& task) {
bool result = false;
if (!deque_.empty() && lock_.try_lock()) {
if (!deque_.empty()) {
task = std::move(deque_.back()); // 从后方窃取
task = std::forward<T>(deque_.back()); // 从后方窃取
deque_.pop_back();
result = true;
}
Expand All @@ -120,11 +156,11 @@ class UWorkStealingQueue : public UQueueObject {
* @param taskArr
* @return
*/
CBool trySteal(UTaskArrRef taskArr, int maxStealBatchSize) {
CBool trySteal(std::vector<T>& taskArr, int maxStealBatchSize) {
bool result = false;
if (!deque_.empty() && lock_.try_lock()) {
while (!deque_.empty() && maxStealBatchSize--) {
taskArr.emplace_back(std::move(deque_.back()));
taskArr.emplace_back(std::forward<T>(deque_.back()));
deque_.pop_back();
result = true;
}
Expand All @@ -139,7 +175,7 @@ class UWorkStealingQueue : public UQueueObject {
CGRAPH_NO_ALLOWED_COPY(UWorkStealingQueue)

private:
std::deque<UTask> deque_; // 存放任务的双向队列
std::deque<T> deque_; // 存放任务的双向队列
std::mutex lock_; // 用于处理deque_的锁
};

Expand Down
6 changes: 4 additions & 2 deletions src/UtilsCtrl/ThreadPool/Thread/UThreadPrimary.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#ifndef CGRAPH_UTHREADPRIMARY_H
#define CGRAPH_UTHREADPRIMARY_H

#include <vector>

#include "UThreadBase.h"

CGRAPH_NAMESPACE_BEGIN
Expand Down Expand Up @@ -229,8 +231,8 @@ class UThreadPrimary : public UThreadBase {

private:
int index_; // 线程index
UWorkStealingQueue primary_queue_; // 内部队列信息
UWorkStealingQueue secondary_queue_; // 第二个队列,用于减少触锁概率,提升性能
UWorkStealingQueue<UTask> primary_queue_; // 内部队列信息
UWorkStealingQueue<UTask> secondary_queue_; // 第二个队列,用于减少触锁概率,提升性能
std::vector<UThreadPrimary *>* pool_threads_; // 用于存放线程池中的线程信息
std::vector<int> steal_targets_; // 被偷的目标信息

Expand Down

0 comments on commit 6af3722

Please sign in to comment.