From 25dd0592957e9d0a5ae75832ef95423e366310d6 Mon Sep 17 00:00:00 2001 From: Rueian Date: Wed, 11 Dec 2024 10:51:06 -0800 Subject: [PATCH] feat: CacheMarshal and CacheUnmarshalView for third party cache libraries --- message.go | 101 ++++++++++++++++++++++++++++++++++++++++++++++++ message_test.go | 46 ++++++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/message.go b/message.go index 832c504c..95b07ec0 100644 --- a/message.go +++ b/message.go @@ -1,6 +1,8 @@ package rueidis import ( + "bytes" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -529,6 +531,105 @@ type RedisMessage struct { ttl [7]byte } +func (m *RedisMessage) cachesize() int { + n := 9 // typ (1) + length (8) TODO can we use VarInt instead of fixed 8 bytes for length? + switch m.typ { + case typeInteger, typeNull, typeBool: + case typeArray, typeMap, typeSet: + for _, val := range m.values { + n += val.cachesize() + } + default: + n += len(m.string) + } + return n +} + +func (m *RedisMessage) serialize(o *bytes.Buffer) { + var buf [8]byte // TODO can we use VarInt instead of fixed 8 bytes for length? + o.WriteByte(m.typ) + switch m.typ { + case typeInteger, typeNull, typeBool: + binary.BigEndian.PutUint64(buf[:], uint64(m.integer)) + o.Write(buf[:]) + case typeArray, typeMap, typeSet: + binary.BigEndian.PutUint64(buf[:], uint64(len(m.values))) + o.Write(buf[:]) + for _, val := range m.values { + val.serialize(o) + } + default: + binary.BigEndian.PutUint64(buf[:], uint64(len(m.string))) + o.Write(buf[:]) + o.WriteString(m.string) + } + return +} + +var ErrCacheUnmarshal = errors.New("cache unmarshal error") + +func (m *RedisMessage) unmarshalView(c int64, buf []byte) (int64, error) { + var err error + if int64(len(buf)) < c+9 { + return 0, ErrCacheUnmarshal + } + m.typ = buf[c] + c += 1 + m.integer = int64(binary.BigEndian.Uint64(buf[c : c+8])) + c += 8 // TODO can we use VarInt instead of fixed 8 bytes for length? + switch m.typ { + case typeInteger, typeNull, typeBool: + case typeArray, typeMap, typeSet: + m.values = make([]RedisMessage, m.integer) + m.integer = 0 + for i := range m.values { + if c, err = m.values[i].unmarshalView(c, buf); err != nil { + break + } + } + default: + if int64(len(buf)) < c+m.integer { + return 0, ErrCacheUnmarshal + } + m.string = BinaryString(buf[c : c+m.integer]) + c += m.integer + m.integer = 0 + } + return c, err +} + +// CacheSize returns the buffer size needed by the CacheMarshal. +func (m *RedisMessage) CacheSize() int { + return m.cachesize() + 7 // 7 for ttl +} + +// CacheMarshal writes serialized RedisMessage to the provided buffer. +// If the provided buffer is nil, CacheMarshal will allocate one. +// Note that output format is not compatible with different client versions. +func (m *RedisMessage) CacheMarshal(buf []byte) []byte { + if buf == nil { + buf = make([]byte, 0, m.CacheSize()) + } + o := bytes.NewBuffer(buf) + o.Write(m.ttl[:7]) + m.serialize(o) + return o.Bytes() +} + +// CacheUnmarshalView construct the RedisMessage from the buffer produced by CacheMarshal. +// Note that the buffer can't be reused after CacheUnmarshalView since it uses unsafe.String on top of the buffer. +func (m *RedisMessage) CacheUnmarshalView(buf []byte) error { + if len(buf) < 7 { + return ErrCacheUnmarshal + } + copy(m.ttl[:7], buf[:7]) + if _, err := m.unmarshalView(7, buf); err != nil { + return err + } + m.attrs = cacheMark + return nil +} + // IsNil check if message is a redis nil response func (m *RedisMessage) IsNil() bool { return m.typ == typeNull diff --git a/message_test.go b/message_test.go index 0348aad6..33ae9baa 100644 --- a/message_test.go +++ b/message_test.go @@ -1996,4 +1996,50 @@ func TestRedisMessage(t *testing.T) { } } }) + + t.Run("CacheMarshal and CacheUnmarshalView", func(t *testing.T) { + m1 := RedisMessage{typ: '_'} + m2 := RedisMessage{typ: '+', string: "random"} + m3 := RedisMessage{typ: '#', integer: 1} + m4 := RedisMessage{typ: ':', integer: -1234} + m5 := RedisMessage{typ: ',', string: "-1.5"} + m6 := RedisMessage{typ: '%', values: nil} + m7 := RedisMessage{typ: '~', values: []RedisMessage{m1, m2, m3, m4, m5, m6}} + m8 := RedisMessage{typ: '*', values: []RedisMessage{m1, m2, m3, m4, m5, m6, m7}} + msgs := []RedisMessage{m1, m2, m3, m4, m5, m6, m7, m8} + now := time.Now() + for i := range msgs { + msgs[i].setExpireAt(now.Add(time.Second * time.Duration(i)).UnixMilli()) + } + for i, m1 := range msgs { + siz := m1.CacheSize() + bs1 := m1.CacheMarshal(nil) + if len(bs1) != siz { + t.Fatal("size not match") + } + bs2 := m1.CacheMarshal(bs1) + if !bytes.Equal(bs2[:siz], bs2[siz:]) { + t.Fatal("byte not match") + } + var m2 RedisMessage + if err := m2.CacheUnmarshalView(bs1); err != nil { + t.Fatal(err) + } + if m1.String() != m2.String() { + t.Fatal("content not match") + } + if !m2.IsCacheHit() { + t.Fatal("should be cache hit") + } + if m2.CachePXAT() != now.Add(time.Second*time.Duration(i)).UnixMilli() { + t.Fatal("should have the same ttl") + } + for l := 0; l < siz; l++ { + var m3 RedisMessage + if err := m3.CacheUnmarshalView(bs2[:l]); err != ErrCacheUnmarshal { + t.Fatal("should fail as expected") + } + } + } + }) }