From 71c932d29c5b81bc7f91afb4b0631851f4dc72ae Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (golang.cz)" Date: Thu, 25 Jul 2024 18:25:02 +0200 Subject: [PATCH] Improve performance by using MULTI and MGET commands (#13) * Run ICNRBY + EXPIRE in a single atomic transaction * Use MGET to fetch both counters at once --- httprateredis.go | 60 ++++++++++++++++++++++-------------------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/httprateredis.go b/httprateredis.go index f253cd3..7cf3657 100644 --- a/httprateredis.go +++ b/httprateredis.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strconv" "time" "github.com/go-chi/httprate" @@ -102,20 +103,20 @@ func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount i hkey := c.limitCounterKey(key, currentWindow) - cmd := conn.Do(ctx, "INCRBY", hkey, amount) - if cmd == nil { - return fmt.Errorf("httprateredis: redis incr failed") - } - if err := cmd.Err(); err != nil { - return err + pipe := conn.TxPipeline() + incrCmd := pipe.IncrBy(ctx, hkey, int64(amount)) + expireCmd := pipe.Expire(ctx, hkey, c.windowLength*3) + _, err := pipe.Exec(ctx) + if err != nil { + return fmt.Errorf("httprateredis: redis transaction failed: %w", err) } - cmd = conn.Do(ctx, "EXPIRE", hkey, c.windowLength.Seconds()*3) - if cmd == nil { - return fmt.Errorf("httprateredis: redis expire failed") + if err := incrCmd.Err(); err != nil { + return fmt.Errorf("httprateredis: redis incr failed: %w", err) } - if err := cmd.Err(); err != nil { - return err + + if err := expireCmd.Err(); err != nil { + return fmt.Errorf("httprateredis: redis expire failed: %w", err) } return nil @@ -125,32 +126,27 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time) ctx := context.Background() conn := c.client - cmd := conn.Do(ctx, "GET", c.limitCounterKey(key, currentWindow)) - if cmd == nil { - return 0, 0, fmt.Errorf("httprateredis: redis get curr failed") - } - if err := cmd.Err(); err != nil && err != redis.Nil { - return 0, 0, fmt.Errorf("httprateredis: redis get curr failed: %w", err) - } + currKey := c.limitCounterKey(key, currentWindow) + prevKey := c.limitCounterKey(key, previousWindow) - curr, err := cmd.Int() - if err != nil && err != redis.Nil { - return 0, 0, fmt.Errorf("httprateredis: redis int curr value: %w", err) + values, err := conn.MGet(ctx, currKey, prevKey).Result() + if err != nil { + return 0, 0, fmt.Errorf("httprateredis: redis mget failed: %w", err) + } else if len(values) != 2 { + return 0, 0, fmt.Errorf("httprateredis: redis mget returned wrong number of keys: %v, expected 2", len(values)) } - cmd = conn.Do(ctx, "GET", c.limitCounterKey(key, previousWindow)) - if cmd == nil { - return 0, 0, fmt.Errorf("httprateredis: redis get prev failed") - } + var curr, prev int - if err := cmd.Err(); err != nil && err != redis.Nil { - return 0, 0, fmt.Errorf("httprateredis: redis get prev failed: %w", err) + // MGET always returns slice with nil or "string" values, even if the values + // were created with the INCR command. Ignore error if we can't parse the number. + if values[0] != nil { + v, _ := values[0].(string) + curr, _ = strconv.Atoi(v) } - - var prev int - prev, err = cmd.Int() - if err != nil && err != redis.Nil { - return 0, 0, fmt.Errorf("httprateredis: redis int prev value: %w", err) + if values[1] != nil { + v, _ := values[1].(string) + prev, _ = strconv.Atoi(v) } return curr, prev, nil