Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy fixes #6702

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from dbt.adapters.protocol import AdapterConfig, ConnectionManagerProtocol
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.contracts.graph.manifest import AnyManifest, Manifest, MacroManifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.events.functions import fire_event, warn_or_error
from dbt.events.types import (
Expand Down Expand Up @@ -349,10 +349,8 @@ def load_macro_manifest(self, base_macros_only=False) -> MacroManifest:
self.connections.set_query_header,
base_macros_only=base_macros_only,
)
# TODO CT-211
self._macro_manifest_lazy = manifest # type: ignore[assignment]
# TODO CT-211
return self._macro_manifest_lazy # type: ignore[return-value]
self._macro_manifest_lazy = manifest
return self._macro_manifest_lazy

def clear_macro_manifest(self):
if self._macro_manifest_lazy is not None:
Expand Down Expand Up @@ -983,7 +981,7 @@ def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[
def execute_macro(
self,
macro_name: str,
manifest: Optional[Manifest] = None,
manifest: Optional[AnyManifest] = None,
project: Optional[str] = None,
context_override: Optional[Dict[str, Any]] = None,
kwargs: Dict[str, Any] = None,
Expand All @@ -992,7 +990,7 @@ def execute_macro(
"""Look macro_name up in the manifest and execute its results.

:param macro_name: The name of the macro to execute.
:param manifest: The manifest to use for generating the base macro
:param provided_manifest: The manifest to use for generating the base macro
execution context. If none is provided, use the internal manifest.
:param project: The name of the project to search in, or None for the
first match.
Expand All @@ -1004,16 +1002,15 @@ def execute_macro(

if kwargs is None:
kwargs = {}

if context_override is None:
context_override = {}

if manifest is None:
# TODO CT-211
manifest = self._macro_manifest # type: ignore[assignment]
# TODO CT-211
macro = manifest.find_macro_by_name( # type: ignore[union-attr]
macro_name, self.config.project_name, project
)
manifest = self._macro_manifest

macro = manifest.find_macro_by_name(macro_name, self.config.project_name, project)

if macro is None:
if project is None:
package_name = "any package"
Expand All @@ -1036,6 +1033,7 @@ def execute_macro(
manifest=manifest, # type: ignore[arg-type]
package_name=project,
)

macro_context.update(context_override)

macro_function = MacroGenerator(macro, macro_context)
Expand Down
5 changes: 2 additions & 3 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ def compile_node(
ConnectionManager_T = TypeVar("ConnectionManager_T", bound=ConnectionManagerProtocol)
Relation_T = TypeVar("Relation_T", bound=RelationProtocol)
Column_T = TypeVar("Column_T", bound=ColumnProtocol)
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol)
Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol, covariant=True)


# TODO CT-211
class AdapterProtocol( # type: ignore[misc]
class AdapterProtocol(
Protocol,
Generic[
AdapterConfig_T,
Expand Down
19 changes: 6 additions & 13 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,15 @@ def get_response(cls, cursor: Any) -> AdapterResponse:
)

@classmethod
def process_results(
cls, column_names: Iterable[str], rows: Iterable[Any]
) -> List[Dict[str, Any]]:
# TODO CT-211
unique_col_names = dict() # type: ignore[var-annotated]
# TODO CT-211
for idx in range(len(column_names)): # type: ignore[arg-type]
# TODO CT-211
col_name = column_names[idx] # type: ignore[index]
def process_results(cls, column_names: List[str], rows: Iterable[Any]) -> List[Dict[str, Any]]:
unique_col_names: Dict = dict()
for idx in range(len(column_names)):
col_name = column_names[idx]
if col_name in unique_col_names:
unique_col_names[col_name] += 1
# TODO CT-211
column_names[idx] = f"{col_name}_{unique_col_names[col_name]}" # type: ignore[index] # noqa
column_names[idx] = f"{col_name}_{unique_col_names[col_name]}"
else:
# TODO CT-211
unique_col_names[column_names[idx]] = 1 # type: ignore[index]
unique_col_names[column_names[idx]] = 1
return [dict(zip(column_names, row)) for row in rows]

@classmethod
Expand Down
3 changes: 1 addition & 2 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:

@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
# TODO CT-211
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined]
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float8" if decimals else "integer"

@classmethod
Expand Down
3 changes: 1 addition & 2 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def get_datetime_module_context() -> Dict[str, Any]:


def get_re_module_context() -> Dict[str, Any]:
# TODO CT-211
context_exports = re.__all__ # type: ignore[attr-defined]
context_exports = re.__all__

return {name: getattr(re, name) for name in context_exports}

Expand Down
18 changes: 8 additions & 10 deletions core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Iterator, Dict, Any, TypeVar, Generic
from typing import List, Iterator, Dict, Any, TypeVar, Generic, Union

from dbt.config import RuntimeConfig, Project, IsFQNResource
from dbt.contracts.graph.model_config import BaseConfig, get_config_for, _listify
Expand Down Expand Up @@ -131,7 +131,7 @@ def calculate_node_config(
project_name: str,
base: bool,
patch_config_dict: Dict[str, Any] = None,
) -> BaseConfig:
) -> T:
own_config = self.get_node_project(project_name)

result = self.initial_result(resource_type=resource_type, base=base)
Expand All @@ -155,8 +155,7 @@ def calculate_node_config(
result = self._update_from_config(result, fqn_config)

# this is mostly impactful in the snapshot config case
# TODO CT-211
return result # type: ignore[return-value]
return result

@abstractmethod
def calculate_node_config_dict(
Expand Down Expand Up @@ -227,15 +226,14 @@ def calculate_node_config_dict(
base: bool,
patch_config_dict: dict = None,
) -> Dict[str, Any]:
# TODO CT-211
return self.calculate_node_config(
config_call_dict=config_call_dict,
fqn=fqn,
resource_type=resource_type,
project_name=project_name,
base=base,
patch_config_dict=patch_config_dict,
) # type: ignore[return-value]
)

def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]:
return {}
Expand Down Expand Up @@ -321,11 +319,11 @@ def build_config_dict(
self, base: bool = False, *, rendered: bool = True, patch_config_dict: dict = None
) -> Dict[str, Any]:
if rendered:
# TODO CT-211
src = ContextConfigGenerator(self._active_project) # type: ignore[var-annotated]
src: Union[ContextConfigGenerator, UnrenderedConfigGenerator] = ContextConfigGenerator(
self._active_project
)
else:
# TODO CT-211
src = UnrenderedConfigGenerator(self._active_project) # type: ignore[assignment]
src = UnrenderedConfigGenerator(self._active_project)

return src.calculate_node_config_dict(
config_call_dict=self._config_call_dict,
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/context/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dbt.config.runtime import RuntimeConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import Macro, ResultNode
from dbt.contracts.files import SourceFile

from dbt.context.base import contextmember
from dbt.context.configured import SchemaYamlContext
Expand Down Expand Up @@ -65,8 +66,8 @@ def doc(self, *args: str) -> str:
file_id = target_doc.file_id
if file_id in self.manifest.files:
source_file = self.manifest.files[file_id]
# TODO CT-211
source_file.add_node(self.node.unique_id) # type: ignore[union-attr]
if type(source_file) == SourceFile:
source_file.add_node(self.node.unique_id)
else:
raise DocTargetNotFoundError(
node=self.node, target_doc_name=doc_name, target_doc_package=doc_package_name
Expand Down
10 changes: 5 additions & 5 deletions core/dbt/context/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def __init__(
self.packages: Dict[str, FlatNamespace] = packages
self.global_project_namespace: FlatNamespace = global_project_namespace

def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]:
def _search_order(
self,
) -> Iterable[Union[FullNamespace, FlatNamespace, Dict[str, FlatNamespace]]]:
yield self.local_namespace # local package
yield self.global_namespace # root package
# TODO CT-211
yield self.packages # type: ignore[misc] # non-internal packages
yield self.packages
yield {
# TODO CT-211
GLOBAL_PROJECT_NAME: self.global_project_namespace, # type: ignore[misc] # dbt
GLOBAL_PROJECT_NAME: self.global_project_namespace,
}
yield self.global_project_namespace # other internal project besides dbt

Expand Down
40 changes: 18 additions & 22 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dbt.context.macros import MacroNamespaceBuilder, MacroNamespace
from dbt.context.manifest import ManifestContext
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.files import SourceFile, SchemaSourceFile
from dbt.contracts.graph.manifest import Manifest, Disabled
from dbt.contracts.graph.nodes import (
Macro,
Expand All @@ -39,6 +40,7 @@
ManifestNode,
RefArgs,
AccessType,
GenericTestNode,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
Expand Down Expand Up @@ -1290,9 +1292,8 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
if self.model.file_id in self.manifest.files:
source_file = self.manifest.files[self.model.file_id]
# Schema files should never get here
if source_file.parse_file_type != "schema":
# TODO CT-211
source_file.env_vars.append(var) # type: ignore[union-attr]
if source_file.parse_file_type != "schema" and type(source_file) == SourceFile:
source_file.env_vars.append(var)
return return_value
else:
raise EnvVarMissingError(var)
Expand Down Expand Up @@ -1353,36 +1354,28 @@ class ModelContext(ProviderContext):
def pre_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
return []
# TODO CT-211
return [
h.to_dict(omit_none=True) for h in self.model.config.pre_hook # type: ignore[union-attr] # noqa
]
return [h.to_dict(omit_none=True) for h in self.model.config.pre_hook]

@contextproperty
def post_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
return []
# TODO CT-211
return [
h.to_dict(omit_none=True) for h in self.model.config.post_hook # type: ignore[union-attr] # noqa
]
return [h.to_dict(omit_none=True) for h in self.model.config.post_hook]

@contextproperty
def sql(self) -> Optional[str]:
# only doing this in sql model for backward compatible
if (
getattr(self.model, "extra_ctes_injected", None)
and self.model.language == ModelLanguage.sql # type: ignore[union-attr]
and self.model.language == ModelLanguage.sql
):
# TODO CT-211
return self.model.compiled_code # type: ignore[union-attr]
return self.model.compiled_code
return None

@contextproperty
def compiled_code(self) -> Optional[str]:
if getattr(self.model, "extra_ctes_injected", None):
# TODO CT-211
return self.model.compiled_code # type: ignore[union-attr]
return self.model.compiled_code
return None

@contextproperty
Expand Down Expand Up @@ -1652,13 +1645,16 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
)
# the "model" should only be test nodes, but just in case, check
# TODO CT-211
if self.model.resource_type == NodeType.Test and self.model.file_key_name: # type: ignore[union-attr] # noqa
if (
self.model.resource_type == NodeType.Test
and type(self.model) == GenericTestNode
and self.model.file_key_name
):
source_file = self.manifest.files[self.model.file_id]
# TODO CT-211
(yaml_key, name) = self.model.file_key_name.split(".") # type: ignore[union-attr] # noqa
# TODO CT-211
source_file.add_env_var(var, yaml_key, name) # type: ignore[union-attr]

(yaml_key, name) = self.model.file_key_name.split(".")
if type(source_file) == SchemaSourceFile:
source_file.add_env_var(var, yaml_key, name)
return return_value
else:
raise EnvVarMissingError(var)
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,8 @@ class TestConfig(NodeAndTestConfig):
fail_calc: str = "count(*)"
warn_if: str = "!= 0"
error_if: str = "!= 0"
pre_hook: List = []
post_hook: List = []

@classmethod
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ class SeedNode(ParsedNode): # No SQLDefaults!
root_path: Optional[str] = None
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
state_relation: Optional[StateRelation] = None
compiled_code = None

def same_seeds(self, other: "SeedNode") -> bool:
# for seeds, we check the hashes. If the hashes are different types,
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ def load_macros(
root_config: RuntimeConfig,
macro_hook: Callable[[Manifest], Any],
base_macros_only=False,
) -> Manifest:
) -> MacroManifest:
with PARSING_STATE:
# base_only/base_macros_only: for testing only,
# allows loading macros without running 'dbt deps' first
Expand Down
3 changes: 2 additions & 1 deletion third-party-stubs/agate/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence

from typing import Any, Optional, Callable, Iterable, Dict, Union
from typing import Any, Optional, Callable, Iterable, Dict, Union, Tuple, OrderedDict

from . import data_types as data_types
from .data_types import (
Expand Down Expand Up @@ -52,6 +52,7 @@ class Table:
def columns(self): ...
@property
def rows(self): ...
def aggregate(self, aggregations: Any) -> OrderedDict: ...
def print_csv(self, **kwargs: Any) -> None: ...
def print_json(self, **kwargs: Any) -> None: ...
def where(self, test: Callable[[Row], bool]) -> "Table": ...
Expand Down