Skip to content

Commit

Permalink
Handle exceptions in callback. (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
gandalf013 authored and darvid committed Jul 31, 2019
1 parent 1ebabf6 commit f1b92be
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
14 changes: 12 additions & 2 deletions hyperscan/hyperscanmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ static PyTypeObject StreamType;
typedef struct {
PyObject *callback;
PyObject *ctx;
int success;
} py_scan_callback_ctx;

typedef struct {
Expand Down Expand Up @@ -73,7 +74,13 @@ static int match_handler(unsigned int id, unsigned long long from,
gstate = PyGILState_Ensure();
PyObject *rv = PyObject_CallFunction(cctx->callback, "IIIIO", id,
from, to, flags, cctx->ctx);
int halt = rv == Py_None ? 0 : PyObject_IsTrue(rv);
int halt = 1;
if (rv == NULL) {
cctx->success = 0;
} else {
halt = rv == Py_None ? 0 : PyObject_IsTrue(rv);
cctx->success = 1;
}
PyGILState_Release(gstate);
Py_XDECREF(rv);
return halt;
Expand Down Expand Up @@ -235,7 +242,7 @@ static PyObject* Database_scan(Database *self, PyObject *args, PyObject *kwds) {
&data, &length, &ocallback, &flags,
&octx, &oscratch))
return NULL;
py_scan_callback_ctx cctx = {ocallback, octx};
py_scan_callback_ctx cctx = {ocallback, octx, 1};
Py_BEGIN_ALLOW_THREADS
err = hs_scan(
self->db,
Expand All @@ -248,6 +255,9 @@ static PyObject* Database_scan(Database *self, PyObject *args, PyObject *kwds) {
ocallback == Py_None ? NULL : (void*)&cctx
);
Py_END_ALLOW_THREADS
if (!cctx.success) {
return NULL;
}
HANDLE_HYPERSCAN_ERR(err, NULL);
Py_RETURN_NONE;
}
Expand Down
8 changes: 6 additions & 2 deletions tests/test_hyperscan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import sys

import pytest

import hyperscan
Expand Down Expand Up @@ -95,3 +93,9 @@ def test_database_deserialize(database_stream):
serialized = hyperscan.dumps(database_stream)
db = hyperscan.loads(bytearray(serialized))
assert id(db) != id(database_stream)

def test_database_exception_in_callback(database_block, mocker):
callback = mocker.Mock(side_effect=RuntimeError('oops'))

with pytest.raises(RuntimeError, match=r'^oops$'):
database_block.scan(b'foobar', match_event_handler=callback)

0 comments on commit f1b92be

Please sign in to comment.