Skip to content

Commit

Permalink
feat: added pg_listener source plugin
Browse files Browse the repository at this point in the history
This source plugin listens to notifications from the pg_pub_sub
Added blob_size to generic source plugin
added final_payload in generic
systemd-python; conditional for MacOS
  • Loading branch information
mkanoor committed Dec 5, 2023
1 parent c75cf2f commit 15d6e9e
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 48 deletions.
129 changes: 82 additions & 47 deletions extensions/eda/plugins/event_source/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Optional Parameters:
randomize True|False Randomize the events in the payload, default False
display True|False Display the event data in stdout, default False
add_timestamp True|False Add an event timestamp, default False
timestamp True|False Add an event timestamp, default False
time_format local|iso8601|epoch The time format of event timestamp,
default local
create_index str The index to create for each event starts at 0
Expand All @@ -24,6 +24,14 @@
should be repeated. Default 0
repeat_count int Number of times each individual event in the playload
should be repeated. Default 1
blob_size int An arbitray blob of blob_size bytes to be inserted
into every event payload. Default is 0 don't create
a blob
final_payload dict After all the events have been sent we send the optional
final payload which can be used to trigger a shutdown of
the rulebook, especially when we are using rulebooks to
forward messages to other running rulebooks.
"""

Expand All @@ -41,94 +49,121 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import random
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Any


@dataclass
class Args:
"""Class to store all the passed in args."""

payload: Any
time_format: str = "local"
display: bool = False
timestamp: bool = False
create_index: str = ""
randomize: bool = False
startup_delay: float = 0
event_delay: float = 0
repeat_delay: float = 0
loop_delay: float = 0
shutdown_after: float = 0
blob_size: int = 0
loop_count: int = 1
repeat_count: int = 1
final_payload: Any = None


async def main( # pylint: disable=R0914
queue: asyncio.Queue,
args: dict[str, Any],
) -> None:
"""Insert event data into the queue."""
payload = args.get("payload")
randomize = args.get("randomize", False)
display = args.get("display", False)
add_timestamp = args.get("timestamp", False)
time_format = args.get("time_format", "local")
create_index = args.get("create_index", "")

startup_delay = float(args.get("startup_delay", 0))
event_delay = float(args.get("event_delay", 0))
repeat_delay = float(args.get("repeat_delay", 0))
loop_delay = float(args.get("loop_delay", 0))
shutdown_after = float(args.get("shutdown_after", 0))

loop_count = int(args.get("loop_count", 1)) # -1 infinite
repeat_count = int(args.get("repeat_count", 1))
if time_format not in ["local", "iso8601", "epoch"]:
my_args = Args(**args)
blob = "x" * my_args.blob_size if my_args.blob_size > 0 else None

if my_args.timestamp and my_args.time_format not in ["local", "iso8601", "epoch"]:
msg = "time_format must be one of local, iso8601, epoch"
raise ValueError(msg)

if not isinstance(payload, list):
payload = [payload]
if not isinstance(my_args.payload, list):
my_args.payload = [my_args.payload]

iteration = 0
index = 0

await asyncio.sleep(startup_delay)
await asyncio.sleep(my_args.startup_delay)

while iteration != loop_count:
if loop_delay > 0 and iteration > 0:
await asyncio.sleep(loop_delay)
if randomize:
random.shuffle(payload)
for event in payload:
while iteration != my_args.loop_count:
if my_args.loop_delay > 0 and iteration > 0:
await asyncio.sleep(my_args.loop_delay)
if my_args.randomize:
random.shuffle(my_args.payload)
for event in my_args.payload:
if not event:
continue
for _ignore in range(repeat_count):
data = _create_data(create_index, index, add_timestamp, time_format)

for _ignore in range(my_args.repeat_count):
await _post_event(my_args, queue, event, index, blob)
index += 1
data.update(event)
if display:
print(data) # noqa: T201
await queue.put(data)
await asyncio.sleep(repeat_delay)
await asyncio.sleep(my_args.repeat_delay)

await asyncio.sleep(event_delay)
await asyncio.sleep(my_args.event_delay)
iteration += 1
await asyncio.sleep(shutdown_after)

if isinstance(my_args.final_payload, dict):
await _post_event(my_args, queue, my_args.final_payload, index, blob)

await asyncio.sleep(my_args.shutdown_after)


async def _post_event(
my_args: Args,
queue: asyncio.Queue,
event: dict,
index: int,
blob: str,
) -> None:

data = _create_data(my_args, index, blob)

data.update(event)
if my_args.display:
print(data) # noqa: T201
await queue.put(data)


def _create_data(
create_index: str,
my_args: Args,
index: int,
add_timestamp: str,
time_format: str,
blob: str | None,
) -> dict:
data = {}
if create_index:
data[create_index] = index
if add_timestamp:
if time_format == "local":
if my_args.create_index:
data[my_args.create_index] = index
if blob:
data["blob"] = blob
if my_args.timestamp:
if my_args.time_format == "local":
data["timestamp"] = str(datetime.now()) # noqa: DTZ005
elif time_format == "epoch":
elif my_args.time_format == "epoch":
data["timestamp"] = int(time.time())
elif time_format == "iso8601":
elif my_args.time_format == "iso8601":
data["timestamp"] = datetime.now(tz=None).isoformat() # noqa: DTZ005
return data


if __name__ == "__main__":
"""MockQueue if running directly."""

class MockQueue:
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
async def put(self: MockQueue, event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
141 changes: 141 additions & 0 deletions extensions/eda/plugins/event_source/pg_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""pg_listener.py.
An ansible-rulebook event source plugin for reading events from
pg_pub_sub
Arguments:
---------
dsn: The connection string/dsn for Postgres
channel: The name of the channel to listen to
Example:
-------
- ansible.eda.pg_listener:
dsn: "host=localhost port=5432 dbname=mydb"
channel: events
"""

import asyncio
import json
import logging
from typing import Any

import xxhash
from psycopg import AsyncConnection, OperationalError

LOGGER = logging.getLogger(__name__)

MESSAGE_CHUNKED_UUID = "_message_chunked_uuid"
MESSAGE_CHUNK_COUNT = "_message_chunk_count"
MESSAGE_CHUNK_SEQUENCE = "_message_chunk_sequence"
MESSAGE_CHUNK = "_chunk"
MESSAGE_LENGTH = "_message_length"
MESSAGE_XX_HASH = "_message_xx_hash"
REQUIRED_KEYS = ("dsn", "channel")


class MissingRequiredArgumentError(Exception):
"""Exception class for missing arguments."""

def __init__(self: "MissingRequiredArgumentError", key: str) -> None:
"""Class constructor with the missing key."""
super().__init__(f"PG Listener {key} is a required argument")


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
"""Listen for events from a channel."""
for key in REQUIRED_KEYS:
if key not in args:
raise MissingRequiredArgumentError(key)

try:
async with await AsyncConnection.connect(
conninfo=args["dsn"],
autocommit=True,
) as conn:
chunked_cache = {}
cursor = conn.cursor()
await cursor.execute(f"LISTEN {args['channel']};")
LOGGER.debug("Waiting for notifications on channel %s", args["channel"])
async for event in conn.notifies():
data = json.loads(event.payload)
if MESSAGE_CHUNKED_UUID in data:
await _handle_chunked_message(data, chunked_cache, queue)
else:
await queue.put(data)
except json.decoder.JSONDecodeError:
LOGGER.exception("Error decoding data, ignoring it")
except OperationalError:
LOGGER.exception("PG Listen operational error")


async def _handle_chunked_message(
data: dict,
chunked_cache: dict,
queue: asyncio.Queue,
) -> None:
message_uuid = data[MESSAGE_CHUNKED_UUID]
number_of_chunks = data[MESSAGE_CHUNK_COUNT]
message_length = data[MESSAGE_LENGTH]
LOGGER.debug(
"Received chunked message %s total chunks %d message length %d",
message_uuid,
number_of_chunks,
message_length,
)
if message_uuid in chunked_cache:
chunked_cache[message_uuid].append(data)
else:
chunked_cache[message_uuid] = [data]
if (
len(chunked_cache[message_uuid])
== chunked_cache[message_uuid][0][MESSAGE_CHUNK_COUNT]
):
LOGGER.debug(
"Received all chunks for message %s",
message_uuid,
)
all_data = ""
for chunk in chunked_cache[message_uuid]:
all_data += chunk[MESSAGE_CHUNK]
chunks = chunked_cache.pop(message_uuid)
xx_hash = xxhash.xxh32(all_data.encode("utf-8")).hexdigest()
LOGGER.debug("Computed XX Hash is %s", xx_hash)
LOGGER.debug(
"XX Hash expected %s",
chunks[0][MESSAGE_XX_HASH],
)
if xx_hash == chunks[0][MESSAGE_XX_HASH]:
data = json.loads(all_data)
await queue.put(data)
else:
LOGGER.error("XX Hash of chunked payload doesn't match")
else:
LOGGER.debug(
"Received %d chunks for message %s",
len(chunked_cache[message_uuid]),
message_uuid,
)


if __name__ == "__main__":
# MockQueue if running directly

class MockQueue:
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
"""Print the event."""
print(event) # noqa: T201

asyncio.run(
main(
MockQueue(),
{
"dsn": "host=localhost port=5432 dbname=eda "
"user=postgres password=secret",
"channel": "my_channel",
},
),
)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ aiobotocore
aiohttp
aiokafka
watchdog
systemd-python
systemd-python; sys_platform != 'darwin'
dpath
pyyaml
psycopg
xxhash
Loading

0 comments on commit 15d6e9e

Please sign in to comment.