Skip to content

Commit

Permalink
remove flags.MP_CONTEXT
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Oct 27, 2023
1 parent af91666 commit 4fe18f2
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 25 deletions.
6 changes: 3 additions & 3 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# multiprocessing.RLock is a function returning this type
from multiprocessing.synchronize import RLock
from multiprocessing.context import SpawnContext
from threading import get_ident
from typing import (
Any,
Expand Down Expand Up @@ -49,7 +50,6 @@
RollbackFailed,
)
from dbt.common.events.contextvars import get_node_info
from dbt import flags
from dbt.common.utils import cast_to_str

SleepTime = Union[int, float] # As taken by time.sleep.
Expand All @@ -72,10 +72,10 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):

TYPE: str = NotImplemented

def __init__(self, profile: AdapterRequiredConfig) -> None:
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None:
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = flags.MP_CONTEXT.RLock()
self.lock: RLock = mp_context.RLock()
self.query_header: Optional[MacroQueryStringSetter] = None

def set_query_header(self, manifest: Manifest) -> None:
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypedDict,
Union,
)
from multiprocessing.context import SpawnContext

from dbt.adapters.capability import Capability, CapabilityDict
from dbt.contracts.graph.nodes import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint
Expand Down Expand Up @@ -241,10 +242,10 @@ class BaseAdapter(metaclass=AdapterMeta):
# implementations to indicate adapter support for optional capabilities.
_capabilities = CapabilityDict({})

def __init__(self, config) -> None:
def __init__(self, config, mp_context: SpawnContext) -> None:
self.config = config
self.cache = RelationsCache()
self.connections = self.ConnectionManager(config)
self.connections = self.ConnectionManager(config, mp_context)
self._macro_manifest_lazy: Optional[MacroManifest] = None

###
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dbt.include.global_project import PACKAGE_PATH as GLOBAL_PROJECT_PATH
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.semver import VersionSpecifier
from dbt.mp_context import get_mp_context

Adapter = AdapterProtocol

Expand Down Expand Up @@ -102,7 +103,7 @@ def register_adapter(self, config: AdapterRequiredConfig) -> None:
# this shouldn't really happen...
return

adapter: Adapter = adapter_type(config) # type: ignore
adapter: Adapter = adapter_type(config, get_mp_context()) # type: ignore
self.adapters[adapter_name] = adapter

def lookup_adapter(self, adapter_name: str) -> Adapter:
Expand Down
2 changes: 0 additions & 2 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
from dataclasses import dataclass
from importlib import import_module
from multiprocessing import get_context
from pprint import pformat as pf
from typing import Any, Callable, Dict, List, Optional, Set, Union

Expand Down Expand Up @@ -224,7 +223,6 @@ def _assign_params(

# Set hard coded flags.
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
object.__setattr__(self, "MP_CONTEXT", get_context("spawn"))

# Apply the lead/follow relationship between some parameters.
self._override_if_set("USE_COLORS", "USE_COLORS_FILE", params_assigned_from_default)
Expand Down
7 changes: 4 additions & 3 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
from dbt.common.events.types import MergedFromState, UnpinnedRefNewVersionAvailable
from dbt.common.events.contextvars import get_node_info
from dbt.node_types import NodeType, AccessType
from dbt.flags import get_flags, MP_CONTEXT
from dbt.flags import get_flags
from dbt.mp_context import get_mp_context
from dbt import tracking
import dbt.utils

Expand Down Expand Up @@ -823,7 +824,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
_lock: Lock = field(
default_factory=MP_CONTEXT.Lock,
default_factory=get_mp_context().Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)

Expand All @@ -835,7 +836,7 @@ def __pre_serialize__(self):

@classmethod
def __post_deserialize__(cls, obj):
obj._lock = MP_CONTEXT.Lock()
obj._lock = get_mp_context().Lock()

Check warning on line 839 in core/dbt/contracts/graph/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/manifest.py#L839

Added line #L839 was not covered by tests
return obj

def build_flat_graph(self):
Expand Down
4 changes: 0 additions & 4 deletions core/dbt/flags.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Do not import the os package because we expose this package in jinja
from os import getenv as os_getenv
from argparse import Namespace
from multiprocessing import get_context
from typing import Optional
from pathlib import Path

Expand All @@ -20,9 +19,6 @@ def env_set_truthy(key: str) -> Optional[str]:
# for setting up logger for legacy logger
ENABLE_LEGACY_LOGGER = env_set_truthy("DBT_ENABLE_LEGACY_LOGGER")

# This is not a flag, it's a place to store the lock
MP_CONTEXT = get_context()


# this roughly follows the patten of EVENT_MANAGER in dbt/common/events/functions.py
# During de-globlization, we'll need to handle both similarly
Expand Down
9 changes: 9 additions & 0 deletions core/dbt/mp_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from multiprocessing import get_context
from multiprocessing.context import SpawnContext


_MP_CONTEXT = get_context("spawn")


def get_mp_context() -> SpawnContext:
return _MP_CONTEXT
5 changes: 0 additions & 5 deletions tests/unit/test_cli_flags.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from multiprocessing import get_context
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -34,10 +33,6 @@ def test_which(self, run_context):
flags = Flags(run_context)
assert flags.WHICH == "run"

def test_mp_context(self, run_context):
flags = Flags(run_context)
assert flags.MP_CONTEXT == get_context("spawn")

@pytest.mark.parametrize("param", cli.params)
def test_cli_group_flags_from_params(self, run_context, param):
flags = Flags(run_context)
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def test_parser_var_not_defined(self):
class TestParseWrapper(unittest.TestCase):
def setUp(self):
self.mock_config = mock.MagicMock()
self.mock_mp_context = mock.MagicMock()
adapter_class = adapter_factory()
self.mock_adapter = adapter_class(self.mock_config)
self.mock_adapter = adapter_class(self.mock_config, self.mock_mp_context)
self.namespace = mock.MagicMock()
self.wrapper = providers.ParseDatabaseWrapper(self.mock_adapter, self.namespace)
self.responder = self.mock_adapter.responder
Expand All @@ -137,13 +138,14 @@ def test_wrapped_method(self):
class TestRuntimeWrapper(unittest.TestCase):
def setUp(self):
self.mock_config = mock.MagicMock()
self.mock_mp_context = mock.MagicMock()
self.mock_config.quoting = {
"database": True,
"schema": True,
"identifier": True,
}
adapter_class = adapter_factory()
self.mock_adapter = adapter_class(self.mock_config)
self.mock_adapter = adapter_class(self.mock_config, self.mock_mp_context)
self.namespace = mock.MagicMock()
self.wrapper = providers.RuntimeDatabaseWrapper(self.mock_adapter, self.namespace)
self.responder = self.mock_adapter.responder
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import agate
import decimal
from multiprocessing import get_context
import unittest
from unittest import mock

Expand Down Expand Up @@ -55,12 +56,13 @@ def setUp(self):
}

self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.mp_context = get_context("spawn")
self._adapter = None

@property
def adapter(self):
if self._adapter is None:
self._adapter = PostgresAdapter(self.config)
self._adapter = PostgresAdapter(self.config, self.mp_context)
inject_adapter(self._adapter, PostgresPlugin)
return self._adapter

Expand Down Expand Up @@ -384,6 +386,7 @@ def setUp(self):
}

self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.mp_context = get_context("spawn")

self.handle = mock.MagicMock(spec=psycopg2_extensions.connection)
self.cursor = self.handle.cursor.return_value
Expand All @@ -408,7 +411,7 @@ def _mock_state_check(self):
self.mock_state_check.side_effect = _mock_state_check

self.psycopg2.connect.return_value = self.handle
self.adapter = PostgresAdapter(self.config)
self.adapter = PostgresAdapter(self.config, self.mp_context)
self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config)
self.adapter.connections.query_header = MacroQueryStringSetter(
self.config, self.adapter._macro_manifest_lazy
Expand Down Expand Up @@ -533,8 +536,9 @@ def test_dbname_verification_is_case_insensitive(self):
"config-version": 2,
}
self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.mp_context = get_context("spawn")
self.adapter.cleanup_connections()
self._adapter = PostgresAdapter(self.config)
self._adapter = PostgresAdapter(self.config, self.mp_context)
self.adapter.verify_database("postgres")


Expand Down

0 comments on commit 4fe18f2

Please sign in to comment.