diff --git a/internal/protoserialization/BUILD.bazel b/internal/protoserialization/BUILD.bazel index 205c376..4baf73b 100644 --- a/internal/protoserialization/BUILD.bazel +++ b/internal/protoserialization/BUILD.bazel @@ -20,5 +20,6 @@ go_test( deps = [ ":protoserialization", "//proto/tink_go_proto", + "@org_golang_google_protobuf//proto", ], ) diff --git a/internal/protoserialization/protoserialization.go b/internal/protoserialization/protoserialization.go index f9ad621..59d3fbf 100644 --- a/internal/protoserialization/protoserialization.go +++ b/internal/protoserialization/protoserialization.go @@ -19,14 +19,17 @@ package protoserialization import ( "fmt" + "reflect" "sync" tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto" ) var ( - keyParsersMu sync.RWMutex - keyParsers = make(map[string]KeyParser) // TypeURL -> KeyParser + keyParsersMu sync.RWMutex + keyParsers = make(map[string]KeyParser) // TypeURL -> KeyParser + keySerializersMu sync.RWMutex + keySerializers = make(map[reflect.Type]KeySerializer) // KeyType -> KeySerializer ) // FallbackProtoKey is a key that wraps a proto keyset key. @@ -55,7 +58,37 @@ type KeyParser interface { ParseKey(keysetKey *tinkpb.Keyset_Key) (any, error) } -// RegisterKeyParser registers the given key parser to the global registry. +// KeySerializer is an interface for serializing a key into a proto keyset key. +type KeySerializer interface { + // SerializeKey serializes the given key into a proto keyset key. + SerializeKey(key any) (*tinkpb.Keyset_Key, error) +} + +// RegisterKeySerializer registers the given key serializer for keys of type K. +// +// It doesn't allow replacing existing serializers. +func RegisterKeySerializer[K any](keySerializer KeySerializer) error { + keySerializersMu.Lock() + defer keySerializersMu.Unlock() + keyType := reflect.TypeOf((*K)(nil)).Elem() + if _, found := keySerializers[keyType]; found { + return fmt.Errorf("serialization.RegisterKeySerializer: type %v already registered", keyType) + } + keySerializers[keyType] = keySerializer + return nil +} + +// SerializeKey serializes the given key into a proto keyset key. +func SerializeKey(key any) (*tinkpb.Keyset_Key, error) { + keyType := reflect.TypeOf(key) + serializer, ok := keySerializers[keyType] + if !ok { + return nil, fmt.Errorf("serialization.SerializeKey: no serializer for type %v", keyType) + } + return serializer.SerializeKey(key) +} + +// RegisterKeyParser registers the given key parser. // // It doesn't allow replacing existing parsers. func RegisterKeyParser(keyTypeURL string, keyParser KeyParser) error { @@ -79,11 +112,37 @@ func ParseKey(keysetKey *tinkpb.Keyset_Key) (any, error) { return parser.ParseKey(keysetKey) } -// ClearKeyParsers clears the global parsers registry. +type fallbackProtoKeySerializer struct{} + +func (s *fallbackProtoKeySerializer) SerializeKey(key any) (*tinkpb.Keyset_Key, error) { + fallbackKey, ok := key.(*FallbackProtoKey) + if !ok { + return nil, fmt.Errorf("serialization.fallbackProtoKeySerializer.SerializeKey: key is not a FallbackProtoKey") + } + return fallbackKey.protoKeysetKey, nil +} + +// ClearKeyParsers clears the global key parsers registry. // // This function is intended to be used in tests only. func ClearKeyParsers() { keyParsersMu.Lock() defer keyParsersMu.Unlock() - keyParsers = make(map[string]KeyParser) + clear(keyParsers) +} + +// ReinitializeKeySerializers clears the global key serializers registry and registers +// fallbackProtoKeySerializer. +// +// This function is intended to be used in tests only. +func ReinitializeKeySerializers() { + keySerializersMu.Lock() + defer keySerializersMu.Unlock() + clear(keySerializers) + // Always register the fallback serializer. + keySerializers[reflect.TypeOf((*FallbackProtoKey)(nil))] = &fallbackProtoKeySerializer{} +} + +func init() { + RegisterKeySerializer[*FallbackProtoKey](&fallbackProtoKeySerializer{}) } diff --git a/internal/protoserialization/protoserialization_test.go b/internal/protoserialization/protoserialization_test.go index 04afc62..f142457 100644 --- a/internal/protoserialization/protoserialization_test.go +++ b/internal/protoserialization/protoserialization_test.go @@ -17,8 +17,10 @@ package protoserialization_test import ( "bytes" "errors" + "fmt" "testing" + "google.golang.org/protobuf/proto" "github.com/tink-crypto/tink-go/v2/internal/protoserialization" tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto" ) @@ -27,7 +29,8 @@ var ( testKeyURL = "test-key-url" testKeyURL2 = "test-key-url-2" - ErrKeyParsing = errors.New("key parsing failed") + ErrKeyParsing = errors.New("key parsing failed") + ErrKeySerialization = errors.New("key serialization failed") ) type testKey struct { @@ -42,6 +45,24 @@ func (p *testParser) ParseKey(keysetKey *tinkpb.Keyset_Key) (any, error) { }, nil } +var _ protoserialization.KeyParser = (*testParser)(nil) + +type testSerializer struct{} + +func (s *testSerializer) SerializeKey(key any) (*tinkpb.Keyset_Key, error) { + actualKey, ok := key.(*testKey) + if !ok { + return nil, fmt.Errorf("type mismatch: got %T, want *testKey", key) + } + return &tinkpb.Keyset_Key{ + KeyData: &tinkpb.KeyData{ + Value: actualKey.keyBytes, + }, + }, nil +} + +var _ protoserialization.KeySerializer = (*testSerializer)(nil) + func TestRegisterKeyParserFailsIfAlreadyRegistered(t *testing.T) { defer protoserialization.ClearKeyParsers() err := protoserialization.RegisterKeyParser(testKeyURL, &testParser{}) @@ -148,3 +169,87 @@ func TestParseKeyFailsIfParserFails(t *testing.T) { t.Errorf("protoserialization.ParseKey(%s) err = %v, want %v", testKeyURL, err, ErrKeyParsing) } } + +func TestRegisterKeySerializerAndSerializeKey(t *testing.T) { + defer protoserialization.ReinitializeKeySerializers() + err := protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) + if err != nil { + t.Fatalf("protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) err = %v, want nil", err) + } + + key := &testKey{ + keyBytes: []byte("123"), + } + gotKeysetKey, err := protoserialization.SerializeKey(key) + if err != nil { + t.Fatalf("protoserialization.SerializeKey(%v) err = %v, want nil", key, err) + } + if !bytes.Equal(gotKeysetKey.GetKeyData().GetValue(), key.keyBytes) { + t.Errorf("bytes.Equal(%v, %v) = false, want true", gotKeysetKey.GetKeyData().GetValue(), key.keyBytes) + } +} + +func TestRegisterKeySerializerFailsIfAlreadyRegistered(t *testing.T) { + defer protoserialization.ReinitializeKeySerializers() + err := protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) + if err != nil { + t.Fatalf("protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) err = %v, want nil", err) + } + if protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) == nil { + t.Errorf("protoserialization.RegisterKeySerializer[*testKey](&testSerializer{}) err = nil, want error") + } +} + +func TestSerializeKeyFailsIfNoSerializersRegistered(t *testing.T) { + defer protoserialization.ReinitializeKeySerializers() + key := &testKey{ + keyBytes: []byte("123"), + } + if _, err := protoserialization.SerializeKey(key); err == nil { + t.Errorf("protoserialization.SerializeKey(%v) err = nil, want error", key) + } +} + +func TestSerializeKeyWithFallbackKey(t *testing.T) { + defer protoserialization.ReinitializeKeySerializers() + wantProtoKey := &tinkpb.Keyset_Key{ + KeyData: &tinkpb.KeyData{ + TypeUrl: testKeyURL, + Value: []byte("123"), + }, + } + key := protoserialization.NewFallbackProtoKey(wantProtoKey) + gotProtoKey, err := protoserialization.SerializeKey(key) + if err != nil { + t.Fatalf("protoserialization.SerializeKey(%v) err = %v, want nil", key, err) + } + if !proto.Equal(gotProtoKey, wantProtoKey) { + t.Errorf("proto.Equal(%v, %v) = false, want true", gotProtoKey, wantProtoKey) + } +} + +type alwaysFailingSerializer struct{} + +func (s *alwaysFailingSerializer) SerializeKey(key any) (*tinkpb.Keyset_Key, error) { + return nil, ErrKeySerialization +} + +var _ protoserialization.KeySerializer = (*alwaysFailingSerializer)(nil) + +func TestSerializeKeyFailsIfSerializeFails(t *testing.T) { + defer protoserialization.ReinitializeKeySerializers() + err := protoserialization.RegisterKeySerializer[*testKey](&alwaysFailingSerializer{}) + if err != nil { + t.Fatalf("protoserialization.RegisterKeySerializer[*testKey](&alwaysFailingSerializer{}) err = %v, want nil", err) + } + key := &testKey{ + keyBytes: []byte("123"), + } + _, err = protoserialization.SerializeKey(key) + if err == nil { + t.Errorf("protoserialization.SerializeKey(%v) err = nil, want error", key) + } + if !errors.Is(err, ErrKeySerialization) { + t.Errorf("protoserialization.SerializeKey(%v) err = %v, want %v", key, err, ErrKeyParsing) + } +}