Skip to content

Commit

Permalink
feat: added pg_listener source plugin
Browse files Browse the repository at this point in the history
This source plugin listens to notifications from the pg_pub_sub
Added blob_size to generic source plugin
added final_payload in generic
systemd-python; conditional for MacOS
  • Loading branch information
mkanoor committed Nov 28, 2023
1 parent 2d8817d commit 649fcad
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 3 deletions.
28 changes: 26 additions & 2 deletions extensions/eda/plugins/event_source/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
should be repeated. Default 0
repeat_count int Number of times each individual event in the playload
should be repeated. Default 1
blob_size int An arbitray blob of blob_size bytes to be inserted
into every event payload. Default is 0 don't create
a blob
final_payload dict After all the events have been sent we send the optional
final payload which can be used to trigger a shutdown of
the rulebook, especially when we are using rulebooks to
forward messages to other running rulebooks.
"""

Expand All @@ -45,7 +53,7 @@
import random
import time
from datetime import datetime
from typing import Any
from typing import Any, Optional


async def main( # pylint: disable=R0914
Expand All @@ -65,6 +73,8 @@ async def main( # pylint: disable=R0914
repeat_delay = float(args.get("repeat_delay", 0))
loop_delay = float(args.get("loop_delay", 0))
shutdown_after = float(args.get("shutdown_after", 0))
blob_size = int(args.get("blob_size", 0))
blob = "x" * blob_size if blob_size > 0 else None

loop_count = int(args.get("loop_count", 1)) # -1 infinite
repeat_count = int(args.get("repeat_count", 1))
Expand All @@ -89,7 +99,9 @@ async def main( # pylint: disable=R0914
if not event:
continue
for _ignore in range(repeat_count):
data = _create_data(create_index, index, add_timestamp, time_format)
data = _create_data(
create_index, index, add_timestamp, time_format, blob,
)

index += 1
data.update(event)
Expand All @@ -100,6 +112,15 @@ async def main( # pylint: disable=R0914

await asyncio.sleep(event_delay)
iteration += 1

if "final_payload" in args:
data = _create_data(create_index, index, add_timestamp, time_format, blob)

data.update(args["final_payload"])
if display:
print(data) # noqa: T201
await queue.put(data)

await asyncio.sleep(shutdown_after)


Expand All @@ -108,10 +129,13 @@ def _create_data(
index: int,
add_timestamp: str,
time_format: str,
blob: Optional[str],
) -> dict:
data = {}
if create_index:
data[create_index] = index
if blob:
data["blob"] = blob
if add_timestamp:
if time_format == "local":
data["timestamp"] = str(datetime.now()) # noqa: DTZ005
Expand Down
135 changes: 135 additions & 0 deletions extensions/eda/plugins/event_source/pg_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""pg_listener.py.
An ansible-rulebook event source plugin for reading events from
pg_pub_sub
Arguments:
---------
dsn: The connection string/dsn for Postgres
channel: The name of the channel to listen to
Example:
-------
- ansible.eda.pg_listener:
dsn: "host=localhost port=5432 dbname=mydb"
channel: events
"""

import asyncio
import json
import logging
from typing import Any, Self

import xxhash
from psycopg import AsyncConnection, OperationalError

LOGGER = logging.getLogger(__name__)

MESSAGE_CHUNKED_UUID = "_message_chunked_uuid"
MESSAGE_CHUNK_COUNT = "_message_chunk_count"
MESSAGE_CHUNK_SEQUENCE = "_message_chunk_sequence"
MESSAGE_CHUNK = "_chunk"
MESSAGE_LENGTH = "_message_length"
MESSAGE_XX_HASH = "_message_xx_hash"
REQUIRED_KEYS = ("dsn", "channel")


class MissingRequiredArgumentError(Exception):
"""Exception class for missing arguments."""

def __init__(self: Self, key: str) -> None:
"""Class constructor with the missing key."""
super().__init__(f"PG Listener {key} is a required argument")


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
"""Listen for events from a channel."""
for key in REQUIRED_KEYS:
if key not in args:
raise MissingRequiredArgumentError(key)

try:
async with await AsyncConnection.connect(
conninfo=args["dsn"],
autocommit=True,
) as conn:
chunked_cache = {}
cursor = conn.cursor()
await cursor.execute(f"LISTEN {args['channel']};")
LOGGER.debug("Waiting for notifications on channel %s", args["channel"])
async for event in conn.notifies():
data = json.loads(event.payload)
if MESSAGE_CHUNKED_UUID in data:
message_uuid = data[MESSAGE_CHUNKED_UUID]
number_of_chunks = data[MESSAGE_CHUNK_COUNT]
message_length = data[MESSAGE_LENGTH]
LOGGER.debug(
"Received chunked message %s"
"total chunks %d"
"message length %d",
message_uuid,
number_of_chunks,
message_length,
)
if message_uuid in chunked_cache:
chunked_cache[message_uuid].append(data)
else:
chunked_cache[message_uuid] = [data]
if (
len(chunked_cache[message_uuid])
== chunked_cache[message_uuid][0][MESSAGE_CHUNK_COUNT]
):
LOGGER.debug(
"Received all chunks for message %s",
message_uuid,
)
all_data = ""
for chunk in chunked_cache[message_uuid]:
all_data += chunk[MESSAGE_CHUNK]
chunks = chunked_cache.pop(message_uuid)
xx_hash = xxhash.xxh32(all_data.encode("utf-8")).hexdigest()
LOGGER.debug("Computed XX Hash is %s", xx_hash)
LOGGER.debug(
"XX Hash expected %s",
chunks[0][MESSAGE_XX_HASH],
)
if xx_hash == chunks[0][MESSAGE_XX_HASH]:
data = json.loads(all_data)
await queue.put(data)
else:
LOGGER.error("XX Hash of chunked payload doesn't match")
else:
LOGGER.debug(
"Received %d chunks for message %s",
len(chunked_cache[message_uuid]),
message_uuid,
)
else:
await queue.put(data)
except json.decoder.JSONDecodeError:
LOGGER.exception("Error decoding data, ignoring it")
except OperationalError:
LOGGER.exception("PG Listen operational error")


if __name__ == "__main__":
# MockQueue if running directly

class MockQueue:
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201

asyncio.run(
main(
MockQueue(),
{
"dsn": "host=localhost port=5432 dbname=eda "
"user=postgres password=secret",
"channel": "my_channel",
},
),
)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ aiobotocore
aiohttp
aiokafka
watchdog
systemd-python
systemd-python; sys_platform != 'darwin'
dpath
pyyaml
psycopg
xxhash
117 changes: 117 additions & 0 deletions tests/unit/event_source/test_pg_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import asyncio
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
import xxhash

from extensions.eda.plugins.event_source.pg_listener import (
MESSAGE_CHUNK,
MESSAGE_CHUNK_COUNT,
MESSAGE_CHUNK_SEQUENCE,
MESSAGE_CHUNKED_UUID,
MESSAGE_LENGTH,
MESSAGE_XX_HASH,
)
from extensions.eda.plugins.event_source.pg_listener import main as pg_listener_main

MAX_LENGTH = 7 * 1024


class MockQueue:
def __init__(self):
self.queue = []

async def put(self, event):
self.queue.append(event)


@pytest.fixture
def myqueue():
return MockQueue()


class AsyncIterator:
def __init__(self, data):
self.count = 0
self.data = data

def __aiter__(self):
return AsyncIterator(self.data)

async def __aenter__(self):
return self

async def __anext__(self):
if self.count >= len(self.data):
raise StopAsyncIteration

mock = MagicMock()
mock.payload = self.data[self.count]
self.count += 1
return mock


def to_chunks(payload: str, result: list[str]):
message_length = len(payload)
if message_length >= MAX_LENGTH:
xx_hash = xxhash.xxh32(payload.encode("utf-8")).hexdigest()
message_uuid = str(uuid.uuid4())
number_of_chunks = int(message_length / MAX_LENGTH) + 1
chunked = {
MESSAGE_CHUNKED_UUID: message_uuid,
MESSAGE_CHUNK_COUNT: number_of_chunks,
MESSAGE_LENGTH: message_length,
MESSAGE_XX_HASH: xx_hash,
}
sequence = 1
for i in range(0, message_length, MAX_LENGTH):
chunked[MESSAGE_CHUNK] = payload[i : i + MAX_LENGTH]
chunked[MESSAGE_CHUNK_SEQUENCE] = sequence
sequence += 1
result.append(json.dumps(chunked))
else:
result.append(payload)


TEST_PAYLOADS = [
[{"a": 1, "b": 2}, {"name": "Fred", "kids": ["Pebbles"]}],
[{"blob": "x" * 9000, "huge": "h" * 9000}],
[{"a": 1, "x": 2}, {"x": "y" * 20000, "fail": False, "pi": 3.14159}],
]


@pytest.mark.parametrize("events", TEST_PAYLOADS)
def test_receive_from_pg_listener(myqueue, events):
notify_payload = []
for event in events:
to_chunks(json.dumps(event), notify_payload)

def my_iterator():
return AsyncIterator(notify_payload)

with patch(
"extensions.eda.plugins.event_source.pg_listener.AsyncConnection.connect"
) as conn:
mock_object = AsyncMock()
conn.return_value = mock_object
conn.return_value.__aenter__.return_value = mock_object
mock_object.cursor = AsyncMock
mock_object.notifies = my_iterator

asyncio.run(
pg_listener_main(
myqueue,
{
"dsn": "host=localhost dbname=mydb user=postgres password=password",
"channel": "test",
},
)
)

assert len(myqueue.queue) == len(events)
index = 0
for event in events:
assert myqueue.queue[index] == event
index += 1

0 comments on commit 649fcad

Please sign in to comment.