From 1453cc2660c26d4a538d65f343334d8621909f70 Mon Sep 17 00:00:00 2001 From: Dylan Ratcliffe Date: Sun, 8 Dec 2024 14:32:37 +0000 Subject: [PATCH 1/7] Moved `AlwaysGetSource` to streaming --- adapterhelpers/always_get_source.go | 186 +++++++------- adapterhelpers/always_get_source_test.go | 296 +++++++++++++++++------ adapterhelpers/util.go | 32 ++- go.mod | 2 +- go.sum | 4 +- 5 files changed, 345 insertions(+), 175 deletions(-) diff --git a/adapterhelpers/always_get_source.go b/adapterhelpers/always_get_source.go index 3381015b..a36ea7ac 100644 --- a/adapterhelpers/always_get_source.go +++ b/adapterhelpers/always_get_source.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" "github.com/overmindtech/sdpcache" "github.com/sourcegraph/conc/pool" @@ -196,158 +197,157 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru // List Lists all available items. This is done by running the ListFunc, then // passing these results to GetFunc in order to get the details -func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) List(ctx context.Context, scope string, ignoreCache bool) ([]*sdp.Item, error) { +func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) ListStream(ctx context.Context, scope string, ignoreCache bool, stream *discovery.QueryResultStream) { if scope != s.Scopes()[0] { - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, ErrorString: fmt.Sprintf("requested scope %v does not match adapter scope %v", scope, s.Scopes()[0]), - } + }) + return + } + + if err := s.Validate(); err != nil { + stream.SendError(WrapAWSError(err)) + return } // Check to see if we have supplied the required functions if s.DisableList { // In this case we can't run list, so just return empty - return []*sdp.Item{}, nil + return } s.ensureCache() cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) if qErr != nil { - return nil, qErr + stream.SendError(qErr) + return } if cacheHit { - return cachedItems, nil - } - - items, err := s.listInternal(ctx, scope, s.ListInput) - if err != nil { - err := WrapAWSError(err) - if !CanRetry(err) { - s.cache.StoreError(err, s.cacheDuration(), ck) + for _, item := range cachedItems { + stream.SendItem(item) } - return nil, err - } - - for _, item := range items { - s.cache.StoreItem(item, s.cacheDuration(), ck) + return } - return items, nil + s.listInternal(ctx, scope, s.ListInput, ck, stream) } -// listInternal Accepts a ListInput and runs the List logic against it -func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) listInternal(ctx context.Context, scope string, input ListInput) ([]*sdp.Item, error) { - var output ListOutput - var err error - - if err = s.Validate(); err != nil { - return nil, WrapAWSError(err) - } - - p := pool.NewWithResults[*sdp.Item]().WithErrors().WithContext(ctx).WithMaxGoroutines(s.MaxParallel.Value()) - +func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) listInternal(ctx context.Context, scope string, input ListInput, ck sdpcache.CacheKey, stream *discovery.QueryResultStream) { paginator := s.ListFuncPaginatorBuilder(s.Client, input) var newGetInputs []GetInput for paginator.HasMorePages() { - output, err = paginator.NextPage(ctx) + p := pool.New().WithContext(ctx).WithMaxGoroutines(s.MaxParallel.Value()) + + output, err := paginator.NextPage(ctx) if err != nil { - return nil, err + err := WrapAWSError(err) + if !CanRetry(err) { + s.cache.StoreError(err, s.cacheDuration(), ck) + } + stream.SendError(err) + return } newGetInputs, err = s.ListFuncOutputMapper(output, input) if err != nil { - return nil, err + err := WrapAWSError(err) + if !CanRetry(err) { + s.cache.StoreError(err, s.cacheDuration(), ck) + } + stream.SendError(err) + return } for _, input := range newGetInputs { - p.Go(func(ctx context.Context) (*sdp.Item, error) { - return s.GetFunc(ctx, s.Client, scope, input) + p.Go(func(ctx context.Context) error { + item, err := s.GetFunc(ctx, s.Client, scope, input) + + if err != nil { + stream.SendError(WrapAWSError(err)) + } + + if item != nil { + s.cache.StoreItem(item, s.cacheDuration(), ck) + stream.SendItem(item) + } + + return nil }) } - } - - // We are deciding to throw the errors away from the Get requests, this - // probably isn't the best idea, but we don't want to fail the whole list - // because a Get failed. We might want to revisit this logic in the future - items, _ := p.Wait() - return items, nil + // Wait for this page to be processed before moving on to the next one + _ = p.Wait() + } } // Search Searches for AWS resources by ARN -func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) Search(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { +func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) SearchStream(ctx context.Context, scope string, query string, ignoreCache bool, stream *discovery.QueryResultStream) { if scope != s.Scopes()[0] { - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, ErrorString: fmt.Sprintf("requested scope %v does not match adapter scope %v", scope, s.Scopes()[0]), - } + }) + return } - var items []*sdp.Item - var err error + if err := s.Validate(); err != nil { + stream.SendError(WrapAWSError(err)) + return + } if s.SearchInputMapper == nil && s.SearchGetInputMapper == nil { - items, err = s.SearchARN(ctx, scope, query, ignoreCache) + s.SearchARN(ctx, scope, query, ignoreCache, stream) } else { // If we should always look for ARNs first, do that if s.AlwaysSearchARNs { - if _, err = ParseARN(query); err == nil { - items, err = s.SearchARN(ctx, scope, query, ignoreCache) + if _, err := ParseARN(query); err == nil { + s.SearchARN(ctx, scope, query, ignoreCache, stream) } else { - items, err = s.SearchCustom(ctx, scope, query, ignoreCache) + s.SearchCustom(ctx, scope, query, ignoreCache, stream) } } else { - items, err = s.SearchCustom(ctx, scope, query, ignoreCache) + s.SearchCustom(ctx, scope, query, ignoreCache, stream) } } - - if err != nil { - return nil, err - } - - return items, nil } // SearchCustom Searches using custom mapping logic. The SearchInputMapper is // used to create an input for ListFunc, at which point the usual logic is used -func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) SearchCustom(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { +func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) SearchCustom(ctx context.Context, scope string, query string, ignoreCache bool, stream *discovery.QueryResultStream) { s.ensureCache() cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query, ignoreCache) if qErr != nil { - return nil, qErr + stream.SendError(qErr) + return } if cacheHit { - return cachedItems, nil + for _, item := range cachedItems { + stream.SendItem(item) + } + return } - var items []*sdp.Item - if s.SearchInputMapper != nil { input, err := s.SearchInputMapper(scope, query) if err != nil { // Don't bother caching this error since it costs nearly nothing - return nil, WrapAWSError(err) + stream.SendError(WrapAWSError(err)) + return } - items, err = s.listInternal(ctx, scope, input) - - if err != nil { - err := WrapAWSError(err) - if !CanRetry(err) { - s.cache.StoreError(err, s.cacheDuration(), ck) - } - return nil, err - } + s.listInternal(ctx, scope, input, ck, stream) } else if s.SearchGetInputMapper != nil { input, err := s.SearchGetInputMapper(scope, query) if err != nil { // Don't cache this as it costs nearly nothing - return nil, WrapAWSError(err) + stream.SendError(WrapAWSError(err)) + return } item, err := s.GetFunc(ctx, s.Client, scope, input) @@ -357,51 +357,57 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru if !CanRetry(err) { s.cache.StoreError(err, s.cacheDuration(), ck) } - return nil, err + stream.SendError(err) + return } - items = []*sdp.Item{item} + if item != nil { + s.cache.StoreItem(item, s.cacheDuration(), ck) + stream.SendItem(item) + } } else { - return nil, errors.New("SearchCustom called without SearchInputMapper or SearchGetInputMapper") - } - - for _, item := range items { - s.cache.StoreItem(item, s.cacheDuration(), ck) + stream.SendError(errors.New("SearchCustom called without SearchInputMapper or SearchGetInputMapper")) + return } - return items, nil } -func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) SearchARN(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { +func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) SearchARN(ctx context.Context, scope string, query string, ignoreCache bool, stream *discovery.QueryResultStream) { // Parse the ARN a, err := ParseARN(query) if err != nil { - return nil, WrapAWSError(err) + stream.SendError(WrapAWSError(err)) + return } if a.ContainsWildcard() { // We can't handle wildcards by default so bail out - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOTFOUND, ErrorString: fmt.Sprintf("wildcards are not supported by adapter %v", s.Name()), Scope: scope, - } + }) + return } if arnScope := FormatScope(a.AccountID, a.Region); arnScope != scope { - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, ErrorString: fmt.Sprintf("ARN scope %v does not match request scope %v", arnScope, scope), Scope: scope, - } + }) + return } item, err := s.Get(ctx, scope, a.ResourceID(), ignoreCache) if err != nil { - return nil, WrapAWSError(err) + stream.SendError(WrapAWSError(err)) + return } - return []*sdp.Item{item}, nil + if item != nil { + stream.SendItem(item) + } } // Weight Returns the priority weighting of items returned by this sourcs. diff --git a/adapterhelpers/always_get_source_test.go b/adapterhelpers/always_get_source_test.go index f63986f5..b408e85a 100644 --- a/adapterhelpers/always_get_source_test.go +++ b/adapterhelpers/always_get_source_test.go @@ -6,6 +6,7 @@ import ( "fmt" "testing" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" "google.golang.org/protobuf/types/known/structpb" ) @@ -141,11 +142,22 @@ func TestAlwaysGetSourceList(t *testing.T) { return "" }, } + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - items, err := lgs.List(context.Background(), "foo.bar", false) + lgs.ListStream(context.Background(), "foo.bar", false, stream) + stream.Close() - if err != nil { - t.Error(err) + if len(errs) != 0 { + t.Errorf("expected no errors, got %v", len(errs)) } if len(items) != 6 { @@ -178,16 +190,24 @@ func TestAlwaysGetSourceList(t *testing.T) { return "" }, } + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) - _, err := lgs.List(context.Background(), "foo.bar", false) + lgs.ListStream(context.Background(), "foo.bar", false, stream) + stream.Close() - if err == nil { - t.Fatal("expected error but got nil") + if len(errs) != 1 { + t.Fatalf("expected 1 error, got %v", len(errs)) } qErr := &sdp.QueryError{} - if !errors.As(err, &qErr) { - t.Errorf("expected error to be a QueryError, got %v", err) + if !errors.As(errs[0], &qErr) { + t.Errorf("expected error to be a QueryError, got %v", errs[0]) } else { if qErr.GetErrorString() != "output mapper error" { t.Errorf("expected 'output mapper error', got '%v'", qErr.GetErrorString()) @@ -214,18 +234,28 @@ func TestAlwaysGetSourceList(t *testing.T) { return []string{"", ""}, nil }, GetFunc: func(ctx context.Context, client struct{}, scope, input string) (*sdp.Item, error) { - return &sdp.Item{}, errors.New("get func error") + return nil, errors.New("get func error") }, GetInputMapper: func(scope, query string) string { return "" }, } + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - items, err := lgs.List(context.Background(), "foo.bar", false) + lgs.ListStream(context.Background(), "foo.bar", false, stream) + stream.Close() - // If GetFunc fails it doesn't cause an error - if err != nil { - t.Error(err) + if len(errs) != 6 { + t.Fatalf("expected 6 error, got %v", len(errs)) } if len(items) != 0 { @@ -266,26 +296,62 @@ func TestAlwaysGetSourceSearch(t *testing.T) { } t.Run("bad ARN", func(t *testing.T) { - _, err := lgs.Search(context.Background(), "foo.bar", "query", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + lgs.SearchStream(context.Background(), "foo.bar", "query", false, stream) + stream.Close() - if err == nil { + if len(errs) == 0 { t.Error("expected error because the ARN was bad") } }) t.Run("good ARN but bad scope", func(t *testing.T) { - _, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:region:account:type/id", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err == nil { + lgs.SearchStream(context.Background(), "foo.bar", "arn:aws:service:region:account:type/id", false, stream) + stream.Close() + + if len(errs) == 0 { t.Error("expected error because the ARN had a bad scope") } }) t.Run("good ARN", func(t *testing.T) { - _, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:bar:foo:type/id", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + lgs.SearchStream(context.Background(), "foo.bar", "arn:aws:service:bar:foo:type/id", false, stream) + stream.Close() + + if len(errs) != 0 { + t.Error(errs[0]) } }) }) @@ -325,10 +391,22 @@ func TestAlwaysGetSourceSearch(t *testing.T) { } t.Run("ARN", func(t *testing.T) { - items, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:bar:foo:type/id", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + lgs.SearchStream(context.Background(), "foo.bar", "arn:aws:service:bar:foo:type/id", false, stream) + stream.Close() + + if len(errs) != 0 { + t.Error(errs[0]) } if len(items) != 1 { @@ -337,18 +415,29 @@ func TestAlwaysGetSourceSearch(t *testing.T) { }) t.Run("other search", func(t *testing.T) { - items, err := lgs.Search(context.Background(), "foo.bar", "id", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + lgs.SearchStream(context.Background(), "foo.bar", "id", false, stream) + stream.Close() + + if len(errs) != 6 { + t.Errorf("expected 6 error, got %v", len(errs)) } if len(items) != 0 { - t.Errorf("expected 0 item, got %v", len(items)) + t.Errorf("expected 0 items, got %v", len(items)) } }) }) - t.Run("with custom search logic", func(t *testing.T) { var searchMapperCalled bool @@ -380,10 +469,22 @@ func TestAlwaysGetSourceSearch(t *testing.T) { }, } - _, err := lgs.Search(context.Background(), "foo.bar", "bar", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + lgs.SearchStream(context.Background(), "foo.bar", "bar", false, stream) + stream.Close() + + if len(errs) != 0 { + t.Error(errs[0]) } if !searchMapperCalled { @@ -425,10 +526,22 @@ func TestAlwaysGetSourceSearch(t *testing.T) { }, } - items, err := ags.Search(context.Background(), "foo.bar", "id", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + ags.SearchStream(context.Background(), "foo.bar", "id", false, stream) + stream.Close() + + if len(errs) != 0 { + t.Error(errs[0]) } if len(items) != 1 { @@ -518,81 +631,108 @@ func TestAlwaysGetSourceCaching(t *testing.T) { }) t.Run("list", func(t *testing.T) { - // list - first, err := s.List(ctx, "foo.eu-west-2", false) - if err != nil { - t.Fatal(err) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + // First query + s.ListStream(ctx, "foo.eu-west-2", false, stream) + // Second time we're expecting caching + s.ListStream(ctx, "foo.eu-west-2", false, stream) + // Third time we're expecting no caching since we asked it to ignore + s.ListStream(ctx, "foo.eu-west-2", true, stream) + stream.Close() + + if len(errs) != 0 { + for _, err := range errs { + t.Error(err) + } + t.Fatal("expected no errors") } - firstGen, err := first[0].GetAttributes().Get("generation") - if err != nil { - t.Fatal(err) + + if len(items) != 3 { + t.Errorf("expected 3 items, got %v", len(items)) } - // list again - withCache, err := s.List(ctx, "foo.eu-west-2", false) + firstGen, err := items[0].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - withCacheGen, err := withCache[0].GetAttributes().Get("generation") + withCache, err := items[1].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - - if firstGen != withCacheGen { - t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) - } - - // list ignore cache - withoutCache, err := s.List(ctx, "foo.eu-west-2", true) + withoutCache, err := items[2].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - withoutCacheGen, err := withoutCache[0].GetAttributes().Get("generation") - if err != nil { - t.Fatal(err) + + if firstGen != withCache { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCache) } - if withoutCacheGen == firstGen { - t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + if withoutCache == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCache) } }) t.Run("search", func(t *testing.T) { - // search - first, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) - if err != nil { - t.Fatal(err) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + // First query + s.SearchStream(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false, stream) + // Second time we're expecting caching + s.SearchStream(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false, stream) + // Third time we're expecting no caching since we asked it to ignore + s.SearchStream(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", true, stream) + stream.Close() + + if len(errs) != 0 { + for _, err := range errs { + t.Error(err) + } + t.Fatal("expected no errors") } - firstGen, err := first[0].GetAttributes().Get("generation") - if err != nil { - t.Fatal(err) + + if len(items) != 3 { + t.Errorf("expected 3 items, got %v", len(items)) } - // search again - withCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) + firstGen, err := items[0].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - withCacheGen, err := withCache[0].GetAttributes().Get("generation") + withCache, err := items[1].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - - if firstGen != withCacheGen { - t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) - } - - // search ignore cache - withoutCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", true) + withoutCache, err := items[2].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - withoutCacheGen, err := withoutCache[0].GetAttributes().Get("generation") - if err != nil { - t.Fatal(err) + + if firstGen != withCache { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCache) } - if withoutCacheGen == firstGen { - t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + + if withoutCache == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCache) } }) } diff --git a/adapterhelpers/util.go b/adapterhelpers/util.go index 0d338b06..d490ae1b 100644 --- a/adapterhelpers/util.go +++ b/adapterhelpers/util.go @@ -294,16 +294,40 @@ func (e E2ETest) Run(t *testing.T) { t.Skip("list tests deliberately skipped") } + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + ctx, cancel := context.WithTimeout(context.Background(), e.Timeout) defer cancel() - items, err := e.Adapter.List(ctx, scope, false) - if err != nil { - t.Error(err) + if streamingAdapter, ok := e.Adapter.(discovery.StreamingAdapter); ok { + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + streamingAdapter.ListStream(context.Background(), scope, false, stream) + } else if listableAdapter, ok := e.Adapter.(discovery.ListableAdapter); ok { + var err error + items, err = listableAdapter.List(ctx, scope, false) + + if err != nil { + errs = append(errs, err) + } + } else { + t.Skip("adapter is not listable or streamable") } allNames := make(map[string]bool) + for _, err := range errs { + t.Error(err) + } + for _, item := range items { if _, exists := allNames[item.UniqueAttributeValue()]; exists { t.Errorf("duplicate item found: %v", item.UniqueAttributeValue()) @@ -311,7 +335,7 @@ func (e E2ETest) Run(t *testing.T) { allNames[item.UniqueAttributeValue()] = true } - if err = item.Validate(); err != nil { + if err := item.Validate(); err != nil { t.Error(err) } diff --git a/go.mod b/go.mod index 0f003106..212ae4ad 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/aws/smithy-go v1.22.1 github.com/getsentry/sentry-go v0.30.0 github.com/micahhausler/aws-iam-policy v0.4.2 - github.com/overmindtech/discovery v0.32.2 + github.com/overmindtech/discovery v0.33.0 github.com/overmindtech/sdp-go v0.102.0 github.com/overmindtech/sdpcache v1.6.4 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index 9fb2dbf4..71d2b0a1 100644 --- a/go.sum +++ b/go.sum @@ -156,8 +156,8 @@ github.com/nats-io/nkeys v0.4.8 h1:+wee30071y3vCZAYRsnrmIPaOe47A/SkK/UBDPdIV70= github.com/nats-io/nkeys v0.4.8/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/overmindtech/discovery v0.32.2 h1:e2rU6d2s7zTS31qVI6dTL+/tg2RCg4bgr056mOImP2k= -github.com/overmindtech/discovery v0.32.2/go.mod h1:/mtLqJh6RKdh+V7RzFGhTOSTZTxfpmciBigvyRUSGOQ= +github.com/overmindtech/discovery v0.33.0 h1:2IpIFEhpmxqs8gzen4ltxEbWg26dCh6Ey/OKtZjv1WY= +github.com/overmindtech/discovery v0.33.0/go.mod h1:/mtLqJh6RKdh+V7RzFGhTOSTZTxfpmciBigvyRUSGOQ= github.com/overmindtech/sdp-go v0.102.0 h1:fIJo893+nhr5Wn9HmOM6afXUpUTYuQOJMC6O7Vs9YFY= github.com/overmindtech/sdp-go v0.102.0/go.mod h1:byGP2BXstnX3KeFLNyEbSLE75hv01EZhyU6QiKuiYb8= github.com/overmindtech/sdpcache v1.6.4 h1:MJoYBDqDE3s8FrRzZ0RPgFiH39HWI/Mv2ImH1NdLT8k= From 60587409bac00c02d55cc6331e4ef0eb8d20a539 Mon Sep 17 00:00:00 2001 From: Dylan Ratcliffe Date: Sun, 8 Dec 2024 18:42:03 +0000 Subject: [PATCH 2/7] Moved `DescribeOnlySource` to streaming --- adapterhelpers/describe_source.go | 164 ++++++++------- adapterhelpers/describe_source_test.go | 274 +++++++++++++++++-------- adapters/ecs-capacity-provider_test.go | 19 +- adapters/elbv2-rule_test.go | 39 +++- adapters/ssm-parameter_test.go | 40 +++- 5 files changed, 363 insertions(+), 173 deletions(-) diff --git a/adapterhelpers/describe_source.go b/adapterhelpers/describe_source.go index f3719517..7c43086b 100644 --- a/adapterhelpers/describe_source.go +++ b/adapterhelpers/describe_source.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" "github.com/overmindtech/sdpcache" ) @@ -265,142 +266,130 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) Get(ctx cont } // List Lists all items in a given scope -func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) List(ctx context.Context, scope string, ignoreCache bool) ([]*sdp.Item, error) { +func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) ListStream(ctx context.Context, scope string, ignoreCache bool, stream *discovery.QueryResultStream) { if scope != s.Scopes()[0] { - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, ErrorString: fmt.Sprintf("requested scope %v does not match adapter scope %v", scope, s.Scopes()[0]), - } + }) + return } if s.InputMapperList == nil { - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOTFOUND, ErrorString: fmt.Sprintf("list is not supported for %v resources", s.ItemType), - } + }) + return } err := s.Validate() if err != nil { - return nil, WrapAWSError(err) + stream.SendError(WrapAWSError(err)) + return } s.ensureCache() cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) if qErr != nil { - return nil, qErr + stream.SendError(qErr) + return } if cacheHit { - return cachedItems, nil + for _, item := range cachedItems { + stream.SendItem(item) + } + return } - var items []*sdp.Item - input, err := s.InputMapperList(scope) if err != nil { err = s.processError(err, ck) - return nil, err + stream.SendError(err) + return } - items, err = s.describe(ctx, input, scope) - if err != nil { - err = s.processError(err, ck) - return nil, err - } - - for _, item := range items { - s.cache.StoreItem(item, s.cacheDuration(), ck) - } - - return items, nil + s.describe(ctx, nil, input, scope, ck, stream) } // Search Searches for AWS resources by ARN -func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) Search(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { +func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) SearchStream(ctx context.Context, scope string, query string, ignoreCache bool, stream *discovery.QueryResultStream) { if scope != s.Scopes()[0] { - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, ErrorString: fmt.Sprintf("requested scope %v does not match adapter scope %v", scope, s.Scopes()[0]), - } + }) + return } if s.InputMapperSearch == nil { - return s.searchARN(ctx, scope, query, ignoreCache) + s.searchARN(ctx, scope, query, ignoreCache, stream) } else { - return s.searchCustom(ctx, scope, query, ignoreCache) + s.searchCustom(ctx, scope, query, ignoreCache, stream) } } -func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) searchARN(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { +func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) searchARN(ctx context.Context, scope string, query string, ignoreCache bool, stream *discovery.QueryResultStream) { // Parse the ARN a, err := ParseARN(query) if err != nil { - return nil, WrapAWSError(err) + stream.SendError(WrapAWSError(err)) + return } if a.ContainsWildcard() { // We can't handle wildcards by default so bail out - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOTFOUND, ErrorString: fmt.Sprintf("wildcards are not supported by adapter %v", s.Name()), Scope: scope, - } + }) + return } if arnScope := FormatScope(a.AccountID, a.Region); arnScope != scope { - return nil, &sdp.QueryError{ + stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, ErrorString: fmt.Sprintf("ARN scope %v does not match request scope %v", arnScope, scope), Scope: scope, - } + }) + return } // this already uses the cache, so needs no extra handling item, err := s.Get(ctx, scope, a.ResourceID(), ignoreCache) if err != nil { - return nil, WrapAWSError(err) + stream.SendError(err) + return } - return []*sdp.Item{item}, nil + stream.SendItem(item) } // searchCustom Runs custom search logic using the `InputMapperSearch` function -func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) searchCustom(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { +func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) searchCustom(ctx context.Context, scope string, query string, ignoreCache bool, stream *discovery.QueryResultStream) { // We need to cache here since this is the only place it'll be called s.ensureCache() cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query, ignoreCache) if qErr != nil { - return nil, qErr + stream.SendError(qErr) + return } if cacheHit { - return cachedItems, nil + for _, item := range cachedItems { + stream.SendItem(item) + } + return } input, err := s.InputMapperSearch(ctx, s.Client, scope, query) if err != nil { - return nil, WrapAWSError(err) - } - - items, err := s.describe(ctx, input, scope) - if err != nil { - err = s.processError(err, ck) - return nil, err - } - - if s.PostSearchFilter != nil { - items, err = s.PostSearchFilter(ctx, query, items) - if err != nil { - err = s.processError(err, ck) - return nil, err - } + stream.SendError(WrapAWSError(err)) + return } - for _, item := range items { - s.cache.StoreItem(item, s.cacheDuration(), ck) - } - - return items, nil + s.describe(ctx, &query, input, scope, ck, stream) } // Processes an error returned by the AWS API so that it can be handled by @@ -422,43 +411,64 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) processError } // describe Runs describe on the given input, intelligently choosing whether to -// run the paginated or unpaginated query -func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) describe(ctx context.Context, input Input, scope string) ([]*sdp.Item, error) { - var output Output - var err error - var newItems []*sdp.Item - - items := make([]*sdp.Item, 0) - +// run the paginated or unpaginated query. This handles caching, error handling, +// and post-search filtering if the query param is passed +func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) describe(ctx context.Context, query *string, input Input, scope string, ck sdpcache.CacheKey, stream *discovery.QueryResultStream) { if s.Paginated() { paginator := s.PaginatorBuilder(s.Client, input) for paginator.HasMorePages() { - output, err = paginator.NextPage(ctx) + output, err := paginator.NextPage(ctx) if err != nil { - return nil, err + stream.SendError(s.processError(err, ck)) + return } - newItems, err = s.OutputMapper(ctx, s.Client, scope, input, output) + items, err := s.OutputMapper(ctx, s.Client, scope, input, output) if err != nil { - return nil, err + stream.SendError(s.processError(err, ck)) + return + } + + if query != nil && s.PostSearchFilter != nil { + items, err = s.PostSearchFilter(ctx, *query, items) + if err != nil { + stream.SendError(s.processError(err, ck)) + return + } } - items = append(items, newItems...) + for _, item := range items { + s.cache.StoreItem(item, s.cacheDuration(), ck) + stream.SendItem(item) + } } } else { - output, err = s.DescribeFunc(ctx, s.Client, input) + output, err := s.DescribeFunc(ctx, s.Client, input) if err != nil { - return nil, err + stream.SendError(s.processError(err, ck)) + return } - items, err = s.OutputMapper(ctx, s.Client, scope, input, output) + items, err := s.OutputMapper(ctx, s.Client, scope, input, output) if err != nil { - return nil, err + stream.SendError(s.processError(err, ck)) + return } - } - return items, nil + if query != nil && s.PostSearchFilter != nil { + items, err = s.PostSearchFilter(ctx, *query, items) + if err != nil { + stream.SendError(s.processError(err, ck)) + return + } + } + + for _, item := range items { + s.cache.StoreItem(item, s.cacheDuration(), ck) + stream.SendItem(item) + } + } } // Weight Returns the priority weighting of items returned by this adapter. diff --git a/adapterhelpers/describe_source_test.go b/adapterhelpers/describe_source_test.go index cd486153..f5e73e76 100644 --- a/adapterhelpers/describe_source_test.go +++ b/adapterhelpers/describe_source_test.go @@ -8,6 +8,7 @@ import ( "regexp" "testing" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/structpb" @@ -254,11 +255,22 @@ func TestSearchARN(t *testing.T) { return "fancy", nil }, } + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - items, err := s.Search(context.Background(), "account-id.region", "arn:partition:service:region:account-id:resource-type:resource-id", false) + s.SearchStream(context.Background(), "account-id.region", "arn:partition:service:region:account-id:resource-type:resource-id", false, stream) + stream.Close() - if err != nil { - t.Error(err) + if len(errs) > 0 { + t.Error(errs) } if len(items) != 1 { @@ -298,11 +310,22 @@ func TestSearchCustom(t *testing.T) { return input, nil }, } + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - items, err := s.Search(context.Background(), "account-id.region", "foo", false) + s.SearchStream(context.Background(), "account-id.region", "foo", false, stream) + stream.Close() - if err != nil { - t.Fatal(err) + if len(errs) > 0 { + t.Error(errs) } if len(items) != 1 { @@ -317,11 +340,22 @@ func TestSearchCustom(t *testing.T) { s.PostSearchFilter = func(ctx context.Context, query string, items []*sdp.Item) ([]*sdp.Item, error) { return nil, nil } + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - items, err := s.Search(context.Background(), "account-id.region", "bar", false) + s.SearchStream(context.Background(), "account-id.region", "bar", false, stream) + stream.Close() - if err != nil { - t.Fatal(err) + if len(errs) > 0 { + t.Error(errs) } if len(items) != 0 { @@ -353,10 +387,18 @@ func TestNoInputMapper(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2", false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) + s.ListStream(context.Background(), "foo.eu-west-2", false, stream) + stream.Close() - if err == nil { - t.Error("expected error but got nil") + if len(errs) == 0 { + t.Error("expected error but got none") } }) } @@ -385,10 +427,18 @@ func TestNoOutputMapper(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2", false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) + s.ListStream(context.Background(), "foo.eu-west-2", false, stream) + stream.Close() - if err == nil { - t.Error("expected error but got nil") + if len(errs) == 0 { + t.Error("expected error but got none") } }) } @@ -419,10 +469,18 @@ func TestNoDescribeFunc(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2", false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) + s.ListStream(context.Background(), "foo.eu-west-2", false, stream) + stream.Close() - if err == nil { - t.Error("expected error but got nil") + if len(errs) == 0 { + t.Error("expected error but got none") } }) } @@ -462,14 +520,22 @@ func TestFailingInputMapper(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2", false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) + s.ListStream(context.Background(), "foo.eu-west-2", false, stream) + stream.Close() - if err == nil { - t.Error("expected error but got nil") + if len(errs) == 0 { + t.Error("expected error but got none") } - if !fooBar.MatchString(err.Error()) { - t.Errorf("expected error string '%v' to contain foobar", err.Error()) + if !fooBar.MatchString(errs[0].Error()) { + t.Errorf("expected error string '%v' to contain foobar", errs[0].Error()) } }) } @@ -507,14 +573,22 @@ func TestFailingOutputMapper(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2", false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) + s.ListStream(context.Background(), "foo.eu-west-2", false, stream) + stream.Close() - if err == nil { - t.Error("expected error but got nil") + if len(errs) == 0 { + t.Error("expected error but got none") } - if !fooBar.MatchString(err.Error()) { - t.Errorf("expected error string '%v' to contain foobar", err.Error()) + if !fooBar.MatchString(errs[0].Error()) { + t.Errorf("expected error string '%v' to contain foobar", errs[0].Error()) } }) } @@ -554,14 +628,22 @@ func TestFailingDescribeFunc(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2", false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) + s.ListStream(context.Background(), "foo.eu-west-2", false, stream) + stream.Close() - if err == nil { - t.Error("expected error but got nil") + if len(errs) == 0 { + t.Error("expected error but got none") } - if !fooBar.MatchString(err.Error()) { - t.Errorf("expected error string '%v' to contain foobar", err.Error()) + if !fooBar.MatchString(errs[0].Error()) { + t.Errorf("expected error string '%v' to contain foobar", errs[0].Error()) } }) } @@ -624,10 +706,21 @@ func TestPaginated(t *testing.T) { }) t.Run("paginating a List query", func(t *testing.T) { - items, err := s.List(context.Background(), "foo.eu-west-2", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + s.ListStream(context.Background(), "foo.eu-west-2", false, stream) + stream.Close() - if err != nil { - t.Error(err) + if len(errs) > 0 { + t.Error(errs) } if len(items) != 3 { @@ -722,81 +815,102 @@ func TestDescribeOnlySourceCaching(t *testing.T) { }) t.Run("list", func(t *testing.T) { - // list - first, err := s.List(ctx, "foo.eu-west-2", false) - if err != nil { - t.Fatal(err) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + // Fist list + s.ListStream(ctx, "foo.eu-west-2", false, stream) + // List again, expect caching + s.ListStream(ctx, "foo.eu-west-2", false, stream) + // List again, ignore cache + s.ListStream(ctx, "foo.eu-west-2", true, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } - firstGen, err := first[0].GetAttributes().Get("generation") - if err != nil { - t.Fatal(err) + + if len(items) != 3 { + t.Fatalf("expected 3 items, got %v", len(items)) } - // list again - withCache, err := s.List(ctx, "foo.eu-west-2", false) + firstGen, err := items[0].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - withCacheGen, err := withCache[0].GetAttributes().Get("generation") + withCache, err := items[1].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - - if firstGen != withCacheGen { - t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) - } - - // list ignore cache - withoutCache, err := s.List(ctx, "foo.eu-west-2", true) + withoutCache, err := items[2].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - withoutCacheGen, err := withoutCache[0].GetAttributes().Get("generation") - if err != nil { - t.Fatal(err) + + if firstGen != withCache { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCache) } - if withoutCacheGen == firstGen { - t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + if withoutCache == firstGen { + t.Errorf("without cache: expected generation %v, got %v", firstGen, withoutCache) } }) t.Run("search", func(t *testing.T) { - // search - first, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) - if err != nil { - t.Fatal(err) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + // First time + s.SearchStream(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false, stream) + // Search again, expect caching + s.SearchStream(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false, stream) + // Search again, ignore cache + s.SearchStream(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", true, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } - firstGen, err := first[0].GetAttributes().Get("generation") - if err != nil { - t.Fatal(err) + + if len(items) != 3 { + t.Fatalf("expected 3 items, got %v", len(items)) } - // search again - withCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) + firstGen, err := items[0].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - withCacheGen, err := withCache[0].GetAttributes().Get("generation") + withCache, err := items[1].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - - if firstGen != withCacheGen { - t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) - } - - // search ignore cache - withoutCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", true) + withoutCache, err := items[2].GetAttributes().Get("generation") if err != nil { t.Fatal(err) } - withoutCacheGen, err := withoutCache[0].GetAttributes().Get("generation") - if err != nil { - t.Fatal(err) + + if firstGen != withCache { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCache) } - if withoutCacheGen == firstGen { - t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + + if withoutCache == firstGen { + t.Errorf("without cache: expected generation %v, got %v", firstGen, withoutCache) } }) } diff --git a/adapters/ecs-capacity-provider_test.go b/adapters/ecs-capacity-provider_test.go index fc00fe07..1f504479 100644 --- a/adapters/ecs-capacity-provider_test.go +++ b/adapters/ecs-capacity-provider_test.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/overmindtech/aws-source/adapterhelpers" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" ) @@ -126,10 +127,22 @@ func TestCapacityProviderOutputMapper(t *testing.T) { func TestCapacityProviderAdapter(t *testing.T) { adapter := NewECSCapacityProviderAdapter(&ecsTestClient{}, "", "") - items, err := adapter.List(context.Background(), "", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + adapter.ListStream(context.Background(), "", false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } if len(items) != 3 { diff --git a/adapters/elbv2-rule_test.go b/adapters/elbv2-rule_test.go index da7bb8aa..4e4bfb9b 100644 --- a/adapters/elbv2-rule_test.go +++ b/adapters/elbv2-rule_test.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/overmindtech/aws-source/adapterhelpers" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" ) @@ -102,9 +103,22 @@ func TestNewELBv2RuleAdapter(t *testing.T) { listenerSource := NewELBv2ListenerAdapter(client, account, region) ruleSource := NewELBv2RuleAdapter(client, account, region) - lbs, err := lbSource.List(context.Background(), lbSource.Scopes()[0], false) - if err != nil { - t.Fatal(err) + lbs := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + lbs = append(lbs, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + lbSource.ListStream(context.Background(), lbSource.Scopes()[0], false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } if len(lbs) == 0 { @@ -116,9 +130,22 @@ func TestNewELBv2RuleAdapter(t *testing.T) { t.Fatal(err) } - listeners, err := listenerSource.Search(context.Background(), listenerSource.Scopes()[0], fmt.Sprint(lbARN), false) - if err != nil { - t.Fatal(err) + listeners := make([]*sdp.Item, 0) + errs = make([]error, 0) + stream = discovery.NewQueryResultStream( + func(item *sdp.Item) { + listeners = append(listeners, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + listenerSource.SearchStream(context.Background(), listenerSource.Scopes()[0], fmt.Sprint(lbARN), false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } if len(listeners) == 0 { diff --git a/adapters/ssm-parameter_test.go b/adapters/ssm-parameter_test.go index 2f3fbdab..b36bfe92 100644 --- a/adapters/ssm-parameter_test.go +++ b/adapters/ssm-parameter_test.go @@ -9,6 +9,8 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/aws/aws-sdk-go/aws" "github.com/overmindtech/aws-source/adapterhelpers" + "github.com/overmindtech/discovery" + "github.com/overmindtech/sdp-go" ) type mockSSMClient struct { @@ -85,27 +87,51 @@ func TestSSMParameterAdapter(t *testing.T) { }) t.Run("List", func(t *testing.T) { - items, err := adapter.List(context.Background(), "123456789.us-east-1", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Fatal(err) + adapter.ListStream(context.Background(), "123456789.us-east-1", false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } if len(items) != 1 { t.Errorf("expected 1 item, got %d", len(items)) } - err = items[0].Validate() + err := items[0].Validate() if err != nil { t.Error(err) } }) t.Run("Search", func(t *testing.T) { - items, err := adapter.Search(context.Background(), "123456789.us-east-1", "arn:aws:ssm:us-east-1:1234567890:parameter/prod/*/service/example-service", false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Fatal(err) + adapter.SearchStream(context.Background(), "123456789.us-east-1", "arn:aws:ssm:us-east-1:1234567890:parameter/prod/*/service/example-service", false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } if len(items) != 0 { From bda50d209750329d63304518f0a8c3f8cb67318d Mon Sep 17 00:00:00 2001 From: Dylan Ratcliffe Date: Sun, 8 Dec 2024 22:32:16 +0000 Subject: [PATCH 3/7] Created new GetListAdapter that supports streaming --- adapterhelpers/get_list_adapter_v2.go | 361 ++++++++++++++ adapterhelpers/get_list_adapter_v2_test.go | 444 ++++++++++++++++++ adapterhelpers/sources.go | 9 +- adapters/iam-group.go | 21 +- adapters/iam-group_test.go | 2 +- adapters/iam-instance-profile.go | 21 +- adapters/iam-instance-profile_test.go | 2 +- adapters/iam-policy.go | 132 ++---- adapters/iam-policy_test.go | 108 +++-- adapters/iam-role.go | 100 ++-- adapters/iam-role_test.go | 77 ++- adapters/iam-user.go | 82 ++-- adapters/iam-user_test.go | 36 +- adapters/integration/ec2/instance_test.go | 49 +- adapters/integration/kms/kms_test.go | 59 ++- .../networkmanager/networkmanager_test.go | 63 ++- adapters/integration/ssm/main_test.go | 22 +- 17 files changed, 1296 insertions(+), 292 deletions(-) create mode 100644 adapterhelpers/get_list_adapter_v2.go create mode 100644 adapterhelpers/get_list_adapter_v2_test.go diff --git a/adapterhelpers/get_list_adapter_v2.go b/adapterhelpers/get_list_adapter_v2.go new file mode 100644 index 00000000..b45156fd --- /dev/null +++ b/adapterhelpers/get_list_adapter_v2.go @@ -0,0 +1,361 @@ +package adapterhelpers + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/overmindtech/discovery" + "github.com/overmindtech/sdp-go" + "github.com/overmindtech/sdpcache" +) + +// GetListAdapterV2 A adapter for AWS APIs where the Get and List functions both +// return the full item, such as many of the IAM APIs. This version supports +// paginated APIs and streaming results. +type GetListAdapterV2[ListInput InputType, ListOutput OutputType, AWSItem AWSItemType, ClientStruct ClientStructType, Options OptionsType] struct { + ItemType string // The type of items that will be returned + Client ClientStruct // The AWS API client + AccountID string // The AWS account ID + Region string // The AWS region this is related to + SupportGlobalResources bool // If true, this will also support resources in the "aws" scope which are global + AdapterMetadata *sdp.AdapterMetadata + + CacheDuration time.Duration // How long to cache items for + cache *sdpcache.Cache // The sdpcache of this adapter + cacheInitMu sync.Mutex // Mutex to ensure cache is only initialised once + + // Disables List(), meaning all calls will return empty results. This does + // not affect Search() + DisableList bool + + // GetFunc Gets the details of a specific item, returns the AWS + // representation of that item, and an error + GetFunc func(ctx context.Context, client ClientStruct, scope string, query string) (AWSItem, error) + + // A function that returns the input object that will be passed to + // ListFunc for a LIST request + InputMapperList func(scope string) (ListInput, error) + + // ListFunc Lists all items that it can find this should be used only if the + // API does not have a paginator, otherwise use ListFuncPaginatorBuilder + ListFunc func(ctx context.Context, client ClientStruct, input ListInput) (ListOutput, error) + + // A function that returns a paginator for this API. If this is nil, we will + // assume that the API is not paginated e.g. + // https://aws.github.io/aws-sdk-go-v2/docs/making-requests/#using-paginators + // + // If this is set then ListFunc will be ignored + ListFuncPaginatorBuilder func(client ClientStruct, params ListInput) Paginator[ListOutput, Options] + + // Extracts the list of items from the output of the ListFunc, these will be + // passed to the ItemMapper for conversion to SDP items + ListExtractor func(ctx context.Context, output ListOutput, client ClientStruct) ([]AWSItem, error) + + // NOTE + // + // This does not yet support custom searching, this will be added in a + // future version + + // ItemMapper Maps an AWS representation of an item to the SDP version, the + // query will be nil if the method was LIST + ItemMapper func(query *string, scope string, awsItem AWSItem) (*sdp.Item, error) + + // ListTagsFunc Optional function that will be used to list tags for a + // resource + ListTagsFunc func(context.Context, AWSItem, ClientStruct) (map[string]string, error) +} + +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) cacheDuration() time.Duration { + if s.CacheDuration == 0 { + return DefaultCacheDuration + } + + return s.CacheDuration +} + +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) ensureCache() { + s.cacheInitMu.Lock() + defer s.cacheInitMu.Unlock() + + if s.cache == nil { + s.cache = sdpcache.NewCache() + } +} + +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Cache() *sdpcache.Cache { + s.ensureCache() + return s.cache +} + +// Validate Checks that the adapter has been set up correctly +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Validate() error { + if s.GetFunc == nil { + return errors.New("GetFunc is nil") + } + + if !s.DisableList { + if s.ListFunc == nil && s.ListFuncPaginatorBuilder == nil { + return errors.New("ListFunc and ListFuncPaginatorBuilder are nil") + } + + if s.ListExtractor == nil { + return errors.New("ListExtractor is nil") + } + + if s.InputMapperList == nil { + return errors.New("InputMapperList is nil") + } + } + + if s.ItemMapper == nil { + return errors.New("ItemMapper is nil") + } + + return nil +} + +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Type() string { + return s.ItemType +} + +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Name() string { + return fmt.Sprintf("%v-adapter", s.ItemType) +} + +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Metadata() *sdp.AdapterMetadata { + return s.AdapterMetadata +} + +// List of scopes that this adapter is capable of find items for. This will be +// in the format {accountID}.{region} +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Scopes() []string { + scopes := make([]string, 0) + + scopes = append(scopes, FormatScope(s.AccountID, s.Region)) + + if s.SupportGlobalResources { + scopes = append(scopes, "aws") + } + + return scopes +} + +// hasScope Returns whether or not this adapter has the given scope +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) hasScope(scope string) bool { + if scope == "aws" && s.SupportGlobalResources { + // There is a special global "account" that is used for global resources + // called "aws" + return true + } + + for _, s := range s.Scopes() { + if s == scope { + return true + } + } + + return false +} + +// Get retrieves an item from the adapter based on the provided scope, query, and +// cache settings. It uses the defined `GetFunc`, `ItemMapper`, and +// `ListTagsFunc` to retrieve and map the item. +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Get(ctx context.Context, scope string, query string, ignoreCache bool) (*sdp.Item, error) { + if !s.hasScope(scope) { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_NOSCOPE, + ErrorString: fmt.Sprintf("requested scope %v does not match adapter scope %v", scope, s.Scopes()[0]), + } + } + + s.ensureCache() + cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + if len(cachedItems) == 0 { + return nil, nil + } else { + return cachedItems[0], nil + } + } + + awsItem, err := s.GetFunc(ctx, s.Client, scope, query) + if err != nil { + err := WrapAWSError(err) + if !CanRetry(err) { + s.cache.StoreError(err, s.cacheDuration(), ck) + } + return nil, err + } + + item, err := s.ItemMapper(&query, scope, awsItem) + if err != nil { + // Don't cache this as wrapping is very cheap and better to just try + // again than store in memory + return nil, WrapAWSError(err) + } + + if s.ListTagsFunc != nil { + item.Tags, err = s.ListTagsFunc(ctx, awsItem, s.Client) + if err != nil { + item.Tags = HandleTagsError(ctx, err) + } + } + + s.cache.StoreItem(item, s.cacheDuration(), ck) + + return item, nil +} + +// List Lists all available items. This is done by running the ListFunc, then +// passing these results to GetFunc in order to get the details +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) ListStream(ctx context.Context, scope string, ignoreCache bool, stream *discovery.QueryResultStream) { + if !s.hasScope(scope) { + stream.SendError(&sdp.QueryError{ + ErrorType: sdp.QueryError_NOSCOPE, + ErrorString: fmt.Sprintf("requested scope %v does not match adapter scope %v", scope, s.Scopes()[0]), + }) + return + } + + if s.DisableList { + return + } + + if err := s.Validate(); err != nil { + stream.SendError(&sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: err.Error(), + }) + return + } + + s.ensureCache() + cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) + if qErr != nil { + stream.SendError(qErr) + return + } + if cacheHit { + for _, item := range cachedItems { + stream.SendItem(item) + } + return + } + + listInput, err := s.InputMapperList(scope) + if err != nil { + stream.SendError(WrapAWSError(err)) + return + } + + // Define the function to send the outputs + sendOutputs := func(out ListOutput) { + // Extract the items in the correct format + awsItems, err := s.ListExtractor(ctx, out, s.Client) + if err != nil { + stream.SendError(WrapAWSError(err)) + return + } + + // Map the items to SDP items, send on the stream, and save to the + // cache + for _, awsItem := range awsItems { + item, err := s.ItemMapper(nil, scope, awsItem) + if err != nil { + stream.SendError(WrapAWSError(err)) + continue + } + + if s.ListTagsFunc != nil { + item.Tags, err = s.ListTagsFunc(ctx, awsItem, s.Client) + if err != nil { + item.Tags = HandleTagsError(ctx, err) + } + } + + stream.SendItem(item) + s.cache.StoreItem(item, s.cacheDuration(), ck) + } + } + + // See if this is paginated or not and use the appropriate method + if s.ListFuncPaginatorBuilder != nil { + paginator := s.ListFuncPaginatorBuilder(s.Client, listInput) + + for paginator.HasMorePages() { + out, err := paginator.NextPage(ctx) + if err != nil { + stream.SendError(WrapAWSError(err)) + return + } + + sendOutputs(out) + } + } else if s.ListFunc != nil { + out, err := s.ListFunc(ctx, s.Client, listInput) + if err != nil { + stream.SendError(WrapAWSError(err)) + return + } + + sendOutputs(out) + } +} + +// Search Searches for AWS resources, this can be implemented either as a +// generic ARN search that tries to extract the globally unique name from the +// ARN and pass this to a Get request, or a custom search function that can be +// used to search for items in a different, adapter-specific way +func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) SearchStream(ctx context.Context, scope string, query string, ignoreCache bool, stream *discovery.QueryResultStream) { + if !s.hasScope(scope) { + stream.SendError(&sdp.QueryError{ + ErrorType: sdp.QueryError_NOSCOPE, + ErrorString: fmt.Sprintf("requested scope %v does not match adapter scope %v", scope, s.Scopes()[0]), + }) + return + } + + // Parse the ARN + a, err := ParseARN(query) + if err != nil { + stream.SendError(WrapAWSError(err)) + return + } + + if a.ContainsWildcard() { + stream.SendError(&sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: fmt.Sprintf("wildcards are not supported by adapter %v", s.Name()), + Scope: scope, + }) + return + } + + if arnScope := FormatScope(a.AccountID, a.Region); !s.hasScope(arnScope) { + stream.SendError(&sdp.QueryError{ + ErrorType: sdp.QueryError_NOSCOPE, + ErrorString: fmt.Sprintf("ARN scope %v does not match request scope %v", arnScope, scope), + Scope: scope, + }) + return + } + + // Since this gits the Get method, and this method implements caching, we + // don't need to implement it here + item, err := s.Get(ctx, scope, a.ResourceID(), ignoreCache) + + if err != nil { + stream.SendError(err) + return + } + + if item != nil { + stream.SendItem(item) + } +} diff --git a/adapterhelpers/get_list_adapter_v2_test.go b/adapterhelpers/get_list_adapter_v2_test.go new file mode 100644 index 00000000..8356e3c0 --- /dev/null +++ b/adapterhelpers/get_list_adapter_v2_test.go @@ -0,0 +1,444 @@ +package adapterhelpers + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/overmindtech/discovery" + "github.com/overmindtech/sdp-go" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestGetListAdapterV2Type(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "foo", + } + + if s.Type() != "foo" { + t.Errorf("expected type to be foo got %v", s.Type()) + } +} + +func TestGetListAdapterV2Name(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "foo", + } + + if s.Name() != "foo-adapter" { + t.Errorf("expected type to be foo-adapter got %v", s.Name()) + } +} + +func TestGetListAdapterV2Scopes(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + AccountID: "foo", + Region: "bar", + } + + if s.Scopes()[0] != "foo.bar" { + t.Errorf("expected scope to be foo.bar, got %v", s.Scopes()[0]) + } +} + +func TestGetListAdapterV2Get(t *testing.T) { + t.Run("with no errors", func(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "person", + Region: "eu-west-2", + AccountID: "12345", + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + return "", nil + }, + ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { + return &sdp.Item{}, nil + }, + ListTagsFunc: func(ctx context.Context, s1 string, s2 struct{}) (map[string]string, error) { + return map[string]string{ + "foo": "bar", + }, nil + }, + } + + item, err := s.Get(context.Background(), "12345.eu-west-2", "", false) + if err != nil { + t.Error(err) + } + + if item.GetTags()["foo"] != "bar" { + t.Errorf("expected tag foo to be bar, got %v", item.GetTags()["foo"]) + } + }) + + t.Run("with an error in the GetFunc", func(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "person", + Region: "eu-west-2", + AccountID: "12345", + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + return "", errors.New("get func error") + }, + ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { + return &sdp.Item{}, nil + }, + } + + if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err == nil { + t.Error("expected error got nil") + } + }) + + t.Run("with an error in the mapper", func(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "person", + Region: "eu-west-2", + AccountID: "12345", + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + return "", nil + }, + ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { + return &sdp.Item{}, errors.New("mapper error") + }, + } + + if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err == nil { + t.Error("expected error got nil") + } + }) +} + +func TestGetListAdapterV2ListStream(t *testing.T) { + t.Run("with no errors", func(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "person", + Region: "eu-west-2", + AccountID: "12345", + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + return "", nil + }, + ListFunc: func(ctx context.Context, client struct{}, input string) ([]string, error) { + return []string{"one", "two"}, nil + }, + ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { + return &sdp.Item{}, nil + }, + ListExtractor: func(ctx context.Context, output []string, client struct{}) ([]string, error) { + return output, nil + }, + ListTagsFunc: func(ctx context.Context, s1 string, s2 struct{}) (map[string]string, error) { + return map[string]string{ + "foo": "bar", + }, nil + }, + InputMapperList: func(scope string) (string, error) { + return "input", nil + }, + } + + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + s.ListStream(context.Background(), "12345.eu-west-2", false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) + } + + if len(items) != 2 { + t.Errorf("expected 2 items, got %v", len(items)) + } + }) + + t.Run("with an error in the ListFunc", func(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "person", + Region: "eu-west-2", + AccountID: "12345", + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + return "", nil + }, + ListFunc: func(ctx context.Context, client struct{}, scope string) ([]string, error) { + return []string{"", ""}, errors.New("list func error") + }, + ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { + return &sdp.Item{}, nil + }, + } + + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) + + s.ListStream(context.Background(), "12345.eu-west-2", false, stream) + stream.Close() + + if len(errs) == 0 { + t.Error("expected errors got none") + } + }) + + t.Run("with an error in the mapper", func(t *testing.T) { + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "person", + Region: "eu-west-2", + AccountID: "12345", + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + return "", nil + }, + ListExtractor: func(ctx context.Context, output []string, client struct{}) ([]string, error) { + return output, nil + }, + ListFunc: func(ctx context.Context, client struct{}, scope string) ([]string, error) { + return []string{"", ""}, nil + }, + ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { + return &sdp.Item{}, errors.New("mapper error") + }, + InputMapperList: func(scope string) (string, error) { + return "input", nil + }, + } + + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + s.ListStream(context.Background(), "12345.eu-west-2", false, stream) + stream.Close() + + if len(errs) != 2 { + t.Errorf("expected 2 errors got %v", len(errs)) + } + + if len(items) != 0 { + t.Errorf("expected no items, got %v", len(items)) + } + }) +} + +// MockPaginator is a mock implementation of the Paginator interface +type MockPaginator struct { + pages [][]string + pageIdx int + hasPages bool +} + +func (p *MockPaginator) HasMorePages() bool { + return p.hasPages && p.pageIdx < len(p.pages) +} + +func (p *MockPaginator) NextPage(ctx context.Context, opts ...func(struct{})) ([]string, error) { + if !p.HasMorePages() { + return nil, errors.New("no more pages available") + } + page := p.pages[p.pageIdx] + p.pageIdx++ + return page, nil +} + +func TestListFuncPaginatorBuilder(t *testing.T) { + adapter := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "test-item", + AccountID: "foo", + Region: "eu-west-2", + Client: struct{}{}, + InputMapperList: func(scope string) (string, error) { + return "test-input", nil + }, + ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[[]string, struct{}] { + return &MockPaginator{ + pages: [][]string{ + {"item1", "item2"}, + {"item3", "item4"}, + }, + hasPages: true, + } + }, + ListExtractor: func(ctx context.Context, output []string, client struct{}) ([]string, error) { + return output, nil + }, + ItemMapper: func(query *string, scope string, awsItem string) (*sdp.Item, error) { + attrs, _ := sdp.ToAttributes(map[string]interface{}{ + "id": awsItem, + }) + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "id", + Attributes: attrs, + Scope: scope, + }, nil + }, + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + return "", nil + }, + } + + ctx := context.Background() + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.ListStream(ctx, "foo.eu-west-2", false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) + } + + if len(items) != 4 { + t.Errorf("expected 4 items, got %v", len(items)) + } + +} + +func TestGetListAdapterV2Caching(t *testing.T) { + ctx := context.Background() + generation := 0 + s := GetListAdapterV2[string, []string, string, struct{}, struct{}]{ + ItemType: "test-type", + Region: "eu-west-2", + AccountID: "foo", + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + generation += 1 + return fmt.Sprintf("%v", generation), nil + }, + ListFunc: func(ctx context.Context, client struct{}, scope string) ([]string, error) { + generation += 1 + return []string{fmt.Sprintf("%v", generation)}, nil + }, + ListExtractor: func(ctx context.Context, output []string, client struct{}) ([]string, error) { + return output, nil + }, + InputMapperList: func(scope string) (string, error) { + return "input", nil + }, + ItemMapper: func(query *string, scope string, output string) (*sdp.Item, error) { + return &sdp.Item{ + Scope: "foo.eu-west-2", + Type: "test-type", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{ + AttrStruct: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "name": structpb.NewStringValue("test-item"), + "generation": structpb.NewStringValue(output), + }, + }, + }, + }, nil + }, + } + + t.Run("get", func(t *testing.T) { + // get + first, err := s.Get(ctx, "foo.eu-west-2", "test-item", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first.GetAttributes().Get("generation") + if err != nil { + t.Fatal(err) + } + + // get again + withCache, err := s.Get(ctx, "foo.eu-west-2", "test-item", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache.GetAttributes().Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // get ignore cache + withoutCache, err := s.Get(ctx, "foo.eu-west-2", "test-item", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache.GetAttributes().Get("generation") + if err != nil { + t.Fatal(err) + } + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) + + t.Run("list", func(t *testing.T) { + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + // First call + s.ListStream(ctx, "foo.eu-west-2", false, stream) + // Second call with caching + s.ListStream(ctx, "foo.eu-west-2", false, stream) + // Third call without caching + s.ListStream(ctx, "foo.eu-west-2", true, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) + } + + firstGen, err := items[0].GetAttributes().Get("generation") + if err != nil { + t.Fatal(err) + } + withCacheGen, err := items[1].GetAttributes().Get("generation") + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := items[2].GetAttributes().Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) +} diff --git a/adapterhelpers/sources.go b/adapterhelpers/sources.go index 2c0fcf10..585baa4a 100644 --- a/adapterhelpers/sources.go +++ b/adapterhelpers/sources.go @@ -12,13 +12,14 @@ const DefaultMaxResultsPerPage = 100 // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/eks@v1.26.0#Client type ClientStructType any -// InputType is the type of data that will be sent to the DesribeFunc. This is -// typically a struct ending with the word Input such as: +// InputType is the type of data that will be sent to the a List/Describe +// function. This is typically a struct ending with the word Input such as: // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/eks@v1.26.0#DescribeClusterInput type InputType any -// OutputType is the type of output to expect from the DescribeFunc, this is -// usually named the same as the input type, but with `Output` on the end e.g. +// OutputType is the type of output to expect from the List/Describe function, +// this is usually named the same as the input type, but with `Output` on the +// end e.g. // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/eks@v1.26.0#DescribeClusterOutput type OutputType any diff --git a/adapters/iam-group.go b/adapters/iam-group.go index 0a543209..ee39ca1a 100644 --- a/adapters/iam-group.go +++ b/adapters/iam-group.go @@ -39,7 +39,7 @@ func groupListFunc(ctx context.Context, client *iam.Client, _ string) ([]*types. return zones, nil } -func groupItemMapper(_, scope string, awsItem *types.Group) (*sdp.Item, error) { +func groupItemMapper(_ *string, scope string, awsItem *types.Group) (*sdp.Item, error) { attributes, err := adapterhelpers.ToAttributesWithExclude(awsItem) if err != nil { @@ -56,18 +56,29 @@ func groupItemMapper(_, scope string, awsItem *types.Group) (*sdp.Item, error) { return &item, nil } -func NewIAMGroupAdapter(client *iam.Client, accountID string, region string) *adapterhelpers.GetListAdapter[*types.Group, *iam.Client, *iam.Options] { - return &adapterhelpers.GetListAdapter[*types.Group, *iam.Client, *iam.Options]{ +func NewIAMGroupAdapter(client *iam.Client, accountID string, region string) *adapterhelpers.GetListAdapterV2[*iam.ListGroupsInput, *iam.ListGroupsOutput, *types.Group, *iam.Client, *iam.Options] { + return &adapterhelpers.GetListAdapterV2[*iam.ListGroupsInput, *iam.ListGroupsOutput, *types.Group, *iam.Client, *iam.Options]{ ItemType: "iam-group", Client: client, CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time AccountID: accountID, + Region: region, AdapterMetadata: iamGroupAdapterMetadata, GetFunc: func(ctx context.Context, client *iam.Client, scope, query string) (*types.Group, error) { return groupGetFunc(ctx, client, scope, query) }, - ListFunc: func(ctx context.Context, client *iam.Client, scope string) ([]*types.Group, error) { - return groupListFunc(ctx, client, scope) + InputMapperList: func(scope string) (*iam.ListGroupsInput, error) { + return &iam.ListGroupsInput{}, nil + }, + ListFuncPaginatorBuilder: func(client *iam.Client, params *iam.ListGroupsInput) adapterhelpers.Paginator[*iam.ListGroupsOutput, *iam.Options] { + return iam.NewListGroupsPaginator(client, params) + }, + ListExtractor: func(_ context.Context, output *iam.ListGroupsOutput, _ *iam.Client) ([]*types.Group, error) { + groups := make([]*types.Group, 0, len(output.Groups)) + for i := range output.Groups { + groups = append(groups, &output.Groups[i]) + } + return groups, nil }, ItemMapper: groupItemMapper, } diff --git a/adapters/iam-group_test.go b/adapters/iam-group_test.go index 1ddad533..1cb3c722 100644 --- a/adapters/iam-group_test.go +++ b/adapters/iam-group_test.go @@ -19,7 +19,7 @@ func TestGroupItemMapper(t *testing.T) { CreateDate: adapterhelpers.PtrTime(time.Now()), } - item, err := groupItemMapper("", "foo", &zone) + item, err := groupItemMapper(nil, "foo", &zone) if err != nil { t.Error(err) diff --git a/adapters/iam-instance-profile.go b/adapters/iam-instance-profile.go index 477d5afc..f28a8b8c 100644 --- a/adapters/iam-instance-profile.go +++ b/adapters/iam-instance-profile.go @@ -39,7 +39,7 @@ func instanceProfileListFunc(ctx context.Context, client *iam.Client, _ string) return zones, nil } -func instanceProfileItemMapper(_, scope string, awsItem *types.InstanceProfile) (*sdp.Item, error) { +func instanceProfileItemMapper(_ *string, scope string, awsItem *types.InstanceProfile) (*sdp.Item, error) { attributes, err := adapterhelpers.ToAttributesWithExclude(awsItem) if err != nil { @@ -118,18 +118,29 @@ func instanceProfileListTagsFunc(ctx context.Context, ip *types.InstanceProfile, return tags } -func NewIAMInstanceProfileAdapter(client *iam.Client, accountID string, region string) *adapterhelpers.GetListAdapter[*types.InstanceProfile, *iam.Client, *iam.Options] { - return &adapterhelpers.GetListAdapter[*types.InstanceProfile, *iam.Client, *iam.Options]{ +func NewIAMInstanceProfileAdapter(client *iam.Client, accountID string, region string) *adapterhelpers.GetListAdapterV2[*iam.ListInstanceProfilesInput, *iam.ListInstanceProfilesOutput, *types.InstanceProfile, *iam.Client, *iam.Options] { + return &adapterhelpers.GetListAdapterV2[*iam.ListInstanceProfilesInput, *iam.ListInstanceProfilesOutput, *types.InstanceProfile, *iam.Client, *iam.Options]{ ItemType: "iam-instance-profile", Client: client, CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time AccountID: accountID, + Region: region, AdapterMetadata: instanceProfileAdapterMetadata, GetFunc: func(ctx context.Context, client *iam.Client, scope, query string) (*types.InstanceProfile, error) { return instanceProfileGetFunc(ctx, client, scope, query) }, - ListFunc: func(ctx context.Context, client *iam.Client, scope string) ([]*types.InstanceProfile, error) { - return instanceProfileListFunc(ctx, client, scope) + InputMapperList: func(scope string) (*iam.ListInstanceProfilesInput, error) { + return &iam.ListInstanceProfilesInput{}, nil + }, + ListFuncPaginatorBuilder: func(client *iam.Client, params *iam.ListInstanceProfilesInput) adapterhelpers.Paginator[*iam.ListInstanceProfilesOutput, *iam.Options] { + return iam.NewListInstanceProfilesPaginator(client, params) + }, + ListExtractor: func(_ context.Context, output *iam.ListInstanceProfilesOutput, _ *iam.Client) ([]*types.InstanceProfile, error) { + profiles := make([]*types.InstanceProfile, 0, len(output.InstanceProfiles)) + for i := range output.InstanceProfiles { + profiles = append(profiles, &output.InstanceProfiles[i]) + } + return profiles, nil }, ListTagsFunc: func(ctx context.Context, ip *types.InstanceProfile, c *iam.Client) (map[string]string, error) { return instanceProfileListTagsFunc(ctx, ip, c), nil diff --git a/adapters/iam-instance-profile_test.go b/adapters/iam-instance-profile_test.go index be6dc0ca..6958e651 100644 --- a/adapters/iam-instance-profile_test.go +++ b/adapters/iam-instance-profile_test.go @@ -39,7 +39,7 @@ func TestInstanceProfileItemMapper(t *testing.T) { }, } - item, err := instanceProfileItemMapper("", "foo", &profile) + item, err := instanceProfileItemMapper(nil, "foo", &profile) if err != nil { t.Error(err) diff --git a/adapters/iam-policy.go b/adapters/iam-policy.go index 47f0ca75..2a5dcb26 100644 --- a/adapters/iam-policy.go +++ b/adapters/iam-policy.go @@ -16,7 +16,6 @@ import ( "github.com/overmindtech/sdp-go" log "github.com/sirupsen/logrus" "github.com/sourcegraph/conc/iter" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) @@ -132,74 +131,7 @@ func addPolicyEntities(ctx context.Context, client IAMClient, details *PolicyDet return nil } -// PolicyListFunc Lists all attached policies. There is no way to list -// unattached policies since I don't think it will be very valuable, there are -// hundreds by default and if you aren't using them they aren't very interesting -func policyListFunc(ctx context.Context, client IAMClient, scope string) ([]*PolicyDetails, error) { - var span trace.Span - if log.GetLevel() == log.TraceLevel { - // Only create new spans on trace level logging - ctx, span = tracer.Start(ctx, "policyListFunc") - defer span.End() - } else { - span = trace.SpanFromContext(ctx) - } - - policies := make([]types.Policy, 0) - - var iamScope types.PolicyScopeType - - if scope == "aws" { - iamScope = types.PolicyScopeTypeAws - } else { - iamScope = types.PolicyScopeTypeLocal - } - - paginator := iam.NewListPoliciesPaginator(client, &iam.ListPoliciesInput{ - OnlyAttached: true, - Scope: iamScope, - }) - - for paginator.HasMorePages() { - out, err := paginator.NextPage(ctx) - - if err != nil { - return nil, err - } - - policies = append(policies, out.Policies...) - } - - span.SetAttributes( - attribute.Int("ovm.aws.numPolicies", len(policies)), - ) - - policyDetails, err := iter.MapErr[types.Policy, *PolicyDetails](policies, func(p *types.Policy) (*PolicyDetails, error) { - details := PolicyDetails{ - Policy: p, - } - - err := addPolicyEntities(ctx, client, &details) - if err != nil { - return &details, err - } - - err = addPolicyDocument(ctx, client, &details) - if err != nil { - return &details, err - } - - return &details, nil - }) - - if err != nil { - return nil, err - } - - return policyDetails, nil -} - -func policyItemMapper(_, scope string, awsItem *PolicyDetails) (*sdp.Item, error) { +func policyItemMapper(_ *string, scope string, awsItem *PolicyDetails) (*sdp.Item, error) { finalAttributes := struct { *types.Policy Document *policy.Policy @@ -316,32 +248,60 @@ func policyListTagsFunc(ctx context.Context, p *PolicyDetails, client IAMClient) return tags, nil } +func policyListExtractor(ctx context.Context, output *iam.ListPoliciesOutput, client IAMClient) ([]*PolicyDetails, error) { + return iter.MapErr[types.Policy, *PolicyDetails](output.Policies, func(p *types.Policy) (*PolicyDetails, error) { + details := PolicyDetails{ + Policy: p, + } + + err := addPolicyEntities(ctx, client, &details) + if err != nil { + return &details, err + } + + err = addPolicyDocument(ctx, client, &details) + if err != nil { + return &details, err + } + + return &details, nil + }) +} + // NewPolicyAdapter Note that this policy adapter only support polices that are // user-created due to the fact that the AWS-created ones are basically "global" // in scope. In order to get this to work I'd have to change the way the adapter // is implemented so that it was mart enough to handle different scopes. This // has been added to the backlog: // https://github.com/overmindtech/aws-adapter/issues/68 -func NewIAMPolicyAdapter(client *iam.Client, accountID string, _ string) *adapterhelpers.GetListAdapter[*PolicyDetails, IAMClient, *iam.Options] { - return &adapterhelpers.GetListAdapter[*PolicyDetails, IAMClient, *iam.Options]{ - ItemType: "iam-policy", - Client: client, - CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time - AccountID: accountID, - Region: "", // IAM policies aren't tied to a region - AdapterMetadata: policyAdapterMetadata, - // Some IAM policies are global, this means that their ARN doesn't - // contain an account name and instead just says "aws". Enabling this - // setting means these also work +func NewIAMPolicyAdapter(client IAMClient, accountID string, _ string) *adapterhelpers.GetListAdapterV2[*iam.ListPoliciesInput, *iam.ListPoliciesOutput, *PolicyDetails, IAMClient, *iam.Options] { + return &adapterhelpers.GetListAdapterV2[*iam.ListPoliciesInput, *iam.ListPoliciesOutput, *PolicyDetails, IAMClient, *iam.Options]{ + ItemType: "iam-policy", + Client: client, + AccountID: accountID, + Region: "", // IAM policies aren't tied to a region + CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time + AdapterMetadata: policyAdapterMetadata, SupportGlobalResources: true, - GetFunc: func(ctx context.Context, client IAMClient, scope, query string) (*PolicyDetails, error) { - return policyGetFunc(ctx, client, scope, query) + InputMapperList: func(scope string) (*iam.ListPoliciesInput, error) { + var iamScope types.PolicyScopeType + if scope == "aws" { + iamScope = types.PolicyScopeTypeAws + } else { + iamScope = types.PolicyScopeTypeLocal + } + return &iam.ListPoliciesInput{ + OnlyAttached: true, + Scope: iamScope, + }, nil }, - ListFunc: func(ctx context.Context, client IAMClient, scope string) ([]*PolicyDetails, error) { - return policyListFunc(ctx, client, scope) + ListFuncPaginatorBuilder: func(client IAMClient, params *iam.ListPoliciesInput) adapterhelpers.Paginator[*iam.ListPoliciesOutput, *iam.Options] { + return iam.NewListPoliciesPaginator(client, params) }, - ListTagsFunc: policyListTagsFunc, - ItemMapper: policyItemMapper, + ListExtractor: policyListExtractor, + GetFunc: policyGetFunc, + ItemMapper: policyItemMapper, + ListTagsFunc: policyListTagsFunc, } } diff --git a/adapters/iam-policy_test.go b/adapters/iam-policy_test.go index bddb3bc9..4f655ce4 100644 --- a/adapters/iam-policy_test.go +++ b/adapters/iam-policy_test.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/overmindtech/aws-source/adapterhelpers" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" ) @@ -206,18 +207,6 @@ func TestPolicyGetFunc(t *testing.T) { } } -func TestPolicyListFunc(t *testing.T) { - policies, err := policyListFunc(context.Background(), &TestIAMClient{}, "foo") - - if err != nil { - t.Error(err) - } - - if len(policies) != 2 { - t.Errorf("expected 2 policies, got %v", len(policies)) - } -} - func TestPolicyListTagsFunc(t *testing.T) { tags, err := policyListTagsFunc(context.Background(), &PolicyDetails{ Policy: &types.Policy{ @@ -271,7 +260,7 @@ func TestPolicyItemMapper(t *testing.T) { if err != nil { t.Fatal(err) } - item, err := policyItemMapper("", "foo", details) + item, err := policyItemMapper(nil, "foo", details) if err != nil { t.Error(err) @@ -373,10 +362,22 @@ func TestNewIAMPolicyAdapter(t *testing.T) { ctx, span := tracer.Start(context.Background(), t.Name()) defer span.End() - items, err := adapter.List(ctx, adapterhelpers.FormatScope(account, ""), false) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + adapter.ListStream(ctx, adapterhelpers.FormatScope(account, ""), false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } for _, item := range items { @@ -405,10 +406,19 @@ func TestNewIAMPolicyAdapter(t *testing.T) { arn, _ := items[0].GetAttributes().Get("Arn") - _, err := adapter.Search(ctx, adapterhelpers.FormatScope(account, ""), arn.(string), false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + adapter.SearchStream(ctx, adapterhelpers.FormatScope(account, ""), arn.(string), false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } }) @@ -420,9 +430,18 @@ func TestNewIAMPolicyAdapter(t *testing.T) { arn, _ := items[0].GetAttributes().Get("Arn") - _, err := adapter.Search(ctx, "aws", arn.(string), false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) - if err == nil { + adapter.SearchStream(ctx, "aws", arn.(string), false, stream) + stream.Close() + + if len(errs) == 0 { t.Error("expected error, got nil") } }) @@ -432,9 +451,22 @@ func TestNewIAMPolicyAdapter(t *testing.T) { ctx, span := tracer.Start(context.Background(), t.Name()) defer span.End() - items, err := adapter.List(ctx, "aws", false) - if err != nil { - t.Error(err) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.ListStream(ctx, "aws", false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } if len(items) == 0 { @@ -467,9 +499,18 @@ func TestNewIAMPolicyAdapter(t *testing.T) { arn, _ := items[0].GetAttributes().Get("Arn") - _, err := adapter.Search(ctx, adapterhelpers.FormatScope(account, ""), arn.(string), false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.SearchStream(ctx, adapterhelpers.FormatScope(account, ""), arn.(string), false, stream) + stream.Close() - if err == nil { + if len(errs) == 0 { t.Error("expected error, got nil") } }) @@ -482,10 +523,19 @@ func TestNewIAMPolicyAdapter(t *testing.T) { arn, _ := items[0].GetAttributes().Get("Arn") - _, err := adapter.Search(ctx, "aws", arn.(string), false) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) {}, + func(err error) { + errs = append(errs, err) + }, + ) - if err != nil { - t.Error(err) + adapter.SearchStream(ctx, "aws", arn.(string), false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } }) }) diff --git a/adapters/iam-role.go b/adapters/iam-role.go index e929d972..b7e99e07 100644 --- a/adapters/iam-role.go +++ b/adapters/iam-role.go @@ -9,7 +9,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/micahhausler/aws-iam-policy/policy" - "go.opentelemetry.io/otel/attribute" "github.com/overmindtech/aws-source/adapterhelpers" "github.com/overmindtech/sdp-go" @@ -153,52 +152,7 @@ func getAttachedPolicies(ctx context.Context, client IAMClient, roleName string) return attachedPolicies, nil } -func roleListFunc(ctx context.Context, client IAMClient, _ string) ([]*RoleDetails, error) { - paginator := iam.NewListRolesPaginator(client, &iam.ListRolesInput{}) - roles := make([]*RoleDetails, 0) - ctx, span := tracer.Start(ctx, "roleListFunc") - defer span.End() - - mapper := iter.Mapper[types.Role, *RoleDetails]{ - MaxGoroutines: 100, - } - - for paginator.HasMorePages() { - out, err := paginator.NextPage(ctx) - - if err != nil { - return nil, err - } - - newRoles, err := mapper.MapErr(out.Roles, func(role *types.Role) (*RoleDetails, error) { - details := RoleDetails{ - Role: role, - } - - err := enrichRole(ctx, client, &details) - - if err != nil { - return nil, err - } - - return &details, nil - }) - - if err != nil { - return nil, err - } - - roles = append(roles, newRoles...) - } - - span.SetAttributes( - attribute.Int("ovm.aws.numRoles", len(roles)), - ) - - return roles, nil -} - -func roleItemMapper(_, scope string, awsItem *RoleDetails) (*sdp.Item, error) { +func roleItemMapper(_ *string, scope string, awsItem *RoleDetails) (*sdp.Item, error) { enrichedRole := struct { *types.Role EmbeddedPolicies []embeddedPolicy @@ -291,21 +245,51 @@ func roleListTagsFunc(ctx context.Context, r *RoleDetails, client IAMClient) (ma return tags, nil } -func NewIAMRoleAdapter(client *iam.Client, accountID string, region string) *adapterhelpers.GetListAdapter[*RoleDetails, IAMClient, *iam.Options] { - return &adapterhelpers.GetListAdapter[*RoleDetails, IAMClient, *iam.Options]{ - ItemType: "iam-role", - Client: client, - CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time - AccountID: accountID, - AdapterMetadata: roleAdapterMetadata, +func NewIAMRoleAdapter(client IAMClient, accountID string, region string) *adapterhelpers.GetListAdapterV2[*iam.ListRolesInput, *iam.ListRolesOutput, *RoleDetails, IAMClient, *iam.Options] { + return &adapterhelpers.GetListAdapterV2[*iam.ListRolesInput, *iam.ListRolesOutput, *RoleDetails, IAMClient, *iam.Options]{ + ItemType: "iam-role", + Client: client, + CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time + AccountID: accountID, + Region: region, GetFunc: func(ctx context.Context, client IAMClient, scope, query string) (*RoleDetails, error) { return roleGetFunc(ctx, client, scope, query) }, - ListFunc: func(ctx context.Context, client IAMClient, scope string) ([]*RoleDetails, error) { - return roleListFunc(ctx, client, scope) + InputMapperList: func(scope string) (*iam.ListRolesInput, error) { + return &iam.ListRolesInput{}, nil + }, + ListFuncPaginatorBuilder: func(client IAMClient, input *iam.ListRolesInput) adapterhelpers.Paginator[*iam.ListRolesOutput, *iam.Options] { + return iam.NewListRolesPaginator(client, input) + }, + ListExtractor: func(ctx context.Context, output *iam.ListRolesOutput, client IAMClient) ([]*RoleDetails, error) { + roles := make([]*RoleDetails, 0) + mapper := iter.Mapper[types.Role, *RoleDetails]{ + MaxGoroutines: 100, + } + + newRoles, err := mapper.MapErr(output.Roles, func(role *types.Role) (*RoleDetails, error) { + details := RoleDetails{ + Role: role, + } + + err := enrichRole(ctx, client, &details) + if err != nil { + return nil, err + } + + return &details, nil + }) + + if err != nil { + return nil, err + } + + roles = append(roles, newRoles...) + return roles, nil }, - ListTagsFunc: roleListTagsFunc, - ItemMapper: roleItemMapper, + ItemMapper: roleItemMapper, + ListTagsFunc: roleListTagsFunc, + AdapterMetadata: roleAdapterMetadata, } } diff --git a/adapters/iam-role_test.go b/adapters/iam-role_test.go index f30e5100..d00e54e4 100644 --- a/adapters/iam-role_test.go +++ b/adapters/iam-role_test.go @@ -13,19 +13,31 @@ import ( "github.com/micahhausler/aws-iam-policy/policy" "github.com/overmindtech/aws-source/adapterhelpers" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" ) func (t *TestIAMClient) GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) { return &iam.GetRoleOutput{ Role: &types.Role{ - Path: adapterhelpers.PtrString("/service-role/"), - RoleName: adapterhelpers.PtrString("AWSControlTowerConfigAggregatorRoleForOrganizations"), - RoleId: adapterhelpers.PtrString("AROA3VLV2U27YSTBFCGCJ"), - Arn: adapterhelpers.PtrString("arn:aws:iam::801795385023:role/service-role/AWSControlTowerConfigAggregatorRoleForOrganizations"), - CreateDate: adapterhelpers.PtrTime(time.Now()), - AssumeRolePolicyDocument: adapterhelpers.PtrString("FOO"), - MaxSessionDuration: adapterhelpers.PtrInt32(3600), + Path: adapterhelpers.PtrString("/service-role/"), + RoleName: adapterhelpers.PtrString("AWSControlTowerConfigAggregatorRoleForOrganizations"), + RoleId: adapterhelpers.PtrString("AROA3VLV2U27YSTBFCGCJ"), + Arn: adapterhelpers.PtrString("arn:aws:iam::801795385023:role/service-role/AWSControlTowerConfigAggregatorRoleForOrganizations"), + CreateDate: adapterhelpers.PtrTime(time.Now()), + AssumeRolePolicyDocument: adapterhelpers.PtrString(`{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": "ec2.amazonaws.com" + }, + "Action": "sts:AssumeRole" + } + ] +}`), + MaxSessionDuration: adapterhelpers.PtrInt32(3600), }, }, nil } @@ -43,13 +55,24 @@ func (t *TestIAMClient) ListRoles(context.Context, *iam.ListRolesInput, ...func( return &iam.ListRolesOutput{ Roles: []types.Role{ { - Path: adapterhelpers.PtrString("/service-role/"), - RoleName: adapterhelpers.PtrString("AWSControlTowerConfigAggregatorRoleForOrganizations"), - RoleId: adapterhelpers.PtrString("AROA3VLV2U27YSTBFCGCJ"), - Arn: adapterhelpers.PtrString("arn:aws:iam::801795385023:role/service-role/AWSControlTowerConfigAggregatorRoleForOrganizations"), - CreateDate: adapterhelpers.PtrTime(time.Now()), - AssumeRolePolicyDocument: adapterhelpers.PtrString("FOO"), - MaxSessionDuration: adapterhelpers.PtrInt32(3600), + Path: adapterhelpers.PtrString("/service-role/"), + RoleName: adapterhelpers.PtrString("AWSControlTowerConfigAggregatorRoleForOrganizations"), + RoleId: adapterhelpers.PtrString("AROA3VLV2U27YSTBFCGCJ"), + Arn: adapterhelpers.PtrString("arn:aws:iam::801795385023:role/service-role/AWSControlTowerConfigAggregatorRoleForOrganizations"), + CreateDate: adapterhelpers.PtrTime(time.Now()), + AssumeRolePolicyDocument: adapterhelpers.PtrString(`{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": "ec2.amazonaws.com" + }, + "Action": "sts:AssumeRole" + } + ] +}`), + MaxSessionDuration: adapterhelpers.PtrInt32(3600), }, }, }, nil @@ -120,14 +143,28 @@ func TestRoleGetFunc(t *testing.T) { } func TestRoleListFunc(t *testing.T) { - roles, err := roleListFunc(context.Background(), &TestIAMClient{}, "foo") + adapter := NewIAMRoleAdapter(&TestIAMClient{}, "foo", "bar") - if err != nil { - t.Error(err) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.ListStream(context.Background(), "foo.bar", false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } - if len(roles) != 1 { - t.Errorf("expected 1 role, got %b", len(roles)) + if len(items) != 1 { + t.Errorf("expected 1 role, got %b", len(items)) } } @@ -199,7 +236,7 @@ func TestRoleItemMapper(t *testing.T) { }, } - item, err := roleItemMapper("", "foo", &role) + item, err := roleItemMapper(nil, "foo", &role) if err != nil { t.Fatal(err) diff --git a/adapters/iam-user.go b/adapters/iam-user.go index 841ba7e6..65de98f5 100644 --- a/adapters/iam-user.go +++ b/adapters/iam-user.go @@ -77,42 +77,7 @@ func getUserGroups(ctx context.Context, client IAMClient, userName *string) ([]t return groups, nil } -func userListFunc(ctx context.Context, client IAMClient, _ string) ([]*UserDetails, error) { - var out *iam.ListUsersOutput - var err error - users := make([]types.User, 0) - - paginator := iam.NewListUsersPaginator(client, &iam.ListUsersInput{}) - - for paginator.HasMorePages() { - out, err = paginator.NextPage(ctx) - - if err != nil { - return nil, err - } - - users = append(users, out.Users...) - } - - userDetails := make([]*UserDetails, 0, len(users)) - - for i := range users { - details := UserDetails{ - User: &users[i], - } - - err := enrichUser(ctx, client, &details) - if err != nil { - return nil, fmt.Errorf("failed to enrich user %s: %w", *details.User.UserName, err) - } - - userDetails = append(userDetails, &details) - } - - return userDetails, nil -} - -func userItemMapper(_, scope string, awsItem *UserDetails) (*sdp.Item, error) { +func userItemMapper(_ *string, scope string, awsItem *UserDetails) (*sdp.Item, error) { attributes, err := adapterhelpers.ToAttributesWithExclude(awsItem.User) if err != nil { @@ -170,22 +135,43 @@ func userListTagsFunc(ctx context.Context, u *UserDetails, client IAMClient) (ma return tags, nil } -func NewIAMUserAdapter(client *iam.Client, accountID string, region string) *adapterhelpers.GetListAdapter[*UserDetails, IAMClient, *iam.Options] { - return &adapterhelpers.GetListAdapter[*UserDetails, IAMClient, *iam.Options]{ - ItemType: "iam-user", - Client: client, - AccountID: accountID, - CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time - Region: region, - AdapterMetadata: iamUserAdapterMetadata, +func NewIAMUserAdapter(client IAMClient, accountID string, region string) *adapterhelpers.GetListAdapterV2[*iam.ListUsersInput, *iam.ListUsersOutput, *UserDetails, IAMClient, *iam.Options] { + return &adapterhelpers.GetListAdapterV2[*iam.ListUsersInput, *iam.ListUsersOutput, *UserDetails, IAMClient, *iam.Options]{ + ItemType: "iam-user", + Client: client, + CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time + AccountID: accountID, + Region: region, GetFunc: func(ctx context.Context, client IAMClient, scope, query string) (*UserDetails, error) { return userGetFunc(ctx, client, scope, query) }, - ListFunc: func(ctx context.Context, client IAMClient, scope string) ([]*UserDetails, error) { - return userListFunc(ctx, client, scope) + InputMapperList: func(scope string) (*iam.ListUsersInput, error) { + return &iam.ListUsersInput{}, nil + }, + ListFuncPaginatorBuilder: func(client IAMClient, input *iam.ListUsersInput) adapterhelpers.Paginator[*iam.ListUsersOutput, *iam.Options] { + return iam.NewListUsersPaginator(client, input) }, - ListTagsFunc: userListTagsFunc, - ItemMapper: userItemMapper, + ListExtractor: func(ctx context.Context, output *iam.ListUsersOutput, client IAMClient) ([]*UserDetails, error) { + userDetails := make([]*UserDetails, 0, len(output.Users)) + + for i := range output.Users { + details := UserDetails{ + User: &output.Users[i], + } + + err := enrichUser(ctx, client, &details) + if err != nil { + return nil, fmt.Errorf("failed to enrich user %s: %w", *details.User.UserName, err) + } + + userDetails = append(userDetails, &details) + } + + return userDetails, nil + }, + ItemMapper: userItemMapper, + ListTagsFunc: userListTagsFunc, + AdapterMetadata: iamUserAdapterMetadata, } } diff --git a/adapters/iam-user_test.go b/adapters/iam-user_test.go index 3682d110..e24ad977 100644 --- a/adapters/iam-user_test.go +++ b/adapters/iam-user_test.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/overmindtech/aws-source/adapterhelpers" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" ) @@ -142,19 +143,36 @@ func TestUserGetFunc(t *testing.T) { } func TestUserListFunc(t *testing.T) { - users, err := userListFunc(context.Background(), &TestIAMClient{}, "foo") + adapter := NewIAMUserAdapter(&TestIAMClient{}, "foo", "bar") - if err != nil { - t.Error(err) + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.ListStream(context.Background(), "foo.bar", false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } - if len(users) != 3 { - t.Errorf("expected 3 users, got %v", len(users)) + if len(items) != 3 { + t.Errorf("expected 3 items, got %v", len(items)) } - for _, user := range users { - if len(user.UserGroups) != 3 { - t.Errorf("expected 3 groups, got %v", len(user.UserGroups)) + for _, item := range items { + if err := item.Validate(); err != nil { + t.Error(err) + } + if len(item.LinkedItemQueries) != 3 { + t.Errorf("expected 3 linked item queries, got %v", len(item.LinkedItemQueries)) } } } @@ -195,7 +213,7 @@ func TestUserItemMapper(t *testing.T) { }, } - item, err := userItemMapper("", "foo", &details) + item, err := userItemMapper(nil, "foo", &details) if err != nil { t.Error(err) diff --git a/adapters/integration/ec2/instance_test.go b/adapters/integration/ec2/instance_test.go index f7a9f466..2df3213c 100644 --- a/adapters/integration/ec2/instance_test.go +++ b/adapters/integration/ec2/instance_test.go @@ -8,9 +8,54 @@ import ( "github.com/overmindtech/aws-source/adapterhelpers" "github.com/overmindtech/aws-source/adapters" "github.com/overmindtech/aws-source/adapters/integration" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" ) +func searchSync(adapter discovery.StreamingAdapter, ctx context.Context, scope, query string, ignoreCache bool) ([]*sdp.Item, error) { + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.SearchStream(ctx, scope, query, ignoreCache, stream) + stream.Close() + + if len(errs) > 0 { + return nil, fmt.Errorf("failed to search: %v", errs) + } + + return items, nil +} + +func listSync(adapter discovery.StreamingAdapter, ctx context.Context, scope string, ignoreCache bool) ([]*sdp.Item, error) { + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.ListStream(ctx, scope, ignoreCache, stream) + stream.Close() + + if len(errs) > 0 { + return nil, fmt.Errorf("failed to List: %v", errs) + } + + return items, nil +} + func EC2(t *testing.T) { ctx := context.Background() @@ -39,7 +84,7 @@ func EC2(t *testing.T) { scope := adapterhelpers.FormatScope(accountID, testAWSConfig.Region) // List instances - sdpListInstances, err := instanceAdapter.List(context.Background(), scope, true) + sdpListInstances, err := listSync(instanceAdapter, context.Background(), scope, true) if err != nil { t.Fatalf("failed to list EC2 instances: %v", err) } @@ -77,7 +122,7 @@ func EC2(t *testing.T) { // Search instances instanceARN := fmt.Sprintf("arn:aws:ec2:%s:%s:instance/%s", testAWSConfig.Region, accountID, instanceID) - sdpSearchInstances, err := instanceAdapter.Search(context.Background(), scope, instanceARN, true) + sdpSearchInstances, err := searchSync(instanceAdapter, context.Background(), scope, instanceARN, true) if err != nil { t.Fatalf("failed to search EC2 instances: %v", err) } diff --git a/adapters/integration/kms/kms_test.go b/adapters/integration/kms/kms_test.go index 09e55b6e..66c1db46 100644 --- a/adapters/integration/kms/kms_test.go +++ b/adapters/integration/kms/kms_test.go @@ -9,9 +9,54 @@ import ( "github.com/overmindtech/aws-source/adapterhelpers" "github.com/overmindtech/aws-source/adapters" "github.com/overmindtech/aws-source/adapters/integration" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" ) +func searchSync(adapter discovery.StreamingAdapter, ctx context.Context, scope, query string, ignoreCache bool) ([]*sdp.Item, error) { + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.SearchStream(ctx, scope, query, ignoreCache, stream) + stream.Close() + + if len(errs) > 0 { + return nil, fmt.Errorf("failed to search: %v", errs) + } + + return items, nil +} + +func listSync(adapter discovery.StreamingAdapter, ctx context.Context, scope string, ignoreCache bool) ([]*sdp.Item, error) { + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.ListStream(ctx, scope, ignoreCache, stream) + stream.Close() + + if len(errs) > 0 { + return nil, fmt.Errorf("failed to List: %v", errs) + } + + return items, nil +} + func KMS(t *testing.T) { ctx := context.Background() @@ -61,7 +106,7 @@ func KMS(t *testing.T) { scope := adapterhelpers.FormatScope(accountID, testAWSConfig.Region) // List keys - sdpListKeys, err := keySource.List(context.Background(), scope, true) + sdpListKeys, err := listSync(keySource, context.Background(), scope, true) if err != nil { t.Fatalf("failed to list KMS keys: %v", err) } @@ -94,7 +139,7 @@ func KMS(t *testing.T) { // Search keys keyARN := fmt.Sprintf("arn:aws:kms:%s:%s:key/%s", testAWSConfig.Region, accountID, keyID) - sdpSearchKeys, err := keySource.Search(context.Background(), scope, keyARN, true) + sdpSearchKeys, err := searchSync(keySource, context.Background(), scope, keyARN, true) if err != nil { t.Fatalf("failed to search KMS keys: %v", err) } @@ -113,7 +158,7 @@ func KMS(t *testing.T) { } // List aliases - sdpListAliases, err := aliasSource.List(context.Background(), scope, true) + sdpListAliases, err := listSync(aliasSource, context.Background(), scope, true) if err != nil { t.Fatalf("failed to list KMS aliases: %v", err) } @@ -157,7 +202,7 @@ func KMS(t *testing.T) { } // Search aliases - sdpSearchAliases, err := aliasSource.Search(context.Background(), scope, keyID, true) + sdpSearchAliases, err := searchSync(aliasSource, context.Background(), scope, keyID, true) if err != nil { t.Fatalf("failed to search KMS aliases: %v", err) } @@ -176,7 +221,7 @@ func KMS(t *testing.T) { } // List grants is not supported - sdpListGrants, err := grantSource.List(context.Background(), scope, true) + sdpListGrants, err := listSync(grantSource, context.Background(), scope, true) if err == nil { t.Fatal("expected error but got nil") } @@ -186,7 +231,7 @@ func KMS(t *testing.T) { } // Search grants - sdpSearchGrants, err := grantSource.Search(context.Background(), scope, keyID, true) + sdpSearchGrants, err := searchSync(grantSource, context.Background(), scope, keyID, true) if err != nil { t.Fatalf("failed to search KMS grants: %v", err) } @@ -227,7 +272,7 @@ func KMS(t *testing.T) { } // Search key policy by key ID - sdpSearchKeyPolicies, err := keyPolicySource.Search(context.Background(), scope, keyID, true) + sdpSearchKeyPolicies, err := searchSync(keyPolicySource, context.Background(), scope, keyID, true) if err != nil { t.Fatalf("failed to search KMS key policies: %v", err) } diff --git a/adapters/integration/networkmanager/networkmanager_test.go b/adapters/integration/networkmanager/networkmanager_test.go index 3106aecd..7b024a26 100644 --- a/adapters/integration/networkmanager/networkmanager_test.go +++ b/adapters/integration/networkmanager/networkmanager_test.go @@ -9,9 +9,32 @@ import ( "github.com/overmindtech/aws-source/adapterhelpers" "github.com/overmindtech/aws-source/adapters" "github.com/overmindtech/aws-source/adapters/integration" + "github.com/overmindtech/discovery" "github.com/overmindtech/sdp-go" ) +func searchSync(adapter discovery.StreamingAdapter, ctx context.Context, scope, query string, ignoreCache bool) ([]*sdp.Item, error) { + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.SearchStream(ctx, scope, query, ignoreCache, stream) + stream.Close() + + if len(errs) > 0 { + return nil, fmt.Errorf("failed to search: %v", errs) + } + + return items, nil +} + func NetworkManager(t *testing.T) { ctx := context.Background() @@ -63,10 +86,22 @@ func NetworkManager(t *testing.T) { globalScope := adapterhelpers.FormatScope(accountID, "") t.Run("Global Network", func(t *testing.T) { - // List global networks - globalNetworks, err := globalNetworkSource.List(ctx, globalScope, true) - if err != nil { - t.Fatalf("failed to list NetworkManager global networks: %v", err) + globalNetworks := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + globalNetworks = append(globalNetworks, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + globalNetworkSource.ListStream(ctx, globalScope, false, stream) + stream.Close() + + if len(errs) > 0 { + t.Fatalf("failed to list NetworkManager global networks: %v", errs) } if len(globalNetworks) == 0 { @@ -105,7 +140,7 @@ func NetworkManager(t *testing.T) { t.Fatalf("expected global scope %s, got %s", globalScope, globalNetwork.GetScope()) } - globalNetworks, err = globalNetworkSource.Search(ctx, globalScope, globalNetworkARN.(string), true) + globalNetworks, err = searchSync(globalNetworkSource, ctx, globalScope, globalNetworkARN.(string), true) if err != nil { t.Fatalf("failed to search NetworkManager global networks: %v", err) } @@ -125,7 +160,7 @@ func NetworkManager(t *testing.T) { t.Run("Site", func(t *testing.T) { // Search sites by the global network ID that they are created on - sites, err := siteSource.Search(ctx, globalScope, globalNetworkID, true) + sites, err := searchSync(siteSource, ctx, globalScope, globalNetworkID, true) if err != nil { t.Fatalf("failed to search for site: %v", err) } @@ -161,7 +196,7 @@ func NetworkManager(t *testing.T) { t.Run("Link", func(t *testing.T) { // Search links by the global network ID that they are created on - links, err := linkSource.Search(ctx, globalScope, globalNetworkID, true) + links, err := searchSync(linkSource, ctx, globalScope, globalNetworkID, true) if err != nil { t.Fatalf("failed to search for link: %v", err) } @@ -198,7 +233,7 @@ func NetworkManager(t *testing.T) { // Search devices by the global network ID and site ID // query format = globalNetworkID|siteID queryDevice := fmt.Sprintf("%s|%s", globalNetworkID, siteID) - devices, err := deviceSource.Search(ctx, globalScope, queryDevice, true) + devices, err := searchSync(deviceSource, ctx, globalScope, queryDevice, true) if err != nil { t.Fatalf("failed to search for device: %v", err) } @@ -233,7 +268,7 @@ func NetworkManager(t *testing.T) { deviceOneID := strings.Split(deviceOneCompositeID, "|")[1] // Search devices by the global network ID - devicesByGlobalNetwork, err := deviceSource.Search(ctx, globalScope, globalNetworkID, true) + devicesByGlobalNetwork, err := searchSync(deviceSource, ctx, globalScope, globalNetworkID, true) if err != nil { t.Fatalf("failed to search for device by global network: %v", err) } @@ -243,7 +278,7 @@ func NetworkManager(t *testing.T) { t.Run("Link Association", func(t *testing.T) { // Search link associations by the global network ID, link ID queryLALink := fmt.Sprintf("%s|link|%s", globalNetworkID, linkID) - linkAssociations, err := linkAssociationSource.Search(ctx, globalScope, queryLALink, true) + linkAssociations, err := searchSync(linkAssociationSource, ctx, globalScope, queryLALink, true) if err != nil { t.Fatalf("failed to search for link association: %v", err) } @@ -276,7 +311,7 @@ func NetworkManager(t *testing.T) { } // Search link associations by the global network ID - searchLinkAssociationsByGlobalNetwork, err := linkAssociationSource.Search(ctx, globalScope, globalNetworkID, true) + searchLinkAssociationsByGlobalNetwork, err := searchSync(linkAssociationSource, ctx, globalScope, globalNetworkID, true) if err != nil { t.Fatalf("failed to search for link association by global network: %v", err) } @@ -285,7 +320,7 @@ func NetworkManager(t *testing.T) { // Search link associations by the global network ID and device ID queryLADevice := fmt.Sprintf("%s|device|%s", globalNetworkID, deviceOneID) - linkAssociationsByDevice, err := linkAssociationSource.Search(ctx, globalScope, queryLADevice, true) + linkAssociationsByDevice, err := searchSync(linkAssociationSource, ctx, globalScope, queryLADevice, true) if err != nil { t.Fatalf("failed to search for link association by device: %v", err) } @@ -295,7 +330,7 @@ func NetworkManager(t *testing.T) { t.Run("Connection", func(t *testing.T) { // Search connections by the global network ID - connections, err := connectionSource.Search(ctx, globalScope, globalNetworkID, true) + connections, err := searchSync(connectionSource, ctx, globalScope, globalNetworkID, true) if err != nil { t.Fatalf("failed to search for connection: %v", err) } @@ -329,7 +364,7 @@ func NetworkManager(t *testing.T) { // Search connections by global network ID and device ID queryCon := fmt.Sprintf("%s|%s", globalNetworkID, deviceOneID) - connectionsByDevice, err := connectionSource.Search(ctx, globalScope, queryCon, true) + connectionsByDevice, err := searchSync(connectionSource, ctx, globalScope, queryCon, true) if err != nil { t.Fatalf("failed to search for connection by device: %v", err) } diff --git a/adapters/integration/ssm/main_test.go b/adapters/integration/ssm/main_test.go index c942ee9e..314516d0 100644 --- a/adapters/integration/ssm/main_test.go +++ b/adapters/integration/ssm/main_test.go @@ -15,6 +15,8 @@ import ( "github.com/overmindtech/aws-source/adapters" "github.com/overmindtech/aws-source/adapters/integration" "github.com/overmindtech/aws-source/tracing" + "github.com/overmindtech/discovery" + "github.com/overmindtech/sdp-go" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" @@ -132,9 +134,23 @@ func TestIntegrationSSM(t *testing.T) { ctx, span := tracer.Start(ctx, "SSM.List") defer span.End() start := time.Now() - items, err := adapter.List(ctx, scope, true) - if err != nil { - t.Errorf("Failed to list SSM parameters: %v", err) + + items := make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + adapter.ListStream(ctx, scope, false, stream) + stream.Close() + + if len(errs) > 0 { + t.Error(errs) } timeTaken := time.Since(start) From 767f4d4a0d621489fea43f0290bb462d17bfc645 Mon Sep 17 00:00:00 2001 From: Dylan Ratcliffe Date: Sun, 8 Dec 2024 22:48:28 +0000 Subject: [PATCH 4/7] Fixed linting issues --- adapters/iam-group.go | 16 ---------------- adapters/iam-instance-profile.go | 16 ---------------- adapters/iam-user_test.go | 4 ++-- proc/proc.go | 10 ++++++++-- 4 files changed, 10 insertions(+), 36 deletions(-) diff --git a/adapters/iam-group.go b/adapters/iam-group.go index ee39ca1a..a3039764 100644 --- a/adapters/iam-group.go +++ b/adapters/iam-group.go @@ -23,22 +23,6 @@ func groupGetFunc(ctx context.Context, client *iam.Client, _, query string) (*ty return out.Group, nil } -func groupListFunc(ctx context.Context, client *iam.Client, _ string) ([]*types.Group, error) { - out, err := client.ListGroups(ctx, &iam.ListGroupsInput{}) - - if err != nil { - return nil, err - } - - zones := make([]*types.Group, 0, len(out.Groups)) - - for i := range out.Groups { - zones = append(zones, &out.Groups[i]) - } - - return zones, nil -} - func groupItemMapper(_ *string, scope string, awsItem *types.Group) (*sdp.Item, error) { attributes, err := adapterhelpers.ToAttributesWithExclude(awsItem) diff --git a/adapters/iam-instance-profile.go b/adapters/iam-instance-profile.go index f28a8b8c..270bfbfe 100644 --- a/adapters/iam-instance-profile.go +++ b/adapters/iam-instance-profile.go @@ -23,22 +23,6 @@ func instanceProfileGetFunc(ctx context.Context, client *iam.Client, _, query st return out.InstanceProfile, nil } -func instanceProfileListFunc(ctx context.Context, client *iam.Client, _ string) ([]*types.InstanceProfile, error) { - out, err := client.ListInstanceProfiles(ctx, &iam.ListInstanceProfilesInput{}) - - if err != nil { - return nil, err - } - - zones := make([]*types.InstanceProfile, 0, len(out.InstanceProfiles)) - - for i := range out.InstanceProfiles { - zones = append(zones, &out.InstanceProfiles[i]) - } - - return zones, nil -} - func instanceProfileItemMapper(_ *string, scope string, awsItem *types.InstanceProfile) (*sdp.Item, error) { attributes, err := adapterhelpers.ToAttributesWithExclude(awsItem) diff --git a/adapters/iam-user_test.go b/adapters/iam-user_test.go index e24ad977..875639e8 100644 --- a/adapters/iam-user_test.go +++ b/adapters/iam-user_test.go @@ -171,8 +171,8 @@ func TestUserListFunc(t *testing.T) { if err := item.Validate(); err != nil { t.Error(err) } - if len(item.LinkedItemQueries) != 3 { - t.Errorf("expected 3 linked item queries, got %v", len(item.LinkedItemQueries)) + if len(item.GetLinkedItemQueries()) != 3 { + t.Errorf("expected 3 linked item queries, got %v", len(item.GetLinkedItemQueries())) } } } diff --git a/proc/proc.go b/proc/proc.go index c046f284..667227e5 100644 --- a/proc/proc.go +++ b/proc/proc.go @@ -487,14 +487,17 @@ func InitializeAwsSourceEngine(ctx context.Context, ec *discovery.EngineConfig, adapters.NewSSMParameterAdapter(ssmClient, *callerID.Account, cfg.Region), } - e.AddAdapters(configuredAdapters...) + err = e.AddAdapters(configuredAdapters...) + if err != nil { + return err + } // Add "global" sources (those that aren't tied to a region, like // cloudfront). but only do this once for the first region. For // these APIs it doesn't matter which region we call them from, we // get global results if globalDone.CompareAndSwap(false, true) { - e.AddAdapters( + err = e.AddAdapters( // Cloudfront adapters.NewCloudfrontCachePolicyAdapter(cloudfrontClient, *callerID.Account), adapters.NewCloudfrontContinuousDeploymentPolicyAdapter(cloudfrontClient, *callerID.Account), @@ -518,6 +521,9 @@ func InitializeAwsSourceEngine(ctx context.Context, ec *discovery.EngineConfig, adapters.NewNetworkManagerLinkAssociationAdapter(networkmanagerClient, *callerID.Account), adapters.NewNetworkManagerConnectionAdapter(networkmanagerClient, *callerID.Account), ) + if err != nil { + return err + } } return nil }) From 6a2c81a36c77511495baf964d35fbec498ba99f8 Mon Sep 17 00:00:00 2001 From: Dylan Ratcliffe Date: Sun, 8 Dec 2024 23:07:02 +0000 Subject: [PATCH 5/7] Fixed generic tests --- adapterhelpers/util.go | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/adapterhelpers/util.go b/adapterhelpers/util.go index d490ae1b..ebc53623 100644 --- a/adapterhelpers/util.go +++ b/adapterhelpers/util.go @@ -257,18 +257,36 @@ func (e E2ETest) Run(t *testing.T) { t.Run(fmt.Sprintf("Adapter: %v", e.Adapter.Name()), func(t *testing.T) { if e.GoodSearchQuery != nil { - var searchSrc discovery.SearchableAdapter - var ok bool - - if searchSrc, ok = e.Adapter.(discovery.SearchableAdapter); !ok { - t.Errorf("adapter is not searchable") - } - t.Run(fmt.Sprintf("Good search query: %v", e.GoodSearchQuery), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), e.Timeout) defer cancel() - items, err := searchSrc.Search(ctx, scope, *e.GoodSearchQuery, false) + var items []*sdp.Item + var err error + if searchSrc, ok := e.Adapter.(discovery.SearchableAdapter); ok { + items, err = searchSrc.Search(ctx, scope, *e.GoodSearchQuery, false) + } else if streamSrc, ok := e.Adapter.(discovery.StreamingAdapter); ok { + items = make([]*sdp.Item, 0) + errs := make([]error, 0) + stream := discovery.NewQueryResultStream( + func(item *sdp.Item) { + items = append(items, item) + }, + func(err error) { + errs = append(errs, err) + }, + ) + + streamSrc.SearchStream(context.Background(), scope, *e.GoodSearchQuery, false, stream) + stream.Close() + + if len(errs) > 0 { + err = errs[0] + } + } else { + t.Skip("adapter is not searchable or streamable") + } + if err != nil { t.Error(err) } @@ -311,6 +329,7 @@ func (e E2ETest) Run(t *testing.T) { ) streamingAdapter.ListStream(context.Background(), scope, false, stream) + stream.Close() } else if listableAdapter, ok := e.Adapter.(discovery.ListableAdapter); ok { var err error items, err = listableAdapter.List(ctx, scope, false) From 990f4f41441e59a6d0ec5056450e910f2a601aa9 Mon Sep 17 00:00:00 2001 From: Dylan Ratcliffe Date: Sun, 8 Dec 2024 23:26:32 +0000 Subject: [PATCH 6/7] RTevert to previous behavoiur --- adapterhelpers/always_get_source.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/adapterhelpers/always_get_source.go b/adapterhelpers/always_get_source.go index a36ea7ac..5cb8111d 100644 --- a/adapterhelpers/always_get_source.go +++ b/adapterhelpers/always_get_source.go @@ -264,11 +264,8 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru for _, input := range newGetInputs { p.Go(func(ctx context.Context) error { - item, err := s.GetFunc(ctx, s.Client, scope, input) - - if err != nil { - stream.SendError(WrapAWSError(err)) - } + // Ignore the error here as we don't want to stop the whole process + item, _ := s.GetFunc(ctx, s.Client, scope, input) if item != nil { s.cache.StoreItem(item, s.cacheDuration(), ck) From f6d277fa3ee127c9f70b02a799ebed5496d46ff3 Mon Sep 17 00:00:00 2001 From: Dylan Ratcliffe Date: Sun, 8 Dec 2024 23:30:04 +0000 Subject: [PATCH 7/7] Fixed test expectations --- adapterhelpers/always_get_source_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/adapterhelpers/always_get_source_test.go b/adapterhelpers/always_get_source_test.go index b408e85a..67280038 100644 --- a/adapterhelpers/always_get_source_test.go +++ b/adapterhelpers/always_get_source_test.go @@ -254,8 +254,8 @@ func TestAlwaysGetSourceList(t *testing.T) { lgs.ListStream(context.Background(), "foo.bar", false, stream) stream.Close() - if len(errs) != 6 { - t.Fatalf("expected 6 error, got %v", len(errs)) + if len(errs) != 0 { + t.Fatalf("expected 0 error, got %v", len(errs)) } if len(items) != 0 { @@ -429,8 +429,8 @@ func TestAlwaysGetSourceSearch(t *testing.T) { lgs.SearchStream(context.Background(), "foo.bar", "id", false, stream) stream.Close() - if len(errs) != 6 { - t.Errorf("expected 6 error, got %v", len(errs)) + if len(errs) != 0 { + t.Errorf("expected 0 error, got %v", len(errs)) } if len(items) != 0 {