diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 39ba49b..b014669 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -57,8 +57,8 @@ Before sending us a pull request, please ensure that: ### Running the Unit Tests 1. Install tox using `pip install tox` -1. Install test dependencies, including coverage, using `pip install .[test]` 1. cd into the aws-step-functions-data-science-sdk-python folder: `cd aws-step-functions-data-science-sdk-python` or `cd /environment/aws-step-functions-data-science-sdk-python` +1. Install test dependencies, including coverage, using `pip install ".[test]"` 1. Run the following tox command and verify that all code checks and unit tests pass: `tox tests/unit` You can also run a single test with the following command: `tox -e py36 -- -s -vv ::` @@ -80,7 +80,7 @@ You should only worry about manually running any new integration tests that you 1. Create a new git branch: ```shell - git checkout -b my-fix-branch master + git checkout -b my-fix-branch ``` 1. Make your changes, **including unit tests** and, if appropriate, integration tests. 1. Include unit tests when you contribute new features or make bug fixes, as they help to: diff --git a/src/stepfunctions/steps/compute.py b/src/stepfunctions/steps/compute.py index da29b4c..203ed47 100644 --- a/src/stepfunctions/steps/compute.py +++ b/src/stepfunctions/steps/compute.py @@ -12,8 +12,31 @@ # permissions and limitations under the License. from __future__ import absolute_import +from enum import Enum from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn + +LAMBDA_SERVICE_NAME = "lambda" +GLUE_SERVICE_NAME = "glue" +ECS_SERVICE_NAME = "ecs" +BATCH_SERVICE_NAME = "batch" + + +class LambdaApi(Enum): + Invoke = "invoke" + + +class GlueApi(Enum): + StartJobRun = "startJobRun" + + +class EcsApi(Enum): + RunTask = "runTask" + + +class BatchApi(Enum): + SubmitJob = "submitJob" class LambdaStep(Task): @@ -37,10 +60,22 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ + if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke.waitForTaskToken' + """ + Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME, + LambdaApi.Invoke, + IntegrationPattern.WaitForTaskToken) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke' + """ + Example resource arn: arn:aws:states:::lambda:invoke + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME, LambdaApi.Invoke) + super(LambdaStep, self).__init__(state_id, **kwargs) @@ -67,9 +102,20 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun.sync' + """ + Example resource arn: arn:aws:states:::glue:startJobRun.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME, + GlueApi.StartJobRun, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun' + """ + Example resource arn: arn:aws:states:::glue:startJobRun + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME, + GlueApi.StartJobRun) super(GlueStartJobRunStep, self).__init__(state_id, **kwargs) @@ -96,9 +142,20 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob.sync' + """ + Example resource arn: arn:aws:states:::batch:submitJob.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME, + BatchApi.SubmitJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob' + """ + Example resource arn: arn:aws:states:::batch:submitJob + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME, + BatchApi.SubmitJob) super(BatchSubmitJobStep, self).__init__(state_id, **kwargs) @@ -125,8 +182,19 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask.sync' + """ + Example resource arn: arn:aws:states:::ecs:runTask.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, + EcsApi.RunTask, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask' + """ + Example resource arn: arn:aws:states:::ecs:runTask + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, + EcsApi.RunTask) super(EcsRunTaskStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/integration_resources.py b/src/stepfunctions/steps/integration_resources.py new file mode 100644 index 0000000..5223f14 --- /dev/null +++ b/src/stepfunctions/steps/integration_resources.py @@ -0,0 +1,46 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import absolute_import + +from enum import Enum +from stepfunctions.steps.utils import get_aws_partition + + +class IntegrationPattern(Enum): + """ + Integration pattern enum classes for task integration resource arn builder + """ + + WaitForTaskToken = "waitForTaskToken" + WaitForCompletion = "sync" + RequestResponse = "" + + +def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.RequestResponse): + + """ + ARN builder for task integration + Args: + service (str): The service name for the service integration + api (str): The api of the service integration + integration_pattern (IntegrationPattern, optional): The integration pattern for the task. (Default: IntegrationPattern.RequestResponse) + """ + arn = "" + if integration_pattern == IntegrationPattern.RequestResponse: + arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}" + else: + arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}" + return arn + + diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index acc7506..deb5176 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -12,15 +12,31 @@ # permissions and limitations under the License. from __future__ import absolute_import +from enum import Enum from stepfunctions.inputs import ExecutionInput, StepInput from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field from stepfunctions.steps.utils import tags_dict_to_kv_list +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel from sagemaker.model_monitor import DataCaptureConfig +SAGEMAKER_SERVICE_NAME = "sagemaker" + + +class SageMakerApi(Enum): + CreateTrainingJob = "createTrainingJob" + CreateTransformJob = "createTransformJob" + CreateModel = "createModel" + CreateEndpointConfig = "createEndpointConfig" + UpdateEndpoint = "updateEndpoint" + CreateEndpoint = "createEndpoint" + CreateHyperParameterTuningJob = "createHyperParameterTuningJob" + CreateProcessingJob = "createProcessingJob" + + class TrainingStep(Task): """ @@ -58,9 +74,20 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non self.job_name = job_name if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob.sync' + """ + Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateTrainingJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob' + """ + Example resource arn: arn:aws:states:::sagemaker:createTrainingJob + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateTrainingJob) if isinstance(job_name, str): parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) @@ -141,9 +168,20 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync' + """ + Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateTransformJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob' + """ + Example resource arn: arn:aws:states:::sagemaker:createTransformJob + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateTransformJob) if isinstance(job_name, str): parameters = transform_config( @@ -225,7 +263,13 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No parameters['Tags'] = tags_dict_to_kv_list(tags) kwargs[Field.Parameters.value] = parameters - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel' + + """ + Example resource arn: arn:aws:states:::sagemaker:createModel + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateModel) super(ModelStep, self).__init__(state_id, **kwargs) @@ -266,7 +310,13 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_ if tags: parameters['Tags'] = tags_dict_to_kv_list(tags) - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpointConfig' + """ + Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateEndpointConfig) + kwargs[Field.Parameters.value] = parameters super(EndpointConfigStep, self).__init__(state_id, **kwargs) @@ -298,9 +348,19 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd parameters['Tags'] = tags_dict_to_kv_list(tags) if update: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:updateEndpoint' + """ + Example resource arn: arn:aws:states:::sagemaker:updateEndpoint + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.UpdateEndpoint) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpoint' + """ + Example resource arn: arn:aws:states:::sagemaker:createEndpoint + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateEndpoint) kwargs[Field.Parameters.value] = parameters @@ -338,9 +398,20 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync' + """ + Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateHyperParameterTuningJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob' + """ + Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateHyperParameterTuningJob) parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy() @@ -387,10 +458,21 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob.sync' + """ + Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateProcessingJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob' - + """ + Example resource arn: arn:aws:states:::sagemaker:createProcessingJob + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, + SageMakerApi.CreateProcessingJob) + if isinstance(job_name, str): parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) else: diff --git a/src/stepfunctions/steps/service.py b/src/stepfunctions/steps/service.py index 20509f8..6bf155f 100644 --- a/src/stepfunctions/steps/service.py +++ b/src/stepfunctions/steps/service.py @@ -12,8 +12,40 @@ # permissions and limitations under the License. from __future__ import absolute_import +from enum import Enum from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn + +DYNAMODB_SERVICE_NAME = "dynamodb" +SNS_SERVICE_NAME = "sns" +SQS_SERVICE_NAME = "sqs" +ELASTICMAPREDUCE_SERVICE_NAME = "elasticmapreduce" + + +class DynamoDBApi(Enum): + GetItem = "getItem" + PutItem = "putItem" + DeleteItem = "deleteItem" + UpdateItem = "updateItem" + + +class SnsApi(Enum): + Publish = "publish" + + +class SqsApi(Enum): + SendMessage = "sendMessage" + + +class ElasticMapReduceApi(Enum): + CreateCluster = "createCluster" + TerminateCluster = "terminateCluster" + AddStep = "addStep" + CancelStep = "cancelStep" + SetClusterTerminationProtection = "setClusterTerminationProtection" + ModifyInstanceFleetByName = "modifyInstanceFleetByName" + ModifyInstanceGroupByName = "modifyInstanceGroupByName" class DynamoDBGetItemStep(Task): @@ -35,7 +67,13 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:getItem' + + """ + Example resource arn: arn:aws:states:::dynamodb:getItem + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(DYNAMODB_SERVICE_NAME, + DynamoDBApi.GetItem) super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs) @@ -59,7 +97,13 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:putItem' + + """ + Example resource arn: arn:aws:states:::dynamodb:putItem + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(DYNAMODB_SERVICE_NAME, + DynamoDBApi.PutItem) super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs) @@ -83,7 +127,13 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:deleteItem' + + """ + Example resource arn: arn:aws:states:::dynamodb:deleteItem + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(DYNAMODB_SERVICE_NAME, + DynamoDBApi.DeleteItem) super(DynamoDBDeleteItemStep, self).__init__(state_id, **kwargs) @@ -107,7 +157,13 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:updateItem' + + """ + Example resource arn: arn:aws:states:::dynamodb:updateItem + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(DYNAMODB_SERVICE_NAME, + DynamoDBApi.UpdateItem) super(DynamoDBUpdateItemStep, self).__init__(state_id, **kwargs) @@ -133,9 +189,20 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish.waitForTaskToken' + """ + Example resource arn: arn:aws:states:::sns:publish.waitForTaskToken + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SNS_SERVICE_NAME, + SnsApi.Publish, + IntegrationPattern.WaitForTaskToken) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish' + """ + Example resource arn: arn:aws:states:::sns:publish + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SNS_SERVICE_NAME, + SnsApi.Publish) super(SnsPublishStep, self).__init__(state_id, **kwargs) @@ -162,9 +229,20 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage.waitForTaskToken' + """ + Example resource arn: arn:aws:states:::sqs:sendMessage.waitForTaskToken + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SQS_SERVICE_NAME, + SqsApi.SendMessage, + IntegrationPattern.WaitForTaskToken) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage' + """ + Example resource arn: arn:aws:states:::sqs:sendMessage + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(SQS_SERVICE_NAME, + SqsApi.SendMessage) super(SqsSendMessageStep, self).__init__(state_id, **kwargs) @@ -190,9 +268,20 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster.sync' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:createCluster.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.CreateCluster, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:createCluster + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.CreateCluster) super(EmrCreateClusterStep, self).__init__(state_id, **kwargs) @@ -218,9 +307,20 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster.sync' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:terminateCluster.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.TerminateCluster, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:terminateCluster + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.TerminateCluster) super(EmrTerminateClusterStep, self).__init__(state_id, **kwargs) @@ -246,9 +346,20 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep.sync' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:addStep.sync + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.AddStep, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:addStep + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.AddStep) super(EmrAddStepStep, self).__init__(state_id, **kwargs) @@ -272,7 +383,13 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:cancelStep' + + """ + Example resource arn: arn:aws:states:::elasticmapreduce:cancelStep + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.CancelStep) super(EmrCancelStepStep, self).__init__(state_id, **kwargs) @@ -296,7 +413,13 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:setClusterTerminationProtection' + + """ + Example resource arn: arn:aws:states:::elasticmapreduce:setClusterTerminationProtection + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.SetClusterTerminationProtection) super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, **kwargs) @@ -320,7 +443,13 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName' + + """ + Example resource arn: arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.ModifyInstanceFleetByName) super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, **kwargs) @@ -344,7 +473,12 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName' - super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs) + """ + Example resource arn: arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName + """ + + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, + ElasticMapReduceApi.ModifyInstanceGroupByName) + super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 8e35d36..6f44481 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -12,6 +12,36 @@ # permissions and limitations under the License. from __future__ import absolute_import +import boto3 +import logging + +logger = logging.getLogger('stepfunctions') + + def tags_dict_to_kv_list(tags_dict): - kv_list = [{"Key": k, "Value": v} for k,v in tags_dict.items()] - return kv_list \ No newline at end of file + kv_list = [{"Key": k, "Value": v} for k, v in tags_dict.items()] + return kv_list + + +def get_aws_partition(): + + """ + Returns the aws partition for the current boto3 session. + Defaults to 'aws' if the region could not be detected. + """ + + partitions = boto3.session.Session().get_available_partitions() + cur_region = boto3.session.Session().region_name + cur_partition = "aws" + + if cur_region is None: + logger.warning("No region detected for the boto3 session. Using default partition: aws") + return cur_partition + + for partition in partitions: + regions = boto3.session.Session().get_available_regions("stepfunctions", partition) + if cur_region in regions: + cur_partition = partition + return cur_partition + + return cur_partition diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index cb3bee0..b2bb979 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -21,6 +21,7 @@ from sagemaker import Session from sagemaker.amazon import pca from sagemaker.sklearn.processing import SKLearnProcessor +from stepfunctions.steps.utils import get_aws_partition from tests.integ import DATA_DIR @pytest.fixture(scope="session") @@ -43,11 +44,11 @@ def aws_account_id(): @pytest.fixture(scope="session") def sfn_role_arn(aws_account_id): - return "arn:aws:iam::{}:role/StepFunctionsMLWorkflowExecutionFullAccess".format(aws_account_id) + return f"arn:{get_aws_partition()}:iam::{aws_account_id}:role/StepFunctionsMLWorkflowExecutionFullAccess" @pytest.fixture(scope="session") def sagemaker_role_arn(aws_account_id): - return "arn:aws:iam::{}:role/SageMakerRole".format(aws_account_id) + return f"arn:{get_aws_partition()}:iam::{aws_account_id}:role/SageMakerRole" @pytest.fixture(scope="session") def pca_estimator_fixture(sagemaker_role_arn): diff --git a/tests/integ/test_state_machine_definition.py b/tests/integ/test_state_machine_definition.py index ac5e7ac..0fec9f5 100644 --- a/tests/integ/test_state_machine_definition.py +++ b/tests/integ/test_state_machine_definition.py @@ -19,8 +19,10 @@ from sagemaker.image_uris import retrieve from stepfunctions import steps from stepfunctions.workflow import Workflow +from stepfunctions.steps.utils import get_aws_partition from tests.integ.utils import state_machine_delete_wait + @pytest.fixture(scope="module") def training_job_parameters(sagemaker_session, sagemaker_role_arn, record_set_fixture): parameters = { @@ -62,6 +64,7 @@ def training_job_parameters(sagemaker_session, sagemaker_role_arn, record_set_fi return parameters + def workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, output_result, inputs=None): state_machine_arn = workflow.create() execution = workflow.execute(inputs=inputs) @@ -380,7 +383,7 @@ def test_choice_state_machine_creation(sfn_client, sfn_role_arn): def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): task_state_name = "TaskState" final_state_name = "FinalState" - resource = "arn:aws:states:::sagemaker:createTrainingJob.sync" + resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" task_state_result = "Task State Result" asl_state_machine_definition = { "StartAt": task_state_name, @@ -426,7 +429,7 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par task_failed_state_name = "Task Failed End" all_error_state_name = "Catch All End" catch_state_result = "Catch Result" - task_resource = "arn:aws:states:::sagemaker:createTrainingJob.sync" + task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" # change the parameters to cause task state to fail training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" @@ -482,7 +485,7 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par interval_seconds = 1 max_attempts = 2 backoff_rate = 2 - task_resource = "arn:aws:states:::sagemaker:createTrainingJob.sync" + task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" # change the parameters to cause task state to fail training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" diff --git a/tests/unit/test_compute_steps.py b/tests/unit/test_compute_steps.py index 030cf35..368010a 100644 --- a/tests/unit/test_compute_steps.py +++ b/tests/unit/test_compute_steps.py @@ -13,10 +13,13 @@ from __future__ import absolute_import import pytest +import boto3 +from unittest.mock import patch from stepfunctions.steps.compute import LambdaStep, GlueStartJobRunStep, BatchSubmitJobStep, EcsRunTaskStep +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_lambda_step_creation(): step = LambdaStep('Echo') @@ -45,6 +48,8 @@ def test_lambda_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_glue_start_job_run_step_creation(): step = GlueStartJobRunStep('Glue Job', wait_for_completion=False) @@ -67,6 +72,8 @@ def test_glue_start_job_run_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_batch_submit_job_step_creation(): step = BatchSubmitJobStep('Batch Job', wait_for_completion=False) @@ -91,6 +98,8 @@ def test_batch_submit_job_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_ecs_run_task_step_creation(): step = EcsRunTaskStep('Ecs Job', wait_for_completion=False) diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index 2123a3d..c7ab502 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -14,6 +14,7 @@ import pytest import sagemaker +import boto3 from sagemaker.sklearn.estimator import SKLearn from unittest.mock import MagicMock, patch @@ -27,6 +28,7 @@ PCA_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1' LINEAR_LEARNER_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1' + @pytest.fixture def pca_estimator(): s3_output_location = 's3://sagemaker/models' @@ -52,6 +54,7 @@ def pca_estimator(): return pca + @pytest.fixture def sklearn_preprocessor(): script_path = 'sklearn_abalone_featurizer.py' @@ -75,6 +78,7 @@ def sklearn_preprocessor(): return sklearn_preprocessor + @pytest.fixture def linear_learner_estimator(): s3_output_location = 's3://sagemaker/models' @@ -101,7 +105,9 @@ def linear_learner_estimator(): return ll_estimator + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_pca_training_pipeline(pca_estimator): s3_inputs = { 'train': 's3://sagemaker/pca/train' @@ -228,7 +234,9 @@ def test_pca_training_pipeline(pca_estimator): workflow.execute.assert_called_with(name=job_name, inputs=inputs) + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator): s3_inputs = { 'train': 's3://sagemaker-us-east-1/inference/train' diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index d45ee79..2ce9d3b 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -37,6 +37,7 @@ DEFAULT_TAGS = {'Purpose': 'unittests'} DEFAULT_TAGS_LIST = [{'Key': 'Purpose', 'Value': 'unittests'}] + @pytest.fixture def pca_estimator(): s3_output_location = 's3://sagemaker/models' @@ -63,6 +64,7 @@ def pca_estimator(): return pca + @pytest.fixture def pca_estimator_with_debug_hook(): s3_output_location = 's3://sagemaker/models' @@ -139,6 +141,7 @@ def pca_estimator_with_falsy_debug_hook(): return pca + @pytest.fixture def pca_model(): model_data = 's3://sagemaker/models/pca.tar.gz' @@ -149,6 +152,7 @@ def pca_model(): name='pca-model' ) + @pytest.fixture def pca_transformer(pca_model): return Transformer( @@ -158,6 +162,7 @@ def pca_transformer(pca_model): output_path='s3://sagemaker/transform-output' ) + @pytest.fixture def tensorflow_estimator(): s3_output_location = 's3://sagemaker/models' @@ -190,6 +195,7 @@ def tensorflow_estimator(): return estimator + @pytest.fixture def sklearn_processor(): sagemaker_session = MagicMock() @@ -206,7 +212,9 @@ def sklearn_processor(): return processor + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation(pca_estimator): step = TrainingStep('Training', estimator=pca_estimator, @@ -256,7 +264,9 @@ def test_training_step_creation(pca_estimator): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook): step = TrainingStep('Training', estimator=pca_estimator_with_debug_hook, @@ -315,7 +325,9 @@ def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_falsy_debug_hook(pca_estimator_with_falsy_debug_hook): step = TrainingStep('Training', estimator=pca_estimator_with_falsy_debug_hook, @@ -352,7 +364,9 @@ def test_training_step_creation_with_falsy_debug_hook(pca_estimator_with_falsy_d 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_model(pca_estimator): training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob') model_step = ModelStep('Training - Save Model', training_step.get_expected_model(model_name=training_step.output()['TrainingJobName'])) @@ -404,7 +418,9 @@ def test_training_step_creation_with_model(pca_estimator): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_framework(tensorflow_estimator): step = TrainingStep('Training', estimator=tensorflow_estimator, @@ -466,6 +482,8 @@ def test_training_step_creation_with_framework(tensorflow_estimator): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_transform_step_creation(pca_transformer): step = TransformStep('Inference', transformer=pca_transformer, @@ -518,7 +536,9 @@ def test_transform_step_creation(pca_transformer): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_get_expected_model(pca_estimator): training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob') expected_model = training_step.get_expected_model() @@ -538,7 +558,9 @@ def test_get_expected_model(pca_estimator): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_get_expected_model_with_framework_estimator(tensorflow_estimator): training_step = TrainingStep('Training', estimator=tensorflow_estimator, @@ -569,6 +591,8 @@ def test_get_expected_model_with_framework_estimator(tensorflow_estimator): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_model_step_creation(pca_model): step = ModelStep('Create model', model=pca_model, model_name='pca-model', tags=DEFAULT_TAGS) assert step.to_dict() == { @@ -587,6 +611,8 @@ def test_model_step_creation(pca_model): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_endpoint_config_step_creation(pca_model): data_capture_config = DataCaptureConfig( enable_capture=True, @@ -629,6 +655,8 @@ def test_endpoint_config_step_creation(pca_model): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_endpoint_step_creation(pca_model): step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', tags=DEFAULT_TAGS) assert step.to_dict() == { @@ -654,6 +682,8 @@ def test_endpoint_step_creation(pca_model): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_processing_step_creation(sklearn_processor): inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')] outputs = [ diff --git a/tests/unit/test_service_steps.py b/tests/unit/test_service_steps.py index 64809e9..6576aaf 100644 --- a/tests/unit/test_service_steps.py +++ b/tests/unit/test_service_steps.py @@ -13,12 +13,15 @@ from __future__ import absolute_import import pytest +import boto3 +from unittest.mock import patch from stepfunctions.steps.service import DynamoDBGetItemStep, DynamoDBPutItemStep, DynamoDBUpdateItemStep, DynamoDBDeleteItemStep from stepfunctions.steps.service import SnsPublishStep, SqsSendMessageStep from stepfunctions.steps.service import EmrCreateClusterStep, EmrTerminateClusterStep, EmrAddStepStep, EmrCancelStepStep, EmrSetClusterTerminationProtectionStep, EmrModifyInstanceFleetByNameStep, EmrModifyInstanceGroupByNameStep +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_sns_publish_step_creation(): step = SnsPublishStep('Publish to SNS', parameters={ 'TopicArn': 'arn:aws:sns:us-east-1:123456789012:myTopic', @@ -57,6 +60,7 @@ def test_sns_publish_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_sqs_send_message_step_creation(): step = SqsSendMessageStep('Send to SQS', parameters={ 'QueueUrl': 'https://sqs.us-east-1.amazonaws.com/123456789012/myQueue', @@ -95,6 +99,7 @@ def test_sqs_send_message_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_dynamodb_get_item_step_creation(): step = DynamoDBGetItemStep('Read Message From DynamoDB', parameters={ 'TableName': 'TransferDataRecords-DDBTable-3I41R5L5EAGT', @@ -120,6 +125,7 @@ def test_dynamodb_get_item_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_dynamodb_put_item_step_creation(): step = DynamoDBPutItemStep('Add Message From DynamoDB', parameters={ 'TableName': 'TransferDataRecords-DDBTable-3I41R5L5EAGT', @@ -145,6 +151,7 @@ def test_dynamodb_put_item_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_dynamodb_delete_item_step_creation(): step = DynamoDBDeleteItemStep('Delete Message From DynamoDB', parameters={ 'TableName': 'TransferDataRecords-DDBTable-3I41R5L5EAGT', @@ -170,6 +177,7 @@ def test_dynamodb_delete_item_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_dynamodb_update_item_step_creation(): step = DynamoDBUpdateItemStep('Update Message From DynamoDB', parameters={ 'TableName': 'TransferDataRecords-DDBTable-3I41R5L5EAGT', @@ -203,6 +211,7 @@ def test_dynamodb_update_item_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_create_cluster_step_creation(): step = EmrCreateClusterStep('Create EMR cluster', parameters={ 'Name': 'MyWorkflowCluster', @@ -372,6 +381,7 @@ def test_emr_create_cluster_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_terminate_cluster_step_creation(): step = EmrTerminateClusterStep('Terminate EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId' @@ -399,6 +409,8 @@ def test_emr_terminate_cluster_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_add_step_step_creation(): step = EmrAddStepStep('Add step to EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -498,6 +510,8 @@ def test_emr_add_step_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_cancel_step_step_creation(): step = EmrCancelStepStep('Cancel step from EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -514,6 +528,8 @@ def test_emr_cancel_step_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_set_cluster_termination_protection_step_creation(): step = EmrSetClusterTerminationProtectionStep('Set termination protection for EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -530,6 +546,8 @@ def test_emr_set_cluster_termination_protection_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_modify_instance_fleet_by_name_step_creation(): step = EmrModifyInstanceFleetByNameStep('Modify Instance Fleet by name for EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -554,6 +572,8 @@ def test_emr_modify_instance_fleet_by_name_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_modify_instance_group_by_name_step_creation(): step = EmrModifyInstanceGroupByNameStep('Modify Instance Group by name for EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', diff --git a/tests/unit/test_steps.py b/tests/unit/test_steps.py index 5b338bc..91b5a0e 100644 --- a/tests/unit/test_steps.py +++ b/tests/unit/test_steps.py @@ -23,15 +23,16 @@ def test_to_pascalcase(): assert 'InputPath' == to_pascalcase('input_path') + def test_state_creation(): state = State( state_id='StartState', state_type='Void', - comment = 'This is a comment', - input_path = '$.Input', - output_path = '$.Output', - parameters = {'Key': 'Value'}, - result_path = '$.Result' + comment='This is a comment', + input_path='$.Input', + output_path='$.Output', + parameters={'Key': 'Value'}, + result_path='$.Result' ) assert state.to_dict() == { @@ -49,6 +50,7 @@ def test_state_creation(): with pytest.raises(TypeError): State(state_id='State', unknown_attribute=True) + def test_pass_state_creation(): pass_state = Pass('Pass', result='Pass') assert pass_state.state_id == 'Pass' @@ -58,6 +60,7 @@ def test_pass_state_creation(): 'End': True } + def test_verify_pass_state_fields(): pass_state = Pass( state_id='Pass', @@ -77,6 +80,7 @@ def test_verify_pass_state_fields(): with pytest.raises(TypeError): Pass('Pass', unknown_field='Unknown Field') + def test_succeed_state_creation(): succeed_state = Succeed( state_id='Succeed', @@ -89,10 +93,12 @@ def test_succeed_state_creation(): 'Comment': 'This is a comment' } + def test_verify_succeed_state_fields(): with pytest.raises(TypeError): Succeed('Succeed', unknown_field='Unknown Field') + def test_fail_creation(): fail_state = Fail( state_id='Fail', @@ -111,10 +117,12 @@ def test_fail_creation(): 'Cause': 'Kaiju attack' } + def test_verify_fail_state_fields(): with pytest.raises(TypeError): Fail('Succeed', unknown_field='Unknown Field') + def test_wait_state_creation(): wait_state = Wait( state_id='Wait', @@ -140,6 +148,7 @@ def test_wait_state_creation(): 'End': True } + def test_verify_wait_state_fields(): with pytest.raises(ValueError): Wait( @@ -148,6 +157,7 @@ def test_verify_wait_state_fields(): seconds_path='$.SecondsPath' ) + def test_choice_state_creation(): choice_state = Choice('Choice', input_path='$.Input') choice_state.add_choice(ChoiceRule.IsPresent("$.StringVariable1", True), Pass("End State 1")) @@ -183,6 +193,7 @@ def test_choice_state_creation(): with pytest.raises(TypeError): Choice('Choice', unknown_field='Unknown Field') + def test_task_state_creation(): task_state = Task('Task', resource='arn:aws:lambda:us-east-1:1234567890:function:StartLambda') task_state.add_retry(Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2)) @@ -215,6 +226,7 @@ def test_task_state_creation(): 'End': True } + def test_task_state_creation_with_dynamic_timeout(): task_state = Task( 'Task', @@ -230,6 +242,7 @@ def test_task_state_creation_with_dynamic_timeout(): 'End': True } + def test_task_state_create_fail_for_duplicated_dynamic_timeout_fields(): with pytest.raises(ValueError): Task( @@ -247,6 +260,7 @@ def test_task_state_create_fail_for_duplicated_dynamic_timeout_fields(): heartbeat_seconds_path='$.heartbeat', ) + def test_parallel_state_creation(): parallel_state = Parallel('Parallel') parallel_state.add_branch(Pass('Branch 1')) @@ -288,6 +302,7 @@ def test_parallel_state_creation(): 'End': True } + def test_map_state_creation(): map_state = Map('Map', iterator=Pass('FirstIteratorState'), items_path='$', max_concurrency=0) assert map_state.to_dict() == { @@ -306,9 +321,11 @@ def test_map_state_creation(): 'End': True } + def test_nested_chain_is_now_allowed(): chain = Chain([Chain([Pass('S1')])]) + def test_catch_creation(): catch = Catch(error_equals=['States.ALL'], next_step=Fail('End')) assert catch.to_dict() == { @@ -316,6 +333,7 @@ def test_catch_creation(): 'Next': 'End' } + def test_append_states_after_terminal_state_will_fail(): with pytest.raises(ValueError): chain = Chain() @@ -400,7 +418,7 @@ def test_chaining_choice_with_existing_default_overrides_value(caplog): assert s2_choice.default == s1_pass assert s2_choice.next_step is None # Choice steps do not have next_step - + def test_catch_fail_for_unsupported_state(): s1 = Pass('Step - One') @@ -409,7 +427,6 @@ def test_catch_fail_for_unsupported_state(): def test_retry_fail_for_unsupported_state(): - c1 = Choice('My Choice') with pytest.raises(ValueError): @@ -453,3 +470,6 @@ def test_default_paths_not_converted_to_null(): assert '"ResultPath": null' not in task_state.to_json() assert '"InputPath": null' not in task_state.to_json() assert '"OutputPath": null' not in task_state.to_json() + + + diff --git a/tests/unit/test_steps_utils.py b/tests/unit/test_steps_utils.py new file mode 100644 index 0000000..6eb0885 --- /dev/null +++ b/tests/unit/test_steps_utils.py @@ -0,0 +1,53 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +# Test if boto3 session can fetch correct aws partition info from test environment + +from stepfunctions.steps.utils import get_aws_partition +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn +import boto3 +from unittest.mock import patch +from enum import Enum + + +testService = "sagemaker" + + +class TestApi(Enum): + CreateTrainingJob = "createTrainingJob" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_util_get_aws_partition_aws(): + cur_partition = get_aws_partition() + assert cur_partition == "aws" + + +@patch.object(boto3.session.Session, 'region_name', 'cn-north-1') +def test_util_get_aws_partition_aws_cn(): + cur_partition = get_aws_partition() + assert cur_partition == "aws-cn" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_sagemaker_no_wait(): + arn = get_service_integration_arn(testService, TestApi.CreateTrainingJob) + assert arn == "arn:aws:states:::sagemaker:createTrainingJob" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_sagemaker_wait_completion(): + arn = get_service_integration_arn(testService, TestApi.CreateTrainingJob, + IntegrationPattern.WaitForCompletion) + assert arn == "arn:aws:states:::sagemaker:createTrainingJob.sync" +