Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove dbt.flags.MP_CONTEXT usage in dbt/adapters #8931

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231101-102758.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: remove dbt.flags.MP_CONTEXT usage in dbt/adapters
time: 2023-11-01T10:27:58.790153-04:00
custom:
Author: michelleark
Issue: "8967"
6 changes: 3 additions & 3 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# multiprocessing.RLock is a function returning this type
from multiprocessing.synchronize import RLock
from multiprocessing.context import SpawnContext
from threading import get_ident
from typing import (
Any,
Expand Down Expand Up @@ -49,7 +50,6 @@
RollbackFailed,
)
from dbt.common.events.contextvars import get_node_info
from dbt import flags
from dbt.common.utils import cast_to_str

SleepTime = Union[int, float] # As taken by time.sleep.
Expand All @@ -72,10 +72,10 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):

TYPE: str = NotImplemented

def __init__(self, profile: AdapterRequiredConfig) -> None:
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None:
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = flags.MP_CONTEXT.RLock()
self.lock: RLock = mp_context.RLock()
self.query_header: Optional[MacroQueryStringSetter] = None

def set_query_header(self, manifest: Manifest) -> None:
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypedDict,
Union,
)
from multiprocessing.context import SpawnContext

from dbt.adapters.capability import Capability, CapabilityDict
from dbt.contracts.graph.nodes import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint
Expand Down Expand Up @@ -241,10 +242,10 @@ class BaseAdapter(metaclass=AdapterMeta):
# implementations to indicate adapter support for optional capabilities.
_capabilities = CapabilityDict({})

def __init__(self, config) -> None:
def __init__(self, config, mp_context: SpawnContext) -> None:
self.config = config
self.cache = RelationsCache()
self.connections = self.ConnectionManager(config)
self.connections = self.ConnectionManager(config, mp_context)
self._macro_manifest_lazy: Optional[MacroManifest] = None

###
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dbt.include.global_project import PACKAGE_PATH as GLOBAL_PROJECT_PATH
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.semver import VersionSpecifier
from dbt.mp_context import get_mp_context

Adapter = AdapterProtocol

Expand Down Expand Up @@ -102,7 +103,7 @@ def register_adapter(self, config: AdapterRequiredConfig) -> None:
# this shouldn't really happen...
return

adapter: Adapter = adapter_type(config) # type: ignore
adapter: Adapter = adapter_type(config, get_mp_context()) # type: ignore
self.adapters[adapter_name] = adapter

def lookup_adapter(self, adapter_name: str) -> Adapter:
Expand Down
2 changes: 0 additions & 2 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
from dataclasses import dataclass
from importlib import import_module
from multiprocessing import get_context
from pprint import pformat as pf
from typing import Any, Callable, Dict, List, Optional, Set, Union

Expand Down Expand Up @@ -224,7 +223,6 @@ def _assign_params(

# Set hard coded flags.
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
object.__setattr__(self, "MP_CONTEXT", get_context("spawn"))

# Apply the lead/follow relationship between some parameters.
self._override_if_set("USE_COLORS", "USE_COLORS_FILE", params_assigned_from_default)
Expand Down
7 changes: 4 additions & 3 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
from dbt.common.events.types import MergedFromState, UnpinnedRefNewVersionAvailable
from dbt.common.events.contextvars import get_node_info
from dbt.node_types import NodeType, AccessType
from dbt.flags import get_flags, MP_CONTEXT
from dbt.flags import get_flags
from dbt.mp_context import get_mp_context
from dbt import tracking
import dbt.utils

Expand Down Expand Up @@ -823,7 +824,7 @@
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
_lock: Lock = field(
default_factory=MP_CONTEXT.Lock,
default_factory=get_mp_context().Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)

Expand All @@ -835,7 +836,7 @@

@classmethod
def __post_deserialize__(cls, obj):
obj._lock = MP_CONTEXT.Lock()
obj._lock = get_mp_context().Lock()

Check warning on line 839 in core/dbt/contracts/graph/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/manifest.py#L839

Added line #L839 was not covered by tests
return obj

def build_flat_graph(self):
Expand Down
4 changes: 0 additions & 4 deletions core/dbt/flags.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Do not import the os package because we expose this package in jinja
from os import getenv as os_getenv
from argparse import Namespace
from multiprocessing import get_context
from typing import Optional
from pathlib import Path

Expand All @@ -20,9 +19,6 @@ def env_set_truthy(key: str) -> Optional[str]:
# for setting up logger for legacy logger
ENABLE_LEGACY_LOGGER = env_set_truthy("DBT_ENABLE_LEGACY_LOGGER")

# This is not a flag, it's a place to store the lock
MP_CONTEXT = get_context()


# this roughly follows the patten of EVENT_MANAGER in dbt/common/events/functions.py
# During de-globlization, we'll need to handle both similarly
Expand Down
9 changes: 9 additions & 0 deletions core/dbt/mp_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from multiprocessing import get_context
from multiprocessing.context import SpawnContext


_MP_CONTEXT = get_context("spawn")


def get_mp_context() -> SpawnContext:
return _MP_CONTEXT
5 changes: 0 additions & 5 deletions tests/unit/test_cli_flags.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from multiprocessing import get_context
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -34,10 +33,6 @@ def test_which(self, run_context):
flags = Flags(run_context)
assert flags.WHICH == "run"

def test_mp_context(self, run_context):
flags = Flags(run_context)
assert flags.MP_CONTEXT == get_context("spawn")

@pytest.mark.parametrize("param", cli.params)
def test_cli_group_flags_from_params(self, run_context, param):
flags = Flags(run_context)
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def test_parser_var_not_defined(self):
class TestParseWrapper(unittest.TestCase):
def setUp(self):
self.mock_config = mock.MagicMock()
self.mock_mp_context = mock.MagicMock()
adapter_class = adapter_factory()
self.mock_adapter = adapter_class(self.mock_config)
self.mock_adapter = adapter_class(self.mock_config, self.mock_mp_context)
self.namespace = mock.MagicMock()
self.wrapper = providers.ParseDatabaseWrapper(self.mock_adapter, self.namespace)
self.responder = self.mock_adapter.responder
Expand All @@ -137,13 +138,14 @@ def test_wrapped_method(self):
class TestRuntimeWrapper(unittest.TestCase):
def setUp(self):
self.mock_config = mock.MagicMock()
self.mock_mp_context = mock.MagicMock()
self.mock_config.quoting = {
"database": True,
"schema": True,
"identifier": True,
}
adapter_class = adapter_factory()
self.mock_adapter = adapter_class(self.mock_config)
self.mock_adapter = adapter_class(self.mock_config, self.mock_mp_context)
self.namespace = mock.MagicMock()
self.wrapper = providers.RuntimeDatabaseWrapper(self.mock_adapter, self.namespace)
self.responder = self.mock_adapter.responder
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import agate
import decimal
from multiprocessing import get_context
import unittest
from unittest import mock

Expand Down Expand Up @@ -55,12 +56,13 @@ def setUp(self):
}

self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.mp_context = get_context("spawn")
self._adapter = None

@property
def adapter(self):
if self._adapter is None:
self._adapter = PostgresAdapter(self.config)
self._adapter = PostgresAdapter(self.config, self.mp_context)
inject_adapter(self._adapter, PostgresPlugin)
return self._adapter

Expand Down Expand Up @@ -384,6 +386,7 @@ def setUp(self):
}

self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.mp_context = get_context("spawn")

self.handle = mock.MagicMock(spec=psycopg2_extensions.connection)
self.cursor = self.handle.cursor.return_value
Expand All @@ -408,7 +411,7 @@ def _mock_state_check(self):
self.mock_state_check.side_effect = _mock_state_check

self.psycopg2.connect.return_value = self.handle
self.adapter = PostgresAdapter(self.config)
self.adapter = PostgresAdapter(self.config, self.mp_context)
self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config)
self.adapter.connections.query_header = MacroQueryStringSetter(
self.config, self.adapter._macro_manifest_lazy
Expand Down Expand Up @@ -533,8 +536,9 @@ def test_dbname_verification_is_case_insensitive(self):
"config-version": 2,
}
self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.mp_context = get_context("spawn")
self.adapter.cleanup_connections()
self._adapter = PostgresAdapter(self.config)
self._adapter = PostgresAdapter(self.config, self.mp_context)
self.adapter.verify_database("postgres")


Expand Down