diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 59e2a0a93a6..75d69f66731 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -46,7 +46,7 @@ from dbt.adapters.protocol import AdapterConfig, ConnectionManagerProtocol from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows from dbt.clients.jinja import MacroGenerator -from dbt.contracts.graph.manifest import Manifest, MacroManifest +from dbt.contracts.graph.manifest import AnyManifest, Manifest, MacroManifest from dbt.contracts.graph.nodes import ResultNode from dbt.events.functions import fire_event, warn_or_error from dbt.events.types import ( @@ -349,10 +349,8 @@ def load_macro_manifest(self, base_macros_only=False) -> MacroManifest: self.connections.set_query_header, base_macros_only=base_macros_only, ) - # TODO CT-211 - self._macro_manifest_lazy = manifest # type: ignore[assignment] - # TODO CT-211 - return self._macro_manifest_lazy # type: ignore[return-value] + self._macro_manifest_lazy = manifest + return self._macro_manifest_lazy def clear_macro_manifest(self): if self._macro_manifest_lazy is not None: @@ -983,7 +981,7 @@ def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[ def execute_macro( self, macro_name: str, - manifest: Optional[Manifest] = None, + manifest: Optional[AnyManifest] = None, project: Optional[str] = None, context_override: Optional[Dict[str, Any]] = None, kwargs: Dict[str, Any] = None, @@ -992,7 +990,7 @@ def execute_macro( """Look macro_name up in the manifest and execute its results. :param macro_name: The name of the macro to execute. - :param manifest: The manifest to use for generating the base macro + :param provided_manifest: The manifest to use for generating the base macro execution context. If none is provided, use the internal manifest. :param project: The name of the project to search in, or None for the first match. @@ -1004,16 +1002,15 @@ def execute_macro( if kwargs is None: kwargs = {} + if context_override is None: context_override = {} if manifest is None: - # TODO CT-211 - manifest = self._macro_manifest # type: ignore[assignment] - # TODO CT-211 - macro = manifest.find_macro_by_name( # type: ignore[union-attr] - macro_name, self.config.project_name, project - ) + manifest = self._macro_manifest + + macro = manifest.find_macro_by_name(macro_name, self.config.project_name, project) + if macro is None: if project is None: package_name = "any package" @@ -1036,6 +1033,7 @@ def execute_macro( manifest=manifest, # type: ignore[arg-type] package_name=project, ) + macro_context.update(context_override) macro_function = MacroGenerator(macro, macro_context) diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 13b9bd79968..1d2a8b960c7 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -67,11 +67,10 @@ def compile_node( ConnectionManager_T = TypeVar("ConnectionManager_T", bound=ConnectionManagerProtocol) Relation_T = TypeVar("Relation_T", bound=RelationProtocol) Column_T = TypeVar("Column_T", bound=ColumnProtocol) -Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol) +Compiler_T = TypeVar("Compiler_T", bound=CompilerProtocol, covariant=True) -# TODO CT-211 -class AdapterProtocol( # type: ignore[misc] +class AdapterProtocol( Protocol, Generic[ AdapterConfig_T, diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index 464c07871a0..b3ab135eb48 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -99,22 +99,15 @@ def get_response(cls, cursor: Any) -> AdapterResponse: ) @classmethod - def process_results( - cls, column_names: Iterable[str], rows: Iterable[Any] - ) -> List[Dict[str, Any]]: - # TODO CT-211 - unique_col_names = dict() # type: ignore[var-annotated] - # TODO CT-211 - for idx in range(len(column_names)): # type: ignore[arg-type] - # TODO CT-211 - col_name = column_names[idx] # type: ignore[index] + def process_results(cls, column_names: List[str], rows: Iterable[Any]) -> List[Dict[str, Any]]: + unique_col_names: Dict = dict() + for idx in range(len(column_names)): + col_name = column_names[idx] if col_name in unique_col_names: unique_col_names[col_name] += 1 - # TODO CT-211 - column_names[idx] = f"{col_name}_{unique_col_names[col_name]}" # type: ignore[index] # noqa + column_names[idx] = f"{col_name}_{unique_col_names[col_name]}" else: - # TODO CT-211 - unique_col_names[column_names[idx]] = 1 # type: ignore[index] + unique_col_names[column_names[idx]] = 1 return [dict(zip(column_names, row)) for row in rows] @classmethod diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index 835302a9b0d..993425066b7 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -70,8 +70,7 @@ def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: @classmethod def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: - # TODO CT-211 - decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined] + decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "float8" if decimals else "integer" @classmethod diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 1ac9fc239f0..ab71df182d2 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -46,8 +46,7 @@ def get_datetime_module_context() -> Dict[str, Any]: def get_re_module_context() -> Dict[str, Any]: - # TODO CT-211 - context_exports = re.__all__ # type: ignore[attr-defined] + context_exports = re.__all__ return {name: getattr(re, name) for name in context_exports} diff --git a/core/dbt/context/context_config.py b/core/dbt/context/context_config.py index b497887ab45..685cc8f7ca7 100644 --- a/core/dbt/context/context_config.py +++ b/core/dbt/context/context_config.py @@ -1,7 +1,7 @@ from abc import abstractmethod from copy import deepcopy from dataclasses import dataclass -from typing import List, Iterator, Dict, Any, TypeVar, Generic +from typing import List, Iterator, Dict, Any, TypeVar, Generic, Union from dbt.config import RuntimeConfig, Project, IsFQNResource from dbt.contracts.graph.model_config import BaseConfig, get_config_for, _listify @@ -131,7 +131,7 @@ def calculate_node_config( project_name: str, base: bool, patch_config_dict: Dict[str, Any] = None, - ) -> BaseConfig: + ) -> T: own_config = self.get_node_project(project_name) result = self.initial_result(resource_type=resource_type, base=base) @@ -155,8 +155,7 @@ def calculate_node_config( result = self._update_from_config(result, fqn_config) # this is mostly impactful in the snapshot config case - # TODO CT-211 - return result # type: ignore[return-value] + return result @abstractmethod def calculate_node_config_dict( @@ -227,7 +226,6 @@ def calculate_node_config_dict( base: bool, patch_config_dict: dict = None, ) -> Dict[str, Any]: - # TODO CT-211 return self.calculate_node_config( config_call_dict=config_call_dict, fqn=fqn, @@ -235,7 +233,7 @@ def calculate_node_config_dict( project_name=project_name, base=base, patch_config_dict=patch_config_dict, - ) # type: ignore[return-value] + ) def initial_result(self, resource_type: NodeType, base: bool) -> Dict[str, Any]: return {} @@ -321,11 +319,11 @@ def build_config_dict( self, base: bool = False, *, rendered: bool = True, patch_config_dict: dict = None ) -> Dict[str, Any]: if rendered: - # TODO CT-211 - src = ContextConfigGenerator(self._active_project) # type: ignore[var-annotated] + src: Union[ContextConfigGenerator, UnrenderedConfigGenerator] = ContextConfigGenerator( + self._active_project + ) else: - # TODO CT-211 - src = UnrenderedConfigGenerator(self._active_project) # type: ignore[assignment] + src = UnrenderedConfigGenerator(self._active_project) return src.calculate_node_config_dict( config_call_dict=self._config_call_dict, diff --git a/core/dbt/context/docs.py b/core/dbt/context/docs.py index 3d5abf42e11..b3e87a161c5 100644 --- a/core/dbt/context/docs.py +++ b/core/dbt/context/docs.py @@ -7,6 +7,7 @@ from dbt.config.runtime import RuntimeConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import Macro, ResultNode +from dbt.contracts.files import SourceFile from dbt.context.base import contextmember from dbt.context.configured import SchemaYamlContext @@ -65,8 +66,8 @@ def doc(self, *args: str) -> str: file_id = target_doc.file_id if file_id in self.manifest.files: source_file = self.manifest.files[file_id] - # TODO CT-211 - source_file.add_node(self.node.unique_id) # type: ignore[union-attr] + if type(source_file) == SourceFile: + source_file.add_node(self.node.unique_id) else: raise DocTargetNotFoundError( node=self.node, target_doc_name=doc_name, target_doc_package=doc_package_name diff --git a/core/dbt/context/macros.py b/core/dbt/context/macros.py index 1c61e564e06..387d1e112a8 100644 --- a/core/dbt/context/macros.py +++ b/core/dbt/context/macros.py @@ -32,14 +32,14 @@ def __init__( self.packages: Dict[str, FlatNamespace] = packages self.global_project_namespace: FlatNamespace = global_project_namespace - def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]: + def _search_order( + self, + ) -> Iterable[Union[FullNamespace, FlatNamespace, Dict[str, FlatNamespace]]]: yield self.local_namespace # local package yield self.global_namespace # root package - # TODO CT-211 - yield self.packages # type: ignore[misc] # non-internal packages + yield self.packages yield { - # TODO CT-211 - GLOBAL_PROJECT_NAME: self.global_project_namespace, # type: ignore[misc] # dbt + GLOBAL_PROJECT_NAME: self.global_project_namespace, } yield self.global_project_namespace # other internal project besides dbt diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index fe279a7fd3f..e4159bf8115 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -28,6 +28,7 @@ from dbt.context.macros import MacroNamespaceBuilder, MacroNamespace from dbt.context.manifest import ManifestContext from dbt.contracts.connection import AdapterResponse +from dbt.contracts.files import SourceFile, SchemaSourceFile from dbt.contracts.graph.manifest import Manifest, Disabled from dbt.contracts.graph.nodes import ( Macro, @@ -39,6 +40,7 @@ ManifestNode, RefArgs, AccessType, + GenericTestNode, ) from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference from dbt.contracts.graph.unparsed import NodeVersion @@ -1290,9 +1292,8 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: if self.model.file_id in self.manifest.files: source_file = self.manifest.files[self.model.file_id] # Schema files should never get here - if source_file.parse_file_type != "schema": - # TODO CT-211 - source_file.env_vars.append(var) # type: ignore[union-attr] + if source_file.parse_file_type != "schema" and type(source_file) == SourceFile: + source_file.env_vars.append(var) return return_value else: raise EnvVarMissingError(var) @@ -1353,36 +1354,28 @@ class ModelContext(ProviderContext): def pre_hooks(self) -> List[Dict[str, Any]]: if self.model.resource_type in [NodeType.Source, NodeType.Test]: return [] - # TODO CT-211 - return [ - h.to_dict(omit_none=True) for h in self.model.config.pre_hook # type: ignore[union-attr] # noqa - ] + return [h.to_dict(omit_none=True) for h in self.model.config.pre_hook] @contextproperty def post_hooks(self) -> List[Dict[str, Any]]: if self.model.resource_type in [NodeType.Source, NodeType.Test]: return [] - # TODO CT-211 - return [ - h.to_dict(omit_none=True) for h in self.model.config.post_hook # type: ignore[union-attr] # noqa - ] + return [h.to_dict(omit_none=True) for h in self.model.config.post_hook] @contextproperty def sql(self) -> Optional[str]: # only doing this in sql model for backward compatible if ( getattr(self.model, "extra_ctes_injected", None) - and self.model.language == ModelLanguage.sql # type: ignore[union-attr] + and self.model.language == ModelLanguage.sql ): - # TODO CT-211 - return self.model.compiled_code # type: ignore[union-attr] + return self.model.compiled_code return None @contextproperty def compiled_code(self) -> Optional[str]: if getattr(self.model, "extra_ctes_injected", None): - # TODO CT-211 - return self.model.compiled_code # type: ignore[union-attr] + return self.model.compiled_code return None @contextproperty @@ -1652,13 +1645,16 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER ) # the "model" should only be test nodes, but just in case, check - # TODO CT-211 - if self.model.resource_type == NodeType.Test and self.model.file_key_name: # type: ignore[union-attr] # noqa + if ( + self.model.resource_type == NodeType.Test + and type(self.model) == GenericTestNode + and self.model.file_key_name + ): source_file = self.manifest.files[self.model.file_id] - # TODO CT-211 - (yaml_key, name) = self.model.file_key_name.split(".") # type: ignore[union-attr] # noqa - # TODO CT-211 - source_file.add_env_var(var, yaml_key, name) # type: ignore[union-attr] + + (yaml_key, name) = self.model.file_key_name.split(".") + if type(source_file) == SchemaSourceFile: + source_file.add_env_var(var, yaml_key, name) return return_value else: raise EnvVarMissingError(var) diff --git a/core/dbt/contracts/graph/model_config.py b/core/dbt/contracts/graph/model_config.py index 89471df8d5b..443a582d3d3 100644 --- a/core/dbt/contracts/graph/model_config.py +++ b/core/dbt/contracts/graph/model_config.py @@ -563,6 +563,8 @@ class TestConfig(NodeAndTestConfig): fail_calc: str = "count(*)" warn_if: str = "!= 0" error_if: str = "!= 0" + pre_hook: List = [] + post_hook: List = [] @classmethod def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool: diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 5f3513fbda3..b80e18f6761 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -786,6 +786,7 @@ class SeedNode(ParsedNode): # No SQLDefaults! root_path: Optional[str] = None depends_on: MacroDependsOn = field(default_factory=MacroDependsOn) state_relation: Optional[StateRelation] = None + compiled_code = None def same_seeds(self, other: "SeedNode") -> bool: # for seeds, we check the hashes. If the hashes are different types, diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 1dbf39e01ad..69c010051ab 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -1117,7 +1117,7 @@ def load_macros( root_config: RuntimeConfig, macro_hook: Callable[[Manifest], Any], base_macros_only=False, - ) -> Manifest: + ) -> MacroManifest: with PARSING_STATE: # base_only/base_macros_only: for testing only, # allows loading macros without running 'dbt deps' first diff --git a/third-party-stubs/agate/__init__.pyi b/third-party-stubs/agate/__init__.pyi index c773cc7d7f4..5153c983a34 100644 --- a/third-party-stubs/agate/__init__.pyi +++ b/third-party-stubs/agate/__init__.pyi @@ -1,6 +1,6 @@ from collections.abc import Sequence -from typing import Any, Optional, Callable, Iterable, Dict, Union +from typing import Any, Optional, Callable, Iterable, Dict, Union, Tuple, OrderedDict from . import data_types as data_types from .data_types import ( @@ -52,6 +52,7 @@ class Table: def columns(self): ... @property def rows(self): ... + def aggregate(self, aggregations: Any) -> OrderedDict: ... def print_csv(self, **kwargs: Any) -> None: ... def print_json(self, **kwargs: Any) -> None: ... def where(self, test: Callable[[Row], bool]) -> "Table": ...