diff --git a/config/config.go b/config/config.go index 73c1f01..b0f0b9e 100644 --- a/config/config.go +++ b/config/config.go @@ -11,11 +11,11 @@ import ( // Config holds the application configuration type Config struct { - DiceAddr string - ServerPort string - RequestLimit int64 // Field for the request limit - RequestWindow float64 // Field for the time window in float64 - AllowedOrigins []string // Field for the allowed origins + DiceDBAddr string + ServerPort string + RequestLimitPerMin int64 // Field for the request limit + RequestWindowSec float64 // Field for the time window in float64 + AllowedOrigins []string // Field for the allowed origins } // LoadConfig loads the application configuration from environment variables or defaults @@ -26,11 +26,11 @@ func LoadConfig() *Config { } return &Config{ - DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address - ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port - RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit - RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64 - AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins + DiceDBAddr: getEnv("DICEDB_ADDR", "localhost:7379"), // Default DiceDB address + ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port + RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit + RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64 + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins } } diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index c811aac..1701cf6 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -11,10 +11,10 @@ import ( "log/slog" "os" "server/config" - "server/internal/cmds" + "server/util/cmds" "time" - dice "github.com/dicedb/go-dice" + dicedb "github.com/dicedb/go-dice" ) const ( @@ -22,7 +22,7 @@ const ( ) type DiceDB struct { - Client *dice.Client + Client *dicedb.Client Ctx context.Context } @@ -36,13 +36,13 @@ func (db *DiceDB) CloseDiceDB() { } func InitDiceClient(configValue *config.Config) (*DiceDB, error) { - diceClient := dice.NewClient(&dice.Options{ - Addr: configValue.DiceAddr, + diceClient := dicedb.NewClient(&dicedb.Options{ + Addr: configValue.DiceDBAddr, DialTimeout: 10 * time.Second, MaxRetries: 10, }) - // Ping the dice client to verify the connection + // Ping the dicedb client to verify the connection err := diceClient.Ping(context.Background()).Err() if err != nil { return nil, err @@ -64,7 +64,7 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err val, err := db.getKey(command.Args[0]) switch { - case errors.Is(err, dice.Nil): + case errors.Is(err, dicedb.Nil): return nil, errors.New("key does not exist") case err != nil: return nil, fmt.Errorf("get failed %v", err) diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c8da643..c852a26 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -12,19 +12,17 @@ import ( "strings" "time" - dice "github.com/dicedb/go-dice" + dicedb "github.com/dicedb/go-dice" ) // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Enable CORS for requests enableCors(w, r) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Skip rate limiting for non-command endpoints - if !strings.Contains(r.URL.Path, "/cli/") { + if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } @@ -36,7 +34,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float // Fetch the current request count val, err := client.Client.Get(ctx, key).Result() - if err != nil && !errors.Is(err, dice.Nil) { + if err != nil && !errors.Is(err, dicedb.Nil) { slog.Error("Error fetching request count", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return @@ -74,25 +72,19 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float } } - // Log the successful request increment slog.Info("Request processed", "count", requestCount+1) - - // Call the next handler next.ServeHTTP(w, r) }) } func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Enable CORS for requests enableCors(w, r) - - // Set a request context with a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Only apply rate limiting for specific paths (e.g., "/cli/") - if !strings.Contains(r.URL.Path, "/cli/") { + if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } @@ -144,7 +136,6 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } } - // Log the successful request and pass control to the next handler slog.Info("Request processed", "count", requestCount) next.ServeHTTP(w, r) }) diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 09ac1b7..80871ce 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -8,9 +8,7 @@ import ( func TrailingSlashMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - // remove slash newPath := strings.TrimSuffix(r.URL.Path, "/") - // if query params exist append them newURL := newPath if r.URL.RawQuery != "" { newURL += "?" + r.URL.RawQuery diff --git a/internal/server/httpServer.go b/internal/server/http.go similarity index 74% rename from internal/server/httpServer.go rename to internal/server/http.go index 4bdddec..b95eb63 100644 --- a/internal/server/httpServer.go +++ b/internal/server/http.go @@ -4,8 +4,7 @@ import ( "context" "encoding/json" "errors" - "fmt" - "log" + "log/slog" "net/http" "strings" "sync" @@ -13,7 +12,7 @@ import ( "server/internal/middleware" "server/internal/db" - util "server/pkg/util" + util "server/util" ) type HTTPServer struct { @@ -37,7 +36,13 @@ type HTTPErrorResponse struct { } func errorResponse(response string) string { - return fmt.Sprintf("{\"error\": %q}", response) + errorMessage := map[string]string{"error": response} + jsonResponse, err := json.Marshal(errorMessage) + if err != nil { + slog.Error("Error marshaling response: %v", slog.Any("err", err)) + return `{"error": "internal server error"}` + } + return string(jsonResponse) } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -73,33 +78,33 @@ func (s *HTTPServer) Run(ctx context.Context) error { wg.Add(1) go func() { defer wg.Done() - log.Printf("Starting server at %s\n", s.httpServer.Addr) + slog.Info("starting server at", slog.String("addr", s.httpServer.Addr)) if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("HTTP server error: %v", err) + slog.Error("http server error: %v", slog.Any("err", err)) } }() <-ctx.Done() - log.Println("Shutting down server...") + slog.Info("shutting down server...") return s.Shutdown() } func (s *HTTPServer) Shutdown() error { if err := s.DiceClient.Client.Close(); err != nil { - log.Printf("Failed to close dice client: %v", err) + slog.Error("failed to close dicedb client: %v", slog.Any("err", err)) } return s.httpServer.Shutdown(context.Background()) } func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Server is running"}) + util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"}) } func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest) + http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest) return } @@ -112,32 +117,32 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest) + http.Error(w, errorResponse("error executing command"), http.StatusBadRequest) return } respStr, ok := resp.(string) if !ok { - log.Println("Error: response is not a string", "error", err) - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + slog.Error("error: response is not a string", "error", slog.Any("err", err)) + http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError) return } httpResponse := HTTPResponse{Data: respStr} responseJSON, err := json.Marshal(httpResponse) if err != nil { - log.Println("Error marshaling response to JSON", "error", err) - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + slog.Error("error marshaling response to json", "error", slog.Any("err", err)) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } _, err = w.Write(responseJSON) if err != nil { - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } } func (s *HTTPServer) SearchHandler(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Search results"}) + util.JSONResponse(w, http.StatusOK, map[string]string{"message": "search results"}) } diff --git a/internal/tests/dbmocks/mock_dicedb.go b/internal/tests/dbmocks/dicedb_mock.go similarity index 100% rename from internal/tests/dbmocks/mock_dicedb.go rename to internal/tests/dbmocks/dicedb_mock.go diff --git a/internal/tests/integration/ratelimiter_integration_test.go b/internal/tests/integration/ratelimiter_integration_test.go index 421dafb..7faca79 100644 --- a/internal/tests/integration/ratelimiter_integration_test.go +++ b/internal/tests/integration/ratelimiter_integration_test.go @@ -4,7 +4,7 @@ import ( "net/http" "net/http/httptest" config "server/config" - util "server/pkg/util" + util "server/util" "testing" "github.com/stretchr/testify/require" @@ -12,8 +12,8 @@ import ( func TestRateLimiterWithinLimit(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec w, r, rateLimiter := util.SetupRateLimiter(limit, window) @@ -25,8 +25,8 @@ func TestRateLimiterWithinLimit(t *testing.T) { func TestRateLimiterExceedsLimit(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec w, r, rateLimiter := util.SetupRateLimiter(limit, window) diff --git a/internal/middleware/trailing_slash_test.go b/internal/tests/integration/trailing_slash_test.go similarity index 72% rename from internal/middleware/trailing_slash_test.go rename to internal/tests/integration/trailing_slash_test.go index 203fb1a..972158d 100644 --- a/internal/middleware/trailing_slash_test.go +++ b/internal/tests/integration/trailing_slash_test.go @@ -45,19 +45,19 @@ func TestTrailingSlashMiddleware(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", tt.requestURL, nil) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := httptest.NewRequest("GET", test.requestURL, nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - if w.Code != tt.expectedCode { - t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + if w.Code != test.expectedCode { + t.Errorf("expected status %d, got %d", test.expectedCode, w.Code) } - if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { - t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + if test.expectedUrl != "" && w.Header().Get("Location") != test.expectedUrl { + t.Errorf("expected location %s, got %s", test.expectedUrl, w.Header().Get("Location")) } }) } diff --git a/internal/tests/stress/ratelimiter_stress_test.go b/internal/tests/stress/ratelimiter_stress_test.go index 7e37054..75c14f1 100644 --- a/internal/tests/stress/ratelimiter_stress_test.go +++ b/internal/tests/stress/ratelimiter_stress_test.go @@ -4,7 +4,7 @@ import ( "net/http" "net/http/httptest" "server/config" - util "server/pkg/util" + util "server/util" "sync" "testing" "time" @@ -14,13 +14,13 @@ import ( func TestRateLimiterUnderStress(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec _, r, rateLimiter := util.SetupRateLimiter(limit, window) var wg sync.WaitGroup - var numRequests int64 = limit // add some extra requests to ensure we don't hit the limit + var numRequests int64 = limit successCount := int64(0) failCount := int64(0) var mu sync.Mutex @@ -43,5 +43,5 @@ func TestRateLimiterUnderStress(t *testing.T) { }() } wg.Wait() - require.Equal(t, limit, successCount, "Should succeed for exactly limit requests") + require.Equal(t, limit, successCount, "should succeed for exactly limit requests") } diff --git a/main.go b/main.go index 09a082f..b01de76 100644 --- a/main.go +++ b/main.go @@ -2,25 +2,25 @@ package main import ( "context" - "log" + "log/slog" "net/http" "server/config" "server/internal/db" - "server/internal/server" // Import the new package for HTTPServer + "server/internal/server" ) func main() { configValue := config.LoadConfig() diceClient, err := db.InitDiceClient(configValue) if err != nil { - log.Fatalf("Failed to initialize dice client: %v", err) + slog.Error("Failed to initialize DiceDB client: %v", slog.Any("err", err)) } // Create mux and register routes mux := http.NewServeMux() - httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimit, configValue.RequestWindow) + httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimitPerMin, configValue.RequestWindowSec) mux.HandleFunc("/health", httpServer.HealthCheck) - mux.HandleFunc("/cli/{cmd}", httpServer.CliHandler) + mux.HandleFunc("/shell/exec/{cmd}", httpServer.CliHandler) mux.HandleFunc("/search", httpServer.SearchHandler) // Graceful shutdown context @@ -29,6 +29,6 @@ func main() { // Run the HTTP Server if err := httpServer.Run(ctx); err != nil { - log.Printf("Server failed: %v\n", err) + slog.Error("server failed: %v\n", slog.Any("err", err)) } } diff --git a/pkg/util/helpers.go b/pkg/util/helpers.go deleted file mode 100644 index c86cef0..0000000 --- a/pkg/util/helpers.go +++ /dev/null @@ -1,185 +0,0 @@ -package helpers - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/http/httptest" - "server/internal/cmds" - "server/internal/middleware" - db "server/internal/tests/dbmocks" - "strings" -) - -const ( - Key = "key" - Keys = "keys" - KeyPrefix = "key_prefix" - Field = "field" - Path = "path" - Value = "value" - Values = "values" - User = "user" - Password = "password" - Seconds = "seconds" - KeyValues = "key_values" - True = "true" - QwatchQuery = "query" - Offset = "offset" - Member = "member" - Members = "members" - - JSONIngest string = "JSON.INGEST" -) - -func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { - command := strings.TrimPrefix(r.URL.Path, "/cli/") - if command == "" { - return nil, errors.New("invalid command") - } - - command = strings.ToUpper(command) - var args []string - - // Extract query parameters - queryParams := r.URL.Query() - keyPrefix := queryParams.Get(KeyPrefix) - - if keyPrefix != "" && command == JSONIngest { - args = append(args, keyPrefix) - } - // Step 1: Handle JSON body if present - if r.Body != nil { - body, err := io.ReadAll(r.Body) - if err != nil { - return nil, err - } - - if len(body) > 0 { - var jsonBody map[string]interface{} - if err := json.Unmarshal(body, &jsonBody); err != nil { - return nil, err - } - - if len(jsonBody) == 0 { - return nil, fmt.Errorf("empty JSON object") - } - - // Define keys to exclude and process their values first - // Update as we support more commands - var priorityKeys = []string{ - Key, - Keys, - Field, - Path, - Value, - Values, - Seconds, - User, - Password, - KeyValues, - QwatchQuery, - Offset, - Member, - Members, - } - for _, key := range priorityKeys { - if val, exists := jsonBody[key]; exists { - if key == Keys { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - if key == Values { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - // MultiKey operations - if key == KeyValues { - // Handle KeyValues separately - for k, v := range val.(map[string]interface{}) { - args = append(args, k, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - if key == Members { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - args = append(args, fmt.Sprintf("%v", val)) - delete(jsonBody, key) - } - } - - // Process remaining keys in the JSON body - for key, val := range jsonBody { - switch v := val.(type) { - case string: - // Handle unary operations like 'nx' where value is "true" - args = append(args, key) - if !strings.EqualFold(v, True) { - args = append(args, v) - } - case map[string]interface{}, []interface{}: - // Marshal nested JSON structures back into a string - jsonValue, err := json.Marshal(v) - if err != nil { - return nil, err - } - args = append(args, string(jsonValue)) - default: - args = append(args, key) - // Append other types as strings - value := fmt.Sprintf("%v", v) - if !strings.EqualFold(value, True) { - args = append(args, value) - } - } - } - } - } - - // Step 2: Return the constructed Redis command - return &cmds.CommandRequest{ - Cmd: command, - Args: args, - }, nil -} - -func JSONResponse(w http.ResponseWriter, status int, data interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - if err := json.NewEncoder(w).Encode(data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -func MockHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("OK")); err != nil { - log.Fatalf("Failed to write response: %v", err) - } -} - -func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { - mockClient := db.NewDiceDBMock() - - r := httptest.NewRequest("GET", "/cli/somecommand", http.NoBody) - w := httptest.NewRecorder() - - rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window) - - return w, r, rateLimiter -} diff --git a/internal/cmds/cmds.go b/util/cmds/cmds.go similarity index 100% rename from internal/cmds/cmds.go rename to util/cmds/cmds.go diff --git a/util/helpers.go b/util/helpers.go new file mode 100644 index 0000000..20b3c2b --- /dev/null +++ b/util/helpers.go @@ -0,0 +1,191 @@ +package utils + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "server/internal/middleware" + db "server/internal/tests/dbmocks" + cmds "server/util/cmds" + "strings" +) + +const ( + Key = "key" + Keys = "keys" + KeyPrefix = "key_prefix" + Field = "field" + Path = "path" + Value = "value" + Values = "values" + User = "user" + Password = "password" + Seconds = "seconds" + KeyValues = "key_values" + True = "true" + QwatchQuery = "query" + Offset = "offset" + Member = "member" + Members = "members" + + JSONIngest string = "JSON.INGEST" +) + +var priorityKeys = []string{ + Key, Keys, Field, Path, Value, Values, Seconds, User, Password, KeyValues, QwatchQuery, Offset, Member, Members, +} + +// ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands +func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { + command := extractCommand(r.URL.Path) + if command == "" { + return nil, errors.New("invalid command") + } + + args, err := extractArgsFromRequest(r, command) + if err != nil { + return nil, err + } + + return &cmds.CommandRequest{ + Cmd: command, + Args: args, + }, nil +} + +func extractCommand(path string) string { + command := strings.TrimPrefix(path, "/shell/exec/") + return strings.ToUpper(command) +} + +func extractArgsFromRequest(r *http.Request, command string) ([]string, error) { + var args []string + queryParams := r.URL.Query() + keyPrefix := queryParams.Get(KeyPrefix) + + if keyPrefix != "" && command == JSONIngest { + args = append(args, keyPrefix) + } + + if r.Body != nil { + bodyArgs, err := parseRequestBody(r.Body) + if err != nil { + return nil, err + } + args = append(args, bodyArgs...) + } + + return args, nil +} + +func parseRequestBody(body io.ReadCloser) ([]string, error) { + var args []string + bodyContent, err := io.ReadAll(body) + if err != nil { + return nil, err + } + + if len(bodyContent) == 0 { + return args, nil + } + + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyContent, &jsonBody); err != nil { + return nil, err + } + + if len(jsonBody) == 0 { + return nil, fmt.Errorf("empty JSON object") + } + + args = append(args, extractPriorityArgs(jsonBody)...) + args = append(args, extractRemainingArgs(jsonBody)...) + + return args, nil +} + +func extractPriorityArgs(jsonBody map[string]interface{}) []string { + var args []string + for _, key := range priorityKeys { + if val, exists := jsonBody[key]; exists { + switch key { + case Keys, Values, Members: + args = append(args, convertListToStrings(val.([]interface{}))...) + case KeyValues: + args = append(args, convertMapToStrings(val.(map[string]interface{}))...) + default: + args = append(args, fmt.Sprintf("%v", val)) + } + delete(jsonBody, key) + } + } + return args +} + +func extractRemainingArgs(jsonBody map[string]interface{}) []string { + var args []string + for key, val := range jsonBody { + switch v := val.(type) { + case string: + args = append(args, key) + if !strings.EqualFold(v, True) { + args = append(args, v) + } + case map[string]interface{}, []interface{}: + jsonValue, _ := json.Marshal(v) + args = append(args, string(jsonValue)) + default: + args = append(args, key, fmt.Sprintf("%v", v)) + } + } + return args +} + +func convertListToStrings(list []interface{}) []string { + var result []string + for _, v := range list { + result = append(result, fmt.Sprintf("%v", v)) + } + return result +} + +func convertMapToStrings(m map[string]interface{}) []string { + var result []string + for k, v := range m { + result = append(result, k, fmt.Sprintf("%v", v)) + } + return result +} + +// JSONResponse sends a JSON response to the client +func JSONResponse(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// MockHandler is a basic mock handler for testing +func MockHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("OK")); err != nil { + slog.Error("Failed to write response: %v", slog.Any("err", err)) + } +} + +// SetupRateLimiter sets up a rate limiter for testing purposes +func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { + mockClient := db.NewDiceDBMock() + + r := httptest.NewRequest("GET", "/shell/exec/get", http.NoBody) + w := httptest.NewRecorder() + + rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window) + + return w, r, rateLimiter +}