From e501cb8cd361c7c76b674fd71c91b864229c345b Mon Sep 17 00:00:00 2001 From: Volodymyr Manilo <35466116+vmanilo@users.noreply.github.com> Date: Mon, 29 Apr 2024 23:03:17 +0200 Subject: [PATCH 1/2] fix schema upgrade (#517) --- .../internal/provider/resource/resource-v1.go | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/twingate/internal/provider/resource/resource-v1.go b/twingate/internal/provider/resource/resource-v1.go index 4f646165..fed0373e 100644 --- a/twingate/internal/provider/resource/resource-v1.go +++ b/twingate/internal/provider/resource/resource-v1.go @@ -233,8 +233,9 @@ func convertAccessGroupsToTerraform(ctx context.Context, groups []string) (types for _, g := range groups { attributes := map[string]tfattr.Value{ - attr.GroupID: types.StringValue(g), - attr.SecurityPolicyID: types.StringNull(), + attr.GroupID: types.StringValue(g), + attr.SecurityPolicyID: types.StringNull(), + attr.UsageBasedAutolockDurationDays: types.Int64Null(), } obj, diags := types.ObjectValue(accessGroupAttributeTypes(), attributes) @@ -244,7 +245,7 @@ func convertAccessGroupsToTerraform(ctx context.Context, groups []string) (types } if diagnostics.HasError() { - return makeObjectsSetNull(ctx, accessAttributeTypes()), diagnostics + return makeObjectsSetNull(ctx, accessGroupAttributeTypes()), diagnostics } return makeObjectsSet(ctx, objects...) @@ -264,26 +265,15 @@ func convertAccessServiceAccountsToTerraform(ctx context.Context, serviceAccount attr.ServiceAccountID: types.StringValue(account), } - obj, diags := types.ObjectValue(accessGroupAttributeTypes(), attributes) + obj, diags := types.ObjectValue(accessServiceAccountAttributeTypes(), attributes) diagnostics.Append(diags...) objects = append(objects, obj) } if diagnostics.HasError() { - return makeObjectsSetNull(ctx, accessAttributeTypes()), diagnostics + return makeObjectsSetNull(ctx, accessServiceAccountAttributeTypes()), diagnostics } return makeObjectsSet(ctx, objects...) } - -func accessAttributeTypes() map[string]tfattr.Type { - return map[string]tfattr.Type{ - attr.GroupIDs: types.SetType{ - ElemType: types.StringType, - }, - attr.ServiceAccountIDs: types.SetType{ - ElemType: types.StringType, - }, - } -} From 524292756aab20c74965132ddb10978e1ae1afd9 Mon Sep 17 00:00:00 2001 From: Volodymyr Manilo <35466116+vmanilo@users.noreply.github.com> Date: Mon, 29 Apr 2024 23:06:32 +0200 Subject: [PATCH 2/2] Feature: reduce requests for resource read operation (#516) * wip * wip * fix linter issues * set page limit 100 * enable tests * disable cache for unit tests * revert ci changes --------- Co-authored-by: bertekintw <101608051+bertekintw@users.noreply.github.com> --- twingate/internal/client/cache.go | 181 ++++++++++++++++++ twingate/internal/client/client.go | 8 +- .../internal/client/query/resources-read.go | 24 +++ twingate/internal/client/resource.go | 75 +++++++- .../test/acctests/resource/resource_test.go | 81 ++++++++ 5 files changed, 367 insertions(+), 2 deletions(-) create mode 100644 twingate/internal/client/cache.go diff --git a/twingate/internal/client/cache.go b/twingate/internal/client/cache.go new file mode 100644 index 00000000..34a57cda --- /dev/null +++ b/twingate/internal/client/cache.go @@ -0,0 +1,181 @@ +package client + +import ( + "context" + "sync" + "time" + + "github.com/Twingate/terraform-provider-twingate/v3/twingate/internal/model" +) + +var closedChan chan struct{} //nolint:gochecknoglobals + +const ( + minBulkSize = 10 + requestsBufferSize = 1000 + collectTime = 70 * time.Millisecond + shortWaitTime = 5 * time.Millisecond +) + +var cache = &clientCache{ //nolint:gochecknoglobals + resources: map[string]*model.Resource{}, + requestedResources: make(chan string, requestsBufferSize), +} + +func init() { //nolint:gochecknoinits + closedChan = make(chan struct{}) + close(closedChan) + + go cache.run() +} + +type clientCache struct { + lock sync.RWMutex + resources map[string]*model.Resource + + requestDone bool + requestedResources chan string + + client *Client +} + +func (c *clientCache) done() <-chan struct{} { + c.lock.RLock() + defer c.lock.RUnlock() + + if !c.requestDone { + return nil + } + + return closedChan +} + +func (c *clientCache) run() { //nolint + var collectTimer *time.Timer + + resourcesToRequest := make(map[string]bool) + + for { + select { + case id := <-c.requestedResources: + resourcesToRequest[id] = true + + c.lock.RLock() + isDone := c.requestDone + c.lock.RUnlock() + + if isDone { + c.lock.Lock() + c.requestDone = false + c.lock.Unlock() + } + + if collectTimer == nil { + collectTimer = time.NewTimer(collectTime) + + continue + } else { + select { + case <-collectTimer.C: + collectTimer = nil + + c.fetchResources(resourcesToRequest) + resourcesToRequest = make(map[string]bool) + + default: // no op + } + } + + default: + if collectTimer != nil { + select { + case <-collectTimer.C: + collectTimer = nil + + c.fetchResources(resourcesToRequest) + resourcesToRequest = make(map[string]bool) + + default: // no op + } + } + + time.Sleep(shortWaitTime) + } + } +} + +func (c *clientCache) fetchResources(resourcesToRequest map[string]bool) { + if len(resourcesToRequest) >= minBulkSize && c.client != nil { + resources, err := c.client.ReadFullResources(context.Background()) + if err == nil { + c.setResources(resources) + } + } + + // notify + c.lock.Lock() + c.requestDone = true + c.lock.Unlock() +} + +func (c *clientCache) getResource(resourceID string) (*model.Resource, bool) { + c.lock.RLock() + + if c.client == nil { + c.lock.RUnlock() + + return nil, false + } + + c.lock.RUnlock() + + c.lock.RLock() + res, exists := c.resources[resourceID] + c.lock.RUnlock() + + if exists { + return res, exists + } + + c.requestedResources <- resourceID + // wait for fetching +LOOP: + for { + select { + case <-c.done(): + break LOOP + + default: + time.Sleep(shortWaitTime) + } + } + + c.lock.RLock() + res, exists = c.resources[resourceID] + c.lock.RUnlock() + + return res, exists +} + +func (c *clientCache) setResource(resource *model.Resource) { + c.lock.Lock() + defer c.lock.Unlock() + + c.resources[resource.ID] = resource +} + +func (c *clientCache) setResources(resources []*model.Resource) { + c.lock.Lock() + defer c.lock.Unlock() + + for _, resource := range resources { + c.resources[resource.ID] = resource + } +} + +func (c *clientCache) invalidateResource(id string) { + c.lock.Lock() + defer c.lock.Unlock() + + delete(c.resources, id) +} diff --git a/twingate/internal/client/client.go b/twingate/internal/client/client.go index d29480f1..edac3261 100644 --- a/twingate/internal/client/client.go +++ b/twingate/internal/client/client.go @@ -28,7 +28,8 @@ const ( headerAgent = "User-Agent" headerCorrelationID = "X-Correlation-Id" - defaultPageLimit = 50 + defaultPageLimit = 50 + extendedPageLimit = 100 ) var ( @@ -137,6 +138,7 @@ func NewClient(url string, apiToken string, network string, httpTimeout time.Dur sURL := newServerURL(network, url) retryableClient := retryablehttp.NewClient() + retryableClient.Logger = nil retryableClient.CheckRetry = customRetryPolicy retryableClient.RetryMax = httpRetryMax retryableClient.RequestLogHook = func(logger retryablehttp.Logger, req *http.Request, retryNumber int) { @@ -161,6 +163,10 @@ func NewClient(url string, apiToken string, network string, httpTimeout time.Dur log.Printf("[INFO] Using Server URL %s", sURL.newGraphqlServerURL()) + if version != "test" { + cache.client = &client + } + return &client } diff --git a/twingate/internal/client/query/resources-read.go b/twingate/internal/client/query/resources-read.go index d62ef1f3..3dd2ad4f 100644 --- a/twingate/internal/client/query/resources-read.go +++ b/twingate/internal/client/query/resources-read.go @@ -28,3 +28,27 @@ func (r Resources) ToModel() []*model.Resource { return edge.Node.ToModel() }) } + +// --- + +type ReadFullResources struct { + FullResources `graphql:"resources(after: $resourcesEndCursor, first: $pageLimit)"` +} + +func (r ReadFullResources) IsEmpty() bool { + return len(r.Edges) == 0 +} + +type FullResources struct { + PaginatedResource[*FullResourceEdge] +} + +type FullResourceEdge struct { + Node *gqlResource +} + +func (r ReadFullResources) ToModel() []*model.Resource { + return utils.Map[*FullResourceEdge, *model.Resource](r.Edges, func(edge *FullResourceEdge) *model.Resource { + return edge.Node.ToModel() + }) +} diff --git a/twingate/internal/client/resource.go b/twingate/internal/client/resource.go index 6cdc0ebf..f984982e 100644 --- a/twingate/internal/client/resource.go +++ b/twingate/internal/client/resource.go @@ -106,6 +106,10 @@ func (client *Client) ReadResource(ctx context.Context, resourceID string) (*mod return nil, opr.apiError(ErrGraphqlIDIsEmpty) } + if res, ok := cache.getResource(resourceID); ok { + return res, nil + } + variables := newVars( gqlID(resourceID), cursor(query.CursorAccess), @@ -121,7 +125,11 @@ func (client *Client) ReadResource(ctx context.Context, resourceID string) (*mod return nil, err //nolint } - return response.Resource.ToModel(), nil + res := response.Resource.ToModel() + + cache.setResource(res) + + return res, nil } func (client *Client) readResourceAccessAfter(ctx context.Context, variables map[string]interface{}, cursor string) (*query.PaginatedResource[*query.AccessEdge], error) { @@ -172,9 +180,66 @@ func (client *Client) readResourcesAfter(ctx context.Context, variables map[stri return &response.PaginatedResource, nil } +func (client *Client) ReadFullResources(ctx context.Context) ([]*model.Resource, error) { + opr := resourceResource.read() + + variables := newVars( + cursor(query.CursorAccess), + cursor(query.CursorResources), + pageLimit(extendedPageLimit), + ) + + response := query.ReadFullResources{} + if err := client.query(ctx, &response, variables, opr.withCustomName("readFullResources"), attr{id: "All"}); err != nil && !errors.Is(err, ErrGraphqlResultIsEmpty) { + return nil, err + } + + if err := response.FetchPages(ctx, client.readFullResourcesAfter, variables); err != nil { + return nil, err //nolint + } + + for i := range response.Edges { + if err := response.Edges[i].Node.Access.FetchPages(ctx, client.readExtendedResourceAccessAfter, newVars(gqlID(response.Edges[i].Node.ID))); err != nil { + return nil, err //nolint:wrapcheck + } + } + + return response.ToModel(), nil +} + +func (client *Client) readFullResourcesAfter(ctx context.Context, variables map[string]interface{}, cursor string) (*query.PaginatedResource[*query.FullResourceEdge], error) { + opr := resourceResource.read() + + variables[query.CursorResources] = cursor + + response := query.ReadFullResources{} + if err := client.query(ctx, &response, variables, opr); err != nil { + return nil, err + } + + return &response.PaginatedResource, nil +} + +func (client *Client) readExtendedResourceAccessAfter(ctx context.Context, variables map[string]interface{}, cursor string) (*query.PaginatedResource[*query.AccessEdge], error) { + opr := resourceResource.read() + + resourceID := string(variables["id"].(graphql.ID)) + variables[query.CursorAccess] = cursor + pageLimit(extendedPageLimit)(variables) + + response := query.ReadResourceAccess{} + if err := client.query(ctx, &response, variables, opr, attr{id: resourceID}); err != nil { + return nil, err + } + + return &response.Resource.Access.PaginatedResource, nil +} + func (client *Client) UpdateResource(ctx context.Context, input *model.Resource) (*model.Resource, error) { opr := resourceResource.update() + cache.invalidateResource(input.ID) + variables := newVars( gqlID(input.ID), gqlID(input.RemoteNetworkID, "remoteNetworkId"), @@ -224,6 +289,8 @@ func (client *Client) DeleteResource(ctx context.Context, resourceID string) err return opr.apiError(ErrGraphqlIDIsEmpty) } + cache.invalidateResource(resourceID) + response := query.DeleteResource{} return client.mutate(ctx, &response, newVars(gqlID(resourceID)), opr, attr{id: resourceID}) @@ -232,6 +299,8 @@ func (client *Client) DeleteResource(ctx context.Context, resourceID string) err func (client *Client) UpdateResourceActiveState(ctx context.Context, resource *model.Resource) error { opr := resourceResource.update() + cache.invalidateResource(resource.ID) + variables := newVars( gqlID(resource.ID), gqlVar(resource.IsActive, "isActive"), @@ -287,6 +356,8 @@ func (client *Client) RemoveResourceAccess(ctx context.Context, resourceID strin return opr.apiError(ErrGraphqlIDIsEmpty) } + cache.invalidateResource(resourceID) + variables := newVars( gqlID(resourceID), gqlIDs(principalIDs, "principalIds"), @@ -314,6 +385,8 @@ func (client *Client) AddResourceAccess(ctx context.Context, resourceID string, return opr.apiError(ErrGraphqlIDIsEmpty) } + cache.invalidateResource(resourceID) + variables := newVars( gqlID(resourceID), gqlNullable(access, "access"), diff --git a/twingate/internal/test/acctests/resource/resource_test.go b/twingate/internal/test/acctests/resource/resource_test.go index 5b54adbc..f4b90d85 100644 --- a/twingate/internal/test/acctests/resource/resource_test.go +++ b/twingate/internal/test/acctests/resource/resource_test.go @@ -3352,3 +3352,84 @@ func createResourceWithUsageBasedOnGroupAccess(remoteNetwork, resource, groupNam } `, remoteNetwork, resource, groupName, daysDuration) } + +func TestAccTwingateWithMultipleResource(t *testing.T) { + t.Parallel() + + resourceName := test.RandomResourceName() + remoteNetworkName := test.RandomName() + groupName := test.RandomGroupName() + + theResource1 := acctests.TerraformResource(resourceName + "-1") + theResource2 := acctests.TerraformResource(resourceName + "-2") + + sdk.Test(t, sdk.TestCase{ + ProtoV6ProviderFactories: acctests.ProviderFactories, + PreCheck: func() { acctests.PreCheck(t) }, + CheckDestroy: acctests.CheckTwingateResourceDestroy, + Steps: []sdk.TestStep{ + { + Config: createMultipleResourcesN(remoteNetworkName, resourceName, groupName, 10), + Check: acctests.ComposeTestCheckFunc( + acctests.CheckTwingateResourceExists(theResource1), + acctests.CheckTwingateResourceExists(theResource2), + sdk.TestCheckResourceAttr(theResource1, accessGroupIdsLen, "2"), + sdk.TestCheckResourceAttr(theResource2, accessGroupIdsLen, "2"), + ), + }, + { + Config: createMultipleResourcesN(remoteNetworkName, resourceName, groupName, 10), + Check: acctests.ComposeTestCheckFunc( + acctests.CheckTwingateResourceExists(theResource1), + acctests.CheckTwingateResourceExists(theResource2), + sdk.TestCheckResourceAttr(theResource1, accessGroupIdsLen, "2"), + sdk.TestCheckResourceAttr(theResource2, accessGroupIdsLen, "2"), + ), + }, + }, + }) +} + +func createMultipleResourcesN(remoteNetwork, resource, groupName string, n int) string { + return fmt.Sprintf(` + resource "twingate_group" "%[2]s-group-1" { + name = "%[3]s-1" + } + + resource "twingate_group" "%[2]s-group-2" { + name = "%[3]s-2" + } + + resource "twingate_remote_network" "%[1]s" { + name = "%[1]s" + } + + `+genMultipleResource(n), + remoteNetwork, resource, groupName) +} + +func genMultipleResource(n int) string { + res := make([]string, 0, n) + for i := 0; i < n; i++ { + res = append(res, fmtResource(i+1)) + } + + return strings.Join(res, "\n\n") +} + +func fmtResource(index int) string { + return fmt.Sprintf(` + resource "twingate_resource" "%%[2]s-%[1]v" { + name = "%%[2]s-%[1]v" + address = "acc-test-address-%[1]v.com" + remote_network_id = twingate_remote_network.%%[1]s.id + + dynamic "access_group" { + for_each = [twingate_group.%%[2]s-group-1.id, twingate_group.%%[2]s-group-2.id] + content { + group_id = access_group.value + } + } + } + `, index) +}