From 7ba7ccaaba4bccdebcb497a9ce4d24e3dffff9c4 Mon Sep 17 00:00:00 2001 From: Eugene K Date: Tue, 12 Sep 2023 16:45:02 -0400 Subject: [PATCH] allow client to connect without ALPN if listener has a single handler --- tls/listener.go | 52 ++++++++++++++++++++++++++++---------------- tls/listener_test.go | 25 +++++++++++++++++++++ 2 files changed, 58 insertions(+), 19 deletions(-) diff --git a/tls/listener.go b/tls/listener.go index a9a04ab..7e767a1 100644 --- a/tls/listener.go +++ b/tls/listener.go @@ -258,32 +258,46 @@ func (self *sharedListener) getConfig(info *tls.ClientHelloInfo) (*tls.Config, e protos := info.SupportedProtos log.Debug("client requesting protocols = ", protos) - if protos == nil { - protos = append(protos, noProtocol) - } - ctx := info.Context() - handler := ctx.Value(handlerKey).(**protocolHandler) + handlerOut := ctx.Value(handlerKey).(**protocolHandler) self.mtx.RLock() defer self.mtx.RUnlock() - for _, proto := range protos { - acc, found := self.handlers[proto] - if found { - log.Debugf("found handler for proto[%s]", proto) - *handler = acc - cfg := acc.tls - if cfg.GetConfigForClient != nil { - c, _ := cfg.GetConfigForClient(info) - if c != nil { - cfg = c - } + var handler *protocolHandler + var proto string + if protos == nil && len(self.handlers) == 1 { + log.Debugf("using single protocol as default") + for p, h := range self.handlers { + proto, handler = p, h + } + } else { + if protos == nil { + protos = append(protos, noProtocol) + } + + for _, p := range protos { + h, found := self.handlers[p] + if found { + log.Debugf("found handler for proto[%s]", proto) + handler = h + proto = p + } + } + } + + if handler != nil { + *handlerOut = handler + cfg := handler.tls + if cfg.GetConfigForClient != nil { + c, _ := cfg.GetConfigForClient(info) + if c != nil { + cfg = c } - cfg = cfg.Clone() - cfg.NextProtos = []string{proto} - return cfg, nil } + cfg = cfg.Clone() + cfg.NextProtos = []string{proto} + return cfg, nil } return nil, fmt.Errorf("not handler for requested protocols %+v", protos) diff --git a/tls/listener_test.go b/tls/listener_test.go index d44bf10..4a34f02 100644 --- a/tls/listener_test.go +++ b/tls/listener_test.go @@ -362,3 +362,28 @@ func TestListenTLS(t *testing.T) { req.NoError(httpListener.Close()) req.NoError(fooListener.Close()) } + +func TestListenSingleProto(t *testing.T) { + req := require.New(t) + + ident := &identity.TokenId{ + Identity: serverId, + Token: "test", + Data: nil, + } + + testAddress := "localhost:14444" + + if _, ok := sharedListeners.Load(testAddress); ok { + t.Error("should be empty") + } + + fooListener, err := Listen(testAddress, "fooListener", ident, makeGreeter("foo"), "foo") + req.NoError(err) + + req.NoError(checkClient(testAddress, "foo", "foo", t), "should find handler") + req.NoError(checkClient(testAddress, "", "foo", t), "should find handler") + req.Error(checkClient(testAddress, "bar", "bar", t), "should have no handler") + + req.NoError(fooListener.Close()) +}