Skip to content

Commit

Permalink
Disable list of commands from playground repositories #897 - Command…
Browse files Browse the repository at this point in the history
…s Blacklisted (DiceDB#23)
  • Loading branch information
yashbudhia authored Oct 6, 2024
1 parent f255c5f commit d373641
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
10 changes: 3 additions & 7 deletions internal/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import (
"errors"
"log/slog"
"net/http"
"server/internal/middleware"
"strings"
"sync"
"time"

"server/internal/db"
"server/internal/middleware"
util "server/util"
)

Expand All @@ -20,8 +20,6 @@ type HTTPServer struct {
DiceClient *db.DiceDB
}

// HandlerMux wraps ServeMux and forces REST paths to lowercase
// and attaches a rate limiter with the handler
type HandlerMux struct {
mux *http.ServeMux
rateLimiter func(http.ResponseWriter, *http.Request, http.Handler)
Expand All @@ -47,10 +45,8 @@ func errorResponse(response string) string {
}

func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Convert the path to lowercase before passing to the underlying mux.
middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.ToLower(r.URL.Path)
// Apply rate limiter
cim.rateLimiter(w, r, cim.mux)
})).ServeHTTP(w, r)
}
Expand Down Expand Up @@ -105,7 +101,7 @@ func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) {
func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) {
diceCmd, err := util.ParseHTTPRequest(r)
if err != nil {
http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest)
http.Error(w, errorResponse(err.Error()), http.StatusBadRequest)
return
}

Expand All @@ -119,7 +115,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) {
respStr, ok := resp.(string)
if !ok {
slog.Error("error: response is not a string", "error", slog.Any("err", err))
http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError)
http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError)
return
}

Expand Down
36 changes: 36 additions & 0 deletions util/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,49 @@ import (
"strings"
)

// Map of blocklisted commands
var blocklistedCommands = map[string]bool{
"FLUSHALL": true,
"FLUSHDB": true,
"DUMP": true,
"ABORT": true,
"AUTH": true,
"CONFIG": true,
"SAVE": true,
"BGSAVE": true,
"BGREWRITEAOF": true,
"RESTORE": true,
"MULTI": true,
"EXEC": true,
"DISCARD": true,
"QWATCH": true,
"QUNWATCH": true,
"LATENCY": true,
"CLIENT": true,
"SLEEP": true,
"PERSIST": true,
}

// BlockListedCommand checks if a command is blocklisted
func BlockListedCommand(cmd string) error {
if _, exists := blocklistedCommands[strings.ToUpper(cmd)]; exists {
return errors.New("ERR unknown command '" + cmd + "'")
}
return nil
}

// ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands
func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) {
command := extractCommand(r.URL.Path)
if command == "" {
return nil, errors.New("invalid command")
}

// Check if the command is blocklisted
if err := BlockListedCommand(command); err != nil {
return nil, err
}

args, err := newExtractor(r)
if err != nil {
return nil, err
Expand Down

0 comments on commit d373641

Please sign in to comment.