diff --git a/internal/registryconfig/registry_config.go b/internal/registryconfig/registry_config.go index 5991f6d..aa2152b 100644 --- a/internal/registryconfig/registry_config.go +++ b/internal/registryconfig/registry_config.go @@ -18,27 +18,68 @@ package registryconfig import ( + "fmt" + "reflect" + "sync" + "github.com/tink-crypto/tink-go/v2/core/registry" "github.com/tink-crypto/tink-go/v2/internal/internalapi" "github.com/tink-crypto/tink-go/v2/internal/protoserialization" "github.com/tink-crypto/tink-go/v2/key" ) +var ( + primitiveConstructorsMu sync.RWMutex + primitiveConstructors = make(map[reflect.Type]primitiveConstructor) +) + +type primitiveConstructor func(key key.Key) (any, error) + // RegistryConfig is an internal way for the keyset handle to access the // old global Registry through the new Configuration interface. type RegistryConfig struct{} -// PrimitiveFromKey creates a primitive from a Key using the Registry. +// PrimitiveFromKey constructs a primitive from a [key.Key] using the registry. func (c *RegistryConfig) PrimitiveFromKey(key key.Key, _ internalapi.Token) (any, error) { - keySerialization, err := protoserialization.SerializeKey(key) - if err != nil { - return nil, err + if key == nil { + return nil, fmt.Errorf("key is nil") + } + constructor, found := primitiveConstructors[reflect.TypeOf(key)] + if !found { + // Fallback to using the key manager. + keySerialization, err := protoserialization.SerializeKey(key) + if err != nil { + return nil, err + } + return registry.PrimitiveFromKeyData(keySerialization.KeyData()) } - return registry.PrimitiveFromKeyData(keySerialization.KeyData()) + return constructor(key) } -// RegisterKeyManager registers a provided KeyManager by forwarding it directly -// to the Registry. +// RegisterKeyManager registers a provided [registry.KeyManager] by forwarding +// it directly to the Registry. func (c *RegistryConfig) RegisterKeyManager(km registry.KeyManager, _ internalapi.Token) error { return registry.RegisterKeyManager(km) } + +// RegisterPrimitiveConstructor registers a function that constructs primitives +// from a given [key.Key] to the global registry. +func RegisterPrimitiveConstructor[K key.Key](constructor primitiveConstructor) error { + keyType := reflect.TypeFor[K]() + primitiveConstructorsMu.Lock() + defer primitiveConstructorsMu.Unlock() + if existingCreator, found := primitiveConstructors[keyType]; found && reflect.ValueOf(existingCreator).Pointer() != reflect.ValueOf(constructor).Pointer() { + return fmt.Errorf("a different constructor already registered for %v", keyType) + } + primitiveConstructors[keyType] = constructor + return nil +} + +// ClearPrimitiveConstructors clears the registry of primitive constructors. +// +// This function is intended to be used in tests only. +func ClearPrimitiveConstructors() { + primitiveConstructorsMu.Lock() + defer primitiveConstructorsMu.Unlock() + primitiveConstructors = make(map[reflect.Type]primitiveConstructor) +} diff --git a/internal/registryconfig/registry_config_test.go b/internal/registryconfig/registry_config_test.go index 648969c..156287c 100644 --- a/internal/registryconfig/registry_config_test.go +++ b/internal/registryconfig/registry_config_test.go @@ -15,6 +15,7 @@ package registryconfig_test import ( + "fmt" "testing" "google.golang.org/protobuf/proto" @@ -96,28 +97,141 @@ func TestPrimitiveFromKeyErrors(t *testing.T) { } } -type testPrimitive struct{} -type testKeyManager struct{} +type stubPrimitive struct{} +type stubKeyManager struct{} -func (km *testKeyManager) Primitive(_ []byte) (any, error) { return &testPrimitive{}, nil } -func (km *testKeyManager) NewKey(_ []byte) (proto.Message, error) { return nil, nil } -func (km *testKeyManager) DoesSupport(typeURL string) bool { return typeURL == "testKeyManager" } -func (km *testKeyManager) TypeURL() string { return "testKeyManager" } -func (km *testKeyManager) NewKeyData(_ []byte) (*tinkpb.KeyData, error) { return nil, nil } +func (km *stubKeyManager) Primitive(_ []byte) (any, error) { return &stubPrimitive{}, nil } +func (km *stubKeyManager) NewKey(_ []byte) (proto.Message, error) { return nil, nil } +func (km *stubKeyManager) DoesSupport(typeURL string) bool { return typeURL == "stubKeyManager" } +func (km *stubKeyManager) TypeURL() string { return "stubKeyManager" } +func (km *stubKeyManager) NewKeyData(_ []byte) (*tinkpb.KeyData, error) { return nil, nil } func TestRegisterKeyManager(t *testing.T) { registryConfig := ®istryconfig.RegistryConfig{} - if err := registryConfig.RegisterKeyManager(new(testKeyManager), internalapi.Token{}); err != nil { + if err := registryConfig.RegisterKeyManager(new(stubKeyManager), internalapi.Token{}); err != nil { t.Fatalf("registryConfig.RegisterKeyManager() err = %v, want nil", err) } - if _, err := registry.GetKeyManager("testKeyManager"); err != nil { - t.Fatalf("registry.GetKeyManager(\"testKeyManager\") err = %v, want nil", err) + if _, err := registry.GetKeyManager("stubKeyManager"); err != nil { + t.Fatalf("registry.GetKeyManager(\"stubKeyManager\") err = %v, want nil", err) } - primitive, err := registry.Primitive(new(testKeyManager).TypeURL(), []byte{0, 1, 2, 3}) + primitive, err := registry.Primitive(new(stubKeyManager).TypeURL(), []byte{0, 1, 2, 3}) if err != nil { t.Fatalf("registry.Primitive() err = %v, want nil", err) } - if _, ok := primitive.(*testPrimitive); !ok { - t.Error("primitive is not of type *testPrimitive") + if _, ok := primitive.(*stubPrimitive); !ok { + t.Error("primitive is not of type *stubPrimitive") + } +} + +type stubKey struct{} + +func (k *stubKey) Parameters() key.Parameters { return nil } +func (k *stubKey) Equals(other key.Key) bool { return true } +func (k *stubKey) IDRequirement() (uint32, bool) { return 123, true } + +// stubPrimitiveConstructor creates a stubPrimitive from a stubKey. +func stubPrimitiveConstructor(k key.Key) (any, error) { + _, ok := k.(*stubKey) + if !ok { + return nil, fmt.Errorf("key is of type %T; needed *stubKey", k) + } + return &stubPrimitive{}, nil +} + +// anotherStubPrimitiveConstructor creates a stubPrimitive from a stubKey. +func anotherStubPrimitiveConstructor(k key.Key) (any, error) { + return stubPrimitiveConstructor(k) +} + +func alwaysFailingStubPrimitiveConstructor(k key.Key) (any, error) { + return nil, fmt.Errorf("I always fail :(") +} + +func TestRegisterPrimitiveConstructor(t *testing.T) { + defer registryconfig.ClearPrimitiveConstructors() + if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor); err != nil { + t.Errorf("registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor) err = %v, want nil", err) + } + rc := ®istryconfig.RegistryConfig{} + primitive, err := rc.PrimitiveFromKey(new(stubKey), internalapi.Token{}) + if err != nil { + t.Fatalf("rc.PrimitiveFromKey() err = %v, want nil", err) + } + if _, ok := primitive.(*stubPrimitive); !ok { + t.Error("primitive is not of type *stubPrimitive") + } +} + +func stubPrimitiveConstructorFromFallbackProtoKey(k key.Key) (any, error) { + _, ok := k.(*protoserialization.FallbackProtoKey) + if !ok { + return nil, fmt.Errorf("key is of type %T; needed *protoserialization.FallbackProtoKey", k) + } + return &stubPrimitive{}, nil +} + +func TestRegisterPrimitiveConstructorUsesCreatorFirst(t *testing.T) { + defer registryconfig.ClearPrimitiveConstructors() + keyset, err := keyset.NewHandle(mac.HMACSHA256Tag256KeyTemplate()) + if err != nil { + t.Fatalf("keyset.NewHandle() err = %v, want nil", err) + } + entry, err := keyset.Entry(0) + if err != nil { + t.Fatalf("keyset.Entry() err = %v, want nil", err) + } + + registryConfig := ®istryconfig.RegistryConfig{} + p, err := registryConfig.PrimitiveFromKey(entry.Key(), internalapi.Token{}) + if err != nil { + t.Errorf("registryConfig.PrimitiveFromKey() err = %v, want nil", err) + } + if _, ok := p.(*subtle.HMAC); !ok { + t.Error("p is not of type *subtle.HMAC") + } + + rc := ®istryconfig.RegistryConfig{} + // We now register a constructor for protoserialization.FallbackProtoKey that + // returns a stubPrimitive instead of a HMAC. + if err := registryconfig.RegisterPrimitiveConstructor[*protoserialization.FallbackProtoKey](stubPrimitiveConstructorFromFallbackProtoKey); err != nil { + t.Errorf("registryconfig.RegisterPrimitiveConstructor[*protoserialization.FallbackProtoKey](stubPrimitiveConstructorFromFallbackProtoKey) err = %v, want nil", err) + } + p, err = rc.PrimitiveFromKey(entry.Key(), internalapi.Token{}) + if err != nil { + t.Errorf("registryConfig.PrimitiveFromKey() err = %v, want nil", err) + } + if _, ok := p.(*stubPrimitive); !ok { + t.Error("p is not of type *stubPrimitive") + } +} + +func TestPrimitiveFromKeyFailsIfCreatorFails(t *testing.T) { + defer registryconfig.ClearPrimitiveConstructors() + if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](alwaysFailingStubPrimitiveConstructor); err != nil { + t.Errorf("registryconfig.RegisterPrimitiveConstructor[*stubKey](alwaysFailingStubPrimitiveConstructor) err = %v, want nil", err) + } + rc := ®istryconfig.RegistryConfig{} + if _, err := rc.PrimitiveFromKey(new(stubKey), internalapi.Token{}); err == nil { + t.Errorf("rc.PrimitiveFromKey() err = nil, want error") + } +} + +func TestRegisterPrimitiveConstructorSucceedsIfDoubleRegister(t *testing.T) { + defer registryconfig.ClearPrimitiveConstructors() + if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor); err != nil { + t.Fatalf("registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor) err = %v, want nil", err) + } + if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor); err != nil { + t.Errorf("registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor) err = %v, want nil", err) + } +} + +func TestRegisterPrimitiveConstructorFailsIfRegisterAnotherCreatorForSameKeyType(t *testing.T) { + defer registryconfig.ClearPrimitiveConstructors() + if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor); err != nil { + t.Fatalf("registryconfig.RegisterPrimitiveConstructor[*stubKey](stubPrimitiveConstructor) err = %v, want nil", err) + } + if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](anotherStubPrimitiveConstructor); err == nil { + t.Errorf("registryconfig.RegisterPrimitiveConstructor[*stubKey](anotherStubPrimitiveConstructor) err = nil, want error") } } diff --git a/keyset/handle_test.go b/keyset/handle_test.go index 0ce30ce..7385de7 100644 --- a/keyset/handle_test.go +++ b/keyset/handle_test.go @@ -752,6 +752,10 @@ func TestPrimitivesWithRegistry(t *testing.T) { type testConfig struct{} +func (c *testConfig) PrimitiveFromKeyData(_ *tinkpb.KeyData, _ internalapi.Token) (any, error) { + return testPrimitive{}, nil +} + func (c *testConfig) PrimitiveFromKey(_ key.Key, _ internalapi.Token) (any, error) { return testPrimitive{}, nil }