Skip to content

Commit

Permalink
Add ability to select cert based on ch rand bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
mingyech authored and Sean-Der committed Jun 9, 2024
1 parent eddca22 commit 61b3466
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 0 deletions.
5 changes: 5 additions & 0 deletions certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"crypto/x509"
"fmt"
"strings"

"github.com/pion/dtls/v2/pkg/protocol/handshake"
)

// ClientHelloInfo contains information from a ClientHello message in order to
Expand All @@ -22,6 +24,9 @@ type ClientHelloInfo struct {
// CipherSuites lists the CipherSuites supported by the client (e.g.
// TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
CipherSuites []CipherSuiteID

// RandomBytes stores the client hello random bytes
RandomBytes [handshake.RandomBytesLength]byte
}

// CertificateRequestInfo contains information from a server's
Expand Down
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ type Config struct {
// https://datatracker.ietf.org/doc/html/rfc9146#section-4
PaddingLengthGenerator func(uint) uint

// HelloRandomBytesGenerator generates custom client hello random bytes.
HelloRandomBytesGenerator func() [handshake.RandomBytesLength]byte

// Handshake hooks: hooks can be used for testing invalid messages,
// mimicking other implementations or randomizing fields, which is valuable
// for applications that need censorship-resistance by making
Expand Down
1 change: 1 addition & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
localGetClientCertificate: config.GetClientCertificate,
insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
connectionIDGenerator: config.ConnectionIDGenerator,
helloRandomBytesGenerator: config.HelloRandomBytesGenerator,
clientHelloMessageHook: config.ClientHelloMessageHook,
serverHelloMessageHook: config.ServerHelloMessageHook,
certificateRequestMessageHook: config.CertificateRequestMessageHook,
Expand Down
73 changes: 73 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3363,3 +3363,76 @@ func TestApplicationDataQueueLimited(t *testing.T) {
ca.Close() // nolint
<-done
}

func TestHelloRandom(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

ca, cb := dpipe.Pipe()
certificate, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
gotHello := make(chan struct{})

chRandom := [handshake.RandomBytesLength]byte{}
_, err = rand.Read(chRandom[:])
if err != nil {
t.Fatal(err)
}

go func() {
server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{
GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) {
if len(chi.CipherSuites) == 0 {
return &certificate, nil
}

if !bytes.Equal(chi.RandomBytes[:], chRandom[:]) {
t.Error("client hello random differs")
}

return &certificate, nil
},
LoggerFactory: logging.NewDefaultLoggerFactory(),
}, false)
if sErr != nil {
t.Error(sErr)
return
}
buf := make([]byte, 1024)
if _, sErr = server.Read(buf); sErr != nil {
t.Error(sErr)
}
gotHello <- struct{}{}
if sErr = server.Close(); sErr != nil { //nolint:contextcheck
t.Error(sErr)
}
}()

client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{
LoggerFactory: logging.NewDefaultLoggerFactory(),
HelloRandomBytesGenerator: func() [handshake.RandomBytesLength]byte {
return chRandom
},
InsecureSkipVerify: true,
}, false)
if err != nil {
t.Fatal(err)
}
if _, err = client.Write([]byte("hello")); err != nil {
t.Error(err)
}
select {
case <-gotHello:
// OK
case <-time.After(time.Second * 5):
t.Error("timeout")
}

if err = client.Close(); err != nil {
t.Error(err)
}
}
4 changes: 4 additions & 0 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha
return nil, nil, err
}

if cfg.helloRandomBytesGenerator != nil {
state.localRandom.RandomBytes = cfg.helloRandomBytesGenerator()
}

extensions := []extension.Extension{
&extension.SupportedSignatureAlgorithms{
SignatureHashAlgorithms: cfg.localSignatureSchemes,
Expand Down
1 change: 1 addition & 0 deletions flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
certificate, err := cfg.getCertificate(&ClientHelloInfo{
ServerName: state.serverName,
CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()},
RandomBytes: state.remoteRandom.RandomBytes,
})
if err != nil {
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
Expand Down
1 change: 1 addition & 0 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ type handshakeConfig struct {
ellipticCurves []elliptic.Curve
insecureSkipHelloVerify bool
connectionIDGenerator func() []byte
helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte

onFlightState func(flightVal, handshakeState)
log logging.LeveledLogger
Expand Down
5 changes: 5 additions & 0 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,8 @@ func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {

return 0
}

// RemoteRandomBytes returns the remote client hello random bytes
func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte {
return s.remoteRandom.RandomBytes
}

0 comments on commit 61b3466

Please sign in to comment.