diff --git a/backend/dataall/base/feature_toggle_checker.py b/backend/dataall/base/feature_toggle_checker.py index 3fa3f70f3..5e945beab 100644 --- a/backend/dataall/base/feature_toggle_checker.py +++ b/backend/dataall/base/feature_toggle_checker.py @@ -2,6 +2,7 @@ Contains decorators that check if a feature has been enabled or not """ +import functools from typing import List, Any, Optional, Callable from dataall.base.config import config @@ -12,6 +13,7 @@ def is_feature_enabled(config_property: str): def decorator(f): fn, fn_decorator = process_func(f) + @functools.wraps(fn) def decorated(*args, **kwargs): value = config.get_property(config_property) if not value: @@ -33,6 +35,7 @@ def is_feature_enabled_for_allowed_values( def decorator(f): fn, fn_decorator = process_func(f) + @functools.wraps(fn) def decorated(*args, **kwargs): config_property_value = None if config_property is None and resolve_property is None: diff --git a/backend/dataall/core/stacks/db/target_type_repositories.py b/backend/dataall/core/stacks/db/target_type_repositories.py index c175ed307..ca5770649 100644 --- a/backend/dataall/core/stacks/db/target_type_repositories.py +++ b/backend/dataall/core/stacks/db/target_type_repositories.py @@ -5,6 +5,7 @@ GET_ENVIRONMENT, UPDATE_ENVIRONMENT, ) +from dataall.core.permissions.services.tenant_permissions import MANAGE_ENVIRONMENTS logger = logging.getLogger(__name__) @@ -14,10 +15,11 @@ class TargetType: _TARGET_TYPES = {} - def __init__(self, name, read_permission, write_permission): + def __init__(self, name, read_permission, write_permission, tenant_permission): self.name = name self.read_permission = read_permission self.write_permission = write_permission + self.tenant_permission = tenant_permission TargetType._TARGET_TYPES[name] = self @@ -31,6 +33,11 @@ def get_resource_read_permission_name(target_type): TargetType.is_supported_target_type(target_type) return TargetType._TARGET_TYPES[target_type].read_permission + @staticmethod + def get_resource_tenant_permission_name(target_type): + TargetType.is_supported_target_type(target_type) + return TargetType._TARGET_TYPES[target_type].tenant_permission + @staticmethod def is_supported_target_type(target_type): if target_type not in TargetType._TARGET_TYPES: @@ -41,4 +48,4 @@ def is_supported_target_type(target_type): ) -TargetType('environment', GET_ENVIRONMENT, UPDATE_ENVIRONMENT) +TargetType('environment', GET_ENVIRONMENT, UPDATE_ENVIRONMENT, MANAGE_ENVIRONMENTS) diff --git a/backend/dataall/core/stacks/services/stack_service.py b/backend/dataall/core/stacks/services/stack_service.py index c409a640d..d02d9ba48 100644 --- a/backend/dataall/core/stacks/services/stack_service.py +++ b/backend/dataall/core/stacks/services/stack_service.py @@ -6,6 +6,7 @@ from dataall.base.db import exceptions from dataall.base.feature_toggle_checker import is_feature_enabled_for_allowed_values from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.core.stacks.aws.cloudformation import CloudFormation from dataall.core.stacks.services.keyvaluetag_service import KeyValueTagService from dataall.core.tasks.service_handlers import Worker @@ -181,6 +182,13 @@ def update_stack_by_target_uri(target_uri, target_type): StackRequestVerifier.verify_target_type_and_uri(target_uri, target_type) context = get_context() with context.db_engine.scoped_session() as session: + TenantPolicyService.check_user_tenant_permission( + session=session, + username=context.username, + groups=context.groups, + permission_name=TargetType.get_resource_tenant_permission_name(target_type), + tenant_name=TenantPolicyService.TENANT_NAME, + ) ResourcePolicyService.check_user_resource_permission( session=session, username=context.username, @@ -196,6 +204,23 @@ def update_stack_by_target_uri(target_uri, target_type): def update_stack_tags(input): StackRequestVerifier.validate_update_tag_input(input) target_uri = input.get('targetUri') + target_type = input.get('targetType') + context = get_context() + with context.db_engine.scoped_session() as session: + TenantPolicyService.check_user_tenant_permission( + session=session, + username=context.username, + groups=context.groups, + permission_name=TargetType.get_resource_tenant_permission_name(target_type), + tenant_name=TenantPolicyService.TENANT_NAME, + ) + ResourcePolicyService.check_user_resource_permission( + session=session, + username=context.username, + groups=context.groups, + resource_uri=target_uri, + permission_name=TargetType.get_resource_update_permission_name(target_type), + ) kv_tags = KeyValueTagService.update_key_value_tags( uri=target_uri, data=input, diff --git a/backend/dataall/modules/dashboards/services/dashboard_quicksight_service.py b/backend/dataall/modules/dashboards/services/dashboard_quicksight_service.py index 8b7eb9a9e..1b60b2122 100644 --- a/backend/dataall/modules/dashboards/services/dashboard_quicksight_service.py +++ b/backend/dataall/modules/dashboards/services/dashboard_quicksight_service.py @@ -8,10 +8,11 @@ from dataall.base.db.exceptions import UnauthorizedOperation, TenantUnauthorized, AWSResourceNotFound from dataall.core.permissions.services.tenant_permissions import TENANT_ALL from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.modules.dashboards.db.dashboard_repositories import DashboardRepository from dataall.modules.dashboards.db.dashboard_models import Dashboard from dataall.modules.dashboards.aws.dashboard_quicksight_client import DashboardQuicksightClient -from dataall.modules.dashboards.services.dashboard_permissions import GET_DASHBOARD, CREATE_DASHBOARD +from dataall.modules.dashboards.services.dashboard_permissions import GET_DASHBOARD, CREATE_DASHBOARD, MANAGE_DASHBOARDS from dataall.base.utils import Parameter @@ -58,6 +59,7 @@ def get_quicksight_reader_url(cls, uri): return client.get_anonymous_session(dashboard_id=dash.DashboardId) @classmethod + @TenantPolicyService.has_tenant_permission(MANAGE_DASHBOARDS) @ResourcePolicyService.has_resource_permission(CREATE_DASHBOARD) def get_quicksight_designer_url(cls, uri: str): context = get_context() diff --git a/backend/dataall/modules/datapipelines/__init__.py b/backend/dataall/modules/datapipelines/__init__.py index ad96a73c1..7b1a56334 100644 --- a/backend/dataall/modules/datapipelines/__init__.py +++ b/backend/dataall/modules/datapipelines/__init__.py @@ -28,14 +28,17 @@ def __init__(self): from dataall.modules.feed.api.registry import FeedRegistry, FeedDefinition from dataall.modules.datapipelines.db.datapipelines_models import DataPipeline from dataall.modules.datapipelines.db.datapipelines_repositories import DatapipelinesRepository - from dataall.modules.datapipelines.services.datapipelines_permissions import GET_PIPELINE, UPDATE_PIPELINE - + from dataall.modules.datapipelines.services.datapipelines_permissions import ( + GET_PIPELINE, + UPDATE_PIPELINE, + MANAGE_PIPELINES, + ) import dataall.modules.datapipelines.api FeedRegistry.register(FeedDefinition('DataPipeline', DataPipeline)) - TargetType('pipeline', GET_PIPELINE, UPDATE_PIPELINE) - TargetType('cdkpipeline', GET_PIPELINE, UPDATE_PIPELINE) + TargetType('pipeline', GET_PIPELINE, UPDATE_PIPELINE, MANAGE_PIPELINES) + TargetType('cdkpipeline', GET_PIPELINE, UPDATE_PIPELINE, MANAGE_PIPELINES) EnvironmentResourceManager.register(DatapipelinesRepository()) diff --git a/backend/dataall/modules/mlstudio/__init__.py b/backend/dataall/modules/mlstudio/__init__.py index e3d1d2f19..2e50fb64e 100644 --- a/backend/dataall/modules/mlstudio/__init__.py +++ b/backend/dataall/modules/mlstudio/__init__.py @@ -20,9 +20,13 @@ def __init__(self): from dataall.core.stacks.db.target_type_repositories import TargetType import dataall.modules.mlstudio.api from dataall.modules.mlstudio.services.mlstudio_service import SagemakerStudioEnvironmentResource - from dataall.modules.mlstudio.services.mlstudio_permissions import GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER + from dataall.modules.mlstudio.services.mlstudio_permissions import ( + GET_SGMSTUDIO_USER, + UPDATE_SGMSTUDIO_USER, + MANAGE_SGMSTUDIO_USERS, + ) - TargetType('mlstudio', GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER) + TargetType('mlstudio', GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER, MANAGE_SGMSTUDIO_USERS) EnvironmentResourceManager.register(SagemakerStudioEnvironmentResource()) diff --git a/backend/dataall/modules/notebooks/__init__.py b/backend/dataall/modules/notebooks/__init__.py index 5fd8900da..0fc22ea07 100644 --- a/backend/dataall/modules/notebooks/__init__.py +++ b/backend/dataall/modules/notebooks/__init__.py @@ -17,9 +17,13 @@ def is_supported(modes): def __init__(self): import dataall.modules.notebooks.api from dataall.core.stacks.db.target_type_repositories import TargetType - from dataall.modules.notebooks.services.notebook_permissions import GET_NOTEBOOK, UPDATE_NOTEBOOK + from dataall.modules.notebooks.services.notebook_permissions import ( + GET_NOTEBOOK, + UPDATE_NOTEBOOK, + MANAGE_NOTEBOOKS, + ) - TargetType('notebook', GET_NOTEBOOK, UPDATE_NOTEBOOK) + TargetType('notebook', GET_NOTEBOOK, UPDATE_NOTEBOOK, MANAGE_NOTEBOOKS) log.info('API of sagemaker notebooks has been imported') diff --git a/backend/dataall/modules/s3_datasets/__init__.py b/backend/dataall/modules/s3_datasets/__init__.py index f0b73a6d0..dbd4f458c 100644 --- a/backend/dataall/modules/s3_datasets/__init__.py +++ b/backend/dataall/modules/s3_datasets/__init__.py @@ -41,7 +41,11 @@ def __init__(self): from dataall.modules.s3_datasets.indexers.table_indexer import DatasetTableIndexer import dataall.modules.s3_datasets.api - from dataall.modules.s3_datasets.services.dataset_permissions import GET_DATASET, UPDATE_DATASET + from dataall.modules.s3_datasets.services.dataset_permissions import ( + GET_DATASET, + UPDATE_DATASET, + MANAGE_DATASETS, + ) from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository from dataall.modules.s3_datasets.db.dataset_models import DatasetStorageLocation, DatasetTable, S3Dataset @@ -73,7 +77,7 @@ def __init__(self): add_vote_type('dataset', DatasetIndexer) - TargetType('dataset', GET_DATASET, UPDATE_DATASET) + TargetType('dataset', GET_DATASET, UPDATE_DATASET, MANAGE_DATASETS) EnvironmentResourceManager.register(DatasetRepository()) diff --git a/backend/dataall/modules/s3_datasets/api/profiling/resolvers.py b/backend/dataall/modules/s3_datasets/api/profiling/resolvers.py index b92b8d065..5eae47bb5 100644 --- a/backend/dataall/modules/s3_datasets/api/profiling/resolvers.py +++ b/backend/dataall/modules/s3_datasets/api/profiling/resolvers.py @@ -10,6 +10,11 @@ log = logging.getLogger(__name__) +def _validate_uri(uri): + if not uri: + raise RequiredParameter('URI') + + def resolve_dataset(context, source: DatasetProfilingRun): if not source: return None @@ -17,8 +22,7 @@ def resolve_dataset(context, source: DatasetProfilingRun): def start_profiling_run(context: Context, source, input: dict = None): - if 'datasetUri' not in input: - raise RequiredParameter('datasetUri') + _validate_uri(input.get('datasetUri')) return DatasetProfilingService.start_profiling_run( uri=input['datasetUri'], table_uri=input.get('tableUri'), glue_table_name=input.get('GlueTableName') diff --git a/backend/dataall/modules/s3_datasets/api/storage_location/resolvers.py b/backend/dataall/modules/s3_datasets/api/storage_location/resolvers.py index 212332652..3d6847029 100644 --- a/backend/dataall/modules/s3_datasets/api/storage_location/resolvers.py +++ b/backend/dataall/modules/s3_datasets/api/storage_location/resolvers.py @@ -6,13 +6,16 @@ from dataall.modules.s3_datasets.db.dataset_models import DatasetStorageLocation, S3Dataset -@is_feature_enabled('modules.s3_datasets.features.file_actions') -def create_storage_location(context, source, datasetUri: str = None, input: dict = None): - if 'prefix' not in input: - raise RequiredParameter('prefix') +def _validate_input(input: dict): if 'label' not in input: raise RequiredParameter('label') + if 'prefix' not in input: + raise RequiredParameter('prefix') + +@is_feature_enabled('modules.s3_datasets.features.file_actions') +def create_storage_location(context, source, datasetUri: str = None, input: dict = None): + _validate_input(input) return DatasetLocationService.create_storage_location(uri=datasetUri, data=input) diff --git a/backend/dataall/modules/s3_datasets/services/dataset_column_service.py b/backend/dataall/modules/s3_datasets/services/dataset_column_service.py index 987b855a4..77d94b271 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_column_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_column_service.py @@ -1,4 +1,5 @@ from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.core.tasks.service_handlers import Worker from dataall.base.aws.sts import SessionHelper from dataall.base.context import get_context @@ -7,11 +8,10 @@ from dataall.modules.s3_datasets.aws.glue_table_client import GlueTableClient from dataall.modules.s3_datasets.db.dataset_column_repositories import DatasetColumnRepository from dataall.modules.s3_datasets.db.dataset_table_repositories import DatasetTableRepository -from dataall.modules.s3_datasets.services.dataset_permissions import UPDATE_DATASET_TABLE +from dataall.modules.s3_datasets.services.dataset_permissions import UPDATE_DATASET_TABLE, MANAGE_DATASETS from dataall.modules.s3_datasets.db.dataset_models import DatasetTable, DatasetTableColumn from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository from dataall.modules.datasets_base.services.datasets_enums import ConfidentialityClassification -from dataall.modules.s3_datasets.services.dataset_permissions import PREVIEW_DATASET_TABLE class DatasetColumnService: @@ -42,6 +42,7 @@ def paginate_active_columns_for_table(uri: str, filter=None): return DatasetColumnRepository.paginate_active_columns_for_table(session, uri, filter) @classmethod + @TenantPolicyService.has_tenant_permission(MANAGE_DATASETS) @ResourcePolicyService.has_resource_permission( UPDATE_DATASET_TABLE, parent_resource=_get_dataset_uri, param_name='table_uri' ) @@ -56,6 +57,7 @@ def sync_table_columns(cls, table_uri: str): return cls.paginate_active_columns_for_table(uri=table_uri, filter={}) @staticmethod + @TenantPolicyService.has_tenant_permission(MANAGE_DATASETS) @ResourcePolicyService.has_resource_permission( UPDATE_DATASET_TABLE, parent_resource=_get_dataset_uri_for_column, param_name='column_uri' ) diff --git a/backend/dataall/modules/s3_datasets/services/dataset_profiling_service.py b/backend/dataall/modules/s3_datasets/services/dataset_profiling_service.py index 7b8e5f0f8..063d30959 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_profiling_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_profiling_service.py @@ -2,6 +2,7 @@ from dataall.base.feature_toggle_checker import is_feature_enabled from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.core.tasks.service_handlers import Worker from dataall.base.context import get_context from dataall.base.db import exceptions @@ -13,7 +14,7 @@ from dataall.modules.s3_datasets.aws.s3_profiler_client import S3ProfilerClient from dataall.modules.s3_datasets.db.dataset_profiling_repositories import DatasetProfilingRepository from dataall.modules.s3_datasets.db.dataset_table_repositories import DatasetTableRepository -from dataall.modules.s3_datasets.services.dataset_permissions import PROFILE_DATASET_TABLE, GET_DATASET +from dataall.modules.s3_datasets.services.dataset_permissions import PROFILE_DATASET_TABLE, GET_DATASET, MANAGE_DATASETS from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository from dataall.modules.datasets_base.services.datasets_enums import ConfidentialityClassification from dataall.modules.s3_datasets.db.dataset_models import DatasetProfilingRun, DatasetTable @@ -22,6 +23,7 @@ class DatasetProfilingService: @staticmethod + @TenantPolicyService.has_tenant_permission(MANAGE_DATASETS) @ResourcePolicyService.has_resource_permission(PROFILE_DATASET_TABLE) @is_feature_enabled('modules.s3_datasets.features.metrics_data') def start_profiling_run(uri, table_uri, glue_table_name): diff --git a/backend/dataall/modules/s3_datasets/services/dataset_table_data_filter_service.py b/backend/dataall/modules/s3_datasets/services/dataset_table_data_filter_service.py index a6a75996f..e848a8afb 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_table_data_filter_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_table_data_filter_service.py @@ -2,6 +2,7 @@ import re from dataall.base.context import get_context from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.modules.s3_datasets.db.dataset_table_data_filter_repositories import DatasetTableDataFilterRepository from dataall.modules.s3_datasets.db.dataset_table_repositories import DatasetTableRepository from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository @@ -11,6 +12,7 @@ CREATE_TABLE_DATA_FILTER, DELETE_TABLE_DATA_FILTER, LIST_TABLE_DATA_FILTERS, + MANAGE_DATASETS, ) from dataall.base.db import exceptions from dataall.modules.s3_datasets.aws.lf_data_filter_client import LakeFormationDataFilterClient @@ -70,6 +72,7 @@ def _get_table_uri_from_filter(session, uri): return data_filter.tableUri @staticmethod + @TenantPolicyService.has_tenant_permission(MANAGE_DATASETS) @ResourcePolicyService.has_resource_permission(CREATE_TABLE_DATA_FILTER) def create_table_data_filter(uri: str, data: dict): DatasetTableDataFilterRequestValidationService.validate_creation_data_filter_params(uri, data) @@ -93,6 +96,7 @@ def create_table_data_filter(uri: str, data: dict): return data_filter @staticmethod + @TenantPolicyService.has_tenant_permission(MANAGE_DATASETS) @ResourcePolicyService.has_resource_permission(DELETE_TABLE_DATA_FILTER, parent_resource=_get_table_uri_from_filter) def delete_table_data_filter(uri: str): with get_context().db_engine.scoped_session() as session: diff --git a/backend/dataall/modules/s3_datasets_shares/services/s3_share_service.py b/backend/dataall/modules/s3_datasets_shares/services/s3_share_service.py index f5220901e..309750265 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/s3_share_service.py +++ b/backend/dataall/modules/s3_datasets_shares/services/s3_share_service.py @@ -193,6 +193,7 @@ def list_shared_tables_by_env_dataset(dataset_uri: str, env_uri: str): ] @staticmethod + @TenantPolicyService.has_tenant_permission(MANAGE_DATASETS) @ResourcePolicyService.has_resource_permission(CREDENTIALS_DATASET) def get_dataset_shared_assume_role_url(uri): context = get_context() diff --git a/backend/dataall/modules/worksheets/api/resolvers.py b/backend/dataall/modules/worksheets/api/resolvers.py index 97e4c9edf..280cf468c 100644 --- a/backend/dataall/modules/worksheets/api/resolvers.py +++ b/backend/dataall/modules/worksheets/api/resolvers.py @@ -11,8 +11,6 @@ def create_worksheet(context: Context, source, input: dict = None): raise exceptions.RequiredParameter(input) if not input.get('SamlAdminGroupName'): raise exceptions.RequiredParameter('groupUri') - if input.get('SamlAdminGroupName') not in context.groups: - raise exceptions.InvalidInput('groupUri', input.get('SamlAdminGroupName'), " a user's groups") if not input.get('label'): raise exceptions.RequiredParameter('label') diff --git a/backend/dataall/modules/worksheets/services/worksheet_service.py b/backend/dataall/modules/worksheets/services/worksheet_service.py index e07ce63ce..faf10f93b 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_service.py +++ b/backend/dataall/modules/worksheets/services/worksheet_service.py @@ -36,6 +36,10 @@ def _get_worksheet_by_uri(session, uri: str) -> Worksheet: @TenantPolicyService.has_tenant_permission(MANAGE_WORKSHEETS) def create_worksheet(data=None) -> Worksheet: context = get_context() + if data['SamlAdminGroupName'] not in context.groups: + raise exceptions.UnauthorizedOperation( + 'CREATE_WORKSHEET', f"user {context.username} does not belong to group {data['SamlAdminGroupName']}" + ) with context.db_engine.scoped_session() as session: worksheet = Worksheet( owner=context.username, @@ -126,6 +130,7 @@ def delete_worksheet(uri) -> bool: return True @staticmethod + @TenantPolicyService.has_tenant_permission(MANAGE_WORKSHEETS) @ResourcePolicyService.has_resource_permission(RUN_ATHENA_QUERY) def run_sql_query(uri, worksheetUri, sqlQuery): with get_context().db_engine.scoped_session() as session: diff --git a/tests/conftest.py b/tests/conftest.py index 78c96a26e..50af73a38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,18 +93,24 @@ def user3(): yield User('david') -def _create_group(db, tenant, name, user): +@pytest.fixture(scope='module', autouse=True) +def userNoTenantPermissions(): + yield User('noPermissionsUser') + + +def _create_group(db, tenant, name, user, attach_permissions=True): with db.scoped_session() as session: group = Group(name=name, label=name, owner=user.username) session.add(group) session.commit() - TenantPolicyService.attach_group_tenant_policy( - session=session, - group=name, - permissions=TENANT_ALL, - tenant_name=tenant.name, - ) + if attach_permissions: + TenantPolicyService.attach_group_tenant_policy( + session=session, + group=name, + permissions=TENANT_ALL, + tenant_name=tenant.name, + ) return group @@ -133,6 +139,11 @@ def not_in_org_group(db, tenant, user): yield _create_group(db, tenant, 'NotInOrgGroup', user) +@pytest.fixture(scope='module') +def groupNoTenantPermissions(db, tenant, userNoTenantPermissions): + yield _create_group(db, tenant, 'groupNoTenantPermissions', userNoTenantPermissions, attach_permissions=False) + + @pytest.fixture(scope='module', autouse=True) def tenant(db, permissions): with db.scoped_session() as session: diff --git a/tests/test_tenant_unauthorized.py b/tests/test_tenant_unauthorized.py new file mode 100644 index 000000000..f87b5dde7 --- /dev/null +++ b/tests/test_tenant_unauthorized.py @@ -0,0 +1,88 @@ +from unittest.mock import MagicMock, patch +import pytest +from assertpy import assert_that +from dataall.base.api import bootstrap +from dataall.base.loader import load_modules, ImportMode +from dataall.base.context import RequestContext +from dataall.base.db.exceptions import TenantUnauthorized +import inspect + + +load_modules(modes={ImportMode.API}) + +OPT_OUT_MUTATIONS = { + 'Mutation.updateGroupTenantPermissions': 'admin action. No need for tenant permission check', + 'Mutation.updateSSMParameter': 'admin action. No need for tenant permission check', + 'Mutation.createQuicksightDataSourceSet': 'admin action. No need for tenant permission check', + 'Mutation.startMaintenanceWindow': 'admin action. No need for tenant permission check', + 'Mutation.stopMaintenanceWindow': 'admin action. No need for tenant permission check', + 'Mutation.startReindexCatalog': 'admin action. No need for tenant permission check', + 'Mutation.markNotificationAsRead': 'tenant permissions do not apply to support notifications', + 'Mutation.deleteNotification': 'tenant permissions do not apply to support notifications', + 'Mutation.postFeedMessage': 'tenant permissions do not apply to support feed messages', + 'Mutation.upVote': 'tenant permissions do not apply to support votes', + 'Mutation.createAttachedMetadataForm': 'outside of this PR to be able to backport to v2.6.2', + 'Mutation.deleteAttachedMetadataForm': 'outside of this PR to be able to backport to v2.6.2', + 'Mutation.createRedshiftConnection': 'outside of this PR to be able to backport to v2.6.2', + 'Mutation.deleteRedshiftConnection': 'outside of this PR to be able to backport to v2.6.2', + 'Mutation.addConnectionGroupPermission': 'outside of this PR to be able to backport to v2.6.2', + 'Mutation.deleteConnectionGroupPermission': 'outside of this PR to be able to backport to v2.6.2', +} + +OPT_IN_QUERIES = [ + 'Query.generateEnvironmentAccessToken', + 'Query.getEnvironmentAssumeRoleUrl', + 'Query.getSagemakerStudioUserPresignedUrl', + 'Query.getSagemakerNotebookPresignedUrl', + 'Query.getDatasetAssumeRoleUrl', + 'Query.getDatasetPresignedUrl', + 'Query.getAuthorSession', + 'Query.getDatasetSharedAssumeRoleUrl', + 'Query.runAthenaSqlQuery', +] + +ALL_RESOLVERS = {(_type, field) for _type in bootstrap().types for field in _type.fields if field.resolver} + + +@pytest.fixture(scope='function') +def mock_input_validation(mocker): + mocker.patch('dataall.modules.mlstudio.api.resolvers.RequestValidator', MagicMock()) + mocker.patch( + 'dataall.modules.mlstudio.services.mlstudio_service.SagemakerStudioCreationRequest.from_dict', MagicMock() + ) + mocker.patch('dataall.modules.notebooks.api.resolvers.RequestValidator', MagicMock()) + mocker.patch('dataall.modules.notebooks.services.notebook_service.NotebookCreationRequest.from_dict', MagicMock()) + mocker.patch('dataall.modules.s3_datasets.api.profiling.resolvers._validate_uri', MagicMock()) + mocker.patch('dataall.modules.s3_datasets.api.storage_location.resolvers._validate_input', MagicMock()) + mocker.patch('dataall.modules.s3_datasets.api.dataset.resolvers.RequestValidator', MagicMock()) + mocker.patch( + 'dataall.core.stacks.db.target_type_repositories.TargetType.get_resource_tenant_permission_name', + return_value='MANAGE_ENVIRONMENTS', + ) + mocker.patch('dataall.modules.shares_base.api.resolvers.RequestValidator', MagicMock()) + + +@pytest.mark.parametrize( + '_type,field', + [ + pytest.param(_type, field, id=f'{_type.name}.{field.name}') + for _type, field in ALL_RESOLVERS + if _type.name in ['Query', 'Mutation'] + ], +) +@patch('dataall.base.context._request_storage') +def test_unauthorized_tenant_permissions( + mock_local, _type, field, mock_input_validation, db, userNoTenantPermissions, groupNoTenantPermissions +): + if _type.name == 'Mutation' and f'{_type.name}.{field.name}' in OPT_OUT_MUTATIONS.keys(): + pytest.skip(f'Skipping test for {field.name}: {OPT_OUT_MUTATIONS[f"{_type.name}.{field.name}"]}') + if _type.name == 'Query' and f'{_type.name}.{field.name}' not in OPT_IN_QUERIES: + pytest.skip(f'Skipping test for {field.name}: This Query does not require a tenant permission check.') + assert_that(field.resolver).is_not_none() + mock_local.context = RequestContext( + db, userNoTenantPermissions.username, [groupNoTenantPermissions.groupUri], userNoTenantPermissions + ) + # Mocking arguments + iargs = {arg: MagicMock() for arg in inspect.signature(field.resolver).parameters.keys()} + # Assert Unauthorized exception is raised + assert_that(field.resolver).raises(TenantUnauthorized).when_called_with(**iargs).contains('UnauthorizedOperation')