Skip to content

Commit

Permalink
Refactor ci-tests setup and add sagemaker role creation
Browse files Browse the repository at this point in the history
  • Loading branch information
tkilias committed Nov 27, 2023
1 parent c47c756 commit 06ead37
Show file tree
Hide file tree
Showing 24 changed files with 246 additions and 136 deletions.
3 changes: 2 additions & 1 deletion ci-isolation/src/main/resources/s3-access.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"Action": [
"iam:CreatePolicy",
"iam:CreateRole",
"iam:GetRole",
"iam:AttachRolePolicy",
"iam:DeleteRole",
"iam:DeletePolicy",
Expand All @@ -21,4 +22,4 @@
"Resource": "*"
}
]
}
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ botocore = "1.29.163"
protobuf = ">=3.1,<=3.20.0"
sagemaker = "^2.59.1"
pyexasol = "^0.24.0"
localstack-client = "^1.25"
importlib-resources = "^5.2.0"
click = "^8.0.3"
typeguard = "^2.11.1"
Expand All @@ -40,6 +39,7 @@ coverage = "^6.3"
exasol-udf-mock-python = { git = "https://github.com/exasol/udf-mock-python.git", branch = "main" }
exasol-bucketfs = "^0.6.0"
poethepoet = "^0.13.1"
localstack-client = "^1.25"
boto3 = "^1.20.40"

[build-system]
Expand Down
4 changes: 2 additions & 2 deletions scripts/setup_integration_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ cnt_func=0
function install_localstack {
cnt_func=$((cnt_func+1))
echo -e "${YEL} Step-$cnt_func: ${GRA} Install Localstack packages${NC}"
pip install localstack=='0.12.18'
pip install localstack-client=="1.25"
pip install localstack
pip install localstack-client
}

function checkout_exasol_test_container {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def language_container():
completed_process = subprocess.run(
[script_dir], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output = completed_process.stdout.decode("UTF-8")
print(output)
completed_process.check_returncode()

print(output)
lines = output.splitlines()

alter_session_selector = "ALTER SYSTEM SET SCRIPT_LANGUAGES='"
Expand Down
151 changes: 134 additions & 17 deletions tests/ci_tests/fixtures/prepare_environment_fixture.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import dataclasses
import os
from inspect import cleandoc

import boto3
import pyexasol
import pytest
from click.testing import CliRunner

from exasol_sagemaker_extension.deployment import deploy_cli
from tests.ci_tests.utils.parameters import db_params, aws_params, \
from tests.ci_tests.utils.parameters import db_params, \
reg_model_setup_params, cls_model_setup_params


Expand Down Expand Up @@ -50,40 +54,153 @@ def _setup_database(db_conn):
__insert_into_tables(db_conn, model_setup)


def _create_aws_connection(conn):
@pytest.fixture(scope="session")
def connection_object_for_aws_credentials(db_conn, aws_s3_bucket):
aws_conn_name = "test_aws_credentials_connection_name"
aws_region = os.environ["AWS_DEFAULT_REGION"]
aws_s3_uri = f"https://{aws_s3_bucket}.s3.{aws_region}.amazonaws.com"
query = "CREATE OR REPLACE CONNECTION {aws_conn_name} " \
"TO '{aws_s3_uri}' " \
"USER '{aws_key_id}' IDENTIFIED BY '{aws_access_key}'"\
.format(aws_conn_name=aws_params.aws_conn_name,
aws_s3_uri=aws_params.aws_s3_uri,
aws_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_access_key=os.environ["AWS_SECRET_ACCESS_KEY"])
conn.execute(query)
"USER '{aws_access_key_id}' IDENTIFIED BY '{aws_secret_access_key}'" \
.format(aws_conn_name=aws_conn_name,
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
aws_s3_uri=aws_s3_uri)
db_conn.execute(query)
print(query)
yield aws_conn_name
db_conn.execute(f"DROP CONNECTION {aws_conn_name};")


def _create_aws_s3_bucket():
s3_client = boto3.client('s3')
bucket_name = "ci-exasol-sagemaker-extension-bucket"
try:
s3_client.create_bucket(
Bucket=aws_params.aws_bucket,
Bucket=bucket_name,
CreateBucketConfiguration={
'LocationConstraint': os.environ["AWS_DEFAULT_REGION"]}
)
except s3_client.exceptions.BucketAlreadyOwnedByYou as ex:
print("'BucketAlreadyOwnedByYou' exception is handled")
return bucket_name


def _remove_aws_s3_bucket():
def _remove_aws_s3_bucket_content(bucket_name: str):
s3_client = boto3.resource('s3')

bucket = s3_client.Bucket(aws_params.aws_bucket)
bucket = s3_client.Bucket(bucket_name)
bucket.objects.all().delete()


@pytest.fixture(scope="session")
def prepare_ci_test_environment(db_conn):
def aws_s3_bucket():
bucket_name = _create_aws_s3_bucket()
yield bucket_name
_remove_aws_s3_bucket_content(bucket_name)


@pytest.fixture(scope="session")
def aws_sagemaker_role() -> str:
iam_client = boto3.client('iam')
role_name = _create_sagemaker_role(iam_client)
policy_arn = _create_sagemaker_policy(iam_client)
_attach_policy_to_role(iam_client,
policy_arn=policy_arn,
role_name=role_name)
_attach_policy_to_role(iam_client,
policy_arn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
role_name=role_name)
return role_name


def _attach_policy_to_role(iam_client, policy_arn, role_name):
response = iam_client.attach_role_policy(
PolicyArn=policy_arn,
RoleName=role_name,
)


def _create_sagemaker_role(iam_client):
role_name = "ci-exasol-sagemaker-extension-role"
try:
assume_policy_document = cleandoc("""
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"Service": "sagemaker.amazonaws.com"
},
"Action": "sts:AssumeRole"
}
]
}
""")
response = iam_client.create_role(
RoleName=role_name,
AssumeRolePolicyDocument=assume_policy_document,
Description='This role is used for the CI Tests of the exasol.sagemaker-extension',
)
except iam_client.exceptions.EntityAlreadyExistsException as ex:
print("'EntityAlreadyExistsException' exception is handled")
return role_name


def _create_sagemaker_policy(iam_client) -> str:
policy_name = "ci-exasol-sagemaker-extension-policy"
try:
policy_document = cleandoc("""
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:*"
],
"Resource": "*"
}
]
}
""")
response = iam_client.create_policy(
PolicyName=policy_name,
PolicyDocument=policy_document,
Description='This policy is used for the CI Tests of the exasol.sagemaker-extension',

)
print(response)
return response["Policy"]["ARN"] # FIXME got key error in this line
except iam_client.exceptions.EntityAlreadyExistsException as ex:
print("'EntityAlreadyExistsException' exception is handled")
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity()['Account']
policy_arn = f'arn:aws:iam::{account_id}:policy/{policy_name}'
return policy_arn


@dataclasses.dataclass
class CITestEnvironment:
db_conn: pyexasol.ExaConnection
aws_s3_bucket: str
aws_sagemaker_role: str
connection_object_for_aws_credentials: str
aws_region: str = os.environ["AWS_DEFAULT_REGION"]

@property
def aws_bucket_uri(self) -> str:
aws_bucket_uri = f"s3://{self.aws_s3_bucket}"
return aws_bucket_uri


@pytest.fixture(scope="session")
def prepare_ci_test_environment(db_conn,
aws_s3_bucket,
connection_object_for_aws_credentials,
aws_sagemaker_role) -> CITestEnvironment:
_setup_database(db_conn)
_create_aws_connection(db_conn)
_create_aws_s3_bucket()
yield db_conn
_remove_aws_s3_bucket()
yield CITestEnvironment(db_conn=db_conn,
aws_s3_bucket=aws_s3_bucket,
connection_object_for_aws_credentials=connection_object_for_aws_credentials,
aws_sagemaker_role=aws_sagemaker_role)
7 changes: 4 additions & 3 deletions tests/ci_tests/fixtures/setup_ci_test_environment.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest

from tests.ci_tests.fixtures.prepare_environment_fixture import CITestEnvironment


@pytest.fixture(scope="session")
def setup_ci_test_environment(register_language_container,
prepare_ci_test_environment):
db_conn = prepare_ci_test_environment
return db_conn
prepare_ci_test_environment) -> CITestEnvironment:
return prepare_ci_test_environment
18 changes: 10 additions & 8 deletions tests/ci_tests/test_deploying_autopilot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import time
import pytest
from datetime import datetime

import pytest

from tests.ci_tests.fixtures.prepare_environment_fixture import CITestEnvironment
from tests.ci_tests.utils import parameters
from tests.ci_tests.utils.autopilot_deployment import AutopilotTestDeployment
from tests.ci_tests.utils.autopilot_polling import AutopilotTestPolling
from tests.ci_tests.utils.autopilot_training import AutopilotTestTraining
from tests.ci_tests.utils.cleanup import cleanup
from tests.ci_tests.utils.queries import DatabaseQueries
from tests.ci_tests.utils.checkers import is_aws_credentials_not_set
from tests.ci_tests.utils.parameters import cls_model_setup_params
from tests.ci_tests.utils import parameters
from tests.ci_tests.utils.queries import DatabaseQueries


def _is_training_completed(status):
Expand All @@ -19,14 +21,14 @@ def _is_training_completed(status):


@cleanup
def _deploy_endpoint(job_name, endpoint_name, model_setup_params, db_conn):
def _deploy_endpoint(job_name, endpoint_name, model_setup_params, ci_test_env: CITestEnvironment, ):
# poll until the training is completed
timeout_time = time.time() + parameters.TIMEOUT
while True:
status = AutopilotTestPolling.poll_autopilot_job(
job_name,
model_setup_params.schema_name,
db_conn)
ci_test_env)
print(status)

if _is_training_completed(status):
Expand All @@ -42,12 +44,12 @@ def _deploy_endpoint(job_name, endpoint_name, model_setup_params, db_conn):
job_name,
endpoint_name,
model_setup_params,
db_conn
ci_test_env
)

# assertion
all_scripts = DatabaseQueries.get_all_scripts(
model_setup_params, db_conn)
model_setup_params, ci_test_env.db_conn)
assert endpoint_name in list(map(lambda x: x[0], all_scripts))


Expand Down
5 changes: 3 additions & 2 deletions tests/ci_tests/test_polling_autopilot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from datetime import datetime

import pytest

from tests.ci_tests.utils.autopilot_polling import AutopilotTestPolling
from tests.ci_tests.utils.autopilot_training import AutopilotTestTraining
from tests.ci_tests.utils.checkers import is_aws_credentials_not_set
from tests.ci_tests.utils.parameters import cls_model_setup_params


Expand Down
14 changes: 8 additions & 6 deletions tests/ci_tests/test_predicting_autopilot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import time
import pytest
from datetime import datetime

import pytest

from tests.ci_tests.fixtures.prepare_environment_fixture import CITestEnvironment
from tests.ci_tests.utils import parameters
from tests.ci_tests.utils.autopilot_deployment import AutopilotTestDeployment
from tests.ci_tests.utils.autopilot_polling import AutopilotTestPolling
from tests.ci_tests.utils.autopilot_prediction import AutopilotTestPrediction
from tests.ci_tests.utils.autopilot_training import AutopilotTestTraining
from tests.ci_tests.utils.checkers import is_aws_credentials_not_set
from tests.ci_tests.utils.cleanup import cleanup
from tests.ci_tests.utils.parameters import cls_model_setup_params, \
reg_model_setup_params
Expand All @@ -20,14 +22,14 @@ def _is_training_completed(status):


@cleanup
def _make_prediction(job_name, endpoint_name, model_setup_params, db_conn):
def _make_prediction(job_name, endpoint_name, model_setup_params, ci_test_env: CITestEnvironment):
# poll until the training is completed
timeout_time = time.time() + parameters.TIMEOUT
while True:
status = AutopilotTestPolling.poll_autopilot_job(
job_name,
model_setup_params.schema_name,
db_conn)
ci_test_env)
print(status)

if _is_training_completed(status):
Expand All @@ -43,12 +45,12 @@ def _make_prediction(job_name, endpoint_name, model_setup_params, db_conn):
job_name,
endpoint_name,
model_setup_params,
db_conn
ci_test_env
)

# assertion
predictions = AutopilotTestPrediction.predict(
endpoint_name, model_setup_params.schema_name, db_conn)
endpoint_name, model_setup_params.schema_name, ci_test_env)
assert predictions


Expand Down
Loading

0 comments on commit 06ead37

Please sign in to comment.