diff --git a/tasktiger/logging.py b/tasktiger/logging.py index dd8442ca..7a8e6c0d 100644 --- a/tasktiger/logging.py +++ b/tasktiger/logging.py @@ -14,16 +14,3 @@ def tasktiger_processor(logger, method_name, event_dict): event_dict['task_id'] = g['current_batch_task'].id return event_dict - - -def batch_param_iterator(params): - """ - Helper to set current batch task. - - This helper should be used in conjunction with tasktiger_processor - to facilitate logging of task ids. - """ - for i, p in enumerate(params): - g['current_batch_task'] = g['current_tasks'][i] - yield p - g['current_batch_task'] = None diff --git a/tasktiger/tasktiger.py b/tasktiger/tasktiger.py index 9ba92ef5..8c135e62 100644 --- a/tasktiger/tasktiger.py +++ b/tasktiger/tasktiger.py @@ -226,8 +226,12 @@ def init(self, connection=None, config=None, setup_structlog=False): def _get_current_task(self): if g['current_tasks'] is None: raise RuntimeError('Must be accessed from within a task') - if g['current_task_is_batch']: - raise RuntimeError('Must use current_tasks in a batch task.') + + if g['current_task_is_batch'] and g['current_batch_task']: + return g['current_batch_task'] + elif g['current_task_is_batch'] and not g['current_batch_task']: + raise RuntimeError('Must use batch_param_iterator in batch task.') + return g['current_tasks'][0] def _get_current_tasks(self): diff --git a/tasktiger/utils.py b/tasktiger/utils.py new file mode 100644 index 00000000..55910be2 --- /dev/null +++ b/tasktiger/utils.py @@ -0,0 +1,14 @@ +from ._internal import g + + +def batch_param_iterator(params): + """ + Helper to set current batch task. + + This helper should be used in conjunction with tasktiger_processor + to facilitate logging of task ids. + """ + for i, p in enumerate(params): + g['current_batch_task'] = g['current_tasks'][i] + yield p + g['current_batch_task'] = None diff --git a/tests/tasks.py b/tests/tasks.py index 58da158b..5d4b6966 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -2,10 +2,12 @@ from math import ceil import time +import pytest import redis -from tasktiger import RetryException +from tasktiger import RetryException, g from tasktiger.retry import fixed +from tasktiger.utils import batch_param_iterator from .config import DELAY, TEST_DB, REDIS_HOST from .utils import get_tiger @@ -139,18 +141,24 @@ def verify_current_task(): @tiger.task(batch=True, queue='batch') -def verify_current_tasks(tasks): +def verify_current_tasks(params): with redis.Redis( host=REDIS_HOST, db=TEST_DB, decode_responses=True ) as conn: try: - tasks = tiger.current_task + tiger.current_task except RuntimeError: # This is expected (we need to use current_tasks) tasks = tiger.current_tasks conn.rpush('task_ids', *[t.id for t in tasks]) + for i, p in enumerate(batch_param_iterator(params)): + assert tiger.current_task.id == g['current_tasks'][i].id + + with pytest.raises(RuntimeError): + tiger.current_task.id + @tiger.task() def sleep_task(delay=10): diff --git a/tests/test_logging.py b/tests/test_logging.py index f30f21b4..8d4dc675 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -4,7 +4,8 @@ import structlog from tasktiger import TaskTiger, Worker, g -from tasktiger.logging import tasktiger_processor, batch_param_iterator +from tasktiger.logging import tasktiger_processor +from tasktiger.utils import batch_param_iterator from .test_base import BaseTestCase from .utils import get_tiger, get_redis