Skip to content

Commit

Permalink
Change internal representation of a Keyset to []*Entry
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651419511
Change-Id: I27c4397975be9d96cedd99851a39623e1bf58c77
  • Loading branch information
morambro authored and copybara-github committed Jul 11, 2024
1 parent 84d8dda commit e8bdf5c
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 79 deletions.
199 changes: 123 additions & 76 deletions keyset/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ import (

var errInvalidKeyset = fmt.Errorf("keyset.Handle: invalid keyset")

// Handle provides access to a Keyset protobuf, to limit the exposure of actual protocol
// buffers that hold sensitive key material.
// Handle provides access to a keyset to limit the exposure of the internal
// keyset representation, which may hold sensitive key material.
type Handle struct {
ks *tinkpb.Keyset // must be non-nil
entries []*Entry
isKsValidated bool
annotations map[string]string
keysetHasSecrets bool // Whether the keyset contains secret key material.
primaryKeyEntry *Entry
}

// KeyStatus is the key status.
Expand Down Expand Up @@ -108,15 +109,82 @@ func keyStatusFromProto(status tinkpb.KeyStatusType) (KeyStatus, error) {
case tinkpb.KeyStatusType_DESTROYED:
return Destroyed, nil
default:
return Unknown, fmt.Errorf("unknown key status: %v", status)
return Unknown, nil
}
}

func keyStatusToProto(status KeyStatus) (tinkpb.KeyStatusType, error) {
switch status {
case Enabled:
return tinkpb.KeyStatusType_ENABLED, nil
case Disabled:
return tinkpb.KeyStatusType_DISABLED, nil
case Destroyed:
return tinkpb.KeyStatusType_DESTROYED, nil
default:
return tinkpb.KeyStatusType_UNKNOWN_STATUS, nil
}
}

func entriesToProtoKeyset(entries []*Entry) (*tinkpb.Keyset, error) {
if entries == nil {
return nil, fmt.Errorf("entriesToProtoKeyset called with nil")
}
if len(entries) == 0 {
return nil, fmt.Errorf("entries is empty")
}
protoKeyset := &tinkpb.Keyset{}
for _, entry := range entries {
protoKey, err := protoserialization.SerializeKey(entry.Key())
if err != nil {
return nil, err
}
if protoKey == nil {
return nil, fmt.Errorf("key is nil")
}
protoKey.Status, err = keyStatusToProto(entry.KeyStatus())
if err != nil {
return nil, err
}
protoKey.KeyId = entry.KeyID()
protoKeyset.Key = append(protoKeyset.Key, protoKey)
if entry.IsPrimary() {
protoKeyset.PrimaryKeyId = entry.KeyID()
}
}
return protoKeyset, nil
}

func newWithOptions(ks *tinkpb.Keyset, opts ...Option) (*Handle, error) {
if ks == nil {
return nil, errors.New("keyset.Handle: nil keyset")
return nil, errors.New("keyset.Handle: keyset is nil")
}
entries := make([]*Entry, len(ks.Key))
var primaryKeyEntry *Entry = nil
for i, protoKey := range ks.Key {
key, err := protoserialization.ParseKey(protoKey)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: %v", err)
}
keyStatus, err := keyStatusFromProto(protoKey.GetStatus())
if err != nil {
return nil, fmt.Errorf("keyset.Handle: %v", err)
}
entries[i] = &Entry{
key: key,
isPrimary: protoKey.GetKeyId() == ks.GetPrimaryKeyId(),
keyID: protoKey.GetKeyId(),
status: keyStatus,
}
if protoKey.GetKeyId() == ks.GetPrimaryKeyId() {
primaryKeyEntry = entries[i]
}
}
h := &Handle{
entries: entries,
keysetHasSecrets: hasSecrets(ks),
primaryKeyEntry: primaryKeyEntry,
}
h := &Handle{ks: ks, keysetHasSecrets: hasSecrets(ks)}
if err := applyOptions(h, opts...); err != nil {
return nil, err
}
Expand Down Expand Up @@ -167,28 +235,32 @@ func ReadWithAssociatedData(reader Reader, masterKey tink.AEAD, associatedData [
if err != nil {
return nil, err
}
ks, err := decrypt(encryptedKeyset, masterKey, associatedData)
protoKeyset, err := decrypt(encryptedKeyset, masterKey, associatedData)
if err != nil {
return nil, err
}
return newWithOptions(ks)
return newWithOptions(protoKeyset)
}

// ReadWithNoSecrets tries to create a keyset.Handle from a keyset obtained via reader.
func ReadWithNoSecrets(reader Reader) (*Handle, error) {
ks, err := reader.Read()
protoKeyset, err := reader.Read()
if err != nil {
return nil, err
}
return NewHandleWithNoSecrets(ks)
return NewHandleWithNoSecrets(protoKeyset)
}

func (h *Handle) validateKeyset() error {
if h.isKsValidated {
return nil
}
if err := Validate(h.ks); err != nil {
return fmt.Errorf("keyset.Handle: invalid keyset: %v", err)
protoKeyset, err := entriesToProtoKeyset(h.entries)
if err != nil {
return err
}
if err := Validate(protoKeyset); err != nil {
return fmt.Errorf("invalid keyset: %v", err)
}
h.isKsValidated = true
return nil
Expand All @@ -197,92 +269,58 @@ func (h *Handle) validateKeyset() error {
// Primary returns the primary key of the keyset.
func (h *Handle) Primary() (*Entry, error) {
if err := h.validateKeyset(); err != nil {
return nil, err
}
for _, key := range h.ks.GetKey() {
if key.GetKeyId() == h.ks.GetPrimaryKeyId() {
keyStatus, err := keyStatusFromProto(key.GetStatus())
if err != nil {
return nil, fmt.Errorf("keyset.Handle: invalid key status: %v", key.GetStatus())
}
keyObject, err := protoserialization.ParseKey(key)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: %v", err)
}
return &Entry{
key: keyObject,
isPrimary: true,
keyID: key.GetKeyId(),
status: keyStatus,
}, nil
}
return nil, fmt.Errorf("keyset.Handle: %v", err)
}
// Should never reach this point.
return nil, fmt.Errorf("keyset.Handle: no primary key found")
// If validation succeeded, then the primary key must exist.
return h.primaryKeyEntry, nil
}

// Entry returns the key at index i from the keyset.
// i must be within the range [0, Handle.Len()).
func (h *Handle) Entry(i int) (*Entry, error) {
if err := h.validateKeyset(); err != nil {
return nil, err
return nil, fmt.Errorf("keyset.Handle: %v", err)
}
if i < 0 || i >= h.Len() {
return nil, fmt.Errorf("keyset.Handle: index %d out of range", i)
}

key := h.ks.GetKey()[i]
keyStatus, err := keyStatusFromProto(key.GetStatus())
if err != nil {
return nil, fmt.Errorf("keyset.Handle: %v", err)
}
keyObject, err := protoserialization.ParseKey(key)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: %v", err)
}
return &Entry{
key: keyObject,
isPrimary: key.GetKeyId() == h.ks.GetPrimaryKeyId(),
keyID: key.GetKeyId(),
status: keyStatus,
}, nil
return h.entries[i], nil
}

// Public returns a Handle of the public keys if the managed keyset contains private keys.
func (h *Handle) Public() (*Handle, error) {
privKeys := h.ks.GetKey()
if len(privKeys) == 0 {
return nil, errors.New("keyset.Handle: invalid keyset")
protoKeyset, err := entriesToProtoKeyset(h.entries)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: %v", err)
}
pubKeys := make([]*tinkpb.Keyset_Key, len(privKeys))

for i := 0; i < len(privKeys); i++ {
if privKeys[i] == nil || privKeys[i].KeyData == nil {
publicKeys := make([]*tinkpb.Keyset_Key, h.Len())
for i, privKey := range protoKeyset.Key {
if privKey == nil || privKey.KeyData == nil {
return nil, errInvalidKeyset
}
privKeyData := privKeys[i].KeyData
privKeyData := privKey.KeyData
pubKeyData, err := publicKeyData(privKeyData)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: %s", err)
}
pubKeys[i] = &tinkpb.Keyset_Key{
publicKeys[i] = &tinkpb.Keyset_Key{
KeyData: pubKeyData,
Status: privKeys[i].Status,
KeyId: privKeys[i].KeyId,
OutputPrefixType: privKeys[i].OutputPrefixType,
Status: privKey.Status,
KeyId: privKey.KeyId,
OutputPrefixType: privKey.OutputPrefixType,
}
}
ks := &tinkpb.Keyset{
PrimaryKeyId: h.ks.PrimaryKeyId,
Key: pubKeys,
publicProtoKeyset := &tinkpb.Keyset{
PrimaryKeyId: protoKeyset.PrimaryKeyId,
Key: publicKeys,
}
return newWithOptions(ks)
return newWithOptions(publicProtoKeyset)
}

// String returns a string representation of the managed keyset.
// The result does not contain any sensitive key material.
func (h *Handle) String() string {
c, err := prototext.MarshalOptions{}.Marshal(getKeysetInfo(h.ks))
c, err := prototext.MarshalOptions{}.Marshal(h.KeysetInfo())
if err != nil {
return ""
}
Expand All @@ -291,13 +329,13 @@ func (h *Handle) String() string {

// Len returns the number of keys in the keyset.
func (h *Handle) Len() int {
return len(h.ks.GetKey())
return len(h.entries)
}

// KeysetInfo returns KeysetInfo representation of the managed keyset.
// The result does not contain any sensitive key material.
func (h *Handle) KeysetInfo() *tinkpb.KeysetInfo {
return getKeysetInfo(h.ks)
return getKeysetInfo(keysetMaterial(h))
}

// Write encrypts and writes the enclosing keyset.
Expand All @@ -307,10 +345,11 @@ func (h *Handle) Write(writer Writer, masterKey tink.AEAD) error {

// WriteWithAssociatedData encrypts and writes the enclosing keyset using the provided associated data.
func (h *Handle) WriteWithAssociatedData(writer Writer, masterKey tink.AEAD, associatedData []byte) error {
if h.ks == nil {
return errors.New("keyset.Handle: invalid keyset")
protoKeyset, err := entriesToProtoKeyset(h.entries)
if err != nil {
return err
}
encrypted, err := encrypt(h.ks, masterKey, associatedData)
encrypted, err := encrypt(protoKeyset, masterKey, associatedData)
if err != nil {
return err
}
Expand All @@ -323,7 +362,11 @@ func (h *Handle) WriteWithNoSecrets(w Writer) error {
if h.keysetHasSecrets {
return errors.New("keyset.Handle: exporting unencrypted secret key material is forbidden")
}
return w.Write(h.ks)
protoKeyset, err := entriesToProtoKeyset(h.entries)
if err != nil {
return err
}
return w.Write(protoKeyset)
}

// Config defines methods in the config.Config concrete type that are used by keyset.Handle.
Expand Down Expand Up @@ -402,12 +445,16 @@ func (h *Handle) primitives(km registry.KeyManager, opts ...PrimitivesOption) (*
config = &registryconfig.RegistryConfig{}
}

if err := Validate(h.ks); err != nil {
if err := h.validateKeyset(); err != nil {
return nil, fmt.Errorf("invalid keyset: %v", err)
}
primitiveSet := primitiveset.New()
primitiveSet.Annotations = h.annotations
for _, key := range h.ks.Key {
protoKeyset, err := entriesToProtoKeyset(h.entries)
if err != nil {
return nil, err
}
for _, key := range protoKeyset.Key {
if key.Status != tinkpb.KeyStatusType_ENABLED {
continue
}
Expand All @@ -425,7 +472,7 @@ func (h *Handle) primitives(km registry.KeyManager, opts ...PrimitivesOption) (*
if err != nil {
return nil, fmt.Errorf("cannot add primitive: %v", err)
}
if key.KeyId == h.ks.PrimaryKeyId {
if key.KeyId == protoKeyset.PrimaryKeyId {
primitiveSet.Primary = entry
}
}
Expand Down
Loading

0 comments on commit e8bdf5c

Please sign in to comment.