Skip to content

Commit

Permalink
Merge pull request #58 from lesismal/stream
Browse files Browse the repository at this point in the history
Stream
  • Loading branch information
lesismal authored Mar 6, 2024
2 parents 78d6a26 + 34c6d63 commit 3a871ae
Show file tree
Hide file tree
Showing 10 changed files with 1,249 additions and 39 deletions.
130 changes: 106 additions & 24 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ type Client struct {
mux sync.Mutex
sessionMap map[uint64]*rpcSession
asyncHandlerMap map[uint64]*asyncHandler
streamLocalMap map[uint64]*Stream
streamRemoteMap map[uint64]*Stream

chSend chan *Message
chClose chan util.Empty
Expand Down Expand Up @@ -180,11 +182,11 @@ func (c *Client) Call(method string, req interface{}, rsp interface{}, timeout t
msg = coders[j].Encode(c, msg)
}
_, err := c.Handler.Send(c.Conn, msg.Buffer)
c.Handler.OnMessageDone(c, msg)
if err != nil {
c.Conn.Close()
return err
}
c.Handler.OnMessageDone(c, msg)
return err
} else {
c.dropMessage(msg)
return ErrClientReconnecting
Expand Down Expand Up @@ -242,11 +244,11 @@ func (c *Client) CallContext(ctx context.Context, method string, req interface{}
msg = coders[j].Encode(c, msg)
}
_, err := c.Handler.Send(c.Conn, msg.Buffer)
c.Handler.OnMessageDone(c, msg)
if err != nil {
c.Conn.Close()
return err
}
c.Handler.OnMessageDone(c, msg)
return err
} else {
c.dropMessage(msg)
return ErrClientReconnecting
Expand Down Expand Up @@ -483,10 +485,13 @@ func (c *Client) Restart() error {
c.chClose = make(chan util.Empty)
c.sessionMap = make(map[uint64]*rpcSession)
c.asyncHandlerMap = make(map[uint64]*asyncHandler)
c.streamLocalMap = make(map[uint64]*Stream)
c.streamRemoteMap = make(map[uint64]*Stream)
c.values = map[interface{}]interface{}{}

c.initReader()
if c.Handler.AsyncWrite() {
c.chSend = make(chan *Message, c.Handler.SendQueueSize())
go util.Safe(c.sendLoop)
}
go util.Safe(c.recvLoop)
Expand Down Expand Up @@ -521,6 +526,7 @@ func (c *Client) closeAndClean() {
if c.onStop != nil {
c.onStop(c)
}

c.Handler.OnDisconnected(c)
}

Expand Down Expand Up @@ -737,6 +743,69 @@ func (c *Client) clearAsyncHandler() {
}
}

func (c *Client) addStream(id uint64, local bool, stream *Stream) {
c.mux.Lock()
if c.running {
var streamMap map[uint64]*Stream
if local {
streamMap = c.streamLocalMap
} else {
streamMap = c.streamRemoteMap
}
streamMap[id] = stream
}
c.mux.Unlock()
}

func (c *Client) deleteStream(id uint64, local bool) {
c.mux.Lock()
if c.running {
var streamMap map[uint64]*Stream
if local {
streamMap = c.streamLocalMap
} else {
streamMap = c.streamRemoteMap
}
delete(streamMap, id)
}
c.mux.Unlock()
}

func (c *Client) getStreamAndPushMsg(id uint64, local, done bool) (stream *Stream, ok bool) {
c.mux.Lock()
if c.running {
var streamMap map[uint64]*Stream
if local {
streamMap = c.streamLocalMap
} else {
streamMap = c.streamRemoteMap
}
stream, ok = streamMap[id]
if ok && done {
delete(streamMap, id)
}
}
c.mux.Unlock()
return stream, ok
}

func (c *Client) clearStream() {
c.mux.Lock()
streamLocalMap := c.streamLocalMap
streamRemoteMap := c.streamRemoteMap
c.streamLocalMap = make(map[uint64]*Stream)
c.streamRemoteMap = make(map[uint64]*Stream)
c.mux.Unlock()
for _, stream := range streamLocalMap {
stream.CloseSend()
stream.CloseRecv()
}
for _, stream := range streamRemoteMap {
stream.CloseSend()
stream.CloseRecv()
}
}

func (c *Client) run() {
c.mux.Lock()
defer c.mux.Unlock()
Expand Down Expand Up @@ -797,6 +866,7 @@ func (c *Client) recvLoop() {
c.Conn.Close()
c.clearSession()
c.clearAsyncHandler()
c.clearStream()

// if c.running {
// log.Info("%v\t%v\tReconnect Start", c.Handler.LogTag(), addr)
Expand Down Expand Up @@ -934,16 +1004,22 @@ func (c *Client) batchSendLoop() {
func newClientWithConn(conn net.Conn, codec codec.Codec, handler Handler, onStop func(*Client)) *Client {
log.Info("%v\t%v\tConnected", handler.LogTag(), conn.RemoteAddr())

c := &Client{}
c.Conn = conn
c.Codec = codec
c.Handler = handler
c.Head = make([]byte, 4)
c.chSend = make(chan *Message, c.Handler.SendQueueSize())
c.chClose = make(chan util.Empty)
c.sessionMap = make(map[uint64]*rpcSession)
c.asyncHandlerMap = make(map[uint64]*asyncHandler)
c.onStop = onStop
c := &Client{
seq: 1,
Conn: conn,
Codec: codec,
Handler: handler,
Head: make([]byte, 4),
chClose: make(chan util.Empty),
sessionMap: make(map[uint64]*rpcSession),
asyncHandlerMap: make(map[uint64]*asyncHandler),
streamLocalMap: make(map[uint64]*Stream),
streamRemoteMap: make(map[uint64]*Stream),
onStop: onStop,
}
if c.Handler.AsyncWrite() {
c.chSend = make(chan *Message, handler.SendQueueSize())
}

c.run()

Expand All @@ -967,16 +1043,22 @@ func NewClient(dialer DialerFunc, args ...interface{}) (*Client, error) {
handler = DefaultHandler.Clone()
}

c := &Client{}
c.Conn = conn
c.Codec = codec.DefaultCodec
c.Handler = handler
c.Dialer = dialer
c.Head = make([]byte, 4)
c.chSend = make(chan *Message, c.Handler.SendQueueSize())
c.chClose = make(chan util.Empty)
c.sessionMap = make(map[uint64]*rpcSession)
c.asyncHandlerMap = make(map[uint64]*asyncHandler)
c := &Client{
seq: 1,
Conn: conn,
Codec: codec.DefaultCodec,
Handler: handler,
Dialer: dialer,
Head: make([]byte, 4),
chClose: make(chan util.Empty),
sessionMap: make(map[uint64]*rpcSession),
asyncHandlerMap: make(map[uint64]*asyncHandler),
streamLocalMap: make(map[uint64]*Stream),
streamRemoteMap: make(map[uint64]*Stream),
}
if c.Handler.AsyncWrite() {
c.chSend = make(chan *Message, handler.SendQueueSize())
}

c.run()

Expand Down
6 changes: 6 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ var (
ErrContextResponseToNotify = errors.New("should not response to a context with notify message")
)

// stream errors
var (
// ErrStreamClosedSend represents an error of stream closed send.
ErrStreamClosedSend = errors.New("stream has closed send")
)

// general errors
var (
// ErrTimeout represents an error of timeout.
Expand Down
111 changes: 111 additions & 0 deletions examples/stream/client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package main

import (
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"log"
"net"
"sync"
"time"

"github.com/lesismal/arpc"
)

var (
addr = "localhost:8888"

method = "Hello"
)

// HelloReq .
type HelloReq struct {
Msg string
}

// HelloRsp .
type HelloRsp struct {
Msg string
}

func dialer() (net.Conn, error) {
return net.DialTimeout("tcp", addr, time.Second*3)
}

func main() {
arpc.EnablePool(true)

client, err := arpc.NewClient(dialer)
if err != nil {
log.Println("NewClient failed:", err)
return
}
defer client.Stop()

wg := &sync.WaitGroup{}
wg.Add(1)
client.Handler.HandleStream("/stream_server_to_client", func(stream *arpc.Stream) {
defer wg.Done()
defer stream.CloseRecv()
for {
str := ""
err := stream.Recv(&str)
if err == io.EOF {
stream.CloseSend()
log.Printf("[client] [stream id: %v] stream_server_to_client closed", stream.Id())
break
}
if err != nil {
panic(err)
}
log.Printf("[client] [stream id: %v] stream_server_to_client: %v", stream.Id(), str)
err = stream.Send(&str)
if err != nil {
panic(err)
}
}
})

data := make([]byte, 10)
rand.Read(data)
req := &HelloReq{Msg: base64.RawStdEncoding.EncodeToString(data)}
rsp := &HelloRsp{}
err = client.Call(method, req, rsp, time.Second*5)
if err != nil {
log.Printf("Call failed: %v", err)
} else if rsp.Msg != req.Msg {
log.Fatal("Call failed: not equal")
}

wg.Wait()
time.Sleep(time.Second)

stream := client.NewStream("/stream_client_to_server")
defer stream.CloseRecv()
go func() {
for i := 0; i < 3; i++ {
err := stream.Send(fmt.Sprintf("stream data %v", i))
if err != nil {
panic(err)
}
}
err = stream.SendAndClose(fmt.Sprintf("stream data %v", 3))
if err != nil {
panic(err)
}
}()

for {
str := ""
err = stream.Recv(&str)
if err == io.EOF {
log.Printf("[client] [stream id: %v] stream_client_to_server closed", stream.Id())
break
}
if err != nil {
panic(err)
}
log.Printf("[client] [stream id: %v] stream_client_to_server: %v", stream.Id(), str)
}
}
Loading

0 comments on commit 3a871ae

Please sign in to comment.