From 12a0fef93fc4e78747b99d399b0c457e50c710b9 Mon Sep 17 00:00:00 2001 From: rishav vajpayee <46602331+rishavvajpayee@users.noreply.github.com> Date: Fri, 4 Oct 2024 20:52:16 +0530 Subject: [PATCH] #733 & #894 : integration and stress tests for ratelimiter || added CORS to server (#13) --- .gitignore | 1 + config/config.go | 59 ++++++++++-- go.mod | 10 +- go.sum | 18 +++- internal/middleware/cors.go | 34 +++++++ internal/middleware/ratelimiter.go | 95 +++++++++++++++---- internal/server/httpServer.go | 15 +-- internal/tests/dbmocks/mock_dicedb.go | 59 ++++++++++++ .../ratelimiter_integration_test.go | 42 ++++++++ .../tests/stress/ratelimiter_stress_test.go | 47 +++++++++ pkg/util/helpers.go | 22 +++++ 11 files changed, 362 insertions(+), 40 deletions(-) create mode 100644 internal/middleware/cors.go create mode 100644 internal/tests/dbmocks/mock_dicedb.go create mode 100644 internal/tests/integration/ratelimiter_integration_test.go create mode 100644 internal/tests/stress/ratelimiter_stress_test.go diff --git a/.gitignore b/.gitignore index 814f196..b1ca7cd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ .vscode/ .env /playground-mono +__debug_bin* \ No newline at end of file diff --git a/config/config.go b/config/config.go index a9991de..73c1f01 100644 --- a/config/config.go +++ b/config/config.go @@ -1,25 +1,36 @@ package config import ( + "fmt" "os" "strconv" + "strings" + + "github.com/joho/godotenv" ) // Config holds the application configuration type Config struct { - DiceAddr string - ServerPort string - RequestLimit int // Field for the request limit - RequestWindow int // Field for the time window in seconds + DiceAddr string + ServerPort string + RequestLimit int64 // Field for the request limit + RequestWindow float64 // Field for the time window in float64 + AllowedOrigins []string // Field for the allowed origins } // LoadConfig loads the application configuration from environment variables or defaults func LoadConfig() *Config { + err := godotenv.Load() + if err != nil { + fmt.Println("Warning: .env file not found, falling back to system environment variables.") + } + return &Config{ - DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address - ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port - RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit - RequestWindow: getEnvInt("REQUEST_WINDOW", 60), // Default request window in seconds + DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address + ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port + RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit + RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64 + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins } } @@ -32,11 +43,39 @@ func getEnv(key, fallback string) string { } // getEnvInt retrieves an environment variable as an integer or returns a default value -func getEnvInt(key string, fallback int) int { +func getEnvInt(key string, fallback int) int64 { if value, exists := os.LookupEnv(key); exists { if intValue, err := strconv.Atoi(value); err == nil { - return intValue + return int64(intValue) + } + } + return int64(fallback) +} + +// added for miliseconds request window controls +func getEnvFloat64(key string, fallback float64) float64 { + if value, exists := os.LookupEnv(key); exists { + if floatValue, err := strconv.ParseFloat(value, 64); err == nil { + return floatValue } } return fallback } + +func getEnvArray(key string, fallback []string) []string { + if value, exists := os.LookupEnv(key); exists { + if arrayValue := splitString(value); len(arrayValue) > 0 { + return arrayValue + } + } + return fallback +} + +// splitString splits a string by comma and returns a slice of strings +func splitString(s string) []string { + var array []string + for _, v := range strings.Split(s, ",") { + array = append(array, strings.TrimSpace(v)) + } + return array +} diff --git a/go.mod b/go.mod index 26906be..5502661 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,16 @@ module server go 1.22.5 +require ( + github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 + github.com/joho/godotenv v1.5.1 + github.com/stretchr/testify v1.9.0 +) + require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index af5b6c5..36368eb 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,22 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 h1:Cqyj9WCtoobN6++bFbDSe27q94SPwJD9Z0wmu+SDRuk= github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831/go.mod h1:8+VZrr14c2LW8fW4tWZ8Bv3P2lfvlg+PpsSn5cWWuiQ= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go new file mode 100644 index 0000000..63bc5d9 --- /dev/null +++ b/internal/middleware/cors.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "net/http" + "server/config" +) + +func enableCors(w http.ResponseWriter, r *http.Request) { + configValue := config.LoadConfig() + allAllowedOrigins := configValue.AllowedOrigins + origin := r.Header.Get("Origin") + allowed := false + for _, allowedOrigin := range allAllowedOrigins { + if origin == allowedOrigin || allowedOrigin == "*" || origin == "" { + allowed = true + break + } + } + if !allowed { + http.Error(w, "CORS: origin not allowed", http.StatusForbidden) + return + } + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE, PATCH") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Content-Length") + + if r.Method == http.MethodOptions { + w.Header().Set("Access-Control-Max-Age", "86400") + w.WriteHeader(http.StatusOK) + return + } + w.Header().Set("Content-Type", "application/json") +} diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index 293dd9c..c8da643 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "server/internal/db" + mock "server/internal/tests/dbmocks" "strconv" "strings" "time" @@ -14,28 +15,14 @@ import ( dice "github.com/dicedb/go-dice" ) -// TODO: Look at this later -func enableCors(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") -} - // RateLimiter middleware to limit requests based on a specified limit and duration -func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.Handler { +func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Enable CORS for requests + enableCors(w, r) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Set CORS headers - enableCors(w) - - // Handle OPTIONS preflight request - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - return - } - // Skip rate limiting for non-command endpoints if !strings.Contains(r.URL.Path, "/cli/") { next.ServeHTTP(w, r) @@ -56,9 +43,9 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H } // Initialize request count - requestCount := 0 + requestCount := int64(0) if val != "" { - requestCount, err = strconv.Atoi(val) + requestCount, err = strconv.ParseInt(val, 10, 64) if err != nil { slog.Error("Error converting request count", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -67,7 +54,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H } // Check if the request count exceeds the limit - if requestCount >= limit { + if requestCount > limit { slog.Warn("Request limit exceeded", "count", requestCount) http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) return @@ -94,3 +81,71 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H next.ServeHTTP(w, r) }) } + +func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Enable CORS for requests + enableCors(w, r) + + // Set a request context with a timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Only apply rate limiting for specific paths (e.g., "/cli/") + if !strings.Contains(r.URL.Path, "/cli/") { + next.ServeHTTP(w, r) + return + } + + // Generate the rate limiting key based on the current window + currentWindow := time.Now().Unix() / int64(window) + key := fmt.Sprintf("request_count:%d", currentWindow) + slog.Info("Created rate limiter key", slog.Any("key", key)) + + // Get the current request count for this window from the mock DB + val, err := client.Get(ctx, key) + if err != nil { + slog.Error("Error fetching request count", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Parse the current request count or initialize to 0 + var requestCount int64 = 0 + if val != "" { + requestCount, err = strconv.ParseInt(val, 10, 64) + if err != nil { + slog.Error("Error converting request count", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + } + + // Check if the request limit has been exceeded + if requestCount >= limit { + slog.Warn("Request limit exceeded", "count", requestCount) + http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) + return + } + + // Increment the request count in the mock DB + requestCount, err = client.Incr(ctx, key) + if err != nil { + slog.Error("Error incrementing request count", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Set expiration for the key if it's the first request in the window + if requestCount == 1 { + err = client.Expire(ctx, key, time.Duration(window)*time.Second) + if err != nil { + slog.Error("Error setting expiry for request count", "error", err) + } + } + + // Log the successful request and pass control to the next handler + slog.Info("Request processed", "count", requestCount) + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index 68e4c86..2e8e1a8 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -47,7 +47,7 @@ func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { cim.rateLimiter(w, r, cim.mux) } -func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit, window int) *HTTPServer { +func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit int64, window float64) *HTTPServer { handlerMux := &HandlerMux{ mux: mux, rateLimiter: func(w http.ResponseWriter, r *http.Request, next http.Handler) { @@ -97,26 +97,27 @@ func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, "Error parsing HTTP request", http.StatusBadRequest) + http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest) return } resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) + http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest) return } - if _, ok := resp.(string); !ok { - log.Println("Error marshaling response", "error", err) + respStr, ok := resp.(string) + if !ok { + log.Println("Error: response is not a string", "error", err) http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) return } - httpResponse := HTTPResponse{Data: resp.(string)} + httpResponse := HTTPResponse{Data: respStr} responseJSON, err := json.Marshal(httpResponse) if err != nil { - log.Println("Error marshaling response", "error", err) + log.Println("Error marshaling response to JSON", "error", err) http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) return } diff --git a/internal/tests/dbmocks/mock_dicedb.go b/internal/tests/dbmocks/mock_dicedb.go new file mode 100644 index 0000000..97deee2 --- /dev/null +++ b/internal/tests/dbmocks/mock_dicedb.go @@ -0,0 +1,59 @@ +package db + +import ( + "context" + "fmt" + "sync" + "time" +) + +type DiceDBMock struct { + data map[string]string + mutex sync.Mutex +} + +func NewDiceDBMock() *DiceDBMock { + return &DiceDBMock{ + data: make(map[string]string), + } +} + +func (db *DiceDBMock) Get(ctx context.Context, key string) (string, error) { + db.mutex.Lock() + defer db.mutex.Unlock() + + val, exists := db.data[key] + if !exists { + return "", nil + } + return val, nil +} + +func (db *DiceDBMock) Set(ctx context.Context, key, value string, expiration time.Duration) error { + db.mutex.Lock() + defer db.mutex.Unlock() + + db.data[key] = value + return nil +} + +func (db *DiceDBMock) Incr(ctx context.Context, key string) (int64, error) { + db.mutex.Lock() + defer db.mutex.Unlock() + + val, exists := db.data[key] + var count int64 + if exists { + if _, err := fmt.Sscanf(val, "%d", &count); err != nil { + return 0, fmt.Errorf("error parsing value for key %s: %w", key, err) + } + } + + count++ + db.data[key] = fmt.Sprintf("%d", count) + return count, nil +} + +func (db *DiceDBMock) Expire(ctx context.Context, key string, expiration time.Duration) error { + return nil +} diff --git a/internal/tests/integration/ratelimiter_integration_test.go b/internal/tests/integration/ratelimiter_integration_test.go new file mode 100644 index 0000000..421dafb --- /dev/null +++ b/internal/tests/integration/ratelimiter_integration_test.go @@ -0,0 +1,42 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + config "server/config" + util "server/pkg/util" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRateLimiterWithinLimit(t *testing.T) { + configValue := config.LoadConfig() + limit := configValue.RequestLimit + window := configValue.RequestWindow + + w, r, rateLimiter := util.SetupRateLimiter(limit, window) + + for i := int64(0); i < limit; i++ { + rateLimiter.ServeHTTP(w, r) + require.Equal(t, http.StatusOK, w.Code) + } +} + +func TestRateLimiterExceedsLimit(t *testing.T) { + configValue := config.LoadConfig() + limit := configValue.RequestLimit + window := configValue.RequestWindow + + w, r, rateLimiter := util.SetupRateLimiter(limit, window) + + for i := int64(0); i < limit; i++ { + rateLimiter.ServeHTTP(w, r) + require.Equal(t, http.StatusOK, w.Code) + } + + w = httptest.NewRecorder() + rateLimiter.ServeHTTP(w, r) + require.Equal(t, http.StatusTooManyRequests, w.Code) + require.Contains(t, w.Body.String(), "429 - Too Many Requests") +} diff --git a/internal/tests/stress/ratelimiter_stress_test.go b/internal/tests/stress/ratelimiter_stress_test.go new file mode 100644 index 0000000..7e37054 --- /dev/null +++ b/internal/tests/stress/ratelimiter_stress_test.go @@ -0,0 +1,47 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "server/config" + util "server/pkg/util" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateLimiterUnderStress(t *testing.T) { + configValue := config.LoadConfig() + limit := configValue.RequestLimit + window := configValue.RequestWindow + + _, r, rateLimiter := util.SetupRateLimiter(limit, window) + + var wg sync.WaitGroup + var numRequests int64 = limit // add some extra requests to ensure we don't hit the limit + successCount := int64(0) + failCount := int64(0) + var mu sync.Mutex + + for i := int64(0); i < numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + rec := httptest.NewRecorder() + + time.Sleep(10 * time.Millisecond) + rateLimiter.ServeHTTP(rec, r) + mu.Lock() + if rec.Code == http.StatusOK { + successCount++ + } else if rec.Code == http.StatusTooManyRequests { + failCount++ + } + mu.Unlock() + }() + } + wg.Wait() + require.Equal(t, limit, successCount, "Should succeed for exactly limit requests") +} diff --git a/pkg/util/helpers.go b/pkg/util/helpers.go index ce826f8..c86cef0 100644 --- a/pkg/util/helpers.go +++ b/pkg/util/helpers.go @@ -5,8 +5,12 @@ import ( "errors" "fmt" "io" + "log" "net/http" + "net/http/httptest" "server/internal/cmds" + "server/internal/middleware" + db "server/internal/tests/dbmocks" "strings" ) @@ -161,3 +165,21 @@ func JSONResponse(w http.ResponseWriter, status int, data interface{}) { http.Error(w, err.Error(), http.StatusInternalServerError) } } + +func MockHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("OK")); err != nil { + log.Fatalf("Failed to write response: %v", err) + } +} + +func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { + mockClient := db.NewDiceDBMock() + + r := httptest.NewRequest("GET", "/cli/somecommand", http.NoBody) + w := httptest.NewRecorder() + + rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window) + + return w, r, rateLimiter +}