Skip to content

Commit

Permalink
feat: Support placeholders for TuningStep (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
ca-nguyen authored Nov 3, 2021
1 parent 2091850 commit 1ea8346
Show file tree
Hide file tree
Showing 3 changed files with 340 additions and 16 deletions.
20 changes: 13 additions & 7 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateHyperParameterTuningJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateHyperParameterTuningJob.html>`_.
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
"""
if wait_for_completion:
"""
Expand All @@ -483,19 +486,22 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateHyperParameterTuningJob)

parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
tuning_parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()

if job_name is not None:
parameters['HyperParameterTuningJobName'] = job_name
tuning_parameters['HyperParameterTuningJobName'] = job_name

if 'S3Operations' in parameters:
del parameters['S3Operations']
if 'S3Operations' in tuning_parameters:
del tuning_parameters['S3Operations']

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)
tuning_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)

kwargs[Field.Parameters.value] = parameters
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update tuning parameters with input parameters
merge_dicts(tuning_parameters, kwargs[Field.Parameters.value])

kwargs[Field.Parameters.value] = tuning_parameters
super(TuningStep, self).__init__(state_id, **kwargs)


Expand Down
101 changes: 94 additions & 7 deletions tests/integ/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup


def test_training_step_with_placeholders(pca_estimator_fixture, record_set_fixture, sfn_client, sfn_role_arn):
Expand Down Expand Up @@ -193,7 +192,7 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
model_name = get_resource_name_from_arn(execution_output.get("ModelArn")).split("/")[1]
delete_sagemaker_model(model_name, sagemaker_session)
# End of Cleanup



def test_model_step_with_placeholders(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn):
Expand Down Expand Up @@ -288,7 +287,6 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup


def test_transform_step_with_placeholder(trained_estimator, sfn_client, sfn_role_arn):
Expand Down Expand Up @@ -413,7 +411,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
delete_sagemaker_endpoint_config(endpoint_config_name, sagemaker_session)
delete_sagemaker_model(model.name, sagemaker_session)
# End of Cleanup


def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client, sagemaker_session, sfn_role_arn):
# Setup: Create model and endpoint config for trained estimator in SageMaker
Expand Down Expand Up @@ -456,7 +454,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
delete_sagemaker_endpoint(endpoint_name, sagemaker_session)
delete_sagemaker_endpoint_config(model.name, sagemaker_session)
delete_sagemaker_model(model.name, sagemaker_session)
# End of Cleanup


def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
job_name = generate_job_name()
Expand Down Expand Up @@ -507,7 +505,97 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup


def test_tuning_step_with_placeholders(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
kmeans = KMeans(
role=sagemaker_role_arn,
instance_count=1,
instance_type=INSTANCE_TYPE,
k=10
)

hyperparameter_ranges = {
"extra_center_factor": IntegerParameter(4, 10),
"mini_batch_size": IntegerParameter(10, 100),
"epochs": IntegerParameter(1, 2),
"init_method": CategoricalParameter(["kmeans++", "random"]),
}

tuner = HyperparameterTuner(
estimator=kmeans,
objective_metric_name="test:msd",
hyperparameter_ranges=hyperparameter_ranges,
objective_type="Maximize",
max_jobs=2,
max_parallel_jobs=1,
)

execution_input = ExecutionInput(schema={
'job_name': str,
'objective_metric_name': str,
'objective_type': str,
'max_jobs': int,
'max_parallel_jobs': int,
'early_stopping_type': str,
'strategy': str,
})

parameters = {
'HyperParameterTuningJobConfig': {
'HyperParameterTuningJobObjective': {
'MetricName': execution_input['objective_metric_name'],
'Type': execution_input['objective_type']
},
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
'Strategy': execution_input['strategy'],
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
},
'TrainingJobDefinition': {
'AlgorithmSpecification': {
'TrainingInputMode': 'File'
}
}
}

# Build workflow definition
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=execution_input['job_name'],
data=record_set_for_hyperparameter_tuning, parameters=parameters)
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
workflow_graph = Chain([tuning_step])

with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
# Create workflow and check definition
workflow = create_workflow_and_check_definition(
workflow_graph=workflow_graph,
workflow_name=unique_name_from_base("integ-test-tuning-step-workflow"),
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

job_name = generate_job_name()

inputs = {
'job_name': job_name,
'objective_metric_name': 'test:msd',
'objective_type': 'Minimize',
'max_jobs': 2,
'max_parallel_jobs': 2,
'early_stopping_type': 'Off',
'strategy': 'Bayesian',
}

# Execute workflow
execution = workflow.execute(inputs=inputs)
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)


def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn):
region = boto3.session.Session().region_name
Expand Down Expand Up @@ -561,7 +649,6 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup


def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,
Expand Down
Loading

0 comments on commit 1ea8346

Please sign in to comment.