diff --git a/twingate/internal/provider/resource/helper.go b/twingate/internal/provider/resource/helper.go index 744172f9..0f1b1014 100644 --- a/twingate/internal/provider/resource/helper.go +++ b/twingate/internal/provider/resource/helper.go @@ -6,6 +6,7 @@ import ( "github.com/Twingate/terraform-provider-twingate/twingate/internal/model" tfattr "github.com/hashicorp/terraform-plugin-framework/attr" "github.com/hashicorp/terraform-plugin-framework/types" + "strings" "github.com/Twingate/terraform-provider-twingate/twingate/internal/utils" "github.com/hashicorp/terraform-plugin-framework/diag" @@ -29,7 +30,9 @@ func setIntersection(a, b []string) []string { } func setIntersectionGroupAccess(inputA, inputB []model.AccessGroup) []model.AccessGroup { - var setA, setB map[string]model.AccessGroup + setA := map[string]model.AccessGroup{} + setB := map[string]model.AccessGroup{} + for _, access := range inputA { setA[access.GroupID] = access } @@ -75,7 +78,9 @@ func setDifference(inputA, inputB []string) []string { } func setDifferenceGroupAccess(inputA, inputB []model.AccessGroup) []model.AccessGroup { - var setA, setB map[string]model.AccessGroup + setA := map[string]model.AccessGroup{} + setB := map[string]model.AccessGroup{} + for _, access := range inputA { setA[access.GroupID] = access } @@ -86,15 +91,27 @@ func setDifferenceGroupAccess(inputA, inputB []model.AccessGroup) []model.Access result := make([]model.AccessGroup, 0, len(setA)) - for key := range setA { - if val, exist := setB[key]; !exist { - result = append(result, val) + for key, valA := range setA { + if valB, exist := setB[key]; !exist || !equalOptionalStrings(valA.SecurityPolicyID, valB.SecurityPolicyID) { + result = append(result, valA) } } return result } +func equalOptionalStrings(str1, str2 *string) bool { + if str1 == nil && str2 == nil { + return true + } + + if str1 == nil && str2 != nil || str1 != nil && str2 == nil { + return false + } + + return strings.EqualFold(*str1, *str2) +} + func setDifferenceGroups(inputA, inputB []model.AccessGroup) []string { groupsA := utils.Map(inputA, func(item model.AccessGroup) string { return item.GroupID diff --git a/twingate/internal/provider/resource/resource.go b/twingate/internal/provider/resource/resource.go index 98025ac1..ce13791c 100644 --- a/twingate/internal/provider/resource/resource.go +++ b/twingate/internal/provider/resource/resource.go @@ -283,7 +283,8 @@ func groupAccessBlock() schema.SetNestedBlock { Validators: []validator.String{ stringvalidator.AlsoRequires(path.MatchRelative().AtParent().AtName(attr.GroupID)), }, - Default: stringdefault.StaticString(""), + //Default: stringdefault.StaticString(""), + PlanModifiers: []planmodifier.String{PolicyForGroupAccess()}, }, }, }, @@ -1190,10 +1191,18 @@ func convertGroupsAccessToTerraform(ctx context.Context, groupAccess []model.Acc var objects []types.Object for _, access := range groupAccess { + var securityPolicy basetypes.StringValue + //if access.SecurityPolicyID == nil { + // securityPolicy = types.StringValue("") + //} else { + securityPolicy = types.StringPointerValue(access.SecurityPolicyID) + //} + attributes := map[string]tfattr.Value{ attr.GroupID: types.StringValue(access.GroupID), //attr.SecurityPolicyID: types.StringNull(), - attr.SecurityPolicyID: types.StringPointerValue(access.SecurityPolicyID), + //attr.SecurityPolicyID: types.StringPointerValue(access.SecurityPolicyID), + attr.SecurityPolicyID: securityPolicy, } obj, diags := types.ObjectValue(accessGroupAttributeTypes(), attributes) @@ -1312,3 +1321,52 @@ func accessServiceAccountAttributeTypes() map[string]tfattr.Type { attr.ServiceAccountID: types.StringType, } } + +func PolicyForGroupAccess() planmodifier.String { + return policyForGroupAccess{} +} + +type policyForGroupAccess struct{} + +func (m policyForGroupAccess) Description(_ context.Context) string { + return "" +} + +func (m policyForGroupAccess) MarkdownDescription(_ context.Context) string { + return "" +} + +func (m policyForGroupAccess) PlanModifyString(ctx context.Context, req planmodifier.StringRequest, resp *planmodifier.StringResponse) { + if req.StateValue.IsNull() && req.ConfigValue.IsNull() { + resp.PlanValue = types.StringNull() + + return + } + + // Do nothing if there is no state value. + if req.StateValue.IsNull() { + return + } + + // Do nothing if there is an unknown configuration value, otherwise interpolation gets messed up. + if req.ConfigValue.IsUnknown() { + return + } + + // Do nothing if there is a known planned value. + if req.ConfigValue.ValueString() != "" { + return + } + + if !req.StateValue.IsUnknown() && req.ConfigValue.IsNull() { + resp.PlanValue = types.StringNull() + + return + } + + //if req.StateValue.ValueString() == "" && req.PlanValue.ValueString() == DefaultSecurityPolicyID { + // resp.PlanValue = types.StringValue("") + //} else if req.StateValue.ValueString() == DefaultSecurityPolicyID && req.PlanValue.ValueString() == "" { + // resp.PlanValue = types.StringValue(DefaultSecurityPolicyID) + //} +} diff --git a/twingate/internal/test/acctests/helper.go b/twingate/internal/test/acctests/helper.go index a6347f3c..633f10f4 100644 --- a/twingate/internal/test/acctests/helper.go +++ b/twingate/internal/test/acctests/helper.go @@ -310,6 +310,68 @@ func DeactivateTwingateResource(resourceName string) sdk.TestCheckFunc { } } +func CheckTwingateResourceSecurityPolicyOnGroupAccess(resourceName string, expectedSecurityPolicy string) sdk.TestCheckFunc { + return func(s *terraform.State) error { + resourceState, ok := s.RootModule().Resources[resourceName] + + if !ok { + return fmt.Errorf("%w: %s", ErrResourceNotFound, resourceName) + } + + if resourceState.Primary.ID == "" { + return ErrResourceIDNotSet + } + + res, err := providerClient.ReadResource(context.Background(), resourceState.Primary.ID) + if err != nil { + return fmt.Errorf("failed to read resource: %w", err) + } + + if len(res.GroupsAccess) == 0 { + return errors.New("expected at least one group in GroupAccess") + } + + if res.GroupsAccess[0].SecurityPolicyID == nil { + return errors.New("expected non nil security policy in GroupAccess") + } + + if *res.GroupsAccess[0].SecurityPolicyID != expectedSecurityPolicy { + return fmt.Errorf("expected security policy %v, got %v", expectedSecurityPolicy, *res.GroupsAccess[0].SecurityPolicyID) //nolint:goerr113 + } + + return nil + } +} + +func CheckTwingateResourceSecurityPolicyIsNullOnGroupAccess(resourceName string) sdk.TestCheckFunc { + return func(s *terraform.State) error { + resourceState, ok := s.RootModule().Resources[resourceName] + + if !ok { + return fmt.Errorf("%w: %s", ErrResourceNotFound, resourceName) + } + + if resourceState.Primary.ID == "" { + return ErrResourceIDNotSet + } + + res, err := providerClient.ReadResource(context.Background(), resourceState.Primary.ID) + if err != nil { + return fmt.Errorf("failed to read resource: %w", err) + } + + if len(res.GroupsAccess) == 0 { + return errors.New("expected at least one group in GroupAccess") + } + + if res.GroupsAccess[0].SecurityPolicyID != nil { + return errors.New("expected nil security policy in GroupAccess, got non nil") + } + + return nil + } +} + func CheckTwingateResourceActiveState(resourceName string, expectedActiveState bool) sdk.TestCheckFunc { return func(s *terraform.State) error { resourceState, ok := s.RootModule().Resources[resourceName] diff --git a/twingate/internal/test/acctests/resource/resource_test.go b/twingate/internal/test/acctests/resource/resource_test.go index 96245303..a13297dd 100644 --- a/twingate/internal/test/acctests/resource/resource_test.go +++ b/twingate/internal/test/acctests/resource/resource_test.go @@ -144,6 +144,16 @@ func TestAccTwingateResourceCreateWithProtocolsAndGroups(t *testing.T) { sdk.TestCheckResourceAttr(theResource, firstTCPPort, "80"), ), }, + { + Config: createResourceWithProtocolsAndGroups2(remoteNetworkName, groupName1, groupName2, resourceName), + Check: acctests.ComposeTestCheckFunc( + acctests.CheckTwingateResourceExists(theResource), + sdk.TestCheckResourceAttr(theResource, attr.Address, "new-acc-test.com"), + sdk.TestCheckResourceAttr(theResource, accessGroupIdsLen, "1"), + sdk.TestCheckResourceAttr(theResource, tcpPolicy, model.PolicyRestricted), + sdk.TestCheckResourceAttr(theResource, firstTCPPort, "80"), + ), + }, }, }) } @@ -189,6 +199,47 @@ func createResourceWithProtocolsAndGroups(networkName, groupName1, groupName2, r `, networkName, groupName1, groupName2, resourceName, model.PolicyRestricted, model.PolicyAllowAll) } +func createResourceWithProtocolsAndGroups2(networkName, groupName1, groupName2, resourceName string) string { + return fmt.Sprintf(` + resource "twingate_remote_network" "test2" { + name = "%s" + } + + resource "twingate_group" "g21" { + name = "%s" + } + + resource "twingate_group" "g22" { + name = "%s" + } + + resource "twingate_resource" "test2" { + name = "%s" + address = "new-acc-test.com" + remote_network_id = twingate_remote_network.test2.id + + protocols = { + allow_icmp = true + tcp = { + policy = "%s" + ports = ["80", "82-83"] + } + udp = { + policy = "%s" + } + } + + dynamic "access_group" { + for_each = [twingate_group.g21.id] + content { + group_id = access_group.value + security_policy_id = null + } + } + } + `, networkName, groupName1, groupName2, resourceName, model.PolicyRestricted, model.PolicyAllowAll) +} + func TestAccTwingateResourceFullCreationFlow(t *testing.T) { const theResource = "twingate_resource.test3" remoteNetworkName := test.RandomName() @@ -3111,3 +3162,146 @@ func createResource(networkName, resourceName string) string { } `, networkName, resourceName) } + +func TestAccTwingateResourceUpdateSecurityPolicyOnGroupAccess(t *testing.T) { + t.Parallel() + + resourceName := test.RandomResourceName() + theResource := acctests.TerraformResource(resourceName) + remoteNetworkName := test.RandomName() + groupName := test.RandomGroupName() + + defaultPolicy, testPolicy := preparePolicies(t) + + sdk.Test(t, sdk.TestCase{ + ProtoV6ProviderFactories: acctests.ProviderFactories, + PreCheck: func() { acctests.PreCheck(t) }, + CheckDestroy: acctests.CheckTwingateResourceDestroy, + Steps: []sdk.TestStep{ + { + Config: createResourceWithSecurityPolicyOnGroupAccess(remoteNetworkName, resourceName, testPolicy, groupName), + Check: acctests.ComposeTestCheckFunc( + acctests.CheckTwingateResourceExists(theResource), + acctests.CheckTwingateResourceSecurityPolicyOnGroupAccess(theResource, testPolicy), + ), + }, + { + Config: createResourceWithSecurityPolicyOnGroupAccess(remoteNetworkName, resourceName, defaultPolicy, groupName), + Check: acctests.ComposeTestCheckFunc( + acctests.CheckTwingateResourceExists(theResource), + acctests.CheckTwingateResourceSecurityPolicyOnGroupAccess(theResource, defaultPolicy), + ), + }, + { + Config: createResourceWithoutSecurityPolicyOnGroupAccess(remoteNetworkName, resourceName, groupName), + Check: acctests.ComposeTestCheckFunc( + acctests.CheckTwingateResourceSecurityPolicyOnGroupAccess(theResource, defaultPolicy), + ), + }, + }, + }) +} + +func createResourceWithSecurityPolicyOnGroupAccess(remoteNetwork, resource, policyID, groupName string) string { + return fmt.Sprintf(` + resource "twingate_group" "g21" { + name = "%[4]s" + } + + resource "twingate_remote_network" "%[1]s" { + name = "%[1]s" + } + + resource "twingate_resource" "%[2]s" { + name = "%[2]s" + address = "acc-test-address.com" + remote_network_id = twingate_remote_network.%[1]s.id + + access_group { + group_id = twingate_group.g21.id + security_policy_id = "%[3]s" + } + } + `, remoteNetwork, resource, policyID, groupName) +} + +func createResourceWithoutSecurityPolicyOnGroupAccess(remoteNetwork, resource, groupName string) string { + return fmt.Sprintf(` + resource "twingate_group" "g21" { + name = "%[3]s" + } + + resource "twingate_remote_network" "%[1]s" { + name = "%[1]s" + } + + resource "twingate_resource" "%[2]s" { + name = "%[2]s" + address = "acc-test-address.com" + remote_network_id = twingate_remote_network.%[1]s.id + + access_group { + group_id = twingate_group.g21.id + } + } + `, remoteNetwork, resource, groupName) +} + +func createResourceWithNullSecurityPolicyOnGroupAccess(remoteNetwork, resource, groupName string) string { + return fmt.Sprintf(` + resource "twingate_group" "g21" { + name = "%[3]s" + } + + resource "twingate_remote_network" "%[1]s" { + name = "%[1]s" + } + + resource "twingate_resource" "%[2]s" { + name = "%[2]s" + address = "acc-test-address.com" + remote_network_id = twingate_remote_network.%[1]s.id + + access_group { + group_id = twingate_group.g21.id + security_policy_id = null + } + } + `, remoteNetwork, resource, groupName) +} + +func TestAccTwingateResourceUnsetSecurityPolicyOnGroupAccess(t *testing.T) { + t.Parallel() + + resourceName := test.RandomResourceName() + theResource := acctests.TerraformResource(resourceName) + remoteNetworkName := test.RandomName() + groupName := test.RandomGroupName() + + defaultPolicy, testPolicy := preparePolicies(t) + _ = defaultPolicy + + sdk.Test(t, sdk.TestCase{ + ProtoV6ProviderFactories: acctests.ProviderFactories, + PreCheck: func() { acctests.PreCheck(t) }, + CheckDestroy: acctests.CheckTwingateResourceDestroy, + Steps: []sdk.TestStep{ + { + Config: createResourceWithSecurityPolicyOnGroupAccess(remoteNetworkName, resourceName, testPolicy, groupName), + Check: acctests.ComposeTestCheckFunc( + acctests.CheckTwingateResourceExists(theResource), + acctests.CheckTwingateResourceSecurityPolicyOnGroupAccess(theResource, testPolicy), + ), + }, + { + Config: createResourceWithNullSecurityPolicyOnGroupAccess(remoteNetworkName, resourceName, groupName), + //// no changes + //PlanOnly: true, + + Check: acctests.ComposeTestCheckFunc( + acctests.CheckTwingateResourceSecurityPolicyIsNullOnGroupAccess(theResource), + ), + }, + }, + }) +}