Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Firewall: select an incoming or outgoing chain when creating a rule #64

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 53 additions & 129 deletions firewall.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
//go:build windows && amd64
// +build windows,amd64

package winapi

import (
"fmt"
"log"
"runtime"

ole "github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
"github.com/scjalliance/comshim"
)

// Firewall related API constants.
Expand Down Expand Up @@ -105,12 +102,12 @@ func (r *FWRule) InProfiles() FWProfiles {
// NET_FW_PROFILE2_CURRENT // adds rule to currently used FW Profile(-s)
// NET_FW_PROFILE2_ALL // adds rule to all profiles
// NET_FW_PROFILE2_DOMAIN|NET_FW_PROFILE2_PRIVATE // rule in Private and Domain profile
func FirewallRuleAdd(name, description, group, ports string, protocol, profile int32) (bool, error) {
func FirewallRuleAdd(name, description, group, ports, remotePorts string, protocol, direction, profile int32) (bool, error) {

if ports == "" {
return false, fmt.Errorf("empty FW Rule ports, it is mandatory")
}
return firewallRuleAdd(name, description, group, "", "", ports, "", "", "", "", protocol, 0, NET_FW_ACTION_ALLOW, profile, true, false)
// if ports == "" {
// return false, fmt.Errorf("empty FW Rule ports, it is mandatory")
// }
return firewallRuleAdd(name, description, group, "", "", ports, remotePorts, "", "", "", protocol, direction, profile, true, false)
}

// FirewallRuleAddApplication creates Inbound rule for given application.
Expand All @@ -131,16 +128,16 @@ func FirewallRuleAdd(name, description, group, ports string, protocol, profile i
// NET_FW_PROFILE2_CURRENT // adds rule to currently used FW Profile
// NET_FW_PROFILE2_ALL // adds rule to all profiles
// NET_FW_PROFILE2_DOMAIN|NET_FW_PROFILE2_PRIVATE // rule in Private and Domain profile
func FirewallRuleAddApplication(name, description, group, appPath string, profile int32) (bool, error) {
func FirewallRuleAddApplication(name, description, group, appPath, remotePorts string, protocol, direction, profile int32) (bool, error) {
if appPath == "" {
return false, fmt.Errorf("empty FW Rule appPath, it is mandatory")
}
return firewallRuleAdd(name, description, group, appPath, "", "", "", "", "", "", 0, 0, NET_FW_ACTION_ALLOW, profile, true, false)
return firewallRuleAdd(name, description, group, appPath, "", "", remotePorts, "", "", "", protocol, direction, profile, true, false)
}

// FirewallRuleCreate is deprecated, use FirewallRuleAddApplication instead.
func FirewallRuleCreate(name, description, group, appPath, port string, protocol int32) (bool, error) {
return firewallRuleAdd(name, description, group, appPath, "", port, "", "", "", "", protocol, 0, NET_FW_ACTION_ALLOW, NET_FW_PROFILE2_CURRENT, true, false)
return firewallRuleAdd(name, description, group, appPath, "", port, "", "", "", "", protocol, 0, NET_FW_PROFILE2_CURRENT, true, false)
}

// FirewallPingEnable creates Inbound ICMPv4 rule which allows to answer echo requests.
Expand All @@ -159,7 +156,7 @@ func FirewallRuleCreate(name, description, group, appPath, port string, protocol
// NET_FW_PROFILE2_ALL // adds rule to all profiles
// NET_FW_PROFILE2_DOMAIN|NET_FW_PROFILE2_PRIVATE // rule in Private and Domain profile
func FirewallPingEnable(name, description, group, remoteAddresses string, profile int32) (bool, error) {
return firewallRuleAdd(name, description, group, "", "", "", "", "", remoteAddresses, "8:*", NET_FW_IP_PROTOCOL_ICMPv4, 0, NET_FW_ACTION_ALLOW, profile, true, false)
return firewallRuleAdd(name, description, group, "", "", "", "", "", remoteAddresses, "8:*", NET_FW_IP_PROTOCOL_ICMPv4, 0, profile, true, false)
}

// FirewallRuleAddAdvanced allows to modify almost all available FW Rule parameters.
Expand All @@ -170,7 +167,7 @@ func FirewallPingEnable(name, description, group, remoteAddresses string, profil
func FirewallRuleAddAdvanced(rule FWRule) (bool, error) {
return firewallRuleAdd(rule.Name, rule.Description, rule.Grouping, rule.ApplicationName, rule.ServiceName,
rule.LocalPorts, rule.RemotePorts, rule.LocalAddresses, rule.RemoteAddresses, rule.ICMPTypesAndCodes,
rule.Protocol, rule.Direction, rule.Action, rule.Profiles, rule.Enabled, rule.EdgeTraversal)
rule.Protocol, rule.Direction, rule.Profiles, rule.Enabled, rule.EdgeTraversal)
}

// FirewallRuleDelete allows you to delete existing rule by name.
Expand Down Expand Up @@ -243,7 +240,7 @@ func FirewallRuleGet(name string) (FWRule, error) {
if err != nil {
return rule, err
}
defer firewallRulesEnumRelease(ur, ep, enum)
defer firewallRulesEnumRealease(ur, ep)

for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
if err != nil {
Expand All @@ -270,7 +267,7 @@ func FirewallRuleGet(name string) (FWRule, error) {

// FirewallRulesGet returns all rules defined in firewall.
func FirewallRulesGet() ([]FWRule, error) {
rules := make([]FWRule, 0, 1024)
rules := make([]FWRule, 1000)

u, fwPolicy, err := firewallAPIInit()
if err != nil {
Expand All @@ -282,7 +279,7 @@ func FirewallRulesGet() ([]FWRule, error) {
if err != nil {
return rules, err
}
defer firewallRulesEnumRelease(ur, ep, enum)
defer firewallRulesEnumRealease(ur, ep)

for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
if err != nil {
Expand Down Expand Up @@ -314,108 +311,94 @@ func firewallRuleParams(itemRaw ole.VARIANT) (FWRule, error) {
item := itemRaw.ToIDispatch()
defer item.Release()

var err error
rule.Name, err = getStringProperty(item, "Name")
name, err := oleutil.GetProperty(item, "Name")
if err != nil {
return rule, fmt.Errorf("failed to get Property (Name) of Rule")
}
rule.Description, err = getStringProperty(item, "Description")
rule.Name = name.ToString()
description, err := oleutil.GetProperty(item, "Description")
if err != nil {
return rule, fmt.Errorf("failed to get Property (Description) of Rule %q", rule.Name)
}
rule.ApplicationName, err = getStringProperty(item, "ApplicationName")
rule.Description = description.ToString()
applicationApplicationName, err := oleutil.GetProperty(item, "ApplicationName")
if err != nil {
return rule, fmt.Errorf("failed to get Property (ApplicationName) of Rule %q", rule.Name)
}
rule.ServiceName, err = getStringProperty(item, "ServiceName")
rule.ApplicationName = applicationApplicationName.ToString()
serviceName, err := oleutil.GetProperty(item, "ServiceName")
if err != nil {
return rule, fmt.Errorf("failed to get Property (ServiceName) of Rule %q", rule.Name)
}
rule.LocalPorts, err = getStringProperty(item, "LocalPorts")
rule.ServiceName = serviceName.ToString()
localPorts, err := oleutil.GetProperty(item, "LocalPorts")
if err != nil {
return rule, fmt.Errorf("failed to get Property (LocalPorts) of Rule %q", rule.Name)
}

rule.RemotePorts, err = getStringProperty(item, "RemotePorts")
rule.LocalPorts = localPorts.ToString()
remotePorts, err := oleutil.GetProperty(item, "RemotePorts")
if err != nil {
return rule, fmt.Errorf("failed to get Property (RemotePorts) of Rule %q", rule.Name)
}
rule.LocalAddresses, err = getStringProperty(item, "LocalAddresses")
rule.RemotePorts = remotePorts.ToString()
localAddresses, err := oleutil.GetProperty(item, "LocalAddresses")
if err != nil {
return rule, fmt.Errorf("failed to get Property (LocalAddresses) of Rule %q", rule.Name)
}
rule.RemoteAddresses, err = getStringProperty(item, "RemoteAddresses")
rule.LocalAddresses = localAddresses.ToString()
remoteAddresses, err := oleutil.GetProperty(item, "RemoteAddresses")
if err != nil {
return rule, fmt.Errorf("failed to get Property (RemoteAddresses) of Rule %q", rule.Name)
}
rule.ICMPTypesAndCodes, err = getStringProperty(item, "ICMPTypesAndCodes")
rule.RemoteAddresses = remoteAddresses.ToString()
icmpTypesAndCodes, err := oleutil.GetProperty(item, "ICMPTypesAndCodes")
if err != nil {
return rule, fmt.Errorf("failed to get Property (ICMPTypesAndCodes) of Rule %q", rule.Name)
}
rule.Grouping, err = getStringProperty(item, "Grouping")
rule.ICMPTypesAndCodes = icmpTypesAndCodes.ToString()
grouping, err := oleutil.GetProperty(item, "Grouping")
if err != nil {
return rule, fmt.Errorf("failed to get Property (Grouping) of Rule %q", rule.Name)
}
rule.InterfaceTypes, err = getStringProperty(item, "InterfaceTypes")
rule.Grouping = grouping.ToString()
interfaceTypes, err := oleutil.GetProperty(item, "InterfaceTypes")
if err != nil {
return rule, fmt.Errorf("failed to get Property (InterfaceTypes) of Rule %q", rule.Name)
}
rule.Protocol, err = getInt32Property(item, "Protocol")
rule.InterfaceTypes = interfaceTypes.ToString()
protocol, err := oleutil.GetProperty(item, "Protocol")
if err != nil {
return rule, fmt.Errorf("failed to get Property (Protocol) of Rule %q", rule.Name)
}
rule.Direction, err = getInt32Property(item, "Direction")
rule.Protocol = protocol.Value().(int32)
direction, err := oleutil.GetProperty(item, "Direction")
if err != nil {
return rule, fmt.Errorf("failed to get Property (Direction) of Rule %q", rule.Name)
}
rule.Action, err = getInt32Property(item, "Action")
rule.Direction = direction.Value().(int32)
action, err := oleutil.GetProperty(item, "Action")
if err != nil {
return rule, fmt.Errorf("failed to get Property (Action) of Rule %q", rule.Name)
}
rule.Enabled, err = getBoolProperty(item, "Enabled")
rule.Action = action.Value().(int32)
enabled, err := oleutil.GetProperty(item, "Enabled")
if err != nil {
return rule, fmt.Errorf("failed to get Property (Enabled) of Rule %q", rule.Name)
}
rule.EdgeTraversal, err = getBoolProperty(item, "EdgeTraversal")
rule.Enabled = enabled.Value().(bool)
edgeTraversal, err := oleutil.GetProperty(item, "EdgeTraversal")
if err != nil {
return rule, fmt.Errorf("failed to get Property (EdgeTraversal) of Rule %q", rule.Name)
}
rule.Profiles, err = getInt32Property(item, "Profiles")
rule.EdgeTraversal = edgeTraversal.Value().(bool)
profiles, err := oleutil.GetProperty(item, "Profiles")
if err != nil {
return rule, fmt.Errorf("failed to get Property (Profiles) of Rule %q", rule.Name)
}
rule.Profiles = profiles.Value().(int32)

return rule, nil
}
func getInt32Property(dispatch *ole.IDispatch, property string) (int32, error) {
val, err := oleutil.GetProperty(dispatch, property)
if err != nil {
log.Printf("failed to get dispatch property: %s \n", err.Error())
return 0, err
}
defer val.Clear()
return val.Value().(int32), nil
}

func getStringProperty(dispatch *ole.IDispatch, property string) (string, error) {
val, err := oleutil.GetProperty(dispatch, property)
if err != nil {
log.Printf("failed to get dispatch property: %s \n", err.Error())
return "", err
}
defer val.Clear()
return val.ToString(), nil
}

func getBoolProperty(dispatch *ole.IDispatch, property string) (bool, error) {
val, err := oleutil.GetProperty(dispatch, property)
if err != nil {
log.Printf("failed to get dispatch property: %s \n", err.Error())
return false, err
}
defer val.Clear()
return val.Value().(bool), nil
}

// FirewallGroupEnable allows to enable predefined firewall group. It is better
// to not use names as "File and Printer Sharing" because they are localized and
Expand Down Expand Up @@ -537,64 +520,6 @@ func FirewallDisable(profile int32) (bool, error) {
return true, nil
}

// FirewallGetDefaultOutboundAction checks if outgoing connections without matching rules are allowed or blocked.
// Returns either NET_FW_ACTION_ALLOW or NET_FW_ACTION_BLOCK.
func FirewallGetDefaultOutboundAction(profile int32) (int32, error) {
u, fwPolicy, err := firewallAPIInit()
if err != nil {
return 0, err
}
defer firewallAPIRelease(u, fwPolicy)

action, err := oleutil.GetProperty(fwPolicy, "DefaultOutboundAction", profile)
if err != nil {
return 0, err
}
return action.Value().(int32), nil
}

// FirewallGetDefaultInboundAction checks if incoming connections without matching rules are allowed or blocked.
// Returns either NET_FW_ACTION_ALLOW or NET_FW_ACTION_BLOCK.
func FirewallGetDefaultInboundAction(profile int32) (int32, error) {
u, fwPolicy, err := firewallAPIInit()
if err != nil {
return 0, err
}
defer firewallAPIRelease(u, fwPolicy)

action, err := oleutil.GetProperty(fwPolicy, "DefaultInboundAction", profile)
if err != nil {
return 0, err
}
return action.Value().(int32), nil
}

// FirewallSetDefaultOutboundAction sets the default policy for outgoing connections.
// action must be NET_FW_ACTION_ALLOW or NET_FW_ACTION_BLOCK.
func FirewallSetDefaultOutboundAction(profile, action int32) error {
u, fwPolicy, err := firewallAPIInit()
if err != nil {
return err
}
defer firewallAPIRelease(u, fwPolicy)

_, err = oleutil.PutProperty(fwPolicy, "DefaultOutboundAction", profile, action)
return err
}

// FirewallSetDefaultInboundAction sets the default policy for incoming connections.
//// action must be NET_FW_ACTION_ALLOW or NET_FW_ACTION_BLOCK.
func FirewallSetDefaultInboundAction(profile, action int32) error {
u, fwPolicy, err := firewallAPIInit()
if err != nil {
return err
}
defer firewallAPIRelease(u, fwPolicy)

_, err = oleutil.PutProperty(fwPolicy, "DefaultInboundAction", profile, action)
return err
}

// FirewallCurrentProfiles return which profiles are currently active.
// Every active interface can have it's own profile. F.e.: Public for Wifi,
// Domain for VPN, and Private for LAN. All at the same time.
Expand Down Expand Up @@ -636,7 +561,7 @@ func firewallParseProfiles(v int32) FWProfiles {
}

// firewallRuleAdd is universal function to add all kinds of rules.
func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, action, profile int32, enabled, edgeTraversal bool) (bool, error) {
func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, profile int32, enabled, edgeTraversal bool) (bool, error) {

if name == "" {
return false, fmt.Errorf("empty FW Rule name, name is mandatory")
Expand Down Expand Up @@ -742,7 +667,7 @@ func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remo
if _, err := oleutil.PutProperty(fwRule, "Profiles", profile); err != nil {
return false, fmt.Errorf("Error setting property (Profiles) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Action", action); err != nil {
if _, err := oleutil.PutProperty(fwRule, "Action", NET_FW_ACTION_ALLOW); err != nil {
return false, fmt.Errorf("Error setting property (Action) of Rule: %s", err)
}
if edgeTraversal {
Expand Down Expand Up @@ -803,7 +728,7 @@ func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) {

// firewallRulesEnum takes fwPolicy object and returns all objects which needs freeing and enum itself,
// which is used to enumerate rules. do not forget to:
// defer firewallRulesEnumRelease(ur, ep)
// defer firewallRulesEnumRealease(ur, ep)
func firewallRulesEnum(fwPolicy *ole.IDispatch) (*ole.VARIANT, *ole.VARIANT, *ole.IEnumVARIANT, error) {
unknownRules, err := oleutil.GetProperty(fwPolicy, "Rules")
if err != nil {
Expand Down Expand Up @@ -832,17 +757,16 @@ func firewallRulesEnum(fwPolicy *ole.IDispatch) (*ole.VARIANT, *ole.VARIANT, *ol
}

// firewallRuleEnumRelease will free memory used by firewallRulesEnum.
func firewallRulesEnumRelease(unknownRules, enumProperty *ole.VARIANT, enum *ole.IEnumVARIANT) {
func firewallRulesEnumRealease(unknownRules, enumProperty *ole.VARIANT) {
enumProperty.Clear()
unknownRules.Clear()
enum.Release()
}

// firewallAPIInit initialize common fw api.
// then:
// dispatch firewallAPIRelease(u, fwp)
func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
comshim.Add(1)
ole.CoInitializeEx(0, ole.COINIT_APARTMENTTHREADED|ole.COINIT_SPEED_OVER_MEMORY)

unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2")
if err != nil {
Expand All @@ -863,5 +787,5 @@ func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) {
fwp.Release()
u.Release()
comshim.Done()
ole.CoUninitialize()
}