diff --git a/pkg/query_context/context.go b/pkg/query_context/context.go index 9fa3fd70c..8a7c32770 100644 --- a/pkg/query_context/context.go +++ b/pkg/query_context/context.go @@ -34,12 +34,20 @@ import ( // Context MUST be created using NewContext. // All Context funcs are not safe for concurrent use. type Context struct { + // id for this Context. Not for the dns query. This id is mainly for logging. + id uint32 startTime time.Time // when was this Context created q *dns.Msg - queryMeta QueryMeta - // id for this Context. Not for the dns query. This id is mainly for logging. - id uint32 + // EDNS0 options from client. + // Note: Q() is the query that is going to forward. Initially it has no + // option. All options from client are moved to QueryOpt. + // Plugins that responsible for handling edns0 option should + // check QueryOpt and add options into Q() on demand. + // This field should be read-only. + QueryOpt []dns.EDNS0 + // Server Info. This field should be read-only. + ServerMeta ServerMeta // Response. Might be nil. r *dns.Msg @@ -51,17 +59,16 @@ type Context struct { var contextUid atomic.Uint32 -type QueryMeta = server.QueryMeta +type ServerMeta = server.QueryMeta // NewContext creates a new query Context. // q is the query dns msg. It cannot be nil, or NewContext will panic. -func NewContext(q *dns.Msg, qm QueryMeta) *Context { +func NewContext(q *dns.Msg) *Context { if q == nil { panic("handler: query msg is nil") } ctx := &Context{ q: q, - queryMeta: qm, id: contextUid.Add(1), startTime: time.Now(), } @@ -74,11 +81,6 @@ func (ctx *Context) Q() *dns.Msg { return ctx.q } -// QueryMeta returns the meta data of the query. -func (ctx *Context) QueryMeta() QueryMeta { - return ctx.queryMeta -} - // R returns the response. It might be nil. func (ctx *Context) R() *dns.Msg { return ctx.r @@ -175,7 +177,7 @@ func (ctx *Context) DeleteMark(m uint32) { func (ctx *Context) MarshalLogObject(encoder zapcore.ObjectEncoder) error { encoder.AddUint32("uqid", ctx.id) - if clientAddr := ctx.queryMeta.ClientAddr; clientAddr.IsValid() { + if clientAddr := ctx.ServerMeta.ClientAddr; clientAddr.IsValid() { zap.Stringer("client", clientAddr).AddTo(encoder) } diff --git a/pkg/server_handler/entry_handler.go b/pkg/server_handler/entry_handler.go index cd93ea9b2..38c62fc2c 100644 --- a/pkg/server_handler/entry_handler.go +++ b/pkg/server_handler/entry_handler.go @@ -36,11 +36,21 @@ import ( const ( defaultQueryTimeout = time.Second * 5 - edns0Size = 1200 + edns0Size = 1220 ) var ( nopLogger = mlog.Nop() + + // options that can forward to upstream + queryForwardEDNS0Option = map[uint16]struct{}{ + dns.EDNS0SUBNET: {}, + } + + // options that useless for downstream + respRemoveEDNS0Option = map[uint16]struct{}{ + dns.EDNS0PADDING: {}, + } ) type EntryHandlerOpts struct { @@ -76,7 +86,7 @@ func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler { // ServeDNS implements server.Handler. // If entry returns an error, a SERVFAIL response will be returned. // If entry returns without a response, a REFUSED response will be returned. -func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte { +func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, serverMeta server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte { ddl := time.Now().Add(h.opts.QueryTimeout) ctx, cancel := context.WithDeadline(ctx, ddl) defer cancel() @@ -84,63 +94,80 @@ func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.Quer // Get udp size before exec plugins. It may be changed by plugins. queryUdpSize := getUDPSize(q) - // Enable edns0. We can handle this. + // Enable edns0. We can handle this. // This also helps to avoid udp->tcp fallback. - ends0Upgraded := false - if opt := q.IsEdns0(); opt == nil { - q.SetEdns0(edns0Size, false) - ends0Upgraded = true - } else { - opt.SetUDPSize(edns0Size) - } + queryOpt, queryEDNS0Upgraded := enableEdns0(q) + queryHasEDNS0 := !queryEDNS0Upgraded + + // Save all query opts for plugins. + allQueryOpts := queryOpt.Option + queryOpt.Option = filterQueryOptions2Forward(queryOpt) // exec entry - qCtx := query_context.NewContext(q, qInfo) + qCtx := query_context.NewContext(q) + qCtx.ServerMeta = serverMeta + qCtx.QueryOpt = allQueryOpts err := h.opts.Entry.Exec(ctx, qCtx) - respMsg := qCtx.R() + resp := qCtx.R() + if resp == nil { + resp = new(dns.Msg) + resp.SetReply(qCtx.Q()) + resp.Rcode = dns.RcodeRefused + } + + // May be nil + var respOpt *dns.OPT + if queryHasEDNS0 { + // RFC 3225 3 + // The DO bit of the query MUST be copied in the response. + // ... + // The absence of DNSSEC data in response to a query with the DO bit set + // MUST NOT be taken to mean no security information is available for + // that zone as the response may be forged or a non-forged response of + // an altered (DO bit cleared) query. + var newOpt bool + respOpt, newOpt = enableEdns0(resp) + if queryOpt.Do() { + setDo(respOpt, true) + } + if !newOpt { + filterUselessRespOpt(respOpt) + } + } else { + // Remove edns0 from resp if client didn't send it, as RFC 2671 required. + dnsutils.RemoveEDNS0(resp) + } + if err != nil { - respMsg = new(dns.Msg) - respMsg.SetReply(q) - respMsg.Rcode = dns.RcodeServerFailure - optForEde := tryPopOptForEde(q, respMsg) + resp.Rcode = dns.RcodeServerFailure switch v := err.(type) { case *edns0ede.EdeError: h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Object("ede", v)) - if optForEde != nil { - optForEde.Option = append(optForEde.Option, (*dns.EDNS0_EDE)(v)) + if respOpt != nil { + respOpt.Option = append(respOpt.Option, (*dns.EDNS0_EDE)(v)) } case *edns0ede.EdeErrors: h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Array("edes", v)) - if optForEde != nil { + if respOpt != nil { for _, ede := range ([]*dns.EDNS0_EDE)(*v) { - optForEde.Option = append(optForEde.Option, ede) + respOpt.Option = append(respOpt.Option, ede) } } default: h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err)) } - } else { - if respMsg == nil { - respMsg = new(dns.Msg) - respMsg.SetReply(qCtx.Q()) - respMsg.Rcode = dns.RcodeRefused - } } - respMsg.RecursionAvailable = true - // Client may not support edns0. - // Remove edns0 from resp, as RFC 2671 required. - if ends0Upgraded { - dnsutils.RemoveEDNS0(respMsg) - } + // We assume that our server is a forwarder. + resp.RecursionAvailable = true - if qInfo.FromUDP { - respMsg.Truncate(queryUdpSize) + if serverMeta.FromUDP { + resp.Truncate(queryUdpSize) } - payload, err := packMsgPayload(respMsg) + payload, err := packMsgPayload(resp) if err != nil { - h.opts.Logger.Error("internal err: failed to pack resp msg", zap.Error(err)) + h.opts.Logger.Error("internal err: failed to pack resp msg", qCtx.InfoField(), zap.Error(err)) return nil } return payload @@ -157,19 +184,45 @@ func getUDPSize(m *dns.Msg) int { return int(s) } -func tryPopOptForEde(q, resp *dns.Msg) *dns.OPT { - qOpt := q.IsEdns0() - if qOpt == nil { - return nil // query does not support edns0 +// returns a copy +func filterQueryOptions2Forward(opt *dns.OPT) []dns.EDNS0 { + var remainOpt []dns.EDNS0 + for _, op := range opt.Option { + if _, ok := queryForwardEDNS0Option[op.Option()]; ok { + remainOpt = append(remainOpt, op) + } } + return remainOpt +} - opt := resp.IsEdns0() +// modifies opt directly +func filterUselessRespOpt(opt *dns.OPT) { + remainOpt := opt.Option[0:0] + for _, op := range opt.Option { + if _, remove := respRemoveEDNS0Option[op.Option()]; !remove { + remainOpt = append(remainOpt, op) + } + } + opt.Option = remainOpt +} + +func enableEdns0(m *dns.Msg) (*dns.OPT, bool) { + opt := m.IsEdns0() + edns0Upgraded := false if opt == nil { opt = new(dns.OPT) opt.Hdr.Name = "." opt.Hdr.Rrtype = dns.TypeOPT - opt.SetUDPSize(qOpt.UDPSize()) - resp.Extra = append(resp.Extra, opt) + m.Extra = append(m.Extra, opt) + edns0Upgraded = true + } + opt.SetUDPSize(edns0Size) + return opt, edns0Upgraded +} + +func setDo(opt *dns.OPT, do bool) { + const doBit = 1 << 15 // DNSSEC OK + if do { + opt.Hdr.Ttl |= doBit } - return opt } diff --git a/plugin/executable/dual_selector/dual_selector_test.go b/plugin/executable/dual_selector/dual_selector_test.go index 524e73918..95196cb52 100644 --- a/plugin/executable/dual_selector/dual_selector_test.go +++ b/plugin/executable/dual_selector/dual_selector_test.go @@ -159,7 +159,7 @@ func TestSelector_Exec(t *testing.T) { q := new(dns.Msg) q.SetQuestion("example.", tt.qtype) - qCtx := query_context.NewContext(q, query_context.QueryMeta{}) + qCtx := query_context.NewContext(q) cw := sequence.NewChainWalker([]*sequence.ChainNode{{E: tt.next}}, nil) if err := s.Exec(context.Background(), qCtx, cw); (err != nil) != tt.wantErr { t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr) diff --git a/plugin/executable/ipset/ipset_test.go b/plugin/executable/ipset/ipset_test.go index c5ad5088f..b2cfd292c 100644 --- a/plugin/executable/ipset/ipset_test.go +++ b/plugin/executable/ipset/ipset_test.go @@ -86,7 +86,7 @@ func Test_ipset(t *testing.T) { r.Answer = append(r.Answer, &dns.A{A: net.ParseIP("127.0.0.2")}) r.Answer = append(r.Answer, &dns.AAAA{AAAA: net.ParseIP("::1")}) r.Answer = append(r.Answer, &dns.AAAA{AAAA: net.ParseIP("::2")}) - qCtx := query_context.NewContext(q, query_context.QueryMeta{}) + qCtx := query_context.NewContext(q) qCtx.SetResponse(r) if err := p.Exec(context.Background(), qCtx); err != nil { t.Fatal(err) diff --git a/plugin/executable/rate_limiter/rate_limiter.go b/plugin/executable/rate_limiter/rate_limiter.go index 241f94750..a4ce902db 100644 --- a/plugin/executable/rate_limiter/rate_limiter.go +++ b/plugin/executable/rate_limiter/rate_limiter.go @@ -68,7 +68,7 @@ func New(args Args) *RateLimiter { } func (s *RateLimiter) Exec(ctx context.Context, qCtx *query_context.Context) error { - clientAddr := qCtx.QueryMeta().ClientAddr + clientAddr := qCtx.ServerMeta.ClientAddr if clientAddr.IsValid() { if !s.l.Allow(clientAddr) { qCtx.SetResponse(refuse(qCtx.Q())) diff --git a/plugin/executable/sequence/sequence_test.go b/plugin/executable/sequence/sequence_test.go index 16b1360ca..8c749601d 100644 --- a/plugin/executable/sequence/sequence_test.go +++ b/plugin/executable/sequence/sequence_test.go @@ -187,7 +187,7 @@ func Test_sequence_Exec(t *testing.T) { if err != nil { t.Fatal(err) } - qCtx := query_context.NewContext(new(dns.Msg), query_context.QueryMeta{}) + qCtx := query_context.NewContext(new(dns.Msg)) if err := s.Exec(context.Background(), qCtx); (err != nil) != tt.wantErr { t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/plugin/matcher/client_ip/client_ip_matcher.go b/plugin/matcher/client_ip/client_ip_matcher.go index b308b5df0..0e41355f2 100644 --- a/plugin/matcher/client_ip/client_ip_matcher.go +++ b/plugin/matcher/client_ip/client_ip_matcher.go @@ -39,7 +39,7 @@ func QuickSetup(bq sequence.BQ, s string) (sequence.Matcher, error) { } func matchClientAddr(qCtx *query_context.Context, m netlist.Matcher) (bool, error) { - addr := qCtx.QueryMeta().ClientAddr + addr := qCtx.ServerMeta.ClientAddr if !addr.IsValid() { return false, nil } diff --git a/plugin/matcher/string_exp/string_exp.go b/plugin/matcher/string_exp/string_exp.go index 692f4e362..2534328da 100644 --- a/plugin/matcher/string_exp/string_exp.go +++ b/plugin/matcher/string_exp/string_exp.go @@ -176,9 +176,9 @@ func (op *opRegExp) MatchStr(s string) bool { } func getUrlPath(qCtx *query_context.Context) string { - return qCtx.QueryMeta().UrlPath + return qCtx.ServerMeta.UrlPath } func getServerName(qCtx *query_context.Context) string { - return qCtx.QueryMeta().ServerName + return qCtx.ServerMeta.ServerName } diff --git a/plugin/matcher/string_exp/string_exp_test.go b/plugin/matcher/string_exp/string_exp_test.go index 914019107..e810c133d 100644 --- a/plugin/matcher/string_exp/string_exp_test.go +++ b/plugin/matcher/string_exp/string_exp_test.go @@ -32,7 +32,8 @@ import ( func TestMatcher_Match(t *testing.T) { r := require.New(t) q := new(dns.Msg) - qc := query_context.NewContext(q, query_context.QueryMeta{UrlPath: "/dns-query", ServerName: "a.b.c"}) + qc := query_context.NewContext(q) + qc.ServerMeta = query_context.ServerMeta{UrlPath: "/dns-query", ServerName: "a.b.c"} os.Setenv("STRING_EXP_TEST", "abc") doTest := func(arg string, want bool) { diff --git a/plugin/padding_auth/padding_auth.go b/plugin/padding_auth/padding_auth.go index 004970080..4d885f4f8 100644 --- a/plugin/padding_auth/padding_auth.go +++ b/plugin/padding_auth/padding_auth.go @@ -49,16 +49,9 @@ type paddingAuth struct { } func (m *paddingAuth) Match(_ context.Context, qCtx *query_context.Context) (bool, error) { - q := qCtx.Q() - - opt := q.IsEdns0() - if opt == nil { - return false, nil - } - for i, optRR := range opt.Option { + for _, optRR := range qCtx.QueryOpt { if padding, ok := optRR.(*dns.EDNS0_PADDING); ok { if bytes.Equal(padding.Padding, m.key) { - opt.Option = append(opt.Option[:i], opt.Option[i+1:]...) return true, nil } }