Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: added security_policy_id to resource definition #425

Merged
merged 10 commits into from
Nov 28, 2023
7 changes: 7 additions & 0 deletions docs/resources/resource.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ resource "twingate_service_account" "github_actions_prod" {
name = "Github Actions PROD"
}

data "twingate_security_policy" "test_policy" {
name = "Test Policy"
}

resource "twingate_resource" "resource" {
name = "network"
address = "internal.int"
remote_network_id = twingate_remote_network.aws_network.id

security_policy_id = data.twingate_security_policy.test_policy.id

protocols {
allow_icmp = true
tcp {
Expand Down Expand Up @@ -70,6 +76,7 @@ resource "twingate_resource" "resource" {
- `is_browser_shortcut_enabled` (Boolean) Controls whether an "Open in Browser" shortcut will be shown for this Resource in the Twingate Client.
- `is_visible` (Boolean) Controls whether this Resource will be visible in the main Resource list in the Twingate Client.
- `protocols` (Block List, Max: 1) Restrict access to certain protocols and ports. By default or when this argument is not defined, there is no restriction, and all protocols and ports are allowed. (see [below for nested schema](#nestedblock--protocols))
- `security_policy_id` (String) The ID of a `twingate_security_policy` to set as this Resource's Security Policy. Default is `Default Policy`

### Read-Only

Expand Down
6 changes: 6 additions & 0 deletions examples/resources/twingate_resource/resource.tf
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ resource "twingate_service_account" "github_actions_prod" {
name = "Github Actions PROD"
}

data "twingate_security_policy" "test_policy" {
name = "Test Policy"
}

resource "twingate_resource" "resource" {
name = "network"
address = "internal.int"
remote_network_id = twingate_remote_network.aws_network.id

security_policy_id = data.twingate_security_policy.test_policy.id

protocols {
allow_icmp = true
tcp {
Expand Down
2 changes: 1 addition & 1 deletion twingate/internal/client/query/resource-create.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package query

type CreateResource struct {
ResourceEntityResponse `graphql:"resourceCreate(name: $name, address: $address, remoteNetworkId: $remoteNetworkId, groupIds: $groupIds, protocols: $protocols, isVisible: $isVisible, isBrowserShortcutEnabled: $isBrowserShortcutEnabled, alias: $alias)"`
ResourceEntityResponse `graphql:"resourceCreate(name: $name, address: $address, remoteNetworkId: $remoteNetworkId, groupIds: $groupIds, protocols: $protocols, isVisible: $isVisible, isBrowserShortcutEnabled: $isBrowserShortcutEnabled, alias: $alias, securityPolicyId: $securityPolicyId)"`
}

func (q CreateResource) IsEmpty() bool {
Expand Down
7 changes: 7 additions & 0 deletions twingate/internal/client/query/resource-read.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type ResourceNode struct {
IsVisible bool
IsBrowserShortcutEnabled bool
Alias string
SecurityPolicy *gqlSecurityPolicy
}

type Protocols struct {
Expand Down Expand Up @@ -90,6 +91,11 @@ func (r gqlResource) ToModel() *model.Resource {
}

func (r ResourceNode) ToModel() *model.Resource {
var securityPolicy string
if r.SecurityPolicy != nil {
securityPolicy = string(r.SecurityPolicy.ID)
}

return &model.Resource{
ID: string(r.ID),
Name: r.Name,
Expand All @@ -100,6 +106,7 @@ func (r ResourceNode) ToModel() *model.Resource {
IsVisible: &r.IsVisible,
IsBrowserShortcutEnabled: &r.IsBrowserShortcutEnabled,
Alias: optionalString(r.Alias),
SecurityPolicyID: optionalString(securityPolicy),
}
}

Expand Down
2 changes: 1 addition & 1 deletion twingate/internal/client/query/resource-update.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package query

type UpdateResource struct {
ResourceEntityResponse `graphql:"resourceUpdate(id: $id, name: $name, address: $address, remoteNetworkId: $remoteNetworkId, protocols: $protocols, isVisible: $isVisible, isBrowserShortcutEnabled: $isBrowserShortcutEnabled, alias: $alias)"`
ResourceEntityResponse `graphql:"resourceUpdate(id: $id, name: $name, address: $address, remoteNetworkId: $remoteNetworkId, protocols: $protocols, isVisible: $isVisible, isBrowserShortcutEnabled: $isBrowserShortcutEnabled, alias: $alias, securityPolicyId: $securityPolicyId)"`
}

func (q UpdateResource) IsEmpty() bool {
Expand Down
10 changes: 10 additions & 0 deletions twingate/internal/client/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (client *Client) CreateResource(ctx context.Context, input *model.Resource)
gqlNullable(input.IsVisible, "isVisible"),
gqlNullable(input.IsBrowserShortcutEnabled, "isBrowserShortcutEnabled"),
gqlNullable(input.Alias, "alias"),
gqlNullableID(input.SecurityPolicyID, "securityPolicyId"),
cursor(query.CursorAccess),
pageLimit(client.pageLimit),
)
Expand All @@ -92,6 +93,10 @@ func (client *Client) CreateResource(ctx context.Context, input *model.Resource)
resource.IsBrowserShortcutEnabled = nil
}

if input.SecurityPolicyID == nil {
resource.SecurityPolicyID = nil
}

return resource, nil
}

Expand Down Expand Up @@ -180,6 +185,7 @@ func (client *Client) UpdateResource(ctx context.Context, input *model.Resource)
gqlNullable(input.IsVisible, "isVisible"),
gqlNullable(input.IsBrowserShortcutEnabled, "isBrowserShortcutEnabled"),
gqlNullable(input.Alias, "alias"),
gqlNullableID(input.SecurityPolicyID, "securityPolicyId"),
cursor(query.CursorAccess),
pageLimit(client.pageLimit),
)
Expand All @@ -204,6 +210,10 @@ func (client *Client) UpdateResource(ctx context.Context, input *model.Resource)
resource.IsBrowserShortcutEnabled = nil
}

if input.SecurityPolicyID == nil {
resource.SecurityPolicyID = nil
}

return resource, nil
}

Expand Down
5 changes: 5 additions & 0 deletions twingate/internal/client/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,18 @@ func getValue(val any) any {
}
}

//nolint:unparam
func gqlNullableID(val interface{}, name string) gqlVarOption {
return func(values map[string]interface{}) map[string]interface{} {
var (
gqlValue interface{}
defaultID *graphql.ID
)

if value, ok := val.(*string); ok && value != nil {
val = *value
}

if isZeroValue(val) {
gqlValue = defaultID
} else {
Expand Down
1 change: 1 addition & 0 deletions twingate/internal/model/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Resource struct {
IsVisible *bool
IsBrowserShortcutEnabled *bool
Alias *string
SecurityPolicyID *string
}

func (r Resource) AccessToTerraform() []interface{} {
Expand Down
85 changes: 72 additions & 13 deletions twingate/internal/provider/resource/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
)

const DefaultSecurityPolicyName = "Default Policy"

var (
DefaultSecurityPolicyID string //nolint:gochecknoglobals
ErrPortsWithPolicyAllowAll = errors.New(model.PolicyAllowAll + " policy does not allow specifying ports.")
ErrPortsWithPolicyDenyAll = errors.New(model.PolicyDenyAll + " policy does not allow specifying ports.")
ErrPolicyRestrictedWithoutPorts = errors.New(model.PolicyRestricted + " policy requires specifying ports.")
Expand Down Expand Up @@ -136,6 +139,13 @@ func Resource() *schema.Resource { //nolint:funlen
Description: "Restrict access to certain groups or service accounts",
Elem: accessSchema,
},
attr.SecurityPolicyID: {
Type: schema.TypeString,
Optional: true,
Description: "The ID of a `twingate_security_policy` to set as this Resource's Security Policy. Default is `Default Policy`",
DiffSuppressOnRefresh: true,
DiffSuppressFunc: defaultPolicyNotChanged,
},
// computed
attr.IsVisible: {
Type: schema.TypeBool,
Expand Down Expand Up @@ -222,7 +232,13 @@ func resourceUpdate(ctx context.Context, resourceData *schema.ResourceData, meta
attr.IsVisible,
attr.IsBrowserShortcutEnabled,
attr.Alias,
attr.SecurityPolicyID,
) {
diagErr := setDefaultSecurityPolicy(ctx, resource, client)
if diagErr.HasError() {
return diagErr
}

resource, err = client.UpdateResource(ctx, resource)
} else {
resource, err = client.ReadResource(ctx, resource.ID)
Expand All @@ -236,12 +252,43 @@ func resourceUpdate(ctx context.Context, resourceData *schema.ResourceData, meta
return resourceResourceReadHelper(ctx, client, resourceData, resource, err)
}

func setDefaultSecurityPolicy(ctx context.Context, resource *model.Resource, client *client.Client) diag.Diagnostics {
if DefaultSecurityPolicyID == "" {
policy, _ := client.ReadSecurityPolicy(ctx, "", DefaultSecurityPolicyName)
if policy != nil {
DefaultSecurityPolicyID = policy.ID
}
}

if DefaultSecurityPolicyID == "" {
return diag.Errorf("default policy not set")
}

remoteResource, err := client.ReadResource(ctx, resource.ID)
if err != nil {
return diag.FromErr(err)
}

if remoteResource.SecurityPolicyID != nil && (resource.SecurityPolicyID == nil || *resource.SecurityPolicyID == "") &&
*remoteResource.SecurityPolicyID != DefaultSecurityPolicyID {
resource.SecurityPolicyID = &DefaultSecurityPolicyID
}

return nil
}

func resourceRead(ctx context.Context, resourceData *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*client.Client)

securityPolicyID := resourceData.Get(attr.SecurityPolicyID)

resource, err := client.ReadResource(ctx, resourceData.Id())
if resource != nil {
resource.IsAuthoritative = convertAuthoritativeFlagLegacy(resourceData)

if securityPolicyID == "" {
resource.SecurityPolicyID = nil
}
}

return resourceResourceReadHelper(ctx, client, resourceData, resource, err)
Expand Down Expand Up @@ -348,13 +395,12 @@ func readDiagnostics(resourceData *schema.ResourceData, resource *model.Resource
}
}

var alias interface{}
if resource.Alias != nil {
alias = *resource.Alias
if err := resourceData.Set(attr.Alias, resource.Alias); err != nil {
return ErrAttributeSet(err, attr.Alias)
}

if err := resourceData.Set(attr.Alias, alias); err != nil {
return ErrAttributeSet(err, attr.Alias)
if err := resourceData.Set(attr.SecurityPolicyID, resource.SecurityPolicyID); err != nil {
return ErrAttributeSet(err, attr.SecurityPolicyID)
}

return nil
Expand Down Expand Up @@ -440,6 +486,10 @@ func protocolsNotChanged(attribute, oldValue, newValue string, data *schema.Reso
return false
}

func defaultPolicyNotChanged(attribute, oldValue, newValue string, data *schema.ResourceData) bool {
return oldValue == DefaultSecurityPolicyID && (newValue == "" || newValue == DefaultSecurityPolicyID)
}

func getChangedAccessIDs(ctx context.Context, resourceData *schema.ResourceData, resource *model.Resource, client *client.Client) ([]string, []string, error) {
remote, err := client.ReadResource(ctx, resource.ID)
if err != nil {
Expand Down Expand Up @@ -483,14 +533,15 @@ func convertResource(data *schema.ResourceData) (*model.Resource, error) {

groups, serviceAccounts := convertAccess(data)
res := &model.Resource{
Name: data.Get(attr.Name).(string),
RemoteNetworkID: data.Get(attr.RemoteNetworkID).(string),
Address: data.Get(attr.Address).(string),
Protocols: protocols,
Groups: groups,
ServiceAccounts: serviceAccounts,
IsAuthoritative: convertAuthoritativeFlagLegacy(data),
Alias: getOptionalString(data, attr.Alias),
Name: data.Get(attr.Name).(string),
RemoteNetworkID: data.Get(attr.RemoteNetworkID).(string),
Address: data.Get(attr.Address).(string),
Protocols: protocols,
Groups: groups,
ServiceAccounts: serviceAccounts,
IsAuthoritative: convertAuthoritativeFlagLegacy(data),
Alias: getOptionalString(data, attr.Alias),
SecurityPolicyID: getOptionalString(data, attr.SecurityPolicyID),
}

isVisible, ok := data.GetOkExists(attr.IsVisible) //nolint
Expand Down Expand Up @@ -524,9 +575,17 @@ func isAttrKnown(data *schema.ResourceData, attr string) bool {
}

func getOptionalString(data *schema.ResourceData, attr string) *string {
if data == nil {
return nil
}

var result *string

cfg := data.GetRawConfig()
if cfg.IsNull() {
return nil
}

val := cfg.GetAttr(attr)

if !val.IsNull() {
Expand Down
43 changes: 43 additions & 0 deletions twingate/internal/test/acctests/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,49 @@ func CheckResourceServiceAccountsLen(resourceName string, expectedServiceAccount
}
}

func CheckResourceSecurityPolicy(resourceName string, expectedSecurityPolicyID string) sdk.TestCheckFunc {
return func(state *terraform.State) error {
resourceID, err := getResourceID(state, resourceName)
if err != nil {
return err
}

resource, err := providerClient.ReadResource(context.Background(), resourceID)
if err != nil {
return fmt.Errorf("resource with ID %s failed to read: %w", resourceID, err)
}

if resource.SecurityPolicyID != nil && *resource.SecurityPolicyID != expectedSecurityPolicyID {
return fmt.Errorf("expected security_policy_id %s, got %s", expectedSecurityPolicyID, *resource.SecurityPolicyID) //nolint
}

return nil
}
}

func UpdateResourceSecurityPolicy(resourceName, securityPolicyID string) sdk.TestCheckFunc {
return func(state *terraform.State) error {
resourceID, err := getResourceID(state, resourceName)
if err != nil {
return err
}

resource, err := providerClient.ReadResource(context.Background(), resourceID)
if err != nil {
return fmt.Errorf("resource with ID %s failed to read: %w", resourceID, err)
}

resource.SecurityPolicyID = &securityPolicyID

_, err = providerClient.UpdateResource(context.Background(), resource)
if err != nil {
return fmt.Errorf("resource with ID %s failed to update security_policy: %w", resourceID, err)
}

return nil
}
}

func AddGroupUser(groupResource, groupName, terraformUserID string) sdk.TestCheckFunc {
return func(state *terraform.State) error {
userID, err := getResourceID(state, getResourceNameFromID(terraformUserID))
Expand Down
Loading