Skip to content

Commit

Permalink
Rewrite _move helper in Lua (to add Redis 7 support) (#284)
Browse files Browse the repository at this point in the history
* add 7 support in configs

* add a Lua implementation of move_task

* remove execute_pipeline

* list all used keys in KEYS

* address PR comments
  • Loading branch information
nsaje authored Jun 16, 2023
1 parent 62e5dff commit b1019e6
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 259 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
os: ['ubuntu-20.04']
redis-version: [4, 5, "6.2.6"]
redis-version: [4, 5, "6.2.6", "7.0.9"]
# Do not cancel any jobs when a single job fails
fail-fast: false
name: Python ${{ matrix.python-version }} on ${{ matrix.os }} with Redis ${{ matrix.redis-version }}
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3.7"
services:
redis:
image: redis:4.0.6
image: redis:7.0.9
expose:
- 6379
tasktiger:
Expand Down
67 changes: 0 additions & 67 deletions tasktiger/lua/execute_pipeline.lua

This file was deleted.

91 changes: 91 additions & 0 deletions tasktiger/lua/move_task.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
local function zadd_w_mode(key, score, member, mode)
if mode == "" then
redis.call('zadd', key, score, member)
elseif mode == "nx" then
zadd_noupdate({ key }, { score, member })
elseif mode == "min" then
zadd_update_min({ key }, { score, member })
else
error("mode " .. mode .. " unsupported")
end
end


local key_task_id = KEYS[1]
local key_task_id_executions = KEYS[2]
local key_task_id_executions_count = KEYS[3]
local key_from_state = KEYS[4]
local key_to_state = KEYS[5]
local key_active_queue = KEYS[6]
local key_queued_queue = KEYS[7]
local key_error_queue = KEYS[8]
local key_scheduled_queue = KEYS[9]
local key_activity = KEYS[10]

local id = ARGV[1]
local queue = ARGV[2]
local from_state = ARGV[3]
local to_state = ARGV[4]
local unique = ARGV[5]
local when = ARGV[6]
local mode = ARGV[7]
local publish_queued_tasks = ARGV[8]

local state_queues_keys_by_state = {
active = key_active_queue,
queued = key_queued_queue,
error = key_error_queue,
scheduled = key_scheduled_queue,
}
local key_from_state_queue = state_queues_keys_by_state[from_state]
local key_to_state_queue = state_queues_keys_by_state[to_state]

assert(redis.call('zscore', key_from_state_queue, id), '<FAIL_IF_NOT_IN_ZSET>')

if to_state ~= "" then
zadd_w_mode(key_to_state_queue, when, id, mode)
redis.call('sadd', key_to_state, queue)
end
redis.call('zrem', key_from_state_queue, id)

if to_state == "" then -- Remove the task if necessary
if unique == 'true' then
-- Delete executions if there were no errors
local to_delete = {
key_task_id_executions,
key_task_id_executions_count,
}
local keys = { unpack(to_delete) }
if from_state ~= 'error' then
table.insert(keys, key_error_queue)
end
-- keys=[to_delete + zsets], args=[len(to_delete), value]
delete_if_not_in_zsets(keys, { #to_delete, id })

-- Only delete task if it's not in any other queue
local to_delete = { key_task_id }
local zsets = {}
for i, v in pairs({ 'active', 'queued', 'error', 'scheduled' }) do
if v ~= from_state then
table.insert(zsets, state_queues_keys_by_state[v])
end
end
-- keys=[to_delete + zsets], args=[len(to_delete), value]
delete_if_not_in_zsets({ unpack(to_delete), unpack(zsets) }, { #to_delete, id })
else
-- Safe to remove
redis.call(
'del',
key_task_id,
key_task_id_executions,
key_task_id_executions_count
)
end
end

-- keys=[key, other_key], args=[member]
srem_if_not_exists({ key_from_state, key_from_state_queue }, { queue })

if to_state == 'queued' and publish_queued_tasks == 'true' then
redis.call('publish', key_activity, queue)
end
167 changes: 90 additions & 77 deletions tasktiger/redis_scripts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import os
from typing import Any, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

from redis import Redis
from redis.client import Pipeline
from redis.commands.core import Script

from ._internal import ACTIVE, ERROR, QUEUED, SCHEDULED

LOCAL_FUNC_TEMPLATE = """
local function {func_name}(KEYS, ARGV)
{func_body}
end
"""

# ARGV = { score, member }
ZADD_NOUPDATE_TEMPLATE = """
if {condition} redis.call('zscore', {key}, {member}) then
Expand Down Expand Up @@ -313,8 +321,14 @@ def __init__(self, redis: Redis) -> None:

self._get_expired_tasks = redis.register_script(GET_EXPIRED_TASKS)

self._execute_pipeline = self.register_script_from_file(
"lua/execute_pipeline.lua"
self._move_task = self.register_script_from_file(
"lua/move_task.lua",
include_functions={
"zadd_noupdate": ZADD_NOUPDATE,
"zadd_update_min": ZADD_UPDATE_MIN,
"srem_if_not_exists": SREM_IF_NOT_EXISTS,
"delete_if_not_in_zsets": DELETE_IF_NOT_IN_ZSETS,
},
)

@property
Expand All @@ -330,11 +344,25 @@ def can_replicate_commands(self) -> bool:
self._can_replicate_commands = result
return self._can_replicate_commands

def register_script_from_file(self, filename: str) -> Script:
def register_script_from_file(
self, filename: str, include_functions: Optional[dict] = None
) -> Script:
with open(
os.path.join(os.path.dirname(os.path.realpath(__file__)), filename)
) as f:
return self.redis.register_script(f.read())
script = f.read()
if include_functions:
function_definitions = []
for func_name in sorted(include_functions.keys()):
function_definitions.append(
LOCAL_FUNC_TEMPLATE.format(
func_name=func_name,
func_body=include_functions[func_name],
)
)
script = "\n".join(function_definitions + [script])

return self.redis.register_script(script)

def zadd(
self,
Expand Down Expand Up @@ -538,77 +566,62 @@ def get_expired_tasks(
# [queue1, task1, queue2, task2] -> [(queue1, task1), (queue2, task2)]
return list(zip(result[::2], result[1::2]))

def execute_pipeline(
self, pipeline: Pipeline, client: Optional[Redis] = None
) -> List[Any]:
def move_task(
self,
id: str,
queue: str,
from_state: str,
to_state: Optional[str],
unique: bool,
when: float,
mode: Optional[str],
key_func: Callable[..., str],
publish_queued_tasks: bool,
client: Optional[Redis] = None,
) -> Any:
"""
Executes the given Redis pipeline as a Lua script. When an error
occurs, the transaction stops executing, and an exception is raised.
This differs from Redis transactions, where execution continues after an
error. On success, a list of results is returned. The pipeline is
cleared after execution and can no longer be reused.
Example:
p = conn.pipeline()
p.lrange('x', 0, -1)
p.set('success', 1)
# If "x" is empty or a list, an array [[...], True] is returned.
# Otherwise, ResponseError is raised and "success" is not set.
results = redis_scripts.execute_pipeline(p)
Refer to task._move internal helper documentation.
"""

client = client or self.redis

executing_pipeline = None
try:

# Prepare args
stack = pipeline.command_stack
script_args = [int(self.can_replicate_commands), len(stack)]
for args, options in stack:
script_args += [len(args) - 1] + list(args)

# Run the pipeline
if self.can_replicate_commands: # Redis 3.2 or higher
# Make sure scripts exist
if pipeline.scripts:
pipeline.load_scripts()

raw_results = self._execute_pipeline(
args=script_args, client=client
)
else:
executing_pipeline = client.pipeline()

# Always load scripts to avoid issues when Redis loads data
# from AOF file / when replicating.
for s in pipeline.scripts:
executing_pipeline.script_load(s.script)

# Run actual pipeline lua script
self._execute_pipeline(
args=script_args, client=executing_pipeline
)

# Always load all scripts and run actual pipeline lua script
raw_results = executing_pipeline.execute()[-1]

# Run response callbacks on results.
results = []
response_callbacks = pipeline.response_callbacks
for ((args, options), result) in zip(stack, raw_results):
command_name = args[0]
if command_name in response_callbacks:
result = response_callbacks[command_name](
result, **options
)
results.append(result)

return results

finally:
if executing_pipeline:
executing_pipeline.reset()
pipeline.reset()
def _bool_to_str(v: bool) -> str:
return "true" if v else "false"

def _none_to_empty_str(v: Optional[str]) -> str:
return v or ""

key_task_id = key_func("task", id)
key_task_id_executions = key_func("task", id, "executions")
key_task_id_executions_count = key_func("task", id, "executions_count")
key_from_state = key_func(from_state)
key_to_state = key_func(to_state) if to_state else ""
key_active_queue = key_func(ACTIVE, queue)
key_queued_queue = key_func(QUEUED, queue)
key_error_queue = key_func(ERROR, queue)
key_scheduled_queue = key_func(SCHEDULED, queue)
key_activity = key_func("activity")

return self._move_task(
keys=[
key_task_id,
key_task_id_executions,
key_task_id_executions_count,
key_from_state,
key_to_state,
key_active_queue,
key_queued_queue,
key_error_queue,
key_scheduled_queue,
key_activity,
],
args=[
id,
queue,
from_state,
_none_to_empty_str(to_state),
_bool_to_str(unique),
when,
_none_to_empty_str(mode),
_bool_to_str(publish_queued_tasks),
],
client=client,
)
Loading

0 comments on commit b1019e6

Please sign in to comment.