From c9eaa28a8e1d298503d29d8e75cb8c502c352c3a Mon Sep 17 00:00:00 2001 From: Prashant Shubham Date: Sat, 5 Oct 2024 23:22:39 +0530 Subject: [PATCH] Adding support for generic command execution (#26) --- config/config.go | 2 +- internal/db/commands.go | 16 ---- internal/db/dicedb.go | 68 +++++++---------- internal/middleware/cors.go | 13 +++- internal/middleware/ratelimiter.go | 10 ++- internal/server/http.go | 3 +- main.go | 2 + util/cmds/cmds.go | 4 +- util/helpers.go | 114 +++-------------------------- 9 files changed, 62 insertions(+), 170 deletions(-) delete mode 100644 internal/db/commands.go diff --git a/config/config.go b/config/config.go index b0f0b9e..433c5f9 100644 --- a/config/config.go +++ b/config/config.go @@ -30,7 +30,7 @@ func LoadConfig() *Config { ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64 - AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:3000"}), // Default allowed origins } } diff --git a/internal/db/commands.go b/internal/db/commands.go deleted file mode 100644 index f78b616..0000000 --- a/internal/db/commands.go +++ /dev/null @@ -1,16 +0,0 @@ -package db - -func (db *DiceDB) getKey(key string) (string, error) { - val, err := db.Client.Get(db.Ctx, key).Result() - return val, err -} - -func (db *DiceDB) setKey(key, value string) error { - err := db.Client.Set(db.Ctx, key, value, 0).Err() - return err -} - -func (db *DiceDB) deleteKeys(keys []string) error { - err := db.Client.Del(db.Ctx, keys...).Err() - return err -} diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index 1701cf6..71d4901 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -17,10 +17,6 @@ import ( dicedb "github.com/dicedb/go-dice" ) -const ( - RespOK = "OK" -) - type DiceDB struct { Client *dicedb.Client Ctx context.Context @@ -56,47 +52,35 @@ func InitDiceClient(configValue *config.Config) (*DiceDB, error) { // ExecuteCommand executes a command based on the input func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, error) { - switch command.Cmd { - case "GET": - if len(command.Args) != 1 { - return nil, errors.New("invalid args") - } - - val, err := db.getKey(command.Args[0]) - switch { - case errors.Is(err, dicedb.Nil): - return nil, errors.New("key does not exist") - case err != nil: - return nil, fmt.Errorf("get failed %v", err) - } - - return val, nil - - case "SET": - if len(command.Args) < 2 { - return nil, errors.New("key is required") - } - - err := db.setKey(command.Args[0], command.Args[1]) - if err != nil { - return nil, errors.New("failed to set key") - } - - return RespOK, nil - - case "DEL": - if len(command.Args) == 0 { - return nil, errors.New("at least one key is required") - } + args := make([]interface{}, 0, len(command.Args)+1) + args = append(args, command.Cmd) + for _, arg := range command.Args { + args = append(args, arg) + } - err := db.deleteKeys(command.Args) - if err != nil { - return nil, errors.New("failed to delete keys") - } + res, err := db.Client.Do(db.Ctx, args...).Result() + if errors.Is(err, dicedb.Nil) { + return nil, errors.New("(nil)") + } - return RespOK, nil + if err != nil { + return nil, fmt.Errorf("(error) %v", err) + } + // Print the result based on its type + switch v := res.(type) { + case string: + return v, nil + case []byte: + return string(v), nil + case []interface{}: + case int64: + return fmt.Sprintf("%v", v), nil + case nil: + return "(nil)", nil default: - return nil, errors.New("unknown command") + return fmt.Sprintf("%v", v), nil } + + return nil, fmt.Errorf("(error) invalid result type") } diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 63bc5d9..3e75120 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -5,30 +5,37 @@ import ( "server/config" ) -func enableCors(w http.ResponseWriter, r *http.Request) { +// Updated enableCors function to return a boolean indicating if OPTIONS was handled +func handleCors(w http.ResponseWriter, r *http.Request) bool { configValue := config.LoadConfig() allAllowedOrigins := configValue.AllowedOrigins origin := r.Header.Get("Origin") allowed := false + for _, allowedOrigin := range allAllowedOrigins { if origin == allowedOrigin || allowedOrigin == "*" || origin == "" { allowed = true break } } + if !allowed { http.Error(w, "CORS: origin not allowed", http.StatusForbidden) - return + return true } w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE, PATCH") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Content-Length") + // If the request is an OPTIONS request, handle it and stop further processing if r.Method == http.MethodOptions { w.Header().Set("Access-Control-Max-Age", "86400") w.WriteHeader(http.StatusOK) - return + return true } + + // Continue processing other requests w.Header().Set("Content-Type", "application/json") + return false } diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c852a26..c2b43e3 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -18,7 +18,10 @@ import ( // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - enableCors(w, r) + if handleCors(w, r) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -79,7 +82,10 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - enableCors(w, r) + if handleCors(w, r) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/internal/server/http.go b/internal/server/http.go index 9b12ec3..00d6a3e 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -42,6 +42,7 @@ func errorResponse(response string) string { slog.Error("Error marshaling response: %v", slog.Any("err", err)) return `{"error": "internal server error"}` } + return string(jsonResponse) } @@ -110,7 +111,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse("error executing command"), http.StatusBadRequest) + http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) return } diff --git a/main.go b/main.go index b01de76..99c98e3 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "net/http" + "os" "server/config" "server/internal/db" "server/internal/server" @@ -14,6 +15,7 @@ func main() { diceClient, err := db.InitDiceClient(configValue) if err != nil { slog.Error("Failed to initialize DiceDB client: %v", slog.Any("err", err)) + os.Exit(1) } // Create mux and register routes diff --git a/util/cmds/cmds.go b/util/cmds/cmds.go index bb7a275..1e68d08 100644 --- a/util/cmds/cmds.go +++ b/util/cmds/cmds.go @@ -1,6 +1,6 @@ package cmds type CommandRequest struct { - Cmd string `json:"cmd"` - Args []string + Cmd string `json:"cmd"` + Args []string `json:"args"` } diff --git a/util/helpers.go b/util/helpers.go index 20b3c2b..20a605a 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -10,35 +10,10 @@ import ( "net/http/httptest" "server/internal/middleware" db "server/internal/tests/dbmocks" - cmds "server/util/cmds" + "server/util/cmds" "strings" ) -const ( - Key = "key" - Keys = "keys" - KeyPrefix = "key_prefix" - Field = "field" - Path = "path" - Value = "value" - Values = "values" - User = "user" - Password = "password" - Seconds = "seconds" - KeyValues = "key_values" - True = "true" - QwatchQuery = "query" - Offset = "offset" - Member = "member" - Members = "members" - - JSONIngest string = "JSON.INGEST" -) - -var priorityKeys = []string{ - Key, Keys, Field, Path, Value, Values, Seconds, User, Password, KeyValues, QwatchQuery, Offset, Member, Members, -} - // ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { command := extractCommand(r.URL.Path) @@ -46,7 +21,7 @@ func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { return nil, errors.New("invalid command") } - args, err := extractArgsFromRequest(r, command) + args, err := newExtractor(r) if err != nil { return nil, err } @@ -62,29 +37,9 @@ func extractCommand(path string) string { return strings.ToUpper(command) } -func extractArgsFromRequest(r *http.Request, command string) ([]string, error) { +func newExtractor(r *http.Request) ([]string, error) { var args []string - queryParams := r.URL.Query() - keyPrefix := queryParams.Get(KeyPrefix) - - if keyPrefix != "" && command == JSONIngest { - args = append(args, keyPrefix) - } - - if r.Body != nil { - bodyArgs, err := parseRequestBody(r.Body) - if err != nil { - return nil, err - } - args = append(args, bodyArgs...) - } - - return args, nil -} - -func parseRequestBody(body io.ReadCloser) ([]string, error) { - var args []string - bodyContent, err := io.ReadAll(body) + bodyContent, err := io.ReadAll(r.Body) if err != nil { return nil, err } @@ -93,7 +48,7 @@ func parseRequestBody(body io.ReadCloser) ([]string, error) { return args, nil } - var jsonBody map[string]interface{} + var jsonBody []interface{} if err := json.Unmarshal(bodyContent, &jsonBody); err != nil { return nil, err } @@ -102,63 +57,16 @@ func parseRequestBody(body io.ReadCloser) ([]string, error) { return nil, fmt.Errorf("empty JSON object") } - args = append(args, extractPriorityArgs(jsonBody)...) - args = append(args, extractRemainingArgs(jsonBody)...) - - return args, nil -} - -func extractPriorityArgs(jsonBody map[string]interface{}) []string { - var args []string - for _, key := range priorityKeys { - if val, exists := jsonBody[key]; exists { - switch key { - case Keys, Values, Members: - args = append(args, convertListToStrings(val.([]interface{}))...) - case KeyValues: - args = append(args, convertMapToStrings(val.(map[string]interface{}))...) - default: - args = append(args, fmt.Sprintf("%v", val)) - } - delete(jsonBody, key) - } - } - return args -} - -func extractRemainingArgs(jsonBody map[string]interface{}) []string { - var args []string - for key, val := range jsonBody { - switch v := val.(type) { - case string: - args = append(args, key) - if !strings.EqualFold(v, True) { - args = append(args, v) - } - case map[string]interface{}, []interface{}: - jsonValue, _ := json.Marshal(v) - args = append(args, string(jsonValue)) - default: - args = append(args, key, fmt.Sprintf("%v", v)) + for _, val := range jsonBody { + s, ok := val.(string) + if !ok { + return nil, fmt.Errorf("invalid input") } - } - return args -} -func convertListToStrings(list []interface{}) []string { - var result []string - for _, v := range list { - result = append(result, fmt.Sprintf("%v", v)) + args = append(args, s) } - return result -} -func convertMapToStrings(m map[string]interface{}) []string { - var result []string - for k, v := range m { - result = append(result, k, fmt.Sprintf("%v", v)) - } - return result + return args, nil } // JSONResponse sends a JSON response to the client