From fb03c5d7539992b23fbc094d0f3869276ae7a1d8 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Wed, 5 Jun 2024 00:34:04 -0700 Subject: [PATCH] Multithreading via C++ thread pool of clients (#125) --- CMakeLists.txt | 4 +- include/ClientPool.h | 106 +++++++++++++++++ include/Export.h | 14 ++- include/LockPool.h | 100 ++++++++++++++++ libmc/__init__.py | 41 ++++++- libmc/_client.pyx | 230 +++++++++++++++++++++++++++---------- misc/.cppcheck-supp | 3 +- misc/runbench.py | 183 ++++++++++++++++++++++++++--- setup.py | 1 + src/Client.cpp | 7 ++ src/ClientPool.cpp | 156 +++++++++++++++++++++++++ src/golibmc.go | 2 +- src/version.go | 4 +- tests/CMakeLists.txt | 2 +- tests/test_client_pool.cpp | 96 ++++++++++++++++ tests/test_client_pool.py | 112 ++++++++++++++++++ 16 files changed, 977 insertions(+), 84 deletions(-) create mode 100644 include/ClientPool.h create mode 100644 include/LockPool.h create mode 100644 src/ClientPool.cpp create mode 100644 tests/test_client_pool.cpp create mode 100644 tests/test_client_pool.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 04a4799a..c27ecce0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ set(CMAKE_MACOSX_RPATH 1) set (MC_VERSION_MAJOR 1) set (MC_VERSION_MINOR 4) -set (MC_VERSION_PATCH 1) +set (MC_VERSION_PATCH 4) set (MC_VERSION ${MC_VERSION_MAJOR}.${MC_VERSION_MINOR}) set (MC_APIVERSION ${MC_VERSION}.${MC_VERSION_PATCH}) @@ -15,7 +15,7 @@ if (NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release." FORCE) endif (NOT CMAKE_BUILD_TYPE) -set(CMAKE_CXX_FLAGS_COMMON "-Wall -fno-rtti -fno-exceptions") +set(CMAKE_CXX_FLAGS_COMMON "-Wall -fno-rtti -fno-exceptions -std=c++17") set(CMAKE_CXX_FLAGS_DEBUG "-DDEBUG -g2 ${CMAKE_CXX_FLAGS_COMMON}" CACHE STRING "CXX DEBUG FLAGS" FORCE) set(CMAKE_CXX_FLAGS_RELEASE "-DNDEBUG -O3 ${CMAKE_CXX_FLAGS_COMMON}" CACHE STRING "CXX RELEASE FLAGS" FORCE) set(CMAKE_INSTALL_INCLUDE include CACHE PATH "Output directory for header files") diff --git a/include/ClientPool.h b/include/ClientPool.h new file mode 100644 index 00000000..eb3fdea6 --- /dev/null +++ b/include/ClientPool.h @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include + +#include "Client.h" +#include "LockPool.h" + +namespace douban { +namespace mc { + +template +void duplicate_strings(const char* const * strs, const size_t n, + std::deque >& out, std::vector& refs) { + out.resize(n); + refs.resize(n); + for (size_t i = 0; i < n; i++) { + if (strs == NULL || strs[i] == NULL) { + out[i][0] = '\0'; + refs[i] = NULL; + continue; + } + std::snprintf(out[i].data(), N, "%s", strs[i]); + refs[i] = out[i].data(); + } +} + +class irange { + int i; + +public: + using value_type = int; + using pointer = const int*; + using reference = const int&; + using difference_type = int; + using iterator_category = std::random_access_iterator_tag; + + explicit irange(int i) : i(i) {} + + reference operator*() const { return i; } + pointer operator->() const { return &i; } + value_type operator[](int n) const { return i + n; } + friend bool operator< (const irange& lhs, const irange& rhs) { return lhs.i < rhs.i; } + friend bool operator> (const irange& lhs, const irange& rhs) { return rhs < lhs; } + friend bool operator<=(const irange& lhs, const irange& rhs) { return !(lhs > rhs); } + friend bool operator>=(const irange& lhs, const irange& rhs) { return !(lhs < rhs); } + friend bool operator==(const irange& lhs, const irange& rhs) { return lhs.i == rhs.i; } + friend bool operator!=(const irange& lhs, const irange& rhs) { return !(lhs == rhs); } + irange& operator++() { ++i; return *this; } + irange& operator--() { --i; return *this; } + irange operator++(int) { irange tmp = *this; ++tmp; return tmp; } + irange operator--(int) { irange tmp = *this; --tmp; return tmp; } + irange& operator+=(difference_type n) { i += n; return *this; } + irange& operator-=(difference_type n) { i -= n; return *this; } + friend irange operator+(const irange& lhs, difference_type n) { irange tmp = lhs; tmp += n; return tmp; } + friend irange operator+(difference_type n, const irange& rhs) { return rhs + n; } + friend irange operator-(const irange& lhs, difference_type n) { irange tmp = lhs; tmp -= n; return tmp; } + friend difference_type operator-(const irange& lhs, const irange& rhs) { return lhs.i - rhs.i; } +}; + +typedef struct { + Client c; + int index; +} IndexedClient; + +class ClientPool : LockPool { +public: + ClientPool(); + ~ClientPool(); + void config(config_options_t opt, int val); + int init(const char* const * hosts, const uint32_t* ports, + const size_t n, const char* const * aliases = NULL); + int updateServers(const char* const * hosts, const uint32_t* ports, + const size_t n, const char* const * aliases = NULL); + IndexedClient* _acquire(); + void _release(const IndexedClient* idx); + Client* acquire(); + void release(const Client* ref); + +private: + int growPool(size_t by); + int setup(Client* c); + inline bool shouldGrowUnsafe(); + int autoGrow(); + + bool m_opt_changed[CLIENT_CONFIG_OPTION_COUNT]; + int m_opt_value[CLIENT_CONFIG_OPTION_COUNT]; + std::deque m_clients; + size_t m_initial_clients; + size_t m_max_clients; + size_t m_max_growth; + + std::deque > m_hosts_data; + std::deque > m_aliases_data; + std::vector m_ports; + + std::vector m_hosts; + std::vector m_aliases; + + std::mutex m_pool_lock; + mutable std::shared_mutex m_acquiring_growth; +}; + +} // namespace mc +} // namespace douban diff --git a/include/Export.h b/include/Export.h index 97a026c0..e7db3e58 100644 --- a/include/Export.h +++ b/include/Export.h @@ -9,11 +9,21 @@ typedef enum { - CFG_POLL_TIMEOUT, + // Client config options + CFG_POLL_TIMEOUT = 0, CFG_CONNECT_TIMEOUT, CFG_RETRY_TIMEOUT, CFG_HASH_FUNCTION, - CFG_MAX_RETRIES + CFG_MAX_RETRIES, + CFG_SET_FAILOVER, + + // type separator to track number of Client config options to save + CLIENT_CONFIG_OPTION_COUNT, + + // ClientPool config options + CFG_INITIAL_CLIENTS, + CFG_MAX_CLIENTS, + CFG_MAX_GROWTH } config_options_t; diff --git a/include/LockPool.h b/include/LockPool.h new file mode 100644 index 00000000..89130c00 --- /dev/null +++ b/include/LockPool.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace douban { +namespace mc { + +// https://stackoverflow.com/a/14792685/3476782 +class OrderedLock { + std::queue m_fifo_locks; +protected: + std::mutex m_fifo_access; + std::atomic m_locked; + +protected: + OrderedLock() : m_locked(true) {}; + std::unique_lock lock() { + std::unique_lock acquire(m_fifo_access); + if (m_locked) { + std::condition_variable signal; + m_fifo_locks.emplace(&signal); + signal.wait(acquire); + m_fifo_locks.pop(); + } else { + m_locked = true; + } + return acquire; + } + + void unlock() { + if (m_fifo_locks.empty()) { + m_locked = false; + } else { + m_fifo_locks.front()->notify_all(); + } + } +}; + +class LockPool : public OrderedLock { + std::deque m_available; + std::list m_muxes; + std::list m_mux_mallocs; + +protected: + std::deque m_thread_workers; + + LockPool() {} + ~LockPool() { + std::lock_guard freeing(m_fifo_access); + for (auto worker : m_thread_workers) { + std::lock_guard freeing_worker(*worker); + } + for (auto mem : m_muxes) { + mem->std::mutex::~mutex(); + } + for (auto mem : m_mux_mallocs) { + delete[] mem; + } + } + + void addWorkers(size_t n) { + std::unique_lock growing_pool(m_fifo_access); + const auto from = m_thread_workers.size(); + const auto muxes = new std::mutex[n]; + m_mux_mallocs.push_back(muxes); + for (size_t i = 0; i < n; i++) { + m_available.push_back(from + i); + m_muxes.push_back(&muxes[i]); + } + // static_cast needed for some versions of C++ + std::transform( + muxes, muxes + n, std::back_inserter(m_thread_workers), + static_cast(std::addressof)); + unlock(); + } + + int acquireWorker() { + auto fifo_lock = lock(); + const auto res = m_available.front(); + m_available.pop_front(); + if (!m_available.empty()) { + unlock(); + } + return res; + } + + void releaseWorker(int worker) { + std::unique_lock growing_pool(m_fifo_access); + m_available.push_front(worker); + unlock(); + } +}; + +} // namespace mc +} // namespace douban diff --git a/libmc/__init__.py b/libmc/__init__.py index a4395a6b..b0f05eaa 100644 --- a/libmc/__init__.py +++ b/libmc/__init__.py @@ -1,13 +1,19 @@ import os +import functools from ._client import ( PyClient, ThreadUnsafe, encode_value, decode_value, + PyClientPool, PyClientUnsafe as ClientUnsafe, MC_DEFAULT_EXPTIME, MC_POLL_TIMEOUT, MC_CONNECT_TIMEOUT, MC_RETRY_TIMEOUT, + MC_SET_FAILOVER, + MC_INITIAL_CLIENTS, + MC_MAX_CLIENTS, + MC_MAX_GROWTH, MC_HASH_MD5, MC_HASH_FNV1_32, @@ -27,25 +33,52 @@ __file__ as _libmc_so_file ) -__VERSION__ = "1.4.3" -__version__ = "v1.4.3" +__VERSION__ = "1.4.4" +__version__ = "v1.4.4" __author__ = "mckelvin" __email__ = "mckelvin@users.noreply.github.com" -__date__ = "Fri Dec 1 07:43:12 2023 +0800" +__date__ = "Sat Jun 1 05:10:05 2024 +0800" class Client(PyClient): pass +class ClientPool(PyClientPool): + pass + +class ThreadedClient: + def __init__(self, *args, **kwargs): + self._client_pool = ClientPool(*args, **kwargs) + + def update_servers(self, servers): + return self._client_pool.update_servers(servers) + + def config(self, opt, val): + self._client_pool.config(opt, val) + + def __getattr__(self, key): + if not hasattr(Client, key): + raise AttributeError + result = getattr(Client, key) + if callable(result): + @functools.wraps(result) + def wrapper(*args, **kwargs): + with self._client_pool.client() as mc: + return getattr(mc, key)(*args, **kwargs) + return wrapper + return result + DYNAMIC_LIBRARIES = [os.path.abspath(_libmc_so_file)] __all__ = [ 'Client', 'ThreadUnsafe', '__VERSION__', 'encode_value', 'decode_value', + 'ClientUnsafe', 'ClientPool', 'ThreadedClient', 'MC_DEFAULT_EXPTIME', 'MC_POLL_TIMEOUT', 'MC_CONNECT_TIMEOUT', - 'MC_RETRY_TIMEOUT', + 'MC_RETRY_TIMEOUT', 'MC_SET_FAILOVER', 'MC_INITIAL_CLIENTS', + 'MC_MAX_CLIENTS', 'MC_MAX_GROWTH', 'MC_HASH_MD5', 'MC_HASH_FNV1_32', 'MC_HASH_FNV1A_32', 'MC_HASH_CRC_32', diff --git a/libmc/_client.pyx b/libmc/_client.pyx index 44719d0a..1bfd7f71 100644 --- a/libmc/_client.pyx +++ b/libmc/_client.pyx @@ -24,7 +24,7 @@ import threading import zlib import marshal import warnings - +from contextlib import contextmanager cdef extern from "Common.h" namespace "douban::mc": ctypedef enum op_code_t: @@ -55,6 +55,11 @@ cdef extern from "Export.h": CFG_RETRY_TIMEOUT CFG_HASH_FUNCTION CFG_MAX_RETRIES + CFG_SET_FAILOVER + + CFG_INITIAL_CLIENTS + CFG_MAX_CLIENTS + CFG_MAX_GROWTH ctypedef enum hash_function_options_t: OPT_HASH_MD5 @@ -217,6 +222,26 @@ cdef extern from "Client.h" namespace "douban::mc": const char* errCodeToString(err_code_t err) nogil + +cdef extern from "ClientPool.h" namespace "douban::mc": + ctypedef struct IndexedClient: + Client c + int index + + cdef cppclass ClientPool: + ClientPool() + void config(config_options_t opt, int val) nogil + int init(const char* const * hosts, const uint32_t* ports, size_t n, + const char* const * aliases) nogil + int updateServers(const char* const * hosts, const uint32_t* ports, size_t n, + const char* const * aliases) nogil + IndexedClient* _acquire() nogil + void _release(const IndexedClient* ref) nogil + +ctypedef fused Configurable: + Client + ClientPool + cdef uint32_t MC_DEFAULT_PORT = 11211 cdef flags_t _FLAG_EMPTY = 0 cdef flags_t _FLAG_PICKLE = 1 << 0 @@ -237,6 +262,10 @@ MC_POLL_TIMEOUT = PyInt_FromLong(CFG_POLL_TIMEOUT) MC_CONNECT_TIMEOUT = PyInt_FromLong(CFG_CONNECT_TIMEOUT) MC_RETRY_TIMEOUT = PyInt_FromLong(CFG_RETRY_TIMEOUT) MC_MAX_RETRIES = PyInt_FromLong(CFG_MAX_RETRIES) +MC_SET_FAILOVER = PyInt_FromLong(CFG_SET_FAILOVER) +MC_INITIAL_CLIENTS = PyInt_FromLong(CFG_INITIAL_CLIENTS) +MC_MAX_CLIENTS = PyInt_FromLong(CFG_MAX_CLIENTS) +MC_MAX_GROWTH = PyInt_FromLong(CFG_MAX_GROWTH) MC_HASH_MD5 = PyInt_FromLong(OPT_HASH_MD5) @@ -335,32 +364,20 @@ class ThreadUnsafe(Exception): pass -cdef class PyClient: +cdef class PyClientSettings: cdef readonly list servers cdef readonly int comp_threshold - cdef Client* _imp cdef bool_t do_split cdef bool_t noreply cdef bytes prefix cdef hash_function_options_t hash_fn cdef bool_t failover cdef basestring encoding - cdef err_code_t last_error - cdef object _thread_ident - cdef object _created_stack def __cinit__(self, list servers, bool_t do_split=True, int comp_threshold=0, noreply=False, basestring prefix=None, hash_function_options_t hash_fn=OPT_HASH_MD5, failover=False, encoding='utf8'): self.servers = servers - self._imp = new Client() - self._imp.config(CFG_HASH_FUNCTION, hash_fn) - rv = self._update_servers(servers, True) - if failover: - self._imp.enableConsistentFailover() - else: - self._imp.disableConsistentFailover() - self.do_split = do_split self.comp_threshold = comp_threshold self.noreply = noreply @@ -372,53 +389,64 @@ cdef class PyClient: else: self.prefix = None - self.last_error = RET_OK - self._thread_ident = None - self._created_stack = traceback.extract_stack() - - cdef _update_servers(self, list servers, bool_t init): - cdef int rv = 0 - cdef size_t n = len(servers) - cdef char** c_hosts = PyMem_Malloc(n * sizeof(char*)) - cdef uint32_t* c_ports = PyMem_Malloc(n * sizeof(uint32_t)) - cdef char** c_aliases = PyMem_Malloc(n * sizeof(char*)) - - servers_ = [] - for srv in servers: - if PY_MAJOR_VERSION > 2: - srv = PyUnicode_AsUTF8String(srv) - srv = PyString_AsString(srv) - servers_.append(srv) - - Py_INCREF(servers_) - for i in range(n): - c_split = splitServerString(servers_[i]) + def _args(self): + return (self.servers, self.do_split, self.comp_threshold, self.noreply, + self.prefix, self.hash_fn, self.failover, self.encoding) - c_hosts[i] = c_split.host - c_aliases[i] = c_split.alias - if c_split.port == NULL: - c_ports[i] = MC_DEFAULT_PORT - else: - c_ports[i] = PyInt_AsLong(int(c_split.port)) - - if init: - rv = self._imp.init(c_hosts, c_ports, n, c_aliases) + def __reduce__(self): + return (self.__class__, self._args()) + +cdef _update_servers(Configurable* imp, list servers, bool_t init): + cdef int rv = 0 + cdef size_t n = len(servers) + cdef char** c_hosts = PyMem_Malloc(n * sizeof(char*)) + cdef uint32_t* c_ports = PyMem_Malloc(n * sizeof(uint32_t)) + cdef char** c_aliases = PyMem_Malloc(n * sizeof(char*)) + + servers_ = [] + for srv in servers: + if PY_MAJOR_VERSION > 2: + srv = PyUnicode_AsUTF8String(srv) + srv = PyString_AsString(srv) + servers_.append(srv) + + Py_INCREF(servers_) + for i in range(n): + c_split = splitServerString(servers_[i]) + + c_hosts[i] = c_split.host + c_aliases[i] = c_split.alias + if c_split.port == NULL: + c_ports[i] = MC_DEFAULT_PORT else: - rv = self._imp.updateServers(c_hosts, c_ports, n, c_aliases) + c_ports[i] = PyInt_AsLong(int(c_split.port)) - PyMem_Free(c_hosts) - PyMem_Free(c_ports) - PyMem_Free(c_aliases) + if init: + rv = imp.init(c_hosts, c_ports, n, c_aliases) + else: + rv = imp.updateServers(c_hosts, c_ports, n, c_aliases) - Py_DECREF(servers_) + PyMem_Free(c_hosts) + PyMem_Free(c_ports) + PyMem_Free(c_aliases) - return rv + Py_DECREF(servers_) - def __dealloc__(self): - del self._imp + if rv + len(servers) == 0: + return True + elif init: + warnings.warn("Client failed to initialize") + return False - def __reduce__(self): - return (PyClient, (self.servers, self.do_split, self.comp_threshold, self.noreply, self.prefix, self.hash_fn, self.failover, self.encoding)) +cdef class PyClientShell(PyClientSettings): + cdef Client* _imp + cdef err_code_t last_error + cdef object _thread_ident + cdef object _created_stack + + def __cinit__(self): + self.last_error = RET_OK + self._thread_ident = None def config(self, int opt, int val): self._imp.config(opt, val) @@ -1082,12 +1110,98 @@ cdef class PyClient: def get_last_error(self): return self.last_error + def get_last_strerror(self): + return errCodeToString(self.last_error) + +cdef class PyClient(PyClientShell): + def __cinit__(self): + self._created_stack = traceback.extract_stack() + self._imp = new Client() + self._imp.config(CFG_HASH_FUNCTION, self.hash_fn) + self.connect() + + if self.failover: + self._imp.enableConsistentFailover() + else: + self._imp.disableConsistentFailover() + + cdef connect(self): + return _update_servers(self._imp, self.servers, True) + def update_servers(self, servers): - rv = self._update_servers(servers, False) - if rv + len(servers) == 0: + if _update_servers(self._imp, servers, False): self.servers = servers return True return False - def get_last_strerror(self): - return errCodeToString(self.last_error) + def __dealloc__(self): + del self._imp + +cdef class PyClientUnsafe(PyClient): + def _check_thread_ident(self): + pass + +cdef class PyPoolClient(PyClientShell): + cdef IndexedClient* _indexed + + def _record_thread_ident(self): + pass + + def _check_thread_ident(self): + pass + +cdef class PyClientPool(PyClientSettings): + cdef list clients + cdef ClientPool* _imp + + def __cinit__(self): + self._imp = new ClientPool() + self.config(CFG_HASH_FUNCTION, self.hash_fn) + self.clients = [] + + if self.failover: + self.config(CFG_SET_FAILOVER, 1) + else: + self.config(CFG_SET_FAILOVER, 0) + + self.connect() + + def config(self, int opt, int val): + self._imp.config(opt, val) + + cdef setup(self, IndexedClient* imp): + worker = PyPoolClient(*self._args()) + worker._indexed = imp + worker._imp = &imp.c + return worker + + cdef acquire(self): + with nogil: + worker = self._imp._acquire() + return self.setup(worker) + + cdef release(self, PyPoolClient worker): + with nogil: + self._imp._release(worker._indexed) + + @contextmanager + def client(self): + worker = self.acquire() + try: + yield worker + finally: + self.release(worker) + + # repeated from PyClient because cython can't handle fused types in classes + # https://github.com/cython/cython/issues/3283 + cdef connect(self): + return _update_servers(self._imp, self.servers, True) + + def update_servers(self, servers): + if _update_servers(self._imp, servers, False): + self.servers = servers + return True + return False + + def __dealloc__(self): + del self._imp diff --git a/misc/.cppcheck-supp b/misc/.cppcheck-supp index 0973d40e..36e0a410 100644 --- a/misc/.cppcheck-supp +++ b/misc/.cppcheck-supp @@ -15,7 +15,8 @@ *:include/llvm/SmallVector.h:735 *:include/llvm/SmallVector.h:796 constParameter:include/BufferReader.h:99 -unusedFunction:src/Client.cpp:232 +unreadVariable:include/LockPool.h +unusedFunction:src/Client.cpp:239 unusedFunction:src/c_client.cpp:8 unusedFunction:src/c_client.cpp:13 unusedFunction:src/c_client.cpp:25 diff --git a/misc/runbench.py b/misc/runbench.py index 40b2b49a..0135b7d5 100644 --- a/misc/runbench.py +++ b/misc/runbench.py @@ -5,9 +5,11 @@ import sys import math import logging +import threading from functools import wraps from collections import namedtuple from contextlib import contextmanager +from queue import Queue import pylibmc import libmc @@ -20,10 +22,16 @@ logger = logging.getLogger('libmc.bench') Benchmark = namedtuple('Benchmark', 'name f args kwargs') -Participant = namedtuple('Participant', 'name factory') +Participant = namedtuple('Participant', 'name factory threads', defaults=(1,)) BENCH_TIME = 1.0 N_SERVERS = 20 +NTHREADS = 4 +POOL_SIZE = 4 +# setting (eg) NTHREADS to 40 and POOL_SIZE to 4 illustrates a failure case of a +# simpler python solution to thread pools for clients +#NTHREADS = 40 +#POOL_SIZE = 4 class Prefix(object): '''add prefix for key in mc command''' @@ -96,7 +104,7 @@ def __init__(self): def __unicode__(self): m = self.mean() d = self.stddev() - fmt = u"%.3gs, σ=%.3g, n=%d, snr=%.3g:%.3g".__mod__ + fmt = u"%.3gs, \u03C3=%.3g, n=%d, snr=%.3g:%.3g".__mod__ return fmt((m, d, len(self.laps)) + ratio(m, d)) __str__ = __unicode__ @@ -121,6 +129,33 @@ def timing(self): self.laps.append(te - t0) +class DelayedStopwatch(Stopwatch): + def __init__(self, laps=None, bound=0): + super().__init__() + self.laps = laps or [] + self._bound = bound + + @property + def bound(self): + return self._bound or sum(self.laps) + + def timing(self): + self.t0 = process_time() + self.timing = super().timing + return super().timing() + + def __add__(self, other): + bound = ((self.bound or other.bound) + (other.bound or self.bound)) / 2 + return DelayedStopwatch(self.laps + other.laps, bound) + + def mean(self): + return self.bound / len(self.laps) + + def stddev(self): + boundless = DelayedStopwatch(self.laps) if self._bound else super() + return boundless.stddev() + + def benchmark_method(f): "decorator to turn f into a factory of benchmarks" @@ -139,7 +174,7 @@ def bench_get(mc, key, data): @benchmark_method def bench_set(mc, key, data): - if isinstance(mc.mc, libmc.Client): + if any(isinstance(mc.mc, client) for client in libmc_clients): if not mc.set(key, data): logger.warn('%r.set(%r, ...) fail', mc, key) else: @@ -156,7 +191,7 @@ def bench_get_multi(mc, keys, pairs): @benchmark_method def bench_set_multi(mc, keys, pairs): ret = mc.set_multi(pairs) - if isinstance(mc.mc, libmc.Client): + if any(isinstance(mc.mc, client) for client in libmc_clients): if not ret: logger.warn('%r.set_multi fail', mc) else: @@ -203,9 +238,113 @@ def make_pylibmc_client(servers, **kw): return Prefix(__import__('pylibmc').Client(servers_, **kw), prefix) +class Pool: + ''' adapted from pylibmc ''' + + client = libmc.ClientUnsafe + + def __init__(self, *args, **kwargs): + self.args, self.kwargs = args, kwargs + + def clone(self): + return self.client(*self.args, **self.kwargs) + + def __getattr__(self, key): + if not hasattr(libmc.Client, key): + raise AttributeError + result = getattr(libmc.Client, key) + if callable(result): + @wraps(result) + def wrapper(*args, **kwargs): + with self.reserve() as mc: + return getattr(mc, key)(*args, **kwargs) + return wrapper + return result + + +class ThreadMappedPool(Pool): + client = libmc.Client + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.clients = {} + + @property + def current_key(self): + return threading.current_thread().native_id + + @contextmanager + def reserve(self): + key = self.current_key + mc = self.clients.pop(key, None) + if mc is None: + mc = self.clone() + try: + yield mc + finally: + self.clients[key] = mc + + +class ThreadPool(Pool): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.clients = Queue() + for _ in range(POOL_SIZE): + self.clients.put(self.clone()) + + @contextmanager + def reserve(self): + mc = self.clients.get() + try: + yield mc + finally: + self.clients.put(mc) + + +class BenchmarkThreadedClient(libmc.ThreadedClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config(libmc.MC_INITIAL_CLIENTS, POOL_SIZE) + self.config(libmc.MC_MAX_CLIENTS, POOL_SIZE) + + +class FIFOThreadPool(ThreadPool): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.waiting = Queue() + self.semaphore = Queue(1) # sorry + self.semaphore.put(1) + + @contextmanager + def reserve(self): + try: + self.semaphore.get() + mc = self.clients.get(False) + self.semaphore.put(1) + except: + channel = Queue(1) + self.waiting.put(channel) + self.semaphore.put(1) + channel.get() + mc = self.clients.get(False) + self.semaphore.put(1) + try: + yield mc + finally: + self.semaphore.get() + self.clients.put(mc) + try: + self.waiting.get(False).put(1) + except: + self.semaphore.put(1) + + host = '127.0.0.1' servers = ['%s:%d' % (host, 21211 + i) for i in range(N_SERVERS)] +libmc_clients = (libmc.Client, BenchmarkThreadedClient, ThreadMappedPool, ThreadPool) +libmc_kwargs = {"servers": servers, "comp_threshold": 4000} + participants = [ Participant( name='pylibmc (md5 / ketama)', @@ -233,11 +372,15 @@ def make_pylibmc_client(servers, **kw): Participant(name='python-memcached', factory=lambda: Prefix(__import__('memcache').Client(servers), 'memcache1')), Participant( name='libmc(md5 / ketama / nodelay / nonblocking, from douban)', - factory=lambda: Prefix(__import__('libmc').Client(servers, comp_threshold=4000), 'libmc1') + factory=lambda: Prefix(__import__('libmc').Client(**libmc_kwargs), 'libmc1') + ), + Participant( + name='libmc(md5 / ketama / nodelay / nonblocking / C++ thread pool, from douban)', + factory=lambda: Prefix(BenchmarkThreadedClient(**libmc_kwargs), 'libmc2'), + threads=NTHREADS ), ] - def bench(participants=participants, benchmarks=benchmarks, bench_time=BENCH_TIME): """Do you even lift?""" @@ -252,20 +395,32 @@ def bench(participants=participants, benchmarks=benchmarks, bench_time=BENCH_TIM logger.info('%s', benchmark_name) for i, (participant, mc) in enumerate(zip(participants, mcs)): + def loop(sw): + while sw.total() < bench_time: + with sw.timing(): + fn(mc, *args, **kwargs) # FIXME: set before bench for get if 'get' in fn.__name__: last_fn(mc, *args, **kwargs) - sw = Stopwatch() - while sw.total() < bench_time: - with sw.timing(): - fn(mc, *args, **kwargs) + if participant.threads == 1: + sw = [DelayedStopwatch()] + loop(sw[0]) + else: + sw = [DelayedStopwatch() for i in range(participant.threads)] + ts = [threading.Thread(target=loop, args=[i]) for i in sw] + for t in ts: + t.start() + + for t in ts: + t.join() - means[i].append(sw.mean()) - stddevs[i].append(sw.stddev()) + total = sum(sw, DelayedStopwatch()) + means[i].append(total.mean()) + stddevs[i].append(total.stddev()) - logger.info(u'%76s: %s', participant.name, sw) + logger.info(u'%76s: %s', participant.name, total) last_fn = fn return means, stddevs @@ -274,6 +429,8 @@ def bench(participants=participants, benchmarks=benchmarks, bench_time=BENCH_TIM def main(args=sys.argv[1:]): logger.info('pylibmc: %s', pylibmc.__file__) logger.info('libmc: %s', libmc.__file__) + logger.info('Running %s servers, %s threads, and a %s client pool', + N_SERVERS, NTHREADS, POOL_SIZE) ps = [p for p in participants if p.name in args] ps = ps if ps else participants diff --git a/setup.py b/setup.py index c9807b95..b404f128 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ "-DMC_USE_SMALL_VECTOR", "-O3", "-DNDEBUG", + "-std=c++17", ] diff --git a/src/Client.cpp b/src/Client.cpp index 889c9800..7db3ad40 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -33,6 +33,13 @@ void Client::config(config_options_t opt, int val) { case CFG_MAX_RETRIES: setMaxRetries(val); break; + case CFG_SET_FAILOVER: + assert(val == 0 || val == 1); + if (val == 0) { + disableConsistentFailover(); + } else { + enableConsistentFailover(); + } default: break; } diff --git a/src/ClientPool.cpp b/src/ClientPool.cpp new file mode 100644 index 00000000..a7ff44c0 --- /dev/null +++ b/src/ClientPool.cpp @@ -0,0 +1,156 @@ +//#include +#include +#include "ClientPool.h" + +namespace douban { +namespace mc { + +// default max of 4 clients to match memcached's default of 4 worker threads +ClientPool::ClientPool(): m_initial_clients(1), m_max_clients(4), m_max_growth(4) { + memset(m_opt_changed, false, sizeof m_opt_changed); + memset(m_opt_value, 0, sizeof m_opt_value); +} + +ClientPool::~ClientPool() { +} + +void ClientPool::config(config_options_t opt, int val) { + std::lock_guard config_pool(m_pool_lock); + if (opt < CLIENT_CONFIG_OPTION_COUNT) { + m_opt_changed[opt] = true; + m_opt_value[opt] = val; + for (auto &client : m_clients) { + client.c.config(opt, val); + } + return; + } + std::unique_lock initializing(m_acquiring_growth); + switch (opt) { + case CFG_INITIAL_CLIENTS: + m_initial_clients = val; + if (m_initial_clients > m_max_clients) { + m_max_clients = m_initial_clients; + } + if (m_clients.size() < m_initial_clients) { + growPool(m_initial_clients); + } + break; + case CFG_MAX_CLIENTS: + m_max_clients = val; + break; + case CFG_MAX_GROWTH: + m_max_growth = val; + break; + default: + break; + } +} + +int ClientPool::init(const char* const * hosts, const uint32_t* ports, + const size_t n, const char* const * aliases) { + updateServers(hosts, ports, n, aliases); + std::unique_lock initializing(m_acquiring_growth); + std::lock_guard config_pool(m_pool_lock); + return growPool(m_initial_clients); +} + +int ClientPool::updateServers(const char* const* hosts, const uint32_t* ports, + const size_t n, const char* const* aliases) { + std::lock_guard updating_clients(m_pool_lock); + duplicate_strings(hosts, n, m_hosts_data, m_hosts); + duplicate_strings(aliases, n, m_aliases_data, m_aliases); + + m_ports.resize(n); + std::copy(ports, ports + n, m_ports.begin()); + + std::atomic rv = 0; + std::lock_guard updating(m_fifo_access); + std::for_each(irange(0), irange(m_clients.size()), + //std::for_each(std::execution::par_unseq, irange(0), irange(m_clients.size()), + [this, &rv](int i) { + std::lock_guard updating_worker(*m_thread_workers[i]); + const int err = m_clients[i].c.updateServers( + m_hosts.data(), m_ports.data(), m_hosts.size(), m_aliases.data()); + if (err != 0) { + rv.store(err, std::memory_order_relaxed); + } + }); + return rv; +} + +int ClientPool::setup(Client* c) { + for (int i = 0; i < CLIENT_CONFIG_OPTION_COUNT; i++) { + if (m_opt_changed[i]) { + c->config(static_cast(i), m_opt_value[i]); + } + } + return c->init(m_hosts.data(), m_ports.data(), m_hosts.size(), m_aliases.data()); +} + +// needs to hold both m_acquiring_growth and m_pool_lock +int ClientPool::growPool(size_t by) { + assert(by > 0); + size_t from = m_clients.size(); + m_clients.resize(from + by); + std::atomic rv = 0; + std::for_each(irange(from), irange(from + by), + //std::for_each(std::execution::par_unseq, irange(from), irange(from + by), + [this, &rv](int i) { + const int err = setup(&m_clients[i].c); + m_clients[i].index = i; + if (err != 0) { + rv.store(err, std::memory_order_relaxed); + } + }); + // adds workers with non-zero return values + // if changed, acquire should probably raise rather than hang + addWorkers(by); + return rv; +} + +inline bool ClientPool::shouldGrowUnsafe() { + return m_clients.size() < m_max_clients && m_locked; +} + +int ClientPool::autoGrow() { + std::unique_lock growing(m_acquiring_growth); + if (shouldGrowUnsafe()) { + std::lock_guard growing_pool(m_pool_lock); + return growPool(MIN(m_max_clients - m_clients.size(), + MIN(m_max_growth, m_clients.size()))); + } + return 0; +} + +IndexedClient* ClientPool::_acquire() { + m_acquiring_growth.lock_shared(); + const auto growing = shouldGrowUnsafe(); + m_acquiring_growth.unlock_shared(); + if (growing) { + std::thread acquire_overflow(&ClientPool::autoGrow, this); + acquire_overflow.detach(); + } + + int idx = acquireWorker(); + m_thread_workers[idx]->lock(); + return &m_clients[idx]; +} + +void ClientPool::_release(const IndexedClient* idx) { + std::mutex* const * mux = &m_thread_workers[idx->index]; + (**mux).unlock(); + releaseWorker(idx->index); +} + +Client* ClientPool::acquire() { + return &_acquire()->c; +} + +void ClientPool::release(const Client* ref) { + // C std 6.7.2.1-13 + auto idx = reinterpret_cast(ref); + return _release(idx); +} + +} // namespace mc +} // namespace douban diff --git a/src/golibmc.go b/src/golibmc.go index 924e98ab..45eb3fe2 100644 --- a/src/golibmc.go +++ b/src/golibmc.go @@ -2,7 +2,7 @@ package golibmc /* #cgo CFLAGS: -I ./../include -#cgo CXXFLAGS: -I ./../include +#cgo CXXFLAGS: -std=c++17 -I ./../include #include "c_client.h" */ import "C" diff --git a/src/version.go b/src/version.go index 38f0f52d..978dc32d 100644 --- a/src/version.go +++ b/src/version.go @@ -1,9 +1,9 @@ package golibmc -const _Version = "v1.4.3" +const _Version = "v1.4.4" const _Author = "mckelvin" const _Email = "mckelvin@users.noreply.github.com" -const _Date = "Fri Dec 1 07:43:12 2023 +0800" +const _Date = "Sat Jun 1 05:10:05 2024 +0800" // Version of the package const Version = _Version diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9af3d51e..9f301081 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,7 +2,7 @@ file(GLOB TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/test_*.cpp) foreach(SRC ${TEST_SRC_FILES}) get_filename_component(test_name ${SRC} NAME_WE) - set_source_files_properties(${SRC} PROPERTIES COMPILE_FLAGS "-fexceptions") + set_source_files_properties(${SRC} PROPERTIES COMPILE_FLAGS "-fexceptions -std=c++17") add_executable(${test_name} ${SRC}) add_dependencies(${test_name} gtest) target_link_libraries(${test_name} diff --git a/tests/test_client_pool.cpp b/tests/test_client_pool.cpp new file mode 100644 index 00000000..092fe7e4 --- /dev/null +++ b/tests/test_client_pool.cpp @@ -0,0 +1,96 @@ +#include "ClientPool.h" +#include "test_common.h" + +#include +#include "gtest/gtest.h" + +using douban::mc::ClientPool; +using douban::mc::tests::gen_random; + +const unsigned int n_ops = 5; +const unsigned int data_size = 10; +const unsigned int n_servers = 20; +const unsigned int start_port = 21211; +const char host[] = "127.0.0.1"; +unsigned int n_threads = 8; + +void inner_test_loop(ClientPool* pool) { + retrieval_result_t **r_results = NULL; + message_result_t **m_results = NULL; + size_t nResults = 0; + flags_t flags[] = {}; + size_t data_lens[] = {data_size}; + exptime_t exptime = 0; + char key[data_size + 1]; + char value[data_size + 1]; + const char* keys = &key[0]; + const char* values = &value[0]; + + for (unsigned int j = 0; j < n_ops; j++) { + gen_random(key, data_size); + gen_random(value, data_size); + auto c = pool->acquire(); + c->set(&keys, data_lens, flags, exptime, NULL, 0, &values, data_lens, 1, &m_results, &nResults); + c->destroyMessageResult(); + c->get(&keys, data_lens, 1, &r_results, &nResults); + EXPECT_EQ(nResults, 1); + ASSERT_N_STREQ(r_results[0]->data_block, values, data_size); + c->destroyRetrievalResult(); + pool->release(c); + } +} + +bool check_availability(ClientPool* pool) { + auto c = pool->acquire(); + broadcast_result_t* results; + size_t nHosts; + int ret = c->version(&results, &nHosts); + c->destroyBroadcastResult(); + pool->release(c); + return ret == 0; +} + +TEST(test_client_pool, simple_set_get) { + uint32_t ports[n_servers]; + const char* hosts[n_servers]; + for (unsigned int i = 0; i < n_servers; i++) { + ports[i] = start_port + i; + hosts[i] = host; + } + + ClientPool* pool = new ClientPool(); + pool->config(CFG_HASH_FUNCTION, OPT_HASH_FNV1A_32); + pool->init(hosts, ports, n_servers); + ASSERT_TRUE(check_availability(pool)); + + for (unsigned int j = 0; j < n_threads; j++) { + inner_test_loop(pool); + } + + delete pool; +} + +TEST(test_client_pool, threaded_set_get) { + uint32_t ports[n_servers]; + const char* hosts[n_servers]; + for (unsigned int i = 0; i < n_servers; i++) { + ports[i] = start_port + i; + hosts[i] = host; + } + + std::thread* threads = new std::thread[n_threads]; + ClientPool* pool = new ClientPool(); + pool->config(CFG_HASH_FUNCTION, OPT_HASH_FNV1A_32); + //pool->config(CFG_INITIAL_CLIENTS, 4); + pool->init(hosts, ports, n_servers); + ASSERT_TRUE(check_availability(pool)); + + for (unsigned int i = 0; i < n_threads; i++) { + threads[i] = std::thread([&pool] { inner_test_loop(pool); }); + } + for (unsigned int i = 0; i < n_threads; i++) { + threads[i].join(); + } + delete[] threads; + delete pool; +} diff --git a/tests/test_client_pool.py b/tests/test_client_pool.py new file mode 100644 index 00000000..d1a0ce45 --- /dev/null +++ b/tests/test_client_pool.py @@ -0,0 +1,112 @@ +# coding: utf-8 +import unittest +import threading +import functools +import os +from libmc import ClientPool, ThreadedClient, MC_MAX_CLIENTS + +def setup_loging(f): + g = None + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return g(*args, **kwargs) + + @functools.wraps(f) + def begin(*args, **kwargs): + nonlocal g + with open("/tmp/debug.log", "w+") as fp: + fp.write("") + g = f + return wrapper(*args, **kwargs) + + g = begin + return wrapper + +@functools.wraps(print) +@setup_loging +def threaded_print(*args, **kwargs): + with open('/tmp/debug.log', 'a+') as fp: + print(*args, **kwargs, file=fp) + +class ClientOps: + nthreads=8 + ops = 100 + + def tid(self, mc): + return (os.getpid(), threading.current_thread().native_id) + + def client_misc(self, mc, i=0): + tid = self.tid(mc) + (i,) + tid = "_".join(map(str, tid)) + f, t = 'foo_' + tid, 'tuiche_' + tid + mc.get_multi([f, t]) + mc.delete(f) + mc.delete(t) + assert mc.get(f) is None + assert mc.get(t) is None + + mc.set(f, 'biu') + mc.set(t, 'bb') + assert mc.get(f) == 'biu' + assert mc.get(t) == 'bb' + assert (mc.get_multi([f, t]) == + {f: 'biu', t: 'bb'}) + mc.set_multi({f: 1024, t: '8964'}) + assert (mc.get_multi([f, t]) == + {f: 1024, t: '8964'}) + + def client_threads(self, target): + errs = [] + def passthrough(args): + _, e, tb, t = args + if hasattr(e, "add_note"): + e.add_note("Occurred in thread " + str(t)) + errs.append(e.with_traceback(tb)) + + threading.excepthook = passthrough + ts = [threading.Thread(target=target) for i in range(self.nthreads)] + for t in ts: + t.start() + + for t in ts: + t.join() + + if errs: + e = errs[0] + if hasattr(e, "add_note"): + e.add_note(f"Along with {len(errs)} errors in other threads") + raise e + +class ThreadedSingleServerCase(unittest.TestCase, ClientOps): + def setUp(self): + self.pool = ClientPool(["127.0.0.1:21211"]) + + def misc(self): + for i in range(self.ops): + self.test_pool_client_misc(i) + + def test_pool_client_misc(self, i=0): + with self.pool.client() as mc: + self.client_misc(mc, i) + + def test_acquire(self): + with self.pool.client() as mc: + pass + + def test_pool_client_threaded(self): + self.client_threads(self.misc) + +class ThreadedClientOps(ClientOps): + def misc(self): + for i in range(self.ops): + self.client_misc(self.imp, i) + + +class ThreadedClientWrapperCheck(unittest.TestCase, ThreadedClientOps): + def setUp(self): + self.imp = ThreadedClient(["127.0.0.1:21211"]) + + def test_many_threads(self): + self.client_threads(self.misc) +