Skip to content

Commit

Permalink
update Behavior from SimpleNamespace to a custom class to support typing
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare committed Aug 24, 2024
1 parent 8fc6485 commit 6adff8c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 39 deletions.
68 changes: 36 additions & 32 deletions dbt_common/behavior_flags.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
from types import SimpleNamespace
from typing import Any, Dict, List, TypedDict

try:
Expand All @@ -12,6 +11,7 @@
from dbt_common.events.base_types import WarnLevel
from dbt_common.events.functions import fire_event
from dbt_common.events.types import BehaviorDeprecationEvent
from dbt_common.exceptions import DbtInternalError


class BehaviorFlag:
Expand Down Expand Up @@ -69,44 +69,48 @@ class RawBehaviorFlag(TypedDict):

# this is effectively a dictionary that supports dot notation
# it makes usage easy, e.g. adapter.behavior.my_flag
Behavior = SimpleNamespace


def register(
behavior_flags: List[RawBehaviorFlag],
user_overrides: Dict[str, Any],
) -> Behavior:
flags = {}
for raw_flag in behavior_flags:
flag = {
"name": raw_flag["name"],
"setting": raw_flag["default"],
}

# specifically evaluate for `None` since `False` and `None` should be treated differently
if user_overrides.get(raw_flag["name"]) is not None:
flag["setting"] = user_overrides[raw_flag["name"]]

event = BehaviorDeprecationEvent(
flag_name=raw_flag["name"],
flag_source=raw_flag.get("source", _default_source()),
deprecation_version=raw_flag.get("deprecation_version"),
deprecation_message=raw_flag.get("deprecation_message"),
docs_url=raw_flag.get("docs_url"),
)
flag["deprecation_event"] = event

flags[flag["name"]] = BehaviorFlag(**flag) # type: ignore

return Behavior(**flags) # type: ignore
class Behavior:
_flags: List[BehaviorFlag]

def __init__(
self,
behavior_flags: List[RawBehaviorFlag],
user_overrides: Dict[str, Any],
) -> None:
flags = []
for raw_flag in behavior_flags:
flags.append(
BehaviorFlag(
name=raw_flag["name"],
setting=user_overrides.get(raw_flag["name"], raw_flag["default"]),
deprecation_event=_behavior_deprecation_event(raw_flag),
)
)
self._flags = flags

def __getattr__(self, name: str) -> BehaviorFlag:
for flag in self._flags:
if flag.name == name:
return flag
raise DbtInternalError(f"The flag {name} has not be registered.")


def _behavior_deprecation_event(flag: RawBehaviorFlag) -> BehaviorDeprecationEvent:
return BehaviorDeprecationEvent(
flag_name=flag["name"],
flag_source=flag.get("source", _default_source()),
deprecation_version=flag.get("deprecation_version"),
deprecation_message=flag.get("deprecation_message"),
docs_url=flag.get("docs_url"),
)


def _default_source() -> str:
"""
If the maintainer did not provide a source, default to the module that called `register`.
For adapters, this will likely be `dbt.adapters.<foo>.impl` for `dbt-foo`.
"""
frame = inspect.stack()[2]
frame = inspect.stack()[3]
if module := inspect.getmodule(frame[0]):
return module.__name__
return "Unknown"
14 changes: 7 additions & 7 deletions tests/unit/test_behavior_flags.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest
from pytest_mock import MockerFixture # type: ignore

from dbt_common.behavior_flags import register
from dbt_common.behavior_flags import Behavior
from dbt_common.events.event_manager import EventManager
from dbt_common.events.event_manager_client import add_callback_to_manager

from tests.unit.utils import EventCatcher


def test_behavior_default():
behavior = register(
behavior = Behavior(
behavior_flags=[
{"name": "default_false_flag", "default": False},
{"name": "default_true_flag", "default": True},
Expand All @@ -22,7 +22,7 @@ def test_behavior_default():


def test_behavior_user_override():
behavior = register(
behavior = Behavior(
behavior_flags=[
{"name": "flag_default_false", "default": False},
{"name": "flag_default_false_override_false", "default": False},
Expand All @@ -48,7 +48,7 @@ def test_behavior_user_override():


def test_behavior_flag_can_be_used_as_conditional():
behavior = register(
behavior = Behavior(
behavior_flags=[
{"name": "flag_false", "default": False},
{"name": "flag_true", "default": True},
Expand All @@ -70,7 +70,7 @@ def event_catcher(mocker: MockerFixture) -> EventCatcher:


def test_behavior_flags_emit_deprecation_event_on_evaluation(event_catcher) -> None:
behavior = register(
behavior = Behavior(
behavior_flags=[
{"name": "flag_false", "default": False},
{"name": "flag_true", "default": True},
Expand All @@ -90,7 +90,7 @@ def test_behavior_flags_emit_deprecation_event_on_evaluation(event_catcher) -> N


def test_behavior_flags_emit_correct_deprecation(event_catcher) -> None:
behavior = register(
behavior = Behavior(
behavior_flags=[{"name": "flag_false", "default": False}],
user_overrides={},
)
Expand All @@ -106,7 +106,7 @@ def test_behavior_flags_emit_correct_deprecation(event_catcher) -> None:


def test_behavior_flags_no_deprecation_event_on_no_warn(event_catcher) -> None:
behavior = register(
behavior = Behavior(
behavior_flags=[
{"name": "flag_false", "default": False},
],
Expand Down

0 comments on commit 6adff8c

Please sign in to comment.