diff --git a/main.go b/main.go index 34e0318b..6deafa64 100644 --- a/main.go +++ b/main.go @@ -198,7 +198,7 @@ func main() { var dnsManagerClient dnsmanager.DNSManagerClient var dnsResolver dnsmanager.DNSResolver if cfg.EnableNetworkTracing || cfg.EnableRuntimeDetection { - dnsManager := dnsmanager.CreateDNSManager() + dnsManager := dnsmanager.CreateDNSManager(ctx) dnsManagerClient = dnsManager // NOTE: dnsResolver is set for threat detection. dnsResolver = dnsManager diff --git a/pkg/dnsmanager/dns_manager.go b/pkg/dnsmanager/dns_manager.go index 3a7d2f49..a8ef3e8a 100644 --- a/pkg/dnsmanager/dns_manager.go +++ b/pkg/dnsmanager/dns_manager.go @@ -1,37 +1,140 @@ package dnsmanager import ( + "context" "net" + "sync" + "time" "github.com/goradd/maps" tracerdnstype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/dns/types" ) -// DNSManager is used to manage DNS events and save IP resolutions. It exposes an API to resolve IP address to domain name. +// DNSManager is used to manage DNS events and save IP resolutions. type DNSManager struct { - addressToDomainMap maps.SafeMap[string, string] // this map is used to resolve IP address to domain name + addressToDomainMap maps.SafeMap[string, string] + lookupCache *sync.Map // Cache for DNS lookups + failureCache *sync.Map // Cache for failed lookups to prevent repeated attempts + cleanupTicker *time.Ticker // Ticker for periodic cache cleanup + cancel context.CancelFunc // Cancel function for cleanup goroutine } +type cacheEntry struct { + addresses []string + timestamp time.Time +} + +const ( + defaultPositiveTTL = 1 * time.Minute // Default TTL for successful lookups + defaultNegativeTTL = 5 * time.Second // Default TTL for failed lookups + cleanupInterval = 5 * time.Minute // How often to run cache cleanup +) + var _ DNSManagerClient = (*DNSManager)(nil) var _ DNSResolver = (*DNSManager)(nil) -func CreateDNSManager() *DNSManager { - return &DNSManager{} +func CreateDNSManager(ctx context.Context) *DNSManager { + ctx, cancel := context.WithCancel(ctx) + dm := &DNSManager{ + lookupCache: &sync.Map{}, + failureCache: &sync.Map{}, + cleanupTicker: time.NewTicker(cleanupInterval), + cancel: cancel, + } + + // Start the cleanup goroutine + go dm.cacheCleaner(ctx) + + return dm } func (dm *DNSManager) ReportDNSEvent(dnsEvent tracerdnstype.Event) { - + // If we have addresses in the event, use them directly if len(dnsEvent.Addresses) > 0 { for _, address := range dnsEvent.Addresses { dm.addressToDomainMap.Set(address, dnsEvent.DNSName) } - } else { - addresses, err := net.LookupIP(dnsEvent.DNSName) - if err != nil { + + // Update the cache with these known good addresses + dm.lookupCache.Store(dnsEvent.DNSName, cacheEntry{ + addresses: dnsEvent.Addresses, + timestamp: time.Now(), + }) + return + } + + // Check if we've recently failed to look up this domain + if failedTime, failed := dm.failureCache.Load(dnsEvent.DNSName); failed { + if time.Since(failedTime.(time.Time)) < defaultNegativeTTL { return } - for _, address := range addresses { - dm.addressToDomainMap.Set(address.String(), dnsEvent.DNSName) + // Failed entry has expired, remove it + dm.failureCache.Delete(dnsEvent.DNSName) + } + + // Check if we have a cached result + if cached, ok := dm.lookupCache.Load(dnsEvent.DNSName); ok { + entry := cached.(cacheEntry) + if time.Since(entry.timestamp) < defaultPositiveTTL { + // Use cached addresses + for _, addr := range entry.addresses { + dm.addressToDomainMap.Set(addr, dnsEvent.DNSName) + } + return + } + } + + // Only perform lookup if we don't have cached results + addresses, err := net.LookupIP(dnsEvent.DNSName) + if err != nil { + // Cache the failure + dm.failureCache.Store(dnsEvent.DNSName, time.Now()) + return + } + + // Convert addresses to strings and store them + addrStrings := make([]string, 0, len(addresses)) + for _, addr := range addresses { + addrStr := addr.String() + addrStrings = append(addrStrings, addrStr) + dm.addressToDomainMap.Set(addrStr, dnsEvent.DNSName) + } + + // Cache the successful lookup + dm.lookupCache.Store(dnsEvent.DNSName, cacheEntry{ + addresses: addrStrings, + timestamp: time.Now(), + }) +} + +// cacheCleaner runs periodically to clean up expired entries from both caches +func (dm *DNSManager) cacheCleaner(ctx context.Context) { + for { + select { + case <-dm.cleanupTicker.C: + now := time.Now() + + // Clean up positive cache + dm.lookupCache.Range(func(key, value interface{}) bool { + entry := value.(cacheEntry) + if now.Sub(entry.timestamp) > defaultPositiveTTL { + dm.lookupCache.Delete(key) + } + return true + }) + + // Clean up negative cache + dm.failureCache.Range(func(key, value interface{}) bool { + failedTime := value.(time.Time) + if now.Sub(failedTime) > defaultNegativeTTL { + dm.failureCache.Delete(key) + } + return true + }) + + case <-ctx.Done(): + dm.cleanupTicker.Stop() + return } } } diff --git a/pkg/dnsmanager/dns_manager_test.go b/pkg/dnsmanager/dns_manager_test.go index bc8edebc..30a9a218 100644 --- a/pkg/dnsmanager/dns_manager_test.go +++ b/pkg/dnsmanager/dns_manager_test.go @@ -2,8 +2,11 @@ package dnsmanager import ( "net" + "sync" "testing" + "time" + "github.com/goradd/maps" tracerdnstype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/dns/types" ) @@ -13,6 +16,7 @@ func TestResolveIPAddress(t *testing.T) { dnsEvent tracerdnstype.Event ipAddr string want string + wantOk bool }{ { name: "ip found", @@ -24,7 +28,8 @@ func TestResolveIPAddress(t *testing.T) { "67.225.146.248", }, }, - want: "test.com", + want: "test.com", + wantOk: true, }, { name: "ip not found", @@ -36,57 +41,95 @@ func TestResolveIPAddress(t *testing.T) { "54.23.332.4", }, }, - want: "", + want: "", + wantOk: false, }, { name: "no address", ipAddr: "67.225.146.248", dnsEvent: tracerdnstype.Event{ DNSName: "test.com", - NumAnswers: 0, // will not resolve + NumAnswers: 0, }, - want: "", + want: "", + wantOk: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - dm := &DNSManager{} + // Create a properly initialized DNSManager + dm := &DNSManager{ + addressToDomainMap: maps.SafeMap[string, string]{}, + lookupCache: &sync.Map{}, + failureCache: &sync.Map{}, + cleanupTicker: time.NewTicker(cleanupInterval), + } + dm.ReportDNSEvent(tt.dnsEvent) - got, _ := dm.ResolveIPAddress(tt.ipAddr) - if got != tt.want { - t.Errorf("ResolveIPAddress() got = %v, want %v", got, tt.want) + got, ok := dm.ResolveIPAddress(tt.ipAddr) + if got != tt.want || ok != tt.wantOk { + t.Errorf("ResolveIPAddress() got = %v, ok = %v, want = %v, wantOk = %v", got, ok, tt.want, tt.wantOk) } + + // Cleanup + dm.cleanupTicker.Stop() }) } } func TestResolveIPAddressFallback(t *testing.T) { + // Skip the test if running in CI or without network access + if testing.Short() { + t.Skip("Skipping test that requires network access") + } + tests := []struct { name string dnsEvent tracerdnstype.Event want string + wantOk bool }{ - { name: "dns resolution fallback", dnsEvent: tracerdnstype.Event{ - DNSName: "test.com", + DNSName: "example.com", // Using example.com as it's guaranteed to exist NumAnswers: 1, }, - want: "test.com", + want: "example.com", + wantOk: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - addresses, _ := net.LookupIP(tt.dnsEvent.DNSName) - dm := &DNSManager{} + // Create a properly initialized DNSManager + dm := &DNSManager{ + addressToDomainMap: maps.SafeMap[string, string]{}, + lookupCache: &sync.Map{}, + failureCache: &sync.Map{}, + cleanupTicker: time.NewTicker(cleanupInterval), + } + + // Perform the actual DNS lookup + addresses, err := net.LookupIP(tt.dnsEvent.DNSName) + if err != nil { + t.Skipf("DNS lookup failed: %v", err) + return + } + if len(addresses) == 0 { + t.Skip("No addresses returned from DNS lookup") + return + } + dm.ReportDNSEvent(tt.dnsEvent) - got, _ := dm.ResolveIPAddress(addresses[0].String()) - if got != tt.want { - t.Errorf("ResolveIPAddress() got = %v, want %v", got, tt.want) + got, ok := dm.ResolveIPAddress(addresses[0].String()) + if got != tt.want || ok != tt.wantOk { + t.Errorf("ResolveIPAddress() got = %v, ok = %v, want = %v, wantOk = %v", got, ok, tt.want, tt.wantOk) } + + // Cleanup + dm.cleanupTicker.Stop() }) } }