diff --git a/graphql.go b/graphql.go index 2113214..706092d 100644 --- a/graphql.go +++ b/graphql.go @@ -333,6 +333,7 @@ type Error struct { Column int `json:"column"` } `json:"locations"` Path []interface{} `json:"path"` + err error } // Error implements error interface. @@ -340,6 +341,11 @@ func (e Error) Error() string { return fmt.Sprintf("Message: %s, Locations: %+v, Extensions: %+v, Path: %+v", e.Message, e.Locations, e.Extensions, e.Path) } +// Unwrap implement the unwrap interface. +func (e Error) Unwrap() error { + return e.err +} + // Error implements error interface. func (e Errors) Error() string { b := strings.Builder{} @@ -349,6 +355,15 @@ func (e Errors) Error() string { return b.String() } +// Unwrap implements the error unwrap interface. +func (e Errors) Unwrap() []error { + var errs []error + for _, err := range e { + errs = append(errs, err.err) + } + return errs +} + func (e Error) getInternalExtension() map[string]interface{} { if e.Extensions == nil { return make(map[string]interface{}) @@ -367,6 +382,7 @@ func newError(code string, err error) Error { Extensions: map[string]interface{}{ "code": code, }, + err: err, } } diff --git a/graphql_test.go b/graphql_test.go index 326e738..224e926 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -253,6 +253,63 @@ func TestClient_Query_errorStatusCode(t *testing.T) { } } +func TestClient_Query_requestError(t *testing.T) { + want := errors.New("bad error") + client := graphql.NewClient("/graphql", &http.Client{Transport: errorRoundTripper{err: want}}) + + var q struct { + User struct { + Name string + } + } + err := client.Query(context.Background(), &q, nil) + if err == nil { + t.Fatal("got error: nil, want: non-nil") + } + if got, want := err.Error(), `Message: Post "/graphql": bad error, Locations: [], Extensions: map[code:request_error], Path: []`; got != want { + t.Errorf("got error: %v, want: %v", got, want) + } + if q.User.Name != "" { + t.Errorf("got non-empty q.User.Name: %v", q.User.Name) + } + if got := err; !errors.Is(got, want) { + t.Errorf("got error: %v, want: %v", got, want) + } + + gqlErr := err.(graphql.Errors) + if got, want := gqlErr[0].Extensions["code"], graphql.ErrRequestError; got != want { + t.Errorf("got error: %v, want: %v", got, want) + } + if _, ok := gqlErr[0].Extensions["internal"]; ok { + t.Errorf("expected empty internal error") + } + if got := gqlErr[0]; !errors.Is(err, want) { + t.Errorf("got error: %v, want %v", got, want) + } + + // test internal error data + client = client.WithDebug(true) + err = client.Query(context.Background(), &q, nil) + if err == nil { + t.Fatal("got error: nil, want: non-nil") + } + if !errors.As(err, &graphql.Errors{}) { + t.Errorf("the error type should be graphql.Errors") + } + gqlErr = err.(graphql.Errors) + if got, want := gqlErr[0].Message, `Post "/graphql": bad error`; got != want { + t.Errorf("got error: %v, want: %v", got, want) + } + if got, want := gqlErr[0].Extensions["code"], graphql.ErrRequestError; got != want { + t.Errorf("got error: %v, want: %v", got, want) + } + interErr := gqlErr[0].Extensions["internal"].(map[string]interface{}) + + if got, want := interErr["request"].(map[string]interface{})["body"], "{\"query\":\"{user{name}}\"}\n"; got != want { + t.Errorf("got error: %v, want: %v", got, want) + } +} + // Test that an empty (but non-nil) variables map is // handled no differently than a nil variables map. func TestClient_Query_emptyVariables(t *testing.T) { @@ -425,6 +482,16 @@ func (l localRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return w.Result(), nil } +// errorRoundTripper is an http.RoundTripper that always returns the supplied +// error. +type errorRoundTripper struct { + err error +} + +func (e errorRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { + return nil, e.err +} + func mustRead(r io.Reader) string { b, err := io.ReadAll(r) if err != nil {