diff --git a/config.go b/config.go index 3cb9ab07f..d203c0872 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "crypto/x509" "io" + "net" "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" @@ -217,6 +218,12 @@ type Config struct { // CertificateRequestMessageHook, if not nil, is called when a Certificate Request // message is sent from a server. The returned handshake message replaces the original message. CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + // Whenever a connection attempt is made, the server or application can call this callback function. + // The callback function can then implement logic to handle the connection attempt, such as logging the attempt, + // checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks. + // If the callback function returns an error, the connection attempt will be aborted. + OnConnectionAttempt func(net.Addr) error } func defaultConnectContextMaker() (context.Context, func()) { diff --git a/conn.go b/conn.go index 1c670cc2c..d49b8740e 100644 --- a/conn.go +++ b/conn.go @@ -325,6 +325,11 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, if config == nil { return nil, errNoConfigProvided } + if config.OnConnectionAttempt != nil { + if err := config.OnConnectionAttempt(rAddr); err != nil { + return nil, err + } + } dconn, err := createConn(conn, rAddr, config, false) if err != nil { return nil, err diff --git a/conn_test.go b/conn_test.go index 960118327..d52d7eb85 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3436,3 +3436,44 @@ func TestHelloRandom(t *testing.T) { t.Error(err) } } + +func TestOnConnectionAttempt(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*20) + defer cancel() + + var clientOnConnectionAttempt, serverOnConnectionAttempt atomic.Int32 + + ca, cb := dpipe.Pipe() + clientErr := make(chan error, 1) + go func() { + _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + OnConnectionAttempt: func(in net.Addr) error { + clientOnConnectionAttempt.Store(1) + return nil + }, + }, true) + clientErr <- err + }() + + expectedErr := &FatalError{} + if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + OnConnectionAttempt: func(in net.Addr) error { + serverOnConnectionAttempt.Store(1) + return expectedErr + }, + }, true); !errors.Is(err, expectedErr) { + t.Fatal(err) + } + + if err := <-clientErr; err == nil { + t.Fatal(err) + } + + if v := serverOnConnectionAttempt.Load(); v != 1 { + t.Fatal("OnConnectionAttempt did not fire for server") + } + + if v := clientOnConnectionAttempt.Load(); v != 0 { + t.Fatal("OnConnectionAttempt fired for client") + } +} diff --git a/examples/listen/verify-brute-force-protection/main.go b/examples/listen/verify-brute-force-protection/main.go new file mode 100644 index 000000000..b5fb82c42 --- /dev/null +++ b/examples/listen/verify-brute-force-protection/main.go @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package main implements an example DTLS server which verifies client certificates. +// It also implements a basic Brute Force Attack protection. +package main + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "sync" + "time" + + "github.com/pion/dtls/v2" + "github.com/pion/dtls/v2/examples/util" +) + +func main() { + // Prepare the IP to connect to + addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} + + // Create parent context to cleanup handshaking connections on exit. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // + // Everything below is the pion-DTLS API! Thanks for using it ❤️. + // + + // ************ Variables used to implement a basic Brute Force Attack protection ************* + var attempts = make(map[string]int) // Map of attempts for each IP address. + var attemptsMutex sync.Mutex // Mutex for the map of attempts. + var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes. + + certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem", + "examples/certificates/server.pub.pem") + util.Check(err) + + rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") + util.Check(err) + certPool := x509.NewCertPool() + cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) + util.Check(err) + certPool.AddCert(cert) + + // Prepare the configuration of the DTLS connection + config := &dtls.Config{ + Certificates: []tls.Certificate{certificate}, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + ClientAuth: dtls.RequireAndVerifyClientCert, + ClientCAs: certPool, + // Create timeout context for accepted connection. + ConnectContextMaker: func() (context.Context, func()) { + return context.WithTimeout(ctx, 30*time.Second) + }, + // This function will be called on each connection attempt. + OnConnectionAttempt: func(addr net.Addr) error { + // *************** Brute Force Attack protection *************** + // Check if the IP address is in the map, and if the IP address has exceeded the limit + attemptsMutex.Lock() + defer attemptsMutex.Unlock() + // Here I implement a time cleaner for the map of attempts, every 5 minutes I will + // decrement by 1 the number of attempts for each IP address. + if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) { + attemptsCleaner = time.Now() + for k, v := range attempts { + if v > 0 { + attempts[k]-- + } + if attempts[k] == 0 { + delete(attempts, k) + } + } + } + // Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection) + attemptIP := addr.(*net.UDPAddr).IP.String() + if attempts[attemptIP] > 10 { + return fmt.Errorf("too many attempts from this IP address") + } + // Here I increment the number of attempts for this IP address (Brute Force Attack protection) + attempts[attemptIP]++ + // *************** END Brute Force Attack protection END *************** + return nil + }, + } + + // Connect to a DTLS server + listener, err := dtls.Listen("udp", addr, config) + util.Check(err) + defer func() { + util.Check(listener.Close()) + }() + + fmt.Println("Listening") + + // Simulate a chat session + hub := util.NewHub() + + go func() { + for { + // Wait for a connection. + conn, err := listener.Accept() + util.Check(err) + // defer conn.Close() // TODO: graceful shutdown + + // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` + // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose + // functions like `ConnectionState` etc. + + // *************** Brute Force Attack protection *************** + // Here I decrease the number of attempts for this IP address + attemptsMutex.Lock() + attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String() + attempts[attemptIP]-- + // If the number of attempts for this IP address is 0, I delete the IP address from the map + if attempts[attemptIP] == 0 { + delete(attempts, attemptIP) + } + attemptsMutex.Unlock() + // *************** END Brute Force Attack protection END *************** + + // Register the connection with the chat hub + hub.Register(conn) + } + }() + + // Start chatting + hub.Chat() +}