Skip to content

Commit

Permalink
refactor extension functions to support x509 and google_x509
Browse files Browse the repository at this point in the history
Signed-off-by: linus-sun <[email protected]>
  • Loading branch information
linus-sun committed Nov 13, 2024
1 parent 74be5d5 commit 425f8bd
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 166 deletions.
67 changes: 0 additions & 67 deletions pkg/ct/identity.go

This file was deleted.

56 changes: 0 additions & 56 deletions pkg/ct/identity_test.go

This file was deleted.

9 changes: 2 additions & 7 deletions pkg/ct/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (
ctclient "github.com/google/certificate-transparency-go/client"
"github.com/sigstore/rekor-monitor/pkg/fulcio/extensions"
"github.com/sigstore/rekor-monitor/pkg/identity"

"github.com/google/certificate-transparency-go/asn1"
)

func GetCTLogEntries(logClient *ctclient.LogClient, startIndex int, endIndex int) ([]ct.LogEntry, error) {
Expand All @@ -35,7 +33,7 @@ func GetCTLogEntries(logClient *ctclient.LogClient, startIndex int, endIndex int
return entries, nil
}

func ScanEntryCertSubject(logEntry ct.LogEntry, monitoredSubjects []string) ([]*identity.LogEntry, error) {
func ScanEntrySubject(logEntry ct.LogEntry, monitoredSubjects []string) ([]*identity.LogEntry, error) {
subject := logEntry.X509Cert.Subject.String()
matchedEntries := []*identity.LogEntry{}
for _, monitoredSub := range monitoredSubjects {
Expand All @@ -59,10 +57,7 @@ func ScanEntryOIDExtensions(logEntry ct.LogEntry, monitoredOIDMatchers []extensi
matchedEntries := []*identity.LogEntry{}
cert := logEntry.X509Cert
for _, monitoredOID := range monitoredOIDMatchers {
// must cast encoding/asn1 objectIdentifier to google/certificate-transparency-go fork of asn1.ObjectIdentifier
oidIntArray := []int(monitoredOID.ObjectIdentifier)
matchingOID := asn1.ObjectIdentifier(oidIntArray)
match, _, extValue, err := OIDMatchesPolicy(cert, matchingOID, monitoredOID.ExtensionValues)
match, _, extValue, err := identity.OIDMatchesPolicy(cert, monitoredOID.ObjectIdentifier, monitoredOID.ExtensionValues)
if err != nil {
return nil, fmt.Errorf("error with policy matching at index %d: %w", logEntry.Index, err)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/ct/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const (
organizationName = "test-org"
)

func TestScanEntryCertSubject(t *testing.T) {
func TestScanEntrySubject(t *testing.T) {
testCases := map[string]struct {
inputEntry ct.LogEntry
inputSubjects []string
Expand Down Expand Up @@ -72,7 +72,7 @@ func TestScanEntryCertSubject(t *testing.T) {
}

for _, tc := range testCases {
logEntries, err := ScanEntryCertSubject(tc.inputEntry, tc.inputSubjects)
logEntries, err := ScanEntrySubject(tc.inputEntry, tc.inputSubjects)
if err != nil {
t.Errorf("received error scanning entry for subjects: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/ct/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"testing"

google_asn1 "github.com/google/certificate-transparency-go/asn1"
"github.com/google/certificate-transparency-go/x509"
google_x509 "github.com/google/certificate-transparency-go/x509"
google_pkix "github.com/google/certificate-transparency-go/x509/pkix"
)

Expand Down Expand Up @@ -56,12 +56,12 @@ func serveRspAt(t *testing.T, path, rsp string) *httptest.Server {
})
}

func mockCertificateWithExtension(oid google_asn1.ObjectIdentifier, value string) (*x509.Certificate, error) {
func mockCertificateWithExtension(oid google_asn1.ObjectIdentifier, value string) (*google_x509.Certificate, error) {
extValue, err := google_asn1.Marshal(value)
if err != nil {
return nil, err
}
cert := &x509.Certificate{
cert := &google_x509.Certificate{
Extensions: []google_pkix.Extension{
{
Id: oid,
Expand Down
113 changes: 82 additions & 31 deletions pkg/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ import (
"crypto/x509"
"encoding/asn1"
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"strings"

"github.com/sigstore/rekor-monitor/pkg/fulcio/extensions"
"github.com/sigstore/sigstore/pkg/cryptoutils"

google_asn1 "github.com/google/certificate-transparency-go/asn1"
google_x509 "github.com/google/certificate-transparency-go/x509"
)

var (
Expand Down Expand Up @@ -193,52 +197,99 @@ func MonitoredValuesExist(mvs MonitoredValues) bool {

// getExtension gets a certificate extension by OID where the extension value is an
// ASN.1-encoded string
func getExtension(cert *x509.Certificate, oid asn1.ObjectIdentifier) (string, error) {
for _, ext := range cert.Extensions {
if !ext.Id.Equal(oid) {
continue
}
var extValue string
rest, err := asn1.Unmarshal(ext.Value, &extValue)
if err != nil {
return "", fmt.Errorf("%w", err)
func getExtension[Certificate *x509.Certificate | *google_x509.Certificate](certificate Certificate, oid asn1.ObjectIdentifier) (string, error) {
switch cert := any(certificate).(type) {
case *x509.Certificate:
for _, ext := range cert.Extensions {
if !ext.Id.Equal(oid) {
continue
}
var extValue string
rest, err := asn1.Unmarshal(ext.Value, &extValue)
if err != nil {
return "", fmt.Errorf("%w", err)
}
if len(rest) != 0 {
return "", fmt.Errorf("unmarshalling extension had rest for oid %v", oid)
}
return extValue, nil
}
if len(rest) != 0 {
return "", fmt.Errorf("unmarshalling extension had rest for oid %v", oid)
return "", nil
case *google_x509.Certificate:
for _, ext := range cert.Extensions {
if !ext.Id.Equal((google_asn1.ObjectIdentifier)(oid)) {
continue
}
var extValue string
rest, err := asn1.Unmarshal(ext.Value, &extValue)
if err != nil {
return "", fmt.Errorf("%w", err)
}
if len(rest) != 0 {
return "", fmt.Errorf("unmarshalling extension had rest for oid %v", oid)
}
return extValue, nil
}
return extValue, nil
return "", nil
}
return "", nil
return "", errors.New("certificate was neither x509 nor google_x509")
}

// getDeprecatedExtension gets a certificate extension by OID where the extension value is a raw string
func getDeprecatedExtension(cert *x509.Certificate, oid asn1.ObjectIdentifier) (string, error) {
for _, ext := range cert.Extensions {
if ext.Id.Equal(oid) {
return string(ext.Value), nil
func getDeprecatedExtension[Certificate *x509.Certificate | *google_x509.Certificate](certificate Certificate, oid asn1.ObjectIdentifier) (string, error) {
switch cert := any(certificate).(type) {
case *x509.Certificate:
for _, ext := range cert.Extensions {
if ext.Id.Equal(oid) {
return string(ext.Value), nil
}
}
return "", nil
case *google_x509.Certificate:
for _, ext := range cert.Extensions {
if ext.Id.Equal((google_asn1.ObjectIdentifier)(oid)) {
return string(ext.Value), nil
}
}
return "", nil
}
return "", nil
return "", errors.New("certificate was neither x509 nor google_x509")
}

// OIDMatchesPolicy returns if a certificate contains both a given OID field and a matching value associated with that field
// if true, it returns the OID extension and extension value that were matched on
func OIDMatchesPolicy(cert *x509.Certificate, oid asn1.ObjectIdentifier, extensionValues []string) (bool, asn1.ObjectIdentifier, string, error) {
extValue, err := getExtension(cert, oid)
if err != nil {
return false, nil, "", fmt.Errorf("error getting extension value: %w", err)
}
if extValue == "" {
func OIDMatchesPolicy[Certificate *x509.Certificate | *google_x509.Certificate](certificate Certificate, oid asn1.ObjectIdentifier, extensionValues []string) (bool, asn1.ObjectIdentifier, string, error) {
switch cert := any(certificate).(type) {
case *x509.Certificate:
extValue, err := getExtension(cert, oid)
if err != nil {
return false, nil, "", fmt.Errorf("error getting extension value: %w", err)
}
if extValue == "" {
return false, nil, "", nil
}
for _, extensionValue := range extensionValues {
if extValue == extensionValue {
return true, oid, extValue, nil
}
}
return false, nil, "", nil
}

for _, extensionValue := range extensionValues {
if extValue == extensionValue {
return true, oid, extValue, nil
case *google_x509.Certificate:
extValue, err := getExtension(cert, oid)
if err != nil {
return false, nil, "", fmt.Errorf("error getting extension value: %w", err)
}
if extValue == "" {
return false, nil, "", nil
}
for _, extensionValue := range extensionValues {
if extValue == extensionValue {
return true, oid, extValue, nil
}
}
return false, nil, "", nil
}

return false, nil, "", nil
return false, nil, "", errors.New("certificate was neither x509 nor google_x509")
}

// CertMatchesPolicy returns true if a certificate contains a given subject and optionally a given issuer
Expand Down
32 changes: 32 additions & 0 deletions pkg/identity/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ import (
"strings"
"testing"

google_asn1 "github.com/google/certificate-transparency-go/asn1"
google_x509 "github.com/google/certificate-transparency-go/x509"
google_pkix "github.com/google/certificate-transparency-go/x509/pkix"
"github.com/sigstore/rekor-monitor/pkg/fulcio/extensions"
)

Expand Down Expand Up @@ -382,6 +385,35 @@ func TestOIDMatchesValue(t *testing.T) {
}
}

// Test when OID is present and matches value
func TestGoogleOIDMatchesValue(t *testing.T) {
oid := asn1.ObjectIdentifier{2, 5, 29, 17}
extValueString := "test cert value"
extensionValues := []string{extValueString}
marshalledExtValue, err := google_asn1.Marshal(extValueString)
if err != nil {
t.Errorf("error marshalling extension value: %v", err)
}
cert := &google_x509.Certificate{
Extensions: []google_pkix.Extension{
{
Id: google_asn1.ObjectIdentifier{2, 5, 29, 17},
Value: marshalledExtValue,
},
},
}
matches, matchedOID, extValue, err := OIDMatchesPolicy(cert, oid, extensionValues)
if !matches || err != nil {
t.Errorf("Expected true, got %v, error %v", matches, err)
}
if matchedOID.String() != oid.String() {
t.Errorf("Expected oid to equal 2.5.29.17, got %s", matchedOID.String())
}
if extValue != extValueString {
t.Errorf("Expected string to equal 'test cert value', got %s", extValue)
}
}

// Test when cert is present but the value does not match
func TestCertDoesNotMatch(t *testing.T) {
emailAddr := "[email protected]"
Expand Down

0 comments on commit 425f8bd

Please sign in to comment.