From 7fc4b6f689a49c3ec1d9322b1f8cfdf203ce730d Mon Sep 17 00:00:00 2001 From: Tolu Aina <7848930+toluaina@users.noreply.github.com> Date: Sun, 31 Jul 2022 19:52:37 +0100 Subject: [PATCH] remove base reference from QueryBuilder --- pgsync/base.py | 12 ++++---- pgsync/node.py | 19 +++++-------- pgsync/querybuilder.py | 5 ++-- pgsync/sync.py | 15 ++++++---- pgsync/view.py | 20 +++++++------- tests/test_base.py | 5 ++-- tests/test_node.py | 29 +++++++------------- tests/test_query_builder.py | 16 +++++------ tests/test_sync.py | 14 +++++----- tests/test_sync_single_child_fk_on_child.py | 6 ++-- tests/test_sync_single_child_fk_on_parent.py | 6 ++-- tests/test_utils.py | 4 +-- tests/test_view.py | 10 +++---- 13 files changed, 75 insertions(+), 86 deletions(-) diff --git a/pgsync/base.py b/pgsync/base.py index 23716526..9e0bb12a 100644 --- a/pgsync/base.py +++ b/pgsync/base.py @@ -78,7 +78,7 @@ def __init__(self, database: str, verbose: bool = False, *args, **kwargs): ) self.__schemas: Optional[dict] = None # models is a dict of f'{schema}.{table}' - self.models: Dict[str] = {} + self.__models: Dict[str] = {} self.__metadata: Dict[str] = {} self.__indices: Dict[str] = {} self.__views: Dict[str] = {} @@ -152,7 +152,7 @@ def has_permissions(self, username: str, permissions: List[str]) -> bool: ) # Tables... - def model(self, table: str, schema: str) -> sa.sql.Alias: + def models(self, table: str, schema: str) -> sa.sql.Alias: """Get an SQLAlchemy model representation from a table. Args: @@ -164,7 +164,7 @@ def model(self, table: str, schema: str) -> sa.sql.Alias: """ name: str = f"{schema}.{table}" - if name not in self.models: + if name not in self.__models: if schema not in self.__metadata: metadata = sa.MetaData(schema=schema) metadata.reflect(self.engine, views=True) @@ -189,9 +189,9 @@ def model(self, table: str, schema: str) -> sa.sql.Alias: "primary_keys", sorted([primary_key.key for primary_key in model.primary_key]), ) - self.models[f"{model.original}"] = model + self.__models[f"{model.original}"] = model - return self.models[name] + return self.__models[name] @property def conn(self): @@ -482,7 +482,7 @@ def create_view( ) -> None: create_view( self.engine, - self.model, + self.models, self.fetchall, schema, tables, diff --git a/pgsync/node.py b/pgsync/node.py index f817baed..cf726aba 100644 --- a/pgsync/node.py +++ b/pgsync/node.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Tuple +from typing import Callable, Dict, List, Optional, Set, Tuple import sqlalchemy as sa @@ -18,7 +18,6 @@ ) from .exc import ( ColumnNotFoundError, - InvalidSchemaError, MultipleThroughTablesError, NodeAttributeError, RelationshipAttributeError, @@ -105,10 +104,9 @@ def __str__(self): @dataclass class Node(object): - base: "base.Base" + models: Callable table: str schema: str - materialized: bool = False primary_key: Optional[list] = None label: Optional[str] = None transform: Optional[dict] = None @@ -118,7 +116,7 @@ class Node(object): base_tables: Optional[list] = None def __post_init__(self): - self.model: sa.sql.Alias = self.base.model(self.table, self.schema) + self.model: sa.sql.Alias = self.models(self.table, self.schema) self.columns = self.columns or [] self.children: List[Node] = [] self.table_columns: List[str] = self.model.columns.keys() @@ -147,7 +145,7 @@ def __post_init__(self): for through_table in self.relationship.through_tables: self.relationship.through_nodes.append( Node( - base=self.base, + models=self.models, table=through_table, schema=self.schema, parent=self, @@ -257,7 +255,8 @@ def traverse_post_order(self) -> Node: @dataclass class Tree: - base: "base.Base" + + models: Callable def __post_init__(self): self.tables: Set[str] = set() @@ -272,15 +271,12 @@ def build(self, data: dict) -> Node: if table is None: raise TableNotInNodeError(f"Table not specified in node: {data}") - if schema and schema not in self.base.schemas: - raise InvalidSchemaError(f"Unknown schema name(s): {schema}") - if not set(data.keys()).issubset(set(NODE_ATTRIBUTES)): attrs = set(data.keys()).difference(set(NODE_ATTRIBUTES)) raise NodeAttributeError(f"Unknown node attribute(s): {attrs}") node = Node( - base=self.base, + models=self.models, table=table, schema=schema, primary_key=data.get("primary_key", []), @@ -289,7 +285,6 @@ def build(self, data: dict) -> Node: columns=data.get("columns", []), relationship=data.get("relationship", {}), base_tables=data.get("base_tables", []), - materialized=(table in self.base._materialized_views(schema)), ) self.tables.add(node.table) diff --git a/pgsync/querybuilder.py b/pgsync/querybuilder.py index c2252b04..acb0efdc 100644 --- a/pgsync/querybuilder.py +++ b/pgsync/querybuilder.py @@ -3,7 +3,7 @@ import sqlalchemy as sa -from .base import Base, compiled_query, get_foreign_keys +from .base import compiled_query, get_foreign_keys from .constants import OBJECT, ONE_TO_MANY, ONE_TO_ONE, SCALAR from .node import Node @@ -11,9 +11,8 @@ class QueryBuilder(object): """Query builder.""" - def __init__(self, base: Base, verbose: bool = False): + def __init__(self, verbose: bool = False): """Query builder constructor.""" - self.base: Base = base self.verbose: bool = verbose self.isouter: bool = True diff --git a/pgsync/sync.py b/pgsync/sync.py index 033afc2e..906fbea8 100644 --- a/pgsync/sync.py +++ b/pgsync/sync.py @@ -34,6 +34,7 @@ from .elastichelper import ElasticHelper from .exc import ( ForeignKeyError, + InvalidSchemaError, PrimaryKeyNotFoundError, RDSError, SchemaError, @@ -101,13 +102,11 @@ def __init__( CHECKPOINT_PATH, f".{self.__name}" ) self.redis: RedisQueue = RedisQueue(self.__name) - self.tree: Tree = Tree(self) + self.tree: Tree = Tree(self.models) if validate: self.validate(repl_slots=repl_slots) self.create_setting() - self.query_builder: QueryBuilder = QueryBuilder( - self, verbose=self.verbose - ) + self.query_builder: QueryBuilder = QueryBuilder(verbose=verbose) self.count: dict = dict(xlog=0, db=0, redis=0) def validate(self, repl_slots: bool = True) -> None: @@ -187,6 +186,12 @@ def validate(self, repl_slots: bool = True) -> None: self.root: Node = self.tree.build(self.nodes) self.root.display() for node in self.root.traverse_breadth_first(): + + if node.schema not in self.schemas: + raise InvalidSchemaError( + f"Unknown schema name(s): {node.schema}" + ) + # ensure all base tables have at least one primary_key for table in node.base_tables: model: sa.sql.Alias = self.model(table, node.schema) @@ -1143,7 +1148,7 @@ async def async_refresh_views(self) -> None: def _refresh_views(self) -> None: for node in self.root.traverse_breadth_first(): if node.table in self.views(node.schema): - if node.materialized: + if node.table in self.materialized_views(node.schema): self.refresh_view(node.table, node.schema) def on_publish(self, payloads: list) -> None: diff --git a/pgsync/view.py b/pgsync/view.py index 64280342..51fb59b8 100644 --- a/pgsync/view.py +++ b/pgsync/view.py @@ -125,7 +125,7 @@ def compile_drop_index( def _get_constraints( - model: Callable, + models: Callable, schema: str, tables: List[str], label: str, @@ -133,8 +133,8 @@ def _get_constraints( ) -> sa.sql.Select: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=sa.exc.SAWarning) - table_constraints = model("table_constraints", "information_schema") - key_column_usage = model("key_column_usage", "information_schema") + table_constraints = models("table_constraints", "information_schema") + key_column_usage = models("key_column_usage", "information_schema") return ( sa.select( [ @@ -168,10 +168,10 @@ def _get_constraints( def _primary_keys( - model: Callable, schema: str, tables: List[str] + models: Callable, schema: str, tables: List[str] ) -> sa.sql.Select: return _get_constraints( - model, + models, schema, tables, label="primary_keys", @@ -180,10 +180,10 @@ def _primary_keys( def _foreign_keys( - model: Callable, schema: str, tables: List[str] + models: Callable, schema: str, tables: List[str] ) -> sa.sql.Select: return _get_constraints( - model, + models, schema, tables, label="foreign_keys", @@ -193,7 +193,7 @@ def _foreign_keys( def create_view( engine: sa.engine.Engine, - model: Callable, + models: Callable, fetchall: Callable, schema: str, tables: list, @@ -243,7 +243,7 @@ def create_view( for table in set(tables): tables.add(f"{schema}.{table}") - for table_name, columns in fetchall(_primary_keys(model, schema, tables)): + for table_name, columns in fetchall(_primary_keys(models, schema, tables)): rows.setdefault( table_name, {"primary_keys": set(), "foreign_keys": set()}, @@ -251,7 +251,7 @@ def create_view( if columns: rows[table_name]["primary_keys"] |= set(columns) - for table_name, columns in fetchall(_foreign_keys(model, schema, tables)): + for table_name, columns in fetchall(_foreign_keys(models, schema, tables)): rows.setdefault( table_name, {"primary_keys": set(), "foreign_keys": set()}, diff --git a/tests/test_base.py b/tests/test_base.py index c3199b39..4a88847b 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -61,11 +61,10 @@ def test_has_permissions(self, connection): def test_model(self, connection): pg_base = Base(connection.engine.url.database) - model = pg_base.model("book", "public") + model = pg_base.models("book", "public") assert str(model.original) == "public.book" - assert pg_base.models["public.book"] == model with pytest.raises(TableNotFoundError) as excinfo: - pg_base.model("book", "bar") + pg_base.models("book", "bar") assert 'Table "bar.book" not found in registry' in str( excinfo.value ) diff --git a/tests/test_node.py b/tests/test_node.py index 707eb7bc..0ea61c05 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -3,7 +3,6 @@ from pgsync.base import Base from pgsync.exc import ( - InvalidSchemaError, MultipleThroughTablesError, NodeAttributeError, RelationshipAttributeError, @@ -114,7 +113,7 @@ def nodes(self): def test_node(self, connection): pg_base = Base(connection.engine.url.database) node = Node( - base=pg_base, + models=pg_base.models, table="book", schema="public", label="book_label", @@ -122,7 +121,7 @@ def test_node(self, connection): assert str(node) == "Node: public.book_label" def test_traverse_breadth_first(self, sync, nodes): - root = Tree(sync).build(nodes) + root = Tree(sync.models).build(nodes) root.display() for i, node in enumerate(root.traverse_breadth_first()): if i == 0: @@ -146,7 +145,7 @@ def test_traverse_breadth_first(self, sync, nodes): sync.es.close() def test_traverse_post_order(self, sync, nodes): - root = Tree(sync).build(nodes) + root = Tree(sync.models).build(nodes) root.display() for i, node in enumerate(root.traverse_post_order()): if i == 0: @@ -183,7 +182,7 @@ def test_relationship(self, sync): ], } with pytest.raises(RelationshipAttributeError) as excinfo: - Tree(sync).build(nodes) + Tree(sync.models).build(nodes) assert "Relationship attribute " in str(excinfo.value) sync.es.close() @@ -200,7 +199,7 @@ def test_get_node(self, sync): }, ], } - tree = Tree(sync) + tree = Tree(sync.models) root: Node = tree.build(nodes) node = tree.get_node(root, "book", "public") assert str(node) == "Node: public.book" @@ -212,24 +211,16 @@ def test_get_node(self, sync): sync.es.close() def test_tree_build(self, sync): - with pytest.raises(InvalidSchemaError) as excinfo: - Tree(sync).build( - { - "table": "book", - "schema": "bar", - } - ) - assert "Unknown schema name(s)" in str(excinfo.value) with pytest.raises(TableNotInNodeError) as excinfo: - Tree(sync).build( + Tree(sync.models).build( { "table": None, } ) with pytest.raises(NodeAttributeError) as excinfo: - Tree(sync).build( + Tree(sync.models).build( { "table": "book", "foo": "bar", @@ -238,7 +229,7 @@ def test_tree_build(self, sync): assert "Unknown node attribute(s):" in str(excinfo.value) with pytest.raises(NodeAttributeError) as excinfo: - Tree(sync).build( + Tree(sync.models).build( { "table": "book", "children": [ @@ -257,7 +248,7 @@ def test_tree_build(self, sync): assert "Unknown node attribute(s):" in str(excinfo.value) with pytest.raises(TableNotInNodeError) as excinfo: - Tree(sync).build( + Tree(sync.models).build( { "table": "book", "children": [ @@ -273,7 +264,7 @@ def test_tree_build(self, sync): ) assert "Table not specified in node" in str(excinfo.value) - Tree(sync).build( + Tree(sync.models).build( { "table": "book", "columns": ["tags->0"], diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py index 23adf59b..faa2fe12 100644 --- a/tests/test_query_builder.py +++ b/tests/test_query_builder.py @@ -13,13 +13,13 @@ class TestQueryBuilder(object): def test__json_build_object(self, connection): pg_base = Base(connection.engine.url.database) - query_builder = QueryBuilder(pg_base) + query_builder = QueryBuilder() with pytest.raises(RuntimeError) as excinfo: query_builder._json_build_object([]) assert "invalid expression" == str(excinfo.value) node = Node( - base=pg_base, + models=pg_base.models, table="book", schema="public", ) @@ -55,9 +55,9 @@ def test__json_build_object(self, connection): def test__get_foreign_keys(self, connection): pg_base = Base(connection.engine.url.database) - query_builder = QueryBuilder(pg_base) + query_builder = QueryBuilder() book = Node( - base=pg_base, + models=pg_base.models, table="book", schema="public", ) @@ -70,7 +70,7 @@ def test__get_foreign_keys(self, connection): ) assert expected in str(excinfo.value) publisher = Node( - base=pg_base, + models=pg_base.models, table="publisher", schema="public", ) @@ -81,7 +81,7 @@ def test__get_foreign_keys(self, connection): } subject = Node( - base=pg_base, + models=pg_base.models, table="subject", schema="public", relationship={ @@ -99,7 +99,7 @@ def test__get_foreign_keys(self, connection): def test__get_column_foreign_keys(self, connection): pg_base = Base(connection.engine.url.database) - query_builder = QueryBuilder(pg_base) + query_builder = QueryBuilder() foreign_keys = { "public.subject": ["column_a", "column_b", "column_X"], @@ -107,7 +107,7 @@ def test__get_column_foreign_keys(self, connection): } subject = Node( - base=pg_base, + models=pg_base.models, table="subject", schema="public", relationship={ diff --git a/tests/test_sync.py b/tests/test_sync.py index ec9bc5a4..610836bc 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -443,7 +443,7 @@ def test_sync_analyze(self, sync): def test__update_op(self, sync, connection): pg_base = Base(connection.engine.url.database) node = Node( - base=pg_base, + models=pg_base.models, table="book", schema="public", ) @@ -468,7 +468,7 @@ def test__update_op(self, sync, connection): def test__insert_op(self, sync, connection): pg_base = Base(connection.engine.url.database) node = Node( - base=pg_base, + models=pg_base.models, table="book", schema="public", ) @@ -488,7 +488,7 @@ def test__insert_op(self, sync, connection): assert len(docs["hits"]["hits"]) == 0 node = Node( - base=pg_base, + models=pg_base.models, table="publisher", schema="public", ) @@ -511,7 +511,7 @@ def test__insert_op(self, sync, connection): def test__delete_op(self, mock_es, sync, connection): pg_base = Base(connection.engine.url.database) node = Node( - base=pg_base, + models=pg_base.models, table="book", schema="public", ) @@ -537,7 +537,7 @@ def test__delete_op(self, mock_es, sync, connection): def test__truncate_op(self, mock_es, sync, connection): pg_base = Base(connection.engine.url.database) node = Node( - base=pg_base, + models=pg_base.models, table="book", schema="public", ) @@ -547,7 +547,7 @@ def test__truncate_op(self, mock_es, sync, connection): # truncate a non root table node = Node( - base=pg_base, + models=pg_base.models, table="publisher", schema="public", ) @@ -652,7 +652,7 @@ def test__payload(self, sync): def test__build_filters(self, sync, connection): pg_base = Base(connection.engine.url.database) node = Node( - base=pg_base, + models=pg_base.models, table="book", schema="public", ) diff --git a/tests/test_sync_single_child_fk_on_child.py b/tests/test_sync_single_child_fk_on_child.py index 0abd6206..18a5001e 100644 --- a/tests/test_sync_single_child_fk_on_child.py +++ b/tests/test_sync_single_child_fk_on_child.py @@ -486,7 +486,7 @@ def test_invalid_relationship_type(self, sync): } sync.es.close() with pytest.raises(RelationshipTypeError) as excinfo: - Tree(sync).build(nodes) + Tree(sync.models).build(nodes) assert 'Relationship type "qwerty" is invalid' in str(excinfo.value) def test_invalid_relationship_variant(self, sync): @@ -504,7 +504,7 @@ def test_invalid_relationship_variant(self, sync): } sync.es.close() with pytest.raises(RelationshipVariantError) as excinfo: - Tree(sync).build(nodes) + Tree(sync.models).build(nodes) assert 'Relationship variant "abcdefg" is invalid' in str( excinfo.value ) @@ -521,7 +521,7 @@ def test_invalid_relationship_attribute(self, sync): } sync.es.close() with pytest.raises(RelationshipAttributeError) as excinfo: - Tree(sync).build(nodes) + Tree(sync.models).build(nodes) assert f"Relationship attribute {set(['foo'])} is invalid" in str( excinfo.value ) diff --git a/tests/test_sync_single_child_fk_on_parent.py b/tests/test_sync_single_child_fk_on_parent.py index ac00b9e0..2fb74e51 100644 --- a/tests/test_sync_single_child_fk_on_parent.py +++ b/tests/test_sync_single_child_fk_on_parent.py @@ -490,7 +490,7 @@ def test_invalid_relationship_type(self, sync): sync.es.close() sync.tree.__post_init__() with pytest.raises(RelationshipTypeError) as excinfo: - Tree(sync).build(nodes) + Tree(sync.models).build(nodes) assert 'Relationship type "qwerty" is invalid' in str(excinfo.value) def test_invalid_relationship_variant(self, sync): @@ -509,7 +509,7 @@ def test_invalid_relationship_variant(self, sync): sync.es.close() sync.tree.__post_init__() with pytest.raises(RelationshipVariantError) as excinfo: - Tree(sync).build(nodes) + Tree(sync.models).build(nodes) assert 'Relationship variant "abcdefg" is invalid' in str( excinfo.value ) @@ -527,7 +527,7 @@ def test_invalid_relationship_attribute(self, sync): sync.es.close() sync.tree.__post_init__() with pytest.raises(RelationshipAttributeError) as excinfo: - Tree(sync).build(nodes) + Tree(sync.models).build(nodes) assert f"Relationship attribute {set(['foo'])} is invalid" in str( excinfo.value ) diff --git a/tests/test_utils.py b/tests/test_utils.py index eabdd264..cd8e7078 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -97,7 +97,7 @@ def test_compiled_query_with_label( self, mock_logger, mock_sys, connection ): pg_base = Base(connection.engine.url.database) - model = pg_base.model("book", "public") + model = pg_base.models("book", "public") statement = sa.select([model.c.isbn]).select_from(model) compiled_query(statement, label="foo", literal_binds=True) mock_logger.debug.assert_called_once_with( @@ -112,7 +112,7 @@ def test_compiled_query_without_label( self, mock_logger, mock_sys, connection ): pg_base = Base(connection.engine.url.database) - model = pg_base.model("book", "public") + model = pg_base.models("book", "public") statement = sa.select([model.c.isbn]).select_from(model) compiled_query(statement, literal_binds=True) mock_logger.debug.assert_called_once_with( diff --git a/tests/test_view.py b/tests/test_view.py index 755c7f8c..e6106db5 100644 --- a/tests/test_view.py +++ b/tests/test_view.py @@ -134,7 +134,7 @@ def test_refresh_view(self, connection, sync, book_cls, data): """Test refresh materialized view.""" view = "test_view_refresh" pg_base = Base(connection.engine.url.database) - model = pg_base.model("book", "public") + model = pg_base.models("book", "public") statement = sa.select([model.c.isbn]).select_from(model) connection.engine.execute( CreateView(DEFAULT_SCHEMA, view, statement, materialized=True) @@ -207,7 +207,7 @@ def test_index(self, connection, sync, book_cls, data): @pytest.mark.usefixtures("table_creator") def test_primary_keys(self, connection, book_cls): pg_base = Base(connection.engine.url.database) - statement = _primary_keys(pg_base.model, DEFAULT_SCHEMA, ["book"]) + statement = _primary_keys(pg_base.models, DEFAULT_SCHEMA, ["book"]) rows = connection.execute(statement).fetchall() assert rows == [("book", ["isbn"])] @@ -215,7 +215,7 @@ def test_primary_keys(self, connection, book_cls): def test_foreign_keys(self, connection, book_cls): pg_base = Base(connection.engine.url.database) statement = _foreign_keys( - pg_base.model, DEFAULT_SCHEMA, ["book", "publisher"] + pg_base.models, DEFAULT_SCHEMA, ["book", "publisher"] ) rows = connection.execute(statement).fetchall() assert rows[0][0] == "book" @@ -233,7 +233,7 @@ def fetchall(statement): with patch("pgsync.view.logger") as mock_logger: create_view( connection.engine, - pg_base.model, + pg_base.models, fetchall, DEFAULT_SCHEMA, ["book", "publisher"], @@ -252,7 +252,7 @@ def fetchall(statement): with patch("pgsync.view.logger") as mock_logger: create_view( connection.engine, - pg_base.model, + pg_base.models, fetchall, "myschema", set(["book", "publisher"]),