From 7c5cd3aa9ac3c6a1f6c7181734ee018778344a51 Mon Sep 17 00:00:00 2001 From: Juliano Martinez Date: Sun, 20 Oct 2024 00:36:04 +0200 Subject: [PATCH] Refactoring and cleanup (#15) * rename from docker to development * refator a bit the acl manager and try to reduce the memory footprint * update commands * update readme * update gitignore * remove old reference * update from localhost to vault to match the containername * update tests * run tests in parallel * adds missing tests * adds extra tests * more tests * tidying a bit the code to simplify testing syncACLS * try to simplify the test a bit before adding new cases * adds missing tests --- .gitignore | 3 + README.md | 121 +- cmd/root.go | 10 +- cmd/run.go | 26 +- cmd/runOnce.go | 19 +- configs/{docker => development}/Dockerfile | 0 configs/{docker => development}/Makefile | 0 .../docker-compose.yaml | 0 .../redis/redis0001.conf | 0 .../redis/redis0002.conf | 0 .../redis/redis0003.conf | 0 .../{docker => development}/vault/setup.sh | 2 +- pkg/aclmanager/aclmanager.go | 319 ++-- pkg/aclmanager/aclmanager_test.go | 1372 +++++++++-------- 14 files changed, 1086 insertions(+), 786 deletions(-) rename configs/{docker => development}/Dockerfile (100%) rename configs/{docker => development}/Makefile (100%) rename configs/{docker => development}/docker-compose.yaml (100%) rename configs/{docker => development}/redis/redis0001.conf (100%) rename configs/{docker => development}/redis/redis0002.conf (100%) rename configs/{docker => development}/redis/redis0003.conf (100%) rename configs/{docker => development}/vault/setup.sh (96%) diff --git a/.gitignore b/.gitignore index f6d9e49..ee1f5a5 100644 --- a/.gitignore +++ b/.gitignore @@ -18,5 +18,8 @@ # vendor/ .idea/ +# .DS_Store +.DS_Store + # ignore local buid bedel diff --git a/README.md b/README.md index e23d412..03a1f9e 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,34 @@ # Bedel -[![Go](https://github.com/ncode/port53/actions/workflows/go.yml/badge.svg)](https://github.com/ncode/port53/actions/workflows/go.yml) +[![Go](https://github.com/ncode/bedel/actions/workflows/go.yml/badge.svg)](https://github.com/ncode/bedel/actions/workflows/go.yml) [![Go Report Card](https://goreportcard.com/badge/github.com/ncode/bedel)](https://goreportcard.com/report/github.com/ncode/bedel) [![codecov](https://codecov.io/gh/ncode/bedel/graph/badge.svg?token=N98KAO33K5)](https://codecov.io/gh/ncode/bedel) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -`bedel` is a tool designed to address a specific challenge with Redis: synchronizing users generated outside of the configuration file, such as through the Vault database backend. This utility ensures that Redis acls are up-to-date and consistent across all nodes. More info [here](https://github.com/redis/redis/issues/7988). +`bedel` is a utility designed to synchronize ACLs across multiple nodes in Redis and Redis-compatible databases like Valkey. It specifically addresses the challenge of managing users created outside the traditional configuration file, such as those generated through the [Vault database backend](https://www.vaultproject.io/docs/secrets/databases/redis). By keeping ACLs consistent across all nodes, Bedel ensures seamless user management and enhanced security in distributed environments. For more information on the underlying issue with Redis, see [Redis Issue #7988](https://github.com/redis/redis/issues/7988). + +## Table of Contents + +- [Features](#features) +- [Getting Started](#getting-started) + - [Prerequisites](#prerequisites) + - [Installation](#installation) + - [Running the Tests](#running-the-tests) +- [Development Setup](#development-setup) + - [Configuration](#configuration) +- [Usage](#usage) +- [Contributing](#contributing) +- [License](#license) ## Features -- Synchronize Redis users created outside the traditional config file. -- Integration with Vault database backend for user management. -- Automated and consistent user synchronization. -- Easy to deploy and integrate within existing Redis setups. + +- Automated User Synchronization: Automatically synchronizes Redis users and ACLs across all nodes to maintain consistency. +- Vault Integration: Seamlessly integrates with HashiCorp Vault's database backend for dynamic user management. +- Configurable Sync Intervals: Allows customization of synchronization intervals to suit your deployment needs. +- Lightweight and Efficient: Designed to have minimal impact on performance, even with thousands of users. +- Easy Deployment: Simple to deploy with Docker Compose or as a standalone binary. +- Robust Logging: Provides detailed logs for monitoring and troubleshooting. ## Getting Started @@ -20,9 +36,15 @@ These instructions will guide you through getting a copy of `bedel` up and runni ### Prerequisites -- Redis server setup. + +For users: +- Redis server setup - Access to Vault database backend (if using Vault for user generation). -- Go environment for development. + +For developers: +- Go 1.21 or higher +- Docker and Docker Compose (for development and testing). +- Git (for cloning the repository). ### Installing @@ -35,6 +57,11 @@ $ cd bedel $ go build ``` +2. Go install: +```bash +4 go install github.com/ncode/bedel/cmd/bedel@latest +``` + ### Running the Tests To run the automated tests for this system, use the following command: @@ -42,4 +69,80 @@ To run the automated tests for this system, use the following command: ```bash $ go test ./... ``` - + +## Development Setup + +Bedel comes with a development environment setup using Docker Compose. This setup includes: + +- Three Redis instances (redis0001, redis0002, redis0003) +- Three Bedel instances (bedel_redis0001, bedel_redis0002, bedel_redis0003) +- A Vault instance for managing secrets + +To start the development environment: + +1. Ensure you have Docker and Docker Compose installed. +2. Navigate to the project root directory. +3. Run the following command: + +```bash +$ cd config/development +$ make +``` + +This will start all the services defined in the `docker-compose.yaml` file. + +### Configuration + +The `docker-compose.yaml` file contains the configuration for all services. Here are some key points: + +- Redis instances are configured with custom configuration files located in the `./redis` directory. +- Bedel instances are configured to connect to their respective Redis instances. +- The Vault instance is set up with a root token "root" and listens on port 8200. + +## Usage + +Bedel can be run in two modes: + +### 1. Run Once Mode + +Performs a single synchronization of ACLs from the primary Redis node to the replica. + +```bash +$ bedel runOnce -a -p -u +``` + +### 2. Continuous Loop: + +Continuously synchronizes ACLs at a defined interval. + +```bash +$ bedel run -a -p -u --sync-interval +``` + +For more options and commands, run: +```bash +$ bedel --help +``` + +### Configuration file + +Bedel can also read configurations from a YAML file (default: $HOME/.bedel.yaml). Command-line options override configurations in the file. + +Example Configuration File (~/.bedel.yaml): +```yaml +address: localhost:6379 +password: mypassword +username: default +syncInterval: 10s +logLevel: INFO +aclfile: false +``` + + +## Contributing + +Contributions are welcome! + +## License + +This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. diff --git a/cmd/root.go b/cmd/root.go index eee7e41..73aab05 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -20,6 +20,7 @@ import ( "log/slog" "os" "path" + "time" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -63,7 +64,7 @@ func Execute() { func init() { cobra.OnInitialize(initConfig) rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.bedel.yaml)") - rootCmd.PersistentFlags().StringP("address", "a", "", "address of the slave to manage instance eg: 127.0.0.1:6379") + rootCmd.PersistentFlags().StringP("address", "a", "", "address of the instance to manage, e.g., 127.0.0.1:6379") viper.BindPFlag("address", rootCmd.PersistentFlags().Lookup("address")) rootCmd.PersistentFlags().StringP("password", "p", "", "password to manage acls") viper.BindPFlag("password", rootCmd.PersistentFlags().Lookup("password")) @@ -73,6 +74,8 @@ func init() { viper.BindPFlag("logLevel", rootCmd.PersistentFlags().Lookup("logLevel")) rootCmd.PersistentFlags().Bool("aclfile", false, "defined if we should use the aclfile to sync acls") viper.BindPFlag("aclfile", rootCmd.PersistentFlags().Lookup("aclfile")) + rootCmd.PersistentFlags().Duration("syncInterval", 10*time.Second, "interval between sync operations") + viper.BindPFlag("syncInterval", rootCmd.PersistentFlags().Lookup("syncInterval")) } // initConfig reads in config file and ENV variables if set. @@ -91,7 +94,6 @@ func initConfig() { viper.SetConfigName(".bedel") } - viper.SetDefault("syncInterval", 10) viper.SetDefault("username", "default") viper.AutomaticEnv() @@ -100,12 +102,12 @@ func initConfig() { } if !viper.IsSet("address") { - fmt.Fprintln(os.Stderr, "address is required") + logger.Error("Address is required") os.Exit(1) } if !viper.IsSet("password") { - fmt.Fprintln(os.Stderr, "password is required") + logger.Error("password is required") os.Exit(1) } diff --git a/cmd/run.go b/cmd/run.go index 87a3489..60c384d 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -16,6 +16,11 @@ limitations under the License. package cmd import ( + "context" + "os" + "os/signal" + "syscall" + "github.com/ncode/bedel/pkg/aclmanager" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -27,11 +32,26 @@ var runCmd = &cobra.Command{ Short: "Run the acl manager in mood loop, it will sync the follower with the primary", Run: func(cmd *cobra.Command, args []string) { mgr := aclmanager.New(viper.GetString("address"), viper.GetString("username"), viper.GetString("password"), viper.GetBool("aclfile")) - ctx := cmd.Context() - err := mgr.Loop(ctx) + defer mgr.Close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + cancel() + }() + + syncInterval := viper.GetDuration("syncInterval") + logger.Info("Starting ACL manager loop") + err := mgr.Loop(ctx, syncInterval) if err != nil { - panic(err) + logger.Error("Error running ACL manager loop", "error", err) + os.Exit(1) } + logger.Info("ACL manager loop terminated") }, } diff --git a/cmd/runOnce.go b/cmd/runOnce.go index 6eb869a..9e22fde 100644 --- a/cmd/runOnce.go +++ b/cmd/runOnce.go @@ -16,12 +16,12 @@ limitations under the License. package cmd import ( - "github.com/ncode/bedel/pkg/aclmanager" - "github.com/spf13/viper" "log/slog" "os" + "github.com/ncode/bedel/pkg/aclmanager" "github.com/spf13/cobra" + "github.com/spf13/viper" ) // runOnceCmd represents the runOnce command @@ -30,22 +30,23 @@ var runOnceCmd = &cobra.Command{ Short: "Run the acl manager once, it will sync the follower with the primary", Run: func(cmd *cobra.Command, args []string) { ctx := cmd.Context() - aclManager := aclmanager.New(viper.GetString("address"), viper.GetString("username"), viper.GetString("password"), viper.GetBool("aclfile")) - function, err := aclManager.CurrentFunction(ctx) + mgr := aclmanager.New(viper.GetString("address"), viper.GetString("username"), viper.GetString("password"), viper.GetBool("aclfile")) + defer mgr.Close() + function, err := mgr.CurrentFunction(ctx) if err != nil { - slog.Warn("unable to check if it's a Primary", "message", err) + slog.Error("Unable to check if node is primary", "error", err) os.Exit(1) } if function == aclmanager.Follower { - primary, err := aclManager.Primary(ctx) + primary, err := mgr.Primary(ctx) if err != nil { - slog.Warn("unable to find Primary", "message", err) + slog.Error("Unable to find Primary", "message", err) os.Exit(1) } var added, deleted []string - added, deleted, err = aclManager.SyncAcls(ctx, primary) + added, deleted, err = mgr.SyncAcls(ctx, primary) if err != nil { - slog.Warn("unable to sync acls from Primary", "message", err) + slog.Error("Unable to sync acls from Primary", "message", err) os.Exit(1) } slog.Info("Synced acls from Primary", "added", added, "deleted", deleted) diff --git a/configs/docker/Dockerfile b/configs/development/Dockerfile similarity index 100% rename from configs/docker/Dockerfile rename to configs/development/Dockerfile diff --git a/configs/docker/Makefile b/configs/development/Makefile similarity index 100% rename from configs/docker/Makefile rename to configs/development/Makefile diff --git a/configs/docker/docker-compose.yaml b/configs/development/docker-compose.yaml similarity index 100% rename from configs/docker/docker-compose.yaml rename to configs/development/docker-compose.yaml diff --git a/configs/docker/redis/redis0001.conf b/configs/development/redis/redis0001.conf similarity index 100% rename from configs/docker/redis/redis0001.conf rename to configs/development/redis/redis0001.conf diff --git a/configs/docker/redis/redis0002.conf b/configs/development/redis/redis0002.conf similarity index 100% rename from configs/docker/redis/redis0002.conf rename to configs/development/redis/redis0002.conf diff --git a/configs/docker/redis/redis0003.conf b/configs/development/redis/redis0003.conf similarity index 100% rename from configs/docker/redis/redis0003.conf rename to configs/development/redis/redis0003.conf diff --git a/configs/docker/vault/setup.sh b/configs/development/vault/setup.sh similarity index 96% rename from configs/docker/vault/setup.sh rename to configs/development/vault/setup.sh index d080ada..acf2eef 100755 --- a/configs/docker/vault/setup.sh +++ b/configs/development/vault/setup.sh @@ -5,7 +5,7 @@ set -x setup_vault(){ # Wait for Vault server to be up echo "Waiting for Vault to start..." - while ! nc -z localhost 8200; do + while ! nc -z vault 8200; do sleep 1 done diff --git a/pkg/aclmanager/aclmanager.go b/pkg/aclmanager/aclmanager.go index b21c3f3..ea973b2 100644 --- a/pkg/aclmanager/aclmanager.go +++ b/pkg/aclmanager/aclmanager.go @@ -1,32 +1,25 @@ package aclmanager import ( - "bufio" "context" + "crypto/sha256" "fmt" - "github.com/redis/go-redis/v9" - "github.com/spf13/viper" "log/slog" - "regexp" "strings" + "sync" "sync/atomic" "time" + + "github.com/redis/go-redis/v9" ) +// Constants for node roles const ( Primary = iota Follower Unknown ) -var ( - followerRegex = regexp.MustCompile(`slave\d+:ip=(?P.+),port=(?P\d+)`) - primaryHostRegex = regexp.MustCompile(`master_host:(?P.+)`) - primaryPortRegex = regexp.MustCompile(`master_port:(?P\d+)`) - role = regexp.MustCompile(`role:master`) - filterUser = regexp.MustCompile(`^user\s+`) -) - // AclManager contains the struct for managing the state of ACLs type AclManager struct { Addr string @@ -35,7 +28,9 @@ type AclManager struct { RedisClient *redis.Client primary atomic.Bool nodes map[string]int + mu sync.Mutex // Mutex to protect nodes map aclFile bool + batchSize int } // New creates a new AclManager @@ -53,61 +48,79 @@ func New(addr string, username string, password string, aclfile bool) *AclManage RedisClient: redisClient, nodes: make(map[string]int), aclFile: aclfile, + batchSize: 100, } } -// findNodes returns a list of nodes in the cluster based on the redis info replication command +// SetBatchSize sets the batch size for syncing ACLs +func (a *AclManager) SetBatchSize(size int) { + a.mu.Lock() + defer a.mu.Unlock() + a.batchSize = size +} + +// findNodes returns a list of nodes in the cluster based on the redis ROLE command func (a *AclManager) findNodes(ctx context.Context) error { slog.Debug("Entering findNodes") defer slog.Debug("Exiting findNodes") - replicationInfo, err := a.RedisClient.Info(ctx, "replication").Result() + roleInfo, err := a.RedisClient.Do(ctx, "ROLE").Result() if err != nil { - slog.Error("Failed to get replication info", "error", err) - return fmt.Errorf("findNodes: failed to get replication info: %w", err) + slog.Error("Failed to get role info", "error", err) + return fmt.Errorf("findNodes: failed to get role info: %w", err) } - a.primary.Store(role.MatchString(replicationInfo)) - - var masterHost, masterPort string - nodes := make([]string, 0) - scanner := bufio.NewScanner(strings.NewReader(replicationInfo)) - for scanner.Scan() { - line := scanner.Text() + a.mu.Lock() + defer a.mu.Unlock() + // Clear the nodes map + a.nodes = make(map[string]int) - slog.Debug("Parsing line for masterHost", "line", line) - if matches := primaryHostRegex.FindStringSubmatch(line); matches != nil { - masterHost = matches[1] + switch info := roleInfo.(type) { + case []interface{}: + if len(info) == 0 { + slog.Error("ROLE command returned empty result") + return fmt.Errorf("findNodes: ROLE command returned empty result") } - slog.Debug("Parsing line for masterPort", "line", line) - if matches := primaryPortRegex.FindStringSubmatch(line); matches != nil { - masterPort = matches[1] - node := fmt.Sprintf("%s:%s", masterHost, masterPort) - nodes = append(nodes, node) - a.nodes[node] = Primary - } - - slog.Debug("Parsing line for follower", "line", line) - if matches := followerRegex.FindStringSubmatch(line); matches != nil { - ip := matches[followerRegex.SubexpIndex("ip")] - port := matches[followerRegex.SubexpIndex("port")] - node := fmt.Sprintf("%s:%s", ip, port) - nodes = append(nodes, node) - a.nodes[node] = Follower + roleType, ok := info[0].(string) + if !ok { + slog.Error("Unexpected type for role", "type", fmt.Sprintf("%T", info[0])) + return fmt.Errorf("findNodes: unexpected type for role: %T", info[0]) } - } - if err := scanner.Err(); err != nil { - slog.Error("Scanner error", "error", err) - return fmt.Errorf("findNodes: scanner error: %w", err) - } - - for _, node := range nodes { - if _, ok := a.nodes[node]; !ok { - delete(a.nodes, node) - slog.Debug("Deleted node", "node", node) + switch roleType { + case "master": + a.primary.Store(true) + // Parse connected slaves + if len(info) >= 3 { + slaves, ok := info[2].([]interface{}) + if ok { + for _, slaveInfo := range slaves { + if slaveArr, ok := slaveInfo.([]interface{}); ok && len(slaveArr) >= 2 { + ip, _ := slaveArr[0].(string) + port, _ := slaveArr[1].(int64) + node := fmt.Sprintf("%s:%d", ip, port) + a.nodes[node] = Follower + } + } + } + } + case "slave": + a.primary.Store(false) + // Get master info + if len(info) >= 3 { + masterHost, _ := info[1].(string) + masterPort, _ := info[2].(int64) + node := fmt.Sprintf("%s:%d", masterHost, masterPort) + a.nodes[node] = Primary + } + default: + slog.Error("Unknown role type", "role", roleType) + return fmt.Errorf("findNodes: unknown role type: %s", roleType) } + default: + slog.Error("Unexpected type for roleInfo", "type", fmt.Sprintf("%T", roleInfo)) + return fmt.Errorf("findNodes: unexpected type for roleInfo: %T", roleInfo) } return nil @@ -131,6 +144,7 @@ func (a *AclManager) CurrentFunction(ctx context.Context) (int, error) { return Follower, nil } +// Primary returns an AclManager connected to the primary node func (a *AclManager) Primary(ctx context.Context) (*AclManager, error) { slog.Debug("Entering Primary") defer slog.Debug("Exiting Primary") @@ -141,6 +155,8 @@ func (a *AclManager) Primary(ctx context.Context) (*AclManager, error) { return nil, fmt.Errorf("Primary: %w", err) } + a.mu.Lock() + defer a.mu.Unlock() for address, function := range a.nodes { if function == Primary { slog.Info("Found Primary node", "address", address) @@ -154,10 +170,13 @@ func (a *AclManager) Primary(ctx context.Context) (*AclManager, error) { // Close closes the redis client func (a *AclManager) Close() error { slog.Debug("Closing Redis client") + if a.RedisClient == nil { + return fmt.Errorf("Redis client is nil") + } return a.RedisClient.Close() } -// listAcls returns a list of acls in the cluster based on the redis acl list command +// listAcls returns a list of ACLs in the cluster based on the redis ACL LIST command func listAcls(ctx context.Context, client *redis.Client) ([]string, error) { slog.Debug("Entering listAcls") defer slog.Debug("Exiting listAcls") @@ -170,31 +189,26 @@ func listAcls(ctx context.Context, client *redis.Client) ([]string, error) { aclList, ok := result.([]interface{}) if !ok { - err := fmt.Errorf("unexpected result format: %v", result) + err := fmt.Errorf("unexpected result format: %T", result) slog.Error("Unexpected result format", "result", result) return nil, fmt.Errorf("listAcls: %w", err) } - if len(aclList) == 0 { - slog.Info("No ACLs found") - return nil, nil // Return nil if no ACLs are found - } - - acls := make([]string, len(aclList)) - for i, acl := range aclList { + acls := make([]string, 0, len(aclList)) + for _, acl := range aclList { aclStr, ok := acl.(string) if !ok { - err := fmt.Errorf("unexpected type for ACL: %v", acl) + err := fmt.Errorf("unexpected type for ACL: %T", acl) slog.Error("Unexpected type for ACL", "acl", acl) return nil, fmt.Errorf("listAcls: %w", err) } - acls[i] = aclStr + acls = append(acls, aclStr) } slog.Info("Listed ACLs", "count", len(acls)) return acls, nil } -// saveAclFile calls the redis command ACL SAVE to save the acls to the aclFile +// saveAclFile calls the redis command ACL SAVE to save the ACLs to the aclFile func saveAclFile(ctx context.Context, client *redis.Client) error { slog.Debug("Entering saveAclFile") defer slog.Debug("Exiting saveAclFile") @@ -207,7 +221,7 @@ func saveAclFile(ctx context.Context, client *redis.Client) error { return nil } -// loadAclFile calls the redis command ACL LOAD to load the acls from the aclFile +// loadAclFile calls the redis command ACL LOAD to load the ACLs from the aclFile func loadAclFile(ctx context.Context, client *redis.Client) error { slog.Debug("Entering loadAclFile") defer slog.Debug("Exiting loadAclFile") @@ -220,7 +234,56 @@ func loadAclFile(ctx context.Context, client *redis.Client) error { return nil } -// SyncAcls connects to master node and syncs the acls to the current node +// hashString computes a SHA-256 hash of the input string +func hashString(s string) string { + h := sha256.New() + h.Write([]byte(s)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +// listAndMapAcls is an auxiliary function to list ACLs and create a map of username to hash and ACL string +func listAndMapAcls(ctx context.Context, client *redis.Client) (map[string]string, map[string]string, error) { + slog.Debug("Entering listAndMapAcls") + defer slog.Debug("Exiting listAndMapAcls") + + result, err := client.Do(ctx, "ACL", "LIST").Result() + if err != nil { + slog.Error("Failed to list ACLs", "error", err) + return nil, nil, fmt.Errorf("listAndMapAcls: error listing ACLs: %w", err) + } + + aclList, ok := result.([]interface{}) + if !ok { + err := fmt.Errorf("unexpected result format: %T", result) + slog.Error("Unexpected result format", "result", result) + return nil, nil, fmt.Errorf("listAndMapAcls: %w", err) + } + + aclHashMap := make(map[string]string) + aclStrMap := make(map[string]string) + for _, acl := range aclList { + aclStr, ok := acl.(string) + if !ok { + err := fmt.Errorf("unexpected type for ACL: %T", acl) + slog.Error("Unexpected type for ACL", "acl", acl) + return nil, nil, fmt.Errorf("listAndMapAcls: %w", err) + } + fields := strings.Fields(aclStr) + if len(fields) < 2 { + slog.Warn("Invalid ACL format", "acl", aclStr) + continue + } + username := fields[1] + hash := hashString(aclStr) + aclHashMap[username] = hash + aclStrMap[username] = aclStr + } + + slog.Info("Listed and mapped ACLs", "count", len(aclHashMap)) + return aclHashMap, aclStrMap, nil +} + +// SyncAcls connects to the primary node and syncs the ACLs to the current node func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]string, []string, error) { slog.Debug("Entering SyncAcls") defer slog.Debug("Exiting SyncAcls") @@ -231,97 +294,119 @@ func (a *AclManager) SyncAcls(ctx context.Context, primary *AclManager) ([]strin return nil, nil, err } - sourceAcls, err := listAcls(ctx, primary.RedisClient) + // Get source ACLs + sourceAclHashMap, sourceAclStrMap, err := listAndMapAcls(ctx, primary.RedisClient) if err != nil { - slog.Error("Failed to list source ACLs", "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error listing source acls: %w", err) - } - - if a.aclFile { - if err = saveAclFile(ctx, primary.RedisClient); err != nil { - slog.Error("Failed to save primary ACLs to aclFile", "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error saving primary acls to aclFile: %w", err) - } + return nil, nil, fmt.Errorf("SyncAcls: error listing source ACLs: %w", err) } - destinationAcls, err := listAcls(ctx, a.RedisClient) + // Get destination ACLs + destinationAclHashMap, _, err := listAndMapAcls(ctx, a.RedisClient) if err != nil { - slog.Error("Failed to list current ACLs", "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error listing current acls: %w", err) - } - - toUpdate := make(map[string]string, len(sourceAcls)) - for _, acl := range sourceAcls { - username := strings.Fields(acl)[1] - toUpdate[username] = acl + return nil, nil, fmt.Errorf("SyncAcls: error listing destination ACLs: %w", err) } var updated, deleted []string - for _, acl := range destinationAcls { - username := strings.Fields(acl)[1] - if currentAcl, found := toUpdate[username]; found { - if currentAcl == acl { - delete(toUpdate, username) - slog.Debug("ACL already in sync", "username", username) + // Batch commands + cmds := make([]redis.Cmder, 0, a.batchSize) + pipe := a.RedisClient.Pipeline() + + // Delete ACLs that are not in the source + for username := range destinationAclHashMap { + if _, found := sourceAclHashMap[username]; !found && username != "default" { + slog.Debug("Deleting ACL", "username", username) + cmd := pipe.Do(ctx, "ACL", "DELUSER", username) + cmds = append(cmds, cmd) + deleted = append(deleted, username) + if len(cmds) >= a.batchSize { + // Execute pipeline + if _, err = pipe.Exec(ctx); err != nil { + slog.Error("Failed to execute pipeline", "error", err) + return nil, nil, fmt.Errorf("SyncAcls: error executing pipeline: %w", err) + } + // Reset pipeline and cmds + pipe = a.RedisClient.Pipeline() + cmds = cmds[:0] } - continue } + } - slog.Debug("Deleting ACL", "username", username) - if err := a.RedisClient.Do(ctx, "ACL", "DELUSER", username).Err(); err != nil { - slog.Error("Failed to delete ACL", "username", username, "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error deleting acl: %w", err) + // Add or update ACLs from the source + for username, sourceHash := range sourceAclHashMap { + destHash, found := destinationAclHashMap[username] + if !found || destHash != sourceHash { + aclStr := sourceAclStrMap[username] + if aclStr == "" { + slog.Error("ACL string not found for user", "username", username) + continue + } + + args := []interface{}{"ACL", "SETUSER"} + fields := strings.Fields(aclStr) + // Skip the "user" keyword + for _, field := range fields[1:] { + args = append(args, field) + } + + cmd := pipe.Do(ctx, args...) + cmds = append(cmds, cmd) + updated = append(updated, username) + + if len(cmds) >= a.batchSize { + // Execute pipeline + if _, err = pipe.Exec(ctx); err != nil { + slog.Error("Failed to execute pipeline", "error", err) + return nil, nil, fmt.Errorf("SyncAcls: error executing pipeline: %w", err) + } + // Reset pipeline and cmds + pipe = a.RedisClient.Pipeline() + cmds = cmds[:0] + } } - deleted = append(deleted, username) } - for username, acl := range toUpdate { - slog.Debug("Syncing ACL", "username", username, "line", acl) - command := strings.Split(filterUser.ReplaceAllString(acl, "ACL SETUSER "), " ") - commandInterface := make([]interface{}, len(command)) - for i, s := range command { - commandInterface[i] = s + // Execute any remaining commands + if len(cmds) > 0 { + if _, err = pipe.Exec(ctx); err != nil { + slog.Error("Failed to execute pipeline", "error", err) + return nil, nil, fmt.Errorf("SyncAcls: error executing pipeline: %w", err) } - if err := a.RedisClient.Do(ctx, commandInterface...).Err(); err != nil { - slog.Error("Failed to set ACL", "username", username, "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error setting acl: %w", err) - } - updated = append(updated, username) } + // If aclFile is enabled, save and load the ACL file if a.aclFile { if err = saveAclFile(ctx, a.RedisClient); err != nil { slog.Error("Failed to save ACLs to aclFile", "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error saving acls to aclFile: %w", err) + return nil, nil, fmt.Errorf("SyncAcls: error saving ACLs to aclFile: %w", err) } if err = loadAclFile(ctx, a.RedisClient); err != nil { slog.Error("Failed to load synced ACLs from aclFile", "error", err) - return nil, nil, fmt.Errorf("SyncAcls: error loading synced acls from aclFile: %w", err) + return nil, nil, fmt.Errorf("SyncAcls: error loading ACLs from aclFile: %w", err) } } - slog.Info("Synced ACLs", "added", updated, "deleted", deleted) + slog.Info("Synced ACLs", "updated", updated, "deleted", deleted) return updated, deleted, nil } -// Loop loops through the sync interval and syncs the acls -func (a *AclManager) Loop(ctx context.Context) error { +// Loop periodically syncs the ACLs at the given interval +func (a *AclManager) Loop(ctx context.Context, syncInterval time.Duration) error { slog.Debug("Entering Loop") defer slog.Debug("Exiting Loop") - ticker := time.NewTicker(viper.GetDuration("syncInterval") * time.Second) + ticker := time.NewTicker(syncInterval) defer ticker.Stop() for { select { case <-ctx.Done(): slog.Info("Context done, exiting Loop") - return nil + return ctx.Err() case <-ticker.C: function, err := a.CurrentFunction(ctx) if err != nil { - slog.Warn("Unable to check if it's a Primary", "error", err) + slog.Warn("Unable to determine node function", "error", err) continue } if function == Follower { @@ -330,6 +415,10 @@ func (a *AclManager) Loop(ctx context.Context) error { slog.Warn("Unable to find Primary", "error", err) continue } + if primary == nil { + slog.Warn("Primary node is nil") + continue + } added, deleted, err := a.SyncAcls(ctx, primary) if err != nil { slog.Warn("Unable to sync ACLs from Primary", "error", err) diff --git a/pkg/aclmanager/aclmanager_test.go b/pkg/aclmanager/aclmanager_test.go index 0b9058d..5806d4f 100644 --- a/pkg/aclmanager/aclmanager_test.go +++ b/pkg/aclmanager/aclmanager_test.go @@ -3,10 +3,8 @@ package aclmanager import ( "context" "fmt" - "github.com/spf13/viper" - "reflect" - "slices" "strings" + "sync" "testing" "time" @@ -14,88 +12,89 @@ import ( "github.com/stretchr/testify/assert" ) +// Sample outputs for ROLE command var ( - primaryOutput = ` -# Replication -role:master -connected_slaves:1 -slave0:ip=172.21.0.3,port=6379,state=online,offset=322,lag=0 -master_replid:1da7151855972ec8517bcae3d2c11454ff942d72 -master_replid2:0000000000000000000000000000000000000000 -master_repl_offset:322 -second_repl_offset:-1 -repl_backlog_active:1 -repl_backlog_size:1048576 -repl_backlog_first_byte_offset:1 -repl_backlog_histlen:322` - - followerOutput = ` -# Replication -role:slave -master_host:172.21.0.2 -master_port:6379 -master_link_status:up -master_last_io_seconds_ago:10 -master_sync_in_progress:0 -slave_repl_offset:434 -slave_priority:100 -slave_read_only:1 -connected_slaves:0 -master_replid:7d4b067fa70ad532ff7feff7bd7ff3cf27429b08 -master_replid2:0000000000000000000000000000000000000000 -master_repl_offset:434 -second_repl_offset:-1 -repl_backlog_active:1 -repl_backlog_size:1048576 -repl_backlog_first_byte_offset:1 -repl_backlog_histlen:434` + primaryRoleOutput = []interface{}{ + "master", + int64(0), + []interface{}{ + []interface{}{"172.21.0.3", int64(6379)}, + }, + } + + followerRoleOutput = []interface{}{ + "slave", + "172.21.0.2", + int64(6379), + "connected", + int64(1), + } ) func TestFindNodes(t *testing.T) { - // Sample master and slave output for testing - + t.Parallel() tests := []struct { - name string - mockResp string - want map[string]int - wantErr bool - nodes map[string]int + name string + mockRoleResp interface{} + expectedNodes map[string]int + wantErr bool + expectedErrMsg string }{ { - name: "parse master output", - mockResp: primaryOutput, - want: map[string]int{ + name: "parse primary role output", + mockRoleResp: primaryRoleOutput, + expectedNodes: map[string]int{ "172.21.0.3:6379": Follower, }, wantErr: false, }, { - name: "parse Follower output", - mockResp: followerOutput, - want: map[string]int{ + name: "parse follower role output", + mockRoleResp: followerRoleOutput, + expectedNodes: map[string]int{ "172.21.0.2:6379": Primary, }, wantErr: false, }, { - name: "error on replicationInfo", - mockResp: followerOutput, - want: nil, - wantErr: true, + name: "ROLE command returns empty result", + mockRoleResp: []interface{}{}, + expectedNodes: nil, + wantErr: true, + expectedErrMsg: "findNodes: ROLE command returned empty result", }, { - name: "ensure old nodes are removed", - mockResp: primaryOutput, - want: map[string]int{ - "172.21.0.3:6379": Follower, + name: "ROLE command first element not a string", + mockRoleResp: []interface{}{ + int64(12345), // Non-string type explicitly set to int64 + "some other data", }, - wantErr: false, - nodes: map[string]int{ - "192.168.0.1:6379": Follower, - "192.168.0.2:6379": Follower, - "192.168.0.3:6379": Follower, - "192.168.0.4:6379": Follower, + expectedNodes: nil, + wantErr: true, + expectedErrMsg: "findNodes: unexpected type for role: int64", + }, + { + name: "error on ROLE command", + mockRoleResp: nil, // Simulate Redis error + expectedNodes: nil, + wantErr: true, + expectedErrMsg: "findNodes: ROLE command failed", + }, + { + name: "unknown role type", + mockRoleResp: []interface{}{ + "sentinel", }, + expectedNodes: nil, + wantErr: true, + expectedErrMsg: "findNodes: unknown role type: sentinel", + }, + { + name: "unexpected type for roleInfo", + mockRoleResp: "invalid_type", // Not a slice + expectedNodes: nil, + wantErr: true, + expectedErrMsg: "findNodes: unexpected type for roleInfo: string", }, } @@ -103,85 +102,100 @@ func TestFindNodes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { redisClient, mock := redismock.NewClientMock() - // Mocking the response for the Info function - if tt.wantErr { - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) + // Setup the expected ROLE command response + if tt.wantErr && tt.mockRoleResp == nil { + // Simulate an error from Redis + mock.ExpectDo("ROLE").SetErr(fmt.Errorf("findNodes: ROLE command failed")) } else { - mock.ExpectInfo("replication").SetVal(tt.mockResp) + // Simulate a successful ROLE command with the provided response + mock.ExpectDo("ROLE").SetVal(tt.mockRoleResp) + } + + // Initialize AclManager with the mocked Redis client + aclManager := AclManager{ + RedisClient: redisClient, + nodes: make(map[string]int), + mu: sync.Mutex{}, } - aclManager := AclManager{RedisClient: redisClient, nodes: make(map[string]int)} ctx := context.Background() + // Execute the findNodes function err := aclManager.findNodes(ctx) - if (err != nil) != tt.wantErr { - t.Errorf("findNodes() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.name == "ensure old nodes are removed" { - for address, _ := range aclManager.nodes { - if _, ok := tt.nodes[address]; ok { - t.Errorf("findNodes() address %v shound not be found", address) - } - } - } - for address, function := range aclManager.nodes { - if wantFunction, ok := tt.want[address]; ok { - if wantFunction != function { - t.Errorf("findNodes() wanted function %v not found", function) - } - return + + // Assert whether an error was expected + if tt.wantErr { + assert.Error(t, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) } - t.Errorf("findNodes() wanted address %v not found", address) + } else { + assert.NoError(t, err) + aclManager.mu.Lock() + defer aclManager.mu.Unlock() + assert.Equal(t, tt.expectedNodes, aclManager.nodes) } + + // Ensure all expectations were met + assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestListAcls(t *testing.T) { + t.Parallel() tests := []struct { - name string - mockResp []interface{} - want []string - wantErr bool + name string + mockResp interface{} + expectedAcls []string + wantErr bool + expectedErrMsg string }{ { - name: "parse valid ACL list", + name: "valid ACL list", mockResp: []interface{}{ "user default on nopass ~* &* +@all", "user alice on >password ~keys:* -@all +get +set +del", }, - want: []string{ + expectedAcls: []string{ "user default on nopass ~* &* +@all", "user alice on >password ~keys:* -@all +get +set +del", }, wantErr: false, }, { - name: "empty ACL list", - mockResp: []interface{}{}, - want: []string(nil), - wantErr: false, - }, - { - name: "nil response from Redis", - mockResp: nil, - want: nil, - wantErr: false, + name: "empty ACL list", + mockResp: []interface{}{}, + expectedAcls: []string{}, + wantErr: false, }, { - name: "error from Redis client", - mockResp: nil, - want: nil, - wantErr: false, + name: "error from Redis client", + mockResp: nil, + wantErr: true, + expectedErrMsg: "error", }, { - name: "non-string elements in ACL list", + name: "aclList contains non-string element", mockResp: []interface{}{ "user default on nopass ~* &* +@all", - 123, // Invalid element + map[string]interface{}{ + "unexpected": "data", + }, }, - want: nil, - wantErr: true, + wantErr: true, + expectedErrMsg: "unexpected type for ACL: map[string]interface {}", + }, + { + name: "result is not []interface{}", + mockResp: "invalid_type", + wantErr: true, + expectedErrMsg: "unexpected result format: string", + }, + { + name: "nil response", + mockResp: nil, + wantErr: true, + expectedErrMsg: "error", }, } @@ -189,8 +203,14 @@ func TestListAcls(t *testing.T) { t.Run(tt.name, func(t *testing.T) { redisClient, mock := redismock.NewClientMock() - // Mocking the response for the ACL LIST command - mock.ExpectDo("ACL", "LIST").SetVal(tt.mockResp) + if tt.wantErr && tt.mockResp == nil && tt.expectedErrMsg == "error" { + // Simulate an error from Redis + mock.ExpectDo("ACL", "LIST").SetErr(fmt.Errorf("error")) + } else { + // Simulate a successful ACL LIST command with the provided response + mock.ExpectDo("ACL", "LIST").SetVal(tt.mockResp) + } + acls, err := listAcls(context.Background(), redisClient) if (err != nil) != tt.wantErr { @@ -198,115 +218,327 @@ func TestListAcls(t *testing.T) { return } - assert.Equal(t, tt.want, acls) + if tt.wantErr { + assert.Error(t, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedAcls, acls) + } + + // Ensure all expectations were met + assert.NoError(t, mock.ExpectationsWereMet()) }) } } -func TestListAcls_Error(t *testing.T) { - redisClient, mock := redismock.NewClientMock() +func TestListAndMapAcls(t *testing.T) { + t.Parallel() + tests := []struct { + name string + mockResp interface{} + expectedHashMap map[string]string + expectedStrMap map[string]string + wantErr bool + expectedErrMsg string + }{ + { + name: "valid ACL list", + mockResp: []interface{}{ + "user default on nopass ~* &* +@all", + "user alice on >password ~keys:* -@all +get +set +del", + }, + expectedHashMap: map[string]string{ + "default": hashString("user default on nopass ~* &* +@all"), + "alice": hashString("user alice on >password ~keys:* -@all +get +set +del"), + }, + expectedStrMap: map[string]string{ + "default": "user default on nopass ~* &* +@all", + "alice": "user alice on >password ~keys:* -@all +get +set +del", + }, + wantErr: false, + }, + { + name: "empty ACL list", + mockResp: []interface{}{}, + expectedHashMap: map[string]string{}, + expectedStrMap: map[string]string{}, + wantErr: false, + }, + { + name: "error from Redis client", + mockResp: nil, + wantErr: true, + expectedErrMsg: "error listing ACLs", + }, + { + name: "invalid ACL format", + mockResp: []interface{}{ + "invalid_acl", + "user alice on >password ~keys:* -@all +get +set +del", + }, + expectedHashMap: map[string]string{ + "alice": hashString("user alice on >password ~keys:* -@all +get +set +del"), + }, + expectedStrMap: map[string]string{ + "alice": "user alice on >password ~keys:* -@all +get +set +del", + }, + wantErr: false, + }, + { + name: "result is not []interface{}", + mockResp: "invalid_type", + wantErr: true, + expectedErrMsg: "unexpected result format", + }, + } - // Mocking the response for the ACL LIST command - mock.ExpectDo("ACL", "LIST").SetVal([]string{"user acl1", "user acl2"}) - _, err := listAcls(context.Background(), redisClient) - assert.Error(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redisClient, mock := redismock.NewClientMock() + + if tt.wantErr && tt.mockResp == nil { + mock.ExpectDo("ACL", "LIST").SetErr(fmt.Errorf("error")) + } else { + mock.ExpectDo("ACL", "LIST").SetVal(tt.mockResp) + } + + aclHashMap, aclStrMap, err := listAndMapAcls(context.Background(), redisClient) + + if (err != nil) != tt.wantErr { + t.Errorf("listAndMapAcls() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + assert.Error(t, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedHashMap, aclHashMap) + assert.Equal(t, tt.expectedStrMap, aclStrMap) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } } func TestSyncAcls(t *testing.T) { + t.Parallel() tests := []struct { - name string - sourceAcls []interface{} - destinationAcls []interface{} - expectedDeleted []string - expectedUpdated []string - listAclsError error - redisDoError error - saveAclError error - loadAclError error - wantSourceErr bool - wantDestinationErr bool - aclFile bool + name string + sourceAcls interface{} + destinationAcls interface{} + expectedDeleted []string + expectedUpdated []string + sourceListAclsErr error + destListAclsErr error + redisDoError error + saveAclError error + loadAclError error + wantErr bool + expectedErrMsg string + aclFile bool }{ { - name: "ACLs synced with deletions", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - destinationAcls: []interface{}{"user acl1", "user acl3"}, + name: "ACLs synced with deletions", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl2 on >password2 ~* +@all", + }, + destinationAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl3 on >password3 ~* +@all", + }, expectedDeleted: []string{"acl3"}, expectedUpdated: []string{"acl2"}, + aclFile: false, + wantErr: false, }, { - name: "ACLs synced with differences", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - destinationAcls: []interface{}{"user acl1 something something", "user acl3"}, - expectedDeleted: []string{"acl3"}, + name: "ACLs synced with differences", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl2 on >password2 ~* +@all", + }, + destinationAcls: []interface{}{ + "user acl1 on >password_different ~* +@all", + }, + expectedDeleted: []string{}, expectedUpdated: []string{"acl1", "acl2"}, + aclFile: false, + wantErr: false, }, { - name: "ACLs synced with Error om SETUSER", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - destinationAcls: []interface{}{"user acl1", "user acl3"}, - redisDoError: fmt.Errorf("DELUSER"), - wantDestinationErr: true, + name: "Error listing source ACLs", + sourceListAclsErr: fmt.Errorf("error listing source ACLs"), + wantErr: true, + expectedErrMsg: "SyncAcls: error listing source ACLs", + aclFile: false, }, { - name: "ACLs synced with Error on SETUSER", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - destinationAcls: []interface{}{"user acl1"}, - redisDoError: fmt.Errorf("SETUSER"), - wantSourceErr: false, - wantDestinationErr: true, + name: "Error listing destination ACLs", + sourceAcls: []interface{}{}, // Set to empty slice to prevent panic + destListAclsErr: fmt.Errorf("error listing destination ACLs"), + wantErr: true, + expectedErrMsg: "SyncAcls: error listing destination ACLs", + aclFile: false, }, { - name: "No ACLs to delete", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - destinationAcls: []interface{}{"user acl1", "user acl2"}, - expectedDeleted: nil, - wantSourceErr: false, + name: "Error deleting ACL", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + }, + destinationAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl3 on >password3 ~* +@all", + }, + expectedDeleted: []string{"acl3"}, + expectedUpdated: []string{}, + redisDoError: fmt.Errorf("error deleting ACL"), // Simulate error only on deletion + wantErr: true, + expectedErrMsg: "SyncAcls: error executing pipeline", + aclFile: false, }, { - name: "Error listing source ACLs", - listAclsError: fmt.Errorf("error listing source ACLs"), - wantSourceErr: true, + name: "Error setting ACL", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl2 on >password2 ~* +@all", + }, + destinationAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + }, + expectedDeleted: []string{}, + expectedUpdated: []string{"acl2"}, + redisDoError: fmt.Errorf("error setting ACL"), + wantErr: true, + expectedErrMsg: "SyncAcls: error executing pipeline", + aclFile: false, }, { - name: "Error listing destination ACLs", - listAclsError: fmt.Errorf("error listing destination ACLs"), - wantDestinationErr: true, + name: "ACLs synced with ACL file enabled", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl2 on >password2 ~* +@all", + }, + destinationAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl3 on >password3 ~* +@all", + }, + expectedDeleted: []string{"acl3"}, + expectedUpdated: []string{"acl2"}, + aclFile: true, + wantErr: false, }, { - name: "Invalid aclManagerPrimary", - listAclsError: fmt.Errorf("error listing destination ACLs"), + name: "Error saving ACL file", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + }, + destinationAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl3 on >password3 ~* +@all", + }, + expectedDeleted: []string{"acl3"}, + expectedUpdated: []string{}, + aclFile: true, + saveAclError: fmt.Errorf("error saving ACL file"), + wantErr: true, + expectedErrMsg: "SyncAcls: error saving ACLs to aclFile", }, { - name: "ACLs synced with deletions, aclFile", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - destinationAcls: []interface{}{"user acl1", "user acl3"}, + name: "Error loading ACL file", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl2 on >password2 ~* +@all", // Added acl2 + }, + destinationAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + "user acl3 on >password3 ~* +@all", + }, expectedDeleted: []string{"acl3"}, expectedUpdated: []string{"acl2"}, aclFile: true, + loadAclError: fmt.Errorf("error loading ACL file"), + wantErr: true, + expectedErrMsg: "SyncAcls: error loading ACLs from aclFile", + }, + { + name: "No ACLs to update", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + }, + destinationAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + }, + expectedDeleted: []string{}, + expectedUpdated: []string{}, + aclFile: false, + wantErr: false, }, { - name: "Error on save ACL file on primary, aclFile", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - saveAclError: fmt.Errorf("failed to save ACL on primary"), - aclFile: true, - wantSourceErr: true, + name: "No primary found", + wantErr: true, + expectedErrMsg: "no primary found", + aclFile: false, }, { - name: "Error on save ACL file on destination, aclFile", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - destinationAcls: []interface{}{"user acl1", "user acl3"}, - saveAclError: fmt.Errorf("failed to save ACL on destination"), - aclFile: true, - wantDestinationErr: true, + name: "Error: element in sourceAclList not string", + sourceAcls: []interface{}{ + 12345, // Invalid element + }, + destinationAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + }, + wantErr: true, + expectedErrMsg: "unexpected type for ACL: int", + aclFile: false, }, { - name: "Error on load ACL file, aclFile", - sourceAcls: []interface{}{"user acl1", "user acl2"}, - destinationAcls: []interface{}{"user acl1", "user acl3"}, - loadAclError: fmt.Errorf("failed to load ACL"), - aclFile: true, - wantDestinationErr: true, + name: "Error: element in destinationAclList not string", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + }, + destinationAcls: []interface{}{ + 12345, // Invalid element + }, + wantErr: true, + expectedErrMsg: "unexpected type for ACL: int", + aclFile: false, + }, + { + name: "Invalid ACL in sourceAcls", + sourceAcls: []interface{}{ + "invalid_acl", // Invalid ACL string + "user acl1 on >password1 ~* +@all", // Valid ACL + }, + destinationAcls: []interface{}{ + "user acl2 on >password2 ~* +@all", + }, + expectedDeleted: []string{"acl2"}, + expectedUpdated: []string{"acl1"}, + aclFile: false, + wantErr: false, + }, + { + name: "Invalid ACL in destinationAcls", + sourceAcls: []interface{}{ + "user acl1 on >password1 ~* +@all", + }, + destinationAcls: []interface{}{ + "invalid_acl", // Invalid ACL string + "user acl2 on >password2 ~* +@all", + }, + expectedDeleted: []string{"acl2"}, // acl2 should be deleted + expectedUpdated: []string{"acl1"}, + aclFile: false, + wantErr: false, }, } @@ -315,226 +547,229 @@ func TestSyncAcls(t *testing.T) { primaryClient, sourceMock := redismock.NewClientMock() followerClient, destMock := redismock.NewClientMock() - aclManagerPrimary := &AclManager{RedisClient: primaryClient, nodes: make(map[string]int)} - aclManagerFollower := &AclManager{RedisClient: followerClient, nodes: make(map[string]int)} + aclManagerPrimary := &AclManager{ + RedisClient: primaryClient, + nodes: make(map[string]int), + aclFile: tt.aclFile, + } + aclManagerFollower := &AclManager{ + RedisClient: followerClient, + nodes: make(map[string]int), + aclFile: tt.aclFile, + } - if tt.name == "Invalid aclManagerPrimary" { - aclManagerPrimary = nil - _, _, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) + // Handle "No primary found" separately + if tt.name == "No primary found" { + // No ACL LIST commands should be called + updated, deleted, err := aclManagerFollower.SyncAcls(context.Background(), nil) assert.Error(t, err) - assert.Equal(t, "no primary found", err.Error()) + assert.Nil(t, updated) + assert.Nil(t, deleted) + assert.Contains(t, err.Error(), tt.expectedErrMsg) return } - if tt.listAclsError != nil && tt.wantSourceErr { - sourceMock.ExpectDo("ACL", "LIST").SetErr(tt.listAclsError) - } else { - sourceMock.ExpectDo("ACL", "LIST").SetVal(tt.sourceAcls) - } + if tt.name == "Error: element in sourceAclList not string" { + // Set up source ACL LIST expectation + expectACLList(sourceMock, tt.sourceAcls, tt.sourceListAclsErr) - if tt.listAclsError != nil && tt.wantDestinationErr { - destMock.ExpectDo("ACL", "LIST").SetErr(tt.listAclsError) - } else { - destMock.ExpectDo("ACL", "LIST").SetVal(tt.destinationAcls) - if tt.expectedDeleted != nil { - for _, username := range tt.expectedDeleted { - if tt.wantDestinationErr && tt.redisDoError.Error() == "DELUSER" { - destMock.ExpectDo("ACL", "DELUSER", username).SetErr(tt.redisDoError) - continue - } - destMock.ExpectDo("ACL", "DELUSER", username).SetVal("OK") - } - } - if tt.expectedUpdated != nil { - for _, username := range tt.expectedUpdated { - if tt.wantDestinationErr && tt.redisDoError.Error() == "SETUSER" { - destMock.ExpectDo("ACL", "SETUSER", username).SetErr(tt.redisDoError) - continue - } - destMock.ExpectDo("ACL", "SETUSER", username).SetVal("OK") - } + // Run SyncAcls and assert the expected error + updated, deleted, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) + assert.Error(t, err) + assert.Nil(t, updated) + assert.Nil(t, deleted) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) } + + // Ensure all expectations were met + assert.NoError(t, sourceMock.ExpectationsWereMet()) + return } - if tt.listAclsError != nil && tt.wantDestinationErr { - destMock.ExpectDo("ACL", "LIST").SetErr(tt.listAclsError) - } else { - destMock.ExpectDo("ACL", "LIST").SetVal(tt.destinationAcls) + // Setup source ACL LIST expectation + expectACLList(sourceMock, tt.sourceAcls, tt.sourceListAclsErr) + + // Setup destination ACL LIST expectation only if no source ACL error + if tt.sourceListAclsErr == nil { + expectACLList(destMock, tt.destinationAcls, tt.destListAclsErr) } - if tt.aclFile { - if tt.saveAclError != nil { - sourceMock.ExpectDo("ACL", "SAVE").SetErr(tt.saveAclError) - if !tt.wantSourceErr { - destMock.ExpectDo("ACL", "SAVE").SetErr(tt.saveAclError) + // If there is an error during ACL listing, we expect SyncAcls to return early + if tt.sourceListAclsErr != nil || tt.destListAclsErr != nil { + // Run SyncAcls + updated, deleted, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, updated) + assert.Nil(t, deleted) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) } } else { - sourceMock.ExpectDo("ACL", "SAVE").SetVal("OK") - destMock.ExpectDo("ACL", "SAVE").SetVal("OK") + assert.NoError(t, err) + assert.ElementsMatch(t, tt.expectedUpdated, updated) + assert.ElementsMatch(t, tt.expectedDeleted, deleted) } - if tt.loadAclError != nil && !tt.wantSourceErr { - destMock.ExpectDo("ACL", "LOAD").SetErr(tt.loadAclError) - } else { - destMock.ExpectDo("ACL", "LOAD").SetVal("OK") - } + // Ensure all expectations were met + assert.NoError(t, sourceMock.ExpectationsWereMet()) + assert.NoError(t, destMock.ExpectationsWereMet()) + return } - added, deleted, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) - if err != nil { - if tt.wantSourceErr { - if tt.listAclsError != nil && !strings.Contains(err.Error(), tt.listAclsError.Error()) { - t.Errorf("mirrorAcls() error = %v, wantErr %v", err, tt.listAclsError) - } - if tt.redisDoError != nil && !strings.Contains(err.Error(), tt.redisDoError.Error()) { - t.Errorf("mirrorAcls() error = %v, wantErr %v", err, tt.redisDoError) - } - } - if tt.wantDestinationErr { - if tt.listAclsError != nil && !strings.Contains(err.Error(), tt.listAclsError.Error()) { - t.Errorf("mirrorAcls() error = %v, wantErr %v", err, tt.listAclsError) + // Setup expected ACL DELUSER and ACL SETUSER commands + for _, username := range tt.expectedDeleted { + expectACLDelUser(destMock, username, tt.redisDoError) + } + + for _, username := range tt.expectedUpdated { + // Find the ACL string for the username + var aclStr string + for _, acl := range tt.sourceAcls.([]interface{}) { + aclString, ok := acl.(string) + if !ok { + continue } - if tt.redisDoError != nil && !strings.Contains(err.Error(), tt.redisDoError.Error()) { - t.Errorf("mirrorAcls() error = %v, wantErr %v", err, tt.redisDoError) + if strings.Contains(aclString, "user "+username+" ") { + aclStr = aclString + break } } - if !tt.wantSourceErr && !tt.wantDestinationErr { - t.Errorf("mirrorAcls() error = %v, wantErr %v", err, tt.wantSourceErr) + + if aclStr == "" { + t.Fatalf("No ACL string found for user '%s'", username) } - } - slices.Sort(added) - slices.Sort(tt.expectedUpdated) - slices.Sort(deleted) - slices.Sort(tt.expectedDeleted) - if !reflect.DeepEqual(deleted, tt.expectedDeleted) { - t.Errorf("mirrorAcls() deleted = %v, expectedDeleted %v", deleted, tt.expectedDeleted) - } - if !reflect.DeepEqual(added, tt.expectedUpdated) { - t.Errorf("mirrorAcls() updated = %v, expectedUpdated %v", deleted, tt.expectedUpdated) - } - }) - } -} -func TestCurrentFunction(t *testing.T) { - tests := []struct { - name string - mockResp string - want int - wantErr bool - RedisExpectInfoError error - }{ - { - name: "parse Primary output", - mockResp: primaryOutput, - want: Primary, - wantErr: false, - }, - { - name: "parse Follower output", - mockResp: followerOutput, - want: Follower, - wantErr: false, - }, - { - name: "parse primary error", - mockResp: primaryOutput, - want: Unknown, - wantErr: true, - RedisExpectInfoError: fmt.Errorf("error"), - }, - } + expectACLSetUser(destMock, aclStr, tt.redisDoError) + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - redisClient, mock := redismock.NewClientMock() + // Setup ACL SAVE and LOAD on destMock if aclFile is enabled + if tt.aclFile { + if tt.saveAclError != nil { + expectACLSave(destMock, tt.saveAclError) + } else { + expectACLSave(destMock, nil) + if tt.loadAclError != nil { + expectACLLoad(destMock, tt.loadAclError) + } else { + expectACLLoad(destMock, nil) + } + } + } - // Mocking the response for the Info function + // Run SyncAcls + updated, deleted, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) if tt.wantErr { - mock.ExpectInfo("replication").SetErr(tt.RedisExpectInfoError) - } else { - mock.ExpectInfo("replication").SetVal(tt.mockResp) - } - aclManager := AclManager{RedisClient: redisClient, nodes: make(map[string]int)} - ctx := context.Background() - nodes, err := aclManager.CurrentFunction(ctx) - if (err != nil) != tt.wantErr { - if !strings.Contains(err.Error(), tt.RedisExpectInfoError.Error()) { - t.Errorf("findNodes() error = %v, wantErr %v", err, tt.wantErr) - return + assert.Error(t, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) } + } else { + assert.NoError(t, err) + assert.ElementsMatch(t, tt.expectedUpdated, updated) + assert.ElementsMatch(t, tt.expectedDeleted, deleted) } - assert.Equal(t, tt.want, nodes) + // Ensure all expectations were met + assert.NoError(t, sourceMock.ExpectationsWereMet()) + assert.NoError(t, destMock.ExpectationsWereMet()) }) } } -func TestNewAclManager(t *testing.T) { - tests := []struct { - name string - want *AclManager - }{ - { - name: "create AclManager", - want: &AclManager{ - Addr: "localhost:6379", - Password: "password", - Username: "username", - }, - }, +// expectACLList sets up the expectation for the ACL LIST command. +// If err is not nil, it simulates an error; otherwise, it returns the provided ACL list. +func expectACLList(mock redismock.ClientMock, acls interface{}, err error) { + if err != nil { + mock.ExpectDo("ACL", "LIST").SetErr(err) + } else { + if acls == nil { + acls = []interface{}{} + } + mock.ExpectDo("ACL", "LIST").SetVal(acls) } +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := New(tt.want.Addr, tt.want.Username, tt.want.Password, false) - assert.Equal(t, tt.want.Addr, got.Addr) - assert.Equal(t, tt.want.Username, got.Username) - assert.Equal(t, tt.want.Password, got.Password) - assert.NotNil(t, got.RedisClient) - }) +// expectACLDelUser sets up the expectation for the ACL DELUSER command. +// If err is not nil, it simulates an error; otherwise, it returns "OK". +func expectACLDelUser(mock redismock.ClientMock, username string, err error) { + if err != nil { + mock.ExpectDo("ACL", "DELUSER", username).SetErr(err) + } else { + mock.ExpectDo("ACL", "DELUSER", username).SetVal("OK") } } -func TestCurrentFunction_Error(t *testing.T) { - redisClient, mock := redismock.NewClientMock() +// expectACLSetUser sets up the expectation for the ACL SETUSER command. +// If err is not nil, it simulates an error; otherwise, it returns "OK". +func expectACLSetUser(mock redismock.ClientMock, aclStr string, err error) { + fields := strings.Fields(aclStr) + args := []interface{}{"ACL", "SETUSER"} + for _, field := range fields[1:] { // Skip the "user" keyword + args = append(args, field) + } - // Mocking the response for the Info function - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - aclManager := AclManager{RedisClient: redisClient} - ctx := context.Background() + if err != nil { + mock.ExpectDo(args...).SetErr(err) + } else { + mock.ExpectDo(args...).SetVal("OK") + } +} - _, err := aclManager.CurrentFunction(ctx) - assert.Error(t, err) +// expectACLSave sets up the expectation for the ACL SAVE command. +// If err is not nil, it simulates an error; otherwise, it returns "OK". +func expectACLSave(mock redismock.ClientMock, err error) { + if err != nil { + mock.ExpectDo("ACL", "SAVE").SetErr(err) + } else { + mock.ExpectDo("ACL", "SAVE").SetVal("OK") + } +} + +// expectACLLoad sets up the expectation for the ACL LOAD command. +// If err is not nil, it simulates an error; otherwise, it returns "OK". +func expectACLLoad(mock redismock.ClientMock, err error) { + if err != nil { + mock.ExpectDo("ACL", "LOAD").SetErr(err) + } else { + mock.ExpectDo("ACL", "LOAD").SetVal("OK") + } } -func TestAclManager_Primary(t *testing.T) { +func TestCurrentFunction(t *testing.T) { + t.Parallel() tests := []struct { - name string - mockResp string - want string - wantErr bool + name string + mockRoleResp interface{} + expectedFunc int + wantErr bool }{ { - name: "parse master output", - mockResp: primaryOutput, - wantErr: false, + name: "Primary node", + mockRoleResp: primaryRoleOutput, + expectedFunc: Primary, + wantErr: false, }, { - name: "parse Follower output", - mockResp: followerOutput, - want: "172.21.0.2:6379", - wantErr: false, + name: "Follower node", + mockRoleResp: followerRoleOutput, + expectedFunc: Follower, + wantErr: false, }, { - name: "error on replicationInfo", - mockResp: followerOutput, - wantErr: true, + name: "Error on ROLE command", + mockRoleResp: nil, + wantErr: true, + expectedFunc: Unknown, }, { - name: "username and password", - mockResp: followerOutput, - wantErr: false, + name: "Unknown role type", + mockRoleResp: []interface{}{ + "sentinel", + }, + wantErr: true, + expectedFunc: Unknown, }, } @@ -542,164 +777,134 @@ func TestAclManager_Primary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { redisClient, mock := redismock.NewClientMock() - // Mocking the response for the Info function if tt.wantErr { - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) + mock.ExpectDo("ROLE").SetErr(fmt.Errorf("error")) } else { - mock.ExpectInfo("replication").SetVal(tt.mockResp) + mock.ExpectDo("ROLE").SetVal(tt.mockRoleResp) } - aclManager := AclManager{RedisClient: redisClient, Username: "username", Password: "password", nodes: make(map[string]int)} - ctx := context.Background() - primary, err := aclManager.Primary(ctx) - if tt.wantErr { - assert.Error(t, err) - assert.Nil(t, primary) - return + aclManager := AclManager{ + RedisClient: redisClient, + nodes: make(map[string]int), + mu: sync.Mutex{}, } + ctx := context.Background() - if tt.name == "username and password" { - assert.Equal(t, aclManager.Username, primary.Username) - assert.Equal(t, aclManager.Password, primary.Password) + function, err := aclManager.CurrentFunction(ctx) + if (err != nil) != tt.wantErr { + t.Errorf("CurrentFunction() error = %v, wantErr %v", err, tt.wantErr) return } - assert.NoError(t, err) - if tt.want == "" { - assert.Nil(t, primary) - return - } - assert.NotNil(t, primary) - assert.Equal(t, tt.want, primary.Addr) + assert.Equal(t, tt.expectedFunc, function) }) } } -func TestAclManager_Loop(t *testing.T) { - viper.Set("syncInterval", 4) +func TestPrimary(t *testing.T) { tests := []struct { - name string - aclManager *AclManager - wantErr bool - expectError error + name string + mockRoleResp interface{} + expectedAddr string + wantErr bool }{ { - name: "Primary node", - aclManager: &AclManager{ - Addr: "localhost:6379", - Password: "password", - Username: "username", - aclFile: false, - nodes: make(map[string]int), - }, - wantErr: false, + name: "Primary node returns nil", + mockRoleResp: primaryRoleOutput, + expectedAddr: "", + wantErr: false, }, { - name: "Primary node with error", - aclManager: &AclManager{ - Addr: "localhost:6379", - Password: "password", - Username: "username", - aclFile: false, - nodes: make(map[string]int), - }, - wantErr: true, - expectError: fmt.Errorf("unable to find Primary"), + name: "Follower node returns primary address", + mockRoleResp: followerRoleOutput, + expectedAddr: "172.21.0.2:6379", + wantErr: false, }, { - name: "follower node with error", - aclManager: &AclManager{ - Addr: "localhost:6379", - Password: "password", - Username: "username", - aclFile: false, - nodes: make(map[string]int), - }, - wantErr: true, - expectError: fmt.Errorf("unable to check if it's a Primary"), - }, - { - name: "follower node", - aclManager: &AclManager{ - Addr: "localhost:6379", - Password: "password", - Username: "username", - aclFile: false, - nodes: make(map[string]int), - }, - wantErr: false, - expectError: nil, + name: "Error on ROLE command", + mockRoleResp: nil, + wantErr: true, + expectedAddr: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { redisClient, mock := redismock.NewClientMock() - tt.aclManager.RedisClient = redisClient if tt.wantErr { - if tt.name == "Primary node with error" { - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - } else { - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectInfo("replication").SetErr(fmt.Errorf("error")) - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectInfo("replication").SetVal(followerOutput) - } + mock.ExpectDo("ROLE").SetErr(fmt.Errorf("error")) } else { - if tt.name == "Primary node" { - mock.ExpectInfo("replication").SetVal(primaryOutput) - mock.ExpectInfo("replication").SetVal(primaryOutput) - mock.ExpectInfo("replication").SetVal(primaryOutput) - mock.ExpectInfo("replication").SetVal(primaryOutput) - mock.ExpectInfo("replication").SetVal(primaryOutput) - mock.ExpectInfo("replication").SetVal(primaryOutput) - mock.ExpectInfo("replication").SetVal(primaryOutput) - } + mock.ExpectDo("ROLE").SetVal(tt.mockRoleResp) } - // Set up a cancellable context to control the loop - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Run Loop in a separate goroutine - done := make(chan error, 1) - go func() { - done <- tt.aclManager.Loop(ctx) - }() - - time.Sleep(time.Second * 10) - - // Cancel the context to stop the loop - cancel() + aclManager := AclManager{ + RedisClient: redisClient, + nodes: make(map[string]int), + mu: sync.Mutex{}, + } + ctx := context.Background() - // Check for errors - err := <-done - if err != nil { - if !tt.wantErr { - t.Errorf("Expected no error, got: %v", err) - } + primary, err := aclManager.Primary(ctx) + if (err != nil) != tt.wantErr { + t.Errorf("Primary() error = %v, wantErr %v", err, tt.wantErr) + return + } - if !strings.Contains(err.Error(), tt.expectError.Error()) { - t.Errorf("Expected error: %v, got: %v", tt.expectError, err) - } + if tt.expectedAddr == "" { + assert.Nil(t, primary) + } else { + assert.NotNil(t, primary) + assert.Equal(t, tt.expectedAddr, primary.Addr) } }) } } +func TestLoop(t *testing.T) { + t.Parallel() + redisClient, mock := redismock.NewClientMock() + aclManager := &AclManager{ + RedisClient: redisClient, + nodes: make(map[string]int), + mu: sync.Mutex{}, + } + + // Mock the ROLE command to return follower output + mock.ExpectDo("ROLE").SetVal(followerRoleOutput) + mock.ExpectDo("ROLE").SetVal(followerRoleOutput) + // Mock listing ACLs + mock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ + "user default on nopass ~* &* +@all", + }) + + // Simulate SyncAcls + mock.ExpectDo("ROLE").SetVal(followerRoleOutput) + mock.ExpectDo("ROLE").SetVal(primaryRoleOutput) + mock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ + "user default on nopass ~* &* +@all", + }) + mock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ + "user default on nopass ~* &* +@all", + }) + + // Set up a cancellable context to control the loop + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Run Loop in a separate goroutine + go func() { + err := aclManager.Loop(ctx, 2*time.Second) + if err != nil && !strings.Contains(err.Error(), "context canceled") { + t.Errorf("Loop() error = %v", err) + } + }() + + // Let it run for a short time + time.Sleep(5 * time.Second) + cancel() +} + func TestClose(t *testing.T) { redisClient, _ := redismock.NewClientMock() aclManager := AclManager{RedisClient: redisClient} @@ -707,9 +912,10 @@ func TestClose(t *testing.T) { assert.NoError(t, err) } -func TestClosePanic(t *testing.T) { +func TestClose_NilClient(t *testing.T) { aclManager := AclManager{RedisClient: nil} - assert.Panics(t, func() { aclManager.Close() }) + err := aclManager.Close() + assert.Error(t, err) } func TestSaveAclFile(t *testing.T) { @@ -746,19 +952,12 @@ func TestSaveAclFile(t *testing.T) { if err != nil && tt.wantErr && !strings.HasSuffix(err.Error(), tt.err.Error()) { t.Errorf("saveAclFile() got unexpected error = %v, want %v", err, tt.err) } - - assertExpectations(t, mock) }) } } -func assertExpectations(t *testing.T, mock redismock.ClientMock) { - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unmet expectations: %s", err) - } -} - func TestLoadAclFile(t *testing.T) { + t.Parallel() tests := []struct { name string wantErr bool @@ -792,176 +991,59 @@ func TestLoadAclFile(t *testing.T) { if err != nil && tt.wantErr && !strings.HasSuffix(err.Error(), tt.err.Error()) { t.Errorf("loadAclFile() got unexpected error = %v, want %v", err, tt.err) } - - assertExpectations(t, mock) }) } } -func TestFindNodes_LargeCluster(t *testing.T) { - mockResp := generateLargeClusterOutput(1000) // Generates a mock output for 1000 nodes - redisClient, mock := redismock.NewClientMock() - mock.ExpectInfo("replication").SetVal(mockResp) - - aclManager := AclManager{RedisClient: redisClient, nodes: make(map[string]int)} - ctx := context.Background() - - err := aclManager.findNodes(ctx) - assert.NoError(t, err) - assert.Equal(t, 1000, len(aclManager.nodes)) -} - -func TestLoop_ShortInterval(t *testing.T) { - viper.Set("syncInterval", 1) // Set a very short sync interval for testing - redisClient, mock := redismock.NewClientMock() - - aclManager := &AclManager{ - Addr: "localhost:6379", - Password: "password", - Username: "username", - RedisClient: redisClient, - nodes: make(map[string]int), +func TestHashString(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "empty string", + input: "", + expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + }, + { + name: "non-empty string", + input: "hello world", + expected: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", + }, } - mock.ExpectInfo("replication").SetVal(followerOutput) - mock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user default on nopass ~* &* +@all", - }) - - // Set up a cancellable context to control the loop - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Run Loop in a separate goroutine - done := make(chan error, 1) - go func() { - done <- aclManager.Loop(ctx) - }() - - time.Sleep(time.Second * 5) // Run the loop for a few seconds - - // Cancel the context to stop the loop - cancel() - - // Check for errors - err := <-done - assert.NoError(t, err) -} - -func generateLargeClusterOutput(nodeCount int) string { - var sb strings.Builder - sb.WriteString("# Replication\nrole:master\nconnected_slaves:" + fmt.Sprint(nodeCount) + "\n") - for i := 0; i < nodeCount; i++ { - sb.WriteString(fmt.Sprintf("slave%d:ip=172.21.0.%d,port=6379,state=online,offset=322,lag=0\n", i, i+3)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash := hashString(tt.input) + assert.Equal(t, tt.expected, hash) + }) } - return sb.String() -} - -func TestSyncAcls_ACLFileEnabled(t *testing.T) { - primaryClient, primaryMock := redismock.NewClientMock() - followerClient, followerMock := redismock.NewClientMock() - - aclManagerPrimary := &AclManager{RedisClient: primaryClient, nodes: make(map[string]int), aclFile: true} - aclManagerFollower := &AclManager{RedisClient: followerClient, nodes: make(map[string]int), aclFile: true} - - primaryMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl2", - }) - followerMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl3", - }) - followerMock.ExpectDo("ACL", "DELUSER", "acl3").SetVal("OK") - followerMock.ExpectDo("ACL", "SETUSER", "acl2").SetVal("OK") - - primaryMock.ExpectDo("ACL", "SAVE").SetVal("OK") - followerMock.ExpectDo("ACL", "SAVE").SetVal("OK") - followerMock.ExpectDo("ACL", "LOAD").SetVal("OK") - - updated, deleted, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"acl2"}, updated) - assert.ElementsMatch(t, []string{"acl3"}, deleted) -} - -func TestSyncAcls_SaveACLFileError(t *testing.T) { - primaryClient, primaryMock := redismock.NewClientMock() - followerClient, followerMock := redismock.NewClientMock() - - aclManagerPrimary := &AclManager{RedisClient: primaryClient, nodes: make(map[string]int), aclFile: true} - aclManagerFollower := &AclManager{RedisClient: followerClient, nodes: make(map[string]int), aclFile: true} - - primaryMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl2", - }) - followerMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl3", - }) - - primaryMock.ExpectDo("ACL", "SAVE").SetErr(fmt.Errorf("failed to save ACL on primary")) - - _, _, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to save ACL on primary") -} - -func TestSyncAcls_LoadACLFileError(t *testing.T) { - primaryClient, primaryMock := redismock.NewClientMock() - followerClient, followerMock := redismock.NewClientMock() - - aclManagerPrimary := &AclManager{RedisClient: primaryClient, nodes: make(map[string]int), aclFile: true} - aclManagerFollower := &AclManager{RedisClient: followerClient, nodes: make(map[string]int), aclFile: true} - - primaryMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl2", - }) - followerMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl2", - }) - - primaryMock.ExpectDo("ACL", "SAVE").SetVal("OK") - followerMock.ExpectDo("ACL", "SAVE").SetVal("OK") - followerMock.ExpectDo("ACL", "LOAD").SetErr(fmt.Errorf("failed to load ACL on follower")) - - _, _, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to load ACL on follower") } -func TestSyncAcls_ACLFileSync(t *testing.T) { - primaryClient, primaryMock := redismock.NewClientMock() - followerClient, followerMock := redismock.NewClientMock() - - aclManagerPrimary := &AclManager{RedisClient: primaryClient, nodes: make(map[string]int), aclFile: true} - aclManagerFollower := &AclManager{RedisClient: followerClient, nodes: make(map[string]int), aclFile: true} - - primaryMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl2", - }) - followerMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl2", "user acl3", - }) - followerMock.ExpectDo("ACL", "DELUSER", "acl3").SetVal("OK") - - primaryMock.ExpectDo("ACL", "SAVE").SetVal("OK") - followerMock.ExpectDo("ACL", "SAVE").SetVal("OK") - followerMock.ExpectDo("ACL", "LOAD").SetVal("OK") - - _, _, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) - assert.NoError(t, err) - - primaryMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl2", - }) - followerMock.ExpectDo("ACL", "LIST").SetVal([]interface{}{ - "user acl1", "user acl2", - }) - - primaryMock.ExpectDo("ACL", "SAVE").SetVal("OK") - followerMock.ExpectDo("ACL", "SAVE").SetVal("OK") - followerMock.ExpectDo("ACL", "LOAD").SetVal("OK") +func TestSetBatchSize(t *testing.T) { + tests := []struct { + name string + batchSize int + expectedSize int + }{ + { + name: "default batch size", + batchSize: 0, + expectedSize: 0, + }, + { + name: "custom batch size", + batchSize: 10, + expectedSize: 10, + }, + } - updated, deleted, err := aclManagerFollower.SyncAcls(context.Background(), aclManagerPrimary) - assert.NoError(t, err) - assert.Empty(t, updated) - assert.Empty(t, deleted) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + aclManager := &AclManager{} + aclManager.SetBatchSize(tt.batchSize) + assert.Equal(t, tt.expectedSize, aclManager.batchSize) + }) + } }