diff --git a/adapters/apigateway-method.go b/adapters/apigateway-method.go new file mode 100644 index 00000000..ce9022a3 --- /dev/null +++ b/adapters/apigateway-method.go @@ -0,0 +1,145 @@ +package adapters + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/apigateway" + "github.com/overmindtech/aws-source/adapterhelpers" + "github.com/overmindtech/sdp-go" +) + +type apigatewayClient interface { + GetMethod(ctx context.Context, params *apigateway.GetMethodInput, optFns ...func(*apigateway.Options)) (*apigateway.GetMethodOutput, error) +} + +func apiGatewayMethodGetFunc(ctx context.Context, client apigatewayClient, scope string, input *apigateway.GetMethodInput) (*sdp.Item, error) { + output, err := client.GetMethod(ctx, input) + if err != nil { + return nil, err + } + + attributes, err := adapterhelpers.ToAttributesWithExclude(output, "tags") + if err != nil { + return nil, err + } + + // We create a custom ID of {rest-api-id}/{resource-id}/{http-method} e.g. + // rest-api-id/resource-id/GET + methodID := fmt.Sprintf( + "%s/%s/%s", + *input.RestApiId, + *input.ResourceId, + *input.HttpMethod, + ) + err = attributes.Set("MethodID", methodID) + if err != nil { + return nil, err + } + + item := &sdp.Item{ + Type: "apigateway-method", + UniqueAttribute: "MethodID", + Attributes: attributes, + Scope: scope, + } + + if output.MethodIntegration != nil { + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: "apigateway-integration", + Method: sdp.QueryMethod_GET, + Query: methodID, + Scope: scope, + }, + BlastPropagation: &sdp.BlastPropagation{ + // They are tightly coupled + In: true, + Out: true, + }, + }) + } + + if output.AuthorizerId != nil { + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: "apigateway-authorizer", + Method: sdp.QueryMethod_GET, + Query: fmt.Sprintf("%s/%s", *input.RestApiId, *output.AuthorizerId), + Scope: scope, + }, + BlastPropagation: &sdp.BlastPropagation{ + // Deleting authorizer will affect the method + In: true, + // Deleting method won't affect the authorizer + Out: false, + }, + }) + } + + if output.RequestValidatorId != nil { + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: "apigateway-request-validator", + Method: sdp.QueryMethod_GET, + Query: fmt.Sprintf("%s/%s", *input.RestApiId, *output.RequestValidatorId), + Scope: scope, + }, + BlastPropagation: &sdp.BlastPropagation{ + // Deleting request validator will affect the method + In: true, + // Deleting method won't affect the request validator + Out: false, + }, + }) + } + + return item, nil +} + +func NewAPIGatewayMethodAdapter(client apigatewayClient, accountID string, region string) *adapterhelpers.AlwaysGetAdapter[*apigateway.GetMethodInput, *apigateway.GetMethodOutput, *apigateway.GetMethodInput, *apigateway.GetMethodOutput, apigatewayClient, *apigateway.Options] { + return &adapterhelpers.AlwaysGetAdapter[*apigateway.GetMethodInput, *apigateway.GetMethodOutput, *apigateway.GetMethodInput, *apigateway.GetMethodOutput, apigatewayClient, *apigateway.Options]{ + ItemType: "apigateway-method", + Client: client, + AccountID: accountID, + Region: region, + AdapterMetadata: apiGatewayMethodAdapterMetadata, + GetFunc: apiGatewayMethodGetFunc, + GetInputMapper: func(scope, query string) *apigateway.GetMethodInput { + // We are using a custom id of {rest-api-id}/{resource-id}/{http-method} e.g. + // rest-api-id/resource-id/GET + f := strings.Split(query, "/") + if len(f) != 3 { + slog.Error( + "query must be in the format of: the rest-api-id/resource-id/http-method", + "found", + query, + ) + + return nil + } + + return &apigateway.GetMethodInput{ + RestApiId: &f[0], + ResourceId: &f[1], + HttpMethod: &f[2], + } + }, + DisableList: true, + } +} + +var apiGatewayMethodAdapterMetadata = Metadata.Register(&sdp.AdapterMetadata{ + Type: "apigateway-method", + DescriptiveName: "API Gateway Method", + Category: sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + GetDescription: "Get a Method by rest-api id, resource id and http-method", + Search: true, + SearchDescription: "Search Methods by ARN", + }, + PotentialLinks: []string{"apigateway-method-response"}, +}) diff --git a/adapters/apigateway-method_test.go b/adapters/apigateway-method_test.go new file mode 100644 index 00000000..84241cab --- /dev/null +++ b/adapters/apigateway-method_test.go @@ -0,0 +1,114 @@ +package adapters + +import ( + "context" + "fmt" + "github.com/overmindtech/sdp-go" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/apigateway" + "github.com/aws/aws-sdk-go-v2/service/apigateway/types" + "github.com/overmindtech/aws-source/adapterhelpers" +) + +type mockAPIGatewayClient struct{} + +func (m *mockAPIGatewayClient) GetMethod(ctx context.Context, params *apigateway.GetMethodInput, optFns ...func(*apigateway.Options)) (*apigateway.GetMethodOutput, error) { + return &apigateway.GetMethodOutput{ + ApiKeyRequired: aws.Bool(false), + HttpMethod: aws.String("GET"), + AuthorizationType: aws.String("NONE"), + AuthorizerId: aws.String("authorizer-id"), + RequestParameters: map[string]bool{}, + RequestValidatorId: aws.String("request-validator-id"), + MethodResponses: map[string]types.MethodResponse{ + "200": { + ResponseModels: map[string]string{ + "application/json": "Empty", + }, + StatusCode: aws.String("200"), + }, + }, + MethodIntegration: &types.Integration{ + IntegrationResponses: map[string]types.IntegrationResponse{ + "200": { + ResponseTemplates: map[string]string{ + "application/json": "", + }, + StatusCode: aws.String("200"), + }, + }, + CacheKeyParameters: []string{}, + Uri: aws.String("arn:aws:apigateway:us-west-2:lambda:path/2015-03-31/functions/arn:aws:lambda:us-west-2:123412341234:function:My_Function/invocations"), + HttpMethod: aws.String("POST"), + CacheNamespace: aws.String("y9h6rt"), + Type: "AWS", + }, + }, nil + +} + +func TestApiGatewayGetFunc(t *testing.T) { + ctx := context.Background() + cli := mockAPIGatewayClient{} + + input := &apigateway.GetMethodInput{ + RestApiId: aws.String("rest-api-id"), + ResourceId: aws.String("resource-id"), + HttpMethod: aws.String("GET"), + } + + item, err := apiGatewayMethodGetFunc(ctx, &cli, "scope", input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err = item.Validate(); err != nil { + t.Fatal(err) + } + + methodID := fmt.Sprintf("%s/%s/%s", *input.RestApiId, *input.ResourceId, *input.HttpMethod) + authorizerID := fmt.Sprintf("%s/%s", *input.RestApiId, "authorizer-id") + validatorID := fmt.Sprintf("%s/%s", *input.RestApiId, "request-validator-id") + + tests := adapterhelpers.QueryTests{ + { + ExpectedType: "apigateway-integration", + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: methodID, + ExpectedScope: "scope", + }, + { + ExpectedType: "apigateway-authorizer", + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: authorizerID, + ExpectedScope: "scope", + }, + { + ExpectedType: "apigateway-request-validator", + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: validatorID, + ExpectedScope: "scope", + }, + } + + tests.Execute(t, item) +} + +func TestNewAPIGatewayMethodAdapter(t *testing.T) { + config, account, region := adapterhelpers.GetAutoConfig(t) + + client := apigateway.NewFromConfig(config) + + adapter := NewAPIGatewayMethodAdapter(client, account, region) + + test := adapterhelpers.E2ETest{ + Adapter: adapter, + Timeout: 10 * time.Second, + SkipList: true, + } + + test.Run(t) +} diff --git a/adapters/integration/apigateway/apigateway_test.go b/adapters/integration/apigateway/apigateway_test.go index e14df306..442c85b1 100644 --- a/adapters/integration/apigateway/apigateway_test.go +++ b/adapters/integration/apigateway/apigateway_test.go @@ -2,6 +2,7 @@ package apigateway import ( "context" + "fmt" "testing" "github.com/overmindtech/aws-source/adapterhelpers" @@ -42,6 +43,13 @@ func APIGateway(t *testing.T) { t.Fatalf("failed to validate APIGateway resource adapter: %v", err) } + methodSource := adapters.NewAPIGatewayMethodAdapter(testClient, accountID, testAWSConfig.Region) + + err = methodSource.Validate() + if err != nil { + t.Fatalf("failed to validate APIGateway method adapter: %v", err) + } + scope := adapterhelpers.FormatScope(accountID, testAWSConfig.Region) // List restApis @@ -155,4 +163,20 @@ func APIGateway(t *testing.T) { if resourceUniqueAttrFromSearch != resourceUniqueAttrFromGet { t.Fatalf("expected resource ID %s, got %s", resourceUniqueAttrFromSearch, resourceUniqueAttrFromGet) } + + // Get method + methodID := fmt.Sprintf("%s/GET", resourceUniqueAttrFromGet) // resourceUniqueAttribute contains the restApiID + method, err := methodSource.Get(ctx, scope, methodID, true) + if err != nil { + t.Fatalf("failed to get APIGateway method: %v", err) + } + + uniqueMethodAttr, err := method.GetAttributes().Get(method.GetUniqueAttribute()) + if err != nil { + t.Fatalf("failed to get unique method attribute: %v", err) + } + + if uniqueMethodAttr != methodID { + t.Fatalf("expected method ID %s, got %s", methodID, uniqueMethodAttr) + } } diff --git a/adapters/integration/apigateway/create.go b/adapters/integration/apigateway/create.go index 1772bbe5..976a959b 100644 --- a/adapters/integration/apigateway/create.go +++ b/adapters/integration/apigateway/create.go @@ -4,6 +4,7 @@ import ( "context" "errors" "log/slog" + "strings" "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/overmindtech/aws-source/adapterhelpers" @@ -57,7 +58,7 @@ func createResource(ctx context.Context, logger *slog.Logger, client *apigateway result, err := client.CreateResource(ctx, &apigateway.CreateResourceInput{ RestApiId: restAPIID, ParentId: parentID, - PathPart: adapterhelpers.PtrString(path), + PathPart: adapterhelpers.PtrString(cleanPath(path)), }) if err != nil { return nil, err @@ -65,3 +66,41 @@ func createResource(ctx context.Context, logger *slog.Logger, client *apigateway return result.Id, nil } + +func cleanPath(path string) string { + p, ok := strings.CutPrefix(path, "/") + if !ok { + return path + } + + return p +} + +func createMethod(ctx context.Context, logger *slog.Logger, client *apigateway.Client, restAPIID, resourceID *string, method string) error { + // check if a method with the same name already exists + err := findMethod(ctx, client, restAPIID, resourceID, method) + if err != nil { + if errors.As(err, new(integration.NotFoundError)) { + logger.InfoContext(ctx, "Creating method") + } else { + return err + } + } + + if err == nil { + logger.InfoContext(ctx, "Method already exists") + return nil + } + + _, err = client.PutMethod(ctx, &apigateway.PutMethodInput{ + RestApiId: restAPIID, + ResourceId: resourceID, + HttpMethod: adapterhelpers.PtrString(method), + AuthorizationType: adapterhelpers.PtrString("NONE"), + }) + if err != nil { + return err + } + + return nil +} diff --git a/adapters/integration/apigateway/find.go b/adapters/integration/apigateway/find.go index fac36f3e..42b52abe 100644 --- a/adapters/integration/apigateway/find.go +++ b/adapters/integration/apigateway/find.go @@ -2,8 +2,9 @@ package apigateway import ( "context" - + "errors" "github.com/aws/aws-sdk-go-v2/service/apigateway" + "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/aws-source/adapters/integration" ) @@ -38,3 +39,26 @@ func findResource(ctx context.Context, client *apigateway.Client, restAPIID *str return nil, integration.NewNotFoundError(integration.ResourceName(integration.APIGateway, resourceSrc, path)) } + +func findMethod(ctx context.Context, client *apigateway.Client, restAPIID, resourceID *string, method string) error { + _, err := client.GetMethod(ctx, &apigateway.GetMethodInput{ + RestApiId: restAPIID, + ResourceId: resourceID, + HttpMethod: &method, + }) + + if err != nil { + var notFoundErr *types.NotFoundException + if errors.As(err, ¬FoundErr) { + return integration.NewNotFoundError(integration.ResourceName( + integration.APIGateway, + methodSrc, + method, + )) + } + + return err + } + + return nil +} diff --git a/adapters/integration/apigateway/setup.go b/adapters/integration/apigateway/setup.go index 562b6405..c8285643 100644 --- a/adapters/integration/apigateway/setup.go +++ b/adapters/integration/apigateway/setup.go @@ -11,6 +11,7 @@ import ( const ( restAPISrc = "rest-api" resourceSrc = "resource" + methodSrc = "method" ) func setup(ctx context.Context, logger *slog.Logger, client *apigateway.Client) error { @@ -29,7 +30,13 @@ func setup(ctx context.Context, logger *slog.Logger, client *apigateway.Client) } // Create resource - _, err = createResource(ctx, logger, client, restApiID, rootResourceID, "test") + testResourceID, err := createResource(ctx, logger, client, restApiID, rootResourceID, "/test") + if err != nil { + return err + } + + // Create method + err = createMethod(ctx, logger, client, restApiID, testResourceID, "GET") if err != nil { return err } diff --git a/proc/proc.go b/proc/proc.go index 668839ad..98642c2a 100644 --- a/proc/proc.go +++ b/proc/proc.go @@ -479,6 +479,7 @@ func InitializeAwsSourceEngine(ctx context.Context, ec *discovery.EngineConfig, adapters.NewAPIGatewayRestApiAdapter(apigatewayClient, *callerID.Account, cfg.Region), adapters.NewAPIGatewayResourceAdapter(apigatewayClient, *callerID.Account, cfg.Region), adapters.NewAPIGatewayDomainNameAdapter(apigatewayClient, *callerID.Account, cfg.Region), + adapters.NewAPIGatewayMethodAdapter(apigatewayClient, *callerID.Account, cfg.Region), // SSM adapters.NewSSMParameterAdapter(ssmClient, *callerID.Account, cfg.Region),