From d32820f07f01c04d6f8c1272248fa47cd131b14c Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 10 Dec 2023 23:01:32 -0800 Subject: [PATCH] Add support for UNIX domain sockets (#120) * added support for UNIX domain sockets * untested progress * moved cython server string parse to c * regression test for host parse * updated cppcheck suppression * golang bindings * update c_client cppcheck * removed duplicated line * removed github workflow changes * moved unix client to test_unix * missed file * remove atoi usage * warning explanation * removed unused cppcheck suppression * startall for cython valgrind * better string iteration * sorry to commit after requesting review * reverted server string parsing behavior * code hygiene * version bump * syntax credit * version bump not overwritten by pre-commit hook * timezone matching email * padding type * fix port interpretation * revert to existing distribution of responsibilities * deduplicate golang test cases * remove unnecessary inline in header * rewrote to match default branch parsing more closely * more descriptive names * apt-get update before install --- .github/workflows/cpp.yml | 1 + .github/workflows/golang.yml | 1 + .github/workflows/python.yml | 6 ++-- include/Common.h | 2 ++ include/Connection.h | 4 ++- include/Export.h | 6 ++++ include/c_client.h | 1 + libmc/__init__.py | 6 ++-- libmc/_client.pyx | 39 +++++++++++------------- misc/.cppcheck-supp | 4 +-- misc/memcached_server | 35 ++++++++++++++++++++- misc/travis/cpptest.sh | 4 +-- src/Common.cpp | 43 ++++++++++++++++++++++++++ src/Connection.cpp | 55 +++++++++++++++++++++++++++++---- src/c_client.cpp | 4 +++ src/golibmc.go | 26 ++++++---------- src/version.go | 4 +-- tests/test_common.h | 30 ++++++++++-------- tests/test_unix.cpp | 59 ++++++++++++++++++++++++++++++++++++ 19 files changed, 260 insertions(+), 70 deletions(-) create mode 100644 tests/test_unix.cpp diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 0eaee9c8..c82e2482 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -14,6 +14,7 @@ jobs: - uses: actions/checkout@v2 - name: Setup system dependencies run: | + sudo apt-get update sudo apt-get -y install cppcheck - name: Run cppcheck run: | diff --git a/.github/workflows/golang.yml b/.github/workflows/golang.yml index f0d441e7..c542e54e 100644 --- a/.github/workflows/golang.yml +++ b/.github/workflows/golang.yml @@ -19,6 +19,7 @@ jobs: - uses: actions/checkout@v2 - name: Setup system dependencies run: | + sudo apt-get update sudo apt-get -y install memcached g++ - name: Set up Golang uses: actions/setup-go@v2 diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 0c8606c4..f364f304 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -19,6 +19,7 @@ jobs: - uses: actions/checkout@v2 - name: Setup system dependencies run: | + sudo apt-get update sudo apt-get -y install valgrind memcached g++ - name: Set up Python uses: actions/setup-python@v2 @@ -29,14 +30,14 @@ jobs: python -m pip install --upgrade pip pip install setuptools future pytest greenify gevent numpy - name: Start memcached servers - run: ./misc/memcached_server start + run: ./misc/memcached_server startall - name: Run unittest run: | if [[ ${{ matrix.compiler }} = "gcc" ]]; then export CC=gcc CXX=g++; fi if [[ ${{ matrix.compiler }} = "clang" ]]; then export CC=clang CXX=clang++; fi ./misc/travis/unittest.sh - name: Stop memcached servers - run: ./misc/memcached_server stop + run: ./misc/memcached_server stopall benchmark: runs-on: ubuntu-latest @@ -48,6 +49,7 @@ jobs: - uses: actions/checkout@v2 - name: Setup system dependencies run: | + sudo apt-get update sudo apt-get -y install memcached libmemcached-dev g++ - name: Set up Python uses: actions/setup-python@v2 diff --git a/include/Common.h b/include/Common.h index 06baf1a8..f7a22de8 100644 --- a/include/Common.h +++ b/include/Common.h @@ -226,6 +226,8 @@ typedef enum { } op_code_t; const char* errCodeToString(err_code_t err); +bool isLocalSocket(const char* host); +server_string_split_t splitServerString(char* input); } // namespace mc } // namespace douban diff --git a/include/Connection.h b/include/Connection.h index 79a63c52..c54f0379 100644 --- a/include/Connection.h +++ b/include/Connection.h @@ -61,7 +61,8 @@ class Connection { size_t m_counter; protected: - int connectPoll(int fd, struct addrinfo* ai_ptr); + int connectPoll(int fd, const sockaddr* ai_ptr, const socklen_t ai_addrlen); + int unixSocketConnect(); char m_name[MC_NI_MAXHOST + 1 + MC_NI_MAXSERV]; char m_host[MC_NI_MAXHOST]; @@ -70,6 +71,7 @@ class Connection { int m_socketFd; bool m_alive; bool m_hasAlias; + bool m_unixSocket; time_t m_deadUntil; io::BufferWriter* m_buffer_writer; // for send io::BufferReader* m_buffer_reader; // for recv diff --git a/include/Export.h b/include/Export.h index d43cdbde..97a026c0 100644 --- a/include/Export.h +++ b/include/Export.h @@ -38,6 +38,12 @@ typedef enum { RET_OK = 0 } err_code_t; +typedef struct { + char* host; + char* port; + char* alias; +} server_string_split_t; + typedef int64_t exptime_t; typedef uint32_t flags_t; diff --git a/include/c_client.h b/include/c_client.h index 445b7901..21b0c4e9 100644 --- a/include/c_client.h +++ b/include/c_client.h @@ -68,6 +68,7 @@ extern "C" { err_code_t client_quit(void* client); const char* err_code_to_string(err_code_t err); + server_string_split_t splitServerString(char* input); #ifdef __cplusplus } #endif diff --git a/libmc/__init__.py b/libmc/__init__.py index b5ac549e..a4395a6b 100644 --- a/libmc/__init__.py +++ b/libmc/__init__.py @@ -27,11 +27,11 @@ __file__ as _libmc_so_file ) -__VERSION__ = "1.4.2" -__version__ = "v1.4.2" +__VERSION__ = "1.4.3" +__version__ = "v1.4.3" __author__ = "mckelvin" __email__ = "mckelvin@users.noreply.github.com" -__date__ = "Thu May 20 18:44:42 2021 +0800" +__date__ = "Fri Dec 1 07:43:12 2023 +0800" class Client(PyClient): diff --git a/libmc/_client.pyx b/libmc/_client.pyx index 903a64e5..44719d0a 100644 --- a/libmc/_client.pyx +++ b/libmc/_client.pyx @@ -45,6 +45,8 @@ cdef extern from "Common.h" namespace "douban::mc": VERSION_OP QUIT_OP + server_string_split_t splitServerString(char* input) nogil + cdef extern from "Export.h": ctypedef enum config_options_t: @@ -99,6 +101,11 @@ cdef extern from "Export.h": RET_INCOMPLETE_BUFFER_ERR RET_OK + ctypedef struct server_string_split_t: + char* host + char* port + char* alias + ctypedef struct unsigned_result_t: char* key size_t key_len @@ -378,33 +385,21 @@ cdef class PyClient: servers_ = [] for srv in servers: - addr_alias = srv.split(' ') - addr = addr_alias[0] - if len(addr_alias) == 1: - alias = None - else: - alias = addr_alias[1] - - host_port = addr.split(':') - host = host_port[0] - if len(host_port) == 1: - port = MC_DEFAULT_PORT - else: - port = int(host_port[1]) if PY_MAJOR_VERSION > 2: - host = PyUnicode_AsUTF8String(host) - alias = PyUnicode_AsUTF8String(alias) if alias else None - servers_.append((host, port, alias)) + srv = PyUnicode_AsUTF8String(srv) + srv = PyString_AsString(srv) + servers_.append(srv) Py_INCREF(servers_) for i in range(n): - host, port, alias = servers_[i] - c_hosts[i] = PyString_AsString(host) - c_ports[i] = PyInt_AsLong(port) - if alias is None: - c_aliases[i] = NULL + c_split = splitServerString(servers_[i]) + + c_hosts[i] = c_split.host + c_aliases[i] = c_split.alias + if c_split.port == NULL: + c_ports[i] = MC_DEFAULT_PORT else: - c_aliases[i] = PyString_AsString(alias) + c_ports[i] = PyInt_AsLong(int(c_split.port)) if init: rv = self._imp.init(c_hosts, c_ports, n, c_aliases) diff --git a/misc/.cppcheck-supp b/misc/.cppcheck-supp index 5a2e13ef..0973d40e 100644 --- a/misc/.cppcheck-supp +++ b/misc/.cppcheck-supp @@ -45,5 +45,5 @@ unusedFunction:src/c_client.cpp:149 unusedFunction:src/c_client.cpp:154 unusedFunction:src/c_client.cpp:159 unusedFunction:src/Utility.cpp:43 -*:src/Connection.cpp:30 -*:src/Connection.cpp:35 +*:src/Connection.cpp:32 +*:src/Connection.cpp:37 diff --git a/misc/memcached_server b/misc/memcached_server index 77a71ac8..d47acb13 100755 --- a/misc/memcached_server +++ b/misc/memcached_server @@ -32,6 +32,18 @@ function start() fi } +function unix() +{ + name="${1:-unix_test}" + if [ ! -f $basedir/var/log/${name}.log ]; then + mkdir -p $basedir/var/log + touch $basedir/var/log/${name}.log + fi + mkdir -p $basedir/var/run + $cmd -d -u $USER -s $basedir/var/run/${name}.socket -t $threads -m ${memory} -P $basedir/var/run/${name}.pid > $basedir/var/log/${name}.log 2>&1 + echo "Starting the memcached server on '$basedir/var/run/${name}.socket'... " +} + function stop() { port="$1" @@ -41,7 +53,6 @@ function stop() kill -TERM `ps -ef | grep "$cmd" | grep $port | grep -v grep | awk '{ print $2 }'` echo "Stopping the memcached server on port '$port'... " fi - rm -rf $basedir } case "$1" in @@ -64,6 +75,7 @@ case "$1" in stop $port & done wait + rm -rf $basedir fi ;; restart) @@ -78,6 +90,27 @@ case "$1" in wait fi ;; + unix) + shift + unix $@ + ;; + startall) + unix & + for port in $PORTS; do + start $port & + done + wait + ;; + stopall) + if [ `ls $basedir/var/run/ | grep -c .pid` -ge 1 ]; then + names="`basename $basedir/var/run/*.pid | cut -d. -f1`" + for name in $names; do + stop $name & + done + fi + wait + rm -rf $basedir + ;; *) printf 'Usage: %s {start|stop|restart} \n' "$prog" exit 1 diff --git a/misc/travis/cpptest.sh b/misc/travis/cpptest.sh index c095ce76..033fb737 100755 --- a/misc/travis/cpptest.sh +++ b/misc/travis/cpptest.sh @@ -2,7 +2,7 @@ set -ex echo "CXX=${CXX}" -./misc/memcached_server start &>/dev/null & +./misc/memcached_server startall &>/dev/null & python misc/generate_hash_dataset.py tests/resources/keys.txt &>/dev/null & mkdir -p build cd build @@ -11,4 +11,4 @@ make -j8 &>/dev/null wait ARGS=-V make test cd .. -./misc/memcached_server stop &>/dev/null & +./misc/memcached_server stopall &>/dev/null & diff --git a/src/Common.cpp b/src/Common.cpp index 674750a6..2e75ffca 100644 --- a/src/Common.cpp +++ b/src/Common.cpp @@ -50,5 +50,48 @@ const char* errCodeToString(err_code_t err) { } } +bool isLocalSocket(const char* host) { + // errors on the side of false negatives, allowing syntax expansion; + // starting slash to denote socket paths is from pylibmc + return host[0] == '/'; +} + +// modifies input string and output pointers reference input +server_string_split_t splitServerString(char* input) { + bool escaped = false; + server_string_split_t res = { input, NULL, NULL }; + for (;; input++) { + switch (*input) + { + case '\0': + return res; + case ':': + if (res.alias == NULL) { + *input = '\0'; + if (res.port == NULL) { + res.port = input + 1; + } + } + escaped = false; + continue; + case ' ': + if (!escaped) { + *input = '\0'; + if (res.alias == NULL) { + res.alias = input + 1; + continue; + } else { + return res; + } + } + default: + escaped = false; + continue; + case '\\': + escaped ^= 1; + } + } +} + } // namespace mc } // namespace douban diff --git a/src/Connection.cpp b/src/Connection.cpp index 2be7de55..0b95c8b3 100644 --- a/src/Connection.cpp +++ b/src/Connection.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -10,6 +11,7 @@ #include #include +#include "Common.h" #include "Connection.h" #include "Keywords.h" @@ -21,8 +23,8 @@ namespace mc { Connection::Connection() : m_counter(0), m_port(0), m_socketFd(-1), - m_alive(false), m_hasAlias(false), m_deadUntil(0), - m_connectTimeout(MC_DEFAULT_CONNECT_TIMEOUT), + m_alive(false), m_hasAlias(false), m_unixSocket(false), + m_deadUntil(0), m_connectTimeout(MC_DEFAULT_CONNECT_TIMEOUT), m_retryTimeout(MC_DEFAULT_RETRY_TIMEOUT), m_maxRetries(MC_DEFAULT_MAX_RETRIES), m_retires(0) { m_name[0] = '\0'; @@ -45,9 +47,14 @@ Connection::~Connection() { int Connection::init(const char* host, uint32_t port, const char* alias) { snprintf(m_host, sizeof m_host, "%s", host); m_port = port; + m_unixSocket = isLocalSocket(m_host); if (alias == NULL) { m_hasAlias = false; - snprintf(m_name, sizeof m_name, "%s:%u", m_host, m_port); + if (m_unixSocket) { + snprintf(m_name, sizeof m_name, "%s", m_host); + } else { + snprintf(m_name, sizeof m_name, "%s:%u", m_host, m_port); + } } else { m_hasAlias = true; snprintf(m_name, sizeof m_name, "%s", alias); @@ -59,6 +66,10 @@ int Connection::init(const char* host, uint32_t port, const char* alias) { int Connection::connect() { assert(!m_alive); this->close(); + if (m_unixSocket) { + return unixSocketConnect(); + } + struct addrinfo hints, *server_addrinfo = NULL, *ai_ptr = NULL; memset(&hints, 0, sizeof hints); hints.ai_family = AF_INET; @@ -104,7 +115,7 @@ int Connection::connect() { } // make sure the connection is established - if (connectPoll(fd, ai_ptr) == 0) { + if (connectPoll(fd, ai_ptr->ai_addr, ai_ptr->ai_addrlen) == 0) { m_socketFd = fd; m_alive = true; break; @@ -121,8 +132,40 @@ int Connection::connect() { return m_alive ? 0 : -1; } -int Connection::connectPoll(int fd, struct addrinfo* ai_ptr) { - int conn_rv = ::connect(fd, ai_ptr->ai_addr, ai_ptr->ai_addrlen); +int Connection::unixSocketConnect() { + int fd, flags, opt_keepalive = 1; + + if ((fd = socket(AF_UNIX, SOCK_STREAM, 0)) == -1) { + log_err("socket()"); + return -1; + } + + if ((flags = fcntl(fd, F_GETFL, 0)) < 0 || + fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0) { + log_err("setting O_NONBLOCK"); + ::close(fd); + return -1; + } + + setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, (void *)&opt_keepalive, sizeof opt_keepalive); + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + // un.h UNIX_PATH_MAX < netdb.h NI_MAXHOST + // storing the unix path as a host doesn't limit the input but can overflow + strncpy(addr.sun_path, m_host, sizeof(addr.sun_path) - 1); + assert(strcmp(addr.sun_path, m_host) == 0); + if (connectPoll(fd, (const struct sockaddr *)&addr, sizeof addr) != 0) { + return -1; + } + m_socketFd = fd; + m_alive = true; + return 0; +} + +int Connection::connectPoll(int fd, const sockaddr* ai_addr, const socklen_t ai_addrlen) { + int conn_rv = ::connect(fd, ai_addr, ai_addrlen); if (conn_rv == 0) { return 0; } diff --git a/src/c_client.cpp b/src/c_client.cpp index 9fab71c9..88b9d128 100644 --- a/src/c_client.cpp +++ b/src/c_client.cpp @@ -159,3 +159,7 @@ err_code_t client_quit(void* client) { const char* err_code_to_string(err_code_t err) { return douban::mc::errCodeToString(err); } + +server_string_split_t splitServerString(char* input) { + return douban::mc::splitServerString(input); +} diff --git a/src/golibmc.go b/src/golibmc.go index dcdcc235..924e98ab 100644 --- a/src/golibmc.go +++ b/src/golibmc.go @@ -338,29 +338,21 @@ func (client *Client) newConn() (*conn, error) { cAliases := make([]*C.char, n) for i, srv := range client.servers { - addrAndAlias := strings.Split(srv, " ") + csrv := C.CString(srv) + defer C.free(unsafe.Pointer(csrv)) + split := C.splitServerString(csrv) - addr := addrAndAlias[0] - if len(addrAndAlias) == 2 { - cAlias := C.CString(addrAndAlias[1]) - defer C.free(unsafe.Pointer(cAlias)) - cAliases[i] = cAlias - } - - hostAndPort := strings.Split(addr, ":") - host := hostAndPort[0] - cHost := C.CString(host) - defer C.free(unsafe.Pointer(cHost)) - cHosts[i] = cHost + cAliases[i] = split.alias + cHosts[i] = split.host - if len(hostAndPort) == 2 { - port, err := strconv.Atoi(hostAndPort[1]) + if split.port == nil { + cPorts[i] = C.uint32_t(DefaultPort) + } else { + port, err := strconv.Atoi(C.GoString(split.port)) if err != nil { return nil, err } cPorts[i] = C.uint32_t(port) - } else { - cPorts[i] = C.uint32_t(DefaultPort) } } diff --git a/src/version.go b/src/version.go index 28acf41f..38f0f52d 100644 --- a/src/version.go +++ b/src/version.go @@ -1,9 +1,9 @@ package golibmc -const _Version = "v1.4.2" +const _Version = "v1.4.3" const _Author = "mckelvin" const _Email = "mckelvin@users.noreply.github.com" -const _Date = "Thu May 20 18:44:42 2021 +0800" +const _Date = "Fri Dec 1 07:43:12 2023 +0800" // Version of the package const Version = _Version diff --git a/tests/test_common.h b/tests/test_common.h index 7db98a01..765cf046 100644 --- a/tests/test_common.h +++ b/tests/test_common.h @@ -23,6 +23,23 @@ void gen_random(char *s, const int len) { } +mc::Client* md5Client(const char* const * hosts, const uint32_t* ports, const size_t n, + const char* const * aliases = NULL) { + mc::Client* client = new mc::Client(); + client->config(CFG_HASH_FUNCTION, OPT_HASH_MD5); + client->init(hosts, ports, n, aliases); + broadcast_result_t* results; + size_t nHosts; + int ret = client->version(&results, &nHosts); + client->destroyBroadcastResult(); + if (ret != 0) { + delete client; + return NULL; + } + return client; +} + + mc::Client* newClient(int n) { assert(n <= 20); const char * hosts[] = { @@ -73,18 +90,7 @@ mc::Client* newClient(int n) { "sierra", "tango" }; - mc::Client* client = new mc::Client(); - client->config(CFG_HASH_FUNCTION, OPT_HASH_MD5); - client->init(hosts, ports, n, aliases); - broadcast_result_t* results; - size_t nHosts; - int ret = client->version(&results, &nHosts); - client->destroyBroadcastResult(); - if (ret != 0) { - delete client; - return NULL; - } - return client; + return md5Client(hosts, ports, n, aliases); } diff --git a/tests/test_unix.cpp b/tests/test_unix.cpp new file mode 100644 index 00000000..71743e57 --- /dev/null +++ b/tests/test_unix.cpp @@ -0,0 +1,59 @@ +#include "Common.h" +#include "Client.h" +#include "test_common.h" + +#include +#include +#include "gtest/gtest.h" + +using douban::mc::Client; +using douban::mc::tests::md5Client; +using douban::mc::splitServerString; + +Client* newUnixClient() { + const char * hosts[] = { "/tmp/env_mc_dev/var/run/unix_test.socket" }; + const uint32_t ports[] = { 0 }; + struct stat info; + // fails if ../misc/memcached_server wasn't started with startall or unix + EXPECT_EQ(stat(hosts[0], &info), 0); + + return md5Client(hosts, ports, 1); +} + +TEST(test_unix, establish_connection) { + Client* client = newUnixClient(); + ASSERT_TRUE(client != NULL); + delete client; +} + +TEST(test_unix, host_parse_regression) { + char test[] = "127.0.0.1:21211 testing"; + server_string_split_t out = splitServerString(test); + ASSERT_STREQ(out.host, "127.0.0.1"); + ASSERT_STREQ(out.port, "21211"); + ASSERT_STREQ(out.alias, "testing"); +} + +TEST(test_unix, socket_path_spaces) { + char test[] = "/tmp/spacey\\ path testing"; + server_string_split_t out = splitServerString(test); + ASSERT_STREQ(out.host, "/tmp/spacey\\ path"); + ASSERT_EQ(out.port, nullptr); + ASSERT_STREQ(out.alias, "testing"); +} + +TEST(test_unix, socket_path_escaping) { + char test[] = "/tmp/spicy\\\\ path testing"; + server_string_split_t out = splitServerString(test); + ASSERT_STREQ(out.host, "/tmp/spicy\\\\"); + ASSERT_EQ(out.port, nullptr); + ASSERT_STREQ(out.alias, "path"); +} + +TEST(test_unix, alias_space_escaping) { + char test[] = "/tmp/path testing\\ alias"; + server_string_split_t out = splitServerString(test); + ASSERT_STREQ(out.host, "/tmp/path"); + ASSERT_EQ(out.port, nullptr); + ASSERT_STREQ(out.alias, "testing\\ alias"); +}