From ad3e8db646081250b0b695900608a0392c6f82e8 Mon Sep 17 00:00:00 2001 From: Carlos Brito Date: Tue, 21 Nov 2023 13:15:13 +0100 Subject: [PATCH] Add ThreadPool class --- examples/Variational/P1/P1Potential.cpp | 21 +++--- src/Rodin/Assembly/Multithreaded.cpp | 65 ++++++++++--------- src/Rodin/Assembly/Multithreaded.h | 7 +- .../{Assembly => Threads}/BS_thread_pool.hpp | 0 src/Rodin/Threads/ThreadPool.h | 32 +++++++++ 5 files changed, 81 insertions(+), 44 deletions(-) rename src/Rodin/{Assembly => Threads}/BS_thread_pool.hpp (100%) create mode 100644 src/Rodin/Threads/ThreadPool.h diff --git a/examples/Variational/P1/P1Potential.cpp b/examples/Variational/P1/P1Potential.cpp index 538fc254a..c2611308f 100644 --- a/examples/Variational/P1/P1Potential.cpp +++ b/examples/Variational/P1/P1Potential.cpp @@ -28,17 +28,16 @@ int main(int, char**) mesh.scale(2. / (n - 1)); mesh.displace(VectorFunction{-1, -1}); - auto it = mesh.getCell(); - // for (auto it = mesh.getVertex(); it; ++it) - // { - // const auto& coords = it->getCoordinates(); - // const Scalar x = coords.x(); - // const Scalar y = coords.y(); - // const Scalar u = x * sqrt(x * x + y * y - x * x * y * y) / sqrt(x * x + y * y); - // const Scalar v = y * sqrt(x * x + y * y - x * x * y * y) / sqrt(x * x + y * y); - // mesh.setVertexCoordinates(it->getIndex(), u, 0); - // mesh.setVertexCoordinates(it->getIndex(), v, 1); - // } + for (auto it = mesh.getVertex(); it; ++it) + { + const auto& coords = it->getCoordinates(); + const Scalar x = coords.x(); + const Scalar y = coords.y(); + const Scalar u = x * sqrt(x * x + y * y - x * x * y * y) / sqrt(x * x + y * y); + const Scalar v = y * sqrt(x * x + y * y - x * x * y * y) / sqrt(x * x + y * y); + mesh.setVertexCoordinates(it->getIndex(), u, 0); + mesh.setVertexCoordinates(it->getIndex(), v, 1); + } MMG::Optimizer().setHMax(0.04).setHMin(0.002).optimize(mesh); for (const Index& i : mesh.getRidges()) diff --git a/src/Rodin/Assembly/Multithreaded.cpp b/src/Rodin/Assembly/Multithreaded.cpp index a56e840fc..e7127a0b7 100644 --- a/src/Rodin/Assembly/Multithreaded.cpp +++ b/src/Rodin/Assembly/Multithreaded.cpp @@ -6,8 +6,6 @@ */ #include -#include "BS_thread_pool.hpp" - #include "Rodin/Variational/FiniteElementSpace.h" #include "Rodin/Variational/LinearFormIntegrator.h" #include "Rodin/Variational/BilinearFormIntegrator.h" @@ -23,7 +21,7 @@ namespace Rodin::Assembly Multithreaded> ::Multithreaded(size_t threadCount) - : m_threadCount(threadCount) + : m_assembly(threadCount) { assert(threadCount > 0); } @@ -31,21 +29,20 @@ namespace Rodin::Assembly Multithreaded> ::Multithreaded(const Multithreaded& other) : Parent(other), - m_threadCount(other.m_threadCount) + m_assembly(other.m_assembly) {} Multithreaded> ::Multithreaded(Multithreaded&& other) : Parent(std::move(other)), - m_threadCount(std::move(other.m_threadCount)) + m_assembly(std::move(other.m_assembly)) {} Math::SparseMatrix Multithreaded> ::execute(const BilinearAssemblyInput& input) const { - Multithreaded>>> assembly(m_threadCount); - const auto triplets = assembly.execute(input); + const auto triplets = m_assembly.execute(input); OperatorType res(input.testFES.getSize(), input.trialFES.getSize()); res.setFromTriplets(triplets.begin(), triplets.end()); return res; @@ -58,7 +55,8 @@ namespace Rodin::Assembly Multithreaded>>> ::Multithreaded(size_t threadCount) - : m_threadCount(threadCount) + : m_threadCount(threadCount), + m_pool(threadCount) { assert(threadCount > 0); } @@ -66,13 +64,15 @@ namespace Rodin::Assembly Multithreaded>>> ::Multithreaded(const Multithreaded& other) : Parent(other), - m_threadCount(other.m_threadCount) + m_threadCount(other.m_threadCount), + m_pool(m_threadCount) {} Multithreaded>>> ::Multithreaded(Multithreaded&& other) : Parent(std::move(other)), - m_threadCount(std::move(other.m_threadCount)) + m_threadCount(std::move(other.m_threadCount)), + m_pool(m_threadCount) {} void @@ -111,7 +111,7 @@ namespace Rodin::Assembly TripletVector res; res.clear(); res.reserve(input.testFES.getSize() * std::log(input.trialFES.getSize())); - BS::thread_pool threadPool(m_threadCount); + auto& threadPool = m_pool; for (auto& bfi : input.bfis) { const auto& attrs = bfi.getAttributes(); @@ -144,8 +144,8 @@ namespace Rodin::Assembly m_mutex.unlock(); tl_triplets.clear(); }; - threadPool.push_loop(0, input.mesh.getCellCount(), loop); - threadPool.wait_for_tasks(); + threadPool.pushLoop(0, input.mesh.getCellCount(), loop); + threadPool.waitForTasks(); break; } case Variational::Integrator::Region::Faces: @@ -175,8 +175,8 @@ namespace Rodin::Assembly m_mutex.unlock(); tl_triplets.clear(); }; - threadPool.push_loop(0, input.mesh.getFaceCount(), loop); - threadPool.wait_for_tasks(); + threadPool.pushLoop(0, input.mesh.getFaceCount(), loop); + threadPool.waitForTasks(); break; } case Variational::Integrator::Region::Boundary: @@ -209,8 +209,8 @@ namespace Rodin::Assembly m_mutex.unlock(); tl_triplets.clear(); }; - threadPool.push_loop(0, input.mesh.getFaceCount(), loop); - threadPool.wait_for_tasks(); + threadPool.pushLoop(0, input.mesh.getFaceCount(), loop); + threadPool.waitForTasks(); break; } case Variational::Integrator::Region::Interface: @@ -243,8 +243,8 @@ namespace Rodin::Assembly m_mutex.unlock(); tl_triplets.clear(); }; - threadPool.push_loop(0, input.mesh.getFaceCount(), loop); - threadPool.wait_for_tasks(); + threadPool.pushLoop(0, input.mesh.getFaceCount(), loop); + threadPool.waitForTasks(); break; } } @@ -267,7 +267,8 @@ namespace Rodin::Assembly Multithreaded> ::Multithreaded(size_t threadCount) - : m_threadCount(threadCount) + : m_threadCount(threadCount), + m_pool(threadCount) { assert(threadCount > 0); } @@ -275,13 +276,15 @@ namespace Rodin::Assembly Multithreaded> ::Multithreaded(const Multithreaded& other) : Parent(other), - m_threadCount(other.m_threadCount) + m_threadCount(other.m_threadCount), + m_pool(m_threadCount) {} Multithreaded> ::Multithreaded(Multithreaded&& other) : Parent(std::move(other)), - m_threadCount(std::move(other.m_threadCount)) + m_threadCount(std::move(other.m_threadCount)), + m_pool(m_threadCount) {} void @@ -298,9 +301,9 @@ namespace Rodin::Assembly Multithreaded> ::execute(const Input& input) const { - BS::thread_pool threadPool(m_threadCount); VectorType res(input.fes.getSize()); res.setZero(); + auto& threadPool = m_pool; for (auto& lfi : input.lfis) { const auto& attrs = lfi.getAttributes(); @@ -329,8 +332,8 @@ namespace Rodin::Assembly res += tl_res; m_mutex.unlock(); }; - threadPool.push_loop(0, input.mesh.getCellCount(), loop); - threadPool.wait_for_tasks(); + threadPool.pushLoop(0, input.mesh.getCellCount(), loop); + threadPool.waitForTasks(); break; } case Variational::Integrator::Region::Faces: @@ -356,8 +359,8 @@ namespace Rodin::Assembly res += tl_res; m_mutex.unlock(); }; - threadPool.push_loop(0, input.mesh.getFaceCount(), loop); - threadPool.wait_for_tasks(); + threadPool.pushLoop(0, input.mesh.getFaceCount(), loop); + threadPool.waitForTasks(); break; } case Variational::Integrator::Region::Boundary: @@ -386,8 +389,8 @@ namespace Rodin::Assembly res += tl_res; m_mutex.unlock(); }; - threadPool.push_loop(0, input.mesh.getFaceCount(), loop); - threadPool.wait_for_tasks(); + threadPool.pushLoop(0, input.mesh.getFaceCount(), loop); + threadPool.waitForTasks(); break; } case Variational::Integrator::Region::Interface: @@ -416,8 +419,8 @@ namespace Rodin::Assembly res += tl_res; m_mutex.unlock(); }; - threadPool.push_loop(0, input.mesh.getFaceCount(), loop); - threadPool.wait_for_tasks(); + threadPool.pushLoop(0, input.mesh.getFaceCount(), loop); + threadPool.waitForTasks(); break; } } diff --git a/src/Rodin/Assembly/Multithreaded.h b/src/Rodin/Assembly/Multithreaded.h index f100ff140..c2a20b91d 100644 --- a/src/Rodin/Assembly/Multithreaded.h +++ b/src/Rodin/Assembly/Multithreaded.h @@ -10,6 +10,8 @@ #include "Rodin/Math/Vector.h" #include "Rodin/Math/SparseMatrix.h" +#include "Rodin/Threads/ThreadPool.h" + #include "Rodin/Threads/Mutex.h" #include "ForwardDecls.h" @@ -57,8 +59,8 @@ namespace Rodin::Assembly static thread_local std::unique_ptr tl_bfi; const size_t m_threadCount; - mutable Threads::Mutex m_mutex; + mutable Threads::ThreadPool m_pool; }; /** @@ -93,7 +95,7 @@ namespace Rodin::Assembly } private: - const size_t m_threadCount; + Multithreaded>>> m_assembly; }; /** @@ -136,6 +138,7 @@ namespace Rodin::Assembly const size_t m_threadCount; mutable Threads::Mutex m_mutex; + mutable Threads::ThreadPool m_pool; }; } diff --git a/src/Rodin/Assembly/BS_thread_pool.hpp b/src/Rodin/Threads/BS_thread_pool.hpp similarity index 100% rename from src/Rodin/Assembly/BS_thread_pool.hpp rename to src/Rodin/Threads/BS_thread_pool.hpp diff --git a/src/Rodin/Threads/ThreadPool.h b/src/Rodin/Threads/ThreadPool.h new file mode 100644 index 000000000..1d9ea022e --- /dev/null +++ b/src/Rodin/Threads/ThreadPool.h @@ -0,0 +1,32 @@ +#ifndef RODIN_THREADPOOL_H +#define RODIN_THREADPOOL_H + +#include "BS_thread_pool.hpp" + +namespace Rodin::Threads +{ + class ThreadPool + { + public: + ThreadPool(size_t numThreads) + : m_pool(numThreads) + {} + + template + void pushLoop(Args... args) + { + m_pool.push_loop(std::forward(args)...); + } + + void waitForTasks() + { + m_pool.wait_for_tasks(); + } + + private: + BS::thread_pool m_pool; + }; +} + +#endif +