From 8e0cf6b9f2a530b96d4f6441917a95cc24cf56c0 Mon Sep 17 00:00:00 2001 From: Mikhail Beck Date: Wed, 20 Dec 2023 08:03:00 +0000 Subject: [PATCH] Scripts deployment refactoring [run aws tests] (#100) --- .../deployment/deploy_cli.py | 3 +- .../deployment/deploy_create_statements.py | 57 +++++++++++++------ .../test_deploy_create_statements.py | 4 +- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/exasol_sagemaker_extension/deployment/deploy_cli.py b/exasol_sagemaker_extension/deployment/deploy_cli.py index 892d35b3..f4e14e74 100644 --- a/exasol_sagemaker_extension/deployment/deploy_cli.py +++ b/exasol_sagemaker_extension/deployment/deploy_cli.py @@ -20,7 +20,7 @@ def main(host: str, port: str, user: str, pwd: str, schema: str, logging.basicConfig(format='%(asctime)s - %(module)s - %(message)s', level=logging.DEBUG) - deployment = DeployCreateStatements( + DeployCreateStatements.create_and_run( db_host=host, db_port=port, db_user=user, @@ -29,7 +29,6 @@ def main(host: str, port: str, user: str, pwd: str, schema: str, to_print=verbose, develop=develop ) - deployment.run() if __name__ == "__main__": diff --git a/exasol_sagemaker_extension/deployment/deploy_create_statements.py b/exasol_sagemaker_extension/deployment/deploy_create_statements.py index fec588e3..a6e38256 100644 --- a/exasol_sagemaker_extension/deployment/deploy_create_statements.py +++ b/exasol_sagemaker_extension/deployment/deploy_create_statements.py @@ -25,26 +25,12 @@ class DeployCreateStatements: that generate scripts deploying the sagemaker-extension project. """ - def __init__(self, db_host: str, db_port: str, db_user: str, db_pass: str, + def __init__(self, exasol_conn: pyexasol.ExaConnection, schema: str, to_print: bool, develop: bool): - self._db_host = db_host - self._db_port = db_port - self._db_user = db_user - self._db_pass = db_pass self._schema = schema self._to_print = to_print self._develop = develop - self.__exasol_conn = pyexasol.connect( - dsn="{host}:{port}".format( - host=self._db_host, port=self._db_port), - user=self._db_user, - password=self._db_pass, - compression=True, - encryption=True, - websocket_sslopt={ - "cert_reqs": ssl.CERT_NONE, - } - ) + self.__exasol_conn = exasol_conn @property def statement_maps(self): @@ -131,3 +117,42 @@ def create_statements(): stmt_generator.save_statement() logger.debug(f"{stmt_generator.__class__.__name__} " "is created and saved.") + + @classmethod + def create_and_run(cls, + db_host: str, + db_port: str, + db_user: str, + db_pass: str, + schema: str, + to_print: bool, + develop: bool): + """ + Creates a database connection object based on the provided credentials + Creates an instance of the DeployCreateStatements passing the connection + object to it and calls its run method. + + Parameters: + db_host - database host address + db_port - database port + db_user - database username + db_pass - the user password + schema - schema where the scripts should be created + to_print - if True the script creation SQL commands will be + printed rather than executed + develop - if True the scripts will be generated from scratch + """ + + exasol_conn = pyexasol.connect( + dsn=f"{db_host}:{db_port}", + user=db_user, + password=db_pass, + compression=True, + encryption=True, + websocket_sslopt={ + "cert_reqs": ssl.CERT_NONE, + } + ) + + deployer = cls(exasol_conn, schema, to_print, develop) + deployer.run() diff --git a/tests/deployment/test_deploy_create_statements.py b/tests/deployment/test_deploy_create_statements.py index bd4a8702..8a90c549 100644 --- a/tests/deployment/test_deploy_create_statements.py +++ b/tests/deployment/test_deploy_create_statements.py @@ -27,7 +27,7 @@ def get_all_scripts(db_conn): def test_deploy_create_statements(db_conn, register_language_container): - deployer = DeployCreateStatements( + DeployCreateStatements.create_and_run( db_host=db_params.host, db_port=db_params.port, db_user=db_params.user, @@ -37,8 +37,6 @@ def test_deploy_create_statements(db_conn, register_language_container): develop=False ) - deployer.run() - all_schemas = get_all_schemas(db_conn) all_scripts = get_all_scripts(db_conn)