Skip to content

Commit

Permalink
Add ThreadPool class
Browse files Browse the repository at this point in the history
  • Loading branch information
cbritopacheco committed Nov 21, 2023
1 parent 1308aeb commit ad3e8db
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 44 deletions.
21 changes: 10 additions & 11 deletions examples/Variational/P1/P1Potential.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@ int main(int, char**)
mesh.scale(2. / (n - 1));
mesh.displace(VectorFunction{-1, -1});

auto it = mesh.getCell();
// for (auto it = mesh.getVertex(); it; ++it)
// {
// const auto& coords = it->getCoordinates();
// const Scalar x = coords.x();
// const Scalar y = coords.y();
// const Scalar u = x * sqrt(x * x + y * y - x * x * y * y) / sqrt(x * x + y * y);
// const Scalar v = y * sqrt(x * x + y * y - x * x * y * y) / sqrt(x * x + y * y);
// mesh.setVertexCoordinates(it->getIndex(), u, 0);
// mesh.setVertexCoordinates(it->getIndex(), v, 1);
// }
for (auto it = mesh.getVertex(); it; ++it)
{
const auto& coords = it->getCoordinates();
const Scalar x = coords.x();
const Scalar y = coords.y();
const Scalar u = x * sqrt(x * x + y * y - x * x * y * y) / sqrt(x * x + y * y);
const Scalar v = y * sqrt(x * x + y * y - x * x * y * y) / sqrt(x * x + y * y);
mesh.setVertexCoordinates(it->getIndex(), u, 0);
mesh.setVertexCoordinates(it->getIndex(), v, 1);
}
MMG::Optimizer().setHMax(0.04).setHMin(0.002).optimize(mesh);

for (const Index& i : mesh.getRidges())
Expand Down
65 changes: 34 additions & 31 deletions src/Rodin/Assembly/Multithreaded.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
*/
#include <thread>

#include "BS_thread_pool.hpp"

#include "Rodin/Variational/FiniteElementSpace.h"
#include "Rodin/Variational/LinearFormIntegrator.h"
#include "Rodin/Variational/BilinearFormIntegrator.h"
Expand All @@ -23,29 +21,28 @@ namespace Rodin::Assembly

Multithreaded<Variational::BilinearFormBase<Math::SparseMatrix>>
::Multithreaded(size_t threadCount)
: m_threadCount(threadCount)
: m_assembly(threadCount)

Check warning on line 24 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L24

Added line #L24 was not covered by tests
{
assert(threadCount > 0);
}

Multithreaded<Variational::BilinearFormBase<Math::SparseMatrix>>
::Multithreaded(const Multithreaded& other)
: Parent(other),
m_threadCount(other.m_threadCount)
m_assembly(other.m_assembly)

Check warning on line 32 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L32

Added line #L32 was not covered by tests
{}

Multithreaded<Variational::BilinearFormBase<Math::SparseMatrix>>
::Multithreaded(Multithreaded&& other)
: Parent(std::move(other)),
m_threadCount(std::move(other.m_threadCount))
m_assembly(std::move(other.m_assembly))

Check warning on line 38 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L38

Added line #L38 was not covered by tests
{}

Math::SparseMatrix
Multithreaded<Variational::BilinearFormBase<Math::SparseMatrix>>
::execute(const BilinearAssemblyInput& input) const
{
Multithreaded<Variational::BilinearFormBase<std::vector<Eigen::Triplet<Scalar>>>> assembly(m_threadCount);
const auto triplets = assembly.execute(input);
const auto triplets = m_assembly.execute(input);

Check warning on line 45 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L45

Added line #L45 was not covered by tests
OperatorType res(input.testFES.getSize(), input.trialFES.getSize());
res.setFromTriplets(triplets.begin(), triplets.end());
return res;
Expand All @@ -58,21 +55,24 @@ namespace Rodin::Assembly

Multithreaded<Variational::BilinearFormBase<std::vector<Eigen::Triplet<Scalar>>>>
::Multithreaded(size_t threadCount)
: m_threadCount(threadCount)
: m_threadCount(threadCount),
m_pool(threadCount)

Check warning on line 59 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L58-L59

Added lines #L58 - L59 were not covered by tests
{
assert(threadCount > 0);
}

Multithreaded<Variational::BilinearFormBase<std::vector<Eigen::Triplet<Scalar>>>>
::Multithreaded(const Multithreaded& other)
: Parent(other),
m_threadCount(other.m_threadCount)
m_threadCount(other.m_threadCount),
m_pool(m_threadCount)

Check warning on line 68 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L67-L68

Added lines #L67 - L68 were not covered by tests
{}

Multithreaded<Variational::BilinearFormBase<std::vector<Eigen::Triplet<Scalar>>>>
::Multithreaded(Multithreaded&& other)
: Parent(std::move(other)),
m_threadCount(std::move(other.m_threadCount))
m_threadCount(std::move(other.m_threadCount)),
m_pool(m_threadCount)

Check warning on line 75 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L74-L75

Added lines #L74 - L75 were not covered by tests
{}

void
Expand Down Expand Up @@ -111,7 +111,7 @@ namespace Rodin::Assembly
TripletVector res;
res.clear();
res.reserve(input.testFES.getSize() * std::log(input.trialFES.getSize()));
BS::thread_pool threadPool(m_threadCount);
auto& threadPool = m_pool;

Check warning on line 114 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L114

Added line #L114 was not covered by tests
for (auto& bfi : input.bfis)
{
const auto& attrs = bfi.getAttributes();
Expand Down Expand Up @@ -144,8 +144,8 @@ namespace Rodin::Assembly
m_mutex.unlock();
tl_triplets.clear();
};
threadPool.push_loop(0, input.mesh.getCellCount(), loop);
threadPool.wait_for_tasks();
threadPool.pushLoop(0, input.mesh.getCellCount(), loop);
threadPool.waitForTasks();

Check warning on line 148 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L147-L148

Added lines #L147 - L148 were not covered by tests
break;
}
case Variational::Integrator::Region::Faces:
Expand Down Expand Up @@ -175,8 +175,8 @@ namespace Rodin::Assembly
m_mutex.unlock();
tl_triplets.clear();
};
threadPool.push_loop(0, input.mesh.getFaceCount(), loop);
threadPool.wait_for_tasks();
threadPool.pushLoop(0, input.mesh.getFaceCount(), loop);
threadPool.waitForTasks();

Check warning on line 179 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L178-L179

Added lines #L178 - L179 were not covered by tests
break;
}
case Variational::Integrator::Region::Boundary:
Expand Down Expand Up @@ -209,8 +209,8 @@ namespace Rodin::Assembly
m_mutex.unlock();
tl_triplets.clear();
};
threadPool.push_loop(0, input.mesh.getFaceCount(), loop);
threadPool.wait_for_tasks();
threadPool.pushLoop(0, input.mesh.getFaceCount(), loop);
threadPool.waitForTasks();

Check warning on line 213 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L212-L213

Added lines #L212 - L213 were not covered by tests
break;
}
case Variational::Integrator::Region::Interface:
Expand Down Expand Up @@ -243,8 +243,8 @@ namespace Rodin::Assembly
m_mutex.unlock();
tl_triplets.clear();
};
threadPool.push_loop(0, input.mesh.getFaceCount(), loop);
threadPool.wait_for_tasks();
threadPool.pushLoop(0, input.mesh.getFaceCount(), loop);
threadPool.waitForTasks();

Check warning on line 247 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L246-L247

Added lines #L246 - L247 were not covered by tests
break;
}
}
Expand All @@ -267,21 +267,24 @@ namespace Rodin::Assembly

Multithreaded<Variational::LinearFormBase<Math::Vector>>
::Multithreaded(size_t threadCount)
: m_threadCount(threadCount)
: m_threadCount(threadCount),
m_pool(threadCount)
{
assert(threadCount > 0);
}

Multithreaded<Variational::LinearFormBase<Math::Vector>>
::Multithreaded(const Multithreaded& other)
: Parent(other),
m_threadCount(other.m_threadCount)
m_threadCount(other.m_threadCount),
m_pool(m_threadCount)

Check warning on line 280 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L279-L280

Added lines #L279 - L280 were not covered by tests
{}

Multithreaded<Variational::LinearFormBase<Math::Vector>>
::Multithreaded(Multithreaded&& other)
: Parent(std::move(other)),
m_threadCount(std::move(other.m_threadCount))
m_threadCount(std::move(other.m_threadCount)),
m_pool(m_threadCount)

Check warning on line 287 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L286-L287

Added lines #L286 - L287 were not covered by tests
{}

void
Expand All @@ -298,9 +301,9 @@ namespace Rodin::Assembly
Multithreaded<Variational::LinearFormBase<Math::Vector>>
::execute(const Input& input) const
{
BS::thread_pool threadPool(m_threadCount);
VectorType res(input.fes.getSize());
res.setZero();
auto& threadPool = m_pool;
for (auto& lfi : input.lfis)
{
const auto& attrs = lfi.getAttributes();
Expand Down Expand Up @@ -329,8 +332,8 @@ namespace Rodin::Assembly
res += tl_res;
m_mutex.unlock();
};
threadPool.push_loop(0, input.mesh.getCellCount(), loop);
threadPool.wait_for_tasks();
threadPool.pushLoop(0, input.mesh.getCellCount(), loop);
threadPool.waitForTasks();
break;
}
case Variational::Integrator::Region::Faces:
Expand All @@ -356,8 +359,8 @@ namespace Rodin::Assembly
res += tl_res;
m_mutex.unlock();
};
threadPool.push_loop(0, input.mesh.getFaceCount(), loop);
threadPool.wait_for_tasks();
threadPool.pushLoop(0, input.mesh.getFaceCount(), loop);
threadPool.waitForTasks();

Check warning on line 363 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L362-L363

Added lines #L362 - L363 were not covered by tests
break;
}
case Variational::Integrator::Region::Boundary:
Expand Down Expand Up @@ -386,8 +389,8 @@ namespace Rodin::Assembly
res += tl_res;
m_mutex.unlock();
};
threadPool.push_loop(0, input.mesh.getFaceCount(), loop);
threadPool.wait_for_tasks();
threadPool.pushLoop(0, input.mesh.getFaceCount(), loop);
threadPool.waitForTasks();

Check warning on line 393 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L392-L393

Added lines #L392 - L393 were not covered by tests
break;
}
case Variational::Integrator::Region::Interface:
Expand Down Expand Up @@ -416,8 +419,8 @@ namespace Rodin::Assembly
res += tl_res;
m_mutex.unlock();
};
threadPool.push_loop(0, input.mesh.getFaceCount(), loop);
threadPool.wait_for_tasks();
threadPool.pushLoop(0, input.mesh.getFaceCount(), loop);
threadPool.waitForTasks();

Check warning on line 423 in src/Rodin/Assembly/Multithreaded.cpp

View check run for this annotation

Codecov / codecov/patch

src/Rodin/Assembly/Multithreaded.cpp#L422-L423

Added lines #L422 - L423 were not covered by tests
break;
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/Rodin/Assembly/Multithreaded.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "Rodin/Math/Vector.h"
#include "Rodin/Math/SparseMatrix.h"

#include "Rodin/Threads/ThreadPool.h"

#include "Rodin/Threads/Mutex.h"

#include "ForwardDecls.h"
Expand Down Expand Up @@ -57,8 +59,8 @@ namespace Rodin::Assembly
static thread_local std::unique_ptr<Variational::BilinearFormIntegratorBase> tl_bfi;

const size_t m_threadCount;

mutable Threads::Mutex m_mutex;
mutable Threads::ThreadPool m_pool;
};

/**
Expand Down Expand Up @@ -93,7 +95,7 @@ namespace Rodin::Assembly
}

private:
const size_t m_threadCount;
Multithreaded<Variational::BilinearFormBase<std::vector<Eigen::Triplet<Scalar>>>> m_assembly;
};

/**
Expand Down Expand Up @@ -136,6 +138,7 @@ namespace Rodin::Assembly

const size_t m_threadCount;
mutable Threads::Mutex m_mutex;
mutable Threads::ThreadPool m_pool;
};
}

Expand Down
File renamed without changes.
32 changes: 32 additions & 0 deletions src/Rodin/Threads/ThreadPool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef RODIN_THREADPOOL_H
#define RODIN_THREADPOOL_H

#include "BS_thread_pool.hpp"

namespace Rodin::Threads
{
class ThreadPool
{
public:
ThreadPool(size_t numThreads)
: m_pool(numThreads)
{}

template <class ... Args>
void pushLoop(Args... args)
{
m_pool.push_loop(std::forward<Args>(args)...);
}

void waitForTasks()
{
m_pool.wait_for_tasks();
}

private:
BS::thread_pool m_pool;
};
}

#endif

0 comments on commit ad3e8db

Please sign in to comment.