diff --git a/reporthandling/attacktrack/v1alpha1/attacktrackmethods.go b/reporthandling/attacktrack/v1alpha1/attacktrackmethods.go index e70bace9..8a25f6b3 100644 --- a/reporthandling/attacktrack/v1alpha1/attacktrackmethods.go +++ b/reporthandling/attacktrack/v1alpha1/attacktrackmethods.go @@ -57,6 +57,25 @@ func (at *AttackTrack) Iterator() IAttackTrackIterator { } } +// GetSubstepsWithVulnerabilities returns a list of substeps names that check for vulnerabilities +func (at *AttackTrack) GetSubstepsWithVulnerabilities() []string { + var substepNames []string + + var traverse func(step AttackTrackStep) + traverse = func(step AttackTrackStep) { + if step.DoesCheckVulnerabilities() { + substepNames = append(substepNames, step.Name) + } + for _, substep := range step.SubSteps { + traverse(substep) + } + } + + traverse(at.Spec.Data) + + return substepNames +} + func (iter *AttackTrackIterator) HasNext() bool { return !iter.stack.IsEmpty() } diff --git a/reporthandling/attacktrack/v1alpha1/attacktrackmethods_test.go b/reporthandling/attacktrack/v1alpha1/attacktrackmethods_test.go index d1ab22f0..069dd3b1 100644 --- a/reporthandling/attacktrack/v1alpha1/attacktrackmethods_test.go +++ b/reporthandling/attacktrack/v1alpha1/attacktrackmethods_test.go @@ -685,3 +685,50 @@ func TestFilterNodesWithControls(t *testing.T) { }) } } + +func TestGetSubstepsWithVulnerabilities(t *testing.T) { + // Create an AttackTrack object with substeps having different values for ChecksVulnerabilities + attackTrack := AttackTrack{ + ApiVersion: "v1", + Kind: "AttackTrack", + Metadata: map[string]interface{}{}, + Spec: AttackTrackSpecification{ + Version: "1.0", + Description: "Example attack track", + Data: AttackTrackStep{ + Name: "Step 1", + Description: "First step", + ChecksVulnerabilities: true, + SubSteps: []AttackTrackStep{ + { + Name: "Substep 1.1", + Description: "Substep 1.1 description", + ChecksVulnerabilities: true, + }, + { + Name: "Substep 1.2", + Description: "Substep 1.2 description", + ChecksVulnerabilities: false, + }, + }, + }, + }, + } + + // Call the method being tested + substepNames := attackTrack.GetSubstepsWithVulnerabilities() + + // Define the expected substep names with ChecksVulnerabilities set to true + expectedSubstepNames := []string{"Step 1", "Substep 1.1"} + + // Check if the returned substep names match the expected substep names + if len(substepNames) != len(expectedSubstepNames) { + t.Errorf("Unexpected number of substep names. Expected: %d, Got: %d", len(expectedSubstepNames), len(substepNames)) + } + + for i, name := range substepNames { + if name != expectedSubstepNames[i] { + t.Errorf("Mismatched substep name. Expected: %s, Got: %s", expectedSubstepNames[i], name) + } + } +} diff --git a/reporthandling/attacktrack/v1alpha1/attacktrackmocks.go b/reporthandling/attacktrack/v1alpha1/attacktrackmocks.go index 7dd94392..fed1a37d 100644 --- a/reporthandling/attacktrack/v1alpha1/attacktrackmocks.go +++ b/reporthandling/attacktrack/v1alpha1/attacktrackmocks.go @@ -136,6 +136,26 @@ func (at AttackTrackMock) Iterator() IAttackTrackIterator { } } +// GetSubstepsWithVulnerabilities returns a list of substeps names that check for vulnerabilities +func (at AttackTrackMock) GetSubstepsWithVulnerabilities() []string { + var substepNames []string + + var traverse func(step AttackTrackStep) + traverse = func(step AttackTrackStep) { + if step.DoesCheckVulnerabilities() { + substepNames = append(substepNames, step.Name) + } + for _, substep := range step.SubSteps { + traverse(substep) + } + } + + t := at.Spec.Data.(*AttackTrackStep) + traverse(*t) + + return substepNames +} + type MockAttackTrackSpecification struct { Version string `json:"version,omitempty"` Description string `json:"description,omitempty"` diff --git a/reporthandling/attacktrack/v1alpha1/interface.go b/reporthandling/attacktrack/v1alpha1/interface.go index e1ddd3a6..ac4f8608 100644 --- a/reporthandling/attacktrack/v1alpha1/interface.go +++ b/reporthandling/attacktrack/v1alpha1/interface.go @@ -9,6 +9,7 @@ type IAttackTrack interface { GetData() IAttackTrackStep Iterator() IAttackTrackIterator IsValid() bool + GetSubstepsWithVulnerabilities() []string } // A step in an attack track