Skip to content

Commit

Permalink
Clean up usage of hash.Hash interface.
Browse files Browse the repository at this point in the history
Check for nil value of the exported HMAC.HashFunc field.

PiperOrigin-RevId: 611642993
Change-Id: I47cdeb5069ec0d43666a8ed3d77e73b912a62964
  • Loading branch information
chuckx authored and copybara-github committed Mar 1, 2024
1 parent 328cc24 commit f3a178b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
7 changes: 4 additions & 3 deletions mac/subtle/hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ func ValidateHMACParams(hash string, keySize uint32, tagSize uint32) error {

// ComputeMAC computes message authentication code (MAC) for the given data.
func (h *HMAC) ComputeMAC(data []byte) ([]byte, error) {
mac := hmac.New(h.HashFunc, h.Key)
if _, err := mac.Write(data); err != nil {
return nil, err
if h.HashFunc == nil {
return nil, fmt.Errorf("hmac: invalid hash algorithm")
}
mac := hmac.New(h.HashFunc, h.Key)
mac.Write(data)
tag := mac.Sum(nil)
return tag[:h.TagSize], nil
}
Expand Down
14 changes: 14 additions & 0 deletions mac/subtle/hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ func TestNewHMACWithInvalidInput(t *testing.T) {
}
}

func TestHMACWithNilHashFunc(t *testing.T) {
cipher, err := subtle.NewHMAC("SHA256", random.GetRandomBytes(32), 32)
if err != nil {
t.Fatalf("subtle.NewHMAC() err = %v", err)
}

// Modify exported field.
cipher.HashFunc = nil

if _, err := cipher.ComputeMAC([]byte{}); err == nil {
t.Errorf("cipher.ComputerMAC() err = nil, want not nil")
}
}

func TestHMAComputeVerifyWithNilInput(t *testing.T) {
cipher, err := subtle.NewHMAC("SHA256", random.GetRandomBytes(16), 32)
if err != nil {
Expand Down
7 changes: 1 addition & 6 deletions subtle/subtle.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,7 @@ func ComputeHash(hashFunc func() hash.Hash, data []byte) ([]byte, error) {
return nil, errNilHashFunc
}
h := hashFunc()

_, err := h.Write(data)
if err != nil {
return nil, err
}

h.Write(data)
return h.Sum(nil), nil
}

Expand Down

0 comments on commit f3a178b

Please sign in to comment.