diff --git a/go.mod b/go.mod index b03c710d0f7..f33267f6ef0 100644 --- a/go.mod +++ b/go.mod @@ -169,6 +169,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 + github.com/thejerf/suture/v4 v4.0.5 github.com/u-root/u-root v0.14.0 github.com/ulikunitz/xz v0.5.12 github.com/vmware/vmw-guestinfo v0.0.0-20220317130741-510905f0efa3 diff --git a/go.sum b/go.sum index 598dd5e1bae..ebf292ea9f3 100644 --- a/go.sum +++ b/go.sum @@ -724,6 +724,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 h1:kdXcSzyDtseVEc4yCz2qF8ZrQvIDBJLl4S1c3GCXmoI= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= +github.com/thejerf/suture/v4 v4.0.5 h1:F1E/4FZwXWqvlWDKEUo6/ndLtxGAUzMmNqkrMknZbAA= +github.com/thejerf/suture/v4 v4.0.5/go.mod h1:gu9Y4dXNUWFrByqRt30Rm9/UZ0wzRSt9AJS6xu/ZGxU= github.com/u-root/u-root v0.14.0 h1:Ka4T10EEML7dQ5XDvO9c3MBN8z4nuSnGjcd1jmU2ivg= github.com/u-root/u-root v0.14.0/go.mod h1:hAyZorapJe4qzbLWlAkmSVCJGbfoU9Pu4jpJ1WMluqE= github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= diff --git a/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go b/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go index 33ec6ac6e38..54fca917d57 100644 --- a/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go +++ b/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go @@ -5,24 +5,21 @@ package network import ( + "cmp" "context" - "errors" "fmt" - "net" + "iter" "net/netip" - "slices" - "strings" "sync" - "time" "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/cosi-project/runtime/pkg/controller" "github.com/cosi-project/runtime/pkg/safe" "github.com/cosi-project/runtime/pkg/state" - dnssrv "github.com/miekg/dns" + "github.com/hashicorp/go-multierror" "github.com/siderolabs/gen/optional" - "github.com/siderolabs/gen/pair" "github.com/siderolabs/gen/xiter" + "github.com/thejerf/suture/v4" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -36,13 +33,9 @@ type DNSResolveCacheController struct { State state.State Logger *zap.Logger - mx sync.Mutex - handler *dns.Handler - nodeHandler *dns.NodeHandler - rootHandler dnssrv.Handler - runners map[runnerConfig]pair.Pair[func(), <-chan struct{}] - reconcile chan struct{} - originalCtx context.Context //nolint:containedctx + mx sync.Mutex + manager *dns.Manager + reconcile chan struct{} } // Name implements controller.Controller interface. @@ -74,15 +67,21 @@ func (ctrl *DNSResolveCacheController) Outputs() []controller.Output { } // Run implements controller.Controller interface. -// -//nolint:gocyclo,cyclop -func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Runtime, logger *zap.Logger) error { +func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Runtime, _ *zap.Logger) error { ctrl.init(ctx) ctrl.mx.Lock() defer ctrl.mx.Unlock() - defer ctrl.stopRunners(ctx, false) + defer func() { + for err := range ctrl.manager.ClearAll(ctx.Err() == nil) { + ctrl.Logger.Error("error stopping dns runner", zap.Error(err)) + } + + if ctx.Err() != nil { + ctrl.Logger.Info("manager finished", zap.Error(<-ctrl.manager.Done())) + } + }() for { select { @@ -90,263 +89,128 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run return nil case <-r.EventCh(): case <-ctrl.reconcile: - for cfg, stop := range ctrl.runners { - select { - default: - continue - case <-stop.F2: - } - - stop.F1() - delete(ctrl.runners, cfg) - } - } - - cfg, err := safe.ReaderGetByID[*network.HostDNSConfig](ctx, r, network.HostDNSConfigID) - if err != nil { - if state.IsNotFoundError(err) { - continue - } - - return fmt.Errorf("error getting host dns config: %w", err) } - r.StartTrackingOutputs() - - if !cfg.TypedSpec().Enabled { - ctrl.stopRunners(ctx, true) - - if err = safe.CleanupOutputs[*network.DNSResolveCache](ctx, r); err != nil { - return fmt.Errorf("error cleaning up dns status on disable: %w", err) - } - - continue - } - - ctrl.nodeHandler.SetEnabled(cfg.TypedSpec().ResolveMemberNames) - - touchedRunners := make(map[runnerConfig]struct{}, len(ctrl.runners)) - - for _, addr := range cfg.TypedSpec().ListenAddresses { - for _, netwk := range []string{"udp", "tcp"} { - runnerCfg := runnerConfig{net: netwk, addr: addr} - - if _, ok := ctrl.runners[runnerCfg]; !ok { - runner, rErr := newDNSRunner(runnerCfg, ctrl.rootHandler, ctrl.Logger, cfg.TypedSpec().ServiceHostDNSAddress.IsValid()) - if rErr != nil { - return fmt.Errorf("error creating dns runner: %w", rErr) - } - - ctrl.runners[runnerCfg] = pair.MakePair(runner.Start(ctrl.handleDone(ctx, logger))) - } - - if err = ctrl.writeDNSStatus(ctx, r, runnerCfg); err != nil { - return fmt.Errorf("error writing dns status: %w", err) - } - - touchedRunners[runnerCfg] = struct{}{} - } - } - - for runnerCfg, stop := range ctrl.runners { - if _, ok := touchedRunners[runnerCfg]; !ok { - stop.F1() - delete(ctrl.runners, runnerCfg) - - continue - } - } - - upstreams, err := safe.ReaderListAll[*network.DNSUpstream](ctx, r) - if err != nil { - return fmt.Errorf("error getting resolver status: %w", err) - } - - prxs := xiter.Map( - upstreams.All(), - // We are using iterator here to preserve finalizer on - func(upstream *network.DNSUpstream) *proxy.Proxy { - return upstream.TypedSpec().Value.Conn.Proxy().(*proxy.Proxy) - }) - - if ctrl.handler.SetProxy(prxs) { - ctrl.Logger.Info("updated dns server nameservers", zap.Array("addrs", addrsArr(upstreams))) - } - - if err = safe.CleanupOutputs[*network.DNSResolveCache](ctx, r); err != nil { - return fmt.Errorf("error cleaning up dns status: %w", err) + if err := ctrl.run(ctx, r); err != nil { + return err } } } -func (ctrl *DNSResolveCacheController) writeDNSStatus(ctx context.Context, r controller.Runtime, config runnerConfig) error { - return safe.WriterModify(ctx, r, network.NewDNSResolveCache(fmt.Sprintf("%s-%s", config.net, config.addr)), func(drc *network.DNSResolveCache) error { - drc.TypedSpec().Status = "running" +func (ctrl *DNSResolveCacheController) run(ctx context.Context, r controller.Runtime) (resErr error) { + r.StartTrackingOutputs() - return nil - }) -} - -func (ctrl *DNSResolveCacheController) init(ctx context.Context) { - if ctrl.runners != nil { - if ctrl.originalCtx != ctx { - // This should not happen, but if it does, it's a bug. - panic("DNSResolveCacheController is called with a different context") + defer func() { + if err := safe.CleanupOutputs[*network.DNSResolveCache](ctx, r); err != nil { + resErr = cmp.Or(resErr, fmt.Errorf("error cleaning up dns resolve cache: %w", err)) } + }() - return - } + cfg, err := safe.ReaderGetByID[*network.HostDNSConfig](ctx, r, network.HostDNSConfigID) - ctrl.originalCtx = ctx - ctrl.handler = dns.NewHandler(ctrl.Logger) - ctrl.nodeHandler = dns.NewNodeHandler(ctrl.handler, &stateMapper{state: ctrl.State}, ctrl.Logger) - ctrl.rootHandler = dns.NewCache(ctrl.nodeHandler, ctrl.Logger) - ctrl.runners = map[runnerConfig]pair.Pair[func(), <-chan struct{}]{} - ctrl.reconcile = make(chan struct{}, 1) - - // Ensure we stop all runners when the context is canceled, no matter where we are currently. - // For example if we are in Controller runtime sleeping after error and ctx is canceled, we should stop all runners - // but, we will never call Run method again, so we need to ensure this happens regardless of the current state. - context.AfterFunc(ctx, func() { - ctrl.mx.Lock() - defer ctrl.mx.Unlock() - - ctrl.stopRunners(ctx, true) - }) -} - -func (ctrl *DNSResolveCacheController) stopRunners(ctx context.Context, ignoreCtx bool) { - if !ignoreCtx && ctx.Err() == nil { - // context not yet canceled, preserve runners, cache and handler - return - } - - for _, stop := range ctrl.runners { - stop.F1() + switch { + case state.IsNotFoundError(err): + return nil + case err != nil: + return fmt.Errorf("error getting host dns config: %w", err) } - clear(ctrl.runners) - - ctrl.handler.Stop() -} + ctrl.manager.AllowNodeResolving(cfg.TypedSpec().ResolveMemberNames) -func (ctrl *DNSResolveCacheController) handleDone(ctx context.Context, logger *zap.Logger) func(err error) { - return func(err error) { - if ctx.Err() != nil { - if err != nil && !errors.Is(err, net.ErrClosed) { - logger.Error("controller is closing, but error running dns server", zap.Error(err)) - } + if !cfg.TypedSpec().Enabled { + return foldErrors(ctrl.manager.ClearAll(false)) + } - return - } + pairs := allAddressPairs(cfg.TypedSpec().ListenAddresses) + forwardKubeDNSToHost := cfg.TypedSpec().ServiceHostDNSAddress.IsValid() - if err != nil { - logger.Error("error running dns server", zap.Error(err)) + for runCfg, runErr := range ctrl.manager.RunAll(pairs, forwardKubeDNSToHost) { + if runErr != nil { + return fmt.Errorf("error updating dns runner %v: %w", runCfg, runErr) } - select { - case ctrl.reconcile <- struct{}{}: - default: + if err = ctrl.writeDNSStatus(ctx, r, runCfg); err != nil { + return fmt.Errorf("error writing dns status: %w", err) } } -} -type runnerConfig struct { - net string - addr netip.AddrPort -} - -func newDNSRunner(cfg runnerConfig, rootHandler dnssrv.Handler, logger *zap.Logger, forwardEnabled bool) (*dns.Server, error) { - if cfg.addr.Addr().Is6() { - cfg.net += "6" + upstreams, err := safe.ReaderListAll[*network.DNSUpstream](ctx, r) + if err != nil { + return fmt.Errorf("error getting resolver status: %w", err) } - logger = logger.With(zap.String("net", cfg.net), zap.Stringer("addr", cfg.addr)) + prxs := xiter.Map( + upstreams.All(), + // We are using iterator here to preserve finalizer on + func(upstream *network.DNSUpstream) *proxy.Proxy { + return upstream.TypedSpec().Value.Conn.Proxy().(*proxy.Proxy) + }) - var serverOpts dns.ServerOptions - - controlFn, ctrlErr := dns.MakeControl(cfg.net, forwardEnabled) - if ctrlErr != nil { - return nil, fmt.Errorf("error creating %q control function: %w", cfg.net, ctrlErr) + if ctrl.manager.SetUpstreams(prxs) { + ctrl.Logger.Info("updated dns server nameservers", zap.Array("addrs", addrsArr(upstreams))) } - switch cfg.net { - case "udp", "udp6": - packetConn, err := dns.NewUDPPacketConn(cfg.net, cfg.addr.String(), controlFn) - if err != nil { - return nil, fmt.Errorf("error creating %q packet conn: %w", cfg.net, err) - } - - serverOpts = dns.ServerOptions{ - PacketConn: packetConn, - Handler: rootHandler, - Logger: logger, - } + return nil +} - case "tcp", "tcp6": - listener, err := dns.NewTCPListener(cfg.net, cfg.addr.String(), controlFn) - if err != nil { - return nil, fmt.Errorf("error creating %q listener: %w", cfg.net, err) - } +func foldErrors(it iter.Seq[error]) error { + var multiErr *multierror.Error - serverOpts = dns.ServerOptions{ - Listener: listener, - Handler: rootHandler, - ReadTimeout: 3 * time.Second, - WriteTimeout: 5 * time.Second, - IdleTimeout: func() time.Duration { return 10 * time.Second }, - MaxTCPQueries: -1, - Logger: logger, - } + for err := range it { + multiErr = multierror.Append(multiErr, err) } - return dns.NewServer(serverOpts), nil + return multiErr.ErrorOrNil() } -type stateMapper struct { - state state.State -} - -func (s *stateMapper) ResolveAddr(ctx context.Context, qType uint16, name string) []netip.Addr { - name = strings.TrimRight(name, ".") +func (ctrl *DNSResolveCacheController) writeDNSStatus(ctx context.Context, r controller.Runtime, config dns.AddressPair) error { + res := network.NewDNSResolveCache(fmt.Sprintf("%s-%s", config.Network, config.Addr)) - list, err := safe.ReaderListAll[*cluster.Member](ctx, s.state) - if err != nil { - return nil - } + return safe.WriterModify(ctx, r, res, func(drc *network.DNSResolveCache) error { + drc.TypedSpec().Status = "running" - elem, ok := list.Find(func(res *cluster.Member) bool { - return fqdnMatch(name, res.TypedSpec().Hostname) || fqdnMatch(name, res.Metadata().ID()) - }) - if !ok { return nil - } - - result := slices.DeleteFunc(slices.Clone(elem.TypedSpec().Addresses), func(addr netip.Addr) bool { - return !((qType == dnssrv.TypeA && addr.Is4()) || (qType == dnssrv.TypeAAAA && addr.Is6())) }) +} - if len(result) == 0 { - return nil +func (ctrl *DNSResolveCacheController) init(ctx context.Context) { + if ctrl.manager == nil { + ctrl.manager = dns.NewManager(&memberReader{st: ctrl.State}, ctrl.eventHook, ctrl.Logger) + + // Ensure we stop all runners when the context is canceled, no matter where we are currently. + // For example if we are in Controller runtime sleeping after error and ctx is canceled, we should stop all runners + // but, we will never call Run method again, so we need to ensure this happens regardless of the current state. + context.AfterFunc(ctx, func() { + ctrl.mx.Lock() + defer ctrl.mx.Unlock() + + for err := range ctrl.manager.ClearAll(false) { + ctrl.Logger.Error("error ctx stopping dns runner", zap.Error(err)) + } + }) } - return result + ctrl.manager.ServeBackground(ctx) } -func fqdnMatch(what, where string) bool { - what = strings.TrimRight(what, ".") - where = strings.TrimRight(where, ".") +func (ctrl *DNSResolveCacheController) eventHook(event suture.Event) { + ctrl.Logger.Info("dns-resolve-cache-runners event", zap.String("event", event.String())) - if what == where { - return true + select { + case ctrl.reconcile <- struct{}{}: + default: } +} + +type memberReader struct{ st state.State } - first, _, found := strings.Cut(where, ".") - if !found { - return false +func (m *memberReader) ReadMembers(ctx context.Context) (iter.Seq[*cluster.Member], error) { + list, err := safe.ReaderListAll[*cluster.Member](ctx, m.st) + if err != nil { + return nil, err } - return what == first + return list.All(), nil } type addrsArr safe.List[*network.DNSUpstream] @@ -360,3 +224,18 @@ func (a addrsArr) MarshalLogArray(encoder zapcore.ArrayEncoder) error { return nil } + +func allAddressPairs(addresses []netip.AddrPort) iter.Seq[dns.AddressPair] { + return func(yield func(dns.AddressPair) bool) { + for _, addr := range addresses { + for _, netwk := range []string{"udp", "tcp"} { + if !yield(dns.AddressPair{ + Network: netwk, + Addr: addr, + }) { + return + } + } + } + } +} diff --git a/internal/app/machined/pkg/controllers/network/operator_spec_test.go b/internal/app/machined/pkg/controllers/network/operator_spec_test.go index 9aeab1fe850..3190c8661bb 100644 --- a/internal/app/machined/pkg/controllers/network/operator_spec_test.go +++ b/internal/app/machined/pkg/controllers/network/operator_spec_test.go @@ -311,7 +311,8 @@ func (suite *OperatorSpecSuite) TestScheduling() { retry.Constant(3*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( func() error { return suite.assertRunning( - []string{"dhcp4/eth0", "vip/eth0"}, func(op *mockOperator) error { + []string{"dhcp4/eth0", "vip/eth0"}, + func(op *mockOperator) error { switch op.spec.Operator { //nolint:exhaustive case network.OperatorDHCP4: suite.Assert().EqualValues(1024, op.spec.DHCP4.RouteMetric) @@ -339,7 +340,8 @@ func (suite *OperatorSpecSuite) TestScheduling() { retry.Constant(3*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( func() error { return suite.assertRunning( - []string{"dhcp4/eth0", "vip/eth0"}, func(op *mockOperator) error { + []string{"dhcp4/eth0", "vip/eth0"}, + func(op *mockOperator) error { switch op.spec.Operator { //nolint:exhaustive case network.OperatorDHCP4: suite.Assert().EqualValues(1024, op.spec.DHCP4.RouteMetric) diff --git a/internal/app/machined/pkg/xcontext/xcontext.go b/internal/app/machined/pkg/xcontext/xcontext.go new file mode 100644 index 00000000000..32dca45de85 --- /dev/null +++ b/internal/app/machined/pkg/xcontext/xcontext.go @@ -0,0 +1,28 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package xcontext provides a small utils for context package +package xcontext + +import "context" + +// AfterFuncSync is like [context.AfterFunc] but it blocks until the function is executed. +func AfterFuncSync(ctx context.Context, fn func()) func() bool { + stopChan := make(chan struct{}) + + stop := context.AfterFunc(ctx, func() { + defer close(stopChan) + + fn() + }) + + return func() bool { + result := stop() + if !result { + <-stopChan + } + + return result + } +} diff --git a/internal/pkg/dns/dns.go b/internal/pkg/dns/dns.go index 6a8be932fcc..5eaa7619f4d 100644 --- a/internal/pkg/dns/dns.go +++ b/internal/pkg/dns/dns.go @@ -9,12 +9,10 @@ import ( "context" "errors" "fmt" - "io" "iter" "net" "net/netip" "slices" - "strings" "sync" "sync/atomic" "syscall" @@ -207,7 +205,7 @@ func NewNodeHandler(next plugin.Handler, hostMapper HostMapper, logger *zap.Logg // HostMapper is a name to node mapper. type HostMapper interface { - ResolveAddr(ctx context.Context, qType uint16, name string) []netip.Addr + ResolveAddr(ctx context.Context, qType uint16, name string) (iter.Seq[netip.Addr], bool) } // NodeHandler try to resolve dns request to a node. If required node is not found, it will move to the next handler. @@ -238,14 +236,19 @@ func (h *NodeHandler) ServeDNS(ctx context.Context, wrt dns.ResponseWriter, msg req := request.Request{W: wrt, Req: msg} // Check if the request is for a node. - result := h.mapper.ResolveAddr(ctx, req.QType(), req.Name()) - if len(result) == 0 { + result, ok := h.mapper.ResolveAddr(ctx, req.QType(), req.Name()) + if !ok { + return h.next.ServeDNS(ctx, wrt, msg) + } + + answers := mapAnswers(result, req.Name()) + if len(answers) == 0 { return h.next.ServeDNS(ctx, wrt, msg) } resp := new(dns.Msg).SetReply(req.Req) resp.Authoritative = true - resp.Answer = mapAnswers(result, req.Name()) + resp.Answer = answers err := wrt.WriteMsg(resp) if err != nil { @@ -256,10 +259,10 @@ func (h *NodeHandler) ServeDNS(ctx context.Context, wrt dns.ResponseWriter, msg return dns.RcodeSuccess, nil } -func mapAnswers(addrs []netip.Addr, name string) []dns.RR { +func mapAnswers(addrs iter.Seq[netip.Addr], name string) []dns.RR { var result []dns.RR - for _, addr := range addrs { + for addr := range addrs { switch { case addr.Is4(): result = append(result, &dns.A{ @@ -295,89 +298,6 @@ func (h *NodeHandler) SetEnabled(enabled bool) { h.enabled.Store(enabled) } -// ServerOptions is a Server options. -type ServerOptions struct { - Listener net.Listener - PacketConn net.PacketConn - Handler dns.Handler - ReadTimeout time.Duration - WriteTimeout time.Duration - IdleTimeout func() time.Duration - MaxTCPQueries int - Logger *zap.Logger -} - -// NewServer creates a new Server. -func NewServer(opts ServerOptions) *Server { - return &Server{ - srv: &dns.Server{ - Listener: opts.Listener, - PacketConn: opts.PacketConn, - Handler: opts.Handler, - UDPSize: dns.DefaultMsgSize, // 4096 since default is [dns.MinMsgSize] = 512 bytes, which is too small. - ReadTimeout: opts.ReadTimeout, - WriteTimeout: opts.WriteTimeout, - IdleTimeout: opts.IdleTimeout, - MaxTCPQueries: opts.MaxTCPQueries, - }, - logger: opts.Logger, - } -} - -// Server is a dns server. -type Server struct { - srv *dns.Server - logger *zap.Logger -} - -// Start starts the dns server. Returns a function to stop the server. -func (s *Server) Start(onDone func(err error)) (stop func(), stopped <-chan struct{}) { - done := make(chan struct{}) - - fn := sync.OnceFunc(func() { - for { - err := s.srv.Shutdown() - if err != nil { - if strings.Contains(err.Error(), "server not started") { - // There a possible scenario where `go func()` not yet reached `ActivateAndServe` and yielded CPU - // time to another goroutine and then this closure reached `Shutdown`. In that case - // `ActivateAndServe` will actually start after `Shutdown` and this closure will block forever - // because `go func()` will never exit and close `done` channel. - continue - } - - s.logger.Error("error shutting down dns server", zap.Error(err)) - } - - break - } - - closer := io.Closer(s.srv.Listener) - if closer == nil { - closer = s.srv.PacketConn - } - - if closer != nil { - err := closer.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - s.logger.Error("error closing dns server listener", zap.Error(err)) - } else { - s.logger.Debug("dns server listener closed") - } - } - - <-done - }) - - go func() { - defer close(done) - - onDone(s.srv.ActivateAndServe()) - }() - - return fn, done -} - // NewTCPListener creates a new TCP listener. func NewTCPListener(network, addr string, control ControlFn) (net.Listener, error) { network, ok := networkNames[network] diff --git a/internal/pkg/dns/dns_test.go b/internal/pkg/dns/dns_test.go index e15ac7992e2..e8de0508c54 100644 --- a/internal/pkg/dns/dns_test.go +++ b/internal/pkg/dns/dns_test.go @@ -6,22 +6,26 @@ package dns_test import ( "context" + "iter" "net" "net/netip" + "runtime" "slices" "testing" "time" "github.com/coredns/coredns/plugin/pkg/proxy" dnssrv "github.com/miekg/dns" - "github.com/siderolabs/gen/ensure" + "github.com/siderolabs/gen/maps" "github.com/siderolabs/gen/xiter" "github.com/siderolabs/gen/xslices" "github.com/siderolabs/gen/xtesting/check" "github.com/stretchr/testify/require" + "github.com/thejerf/suture/v4" "go.uber.org/zap/zaptest" "github.com/siderolabs/talos/internal/pkg/dns" + "github.com/siderolabs/talos/pkg/machinery/resources/cluster" ) func TestDNS(t *testing.T) { @@ -78,7 +82,7 @@ func TestDNS(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - stop := newServer(t, test.nameservers...) + stop := newManager(t, test.nameservers...) t.Cleanup(stop) time.Sleep(10 * time.Millisecond) @@ -96,7 +100,7 @@ func TestDNS(t *testing.T) { } func TestDNSEmptyDestinations(t *testing.T) { - stop := newServer(t) + stop := newManager(t) defer stop() time.Sleep(10 * time.Millisecond) @@ -112,11 +116,50 @@ func TestDNSEmptyDestinations(t *testing.T) { stop() } -func newServer(t *testing.T, nameservers ...string) func() { - l := zaptest.NewLogger(t) +func TestGC_NOGC(t *testing.T) { + tests := map[string]bool{ + "ClearAll": false, + "No ClearAll": true, + } + + for name, f := range tests { + t.Run(name, func(t *testing.T) { + m := dns.NewManager(&testReader{}, func(e suture.Event) { t.Log("dns-runners event:", e) }, zaptest.NewLogger(t)) + + m.ServeBackground(context.Background()) + m.ServeBackground(context.Background()) + require.Panics(t, func() { m.ServeBackground(context.TODO()) }) + + for _, err := range m.RunAll(xiter.Values(slices.All([]dns.AddressPair{ + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.53:10700")}, + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.53:10701")}, + })), false) { + require.NoError(t, err) + } + + for err := range m.ClearAll(f) { + require.NoError(t, err) + } + + m = nil + + for range 100 { + runtime.GC() + } + }) + } +} + +func newManager(t *testing.T, nameservers ...string) func() { + m := dns.NewManager(&testReader{}, func(e suture.Event) { t.Log("dns-runners event:", e) }, zaptest.NewLogger(t)) + + m.AllowNodeResolving(true) - handler := dns.NewHandler(l) - t.Cleanup(handler.Stop) + t.Cleanup(func() { + for err := range m.ClearAll(false) { + t.Logf("error stopping dns runner: %v", err) + } + }) pxs := xslices.Map(nameservers, func(ns string) *proxy.Proxy { p := proxy.NewProxy(ns, net.JoinHostPort(ns, "53"), "dns") @@ -127,30 +170,34 @@ func newServer(t *testing.T, nameservers ...string) func() { return p }) - handler.SetProxy(xiter.Values(slices.All(pxs))) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) - pc, err := dns.NewUDPPacketConn("udp", "127.0.0.53:10700", ensure.Value(dns.MakeControl("udp", false))) - require.NoError(t, err) + m.SetUpstreams(xiter.Values(slices.All(pxs))) - nodeHandler := dns.NewNodeHandler(handler, &testResolver{}, l) + m.ServeBackground(ctx) + m.ServeBackground(ctx) - nodeHandler.SetEnabled(true) + for _, err := range m.RunAll(xiter.Values(slices.All([]dns.AddressPair{ + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.53:10700")}, + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.53:10701")}, + {Network: "tcp", Addr: netip.MustParseAddrPort("127.0.0.53:10700")}, + })), false) { + require.NoError(t, err) + } - srv := dns.NewServer(dns.ServerOptions{ - PacketConn: pc, - Handler: dns.NewCache(nodeHandler, l), - Logger: l, - }) + for _, err := range m.RunAll(xiter.Values(slices.All([]dns.AddressPair{ + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.53:10700")}, + {Network: "tcp", Addr: netip.MustParseAddrPort("127.0.0.53:10700")}, + })), false) { + require.NoError(t, err) + } - stop, _ := srv.Start(func(err error) { - if err != nil { - t.Errorf("error running dns server: %v", err) + return func() { + for err := range m.ClearAll(false) { + t.Logf("error stopping dns runner: %v", err) } - - t.Logf("dns server stopped") - }) - - return stop + } } func createQuery(name string) *dnssrv.Msg { @@ -169,19 +216,21 @@ func createQuery(name string) *dnssrv.Msg { } } -type testResolver struct{} +type testReader struct{} -func (*testResolver) ResolveAddr(_ context.Context, qType uint16, name string) []netip.Addr { - if qType != dnssrv.TypeA { - return nil +func (r *testReader) ReadMembers(context.Context) (iter.Seq[*cluster.Member], error) { + namesToAddresses := map[string][]netip.Addr{ + "talos-default-controlplane-1": {netip.MustParseAddr("172.20.0.2")}, + "talos-default-worker-1": {netip.MustParseAddr("172.20.0.3")}, } - switch name { - case "talos-default-controlplane-1.": - return []netip.Addr{netip.MustParseAddr("172.20.0.2")} - case "talos-default-worker-1.": - return []netip.Addr{netip.MustParseAddr("172.20.0.3")} - default: - return nil - } + result := maps.ToSlice(namesToAddresses, func(k string, v []netip.Addr) *cluster.Member { + result := cluster.NewMember(cluster.NamespaceName, k) + + result.TypedSpec().Addresses = v + + return result + }) + + return xiter.Values(slices.All(result)), nil } diff --git a/internal/pkg/dns/manager.go b/internal/pkg/dns/manager.go new file mode 100644 index 00000000000..a615836009a --- /dev/null +++ b/internal/pkg/dns/manager.go @@ -0,0 +1,292 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package dns + +import ( + "context" + "errors" + "fmt" + "iter" + "net/netip" + "runtime" + "slices" + "strings" + "time" + + "github.com/coredns/coredns/plugin/pkg/proxy" + dnssrv "github.com/miekg/dns" + "github.com/siderolabs/gen/xiter" + "github.com/thejerf/suture/v4" + "go.uber.org/zap" + + "github.com/siderolabs/talos/pkg/machinery/resources/cluster" +) + +// ErrCreatingRunner is an error that occurs when creating a runner. +var ErrCreatingRunner = errors.New("error creating runner") + +// Manager manages DNS runners. +type Manager struct { + originalCtx context.Context //nolint:containedctx + handler *Handler + nodeHandler *NodeHandler + rootHandler *Cache + s *suture.Supervisor + supervisorCh <-chan error + logger *zap.Logger + runners map[AddressPair]suture.ServiceToken +} + +// NewManager creates a new manager. +func NewManager(mr MemberReader, hook suture.EventHook, logger *zap.Logger) *Manager { + handler := NewHandler(logger) + nodeHandler := NewNodeHandler(handler, &addrResolver{mr: mr}, logger) + rootHandler := NewCache(nodeHandler, logger) + + m := &Manager{ + handler: handler, + nodeHandler: nodeHandler, + rootHandler: rootHandler, + s: suture.New("dns-resolve-cache-runners", suture.Spec{EventHook: hook}), + logger: logger, + runners: map[AddressPair]suture.ServiceToken{}, + } + + // If we lost ref to the manager. Ensure finalizer is called and all upstreams are collected. + runtime.SetFinalizer(m, (*Manager).finalize) + + return m +} + +// ServeBackground starts the manager in the background. It panics if the manager is not initialized or if it's called +// more than once. +func (m *Manager) ServeBackground(ctx context.Context) { + switch { + case m.originalCtx == nil: + m.originalCtx = ctx + case m.originalCtx != ctx: + panic("Manager.ServeBackground is called with a different context") + case m.originalCtx == ctx: + return + } + + m.supervisorCh = m.s.ServeBackground(ctx) +} + +// AddressPair represents a network and address with port. +type AddressPair struct { + Network string + Addr netip.AddrPort +} + +// String returns a string representation of the address pair. +func (a AddressPair) String() string { return "Network: " + a.Network + ", Addr: " + a.Addr.String() } + +// RunAll updates and run the runners managed by the manager. It returns an iterator which yields the address pairs for +// all running and attempted ro run configurations. It's mandatory to range over the iterator to ensure all runners are updated. +func (m *Manager) RunAll(pairs iter.Seq[AddressPair], forwardEnabled bool) iter.Seq2[AddressPair, error] { + return func(yield func(AddressPair, error) bool) { + preserve := make(map[AddressPair]struct{}, len(m.runners)) + + for cfg := range pairs { + if _, ok := m.runners[cfg]; !ok { + opts, err := newDNSRunnerOpts(cfg, m.rootHandler, forwardEnabled) + if err != nil { + err = fmt.Errorf("%w: %w", ErrCreatingRunner, err) + } + + if !yield(cfg, err) { + return + } + + m.runners[cfg] = m.s.Add(NewRunner(opts, m.logger)) + } else if !yield(cfg, nil) { + return + } + + preserve[cfg] = struct{}{} + } + + for runnerCfg, token := range m.runners { + if _, ok := preserve[runnerCfg]; !ok { + err := m.s.RemoveAndWait(token, 0) + if err != nil { + m.logger.Warn("error removing runner", zap.Stringer("cfg", runnerCfg), zap.Error(err)) + } + + delete(m.runners, runnerCfg) + } + } + } +} + +// AllowNodeResolving enables or disables the node resolving feature. +func (m *Manager) AllowNodeResolving(enabled bool) { m.nodeHandler.SetEnabled(enabled) } + +// SetUpstreams sets the upstreams for the DNS handler. It returns true if the upstreams were updated, false otherwise. +func (m *Manager) SetUpstreams(prxs iter.Seq[*proxy.Proxy]) bool { return m.handler.SetProxy(prxs) } + +// ClearAll stops and removes all runners. It returns an iterator which yields the address pairs that were removed +// and/or errors that occurred during the removal process. It's mandatory to range over the iterator to ensure all +// runners are stopped. +func (m *Manager) ClearAll(dry bool) iter.Seq[error] { + if dry { + return xiter.Empty + } + + return xiter.Filter( + xiter.Values(m.clearAll()), + func(e error) bool { return e != nil }, + ) +} + +func (m *Manager) clearAll() iter.Seq2[AddressPair, error] { + return func(yield func(AddressPair, error) bool) { + if len(m.runners) == 0 { + return + } + + defer m.handler.Stop() + + removeAndWait := m.s.RemoveAndWait + if m.originalCtx.Err() != nil { + // ctx canceled, no reason to remove runners from Supervisor since they are already dropped + removeAndWait = func(id suture.ServiceToken, timeout time.Duration) error { return nil } + } + + for runData, token := range m.runners { + err := removeAndWait(token, 0) + if err != nil { + err = fmt.Errorf("error removing runner: %w", err) + } + + if !yield(runData, err) { + return + } + + delete(m.runners, runData) + } + } +} + +func (m *Manager) finalize() { + for data, err := range m.clearAll() { + if err != nil { + m.logger.Error("error stopping dns runner", zap.Error(err)) + } + + m.logger.Info( + "dns runner stopped from finalizer!", + zap.String("address", data.Addr.String()), + zap.String("network", data.Network), + ) + } +} + +// Done reports if superwisor finished execution. +func (m *Manager) Done() <-chan error { + return m.supervisorCh +} + +type addrResolver struct { + mr MemberReader +} + +func (s *addrResolver) ResolveAddr(ctx context.Context, qType uint16, name string) (iter.Seq[netip.Addr], bool) { + name = strings.TrimRight(name, ".") + + items, err := s.mr.ReadMembers(ctx) + if err != nil { + return nil, false + } + + found, ok := findInIter(items, func(res *cluster.Member) bool { + return fqdnMatch(name, res.TypedSpec().Hostname) || fqdnMatch(name, res.Metadata().ID()) + }) + if !ok { + return nil, false + } + + return xiter.Filter( + xiter.Values(slices.All(found.TypedSpec().Addresses)), + func(addr netip.Addr) bool { + return (qType == dnssrv.TypeA && addr.Is4()) || (qType == dnssrv.TypeAAAA && addr.Is6()) + }, + ), true +} + +func fqdnMatch(what, where string) bool { + what = strings.TrimRight(what, ".") + where = strings.TrimRight(where, ".") + + if what == where { + return true + } + + first, _, found := strings.Cut(where, ".") + if !found { + return false + } + + return what == first +} + +// MemberReader is an interface to read members. +type MemberReader interface { + ReadMembers(ctx context.Context) (iter.Seq[*cluster.Member], error) +} + +func findInIter[T any](it iter.Seq[T], pred func(T) bool) (T, bool) { + for v := range it { + if pred(v) { + return v, true + } + } + + return *new(T), false +} + +func newDNSRunnerOpts(cfg AddressPair, rootHandler dnssrv.Handler, forwardEnabled bool) (RunnerOptions, error) { + if cfg.Addr.Addr().Is6() { + cfg.Network += "6" + } + + var serverOpts RunnerOptions + + controlFn, ctrlErr := MakeControl(cfg.Network, forwardEnabled) + if ctrlErr != nil { + return serverOpts, fmt.Errorf("error creating %q control function: %w", cfg.Network, ctrlErr) + } + + switch cfg.Network { + case "udp", "udp6": + packetConn, err := NewUDPPacketConn(cfg.Network, cfg.Addr.String(), controlFn) + if err != nil { + return serverOpts, fmt.Errorf("error creating %q packet conn: %w", cfg.Network, err) + } + + serverOpts = RunnerOptions{ + PacketConn: packetConn, + Handler: rootHandler, + } + + case "tcp", "tcp6": + listener, err := NewTCPListener(cfg.Network, cfg.Addr.String(), controlFn) + if err != nil { + return serverOpts, fmt.Errorf("error creating %q listener: %w", cfg.Network, err) + } + + serverOpts = RunnerOptions{ + Listener: listener, + Handler: rootHandler, + ReadTimeout: 3 * time.Second, + WriteTimeout: 5 * time.Second, + IdleTimeout: func() time.Duration { return 10 * time.Second }, + MaxTCPQueries: -1, + } + } + + return serverOpts, nil +} diff --git a/internal/pkg/dns/runnner.go b/internal/pkg/dns/runnner.go new file mode 100644 index 00000000000..f0725de97d7 --- /dev/null +++ b/internal/pkg/dns/runnner.go @@ -0,0 +1,108 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package dns + +import ( + "context" + "errors" + "io" + "net" + "strings" + "time" + + "github.com/miekg/dns" + "go.uber.org/zap" + + "github.com/siderolabs/talos/internal/app/machined/pkg/xcontext" +) + +// RunnerOptions is a [Runner] options. +type RunnerOptions struct { + Listener net.Listener + PacketConn net.PacketConn + Handler dns.Handler + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout func() time.Duration + MaxTCPQueries int +} + +// NewRunner creates a new [Runner]. +func NewRunner(opts RunnerOptions, l *zap.Logger) *Runner { + return &Runner{ + srv: &dns.Server{ + Listener: opts.Listener, + PacketConn: opts.PacketConn, + Handler: opts.Handler, + UDPSize: dns.DefaultMsgSize, // 4096 since default is [dns.MinMsgSize] = 512 bytes, which is too small. + ReadTimeout: opts.ReadTimeout, + WriteTimeout: opts.WriteTimeout, + IdleTimeout: opts.IdleTimeout, + MaxTCPQueries: opts.MaxTCPQueries, + }, + logger: l, + } +} + +// Runner is a DNS server runner. +type Runner struct { + srv *dns.Server + logger *zap.Logger +} + +// Serve starts the DNS server. +func (r *Runner) Serve(ctx context.Context) error { + detach := xcontext.AfterFuncSync(ctx, r.close) + defer func() { + if !detach() { + return + } + + r.close() + }() + + return r.srv.ActivateAndServe() +} + +func (r *Runner) close() { + l := r.logger + + if r.srv.Listener != nil { + l = l.With(zap.String("net", "tcp"), zap.String("local_addr", r.srv.Listener.Addr().String())) + } else if r.srv.PacketConn != nil { + l = l.With(zap.String("net", "udp"), zap.String("local_addr", r.srv.PacketConn.LocalAddr().String())) + } + + for { + err := r.srv.Shutdown() + if err != nil { + if strings.Contains(err.Error(), "server not started") { + // There a possible scenario where `go func()` not yet reached `ActivateAndServe` and yielded CPU + // time to another goroutine and then this closure reached `Shutdown`. In that case + // `dns.Server.ActivateAndServe` will actually start after `Shutdown` and this closure will block forever + // because `go func()` will never exit and close `done` channel. + continue + } + + l.Error("error shutting down dns server", zap.Error(err)) + } + + closer := io.Closer(r.srv.Listener) + if closer == nil { + closer = r.srv.PacketConn + } + + if closer != nil { + err = closer.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + l.Error("error closing dns server listener", zap.Error(err)) + } else { + l.Debug("dns server listener closed") + } + } + + break + } +}