diff --git a/brukva/client.py b/brukva/client.py index b17f2c2..136aa7c 100644 --- a/brukva/client.py +++ b/brukva/client.py @@ -3,7 +3,7 @@ from functools import partial from itertools import izip import logging -from collections import Iterable +from collections import Iterable, defaultdict import weakref from tornado.ioloop import IOLoop @@ -21,13 +21,13 @@ def __init__(self, callbacks, error_wrapper=None): self.error_wrapper = error_wrapper self.is_active = True - def _call_callbacks(self, value): - if self.callbacks: - if isinstance(self.callbacks, Iterable): - for cb in self.callbacks: + def _call_callbacks(self, callbacks, value): + if callbacks: + if isinstance(callbacks, Iterable): + for cb in callbacks: cb(value) else: - self.callbacks(value) + callbacks(value) def __enter__(self): return self @@ -53,9 +53,15 @@ def enable(self): def ret_call(self, value): self.is_active = False - self._call_callbacks(value) + self._call_callbacks(self.callbacks, value) self.is_active = True + def safe_call(self, callbacks, value): + self.is_active = False + self._call_callbacks(callbacks, value) + self.is_active = True + + def execution_context(callbacks, error_wrapper=None): """ Syntax sugar. @@ -154,7 +160,6 @@ def write(self, data, try_left=None): if try_left > 0: try: - #print('try to write: %s'% data) self._stream.write(data) except IOError: self.disconnect() @@ -290,7 +295,6 @@ def __init__(self, host='localhost', port=6379, password=None, self.subscribed = False self.password = password self.selected_db = selected_db - self.write_try_num = 2 self.REPLY_MAP = dict_merge( string_keys_to_dict('AUTH BGREWRITEAOF BGSAVE DEL EXISTS EXPIRE HDEL HEXISTS ' 'HMSET MOVE MSET MSETNX SAVE SETNX', @@ -320,6 +324,7 @@ def __init__(self, host='localhost', port=6379, password=None, {'MULTI_PART': make_reply_assert_msg('QUEUED')}, ) + self._waiting_callbacks = defaultdict(list) self._pipeline = None def __repr__(self): @@ -408,8 +413,8 @@ def execute_command(self, cmd, callbacks, *args, **kwargs): self.connection.disconnect() raise e - if cmd == 'UNSUBSCRIBE' or self.subscribed and cmd == 'SUBSCRIBE': - ctx.ret_call(True) + if self.subscribed and cmd in ('SUBSCRIBE', 'UNSUBSCRIBE'): + self._waiting_callbacks[cmd].append(callbacks) return yield self.connection.queue_wait() @@ -847,7 +852,8 @@ def subscribe(self, channels, callbacks=None): callbacks = [callbacks] if isinstance(channels, basestring): channels = [channels] - callbacks = list(callbacks) + [self.on_subscribed] + if not self.subscribed: + callbacks = list(callbacks) + [self.on_subscribed] self.execute_command('SUBSCRIBE', callbacks, *channels) def on_subscribed(self, result): @@ -895,11 +901,15 @@ def error_wrapper(e): result = self.format_reply(cmd_listen, response) - if result.kind == 'unsubscribe' and result.body == 0: - self.on_unsubscribed() - self.connection.read_done() - ctx.ret_call(result) - break + if result.kind != 'message': + waiting_stack = self._waiting_callbacks[result.kind.upper()] + if len(waiting_stack) > 0: + ctx.safe_call(waiting_stack.pop(0), result) + + if result.kind == 'unsubscribe' and result.body == 0: + self.on_unsubscribed() + self.connection.read_done() + break else: ctx.ret_call(result) @@ -921,7 +931,7 @@ def execute_command(self, cmd, callbacks, *args, **kwargs): super(Pipeline, self).execute_command(cmd, callbacks, *args, **kwargs) elif cmd in PUB_SUB_COMMANDS: raise RequestError( - 'Client is not supposed to issue command %s in pipeline ' % cmd) + 'Client is not supposed to issue command %s in pipeline' % cmd) self.command_stack.append(CmdLine(cmd, *args, **kwargs)) def discard(self): # actually do nothing with redis-server, just flush command_stack