Skip to content

Commit

Permalink
DiceDB#21: Refactored repo for consistency (DiceDB#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
rishavvajpayee authored and yashbudhia committed Oct 6, 2024
1 parent d83c716 commit 11565ac
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 257 deletions.
20 changes: 10 additions & 10 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (

// Config holds the application configuration
type Config struct {
DiceAddr string
ServerPort string
RequestLimit int64 // Field for the request limit
RequestWindow float64 // Field for the time window in float64
AllowedOrigins []string // Field for the allowed origins
DiceDBAddr string
ServerPort string
RequestLimitPerMin int64 // Field for the request limit
RequestWindowSec float64 // Field for the time window in float64
AllowedOrigins []string // Field for the allowed origins
}

// LoadConfig loads the application configuration from environment variables or defaults
Expand All @@ -26,11 +26,11 @@ func LoadConfig() *Config {
}

return &Config{
DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address
ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port
RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit
RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64
AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins
DiceDBAddr: getEnv("DICEDB_ADDR", "localhost:7379"), // Default DiceDB address
ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port
RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit
RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64
AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins
}
}

Expand Down
14 changes: 7 additions & 7 deletions internal/db/dicedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ import (
"log/slog"
"os"
"server/config"
"server/internal/cmds"
"server/util/cmds"
"time"

dice "github.com/dicedb/go-dice"
dicedb "github.com/dicedb/go-dice"
)

const (
RespOK = "OK"
)

type DiceDB struct {
Client *dice.Client
Client *dicedb.Client
Ctx context.Context
}

Expand All @@ -36,13 +36,13 @@ func (db *DiceDB) CloseDiceDB() {
}

func InitDiceClient(configValue *config.Config) (*DiceDB, error) {
diceClient := dice.NewClient(&dice.Options{
Addr: configValue.DiceAddr,
diceClient := dicedb.NewClient(&dicedb.Options{
Addr: configValue.DiceDBAddr,
DialTimeout: 10 * time.Second,
MaxRetries: 10,
})

// Ping the dice client to verify the connection
// Ping the dicedb client to verify the connection
err := diceClient.Ping(context.Background()).Err()
if err != nil {
return nil, err
Expand All @@ -64,7 +64,7 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err

val, err := db.getKey(command.Args[0])
switch {
case errors.Is(err, dice.Nil):
case errors.Is(err, dicedb.Nil):
return nil, errors.New("key does not exist")
case err != nil:
return nil, fmt.Errorf("get failed %v", err)
Expand Down
17 changes: 4 additions & 13 deletions internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@ import (
"strings"
"time"

dice "github.com/dicedb/go-dice"
dicedb "github.com/dicedb/go-dice"
)

// RateLimiter middleware to limit requests based on a specified limit and duration
func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Enable CORS for requests
enableCors(w, r)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Skip rate limiting for non-command endpoints
if !strings.Contains(r.URL.Path, "/cli/") {
if !strings.Contains(r.URL.Path, "/shell/exec/") {
next.ServeHTTP(w, r)
return
}
Expand All @@ -36,7 +34,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float

// Fetch the current request count
val, err := client.Client.Get(ctx, key).Result()
if err != nil && !errors.Is(err, dice.Nil) {
if err != nil && !errors.Is(err, dicedb.Nil) {
slog.Error("Error fetching request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -74,25 +72,19 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float
}
}

// Log the successful request increment
slog.Info("Request processed", "count", requestCount+1)

// Call the next handler
next.ServeHTTP(w, r)
})
}

func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Enable CORS for requests
enableCors(w, r)

// Set a request context with a timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Only apply rate limiting for specific paths (e.g., "/cli/")
if !strings.Contains(r.URL.Path, "/cli/") {
if !strings.Contains(r.URL.Path, "/shell/exec/") {
next.ServeHTTP(w, r)
return
}
Expand Down Expand Up @@ -144,7 +136,6 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
}
}

// Log the successful request and pass control to the next handler
slog.Info("Request processed", "count", requestCount)
next.ServeHTTP(w, r)
})
Expand Down
2 changes: 0 additions & 2 deletions internal/middleware/trailingslash.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import (
func TrailingSlashMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") {
// remove slash
newPath := strings.TrimSuffix(r.URL.Path, "/")
// if query params exist append them
newURL := newPath
if r.URL.RawQuery != "" {
newURL += "?" + r.URL.RawQuery
Expand Down
39 changes: 22 additions & 17 deletions internal/server/httpServer.go → internal/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"log/slog"
"net/http"
"strings"
"sync"
"time"

"server/internal/middleware"
"server/internal/db"
util "server/pkg/util"
util "server/util"
)

type HTTPServer struct {
Expand All @@ -37,7 +36,13 @@ type HTTPErrorResponse struct {
}

func errorResponse(response string) string {
return fmt.Sprintf("{\"error\": %q}", response)
errorMessage := map[string]string{"error": response}
jsonResponse, err := json.Marshal(errorMessage)
if err != nil {
slog.Error("Error marshaling response: %v", slog.Any("err", err))
return `{"error": "internal server error"}`
}
return string(jsonResponse)
}

func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -73,33 +78,33 @@ func (s *HTTPServer) Run(ctx context.Context) error {
wg.Add(1)
go func() {
defer wg.Done()
log.Printf("Starting server at %s\n", s.httpServer.Addr)
slog.Info("starting server at", slog.String("addr", s.httpServer.Addr))
if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("HTTP server error: %v", err)
slog.Error("http server error: %v", slog.Any("err", err))
}
}()

<-ctx.Done()
log.Println("Shutting down server...")
slog.Info("shutting down server...")
return s.Shutdown()
}

func (s *HTTPServer) Shutdown() error {
if err := s.DiceClient.Client.Close(); err != nil {
log.Printf("Failed to close dice client: %v", err)
slog.Error("failed to close dicedb client: %v", slog.Any("err", err))
}

return s.httpServer.Shutdown(context.Background())
}

func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) {
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Server is running"})
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"})
}

func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) {
diceCmd, err := util.ParseHTTPRequest(r)
if err != nil {
http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest)
http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest)
return
}

Expand All @@ -112,32 +117,32 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) {

resp, err := s.DiceClient.ExecuteCommand(diceCmd)
if err != nil {
http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest)
http.Error(w, errorResponse("error executing command"), http.StatusBadRequest)
return
}

respStr, ok := resp.(string)
if !ok {
log.Println("Error: response is not a string", "error", err)
http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError)
slog.Error("error: response is not a string", "error", slog.Any("err", err))
http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError)
return
}

httpResponse := HTTPResponse{Data: respStr}
responseJSON, err := json.Marshal(httpResponse)
if err != nil {
log.Println("Error marshaling response to JSON", "error", err)
http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError)
slog.Error("error marshaling response to json", "error", slog.Any("err", err))
http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError)
return
}

_, err = w.Write(responseJSON)
if err != nil {
http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError)
http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError)
return
}
}

func (s *HTTPServer) SearchHandler(w http.ResponseWriter, request *http.Request) {
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Search results"})
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "search results"})
}
File renamed without changes.
10 changes: 5 additions & 5 deletions internal/tests/integration/ratelimiter_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ import (
"net/http"
"net/http/httptest"
config "server/config"
util "server/pkg/util"
util "server/util"
"testing"

"github.com/stretchr/testify/require"
)

func TestRateLimiterWithinLimit(t *testing.T) {
configValue := config.LoadConfig()
limit := configValue.RequestLimit
window := configValue.RequestWindow
limit := configValue.RequestLimitPerMin
window := configValue.RequestWindowSec

w, r, rateLimiter := util.SetupRateLimiter(limit, window)

Expand All @@ -25,8 +25,8 @@ func TestRateLimiterWithinLimit(t *testing.T) {

func TestRateLimiterExceedsLimit(t *testing.T) {
configValue := config.LoadConfig()
limit := configValue.RequestLimit
window := configValue.RequestWindow
limit := configValue.RequestLimitPerMin
window := configValue.RequestWindowSec

w, r, rateLimiter := util.SetupRateLimiter(limit, window)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ func TestTrailingSlashMiddleware(t *testing.T) {
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.requestURL, nil)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", test.requestURL, nil)
w := httptest.NewRecorder()

handler.ServeHTTP(w, req)

if w.Code != tt.expectedCode {
t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code)
if w.Code != test.expectedCode {
t.Errorf("expected status %d, got %d", test.expectedCode, w.Code)
}

if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl {
t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location"))
if test.expectedUrl != "" && w.Header().Get("Location") != test.expectedUrl {
t.Errorf("expected location %s, got %s", test.expectedUrl, w.Header().Get("Location"))
}
})
}
Expand Down
10 changes: 5 additions & 5 deletions internal/tests/stress/ratelimiter_stress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"net/http"
"net/http/httptest"
"server/config"
util "server/pkg/util"
util "server/util"
"sync"
"testing"
"time"
Expand All @@ -14,13 +14,13 @@ import (

func TestRateLimiterUnderStress(t *testing.T) {
configValue := config.LoadConfig()
limit := configValue.RequestLimit
window := configValue.RequestWindow
limit := configValue.RequestLimitPerMin
window := configValue.RequestWindowSec

_, r, rateLimiter := util.SetupRateLimiter(limit, window)

var wg sync.WaitGroup
var numRequests int64 = limit // add some extra requests to ensure we don't hit the limit
var numRequests int64 = limit
successCount := int64(0)
failCount := int64(0)
var mu sync.Mutex
Expand All @@ -43,5 +43,5 @@ func TestRateLimiterUnderStress(t *testing.T) {
}()
}
wg.Wait()
require.Equal(t, limit, successCount, "Should succeed for exactly limit requests")
require.Equal(t, limit, successCount, "should succeed for exactly limit requests")
}
Loading

0 comments on commit 11565ac

Please sign in to comment.