diff --git a/aiven_db_migrate/migrate/pgmigrate.py b/aiven_db_migrate/migrate/pgmigrate.py index 22110b7..ef02c1b 100644 --- a/aiven_db_migrate/migrate/pgmigrate.py +++ b/aiven_db_migrate/migrate/pgmigrate.py @@ -114,6 +114,7 @@ def __init__( conn_info: Union[str, Dict[str, Any]], filtered_db: Optional[str] = None, excluded_roles: Optional[str] = None, + excluded_extensions: Optional[str] = None, mangle: bool = False, ): self.log = logging.getLogger(self.__class__.__name__) @@ -130,6 +131,7 @@ def __init__( self._pg_roles = dict() self.filtered_db = filtered_db.split(",") if filtered_db else [] self.excluded_roles = excluded_roles.split(",") if excluded_roles else [] + self.excluded_extensions = excluded_extensions.split(",") if excluded_extensions else [] if "application_name" not in self.conn_info: self.conn_info["application_name"] = f"aiven-db-migrate/{__version__}" self._mangle = mangle @@ -797,6 +799,7 @@ def __init__( mangle: bool = False, filtered_db: Optional[str] = None, excluded_roles: Optional[str] = None, + excluded_extensions: Optional[str] = None, skip_tables: Optional[List[str]] = None, with_tables: Optional[List[str]] = None, replicate_extensions: bool = True, @@ -808,12 +811,14 @@ def __init__( conn_info=source_conn_info, filtered_db=filtered_db, excluded_roles=excluded_roles, + excluded_extensions=excluded_extensions, mangle=mangle, ) self.target = PGTarget( conn_info=target_conn_info, filtered_db=filtered_db, excluded_roles=excluded_roles, + excluded_extensions=excluded_extensions, mangle=mangle, ) self.skip_tables = self._convert_table_names(skip_tables) @@ -918,6 +923,13 @@ def filter_tables(self, db: PGDatabase) -> Optional[List[str]]: quoted.append(name) return quoted + def filter_extensions(self, db: PGDatabase) -> Optional[List[str]]: + """ + Given a database, return installed extensions on the source without + the ones that have explicitly been excluded from the migration. + """ + return [e.name for e in self.source.databases[db.dbname].pg_ext if e.name not in self.target.excluded_extensions] + def _check_different_servers(self) -> None: """Check if source and target are different servers.""" source = (self.source.conn_info["host"], self.source.conn_info.get("port")) @@ -977,6 +989,10 @@ def _check_pg_ext(self): continue dbname = source_db.dbname for source_ext in source_db.pg_ext: + if source_ext.name in self.target.excluded_extensions: + self.log.info("Extension %r will not be installed in target", source_ext.name) + continue + if dbname in self.target.databases: target_db = self.target.databases[dbname] try: @@ -1194,6 +1210,11 @@ def _dump_schema( self.source.conn_str(dbname=dbname), ] + # PG 13 and older versions do not support `--extension` option. + # The migration still succeeds with some unharmful error messages in the output. + if db and self.source.version >= LooseVersion("14"): + pg_dump_cmd.extend([f"--extension={ext}" for ext in self.filter_extensions(db)]) + if self.createdb: pg_dump_cmd.insert(-1, "--create") # db is created and connected @@ -1214,6 +1235,12 @@ def _dump_data(self, *, db: PGDatabase) -> PGMigrateStatus: ] tables = self.filter_tables(db) or [] pg_dump_cmd.extend([f"--table={w}" for w in tables]) + + # PG 13 and older versions do not support `--extension` option. + # The migration still succeeds with some unharmful error messages in the output. + if self.source.version >= LooseVersion("14"): + pg_dump_cmd.extend([f"--extension={ext}" for ext in self.filter_extensions(db)]) + subtask: PGSubTask = self._pg_dump_pipe_psql( pg_dump_cmd=pg_dump_cmd, target_conn_str=self.target.conn_str(dbname=dbname) ) @@ -1410,6 +1437,12 @@ def main(args=None, *, prog="pg_migrate"): help="Comma separated list of database roles to exclude during migrations", required=False, ) + parser.add_argument( + "-xe", + "--excluded-extensions", + help="Comma separated list of database extensions to exclude during migrations", + required=False, + ) parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output.") parser.add_argument( "--no-createdb", action="store_false", dest="createdb", help="Don't automatically create database(s) in target." @@ -1503,6 +1536,7 @@ def main(args=None, *, prog="pg_migrate"): verbose=args.verbose, filtered_db=args.filtered_db, excluded_roles=args.excluded_roles, + excluded_extensions=args.excluded_extensions, mangle=args.mangle, skip_tables=args.skip_table, with_tables=args.with_table, diff --git a/test/test_pg_extensions.py b/test/test_pg_extensions.py index 1cbd6b7..6d31ffd 100644 --- a/test/test_pg_extensions.py +++ b/test/test_pg_extensions.py @@ -58,6 +58,28 @@ def test_extension_requires_superuser(pg_source_and_target: Tuple[PGRunner, PGRu assert str(err.value) == f"Installing extension '{extname}' in target requires superuser" +def test_migration_succeeds_when_extensions_that_require_superuser_are_excluded( + pg_source_and_target: Tuple[PGRunner, PGRunner] +) -> None: + source, target = pg_source_and_target + dbname = random_string() + extensions = {"pg_freespacemap", "pg_visibility"} + + source.create_db(dbname=dbname) + for extname in extensions: + source.create_extension(extname=extname, dbname=dbname) + + pg_mig = PGMigrate( + source_conn_info=source.conn_info(), + target_conn_info=target.conn_info(), + verbose=True, + excluded_extensions=",".join(extensions), + ) + assert set(pg_mig.target.excluded_extensions) == extensions + + pg_mig.validate() + + @pytest.mark.parametrize("createdb", [True, False]) def test_extension_superuser(pg_source_and_target: Tuple[PGRunner, PGRunner], createdb: bool): source, target = pg_source_and_target diff --git a/test/utils.py b/test/utils.py index a2a5691..5ea1360 100644 --- a/test/utils.py +++ b/test/utils.py @@ -452,7 +452,7 @@ def create_extension(self, *, extname: str, extversion: str = None, dbname: str, grantee = self.testuser sql = f"CREATE EXTENSION IF NOT EXISTS {extname}" if extversion: - sql += f" WITH VERSION {extversion}" + sql += f" WITH VERSION '{extversion}'" if LooseVersion(self.pgversion) > "9.5": sql += " CASCADE" try: