Skip to content

Commit

Permalink
Ensure we actually set a cache duration
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanratcliffe committed Oct 4, 2023
1 parent 650af18 commit f5dba3e
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 43 deletions.
21 changes: 10 additions & 11 deletions sources/always_get_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ type AlwaysGetSource[ListInput InputType, ListOutput OutputType, GetInput InputT
cacheInitMu sync.Mutex // Mutex to ensure cache is only initialised once
}

// DefaultCacheDuration Returns the default cache duration for this source
func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) DefaultCacheDuration() time.Duration {
func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) cacheDuration() time.Duration {
if s.CacheDuration == 0 {
return 10 * time.Minute
return DefaultCacheDuration
}

return s.CacheDuration
Expand Down Expand Up @@ -170,11 +169,11 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc
if err != nil {
// TODO: How can we handle NOTFOUND?
qErr := WrapAWSError(err)
s.cache.StoreError(qErr, s.CacheDuration, ck)
s.cache.StoreError(qErr, s.cacheDuration(), ck)
return nil, qErr
}

s.cache.StoreItem(item, s.CacheDuration, ck)
s.cache.StoreItem(item, s.cacheDuration(), ck)
return item, nil
}

Expand Down Expand Up @@ -206,12 +205,12 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc
items, err := s.listInternal(ctx, scope, s.ListInput)
if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

for _, item := range items {
s.cache.StoreItem(item, s.CacheDuration, ck)
s.cache.StoreItem(item, s.cacheDuration(), ck)
}

return items, nil
Expand Down Expand Up @@ -343,12 +342,12 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc

if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

for _, item := range items {
s.cache.StoreItem(item, s.CacheDuration, ck)
s.cache.StoreItem(item, s.cacheDuration(), ck)
}

return items, nil
Expand All @@ -369,12 +368,12 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc

if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

for _, item := range items {
s.cache.StoreItem(item, s.CacheDuration, ck)
s.cache.StoreItem(item, s.cacheDuration(), ck)
}
return items, nil
}
Expand Down
32 changes: 18 additions & 14 deletions sources/describe_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/overmindtech/sdpcache"
)

const DefaultCacheDuration = 1 * time.Hour

// DescribeOnlySource Generates a source for AWS APIs that only use a `Describe`
// function for both List and Get operations. EC2 is a good example of this,
// where running Describe with no params returns everything, but params can be
Expand Down Expand Up @@ -63,10 +65,12 @@ type DescribeOnlySource[Input InputType, Output OutputType, ClientStruct ClientS
Client ClientStruct
}

// DefaultCacheDuration Returns the default cache duration for this source
func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) DefaultCacheDuration() time.Duration {
// Returns the duration that items should be cached for. This will use the
// `CacheDuration` for this source if set, otherwise it will use the default
// duration of 1 hour
func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) cacheDuration() time.Duration {
if s.CacheDuration == 0 {
return 10 * time.Minute
return DefaultCacheDuration
}

return s.CacheDuration
Expand Down Expand Up @@ -169,22 +173,22 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Get(ctx conte
input, err = s.InputMapperGet(scope, query)
if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

// Call the API using the object
output, err = s.DescribeFunc(ctx, s.Client, input)
if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

items, err = s.OutputMapper(scope, input, output)
if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

Expand All @@ -203,19 +207,19 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Get(ctx conte
ErrorType: sdp.QueryError_OTHER,
ErrorString: fmt.Sprintf("Request returned > 1 item for a GET request. Items: %v", strings.Join(itemNames, ", ")),
}
s.cache.StoreError(qErr, s.CacheDuration, ck)
s.cache.StoreError(qErr, s.cacheDuration(), ck)

return nil, qErr
case numItems == 0:
qErr := &sdp.QueryError{
ErrorType: sdp.QueryError_NOTFOUND,
ErrorString: fmt.Sprintf("%v %v not found", s.Type(), query),
}
s.cache.StoreError(qErr, s.CacheDuration, ck)
s.cache.StoreError(qErr, s.cacheDuration(), ck)
return nil, qErr
}

s.cache.StoreItem(items[0], s.CacheDuration, ck)
s.cache.StoreItem(items[0], s.cacheDuration(), ck)
return items[0], nil
}

Expand Down Expand Up @@ -254,19 +258,19 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) List(ctx cont
input, err := s.InputMapperList(scope)
if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

items, err = s.describe(ctx, input, scope)
if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

for _, item := range items {
s.cache.StoreItem(item, s.CacheDuration, ck)
s.cache.StoreItem(item, s.cacheDuration(), ck)
}

return items, nil
Expand Down Expand Up @@ -325,12 +329,12 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) searchCustom(
items, err := s.describe(ctx, input, scope)
if err != nil {
err = WrapAWSError(err)
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, err
}

for _, item := range items {
s.cache.StoreItem(item, s.CacheDuration, ck)
s.cache.StoreItem(item, s.cacheDuration(), ck)
}

return items, nil
Expand Down
25 changes: 12 additions & 13 deletions sources/get_list_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ type GetListSource[AWSItem AWSItemType, ClientStruct ClientStructType, Options O
ItemMapper func(scope string, awsItem AWSItem) (*sdp.Item, error)
}

func (s *GetListSource[AWSItem, ClientStruct, Options]) cacheDuration() time.Duration {
if s.CacheDuration == 0 {
return DefaultCacheDuration
}

return s.CacheDuration
}

func (s *GetListSource[AWSItem, ClientStruct, Options]) ensureCache() {
s.cacheInitMu.Lock()
defer s.cacheInitMu.Unlock()
Expand Down Expand Up @@ -84,15 +92,6 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) Name() string {
return fmt.Sprintf("%v-source", s.ItemType)
}

// DefaultCacheDuration Returns the default cache duration for this source
func (s *GetListSource[AWSItem, ClientStruct, Options]) DefaultCacheDuration() time.Duration {
if s.CacheDuration == 0 {
return 10 * time.Minute
}

return s.CacheDuration
}

// List of scopes that this source is capable of find items for. This will be
// in the format {accountID}.{region}
func (s *GetListSource[AWSItem, ClientStruct, Options]) Scopes() []string {
Expand Down Expand Up @@ -147,17 +146,17 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) Get(ctx context.Context,

awsItem, err := s.GetFunc(ctx, s.Client, scope, query)
if err != nil {
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, WrapAWSError(err)
}

item, err := s.ItemMapper(scope, awsItem)
if err != nil {
s.cache.StoreError(err, s.CacheDuration, ck)
s.cache.StoreError(err, s.cacheDuration(), ck)
return nil, WrapAWSError(err)
}

s.cache.StoreItem(item, s.CacheDuration, ck)
s.cache.StoreItem(item, s.cacheDuration(), ck)

return item, nil
}
Expand Down Expand Up @@ -199,7 +198,7 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) List(ctx context.Context
}

items = append(items, item)
s.cache.StoreItem(item, s.CacheDuration, ck)
s.cache.StoreItem(item, s.cacheDuration(), ck)
}

return items, nil
Expand Down
2 changes: 1 addition & 1 deletion sources/iam/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func NewGroupSource(config aws.Config, accountID string, region string, limit *s
return &sources.GetListSource[*types.Group, *iam.Client, *iam.Options]{
ItemType: "iam-group",
Client: iam.NewFromConfig(config),
CacheDuration: 1 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
AccountID: accountID,
GetFunc: func(ctx context.Context, client *iam.Client, scope, query string) (*types.Group, error) {
limit.Wait(ctx) // Wait for rate limiting
Expand Down
2 changes: 1 addition & 1 deletion sources/iam/instance_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func NewInstanceProfileSource(config aws.Config, accountID string, region string
return &sources.GetListSource[*types.InstanceProfile, *iam.Client, *iam.Options]{
ItemType: "iam-instance-profile",
Client: iam.NewFromConfig(config),
CacheDuration: 1 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
AccountID: accountID,
GetFunc: func(ctx context.Context, client *iam.Client, scope, query string) (*types.InstanceProfile, error) {
limit.Wait(ctx) // Wait for rate limiting
Expand Down
2 changes: 1 addition & 1 deletion sources/iam/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func NewPolicySource(config aws.Config, accountID string, _ string, limit *sourc
return &sources.GetListSource[*PolicyDetails, IAMClient, *iam.Options]{
ItemType: "iam-policy",
Client: iam.NewFromConfig(config),
CacheDuration: 1 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
AccountID: accountID,
Region: "", // IAM policies aren't tied to a region

Expand Down
2 changes: 1 addition & 1 deletion sources/iam/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func NewRoleSource(config aws.Config, accountID string, region string, limit *so
return &sources.GetListSource[*RoleDetails, IAMClient, *iam.Options]{
ItemType: "iam-role",
Client: iam.NewFromConfig(config),
CacheDuration: 1 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
AccountID: accountID,
GetFunc: func(ctx context.Context, client IAMClient, scope, query string) (*RoleDetails, error) {
return roleGetFunc(ctx, client, scope, query, limit)
Expand Down
2 changes: 1 addition & 1 deletion sources/iam/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func NewUserSource(config aws.Config, accountID string, region string, limit *so
ItemType: "iam-user",
Client: iam.NewFromConfig(config),
AccountID: accountID,
CacheDuration: 1 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time
Region: region,
GetFunc: func(ctx context.Context, client IAMClient, scope, query string) (*UserDetails, error) {
return userGetFunc(ctx, client, scope, query, limit)
Expand Down

0 comments on commit f5dba3e

Please sign in to comment.