Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#21: Refactored repo for consistency #24

Merged
merged 3 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (

// Config holds the application configuration
type Config struct {
DiceAddr string
ServerPort string
RequestLimit int64 // Field for the request limit
RequestWindow float64 // Field for the time window in float64
AllowedOrigins []string // Field for the allowed origins
DiceDBAddr string
ServerPort string
RequestLimitPerMin int64 // Field for the request limit
RequestWindowSec float64 // Field for the time window in float64
AllowedOrigins []string // Field for the allowed origins
}

// LoadConfig loads the application configuration from environment variables or defaults
Expand All @@ -26,11 +26,11 @@ func LoadConfig() *Config {
}

return &Config{
DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address
ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port
RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit
RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64
AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins
DiceDBAddr: getEnv("DICEDB_ADDR", "localhost:7379"), // Default DiceDB address
ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port
RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit
RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64
AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins
}
}

Expand Down
14 changes: 7 additions & 7 deletions internal/db/dicedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ import (
"log/slog"
"os"
"server/config"
"server/internal/cmds"
"server/util/cmds"
"time"

dice "github.com/dicedb/go-dice"
dicedb "github.com/dicedb/go-dice"
)

const (
RespOK = "OK"
)

type DiceDB struct {
Client *dice.Client
Client *dicedb.Client
Ctx context.Context
}

Expand All @@ -36,13 +36,13 @@ func (db *DiceDB) CloseDiceDB() {
}

func InitDiceClient(configValue *config.Config) (*DiceDB, error) {
diceClient := dice.NewClient(&dice.Options{
Addr: configValue.DiceAddr,
diceClient := dicedb.NewClient(&dicedb.Options{
Addr: configValue.DiceDBAddr,
DialTimeout: 10 * time.Second,
MaxRetries: 10,
})

// Ping the dice client to verify the connection
// Ping the dicedb client to verify the connection
err := diceClient.Ping(context.Background()).Err()
if err != nil {
return nil, err
Expand All @@ -64,7 +64,7 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err

val, err := db.getKey(command.Args[0])
switch {
case errors.Is(err, dice.Nil):
case errors.Is(err, dicedb.Nil):
return nil, errors.New("key does not exist")
case err != nil:
return nil, fmt.Errorf("get failed %v", err)
Expand Down
17 changes: 4 additions & 13 deletions internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@ import (
"strings"
"time"

dice "github.com/dicedb/go-dice"
dicedb "github.com/dicedb/go-dice"
)

// RateLimiter middleware to limit requests based on a specified limit and duration
func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Enable CORS for requests
enableCors(w, r)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Skip rate limiting for non-command endpoints
if !strings.Contains(r.URL.Path, "/cli/") {
if !strings.Contains(r.URL.Path, "/shell/exec/") {
next.ServeHTTP(w, r)
return
}
Expand All @@ -36,7 +34,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float

// Fetch the current request count
val, err := client.Client.Get(ctx, key).Result()
if err != nil && !errors.Is(err, dice.Nil) {
if err != nil && !errors.Is(err, dicedb.Nil) {
slog.Error("Error fetching request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -74,25 +72,19 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float
}
}

// Log the successful request increment
slog.Info("Request processed", "count", requestCount+1)

// Call the next handler
next.ServeHTTP(w, r)
})
}

func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Enable CORS for requests
enableCors(w, r)

// Set a request context with a timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Only apply rate limiting for specific paths (e.g., "/cli/")
if !strings.Contains(r.URL.Path, "/cli/") {
if !strings.Contains(r.URL.Path, "/shell/exec/") {
next.ServeHTTP(w, r)
return
}
Expand Down Expand Up @@ -144,7 +136,6 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
}
}

// Log the successful request and pass control to the next handler
slog.Info("Request processed", "count", requestCount)
next.ServeHTTP(w, r)
})
Expand Down
2 changes: 0 additions & 2 deletions internal/middleware/trailingslash.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import (
func TrailingSlashMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") {
// remove slash
newPath := strings.TrimSuffix(r.URL.Path, "/")
// if query params exist append them
newURL := newPath
if r.URL.RawQuery != "" {
newURL += "?" + r.URL.RawQuery
Expand Down
39 changes: 22 additions & 17 deletions internal/server/httpServer.go → internal/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"log/slog"
"net/http"
"server/internal/middleware"
"strings"
"sync"
"time"

"server/internal/db"
util "server/pkg/util"
util "server/util"
)

type HTTPServer struct {
Expand All @@ -37,7 +36,13 @@ type HTTPErrorResponse struct {
}

func errorResponse(response string) string {
return fmt.Sprintf("{\"error\": %q}", response)
errorMessage := map[string]string{"error": response}
jsonResponse, err := json.Marshal(errorMessage)
if err != nil {
slog.Error("Error marshaling response: %v", slog.Any("err", err))
return `{"error": "internal server error"}`
}
return string(jsonResponse)
}

func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -73,64 +78,64 @@ func (s *HTTPServer) Run(ctx context.Context) error {
wg.Add(1)
go func() {
defer wg.Done()
log.Printf("Starting server at %s\n", s.httpServer.Addr)
slog.Info("starting server at", slog.String("addr", s.httpServer.Addr))
if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("HTTP server error: %v", err)
slog.Error("http server error: %v", slog.Any("err", err))
}
}()

<-ctx.Done()
log.Println("Shutting down server...")
slog.Info("shutting down server...")
return s.Shutdown()
}

func (s *HTTPServer) Shutdown() error {
if err := s.DiceClient.Client.Close(); err != nil {
log.Printf("Failed to close dice client: %v", err)
slog.Error("failed to close dicedb client: %v", slog.Any("err", err))
}

return s.httpServer.Shutdown(context.Background())
}

func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) {
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Server is running"})
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"})
}

func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) {
diceCmd, err := util.ParseHTTPRequest(r)
if err != nil {
http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest)
http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest)
return
}

resp, err := s.DiceClient.ExecuteCommand(diceCmd)
if err != nil {
http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest)
http.Error(w, errorResponse("error executing command"), http.StatusBadRequest)
return
}

respStr, ok := resp.(string)
if !ok {
log.Println("Error: response is not a string", "error", err)
http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError)
slog.Error("error: response is not a string", "error", slog.Any("err", err))
http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError)
return
}

httpResponse := HTTPResponse{Data: respStr}
responseJSON, err := json.Marshal(httpResponse)
if err != nil {
log.Println("Error marshaling response to JSON", "error", err)
http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError)
slog.Error("error marshaling response to json", "error", slog.Any("err", err))
http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError)
return
}

_, err = w.Write(responseJSON)
if err != nil {
http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError)
http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError)
return
}
}

func (s *HTTPServer) SearchHandler(w http.ResponseWriter, request *http.Request) {
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Search results"})
util.JSONResponse(w, http.StatusOK, map[string]string{"message": "search results"})
}
10 changes: 5 additions & 5 deletions internal/tests/integration/ratelimiter_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ import (
"net/http"
"net/http/httptest"
config "server/config"
util "server/pkg/util"
util "server/util"
"testing"

"github.com/stretchr/testify/require"
)

func TestRateLimiterWithinLimit(t *testing.T) {
configValue := config.LoadConfig()
limit := configValue.RequestLimit
window := configValue.RequestWindow
limit := configValue.RequestLimitPerMin
window := configValue.RequestWindowSec

w, r, rateLimiter := util.SetupRateLimiter(limit, window)

Expand All @@ -25,8 +25,8 @@ func TestRateLimiterWithinLimit(t *testing.T) {

func TestRateLimiterExceedsLimit(t *testing.T) {
configValue := config.LoadConfig()
limit := configValue.RequestLimit
window := configValue.RequestWindow
limit := configValue.RequestLimitPerMin
window := configValue.RequestWindowSec

w, r, rateLimiter := util.SetupRateLimiter(limit, window)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ func TestTrailingSlashMiddleware(t *testing.T) {
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.requestURL, nil)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", test.requestURL, nil)
w := httptest.NewRecorder()

handler.ServeHTTP(w, req)

if w.Code != tt.expectedCode {
t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code)
if w.Code != test.expectedCode {
t.Errorf("expected status %d, got %d", test.expectedCode, w.Code)
}

if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl {
t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location"))
if test.expectedUrl != "" && w.Header().Get("Location") != test.expectedUrl {
t.Errorf("expected location %s, got %s", test.expectedUrl, w.Header().Get("Location"))
}
})
}
Expand Down
10 changes: 5 additions & 5 deletions internal/tests/stress/ratelimiter_stress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"net/http"
"net/http/httptest"
"server/config"
util "server/pkg/util"
util "server/util"
"sync"
"testing"
"time"
Expand All @@ -14,13 +14,13 @@ import (

func TestRateLimiterUnderStress(t *testing.T) {
configValue := config.LoadConfig()
limit := configValue.RequestLimit
window := configValue.RequestWindow
limit := configValue.RequestLimitPerMin
window := configValue.RequestWindowSec

_, r, rateLimiter := util.SetupRateLimiter(limit, window)

var wg sync.WaitGroup
var numRequests int64 = limit // add some extra requests to ensure we don't hit the limit
var numRequests int64 = limit
successCount := int64(0)
failCount := int64(0)
var mu sync.Mutex
Expand All @@ -43,5 +43,5 @@ func TestRateLimiterUnderStress(t *testing.T) {
}()
}
wg.Wait()
require.Equal(t, limit, successCount, "Should succeed for exactly limit requests")
require.Equal(t, limit, successCount, "should succeed for exactly limit requests")
}
12 changes: 6 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@ package main

import (
"context"
"log"
"log/slog"
"net/http"
"server/config"
"server/internal/db"
"server/internal/server" // Import the new package for HTTPServer
"server/internal/server"
)

func main() {
configValue := config.LoadConfig()
diceClient, err := db.InitDiceClient(configValue)
if err != nil {
log.Fatalf("Failed to initialize dice client: %v", err)
slog.Error("Failed to initialize DiceDB client: %v", slog.Any("err", err))
}

// Create mux and register routes
mux := http.NewServeMux()
httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimit, configValue.RequestWindow)
httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimitPerMin, configValue.RequestWindowSec)
mux.HandleFunc("/health", httpServer.HealthCheck)
mux.HandleFunc("/cli/{cmd}", httpServer.CliHandler)
mux.HandleFunc("/shell/exec/{cmd}", httpServer.CliHandler)
mux.HandleFunc("/search", httpServer.SearchHandler)

// Graceful shutdown context
Expand All @@ -29,6 +29,6 @@ func main() {

// Run the HTTP Server
if err := httpServer.Run(ctx); err != nil {
log.Printf("Server failed: %v\n", err)
slog.Error("server failed: %v\n", slog.Any("err", err))
}
}
Loading