Skip to content

Commit

Permalink
lb: Perform round-robin load balancing in server list
Browse files Browse the repository at this point in the history
The client will query the servers in a round-robin fashion with the
order of the server list it receives. We order the server list
differently for each client so that one-off calls do not hit always
the same server.

In order to increment the offset for each client we use an atomic
operation the increments a counter and resets its value when it wraps
around. The respective code is in the internal/atomic package.
  • Loading branch information
joa committed Mar 26, 2019
1 parent 05c7d86 commit c1d61e9
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 21 deletions.
4 changes: 1 addition & 3 deletions broadcast.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package main

import (
"context"
)
import "context"

type Listener chan<- ServerList

Expand Down
22 changes: 22 additions & 0 deletions internal/atomic/atomic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package atomic

import "sync/atomic"

// IncWrapInt64 atomically increments a 64-bit signed integer, wrapping around zero.
//
// Specifically if p points to a value of math.MaxInt64 the result of calling
// IncWrapInt64 will be 0.
func IncWrapInt64(p *int64) int64 {
for {
o := atomic.LoadInt64(p)
n := o + 1

if n < 0 {
n = 0
}

if atomic.CompareAndSwapInt64(p, o, n) {
return n
}
}
}
46 changes: 46 additions & 0 deletions internal/atomic/atomic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package atomic

import (
"math"
"sync"
"testing"
)

func TestIncWrapInt64(t *testing.T) {
const n = 1000
const m = n - 1

var x int64

var wg sync.WaitGroup
wg.Add(m)

for i := 0; i < m; i++ {
go func() {
IncWrapInt64(&x)
wg.Done()
}()
}

wg.Wait()

if y := IncWrapInt64(&x); y != n {
t.Errorf("expected %d, got %d", n, y)
}

x = math.MaxInt64 - m
wg.Add(m)

for i := 0; i < m; i++ {
go func() {
IncWrapInt64(&x)
wg.Done()
}()
}

wg.Wait()

if y := IncWrapInt64(&x); y != 0 {
t.Errorf("expected %d, got %d", 0, y)
}
}
29 changes: 23 additions & 6 deletions lb.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package main

import (
"github.com/joa/jawlb/internal/atomic"
grpclb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type lb struct {
b *broadcast
b *broadcast
rr int64
}

func (l *lb) BalanceLoad(req grpclb.LoadBalancer_BalanceLoadServer) error {
Expand All @@ -33,16 +35,14 @@ func (l *lb) BalanceLoad(req grpclb.LoadBalancer_BalanceLoadServer) error {
l.b.addListener(ch)
defer l.b.remListener(ch)

offset := int(atomic.IncWrapInt64(&l.rr))

for {
select {
case <-req.Context().Done():
return nil
case msg := <-ch:
var servers []*grpclb.Server

for _, server := range msg {
servers = append(servers, &grpclb.Server{IpAddress: server.IP, Port: server.Port})
}
servers := convertServerList(msg, offset)

err := req.Send(&grpclb.LoadBalanceResponse{
LoadBalanceResponseType: &grpclb.LoadBalanceResponse_ServerList{
Expand All @@ -58,3 +58,20 @@ func (l *lb) BalanceLoad(req grpclb.LoadBalancer_BalanceLoadServer) error {
}
}
}

func convertServerList(l ServerList, offset int) []*grpclb.Server {
var servers []*grpclb.Server

n := len(l)

for i := 0; i < n; i++ {
server := l[(i+offset)%n]
servers = append(servers, convertServer(server))
}

return servers
}

func convertServer(s Server) *grpclb.Server {
return &grpclb.Server{IpAddress: s.IP, Port: s.Port}
}
36 changes: 36 additions & 0 deletions lb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package main

import (
"net"
"testing"
)

func TestConvertServerList(t *testing.T) {
l := ServerList{Server{net.IP{}, 1}, Server{net.IP{}, 2}, Server{net.IP{}, 3}}

table := []struct {
offset int
want []int32
}{
{0, []int32{1, 2, 3}},
{1, []int32{2, 3, 1}},
{2, []int32{3, 1, 2}},
{3, []int32{1, 2, 3}},
{4, []int32{2, 3, 1}},
}

for _, tt := range table {
res := convertServerList(l, tt.offset)

if len(res) != len(tt.want) {
t.Errorf("expected len %d, got %d", len(tt.want), len(res))
continue
}

for i := range res {
if res[i].Port != tt.want[i] {
t.Errorf("expected %d at %d, got %d", tt.want[i], i, res[i].Port)
}
}
}
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func startServer(bc *broadcast) *grpc.Server {
}

srv := grpc.NewServer()
grpclb.RegisterLoadBalancerServer(srv, &lb{bc})
grpclb.RegisterLoadBalancerServer(srv, &lb{bc, 0})

go func() {
if err := srv.Serve(conn); err != nil {
Expand Down
14 changes: 14 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package main

import "net"

type ServerList []Server

type Server struct {
IP net.IP
Port int32
}

func (s Server) Equal(x Server) bool {
return s.Port == x.Port && s.IP.Equal(x.IP)
}
11 changes: 0 additions & 11 deletions watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,6 @@ import (
"k8s.io/client-go/tools/clientcmd"
)

type Server struct {
IP net.IP
Port int32
}

func (s Server) Equal(x Server) bool {
return s.Port == x.Port && s.IP.Equal(x.IP)
}

type ServerList []Server

func watchService(ctx context.Context) (_ <-chan ServerList, err error) {
icc, err := getConfig()

Expand Down

0 comments on commit c1d61e9

Please sign in to comment.