diff --git a/aead/aead_factory.go b/aead/aead_factory.go index ad20488..785524e 100644 --- a/aead/aead_factory.go +++ b/aead/aead_factory.go @@ -50,11 +50,26 @@ func NewWithConfig(handle *keyset.Handle, config keyset.Config) (tink.AEAD, erro // wrappedAead is an AEAD implementation that uses the underlying primitive set for encryption // and decryption. type wrappedAead struct { - ps *primitiveset.PrimitiveSet + primary aeadAndKeyID + primitives map[string][]aeadAndKeyID + encLogger monitoring.Logger decLogger monitoring.Logger } +type aeadAndKeyID struct { + primitive tink.AEAD + keyID uint32 +} + +func (a *aeadAndKeyID) Encrypt(plaintext, associatedData []byte) ([]byte, error) { + return a.primitive.Encrypt(plaintext, associatedData) +} + +func (a *aeadAndKeyID) Decrypt(ciphertext, associatedData []byte) ([]byte, error) { + return a.primitive.Decrypt(ciphertext, associatedData) +} + // aeadPrimitiveAdapter is an adapter that turns a non-full [tink.AEAD] // primitive into a full [tink.AEAD] primitive. type fullAEADPrimitiveAdapter struct { @@ -74,32 +89,39 @@ func (a *fullAEADPrimitiveAdapter) Decrypt(ciphertext, associatedData []byte) ([ return a.primitive.Decrypt(ciphertext[len(a.prefix):], associatedData) } -// getFullPrimitive returns a full [tink.AEAD] from the given +// extractFullAEAD returns a full aeadAndKeyID primitive from the given // [primitiveset.Entry]. -func getFullPrimitive(entry *primitiveset.Entry) (tink.AEAD, error) { +func extractFullAEAD(entry *primitiveset.Entry) (*aeadAndKeyID, error) { if entry.FullPrimitive != nil { a, ok := (entry.FullPrimitive).(tink.AEAD) if !ok { return nil, fmt.Errorf("aead_factory: not an AEAD primitive") } - return a, nil + return &aeadAndKeyID{primitive: a, keyID: entry.KeyID}, nil } a, ok := (entry.Primitive).(tink.AEAD) if !ok { return nil, fmt.Errorf("aead_factory: not an AEAD primitive") } - return &fullAEADPrimitiveAdapter{primitive: a, prefix: []byte(entry.Prefix)}, nil + return &aeadAndKeyID{ + primitive: &fullAEADPrimitiveAdapter{primitive: a, prefix: []byte(entry.Prefix)}, + keyID: entry.KeyID, + }, nil } func newWrappedAead(ps *primitiveset.PrimitiveSet) (*wrappedAead, error) { - if _, err := getFullPrimitive(ps.Primary); err != nil { + primary, err := extractFullAEAD(ps.Primary) + if err != nil { return nil, err } + primitives := make(map[string][]aeadAndKeyID) for _, entries := range ps.Entries { for _, entry := range entries { - if _, err := getFullPrimitive(entry); err != nil { + p, err := extractFullAEAD(entry) + if err != nil { return nil, err } + primitives[entry.Prefix] = append(primitives[entry.Prefix], *p) } } encLogger, decLogger, err := createLoggers(ps) @@ -107,9 +129,10 @@ func newWrappedAead(ps *primitiveset.PrimitiveSet) (*wrappedAead, error) { return nil, err } return &wrappedAead{ - ps: ps, - encLogger: encLogger, - decLogger: decLogger, + primary: *primary, + primitives: primitives, + encLogger: encLogger, + decLogger: decLogger, }, nil } @@ -144,17 +167,12 @@ func createLoggers(ps *primitiveset.PrimitiveSet) (monitoring.Logger, monitoring // Encrypt encrypts the given plaintext with the given associatedData. // It returns the concatenation of the primary's identifier and the ciphertext. func (a *wrappedAead) Encrypt(plaintext, associatedData []byte) ([]byte, error) { - primary := a.ps.Primary - p, err := getFullPrimitive(primary) - if err != nil { - return nil, err - } - ct, err := p.Encrypt(plaintext, associatedData) + ct, err := a.primary.Encrypt(plaintext, associatedData) if err != nil { a.encLogger.LogFailure() return nil, err } - a.encLogger.Log(primary.KeyID, len(plaintext)) + a.encLogger.Log(a.primary.keyID, len(plaintext)) return ct, nil } @@ -162,42 +180,34 @@ func (a *wrappedAead) Encrypt(plaintext, associatedData []byte) ([]byte, error) // associatedData. It returns the corresponding plaintext if the // ciphertext is authenticated. func (a *wrappedAead) Decrypt(ciphertext, associatedData []byte) ([]byte, error) { - // try non-raw keys + // Try non-raw keys. prefixSize := cryptofmt.NonRawPrefixSize if len(ciphertext) > prefixSize { prefix := ciphertext[:prefixSize] - entries, err := a.ps.EntriesForPrefix(string(prefix)) - if err == nil { - for _, entry := range entries { - p, err := getFullPrimitive(entry) - if err != nil { - return nil, err - } - pt, err := p.Decrypt(ciphertext, associatedData) + primitivesForPrefix, ok := a.primitives[string(prefix)] + if ok { + for _, primitive := range primitivesForPrefix { + pt, err := primitive.Decrypt(ciphertext, associatedData) if err == nil { numBytes := len(ciphertext[prefixSize:]) - a.decLogger.Log(entry.KeyID, numBytes) + a.decLogger.Log(primitive.keyID, numBytes) return pt, nil } } } } - // try raw keys - entries, err := a.ps.RawEntries() - if err == nil { - for _, entry := range entries { - p, err := getFullPrimitive(entry) - if err != nil { - return nil, err - } - pt, err := p.Decrypt(ciphertext, associatedData) + // Try raw keys. + rawPrimitives, ok := a.primitives[cryptofmt.RawPrefix] + if ok { + for _, primitive := range rawPrimitives { + pt, err := primitive.Decrypt(ciphertext, associatedData) if err == nil { - a.decLogger.Log(entry.KeyID, len(ciphertext)) + a.decLogger.Log(primitive.keyID, len(ciphertext)) return pt, nil } } } - // nothing worked + // Nothing worked. a.decLogger.LogFailure() return nil, fmt.Errorf("aead_factory: decryption failed") }