Skip to content

Commit

Permalink
Add monitoring to JWT Signer and Verify factory.
Browse files Browse the repository at this point in the history
If a MonitoringClientFactory is registered, then we create the Jwt Signer/Verifier wrapper with the corresponding loggers to report events from the Sign API.

PiperOrigin-RevId: 642250007
Change-Id: Ie50080a8c04288e42f421c6a4189a18cba98e3ce
  • Loading branch information
fernandolobato authored and copybara-github committed Jun 11, 2024
1 parent 36ca961 commit 3b9f7c1
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 5 deletions.
39 changes: 36 additions & 3 deletions jwt/jwt_signer_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (
"fmt"

"github.com/tink-crypto/tink-go/v2/core/primitiveset"
"github.com/tink-crypto/tink-go/v2/internal/internalregistry"
"github.com/tink-crypto/tink-go/v2/internal/monitoringutil"
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/monitoring"
tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
)

Expand All @@ -36,11 +39,28 @@ func NewSigner(handle *keyset.Handle) (Signer, error) {

// wrappedSigner is a JWT Signer implementation that uses the underlying primitive set for JWT Sign.
type wrappedSigner struct {
ps *primitiveset.PrimitiveSet
ps *primitiveset.PrimitiveSet
logger monitoring.Logger
}

var _ Signer = (*wrappedSigner)(nil)

func createSignerLogger(ps *primitiveset.PrimitiveSet) (monitoring.Logger, error) {
// only keysets which contain annotations are monitored.
if len(ps.Annotations) == 0 {
return &monitoringutil.DoNothingLogger{}, nil
}
keysetInfo, err := monitoringutil.KeysetInfoFromPrimitiveSet(ps)
if err != nil {
return nil, err
}
return internalregistry.GetMonitoringClient().NewLogger(&monitoring.Context{
KeysetInfo: keysetInfo,
Primitive: "jwtsign",
APIFunction: "sign",
})
}

func newWrappedSigner(ps *primitiveset.PrimitiveSet) (*wrappedSigner, error) {
if _, ok := (ps.Primary.Primitive).(*signerWithKID); !ok {
return nil, fmt.Errorf("jwt_signer_factory: not a JWT Signer primitive")
Expand All @@ -55,7 +75,14 @@ func newWrappedSigner(ps *primitiveset.PrimitiveSet) (*wrappedSigner, error) {
}
}
}
return &wrappedSigner{ps: ps}, nil
logger, err := createSignerLogger(ps)
if err != nil {
return nil, err
}
return &wrappedSigner{
ps: ps,
logger: logger,
}, nil
}

func (w *wrappedSigner) SignAndEncode(rawJWT *RawJWT) (string, error) {
Expand All @@ -64,5 +91,11 @@ func (w *wrappedSigner) SignAndEncode(rawJWT *RawJWT) (string, error) {
if !ok {
return "", fmt.Errorf("jwt_signer_factory: not a JWT Signer primitive")
}
return p.SignAndEncodeWithKID(rawJWT, keyID(primary.KeyID, primary.PrefixType))
token, err := p.SignAndEncodeWithKID(rawJWT, keyID(primary.KeyID, primary.PrefixType))
if err != nil {
w.logger.LogFailure()
return "", err
}
w.logger.Log(primary.KeyID, 1)
return token, nil
}
263 changes: 263 additions & 0 deletions jwt/jwt_signer_verifier_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
package jwt_test

import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"github.com/tink-crypto/tink-go/v2/insecurecleartextkeyset"
"github.com/tink-crypto/tink-go/v2/internal/internalregistry"
"github.com/tink-crypto/tink-go/v2/jwt"
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/monitoring"
"github.com/tink-crypto/tink-go/v2/signature"
"github.com/tink-crypto/tink-go/v2/testing/fakemonitoring"
"github.com/tink-crypto/tink-go/v2/testkeyset"
"github.com/tink-crypto/tink-go/v2/testutil"
jepb "github.com/tink-crypto/tink-go/v2/proto/jwt_ecdsa_go_proto"
Expand Down Expand Up @@ -374,3 +380,260 @@ func TestFactorySignVerifyWithKIDSuccess(t *testing.T) {
})
}
}

func TestFactorySignVerifyWithoutAnnotationsEmitsNoMonitoring(t *testing.T) {
defer internalregistry.ClearMonitoringClient()
client := fakemonitoring.NewClient("fake-client")
if err := internalregistry.RegisterMonitoringClient(client); err != nil {
t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err)
}
privHandle, err := keyset.NewHandle(jwt.ES256Template())
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
pubHandle, err := privHandle.Public()
if err != nil {
t.Fatalf("privHandle.Public() err = %v, want nil", err)
}
signer, err := jwt.NewSigner(privHandle)
if err != nil {
t.Fatalf("jwt.NewSigner() err = %v, want nil", err)
}
verifier, err := jwt.NewVerifier(pubHandle)
if err != nil {
t.Fatalf("jwt.NewVerifier() err = %v, want nil", err)
}
rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{WithoutExpiration: true})
if err != nil {
t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err)
}
validator, err := jwt.NewValidator(&jwt.ValidatorOpts{AllowMissingExpiration: true})
if err != nil {
t.Fatalf("jwt.NewValidator() err = %v, want nil", err)
}
compact, err := signer.SignAndEncode(rawJWT)
if err != nil {
t.Fatalf("signer.SignAndEncode() err = %v, want nil", err)
}
if _, err := verifier.VerifyAndDecode(compact, validator); err != nil {
t.Fatalf("verifier.VerifyAndDecode() err = %v, want nil", err)
}
if len(client.Events()) != 0 {
t.Errorf("len(client.Events()) = %d, want 0", len(client.Events()))
}
if len(client.Failures()) != 0 {
t.Errorf("len(client.Failures()) = %d, want 0", len(client.Failures()))
}
}

func TestFactorySignWithAnnotationsEmitsMonitoringSuccess(t *testing.T) {
defer internalregistry.ClearMonitoringClient()
client := fakemonitoring.NewClient("fake-client")
if err := internalregistry.RegisterMonitoringClient(client); err != nil {
t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err)
}
handle, err := keyset.NewHandle(jwt.ES256Template())
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
buff := &bytes.Buffer{}
if err := insecurecleartextkeyset.Write(handle, keyset.NewBinaryWriter(buff)); err != nil {
t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err)
}
annotations := map[string]string{"foo": "bar"}
privHandle, err := insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations))
if err != nil {
t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err)
}
// Verify annotations aren't propagated.
pubHandle, err := privHandle.Public()
if err != nil {
t.Fatalf("privHandle.Public() err = %v, want nil", err)
}
signer, err := jwt.NewSigner(privHandle)
if err != nil {
t.Fatalf("jwt.NewSigner() err = %v, want nil", err)
}
verifier, err := jwt.NewVerifier(pubHandle)
if err != nil {
t.Fatalf("jwt.NewVerifier() err = %v, want nil", err)
}
rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{WithoutExpiration: true})
if err != nil {
t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err)
}
validator, err := jwt.NewValidator(&jwt.ValidatorOpts{AllowMissingExpiration: true})
if err != nil {
t.Fatalf("jwt.NewValidator() err = %v, want nil", err)
}
compact, err := signer.SignAndEncode(rawJWT)
if err != nil {
t.Fatalf("signer.SignAndEncode() err = %v, want nil", err)
}
if _, err := verifier.VerifyAndDecode(compact, validator); err != nil {
t.Fatalf("verifier.VerifyAndDecode() err = %v, want nil", err)
}
// verify error emits no monitoring.
if _, err := verifier.VerifyAndDecode("invalid", validator); err == nil {
t.Fatalf("verifier.VerifyAndDecode() err = %v, want nil", err)
}
if len(client.Failures()) != 0 {
t.Errorf("len(client.Failures()) = %d, want 0", len(client.Failures()))
}
if len(client.Events()) != 1 {
t.Errorf("len(client.Events()) = %d, want 1", len(client.Events()))
}
got := client.Events()
wantSignKeysetInfo := &monitoring.KeysetInfo{
Annotations: annotations,
PrimaryKeyID: privHandle.KeysetInfo().GetPrimaryKeyId(),
Entries: []*monitoring.Entry{
{
KeyID: privHandle.KeysetInfo().GetPrimaryKeyId(),
Status: monitoring.Enabled,
KeyType: "tink.JwtEcdsaPrivateKey",
KeyPrefix: "TINK",
},
},
}
want := []*fakemonitoring.LogEvent{
{
Context: monitoring.NewContext("jwtsign", "sign", wantSignKeysetInfo),
KeyID: privHandle.KeysetInfo().GetPrimaryKeyId(),
NumBytes: 1,
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("%v", diff)
}
}

func TestFactoryVerifyWithAnnotationsEmitsMonitoringSuccess(t *testing.T) {
defer internalregistry.ClearMonitoringClient()
client := fakemonitoring.NewClient("fake-client")
if err := internalregistry.RegisterMonitoringClient(client); err != nil {
t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err)
}
privHandle, err := keyset.NewHandle(jwt.ES256Template())
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
signer, err := jwt.NewSigner(privHandle)
if err != nil {
t.Fatalf("jwt.NewSigner() err = %v, want nil", err)
}

pubHandle, err := privHandle.Public()
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
buff := &bytes.Buffer{}
if err := insecurecleartextkeyset.Write(pubHandle, keyset.NewBinaryWriter(buff)); err != nil {
t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err)
}
annotations := map[string]string{"foo": "bar"}
pubHandle, err = insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations))
if err != nil {
t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err)
}
verifier, err := jwt.NewVerifier(pubHandle)
if err != nil {
t.Fatalf("jwt.NewVerifier() err = %v, want nil", err)
}
rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{WithoutExpiration: true})
if err != nil {
t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err)
}
validator, err := jwt.NewValidator(&jwt.ValidatorOpts{AllowMissingExpiration: true})
if err != nil {
t.Fatalf("jwt.NewValidator() err = %v, want nil", err)
}
compact, err := signer.SignAndEncode(rawJWT)
if err != nil {
t.Fatalf("signer.SignAndEncode() err = %v, want nil", err)
}
if _, err := verifier.VerifyAndDecode(compact, validator); err != nil {
t.Fatalf("verifier.VerifyAndDecode() err = %v, want nil", err)
}
if len(client.Failures()) != 0 {
t.Errorf("len(client.Failures()) = %d, want 0", len(client.Failures()))
}
if len(client.Events()) != 1 {
t.Errorf("len(client.Events()) = %d, want 1", len(client.Events()))
}
got := client.Events()
wantSignKeysetInfo := &monitoring.KeysetInfo{
Annotations: annotations,
PrimaryKeyID: privHandle.KeysetInfo().GetPrimaryKeyId(),
Entries: []*monitoring.Entry{
{
KeyID: privHandle.KeysetInfo().GetPrimaryKeyId(),
Status: monitoring.Enabled,
KeyType: "tink.JwtEcdsaPublicKey",
KeyPrefix: "TINK",
},
},
}
want := []*fakemonitoring.LogEvent{
{
Context: monitoring.NewContext("jwtverify", "verify", wantSignKeysetInfo),
KeyID: privHandle.KeysetInfo().GetPrimaryKeyId(),
NumBytes: 1,
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("%v", diff)
}
}

func TestFactorySignAndVerifyWithAnnotationsEmitsMonitoringOnError(t *testing.T) {
defer internalregistry.ClearMonitoringClient()
client := fakemonitoring.NewClient("fake-client")
if err := internalregistry.RegisterMonitoringClient(client); err != nil {
t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err)
}
kid := "intrusive_kid"
_, privHandle, pubHandle := createKeyAndKeyHandles(t, &kid, tinkpb.OutputPrefixType_TINK)
buff := &bytes.Buffer{}
if err := insecurecleartextkeyset.Write(privHandle, keyset.NewBinaryWriter(buff)); err != nil {
t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err)
}
annotations := map[string]string{"foo": "bar"}
privHandle, err := insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations))
if err != nil {
t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err)
}
signer, err := jwt.NewSigner(privHandle)
if err != nil {
t.Fatalf("jwt.NewSigner() err = %v, want nil", err)
}
buff.Reset()
if err := insecurecleartextkeyset.Write(pubHandle, keyset.NewBinaryWriter(buff)); err != nil {
t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err)
}
pubHandle, err = insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations))
if err != nil {
t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err)
}
verifier, err := jwt.NewVerifier(pubHandle)
if err != nil {
t.Fatalf("jwt.NewVerifier() err = %v, want nil", err)
}
rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{WithoutExpiration: true})
if err != nil {
t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err)
}
validator, err := jwt.NewValidator(&jwt.ValidatorOpts{AllowMissingExpiration: true})
if err != nil {
t.Fatalf("jwt.NewValidator() err = %v, want nil", err)
}
if _, err := signer.SignAndEncode(rawJWT); err == nil {
t.Fatalf("signer.SignAndEncode() err = nil, want error")
}
if _, err := verifier.VerifyAndDecode("invalid_token", validator); err == nil {
t.Fatalf("verifier.VerifyAndDecode() err = nil want error")
}
if len(client.Failures()) != 2 {
t.Errorf("len(client.Failures()) = %d, want 2", len(client.Failures()))
}
}
Loading

0 comments on commit 3b9f7c1

Please sign in to comment.