diff --git a/pkg/auditserver/server.go b/pkg/auditserver/server.go index 611431a..bb6dafe 100644 --- a/pkg/auditserver/server.go +++ b/pkg/auditserver/server.go @@ -236,6 +236,7 @@ func New(logger *slog.Logger) (*AuditServer, error) { fwd, err = forwarder.NewUDPForwarder(rgConfig.Forwarding.Address) if err != nil { logger.Error("Failed to create UDP forwarder", "error", err) + return nil, fmt.Errorf("failed to create UDP forwarder: %w", err) } } diff --git a/pkg/auditserver/server_test.go b/pkg/auditserver/server_test.go index efa75f6..5d0d302 100644 --- a/pkg/auditserver/server_test.go +++ b/pkg/auditserver/server_test.go @@ -412,6 +412,97 @@ func TestNew_WithInvalidRuleGroups(t *testing.T) { assert.Contains(t, logOutput, "ERROR") } +func TestNew_WithValidForwarder(t *testing.T) { + // Set up a mock configuration with a valid forwarder address + viper.Reset() + viper.Set("rule_groups", []map[string]interface{}{ + { + "name": "test_group", + "rules": []string{ + "Request.Operation == 'read'", + }, + "log_file": map[string]interface{}{ + "file_path": "/tmp/test.log", + "max_size": 10, + }, + "forwarding": map[string]interface{}{ + "enabled": true, + "address": "127.0.0.1:9000", + }, + }, + }) + + server, err := New(nil) + + assert.NoError(t, err) + assert.NotNil(t, server) + assert.Len(t, server.ruleGroups, 1) + assert.NotNil(t, server.ruleGroups[0].Forwarder) +} + +func TestNew_WithInvalidForwarder(t *testing.T) { + // Set up a mock configuration with an invalid forwarder address + viper.Reset() + viper.Set("rule_groups", []map[string]interface{}{ + { + "name": "test_group", + "rules": []string{ + "Request.Operation == 'read'", + }, + "log_file": map[string]interface{}{ + "file_path": "/tmp/test.log", + "max_size": 10, + }, + "forwarding": map[string]interface{}{ + "enabled": true, + "address": "invalid:address:9000", + }, + }, + }) + + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + server, err := New(logger) + + assert.Error(t, err) + assert.Nil(t, server) + assert.Contains(t, err.Error(), "failed to create UDP forwarder") + + // Verify error was logged + logOutput := buf.String() + assert.Contains(t, logOutput, "Failed to create UDP forwarder") + assert.Contains(t, logOutput, "ERROR") +} + +func TestNew_WithDisabledForwarder(t *testing.T) { + // Set up a mock configuration with forwarding disabled + viper.Reset() + viper.Set("rule_groups", []map[string]interface{}{ + { + "name": "test_group", + "rules": []string{ + "Request.Operation == 'read'", + }, + "log_file": map[string]interface{}{ + "file_path": "/tmp/test.log", + "max_size": 10, + }, + "forwarding": map[string]interface{}{ + "enabled": false, + "address": "127.0.0.1:9000", + }, + }, + }) + + server, err := New(nil) + + assert.NoError(t, err) + assert.NotNil(t, server) + assert.Len(t, server.ruleGroups, 1) + assert.Nil(t, server.ruleGroups[0].Forwarder) +} + func TestRuleGroup_shouldLog(t *testing.T) { // Define a sample audit log auditLog := &AuditLog{