Skip to content

Commit

Permalink
Merge branch 'reconnect'
Browse files Browse the repository at this point in the history
  • Loading branch information
evilkost committed Apr 15, 2011
2 parents 93cca47 + aeabd47 commit 64c9930
Show file tree
Hide file tree
Showing 3 changed files with 525 additions and 80 deletions.
193 changes: 118 additions & 75 deletions brukva/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,47 @@

log = logging.getLogger('brukva.client')

@contextlib.contextmanager
def forward_error(callbacks, cleanup=None):
try:
yield callbacks
except Exception, e:
log.error(e)
if isinstance(callbacks, Iterable):
for cb in callbacks:
cb(e)
class ForwardErrorManager(object):
def __init__(self, callbacks):
self.callbacks = callbacks
self.is_active = True

def __enter__(self):
return self

def __exit__(self, type, value, tb):
if type is None:
return True

if self.is_active:
if isinstance(self.callbacks, Iterable):
for cb in self.callbacks:
cb(value)
else:
self.callbacks(value)
return True
else:
callbacks(e)
finally:
if cleanup:
cleanup()
return False

def disable(self):
self.is_active = False

def enable(self):
self.is_active = True

def forward_error(callbacks):
"""
Syntax sugar.
If some error occurred inside with block,
it will be suppressed and forwarded to callbacks.
Error handling can be disabled using context.disable(),
and re enabled again using context.enable().
@type callbacks: callable or iterator over callables
@rtype: context
"""
return ForwardErrorManager(callbacks)

class Message(object):
def __init__(self, kind, channel, body):
Expand Down Expand Up @@ -73,13 +100,15 @@ def format_pipeline_request(command_stack):
return ''.join(format(c.cmd, *c.args, **c.kwargs) for c in command_stack)

class Connection(object):
def __init__(self, host, port, on_reconnect, timeout=None, io_loop=None):
def __init__(self, host, port, on_connect, on_disconnect, timeout=None, io_loop=None):
self.host = host
self.port = port
self.on_reconnect = on_reconnect
self.on_connect = on_connect
self.on_disconnect = on_disconnect
self.timeout = timeout
self._stream = None
self._io_loop = io_loop
self.try_left = 2

self.in_progress = False
self.read_queue = []
Expand All @@ -91,8 +120,10 @@ def connect(self):
sock.settimeout(self.timeout)
sock.connect((self.host, self.port))
self._stream = IOStream(sock, io_loop=self._io_loop)
self.connected()
except socket.error, e:
raise ConnectionError(str(e))
self.on_connect()

def disconnect(self):
if self._stream:
Expand All @@ -102,44 +133,41 @@ def disconnect(self):
pass
self._stream = None

def write(self, data):
def write(self, data, try_left=None):
if try_left is None:
try_left = self.try_left
if not self._stream:
self.on_reconnect()
self.connect()
if not self._stream:
raise ConnectionError('Tried to write to non-existent connection')
else:
self._stream.write(data)

def consume(self, length):
if not self._stream:
self.on_reconnect()
if not self._stream:
raise ConnectionError('Tried to consume from non-existent connection')
self._stream.read_bytes(length, NOOP_CB)
if try_left > 0:
try:
#print('try to write: %s'% data)
self._stream.write(data)
except IOError:
self.disconnect()
self.write(data, try_left - 1)
else:
raise ConnectionError('Tried to write to non-existent connection')

def read(self, length, callback):
try:
if not self._stream:
self.client._sudden_disconnect([callback])
self.on_reconnect()
if not self._stream:
raise ConnectionError('Tried to read from non-existent connection')
self.disconnect()
raise ConnectionError('Tried to read from non-existent connection')
self._stream.read_bytes(length, callback)
except IOError:
self.client._sudden_disconnect([callback])
self.on_reconnect()
self.on_disconnect()

def readline(self, callback):
try:
if not self._stream:
self.client._sudden_disconnect([callback])
self.on_reconnect()
if not self._stream:
raise ConnectionError('Tried to read from non-existent connection')
self.disconnect()
raise ConnectionError('Tried to read from non-existent connection')
self._stream.read_until('\r\n', callback)
except IOError:
self.client._sudden_disconnect([callback])
self.on_reconnect()
self.on_disconnect()

def try_to_perform_read(self):
if not self.in_progress and self.read_queue:
Expand Down Expand Up @@ -234,16 +262,18 @@ def __getattr__(self, item):


class Client(object):
def __init__(self, host='localhost', port=6379, password=None, reconnect=False, io_loop=None):
def __init__(self, host='localhost', port=6379, password=None,
selected_db=None, io_loop=None):
self._io_loop = io_loop or IOLoop.instance()

self.connection = Connection(host, port, self.on_reconnect, io_loop=self._io_loop)
self.connection = Connection(host, port,
self.on_connect, self.on_disconnect, io_loop=self._io_loop)
self.async = _AsyncWrapper(weakref.proxy(self))
self.queue = []
self.current_cmd_line = None
self.subscribed = False
self.password = password
self.reconnect = reconnect
self.selected_db = selected_db
self.write_try_num = 2
self.REPLY_MAP = dict_merge(
string_keys_to_dict('AUTH BGREWRITEAOF BGSAVE DEL EXISTS EXPIRE HDEL HEXISTS '
'HMSET MOVE MSET MSETNX SAVE SETNX',
Expand Down Expand Up @@ -280,22 +310,30 @@ def __repr__(self):

def pipeline(self, transactional=False):
if not self._pipeline:
self._pipeline = Pipeline(io_loop = self._io_loop, transactional=transactional)
self._pipeline = Pipeline(
selected_db=self.selected_db,
io_loop = self._io_loop,
transactional=transactional
)
self._pipeline.connection = self.connection
return self._pipeline

#### connection

def connect(self):
self.connection.connect()
if self.password:
self.auth(self.password)

def disconnect(self):
self.connection.disconnect()

def on_reconnect(self):
if self.reconnect:
self.connect()
def on_connect(self):
if self.password:
self.auth(self.password)
if self.selected_db:
self.select(self.selected_db)

def on_disconnect(self, callbacks):
raise ConnectionError("Socket closed on remote end")
####

#### formatting
Expand Down Expand Up @@ -332,22 +370,18 @@ def call_callbacks(self, callbacks, *args, **kwargs):
for cb in callbacks:
cb(*args, **kwargs)

def _sudden_disconnect(self, callbacks):
self.connection.disconnect()
raise ConnectionError("Socket closed on remote end")

@process
def execute_command(self, cmd, callbacks, *args, **kwargs):
result = None
with forward_error(callbacks):
if callbacks is None:
callbacks = []
elif not hasattr(callbacks, '__iter__'):
callbacks = [callbacks]

try:
if self.reconnect and not self.connection.connected():
self.connect()
self.connection.write(self.format(cmd, *args, **kwargs))
except IOError:
except IOError, e:
self._sudden_disconnect(callbacks)
except Exception, e:
self.connection.disconnect()
Expand All @@ -359,13 +393,15 @@ def execute_command(self, cmd, callbacks, *args, **kwargs):
data = yield async(self.connection.readline)()
if not data:
result = None
self.connection.read_done()
raise Exception('TODO: [no data from connection->readline')
else:
response = yield self.process_data(data, cmd_line)
result = self.format_reply(cmd_line, response)

self.connection.read_done()
self.call_callbacks(callbacks, result)
self.connection.read_done()

self.call_callbacks(callbacks, result)

@async
@process
Expand All @@ -379,7 +415,6 @@ def process_data(self, data, cmd_line, callback):
response = []
else:
if len(data) == 0:
self.on_reconnect()
raise IOError('Disconnected')
head, tail = data[0], data[1:]

Expand All @@ -398,7 +433,7 @@ def process_data(self, data, cmd_line, callback):
else:
raise ResponseError('Unknown response type %s' % head, cmd_line)

callback(response)
callback(response)

@async
@process
Expand All @@ -414,7 +449,7 @@ def consume_multibulk(self, length, cmd_line, callback):
)
token = yield self.process_data(data, cmd_line) #FIXME error
tokens.append( token )
callback(tokens)
callback(tokens)

@async
@process
Expand All @@ -427,7 +462,7 @@ def consume_bulk(self, length, callback):
raise ResponseError('EmptyResponse')
else:
data = data[:-2]
callback(data)
callback(data)
####

### MAINTENANCE
Expand All @@ -450,6 +485,7 @@ def info(self, callbacks=None):
self.execute_command('INFO', callbacks)

def select(self, db, callbacks=None):
self.selected_db = db
self.execute_command('SELECT', callbacks, db)

def shutdown(self, callbacks=None):
Expand Down Expand Up @@ -813,7 +849,7 @@ def publish(self, channel, message, callbacks=None):
@process
def listen(self, callbacks=None):
# 'LISTEN' is just for receiving information, it is not actually sent anywhere
with forward_error(callbacks):
with forward_error(callbacks) as forward:
callbacks = callbacks or []
if not hasattr(callbacks, '__iter__'):
callbacks = [callbacks]
Expand All @@ -829,8 +865,10 @@ def listen(self, callbacks=None):
if isinstance(response, Exception):
raise response
result = self.format_reply(cmd_listen, response)
self.call_callbacks(callbacks, result)

forward.disable()
self.call_callbacks(callbacks, result)
forward.enable()
### CAS
def watch(self, key, callbacks=None):
self.execute_command('WATCH', callbacks, key)
Expand All @@ -845,19 +883,26 @@ def __init__(self, transactional, *args, **kwargs):
self.command_stack = []

def execute_command(self, cmd, callbacks, *args, **kwargs):
if cmd in ('AUTH'):
raise Exception('403')
if cmd in ('AUTH', 'SELECT'):
raise RuntimeError('cmd %s must not be in pipe ' % cmd)
self.command_stack.append(CmdLine(cmd, *args, **kwargs))

def discard(self): # actually do nothing with redis-server, just flush command_stack
self.command_stack = []

def _sudden_disconnect(self, callbacks, error=None):
self.connection.disconnect()
raise error or ConnectionError("Socket closed on remote end")
###
def select(self, db, callbacks=None):
self.selected_db = db
super(Pipeline, self).execute_command('SELECT', callbacks, db)

def auth(self, password, callbacks=None):
super(Pipeline, self).execute_command('AUTH', callbacks, password)
###


@process
def execute(self, callbacks):
results = None
with forward_error(callbacks):
command_stack = self.command_stack
self.command_stack = []
Expand All @@ -871,16 +916,17 @@ def execute(self, callbacks):
command_stack = [CmdLine('MULTI')] + command_stack + [CmdLine('EXEC')]

request = format_pipeline_request(command_stack)

try:
if self.reconnect and not self.connection.connected():
self.connect()
self.connection.write(request)
except IOError:
self.command_stack = []
self._sudden_disconnect(callbacks)
self.connection.disconnect()
raise ConnectionError("Socket closed on remote end")
except Exception, e:
self.command_stack = []
self._sudden_disconnect(callbacks, e)
self.connection.disconnect()
raise e

yield self.connection.queue_wait()
responses = []
Expand All @@ -891,7 +937,6 @@ def execute(self, callbacks):
data = yield async(self.connection.readline)()
if not data:
raise ResponseError('Not enough data after EXEC')

try:
cmd_line = cmds.next()
if self.transactional and cmd_line.cmd != 'EXEC':
Expand Down Expand Up @@ -922,6 +967,4 @@ def format_replies(cmd_lines, responses):
else:
results = format_replies(command_stack, responses)

self.call_callbacks(callbacks, results)


self.call_callbacks(callbacks, results)
Loading

0 comments on commit 64c9930

Please sign in to comment.