Skip to content

Commit

Permalink
Extract primitives from the primitive set on AEAD factory creation
Browse files Browse the repository at this point in the history
This reduces the amount of allocations performed during `Encrypt/Decrypt`

PiperOrigin-RevId: 696823276
Change-Id: Iaf9043de2bb8f022efea51e838eb8493d34d0660
  • Loading branch information
morambro authored and copybara-github committed Nov 15, 2024
1 parent 4a5d7f1 commit 97c1d8a
Showing 1 changed file with 48 additions and 38 deletions.
86 changes: 48 additions & 38 deletions aead/aead_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,26 @@ func NewWithConfig(handle *keyset.Handle, config keyset.Config) (tink.AEAD, erro
// wrappedAead is an AEAD implementation that uses the underlying primitive set for encryption
// and decryption.
type wrappedAead struct {
ps *primitiveset.PrimitiveSet
primary aeadAndKeyID
primitives map[string][]aeadAndKeyID

encLogger monitoring.Logger
decLogger monitoring.Logger
}

type aeadAndKeyID struct {
primitive tink.AEAD
keyID uint32
}

func (a *aeadAndKeyID) Encrypt(plaintext, associatedData []byte) ([]byte, error) {
return a.primitive.Encrypt(plaintext, associatedData)
}

func (a *aeadAndKeyID) Decrypt(ciphertext, associatedData []byte) ([]byte, error) {
return a.primitive.Decrypt(ciphertext, associatedData)
}

// aeadPrimitiveAdapter is an adapter that turns a non-full [tink.AEAD]
// primitive into a full [tink.AEAD] primitive.
type fullAEADPrimitiveAdapter struct {
Expand All @@ -74,42 +89,50 @@ func (a *fullAEADPrimitiveAdapter) Decrypt(ciphertext, associatedData []byte) ([
return a.primitive.Decrypt(ciphertext[len(a.prefix):], associatedData)
}

// getFullPrimitive returns a full [tink.AEAD] from the given
// extractFullAEAD returns a full aeadAndKeyID primitive from the given
// [primitiveset.Entry].
func getFullPrimitive(entry *primitiveset.Entry) (tink.AEAD, error) {
func extractFullAEAD(entry *primitiveset.Entry) (*aeadAndKeyID, error) {
if entry.FullPrimitive != nil {
a, ok := (entry.FullPrimitive).(tink.AEAD)
if !ok {
return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
}
return a, nil
return &aeadAndKeyID{primitive: a, keyID: entry.KeyID}, nil
}
a, ok := (entry.Primitive).(tink.AEAD)
if !ok {
return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
}
return &fullAEADPrimitiveAdapter{primitive: a, prefix: []byte(entry.Prefix)}, nil
return &aeadAndKeyID{
primitive: &fullAEADPrimitiveAdapter{primitive: a, prefix: []byte(entry.Prefix)},
keyID: entry.KeyID,
}, nil
}

func newWrappedAead(ps *primitiveset.PrimitiveSet) (*wrappedAead, error) {
if _, err := getFullPrimitive(ps.Primary); err != nil {
primary, err := extractFullAEAD(ps.Primary)
if err != nil {
return nil, err
}
primitives := make(map[string][]aeadAndKeyID)
for _, entries := range ps.Entries {
for _, entry := range entries {
if _, err := getFullPrimitive(entry); err != nil {
p, err := extractFullAEAD(entry)
if err != nil {
return nil, err
}
primitives[entry.Prefix] = append(primitives[entry.Prefix], *p)
}
}
encLogger, decLogger, err := createLoggers(ps)
if err != nil {
return nil, err
}
return &wrappedAead{
ps: ps,
encLogger: encLogger,
decLogger: decLogger,
primary: *primary,
primitives: primitives,
encLogger: encLogger,
decLogger: decLogger,
}, nil
}

Expand Down Expand Up @@ -144,60 +167,47 @@ func createLoggers(ps *primitiveset.PrimitiveSet) (monitoring.Logger, monitoring
// Encrypt encrypts the given plaintext with the given associatedData.
// It returns the concatenation of the primary's identifier and the ciphertext.
func (a *wrappedAead) Encrypt(plaintext, associatedData []byte) ([]byte, error) {
primary := a.ps.Primary
p, err := getFullPrimitive(primary)
if err != nil {
return nil, err
}
ct, err := p.Encrypt(plaintext, associatedData)
ct, err := a.primary.Encrypt(plaintext, associatedData)
if err != nil {
a.encLogger.LogFailure()
return nil, err
}
a.encLogger.Log(primary.KeyID, len(plaintext))
a.encLogger.Log(a.primary.keyID, len(plaintext))
return ct, nil
}

// Decrypt decrypts the given ciphertext and authenticates it with the given
// associatedData. It returns the corresponding plaintext if the
// ciphertext is authenticated.
func (a *wrappedAead) Decrypt(ciphertext, associatedData []byte) ([]byte, error) {
// try non-raw keys
// Try non-raw keys.
prefixSize := cryptofmt.NonRawPrefixSize
if len(ciphertext) > prefixSize {
prefix := ciphertext[:prefixSize]
entries, err := a.ps.EntriesForPrefix(string(prefix))
if err == nil {
for _, entry := range entries {
p, err := getFullPrimitive(entry)
if err != nil {
return nil, err
}
pt, err := p.Decrypt(ciphertext, associatedData)
primitivesForPrefix, ok := a.primitives[string(prefix)]
if ok {
for _, primitive := range primitivesForPrefix {
pt, err := primitive.Decrypt(ciphertext, associatedData)
if err == nil {
numBytes := len(ciphertext[prefixSize:])
a.decLogger.Log(entry.KeyID, numBytes)
a.decLogger.Log(primitive.keyID, numBytes)
return pt, nil
}
}
}
}
// try raw keys
entries, err := a.ps.RawEntries()
if err == nil {
for _, entry := range entries {
p, err := getFullPrimitive(entry)
if err != nil {
return nil, err
}
pt, err := p.Decrypt(ciphertext, associatedData)
// Try raw keys.
rawPrimitives, ok := a.primitives[cryptofmt.RawPrefix]
if ok {
for _, primitive := range rawPrimitives {
pt, err := primitive.Decrypt(ciphertext, associatedData)
if err == nil {
a.decLogger.Log(entry.KeyID, len(ciphertext))
a.decLogger.Log(primitive.keyID, len(ciphertext))
return pt, nil
}
}
}
// nothing worked
// Nothing worked.
a.decLogger.LogFailure()
return nil, fmt.Errorf("aead_factory: decryption failed")
}

0 comments on commit 97c1d8a

Please sign in to comment.