Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: CacheMarshal and CacheUnmarshalView for third party cache libraries #693

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package rueidis

import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -529,6 +531,104 @@ type RedisMessage struct {
ttl [7]byte
}

func (m *RedisMessage) cachesize() int {
n := 9 // typ (1) + length (8) TODO: can we use VarInt instead of fixed 8 bytes for length?
switch m.typ {
case typeInteger, typeNull, typeBool:
case typeArray, typeMap, typeSet:
for _, val := range m.values {
n += val.cachesize()
}
default:
n += len(m.string)
}
return n
}

func (m *RedisMessage) serialize(o *bytes.Buffer) {
var buf [8]byte // TODO: can we use VarInt instead of fixed 8 bytes for length?
o.WriteByte(m.typ)
switch m.typ {
case typeInteger, typeNull, typeBool:
binary.BigEndian.PutUint64(buf[:], uint64(m.integer))
o.Write(buf[:])
case typeArray, typeMap, typeSet:
binary.BigEndian.PutUint64(buf[:], uint64(len(m.values)))
o.Write(buf[:])
for _, val := range m.values {
val.serialize(o)
}
default:
binary.BigEndian.PutUint64(buf[:], uint64(len(m.string)))
o.Write(buf[:])
o.WriteString(m.string)
}
}

var ErrCacheUnmarshal = errors.New("cache unmarshal error")

func (m *RedisMessage) unmarshalView(c int64, buf []byte) (int64, error) {
var err error
if int64(len(buf)) < c+9 {
return 0, ErrCacheUnmarshal
}
m.typ = buf[c]
c += 1
m.integer = int64(binary.BigEndian.Uint64(buf[c : c+8]))
c += 8 // TODO: can we use VarInt instead of fixed 8 bytes for length?
switch m.typ {
case typeInteger, typeNull, typeBool:
case typeArray, typeMap, typeSet:
m.values = make([]RedisMessage, m.integer)
m.integer = 0
for i := range m.values {
if c, err = m.values[i].unmarshalView(c, buf); err != nil {
break
}
}
default:
if int64(len(buf)) < c+m.integer {
return 0, ErrCacheUnmarshal
}
m.string = BinaryString(buf[c : c+m.integer])
c += m.integer
m.integer = 0
}
return c, err
}

// CacheSize returns the buffer size needed by the CacheMarshal.
func (m *RedisMessage) CacheSize() int {
return m.cachesize() + 7 // 7 for ttl
}

// CacheMarshal writes serialized RedisMessage to the provided buffer.
// If the provided buffer is nil, CacheMarshal will allocate one.
// Note that output format is not compatible with different client versions.
func (m *RedisMessage) CacheMarshal(buf []byte) []byte {
if buf == nil {
buf = make([]byte, 0, m.CacheSize())
}
o := bytes.NewBuffer(buf)
o.Write(m.ttl[:7])
m.serialize(o)
return o.Bytes()
}

// CacheUnmarshalView construct the RedisMessage from the buffer produced by CacheMarshal.
// Note that the buffer can't be reused after CacheUnmarshalView since it uses unsafe.String on top of the buffer.
func (m *RedisMessage) CacheUnmarshalView(buf []byte) error {
if len(buf) < 7 {
return ErrCacheUnmarshal
}
copy(m.ttl[:7], buf[:7])
if _, err := m.unmarshalView(7, buf); err != nil {
return err
}
m.attrs = cacheMark
return nil
}

// IsNil check if message is a redis nil response
func (m *RedisMessage) IsNil() bool {
return m.typ == typeNull
Expand Down
46 changes: 46 additions & 0 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1996,4 +1996,50 @@ func TestRedisMessage(t *testing.T) {
}
}
})

t.Run("CacheMarshal and CacheUnmarshalView", func(t *testing.T) {
m1 := RedisMessage{typ: '_'}
m2 := RedisMessage{typ: '+', string: "random"}
m3 := RedisMessage{typ: '#', integer: 1}
m4 := RedisMessage{typ: ':', integer: -1234}
m5 := RedisMessage{typ: ',', string: "-1.5"}
m6 := RedisMessage{typ: '%', values: nil}
m7 := RedisMessage{typ: '~', values: []RedisMessage{m1, m2, m3, m4, m5, m6}}
m8 := RedisMessage{typ: '*', values: []RedisMessage{m1, m2, m3, m4, m5, m6, m7}}
msgs := []RedisMessage{m1, m2, m3, m4, m5, m6, m7, m8}
now := time.Now()
for i := range msgs {
msgs[i].setExpireAt(now.Add(time.Second * time.Duration(i)).UnixMilli())
}
for i, m1 := range msgs {
siz := m1.CacheSize()
bs1 := m1.CacheMarshal(nil)
if len(bs1) != siz {
t.Fatal("size not match")
}
bs2 := m1.CacheMarshal(bs1)
if !bytes.Equal(bs2[:siz], bs2[siz:]) {
t.Fatal("byte not match")
}
var m2 RedisMessage
if err := m2.CacheUnmarshalView(bs1); err != nil {
t.Fatal(err)
}
if m1.String() != m2.String() {
t.Fatal("content not match")
}
if !m2.IsCacheHit() {
t.Fatal("should be cache hit")
}
if m2.CachePXAT() != now.Add(time.Second*time.Duration(i)).UnixMilli() {
t.Fatal("should have the same ttl")
}
for l := 0; l < siz; l++ {
var m3 RedisMessage
if err := m3.CacheUnmarshalView(bs2[:l]); err != ErrCacheUnmarshal {
t.Fatal("should fail as expected")
}
}
}
})
}
Loading