From fbe0af17ce1d2b7f08a4aed43d9ffe9df3f4e49d Mon Sep 17 00:00:00 2001 From: lesismal <40462947+lesismal@users.noreply.github.com> Date: Tue, 3 Aug 2021 20:37:00 +0800 Subject: [PATCH] - support both client and server sync write --- client.go | 146 +++++++++++++++------------ context.go | 2 +- examples/bench_nbio/server/server.go | 4 +- examples/nbio/server/server.go | 4 +- handler.go | 14 +++ 5 files changed, 104 insertions(+), 66 deletions(-) diff --git a/client.go b/client.go index 3243580..99ad77b 100644 --- a/client.go +++ b/client.go @@ -50,8 +50,6 @@ type Client struct { Dialer DialerFunc Head Header - IsAsyncWrite bool - running bool reconnecting bool @@ -131,7 +129,7 @@ func (c *Client) Call(method string, req interface{}, rsp interface{}, timeout t c.deleteSession(seq) }() - if c.IsAsyncWrite { + if c.Handler.AsyncWrite() { select { case c.chSend <- msg: case <-timer.C: @@ -142,13 +140,18 @@ func (c *Client) Call(method string, req interface{}, rsp interface{}, timeout t return ErrClientStopped } } else { - coders := c.Handler.Coders() - for j := 0; j < len(coders); j++ { - msg = coders[j].Encode(c, msg) - } - if _, err := c.Handler.Send(c.Conn, msg.Buffer); err != nil { - c.Conn.Close() - return err + if !c.reconnecting { + coders := c.Handler.Coders() + for j := 0; j < len(coders); j++ { + msg = coders[j].Encode(c, msg) + } + if _, err := c.Handler.Send(c.Conn, msg.Buffer); err != nil { + c.Conn.Close() + return err + } + } else { + c.dropMessage(msg) + return ErrClientReconnecting } } @@ -176,7 +179,7 @@ func (c *Client) CallWith(ctx context.Context, method string, req interface{}, r c.addSession(seq, sess) defer c.deleteSession(seq) - if c.IsAsyncWrite { + if c.Handler.AsyncWrite() { select { case c.chSend <- msg: case <-ctx.Done(): @@ -187,13 +190,18 @@ func (c *Client) CallWith(ctx context.Context, method string, req interface{}, r return ErrClientStopped } } else { - coders := c.Handler.Coders() - for j := 0; j < len(coders); j++ { - msg = coders[j].Encode(c, msg) - } - if _, err := c.Handler.Send(c.Conn, msg.Buffer); err != nil { - c.Conn.Close() - return err + if !c.reconnecting { + coders := c.Handler.Coders() + for j := 0; j < len(coders); j++ { + msg = coders[j].Encode(c, msg) + } + if _, err := c.Handler.Send(c.Conn, msg.Buffer); err != nil { + c.Conn.Close() + return err + } + } else { + c.dropMessage(msg) + return ErrClientReconnecting } } @@ -230,7 +238,7 @@ func (c *Client) CallAsync(method string, req interface{}, handler HandlerFunc, defer timer.Stop() } - if c.IsAsyncWrite { + if c.Handler.AsyncWrite() { switch timeout { case TimeZero: err = c.pushMessage(msg, nil) @@ -238,13 +246,18 @@ func (c *Client) CallAsync(method string, req interface{}, handler HandlerFunc, err = c.pushMessage(msg, timer) } } else { - coders := c.Handler.Coders() - for j := 0; j < len(coders); j++ { - msg = coders[j].Encode(c, msg) - } - _, err = c.Handler.Send(c.Conn, msg.Buffer) - if err != nil { - c.Conn.Close() + if !c.reconnecting { + coders := c.Handler.Coders() + for j := 0; j < len(coders); j++ { + msg = coders[j].Encode(c, msg) + } + _, err = c.Handler.Send(c.Conn, msg.Buffer) + if err != nil { + c.Conn.Close() + } + } else { + c.dropMessage(msg) + err = ErrClientReconnecting } } @@ -265,7 +278,7 @@ func (c *Client) Notify(method string, data interface{}, timeout time.Duration, msg := c.newRequestMessage(CmdNotify, method, data, false, true, args...) - if c.IsAsyncWrite { + if c.Handler.AsyncWrite() { switch timeout { case TimeZero: err = c.pushMessage(msg, nil) @@ -275,13 +288,18 @@ func (c *Client) Notify(method string, data interface{}, timeout time.Duration, err = c.pushMessage(msg, timer) } } else { - coders := c.Handler.Coders() - for j := 0; j < len(coders); j++ { - msg = coders[j].Encode(c, msg) - } - _, err = c.Handler.Send(c.Conn, msg.Buffer) - if err != nil { - c.Conn.Close() + if !c.reconnecting { + coders := c.Handler.Coders() + for j := 0; j < len(coders); j++ { + msg = coders[j].Encode(c, msg) + } + _, err = c.Handler.Send(c.Conn, msg.Buffer) + if err != nil { + c.Conn.Close() + } + } else { + c.dropMessage(msg) + err = ErrClientReconnecting } } @@ -297,7 +315,7 @@ func (c *Client) NotifyWith(ctx context.Context, method string, data interface{} msg := c.newRequestMessage(CmdNotify, method, data, false, true, args...) - if c.IsAsyncWrite { + if c.Handler.AsyncWrite() { select { case c.chSend <- msg: case <-ctx.Done(): @@ -308,13 +326,18 @@ func (c *Client) NotifyWith(ctx context.Context, method string, data interface{} return ErrClientStopped } } else { - coders := c.Handler.Coders() - for j := 0; j < len(coders); j++ { - msg = coders[j].Encode(c, msg) - } - if _, err := c.Handler.Send(c.Conn, msg.Buffer); err != nil { - c.Conn.Close() - return err + if !c.reconnecting { + coders := c.Handler.Coders() + for j := 0; j < len(coders); j++ { + msg = coders[j].Encode(c, msg) + } + if _, err := c.Handler.Send(c.Conn, msg.Buffer); err != nil { + c.Conn.Close() + return err + } + } else { + c.dropMessage(msg) + return ErrClientReconnecting } } @@ -328,16 +351,21 @@ func (c *Client) PushMsg(msg *Message, timeout time.Duration) error { return err } - if !c.IsAsyncWrite { - coders := c.Handler.Coders() - for j := 0; j < len(coders); j++ { - msg = coders[j].Encode(c, msg) - } - _, err := c.Handler.Send(c.Conn, msg.Buffer) - if err != nil { - c.Conn.Close() + if !c.Handler.AsyncWrite() { + if !c.reconnecting { + coders := c.Handler.Coders() + for j := 0; j < len(coders); j++ { + msg = coders[j].Encode(c, msg) + } + _, err := c.Handler.Send(c.Conn, msg.Buffer) + if err != nil { + c.Conn.Close() + } + return err + } else { + c.dropMessage(msg) + return ErrClientReconnecting } - return err } if timeout < 0 { @@ -390,7 +418,7 @@ func (c *Client) Restart() error { c.values = map[interface{}]interface{}{} c.initReader() - if c.IsAsyncWrite { + if c.Handler.AsyncWrite() { go util.Safe(c.sendLoop) } go util.Safe(c.recvLoop) @@ -622,7 +650,7 @@ func (c *Client) run() { if !c.running { c.running = true c.initReader() - if c.IsAsyncWrite { + if c.Handler.AsyncWrite() { go util.Safe(c.sendLoop) } go util.Safe(c.recvLoop) @@ -635,7 +663,7 @@ func (c *Client) runWebsocket() { if !c.running { c.running = true c.initReader() - if c.IsAsyncWrite { + if c.Handler.AsyncWrite() { go util.Safe(c.sendLoop) } c.Conn.(WebsocketConn).HandleWebsocket(c.recvLoop) @@ -822,19 +850,12 @@ func newClientWithConn(conn net.Conn, codec codec.Codec, handler Handler, onStop } // NewClient creates a Client. -func NewClient(dialer DialerFunc, args ...interface{}) (*Client, error) { +func NewClient(dialer DialerFunc) (*Client, error) { conn, err := dialer() if err != nil { return nil, err } - isAsyncWrite := true - if len(args) > 0 { - if asyncWrite, ok := args[0].(bool); ok { - isAsyncWrite = asyncWrite - } - } - c := &Client{} c.Conn = conn c.Codec = codec.DefaultCodec @@ -845,7 +866,6 @@ func NewClient(dialer DialerFunc, args ...interface{}) (*Client, error) { c.chClose = make(chan util.Empty) c.sessionMap = make(map[uint64]*rpcSession) c.asyncHandlerMap = make(map[uint64]HandlerFunc) - c.IsAsyncWrite = isAsyncWrite c.run() diff --git a/context.go b/context.go index 5eb86d6..2a2feb9 100644 --- a/context.go +++ b/context.go @@ -124,7 +124,7 @@ func (ctx *Context) Value(key interface{}) interface{} { func (ctx *Context) write(v interface{}, isError bool, timeout time.Duration) error { cli := ctx.Client - if cli.IsAsyncWrite { + if !cli.Handler.AsyncWrite() { return ctx.writeDirectly(v, isError) } req := ctx.Message diff --git a/examples/bench_nbio/server/server.go b/examples/bench_nbio/server/server.go index cdca68a..ba112df 100644 --- a/examples/bench_nbio/server/server.go +++ b/examples/bench_nbio/server/server.go @@ -34,7 +34,7 @@ type Session struct { } func onOpen(c *nbio.Conn) { - client := &arpc.Client{Conn: c, Codec: codec.DefaultCodec, IsAsync: true} + client := &arpc.Client{Conn: c, Codec: codec.DefaultCodec, Handler: handler} session := &Session{ Client: client, Buffer: nil, @@ -68,6 +68,8 @@ func onData(c *nbio.Conn, data []byte) { func main() { nlog.SetLogger(log.DefaultLogger) + handler.SetAsyncWrite(false) + // register router handler.Handle("Hello", func(ctx *arpc.Context) { req := &HelloReq{} diff --git a/examples/nbio/server/server.go b/examples/nbio/server/server.go index 574e089..73a6908 100644 --- a/examples/nbio/server/server.go +++ b/examples/nbio/server/server.go @@ -24,7 +24,7 @@ type Session struct { } func onOpen(c *nbio.Conn) { - client := &arpc.Client{Conn: c, Codec: codec.DefaultCodec, IsAsync: true} + client := &arpc.Client{Conn: c, Codec: codec.DefaultCodec, Handler: handler} session := &Session{ Client: client, Buffer: nil, @@ -58,6 +58,8 @@ func onData(c *nbio.Conn, data []byte) { func main() { nlog.SetLogger(log.DefaultLogger) + handler.SetAsyncWrite(false) + // register router handler.Handle("/echo", func(ctx *arpc.Context) { str := "" diff --git a/handler.go b/handler.go index 260bc2e..6e4c208 100644 --- a/handler.go +++ b/handler.go @@ -77,6 +77,11 @@ type Handler interface { // SetBatchSend sets BatchSend flag. SetBatchSend(batch bool) + // AsyncWrite returns AsyncWrite flag. + AsyncWrite() bool + // SetAsyncWrite sets AsyncWrite flag. + SetAsyncWrite(async bool) + // AsyncResponse returns AsyncResponse flag. AsyncResponse() bool // SetAsyncResponse sets AsyncResponse flag. @@ -140,6 +145,7 @@ type handler struct { logtag string batchRecv bool batchSend bool + asyncWrite bool asyncResponse bool recvBufferSize int sendQueueSize int @@ -283,6 +289,14 @@ func (h *handler) SetBatchSend(batch bool) { h.batchSend = batch } +func (h *handler) AsyncWrite() bool { + return h.asyncWrite +} + +func (h *handler) SetAsyncWrite(async bool) { + h.asyncWrite = async +} + func (h *handler) AsyncResponse() bool { return h.asyncResponse }