From 2414e77de30d5995688b138d2c86873fb02c32a8 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 10 Dec 2014 23:44:53 -0500 Subject: [PATCH 01/22] Start performing a lot more error checking in Send() --- client.go | 72 ++++++++++++++++++++++++++++++++----------------- notification.go | 8 ++++++ 2 files changed, 55 insertions(+), 25 deletions(-) diff --git a/client.go b/client.go index de7ab1a..b9a31f1 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "io" "log" + "sync" "time" ) @@ -27,20 +28,28 @@ func (b *buffer) Add(v interface{}) *list.Element { return e } +type serializedNotif struct { + id uint32 + b []byte +} + type Client struct { Conn *Conn FailedNotifs chan NotificationResult - notifs chan Notification - id uint32 + notifs chan serializedNotif + + id uint32 + idm sync.Mutex } func newClientWithConn(gw string, conn Conn) Client { c := Client{ Conn: &conn, FailedNotifs: make(chan NotificationResult), - id: uint32(1), - notifs: make(chan Notification), + notifs: make(chan serializedNotif), + id: 1, + idm: sync.Mutex{}, } go c.runLoop() @@ -73,10 +82,37 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err } func (c *Client) Send(n Notification) error { - c.notifs <- n + // Set identifier if not specified + if n.Identifier == 0 { + n.Identifier = c.nextID() + } else if c.id < n.Identifier { + c.setID(n.Identifier) + } + + b, err := n.ToBinary() + if err != nil { + return err + } + + c.notifs <- serializedNotif{b: b, id: n.Identifier} return nil } +func (c *Client) setID(n uint32) { + c.idm.Lock() + defer c.idm.Unlock() + + c.id = n +} + +func (c *Client) nextID() uint32 { + c.idm.Lock() + defer c.idm.Unlock() + + c.id++ + return c.id +} + func (c *Client) reportFailedPush(v interface{}, err *Error) { failedNotif, ok := v.(Notification) if !ok || v == nil { @@ -93,7 +129,7 @@ func (c *Client) requeue(cursor *list.Element) { // If `cursor` is not nil, this means there are notifications that // need to be delivered (or redelivered) for ; cursor != nil; cursor = cursor.Next() { - if n, ok := cursor.Value.(Notification); ok { + if n, ok := cursor.Value.(serializedNotif); ok { go func() { c.notifs <- n }() } } @@ -103,11 +139,11 @@ func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { cursor := buffer.Back() for cursor != nil { - // Get notification - n, _ := cursor.Value.(Notification) + // Get serialized notification + n, _ := cursor.Value.(serializedNotif) // If the notification, move cursor after the trouble notification - if n.Identifier == err.Identifier { + if n.id == err.Identifier { go c.reportFailedPush(cursor.Value, err) next := cursor.Next() @@ -143,7 +179,7 @@ func (c *Client) runLoop() { // Connection open, listen for notifs and errors for { var err error - var n Notification + var n serializedNotif // Check for notifications or errors. There is a chance we'll send notifications // if we already have an error since `select` will "pseudorandomly" choose a @@ -169,21 +205,7 @@ func (c *Client) runLoop() { // Add to list cursor = sent.Add(n) - // Set identifier if not specified - if n.Identifier == 0 { - n.Identifier = c.id - c.id++ - } else if c.id < n.Identifier { - c.id = n.Identifier + 1 - } - - b, err := n.ToBinary() - if err != nil { - // TODO - continue - } - - _, err = c.Conn.Write(b) + _, err = c.Conn.Write(n.b) if err == io.EOF { log.Println("EOF trying to write notification") diff --git a/notification.go b/notification.go index a82557e..ad967c0 100644 --- a/notification.go +++ b/notification.go @@ -15,6 +15,10 @@ const ( PriorityPowerConserve = 5 ) +const ( + validDeviceTokenLength = 64 +) + const ( commandID = 2 @@ -133,6 +137,10 @@ func NewPayload() *Payload { func (n Notification) ToBinary() ([]byte, error) { b := []byte{} + if len(n.DeviceToken) != validDeviceTokenLength { + return b, errors.New(ErrInvalidToken) + } + binTok, err := hex.DecodeString(n.DeviceToken) if err != nil { return b, fmt.Errorf("convert token to hex error: %s", err) From 992f8f22c0083d5c71a710a7dd9dc0d927ef6990 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 10 Dec 2014 23:58:05 -0500 Subject: [PATCH 02/22] Pass the original notification through if we want it with NotificationResult --- client.go | 14 +++++--------- notification.go | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index b9a31f1..fdc62ae 100644 --- a/client.go +++ b/client.go @@ -31,6 +31,7 @@ func (b *buffer) Add(v interface{}) *list.Element { type serializedNotif struct { id uint32 b []byte + n *Notification } type Client struct { @@ -94,7 +95,7 @@ func (c *Client) Send(n Notification) error { return err } - c.notifs <- serializedNotif{b: b, id: n.Identifier} + c.notifs <- serializedNotif{b: b, id: n.Identifier, n: &n} return nil } @@ -113,14 +114,9 @@ func (c *Client) nextID() uint32 { return c.id } -func (c *Client) reportFailedPush(v interface{}, err *Error) { - failedNotif, ok := v.(Notification) - if !ok || v == nil { - return - } - +func (c *Client) reportFailedPush(s serializedNotif, err *Error) { select { - case c.FailedNotifs <- NotificationResult{Notif: failedNotif, Err: *err}: + case c.FailedNotifs <- NotificationResult{Notif: s.n, Err: *err}: default: } } @@ -144,7 +140,7 @@ func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { // If the notification, move cursor after the trouble notification if n.id == err.Identifier { - go c.reportFailedPush(cursor.Value, err) + go c.reportFailedPush(n, err) next := cursor.Next() diff --git a/notification.go b/notification.go index ad967c0..3fe046b 100644 --- a/notification.go +++ b/notification.go @@ -37,7 +37,7 @@ const ( ) type NotificationResult struct { - Notif Notification + Notif *Notification Err Error } From ff654749ca72e276cd03472b3ddd11cc5976a869 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 11 Dec 2014 19:07:49 -0500 Subject: [PATCH 03/22] Undo NotificationResult api change --- client.go | 2 +- notification.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index fdc62ae..2d27b5d 100644 --- a/client.go +++ b/client.go @@ -116,7 +116,7 @@ func (c *Client) nextID() uint32 { func (c *Client) reportFailedPush(s serializedNotif, err *Error) { select { - case c.FailedNotifs <- NotificationResult{Notif: s.n, Err: *err}: + case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: *err}: default: } } diff --git a/notification.go b/notification.go index 3fe046b..ad967c0 100644 --- a/notification.go +++ b/notification.go @@ -37,7 +37,7 @@ const ( ) type NotificationResult struct { - Notif *Notification + Notif Notification Err Error } From 5fefc9a9d53ceee17464af8ff6121ff612e6ec35 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Mon, 15 Dec 2014 21:06:02 -0500 Subject: [PATCH 04/22] WIP --- client.go | 62 ++++++++++++++++++++++++++++++++++++++----------------- error.go | 5 +++++ 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 2d27b5d..69a135f 100644 --- a/client.go +++ b/client.go @@ -6,7 +6,6 @@ import ( "io" "log" "sync" - "time" ) type buffer struct { @@ -28,7 +27,7 @@ func (b *buffer) Add(v interface{}) *list.Element { return e } -type serializedNotif struct { +type serialized struct { id uint32 b []byte n *Notification @@ -38,23 +37,26 @@ type Client struct { Conn *Conn FailedNotifs chan NotificationResult - notifs chan serializedNotif + notifs chan serialized id uint32 idm sync.Mutex + + connected bool + connm sync.Mutex } func newClientWithConn(gw string, conn Conn) Client { c := Client{ Conn: &conn, FailedNotifs: make(chan NotificationResult), - notifs: make(chan serializedNotif), + notifs: make(chan serialized), id: 1, idm: sync.Mutex{}, + connected: false, + connm: sync.Mutex{}, } - go c.runLoop() - return c } @@ -82,7 +84,21 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err return newClientWithConn(gw, conn), nil } +func (c *Client) Connect() error { + err := c.Conn.Connect() + if err != nil { + return err + } + + go c.runLoop() + return nil +} + func (c *Client) Send(n Notification) error { + if !c.connected { + return ErrDisconnected + } + // Set identifier if not specified if n.Identifier == 0 { n.Identifier = c.nextID() @@ -95,7 +111,7 @@ func (c *Client) Send(n Notification) error { return err } - c.notifs <- serializedNotif{b: b, id: n.Identifier, n: &n} + c.notifs <- serialized{b: b, id: n.Identifier, n: &n} return nil } @@ -114,7 +130,21 @@ func (c *Client) nextID() uint32 { return c.id } -func (c *Client) reportFailedPush(s serializedNotif, err *Error) { +func (c *Client) connected() { + c.connm.Lock() + defer c.connm.Unlock() + + c.connected = true +} + +func (c *Client) disconnected() { + c.connm.Lock() + defer c.connm.Unlock() + + c.connected = false +} + +func (c *Client) reportFailedPush(s serialized, err *Error) { select { case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: *err}: default: @@ -125,7 +155,7 @@ func (c *Client) requeue(cursor *list.Element) { // If `cursor` is not nil, this means there are notifications that // need to be delivered (or redelivered) for ; cursor != nil; cursor = cursor.Next() { - if n, ok := cursor.Value.(serializedNotif); ok { + if n, ok := cursor.Value.(serialized); ok { go func() { c.notifs <- n }() } } @@ -136,7 +166,7 @@ func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { for cursor != nil { // Get serialized notification - n, _ := cursor.Value.(serializedNotif) + n, _ := cursor.Value.(serialized) // If the notification, move cursor after the trouble notification if n.id == err.Identifier { @@ -160,13 +190,6 @@ func (c *Client) runLoop() { // APNS connection for { - err := c.Conn.Connect() - if err != nil { - // TODO Probably want to exponentially backoff... - time.Sleep(1 * time.Second) - continue - } - // Start reading errors from APNS errs := readErrs(c.Conn) @@ -175,7 +198,7 @@ func (c *Client) runLoop() { // Connection open, listen for notifs and errors for { var err error - var n serializedNotif + var n serialized // Check for notifications or errors. There is a chance we'll send notifications // if we already have an error since `select` will "pseudorandomly" choose a @@ -205,7 +228,8 @@ func (c *Client) runLoop() { if err == io.EOF { log.Println("EOF trying to write notification") - break + c.connected = false + return } if err != nil { diff --git a/error.go b/error.go index 5425868..3ff4b86 100644 --- a/error.go +++ b/error.go @@ -3,6 +3,11 @@ package apns import ( "bytes" "encoding/binary" + "errors" +) + +const ( + ErrDisconnected = errors.New("disconnected from gateway") ) const ( From 3c1ae0e6c2d1a49994ad264a4365a9b5fa2a52d0 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 16 Dec 2014 00:37:16 -0500 Subject: [PATCH 05/22] Start simplifying the client internals --- client.go | 201 +++++++++++++++++++----------------------------------- error.go | 2 +- 2 files changed, 71 insertions(+), 132 deletions(-) diff --git a/client.go b/client.go index 69a135f..65ad2c4 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,6 @@ import ( "container/list" "crypto/tls" "io" - "log" "sync" ) @@ -39,6 +38,9 @@ type Client struct { notifs chan serialized + buffer *buffer + cursor *list.Element + id uint32 idm sync.Mutex @@ -51,7 +53,9 @@ func newClientWithConn(gw string, conn Conn) Client { Conn: &conn, FailedNotifs: make(chan NotificationResult), notifs: make(chan serialized), - id: 1, + buffer: newBuffer(50), + cursor: nil, + id: 0, idm: sync.Mutex{}, connected: false, connm: sync.Mutex{}, @@ -85,12 +89,20 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err } func (c *Client) Connect() error { - err := c.Conn.Connect() - if err != nil { + if err := c.Conn.Connect(); err != nil { + return err + } + + // On connect, requeue any notifications that were + // sent after the error & disconnect. + // http://redth.codes/the-problem-with-apples-push-notification-ser/ + if err := c.requeue(); err != nil { return err } - go c.runLoop() + // Kick off asynchronous error reading + go c.readErrors() + return nil } @@ -100,162 +112,89 @@ func (c *Client) Send(n Notification) error { } // Set identifier if not specified - if n.Identifier == 0 { - n.Identifier = c.nextID() - } else if c.id < n.Identifier { - c.setID(n.Identifier) - } + n.Identifier = c.determineIdentifier(n.Identifier) b, err := n.ToBinary() if err != nil { return err } - c.notifs <- serialized{b: b, id: n.Identifier, n: &n} - return nil -} + // Add to list + c.cursor = c.buffer.Add(n) -func (c *Client) setID(n uint32) { - c.idm.Lock() - defer c.idm.Unlock() + _, err = c.Conn.Write(b) + if err == io.EOF { + c.connected = false + return err + } + + if err != nil { + return err + } - c.id = n + c.cursor = c.cursor.Next() + return nil } -func (c *Client) nextID() uint32 { +func (c *Client) determineIdentifier(n uint32) uint32 { c.idm.Lock() defer c.idm.Unlock() - c.id++ - return c.id -} - -func (c *Client) connected() { - c.connm.Lock() - defer c.connm.Unlock() - - c.connected = true -} - -func (c *Client) disconnected() { - c.connm.Lock() - defer c.connm.Unlock() - - c.connected = false -} - -func (c *Client) reportFailedPush(s serialized, err *Error) { - select { - case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: *err}: - default: + // If the id passed in is 0, that means it wasn't + // set so get the next ID. Otherwise, set it to that + // identifier. + if n == 0 { + c.id++ + } else { + c.id = n } + + return c.id } -func (c *Client) requeue(cursor *list.Element) { +func (c *Client) requeue() error { // If `cursor` is not nil, this means there are notifications that // need to be delivered (or redelivered) - for ; cursor != nil; cursor = cursor.Next() { - if n, ok := cursor.Value.(serialized); ok { - go func() { c.notifs <- n }() - } - } -} - -func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { - cursor := buffer.Back() - - for cursor != nil { - // Get serialized notification - n, _ := cursor.Value.(serialized) - - // If the notification, move cursor after the trouble notification - if n.id == err.Identifier { - go c.reportFailedPush(n, err) - - next := cursor.Next() - - buffer.Remove(cursor) - return next + for ; c.cursor != nil; c.cursor = c.cursor.Next() { + if s, ok := c.cursor.Value.(serialized); ok { + if err := c.Send(*s.n); err != nil { + return err + } } - - cursor = cursor.Prev() } - return cursor + return nil } -func (c *Client) runLoop() { - sent := newBuffer(50) - cursor := sent.Front() +func (c *Client) readErrors() { + p := make([]byte, 6, 6) - // APNS connection - for { - // Start reading errors from APNS - errs := readErrs(c.Conn) + _, err := c.Conn.Read(p) + // TODO(bw) not sure what to do here. It's unclear what errors + // come out of this and how we handle it. + if err != nil { + return + } - c.requeue(cursor) + e := NewError(p) + cursor := c.buffer.Back() - // Connection open, listen for notifs and errors - for { - var err error - var n serialized + for cursor != nil { + // Get serialized notification + s, _ := cursor.Value.(serialized) - // Check for notifications or errors. There is a chance we'll send notifications - // if we already have an error since `select` will "pseudorandomly" choose a - // ready channels. It turns out to be fine because the connection will already - // be closed and it'll requeue. We could check before we get to this select - // block, but it doesn't seem worth the extra code and complexity. + // If the notification, move cursor after the trouble notification + if s.id == e.Identifier { + // Try to write - skip if no one is reading on the other side select { - case err = <-errs: - case n = <-c.notifs: - } - - // If there is an error we understand, find the notification that failed, - // move the cursor right after it. - if nErr, ok := err.(*Error); ok { - cursor = c.handleError(nErr, sent) - break + case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: e}: + default: } - if err != nil { - break - } - - // Add to list - cursor = sent.Add(n) - - _, err = c.Conn.Write(n.b) - - if err == io.EOF { - log.Println("EOF trying to write notification") - c.connected = false - return - } - - if err != nil { - log.Println("err writing to apns", err.Error()) - break - } - - cursor = cursor.Next() + c.cursor = cursor.Next() + c.buffer.Remove(cursor) } - } -} - -func readErrs(c *Conn) chan error { - errs := make(chan error) - - go func() { - p := make([]byte, 6, 6) - _, err := c.Read(p) - if err != nil { - errs <- err - return - } - - e := NewError(p) - errs <- &e - }() - return errs + cursor = cursor.Prev() + } } diff --git a/error.go b/error.go index 3ff4b86..3371bea 100644 --- a/error.go +++ b/error.go @@ -6,7 +6,7 @@ import ( "errors" ) -const ( +var ( ErrDisconnected = errors.New("disconnected from gateway") ) From bbb3e83990b1849aacc4d107301187f3e0300d53 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 16 Dec 2014 00:38:05 -0500 Subject: [PATCH 06/22] Update example --- example/example.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/example/example.go b/example/example.go index 1b670ac..2d9e5c0 100644 --- a/example/example.go +++ b/example/example.go @@ -10,10 +10,14 @@ import ( func main() { c, err := apns.NewClientWithFiles(apns.ProductionGateway, "apns.crt", "apns.key") if err != nil { - log.Fatal("Could not create client", err.Error()) + log.Fatal("Could not create client: ", err.Error()) } - i := 0 + if err := c.Connect(); err != nil { + log.Fatal("Could not create connect: ", err.Error()) + } + + i := 1 for { fmt.Print("Enter ' ': ") From a9286ca26762177b8d9b7ba688c0631bbf91959c Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 23 Dec 2014 17:46:37 -0500 Subject: [PATCH 07/22] Add connection resource locking --- client.go | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 65ad2c4..d8c2441 100644 --- a/client.go +++ b/client.go @@ -106,6 +106,17 @@ func (c *Client) Connect() error { return nil } +func (c *Client) disconnect() error { + c.connm.Lock() + defer c.connm.Unlock() + + if c.Conn == nil { + return nil + } + + return c.Conn.Close() +} + func (c *Client) Send(n Notification) error { if !c.connected { return ErrDisconnected @@ -122,7 +133,14 @@ func (c *Client) Send(n Notification) error { // Add to list c.cursor = c.buffer.Add(n) - _, err = c.Conn.Write(b) + return c.send(b) +} + +func (c *Client) send(b []byte) error { + c.connm.Lock() + defer c.connm.Unlock() + + _, err := c.Conn.Write(b) if err == io.EOF { c.connected = false return err @@ -179,6 +197,8 @@ func (c *Client) readErrors() { e := NewError(p) cursor := c.buffer.Back() + c.disconnect() + for cursor != nil { // Get serialized notification s, _ := cursor.Value.(serialized) From 92b160330747ae2e93ef1c8814c4fb79b31885d0 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 23 Dec 2014 17:55:52 -0500 Subject: [PATCH 08/22] Update notification test and remove old client test file --- client_test.go | 384 ------------------------------------------- notification_test.go | 15 +- 2 files changed, 14 insertions(+), 385 deletions(-) delete mode 100644 client_test.go diff --git a/client_test.go b/client_test.go deleted file mode 100644 index c9dfd47..0000000 --- a/client_test.go +++ /dev/null @@ -1,384 +0,0 @@ -package apns_test - -import ( - "bytes" - "encoding/binary" - "io/ioutil" - "os" - "time" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/timehop/apns" -) - -var _ = Describe("Client", func() { - Describe(".NewConn", func() { - Context("bad cert/key pair", func() { - It("should error out", func() { - _, err := apns.NewClient(apns.ProductionGateway, "missing", "missing_also") - Expect(err).NotTo(BeNil()) - }) - }) - - Context("valid cert/key pair", func() { - It("should create a valid client", func() { - c, err := apns.NewClient(apns.ProductionGateway, DummyCert, DummyKey) - Expect(err).To(BeNil()) - Expect(c.Conn).NotTo(BeNil()) - }) - }) - }) - - Describe(".NewConnWithFiles", func() { - Context("missing cert/key pair", func() { - It("should error out", func() { - _, err := apns.NewClientWithFiles(apns.ProductionGateway, "missing", "missing_also") - Expect(err).NotTo(BeNil()) - }) - }) - - Context("valid cert/key pair", func() { - var certFile, keyFile *os.File - - BeforeEach(func() { - certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) - certFile.Close() - - keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) - keyFile.Close() - }) - - AfterEach(func() { - if certFile != nil { - os.Remove(certFile.Name()) - } - - if keyFile != nil { - os.Remove(keyFile.Name()) - } - }) - - It("should create a valid client", func() { - c, err := apns.NewClientWithFiles(apns.ProductionGateway, certFile.Name(), keyFile.Name()) - Expect(err).To(BeNil()) - Expect(c.Conn).NotTo(BeNil()) - }) - }) - }) - - Describe("#Send", func() { - Context("simple write", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - Expect(c.Send(apns.Notification{})).To(BeNil()) - - close(mockDone) - close(d) - }) - }) - }) - - Context("simple write with buffer", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - for i := 0; i < 54; i++ { - Expect(c.Send(apns.Notification{})).To(BeNil()) - } - - close(mockDone) - close(d) - }) - }) - }) - - Context("multiple write", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - Expect(c.Send(apns.Notification{})).To(BeNil()) - Expect(c.Send(apns.Notification{})).To(BeNil()) - - close(mockDone) - close(d) - }) - }) - }) - - Context("bad push", func() { - n := apns.Notification{Identifier: 9, ID: "some_rando"} - nb, _ := n.ToBinary() - nbcb := make([]byte, len(nb)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(9)) - - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - serverAction{action: readAction, data: nbcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(nb)) - }}, - - // Bad push results in a close - serverAction{action: writeAction, data: errPayload.Bytes()}, - serverAction{action: closeAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - go func() { - n := <-c.FailedNotifs - - Expect(n.Notif.Identifier).To(Equal(uint32(9))) - Expect(n.Notif.ID).To(Equal("some_rando")) - - close(mockDone) - close(d) - }() - - Expect(c.Send(n)).To(BeNil()) - }) - }) - }) - - Context("closed, reconnect", func() { - done := make(chan bool) - - n1 := apns.Notification{Identifier: 1} - n1b, _ := n1.ToBinary() - n1bcb := make([]byte, len(n1b)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(2)) - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - - as := [][]serverAction{ - []serverAction{ - // Write error - serverAction{action: writeAction, data: errPayload.Bytes(), cb: func(a serverAction) { - done <- true - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - - close(mockDone) - close(d) - }}, - }, - } - - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - <-done - time.Sleep(5 * time.Millisecond) - - // Good - Expect(c.Send(n1)).To(BeNil()) - }) - }) - }) - - Context("good, close, good, requeue of last good", func() { - closed := make(chan bool) - - n1 := apns.Notification{Identifier: 1} - n2 := apns.Notification{Identifier: 2} - - n1b, _ := n1.ToBinary() - n2b, _ := n2.ToBinary() - - n1bcb := make([]byte, len(n1b)) - n2bcb := make([]byte, len(n2b)) - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - as := [][]serverAction{ - []serverAction{ - // Connect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Handshake - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - closed <- true - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Requeue - serverAction{action: readAction, data: n2bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n2b)) - - close(mockDone) - close(d) - }}, - }, - } - - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - // Good - Expect(c.Send(n1)).To(BeNil()) - - <-closed - time.Sleep(5 * time.Millisecond) - - // Good - Expect(c.Send(n2)).To(BeNil()) - }) - }) - }) - - Context("good, bad, good, requeue of last good", func() { - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - - n1 := apns.Notification{Identifier: 1} - n2 := apns.Notification{Identifier: 2} - n3 := apns.Notification{Identifier: 3} - - n1b, _ := n1.ToBinary() - n2b, _ := n2.ToBinary() - n3b, _ := n3.ToBinary() - - n1bcb := make([]byte, len(n1b)) - n2bcb := make([]byte, len(n2b)) - n3bcb := make([]byte, len(n3b)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(2)) - - as := [][]serverAction{ - []serverAction{ - // Connect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Handshake - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - }}, - - // Read bad notification - serverAction{action: readAction, data: n2bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n2b)) - }}, - - // Read second good notification - serverAction{action: readAction, data: n3bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n3b)) - }}, - - // Write error - serverAction{action: writeAction, data: errPayload.Bytes(), cb: func(a serverAction) { - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Requeue - serverAction{action: readAction, data: n3bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n3b)) - - close(mockDone) - close(d) - }}, - }, - } - - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - // Good - Expect(c.Send(n1)).To(BeNil()) - - // Bad - Expect(c.Send(n2)).To(BeNil()) - - // Good - Expect(c.Send(n3)).To(BeNil()) - }) - }) - }) - }) -}) diff --git a/notification_test.go b/notification_test.go index cea990a..e76e253 100644 --- a/notification_test.go +++ b/notification_test.go @@ -193,7 +193,20 @@ var _ = Describe("Notifications", func() { Describe("#ToBinary", func() { Context("invalid token format", func() { n := apns.NewNotification() - n.DeviceToken = "totally not a valid token" + n.DeviceToken = "totally not a valid token length" + + It("should return an error", func() { + _, err := n.ToBinary() + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(Equal(apns.ErrInvalidToken)) + }) + + // Expect(err.Error()).To(ContainSubstring("convert token to hex error")) + }) + + Context("non-convertable token", func() { + n := apns.NewNotification() + n.DeviceToken = "123456789012345678901234567890zz123456789012345678901234567890zz" It("should return an error", func() { _, err := n.ToBinary() From 4b752d30335ebd3e93068d1f0aca6fec5d3ff09a Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 27 Jan 2015 00:32:45 -0500 Subject: [PATCH 09/22] WIP Convert Conn into an interface and make feedback test --- apns_suite_test.go | 38 +++++ client.go | 4 +- conn.go | 66 +++++---- conn_test.go | 347 +++++++++------------------------------------ feedback.go | 12 +- feedback_test.go | 177 ++++++++++++----------- 6 files changed, 243 insertions(+), 401 deletions(-) diff --git a/apns_suite_test.go b/apns_suite_test.go index b0bcca4..858a5b9 100644 --- a/apns_suite_test.go +++ b/apns_suite_test.go @@ -5,8 +5,46 @@ import ( . "github.com/onsi/gomega" "testing" + "time" ) +type mockConn struct { + connect func() error + read func([]byte) (int, error) + readWithTimeout func([]byte, time.Time) (int, error) +} + +func (m *mockConn) Connect() error { + if m.connect != nil { + return m.connect() + } + + return nil +} + +func (m *mockConn) Read(b []byte) (int, error) { + if m.read != nil { + return m.read(b) + } + return 0, nil +} + +func (m *mockConn) Write([]byte) (int, error) { + return 0, nil +} + +func (m *mockConn) Close() error { + return nil +} + +func (m *mockConn) ReadWithTimeout(b []byte, t time.Time) (int, error) { + if m.readWithTimeout != nil { + return m.readWithTimeout(b, t) + } + + return 0, nil +} + func TestApns(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Apns Suite") diff --git a/client.go b/client.go index d8c2441..e4fac1d 100644 --- a/client.go +++ b/client.go @@ -33,7 +33,7 @@ type serialized struct { } type Client struct { - Conn *Conn + Conn Conn FailedNotifs chan NotificationResult notifs chan serialized @@ -50,7 +50,7 @@ type Client struct { func newClientWithConn(gw string, conn Conn) Client { c := Client{ - Conn: &conn, + Conn: conn, FailedNotifs: make(chan NotificationResult), notifs: make(chan serialized), buffer: newBuffer(50), diff --git a/conn.go b/conn.go index d3aa712..1bb8bab 100644 --- a/conn.go +++ b/conn.go @@ -2,8 +2,10 @@ package apns import ( "crypto/tls" + "io" "net" "strings" + "time" ) const ( @@ -15,9 +17,16 @@ const ( ) // Conn is a wrapper for the actual TLS connections made to Apple -type Conn struct { - NetConn net.Conn - Conf *tls.Config +type Conn interface { + io.ReadWriteCloser + + Connect() error + ReadWithTimeout(p []byte, deadline time.Time) (int, error) +} + +type conn struct { + netConn net.Conn + tls *tls.Config gateway string connected bool @@ -25,19 +34,20 @@ type Conn struct { func NewConnWithCert(gw string, cert tls.Certificate) Conn { gatewayParts := strings.Split(gw, ":") - conf := tls.Config{ - Certificates: []tls.Certificate{cert}, - ServerName: gatewayParts[0], + tls := tls.Config{ + Certificates: []tls.Certificate{cert}, + ServerName: gatewayParts[0], + InsecureSkipVerify: true, } - return Conn{gateway: gw, Conf: &conf} + return &conn{gateway: gw, tls: &tls} } // NewConnWithFiles creates a new Conn from certificate and key in the specified files func NewConn(gw string, crt string, key string) (Conn, error) { cert, err := tls.X509KeyPair([]byte(crt), []byte(key)) if err != nil { - return Conn{}, err + return &conn{}, err } return NewConnWithCert(gw, cert), nil @@ -47,49 +57,51 @@ func NewConn(gw string, crt string, key string) (Conn, error) { func NewConnWithFiles(gw string, certFile string, keyFile string) (Conn, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return Conn{}, err + return &conn{}, err } return NewConnWithCert(gw, cert), nil } // Connect actually creates the TLS connection -func (c *Conn) Connect() error { +func (c *conn) Connect() error { // Make sure the existing connection is closed - if c.NetConn != nil { - c.NetConn.Close() - } - - conn, err := net.Dial("tcp", c.gateway) - if err != nil { - return err + if c.netConn != nil { + c.netConn.Close() } - tlsConn := tls.Client(conn, c.Conf) - err = tlsConn.Handshake() + tlsConn, err := tls.Dial("tcp", c.gateway, c.tls) if err != nil { return err } - c.NetConn = tlsConn + c.netConn = tlsConn return nil } -func (c *Conn) Close() error { - if c.NetConn != nil { - return c.NetConn.Close() +func (c *conn) Close() error { + if c.netConn != nil { + return c.netConn.Close() } return nil } // Read reads data from the connection -func (c *Conn) Read(p []byte) (int, error) { - i, err := c.NetConn.Read(p) +func (c *conn) Read(p []byte) (int, error) { + i, err := c.netConn.Read(p) + return i, err +} + +// ReadWithTimeout reads data from the connection and returns an error +// after duration +func (c *conn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { + c.netConn.SetReadDeadline(deadline) + i, err := c.netConn.Read(p) return i, err } // Write writes data from the connection -func (c *Conn) Write(p []byte) (int, error) { - return c.NetConn.Write(p) +func (c *conn) Write(p []byte) (int, error) { + return c.netConn.Write(p) } diff --git a/conn_test.go b/conn_test.go index e910e6c..4bee29c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,232 +1,15 @@ package apns_test import ( - "bytes" - "crypto/tls" - "fmt" - "io" "io/ioutil" - "log" "net" "os" - "strings" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "github.com/timehop/tcptest" ) -var DummyCert = `-----BEGIN CERTIFICATE----- -MIIC9TCCAd+gAwIBAgIQf3bEgFWUb+q6eK5ySkV/gjALBgkqhkiG9w0BAQUwEjEQ -MA4GA1UEChMHQWNtZSBDbzAeFw0xNDA2MzAwNDI5MDhaFw0xNTA2MzAwNDI5MDha -MBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK -AoIBAQDhAgWrrFZBtCfVEPg1tSIr9fuSUoeundb556IUr9uOmOHaYK7r3/I43acw -bVIfaenFxwUUf8YakQzTjOa5qSfK/Eylyw2ezBJtNUEqcHw0f+y66+jJbZa4clPa -tL6ezaMS/syXPpvNU8+16jdVdTJzqdBdSGAZMOCeumUWDNdlfBmHPVq1JMy0uGmO -XDoZK2Ir0/3LUfjk9R2wdm1VLrJAml7F0L0FhBHHXgHOSFM2ixjGflffaiuTCxhW -1z1NTo9XjWUQh2iM9Udf+xVnJLGLZ0EMFr2qihuK604Fp4SlNHEF+UWUn+j0PYo+ -LbzM9oKJcdVD0XI36vrn3rGPHO9vAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIAoDAT -BgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuCCWxv -Y2FsaG9zdDALBgkqhkiG9w0BAQUDggEBAGJ/3I4KKlbEwLAC5ut4ZZ9V8WF4sHkI -Lj7e4vx2pPi6hf9miV1ff01NrpfUna7flwL9yD7Ybl7jRRIB4rIcKk+U5djGsT3H -ScGkbIMKrr08drWw1g4JU6PBH7xTfzGxNRERrnmrbJV0jCo9Tt8i53IpPtp6Z2Q1 -8ydtPhU+Bpe2YoNr1w1fSV1JHXqjKV8RlGkCNSi4ozPOO8RbAYnBT3d9XSGoX//q -RGJUf3wC/rCxJkN63Moxuy3vxV2TmiqccHOrXJSJ8P/4PpPV/xuBk5k4HS1Nfmew -d9WHHn6bMJE9arVvWAiu9teCadVffuS2cl2cicN4XB6Ui0aDqhG2Exw= ------END CERTIFICATE-----` - -var DummyKey = `-----BEGIN RSA PRIVATE KEY----- -MIIEpAIBAAKCAQEA4QIFq6xWQbQn1RD4NbUiK/X7klKHrp3W+eeiFK/bjpjh2mCu -69/yON2nMG1SH2npxccFFH/GGpEM04zmuaknyvxMpcsNnswSbTVBKnB8NH/suuvo -yW2WuHJT2rS+ns2jEv7Mlz6bzVPPteo3VXUyc6nQXUhgGTDgnrplFgzXZXwZhz1a -tSTMtLhpjlw6GStiK9P9y1H45PUdsHZtVS6yQJpexdC9BYQRx14BzkhTNosYxn5X -32orkwsYVtc9TU6PV41lEIdojPVHX/sVZySxi2dBDBa9qoobiutOBaeEpTRxBflF -lJ/o9D2KPi28zPaCiXHVQ9FyN+r6596xjxzvbwIDAQABAoIBAFzW+cIA5MJNdFX8 -n32BlGzxHPEd7nAFHmuUwJKqkPwAZsg1NleK2qXOByr7IHRnvhZl7Nmtcu8JRHKR -Y63ddtbRTUrnQmJwL3YyEAZTzVvYILRrnGxoNFU8jw7hnvllPdEbow0QvzZ0S3Lz -BgvTxJJm0dt7fnNGcJftrsHvYHy1dptaR4hPv0xV5G7RPrbTl94llKfi745tp5Wd -xGpnjcBXoAnzCVRij1tHfSYubRJ2MJV0kzG3oVdRV2P/zWaout8BlhLCURv4sRUX -7FfCNa/z+G6AlROjCKJUP9YIUbxBEa/aP8YlSiyLRi1jFbMWcnKWQUdqS19m73Ap -a1LJFPECgYEA+Ve5DegcrWnUb2HsHD38HlmEg6S+/jg2P4TsuLZBtvO4/vzRx/qq -pwuuMm2CsvXr4nVmMEsMlSzYdsnaXIlWqyVDCOwIWR5VYT2GDWqQLaIXPlFaISzN -27tHd64KUtR1fMJUwQVK/MUORUbpYoAnSIil2SlYkWUhF024fNP8CxcCgYEA5wP4 -HLiqU2rqe7vSAF/8fHwPleTzuCfMCVZm0aegUzQQQtklZoVE/BBwEGHdXflq1veq -pHeC8bNR4BF6ZgeSWgbLVF3msquy47QeNElHA2muJd3qmNWz4LXo1Pxb8KXcnXri -QZ+r3Y8obWTFQYq7gGQGPLXGTV3bhLGIyrT4lWkCgYAgZ2MYSJL5gmhmNT6fCPsr -4oxTI2Ti2uFJ7fdppd3ybcgb8zU8HPpyjRUNXqf+o/EM1B78pbQz6skS3vau0fZe -dZA5p5sKIeQMqBc0xSWJmKgWpDHnX9A8/yCxj/+tdgjytrqW/x4YrW9GV4nbEDaK -uZ98EmB9PLxJMAOKzW3S7wKBgQDD4PCy4b3CR2iVC9dva/P5VXQdo+knX884p6M8 -58YgZofXNqnouN2aYRG0QlbiBMcbiRqOo6tK58JnnEpNUuQ8I4Cqg4hGPSHMwv/N -U8i70xLPltABUUpZIcVPOr92WBytBvHrtMiUb3tW7lf3T/vWTHmhZnvDQ+8LH0Ge -pz4T6QKBgQCoBJKOd781IQmT6i5hHSYJlsP6ymaaaQniJPVpnci/jf8+2QtponQY -scgnaBLBasLQ6GfKSRtcyidEi9wwxpVj0tw2p567jeNcIveD0TOYFf0RHEfrs+D4 -VdRgai/v2NbFZLDnzeGVuYypXu6R78isJfHtz/a0aEave8yB3CRiDw== ------END RSA PRIVATE KEY-----` - -// To be able to run in parallel -var mockPort = 50000 - -// Mock Addr -type mockAddr struct { -} - -func (m mockAddr) Network() string { - return "localhost:56789" -} - -func (m mockAddr) String() string { - return "localhost:56789" -} - -// Mock TLS connection -type mockTLSNetConn struct { - bb *bytes.Buffer - err error -} - -func (t mockTLSNetConn) Read(p []byte) (int, error) { - r := bytes.NewReader(t.bb.Bytes()) - return r.Read(p) -} - -func (t mockTLSNetConn) Write(p []byte) (int, error) { - return t.bb.Write(p) -} - -func (t mockTLSNetConn) Close() error { - return t.err -} - -func (m mockTLSNetConn) LocalAddr() net.Addr { - return mockAddr{} -} - -func (m mockTLSNetConn) RemoteAddr() net.Addr { - return mockAddr{} -} - -func (m mockTLSNetConn) SetDeadline(t time.Time) error { - return nil -} - -func (m mockTLSNetConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (m mockTLSNetConn) SetWriteDeadline(t time.Time) error { - return nil -} - -type serverAction struct { - action string - data []byte - cb func(s serverAction) -} - -const ( - readAction = "read" - writeAction = "write" - closeAction = "close" -) - -type mockTLSServer struct { - Port int - Server net.Listener - ConnectionActionGroups [][]serverAction -} - -func (m *mockTLSServer) portStr() string { - if m.Port == 0 { - mockPort = mockPort + 1 - m.Port = mockPort - } - - return fmt.Sprint(m.Port) -} - -func (m *mockTLSServer) Address() string { - return "localhost:" + m.portStr() -} - -func (m *mockTLSServer) start() { - cert, err := tls.X509KeyPair([]byte(DummyCert), []byte(DummyKey)) - if err != nil { - log.Panic(err) - } - - config := tls.Config{Certificates: []tls.Certificate{cert}, ClientAuth: tls.RequireAnyClientCert} - - m.Server, err = tls.Listen("tcp", "localhost:"+m.portStr(), &config) - go func() { - for i := 0; i < len(m.ConnectionActionGroups); i++ { - g := m.ConnectionActionGroups[i] - - // Wait for a connection. - conn, err := m.Server.Accept() - if err != nil { - if strings.Contains(err.Error(), "use of closed network connection") { - return - } else { - log.Fatal(err) - } - } - // Handle the connection in a new goroutine. - // The loop then returns to accepting, so that - // multiple connections may be served concurrently. - go func(c net.Conn) { - for j := 0; j < len(g); j++ { - a := g[j] - switch a.action { - case readAction: - c.Read(a.data) - case writeAction: - c.Write(a.data) - case closeAction: - c.Close() - - if a.cb != nil { - a.cb(a) - } - return - } - - if a.cb != nil { - a.cb(a) - } - } - }(conn) - } - - // No more connection action groups - }() -} - -func (m *mockTLSServer) stop() { - if m.Server != nil { - m.Server.Close() - } -} - -var withMockServer = func(as [][]serverAction, cb func(s *mockTLSServer)) { - d := make(chan interface{}) - withMockServerAsync(as, d, func(s *mockTLSServer) { - cb(s) - close(d) - }) -} - -var withMockServerAsync = func(as [][]serverAction, d chan interface{}, cb func(s *mockTLSServer)) { - s := &mockTLSServer{} - s.ConnectionActionGroups = as - - s.start() - - cb(s) - - <-d - s.stop() -} - // Tests var _ = Describe("Conn", func() { Describe(".NewConn", func() { @@ -239,7 +22,7 @@ var _ = Describe("Conn", func() { Context("valid key/cert pair", func() { It("should not return an error", func() { - _, err := apns.NewConn(apns.SandboxGateway, DummyCert, DummyKey) + _, err := apns.NewConn(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) @@ -259,11 +42,11 @@ var _ = Describe("Conn", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -295,65 +78,71 @@ var _ = Describe("Conn", func() { }) Context("server up", func() { - as := [][]serverAction{[]serverAction{serverAction{action: readAction, data: []byte{}}}} - Context("with untrusted certs", func() { It("should return an error", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - err := conn.Connect() - Expect(err).NotTo(BeNil()) + s := tcptest.NewTLSServer(func(c net.Conn) {}) + defer s.Close() - close(d) - }) + conn, err := apns.NewConn(s.Addr, "not trusted", "not even a little") + Expect(err).NotTo(BeNil()) + + err = conn.Connect() + Expect(err).NotTo(BeNil()) + + close(d) }) }) Context("trusting the certs", func() { It("should not return an error", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - conn.Conf.InsecureSkipVerify = true + s := tcptest.NewUnstartedServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) - err := conn.Connect() - Expect(err).To(BeNil()) + s.StartTLS() + defer s.Close() - close(d) - }) + conn, err := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = conn.Connect() + Expect(err).To(BeNil()) + + close(d) }) }) Context("with existing connection", func() { It("should not return an error", func(d Done) { - as = [][]serverAction{ - []serverAction{serverAction{action: readAction, data: []byte{}}}, - []serverAction{serverAction{action: readAction, data: []byte{}}}, - } + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) + defer s.Close() - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - conn.Conf.InsecureSkipVerify = true + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() + conn.Connect() - err := conn.Connect() - Expect(err).To(BeNil()) + err := conn.Connect() + Expect(err).To(BeNil()) - close(d) - }) + close(d) }) }) }) }) Describe("#Read", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte("hello!"))} - - pp := make([]byte, 6) - bytes.NewReader(rwc.bb.Bytes()).Read(pp) + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte("hello!")) + }) + defer s.Close() - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() It("should read out 'hello!'", func() { p := make([]byte, 6) @@ -364,47 +153,47 @@ var _ = Describe("Conn", func() { }) Describe("#Write", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} + It("should read out 'hello!'", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + + b := make([]byte, 6) + c.Read(b) + + Expect(string(b)).To(Equal("hello!")) + close(d) + }) + defer s.Close() - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - It("should write out 'world!'", func() { - conn.Write([]byte("world!")) - Expect(rwc.bb.String()).To(Equal("world!")) + conn.Write([]byte("hello!")) }) }) Describe("#Close", func() { Context("with connection", func() { Context("no error", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} - - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc - It("should return no error", func() { - Expect(rwc.Close()).To(BeNil()) - }) - }) - - Context("with error", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} - - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) + defer s.Close() - rwc.err = io.EOF - It("should return that error", func() { - Expect(rwc.Close()).To(Equal(io.EOF)) + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() + Expect(conn.Close()).To(BeNil()) }) }) }) Context("without connection", func() { - c, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) It("should not return an error", func() { - Expect(c.Close()).To(BeNil()) + conn, _ := apns.NewConn("localhost:12345", string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(conn.Close()).To(BeNil()) }) }) }) diff --git a/feedback.go b/feedback.go index 488bf1b..a0093b1 100644 --- a/feedback.go +++ b/feedback.go @@ -9,7 +9,7 @@ import ( ) type Feedback struct { - Conn *Conn + Conn Conn } type FeedbackTuple struct { @@ -40,7 +40,7 @@ func feedbackTupleFromBytes(b []byte) FeedbackTuple { func NewFeedbackWithCert(gw string, cert tls.Certificate) Feedback { conn := NewConnWithCert(gw, cert) - return Feedback{Conn: &conn} + return Feedback{Conn: conn} } func NewFeedback(gw string, cert string, key string) (Feedback, error) { @@ -49,7 +49,7 @@ func NewFeedback(gw string, cert string, key string) (Feedback, error) { return Feedback{}, err } - return Feedback{Conn: &conn}, nil + return Feedback{Conn: conn}, nil } func NewFeedbackWithFiles(gw string, certFile string, keyFile string) (Feedback, error) { @@ -58,7 +58,7 @@ func NewFeedbackWithFiles(gw string, certFile string, keyFile string) (Feedback, return Feedback{}, err } - return Feedback{Conn: &conn}, nil + return Feedback{Conn: conn}, nil } // Receive returns a read only channel for APNs feedback. The returned channel @@ -80,9 +80,7 @@ func (f Feedback) receive(fc chan FeedbackTuple) { for { b := make([]byte, 38) - f.Conn.NetConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - - _, err := f.Conn.Read(b) + _, err := f.Conn.ReadWithTimeout(b, time.Now().Add(100*time.Millisecond)) if err != nil { close(fc) return diff --git a/feedback_test.go b/feedback_test.go index 29978b4..8cd909b 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -4,12 +4,16 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "fmt" + "io" "io/ioutil" + "net" "os" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "github.com/timehop/tcptest" ) var _ = Describe("Feedback", func() { @@ -23,7 +27,7 @@ var _ = Describe("Feedback", func() { Context("valid cert/key pair", func() { It("should create a valid client", func() { - _, err := apns.NewFeedback(apns.ProductionGateway, DummyCert, DummyKey) + _, err := apns.NewFeedback(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) @@ -42,11 +46,11 @@ var _ = Describe("Feedback", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -70,11 +74,13 @@ var _ = Describe("Feedback", func() { Describe("#Receive", func() { Context("could not connect", func() { It("should not receive anything", func() { - s := &mockTLSServer{} - - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true + m := mockConn{ + connect: func() error { + return io.EOF + }, + } + f := apns.Feedback{Conn: &m} c := f.Receive() r := 0 @@ -87,89 +93,88 @@ var _ = Describe("Feedback", func() { }) Context("times out", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - withMockServer(as, func(s *mockTLSServer) { - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true - - It("should not receive anything", func() { - c := f.Receive() - - r := 0 - for _ = range c { - r += 1 - } - - Expect(r).To(Equal(0)) - }) + It("should not receive anything", func() { + m := mockConn{ + readWithTimeout: func(b []byte, t time.Time) (int, error) { + return 0, net.UnknownNetworkError("") + }, + } + + f := apns.Feedback{Conn: &m} + c := f.Receive() + + r := 0 + for _ = range c { + r += 1 + } + + Expect(r).To(Equal(0)) }) }) + }) - Context("with feedback", func() { - f1 := bytes.NewBuffer([]byte{}) - f2 := bytes.NewBuffer([]byte{}) - f3 := bytes.NewBuffer([]byte{}) - - // The final token strings - t1 := "00a18269661e9406aea59a5620b05c7c0e371574fa6f251951de8d7a5a292535" - t2 := "00a1a4b7294fcfbc5293f63d4298fcecd9c20a893befd45adceead5fc92d3319" - t3 := "00a1b7893d5e85eb8bb7bf0846b464d075248555118ae893b06e96cfb8d678e3" - - bt1, _ := hex.DecodeString(t1) - bt2, _ := hex.DecodeString(t2) - bt3, _ := hex.DecodeString(t3) - - binary.Write(f1, binary.BigEndian, uint32(1404358249)) - binary.Write(f1, binary.BigEndian, uint16(len(bt1))) - binary.Write(f1, binary.BigEndian, bt1) - - binary.Write(f2, binary.BigEndian, uint32(1404352249)) - binary.Write(f2, binary.BigEndian, uint16(len(bt2))) - binary.Write(f2, binary.BigEndian, bt2) - - binary.Write(f3, binary.BigEndian, uint32(1394352249)) - binary.Write(f3, binary.BigEndian, uint16(len(bt3))) - binary.Write(f3, binary.BigEndian, bt3) - - as := [][]serverAction{ - []serverAction{ - serverAction{action: writeAction, data: f1.Bytes()}, - serverAction{action: writeAction, data: f2.Bytes()}, - serverAction{action: writeAction, data: f3.Bytes()}, - }, - } - - It("should receive feedback", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true - - c := f.Receive() - - r1 := <-c - Expect(r1.Timestamp).To(Equal(time.Unix(1404358249, 0))) - Expect(r1.TokenLength).To(Equal(uint16(len(bt1)))) - Expect(r1.DeviceToken).To(Equal(t1)) - - r2 := <-c - Expect(r2.Timestamp).To(Equal(time.Unix(1404352249, 0))) - Expect(r2.TokenLength).To(Equal(uint16(len(bt2)))) - Expect(r2.DeviceToken).To(Equal(t2)) - - r3 := <-c - Expect(r3.Timestamp).To(Equal(time.Unix(1394352249, 0))) - Expect(r3.TokenLength).To(Equal(uint16(len(bt3)))) - Expect(r3.DeviceToken).To(Equal(t3)) - - <-c - close(d) - }) + Context("with feedback", func() { + f1 := bytes.NewBuffer([]byte{}) + f2 := bytes.NewBuffer([]byte{}) + f3 := bytes.NewBuffer([]byte{}) + + // The final token strings + t1 := "00a18269661e9406aea59a5620b05c7c0e371574fa6f251951de8d7a5a292535" + t2 := "00a1a4b7294fcfbc5293f63d4298fcecd9c20a893befd45adceead5fc92d3319" + t3 := "00a1b7893d5e85eb8bb7bf0846b464d075248555118ae893b06e96cfb8d678e3" + + bt1, _ := hex.DecodeString(t1) + bt2, _ := hex.DecodeString(t2) + bt3, _ := hex.DecodeString(t3) + + binary.Write(f1, binary.BigEndian, uint32(1404358249)) + binary.Write(f1, binary.BigEndian, uint16(len(bt1))) + binary.Write(f1, binary.BigEndian, bt1) + + binary.Write(f2, binary.BigEndian, uint32(1404352249)) + binary.Write(f2, binary.BigEndian, uint16(len(bt2))) + binary.Write(f2, binary.BigEndian, bt2) + + binary.Write(f3, binary.BigEndian, uint32(1394352249)) + fmt.Println("f3 bytes", f3) + + binary.Write(f3, binary.BigEndian, uint16(len(bt3))) + binary.Write(f3, binary.BigEndian, bt3) + + It("should receive feedback", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write(f1.Bytes()) + c.Write(f2.Bytes()) + c.Write(f3.Bytes()) + + // TODO(bw) figure out why we need this + c.Write([]byte{0}) + c.Close() }) + defer s.Close() + + f, err := apns.NewFeedback(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + c := f.Receive() + + r1 := <-c + Expect(r1.Timestamp.Unix()).To(Equal(int64(1404358249))) + Expect(r1.TokenLength).To(Equal(uint16(len(bt1)))) + Expect(r1.DeviceToken).To(Equal(t1)) + + r2 := <-c + Expect(r2.Timestamp.Unix()).To(Equal(int64(1404352249))) + Expect(r2.TokenLength).To(Equal(uint16(len(bt2)))) + Expect(r2.DeviceToken).To(Equal(t2)) + + r3 := <-c + Expect(r3.Timestamp.Unix()).To(Equal(int64(1394352249))) + Expect(r3.TokenLength).To(Equal(uint16(len(bt3)))) + Expect(r3.DeviceToken).To(Equal(t3)) + + <-c + close(d) }) }) }) From e4d4232254b9cda874a894e4ec1d31fd7dcc8ae5 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 28 Jan 2015 19:24:44 -0500 Subject: [PATCH 10/22] Lengthen timeouts for real tcp conn negotiation stuff --- conn.go | 6 ++-- conn_test.go | 85 ++++++++++++++++++++++++------------------------ feedback_test.go | 8 ++--- 3 files changed, 47 insertions(+), 52 deletions(-) diff --git a/conn.go b/conn.go index 1bb8bab..b14f8b8 100644 --- a/conn.go +++ b/conn.go @@ -89,16 +89,14 @@ func (c *conn) Close() error { // Read reads data from the connection func (c *conn) Read(p []byte) (int, error) { - i, err := c.netConn.Read(p) - return i, err + return c.netConn.Read(p) } // ReadWithTimeout reads data from the connection and returns an error // after duration func (c *conn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { c.netConn.SetReadDeadline(deadline) - i, err := c.netConn.Read(p) - return i, err + return c.netConn.Read(p) } // Write writes data from the connection diff --git a/conn_test.go b/conn_test.go index 4bee29c..44ebf23 100644 --- a/conn_test.go +++ b/conn_test.go @@ -129,72 +129,71 @@ var _ = Describe("Conn", func() { Expect(err).To(BeNil()) close(d) - }) + }, 10) }) }) }) Describe("#Read", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - defer c.Close() - c.Write([]byte("hello!")) - }) - defer s.Close() + It("should read out 'hello!'", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte("hello!")) + }) + defer s.Close() - conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - It("should read out 'hello!'", func() { p := make([]byte, 6) conn.Read(p) Expect(p).To(Equal([]byte("hello!"))) }) }) +}) - Describe("#Write", func() { - It("should read out 'hello!'", func(d Done) { - s := tcptest.NewTLSServer(func(c net.Conn) { - defer c.Close() - c.Write([]byte{}) // Connect - - b := make([]byte, 6) - c.Read(b) - - Expect(string(b)).To(Equal("hello!")) - close(d) - }) - defer s.Close() +var _ = Describe("#Write", func() { + It("should read out 'hello!'", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect - conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() + b := make([]byte, 6) + c.Read(b) - conn.Write([]byte("hello!")) + Expect(string(b)).To(Equal("hello!")) + close(d) }) - }) - Describe("#Close", func() { - Context("with connection", func() { - Context("no error", func() { - It("should return no error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - defer c.Close() - c.Write([]byte{}) // Connect - }) - defer s.Close() + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() - Expect(conn.Close()).To(BeNil()) + conn.Write([]byte("hello!")) + }, 10) +}) + +var _ = Describe("#Close", func() { + Context("with connection", func() { + Context("no error", func() { + It("should return no error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect }) - }) - }) + defer s.Close() - Context("without connection", func() { - It("should not return an error", func() { - conn, _ := apns.NewConn("localhost:12345", string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() Expect(conn.Close()).To(BeNil()) }) }) }) + + Context("without connection", func() { + It("should not return an error", func() { + conn, _ := apns.NewConn("localhost:12345", string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(conn.Close()).To(BeNil()) + }) + }) }) diff --git a/feedback_test.go b/feedback_test.go index 8cd909b..96dadb4 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "encoding/hex" - "fmt" "io" "io/ioutil" "net" @@ -136,8 +135,6 @@ var _ = Describe("Feedback", func() { binary.Write(f2, binary.BigEndian, bt2) binary.Write(f3, binary.BigEndian, uint32(1394352249)) - fmt.Println("f3 bytes", f3) - binary.Write(f3, binary.BigEndian, uint16(len(bt3))) binary.Write(f3, binary.BigEndian, bt3) @@ -147,8 +144,9 @@ var _ = Describe("Feedback", func() { c.Write(f2.Bytes()) c.Write(f3.Bytes()) - // TODO(bw) figure out why we need this + // TODO(bw) this doesn't seem right c.Write([]byte{0}) + c.Close() }) defer s.Close() @@ -175,6 +173,6 @@ var _ = Describe("Feedback", func() { <-c close(d) - }) + }, 10) }) }) From 29e6c72cd18c9f1ddc8f1201e110327d83ae741a Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 28 Jan 2015 19:27:47 -0500 Subject: [PATCH 11/22] Add race detection --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 18c443a..687544f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ before_script: - go get github.com/onsi/gomega - go get code.google.com/p/go.tools/cmd/cover - go install github.com/onsi/ginkgo/ginkgo -script: ginkgo -r --skipMeasurements --cover --trace +script: ginkgo -r --skipMeasurements --cover --trace --race env: global: - PATH=$HOME/gopath/bin:$PATH From 2935eb1b272f0c4f6e578be3eabda30c8eb987d5 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 28 Jan 2015 19:32:04 -0500 Subject: [PATCH 12/22] Extend timeout --- conn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_test.go b/conn_test.go index 44ebf23..d388a6d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -110,7 +110,7 @@ var _ = Describe("Conn", func() { Expect(err).To(BeNil()) close(d) - }) + }, 10) }) Context("with existing connection", func() { From fc87338edfc17476415f46e788be15b828e20b8a Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 12 Feb 2015 22:13:31 -0500 Subject: [PATCH 13/22] Start filling out client test --- client.go | 9 ++++ client_test.go | 125 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 client_test.go diff --git a/client.go b/client.go index e4fac1d..c52f623 100644 --- a/client.go +++ b/client.go @@ -93,6 +93,8 @@ func (c *Client) Connect() error { return err } + c.setConnected(true) + // On connect, requeue any notifications that were // sent after the error & disconnect. // http://redth.codes/the-problem-with-apples-push-notification-ser/ @@ -218,3 +220,10 @@ func (c *Client) readErrors() { cursor = cursor.Prev() } } + +func (c *Client) setConnected(connected bool) { + c.connm.Lock() + defer c.connm.Unlock() + + c.connected = true +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..618dd2b --- /dev/null +++ b/client_test.go @@ -0,0 +1,125 @@ +package apns_test + +import ( + "io/ioutil" + "net" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/timehop/apns" + "github.com/timehop/tcptest" +) + +var _ = Describe("Client", func() { + Describe(".NewClient", func() { + Context("bad cert/key pair", func() { + It("should error out", func() { + _, err := apns.NewClient(apns.ProductionGateway, "missing", "missing_also") + Expect(err).NotTo(BeNil()) + }) + }) + + Context("valid cert/key pair", func() { + It("should create a valid client", func() { + _, err := apns.NewClient(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + }) + }) + }) + + Describe(".NewClientWithFiles", func() { + Context("missing cert/key pair", func() { + It("should error out", func() { + _, err := apns.NewClientWithFiles(apns.ProductionGateway, "missing", "missing_also") + Expect(err).NotTo(BeNil()) + }) + }) + + Context("valid cert/key pair", func() { + var certFile, keyFile *os.File + + BeforeEach(func() { + certFile, _ = ioutil.TempFile("", "cert.pem") + certFile.Write([]byte(tcptest.LocalhostCert)) + certFile.Close() + + keyFile, _ = ioutil.TempFile("", "key.pem") + keyFile.Write([]byte(tcptest.LocalhostKey)) + keyFile.Close() + }) + + AfterEach(func() { + if certFile != nil { + os.Remove(certFile.Name()) + } + + if keyFile != nil { + os.Remove(keyFile.Name()) + } + }) + + It("should create a valid client", func() { + _, err := apns.NewClientWithFiles(apns.ProductionGateway, certFile.Name(), keyFile.Name()) + Expect(err).To(BeNil()) + }) + }) + }) + + Describe("Connect", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + }) + }) + + Describe("Send", func() { + Context("valid push", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + }) + }) + + Context("invalid notification", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{DeviceToken: "lol"}) + Expect(err).NotTo(BeNil()) + }) + }) + }) +}) From e4a178743fdb55baeb5c0c192e91d41d622b2627 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 12 Feb 2015 22:18:52 -0500 Subject: [PATCH 14/22] Remove extraneous function --- client.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/client.go b/client.go index c52f623..8be71c1 100644 --- a/client.go +++ b/client.go @@ -93,7 +93,7 @@ func (c *Client) Connect() error { return err } - c.setConnected(true) + c.connected = true // On connect, requeue any notifications that were // sent after the error & disconnect. @@ -220,10 +220,3 @@ func (c *Client) readErrors() { cursor = cursor.Prev() } } - -func (c *Client) setConnected(connected bool) { - c.connm.Lock() - defer c.connm.Unlock() - - c.connected = true -} From 7dd714ec78ee4759758eb2a82fbc210337d999cf Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 12 Feb 2015 22:46:05 -0500 Subject: [PATCH 15/22] Adding more client tests --- client.go | 13 +++++------ client_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 8be71c1..8f1c157 100644 --- a/client.go +++ b/client.go @@ -36,8 +36,6 @@ type Client struct { Conn Conn FailedNotifs chan NotificationResult - notifs chan serialized - buffer *buffer cursor *list.Element @@ -52,7 +50,6 @@ func newClientWithConn(gw string, conn Conn) Client { c := Client{ Conn: conn, FailedNotifs: make(chan NotificationResult), - notifs: make(chan serialized), buffer: newBuffer(50), cursor: nil, id: 0, @@ -197,19 +194,19 @@ func (c *Client) readErrors() { } e := NewError(p) - cursor := c.buffer.Back() - c.disconnect() + cursor := c.buffer.Back() + for cursor != nil { // Get serialized notification - s, _ := cursor.Value.(serialized) + n, _ := cursor.Value.(Notification) // If the notification, move cursor after the trouble notification - if s.id == e.Identifier { + if n.Identifier == e.Identifier { // Try to write - skip if no one is reading on the other side select { - case c.FailedNotifs <- NotificationResult{Notif: *s.n, Err: e}: + case c.FailedNotifs <- NotificationResult{Notif: n, Err: e}: default: } diff --git a/client_test.go b/client_test.go index 618dd2b..5a6b7e5 100644 --- a/client_test.go +++ b/client_test.go @@ -83,6 +83,62 @@ var _ = Describe("Client", func() { }) }) + Describe("Reading Errors", func() { + Context("send a notification and get an error", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte("123456")) + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{Identifier: 859059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + nr := <-c.FailedNotifs + Expect(nr.Err).NotTo(BeNil()) + Expect(nr.Notif.Identifier).To(Equal(uint32(859059510))) + }) + }) + + Context("send a multiple notifications and get an error", func() { + It("should not return an error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write([]byte("123456")) + c.Write([]byte{0}) + c.Close() + }) + defer s.Close() + + c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Connect() + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{Identifier: 859059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{Identifier: 159059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + err = c.Send(apns.Notification{Identifier: 259059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + nr := <-c.FailedNotifs + Expect(nr.Err).NotTo(BeNil()) + Expect(nr.Notif.Identifier).To(Equal(uint32(859059510))) + }) + }) + }) + Describe("Send", func() { Context("valid push", func() { It("should not return an error", func() { @@ -104,7 +160,7 @@ var _ = Describe("Client", func() { }) Context("invalid notification", func() { - It("should not return an error", func() { + It("should return an error", func() { s := tcptest.NewTLSServer(func(c net.Conn) { c.Write([]byte{0}) c.Close() From 571f1853ab6514ce5bdf76c7d597ed30d4d3d9ff Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 17 Feb 2015 00:01:48 -0500 Subject: [PATCH 16/22] Introduce session concept --- client.go | 213 ++++++++++-------------------------------------- client_test.go | 175 +++++++++++++++------------------------ session.go | 216 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 326 insertions(+), 278 deletions(-) create mode 100644 session.go diff --git a/client.go b/client.go index 8f1c157..b9159bb 100644 --- a/client.go +++ b/client.go @@ -1,70 +1,33 @@ package apns import ( - "container/list" "crypto/tls" - "io" "sync" + "time" ) -type buffer struct { - size int - *list.List -} - -func newBuffer(size int) *buffer { - return &buffer{size, list.New()} -} - -func (b *buffer) Add(v interface{}) *list.Element { - e := b.PushBack(v) - - if b.Len() > b.size { - b.Remove(b.Front()) - } - - return e -} - -type serialized struct { - id uint32 - b []byte - n *Notification -} - type Client struct { - Conn Conn - FailedNotifs chan NotificationResult + conn Conn - buffer *buffer - cursor *list.Element - - id uint32 - idm sync.Mutex - - connected bool - connm sync.Mutex + sess Session + sessm sync.Mutex } -func newClientWithConn(gw string, conn Conn) Client { - c := Client{ - Conn: conn, - FailedNotifs: make(chan NotificationResult), - buffer: newBuffer(50), - cursor: nil, - id: 0, - idm: sync.Mutex{}, - connected: false, - connm: sync.Mutex{}, +func newClientWithConn(conn Conn) (Client, error) { + c := Client{conn: conn} + + sess := newSession(conn) + err := sess.Connect() + if err != nil { + return c, err } - return c + return Client{conn, sess, sync.Mutex{}}, nil } -func NewClientWithCert(gw string, cert tls.Certificate) Client { +func NewClientWithCert(gw string, cert tls.Certificate) (Client, error) { conn := NewConnWithCert(gw, cert) - - return newClientWithConn(gw, conn) + return newClientWithConn(conn) } func NewClient(gw string, cert string, key string) (Client, error) { @@ -73,7 +36,7 @@ func NewClient(gw string, cert string, key string) (Client, error) { return Client{}, err } - return newClientWithConn(gw, conn), nil + return newClientWithConn(conn) } func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, error) { @@ -82,138 +45,48 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err return Client{}, err } - return newClientWithConn(gw, conn), nil -} - -func (c *Client) Connect() error { - if err := c.Conn.Connect(); err != nil { - return err - } - - c.connected = true - - // On connect, requeue any notifications that were - // sent after the error & disconnect. - // http://redth.codes/the-problem-with-apples-push-notification-ser/ - if err := c.requeue(); err != nil { - return err - } - - // Kick off asynchronous error reading - go c.readErrors() - - return nil -} - -func (c *Client) disconnect() error { - c.connm.Lock() - defer c.connm.Unlock() - - if c.Conn == nil { - return nil - } - - return c.Conn.Close() + return newClientWithConn(conn) } func (c *Client) Send(n Notification) error { - if !c.connected { - return ErrDisconnected + if c.sess.Disconnected() { + c.reconnectAndRequeue() } - // Set identifier if not specified - n.Identifier = c.determineIdentifier(n.Identifier) - - b, err := n.ToBinary() - if err != nil { - return err - } - - // Add to list - c.cursor = c.buffer.Add(n) - - return c.send(b) + return c.sess.Send(n) } -func (c *Client) send(b []byte) error { - c.connm.Lock() - defer c.connm.Unlock() - - _, err := c.Conn.Write(b) - if err == io.EOF { - c.connected = false - return err - } +func (c *Client) reconnectAndRequeue() { + c.sessm.Lock() + defer c.sessm.Unlock() - if err != nil { - return err - } + // Pull off undelivered notifications + notifs := c.sess.RequeueableNotifications() - c.cursor = c.cursor.Next() - return nil -} + // Reconnect + c.sess = nil -func (c *Client) determineIdentifier(n uint32) uint32 { - c.idm.Lock() - defer c.idm.Unlock() - - // If the id passed in is 0, that means it wasn't - // set so get the next ID. Otherwise, set it to that - // identifier. - if n == 0 { - c.id++ - } else { - c.id = n - } - - return c.id -} + for c.sess == nil { + sess := newSession(c.conn) -func (c *Client) requeue() error { - // If `cursor` is not nil, this means there are notifications that - // need to be delivered (or redelivered) - for ; c.cursor != nil; c.cursor = c.cursor.Next() { - if s, ok := c.cursor.Value.(serialized); ok { - if err := c.Send(*s.n); err != nil { - return err - } + err := sess.Connect() + if err != nil { + // TODO retry policy + // TODO connect error channel + // Keep trying to connect + time.Sleep(1 * time.Second) + continue } - } - - return nil -} - -func (c *Client) readErrors() { - p := make([]byte, 6, 6) - _, err := c.Conn.Read(p) - // TODO(bw) not sure what to do here. It's unclear what errors - // come out of this and how we handle it. - if err != nil { - return + c.sess = sess } - e := NewError(p) - c.disconnect() - - cursor := c.buffer.Back() - - for cursor != nil { - // Get serialized notification - n, _ := cursor.Value.(Notification) - - // If the notification, move cursor after the trouble notification - if n.Identifier == e.Identifier { - // Try to write - skip if no one is reading on the other side - select { - case c.FailedNotifs <- NotificationResult{Notif: n, Err: e}: - default: - } - - c.cursor = cursor.Next() - c.buffer.Remove(cursor) - } - - cursor = cursor.Prev() + for _, n := range notifs { + // TODO handle error from sending + c.sess.Send(n) } } + +var newSession = func(c Conn) Session { + return NewSession(c) +} diff --git a/client_test.go b/client_test.go index 5a6b7e5..5fccc45 100644 --- a/client_test.go +++ b/client_test.go @@ -1,38 +1,80 @@ -package apns_test +package apns import ( + "errors" "io/ioutil" - "net" "os" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "github.com/timehop/apns" "github.com/timehop/tcptest" ) +type mockSession struct { + sendErr error +} + +func (m mockSession) Send(n Notification) error { + return m.sendErr +} + +func (m mockSession) Connect() error { + return nil +} + +func (m mockSession) RequeueableNotifications() []Notification { + return []Notification{} +} + +func (m mockSession) Disconnect() { +} + +func (m mockSession) Disconnected() bool { + return false +} + +type badConnMockSession struct { + mockSession +} + +func (_ badConnMockSession) Connect() error { + return errors.New("whatev") +} + var _ = Describe("Client", func() { + BeforeEach(func() { + newSession = func(_ Conn) Session { return mockSession{} } + }) + Describe(".NewClient", func() { Context("bad cert/key pair", func() { It("should error out", func() { - _, err := apns.NewClient(apns.ProductionGateway, "missing", "missing_also") + _, err := NewClient(ProductionGateway, "missing", "missing_also") Expect(err).NotTo(BeNil()) }) }) Context("valid cert/key pair", func() { It("should create a valid client", func() { - _, err := apns.NewClient(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + _, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) + + Context("bad connection", func() { + It("should error out", func() { + newSession = func(_ Conn) Session { return badConnMockSession{} } + + _, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).NotTo(BeNil()) + }) + }) }) Describe(".NewClientWithFiles", func() { Context("missing cert/key pair", func() { It("should error out", func() { - _, err := apns.NewClientWithFiles(apns.ProductionGateway, "missing", "missing_also") + _, err := NewClientWithFiles(ProductionGateway, "missing", "missing_also") Expect(err).NotTo(BeNil()) }) }) @@ -61,121 +103,38 @@ var _ = Describe("Client", func() { }) It("should create a valid client", func() { - _, err := apns.NewClientWithFiles(apns.ProductionGateway, certFile.Name(), keyFile.Name()) + _, err := NewClientWithFiles(ProductionGateway, certFile.Name(), keyFile.Name()) Expect(err).To(BeNil()) }) }) }) - Describe("Connect", func() { - It("should not return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte{0}) - c.Close() - }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - }) - }) - - Describe("Reading Errors", func() { - Context("send a notification and get an error", func() { - It("should not return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte("123456")) - c.Write([]byte{0}) - c.Close() + Describe("Send", func() { + Context("connected", func() { + Context("valid push", func() { + It("should not return an error", func() { + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - - err = c.Send(apns.Notification{Identifier: 859059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) - - nr := <-c.FailedNotifs - Expect(nr.Err).NotTo(BeNil()) - Expect(nr.Notif.Identifier).To(Equal(uint32(859059510))) }) - }) - - Context("send a multiple notifications and get an error", func() { - It("should not return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte("123456")) - c.Write([]byte{0}) - c.Close() - }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - err = c.Send(apns.Notification{Identifier: 859059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) + Context("invalid notification", func() { + It("should return an error", func() { + newSession = func(_ Conn) Session { return mockSession{sendErr: errors.New("")} } - err = c.Send(apns.Notification{Identifier: 159059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) - - err = c.Send(apns.Notification{Identifier: 259059510, DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) - - nr := <-c.FailedNotifs - Expect(nr.Err).NotTo(BeNil()) - Expect(nr.Notif.Identifier).To(Equal(uint32(859059510))) - }) - }) - }) + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - Describe("Send", func() { - Context("valid push", func() { - It("should not return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte{0}) - c.Close() + err = c.Send(Notification{DeviceToken: "lol"}) + Expect(err).NotTo(BeNil()) }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - - err = c.Send(apns.Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) - Expect(err).To(BeNil()) }) }) - Context("invalid notification", func() { - It("should return an error", func() { - s := tcptest.NewTLSServer(func(c net.Conn) { - c.Write([]byte{0}) - c.Close() - }) - defer s.Close() - - c, err := apns.NewClient(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - Expect(err).To(BeNil()) - - err = c.Connect() - Expect(err).To(BeNil()) - - err = c.Send(apns.Notification{DeviceToken: "lol"}) - Expect(err).NotTo(BeNil()) - }) + Context("disconnected", func() { }) }) }) diff --git a/session.go b/session.go new file mode 100644 index 0000000..ec9fd67 --- /dev/null +++ b/session.go @@ -0,0 +1,216 @@ +package apns + +import ( + "container/list" + "errors" + "io" + "sync" +) + +type SessionError struct { + Notification Notification + Err Error +} + +func (s SessionError) Error() string { + return s.Err.Error() +} + +type Session interface { + Send(n Notification) error + Connect() error + RequeueableNotifications() []Notification + Disconnect() + Disconnected() bool +} + +type buffer struct { + size int + m sync.Mutex + *list.List +} + +func newBuffer(size int) *buffer { + return &buffer{size, sync.Mutex{}, list.New()} +} + +func (b *buffer) Add(v interface{}) *list.Element { + b.m.Lock() + defer b.m.Unlock() + + e := b.PushBack(v) + + if b.Len() > b.size { + b.Remove(b.Front()) + } + + return e +} + +type sessionState int + +const ( + sessionStateNew sessionState = iota + sessionStateConnected sessionState = iota + sessionStateDisconnected sessionState = iota +) + +type session struct { + b *buffer + + conn Conn + connm sync.Mutex + + st sessionState + stm sync.Mutex + + id uint32 + idm sync.Mutex + + err SessionError +} + +func NewSession(conn Conn) Session { + return &session{ + st: sessionStateNew, + stm: sync.Mutex{}, + conn: conn, + connm: sync.Mutex{}, + idm: sync.Mutex{}, + b: newBuffer(50), + } +} + +func (s *session) Connect() error { + if s.st != sessionStateNew { + return errors.New("can't connect unless the session is new") + } + + go s.readErrors() + return nil +} + +func (s *session) Disconnected() bool { + return s.st == sessionStateDisconnected +} + +func (s *session) Send(n Notification) error { + // If disconnected, error out + if s.st != sessionStateConnected { + return errors.New("not connected") + } + + // Set identifier if not specified + n.Identifier = s.determineIdentifier(n.Identifier) + + // Serialize + b, err := n.ToBinary() + if err != nil { + return err + } + + // Add to buffer + s.b.Add(n) + + // Send synchronously + return s.send(b) +} + +func (s *session) send(b []byte) error { + s.connm.Lock() + defer s.connm.Unlock() + + _, err := s.conn.Write(b) + if err == io.EOF { + s.Disconnect() + return err + } + + if err != nil { + return err + } + + return nil +} + +func (s *session) Disconnect() { + // Disconnect + s.transitionState(sessionStateDisconnected) +} + +func (s *session) RequeueableNotifications() []Notification { + notifs := []Notification{} + + // If still connected, return nothing + if s.st != sessionStateDisconnected { + return notifs + } + + // Walk back to last known good notification and return the slice + var e *list.Element + for e = s.b.Front(); e != nil; e = e.Next() { + if n, ok := e.Value.(Notification); ok && n.Identifier == s.err.Notification.Identifier { + break + } + } + + // Start right after errored ID and get the rest of the list + for e = e.Next(); e != nil; e = e.Next() { + n, ok := e.Value.(Notification) + if !ok { + continue + } + + notifs = append(notifs, n) + } + + return notifs +} + +func (s *session) transitionState(st sessionState) { + s.stm.Lock() + defer s.stm.Unlock() + + s.st = st +} + +func (s *session) determineIdentifier(n uint32) uint32 { + s.idm.Lock() + defer s.idm.Unlock() + + // If the id passed in is 0, that means it wasn't + // set so get the next ID. Otherwise, set it to that + // identifier. + if n == 0 { + s.id++ + } else { + s.id = n + } + + return s.id +} + +func (s *session) readErrors() { + p := make([]byte, 6, 6) + + _, err := s.conn.Read(p) + // TODO(bw) not sure what to do here. It's unclear what errors + // come out of this and how we handle it. + if err != nil { + return + } + + s.Disconnect() + + e := NewError(p) + + for cursor := s.b.Back(); cursor != nil; cursor = cursor.Prev() { + // Get serialized notification + n, _ := cursor.Value.(Notification) + + // If the notification, move cursor after the trouble notification + if n.Identifier == e.Identifier { + s.err = SessionError{n, e} + } + } +} From 1af119239be67d864ff5ea7f912b1552cc20687b Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 17 Feb 2015 00:05:47 -0500 Subject: [PATCH 17/22] Fix example --- example/example.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/example/example.go b/example/example.go index 2d9e5c0..637d0af 100644 --- a/example/example.go +++ b/example/example.go @@ -13,10 +13,6 @@ func main() { log.Fatal("Could not create client: ", err.Error()) } - if err := c.Connect(); err != nil { - log.Fatal("Could not create connect: ", err.Error()) - } - i := 1 for { fmt.Print("Enter ' ': ") From cc9e4f3dab1096afd907e880742265804efdbd5f Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Tue, 17 Feb 2015 18:18:05 -0500 Subject: [PATCH 18/22] Add more tests for client --- client_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 77 insertions(+), 12 deletions(-) diff --git a/client_test.go b/client_test.go index 5fccc45..7dbee14 100644 --- a/client_test.go +++ b/client_test.go @@ -11,30 +11,41 @@ import ( ) type mockSession struct { - sendErr error + sendCB func(n Notification) error + requeueNotifs []Notification + disconnectedState bool } -func (m mockSession) Send(n Notification) error { - return m.sendErr +func (m *mockSession) Send(n Notification) error { + if m.sendCB == nil { + return nil + } + + return m.sendCB(n) } -func (m mockSession) Connect() error { +func (m *mockSession) Connect() error { return nil } -func (m mockSession) RequeueableNotifications() []Notification { - return []Notification{} +func (m *mockSession) RequeueableNotifications() []Notification { + if len(m.requeueNotifs) == 0 { + return []Notification{} + } + + return m.requeueNotifs } -func (m mockSession) Disconnect() { +func (m *mockSession) Disconnect() { + m.disconnectedState = true } -func (m mockSession) Disconnected() bool { - return false +func (m *mockSession) Disconnected() bool { + return m.disconnectedState } type badConnMockSession struct { - mockSession + *mockSession } func (_ badConnMockSession) Connect() error { @@ -43,7 +54,7 @@ func (_ badConnMockSession) Connect() error { var _ = Describe("Client", func() { BeforeEach(func() { - newSession = func(_ Conn) Session { return mockSession{} } + newSession = func(_ Conn) Session { return &mockSession{} } }) Describe(".NewClient", func() { @@ -123,7 +134,13 @@ var _ = Describe("Client", func() { Context("invalid notification", func() { It("should return an error", func() { - newSession = func(_ Conn) Session { return mockSession{sendErr: errors.New("")} } + newSession = func(_ Conn) Session { + return &mockSession{ + sendCB: func(_ Notification) error { + return errors.New("") + }, + } + } c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) @@ -135,6 +152,54 @@ var _ = Describe("Client", func() { }) Context("disconnected", func() { + It("should reconnect", func() { + newSessCount := 0 + newSession = func(_ Conn) Session { + newSessCount += 1 + return &mockSession{} + } + + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + c.sess.Disconnect() + + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + Expect(newSessCount).To(Equal(2)) + }) + }) + + It("should reconnect and requeue", func() { + newSessCount := 0 + sendCount := 0 + + newSession = func(_ Conn) Session { + newSessCount += 1 + return &mockSession{ + requeueNotifs: []Notification{ + Notification{}, + Notification{}, + Notification{}, + }, + sendCB: func(_ Notification) error { + sendCount += 1 + return nil + }, + } + } + + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + c.sess.Disconnect() + + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) + + Expect(newSessCount).To(Equal(2)) + Expect(sendCount).To(Equal(4)) }) }) }) From d06b28b72987a8be4f49a7fd3e04cdaac75855ff Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 25 Feb 2015 17:13:43 -0500 Subject: [PATCH 19/22] Beginnings of a session test --- session_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 session_test.go diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..67a598f --- /dev/null +++ b/session_test.go @@ -0,0 +1,94 @@ +package apns + +import ( + "time" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type mockConn struct{} + +func (m mockConn) Read(b []byte) (int, error) { + return 0, nil +} + +func (m mockConn) Write(b []byte) (int, error) { + return 0, nil +} + +func (m mockConn) Close() error { + return nil +} + +func (m mockConn) Connect() error { + return nil +} + +func (m mockConn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { + return 0, nil +} + +var _ = Describe("Session", func() { + Describe("NewSession", func() { + It("creates a session", func() { + s := NewSession(mockConn{}) + Expect(s).NotTo(BeNil()) + }) + }) + + Describe("Connect", func() { + Context("new state", func() { + It("should not return an error", func() { + s := NewSession(mockConn{}) + + err := s.Connect() + Expect(err).To(BeNil()) + }) + }) + + Context("not new state", func() { + It("should return an error", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.transitionState(sessionStateDisconnected) + + err := s.Connect() + Expect(err).NotTo(BeNil()) + }) + }) + }) + + Describe("Disconnected", func() { + Context("not connected", func() { + It("should not be true", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.transitionState(sessionStateDisconnected) + + Expect(s.Disconnected()).To(BeTrue()) + }) + }) + + Context("connected", func() { + It("should be false", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.Connect() + + Expect(s.Disconnected()).To(BeFalse()) + }) + }) + }) + + Describe("Send", func() { + }) + + Describe("Disconnect", func() { + }) + + Describe("RequeueableNotifications", func() { + }) +}) From d01f235d05c648f1dd84c495fb3c29fae25bcaa9 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Thu, 26 Feb 2015 16:30:00 -0500 Subject: [PATCH 20/22] Synchronize around the session state --- session.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/session.go b/session.go index ec9fd67..c7dd832 100644 --- a/session.go +++ b/session.go @@ -82,7 +82,7 @@ func NewSession(conn Conn) Session { } func (s *session) Connect() error { - if s.st != sessionStateNew { + if s.isNew() { return errors.New("can't connect unless the session is new") } @@ -90,13 +90,30 @@ func (s *session) Connect() error { return nil } +func (s *session) isNew() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st != sessionStateNew +} + func (s *session) Disconnected() bool { + s.stm.Lock() + defer s.stm.Unlock() + return s.st == sessionStateDisconnected } +func (s *session) Connnected() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st == sessionStateConnected +} + func (s *session) Send(n Notification) error { // If disconnected, error out - if s.st != sessionStateConnected { + if s.Connnected() { return errors.New("not connected") } From 70205b7a9b6840ebbdc2969b91d25793ae6848b8 Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 4 Mar 2015 15:49:23 -0500 Subject: [PATCH 21/22] How do you even logic --- session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/session.go b/session.go index c7dd832..dc7e416 100644 --- a/session.go +++ b/session.go @@ -113,7 +113,7 @@ func (s *session) Connnected() bool { func (s *session) Send(n Notification) error { // If disconnected, error out - if s.Connnected() { + if !s.Connnected() { return errors.New("not connected") } From d245d825421023d5c288c7a4ba7f109182d2cbee Mon Sep 17 00:00:00 2001 From: Benny Wong Date: Wed, 4 Mar 2015 15:51:18 -0500 Subject: [PATCH 22/22] Clean up session states --- session.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/session.go b/session.go index dc7e416..36e4576 100644 --- a/session.go +++ b/session.go @@ -50,9 +50,9 @@ func (b *buffer) Add(v interface{}) *list.Element { type sessionState int const ( - sessionStateNew sessionState = iota - sessionStateConnected sessionState = iota - sessionStateDisconnected sessionState = iota + sessionStateNew sessionState = 1 << iota + sessionStateConnected + sessionStateDisconnected ) type session struct { @@ -143,11 +143,7 @@ func (s *session) send(b []byte) error { return err } - if err != nil { - return err - } - - return nil + return err } func (s *session) Disconnect() {