Skip to content

Commit

Permalink
feat: added pg_listener source plugin
Browse files Browse the repository at this point in the history
This source plugin listens to notifications from the pg_pub_sub
  • Loading branch information
mkanoor committed Feb 6, 2024
1 parent 0967bac commit c638082
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 0 deletions.
188 changes: 188 additions & 0 deletions extensions/eda/plugins/event_source/pg_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""pg_listener.py.
An ansible-rulebook event source plugin for reading events from
pg_pub_sub
Arguments:
---------
dsn: The connection string/dsn for Postgres
channels: The list of channels to listen
Example:
-------
- ansible.eda.pg_listener:
dsn: "host=localhost port=5432 dbname=mydb"
channels:
- my_events
- my_alerts
Chunking:
---------
This is just informational a user doesn't have to do anything
special to enablo chunking. The sender which is the pg_notify
action from ansible rulebook will decide if chunking needs to
happen based on the size of the payload.
If the messages are over 7KB the sender will chunk the messages
into separate payloads with each payload having having the following
keys
* _message_chunked_uuid The unique message uuid
* _message_chunk_count The number of chunks for the message
* _message_chunk_sequence The sequence of the current chunk
* _chunk The actual chunk
* _message_length The total length of the message
* _message_xx_hash A hash for the entire message
The pg_listener source will assemble the chunks and once all the
chunks have been received it will deliver the entire payload to the
rulebook engine. Before the payload is delivered we validated that the entire
message has been received by validate its computed hash.
"""

import asyncio
import json
import logging
from typing import Any

import xxhash
from psycopg import AsyncConnection, OperationalError

LOGGER = logging.getLogger(__name__)

MESSAGE_CHUNKED_UUID = "_message_chunked_uuid"
MESSAGE_CHUNK_COUNT = "_message_chunk_count"
MESSAGE_CHUNK_SEQUENCE = "_message_chunk_sequence"
MESSAGE_CHUNK = "_chunk"
MESSAGE_LENGTH = "_message_length"
MESSAGE_XX_HASH = "_message_xx_hash"
REQUIRED_KEYS = ("dsn", "channels")

REQUIRED_CHUNK_KEYS = (
MESSAGE_CHUNK_COUNT,
MESSAGE_CHUNK_SEQUENCE,
MESSAGE_CHUNK,
MESSAGE_LENGTH,
MESSAGE_XX_HASH,
)


class MissingRequiredArgumentError(Exception):
"""Exception class for missing arguments."""

def __init__(self: "MissingRequiredArgumentError", key: str) -> None:
"""Class constructor with the missing key."""
super().__init__(f"PG Listener {key} is a required argument")


class MissingChunkKeyError(Exception):
"""Exception class for missing chunking keys."""

def __init__(self: "MissingChunkKeyError", key: str) -> None:
"""Class constructor with the missing key."""
super().__init__(f"Chunked payload is missing required {key}")


def _validate_chunked_payload(payload: dict) -> None:
for key in REQUIRED_CHUNK_KEYS:
if key not in payload:
raise MissingChunkKeyError(key)


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
"""Listen for events from a channel."""
for key in REQUIRED_KEYS:
if key not in args:
raise MissingRequiredArgumentError(key)

try:
async with await AsyncConnection.connect(
conninfo=args["dsn"],
autocommit=True,
) as conn:
chunked_cache = {}
cursor = conn.cursor()
for channel in args["channels"]:
await cursor.execute(f"LISTEN {channel};")
LOGGER.debug("Waiting for notifications on channel %s", channel)
async for event in conn.notifies():
data = json.loads(event.payload)
if MESSAGE_CHUNKED_UUID in data:
_validate_chunked_payload(data)
await _handle_chunked_message(data, chunked_cache, queue)
else:
await queue.put(data)
except json.decoder.JSONDecodeError:
LOGGER.exception("Error decoding data, ignoring it")
except OperationalError:
LOGGER.exception("PG Listen operational error")


async def _handle_chunked_message(
data: dict,
chunked_cache: dict,
queue: asyncio.Queue,
) -> None:
message_uuid = data[MESSAGE_CHUNKED_UUID]
number_of_chunks = data[MESSAGE_CHUNK_COUNT]
message_length = data[MESSAGE_LENGTH]
LOGGER.debug(
"Received chunked message %s total chunks %d message length %d",
message_uuid,
number_of_chunks,
message_length,
)
if message_uuid in chunked_cache:
chunked_cache[message_uuid].append(data)
else:
chunked_cache[message_uuid] = [data]
if (
len(chunked_cache[message_uuid])
== chunked_cache[message_uuid][0][MESSAGE_CHUNK_COUNT]
):
LOGGER.debug(
"Received all chunks for message %s",
message_uuid,
)
all_data = ""
for chunk in chunked_cache[message_uuid]:
all_data += chunk[MESSAGE_CHUNK]
chunks = chunked_cache.pop(message_uuid)
xx_hash = xxhash.xxh32(all_data.encode("utf-8")).hexdigest()
LOGGER.debug("Computed XX Hash is %s", xx_hash)
LOGGER.debug(
"XX Hash expected %s",
chunks[0][MESSAGE_XX_HASH],
)
if xx_hash == chunks[0][MESSAGE_XX_HASH]:
data = json.loads(all_data)
await queue.put(data)
else:
LOGGER.error("XX Hash of chunked payload doesn't match")
else:
LOGGER.debug(
"Received %d chunks for message %s",
len(chunked_cache[message_uuid]),
message_uuid,
)


if __name__ == "__main__":
# MockQueue if running directly

class MockQueue:
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201

asyncio.run(
main(
MockQueue(),
{
"dsn": "host=localhost port=5432 dbname=eda "
"user=postgres password=secret",
"channels": ["my_channel"],
},
),
)
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ kafka-python
pyyaml
systemd-python; sys_platform != 'darwin'
watchdog
psycopg
xxhash
117 changes: 117 additions & 0 deletions tests/unit/event_source/test_pg_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
""" Tests for pg_listener source plugin """

import asyncio
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
import xxhash

from extensions.eda.plugins.event_source.pg_listener import (
MESSAGE_CHUNK,
MESSAGE_CHUNK_COUNT,
MESSAGE_CHUNK_SEQUENCE,
MESSAGE_CHUNKED_UUID,
MESSAGE_LENGTH,
MESSAGE_XX_HASH,
)
from extensions.eda.plugins.event_source.pg_listener import main as pg_listener_main

MAX_LENGTH = 7 * 1024


class _MockQueue:
def __init__(self):
self.queue = []

async def put(self, event):
"""Put an event into the queue"""
self.queue.append(event)


class _AsyncIterator:
def __init__(self, data):
self.count = 0
self.data = data

def __aiter__(self):
return _AsyncIterator(self.data)

async def __aenter__(self):
return self

async def __anext__(self):
if self.count >= len(self.data):
raise StopAsyncIteration

mock = MagicMock()
mock.payload = self.data[self.count]
self.count += 1
return mock


def _to_chunks(payload: str, result: list[str]):
message_length = len(payload)
if message_length >= MAX_LENGTH:
xx_hash = xxhash.xxh32(payload.encode("utf-8")).hexdigest()
message_uuid = str(uuid.uuid4())
number_of_chunks = int(message_length / MAX_LENGTH) + 1
chunked = {
MESSAGE_CHUNKED_UUID: message_uuid,
MESSAGE_CHUNK_COUNT: number_of_chunks,
MESSAGE_LENGTH: message_length,
MESSAGE_XX_HASH: xx_hash,
}
sequence = 1
for i in range(0, message_length, MAX_LENGTH):
chunked[MESSAGE_CHUNK] = payload[i : i + MAX_LENGTH]
chunked[MESSAGE_CHUNK_SEQUENCE] = sequence
sequence += 1
result.append(json.dumps(chunked))
else:
result.append(payload)


TEST_PAYLOADS = [
[{"a": 1, "b": 2}, {"name": "Fred", "kids": ["Pebbles"]}],
[{"blob": "x" * 9000, "huge": "h" * 9000}],
[{"a": 1, "x": 2}, {"x": "y" * 20000, "fail": False, "pi": 3.14159}],
]


@pytest.mark.parametrize("events", TEST_PAYLOADS)
def test_receive_from_pg_listener(events):
"""Test receiving different payloads from pg notify."""
notify_payload = []
myqueue = _MockQueue()
for event in events:
_to_chunks(json.dumps(event), notify_payload)

def my_iterator():
return _AsyncIterator(notify_payload)

with patch(
"extensions.eda.plugins.event_source.pg_listener.AsyncConnection.connect"
) as conn:
mock_object = AsyncMock()
conn.return_value = mock_object
conn.return_value.__aenter__.return_value = mock_object
mock_object.cursor = AsyncMock
mock_object.notifies = my_iterator

asyncio.run(
pg_listener_main(
myqueue,
{
"dsn": "host=localhost dbname=mydb user=postgres password=password",
"channels": ["test"],
},
)
)

assert len(myqueue.queue) == len(events)
index = 0
for event in events:
assert myqueue.queue[index] == event
index += 1
2 changes: 2 additions & 0 deletions tests/unit/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ asyncmock
azure-servicebus
dpath
pytest-asyncio
psycopg
xxhash

0 comments on commit c638082

Please sign in to comment.