Skip to content

Commit

Permalink
better edns0 handling
Browse files Browse the repository at this point in the history
only forward client edns0 options that plugins can handle

server_handler: remove padding

server_handler: set DO bit according to rfc
  • Loading branch information
IrineSistiana committed Sep 24, 2023
1 parent cd31bc2 commit b753642
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 72 deletions.
26 changes: 14 additions & 12 deletions pkg/query_context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@ import (
// Context MUST be created using NewContext.
// All Context funcs are not safe for concurrent use.
type Context struct {
// id for this Context. Not for the dns query. This id is mainly for logging.
id uint32
startTime time.Time // when was this Context created
q *dns.Msg
queryMeta QueryMeta

// id for this Context. Not for the dns query. This id is mainly for logging.
id uint32
// EDNS0 options from client.
// Note: Q() is the query that is going to forward. Initially it has no
// option. All options from client are moved to QueryOpt.
// Plugins that responsible for handling edns0 option should
// check QueryOpt and add options into Q() on demand.
// This field should be read-only.
QueryOpt []dns.EDNS0
// Server Info. This field should be read-only.
ServerMeta ServerMeta

// Response. Might be nil.
r *dns.Msg
Expand All @@ -51,17 +59,16 @@ type Context struct {

var contextUid atomic.Uint32

type QueryMeta = server.QueryMeta
type ServerMeta = server.QueryMeta

// NewContext creates a new query Context.
// q is the query dns msg. It cannot be nil, or NewContext will panic.
func NewContext(q *dns.Msg, qm QueryMeta) *Context {
func NewContext(q *dns.Msg) *Context {
if q == nil {
panic("handler: query msg is nil")
}
ctx := &Context{
q: q,
queryMeta: qm,
id: contextUid.Add(1),
startTime: time.Now(),
}
Expand All @@ -74,11 +81,6 @@ func (ctx *Context) Q() *dns.Msg {
return ctx.q
}

// QueryMeta returns the meta data of the query.
func (ctx *Context) QueryMeta() QueryMeta {
return ctx.queryMeta
}

// R returns the response. It might be nil.
func (ctx *Context) R() *dns.Msg {
return ctx.r
Expand Down Expand Up @@ -175,7 +177,7 @@ func (ctx *Context) DeleteMark(m uint32) {
func (ctx *Context) MarshalLogObject(encoder zapcore.ObjectEncoder) error {
encoder.AddUint32("uqid", ctx.id)

if clientAddr := ctx.queryMeta.ClientAddr; clientAddr.IsValid() {
if clientAddr := ctx.ServerMeta.ClientAddr; clientAddr.IsValid() {
zap.Stringer("client", clientAddr).AddTo(encoder)
}

Expand Down
141 changes: 97 additions & 44 deletions pkg/server_handler/entry_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,21 @@ import (

const (
defaultQueryTimeout = time.Second * 5
edns0Size = 1200
edns0Size = 1220
)

var (
nopLogger = mlog.Nop()

// options that can forward to upstream
queryForwardEDNS0Option = map[uint16]struct{}{
dns.EDNS0SUBNET: {},
}

// options that useless for downstream
respRemoveEDNS0Option = map[uint16]struct{}{
dns.EDNS0PADDING: {},
}
)

type EntryHandlerOpts struct {
Expand Down Expand Up @@ -76,71 +86,88 @@ func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler {
// ServeDNS implements server.Handler.
// If entry returns an error, a SERVFAIL response will be returned.
// If entry returns without a response, a REFUSED response will be returned.
func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte {
func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, serverMeta server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte {
ddl := time.Now().Add(h.opts.QueryTimeout)
ctx, cancel := context.WithDeadline(ctx, ddl)
defer cancel()

// Get udp size before exec plugins. It may be changed by plugins.
queryUdpSize := getUDPSize(q)

// Enable edns0. We can handle this.
// Enable edns0. We can handle this.
// This also helps to avoid udp->tcp fallback.
ends0Upgraded := false
if opt := q.IsEdns0(); opt == nil {
q.SetEdns0(edns0Size, false)
ends0Upgraded = true
} else {
opt.SetUDPSize(edns0Size)
}
queryOpt, queryEDNS0Upgraded := enableEdns0(q)
queryHasEDNS0 := !queryEDNS0Upgraded

// Save all query opts for plugins.
allQueryOpts := queryOpt.Option
queryOpt.Option = filterQueryOptions2Forward(queryOpt)

// exec entry
qCtx := query_context.NewContext(q, qInfo)
qCtx := query_context.NewContext(q)
qCtx.ServerMeta = serverMeta
qCtx.QueryOpt = allQueryOpts
err := h.opts.Entry.Exec(ctx, qCtx)
respMsg := qCtx.R()
resp := qCtx.R()
if resp == nil {
resp = new(dns.Msg)
resp.SetReply(qCtx.Q())
resp.Rcode = dns.RcodeRefused
}

// May be nil
var respOpt *dns.OPT
if queryHasEDNS0 {
// RFC 3225 3
// The DO bit of the query MUST be copied in the response.
// ...
// The absence of DNSSEC data in response to a query with the DO bit set
// MUST NOT be taken to mean no security information is available for
// that zone as the response may be forged or a non-forged response of
// an altered (DO bit cleared) query.
var newOpt bool
respOpt, newOpt = enableEdns0(resp)
if queryOpt.Do() {
setDo(respOpt, true)
}
if !newOpt {
filterUselessRespOpt(respOpt)
}
} else {
// Remove edns0 from resp if client didn't send it, as RFC 2671 required.
dnsutils.RemoveEDNS0(resp)
}

if err != nil {
respMsg = new(dns.Msg)
respMsg.SetReply(q)
respMsg.Rcode = dns.RcodeServerFailure
optForEde := tryPopOptForEde(q, respMsg)
resp.Rcode = dns.RcodeServerFailure
switch v := err.(type) {
case *edns0ede.EdeError:
h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Object("ede", v))
if optForEde != nil {
optForEde.Option = append(optForEde.Option, (*dns.EDNS0_EDE)(v))
if respOpt != nil {
respOpt.Option = append(respOpt.Option, (*dns.EDNS0_EDE)(v))
}
case *edns0ede.EdeErrors:
h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Array("edes", v))
if optForEde != nil {
if respOpt != nil {
for _, ede := range ([]*dns.EDNS0_EDE)(*v) {
optForEde.Option = append(optForEde.Option, ede)
respOpt.Option = append(respOpt.Option, ede)
}
}
default:
h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err))
}
} else {
if respMsg == nil {
respMsg = new(dns.Msg)
respMsg.SetReply(qCtx.Q())
respMsg.Rcode = dns.RcodeRefused
}
}
respMsg.RecursionAvailable = true

// Client may not support edns0.
// Remove edns0 from resp, as RFC 2671 required.
if ends0Upgraded {
dnsutils.RemoveEDNS0(respMsg)
}
// We assume that our server is a forwarder.
resp.RecursionAvailable = true

if qInfo.FromUDP {
respMsg.Truncate(queryUdpSize)
if serverMeta.FromUDP {
resp.Truncate(queryUdpSize)
}

payload, err := packMsgPayload(respMsg)
payload, err := packMsgPayload(resp)
if err != nil {
h.opts.Logger.Error("internal err: failed to pack resp msg", zap.Error(err))
h.opts.Logger.Error("internal err: failed to pack resp msg", qCtx.InfoField(), zap.Error(err))
return nil
}
return payload
Expand All @@ -157,19 +184,45 @@ func getUDPSize(m *dns.Msg) int {
return int(s)
}

func tryPopOptForEde(q, resp *dns.Msg) *dns.OPT {
qOpt := q.IsEdns0()
if qOpt == nil {
return nil // query does not support edns0
// returns a copy
func filterQueryOptions2Forward(opt *dns.OPT) []dns.EDNS0 {
var remainOpt []dns.EDNS0
for _, op := range opt.Option {
if _, ok := queryForwardEDNS0Option[op.Option()]; ok {
remainOpt = append(remainOpt, op)
}
}
return remainOpt
}

opt := resp.IsEdns0()
// modifies opt directly
func filterUselessRespOpt(opt *dns.OPT) {
remainOpt := opt.Option[0:0]
for _, op := range opt.Option {
if _, remove := respRemoveEDNS0Option[op.Option()]; !remove {
remainOpt = append(remainOpt, op)
}
}
opt.Option = remainOpt
}

func enableEdns0(m *dns.Msg) (*dns.OPT, bool) {
opt := m.IsEdns0()
edns0Upgraded := false
if opt == nil {
opt = new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.SetUDPSize(qOpt.UDPSize())
resp.Extra = append(resp.Extra, opt)
m.Extra = append(m.Extra, opt)
edns0Upgraded = true
}
opt.SetUDPSize(edns0Size)
return opt, edns0Upgraded
}

func setDo(opt *dns.OPT, do bool) {
const doBit = 1 << 15 // DNSSEC OK
if do {
opt.Hdr.Ttl |= doBit
}
return opt
}
2 changes: 1 addition & 1 deletion plugin/executable/dual_selector/dual_selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestSelector_Exec(t *testing.T) {

q := new(dns.Msg)
q.SetQuestion("example.", tt.qtype)
qCtx := query_context.NewContext(q, query_context.QueryMeta{})
qCtx := query_context.NewContext(q)
cw := sequence.NewChainWalker([]*sequence.ChainNode{{E: tt.next}}, nil)
if err := s.Exec(context.Background(), qCtx, cw); (err != nil) != tt.wantErr {
t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
2 changes: 1 addition & 1 deletion plugin/executable/ipset/ipset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func Test_ipset(t *testing.T) {
r.Answer = append(r.Answer, &dns.A{A: net.ParseIP("127.0.0.2")})
r.Answer = append(r.Answer, &dns.AAAA{AAAA: net.ParseIP("::1")})
r.Answer = append(r.Answer, &dns.AAAA{AAAA: net.ParseIP("::2")})
qCtx := query_context.NewContext(q, query_context.QueryMeta{})
qCtx := query_context.NewContext(q)
qCtx.SetResponse(r)
if err := p.Exec(context.Background(), qCtx); err != nil {
t.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion plugin/executable/rate_limiter/rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func New(args Args) *RateLimiter {
}

func (s *RateLimiter) Exec(ctx context.Context, qCtx *query_context.Context) error {
clientAddr := qCtx.QueryMeta().ClientAddr
clientAddr := qCtx.ServerMeta.ClientAddr
if clientAddr.IsValid() {
if !s.l.Allow(clientAddr) {
qCtx.SetResponse(refuse(qCtx.Q()))
Expand Down
2 changes: 1 addition & 1 deletion plugin/executable/sequence/sequence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func Test_sequence_Exec(t *testing.T) {
if err != nil {
t.Fatal(err)
}
qCtx := query_context.NewContext(new(dns.Msg), query_context.QueryMeta{})
qCtx := query_context.NewContext(new(dns.Msg))
if err := s.Exec(context.Background(), qCtx); (err != nil) != tt.wantErr {
t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr)
}
Expand Down
2 changes: 1 addition & 1 deletion plugin/matcher/client_ip/client_ip_matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func QuickSetup(bq sequence.BQ, s string) (sequence.Matcher, error) {
}

func matchClientAddr(qCtx *query_context.Context, m netlist.Matcher) (bool, error) {
addr := qCtx.QueryMeta().ClientAddr
addr := qCtx.ServerMeta.ClientAddr
if !addr.IsValid() {
return false, nil
}
Expand Down
4 changes: 2 additions & 2 deletions plugin/matcher/string_exp/string_exp.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ func (op *opRegExp) MatchStr(s string) bool {
}

func getUrlPath(qCtx *query_context.Context) string {
return qCtx.QueryMeta().UrlPath
return qCtx.ServerMeta.UrlPath
}

func getServerName(qCtx *query_context.Context) string {
return qCtx.QueryMeta().ServerName
return qCtx.ServerMeta.ServerName
}
3 changes: 2 additions & 1 deletion plugin/matcher/string_exp/string_exp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import (
func TestMatcher_Match(t *testing.T) {
r := require.New(t)
q := new(dns.Msg)
qc := query_context.NewContext(q, query_context.QueryMeta{UrlPath: "/dns-query", ServerName: "a.b.c"})
qc := query_context.NewContext(q)
qc.ServerMeta = query_context.ServerMeta{UrlPath: "/dns-query", ServerName: "a.b.c"}
os.Setenv("STRING_EXP_TEST", "abc")

doTest := func(arg string, want bool) {
Expand Down
9 changes: 1 addition & 8 deletions plugin/padding_auth/padding_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,9 @@ type paddingAuth struct {
}

func (m *paddingAuth) Match(_ context.Context, qCtx *query_context.Context) (bool, error) {
q := qCtx.Q()

opt := q.IsEdns0()
if opt == nil {
return false, nil
}
for i, optRR := range opt.Option {
for _, optRR := range qCtx.QueryOpt {
if padding, ok := optRR.(*dns.EDNS0_PADDING); ok {
if bytes.Equal(padding.Padding, m.key) {
opt.Option = append(opt.Option[:i], opt.Option[i+1:]...)
return true, nil
}
}
Expand Down

0 comments on commit b753642

Please sign in to comment.