Skip to content

Commit

Permalink
perf: optimise table metadata build
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesmeneghello committed Nov 4, 2024
1 parent 6efa7d0 commit c641747
Showing 1 changed file with 136 additions and 1 deletion.
137 changes: 136 additions & 1 deletion tap_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import sqlalchemy.types
from psycopg2 import extras
from singer_sdk import SQLConnector, SQLStream
from singer_sdk import typing as th
from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema
from singer_sdk.connectors.sql import SQLToJSONSchema
from singer_sdk.helpers._state import increment_state
from singer_sdk.helpers._typing import TypeConformanceLevel
Expand All @@ -31,7 +33,12 @@
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import Engine
from sqlalchemy.engine.reflection import Inspector

from sqlalchemy.types import (
ReflectedColumn,
ReflectedIndex,
ReflectedPrimaryKeyConstraint,
TableKey,
)

class PostgresSQLToJSONSchema(SQLToJSONSchema):
"""Custom SQL to JSON Schema conversion for Postgres."""
Expand Down Expand Up @@ -181,6 +188,134 @@ def get_schema_names(self, engine: Engine, inspected: Inspector) -> list[str]:
return self.config["filter_schemas"]
return super().get_schema_names(engine, inspected)

# Uses information_schema for speed.
def discover_catalog_entry( # noqa: PLR0913
self,
engine: Engine,
inspected: Inspector,
schema_name: str,
table_name: str,
is_view: bool,
table_data: dict[TableKey, list[ReflectedColumn]],
pk_data: dict[TableKey, ReflectedPrimaryKeyConstraint],
index_data: dict[TableKey, list[ReflectedIndex]]
) -> CatalogEntry:
"""Create `CatalogEntry` object for the given table or a view.
Args:
engine: SQLAlchemy engine
inspected: SQLAlchemy inspector instance for engine
schema_name: Schema name to inspect
table_name: Name of the table or a view
is_view: Flag whether this object is a view, returned by `get_object_names`
table_data: Cached inspector data for the relevant tables
pk_data: Cached inspector data for the relevant primary keys
index_data: Cached inspector data for the relevant indexes
Returns:
`CatalogEntry` object for the given table or a view
"""
# Initialize unique stream name
unique_stream_id = f"{schema_name}-{table_name}"
table_key = (schema_name, table_name)

# Detect key properties
possible_primary_keys: list[list[str]] = []
pk_def = pk_data.get(table_key, {})
if pk_def and "constrained_columns" in pk_def: # type: ignore[redundant-expr]
possible_primary_keys.append(pk_def["constrained_columns"])

# An element of the columns list is ``None`` if it's an expression and is
# returned in the ``expressions`` list of the reflected index.
possible_primary_keys.extend(
index_def["column_names"] # type: ignore[misc]
for index_def in index_data.get(table_key, [])
if index_def.get("unique", False)
)

key_properties = next(iter(possible_primary_keys), None)

# Initialize columns list
table_schema = th.PropertiesList()
for column_def in table_data.get(table_key, []):
column_name = column_def["name"]
is_nullable = column_def.get("nullable", False)
jsonschema_type: dict = self.to_jsonschema_type(column_def["type"])
table_schema.append(
th.Property(
name=column_name,
wrapped=th.CustomType(jsonschema_type),
nullable=is_nullable,
required=column_name in key_properties if key_properties else False,
),
)
schema = table_schema.to_dict()

# Initialize available replication methods
addl_replication_methods: list[str] = [""] # By default an empty list.
# Notes regarding replication methods:
# - 'INCREMENTAL' replication must be enabled by the user by specifying
# a replication_key value.
# - 'LOG_BASED' replication must be enabled by the developer, according
# to source-specific implementation capabilities.
replication_method = next(reversed(["FULL_TABLE", *addl_replication_methods]))

# Create the catalog entry object
return CatalogEntry(
tap_stream_id=unique_stream_id,
stream=unique_stream_id,
table=table_name,
key_properties=key_properties,
schema=Schema.from_dict(schema),
is_view=is_view,
replication_method=replication_method,
metadata=MetadataMapping.get_standard_metadata(
schema_name=schema_name,
schema=schema,
replication_method=replication_method,
key_properties=key_properties,
valid_replication_keys=None, # Must be defined by user
),
database=None, # Expects single-database context
row_count=None,
stream_alias=None,
replication_key=None, # Must be defined by user
)

def discover_catalog_entries(self) -> list[dict]:
"""Return a list of catalog entries from discovery.
Returns:
The discovered catalog entries as a list.
"""
result: list[dict] = []
engine = self._engine
inspected = sa.inspect(engine)
for schema_name in self.get_schema_names(engine, inspected):
# Use get_multi_* data here instead of pulling per-table
table_data = inspected.get_multi_columns(schema=schema_name)
pk_data = inspected.get_multi_pk_constraint(schema=schema_name)
index_data = inspected.get_multi_indexes(schema=schema_name)

# Iterate through each table and view
for table_name, is_view in self.get_object_names(
engine,
inspected,
schema_name,
):
catalog_entry = self.discover_catalog_entry(
engine,
inspected,
schema_name,
table_name,
is_view,
table_data,
pk_data,
index_data
)
result.append(catalog_entry.to_dict())

return result

class PostgresStream(SQLStream):
"""Stream class for Postgres streams."""
Expand Down

0 comments on commit c641747

Please sign in to comment.