diff --git a/aead/aesgcm/key.go b/aead/aesgcm/key.go index 8edeebc..5e29709 100644 --- a/aead/aesgcm/key.go +++ b/aead/aesgcm/key.go @@ -101,19 +101,26 @@ type ParametersOpts struct { Variant Variant } -// NewParameters creates a new AES-GCM Parameters object. -func NewParameters(opts ParametersOpts) (*Parameters, error) { +func validateOpts(opts *ParametersOpts) error { if opts.KeySizeInBytes != 16 && opts.KeySizeInBytes != 24 && opts.KeySizeInBytes != 32 { - return nil, fmt.Errorf("aesgcm.Parameters: unsupported key size; want 16, 24, or 32, got: %v", opts.KeySizeInBytes) + return fmt.Errorf("unsupported key size; want 16, 24, or 32, got: %v", opts.KeySizeInBytes) } if opts.IVSizeInBytes <= 0 { - return nil, fmt.Errorf("aesgcm.Parameters: unsupported IV size; want > 0, got: %v", opts.IVSizeInBytes) + return fmt.Errorf("unsupported IV size; want > 0, got: %v", opts.IVSizeInBytes) } if opts.TagSizeInBytes < 12 || opts.TagSizeInBytes > 16 { - return nil, fmt.Errorf("aesgcm.Parameters: unsupported tag size; want >= 12 and <= 16, got: %v", opts.TagSizeInBytes) + return fmt.Errorf("unsupported tag size; want >= 12 and <= 16, got: %v", opts.TagSizeInBytes) } if opts.Variant == VariantUnknown { - return nil, fmt.Errorf("aesgcm.Parameters: unsupported variant: %v", opts.Variant) + return fmt.Errorf("unsupported variant: %v", opts.Variant) + } + return nil +} + +// NewParameters creates a new AES-GCM Parameters object. +func NewParameters(opts ParametersOpts) (*Parameters, error) { + if err := validateOpts(&opts); err != nil { + return nil, fmt.Errorf("aesgcm.NewParameters: %v", err) } return &Parameters{ keySizeInBytes: opts.KeySizeInBytes, @@ -151,6 +158,17 @@ func NewKey(keyBytes secretdata.Bytes, keyID uint32, parameters *Parameters) (*K if parameters == nil { return nil, fmt.Errorf("aesgcm.NewKey: parameters is nil") } + + opts := &ParametersOpts{ + KeySizeInBytes: parameters.KeySizeInBytes(), + IVSizeInBytes: parameters.IVSizeInBytes(), + TagSizeInBytes: parameters.TagSizeInBytes(), + Variant: parameters.Variant(), + } + if err := validateOpts(opts); err != nil { + return nil, fmt.Errorf("aesgcm.NewKey: %v", err) + } + if keyBytes.Len() != int(parameters.KeySizeInBytes()) { return nil, fmt.Errorf("aesgcm.NewKey: key.Len() = %v, want %v", keyBytes.Len(), parameters.KeySizeInBytes()) } diff --git a/aead/aesgcm/key_test.go b/aead/aesgcm/key_test.go index 809522a..bdae1c9 100644 --- a/aead/aesgcm/key_test.go +++ b/aead/aesgcm/key_test.go @@ -91,6 +91,70 @@ func TestNewParametersInvalidVariant(t *testing.T) { } } +func TestNewKeyFailsIfParametersIsNil(t *testing.T) { + keyBytes, err := secretdata.NewBytesFromRand(32) + if err != nil { + t.Fatalf("secretdata.NewBytesFromRand(32) err = %v, want nil", err) + } + if _, err := aesgcm.NewKey(*keyBytes, 123, nil); err == nil { + t.Errorf("aesgcm.NewKey(*keyBytes, 123, nil) err = nil, want error") + } +} + +func TestNewKeyFailsIfKeySizeIsDifferentThanParameters(t *testing.T) { + for _, tc := range []struct { + name string + keyBytes *secretdata.Bytes + params aesgcm.ParametersOpts + }{ + { + name: "key size is 16 but parameters is 32", + keyBytes: secretdata.NewBytesFromData(key128Bits, insecuresecretdataaccess.Token{}), + params: aesgcm.ParametersOpts{ + KeySizeInBytes: 32, + IVSizeInBytes: 12, + TagSizeInBytes: 16, + Variant: aesgcm.VariantTink, + }, + }, + { + name: "key size is 32 but parameters is 16", + keyBytes: secretdata.NewBytesFromData(key256Bits, insecuresecretdataaccess.Token{}), + params: aesgcm.ParametersOpts{ + KeySizeInBytes: 16, + IVSizeInBytes: 12, + TagSizeInBytes: 16, + Variant: aesgcm.VariantTink, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + params, err := aesgcm.NewParameters(tc.params) + if err != nil { + t.Fatalf("aesgcm.NewParameters(%v) err = %v, want nil", tc.params, err) + } + if _, err := aesgcm.NewKey(*tc.keyBytes, 123, params); err == nil { + t.Errorf("aesgcm.NewKey(%v, 123, %v) err = nil, want error", tc.keyBytes, params) + } + }) + } +} + +// TestNewKeyFailsIfInvalidParams tests that NewKey fails if the parameters are invalid. +// +// The only way to create invalid parameters is to create a struct literal with default +// values. +func TestNewKeyFailsIfInvalidParams(t *testing.T) { + keyBytes, err := secretdata.NewBytesFromRand(32) + if err != nil { + t.Fatalf("secretdata.NewBytesFromRand(32) err = %v, want nil", err) + } + params := &aesgcm.Parameters{} + if _, err := aesgcm.NewKey(*keyBytes, 123, params); err == nil { + t.Errorf("aesgcm.NewKey(*keyBytes, 123, nil) err = nil, want error") + } +} + func TestOutputPrefix(t *testing.T) { for _, test := range []struct { name string