Skip to content

Commit

Permalink
Amend keyset handle to accomodate the new Config classes coming.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590928087
Change-Id: I0548fb241a28e391ef5a97dbb76728957e8e407d
  • Loading branch information
LizaTretyakova authored and copybara-github committed Dec 14, 2023
1 parent 1d24f7d commit ca28af4
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 8 deletions.
4 changes: 4 additions & 0 deletions keyset/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ go_library(
"//core/primitiveset",
"//core/registry",
"//internal",
"//internal/internalapi",
"//internal/registryconfig",
"//proto/tink_go_proto",
"//subtle/random",
"//tink",
Expand All @@ -49,7 +51,9 @@ go_test(
deps = [
":keyset",
"//aead",
"//core/registry",
"//insecurecleartextkeyset",
"//internal/internalapi",
"//mac",
"//proto/common_go_proto",
"//proto/tink_go_proto",
Expand Down
71 changes: 63 additions & 8 deletions keyset/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (

"github.com/tink-crypto/tink-go/v2/core/primitiveset"
"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/internal/internalapi"
"github.com/tink-crypto/tink-go/v2/internal/registryconfig"
"github.com/tink-crypto/tink-go/v2/tink"
tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
)
Expand Down Expand Up @@ -174,14 +176,48 @@ func (h *Handle) WriteWithNoSecrets(w Writer) error {
return w.Write(h.ks)
}

// Config defines methods in the config.Config concrete type that are used by keyset.Handle.
// The config.Config concrete type is not used directly due to circular dependencies.
type Config interface {
PrimitiveFromKeyData(keyData *tinkpb.KeyData, _ internalapi.Token) (any, error)
}
type primitiveOptions struct {
config Config
}

// PrimitivesOption is used to configure Primitives(...).
type PrimitivesOption func(*primitiveOptions) error

// WithConfig sets the configuration used to create primitives via Primitives().
// If this option is omitted, default to using the global registry.
func WithConfig(c Config) PrimitivesOption {
return func(args *primitiveOptions) error {
if args.config != nil {
return fmt.Errorf("configuration has already been set")
}
args.config = c
return nil
}
}

// Primitives creates a set of primitives corresponding to the keys with
// status=ENABLED in the keyset of the given keyset handle, assuming all the
// corresponding key managers are present (keys with status!=ENABLED are skipped).
// status=ENABLED in the keyset of the given keyset handle. It uses the
// key managers that are present in the global Registry or in the Config,
// should it be provided. It assumes that all the needed key managers are
// present. Keys with status!=ENABLED are skipped.
//
// An example usage where a custom config is provided:
//
// ps, err := h.Primitives(WithConfig(config.V0()))
//
// The returned set is usually later "wrapped" into a class that implements
// the corresponding Primitive-interface.
func (h *Handle) Primitives() (*primitiveset.PrimitiveSet, error) {
return h.PrimitivesWithKeyManager(nil)
func (h *Handle) Primitives(opts ...PrimitivesOption) (*primitiveset.PrimitiveSet, error) {
p, err := h.primitives(nil, opts...)
if err != nil {
return nil, fmt.Errorf("handle.Primitives: %v", err)
}
return p, nil
}

// PrimitivesWithKeyManager creates a set of primitives corresponding to
Expand All @@ -197,8 +233,27 @@ func (h *Handle) Primitives() (*primitiveset.PrimitiveSet, error) {
// The returned set is usually later "wrapped" into a class that implements
// the corresponding Primitive-interface.
func (h *Handle) PrimitivesWithKeyManager(km registry.KeyManager) (*primitiveset.PrimitiveSet, error) {
p, err := h.primitives(km)
if err != nil {
return nil, fmt.Errorf("handle.PrimitivesWithKeyManager: %v", err)
}
return p, nil
}

func (h *Handle) primitives(km registry.KeyManager, opts ...PrimitivesOption) (*primitiveset.PrimitiveSet, error) {
args := new(primitiveOptions)
for _, opt := range opts {
if err := opt(args); err != nil {
return nil, fmt.Errorf("failed to process primitiveOptions: %v", err)
}
}
config := args.config
if config == nil {
config = &registryconfig.RegistryConfig{}
}

if err := Validate(h.ks); err != nil {
return nil, fmt.Errorf("registry.PrimitivesWithKeyManager: invalid keyset: %s", err)
return nil, fmt.Errorf("invalid keyset: %v", err)
}
primitiveSet := primitiveset.New()
primitiveSet.Annotations = h.annotations
Expand All @@ -211,14 +266,14 @@ func (h *Handle) PrimitivesWithKeyManager(km registry.KeyManager) (*primitiveset
if km != nil && km.DoesSupport(key.KeyData.TypeUrl) {
primitive, err = km.Primitive(key.KeyData.Value)
} else {
primitive, err = registry.PrimitiveFromKeyData(key.KeyData)
primitive, err = config.PrimitiveFromKeyData(key.KeyData, internalapi.Token{})
}
if err != nil {
return nil, fmt.Errorf("registry.PrimitivesWithKeyManager: cannot get primitive from key: %s", err)
return nil, fmt.Errorf("cannot get primitive from key: %v", err)
}
entry, err := primitiveSet.Add(primitive, key)
if err != nil {
return nil, fmt.Errorf("registry.PrimitivesWithKeyManager: cannot add primitive: %s", err)
return nil, fmt.Errorf("cannot add primitive: %v", err)
}
if key.KeyId == h.ks.PrimaryKeyId {
primitiveSet.Primary = entry
Expand Down
120 changes: 120 additions & 0 deletions keyset/handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ import (

"google.golang.org/protobuf/proto"
"github.com/tink-crypto/tink-go/v2/aead"
"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/internal/internalapi"
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/mac"
"github.com/tink-crypto/tink-go/v2/signature"
"github.com/tink-crypto/tink-go/v2/testkeyset"
"github.com/tink-crypto/tink-go/v2/testutil"
"github.com/tink-crypto/tink-go/v2/tink"
tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
)

Expand Down Expand Up @@ -324,3 +327,120 @@ func TestKeysetInfo(t *testing.T) {
t.Errorf("Expected primary key id: %d, but got: %d", info.KeyInfo[0].KeyId, info.PrimaryKeyId)
}
}

func TestPrimitivesWithRegistry(t *testing.T) {
template := mac.HMACSHA256Tag128KeyTemplate()
template.OutputPrefixType = tinkpb.OutputPrefixType_RAW
handle, err := keyset.NewHandle(template)
if err != nil {
t.Fatalf("keyset.NewHandle(%v) err = %v, want nil", template, err)
}
handleMAC, err := mac.New(handle)
if err != nil {
t.Fatalf("mac.New(%v) err = %v, want nil", handle, err)
}

ks := testkeyset.KeysetMaterial(handle)
if len(ks.Key) != 1 {
t.Fatalf("len(ks.Key) = %d, want 1", len(ks.Key))
}
keyDataPrimitive, err := registry.PrimitiveFromKeyData(ks.Key[0].KeyData)
if err != nil {
t.Fatalf("registry.PrimitiveFromKeyData(%v) err = %v, want nil", ks.Key[0].KeyData, err)
}
keyDataMAC, ok := keyDataPrimitive.(tink.MAC)
if !ok {
t.Fatal("registry.PrimitiveFromKeyData(keyData) is not of type tink.MAC")
}

plaintext := []byte("plaintext")
handleMACTag, err := handleMAC.ComputeMAC(plaintext)
if err != nil {
t.Fatalf("handleMAC.ComputeMAC(%v) err = %v, want nil", plaintext, err)
}
if err = keyDataMAC.VerifyMAC(handleMACTag, plaintext); err != nil {
t.Errorf("keyDataMAC.VerifyMAC(%v, %v) err = %v, want nil", handleMACTag, plaintext, err)
}
keyDataMACTag, err := keyDataMAC.ComputeMAC(plaintext)
if err != nil {
t.Fatalf("keyDataMAC.ComputeMAC(%v) err = %v, want nil", plaintext, err)
}
if err = handleMAC.VerifyMAC(keyDataMACTag, plaintext); err != nil {
t.Errorf("handleMAC.VerifyMAC(%v, %v) err = %v, want nil", keyDataMACTag, plaintext, err)
}
}

type testConfig struct{}

func (c *testConfig) PrimitiveFromKeyData(_ *tinkpb.KeyData, _ internalapi.Token) (any, error) {
return testPrimitive{}, nil
}

func TestPrimitivesWithConfig(t *testing.T) {
template := mac.HMACSHA256Tag128KeyTemplate()
template.OutputPrefixType = tinkpb.OutputPrefixType_RAW
handle, err := keyset.NewHandle(template)
if err != nil {
t.Fatalf("keyset.NewHandle(%v) = %v, want nil", template, err)
}
primitives, err := handle.Primitives(keyset.WithConfig(&testConfig{}))
if err != nil {
t.Fatalf("handle.Primitives(keyset.WithConfig(&testConfig{})) err = %v, want nil", err)
}
if len(primitives.EntriesInKeysetOrder) != 1 {
t.Fatalf("len(handle.Primitives()) = %d, want 1", len(primitives.EntriesInKeysetOrder))
}
if _, ok := (primitives.Primary.Primitive).(testPrimitive); !ok {
t.Errorf("handle.Primitives().Primary = %v, want instance of `testPrimitive`", primitives.Primary.Primitive)
}
}

func TestPrimitivesWithMultipleConfigs(t *testing.T) {
template := mac.HMACSHA256Tag128KeyTemplate()
template.OutputPrefixType = tinkpb.OutputPrefixType_RAW
handle, err := keyset.NewHandle(template)
if err != nil {
t.Fatalf("keyset.NewHandle(%v) = %v, want nil", template, err)
}
_, err = handle.Primitives(keyset.WithConfig(&testConfig{}), keyset.WithConfig(&testConfig{}))
if err == nil { // if NO error
t.Error("handle.Primitives(keyset.WithConfig(&testConfig{}), keyset.WithConfig(&testConfig{})) err = nil, want error")
}
}

type testKeyManager struct{}

type testPrimitive struct{}

func (km *testKeyManager) Primitive(_ []byte) (any, error) { return testPrimitive{}, nil }
func (km *testKeyManager) NewKey(_ []byte) (proto.Message, error) { return nil, nil }
func (km *testKeyManager) TypeURL() string { return mac.HMACSHA256Tag128KeyTemplate().TypeUrl }
func (km *testKeyManager) NewKeyData(_ []byte) (*tinkpb.KeyData, error) { return nil, nil }
func (km *testKeyManager) DoesSupport(typeURL string) bool {
return typeURL == mac.HMACSHA256Tag128KeyTemplate().TypeUrl
}

func TestPrimitivesWithKeyManager(t *testing.T) {
template := mac.HMACSHA256Tag128KeyTemplate()
handle, err := keyset.NewHandle(template)
if err != nil {
t.Fatalf("keyset.NewHandle(%v) = %v, want nil", template, err)
}

// Verify that without providing a custom key manager we get a usual MAC.
if _, err = mac.New(handle); err != nil {
t.Fatalf("mac.New(%v) err = %v, want nil", handle, err)
}

// Verify that with the custom key manager provided we get the custom primitive.
primitives, err := handle.PrimitivesWithKeyManager(&testKeyManager{})
if err != nil {
t.Fatalf("handle.PrimitivesWithKeyManager(testKeyManager) err = %v, want nil", err)
}
if len(primitives.EntriesInKeysetOrder) != 1 {
t.Fatalf("len(handle.PrimitivesWithKeyManager()) = %d, want 1", len(primitives.EntriesInKeysetOrder))
}
if _, ok := (primitives.Primary.Primitive).(testPrimitive); !ok {
t.Errorf("handle.PrimitivesWithKeyManager().Primary = %v, want instance of `testPrimitive`", primitives.Primary.Primitive)
}
}

0 comments on commit ca28af4

Please sign in to comment.