diff --git a/.github/workflows/component-tests.yaml b/.github/workflows/component-tests.yaml index c9149da1..2c741b92 100644 --- a/.github/workflows/component-tests.yaml +++ b/.github/workflows/component-tests.yaml @@ -43,15 +43,15 @@ jobs: Test_01_BasicAlertTest, Test_02_AllAlertsFromMaliciousApp, Test_03_BasicLoadActivities, - Test_04_MemoryLeak, + # Test_04_MemoryLeak, Test_05_MemoryLeak_10K_Alerts, Test_06_KillProcessInTheMiddle, Test_07_RuleBindingApplyTest, Test_08_ApplicationProfilePatching, Test_10_MalwareDetectionTest, Test_11_EndpointTest, - # Test_10_DemoTest - # Test_11_DuplicationTest + Test_12_MergingProfilesTest, + Test_13_MergingNetworkNeighborhoodTest, ] steps: - name: Checkout code diff --git a/clamav/init.sh b/clamav/init.sh index 3559497a..960645ef 100755 --- a/clamav/init.sh +++ b/clamav/init.sh @@ -63,7 +63,8 @@ else if [ -S "/tmp/clamd.sock" ]; then unlink "/tmp/clamd.sock" fi - clamd --foreground & + # Run clamd in the foreground but redirecting output to stdout and stderr to /dev/null + clamd --foreground > /dev/null 2>&1 & while [ ! -S "/run/clamav/clamd.sock" ] && [ ! -S "/tmp/clamd.sock" ]; do if [ "${_timeout:=0}" -gt "${CLAMD_STARTUP_TIMEOUT:=1800}" ]; then echo diff --git a/go.mod b/go.mod index c241777f..6d714b5c 100644 --- a/go.mod +++ b/go.mod @@ -262,3 +262,5 @@ require ( replace github.com/inspektor-gadget/inspektor-gadget => /home/afek/Projects/Armo/poc/inspektor-gadget replace github.com/vishvananda/netns => github.com/inspektor-gadget/netns v0.0.5-0.20230524185006-155d84c555d6 + +replace github.com/goradd/maps => github.com/matthyx/maps v0.0.0-20241029072232-2f5d83d608a7 diff --git a/go.sum b/go.sum index a351bb89..2a08498e 100644 --- a/go.sum +++ b/go.sum @@ -448,8 +448,6 @@ github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/Q github.com/gopacket/gopacket v1.2.0 h1:eXbzFad7f73P1n2EJHQlsKuvIMJjVXK5tXoSca78I3A= github.com/gopacket/gopacket v1.2.0/go.mod h1:BrAKEy5EOGQ76LSqh7DMAr7z0NNPdczWm2GxCG7+I8M= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/goradd/maps v0.1.5 h1:Ut7BPJgNy5BYbleI3LswVJJquiM8X5uN0ZuZBHSdRUI= -github.com/goradd/maps v0.1.5/go.mod h1:E5X1CHMgfVm1qFTHgXpgVLVylO5wtlhZdB93dRGjnc0= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= @@ -566,6 +564,8 @@ github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3v github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/matthyx/maps v0.0.0-20241029072232-2f5d83d608a7 h1:LAAFb3ra/vxiZcDY1zrbS29oqnB+N9MknuQZC1ju2+A= +github.com/matthyx/maps v0.0.0-20241029072232-2f5d83d608a7/go.mod h1:E5X1CHMgfVm1qFTHgXpgVLVylO5wtlhZdB93dRGjnc0= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= diff --git a/main.go b/main.go index a26edb0a..e7863961 100644 --- a/main.go +++ b/main.go @@ -32,6 +32,8 @@ import ( "github.com/kubescape/node-agent/pkg/objectcache/k8scache" "github.com/kubescape/node-agent/pkg/objectcache/networkneighborhoodcache" objectcachev1 "github.com/kubescape/node-agent/pkg/objectcache/v1" + "github.com/kubescape/node-agent/pkg/processmanager" + processmanagerv1 "github.com/kubescape/node-agent/pkg/processmanager/v1" "github.com/kubescape/node-agent/pkg/relevancymanager" relevancymanagerv1 "github.com/kubescape/node-agent/pkg/relevancymanager/v1" rulebinding "github.com/kubescape/node-agent/pkg/rulebindingmanager" @@ -193,26 +195,27 @@ func main() { var networkManagerClient networkmanager.NetworkManagerClient var dnsManagerClient dnsmanager.DNSManagerClient var dnsResolver dnsmanager.DNSResolver - if cfg.EnableNetworkTracing { + if cfg.EnableNetworkTracing || cfg.EnableRuntimeDetection { dnsManager := dnsmanager.CreateDNSManager() dnsManagerClient = dnsManager // NOTE: dnsResolver is set for threat detection. dnsResolver = dnsManager networkManagerClient = networkmanagerv2.CreateNetworkManager(ctx, cfg, clusterData.ClusterName, k8sClient, storageClient, dnsManager, preRunningContainersIDs, k8sObjectCache) } else { - if cfg.EnableRuntimeDetection { - logger.L().Ctx(ctx).Fatal("Network tracing is disabled, but runtime detection is enabled. Network tracing is required for runtime detection.") - } dnsManagerClient = dnsmanager.CreateDNSManagerMock() dnsResolver = dnsmanager.CreateDNSManagerMock() networkManagerClient = networkmanager.CreateNetworkManagerMock() } var ruleManager rulemanager.RuleManagerClient + var processManager processmanager.ProcessManagerClient var objCache objectcache.ObjectCache var ruleBindingNotify chan rulebinding.RuleBindingNotify if cfg.EnableRuntimeDetection { + // create the process manager + processManager = processmanagerv1.CreateProcessManager(ctx) + // create ruleBinding cache ruleBindingCache := rulebindingcachev1.NewCache(nodeName, k8sClient) dWatcher.AddAdaptor(ruleBindingCache) @@ -235,7 +238,7 @@ func main() { exporter := exporters.InitExporters(cfg.Exporters, clusterData.ClusterName, nodeName) // create runtimeDetection managers - ruleManager, err = rulemanagerv1.CreateRuleManager(ctx, cfg, k8sClient, ruleBindingCache, objCache, exporter, prometheusExporter, nodeName, clusterData.ClusterName) + ruleManager, err = rulemanagerv1.CreateRuleManager(ctx, cfg, k8sClient, ruleBindingCache, objCache, exporter, prometheusExporter, nodeName, clusterData.ClusterName, processManager, nil) if err != nil { logger.L().Ctx(ctx).Fatal("error creating RuleManager", helpers.Error(err)) } @@ -244,6 +247,7 @@ func main() { ruleManager = rulemanager.CreateRuleManagerMock() objCache = objectcache.NewObjectCacheMock() ruleBindingNotify = make(chan rulebinding.RuleBindingNotify, 1) + processManager = processmanager.CreateProcessManagerMock() } // Create the node profile manager @@ -269,7 +273,7 @@ func main() { } // Create the container handler - mainHandler, err := containerwatcher.CreateIGContainerWatcher(cfg, applicationProfileManager, k8sClient, relevancyManager, networkManagerClient, dnsManagerClient, prometheusExporter, ruleManager, malwareManager, preRunningContainersIDs, &ruleBindingNotify, containerRuntime, nil, nil) + mainHandler, err := containerwatcher.CreateIGContainerWatcher(cfg, applicationProfileManager, k8sClient, relevancyManager, networkManagerClient, dnsManagerClient, prometheusExporter, ruleManager, malwareManager, preRunningContainersIDs, &ruleBindingNotify, containerRuntime, nil, nil, processManager) if err != nil { logger.L().Ctx(ctx).Fatal("error creating the container watcher", helpers.Error(err)) } diff --git a/pkg/applicationprofilemanager/v1/applicationprofile_manager.go b/pkg/applicationprofilemanager/v1/applicationprofile_manager.go index f18dddce..5e3883d6 100644 --- a/pkg/applicationprofilemanager/v1/applicationprofile_manager.go +++ b/pkg/applicationprofilemanager/v1/applicationprofile_manager.go @@ -4,7 +4,9 @@ import ( "context" "errors" "fmt" + "regexp" "runtime" + "strings" "time" "github.com/armosec/utils-k8s-go/wlid" @@ -26,6 +28,7 @@ import ( "github.com/kubescape/node-agent/pkg/storage" "github.com/kubescape/node-agent/pkg/utils" "github.com/kubescape/storage/pkg/apis/softwarecomposition/v1beta1" + "github.com/kubescape/storage/pkg/registry/file/dynamicpathdetector" storageUtils "github.com/kubescape/storage/pkg/utils" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -34,6 +37,10 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +const OpenDynamicThreshold = 50 + +var procRegex = regexp.MustCompile(`^/proc/\d+`) + type ApplicationProfileManager struct { cfg config.Config clusterName string @@ -538,7 +545,34 @@ func (am *ApplicationProfileManager) saveProfile(ctx context.Context, watchedCon return true }) // record saved opens - toSaveOpens.Range(utils.SetInMap(am.savedOpens.Get(watchedContainer.K8sContainerID))) + savedOpens := am.savedOpens.Get(watchedContainer.K8sContainerID) + toSaveOpens.Range(utils.SetInMap(savedOpens)) + // use a dynamic path detector to compress opens + analyzer := dynamicpathdetector.NewPathAnalyzer(OpenDynamicThreshold) + keys := savedOpens.Keys() + // first pass to learn the opens + for _, path := range keys { + _, _ = dynamicpathdetector.AnalyzeOpen(path, analyzer) + } + // second pass to compress the opens + for _, path := range keys { + result, err := dynamicpathdetector.AnalyzeOpen(path, analyzer) + if err != nil { + continue + } + if result != path { + // path becomes compressed + // we avoid a lock by using Pop to remove path and retrieve its flags + pathFlags := savedOpens.Pop(path) + if savedOpens.Has(result) { + // merge flags + savedOpens.Get(result).Append(pathFlags.ToSlice()...) + } else { + // create new entry + savedOpens.Set(result, pathFlags) + } + } + } logger.L().Debug("ApplicationProfileManager - saved application profile", helpers.Int("capabilities", len(capabilities)), helpers.Int("endpoints", toSaveEndpoints.Len()), @@ -675,6 +709,11 @@ func (am *ApplicationProfileManager) ReportFileOpen(k8sContainerID, path string, if err := am.waitForContainer(k8sContainerID); err != nil { return } + // deduplicate /proc/1234/* into /proc/.../* (quite a common case) + // we perform it here instead of waiting for compression + if strings.HasPrefix(path, "/proc/") { + path = procRegex.ReplaceAllString(path, "/proc/"+dynamicpathdetector.DynamicIdentifier) + } // check if we already have this open savedOpens := am.savedOpens.Get(k8sContainerID) if savedOpens.Has(path) && savedOpens.Get(path).Contains(flags...) { diff --git a/pkg/applicationprofilemanager/v1/applicationprofile_manager_test.go b/pkg/applicationprofilemanager/v1/applicationprofile_manager_test.go index 6e04c332..1bf56cd9 100644 --- a/pkg/applicationprofilemanager/v1/applicationprofile_manager_test.go +++ b/pkg/applicationprofilemanager/v1/applicationprofile_manager_test.go @@ -6,10 +6,12 @@ import ( "net/http" "net/url" "sort" + "strings" "testing" "time" mapset "github.com/deckarep/golang-set/v2" + "github.com/goradd/maps" containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" "github.com/inspektor-gadget/inspektor-gadget/pkg/types" "github.com/kubescape/node-agent/pkg/config" @@ -19,6 +21,7 @@ import ( "github.com/kubescape/node-agent/pkg/seccompmanager" "github.com/kubescape/node-agent/pkg/storage" "github.com/kubescape/storage/pkg/apis/softwarecomposition/v1beta1" + "github.com/kubescape/storage/pkg/registry/file/dynamicpathdetector" "github.com/stretchr/testify/assert" ) @@ -274,3 +277,21 @@ func sortHTTPEndpoints(endpoints []v1beta1.HTTPEndpoint) { return string(endpoints[i].Headers) < string(endpoints[j].Headers) }) } + +func BenchmarkReportFileOpen(b *testing.B) { + savedOpens := maps.SafeMap[string, mapset.Set[string]]{} + savedOpens.Set("/proc/"+dynamicpathdetector.DynamicIdentifier+"/foo/bar", mapset.NewSet("O_LARGEFILE", "O_RDONLY")) + paths := []string{"/proc/12345/foo/bar", "/bin/ls", "/etc/passwd"} + flags := []string{"O_CLOEXEC", "O_RDONLY"} + for i := 0; i < b.N; i++ { + for _, path := range paths { + if strings.HasPrefix(path, "/proc/") { + path = procRegex.ReplaceAllString(path, "/proc/"+dynamicpathdetector.DynamicIdentifier) + } + if savedOpens.Has(path) && savedOpens.Get(path).Contains(flags...) { + continue + } + } + } + b.ReportAllocs() +} diff --git a/pkg/containerwatcher/v1/container_watcher.go b/pkg/containerwatcher/v1/container_watcher.go index d59370b8..8d6d30b7 100644 --- a/pkg/containerwatcher/v1/container_watcher.go +++ b/pkg/containerwatcher/v1/container_watcher.go @@ -41,6 +41,8 @@ import ( tracersshtype "github.com/kubescape/node-agent/pkg/ebpf/gadgets/ssh/types" tracersymlink "github.com/kubescape/node-agent/pkg/ebpf/gadgets/symlink/tracer" tracersymlinktype "github.com/kubescape/node-agent/pkg/ebpf/gadgets/symlink/types" + "github.com/kubescape/node-agent/pkg/processmanager" + "github.com/kubescape/node-agent/pkg/malwaremanager" "github.com/kubescape/node-agent/pkg/metricsmanager" "github.com/kubescape/node-agent/pkg/networkmanager" @@ -153,11 +155,13 @@ type IGContainerWatcher struct { ruleBindingPodNotify *chan rulebinding.RuleBindingNotify // container runtime runtime *containerutilsTypes.RuntimeConfig + // process manager + processManager processmanager.ProcessManagerClient } var _ containerwatcher.ContainerWatcher = (*IGContainerWatcher)(nil) -func CreateIGContainerWatcher(cfg config.Config, applicationProfileManager applicationprofilemanager.ApplicationProfileManagerClient, k8sClient *k8sinterface.KubernetesApi, relevancyManager relevancymanager.RelevancyManagerClient, networkManagerClient networkmanager.NetworkManagerClient, dnsManagerClient dnsmanager.DNSManagerClient, metrics metricsmanager.MetricsManager, ruleManager rulemanager.RuleManagerClient, malwareManager malwaremanager.MalwareManagerClient, preRunningContainers mapset.Set[string], ruleBindingPodNotify *chan rulebinding.RuleBindingNotify, runtime *containerutilsTypes.RuntimeConfig, thirdPartyEventReceivers *maps.SafeMap[utils.EventType, mapset.Set[containerwatcher.EventReceiver]], thirdPartyEnricher containerwatcher.ThirdPartyEnricher) (*IGContainerWatcher, error) { +func CreateIGContainerWatcher(cfg config.Config, applicationProfileManager applicationprofilemanager.ApplicationProfileManagerClient, k8sClient *k8sinterface.KubernetesApi, relevancyManager relevancymanager.RelevancyManagerClient, networkManagerClient networkmanager.NetworkManagerClient, dnsManagerClient dnsmanager.DNSManagerClient, metrics metricsmanager.MetricsManager, ruleManager rulemanager.RuleManagerClient, malwareManager malwaremanager.MalwareManagerClient, preRunningContainers mapset.Set[string], ruleBindingPodNotify *chan rulebinding.RuleBindingNotify, runtime *containerutilsTypes.RuntimeConfig, thirdPartyEventReceivers *maps.SafeMap[utils.EventType, mapset.Set[containerwatcher.EventReceiver]], thirdPartyEnricher containerwatcher.ThirdPartyEnricher, processManager processmanager.ProcessManagerClient) (*IGContainerWatcher, error) { // Use container collection to get notified for new containers containerCollection := &containercollection.ContainerCollection{} // Create a tracer collection instance @@ -207,6 +211,7 @@ func CreateIGContainerWatcher(cfg config.Config, applicationProfileManager appli ruleManager.ReportEvent(utils.ExecveEventType, &event) malwareManager.ReportEvent(utils.ExecveEventType, &event) metrics.ReportEvent(utils.ExecveEventType) + processManager.ReportEvent(utils.ExecveEventType, &event) applicationProfileManager.ReportFileExec(k8sContainerID, path, event.Args) relevancyManager.ReportFileExec(event.Runtime.ContainerID, k8sContainerID, path) @@ -453,6 +458,7 @@ func CreateIGContainerWatcher(cfg config.Config, applicationProfileManager appli thirdPartyTracers: mapset.NewSet[containerwatcher.CustomTracer](), thirdPartyContainerReceivers: mapset.NewSet[containerwatcher.ContainerReceiver](), thirdPartyEnricher: thirdPartyEnricher, + processManager: processManager, }, nil } @@ -498,11 +504,16 @@ func (ch *IGContainerWatcher) UnregisterContainerReceiver(receiver containerwatc func (ch *IGContainerWatcher) Start(ctx context.Context) error { if !ch.running { - if err := ch.startContainerCollection(ctx); err != nil { return fmt.Errorf("setting up container collection: %w", err) } + // We want to populate the initial processes before starting the tracers but after retrieving the shims. + if err := ch.processManager.PopulateInitialProcesses(); err != nil { + ch.stopContainerCollection() + return fmt.Errorf("populating initial processes: %w", err) + } + if err := ch.startTracers(); err != nil { ch.stopContainerCollection() return fmt.Errorf("starting app behavior tracing: %w", err) diff --git a/pkg/containerwatcher/v1/container_watcher_private.go b/pkg/containerwatcher/v1/container_watcher_private.go index 5e89bf4e..70a17478 100644 --- a/pkg/containerwatcher/v1/container_watcher_private.go +++ b/pkg/containerwatcher/v1/container_watcher_private.go @@ -34,15 +34,14 @@ func (ch *IGContainerWatcher) containerCallback(notif containercollection.PubSub k8sContainerID := utils.CreateK8sContainerID(notif.Container.K8s.Namespace, notif.Container.K8s.PodName, notif.Container.K8s.ContainerName) - if !ch.preRunningContainersIDs.Contains(notif.Container.Runtime.ContainerID) { - // container is not in preRunningContainersIDs, it is a new container - ch.timeBasedContainers.Add(notif.Container.Runtime.ContainerID) - } - switch notif.Type { case containercollection.EventTypeAddContainer: logger.L().Info("start monitor on container", helpers.String("container ID", notif.Container.Runtime.ContainerID), helpers.String("k8s workload", k8sContainerID)) - + if ch.running { + ch.timeBasedContainers.Add(notif.Container.Runtime.ContainerID) + } else { + ch.preRunningContainersIDs.Add(notif.Container.Runtime.ContainerID) + } // Check if Pod has a label of max sniffing time sniffingTime := utils.AddJitter(ch.cfg.MaxSniffingTime, ch.cfg.MaxJitterPercentage) if podLabelMaxSniffingTime, ok := notif.Container.K8s.PodLabels[MaxSniffingTimeLabel]; ok { @@ -87,6 +86,7 @@ func (ch *IGContainerWatcher) startContainerCollection(ctx context.Context) erro ch.networkManager.ContainerCallback, ch.malwareManager.ContainerCallback, ch.ruleManager.ContainerCallback, + ch.processManager.ContainerCallback, } for receiver := range ch.thirdPartyContainerReceivers.Iter() { @@ -130,7 +130,7 @@ func (ch *IGContainerWatcher) startContainerCollection(ctx context.Context) erro return nil } -func (ch *IGContainerWatcher) startRunningContainers() error { +func (ch *IGContainerWatcher) startRunningContainers() { k8sClient, err := containercollection.NewK8sClient(ch.nodeName) if err != nil { logger.L().Fatal("creating IG Kubernetes client", helpers.Error(err)) @@ -139,7 +139,6 @@ func (ch *IGContainerWatcher) startRunningContainers() error { for n := range *ch.ruleBindingPodNotify { ch.addRunningContainers(k8sClient, &n) } - return nil } func (ch *IGContainerWatcher) addRunningContainers(k8sClient IGK8sClient, notf *rulebindingmanager.RuleBindingNotify) { diff --git a/pkg/objectcache/applicationprofilecache/applicationprofilecache.go b/pkg/objectcache/applicationprofilecache/applicationprofilecache.go index 143dadb2..eae254ef 100644 --- a/pkg/objectcache/applicationprofilecache/applicationprofilecache.go +++ b/pkg/objectcache/applicationprofilecache/applicationprofilecache.go @@ -3,6 +3,7 @@ package applicationprofilecache import ( "context" "fmt" + "strings" "time" mapset "github.com/deckarep/golang-set/v2" @@ -47,30 +48,102 @@ func newApplicationProfileState(ap *v1beta1.ApplicationProfile) applicationProfi } type ApplicationProfileCacheImpl struct { - containerToSlug maps.SafeMap[string, string] // cache the containerID to slug mapping, this will enable a quick lookup of the application profile - slugToAppProfile maps.SafeMap[string, *v1beta1.ApplicationProfile] // cache the application profile - slugToContainers maps.SafeMap[string, mapset.Set[string]] // cache the containerIDs that belong to the application profile, this will enable removing from cache AP without pods - slugToState maps.SafeMap[string, applicationProfileState] // cache the containerID to slug mapping, this will enable a quick lookup of the application profile - storageClient versioned.SpdxV1beta1Interface - allProfiles mapset.Set[string] // cache all the application profiles that are ready. this will enable removing from cache AP without pods that are running on the same node - nodeName string - maxDelaySeconds int // maximum delay in seconds before getting the full object from the storage + containerToSlug maps.SafeMap[string, string] // cache the containerID to slug mapping, this will enable a quick lookup of the application profile + slugToAppProfile maps.SafeMap[string, *v1beta1.ApplicationProfile] // cache the application profile + slugToContainers maps.SafeMap[string, mapset.Set[string]] // cache the containerIDs that belong to the application profile, this will enable removing from cache AP without pods + slugToState maps.SafeMap[string, applicationProfileState] // cache the containerID to slug mapping, this will enable a quick lookup of the application profile + storageClient versioned.SpdxV1beta1Interface + allProfiles mapset.Set[string] // cache all the application profiles that are ready. this will enable removing from cache AP without pods that are running on the same node + nodeName string + maxDelaySeconds int // maximum delay in seconds before getting the full object from the storage + userManagedProfiles maps.SafeMap[string, *v1beta1.ApplicationProfile] } func NewApplicationProfileCache(nodeName string, storageClient versioned.SpdxV1beta1Interface, maxDelaySeconds int) *ApplicationProfileCacheImpl { return &ApplicationProfileCacheImpl{ - nodeName: nodeName, - maxDelaySeconds: maxDelaySeconds, - storageClient: storageClient, - containerToSlug: maps.SafeMap[string, string]{}, - slugToContainers: maps.SafeMap[string, mapset.Set[string]]{}, - allProfiles: mapset.NewSet[string](), + nodeName: nodeName, + maxDelaySeconds: maxDelaySeconds, + storageClient: storageClient, + containerToSlug: maps.SafeMap[string, string]{}, + slugToAppProfile: maps.SafeMap[string, *v1beta1.ApplicationProfile]{}, + slugToContainers: maps.SafeMap[string, mapset.Set[string]]{}, + slugToState: maps.SafeMap[string, applicationProfileState]{}, + allProfiles: mapset.NewSet[string](), + userManagedProfiles: maps.SafeMap[string, *v1beta1.ApplicationProfile]{}, } - } // ------------------ objectcache.ApplicationProfileCache methods ----------------------- +func (ap *ApplicationProfileCacheImpl) handleUserManagedProfile(appProfile *v1beta1.ApplicationProfile) { + baseProfileName := strings.TrimPrefix(appProfile.GetName(), "ug-") + baseProfileUniqueName := objectcache.UniqueName(appProfile.GetNamespace(), baseProfileName) + + // Get the full user managed profile from the storage + userManagedProfile, err := ap.getApplicationProfile(appProfile.GetNamespace(), appProfile.GetName()) + if err != nil { + logger.L().Error("failed to get full application profile", helpers.Error(err)) + return + } + + // Store the user-managed profile temporarily + ap.userManagedProfiles.Set(baseProfileUniqueName, userManagedProfile) + + // If we have the base profile cached, fetch a fresh copy and merge. + // If the base profile is not cached yet, the merge will be attempted when it's added. + if ap.slugToAppProfile.Has(baseProfileUniqueName) { + // Fetch fresh base profile from cluster + freshBaseProfile, err := ap.getApplicationProfile(appProfile.GetNamespace(), baseProfileName) + if err != nil { + logger.L().Error("failed to get fresh base profile for merging", + helpers.String("name", baseProfileName), + helpers.String("namespace", appProfile.GetNamespace()), + helpers.Error(err)) + return + } + + mergedProfile := ap.performMerge(freshBaseProfile, userManagedProfile) + ap.slugToAppProfile.Set(baseProfileUniqueName, mergedProfile) + + // Clean up the user-managed profile after successful merge + ap.userManagedProfiles.Delete(baseProfileUniqueName) + + logger.L().Debug("merged user-managed profile with fresh base profile", + helpers.String("name", baseProfileName), + helpers.String("namespace", appProfile.GetNamespace())) + } +} + +func (ap *ApplicationProfileCacheImpl) addApplicationProfile(obj runtime.Object) { + appProfile := obj.(*v1beta1.ApplicationProfile) + apName := objectcache.MetaUniqueName(appProfile) + + if isUserManagedProfile(appProfile) { + ap.handleUserManagedProfile(appProfile) + return + } + + // Original behavior for normal profiles + apState := newApplicationProfileState(appProfile) + ap.slugToState.Set(apName, apState) + + if apState.status != helpersv1.Completed { + if ap.slugToAppProfile.Has(apName) { + ap.slugToAppProfile.Delete(apName) + ap.allProfiles.Remove(apName) + } + return + } + + ap.allProfiles.Add(apName) + + if ap.slugToContainers.Has(apName) { + time.AfterFunc(utils.RandomDuration(ap.maxDelaySeconds, time.Second), func() { + ap.addFullApplicationProfile(appProfile, apName) + }) + } +} + func (ap *ApplicationProfileCacheImpl) GetApplicationProfile(containerID string) *v1beta1.ApplicationProfile { if s := ap.containerToSlug.Get(containerID); s != "" { return ap.slugToAppProfile.Get(s) @@ -110,7 +183,7 @@ func (ap *ApplicationProfileCacheImpl) AddHandler(ctx context.Context, obj runti if pod, ok := obj.(*corev1.Pod); ok { ap.addPod(pod) } else if appProfile, ok := obj.(*v1beta1.ApplicationProfile); ok { - ap.addApplicationProfile(ctx, appProfile) + ap.addApplicationProfile(appProfile) } } @@ -118,7 +191,7 @@ func (ap *ApplicationProfileCacheImpl) ModifyHandler(ctx context.Context, obj ru if pod, ok := obj.(*corev1.Pod); ok { ap.addPod(pod) } else if appProfile, ok := obj.(*v1beta1.ApplicationProfile); ok { - ap.addApplicationProfile(ctx, appProfile) + ap.addApplicationProfile(appProfile) } } @@ -213,36 +286,6 @@ func (ap *ApplicationProfileCacheImpl) removeContainer(containerID string) { } // ------------------ watch application profile methods ----------------------- -func (ap *ApplicationProfileCacheImpl) addApplicationProfile(_ context.Context, obj runtime.Object) { - appProfile := obj.(*v1beta1.ApplicationProfile) - apName := objectcache.MetaUniqueName(appProfile) - - apState := newApplicationProfileState(appProfile) - ap.slugToState.Set(apName, apState) - - // the cache holds only completed application profiles. - // check if the application profile is completed - // if status was completed and now is not (e.g. mode changed from complete to partial), remove from cache - if apState.status != helpersv1.Completed { - if ap.slugToAppProfile.Has(apName) { - ap.slugToAppProfile.Delete(apName) - ap.allProfiles.Remove(apName) - } - return - } - - // add to the cache - ap.allProfiles.Add(apName) - - if ap.slugToContainers.Has(apName) { - // get the full application profile from the storage - // the watch only returns the metadata - // avoid thundering herd problem by adding a random delay - time.AfterFunc(utils.RandomDuration(ap.maxDelaySeconds, time.Second), func() { - ap.addFullApplicationProfile(appProfile, apName) - }) - } -} func (ap *ApplicationProfileCacheImpl) addFullApplicationProfile(appProfile *v1beta1.ApplicationProfile, apName string) { fullAP, err := ap.getApplicationProfile(appProfile.GetNamespace(), appProfile.GetName()) @@ -250,6 +293,16 @@ func (ap *ApplicationProfileCacheImpl) addFullApplicationProfile(appProfile *v1b logger.L().Error("failed to get full application profile", helpers.Error(err)) return } + + // Check if there's a pending user-managed profile to merge + if ap.userManagedProfiles.Has(apName) { + userManagedProfile := ap.userManagedProfiles.Get(apName) + fullAP = ap.performMerge(fullAP, userManagedProfile) + // Clean up the user-managed profile after successful merge + ap.userManagedProfiles.Delete(apName) + logger.L().Debug("merged pending user-managed profile", helpers.String("name", apName)) + } + ap.slugToAppProfile.Set(apName, fullAP) for _, i := range ap.slugToContainers.Get(apName).ToSlice() { ap.containerToSlug.Set(i, apName) @@ -257,13 +310,74 @@ func (ap *ApplicationProfileCacheImpl) addFullApplicationProfile(appProfile *v1b logger.L().Debug("added pod to application profile cache", helpers.String("name", apName)) } +func (ap *ApplicationProfileCacheImpl) performMerge(normalProfile, userManagedProfile *v1beta1.ApplicationProfile) *v1beta1.ApplicationProfile { + mergedProfile := normalProfile.DeepCopy() + + // Merge spec + mergedProfile.Spec.Containers = ap.mergeContainers(mergedProfile.Spec.Containers, userManagedProfile.Spec.Containers) + mergedProfile.Spec.InitContainers = ap.mergeContainers(mergedProfile.Spec.InitContainers, userManagedProfile.Spec.InitContainers) + mergedProfile.Spec.EphemeralContainers = ap.mergeContainers(mergedProfile.Spec.EphemeralContainers, userManagedProfile.Spec.EphemeralContainers) + + return mergedProfile +} + +func (ap *ApplicationProfileCacheImpl) mergeContainers(normalContainers, userManagedContainers []v1beta1.ApplicationProfileContainer) []v1beta1.ApplicationProfileContainer { + if len(userManagedContainers) != len(normalContainers) { + // If the number of containers don't match, we can't merge + logger.L().Error("failed to merge user-managed profile with base profile", + helpers.Int("normalContainers len", len(normalContainers)), + helpers.Int("userManagedContainers len", len(userManagedContainers)), + helpers.String("reason", "number of containers don't match")) + return normalContainers + } + + // Assuming the normalContainers are already in the correct Pod order + // We'll merge user containers at their corresponding positions + for i := range normalContainers { + for _, userContainer := range userManagedContainers { + if normalContainers[i].Name == userContainer.Name { + ap.mergeContainer(&normalContainers[i], &userContainer) + break + } + } + } + return normalContainers +} + +func (ap *ApplicationProfileCacheImpl) mergeContainer(normalContainer, userContainer *v1beta1.ApplicationProfileContainer) { + normalContainer.Capabilities = append(normalContainer.Capabilities, userContainer.Capabilities...) + normalContainer.Execs = append(normalContainer.Execs, userContainer.Execs...) + normalContainer.Opens = append(normalContainer.Opens, userContainer.Opens...) + normalContainer.Syscalls = append(normalContainer.Syscalls, userContainer.Syscalls...) + normalContainer.Endpoints = append(normalContainer.Endpoints, userContainer.Endpoints...) +} + func (ap *ApplicationProfileCacheImpl) deleteApplicationProfile(obj runtime.Object) { - apName := objectcache.MetaUniqueName(obj.(metav1.Object)) - ap.slugToAppProfile.Delete(apName) - ap.slugToState.Delete(apName) - ap.allProfiles.Remove(apName) + appProfile := obj.(*v1beta1.ApplicationProfile) + apName := objectcache.MetaUniqueName(appProfile) - logger.L().Info("deleted application profile from cache", helpers.String("uniqueSlug", apName)) + if isUserManagedProfile(appProfile) { + // For user-managed profiles, we need to use the base name for cleanup + baseProfileName := strings.TrimPrefix(appProfile.GetName(), "ug-") + baseProfileUniqueName := objectcache.UniqueName(appProfile.GetNamespace(), baseProfileName) + ap.userManagedProfiles.Delete(baseProfileUniqueName) + + logger.L().Debug("deleted user-managed profile from cache", + helpers.String("profileName", appProfile.GetName()), + helpers.String("baseProfile", baseProfileName)) + } else { + // For normal profiles, clean up all related data + ap.slugToAppProfile.Delete(apName) + ap.slugToState.Delete(apName) + ap.allProfiles.Remove(apName) + + // Log the deletion of normal profile + logger.L().Debug("deleted application profile from cache", + helpers.String("uniqueSlug", apName)) + } + + // Clean up any orphaned user-managed profiles + ap.cleanupOrphanedUserManagedProfiles() } func (ap *ApplicationProfileCacheImpl) getApplicationProfile(namespace, name string) (*v1beta1.ApplicationProfile, error) { @@ -301,3 +415,27 @@ func getSlug(p *corev1.Pod) (string, error) { } return slug, nil } + +// Add cleanup method for any orphaned user-managed profiles +func (ap *ApplicationProfileCacheImpl) cleanupOrphanedUserManagedProfiles() { + // This could be called periodically or during certain operations + ap.userManagedProfiles.Range(func(key string, value *v1beta1.ApplicationProfile) bool { + if ap.slugToAppProfile.Has(key) { + // If base profile exists but merge didn't happen for some reason, + // attempt merge again and cleanup + if baseProfile := ap.slugToAppProfile.Get(key); baseProfile != nil { + mergedProfile := ap.performMerge(baseProfile, value) + ap.slugToAppProfile.Set(key, mergedProfile) + ap.userManagedProfiles.Delete(key) + logger.L().Debug("cleaned up orphaned user-managed profile", helpers.String("name", key)) + } + } + return true + }) +} + +func isUserManagedProfile(appProfile *v1beta1.ApplicationProfile) bool { + return appProfile.Annotations != nil && + appProfile.Annotations["kubescape.io/managed-by"] == "User" && + strings.HasPrefix(appProfile.GetName(), "ug-") +} diff --git a/pkg/objectcache/applicationprofilecache/applicationprofilecache_test.go b/pkg/objectcache/applicationprofilecache/applicationprofilecache_test.go index eab2c8f4..3c1797e5 100644 --- a/pkg/objectcache/applicationprofilecache/applicationprofilecache_test.go +++ b/pkg/objectcache/applicationprofilecache/applicationprofilecache_test.go @@ -186,10 +186,10 @@ func Test_addApplicationProfile(t *testing.T) { ap.addPod(tt.preCreatedPods[i]) } for i := range tt.preCreatedAP { - ap.addApplicationProfile(context.Background(), tt.preCreatedAP[i]) + ap.addApplicationProfile(tt.preCreatedAP[i]) } - ap.addApplicationProfile(context.Background(), tt.obj) + ap.addApplicationProfile(tt.obj) time.Sleep(1 * time.Second) // add is async // test if the application profile is added to the cache @@ -513,9 +513,9 @@ func Test_addApplicationProfile_existing(t *testing.T) { ap.slugToContainers.Set(tt.pods[i].slug, mapset.NewSet(tt.pods[i].podName)) } - ap.addApplicationProfile(context.Background(), tt.obj1) + ap.addApplicationProfile(tt.obj1) time.Sleep(1 * time.Second) // add is async - ap.addApplicationProfile(context.Background(), tt.obj2) + ap.addApplicationProfile(tt.obj2) // test if the application profile is added to the cache if tt.storeInCache { @@ -706,7 +706,7 @@ func Test_addPod(t *testing.T) { ap := NewApplicationProfileCache("", storageClient, 0) - ap.addApplicationProfile(context.Background(), tt.preCreatedAP) + ap.addApplicationProfile(tt.preCreatedAP) time.Sleep(1 * time.Second) // add is async tt.obj.(metav1.Object).SetNamespace(namespace) @@ -734,3 +734,202 @@ func Test_addPod(t *testing.T) { }) } } + +func Test_MergeApplicationProfiles(t *testing.T) { + tests := []struct { + name string + normalProfile *v1beta1.ApplicationProfile + userProfile *v1beta1.ApplicationProfile + expectedMerged *v1beta1.ApplicationProfile + }{ + { + name: "merge profiles with overlapping containers", + normalProfile: &v1beta1.ApplicationProfile{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-profile", + Namespace: "default", + }, + Spec: v1beta1.ApplicationProfileSpec{ + Containers: []v1beta1.ApplicationProfileContainer{ + { + Name: "container1", + Capabilities: []string{ + "NET_ADMIN", + }, + Syscalls: []string{ + "open", + }, + Opens: []v1beta1.OpenCalls{ + { + Path: "/etc/config", + }, + }, + }, + }, + }, + }, + userProfile: &v1beta1.ApplicationProfile{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ug-test-profile", // Added ug- prefix + Namespace: "default", + Annotations: map[string]string{ + "kubescape.io/managed-by": "User", + }, + }, + Spec: v1beta1.ApplicationProfileSpec{ + Containers: []v1beta1.ApplicationProfileContainer{ + { + Name: "container1", + Capabilities: []string{ + "SYS_ADMIN", + }, + Syscalls: []string{ + "write", + }, + Opens: []v1beta1.OpenCalls{ + { + Path: "/etc/secret", + }, + }, + }, + }, + }, + }, + expectedMerged: &v1beta1.ApplicationProfile{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-profile", // Keeps original name without ug- prefix + Namespace: "default", + }, + Spec: v1beta1.ApplicationProfileSpec{ + Containers: []v1beta1.ApplicationProfileContainer{ + { + Name: "container1", + Capabilities: []string{ + "NET_ADMIN", + "SYS_ADMIN", + }, + Syscalls: []string{ + "open", + "write", + }, + Opens: []v1beta1.OpenCalls{ + { + Path: "/etc/config", + }, + { + Path: "/etc/secret", + }, + }, + }, + }, + }, + }, + }, + { + name: "merge profiles with init containers", + normalProfile: &v1beta1.ApplicationProfile{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-profile", + Namespace: "default", + }, + Spec: v1beta1.ApplicationProfileSpec{ + InitContainers: []v1beta1.ApplicationProfileContainer{ + { + Name: "init1", + Execs: []v1beta1.ExecCalls{ + { + Path: "mount", + }, + }, + }, + }, + }, + }, + userProfile: &v1beta1.ApplicationProfile{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ug-test-profile", // Added ug- prefix + Namespace: "default", + Annotations: map[string]string{ + "kubescape.io/managed-by": "User", + }, + }, + Spec: v1beta1.ApplicationProfileSpec{ + InitContainers: []v1beta1.ApplicationProfileContainer{ + { + Name: "init1", + Execs: []v1beta1.ExecCalls{ + { + Path: "chmod", + }, + }, + Syscalls: []string{ + "chmod", + }, + }, + }, + }, + }, + expectedMerged: &v1beta1.ApplicationProfile{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-profile", // Keeps original name without ug- prefix + Namespace: "default", + }, + Spec: v1beta1.ApplicationProfileSpec{ + InitContainers: []v1beta1.ApplicationProfileContainer{ + { + Name: "init1", + Execs: []v1beta1.ExecCalls{ + { + Path: "mount", + }, + { + Path: "chmod", + }, + }, + Syscalls: []string{ + "chmod", + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := NewApplicationProfileCache("test-node", nil, 0) + merged := cache.performMerge(tt.normalProfile, tt.userProfile) + + // Verify object metadata + assert.Equal(t, tt.expectedMerged.Name, merged.Name) + assert.Equal(t, tt.expectedMerged.Namespace, merged.Namespace) + + // Verify user-managed annotation is removed + _, hasAnnotation := merged.Annotations["kubescape.io/managed-by"] + assert.False(t, hasAnnotation) + + // Verify containers + assert.Equal(t, len(tt.expectedMerged.Spec.Containers), len(merged.Spec.Containers)) + for i, container := range tt.expectedMerged.Spec.Containers { + assert.Equal(t, container.Name, merged.Spec.Containers[i].Name) + assert.ElementsMatch(t, container.Capabilities, merged.Spec.Containers[i].Capabilities) + assert.ElementsMatch(t, container.Syscalls, merged.Spec.Containers[i].Syscalls) + assert.ElementsMatch(t, container.Opens, merged.Spec.Containers[i].Opens) + assert.ElementsMatch(t, container.Execs, merged.Spec.Containers[i].Execs) + assert.ElementsMatch(t, container.Endpoints, merged.Spec.Containers[i].Endpoints) + } + + // Verify init containers + assert.Equal(t, len(tt.expectedMerged.Spec.InitContainers), len(merged.Spec.InitContainers)) + for i, container := range tt.expectedMerged.Spec.InitContainers { + assert.Equal(t, container.Name, merged.Spec.InitContainers[i].Name) + assert.ElementsMatch(t, container.Capabilities, merged.Spec.InitContainers[i].Capabilities) + assert.ElementsMatch(t, container.Syscalls, merged.Spec.InitContainers[i].Syscalls) + assert.ElementsMatch(t, container.Opens, merged.Spec.InitContainers[i].Opens) + assert.ElementsMatch(t, container.Execs, merged.Spec.InitContainers[i].Execs) + assert.ElementsMatch(t, container.Endpoints, merged.Spec.InitContainers[i].Endpoints) + } + }) + } +} diff --git a/pkg/objectcache/networkneighborhoodcache/networkneighborhoodcache.go b/pkg/objectcache/networkneighborhoodcache/networkneighborhoodcache.go index 6509ef61..93fc1083 100644 --- a/pkg/objectcache/networkneighborhoodcache/networkneighborhoodcache.go +++ b/pkg/objectcache/networkneighborhoodcache/networkneighborhoodcache.go @@ -3,6 +3,7 @@ package networkneighborhoodcache import ( "context" "fmt" + "strings" "time" mapset "github.com/deckarep/golang-set/v2" @@ -47,30 +48,99 @@ func newNetworkNeighborhoodState(nn *v1beta1.NetworkNeighborhood) networkNeighbo } type NetworkNeighborhoodCacheImpl struct { - containerToSlug maps.SafeMap[string, string] // cache the containerID to slug mapping, this will enable a quick lookup of the network neighborhood - slugToNetworkNeighborhood maps.SafeMap[string, *v1beta1.NetworkNeighborhood] // cache the network neighborhood - slugToContainers maps.SafeMap[string, mapset.Set[string]] // cache the containerIDs that belong to the network neighborhood, this will enable removing from cache NN without pods - slugToState maps.SafeMap[string, networkNeighborhoodState] // cache the containerID to slug mapping, this will enable a quick lookup of the network neighborhood - storageClient versioned.SpdxV1beta1Interface - allNetworkNeighborhoods mapset.Set[string] // cache all the NN that are ready. this will enable removing from cache NN without pods that are running on the same node - nodeName string - maxDelaySeconds int // maximum delay in seconds before getting the full object from the storage + containerToSlug maps.SafeMap[string, string] // cache the containerID to slug mapping, this will enable a quick lookup of the network neighborhood + slugToNetworkNeighborhood maps.SafeMap[string, *v1beta1.NetworkNeighborhood] // cache the network neighborhood + slugToContainers maps.SafeMap[string, mapset.Set[string]] // cache the containerIDs that belong to the network neighborhood, this will enable removing from cache NN without pods + slugToState maps.SafeMap[string, networkNeighborhoodState] // cache the containerID to slug mapping, this will enable a quick lookup of the network neighborhood + storageClient versioned.SpdxV1beta1Interface + allNetworkNeighborhoods mapset.Set[string] // cache all the NN that are ready. this will enable removing from cache NN without pods that are running on the same node + nodeName string + maxDelaySeconds int // maximum delay in seconds before getting the full object from the storage + userManagedNetworkNeighborhood maps.SafeMap[string, *v1beta1.NetworkNeighborhood] } func NewNetworkNeighborhoodCache(nodeName string, storageClient versioned.SpdxV1beta1Interface, maxDelaySeconds int) *NetworkNeighborhoodCacheImpl { return &NetworkNeighborhoodCacheImpl{ - nodeName: nodeName, - maxDelaySeconds: maxDelaySeconds, - storageClient: storageClient, - containerToSlug: maps.SafeMap[string, string]{}, - slugToContainers: maps.SafeMap[string, mapset.Set[string]]{}, - allNetworkNeighborhoods: mapset.NewSet[string](), + nodeName: nodeName, + maxDelaySeconds: maxDelaySeconds, + storageClient: storageClient, + containerToSlug: maps.SafeMap[string, string]{}, + slugToContainers: maps.SafeMap[string, mapset.Set[string]]{}, + allNetworkNeighborhoods: mapset.NewSet[string](), + userManagedNetworkNeighborhood: maps.SafeMap[string, *v1beta1.NetworkNeighborhood]{}, } - } // ------------------ objectcache.NetworkNeighborhoodCache methods ----------------------- +func (nn *NetworkNeighborhoodCacheImpl) handleUserManagedNN(netNeighborhood *v1beta1.NetworkNeighborhood) { + baseNNName := strings.TrimPrefix(netNeighborhood.GetName(), "ug-") + baseNNUniqueName := objectcache.UniqueName(netNeighborhood.GetNamespace(), baseNNName) + + // Get the full user managed network neighborhood from the storage + userManagedNN, err := nn.getNetworkNeighborhood(netNeighborhood.GetNamespace(), netNeighborhood.GetName()) + if err != nil { + logger.L().Error("failed to get full network neighborhood", helpers.Error(err)) + return + } + + // Store the user-managed network neighborhood temporarily + nn.userManagedNetworkNeighborhood.Set(baseNNUniqueName, userManagedNN) + + // If we have the base network neighborhood cached, fetch a fresh copy and merge. + // If the base network neighborhood is not cached yet, the merge will be attempted when it's added. + if nn.slugToNetworkNeighborhood.Has(baseNNUniqueName) { + // Fetch fresh base network neighborhood from cluster + freshBaseNN, err := nn.getNetworkNeighborhood(netNeighborhood.GetNamespace(), baseNNName) + if err != nil { + logger.L().Error("failed to get fresh base network neighborhood for merging", + helpers.String("name", baseNNName), + helpers.String("namespace", netNeighborhood.GetNamespace()), + helpers.Error(err)) + return + } + + mergedNN := nn.performMerge(freshBaseNN, userManagedNN) + nn.slugToNetworkNeighborhood.Set(baseNNUniqueName, mergedNN) + + // Clean up the user-managed network neighborhood after successful merge + nn.userManagedNetworkNeighborhood.Delete(baseNNUniqueName) + + logger.L().Debug("merged user-managed network neighborhood with fresh base network neighborhood", + helpers.String("name", baseNNName), + helpers.String("namespace", netNeighborhood.GetNamespace())) + } +} + +func (nn *NetworkNeighborhoodCacheImpl) addNetworkNeighborhood(_ context.Context, obj runtime.Object) { + netNeighborhood := obj.(*v1beta1.NetworkNeighborhood) + nnName := objectcache.MetaUniqueName(netNeighborhood) + + if isUserManagedNN(netNeighborhood) { + nn.handleUserManagedNN(netNeighborhood) + return + } + + nnState := newNetworkNeighborhoodState(netNeighborhood) + nn.slugToState.Set(nnName, nnState) + + if nnState.status != helpersv1.Completed { + if nn.slugToNetworkNeighborhood.Has(nnName) { + nn.slugToNetworkNeighborhood.Delete(nnName) + nn.allNetworkNeighborhoods.Remove(nnName) + } + return + } + + nn.allNetworkNeighborhoods.Add(nnName) + + if nn.slugToContainers.Has(nnName) { + time.AfterFunc(utils.RandomDuration(nn.maxDelaySeconds, time.Second), func() { + nn.addFullNetworkNeighborhood(netNeighborhood, nnName) + }) + } +} + func (nn *NetworkNeighborhoodCacheImpl) GetNetworkNeighborhood(containerID string) *v1beta1.NetworkNeighborhood { if s := nn.containerToSlug.Get(containerID); s != "" { return nn.slugToNetworkNeighborhood.Get(s) @@ -213,36 +283,6 @@ func (nn *NetworkNeighborhoodCacheImpl) removeContainer(containerID string) { } // ------------------ watch network neighborhood methods ----------------------- -func (nn *NetworkNeighborhoodCacheImpl) addNetworkNeighborhood(_ context.Context, obj runtime.Object) { - netNeighborhood := obj.(*v1beta1.NetworkNeighborhood) - nnName := objectcache.MetaUniqueName(netNeighborhood) - - nnState := newNetworkNeighborhoodState(netNeighborhood) - nn.slugToState.Set(nnName, nnState) - - // the cache holds only completed network neighborhoods. - // check if the network neighborhood is completed - // if status was completed and now is not (e.g. mode changed from complete to partial), remove from cache - if nnState.status != helpersv1.Completed { - if nn.slugToNetworkNeighborhood.Has(nnName) { - nn.slugToNetworkNeighborhood.Delete(nnName) - nn.allNetworkNeighborhoods.Remove(nnName) - } - return - } - - // add to the cache - nn.allNetworkNeighborhoods.Add(nnName) - - if nn.slugToContainers.Has(nnName) { - // get the full network neighborhood from the storage - // the watch only returns the metadata - // avoid thundering herd problem by adding a random delay - time.AfterFunc(utils.RandomDuration(nn.maxDelaySeconds, time.Second), func() { - nn.addFullNetworkNeighborhood(netNeighborhood, nnName) - }) - } -} func (nn *NetworkNeighborhoodCacheImpl) addFullNetworkNeighborhood(netNeighborhood *v1beta1.NetworkNeighborhood, nnName string) { fullNN, err := nn.getNetworkNeighborhood(netNeighborhood.GetNamespace(), netNeighborhood.GetName()) @@ -250,6 +290,16 @@ func (nn *NetworkNeighborhoodCacheImpl) addFullNetworkNeighborhood(netNeighborho logger.L().Error("failed to get full network neighborhood", helpers.Error(err)) return } + + // Check if there's a pending user-managed network neighborhood to merge + if nn.userManagedNetworkNeighborhood.Has(nnName) { + userManagedNN := nn.userManagedNetworkNeighborhood.Get(nnName) + fullNN = nn.performMerge(fullNN, userManagedNN) + // Clean up the user-managed network neighborhood after successful merge + nn.userManagedNetworkNeighborhood.Delete(nnName) + logger.L().Debug("merged pending user-managed network neighborhood", helpers.String("name", nnName)) + } + nn.slugToNetworkNeighborhood.Set(nnName, fullNN) for _, i := range nn.slugToContainers.Get(nnName).ToSlice() { nn.containerToSlug.Set(i, nnName) @@ -257,13 +307,199 @@ func (nn *NetworkNeighborhoodCacheImpl) addFullNetworkNeighborhood(netNeighborho logger.L().Debug("added pod to network neighborhood cache", helpers.String("name", nnName)) } +func (nn *NetworkNeighborhoodCacheImpl) performMerge(normalNN, userManagedNN *v1beta1.NetworkNeighborhood) *v1beta1.NetworkNeighborhood { + mergedNN := normalNN.DeepCopy() + + // Merge spec containers + mergedNN.Spec.Containers = nn.mergeContainers(mergedNN.Spec.Containers, userManagedNN.Spec.Containers) + mergedNN.Spec.InitContainers = nn.mergeContainers(mergedNN.Spec.InitContainers, userManagedNN.Spec.InitContainers) + mergedNN.Spec.EphemeralContainers = nn.mergeContainers(mergedNN.Spec.EphemeralContainers, userManagedNN.Spec.EphemeralContainers) + + // Merge LabelSelector + if userManagedNN.Spec.LabelSelector.MatchLabels != nil { + if mergedNN.Spec.LabelSelector.MatchLabels == nil { + mergedNN.Spec.LabelSelector.MatchLabels = make(map[string]string) + } + for k, v := range userManagedNN.Spec.LabelSelector.MatchLabels { + mergedNN.Spec.LabelSelector.MatchLabels[k] = v + } + } + mergedNN.Spec.LabelSelector.MatchExpressions = append( + mergedNN.Spec.LabelSelector.MatchExpressions, + userManagedNN.Spec.LabelSelector.MatchExpressions..., + ) + + return mergedNN +} + +func (nn *NetworkNeighborhoodCacheImpl) mergeContainers(normalContainers, userManagedContainers []v1beta1.NetworkNeighborhoodContainer) []v1beta1.NetworkNeighborhoodContainer { + if len(userManagedContainers) != len(normalContainers) { + // If the number of containers don't match, we can't merge + logger.L().Error("failed to merge user-managed profile with base profile", + helpers.Int("normalContainers len", len(normalContainers)), + helpers.Int("userManagedContainers len", len(userManagedContainers)), + helpers.String("reason", "number of containers don't match")) + return normalContainers + } + + // Assuming the normalContainers are already in the correct Pod order + // We'll merge user containers at their corresponding positions + for i := range normalContainers { + for _, userContainer := range userManagedContainers { + if normalContainers[i].Name == userContainer.Name { + nn.mergeContainer(&normalContainers[i], &userContainer) + break + } + } + } + return normalContainers +} + +func (nn *NetworkNeighborhoodCacheImpl) mergeContainer(normalContainer, userContainer *v1beta1.NetworkNeighborhoodContainer) { + // Merge ingress rules + normalContainer.Ingress = nn.mergeNetworkNeighbors(normalContainer.Ingress, userContainer.Ingress) + + // Merge egress rules + normalContainer.Egress = nn.mergeNetworkNeighbors(normalContainer.Egress, userContainer.Egress) +} + +func (nn *NetworkNeighborhoodCacheImpl) mergeNetworkNeighbors(normalNeighbors, userNeighbors []v1beta1.NetworkNeighbor) []v1beta1.NetworkNeighbor { + // Use map to track existing neighbors by identifier + neighborMap := make(map[string]int) + for i, neighbor := range normalNeighbors { + neighborMap[neighbor.Identifier] = i + } + + // Merge or append user neighbors + for _, userNeighbor := range userNeighbors { + if idx, exists := neighborMap[userNeighbor.Identifier]; exists { + // Merge existing neighbor + normalNeighbors[idx] = nn.mergeNetworkNeighbor(normalNeighbors[idx], userNeighbor) + } else { + // Append new neighbor + normalNeighbors = append(normalNeighbors, userNeighbor) + } + } + + return normalNeighbors +} + +func (nn *NetworkNeighborhoodCacheImpl) mergeNetworkNeighbor(normal, user v1beta1.NetworkNeighbor) v1beta1.NetworkNeighbor { + merged := normal.DeepCopy() + + // Merge DNS names (removing duplicates) + dnsNamesSet := make(map[string]struct{}) + for _, dns := range normal.DNSNames { + dnsNamesSet[dns] = struct{}{} + } + for _, dns := range user.DNSNames { + dnsNamesSet[dns] = struct{}{} + } + merged.DNSNames = make([]string, 0, len(dnsNamesSet)) + for dns := range dnsNamesSet { + merged.DNSNames = append(merged.DNSNames, dns) + } + + // Merge ports based on patchMergeKey (name) + merged.Ports = nn.mergeNetworkPorts(merged.Ports, user.Ports) + + // Merge pod selector if provided + if user.PodSelector != nil { + if merged.PodSelector == nil { + merged.PodSelector = &metav1.LabelSelector{} + } + if user.PodSelector.MatchLabels != nil { + if merged.PodSelector.MatchLabels == nil { + merged.PodSelector.MatchLabels = make(map[string]string) + } + for k, v := range user.PodSelector.MatchLabels { + merged.PodSelector.MatchLabels[k] = v + } + } + merged.PodSelector.MatchExpressions = append( + merged.PodSelector.MatchExpressions, + user.PodSelector.MatchExpressions..., + ) + } + + // Merge namespace selector if provided + if user.NamespaceSelector != nil { + if merged.NamespaceSelector == nil { + merged.NamespaceSelector = &metav1.LabelSelector{} + } + if user.NamespaceSelector.MatchLabels != nil { + if merged.NamespaceSelector.MatchLabels == nil { + merged.NamespaceSelector.MatchLabels = make(map[string]string) + } + for k, v := range user.NamespaceSelector.MatchLabels { + merged.NamespaceSelector.MatchLabels[k] = v + } + } + merged.NamespaceSelector.MatchExpressions = append( + merged.NamespaceSelector.MatchExpressions, + user.NamespaceSelector.MatchExpressions..., + ) + } + + // Take the user's IP address if provided + if user.IPAddress != "" { + merged.IPAddress = user.IPAddress + } + + // Take the user's type if provided + if user.Type != "" { + merged.Type = user.Type + } + + return *merged +} + +func (nn *NetworkNeighborhoodCacheImpl) mergeNetworkPorts(normalPorts, userPorts []v1beta1.NetworkPort) []v1beta1.NetworkPort { + // Use map to track existing ports by name (patchMergeKey) + portMap := make(map[string]int) + for i, port := range normalPorts { + portMap[port.Name] = i + } + + // Merge or append user ports + for _, userPort := range userPorts { + if idx, exists := portMap[userPort.Name]; exists { + // Update existing port + normalPorts[idx] = userPort + } else { + // Append new port + normalPorts = append(normalPorts, userPort) + } + } + + return normalPorts +} + func (nn *NetworkNeighborhoodCacheImpl) deleteNetworkNeighborhood(obj runtime.Object) { - nnName := objectcache.MetaUniqueName(obj.(metav1.Object)) - nn.slugToNetworkNeighborhood.Delete(nnName) - nn.slugToState.Delete(nnName) - nn.allNetworkNeighborhoods.Remove(nnName) + netNeighborhood := obj.(*v1beta1.NetworkNeighborhood) + nnName := objectcache.MetaUniqueName(netNeighborhood) - logger.L().Info("deleted network neighborhood from cache", helpers.String("uniqueSlug", nnName)) + if isUserManagedNN(netNeighborhood) { + // For user-managed network neighborhoods, we need to use the base name for cleanup + baseNNName := strings.TrimPrefix(netNeighborhood.GetName(), "ug-") + baseNNUniqueName := objectcache.UniqueName(netNeighborhood.GetNamespace(), baseNNName) + nn.userManagedNetworkNeighborhood.Delete(baseNNUniqueName) + + logger.L().Debug("deleted user-managed network neighborhood from cache", + helpers.String("nnName", netNeighborhood.GetName()), + helpers.String("baseNN", baseNNName)) + } else { + // For normal network neighborhoods, clean up all related data + nn.slugToNetworkNeighborhood.Delete(nnName) + nn.slugToState.Delete(nnName) + nn.allNetworkNeighborhoods.Remove(nnName) + + logger.L().Debug("deleted network neighborhood from cache", + helpers.String("uniqueSlug", nnName)) + } + + // Clean up any orphaned user-managed network neighborhoods + nn.cleanupOrphanedUserManagedNNs() } func (nn *NetworkNeighborhoodCacheImpl) getNetworkNeighborhood(namespace, name string) (*v1beta1.NetworkNeighborhood, error) { @@ -301,3 +537,26 @@ func getSlug(p *corev1.Pod) (string, error) { } return slug, nil } + +// Add cleanup method for orphaned user-managed network neighborhoods +func (nn *NetworkNeighborhoodCacheImpl) cleanupOrphanedUserManagedNNs() { + nn.userManagedNetworkNeighborhood.Range(func(key string, value *v1beta1.NetworkNeighborhood) bool { + if nn.slugToNetworkNeighborhood.Has(key) { + // If base network neighborhood exists but merge didn't happen for some reason, + // attempt merge again and cleanup + if baseNN := nn.slugToNetworkNeighborhood.Get(key); baseNN != nil { + mergedNN := nn.performMerge(baseNN, value) + nn.slugToNetworkNeighborhood.Set(key, mergedNN) + nn.userManagedNetworkNeighborhood.Delete(key) + logger.L().Debug("cleaned up orphaned user-managed network neighborhood", helpers.String("name", key)) + } + } + return true + }) +} + +func isUserManagedNN(nn *v1beta1.NetworkNeighborhood) bool { + return nn.Annotations != nil && + nn.Annotations["kubescape.io/managed-by"] == "User" && + strings.HasPrefix(nn.GetName(), "ug-") +} diff --git a/pkg/objectcache/networkneighborhoodcache/networkneighborhoodcache_test.go b/pkg/objectcache/networkneighborhoodcache/networkneighborhoodcache_test.go index acb96444..49f415be 100644 --- a/pkg/objectcache/networkneighborhoodcache/networkneighborhoodcache_test.go +++ b/pkg/objectcache/networkneighborhoodcache/networkneighborhoodcache_test.go @@ -734,3 +734,243 @@ func Test_addPod(t *testing.T) { }) } } + +func Test_performMerge(t *testing.T) { + tests := []struct { + name string + baseNN *v1beta1.NetworkNeighborhood + userNN *v1beta1.NetworkNeighborhood + expectedResult *v1beta1.NetworkNeighborhood + validateResults func(*testing.T, *v1beta1.NetworkNeighborhood) + }{ + { + name: "merge basic network neighbors", + baseNN: &v1beta1.NetworkNeighborhood{ + Spec: v1beta1.NetworkNeighborhoodSpec{ + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "container1", + Ingress: []v1beta1.NetworkNeighbor{ + { + Identifier: "ingress1", + Type: "http", + DNSNames: []string{"example.com"}, + Ports: []v1beta1.NetworkPort{ + {Name: "http", Protocol: "TCP", Port: ptr(int32(80))}, + }, + }, + }, + }, + }, + }, + }, + userNN: &v1beta1.NetworkNeighborhood{ + Spec: v1beta1.NetworkNeighborhoodSpec{ + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "container1", + Ingress: []v1beta1.NetworkNeighbor{ + { + Identifier: "ingress2", + Type: "https", + DNSNames: []string{"secure.example.com"}, + Ports: []v1beta1.NetworkPort{ + {Name: "https", Protocol: "TCP", Port: ptr(int32(443))}, + }, + }, + }, + }, + }, + }, + }, + validateResults: func(t *testing.T, result *v1beta1.NetworkNeighborhood) { + assert.Len(t, result.Spec.Containers, 1) + assert.Len(t, result.Spec.Containers[0].Ingress, 2) + + // Verify both ingress rules are present + ingressIdentifiers := []string{ + result.Spec.Containers[0].Ingress[0].Identifier, + result.Spec.Containers[0].Ingress[1].Identifier, + } + assert.Contains(t, ingressIdentifiers, "ingress1") + assert.Contains(t, ingressIdentifiers, "ingress2") + }, + }, + { + name: "merge label selectors", + baseNN: &v1beta1.NetworkNeighborhood{ + Spec: v1beta1.NetworkNeighborhoodSpec{ + LabelSelector: metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app": "base", + }, + }, + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "container1", + Egress: []v1beta1.NetworkNeighbor{ + { + Identifier: "egress1", + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "role": "db", + }, + }, + }, + }, + }, + }, + }, + }, + userNN: &v1beta1.NetworkNeighborhood{ + Spec: v1beta1.NetworkNeighborhoodSpec{ + LabelSelector: metav1.LabelSelector{ + MatchLabels: map[string]string{ + "env": "prod", + }, + }, + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "container1", + Egress: []v1beta1.NetworkNeighbor{ + { + Identifier: "egress1", + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "version": "v1", + }, + }, + }, + }, + }, + }, + }, + }, + validateResults: func(t *testing.T, result *v1beta1.NetworkNeighborhood) { + // Verify merged label selectors + assert.Equal(t, "base", result.Spec.LabelSelector.MatchLabels["app"]) + assert.Equal(t, "prod", result.Spec.LabelSelector.MatchLabels["env"]) + + // Verify merged pod selector in egress rule + container := result.Spec.Containers[0] + podSelector := container.Egress[0].PodSelector + assert.Equal(t, "db", podSelector.MatchLabels["role"]) + assert.Equal(t, "v1", podSelector.MatchLabels["version"]) + }, + }, + { + name: "merge network ports", + baseNN: &v1beta1.NetworkNeighborhood{ + Spec: v1beta1.NetworkNeighborhoodSpec{ + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "container1", + Egress: []v1beta1.NetworkNeighbor{ + { + Identifier: "egress1", + Ports: []v1beta1.NetworkPort{ + {Name: "http", Protocol: "TCP", Port: ptr(int32(80))}, + }, + }, + }, + }, + }, + }, + }, + userNN: &v1beta1.NetworkNeighborhood{ + Spec: v1beta1.NetworkNeighborhoodSpec{ + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "container1", + Egress: []v1beta1.NetworkNeighbor{ + { + Identifier: "egress1", + Ports: []v1beta1.NetworkPort{ + {Name: "http", Protocol: "TCP", Port: ptr(int32(8080))}, // Override existing port + {Name: "https", Protocol: "TCP", Port: ptr(int32(443))}, // Add new port + }, + }, + }, + }, + }, + }, + }, + validateResults: func(t *testing.T, result *v1beta1.NetworkNeighborhood) { + container := result.Spec.Containers[0] + ports := container.Egress[0].Ports + + // Verify ports are properly merged + assert.Len(t, ports, 2) + + // Find HTTP port - should be updated to 8080 + for _, port := range ports { + if port.Name == "http" { + assert.Equal(t, int32(8080), *port.Port) + } + if port.Name == "https" { + assert.Equal(t, int32(443), *port.Port) + } + } + }, + }, + { + name: "merge DNS names", + baseNN: &v1beta1.NetworkNeighborhood{ + Spec: v1beta1.NetworkNeighborhoodSpec{ + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "container1", + Egress: []v1beta1.NetworkNeighbor{ + { + Identifier: "egress1", + DNSNames: []string{"example.com", "api.example.com"}, + }, + }, + }, + }, + }, + }, + userNN: &v1beta1.NetworkNeighborhood{ + Spec: v1beta1.NetworkNeighborhoodSpec{ + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "container1", + Egress: []v1beta1.NetworkNeighbor{ + { + Identifier: "egress1", + DNSNames: []string{"api.example.com", "admin.example.com"}, + }, + }, + }, + }, + }, + }, + validateResults: func(t *testing.T, result *v1beta1.NetworkNeighborhood) { + container := result.Spec.Containers[0] + dnsNames := container.Egress[0].DNSNames + + // Verify DNS names are properly merged and deduplicated + assert.Len(t, dnsNames, 3) + assert.Contains(t, dnsNames, "example.com") + assert.Contains(t, dnsNames, "api.example.com") + assert.Contains(t, dnsNames, "admin.example.com") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nn := NewNetworkNeighborhoodCache("test-node", nil, 0) + result := nn.performMerge(tt.baseNN, tt.userNN) + + if tt.validateResults != nil { + tt.validateResults(t, result) + } + }) + } +} + +// Helper function to create pointer to int32 +func ptr(i int32) *int32 { + return &i +} diff --git a/pkg/processmanager/process_manager_interface.go b/pkg/processmanager/process_manager_interface.go new file mode 100644 index 00000000..0f36f367 --- /dev/null +++ b/pkg/processmanager/process_manager_interface.go @@ -0,0 +1,20 @@ +package processmanager + +import ( + apitypes "github.com/armosec/armoapi-go/armotypes" + containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" + "github.com/kubescape/node-agent/pkg/utils" +) + +// ProcessManagerClient is the interface for the process manager client. +// It provides methods to get process tree for a container or a PID. +// The manager is responsible for maintaining the process tree for all containers. +type ProcessManagerClient interface { + GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) + // PopulateInitialProcesses is called to populate the initial process tree (parsed from /proc) for all containers. + PopulateInitialProcesses() error + + // ReportEvent will be called to report new exec events to the process manager. + ReportEvent(eventType utils.EventType, event utils.K8sEvent) + ContainerCallback(notif containercollection.PubSubEvent) +} diff --git a/pkg/processmanager/process_manager_mock.go b/pkg/processmanager/process_manager_mock.go new file mode 100644 index 00000000..68fcdd14 --- /dev/null +++ b/pkg/processmanager/process_manager_mock.go @@ -0,0 +1,32 @@ +package processmanager + +import ( + apitypes "github.com/armosec/armoapi-go/armotypes" + containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" + "github.com/kubescape/node-agent/pkg/utils" +) + +type ProcessManagerMock struct { +} + +var _ ProcessManagerClient = (*ProcessManagerMock)(nil) + +func CreateProcessManagerMock() *ProcessManagerMock { + return &ProcessManagerMock{} +} + +func (p *ProcessManagerMock) GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) { + return apitypes.Process{}, nil +} + +func (p *ProcessManagerMock) PopulateInitialProcesses() error { + return nil +} + +func (p *ProcessManagerMock) ReportEvent(eventType utils.EventType, event utils.K8sEvent) { + // no-op +} + +func (p *ProcessManagerMock) ContainerCallback(notif containercollection.PubSubEvent) { + // no-op +} diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go new file mode 100644 index 00000000..74a95e5c --- /dev/null +++ b/pkg/processmanager/v1/process_manager.go @@ -0,0 +1,419 @@ +package processmanager + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/goradd/maps" + containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" + "github.com/prometheus/procfs" + + apitypes "github.com/armosec/armoapi-go/armotypes" + tracerexectype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/exec/types" + "github.com/kubescape/go-logger" + "github.com/kubescape/go-logger/helpers" + "github.com/kubescape/node-agent/pkg/utils" +) + +const ( + cleanupInterval = 1 * time.Minute + maxTreeDepth = 50 +) + +type ProcessManager struct { + containerIdToShimPid maps.SafeMap[string, uint32] + processTree maps.SafeMap[uint32, apitypes.Process] + // For testing purposes we allow to override the function that gets process info from /proc. + getProcessFromProc func(pid int) (apitypes.Process, error) +} + +func CreateProcessManager(ctx context.Context) *ProcessManager { + pm := &ProcessManager{ + getProcessFromProc: getProcessFromProc, + } + go pm.startCleanupRoutine(ctx) + return pm +} + +// PopulateInitialProcesses scans the /proc filesystem to build the initial process tree +// for all registered container shim processes. It establishes parent-child relationships +// between processes and adds them to the process tree if they are descendants of a shim. +func (p *ProcessManager) PopulateInitialProcesses() error { + if len(p.containerIdToShimPid.Keys()) == 0 { + return nil + } + + fs, err := procfs.NewFS("/proc") + if err != nil { + return fmt.Errorf("failed to open procfs: %w", err) + } + + procs, err := fs.AllProcs() + if err != nil { + return fmt.Errorf("failed to read all procs: %w", err) + } + + tempProcesses := make(map[uint32]apitypes.Process, len(procs)) + shimPIDs := make(map[uint32]struct{}) + + p.containerIdToShimPid.Range(func(_ string, shimPID uint32) bool { + shimPIDs[shimPID] = struct{}{} + return true + }) + + // First collect all processes + for _, proc := range procs { + if process, err := p.getProcessFromProc(proc.PID); err == nil { + tempProcesses[process.PID] = process + } + } + + // Then build relationships and add to tree + for pid, process := range tempProcesses { + if p.isDescendantOfShim(pid, process.PPID, shimPIDs, tempProcesses) { + if parent, exists := tempProcesses[process.PPID]; exists { + parent.Children = append(parent.Children, process) + tempProcesses[process.PPID] = parent + } + p.processTree.Set(pid, process) + } + } + + return nil +} + +// isDescendantOfShim checks if a process with the given PID is a descendant of any +// registered shim process. It traverses the process tree upwards until it either finds +// a shim process or reaches the maximum tree depth to prevent infinite loops. +func (p *ProcessManager) isDescendantOfShim(pid uint32, ppid uint32, shimPIDs map[uint32]struct{}, processes map[uint32]apitypes.Process) bool { + visited := make(map[uint32]bool) + currentPID := pid + for depth := 0; depth < maxTreeDepth; depth++ { + if currentPID == 0 || visited[currentPID] { + return false + } + visited[currentPID] = true + + if _, isShim := shimPIDs[ppid]; isShim { + return true + } + + process, exists := processes[ppid] + if !exists { + return false + } + currentPID = ppid + ppid = process.PPID + } + return false +} + +// ContainerCallback handles container lifecycle events (creation and removal). +// For new containers, it identifies the container's shim process and adds it to the tracking system. +// For removed containers, it cleans up the associated processes from the process tree. +func (p *ProcessManager) ContainerCallback(notif containercollection.PubSubEvent) { + containerID := notif.Container.Runtime.BasicRuntimeMetadata.ContainerID + + switch notif.Type { + case containercollection.EventTypeAddContainer: + containerPID := uint32(notif.Container.Pid) + if process, err := p.getProcessFromProc(int(containerPID)); err == nil { + shimPID := process.PPID + p.containerIdToShimPid.Set(containerID, shimPID) + p.addProcess(process) + } else { + logger.L().Warning("Failed to get container process info", + helpers.String("containerID", containerID), + helpers.Error(err)) + } + + case containercollection.EventTypeRemoveContainer: + if shimPID, exists := p.containerIdToShimPid.Load(containerID); exists { + p.removeProcessesUnderShim(shimPID) + p.containerIdToShimPid.Delete(containerID) + } + } +} + +// removeProcessesUnderShim removes all processes that are descendants of the specified +// shim process PID from the process tree. This is typically called when a container +// is being removed. +func (p *ProcessManager) removeProcessesUnderShim(shimPID uint32) { + var pidsToRemove []uint32 + + p.processTree.Range(func(pid uint32, process apitypes.Process) bool { + currentPID := pid + visited := make(map[uint32]bool) + + for currentPID != 0 && !visited[currentPID] { + visited[currentPID] = true + if proc, exists := p.processTree.Load(currentPID); exists { + if proc.PPID == shimPID { + pidsToRemove = append(pidsToRemove, pid) + break + } + currentPID = proc.PPID + } else { + break + } + } + return true + }) + + // Remove in reverse order to handle parent-child relationships + for i := len(pidsToRemove) - 1; i >= 0; i-- { + p.removeProcess(pidsToRemove[i]) + } +} + +// addProcess adds or updates a process in the process tree and maintains the +// parent-child relationships between processes. If the process already exists +// with a different parent, it updates the relationships accordingly. +func (p *ProcessManager) addProcess(process apitypes.Process) { + // First, check if the process already exists and has a different parent + if existingProc, exists := p.processTree.Load(process.PID); exists && existingProc.PPID != process.PPID { + // Remove from old parent's children list + if oldParent, exists := p.processTree.Load(existingProc.PPID); exists { + newChildren := make([]apitypes.Process, 0, len(oldParent.Children)) + for _, child := range oldParent.Children { + if child.PID != process.PID { + newChildren = append(newChildren, child) + } + } + oldParent.Children = newChildren + p.processTree.Set(oldParent.PID, oldParent) + } + } + + // Update the process in the tree + p.processTree.Set(process.PID, process) + + // Update new parent's children list + if parent, exists := p.processTree.Load(process.PPID); exists { + newChildren := make([]apitypes.Process, 0, len(parent.Children)+1) + hasProcess := false + for _, child := range parent.Children { + if child.PID == process.PID { + hasProcess = true + newChildren = append(newChildren, process) + } else { + newChildren = append(newChildren, child) + } + } + if !hasProcess { + newChildren = append(newChildren, process) + } + parent.Children = newChildren + p.processTree.Set(parent.PID, parent) + } +} + +// removeProcess removes a process from the process tree and updates the parent-child +// relationships. Children of the removed process are reassigned to their grandparent +// to maintain the process hierarchy. +func (p *ProcessManager) removeProcess(pid uint32) { + if process, exists := p.processTree.Load(pid); exists { + if parent, exists := p.processTree.Load(process.PPID); exists { + newChildren := make([]apitypes.Process, 0, len(parent.Children)) + for _, child := range parent.Children { + if child.PID != pid { + newChildren = append(newChildren, child) + } + } + parent.Children = newChildren + p.processTree.Set(parent.PID, parent) + } + + for _, child := range process.Children { + if childProcess, exists := p.processTree.Load(child.PID); exists { + childProcess.PPID = process.PPID + p.addProcess(childProcess) + } + } + + p.processTree.Delete(pid) + } +} + +// GetProcessTreeForPID retrieves the process tree for a specific PID within a container. +// It returns the process and all its ancestors up to the container's shim process. +// If the process is not in the tree, it attempts to fetch it from /proc. +func (p *ProcessManager) GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) { + if !p.containerIdToShimPid.Has(containerID) { + return apitypes.Process{}, fmt.Errorf("container ID %s not found", containerID) + } + + targetPID := uint32(pid) + if !p.processTree.Has(targetPID) { + process, err := p.getProcessFromProc(pid) + if err != nil { + return apitypes.Process{}, fmt.Errorf("process %d not found: %v", pid, err) + } + p.addProcess(process) + } + + result := p.processTree.Get(targetPID) + currentPID := result.PPID + seen := make(map[uint32]bool) + + for currentPID != p.containerIdToShimPid.Get(containerID) && currentPID != 0 { + if seen[currentPID] { + break + } + seen[currentPID] = true + + if p.processTree.Has(currentPID) { + parent := p.processTree.Get(currentPID) + parentCopy := parent + parentCopy.Children = []apitypes.Process{result} + result = parentCopy + currentPID = parent.PPID + } else { + break + } + } + + return result, nil +} + +// ReportEvent handles process execution events from the system. +// It specifically processes execve events to track new process creations +// and updates the process tree accordingly. +func (p *ProcessManager) ReportEvent(eventType utils.EventType, event utils.K8sEvent) { + if eventType != utils.ExecveEventType { + return + } + + execEvent, ok := event.(*tracerexectype.Event) + if !ok { + return + } + + process := apitypes.Process{ + PID: uint32(execEvent.Pid), + PPID: uint32(execEvent.Ppid), + Comm: execEvent.Comm, + Uid: &execEvent.Uid, + Gid: &execEvent.Gid, + Hardlink: execEvent.ExePath, + UpperLayer: &execEvent.UpperLayer, + Path: execEvent.ExePath, + Cwd: execEvent.Cwd, + Pcomm: execEvent.Pcomm, + Cmdline: strings.Join(execEvent.Args, " "), + Children: []apitypes.Process{}, + } + + p.addProcess(process) +} + +// startCleanupRoutine starts a goroutine that periodically runs the cleanup +// function to remove dead processes from the process tree. It continues until +// the context is cancelled. +// TODO: Register eBPF tracer to get process exit events and remove dead processes immediately. +func (p *ProcessManager) startCleanupRoutine(ctx context.Context) { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.cleanup() + case <-ctx.Done(): + return + } + } +} + +// cleanup removes dead processes from the process tree by checking if each +// process in the tree is still alive in the system. +func (p *ProcessManager) cleanup() { + deadPids := make(map[uint32]bool) + p.processTree.Range(func(pid uint32, _ apitypes.Process) bool { + if !isProcessAlive(int(pid)) { + deadPids[pid] = true + } + return true + }) + + for pid := range deadPids { + logger.L().Debug("Removing dead process", helpers.Int("pid", int(pid))) + p.removeProcess(pid) + } +} + +// getProcessFromProc retrieves process information from the /proc filesystem +// for a given PID. It collects various process attributes such as command line, +// working directory, and user/group IDs. +func getProcessFromProc(pid int) (apitypes.Process, error) { + proc, err := procfs.NewProc(pid) + if err != nil { + return apitypes.Process{}, fmt.Errorf("failed to get process info: %v", err) + } + + stat, err := utils.GetProcessStat(pid) + if err != nil { + return apitypes.Process{}, fmt.Errorf("failed to get process stat: %v", err) + } + + var uid, gid uint32 + if status, err := proc.NewStatus(); err == nil { + if len(status.UIDs) > 1 { + uid = uint32(status.UIDs[1]) + } + if len(status.GIDs) > 1 { + gid = uint32(status.GIDs[1]) + } + } + + cmdline, _ := proc.CmdLine() + if len(cmdline) == 0 { + cmdline = []string{stat.Comm} + } + + cwd, _ := proc.Cwd() + path, _ := proc.Executable() + pcomm := func() string { + if stat.PPID <= 0 { + return "" + } + + parentProc, err := procfs.NewProc(stat.PPID) + if err != nil { + return "" + } + + parentStat, err := parentProc.Stat() + if err != nil { + return "" + } + + return parentStat.Comm + }() + + return apitypes.Process{ + PID: uint32(pid), + PPID: uint32(stat.PPID), + Comm: stat.Comm, + Pcomm: pcomm, + Uid: &uid, + Gid: &gid, + Cmdline: strings.Join(cmdline, " "), + Cwd: cwd, + Path: path, + Children: []apitypes.Process{}, + }, nil +} + +// isProcessAlive checks if a process with the given PID is still running +// by attempting to read its information from the /proc filesystem. +func isProcessAlive(pid int) bool { + proc, err := procfs.NewProc(pid) + if err != nil { + return false + } + _, err = proc.Stat() + return err == nil +} diff --git a/pkg/processmanager/v1/process_manager_test.go b/pkg/processmanager/v1/process_manager_test.go new file mode 100644 index 00000000..6405763f --- /dev/null +++ b/pkg/processmanager/v1/process_manager_test.go @@ -0,0 +1,1046 @@ +package processmanager + +import ( + "context" + "fmt" + "sync" + "testing" + + apitypes "github.com/armosec/armoapi-go/armotypes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" + tracerexectype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/exec/types" + "github.com/inspektor-gadget/inspektor-gadget/pkg/types" + "github.com/kubescape/node-agent/pkg/utils" +) + +// Helper function type definition +type mockProcessAdder func(pid int, ppid uint32, comm string) + +// Updated setup function with correct return types +func setupTestProcessManager(t *testing.T) (*ProcessManager, mockProcessAdder) { + ctx, cancel := context.WithCancel(context.Background()) + pm := CreateProcessManager(ctx) + + // Create process mock map + mockProcesses := make(map[int]apitypes.Process) + + // Store original function + originalGetProcessFromProc := pm.getProcessFromProc + + // Replace with mock version + pm.getProcessFromProc = func(pid int) (apitypes.Process, error) { + if proc, exists := mockProcesses[pid]; exists { + return proc, nil + } + return apitypes.Process{}, fmt.Errorf("mock process not found: %d", pid) + } + + // Set up cleanup + t.Cleanup(func() { + cancel() + pm.getProcessFromProc = originalGetProcessFromProc + }) + + // Return the process manager and the mock process adder function + return pm, func(pid int, ppid uint32, comm string) { + uid := uint32(1000) + gid := uint32(1000) + mockProcesses[pid] = apitypes.Process{ + PID: uint32(pid), + PPID: ppid, + Comm: comm, + Cmdline: comm, + Uid: &uid, + Gid: &gid, + } + } +} + +func TestProcessManagerBasics(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container-1" + shimPID := uint32(999) + containerPID := uint32(1000) + + // Add mock container process with shim as parent + addMockProcess(int(containerPID), shimPID, "container-main") + + // Register container + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + // Verify shim was recorded + assert.True(t, pm.containerIdToShimPid.Has(containerID)) + assert.Equal(t, shimPID, pm.containerIdToShimPid.Get(containerID)) + + // Verify container process was added + containerProc, exists := pm.processTree.Load(containerPID) + assert.True(t, exists) + assert.Equal(t, shimPID, containerProc.PPID) +} + +func TestProcessTracking(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container-1" + shimPID := uint32(999) + containerPID := uint32(1000) + + addMockProcess(int(containerPID), shimPID, "container-main") + + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + testCases := []struct { + name string + event tracerexectype.Event + verify func(t *testing.T, pm *ProcessManager) + }{ + { + name: "Container child process", + event: tracerexectype.Event{ + Pid: 1001, + Ppid: containerPID, + Comm: "nginx", + Args: []string{"nginx", "-g", "daemon off;"}, + }, + verify: func(t *testing.T, pm *ProcessManager) { + proc, exists := pm.processTree.Load(1001) + require.True(t, exists) + assert.Equal(t, containerPID, proc.PPID) + assert.Equal(t, "nginx", proc.Comm) + }, + }, + { + name: "Exec process (direct child of shim)", + event: tracerexectype.Event{ + Pid: 1002, + Ppid: shimPID, + Comm: "bash", + Args: []string{"bash"}, + }, + verify: func(t *testing.T, pm *ProcessManager) { + proc, exists := pm.processTree.Load(1002) + require.True(t, exists) + assert.Equal(t, shimPID, proc.PPID) + assert.Equal(t, "bash", proc.Comm) + }, + }, + { + name: "Nested process", + event: tracerexectype.Event{ + Pid: 1003, + Ppid: 1001, + Comm: "nginx-worker", + Args: []string{"nginx", "worker process"}, + }, + verify: func(t *testing.T, pm *ProcessManager) { + proc, exists := pm.processTree.Load(1003) + require.True(t, exists) + assert.Equal(t, uint32(1001), proc.PPID) + + parent, exists := pm.processTree.Load(1001) + require.True(t, exists) + hasChild := false + for _, child := range parent.Children { + if child.PID == 1003 { + hasChild = true + break + } + } + assert.True(t, hasChild) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + pm.ReportEvent(utils.ExecveEventType, &tc.event) + tc.verify(t, pm) + }) + } +} + +func TestProcessRemoval(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container-1" + shimPID := uint32(999) + containerPID := uint32(1000) + + addMockProcess(int(containerPID), shimPID, "container-main") + + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + // Create a process tree + processes := []struct { + pid uint32 + ppid uint32 + comm string + }{ + {1001, containerPID, "parent"}, + {1002, 1001, "child1"}, + {1003, 1002, "grandchild1"}, + {1004, 1002, "grandchild2"}, + } + + // Add processes + for _, proc := range processes { + event := &tracerexectype.Event{ + Pid: proc.pid, + Ppid: proc.ppid, + Comm: proc.comm, + } + pm.ReportEvent(utils.ExecveEventType, event) + } + + // Verify initial structure + for _, proc := range processes { + assert.True(t, pm.processTree.Has(proc.pid)) + } + + // Remove middle process and verify tree reorganization + pm.removeProcess(1002) + + // Verify process was removed + assert.False(t, pm.processTree.Has(1002)) + + // Verify children were reassigned to parent + parent, exists := pm.processTree.Load(1001) + require.True(t, exists) + + // Should now have both grandchildren + childPIDs := make(map[uint32]bool) + for _, child := range parent.Children { + childPIDs[child.PID] = true + } + assert.True(t, childPIDs[1003]) + assert.True(t, childPIDs[1004]) + + // Verify grandchildren's PPID was updated + for _, pid := range []uint32{1003, 1004} { + proc, exists := pm.processTree.Load(pid) + require.True(t, exists) + assert.Equal(t, uint32(1001), proc.PPID) + } +} + +func TestContainerRemoval(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container-1" + shimPID := uint32(999) + containerPID := uint32(1000) + + addMockProcess(int(containerPID), shimPID, "container-main") + + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + // Create various processes under the container + processes := []struct { + pid uint32 + ppid uint32 + comm string + }{ + {containerPID, shimPID, "container-main"}, + {1001, containerPID, "app"}, + {1002, 1001, "worker"}, + {1003, shimPID, "exec"}, // direct child of shim + } + + for _, proc := range processes { + event := &tracerexectype.Event{ + Pid: proc.pid, + Ppid: proc.ppid, + Comm: proc.comm, + } + pm.ReportEvent(utils.ExecveEventType, event) + } + + // Remove container + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeRemoveContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + // Verify all processes were removed + for _, proc := range processes { + assert.False(t, pm.processTree.Has(proc.pid)) + } + + // Verify container was removed from mapping + assert.False(t, pm.containerIdToShimPid.Has(containerID)) +} + +func TestMultipleContainers(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containers := []struct { + id string + shimPID uint32 + containerPID uint32 + }{ + {"container-1", 999, 1000}, + {"container-2", 1998, 2000}, + } + + // Add containers + for _, c := range containers { + addMockProcess(int(c.containerPID), c.shimPID, fmt.Sprintf("container-%s", c.id)) + + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: c.id, + }, + }, + Pid: c.containerPID, + }, + }) + + // Add some processes to each container + event1 := &tracerexectype.Event{ + Pid: c.containerPID + 1, + Ppid: c.containerPID, + Comm: "process-1", + } + event2 := &tracerexectype.Event{ + Pid: c.containerPID + 2, + Ppid: c.shimPID, + Comm: "exec-process", + } + + pm.ReportEvent(utils.ExecveEventType, event1) + pm.ReportEvent(utils.ExecveEventType, event2) + } + + // Verify each container's processes + for _, c := range containers { + // Check container process + proc, exists := pm.processTree.Load(c.containerPID) + require.True(t, exists) + assert.Equal(t, c.shimPID, proc.PPID) + + // Check child process + childProc, exists := pm.processTree.Load(c.containerPID + 1) + require.True(t, exists) + assert.Equal(t, c.containerPID, childProc.PPID) + + // Check exec process + execProc, exists := pm.processTree.Load(c.containerPID + 2) + require.True(t, exists) + assert.Equal(t, c.shimPID, execProc.PPID) + } + + // Remove first container + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeRemoveContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containers[0].id, + }, + }, + Pid: containers[0].containerPID, + }, + }) + + // Verify first container's processes are gone + assert.False(t, pm.processTree.Has(containers[0].containerPID)) + assert.False(t, pm.processTree.Has(containers[0].containerPID+1)) + assert.False(t, pm.processTree.Has(containers[0].containerPID+2)) + + // Verify second container's processes remain + assert.True(t, pm.processTree.Has(containers[1].containerPID)) + assert.True(t, pm.processTree.Has(containers[1].containerPID+1)) + assert.True(t, pm.processTree.Has(containers[1].containerPID+2)) +} + +func TestErrorCases(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + t.Run("get non-existent process tree", func(t *testing.T) { + _, err := pm.GetProcessTreeForPID("non-existent", 1000) + assert.Error(t, err) + }) + + t.Run("process with non-existent parent", func(t *testing.T) { + containerID := "test-container" + shimPID := uint32(999) + containerPID := uint32(1000) + + addMockProcess(int(containerPID), shimPID, "container-main") + + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + // Add process with non-existent parent + event := &tracerexectype.Event{ + Pid: 2000, + Ppid: 1500, // Non-existent PPID + Comm: "orphan", + } + pm.ReportEvent(utils.ExecveEventType, event) + + // Process should still be added + assert.True(t, pm.processTree.Has(2000)) + }) +} + +func TestRaceConditions(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container" + shimPID := uint32(999) + containerPID := uint32(1000) + + // Setup container + addMockProcess(int(containerPID), shimPID, "container-main") + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + processCount := 100 + var mu sync.Mutex + processStates := make(map[uint32]struct { + added bool + removed bool + }) + + // Pre-populate process states + for i := 0; i < processCount; i++ { + pid := uint32(2000 + i) + processStates[pid] = struct { + added bool + removed bool + }{false, false} + } + + // Channel to signal between goroutines + removeDone := make(chan bool) + addDone := make(chan bool) + + // Goroutine to remove processes (run first) + go func() { + for i := 0; i < processCount; i++ { + if i%2 == 0 { + pid := uint32(2000 + i) + mu.Lock() + if state, exists := processStates[pid]; exists { + state.removed = true + processStates[pid] = state + } + mu.Unlock() + pm.removeProcess(pid) + } + } + removeDone <- true + }() + + // Wait for removals to complete before starting additions + <-removeDone + + // Goroutine to add processes + go func() { + for i := 0; i < processCount; i++ { + pid := uint32(2000 + i) + // Only add if not marked for removal + mu.Lock() + state := processStates[pid] + if !state.removed { + event := &tracerexectype.Event{ + Pid: pid, + Ppid: shimPID, + Comm: fmt.Sprintf("process-%d", i), + } + state.added = true + processStates[pid] = state + mu.Unlock() + pm.ReportEvent(utils.ExecveEventType, event) + } else { + mu.Unlock() + } + } + addDone <- true + }() + + // Wait for additions to complete + <-addDone + + // Verify final state + remainingCount := 0 + pm.processTree.Range(func(pid uint32, process apitypes.Process) bool { + if pid >= 2000 && pid < 2000+uint32(processCount) { + mu.Lock() + state := processStates[pid] + mu.Unlock() + + if state.removed { + t.Errorf("Process %d exists but was marked for removal", pid) + } + if !state.added { + t.Errorf("Process %d exists but was not marked as added", pid) + } + remainingCount++ + } + return true + }) + + // Verify all processes marked as removed are actually gone + mu.Lock() + for pid, state := range processStates { + if state.removed { + if pm.processTree.Has(pid) { + t.Errorf("Process %d was marked for removal but still exists", pid) + } + } else if state.added { + if !pm.processTree.Has(pid) { + t.Errorf("Process %d was marked as added but doesn't exist", pid) + } + } + } + mu.Unlock() + + // We expect exactly half of the processes to remain (odd-numbered ones) + expectedCount := processCount / 2 + assert.Equal(t, expectedCount, remainingCount, + "Expected exactly %d processes, got %d", expectedCount, remainingCount) + + // Verify all remaining processes have correct parent + pm.processTree.Range(func(pid uint32, process apitypes.Process) bool { + if pid >= 2000 && pid < 2000+uint32(processCount) { + assert.Equal(t, shimPID, process.PPID, + "Process %d should have shim as parent", pid) + } + return true + }) +} + +func TestDuplicateProcessHandling(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container" + shimPID := uint32(999) + containerPID := uint32(1000) + + // Setup container + addMockProcess(int(containerPID), shimPID, "container-main") + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + t.Run("update process with same parent", func(t *testing.T) { + // First add a parent process + parentEvent := &tracerexectype.Event{ + Pid: 1001, + Ppid: containerPID, + Comm: "parent-process", + Args: []string{"parent-process", "--initial"}, + } + pm.ReportEvent(utils.ExecveEventType, parentEvent) + + // Add child process + childEvent := &tracerexectype.Event{ + Pid: 1002, + Ppid: 1001, + Comm: "child-process", + Args: []string{"child-process", "--initial"}, + } + pm.ReportEvent(utils.ExecveEventType, childEvent) + + // Verify initial state + parent, exists := pm.processTree.Load(1001) + require.True(t, exists) + assert.Equal(t, "parent-process", parent.Comm) + assert.Equal(t, "parent-process --initial", parent.Cmdline) + assert.Len(t, parent.Children, 1) + assert.Equal(t, uint32(1002), parent.Children[0].PID) + + // Add same child process again with different arguments + updatedChildEvent := &tracerexectype.Event{ + Pid: 1002, + Ppid: 1001, + Comm: "child-process", + Args: []string{"child-process", "--updated"}, + } + pm.ReportEvent(utils.ExecveEventType, updatedChildEvent) + + // Verify the process was updated + updatedChild, exists := pm.processTree.Load(1002) + require.True(t, exists) + assert.Equal(t, "child-process --updated", updatedChild.Cmdline) + + // Verify parent's children list was updated + updatedParent, exists := pm.processTree.Load(1001) + require.True(t, exists) + assert.Len(t, updatedParent.Children, 1) + assert.Equal(t, "child-process --updated", updatedParent.Children[0].Cmdline) + }) + + t.Run("update process with different parent", func(t *testing.T) { + // Move process to different parent + differentParentEvent := &tracerexectype.Event{ + Pid: 1002, + Ppid: containerPID, + Comm: "child-process", + Args: []string{"child-process", "--new-parent"}, + } + pm.ReportEvent(utils.ExecveEventType, differentParentEvent) + + // Verify process was updated with new parent + movedChild, exists := pm.processTree.Load(1002) + require.True(t, exists) + assert.Equal(t, containerPID, movedChild.PPID) + assert.Equal(t, "child-process --new-parent", movedChild.Cmdline) + + // Verify old parent no longer has the child + oldParent, exists := pm.processTree.Load(1001) + require.True(t, exists) + assert.Empty(t, oldParent.Children, "Old parent should have no children") + + // Verify new parent has the child + containerProcess, exists := pm.processTree.Load(containerPID) + require.True(t, exists) + hasChild := false + for _, child := range containerProcess.Children { + if child.PID == 1002 { + hasChild = true + assert.Equal(t, "child-process --new-parent", child.Cmdline) + } + } + assert.True(t, hasChild, "New parent should have the child") + }) +} + +func TestProcessReparenting(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container" + shimPID := uint32(999) + containerPID := uint32(1000) + + // Setup container + addMockProcess(int(containerPID), shimPID, "container-main") + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + t.Run("reparent to nearest living ancestor", func(t *testing.T) { + // Create a chain of processes: + // shim -> grandparent -> parent -> child + + // Create grandparent process + grandparentPID := uint32(2000) + grandparentEvent := &tracerexectype.Event{ + Pid: grandparentPID, + Ppid: shimPID, + Comm: "grandparent", + Args: []string{"grandparent"}, + } + pm.ReportEvent(utils.ExecveEventType, grandparentEvent) + + // Create parent process + parentPID := uint32(2001) + parentEvent := &tracerexectype.Event{ + Pid: parentPID, + Ppid: grandparentPID, + Comm: "parent", + Args: []string{"parent"}, + } + pm.ReportEvent(utils.ExecveEventType, parentEvent) + + // Create child process + childPID := uint32(2002) + childEvent := &tracerexectype.Event{ + Pid: childPID, + Ppid: parentPID, + Comm: "child", + Args: []string{"child"}, + } + pm.ReportEvent(utils.ExecveEventType, childEvent) + + // Verify initial hierarchy + child, exists := pm.processTree.Load(childPID) + require.True(t, exists) + assert.Equal(t, parentPID, child.PPID) + + parent, exists := pm.processTree.Load(parentPID) + require.True(t, exists) + assert.Equal(t, grandparentPID, parent.PPID) + + // When parent dies, child should be reparented to grandparent + pm.removeProcess(parentPID) + + // Verify child was reparented to grandparent + child, exists = pm.processTree.Load(childPID) + require.True(t, exists) + assert.Equal(t, grandparentPID, child.PPID, "Child should be reparented to grandparent") + + // Verify grandparent has the child in its children list + grandparent, exists := pm.processTree.Load(grandparentPID) + require.True(t, exists) + hasChild := false + for _, c := range grandparent.Children { + if c.PID == childPID { + hasChild = true + break + } + } + assert.True(t, hasChild, "Grandparent should have the reparented child") + + // Now if grandparent dies too, child should be reparented to shim + pm.removeProcess(grandparentPID) + + child, exists = pm.processTree.Load(childPID) + require.True(t, exists) + assert.Equal(t, shimPID, child.PPID, "Child should be reparented to shim when grandparent dies") + }) + + t.Run("reparent multiple children", func(t *testing.T) { + // Create a parent with multiple children + parentPID := uint32(3000) + parentEvent := &tracerexectype.Event{ + Pid: parentPID, + Ppid: shimPID, + Comm: "parent", + Args: []string{"parent"}, + } + pm.ReportEvent(utils.ExecveEventType, parentEvent) + + // Create several children + childPIDs := []uint32{3001, 3002, 3003} + for _, pid := range childPIDs { + childEvent := &tracerexectype.Event{ + Pid: pid, + Ppid: parentPID, + Comm: fmt.Sprintf("child-%d", pid), + Args: []string{"child"}, + } + pm.ReportEvent(utils.ExecveEventType, childEvent) + } + + // Create a subprocess under one of the children + grandchildPID := uint32(3004) + grandchildEvent := &tracerexectype.Event{ + Pid: grandchildPID, + Ppid: childPIDs[0], + Comm: "grandchild", + Args: []string{"grandchild"}, + } + pm.ReportEvent(utils.ExecveEventType, grandchildEvent) + + // When parent dies, all direct children should be reparented to shim + pm.removeProcess(parentPID) + + // Verify all children were reparented to shim + for _, childPID := range childPIDs { + child, exists := pm.processTree.Load(childPID) + require.True(t, exists) + assert.Equal(t, shimPID, child.PPID, "Child should be reparented to shim") + } + + // When first child dies, its grandchild should be reparented to shim too + pm.removeProcess(childPIDs[0]) + + grandchild, exists := pm.processTree.Load(grandchildPID) + require.True(t, exists) + assert.Equal(t, shimPID, grandchild.PPID, "Grandchild should be reparented to shim") + }) +} + +func TestRemoveProcessesUnderShim(t *testing.T) { + tests := []struct { + name string + initialTree map[uint32]apitypes.Process + shimPID uint32 + expectedTree map[uint32]apitypes.Process + description string + }{ + { + name: "simple_process_tree", + initialTree: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim", Children: []apitypes.Process{}}, // shim process + 200: {PID: 200, PPID: 100, Comm: "parent", Children: []apitypes.Process{}}, // direct child of shim + 201: {PID: 201, PPID: 200, Comm: "child1", Children: []apitypes.Process{}}, // child of parent + 202: {PID: 202, PPID: 200, Comm: "child2", Children: []apitypes.Process{}}, // another child of parent + }, + shimPID: 100, + expectedTree: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim", Children: []apitypes.Process{}}, // only shim remains + }, + description: "Should remove all processes under shim including children of children", + }, + { + name: "empty_tree", + initialTree: map[uint32]apitypes.Process{}, + shimPID: 100, + expectedTree: map[uint32]apitypes.Process{}, + description: "Should handle empty process tree gracefully", + }, + { + name: "orphaned_processes", + initialTree: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim", Children: []apitypes.Process{}}, // shim process + 200: {PID: 200, PPID: 100, Comm: "parent", Children: []apitypes.Process{}}, // direct child of shim + 201: {PID: 201, PPID: 999, Comm: "orphan", Children: []apitypes.Process{}}, // orphaned process (parent doesn't exist) + }, + shimPID: 100, + expectedTree: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim", Children: []apitypes.Process{}}, // shim remains + 201: {PID: 201, PPID: 999, Comm: "orphan", Children: []apitypes.Process{}}, // orphan unaffected + }, + description: "Should handle orphaned processes correctly", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create process manager with test data + pm := &ProcessManager{} + + // Populate initial process tree + for pid, process := range tc.initialTree { + pm.processTree.Set(pid, process) + } + + // Call the function under test + pm.removeProcessesUnderShim(tc.shimPID) + + // Verify results + assert.Equal(t, len(tc.expectedTree), len(pm.processTree.Keys()), + "Process tree size mismatch after removal") + + // Check each expected process + for pid, expectedProcess := range tc.expectedTree { + actualProcess, exists := pm.processTree.Load(pid) + assert.True(t, exists, "Expected process %d not found in tree", pid) + assert.Equal(t, expectedProcess, actualProcess, + "Process %d details don't match expected values", pid) + } + + // Verify no unexpected processes remain + pm.processTree.Range(func(pid uint32, process apitypes.Process) bool { + _, shouldExist := tc.expectedTree[pid] + assert.True(t, shouldExist, + "Unexpected process %d found in tree", pid) + return true + }) + }) + } +} + +func TestIsDescendantOfShim(t *testing.T) { + tests := []struct { + name string + processes map[uint32]apitypes.Process + shimPIDs map[uint32]struct{} + pid uint32 + ppid uint32 + expected bool + description string + }{ + { + name: "direct_child_of_shim", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim"}, + 200: {PID: 200, PPID: 100, Comm: "child"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + }, + pid: 200, + ppid: 100, + expected: true, + description: "Process is a direct child of shim", + }, + { + name: "indirect_descendant", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim"}, + 200: {PID: 200, PPID: 100, Comm: "parent"}, + 300: {PID: 300, PPID: 200, Comm: "child"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + }, + pid: 300, + ppid: 200, + expected: true, + description: "Process is an indirect descendant of shim", + }, + { + name: "not_a_descendant", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim"}, + 200: {PID: 200, PPID: 2, Comm: "unrelated"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + }, + pid: 200, + ppid: 2, + expected: false, + description: "Process is not a descendant of any shim", + }, + { + name: "circular_reference", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim"}, + 200: {PID: 200, PPID: 300, Comm: "circular1"}, + 300: {PID: 300, PPID: 200, Comm: "circular2"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + }, + pid: 200, + ppid: 300, + expected: false, + description: "Process is part of a circular reference", + }, + { + name: "process_chain_exceeds_max_depth", + processes: func() map[uint32]apitypes.Process { + // Create a chain where the target process is maxTreeDepth + 1 steps away from any shim + procs := map[uint32]apitypes.Process{ + 1: {PID: 1, PPID: 0, Comm: "init"}, // init process + 2: {PID: 2, PPID: 1, Comm: "shim"}, // shim process + } + // Create a chain starting far from the shim + currentPPID := uint32(100) // Start with a different base to avoid conflicts + targetPID := uint32(100 + maxTreeDepth + 1) + + // Build the chain backwards from target to base + for pid := targetPID; pid > currentPPID; pid-- { + procs[pid] = apitypes.Process{ + PID: pid, + PPID: pid - 1, + Comm: fmt.Sprintf("process-%d", pid), + } + } + // Add the base process that's not connected to shim + procs[currentPPID] = apitypes.Process{ + PID: currentPPID, + PPID: currentPPID - 1, + Comm: fmt.Sprintf("process-%d", currentPPID), + } + return procs + }(), + shimPIDs: map[uint32]struct{}{ + 2: {}, // Shim PID + }, + pid: uint32(100 + maxTreeDepth + 1), // Target process at the end of chain + ppid: uint32(100 + maxTreeDepth), // Its immediate parent + expected: false, + description: "Process chain exceeds maximum allowed depth", + }, + { + name: "multiple_shims", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim1"}, + 101: {PID: 101, PPID: 1, Comm: "shim2"}, + 200: {PID: 200, PPID: 100, Comm: "child1"}, + 201: {PID: 201, PPID: 101, Comm: "child2"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + 101: {}, + }, + pid: 200, + ppid: 100, + expected: true, + description: "Multiple shims in the system", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + pm := &ProcessManager{} + result := pm.isDescendantOfShim(tc.pid, tc.ppid, tc.shimPIDs, tc.processes) + assert.Equal(t, tc.expected, result, tc.description) + }) + } +} diff --git a/pkg/ruleengine/v1/r0004_unexpected_capability_used.go b/pkg/ruleengine/v1/r0004_unexpected_capability_used.go index 4f3b2571..c4d8e906 100644 --- a/pkg/ruleengine/v1/r0004_unexpected_capability_used.go +++ b/pkg/ruleengine/v1/r0004_unexpected_capability_used.go @@ -3,6 +3,7 @@ package ruleengine import ( "fmt" + "github.com/goradd/maps" "github.com/kubescape/node-agent/pkg/objectcache" "github.com/kubescape/node-agent/pkg/ruleengine" "github.com/kubescape/node-agent/pkg/utils" @@ -34,6 +35,7 @@ var _ ruleengine.RuleEvaluator = (*R0004UnexpectedCapabilityUsed)(nil) type R0004UnexpectedCapabilityUsed struct { BaseRule + alertedCapabilities maps.SafeMap[string, bool] } func CreateRuleR0004UnexpectedCapabilityUsed() *R0004UnexpectedCapabilityUsed { @@ -76,6 +78,10 @@ func (rule *R0004UnexpectedCapabilityUsed) ProcessEvent(eventType utils.EventTyp return nil } + if rule.alertedCapabilities.Has(capEvent.CapName) { + return nil + } + for _, capability := range appProfileCapabilitiesList.Capabilities { if capEvent.CapName == capability { return nil @@ -112,6 +118,8 @@ func (rule *R0004UnexpectedCapabilityUsed) ProcessEvent(eventType utils.EventTyp RuleID: rule.ID(), } + rule.alertedCapabilities.Set(capEvent.CapName, true) + return &ruleFailure } diff --git a/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access.go b/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access.go index b99e43b3..6e74dad3 100644 --- a/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access.go +++ b/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access.go @@ -2,6 +2,7 @@ package ruleengine import ( "fmt" + "path/filepath" "strings" events "github.com/kubescape/node-agent/pkg/ebpf/events" @@ -99,15 +100,7 @@ func (rule *R0010UnexpectedSensitiveFileAccess) ProcessEvent(eventType utils.Eve return nil } - isSensitive := false - for _, path := range rule.additionalPaths { - if strings.HasPrefix(openEvent.FullPath, path) { - isSensitive = true - break - } - } - - if !isSensitive { + if !isSensitivePath(openEvent.FullPath, rule.additionalPaths) { return nil } @@ -157,3 +150,30 @@ func (rule *R0010UnexpectedSensitiveFileAccess) Requirements() ruleengine.RuleSp EventTypes: R0010UnexpectedSensitiveFileAccessRuleDescriptor.Requirements.RequiredEventTypes(), } } + +// isSensitivePath checks if a given path matches or is within any sensitive paths +func isSensitivePath(fullPath string, paths []string) bool { + // Clean the path to handle "..", "//", etc. + fullPath = filepath.Clean(fullPath) + + for _, sensitivePath := range paths { + sensitivePath = filepath.Clean(sensitivePath) + + // Check if the path exactly matches + if fullPath == sensitivePath { + return true + } + + // Check if the path is a directory that contains sensitive files + if strings.HasPrefix(sensitivePath, fullPath+"/") { + return true + } + + // Check if the path is within a sensitive directory + if strings.HasPrefix(fullPath, sensitivePath+"/") { + return true + } + } + + return false +} diff --git a/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access_test.go b/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access_test.go index e35abe05..e45d7f2b 100644 --- a/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access_test.go +++ b/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access_test.go @@ -3,25 +3,14 @@ package ruleengine import ( "testing" - "github.com/kubescape/node-agent/pkg/objectcache" - "github.com/kubescape/node-agent/pkg/utils" - - "github.com/kubescape/storage/pkg/apis/softwarecomposition/v1beta1" - traceropentype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/open/types" eventtypes "github.com/inspektor-gadget/inspektor-gadget/pkg/types" + "github.com/kubescape/node-agent/pkg/utils" + "github.com/kubescape/storage/pkg/apis/softwarecomposition/v1beta1" ) -func TestR0010UnexpectedSensitiveFileAccess(t *testing.T) { - // Create a new rule - r := CreateRuleR0010UnexpectedSensitiveFileAccess() - // Assert r is not nil - if r == nil { - t.Errorf("Expected r to not be nil") - } - - // Create a file access event - e := &traceropentype.Event{ +func createTestEvent(path string, flags []string) *traceropentype.Event { + return &traceropentype.Event{ Event: eventtypes.Event{ CommonData: eventtypes.CommonData{ K8s: eventtypes.K8sMetadata{ @@ -31,103 +20,127 @@ func TestR0010UnexpectedSensitiveFileAccess(t *testing.T) { }, }, }, - Path: "/test", - FullPath: "/test", - Flags: []string{"O_RDONLY"}, - } - - // Test with nil appProfileAccess - ruleResult := r.ProcessEvent(utils.OpenEventType, e, &objectcache.ObjectCacheMock{}) - if ruleResult != nil { - t.Errorf("Expected ruleResult to not be nil since no appProfile") + Path: path, + FullPath: path, + Flags: flags, } +} - // Test with whitelisted file - objCache := RuleObjectCacheMock{} - profile := objCache.ApplicationProfileCache().GetApplicationProfile("test") - if profile == nil { - profile = &v1beta1.ApplicationProfile{ - Spec: v1beta1.ApplicationProfileSpec{ - Containers: []v1beta1.ApplicationProfileContainer{ - { - Name: "test", - Opens: []v1beta1.OpenCalls{ - { - Path: "/test", - Flags: []string{"O_RDONLY"}, - }, - }, - }, - }, - }, +func createTestProfile(containerName string, paths []string, flags []string) *v1beta1.ApplicationProfile { + opens := make([]v1beta1.OpenCalls, len(paths)) + for i, path := range paths { + opens[i] = v1beta1.OpenCalls{ + Path: path, + Flags: flags, } - objCache.SetApplicationProfile(profile) - } - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is whitelisted and not sensitive") - } - - // Test with non whitelisted file, but not sensitive - e.FullPath = "/var/test1" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is not whitelisted and not sensitive") - } - - // Test with sensitive file that is whitelisted - e.FullPath = "/etc/shadow" - profile.Spec.Containers[0].Opens[0].Path = "/etc/shadow" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is whitelisted and sensitive") - } - - // Test with sensitive file, but not whitelisted - e.FullPath = "/etc/shadow" - profile.Spec.Containers[0].Opens[0].Path = "/test" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult == nil { - t.Errorf("Expected ruleResult to not be nil since file is not whitelisted and sensitive") - } - - // Test with sensitive file that originates from additionalPaths parameter - e.FullPath = "/etc/blabla" - profile.Spec.Containers[0].Opens[0].Path = "/test" - additionalPaths := []interface{}{"/etc/blabla"} - r.SetParameters(map[string]interface{}{"additionalPaths": additionalPaths}) - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult == nil { - t.Errorf("Expected ruleResult to not be nil since file is not whitelisted and sensitive") } - e.FullPath = "/tmp/blabla" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is whitelisted and not sensitive") - } - - profile = &v1beta1.ApplicationProfile{ + return &v1beta1.ApplicationProfile{ Spec: v1beta1.ApplicationProfileSpec{ Containers: []v1beta1.ApplicationProfileContainer{ { - Name: "test", - Opens: []v1beta1.OpenCalls{ - { - Path: "/etc/\u22ef", - Flags: []string{"O_RDONLY"}, - }, - }, + Name: containerName, + Opens: opens, }, }, }, } - objCache.SetApplicationProfile(profile) +} - e.FullPath = "/etc/blabla" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is whitelisted and not sensitive") +func TestR0010UnexpectedSensitiveFileAccess(t *testing.T) { + tests := []struct { + name string + event *traceropentype.Event + profile *v1beta1.ApplicationProfile + additionalPaths []interface{} + expectAlert bool + description string + }{ + { + name: "No application profile", + event: createTestEvent("/test", []string{"O_RDONLY"}), + profile: nil, + expectAlert: false, + description: "Should not alert when no application profile is present", + }, + { + name: "Whitelisted non-sensitive file", + event: createTestEvent("/test", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + expectAlert: false, + description: "Should not alert for whitelisted non-sensitive file", + }, + { + name: "Non-whitelisted non-sensitive file", + event: createTestEvent("/var/test1", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + expectAlert: false, + description: "Should not alert for non-whitelisted non-sensitive file", + }, + { + name: "Whitelisted sensitive file", + event: createTestEvent("/etc/shadow", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/etc/shadow"}, []string{"O_RDONLY"}), + expectAlert: false, + description: "Should not alert for whitelisted sensitive file", + }, + { + name: "Non-whitelisted sensitive file", + event: createTestEvent("/etc/shadow", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + expectAlert: true, + description: "Should alert for non-whitelisted sensitive file", + }, + { + name: "Additional sensitive path", + event: createTestEvent("/etc/custom-sensitive", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + additionalPaths: []interface{}{"/etc/custom-sensitive"}, + expectAlert: true, + description: "Should alert for non-whitelisted file in additional sensitive paths", + }, + { + name: "Wildcard path match", + event: createTestEvent("/etc/blabla", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/etc/\u22ef"}, []string{"O_RDONLY"}), + expectAlert: false, + description: "Should not alert when path matches wildcard pattern", + }, + { + name: "Path traversal attempt", + event: createTestEvent("/etc/shadow/../passwd", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + expectAlert: true, + description: "Should alert for path traversal attempts", + }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := CreateRuleR0010UnexpectedSensitiveFileAccess() + if rule == nil { + t.Fatal("Expected rule to not be nil") + } + + objCache := &RuleObjectCacheMock{} + if tt.profile != nil { + objCache.SetApplicationProfile(tt.profile) + } + + if tt.additionalPaths != nil { + rule.SetParameters(map[string]interface{}{ + "additionalPaths": tt.additionalPaths, + }) + } + + result := rule.ProcessEvent(utils.OpenEventType, tt.event, objCache) + + if tt.expectAlert && result == nil { + t.Errorf("%s: expected alert but got none", tt.description) + } + if !tt.expectAlert && result != nil { + t.Errorf("%s: expected no alert but got one", tt.description) + } + }) + } } diff --git a/pkg/rulemanager/v1/rule_manager.go b/pkg/rulemanager/v1/rule_manager.go index d0c0a762..d06dda14 100644 --- a/pkg/rulemanager/v1/rule_manager.go +++ b/pkg/rulemanager/v1/rule_manager.go @@ -10,6 +10,7 @@ import ( "github.com/kubescape/node-agent/pkg/config" "github.com/kubescape/node-agent/pkg/exporters" "github.com/kubescape/node-agent/pkg/k8sclient" + "github.com/kubescape/node-agent/pkg/processmanager" "github.com/kubescape/node-agent/pkg/ruleengine" "github.com/kubescape/node-agent/pkg/rulemanager" "github.com/kubescape/node-agent/pkg/utils" @@ -65,11 +66,12 @@ type RuleManager struct { containerIdToShimPid maps.SafeMap[string, uint32] containerIdToPid maps.SafeMap[string, uint32] enricher ruleenginetypes.Enricher + processManager processmanager.ProcessManagerClient } var _ rulemanager.RuleManagerClient = (*RuleManager)(nil) -func CreateRuleManager(ctx context.Context, cfg config.Config, k8sClient k8sclient.K8sClientInterface, ruleBindingCache bindingcache.RuleBindingCache, objectCache objectcache.ObjectCache, exporter exporters.Exporter, metrics metricsmanager.MetricsManager, nodeName string, clusterName string, enricher ruleenginetypes.Enricher) (*RuleManager, error) { +func CreateRuleManager(ctx context.Context, cfg config.Config, k8sClient k8sclient.K8sClientInterface, ruleBindingCache bindingcache.RuleBindingCache, objectCache objectcache.ObjectCache, exporter exporters.Exporter, metrics metricsmanager.MetricsManager, nodeName string, clusterName string, processManager processmanager.ProcessManagerClient, enricher ruleenginetypes.Enricher) (*RuleManager, error) { return &RuleManager{ cfg: cfg, ctx: ctx, @@ -83,6 +85,7 @@ func CreateRuleManager(ctx context.Context, cfg config.Config, k8sClient k8sclie nodeName: nodeName, clusterName: clusterName, enricher: enricher, + processManager: processManager, }, nil } @@ -356,8 +359,13 @@ func (rm *RuleManager) processEvent(eventType utils.EventType, event utils.K8sEv } } func (rm *RuleManager) enrichRuleFailure(ruleFailure ruleengine.RuleFailure) ruleengine.RuleFailure { - path, err := utils.GetPathFromPid(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID) - hostPath := "" + var err error + var path string + var hostPath string + if ruleFailure.GetRuntimeProcessDetails().ProcessTree.Path == "" { + path, err = utils.GetPathFromPid(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID) + } + if err != nil { if ruleFailure.GetRuntimeProcessDetails().ProcessTree.Path != "" { hostPath = filepath.Join("/proc", fmt.Sprintf("/%d/root/%s", rm.containerIdToPid.Get(ruleFailure.GetTriggerEvent().Runtime.ContainerID), ruleFailure.GetRuntimeProcessDetails().ProcessTree.Path)) @@ -395,52 +403,30 @@ func (rm *RuleManager) enrichRuleFailure(ruleFailure ruleengine.RuleFailure) rul ruleFailure.SetBaseRuntimeAlert(baseRuntimeAlert) runtimeProcessDetails := ruleFailure.GetRuntimeProcessDetails() - if runtimeProcessDetails.ProcessTree.Cmdline == "" { - commandLine, err := utils.GetCmdlineByPid(int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID)) - if err != nil { - runtimeProcessDetails.ProcessTree.Cmdline = "" - } else { - runtimeProcessDetails.ProcessTree.Cmdline = *commandLine - } - } - - if runtimeProcessDetails.ProcessTree.PPID == 0 { - parent, err := utils.GetProcessStat(int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID)) - if err != nil { - runtimeProcessDetails.ProcessTree.PPID = 0 - } else { - runtimeProcessDetails.ProcessTree.PPID = uint32(parent.PPID) - } - - if runtimeProcessDetails.ProcessTree.Pcomm == "" { - if err == nil { - runtimeProcessDetails.ProcessTree.Pcomm = parent.Comm - } else { - runtimeProcessDetails.ProcessTree.Pcomm = "" - } - } - } - - if runtimeProcessDetails.ProcessTree.PID == 0 { - runtimeProcessDetails.ProcessTree.PID = ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID - } - if runtimeProcessDetails.ProcessTree.Comm == "" { - comm, err := utils.GetCommFromPid(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID) + err = backoff.Retry(func() error { + tree, err := rm.processManager.GetProcessTreeForPID( + ruleFailure.GetRuntimeProcessDetails().ContainerID, + int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID), + ) if err != nil { - comm = "" + return err } - runtimeProcessDetails.ProcessTree.Comm = comm - } - - if runtimeProcessDetails.ProcessTree.Path == "" && path != "" { - runtimeProcessDetails.ProcessTree.Path = path - } - - if rm.containerIdToShimPid.Has(ruleFailure.GetRuntimeProcessDetails().ContainerID) { - shimPid := rm.containerIdToShimPid.Get(ruleFailure.GetRuntimeProcessDetails().ContainerID) - tree, err := utils.CreateProcessTree(&runtimeProcessDetails.ProcessTree, shimPid) - if err == nil { + runtimeProcessDetails.ProcessTree = tree + return nil + }, backoff.NewExponentialBackOff( + backoff.WithInitialInterval(50*time.Millisecond), + backoff.WithMaxInterval(200*time.Millisecond), + backoff.WithMaxElapsedTime(500*time.Millisecond), + )) + + if err != nil && rm.containerIdToShimPid.Has(ruleFailure.GetRuntimeProcessDetails().ContainerID) { + logger.L().Debug("RuleManager - failed to get process tree, trying to get process tree from shim", + helpers.Error(err), + helpers.String("container ID", ruleFailure.GetRuntimeProcessDetails().ContainerID)) + + if tree, err := utils.CreateProcessTree(&runtimeProcessDetails.ProcessTree, + rm.containerIdToShimPid.Get(ruleFailure.GetRuntimeProcessDetails().ContainerID)); err == nil { runtimeProcessDetails.ProcessTree = *tree } } diff --git a/tests/component_test.go b/tests/component_test.go index bf2cc4f6..f23e5796 100644 --- a/tests/component_test.go +++ b/tests/component_test.go @@ -19,6 +19,7 @@ import ( spdxv1beta1client "github.com/kubescape/storage/pkg/generated/clientset/versioned/typed/softwarecomposition/v1beta1" "github.com/kubescape/storage/pkg/registry/file/dynamicpathdetector" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -720,158 +721,407 @@ func getEndpoint(endpoints []v1beta1.HTTPEndpoint, endpoint v1beta1.HTTPEndpoint } -// func Test_10_DemoTest(t *testing.T) { -// start := time.Now() -// defer tearDownTest(t, start) - -// //testutils.IncreaseNodeAgentSniffingTime("2m") -// wl, err := testutils.NewTestWorkload("default", path.Join(utils.CurrentDir(), "resources/ping-app-role.yaml")) -// if err != nil { -// t.Errorf("Error creating role: %v", err) -// } - -// wl, err = testutils.NewTestWorkload("default", path.Join(utils.CurrentDir(), "resources/ping-app-role-binding.yaml")) -// if err != nil { -// t.Errorf("Error creating role binding: %v", err) -// } - -// wl, err = testutils.NewTestWorkload("default", path.Join(utils.CurrentDir(), "resources/ping-app-service.yaml")) -// if err != nil { -// t.Errorf("Error creating service: %v", err) -// } - -// wl, err = testutils.NewTestWorkload("default", path.Join(utils.CurrentDir(), "resources/ping-app.yaml")) -// if err != nil { -// t.Errorf("Error creating workload: %v", err) -// } -// assert.NoError(t, wl.WaitForReady(80)) -// _, _, err = wl.ExecIntoPod([]string{"sh", "-c", "ping 1.1.1.1 -c 4"}, "") -// err = wl.WaitForApplicationProfileCompletion(80) -// if err != nil { -// t.Errorf("Error waiting for application profile to be completed: %v", err) -// } -// // err = wl.WaitForNetworkNeighborhoodCompletion(80) -// // if err != nil { -// // t.Errorf("Error waiting for network neighborhood to be completed: %v", err) -// // } - -// // Do a ls command using command injection in the ping command -// _, _, err = wl.ExecIntoPod([]string{"sh", "-c", "ping 1.1.1.1 -c 4;ls"}, "ping-app") -// if err != nil { -// t.Errorf("Error executing remote command: %v", err) -// } - -// // Do a cat command using command injection in the ping command -// _, _, err = wl.ExecIntoPod([]string{"sh", "-c", "ping 1.1.1.1 -c 4;cat /run/secrets/kubernetes.io/serviceaccount/token"}, "ping-app") -// if err != nil { -// t.Errorf("Error executing remote command: %v", err) -// } - -// // Do an uname command using command injection in the ping command -// _, _, err = wl.ExecIntoPod([]string{"sh", "-c", "ping 1.1.1.1 -c 4;uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g'"}, "ping-app") -// if err != nil { -// t.Errorf("Error executing remote command: %v", err) -// } - -// // Download kubectl -// _, _, err = wl.ExecIntoPod([]string{"sh", "-c", "ping 1.1.1.1 -c 4;curl -LO \"https://dl.k8s.io/release/$(curl -L -s https://dl.k8s.io/release/stable.txt)/bin/linux/amd64/kubectl\""}, "ping-app") -// if err != nil { -// t.Errorf("Error executing remote command: %v", err) -// } - -// // Sleep for 10 seconds to wait for the kubectl download -// time.Sleep(10 * time.Second) - -// // Make kubectl executable -// _, _, err = wl.ExecIntoPod([]string{"sh", "-c", "ping 1.1.1.1 -c 4;chmod +x kubectl"}, "ping-app") -// if err != nil { -// t.Errorf("Error executing remote command: %v", err) -// } - -// // Get the pods in the cluster -// output, _, err := wl.ExecIntoPod([]string{"sh", "-c", "ping 1.1.1.1 -c 4;./kubectl --server https://kubernetes.default --insecure-skip-tls-verify --token $(cat /run/secrets/kubernetes.io/serviceaccount/token) get pods"}, "ping-app") -// if err != nil { -// t.Errorf("Error executing remote command: %v", err) -// } - -// // Check that the output contains the pod-ping-app pod -// assert.Contains(t, output, "ping-app", "Expected output to contain 'ping-app'") - -// // Get the alerts and check that the alerts are generated -// alerts, err := testutils.GetAlerts(wl.Namespace) -// if err != nil { -// t.Errorf("Error getting alerts: %v", err) -// } - -// // Validate that all alerts are signaled -// expectedAlerts := map[string]bool{ -// "Unexpected process launched": false, -// "Unexpected file access": false, -// "Kubernetes Client Executed": false, -// // "Exec from malicious source": false, -// "Exec Binary Not In Base Image": false, -// "Unexpected Service Account Token Access": false, -// // "Unexpected domain request": false, -// } - -// for _, alert := range alerts { -// ruleName, ruleOk := alert.Labels["rule_name"] -// if ruleOk { -// if _, exists := expectedAlerts[ruleName]; exists { -// expectedAlerts[ruleName] = true -// } -// } -// } - -// for ruleName, signaled := range expectedAlerts { -// if !signaled { -// t.Errorf("Expected alert '%s' was not signaled", ruleName) -// } -// } -// } - -// func Test_11_DuplicationTest(t *testing.T) { -// start := time.Now() -// defer tearDownTest(t, start) - -// ns := testutils.NewRandomNamespace() -// // wl, err := testutils.NewTestWorkload(ns.Name, path.Join(utils.CurrentDir(), "resources/deployment-multiple-containers.yaml")) -// wl, err := testutils.NewTestWorkload(ns.Name, path.Join(utils.CurrentDir(), "resources/ping-app.yaml")) -// if err != nil { -// t.Errorf("Error creating workload: %v", err) -// } -// assert.NoError(t, wl.WaitForReady(80)) - -// err = wl.WaitForApplicationProfileCompletion(80) -// if err != nil { -// t.Errorf("Error waiting for application profile to be completed: %v", err) -// } - -// // process launched from nginx container -// _, _, err = wl.ExecIntoPod([]string{"ls", "-a"}, "ping-app") -// if err != nil { -// t.Errorf("Error executing remote command: %v", err) -// } - -// time.Sleep(20 * time.Second) - -// alerts, err := testutils.GetAlerts(wl.Namespace) -// if err != nil { -// t.Errorf("Error getting alerts: %v", err) -// } - -// // Validate that unexpected process launched alert is signaled only once -// count := 0 -// for _, alert := range alerts { -// ruleName, ruleOk := alert.Labels["rule_name"] -// if ruleOk { -// if ruleName == "Unexpected process launched" { -// count++ -// } -// } -// } - -// testutils.AssertContains(t, alerts, "Unexpected process launched", "ls", "ping-app") - -// assert.Equal(t, 1, count, "Expected 1 alert of type 'Unexpected process launched' but got %d", count) -// } +func Test_12_MergingProfilesTest(t *testing.T) { + start := time.Now() + defer tearDownTest(t, start) + + // PHASE 1: Setup workload and initial profile + ns := testutils.NewRandomNamespace() + wl, err := testutils.NewTestWorkload(ns.Name, path.Join(utils.CurrentDir(), "resources/deployment-multiple-containers.yaml")) + require.NoError(t, err, "Failed to create workload") + require.NoError(t, wl.WaitForReady(80), "Workload failed to be ready") + require.NoError(t, wl.WaitForApplicationProfile(80, "ready"), "Application profile not ready") + + // Generate initial profile data + _, _, err = wl.ExecIntoPod([]string{"ls", "-l"}, "nginx") + require.NoError(t, err, "Failed to exec into nginx container") + _, _, err = wl.ExecIntoPod([]string{"wget", "ebpf.io", "-T", "2", "-t", "1"}, "server") + require.NoError(t, err, "Failed to exec into server container") + + require.NoError(t, wl.WaitForApplicationProfileCompletion(80), "Profile failed to complete") + time.Sleep(10 * time.Second) // Allow profile processing + + // Log initial profile state + initialProfile, err := wl.GetApplicationProfile() + require.NoError(t, err, "Failed to get initial profile") + initialProfileJSON, _ := json.Marshal(initialProfile) + t.Logf("Initial application profile:\n%s", string(initialProfileJSON)) + + // PHASE 2: Verify initial alerts + t.Log("Testing initial alert generation...") + wl.ExecIntoPod([]string{"ls", "-l"}, "nginx") // Expected: no alert + wl.ExecIntoPod([]string{"ls", "-l"}, "server") // Expected: alert + time.Sleep(30 * time.Second) // Wait for alert generation + + initialAlerts, err := testutils.GetAlerts(wl.Namespace) + require.NoError(t, err, "Failed to get initial alerts") + + // Record initial alert count + initialAlertCount := 0 + for _, alert := range initialAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected process launched" { + initialAlertCount++ + } + } + + testutils.AssertContains(t, initialAlerts, "Unexpected process launched", "ls", "server") + testutils.AssertNotContains(t, initialAlerts, "Unexpected process launched", "ls", "nginx") + + // PHASE 3: Apply user-managed profile + t.Log("Applying user-managed profile...") + // Create the user-managed profile + userProfile := &v1beta1.ApplicationProfile{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("ug-%s", initialProfile.Name), + Namespace: initialProfile.Namespace, + Annotations: map[string]string{ + "kubescape.io/managed-by": "User", + }, + }, + Spec: v1beta1.ApplicationProfileSpec{ + Architectures: []string{"amd64"}, + Containers: []v1beta1.ApplicationProfileContainer{ + { + Name: "nginx", + Execs: []v1beta1.ExecCalls{ + { + Path: "/usr/bin/ls", + Args: []string{"/usr/bin/ls", "-l"}, + }, + }, + SeccompProfile: v1beta1.SingleSeccompProfile{ + Spec: v1beta1.SingleSeccompProfileSpec{ + DefaultAction: "", + }, + }, + }, + { + Name: "server", + Execs: []v1beta1.ExecCalls{ + { + Path: "/bin/ls", + Args: []string{"/bin/ls", "-l"}, + }, + { + Path: "/bin/grpc_health_probe", + Args: []string{"-addr=:9555"}, + }, + }, + SeccompProfile: v1beta1.SingleSeccompProfile{ + Spec: v1beta1.SingleSeccompProfileSpec{ + DefaultAction: "", + }, + }, + }, + }, + }, + } + + // Log the profile we're about to create + userProfileJSON, err := json.MarshalIndent(userProfile, "", " ") + require.NoError(t, err, "Failed to marshal user profile") + t.Logf("Creating user profile:\n%s", string(userProfileJSON)) + + // Get k8s client + k8sClient := k8sinterface.NewKubernetesApi() + + // Create the user-managed profile + storageClient := spdxv1beta1client.NewForConfigOrDie(k8sClient.K8SConfig) + _, err = storageClient.ApplicationProfiles(ns.Name).Create(context.Background(), userProfile, metav1.CreateOptions{}) + require.NoError(t, err, "Failed to create user profile") + + // PHASE 4: Verify merged profile behavior + t.Log("Verifying merged profile behavior...") + time.Sleep(15 * time.Second) // Allow merge to complete + + // Test merged profile behavior + wl.ExecIntoPod([]string{"ls", "-l"}, "nginx") // Expected: no alert + wl.ExecIntoPod([]string{"ls", "-l"}, "server") // Expected: no alert (user profile should suppress alert) + time.Sleep(10 * time.Second) // Wait for potential alerts + + // Verify alert counts + finalAlerts, err := testutils.GetAlerts(wl.Namespace) + require.NoError(t, err, "Failed to get final alerts") + + // Only count new alerts (after the initial count) + newAlertCount := 0 + for _, alert := range finalAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected process launched" { + newAlertCount++ + } + } + + t.Logf("Alert counts - Initial: %d, Final: %d", initialAlertCount, newAlertCount) + + if newAlertCount > initialAlertCount { + t.Logf("Full alert details:") + for _, alert := range finalAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected process launched" { + t.Logf("Alert: %+v", alert) + } + } + t.Errorf("New alerts were generated after merge (Initial: %d, Final: %d)", initialAlertCount, newAlertCount) + } + + // PHASE 5: Check PATCH (removing the ls command from the user profile of the server container and triggering an alert) + t.Log("Patching user profile to remove ls command from server container...") + patchOperations := []utils.PatchOperation{ + {Op: "remove", Path: "/spec/containers/1/execs/0"}, + } + + patch, err := json.Marshal(patchOperations) + require.NoError(t, err, "Failed to marshal patch operations") + + _, err = storageClient.ApplicationProfiles(ns.Name).Patch(context.Background(), userProfile.Name, types.JSONPatchType, patch, metav1.PatchOptions{}) + require.NoError(t, err, "Failed to patch user profile") + + // Verify patched profile behavior + time.Sleep(15 * time.Second) // Allow merge to complete + + // Log the profile that was patched + patchedProfile, err := wl.GetApplicationProfile() + require.NoError(t, err, "Failed to get patched profile") + t.Logf("Patched application profile:\n%v", patchedProfile) + + // Test patched profile behavior + wl.ExecIntoPod([]string{"ls", "-l"}, "nginx") // Expected: no alert + wl.ExecIntoPod([]string{"ls", "-l"}, "server") // Expected: alert (ls command removed from user profile) + time.Sleep(10 * time.Second) // Wait for potential alerts + + // Verify alert counts + finalAlerts, err = testutils.GetAlerts(wl.Namespace) + require.NoError(t, err, "Failed to get final alerts") + + // Only count new alerts (after the initial count) + newAlertCount = 0 + for _, alert := range finalAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected process launched" { + newAlertCount++ + } + } + + t.Logf("Alert counts - Initial: %d, Final: %d", initialAlertCount, newAlertCount) + + if newAlertCount <= initialAlertCount { + t.Logf("Full alert details:") + for _, alert := range finalAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected process launched" { + t.Logf("Alert: %+v", alert) + } + } + t.Errorf("New alerts were not generated after patch (Initial: %d, Final: %d)", initialAlertCount, newAlertCount) + } +} + +func Test_13_MergingNetworkNeighborhoodTest(t *testing.T) { + start := time.Now() + defer tearDownTest(t, start) + + // PHASE 1: Setup workload and initial network neighborhood + ns := testutils.NewRandomNamespace() + wl, err := testutils.NewTestWorkload(ns.Name, path.Join(utils.CurrentDir(), "resources/deployment-multiple-containers.yaml")) + require.NoError(t, err, "Failed to create workload") + require.NoError(t, wl.WaitForReady(80), "Workload failed to be ready") + require.NoError(t, wl.WaitForNetworkNeighborhood(80, "ready"), "Network neighborhood not ready") + + // Generate initial network data + _, _, err = wl.ExecIntoPod([]string{"wget", "ebpf.io", "-T", "2", "-t", "1"}, "server") + require.NoError(t, err, "Failed to exec wget in server container") + _, _, err = wl.ExecIntoPod([]string{"curl", "kubernetes.io", "-m", "2"}, "nginx") + require.NoError(t, err, "Failed to exec curl in nginx container") + + require.NoError(t, wl.WaitForNetworkNeighborhoodCompletion(80), "Network neighborhood failed to complete") + time.Sleep(10 * time.Second) // Allow network neighborhood processing + + // Log initial network neighborhood state + initialNN, err := wl.GetNetworkNeighborhood() + require.NoError(t, err, "Failed to get initial network neighborhood") + initialNNJSON, _ := json.Marshal(initialNN) + t.Logf("Initial network neighborhood:\n%s", string(initialNNJSON)) + + // PHASE 2: Verify initial alerts + t.Log("Testing initial alert generation...") + _, _, err = wl.ExecIntoPod([]string{"wget", "ebpf.io", "-T", "2", "-t", "1"}, "server") // Expected: no alert (original rule) + _, _, err = wl.ExecIntoPod([]string{"wget", "httpforever.com", "-T", "2", "-t", "1"}, "server") // Expected: alert (not allowed) + _, _, err = wl.ExecIntoPod([]string{"wget", "httpforever.com", "-T", "2", "-t", "1"}, "server") // Expected: alert (not allowed) + _, _, err = wl.ExecIntoPod([]string{"wget", "httpforever.com", "-T", "2", "-t", "1"}, "server") // Expected: alert (not allowed) + _, _, err = wl.ExecIntoPod([]string{"curl", "kubernetes.io", "-m", "2"}, "nginx") // Expected: no alert (original rule) + _, _, err = wl.ExecIntoPod([]string{"curl", "github.com", "-m", "2"}, "nginx") // Expected: alert (not allowed) + time.Sleep(30 * time.Second) // Wait for alert generation + + initialAlerts, err := testutils.GetAlerts(wl.Namespace) + require.NoError(t, err, "Failed to get initial alerts") + + // Record initial alert count + initialAlertCount := 0 + for _, alert := range initialAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected domain request" && alert.Labels["container_name"] == "server" { + initialAlertCount++ + } + } + + // Verify initial alerts + testutils.AssertContains(t, initialAlerts, "Unexpected domain request", "wget", "server") + testutils.AssertContains(t, initialAlerts, "Unexpected domain request", "curl", "nginx") + + // PHASE 3: Apply user-managed network neighborhood + t.Log("Applying user-managed network neighborhood...") + userNN := &v1beta1.NetworkNeighborhood{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("ug-%s", initialNN.Name), + Namespace: initialNN.Namespace, + Annotations: map[string]string{ + "kubescape.io/managed-by": "User", + }, + }, + Spec: v1beta1.NetworkNeighborhoodSpec{ + LabelSelector: metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app": "multiple-containers-app", + }, + }, + Containers: []v1beta1.NetworkNeighborhoodContainer{ + { + Name: "nginx", + Egress: []v1beta1.NetworkNeighbor{ + { + Identifier: "nginx-github", + Type: "external", + DNSNames: []string{"github.com."}, + Ports: []v1beta1.NetworkPort{ + { + Name: "TCP-80", + Protocol: "TCP", + Port: ptr(int32(80)), + }, + { + Name: "TCP-443", + Protocol: "TCP", + Port: ptr(int32(443)), + }, + }, + }, + }, + }, + { + Name: "server", + Egress: []v1beta1.NetworkNeighbor{ + { + Identifier: "server-example", + Type: "external", + DNSNames: []string{"info.cern.ch."}, + Ports: []v1beta1.NetworkPort{ + { + Name: "TCP-80", + Protocol: "TCP", + Port: ptr(int32(80)), + }, + { + Name: "TCP-443", + Protocol: "TCP", + Port: ptr(int32(443)), + }, + }, + }, + }, + }, + }, + }, + } + + // Create user-managed network neighborhood + k8sClient := k8sinterface.NewKubernetesApi() + storageClient := spdxv1beta1client.NewForConfigOrDie(k8sClient.K8SConfig) + _, err = storageClient.NetworkNeighborhoods(ns.Name).Create(context.Background(), userNN, metav1.CreateOptions{}) + require.NoError(t, err, "Failed to create user network neighborhood") + + // PHASE 4: Verify merged behavior (no new alerts) + t.Log("Verifying merged network neighborhood behavior...") + time.Sleep(25 * time.Second) // Allow merge to complete + + _, _, err = wl.ExecIntoPod([]string{"wget", "ebpf.io", "-T", "2", "-t", "1"}, "server") // Expected: no alert (original) + // Try multiple times to ensure alert is removed + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: no alert (user added) + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: no alert (user added) + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: no alert (user added) + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: no alert (user added) + _, _, err = wl.ExecIntoPod([]string{"curl", "kubernetes.io", "-m", "2"}, "nginx") // Expected: no alert (original) + _, _, err = wl.ExecIntoPod([]string{"curl", "github.com", "-m", "2"}, "nginx") // Expected: no alert (user added) + time.Sleep(30 * time.Second) // Wait for potential alerts + + mergedAlerts, err := testutils.GetAlerts(wl.Namespace) + require.NoError(t, err, "Failed to get alerts after merge") + + // Count new alerts after merge + newAlertCount := 0 + for _, alert := range mergedAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected domain request" && alert.Labels["container_name"] == "server" { + newAlertCount++ + } + } + + t.Logf("Alert counts - Initial: %d, After merge: %d", initialAlertCount, newAlertCount) + + if newAlertCount > initialAlertCount { + t.Logf("Full alert details:") + for _, alert := range mergedAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected domain request" && alert.Labels["container_name"] == "server" { + t.Logf("Alert: %+v", alert) + } + } + t.Errorf("New alerts were generated after merge (Initial: %d, After merge: %d)", initialAlertCount, newAlertCount) + } + + // PHASE 5: Remove permission via patch and verify alerts return + t.Log("Patching user network neighborhood to remove info.cern.ch from server container...") + patchOperations := []utils.PatchOperation{ + {Op: "remove", Path: "/spec/containers/1/egress/0"}, + } + + patch, err := json.Marshal(patchOperations) + require.NoError(t, err, "Failed to marshal patch operations") + + _, err = storageClient.NetworkNeighborhoods(ns.Name).Patch(context.Background(), userNN.Name, types.JSONPatchType, patch, metav1.PatchOptions{}) + require.NoError(t, err, "Failed to patch user network neighborhood") + + time.Sleep(20 * time.Second) // Allow merge to complete + + // Test alerts after patch + _, _, err = wl.ExecIntoPod([]string{"wget", "ebpf.io", "-T", "2", "-t", "1"}, "server") // Expected: no alert + // Try multiple times to ensure alert is removed + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: alert (removed) + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: alert (removed) + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: alert (removed) + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: alert (removed) + _, _, err = wl.ExecIntoPod([]string{"wget", "info.cern.ch", "-T", "2", "-t", "1"}, "server") // Expected: alert (removed) + _, _, err = wl.ExecIntoPod([]string{"curl", "kubernetes.io", "-m", "2"}, "nginx") // Expected: no alert + _, _, err = wl.ExecIntoPod([]string{"curl", "github.com", "-m", "2"}, "nginx") // Expected: no alert + time.Sleep(30 * time.Second) // Wait for alerts + + finalAlerts, err := testutils.GetAlerts(wl.Namespace) + require.NoError(t, err, "Failed to get final alerts") + + // Count final alerts + finalAlertCount := 0 + for _, alert := range finalAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected domain request" && alert.Labels["container_name"] == "server" { + finalAlertCount++ + } + } + + t.Logf("Alert counts - Initial: %d, Final: %d", initialAlertCount, finalAlertCount) + + if finalAlertCount <= initialAlertCount { + t.Logf("Full alert details:") + for _, alert := range finalAlerts { + if ruleName, ok := alert.Labels["rule_name"]; ok && ruleName == "Unexpected domain request" && alert.Labels["container_name"] == "server" { + t.Logf("Alert: %+v", alert) + } + } + t.Errorf("New alerts were not generated after patch (Initial: %d, Final: %d)", initialAlertCount, finalAlertCount) + } +} + +func ptr(i int32) *int32 { + return &i +} diff --git a/tests/resources/user-profile.yaml b/tests/resources/user-profile.yaml new file mode 100644 index 00000000..97a116f6 --- /dev/null +++ b/tests/resources/user-profile.yaml @@ -0,0 +1,47 @@ +apiVersion: spdx.softwarecomposition.kubescape.io/v1beta1 +kind: ApplicationProfile +metadata: + name: {name} + namespace: {namespace} + resourceVersion: "1" # Start with "1" for new resources + annotations: + kubescape.io/managed-by: User +spec: + architectures: ["amd64"] + containers: + - name: nginx + imageID: "" + imageTag: "" + capabilities: [] + opens: [] + syscalls: [] + endpoints: [] + execs: + - path: /usr/bin/ls + args: + - /usr/bin/ls + - -l + seccompProfile: + spec: + defaultAction: "" + - name: server + imageID: "" + imageTag: "" + capabilities: [] + opens: [] + syscalls: [] + endpoints: [] + execs: + - path: /bin/ls + args: + - /bin/ls + - -l + - path: /bin/grpc_health_probe + args: + - "-addr=:9555" + seccompProfile: + spec: + defaultAction: "" + initContainers: [] + ephemeralContainers: [] +status: {} \ No newline at end of file