diff --git a/context.go b/context.go index 9600a47..0f46bab 100644 --- a/context.go +++ b/context.go @@ -25,8 +25,9 @@ type Context struct { Client *Client Message *Message - index int - handlers []HandlerFunc + index int + handlers []HandlerFunc + responseErr interface{} } func (ctx *Context) Release() { @@ -35,6 +36,10 @@ func (ctx *Context) Release() { contextPool.Put(ctx) } +func (ctx *Context) ResponseError() interface{} { + return ctx.responseErr +} + // Get returns value for key. func (ctx *Context) Get(key interface{}) (interface{}, bool) { if len(ctx.Message.values) == 0 { @@ -153,6 +158,10 @@ func (ctx *Context) write(v interface{}, isError bool, timeout time.Duration) er if _, ok := v.(error); ok { isError = true } + if isError { + ctx.responseErr = v + } + rsp := newMessage(CmdResponse, req.method(), v, isError, req.IsAsync(), req.Seq(), cli.Handler, cli.Codec, ctx.Message.values) return cli.PushMsg(rsp, timeout) } diff --git a/context_test.go b/context_test.go index 293e5b1..a822acb 100644 --- a/context_test.go +++ b/context_test.go @@ -5,6 +5,7 @@ package arpc import ( + "errors" "testing" "github.com/lesismal/arpc/codec" @@ -110,3 +111,15 @@ func TestContext_Value(t *testing.T) { t.Fatalf("Context.Value() value != 'value', have %v", value) } } + +func TestContext_ResponseError(t *testing.T) { + ctx := &Context{ + Client: &Client{Handler: DefaultHandler}, + Message: newMessage(CmdRequest, "test", nil, false, false, 0, DefaultHandler, codec.DefaultCodec, nil), + } + err := errors.New("test err") + ctx.Error(err) + if ctx.ResponseError() != err { + t.Fatalf("Context.ResponseError() != 'test err'") + } +}