Skip to content

Commit

Permalink
fix: close sse connections on server shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
palkan committed Sep 7, 2023
1 parent 2b01a06 commit ca55eb6
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 25 deletions.
6 changes: 3 additions & 3 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func (r *Runner) Run() error {
wsServer.Address(), r.config.SSE.Path,
)

sseHandler, err := r.defaultSSEHandler(appNode, r.config)
sseHandler, err := r.defaultSSEHandler(appNode, wsServer.ShutdownCtx(), r.config)

if err != nil {
return errorx.Decorate(err, "!!! Failed to initialize SSE handler !!!")
Expand Down Expand Up @@ -416,9 +416,9 @@ func (r *Runner) defaultWebSocketHandler(n *node.Node, c *config.Config) (http.H
}), nil
}

func (r *Runner) defaultSSEHandler(n *node.Node, c *config.Config) (http.Handler, error) {
func (r *Runner) defaultSSEHandler(n *node.Node, ctx context.Context, c *config.Config) (http.Handler, error) {
extractor := server.DefaultHeadersExtractor{Headers: c.Headers, Cookies: c.Cookies}
handler := sse.SSEHandler(n, &extractor, &c.SSE)
handler := sse.SSEHandler(n, ctx, &extractor, &c.SSE)

return handler, nil
}
Expand Down
31 changes: 23 additions & 8 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ type HTTPServer struct {
mu sync.Mutex
log *log.Entry

shutdownCtx context.Context
shutdownFn context.CancelFunc

mux *chi.Mux
}

Expand Down Expand Up @@ -75,15 +78,19 @@ func NewServer(host string, port string, ssl *SSLConfig, maxConn int) (*HTTPServ
server.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cer}, MinVersion: tls.VersionTLS12}
}

shutdownCtx, shutdownFn := context.WithCancel(context.Background())

return &HTTPServer{
server: server,
addr: addr,
mux: router,
secured: secured,
shutdown: false,
started: false,
maxConn: maxConn,
log: log.WithField("context", "http"),
server: server,
addr: addr,
mux: router,
secured: secured,
shutdown: false,
started: false,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
maxConn: maxConn,
log: log.WithField("context", "http"),
}, nil
}

Expand Down Expand Up @@ -150,9 +157,17 @@ func (s *HTTPServer) Shutdown(ctx context.Context) error {
s.shutdown = true
s.mu.Unlock()

s.shutdownFn()

return s.server.Shutdown(ctx)
}

// ShutdownCtx returns context for graceful shutdown.
// It must be used by HTTP handlers to termniates long-running requests (SSE, long-polling).
func (s *HTTPServer) ShutdownCtx() context.Context {
return s.shutdownCtx
}

// Stopped return true iff server has been stopped by user
func (s *HTTPServer) Stopped() bool {
s.mu.Lock()
Expand Down
33 changes: 23 additions & 10 deletions sse/handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sse

import (
"context"
"net/http"
"strings"

Expand All @@ -13,7 +14,7 @@ import (
)

// SSEHandler generates a new http handler for SSE connections
func SSEHandler(n *node.Node, headersExtractor server.HeadersExtractor, config *Config) http.Handler {
func SSEHandler(n *node.Node, shutdownCtx context.Context, headersExtractor server.HeadersExtractor, config *Config) http.Handler {
var allowedHosts []string

if config.AllowedOrigins == "" {
Expand Down Expand Up @@ -118,15 +119,27 @@ func SSEHandler(n *node.Node, headersExtractor server.HeadersExtractor, config *
conn.Established()
sessionCtx.Debugf("session established")

// TODO: Handle server shutdown. Currently, server is waiting for SSE connections to be closed
select {
case <-r.Context().Done():
sessionCtx.Debugf("request terminated")
session.DisconnectNow("Closed", ws.CloseNormalClosure)
return
case <-conn.Context().Done():
sessionCtx.Debugf("session completed")
return
shutdownReceived := false

for {
select {
case <-shutdownCtx.Done():
if !shutdownReceived {
shutdownReceived = true
sessionCtx.Debugf("server shutdown")
session.DisconnectWithMessage(
&common.DisconnectMessage{Type: "disconnect", Reason: common.SERVER_RESTART_REASON, Reconnect: true},
common.SERVER_RESTART_REASON,
)
}
case <-r.Context().Done():
sessionCtx.Debugf("request terminated")
session.DisconnectNow("Closed", ws.CloseNormalClosure)
return
case <-conn.Context().Done():
sessionCtx.Debugf("session completed")
return
}
}
})
}
18 changes: 14 additions & 4 deletions sse/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestSSEHandler(t *testing.T) {

headersExtractor := &server.DefaultHeadersExtractor{}

handler := SSEHandler(appNode, headersExtractor, &conf)
handler := SSEHandler(appNode, context.Background(), headersExtractor, &conf)

controller.
On("Shutdown").
Expand Down Expand Up @@ -104,7 +104,7 @@ func TestSSEHandler(t *testing.T) {
corsConf := NewConfig()
corsConf.AllowedOrigins = "*.example.com"

corsHandler := SSEHandler(appNode, headersExtractor, &corsConf)
corsHandler := SSEHandler(appNode, context.Background(), headersExtractor, &corsConf)

corsHandler.ServeHTTP(w, req)

Expand Down Expand Up @@ -246,7 +246,7 @@ func TestSSEHandler(t *testing.T) {
assert.Empty(t, w.Body.String())
})

t.Run("POST request without commands", func(t *testing.T) {
t.Run("POST request without commands + server shutdown", func(t *testing.T) {
controller.
On("Authenticate", "sid-post-no-op", mock.Anything).
Return(&common.ConnectResult{
Expand All @@ -266,12 +266,22 @@ func TestSSEHandler(t *testing.T) {
w := httptest.NewRecorder()
sw := newStreamingWriter(w)

go handler.ServeHTTP(sw, req)
shutdownCtx, shutdownFn := context.WithCancel(context.Background())

shutdownHandler := SSEHandler(appNode, shutdownCtx, headersExtractor, &conf)

go shutdownHandler.ServeHTTP(sw, req)

msg, err := sw.ReadEvent(ctx)
require.NoError(t, err)
assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg)

shutdownFn()

msg, err = sw.ReadEvent(ctx)
require.NoError(t, err)
assert.Equal(t, "event: disconnect\n"+`data: {"type":"disconnect","reason":"server_restart","reconnect":true}`, msg)

require.Equal(t, http.StatusOK, w.Code)
})

Expand Down

0 comments on commit ca55eb6

Please sign in to comment.