Skip to content

Commit

Permalink
Add OnConnectionAttempt to Config
Browse files Browse the repository at this point in the history
Convenience callback that can be fired for Servers with the net.Addr of
the remote
  • Loading branch information
tonisole authored and Sean-Der committed Jul 5, 2024
1 parent 48d6748 commit a6d9640
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 0 deletions.
7 changes: 7 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"

"github.com/pion/dtls/v2/pkg/crypto/elliptic"
Expand Down Expand Up @@ -221,6 +222,12 @@ type Config struct {
// CertificateRequestMessageHook, if not nil, is called when a Certificate Request
// message is sent from a server. The returned handshake message replaces the original message.
CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message

// OnConnectionAttempt is fired Whenever a connection attempt is made, the server or application can call this callback function.
// The callback function can then implement logic to handle the connection attempt, such as logging the attempt,
// checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks.
// If the callback function returns an error, the connection attempt will be aborted.
OnConnectionAttempt func(net.Addr) error
}

func defaultConnectContextMaker() (context.Context, func()) {
Expand Down
5 changes: 5 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
if config == nil {
return nil, errNoConfigProvided
}
if config.OnConnectionAttempt != nil {
if err := config.OnConnectionAttempt(rAddr); err != nil {
return nil, err
}
}
dconn, err := createConn(conn, rAddr, config, false)
if err != nil {
return nil, err
Expand Down
47 changes: 47 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3435,3 +3435,50 @@ func TestHelloRandom(t *testing.T) {
t.Error(err)
}
}

func TestOnConnectionAttempt(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*20)
defer cancel()

var clientOnConnectionAttempt, serverOnConnectionAttempt atomic.Int32

ca, cb := dpipe.Pipe()
clientErr := make(chan error, 1)
go func() {
_, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{
OnConnectionAttempt: func(in net.Addr) error {
clientOnConnectionAttempt.Store(1)
if in == nil {
t.Fatal("net.Addr is nil") //nolint: govet
}
return nil
},
}, true)
clientErr <- err
}()

expectedErr := &FatalError{}
if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{
OnConnectionAttempt: func(in net.Addr) error {
serverOnConnectionAttempt.Store(1)
if in == nil {
t.Fatal("net.Addr is nil") //nolint: govet
}
return expectedErr
},
}, true); !errors.Is(err, expectedErr) {
t.Fatal(err)
}

if err := <-clientErr; err == nil {
t.Fatal(err)
}

if v := serverOnConnectionAttempt.Load(); v != 1 {
t.Fatal("OnConnectionAttempt did not fire for server")
}

if v := clientOnConnectionAttempt.Load(); v != 0 {
t.Fatal("OnConnectionAttempt fired for client")
}
}
132 changes: 132 additions & 0 deletions examples/listen/verify-brute-force-protection/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

// Package main implements an example DTLS server which verifies client certificates.
// It also implements a basic Brute Force Attack protection.
package main

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"sync"
"time"

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/examples/util"
)

func main() {
// Prepare the IP to connect to
addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444}

// Create parent context to cleanup handshaking connections on exit.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

//
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//

// ************ Variables used to implement a basic Brute Force Attack protection *************
var attempts = make(map[string]int) // Map of attempts for each IP address.
var attemptsMutex sync.Mutex // Mutex for the map of attempts.
var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes.

certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem",
"examples/certificates/server.pub.pem")
util.Check(err)

rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem")
util.Check(err)
certPool := x509.NewCertPool()
cert, err := x509.ParseCertificate(rootCertificate.Certificate[0])
util.Check(err)
certPool.AddCert(cert)

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
Certificates: []tls.Certificate{certificate},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
ClientAuth: dtls.RequireAndVerifyClientCert,
ClientCAs: certPool,
// Create timeout context for accepted connection.
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(ctx, 30*time.Second)
},
// This function will be called on each connection attempt.
OnConnectionAttempt: func(addr net.Addr) error {
// *************** Brute Force Attack protection ***************
// Check if the IP address is in the map, and if the IP address has exceeded the limit
attemptsMutex.Lock()
defer attemptsMutex.Unlock()
// Here I implement a time cleaner for the map of attempts, every 5 minutes I will
// decrement by 1 the number of attempts for each IP address.
if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) {
attemptsCleaner = time.Now()
for k, v := range attempts {
if v > 0 {
attempts[k]--
}
if attempts[k] == 0 {
delete(attempts, k)
}
}
}
// Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection)
attemptIP := addr.(*net.UDPAddr).IP.String()
if attempts[attemptIP] > 10 {
return fmt.Errorf("too many attempts from this IP address")
}
// Here I increment the number of attempts for this IP address (Brute Force Attack protection)
attempts[attemptIP]++
// *************** END Brute Force Attack protection END ***************
return nil
},
}

// Connect to a DTLS server
listener, err := dtls.Listen("udp", addr, config)
util.Check(err)
defer func() {
util.Check(listener.Close())
}()

fmt.Println("Listening")

// Simulate a chat session
hub := util.NewHub()

go func() {
for {
// Wait for a connection.
conn, err := listener.Accept()
util.Check(err)
// defer conn.Close() // TODO: graceful shutdown

// `conn` is of type `net.Conn` but may be casted to `dtls.Conn`
// using `dtlsConn := conn.(*dtls.Conn)` in order to to expose
// functions like `ConnectionState` etc.

// *************** Brute Force Attack protection ***************
// Here I decrease the number of attempts for this IP address
attemptsMutex.Lock()
attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String()
attempts[attemptIP]--
// If the number of attempts for this IP address is 0, I delete the IP address from the map
if attempts[attemptIP] == 0 {
delete(attempts, attemptIP)
}
attemptsMutex.Unlock()
// *************** END Brute Force Attack protection END ***************

// Register the connection with the chat hub
hub.Register(conn)
}
}()

// Start chatting
hub.Chat()
}

0 comments on commit a6d9640

Please sign in to comment.