-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added pg_listener source plugin
This source plugin listens to notifications from the pg_pub_sub Added blob_size to generic source plugin added final_payload in generic systemd-python; conditional for MacOS
- Loading branch information
Showing
5 changed files
with
519 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
"""pg_listener.py. | ||
An ansible-rulebook event source plugin for reading events from | ||
pg_pub_sub | ||
Arguments: | ||
--------- | ||
dsn: The connection string/dsn for Postgres | ||
channel: The name of the channel to listen to | ||
Example: | ||
------- | ||
- ansible.eda.pg_listener: | ||
dsn: "host=localhost port=5432 dbname=mydb" | ||
channel: events | ||
""" | ||
|
||
import asyncio | ||
import json | ||
import logging | ||
from typing import Any | ||
|
||
import xxhash | ||
from psycopg import AsyncConnection, OperationalError | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
MESSAGE_CHUNKED_UUID = "_message_chunked_uuid" | ||
MESSAGE_CHUNK_COUNT = "_message_chunk_count" | ||
MESSAGE_CHUNK_SEQUENCE = "_message_chunk_sequence" | ||
MESSAGE_CHUNK = "_chunk" | ||
MESSAGE_LENGTH = "_message_length" | ||
MESSAGE_XX_HASH = "_message_xx_hash" | ||
REQUIRED_KEYS = ("dsn", "channel") | ||
|
||
|
||
class MissingRequiredArgumentError(Exception): | ||
"""Exception class for missing arguments.""" | ||
|
||
def __init__(self: "MissingRequiredArgumentError", key: str) -> None: | ||
"""Class constructor with the missing key.""" | ||
super().__init__(f"PG Listener {key} is a required argument") | ||
|
||
|
||
async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: | ||
"""Listen for events from a channel.""" | ||
for key in REQUIRED_KEYS: | ||
if key not in args: | ||
raise MissingRequiredArgumentError(key) | ||
|
||
try: | ||
async with await AsyncConnection.connect( | ||
conninfo=args["dsn"], | ||
autocommit=True, | ||
) as conn: | ||
chunked_cache = {} | ||
cursor = conn.cursor() | ||
await cursor.execute(f"LISTEN {args['channel']};") | ||
LOGGER.debug("Waiting for notifications on channel %s", args["channel"]) | ||
async for event in conn.notifies(): | ||
data = json.loads(event.payload) | ||
if MESSAGE_CHUNKED_UUID in data: | ||
await _handle_chunked_message(data, chunked_cache, queue) | ||
else: | ||
await queue.put(data) | ||
except json.decoder.JSONDecodeError: | ||
LOGGER.exception("Error decoding data, ignoring it") | ||
except OperationalError: | ||
LOGGER.exception("PG Listen operational error") | ||
|
||
|
||
async def _handle_chunked_message( | ||
data: dict, | ||
chunked_cache: dict, | ||
queue: asyncio.Queue, | ||
) -> None: | ||
message_uuid = data[MESSAGE_CHUNKED_UUID] | ||
number_of_chunks = data[MESSAGE_CHUNK_COUNT] | ||
message_length = data[MESSAGE_LENGTH] | ||
LOGGER.debug( | ||
"Received chunked message %s total chunks %d message length %d", | ||
message_uuid, | ||
number_of_chunks, | ||
message_length, | ||
) | ||
if message_uuid in chunked_cache: | ||
chunked_cache[message_uuid].append(data) | ||
else: | ||
chunked_cache[message_uuid] = [data] | ||
if ( | ||
len(chunked_cache[message_uuid]) | ||
== chunked_cache[message_uuid][0][MESSAGE_CHUNK_COUNT] | ||
): | ||
LOGGER.debug( | ||
"Received all chunks for message %s", | ||
message_uuid, | ||
) | ||
all_data = "" | ||
for chunk in chunked_cache[message_uuid]: | ||
all_data += chunk[MESSAGE_CHUNK] | ||
chunks = chunked_cache.pop(message_uuid) | ||
xx_hash = xxhash.xxh32(all_data.encode("utf-8")).hexdigest() | ||
LOGGER.debug("Computed XX Hash is %s", xx_hash) | ||
LOGGER.debug( | ||
"XX Hash expected %s", | ||
chunks[0][MESSAGE_XX_HASH], | ||
) | ||
if xx_hash == chunks[0][MESSAGE_XX_HASH]: | ||
data = json.loads(all_data) | ||
await queue.put(data) | ||
else: | ||
LOGGER.error("XX Hash of chunked payload doesn't match") | ||
else: | ||
LOGGER.debug( | ||
"Received %d chunks for message %s", | ||
len(chunked_cache[message_uuid]), | ||
message_uuid, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
# MockQueue if running directly | ||
|
||
class MockQueue: | ||
"""A fake queue.""" | ||
|
||
async def put(self: "MockQueue", event: dict) -> None: | ||
"""Print the event.""" | ||
print(event) # noqa: T201 | ||
|
||
asyncio.run( | ||
main( | ||
MockQueue(), | ||
{ | ||
"dsn": "host=localhost port=5432 dbname=eda " | ||
"user=postgres password=secret", | ||
"channel": "my_channel", | ||
}, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.