diff --git a/brukva/client.py b/brukva/client.py index c7f9e30..45bcd9a 100644 --- a/brukva/client.py +++ b/brukva/client.py @@ -83,10 +83,20 @@ def execution_context(callbacks, error_wrapper=None): return ExecutionContext(callbacks, error_wrapper) class Message(object): - def __init__(self, kind, channel, body): - self.kind = kind - self.channel = channel - self.body = body + ''' Wrapper Message object. + kind = command + channel = channel from which the message was received + pattern = subscription pattern + body = message body + ''' + def __init__(self, *args): + if len(args) == 3: + (self.kind, self.channel, self.body) = args + self.pattern = self.channel + elif len(args) == 4: + (self.kind, self.channel, self.pattern, self.body) = args + else: + raise ValueError('Invalid number of arguments') class CmdLine(object): def __init__(self, cmd, *args, **kwargs): @@ -275,7 +285,9 @@ def reply_ttl(r, *args, **kwargs): PUB_SUB_COMMANDS = set([ 'SUBSCRIBE', + 'PSUBSCRIBE', 'UNSUBSCRIBE', + 'PUNSUBSCRIBE', 'LISTEN', ]) @@ -315,7 +327,8 @@ def __init__(self, host='localhost', port=6379, password=None, reply_dict_from_pairs), string_keys_to_dict('HGET', reply_str), - string_keys_to_dict('SUBSCRIBE UNSUBSCRIBE LISTEN', + string_keys_to_dict('SUBSCRIBE UNSUBSCRIBE LISTEN ' + 'PSUBSCRIBE UNSUBSCRIBE', reply_pubsub_message), string_keys_to_dict('ZRANK ZREVRANK', reply_int), @@ -854,6 +867,12 @@ def hvals(self, key, callbacks=None): ### PUBSUB def subscribe(self, channels, callbacks=None): + self._subscribe('SUBSCRIBE', channels, callbacks) + + def psubscribe(self, channels, callbacks=None): + self._subscribe('PSUBSCRIBE', channels, callbacks) + + def _subscribe(self, cmd, channels, callbacks=None): callbacks = callbacks or [] if not isinstance(callbacks, Iterable): callbacks = [callbacks] @@ -861,19 +880,25 @@ def subscribe(self, channels, callbacks=None): channels = [channels] if not self.subscribed: callbacks = list(callbacks) + [self.on_subscribed] - self.execute_command('SUBSCRIBE', callbacks, *channels) + self.execute_command(cmd, callbacks, *channels) def on_subscribed(self, result): self.subscribed = True def unsubscribe(self, channels, callbacks=None): + self._unsubscribe('UNSUBSCRIBE', channels, callbacks) + + def punsubscribe(self, channels, callbacks=None): + self._unsubscribe('UNSUBSCRIBE', channels, callbacks) + + def _unsubscribe(self, cmd, channels, callbacks=None): callbacks = callbacks or [] if not isinstance(callbacks, Iterable): callbacks = [callbacks] if isinstance(channels, basestring): channels = [channels] callbacks = list(callbacks) - self.execute_command('UNSUBSCRIBE', callbacks, *channels) + self.execute_command(cmd, callbacks, *channels) def on_unsubscribed(self, *args, **kwargs): self.subscribed = False @@ -908,7 +933,7 @@ def error_wrapper(e): result = self.format_reply(cmd_listen, response) - if result.kind != 'message': + if result.kind not in ('message', 'pmessage'): waiting_stack = self._waiting_callbacks[result.kind.upper()] if len(waiting_stack) > 0: ctx.safe_call(waiting_stack.pop(0), result)