Skip to content

Commit

Permalink
Adding cache for dns
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Schendel <[email protected]>
  • Loading branch information
amitschendel committed Nov 7, 2024
1 parent 8c43514 commit 174309a
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 27 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func main() {
var dnsManagerClient dnsmanager.DNSManagerClient
var dnsResolver dnsmanager.DNSResolver
if cfg.EnableNetworkTracing || cfg.EnableRuntimeDetection {
dnsManager := dnsmanager.CreateDNSManager()
dnsManager := dnsmanager.CreateDNSManager(ctx)
dnsManagerClient = dnsManager
// NOTE: dnsResolver is set for threat detection.
dnsResolver = dnsManager
Expand Down
123 changes: 113 additions & 10 deletions pkg/dnsmanager/dns_manager.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,140 @@
package dnsmanager

import (
"context"
"net"
"sync"
"time"

"github.com/goradd/maps"
tracerdnstype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/dns/types"
)

// DNSManager is used to manage DNS events and save IP resolutions. It exposes an API to resolve IP address to domain name.
// DNSManager is used to manage DNS events and save IP resolutions.
type DNSManager struct {
addressToDomainMap maps.SafeMap[string, string] // this map is used to resolve IP address to domain name
addressToDomainMap maps.SafeMap[string, string]
lookupCache *sync.Map // Cache for DNS lookups
failureCache *sync.Map // Cache for failed lookups to prevent repeated attempts
cleanupTicker *time.Ticker // Ticker for periodic cache cleanup
cancel context.CancelFunc // Cancel function for cleanup goroutine
}

type cacheEntry struct {
addresses []string
timestamp time.Time
}

const (
defaultPositiveTTL = 1 * time.Minute // Default TTL for successful lookups
defaultNegativeTTL = 5 * time.Second // Default TTL for failed lookups
cleanupInterval = 5 * time.Minute // How often to run cache cleanup
)

var _ DNSManagerClient = (*DNSManager)(nil)
var _ DNSResolver = (*DNSManager)(nil)

func CreateDNSManager() *DNSManager {
return &DNSManager{}
func CreateDNSManager(ctx context.Context) *DNSManager {
ctx, cancel := context.WithCancel(ctx)
dm := &DNSManager{
lookupCache: &sync.Map{},
failureCache: &sync.Map{},
cleanupTicker: time.NewTicker(cleanupInterval),
cancel: cancel,
}

// Start the cleanup goroutine
go dm.cacheCleaner(ctx)

return dm
}

func (dm *DNSManager) ReportDNSEvent(dnsEvent tracerdnstype.Event) {

// If we have addresses in the event, use them directly
if len(dnsEvent.Addresses) > 0 {
for _, address := range dnsEvent.Addresses {
dm.addressToDomainMap.Set(address, dnsEvent.DNSName)
}
} else {
addresses, err := net.LookupIP(dnsEvent.DNSName)
if err != nil {

// Update the cache with these known good addresses
dm.lookupCache.Store(dnsEvent.DNSName, cacheEntry{
addresses: dnsEvent.Addresses,
timestamp: time.Now(),
})
return
}

// Check if we've recently failed to look up this domain
if failedTime, failed := dm.failureCache.Load(dnsEvent.DNSName); failed {
if time.Since(failedTime.(time.Time)) < defaultNegativeTTL {
return
}
for _, address := range addresses {
dm.addressToDomainMap.Set(address.String(), dnsEvent.DNSName)
// Failed entry has expired, remove it
dm.failureCache.Delete(dnsEvent.DNSName)
}

// Check if we have a cached result
if cached, ok := dm.lookupCache.Load(dnsEvent.DNSName); ok {
entry := cached.(cacheEntry)
if time.Since(entry.timestamp) < defaultPositiveTTL {
// Use cached addresses
for _, addr := range entry.addresses {
dm.addressToDomainMap.Set(addr, dnsEvent.DNSName)
}
return
}
}

// Only perform lookup if we don't have cached results
addresses, err := net.LookupIP(dnsEvent.DNSName)
if err != nil {
// Cache the failure
dm.failureCache.Store(dnsEvent.DNSName, time.Now())
return
}

// Convert addresses to strings and store them
addrStrings := make([]string, 0, len(addresses))
for _, addr := range addresses {
addrStr := addr.String()
addrStrings = append(addrStrings, addrStr)
dm.addressToDomainMap.Set(addrStr, dnsEvent.DNSName)
}

// Cache the successful lookup
dm.lookupCache.Store(dnsEvent.DNSName, cacheEntry{
addresses: addrStrings,
timestamp: time.Now(),
})
}

// cacheCleaner runs periodically to clean up expired entries from both caches
func (dm *DNSManager) cacheCleaner(ctx context.Context) {
for {
select {
case <-dm.cleanupTicker.C:
now := time.Now()

// Clean up positive cache
dm.lookupCache.Range(func(key, value interface{}) bool {
entry := value.(cacheEntry)
if now.Sub(entry.timestamp) > defaultPositiveTTL {
dm.lookupCache.Delete(key)
}
return true
})

// Clean up negative cache
dm.failureCache.Range(func(key, value interface{}) bool {
failedTime := value.(time.Time)
if now.Sub(failedTime) > defaultNegativeTTL {
dm.failureCache.Delete(key)
}
return true
})

case <-ctx.Done():
dm.cleanupTicker.Stop()
return
}
}
}
Expand Down
75 changes: 59 additions & 16 deletions pkg/dnsmanager/dns_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package dnsmanager

import (
"net"
"sync"
"testing"
"time"

"github.com/goradd/maps"
tracerdnstype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/dns/types"
)

Expand All @@ -13,6 +16,7 @@ func TestResolveIPAddress(t *testing.T) {
dnsEvent tracerdnstype.Event
ipAddr string
want string
wantOk bool
}{
{
name: "ip found",
Expand All @@ -24,7 +28,8 @@ func TestResolveIPAddress(t *testing.T) {
"67.225.146.248",
},
},
want: "test.com",
want: "test.com",
wantOk: true,
},
{
name: "ip not found",
Expand All @@ -36,57 +41,95 @@ func TestResolveIPAddress(t *testing.T) {
"54.23.332.4",
},
},
want: "",
want: "",
wantOk: false,
},
{
name: "no address",
ipAddr: "67.225.146.248",
dnsEvent: tracerdnstype.Event{
DNSName: "test.com",
NumAnswers: 0, // will not resolve
NumAnswers: 0,
},
want: "",
want: "",
wantOk: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dm := &DNSManager{}
// Create a properly initialized DNSManager
dm := &DNSManager{
addressToDomainMap: maps.SafeMap[string, string]{},
lookupCache: &sync.Map{},
failureCache: &sync.Map{},
cleanupTicker: time.NewTicker(cleanupInterval),
}

dm.ReportDNSEvent(tt.dnsEvent)
got, _ := dm.ResolveIPAddress(tt.ipAddr)
if got != tt.want {
t.Errorf("ResolveIPAddress() got = %v, want %v", got, tt.want)
got, ok := dm.ResolveIPAddress(tt.ipAddr)
if got != tt.want || ok != tt.wantOk {
t.Errorf("ResolveIPAddress() got = %v, ok = %v, want = %v, wantOk = %v", got, ok, tt.want, tt.wantOk)
}

// Cleanup
dm.cleanupTicker.Stop()
})
}
}

func TestResolveIPAddressFallback(t *testing.T) {
// Skip the test if running in CI or without network access
if testing.Short() {
t.Skip("Skipping test that requires network access")
}

tests := []struct {
name string
dnsEvent tracerdnstype.Event
want string
wantOk bool
}{

{
name: "dns resolution fallback",
dnsEvent: tracerdnstype.Event{
DNSName: "test.com",
DNSName: "example.com", // Using example.com as it's guaranteed to exist
NumAnswers: 1,
},
want: "test.com",
want: "example.com",
wantOk: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addresses, _ := net.LookupIP(tt.dnsEvent.DNSName)
dm := &DNSManager{}
// Create a properly initialized DNSManager
dm := &DNSManager{
addressToDomainMap: maps.SafeMap[string, string]{},
lookupCache: &sync.Map{},
failureCache: &sync.Map{},
cleanupTicker: time.NewTicker(cleanupInterval),
}

// Perform the actual DNS lookup
addresses, err := net.LookupIP(tt.dnsEvent.DNSName)
if err != nil {
t.Skipf("DNS lookup failed: %v", err)
return
}
if len(addresses) == 0 {
t.Skip("No addresses returned from DNS lookup")
return
}

dm.ReportDNSEvent(tt.dnsEvent)
got, _ := dm.ResolveIPAddress(addresses[0].String())
if got != tt.want {
t.Errorf("ResolveIPAddress() got = %v, want %v", got, tt.want)
got, ok := dm.ResolveIPAddress(addresses[0].String())
if got != tt.want || ok != tt.wantOk {
t.Errorf("ResolveIPAddress() got = %v, ok = %v, want = %v, wantOk = %v", got, ok, tt.want, tt.wantOk)
}

// Cleanup
dm.cleanupTicker.Stop()
})
}
}

0 comments on commit 174309a

Please sign in to comment.