diff --git a/pyproject.toml b/pyproject.toml index 3240b861..71ef119c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,44 +23,12 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "dbt-adapters", - "psycopg2~=2.9", - # installed via dbt-adapters but used directly, unpin minor to avoid version conflicts + "dbt-adapters>=0.1.0a6,<0.2.0", + "psycopg2>=2.9,<3.0", + # installed via dbt-adapters but used directly "dbt-common<1.0", - "agate<2.0", + "agate>=1.0,<2.0", ] -[project.optional-dependencies] -dev = [ - "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git", -] -lint = [ - "black", - "flake8", - "Flake8-pyproject", -] -typecheck = [ - "mypy", - "types-protobuf", - "types-pytz", -] -test = [ - # TODO: remove `dbt-core` dependencies from unit tests - "dbt-core @ git+https://github.com/dbt-labs/dbt-core.git#subdirectory=core", - "freezegun", - "pytest", - "pytest-dotenv", - "pytest-mock", - "pytest-xdist", -] -integration = [ - "dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter", -] -build = [ - "wheel", - "twine", - "check-wheel-contents", -] - [project.urls] Homepage = "https://github.com/dbt-labs/dbt-postgres" Documentation = "https://docs.getdbt.com" @@ -72,11 +40,6 @@ Changelog = "https://github.com/dbt-labs/dbt-postgres/blob/main/CHANGELOG.md" requires = ["hatchling"] build-backend = "hatchling.build" -# TODO: this is needed to install from github in optoinal-dependencies -# alternatively, we can stick the github dependencies directly in the hatch envs -[tool.hatch.metadata] -allow-direct-references = true - [tool.hatch.build.targets.sdist] include = ["dbt"] @@ -87,44 +50,68 @@ packages = ["dbt"] path = "dbt/adapters/postgres/__version__.py" [tool.hatch.envs.default] -features = [ - "lint", - "typecheck", - "test", - "integration", - "build", +dependencies = [ + "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git", + "dbt_common @ git+https://github.com/dbt-labs/dbt-common.git", ] [tool.hatch.envs.lint] detached = true -features = ["lint"] +dependencies = [ + "black", + "flake8", + "Flake8-pyproject", +] [tool.hatch.envs.lint.scripts] -all = ["black", "flake8"] +all = [ + "black", + "flake8", +] black = "python -m black ." flake8 = "python -m flake8 ." [tool.hatch.envs.typecheck] -features = ["typecheck"] +dependencies = [ + "mypy", + "types-protobuf", + "types-pytz", +] [tool.hatch.envs.typecheck.scripts] all = "python -m mypy ." [tool.hatch.envs.unit-tests] -# TODO: confirm this works for production testing or add appropriate hatch envs -features = ["dev", "test"] +dependencies = [ + # TODO: remove `dbt-core` dependencies from unit tests + "dbt-core @ git+https://github.com/dbt-labs/dbt-core.git#subdirectory=core", + "freezegun", + "pytest", + "pytest-dotenv", + "pytest-mock", + "pytest-xdist", +] [tool.hatch.envs.unit-tests.scripts] all = "python -m pytest {args:tests/unit}" [tool.hatch.envs.integration-tests] -# TODO: confirm this works for production testing or add appropriate hatch envs -features = ["dev", "test", "integration"] +template = "unit-tests" +extra-dependencies = [ + "dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter", +] [tool.hatch.envs.integration-tests.scripts] all = "python -m pytest {args:tests/functional}" [tool.hatch.envs.build] detached = true -features = ["build"] +dependencies = [ + "wheel", + "twine", + "check-wheel-contents", +] [tool.hatch.envs.build.scripts] -check-all = ["- check-wheel", "- check-sdist"] +check-all = [ + "- check-wheel", + "- check-sdist", +] check-wheel = [ "twine check dist/*", "find ./dist/dbt_postgres-*.whl -maxdepth 1 -type f | xargs python -m pip install --force-reinstall --find-links=dist/", diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index d73ed54c..1c375ac0 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,6 +1,11 @@ +import dataclasses from multiprocessing import get_context from unittest import TestCase, mock +import agate +from dbt.adapters.base import BaseRelation +from dbt.adapters.contracts.relation import Path +from dbt_common.context import set_invocation_context from dbt_common.exceptions import DbtValidationError from dbt.adapters.postgres import Plugin as PostgresPlugin, PostgresAdapter @@ -302,3 +307,46 @@ def test_set_zero_keepalive(self, psycopg2): connect_timeout=10, application_name="dbt", ) + + @mock.patch.object(PostgresAdapter, "execute_macro") + @mock.patch.object(PostgresAdapter, "_get_catalog_relations") + def test_get_catalog_various_schemas(self, mock_get_relations, mock_execute): + self.catalog_test(mock_get_relations, mock_execute, False) + + @mock.patch.object(PostgresAdapter, "execute_macro") + @mock.patch.object(PostgresAdapter, "_get_catalog_relations") + def test_get_filtered_catalog(self, mock_get_relations, mock_execute): + self.catalog_test(mock_get_relations, mock_execute, True) + + def catalog_test(self, mock_get_relations, mock_execute, filtered=False): + column_names = ["table_database", "table_schema", "table_name"] + relations = [ + BaseRelation(path=Path(database="dbt", schema="foo", identifier="bar")), + BaseRelation(path=Path(database="dbt", schema="FOO", identifier="baz")), + BaseRelation(path=Path(database="dbt", schema=None, identifier="bar")), + BaseRelation(path=Path(database="dbt", schema="quux", identifier="bar")), + BaseRelation(path=Path(database="dbt", schema="skip", identifier="bar")), + ] + rows = list(map(lambda x: dataclasses.astuple(x.path), relations)) + mock_execute.return_value = agate.Table(rows=rows, column_names=column_names) + + mock_get_relations.return_value = relations + + relation_configs = [] + used_schemas = {("dbt", "foo"), ("dbt", "quux")} + + set_invocation_context({}) + if filtered: + catalog, exceptions = self.adapter.get_filtered_catalog( + relation_configs, used_schemas, set([relations[0], relations[3]]) + ) + else: + catalog, exceptions = self.adapter.get_catalog(relation_configs, used_schemas) + + tupled_catalog = set(map(tuple, catalog)) + if filtered: + self.assertEqual(tupled_catalog, {rows[0], rows[3]}) + else: + self.assertEqual(tupled_catalog, {rows[0], rows[1], rows[3]}) + + self.assertEqual(exceptions, [])