Skip to content

Commit

Permalink
Adding support for generic command execution (DiceDB#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucifercr07 authored Oct 5, 2024
1 parent b42d332 commit c9eaa28
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 170 deletions.
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func LoadConfig() *Config {
ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port
RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit
RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64
AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins
AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:3000"}), // Default allowed origins
}
}

Expand Down
16 changes: 0 additions & 16 deletions internal/db/commands.go

This file was deleted.

68 changes: 26 additions & 42 deletions internal/db/dicedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ import (
dicedb "github.com/dicedb/go-dice"
)

const (
RespOK = "OK"
)

type DiceDB struct {
Client *dicedb.Client
Ctx context.Context
Expand Down Expand Up @@ -56,47 +52,35 @@ func InitDiceClient(configValue *config.Config) (*DiceDB, error) {

// ExecuteCommand executes a command based on the input
func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, error) {
switch command.Cmd {
case "GET":
if len(command.Args) != 1 {
return nil, errors.New("invalid args")
}

val, err := db.getKey(command.Args[0])
switch {
case errors.Is(err, dicedb.Nil):
return nil, errors.New("key does not exist")
case err != nil:
return nil, fmt.Errorf("get failed %v", err)
}

return val, nil

case "SET":
if len(command.Args) < 2 {
return nil, errors.New("key is required")
}

err := db.setKey(command.Args[0], command.Args[1])
if err != nil {
return nil, errors.New("failed to set key")
}

return RespOK, nil

case "DEL":
if len(command.Args) == 0 {
return nil, errors.New("at least one key is required")
}
args := make([]interface{}, 0, len(command.Args)+1)
args = append(args, command.Cmd)
for _, arg := range command.Args {
args = append(args, arg)
}

err := db.deleteKeys(command.Args)
if err != nil {
return nil, errors.New("failed to delete keys")
}
res, err := db.Client.Do(db.Ctx, args...).Result()
if errors.Is(err, dicedb.Nil) {
return nil, errors.New("(nil)")
}

return RespOK, nil
if err != nil {
return nil, fmt.Errorf("(error) %v", err)
}

// Print the result based on its type
switch v := res.(type) {
case string:
return v, nil
case []byte:
return string(v), nil
case []interface{}:
case int64:
return fmt.Sprintf("%v", v), nil
case nil:
return "(nil)", nil
default:
return nil, errors.New("unknown command")
return fmt.Sprintf("%v", v), nil
}

return nil, fmt.Errorf("(error) invalid result type")
}
13 changes: 10 additions & 3 deletions internal/middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,37 @@ import (
"server/config"
)

func enableCors(w http.ResponseWriter, r *http.Request) {
// Updated enableCors function to return a boolean indicating if OPTIONS was handled
func handleCors(w http.ResponseWriter, r *http.Request) bool {
configValue := config.LoadConfig()
allAllowedOrigins := configValue.AllowedOrigins
origin := r.Header.Get("Origin")
allowed := false

for _, allowedOrigin := range allAllowedOrigins {
if origin == allowedOrigin || allowedOrigin == "*" || origin == "" {
allowed = true
break
}
}

if !allowed {
http.Error(w, "CORS: origin not allowed", http.StatusForbidden)
return
return true
}

w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE, PATCH")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Content-Length")

// If the request is an OPTIONS request, handle it and stop further processing
if r.Method == http.MethodOptions {
w.Header().Set("Access-Control-Max-Age", "86400")
w.WriteHeader(http.StatusOK)
return
return true
}

// Continue processing other requests
w.Header().Set("Content-Type", "application/json")
return false
}
10 changes: 8 additions & 2 deletions internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (
// RateLimiter middleware to limit requests based on a specified limit and duration
func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
enableCors(w, r)
if handleCors(w, r) {
return
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

Expand Down Expand Up @@ -79,7 +82,10 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float

func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
enableCors(w, r)
if handleCors(w, r) {
return
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

Expand Down
3 changes: 2 additions & 1 deletion internal/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func errorResponse(response string) string {
slog.Error("Error marshaling response: %v", slog.Any("err", err))
return `{"error": "internal server error"}`
}

return string(jsonResponse)
}

Expand Down Expand Up @@ -110,7 +111,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) {

resp, err := s.DiceClient.ExecuteCommand(diceCmd)
if err != nil {
http.Error(w, errorResponse("error executing command"), http.StatusBadRequest)
http.Error(w, errorResponse(err.Error()), http.StatusBadRequest)
return
}

Expand Down
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"log/slog"
"net/http"
"os"
"server/config"
"server/internal/db"
"server/internal/server"
Expand All @@ -14,6 +15,7 @@ func main() {
diceClient, err := db.InitDiceClient(configValue)
if err != nil {
slog.Error("Failed to initialize DiceDB client: %v", slog.Any("err", err))
os.Exit(1)
}

// Create mux and register routes
Expand Down
4 changes: 2 additions & 2 deletions util/cmds/cmds.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package cmds

type CommandRequest struct {
Cmd string `json:"cmd"`
Args []string
Cmd string `json:"cmd"`
Args []string `json:"args"`
}
114 changes: 11 additions & 103 deletions util/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,18 @@ import (
"net/http/httptest"
"server/internal/middleware"
db "server/internal/tests/dbmocks"
cmds "server/util/cmds"
"server/util/cmds"
"strings"
)

const (
Key = "key"
Keys = "keys"
KeyPrefix = "key_prefix"
Field = "field"
Path = "path"
Value = "value"
Values = "values"
User = "user"
Password = "password"
Seconds = "seconds"
KeyValues = "key_values"
True = "true"
QwatchQuery = "query"
Offset = "offset"
Member = "member"
Members = "members"

JSONIngest string = "JSON.INGEST"
)

var priorityKeys = []string{
Key, Keys, Field, Path, Value, Values, Seconds, User, Password, KeyValues, QwatchQuery, Offset, Member, Members,
}

// ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands
func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) {
command := extractCommand(r.URL.Path)
if command == "" {
return nil, errors.New("invalid command")
}

args, err := extractArgsFromRequest(r, command)
args, err := newExtractor(r)
if err != nil {
return nil, err
}
Expand All @@ -62,29 +37,9 @@ func extractCommand(path string) string {
return strings.ToUpper(command)
}

func extractArgsFromRequest(r *http.Request, command string) ([]string, error) {
func newExtractor(r *http.Request) ([]string, error) {
var args []string
queryParams := r.URL.Query()
keyPrefix := queryParams.Get(KeyPrefix)

if keyPrefix != "" && command == JSONIngest {
args = append(args, keyPrefix)
}

if r.Body != nil {
bodyArgs, err := parseRequestBody(r.Body)
if err != nil {
return nil, err
}
args = append(args, bodyArgs...)
}

return args, nil
}

func parseRequestBody(body io.ReadCloser) ([]string, error) {
var args []string
bodyContent, err := io.ReadAll(body)
bodyContent, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
Expand All @@ -93,7 +48,7 @@ func parseRequestBody(body io.ReadCloser) ([]string, error) {
return args, nil
}

var jsonBody map[string]interface{}
var jsonBody []interface{}
if err := json.Unmarshal(bodyContent, &jsonBody); err != nil {
return nil, err
}
Expand All @@ -102,63 +57,16 @@ func parseRequestBody(body io.ReadCloser) ([]string, error) {
return nil, fmt.Errorf("empty JSON object")
}

args = append(args, extractPriorityArgs(jsonBody)...)
args = append(args, extractRemainingArgs(jsonBody)...)

return args, nil
}

func extractPriorityArgs(jsonBody map[string]interface{}) []string {
var args []string
for _, key := range priorityKeys {
if val, exists := jsonBody[key]; exists {
switch key {
case Keys, Values, Members:
args = append(args, convertListToStrings(val.([]interface{}))...)
case KeyValues:
args = append(args, convertMapToStrings(val.(map[string]interface{}))...)
default:
args = append(args, fmt.Sprintf("%v", val))
}
delete(jsonBody, key)
}
}
return args
}

func extractRemainingArgs(jsonBody map[string]interface{}) []string {
var args []string
for key, val := range jsonBody {
switch v := val.(type) {
case string:
args = append(args, key)
if !strings.EqualFold(v, True) {
args = append(args, v)
}
case map[string]interface{}, []interface{}:
jsonValue, _ := json.Marshal(v)
args = append(args, string(jsonValue))
default:
args = append(args, key, fmt.Sprintf("%v", v))
for _, val := range jsonBody {
s, ok := val.(string)
if !ok {
return nil, fmt.Errorf("invalid input")
}
}
return args
}

func convertListToStrings(list []interface{}) []string {
var result []string
for _, v := range list {
result = append(result, fmt.Sprintf("%v", v))
args = append(args, s)
}
return result
}

func convertMapToStrings(m map[string]interface{}) []string {
var result []string
for k, v := range m {
result = append(result, k, fmt.Sprintf("%v", v))
}
return result
return args, nil
}

// JSONResponse sends a JSON response to the client
Expand Down

0 comments on commit c9eaa28

Please sign in to comment.