diff --git a/firewall.go b/firewall.go index 9cdebeb..25e1469 100755 --- a/firewall.go +++ b/firewall.go @@ -1,16 +1,13 @@ -//go:build windows && amd64 // +build windows,amd64 package winapi import ( "fmt" - "log" "runtime" ole "github.com/go-ole/go-ole" "github.com/go-ole/go-ole/oleutil" - "github.com/scjalliance/comshim" ) // Firewall related API constants. @@ -105,12 +102,12 @@ func (r *FWRule) InProfiles() FWProfiles { // NET_FW_PROFILE2_CURRENT // adds rule to currently used FW Profile(-s) // NET_FW_PROFILE2_ALL // adds rule to all profiles // NET_FW_PROFILE2_DOMAIN|NET_FW_PROFILE2_PRIVATE // rule in Private and Domain profile -func FirewallRuleAdd(name, description, group, ports string, protocol, profile int32) (bool, error) { +func FirewallRuleAdd(name, description, group, ports, remotePorts string, protocol, direction, profile int32) (bool, error) { - if ports == "" { - return false, fmt.Errorf("empty FW Rule ports, it is mandatory") - } - return firewallRuleAdd(name, description, group, "", "", ports, "", "", "", "", protocol, 0, NET_FW_ACTION_ALLOW, profile, true, false) +// if ports == "" { +// return false, fmt.Errorf("empty FW Rule ports, it is mandatory") +// } + return firewallRuleAdd(name, description, group, "", "", ports, remotePorts, "", "", "", protocol, direction, profile, true, false) } // FirewallRuleAddApplication creates Inbound rule for given application. @@ -131,16 +128,16 @@ func FirewallRuleAdd(name, description, group, ports string, protocol, profile i // NET_FW_PROFILE2_CURRENT // adds rule to currently used FW Profile // NET_FW_PROFILE2_ALL // adds rule to all profiles // NET_FW_PROFILE2_DOMAIN|NET_FW_PROFILE2_PRIVATE // rule in Private and Domain profile -func FirewallRuleAddApplication(name, description, group, appPath string, profile int32) (bool, error) { +func FirewallRuleAddApplication(name, description, group, appPath, remotePorts string, protocol, direction, profile int32) (bool, error) { if appPath == "" { return false, fmt.Errorf("empty FW Rule appPath, it is mandatory") } - return firewallRuleAdd(name, description, group, appPath, "", "", "", "", "", "", 0, 0, NET_FW_ACTION_ALLOW, profile, true, false) + return firewallRuleAdd(name, description, group, appPath, "", "", remotePorts, "", "", "", protocol, direction, profile, true, false) } // FirewallRuleCreate is deprecated, use FirewallRuleAddApplication instead. func FirewallRuleCreate(name, description, group, appPath, port string, protocol int32) (bool, error) { - return firewallRuleAdd(name, description, group, appPath, "", port, "", "", "", "", protocol, 0, NET_FW_ACTION_ALLOW, NET_FW_PROFILE2_CURRENT, true, false) + return firewallRuleAdd(name, description, group, appPath, "", port, "", "", "", "", protocol, 0, NET_FW_PROFILE2_CURRENT, true, false) } // FirewallPingEnable creates Inbound ICMPv4 rule which allows to answer echo requests. @@ -159,7 +156,7 @@ func FirewallRuleCreate(name, description, group, appPath, port string, protocol // NET_FW_PROFILE2_ALL // adds rule to all profiles // NET_FW_PROFILE2_DOMAIN|NET_FW_PROFILE2_PRIVATE // rule in Private and Domain profile func FirewallPingEnable(name, description, group, remoteAddresses string, profile int32) (bool, error) { - return firewallRuleAdd(name, description, group, "", "", "", "", "", remoteAddresses, "8:*", NET_FW_IP_PROTOCOL_ICMPv4, 0, NET_FW_ACTION_ALLOW, profile, true, false) + return firewallRuleAdd(name, description, group, "", "", "", "", "", remoteAddresses, "8:*", NET_FW_IP_PROTOCOL_ICMPv4, 0, profile, true, false) } // FirewallRuleAddAdvanced allows to modify almost all available FW Rule parameters. @@ -170,7 +167,7 @@ func FirewallPingEnable(name, description, group, remoteAddresses string, profil func FirewallRuleAddAdvanced(rule FWRule) (bool, error) { return firewallRuleAdd(rule.Name, rule.Description, rule.Grouping, rule.ApplicationName, rule.ServiceName, rule.LocalPorts, rule.RemotePorts, rule.LocalAddresses, rule.RemoteAddresses, rule.ICMPTypesAndCodes, - rule.Protocol, rule.Direction, rule.Action, rule.Profiles, rule.Enabled, rule.EdgeTraversal) + rule.Protocol, rule.Direction, rule.Profiles, rule.Enabled, rule.EdgeTraversal) } // FirewallRuleDelete allows you to delete existing rule by name. @@ -243,7 +240,7 @@ func FirewallRuleGet(name string) (FWRule, error) { if err != nil { return rule, err } - defer firewallRulesEnumRelease(ur, ep, enum) + defer firewallRulesEnumRealease(ur, ep) for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) { if err != nil { @@ -270,7 +267,7 @@ func FirewallRuleGet(name string) (FWRule, error) { // FirewallRulesGet returns all rules defined in firewall. func FirewallRulesGet() ([]FWRule, error) { - rules := make([]FWRule, 0, 1024) + rules := make([]FWRule, 1000) u, fwPolicy, err := firewallAPIInit() if err != nil { @@ -282,7 +279,7 @@ func FirewallRulesGet() ([]FWRule, error) { if err != nil { return rules, err } - defer firewallRulesEnumRelease(ur, ep, enum) + defer firewallRulesEnumRealease(ur, ep) for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) { if err != nil { @@ -314,108 +311,94 @@ func firewallRuleParams(itemRaw ole.VARIANT) (FWRule, error) { item := itemRaw.ToIDispatch() defer item.Release() - var err error - rule.Name, err = getStringProperty(item, "Name") + name, err := oleutil.GetProperty(item, "Name") if err != nil { return rule, fmt.Errorf("failed to get Property (Name) of Rule") } - rule.Description, err = getStringProperty(item, "Description") + rule.Name = name.ToString() + description, err := oleutil.GetProperty(item, "Description") if err != nil { return rule, fmt.Errorf("failed to get Property (Description) of Rule %q", rule.Name) } - rule.ApplicationName, err = getStringProperty(item, "ApplicationName") + rule.Description = description.ToString() + applicationApplicationName, err := oleutil.GetProperty(item, "ApplicationName") if err != nil { return rule, fmt.Errorf("failed to get Property (ApplicationName) of Rule %q", rule.Name) } - rule.ServiceName, err = getStringProperty(item, "ServiceName") + rule.ApplicationName = applicationApplicationName.ToString() + serviceName, err := oleutil.GetProperty(item, "ServiceName") if err != nil { return rule, fmt.Errorf("failed to get Property (ServiceName) of Rule %q", rule.Name) } - rule.LocalPorts, err = getStringProperty(item, "LocalPorts") + rule.ServiceName = serviceName.ToString() + localPorts, err := oleutil.GetProperty(item, "LocalPorts") if err != nil { return rule, fmt.Errorf("failed to get Property (LocalPorts) of Rule %q", rule.Name) } - - rule.RemotePorts, err = getStringProperty(item, "RemotePorts") + rule.LocalPorts = localPorts.ToString() + remotePorts, err := oleutil.GetProperty(item, "RemotePorts") if err != nil { return rule, fmt.Errorf("failed to get Property (RemotePorts) of Rule %q", rule.Name) } - rule.LocalAddresses, err = getStringProperty(item, "LocalAddresses") + rule.RemotePorts = remotePorts.ToString() + localAddresses, err := oleutil.GetProperty(item, "LocalAddresses") if err != nil { return rule, fmt.Errorf("failed to get Property (LocalAddresses) of Rule %q", rule.Name) } - rule.RemoteAddresses, err = getStringProperty(item, "RemoteAddresses") + rule.LocalAddresses = localAddresses.ToString() + remoteAddresses, err := oleutil.GetProperty(item, "RemoteAddresses") if err != nil { return rule, fmt.Errorf("failed to get Property (RemoteAddresses) of Rule %q", rule.Name) } - rule.ICMPTypesAndCodes, err = getStringProperty(item, "ICMPTypesAndCodes") + rule.RemoteAddresses = remoteAddresses.ToString() + icmpTypesAndCodes, err := oleutil.GetProperty(item, "ICMPTypesAndCodes") if err != nil { return rule, fmt.Errorf("failed to get Property (ICMPTypesAndCodes) of Rule %q", rule.Name) } - rule.Grouping, err = getStringProperty(item, "Grouping") + rule.ICMPTypesAndCodes = icmpTypesAndCodes.ToString() + grouping, err := oleutil.GetProperty(item, "Grouping") if err != nil { return rule, fmt.Errorf("failed to get Property (Grouping) of Rule %q", rule.Name) } - rule.InterfaceTypes, err = getStringProperty(item, "InterfaceTypes") + rule.Grouping = grouping.ToString() + interfaceTypes, err := oleutil.GetProperty(item, "InterfaceTypes") if err != nil { return rule, fmt.Errorf("failed to get Property (InterfaceTypes) of Rule %q", rule.Name) } - rule.Protocol, err = getInt32Property(item, "Protocol") + rule.InterfaceTypes = interfaceTypes.ToString() + protocol, err := oleutil.GetProperty(item, "Protocol") if err != nil { return rule, fmt.Errorf("failed to get Property (Protocol) of Rule %q", rule.Name) } - rule.Direction, err = getInt32Property(item, "Direction") + rule.Protocol = protocol.Value().(int32) + direction, err := oleutil.GetProperty(item, "Direction") if err != nil { return rule, fmt.Errorf("failed to get Property (Direction) of Rule %q", rule.Name) } - rule.Action, err = getInt32Property(item, "Action") + rule.Direction = direction.Value().(int32) + action, err := oleutil.GetProperty(item, "Action") if err != nil { return rule, fmt.Errorf("failed to get Property (Action) of Rule %q", rule.Name) } - rule.Enabled, err = getBoolProperty(item, "Enabled") + rule.Action = action.Value().(int32) + enabled, err := oleutil.GetProperty(item, "Enabled") if err != nil { return rule, fmt.Errorf("failed to get Property (Enabled) of Rule %q", rule.Name) } - rule.EdgeTraversal, err = getBoolProperty(item, "EdgeTraversal") + rule.Enabled = enabled.Value().(bool) + edgeTraversal, err := oleutil.GetProperty(item, "EdgeTraversal") if err != nil { return rule, fmt.Errorf("failed to get Property (EdgeTraversal) of Rule %q", rule.Name) } - rule.Profiles, err = getInt32Property(item, "Profiles") + rule.EdgeTraversal = edgeTraversal.Value().(bool) + profiles, err := oleutil.GetProperty(item, "Profiles") if err != nil { return rule, fmt.Errorf("failed to get Property (Profiles) of Rule %q", rule.Name) } + rule.Profiles = profiles.Value().(int32) return rule, nil } -func getInt32Property(dispatch *ole.IDispatch, property string) (int32, error) { - val, err := oleutil.GetProperty(dispatch, property) - if err != nil { - log.Printf("failed to get dispatch property: %s \n", err.Error()) - return 0, err - } - defer val.Clear() - return val.Value().(int32), nil -} - -func getStringProperty(dispatch *ole.IDispatch, property string) (string, error) { - val, err := oleutil.GetProperty(dispatch, property) - if err != nil { - log.Printf("failed to get dispatch property: %s \n", err.Error()) - return "", err - } - defer val.Clear() - return val.ToString(), nil -} - -func getBoolProperty(dispatch *ole.IDispatch, property string) (bool, error) { - val, err := oleutil.GetProperty(dispatch, property) - if err != nil { - log.Printf("failed to get dispatch property: %s \n", err.Error()) - return false, err - } - defer val.Clear() - return val.Value().(bool), nil -} // FirewallGroupEnable allows to enable predefined firewall group. It is better // to not use names as "File and Printer Sharing" because they are localized and @@ -537,64 +520,6 @@ func FirewallDisable(profile int32) (bool, error) { return true, nil } -// FirewallGetDefaultOutboundAction checks if outgoing connections without matching rules are allowed or blocked. -// Returns either NET_FW_ACTION_ALLOW or NET_FW_ACTION_BLOCK. -func FirewallGetDefaultOutboundAction(profile int32) (int32, error) { - u, fwPolicy, err := firewallAPIInit() - if err != nil { - return 0, err - } - defer firewallAPIRelease(u, fwPolicy) - - action, err := oleutil.GetProperty(fwPolicy, "DefaultOutboundAction", profile) - if err != nil { - return 0, err - } - return action.Value().(int32), nil -} - -// FirewallGetDefaultInboundAction checks if incoming connections without matching rules are allowed or blocked. -// Returns either NET_FW_ACTION_ALLOW or NET_FW_ACTION_BLOCK. -func FirewallGetDefaultInboundAction(profile int32) (int32, error) { - u, fwPolicy, err := firewallAPIInit() - if err != nil { - return 0, err - } - defer firewallAPIRelease(u, fwPolicy) - - action, err := oleutil.GetProperty(fwPolicy, "DefaultInboundAction", profile) - if err != nil { - return 0, err - } - return action.Value().(int32), nil -} - -// FirewallSetDefaultOutboundAction sets the default policy for outgoing connections. -// action must be NET_FW_ACTION_ALLOW or NET_FW_ACTION_BLOCK. -func FirewallSetDefaultOutboundAction(profile, action int32) error { - u, fwPolicy, err := firewallAPIInit() - if err != nil { - return err - } - defer firewallAPIRelease(u, fwPolicy) - - _, err = oleutil.PutProperty(fwPolicy, "DefaultOutboundAction", profile, action) - return err -} - -// FirewallSetDefaultInboundAction sets the default policy for incoming connections. -//// action must be NET_FW_ACTION_ALLOW or NET_FW_ACTION_BLOCK. -func FirewallSetDefaultInboundAction(profile, action int32) error { - u, fwPolicy, err := firewallAPIInit() - if err != nil { - return err - } - defer firewallAPIRelease(u, fwPolicy) - - _, err = oleutil.PutProperty(fwPolicy, "DefaultInboundAction", profile, action) - return err -} - // FirewallCurrentProfiles return which profiles are currently active. // Every active interface can have it's own profile. F.e.: Public for Wifi, // Domain for VPN, and Private for LAN. All at the same time. @@ -636,7 +561,7 @@ func firewallParseProfiles(v int32) FWProfiles { } // firewallRuleAdd is universal function to add all kinds of rules. -func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, action, profile int32, enabled, edgeTraversal bool) (bool, error) { +func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, profile int32, enabled, edgeTraversal bool) (bool, error) { if name == "" { return false, fmt.Errorf("empty FW Rule name, name is mandatory") @@ -742,7 +667,7 @@ func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remo if _, err := oleutil.PutProperty(fwRule, "Profiles", profile); err != nil { return false, fmt.Errorf("Error setting property (Profiles) of Rule: %s", err) } - if _, err := oleutil.PutProperty(fwRule, "Action", action); err != nil { + if _, err := oleutil.PutProperty(fwRule, "Action", NET_FW_ACTION_ALLOW); err != nil { return false, fmt.Errorf("Error setting property (Action) of Rule: %s", err) } if edgeTraversal { @@ -803,7 +728,7 @@ func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) { // firewallRulesEnum takes fwPolicy object and returns all objects which needs freeing and enum itself, // which is used to enumerate rules. do not forget to: -// defer firewallRulesEnumRelease(ur, ep) +// defer firewallRulesEnumRealease(ur, ep) func firewallRulesEnum(fwPolicy *ole.IDispatch) (*ole.VARIANT, *ole.VARIANT, *ole.IEnumVARIANT, error) { unknownRules, err := oleutil.GetProperty(fwPolicy, "Rules") if err != nil { @@ -832,17 +757,16 @@ func firewallRulesEnum(fwPolicy *ole.IDispatch) (*ole.VARIANT, *ole.VARIANT, *ol } // firewallRuleEnumRelease will free memory used by firewallRulesEnum. -func firewallRulesEnumRelease(unknownRules, enumProperty *ole.VARIANT, enum *ole.IEnumVARIANT) { +func firewallRulesEnumRealease(unknownRules, enumProperty *ole.VARIANT) { enumProperty.Clear() unknownRules.Clear() - enum.Release() } // firewallAPIInit initialize common fw api. // then: // dispatch firewallAPIRelease(u, fwp) func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) { - comshim.Add(1) + ole.CoInitializeEx(0, ole.COINIT_APARTMENTTHREADED|ole.COINIT_SPEED_OVER_MEMORY) unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2") if err != nil { @@ -863,5 +787,5 @@ func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) { func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) { fwp.Release() u.Release() - comshim.Done() + ole.CoUninitialize() }