Skip to content

Commit

Permalink
Merge pull request #54 from aiven/ettanany-support-excluding-db-exten…
Browse files Browse the repository at this point in the history
…sions

Support excluding database extensions

#54
  • Loading branch information
packi authored Aug 15, 2024
2 parents 14bbc70 + 4afdaa1 commit eea90a4
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
34 changes: 34 additions & 0 deletions aiven_db_migrate/migrate/pgmigrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
conn_info: Union[str, Dict[str, Any]],
filtered_db: Optional[str] = None,
excluded_roles: Optional[str] = None,
excluded_extensions: Optional[str] = None,
mangle: bool = False,
):
self.log = logging.getLogger(self.__class__.__name__)
Expand All @@ -130,6 +131,7 @@ def __init__(
self._pg_roles = dict()
self.filtered_db = filtered_db.split(",") if filtered_db else []
self.excluded_roles = excluded_roles.split(",") if excluded_roles else []
self.excluded_extensions = excluded_extensions.split(",") if excluded_extensions else []
if "application_name" not in self.conn_info:
self.conn_info["application_name"] = f"aiven-db-migrate/{__version__}"
self._mangle = mangle
Expand Down Expand Up @@ -797,6 +799,7 @@ def __init__(
mangle: bool = False,
filtered_db: Optional[str] = None,
excluded_roles: Optional[str] = None,
excluded_extensions: Optional[str] = None,
skip_tables: Optional[List[str]] = None,
with_tables: Optional[List[str]] = None,
replicate_extensions: bool = True,
Expand All @@ -808,12 +811,14 @@ def __init__(
conn_info=source_conn_info,
filtered_db=filtered_db,
excluded_roles=excluded_roles,
excluded_extensions=excluded_extensions,
mangle=mangle,
)
self.target = PGTarget(
conn_info=target_conn_info,
filtered_db=filtered_db,
excluded_roles=excluded_roles,
excluded_extensions=excluded_extensions,
mangle=mangle,
)
self.skip_tables = self._convert_table_names(skip_tables)
Expand Down Expand Up @@ -918,6 +923,13 @@ def filter_tables(self, db: PGDatabase) -> Optional[List[str]]:
quoted.append(name)
return quoted

def filter_extensions(self, db: PGDatabase) -> Optional[List[str]]:
"""
Given a database, return installed extensions on the source without
the ones that have explicitly been excluded from the migration.
"""
return [e.name for e in self.source.databases[db.dbname].pg_ext if e.name not in self.target.excluded_extensions]

def _check_different_servers(self) -> None:
"""Check if source and target are different servers."""
source = (self.source.conn_info["host"], self.source.conn_info.get("port"))
Expand Down Expand Up @@ -977,6 +989,10 @@ def _check_pg_ext(self):
continue
dbname = source_db.dbname
for source_ext in source_db.pg_ext:
if source_ext.name in self.target.excluded_extensions:
self.log.info("Extension %r will not be installed in target", source_ext.name)
continue

if dbname in self.target.databases:
target_db = self.target.databases[dbname]
try:
Expand Down Expand Up @@ -1194,6 +1210,11 @@ def _dump_schema(
self.source.conn_str(dbname=dbname),
]

# PG 13 and older versions do not support `--extension` option.
# The migration still succeeds with some unharmful error messages in the output.
if db and self.source.version >= LooseVersion("14"):
pg_dump_cmd.extend([f"--extension={ext}" for ext in self.filter_extensions(db)])

if self.createdb:
pg_dump_cmd.insert(-1, "--create")
# db is created and connected
Expand All @@ -1214,6 +1235,12 @@ def _dump_data(self, *, db: PGDatabase) -> PGMigrateStatus:
]
tables = self.filter_tables(db) or []
pg_dump_cmd.extend([f"--table={w}" for w in tables])

# PG 13 and older versions do not support `--extension` option.
# The migration still succeeds with some unharmful error messages in the output.
if self.source.version >= LooseVersion("14"):
pg_dump_cmd.extend([f"--extension={ext}" for ext in self.filter_extensions(db)])

subtask: PGSubTask = self._pg_dump_pipe_psql(
pg_dump_cmd=pg_dump_cmd, target_conn_str=self.target.conn_str(dbname=dbname)
)
Expand Down Expand Up @@ -1410,6 +1437,12 @@ def main(args=None, *, prog="pg_migrate"):
help="Comma separated list of database roles to exclude during migrations",
required=False,
)
parser.add_argument(
"-xe",
"--excluded-extensions",
help="Comma separated list of database extensions to exclude during migrations",
required=False,
)
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output.")
parser.add_argument(
"--no-createdb", action="store_false", dest="createdb", help="Don't automatically create database(s) in target."
Expand Down Expand Up @@ -1503,6 +1536,7 @@ def main(args=None, *, prog="pg_migrate"):
verbose=args.verbose,
filtered_db=args.filtered_db,
excluded_roles=args.excluded_roles,
excluded_extensions=args.excluded_extensions,
mangle=args.mangle,
skip_tables=args.skip_table,
with_tables=args.with_table,
Expand Down
22 changes: 22 additions & 0 deletions test/test_pg_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@ def test_extension_requires_superuser(pg_source_and_target: Tuple[PGRunner, PGRu
assert str(err.value) == f"Installing extension '{extname}' in target requires superuser"


def test_migration_succeeds_when_extensions_that_require_superuser_are_excluded(
pg_source_and_target: Tuple[PGRunner, PGRunner]
) -> None:
source, target = pg_source_and_target
dbname = random_string()
extensions = {"pg_freespacemap", "pg_visibility"}

source.create_db(dbname=dbname)
for extname in extensions:
source.create_extension(extname=extname, dbname=dbname)

pg_mig = PGMigrate(
source_conn_info=source.conn_info(),
target_conn_info=target.conn_info(),
verbose=True,
excluded_extensions=",".join(extensions),
)
assert set(pg_mig.target.excluded_extensions) == extensions

pg_mig.validate()


@pytest.mark.parametrize("createdb", [True, False])
def test_extension_superuser(pg_source_and_target: Tuple[PGRunner, PGRunner], createdb: bool):
source, target = pg_source_and_target
Expand Down
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def create_extension(self, *, extname: str, extversion: str = None, dbname: str,
grantee = self.testuser
sql = f"CREATE EXTENSION IF NOT EXISTS {extname}"
if extversion:
sql += f" WITH VERSION {extversion}"
sql += f" WITH VERSION '{extversion}'"
if LooseVersion(self.pgversion) > "9.5":
sql += " CASCADE"
try:
Expand Down

0 comments on commit eea90a4

Please sign in to comment.