-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added pg_listener source plugin
This source plugin listens to notifications from the pg_pub_sub Added blob_size to generic source plugin added final_payload in generic systemd-python; conditional for MacOS
- Loading branch information
Showing
4 changed files
with
286 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
"""pg_listener.py. | ||
An ansible-rulebook event source plugin for reading events from | ||
pg_pub_sub | ||
Arguments: | ||
--------- | ||
dsn: The connection string/dsn for Postgres | ||
channel: The name of the channel to listen to | ||
Example: | ||
------- | ||
- ansible.eda.pg_listener: | ||
dsn: "host=localhost port=5432 dbname=mydb" | ||
channel: events | ||
""" | ||
|
||
import asyncio | ||
import json | ||
import logging | ||
from typing import Any | ||
|
||
import xxhash | ||
from psycopg import AsyncConnection, OperationalError | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
MESSAGE_CHUNKED_UUID = "_message_chunked_uuid" | ||
MESSAGE_CHUNK_COUNT = "_message_chunk_count" | ||
MESSAGE_CHUNK_SEQUENCE = "_message_chunk_sequence" | ||
MESSAGE_CHUNK = "_chunk" | ||
MESSAGE_LENGTH = "_message_length" | ||
MESSAGE_XX_HASH = "_message_xx_hash" | ||
REQUIRED_KEYS = ("dsn", "channel") | ||
|
||
|
||
async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: | ||
"""Listen for events from a channel""" | ||
|
||
for key in REQUIRED_KEYS: | ||
if key not in args: | ||
raise ValueError(f"PG Listener {key} is a required argument") | ||
|
||
try: | ||
async with await AsyncConnection.connect( | ||
conninfo=args["dsn"], | ||
autocommit=True, | ||
) as conn: | ||
chunked_cache = {} | ||
async with await conn.cursor() as cursor: | ||
await cursor.execute(f"LISTEN {args['channel']};") | ||
LOGGER.debug("Waiting for notifications on channel %s", args["channel"]) | ||
async for event in conn.notifies(): | ||
try: | ||
data = json.loads(event.payload) | ||
if MESSAGE_CHUNKED_UUID in data: | ||
message_uuid = data[MESSAGE_CHUNKED_UUID] | ||
number_of_chunks = data[MESSAGE_CHUNK_COUNT] | ||
message_length = data[MESSAGE_LENGTH] | ||
LOGGER.debug( | ||
"Received chunked message %s" | ||
"total chunks %d" | ||
"message length %d", | ||
message_uuid, | ||
number_of_chunks, | ||
message_length, | ||
) | ||
if message_uuid in chunked_cache: | ||
chunked_cache[message_uuid].append(data) | ||
else: | ||
chunked_cache[message_uuid] = [data] | ||
if ( | ||
len(chunked_cache[message_uuid]) | ||
== chunked_cache[message_uuid][0][MESSAGE_CHUNK_COUNT] | ||
): | ||
LOGGER.debug( | ||
"Received all chunks for message %s", message_uuid | ||
) | ||
all_data = "" | ||
for chunk in chunked_cache[message_uuid]: | ||
all_data += chunk[MESSAGE_CHUNK] | ||
chunks = chunked_cache.pop(message_uuid) | ||
xx_hash = xxhash.xxh32( | ||
all_data.encode("utf-8") | ||
).hexdigest() | ||
LOGGER.debug("Computed XX Hash is %s", xx_hash) | ||
LOGGER.debug( | ||
"XX Hash expected %s", chunks[0][MESSAGE_XX_HASH] | ||
) | ||
if xx_hash == chunks[0][MESSAGE_XX_HASH]: | ||
data = json.loads(all_data) | ||
await queue.put(data) | ||
else: | ||
LOGGER.error( | ||
"XX Hash of chunked payload doesn't match" | ||
) | ||
else: | ||
LOGGER.debug( | ||
"Received %d chunks for message %s", | ||
len(chunked_cache[message_uuid]), | ||
message_uuid, | ||
) | ||
else: | ||
await queue.put(data) | ||
except json.decoder.JSONDecodeError: | ||
LOGGER.error("Error decoding data, ignoring it") | ||
except OperationalError as e: | ||
LOGGER.error("PG Listen operational error %s", str(e)) | ||
raise | ||
|
||
|
||
if __name__ == "__main__": | ||
# MockQueue if running directly | ||
|
||
class MockQueue: | ||
"""A fake queue.""" | ||
|
||
async def put(self: "MockQueue", event: dict) -> None: | ||
"""Print the event.""" | ||
print(event) # noqa: T201 | ||
|
||
asyncio.run( | ||
main( | ||
MockQueue(), | ||
{ | ||
"dsn": "host=localhost port=5432 dbname=eda " | ||
"user=postgres password=secret", | ||
"channel": "my_channel", | ||
}, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import asyncio | ||
import json | ||
import uuid | ||
from unittest.mock import AsyncMock, MagicMock, patch | ||
|
||
import pytest | ||
import xxhash | ||
|
||
from extensions.eda.plugins.event_source.pg_listener import ( | ||
MESSAGE_CHUNK, | ||
MESSAGE_CHUNK_COUNT, | ||
MESSAGE_CHUNK_SEQUENCE, | ||
MESSAGE_CHUNKED_UUID, | ||
MESSAGE_LENGTH, | ||
MESSAGE_XX_HASH, | ||
) | ||
from extensions.eda.plugins.event_source.pg_listener import main as pg_listener_main | ||
|
||
MAX_LENGTH = 7 * 1024 | ||
|
||
|
||
class MockQueue: | ||
def __init__(self): | ||
self.queue = [] | ||
|
||
async def put(self, event): | ||
self.queue.append(event) | ||
|
||
|
||
@pytest.fixture | ||
def myqueue(): | ||
return MockQueue() | ||
|
||
|
||
class AsyncIterator: | ||
def __init__(self, data): | ||
self.count = 0 | ||
self.data = data | ||
|
||
def __aiter__(self): | ||
return AsyncIterator(self.data) | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __anext__(self): | ||
if self.count >= len(self.data): | ||
raise StopAsyncIteration | ||
|
||
mock = MagicMock() | ||
mock.payload = self.data[self.count] | ||
self.count += 1 | ||
return mock | ||
|
||
|
||
def to_chunks(payload: str, result: list[str]): | ||
message_length = len(payload) | ||
if message_length >= MAX_LENGTH: | ||
xx_hash = xxhash.xxh32(payload.encode("utf-8")).hexdigest() | ||
message_uuid = str(uuid.uuid4()) | ||
number_of_chunks = int(message_length / MAX_LENGTH) + 1 | ||
chunked = { | ||
MESSAGE_CHUNKED_UUID: message_uuid, | ||
MESSAGE_CHUNK_COUNT: number_of_chunks, | ||
MESSAGE_LENGTH: message_length, | ||
MESSAGE_XX_HASH: xx_hash, | ||
} | ||
sequence = 1 | ||
for i in range(0, message_length, MAX_LENGTH): | ||
chunked[MESSAGE_CHUNK] = payload[i : i + MAX_LENGTH] | ||
chunked[MESSAGE_CHUNK_SEQUENCE] = sequence | ||
sequence += 1 | ||
result.append(json.dumps(chunked)) | ||
else: | ||
result.append(payload) | ||
|
||
|
||
class AsyncContextManager: | ||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self, _exc_type, _exc_val, _exc_tb): | ||
pass | ||
|
||
|
||
TEST_PAYLOADS = [ | ||
[{"a": 1, "b": 2}, {"name": "Fred", "kids": ["Pebbles"]}], | ||
[{"blob": "x" * 9000, "huge": "h" * 9000}], | ||
[{"a": 1, "x": 2}, {"x": "y" * 20000, "fail": False, "pi": 3.14159}], | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("events", TEST_PAYLOADS) | ||
def test_receive_from_pg_listener(myqueue, events): | ||
notify_payload = [] | ||
for event in events: | ||
to_chunks(json.dumps(event), notify_payload) | ||
|
||
def my_iterator(): | ||
return AsyncIterator(notify_payload) | ||
|
||
with patch( | ||
"extensions.eda.plugins.event_source.pg_listener.AsyncConnection.connect", | ||
return_value=MagicMock(AsyncContextManager()), | ||
) as conn: | ||
conn.return_value.__aenter__.return_value.cursor = AsyncMock() | ||
conn.return_value.__aenter__.return_value.notifies = my_iterator | ||
asyncio.run( | ||
pg_listener_main( | ||
myqueue, | ||
{ | ||
"dsn": "host=localhost dbname=mydb user=postgres password=password", | ||
"channel": "test", | ||
}, | ||
) | ||
) | ||
|
||
assert len(myqueue.queue) == len(events) | ||
index = 0 | ||
for event in events: | ||
assert myqueue.queue[index] == event | ||
index += 1 |