diff --git a/action/actions.go b/action/actions.go index 199fa99..058f103 100644 --- a/action/actions.go +++ b/action/actions.go @@ -1,6 +1,5 @@ package action -// TODO: Drop this type and use plain []Action. type Actions []Action func (actions *Actions) SetVar(scope Scope, name string, value interface{}) { diff --git a/action/marshal.go b/action/marshal.go index 4519166..8739306 100644 --- a/action/marshal.go +++ b/action/marshal.go @@ -2,12 +2,12 @@ package action import ( "fmt" + "github.com/negasus/haproxy-spoe-go/typeddata" "github.com/negasus/haproxy-spoe-go/varint" ) func (action *Action) Marshal(buf []byte) ([]byte, error) { - var nb byte switch action.Type { diff --git a/client/client.go b/client/client.go index 079a766..26ae327 100644 --- a/client/client.go +++ b/client/client.go @@ -4,24 +4,24 @@ import ( "bufio" "bytes" "fmt" - "github.com/negasus/haproxy-spoe-go/frame" - _ "github.com/negasus/haproxy-spoe-go/request" "io" "net" + + "github.com/negasus/haproxy-spoe-go/frame" ) -/// Client is a simple client for spop protocol, this should only be used for testing purpose +// Client is a simple client for spop protocol, this should only be used for testing purpose type Client struct { conn net.Conn reader io.Reader } -/// NewClient create a new Client for an established connection +// NewClient create a new Client for an established connection func NewClient(conn net.Conn) Client { return Client{conn: conn, reader: bufio.NewReader(conn)} } -/// Init initialize the client by sending the HaproxyHello frame +// Init initialize the client by sending the HaproxyHello frame func (c *Client) Init() error { f := frame.AcquireFrame() defer frame.ReleaseFrame(f) @@ -70,7 +70,7 @@ func (c *Client) send(f *frame.Frame) error { return nil } -/// Notify send an empty Notify frame +// Notify send an empty Notify frame func (c *Client) Notify() error { f := frame.AcquireFrame() defer frame.ReleaseFrame(f) @@ -89,7 +89,7 @@ func (c *Client) Notify() error { return nil } -/// Stop the client by sending HaproxyDisconnect frame +// Stop the client by sending HaproxyDisconnect frame func (c *Client) Stop() error { f := frame.AcquireFrame() defer frame.ReleaseFrame(f) diff --git a/frame/encode.go b/frame/encode.go index b8ec84d..6eb7312 100644 --- a/frame/encode.go +++ b/frame/encode.go @@ -10,7 +10,6 @@ import ( ) func (f *Frame) Encode(dest io.Writer) (n int, err error) { - buf := bytes.Buffer{} buf.WriteByte(byte(f.Type)) @@ -45,7 +44,7 @@ func (f *Frame) Encode(dest io.Writer) (n int, err error) { } case TypeNotify: if len(*f.Messages) > 0 { - err = fmt.Errorf("Encoding Notify frame with Message isn't handled yet") + err = fmt.Errorf("encoding Notify frame with Message isn't handled yet") return } diff --git a/frame/encode_test.go b/frame/encode_test.go index f7f9483..1dd4dba 100644 --- a/frame/encode_test.go +++ b/frame/encode_test.go @@ -3,8 +3,6 @@ package frame import ( "bytes" "encoding/binary" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "testing" ) @@ -17,15 +15,26 @@ func TestFrame_Write(t *testing.T) { f.KV.Add("key2", "val2") buf := &bytes.Buffer{} frameSize, err := f.Encode(buf) - require.Nil(t, err) + if err != nil { + t.Fatalf("expect err is nil, got %v", err) + } bufBytes := buf.Bytes() encodedFrameSize := int(binary.BigEndian.Uint32(bufBytes[0:4])) - assert.Equal(t, frameSize-4, encodedFrameSize, "frame size") - assert.Equal(t, "key1", string(bufBytes[13:17])) - assert.Equal(t, "val1", string(bufBytes[19:23])) - assert.Equal(t, "key2", string(bufBytes[24:28])) - assert.Equal(t, "val2", string(bufBytes[30:34])) - + if frameSize-4 != encodedFrameSize { + t.Fatal("wrong frame size") + } + if string(bufBytes[13:17]) != "key1" { + t.Fatal("expect key1") + } + if string(bufBytes[19:23]) != "val1" { + t.Fatal("expect val1") + } + if string(bufBytes[24:28]) != "key2" { + t.Fatal("expect key2") + } + if string(bufBytes[30:34]) != "val2" { + t.Fatal("expect val1") + } } func BenchmarkFrame_Encode(b *testing.B) { diff --git a/frame/frame.go b/frame/frame.go index 389b560..a6eea10 100644 --- a/frame/frame.go +++ b/frame/frame.go @@ -35,7 +35,7 @@ func ReleaseFrame(frame *Frame) { framePool.Put(frame) } -//Frame describe frame struct +// Frame describe frame struct type Frame struct { Len uint32 Type Type @@ -80,12 +80,12 @@ func (f *Frame) Reset() { f.KV.Reset() } -//IsFin returns true, if frame has flag 'FIN' +// IsFin returns true, if frame has flag 'FIN' func (f *Frame) IsFin() bool { return f.Flags&0x01 > 0 } -//IsAbort returns true, if frame has flag 'ABORT' +// IsAbort returns true, if frame has flag 'ABORT' func (f *Frame) IsAbort() bool { return f.Flags&0x02 > 0 } diff --git a/frame/read.go b/frame/read.go index a319808..eff3125 100644 --- a/frame/read.go +++ b/frame/read.go @@ -3,8 +3,9 @@ package frame import ( "encoding/binary" "fmt" - "github.com/negasus/haproxy-spoe-go/varint" "io" + + "github.com/negasus/haproxy-spoe-go/varint" ) func (f *Frame) Read(src io.Reader) error { @@ -23,7 +24,7 @@ func (f *Frame) Read(src io.Reader) error { f.Type = Type(f.tmp[4]) // Drop packet that doesn't have defined frame type early, before allocating any buffers - // that way spurious connectons (say someone calling curl on port) won't cause it to + // that way spurious connections (say someone calling curl on port) won't cause it to // allocate gigabytes of RAM switch f.Type { case TypeHaproxyHello, TypeHaproxyDisconnect, TypeNotify, TypeAgentHello, TypeAgentDisconnect, TypeAgentAck: diff --git a/frame/read_test.go b/frame/read_test.go index 3471f1b..e12b837 100644 --- a/frame/read_test.go +++ b/frame/read_test.go @@ -2,8 +2,6 @@ package frame import ( "bytes" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "io" "testing" ) @@ -25,17 +23,33 @@ func TestFrame_Read(t *testing.T) { r := bytes.NewBuffer(testFrame) f := NewFrame() err := f.Read(r) - require.Nil(t, err) - assert.Equal(t, 1, int(f.FrameID), "FrameID") - assert.Equal(t, 542, int(f.StreamID), "StreamID") - assert.Equal(t, TypeNotify, f.Type) + if err != nil { + t.Fatal() + } + if int(f.FrameID) != 1 { + t.Fatal("wrong FrameID") + } + if int(f.StreamID) != 542 { + t.Fatal("wrong StreamID") + } + if f.Type != TypeNotify { + t.Fatal("wrong type") + } messages := *f.Messages - require.Len(t, messages, 1) + if len(messages) != 1 { + t.Fatal("wrong messages len") + } host, found := messages[0].KV.Get("host") - require.True(t, found, "key host") + if !found { + t.Fatal("host not found") + } hostString, ok := host.(string) - require.True(t, ok) - assert.Equal(t, "domain.example.com", hostString) + if !ok { + t.Fatal("error convert host to string") + } + if hostString != "domain.example.com" { + t.Fatal("wrong hostString") + } } func BenchmarkFrame_Read(b *testing.B) { diff --git a/go.mod b/go.mod index 6f19403..d8e9367 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,3 @@ module github.com/negasus/haproxy-spoe-go go 1.19 - -require ( - github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.8.0 -) - -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/go.sum b/go.sum index de8e9cb..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/message/decode.go b/message/decode.go index 07474ac..3e86ad8 100644 --- a/message/decode.go +++ b/message/decode.go @@ -5,7 +5,6 @@ import ( ) func (m *Messages) Decode(buf []byte) error { - for { if len(buf) == 0 { break diff --git a/message/decode_test.go b/message/decode_test.go index 28c197a..58e55ef 100644 --- a/message/decode_test.go +++ b/message/decode_test.go @@ -1,8 +1,7 @@ package message import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "bytes" "testing" ) @@ -16,29 +15,52 @@ func TestDecode(t *testing.T) { } err := mess.Decode(buf) - require.NoError(t, err) - require.Equal(t, 2, mess.Len()) + if err != nil { + t.Fatal("unexpected error") + } + if mess.Len() != 2 { + t.Fatalf("mess.Len must be 2, got %d", mess.Len()) + } // First message m, err := mess.GetByIndex(0) - require.NoError(t, err) - assert.Equal(t, "Foo", m.Name) + if err != nil { + t.Fatal("unexpected error") + } + if m.Name != "Foo" { + t.Fatalf("m.Name must be Foo, got %s", m.Name) + } v, ok := m.KV.Get("Bar") - require.True(t, ok) - require.Equal(t, []byte{0x10, 0x20, 0x30}, v) + if !ok { + t.Fatal("ok is not true") + } + if !bytes.Equal([]byte{0x10, 0x20, 0x30}, v.([]byte)) { + t.Fatal("invalid result") + } v, ok = m.KV.Get("Vals") - require.True(t, ok) - require.Equal(t, "Baz", v) + if !ok { + t.Fatal("ok is not true") + } + if v != "Baz" { + t.Fatalf("v must be Baz, got %s", v) + } // Second message m, err = mess.GetByIndex(1) - require.NoError(t, err) - assert.Equal(t, "UI", m.Name) + if err != nil { + t.Fatal("unexpected error") + } + if m.Name != "UI" { + t.Fatalf("m.Name must be UI, got %s", m.Name) + } v, ok = m.KV.Get("Fee") - require.True(t, ok) - require.Equal(t, int32(10), v) - + if !ok { + t.Fatal("ok is not true") + } + if v != int32(10) { + t.Fatalf("v must be int32(10), got %d", v) + } } diff --git a/message/message.go b/message/message.go index 279c045..55bda9c 100644 --- a/message/message.go +++ b/message/message.go @@ -1,8 +1,9 @@ package message import ( - "github.com/negasus/haproxy-spoe-go/payload/kv" "sync" + + "github.com/negasus/haproxy-spoe-go/payload/kv" ) var messagePool = sync.Pool{ diff --git a/message/messages.go b/message/messages.go index 9d8bda3..c629949 100644 --- a/message/messages.go +++ b/message/messages.go @@ -1,6 +1,8 @@ package message -import "errors" +import ( + "errors" +) var ( ErrMessageNotFound = errors.New("message not found") diff --git a/payload/kv/kv.go b/payload/kv/kv.go index ed9a5a6..c2b72ca 100644 --- a/payload/kv/kv.go +++ b/payload/kv/kv.go @@ -2,9 +2,10 @@ package kv import ( "fmt" + "sync" + "github.com/negasus/haproxy-spoe-go/typeddata" "github.com/negasus/haproxy-spoe-go/varint" - "sync" ) var kvPool = sync.Pool{ diff --git a/typeddata/typeddata.go b/typeddata/typeddata.go index bb7d5c7..63e7bc1 100644 --- a/typeddata/typeddata.go +++ b/typeddata/typeddata.go @@ -1,13 +1,12 @@ package typeddata import ( + "errors" "fmt" "net" "reflect" "github.com/negasus/haproxy-spoe-go/varint" - - "github.com/pkg/errors" ) const ( diff --git a/typeddata/typeddata_test.go b/typeddata/typeddata_test.go index 38b8665..cb22dc5 100644 --- a/typeddata/typeddata_test.go +++ b/typeddata/typeddata_test.go @@ -1,47 +1,84 @@ package typeddata import ( + "bytes" "testing" - - "github.com/stretchr/testify/assert" ) func TestEncode_Nil(t *testing.T) { buf, n, err := Encode(nil, make([]byte, 0)) - assert.Nil(t, err) - assert.Equal(t, 1, n) - assert.Equal(t, 1, len(buf)) - assert.Equal(t, byte(0x00), buf[0]) + if err != nil { + t.Fatal("unexpected error") + } + if n != 1 { + t.Fatalf("n must be 1, got %d", n) + } + if len(buf) != 1 { + t.Fatalf("buf len must be 1, got %d", len(buf)) + } + if buf[0] != 0x00 { + t.Fatalf("invalid buf value") + } } func TestEncode_Bool(t *testing.T) { buf, n, err := Encode(false, make([]byte, 0)) - assert.Nil(t, err) - assert.Equal(t, 1, n) - assert.Equal(t, 1, len(buf)) - assert.Equal(t, byte(0x10), buf[0]) + if err != nil { + t.Fatal("unexpected error") + } + if n != 1 { + t.Fatalf("n must be 1, got %d", n) + } + if len(buf) != 1 { + t.Fatalf("buf len must be 1, got %d", len(buf)) + } + if buf[0] != 0x10 { + t.Fatalf("invalid buf value") + } buf, n, err = Encode(true, make([]byte, 0)) - assert.Nil(t, err) - assert.Equal(t, 1, n) - assert.Equal(t, 1, len(buf)) - assert.Equal(t, byte(0x11), buf[0]) + if err != nil { + t.Fatal("unexpected error") + } + if n != 1 { + t.Fatalf("n must be 1, got %d", n) + } + if len(buf) != 1 { + t.Fatalf("buf len must be 1, got %d", len(buf)) + } + if buf[0] != 0x11 { + t.Fatalf("invalid buf value") + } } func TestEncode_Int32(t *testing.T) { buf, n, err := Encode(int32(100500), make([]byte, 0)) - assert.Nil(t, err) - assert.Equal(t, 4, n) - assert.Equal(t, 4, len(buf)) - assert.Equal(t, []byte{0x02, 0xF4, 0xFA, 0x2F}, buf) + if err != nil { + t.Fatal("unexpected error") + } + if n != 4 { + t.Fatalf("n must be 4, got %d", n) + } + if len(buf) != 4 { + t.Fatalf("buf len must be 4, got %d", len(buf)) + } + if !bytes.Equal(buf, []byte{0x02, 0xF4, 0xFA, 0x2F}) { + t.Fatalf("invalid buf value") + } } func TestEncode_Binary(t *testing.T) { buf, n, err := Encode([]byte{0x10, 0x20, 0x30}, make([]byte, 0)) - assert.Nil(t, err) - assert.Equal(t, 5, n) - assert.Equal(t, 5, len(buf)) - assert.Equal(t, []byte{0x09, 0x03, 0x10, 0x20, 0x30}, buf) + if err != nil { + t.Fatal("unexpected error") + } + if n != 5 { + t.Fatalf("n must be 4, got %d", n) + } + if len(buf) != 5 { + t.Fatalf("buf len must be 4, got %d", len(buf)) + } + if !bytes.Equal(buf, []byte{0x09, 0x03, 0x10, 0x20, 0x30}) { + t.Fatalf("invalid buf value") + } } - -// todo: tests diff --git a/varint/varint_test.go b/varint/varint_test.go index 6820fb1..414d8d3 100644 --- a/varint/varint_test.go +++ b/varint/varint_test.go @@ -1,7 +1,6 @@ package varint import ( - "github.com/stretchr/testify/assert" "testing" ) @@ -11,29 +10,59 @@ func TestPutUVarint(t *testing.T) { buf := make([]byte, 10) n = PutUvarint(buf, 239) - assert.Equal(t, 1, n) - assert.Equal(t, byte(0xEF), buf[0]) + if n != 1 { + t.Fatalf("n must be 1, got %d", n) + } + if buf[0] != 0xEF { + t.Fatal("wrong buf value") + } n = PutUvarint(buf, 240) - assert.Equal(t, 2, n) - assert.Equal(t, byte(0xF0), buf[0]) - assert.Equal(t, byte(0x00), buf[1]) + if n != 2 { + t.Fatalf("n must be 2, got %d", n) + } + if buf[0] != 0xF0 { + t.Fatal("wrong buf value") + } + if buf[1] != 0x00 { + t.Fatal("wrong buf value") + } n = PutUvarint(buf, 256) - assert.Equal(t, 2, n) - assert.Equal(t, byte(0xF0), buf[0]) - assert.Equal(t, byte(0x01), buf[1]) + if n != 2 { + t.Fatalf("n must be 2, got %d", n) + } + if buf[0] != 0xF0 { + t.Fatal("wrong buf value") + } + if buf[1] != 0x01 { + t.Fatal("wrong buf value") + } n = PutUvarint(buf, 2287) - assert.Equal(t, 2, n) - assert.Equal(t, byte(0xFF), buf[0]) - assert.Equal(t, byte(0x7F), buf[1]) + if n != 2 { + t.Fatalf("n must be 2, got %d", n) + } + if buf[0] != 0xFF { + t.Fatal("wrong buf value") + } + if buf[1] != 0x7F { + t.Fatal("wrong buf value") + } n = PutUvarint(buf, 2289) - assert.Equal(t, 3, n) - assert.Equal(t, byte(0xF1), buf[0]) - assert.Equal(t, byte(0x80), buf[1]) - assert.Equal(t, byte(0x00), buf[2]) + if n != 3 { + t.Fatalf("n must be 3, got %d", n) + } + if buf[0] != 0xF1 { + t.Fatal("wrong buf value") + } + if buf[1] != 0x80 { + t.Fatal("wrong buf value") + } + if buf[2] != 0x00 { + t.Fatal("wrong buf value") + } } func TestGetUVarint(t *testing.T) { @@ -41,28 +70,52 @@ func TestGetUVarint(t *testing.T) { var c int n, c = Uvarint([]byte{0xF0}) - assert.Equal(t, uint64(0), n) - assert.Equal(t, -1, c) + if n != uint64(0) { + t.Fatalf("n must be 0, got %d", n) + } + if c != -1 { + t.Fatalf("c must be -1, got %d", c) + } n, c = Uvarint([]byte{0xEF}) - assert.Equal(t, uint64(239), n) - assert.Equal(t, 1, c) + if n != uint64(239) { + t.Fatalf("n must be 239, got %d", n) + } + if c != 1 { + t.Fatalf("c must be 1, got %d", c) + } n, c = Uvarint([]byte{0xF1, 0x00}) - assert.Equal(t, uint64(241), n) - assert.Equal(t, 2, c) + if n != uint64(241) { + t.Fatalf("n must be 241, got %d", n) + } + if c != 2 { + t.Fatalf("c must be 2, got %d", c) + } n, c = Uvarint([]byte{0xF0, 0x01}) - assert.Equal(t, uint64(256), n) - assert.Equal(t, 2, c) + if n != uint64(256) { + t.Fatalf("n must be 256, got %d", n) + } + if c != 2 { + t.Fatalf("c must be 2, got %d", c) + } n, c = Uvarint([]byte{0xFF, 0x7F}) - assert.Equal(t, uint64(2287), n) - assert.Equal(t, 2, c) + if n != uint64(2287) { + t.Fatalf("n must be 2287, got %d", n) + } + if c != 2 { + t.Fatalf("c must be 2, got %d", c) + } n, c = Uvarint([]byte{0xF1, 0x80, 0x00}) - assert.Equal(t, uint64(2289), n) - assert.Equal(t, 3, c) + if n != uint64(2289) { + t.Fatalf("n must be 2289, got %d", n) + } + if c != 3 { + t.Fatalf("c must be 3, got %d", c) + } } func TestLoop(t *testing.T) { @@ -71,6 +124,8 @@ func TestLoop(t *testing.T) { for i := 0; i < 1e6; i++ { PutUvarint(buf, uint64(i)) n, _ := Uvarint(buf) - assert.Equal(t, uint64(i), n) + if n != uint64(i) { + t.Fatal("unexpected n value") + } } } diff --git a/worker/frame_agenthello.go b/worker/frame_agenthello.go index 12958e5..4fc6b75 100644 --- a/worker/frame_agenthello.go +++ b/worker/frame_agenthello.go @@ -3,6 +3,7 @@ package worker import ( "bytes" "fmt" + "github.com/negasus/haproxy-spoe-go/frame" ) diff --git a/worker/frame_disconnect.go b/worker/frame_disconnect.go index d1b8b64..0f36a98 100644 --- a/worker/frame_disconnect.go +++ b/worker/frame_disconnect.go @@ -3,6 +3,7 @@ package worker import ( "bytes" "fmt" + "github.com/negasus/haproxy-spoe-go/frame" ) diff --git a/worker/frame_notify.go b/worker/frame_notify.go index 9032650..08bc9c5 100644 --- a/worker/frame_notify.go +++ b/worker/frame_notify.go @@ -9,7 +9,6 @@ import ( ) func (w *worker) processNotifyFrame(f *frame.Frame) { - defer frame.ReleaseFrame(f) req := request.AcquireRequest() diff --git a/worker/worker_test.go b/worker/worker_test.go index 3fd5e82..1e9bc45 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -8,44 +8,50 @@ import ( "github.com/negasus/haproxy-spoe-go/client" "github.com/negasus/haproxy-spoe-go/logger" "github.com/negasus/haproxy-spoe-go/request" - "github.com/stretchr/testify/assert" - _ "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) type MockedHandler struct { - m mock.Mock + handleFunc func(r *request.Request) + finishFunc func() } func (h *MockedHandler) Handle(r *request.Request) { - h.m.MethodCalled("handle", r) + h.handleFunc(r) } func (h *MockedHandler) Finish() { - h.m.MethodCalled("Finished") + h.finishFunc() } func TestWorker(t *testing.T) { clientConn, server := net.Pipe() spoe := client.NewClient(clientConn) - var m MockedHandler - m.m.On("handle", mock.Anything) - m.m.On("Finished") + m := MockedHandler{ + handleFunc: func(r *request.Request) { + + }, + finishFunc: func() { + + }, + } go func() { Handle(server, m.Handle, logger.NewNop()) m.Finish() }() - assert.NoError(t, spoe.Init()) - assert.NoError(t, spoe.Notify()) - assert.NoError(t, spoe.Stop()) + if spoe.Init() != nil { + t.Fatal("unexpected error on Init") + } + if spoe.Notify() != nil { + t.Fatal("unexpected error on Notify") + } + if spoe.Stop() != nil { + t.Fatal("unexpected error on Stop") + } - // Lets wait a bit to have everything finished + // Let's wait a bit to have everything finished <-time.After(time.Millisecond * 100) clientConn.Close() - - m.m.AssertExpectations(t) - } /* @@ -57,8 +63,14 @@ func TestWorkerConcurrent(t *testing.T) { clientConn2, server2 := net.Pipe() spoe := client.NewClient(clientConn) spoe2 := client.NewClient(clientConn2) - var m MockedHandler - m.m.On("handle", mock.Anything) + m := MockedHandler{ + handleFunc: func(r *request.Request) { + + }, + finishFunc: func() { + + }, + } go func() { Handle(server, m.Handle, logger.NewNop()) @@ -68,7 +80,9 @@ func TestWorkerConcurrent(t *testing.T) { }() duration := time.Second loop := func(s client.Client) { - assert.NoError(t, s.Init()) + if s.Init() != nil { + t.Fatal("unexpected error on Init") + } for { select { case <-time.After(duration): @@ -81,11 +95,8 @@ func TestWorkerConcurrent(t *testing.T) { go loop(spoe) go loop(spoe2) - // Lets wait a bit to have everything finished + // Let's wait a bit to have everything finished <-time.After(duration) - - m.m.AssertExpectations(t) - } /* @@ -95,9 +106,14 @@ func TestWorkerConcurrent(t *testing.T) { func BenchmarkWorker(b *testing.B) { clientConn, server := net.Pipe() spoe := client.NewClient(clientConn) - var m MockedHandler - m.m.On("handle", mock.Anything) - m.m.On("Finished") + m := MockedHandler{ + handleFunc: func(r *request.Request) { + + }, + finishFunc: func() { + + }, + } go func() { Handle(server, m.Handle, logger.NewNop()) @@ -110,8 +126,7 @@ func BenchmarkWorker(b *testing.B) { } spoe.Stop() - // Lets wait a bit to have everything finished + // Let's wait a bit to have everything finished <-time.After(time.Millisecond * 100) clientConn.Close() - }