diff --git a/certificate.go b/certificate.go index 519fc875f..e8dd2b83f 100644 --- a/certificate.go +++ b/certificate.go @@ -9,6 +9,8 @@ import ( "crypto/x509" "fmt" "strings" + + "github.com/pion/dtls/v2/pkg/protocol/handshake" ) // ClientHelloInfo contains information from a ClientHello message in order to @@ -22,6 +24,9 @@ type ClientHelloInfo struct { // CipherSuites lists the CipherSuites supported by the client (e.g. // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). CipherSuites []CipherSuiteID + + // RandomBytes stores the client hello random bytes + RandomBytes [handshake.RandomBytesLength]byte } // CertificateRequestInfo contains information from a server's diff --git a/config.go b/config.go index d765ecd91..3cb9ab07f 100644 --- a/config.go +++ b/config.go @@ -198,6 +198,9 @@ type Config struct { // https://datatracker.ietf.org/doc/html/rfc9146#section-4 PaddingLengthGenerator func(uint) uint + // HelloRandomBytesGenerator generates custom client hello random bytes. + HelloRandomBytesGenerator func() [handshake.RandomBytesLength]byte + // Handshake hooks: hooks can be used for testing invalid messages, // mimicking other implementations or randomizing fields, which is valuable // for applications that need censorship-resistance by making diff --git a/conn.go b/conn.go index e65163cf7..1c670cc2c 100644 --- a/conn.go +++ b/conn.go @@ -212,6 +212,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo localGetClientCertificate: config.GetClientCertificate, insecureSkipHelloVerify: config.InsecureSkipVerifyHello, connectionIDGenerator: config.ConnectionIDGenerator, + helloRandomBytesGenerator: config.HelloRandomBytesGenerator, clientHelloMessageHook: config.ClientHelloMessageHook, serverHelloMessageHook: config.ServerHelloMessageHook, certificateRequestMessageHook: config.CertificateRequestMessageHook, diff --git a/conn_test.go b/conn_test.go index 17eb932f0..960118327 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3363,3 +3363,76 @@ func TestApplicationDataQueueLimited(t *testing.T) { ca.Close() // nolint <-done } + +func TestHelloRandom(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ca, cb := dpipe.Pipe() + certificate, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + gotHello := make(chan struct{}) + + chRandom := [handshake.RandomBytesLength]byte{} + _, err = rand.Read(chRandom[:]) + if err != nil { + t.Fatal(err) + } + + go func() { + server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { + if len(chi.CipherSuites) == 0 { + return &certificate, nil + } + + if !bytes.Equal(chi.RandomBytes[:], chRandom[:]) { + t.Error("client hello random differs") + } + + return &certificate, nil + }, + LoggerFactory: logging.NewDefaultLoggerFactory(), + }, false) + if sErr != nil { + t.Error(sErr) + return + } + buf := make([]byte, 1024) + if _, sErr = server.Read(buf); sErr != nil { + t.Error(sErr) + } + gotHello <- struct{}{} + if sErr = server.Close(); sErr != nil { //nolint:contextcheck + t.Error(sErr) + } + }() + + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + LoggerFactory: logging.NewDefaultLoggerFactory(), + HelloRandomBytesGenerator: func() [handshake.RandomBytesLength]byte { + return chRandom + }, + InsecureSkipVerify: true, + }, false) + if err != nil { + t.Fatal(err) + } + if _, err = client.Write([]byte("hello")); err != nil { + t.Error(err) + } + select { + case <-gotHello: + // OK + case <-time.After(time.Second * 5): + t.Error("timeout") + } + + if err = client.Close(); err != nil { + t.Error(err) + } +} diff --git a/flight1handler.go b/flight1handler.go index 08a8f3921..6448fef79 100644 --- a/flight1handler.go +++ b/flight1handler.go @@ -57,6 +57,10 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha return nil, nil, err } + if cfg.helloRandomBytesGenerator != nil { + state.localRandom.RandomBytes = cfg.helloRandomBytesGenerator() + } + extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, diff --git a/flight4handler.go b/flight4handler.go index 86f21464b..840a24f15 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -300,6 +300,7 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha certificate, err := cfg.getCertificate(&ClientHelloInfo{ ServerName: state.serverName, CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()}, + RandomBytes: state.remoteRandom.RandomBytes, }) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err diff --git a/handshaker.go b/handshaker.go index 317cb4e4e..09c6b4e3b 100644 --- a/handshaker.go +++ b/handshaker.go @@ -114,6 +114,7 @@ type handshakeConfig struct { ellipticCurves []elliptic.Curve insecureSkipHelloVerify bool connectionIDGenerator func() []byte + helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte onFlightState func(flightVal, handshakeState) log logging.LeveledLogger diff --git a/state.go b/state.go index 35a5b64fb..bc93bdf0c 100644 --- a/state.go +++ b/state.go @@ -258,3 +258,8 @@ func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile { return 0 } + +// RemoteRandomBytes returns the remote client hello random bytes +func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte { + return s.remoteRandom.RandomBytes +}