From 6ddeb7caed5de2611384502ec970cfaee08b4e8c Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Thu, 15 Aug 2024 08:47:35 +0200 Subject: [PATCH] Implement cthelper object expr (#268) --- expr/ct.go | 56 +++++++++++++++++++++++++++++++++++++++++++++ expr/expr.go | 2 ++ nftables_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ obj.go | 2 +- 4 files changed, 118 insertions(+), 1 deletion(-) diff --git a/expr/ct.go b/expr/ct.go index afec90b..7ba113a 100644 --- a/expr/ct.go +++ b/expr/ct.go @@ -138,3 +138,59 @@ func (e *Ct) unmarshal(fam byte, data []byte) error { } return ad.Err() } + +type CtHelper struct { + Name string + L3Proto uint16 + L4Proto uint8 +} + +func (c *CtHelper) marshal(fam byte) ([]byte, error) { + exprData, err := c.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("cthelper\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (c *CtHelper) marshalData(fam byte) ([]byte, error) { + exprData := []netlink.Attribute{ + {Type: unix.NFTA_CT_HELPER_NAME, Data: []byte(c.Name)}, + } + + if c.L3Proto != 0 { + exprData = append(exprData, netlink.Attribute{ + Type: unix.NFTA_CT_HELPER_L3PROTO, Data: binaryutil.BigEndian.PutUint16(c.L3Proto), + }) + } + if c.L4Proto != 0 { + exprData = append(exprData, netlink.Attribute{ + Type: unix.NFTA_CT_HELPER_L4PROTO, Data: []byte{c.L4Proto}, + }) + } + + return netlink.MarshalAttributes(exprData) +} + +func (c *CtHelper) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_CT_HELPER_NAME: + c.Name = ad.String() + case unix.NFTA_CT_HELPER_L3PROTO: + c.L3Proto = ad.Uint16() + case unix.NFTA_CT_HELPER_L4PROTO: + c.L4Proto = ad.Uint8() + } + } + return ad.Err() +} diff --git a/expr/expr.go b/expr/expr.go index dd898ec..a8a9197 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -197,6 +197,8 @@ func exprFromName(name string) Any { e = &Masq{} case "hash": e = &Hash{} + case "cthelper": + e = &CtHelper{} } return e } diff --git a/nftables_test.go b/nftables_test.go index cc16504..357cfd3 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1383,6 +1383,65 @@ func TestCt(t *testing.T) { } } +func TestCtHelper(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + cthelp1 := conn.AddObj(&nftables.NamedObj{ + Table: table, + Name: "ftp-standard", + Type: nftables.ObjTypeCtHelper, + Obj: &expr.CtHelper{ + Name: "ftp", + L4Proto: unix.IPPROTO_TCP, + L3Proto: unix.NFPROTO_IPV4, + }, + }) + + if err := conn.Flush(); err != nil { + t.Fatalf(err.Error()) + } + + obj1, err := conn.GetObject(cthelp1) + if err != nil { + t.Errorf("c.GetObject(cthelp1) failed: %v failed", err) + } + + helper, ok := obj1.(*nftables.NamedObj) + if !ok { + t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj1) + } + + if got, want := helper.Name, "ftp-standard"; got != want { + t.Fatalf("unexpected counter name: got %s, want %s", got, want) + } + + if _, err = conn.ResetObject(cthelp1); err != nil { + t.Errorf("c.ResetObjects(cthelp1) failed: %v failed", err) + } + + obj1, err = conn.GetObject(cthelp1) + if err != nil { + t.Errorf("c.GetObject(cthelp1) failed: %v failed", err) + } + + help := obj1.(*nftables.NamedObj).Obj.(*expr.CtHelper) + if got, want := help.L4Proto, uint8(unix.IPPROTO_TCP); got != want { + t.Errorf("unexpected l4proto number: got %d, want %d", got, want) + } + + if got, want := help.L3Proto, uint16(unix.NFPROTO_IPV4); got != want { + t.Errorf("unexpected l3proto number: got %d, want %d", got, want) + } +} + func TestCtSet(t *testing.T) { want := [][]byte{ // batch begin diff --git a/obj.go b/obj.go index 84ce00f..421a87c 100644 --- a/obj.go +++ b/obj.go @@ -51,7 +51,7 @@ var objByObjTypeMagic = map[ObjType]string{ ObjTypeQuota: "quota", ObjTypeLimit: "limit", ObjTypeConnLimit: "connlimit", - ObjTypeCtHelper: "cthelper", // not implemented in expr + ObjTypeCtHelper: "cthelper", ObjTypeTunnel: "tunnel", // not implemented in expr ObjTypeCtTimeout: "cttimeout", // not implemented in expr ObjTypeSecMark: "secmark", // not implemented in expr