Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite job pool to support nested parallelism for improved scalability #339

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions encoder/basisu_comp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ namespace basisu

std::atomic<uint32_t> total_blocks_processed;
total_blocks_processed = 0;

job_pool::token token{0};

const uint32_t N = 256;
for (uint32_t block_index_iter = 0; block_index_iter < total_blocks; block_index_iter += N)
Expand Down Expand Up @@ -646,13 +648,13 @@ namespace basisu
}

#ifndef __EMSCRIPTEN__
});
}, &token);
#endif

} // block_index_iter

#ifndef __EMSCRIPTEN__
m_params.m_pJob_pool->wait_for_all();
m_params.m_pJob_pool->wait_for_all(&token);
#endif

if (any_failures)
Expand Down Expand Up @@ -750,6 +752,8 @@ namespace basisu
std::atomic<uint32_t> total_blocks_processed;
total_blocks_processed = 0;

job_pool::token token{0};

const uint32_t N = 256;
for (uint32_t block_index_iter = 0; block_index_iter < total_blocks; block_index_iter += N)
{
Expand Down Expand Up @@ -791,13 +795,13 @@ namespace basisu
}

#ifndef __EMSCRIPTEN__
});
}, &token);
#endif

} // block_index_iter

#ifndef __EMSCRIPTEN__
m_params.m_pJob_pool->wait_for_all();
m_params.m_pJob_pool->wait_for_all(&token);
#endif

if (m_params.m_rdo_uastc)
Expand Down Expand Up @@ -3462,7 +3466,7 @@ namespace basisu

for (uint32_t pindex = 0; pindex < params_vec.size(); pindex++)
{
jpool.add_job([pindex, &params_vec, &results_vec, &result, &opencl_failed] {
jpool.add_job([pindex, &params_vec, &results_vec, &result, &opencl_failed, &jpool] {

basis_compressor_params params = params_vec[pindex];
parallel_results& results = results_vec[pindex];
Expand All @@ -3472,9 +3476,7 @@ namespace basisu

basis_compressor c;

// Dummy job pool
job_pool task_jpool(1);
params.m_pJob_pool = &task_jpool;
params.m_pJob_pool = &jpool;
// TODO: Remove this flag entirely
params.m_multithreading = true;

Expand Down
109 changes: 58 additions & 51 deletions encoder/basisu_enc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2091,7 +2091,7 @@ namespace basisu
}

job_pool::job_pool(uint32_t num_threads) :
m_num_active_jobs(0),
m_num_pending_jobs(0),
m_kill_flag(false)
{
assert(num_threads >= 1U);
Expand All @@ -2112,75 +2112,98 @@ namespace basisu
debug_printf("job_pool::~job_pool\n");

// Notify all workers that they need to die right now.
m_kill_flag = true;

m_has_work.notify_all();
{
std::unique_lock<std::mutex> lock(m_mutex);
m_kill_flag = true;
m_has_work.notify_all();
}

// Wait for all workers to die.
for (uint32_t i = 0; i < m_threads.size(); i++)
m_threads[i].join();
}

void job_pool::add_job(const std::function<void()>& job)
void job_pool::add_job(std::function<void()> job, token* tok)
{
std::unique_lock<std::mutex> lock(m_mutex);
{
std::unique_lock<std::mutex> lock(m_mutex);

m_queue.emplace_back(job);
m_queue.push_back(item{ std::move(job), tok });

const size_t queue_size = m_queue.size();
if (tok)
(*tok)++;

lock.unlock();
m_num_pending_jobs++;
}

if (queue_size > 1)
m_has_work.notify_one();
m_has_work.notify_one();
}

void job_pool::add_job(std::function<void()>&& job)
void job_pool::wait_for_all(token* tok)
{
token* wait_token = tok ? tok : &m_num_pending_jobs;

std::unique_lock<std::mutex> lock(m_mutex);

m_queue.emplace_back(std::move(job));

const size_t queue_size = m_queue.size();
while (true)
{
if (*wait_token == 0)
return;

lock.unlock();
item job;
if (!job_steal(job, tok, lock))
break;

job_run(job, lock);
}

if (queue_size > 1)
m_job_done.wait(lock, [wait_token] { return *wait_token == 0; });
}

bool job_pool::job_steal(item& job, token* tok, std::unique_lock<std::mutex>&)
{
for (size_t i = m_queue.size(); i > 0; --i)
{
m_has_work.notify_one();
item& victim = m_queue[i - 1];

if (tok == nullptr || victim.tok == tok)
{
job = std::move(victim);
victim = std::move(m_queue.back());
m_queue.pop_back();

return true;
}
}

return false;
}

void job_pool::wait_for_all()
void job_pool::job_run(item& job, std::unique_lock<std::mutex>& lock)
{
std::unique_lock<std::mutex> lock(m_mutex);
lock.unlock();

// Drain the job queue on the calling thread.
while (!m_queue.empty())
{
std::function<void()> job(m_queue.back());
m_queue.pop_back();
job.fn();

lock.unlock();
lock.lock();

job();
if (job.tok)
(*job.tok)--;

lock.lock();
}
m_num_pending_jobs--;

// The queue is empty, now wait for all active jobs to finish up.
m_no_more_jobs.wait(lock, [this]{ return !m_num_active_jobs; } );
m_job_done.notify_all();
}

void job_pool::job_thread(uint32_t index)
{
BASISU_NOTE_UNUSED(index);
//debug_printf("job_pool::job_thread: starting %u\n", index);

std::unique_lock<std::mutex> lock(m_mutex);

while (true)
{
std::unique_lock<std::mutex> lock(m_mutex);

// Wait for any jobs to be issued.
m_has_work.wait(lock, [this] { return m_kill_flag || m_queue.size(); } );

Expand All @@ -2189,26 +2212,10 @@ namespace basisu
break;

// Get the job and execute it.
std::function<void()> job(m_queue.back());
item job = std::move(m_queue.back());
m_queue.pop_back();

++m_num_active_jobs;

lock.unlock();

job();

lock.lock();

--m_num_active_jobs;

// Now check if there are no more jobs remaining.
const bool all_done = m_queue.empty() && !m_num_active_jobs;

lock.unlock();

if (all_done)
m_no_more_jobs.notify_all();
job_run(job, lock);
}

//debug_printf("job_pool::job_thread: exiting\n");
Expand Down
34 changes: 21 additions & 13 deletions encoder/basisu_enc.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "../transcoder/basisu_transcoder_internal.h"

#include <mutex>
#include <atomic>
#include <condition_variable>
#include <functional>
#include <thread>
Expand Down Expand Up @@ -736,29 +735,36 @@ namespace basisu
BASISU_NO_EQUALS_OR_COPY_CONSTRUCT(job_pool);

public:
using token = uint32_t;

// num_threads is the TOTAL number of job pool threads, including the calling thread! So 2=1 new thread, 3=2 new threads, etc.
job_pool(uint32_t num_threads);
~job_pool();

void add_job(const std::function<void()>& job);
void add_job(std::function<void()>&& job);

void wait_for_all();
void add_job(std::function<void()> job, token* tok = nullptr);
void wait_for_all(token* tok = nullptr);

size_t get_total_threads() const { return 1 + m_threads.size(); }

private:
struct item
{
std::function<void()> fn;
token* tok;
};

std::vector<std::thread> m_threads;
std::vector<std::function<void()> > m_queue;
std::vector<item> m_queue;

std::mutex m_mutex;
std::condition_variable m_has_work;
std::condition_variable m_no_more_jobs;

uint32_t m_num_active_jobs;

std::atomic<bool> m_kill_flag;
std::condition_variable m_job_done;

uint32_t m_num_pending_jobs;
bool m_kill_flag;

bool job_steal(item& job, token* tok, std::unique_lock<std::mutex>& lock);
void job_run(item& job, std::unique_lock<std::mutex>& lock);
void job_thread(uint32_t index);
};

Expand Down Expand Up @@ -2017,6 +2023,8 @@ namespace basisu
basisu::vector<uint_vec> local_clusters[cMaxThreads];
basisu::vector<uint_vec> local_parent_clusters[cMaxThreads];

job_pool::token token{0};

for (uint32_t thread_iter = 0; thread_iter < max_threads; thread_iter++)
{
#ifndef __EMSCRIPTEN__
Expand Down Expand Up @@ -2063,13 +2071,13 @@ namespace basisu
}

#ifndef __EMSCRIPTEN__
} );
}, &token);
#endif

} // thread_iter

#ifndef __EMSCRIPTEN__
pJob_pool->wait_for_all();
pJob_pool->wait_for_all(&token);
#endif

uint32_t total_clusters = 0, total_parent_clusters = 0;
Expand Down
Loading