Skip to content

Commit

Permalink
Encryption enforcement
Browse files Browse the repository at this point in the history
  • Loading branch information
agnivade committed Apr 14, 2022
1 parent 4e9d6e9 commit d97630e
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 16 deletions.
24 changes: 19 additions & 5 deletions decode_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,7 @@ func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
return nil
}

if err := etreeutils.NSFindIterate(el, SAMLAssertionNamespace, EncryptedAssertionTag, decryptAssertion); err != nil {
return err
} else {
return nil
}
return etreeutils.NSFindIterate(el, SAMLAssertionNamespace, EncryptedAssertionTag, decryptAssertion)
}

func (sp *SAMLServiceProvider) validateElementSignature(el *etree.Element) (*etree.Element, error) {
Expand Down Expand Up @@ -265,6 +261,24 @@ func (sp *SAMLServiceProvider) ValidateEncodedResponse(encodedResponse string) (
return nil, err
}

elAssertion, err := etreeutils.NSFindOne(el, SAMLAssertionNamespace, AssertionTag)
if err != nil {
return nil, err
}
elEncAssertion, err := etreeutils.NSFindOne(el, SAMLAssertionNamespace, EncryptedAssertionTag)
if err != nil {
return nil, err
}
// We verify that either one of assertion or encrypted assertion elements are present,
// but not both.
if (elAssertion == nil) == (elEncAssertion == nil) {
return nil, fmt.Errorf("found both or no assertion and encrypted assertion elements")
}
// And if a decryptCert is present, then it's only encrypted assertion elements.
if sp.SPKeyStore != nil && elAssertion != nil {
return nil, fmt.Errorf("all assertions are not encrypted")
}

var responseSignatureValidated bool
if !sp.SkipSignatureValidation {
el, err = sp.validateElementSignature(el)
Expand Down
1 change: 0 additions & 1 deletion decode_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ func TestCompressedResponse(t *testing.T) {

sp := SAMLServiceProvider{
AssertionConsumerServiceURL: "https://f1f51ddc.ngrok.io/api/sso/saml2/acs/58cafd0573d4f375b8e70e8e",
SPKeyStore: dsig.TLSCertKeyStore(cert),
IDPCertificateStore: &dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{idpCert},
},
Expand Down
15 changes: 13 additions & 2 deletions providertests/onelogin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"fmt"
"testing"

"github.com/mattermost/gosaml2"
saml2 "github.com/mattermost/gosaml2"
)

var oneLoginScenarioErrors = map[int]string{
Expand Down Expand Up @@ -182,6 +182,10 @@ var oneLoginScenarioWarnings = map[int]scenarioWarnings{
},
}

// oneLoginNilKeyStoreIndices is a slice of indices where keyStore is not required, but is passed nevertheless.
// This is to make the tests pass.
var oneLoginNilKeyStoreIndices = []int{1, 3, 4, 11, 12, 13, 14, 15, 21, 22, 25, 26, 31, 33, 34, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 81, 82, 83, 84, 85, 86, 87, 88, 89, 91, 92, 93, 94, 99, 155, 156, 157, 158, 159}

var oneLoginAtTimes = map[int]string{
25: "2017-08-30T23:00:00Z",
26: "2017-08-30T23:55:00Z",
Expand All @@ -204,10 +208,17 @@ func TestOneLoginCasesLocally(t *testing.T) {
scenarios := []ProviderTestScenario{}
for _, idx := range scenarioIndexes(oneLoginScenarioErrors, oneLoginScenarioWarnings) {
response := LoadRawResponse(fmt.Sprintf("./testdata/onelogin/olgn09_response_%02d.b64", idx))
var nilKeyStore bool
for _, ind := range oneLoginNilKeyStoreIndices {
if idx == ind {
nilKeyStore = true
break
}
}
scenarios = append(scenarios, ProviderTestScenario{
ScenarioName: fmt.Sprintf("Scenario_%02d", idx),
Response: response,
ServiceProvider: spAtTime(sp, getAtTime(idx, oneLoginAtTimes), response),
ServiceProvider: spAtTime(sp, getAtTime(idx, oneLoginAtTimes), response, nilKeyStore),
CheckError: scenarioErrorChecker(idx, oneLoginScenarioErrors),
CheckWarningInfo: scenarioWarningChecker(idx, oneLoginScenarioWarnings),
})
Expand Down
15 changes: 13 additions & 2 deletions providertests/pingfed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"fmt"
"testing"

"github.com/mattermost/gosaml2"
saml2 "github.com/mattermost/gosaml2"
)

var pingFedScenarioErrors = map[int]string{
Expand Down Expand Up @@ -51,6 +51,10 @@ var pingFedScenarioErrors = map[int]string{

var pingFedScenarioWarnings = map[int]scenarioWarnings{}

// pingFedNilKeyStoreIndices is a slice of indices where keyStore is not required, but is passed nevertheless.
// This is to make the tests pass.
var pingFedNilKeyStoreIndices = []int{1, 2, 3}

var pingFedAtTimes = map[int]string{}

func TestPingFedCasesLocally(t *testing.T) {
Expand All @@ -68,10 +72,17 @@ func TestPingFedCasesLocally(t *testing.T) {
scenarios := []ProviderTestScenario{}
for _, idx := range scenarioIndexes(pingFedScenarioErrors, pingFedScenarioWarnings) {
response := LoadRawResponse(fmt.Sprintf("./testdata/pingfed/pfed11_response_%02d.b64", idx))
var nilKeyStore bool
for _, ind := range pingFedNilKeyStoreIndices {
if idx == ind {
nilKeyStore = true
break
}
}
scenarios = append(scenarios, ProviderTestScenario{
ScenarioName: fmt.Sprintf("Scenario_%02d", idx),
Response: response,
ServiceProvider: spAtTime(sp, getAtTime(idx, pingFedAtTimes), response),
ServiceProvider: spAtTime(sp, getAtTime(idx, pingFedAtTimes), response, nilKeyStore),
CheckError: scenarioErrorChecker(idx, pingFedScenarioErrors),
CheckWarningInfo: scenarioWarningChecker(idx, pingFedScenarioWarnings),
})
Expand Down
5 changes: 2 additions & 3 deletions providertests/providers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
"time"

"github.com/jonboulle/clockwork"
"github.com/mattermost/gosaml2"
"github.com/russellhaering/goxmldsig"
saml2 "github.com/mattermost/gosaml2"
dsig "github.com/russellhaering/goxmldsig"
)

func TestValidateResponses(t *testing.T) {
Expand All @@ -47,7 +47,6 @@ func TestValidateResponses(t *testing.T) {
SignAuthnRequests: false,
AudienceURI: "https://saml.test.nope/session/sso/saml/spentityid/dknhyszjl7",
IDPCertificateStore: LoadCertificateStore("./testdata/adfs_idp_signing_cert.pem"),
SPKeyStore: LoadKeyStore("./testdata/adfs_sp_encryption_cert.pem", "./testdata/adfs_sp_encryption_key.pem"),
SPSigningKeyStore: LoadKeyStore("./testdata/adfs_sp_signing_cert.pem", "./testdata/adfs_sp_signing_key.pem"),
Clock: dsig.NewFakeClock(clockwork.NewFakeClockAt(time.Date(2017, 9, 21, 23, 28, 0, 0, time.UTC))),
},
Expand Down
9 changes: 6 additions & 3 deletions providertests/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ import (
"time"

"github.com/jonboulle/clockwork"
"github.com/mattermost/gosaml2"
saml2 "github.com/mattermost/gosaml2"
"github.com/mattermost/gosaml2/types"
"github.com/russellhaering/goxmldsig"
dsig "github.com/russellhaering/goxmldsig"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -136,7 +136,7 @@ func getAtTime(idx int, scenarioAtTimes map[int]string) (atTime time.Time) {
return // zero time
}

func spAtTime(template *saml2.SAMLServiceProvider, atTime time.Time, rawResp string) *saml2.SAMLServiceProvider {
func spAtTime(template *saml2.SAMLServiceProvider, atTime time.Time, rawResp string, nilKeyStore bool) *saml2.SAMLServiceProvider {
resp := &types.Response{}
if rawResp == "" {
panic(fmt.Errorf("empty rawResp"))
Expand All @@ -152,6 +152,9 @@ func spAtTime(template *saml2.SAMLServiceProvider, atTime time.Time, rawResp str

var sp saml2.SAMLServiceProvider
sp = *template // copy most fields template, we only set the clock below
if nilKeyStore {
sp.SPKeyStore = nil
}
if atTime.IsZero() {
// Prefer more official Assertion IssueInstant over Response IssueIntant
// (Assertion will be signed, either individually or as part of Response)
Expand Down
19 changes: 19 additions & 0 deletions saml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ func TestSAML(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, authRequestURL)

// XXX: Need to set the keystore to nil to bypass the assertion
// to check if all elements are encrypted.
sp.SPKeyStore = nil

authRequestString, err := sp.BuildAuthRequest()
require.NoError(t, err)
require.NotEmpty(t, authRequestString)
Expand Down Expand Up @@ -287,6 +291,21 @@ func TestInvalidResponseBadBase64(t *testing.T) {
require.Nil(t, response)
}

func TestMixedAssertions(t *testing.T) {
f, err := ioutil.ReadFile("./testdata/mixed_assertions.xml")
if err != nil {
t.Fatalf("could not open test file: %v\n", err)
}

b64Response := base64.StdEncoding.EncodeToString(f)
sp := &SAMLServiceProvider{
SkipSignatureValidation: true,
}
response, err := sp.ValidateEncodedResponse(b64Response)
require.EqualError(t, err, "found both or no assertion and encrypted assertion elements")
require.Nil(t, response)
}

func TestInvalidResponseBadCompression(t *testing.T) {
sp := &SAMLServiceProvider{}

Expand Down
Loading

0 comments on commit d97630e

Please sign in to comment.