From ad6baa121f4c6fbc7622d83d606210ac22549c78 Mon Sep 17 00:00:00 2001 From: sewenew Date: Sun, 31 Dec 2023 22:34:24 +0800 Subject: [PATCH] add shard pub/sub support --- src/sw/redis++/async_redis.h | 9 ++++ src/sw/redis++/async_redis_cluster.cpp | 13 ++++++ src/sw/redis++/async_redis_cluster.h | 11 +++++ src/sw/redis++/async_subscriber.cpp | 18 ++++++++ src/sw/redis++/async_subscriber.h | 50 ++++++++++++++++++++ src/sw/redis++/async_subscriber_impl.cpp | 38 ++++++++++++++++ src/sw/redis++/async_subscriber_impl.h | 9 ++++ src/sw/redis++/cmd_formatter.h | 38 ++++++++++++++++ src/sw/redis++/command.h | 44 ++++++++++++++++++ src/sw/redis++/redis.cpp | 6 +++ src/sw/redis++/redis.h | 2 + src/sw/redis++/redis_cluster.cpp | 15 ++++++ src/sw/redis++/redis_cluster.h | 4 ++ src/sw/redis++/subscriber.cpp | 44 ++++++++++++++++++ src/sw/redis++/subscriber.h | 58 ++++++++++++++++++++++++ 15 files changed, 359 insertions(+) diff --git a/src/sw/redis++/async_redis.h b/src/sw/redis++/async_redis.h index 904c4e4..211c0c6 100644 --- a/src/sw/redis++/async_redis.h +++ b/src/sw/redis++/async_redis.h @@ -1616,6 +1616,15 @@ class AsyncRedis { _callback_fmt_command(std::forward(cb), fmt::publish, channel, message); } + Future spublish(const StringView &channel, const StringView &message) { + return _command(fmt::spublish, channel, message); + } + + template + void spublish(const StringView &channel, const StringView &message, Callback &&cb) { + _callback_fmt_command(std::forward(cb), fmt::spublish, channel, message); + } + // co_command* are used internally. DO NOT use them. template diff --git a/src/sw/redis++/async_redis_cluster.cpp b/src/sw/redis++/async_redis_cluster.cpp index 0b1d54c..4b0dc6f 100644 --- a/src/sw/redis++/async_redis_cluster.cpp +++ b/src/sw/redis++/async_redis_cluster.cpp @@ -62,6 +62,19 @@ AsyncSubscriber AsyncRedisCluster::subscriber() { return AsyncSubscriber(_loop, std::move(connection)); } +AsyncSubscriber AsyncRedisCluster::subscriber(const StringView &hash_tag) { + assert(_pool); + + _pool->update(); + + auto opts = _pool->connection_options(hash_tag); + + auto connection = std::make_shared(opts, _loop.get()); + connection->set_subscriber_mode(); + + return AsyncSubscriber(_loop, std::move(connection)); +} + } } diff --git a/src/sw/redis++/async_redis_cluster.h b/src/sw/redis++/async_redis_cluster.h index c0a9c93..fc6a1cc 100644 --- a/src/sw/redis++/async_redis_cluster.h +++ b/src/sw/redis++/async_redis_cluster.h @@ -53,6 +53,8 @@ class AsyncRedisCluster { AsyncSubscriber subscriber(); + AsyncSubscriber subscriber(const StringView &hash_tag); + template auto command(const StringView &cmd_name, const StringView &key, Args &&...args) -> typename std::enable_if::type, @@ -1070,6 +1072,15 @@ class AsyncRedisCluster { _callback_fmt_command(std::forward(cb), fmt::publish, channel, message); } + Future spublish(const StringView &channel, const StringView &message) { + return _command(fmt::spublish, channel, message); + } + + template + void spublish(const StringView &channel, const StringView &message, Callback &&cb) { + _callback_fmt_command(std::forward(cb), fmt::spublish, channel, message); + } + // co_command* are used internally. DO NOT use them. template diff --git a/src/sw/redis++/async_subscriber.cpp b/src/sw/redis++/async_subscriber.cpp index 7dbe35c..889e45b 100644 --- a/src/sw/redis++/async_subscriber.cpp +++ b/src/sw/redis++/async_subscriber.cpp @@ -91,6 +91,24 @@ Future AsyncSubscriber::punsubscribe(const StringView &channel) { return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::punsubscribe(channel)))); } +Future AsyncSubscriber::ssubscribe(const StringView &channel) { + _check_connection(); + + return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::ssubscribe(channel)))); +} + +Future AsyncSubscriber::sunsubscribe() { + _check_connection(); + + return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::sunsubscribe()))); +} + +Future AsyncSubscriber::sunsubscribe(const StringView &channel) { + _check_connection(); + + return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::sunsubscribe(channel)))); +} + void AsyncSubscriber::_check_connection() { if (!_connection || _connection->broken()) { throw Error("Connection is broken"); diff --git a/src/sw/redis++/async_subscriber.h b/src/sw/redis++/async_subscriber.h index d058fca..9039f85 100644 --- a/src/sw/redis++/async_subscriber.h +++ b/src/sw/redis++/async_subscriber.h @@ -54,6 +54,9 @@ class AsyncSubscriber { template void on_pmessage(PMsgCb &&pmsg_callback); + template + void on_smessage(SMsgCb &&smsg_callback); + template void on_meta(MetaCb &&meta_callback); @@ -104,6 +107,28 @@ class AsyncSubscriber { return punsubscribe(channels.begin(), channels.end()); } + Future ssubscribe(const StringView &channel); + + template + Future ssubscribe(Input first, Input last); + + template + Future ssubscribe(std::initializer_list channels) { + return ssubscribe(channels.begin(), channels.end()); + } + + Future sunsubscribe(); + + Future sunsubscribe(const StringView &channel); + + template + Future sunsubscribe(Input first, Input last); + + template + Future sunsubscribe(std::initializer_list channels) { + return sunsubscribe(channels.begin(), channels.end()); + } + private: friend class AsyncRedis; @@ -134,6 +159,13 @@ void AsyncSubscriber::on_pmessage(PMsgCb &&pmsg_callback) { _connection->subscriber().on_pmessage(std::forward(pmsg_callback)); } +template +void AsyncSubscriber::on_smessage(SMsgCb &&smsg_callback) { + _check_connection(); + + _connection->subscriber().on_smessage(std::forward(smsg_callback)); +} + template void AsyncSubscriber::on_meta(MetaCb &&meta_callback) { _check_connection(); @@ -184,6 +216,24 @@ Future AsyncSubscriber::punsubscribe(Input first, Input last) { return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::punsubscribe_range(first, last)))); } +template +Future AsyncSubscriber::ssubscribe(Input first, Input last) { + range_check("ssubscribe", first, last); + + _check_connection(); + + return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::ssubscribe_range(first, last)))); +} + +template +Future AsyncSubscriber::sunsubscribe(Input first, Input last) { + range_check("sunsubscribe", first, last); + + _check_connection(); + + return _send(SubscribeEventUPtr(new SubscribeEvent(fmt::sunsubscribe_range(first, last)))); +} + } } diff --git a/src/sw/redis++/async_subscriber_impl.cpp b/src/sw/redis++/async_subscriber_impl.cpp index 1162135..26ba7c9 100644 --- a/src/sw/redis++/async_subscriber_impl.cpp +++ b/src/sw/redis++/async_subscriber_impl.cpp @@ -66,10 +66,16 @@ void AsyncSubscriberImpl::_run_callback(redisReply &reply) { _handle_pmessage(reply); break; + case Subscriber::MsgType::SMESSAGE: + _handle_smessage(reply); + break; + case Subscriber::MsgType::SUBSCRIBE: case Subscriber::MsgType::UNSUBSCRIBE: case Subscriber::MsgType::PSUBSCRIBE: case Subscriber::MsgType::PUNSUBSCRIBE: + case Subscriber::MsgType::SSUBSCRIBE: + case Subscriber::MsgType::SUNSUBSCRIBE: _handle_meta(type, reply); break; @@ -93,6 +99,8 @@ Subscriber::MsgType AsyncSubscriberImpl::_msg_type(const std::string &type) cons return Subscriber::MsgType::MESSAGE; } else if ("pmessage" == type) { return Subscriber::MsgType::PMESSAGE; + } else if ("smessage" == type) { + return Subscriber::MsgType::SMESSAGE; } else if ("subscribe" == type) { return Subscriber::MsgType::SUBSCRIBE; } else if ("unsubscribe" == type) { @@ -101,6 +109,10 @@ Subscriber::MsgType AsyncSubscriberImpl::_msg_type(const std::string &type) cons return Subscriber::MsgType::PSUBSCRIBE; } else if ("punsubscribe" == type) { return Subscriber::MsgType::PUNSUBSCRIBE; + } else if ("ssubscribe" == type) { + return Subscriber::MsgType::SSUBSCRIBE; + } else if ("sunsubscribe" == type) { + return Subscriber::MsgType::SUNSUBSCRIBE; } else { return Subscriber::MsgType::UNKNOWN; } @@ -164,6 +176,32 @@ void AsyncSubscriberImpl::_handle_pmessage(redisReply &reply) { _pmsg_callback(std::move(pattern), std::move(channel), std::move(msg)); } +void AsyncSubscriberImpl::_handle_smessage(redisReply &reply) { + if (_smsg_callback == nullptr) { + return; + } + + if (reply.elements != 3) { + throw ProtoError("Expect 3 sub replies"); + } + + assert(reply.element != nullptr); + + auto *channel_reply = reply.element[1]; + if (channel_reply == nullptr) { + throw ProtoError("Null channel reply"); + } + auto channel = reply::parse(*channel_reply); + + auto *msg_reply = reply.element[2]; + if (msg_reply == nullptr) { + throw ProtoError("Null message reply"); + } + auto msg = reply::parse(*msg_reply); + + _smsg_callback(std::move(channel), std::move(msg)); +} + void AsyncSubscriberImpl::_handle_meta(Subscriber::MsgType type, redisReply &reply) { if (_meta_callback == nullptr) { return; diff --git a/src/sw/redis++/async_subscriber_impl.h b/src/sw/redis++/async_subscriber_impl.h index ea213d6..ca9e019 100644 --- a/src/sw/redis++/async_subscriber_impl.h +++ b/src/sw/redis++/async_subscriber_impl.h @@ -40,6 +40,11 @@ class AsyncSubscriberImpl { _pmsg_callback = std::forward(pmsg_callback); } + template + void on_smessage(SMsgCb &&smsg_callback) { + _smsg_callback = std::forward(smsg_callback); + } + template void on_meta(MetaCb &&meta_callback) { _meta_callback = std::forward(meta_callback); @@ -63,6 +68,8 @@ class AsyncSubscriberImpl { void _handle_pmessage(redisReply &reply); + void _handle_smessage(redisReply &reply); + void _handle_meta(Subscriber::MsgType type, redisReply &reply); std::function _msg_callback; @@ -70,6 +77,8 @@ class AsyncSubscriberImpl { std::function _pmsg_callback; + std::function _smsg_callback; + std::function _meta_callback; diff --git a/src/sw/redis++/cmd_formatter.h b/src/sw/redis++/cmd_formatter.h index 0c0f0e9..abf1517 100644 --- a/src/sw/redis++/cmd_formatter.h +++ b/src/sw/redis++/cmd_formatter.h @@ -813,6 +813,12 @@ inline FormattedCommand publish(const StringView &channel, const StringView &mes message.data(), message.size()); } +inline FormattedCommand spublish(const StringView &channel, const StringView &message) { + return format_cmd("SPUBLISH %b %b", + channel.data(), channel.size(), + message.data(), message.size()); +} + inline FormattedCommand punsubscribe() { return format_cmd("PUNSUBSCRIBE"); } @@ -863,6 +869,38 @@ inline FormattedCommand unsubscribe_range(Input first, Input last) { return format_cmd(args); } +inline FormattedCommand ssubscribe(const StringView &channel) { + return format_cmd("SSUBSCRIBE %b", channel.data(), channel.size()); +} + +template +inline FormattedCommand ssubscribe_range(Input first, Input last) { + assert(first != last); + + CmdArgs args; + args << "SSUBSCRIBE" << std::make_pair(first, last); + + return format_cmd(args); +} + +inline FormattedCommand sunsubscribe() { + return format_cmd("SUNSUBSCRIBE"); +} + +inline FormattedCommand sunsubscribe(const StringView &channel) { + return format_cmd("SUNSUBSCRIBE %b", channel.data(), channel.size()); +} + +template +inline FormattedCommand sunsubscribe_range(Input first, Input last) { + assert(first != last); + + CmdArgs args; + args << "SUNSUBSCRIBE" << std::make_pair(first, last); + + return format_cmd(args); +} + } } diff --git a/src/sw/redis++/command.h b/src/sw/redis++/command.h index 761381f..3a170cf 100644 --- a/src/sw/redis++/command.h +++ b/src/sw/redis++/command.h @@ -1594,6 +1594,14 @@ inline void publish(Connection &connection, message.data(), message.size()); } +inline void spublish(Connection &connection, + const StringView &channel, + const StringView &message) { + connection.send("SPUBLISH %b %b", + channel.data(), channel.size(), + message.data(), message.size()); +} + inline void punsubscribe(Connection &connection) { connection.send("PUNSUBSCRIBE"); } @@ -1650,6 +1658,42 @@ inline void unsubscribe_range(Connection &connection, Input first, Input last) { connection.send(args); } +inline void ssubscribe(Connection &connection, const StringView &channel) { + connection.send("SSUBSCRIBE %b", channel.data(), channel.size()); +} + +template +inline void ssubscribe_range(Connection &connection, Input first, Input last) { + if (first == last) { + throw Error("SSUBSCRIBE: no key specified"); + } + + CmdArgs args; + args << "SSUBSCRIBE" << std::make_pair(first, last); + + connection.send(args); +} + +inline void sunsubscribe(Connection &connection) { + connection.send("SUNSUBSCRIBE"); +} + +inline void sunsubscribe(Connection &connection, const StringView &channel) { + connection.send("SUNSUBSCRIBE %b", channel.data(), channel.size()); +} + +template +inline void sunsubscribe_range(Connection &connection, Input first, Input last) { + if (first == last) { + throw Error("SUNSUBSCRIBE: no key specified"); + } + + CmdArgs args; + args << "SUNSUBSCRIBE" << std::make_pair(first, last); + + connection.send(args); +} + // Transaction commands. inline void discard(Connection &connection) { diff --git a/src/sw/redis++/redis.cpp b/src/sw/redis++/redis.cpp index e4e6ddc..3c0ed4f 100644 --- a/src/sw/redis++/redis.cpp +++ b/src/sw/redis++/redis.cpp @@ -863,6 +863,12 @@ long long Redis::publish(const StringView &channel, const StringView &message) { return reply::parse(*reply); } +long long Redis::spublish(const StringView &channel, const StringView &message) { + auto reply = command(cmd::spublish, channel, message); + + return reply::parse(*reply); +} + // Transaction commands. void Redis::watch(const StringView &key) { diff --git a/src/sw/redis++/redis.h b/src/sw/redis++/redis.h index 84069c3..69b1cf2 100644 --- a/src/sw/redis++/redis.h +++ b/src/sw/redis++/redis.h @@ -3252,6 +3252,8 @@ class Redis { long long publish(const StringView &channel, const StringView &message); + long long spublish(const StringView &channel, const StringView &message); + // Transaction commands. void watch(const StringView &key); diff --git a/src/sw/redis++/redis_cluster.cpp b/src/sw/redis++/redis_cluster.cpp index a35e824..cc37305 100644 --- a/src/sw/redis++/redis_cluster.cpp +++ b/src/sw/redis++/redis_cluster.cpp @@ -80,6 +80,15 @@ Subscriber RedisCluster::subscriber() { return Subscriber(Connection(opts)); } +Subscriber RedisCluster::subscriber(const StringView &hash_tag) { + assert(_pool); + + _pool->async_update(); + + auto opts = _pool->connection_options(hash_tag); + return Subscriber(Connection(opts)); +} + // KEY commands. long long RedisCluster::del(const StringView &key) { @@ -755,6 +764,12 @@ long long RedisCluster::publish(const StringView &channel, const StringView &mes return reply::parse(*reply); } +long long RedisCluster::spublish(const StringView &channel, const StringView &message) { + auto reply = command(cmd::spublish, channel, message); + + return reply::parse(*reply); +} + // Stream commands. long long RedisCluster::xack(const StringView &key, const StringView &group, const StringView &id) { diff --git a/src/sw/redis++/redis_cluster.h b/src/sw/redis++/redis_cluster.h index 9f851fc..23cb171 100644 --- a/src/sw/redis++/redis_cluster.h +++ b/src/sw/redis++/redis_cluster.h @@ -68,6 +68,8 @@ class RedisCluster { Subscriber subscriber(); + Subscriber subscriber(const StringView &hash_tag); + /// @brief Run the given callback with each node in the cluster. /// The following is the prototype of the callback: void (Redis &r); /// @@ -1066,6 +1068,8 @@ class RedisCluster { long long publish(const StringView &channel, const StringView &message); + long long spublish(const StringView &channel, const StringView &message); + // Stream commands. long long xack(const StringView &key, const StringView &group, const StringView &id); diff --git a/src/sw/redis++/subscriber.cpp b/src/sw/redis++/subscriber.cpp index 06575ec..f6f5612 100644 --- a/src/sw/redis++/subscriber.cpp +++ b/src/sw/redis++/subscriber.cpp @@ -40,6 +40,12 @@ void Subscriber::subscribe(const StringView &channel) { cmd::subscribe(_connection, channel); } +void Subscriber::ssubscribe(const StringView &channel) { + _check_connection(); + + cmd::ssubscribe(_connection, channel); +} + void Subscriber::unsubscribe() { _check_connection(); @@ -101,10 +107,16 @@ void Subscriber::consume() { _handle_pmessage(*reply); break; + case MsgType::SMESSAGE: + _handle_smessage(*reply); + break; + case MsgType::SUBSCRIBE: case MsgType::UNSUBSCRIBE: case MsgType::PSUBSCRIBE: case MsgType::PUNSUBSCRIBE: + case MsgType::SSUBSCRIBE: + case MsgType::SUNSUBSCRIBE: _handle_meta(type, *reply); break; @@ -129,6 +141,8 @@ Subscriber::MsgType Subscriber::_msg_type(std::string const& type) const return MsgType::MESSAGE; } else if ("pmessage" == type) { return MsgType::PMESSAGE; + } else if ("smessage" == type) { + return MsgType::SMESSAGE; } else if ("subscribe" == type) { return MsgType::SUBSCRIBE; } else if ("unsubscribe" == type) { @@ -137,6 +151,10 @@ Subscriber::MsgType Subscriber::_msg_type(std::string const& type) const return MsgType::PSUBSCRIBE; } else if ("punsubscribe" == type) { return MsgType::PUNSUBSCRIBE; + } else if ("ssubscribe" == type) { + return MsgType::SSUBSCRIBE; + } else if ("sunsubscribe" == type) { + return MsgType::SUNSUBSCRIBE; } else { return MsgType::UNKNOWN; } @@ -174,6 +192,32 @@ void Subscriber::_handle_message(redisReply &reply) { _msg_callback(std::move(channel), std::move(msg)); } +void Subscriber::_handle_smessage(redisReply &reply) { + if (_smsg_callback == nullptr) { + return; + } + + if (reply.elements != 3) { + throw ProtoError("Expect 3 sub replies"); + } + + assert(reply.element != nullptr); + + auto *channel_reply = reply.element[1]; + if (channel_reply == nullptr) { + throw ProtoError("Null channel reply"); + } + auto channel = reply::parse(*channel_reply); + + auto *msg_reply = reply.element[2]; + if (msg_reply == nullptr) { + throw ProtoError("Null message reply"); + } + auto msg = reply::parse(*msg_reply); + + _smsg_callback(std::move(channel), std::move(msg)); +} + void Subscriber::_handle_pmessage(redisReply &reply) { if (_pmsg_callback == nullptr) { return; diff --git a/src/sw/redis++/subscriber.h b/src/sw/redis++/subscriber.h index 654f4b4..c8b8d78 100644 --- a/src/sw/redis++/subscriber.h +++ b/src/sw/redis++/subscriber.h @@ -77,6 +77,9 @@ class Subscriber { PUNSUBSCRIBE, MESSAGE, PMESSAGE, + SSUBSCRIBE, + SUNSUBSCRIBE, + SMESSAGE, UNKNOWN }; @@ -86,6 +89,9 @@ class Subscriber { template void on_pmessage(PMsgCb pmsg_callback); + template + void on_smessage(SMsgCb smsg_callback); + template void on_meta(MetaCb meta_callback); @@ -133,6 +139,28 @@ class Subscriber { punsubscribe(channels.begin(), channels.end()); } + void ssubscribe(const StringView &channel); + + template + void ssubscribe(Input first, Input last); + + template + void ssubscribe(std::initializer_list channels) { + ssubscribe(channels.begin(), channels.end()); + } + + void sunsubscribe(); + + void sunsubscribe(const StringView &channel); + + template + void sunsubscribe(Input first, Input last); + + template + void sunsubscribe(std::initializer_list channels) { + sunsubscribe(channels.begin(), channels.end()); + } + void consume(); private: @@ -151,6 +179,8 @@ class Subscriber { void _handle_pmessage(redisReply &reply); + void _handle_smessage(redisReply &reply); + void _handle_meta(MsgType type, redisReply &reply); using MsgCallback = std::function; @@ -159,6 +189,9 @@ class Subscriber { std::string channel, std::string msg)>; + using SMsgCallback = std::function; + + using MetaCallback = std::function; @@ -169,6 +202,8 @@ class Subscriber { PatternMsgCallback _pmsg_callback = nullptr; + SMsgCallback _smsg_callback = nullptr; + MetaCallback _meta_callback = nullptr; }; @@ -182,6 +217,11 @@ void Subscriber::on_pmessage(PMsgCb pmsg_callback) { _pmsg_callback = pmsg_callback; } +template +void Subscriber::on_smessage(SMsgCb smsg_callback) { + _smsg_callback = smsg_callback; +} + template void Subscriber::on_meta(MetaCb meta_callback) { _meta_callback = meta_callback; @@ -223,6 +263,24 @@ void Subscriber::punsubscribe(Input first, Input last) { cmd::punsubscribe_range(_connection, first, last); } +template +void Subscriber::ssubscribe(Input first, Input last) { + if (first == last) { + return; + } + + _check_connection(); + + cmd::ssubscribe_range(_connection, first, last); +} + +template +void Subscriber::sunsubscribe(Input first, Input last) { + _check_connection(); + + cmd::sunsubscribe_range(_connection, first, last); +} + } }