Skip to content

Commit

Permalink
separate checksum config check and workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianyi Wang committed Sep 29, 2024
1 parent b93f439 commit 72c3ba6
Show file tree
Hide file tree
Showing 40 changed files with 236 additions and 575 deletions.
2 changes: 1 addition & 1 deletion .changelog/9ebe24c4791541e0840da49eab6f9d97.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"id": "9ebe24c4-7915-41e0-840d-a49eab6f9d97",
"type": "feature",
"description": "add client cfg to opt-in/out checksum behavior and change its default algorithm",
"description": "This feature adds new client cfg fields so user can opt-in/out request/response checksum calculation/validation for operation modeled with checksum trait. The default MD5 checksum algorithm is replaced with CRC32.",
"modules": [
".",
"config",
Expand Down
29 changes: 8 additions & 21 deletions aws/checksum.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,26 @@
package aws

// RequestChecksumCalculation controls request checksum calculation workflow
type RequestChecksumCalculation string
type RequestChecksumCalculation int

const (
// RequestChecksumCalculationWhenSupported indicates request checksum should be calculated if modeled
RequestChecksumCalculationWhenSupported RequestChecksumCalculation = "whenSupported"
// RequestChecksumCalculationWhenSupported indicates request checksum should be calculated if
// client operation model has request checksum trait
RequestChecksumCalculationWhenSupported RequestChecksumCalculation = 1

// RequestChecksumCalculationWhenRequired indicates request checksum should be calculated
// if modeled and user set an algorithm
RequestChecksumCalculationWhenRequired = "whenRequired"
RequestChecksumCalculationWhenRequired = 2
)

// ResponseChecksumValidation controls response checksum validation workflow
type ResponseChecksumValidation string
type ResponseChecksumValidation int

const (
// ResponseChecksumValidationWhenSupported indicates response checksum should be validated if modeled
ResponseChecksumValidationWhenSupported ResponseChecksumValidation = "whenSupported"
ResponseChecksumValidationWhenSupported ResponseChecksumValidation = 1

// ResponseChecksumValidationWhenRequired indicates response checksum should be validated if modeled
// and user enable that in vlidation mode cfg
ResponseChecksumValidationWhenRequired = "whenRequired"
)

// RequireChecksum indicates if a checksum needs calculated/validated for a request/response
type RequireChecksum string

const (
// RequireChecksumTrue indicates checksum should be calculated/validated
RequireChecksumTrue RequireChecksum = "true"

// RequireChecksumFalse indicates checksum should not be calculated/validated
RequireChecksumFalse RequireChecksum = "false"

// RequireChecksumPending indicates further check is needed to decide
RequireChecksumPending RequireChecksum = "pending"
ResponseChecksumValidationWhenRequired = 2
)
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ private void writeInputMiddlewareHelper(
writer.write("""
return $T(stack, $T{
GetAlgorithm: $L,
RequireChecksum: $T,
RequireChecksum: $L,
RequestChecksumCalculation: options.RequestChecksumCalculation,
EnableTrailingChecksum: $L,
EnableComputeSHA256PayloadHash: true,
Expand All @@ -275,23 +275,14 @@ private void writeInputMiddlewareHelper(
AwsGoDependency.SERVICE_INTERNAL_CHECKSUM).build(),
hasRequestAlgorithmMember ?
getRequestAlgorithmAccessorFuncName(operationName) : "nil",
getInputRequireChecksum(isRequestChecksumRequired, hasRequestAlgorithmMember),
isRequestChecksumRequired,
supportsRequestTrailingChecksum,
supportsDecodedContentLengthHeader);
}
);
writer.insertTrailingNewline();
}

private Symbol getInputRequireChecksum(boolean isRequestChecksumRequired, boolean hasRequestAlgorithmMember) {
if (isRequestChecksumRequired) {
return SdkGoTypes.Aws.RequireChecksumTrue;
} else if (hasRequestAlgorithmMember) {
return SdkGoTypes.Aws.RequireChecksumPending;
}
return SdkGoTypes.Aws.RequireChecksumFalse;
}

private void writeOutputMiddlewareHelper(
GoWriter writer,
Model model,
Expand All @@ -314,7 +305,6 @@ private void writeOutputMiddlewareHelper(
writer.write("""
return $T(stack, $T{
GetValidationMode: $L,
RequireChecksum: $T,
ResponseChecksumValidation: options.ResponseChecksumValidation,
ValidationAlgorithms: $L,
IgnoreMultipartValidation: $L,
Expand All @@ -326,20 +316,13 @@ private void writeOutputMiddlewareHelper(
SymbolUtils.createValueSymbolBuilder("OutputMiddlewareOptions",
AwsGoDependency.SERVICE_INTERNAL_CHECKSUM).build(),
getRequestValidationModeAccessorFuncName(operationName),
getOutputRequireChecksum(responseAlgorithms),
convertToGoStringList(responseAlgorithms),
ignoreMultipartChecksumValidationMap.getOrDefault(
service.toShapeId(), new HashSet<>()).contains(operation.toShapeId())
);
});
writer.insertTrailingNewline();
}
private Symbol getOutputRequireChecksum(List<String> responseAlgorithms) {
if (responseAlgorithms.isEmpty()) {
return SdkGoTypes.Aws.RequireChecksumFalse;
}
return SdkGoTypes.Aws.RequireChecksumPending;
}

private String convertToGoStringList(List<String> list) {
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ public static final class Aws {
public static final Symbol RequestChecksumCalculation = AwsGoDependency.AWS_CORE.valueSymbol("RequestChecksumCalculation");
public static final Symbol ResponseChecksumValidation = AwsGoDependency.AWS_CORE.valueSymbol("ResponseChecksumValidation");

public static final Symbol RequireChecksumTrue = AwsGoDependency.AWS_CORE.valueSymbol("RequireChecksumTrue");
public static final Symbol RequireChecksumFalse = AwsGoDependency.AWS_CORE.valueSymbol("RequireChecksumFalse");
public static final Symbol RequireChecksumPending = AwsGoDependency.AWS_CORE.valueSymbol("RequireChecksumPending");


public static final class Middleware {
public static final Symbol GetRequiresLegacyEndpoints = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("GetRequiresLegacyEndpoints");
public static final Symbol GetSigningName = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("GetSigningName");
Expand Down
4 changes: 2 additions & 2 deletions config/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,11 @@ func (c EnvConfig) getAccountIDEndpointMode(context.Context) (aws.AccountIDEndpo
}

func (c EnvConfig) getRequestChecksumCalculation(context.Context) (aws.RequestChecksumCalculation, bool, error) {
return c.RequestChecksumCalculation, len(c.RequestChecksumCalculation) > 0, nil
return c.RequestChecksumCalculation, c.RequestChecksumCalculation > 0, nil
}

func (c EnvConfig) getResponseChecksumValidation(context.Context) (aws.ResponseChecksumValidation, bool, error) {
return c.ResponseChecksumValidation, len(c.ResponseChecksumValidation) > 0, nil
return c.ResponseChecksumValidation, c.ResponseChecksumValidation > 0, nil
}

// GetRetryMaxAttempts returns the value of AWS_MAX_ATTEMPTS if was specified,
Expand Down
8 changes: 4 additions & 4 deletions config/load_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,11 @@ func (o LoadOptions) getAccountIDEndpointMode(ctx context.Context) (aws.AccountI
}

func (o LoadOptions) getRequestChecksumCalculation(ctx context.Context) (aws.RequestChecksumCalculation, bool, error) {
return o.RequestChecksumCalculation, len(o.RequestChecksumCalculation) > 0, nil
return o.RequestChecksumCalculation, o.RequestChecksumCalculation > 0, nil
}

func (o LoadOptions) getResponseChecksumValidation(ctx context.Context) (aws.ResponseChecksumValidation, bool, error) {
return o.ResponseChecksumValidation, len(o.ResponseChecksumValidation) > 0, nil
return o.ResponseChecksumValidation, o.ResponseChecksumValidation > 0, nil
}

// WithRegion is a helper function to construct functional options
Expand Down Expand Up @@ -359,7 +359,7 @@ func WithAccountIDEndpointMode(m aws.AccountIDEndpointMode) LoadOptionsFunc {
// that sets RequestChecksumCalculation on config's LoadOptions
func WithRequestChecksumCalculation(c aws.RequestChecksumCalculation) LoadOptionsFunc {
return func(o *LoadOptions) error {
if c != "" {
if c > 0 {
o.RequestChecksumCalculation = c
}
return nil
Expand All @@ -370,7 +370,7 @@ func WithRequestChecksumCalculation(c aws.RequestChecksumCalculation) LoadOption
// that sets ResponseChecksumValidation on config's LoadOptions
func WithResponseChecksumValidation(v aws.ResponseChecksumValidation) LoadOptionsFunc {
return func(o *LoadOptions) error {
if v != "" {
if v > 0 {
o.ResponseChecksumValidation = v
}
return nil
Expand Down
4 changes: 2 additions & 2 deletions config/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1278,11 +1278,11 @@ func (c SharedConfig) getAccountIDEndpointMode(ctx context.Context) (aws.Account
}

func (c SharedConfig) getRequestChecksumCalculation(ctx context.Context) (aws.RequestChecksumCalculation, bool, error) {
return c.RequestChecksumCalculation, len(c.RequestChecksumCalculation) > 0, nil
return c.RequestChecksumCalculation, c.RequestChecksumCalculation > 0, nil
}

func (c SharedConfig) getResponseChecksumValidation(ctx context.Context) (aws.ResponseChecksumValidation, bool, error) {
return c.ResponseChecksumValidation, len(c.ResponseChecksumValidation) > 0, nil
return c.ResponseChecksumValidation, c.ResponseChecksumValidation > 0, nil
}

func updateDefaultsMode(mode *aws.DefaultsMode, section ini.Section, key string) error {
Expand Down
26 changes: 10 additions & 16 deletions service/internal/checksum/middleware_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ type InputMiddlewareOptions struct {
// and true, or false if no algorithm is specified.
GetAlgorithm func(interface{}) (string, bool)

// Whether operation model forces middleware to compute the input payload's checksum. The
// request will fail if the algorithm is not specified or unable to compute
// the checksum.
RequireChecksum aws.RequireChecksum
// Whether operation model forces middleware to compute the input payload's checksum.
RequireChecksum bool

// User config to opt-in/out request checksum calculation
RequestChecksumCalculation aws.RequestChecksumCalculation
Expand Down Expand Up @@ -76,7 +74,9 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions)

// Initial checksum configuration look up middleware
err = stack.Initialize.Add(&setupInputContext{
GetAlgorithm: options.GetAlgorithm,
GetAlgorithm: options.GetAlgorithm,
RequireChecksum: options.RequireChecksum,
RequestChecksumCalculation: options.RequestChecksumCalculation,
}, middleware.Before)
if err != nil {
return err
Expand All @@ -85,8 +85,6 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions)
stack.Build.Remove("ContentChecksum")

inputChecksum := &computeInputPayloadChecksum{
RequireChecksum: options.RequireChecksum,
RequestChecksumCalculation: options.RequestChecksumCalculation,
EnableTrailingChecksum: options.EnableTrailingChecksum,
EnableComputePayloadHash: options.EnableComputeSHA256PayloadHash,
EnableDecodedContentLengthHeader: options.EnableDecodedContentLengthHeader,
Expand All @@ -99,8 +97,6 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions)
if options.EnableTrailingChecksum {
trailerMiddleware := &addInputChecksumTrailer{
EnableTrailingChecksum: inputChecksum.EnableTrailingChecksum,
RequireChecksum: inputChecksum.RequireChecksum,
RequestChecksumCalculation: inputChecksum.RequestChecksumCalculation,
EnableComputePayloadHash: inputChecksum.EnableComputePayloadHash,
EnableDecodedContentLengthHeader: inputChecksum.EnableDecodedContentLengthHeader,
}
Expand Down Expand Up @@ -128,12 +124,11 @@ type OutputMiddlewareOptions struct {
// GetValidationMode is a function to get the checksum validation
// mode of the output payload from the input parameters.
//
// Given the input parameter value, the function must return the validation
// mode and true, or false if no mode is specified.
// Given the input parameter value, the function must return the validation mode
GetValidationMode func(interface{}) (string, bool)

// Whether operation model forces middleware to validate checksum
RequireChecksum aws.RequireChecksum
RequireChecksum bool

// User config to opt-in/out response checksum validation
ResponseChecksumValidation aws.ResponseChecksumValidation
Expand All @@ -146,7 +141,7 @@ type OutputMiddlewareOptions struct {
ValidationAlgorithms []string

// If set the middleware will ignore output multipart checksums. Otherwise
// an checksum format error will be returned by the middleware.
// a checksum format error will be returned by the middleware.
IgnoreMultipartValidation bool

// When set the middleware will log when output does not have checksum or
Expand All @@ -162,7 +157,8 @@ type OutputMiddlewareOptions struct {
// checksum.
func AddOutputMiddleware(stack *middleware.Stack, options OutputMiddlewareOptions) error {
err := stack.Initialize.Add(&setupOutputContext{
GetValidationMode: options.GetValidationMode,
GetValidationMode: options.GetValidationMode,
ResponseChecksumValidation: options.ResponseChecksumValidation,
}, middleware.Before)
if err != nil {
return err
Expand All @@ -173,8 +169,6 @@ func AddOutputMiddleware(stack *middleware.Stack, options OutputMiddlewareOption

m := &validateOutputPayloadChecksum{
Algorithms: algorithms,
RequireChecksum: options.RequireChecksum,
ResponseChecksumValidation: options.ResponseChecksumValidation,
IgnoreMultipartValidation: options.IgnoreMultipartValidation,
LogMultipartValidationSkipped: options.LogMultipartValidationSkipped,
LogValidationSkipped: options.LogValidationSkipped,
Expand Down
7 changes: 1 addition & 6 deletions service/internal/checksum/middleware_add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"reflect"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
Expand Down Expand Up @@ -64,7 +63,7 @@ func TestAddInputMiddleware(t *testing.T) {
return string(AlgorithmCRC32), true
},
EnableTrailingChecksum: true,
RequireChecksum: aws.RequireChecksumTrue,
RequireChecksum: true,
},
expectMiddleware: []string{
"test",
Expand All @@ -88,7 +87,6 @@ func TestAddInputMiddleware(t *testing.T) {
},
},
expectFinalize: &computeInputPayloadChecksum{
RequireChecksum: aws.RequireChecksumTrue,
EnableTrailingChecksum: true,
},
},
Expand Down Expand Up @@ -168,9 +166,6 @@ func TestAddInputMiddleware(t *testing.T) {
var computeInput *computeInputPayloadChecksum
if c.expectFinalize != nil && ok {
computeInput = finalizeMW.(*computeInputPayloadChecksum)
if e, a := c.expectFinalize.RequireChecksum, computeInput.RequireChecksum; e != a {
t.Errorf("expect %v require checksum, got %v", e, a)
}
if e, a := c.expectFinalize.EnableTrailingChecksum, computeInput.EnableTrailingChecksum; e != a {
t.Errorf("expect %v enable trailing checksum, got %v", e, a)
}
Expand Down
Loading

0 comments on commit 72c3ba6

Please sign in to comment.