From 24d44ae05c1d5d9c273352046076cab7a34737ea Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 19 Dec 2024 03:07:13 +0000 Subject: [PATCH 01/21] fix assertion failure during rds Signed-off-by: Nigel Brittain --- source/extensions/common/aws/BUILD | 1 + .../common/aws/credentials_provider_impl.cc | 34 +++++++++++++------ .../common/aws/credentials_provider_impl.h | 2 ++ 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/source/extensions/common/aws/BUILD b/source/extensions/common/aws/BUILD index ddd52931367a..01cac36a32b6 100644 --- a/source/extensions/common/aws/BUILD +++ b/source/extensions/common/aws/BUILD @@ -122,6 +122,7 @@ envoy_cc_library( "//source/common/common:thread_lib", "//source/common/http:utility_lib", "//source/common/init:target_lib", + "//source/common/init:manager_lib", "//source/common/json:json_loader_lib", "//source/common/runtime:runtime_features_lib", "//source/common/tracing:http_tracer_lib", diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 82728d8d39f9..4b7583b821f7 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -148,18 +148,23 @@ MetadataCredentialsProviderBase::MetadataCredentialsProviderBase( ALL_METADATACREDENTIALSPROVIDER_STATS(POOL_COUNTER(*scope_), POOL_GAUGE(*scope_))}); stats_->metadata_refresh_state_.set(uint64_t(refresh_state_)); - init_target_ = std::make_unique(debug_name_, [this]() -> void { - tls_slot_ = - ThreadLocal::TypedSlot::makeUnique(context_->threadLocal()); - tls_slot_->set( - [&](Event::Dispatcher&) { return std::make_shared(*this); }); + // If credential provider is being created during Envoy initialization, use init manager to delay cluster creation + // If we are here during normal processing, such as xDS update, then create clusters and initialize TLS immediately + if(context_->initManager().state() == Envoy::Init::Manager::State::Initialized) + { + initializeTlsAndCluster(); + } + else + { + init_target_ = std::make_unique(debug_name_, [this]() -> void { - createCluster(true); + initializeTlsAndCluster(); - init_target_->ready(); - init_target_.reset(); - }); - context_->initManager().add(*init_target_); + init_target_->ready(); + init_target_.reset(); + }); + context_->initManager().add(*init_target_); + } } }; @@ -171,6 +176,15 @@ MetadataCredentialsProviderBase::ThreadLocalCredentialsCache::~ThreadLocalCreden } } +void MetadataCredentialsProviderBase::initializeTlsAndCluster() { + tls_slot_ = + ThreadLocal::TypedSlot::makeUnique(context_->threadLocal()); + + tls_slot_->set( + [&](Event::Dispatcher&) { return std::make_shared(*this); }); + createCluster(true); +} + void MetadataCredentialsProviderBase::createCluster(bool new_timer) { auto cluster = Utility::createInternalClusterStatic(cluster_name_, cluster_type_, uri_); diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index ed5c8ce386e1..b550f593150d 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -16,6 +16,7 @@ #include "source/common/common/logger.h" #include "source/common/common/thread.h" #include "source/common/init/target_impl.h" +#include "envoy/init/manager.h" #include "source/common/protobuf/message_validator_impl.h" #include "source/common/protobuf/utility.h" #include "source/extensions/common/aws/credentials_provider.h" @@ -145,6 +146,7 @@ class MetadataCredentialsProviderBase : public CachedCredentialsProviderBase { private: void createCluster(bool new_timer); + void initializeTlsAndCluster(); protected: struct LoadClusterEntryHandleImpl From e918f7681e1220162e2d6498232bd62399169f82 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 19 Dec 2024 05:18:08 +0000 Subject: [PATCH 02/21] test cases Signed-off-by: Nigel Brittain --- source/extensions/common/aws/BUILD | 2 +- .../common/aws/credentials_provider_impl.cc | 27 +++++++++---------- .../common/aws/credentials_provider_impl.h | 2 +- .../aws/credentials_provider_impl_test.cc | 11 ++++++++ 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/source/extensions/common/aws/BUILD b/source/extensions/common/aws/BUILD index 01cac36a32b6..fc08f5407312 100644 --- a/source/extensions/common/aws/BUILD +++ b/source/extensions/common/aws/BUILD @@ -121,8 +121,8 @@ envoy_cc_library( "//source/common/common:logger_lib", "//source/common/common:thread_lib", "//source/common/http:utility_lib", - "//source/common/init:target_lib", "//source/common/init:manager_lib", + "//source/common/init:target_lib", "//source/common/json:json_loader_lib", "//source/common/runtime:runtime_features_lib", "//source/common/tracing:http_tracer_lib", diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 4b7583b821f7..3275d2e764fb 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -148,16 +148,13 @@ MetadataCredentialsProviderBase::MetadataCredentialsProviderBase( ALL_METADATACREDENTIALSPROVIDER_STATS(POOL_COUNTER(*scope_), POOL_GAUGE(*scope_))}); stats_->metadata_refresh_state_.set(uint64_t(refresh_state_)); - // If credential provider is being created during Envoy initialization, use init manager to delay cluster creation - // If we are here during normal processing, such as xDS update, then create clusters and initialize TLS immediately - if(context_->initManager().state() == Envoy::Init::Manager::State::Initialized) - { - initializeTlsAndCluster(); - } - else - { + // If credential provider is being created during Envoy initialization, use init manager to + // delay cluster creation If we are here during normal processing, such as xDS update, then + // create clusters and initialize TLS immediately + if (context_->initManager().state() == Envoy::Init::Manager::State::Initialized) { + initializeTlsAndCluster(); + } else { init_target_ = std::make_unique(debug_name_, [this]() -> void { - initializeTlsAndCluster(); init_target_->ready(); @@ -177,12 +174,12 @@ MetadataCredentialsProviderBase::ThreadLocalCredentialsCache::~ThreadLocalCreden } void MetadataCredentialsProviderBase::initializeTlsAndCluster() { - tls_slot_ = - ThreadLocal::TypedSlot::makeUnique(context_->threadLocal()); - - tls_slot_->set( - [&](Event::Dispatcher&) { return std::make_shared(*this); }); - createCluster(true); + tls_slot_ = + ThreadLocal::TypedSlot::makeUnique(context_->threadLocal()); + + tls_slot_->set( + [&](Event::Dispatcher&) { return std::make_shared(*this); }); + createCluster(true); } void MetadataCredentialsProviderBase::createCluster(bool new_timer) { diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index b550f593150d..0e9e089f99bb 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -10,13 +10,13 @@ #include "envoy/event/timer.h" #include "envoy/extensions/common/aws/v3/credential_provider.pb.h" #include "envoy/http/message.h" +#include "envoy/init/manager.h" #include "envoy/server/factory_context.h" #include "source/common/common/lock_guard.h" #include "source/common/common/logger.h" #include "source/common/common/thread.h" #include "source/common/init/target_impl.h" -#include "envoy/init/manager.h" #include "source/common/protobuf/message_validator_impl.h" #include "source/common/protobuf/utility.h" #include "source/extensions/common/aws/credentials_provider.h" diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index a25fcd165b6e..c432a0c7ef11 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -1468,6 +1468,17 @@ class ContainerCredentialsProviderTest : public testing::Test { NiceMock init_watcher_; }; +TEST_F(ContainerCredentialsProviderTest, CreationAfterInitCompleted) { + // Handle the case where we've already completed init. This validates that clusters create + // successfully but init manager is not used + NiceMock initManager; + ON_CALL(context_, initManager()).WillByDefault(ReturnRef(initManager)); + ON_CALL(initManager, state()).WillByDefault(Return(Init::Manager::State::Initialized)); + EXPECT_CALL(cluster_manager_, addOrUpdateCluster(_, _, _)); + EXPECT_CALL(initManager, add(_)).Times(0); + setupProvider(); +} + TEST_F(ContainerCredentialsProviderTest, FailedFetchingDocument) { // Setup timer. From 1e3fa55577ff9d332c4a1687e36054a989e3806e Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 19 Dec 2024 06:43:07 +0000 Subject: [PATCH 03/21] fix test leak Signed-off-by: Nigel Brittain --- .../common/aws/credentials_provider_impl_test.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index c432a0c7ef11..e59eef7785e9 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -1466,17 +1466,18 @@ class ContainerCredentialsProviderTest : public testing::Test { MetadataFetcher::MetadataReceiver::RefreshState refresh_state_; Init::TargetHandlePtr init_target_; NiceMock init_watcher_; + NiceMock init_manager_; }; TEST_F(ContainerCredentialsProviderTest, CreationAfterInitCompleted) { // Handle the case where we've already completed init. This validates that clusters create // successfully but init manager is not used - NiceMock initManager; - ON_CALL(context_, initManager()).WillByDefault(ReturnRef(initManager)); - ON_CALL(initManager, state()).WillByDefault(Return(Init::Manager::State::Initialized)); + ON_CALL(context_, initManager()).WillByDefault(ReturnRef(init_manager_)); + ON_CALL(init_manager_, state()).WillByDefault(Return(Init::Manager::State::Initialized)); EXPECT_CALL(cluster_manager_, addOrUpdateCluster(_, _, _)); - EXPECT_CALL(initManager, add(_)).Times(0); + EXPECT_CALL(init_manager_, add(_)).Times(0); setupProvider(); + delete (raw_metadata_fetcher_); } TEST_F(ContainerCredentialsProviderTest, FailedFetchingDocument) { From a33b14226fa8f138bbc5edb571ccfbf3d496ebb4 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Sat, 21 Dec 2024 04:43:36 +0000 Subject: [PATCH 04/21] singleton webidentity Signed-off-by: Nigel Brittain --- .../common/aws/credentials_provider_impl.cc | 27 ++++++++++++++++--- .../common/aws/credentials_provider_impl.h | 12 +++------ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 3275d2e764fb..0796a2092607 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -149,7 +149,7 @@ MetadataCredentialsProviderBase::MetadataCredentialsProviderBase( stats_->metadata_refresh_state_.set(uint64_t(refresh_state_)); // If credential provider is being created during Envoy initialization, use init manager to - // delay cluster creation If we are here during normal processing, such as xDS update, then + // delay cluster creation. If we are here during normal processing, such as xDS update, then // create clusters and initialize TLS immediately if (context_->initManager().state() == Envoy::Init::Manager::State::Initialized) { initializeTlsAndCluster(); @@ -977,16 +977,16 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( if (!web_token_path.empty() && !role_arn.empty()) { const auto session_name = sessionName(api); const auto sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; - const auto region_uuid = absl::StrCat(region, "_", context->api().randomGenerator().uuid()); + // const auto region_uuid = absl::StrCat(region, "_", context->api().randomGenerator().uuid()); - const auto cluster_name = stsClusterName(region_uuid); + const auto cluster_name = stsClusterName(region); ENVOY_LOG( debug, "Using web identity credentials provider with STS endpoint: {} and session name: {}", sts_endpoint, session_name); add(factories.createWebIdentityCredentialsProvider( - api, context, fetch_metadata_using_curl, MetadataFetcher::create, cluster_name, + api, context, singleton_manager, fetch_metadata_using_curl, MetadataFetcher::create, cluster_name, web_token_path, "", sts_endpoint, role_arn, session_name, refresh_state, initialization_timer)); } @@ -1036,6 +1036,7 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( // extensions SINGLETON_MANAGER_REGISTRATION(container_credentials_provider); SINGLETON_MANAGER_REGISTRATION(instance_profile_credentials_provider); +SINGLETON_MANAGER_REGISTRATION(web_identity_credentials_provider); CredentialsProviderSharedPtr DefaultCredentialsProviderChain::createContainerCredentialsProvider( Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, @@ -1071,6 +1072,24 @@ DefaultCredentialsProviderChain::createInstanceProfileCredentialsProvider( }); } + CredentialsProviderSharedPtr + DefaultCredentialsProviderChain::createWebIdentityCredentialsProvider( + Api::Api& api, ServerFactoryContextOptRef context,Singleton::Manager& singleton_manager, + const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view cluster_name, + absl::string_view token_file_path, absl::string_view token, absl::string_view sts_endpoint, + absl::string_view role_arn, absl::string_view role_session_name, + MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + std::chrono::seconds initialization_timer) const { + return singleton_manager.getTyped( + SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), + [&context, &api, cluster_name, fetch_metadata_using_curl,create_metadata_fetcher_cb,token_file_path, token, sts_endpoint, role_arn, role_session_name, refresh_state, initialization_timer] { + return std::make_shared(api, context, fetch_metadata_using_curl, create_metadata_fetcher_cb, token_file_path, token, + sts_endpoint, role_arn, role_session_name, refresh_state, initialization_timer, + cluster_name); + }); + } + absl::StatusOr createCredentialsProviderFromConfig( Server::Configuration::ServerFactoryContext& context, absl::string_view region, const envoy::extensions::common::aws::v3::AwsCredentialProvider& config) { diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index 0e9e089f99bb..73978005181c 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -319,6 +319,7 @@ class ContainerCredentialsProvider : public MetadataCredentialsProviderBase, * OpenID) */ class WebIdentityCredentialsProvider : public MetadataCredentialsProviderBase, + public Envoy::Singleton::Instance, public MetadataFetcher::MetadataReceiver { public: // token and token_file_path are mutually exclusive. If token is not empty, token_file_path is @@ -382,7 +383,7 @@ class CredentialsProviderChainFactories { createCredentialsFileCredentialsProvider(Api::Api& api) const PURE; virtual CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( - Api::Api& api, ServerFactoryContextOptRef context, + Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl, CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view cluster_name, absl::string_view token_file_path, absl::string_view token, absl::string_view sts_endpoint, @@ -456,18 +457,13 @@ class DefaultCredentialsProviderChain : public CredentialsProviderChain, std::chrono::seconds initialization_timer, absl::string_view cluster_name) const override; CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( - Api::Api& api, ServerFactoryContextOptRef context, + Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl, CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view cluster_name, absl::string_view token_file_path, absl::string_view token, absl::string_view sts_endpoint, absl::string_view role_arn, absl::string_view role_session_name, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - std::chrono::seconds initialization_timer) const override { - return std::make_shared( - api, context, fetch_metadata_using_curl, create_metadata_fetcher_cb, token_file_path, token, - sts_endpoint, role_arn, role_session_name, refresh_state, initialization_timer, - cluster_name); - } + std::chrono::seconds initialization_timer) const override; }; /** From 6ae5b8a272f224badb7d794e0f0cc9dcf8684a98 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Sun, 15 Dec 2024 10:36:12 +0000 Subject: [PATCH 05/21] credential_provider Signed-off-by: Nigel Brittain --- api/envoy/extensions/common/aws/v3/BUILD | 5 +- .../common/aws/v3/credential_provider.proto | 48 +- changelogs/current.yaml | 3 +- ...ing-filter-credential-provider-config.yaml | 70 +++ .../http_filters/_include/aws_credentials.rst | 22 +- .../aws_request_signing_filter.rst | 12 + source/extensions/common/aws/BUILD | 3 + .../common/aws/credentials_provider_impl.cc | 228 ++++++--- .../common/aws/credentials_provider_impl.h | 203 +++++--- .../common/aws/region_provider_impl.cc | 48 +- .../common/aws/region_provider_impl.h | 23 +- source/extensions/common/aws/utility.cc | 28 +- source/extensions/common/aws/utility.h | 15 +- .../filters/http/aws_lambda/config.cc | 4 +- .../http/aws_request_signing/config.cc | 166 +++++-- .../aws/credentials_provider_impl_test.cc | 465 ++++++++++++++---- .../common/aws/region_provider_impl_test.cc | 55 ++- test/extensions/common/aws/utility_test.cc | 4 +- .../http/aws_request_signing/config_test.cc | 92 +++- 19 files changed, 1153 insertions(+), 341 deletions(-) create mode 100644 docs/root/configuration/http/http_filters/_include/aws-request-signing-filter-credential-provider-config.yaml diff --git a/api/envoy/extensions/common/aws/v3/BUILD b/api/envoy/extensions/common/aws/v3/BUILD index 29ebf0741406..09a37ad16b83 100644 --- a/api/envoy/extensions/common/aws/v3/BUILD +++ b/api/envoy/extensions/common/aws/v3/BUILD @@ -5,5 +5,8 @@ load("@envoy_api//bazel:api_build_system.bzl", "api_proto_package") licenses(["notice"]) # Apache 2 api_proto_package( - deps = ["@com_github_cncf_xds//udpa/annotations:pkg"], + deps = [ + "//envoy/config/core/v3:pkg", + "@com_github_cncf_xds//udpa/annotations:pkg", + ], ) diff --git a/api/envoy/extensions/common/aws/v3/credential_provider.proto b/api/envoy/extensions/common/aws/v3/credential_provider.proto index b623a40a437d..c05e34cbd30a 100644 --- a/api/envoy/extensions/common/aws/v3/credential_provider.proto +++ b/api/envoy/extensions/common/aws/v3/credential_provider.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.extensions.common.aws.v3; +import "envoy/config/core/v3/base.proto"; + import "udpa/annotations/sensitive.proto"; import "udpa/annotations/status.proto"; import "validate/validate.proto"; @@ -14,18 +16,26 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: AWS common configuration] -// Configuration for AWS credential provider. Normally, this is optional and the credentials are +// Configuration for AWS credential provider. This is optional and the credentials are normally // retrieved from the environment or AWS configuration files by following the default credential -// provider chain. This is to support cases where the credentials need to be explicitly provided -// by the control plane. +// provider chain. However, this configuration can be used to override the default behavior. message AwsCredentialProvider { // The option to use `AssumeRoleWithWebIdentity `_. - // If inline_credential is set, this is ignored. - AssumeRoleWithWebIdentityCredentialProvider assume_role_with_web_identity = 1; + AssumeRoleWithWebIdentityCredentialProvider assume_role_with_web_identity_provider = 1; - // The option to use an inline credential. - // If this is set, it takes precedence over assume_role_with_web_identity. + // The option to use an inline credential. If inline credential is provided, no chain will be created and only the inline credential will be used. InlineCredentialProvider inline_credential = 2; + + // The option to specify parameters for credential retrieval from an envoy data source, such as a file in AWS credential format. + CredentialsFileCredentialProvider credentials_file_provider = 3; + + // Create a custom credential provider chain instead of the default credential provider chain. + // If set to TRUE, the credential provider chain that is created contains only those set in this credential provider message. + // If set to FALSE, the settings provided here will act as modifiers to the default credential provider chain. + // Defaults to FALSE. + // + // This has no effect if inline_credential is provided. + bool custom_credential_provider_chain = 4; } // Configuration to use an inline AWS credential. This is an equivalent to setting the well-known @@ -43,12 +53,26 @@ message InlineCredentialProvider { } // Configuration to use `AssumeRoleWithWebIdentity `_ -// to get AWS credentials. +// to retrieve AWS credentials. message AssumeRoleWithWebIdentityCredentialProvider { + // Data source for a web identity token that is provided by the identity provider to assume the role. + // When using this data source, even if a ``watched_directory`` is provided, the token file will only be re-read when the credentials + // returned from AssumeRoleWithWebIdentity expire. + config.core.v3.DataSource web_identity_token_data_source = 1 + [(udpa.annotations.sensitive) = true]; + // The ARN of the role to assume. - string role_arn = 1 [(validate.rules).string = {min_len: 1}]; + string role_arn = 2 [(validate.rules).string = {min_len: 1}]; - // The web identity token that is provided by the identity provider to assume the role. - string web_identity_token = 2 - [(validate.rules).string = {min_len: 1}, (udpa.annotations.sensitive) = true]; + // Optional role session name to use in AssumeRoleWithWebIdentity API call. + string role_session_name = 3; +} + +message CredentialsFileCredentialProvider { + // Data source from which to retrieve AWS credentials + // When using this data source, if a ``watched_directory`` is provided, the credential file will be re-read when a file move is detected. + config.core.v3.DataSource credentials_data_source = 1 [(udpa.annotations.sensitive) = true]; + + // The profile within the credentials_file data source + string profile = 2; } diff --git a/changelogs/current.yaml b/changelogs/current.yaml index 030fddfa2054..b3a547b3a295 100644 --- a/changelogs/current.yaml +++ b/changelogs/current.yaml @@ -209,7 +209,8 @@ new_features: change: | Added an optional field :ref:`credential_provider ` - to the AWS request signing filter to explicitly specify a source for AWS credentials. + to the AWS request signing filter to explicitly specify a source for AWS credentials. Credential file and AssumeRoleWithWebIdentity + behaviour can also be overridden with this field. - area: tls change: | Added support for P-384 and P-521 curves for TLS server certificates. diff --git a/docs/root/configuration/http/http_filters/_include/aws-request-signing-filter-credential-provider-config.yaml b/docs/root/configuration/http/http_filters/_include/aws-request-signing-filter-credential-provider-config.yaml new file mode 100644 index 000000000000..0969a29e13a0 --- /dev/null +++ b/docs/root/configuration/http/http_filters/_include/aws-request-signing-filter-credential-provider-config.yaml @@ -0,0 +1,70 @@ +static_resources: + listeners: + - address: + socket_address: + address: 0.0.0.0 + port_value: 10000 + filter_chains: + - filters: + - name: envoy.filters.network.http_connection_manager + typed_config: + '@type': type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + stat_prefix: ingress_http + http_filters: + - name: envoy.filters.http.router + typed_config: + '@type': type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + route_config: + name: local_route + virtual_hosts: + - domains: + - '*' + name: local_service + routes: + - match: {prefix: "/"} + route: {cluster: default_service} + clusters: + - name: default_service + load_assignment: + cluster_name: default_service + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 127.0.0.1 + port_value: 10001 + typed_extension_protocol_options: + envoy.extensions.upstreams.http.v3.HttpProtocolOptions: + "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions + upstream_http_protocol_options: + auto_sni: true + auto_san_validation: true + auto_config: + http2_protocol_options: {} + http_filters: + - name: envoy.filters.http.aws_request_signing + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.aws_request_signing.v3.AwsRequestSigning + credential_provider: + custom_credential_provider_chain: true + credentials_file_provider: + credentials_data_source: + filename: /tmp/a + watched_directory: + path: /tmp + service_name: vpc-lattice-svcs + region: '*' + signing_algorithm: AWS_SIGV4A + use_unsigned_payload: true + match_excluded_headers: + - prefix: x-envoy + - prefix: x-forwarded + - exact: x-amzn-trace-id + - name: envoy.filters.http.upstream_codec + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.upstream_codec.v3.UpstreamCodec + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext diff --git a/docs/root/configuration/http/http_filters/_include/aws_credentials.rst b/docs/root/configuration/http/http_filters/_include/aws_credentials.rst index 0d945684a009..edf5138df5c0 100644 --- a/docs/root/configuration/http/http_filters/_include/aws_credentials.rst +++ b/docs/root/configuration/http/http_filters/_include/aws_credentials.rst @@ -5,13 +5,24 @@ The filter uses a number of different credentials providers to obtain an AWS acc By default, it moves through the credentials providers in the order described below, stopping when one of them returns an access key ID and a secret access key (the session token is optional). -1. Environment variables. The environment variables ``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``, and ``AWS_SESSION_TOKEN`` are used. +1. :ref:`inline_credentials ` field. + If this field is configured, no other credentials providers will be used. -2. The AWS credentials file. The environment variables ``AWS_SHARED_CREDENTIALS_FILE`` and ``AWS_PROFILE`` are respected if they are set, else +2. :ref:`credential_provider ` field. + By using this field, the filter allows override of the default environment variables, credential parameters and file locations. + Currently this supports both AWS credentials file locations and content, and AssumeRoleWithWebIdentity token files. + If the :ref:`credential_provider ` field is provided, + it can be used either to modify the default credentials provider chain, or when :ref:`custom_credential_provider_chain ` + is set to ``true``, to create a custom credentials provider chain containing only the specified credentials provider settings. Examples of using these fields + are provided in :ref:`configuration examples `. + +3. Environment variables. The environment variables ``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``, and ``AWS_SESSION_TOKEN`` are used. + +4. The AWS credentials file. The environment variables ``AWS_SHARED_CREDENTIALS_FILE`` and ``AWS_PROFILE`` are respected if they are set, else the file ``~/.aws/credentials`` and profile ``default`` are used. The fields ``aws_access_key_id``, ``aws_secret_access_key``, and ``aws_session_token`` defined for the profile in the credentials file are used. These credentials are cached for 1 hour. -3. From `AssumeRoleWithWebIdentity `_ API call +5. From `AssumeRoleWithWebIdentity `_ API call towards AWS Security Token Service using ``WebIdentityToken`` read from a file pointed by ``AWS_WEB_IDENTITY_TOKEN_FILE`` environment variable and role arn read from ``AWS_ROLE_ARN`` environment variable. The credentials are extracted from the fields ``AccessKeyId``, ``SecretAccessKey``, and ``SessionToken`` are used, and credentials are cached for 1 hour or until they expire (according to the field @@ -30,7 +41,7 @@ secret access key (the session token is optional). If you require the use of SigV4A signing and you are using an alternate partition, such as cn or GovCloud, you can ensure correct generation of the STS endpoint by setting the first region in your SigV4A region set to the correct region (such as ``cn-northwest-1`` with no wildcard) -4. Either EC2 instance metadata, ECS task metadata or EKS Pod Identity. +6. Either EC2 instance metadata, ECS task metadata or EKS Pod Identity. For EC2 instance metadata, the fields ``AccessKeyId``, ``SecretAccessKey``, and ``Token`` are used, and credentials are cached for 1 hour. For ECS task metadata, the fields ``AccessKeyId``, ``SecretAccessKey``, and ``Token`` are used, and credentials are cached for 1 hour or until they expire (according to the field ``Expiration``). @@ -46,9 +57,6 @@ secret access key (the session token is optional). The static internal cluster will still be added even if initially ``envoy.reloadable_features.use_http_client_to_fetch_aws_credentials`` is not set so that subsequently if the reloadable feature is set to ``true`` the cluster config is available to fetch the credentials. -Alternatively, each AWS filter (either AWS Request Signing or AWS Lambda) has its own optional configuration to specify the source of the credentials. For example, AWS Request Signing filter -has :ref:`credential_provider ` field. - Statistics ---------- diff --git a/docs/root/configuration/http/http_filters/aws_request_signing_filter.rst b/docs/root/configuration/http/http_filters/aws_request_signing_filter.rst index 7bf0e5ba447c..c25b8a1fc0da 100644 --- a/docs/root/configuration/http/http_filters/aws_request_signing_filter.rst +++ b/docs/root/configuration/http/http_filters/aws_request_signing_filter.rst @@ -59,6 +59,8 @@ the following HTTP header modifications will be made by this extension: Example configuration --------------------- +.. _config_http_filters_aws_request_signing_examples: + Example filter configuration: .. literalinclude:: _include/aws-request-signing-filter.yaml @@ -86,6 +88,16 @@ An example of configuring this filter to use ``AWS_SIGV4A`` signing with a wildc :linenos: :caption: :download:`aws-request-signing-filter-sigv4a.yaml <_include/aws-request-signing-filter-sigv4a.yaml>` +An example of using the credential provider configuration to modify the default behaviour of the credential provider chain. In this scenario, we use +the ``custom_credential_provider_chain`` option to disable the default credential provider chain and use specific settings for the credential file +credentials provider. These settings include a ``watched_directory``, which configures the filter to reload the credentials file when it changes. + +.. literalinclude:: _include/aws-request-signing-filter-credential-provider-config.yaml + :language: yaml + :lines: 46-56 + :lineno-start: 46 + :linenos: + :caption: :download:`aws-request-signing-filter-credential-provider-config.yaml <_include/aws-request-signing-filter-credential-provider-config.yaml>` Configuration as an upstream HTTP filter ---------------------------------------- diff --git a/source/extensions/common/aws/BUILD b/source/extensions/common/aws/BUILD index ddd52931367a..e16539526327 100644 --- a/source/extensions/common/aws/BUILD +++ b/source/extensions/common/aws/BUILD @@ -120,12 +120,14 @@ envoy_cc_library( "//envoy/api:api_interface", "//source/common/common:logger_lib", "//source/common/common:thread_lib", + "//source/common/config:datasource_lib", "//source/common/http:utility_lib", "//source/common/init:target_lib", "//source/common/json:json_loader_lib", "//source/common/runtime:runtime_features_lib", "//source/common/tracing:http_tracer_lib", "@com_google_absl//absl/time", + "@envoy_api//envoy/config/core/v3:pkg_cc_proto", "@envoy_api//envoy/extensions/common/aws/v3:pkg_cc_proto", ], ) @@ -176,5 +178,6 @@ envoy_cc_library( ":region_provider_interface", ":utility_lib", "//source/common/common:logger_lib", + "@envoy_api//envoy/extensions/common/aws/v3:pkg_cc_proto", ], ) diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 82728d8d39f9..55ca36324bae 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -324,36 +324,79 @@ void MetadataCredentialsProviderBase::setCredentialsToAllThreads( } } +CredentialsFileCredentialsProvider::CredentialsFileCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config) + : context_(context), profile_("") { + + if (credential_file_config.has_credentials_data_source()) { + auto provider_or_error_ = Config::DataSource::DataSourceProvider::create( + credential_file_config.credentials_data_source(), context.mainThreadDispatcher(), + context.threadLocal(), context.api(), false, 4096); + if (provider_or_error_.ok()) { + credential_file_data_source_provider_ = std::move(provider_or_error_.value()); + if (credential_file_config.credentials_data_source().has_watched_directory()) { + has_watched_directory_ = true; + } + } else { + ENVOY_LOG_MISC(info, "Invalid credential file data source"); + credential_file_data_source_provider_.reset(); + } + } + if (!credential_file_config.profile().empty()) { + profile_ = credential_file_config.profile(); + } +} + bool CredentialsFileCredentialsProvider::needsRefresh() { - return api_.timeSource().systemTime() - last_updated_ > REFRESH_INTERVAL; + return has_watched_directory_ + ? true + : context_.api().timeSource().systemTime() - last_updated_ > REFRESH_INTERVAL; + // return context_.api().timeSource().systemTime() - last_updated_ > REFRESH_INTERVAL; } void CredentialsFileCredentialsProvider::refresh() { + auto profile = profile_.empty() ? Utility::getCredentialProfileName() : profile_; + ENVOY_LOG(debug, "Getting AWS credentials from the credentials file"); - auto credentials_file = Utility::getCredentialFilePath(); - auto profile = profile_.empty() ? Utility::getCredentialProfileName() : profile_; + std::string credential_file_data, credential_file_path; - ENVOY_LOG(debug, "Credentials file path = {}, profile name = {}", credentials_file, profile); + // Use data source if provided, otherwise read from default AWS credential file path + if (credential_file_data_source_provider_.has_value()) { + credential_file_data = credential_file_data_source_provider_.value()->data(); + credential_file_path = ""; + } else { + credential_file_path = Utility::getCredentialFilePath(); + auto credential_file = context_.api().fileSystem().fileReadToEnd(credential_file_path); + if (credential_file.ok()) { + credential_file_data = credential_file.value(); + } else { + ENVOY_LOG(debug, "Unable to read from credential file {}", credential_file_path); + // Update last_updated_ now so that even if this function returns before successfully + // extracting credentials, this function won't be called again until after the + // REFRESH_INTERVAL. This prevents envoy from attempting and failing to read the credentials + // file on every request if there are errors extracting credentials from it (e.g. if the + // credentials file doesn't exist). + last_updated_ = context_.api().timeSource().systemTime(); + return; + } + } + ENVOY_LOG(debug, "Credentials file path = {}, profile name = {}", credential_file_path, profile); - extractCredentials(credentials_file, profile); + extractCredentials(credential_file_data.data(), profile); } -void CredentialsFileCredentialsProvider::extractCredentials(const std::string& credentials_file, - const std::string& profile) { - // Update last_updated_ now so that even if this function returns before successfully - // extracting credentials, this function won't be called again until after the REFRESH_INTERVAL. - // This prevents envoy from attempting and failing to read the credentials file on every request - // if there are errors extracting credentials from it (e.g. if the credentials file doesn't - // exist). - last_updated_ = api_.timeSource().systemTime(); +void CredentialsFileCredentialsProvider::extractCredentials(absl::string_view credentials_string, + absl::string_view profile) { std::string access_key_id, secret_access_key, session_token; absl::flat_hash_map elements = { {AWS_ACCESS_KEY_ID, ""}, {AWS_SECRET_ACCESS_KEY, ""}, {AWS_SESSION_TOKEN, ""}}; absl::flat_hash_map::iterator it; - Utility::resolveProfileElements(credentials_file, profile, elements); + Utility::resolveProfileElementsFromString(credentials_string.data(), profile.data(), elements); // if profile file fails to load, or these elements are not found in the profile, their values // will remain blank when retrieving them from the hash map access_key_id = elements.find(AWS_ACCESS_KEY_ID)->second; @@ -364,14 +407,14 @@ void CredentialsFileCredentialsProvider::extractCredentials(const std::string& c // Return empty credentials if we're unable to retrieve from profile cached_credentials_ = Credentials(); } else { - ENVOY_LOG(debug, "Found following AWS credentials for profile '{}' in {}: {}={}, {}={}, {}={}", - profile, credentials_file, AWS_ACCESS_KEY_ID, access_key_id, AWS_SECRET_ACCESS_KEY, + ENVOY_LOG(debug, "Found following AWS credentials for profile '{}': {}={}, {}={}, {}={}", + profile, AWS_ACCESS_KEY_ID, access_key_id, AWS_SECRET_ACCESS_KEY, secret_access_key.empty() ? "" : "*****", AWS_SESSION_TOKEN, session_token.empty() ? "" : "*****"); cached_credentials_ = Credentials(access_key_id, secret_access_key, session_token); } - last_updated_ = api_.timeSource().systemTime(); + last_updated_ = context_.api().timeSource().systemTime(); } InstanceProfileCredentialsProvider::InstanceProfileCredentialsProvider( @@ -736,21 +779,33 @@ void ContainerCredentialsProvider::onMetadataError(Failure reason) { } WebIdentityCredentialsProvider::WebIdentityCredentialsProvider( - Api::Api& api, ServerFactoryContextOptRef context, - const CurlMetadataFetcher& fetch_metadata_using_curl, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view token_file_path, - absl::string_view token, absl::string_view sts_endpoint, absl::string_view role_arn, - absl::string_view role_session_name, + Server::Configuration::ServerFactoryContext& context, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - std::chrono::seconds initialization_timer, absl::string_view cluster_name = {}) + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name = {}) : MetadataCredentialsProviderBase( - api, context, fetch_metadata_using_curl, create_metadata_fetcher_cb, cluster_name, + context.api(), context, nullptr, create_metadata_fetcher_cb, cluster_name, envoy::config::cluster::v3::Cluster::LOGICAL_DNS /*cluster_type*/, sts_endpoint, refresh_state, initialization_timer), - token_file_path_(token_file_path), token_(token), sts_endpoint_(sts_endpoint), - role_arn_(role_arn), role_session_name_(role_session_name) {} + sts_endpoint_(sts_endpoint), role_arn_(web_identity_config.role_arn()), + role_session_name_(web_identity_config.role_session_name()) { + + auto provider_or_error_ = Config::DataSource::DataSourceProvider::create( + web_identity_config.web_identity_token_data_source(), context.mainThreadDispatcher(), + context.threadLocal(), context.api(), false, 4096); + if (provider_or_error_.ok()) { + web_identity_data_source_provider_ = std::move(provider_or_error_.value()); + } else { + ENVOY_LOG_MISC(info, "Invalid web identity data source"); + web_identity_data_source_provider_.reset(); + } +} bool WebIdentityCredentialsProvider::needsRefresh() { + const auto now = api_.timeSource().systemTime(); auto expired = (now - last_updated_ > REFRESH_INTERVAL); @@ -762,19 +817,18 @@ bool WebIdentityCredentialsProvider::needsRefresh() { } void WebIdentityCredentialsProvider::refresh() { - ENVOY_LOG(debug, "Getting AWS web identity credentials from STS: {}", sts_endpoint_); - std::string identity_token = token_; - if (identity_token.empty()) { - const auto web_token_file_or_error = api_.fileSystem().fileReadToEnd(token_file_path_); - if (!web_token_file_or_error.ok()) { - ENVOY_LOG(debug, "Unable to read AWS web identity credentials from {}", token_file_path_); - cached_credentials_ = Credentials(); - return; - } - identity_token = web_token_file_or_error.value(); + absl::string_view web_identity_data; + + // If we're unable to read from the configured data source, exit early. + if (!web_identity_data_source_provider_.has_value()) { + return; } + ENVOY_LOG(debug, "Getting AWS web identity credentials from STS: {}", sts_endpoint_); + + web_identity_data = web_identity_data_source_provider_.value()->data(); + Http::RequestMessageImpl message; message.headers().setScheme(Http::Headers::get().SchemeValues.Https); message.headers().setMethod(Http::Headers::get().MethodValues.Get); @@ -787,7 +841,7 @@ void WebIdentityCredentialsProvider::refresh() { "&WebIdentityToken={}", Envoy::Http::Utility::PercentEncoding::encode(role_session_name_), Envoy::Http::Utility::PercentEncoding::encode(role_arn_), - Envoy::Http::Utility::PercentEncoding::encode(identity_token))); + Envoy::Http::Utility::PercentEncoding::encode(web_identity_data))); // Use the Accept header to ensure that AssumeRoleWithWebIdentityResponse is returned as JSON. message.headers().setReference(Http::CustomHeaders::get().Accept, Http::Headers::get().ContentTypeValues.Json); @@ -941,30 +995,77 @@ std::string stsClusterName(absl::string_view region) { return absl::StrCat(STS_TOKEN_CLUSTER, "-", region); } +CustomCredentialsProviderChain::CustomCredentialsProviderChain( + Server::Configuration::ServerFactoryContext& context, absl::string_view region, + const envoy::extensions::common::aws::v3::AwsCredentialProvider& credential_provider_config, + const CredentialsProviderChainFactories& factories) { + + // Custom chain currently only supports file based and web identity credentials + if (credential_provider_config.has_assume_role_with_web_identity_provider()) { + auto web_identity = credential_provider_config.assume_role_with_web_identity_provider(); + const std::string sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; + const auto region_uuid = absl::StrCat(region, "_", context.api().randomGenerator().uuid()); + const std::string cluster_name = stsClusterName(region_uuid); + std::string role_session_name = web_identity.role_session_name(); + if (role_session_name.empty()) { + web_identity.set_role_session_name(sessionName(context.api())); + } + const auto refresh_state = MetadataFetcher::MetadataReceiver::RefreshState::FirstRefresh; + const auto initialization_timer = std::chrono::seconds(2); + add(factories.createWebIdentityCredentialsProvider( + context, MetadataFetcher::create, sts_endpoint, refresh_state, initialization_timer, + web_identity, cluster_name)); + } + + if (credential_provider_config.has_credentials_file_provider()) { + add(factories.createCredentialsFileCredentialsProvider( + context, credential_provider_config.credentials_file_provider())); + } +} + DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, absl::string_view region, const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl, + const envoy::extensions::common::aws::v3::AwsCredentialProvider& credential_provider_config, const CredentialsProviderChainFactories& factories) { ENVOY_LOG(debug, "Using environment credentials provider"); add(factories.createEnvironmentCredentialsProvider()); - ENVOY_LOG(debug, "Using credentials file credentials provider"); - add(factories.createCredentialsFileCredentialsProvider(api)); - // Initial state for an async credential receiver auto refresh_state = MetadataFetcher::MetadataReceiver::RefreshState::FirstRefresh; // Initial amount of time for async credential receivers to wait for an initial refresh to succeed auto initialization_timer = std::chrono::seconds(2); - // WebIdentityCredentialsProvider can be used only if `context` is supplied which is required to - // use http async http client to make http calls to fetch the credentials. if (context) { - const auto web_token_path = absl::NullSafeStringView(std::getenv(AWS_WEB_IDENTITY_TOKEN_FILE)); - const auto role_arn = absl::NullSafeStringView(std::getenv(AWS_ROLE_ARN)); - if (!web_token_path.empty() && !role_arn.empty()) { - const auto session_name = sessionName(api); + + ENVOY_LOG(debug, "Using credentials file credentials provider"); + add(factories.createCredentialsFileCredentialsProvider( + context.value(), credential_provider_config.credentials_file_provider())); + + auto web_identity = credential_provider_config.assume_role_with_web_identity_provider(); + + // Configure defaults if nothing is set in the config + if (!web_identity.has_web_identity_token_data_source()) { + web_identity.mutable_web_identity_token_data_source()->set_filename( + absl::NullSafeStringView(std::getenv(AWS_WEB_IDENTITY_TOKEN_FILE))); + } + + if (web_identity.role_arn().empty()) { + web_identity.set_role_arn(absl::NullSafeStringView(std::getenv(AWS_ROLE_ARN))); + } + + if (web_identity.role_session_name().empty()) { + web_identity.set_role_session_name(sessionName(api)); + } + + if ((!web_identity.web_identity_token_data_source().filename().empty() || + !web_identity.web_identity_token_data_source().inline_bytes().empty() || + !web_identity.web_identity_token_data_source().inline_string().empty() || + !web_identity.web_identity_token_data_source().environment_variable().empty()) && + !web_identity.role_arn().empty()) { + const auto sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; const auto region_uuid = absl::StrCat(region, "_", context->api().randomGenerator().uuid()); @@ -973,11 +1074,10 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( ENVOY_LOG( debug, "Using web identity credentials provider with STS endpoint: {} and session name: {}", - sts_endpoint, session_name); + sts_endpoint, web_identity.role_session_name()); add(factories.createWebIdentityCredentialsProvider( - api, context, fetch_metadata_using_curl, MetadataFetcher::create, cluster_name, - web_token_path, "", sts_endpoint, role_arn, session_name, refresh_state, - initialization_timer)); + context.value(), MetadataFetcher::create, sts_endpoint, refresh_state, + initialization_timer, web_identity, cluster_name)); } } @@ -1060,34 +1160,6 @@ DefaultCredentialsProviderChain::createInstanceProfileCredentialsProvider( }); } -absl::StatusOr createCredentialsProviderFromConfig( - Server::Configuration::ServerFactoryContext& context, absl::string_view region, - const envoy::extensions::common::aws::v3::AwsCredentialProvider& config) { - // The precedence order is: inline_credential > assume_role_with_web_identity. - if (config.has_inline_credential()) { - const auto& inline_credential = config.inline_credential(); - return std::make_shared(inline_credential.access_key_id(), - inline_credential.secret_access_key(), - inline_credential.session_token()); - } else if (config.has_assume_role_with_web_identity()) { - const auto& web_identity = config.assume_role_with_web_identity(); - const std::string& role_arn = web_identity.role_arn(); - const std::string& token = web_identity.web_identity_token(); - const std::string sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; - const auto region_uuid = absl::StrCat(region, "_", context.api().randomGenerator().uuid()); - const std::string cluster_name = stsClusterName(region_uuid); - const std::string role_session_name = sessionName(context.api()); - const auto refresh_state = MetadataFetcher::MetadataReceiver::RefreshState::FirstRefresh; - // This "two seconds" is a bit arbitrary, but matches the other places in the codebase. - const auto initialization_timer = std::chrono::seconds(2); - return std::make_shared( - context.api(), context, nullptr, MetadataFetcher::create, "", token, sts_endpoint, role_arn, - role_session_name, refresh_state, initialization_timer, cluster_name); - } else { - return absl::InvalidArgumentError("No AWS credential provider specified"); - } -} - } // namespace Aws } // namespace Common } // namespace Extensions diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index ed5c8ce386e1..db7825b3b2f7 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -7,6 +7,7 @@ #include "envoy/api/api.h" #include "envoy/common/optref.h" +#include "envoy/config/core/v3/base.pb.h" #include "envoy/event/timer.h" #include "envoy/extensions/common/aws/v3/credential_provider.pb.h" #include "envoy/http/message.h" @@ -15,6 +16,7 @@ #include "source/common/common/lock_guard.h" #include "source/common/common/logger.h" #include "source/common/common/thread.h" +#include "source/common/config/datasource.h" #include "source/common/init/target_impl.h" #include "source/common/protobuf/message_validator_impl.h" #include "source/common/protobuf/utility.h" @@ -92,18 +94,20 @@ class CachedCredentialsProviderBase : public CredentialsProvider, */ class CredentialsFileCredentialsProvider : public CachedCredentialsProviderBase { public: - CredentialsFileCredentialsProvider(Api::Api& api) : CredentialsFileCredentialsProvider(api, "") {} - - CredentialsFileCredentialsProvider(Api::Api& api, const std::string& profile) - : api_(api), profile_(profile) {} + CredentialsFileCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config = {}); private: - Api::Api& api_; - const std::string profile_; + Server::Configuration::ServerFactoryContext& context_; + std::string profile_; + absl::optional credential_file_data_source_provider_; + bool has_watched_directory_ = false; bool needsRefresh() override; void refresh() override; - void extractCredentials(const std::string& credentials_file, const std::string& profile); + void extractCredentials(absl::string_view credentials_string, absl::string_view profile); }; class LoadClusterEntryHandle { @@ -321,29 +325,24 @@ class WebIdentityCredentialsProvider : public MetadataCredentialsProviderBase, public: // token and token_file_path are mutually exclusive. If token is not empty, token_file_path is // not used, and vice versa. - WebIdentityCredentialsProvider(Api::Api& api, ServerFactoryContextOptRef context, - const CurlMetadataFetcher& fetch_metadata_using_curl, - CreateMetadataFetcherCb create_metadata_fetcher_cb, - absl::string_view token_file_path, absl::string_view token, - absl::string_view sts_endpoint, absl::string_view role_arn, - absl::string_view role_session_name, - MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - std::chrono::seconds initialization_timer, - absl::string_view cluster_name); + WebIdentityCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name); // Following functions are for MetadataFetcher::MetadataReceiver interface void onMetadataSuccess(const std::string&& body) override; void onMetadataError(Failure reason) override; - const std::string& tokenForTesting() const { return token_; } const std::string& roleArnForTesting() const { return role_arn_; } private: - // token_ and token_file_path_ are mutually exclusive. If token_ is set, token_file_path_ is not - // used. - const std::string token_file_path_; - const std::string token_; const std::string sts_endpoint_; + absl::optional web_identity_data_source_provider_; const std::string role_arn_; const std::string role_session_name_; @@ -376,17 +375,19 @@ class CredentialsProviderChainFactories { virtual CredentialsProviderSharedPtr createEnvironmentCredentialsProvider() const PURE; - virtual CredentialsProviderSharedPtr - createCredentialsFileCredentialsProvider(Api::Api& api) const PURE; + virtual CredentialsProviderSharedPtr createCredentialsFileCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config = {}) const PURE; virtual CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( - Api::Api& api, ServerFactoryContextOptRef context, - const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view cluster_name, - absl::string_view token_file_path, absl::string_view token, absl::string_view sts_endpoint, - absl::string_view role_arn, absl::string_view role_session_name, + Server::Configuration::ServerFactoryContext& context, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - std::chrono::seconds initialization_timer) const PURE; + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name) const PURE; virtual CredentialsProviderSharedPtr createContainerCredentialsProvider( Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, @@ -405,6 +406,89 @@ class CredentialsProviderChainFactories { std::chrono::seconds initialization_timer, absl::string_view cluster_name) const PURE; }; +class CustomCredentialsProviderChain : public CredentialsProviderChain, + public CredentialsProviderChainFactories { +public: + CustomCredentialsProviderChain( + Server::Configuration::ServerFactoryContext& context, absl::string_view region, + const envoy::extensions::common::aws::v3::AwsCredentialProvider& credential_provider_config, + const CredentialsProviderChainFactories& factories); + + CustomCredentialsProviderChain( + Server::Configuration::ServerFactoryContext& context, absl::string_view region, + const envoy::extensions::common::aws::v3::AwsCredentialProvider& credential_provider_config) + : CustomCredentialsProviderChain(context, region, credential_provider_config, *this) {} + + CredentialsProviderSharedPtr createCredentialsFileCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config = {} + + ) const override { + + return std::make_shared(context, credential_file_config); + }; + + CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name) const override { + return std::make_shared( + context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, + web_identity_config, cluster_name); + }; + + CredentialsProviderSharedPtr createEnvironmentCredentialsProvider() const override { + return nullptr; + } + + CredentialsProviderSharedPtr createContainerCredentialsProvider( + ABSL_ATTRIBUTE_UNUSED Api::Api& api, ABSL_ATTRIBUTE_UNUSED ServerFactoryContextOptRef context, + ABSL_ATTRIBUTE_UNUSED Singleton::Manager& singleton_manager, + ABSL_ATTRIBUTE_UNUSED const MetadataCredentialsProviderBase::CurlMetadataFetcher& + fetch_metadata_using_curl, + ABSL_ATTRIBUTE_UNUSED CreateMetadataFetcherCb create_metadata_fetcher_cb, + ABSL_ATTRIBUTE_UNUSED absl::string_view cluster_name, + ABSL_ATTRIBUTE_UNUSED absl::string_view credential_uri, + ABSL_ATTRIBUTE_UNUSED MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + ABSL_ATTRIBUTE_UNUSED std::chrono::seconds initialization_timer, + ABSL_ATTRIBUTE_UNUSED absl::string_view authorization_token = {}) const override { + return nullptr; + } + + CredentialsProviderSharedPtr createInstanceProfileCredentialsProvider( + ABSL_ATTRIBUTE_UNUSED Api::Api& api, ABSL_ATTRIBUTE_UNUSED ServerFactoryContextOptRef context, + ABSL_ATTRIBUTE_UNUSED Singleton::Manager& singleton_manager, + ABSL_ATTRIBUTE_UNUSED const MetadataCredentialsProviderBase::CurlMetadataFetcher& + fetch_metadata_using_curl, + ABSL_ATTRIBUTE_UNUSED CreateMetadataFetcherCb create_metadata_fetcher_cb, + ABSL_ATTRIBUTE_UNUSED MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + ABSL_ATTRIBUTE_UNUSED std::chrono::seconds initialization_timer, + ABSL_ATTRIBUTE_UNUSED absl::string_view cluster_name) const override { + return nullptr; + } +}; + +/** + * Credential provider based on an inline credential. + */ +class InlineCredentialProvider : public CredentialsProvider { +public: + explicit InlineCredentialProvider(absl::string_view access_key_id, + absl::string_view secret_access_key, + absl::string_view session_token) + : credentials_(access_key_id, secret_access_key, session_token) {} + + Credentials getCredentials() override { return credentials_; } + +private: + const Credentials credentials_; +}; + /** * Default AWS credentials provider chain. * @@ -417,14 +501,18 @@ class DefaultCredentialsProviderChain : public CredentialsProviderChain, DefaultCredentialsProviderChain( Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, absl::string_view region, - const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl) + const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl, + const envoy::extensions::common::aws::v3::AwsCredentialProvider& credential_provider_config = + {}) : DefaultCredentialsProviderChain(api, context, singleton_manager, region, - fetch_metadata_using_curl, *this) {} + fetch_metadata_using_curl, credential_provider_config, + *this) {} DefaultCredentialsProviderChain( Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, absl::string_view region, const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl, + const envoy::extensions::common::aws::v3::AwsCredentialProvider& credential_provider_config, const CredentialsProviderChainFactories& factories); private: @@ -432,10 +520,14 @@ class DefaultCredentialsProviderChain : public CredentialsProviderChain, return std::make_shared(); } - CredentialsProviderSharedPtr - createCredentialsFileCredentialsProvider(Api::Api& api) const override { - return std::make_shared(api); - } + CredentialsProviderSharedPtr createCredentialsFileCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config + + ) const override { + return std::make_shared(context, credential_file_config); + }; CredentialsProviderSharedPtr createContainerCredentialsProvider( Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, @@ -454,44 +546,19 @@ class DefaultCredentialsProviderChain : public CredentialsProviderChain, std::chrono::seconds initialization_timer, absl::string_view cluster_name) const override; CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( - Api::Api& api, ServerFactoryContextOptRef context, - const MetadataCredentialsProviderBase::CurlMetadataFetcher& fetch_metadata_using_curl, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view cluster_name, - absl::string_view token_file_path, absl::string_view token, absl::string_view sts_endpoint, - absl::string_view role_arn, absl::string_view role_session_name, + Server::Configuration::ServerFactoryContext& context, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - std::chrono::seconds initialization_timer) const override { + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name) const override { return std::make_shared( - api, context, fetch_metadata_using_curl, create_metadata_fetcher_cb, token_file_path, token, - sts_endpoint, role_arn, role_session_name, refresh_state, initialization_timer, - cluster_name); + context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, + web_identity_config, cluster_name); } }; -/** - * Credential provider based on an inline credential. - */ -class InlineCredentialProvider : public CredentialsProvider { -public: - explicit InlineCredentialProvider(absl::string_view access_key_id, - absl::string_view secret_access_key, - absl::string_view session_token) - : credentials_(access_key_id, secret_access_key, session_token) {} - - Credentials getCredentials() override { return credentials_; } - -private: - const Credentials credentials_; -}; - -/** - * Create an AWS credentials provider from the proto configuration instead of using the default - * credentials provider chain. - */ -absl::StatusOr createCredentialsProviderFromConfig( - Server::Configuration::ServerFactoryContext& context, absl::string_view region, - const envoy::extensions::common::aws::v3::AwsCredentialProvider& config); - using InstanceProfileCredentialsProviderPtr = std::shared_ptr; using ContainerCredentialsProviderPtr = std::shared_ptr; using WebIdentityCredentialsProviderPtr = std::shared_ptr; diff --git a/source/extensions/common/aws/region_provider_impl.cc b/source/extensions/common/aws/region_provider_impl.cc index ca7255e00843..ee78d1705132 100644 --- a/source/extensions/common/aws/region_provider_impl.cc +++ b/source/extensions/common/aws/region_provider_impl.cc @@ -40,14 +40,31 @@ absl::optional EnvironmentRegionProvider::getRegionSet() { return regionSet; } +AWSCredentialsFileRegionProvider::AWSCredentialsFileRegionProvider( + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config) { + if (credential_file_config.has_credentials_data_source() && + credential_file_config.credentials_data_source().has_filename()) { + credential_file_path_ = credential_file_config.credentials_data_source().filename(); + } + if (!credential_file_config.profile().empty()) { + profile_ = credential_file_config.profile(); + } +} + absl::optional AWSCredentialsFileRegionProvider::getRegion() { absl::flat_hash_map elements = {{REGION, ""}}; absl::flat_hash_map::iterator it; - // Search for the region in the credentials file + absl::string_view file_path; + file_path = credential_file_path_.has_value() ? credential_file_path_.value() + : Utility::getCredentialFilePath(); - if (!Utility::resolveProfileElements(Utility::getCredentialFilePath(), - Utility::getCredentialProfileName(), elements)) { + absl::string_view profile; + profile = profile_.has_value() ? profile_.value() : Utility::getCredentialProfileName(); + + // Search for the region in the credentials file + if (!Utility::resolveProfileElementsFromFile(file_path.data(), profile.data(), elements)) { return absl::nullopt; } it = elements.find(REGION); @@ -66,8 +83,15 @@ absl::optional AWSCredentialsFileRegionProvider::getRegionSet() { // Search for the region in the credentials file - if (!Utility::resolveProfileElements(Utility::getCredentialFilePath(), - Utility::getCredentialProfileName(), elements)) { + absl::string_view file_path; + file_path = credential_file_path_.has_value() ? credential_file_path_.value() + : Utility::getCredentialFilePath(); + + absl::string_view profile; + profile = profile_.has_value() ? profile_.value() : Utility::getCredentialProfileName(); + + // Search for the region in the credentials file + if (!Utility::resolveProfileElementsFromFile(file_path.data(), profile.data(), elements)) { return absl::nullopt; } it = elements.find(SIGV4A_SIGNING_REGION_SET); @@ -86,8 +110,8 @@ absl::optional AWSConfigFileRegionProvider::getRegion() { // Search for the region in the config file - if (!Utility::resolveProfileElements(Utility::getConfigFilePath(), - Utility::getConfigProfileName(), elements)) { + if (!Utility::resolveProfileElementsFromFile(Utility::getConfigFilePath(), + Utility::getConfigProfileName(), elements)) { return absl::nullopt; } @@ -106,8 +130,8 @@ absl::optional AWSConfigFileRegionProvider::getRegionSet() { // Search for the region in the config file - if (!Utility::resolveProfileElements(Utility::getConfigFilePath(), - Utility::getConfigProfileName(), elements)) { + if (!Utility::resolveProfileElementsFromFile(Utility::getConfigFilePath(), + Utility::getConfigProfileName(), elements)) { return absl::nullopt; } @@ -133,10 +157,12 @@ absl::optional AWSConfigFileRegionProvider::getRegionSet() { // Credentials and profile format can be found here: // https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html // -RegionProviderChain::RegionProviderChain() { +RegionProviderChain::RegionProviderChain( + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config) { // TODO(nbaws): Verify that bypassing virtual dispatch here was intentional add(RegionProviderChain::createEnvironmentRegionProvider()); - add(RegionProviderChain::createAWSCredentialsFileRegionProvider()); + add(RegionProviderChain::createAWSCredentialsFileRegionProvider(credential_file_config)); add(RegionProviderChain::createAWSConfigFileRegionProvider()); } diff --git a/source/extensions/common/aws/region_provider_impl.h b/source/extensions/common/aws/region_provider_impl.h index b62deaa0a199..dae381f1cfb4 100644 --- a/source/extensions/common/aws/region_provider_impl.h +++ b/source/extensions/common/aws/region_provider_impl.h @@ -1,5 +1,7 @@ #pragma once +#include "envoy/extensions/common/aws/v3/credential_provider.pb.h" + #include "source/common/common/logger.h" #include "source/extensions/common/aws/region_provider.h" @@ -23,11 +25,17 @@ class EnvironmentRegionProvider : public RegionProvider, public Logger::Loggable class AWSCredentialsFileRegionProvider : public RegionProvider, public Logger::Loggable { public: - AWSCredentialsFileRegionProvider() = default; + AWSCredentialsFileRegionProvider( + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config); absl::optional getRegion() override; absl::optional getRegionSet() override; + +private: + absl::optional credential_file_path_; + absl::optional profile_; }; class AWSConfigFileRegionProvider : public RegionProvider, @@ -45,7 +53,9 @@ class RegionProviderChainFactories { virtual ~RegionProviderChainFactories() = default; virtual RegionProviderSharedPtr createEnvironmentRegionProvider() const PURE; - virtual RegionProviderSharedPtr createAWSCredentialsFileRegionProvider() const PURE; + virtual RegionProviderSharedPtr createAWSCredentialsFileRegionProvider( + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config) const PURE; virtual RegionProviderSharedPtr createAWSConfigFileRegionProvider() const PURE; }; @@ -57,7 +67,8 @@ class RegionProviderChain : public RegionProvider, public RegionProviderChainFactories, public Logger::Loggable { public: - RegionProviderChain(); + RegionProviderChain(const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config = {}); ~RegionProviderChain() override = default; @@ -72,8 +83,10 @@ class RegionProviderChain : public RegionProvider, RegionProviderSharedPtr createEnvironmentRegionProvider() const override { return std::make_shared(); } - RegionProviderSharedPtr createAWSCredentialsFileRegionProvider() const override { - return std::make_shared(); + RegionProviderSharedPtr createAWSCredentialsFileRegionProvider( + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config) const override { + return std::make_shared(credential_file_config); } RegionProviderSharedPtr createAWSConfigFileRegionProvider() const override { return std::make_shared(); diff --git a/source/extensions/common/aws/utility.cc b/source/extensions/common/aws/utility.cc index 13ebbd36c934..792074daf9d5 100644 --- a/source/extensions/common/aws/utility.cc +++ b/source/extensions/common/aws/utility.cc @@ -451,19 +451,39 @@ std::string Utility::getEnvironmentVariableOrDefault(const std::string& variable return (value != nullptr) && (value[0] != '\0') ? value : default_value; } -bool Utility::resolveProfileElements(const std::string& profile_file, - const std::string& profile_name, - absl::flat_hash_map& elements) { +bool Utility::resolveProfileElementsFromString( + const std::string& string_data, const std::string& profile_name, + absl::flat_hash_map& elements) { + // std::istringstream a(string_data); + std::unique_ptr stream; + + stream = std::make_unique(std::istringstream{string_data}); + return resolveProfileElementsFromStream(*stream, profile_name, elements); +} + +bool Utility::resolveProfileElementsFromFile( + const std::string& profile_file, const std::string& profile_name, + absl::flat_hash_map& elements) { std::ifstream file(profile_file); if (!file.is_open()) { ENVOY_LOG_MISC(debug, "Error opening credentials file {}", profile_file); return false; } + std::unique_ptr stream; + stream = std::make_unique(std::move(file)); + return resolveProfileElementsFromStream(*stream, profile_name, elements); +} + +bool Utility::resolveProfileElementsFromStream( + std::istream& stream, const std::string& profile_name, + absl::flat_hash_map& elements) { + const auto profile_start = absl::StrFormat("[%s]", profile_name); bool found_profile = false; std::string line; - while (std::getline(file, line)) { + + while (std::getline(stream, line)) { line = std::string(StringUtil::trim(line)); if (line.empty()) { continue; diff --git a/source/extensions/common/aws/utility.h b/source/extensions/common/aws/utility.h index 00e000f46a64..68685d24d53e 100644 --- a/source/extensions/common/aws/utility.h +++ b/source/extensions/common/aws/utility.h @@ -140,9 +140,18 @@ class Utility { * @return true if profile file could be read and searched. * @return false if profile file could not be read. */ - static bool resolveProfileElements(const std::string& profile_file, - const std::string& profile_name, - absl::flat_hash_map& elements); + + static bool + resolveProfileElementsFromString(const std::string& string_data, const std::string& profile_name, + absl::flat_hash_map& elements); + + static bool + resolveProfileElementsFromFile(const std::string& profile_file, const std::string& profile_name, + absl::flat_hash_map& elements); + + static bool + resolveProfileElementsFromStream(std::istream& stream, const std::string& profile_name, + absl::flat_hash_map& elements); /** * @brief Return the path of AWS credential file, following environment variable expansions diff --git a/source/extensions/filters/http/aws_lambda/config.cc b/source/extensions/filters/http/aws_lambda/config.cc index e23ea02e6643..90a0db569b0b 100644 --- a/source/extensions/filters/http/aws_lambda/config.cc +++ b/source/extensions/filters/http/aws_lambda/config.cc @@ -55,8 +55,10 @@ AwsLambdaFilterFactory::getCredentialsProvider( "credentials profile is set to \"{}\" in config, default credentials providers chain " "will be ignored and only credentials file provider will be used", proto_config.credentials_profile()); + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config; + credential_file_config.set_profile(proto_config.credentials_profile()); return std::make_shared( - server_context.api(), proto_config.credentials_profile()); + server_context, credential_file_config); } return std::make_shared( server_context.api(), makeOptRef(server_context), server_context.singletonManager(), region, diff --git a/source/extensions/filters/http/aws_request_signing/config.cc b/source/extensions/filters/http/aws_request_signing/config.cc index 2d25b4376529..eb2906c541a8 100644 --- a/source/extensions/filters/http/aws_request_signing/config.cc +++ b/source/extensions/filters/http/aws_request_signing/config.cc @@ -36,11 +36,22 @@ AwsRequestSigningFilterFactory::createFilterFactoryFromProtoTyped( const AwsRequestSigningProtoConfig& config, const std::string& stats_prefix, DualInfo dual_info, Server::Configuration::ServerFactoryContext& server_context) { - std::string region; - region = config.region(); + std::string region = config.region(); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + + // If we have an overriding credential provider configuration, read it here as it may contain + // references to the region + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = {}; + if (config.has_credential_provider()) { + if (config.credential_provider().has_credentials_file_provider()) { + credential_file_config = config.credential_provider().credentials_file_provider(); + } + } if (region.empty()) { - auto region_provider = std::make_shared(); + auto region_provider = + std::make_shared(credential_file_config); absl::optional regionOpt; if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { regionOpt = region_provider->getRegionSet(); @@ -55,27 +66,61 @@ AwsRequestSigningFilterFactory::createFilterFactoryFromProtoTyped( region = regionOpt.value(); } - bool query_string = config.has_query_string(); + absl::StatusOr + credentials_provider = + absl::InvalidArgumentError("No credentials provider settings configured."); - uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( - config.query_string(), expiration_time, - Extensions::Common::Aws::SignatureQueryParameterValues::DefaultExpiration); + const bool has_credential_provider_settings = + config.has_credential_provider() && + (config.credential_provider().has_assume_role_with_web_identity_provider() || + config.credential_provider().has_credentials_file_provider()); - absl::StatusOr + if (config.has_credential_provider()) { + if (config.credential_provider().has_inline_credential()) { + // If inline credential provider is set, use it instead of the default or custom credentials + // chain + const auto& inline_credential = config.credential_provider().inline_credential(); + credentials_provider = std::make_shared( + inline_credential.access_key_id(), inline_credential.secret_access_key(), + inline_credential.session_token()); + } else if (config.credential_provider().custom_credential_provider_chain()) { + // Custom credential provider chain + if (has_credential_provider_settings) { + credentials_provider = + std::make_shared( + server_context, region, config.credential_provider()); + } + } else { + // Override default credential provider chain settings with any provided settings + if (has_credential_provider_settings) { + credential_provider_config = config.credential_provider(); + } credentials_provider = - config.has_credential_provider() - ? Extensions::Common::Aws::createCredentialsProviderFromConfig( - server_context, region, config.credential_provider()) - : std::make_shared( - server_context.api(), makeOptRef(server_context), - server_context.singletonManager(), region, nullptr); + std::make_shared( + server_context.api(), makeOptRef(server_context), server_context.singletonManager(), + region, nullptr, credential_provider_config); + } + } else { + // No credential provider settings provided, so make the default credentials provider chain + credentials_provider = + std::make_shared( + server_context.api(), makeOptRef(server_context), server_context.singletonManager(), + region, nullptr, credential_provider_config); + } + if (!credentials_provider.ok()) { - return credentials_provider.status(); + return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); } const auto matcher_config = Extensions::Common::Aws::AwsSigningHeaderExclusionVector( config.match_excluded_headers().begin(), config.match_excluded_headers().end()); + const bool query_string = config.has_query_string(); + + const uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( + config.query_string(), expiration_time, + Extensions::Common::Aws::SignatureQueryParameterValues::DefaultExpiration); + std::unique_ptr signer; if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { @@ -103,13 +148,28 @@ AwsRequestSigningFilterFactory::createFilterFactoryFromProtoTyped( }; } +// TODO: @nbaws remove duplication from above + absl::StatusOr AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( const AwsRequestSigningProtoPerRouteConfig& per_route_config, - Server::Configuration::ServerFactoryContext& context, ProtobufMessage::ValidationVisitor&) { - std::string region; + Server::Configuration::ServerFactoryContext& server_context, + ProtobufMessage::ValidationVisitor&) { - region = per_route_config.aws_request_signing().region(); + std::string region = per_route_config.aws_request_signing().region(); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + + // If we have an overriding credential provider configuration, read it here + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = {}; + if (per_route_config.aws_request_signing().has_credential_provider()) { + if (per_route_config.aws_request_signing() + .credential_provider() + .has_credentials_file_provider()) { + credential_file_config = + per_route_config.aws_request_signing().credential_provider().credentials_file_provider(); + } + } if (region.empty()) { auto region_provider = std::make_shared(); @@ -128,18 +188,56 @@ AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( region = regionOpt.value(); } - bool query_string = per_route_config.aws_request_signing().has_query_string(); - uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( - per_route_config.aws_request_signing().query_string(), expiration_time, 5); - absl::StatusOr credentials_provider = - per_route_config.aws_request_signing().has_credential_provider() - ? Extensions::Common::Aws::createCredentialsProviderFromConfig( - context, region, per_route_config.aws_request_signing().credential_provider()) - : std::make_shared( - context.api(), makeOptRef(context), context.singletonManager(), region, - nullptr); + absl::InvalidArgumentError("No credentials provider settings configured."); + + bool has_credential_provider_settings = + per_route_config.aws_request_signing().has_credential_provider() && + (per_route_config.aws_request_signing() + .credential_provider() + .has_assume_role_with_web_identity_provider() || + per_route_config.aws_request_signing() + .credential_provider() + .has_credentials_file_provider()); + + if (per_route_config.aws_request_signing().has_credential_provider()) { + if (per_route_config.aws_request_signing().credential_provider().has_inline_credential()) { + const auto& inline_credential = + per_route_config.aws_request_signing().credential_provider().inline_credential(); + credentials_provider = std::make_shared( + inline_credential.access_key_id(), inline_credential.secret_access_key(), + inline_credential.session_token()); + } + + if (per_route_config.aws_request_signing() + .credential_provider() + .custom_credential_provider_chain()) { + // Custom credential provider chain + if (has_credential_provider_settings) { + credentials_provider = + std::make_shared( + server_context, region, + per_route_config.aws_request_signing().credential_provider()); + } + } else { + // Override default credential provider chain settings with any provided settings + if (has_credential_provider_settings) { + credential_provider_config = per_route_config.aws_request_signing().credential_provider(); + } + credentials_provider = + std::make_shared( + server_context.api(), makeOptRef(server_context), server_context.singletonManager(), + region, nullptr, credential_provider_config); + } + } else { + // No credential provider settings provided, so make the default credentials provider chain + credentials_provider = + std::make_shared( + server_context.api(), makeOptRef(server_context), server_context.singletonManager(), + region, nullptr, credential_provider_config); + } + if (!credentials_provider.ok()) { return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); } @@ -147,13 +245,19 @@ AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( const auto matcher_config = Extensions::Common::Aws::AwsSigningHeaderExclusionVector( per_route_config.aws_request_signing().match_excluded_headers().begin(), per_route_config.aws_request_signing().match_excluded_headers().end()); + + const bool query_string = per_route_config.aws_request_signing().has_query_string(); + + const uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( + per_route_config.aws_request_signing().query_string(), expiration_time, 5); + std::unique_ptr signer; if (per_route_config.aws_request_signing().signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { signer = std::make_unique( per_route_config.aws_request_signing().service_name(), region, credentials_provider.value(), - context, matcher_config, query_string, expiration_time); + server_context, matcher_config, query_string, expiration_time); } else { // Verify that we have not specified a region set when using sigv4 algorithm if (isARegionSet(region)) { @@ -163,11 +267,11 @@ AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( } signer = std::make_unique( per_route_config.aws_request_signing().service_name(), region, credentials_provider.value(), - context, matcher_config, query_string, expiration_time); + server_context, matcher_config, query_string, expiration_time); } return std::make_shared( - std::move(signer), per_route_config.stat_prefix(), context.scope(), + std::move(signer), per_route_config.stat_prefix(), server_context.scope(), per_route_config.aws_request_signing().host_rewrite(), per_route_config.aws_request_signing().use_unsigned_payload()); } diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index a25fcd165b6e..2de705e53387 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -29,7 +29,7 @@ using testing::InSequence; using testing::NiceMock; using testing::Ref; using testing::Return; - +using testing::WithArg; namespace Envoy { namespace Extensions { namespace Common { @@ -175,13 +175,15 @@ TEST_F(EvironmentCredentialsProviderTest, NoSessionToken) { class CredentialsFileCredentialsProviderTest : public testing::Test { public: CredentialsFileCredentialsProviderTest() - : api_(Api::createApiForTest(time_system_)), provider_(*api_) {} + : api_(Api::createApiForTest(time_system_)), provider_(context_) {} ~CredentialsFileCredentialsProviderTest() override { TestEnvironment::unsetEnvVar("AWS_SHARED_CREDENTIALS_FILE"); TestEnvironment::unsetEnvVar("AWS_PROFILE"); } + void SetUp() override { EXPECT_CALL(context_, api()).WillRepeatedly(testing::ReturnRef(*api_)); } + void setUpTest(std::string file_contents, std::string profile) { auto file_path = TestEnvironment::writeStringToFileForTest(CREDENTIALS_FILE, file_contents); TestEnvironment::setEnvVar("AWS_SHARED_CREDENTIALS_FILE", file_path, 1); @@ -189,7 +191,11 @@ class CredentialsFileCredentialsProviderTest : public testing::Test { } Event::SimulatedTimeSystem time_system_; + NiceMock context_; + Api::ApiPtr api_; + // Event::DispatcherPtr dispatcher_; + // NiceMock tls_; CredentialsFileCredentialsProvider provider_; }; @@ -197,8 +203,37 @@ TEST_F(CredentialsFileCredentialsProviderTest, CustomProfileFromConfigShouldBeHo auto file_path = TestEnvironment::writeStringToFileForTest(CREDENTIALS_FILE, CREDENTIALS_FILE_CONTENTS); TestEnvironment::setEnvVar("AWS_SHARED_CREDENTIALS_FILE", file_path, 1); + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider config = {}; + config.set_profile("profile4"); + auto provider = CredentialsFileCredentialsProvider(context_, config); + const auto credentials = provider.getCredentials(); + EXPECT_EQ("profile4_access_key", credentials.accessKeyId().value()); + EXPECT_EQ("profile4_secret", credentials.secretAccessKey().value()); + EXPECT_EQ("profile4_token", credentials.sessionToken().value()); +} + +TEST_F(CredentialsFileCredentialsProviderTest, CustomFilePathFromConfig) { + auto file_path = + TestEnvironment::writeStringToFileForTest(CREDENTIALS_FILE, CREDENTIALS_FILE_CONTENTS); + + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider config = {}; + config.mutable_credentials_data_source()->set_filename(file_path); + auto provider = CredentialsFileCredentialsProvider(context_, config); + const auto credentials = provider.getCredentials(); + EXPECT_EQ("default_access_key", credentials.accessKeyId().value()); + EXPECT_EQ("default_secret", credentials.secretAccessKey().value()); + EXPECT_EQ("default_token", credentials.sessionToken().value()); +} + +TEST_F(CredentialsFileCredentialsProviderTest, CustomFilePathAndProfileFromConfig) { + auto file_path = + TestEnvironment::writeStringToFileForTest(CREDENTIALS_FILE, CREDENTIALS_FILE_CONTENTS); + + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider config = {}; + config.mutable_credentials_data_source()->set_filename(file_path); + config.set_profile("profile4"); - auto provider = CredentialsFileCredentialsProvider(*api_, "profile4"); + auto provider = CredentialsFileCredentialsProvider(context_, config); const auto credentials = provider.getCredentials(); EXPECT_EQ("profile4_access_key", credentials.accessKeyId().value()); EXPECT_EQ("profile4_secret", credentials.secretAccessKey().value()); @@ -210,7 +245,10 @@ TEST_F(CredentialsFileCredentialsProviderTest, UnexistingCustomProfileFomConfig) TestEnvironment::writeStringToFileForTest(CREDENTIALS_FILE, CREDENTIALS_FILE_CONTENTS); TestEnvironment::setEnvVar("AWS_SHARED_CREDENTIALS_FILE", file_path, 1); - auto provider = CredentialsFileCredentialsProvider(*api_, "unexistening_profile"); + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider config = {}; + config.set_profile("unexistening_profile"); + + auto provider = CredentialsFileCredentialsProvider(context_, config); const auto credentials = provider.getCredentials(); EXPECT_FALSE(credentials.accessKeyId().has_value()); EXPECT_FALSE(credentials.secretAccessKey().has_value()); @@ -1945,17 +1983,26 @@ class WebIdentityCredentialsProviderTest : public testing::Test { std::chrono::seconds initialization_timer = std::chrono::seconds(2)) { ON_CALL(context_, clusterManager()).WillByDefault(ReturnRef(cluster_manager_)); std::string token_file_path; + envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider cred_provider = + {}; + if (token_.empty()) { token_file_path = TestEnvironment::writeStringToFileForTest("web_token_file", "web_token"); + cred_provider.mutable_web_identity_token_data_source()->set_inline_string("web_token"); + } else { + cred_provider.mutable_web_identity_token_data_source()->set_inline_string(token_); } + cred_provider.set_role_arn("aws:iam::123456789012:role/arn"); + cred_provider.set_role_session_name("role-session-name"); + provider_ = std::make_shared( - *api_, context_, nullptr, + context_, [this](Upstream::ClusterManager&, absl::string_view) { metadata_fetcher_.reset(raw_metadata_fetcher_); return std::move(metadata_fetcher_); }, - token_file_path, token_, "sts.region.amazonaws.com:443", "aws:iam::123456789012:role/arn", - "role-session-name", refresh_state, initialization_timer, "credentials_provider_cluster"); + "sts.region.amazonaws.com:443", refresh_state, initialization_timer, cred_provider, + "credentials_provider_cluster"); } void @@ -1975,18 +2022,27 @@ class WebIdentityCredentialsProviderTest : public testing::Test { MetadataFetcher::MetadataReceiver::RefreshState::Ready, std::chrono::seconds initialization_timer = std::chrono::seconds(2)) { std::string token_file_path; + envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider cred_provider = + {}; + if (token_.empty()) { token_file_path = TestEnvironment::writeStringToFileForTest("web_token_file", "web_token"); + cred_provider.mutable_web_identity_token_data_source()->set_inline_string("web_token"); + } else { + cred_provider.mutable_web_identity_token_data_source()->set_inline_string(token_); } + cred_provider.set_role_arn("aws:iam::123456789012:role/arn"); + cred_provider.set_role_session_name("role-session-name"); + ON_CALL(context_, clusterManager()).WillByDefault(ReturnRef(cluster_manager_)); provider_ = std::make_shared( - *api_, context_, nullptr, + context_, [this](Upstream::ClusterManager&, absl::string_view) { metadata_fetcher_.reset(raw_metadata_fetcher_); return std::move(metadata_fetcher_); }, - token_file_path, token_, "sts.region.amazonaws.com:443", "aws:iam::123456789012:role/arn", - "role-session-name", refresh_state, initialization_timer, "credentials_provider_cluster"); + "sts.region.amazonaws.com:443", refresh_state, initialization_timer, cred_provider, + "credentials_provider_cluster"); } void expectDocument(const uint64_t status_code, const std::string&& document) { @@ -2435,6 +2491,45 @@ TEST_F(WebIdentityCredentialsProviderTest, LibcurlEnabled) { metadata_fetcher_.reset(raw_metadata_fetcher_); } +class MockCredentialsProviderChainFactories : public CredentialsProviderChainFactories { +public: + MOCK_METHOD(CredentialsProviderSharedPtr, createEnvironmentCredentialsProvider, (), (const)); + MOCK_METHOD( + CredentialsProviderSharedPtr, mockCreateCredentialsFileCredentialsProvider, + (Server::Configuration::ServerFactoryContext&, + (const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& config)), + (const)); + + CredentialsProviderSharedPtr createCredentialsFileCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& config) + const override { + return mockCreateCredentialsFileCredentialsProvider(context, config); + } + + MOCK_METHOD( + CredentialsProviderSharedPtr, createWebIdentityCredentialsProvider, + (Server::Configuration::ServerFactoryContext&, CreateMetadataFetcherCb, absl::string_view, + MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider&, + absl::string_view), + (const)); + + MOCK_METHOD(CredentialsProviderSharedPtr, createContainerCredentialsProvider, + (Api::Api&, ServerFactoryContextOptRef, Singleton::Manager&, + const MetadataCredentialsProviderBase::CurlMetadataFetcher&, CreateMetadataFetcherCb, + absl::string_view, absl::string_view, + MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, + absl::string_view), + (const)); + MOCK_METHOD(CredentialsProviderSharedPtr, createInstanceProfileCredentialsProvider, + (Api::Api&, ServerFactoryContextOptRef, Singleton::Manager&, + const MetadataCredentialsProviderBase::CurlMetadataFetcher&, CreateMetadataFetcherCb, + MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, + absl::string_view), + (const)); +}; + class DefaultCredentialsProviderChainTest : public testing::Test { public: DefaultCredentialsProviderChainTest() : api_(Api::createApiForTest(time_system_)) { @@ -2455,33 +2550,6 @@ class DefaultCredentialsProviderChainTest : public testing::Test { TestEnvironment::unsetEnvVar("AWS_ROLE_SESSION_NAME"); } - class MockCredentialsProviderChainFactories : public CredentialsProviderChainFactories { - public: - MOCK_METHOD(CredentialsProviderSharedPtr, createEnvironmentCredentialsProvider, (), (const)); - MOCK_METHOD(CredentialsProviderSharedPtr, createCredentialsFileCredentialsProvider, (Api::Api&), - (const)); - MOCK_METHOD(CredentialsProviderSharedPtr, createWebIdentityCredentialsProvider, - (Api::Api&, ServerFactoryContextOptRef, - const MetadataCredentialsProviderBase::CurlMetadataFetcher&, - CreateMetadataFetcherCb, absl::string_view, absl::string_view, absl::string_view, - absl::string_view, absl::string_view, absl::string_view, - MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds), - (const)); - MOCK_METHOD(CredentialsProviderSharedPtr, createContainerCredentialsProvider, - (Api::Api&, ServerFactoryContextOptRef, Singleton::Manager&, - const MetadataCredentialsProviderBase::CurlMetadataFetcher&, - CreateMetadataFetcherCb, absl::string_view, absl::string_view, - MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, - absl::string_view), - (const)); - MOCK_METHOD(CredentialsProviderSharedPtr, createInstanceProfileCredentialsProvider, - (Api::Api&, ServerFactoryContextOptRef, Singleton::Manager&, - const MetadataCredentialsProviderBase::CurlMetadataFetcher&, - CreateMetadataFetcherCb, MetadataFetcher::MetadataReceiver::RefreshState, - std::chrono::seconds, absl::string_view), - (const)); - }; - TestScopedRuntime scoped_runtime_; Event::SimulatedTimeSystem time_system_; Api::ApiPtr api_; @@ -2491,100 +2559,265 @@ class DefaultCredentialsProviderChainTest : public testing::Test { }; TEST_F(DefaultCredentialsProviderChainTest, NoEnvironmentVars) { - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); } TEST_F(DefaultCredentialsProviderChainTest, MetadataDisabled) { TestEnvironment::setEnvVar("AWS_EC2_METADATA_DISABLED", "true", 1); - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)) .Times(0); + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); } TEST_F(DefaultCredentialsProviderChainTest, MetadataNotDisabled) { TestEnvironment::setEnvVar("AWS_EC2_METADATA_DISABLED", "false", 1); - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); } TEST_F(DefaultCredentialsProviderChainTest, RelativeUri) { TestEnvironment::setEnvVar("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/path/to/creds", 1); - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createContainerCredentialsProvider(Ref(*api_), _, _, _, _, _, "169.254.170.2:80/path/to/creds", _, _, "")); + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); } TEST_F(DefaultCredentialsProviderChainTest, FullUriNoAuthorizationToken) { TestEnvironment::setEnvVar("AWS_CONTAINER_CREDENTIALS_FULL_URI", "http://host/path/to/creds", 1); - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createContainerCredentialsProvider( Ref(*api_), _, _, _, _, _, "http://host/path/to/creds", _, _, "")); + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); } TEST_F(DefaultCredentialsProviderChainTest, FullUriWithAuthorizationToken) { TestEnvironment::setEnvVar("AWS_CONTAINER_CREDENTIALS_FULL_URI", "http://host/path/to/creds", 1); TestEnvironment::setEnvVar("AWS_CONTAINER_AUTHORIZATION_TOKEN", "auth_token", 1); - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createContainerCredentialsProvider(Ref(*api_), _, _, _, _, _, "http://host/path/to/creds", _, _, "auth_token")); + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); } TEST_F(DefaultCredentialsProviderChainTest, NoWebIdentityRoleArn) { TestEnvironment::setEnvVar("AWS_WEB_IDENTITY_TOKEN_FILE", "/path/to/web_token", 1); - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); } TEST_F(DefaultCredentialsProviderChainTest, NoWebIdentitySessionName) { TestEnvironment::setEnvVar("AWS_WEB_IDENTITY_TOKEN_FILE", "/path/to/web_token", 1); TestEnvironment::setEnvVar("AWS_ROLE_ARN", "aws:iam::123456789012:role/arn", 1); time_system_.setSystemTime(std::chrono::milliseconds(1234567890)); - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); - EXPECT_CALL(factories_, - createWebIdentityCredentialsProvider( - Ref(*api_), _, _, _, _, "/path/to/web_token", _, "sts.region.amazonaws.com:443", - "aws:iam::123456789012:role/arn", "1234567890000000", _, _)); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); + EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( + Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)); EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); } TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithSessionName) { TestEnvironment::setEnvVar("AWS_WEB_IDENTITY_TOKEN_FILE", "/path/to/web_token", 1); TestEnvironment::setEnvVar("AWS_ROLE_ARN", "aws:iam::123456789012:role/arn", 1); TestEnvironment::setEnvVar("AWS_ROLE_SESSION_NAME", "role-session-name", 1); - EXPECT_CALL(factories_, createCredentialsFileCredentialsProvider(Ref(*api_))); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( + Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", + DummyMetadataFetcher(), credential_provider_config, + factories_); +} + +TEST_F(DefaultCredentialsProviderChainTest, NoWebIdentityWithBlankConfig) { + TestEnvironment::unsetEnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"); + TestEnvironment::unsetEnvVar("AWS_ROLE_ARN"); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, - createWebIdentityCredentialsProvider( - Ref(*api_), _, _, _, _, "/path/to/web_token", _, "sts.region.amazonaws.com:443", - "aws:iam::123456789012:role/arn", "role-session-name", _, _)); + createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( + Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)) + .Times(0); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", + DummyMetadataFetcher(), credential_provider_config, + factories_); +} +// These tests validate override of default credential provider with custom credential provider +// settings + +TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomSessionName) { + TestEnvironment::setEnvVar("AWS_WEB_IDENTITY_TOKEN_FILE", "/path/to/web_token", 1); + TestEnvironment::setEnvVar("AWS_ROLE_ARN", "aws:iam::123456789012:role/arn", 1); + TestEnvironment::setEnvVar("AWS_ROLE_SESSION_NAME", "role-session-name", 1); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); + EXPECT_CALL(factories_, + createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + + std::string role_session_name; + + EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( + Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)) + .WillOnce(Invoke(WithArg<5>( + [&role_session_name]( + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + provider) -> CredentialsProviderSharedPtr { + role_session_name = provider.role_session_name(); + return nullptr; + }))); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + credential_provider_config.mutable_assume_role_with_web_identity_provider() + ->set_role_session_name("custom-role-session-name"); + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", - DummyMetadataFetcher(), factories_); + DummyMetadataFetcher(), credential_provider_config, + factories_); + EXPECT_EQ(role_session_name, "custom-role-session-name"); +} + +TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomRoleArn) { + TestEnvironment::setEnvVar("AWS_WEB_IDENTITY_TOKEN_FILE", "/path/to/web_token", 1); + TestEnvironment::setEnvVar("AWS_ROLE_ARN", "aws:iam::123456789012:role/arn", 1); + TestEnvironment::setEnvVar("AWS_ROLE_SESSION_NAME", "role-session-name", 1); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); + EXPECT_CALL(factories_, + createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + + std::string role_arn; + + EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( + Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)) + .WillOnce(Invoke(WithArg<5>( + [&role_arn]( + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + provider) -> CredentialsProviderSharedPtr { + role_arn = provider.role_arn(); + return nullptr; + }))); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + credential_provider_config.mutable_assume_role_with_web_identity_provider()->set_role_arn( + "custom-role-arn"); + + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", + DummyMetadataFetcher(), credential_provider_config, + factories_); + EXPECT_EQ(role_arn, "custom-role-arn"); +} + +TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomDataSource) { + TestEnvironment::setEnvVar("AWS_WEB_IDENTITY_TOKEN_FILE", "/path/to/web_token", 1); + TestEnvironment::setEnvVar("AWS_ROLE_ARN", "aws:iam::123456789012:role/arn", 1); + TestEnvironment::setEnvVar("AWS_ROLE_SESSION_NAME", "role-session-name", 1); + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); + EXPECT_CALL(factories_, + createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + + std::string inline_string; + + EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( + Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)) + .WillOnce(Invoke(WithArg<5>( + [&inline_string]( + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + provider) -> CredentialsProviderSharedPtr { + inline_string = provider.web_identity_token_data_source().inline_string(); + return nullptr; + }))); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + credential_provider_config.mutable_assume_role_with_web_identity_provider() + ->mutable_web_identity_token_data_source() + ->set_inline_string("custom_token_string"); + + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", + DummyMetadataFetcher(), credential_provider_config, + factories_); + EXPECT_EQ(inline_string, "custom_token_string"); +} + +TEST_F(DefaultCredentialsProviderChainTest, CredentialsFileWithCustomDataSource) { + TestEnvironment::setEnvVar("AWS_WEB_IDENTITY_TOKEN_FILE", "/path/to/web_token", 1); + TestEnvironment::setEnvVar("AWS_ROLE_ARN", "aws:iam::123456789012:role/arn", 1); + TestEnvironment::setEnvVar("AWS_ROLE_SESSION_NAME", "role-session-name", 1); + + std::string inline_string; + + EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)) + .WillOnce(Invoke(WithArg<1>( + [&inline_string]( + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& provider) + -> CredentialsProviderSharedPtr { + inline_string = provider.credentials_data_source().inline_string(); + return nullptr; + }))); + + EXPECT_CALL(factories_, + createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); + + EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( + Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + credential_provider_config.mutable_credentials_file_provider() + ->mutable_credentials_data_source() + ->set_inline_string("custom_inline_string"); + + DefaultCredentialsProviderChain chain(*api_, context_, context_.singletonManager(), "region", + DummyMetadataFetcher(), credential_provider_config, + factories_); + EXPECT_EQ(inline_string, "custom_inline_string"); } TEST(CredentialsProviderChainTest, getCredentials_noCredentials) { @@ -2636,6 +2869,72 @@ TEST(CredentialsProviderChainTest, getCredentials_secondProviderReturns) { EXPECT_EQ(creds, ret_creds); } +class CustomCredentialsProviderChainTest : public testing::Test {}; + +TEST_F(CustomCredentialsProviderChainTest, CreateFileCredentialProviderOnly) { + NiceMock factories; + NiceMock server_context; + auto region = "ap-southeast-2"; + auto file_path = TestEnvironment::writeStringToFileForTest("credentials", "hello"); + + envoy::extensions::common::aws::v3::AwsCredentialProvider cred_provider = {}; + cred_provider.mutable_credentials_file_provider() + ->mutable_credentials_data_source() + ->set_filename(file_path); + + EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)); + EXPECT_CALL(factories, + createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _)) + .Times(0); + + auto chain = std::make_shared( + server_context, region, cred_provider, factories); +} + +TEST_F(CustomCredentialsProviderChainTest, CreateWebIdentityCredentialProviderOnly) { + NiceMock factories; + NiceMock server_context; + auto region = "ap-southeast-2"; + auto file_path = TestEnvironment::writeStringToFileForTest("credentials", "hello"); + + envoy::extensions::common::aws::v3::AwsCredentialProvider cred_provider = {}; + cred_provider.mutable_assume_role_with_web_identity_provider()->set_role_arn("arn://1234"); + cred_provider.mutable_assume_role_with_web_identity_provider() + ->mutable_web_identity_token_data_source() + ->set_filename(file_path); + + EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)) + .Times(0); + EXPECT_CALL(factories, + createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _)); + + auto chain = std::make_shared( + server_context, region, cred_provider, factories); +} + +TEST_F(CustomCredentialsProviderChainTest, CreateFileAndWebProviders) { + NiceMock factories; + NiceMock server_context; + auto region = "ap-southeast-2"; + auto file_path = TestEnvironment::writeStringToFileForTest("credentials", "hello"); + + envoy::extensions::common::aws::v3::AwsCredentialProvider cred_provider = {}; + cred_provider.mutable_credentials_file_provider() + ->mutable_credentials_data_source() + ->set_filename(file_path); + cred_provider.mutable_assume_role_with_web_identity_provider()->set_role_arn("arn://1234"); + cred_provider.mutable_assume_role_with_web_identity_provider() + ->mutable_web_identity_token_data_source() + ->set_filename(file_path); + + EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)); + EXPECT_CALL(factories, + createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _)); + + auto chain = std::make_shared( + server_context, region, cred_provider, factories); +} + TEST(CreateCredentialsProviderFromConfig, InlineCredential) { NiceMock context; envoy::extensions::common::aws::v3::InlineCredentialProvider inline_credential; @@ -2646,49 +2945,15 @@ TEST(CreateCredentialsProviderFromConfig, InlineCredential) { envoy::extensions::common::aws::v3::AwsCredentialProvider base; base.mutable_inline_credential()->CopyFrom(inline_credential); - absl::StatusOr provider = - createCredentialsProviderFromConfig(context, "test-region", base); - EXPECT_TRUE(provider.ok()); - EXPECT_NE(nullptr, provider.value()); - const Credentials creds = provider.value()->getCredentials(); + auto provider = std::make_shared( + inline_credential.access_key_id(), inline_credential.secret_access_key(), + inline_credential.session_token()); + const Credentials creds = provider->getCredentials(); EXPECT_EQ("TestAccessKey", creds.accessKeyId().value()); EXPECT_EQ("TestSecret", creds.secretAccessKey().value()); EXPECT_EQ("TestSessionToken", creds.sessionToken().value()); } -TEST(CreateCredentialsProviderFromConfig, AssumeRoleWithWebIdentity) { - NiceMock context; - envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider - assume_role_provider; - assume_role_provider.set_role_arn("arn:aws:iam::123456789012:role/role-name"); - assume_role_provider.set_web_identity_token("this-is-a-token"); - - envoy::extensions::common::aws::v3::AwsCredentialProvider base; - base.mutable_assume_role_with_web_identity()->CopyFrom(assume_role_provider); - - absl::StatusOr provider = - createCredentialsProviderFromConfig(context, "test-region", base); - EXPECT_TRUE(provider.ok()); - EXPECT_NE(nullptr, provider.value()); - - const auto* web_identity_provider = - dynamic_cast(provider.value().get()); - EXPECT_NE(nullptr, web_identity_provider); - - const std::string& token = web_identity_provider->tokenForTesting(); - const std::string& role_arn = web_identity_provider->roleArnForTesting(); - EXPECT_EQ("this-is-a-token", token); - EXPECT_EQ("arn:aws:iam::123456789012:role/role-name", role_arn); -} - -TEST(CreateCredentialsProviderFromConfig, InvalidEnum) { - NiceMock context; - envoy::extensions::common::aws::v3::AwsCredentialProvider base; - absl::StatusOr result = - createCredentialsProviderFromConfig(context, "foo", base); - EXPECT_FALSE(result.ok()); -} - } // namespace Aws } // namespace Common } // namespace Extensions diff --git a/test/extensions/common/aws/region_provider_impl_test.cc b/test/extensions/common/aws/region_provider_impl_test.cc index 1ce717d2d8cd..82cf24380e06 100644 --- a/test/extensions/common/aws/region_provider_impl_test.cc +++ b/test/extensions/common/aws/region_provider_impl_test.cc @@ -39,8 +39,6 @@ class EnvironmentRegionProviderTest : public testing::Test { class AWSCredentialsFileRegionProviderTest : public testing::Test { public: void SetUp() override { setupEnvironment(); } - - AWSCredentialsFileRegionProvider provider_; }; class AWSConfigFileRegionProviderTest : public testing::Test { @@ -203,6 +201,7 @@ TEST_F(AWSConfigFileRegionProviderTest, NoRegionSet) { EXPECT_EQ(false, provider_.getRegionSet().has_value()); } + TEST_F(AWSCredentialsFileRegionProviderTest, CustomCredentialsFile) { auto temp = TestEnvironment::temporaryDirectory(); TestEnvironment::setEnvVar("HOME", temp, 1); @@ -213,8 +212,23 @@ TEST_F(AWSCredentialsFileRegionProviderTest, CustomCredentialsFile) { credentials_file, CREDENTIALS_FILE_CONTENTS, true, false); TestEnvironment::setEnvVar("AWS_SHARED_CREDENTIALS_FILE", credentials_file, 1); + auto provider = AWSCredentialsFileRegionProvider({}); + EXPECT_EQ("credentialsdefaultregion", provider.getRegion().value()); +} + +TEST_F(AWSCredentialsFileRegionProviderTest, CustomCredentialsFileViaCredentialProviderConfig) { + auto temp = TestEnvironment::temporaryDirectory(); + TestEnvironment::setEnvVar("HOME", temp, 1); + std::filesystem::create_directory(temp + "/.aws"); + std::string credentials_file(temp + "/.aws/customfile"); - EXPECT_EQ("credentialsdefaultregion", provider_.getRegion().value()); + auto file_path = TestEnvironment::writeStringToFileForTest( + credentials_file, CREDENTIALS_FILE_CONTENTS, true, false); + + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config; + credential_file_config.mutable_credentials_data_source()->set_filename(credentials_file); + auto provider = AWSCredentialsFileRegionProvider(credential_file_config); + EXPECT_EQ("credentialsdefaultregion", provider.getRegion().value()); } TEST_F(AWSCredentialsFileRegionProviderTest, CustomCredentialsFileRegionSet) { @@ -227,8 +241,26 @@ TEST_F(AWSCredentialsFileRegionProviderTest, CustomCredentialsFileRegionSet) { credentials_file, CREDENTIALS_FILE_CONTENTS_REGION_SET, true, false); TestEnvironment::setEnvVar("AWS_SHARED_CREDENTIALS_FILE", credentials_file, 1); + auto provider = AWSCredentialsFileRegionProvider({}); - EXPECT_EQ("*", provider_.getRegionSet().value()); + EXPECT_EQ("*", provider.getRegionSet().value()); +} + +TEST_F(AWSCredentialsFileRegionProviderTest, + CustomCredentialsFileRegionSetViaCredentialProviderConfig) { + auto temp = TestEnvironment::temporaryDirectory(); + TestEnvironment::setEnvVar("HOME", temp, 1); + std::filesystem::create_directory(temp + "/.aws"); + std::string credentials_file(temp + "/.aws/customfile"); + + auto file_path = TestEnvironment::writeStringToFileForTest( + credentials_file, CREDENTIALS_FILE_CONTENTS_REGION_SET, true, false); + + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config; + credential_file_config.mutable_credentials_data_source()->set_filename(credentials_file); + auto provider = AWSCredentialsFileRegionProvider(credential_file_config); + + EXPECT_EQ("*", provider.getRegionSet().value()); } TEST_F(AWSCredentialsFileRegionProviderTest, CustomProfileSharedCredentialsFile) { @@ -242,8 +274,8 @@ TEST_F(AWSCredentialsFileRegionProviderTest, CustomProfileSharedCredentialsFile) TestEnvironment::setEnvVar("AWS_SHARED_CREDENTIALS_FILE", credentials_file, 1); TestEnvironment::setEnvVar("AWS_PROFILE", "profile1", 1); - - EXPECT_EQ("profile1region", provider_.getRegion().value()); + auto provider = AWSCredentialsFileRegionProvider({}); + EXPECT_EQ("profile1region", provider.getRegion().value()); } TEST_F(AWSCredentialsFileRegionProviderTest, CustomProfileSharedCredentialsFileRegionSet) { @@ -257,8 +289,9 @@ TEST_F(AWSCredentialsFileRegionProviderTest, CustomProfileSharedCredentialsFileR TestEnvironment::setEnvVar("AWS_SHARED_CREDENTIALS_FILE", credentials_file, 1); TestEnvironment::setEnvVar("AWS_PROFILE", "profile1", 1); + auto provider = AWSCredentialsFileRegionProvider({}); - EXPECT_EQ("us-east-1,us-east-2", provider_.getRegionSet().value()); + EXPECT_EQ("us-east-1,us-east-2", provider.getRegionSet().value()); } TEST_F(AWSCredentialsFileRegionProviderTest, NoRegion) { @@ -269,8 +302,8 @@ TEST_F(AWSCredentialsFileRegionProviderTest, NoRegion) { auto file_path = TestEnvironment::writeStringToFileForTest( credentials_file, CREDENTIALS_FILE_NO_REGION, true, false); - - EXPECT_EQ(false, provider_.getRegion().has_value()); + auto provider = AWSCredentialsFileRegionProvider({}); + EXPECT_EQ(false, provider.getRegion().has_value()); } TEST_F(AWSCredentialsFileRegionProviderTest, NoRegionSet) { @@ -281,8 +314,8 @@ TEST_F(AWSCredentialsFileRegionProviderTest, NoRegionSet) { auto file_path = TestEnvironment::writeStringToFileForTest( credentials_file, CREDENTIALS_FILE_NO_REGION, true, false); - - EXPECT_EQ(false, provider_.getRegionSet().has_value()); + auto provider = AWSCredentialsFileRegionProvider({}); + EXPECT_EQ(false, provider.getRegionSet().has_value()); } TEST_F(RegionProviderChainTest, EnvironmentBeforeCredentialsFile) { diff --git a/test/extensions/common/aws/utility_test.cc b/test/extensions/common/aws/utility_test.cc index 6cabe0ebd79e..0dac68522db5 100644 --- a/test/extensions/common/aws/utility_test.cc +++ b/test/extensions/common/aws/utility_test.cc @@ -66,10 +66,10 @@ TEST(UtilityTest, TestProfileResolver) { auto file_path = TestEnvironment::writeStringToFileForTest( credential_file, CREDENTIALS_FILE_CONTENTS, true, false); - Utility::resolveProfileElements(file_path, "default", elements); + Utility::resolveProfileElementsFromFile(file_path, "default", elements); it = elements.find("AWS_ACCESS_KEY_ID"); EXPECT_EQ(it->second, "default_access_key"); - Utility::resolveProfileElements(file_path, "profile4", elements); + Utility::resolveProfileElementsFromFile(file_path, "profile4", elements); it = elements.find("AWS_ACCESS_KEY_ID"); EXPECT_EQ(it->second, "profile4_access_key"); } diff --git a/test/extensions/filters/http/aws_request_signing/config_test.cc b/test/extensions/filters/http/aws_request_signing/config_test.cc index f007730ebcd9..0e79aad0830c 100644 --- a/test/extensions/filters/http/aws_request_signing/config_test.cc +++ b/test/extensions/filters/http/aws_request_signing/config_test.cc @@ -96,8 +96,9 @@ TEST(AwsRequestSigningFilterConfigTest, CredentialProvider_assume_role_web_ident service_name: s3 region: us-west-2 credential_provider: - assume_role_with_web_identity: - web_identity_token: this-is-token + assume_role_with_web_identity_provider: + web_identity_token_data_source: + inline_string: this-is-token role_arn: arn:aws:iam::123456789012:role/role-name )EOF"; @@ -107,9 +108,9 @@ region: us-west-2 AwsRequestSigningProtoConfig expected_config; expected_config.set_service_name("s3"); expected_config.set_region("us-west-2"); - auto credential_provider = - expected_config.mutable_credential_provider()->mutable_assume_role_with_web_identity(); - credential_provider->set_web_identity_token("this-is-token"); + auto credential_provider = expected_config.mutable_credential_provider() + ->mutable_assume_role_with_web_identity_provider(); + credential_provider->mutable_web_identity_token_data_source()->set_inline_string("this-is-token"); credential_provider->set_role_arn("arn:aws:iam::123456789012:role/role-name"); Protobuf::util::MessageDifferencer differencer; @@ -127,11 +128,90 @@ region: us-west-2 cb(filter_callbacks); } +TEST(AwsRequestSigningFilterConfigTest, CredentialProvider_credential_file) { + const std::string yaml = R"EOF( +service_name: s3 +region: us-west-2 +credential_provider: + credentials_file_provider: + profile: profile1 + credentials_data_source: + filename: this-is-filename + )EOF"; + + AwsRequestSigningProtoConfig proto_config; + TestUtility::loadFromYamlAndValidate(yaml, proto_config); + + AwsRequestSigningProtoConfig expected_config; + expected_config.set_service_name("s3"); + expected_config.set_region("us-west-2"); + auto credential_provider = + expected_config.mutable_credential_provider()->mutable_credentials_file_provider(); + credential_provider->mutable_credentials_data_source()->set_filename("this-is-filename"); + credential_provider->set_profile("profile1"); + + Protobuf::util::MessageDifferencer differencer; + differencer.set_message_field_comparison(Protobuf::util::MessageDifferencer::EQUAL); + differencer.set_repeated_field_comparison(Protobuf::util::MessageDifferencer::AS_SET); + EXPECT_TRUE(differencer.Compare(expected_config, proto_config)); + + testing::NiceMock context; + AwsRequestSigningFilterFactory factory; + + Http::FilterFactoryCb cb = + factory.createFilterFactoryFromProto(proto_config, "stats", context).value(); + Http::MockFilterChainFactoryCallbacks filter_callbacks; + EXPECT_CALL(filter_callbacks, addStreamDecoderFilter(_)); + cb(filter_callbacks); +} + +TEST(AwsRequestSigningFilterConfigTest, CredentialProvider_credential_file_watched_dir) { + const std::string yaml = R"EOF( +service_name: s3 +region: us-west-2 +credential_provider: + credentials_file_provider: + profile: profile5 + credentials_data_source: + filename: this-is-filename + watched_directory: + path: /tmp + )EOF"; + + AwsRequestSigningProtoConfig proto_config; + TestUtility::loadFromYamlAndValidate(yaml, proto_config); + + AwsRequestSigningProtoConfig expected_config; + expected_config.set_service_name("s3"); + expected_config.set_region("us-west-2"); + auto credential_provider = + expected_config.mutable_credential_provider()->mutable_credentials_file_provider(); + credential_provider->mutable_credentials_data_source()->set_filename("this-is-filename"); + credential_provider->mutable_credentials_data_source()->mutable_watched_directory()->set_path( + "/tmp"); + credential_provider->set_profile("profile5"); + + Protobuf::util::MessageDifferencer differencer; + differencer.set_message_field_comparison(Protobuf::util::MessageDifferencer::EQUAL); + differencer.set_repeated_field_comparison(Protobuf::util::MessageDifferencer::AS_SET); + EXPECT_TRUE(differencer.Compare(expected_config, proto_config)); + + testing::NiceMock context; + AwsRequestSigningFilterFactory factory; + + Http::FilterFactoryCb cb = + factory.createFilterFactoryFromProto(proto_config, "stats", context).value(); + Http::MockFilterChainFactoryCallbacks filter_callbacks; + EXPECT_CALL(filter_callbacks, addStreamDecoderFilter(_)); + cb(filter_callbacks); +} + TEST(AwsRequestSigningFilterConfigTest, CredentialProvider_invalid) { const std::string yaml = R"EOF( service_name: s3 region: us-west-2 -credential_provider: {} +credential_provider: + custom_credential_provider_chain: true )EOF"; AwsRequestSigningProtoConfig proto_config; From f31f22f1aa36936a5936a268c53f204142fc0811 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Sat, 28 Dec 2024 01:15:37 +0000 Subject: [PATCH 06/21] string_view scope Signed-off-by: Nigel Brittain --- source/extensions/common/aws/region_provider_impl.cc | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/source/extensions/common/aws/region_provider_impl.cc b/source/extensions/common/aws/region_provider_impl.cc index ee78d1705132..40622802608b 100644 --- a/source/extensions/common/aws/region_provider_impl.cc +++ b/source/extensions/common/aws/region_provider_impl.cc @@ -56,15 +56,13 @@ absl::optional AWSCredentialsFileRegionProvider::getRegion() { absl::flat_hash_map elements = {{REGION, ""}}; absl::flat_hash_map::iterator it; - absl::string_view file_path; + std::string file_path, profile; file_path = credential_file_path_.has_value() ? credential_file_path_.value() : Utility::getCredentialFilePath(); - - absl::string_view profile; profile = profile_.has_value() ? profile_.value() : Utility::getCredentialProfileName(); // Search for the region in the credentials file - if (!Utility::resolveProfileElementsFromFile(file_path.data(), profile.data(), elements)) { + if (!Utility::resolveProfileElementsFromFile(file_path, profile, elements)) { return absl::nullopt; } it = elements.find(REGION); @@ -83,15 +81,14 @@ absl::optional AWSCredentialsFileRegionProvider::getRegionSet() { // Search for the region in the credentials file - absl::string_view file_path; + std::string file_path, profile; file_path = credential_file_path_.has_value() ? credential_file_path_.value() : Utility::getCredentialFilePath(); - absl::string_view profile; profile = profile_.has_value() ? profile_.value() : Utility::getCredentialProfileName(); // Search for the region in the credentials file - if (!Utility::resolveProfileElementsFromFile(file_path.data(), profile.data(), elements)) { + if (!Utility::resolveProfileElementsFromFile(file_path, profile, elements)) { return absl::nullopt; } it = elements.find(SIGV4A_SIGNING_REGION_SET); From 84c52b5d1d5bce88f460f437c6cd70c50d5fe37b Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Sat, 28 Dec 2024 08:51:10 +0000 Subject: [PATCH 07/21] cleanup Signed-off-by: Nigel Brittain --- .../common/aws/credentials_provider_impl.cc | 2 +- .../common/aws/credentials_provider_impl.h | 56 +++--- .../http/aws_request_signing/config.cc | 175 ++++-------------- .../filters/http/aws_request_signing/config.h | 5 + 4 files changed, 65 insertions(+), 173 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 55ca36324bae..80cd2b7748df 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -998,7 +998,7 @@ std::string stsClusterName(absl::string_view region) { CustomCredentialsProviderChain::CustomCredentialsProviderChain( Server::Configuration::ServerFactoryContext& context, absl::string_view region, const envoy::extensions::common::aws::v3::AwsCredentialProvider& credential_provider_config, - const CredentialsProviderChainFactories& factories) { + const CustomCredentialsProviderChainFactories& factories) { // Custom chain currently only supports file based and web identity credentials if (credential_provider_config.has_assume_role_with_web_identity_provider()) { diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index db7825b3b2f7..e7a10293bf4c 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -338,8 +338,6 @@ class WebIdentityCredentialsProvider : public MetadataCredentialsProviderBase, void onMetadataSuccess(const std::string&& body) override; void onMetadataError(Failure reason) override; - const std::string& roleArnForTesting() const { return role_arn_; } - private: const std::string sts_endpoint_; absl::optional web_identity_data_source_provider_; @@ -406,13 +404,33 @@ class CredentialsProviderChainFactories { std::chrono::seconds initialization_timer, absl::string_view cluster_name) const PURE; }; +class CustomCredentialsProviderChainFactories { +public: + virtual ~CustomCredentialsProviderChainFactories() = default; + + virtual CredentialsProviderSharedPtr createCredentialsFileCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& + credential_file_config = {}) const PURE; + + virtual CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name) const PURE; +}; + +// TODO(nbaws) Add additional providers to the custom chain. class CustomCredentialsProviderChain : public CredentialsProviderChain, - public CredentialsProviderChainFactories { + public CustomCredentialsProviderChainFactories { public: CustomCredentialsProviderChain( Server::Configuration::ServerFactoryContext& context, absl::string_view region, const envoy::extensions::common::aws::v3::AwsCredentialProvider& credential_provider_config, - const CredentialsProviderChainFactories& factories); + const CustomCredentialsProviderChainFactories& factories); CustomCredentialsProviderChain( Server::Configuration::ServerFactoryContext& context, absl::string_view region, @@ -441,36 +459,6 @@ class CustomCredentialsProviderChain : public CredentialsProviderChain, context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, web_identity_config, cluster_name); }; - - CredentialsProviderSharedPtr createEnvironmentCredentialsProvider() const override { - return nullptr; - } - - CredentialsProviderSharedPtr createContainerCredentialsProvider( - ABSL_ATTRIBUTE_UNUSED Api::Api& api, ABSL_ATTRIBUTE_UNUSED ServerFactoryContextOptRef context, - ABSL_ATTRIBUTE_UNUSED Singleton::Manager& singleton_manager, - ABSL_ATTRIBUTE_UNUSED const MetadataCredentialsProviderBase::CurlMetadataFetcher& - fetch_metadata_using_curl, - ABSL_ATTRIBUTE_UNUSED CreateMetadataFetcherCb create_metadata_fetcher_cb, - ABSL_ATTRIBUTE_UNUSED absl::string_view cluster_name, - ABSL_ATTRIBUTE_UNUSED absl::string_view credential_uri, - ABSL_ATTRIBUTE_UNUSED MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - ABSL_ATTRIBUTE_UNUSED std::chrono::seconds initialization_timer, - ABSL_ATTRIBUTE_UNUSED absl::string_view authorization_token = {}) const override { - return nullptr; - } - - CredentialsProviderSharedPtr createInstanceProfileCredentialsProvider( - ABSL_ATTRIBUTE_UNUSED Api::Api& api, ABSL_ATTRIBUTE_UNUSED ServerFactoryContextOptRef context, - ABSL_ATTRIBUTE_UNUSED Singleton::Manager& singleton_manager, - ABSL_ATTRIBUTE_UNUSED const MetadataCredentialsProviderBase::CurlMetadataFetcher& - fetch_metadata_using_curl, - ABSL_ATTRIBUTE_UNUSED CreateMetadataFetcherCb create_metadata_fetcher_cb, - ABSL_ATTRIBUTE_UNUSED MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - ABSL_ATTRIBUTE_UNUSED std::chrono::seconds initialization_timer, - ABSL_ATTRIBUTE_UNUSED absl::string_view cluster_name) const override { - return nullptr; - } }; /** diff --git a/source/extensions/filters/http/aws_request_signing/config.cc b/source/extensions/filters/http/aws_request_signing/config.cc index eb2906c541a8..8e1fcc0ac2cc 100644 --- a/source/extensions/filters/http/aws_request_signing/config.cc +++ b/source/extensions/filters/http/aws_request_signing/config.cc @@ -36,6 +36,41 @@ AwsRequestSigningFilterFactory::createFilterFactoryFromProtoTyped( const AwsRequestSigningProtoConfig& config, const std::string& stats_prefix, DualInfo dual_info, Server::Configuration::ServerFactoryContext& server_context) { + auto signer = createSigner(config, server_context); + if (!signer.ok()) { + return absl::InvalidArgumentError(std::string(signer.status().message())); + } + auto filter_config = + std::make_shared(std::move(signer.value()), stats_prefix, dual_info.scope, + config.host_rewrite(), config.use_unsigned_payload()); + return [filter_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { + auto filter = std::make_shared(filter_config); + callbacks.addStreamDecoderFilter(filter); + }; +} + +absl::StatusOr +AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( + const AwsRequestSigningProtoPerRouteConfig& per_route_config, + Server::Configuration::ServerFactoryContext& server_context, + ProtobufMessage::ValidationVisitor&) { + + auto signer = createSigner(per_route_config.aws_request_signing(), server_context); + if (!signer.ok()) { + return absl::InvalidArgumentError(std::string(signer.status().message())); + } + + return std::make_shared( + std::move(signer.value()), per_route_config.stat_prefix(), server_context.scope(), + per_route_config.aws_request_signing().host_rewrite(), + per_route_config.aws_request_signing().use_unsigned_payload()); +} + +absl::StatusOr +AwsRequestSigningFilterFactory::createSigner( + const AwsRequestSigningProtoConfig& config, + Server::Configuration::ServerFactoryContext& server_context) { + std::string region = config.region(); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; @@ -124,7 +159,7 @@ AwsRequestSigningFilterFactory::createFilterFactoryFromProtoTyped( std::unique_ptr signer; if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { - signer = std::make_unique( + return std::make_unique( config.service_name(), region, credentials_provider.value(), server_context, matcher_config, query_string, expiration_time); } else { @@ -134,146 +169,10 @@ AwsRequestSigningFilterFactory::createFilterFactoryFromProtoTyped( "SigV4 region string cannot contain wildcards or commas. Region sets " "can be specified when using signing_algorithm: AWS_SIGV4A."); } - signer = std::make_unique( + return std::make_unique( config.service_name(), region, credentials_provider.value(), server_context, matcher_config, query_string, expiration_time); } - - auto filter_config = - std::make_shared(std::move(signer), stats_prefix, dual_info.scope, - config.host_rewrite(), config.use_unsigned_payload()); - return [filter_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { - auto filter = std::make_shared(filter_config); - callbacks.addStreamDecoderFilter(filter); - }; -} - -// TODO: @nbaws remove duplication from above - -absl::StatusOr -AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( - const AwsRequestSigningProtoPerRouteConfig& per_route_config, - Server::Configuration::ServerFactoryContext& server_context, - ProtobufMessage::ValidationVisitor&) { - - std::string region = per_route_config.aws_request_signing().region(); - - envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; - - // If we have an overriding credential provider configuration, read it here - envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = {}; - if (per_route_config.aws_request_signing().has_credential_provider()) { - if (per_route_config.aws_request_signing() - .credential_provider() - .has_credentials_file_provider()) { - credential_file_config = - per_route_config.aws_request_signing().credential_provider().credentials_file_provider(); - } - } - - if (region.empty()) { - auto region_provider = std::make_shared(); - absl::optional regionOpt; - if (per_route_config.aws_request_signing().signing_algorithm() == - AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { - regionOpt = region_provider->getRegionSet(); - } else { - regionOpt = region_provider->getRegion(); - } - if (!regionOpt.has_value()) { - return absl::InvalidArgumentError( - "AWS region is not set in xDS configuration and failed to retrieve from " - "environment variable or AWS profile/config files."); - } - region = regionOpt.value(); - } - - absl::StatusOr - credentials_provider = - absl::InvalidArgumentError("No credentials provider settings configured."); - - bool has_credential_provider_settings = - per_route_config.aws_request_signing().has_credential_provider() && - (per_route_config.aws_request_signing() - .credential_provider() - .has_assume_role_with_web_identity_provider() || - per_route_config.aws_request_signing() - .credential_provider() - .has_credentials_file_provider()); - - if (per_route_config.aws_request_signing().has_credential_provider()) { - if (per_route_config.aws_request_signing().credential_provider().has_inline_credential()) { - const auto& inline_credential = - per_route_config.aws_request_signing().credential_provider().inline_credential(); - credentials_provider = std::make_shared( - inline_credential.access_key_id(), inline_credential.secret_access_key(), - inline_credential.session_token()); - } - - if (per_route_config.aws_request_signing() - .credential_provider() - .custom_credential_provider_chain()) { - // Custom credential provider chain - if (has_credential_provider_settings) { - credentials_provider = - std::make_shared( - server_context, region, - per_route_config.aws_request_signing().credential_provider()); - } - } else { - // Override default credential provider chain settings with any provided settings - if (has_credential_provider_settings) { - credential_provider_config = per_route_config.aws_request_signing().credential_provider(); - } - credentials_provider = - std::make_shared( - server_context.api(), makeOptRef(server_context), server_context.singletonManager(), - region, nullptr, credential_provider_config); - } - } else { - // No credential provider settings provided, so make the default credentials provider chain - credentials_provider = - std::make_shared( - server_context.api(), makeOptRef(server_context), server_context.singletonManager(), - region, nullptr, credential_provider_config); - } - - if (!credentials_provider.ok()) { - return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); - } - - const auto matcher_config = Extensions::Common::Aws::AwsSigningHeaderExclusionVector( - per_route_config.aws_request_signing().match_excluded_headers().begin(), - per_route_config.aws_request_signing().match_excluded_headers().end()); - - const bool query_string = per_route_config.aws_request_signing().has_query_string(); - - const uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( - per_route_config.aws_request_signing().query_string(), expiration_time, 5); - - std::unique_ptr signer; - - if (per_route_config.aws_request_signing().signing_algorithm() == - AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { - signer = std::make_unique( - per_route_config.aws_request_signing().service_name(), region, credentials_provider.value(), - server_context, matcher_config, query_string, expiration_time); - } else { - // Verify that we have not specified a region set when using sigv4 algorithm - if (isARegionSet(region)) { - return absl::InvalidArgumentError( - "SigV4 region string cannot contain wildcards or commas. Region sets " - "can be specified when using signing_algorithm: AWS_SIGV4A."); - } - signer = std::make_unique( - per_route_config.aws_request_signing().service_name(), region, credentials_provider.value(), - server_context, matcher_config, query_string, expiration_time); - } - - return std::make_shared( - std::move(signer), per_route_config.stat_prefix(), server_context.scope(), - per_route_config.aws_request_signing().host_rewrite(), - per_route_config.aws_request_signing().use_unsigned_payload()); } /** diff --git a/source/extensions/filters/http/aws_request_signing/config.h b/source/extensions/filters/http/aws_request_signing/config.h index 62805bc468a6..48c3d40f6dd5 100644 --- a/source/extensions/filters/http/aws_request_signing/config.h +++ b/source/extensions/filters/http/aws_request_signing/config.h @@ -3,6 +3,7 @@ #include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.h" #include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.validate.h" +#include "source/extensions/common/aws/signer.h" #include "source/extensions/filters/http/common/factory_base.h" namespace Envoy { @@ -35,6 +36,10 @@ class AwsRequestSigningFilterFactory createRouteSpecificFilterConfigTyped(const AwsRequestSigningProtoPerRouteConfig& per_route_config, Server::Configuration::ServerFactoryContext& context, ProtobufMessage::ValidationVisitor&) override; + + absl::StatusOr + createSigner(const AwsRequestSigningProtoConfig& config, + Server::Configuration::ServerFactoryContext& server_context); }; using UpstreamAwsRequestSigningFilterFactory = AwsRequestSigningFilterFactory; From 8a0056db289bd6190422b7c80b04291207fd4dc2 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Sat, 28 Dec 2024 09:19:17 +0000 Subject: [PATCH 08/21] additional mocks Signed-off-by: Nigel Brittain --- .../aws/credentials_provider_impl_test.cc | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index 2de705e53387..8283fc5a63ab 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -2530,6 +2530,30 @@ class MockCredentialsProviderChainFactories : public CredentialsProviderChainFac (const)); }; +class MockCustomCredentialsProviderChainFactories : public CustomCredentialsProviderChainFactories { +public: + MOCK_METHOD( + CredentialsProviderSharedPtr, mockCreateCredentialsFileCredentialsProvider, + (Server::Configuration::ServerFactoryContext&, + (const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& config)), + (const)); + + CredentialsProviderSharedPtr createCredentialsFileCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, + const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& config) + const override { + return mockCreateCredentialsFileCredentialsProvider(context, config); + } + + MOCK_METHOD( + CredentialsProviderSharedPtr, createWebIdentityCredentialsProvider, + (Server::Configuration::ServerFactoryContext&, CreateMetadataFetcherCb, absl::string_view, + MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider&, + absl::string_view), + (const)); +}; + class DefaultCredentialsProviderChainTest : public testing::Test { public: DefaultCredentialsProviderChainTest() : api_(Api::createApiForTest(time_system_)) { @@ -2872,7 +2896,7 @@ TEST(CredentialsProviderChainTest, getCredentials_secondProviderReturns) { class CustomCredentialsProviderChainTest : public testing::Test {}; TEST_F(CustomCredentialsProviderChainTest, CreateFileCredentialProviderOnly) { - NiceMock factories; + NiceMock factories; NiceMock server_context; auto region = "ap-southeast-2"; auto file_path = TestEnvironment::writeStringToFileForTest("credentials", "hello"); @@ -2892,7 +2916,7 @@ TEST_F(CustomCredentialsProviderChainTest, CreateFileCredentialProviderOnly) { } TEST_F(CustomCredentialsProviderChainTest, CreateWebIdentityCredentialProviderOnly) { - NiceMock factories; + NiceMock factories; NiceMock server_context; auto region = "ap-southeast-2"; auto file_path = TestEnvironment::writeStringToFileForTest("credentials", "hello"); @@ -2913,7 +2937,7 @@ TEST_F(CustomCredentialsProviderChainTest, CreateWebIdentityCredentialProviderOn } TEST_F(CustomCredentialsProviderChainTest, CreateFileAndWebProviders) { - NiceMock factories; + NiceMock factories; NiceMock server_context; auto region = "ap-southeast-2"; auto file_path = TestEnvironment::writeStringToFileForTest("credentials", "hello"); From bbb30e40ff40c91b55522c5fecb6e9b144ec2359 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Sat, 28 Dec 2024 23:11:45 +0000 Subject: [PATCH 09/21] address feedback Signed-off-by: Nigel Brittain --- api/envoy/extensions/common/aws/v3/credential_provider.proto | 3 ++- source/extensions/common/aws/credentials_provider_impl.cc | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/envoy/extensions/common/aws/v3/credential_provider.proto b/api/envoy/extensions/common/aws/v3/credential_provider.proto index c05e34cbd30a..722e9b32867d 100644 --- a/api/envoy/extensions/common/aws/v3/credential_provider.proto +++ b/api/envoy/extensions/common/aws/v3/credential_provider.proto @@ -71,8 +71,9 @@ message AssumeRoleWithWebIdentityCredentialProvider { message CredentialsFileCredentialProvider { // Data source from which to retrieve AWS credentials // When using this data source, if a ``watched_directory`` is provided, the credential file will be re-read when a file move is detected. + // See :ref:`watched_directory ` for more information about the ``watched_directory`` field. config.core.v3.DataSource credentials_data_source = 1 [(udpa.annotations.sensitive) = true]; - // The profile within the credentials_file data source + // The profile within the credentials_file data source. If not provided, the default profile will be used. string profile = 2; } diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 80cd2b7748df..c72f23a4e432 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -353,7 +353,6 @@ bool CredentialsFileCredentialsProvider::needsRefresh() { return has_watched_directory_ ? true : context_.api().timeSource().systemTime() - last_updated_ > REFRESH_INTERVAL; - // return context_.api().timeSource().systemTime() - last_updated_ > REFRESH_INTERVAL; } void CredentialsFileCredentialsProvider::refresh() { From fa470efd8f70b71b09332685cbe3a04e5a602fe9 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 19 Dec 2024 03:07:13 +0000 Subject: [PATCH 10/21] fix assertion failure during rds Signed-off-by: Nigel Brittain --- source/extensions/common/aws/BUILD | 1 + .../common/aws/credentials_provider_impl.cc | 34 +++++++++++++------ .../common/aws/credentials_provider_impl.h | 2 ++ 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/source/extensions/common/aws/BUILD b/source/extensions/common/aws/BUILD index e16539526327..1609982f543f 100644 --- a/source/extensions/common/aws/BUILD +++ b/source/extensions/common/aws/BUILD @@ -123,6 +123,7 @@ envoy_cc_library( "//source/common/config:datasource_lib", "//source/common/http:utility_lib", "//source/common/init:target_lib", + "//source/common/init:manager_lib", "//source/common/json:json_loader_lib", "//source/common/runtime:runtime_features_lib", "//source/common/tracing:http_tracer_lib", diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index c72f23a4e432..93cb86abac62 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -148,18 +148,23 @@ MetadataCredentialsProviderBase::MetadataCredentialsProviderBase( ALL_METADATACREDENTIALSPROVIDER_STATS(POOL_COUNTER(*scope_), POOL_GAUGE(*scope_))}); stats_->metadata_refresh_state_.set(uint64_t(refresh_state_)); - init_target_ = std::make_unique(debug_name_, [this]() -> void { - tls_slot_ = - ThreadLocal::TypedSlot::makeUnique(context_->threadLocal()); - tls_slot_->set( - [&](Event::Dispatcher&) { return std::make_shared(*this); }); + // If credential provider is being created during Envoy initialization, use init manager to delay cluster creation + // If we are here during normal processing, such as xDS update, then create clusters and initialize TLS immediately + if(context_->initManager().state() == Envoy::Init::Manager::State::Initialized) + { + initializeTlsAndCluster(); + } + else + { + init_target_ = std::make_unique(debug_name_, [this]() -> void { - createCluster(true); + initializeTlsAndCluster(); - init_target_->ready(); - init_target_.reset(); - }); - context_->initManager().add(*init_target_); + init_target_->ready(); + init_target_.reset(); + }); + context_->initManager().add(*init_target_); + } } }; @@ -171,6 +176,15 @@ MetadataCredentialsProviderBase::ThreadLocalCredentialsCache::~ThreadLocalCreden } } +void MetadataCredentialsProviderBase::initializeTlsAndCluster() { + tls_slot_ = + ThreadLocal::TypedSlot::makeUnique(context_->threadLocal()); + + tls_slot_->set( + [&](Event::Dispatcher&) { return std::make_shared(*this); }); + createCluster(true); +} + void MetadataCredentialsProviderBase::createCluster(bool new_timer) { auto cluster = Utility::createInternalClusterStatic(cluster_name_, cluster_type_, uri_); diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index e7a10293bf4c..c2d5cdf9184c 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -18,6 +18,7 @@ #include "source/common/common/thread.h" #include "source/common/config/datasource.h" #include "source/common/init/target_impl.h" +#include "envoy/init/manager.h" #include "source/common/protobuf/message_validator_impl.h" #include "source/common/protobuf/utility.h" #include "source/extensions/common/aws/credentials_provider.h" @@ -149,6 +150,7 @@ class MetadataCredentialsProviderBase : public CachedCredentialsProviderBase { private: void createCluster(bool new_timer); + void initializeTlsAndCluster(); protected: struct LoadClusterEntryHandleImpl From af5ef9885167cdff9de3ebe8a14b7e64823e9a7f Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 19 Dec 2024 05:18:08 +0000 Subject: [PATCH 11/21] test cases Signed-off-by: Nigel Brittain --- source/extensions/common/aws/BUILD | 2 +- .../common/aws/credentials_provider_impl.cc | 27 +++++++++---------- .../common/aws/credentials_provider_impl.h | 2 +- .../aws/credentials_provider_impl_test.cc | 11 ++++++++ 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/source/extensions/common/aws/BUILD b/source/extensions/common/aws/BUILD index 1609982f543f..5b878c27d4c2 100644 --- a/source/extensions/common/aws/BUILD +++ b/source/extensions/common/aws/BUILD @@ -122,8 +122,8 @@ envoy_cc_library( "//source/common/common:thread_lib", "//source/common/config:datasource_lib", "//source/common/http:utility_lib", - "//source/common/init:target_lib", "//source/common/init:manager_lib", + "//source/common/init:target_lib", "//source/common/json:json_loader_lib", "//source/common/runtime:runtime_features_lib", "//source/common/tracing:http_tracer_lib", diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 93cb86abac62..0728c9620e7d 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -148,16 +148,13 @@ MetadataCredentialsProviderBase::MetadataCredentialsProviderBase( ALL_METADATACREDENTIALSPROVIDER_STATS(POOL_COUNTER(*scope_), POOL_GAUGE(*scope_))}); stats_->metadata_refresh_state_.set(uint64_t(refresh_state_)); - // If credential provider is being created during Envoy initialization, use init manager to delay cluster creation - // If we are here during normal processing, such as xDS update, then create clusters and initialize TLS immediately - if(context_->initManager().state() == Envoy::Init::Manager::State::Initialized) - { - initializeTlsAndCluster(); - } - else - { + // If credential provider is being created during Envoy initialization, use init manager to + // delay cluster creation If we are here during normal processing, such as xDS update, then + // create clusters and initialize TLS immediately + if (context_->initManager().state() == Envoy::Init::Manager::State::Initialized) { + initializeTlsAndCluster(); + } else { init_target_ = std::make_unique(debug_name_, [this]() -> void { - initializeTlsAndCluster(); init_target_->ready(); @@ -177,12 +174,12 @@ MetadataCredentialsProviderBase::ThreadLocalCredentialsCache::~ThreadLocalCreden } void MetadataCredentialsProviderBase::initializeTlsAndCluster() { - tls_slot_ = - ThreadLocal::TypedSlot::makeUnique(context_->threadLocal()); - - tls_slot_->set( - [&](Event::Dispatcher&) { return std::make_shared(*this); }); - createCluster(true); + tls_slot_ = + ThreadLocal::TypedSlot::makeUnique(context_->threadLocal()); + + tls_slot_->set( + [&](Event::Dispatcher&) { return std::make_shared(*this); }); + createCluster(true); } void MetadataCredentialsProviderBase::createCluster(bool new_timer) { diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index c2d5cdf9184c..ce12a3750bdb 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -11,6 +11,7 @@ #include "envoy/event/timer.h" #include "envoy/extensions/common/aws/v3/credential_provider.pb.h" #include "envoy/http/message.h" +#include "envoy/init/manager.h" #include "envoy/server/factory_context.h" #include "source/common/common/lock_guard.h" @@ -18,7 +19,6 @@ #include "source/common/common/thread.h" #include "source/common/config/datasource.h" #include "source/common/init/target_impl.h" -#include "envoy/init/manager.h" #include "source/common/protobuf/message_validator_impl.h" #include "source/common/protobuf/utility.h" #include "source/extensions/common/aws/credentials_provider.h" diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index 8283fc5a63ab..189edc0db7f0 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -1506,6 +1506,17 @@ class ContainerCredentialsProviderTest : public testing::Test { NiceMock init_watcher_; }; +TEST_F(ContainerCredentialsProviderTest, CreationAfterInitCompleted) { + // Handle the case where we've already completed init. This validates that clusters create + // successfully but init manager is not used + NiceMock initManager; + ON_CALL(context_, initManager()).WillByDefault(ReturnRef(initManager)); + ON_CALL(initManager, state()).WillByDefault(Return(Init::Manager::State::Initialized)); + EXPECT_CALL(cluster_manager_, addOrUpdateCluster(_, _, _)); + EXPECT_CALL(initManager, add(_)).Times(0); + setupProvider(); +} + TEST_F(ContainerCredentialsProviderTest, FailedFetchingDocument) { // Setup timer. From 9d91a1b5e11916a51707819e8ae7eef54cd5b78d Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 19 Dec 2024 06:43:07 +0000 Subject: [PATCH 12/21] fix test leak Signed-off-by: Nigel Brittain --- .../common/aws/credentials_provider_impl_test.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index 189edc0db7f0..1f722d38a11a 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -1504,17 +1504,18 @@ class ContainerCredentialsProviderTest : public testing::Test { MetadataFetcher::MetadataReceiver::RefreshState refresh_state_; Init::TargetHandlePtr init_target_; NiceMock init_watcher_; + NiceMock init_manager_; }; TEST_F(ContainerCredentialsProviderTest, CreationAfterInitCompleted) { // Handle the case where we've already completed init. This validates that clusters create // successfully but init manager is not used - NiceMock initManager; - ON_CALL(context_, initManager()).WillByDefault(ReturnRef(initManager)); - ON_CALL(initManager, state()).WillByDefault(Return(Init::Manager::State::Initialized)); + ON_CALL(context_, initManager()).WillByDefault(ReturnRef(init_manager_)); + ON_CALL(init_manager_, state()).WillByDefault(Return(Init::Manager::State::Initialized)); EXPECT_CALL(cluster_manager_, addOrUpdateCluster(_, _, _)); - EXPECT_CALL(initManager, add(_)).Times(0); + EXPECT_CALL(init_manager_, add(_)).Times(0); setupProvider(); + delete (raw_metadata_fetcher_); } TEST_F(ContainerCredentialsProviderTest, FailedFetchingDocument) { From f66cfc7a9ad414e16259e16a70a96bd32e32f056 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Sat, 21 Dec 2024 04:43:36 +0000 Subject: [PATCH 13/21] singleton webidentity Signed-off-by: Nigel Brittain --- source/extensions/common/aws/credentials_provider_impl.cc | 7 ++++--- source/extensions/common/aws/credentials_provider_impl.h | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 0728c9620e7d..de9688bc4ff8 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -149,7 +149,7 @@ MetadataCredentialsProviderBase::MetadataCredentialsProviderBase( stats_->metadata_refresh_state_.set(uint64_t(refresh_state_)); // If credential provider is being created during Envoy initialization, use init manager to - // delay cluster creation If we are here during normal processing, such as xDS update, then + // delay cluster creation. If we are here during normal processing, such as xDS update, then // create clusters and initialize TLS immediately if (context_->initManager().state() == Envoy::Init::Manager::State::Initialized) { initializeTlsAndCluster(); @@ -1077,9 +1077,9 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( !web_identity.role_arn().empty()) { const auto sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; - const auto region_uuid = absl::StrCat(region, "_", context->api().randomGenerator().uuid()); + // const auto region_uuid = absl::StrCat(region, "_", context->api().randomGenerator().uuid()); - const auto cluster_name = stsClusterName(region_uuid); + const auto cluster_name = stsClusterName(region); ENVOY_LOG( debug, @@ -1135,6 +1135,7 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( // extensions SINGLETON_MANAGER_REGISTRATION(container_credentials_provider); SINGLETON_MANAGER_REGISTRATION(instance_profile_credentials_provider); +SINGLETON_MANAGER_REGISTRATION(web_identity_credentials_provider); CredentialsProviderSharedPtr DefaultCredentialsProviderChain::createContainerCredentialsProvider( Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index ce12a3750bdb..4c4a7c96649c 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -323,6 +323,7 @@ class ContainerCredentialsProvider : public MetadataCredentialsProviderBase, * OpenID) */ class WebIdentityCredentialsProvider : public MetadataCredentialsProviderBase, + public Envoy::Singleton::Instance, public MetadataFetcher::MetadataReceiver { public: // token and token_file_path are mutually exclusive. If token is not empty, token_file_path is From 0f2bcb2f2079693162a112ef3f5ae21d696e7467 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Wed, 1 Jan 2025 09:51:22 +0000 Subject: [PATCH 14/21] update for async handling Signed-off-by: Nigel Brittain --- .../common/aws/credentials_provider.h | 11 +++ .../common/aws/credentials_provider_impl.cc | 81 ++++++++++++++++++- .../common/aws/credentials_provider_impl.h | 27 +++---- source/extensions/common/aws/signer.h | 9 ++- .../extensions/common/aws/signer_base_impl.cc | 28 +++---- .../extensions/common/aws/signer_base_impl.h | 12 ++- .../extensions/common/aws/sigv4_signer_impl.h | 3 +- .../common/aws/sigv4a_signer_impl.h | 3 +- .../http/aws_lambda/aws_lambda_filter.cc | 6 +- .../http/aws_lambda/aws_lambda_filter.h | 4 + .../filters/http/aws_lambda/config.cc | 8 +- .../aws_request_signing_filter.cc | 34 ++++++-- .../aws_request_signing_filter.h | 14 +++- .../http/aws_request_signing/config.cc | 77 +++++++++++++----- .../filters/http/aws_request_signing/config.h | 16 ++++ .../grpc_credentials/aws_iam/config.cc | 6 +- .../grpc_credentials/aws_iam/config.h | 5 +- test/extensions/common/aws/mocks.h | 8 +- .../common/aws/sigv4_signer_corpus_test.cc | 8 +- .../common/aws/sigv4_signer_impl_test.cc | 38 ++++----- .../common/aws/sigv4a_signer_corpus_test.cc | 7 +- .../common/aws/sigv4a_signer_impl_test.cc | 55 ++++++++----- 22 files changed, 317 insertions(+), 143 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider.h b/source/extensions/common/aws/credentials_provider.h index dc06c0c77988..3c4e439a9fcd 100644 --- a/source/extensions/common/aws/credentials_provider.h +++ b/source/extensions/common/aws/credentials_provider.h @@ -58,6 +58,9 @@ class Credentials { */ class CredentialsProvider { public: + + using CredentialsPendingCallback = std::function; + virtual ~CredentialsProvider() = default; /** @@ -66,6 +69,14 @@ class CredentialsProvider { * @return AWS credentials */ virtual Credentials getCredentials() PURE; + + /** + * Check if credentials are pending, which supports async credential fetching. + * + * @return bool true if credentials are pending, false otherwise + */ + virtual bool credentialsPending(ABSL_ATTRIBUTE_UNUSED CredentialsPendingCallback&& cb) { return false; } + }; using CredentialsConstSharedPtr = std::shared_ptr; diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index de9688bc4ff8..8d3153619579 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include "envoy/common/exception.h" @@ -265,6 +266,14 @@ void MetadataCredentialsProviderBase::ThreadLocalCredentialsCache::onClusterRemo } }; + bool MetadataCredentialsProviderBase::credentialsPending(CredentialsPendingCallback&& cb) { + if(cb) + { + ENVOY_LOG_MISC(debug,"Adding credentials pending callback to queue"); + } + return credentials_pending_; + } + // Async provider uses its own refresh mechanism. Calling refreshIfNeeded() here is not thread safe. Credentials MetadataCredentialsProviderBase::getCredentials() { if (context_) { @@ -322,6 +331,9 @@ void MetadataCredentialsProviderBase::handleFetchDone() { cache_duration_timer_->enableTimer(cache_duration_); } } + + // We are now no longer waiting for credentials + credentials_pending_.exchange(false); } } @@ -484,6 +496,10 @@ void InstanceProfileCredentialsProvider::refresh() { }; continue_on_async_fetch_failure_ = true; continue_on_async_fetch_failure_reason_ = "Token fetch failed, falling back to IMDSv1"; + + // mark credentials as pending while async completes + credentials_pending_.exchange(true); + metadata_fetcher_->fetch(token_req_message, Tracing::NullSpan::instance(), *this); } } @@ -516,6 +532,10 @@ void InstanceProfileCredentialsProvider::fetchInstanceRole(const std::string&& t on_async_fetch_cb_ = [this, token_string = std::move(token_string)](const std::string&& arg) { return this->fetchCredentialFromInstanceRoleAsync(std::move(arg), std::move(token_string)); }; + + // mark credentials as pending while async completes + credentials_pending_.exchange(true); + metadata_fetcher_->fetch(message, Tracing::NullSpan::instance(), *this); } } @@ -572,6 +592,10 @@ void InstanceProfileCredentialsProvider::fetchCredentialFromInstanceRole( on_async_fetch_cb_ = [this](const std::string&& arg) { return this->extractCredentialsAsync(std::move(arg)); }; + + // mark credentials as pending while async completes + credentials_pending_.exchange(true); + metadata_fetcher_->fetch(message, Tracing::NullSpan::instance(), *this); } } @@ -719,6 +743,10 @@ void ContainerCredentialsProvider::refresh() { on_async_fetch_cb_ = [this](const std::string&& arg) { return this->extractCredentials(std::move(arg)); }; + + // mark credentials as pending while async completes + credentials_pending_.exchange(true); + metadata_fetcher_->fetch(message, Tracing::NullSpan::instance(), *this); } } @@ -869,6 +897,10 @@ void WebIdentityCredentialsProvider::refresh() { on_async_fetch_cb_ = [this](const std::string&& arg) { return this->extractCredentials(std::move(arg)); }; + + // mark credentials as pending while async completes + credentials_pending_.exchange(true); + metadata_fetcher_->fetch(message, Tracing::NullSpan::instance(), *this); } @@ -962,6 +994,18 @@ void WebIdentityCredentialsProvider::onMetadataError(Failure reason) { handleFetchDone(); } +bool CredentialsProviderChain::credentialsPending(CredentialsPendingCallback&& cb) { + for (auto& provider : providers_) { + if(provider->credentialsPending(std::move(cb))) + { + ENVOY_LOG_MISC(debug,"Credentials are pending"); + return true; + } + } + ENVOY_LOG_MISC(debug,"Credentials are not pending"); + return false; +} + Credentials CredentialsProviderChain::getCredentials() { for (auto& provider : providers_) { const auto credentials = provider->getCredentials(); @@ -1023,7 +1067,7 @@ CustomCredentialsProviderChain::CustomCredentialsProviderChain( const auto refresh_state = MetadataFetcher::MetadataReceiver::RefreshState::FirstRefresh; const auto initialization_timer = std::chrono::seconds(2); add(factories.createWebIdentityCredentialsProvider( - context, MetadataFetcher::create, sts_endpoint, refresh_state, initialization_timer, + context, context.singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, initialization_timer, web_identity, cluster_name)); } @@ -1086,7 +1130,7 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( "Using web identity credentials provider with STS endpoint: {} and session name: {}", sts_endpoint, web_identity.role_session_name()); add(factories.createWebIdentityCredentialsProvider( - context.value(), MetadataFetcher::create, sts_endpoint, refresh_state, + context.value(), context->singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, initialization_timer, web_identity, cluster_name)); } } @@ -1169,8 +1213,39 @@ DefaultCredentialsProviderChain::createInstanceProfileCredentialsProvider( api, context, fetch_metadata_using_curl, create_metadata_fetcher_cb, refresh_state, initialization_timer, cluster_name); }); -} + } + CredentialsProviderSharedPtr DefaultCredentialsProviderChain::createWebIdentityCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name) const { + return singleton_manager.getTyped(SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), + [&context, create_metadata_fetcher_cb,sts_endpoint,refresh_state, initialization_timer, web_identity_config, cluster_name]{ + return std::make_shared( + context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, + web_identity_config, cluster_name); + }); + }; + CredentialsProviderSharedPtr CustomCredentialsProviderChain::createWebIdentityCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name) const { + return singleton_manager.getTyped(SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), + [&context, create_metadata_fetcher_cb,sts_endpoint,refresh_state, initialization_timer, web_identity_config, cluster_name]{ + return std::make_shared( + context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, + web_identity_config, cluster_name); + }); + }; + } // namespace Aws } // namespace Common } // namespace Extensions diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index c307a8336b5a..f06d43c8f167 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -145,7 +145,8 @@ class MetadataCredentialsProviderBase : public CachedCredentialsProviderBase { std::chrono::seconds initialization_timer); Credentials getCredentials() override; - + bool credentialsPending(CredentialsPendingCallback&& cb) override; + // Get the Metadata credentials cache duration. static std::chrono::seconds getCacheDuration(); @@ -247,6 +248,8 @@ class MetadataCredentialsProviderBase : public CachedCredentialsProviderBase { std::shared_ptr stats_; // Atomic flag for cluster recreate std::atomic is_creating_ = false; + // Are credentials pending? + std::atomic credentials_pending_ = true; }; /** @@ -324,7 +327,6 @@ class ContainerCredentialsProvider : public MetadataCredentialsProviderBase, * OpenID) */ class WebIdentityCredentialsProvider : public MetadataCredentialsProviderBase, - public Envoy::Singleton::Instance, public Envoy::Singleton::Instance, public MetadataFetcher::MetadataReceiver { public: @@ -367,6 +369,7 @@ class CredentialsProviderChain : public CredentialsProvider, } Credentials getCredentials() override; + bool credentialsPending(CredentialsPendingCallback&& cb) override; protected: std::list providers_; @@ -384,7 +387,7 @@ class CredentialsProviderChainFactories { credential_file_config = {}) const PURE; virtual CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( - Server::Configuration::ServerFactoryContext& context, + Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, @@ -419,7 +422,7 @@ class CustomCredentialsProviderChainFactories { credential_file_config = {}) const PURE; virtual CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( - Server::Configuration::ServerFactoryContext& context, + Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, @@ -453,17 +456,13 @@ class CustomCredentialsProviderChain : public CredentialsProviderChain, }; CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( - Server::Configuration::ServerFactoryContext& context, + Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& web_identity_config, - absl::string_view cluster_name) const override { - return std::make_shared( - context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, - web_identity_config, cluster_name); - }; + absl::string_view cluster_name) const override; }; /** @@ -539,17 +538,13 @@ class DefaultCredentialsProviderChain : public CredentialsProviderChain, std::chrono::seconds initialization_timer, absl::string_view cluster_name) const override; CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( - Server::Configuration::ServerFactoryContext& context, + Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& web_identity_config, - absl::string_view cluster_name) const override { - return std::make_shared( - context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, - web_identity_config, cluster_name); - } + absl::string_view cluster_name) const override; }; using InstanceProfileCredentialsProviderPtr = std::shared_ptr; diff --git a/source/extensions/common/aws/signer.h b/source/extensions/common/aws/signer.h index 623ddfce27c5..61a03567582b 100644 --- a/source/extensions/common/aws/signer.h +++ b/source/extensions/common/aws/signer.h @@ -1,5 +1,6 @@ #pragma once +#include "credentials_provider.h" #include "envoy/common/pure.h" #include "envoy/http/message.h" @@ -19,7 +20,7 @@ class Signer { * @param override_region override the default region that has to be used to sign the request * @throws EnvoyException if the request cannot be signed. */ - virtual absl::Status sign(Http::RequestMessage& message, bool sign_body, + virtual absl::Status sign(Http::RequestMessage& message, const Credentials credentials, bool sign_body, const absl::string_view override_region = "") PURE; /** @@ -28,7 +29,7 @@ class Signer { * @param override_region override the default region that has to be used to sign the request * @throws EnvoyException if the request cannot be signed. */ - virtual absl::Status signEmptyPayload(Http::RequestHeaderMap& headers, + virtual absl::Status signEmptyPayload(Http::RequestHeaderMap& headers, const Credentials credentials, const absl::string_view override_region = "") PURE; /** @@ -37,7 +38,7 @@ class Signer { * @param override_region override the default region that has to be used to sign the request * @throws EnvoyException if the request cannot be signed. */ - virtual absl::Status signUnsignedPayload(Http::RequestHeaderMap& headers, + virtual absl::Status signUnsignedPayload(Http::RequestHeaderMap& headers, const Credentials credentials, const absl::string_view override_region = "") PURE; /** @@ -47,7 +48,7 @@ class Signer { * @param override_region override the default region that has to be used to sign the request * @throws EnvoyException if the request cannot be signed. */ - virtual absl::Status sign(Http::RequestHeaderMap& headers, const std::string& content_hash, + virtual absl::Status sign(Http::RequestHeaderMap& headers, const Credentials credentials, const std::string& content_hash, const absl::string_view override_region = "") PURE; }; diff --git a/source/extensions/common/aws/signer_base_impl.cc b/source/extensions/common/aws/signer_base_impl.cc index 6b370965f73c..f27c80ac3e76 100644 --- a/source/extensions/common/aws/signer_base_impl.cc +++ b/source/extensions/common/aws/signer_base_impl.cc @@ -23,25 +23,25 @@ namespace Extensions { namespace Common { namespace Aws { -absl::Status SignerBaseImpl::sign(Http::RequestMessage& message, bool sign_body, +absl::Status SignerBaseImpl::sign(Http::RequestMessage& message,const Credentials credentials, bool sign_body, const absl::string_view override_region) { const auto content_hash = createContentHash(message, sign_body); auto& headers = message.headers(); - return sign(headers, content_hash, override_region); + return sign(headers, credentials, content_hash, override_region); } -absl::Status SignerBaseImpl::signEmptyPayload(Http::RequestHeaderMap& headers, +absl::Status SignerBaseImpl::signEmptyPayload(Http::RequestHeaderMap& headers, const Credentials credentials, const absl::string_view override_region) { headers.setReference(SignatureHeaders::get().ContentSha256, SignatureConstants::HashedEmptyString); - return sign(headers, std::string(SignatureConstants::HashedEmptyString), override_region); + return sign(headers, credentials, std::string(SignatureConstants::HashedEmptyString), override_region); } -absl::Status SignerBaseImpl::signUnsignedPayload(Http::RequestHeaderMap& headers, +absl::Status SignerBaseImpl::signUnsignedPayload(Http::RequestHeaderMap& headers, const Credentials credentials, const absl::string_view override_region) { headers.setReference(SignatureHeaders::get().ContentSha256, SignatureConstants::UnsignedPayload); - return sign(headers, std::string(SignatureConstants::UnsignedPayload), override_region); + return sign(headers, credentials, std::string(SignatureConstants::UnsignedPayload), override_region); } // Region support utilities for sigv4a @@ -54,20 +54,20 @@ void SignerBaseImpl::addRegionQueryParam( std::string SignerBaseImpl::getRegion() const { return region_; } -absl::Status SignerBaseImpl::sign(Http::RequestHeaderMap& headers, const std::string& content_hash, +absl::Status SignerBaseImpl::sign(Http::RequestHeaderMap& headers, const Credentials credentials, const std::string& content_hash, const absl::string_view override_region) { if (!query_string_ && !content_hash.empty()) { headers.setReferenceKey(SignatureHeaders::get().ContentSha256, content_hash); } - const auto& credentials = credentials_provider_->getCredentials(); - if (!credentials.accessKeyId() || !credentials.secretAccessKey()) { - // Empty or "anonymous" credentials are a valid use-case for non-production environments. - // This behavior matches what the AWS SDK would do. - ENVOY_LOG_MISC(debug, "Sign exiting early - no credentials found"); - return absl::OkStatus(); - } + // const auto& credentials = credentials_provider_->getCredentials(); + // if (!credentials.accessKeyId() || !credentials.secretAccessKey()) { + // // Empty or "anonymous" credentials are a valid use-case for non-production environments. + // // This behavior matches what the AWS SDK would do. + // ENVOY_LOG_MISC(debug, "Sign exiting early - no credentials found"); + // return absl::OkStatus(); + // } if (headers.Method() == nullptr) { return absl::Status{absl::StatusCode::kInvalidArgument, "Message is missing :method header"}; diff --git a/source/extensions/common/aws/signer_base_impl.h b/source/extensions/common/aws/signer_base_impl.h index 241b22b62a9b..1d1a8d00fbe5 100644 --- a/source/extensions/common/aws/signer_base_impl.h +++ b/source/extensions/common/aws/signer_base_impl.h @@ -63,14 +63,13 @@ using AwsSigningHeaderExclusionVector = std::vector { public: SignerBaseImpl(absl::string_view service_name, absl::string_view region, - const CredentialsProviderSharedPtr& credentials_provider, Server::Configuration::CommonFactoryContext& context, const AwsSigningHeaderExclusionVector& matcher_config, const bool query_string = false, const uint16_t expiration_time = SignatureQueryParameterValues::DefaultExpiration) : service_name_(service_name), region_(region), excluded_header_matchers_(defaultMatchers(context)), - credentials_provider_(credentials_provider), query_string_(query_string), + query_string_(query_string), expiration_time_(expiration_time), time_source_(context.timeSource()), long_date_formatter_(std::string(SignatureConstants::LongDateFormat)), short_date_formatter_(std::string(SignatureConstants::ShortDateFormat)) { @@ -81,13 +80,13 @@ class SignerBaseImpl : public Signer, public Logger::Loggable { } } - absl::Status sign(Http::RequestMessage& message, bool sign_body = false, + absl::Status sign(Http::RequestMessage& message, const Credentials credentials, bool sign_body = false, const absl::string_view override_region = "") override; - absl::Status sign(Http::RequestHeaderMap& headers, const std::string& content_hash, + absl::Status sign(Http::RequestHeaderMap& headers, const Credentials credentials, const std::string& content_hash, const absl::string_view override_region = "") override; - absl::Status signEmptyPayload(Http::RequestHeaderMap& headers, + absl::Status signEmptyPayload(Http::RequestHeaderMap& headers, const Credentials credentials, const absl::string_view override_region = "") override; - absl::Status signUnsignedPayload(Http::RequestHeaderMap& headers, + absl::Status signUnsignedPayload(Http::RequestHeaderMap& headers, const Credentials credentials, const absl::string_view override_region = "") override; protected: @@ -154,7 +153,6 @@ class SignerBaseImpl : public Signer, public Logger::Loggable { Http::Headers::get().ForwardedFor.get(), Http::Headers::get().ForwardedProto.get(), "x-amzn-trace-id"}; std::vector excluded_header_matchers_; - CredentialsProviderSharedPtr credentials_provider_; const bool query_string_; const uint16_t expiration_time_; TimeSource& time_source_; diff --git a/source/extensions/common/aws/sigv4_signer_impl.h b/source/extensions/common/aws/sigv4_signer_impl.h index 27d8406bf978..c3877568811d 100644 --- a/source/extensions/common/aws/sigv4_signer_impl.h +++ b/source/extensions/common/aws/sigv4_signer_impl.h @@ -45,12 +45,11 @@ class SigV4SignerImpl : public SignerBaseImpl { public: SigV4SignerImpl(absl::string_view service_name, absl::string_view region, - const CredentialsProviderSharedPtr& credentials_provider, Server::Configuration::CommonFactoryContext& context, const AwsSigningHeaderExclusionVector& matcher_config, const bool query_string = false, const uint16_t expiration_time = SignatureQueryParameterValues::DefaultExpiration) - : SignerBaseImpl(service_name, region, credentials_provider, context, matcher_config, + : SignerBaseImpl(service_name, region, context, matcher_config, query_string, expiration_time) {} private: diff --git a/source/extensions/common/aws/sigv4a_signer_impl.h b/source/extensions/common/aws/sigv4a_signer_impl.h index 4a33a7b0e5ef..5c20d7965b31 100644 --- a/source/extensions/common/aws/sigv4a_signer_impl.h +++ b/source/extensions/common/aws/sigv4a_signer_impl.h @@ -57,11 +57,10 @@ class SigV4ASignerImpl : public SignerBaseImpl { public: SigV4ASignerImpl( absl::string_view service_name, absl::string_view region, - const CredentialsProviderSharedPtr& credentials_provider, Server::Configuration::CommonFactoryContext& context, const AwsSigningHeaderExclusionVector& matcher_config, const bool query_string = false, const uint16_t expiration_time = SignatureQueryParameterValues::DefaultExpiration) - : SignerBaseImpl(service_name, region, credentials_provider, context, matcher_config, + : SignerBaseImpl(service_name, region, context, matcher_config, query_string, expiration_time) {} private: diff --git a/source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc b/source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc index 96c0aab98907..63e84a746c9b 100644 --- a/source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc +++ b/source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc @@ -147,7 +147,7 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, if (settings.payloadPassthrough()) { setLambdaHeaders(headers, settings.arn(), settings.invocationMode(), settings.hostRewrite()); - auto status = settings.signer().signEmptyPayload(headers, settings.arn().region()); + auto status = settings.signer().signEmptyPayload(headers, settings.credentialsProvider()->getCredentials(),settings.arn().region()); if (!status.ok()) { ENVOY_LOG(debug, "signing failed: {}", status.message()); } @@ -165,7 +165,7 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, auto& hashing_util = Envoy::Common::Crypto::UtilitySingleton::get(); const auto hash = Hex::encode(hashing_util.getSha256Digest(json_buf)); - auto status = settings.signer().sign(headers, hash, settings.arn().region()); + auto status = settings.signer().sign(headers, settings.credentialsProvider()->getCredentials(), hash, settings.arn().region()); if (!status.ok()) { ENVOY_LOG(debug, "signing failed: {}", status.message()); } @@ -229,7 +229,7 @@ Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_strea settings.hostRewrite()); const auto hash = Hex::encode(hashing_util.getSha256Digest(decoding_buffer)); - auto status = settings.signer().sign(*request_headers_, hash, settings.arn().region()); + auto status = settings.signer().sign(*request_headers_,settings.credentialsProvider()->getCredentials(), hash, settings.arn().region()); if (!status.ok()) { ENVOY_LOG(debug, "signing failed: {}", status.message()); } diff --git a/source/extensions/filters/http/aws_lambda/aws_lambda_filter.h b/source/extensions/filters/http/aws_lambda/aws_lambda_filter.h index 221e3b7bf69a..93dfc7e49c80 100644 --- a/source/extensions/filters/http/aws_lambda/aws_lambda_filter.h +++ b/source/extensions/filters/http/aws_lambda/aws_lambda_filter.h @@ -87,6 +87,7 @@ class FilterSettings : public Router::RouteSpecificFilterConfig { virtual InvocationMode invocationMode() const PURE; virtual const std::string& hostRewrite() const PURE; virtual Extensions::Common::Aws::Signer& signer() PURE; + virtual Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentialsProvider() PURE; }; class FilterSettingsImpl : public FilterSettings { @@ -101,6 +102,7 @@ class FilterSettingsImpl : public FilterSettings { InvocationMode invocationMode() const override { return invocation_mode_; } const std::string& hostRewrite() const override { return host_rewrite_; } Extensions::Common::Aws::Signer& signer() override { return *signer_; } + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentialsProvider() override { return credentials_provider_;} private: Arn arn_; @@ -108,6 +110,8 @@ class FilterSettingsImpl : public FilterSettings { bool payload_passthrough_; const std::string host_rewrite_; Extensions::Common::Aws::SignerPtr signer_; + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider_; + }; using FilterSettingsSharedPtr = std::shared_ptr; diff --git a/source/extensions/filters/http/aws_lambda/config.cc b/source/extensions/filters/http/aws_lambda/config.cc index 90a0db569b0b..425da61ae330 100644 --- a/source/extensions/filters/http/aws_lambda/config.cc +++ b/source/extensions/filters/http/aws_lambda/config.cc @@ -76,10 +76,8 @@ absl::StatusOr AwsLambdaFilterFactory::createFilterFactor } const std::string region = arn->region(); - auto credentials_provider = getCredentialsProvider(proto_config, server_context, region); - auto signer = std::make_unique( - service_name, region, std::move(credentials_provider), server_context, + service_name, region, server_context, // TODO: extend API to allow specifying header exclusion. ref: // https://github.com/envoyproxy/envoy/pull/18998 Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}); @@ -107,11 +105,9 @@ AwsLambdaFilterFactory::createRouteSpecificFilterConfigTyped( fmt::format("aws_lambda_filter: Invalid ARN: {}", per_route_config.invoke_config().arn())); } const std::string region = arn->region(); - auto credentials_provider = - getCredentialsProvider(per_route_config.invoke_config(), server_context, region); auto signer = std::make_unique( - service_name, region, std::move(credentials_provider), server_context, + service_name, region, server_context, // TODO: extend API to allow specifying header exclusion. ref: // https://github.com/envoyproxy/envoy/pull/18998 Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}); diff --git a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc index 21c12f1f8066..f221882efac5 100644 --- a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc +++ b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc @@ -11,16 +11,19 @@ namespace Extensions { namespace HttpFilters { namespace AwsRequestSigningFilter { -FilterConfigImpl::FilterConfigImpl(Extensions::Common::Aws::SignerPtr&& signer, +FilterConfigImpl::FilterConfigImpl(Extensions::Common::Aws::SignerPtr&& signer, +Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider, const std::string& stats_prefix, Stats::Scope& scope, const std::string& host_rewrite, bool use_unsigned_payload) - : signer_(std::move(signer)), stats_(Filter::generateStats(stats_prefix, scope)), + : signer_(std::move(signer)), credentials_provider_(credentials_provider), stats_(Filter::generateStats(stats_prefix, scope)), host_rewrite_(host_rewrite), use_unsigned_payload_{use_unsigned_payload} {} Filter::Filter(const std::shared_ptr& config) : config_(config) {} Extensions::Common::Aws::Signer& FilterConfigImpl::signer() { return *signer_; } +Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr FilterConfigImpl::credentialsProvider() { return credentials_provider_; } + FilterStats& FilterConfigImpl::stats() { return stats_; } const std::string& FilterConfigImpl::hostRewrite() const { return host_rewrite_; } @@ -31,8 +34,9 @@ FilterStats Filter::generateStats(const std::string& prefix, Stats::Scope& scope return {ALL_AWS_REQUEST_SIGNING_FILTER_STATS(POOL_COUNTER_PREFIX(scope, final_prefix))}; } -Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) { - auto& config = getConfig(); +Http::FilterHeadersStatus Filter::onCredentialNoLongerPending(FilterConfig& config, Http::RequestHeaderMap& headers, bool end_stream, Envoy::Extensions::Common::Aws::Credentials credentials) +{ + ENVOY_LOG(debug, "aws request signing onCredentialNoLongerPending, {}",credentials.accessKeyId().value()); const auto& host_rewrite = config.hostRewrite(); const bool use_unsigned_payload = config.useUnsignedPayload(); @@ -51,9 +55,9 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, ENVOY_LOG(debug, "aws request signing from decodeHeaders use_unsigned_payload: {}", use_unsigned_payload); if (use_unsigned_payload) { - status = config.signer().signUnsignedPayload(headers); + status = config.signer().signUnsignedPayload(headers, config.credentialsProvider()->getCredentials()); } else { - status = config.signer().signEmptyPayload(headers); + status = config.signer().signEmptyPayload(headers, config.credentialsProvider()->getCredentials()); } if (status.ok()) { config.stats().signing_added_.inc(); @@ -65,6 +69,22 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, return Http::FilterHeadersStatus::Continue; } +Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) { + auto& config = getConfig(); + + if(config.credentialsProvider()->credentialsPending( + [this, &dispatcher = decoder_callbacks_->dispatcher(), &end_stream, &headers, &config](Envoy::Extensions::Common::Aws::Credentials credentials) { + dispatcher.post([this, &config, &headers, end_stream, credentials]() { + this->onCredentialNoLongerPending(config, headers, end_stream, credentials); + }); + } + )) + { + return Http::FilterHeadersStatus::StopIteration; + } + return onCredentialNoLongerPending(config, headers, end_stream, config.credentialsProvider()->getCredentials()); +} + Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_stream) { auto& config = getConfig(); @@ -85,7 +105,7 @@ Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_strea ENVOY_LOG(debug, "aws request signing from decodeData"); ASSERT(request_headers_ != nullptr); - auto status = config.signer().sign(*request_headers_, hash); + auto status = config.signer().sign(*request_headers_, config.credentialsProvider()->getCredentials(), hash); if (status.ok()) { config.stats().signing_added_.inc(); config.stats().payload_signing_added_.inc(); diff --git a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h index 85f85f15e2fd..b67c36b240c8 100644 --- a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h +++ b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h @@ -6,6 +6,7 @@ #include "envoy/stats/stats_macros.h" #include "source/extensions/common/aws/signer.h" +#include "source/extensions/common/aws/credentials_provider.h" #include "source/extensions/filters/http/common/pass_through_filter.h" namespace Envoy { @@ -43,6 +44,11 @@ class FilterConfig : public Router::RouteSpecificFilterConfig { */ virtual Extensions::Common::Aws::Signer& signer() PURE; + /** + * @return the config's credentials provider. + */ + virtual Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentialsProvider() PURE; + /** * @return the filter stats. */ @@ -66,16 +72,20 @@ using FilterConfigSharedPtr = std::shared_ptr; */ class FilterConfigImpl : public FilterConfig { public: - FilterConfigImpl(Extensions::Common::Aws::SignerPtr&& signer, const std::string& stats_prefix, + FilterConfigImpl(Extensions::Common::Aws::SignerPtr&& signer, Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider, + const std::string& stats_prefix, Stats::Scope& scope, const std::string& host_rewrite, bool use_unsigned_payload); Extensions::Common::Aws::Signer& signer() override; + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentialsProvider() override; + FilterStats& stats() override; const std::string& hostRewrite() const override; bool useUnsignedPayload() const override; private: Extensions::Common::Aws::SignerPtr signer_; + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider_; FilterStats stats_; std::string host_rewrite_; const bool use_unsigned_payload_; @@ -96,6 +106,8 @@ class Filter : public Http::PassThroughDecoderFilter, Logger::Loggable config_; Http::RequestHeaderMap* request_headers_{}; diff --git a/source/extensions/filters/http/aws_request_signing/config.cc b/source/extensions/filters/http/aws_request_signing/config.cc index 8e1fcc0ac2cc..09b7600e8aa8 100644 --- a/source/extensions/filters/http/aws_request_signing/config.cc +++ b/source/extensions/filters/http/aws_request_signing/config.cc @@ -3,17 +3,7 @@ #include #include -#include "envoy/common/optref.h" -#include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.h" -#include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.validate.h" -#include "envoy/registry/registry.h" - -#include "source/extensions/common/aws/credentials_provider_impl.h" -#include "source/extensions/common/aws/region_provider_impl.h" -#include "source/extensions/common/aws/sigv4_signer_impl.h" -#include "source/extensions/common/aws/sigv4a_signer_impl.h" -#include "source/extensions/common/aws/utility.h" -#include "source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h" + namespace Envoy { namespace Extensions { @@ -36,12 +26,16 @@ AwsRequestSigningFilterFactory::createFilterFactoryFromProtoTyped( const AwsRequestSigningProtoConfig& config, const std::string& stats_prefix, DualInfo dual_info, Server::Configuration::ServerFactoryContext& server_context) { + auto credentials_provider = createCredentialsProvider(config, server_context); + if (!credentials_provider.ok()) { + return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); + } auto signer = createSigner(config, server_context); if (!signer.ok()) { return absl::InvalidArgumentError(std::string(signer.status().message())); } auto filter_config = - std::make_shared(std::move(signer.value()), stats_prefix, dual_info.scope, + std::make_shared(std::move(signer.value()), credentials_provider.value(), stats_prefix, dual_info.scope, config.host_rewrite(), config.use_unsigned_payload()); return [filter_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { auto filter = std::make_shared(filter_config); @@ -55,22 +49,26 @@ AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( Server::Configuration::ServerFactoryContext& server_context, ProtobufMessage::ValidationVisitor&) { + auto credentials_provider = createCredentialsProvider(per_route_config.aws_request_signing(), server_context); + if (!credentials_provider.ok()) { + return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); + } + auto signer = createSigner(per_route_config.aws_request_signing(), server_context); if (!signer.ok()) { return absl::InvalidArgumentError(std::string(signer.status().message())); } return std::make_shared( - std::move(signer.value()), per_route_config.stat_prefix(), server_context.scope(), + std::move(signer.value()), credentials_provider.value(), per_route_config.stat_prefix(), server_context.scope(), per_route_config.aws_request_signing().host_rewrite(), per_route_config.aws_request_signing().use_unsigned_payload()); } -absl::StatusOr -AwsRequestSigningFilterFactory::createSigner( +absl::StatusOr +AwsRequestSigningFilterFactory::createCredentialsProvider( const AwsRequestSigningProtoConfig& config, Server::Configuration::ServerFactoryContext& server_context) { - std::string region = config.region(); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; @@ -121,7 +119,7 @@ AwsRequestSigningFilterFactory::createSigner( } else if (config.credential_provider().custom_credential_provider_chain()) { // Custom credential provider chain if (has_credential_provider_settings) { - credentials_provider = + return std::make_shared( server_context, region, config.credential_provider()); } @@ -130,21 +128,56 @@ AwsRequestSigningFilterFactory::createSigner( if (has_credential_provider_settings) { credential_provider_config = config.credential_provider(); } - credentials_provider = + return std::make_shared( server_context.api(), makeOptRef(server_context), server_context.singletonManager(), region, nullptr, credential_provider_config); } } else { // No credential provider settings provided, so make the default credentials provider chain - credentials_provider = + return std::make_shared( server_context.api(), makeOptRef(server_context), server_context.singletonManager(), region, nullptr, credential_provider_config); } - if (!credentials_provider.ok()) { return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); + +} + +absl::StatusOr +AwsRequestSigningFilterFactory::createSigner( + const AwsRequestSigningProtoConfig& config, + Server::Configuration::ServerFactoryContext& server_context) { + + std::string region = config.region(); + + envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + + // If we have an overriding credential provider configuration, read it here as it may contain + // references to the region + envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = {}; + if (config.has_credential_provider()) { + if (config.credential_provider().has_credentials_file_provider()) { + credential_file_config = config.credential_provider().credentials_file_provider(); + } + } + + if (region.empty()) { + auto region_provider = + std::make_shared(credential_file_config); + absl::optional regionOpt; + if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { + regionOpt = region_provider->getRegionSet(); + } else { + regionOpt = region_provider->getRegion(); + } + if (!regionOpt.has_value()) { + return absl::InvalidArgumentError( + "AWS region is not set in xDS configuration and failed to retrieve from " + "environment variable or AWS profile/config files."); + } + region = regionOpt.value(); } const auto matcher_config = Extensions::Common::Aws::AwsSigningHeaderExclusionVector( @@ -160,7 +193,7 @@ AwsRequestSigningFilterFactory::createSigner( if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { return std::make_unique( - config.service_name(), region, credentials_provider.value(), server_context, matcher_config, + config.service_name(), region, server_context, matcher_config, query_string, expiration_time); } else { // Verify that we have not specified a region set when using sigv4 algorithm @@ -170,7 +203,7 @@ AwsRequestSigningFilterFactory::createSigner( "can be specified when using signing_algorithm: AWS_SIGV4A."); } return std::make_unique( - config.service_name(), region, credentials_provider.value(), server_context, matcher_config, + config.service_name(), region, server_context, matcher_config, query_string, expiration_time); } } diff --git a/source/extensions/filters/http/aws_request_signing/config.h b/source/extensions/filters/http/aws_request_signing/config.h index 48c3d40f6dd5..7d4f960c3d14 100644 --- a/source/extensions/filters/http/aws_request_signing/config.h +++ b/source/extensions/filters/http/aws_request_signing/config.h @@ -5,6 +5,16 @@ #include "source/extensions/common/aws/signer.h" #include "source/extensions/filters/http/common/factory_base.h" +#include "source/extensions/common/aws/credentials_provider_impl.h" +#include "source/extensions/common/aws/region_provider_impl.h" +#include "source/extensions/common/aws/sigv4_signer_impl.h" +#include "source/extensions/common/aws/sigv4a_signer_impl.h" +#include "source/extensions/common/aws/utility.h" +#include "source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h" +#include "envoy/common/optref.h" +#include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.h" +#include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.validate.h" +#include "envoy/registry/registry.h" namespace Envoy { namespace Extensions { @@ -40,6 +50,12 @@ class AwsRequestSigningFilterFactory absl::StatusOr createSigner(const AwsRequestSigningProtoConfig& config, Server::Configuration::ServerFactoryContext& server_context); + + absl::StatusOr +createCredentialsProvider( + const AwsRequestSigningProtoConfig& config, + Server::Configuration::ServerFactoryContext& server_context); + }; using UpstreamAwsRequestSigningFilterFactory = AwsRequestSigningFilterFactory; diff --git a/source/extensions/grpc_credentials/aws_iam/config.cc b/source/extensions/grpc_credentials/aws_iam/config.cc index 9a2ab86eeb46..9aed6655107e 100644 --- a/source/extensions/grpc_credentials/aws_iam/config.cc +++ b/source/extensions/grpc_credentials/aws_iam/config.cc @@ -71,12 +71,12 @@ std::shared_ptr AwsIamGrpcCredentialsFactory::getChann context.api(), absl::nullopt /*Empty factory context*/, context.singletonManager(), region, Common::Aws::Utility::fetchMetadataWithCurl); auto signer = std::make_unique( - config.service_name(), region, credentials_provider, context, + config.service_name(), region, context, // TODO: extend API to allow specifying header exclusion. ref: // https://github.com/envoyproxy/envoy/pull/18998 Common::Aws::AwsSigningHeaderExclusionVector{}); std::shared_ptr new_call_creds = grpc::MetadataCredentialsFromPlugin( - std::make_unique(std::move(signer))); + std::make_unique(std::move(signer), credentials_provider)); if (call_creds == nullptr) { call_creds = new_call_creds; } else { @@ -106,7 +106,7 @@ AwsIamHeaderAuthenticator::GetMetadata(grpc::string_ref service_url, grpc::strin auto message = buildMessageToSign(absl::string_view(service_url.data(), service_url.length()), absl::string_view(method_name.data(), method_name.length())); - auto status = signer_->sign(message, false); + auto status = signer_->sign(message, credentials_provider_->getCredentials(), false); if (!status.ok()) { return {grpc::StatusCode::INTERNAL, std::string{status.message()}}; } diff --git a/source/extensions/grpc_credentials/aws_iam/config.h b/source/extensions/grpc_credentials/aws_iam/config.h index 16efbcead7c6..5daf9c786257 100644 --- a/source/extensions/grpc_credentials/aws_iam/config.h +++ b/source/extensions/grpc_credentials/aws_iam/config.h @@ -34,7 +34,8 @@ class AwsIamGrpcCredentialsFactory : public Grpc::GoogleGrpcCredentialsFactory { */ class AwsIamHeaderAuthenticator : public grpc::MetadataCredentialsPlugin { public: - AwsIamHeaderAuthenticator(Common::Aws::SignerPtr signer) : signer_(std::move(signer)) {} + AwsIamHeaderAuthenticator(Common::Aws::SignerPtr signer, + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider) : signer_(std::move(signer)), credentials_provider_(credentials_provider) {} grpc::Status GetMetadata(grpc::string_ref, grpc::string_ref, const grpc::AuthContext&, std::multimap* metadata) override; @@ -49,6 +50,8 @@ class AwsIamHeaderAuthenticator : public grpc::MetadataCredentialsPlugin { std::multimap& metadata); const Common::Aws::SignerPtr signer_; + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider_; + }; } // namespace AwsIam diff --git a/test/extensions/common/aws/mocks.h b/test/extensions/common/aws/mocks.h index bb556d751074..0f699d5e89b8 100644 --- a/test/extensions/common/aws/mocks.h +++ b/test/extensions/common/aws/mocks.h @@ -44,10 +44,10 @@ class MockSigner : public Signer { MockSigner(); ~MockSigner() override; - MOCK_METHOD(absl::Status, sign, (Http::RequestMessage&, bool, absl::string_view)); - MOCK_METHOD(absl::Status, sign, (Http::RequestHeaderMap&, const std::string&, absl::string_view)); - MOCK_METHOD(absl::Status, signEmptyPayload, (Http::RequestHeaderMap&, absl::string_view)); - MOCK_METHOD(absl::Status, signUnsignedPayload, (Http::RequestHeaderMap&, absl::string_view)); + MOCK_METHOD(absl::Status, sign, (Http::RequestMessage&,const Credentials, bool, absl::string_view)); + MOCK_METHOD(absl::Status, sign, (Http::RequestHeaderMap&, const Credentials, const std::string&, absl::string_view)); + MOCK_METHOD(absl::Status, signEmptyPayload, (Http::RequestHeaderMap&, const Credentials, absl::string_view)); + MOCK_METHOD(absl::Status, signUnsignedPayload, (Http::RequestHeaderMap&, const Credentials, absl::string_view)); }; class MockFetchMetadata { diff --git a/test/extensions/common/aws/sigv4_signer_corpus_test.cc b/test/extensions/common/aws/sigv4_signer_corpus_test.cc index 42dd291275cc..885b9692f6ae 100644 --- a/test/extensions/common/aws/sigv4_signer_corpus_test.cc +++ b/test/extensions/common/aws/sigv4_signer_corpus_test.cc @@ -252,10 +252,8 @@ TEST_P(SigV4SignerCorpusTest, SigV4SignerCorpusHeaderSigning) { setDate(); addBodySigningIfRequired(); - auto* credentials_provider_ = new NiceMock(); - SigV4SignerImpl headersigner_( - service_, region_, CredentialsProviderSharedPtr{credentials_provider_}, context_, + service_, region_, context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, expiration_); auto signer_friend = SigV4SignerImplFriend(&headersigner_); @@ -308,12 +306,10 @@ TEST_P(SigV4SignerCorpusTest, SigV4SignerCorpusQueryStringSigning) { setDate(); addBodySigningIfRequired(); - auto* credentials_provider_ = new NiceMock(); - const auto calculated_canonical_headers = Utility::canonicalizeHeaders(message_.headers(), {}); SigV4SignerImpl querysigner_( - service_, region_, CredentialsProviderSharedPtr{credentials_provider_}, context_, + service_, region_, context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, expiration_); auto signer_friend = SigV4SignerImplFriend(&querysigner_); diff --git a/test/extensions/common/aws/sigv4_signer_impl_test.cc b/test/extensions/common/aws/sigv4_signer_impl_test.cc index 51a3f079b2a4..8e09af4cfa61 100644 --- a/test/extensions/common/aws/sigv4_signer_impl_test.cc +++ b/test/extensions/common/aws/sigv4_signer_impl_test.cc @@ -23,7 +23,7 @@ class SigV4SignerImplTest : public testing::Test { SigV4SignerImplTest() : credentials_provider_(new NiceMock()), message_(new Http::RequestMessageImpl()), - signer_("service", "region", CredentialsProviderSharedPtr{credentials_provider_}, context_, + signer_("service", "region", context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}), credentials_("akid", "secret"), token_credentials_("akid", "secret", "token") { // 20180102T030405Z @@ -53,12 +53,12 @@ class SigV4SignerImplTest : public testing::Test { headers.addCopy(Http::LowerCaseString("host"), "www.example.com"); SigV4SignerImpl signer(service_name, "region", - CredentialsProviderSharedPtr{credentials_provider}, context_, + context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, 5); if (use_unsigned_payload) { - status = signer.signUnsignedPayload(headers, override_region); + status = signer.signUnsignedPayload(headers, credentials_provider->getCredentials(), override_region); } else { - status = signer.signEmptyPayload(headers, override_region); + status = signer.signEmptyPayload(headers, credentials_provider->getCredentials(), override_region); } EXPECT_TRUE(status.ok()); @@ -84,10 +84,10 @@ class SigV4SignerImplTest : public testing::Test { } SigV4SignerImpl signer(service_name, "region", - CredentialsProviderSharedPtr{credentials_provider}, context_, + context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, 5); - auto status = signer.signUnsignedPayload(extra_headers, override_region); + auto status = signer.signUnsignedPayload(extra_headers, credentials_provider->getCredentials(), override_region); EXPECT_TRUE(status.ok()); auto query_parameters = Http::Utility::QueryParamsMulti::parseQueryString( extra_headers.Path()->value().getStringView()); @@ -107,7 +107,7 @@ class SigV4SignerImplTest : public testing::Test { // No authorization header should be present when the credentials are empty TEST_F(SigV4SignerImplTest, AnonymousCredentials) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(Credentials())); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_TRUE(message_->headers().get(Http::CustomHeaders::get().Authorization).empty()); } @@ -115,7 +115,7 @@ TEST_F(SigV4SignerImplTest, AnonymousCredentials) { // HTTP :method header is required TEST_F(SigV4SignerImplTest, MissingMethod) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials()); EXPECT_EQ(status.message(), "Message is missing :method header"); EXPECT_TRUE(message_->headers().get(Http::CustomHeaders::get().Authorization).empty()); } @@ -124,7 +124,7 @@ TEST_F(SigV4SignerImplTest, MissingMethod) { TEST_F(SigV4SignerImplTest, MissingPath) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); addMethod("GET"); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials()); EXPECT_EQ(status.message(), "Message is missing :path header"); EXPECT_TRUE(message_->headers().get(Http::CustomHeaders::get().Authorization).empty()); } @@ -139,7 +139,7 @@ TEST_F(SigV4SignerImplTest, DontDuplicateHeaders) { addHeader("x-amz-security-token", "existing_value_2"); addHeader("x-amz-date", "existing_value_3"); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials()); EXPECT_EQ(message_->headers().get(Http::CustomHeaders::get().Authorization).size(), 1); ENVOY_LOG_MISC(info, "authorization {}", message_->headers() @@ -176,7 +176,7 @@ TEST_F(SigV4SignerImplTest, SignDateHeader) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); addMethod("GET"); addPath("/"); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_FALSE(message_->headers().get(SigV4SignatureHeaders::get().ContentSha256).empty()); EXPECT_EQ("20180102T030400Z", @@ -198,7 +198,7 @@ TEST_F(SigV4SignerImplTest, SignSecurityTokenHeader) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(token_credentials_)); addMethod("GET"); addPath("/"); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_EQ("token", message_->headers() .get(SigV4SignatureHeaders::get().SecurityToken)[0] @@ -220,7 +220,7 @@ TEST_F(SigV4SignerImplTest, SignEmptyContentHeader) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); addMethod("GET"); addPath("/"); - auto status = signer_.sign(*message_, true); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials(),true); EXPECT_TRUE(status.ok()); EXPECT_EQ(SigV4SignatureConstants::HashedEmptyString, message_->headers() @@ -244,7 +244,7 @@ TEST_F(SigV4SignerImplTest, SignContentHeader) { addMethod("POST"); addPath("/"); setBody("test1234"); - auto status = signer_.sign(*message_, true); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials(), true); EXPECT_TRUE(status.ok()); EXPECT_EQ("937e8d5fbb48bd4949536cd65b8d35c426b80d2f830c5c308e2cdec422ae2244", message_->headers() @@ -268,7 +268,7 @@ TEST_F(SigV4SignerImplTest, SignContentHeaderOverrideRegion) { addMethod("POST"); addPath("/"); setBody("test1234"); - auto status = signer_.sign(*message_, true, "region1"); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials(), true, "region1"); EXPECT_TRUE(status.ok()); EXPECT_EQ("937e8d5fbb48bd4949536cd65b8d35c426b80d2f830c5c308e2cdec422ae2244", message_->headers() @@ -295,7 +295,7 @@ TEST_F(SigV4SignerImplTest, SignExtraHeaders) { addHeader("a", "a_value"); addHeader("b", "b_value"); addHeader("c", "c_value"); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_EQ("AWS4-HMAC-SHA256 Credential=akid/20180102/region/service/aws4_request, " "SignedHeaders=a;b;c;x-amz-content-sha256;x-amz-date, " @@ -314,7 +314,7 @@ TEST_F(SigV4SignerImplTest, SignHostHeader) { addMethod("GET"); addPath("/"); addHeader("host", "www.example.com"); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_EQ("AWS4-HMAC-SHA256 Credential=akid/20180102/region/service/aws4_request, " "SignedHeaders=host;x-amz-content-sha256;x-amz-date, " @@ -339,10 +339,10 @@ TEST_F(SigV4SignerImplTest, QueryStringDefault5s) { headers.addCopy(Http::LowerCaseString("host"), "example.service.zz"); headers.addCopy("testheader", "value1"); SigV4SignerImpl querysigner("service", "region", - CredentialsProviderSharedPtr{credentials_provider}, context_, + context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true); - auto status = querysigner.signUnsignedPayload(headers); + auto status = querysigner.signUnsignedPayload(headers, credentials_provider_->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_TRUE(absl::StrContains(headers.getPathValue(), "X-Amz-Expires=5&")); } diff --git a/test/extensions/common/aws/sigv4a_signer_corpus_test.cc b/test/extensions/common/aws/sigv4a_signer_corpus_test.cc index f383b9a69922..84cb476ab2de 100644 --- a/test/extensions/common/aws/sigv4a_signer_corpus_test.cc +++ b/test/extensions/common/aws/sigv4a_signer_corpus_test.cc @@ -275,10 +275,8 @@ TEST_P(SigV4ASignerCorpusTest, SigV4ASignerCorpusHeaderSigning) { setDate(); addBodySigningIfRequired(); - auto* credentials_provider_ = new NiceMock(); - SigV4ASignerImpl headersigner_( - service_, region_, CredentialsProviderSharedPtr{credentials_provider_}, context_, + service_, region_, context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, expiration_); auto signer_friend = SigV4ASignerImplFriend(&headersigner_); @@ -337,10 +335,9 @@ TEST_P(SigV4ASignerCorpusTest, SigV4ASignerCorpusQueryStringSigning) { addBodySigningIfRequired(); const auto calculated_canonical_headers = Utility::canonicalizeHeaders(message_.headers(), {}); - auto* credentials_provider_ = new NiceMock(); SigV4ASignerImpl querysigner_( - service_, region_, CredentialsProviderSharedPtr{credentials_provider_}, context_, + service_, region_, context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, expiration_); auto signer_friend = SigV4ASignerImplFriend(&querysigner_); diff --git a/test/extensions/common/aws/sigv4a_signer_impl_test.cc b/test/extensions/common/aws/sigv4a_signer_impl_test.cc index ab012fb247ee..c98974cbf36c 100644 --- a/test/extensions/common/aws/sigv4a_signer_impl_test.cc +++ b/test/extensions/common/aws/sigv4a_signer_impl_test.cc @@ -56,7 +56,6 @@ class SigV4ASignerImplTest : public testing::Test { } return SigV4ASignerImpl{"service", "region", - getTestCredentialsProvider(), context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, query_string, @@ -79,16 +78,17 @@ class SigV4ASignerImplTest : public testing::Test { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); // Sign the message using our signing algorithm auto signer_ = getTestSigner(query_string, expiration_time); + auto credentials_provider = getTestCredentialsProvider(); switch (signing_type) { case EmptyPayload: - status = signer_.signEmptyPayload(message->headers(), override_region); + status = signer_.signEmptyPayload(message->headers(), credentials_provider->getCredentials(), override_region); break; case NormalSign: - status = signer_.sign(*message, sign_body, override_region); + status = signer_.sign(*message, credentials_provider->getCredentials(), sign_body, override_region); break; case UnsignedPayload: - status = signer_.signUnsignedPayload(message->headers(), override_region); + status = signer_.signUnsignedPayload(message->headers(), credentials_provider->getCredentials(), override_region); break; } EXPECT_TRUE(status.ok()); @@ -150,7 +150,9 @@ TEST_F(SigV4ASignerImplTest, AnonymousCredentials) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(Credentials())); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_); + auto credentials_provider = getTestCredentialsProvider(); + + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_TRUE(message_->headers().get(Http::CustomHeaders::get().Authorization).empty()); } @@ -159,7 +161,8 @@ TEST_F(SigV4ASignerImplTest, AnonymousCredentials) { TEST_F(SigV4ASignerImplTest, MissingMethod) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_); + auto credentials_provider = getTestCredentialsProvider(); + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_EQ(status.message(), "Message is missing :method header"); EXPECT_TRUE(message_->headers().get(Http::CustomHeaders::get().Authorization).empty()); } @@ -169,7 +172,8 @@ TEST_F(SigV4ASignerImplTest, MissingPath) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); addMethod("GET"); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_); + auto credentials_provider = getTestCredentialsProvider(); + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_EQ(status.message(), "Message is missing :path header"); EXPECT_TRUE(message_->headers().get(Http::CustomHeaders::get().Authorization).empty()); } @@ -180,12 +184,13 @@ TEST_F(SigV4ASignerImplTest, DontDuplicateHeaders) { addMethod("GET"); addPath("/"); auto signer_ = getTestSigner(false); + auto credentials_provider = getTestCredentialsProvider(); addHeader("authorization", "existing_value"); addHeader("x-amz-security-token", "existing_value_2"); addHeader("x-amz-date", "existing_value_3"); addHeader("x-amz-region-set", "existing_value_4"); - auto status = signer_.sign(*message_); + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_EQ(message_->headers().get(Http::CustomHeaders::get().Authorization).size(), 1); EXPECT_FALSE(absl::StrContains( @@ -218,7 +223,9 @@ TEST_F(SigV4ASignerImplTest, QueryStringDoesntModifyAuthorization) { addPath("/"); addHeader("Authorization", "testValue"); auto signer_ = getTestSigner(true); - auto status = signer_.sign(*message_); + auto credentials_provider = getTestCredentialsProvider(); + + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_EQ(message_->headers().get(Http::CustomHeaders::get().Authorization)[0]->value(), "testValue"); @@ -230,7 +237,9 @@ TEST_F(SigV4ASignerImplTest, SignDateHeader) { addMethod("GET"); addPath("/"); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_); + auto credentials_provider = getTestCredentialsProvider(); + + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_FALSE(message_->headers().get(SigV4ASignatureHeaders::get().ContentSha256).empty()); EXPECT_EQ( @@ -249,7 +258,8 @@ TEST_F(SigV4ASignerImplTest, SignSecurityTokenHeader) { addMethod("GET"); addPath("/"); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_); + auto credentials_provider = getTestCredentialsProvider(); + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_EQ("token", message_->headers() .get(SigV4ASignatureHeaders::get().SecurityToken)[0] @@ -269,7 +279,8 @@ TEST_F(SigV4ASignerImplTest, SignEmptyContentHeader) { addMethod("GET"); addPath("/"); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_, true); + auto credentials_provider = getTestCredentialsProvider(); + auto status = signer_.sign(*message_, credentials_provider->getCredentials(), true); EXPECT_TRUE(status.ok()); EXPECT_EQ(SigV4ASignatureConstants::HashedEmptyString, message_->headers() @@ -290,7 +301,8 @@ TEST_F(SigV4ASignerImplTest, SignContentHeader) { addPath("/"); setBody("test1234"); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_, true); + auto credentials_provider = getTestCredentialsProvider(); + auto status = signer_.sign(*message_, credentials_provider->getCredentials(),true); EXPECT_TRUE(status.ok()); EXPECT_EQ("937e8d5fbb48bd4949536cd65b8d35c426b80d2f830c5c308e2cdec422ae2244", message_->headers() @@ -311,7 +323,8 @@ TEST_F(SigV4ASignerImplTest, SignContentHeaderOverrideRegion) { addPath("/"); setBody("test1234"); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_, true, "region1"); + auto credentials_provider = getTestCredentialsProvider(); + auto status = signer_.sign(*message_, credentials_provider->getCredentials(), true, "region1"); EXPECT_TRUE(status.ok()); EXPECT_EQ("937e8d5fbb48bd4949536cd65b8d35c426b80d2f830c5c308e2cdec422ae2244", message_->headers() @@ -334,7 +347,9 @@ TEST_F(SigV4ASignerImplTest, SignExtraHeaders) { addHeader("b", "b_value"); addHeader("c", "c_value"); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_); + auto credentials_provider = getTestCredentialsProvider(); + + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_THAT( message_->headers().get(Http::CustomHeaders::get().Authorization)[0]->value().getStringView(), @@ -350,7 +365,9 @@ TEST_F(SigV4ASignerImplTest, SignHostHeader) { addPath("/"); addHeader("host", "www.example.com"); auto signer_ = getTestSigner(false); - auto status = signer_.sign(*message_); + auto credentials_provider = getTestCredentialsProvider(); + + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_THAT( message_->headers().get(Http::CustomHeaders::get().Authorization)[0]->value().getStringView(), @@ -531,10 +548,12 @@ TEST_F(SigV4ASignerImplTest, QueryStringDefault5s) { headers.setPath("/example/path"); headers.addCopy(Http::LowerCaseString("host"), "example.service.zz"); headers.addCopy("testheader", "value1"); - SigV4ASignerImpl querysigner("service", "region", getTestCredentialsProvider(), context_, + auto credentials_provider = getTestCredentialsProvider(); + + SigV4ASignerImpl querysigner("service", "region", context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true); - auto status = querysigner.signUnsignedPayload(headers); + auto status = querysigner.signUnsignedPayload(headers, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_TRUE(absl::StrContains(headers.getPathValue(), "X-Amz-Expires=5&")); } From 167124a8a9d127693bd258437e52e20d3c568eee Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Wed, 1 Jan 2025 12:03:17 +0000 Subject: [PATCH 15/21] first async cut Signed-off-by: Nigel Brittain --- .../common/aws/credentials_provider_impl.cc | 22 +++++++++++--- .../common/aws/credentials_provider_impl.h | 2 ++ .../aws_request_signing_filter.cc | 9 ++++-- .../aws/credentials_provider_impl_test.cc | 30 +++++++++---------- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 8d3153619579..14bace9c3acf 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -270,9 +270,10 @@ void MetadataCredentialsProviderBase::ThreadLocalCredentialsCache::onClusterRemo if(cb) { ENVOY_LOG_MISC(debug,"Adding credentials pending callback to queue"); + credential_pending_callbacks_.push_back(std::move(cb)); } return credentials_pending_; - } + } // Async provider uses its own refresh mechanism. Calling refreshIfNeeded() here is not thread safe. Credentials MetadataCredentialsProviderBase::getCredentials() { @@ -331,20 +332,33 @@ void MetadataCredentialsProviderBase::handleFetchDone() { cache_duration_timer_->enableTimer(cache_duration_); } } - - // We are now no longer waiting for credentials - credentials_pending_.exchange(false); } } void MetadataCredentialsProviderBase::setCredentialsToAllThreads( CredentialsConstUniquePtr&& creds) { + + ENVOY_LOG_MISC(debug, "Setting credentials to all threads"); + + // Call all of our callbacks to unblock pending requests + for(const auto& cb: credential_pending_callbacks_) + { + cb(Credentials(creds->accessKeyId().has_value()?creds->accessKeyId().value():"", + creds->secretAccessKey().has_value()?creds->accessKeyId().value():"", + creds->sessionToken().has_value()?creds->accessKeyId().value():"")); + } + credential_pending_callbacks_.clear(); + CredentialsConstSharedPtr shared_credentials = std::move(creds); if (tls_slot_) { tls_slot_->runOnAllThreads([shared_credentials](OptRef obj) { obj->credentials_ = shared_credentials; }); } + + // We are now no longer waiting for credentials + credentials_pending_.exchange(false); + } CredentialsFileCredentialsProvider::CredentialsFileCredentialsProvider( diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index f06d43c8f167..fcdeded4f826 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -250,6 +250,8 @@ class MetadataCredentialsProviderBase : public CachedCredentialsProviderBase { std::atomic is_creating_ = false; // Are credentials pending? std::atomic credentials_pending_ = true; + // Callbacks list for pending credentials + std::vector credential_pending_callbacks_ = {}; }; /** diff --git a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc index f221882efac5..ab5a3ad2a35f 100644 --- a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc +++ b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc @@ -71,6 +71,7 @@ Http::FilterHeadersStatus Filter::onCredentialNoLongerPending(FilterConfig& conf Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) { auto& config = getConfig(); + ENVOY_LOG_MISC(debug, "******* HERE"); if(config.credentialsProvider()->credentialsPending( [this, &dispatcher = decoder_callbacks_->dispatcher(), &end_stream, &headers, &config](Envoy::Extensions::Common::Aws::Credentials credentials) { @@ -78,9 +79,13 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, this->onCredentialNoLongerPending(config, headers, end_stream, credentials); }); } - )) + )) { - return Http::FilterHeadersStatus::StopIteration; + ENVOY_LOG_MISC(debug, "Credentials are pending"); + return Http::FilterHeadersStatus::StopAllIterationAndBuffer; + } else + { + ENVOY_LOG_MISC(debug, "Credentials are not pending"); } return onCredentialNoLongerPending(config, headers, end_stream, config.credentialsProvider()->getCredentials()); } diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index 1f722d38a11a..1c6426cef58f 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -2521,7 +2521,7 @@ class MockCredentialsProviderChainFactories : public CredentialsProviderChainFac MOCK_METHOD( CredentialsProviderSharedPtr, createWebIdentityCredentialsProvider, - (Server::Configuration::ServerFactoryContext&, CreateMetadataFetcherCb, absl::string_view, + (Server::Configuration::ServerFactoryContext&, Singleton::Manager&, CreateMetadataFetcherCb, absl::string_view, MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider&, absl::string_view), @@ -2559,7 +2559,7 @@ class MockCustomCredentialsProviderChainFactories : public CustomCredentialsProv MOCK_METHOD( CredentialsProviderSharedPtr, createWebIdentityCredentialsProvider, - (Server::Configuration::ServerFactoryContext&, CreateMetadataFetcherCb, absl::string_view, + (Server::Configuration::ServerFactoryContext&, Singleton::Manager&,CreateMetadataFetcherCb, absl::string_view, MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider&, absl::string_view), @@ -2686,7 +2686,7 @@ TEST_F(DefaultCredentialsProviderChainTest, NoWebIdentitySessionName) { time_system_.setSystemTime(std::chrono::milliseconds(1234567890)); EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)); + Ref(context_), _,_, "sts.region.amazonaws.com:443", _, _, _, _)); EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; @@ -2704,7 +2704,7 @@ TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithSessionName) { EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)); + Ref(context_),_, _, "sts.region.amazonaws.com:443", _, _, _, _)); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; @@ -2720,7 +2720,7 @@ TEST_F(DefaultCredentialsProviderChainTest, NoWebIdentityWithBlankConfig) { EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)) + Ref(context_),_, _, "sts.region.amazonaws.com:443", _, _, _, _)) .Times(0); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; @@ -2743,8 +2743,8 @@ TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomSessionName) { std::string role_session_name; EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)) - .WillOnce(Invoke(WithArg<5>( + Ref(context_), _,_, "sts.region.amazonaws.com:443", _, _, _, _)) + .WillOnce(Invoke(WithArg<6>( [&role_session_name]( const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& provider) -> CredentialsProviderSharedPtr { @@ -2773,8 +2773,8 @@ TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomRoleArn) { std::string role_arn; EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)) - .WillOnce(Invoke(WithArg<5>( + Ref(context_), _,_, "sts.region.amazonaws.com:443", _, _, _, _)) + .WillOnce(Invoke(WithArg<6>( [&role_arn]( const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& provider) -> CredentialsProviderSharedPtr { @@ -2803,8 +2803,8 @@ TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomDataSource) { std::string inline_string; EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)) - .WillOnce(Invoke(WithArg<5>( + Ref(context_),_, _, "sts.region.amazonaws.com:443", _, _, _, _)) + .WillOnce(Invoke(WithArg<6>( [&inline_string]( const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& provider) -> CredentialsProviderSharedPtr { @@ -2843,7 +2843,7 @@ TEST_F(DefaultCredentialsProviderChainTest, CredentialsFileWithCustomDataSource) createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _, "sts.region.amazonaws.com:443", _, _, _, _)); + Ref(context_),_, _, "sts.region.amazonaws.com:443", _, _, _, _)); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; credential_provider_config.mutable_credentials_file_provider() @@ -2920,7 +2920,7 @@ TEST_F(CustomCredentialsProviderChainTest, CreateFileCredentialProviderOnly) { EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)); EXPECT_CALL(factories, - createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _)) + createWebIdentityCredentialsProvider(Ref(server_context),_, _, _, _, _, _, _)) .Times(0); auto chain = std::make_shared( @@ -2942,7 +2942,7 @@ TEST_F(CustomCredentialsProviderChainTest, CreateWebIdentityCredentialProviderOn EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)) .Times(0); EXPECT_CALL(factories, - createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _)); + createWebIdentityCredentialsProvider(Ref(server_context), _,_, _, _, _, _, _)); auto chain = std::make_shared( server_context, region, cred_provider, factories); @@ -2965,7 +2965,7 @@ TEST_F(CustomCredentialsProviderChainTest, CreateFileAndWebProviders) { EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)); EXPECT_CALL(factories, - createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _)); + createWebIdentityCredentialsProvider(Ref(server_context), _,_, _, _, _, _, _)); auto chain = std::make_shared( server_context, region, cred_provider, factories); From 353ca783c57cc40d5b228f823064b5bbf3b67884 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Wed, 1 Jan 2025 23:55:57 +0000 Subject: [PATCH 16/21] add async handling Signed-off-by: Nigel Brittain --- .../common/aws/credentials_provider.h | 12 +- .../common/aws/credentials_provider_impl.cc | 110 +++++++++--------- .../common/aws/credentials_provider_impl.h | 3 +- source/extensions/common/aws/signer.h | 16 ++- .../extensions/common/aws/signer_base_impl.cc | 32 ++--- .../extensions/common/aws/signer_base_impl.h | 14 +-- .../extensions/common/aws/sigv4_signer_impl.h | 4 +- .../common/aws/sigv4a_signer_impl.h | 4 +- .../http/aws_lambda/aws_lambda_filter.cc | 10 +- .../http/aws_lambda/aws_lambda_filter.h | 5 +- .../filters/http/aws_request_signing/BUILD | 1 + .../aws_request_signing_filter.cc | 97 ++++++++++----- .../aws_request_signing_filter.h | 24 ++-- .../http/aws_request_signing/config.cc | 45 ++++--- .../filters/http/aws_request_signing/config.h | 16 +-- .../grpc_credentials/aws_iam/config.h | 7 +- .../aws/credentials_provider_impl_test.cc | 28 ++--- test/extensions/common/aws/mocks.h | 12 +- .../common/aws/sigv4_signer_corpus_test.cc | 12 +- .../common/aws/sigv4_signer_impl_test.cc | 20 ++-- .../common/aws/sigv4a_signer_corpus_test.cc | 12 +- .../common/aws/sigv4a_signer_impl_test.cc | 49 ++++---- 22 files changed, 299 insertions(+), 234 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider.h b/source/extensions/common/aws/credentials_provider.h index 3c4e439a9fcd..262f6ed30042 100644 --- a/source/extensions/common/aws/credentials_provider.h +++ b/source/extensions/common/aws/credentials_provider.h @@ -58,7 +58,6 @@ class Credentials { */ class CredentialsProvider { public: - using CredentialsPendingCallback = std::function; virtual ~CredentialsProvider() = default; @@ -71,12 +70,13 @@ class CredentialsProvider { virtual Credentials getCredentials() PURE; /** - * Check if credentials are pending, which supports async credential fetching. - * - * @return bool true if credentials are pending, false otherwise + * Check if credentials are pending, which supports async credential fetching. + * + * @return bool true if credentials are pending, false otherwise */ - virtual bool credentialsPending(ABSL_ATTRIBUTE_UNUSED CredentialsPendingCallback&& cb) { return false; } - + virtual bool credentialsPending(ABSL_ATTRIBUTE_UNUSED CredentialsPendingCallback&& cb) { + return false; + } }; using CredentialsConstSharedPtr = std::shared_ptr; diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 14bace9c3acf..00c0ddc09633 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -266,14 +266,13 @@ void MetadataCredentialsProviderBase::ThreadLocalCredentialsCache::onClusterRemo } }; - bool MetadataCredentialsProviderBase::credentialsPending(CredentialsPendingCallback&& cb) { - if(cb) - { - ENVOY_LOG_MISC(debug,"Adding credentials pending callback to queue"); - credential_pending_callbacks_.push_back(std::move(cb)); - } - return credentials_pending_; +bool MetadataCredentialsProviderBase::credentialsPending(CredentialsPendingCallback&& cb) { + if (cb) { + ENVOY_LOG_MISC(debug, "Adding credentials pending callback to queue"); + credential_pending_callbacks_.push_back(std::move(cb)); } + return credentials_pending_; +} // Async provider uses its own refresh mechanism. Calling refreshIfNeeded() here is not thread safe. Credentials MetadataCredentialsProviderBase::getCredentials() { @@ -341,11 +340,15 @@ void MetadataCredentialsProviderBase::setCredentialsToAllThreads( ENVOY_LOG_MISC(debug, "Setting credentials to all threads"); // Call all of our callbacks to unblock pending requests - for(const auto& cb: credential_pending_callbacks_) - { - cb(Credentials(creds->accessKeyId().has_value()?creds->accessKeyId().value():"", - creds->secretAccessKey().has_value()?creds->accessKeyId().value():"", - creds->sessionToken().has_value()?creds->accessKeyId().value():"")); + for (const auto& cb : credential_pending_callbacks_) { + ENVOY_LOG_MISC(debug, "*** Calling pending callback {} {} {}", + creds->accessKeyId().has_value() ? creds->accessKeyId().value() : "", + creds->secretAccessKey().has_value() ? creds->secretAccessKey().value() : "", + creds->sessionToken().has_value() ? creds->sessionToken().value() : ""); + + cb(Credentials(creds->accessKeyId().has_value() ? creds->accessKeyId().value() : "", + creds->secretAccessKey().has_value() ? creds->secretAccessKey().value() : "", + creds->sessionToken().has_value() ? creds->sessionToken().value() : "")); } credential_pending_callbacks_.clear(); @@ -358,7 +361,6 @@ void MetadataCredentialsProviderBase::setCredentialsToAllThreads( // We are now no longer waiting for credentials credentials_pending_.exchange(false); - } CredentialsFileCredentialsProvider::CredentialsFileCredentialsProvider( @@ -1010,13 +1012,12 @@ void WebIdentityCredentialsProvider::onMetadataError(Failure reason) { bool CredentialsProviderChain::credentialsPending(CredentialsPendingCallback&& cb) { for (auto& provider : providers_) { - if(provider->credentialsPending(std::move(cb))) - { - ENVOY_LOG_MISC(debug,"Credentials are pending"); + if (provider->credentialsPending(std::move(cb))) { + ENVOY_LOG_MISC(debug, "Credentials are pending"); return true; } } - ENVOY_LOG_MISC(debug,"Credentials are not pending"); + ENVOY_LOG_MISC(debug, "Credentials are not pending"); return false; } @@ -1081,8 +1082,8 @@ CustomCredentialsProviderChain::CustomCredentialsProviderChain( const auto refresh_state = MetadataFetcher::MetadataReceiver::RefreshState::FirstRefresh; const auto initialization_timer = std::chrono::seconds(2); add(factories.createWebIdentityCredentialsProvider( - context, context.singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, initialization_timer, - web_identity, cluster_name)); + context, context.singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, + initialization_timer, web_identity, cluster_name)); } if (credential_provider_config.has_credentials_file_provider()) { @@ -1135,7 +1136,8 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( !web_identity.role_arn().empty()) { const auto sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; - // const auto region_uuid = absl::StrCat(region, "_", context->api().randomGenerator().uuid()); + // const auto region_uuid = absl::StrCat(region, "_", + // context->api().randomGenerator().uuid()); const auto cluster_name = stsClusterName(region); @@ -1144,8 +1146,8 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( "Using web identity credentials provider with STS endpoint: {} and session name: {}", sts_endpoint, web_identity.role_session_name()); add(factories.createWebIdentityCredentialsProvider( - context.value(), context->singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, - initialization_timer, web_identity, cluster_name)); + context.value(), context->singletonManager(), MetadataFetcher::create, sts_endpoint, + refresh_state, initialization_timer, web_identity, cluster_name)); } } @@ -1227,39 +1229,43 @@ DefaultCredentialsProviderChain::createInstanceProfileCredentialsProvider( api, context, fetch_metadata_using_curl, create_metadata_fetcher_cb, refresh_state, initialization_timer, cluster_name); }); - } - CredentialsProviderSharedPtr DefaultCredentialsProviderChain::createWebIdentityCredentialsProvider( - Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, - MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - std::chrono::seconds initialization_timer, - const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name) const { - return singleton_manager.getTyped(SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), - [&context, create_metadata_fetcher_cb,sts_endpoint,refresh_state, initialization_timer, web_identity_config, cluster_name]{ - return std::make_shared( - context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, - web_identity_config, cluster_name); +} +CredentialsProviderSharedPtr DefaultCredentialsProviderChain::createWebIdentityCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name) const { + return singleton_manager.getTyped( + SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), + [&context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, + web_identity_config, cluster_name] { + return std::make_shared( + context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, + web_identity_config, cluster_name); }); - }; +}; - CredentialsProviderSharedPtr CustomCredentialsProviderChain::createWebIdentityCredentialsProvider( - Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, - MetadataFetcher::MetadataReceiver::RefreshState refresh_state, - std::chrono::seconds initialization_timer, - const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name) const { - return singleton_manager.getTyped(SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), - [&context, create_metadata_fetcher_cb,sts_endpoint,refresh_state, initialization_timer, web_identity_config, cluster_name]{ - return std::make_shared( - context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, - web_identity_config, cluster_name); +CredentialsProviderSharedPtr CustomCredentialsProviderChain::createWebIdentityCredentialsProvider( + Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + MetadataFetcher::MetadataReceiver::RefreshState refresh_state, + std::chrono::seconds initialization_timer, + const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + web_identity_config, + absl::string_view cluster_name) const { + return singleton_manager.getTyped( + SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), + [&context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, + web_identity_config, cluster_name] { + return std::make_shared( + context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, + web_identity_config, cluster_name); }); - }; - +}; + } // namespace Aws } // namespace Common } // namespace Extensions diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index fcdeded4f826..eb2df62b2805 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -12,7 +12,6 @@ #include "envoy/extensions/common/aws/v3/credential_provider.pb.h" #include "envoy/http/message.h" #include "envoy/init/manager.h" -#include "envoy/init/manager.h" #include "envoy/server/factory_context.h" #include "source/common/common/lock_guard.h" @@ -146,7 +145,7 @@ class MetadataCredentialsProviderBase : public CachedCredentialsProviderBase { Credentials getCredentials() override; bool credentialsPending(CredentialsPendingCallback&& cb) override; - + // Get the Metadata credentials cache duration. static std::chrono::seconds getCacheDuration(); diff --git a/source/extensions/common/aws/signer.h b/source/extensions/common/aws/signer.h index 61a03567582b..17d50c7acd91 100644 --- a/source/extensions/common/aws/signer.h +++ b/source/extensions/common/aws/signer.h @@ -1,9 +1,10 @@ #pragma once -#include "credentials_provider.h" #include "envoy/common/pure.h" #include "envoy/http/message.h" +#include "credentials_provider.h" + namespace Envoy { namespace Extensions { namespace Common { @@ -20,8 +21,8 @@ class Signer { * @param override_region override the default region that has to be used to sign the request * @throws EnvoyException if the request cannot be signed. */ - virtual absl::Status sign(Http::RequestMessage& message, const Credentials credentials, bool sign_body, - const absl::string_view override_region = "") PURE; + virtual absl::Status sign(Http::RequestMessage& message, const Credentials credentials, + bool sign_body, const absl::string_view override_region = "") PURE; /** * Sign an AWS request without a payload (empty string used as content hash). @@ -29,7 +30,8 @@ class Signer { * @param override_region override the default region that has to be used to sign the request * @throws EnvoyException if the request cannot be signed. */ - virtual absl::Status signEmptyPayload(Http::RequestHeaderMap& headers, const Credentials credentials, + virtual absl::Status signEmptyPayload(Http::RequestHeaderMap& headers, + const Credentials credentials, const absl::string_view override_region = "") PURE; /** @@ -38,7 +40,8 @@ class Signer { * @param override_region override the default region that has to be used to sign the request * @throws EnvoyException if the request cannot be signed. */ - virtual absl::Status signUnsignedPayload(Http::RequestHeaderMap& headers, const Credentials credentials, + virtual absl::Status signUnsignedPayload(Http::RequestHeaderMap& headers, + const Credentials credentials, const absl::string_view override_region = "") PURE; /** @@ -48,7 +51,8 @@ class Signer { * @param override_region override the default region that has to be used to sign the request * @throws EnvoyException if the request cannot be signed. */ - virtual absl::Status sign(Http::RequestHeaderMap& headers, const Credentials credentials, const std::string& content_hash, + virtual absl::Status sign(Http::RequestHeaderMap& headers, const Credentials credentials, + const std::string& content_hash, const absl::string_view override_region = "") PURE; }; diff --git a/source/extensions/common/aws/signer_base_impl.cc b/source/extensions/common/aws/signer_base_impl.cc index f27c80ac3e76..95ab0043ff7d 100644 --- a/source/extensions/common/aws/signer_base_impl.cc +++ b/source/extensions/common/aws/signer_base_impl.cc @@ -23,25 +23,29 @@ namespace Extensions { namespace Common { namespace Aws { -absl::Status SignerBaseImpl::sign(Http::RequestMessage& message,const Credentials credentials, bool sign_body, - const absl::string_view override_region) { +absl::Status SignerBaseImpl::sign(Http::RequestMessage& message, const Credentials credentials, + bool sign_body, const absl::string_view override_region) { const auto content_hash = createContentHash(message, sign_body); auto& headers = message.headers(); return sign(headers, credentials, content_hash, override_region); } -absl::Status SignerBaseImpl::signEmptyPayload(Http::RequestHeaderMap& headers, const Credentials credentials, +absl::Status SignerBaseImpl::signEmptyPayload(Http::RequestHeaderMap& headers, + const Credentials credentials, const absl::string_view override_region) { headers.setReference(SignatureHeaders::get().ContentSha256, SignatureConstants::HashedEmptyString); - return sign(headers, credentials, std::string(SignatureConstants::HashedEmptyString), override_region); + return sign(headers, credentials, std::string(SignatureConstants::HashedEmptyString), + override_region); } -absl::Status SignerBaseImpl::signUnsignedPayload(Http::RequestHeaderMap& headers, const Credentials credentials, +absl::Status SignerBaseImpl::signUnsignedPayload(Http::RequestHeaderMap& headers, + const Credentials credentials, const absl::string_view override_region) { headers.setReference(SignatureHeaders::get().ContentSha256, SignatureConstants::UnsignedPayload); - return sign(headers, credentials, std::string(SignatureConstants::UnsignedPayload), override_region); + return sign(headers, credentials, std::string(SignatureConstants::UnsignedPayload), + override_region); } // Region support utilities for sigv4a @@ -54,20 +58,20 @@ void SignerBaseImpl::addRegionQueryParam( std::string SignerBaseImpl::getRegion() const { return region_; } -absl::Status SignerBaseImpl::sign(Http::RequestHeaderMap& headers, const Credentials credentials, const std::string& content_hash, +absl::Status SignerBaseImpl::sign(Http::RequestHeaderMap& headers, const Credentials credentials, + const std::string& content_hash, const absl::string_view override_region) { if (!query_string_ && !content_hash.empty()) { headers.setReferenceKey(SignatureHeaders::get().ContentSha256, content_hash); } - // const auto& credentials = credentials_provider_->getCredentials(); - // if (!credentials.accessKeyId() || !credentials.secretAccessKey()) { - // // Empty or "anonymous" credentials are a valid use-case for non-production environments. - // // This behavior matches what the AWS SDK would do. - // ENVOY_LOG_MISC(debug, "Sign exiting early - no credentials found"); - // return absl::OkStatus(); - // } + if (!credentials.accessKeyId() || !credentials.secretAccessKey()) { + // Empty or "anonymous" credentials are a valid use-case for non-production environments. + // This behavior matches what the AWS SDK would do. + ENVOY_LOG_MISC(debug, "Sign exiting early - no credentials found"); + return absl::OkStatus(); + } if (headers.Method() == nullptr) { return absl::Status{absl::StatusCode::kInvalidArgument, "Message is missing :method header"}; diff --git a/source/extensions/common/aws/signer_base_impl.h b/source/extensions/common/aws/signer_base_impl.h index 1d1a8d00fbe5..3b33b748d687 100644 --- a/source/extensions/common/aws/signer_base_impl.h +++ b/source/extensions/common/aws/signer_base_impl.h @@ -68,8 +68,7 @@ class SignerBaseImpl : public Signer, public Logger::Loggable { const bool query_string = false, const uint16_t expiration_time = SignatureQueryParameterValues::DefaultExpiration) : service_name_(service_name), region_(region), - excluded_header_matchers_(defaultMatchers(context)), - query_string_(query_string), + excluded_header_matchers_(defaultMatchers(context)), query_string_(query_string), expiration_time_(expiration_time), time_source_(context.timeSource()), long_date_formatter_(std::string(SignatureConstants::LongDateFormat)), short_date_formatter_(std::string(SignatureConstants::ShortDateFormat)) { @@ -80,13 +79,14 @@ class SignerBaseImpl : public Signer, public Logger::Loggable { } } - absl::Status sign(Http::RequestMessage& message, const Credentials credentials, bool sign_body = false, + absl::Status sign(Http::RequestMessage& message, const Credentials credentials, + bool sign_body = false, const absl::string_view override_region = "") override; + absl::Status sign(Http::RequestHeaderMap& headers, const Credentials credentials, + const std::string& content_hash, const absl::string_view override_region = "") override; - absl::Status sign(Http::RequestHeaderMap& headers, const Credentials credentials, const std::string& content_hash, - const absl::string_view override_region = "") override; - absl::Status signEmptyPayload(Http::RequestHeaderMap& headers, const Credentials credentials, + absl::Status signEmptyPayload(Http::RequestHeaderMap& headers, const Credentials credentials, const absl::string_view override_region = "") override; - absl::Status signUnsignedPayload(Http::RequestHeaderMap& headers, const Credentials credentials, + absl::Status signUnsignedPayload(Http::RequestHeaderMap& headers, const Credentials credentials, const absl::string_view override_region = "") override; protected: diff --git a/source/extensions/common/aws/sigv4_signer_impl.h b/source/extensions/common/aws/sigv4_signer_impl.h index c3877568811d..f9b62987fd16 100644 --- a/source/extensions/common/aws/sigv4_signer_impl.h +++ b/source/extensions/common/aws/sigv4_signer_impl.h @@ -49,8 +49,8 @@ class SigV4SignerImpl : public SignerBaseImpl { const AwsSigningHeaderExclusionVector& matcher_config, const bool query_string = false, const uint16_t expiration_time = SignatureQueryParameterValues::DefaultExpiration) - : SignerBaseImpl(service_name, region, context, matcher_config, - query_string, expiration_time) {} + : SignerBaseImpl(service_name, region, context, matcher_config, query_string, + expiration_time) {} private: std::string createCredentialScope(const absl::string_view short_date, diff --git a/source/extensions/common/aws/sigv4a_signer_impl.h b/source/extensions/common/aws/sigv4a_signer_impl.h index 5c20d7965b31..dfd3c38bf57a 100644 --- a/source/extensions/common/aws/sigv4a_signer_impl.h +++ b/source/extensions/common/aws/sigv4a_signer_impl.h @@ -60,8 +60,8 @@ class SigV4ASignerImpl : public SignerBaseImpl { Server::Configuration::CommonFactoryContext& context, const AwsSigningHeaderExclusionVector& matcher_config, const bool query_string = false, const uint16_t expiration_time = SignatureQueryParameterValues::DefaultExpiration) - : SignerBaseImpl(service_name, region, context, matcher_config, - query_string, expiration_time) {} + : SignerBaseImpl(service_name, region, context, matcher_config, query_string, + expiration_time) {} private: void addRegionHeader(Http::RequestHeaderMap& headers, diff --git a/source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc b/source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc index 63e84a746c9b..203073f164e8 100644 --- a/source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc +++ b/source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc @@ -147,7 +147,8 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, if (settings.payloadPassthrough()) { setLambdaHeaders(headers, settings.arn(), settings.invocationMode(), settings.hostRewrite()); - auto status = settings.signer().signEmptyPayload(headers, settings.credentialsProvider()->getCredentials(),settings.arn().region()); + auto status = settings.signer().signEmptyPayload( + headers, settings.credentialsProvider()->getCredentials(), settings.arn().region()); if (!status.ok()) { ENVOY_LOG(debug, "signing failed: {}", status.message()); } @@ -165,7 +166,8 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, auto& hashing_util = Envoy::Common::Crypto::UtilitySingleton::get(); const auto hash = Hex::encode(hashing_util.getSha256Digest(json_buf)); - auto status = settings.signer().sign(headers, settings.credentialsProvider()->getCredentials(), hash, settings.arn().region()); + auto status = settings.signer().sign(headers, settings.credentialsProvider()->getCredentials(), + hash, settings.arn().region()); if (!status.ok()) { ENVOY_LOG(debug, "signing failed: {}", status.message()); } @@ -229,7 +231,9 @@ Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_strea settings.hostRewrite()); const auto hash = Hex::encode(hashing_util.getSha256Digest(decoding_buffer)); - auto status = settings.signer().sign(*request_headers_,settings.credentialsProvider()->getCredentials(), hash, settings.arn().region()); + auto status = + settings.signer().sign(*request_headers_, settings.credentialsProvider()->getCredentials(), + hash, settings.arn().region()); if (!status.ok()) { ENVOY_LOG(debug, "signing failed: {}", status.message()); } diff --git a/source/extensions/filters/http/aws_lambda/aws_lambda_filter.h b/source/extensions/filters/http/aws_lambda/aws_lambda_filter.h index 93dfc7e49c80..9a9e26ff3169 100644 --- a/source/extensions/filters/http/aws_lambda/aws_lambda_filter.h +++ b/source/extensions/filters/http/aws_lambda/aws_lambda_filter.h @@ -102,7 +102,9 @@ class FilterSettingsImpl : public FilterSettings { InvocationMode invocationMode() const override { return invocation_mode_; } const std::string& hostRewrite() const override { return host_rewrite_; } Extensions::Common::Aws::Signer& signer() override { return *signer_; } - Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentialsProvider() override { return credentials_provider_;} + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentialsProvider() override { + return credentials_provider_; + } private: Arn arn_; @@ -111,7 +113,6 @@ class FilterSettingsImpl : public FilterSettings { const std::string host_rewrite_; Extensions::Common::Aws::SignerPtr signer_; Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider_; - }; using FilterSettingsSharedPtr = std::shared_ptr; diff --git a/source/extensions/filters/http/aws_request_signing/BUILD b/source/extensions/filters/http/aws_request_signing/BUILD index d1cd1e34a1c8..1d2f55d89699 100644 --- a/source/extensions/filters/http/aws_request_signing/BUILD +++ b/source/extensions/filters/http/aws_request_signing/BUILD @@ -18,6 +18,7 @@ envoy_cc_library( hdrs = ["aws_request_signing_filter.h"], deps = [ "//envoy/http:filter_interface", + "//source/common/common:cancel_wrapper_lib", "//source/extensions/common/aws:credentials_provider_impl_lib", "//source/extensions/common/aws:region_provider_impl_lib", "//source/extensions/common/aws:sigv4_signer_impl_lib", diff --git a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc index ab5a3ad2a35f..5a24f85501e2 100644 --- a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc +++ b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc @@ -1,5 +1,7 @@ #include "source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h" +#include + #include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.h" #include "source/common/common/hex.h" @@ -11,18 +13,23 @@ namespace Extensions { namespace HttpFilters { namespace AwsRequestSigningFilter { -FilterConfigImpl::FilterConfigImpl(Extensions::Common::Aws::SignerPtr&& signer, -Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider, - const std::string& stats_prefix, Stats::Scope& scope, - const std::string& host_rewrite, bool use_unsigned_payload) - : signer_(std::move(signer)), credentials_provider_(credentials_provider), stats_(Filter::generateStats(stats_prefix, scope)), +FilterConfigImpl::FilterConfigImpl( + Extensions::Common::Aws::SignerPtr&& signer, + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider, + const std::string& stats_prefix, Stats::Scope& scope, const std::string& host_rewrite, + bool use_unsigned_payload) + : signer_(std::move(signer)), credentials_provider_(credentials_provider), + stats_(Filter::generateStats(stats_prefix, scope)), host_rewrite_(host_rewrite), use_unsigned_payload_{use_unsigned_payload} {} Filter::Filter(const std::shared_ptr& config) : config_(config) {} Extensions::Common::Aws::Signer& FilterConfigImpl::signer() { return *signer_; } -Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr FilterConfigImpl::credentialsProvider() { return credentials_provider_; } +Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr +FilterConfigImpl::credentialsProvider() { + return credentials_provider_; +} FilterStats& FilterConfigImpl::stats() { return stats_; } @@ -34,9 +41,11 @@ FilterStats Filter::generateStats(const std::string& prefix, Stats::Scope& scope return {ALL_AWS_REQUEST_SIGNING_FILTER_STATS(POOL_COUNTER_PREFIX(scope, final_prefix))}; } -Http::FilterHeadersStatus Filter::onCredentialNoLongerPending(FilterConfig& config, Http::RequestHeaderMap& headers, bool end_stream, Envoy::Extensions::Common::Aws::Credentials credentials) -{ - ENVOY_LOG(debug, "aws request signing onCredentialNoLongerPending, {}",credentials.accessKeyId().value()); +Http::FilterHeadersStatus +Filter::decodeHeadersCredentialsAvailable(Envoy::Extensions::Common::Aws::Credentials credentials) { + ENVOY_LOG(debug, "aws request signing decodeHeadersCredentialsAvailable, {}", + credentials.accessKeyId().value()); + auto& config = getConfig(); const auto& host_rewrite = config.hostRewrite(); const bool use_unsigned_payload = config.useUnsignedPayload(); @@ -44,20 +53,15 @@ Http::FilterHeadersStatus Filter::onCredentialNoLongerPending(FilterConfig& conf absl::Status status; if (!host_rewrite.empty()) { - headers.setHost(host_rewrite); - } - - if (!use_unsigned_payload && !end_stream) { - request_headers_ = &headers; - return Http::FilterHeadersStatus::StopIteration; + request_headers_->setHost(host_rewrite); } ENVOY_LOG(debug, "aws request signing from decodeHeaders use_unsigned_payload: {}", use_unsigned_payload); if (use_unsigned_payload) { - status = config.signer().signUnsignedPayload(headers, config.credentialsProvider()->getCredentials()); + status = config.signer().signUnsignedPayload(*request_headers_, credentials); } else { - status = config.signer().signEmptyPayload(headers, config.credentialsProvider()->getCredentials()); + status = config.signer().signEmptyPayload(*request_headers_, credentials); } if (status.ok()) { config.stats().signing_added_.inc(); @@ -71,23 +75,33 @@ Http::FilterHeadersStatus Filter::onCredentialNoLongerPending(FilterConfig& conf Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) { auto& config = getConfig(); - ENVOY_LOG_MISC(debug, "******* HERE"); - if(config.credentialsProvider()->credentialsPending( - [this, &dispatcher = decoder_callbacks_->dispatcher(), &end_stream, &headers, &config](Envoy::Extensions::Common::Aws::Credentials credentials) { - dispatcher.post([this, &config, &headers, end_stream, credentials]() { - this->onCredentialNoLongerPending(config, headers, end_stream, credentials); - }); + if (!config.useUnsignedPayload() && !end_stream) { + return Http::FilterHeadersStatus::StopIteration; } - )) - { + + request_headers_ = &headers; + + // If we are pending credentials, send the decodeHeadersCredentialsAvailable callback for when + // they become available, and stop iteration. + if (config.credentialsProvider()->credentialsPending( + + Envoy::CancelWrapper::cancelWrapped( + [this, &dispatcher = decoder_callbacks_->dispatcher()]( + Envoy::Extensions::Common::Aws::Credentials credentials) { + dispatcher.post([this, credentials]() { + this->decodeHeadersCredentialsAvailable(credentials); + }); + }, + &cancel_callback_) + + )) { ENVOY_LOG_MISC(debug, "Credentials are pending"); return Http::FilterHeadersStatus::StopAllIterationAndBuffer; - } else - { + } else { ENVOY_LOG_MISC(debug, "Credentials are not pending"); } - return onCredentialNoLongerPending(config, headers, end_stream, config.credentialsProvider()->getCredentials()); + return decodeHeadersCredentialsAvailable(config.credentialsProvider()->getCredentials()); } Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_stream) { @@ -103,14 +117,39 @@ Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_strea decoder_callbacks_->addDecodedData(data, false); + // If we are pending credentials, send the decodeDataCredentialsAvailable callback for when they + // become available, and stop iteration. + if (config.credentialsProvider()->credentialsPending(Envoy::CancelWrapper::cancelWrapped( + + [this, &dispatcher = decoder_callbacks_->dispatcher()]( + Envoy::Extensions::Common::Aws::Credentials credentials) { + dispatcher.post( + [this, credentials]() { this->decodeDataCredentialsAvailable(credentials); }); + }, + &cancel_callback_))) { + ENVOY_LOG_MISC(debug, "Credentials are pending"); + return Http::FilterDataStatus::StopIterationAndBuffer; + } else { + ENVOY_LOG_MISC(debug, "Credentials are not pending"); + } + return decodeDataCredentialsAvailable(config.credentialsProvider()->getCredentials()); +} + +Http::FilterDataStatus +Filter::decodeDataCredentialsAvailable(Envoy::Extensions::Common::Aws::Credentials credentials) { + + ENVOY_LOG(debug, "aws request signing decodeHeadersCredentialsAvailable, {}", + credentials.accessKeyId().value()); + const Buffer::Instance& decoding_buffer = *decoder_callbacks_->decodingBuffer(); + auto& config = getConfig(); auto& hashing_util = Envoy::Common::Crypto::UtilitySingleton::get(); const std::string hash = Hex::encode(hashing_util.getSha256Digest(decoding_buffer)); ENVOY_LOG(debug, "aws request signing from decodeData"); ASSERT(request_headers_ != nullptr); - auto status = config.signer().sign(*request_headers_, config.credentialsProvider()->getCredentials(), hash); + auto status = config.signer().sign(*request_headers_, credentials, hash); if (status.ok()) { config.stats().signing_added_.inc(); config.stats().payload_signing_added_.inc(); diff --git a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h index b67c36b240c8..9eb3bd148fee 100644 --- a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h +++ b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h @@ -5,8 +5,9 @@ #include "envoy/stats/scope.h" #include "envoy/stats/stats_macros.h" -#include "source/extensions/common/aws/signer.h" +#include "source/common/common/cancel_wrapper.h" #include "source/extensions/common/aws/credentials_provider.h" +#include "source/extensions/common/aws/signer.h" #include "source/extensions/filters/http/common/pass_through_filter.h" namespace Envoy { @@ -72,9 +73,11 @@ using FilterConfigSharedPtr = std::shared_ptr; */ class FilterConfigImpl : public FilterConfig { public: - FilterConfigImpl(Extensions::Common::Aws::SignerPtr&& signer, Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider, - const std::string& stats_prefix, - Stats::Scope& scope, const std::string& host_rewrite, bool use_unsigned_payload); + FilterConfigImpl( + Extensions::Common::Aws::SignerPtr&& signer, + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider, + const std::string& stats_prefix, Stats::Scope& scope, const std::string& host_rewrite, + bool use_unsigned_payload); Extensions::Common::Aws::Signer& signer() override; Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentialsProvider() override; @@ -97,7 +100,11 @@ class FilterConfigImpl : public FilterConfig { class Filter : public Http::PassThroughDecoderFilter, Logger::Loggable { public: Filter(const std::shared_ptr& config); - + ~Filter() override { + if (cancel_callback_) { + cancel_callback_(); + } + } static FilterStats generateStats(const std::string& prefix, Stats::Scope& scope); Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, @@ -106,8 +113,11 @@ class Filter : public Http::PassThroughDecoderFilter, Logger::Loggable config_; Http::RequestHeaderMap* request_headers_{}; diff --git a/source/extensions/filters/http/aws_request_signing/config.cc b/source/extensions/filters/http/aws_request_signing/config.cc index 09b7600e8aa8..58c18f201017 100644 --- a/source/extensions/filters/http/aws_request_signing/config.cc +++ b/source/extensions/filters/http/aws_request_signing/config.cc @@ -3,8 +3,6 @@ #include #include - - namespace Envoy { namespace Extensions { namespace HttpFilters { @@ -34,9 +32,9 @@ AwsRequestSigningFilterFactory::createFilterFactoryFromProtoTyped( if (!signer.ok()) { return absl::InvalidArgumentError(std::string(signer.status().message())); } - auto filter_config = - std::make_shared(std::move(signer.value()), credentials_provider.value(), stats_prefix, dual_info.scope, - config.host_rewrite(), config.use_unsigned_payload()); + auto filter_config = std::make_shared( + std::move(signer.value()), credentials_provider.value(), stats_prefix, dual_info.scope, + config.host_rewrite(), config.use_unsigned_payload()); return [filter_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { auto filter = std::make_shared(filter_config); callbacks.addStreamDecoderFilter(filter); @@ -49,7 +47,8 @@ AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( Server::Configuration::ServerFactoryContext& server_context, ProtobufMessage::ValidationVisitor&) { - auto credentials_provider = createCredentialsProvider(per_route_config.aws_request_signing(), server_context); + auto credentials_provider = + createCredentialsProvider(per_route_config.aws_request_signing(), server_context); if (!credentials_provider.ok()) { return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); } @@ -60,8 +59,8 @@ AwsRequestSigningFilterFactory::createRouteSpecificFilterConfigTyped( } return std::make_shared( - std::move(signer.value()), credentials_provider.value(), per_route_config.stat_prefix(), server_context.scope(), - per_route_config.aws_request_signing().host_rewrite(), + std::move(signer.value()), credentials_provider.value(), per_route_config.stat_prefix(), + server_context.scope(), per_route_config.aws_request_signing().host_rewrite(), per_route_config.aws_request_signing().use_unsigned_payload()); } @@ -119,30 +118,26 @@ AwsRequestSigningFilterFactory::createCredentialsProvider( } else if (config.credential_provider().custom_credential_provider_chain()) { // Custom credential provider chain if (has_credential_provider_settings) { - return - std::make_shared( - server_context, region, config.credential_provider()); + return std::make_shared( + server_context, region, config.credential_provider()); } } else { // Override default credential provider chain settings with any provided settings if (has_credential_provider_settings) { credential_provider_config = config.credential_provider(); } - return - std::make_shared( - server_context.api(), makeOptRef(server_context), server_context.singletonManager(), - region, nullptr, credential_provider_config); + return std::make_shared( + server_context.api(), makeOptRef(server_context), server_context.singletonManager(), + region, nullptr, credential_provider_config); } } else { // No credential provider settings provided, so make the default credentials provider chain - return - std::make_shared( - server_context.api(), makeOptRef(server_context), server_context.singletonManager(), - region, nullptr, credential_provider_config); + return std::make_shared( + server_context.api(), makeOptRef(server_context), server_context.singletonManager(), region, + nullptr, credential_provider_config); } - return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); - + return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); } absl::StatusOr @@ -193,8 +188,8 @@ AwsRequestSigningFilterFactory::createSigner( if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { return std::make_unique( - config.service_name(), region, server_context, matcher_config, - query_string, expiration_time); + config.service_name(), region, server_context, matcher_config, query_string, + expiration_time); } else { // Verify that we have not specified a region set when using sigv4 algorithm if (isARegionSet(region)) { @@ -203,8 +198,8 @@ AwsRequestSigningFilterFactory::createSigner( "can be specified when using signing_algorithm: AWS_SIGV4A."); } return std::make_unique( - config.service_name(), region, server_context, matcher_config, - query_string, expiration_time); + config.service_name(), region, server_context, matcher_config, query_string, + expiration_time); } } diff --git a/source/extensions/filters/http/aws_request_signing/config.h b/source/extensions/filters/http/aws_request_signing/config.h index 7d4f960c3d14..c4977830962e 100644 --- a/source/extensions/filters/http/aws_request_signing/config.h +++ b/source/extensions/filters/http/aws_request_signing/config.h @@ -1,20 +1,18 @@ #pragma once +#include "envoy/common/optref.h" #include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.h" #include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.validate.h" +#include "envoy/registry/registry.h" -#include "source/extensions/common/aws/signer.h" -#include "source/extensions/filters/http/common/factory_base.h" #include "source/extensions/common/aws/credentials_provider_impl.h" #include "source/extensions/common/aws/region_provider_impl.h" +#include "source/extensions/common/aws/signer.h" #include "source/extensions/common/aws/sigv4_signer_impl.h" #include "source/extensions/common/aws/sigv4a_signer_impl.h" #include "source/extensions/common/aws/utility.h" #include "source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.h" -#include "envoy/common/optref.h" -#include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.h" -#include "envoy/extensions/filters/http/aws_request_signing/v3/aws_request_signing.pb.validate.h" -#include "envoy/registry/registry.h" +#include "source/extensions/filters/http/common/factory_base.h" namespace Envoy { namespace Extensions { @@ -52,10 +50,8 @@ class AwsRequestSigningFilterFactory Server::Configuration::ServerFactoryContext& server_context); absl::StatusOr -createCredentialsProvider( - const AwsRequestSigningProtoConfig& config, - Server::Configuration::ServerFactoryContext& server_context); - + createCredentialsProvider(const AwsRequestSigningProtoConfig& config, + Server::Configuration::ServerFactoryContext& server_context); }; using UpstreamAwsRequestSigningFilterFactory = AwsRequestSigningFilterFactory; diff --git a/source/extensions/grpc_credentials/aws_iam/config.h b/source/extensions/grpc_credentials/aws_iam/config.h index 5daf9c786257..f4cafe101713 100644 --- a/source/extensions/grpc_credentials/aws_iam/config.h +++ b/source/extensions/grpc_credentials/aws_iam/config.h @@ -34,8 +34,10 @@ class AwsIamGrpcCredentialsFactory : public Grpc::GoogleGrpcCredentialsFactory { */ class AwsIamHeaderAuthenticator : public grpc::MetadataCredentialsPlugin { public: - AwsIamHeaderAuthenticator(Common::Aws::SignerPtr signer, - Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider) : signer_(std::move(signer)), credentials_provider_(credentials_provider) {} + AwsIamHeaderAuthenticator( + Common::Aws::SignerPtr signer, + Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider) + : signer_(std::move(signer)), credentials_provider_(credentials_provider) {} grpc::Status GetMetadata(grpc::string_ref, grpc::string_ref, const grpc::AuthContext&, std::multimap* metadata) override; @@ -51,7 +53,6 @@ class AwsIamHeaderAuthenticator : public grpc::MetadataCredentialsPlugin { const Common::Aws::SignerPtr signer_; Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr credentials_provider_; - }; } // namespace AwsIam diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index 1c6426cef58f..df974a898e27 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -2521,8 +2521,8 @@ class MockCredentialsProviderChainFactories : public CredentialsProviderChainFac MOCK_METHOD( CredentialsProviderSharedPtr, createWebIdentityCredentialsProvider, - (Server::Configuration::ServerFactoryContext&, Singleton::Manager&, CreateMetadataFetcherCb, absl::string_view, - MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, + (Server::Configuration::ServerFactoryContext&, Singleton::Manager&, CreateMetadataFetcherCb, + absl::string_view, MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider&, absl::string_view), (const)); @@ -2559,8 +2559,8 @@ class MockCustomCredentialsProviderChainFactories : public CustomCredentialsProv MOCK_METHOD( CredentialsProviderSharedPtr, createWebIdentityCredentialsProvider, - (Server::Configuration::ServerFactoryContext&, Singleton::Manager&,CreateMetadataFetcherCb, absl::string_view, - MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, + (Server::Configuration::ServerFactoryContext&, Singleton::Manager&, CreateMetadataFetcherCb, + absl::string_view, MetadataFetcher::MetadataReceiver::RefreshState, std::chrono::seconds, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider&, absl::string_view), (const)); @@ -2686,7 +2686,7 @@ TEST_F(DefaultCredentialsProviderChainTest, NoWebIdentitySessionName) { time_system_.setSystemTime(std::chrono::milliseconds(1234567890)); EXPECT_CALL(factories_, mockCreateCredentialsFileCredentialsProvider(Ref(context_), _)); EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _,_, "sts.region.amazonaws.com:443", _, _, _, _)); + Ref(context_), _, _, "sts.region.amazonaws.com:443", _, _, _, _)); EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; @@ -2704,7 +2704,7 @@ TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithSessionName) { EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_),_, _, "sts.region.amazonaws.com:443", _, _, _, _)); + Ref(context_), _, _, "sts.region.amazonaws.com:443", _, _, _, _)); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; @@ -2720,7 +2720,7 @@ TEST_F(DefaultCredentialsProviderChainTest, NoWebIdentityWithBlankConfig) { EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_),_, _, "sts.region.amazonaws.com:443", _, _, _, _)) + Ref(context_), _, _, "sts.region.amazonaws.com:443", _, _, _, _)) .Times(0); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; @@ -2743,7 +2743,7 @@ TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomSessionName) { std::string role_session_name; EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _,_, "sts.region.amazonaws.com:443", _, _, _, _)) + Ref(context_), _, _, "sts.region.amazonaws.com:443", _, _, _, _)) .WillOnce(Invoke(WithArg<6>( [&role_session_name]( const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& @@ -2773,7 +2773,7 @@ TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomRoleArn) { std::string role_arn; EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_), _,_, "sts.region.amazonaws.com:443", _, _, _, _)) + Ref(context_), _, _, "sts.region.amazonaws.com:443", _, _, _, _)) .WillOnce(Invoke(WithArg<6>( [&role_arn]( const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& @@ -2803,7 +2803,7 @@ TEST_F(DefaultCredentialsProviderChainTest, WebIdentityWithCustomDataSource) { std::string inline_string; EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_),_, _, "sts.region.amazonaws.com:443", _, _, _, _)) + Ref(context_), _, _, "sts.region.amazonaws.com:443", _, _, _, _)) .WillOnce(Invoke(WithArg<6>( [&inline_string]( const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& @@ -2843,7 +2843,7 @@ TEST_F(DefaultCredentialsProviderChainTest, CredentialsFileWithCustomDataSource) createInstanceProfileCredentialsProvider(Ref(*api_), _, _, _, _, _, _, _)); EXPECT_CALL(factories_, createWebIdentityCredentialsProvider( - Ref(context_),_, _, "sts.region.amazonaws.com:443", _, _, _, _)); + Ref(context_), _, _, "sts.region.amazonaws.com:443", _, _, _, _)); envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; credential_provider_config.mutable_credentials_file_provider() @@ -2920,7 +2920,7 @@ TEST_F(CustomCredentialsProviderChainTest, CreateFileCredentialProviderOnly) { EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)); EXPECT_CALL(factories, - createWebIdentityCredentialsProvider(Ref(server_context),_, _, _, _, _, _, _)) + createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _, _)) .Times(0); auto chain = std::make_shared( @@ -2942,7 +2942,7 @@ TEST_F(CustomCredentialsProviderChainTest, CreateWebIdentityCredentialProviderOn EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)) .Times(0); EXPECT_CALL(factories, - createWebIdentityCredentialsProvider(Ref(server_context), _,_, _, _, _, _, _)); + createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _, _)); auto chain = std::make_shared( server_context, region, cred_provider, factories); @@ -2965,7 +2965,7 @@ TEST_F(CustomCredentialsProviderChainTest, CreateFileAndWebProviders) { EXPECT_CALL(factories, mockCreateCredentialsFileCredentialsProvider(Ref(server_context), _)); EXPECT_CALL(factories, - createWebIdentityCredentialsProvider(Ref(server_context), _,_, _, _, _, _, _)); + createWebIdentityCredentialsProvider(Ref(server_context), _, _, _, _, _, _, _)); auto chain = std::make_shared( server_context, region, cred_provider, factories); diff --git a/test/extensions/common/aws/mocks.h b/test/extensions/common/aws/mocks.h index 0f699d5e89b8..aa4618367444 100644 --- a/test/extensions/common/aws/mocks.h +++ b/test/extensions/common/aws/mocks.h @@ -44,10 +44,14 @@ class MockSigner : public Signer { MockSigner(); ~MockSigner() override; - MOCK_METHOD(absl::Status, sign, (Http::RequestMessage&,const Credentials, bool, absl::string_view)); - MOCK_METHOD(absl::Status, sign, (Http::RequestHeaderMap&, const Credentials, const std::string&, absl::string_view)); - MOCK_METHOD(absl::Status, signEmptyPayload, (Http::RequestHeaderMap&, const Credentials, absl::string_view)); - MOCK_METHOD(absl::Status, signUnsignedPayload, (Http::RequestHeaderMap&, const Credentials, absl::string_view)); + MOCK_METHOD(absl::Status, sign, + (Http::RequestMessage&, const Credentials, bool, absl::string_view)); + MOCK_METHOD(absl::Status, sign, + (Http::RequestHeaderMap&, const Credentials, const std::string&, absl::string_view)); + MOCK_METHOD(absl::Status, signEmptyPayload, + (Http::RequestHeaderMap&, const Credentials, absl::string_view)); + MOCK_METHOD(absl::Status, signUnsignedPayload, + (Http::RequestHeaderMap&, const Credentials, absl::string_view)); }; class MockFetchMetadata { diff --git a/test/extensions/common/aws/sigv4_signer_corpus_test.cc b/test/extensions/common/aws/sigv4_signer_corpus_test.cc index 885b9692f6ae..66d20c35f133 100644 --- a/test/extensions/common/aws/sigv4_signer_corpus_test.cc +++ b/test/extensions/common/aws/sigv4_signer_corpus_test.cc @@ -252,9 +252,9 @@ TEST_P(SigV4SignerCorpusTest, SigV4SignerCorpusHeaderSigning) { setDate(); addBodySigningIfRequired(); - SigV4SignerImpl headersigner_( - service_, region_, context_, - Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, expiration_); + SigV4SignerImpl headersigner_(service_, region_, context_, + Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, + expiration_); auto signer_friend = SigV4SignerImplFriend(&headersigner_); @@ -308,9 +308,9 @@ TEST_P(SigV4SignerCorpusTest, SigV4SignerCorpusQueryStringSigning) { const auto calculated_canonical_headers = Utility::canonicalizeHeaders(message_.headers(), {}); - SigV4SignerImpl querysigner_( - service_, region_, context_, - Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, expiration_); + SigV4SignerImpl querysigner_(service_, region_, context_, + Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, + expiration_); auto signer_friend = SigV4SignerImplFriend(&querysigner_); diff --git a/test/extensions/common/aws/sigv4_signer_impl_test.cc b/test/extensions/common/aws/sigv4_signer_impl_test.cc index 8e09af4cfa61..6fe46790ae0b 100644 --- a/test/extensions/common/aws/sigv4_signer_impl_test.cc +++ b/test/extensions/common/aws/sigv4_signer_impl_test.cc @@ -52,13 +52,14 @@ class SigV4SignerImplTest : public testing::Test { headers.setPath("/"); headers.addCopy(Http::LowerCaseString("host"), "www.example.com"); - SigV4SignerImpl signer(service_name, "region", - context_, + SigV4SignerImpl signer(service_name, "region", context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, 5); if (use_unsigned_payload) { - status = signer.signUnsignedPayload(headers, credentials_provider->getCredentials(), override_region); + status = signer.signUnsignedPayload(headers, credentials_provider->getCredentials(), + override_region); } else { - status = signer.signEmptyPayload(headers, credentials_provider->getCredentials(), override_region); + status = + signer.signEmptyPayload(headers, credentials_provider->getCredentials(), override_region); } EXPECT_TRUE(status.ok()); @@ -83,11 +84,11 @@ class SigV4SignerImplTest : public testing::Test { EXPECT_CALL(*credentials_provider, getCredentials()).WillOnce(Return(credentials_)); } - SigV4SignerImpl signer(service_name, "region", - context_, + SigV4SignerImpl signer(service_name, "region", context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, 5); - auto status = signer.signUnsignedPayload(extra_headers, credentials_provider->getCredentials(), override_region); + auto status = signer.signUnsignedPayload(extra_headers, credentials_provider->getCredentials(), + override_region); EXPECT_TRUE(status.ok()); auto query_parameters = Http::Utility::QueryParamsMulti::parseQueryString( extra_headers.Path()->value().getStringView()); @@ -220,7 +221,7 @@ TEST_F(SigV4SignerImplTest, SignEmptyContentHeader) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); addMethod("GET"); addPath("/"); - auto status = signer_.sign(*message_, credentials_provider_->getCredentials(),true); + auto status = signer_.sign(*message_, credentials_provider_->getCredentials(), true); EXPECT_TRUE(status.ok()); EXPECT_EQ(SigV4SignatureConstants::HashedEmptyString, message_->headers() @@ -338,8 +339,7 @@ TEST_F(SigV4SignerImplTest, QueryStringDefault5s) { headers.setPath("/example/path"); headers.addCopy(Http::LowerCaseString("host"), "example.service.zz"); headers.addCopy("testheader", "value1"); - SigV4SignerImpl querysigner("service", "region", - context_, + SigV4SignerImpl querysigner("service", "region", context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true); auto status = querysigner.signUnsignedPayload(headers, credentials_provider_->getCredentials()); diff --git a/test/extensions/common/aws/sigv4a_signer_corpus_test.cc b/test/extensions/common/aws/sigv4a_signer_corpus_test.cc index 84cb476ab2de..2d4ee5d95057 100644 --- a/test/extensions/common/aws/sigv4a_signer_corpus_test.cc +++ b/test/extensions/common/aws/sigv4a_signer_corpus_test.cc @@ -275,9 +275,9 @@ TEST_P(SigV4ASignerCorpusTest, SigV4ASignerCorpusHeaderSigning) { setDate(); addBodySigningIfRequired(); - SigV4ASignerImpl headersigner_( - service_, region_, context_, - Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, expiration_); + SigV4ASignerImpl headersigner_(service_, region_, context_, + Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, + expiration_); auto signer_friend = SigV4ASignerImplFriend(&headersigner_); @@ -336,9 +336,9 @@ TEST_P(SigV4ASignerCorpusTest, SigV4ASignerCorpusQueryStringSigning) { const auto calculated_canonical_headers = Utility::canonicalizeHeaders(message_.headers(), {}); - SigV4ASignerImpl querysigner_( - service_, region_, context_, - Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, expiration_); + SigV4ASignerImpl querysigner_(service_, region_, context_, + Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, + expiration_); auto signer_friend = SigV4ASignerImplFriend(&querysigner_); diff --git a/test/extensions/common/aws/sigv4a_signer_impl_test.cc b/test/extensions/common/aws/sigv4a_signer_impl_test.cc index c98974cbf36c..01023b887fbd 100644 --- a/test/extensions/common/aws/sigv4a_signer_impl_test.cc +++ b/test/extensions/common/aws/sigv4a_signer_impl_test.cc @@ -54,12 +54,10 @@ class SigV4ASignerImplTest : public testing::Test { // Default expiration time is 5 seconds expiration_time = 5; } - return SigV4ASignerImpl{"service", - "region", - context_, - Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, - query_string, - expiration_time}; + return SigV4ASignerImpl{ + "service", "region", + context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, + query_string, expiration_time}; } void ecdsaVerifyCanonicalRequest(std::string canonical_request, SigningType signing_type, @@ -82,13 +80,16 @@ class SigV4ASignerImplTest : public testing::Test { switch (signing_type) { case EmptyPayload: - status = signer_.signEmptyPayload(message->headers(), credentials_provider->getCredentials(), override_region); + status = signer_.signEmptyPayload(message->headers(), credentials_provider->getCredentials(), + override_region); break; case NormalSign: - status = signer_.sign(*message, credentials_provider->getCredentials(), sign_body, override_region); + status = signer_.sign(*message, credentials_provider->getCredentials(), sign_body, + override_region); break; case UnsignedPayload: - status = signer_.signUnsignedPayload(message->headers(), credentials_provider->getCredentials(), override_region); + status = signer_.signUnsignedPayload(message->headers(), + credentials_provider->getCredentials(), override_region); break; } EXPECT_TRUE(status.ok()); @@ -161,7 +162,7 @@ TEST_F(SigV4ASignerImplTest, AnonymousCredentials) { TEST_F(SigV4ASignerImplTest, MissingMethod) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_EQ(status.message(), "Message is missing :method header"); EXPECT_TRUE(message_->headers().get(Http::CustomHeaders::get().Authorization).empty()); @@ -172,7 +173,7 @@ TEST_F(SigV4ASignerImplTest, MissingPath) { EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); addMethod("GET"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_EQ(status.message(), "Message is missing :path header"); EXPECT_TRUE(message_->headers().get(Http::CustomHeaders::get().Authorization).empty()); @@ -184,7 +185,7 @@ TEST_F(SigV4ASignerImplTest, DontDuplicateHeaders) { addMethod("GET"); addPath("/"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); addHeader("authorization", "existing_value"); addHeader("x-amz-security-token", "existing_value_2"); addHeader("x-amz-date", "existing_value_3"); @@ -223,7 +224,7 @@ TEST_F(SigV4ASignerImplTest, QueryStringDoesntModifyAuthorization) { addPath("/"); addHeader("Authorization", "testValue"); auto signer_ = getTestSigner(true); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); @@ -237,7 +238,7 @@ TEST_F(SigV4ASignerImplTest, SignDateHeader) { addMethod("GET"); addPath("/"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); @@ -258,8 +259,8 @@ TEST_F(SigV4ASignerImplTest, SignSecurityTokenHeader) { addMethod("GET"); addPath("/"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); - auto status = signer_.sign(*message_, credentials_provider->getCredentials()); + auto credentials_provider = getTestCredentialsProvider(); + auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); EXPECT_EQ("token", message_->headers() .get(SigV4ASignatureHeaders::get().SecurityToken)[0] @@ -279,7 +280,7 @@ TEST_F(SigV4ASignerImplTest, SignEmptyContentHeader) { addMethod("GET"); addPath("/"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); auto status = signer_.sign(*message_, credentials_provider->getCredentials(), true); EXPECT_TRUE(status.ok()); EXPECT_EQ(SigV4ASignatureConstants::HashedEmptyString, @@ -301,8 +302,8 @@ TEST_F(SigV4ASignerImplTest, SignContentHeader) { addPath("/"); setBody("test1234"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); - auto status = signer_.sign(*message_, credentials_provider->getCredentials(),true); + auto credentials_provider = getTestCredentialsProvider(); + auto status = signer_.sign(*message_, credentials_provider->getCredentials(), true); EXPECT_TRUE(status.ok()); EXPECT_EQ("937e8d5fbb48bd4949536cd65b8d35c426b80d2f830c5c308e2cdec422ae2244", message_->headers() @@ -323,7 +324,7 @@ TEST_F(SigV4ASignerImplTest, SignContentHeaderOverrideRegion) { addPath("/"); setBody("test1234"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); auto status = signer_.sign(*message_, credentials_provider->getCredentials(), true, "region1"); EXPECT_TRUE(status.ok()); EXPECT_EQ("937e8d5fbb48bd4949536cd65b8d35c426b80d2f830c5c308e2cdec422ae2244", @@ -347,7 +348,7 @@ TEST_F(SigV4ASignerImplTest, SignExtraHeaders) { addHeader("b", "b_value"); addHeader("c", "c_value"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); @@ -365,7 +366,7 @@ TEST_F(SigV4ASignerImplTest, SignHostHeader) { addPath("/"); addHeader("host", "www.example.com"); auto signer_ = getTestSigner(false); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); auto status = signer_.sign(*message_, credentials_provider->getCredentials()); EXPECT_TRUE(status.ok()); @@ -548,9 +549,9 @@ TEST_F(SigV4ASignerImplTest, QueryStringDefault5s) { headers.setPath("/example/path"); headers.addCopy(Http::LowerCaseString("host"), "example.service.zz"); headers.addCopy("testheader", "value1"); - auto credentials_provider = getTestCredentialsProvider(); + auto credentials_provider = getTestCredentialsProvider(); - SigV4ASignerImpl querysigner("service", "region", context_, + SigV4ASignerImpl querysigner("service", "region", context_, Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true); auto status = querysigner.signUnsignedPayload(headers, credentials_provider->getCredentials()); From a1a8006bb10d4c9649801946a0d205123d1e3d42 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 2 Jan 2025 00:01:40 +0000 Subject: [PATCH 17/21] address feedback Signed-off-by: Nigel Brittain --- source/extensions/common/aws/region_provider_impl.cc | 2 -- source/extensions/common/aws/utility.cc | 1 - 2 files changed, 3 deletions(-) diff --git a/source/extensions/common/aws/region_provider_impl.cc b/source/extensions/common/aws/region_provider_impl.cc index 40622802608b..55f4b1ae2244 100644 --- a/source/extensions/common/aws/region_provider_impl.cc +++ b/source/extensions/common/aws/region_provider_impl.cc @@ -79,8 +79,6 @@ absl::optional AWSCredentialsFileRegionProvider::getRegionSet() { absl::flat_hash_map elements = {{SIGV4A_SIGNING_REGION_SET, ""}}; absl::flat_hash_map::iterator it; - // Search for the region in the credentials file - std::string file_path, profile; file_path = credential_file_path_.has_value() ? credential_file_path_.value() : Utility::getCredentialFilePath(); diff --git a/source/extensions/common/aws/utility.cc b/source/extensions/common/aws/utility.cc index 792074daf9d5..b4642ad9cdfb 100644 --- a/source/extensions/common/aws/utility.cc +++ b/source/extensions/common/aws/utility.cc @@ -454,7 +454,6 @@ std::string Utility::getEnvironmentVariableOrDefault(const std::string& variable bool Utility::resolveProfileElementsFromString( const std::string& string_data, const std::string& profile_name, absl::flat_hash_map& elements) { - // std::istringstream a(string_data); std::unique_ptr stream; stream = std::make_unique(std::istringstream{string_data}); From 2b31631cd0c38cdf4208811d68deab71c6ea043f Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 2 Jan 2025 03:50:36 +0000 Subject: [PATCH 18/21] changes Signed-off-by: Nigel Brittain --- .../common/aws/credentials_provider_impl.cc | 81 +--------- .../http/aws_request_signing/config.cc | 138 +++++++++--------- 2 files changed, 70 insertions(+), 149 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 131ad564411c..edcbbf43d0b2 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -388,56 +388,6 @@ CredentialsFileCredentialsProvider::CredentialsFileCredentialsProvider( } } -CredentialsFileCredentialsProvider::CredentialsFileCredentialsProvider( - Server::Configuration::ServerFactoryContext& context, - const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& - credential_file_config) - : context_(context), profile_("") { - - if (credential_file_config.has_credentials_data_source()) { - auto provider_or_error_ = Config::DataSource::DataSourceProvider::create( - credential_file_config.credentials_data_source(), context.mainThreadDispatcher(), - context.threadLocal(), context.api(), false, 4096); - if (provider_or_error_.ok()) { - credential_file_data_source_provider_ = std::move(provider_or_error_.value()); - if (credential_file_config.credentials_data_source().has_watched_directory()) { - has_watched_directory_ = true; - } - } else { - ENVOY_LOG_MISC(info, "Invalid credential file data source"); - credential_file_data_source_provider_.reset(); - } - } - if (!credential_file_config.profile().empty()) { - profile_ = credential_file_config.profile(); - } -} - -CredentialsFileCredentialsProvider::CredentialsFileCredentialsProvider( - Server::Configuration::ServerFactoryContext& context, - const envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider& - credential_file_config) - : context_(context), profile_("") { - - if (credential_file_config.has_credentials_data_source()) { - auto provider_or_error_ = Config::DataSource::DataSourceProvider::create( - credential_file_config.credentials_data_source(), context.mainThreadDispatcher(), - context.threadLocal(), context.api(), false, 4096); - if (provider_or_error_.ok()) { - credential_file_data_source_provider_ = std::move(provider_or_error_.value()); - if (credential_file_config.credentials_data_source().has_watched_directory()) { - has_watched_directory_ = true; - } - } else { - ENVOY_LOG_MISC(info, "Invalid credential file data source"); - credential_file_data_source_provider_.reset(); - } - } - if (!credential_file_config.profile().empty()) { - profile_ = credential_file_config.profile(); - } -} - bool CredentialsFileCredentialsProvider::needsRefresh() { return has_watched_directory_ ? true @@ -476,35 +426,6 @@ void CredentialsFileCredentialsProvider::refresh() { extractCredentials(credential_file_data.data(), profile); } -void CredentialsFileCredentialsProvider::extractCredentials(absl::string_view credentials_string, - absl::string_view profile) { - std::string credential_file_data, credential_file_path; - - // Use data source if provided, otherwise read from default AWS credential file path - if (credential_file_data_source_provider_.has_value()) { - credential_file_data = credential_file_data_source_provider_.value()->data(); - credential_file_path = ""; - } else { - credential_file_path = Utility::getCredentialFilePath(); - auto credential_file = context_.api().fileSystem().fileReadToEnd(credential_file_path); - if (credential_file.ok()) { - credential_file_data = credential_file.value(); - } else { - ENVOY_LOG(debug, "Unable to read from credential file {}", credential_file_path); - // Update last_updated_ now so that even if this function returns before successfully - // extracting credentials, this function won't be called again until after the - // REFRESH_INTERVAL. This prevents envoy from attempting and failing to read the credentials - // file on every request if there are errors extracting credentials from it (e.g. if the - // credentials file doesn't exist). - last_updated_ = context_.api().timeSource().systemTime(); - return; - } - } - ENVOY_LOG(debug, "Credentials file path = {}, profile name = {}", credential_file_path, profile); - - extractCredentials(credential_file_data.data(), profile); -} - void CredentialsFileCredentialsProvider::extractCredentials(absl::string_view credentials_string, absl::string_view profile) { @@ -1161,7 +1082,7 @@ CustomCredentialsProviderChain::CustomCredentialsProviderChain( const auto refresh_state = MetadataFetcher::MetadataReceiver::RefreshState::FirstRefresh; const auto initialization_timer = std::chrono::seconds(2); add(factories.createWebIdentityCredentialsProvider( - context, MetadataFetcher::create, sts_endpoint, refresh_state, initialization_timer, + context, context.singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, initialization_timer, web_identity, cluster_name)); } diff --git a/source/extensions/filters/http/aws_request_signing/config.cc b/source/extensions/filters/http/aws_request_signing/config.cc index cff75e4e7177..b6211cf17a47 100644 --- a/source/extensions/filters/http/aws_request_signing/config.cc +++ b/source/extensions/filters/http/aws_request_signing/config.cc @@ -217,75 +217,75 @@ AwsRequestSigningFilterFactory::createSigner( return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); } -absl::StatusOr -AwsRequestSigningFilterFactory::createSigner( - const AwsRequestSigningProtoConfig& config, - Server::Configuration::ServerFactoryContext& server_context) { - - std::string region = config.region(); - - envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; - - // If we have an overriding credential provider configuration, read it here as it may contain - // references to the region - envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = {}; - if (config.has_credential_provider()) { - if (config.credential_provider().has_credentials_file_provider()) { - credential_file_config = config.credential_provider().credentials_file_provider(); - } - } - - if (region.empty()) { - auto region_provider = - std::make_shared(credential_file_config); - absl::optional regionOpt; - if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { - regionOpt = region_provider->getRegionSet(); - } else { - // Override default credential provider chain settings with any provided settings - if (has_credential_provider_settings) { - credential_provider_config = config.credential_provider(); - } - credentials_provider = - std::make_shared( - server_context.api(), makeOptRef(server_context), server_context.singletonManager(), - region, nullptr, credential_provider_config); - } - } else { - // No credential provider settings provided, so make the default credentials provider chain - credentials_provider = - std::make_shared( - server_context.api(), makeOptRef(server_context), server_context.singletonManager(), - region, nullptr, credential_provider_config); - } - - const auto matcher_config = Extensions::Common::Aws::AwsSigningHeaderExclusionVector( - config.match_excluded_headers().begin(), config.match_excluded_headers().end()); - - const bool query_string = config.has_query_string(); - - const uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( - config.query_string(), expiration_time, - Extensions::Common::Aws::SignatureQueryParameterValues::DefaultExpiration); - - std::unique_ptr signer; - - if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { - return std::make_unique( - config.service_name(), region, server_context, matcher_config, query_string, - expiration_time); - } else { - // Verify that we have not specified a region set when using sigv4 algorithm - if (isARegionSet(region)) { - return absl::InvalidArgumentError( - "SigV4 region string cannot contain wildcards or commas. Region sets " - "can be specified when using signing_algorithm: AWS_SIGV4A."); - } - return std::make_unique( - config.service_name(), region, server_context, matcher_config, query_string, - expiration_time); - } -} +// absl::StatusOr +// AwsRequestSigningFilterFactory::createSigner( +// const AwsRequestSigningProtoConfig& config, +// Server::Configuration::ServerFactoryContext& server_context) { + +// std::string region = config.region(); + +// envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; + +// // If we have an overriding credential provider configuration, read it here as it may contain +// // references to the region +// envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = {}; +// if (config.has_credential_provider()) { +// if (config.credential_provider().has_credentials_file_provider()) { +// credential_file_config = config.credential_provider().credentials_file_provider(); +// } +// } + +// if (region.empty()) { +// auto region_provider = +// std::make_shared(credential_file_config); +// absl::optional regionOpt; +// if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { +// regionOpt = region_provider->getRegionSet(); +// } else { +// // Override default credential provider chain settings with any provided settings +// if (has_credential_provider_settings) { +// credential_provider_config = config.credential_provider(); +// } +// credentials_provider = +// std::make_shared( +// server_context.api(), makeOptRef(server_context), server_context.singletonManager(), +// region, nullptr, credential_provider_config); +// } +// } else { +// // No credential provider settings provided, so make the default credentials provider chain +// credentials_provider = +// std::make_shared( +// server_context.api(), makeOptRef(server_context), server_context.singletonManager(), +// region, nullptr, credential_provider_config); +// } + +// const auto matcher_config = Extensions::Common::Aws::AwsSigningHeaderExclusionVector( +// config.match_excluded_headers().begin(), config.match_excluded_headers().end()); + +// const bool query_string = config.has_query_string(); + +// const uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( +// config.query_string(), expiration_time, +// Extensions::Common::Aws::SignatureQueryParameterValues::DefaultExpiration); + +// std::unique_ptr signer; + +// if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { +// return std::make_unique( +// config.service_name(), region, server_context, matcher_config, query_string, +// expiration_time); +// } else { +// // Verify that we have not specified a region set when using sigv4 algorithm +// if (isARegionSet(region)) { +// return absl::InvalidArgumentError( +// "SigV4 region string cannot contain wildcards or commas. Region sets " +// "can be specified when using signing_algorithm: AWS_SIGV4A."); +// } +// return std::make_unique( +// config.service_name(), region, server_context, matcher_config, query_string, +// expiration_time); +// } +// } /** * Static registration for the AWS request signing filter. @see RegisterFactory. From 9978344cc95cd2dc6389b18c406ee96dd7dd4143 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 2 Jan 2025 04:24:53 +0000 Subject: [PATCH 19/21] format Signed-off-by: Nigel Brittain --- source/extensions/common/aws/credentials_provider_impl.cc | 4 ++-- source/extensions/common/aws/credentials_provider_impl.h | 1 - .../extensions/filters/http/aws_request_signing/config.cc | 8 ++++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index edcbbf43d0b2..00c0ddc09633 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -1082,8 +1082,8 @@ CustomCredentialsProviderChain::CustomCredentialsProviderChain( const auto refresh_state = MetadataFetcher::MetadataReceiver::RefreshState::FirstRefresh; const auto initialization_timer = std::chrono::seconds(2); add(factories.createWebIdentityCredentialsProvider( - context, context.singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, initialization_timer, - web_identity, cluster_name)); + context, context.singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, + initialization_timer, web_identity, cluster_name)); } if (credential_provider_config.has_credentials_file_provider()) { diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index cbafd415c466..eb2df62b2805 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -8,7 +8,6 @@ #include "envoy/api/api.h" #include "envoy/common/optref.h" #include "envoy/config/core/v3/base.pb.h" -#include "envoy/config/core/v3/base.pb.h" #include "envoy/event/timer.h" #include "envoy/extensions/common/aws/v3/credential_provider.pb.h" #include "envoy/http/message.h" diff --git a/source/extensions/filters/http/aws_request_signing/config.cc b/source/extensions/filters/http/aws_request_signing/config.cc index b6211cf17a47..c1e8a7360cad 100644 --- a/source/extensions/filters/http/aws_request_signing/config.cc +++ b/source/extensions/filters/http/aws_request_signing/config.cc @@ -228,8 +228,8 @@ AwsRequestSigningFilterFactory::createSigner( // // If we have an overriding credential provider configuration, read it here as it may contain // // references to the region -// envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = {}; -// if (config.has_credential_provider()) { +// envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = +// {}; if (config.has_credential_provider()) { // if (config.credential_provider().has_credentials_file_provider()) { // credential_file_config = config.credential_provider().credentials_file_provider(); // } @@ -248,8 +248,8 @@ AwsRequestSigningFilterFactory::createSigner( // } // credentials_provider = // std::make_shared( -// server_context.api(), makeOptRef(server_context), server_context.singletonManager(), -// region, nullptr, credential_provider_config); +// server_context.api(), makeOptRef(server_context), +// server_context.singletonManager(), region, nullptr, credential_provider_config); // } // } else { // // No credential provider settings provided, so make the default credentials provider chain From 27cfb75a4aae7ea764d600ff286e4a5d92053827 Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Thu, 2 Jan 2025 04:40:47 +0000 Subject: [PATCH 20/21] merge Signed-off-by: Nigel Brittain --- .../http/aws_request_signing/config.cc | 128 +++--------------- 1 file changed, 22 insertions(+), 106 deletions(-) diff --git a/source/extensions/filters/http/aws_request_signing/config.cc b/source/extensions/filters/http/aws_request_signing/config.cc index c1e8a7360cad..58c18f201017 100644 --- a/source/extensions/filters/http/aws_request_signing/config.cc +++ b/source/extensions/filters/http/aws_request_signing/config.cc @@ -175,118 +175,34 @@ AwsRequestSigningFilterFactory::createSigner( region = regionOpt.value(); } - absl::StatusOr - credentials_provider = - absl::InvalidArgumentError("No credentials provider settings configured."); + const auto matcher_config = Extensions::Common::Aws::AwsSigningHeaderExclusionVector( + config.match_excluded_headers().begin(), config.match_excluded_headers().end()); - const bool has_credential_provider_settings = - config.has_credential_provider() && - (config.credential_provider().has_assume_role_with_web_identity_provider() || - config.credential_provider().has_credentials_file_provider()); + const bool query_string = config.has_query_string(); - if (config.has_credential_provider()) { - if (config.credential_provider().has_inline_credential()) { - // If inline credential provider is set, use it instead of the default or custom credentials - // chain - const auto& inline_credential = config.credential_provider().inline_credential(); - credentials_provider = std::make_shared( - inline_credential.access_key_id(), inline_credential.secret_access_key(), - inline_credential.session_token()); - } else if (config.credential_provider().custom_credential_provider_chain()) { - // Custom credential provider chain - if (has_credential_provider_settings) { - return std::make_shared( - server_context, region, config.credential_provider()); - } - } else { - // Override default credential provider chain settings with any provided settings - if (has_credential_provider_settings) { - credential_provider_config = config.credential_provider(); - } - return std::make_shared( - server_context.api(), makeOptRef(server_context), server_context.singletonManager(), - region, nullptr, credential_provider_config); - } + const uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( + config.query_string(), expiration_time, + Extensions::Common::Aws::SignatureQueryParameterValues::DefaultExpiration); + + std::unique_ptr signer; + + if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { + return std::make_unique( + config.service_name(), region, server_context, matcher_config, query_string, + expiration_time); } else { - // No credential provider settings provided, so make the default credentials provider chain - return std::make_shared( - server_context.api(), makeOptRef(server_context), server_context.singletonManager(), region, - nullptr, credential_provider_config); + // Verify that we have not specified a region set when using sigv4 algorithm + if (isARegionSet(region)) { + return absl::InvalidArgumentError( + "SigV4 region string cannot contain wildcards or commas. Region sets " + "can be specified when using signing_algorithm: AWS_SIGV4A."); + } + return std::make_unique( + config.service_name(), region, server_context, matcher_config, query_string, + expiration_time); } - - return absl::InvalidArgumentError(std::string(credentials_provider.status().message())); } -// absl::StatusOr -// AwsRequestSigningFilterFactory::createSigner( -// const AwsRequestSigningProtoConfig& config, -// Server::Configuration::ServerFactoryContext& server_context) { - -// std::string region = config.region(); - -// envoy::extensions::common::aws::v3::AwsCredentialProvider credential_provider_config = {}; - -// // If we have an overriding credential provider configuration, read it here as it may contain -// // references to the region -// envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config = -// {}; if (config.has_credential_provider()) { -// if (config.credential_provider().has_credentials_file_provider()) { -// credential_file_config = config.credential_provider().credentials_file_provider(); -// } -// } - -// if (region.empty()) { -// auto region_provider = -// std::make_shared(credential_file_config); -// absl::optional regionOpt; -// if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { -// regionOpt = region_provider->getRegionSet(); -// } else { -// // Override default credential provider chain settings with any provided settings -// if (has_credential_provider_settings) { -// credential_provider_config = config.credential_provider(); -// } -// credentials_provider = -// std::make_shared( -// server_context.api(), makeOptRef(server_context), -// server_context.singletonManager(), region, nullptr, credential_provider_config); -// } -// } else { -// // No credential provider settings provided, so make the default credentials provider chain -// credentials_provider = -// std::make_shared( -// server_context.api(), makeOptRef(server_context), server_context.singletonManager(), -// region, nullptr, credential_provider_config); -// } - -// const auto matcher_config = Extensions::Common::Aws::AwsSigningHeaderExclusionVector( -// config.match_excluded_headers().begin(), config.match_excluded_headers().end()); - -// const bool query_string = config.has_query_string(); - -// const uint16_t expiration_time = PROTOBUF_GET_SECONDS_OR_DEFAULT( -// config.query_string(), expiration_time, -// Extensions::Common::Aws::SignatureQueryParameterValues::DefaultExpiration); - -// std::unique_ptr signer; - -// if (config.signing_algorithm() == AwsRequestSigning_SigningAlgorithm_AWS_SIGV4A) { -// return std::make_unique( -// config.service_name(), region, server_context, matcher_config, query_string, -// expiration_time); -// } else { -// // Verify that we have not specified a region set when using sigv4 algorithm -// if (isARegionSet(region)) { -// return absl::InvalidArgumentError( -// "SigV4 region string cannot contain wildcards or commas. Region sets " -// "can be specified when using signing_algorithm: AWS_SIGV4A."); -// } -// return std::make_unique( -// config.service_name(), region, server_context, matcher_config, query_string, -// expiration_time); -// } -// } - /** * Static registration for the AWS request signing filter. @see RegisterFactory. */ From bf787e7902052ce9d81ce1bed825f6c27194e0da Mon Sep 17 00:00:00 2001 From: Nigel Brittain Date: Sat, 4 Jan 2025 00:25:53 +0000 Subject: [PATCH 21/21] refactor clusters --- .../common/aws/credentials_provider.h | 15 +++- .../common/aws/credentials_provider_impl.cc | 82 ++++++++++--------- .../common/aws/credentials_provider_impl.h | 45 +++++----- .../aws_request_signing_filter.cc | 32 +++++--- 4 files changed, 104 insertions(+), 70 deletions(-) diff --git a/source/extensions/common/aws/credentials_provider.h b/source/extensions/common/aws/credentials_provider.h index 262f6ed30042..40fda22af77e 100644 --- a/source/extensions/common/aws/credentials_provider.h +++ b/source/extensions/common/aws/credentials_provider.h @@ -8,6 +8,16 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" +namespace Envoy { +namespace Extensions { +namespace HttpFilters { +namespace AwsRequestSigningFilter { +class FilterConfig; +} +} // namespace HttpFilters +} // namespace Extensions +} // namespace Envoy + namespace Envoy { namespace Extensions { namespace Common { @@ -74,7 +84,10 @@ class CredentialsProvider { * * @return bool true if credentials are pending, false otherwise */ - virtual bool credentialsPending(ABSL_ATTRIBUTE_UNUSED CredentialsPendingCallback&& cb) { + virtual bool credentialsPending( + ABSL_ATTRIBUTE_UNUSED Envoy::Extensions::HttpFilters::AwsRequestSigningFilter::FilterConfig& + config, + ABSL_ATTRIBUTE_UNUSED CredentialsPendingCallback&& cb) { return false; } }; diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 00c0ddc09633..170c11d0e5c5 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -180,12 +180,12 @@ void MetadataCredentialsProviderBase::initializeTlsAndCluster() { tls_slot_->set( [&](Event::Dispatcher&) { return std::make_shared(*this); }); - createCluster(true); + createCluster(true, cluster_name_, cluster_type_, uri_); } -void MetadataCredentialsProviderBase::createCluster(bool new_timer) { +void MetadataCredentialsProviderBase::createCluster(bool new_timer, std::string cluster_name, const envoy::config::cluster::v3::Cluster::DiscoveryType cluster_type, std::string uri ) { - auto cluster = Utility::createInternalClusterStatic(cluster_name_, cluster_type_, uri_); + auto cluster = Utility::createInternalClusterStatic(cluster_name, cluster_type, uri); // Async credential refresh timer. Only create this if it is the first time we're creating a // cluster if (new_timer) { @@ -196,7 +196,7 @@ void MetadataCredentialsProviderBase::createCluster(bool new_timer) { // Store the timer in pending cluster list for use in onClusterAddOrUpdate cluster_load_handle_ = std::make_unique( - (*tls_slot_)->pending_clusters_, cluster_name_, cache_duration_timer_); + (*tls_slot_)->pending_clusters_, cluster_name, cache_duration_timer_); const auto cluster_type_str = envoy::config::cluster::v3::Cluster::DiscoveryType_descriptor() ->FindValueByNumber(cluster.type()) @@ -207,7 +207,7 @@ void MetadataCredentialsProviderBase::createCluster(bool new_timer) { ENVOY_LOG_MISC(info, "Added a {} internal cluster [name: {}, address:{}] to fetch aws " "credentials", - cluster_type_str, cluster_name_, host_port); + cluster_type_str, cluster_name, host_port); } THROW_IF_NOT_OK(context_->clusterManager().addOrUpdateCluster(cluster, "").status()); @@ -260,15 +260,16 @@ void MetadataCredentialsProviderBase::ThreadLocalCredentialsCache::onClusterRemo if (!already_creating_) { parent_.stats_->clusters_removed_by_cds_.inc(); // Recreate our cluster if it has been deleted via CDS - parent_.context_->mainThreadDispatcher().post([this]() { parent_.createCluster(false); }); + parent_.context_->mainThreadDispatcher().post([this]() { parent_.createCluster(false, parent_.cluster_name_,parent_.cluster_type_, parent_.uri_); }); ENVOY_LOG_MISC(debug, "Re-adding async credential cluster {}", parent_.cluster_name_); } } }; -bool MetadataCredentialsProviderBase::credentialsPending(CredentialsPendingCallback&& cb) { +bool MetadataCredentialsProviderBase::credentialsPending(ABSL_ATTRIBUTE_UNUSED Envoy::Extensions::HttpFilters::AwsRequestSigningFilter::FilterConfig& config, CredentialsPendingCallback&& cb) { if (cb) { ENVOY_LOG_MISC(debug, "Adding credentials pending callback to queue"); + Thread::LockGuard guard(mu_); credential_pending_callbacks_.push_back(std::move(cb)); } return credentials_pending_; @@ -834,17 +835,16 @@ void ContainerCredentialsProvider::onMetadataError(Failure reason) { WebIdentityCredentialsProvider::WebIdentityCredentialsProvider( Server::Configuration::ServerFactoryContext& context, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view region, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name = {}) + web_identity_config) : MetadataCredentialsProviderBase( - context.api(), context, nullptr, create_metadata_fetcher_cb, cluster_name, - envoy::config::cluster::v3::Cluster::LOGICAL_DNS /*cluster_type*/, sts_endpoint, + context.api(), context, nullptr, create_metadata_fetcher_cb, Utility::getSTSEndpoint(region), + envoy::config::cluster::v3::Cluster::LOGICAL_DNS /*cluster_type*/, Utility::getSTSEndpoint(region), refresh_state, initialization_timer), - sts_endpoint_(sts_endpoint), role_arn_(web_identity_config.role_arn()), + sts_endpoint_(Utility::getSTSEndpoint(region)), role_arn_(web_identity_config.role_arn()), role_session_name_(web_identity_config.role_session_name()) { auto provider_or_error_ = Config::DataSource::DataSourceProvider::create( @@ -858,6 +858,15 @@ WebIdentityCredentialsProvider::WebIdentityCredentialsProvider( } } +// CredentialsProviderSharedPtr WebIdentityCredentialsProvider::createInstance(const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& +// web_identity_config, +// absl::string_view region) +// { +// const std::string sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; +// auto role_arn = web_identity_config.role_arn(); + +// } + bool WebIdentityCredentialsProvider::needsRefresh() { const auto now = api_.timeSource().systemTime(); @@ -1010,9 +1019,9 @@ void WebIdentityCredentialsProvider::onMetadataError(Failure reason) { handleFetchDone(); } -bool CredentialsProviderChain::credentialsPending(CredentialsPendingCallback&& cb) { +bool CredentialsProviderChain::credentialsPending(Envoy::Extensions::HttpFilters::AwsRequestSigningFilter::FilterConfig& config, CredentialsPendingCallback&& cb) { for (auto& provider : providers_) { - if (provider->credentialsPending(std::move(cb))) { + if (provider->credentialsPending(config, std::move(cb))) { ENVOY_LOG_MISC(debug, "Credentials are pending"); return true; } @@ -1082,8 +1091,8 @@ CustomCredentialsProviderChain::CustomCredentialsProviderChain( const auto refresh_state = MetadataFetcher::MetadataReceiver::RefreshState::FirstRefresh; const auto initialization_timer = std::chrono::seconds(2); add(factories.createWebIdentityCredentialsProvider( - context, context.singletonManager(), MetadataFetcher::create, sts_endpoint, refresh_state, - initialization_timer, web_identity, cluster_name)); + context, context.singletonManager(), MetadataFetcher::create, region, refresh_state, + initialization_timer, web_identity)); } if (credential_provider_config.has_credentials_file_provider()) { @@ -1135,19 +1144,19 @@ DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( !web_identity.web_identity_token_data_source().environment_variable().empty()) && !web_identity.role_arn().empty()) { - const auto sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; + // const auto sts_endpoint = Utility::getSTSEndpoint(region) + ":443"; // const auto region_uuid = absl::StrCat(region, "_", // context->api().randomGenerator().uuid()); - const auto cluster_name = stsClusterName(region); + // const auto cluster_name = stsClusterName(region); ENVOY_LOG( debug, - "Using web identity credentials provider with STS endpoint: {} and session name: {}", - sts_endpoint, web_identity.role_session_name()); + "Using web identity credentials provider with region {} and session name: {}", + region, web_identity.role_session_name()); add(factories.createWebIdentityCredentialsProvider( - context.value(), context->singletonManager(), MetadataFetcher::create, sts_endpoint, - refresh_state, initialization_timer, web_identity, cluster_name)); + context.value(), context->singletonManager(), MetadataFetcher::create, region, + refresh_state, initialization_timer, web_identity)); } } @@ -1232,37 +1241,36 @@ DefaultCredentialsProviderChain::createInstanceProfileCredentialsProvider( } CredentialsProviderSharedPtr DefaultCredentialsProviderChain::createWebIdentityCredentialsProvider( Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view region, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name) const { + web_identity_config) const { return singleton_manager.getTyped( SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), - [&context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, - web_identity_config, cluster_name] { + [&context, create_metadata_fetcher_cb, region, refresh_state, initialization_timer, + web_identity_config] { return std::make_shared( - context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, - web_identity_config, cluster_name); + context, create_metadata_fetcher_cb, region, refresh_state, initialization_timer, + web_identity_config); }); }; CredentialsProviderSharedPtr CustomCredentialsProviderChain::createWebIdentityCredentialsProvider( Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view region, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name) const { + web_identity_config) const { + ENVOY_LOG_MISC(debug, "**************** Instantiating web identity with region {} role arn {} session name{}", region, web_identity_config.role_arn(), web_identity_config.role_session_name()); return singleton_manager.getTyped( SINGLETON_MANAGER_REGISTERED_NAME(web_identity_credentials_provider), - [&context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, - web_identity_config, cluster_name] { + [&context, create_metadata_fetcher_cb, region, refresh_state, initialization_timer, + web_identity_config] { return std::make_shared( - context, create_metadata_fetcher_cb, sts_endpoint, refresh_state, initialization_timer, - web_identity_config, cluster_name); + context, create_metadata_fetcher_cb, region, refresh_state, initialization_timer, + web_identity_config); }); }; diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index eb2df62b2805..a0f8a8253ac8 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -144,13 +144,15 @@ class MetadataCredentialsProviderBase : public CachedCredentialsProviderBase { std::chrono::seconds initialization_timer); Credentials getCredentials() override; - bool credentialsPending(CredentialsPendingCallback&& cb) override; - + bool credentialsPending( + ABSL_ATTRIBUTE_UNUSED Envoy::Extensions::HttpFilters::AwsRequestSigningFilter::FilterConfig& + config, + ABSL_ATTRIBUTE_UNUSED CredentialsPendingCallback&& cb) override; // Get the Metadata credentials cache duration. static std::chrono::seconds getCacheDuration(); private: - void createCluster(bool new_timer); +void createCluster(bool new_timer, std::string cluster_name, const envoy::config::cluster::v3::Cluster::DiscoveryType cluster_type, std::string uri ); void initializeTlsAndCluster(); protected: @@ -251,6 +253,7 @@ class MetadataCredentialsProviderBase : public CachedCredentialsProviderBase { std::atomic credentials_pending_ = true; // Callbacks list for pending credentials std::vector credential_pending_callbacks_ = {}; + Thread::MutexBasicLockable mu_; }; /** @@ -335,16 +338,19 @@ class WebIdentityCredentialsProvider : public MetadataCredentialsProviderBase, // not used, and vice versa. WebIdentityCredentialsProvider( Server::Configuration::ServerFactoryContext& context, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view region, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name); + web_identity_config); // Following functions are for MetadataFetcher::MetadataReceiver interface void onMetadataSuccess(const std::string&& body) override; void onMetadataError(Failure reason) override; + + // CredentialsProviderSharedPtr createInstance(const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& + // web_identity_config, + // absl::string_view sts_endpoint); private: const std::string sts_endpoint_; @@ -370,7 +376,10 @@ class CredentialsProviderChain : public CredentialsProvider, } Credentials getCredentials() override; - bool credentialsPending(CredentialsPendingCallback&& cb) override; + bool credentialsPending( + Envoy::Extensions::HttpFilters::AwsRequestSigningFilter::FilterConfig& + config, + CredentialsPendingCallback&& cb) override; protected: std::list providers_; @@ -389,12 +398,11 @@ class CredentialsProviderChainFactories { virtual CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view region, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name) const PURE; + web_identity_config) const PURE; virtual CredentialsProviderSharedPtr createContainerCredentialsProvider( Api::Api& api, ServerFactoryContextOptRef context, Singleton::Manager& singleton_manager, @@ -424,12 +432,12 @@ class CustomCredentialsProviderChainFactories { virtual CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view region, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name) const PURE; + web_identity_config + ) const PURE; }; // TODO(nbaws) Add additional providers to the custom chain. @@ -458,12 +466,12 @@ class CustomCredentialsProviderChain : public CredentialsProviderChain, CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view region, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name) const override; + web_identity_config + ) const override; }; /** @@ -540,12 +548,11 @@ class DefaultCredentialsProviderChain : public CredentialsProviderChain, CredentialsProviderSharedPtr createWebIdentityCredentialsProvider( Server::Configuration::ServerFactoryContext& context, Singleton::Manager& singleton_manager, - CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view sts_endpoint, + CreateMetadataFetcherCb create_metadata_fetcher_cb, absl::string_view region, MetadataFetcher::MetadataReceiver::RefreshState refresh_state, std::chrono::seconds initialization_timer, const envoy::extensions::common::aws::v3::AssumeRoleWithWebIdentityCredentialProvider& - web_identity_config, - absl::string_view cluster_name) const override; + web_identity_config) const override; }; using InstanceProfileCredentialsProviderPtr = std::shared_ptr; diff --git a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc index 5a24f85501e2..3e90404a5045 100644 --- a/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc +++ b/source/extensions/filters/http/aws_request_signing/aws_request_signing_filter.cc @@ -84,19 +84,25 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, // If we are pending credentials, send the decodeHeadersCredentialsAvailable callback for when // they become available, and stop iteration. - if (config.credentialsProvider()->credentialsPending( - - Envoy::CancelWrapper::cancelWrapped( - [this, &dispatcher = decoder_callbacks_->dispatcher()]( - Envoy::Extensions::Common::Aws::Credentials credentials) { - dispatcher.post([this, credentials]() { - this->decodeHeadersCredentialsAvailable(credentials); - }); - }, - &cancel_callback_) + auto completion_cb = Envoy::CancelWrapper::cancelWrapped( + [this] (Envoy::Extensions::Common::Aws::Credentials credentials) { + decodeHeadersCredentialsAvailable(credentials); +}, &cancel_callback_); - )) { - ENVOY_LOG_MISC(debug, "Credentials are pending"); + if (config.credentialsProvider()->credentialsPending( + config, + [&dispatcher = decoder_callbacks_->dispatcher(), completion_cb = std::move(completion_cb)](Envoy::Extensions::Common::Aws::Credentials credentials) mutable + { + dispatcher.post( + [creds = std::move(credentials), cb = std::move(completion_cb)]() mutable + { + cb(creds); + } + ); + } + )) + { + ENVOY_LOG_MISC(debug, "Credentials are pending"); return Http::FilterHeadersStatus::StopAllIterationAndBuffer; } else { ENVOY_LOG_MISC(debug, "Credentials are not pending"); @@ -119,7 +125,7 @@ Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_strea // If we are pending credentials, send the decodeDataCredentialsAvailable callback for when they // become available, and stop iteration. - if (config.credentialsProvider()->credentialsPending(Envoy::CancelWrapper::cancelWrapped( + if (config.credentialsProvider()->credentialsPending(config, Envoy::CancelWrapper::cancelWrapped( [this, &dispatcher = decoder_callbacks_->dispatcher()]( Envoy::Extensions::Common::Aws::Credentials credentials) {