From 7c368464c9e47cc5188335671ed84832e2f989af Mon Sep 17 00:00:00 2001 From: Kwok-kuen Cheung Date: Fri, 8 Jul 2016 11:27:09 +0800 Subject: [PATCH] Find default APNS topic from APNS certificate --- push/apns.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/push/apns.go b/push/apns.go index a511edc1d..838fb40c6 100644 --- a/push/apns.go +++ b/push/apns.go @@ -16,6 +16,8 @@ package push import ( "crypto/tls" + "crypto/x509" + "encoding/asn1" "encoding/json" "errors" "fmt" @@ -49,6 +51,7 @@ type APNSPusher struct { conn skydb.Conn service pushService failed chan failedNotification + topic string } type failedNotification struct { @@ -56,6 +59,45 @@ type failedNotification struct { err push.Error } +func parseDefaultCertificateLeaf(certificate *tls.Certificate) error { + if certificate.Leaf != nil { + return nil + } + + for _, cert := range certificate.Certificate { + x509Cert, err := x509.ParseCertificate(cert) + if err != nil { + return err + } + certificate.Leaf = x509Cert + return nil + } + return errors.New("push/apns: provided APNS certificate does not contain leaf") +} + +// findDefaultAPNSTopic returns the APNS topic in the TLS certificate. +func findDefaultAPNSTopic(certificate tls.Certificate) (string, error) { + if certificate.Leaf == nil { + err := parseDefaultCertificateLeaf(&certificate) + if err != nil { + return "", err + } + } + + uidObjectIdentifier := asn1.ObjectIdentifier([]int{0, 9, 2342, 19200300, 100, 1, 1}) + for _, attr := range certificate.Leaf.Subject.Names { + if uidObjectIdentifier.Equal(attr.Type) { + switch value := attr.Value.(type) { + case string: + return value, nil + } + break + } + } + + return "", errors.New("push/apns: cannot find UID in APNS certificate subject name") +} + // NewAPNSPusher returns a new APNSPusher from content of certificate // and private key as string func NewAPNSPusher(connOpener func() (skydb.Conn, error), gwType GatewayType, cert string, key string) (*APNSPusher, error) { @@ -64,6 +106,11 @@ func NewAPNSPusher(connOpener func() (skydb.Conn, error), gwType GatewayType, ce return nil, err } + topic, err := findDefaultAPNSTopic(certificate) + if err != nil { + return nil, err + } + client, err := push.NewClient(certificate) if err != nil { return nil, err @@ -82,6 +129,7 @@ func NewAPNSPusher(connOpener func() (skydb.Conn, error), gwType GatewayType, ce return &APNSPusher{ connOpener: connOpener, service: service, + topic: topic, }, nil } @@ -163,6 +211,7 @@ func (pusher *APNSPusher) Send(m Mapper, device skydb.Device) error { logger := log.WithFields(log.Fields{ "deviceToken": device.Token, "deviceID": device.ID, + "apnsTopic": pusher.topic, }) if m == nil { @@ -180,8 +229,12 @@ func (pusher *APNSPusher) Send(m Mapper, device skydb.Device) error { return err } + headers := push.Headers{ + Topic: pusher.topic, + } + // push the notification: - apnsid, err := pusher.service.Push(device.Token, nil, serializedPayload) + apnsid, err := pusher.service.Push(device.Token, &headers, serializedPayload) if err != nil { if pushError, ok := err.(*push.Error); ok && pushError != nil { // We recognize the error, and that error comes from APNS