Skip to content

Commit

Permalink
Add internal APIs to serialize a key to tinkpb.Keyset_Key
Browse files Browse the repository at this point in the history
This is needed to allow adding new APIs to `keyset` to construct keysets from existing keys.

PiperOrigin-RevId: 644704833
Change-Id: I3761333c859288f9e4022b38adba8a36cc755e0d
  • Loading branch information
morambro authored and copybara-github committed Jun 19, 2024
1 parent 3b9f7c1 commit a9ef7a6
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 6 deletions.
1 change: 1 addition & 0 deletions internal/protoserialization/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ go_test(
deps = [
":protoserialization",
"//proto/tink_go_proto",
"@org_golang_google_protobuf//proto",
],
)
69 changes: 64 additions & 5 deletions internal/protoserialization/protoserialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ package protoserialization

import (
"fmt"
"reflect"
"sync"

tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
)

var (
keyParsersMu sync.RWMutex
keyParsers = make(map[string]KeyParser) // TypeURL -> KeyParser
keyParsersMu sync.RWMutex
keyParsers = make(map[string]KeyParser) // TypeURL -> KeyParser
keySerializersMu sync.RWMutex
keySerializers = make(map[reflect.Type]KeySerializer) // KeyType -> KeySerializer
)

// FallbackProtoKey is a key that wraps a proto keyset key.
Expand Down Expand Up @@ -55,7 +58,37 @@ type KeyParser interface {
ParseKey(keysetKey *tinkpb.Keyset_Key) (any, error)
}

// RegisterKeyParser registers the given key parser to the global registry.
// KeySerializer is an interface for serializing a key into a proto keyset key.
type KeySerializer interface {
// SerializeKey serializes the given key into a proto keyset key.
SerializeKey(key any) (*tinkpb.Keyset_Key, error)
}

// RegisterKeySerializer registers the given key serializer for keys of type K.
//
// It doesn't allow replacing existing serializers.
func RegisterKeySerializer[K any](keySerializer KeySerializer) error {
keySerializersMu.Lock()
defer keySerializersMu.Unlock()
keyType := reflect.TypeOf((*K)(nil)).Elem()
if _, found := keySerializers[keyType]; found {
return fmt.Errorf("serialization.RegisterKeySerializer: type %v already registered", keyType)
}
keySerializers[keyType] = keySerializer
return nil
}

// SerializeKey serializes the given key into a proto keyset key.
func SerializeKey(key any) (*tinkpb.Keyset_Key, error) {
keyType := reflect.TypeOf(key)
serializer, ok := keySerializers[keyType]
if !ok {
return nil, fmt.Errorf("serialization.SerializeKey: no serializer for type %v", keyType)
}
return serializer.SerializeKey(key)
}

// RegisterKeyParser registers the given key parser.
//
// It doesn't allow replacing existing parsers.
func RegisterKeyParser(keyTypeURL string, keyParser KeyParser) error {
Expand All @@ -79,11 +112,37 @@ func ParseKey(keysetKey *tinkpb.Keyset_Key) (any, error) {
return parser.ParseKey(keysetKey)
}

// ClearKeyParsers clears the global parsers registry.
type fallbackProtoKeySerializer struct{}

func (s *fallbackProtoKeySerializer) SerializeKey(key any) (*tinkpb.Keyset_Key, error) {
fallbackKey, ok := key.(*FallbackProtoKey)
if !ok {
return nil, fmt.Errorf("serialization.fallbackProtoKeySerializer.SerializeKey: key is not a FallbackProtoKey")
}
return fallbackKey.protoKeysetKey, nil
}

// ClearKeyParsers clears the global key parsers registry.
//
// This function is intended to be used in tests only.
func ClearKeyParsers() {
keyParsersMu.Lock()
defer keyParsersMu.Unlock()
keyParsers = make(map[string]KeyParser)
clear(keyParsers)
}

// ReinitializeKeySerializers clears the global key serializers registry and registers
// fallbackProtoKeySerializer.
//
// This function is intended to be used in tests only.
func ReinitializeKeySerializers() {
keySerializersMu.Lock()
defer keySerializersMu.Unlock()
clear(keySerializers)
// Always register the fallback serializer.
keySerializers[reflect.TypeOf((*FallbackProtoKey)(nil))] = &fallbackProtoKeySerializer{}
}

func init() {
RegisterKeySerializer[*FallbackProtoKey](&fallbackProtoKeySerializer{})
}
107 changes: 106 additions & 1 deletion internal/protoserialization/protoserialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package protoserialization_test
import (
"bytes"
"errors"
"fmt"
"testing"

"google.golang.org/protobuf/proto"
"github.com/tink-crypto/tink-go/v2/internal/protoserialization"
tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
)
Expand All @@ -27,7 +29,8 @@ var (
testKeyURL = "test-key-url"
testKeyURL2 = "test-key-url-2"

ErrKeyParsing = errors.New("key parsing failed")
ErrKeyParsing = errors.New("key parsing failed")
ErrKeySerialization = errors.New("key serialization failed")
)

type testKey struct {
Expand All @@ -42,6 +45,24 @@ func (p *testParser) ParseKey(keysetKey *tinkpb.Keyset_Key) (any, error) {
}, nil
}

var _ protoserialization.KeyParser = (*testParser)(nil)

type testSerializer struct{}

func (s *testSerializer) SerializeKey(key any) (*tinkpb.Keyset_Key, error) {
actualKey, ok := key.(*testKey)
if !ok {
return nil, fmt.Errorf("type mismatch: got %T, want *testKey", key)
}
return &tinkpb.Keyset_Key{
KeyData: &tinkpb.KeyData{
Value: actualKey.keyBytes,
},
}, nil
}

var _ protoserialization.KeySerializer = (*testSerializer)(nil)

func TestRegisterKeyParserFailsIfAlreadyRegistered(t *testing.T) {
defer protoserialization.ClearKeyParsers()
err := protoserialization.RegisterKeyParser(testKeyURL, &testParser{})
Expand Down Expand Up @@ -148,3 +169,87 @@ func TestParseKeyFailsIfParserFails(t *testing.T) {
t.Errorf("protoserialization.ParseKey(%s) err = %v, want %v", testKeyURL, err, ErrKeyParsing)
}
}

func TestRegisterKeySerializerAndSerializeKey(t *testing.T) {
defer protoserialization.ReinitializeKeySerializers()
err := protoserialization.RegisterKeySerializer[*testKey](&testSerializer{})
if err != nil {
t.Fatalf("protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) err = %v, want nil", err)
}

key := &testKey{
keyBytes: []byte("123"),
}
gotKeysetKey, err := protoserialization.SerializeKey(key)
if err != nil {
t.Fatalf("protoserialization.SerializeKey(%v) err = %v, want nil", key, err)
}
if !bytes.Equal(gotKeysetKey.GetKeyData().GetValue(), key.keyBytes) {
t.Errorf("bytes.Equal(%v, %v) = false, want true", gotKeysetKey.GetKeyData().GetValue(), key.keyBytes)
}
}

func TestRegisterKeySerializerFailsIfAlreadyRegistered(t *testing.T) {
defer protoserialization.ReinitializeKeySerializers()
err := protoserialization.RegisterKeySerializer[*testKey](&testSerializer{})
if err != nil {
t.Fatalf("protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) err = %v, want nil", err)
}
if protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) == nil {
t.Errorf("protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) err = nil, want error")
}
}

func TestSerializeKeyFailsIfNoSerializersRegistered(t *testing.T) {
defer protoserialization.ReinitializeKeySerializers()
key := &testKey{
keyBytes: []byte("123"),
}
if _, err := protoserialization.SerializeKey(key); err == nil {
t.Errorf("protoserialization.SerializeKey(%v) err = nil, want error", key)
}
}

func TestSerializeKeyWithFallbackKey(t *testing.T) {
defer protoserialization.ReinitializeKeySerializers()
wantProtoKey := &tinkpb.Keyset_Key{
KeyData: &tinkpb.KeyData{
TypeUrl: testKeyURL,
Value: []byte("123"),
},
}
key := protoserialization.NewFallbackProtoKey(wantProtoKey)
gotProtoKey, err := protoserialization.SerializeKey(key)
if err != nil {
t.Fatalf("protoserialization.SerializeKey(%v) err = %v, want nil", key, err)
}
if !proto.Equal(gotProtoKey, wantProtoKey) {
t.Errorf("proto.Equal(%v, %v) = false, want true", gotProtoKey, wantProtoKey)
}
}

type alwaysFailingSerializer struct{}

func (s *alwaysFailingSerializer) SerializeKey(key any) (*tinkpb.Keyset_Key, error) {
return nil, ErrKeySerialization
}

var _ protoserialization.KeySerializer = (*alwaysFailingSerializer)(nil)

func TestSerializeKeyFailsIfSerializeFails(t *testing.T) {
defer protoserialization.ReinitializeKeySerializers()
err := protoserialization.RegisterKeySerializer[*testKey](&alwaysFailingSerializer{})
if err != nil {
t.Fatalf("protoserialization.RegisterKeySerializer[*testKey](&alwaysFailingSerializer{}) err = %v, want nil", err)
}
key := &testKey{
keyBytes: []byte("123"),
}
_, err = protoserialization.SerializeKey(key)
if err == nil {
t.Errorf("protoserialization.SerializeKey(%v) err = nil, want error", key)
}
if !errors.Is(err, ErrKeySerialization) {
t.Errorf("protoserialization.SerializeKey(%v) err = %v, want %v", key, err, ErrKeyParsing)
}
}

0 comments on commit a9ef7a6

Please sign in to comment.