From 9c621d1c18663498ed554b9ec305e6913e22b72c Mon Sep 17 00:00:00 2001 From: Luis Madrigal <599908+Madrigal@users.noreply.github.com> Date: Wed, 28 Aug 2024 13:30:10 -0400 Subject: [PATCH] Add a `PresignPostObject` to the `PresignClient` for `s3.PutObject` operation (#2758) * Add a `PresignPostObject` to the `PresignClient` for `s3.PutObject` operation --- .../f80f134492ef472493f6e5090b404bc2.json | 8 + .../integrationtest/s3/presign_post_test.go | 187 +++++++ .../internal/integrationtest/s3/sample.txt | 1 + .../s3shared/integ_test_setup.go | 2 +- service/s3/presign_post.go | 433 ++++++++++++++++ service/s3/presign_post_test.go | 489 ++++++++++++++++++ 6 files changed, 1119 insertions(+), 1 deletion(-) create mode 100644 .changelog/f80f134492ef472493f6e5090b404bc2.json create mode 100644 service/internal/integrationtest/s3/presign_post_test.go create mode 100644 service/internal/integrationtest/s3/sample.txt create mode 100644 service/s3/presign_post.go create mode 100644 service/s3/presign_post_test.go diff --git a/.changelog/f80f134492ef472493f6e5090b404bc2.json b/.changelog/f80f134492ef472493f6e5090b404bc2.json new file mode 100644 index 00000000000..328dbd245d3 --- /dev/null +++ b/.changelog/f80f134492ef472493f6e5090b404bc2.json @@ -0,0 +1,8 @@ +{ + "id": "f80f1344-92ef-4724-93f6-e5090b404bc2", + "type": "feature", + "description": "Add presignPost for s3 PutObject", + "modules": [ + "service/s3" + ] +} \ No newline at end of file diff --git a/service/internal/integrationtest/s3/presign_post_test.go b/service/internal/integrationtest/s3/presign_post_test.go new file mode 100644 index 00000000000..18fc7e38e0b --- /dev/null +++ b/service/internal/integrationtest/s3/presign_post_test.go @@ -0,0 +1,187 @@ +//go:build integration +// +build integration + +package s3 + +import ( + "bytes" + "context" + "io" + "mime/multipart" + "net/http" + "os" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/internal/integrationtest" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +func TestInteg_PresigPost(t *testing.T) { + + const filePath = "sample.txt" + + cases := map[string]struct { + params s3.PutObjectInput + conditions []interface{} + expectedStatusCode int + }{ + "standard": { + params: s3.PutObjectInput{}, + }, + "extra conditions, fail upload": { + params: s3.PutObjectInput{}, + conditions: []interface{}{ + []interface{}{ + // any number larger than the small sample + "content-length-range", + 100000, + 200000, + }, + }, + expectedStatusCode: http.StatusBadRequest, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + + cfg, err := integrationtest.LoadConfigWithDefaultRegion("us-west-2") + if err != nil { + t.Fatalf("failed to load config, %v", err) + } + + client := s3.NewFromConfig(cfg) + + // construct a put object + presignerClient := s3.NewPresignClient(client) + + params := c.params + if params.Key == nil { + params.Key = aws.String(integrationtest.UniqueID()) + } + params.Bucket = &setupMetadata.Buckets.Source.Name + var presignRequest *s3.PresignedPostRequest + if c.conditions != nil { + presignRequest, err = presignerClient.PresignPostObject(ctx, ¶ms, func(opts *s3.PresignPostOptions) { + opts.Conditions = c.conditions + }) + + } else { + presignRequest, err = presignerClient.PresignPostObject(ctx, ¶ms) + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + resp, err := sendMultipartRequest(presignRequest.URL, presignRequest.Values, filePath) + if err != nil { + t.Fatalf("expect no error while sending HTTP request using presigned url, got %v", err) + } + + defer resp.Body.Close() + if c.expectedStatusCode != 0 { + if resp.StatusCode != c.expectedStatusCode { + t.Fatalf("expect status code %v, got %v", c.expectedStatusCode, resp.StatusCode) + } + // don't check the rest of the tests if there's a custom status code + return + } else { + // expected result is 204 on POST requests + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("failed to put S3 object, %d:%s", resp.StatusCode, resp.Status) + } + } + + // construct a get object + getObjectInput := &s3.GetObjectInput{ + Bucket: params.Bucket, + Key: params.Key, + } + + // This could be a regular GetObject call, but since we already have a presigner client available + getRequest, err := presignerClient.PresignGetObject(ctx, getObjectInput) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + resp, err = sendHTTPRequest(getRequest, nil) + if err != nil { + t.Errorf("expect no error while sending HTTP request using presigned url, got %v", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("failed to get S3 object, %d:%s", resp.StatusCode, resp.Status) + } + + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("expect no error reading local file %v, got %v", filePath, err) + } + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("expect no error reading response %v, got %v", resp.Body, err) + } + if !bytes.Equal(content, respBytes) { + t.Fatalf("expect response body %v, got %v", content, resp.Body) + } + }) + } +} + +func sendMultipartRequest(url string, fields map[string]string, filePath string) (*http.Response, error) { + // Create a buffer to hold the multipart data + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + // Add form fields + for key, val := range fields { + err := writer.WriteField(key, val) + if err != nil { + return nil, err + } + } + + // Add the file + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer file.Close() + + // Always has to be named like this + fileField := "file" + part, err := writer.CreateFormFile(fileField, filePath) + if err != nil { + return nil, err + } + _, err = io.Copy(part, file) + if err != nil { + return nil, err + } + + // Close the writer to finalize the multipart message + err = writer.Close() + if err != nil { + return nil, err + } + + // Create a new HTTP request + req, err := http.NewRequest("POST", url, &requestBody) + if err != nil { + return nil, err + } + + // Set the Content-Type header + req.Header.Set("Content-Type", writer.FormDataContentType()) + + // Send the request + client := &http.Client{} + return client.Do(req) +} diff --git a/service/internal/integrationtest/s3/sample.txt b/service/internal/integrationtest/s3/sample.txt new file mode 100644 index 00000000000..534490de16f --- /dev/null +++ b/service/internal/integrationtest/s3/sample.txt @@ -0,0 +1 @@ +Lorem ipsum et dolor \ No newline at end of file diff --git a/service/internal/integrationtest/s3shared/integ_test_setup.go b/service/internal/integrationtest/s3shared/integ_test_setup.go index 34c9e1fd9cb..a6defd54df2 100644 --- a/service/internal/integrationtest/s3shared/integ_test_setup.go +++ b/service/internal/integrationtest/s3shared/integ_test_setup.go @@ -71,7 +71,7 @@ pt: goto pt } // fail if not succeed after 10 attempts - return fmt.Errorf("failed to determine if a bucket %s exists and you have permission to access it", bucketName) + return fmt.Errorf("failed to determine if a bucket %s exists and you have permission to access it %v", bucketName, err) } return nil diff --git a/service/s3/presign_post.go b/service/s3/presign_post.go new file mode 100644 index 00000000000..6bdbcde6687 --- /dev/null +++ b/service/s3/presign_post.go @@ -0,0 +1,433 @@ +package s3 + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "net/url" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/aws/retry" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" + "github.com/aws/aws-sdk-go-v2/internal/sdk" + acceptencodingcust "github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding" + presignedurlcust "github.com/aws/aws-sdk-go-v2/service/internal/presigned-url" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +const ( + algorithmHeader = "X-Amz-Algorithm" + credentialHeader = "X-Amz-Credential" + dateHeader = "X-Amz-Date" + tokenHeader = "X-Amz-Security-Token" + signatureHeader = "X-Amz-Signature" + + algorithm = "AWS4-HMAC-SHA256" + aws4Request = "aws4_request" + bucketHeader = "bucket" + defaultExpiresIn = 15 * time.Minute + shortDateLayout = "20060102" +) + +// PresignPostObject is a special kind of [presigned request] used to send a request using +// form data, likely from an HTML form on a browser. +// Unlike other presigned operations, the return values of this function are not meant to be used directly +// to make an HTTP request but rather to be used as inputs to a form. See [the docs] for more information +// on how to use these values +// +// [presigned request] https://docs.aws.amazon.com/AmazonS3/latest/userguide/ShareObjectPreSignedURL.html +// [the docs] https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectPOST.html +func (c *PresignClient) PresignPostObject(ctx context.Context, params *PutObjectInput, optFns ...func(*PresignPostOptions)) (*PresignedPostRequest, error) { + if params == nil { + params = &PutObjectInput{} + } + clientOptions := c.options.copy() + options := PresignPostOptions{ + Expires: clientOptions.Expires, + PostPresigner: &postSignAdapter{}, + } + for _, fn := range optFns { + fn(&options) + } + clientOptFns := append(clientOptions.ClientOptions, withNopHTTPClientAPIOption) + cvt := presignPostConverter(options) + result, _, err := c.client.invokeOperation(ctx, "$type:L", params, clientOptFns, + c.client.addOperationPutObjectMiddlewares, + cvt.ConvertToPresignMiddleware, + func(stack *middleware.Stack, options Options) error { + return awshttp.RemoveContentTypeHeader(stack) + }, + ) + if err != nil { + return nil, err + } + + out := result.(*PresignedPostRequest) + return out, nil +} + +// PresignedPostRequest represents a presigned request to be sent using HTTP verb POST and FormData +type PresignedPostRequest struct { + // Represents the Base URL to make a request to + URL string + // Values is a key-value map of values to be sent as FormData + // these values are not encoded + Values map[string]string +} + +// postSignAdapter adapter to implement the presignPost interface +type postSignAdapter struct{} + +// PresignPost creates a special kind of [presigned request] +// to be used with HTTP verb POST. +// It differs from PUT request mostly on +// 1. It accepts a new set of parameters, `Conditions[]`, that are used to create a policy doc to limit where an object can be posted to +// 2. The return value needs to have more processing since it's meant to be sent via a form and not stand on its own +// 3. There's no body to be signed, since that will be attached when the actual request is made +// 4. The signature is made based on the policy document, not the whole request +// More information can be found at https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectPOST.html +// +// [presigned request] https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html +func (s *postSignAdapter) PresignPost( + credentials aws.Credentials, + bucket string, key string, + region string, service string, signingTime time.Time, conditions []interface{}, expirationTime time.Time, optFns ...func(*v4.SignerOptions), +) (fields map[string]string, err error) { + credentialScope := buildCredentialScope(signingTime, region, service) + credentialStr := credentials.AccessKeyID + "/" + credentialScope + + policyDoc, err := createPolicyDocument(expirationTime, signingTime, bucket, key, credentialStr, &credentials.SessionToken, conditions) + if err != nil { + return nil, err + } + + signature := buildSignature(policyDoc, credentials.SecretAccessKey, service, region, signingTime) + + fields = getPostSignRequiredFields(signingTime, credentialStr, credentials) + fields[signatureHeader] = signature + fields["key"] = key + fields["policy"] = policyDoc + + return fields, nil +} + +func getPostSignRequiredFields(t time.Time, credentialStr string, awsCredentials aws.Credentials) map[string]string { + fields := map[string]string{ + algorithmHeader: algorithm, + dateHeader: t.UTC().Format("20060102T150405Z"), + credentialHeader: credentialStr, + } + + sessionToken := awsCredentials.SessionToken + if len(sessionToken) > 0 { + fields[tokenHeader] = sessionToken + } + + return fields +} + +// PresignPost defines the interface to presign a POST request +type PresignPost interface { + PresignPost( + credentials aws.Credentials, + bucket string, key string, + region string, service string, signingTime time.Time, conditions []interface{}, expirationTime time.Time, + optFns ...func(*v4.SignerOptions), + ) (fields map[string]string, err error) +} + +// PresignPostOptions represent the options to be passed to a PresignPost sign request +type PresignPostOptions struct { + + // ClientOptions are list of functional options to mutate client options used by + // the presign client. + ClientOptions []func(*Options) + + // PostPresigner to use. One will be created if none is provided + PostPresigner PresignPost + + // Expires sets the expiration duration for the generated presign url. This should + // be the duration in seconds the presigned URL should be considered valid for. If + // not set or set to zero, presign url would default to expire after 900 seconds. + Expires time.Duration + + // Conditions a list of extra conditions to pass to the policy document + // Available conditions can be found [here] + // + // [here]https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-HTTPPOSTConstructPolicy.html#sigv4-PolicyConditions + Conditions []interface{} +} + +type presignPostConverter PresignPostOptions + +// presignPostRequestMiddlewareOptions is the options for the presignPostRequestMiddleware middleware. +type presignPostRequestMiddlewareOptions struct { + CredentialsProvider aws.CredentialsProvider + Presigner PresignPost + LogSigning bool + ExpiresIn time.Duration + Conditions []interface{} +} + +type presignPostRequestMiddleware struct { + credentialsProvider aws.CredentialsProvider + presigner PresignPost + logSigning bool + expiresIn time.Duration + conditions []interface{} +} + +// newPresignPostRequestMiddleware returns a new presignPostRequestMiddleware +// initialized with the presigner. +func newPresignPostRequestMiddleware(options presignPostRequestMiddlewareOptions) *presignPostRequestMiddleware { + return &presignPostRequestMiddleware{ + credentialsProvider: options.CredentialsProvider, + presigner: options.Presigner, + logSigning: options.LogSigning, + expiresIn: options.ExpiresIn, + conditions: options.Conditions, + } +} + +// ID provides the middleware ID. +func (*presignPostRequestMiddleware) ID() string { return "PresignPostRequestMiddleware" } + +// HandleFinalize will take the provided input and create a presigned url for +// the http request using the SigV4 presign authentication scheme. +// +// Since the signed request is not a valid HTTP request +func (s *presignPostRequestMiddleware) HandleFinalize( + ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler, +) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unexpected request middleware type %T", in.Request) + } + + input := getOperationInput(ctx) + asS3Put, ok := input.(*PutObjectInput) + if !ok { + return out, metadata, fmt.Errorf("expected PutObjectInput") + } + bucketName, ok := asS3Put.bucket() + if !ok { + return out, metadata, fmt.Errorf("requested input bucketName not found on request") + } + uploadKey := asS3Put.Key + if uploadKey == nil { + return out, metadata, fmt.Errorf("PutObject input does not have a key input") + } + + httpReq := req.Build(ctx) + u := httpReq.URL.String() + + signingName := awsmiddleware.GetSigningName(ctx) + signingRegion := awsmiddleware.GetSigningRegion(ctx) + + credentials, err := s.credentialsProvider.Retrieve(ctx) + if err != nil { + return out, metadata, &v4.SigningError{ + Err: fmt.Errorf("failed to retrieve credentials: %w", err), + } + } + skew := internalcontext.GetAttemptSkewContext(ctx) + signingTime := sdk.NowTime().Add(skew) + expirationTime := signingTime.Add(s.expiresIn).UTC() + + fields, err := s.presigner.PresignPost( + credentials, + bucketName, + *uploadKey, + signingRegion, + signingName, + signingTime, + s.conditions, + expirationTime, + func(o *v4.SignerOptions) { + o.Logger = middleware.GetLogger(ctx) + o.LogSigning = s.logSigning + }) + if err != nil { + return out, metadata, &v4.SigningError{ + Err: fmt.Errorf("failed to sign http request, %w", err), + } + } + + // Other middlewares may set default values on the URL on the path or as query params. Remove them + baseURL := toBaseURL(u) + + out.Result = &PresignedPostRequest{ + URL: baseURL, + Values: fields, + } + + return out, metadata, nil +} + +func toBaseURL(fullURL string) string { + a, _ := url.Parse(fullURL) + return a.Scheme + "://" + a.Host +} + +// Adapted from existing PresignConverter middleware +func (c presignPostConverter) ConvertToPresignMiddleware(stack *middleware.Stack, options Options) (err error) { + stack.Build.Remove("UserAgent") + stack.Finalize.Remove((*acceptencodingcust.DisableGzip)(nil).ID()) + stack.Finalize.Remove((*retry.Attempt)(nil).ID()) + stack.Finalize.Remove((*retry.MetricsHeader)(nil).ID()) + stack.Deserialize.Clear() + + if err := stack.Finalize.Insert(&presignContextPolyfillMiddleware{}, "Signing", middleware.Before); err != nil { + return err + } + + // if no expiration is set, set one + expiresIn := c.Expires + if expiresIn == 0 { + expiresIn = defaultExpiresIn + } + + pmw := newPresignPostRequestMiddleware(presignPostRequestMiddlewareOptions{ + CredentialsProvider: options.Credentials, + Presigner: c.PostPresigner, + LogSigning: options.ClientLogMode.IsSigning(), + ExpiresIn: expiresIn, + Conditions: c.Conditions, + }) + if _, err := stack.Finalize.Swap("Signing", pmw); err != nil { + return err + } + if err = smithyhttp.AddNoPayloadDefaultContentTypeRemover(stack); err != nil { + return err + } + err = presignedurlcust.AddAsIsPresigningMiddleware(stack) + if err != nil { + return err + } + return nil +} + +func createPolicyDocument(expirationTime time.Time, signingTime time.Time, bucket string, key string, credentialString string, securityToken *string, extraConditions []interface{}) (string, error) { + initialConditions := []interface{}{ + map[string]string{ + algorithmHeader: algorithm, + }, + map[string]string{ + bucketHeader: bucket, + }, + map[string]string{ + credentialHeader: credentialString, + }, + map[string]string{ + dateHeader: signingTime.UTC().Format("20060102T150405Z"), + }, + } + + var conditions []interface{} + for _, v := range initialConditions { + conditions = append(conditions, v) + } + + if securityToken != nil && *securityToken != "" { + conditions = append(conditions, map[string]string{ + tokenHeader: *securityToken, + }) + } + + // append user-defined conditions at the end + conditions = append(conditions, extraConditions...) + + // The policy allows you to set a "key" value to specify what's the name of the + // key to add. Customers can add one by specifying one in their conditions, + // so we're checking if one has already been set. + // If none is found, restrict this to just the key name passed on the request + // This can be disabled by adding a condition that explicitly allows + // everything + if !isAlreadyCheckingForKey(conditions) { + conditions = append(conditions, map[string]string{"key": key}) + } + + policyDoc := map[string]interface{}{ + "conditions": conditions, + "expiration": expirationTime.Format(time.RFC3339), + } + + jsonBytes, err := json.Marshal(policyDoc) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(jsonBytes), nil +} + +func isAlreadyCheckingForKey(conditions []interface{}) bool { + // Need to check for two conditions: + // 1. A condition of the form ["starts-with", "$key", "mykey"] + // 2. A condition of the form {"key": "mykey"} + for _, c := range conditions { + slice, ok := c.([]interface{}) + if ok && len(slice) > 1 { + if slice[0] == "starts-with" && slice[1] == "$key" { + return true + } + } + m, ok := c.(map[string]interface{}) + if ok && len(m) > 0 { + for k := range m { + if k == "key" { + return true + } + } + } + // Repeat this but for map[string]string due to type constrains + ms, ok := c.(map[string]string) + if ok && len(ms) > 0 { + for k := range ms { + if k == "key" { + return true + } + } + } + } + return false +} + +// these methods have been copied from v4 implementation since they are not exported for public use +func hmacsha256(key []byte, data []byte) []byte { + hash := hmac.New(sha256.New, key) + hash.Write(data) + return hash.Sum(nil) +} + +func buildSignature(strToSign, secret, service, region string, t time.Time) string { + key := deriveKey(secret, service, region, t) + return hex.EncodeToString(hmacsha256(key, []byte(strToSign))) +} + +func deriveKey(secret, service, region string, t time.Time) []byte { + hmacDate := hmacsha256([]byte("AWS4"+secret), []byte(t.UTC().Format(shortDateLayout))) + hmacRegion := hmacsha256(hmacDate, []byte(region)) + hmacService := hmacsha256(hmacRegion, []byte(service)) + return hmacsha256(hmacService, []byte(aws4Request)) +} + +func buildCredentialScope(signingTime time.Time, region, service string) string { + return strings.Join([]string{ + signingTime.UTC().Format(shortDateLayout), + region, + service, + aws4Request, + }, "/") +} diff --git a/service/s3/presign_post_test.go b/service/s3/presign_post_test.go new file mode 100644 index 00000000000..baf50a2aaf8 --- /dev/null +++ b/service/s3/presign_post_test.go @@ -0,0 +1,489 @@ +package s3 + +import ( + "context" + "encoding/base64" + "encoding/json" + "reflect" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" + "github.com/aws/aws-sdk-go-v2/internal/sdk" +) + +func TestPresignPutObject(t *testing.T) { + fixedTime := time.Date(2022, time.February, 1, 0, 0, 0, 0, time.UTC) + defer mockTime(fixedTime)() + + cases := map[string]struct { + input PutObjectInput + options []func(*PresignPostOptions) + expectedExpires time.Time + expectedURL string + region string + }{ + "sample": { + input: PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }, + }, + "expires override": { + input: PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }, + expectedExpires: fixedTime.Add(5 * time.Minute), + options: []func(o *PresignPostOptions){ + func(o *PresignPostOptions) { + o.Expires = 5 * time.Minute + }, + }, + }, + "body is ignored": { + input: PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + // This will be ignored + Body: strings.NewReader("hello-world"), + }, + }, + "different region": { + input: PutObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }, + region: "eu-central-1", + expectedURL: "https://bucket.s3.eu-central-1.amazonaws.com", + }, + "mrap endpoint is changed": { + input: PutObjectInput{ + Bucket: aws.String("arn:aws:s3::123456789012:accesspoint:mfzwi23gnjvgw.mrap"), + Key: aws.String("mockkey"), + }, + expectedURL: "https://mfzwi23gnjvgw.mrap.accesspoint.s3-global.amazonaws.com", + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + region := "us-west-2" + if tc.region != "" { + region = tc.region + } + cfg := aws.Config{ + Region: region, + Credentials: unit.StubCredentialsProvider{}, + Retryer: func() aws.Retryer { + return aws.NopRetryer{} + }, + } + + presignClient := NewPresignClient(NewFromConfig(cfg)) + postObject, err := presignClient.PresignPostObject(ctx, &tc.input, tc.options...) + if err != nil { + t.Error(err) + } + if postObject == nil { + t.Error("expected non-nil postObject") + } + if tc.expectedURL != "" { + if tc.expectedURL != postObject.URL { + t.Errorf("expected URL %q; got %q", tc.expectedURL, postObject.URL) + } + } else { + if "https://bucket.s3.us-west-2.amazonaws.com" != postObject.URL { + t.Error("expected URL to contain 'https://amazon.com', was: ", postObject.URL) + } + } + + if len(postObject.Values) < 1 { + t.Error("expected non-empty values") + } + policy, ok := postObject.Values["policy"] + if !ok { + t.Error("expected non-empty policy on postObject") + } + decoded, err := base64.StdEncoding.DecodeString(policy) + if err != nil { + t.Error("expected base64 encoded policy, got error", err, "policy", policy) + } + var policyJSON map[string]interface{} + err = json.Unmarshal(decoded, &policyJSON) + if err != nil { + t.Error("expected valid JSON for policy, got error", err, "with policy", policy) + } + actualExpires, ok := policyJSON["expiration"] + if !ok { + t.Error("expected non-empty expiration on policy JSON policy", policyJSON) + } + + if !time.Time.IsZero(tc.expectedExpires) { + isEqual, err := isTimeEqual(actualExpires.(string), tc.expectedExpires) + if err != nil { + t.Error("Error parsing expires", actualExpires, err) + } + if !isEqual { + t.Error("expected expiration to be", tc.expectedExpires, "got", actualExpires) + } + } else { + // Check the default is set. Go serializes JSON values as RFC3339 + expectedExpires := fixedTime.Add(15 * time.Minute).Format(time.RFC3339) + if actualExpires != expectedExpires { + t.Error("expected expiration to be", expectedExpires, "got", actualExpires) + } + } + }) + } +} + +// Test that comes straight from the docs https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-post-example.html +// Unfortunately it can't be verified with the exact same values +// since the sample in the docs lowercases all headers `x-amzn-header` +// while the SDK does not `X-Amzn-Header`, so the signature and policy are different. +// However, the values have been manually inspected to match the desired output +func TestSampleFromPublicDocs(t *testing.T) { + accessKeyID := "AKIAIOSFODNN7EXAMPLE" + secretAccessKey := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + bucket := "sigv4examplebucket" + key := "user/user1" + testTime := time.Date(2015, time.December, 29, 0, 0, 0, 0, time.UTC) + defer mockTime(testTime)() + expiresIn := 36 * time.Hour + staticCredentials := staticCredentialsProvider{Key: accessKeyID, Secret: secretAccessKey} + ctx := context.Background() + + cfg := aws.Config{ + Region: "us-east-1", + Credentials: staticCredentials, + Retryer: func() aws.Retryer { + return aws.NopRetryer{} + }, + } + + presignClient := NewPresignClient(NewFromConfig(cfg)) + input := PutObjectInput{Bucket: aws.String(bucket), Key: aws.String(key)} + conditions := []interface{}{ + []interface{}{"starts-with", "$key", "user/user1/"}, + map[string]string{"acl": "public-read"}, + map[string]string{"success_action_redirect": "http://sigv4examplebucket.s3.amazonaws.com/successful_upload.html"}, + []interface{}{"starts-with", "$Content-Type", "image/"}, + map[string]string{"x-amz-meta-uuid": "14365123651274"}, + []interface{}{"starts-with", "$x-amz-meta-tag", ""}, + } + opts := func(o *PresignPostOptions) { + o.Expires = expiresIn + o.Conditions = conditions + } + postObject, err := presignClient.PresignPostObject(ctx, &input, opts) + if err != nil { + t.Error(err) + } + if postObject == nil { + t.Error("expected non-nil postObject") + } + values := postObject.Values + signature, ok := values["X-Amz-Signature"] + if !ok { + t.Error("expected non-empty signature on postObject", values) + } + // Signature and policy are VERY sensitive to any change in output or order. If these tests fail, + // it can be due to a change in order for the policy or a change in capitalization + if signature != "41eb7f468113e77dca133475d38815dbe1f92b073964f4a0575f036e9c02d28a" { + t.Error("expected signature to equal to be precomputed", signature, "got", values) + } + policy, ok := values["policy"] + if !ok { + t.Error("expected non-empty policy on values", values) + } + expectedPolicy := "eyJjb25kaXRpb25zIjpbeyJYLUFtei1BbGdvcml0aG0iOiJBV1M0LUhNQUMtU0hBMjU2In0seyJidWN" + + "rZXQiOiJzaWd2NGV4YW1wbGVidWNrZXQifSx7IlgtQW16LUNyZWRlbnRpYWwiOiJBS0lBSU9TRk9ETk" + + "43RVhBTVBMRS8yMDE1MTIyOS91cy1lYXN0LTEvczMvYXdzNF9yZXF1ZXN0In0seyJYLUFtei1EYXRlI" + + "joiMjAxNTEyMjlUMDAwMDAwWiJ9LFsic3RhcnRzLXdpdGgiLCIka2V5IiwidXNlci91c2VyMS8iXSx7" + + "ImFjbCI6InB1YmxpYy1yZWFkIn0seyJzdWNjZXNzX2FjdGlvbl9yZWRpcmVjdCI6Imh0dHA6Ly9zaWd" + + "2NGV4YW1wbGVidWNrZXQuczMuYW1hem9uYXdzLmNvbS9zdWNjZXNzZnVsX3VwbG9hZC5odG1sIn0sWy" + + "JzdGFydHMtd2l0aCIsIiRDb250ZW50LVR5cGUiLCJpbWFnZS8iXSx7IngtYW16LW1ldGEtdXVpZCI6I" + + "jE0MzY1MTIzNjUxMjc0In0sWyJzdGFydHMtd2l0aCIsIiR4LWFtei1tZXRhLXRhZyIsIiJdXSwiZXhw" + + "aXJhdGlvbiI6IjIwMTUtMTItMzBUMTI6MDA6MDBaIn0=" + if policy != expectedPolicy { + t.Error("expected policy to equal", expectedPolicy, "got", policy) + } +} + +func TestBuildPresignPostRequest(t *testing.T) { + cases := map[string]struct { + credentials aws.Credentials + extraConditions []interface{} + isKeyConditionSet bool + }{ + "credentials without access token": { + credentials: credentialsNoToken, + extraConditions: []interface{}{}, + }, + "credentials with access token": { + credentials: credentialsWithToken, + extraConditions: []interface{}{}, + }, + "no extra conditions": { + credentials: credentialsWithToken, + extraConditions: []interface{}{}, + }, + "extra conditions": { + credentials: credentialsWithToken, + extraConditions: []interface{}{ + map[string]string{"acl": "public-read"}, + []string{"starts-with", "$Content-Type", "image/"}, + }, + }, + "extra conditions collision": { + credentials: credentialsWithToken, + extraConditions: []interface{}{ + map[string]string{"bucket": "otherBucket"}, + }, + }, + "a key condition is set, no extra one is generated": { + credentials: credentialsNoToken, + extraConditions: []interface{}{ + []interface{}{"starts-with", "$key", "user/user1/"}, + }, + isKeyConditionSet: true, + }, + } + requiredFields := []string{ + "X-Amz-Algorithm", + "X-Amz-Credential", + "X-Amz-Date", + "X-Amz-Signature", + "key", + "policy", + } + + requiredConditions := []string{"X-Amz-Algorithm", "bucket", "X-Amz-Credential", "X-Amz-Date"} + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + target := postSignAdapter{} + aBucketKey := "someKey" + bucket := "someBucket" + signingTime := sdk.NowTime() + expiration := signingTime.Add(time.Hour) + fields, err := target.PresignPost(tc.credentials, bucket, aBucketKey, "region", "service", signingTime, tc.extraConditions, expiration) + if err != nil { + t.Errorf("PresignPostHTTP returned unexepected error: %s", err.Error()) + } + if len(fields) == 0 { + t.Errorf("PresignPostHTTP returned no fields") + } + + for _, field := range requiredFields { + _, ok := fields[field] + if !ok { + t.Errorf("Fields response did not contain required key %s. Res %v", field, fields) + } + } + + if tc.credentials.SessionToken != "" { + _, ok := fields["X-Amz-Security-Token"] + if !ok { + t.Errorf("Credentials are using a session token, but is not set on the fields response") + } + } + + actualKey := fields["key"] + if actualKey != aBucketKey { + t.Errorf("PresignPostHTTP did not contain expected \"key\" %s. Has %s", aBucketKey, actualKey) + } + policy := fields["policy"] + decoded, err := base64.StdEncoding.DecodeString(policy) + if err != nil { + t.Errorf("Decoding policy document %s failed with error %v", policy, err) + } + var doc map[string]interface{} + err = json.Unmarshal(decoded, &doc) + if err != nil { + t.Errorf("Policy document %s failed to parse to JSON with error %v", policy, err) + } + _, ok := doc["conditions"] + if !ok { + t.Errorf("Conditions field not present in policy document %s", policy) + } + exp, ok := doc["expiration"] + if !ok { + t.Errorf("Expiration field not present in policy document %s", policy) + } + docExpiration, ok := exp.(string) + if !ok { + t.Errorf("Expiration field is not a time as expected, is %v", doc["expiration"]) + } + isEqual, err := isTimeEqual(docExpiration, expiration) + if err != nil { + t.Errorf("PresignPost did not parse expiration time %s. Error %v", docExpiration, err) + } + if !isEqual { + t.Errorf("Expected policy expiration to be %v. Got %v", expiration, docExpiration) + } + conditions := doc["conditions"].([]interface{}) + if len(conditions) == 0 { + t.Errorf("Policy document didn't contain any conditions") + } + for _, required := range requiredConditions { + val := findInSlice(conditions, required) + if val == nil { + t.Errorf("Policy document didn't contain required conditions %s. Has %v", required, conditions) + } + } + actualBucket := findInSlice(conditions, "bucket") + if !reflect.DeepEqual(bucket, actualBucket) { + t.Errorf("Expected bucket to be %v, was %v", bucket, actualBucket) + } + actualDate := findInSlice(conditions, "X-Amz-Date") + signingTimeStr := signingTime.UTC().Format("20060102T150405Z") + if signingTimeStr != actualDate { + t.Errorf("Expected date to be %v, was %v", signingTimeStr, actualDate) + } + if len(tc.extraConditions) > 0 { + for _, ec := range tc.extraConditions { + if !isPresent(ec, conditions) { + t.Errorf("Expected item %v not found on conditions %v", ec, conditions) + } + } + } + if !tc.isKeyConditionSet { + // check the default is set + conditionKey := findInSlice(conditions, "key") + if conditionKey == nil { + t.Errorf("Expected Condition 'key' to be set on policy conditions, none found. Conditions %v", conditions) + } + actualVal, ok := conditionKey.(string) + if !ok { + t.Errorf("Expected condition key to be a string, was %v", conditionKey) + } + if actualVal != aBucketKey { + t.Errorf("Expected bucket key to be %v, was %v", aBucketKey, actualVal) + } + } else { + // check the key condition is not set + conditionKey := findInSlice(conditions, "key") + if conditionKey != nil { + t.Errorf("Expected condition key to be nil since %v was set, was %v", tc.isKeyConditionSet, conditionKey) + } + + } + }) + } +} + +func mockTime(t time.Time) func() { + sdk.NowTime = func() time.Time { return t } + return func() { sdk.NowTime = time.Now } +} + +type staticCredentialsProvider struct { + Key string + Secret string +} + +func (p staticCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + return aws.Credentials{AccessKeyID: p.Key, SecretAccessKey: p.Secret}, nil +} + +var credentialsNoToken = aws.Credentials{AccessKeyID: "AKID", SecretAccessKey: "SECRET"} +var credentialsWithToken = aws.Credentials{AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "SESSION"} + +func isPresent(needle interface{}, haystack []interface{}) bool { + needleValue := reflect.ValueOf(needle) + for _, item := range haystack { + itemValue := reflect.ValueOf(item) + + // special checks for slices and maps, since interface{} are not typecasted + // by reflect.DeepEquals + isSlice := itemValue.Kind() == reflect.Slice && needleValue.Kind() == reflect.Slice + if isSlice && areSlicesEqual(needleValue, itemValue) { + return true + } + isMap := itemValue.Kind() == reflect.Map && needleValue.Kind() == reflect.Map + if isMap && areMapsEqual(needleValue, itemValue) { + return true + } + + // else do a regular deep equal check + if reflect.DeepEqual(item, needle) { + return true + } + } + return false +} + +func areSlicesEqual(a reflect.Value, b reflect.Value) bool { + if a.Len() != b.Len() { + return false + } + + for i := 0; i < a.Len(); i++ { + aValue := a.Index(i).Interface() + bValue := b.Index(i).Interface() + + if !reflect.DeepEqual(aValue, bValue) { + return false + } + } + + return true +} + +func areMapsEqual(aVal reflect.Value, bVal reflect.Value) bool { + // Check if 'a' is a map + if aVal.Kind() != reflect.Map { + return false + } + + // Check if both maps have the same number of keys + if aVal.Len() != bVal.Len() { + return false + } + + // Iterate over the keys and values in the first map + for _, key := range aVal.MapKeys() { + aValue := aVal.MapIndex(key) + if !aValue.IsValid() { + return false + } + bValue := bVal.MapIndex(key) + if !bValue.IsValid() { + return false + } + + // Compare values using reflect.DeepEqual + if !reflect.DeepEqual(aValue.Interface(), bValue.Interface()) { + return false + } + } + return true +} + +// filters items in slice that have a map[string]interface{} and returns +// the first items map that has the key from "key" +func findInSlice(slice []interface{}, key string) interface{} { + for _, item := range slice { + // filter only the values with keys. Ignore stuff like arrays + if v, ok := item.(map[string]interface{}); ok { + // once in the maps, check if they have the desired key + if _, ok := v[key]; ok { + return v[key] + } + } + } + return nil +} + +func isTimeEqual(t1s string, t2 time.Time) (bool, error) { + t1, err := time.Parse(time.RFC3339, t1s) + if err != nil { + return false, err + } + areEqual := t1.Format(time.RFC3339) == t2.Format(time.RFC3339) + return areEqual, nil +}