diff --git a/examples/bench_pool/client/client.go b/examples/bench_pool/client/client.go index ac9f870..527865e 100644 --- a/examples/bench_pool/client/client.go +++ b/examples/bench_pool/client/client.go @@ -77,29 +77,32 @@ func main() { go func() { ticker := time.NewTicker(time.Second) for i := 0; true; i++ { - select { - case <-ticker.C: - req := &HelloReq{Msg: fmt.Sprintf("[%v] %v", client.Conn.LocalAddr(), i)} - rsp := &HelloRsp{} - err := client.CallAsync(method, req, func(ctx *arpc.Context) { - err := ctx.Bind(rsp) - if err != nil || rsp.Msg != req.Msg { - log.Printf("CallAsync failed: %v", err) - atomic.AddUint64(&failedTotal, 1) - } else { - //log.Printf("Call Response: \"%v\"", rsp.Msg) - atomic.AddUint64(&qpsSec, 1) - atomic.AddUint64(&asyncTimes, 1) - } - }, time.Second*5) - if err != nil { - log.Printf("CallAsync failed: %v", err) + <-ticker.C + req := &HelloReq{Msg: fmt.Sprintf("[%v] %v", client.Conn.LocalAddr(), i)} + rsp := &HelloRsp{} + err := client.CallAsync(method, req, func(ctx *arpc.Context, er error) { + if er != nil { + log.Printf("CallAsync failed: %v", er) + atomic.AddUint64(&failedTotal, 1) + return + } + er = ctx.Bind(rsp) + if er != nil || rsp.Msg != req.Msg { + log.Printf("CallAsync failed: %v", er) atomic.AddUint64(&failedTotal, 1) } else { //log.Printf("Call Response: \"%v\"", rsp.Msg) atomic.AddUint64(&qpsSec, 1) atomic.AddUint64(&asyncTimes, 1) } + }, time.Second*5) + if err != nil { + log.Printf("CallAsync failed: %v", err) + atomic.AddUint64(&failedTotal, 1) + } else { + //log.Printf("Call Response: \"%v\"", rsp.Msg) + atomic.AddUint64(&qpsSec, 1) + atomic.AddUint64(&asyncTimes, 1) } } }() diff --git a/examples/rpc/client/client.go b/examples/rpc/client/client.go index fc9d432..5aaef01 100644 --- a/examples/rpc/client/client.go +++ b/examples/rpc/client/client.go @@ -32,14 +32,17 @@ func main() { log.Printf("Call /echo/async Response: \"%v\"", rsp) } done := make(chan string) - err = client.CallAsync("/echo/async", &req, func(ctx *arpc.Context) { + err = client.CallAsync("/echo/async", &req, func(ctx *arpc.Context, er error) { + if er != nil { + log.Fatalf("Call /echo/async failed: %v", err) + } rsp := "" - err = ctx.Bind(&rsp) - if err != nil { - log.Fatalf("Call /echo/async Bind failed: %v", err) + er = ctx.Bind(&rsp) + if er != nil { + log.Fatalf("Call /echo/async Bind failed: %v", er) } if rsp != req { - log.Fatalf("Call /echo/async failed: %v", err) + log.Fatalf("Call /echo/async failed: %v", er) } done <- rsp }, time.Second*5) diff --git a/extension/protocol/quic/quic.go b/extension/protocol/quic/quic.go index f4c5fd7..ceb19e1 100644 --- a/extension/protocol/quic/quic.go +++ b/extension/protocol/quic/quic.go @@ -10,32 +10,32 @@ import ( "net" "time" - quic "github.com/lucas-clemente/quic-go" + quic "github.com/quic-go/quic-go" ) // Listener wraps quick.Listener to net.Listener type Listener struct { - quic.Listener + *quic.Listener } // Accept waits for and returns the next connection to the listener. func (ln *Listener) Accept() (net.Conn, error) { - session, err := ln.Listener.Accept(context.Background()) + conn, err := ln.Listener.Accept(context.Background()) if err != nil { return nil, err } - stream, err := session.AcceptStream(context.Background()) + stream, err := conn.AcceptStream(context.Background()) if err != nil { return nil, err } - return &Conn{session, stream}, err + return &Conn{conn, stream}, err } // Conn wraps quick.Session to net.Conn type Conn struct { - quic.Session + quic.Connection quic.Stream } @@ -59,7 +59,7 @@ func Dial(addr string, tlsConf *tls.Config, quicConf *quic.Config, timeout time. defer cancel() } - session, err := quic.DialAddr(addr, tlsConf, quicConf) + session, err := quic.DialAddr(ctx, addr, tlsConf, quicConf) if err != nil { return nil, err } diff --git a/extension/protocol/websocket/websocket.go b/extension/protocol/websocket/websocket.go index caa9424..31a9ee4 100644 --- a/extension/protocol/websocket/websocket.go +++ b/extension/protocol/websocket/websocket.go @@ -46,7 +46,6 @@ func (ln *Listener) Handler(w http.ResponseWriter, r *http.Request) { case <-ln.chClose: c.Close() } - } // Close . @@ -64,9 +63,10 @@ func (ln *Listener) Addr() net.Addr { // Accept . func (ln *Listener) Accept() (net.Conn, error) { - c := <-ln.chAccept - if c != nil { + select { + case c := <-ln.chAccept: return c, nil + case <-ln.chClose: } return nil, ErrClosed } @@ -150,8 +150,15 @@ func Listen(addr string, upgrader *websocket.Upgrader) (net.Listener, error) { } // Dial wraps websocket dial -func Dial(url string) (net.Conn, error) { - c, _, err := websocket.DefaultDialer.Dial(url, nil) +func Dial(url string, args ...interface{}) (net.Conn, error) { + dialer := websocket.DefaultDialer + if len(args) > 0 { + d, ok := args[0].(*websocket.Dialer) + if ok { + dialer = d + } + } + c, _, err := dialer.Dial(url, nil) if err != nil { return nil, err }