Skip to content

Commit

Permalink
feat(abciclient)!: limit concurrent gRPC connections (#775)
Browse files Browse the repository at this point in the history
* feat(abciclient): limit concurrent gRPC connections

* build(deps): bump bufbuild to fix build errors

* chore: fix lint warnings

* fix(privval): fix error handling

* test(privval): fix grpc tests

* feat(abciclient): grpc conn rate limit per service method

* refactor(config)!: add [abci] section with grpc-concurrency, proxy-app and transport

* chore(confix): config update plan

* refactor(config)!: rename proxy-app to address

* fix(config): empty transport and address are allowed

* fix(confix): adjust plan
  • Loading branch information
lklimek authored Apr 22, 2024
1 parent 6dff7b3 commit 8383d57
Show file tree
Hide file tree
Showing 25 changed files with 655 additions and 307 deletions.
8 changes: 6 additions & 2 deletions abci/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
sync "github.com/sasha-s/go-deadlock"

"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/libs/log"
"github.com/dashpay/tenderdash/libs/service"
)
Expand Down Expand Up @@ -36,12 +37,15 @@ type Client interface {

// NewClient returns a new ABCI client of the specified transport type.
// It returns an error if the transport is not "socket" or "grpc"
func NewClient(logger log.Logger, addr, transport string, mustConnect bool) (Client, error) {
func NewClient(logger log.Logger, cfg config.AbciConfig, mustConnect bool) (Client, error) {
transport := cfg.Transport
addr := cfg.Address

switch transport {
case "socket":
return NewSocketClient(logger, addr, mustConnect), nil
case "grpc":
return NewGRPCClient(logger, addr, mustConnect), nil
return NewGRPCClient(logger, addr, cfg.GrpcConcurrency, mustConnect), nil
case "routed":
return NewRoutedClientWithAddr(logger, addr, mustConnect)
default:
Expand Down
99 changes: 97 additions & 2 deletions abci/client/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"strings"
"time"

sync "github.com/sasha-s/go-deadlock"
Expand All @@ -30,36 +31,130 @@ type grpcClient struct {
mtx sync.Mutex
addr string
err error

// map between method name (in grpc format, for example `/tendermint.abci.ABCIApplication/Echo`)
// and a channel that will be used to limit the number of concurrent requests for that method.
//
// If the value is nil, no limit is enforced.
//
// Not thread-safe, only modify this before starting the client.
concurrency map[string]chan struct{}
}

var _ Client = (*grpcClient)(nil)

// NewGRPCClient creates a gRPC client, which will connect to addr upon the
// start. Note Client#Start returns an error if connection is unsuccessful and
// mustConnect is true.
func NewGRPCClient(logger log.Logger, addr string, mustConnect bool) Client {
func NewGRPCClient(logger log.Logger, addr string, concurrency map[string]uint16, mustConnect bool) Client {
cli := &grpcClient{
logger: logger,
addr: addr,
mustConnect: mustConnect,
concurrency: make(map[string]chan struct{}, 20),
}
cli.BaseService = *service.NewBaseService(logger, "grpcClient", cli)
cli.SetMaxConcurrentStreams(concurrency)

return cli
}

func methodID(method string) string {
if pos := strings.LastIndex(method, "/"); pos > 0 {
method = method[pos+1:]
}

return strings.ToLower(method)
}

// SetMaxConcurrentStreams sets the maximum number of concurrent streams to be
// allowed on this client.
//
// Not thread-safe, only use this before starting the client.
//
// If limit is 0, no limit is enforced.
func (cli *grpcClient) SetMaxConcurrentStreamsForMethod(method string, n uint16) {
if cli.IsRunning() {
panic("cannot set max concurrent streams after starting the client")
}
var ch chan struct{}
if n != 0 {
ch = make(chan struct{}, n)
}

cli.mtx.Lock()
cli.concurrency[methodID(method)] = ch
cli.mtx.Unlock()
}

// SetMaxConcurrentStreams sets the maximum number of concurrent streams to be
// allowed on this client.
// # Arguments
//
// * `methods` - A map between method name (in grpc format, for example `/tendermint.abci.ABCIApplication/Echo`)
// and the maximum number of concurrent streams to be allowed for that method.
//
// Special method name "*" can be used to set the default limit for methods not explicitly listed.
//
// If the value is 0, no limit is enforced.
//
// Not thread-safe, only use this before starting the client.
func (cli *grpcClient) SetMaxConcurrentStreams(methods map[string]uint16) {
for method, n := range methods {
cli.SetMaxConcurrentStreamsForMethod(method, n)
}
}

func dialerFunc(_ctx context.Context, addr string) (net.Conn, error) {
return tmnet.Connect(addr)
}

// rateLimit blocks until the client is allowed to send a request.
// It returns a function that should be called after the request is done.
//
// method should be the method name in grpc format, for example `/tendermint.abci.ABCIApplication/Echo`.
// Special method name "*" can be used to define the default limit.
// If no limit is set for the method, the default limit is used.
func (cli *grpcClient) rateLimit(method string) context.CancelFunc {
ch := cli.concurrency[methodID(method)]
// handle default
if ch == nil {
ch = cli.concurrency["*"]
}
if ch == nil {
return func() {}
}

cli.logger.Trace("grpcClient rateLimit", "addr", cli.addr)
ch <- struct{}{}
return func() { <-ch }
}

func (cli *grpcClient) unaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
done := cli.rateLimit(method)
defer done()

return invoker(ctx, method, req, reply, cc, opts...)
}

func (cli *grpcClient) streamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
done := cli.rateLimit(method)
defer done()

return streamer(ctx, desc, cc, method, opts...)
}

func (cli *grpcClient) OnStart(ctx context.Context) error {
timer := time.NewTimer(0)
defer timer.Stop()

RETRY_LOOP:
for {
conn, err := grpc.Dial(cli.addr,
conn, err := grpc.NewClient(cli.addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(dialerFunc),
grpc.WithChainUnaryInterceptor(cli.unaryClientInterceptor),
grpc.WithChainStreamInterceptor(cli.streamClientInterceptor),
)
if err != nil {
if cli.mustConnect {
Expand Down
140 changes: 95 additions & 45 deletions abci/client/grpc_client_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package abciclient_test
package abciclient

import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/fortytw2/leaktest"
"github.com/stretchr/testify/assert"

abciclient "github.com/dashpay/tenderdash/abci/client"
abciserver "github.com/dashpay/tenderdash/abci/server"
"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/libs/log"
Expand All @@ -18,47 +19,87 @@ import (

// TestGRPCClientServerParallel tests that gRPC client and server can handle multiple parallel requests
func TestGRPCClientServerParallel(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

logger := log.NewNopLogger()
app := &mockApplication{t: t}

socket := t.TempDir() + "/grpc_test"
client, _, err := makeGRPCClientServer(ctx, t, logger, app, socket)
if err != nil {
t.Fatal(err)
const (
timeout = 1 * time.Second
tick = 10 * time.Millisecond
)

type testCase struct {
threads int
infoConcurrency uint16
defautConcurrency uint16
}

// we'll use that mutex to ensure threads don't finish before we check status
app.mtx.Lock()

const threads = 5
// started will be marked as done as soon as app.Info() handler executes on the server
app.started.Add(threads)
// done will be used to wait for all threads to finish
var done sync.WaitGroup
done.Add(threads)

for i := 0; i < threads; i++ {
thread := uint64(i)
go func() {
_, _ = client.Info(ctx, &types.RequestInfo{BlockVersion: thread})
done.Done()
}()
testCases := []testCase{
{threads: 1, infoConcurrency: 1},
{threads: 2, infoConcurrency: 1},
{threads: 2, infoConcurrency: 2},
{threads: 5, infoConcurrency: 0},
{threads: 5, infoConcurrency: 0, defautConcurrency: 2},
{threads: 5, infoConcurrency: 1},
{threads: 5, infoConcurrency: 2},
{threads: 5, infoConcurrency: 5},
}

// wait for threads to execute
// note it doesn't mean threads are really done, as they are waiting on the mtx
// so if all `started` are marked as done, it means all threads have started
// in parallel
app.started.Wait()

// unlock the mutex so that threads can finish their execution
app.mtx.Unlock()
logger := log.NewNopLogger()

// wait for all threads to really finish
done.Wait()
for _, tc := range testCases {
t.Run(fmt.Sprintf("t_%d-i_%d,d_%d", tc.threads, tc.infoConcurrency, tc.defautConcurrency), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

app := &mockApplication{t: t, concurrencyLimit: int32(tc.infoConcurrency)}

socket := t.TempDir() + "/grpc_test"
limits := map[string]uint16{
"/tendermint.abci.ABCIApplication/Info": tc.infoConcurrency,
"*": tc.defautConcurrency,
}

client, _, err := makeGRPCClientServer(ctx, t, logger, app, socket, limits)
if err != nil {
t.Fatal(err)
}

// we'll use that mutex to ensure threads don't finish before we check status
app.mtx.Lock()

// done will be used to wait for all threads to finish
var done sync.WaitGroup

for i := 0; i < tc.threads; i++ {
done.Add(1)
thread := uint64(i)
go func() {
// we use BlockVersion for logging purposes, so we put thread id there
_, _ = client.Info(ctx, &types.RequestInfo{BlockVersion: thread})
done.Done()
}()
}

expectThreads := int32(tc.infoConcurrency)
if expectThreads == 0 {
expectThreads = int32(tc.defautConcurrency)
}
if expectThreads == 0 {
expectThreads = int32(tc.threads)
}

// wait for all threads to start execution
assert.Eventually(t, func() bool {
return app.running.Load() == expectThreads
}, timeout, tick, "not all threads started in time")

// ensure no other threads will start
time.Sleep(2 * tick)

// unlock the mutex so that threads can finish their execution
app.mtx.Unlock()

// wait for all threads to really finish
done.Wait()
})
}
}

func makeGRPCClientServer(
Expand All @@ -67,7 +108,8 @@ func makeGRPCClientServer(
logger log.Logger,
app types.Application,
name string,
) (abciclient.Client, service.Service, error) {
concurrency map[string]uint16,
) (Client, service.Service, error) {
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
t.Cleanup(leaktest.Check(t))
Expand All @@ -82,7 +124,7 @@ func makeGRPCClientServer(
return nil, nil, err
}

client := abciclient.NewGRPCClient(logger.With("module", "abci-client"), socket, true)
client := NewGRPCClient(logger.With("module", "abci-client"), socket, concurrency, true)

if err := client.Start(ctx); err != nil {
cancel()
Expand All @@ -96,19 +138,27 @@ func makeGRPCClientServer(
type mockApplication struct {
types.BaseApplication
mtx sync.Mutex
// we'll use that to ensure all threads have started
started sync.WaitGroup

running atomic.Int32
// concurrencyLimit of concurrent requests
concurrencyLimit int32

t *testing.T
}

func (m *mockApplication) Info(_ctx context.Context, req *types.RequestInfo) (res *types.ResponseInfo, err error) {
m.t.Logf("Info %d called", req.BlockVersion)
// mark wg as done to signal that we have executed
m.started.Done()
// we will wait here until all threads mark wg as done
running := m.running.Add(1)
defer m.running.Add(-1)

if m.concurrencyLimit > 0 {
assert.LessOrEqual(m.t, running, m.concurrencyLimit, "too many requests running in parallel")
}

// we will wait here until all expected threads are running
m.mtx.Lock()
defer m.mtx.Unlock()
m.t.Logf("Info %d finished", req.BlockVersion)

return &types.ResponseInfo{}, nil
}
4 changes: 3 additions & 1 deletion abci/client/routed_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/hashicorp/go-multierror"

"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/libs/log"
"github.com/dashpay/tenderdash/libs/service"
)
Expand Down Expand Up @@ -71,7 +72,8 @@ func NewRoutedClientWithAddr(logger log.Logger, addr string, mustConnect bool) (
// Create a new client if it doesn't exist
clientName := fmt.Sprintf("%s:%s", transport, address)
if _, ok := clients[clientName]; !ok {
c, err := NewClient(logger, address, transport, mustConnect)
cfg := config.AbciConfig{Address: address, Transport: transport}
c, err := NewClient(logger, cfg, mustConnect)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion abci/cmd/abci-cli/abci-cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/dashpay/tenderdash/abci/server"
servertest "github.com/dashpay/tenderdash/abci/tests/server"
"github.com/dashpay/tenderdash/abci/types"
"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/libs/log"
"github.com/dashpay/tenderdash/proto/tendermint/crypto"
tmproto "github.com/dashpay/tenderdash/proto/tendermint/types"
Expand Down Expand Up @@ -64,7 +65,8 @@ func RootCmmand(logger log.Logger) *cobra.Command {

if client == nil {
var err error
client, err = abciclient.NewClient(logger.With("module", "abci-client"), flagAddress, flagAbci, false)
cfg := config.AbciConfig{Address: flagAddress, Transport: flagAbci}
client, err = abciclient.NewClient(logger.With("module", "abci-client"), cfg, false)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 8383d57

Please sign in to comment.