Skip to content

Commit

Permalink
Integrational Tests fixes (#1744)
Browse files Browse the repository at this point in the history
### Feature or Bugfix
<!-- please choose -->
- Bugfix


### Detail
- Fixed attributes' names for  new 'restricted' section in queries
- Don't expect GQL exceptions
- `test_get_folder_unauthorized` removed, since we have no access
control for this query

### Relates
- <URL or Ticket>

### Security
Please answer the questions below briefly where applicable, or write
`N/A`. Based on
[OWASP 10](https://owasp.org/Top10/en/).

- Does this PR introduce or modify any input fields or queries - this
includes
fetching data from storage outside the application (e.g. a database, an
S3 bucket)?
  - Is the input sanitized?
- What precautions are you taking before deserializing the data you
consume?
  - Is injection prevented by parametrizing queries?
  - Have you ensured no `eval` or similar functions are used?
- Does this PR introduce any functionality or component that requires
authorization?
- How have you ensured it respects the existing AuthN/AuthZ mechanisms?
  - Are you logging failed auth attempts?
- Are you using or adding any cryptographic features?
  - Do you use a standard proven implementations?
  - Are the used keys controlled by the customer? Where are they stored?
- Are you introducing any new policies/roles/users?
  - Have you used the least-privilege principle? How?


By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.

---------

Co-authored-by: Sofia Sazonova <[email protected]>
  • Loading branch information
SofiaSazonova and Sofia Sazonova authored Dec 19, 2024
1 parent 3d97d9e commit 9432a4e
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 73 deletions.
1 change: 1 addition & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
gql.Field(name='region', type=gql.String),
gql.Field(name='S3BucketName', type=gql.String),
gql.Field(name='GlueDatabaseName', type=gql.String),
gql.Field(name='GlueCrawlerName', type=gql.String),
gql.Field(name='IAMDatasetAdminRoleArn', type=gql.String),
gql.Field(name='KmsAlias', type=gql.String),
gql.Field(name='importedS3Bucket', type=gql.Boolean),
Expand Down
15 changes: 11 additions & 4 deletions tests_new/integration_tests/modules/s3_datasets/global_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
sync_tables,
create_folder,
create_table_data_filter,
list_dataset_tables,
)

from tests_new.integration_tests.modules.datasets_base.queries import list_datasets
from integration_tests.aws_clients.s3 import S3Client as S3CommonClient
from integration_tests.modules.s3_datasets.aws_clients import S3Client, KMSClient, GlueClient, LakeFormationClient
Expand Down Expand Up @@ -179,8 +181,8 @@ def create_tables(client, dataset):
aws_session_token=creds['sessionToken'],
)
file_path = os.path.join(os.path.dirname(__file__), 'sample_data/csv_table/csv_sample.csv')
s3_client = S3Client(dataset_session, dataset.region)
glue_client = GlueClient(dataset_session, dataset.region)
s3_client = S3Client(dataset_session, dataset.restricted.region)
glue_client = GlueClient(dataset_session, dataset.restricted.region)
s3_client.upload_file_to_prefix(
local_file_path=file_path, s3_path=f'{dataset.restricted.S3BucketName}/integrationtest1'
)
Expand All @@ -198,8 +200,13 @@ def create_tables(client, dataset):
table_name='integrationtest2',
bucket=dataset.restricted.S3BucketName,
)
response = sync_tables(client, datasetUri=dataset.datasetUri)
return [table for table in response.get('nodes', []) if table.GlueTableName.startswith('integrationtest')]
sync_tables(client, datasetUri=dataset.datasetUri)
response = list_dataset_tables(client, datasetUri=dataset.datasetUri)
return [
table
for table in response.tables.get('nodes', [])
if table.restricted.GlueTableName.startswith('integrationtest')
]


def create_folders(client, dataset):
Expand Down
4 changes: 4 additions & 0 deletions tests_new/integration_tests/modules/s3_datasets/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
KmsAlias
S3BucketName
GlueDatabaseName
GlueCrawlerName
IAMDatasetAdminRoleArn
}
environment {
Expand Down Expand Up @@ -352,6 +353,7 @@ def update_folder(client, locationUri, input):
mutation updateDatasetStorageLocation($locationUri: String!, $input: ModifyDatasetStorageLocationInput!) {{
updateDatasetStorageLocation(locationUri: $locationUri, input: $input) {{
locationUri
label
}}
}}
""",
Expand Down Expand Up @@ -500,6 +502,8 @@ def list_dataset_tables(client, datasetUri):
tables {{
count
nodes {{
tableUri
label
restricted {{
GlueTableName
}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_start_crawler(client1, dataset_fixture_name, request):
dataset = request.getfixturevalue(dataset_fixture_name)
dataset_uri = dataset.datasetUri
response = start_glue_crawler(client1, datasetUri=dataset_uri, input={})
assert_that(response.Name).is_equal_to(dataset.GlueCrawlerName)
assert_that(response.Name).is_equal_to(dataset.restricted.GlueCrawlerName)
# TODO: check it can run successfully + check sending prefix - We should first implement it in API


Expand Down
12 changes: 0 additions & 12 deletions tests_new/integration_tests/modules/s3_datasets/test_s3_folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,6 @@ def test_get_folder(client1, folders_fixture_name, request):
assert_that(response.label).is_equal_to('labelSessionFolderA')


@pytest.mark.parametrize(
'folders_fixture_name',
['session_s3_dataset1_folders'],
)
def test_get_folder_unauthorized(client2, folders_fixture_name, request):
folders = request.getfixturevalue(folders_fixture_name)
folder = folders[0]
assert_that(get_folder).raises(GqlError).when_called_with(client2, locationUri=folder.locationUri).contains(
'UnauthorizedOperation', 'GET_DATASET_FOLDER', folder.locationUri
)


@pytest.mark.parametrize(*FOLDERS_FIXTURES_PARAMS)
def test_update_folder(client1, folders_fixture_name, request):
folders = request.getfixturevalue(folders_fixture_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def test_list_dataset_tables(client1, dataset_fixture_name, request):
dataset = request.getfixturevalue(dataset_fixture_name)
response = list_dataset_tables(client1, dataset.datasetUri)
assert_that(response.tables.count).is_greater_than_or_equal_to(2)
tables = [table for table in response.tables.get('nodes', []) if table.GlueTableName.startswith('integrationtest')]
tables = [
table
for table in response.tables.get('nodes', [])
if table.restricted.GlueTableName.startswith('integrationtest')
]
assert_that(len(tables)).is_equal_to(2)


Expand Down Expand Up @@ -116,11 +120,12 @@ def test_delete_table(client1, dataset_fixture_name, request):
aws_secret_access_key=creds['SessionKey'],
aws_session_token=creds['sessionToken'],
)
GlueClient(dataset_session, dataset.region).create_table(
GlueClient(dataset_session, dataset.restricted.region).create_table(
database_name=dataset.restricted.GlueDatabaseName, table_name='todelete', bucket=dataset.restricted.S3BucketName
)
response = sync_tables(client1, datasetUri=dataset.datasetUri)
table_uri = [table.tableUri for table in response.get('nodes', []) if table.label == 'todelete'][0]
sync_tables(client1, datasetUri=dataset.datasetUri)
response = list_dataset_tables(client1, datasetUri=dataset.datasetUri)
table_uri = [table.tableUri for table in response.tables.get('nodes', []) if table.label == 'todelete'][0]
response = delete_table(client1, table_uri)
assert_that(response).is_true()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def test_start_table_profiling(client1, dataset_fixture_name, tables_fixture_nam
table = tables[0]
dataset_uri = dataset.datasetUri
response = start_dataset_profiling_run(
client1, input={'datasetUri': dataset_uri, 'tableUri': table.tableUri, 'GlueTableName': table.GlueTableName}
client1,
input={'datasetUri': dataset_uri, 'tableUri': table.tableUri, 'GlueTableName': table.restricted.GlueTableName},
)
assert_that(response.datasetUri).is_equal_to(dataset_uri)
assert_that(response.GlueTableName).is_equal_to(table.GlueTableName)
assert_that(response.GlueTableName).is_equal_to(table.restricted.GlueTableName)


@pytest.mark.parametrize('dataset_fixture_name', ['session_s3_dataset1'])
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_get_table_profiling_run_by_confidentiality(client2, tables_fixture_name
table_uri = tables[0].tableUri
if confidentiality in ['Unclassified']:
response = get_table_profiling_run(client2, tableUri=table_uri)
assert_that(response.GlueTableName).is_equal_to(tables[0].GlueTableName)
assert_that(response.GlueTableName).is_equal_to(tables[0].restricted.GlueTableName)
else:
assert_that(get_table_profiling_run).raises(GqlError).when_called_with(client2, table_uri).contains(
'UnauthorizedOperation', 'GET_TABLE_PROFILING_METRICS'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def principal1(request, group5, session_consumption_role_1):


@pytest.fixture(params=['Group', 'ConsumptionRole'])
def share_params_main(request, session_share_1, session_share_consrole_1, session_s3_dataset1):
def share_params_main(request, session_share_1, session_cross_acc_env_1, session_share_consrole_1, session_s3_dataset1):
if request.param == 'Group':
yield session_share_1, session_s3_dataset1
yield session_share_1, session_s3_dataset1, session_cross_acc_env_1
else:
yield session_share_consrole_1, session_s3_dataset1
yield session_share_consrole_1, session_s3_dataset1, session_cross_acc_env_1


@pytest.fixture(params=[(False, 'Group'), (True, 'Group'), (False, 'ConsumptionRole'), (True, 'ConsumptionRole')])
Expand Down Expand Up @@ -315,8 +315,10 @@ def persistent_role_share_1(


@pytest.fixture(params=['Group', 'ConsumptionRole'])
def persistent_share_params_main(request, persistent_role_share_1, persistent_group_share_1):
def persistent_share_params_main(
request, persistent_cross_acc_env_1, persistent_role_share_1, persistent_group_share_1
):
if request.param == 'Group':
yield persistent_group_share_1
yield persistent_group_share_1, persistent_cross_acc_env_1
else:
yield persistent_role_share_1
yield persistent_role_share_1, persistent_cross_acc_env_1
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,28 @@ def check_bucket_access(client, s3_client, bucket_name, should_have_access):


def check_accesspoint_access(client, s3_client, access_point_arn, item_uri, should_have_access):
folder = get_folder(client, item_uri)
if should_have_access:
folder = get_folder(client, item_uri)
assert_that(s3_client.list_accesspoint_folder_objects(access_point_arn, folder.S3Prefix + '/')).is_not_none()
else:
assert_that(get_folder).raises(Exception).when_called_with(client, item_uri).contains(
'is not authorized to perform: GET_DATASET_FOLDER'
)
assert_that(s3_client.list_accesspoint_folder_objects).raises(ClientError).when_called_with(
access_point_arn, folder.S3Prefix + '/'
).contains('AccessDenied')


def check_share_items_access(
client,
group,
shareUri,
share_environment,
consumption_role,
env_client,
):
share = get_share_object(client, shareUri)
dataset = share.dataset
principal_type = share.principal.principalType
if principal_type == 'Group':
credentials_str = get_environment_access_token(client, share.environment.environmentUri, group)
credentials_str = get_environment_access_token(client, share_environment.environmentUri, group)
credentials = json.loads(credentials_str)
session = boto3.Session(
aws_access_key_id=credentials['AccessKey'],
Expand All @@ -169,7 +170,7 @@ def check_share_items_access(
f'arn:aws:s3:{dataset.region}:{dataset.AwsAccountId}:accesspoint/{consumption_data.s3AccessPointName}'
)
if principal_type == 'Group':
workgroup = athena_client.get_env_work_group(share.environment.label)
workgroup = athena_client.get_env_work_group(share_environment.label)
athena_workgroup_output_location = None
else:
workgroup = 'primary'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_reject_share(client1, client5, session_cross_acc_env_1, session_s3_data


def test_change_share_purpose(client5, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
change_request_purpose = update_share_request_reason(client5, share.shareUri, 'new purpose')
assert_that(change_request_purpose).is_true()
updated_share = get_share_object(client5, share.shareUri)
Expand All @@ -153,37 +153,42 @@ def test_submit_object(client5, share_params_all):

@pytest.mark.dependency(name='share_approved', depends=['share_submitted'])
def test_approve_share(client1, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
check_approve_share_object(client1, share.shareUri)


@pytest.mark.dependency(name='share_succeeded', depends=['share_approved'])
def test_share_succeeded(client1, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
check_share_succeeded(client1, share.shareUri, check_contains_all_item_types=True)


@pytest.mark.dependency(name='share_verified', depends=['share_succeeded'])
def test_verify_share_items(client1, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
check_verify_share_items(client1, share.shareUri)


@pytest.mark.dependency(depends=['share_verified'])
def test_check_item_access(
client5, session_cross_acc_env_1_aws_client, share_params_main, group5, session_consumption_role_1
):
share, dataset = share_params_main
share, dataset, share_environment = share_params_main
check_share_items_access(
client5, group5, share.shareUri, session_consumption_role_1, session_cross_acc_env_1_aws_client
client5,
group5,
share.shareUri,
share_environment,
session_consumption_role_1,
session_cross_acc_env_1_aws_client,
)


@pytest.mark.dependency(name='unhealthy_items', depends=['share_verified'])
def test_unhealthy_items(
client5, session_cross_acc_env_1_aws_client, session_cross_acc_env_1_integration_role_arn, share_params_main
):
share, _ = share_params_main
share, _, _ = share_params_main
iam = session_cross_acc_env_1_aws_client.resource('iam')
principal_role = iam.Role(share.principal.principalRoleName)
# break s3 by removing policies
Expand All @@ -209,7 +214,7 @@ def test_unhealthy_items(

@pytest.mark.dependency(depends=['share_approved'])
def test_reapply_unauthoried(client5, share_params_main):
share, _ = share_params_main
share, _, _ = share_params_main
share_uri = share.shareUri
share_object = get_share_object(client5, share_uri)
item_uris = [item.shareItemUri for item in share_object['items'].nodes]
Expand All @@ -220,7 +225,7 @@ def test_reapply_unauthoried(client5, share_params_main):

@pytest.mark.dependency(depends=['share_approved'])
def test_reapply(client1, share_params_main):
share, _ = share_params_main
share, _, _ = share_params_main
share_uri = share.shareUri
share_object = get_share_object(client1, share_uri)
item_uris = [item.shareItemUri for item in share_object['items'].nodes]
Expand All @@ -233,7 +238,7 @@ def test_reapply(client1, share_params_main):

@pytest.mark.dependency(name='share_revoked', depends=['share_succeeded'])
def test_revoke_share(client1, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
check_share_ready(client1, share.shareUri)
revoke_and_check_all_shared_items(client1, share.shareUri, check_contains_all_item_types=True)

Expand All @@ -242,8 +247,13 @@ def test_revoke_share(client1, share_params_main):
def test_revoke_succeeded(
client1, client5, session_cross_acc_env_1_aws_client, share_params_main, group5, session_consumption_role_1
):
share, dataset = share_params_main
share, dataset, share_environment = share_params_main
check_all_items_revoke_job_succeeded(client1, share.shareUri, check_contains_all_item_types=True)
check_share_items_access(
client5, group5, share.shareUri, session_consumption_role_1, session_cross_acc_env_1_aws_client
client5,
group5,
share.shareUri,
share_environment,
session_consumption_role_1,
session_cross_acc_env_1_aws_client,
)
Loading

0 comments on commit 9432a4e

Please sign in to comment.