Skip to content

Commit

Permalink
Add option for custom panic handler (#468)
Browse files Browse the repository at this point in the history
Add option for custom panic handler
  • Loading branch information
JohnStarich authored Sep 16, 2021
1 parent 5457f60 commit 446a2dd
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func (r *helloWorldResolver) Hello(ctx context.Context) (string, error) {
- `Tracer(tracer trace.Tracer)` is used to trace queries and fields. It defaults to `trace.OpenTracingTracer`.
- `ValidationTracer(tracer trace.ValidationTracer)` is used to trace validation errors. It defaults to `trace.NoopValidationTracer`.
- `Logger(logger log.Logger)` is used to log panics during query execution. It defaults to `exec.DefaultLogger`.
- `PanicHandler(panicHandler errors.PanicHandler)` is used to transform panics into errors during query execution. It defaults to `errors.DefaultPanicHandler`.
- `DisableIntrospection()` disables introspection queries.

### Custom Errors
Expand Down
18 changes: 18 additions & 0 deletions errors/panic_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package errors

import (
"context"
)

// PanicHandler is the interface used to create custom panic errors that occur during query execution
type PanicHandler interface {
MakePanicError(ctx context.Context, value interface{}) *QueryError
}

// DefaultPanicHandler is the default PanicHandler
type DefaultPanicHandler struct{}

// MakePanicError creates a new QueryError from a panic that occurred during execution
func (h *DefaultPanicHandler) MakePanicError(ctx context.Context, value interface{}) *QueryError {
return Errorf("panic occurred: %v", value)
}
24 changes: 24 additions & 0 deletions errors/panic_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package errors

import (
"context"
"testing"
)

func TestDefaultPanicHandler(t *testing.T) {
handler := &DefaultPanicHandler{}
qErr := handler.MakePanicError(context.Background(), "foo")
if qErr == nil {
t.Fatal("Panic error must not be nil")
}
const (
expectedMessage = "panic occurred: foo"
expectedError = "graphql: " + expectedMessage
)
if qErr.Error() != expectedError {
t.Errorf("Unexpected panic error message: %q != %q", qErr.Error(), expectedError)
}
if qErr.Message != expectedMessage {
t.Errorf("Unexpected panic QueryError.Message: %q != %q", qErr.Message, expectedMessage)
}
}
17 changes: 14 additions & 3 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func ParseSchema(schemaString string, resolver interface{}, opts ...SchemaOpt) (
maxParallelism: 10,
tracer: trace.OpenTracingTracer{},
logger: &log.DefaultLogger{},
panicHandler: &errors.DefaultPanicHandler{},
}
for _, opt := range opts {
opt(s)
Expand Down Expand Up @@ -78,6 +79,7 @@ type Schema struct {
tracer trace.Tracer
validationTracer trace.ValidationTracerContext
logger log.Logger
panicHandler errors.PanicHandler
useStringDescriptions bool
disableIntrospection bool
subscribeResolverTimeout time.Duration
Expand Down Expand Up @@ -143,6 +145,14 @@ func Logger(logger log.Logger) SchemaOpt {
}
}

// PanicHandler is used to customize the panic errors during query execution.
// It defaults to errors.DefaultPanicHandler.
func PanicHandler(panicHandler errors.PanicHandler) SchemaOpt {
return func(s *Schema) {
s.panicHandler = panicHandler
}
}

// DisableIntrospection disables introspection queries.
func DisableIntrospection() SchemaOpt {
return func(s *Schema) {
Expand Down Expand Up @@ -244,9 +254,10 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
Schema: s.schema,
DisableIntrospection: s.disableIntrospection,
},
Limiter: make(chan struct{}, s.maxParallelism),
Tracer: s.tracer,
Logger: s.logger,
Limiter: make(chan struct{}, s.maxParallelism),
Tracer: s.tracer,
Logger: s.logger,
PanicHandler: s.panicHandler,
}
varTypes := make(map[string]*introspection.Type)
for _, v := range op.Vars {
Expand Down
9 changes: 3 additions & 6 deletions internal/exec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,21 @@ type Request struct {
Limiter chan struct{}
Tracer trace.Tracer
Logger log.Logger
PanicHandler errors.PanicHandler
SubscribeResolverTimeout time.Duration
}

func (r *Request) handlePanic(ctx context.Context) {
if value := recover(); value != nil {
r.Logger.LogPanic(ctx, value)
r.AddError(makePanicError(value))
r.AddError(r.PanicHandler.MakePanicError(ctx, value))
}
}

type extensionser interface {
Extensions() map[string]interface{}
}

func makePanicError(value interface{}) *errors.QueryError {
return errors.Errorf("panic occurred: %v", value)
}

func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *types.OperationDefinition) ([]byte, []*errors.QueryError) {
var out bytes.Buffer
func() {
Expand Down Expand Up @@ -188,7 +185,7 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
defer func() {
if panicValue := recover(); panicValue != nil {
r.Logger.LogPanic(ctx, panicValue)
err = makePanicError(panicValue)
err = r.PanicHandler.MakePanicError(ctx, panicValue)
err.Path = path.toSlice()
}
}()
Expand Down
1 change: 1 addition & 0 deletions subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (s *Schema) subscribe(ctx context.Context, queryString string, operationNam
Limiter: make(chan struct{}, s.maxParallelism),
Tracer: s.tracer,
Logger: s.logger,
PanicHandler: s.panicHandler,
SubscribeResolverTimeout: s.subscribeResolverTimeout,
}
varTypes := make(map[string]*introspection.Type)
Expand Down

0 comments on commit 446a2dd

Please sign in to comment.