From f1b92bedb751c4492d29a9e793a27c7941ff6f57 Mon Sep 17 00:00:00 2001 From: Alok Singhal Date: Wed, 31 Jul 2019 04:27:54 -0700 Subject: [PATCH] Handle exceptions in callback. (#15) --- hyperscan/hyperscanmodule.c | 14 ++++++++++++-- tests/test_hyperscan.py | 8 ++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/hyperscan/hyperscanmodule.c b/hyperscan/hyperscanmodule.c index 573ef0e..a2dead9 100644 --- a/hyperscan/hyperscanmodule.c +++ b/hyperscan/hyperscanmodule.c @@ -41,6 +41,7 @@ static PyTypeObject StreamType; typedef struct { PyObject *callback; PyObject *ctx; + int success; } py_scan_callback_ctx; typedef struct { @@ -73,7 +74,13 @@ static int match_handler(unsigned int id, unsigned long long from, gstate = PyGILState_Ensure(); PyObject *rv = PyObject_CallFunction(cctx->callback, "IIIIO", id, from, to, flags, cctx->ctx); - int halt = rv == Py_None ? 0 : PyObject_IsTrue(rv); + int halt = 1; + if (rv == NULL) { + cctx->success = 0; + } else { + halt = rv == Py_None ? 0 : PyObject_IsTrue(rv); + cctx->success = 1; + } PyGILState_Release(gstate); Py_XDECREF(rv); return halt; @@ -235,7 +242,7 @@ static PyObject* Database_scan(Database *self, PyObject *args, PyObject *kwds) { &data, &length, &ocallback, &flags, &octx, &oscratch)) return NULL; - py_scan_callback_ctx cctx = {ocallback, octx}; + py_scan_callback_ctx cctx = {ocallback, octx, 1}; Py_BEGIN_ALLOW_THREADS err = hs_scan( self->db, @@ -248,6 +255,9 @@ static PyObject* Database_scan(Database *self, PyObject *args, PyObject *kwds) { ocallback == Py_None ? NULL : (void*)&cctx ); Py_END_ALLOW_THREADS + if (!cctx.success) { + return NULL; + } HANDLE_HYPERSCAN_ERR(err, NULL); Py_RETURN_NONE; } diff --git a/tests/test_hyperscan.py b/tests/test_hyperscan.py index f8c4693..d725e40 100644 --- a/tests/test_hyperscan.py +++ b/tests/test_hyperscan.py @@ -1,5 +1,3 @@ -import sys - import pytest import hyperscan @@ -95,3 +93,9 @@ def test_database_deserialize(database_stream): serialized = hyperscan.dumps(database_stream) db = hyperscan.loads(bytearray(serialized)) assert id(db) != id(database_stream) + +def test_database_exception_in_callback(database_block, mocker): + callback = mocker.Mock(side_effect=RuntimeError('oops')) + + with pytest.raises(RuntimeError, match=r'^oops$'): + database_block.scan(b'foobar', match_event_handler=callback)