diff --git a/protocol/entities.go b/protocol/entities.go index 814e608e..c1495337 100644 --- a/protocol/entities.go +++ b/protocol/entities.go @@ -33,7 +33,7 @@ type CredentialEntity struct { type RelyingPartyEntity struct { CredentialEntity // A unique identifier for the Relying Party entity, which sets the RP ID. - ID string `json:"id"` + ID string `json:"id,omitempty"` } // The UserEntity represents the PublicKeyCredentialUserEntity IDL and is used to supply additional user account diff --git a/webauthn/types.go b/webauthn/types.go index bb93f31a..482a5b21 100644 --- a/webauthn/types.go +++ b/webauthn/types.go @@ -27,6 +27,7 @@ type WebAuthn struct { // Config represents the WebAuthn configuration. type Config struct { // RPID configures the Relying Party Server ID. This should generally be the origin without a scheme and port. + // If absent the browser will automatically determine this using standard conventions. RPID string // RPDisplayName configures the display name for the Relying Party Server. This can be any string. @@ -101,20 +102,28 @@ func (config *Config) validate() error { return fmt.Errorf(errFmtFieldEmpty, "RPDisplayName") } - if len(config.RPID) == 0 { - return fmt.Errorf(errFmtFieldEmpty, "RPID") - } - var err error - if _, err = url.Parse(config.RPID); err != nil { - return fmt.Errorf(errFmtFieldNotValidURI, "RPID", err) + var uri *url.URL + + if len(config.RPID) != 0 { + if uri, err = url.Parse(config.RPID); err != nil { + return fmt.Errorf(errFmtFieldNotValidURI, "RPID", err) + } + + if uri.IsAbs() { + return fmt.Errorf("field '%s' is an absolute URI but it must not be an absolute URI", "RPID") + } } - if config.RPIcon != "" { - if _, err = url.Parse(config.RPIcon); err != nil { + if len(config.RPIcon) != 0 { + if uri, err = url.Parse(config.RPIcon); err != nil { return fmt.Errorf(errFmtFieldNotValidURI, "RPIcon", err) } + + if !uri.IsAbs() { + return fmt.Errorf("field '%s' is not an absolute URI but it must be an absolute URI", "RPIcon") + } } defaultTimeoutConfig := defaultTimeout @@ -141,7 +150,7 @@ func (config *Config) validate() error { config.Timeouts.Registration.TimeoutUVD = defaultTimeoutUVDConfig } - if len(config.RPOrigin) > 0 { + if len(config.RPOrigin) != 0 { if len(config.RPOrigins) != 0 { return fmt.Errorf("deprecated field 'RPOrigin' can't be defined at the same tme as the replacement field 'RPOrigins'") } diff --git a/webauthn/types_test.go b/webauthn/types_test.go new file mode 100644 index 00000000..e0d41fdb --- /dev/null +++ b/webauthn/types_test.go @@ -0,0 +1,182 @@ +package webauthn + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfigValidateErr(t *testing.T) { + testCases := []struct { + name string + have *Config + err string + check func(t *testing.T, have *Config) + }{ + { + "ShouldNotErrorOnStandardConfig", + &Config{ + RPID: "example.com", + RPDisplayName: "example", + RPOrigins: []string{ + "https://example.com", + }, + }, + "", + nil, + }, + { + "ShouldErrorOnAbsoluteRPID", + &Config{ + RPID: "https://example.com", + RPDisplayName: "example", + RPOrigins: []string{ + "https://example.com", + }, + }, + "field 'RPID' is an absolute URI but it must not be an absolute URI", + nil, + }, + { + "ShouldSkipValidation", + &Config{ + validated: true, + }, + "", + nil, + }, + { + "ShouldErrorOnBadRPIcon", + &Config{ + RPID: "example.com", + RPIcon: "exa$##$#@$@#%^@#mple.com", + RPDisplayName: "example", + RPOrigins: []string{ + "https://example.com", + }, + }, + "field 'RPIcon' is not a valid URI: parse \"exa$##$#@$@#%^@#mple.com\": invalid URL escape \"%^@\"", + nil, + }, + { + "ShouldErrorOnBadRPIconAbsolute", + &Config{ + RPID: "example.com", + RPIcon: "example.com", + RPDisplayName: "example", + RPOrigins: []string{ + "https://example.com", + }, + }, + "field 'RPIcon' is not an absolute URI but it must be an absolute URI", + nil, + }, + { + "ShouldSetFallbackRPOriginAndNotErr", + &Config{ + RPID: "example.com", + RPDisplayName: "example", + RPOrigin: "https://example.com", + RPOrigins: []string{}, + }, + "", + func(t *testing.T, have *Config) { + require.Len(t, have.RPOrigins, 1) + assert.Equal(t, "https://example.com", have.RPOrigins[0]) + }, + }, + { + "ShouldNotErrorOnConfigWithoutRPID", + &Config{ + RPDisplayName: "example", + RPOrigins: []string{ + "https://example.com", + }, + }, + "", + nil, + }, + { + "ShouldErrorOnNoDisplayName", + &Config{ + RPID: "example.com", + RPOrigins: []string{ + "https://example.com", + }, + }, + "the field 'RPDisplayName' must be configured but it is empty", + nil, + }, + { + "ShouldErrorOnNoOrigins", + &Config{ + RPID: "example.com", + RPDisplayName: "example", + RPOrigins: []string{}, + }, + "must provide at least one value to the 'RPOrigins' field", + nil, + }, + { + "ShouldErrorOnInvalidRPID", + &Config{ + RPID: "exa$##$#@$@#%^@#mple.com", + RPDisplayName: "example", + RPOrigins: []string{ + "https://example.com", + }, + }, + "field 'RPID' is not a valid URI: parse \"exa$##$#@$@#%^@#mple.com\": invalid URL escape \"%^@\"", + nil, + }, + { + "ShouldErrorOnDeprecatedAndNewRPOrigins", + &Config{ + RPID: "example.com", + RPDisplayName: "example", + RPOrigin: "https://example.com", + RPOrigins: []string{ + "https://example.com", + }, + }, + "deprecated field 'RPOrigin' can't be defined at the same tme as the replacement field 'RPOrigins'", + nil, + }, + { + "ShouldSetDefaultTimeoutValues", + &Config{ + RPID: "example.com", + RPDisplayName: "example", + RPOrigins: []string{ + "https://example.com", + }, + Timeout: int(time.Second.Milliseconds()), + }, + "", + func(t *testing.T, have *Config) { + assert.Equal(t, time.Second, have.Timeouts.Login.Timeout) + assert.Equal(t, time.Second, have.Timeouts.Login.TimeoutUVD) + assert.Equal(t, time.Second, have.Timeouts.Registration.Timeout) + assert.Equal(t, time.Second, have.Timeouts.Registration.TimeoutUVD) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.have.validate() + + if len(tc.err) == 0 { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tc.err) + } + + if tc.check != nil { + tc.check(t, tc.have) + } + }) + } +}