Skip to content

Commit

Permalink
Add crypto monitoring to JWT MAC
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 639762747
Change-Id: I0d45fd1995650d75ec1e31b5efd3bb688e66fa86
  • Loading branch information
fernandolobato authored and copybara-github committed Jun 3, 2024
1 parent 7b0db1d commit 20ba8b8
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 3 deletions.
6 changes: 6 additions & 0 deletions jwt/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ go_library(
deps = [
"//core/primitiveset",
"//core/registry",
"//internal/internalregistry",
"//internal/monitoringutil",
"//internal/signature",
"//internal/tinkerror",
"//keyset",
"//mac/subtle",
"//monitoring",
"//proto/jwt_ecdsa_go_proto",
"//proto/jwt_hmac_go_proto",
"//proto/jwt_rsa_ssa_pkcs1_go_proto",
Expand Down Expand Up @@ -82,8 +85,10 @@ go_test(
deps = [
"//core/registry",
"//insecurecleartextkeyset",
"//internal/internalregistry",
"//keyset",
"//mac/subtle",
"//monitoring",
"//proto/ecdsa_go_proto",
"//proto/jwt_ecdsa_go_proto",
"//proto/jwt_hmac_go_proto",
Expand All @@ -93,6 +98,7 @@ go_test(
"//signature",
"//signature/subtle",
"//subtle/random",
"//testing/fakemonitoring",
"//testkeyset",
"//testutil",
"//tink",
Expand Down
51 changes: 48 additions & 3 deletions jwt/jwt_mac_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (
"fmt"

"github.com/tink-crypto/tink-go/v2/core/primitiveset"
"github.com/tink-crypto/tink-go/v2/internal/internalregistry"
"github.com/tink-crypto/tink-go/v2/internal/monitoringutil"
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/monitoring"
tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
)

Expand All @@ -36,7 +39,9 @@ func NewMAC(handle *keyset.Handle) (MAC, error) {

// wrappedJWTMAC is a JWTMAC implementation that uses the underlying primitive set for JWT MAC.
type wrappedJWTMAC struct {
ps *primitiveset.PrimitiveSet
ps *primitiveset.PrimitiveSet
computeLogger monitoring.Logger
verifyLogger monitoring.Logger
}

var _ MAC = (*wrappedJWTMAC)(nil)
Expand All @@ -55,7 +60,39 @@ func newWrappedJWTMAC(ps *primitiveset.PrimitiveSet) (*wrappedJWTMAC, error) {
}
}
}
return &wrappedJWTMAC{ps: ps}, nil
computeLogger, verifyLogger, err := createLoggers(ps)
if err != nil {
return nil, err
}
return &wrappedJWTMAC{ps: ps, computeLogger: computeLogger, verifyLogger: verifyLogger}, nil
}

func createLoggers(ps *primitiveset.PrimitiveSet) (monitoring.Logger, monitoring.Logger, error) {
if len(ps.Annotations) == 0 {
return &monitoringutil.DoNothingLogger{}, &monitoringutil.DoNothingLogger{}, nil
}
client := internalregistry.GetMonitoringClient()
keysetInfo, err := monitoringutil.KeysetInfoFromPrimitiveSet(ps)
if err != nil {
return nil, nil, err
}
computeLogger, err := client.NewLogger(&monitoring.Context{
Primitive: "jwtmac",
APIFunction: "compute",
KeysetInfo: keysetInfo,
})
if err != nil {
return nil, nil, err
}
verifyLogger, err := client.NewLogger(&monitoring.Context{
Primitive: "jwtmac",
APIFunction: "verify",
KeysetInfo: keysetInfo,
})
if err != nil {
return nil, nil, err
}
return computeLogger, verifyLogger, nil
}

func (w *wrappedJWTMAC) ComputeMACAndEncode(token *RawJWT) (string, error) {
Expand All @@ -64,7 +101,13 @@ func (w *wrappedJWTMAC) ComputeMACAndEncode(token *RawJWT) (string, error) {
if !ok {
return "", fmt.Errorf("jwt_mac_factory: not a JWT MAC primitive")
}
return p.ComputeMACAndEncodeWithKID(token, keyID(primary.KeyID, primary.PrefixType))
signedToken, err := p.ComputeMACAndEncodeWithKID(token, keyID(primary.KeyID, primary.PrefixType))
if err != nil {
w.computeLogger.LogFailure()
return "", err
}
w.computeLogger.Log(primary.KeyID, 1)
return signedToken, nil
}

func (w *wrappedJWTMAC) VerifyMACAndDecode(compact string, validator *Validator) (*VerifiedJWT, error) {
Expand All @@ -77,6 +120,7 @@ func (w *wrappedJWTMAC) VerifyMACAndDecode(compact string, validator *Validator)
}
verifiedJWT, err := p.VerifyMACAndDecodeWithKID(compact, validator, keyID(e.KeyID, e.PrefixType))
if err == nil {
w.verifyLogger.Log(e.KeyID, 1)
return verifiedJWT, nil
}
if err != errJwtVerification {
Expand All @@ -85,6 +129,7 @@ func (w *wrappedJWTMAC) VerifyMACAndDecode(compact string, validator *Validator)
}
}
}
w.verifyLogger.LogFailure()
if interestingErr != nil {
return nil, interestingErr
}
Expand Down
207 changes: 207 additions & 0 deletions jwt/jwt_mac_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
package jwt_test

import (
"bytes"
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"github.com/tink-crypto/tink-go/v2/insecurecleartextkeyset"
"github.com/tink-crypto/tink-go/v2/internal/internalregistry"
"github.com/tink-crypto/tink-go/v2/jwt"
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/monitoring"
"github.com/tink-crypto/tink-go/v2/signature"
"github.com/tink-crypto/tink-go/v2/subtle/random"
"github.com/tink-crypto/tink-go/v2/testing/fakemonitoring"
"github.com/tink-crypto/tink-go/v2/testkeyset"
"github.com/tink-crypto/tink-go/v2/testutil"

Expand Down Expand Up @@ -277,3 +283,204 @@ func TestVerifyMACAndDecodeReturnsValidationError(t *testing.T) {
t.Errorf("p.VerifyMACAndDecode() err = %q, want %q", err.Error(), wantErr)
}
}

func TestComputeAndVerifyWithoutAnnotationsEmitsNoMonitoring(t *testing.T) {
defer internalregistry.ClearMonitoringClient()
client := fakemonitoring.NewClient("fake-client")
if err := internalregistry.RegisterMonitoringClient(client); err != nil {
t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err)
}
kh, err := keyset.NewHandle(jwt.HS256Template())
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
p, err := jwt.NewMAC(kh)
if err != nil {
t.Fatalf("jwt.NewMAC() err = %v, want nil", err)
}
audience := "audience"
rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{Audience: &audience, WithoutExpiration: true})
if err != nil {
t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err)
}
token, err := p.ComputeMACAndEncode(rawJWT)
if err != nil {
t.Errorf("p.ComputeMACAndEncode() err = %v, want nil", err)
}
validator, err := jwt.NewValidator(
&jwt.ValidatorOpts{ExpectedAudience: &audience, AllowMissingExpiration: true})
if err != nil {
t.Fatalf("jwt.NewValidator() err = %v, want nil", err)
}
if _, err = p.VerifyMACAndDecode(token, validator); err != nil {
t.Errorf("p.VerifyMACAndDecode() err = %v, want error", err)
}
if len(client.Failures()) != 0 {
t.Errorf("len(client.Failures()) = %d, want = 0", len(client.Failures()))
}
if len(client.Events()) != 0 {
t.Errorf("len(client.Events()) = %d, want = 0", len(client.Events()))
}
}

func TestComputeAndVerifyWithAnnotationsEmitsMonitoring(t *testing.T) {
defer internalregistry.ClearMonitoringClient()
client := fakemonitoring.NewClient("fake-client")
if err := internalregistry.RegisterMonitoringClient(client); err != nil {
t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err)
}
kh, err := keyset.NewHandle(jwt.HS256Template())
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
// Annotations are only supported through the `insecurecleartextkeyset` API.
buff := &bytes.Buffer{}
if err := insecurecleartextkeyset.Write(kh, keyset.NewBinaryWriter(buff)); err != nil {
t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err)
}
annotations := map[string]string{"foo": "bar"}
mh, err := insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations))
if err != nil {
t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err)
}
p, err := jwt.NewMAC(mh)
if err != nil {
t.Fatalf("jwt.NewMAC() err = %v, want nil", err)
}
audience := "audience"
rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{Audience: &audience, WithoutExpiration: true})
if err != nil {
t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err)
}
token, err := p.ComputeMACAndEncode(rawJWT)
if err != nil {
t.Errorf("p.ComputeMACAndEncode() err = %v, want nil", err)
}
validator, err := jwt.NewValidator(
&jwt.ValidatorOpts{ExpectedAudience: &audience, AllowMissingExpiration: true})
if err != nil {
t.Fatalf("jwt.NewValidator() err = %v, want nil", err)
}
if _, err = p.VerifyMACAndDecode(token, validator); err != nil {
t.Errorf("p.VerifyMACAndDecode() err = %v, want error", err)
}
failures := client.Failures()
if len(failures) != 0 {
t.Errorf("len(client.Failures()) = %d, want = 0", len(failures))
}
got := client.Events()
wantKeysetInfo := monitoring.NewKeysetInfo(
annotations,
kh.KeysetInfo().GetPrimaryKeyId(),
[]*monitoring.Entry{
{
KeyID: kh.KeysetInfo().GetPrimaryKeyId(),
Status: monitoring.Enabled,
KeyType: "tink.JwtHmacKey",
KeyPrefix: "TINK",
},
},
)
want := []*fakemonitoring.LogEvent{
{
KeyID: mh.KeysetInfo().GetPrimaryKeyId(),
NumBytes: 1,
Context: monitoring.NewContext("jwtmac", "compute", wantKeysetInfo),
},
{
KeyID: mh.KeysetInfo().GetPrimaryKeyId(),
NumBytes: 1,
Context: monitoring.NewContext("jwtmac", "verify", wantKeysetInfo),
},
}
if cmp.Diff(got, want) != "" {
t.Errorf("%v", cmp.Diff(got, want))
}
}

func TestComputeFailureEmitsMonitoring(t *testing.T) {
defer internalregistry.ClearMonitoringClient()
client := &fakemonitoring.Client{Name: ""}
if err := internalregistry.RegisterMonitoringClient(client); err != nil {
t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err)
}
keyData, err := newKeyData(newJWTHMACKey(jwtmacpb.JwtHmacAlgorithm_HS256, &jwtmacpb.JwtHmacKey_CustomKid{Value: "custom-kid"}))
if err != nil {
t.Fatalf("creating NewKeyData: %v", err)
}
primaryKey := testutil.NewKey(keyData, tinkpb.KeyStatusType_ENABLED, 42, tinkpb.OutputPrefixType_TINK)
kh, err := testkeyset.NewHandle(testutil.NewKeyset(primaryKey.KeyId, []*tinkpb.Keyset_Key{primaryKey}))
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
// Annotations are only supported through the `insecurecleartextkeyset` API.
buff := &bytes.Buffer{}
if err := insecurecleartextkeyset.Write(kh, keyset.NewBinaryWriter(buff)); err != nil {
t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err)
}
annotations := map[string]string{"foo": "bar"}
mh, err := insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations))
if err != nil {
t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err)
}
p, err := jwt.NewMAC(mh)
if err != nil {
t.Fatalf("jwt.NewMAC() err = %v, want nil", err)
}
rawJWT, err := jwt.NewRawJWT(&jwt.RawJWTOptions{WithoutExpiration: true})
if err != nil {
t.Fatalf("jwt.NewRawJWT() err = %v, want nil", err)
}
if _, err := p.ComputeMACAndEncode(rawJWT); err == nil {
t.Errorf("p.ComputeMACAndEncode() err = nil, want error")
}
failures := client.Failures()
if len(failures) != 1 {
t.Errorf("len(client.Failures()) = %d, want = 1", len(failures))
}
if len(client.Events()) != 0 {
t.Errorf("len(client.Events()) = %d, want = 0", len(client.Events()))
}
}

func TestVerifyFailureEmitsMonitoring(t *testing.T) {
defer internalregistry.ClearMonitoringClient()
client := fakemonitoring.NewClient("fake-client")
if err := internalregistry.RegisterMonitoringClient(client); err != nil {
t.Fatalf("internalregistry.RegisterMonitoringClient() err = %v, want nil", err)
}
kh, err := keyset.NewHandle(jwt.HS256Template())
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
// Annotations are only supported through the `insecurecleartextkeyset` API.
buff := &bytes.Buffer{}
if err := insecurecleartextkeyset.Write(kh, keyset.NewBinaryWriter(buff)); err != nil {
t.Fatalf("insecurecleartextkeyset.Write() err = %v, want nil", err)
}
annotations := map[string]string{"foo": "bar"}
mh, err := insecurecleartextkeyset.Read(keyset.NewBinaryReader(buff), keyset.WithAnnotations(annotations))
if err != nil {
t.Fatalf("insecurecleartextkeyset.Read() err = %v, want nil", err)
}
p, err := jwt.NewMAC(mh)
if err != nil {
t.Fatalf("jwt.NewMAC() err = %v, want nil", err)
}
audience := "audience"
validator, err := jwt.NewValidator(
&jwt.ValidatorOpts{ExpectedAudience: &audience, AllowMissingExpiration: true})
if err != nil {
t.Fatalf("jwt.NewValidator() err = %v, want nil", err)
}
if _, err := p.VerifyMACAndDecode("", validator); err == nil {
t.Errorf("p.VerifyMACAndDecode() err = nil, want error")
}
failures := client.Failures()
if len(failures) != 1 {
t.Errorf("len(client.Failures()) = %d, want = 1", len(failures))
}
if len(client.Events()) != 0 {
t.Errorf("len(client.Events()) = %d, want = 0", len(client.Events()))
}
}

0 comments on commit 20ba8b8

Please sign in to comment.