diff --git a/config/config.go b/config/config.go index b0f0b9e..433c5f9 100644 --- a/config/config.go +++ b/config/config.go @@ -30,7 +30,7 @@ func LoadConfig() *Config { ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64 - AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:3000"}), // Default allowed origins } } diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 63bc5d9..3e75120 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -5,30 +5,37 @@ import ( "server/config" ) -func enableCors(w http.ResponseWriter, r *http.Request) { +// Updated enableCors function to return a boolean indicating if OPTIONS was handled +func handleCors(w http.ResponseWriter, r *http.Request) bool { configValue := config.LoadConfig() allAllowedOrigins := configValue.AllowedOrigins origin := r.Header.Get("Origin") allowed := false + for _, allowedOrigin := range allAllowedOrigins { if origin == allowedOrigin || allowedOrigin == "*" || origin == "" { allowed = true break } } + if !allowed { http.Error(w, "CORS: origin not allowed", http.StatusForbidden) - return + return true } w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE, PATCH") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Content-Length") + // If the request is an OPTIONS request, handle it and stop further processing if r.Method == http.MethodOptions { w.Header().Set("Access-Control-Max-Age", "86400") w.WriteHeader(http.StatusOK) - return + return true } + + // Continue processing other requests w.Header().Set("Content-Type", "application/json") + return false } diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c852a26..c2b43e3 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -18,7 +18,10 @@ import ( // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - enableCors(w, r) + if handleCors(w, r) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -79,7 +82,10 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - enableCors(w, r) + if handleCors(w, r) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel()