Skip to content

Commit

Permalink
Merge pull request #49 from lesismal/dev
Browse files Browse the repository at this point in the history
jsclient: support keepalive
fix CallAsync timer leak
add CallContext/NotifyContext
  • Loading branch information
lesismal authored May 19, 2023
2 parents 2e5350c + 65b0041 commit 802970d
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 239 deletions.
96 changes: 64 additions & 32 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"sync"
"sync/atomic"
"time"
"unsafe"

"github.com/lesismal/arpc/codec"
"github.com/lesismal/arpc/log"
Expand Down Expand Up @@ -58,7 +59,7 @@ type Client struct {

mux sync.Mutex
sessionMap map[uint64]*rpcSession
asyncHandlerMap map[uint64]HandlerFunc
asyncHandlerMap map[uint64]*asyncHandler

chSend chan *Message
chClose chan util.Empty
Expand Down Expand Up @@ -206,6 +207,12 @@ func (c *Client) Call(method string, req interface{}, rsp interface{}, timeout t
// CallWith uses context to make rpc call.
// CallWith blocks to wait for a response from the server until it times out.
func (c *Client) CallWith(ctx context.Context, method string, req interface{}, rsp interface{}, args ...interface{}) error {
return c.CallContext(ctx, method, req, rsp, args...)
}

// CallContext uses context to make rpc call.
// CallContext blocks to wait for a response from the server until it times out.
func (c *Client) CallContext(ctx context.Context, method string, req interface{}, rsp interface{}, args ...interface{}) error {
if err := c.checkStateAndMethod(method); err != nil {
return err
}
Expand Down Expand Up @@ -262,32 +269,34 @@ func (c *Client) CallWith(ctx context.Context, method string, req interface{}, r
// CallAsync makes an asynchronous rpc call with timeout.
// CallAsync will not block waiting for the server's response,
// But the handler will be called if the response arrives before the timeout.
func (c *Client) CallAsync(method string, req interface{}, handler HandlerFunc, timeout time.Duration, args ...interface{}) error {
func (c *Client) CallAsync(method string, req interface{}, handler AsyncHandlerFunc, timeout time.Duration, args ...interface{}) error {
err := c.checkCallAsyncArgs(method, handler, timeout)
if err != nil {
return err
}

var timer *time.Timer

msg := c.newRequestMessage(CmdRequest, method, req, false, true, args...)
seq := msg.Seq()
if handler != nil {
c.addAsyncHandler(seq, handler)
timer = time.AfterFunc(timeout, func() { c.deleteAsyncHandler(seq) })
defer timer.Stop()
} else if timeout > 0 {
timer = time.NewTimer(timeout)
defer timer.Stop()
}

if c.Handler.AsyncWrite() {
switch timeout {
case TimeZero:
err = c.pushMessage(msg, nil)
default:
err = c.pushMessage(msg, timer)
chTimer := make(chan time.Time, 1)
timerC := *(*<-chan time.Time)(unsafe.Pointer(&chTimer))
timer := &time.Timer{C: timerC}
timerCallback := time.AfterFunc(timeout, func() {
ah, ok := c.getAndDeleteAsyncHandler(seq)
if ok {
if ah.timer != nil {
ah.timer.Stop()
}
ah.handler(nil, ErrTimeout)
putAsyncHandler(ah)
}
chTimer <- time.Now()
})
ah := getAsyncHandler(timerCallback, handler)
c.addAsyncHandler(seq, ah)

if c.Handler.AsyncWrite() {
err = c.pushMessage(msg, timer)
} else {
if !c.reconnecting {
coders := c.Handler.Coders()
Expand All @@ -307,6 +316,7 @@ func (c *Client) CallAsync(method string, req interface{}, handler HandlerFunc,

if err != nil && handler != nil {
c.deleteAsyncHandler(seq)
timerCallback.Stop()
}

return err
Expand Down Expand Up @@ -354,6 +364,12 @@ func (c *Client) Notify(method string, data interface{}, timeout time.Duration,
// NotifyWith use context to make rpc notify.
// A notify does not need a response from the server.
func (c *Client) NotifyWith(ctx context.Context, method string, data interface{}, args ...interface{}) error {
return c.NotifyContext(ctx, method, data, args...)
}

// NotifyContext use context to make rpc notify.
// A notify does not need a response from the server.
func (c *Client) NotifyContext(ctx context.Context, method string, data interface{}, args ...interface{}) error {
if err := c.checkStateAndMethod(method); err != nil {
return err
}
Expand Down Expand Up @@ -466,7 +482,7 @@ func (c *Client) Restart() error {
c.chSend = make(chan *Message, c.Handler.SendQueueSize())
c.chClose = make(chan util.Empty)
c.sessionMap = make(map[uint64]*rpcSession)
c.asyncHandlerMap = make(map[uint64]HandlerFunc)
c.asyncHandlerMap = make(map[uint64]*asyncHandler)
c.values = map[interface{}]interface{}{}

c.initReader()
Expand Down Expand Up @@ -532,16 +548,18 @@ func (c *Client) checkCallArgs(method string, timeout time.Duration) error {
return nil
}

func (c *Client) checkCallAsyncArgs(method string, handler HandlerFunc, timeout time.Duration) error {
func (c *Client) checkCallAsyncArgs(method string, handler AsyncHandlerFunc, timeout time.Duration) error {
if err := c.checkStateAndMethod(method); err != nil {
return err
}

if timeout == 0 {
return ErrClientInvalidTimeoutZero
}
if timeout < 0 {
return ErrClientInvalidTimeoutLessThanZero
}
if timeout == 0 && handler != nil {
return ErrClientInvalidTimeoutZeroWithNonNilCallback
if handler == nil {
return ErrClientInvalidAsyncHandler
}
return nil
}
Expand Down Expand Up @@ -672,37 +690,51 @@ func (c *Client) dropMessage(msg *Message) {
}
}

func (c *Client) addAsyncHandler(seq uint64, h HandlerFunc) {
func (c *Client) addAsyncHandler(seq uint64, ah *asyncHandler) {
c.mux.Lock()
if c.running {
c.asyncHandlerMap[seq] = h
c.asyncHandlerMap[seq] = ah
}
c.mux.Unlock()
}

func (c *Client) deleteAsyncHandler(seq uint64) {
c.mux.Lock()
delete(c.asyncHandlerMap, seq)
ah, ok := c.asyncHandlerMap[seq]
if ok {
delete(c.asyncHandlerMap, seq)
putAsyncHandler(ah)
}
c.mux.Unlock()
}

func (c *Client) getAndDeleteAsyncHandler(seq uint64) (HandlerFunc, bool) {
func (c *Client) getAndDeleteAsyncHandler(seq uint64) (*asyncHandler, bool) {
c.mux.Lock()
handler, ok := c.asyncHandlerMap[seq]
ah, ok := c.asyncHandlerMap[seq]
if ok {
delete(c.asyncHandlerMap, seq)
c.mux.Unlock()
} else {
c.mux.Unlock()
}

return handler, ok
return ah, ok
}

func (c *Client) clearAsyncHandler() {
c.mux.Lock()
c.asyncHandlerMap = make(map[uint64]HandlerFunc)
handlers := c.asyncHandlerMap
c.asyncHandlerMap = make(map[uint64]*asyncHandler)
c.mux.Unlock()
for _, ah := range handlers {
if ah.timer != nil {
ah.timer.Stop()
}
c.Handler.AsyncExecute(func() {
ah.handler(nil, ErrClientReconnecting)
putAsyncHandler(ah)
})
}
}

func (c *Client) run() {
Expand Down Expand Up @@ -923,7 +955,7 @@ func newClientWithConn(conn net.Conn, codec codec.Codec, handler Handler, onStop
c.chSend = make(chan *Message, c.Handler.SendQueueSize())
c.chClose = make(chan util.Empty)
c.sessionMap = make(map[uint64]*rpcSession)
c.asyncHandlerMap = make(map[uint64]HandlerFunc)
c.asyncHandlerMap = make(map[uint64]*asyncHandler)
c.onStop = onStop

if _, ok := conn.(WebsocketConn); !ok {
Expand Down Expand Up @@ -961,7 +993,7 @@ func NewClient(dialer DialerFunc, args ...interface{}) (*Client, error) {
c.chSend = make(chan *Message, c.Handler.SendQueueSize())
c.chClose = make(chan util.Empty)
c.sessionMap = make(map[uint64]*rpcSession)
c.asyncHandlerMap = make(map[uint64]HandlerFunc)
c.asyncHandlerMap = make(map[uint64]*asyncHandler)

c.run()

Expand Down
84 changes: 49 additions & 35 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ var (
testServer *Server
testClientServerAddr = "localhost:11000"

methodCallString = "/callstring"
methodCallBytes = "/callbytes"
methodCallStruct = "/callstruct"
methodCallWith = "/callwith"
methodCallAsync = "/callasync"
methodNotify = "/notify"
methodNotifyWith = "/notifywith"
methodCallError = "/callerror"
methodCallNotFound = "/notfound"
methodCallTimeout = "/timeout"
methodInvalidLong = `1234567890
methodCallString = "/callstring"
methodCallBytes = "/callbytes"
methodCallStruct = "/callstruct"
methodCallWith = "/callwith"
methodCallAsync = "/callasync"
methodCallAsyncTimeout = "/callasynctimeout"
methodNotify = "/notify"
methodNotifyWith = "/notifywith"
methodCallError = "/callerror"
methodCallNotFound = "/notfound"
methodCallTimeout = "/timeout"
methodInvalidLong = `1234567890
1234567890
1234567890
1234567890
Expand Down Expand Up @@ -132,6 +133,10 @@ func initServer() {
testServer.Handler.Handle(methodCallAsync, func(ctx *Context) {
ctx.Write(ctx.Message.Data())
}, true)
testServer.Handler.Handle(methodCallAsyncTimeout, func(ctx *Context) {
time.Sleep(time.Second / 5)
ctx.Write(ctx.Message.Data())
}, true)
testServer.Handler.Handle(methodNotify, func(ctx *Context) {
ctx.Bind(nil)
}, false)
Expand Down Expand Up @@ -503,9 +508,9 @@ func TestClient_CallAsync(t *testing.T) {
testClientCallAsyncDisconnected(c, t)
}

func getAsyncHandler() (func(*Context), chan struct{}) {
func makeAsyncHandler() (func(*Context, error), chan struct{}) {
done := make(chan struct{}, 1)
asyncHandler := func(*Context) {
asyncHandler := func(*Context, error) {
done <- struct{}{}
}
return asyncHandler, done
Expand All @@ -516,7 +521,7 @@ func testClientCallAsyncMethodString(c *Client, t *testing.T) {
err error
req = "hello"
)
asyncHandler, done := getAsyncHandler()
asyncHandler, done := makeAsyncHandler()
if err = c.CallAsync(methodCallAsync, req, asyncHandler, time.Second); err != nil {
t.Fatalf("Client.CallAsync() error = %v", err)
}
Expand All @@ -525,10 +530,7 @@ func testClientCallAsyncMethodString(c *Client, t *testing.T) {
t.Fatalf("Client.CallAsync() error = %v", err)
}
<-done
if err = c.CallAsync(methodCallAsync, &req, nil, time.Second); err != nil {
t.Fatalf("Client.CallAsync() error = %v", err)
}
if err = c.CallAsync(methodCallAsync, req, nil, 0); err != nil {
if err = c.CallAsync(methodCallAsync, &req, func(*Context, error) {}, time.Second); err != nil {
t.Fatalf("Client.CallAsync() error = %v", err)
}
}
Expand All @@ -538,7 +540,7 @@ func testClientCallAsyncMethodBytes(c *Client, t *testing.T) {
err error
req = []byte{1}
)
asyncHandler, done := getAsyncHandler()
asyncHandler, done := makeAsyncHandler()
if err = c.CallAsync(methodCallAsync, req, asyncHandler, time.Second); err != nil {
t.Fatalf("Client.CallAsync() error = %v", err)
}
Expand All @@ -547,7 +549,7 @@ func testClientCallAsyncMethodBytes(c *Client, t *testing.T) {
t.Fatalf("Client.CallAsync() error = %v", err)
}
<-done
if err = c.CallAsync(methodCallAsync, &req, nil, time.Second); err != nil {
if err = c.CallAsync(methodCallAsync, &req, func(*Context, error) {}, time.Second); err != nil {
t.Fatalf("Client.CallAsync() error = %v", err)
}
}
Expand All @@ -557,7 +559,7 @@ func testClientCallAsyncMethodStruct(c *Client, t *testing.T) {
err error
req = MessageTest{A: 3, B: "4"}
)
asyncHandler, done := getAsyncHandler()
asyncHandler, done := makeAsyncHandler()
if err = c.CallAsync(methodCallAsync, &req, asyncHandler, time.Second); err != nil {
t.Fatalf("Client.CallAsync() error = %v", err)
}
Expand All @@ -570,37 +572,49 @@ func testClientCallAsyncMethodStruct(c *Client, t *testing.T) {

func testClientCallAsyncError(c *Client, t *testing.T) {
var err error
asyncHandler, _ := getAsyncHandler()
asyncHandler, _ := makeAsyncHandler()
if err = c.CallAsync(methodCallAsync, "", asyncHandler, -1); err == nil {
t.Fatalf("Client.CallAsync() error is nil, want %v", ErrClientInvalidTimeoutLessThanZero.Error())
} else if err.Error() != ErrClientInvalidTimeoutLessThanZero.Error() {
t.Fatalf("Client.CallAsync() error, returns '%v', want '%v'", err.Error(), ErrClientInvalidTimeoutLessThanZero.Error())
}
if err = c.CallAsync(methodCallAsync, "", nil, -1); err == nil {
t.Fatalf("Client.CallAsync() error is nil, want %v", ErrClientInvalidTimeoutLessThanZero.Error())
} else if err.Error() != ErrClientInvalidTimeoutLessThanZero.Error() {
t.Fatalf("Client.CallAsync() error, returns '%v', want '%v'", err.Error(), ErrClientInvalidTimeoutLessThanZero.Error())
}

asyncHandler, _ = getAsyncHandler()
asyncHandler, _ = makeAsyncHandler()
if err = c.CallAsync(methodCallAsync, "", asyncHandler, 0); err == nil {
t.Fatalf("Client.CallAsync() error is nil, want %v", ErrClientInvalidTimeoutZeroWithNonNilCallback.Error())
} else if err.Error() != ErrClientInvalidTimeoutZeroWithNonNilCallback.Error() {
t.Fatalf("Client.CallAsync() error, returns '%v', want '%v'", err.Error(), ErrClientInvalidTimeoutZeroWithNonNilCallback.Error())
t.Fatalf("Client.CallAsync() error is nil, want %v", ErrClientInvalidTimeoutZero.Error())
} else if err.Error() != ErrClientInvalidTimeoutZero.Error() {
t.Fatalf("Client.CallAsync() error, returns '%v', want '%v'", err.Error(), ErrClientInvalidTimeoutZero.Error())
}

invalidMethodErrString := fmt.Sprintf("invalid method length: %v, should <= %v", len(methodInvalidLong), MaxMethodLen)
if err = c.CallAsync(methodInvalidLong, "", nil, time.Second); err == nil {
if err = c.CallAsync(methodInvalidLong, "", func(*Context, error) {}, time.Second); err == nil {
t.Fatalf("Client.CallAsync() error is nil, want %v", invalidMethodErrString)
} else if err.Error() != invalidMethodErrString {
t.Fatalf("Client.CallAsync() error, returns '%v', want '%v'", err.Error(), invalidMethodErrString)
}
done := make(chan error)
if err = c.CallAsync(methodCallAsyncTimeout, "", func(ctx *Context, err error) {
done <- err
}, time.Second/10); err != nil {
t.Fatalf("Client.CallAsync() error is %v, want %v", err, nil)
}
err = <-done
if err == nil {
t.Fatalf("Client.CallAsync() error is nil, want %v", ErrTimeout)
} else if err != ErrTimeout {
t.Fatalf("Client.CallAsync() error is %v, want %v", err, ErrTimeout)
}
time.Sleep(time.Second / 5)
select {
case err = <-done:
t.Fatalf("Client.CallAsync() callback twice: %v", err)
default:
}
}

func testClientCallAsyncDisconnected(c *Client, t *testing.T) {
var err error
c.Stop()
if err = c.CallAsync(methodCallAsync, "", nil, time.Second); err == nil {
if err = c.CallAsync(methodCallAsync, "", func(*Context, error) {}, time.Second); err == nil {
t.Fatalf("Client.CallAsync() error is nil, want %v", ErrClientStopped)
} else if err.Error() != ErrClientStopped.Error() {
t.Fatalf("Client.CallAsync() error, returns '%v', want '%v'", err.Error(), ErrClientStopped.Error())
Expand All @@ -609,7 +623,7 @@ func testClientCallAsyncDisconnected(c *Client, t *testing.T) {
c.Restart()
testServer.Stop()
time.Sleep(time.Second / 10)
if err = c.CallAsync(methodCallAsync, "", nil, time.Second); err == nil {
if err = c.CallAsync(methodCallAsync, "", func(*Context, error) {}, time.Second); err == nil {
t.Fatalf("Client.CallAsync() error is nil, want %v", ErrClientReconnecting)
} else if err.Error() != ErrClientReconnecting.Error() {
t.Fatalf("Client.CallAsync() error, returns '%v', want '%v'", err.Error(), ErrClientReconnecting.Error())
Expand Down
3 changes: 3 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ var (

// ErrClientInvalidPoolDialers represents an error of empty dialer array.
ErrClientInvalidPoolDialers = errors.New("invalid dialers: empty array")

// ErrClientInvalidAsyncHandler represents an error of invalid(nil) async handler.
ErrClientInvalidAsyncHandler = errors.New("invalid async handler: should not be nil")
)

// message error
Expand Down
Loading

0 comments on commit 802970d

Please sign in to comment.