From 12a0fef93fc4e78747b99d399b0c457e50c710b9 Mon Sep 17 00:00:00 2001 From: rishav vajpayee <46602331+rishavvajpayee@users.noreply.github.com> Date: Fri, 4 Oct 2024 20:52:16 +0530 Subject: [PATCH 1/6] #733 & #894 : integration and stress tests for ratelimiter || added CORS to server (#13) --- .gitignore | 1 + config/config.go | 59 ++++++++++-- go.mod | 10 +- go.sum | 18 +++- internal/middleware/cors.go | 34 +++++++ internal/middleware/ratelimiter.go | 95 +++++++++++++++---- internal/server/httpServer.go | 15 +-- internal/tests/dbmocks/mock_dicedb.go | 59 ++++++++++++ .../ratelimiter_integration_test.go | 42 ++++++++ .../tests/stress/ratelimiter_stress_test.go | 47 +++++++++ pkg/util/helpers.go | 22 +++++ 11 files changed, 362 insertions(+), 40 deletions(-) create mode 100644 internal/middleware/cors.go create mode 100644 internal/tests/dbmocks/mock_dicedb.go create mode 100644 internal/tests/integration/ratelimiter_integration_test.go create mode 100644 internal/tests/stress/ratelimiter_stress_test.go diff --git a/.gitignore b/.gitignore index 814f196..b1ca7cd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ .vscode/ .env /playground-mono +__debug_bin* \ No newline at end of file diff --git a/config/config.go b/config/config.go index a9991de..73c1f01 100644 --- a/config/config.go +++ b/config/config.go @@ -1,25 +1,36 @@ package config import ( + "fmt" "os" "strconv" + "strings" + + "github.com/joho/godotenv" ) // Config holds the application configuration type Config struct { - DiceAddr string - ServerPort string - RequestLimit int // Field for the request limit - RequestWindow int // Field for the time window in seconds + DiceAddr string + ServerPort string + RequestLimit int64 // Field for the request limit + RequestWindow float64 // Field for the time window in float64 + AllowedOrigins []string // Field for the allowed origins } // LoadConfig loads the application configuration from environment variables or defaults func LoadConfig() *Config { + err := godotenv.Load() + if err != nil { + fmt.Println("Warning: .env file not found, falling back to system environment variables.") + } + return &Config{ - DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address - ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port - RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit - RequestWindow: getEnvInt("REQUEST_WINDOW", 60), // Default request window in seconds + DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address + ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port + RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit + RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64 + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins } } @@ -32,11 +43,39 @@ func getEnv(key, fallback string) string { } // getEnvInt retrieves an environment variable as an integer or returns a default value -func getEnvInt(key string, fallback int) int { +func getEnvInt(key string, fallback int) int64 { if value, exists := os.LookupEnv(key); exists { if intValue, err := strconv.Atoi(value); err == nil { - return intValue + return int64(intValue) + } + } + return int64(fallback) +} + +// added for miliseconds request window controls +func getEnvFloat64(key string, fallback float64) float64 { + if value, exists := os.LookupEnv(key); exists { + if floatValue, err := strconv.ParseFloat(value, 64); err == nil { + return floatValue } } return fallback } + +func getEnvArray(key string, fallback []string) []string { + if value, exists := os.LookupEnv(key); exists { + if arrayValue := splitString(value); len(arrayValue) > 0 { + return arrayValue + } + } + return fallback +} + +// splitString splits a string by comma and returns a slice of strings +func splitString(s string) []string { + var array []string + for _, v := range strings.Split(s, ",") { + array = append(array, strings.TrimSpace(v)) + } + return array +} diff --git a/go.mod b/go.mod index 26906be..5502661 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,16 @@ module server go 1.22.5 +require ( + github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 + github.com/joho/godotenv v1.5.1 + github.com/stretchr/testify v1.9.0 +) + require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index af5b6c5..36368eb 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,22 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 h1:Cqyj9WCtoobN6++bFbDSe27q94SPwJD9Z0wmu+SDRuk= github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831/go.mod h1:8+VZrr14c2LW8fW4tWZ8Bv3P2lfvlg+PpsSn5cWWuiQ= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go new file mode 100644 index 0000000..63bc5d9 --- /dev/null +++ b/internal/middleware/cors.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "net/http" + "server/config" +) + +func enableCors(w http.ResponseWriter, r *http.Request) { + configValue := config.LoadConfig() + allAllowedOrigins := configValue.AllowedOrigins + origin := r.Header.Get("Origin") + allowed := false + for _, allowedOrigin := range allAllowedOrigins { + if origin == allowedOrigin || allowedOrigin == "*" || origin == "" { + allowed = true + break + } + } + if !allowed { + http.Error(w, "CORS: origin not allowed", http.StatusForbidden) + return + } + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE, PATCH") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Content-Length") + + if r.Method == http.MethodOptions { + w.Header().Set("Access-Control-Max-Age", "86400") + w.WriteHeader(http.StatusOK) + return + } + w.Header().Set("Content-Type", "application/json") +} diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index 293dd9c..c8da643 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "server/internal/db" + mock "server/internal/tests/dbmocks" "strconv" "strings" "time" @@ -14,28 +15,14 @@ import ( dice "github.com/dicedb/go-dice" ) -// TODO: Look at this later -func enableCors(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") -} - // RateLimiter middleware to limit requests based on a specified limit and duration -func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.Handler { +func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Enable CORS for requests + enableCors(w, r) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Set CORS headers - enableCors(w) - - // Handle OPTIONS preflight request - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - return - } - // Skip rate limiting for non-command endpoints if !strings.Contains(r.URL.Path, "/cli/") { next.ServeHTTP(w, r) @@ -56,9 +43,9 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H } // Initialize request count - requestCount := 0 + requestCount := int64(0) if val != "" { - requestCount, err = strconv.Atoi(val) + requestCount, err = strconv.ParseInt(val, 10, 64) if err != nil { slog.Error("Error converting request count", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -67,7 +54,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H } // Check if the request count exceeds the limit - if requestCount >= limit { + if requestCount > limit { slog.Warn("Request limit exceeded", "count", requestCount) http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) return @@ -94,3 +81,71 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H next.ServeHTTP(w, r) }) } + +func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Enable CORS for requests + enableCors(w, r) + + // Set a request context with a timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Only apply rate limiting for specific paths (e.g., "/cli/") + if !strings.Contains(r.URL.Path, "/cli/") { + next.ServeHTTP(w, r) + return + } + + // Generate the rate limiting key based on the current window + currentWindow := time.Now().Unix() / int64(window) + key := fmt.Sprintf("request_count:%d", currentWindow) + slog.Info("Created rate limiter key", slog.Any("key", key)) + + // Get the current request count for this window from the mock DB + val, err := client.Get(ctx, key) + if err != nil { + slog.Error("Error fetching request count", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Parse the current request count or initialize to 0 + var requestCount int64 = 0 + if val != "" { + requestCount, err = strconv.ParseInt(val, 10, 64) + if err != nil { + slog.Error("Error converting request count", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + } + + // Check if the request limit has been exceeded + if requestCount >= limit { + slog.Warn("Request limit exceeded", "count", requestCount) + http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) + return + } + + // Increment the request count in the mock DB + requestCount, err = client.Incr(ctx, key) + if err != nil { + slog.Error("Error incrementing request count", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Set expiration for the key if it's the first request in the window + if requestCount == 1 { + err = client.Expire(ctx, key, time.Duration(window)*time.Second) + if err != nil { + slog.Error("Error setting expiry for request count", "error", err) + } + } + + // Log the successful request and pass control to the next handler + slog.Info("Request processed", "count", requestCount) + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index 68e4c86..2e8e1a8 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -47,7 +47,7 @@ func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { cim.rateLimiter(w, r, cim.mux) } -func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit, window int) *HTTPServer { +func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit int64, window float64) *HTTPServer { handlerMux := &HandlerMux{ mux: mux, rateLimiter: func(w http.ResponseWriter, r *http.Request, next http.Handler) { @@ -97,26 +97,27 @@ func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, "Error parsing HTTP request", http.StatusBadRequest) + http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest) return } resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) + http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest) return } - if _, ok := resp.(string); !ok { - log.Println("Error marshaling response", "error", err) + respStr, ok := resp.(string) + if !ok { + log.Println("Error: response is not a string", "error", err) http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) return } - httpResponse := HTTPResponse{Data: resp.(string)} + httpResponse := HTTPResponse{Data: respStr} responseJSON, err := json.Marshal(httpResponse) if err != nil { - log.Println("Error marshaling response", "error", err) + log.Println("Error marshaling response to JSON", "error", err) http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) return } diff --git a/internal/tests/dbmocks/mock_dicedb.go b/internal/tests/dbmocks/mock_dicedb.go new file mode 100644 index 0000000..97deee2 --- /dev/null +++ b/internal/tests/dbmocks/mock_dicedb.go @@ -0,0 +1,59 @@ +package db + +import ( + "context" + "fmt" + "sync" + "time" +) + +type DiceDBMock struct { + data map[string]string + mutex sync.Mutex +} + +func NewDiceDBMock() *DiceDBMock { + return &DiceDBMock{ + data: make(map[string]string), + } +} + +func (db *DiceDBMock) Get(ctx context.Context, key string) (string, error) { + db.mutex.Lock() + defer db.mutex.Unlock() + + val, exists := db.data[key] + if !exists { + return "", nil + } + return val, nil +} + +func (db *DiceDBMock) Set(ctx context.Context, key, value string, expiration time.Duration) error { + db.mutex.Lock() + defer db.mutex.Unlock() + + db.data[key] = value + return nil +} + +func (db *DiceDBMock) Incr(ctx context.Context, key string) (int64, error) { + db.mutex.Lock() + defer db.mutex.Unlock() + + val, exists := db.data[key] + var count int64 + if exists { + if _, err := fmt.Sscanf(val, "%d", &count); err != nil { + return 0, fmt.Errorf("error parsing value for key %s: %w", key, err) + } + } + + count++ + db.data[key] = fmt.Sprintf("%d", count) + return count, nil +} + +func (db *DiceDBMock) Expire(ctx context.Context, key string, expiration time.Duration) error { + return nil +} diff --git a/internal/tests/integration/ratelimiter_integration_test.go b/internal/tests/integration/ratelimiter_integration_test.go new file mode 100644 index 0000000..421dafb --- /dev/null +++ b/internal/tests/integration/ratelimiter_integration_test.go @@ -0,0 +1,42 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + config "server/config" + util "server/pkg/util" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRateLimiterWithinLimit(t *testing.T) { + configValue := config.LoadConfig() + limit := configValue.RequestLimit + window := configValue.RequestWindow + + w, r, rateLimiter := util.SetupRateLimiter(limit, window) + + for i := int64(0); i < limit; i++ { + rateLimiter.ServeHTTP(w, r) + require.Equal(t, http.StatusOK, w.Code) + } +} + +func TestRateLimiterExceedsLimit(t *testing.T) { + configValue := config.LoadConfig() + limit := configValue.RequestLimit + window := configValue.RequestWindow + + w, r, rateLimiter := util.SetupRateLimiter(limit, window) + + for i := int64(0); i < limit; i++ { + rateLimiter.ServeHTTP(w, r) + require.Equal(t, http.StatusOK, w.Code) + } + + w = httptest.NewRecorder() + rateLimiter.ServeHTTP(w, r) + require.Equal(t, http.StatusTooManyRequests, w.Code) + require.Contains(t, w.Body.String(), "429 - Too Many Requests") +} diff --git a/internal/tests/stress/ratelimiter_stress_test.go b/internal/tests/stress/ratelimiter_stress_test.go new file mode 100644 index 0000000..7e37054 --- /dev/null +++ b/internal/tests/stress/ratelimiter_stress_test.go @@ -0,0 +1,47 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "server/config" + util "server/pkg/util" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateLimiterUnderStress(t *testing.T) { + configValue := config.LoadConfig() + limit := configValue.RequestLimit + window := configValue.RequestWindow + + _, r, rateLimiter := util.SetupRateLimiter(limit, window) + + var wg sync.WaitGroup + var numRequests int64 = limit // add some extra requests to ensure we don't hit the limit + successCount := int64(0) + failCount := int64(0) + var mu sync.Mutex + + for i := int64(0); i < numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + rec := httptest.NewRecorder() + + time.Sleep(10 * time.Millisecond) + rateLimiter.ServeHTTP(rec, r) + mu.Lock() + if rec.Code == http.StatusOK { + successCount++ + } else if rec.Code == http.StatusTooManyRequests { + failCount++ + } + mu.Unlock() + }() + } + wg.Wait() + require.Equal(t, limit, successCount, "Should succeed for exactly limit requests") +} diff --git a/pkg/util/helpers.go b/pkg/util/helpers.go index ce826f8..c86cef0 100644 --- a/pkg/util/helpers.go +++ b/pkg/util/helpers.go @@ -5,8 +5,12 @@ import ( "errors" "fmt" "io" + "log" "net/http" + "net/http/httptest" "server/internal/cmds" + "server/internal/middleware" + db "server/internal/tests/dbmocks" "strings" ) @@ -161,3 +165,21 @@ func JSONResponse(w http.ResponseWriter, status int, data interface{}) { http.Error(w, err.Error(), http.StatusInternalServerError) } } + +func MockHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("OK")); err != nil { + log.Fatalf("Failed to write response: %v", err) + } +} + +func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { + mockClient := db.NewDiceDBMock() + + r := httptest.NewRequest("GET", "/cli/somecommand", http.NoBody) + w := httptest.NewRecorder() + + rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window) + + return w, r, rateLimiter +} From 911438f3ee1fd2e4fe705e507f593fd9bd4216ed Mon Sep 17 00:00:00 2001 From: Tarun Kantiwal <48859385+tarun-29@users.noreply.github.com> Date: Fri, 4 Oct 2024 22:02:58 +0530 Subject: [PATCH 2/6] Add trailing slash middleware to prevent unexpected API crash (#16) --- internal/middleware/trailing_slash_test.go | 64 ++++++++++++++++++++++ internal/middleware/trailingslash.go | 23 ++++++++ internal/server/httpServer.go | 8 ++- 3 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 internal/middleware/trailing_slash_test.go create mode 100644 internal/middleware/trailingslash.go diff --git a/internal/middleware/trailing_slash_test.go b/internal/middleware/trailing_slash_test.go new file mode 100644 index 0000000..203fb1a --- /dev/null +++ b/internal/middleware/trailing_slash_test.go @@ -0,0 +1,64 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "server/internal/middleware" + "testing" +) + +func TestTrailingSlashMiddleware(t *testing.T) { + + handler := middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + requestURL string + expectedCode int + expectedUrl string + }{ + { + name: "url with trailing slash", + requestURL: "/example/", + expectedCode: http.StatusMovedPermanently, + expectedUrl: "/example", + }, + { + name: "url without trailing slash", + requestURL: "/example", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "root url with trailing slash", + requestURL: "/", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "URL with Query Parameters", + requestURL: "/example?query=1", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.requestURL, nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != tt.expectedCode { + t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + } + + if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { + t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + } + }) + } +} diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go new file mode 100644 index 0000000..09ac1b7 --- /dev/null +++ b/internal/middleware/trailingslash.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "net/http" + "strings" +) + +func TrailingSlashMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { + // remove slash + newPath := strings.TrimSuffix(r.URL.Path, "/") + // if query params exist append them + newURL := newPath + if r.URL.RawQuery != "" { + newURL += "?" + r.URL.RawQuery + } + http.Redirect(w, r, newURL, http.StatusMovedPermanently) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index 2e8e1a8..476f823 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -42,9 +42,11 @@ func errorResponse(response string) string { func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Convert the path to lowercase before passing to the underlying mux. - r.URL.Path = strings.ToLower(r.URL.Path) - // Apply rate limiter - cim.rateLimiter(w, r, cim.mux) + middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.ToLower(r.URL.Path) + // Apply rate limiter + cim.rateLimiter(w, r, cim.mux) + })).ServeHTTP(w, r) } func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit int64, window float64) *HTTPServer { From b42d332ba7880b85be60822cdd9389f9cd246bf2 Mon Sep 17 00:00:00 2001 From: rishav vajpayee <46602331+rishavvajpayee@users.noreply.github.com> Date: Sat, 5 Oct 2024 00:39:57 +0530 Subject: [PATCH 3/6] #21: Refactored repo for consistency (#24) --- config/config.go | 20 +- internal/db/dicedb.go | 14 +- internal/middleware/ratelimiter.go | 17 +- internal/middleware/trailingslash.go | 2 - internal/server/{httpServer.go => http.go} | 39 ++-- .../{mock_dicedb.go => dicedb_mock.go} | 0 .../ratelimiter_integration_test.go | 10 +- .../integration}/trailing_slash_test.go | 14 +- .../tests/stress/ratelimiter_stress_test.go | 10 +- main.go | 12 +- pkg/util/helpers.go | 185 ----------------- {internal => util}/cmds/cmds.go | 0 util/helpers.go | 191 ++++++++++++++++++ 13 files changed, 257 insertions(+), 257 deletions(-) rename internal/server/{httpServer.go => http.go} (72%) rename internal/tests/dbmocks/{mock_dicedb.go => dicedb_mock.go} (100%) rename internal/{middleware => tests/integration}/trailing_slash_test.go (72%) delete mode 100644 pkg/util/helpers.go rename {internal => util}/cmds/cmds.go (100%) create mode 100644 util/helpers.go diff --git a/config/config.go b/config/config.go index 73c1f01..b0f0b9e 100644 --- a/config/config.go +++ b/config/config.go @@ -11,11 +11,11 @@ import ( // Config holds the application configuration type Config struct { - DiceAddr string - ServerPort string - RequestLimit int64 // Field for the request limit - RequestWindow float64 // Field for the time window in float64 - AllowedOrigins []string // Field for the allowed origins + DiceDBAddr string + ServerPort string + RequestLimitPerMin int64 // Field for the request limit + RequestWindowSec float64 // Field for the time window in float64 + AllowedOrigins []string // Field for the allowed origins } // LoadConfig loads the application configuration from environment variables or defaults @@ -26,11 +26,11 @@ func LoadConfig() *Config { } return &Config{ - DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address - ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port - RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit - RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64 - AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins + DiceDBAddr: getEnv("DICEDB_ADDR", "localhost:7379"), // Default DiceDB address + ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port + RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit + RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64 + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins } } diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index c811aac..1701cf6 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -11,10 +11,10 @@ import ( "log/slog" "os" "server/config" - "server/internal/cmds" + "server/util/cmds" "time" - dice "github.com/dicedb/go-dice" + dicedb "github.com/dicedb/go-dice" ) const ( @@ -22,7 +22,7 @@ const ( ) type DiceDB struct { - Client *dice.Client + Client *dicedb.Client Ctx context.Context } @@ -36,13 +36,13 @@ func (db *DiceDB) CloseDiceDB() { } func InitDiceClient(configValue *config.Config) (*DiceDB, error) { - diceClient := dice.NewClient(&dice.Options{ - Addr: configValue.DiceAddr, + diceClient := dicedb.NewClient(&dicedb.Options{ + Addr: configValue.DiceDBAddr, DialTimeout: 10 * time.Second, MaxRetries: 10, }) - // Ping the dice client to verify the connection + // Ping the dicedb client to verify the connection err := diceClient.Ping(context.Background()).Err() if err != nil { return nil, err @@ -64,7 +64,7 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err val, err := db.getKey(command.Args[0]) switch { - case errors.Is(err, dice.Nil): + case errors.Is(err, dicedb.Nil): return nil, errors.New("key does not exist") case err != nil: return nil, fmt.Errorf("get failed %v", err) diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c8da643..c852a26 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -12,19 +12,17 @@ import ( "strings" "time" - dice "github.com/dicedb/go-dice" + dicedb "github.com/dicedb/go-dice" ) // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Enable CORS for requests enableCors(w, r) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Skip rate limiting for non-command endpoints - if !strings.Contains(r.URL.Path, "/cli/") { + if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } @@ -36,7 +34,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float // Fetch the current request count val, err := client.Client.Get(ctx, key).Result() - if err != nil && !errors.Is(err, dice.Nil) { + if err != nil && !errors.Is(err, dicedb.Nil) { slog.Error("Error fetching request count", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return @@ -74,25 +72,19 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float } } - // Log the successful request increment slog.Info("Request processed", "count", requestCount+1) - - // Call the next handler next.ServeHTTP(w, r) }) } func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Enable CORS for requests enableCors(w, r) - - // Set a request context with a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Only apply rate limiting for specific paths (e.g., "/cli/") - if !strings.Contains(r.URL.Path, "/cli/") { + if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } @@ -144,7 +136,6 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } } - // Log the successful request and pass control to the next handler slog.Info("Request processed", "count", requestCount) next.ServeHTTP(w, r) }) diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 09ac1b7..80871ce 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -8,9 +8,7 @@ import ( func TrailingSlashMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - // remove slash newPath := strings.TrimSuffix(r.URL.Path, "/") - // if query params exist append them newURL := newPath if r.URL.RawQuery != "" { newURL += "?" + r.URL.RawQuery diff --git a/internal/server/httpServer.go b/internal/server/http.go similarity index 72% rename from internal/server/httpServer.go rename to internal/server/http.go index 476f823..9b12ec3 100644 --- a/internal/server/httpServer.go +++ b/internal/server/http.go @@ -4,8 +4,7 @@ import ( "context" "encoding/json" "errors" - "fmt" - "log" + "log/slog" "net/http" "server/internal/middleware" "strings" @@ -13,7 +12,7 @@ import ( "time" "server/internal/db" - util "server/pkg/util" + util "server/util" ) type HTTPServer struct { @@ -37,7 +36,13 @@ type HTTPErrorResponse struct { } func errorResponse(response string) string { - return fmt.Sprintf("{\"error\": %q}", response) + errorMessage := map[string]string{"error": response} + jsonResponse, err := json.Marshal(errorMessage) + if err != nil { + slog.Error("Error marshaling response: %v", slog.Any("err", err)) + return `{"error": "internal server error"}` + } + return string(jsonResponse) } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -73,64 +78,64 @@ func (s *HTTPServer) Run(ctx context.Context) error { wg.Add(1) go func() { defer wg.Done() - log.Printf("Starting server at %s\n", s.httpServer.Addr) + slog.Info("starting server at", slog.String("addr", s.httpServer.Addr)) if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("HTTP server error: %v", err) + slog.Error("http server error: %v", slog.Any("err", err)) } }() <-ctx.Done() - log.Println("Shutting down server...") + slog.Info("shutting down server...") return s.Shutdown() } func (s *HTTPServer) Shutdown() error { if err := s.DiceClient.Client.Close(); err != nil { - log.Printf("Failed to close dice client: %v", err) + slog.Error("failed to close dicedb client: %v", slog.Any("err", err)) } return s.httpServer.Shutdown(context.Background()) } func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Server is running"}) + util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"}) } func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest) + http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest) return } resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest) + http.Error(w, errorResponse("error executing command"), http.StatusBadRequest) return } respStr, ok := resp.(string) if !ok { - log.Println("Error: response is not a string", "error", err) - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + slog.Error("error: response is not a string", "error", slog.Any("err", err)) + http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError) return } httpResponse := HTTPResponse{Data: respStr} responseJSON, err := json.Marshal(httpResponse) if err != nil { - log.Println("Error marshaling response to JSON", "error", err) - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + slog.Error("error marshaling response to json", "error", slog.Any("err", err)) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } _, err = w.Write(responseJSON) if err != nil { - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } } func (s *HTTPServer) SearchHandler(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Search results"}) + util.JSONResponse(w, http.StatusOK, map[string]string{"message": "search results"}) } diff --git a/internal/tests/dbmocks/mock_dicedb.go b/internal/tests/dbmocks/dicedb_mock.go similarity index 100% rename from internal/tests/dbmocks/mock_dicedb.go rename to internal/tests/dbmocks/dicedb_mock.go diff --git a/internal/tests/integration/ratelimiter_integration_test.go b/internal/tests/integration/ratelimiter_integration_test.go index 421dafb..7faca79 100644 --- a/internal/tests/integration/ratelimiter_integration_test.go +++ b/internal/tests/integration/ratelimiter_integration_test.go @@ -4,7 +4,7 @@ import ( "net/http" "net/http/httptest" config "server/config" - util "server/pkg/util" + util "server/util" "testing" "github.com/stretchr/testify/require" @@ -12,8 +12,8 @@ import ( func TestRateLimiterWithinLimit(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec w, r, rateLimiter := util.SetupRateLimiter(limit, window) @@ -25,8 +25,8 @@ func TestRateLimiterWithinLimit(t *testing.T) { func TestRateLimiterExceedsLimit(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec w, r, rateLimiter := util.SetupRateLimiter(limit, window) diff --git a/internal/middleware/trailing_slash_test.go b/internal/tests/integration/trailing_slash_test.go similarity index 72% rename from internal/middleware/trailing_slash_test.go rename to internal/tests/integration/trailing_slash_test.go index 203fb1a..972158d 100644 --- a/internal/middleware/trailing_slash_test.go +++ b/internal/tests/integration/trailing_slash_test.go @@ -45,19 +45,19 @@ func TestTrailingSlashMiddleware(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", tt.requestURL, nil) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := httptest.NewRequest("GET", test.requestURL, nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - if w.Code != tt.expectedCode { - t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + if w.Code != test.expectedCode { + t.Errorf("expected status %d, got %d", test.expectedCode, w.Code) } - if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { - t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + if test.expectedUrl != "" && w.Header().Get("Location") != test.expectedUrl { + t.Errorf("expected location %s, got %s", test.expectedUrl, w.Header().Get("Location")) } }) } diff --git a/internal/tests/stress/ratelimiter_stress_test.go b/internal/tests/stress/ratelimiter_stress_test.go index 7e37054..75c14f1 100644 --- a/internal/tests/stress/ratelimiter_stress_test.go +++ b/internal/tests/stress/ratelimiter_stress_test.go @@ -4,7 +4,7 @@ import ( "net/http" "net/http/httptest" "server/config" - util "server/pkg/util" + util "server/util" "sync" "testing" "time" @@ -14,13 +14,13 @@ import ( func TestRateLimiterUnderStress(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec _, r, rateLimiter := util.SetupRateLimiter(limit, window) var wg sync.WaitGroup - var numRequests int64 = limit // add some extra requests to ensure we don't hit the limit + var numRequests int64 = limit successCount := int64(0) failCount := int64(0) var mu sync.Mutex @@ -43,5 +43,5 @@ func TestRateLimiterUnderStress(t *testing.T) { }() } wg.Wait() - require.Equal(t, limit, successCount, "Should succeed for exactly limit requests") + require.Equal(t, limit, successCount, "should succeed for exactly limit requests") } diff --git a/main.go b/main.go index 09a082f..b01de76 100644 --- a/main.go +++ b/main.go @@ -2,25 +2,25 @@ package main import ( "context" - "log" + "log/slog" "net/http" "server/config" "server/internal/db" - "server/internal/server" // Import the new package for HTTPServer + "server/internal/server" ) func main() { configValue := config.LoadConfig() diceClient, err := db.InitDiceClient(configValue) if err != nil { - log.Fatalf("Failed to initialize dice client: %v", err) + slog.Error("Failed to initialize DiceDB client: %v", slog.Any("err", err)) } // Create mux and register routes mux := http.NewServeMux() - httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimit, configValue.RequestWindow) + httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimitPerMin, configValue.RequestWindowSec) mux.HandleFunc("/health", httpServer.HealthCheck) - mux.HandleFunc("/cli/{cmd}", httpServer.CliHandler) + mux.HandleFunc("/shell/exec/{cmd}", httpServer.CliHandler) mux.HandleFunc("/search", httpServer.SearchHandler) // Graceful shutdown context @@ -29,6 +29,6 @@ func main() { // Run the HTTP Server if err := httpServer.Run(ctx); err != nil { - log.Printf("Server failed: %v\n", err) + slog.Error("server failed: %v\n", slog.Any("err", err)) } } diff --git a/pkg/util/helpers.go b/pkg/util/helpers.go deleted file mode 100644 index c86cef0..0000000 --- a/pkg/util/helpers.go +++ /dev/null @@ -1,185 +0,0 @@ -package helpers - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/http/httptest" - "server/internal/cmds" - "server/internal/middleware" - db "server/internal/tests/dbmocks" - "strings" -) - -const ( - Key = "key" - Keys = "keys" - KeyPrefix = "key_prefix" - Field = "field" - Path = "path" - Value = "value" - Values = "values" - User = "user" - Password = "password" - Seconds = "seconds" - KeyValues = "key_values" - True = "true" - QwatchQuery = "query" - Offset = "offset" - Member = "member" - Members = "members" - - JSONIngest string = "JSON.INGEST" -) - -func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { - command := strings.TrimPrefix(r.URL.Path, "/cli/") - if command == "" { - return nil, errors.New("invalid command") - } - - command = strings.ToUpper(command) - var args []string - - // Extract query parameters - queryParams := r.URL.Query() - keyPrefix := queryParams.Get(KeyPrefix) - - if keyPrefix != "" && command == JSONIngest { - args = append(args, keyPrefix) - } - // Step 1: Handle JSON body if present - if r.Body != nil { - body, err := io.ReadAll(r.Body) - if err != nil { - return nil, err - } - - if len(body) > 0 { - var jsonBody map[string]interface{} - if err := json.Unmarshal(body, &jsonBody); err != nil { - return nil, err - } - - if len(jsonBody) == 0 { - return nil, fmt.Errorf("empty JSON object") - } - - // Define keys to exclude and process their values first - // Update as we support more commands - var priorityKeys = []string{ - Key, - Keys, - Field, - Path, - Value, - Values, - Seconds, - User, - Password, - KeyValues, - QwatchQuery, - Offset, - Member, - Members, - } - for _, key := range priorityKeys { - if val, exists := jsonBody[key]; exists { - if key == Keys { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - if key == Values { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - // MultiKey operations - if key == KeyValues { - // Handle KeyValues separately - for k, v := range val.(map[string]interface{}) { - args = append(args, k, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - if key == Members { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - args = append(args, fmt.Sprintf("%v", val)) - delete(jsonBody, key) - } - } - - // Process remaining keys in the JSON body - for key, val := range jsonBody { - switch v := val.(type) { - case string: - // Handle unary operations like 'nx' where value is "true" - args = append(args, key) - if !strings.EqualFold(v, True) { - args = append(args, v) - } - case map[string]interface{}, []interface{}: - // Marshal nested JSON structures back into a string - jsonValue, err := json.Marshal(v) - if err != nil { - return nil, err - } - args = append(args, string(jsonValue)) - default: - args = append(args, key) - // Append other types as strings - value := fmt.Sprintf("%v", v) - if !strings.EqualFold(value, True) { - args = append(args, value) - } - } - } - } - } - - // Step 2: Return the constructed Redis command - return &cmds.CommandRequest{ - Cmd: command, - Args: args, - }, nil -} - -func JSONResponse(w http.ResponseWriter, status int, data interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - if err := json.NewEncoder(w).Encode(data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -func MockHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("OK")); err != nil { - log.Fatalf("Failed to write response: %v", err) - } -} - -func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { - mockClient := db.NewDiceDBMock() - - r := httptest.NewRequest("GET", "/cli/somecommand", http.NoBody) - w := httptest.NewRecorder() - - rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window) - - return w, r, rateLimiter -} diff --git a/internal/cmds/cmds.go b/util/cmds/cmds.go similarity index 100% rename from internal/cmds/cmds.go rename to util/cmds/cmds.go diff --git a/util/helpers.go b/util/helpers.go new file mode 100644 index 0000000..20b3c2b --- /dev/null +++ b/util/helpers.go @@ -0,0 +1,191 @@ +package utils + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "server/internal/middleware" + db "server/internal/tests/dbmocks" + cmds "server/util/cmds" + "strings" +) + +const ( + Key = "key" + Keys = "keys" + KeyPrefix = "key_prefix" + Field = "field" + Path = "path" + Value = "value" + Values = "values" + User = "user" + Password = "password" + Seconds = "seconds" + KeyValues = "key_values" + True = "true" + QwatchQuery = "query" + Offset = "offset" + Member = "member" + Members = "members" + + JSONIngest string = "JSON.INGEST" +) + +var priorityKeys = []string{ + Key, Keys, Field, Path, Value, Values, Seconds, User, Password, KeyValues, QwatchQuery, Offset, Member, Members, +} + +// ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands +func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { + command := extractCommand(r.URL.Path) + if command == "" { + return nil, errors.New("invalid command") + } + + args, err := extractArgsFromRequest(r, command) + if err != nil { + return nil, err + } + + return &cmds.CommandRequest{ + Cmd: command, + Args: args, + }, nil +} + +func extractCommand(path string) string { + command := strings.TrimPrefix(path, "/shell/exec/") + return strings.ToUpper(command) +} + +func extractArgsFromRequest(r *http.Request, command string) ([]string, error) { + var args []string + queryParams := r.URL.Query() + keyPrefix := queryParams.Get(KeyPrefix) + + if keyPrefix != "" && command == JSONIngest { + args = append(args, keyPrefix) + } + + if r.Body != nil { + bodyArgs, err := parseRequestBody(r.Body) + if err != nil { + return nil, err + } + args = append(args, bodyArgs...) + } + + return args, nil +} + +func parseRequestBody(body io.ReadCloser) ([]string, error) { + var args []string + bodyContent, err := io.ReadAll(body) + if err != nil { + return nil, err + } + + if len(bodyContent) == 0 { + return args, nil + } + + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyContent, &jsonBody); err != nil { + return nil, err + } + + if len(jsonBody) == 0 { + return nil, fmt.Errorf("empty JSON object") + } + + args = append(args, extractPriorityArgs(jsonBody)...) + args = append(args, extractRemainingArgs(jsonBody)...) + + return args, nil +} + +func extractPriorityArgs(jsonBody map[string]interface{}) []string { + var args []string + for _, key := range priorityKeys { + if val, exists := jsonBody[key]; exists { + switch key { + case Keys, Values, Members: + args = append(args, convertListToStrings(val.([]interface{}))...) + case KeyValues: + args = append(args, convertMapToStrings(val.(map[string]interface{}))...) + default: + args = append(args, fmt.Sprintf("%v", val)) + } + delete(jsonBody, key) + } + } + return args +} + +func extractRemainingArgs(jsonBody map[string]interface{}) []string { + var args []string + for key, val := range jsonBody { + switch v := val.(type) { + case string: + args = append(args, key) + if !strings.EqualFold(v, True) { + args = append(args, v) + } + case map[string]interface{}, []interface{}: + jsonValue, _ := json.Marshal(v) + args = append(args, string(jsonValue)) + default: + args = append(args, key, fmt.Sprintf("%v", v)) + } + } + return args +} + +func convertListToStrings(list []interface{}) []string { + var result []string + for _, v := range list { + result = append(result, fmt.Sprintf("%v", v)) + } + return result +} + +func convertMapToStrings(m map[string]interface{}) []string { + var result []string + for k, v := range m { + result = append(result, k, fmt.Sprintf("%v", v)) + } + return result +} + +// JSONResponse sends a JSON response to the client +func JSONResponse(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// MockHandler is a basic mock handler for testing +func MockHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("OK")); err != nil { + slog.Error("Failed to write response: %v", slog.Any("err", err)) + } +} + +// SetupRateLimiter sets up a rate limiter for testing purposes +func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { + mockClient := db.NewDiceDBMock() + + r := httptest.NewRequest("GET", "/shell/exec/get", http.NoBody) + w := httptest.NewRecorder() + + rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window) + + return w, r, rateLimiter +} From c9eaa28a8e1d298503d29d8e75cb8c502c352c3a Mon Sep 17 00:00:00 2001 From: Prashant Shubham Date: Sat, 5 Oct 2024 23:22:39 +0530 Subject: [PATCH 4/6] Adding support for generic command execution (#26) --- config/config.go | 2 +- internal/db/commands.go | 16 ---- internal/db/dicedb.go | 68 +++++++---------- internal/middleware/cors.go | 13 +++- internal/middleware/ratelimiter.go | 10 ++- internal/server/http.go | 3 +- main.go | 2 + util/cmds/cmds.go | 4 +- util/helpers.go | 114 +++-------------------------- 9 files changed, 62 insertions(+), 170 deletions(-) delete mode 100644 internal/db/commands.go diff --git a/config/config.go b/config/config.go index b0f0b9e..433c5f9 100644 --- a/config/config.go +++ b/config/config.go @@ -30,7 +30,7 @@ func LoadConfig() *Config { ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64 - AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:3000"}), // Default allowed origins } } diff --git a/internal/db/commands.go b/internal/db/commands.go deleted file mode 100644 index f78b616..0000000 --- a/internal/db/commands.go +++ /dev/null @@ -1,16 +0,0 @@ -package db - -func (db *DiceDB) getKey(key string) (string, error) { - val, err := db.Client.Get(db.Ctx, key).Result() - return val, err -} - -func (db *DiceDB) setKey(key, value string) error { - err := db.Client.Set(db.Ctx, key, value, 0).Err() - return err -} - -func (db *DiceDB) deleteKeys(keys []string) error { - err := db.Client.Del(db.Ctx, keys...).Err() - return err -} diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index 1701cf6..71d4901 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -17,10 +17,6 @@ import ( dicedb "github.com/dicedb/go-dice" ) -const ( - RespOK = "OK" -) - type DiceDB struct { Client *dicedb.Client Ctx context.Context @@ -56,47 +52,35 @@ func InitDiceClient(configValue *config.Config) (*DiceDB, error) { // ExecuteCommand executes a command based on the input func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, error) { - switch command.Cmd { - case "GET": - if len(command.Args) != 1 { - return nil, errors.New("invalid args") - } - - val, err := db.getKey(command.Args[0]) - switch { - case errors.Is(err, dicedb.Nil): - return nil, errors.New("key does not exist") - case err != nil: - return nil, fmt.Errorf("get failed %v", err) - } - - return val, nil - - case "SET": - if len(command.Args) < 2 { - return nil, errors.New("key is required") - } - - err := db.setKey(command.Args[0], command.Args[1]) - if err != nil { - return nil, errors.New("failed to set key") - } - - return RespOK, nil - - case "DEL": - if len(command.Args) == 0 { - return nil, errors.New("at least one key is required") - } + args := make([]interface{}, 0, len(command.Args)+1) + args = append(args, command.Cmd) + for _, arg := range command.Args { + args = append(args, arg) + } - err := db.deleteKeys(command.Args) - if err != nil { - return nil, errors.New("failed to delete keys") - } + res, err := db.Client.Do(db.Ctx, args...).Result() + if errors.Is(err, dicedb.Nil) { + return nil, errors.New("(nil)") + } - return RespOK, nil + if err != nil { + return nil, fmt.Errorf("(error) %v", err) + } + // Print the result based on its type + switch v := res.(type) { + case string: + return v, nil + case []byte: + return string(v), nil + case []interface{}: + case int64: + return fmt.Sprintf("%v", v), nil + case nil: + return "(nil)", nil default: - return nil, errors.New("unknown command") + return fmt.Sprintf("%v", v), nil } + + return nil, fmt.Errorf("(error) invalid result type") } diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 63bc5d9..3e75120 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -5,30 +5,37 @@ import ( "server/config" ) -func enableCors(w http.ResponseWriter, r *http.Request) { +// Updated enableCors function to return a boolean indicating if OPTIONS was handled +func handleCors(w http.ResponseWriter, r *http.Request) bool { configValue := config.LoadConfig() allAllowedOrigins := configValue.AllowedOrigins origin := r.Header.Get("Origin") allowed := false + for _, allowedOrigin := range allAllowedOrigins { if origin == allowedOrigin || allowedOrigin == "*" || origin == "" { allowed = true break } } + if !allowed { http.Error(w, "CORS: origin not allowed", http.StatusForbidden) - return + return true } w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE, PATCH") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Content-Length") + // If the request is an OPTIONS request, handle it and stop further processing if r.Method == http.MethodOptions { w.Header().Set("Access-Control-Max-Age", "86400") w.WriteHeader(http.StatusOK) - return + return true } + + // Continue processing other requests w.Header().Set("Content-Type", "application/json") + return false } diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c852a26..c2b43e3 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -18,7 +18,10 @@ import ( // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - enableCors(w, r) + if handleCors(w, r) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -79,7 +82,10 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - enableCors(w, r) + if handleCors(w, r) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/internal/server/http.go b/internal/server/http.go index 9b12ec3..00d6a3e 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -42,6 +42,7 @@ func errorResponse(response string) string { slog.Error("Error marshaling response: %v", slog.Any("err", err)) return `{"error": "internal server error"}` } + return string(jsonResponse) } @@ -110,7 +111,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse("error executing command"), http.StatusBadRequest) + http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) return } diff --git a/main.go b/main.go index b01de76..99c98e3 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "net/http" + "os" "server/config" "server/internal/db" "server/internal/server" @@ -14,6 +15,7 @@ func main() { diceClient, err := db.InitDiceClient(configValue) if err != nil { slog.Error("Failed to initialize DiceDB client: %v", slog.Any("err", err)) + os.Exit(1) } // Create mux and register routes diff --git a/util/cmds/cmds.go b/util/cmds/cmds.go index bb7a275..1e68d08 100644 --- a/util/cmds/cmds.go +++ b/util/cmds/cmds.go @@ -1,6 +1,6 @@ package cmds type CommandRequest struct { - Cmd string `json:"cmd"` - Args []string + Cmd string `json:"cmd"` + Args []string `json:"args"` } diff --git a/util/helpers.go b/util/helpers.go index 20b3c2b..20a605a 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -10,35 +10,10 @@ import ( "net/http/httptest" "server/internal/middleware" db "server/internal/tests/dbmocks" - cmds "server/util/cmds" + "server/util/cmds" "strings" ) -const ( - Key = "key" - Keys = "keys" - KeyPrefix = "key_prefix" - Field = "field" - Path = "path" - Value = "value" - Values = "values" - User = "user" - Password = "password" - Seconds = "seconds" - KeyValues = "key_values" - True = "true" - QwatchQuery = "query" - Offset = "offset" - Member = "member" - Members = "members" - - JSONIngest string = "JSON.INGEST" -) - -var priorityKeys = []string{ - Key, Keys, Field, Path, Value, Values, Seconds, User, Password, KeyValues, QwatchQuery, Offset, Member, Members, -} - // ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { command := extractCommand(r.URL.Path) @@ -46,7 +21,7 @@ func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { return nil, errors.New("invalid command") } - args, err := extractArgsFromRequest(r, command) + args, err := newExtractor(r) if err != nil { return nil, err } @@ -62,29 +37,9 @@ func extractCommand(path string) string { return strings.ToUpper(command) } -func extractArgsFromRequest(r *http.Request, command string) ([]string, error) { +func newExtractor(r *http.Request) ([]string, error) { var args []string - queryParams := r.URL.Query() - keyPrefix := queryParams.Get(KeyPrefix) - - if keyPrefix != "" && command == JSONIngest { - args = append(args, keyPrefix) - } - - if r.Body != nil { - bodyArgs, err := parseRequestBody(r.Body) - if err != nil { - return nil, err - } - args = append(args, bodyArgs...) - } - - return args, nil -} - -func parseRequestBody(body io.ReadCloser) ([]string, error) { - var args []string - bodyContent, err := io.ReadAll(body) + bodyContent, err := io.ReadAll(r.Body) if err != nil { return nil, err } @@ -93,7 +48,7 @@ func parseRequestBody(body io.ReadCloser) ([]string, error) { return args, nil } - var jsonBody map[string]interface{} + var jsonBody []interface{} if err := json.Unmarshal(bodyContent, &jsonBody); err != nil { return nil, err } @@ -102,63 +57,16 @@ func parseRequestBody(body io.ReadCloser) ([]string, error) { return nil, fmt.Errorf("empty JSON object") } - args = append(args, extractPriorityArgs(jsonBody)...) - args = append(args, extractRemainingArgs(jsonBody)...) - - return args, nil -} - -func extractPriorityArgs(jsonBody map[string]interface{}) []string { - var args []string - for _, key := range priorityKeys { - if val, exists := jsonBody[key]; exists { - switch key { - case Keys, Values, Members: - args = append(args, convertListToStrings(val.([]interface{}))...) - case KeyValues: - args = append(args, convertMapToStrings(val.(map[string]interface{}))...) - default: - args = append(args, fmt.Sprintf("%v", val)) - } - delete(jsonBody, key) - } - } - return args -} - -func extractRemainingArgs(jsonBody map[string]interface{}) []string { - var args []string - for key, val := range jsonBody { - switch v := val.(type) { - case string: - args = append(args, key) - if !strings.EqualFold(v, True) { - args = append(args, v) - } - case map[string]interface{}, []interface{}: - jsonValue, _ := json.Marshal(v) - args = append(args, string(jsonValue)) - default: - args = append(args, key, fmt.Sprintf("%v", v)) + for _, val := range jsonBody { + s, ok := val.(string) + if !ok { + return nil, fmt.Errorf("invalid input") } - } - return args -} -func convertListToStrings(list []interface{}) []string { - var result []string - for _, v := range list { - result = append(result, fmt.Sprintf("%v", v)) + args = append(args, s) } - return result -} -func convertMapToStrings(m map[string]interface{}) []string { - var result []string - for k, v := range m { - result = append(result, k, fmt.Sprintf("%v", v)) - } - return result + return args, nil } // JSONResponse sends a JSON response to the client From f255c5fb0f8f1f6441bd63a5a9ccbde87795467f Mon Sep 17 00:00:00 2001 From: Prashant Shubham Date: Sat, 5 Oct 2024 23:34:06 +0530 Subject: [PATCH 5/6] Adding support for generic command execution (#27) --- internal/db/dicedb.go | 33 ++++++++++++++++++++++++++++++--- internal/server/http.go | 1 + 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index 71d4901..7da8f9f 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -12,11 +12,14 @@ import ( "os" "server/config" "server/util/cmds" + "strings" "time" dicedb "github.com/dicedb/go-dice" ) +const RespNil = "(nil)" + type DiceDB struct { Client *dicedb.Client Ctx context.Context @@ -60,7 +63,7 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err res, err := db.Client.Do(db.Ctx, args...).Result() if errors.Is(err, dicedb.Nil) { - return nil, errors.New("(nil)") + return RespNil, nil } if err != nil { @@ -74,13 +77,37 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err case []byte: return string(v), nil case []interface{}: + return renderListResponse(v) case int64: return fmt.Sprintf("%v", v), nil case nil: - return "(nil)", nil + return RespNil, nil default: return fmt.Sprintf("%v", v), nil } +} + +func renderListResponse(items []interface{}) (string, error) { + if len(items)%2 != 0 { + return "", fmt.Errorf("(error) invalid result format") + } + + var builder strings.Builder + for i := 0; i < len(items); i += 2 { + field, ok1 := items[i].(string) + value, ok2 := items[i+1].(string) + + // Check if both field and value are valid strings + if !ok1 || !ok2 { + return "", fmt.Errorf("(error) invalid result type") + } + + // Append the formatted field and value + _, err := fmt.Fprintf(&builder, "%d) \"%s\"\n%d) \"%s\"\n", i+1, field, i+2, value) + if err != nil { + return "", err + } + } - return nil, fmt.Errorf("(error) invalid result type") + return builder.String(), nil } diff --git a/internal/server/http.go b/internal/server/http.go index 00d6a3e..c665277 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -111,6 +111,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { + slog.Error("error: failure in executing command", "error", slog.Any("err", err)) http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) return } From d373641ef7055cca4e8004495136e5a8888b4be7 Mon Sep 17 00:00:00 2001 From: Yash Budhia <142312760+yashbudhia@users.noreply.github.com> Date: Sun, 6 Oct 2024 20:20:04 +0530 Subject: [PATCH 6/6] Disable list of commands from playground repositories #897 - Commands Blacklisted (#23) --- internal/server/http.go | 10 +++------- util/helpers.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/internal/server/http.go b/internal/server/http.go index c665277..876f959 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -6,12 +6,12 @@ import ( "errors" "log/slog" "net/http" - "server/internal/middleware" "strings" "sync" "time" "server/internal/db" + "server/internal/middleware" util "server/util" ) @@ -20,8 +20,6 @@ type HTTPServer struct { DiceClient *db.DiceDB } -// HandlerMux wraps ServeMux and forces REST paths to lowercase -// and attaches a rate limiter with the handler type HandlerMux struct { mux *http.ServeMux rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) @@ -47,10 +45,8 @@ func errorResponse(response string) string { } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Convert the path to lowercase before passing to the underlying mux. middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.ToLower(r.URL.Path) - // Apply rate limiter cim.rateLimiter(w, r, cim.mux) })).ServeHTTP(w, r) } @@ -105,7 +101,7 @@ func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest) + http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) return } @@ -119,7 +115,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { respStr, ok := resp.(string) if !ok { slog.Error("error: response is not a string", "error", slog.Any("err", err)) - http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } diff --git a/util/helpers.go b/util/helpers.go index 20a605a..3e3aea8 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -14,6 +14,37 @@ import ( "strings" ) +// Map of blocklisted commands +var blocklistedCommands = map[string]bool{ + "FLUSHALL": true, + "FLUSHDB": true, + "DUMP": true, + "ABORT": true, + "AUTH": true, + "CONFIG": true, + "SAVE": true, + "BGSAVE": true, + "BGREWRITEAOF": true, + "RESTORE": true, + "MULTI": true, + "EXEC": true, + "DISCARD": true, + "QWATCH": true, + "QUNWATCH": true, + "LATENCY": true, + "CLIENT": true, + "SLEEP": true, + "PERSIST": true, +} + +// BlockListedCommand checks if a command is blocklisted +func BlockListedCommand(cmd string) error { + if _, exists := blocklistedCommands[strings.ToUpper(cmd)]; exists { + return errors.New("ERR unknown command '" + cmd + "'") + } + return nil +} + // ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { command := extractCommand(r.URL.Path) @@ -21,6 +52,11 @@ func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { return nil, errors.New("invalid command") } + // Check if the command is blocklisted + if err := BlockListedCommand(command); err != nil { + return nil, err + } + args, err := newExtractor(r) if err != nil { return nil, err