Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added pg_listener source plugin #181

Merged
merged 1 commit into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions extensions/eda/plugins/event_source/pg_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""pg_listener.py.

An ansible-rulebook event source plugin for reading events from
pg_pub_sub

Arguments:
---------
dsn: The connection string/dsn for Postgres
Alex-Izquierdo marked this conversation as resolved.
Show resolved Hide resolved
channels: The list of channels to listen

Example:
-------
- ansible.eda.pg_listener:
dsn: "host=localhost port=5432 dbname=mydb"
channels:
- my_events
- my_alerts

Chunking:
---------
This is just informational a user doesn't have to do anything
special to enablo chunking. The sender which is the pg_notify
action from ansible rulebook will decide if chunking needs to
happen based on the size of the payload.
If the messages are over 7KB the sender will chunk the messages
into separate payloads with each payload having having the following
keys
* _message_chunked_uuid The unique message uuid
* _message_chunk_count The number of chunks for the message
* _message_chunk_sequence The sequence of the current chunk
* _chunk The actual chunk
* _message_length The total length of the message
* _message_xx_hash A hash for the entire message

The pg_listener source will assemble the chunks and once all the
chunks have been received it will deliver the entire payload to the
rulebook engine. Before the payload is delivered we validated that the entire
message has been received by validate its computed hash.

"""

import asyncio
import json
import logging
from typing import Any

import xxhash
from psycopg import AsyncConnection, OperationalError

LOGGER = logging.getLogger(__name__)

MESSAGE_CHUNKED_UUID = "_message_chunked_uuid"
MESSAGE_CHUNK_COUNT = "_message_chunk_count"
MESSAGE_CHUNK_SEQUENCE = "_message_chunk_sequence"
MESSAGE_CHUNK = "_chunk"
MESSAGE_LENGTH = "_message_length"
MESSAGE_XX_HASH = "_message_xx_hash"
REQUIRED_KEYS = ("dsn", "channels")

REQUIRED_CHUNK_KEYS = (
MESSAGE_CHUNK_COUNT,
MESSAGE_CHUNK_SEQUENCE,
MESSAGE_CHUNK,
MESSAGE_LENGTH,
MESSAGE_XX_HASH,
)


class MissingRequiredArgumentError(Exception):
"""Exception class for missing arguments."""

def __init__(self: "MissingRequiredArgumentError", key: str) -> None:
"""Class constructor with the missing key."""
super().__init__(f"PG Listener {key} is a required argument")


class MissingChunkKeyError(Exception):
"""Exception class for missing chunking keys."""

def __init__(self: "MissingChunkKeyError", key: str) -> None:
"""Class constructor with the missing key."""
super().__init__(f"Chunked payload is missing required {key}")


def _validate_chunked_payload(payload: dict) -> None:
for key in REQUIRED_CHUNK_KEYS:
if key not in payload:
raise MissingChunkKeyError(key)


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
"""Listen for events from a channel."""
for key in REQUIRED_KEYS:
if key not in args:
raise MissingRequiredArgumentError(key)

try:
async with await AsyncConnection.connect(
conninfo=args["dsn"],
autocommit=True,
) as conn:
chunked_cache = {}
cursor = conn.cursor()
for channel in args["channels"]:
await cursor.execute(f"LISTEN {channel};")
LOGGER.debug("Waiting for notifications on channel %s", channel)
async for event in conn.notifies():
data = json.loads(event.payload)
if MESSAGE_CHUNKED_UUID in data:
Alex-Izquierdo marked this conversation as resolved.
Show resolved Hide resolved
_validate_chunked_payload(data)
await _handle_chunked_message(data, chunked_cache, queue)
else:
await queue.put(data)
except json.decoder.JSONDecodeError:
LOGGER.exception("Error decoding data, ignoring it")
except OperationalError:
LOGGER.exception("PG Listen operational error")
Alex-Izquierdo marked this conversation as resolved.
Show resolved Hide resolved


async def _handle_chunked_message(
data: dict,
chunked_cache: dict,
queue: asyncio.Queue,
) -> None:
message_uuid = data[MESSAGE_CHUNKED_UUID]
number_of_chunks = data[MESSAGE_CHUNK_COUNT]
message_length = data[MESSAGE_LENGTH]
LOGGER.debug(
"Received chunked message %s total chunks %d message length %d",
message_uuid,
number_of_chunks,
message_length,
)
if message_uuid in chunked_cache:
chunked_cache[message_uuid].append(data)
else:
chunked_cache[message_uuid] = [data]
if (
len(chunked_cache[message_uuid])
== chunked_cache[message_uuid][0][MESSAGE_CHUNK_COUNT]
):
LOGGER.debug(
"Received all chunks for message %s",
message_uuid,
)
all_data = ""
for chunk in chunked_cache[message_uuid]:
all_data += chunk[MESSAGE_CHUNK]
chunks = chunked_cache.pop(message_uuid)
xx_hash = xxhash.xxh32(all_data.encode("utf-8")).hexdigest()
LOGGER.debug("Computed XX Hash is %s", xx_hash)
LOGGER.debug(
"XX Hash expected %s",
chunks[0][MESSAGE_XX_HASH],
)
if xx_hash == chunks[0][MESSAGE_XX_HASH]:
data = json.loads(all_data)
await queue.put(data)
else:
LOGGER.error("XX Hash of chunked payload doesn't match")
else:
LOGGER.debug(
"Received %d chunks for message %s",
len(chunked_cache[message_uuid]),
message_uuid,
)


if __name__ == "__main__":
# MockQueue if running directly

class MockQueue:
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201

asyncio.run(
main(
MockQueue(),
{
"dsn": "host=localhost port=5432 dbname=eda "
"user=postgres password=secret",
"channels": ["my_channel"],
},
),
)
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ kafka-python
pyyaml
systemd-python; sys_platform != 'darwin'
watchdog
psycopg
xxhash
117 changes: 117 additions & 0 deletions tests/unit/event_source/test_pg_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
""" Tests for pg_listener source plugin """

import asyncio
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
import xxhash

from extensions.eda.plugins.event_source.pg_listener import (
MESSAGE_CHUNK,
MESSAGE_CHUNK_COUNT,
MESSAGE_CHUNK_SEQUENCE,
MESSAGE_CHUNKED_UUID,
MESSAGE_LENGTH,
MESSAGE_XX_HASH,
)
from extensions.eda.plugins.event_source.pg_listener import main as pg_listener_main

MAX_LENGTH = 7 * 1024


class _MockQueue:
def __init__(self):
self.queue = []

async def put(self, event):
"""Put an event into the queue"""
self.queue.append(event)


class _AsyncIterator:
def __init__(self, data):
self.count = 0
self.data = data

def __aiter__(self):
return _AsyncIterator(self.data)

async def __aenter__(self):
return self

async def __anext__(self):
if self.count >= len(self.data):
raise StopAsyncIteration

mock = MagicMock()
mock.payload = self.data[self.count]
self.count += 1
return mock


def _to_chunks(payload: str, result: list[str]):
message_length = len(payload)
if message_length >= MAX_LENGTH:
xx_hash = xxhash.xxh32(payload.encode("utf-8")).hexdigest()
message_uuid = str(uuid.uuid4())
number_of_chunks = int(message_length / MAX_LENGTH) + 1
chunked = {
MESSAGE_CHUNKED_UUID: message_uuid,
MESSAGE_CHUNK_COUNT: number_of_chunks,
MESSAGE_LENGTH: message_length,
MESSAGE_XX_HASH: xx_hash,
}
sequence = 1
for i in range(0, message_length, MAX_LENGTH):
chunked[MESSAGE_CHUNK] = payload[i : i + MAX_LENGTH]
chunked[MESSAGE_CHUNK_SEQUENCE] = sequence
sequence += 1
result.append(json.dumps(chunked))
else:
result.append(payload)


TEST_PAYLOADS = [
[{"a": 1, "b": 2}, {"name": "Fred", "kids": ["Pebbles"]}],
[{"blob": "x" * 9000, "huge": "h" * 9000}],
[{"a": 1, "x": 2}, {"x": "y" * 20000, "fail": False, "pi": 3.14159}],
]


@pytest.mark.parametrize("events", TEST_PAYLOADS)
def test_receive_from_pg_listener(events):
"""Test receiving different payloads from pg notify."""
notify_payload = []
myqueue = _MockQueue()
for event in events:
_to_chunks(json.dumps(event), notify_payload)

def my_iterator():
return _AsyncIterator(notify_payload)

with patch(
"extensions.eda.plugins.event_source.pg_listener.AsyncConnection.connect"
) as conn:
mock_object = AsyncMock()
conn.return_value = mock_object
conn.return_value.__aenter__.return_value = mock_object
mock_object.cursor = AsyncMock
mock_object.notifies = my_iterator

asyncio.run(
pg_listener_main(
myqueue,
{
"dsn": "host=localhost dbname=mydb user=postgres password=password",
"channels": ["test"],
},
)
)

assert len(myqueue.queue) == len(events)
index = 0
for event in events:
assert myqueue.queue[index] == event
index += 1
2 changes: 2 additions & 0 deletions tests/unit/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ asyncmock
azure-servicebus
dpath
pytest-asyncio
psycopg
xxhash
Loading