Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TarsGo async rpc call #387

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tars/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func (c *AdapterProxy) doKeepAlive() {
IRequestId: c.servantProxy.genRequestID(),
SServantName: c.servantProxy.name,
SFuncName: "tars_ping",
ITimeout: int32(c.servantProxy.timeout),
ITimeout: int32(c.servantProxy.asyncTimeout),
}
msg := &Message{Req: &req, Ser: c.servantProxy}
msg.Init()
Expand Down
1 change: 1 addition & 0 deletions tars/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ func (a *application) initConfig() {
a.cltCfg.Stat = cMap["stat"]
a.cltCfg.Property = cMap["property"]
a.cltCfg.ModuleName = cMap["modulename"]
a.cltCfg.SyncInvokeTimeout = c.GetIntWithDef("/tars/application/client<sync-invoke-timeout>", SyncInvokeTimeout)
a.cltCfg.AsyncInvokeTimeout = c.GetIntWithDef("/tars/application/client<async-invoke-timeout>", AsyncInvokeTimeout)
a.cltCfg.RefreshEndpointInterval = c.GetIntWithDef("/tars/application/client<refresh-endpoint-interval>", refreshEndpointInterval)
a.cltCfg.ReportInterval = c.GetIntWithDef("/tars/application/client<report-interval>", reportInterval)
Expand Down
4 changes: 3 additions & 1 deletion tars/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ type clientConfig struct {
ReportInterval int
CheckStatusInterval int
KeepAliveInterval int
AsyncInvokeTimeout int
// add client timeout
SyncInvokeTimeout int
AsyncInvokeTimeout int
ClientQueueLen int
ClientIdleTimeout time.Duration
ClientReadTimeout time.Duration
Expand Down Expand Up @@ -152,6 +153,7 @@ func newClientConfig() *clientConfig {
ReportInterval: reportInterval,
CheckStatusInterval: checkStatusInterval,
KeepAliveInterval: keepAliveInterval,
SyncInvokeTimeout: SyncInvokeTimeout,
AsyncInvokeTimeout: AsyncInvokeTimeout,
ClientQueueLen: ClientQueueLen,
ClientIdleTimeout: tools.ParseTimeOut(ClientIdleTimeout),
Expand Down
62 changes: 62 additions & 0 deletions tars/message.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package tars

import (
"context"
"time"

"github.com/TarsCloud/TarsGo/tars/model"
"github.com/TarsCloud/TarsGo/tars/protocol/res/basef"
"github.com/TarsCloud/TarsGo/tars/protocol/res/requestf"
"github.com/TarsCloud/TarsGo/tars/selector"
"github.com/TarsCloud/TarsGo/tars/util/current"
"github.com/TarsCloud/TarsGo/tars/util/tools"
)

// HashType is the hash type
Expand All @@ -31,6 +36,9 @@ type Message struct {
hashCode uint32
hashType HashType
isHash bool
Async bool
Callback model.Callback
RespCh chan *requestf.ResponsePacket
}

// Init define the beginTime
Expand Down Expand Up @@ -66,3 +74,57 @@ func (m *Message) HashType() selector.HashType {
func (m *Message) IsHash() bool {
return m.isHash
}

func newMessage(ctx context.Context, cType byte,
sFuncName string,
buf []byte,
status map[string]string,
reqContext map[string]string,
resp *requestf.ResponsePacket,
s *ServantProxy) *Message {

// 将ctx中的dyeing信息传入到request中
var msgType int32
if dyeingKey, ok := current.GetDyeingKey(ctx); ok {
TLOG.Debug("dyeing debug: find dyeing key:", dyeingKey)
if status == nil {
status = make(map[string]string)
}
status[current.StatusDyedKey] = dyeingKey
msgType |= basef.TARSMESSAGETYPEDYED
}

// 将ctx中的trace信息传入到request中
if trace, ok := current.GetTarsTrace(ctx); ok && trace.Call() {
traceKey := trace.GetTraceFullKey(false)
TLOG.Debug("trace debug: find trace key:", traceKey)
if status == nil {
status = make(map[string]string)
}
status[current.StatusTraceKey] = traceKey
msgType |= basef.TARSMESSAGETYPETRACE
}

req := requestf.RequestPacket{
IVersion: s.version,
CPacketType: int8(cType),
IMessageType: msgType,
IRequestId: s.genRequestID(),
SServantName: s.name,
SFuncName: sFuncName,
ITimeout: int32(s.syncTimeout),
SBuffer: tools.ByteToInt8(buf),
Context: reqContext,
Status: status,
}
msg := &Message{Req: &req, Ser: s, Resp: resp}
msg.Init()

if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok {
msg.isHash = isHash
msg.hashType = HashType(hashType)
msg.hashCode = hashCode
}

return msg
}
13 changes: 13 additions & 0 deletions tars/model/servant.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"github.com/TarsCloud/TarsGo/tars/protocol/res/requestf"
)

type Callback interface {
Dispatch(context.Context, *requestf.RequestPacket, *requestf.ResponsePacket, error) (int32, error)
}

// Servant is interface for call the remote server.
type Servant interface {
Name() string
Expand All @@ -17,6 +21,15 @@ type Servant interface {
status map[string]string,
context map[string]string,
resp *requestf.ResponsePacket) error

TarsInvokeAsync(ctx context.Context, cType byte,
sFuncName string,
buf []byte,
status map[string]string,
context map[string]string,
resp *requestf.ResponsePacket,
callback Callback) error

TarsSetTimeout(t int)
TarsSetProtocol(Protocol)
Endpoints() []*endpoint.Endpoint
Expand Down
160 changes: 93 additions & 67 deletions tars/servant.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/TarsCloud/TarsGo/tars/util/current"
"github.com/TarsCloud/TarsGo/tars/util/endpoint"
"github.com/TarsCloud/TarsGo/tars/util/rtimer"
"github.com/TarsCloud/TarsGo/tars/util/tools"
)

var (
Expand All @@ -31,13 +30,14 @@ const (

// ServantProxy tars servant proxy instance
type ServantProxy struct {
name string
comm *Communicator
manager EndpointManager
timeout int
version int16
proto model.Protocol
queueLen int32
name string
comm *Communicator
manager EndpointManager
syncTimeout int
asyncTimeout int
version int16
proto model.Protocol
queueLen int32

pushCallback func([]byte)
}
Expand All @@ -49,10 +49,11 @@ func NewServantProxy(comm *Communicator, objName string, opts ...EndpointManager

func newServantProxy(comm *Communicator, objName string, opts ...EndpointManagerOption) *ServantProxy {
s := &ServantProxy{
comm: comm,
proto: &protocol.TarsProtocol{},
timeout: comm.Client.AsyncInvokeTimeout,
version: basef.TARSVERSION,
comm: comm,
proto: &protocol.TarsProtocol{},
syncTimeout: comm.Client.AsyncInvokeTimeout,
asyncTimeout: comm.Client.AsyncInvokeTimeout,
version: basef.TARSVERSION,
}
pos := strings.Index(objName, "@")
if pos > 0 {
Expand All @@ -77,7 +78,7 @@ func (s *ServantProxy) Name() string {

// TarsSetTimeout sets the timeout for client calling the server , which is in ms.
func (s *ServantProxy) TarsSetTimeout(t int) {
s.timeout = t
s.syncTimeout = t
}

// TarsSetVersion set tars version
Expand Down Expand Up @@ -122,53 +123,42 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte,
resp *requestf.ResponsePacket) error {
defer CheckPanic()

// 将ctx中的dyeing信息传入到request中
var msgType int32
if dyeingKey, ok := current.GetDyeingKey(ctx); ok {
TLOG.Debug("dyeing debug: find dyeing key:", dyeingKey)
if status == nil {
status = make(map[string]string)
}
status[current.StatusDyedKey] = dyeingKey
msgType |= basef.TARSMESSAGETYPEDYED
msg := newMessage(ctx, cType, sFuncName, buf, status, reqContext, resp, s)
timeout := time.Duration(s.syncTimeout) * time.Millisecond
if err := s.invokeFilters(ctx, msg, timeout); err != nil {
return err
}
*resp = *msg.Resp
return nil
}

// 将ctx中的trace信息传入到request中
if trace, ok := current.GetTarsTrace(ctx); ok && trace.Call() {
traceKey := trace.GetTraceFullKey(false)
TLOG.Debug("trace debug: find trace key:", traceKey)
if status == nil {
status = make(map[string]string)
}
status[current.StatusTraceKey] = traceKey
msgType |= basef.TARSMESSAGETYPETRACE
}
// TarsInvokeAsync is used for client invoking server.
func (s *ServantProxy) TarsInvokeAsync(ctx context.Context, cType byte,
sFuncName string,
buf []byte,
status map[string]string,
reqContext map[string]string,
resp *requestf.ResponsePacket,
callback model.Callback) error {
defer CheckPanic()

req := requestf.RequestPacket{
IVersion: s.version,
CPacketType: int8(cType),
IRequestId: s.genRequestID(),
SServantName: s.name,
SFuncName: sFuncName,
SBuffer: tools.ByteToInt8(buf),
ITimeout: int32(s.timeout),
Context: reqContext,
Status: status,
IMessageType: msgType,
}
msg := &Message{Req: &req, Ser: s, Resp: resp}
msg.Init()

timeout := time.Duration(s.timeout) * time.Millisecond
if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok {
msg.isHash = isHash
msg.hashType = HashType(hashType)
msg.hashCode = hashCode
msg := newMessage(ctx, cType, sFuncName, buf, status, reqContext, resp, s)
msg.Req.ITimeout = int32(s.asyncTimeout)
if callback == nil {
msg.Req.CPacketType = basef.TARSONEWAY
} else {
msg.Async = true
msg.Callback = callback
}

timeout := time.Duration(s.asyncTimeout) * time.Millisecond
return s.invokeFilters(ctx, msg, timeout)
}

func (s *ServantProxy) invokeFilters(ctx context.Context, msg *Message, timeout time.Duration) error {
if ok, to, isTimeout := current.GetClientTimeout(ctx); ok && isTimeout {
timeout = time.Duration(to) * time.Millisecond
req.ITimeout = int32(to)
msg.Req.ITimeout = int32(to)
}

var err error
Expand Down Expand Up @@ -196,27 +186,32 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte,
}
}
}
s.manager.postInvoke()
// no async rpc call
if !msg.Async {
s.manager.postInvoke()
msg.End()
s.reportStat(msg, err)
}

return err
}

func (s *ServantProxy) reportStat(msg *Message, err error) {
if err != nil {
msg.End()
TLOG.Errorf("Invoke error: %s, %s, %v, cost:%d", s.name, sFuncName, err.Error(), msg.Cost())
TLOG.Errorf("Invoke error: %s, %s, %v, cost:%d", s.name, msg.Req.SFuncName, err.Error(), msg.Cost())
if msg.Resp == nil {
ReportStat(msg, StatSuccess, StatSuccess, StatFailed)
} else if msg.Status == basef.TARSINVOKETIMEOUT {
ReportStat(msg, StatSuccess, StatFailed, StatSuccess)
} else {
ReportStat(msg, StatSuccess, StatSuccess, StatFailed)
}
return err
return
}
msg.End()
*resp = *msg.Resp
ReportStat(msg, StatFailed, StatSuccess, StatSuccess)
return err
}

func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.Duration) error {
func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.Duration) (err error) {
adp, needCheck := s.manager.SelectAdapterProxy(msg)
if adp == nil {
return errors.New("no adapter Proxy selected:" + msg.Req.SServantName)
Expand All @@ -237,29 +232,60 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.
}

atomic.AddInt32(&s.queueLen, 1)
readCh := make(chan *requestf.ResponsePacket)
adp.resp.Store(msg.Req.IRequestId, readCh)
defer func() {
msg.RespCh = make(chan *requestf.ResponsePacket)
adp.resp.Store(msg.Req.IRequestId, msg.RespCh)
var releaseFunc = func() {
CheckPanic()
atomic.AddInt32(&s.queueLen, -1)
adp.resp.Delete(msg.Req.IRequestId)
}
defer func() {
if !msg.Async || err != nil {
releaseFunc()
}
}()
if err := adp.Send(msg.Req); err != nil {

if err = adp.Send(msg.Req); err != nil {
adp.failAdd()
return err
}

if msg.Req.CPacketType == basef.TARSONEWAY {
adp.successAdd()
return nil
}

// async call rpc
if msg.Async {
go func() {
defer releaseFunc()
err := s.waitResp(msg, timeout, needCheck)
s.manager.postInvoke()
msg.End()
s.reportStat(msg, err)
if msg.Status != basef.TARSINVOKETIMEOUT {
current.SetResponseContext(ctx, msg.Resp.Context)
current.SetResponseStatus(ctx, msg.Resp.Status)
}
if _, err := msg.Callback.Dispatch(ctx, msg.Req, msg.Resp, err); err != nil {
TLOG.Errorf("Callback error: %s, %s, %+v", s.name, msg.Req.SFuncName, err)
}
}()
return nil
}

return s.waitResp(msg, timeout, needCheck)
}

func (s *ServantProxy) waitResp(msg *Message, timeout time.Duration, needCheck bool) error {
adp := msg.Adp
select {
case <-rtimer.After(timeout):
msg.Status = basef.TARSINVOKETIMEOUT
adp.failAdd()
msg.End()
return fmt.Errorf("request timeout, begin time:%d, cost:%d, obj:%s, func:%s, addr:(%s:%d), reqid:%d",
msg.BeginTime, msg.Cost(), msg.Req.SServantName, msg.Req.SFuncName, adp.point.Host, adp.point.Port, msg.Req.IRequestId)
case msg.Resp = <-readCh:
case msg.Resp = <-msg.RespCh:
if needCheck {
go func() {
adp.reset()
Expand Down
Loading