Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash constructing BM3D_FilterData, data race with std::unordered_map #29

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/BM3D.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ struct BM3D_FilterData

return *this;
}

static std::mutex s_fftw_planner_mutex;
};


Expand Down
23 changes: 17 additions & 6 deletions include/BM3D_Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define BM3D_BASE_H_


#include <mutex>
#include <unordered_map>
#include <thread>
#include "BM3D.h"
Expand Down Expand Up @@ -54,6 +55,7 @@ class BM3D_Data_Base
_Mypara para;
std::vector<BM3D_FilterData> f;

std::mutex mutex0, mutex1, mutex2;
std::unordered_map<std::thread::id, FLType *> buffer0, buffer1, buffer2;

public:
Expand All @@ -72,17 +74,26 @@ class BM3D_Data_Base
{
if (rdef && rnode) vsapi->freeNode(rnode);

for (auto &e : buffer0)
{
AlignedFree(e.second);
std::lock_guard<std::mutex> guard0(mutex0);
for (auto &e : buffer0)
{
AlignedFree(e.second);
}
}
for (auto &e : buffer1)
{
AlignedFree(e.second);
std::lock_guard<std::mutex> guard1(mutex1);
for (auto &e : buffer1)
{
AlignedFree(e.second);
}
}
for (auto &e : buffer2)
{
AlignedFree(e.second);
std::lock_guard<std::mutex> guard2(mutex2);
for (auto &e : buffer2)
{
AlignedFree(e.second);
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions source/BM3D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ void BM3D_Para::thMSE_Default()
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Functions of struct BM3D_FilterData

std::mutex BM3D_FilterData::s_fftw_planner_mutex;

BM3D_FilterData::BM3D_FilterData(bool wiener, double sigma, PCType GroupSize, PCType BlockSize, double lambda)
: fp(GroupSize), bp(GroupSize), finalAMP(GroupSize), thrTable(wiener ? 0 : GroupSize),
Expand All @@ -143,6 +144,9 @@ BM3D_FilterData::BM3D_FilterData(bool wiener, double sigma, PCType GroupSize, PC

FLType *temp = nullptr;

// Executing a plan is thread-safe, but creating a plan is not because it mutates the planner's "wisdom".
// Consider that any number of instances of the (V)BM3D filter may exist in a single VapourSynth pipeline.
std::lock_guard<std::mutex> guard(s_fftw_planner_mutex);
for (PCType i = 1; i <= GroupSize; ++i)
{
AlignedMalloc(temp, i * BlockSize * BlockSize);
Expand Down
70 changes: 41 additions & 29 deletions source/BM3D_Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,17 @@ void BM3D_Process_Base::Kernel(FLType *dst, const FLType *src, const FLType *ref
std::thread::id threadId = std::this_thread::get_id();
FLType *ResNum = dst, *ResDen = nullptr;

if (!d.buffer0.count(threadId))
{
AlignedMalloc(ResDen, dst_pcount[0]);
d.buffer0.emplace(threadId, ResDen);
}
else
{
ResDen = d.buffer0.at(threadId);
std::lock_guard<std::mutex> guard(d.mutex0);
if (!d.buffer0.count(threadId))
{
AlignedMalloc(ResDen, dst_pcount[0]);
d.buffer0.emplace(threadId, ResDen);
}
else
{
ResDen = d.buffer0.at(threadId);
}
}

memset(ResNum, 0, sizeof(FLType) * dst_pcount[0]);
Expand Down Expand Up @@ -363,14 +366,17 @@ void BM3D_Process_Base::Kernel(FLType *dstY, FLType *dstU, FLType *dstV,

if (d.process[0])
{
if (!d.buffer0.count(threadId))
{
AlignedMalloc(ResDenY, dst_pcount[0]);
d.buffer0.emplace(threadId, ResDenY);
}
else
{
ResDenY = d.buffer0.at(threadId);
std::lock_guard<std::mutex> guard(d.mutex0);
if (!d.buffer0.count(threadId))
{
AlignedMalloc(ResDenY, dst_pcount[0]);
d.buffer0.emplace(threadId, ResDenY);
}
else
{
ResDenY = d.buffer0.at(threadId);
}
}

memset(ResNumY, 0, sizeof(FLType) * dst_pcount[0]);
Expand All @@ -379,30 +385,36 @@ void BM3D_Process_Base::Kernel(FLType *dstY, FLType *dstU, FLType *dstV,

if (d.process[1])
{
if (!d.buffer1.count(threadId))
{
AlignedMalloc(ResDenU, dst_pcount[1]);
d.buffer1.emplace(threadId, ResDenU);
}
else
{
ResDenU = d.buffer1.at(threadId);
std::lock_guard<std::mutex> guard(d.mutex1);
if (!d.buffer1.count(threadId))
{
AlignedMalloc(ResDenU, dst_pcount[1]);
d.buffer1.emplace(threadId, ResDenU);
}
else
{
ResDenU = d.buffer1.at(threadId);
}
}

memset(ResNumU, 0, sizeof(FLType) * dst_pcount[1]);
memset(ResDenU, 0, sizeof(FLType) * dst_pcount[1]);
}

if (d.process[2])
{
if (!d.buffer2.count(threadId))
{
AlignedMalloc(ResDenV, dst_pcount[2]);
d.buffer2.emplace(threadId, ResDenV);
}
else
{
ResDenV = d.buffer2.at(threadId);
std::lock_guard<std::mutex> guard(d.mutex2);
if (!d.buffer2.count(threadId))
{
AlignedMalloc(ResDenV, dst_pcount[2]);
d.buffer2.emplace(threadId, ResDenV);
}
else
{
ResDenV = d.buffer2.at(threadId);
}
}

memset(ResNumV, 0, sizeof(FLType) * dst_pcount[2]);
Expand Down