diff --git a/jwt/BUILD.bazel b/jwt/BUILD.bazel index 645c645..848872b 100644 --- a/jwt/BUILD.bazel +++ b/jwt/BUILD.bazel @@ -36,10 +36,13 @@ go_library( deps = [ "//core/primitiveset", "//core/registry", + "//internal/internalregistry", + "//internal/monitoringutil", "//internal/signature", "//internal/tinkerror", "//keyset", "//mac/subtle", + "//monitoring", "//proto/jwt_ecdsa_go_proto", "//proto/jwt_hmac_go_proto", "//proto/jwt_rsa_ssa_pkcs1_go_proto", @@ -82,8 +85,10 @@ go_test( deps = [ "//core/registry", "//insecurecleartextkeyset", + "//internal/internalregistry", "//keyset", "//mac/subtle", + "//monitoring", "//proto/ecdsa_go_proto", "//proto/jwt_ecdsa_go_proto", "//proto/jwt_hmac_go_proto", @@ -93,6 +98,7 @@ go_test( "//signature", "//signature/subtle", "//subtle/random", + "//testing/fakemonitoring", "//testkeyset", "//testutil", "//tink", diff --git a/jwt/jwt_mac_factory.go b/jwt/jwt_mac_factory.go index 6345e6d..9760531 100644 --- a/jwt/jwt_mac_factory.go +++ b/jwt/jwt_mac_factory.go @@ -18,7 +18,10 @@ import ( "fmt" "github.com/tink-crypto/tink-go/v2/core/primitiveset" + "github.com/tink-crypto/tink-go/v2/internal/internalregistry" + "github.com/tink-crypto/tink-go/v2/internal/monitoringutil" "github.com/tink-crypto/tink-go/v2/keyset" + "github.com/tink-crypto/tink-go/v2/monitoring" tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto" ) @@ -36,7 +39,9 @@ func NewMAC(handle *keyset.Handle) (MAC, error) { // wrappedJWTMAC is a JWTMAC implementation that uses the underlying primitive set for JWT MAC. type wrappedJWTMAC struct { - ps *primitiveset.PrimitiveSet + ps *primitiveset.PrimitiveSet + computeLogger monitoring.Logger + verifyLogger monitoring.Logger } var _ MAC = (*wrappedJWTMAC)(nil) @@ -55,7 +60,39 @@ func newWrappedJWTMAC(ps *primitiveset.PrimitiveSet) (*wrappedJWTMAC, error) { } } } - return &wrappedJWTMAC{ps: ps}, nil + computeLogger, verifyLogger, err := createLoggers(ps) + if err != nil { + return nil, err + } + return &wrappedJWTMAC{ps: ps, computeLogger: computeLogger, verifyLogger: verifyLogger}, nil +} + +func createLoggers(ps *primitiveset.PrimitiveSet) (monitoring.Logger, monitoring.Logger, error) { + if len(ps.Annotations) == 0 { + return &monitoringutil.DoNothingLogger{}, &monitoringutil.DoNothingLogger{}, nil + } + client := internalregistry.GetMonitoringClient() + keysetInfo, err := monitoringutil.KeysetInfoFromPrimitiveSet(ps) + if err != nil { + return nil, nil, err + } + computeLogger, err := client.NewLogger(&monitoring.Context{ + Primitive: "jwtmac", + APIFunction: "compute", + KeysetInfo: keysetInfo, + }) + if err != nil { + return nil, nil, err + } + verifyLogger, err := client.NewLogger(&monitoring.Context{ + Primitive: "jwtmac", + APIFunction: "verify", + KeysetInfo: keysetInfo, + }) + if err != nil { + return nil, nil, err + } + return computeLogger, verifyLogger, nil } func (w *wrappedJWTMAC) ComputeMACAndEncode(token *RawJWT) (string, error) { @@ -64,7 +101,13 @@ func (w *wrappedJWTMAC) ComputeMACAndEncode(token *RawJWT) (string, error) { if !ok { return "", fmt.Errorf("jwt_mac_factory: not a JWT MAC primitive") } - return p.ComputeMACAndEncodeWithKID(token, keyID(primary.KeyID, primary.PrefixType)) + signedToken, err := p.ComputeMACAndEncodeWithKID(token, keyID(primary.KeyID, primary.PrefixType)) + if err != nil { + w.computeLogger.LogFailure() + return "", err + } + w.computeLogger.Log(primary.KeyID, 1) + return signedToken, nil } func (w *wrappedJWTMAC) VerifyMACAndDecode(compact string, validator *Validator) (*VerifiedJWT, error) { @@ -77,6 +120,7 @@ func (w *wrappedJWTMAC) VerifyMACAndDecode(compact string, validator *Validator) } verifiedJWT, err := p.VerifyMACAndDecodeWithKID(compact, validator, keyID(e.KeyID, e.PrefixType)) if err == nil { + w.verifyLogger.Log(e.KeyID, 1) return verifiedJWT, nil } if err != errJwtVerification { @@ -85,6 +129,7 @@ func (w *wrappedJWTMAC) VerifyMACAndDecode(compact string, validator *Validator) } } } + w.verifyLogger.LogFailure() if interestingErr != nil { return nil, interestingErr } diff --git a/jwt/jwt_mac_factory_test.go b/jwt/jwt_mac_factory_test.go index 0eeb8bc..fc145d8 100644 --- a/jwt/jwt_mac_factory_test.go +++ b/jwt/jwt_mac_factory_test.go @@ -15,14 +15,20 @@ package jwt_test import ( + "bytes" "fmt" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/protobuf/proto" + "github.com/tink-crypto/tink-go/v2/insecurecleartextkeyset" + "github.com/tink-crypto/tink-go/v2/internal/internalregistry" "github.com/tink-crypto/tink-go/v2/jwt" "github.com/tink-crypto/tink-go/v2/keyset" + "github.com/tink-crypto/tink-go/v2/monitoring" "github.com/tink-crypto/tink-go/v2/signature" "github.com/tink-crypto/tink-go/v2/subtle/random" + "github.com/tink-crypto/tink-go/v2/testing/fakemonitoring" "github.com/tink-crypto/tink-go/v2/testkeyset" "github.com/tink-crypto/tink-go/v2/testutil" @@ -277,3 +283,204 @@ func TestVerifyMACAndDecodeReturnsValidationError(t *testing.T) { t.Errorf("p.VerifyMACAndDecode() err = %q, want %q", err.Error(), wantErr) } } + +func TestComputeAndVerifyWithoutAnnotationsEmitsNoMonitoring(t *testing.T) { + defer internalregistry.ClearMonitoringClient() + client := fakemonitoring.NewClient("fake-client") + if err := internalregistry.RegisterMonitoringClient(client); err != nil { + t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err) + } + kh, err := keyset.NewHandle(jwt.HS256Template()) + if err != nil { + t.Fatalf("keyset.NewHandle() err = %v, want nil", err) + } + p, err := jwt.NewMAC(kh) + if err != nil { + t.Fatalf("jwt.NewMAC() err = %v, want nil", err) + } + audience := "audience" + rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{Audience: &audience, WithoutExpiration: true}) + if err != nil { + t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err) + } + token, err := p.ComputeMACAndEncode(rawJWT) + if err != nil { + t.Errorf("p.ComputeMACAndEncode() err = %v, want nil", err) + } + validator, err := jwt.NewValidator( + &jwt.ValidatorOpts{ExpectedAudience: &audience, AllowMissingExpiration: true}) + if err != nil { + t.Fatalf("jwt.NewValidator() err = %v, want nil", err) + } + if _, err = p.VerifyMACAndDecode(token, validator); err != nil { + t.Errorf("p.VerifyMACAndDecode() err = %v, want error", err) + } + if len(client.Failures()) != 0 { + t.Errorf("len(client.Failures()) = %d, want = 0", len(client.Failures())) + } + if len(client.Events()) != 0 { + t.Errorf("len(client.Events()) = %d, want = 0", len(client.Events())) + } +} + +func TestComputeAndVerifyWithAnnotationsEmitsMonitoring(t *testing.T) { + defer internalregistry.ClearMonitoringClient() + client := fakemonitoring.NewClient("fake-client") + if err := internalregistry.RegisterMonitoringClient(client); err != nil { + t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err) + } + kh, err := keyset.NewHandle(jwt.HS256Template()) + if err != nil { + t.Fatalf("keyset.NewHandle() err = %v, want nil", err) + } + // Annotations are only supported through the `insecurecleartextkeyset` API. + buff := &bytes.Buffer{} + if err := insecurecleartextkeyset.Write(kh, keyset.NewBinaryWriter(buff)); err != nil { + t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err) + } + annotations := map[string]string{"foo": "bar"} + mh, err := insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations)) + if err != nil { + t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err) + } + p, err := jwt.NewMAC(mh) + if err != nil { + t.Fatalf("jwt.NewMAC() err = %v, want nil", err) + } + audience := "audience" + rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{Audience: &audience, WithoutExpiration: true}) + if err != nil { + t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err) + } + token, err := p.ComputeMACAndEncode(rawJWT) + if err != nil { + t.Errorf("p.ComputeMACAndEncode() err = %v, want nil", err) + } + validator, err := jwt.NewValidator( + &jwt.ValidatorOpts{ExpectedAudience: &audience, AllowMissingExpiration: true}) + if err != nil { + t.Fatalf("jwt.NewValidator() err = %v, want nil", err) + } + if _, err = p.VerifyMACAndDecode(token, validator); err != nil { + t.Errorf("p.VerifyMACAndDecode() err = %v, want error", err) + } + failures := client.Failures() + if len(failures) != 0 { + t.Errorf("len(client.Failures()) = %d, want = 0", len(failures)) + } + got := client.Events() + wantKeysetInfo := monitoring.NewKeysetInfo( + annotations, + kh.KeysetInfo().GetPrimaryKeyId(), + []*monitoring.Entry{ + { + KeyID: kh.KeysetInfo().GetPrimaryKeyId(), + Status: monitoring.Enabled, + KeyType: "tink.JwtHmacKey", + KeyPrefix: "TINK", + }, + }, + ) + want := []*fakemonitoring.LogEvent{ + { + KeyID: mh.KeysetInfo().GetPrimaryKeyId(), + NumBytes: 1, + Context: monitoring.NewContext("jwtmac", "compute", wantKeysetInfo), + }, + { + KeyID: mh.KeysetInfo().GetPrimaryKeyId(), + NumBytes: 1, + Context: monitoring.NewContext("jwtmac", "verify", wantKeysetInfo), + }, + } + if cmp.Diff(got, want) != "" { + t.Errorf("%v", cmp.Diff(got, want)) + } +} + +func TestComputeFailureEmitsMonitoring(t *testing.T) { + defer internalregistry.ClearMonitoringClient() + client := &fakemonitoring.Client{Name: ""} + if err := internalregistry.RegisterMonitoringClient(client); err != nil { + t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err) + } + keyData, err := newKeyData(newJWTHMACKey(jwtmacpb.JwtHmacAlgorithm_HS256, &jwtmacpb.JwtHmacKey_CustomKid{Value: "custom-kid"})) + if err != nil { + t.Fatalf("creating NewKeyData: %v", err) + } + primaryKey := testutil.NewKey(keyData, tinkpb.KeyStatusType_ENABLED, 42, tinkpb.OutputPrefixType_TINK) + kh, err := testkeyset.NewHandle(testutil.NewKeyset(primaryKey.KeyId, []*tinkpb.Keyset_Key{primaryKey})) + if err != nil { + t.Fatalf("keyset.NewHandle() err = %v, want nil", err) + } + // Annotations are only supported through the `insecurecleartextkeyset` API. + buff := &bytes.Buffer{} + if err := insecurecleartextkeyset.Write(kh, keyset.NewBinaryWriter(buff)); err != nil { + t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err) + } + annotations := map[string]string{"foo": "bar"} + mh, err := insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations)) + if err != nil { + t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err) + } + p, err := jwt.NewMAC(mh) + if err != nil { + t.Fatalf("jwt.NewMAC() err = %v, want nil", err) + } + rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{WithoutExpiration: true}) + if err != nil { + t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err) + } + if _, err := p.ComputeMACAndEncode(rawJWT); err == nil { + t.Errorf("p.ComputeMACAndEncode() err = nil, want error") + } + failures := client.Failures() + if len(failures) != 1 { + t.Errorf("len(client.Failures()) = %d, want = 1", len(failures)) + } + if len(client.Events()) != 0 { + t.Errorf("len(client.Events()) = %d, want = 0", len(client.Events())) + } +} + +func TestVerifyFailureEmitsMonitoring(t *testing.T) { + defer internalregistry.ClearMonitoringClient() + client := fakemonitoring.NewClient("fake-client") + if err := internalregistry.RegisterMonitoringClient(client); err != nil { + t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err) + } + kh, err := keyset.NewHandle(jwt.HS256Template()) + if err != nil { + t.Fatalf("keyset.NewHandle() err = %v, want nil", err) + } + // Annotations are only supported through the `insecurecleartextkeyset` API. + buff := &bytes.Buffer{} + if err := insecurecleartextkeyset.Write(kh, keyset.NewBinaryWriter(buff)); err != nil { + t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err) + } + annotations := map[string]string{"foo": "bar"} + mh, err := insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations)) + if err != nil { + t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err) + } + p, err := jwt.NewMAC(mh) + if err != nil { + t.Fatalf("jwt.NewMAC() err = %v, want nil", err) + } + audience := "audience" + validator, err := jwt.NewValidator( + &jwt.ValidatorOpts{ExpectedAudience: &audience, AllowMissingExpiration: true}) + if err != nil { + t.Fatalf("jwt.NewValidator() err = %v, want nil", err) + } + if _, err := p.VerifyMACAndDecode("", validator); err == nil { + t.Errorf("p.VerifyMACAndDecode() err = nil, want error") + } + failures := client.Failures() + if len(failures) != 1 { + t.Errorf("len(client.Failures()) = %d, want = 1", len(failures)) + } + if len(client.Events()) != 0 { + t.Errorf("len(client.Events()) = %d, want = 0", len(client.Events())) + } +}