Skip to content

Commit

Permalink
DEV: ALPN TLS extension code addition
Browse files Browse the repository at this point in the history
  • Loading branch information
Mgrdich committed Dec 6, 2024
1 parent e53b7c6 commit 4fbcb8d
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 3 deletions.
6 changes: 5 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ func main() {
//}
err := pkg.ListenAndServer("127.0.0.1:8080",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello World"))
_, err := w.Write([]byte("Hello World"))

if err != nil {
return
}
}))

if err != nil {
Expand Down
131 changes: 129 additions & 2 deletions pkg/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ type MiniServer struct {
// If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger

// TLSNextProto optionally specifies a function to take over
// ownership of the provided TLS connection when an ALPN
// protocol upgrade has occurred. The map key is the protocol
// name negotiated. The Handler argument should be used to
// handle HTTP requests and will initialize the Request's TLS
// and RemoteAddr if not already set. The connection is
// automatically closed when the function returns.
// If TLSNextProto is not nil, HTTP/2 support is not enabled
// automatically.
TLSNextProto map[string]func(*MiniServer, *tls.Conn, http.Handler)

disableKeepAlives atomic.Bool

mu sync.Mutex
Expand Down Expand Up @@ -912,6 +923,31 @@ func (s *MiniServer) readHeaderTimeout() time.Duration {
return s.ReadTimeout
}

// tlsHandshakeTimeout returns the time limit permitted for the TLS
// handshake, or zero for unlimited.
//
// It returns the minimum of any positive ReadHeaderTimeout,
// ReadTimeout, or WriteTimeout.
func (s *MiniServer) tlsHandshakeTimeout() time.Duration {
var ret time.Duration

for _, v := range [...]time.Duration{
s.ReadHeaderTimeout,
s.ReadTimeout,
s.WriteTimeout,
} {
if v <= 0 {
continue
}

if ret == 0 || v < ret {
ret = v
}
}

return ret
}

// A ConnState represents the state of a client connection to a server.
// It's used by the optional Server.ConnState hook.
type ConnState int
Expand Down Expand Up @@ -1424,8 +1460,66 @@ func (c *conn) serve(ctx context.Context) {
}()

if tlsConn, ok := c.rwc.(*tls.Conn); ok {
// TODO tls and h2 support
fmt.Println(tlsConn)
tlsTO := c.server.tlsHandshakeTimeout()

if tlsTO > 0 {
dl := time.Now().Add(tlsTO)
//nolint:errcheck
c.rwc.SetReadDeadline(dl)
//nolint:errcheck
c.rwc.SetWriteDeadline(dl)
}

if err := tlsConn.HandshakeContext(ctx); err != nil {
// If the handshake failed due to the client not speaking
// TLS, assume they're speaking plaintext HTTP and write a
// 400 response on the TLS conn's underlying net.Conn.
var re tls.RecordHeaderError
if errors.As(err, &re) && re.Conn != nil && tlsRecordHeaderLooksLikeHTTP(re.RecordHeader) {
//nolint:errcheck
io.WriteString(re.Conn, "HTTP/1.0 400 Bad Request\r\n\r\nClient sent an HTTP request to an HTTPS server.\n")
//nolint:errcheck
re.Conn.Close()

return
}

c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)

return
}

// Restore Conn-level deadlines.
if tlsTO > 0 {
//nolint:errcheck
c.rwc.SetReadDeadline(time.Time{})
//nolint:errcheck
c.rwc.SetWriteDeadline(time.Time{})
}

// Restore Conn-level deadlines.
if tlsTO > 0 {
//nolint:errcheck
c.rwc.SetReadDeadline(time.Time{})
//nolint:errcheck
c.rwc.SetWriteDeadline(time.Time{})
}

c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()

if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) {
if fn := c.server.TLSNextProto[proto]; fn != nil {
h := initALPNRequest{ctx, tlsConn, serverHandler{c.server}}
// Mark freshly created HTTP/2 as active and prevent any server state hooks
// from being run on these connections. This prevents closeIdleConns from
// closing such connections.
c.setState(StateActive)
fn(c.server, tlsConn, h)
}

return
}
}

// HTTP/1.x from here on
Expand Down Expand Up @@ -2484,3 +2578,36 @@ func (globalOptionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
io.Copy(io.Discard, mb)
}
}

// initALPNRequest is an HTTP handler that initializes certain
// uninitialized fields in its *Request. Such partially-initialized
// Requests come from ALPN protocol handlers.
type initALPNRequest struct {
//nolint:containedctx
ctx context.Context
c *tls.Conn
h serverHandler
}

// BaseContext is an exported but unadvertised http.Handler method
// recognized by x/net/http2 to pass down a context; the TLSNextProto
// API predates context support so we shoehorn through the only
// interface we have available.
func (h initALPNRequest) BaseContext() context.Context { return h.ctx }

func (h initALPNRequest) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.TLS == nil {
req.TLS = &tls.ConnectionState{}
*req.TLS = h.c.ConnectionState()
}

if req.Body == nil {
req.Body = http.NoBody
}

if req.RemoteAddr == "" {
req.RemoteAddr = h.c.RemoteAddr().String()
}

h.h.ServeHTTP(rw, req)
}
23 changes: 23 additions & 0 deletions pkg/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,26 @@ func ValidMethod(method string) bool {
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}

// tlsRecordHeaderLooksLikeHTTP reports whether a TLS record header
// looks like it might've been a misdirected plaintext HTTP request.
func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
switch string(hdr[:]) {
case "GET /", "HEAD ", "POST ", "PUT /", "OPTIO":
return true
}

return false
}

// validNextProto reports whether the proto is a valid ALPN protocol name.
// Everything is valid except the empty string and built-in protocol types,
// so that those can't be overridden with alternate implementations.
func validNextProto(proto string) bool {
switch proto {
case "", "http/1.1", "http/1.0":
return false
}

return true
}

0 comments on commit 4fbcb8d

Please sign in to comment.