Skip to content

Commit

Permalink
#733 & #894 : integration and stress tests for ratelimiter || added C…
Browse files Browse the repository at this point in the history
…ORS to server (#13)
  • Loading branch information
rishavvajpayee authored Oct 4, 2024
1 parent 7894fd2 commit 12a0fef
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 40 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
.vscode/
.env
/playground-mono
__debug_bin*
59 changes: 49 additions & 10 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
package config

import (
"fmt"
"os"
"strconv"
"strings"

"github.com/joho/godotenv"
)

// Config holds the application configuration
type Config struct {
DiceAddr string
ServerPort string
RequestLimit int // Field for the request limit
RequestWindow int // Field for the time window in seconds
DiceAddr string
ServerPort string
RequestLimit int64 // Field for the request limit
RequestWindow float64 // Field for the time window in float64
AllowedOrigins []string // Field for the allowed origins
}

// LoadConfig loads the application configuration from environment variables or defaults
func LoadConfig() *Config {
err := godotenv.Load()
if err != nil {
fmt.Println("Warning: .env file not found, falling back to system environment variables.")
}

return &Config{
DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address
ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port
RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit
RequestWindow: getEnvInt("REQUEST_WINDOW", 60), // Default request window in seconds
DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address
ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port
RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit
RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64
AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins
}
}

Expand All @@ -32,11 +43,39 @@ func getEnv(key, fallback string) string {
}

// getEnvInt retrieves an environment variable as an integer or returns a default value
func getEnvInt(key string, fallback int) int {
func getEnvInt(key string, fallback int) int64 {
if value, exists := os.LookupEnv(key); exists {
if intValue, err := strconv.Atoi(value); err == nil {
return intValue
return int64(intValue)
}
}
return int64(fallback)
}

// added for miliseconds request window controls
func getEnvFloat64(key string, fallback float64) float64 {
if value, exists := os.LookupEnv(key); exists {
if floatValue, err := strconv.ParseFloat(value, 64); err == nil {
return floatValue
}
}
return fallback
}

func getEnvArray(key string, fallback []string) []string {
if value, exists := os.LookupEnv(key); exists {
if arrayValue := splitString(value); len(arrayValue) > 0 {
return arrayValue
}
}
return fallback
}

// splitString splits a string by comma and returns a slice of strings
func splitString(s string) []string {
var array []string
for _, v := range strings.Split(s, ",") {
array = append(array, strings.TrimSpace(v))
}
return array
}
10 changes: 9 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@ module server

go 1.22.5

require (
github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831
github.com/joho/godotenv v1.5.1
github.com/stretchr/testify v1.9.0
)

require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
18 changes: 16 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 h1:Cqyj9WCtoobN6++bFbDSe27q94SPwJD9Z0wmu+SDRuk=
github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831/go.mod h1:8+VZrr14c2LW8fW4tWZ8Bv3P2lfvlg+PpsSn5cWWuiQ=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
34 changes: 34 additions & 0 deletions internal/middleware/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package middleware

import (
"net/http"
"server/config"
)

func enableCors(w http.ResponseWriter, r *http.Request) {
configValue := config.LoadConfig()
allAllowedOrigins := configValue.AllowedOrigins
origin := r.Header.Get("Origin")
allowed := false
for _, allowedOrigin := range allAllowedOrigins {
if origin == allowedOrigin || allowedOrigin == "*" || origin == "" {
allowed = true
break
}
}
if !allowed {
http.Error(w, "CORS: origin not allowed", http.StatusForbidden)
return
}

w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE, PATCH")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Content-Length")

if r.Method == http.MethodOptions {
w.Header().Set("Access-Control-Max-Age", "86400")
w.WriteHeader(http.StatusOK)
return
}
w.Header().Set("Content-Type", "application/json")
}
95 changes: 75 additions & 20 deletions internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,22 @@ import (
"log/slog"
"net/http"
"server/internal/db"
mock "server/internal/tests/dbmocks"
"strconv"
"strings"
"time"

dice "github.com/dicedb/go-dice"
)

// TODO: Look at this later
func enableCors(w http.ResponseWriter) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
}

// RateLimiter middleware to limit requests based on a specified limit and duration
func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.Handler {
func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Enable CORS for requests
enableCors(w, r)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Set CORS headers
enableCors(w)

// Handle OPTIONS preflight request
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}

// Skip rate limiting for non-command endpoints
if !strings.Contains(r.URL.Path, "/cli/") {
next.ServeHTTP(w, r)
Expand All @@ -56,9 +43,9 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H
}

// Initialize request count
requestCount := 0
requestCount := int64(0)
if val != "" {
requestCount, err = strconv.Atoi(val)
requestCount, err = strconv.ParseInt(val, 10, 64)
if err != nil {
slog.Error("Error converting request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
Expand All @@ -67,7 +54,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H
}

// Check if the request count exceeds the limit
if requestCount >= limit {
if requestCount > limit {
slog.Warn("Request limit exceeded", "count", requestCount)
http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests)
return
Expand All @@ -94,3 +81,71 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit, window int) http.H
next.ServeHTTP(w, r)
})
}

func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Enable CORS for requests
enableCors(w, r)

// Set a request context with a timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Only apply rate limiting for specific paths (e.g., "/cli/")
if !strings.Contains(r.URL.Path, "/cli/") {
next.ServeHTTP(w, r)
return
}

// Generate the rate limiting key based on the current window
currentWindow := time.Now().Unix() / int64(window)
key := fmt.Sprintf("request_count:%d", currentWindow)
slog.Info("Created rate limiter key", slog.Any("key", key))

// Get the current request count for this window from the mock DB
val, err := client.Get(ctx, key)
if err != nil {
slog.Error("Error fetching request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

// Parse the current request count or initialize to 0
var requestCount int64 = 0
if val != "" {
requestCount, err = strconv.ParseInt(val, 10, 64)
if err != nil {
slog.Error("Error converting request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}

// Check if the request limit has been exceeded
if requestCount >= limit {
slog.Warn("Request limit exceeded", "count", requestCount)
http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests)
return
}

// Increment the request count in the mock DB
requestCount, err = client.Incr(ctx, key)
if err != nil {
slog.Error("Error incrementing request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

// Set expiration for the key if it's the first request in the window
if requestCount == 1 {
err = client.Expire(ctx, key, time.Duration(window)*time.Second)
if err != nil {
slog.Error("Error setting expiry for request count", "error", err)
}
}

// Log the successful request and pass control to the next handler
slog.Info("Request processed", "count", requestCount)
next.ServeHTTP(w, r)
})
}
15 changes: 8 additions & 7 deletions internal/server/httpServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cim.rateLimiter(w, r, cim.mux)
}

func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit, window int) *HTTPServer {
func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit int64, window float64) *HTTPServer {
handlerMux := &HandlerMux{
mux: mux,
rateLimiter: func(w http.ResponseWriter, r *http.Request, next http.Handler) {
Expand Down Expand Up @@ -97,26 +97,27 @@ func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) {
func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) {
diceCmd, err := util.ParseHTTPRequest(r)
if err != nil {
http.Error(w, "Error parsing HTTP request", http.StatusBadRequest)
http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest)
return
}

resp, err := s.DiceClient.ExecuteCommand(diceCmd)
if err != nil {
http.Error(w, errorResponse(err.Error()), http.StatusBadRequest)
http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest)
return
}

if _, ok := resp.(string); !ok {
log.Println("Error marshaling response", "error", err)
respStr, ok := resp.(string)
if !ok {
log.Println("Error: response is not a string", "error", err)
http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError)
return
}

httpResponse := HTTPResponse{Data: resp.(string)}
httpResponse := HTTPResponse{Data: respStr}
responseJSON, err := json.Marshal(httpResponse)
if err != nil {
log.Println("Error marshaling response", "error", err)
log.Println("Error marshaling response to JSON", "error", err)
http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError)
return
}
Expand Down
Loading

0 comments on commit 12a0fef

Please sign in to comment.