diff --git a/config/config.go b/config/config.go index 73c1f01..b0f0b9e 100644 --- a/config/config.go +++ b/config/config.go @@ -11,11 +11,11 @@ import ( // Config holds the application configuration type Config struct { - DiceAddr string - ServerPort string - RequestLimit int64 // Field for the request limit - RequestWindow float64 // Field for the time window in float64 - AllowedOrigins []string // Field for the allowed origins + DiceDBAddr string + ServerPort string + RequestLimitPerMin int64 // Field for the request limit + RequestWindowSec float64 // Field for the time window in float64 + AllowedOrigins []string // Field for the allowed origins } // LoadConfig loads the application configuration from environment variables or defaults @@ -26,11 +26,11 @@ func LoadConfig() *Config { } return &Config{ - DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address - ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port - RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit - RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64 - AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins + DiceDBAddr: getEnv("DICEDB_ADDR", "localhost:7379"), // Default DiceDB address + ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port + RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit + RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64 + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins } } diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index c811aac..f082106 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -11,10 +11,10 @@ import ( "log/slog" "os" "server/config" - "server/internal/cmds" + "server/pkg/util/cmds" "time" - dice "github.com/dicedb/go-dice" + dicedb "github.com/dicedb/go-dice" ) const ( @@ -22,7 +22,7 @@ const ( ) type DiceDB struct { - Client *dice.Client + Client *dicedb.Client Ctx context.Context } @@ -36,13 +36,13 @@ func (db *DiceDB) CloseDiceDB() { } func InitDiceClient(configValue *config.Config) (*DiceDB, error) { - diceClient := dice.NewClient(&dice.Options{ - Addr: configValue.DiceAddr, + diceClient := dicedb.NewClient(&dicedb.Options{ + Addr: configValue.DiceDBAddr, DialTimeout: 10 * time.Second, MaxRetries: 10, }) - // Ping the dice client to verify the connection + // Ping the dicedb client to verify the connection err := diceClient.Ping(context.Background()).Err() if err != nil { return nil, err @@ -64,7 +64,7 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err val, err := db.getKey(command.Args[0]) switch { - case errors.Is(err, dice.Nil): + case errors.Is(err, dicedb.Nil): return nil, errors.New("key does not exist") case err != nil: return nil, fmt.Errorf("get failed %v", err) diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c8da643..c852a26 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -12,19 +12,17 @@ import ( "strings" "time" - dice "github.com/dicedb/go-dice" + dicedb "github.com/dicedb/go-dice" ) // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Enable CORS for requests enableCors(w, r) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Skip rate limiting for non-command endpoints - if !strings.Contains(r.URL.Path, "/cli/") { + if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } @@ -36,7 +34,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float // Fetch the current request count val, err := client.Client.Get(ctx, key).Result() - if err != nil && !errors.Is(err, dice.Nil) { + if err != nil && !errors.Is(err, dicedb.Nil) { slog.Error("Error fetching request count", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return @@ -74,25 +72,19 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float } } - // Log the successful request increment slog.Info("Request processed", "count", requestCount+1) - - // Call the next handler next.ServeHTTP(w, r) }) } func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Enable CORS for requests enableCors(w, r) - - // Set a request context with a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Only apply rate limiting for specific paths (e.g., "/cli/") - if !strings.Contains(r.URL.Path, "/cli/") { + if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } @@ -144,7 +136,6 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } } - // Log the successful request and pass control to the next handler slog.Info("Request processed", "count", requestCount) next.ServeHTTP(w, r) }) diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 09ac1b7..80871ce 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -8,9 +8,7 @@ import ( func TrailingSlashMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - // remove slash newPath := strings.TrimSuffix(r.URL.Path, "/") - // if query params exist append them newURL := newPath if r.URL.RawQuery != "" { newURL += "?" + r.URL.RawQuery diff --git a/internal/server/httpServer.go b/internal/server/http.go similarity index 75% rename from internal/server/httpServer.go rename to internal/server/http.go index 476f823..7ec18b2 100644 --- a/internal/server/httpServer.go +++ b/internal/server/http.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "log" "net/http" "server/internal/middleware" @@ -37,7 +36,13 @@ type HTTPErrorResponse struct { } func errorResponse(response string) string { - return fmt.Sprintf("{\"error\": %q}", response) + errorMessage := map[string]string{"error": response} + jsonResponse, err := json.Marshal(errorMessage) + if err != nil { + log.Printf("Error marshaling response: %v", err) + return `{"error": "internal server error"}` + } + return string(jsonResponse) } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -73,64 +78,64 @@ func (s *HTTPServer) Run(ctx context.Context) error { wg.Add(1) go func() { defer wg.Done() - log.Printf("Starting server at %s\n", s.httpServer.Addr) + log.Printf("starting server at %s\n", s.httpServer.Addr) if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("HTTP server error: %v", err) + log.Fatalf("http server error: %v", err) } }() <-ctx.Done() - log.Println("Shutting down server...") + log.Println("shutting down server...") return s.Shutdown() } func (s *HTTPServer) Shutdown() error { if err := s.DiceClient.Client.Close(); err != nil { - log.Printf("Failed to close dice client: %v", err) + log.Printf("failed to close dicedb client: %v", err) } return s.httpServer.Shutdown(context.Background()) } func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Server is running"}) + util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"}) } func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest) + http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest) return } resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest) + http.Error(w, errorResponse("error executing command"), http.StatusBadRequest) return } respStr, ok := resp.(string) if !ok { - log.Println("Error: response is not a string", "error", err) - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + log.Println("error: response is not a string", "error", err) + http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError) return } httpResponse := HTTPResponse{Data: respStr} responseJSON, err := json.Marshal(httpResponse) if err != nil { - log.Println("Error marshaling response to JSON", "error", err) - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + log.Println("error marshaling response to json", "error", err) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } _, err = w.Write(responseJSON) if err != nil { - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } } func (s *HTTPServer) SearchHandler(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Search results"}) + util.JSONResponse(w, http.StatusOK, map[string]string{"message": "search results"}) } diff --git a/internal/tests/dbmocks/mock_dicedb.go b/internal/tests/dbmocks/dicedb_mock.go similarity index 100% rename from internal/tests/dbmocks/mock_dicedb.go rename to internal/tests/dbmocks/dicedb_mock.go diff --git a/internal/tests/integration/ratelimiter_integration_test.go b/internal/tests/integration/ratelimiter_integration_test.go index 421dafb..6de1b8e 100644 --- a/internal/tests/integration/ratelimiter_integration_test.go +++ b/internal/tests/integration/ratelimiter_integration_test.go @@ -12,8 +12,8 @@ import ( func TestRateLimiterWithinLimit(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec w, r, rateLimiter := util.SetupRateLimiter(limit, window) @@ -25,8 +25,8 @@ func TestRateLimiterWithinLimit(t *testing.T) { func TestRateLimiterExceedsLimit(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec w, r, rateLimiter := util.SetupRateLimiter(limit, window) diff --git a/internal/middleware/trailing_slash_test.go b/internal/tests/integration/trailing_slash_test.go similarity index 72% rename from internal/middleware/trailing_slash_test.go rename to internal/tests/integration/trailing_slash_test.go index 203fb1a..972158d 100644 --- a/internal/middleware/trailing_slash_test.go +++ b/internal/tests/integration/trailing_slash_test.go @@ -45,19 +45,19 @@ func TestTrailingSlashMiddleware(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", tt.requestURL, nil) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := httptest.NewRequest("GET", test.requestURL, nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - if w.Code != tt.expectedCode { - t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + if w.Code != test.expectedCode { + t.Errorf("expected status %d, got %d", test.expectedCode, w.Code) } - if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { - t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + if test.expectedUrl != "" && w.Header().Get("Location") != test.expectedUrl { + t.Errorf("expected location %s, got %s", test.expectedUrl, w.Header().Get("Location")) } }) } diff --git a/internal/tests/stress/ratelimiter_stress_test.go b/internal/tests/stress/ratelimiter_stress_test.go index 7e37054..b95b7c2 100644 --- a/internal/tests/stress/ratelimiter_stress_test.go +++ b/internal/tests/stress/ratelimiter_stress_test.go @@ -14,13 +14,13 @@ import ( func TestRateLimiterUnderStress(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec _, r, rateLimiter := util.SetupRateLimiter(limit, window) var wg sync.WaitGroup - var numRequests int64 = limit // add some extra requests to ensure we don't hit the limit + var numRequests int64 = limit + 10 successCount := int64(0) failCount := int64(0) var mu sync.Mutex @@ -43,5 +43,5 @@ func TestRateLimiterUnderStress(t *testing.T) { }() } wg.Wait() - require.Equal(t, limit, successCount, "Should succeed for exactly limit requests") + require.Equal(t, limit, successCount, "should succeed for exactly limit requests") } diff --git a/main.go b/main.go index 09a082f..6b37cde 100644 --- a/main.go +++ b/main.go @@ -6,21 +6,21 @@ import ( "net/http" "server/config" "server/internal/db" - "server/internal/server" // Import the new package for HTTPServer + "server/internal/server" ) func main() { configValue := config.LoadConfig() diceClient, err := db.InitDiceClient(configValue) if err != nil { - log.Fatalf("Failed to initialize dice client: %v", err) + log.Fatalf("Failed to initialize DiceDB client: %v", err) } // Create mux and register routes mux := http.NewServeMux() - httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimit, configValue.RequestWindow) + httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimitPerMin, configValue.RequestWindowSec) mux.HandleFunc("/health", httpServer.HealthCheck) - mux.HandleFunc("/cli/{cmd}", httpServer.CliHandler) + mux.HandleFunc("/shell/exec/{cmd}", httpServer.CliHandler) mux.HandleFunc("/search", httpServer.SearchHandler) // Graceful shutdown context @@ -29,6 +29,6 @@ func main() { // Run the HTTP Server if err := httpServer.Run(ctx); err != nil { - log.Printf("Server failed: %v\n", err) + log.Printf("server failed: %v\n", err) } } diff --git a/internal/cmds/cmds.go b/pkg/util/cmds/cmds.go similarity index 100% rename from internal/cmds/cmds.go rename to pkg/util/cmds/cmds.go diff --git a/pkg/util/helpers.go b/pkg/util/helpers.go index c86cef0..48980b5 100644 --- a/pkg/util/helpers.go +++ b/pkg/util/helpers.go @@ -8,9 +8,9 @@ import ( "log" "net/http" "net/http/httptest" - "server/internal/cmds" "server/internal/middleware" db "server/internal/tests/dbmocks" + cmds "server/pkg/util/cmds" "strings" ) @@ -35,129 +35,140 @@ const ( JSONIngest string = "JSON.INGEST" ) +var priorityKeys = []string{ + Key, Keys, Field, Path, Value, Values, Seconds, User, Password, KeyValues, QwatchQuery, Offset, Member, Members, +} + +// ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { - command := strings.TrimPrefix(r.URL.Path, "/cli/") + command := extractCommand(r.URL.Path) if command == "" { return nil, errors.New("invalid command") } - command = strings.ToUpper(command) - var args []string + args, err := extractArgsFromRequest(r, command) + if err != nil { + return nil, err + } + + return &cmds.CommandRequest{ + Cmd: command, + Args: args, + }, nil +} + +// extractCommand retrieves and formats the command from the URL path +func extractCommand(path string) string { + command := strings.TrimPrefix(path, "/cli/") + return strings.ToUpper(command) +} - // Extract query parameters +// extractArgsFromRequest extracts arguments from the request's URL query and body +func extractArgsFromRequest(r *http.Request, command string) ([]string, error) { + var args []string queryParams := r.URL.Query() keyPrefix := queryParams.Get(KeyPrefix) if keyPrefix != "" && command == JSONIngest { args = append(args, keyPrefix) } - // Step 1: Handle JSON body if present + if r.Body != nil { - body, err := io.ReadAll(r.Body) + bodyArgs, err := parseRequestBody(r.Body) if err != nil { return nil, err } + args = append(args, bodyArgs...) + } - if len(body) > 0 { - var jsonBody map[string]interface{} - if err := json.Unmarshal(body, &jsonBody); err != nil { - return nil, err - } + return args, nil +} - if len(jsonBody) == 0 { - return nil, fmt.Errorf("empty JSON object") - } +// parseRequestBody parses the body of the request and extracts arguments from JSON content +func parseRequestBody(body io.ReadCloser) ([]string, error) { + var args []string + bodyContent, err := io.ReadAll(body) + if err != nil { + return nil, err + } - // Define keys to exclude and process their values first - // Update as we support more commands - var priorityKeys = []string{ - Key, - Keys, - Field, - Path, - Value, - Values, - Seconds, - User, - Password, - KeyValues, - QwatchQuery, - Offset, - Member, - Members, - } - for _, key := range priorityKeys { - if val, exists := jsonBody[key]; exists { - if key == Keys { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - if key == Values { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - // MultiKey operations - if key == KeyValues { - // Handle KeyValues separately - for k, v := range val.(map[string]interface{}) { - args = append(args, k, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - if key == Members { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - args = append(args, fmt.Sprintf("%v", val)) - delete(jsonBody, key) - } + if len(bodyContent) == 0 { + return args, nil + } + + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyContent, &jsonBody); err != nil { + return nil, err + } + + if len(jsonBody) == 0 { + return nil, fmt.Errorf("empty JSON object") + } + + args = append(args, extractPriorityArgs(jsonBody)...) + args = append(args, extractRemainingArgs(jsonBody)...) + + return args, nil +} + +// extractPriorityArgs extracts arguments for priority keys from the JSON body +func extractPriorityArgs(jsonBody map[string]interface{}) []string { + var args []string + for _, key := range priorityKeys { + if val, exists := jsonBody[key]; exists { + switch key { + case Keys, Values, Members: + args = append(args, convertListToStrings(val.([]interface{}))...) + case KeyValues: + args = append(args, convertMapToStrings(val.(map[string]interface{}))...) + default: + args = append(args, fmt.Sprintf("%v", val)) } + delete(jsonBody, key) + } + } + return args +} - // Process remaining keys in the JSON body - for key, val := range jsonBody { - switch v := val.(type) { - case string: - // Handle unary operations like 'nx' where value is "true" - args = append(args, key) - if !strings.EqualFold(v, True) { - args = append(args, v) - } - case map[string]interface{}, []interface{}: - // Marshal nested JSON structures back into a string - jsonValue, err := json.Marshal(v) - if err != nil { - return nil, err - } - args = append(args, string(jsonValue)) - default: - args = append(args, key) - // Append other types as strings - value := fmt.Sprintf("%v", v) - if !strings.EqualFold(value, True) { - args = append(args, value) - } - } +// extractRemainingArgs processes any remaining non-priority arguments in the JSON body +func extractRemainingArgs(jsonBody map[string]interface{}) []string { + var args []string + for key, val := range jsonBody { + switch v := val.(type) { + case string: + args = append(args, key) + if !strings.EqualFold(v, True) { + args = append(args, v) } + case map[string]interface{}, []interface{}: + jsonValue, _ := json.Marshal(v) + args = append(args, string(jsonValue)) + default: + args = append(args, key, fmt.Sprintf("%v", v)) } } + return args +} - // Step 2: Return the constructed Redis command - return &cmds.CommandRequest{ - Cmd: command, - Args: args, - }, nil +// convertListToStrings converts a list of interface{} to a list of strings +func convertListToStrings(list []interface{}) []string { + var result []string + for _, v := range list { + result = append(result, fmt.Sprintf("%v", v)) + } + return result +} + +// convertMapToStrings converts a map of key-value pairs to a flat list of strings +func convertMapToStrings(m map[string]interface{}) []string { + var result []string + for k, v := range m { + result = append(result, k, fmt.Sprintf("%v", v)) + } + return result } +// JSONResponse sends a JSON response to the client func JSONResponse(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) @@ -166,6 +177,7 @@ func JSONResponse(w http.ResponseWriter, status int, data interface{}) { } } +// MockHandler is a basic mock handler for testing func MockHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) if _, err := w.Write([]byte("OK")); err != nil { @@ -173,10 +185,11 @@ func MockHandler(w http.ResponseWriter, r *http.Request) { } } +// SetupRateLimiter sets up a rate limiter for testing purposes func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { mockClient := db.NewDiceDBMock() - r := httptest.NewRequest("GET", "/cli/somecommand", http.NoBody) + r := httptest.NewRequest("GET", "/shell/exec/get", http.NoBody) w := httptest.NewRecorder() rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window)