Skip to content

Commit

Permalink
remove base reference from QueryBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
toluaina committed Jul 31, 2022
1 parent 8ef67b5 commit 7fc4b6f
Show file tree
Hide file tree
Showing 13 changed files with 75 additions and 86 deletions.
12 changes: 6 additions & 6 deletions pgsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, database: str, verbose: bool = False, *args, **kwargs):
)
self.__schemas: Optional[dict] = None
# models is a dict of f'{schema}.{table}'
self.models: Dict[str] = {}
self.__models: Dict[str] = {}
self.__metadata: Dict[str] = {}
self.__indices: Dict[str] = {}
self.__views: Dict[str] = {}
Expand Down Expand Up @@ -152,7 +152,7 @@ def has_permissions(self, username: str, permissions: List[str]) -> bool:
)

# Tables...
def model(self, table: str, schema: str) -> sa.sql.Alias:
def models(self, table: str, schema: str) -> sa.sql.Alias:
"""Get an SQLAlchemy model representation from a table.
Args:
Expand All @@ -164,7 +164,7 @@ def model(self, table: str, schema: str) -> sa.sql.Alias:
"""
name: str = f"{schema}.{table}"
if name not in self.models:
if name not in self.__models:
if schema not in self.__metadata:
metadata = sa.MetaData(schema=schema)
metadata.reflect(self.engine, views=True)
Expand All @@ -189,9 +189,9 @@ def model(self, table: str, schema: str) -> sa.sql.Alias:
"primary_keys",
sorted([primary_key.key for primary_key in model.primary_key]),
)
self.models[f"{model.original}"] = model
self.__models[f"{model.original}"] = model

return self.models[name]
return self.__models[name]

@property
def conn(self):
Expand Down Expand Up @@ -482,7 +482,7 @@ def create_view(
) -> None:
create_view(
self.engine,
self.model,
self.models,
self.fetchall,
schema,
tables,
Expand Down
19 changes: 7 additions & 12 deletions pgsync/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
from typing import Callable, Dict, List, Optional, Set, Tuple

import sqlalchemy as sa

Expand All @@ -18,7 +18,6 @@
)
from .exc import (
ColumnNotFoundError,
InvalidSchemaError,
MultipleThroughTablesError,
NodeAttributeError,
RelationshipAttributeError,
Expand Down Expand Up @@ -105,10 +104,9 @@ def __str__(self):
@dataclass
class Node(object):

base: "base.Base"
models: Callable
table: str
schema: str
materialized: bool = False
primary_key: Optional[list] = None
label: Optional[str] = None
transform: Optional[dict] = None
Expand All @@ -118,7 +116,7 @@ class Node(object):
base_tables: Optional[list] = None

def __post_init__(self):
self.model: sa.sql.Alias = self.base.model(self.table, self.schema)
self.model: sa.sql.Alias = self.models(self.table, self.schema)
self.columns = self.columns or []
self.children: List[Node] = []
self.table_columns: List[str] = self.model.columns.keys()
Expand Down Expand Up @@ -147,7 +145,7 @@ def __post_init__(self):
for through_table in self.relationship.through_tables:
self.relationship.through_nodes.append(
Node(
base=self.base,
models=self.models,
table=through_table,
schema=self.schema,
parent=self,
Expand Down Expand Up @@ -257,7 +255,8 @@ def traverse_post_order(self) -> Node:

@dataclass
class Tree:
base: "base.Base"

models: Callable

def __post_init__(self):
self.tables: Set[str] = set()
Expand All @@ -272,15 +271,12 @@ def build(self, data: dict) -> Node:
if table is None:
raise TableNotInNodeError(f"Table not specified in node: {data}")

if schema and schema not in self.base.schemas:
raise InvalidSchemaError(f"Unknown schema name(s): {schema}")

if not set(data.keys()).issubset(set(NODE_ATTRIBUTES)):
attrs = set(data.keys()).difference(set(NODE_ATTRIBUTES))
raise NodeAttributeError(f"Unknown node attribute(s): {attrs}")

node = Node(
base=self.base,
models=self.models,
table=table,
schema=schema,
primary_key=data.get("primary_key", []),
Expand All @@ -289,7 +285,6 @@ def build(self, data: dict) -> Node:
columns=data.get("columns", []),
relationship=data.get("relationship", {}),
base_tables=data.get("base_tables", []),
materialized=(table in self.base._materialized_views(schema)),
)

self.tables.add(node.table)
Expand Down
5 changes: 2 additions & 3 deletions pgsync/querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

import sqlalchemy as sa

from .base import Base, compiled_query, get_foreign_keys
from .base import compiled_query, get_foreign_keys
from .constants import OBJECT, ONE_TO_MANY, ONE_TO_ONE, SCALAR
from .node import Node


class QueryBuilder(object):
"""Query builder."""

def __init__(self, base: Base, verbose: bool = False):
def __init__(self, verbose: bool = False):
"""Query builder constructor."""
self.base: Base = base
self.verbose: bool = verbose
self.isouter: bool = True

Expand Down
15 changes: 10 additions & 5 deletions pgsync/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .elastichelper import ElasticHelper
from .exc import (
ForeignKeyError,
InvalidSchemaError,
PrimaryKeyNotFoundError,
RDSError,
SchemaError,
Expand Down Expand Up @@ -101,13 +102,11 @@ def __init__(
CHECKPOINT_PATH, f".{self.__name}"
)
self.redis: RedisQueue = RedisQueue(self.__name)
self.tree: Tree = Tree(self)
self.tree: Tree = Tree(self.models)
if validate:
self.validate(repl_slots=repl_slots)
self.create_setting()
self.query_builder: QueryBuilder = QueryBuilder(
self, verbose=self.verbose
)
self.query_builder: QueryBuilder = QueryBuilder(verbose=verbose)
self.count: dict = dict(xlog=0, db=0, redis=0)

def validate(self, repl_slots: bool = True) -> None:
Expand Down Expand Up @@ -187,6 +186,12 @@ def validate(self, repl_slots: bool = True) -> None:
self.root: Node = self.tree.build(self.nodes)
self.root.display()
for node in self.root.traverse_breadth_first():

if node.schema not in self.schemas:
raise InvalidSchemaError(
f"Unknown schema name(s): {node.schema}"
)

# ensure all base tables have at least one primary_key
for table in node.base_tables:
model: sa.sql.Alias = self.model(table, node.schema)
Expand Down Expand Up @@ -1143,7 +1148,7 @@ async def async_refresh_views(self) -> None:
def _refresh_views(self) -> None:
for node in self.root.traverse_breadth_first():
if node.table in self.views(node.schema):
if node.materialized:
if node.table in self.materialized_views(node.schema):
self.refresh_view(node.table, node.schema)

def on_publish(self, payloads: list) -> None:
Expand Down
20 changes: 10 additions & 10 deletions pgsync/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ def compile_drop_index(


def _get_constraints(
model: Callable,
models: Callable,
schema: str,
tables: List[str],
label: str,
constraint_type: str,
) -> sa.sql.Select:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=sa.exc.SAWarning)
table_constraints = model("table_constraints", "information_schema")
key_column_usage = model("key_column_usage", "information_schema")
table_constraints = models("table_constraints", "information_schema")
key_column_usage = models("key_column_usage", "information_schema")
return (
sa.select(
[
Expand Down Expand Up @@ -168,10 +168,10 @@ def _get_constraints(


def _primary_keys(
model: Callable, schema: str, tables: List[str]
models: Callable, schema: str, tables: List[str]
) -> sa.sql.Select:
return _get_constraints(
model,
models,
schema,
tables,
label="primary_keys",
Expand All @@ -180,10 +180,10 @@ def _primary_keys(


def _foreign_keys(
model: Callable, schema: str, tables: List[str]
models: Callable, schema: str, tables: List[str]
) -> sa.sql.Select:
return _get_constraints(
model,
models,
schema,
tables,
label="foreign_keys",
Expand All @@ -193,7 +193,7 @@ def _foreign_keys(

def create_view(
engine: sa.engine.Engine,
model: Callable,
models: Callable,
fetchall: Callable,
schema: str,
tables: list,
Expand Down Expand Up @@ -243,15 +243,15 @@ def create_view(
for table in set(tables):
tables.add(f"{schema}.{table}")

for table_name, columns in fetchall(_primary_keys(model, schema, tables)):
for table_name, columns in fetchall(_primary_keys(models, schema, tables)):
rows.setdefault(
table_name,
{"primary_keys": set(), "foreign_keys": set()},
)
if columns:
rows[table_name]["primary_keys"] |= set(columns)

for table_name, columns in fetchall(_foreign_keys(model, schema, tables)):
for table_name, columns in fetchall(_foreign_keys(models, schema, tables)):
rows.setdefault(
table_name,
{"primary_keys": set(), "foreign_keys": set()},
Expand Down
5 changes: 2 additions & 3 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def test_has_permissions(self, connection):

def test_model(self, connection):
pg_base = Base(connection.engine.url.database)
model = pg_base.model("book", "public")
model = pg_base.models("book", "public")
assert str(model.original) == "public.book"
assert pg_base.models["public.book"] == model
with pytest.raises(TableNotFoundError) as excinfo:
pg_base.model("book", "bar")
pg_base.models("book", "bar")
assert 'Table "bar.book" not found in registry' in str(
excinfo.value
)
Expand Down
29 changes: 10 additions & 19 deletions tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from pgsync.base import Base
from pgsync.exc import (
InvalidSchemaError,
MultipleThroughTablesError,
NodeAttributeError,
RelationshipAttributeError,
Expand Down Expand Up @@ -114,15 +113,15 @@ def nodes(self):
def test_node(self, connection):
pg_base = Base(connection.engine.url.database)
node = Node(
base=pg_base,
models=pg_base.models,
table="book",
schema="public",
label="book_label",
)
assert str(node) == "Node: public.book_label"

def test_traverse_breadth_first(self, sync, nodes):
root = Tree(sync).build(nodes)
root = Tree(sync.models).build(nodes)
root.display()
for i, node in enumerate(root.traverse_breadth_first()):
if i == 0:
Expand All @@ -146,7 +145,7 @@ def test_traverse_breadth_first(self, sync, nodes):
sync.es.close()

def test_traverse_post_order(self, sync, nodes):
root = Tree(sync).build(nodes)
root = Tree(sync.models).build(nodes)
root.display()
for i, node in enumerate(root.traverse_post_order()):
if i == 0:
Expand Down Expand Up @@ -183,7 +182,7 @@ def test_relationship(self, sync):
],
}
with pytest.raises(RelationshipAttributeError) as excinfo:
Tree(sync).build(nodes)
Tree(sync.models).build(nodes)
assert "Relationship attribute " in str(excinfo.value)
sync.es.close()

Expand All @@ -200,7 +199,7 @@ def test_get_node(self, sync):
},
],
}
tree = Tree(sync)
tree = Tree(sync.models)
root: Node = tree.build(nodes)
node = tree.get_node(root, "book", "public")
assert str(node) == "Node: public.book"
Expand All @@ -212,24 +211,16 @@ def test_get_node(self, sync):
sync.es.close()

def test_tree_build(self, sync):
with pytest.raises(InvalidSchemaError) as excinfo:
Tree(sync).build(
{
"table": "book",
"schema": "bar",
}
)
assert "Unknown schema name(s)" in str(excinfo.value)

with pytest.raises(TableNotInNodeError) as excinfo:
Tree(sync).build(
Tree(sync.models).build(
{
"table": None,
}
)

with pytest.raises(NodeAttributeError) as excinfo:
Tree(sync).build(
Tree(sync.models).build(
{
"table": "book",
"foo": "bar",
Expand All @@ -238,7 +229,7 @@ def test_tree_build(self, sync):
assert "Unknown node attribute(s):" in str(excinfo.value)

with pytest.raises(NodeAttributeError) as excinfo:
Tree(sync).build(
Tree(sync.models).build(
{
"table": "book",
"children": [
Expand All @@ -257,7 +248,7 @@ def test_tree_build(self, sync):
assert "Unknown node attribute(s):" in str(excinfo.value)

with pytest.raises(TableNotInNodeError) as excinfo:
Tree(sync).build(
Tree(sync.models).build(
{
"table": "book",
"children": [
Expand All @@ -273,7 +264,7 @@ def test_tree_build(self, sync):
)
assert "Table not specified in node" in str(excinfo.value)

Tree(sync).build(
Tree(sync.models).build(
{
"table": "book",
"columns": ["tags->0"],
Expand Down
Loading

0 comments on commit 7fc4b6f

Please sign in to comment.