Skip to content

Commit

Permalink
Add recoverer on transformer (#180)
Browse files Browse the repository at this point in the history
* Add recoverer on transformer

* Add test on transformer server's recoverer

* Remove redundancy of setting http status code
  • Loading branch information
ariefrahmansyah authored Sep 10, 2021
1 parent 463b440 commit 760a925
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
24 changes: 21 additions & 3 deletions api/pkg/transformer/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http/pprof"
"os"
"os/signal"
"runtime/debug"
"strings"
"syscall"
"time"
Expand All @@ -31,8 +32,10 @@ import (

const MerlinLogIdHeader = "X-Merlin-Log-Id"

var shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM}
var onlyOneSignalHandler = make(chan struct{})
var (
shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM}
onlyOneSignalHandler = make(chan struct{})
)

var hystrixCommandName = "model_predict"

Expand Down Expand Up @@ -247,9 +250,11 @@ func (s *Server) predict(ctx context.Context, r *http.Request, request []byte) (

// Run serves the HTTP endpoints.
func (s *Server) Run() {
// use default mux
s.router.Use(recoveryHandler)

health := healthcheck.NewHandler()
s.router.Handle("/", health)

s.router.Handle("/metrics", promhttp.Handler())
s.router.PathPrefix("/debug/pprof/profile").HandlerFunc(pprof.Profile)
s.router.PathPrefix("/debug/pprof/trace").HandlerFunc(pprof.Trace)
Expand Down Expand Up @@ -295,6 +300,19 @@ func (s *Server) Run() {
}
}

func recoveryHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
debug.PrintStack()
response.NewError(http.StatusInternalServerError, fmt.Errorf("panic: %v", err)).Write(w)
}
}()

next.ServeHTTP(w, r)
})
}

// setupSignalHandler registered for SIGTERM and SIGINT. A stop channel is returned
// which is closed on one of these signals. If a second signal is caught, the program
// is terminated with exit code 1.
Expand Down
41 changes: 41 additions & 0 deletions api/pkg/transformer/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ import (
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

feastSdk "github.com/feast-dev/feast/sdk/go"
"github.com/feast-dev/feast/sdk/go/protos/feast/serving"
feastTypes "github.com/feast-dev/feast/sdk/go/protos/feast/types"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.uber.org/zap"
Expand Down Expand Up @@ -764,3 +767,41 @@ func respBody(t *testing.T, response *http.Response) string {

return string(respBody)
}

func Test_recoveryHandler(t *testing.T) {
router := mux.NewRouter()
logger, _ := zap.NewDevelopment()

ts := httptest.NewServer(nil)
defer ts.Close()

port := fmt.Sprint(ts.Listener.Addr().(*net.TCPAddr).Port)
modelName := "test-panic"

s := &Server{
router: router,
logger: logger,
options: &Options{
Port: port,
ModelName: modelName,
},
PreprocessHandler: func(ctx context.Context, rawRequest []byte, rawRequestHeaders map[string]string) ([]byte, error) {
panic("panic at preprocess")
return nil, nil
},
}
go s.Run()

// Give some time for the server to run.
time.Sleep(1 * time.Second)

resp, err := http.Post(fmt.Sprintf("http://localhost:%s/v1/models/%s:predict", port, modelName), "", strings.NewReader("{}"))
assert.Nil(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)

respBody, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()

assert.Nil(t, err)
assert.Equal(t, `{"code":500,"message":"panic: panic at preprocess"}`, string(respBody))
}

0 comments on commit 760a925

Please sign in to comment.