diff --git a/.changes/unreleased/Features-20240514-162052.yaml b/.changes/unreleased/Features-20240514-162052.yaml new file mode 100644 index 00000000..1503c479 --- /dev/null +++ b/.changes/unreleased/Features-20240514-162052.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support adding callbacks to the event manager +time: 2024-05-14T16:20:52.120336-07:00 +custom: + Author: QMalcolm + Issue: "131" diff --git a/dbt_common/events/base_types.py b/dbt_common/events/base_types.py index 2a90e78f..5ff6e235 100644 --- a/dbt_common/events/base_types.py +++ b/dbt_common/events/base_types.py @@ -6,7 +6,7 @@ from google.protobuf.json_format import ParseDict, MessageToDict, MessageToJson from google.protobuf.message import Message from dbt_common.events.helpers import get_json_string_utcnow -from typing import Optional +from typing import Callable, Optional from dbt_common.invocation import get_invocation_id @@ -128,6 +128,9 @@ class EventMsg(Protocol): data: Message +TCallback = Callable[[EventMsg], None] + + def msg_from_base_event(event: BaseEvent, level: Optional[EventLevel] = None): msg_class_name = f"{type(event).__name__}Msg" msg_cls = getattr(event.PROTO_TYPES_MODULE, msg_class_name) diff --git a/dbt_common/events/event_manager.py b/dbt_common/events/event_manager.py index 96e61f6b..507588f3 100644 --- a/dbt_common/events/event_manager.py +++ b/dbt_common/events/event_manager.py @@ -1,15 +1,15 @@ import os import traceback -from typing import Callable, List, Optional, Protocol, Tuple +from typing import List, Optional, Protocol, Tuple -from dbt_common.events.base_types import BaseEvent, EventLevel, msg_from_base_event, EventMsg +from dbt_common.events.base_types import BaseEvent, EventLevel, msg_from_base_event, TCallback from dbt_common.events.logger import LoggerConfig, _Logger, _TextLogger, _JsonLogger, LineFormat class EventManager: def __init__(self) -> None: self.loggers: List[_Logger] = [] - self.callbacks: List[Callable[[EventMsg], None]] = [] + self.callbacks: List[TCallback] = [] def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: msg = msg_from_base_event(e, level=level) @@ -37,13 +37,16 @@ def add_logger(self, config: LoggerConfig) -> None: ) self.loggers.append(logger) + def add_callback(self, callback: TCallback) -> None: + self.callbacks.append(callback) + def flush(self) -> None: for logger in self.loggers: logger.flush() class IEventManager(Protocol): - callbacks: List[Callable[[EventMsg], None]] + callbacks: List[TCallback] loggers: List[_Logger] def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: @@ -52,6 +55,9 @@ def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: def add_logger(self, config: LoggerConfig) -> None: ... + def add_callback(self, callback: TCallback) -> None: + ... + class TestEventManager(IEventManager): __test__ = False diff --git a/dbt_common/events/event_manager_client.py b/dbt_common/events/event_manager_client.py index 1b674f6e..538d3199 100644 --- a/dbt_common/events/event_manager_client.py +++ b/dbt_common/events/event_manager_client.py @@ -1,6 +1,7 @@ # Since dbt-rpc does not do its own log setup, and since some events can # currently fire before logs can be configured by setup_event_logger(), we # create a default configuration with default settings and no file output. +from dbt_common.events.base_types import TCallback from dbt_common.events.event_manager import IEventManager, EventManager _EVENT_MANAGER: IEventManager = EventManager() @@ -16,6 +17,11 @@ def add_logger_to_manager(logger) -> None: _EVENT_MANAGER.add_logger(logger) +def add_callback_to_manager(callback: TCallback) -> None: + global _EVENT_MANAGER + _EVENT_MANAGER.add_callback(callback) + + def ctx_set_event_manager(event_manager: IEventManager) -> None: global _EVENT_MANAGER _EVENT_MANAGER = event_manager diff --git a/pyproject.toml b/pyproject.toml index 93524a2c..c1f4f281 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ lint = [ ] test = [ "pytest>=7.3,<8.0", + "pytest-mock", "pytest-xdist>=3.2,<4.0", "pytest-cov>=4.1,<5.0", "hypothesis>=6.87,<7.0", diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_event_manager.py b/tests/unit/test_event_manager.py new file mode 100644 index 00000000..3728b6bd --- /dev/null +++ b/tests/unit/test_event_manager.py @@ -0,0 +1,11 @@ +from dbt_common.events.event_manager import EventManager +from tests.unit.utils import EventCatcher + + +class TestEventManager: + def test_add_callback(self) -> None: + event_manager = EventManager() + assert len(event_manager.callbacks) == 0 + + event_manager.add_callback(EventCatcher().catch) + assert len(event_manager.callbacks) == 1 diff --git a/tests/unit/test_event_manager_client.py b/tests/unit/test_event_manager_client.py new file mode 100644 index 00000000..bd73d61a --- /dev/null +++ b/tests/unit/test_event_manager_client.py @@ -0,0 +1,15 @@ +from pytest_mock import MockerFixture + +from dbt_common.events.event_manager import EventManager +from dbt_common.events.event_manager_client import add_callback_to_manager, get_event_manager +from tests.unit.utils import EventCatcher + + +def test_add_callback_to_manager(mocker: MockerFixture) -> None: + # mock out the global event manager so the callback doesn't get added to all other tests + mocker.patch("dbt_common.events.event_manager_client._EVENT_MANAGER", EventManager()) + manager = get_event_manager() + assert len(manager.callbacks) == 0 + + add_callback_to_manager(EventCatcher().catch) + assert len(manager.callbacks) == 1 diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py index 3484cd97..2583b5f8 100644 --- a/tests/unit/test_events.py +++ b/tests/unit/test_events.py @@ -27,7 +27,7 @@ def get_all_subclasses(cls): InfoLevel, ErrorLevel, DynamicLevel, - ] and not subclass.__module__.startswith("test_"): + ] and not subclass.__module__.startswith("tests."): all_subclasses.append(subclass) all_subclasses.extend(get_all_subclasses(subclass)) return set(all_subclasses) diff --git a/tests/unit/test_functions.py b/tests/unit/test_functions.py index fa240d1d..372b2bda 100644 --- a/tests/unit/test_functions.py +++ b/tests/unit/test_functions.py @@ -1,13 +1,13 @@ import pytest -from dataclasses import dataclass, field from dbt_common.events import functions -from dbt_common.events.base_types import EventLevel, EventMsg, WarnLevel +from dbt_common.events.base_types import EventLevel, WarnLevel from dbt_common.events.event_manager import EventManager from dbt_common.events.event_manager_client import ctx_set_event_manager from dbt_common.exceptions import EventCompilationError from dbt_common.helper_types import WarnErrorOptions -from typing import List, Set +from tests.unit.utils import EventCatcher +from typing import Set # Re-implementing `Note` event as a warn event for @@ -20,14 +20,6 @@ def message(self) -> str: return self.msg -@dataclass -class EventCatcher: - caught_events: List[EventMsg] = field(default_factory=list) - - def catch(self, event: EventMsg) -> None: - self.caught_events.append(event) - - @pytest.fixture(scope="function") def event_catcher() -> EventCatcher: return EventCatcher() diff --git a/tests/unit/utils.py b/tests/unit/utils.py new file mode 100644 index 00000000..c79ecc54 --- /dev/null +++ b/tests/unit/utils.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass, field +from typing import List + +from dbt_common.events.base_types import EventMsg + + +@dataclass +class EventCatcher: + caught_events: List[EventMsg] = field(default_factory=list) + + def catch(self, event: EventMsg) -> None: + self.caught_events.append(event)