Skip to content

Commit

Permalink
Add global registry for primitive constructor functions in `internal/…
Browse files Browse the repository at this point in the history
…registryconfig`

We can register functions that generate a primitive from keys. This is defined in `registryconfig` to avoid adding new public exported APIs to `registry`.
For all other keys, we fallback to using the key managers.

PiperOrigin-RevId: 678675026
Change-Id: I8439e5e83723130b2c69d54efab08aa301a20dab
  • Loading branch information
morambro authored and copybara-github committed Sep 25, 2024
1 parent 439ee81 commit 0aadc94
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 20 deletions.
55 changes: 48 additions & 7 deletions internal/registryconfig/registry_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,68 @@
package registryconfig

import (
"fmt"
"reflect"
"sync"

"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/internal/internalapi"
"github.com/tink-crypto/tink-go/v2/internal/protoserialization"
"github.com/tink-crypto/tink-go/v2/key"
)

var (
primitiveConstructorsMu sync.RWMutex
primitiveConstructors = make(map[reflect.Type]primitiveConstructor)
)

type primitiveConstructor func(key key.Key) (any, error)

// RegistryConfig is an internal way for the keyset handle to access the
// old global Registry through the new Configuration interface.
type RegistryConfig struct{}

// PrimitiveFromKey creates a primitive from a Key using the Registry.
// PrimitiveFromKey constructs a primitive from a [key.Key] using the registry.
func (c *RegistryConfig) PrimitiveFromKey(key key.Key, _ internalapi.Token) (any, error) {
keySerialization, err := protoserialization.SerializeKey(key)
if err != nil {
return nil, err
if key == nil {
return nil, fmt.Errorf("key is nil")
}
constructor, found := primitiveConstructors[reflect.TypeOf(key)]
if !found {
// Fallback to using the key manager.
keySerialization, err := protoserialization.SerializeKey(key)
if err != nil {
return nil, err
}
return registry.PrimitiveFromKeyData(keySerialization.KeyData())
}
return registry.PrimitiveFromKeyData(keySerialization.KeyData())
return constructor(key)
}

// RegisterKeyManager registers a provided KeyManager by forwarding it directly
// to the Registry.
// RegisterKeyManager registers a provided [registry.KeyManager] by forwarding
// it directly to the Registry.
func (c *RegistryConfig) RegisterKeyManager(km registry.KeyManager, _ internalapi.Token) error {
return registry.RegisterKeyManager(km)
}

// RegisterPrimitiveConstructor registers a function that constructs primitives
// from a given [key.Key] to the global registry.
func RegisterPrimitiveConstructor[K key.Key](constructor primitiveConstructor) error {
keyType := reflect.TypeFor[K]()
primitiveConstructorsMu.Lock()
defer primitiveConstructorsMu.Unlock()
if existingCreator, found := primitiveConstructors[keyType]; found && reflect.ValueOf(existingCreator).Pointer() != reflect.ValueOf(constructor).Pointer() {
return fmt.Errorf("a different constructor already registered for %v", keyType)
}
primitiveConstructors[keyType] = constructor
return nil
}

// ClearPrimitiveConstructors clears the registry of primitive constructors.
//
// This function is intended to be used in tests only.
func ClearPrimitiveConstructors() {
primitiveConstructorsMu.Lock()
defer primitiveConstructorsMu.Unlock()
primitiveConstructors = make(map[reflect.Type]primitiveConstructor)
}
140 changes: 127 additions & 13 deletions internal/registryconfig/registry_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package registryconfig_test

import (
"fmt"
"testing"

"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -96,28 +97,141 @@ func TestPrimitiveFromKeyErrors(t *testing.T) {
}
}

type testPrimitive struct{}
type testKeyManager struct{}
type stubPrimitive struct{}
type stubKeyManager struct{}

func (km *testKeyManager) Primitive(_ []byte) (any, error) { return &testPrimitive{}, nil }
func (km *testKeyManager) NewKey(_ []byte) (proto.Message, error) { return nil, nil }
func (km *testKeyManager) DoesSupport(typeURL string) bool { return typeURL == "testKeyManager" }
func (km *testKeyManager) TypeURL() string { return "testKeyManager" }
func (km *testKeyManager) NewKeyData(_ []byte) (*tinkpb.KeyData, error) { return nil, nil }
func (km *stubKeyManager) Primitive(_ []byte) (any, error) { return &stubPrimitive{}, nil }
func (km *stubKeyManager) NewKey(_ []byte) (proto.Message, error) { return nil, nil }
func (km *stubKeyManager) DoesSupport(typeURL string) bool { return typeURL == "stubKeyManager" }
func (km *stubKeyManager) TypeURL() string { return "stubKeyManager" }
func (km *stubKeyManager) NewKeyData(_ []byte) (*tinkpb.KeyData, error) { return nil, nil }

func TestRegisterKeyManager(t *testing.T) {
registryConfig := &registryconfig.RegistryConfig{}
if err := registryConfig.RegisterKeyManager(new(testKeyManager), internalapi.Token{}); err != nil {
if err := registryConfig.RegisterKeyManager(new(stubKeyManager), internalapi.Token{}); err != nil {
t.Fatalf("registryConfig.RegisterKeyManager() err = %v, want nil", err)
}
if _, err := registry.GetKeyManager("testKeyManager"); err != nil {
t.Fatalf("registry.GetKeyManager(\"testKeyManager\") err = %v, want nil", err)
if _, err := registry.GetKeyManager("stubKeyManager"); err != nil {
t.Fatalf("registry.GetKeyManager(\"stubKeyManager\") err = %v, want nil", err)
}
primitive, err := registry.Primitive(new(testKeyManager).TypeURL(), []byte{0, 1, 2, 3})
primitive, err := registry.Primitive(new(stubKeyManager).TypeURL(), []byte{0, 1, 2, 3})
if err != nil {
t.Fatalf("registry.Primitive() err = %v, want nil", err)
}
if _, ok := primitive.(*testPrimitive); !ok {
t.Error("primitive is not of type *testPrimitive")
if _, ok := primitive.(*stubPrimitive); !ok {
t.Error("primitive is not of type *stubPrimitive")
}
}

type stubKey struct{}

func (k *stubKey) Parameters() key.Parameters { return nil }
func (k *stubKey) Equals(other key.Key) bool { return true }
func (k *stubKey) IDRequirement() (uint32, bool) { return 123, true }

// stubPrimitiveConstructor creates a stubPrimitive from a stubKey.
func stubPrimitiveConstructor(k key.Key) (any, error) {
_, ok := k.(*stubKey)
if !ok {
return nil, fmt.Errorf("key is of type %T; needed *stubKey", k)
}
return &stubPrimitive{}, nil
}

// anotherStubPrimitiveConstructor creates a stubPrimitive from a stubKey.
func anotherStubPrimitiveConstructor(k key.Key) (any, error) {
return stubPrimitiveConstructor(k)
}

func alwaysFailingStubPrimitiveConstructor(k key.Key) (any, error) {
return nil, fmt.Errorf("I always fail :(")
}

func TestRegisterPrimitiveConstructor(t *testing.T) {
defer registryconfig.ClearPrimitiveConstructors()
if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor); err != nil {
t.Errorf("registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor) err = %v, want nil", err)
}
rc := &registryconfig.RegistryConfig{}
primitive, err := rc.PrimitiveFromKey(new(stubKey), internalapi.Token{})
if err != nil {
t.Fatalf("rc.PrimitiveFromKey() err = %v, want nil", err)
}
if _, ok := primitive.(*stubPrimitive); !ok {
t.Error("primitive is not of type *stubPrimitive")
}
}

func stubPrimitiveConstructorFromFallbackProtoKey(k key.Key) (any, error) {
_, ok := k.(*protoserialization.FallbackProtoKey)
if !ok {
return nil, fmt.Errorf("key is of type %T; needed *protoserialization.FallbackProtoKey", k)
}
return &stubPrimitive{}, nil
}

func TestRegisterPrimitiveConstructorUsesCreatorFirst(t *testing.T) {
defer registryconfig.ClearPrimitiveConstructors()
keyset, err := keyset.NewHandle(mac.HMACSHA256Tag256KeyTemplate())
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
entry, err := keyset.Entry(0)
if err != nil {
t.Fatalf("keyset.Entry() err = %v, want nil", err)
}

registryConfig := &registryconfig.RegistryConfig{}
p, err := registryConfig.PrimitiveFromKey(entry.Key(), internalapi.Token{})
if err != nil {
t.Errorf("registryConfig.PrimitiveFromKey() err = %v, want nil", err)
}
if _, ok := p.(*subtle.HMAC); !ok {
t.Error("p is not of type *subtle.HMAC")
}

rc := &registryconfig.RegistryConfig{}
// We now register a constructor for protoserialization.FallbackProtoKey that
// returns a stubPrimitive instead of a HMAC.
if err := registryconfig.RegisterPrimitiveConstructor[*protoserialization.FallbackProtoKey](stubPrimitiveConstructorFromFallbackProtoKey); err != nil {
t.Errorf("registryconfig.RegisterPrimitiveConstructor[*protoserialization.FallbackProtoKey](stubPrimitiveConstructorFromFallbackProtoKey) err = %v, want nil", err)
}
p, err = rc.PrimitiveFromKey(entry.Key(), internalapi.Token{})
if err != nil {
t.Errorf("registryConfig.PrimitiveFromKey() err = %v, want nil", err)
}
if _, ok := p.(*stubPrimitive); !ok {
t.Error("p is not of type *stubPrimitive")
}
}

func TestPrimitiveFromKeyFailsIfCreatorFails(t *testing.T) {
defer registryconfig.ClearPrimitiveConstructors()
if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](alwaysFailingStubPrimitiveConstructor); err != nil {
t.Errorf("registryconfig.RegisterPrimitiveConstructor[*stubKey](alwaysFailingStubPrimitiveConstructor) err = %v, want nil", err)
}
rc := &registryconfig.RegistryConfig{}
if _, err := rc.PrimitiveFromKey(new(stubKey), internalapi.Token{}); err == nil {
t.Errorf("rc.PrimitiveFromKey() err = nil, want error")
}
}

func TestRegisterPrimitiveConstructorSucceedsIfDoubleRegister(t *testing.T) {
defer registryconfig.ClearPrimitiveConstructors()
if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor); err != nil {
t.Fatalf("registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor) err = %v, want nil", err)
}
if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor); err != nil {
t.Errorf("registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor) err = %v, want nil", err)
}
}

func TestRegisterPrimitiveConstructorFailsIfRegisterAnotherCreatorForSameKeyType(t *testing.T) {
defer registryconfig.ClearPrimitiveConstructors()
if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor); err != nil {
t.Fatalf("registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor) err = %v, want nil", err)
}
if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](anotherStubPrimitiveConstructor); err == nil {
t.Errorf("registryconfig.RegisterPrimitiveConstructor[*stubKey](anotherStubPrimitiveConstructor) err = nil, want error")
}
}
4 changes: 4 additions & 0 deletions keyset/handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,10 @@ func TestPrimitivesWithRegistry(t *testing.T) {

type testConfig struct{}

func (c *testConfig) PrimitiveFromKeyData(_ *tinkpb.KeyData, _ internalapi.Token) (any, error) {
return testPrimitive{}, nil
}

func (c *testConfig) PrimitiveFromKey(_ key.Key, _ internalapi.Token) (any, error) {
return testPrimitive{}, nil
}
Expand Down

0 comments on commit 0aadc94

Please sign in to comment.