diff --git a/extensions/eda/plugins/event_source/generic.py b/extensions/eda/plugins/event_source/generic.py index 536e5db5..e1e3ae5b 100644 --- a/extensions/eda/plugins/event_source/generic.py +++ b/extensions/eda/plugins/event_source/generic.py @@ -32,6 +32,10 @@ final payload which can be used to trigger a shutdown of the rulebook, especially when we are using rulebooks to forward messages to other running rulebooks. +check_env_vars dict Optionally check if all the defined env vars are set + before generating the events. If any of the env_var is missing + or the value doesn't match the source plugin will end + with an exception """ @@ -53,16 +57,35 @@ from __future__ import annotations import asyncio +import os import random import time from dataclasses import dataclass, fields from datetime import datetime from pathlib import Path -from typing import Any +from typing import Any, Dict, Optional import yaml +class MissingEnvVarError(Exception): + """Exception class for missing env var.""" + + def __init__(self: "MissingEnvVarError", env_var: str) -> None: + """Class constructor with the missing env_var.""" + super().__init__(f"Env Var {env_var} is required") + + +class EnvVarMismatchError(Exception): + """Exception class for mismatch in the env var value.""" + + def __init__( + self: "EnvVarMismatchError", env_var: str, value: str, expected: str + ) -> None: + """Class constructor with mismatch in env_var value.""" + super().__init__(f"Env Var {env_var} expected: {expected} passed in: {value}") + + @dataclass class Args: """Class to store all the passed in args.""" @@ -84,6 +107,7 @@ class ControlArgs: loop_count: int = 1 repeat_count: int = 1 timestamp: bool = False + check_env_vars: Optional[Dict[str, str]] = None @dataclass @@ -135,6 +159,7 @@ async def __call__(self: Generic) -> None: msg = "time_format must be one of local, iso8601, epoch" raise ValueError(msg) + await self._check_env_vars() await self._load_payload_from_file() if not isinstance(self.my_args.payload, list): @@ -174,6 +199,14 @@ async def _post_event(self: Generic, event: dict[str, Any], index: int) -> None: print(data) # noqa: T201 await self.queue.put(data) + async def _check_env_vars(self: Generic) -> None: + if self.control_args.check_env_vars: + for key, value in self.control_args.check_env_vars.items(): + if key not in os.environ: + raise MissingEnvVarError(key) + if os.environ[key] != value: + raise EnvVarMismatchError(key, os.environ[key], value) + async def _load_payload_from_file(self: Generic) -> None: if not self.my_args.payload_file: return diff --git a/tests/unit/event_source/test_generic.py b/tests/unit/event_source/test_generic.py index 909dc6ec..76ae907f 100644 --- a/tests/unit/event_source/test_generic.py +++ b/tests/unit/event_source/test_generic.py @@ -8,6 +8,10 @@ import pytest import yaml +from extensions.eda.plugins.event_source.generic import ( + EnvVarMismatchError, + MissingEnvVarError, +) from extensions.eda.plugins.event_source.generic import main as generic_main @@ -243,3 +247,57 @@ def test_generic_parsing_payload_file() -> None: }, ) ) + + +def test_env_vars_missing() -> None: + """Test missing env vars""" + myqueue = _MockQueue() + event = {"name": "fred"} + + with pytest.raises(MissingEnvVarError): + asyncio.run( + generic_main( + myqueue, + { + "payload": event, + "check_env_vars": {"NAME_MISSING": "Fred"}, + }, + ) + ) + + +def test_env_vars_mismatch() -> None: + """Test env vars with incorrect values""" + myqueue = _MockQueue() + event = {"name": "fred"} + + os.environ["TEST_ENV1"] = "Kaboom" + with pytest.raises(EnvVarMismatchError): + asyncio.run( + generic_main( + myqueue, + { + "payload": event, + "check_env_vars": {"TEST_ENV1": "Fred"}, + }, + ) + ) + + +def test_env_vars() -> None: + """Test env vars with correct values""" + myqueue = _MockQueue() + event = {"name": "fred"} + + os.environ["TEST_ENV1"] = "Fred" + asyncio.run( + generic_main( + myqueue, + { + "payload": event, + "check_env_vars": {"TEST_ENV1": "Fred"}, + }, + ) + ) + assert len(myqueue.queue) == 1 + assert myqueue.queue[0] == {"name": "fred"}