Skip to content

Commit

Permalink
squashfs: add LRU cache for blocks #205
Browse files Browse the repository at this point in the history
  • Loading branch information
ncw committed Dec 31, 2023
1 parent 5bdd6ea commit ce98d27
Show file tree
Hide file tree
Showing 2 changed files with 374 additions and 0 deletions.
138 changes: 138 additions & 0 deletions filesystem/squashfs/lru.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package squashfs

import (
"sync"
)

// A simple least recently used cache
type lru struct {
mu sync.Mutex
cache map[int64]*lruBlock // cache keyed on block position in file
maxBlocks int // max number of blocks in cache
root lruBlock // root block in LRU circular list
}

// A data block to store in the lru cache
type lruBlock struct {
mu sync.Mutex // lock while fetching
data []byte // data block - nil while being fetched
prev *lruBlock // prev block in LRU list
next *lruBlock // next block in LRU list
pos int64 // position it was read off disk
size uint16 // compressed size on disk
}

// Create a new LRU cache of a maximum of maxBlocks blocks of size
func newLRU(maxBlocks int) *lru {
l := &lru{
cache: make(map[int64]*lruBlock),
maxBlocks: maxBlocks,
root: lruBlock{
pos: -1,
},
}
l.root.prev = &l.root // circularly link the root node
l.root.next = &l.root
return l
}

// Unlink the block from the list
func (l *lru) unlink(block *lruBlock) {
block.prev.next = block.next
block.next.prev = block.prev
block.prev = nil
block.next = nil
}

// Pop a block from the end of the list
func (l *lru) pop() *lruBlock {
block := l.root.prev
if block == &l.root {
panic("internal error: list empty")
}
l.unlink(block)
return block
}

// Add a block to the start of the list
func (l *lru) push(block *lruBlock) {
oldHead := l.root.next
l.root.next = block
block.prev = &l.root
block.next = oldHead
oldHead.prev = block
}

// ensure there are no more than n blocks in the cache
func (l *lru) trim(maxBlocks int) {
for len(l.cache) > maxBlocks && len(l.cache) > 0 {
// Remove a block from the cache
block := l.pop()
delete(l.cache, block.pos)
}
}

// add block to the cache, pruning the cache as appropriate
func (l *lru) add(block *lruBlock) {
l.trim(l.maxBlocks - 1)
l.cache[block.pos] = block
l.push(block)
}

// Fetch data returning size used from input and error
//
// data should be a subslice of buf
type fetchFn func() (data []byte, size uint16, err error)

// Get the block at pos from the cache.
//
// If it isn't found in the cache then fetch() is called to get it.
//
// This does read through caching and takes care not to block parallel
// calls to the fetch() function.
func (l *lru) get(pos int64, fetch fetchFn) (data []byte, size uint16, err error) {
if l == nil {
return fetch()
}
l.mu.Lock()
block, found := l.cache[pos]
if !found {
// Add an empty block with data == nil
block = &lruBlock{
pos: pos,
}
// Add it to the cache and the tail of the list
l.add(block)
} else {
// Remove the block from the list
l.unlink(block)
// Add it back to the start
l.push(block)
}
block.mu.Lock() // transfer the lock to the block
l.mu.Unlock()
defer block.mu.Unlock()

if block.data != nil {
return block.data, block.size, nil
}

// Fetch the block
data, size, err = fetch()
if err != nil {
return nil, 0, err
}
block.data = data
block.size = size
return data, size, nil
}

// Sets the number of blocks to be used in the cache
//
// It makes sure that there are no more than maxBlocks in the cache.
func (l *lru) setMaxBlocks(maxBlocks int) {
l.mu.Lock()
defer l.mu.Unlock()
l.maxBlocks = maxBlocks
l.trim(l.maxBlocks)
}
236 changes: 236 additions & 0 deletions filesystem/squashfs/lru_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
package squashfs

import (
"errors"
"strings"
"testing"
)

//nolint:gocyclo // we really do not care about the cyclomatic complexity of a test function. Maybe someday we will improve it.
func TestLRU(t *testing.T) {
const maxBlocks = 10
l := newLRU(maxBlocks)

assertEmpty := func(want bool) {
t.Helper()
got := l.root.prev == &l.root && l.root.next == &l.root
if want != got {
t.Errorf("Wanted empty %v but got %v", want, got)
}
}

assertClear := func(block *lruBlock, want bool) {
t.Helper()
got := block.next == nil && block.prev == nil
if want != got {
t.Errorf("Wanted block clear %v but block clear %v", want, got)
}
}

assertNoError := func(err error) {
t.Helper()
if err != nil {
t.Errorf("Expected no error but got: %v", err)
}
}

assertCacheBlocks := func(want int) {
if len(l.cache) != want {
t.Errorf("Expected len(l.cache)=%d but got %d", want, len(l.cache))
}
}

t.Run("Simple", func(t *testing.T) {
assertEmpty(true)
block := &lruBlock{
pos: 1,
}
assertClear(block, true)
l.push(block)
assertClear(block, false)
assertEmpty(false)
block2 := l.pop()
if block.pos != block2.pos {
t.Errorf("Wanted block %d but got %d", block.pos, block2.pos)
}
assertClear(block, true)
assertClear(block2, true)
assertEmpty(true)
})

t.Run("Unlink", func(t *testing.T) {
assertEmpty(true)
block := &lruBlock{
pos: 1,
}
assertClear(block, true)
l.push(block)
assertClear(block, false)
assertEmpty(false)
l.unlink(block)
assertEmpty(true)
assertClear(block, true)
})

// Check that we push blocks on and off in FIFO order
t.Run("FIFO", func(t *testing.T) {
assertEmpty(true)
for i := int64(1); i <= 10; i++ {
block := &lruBlock{
pos: i,
}
l.push(block)
}
assertEmpty(false)
for i := int64(1); i <= 10; i++ {
block := l.pop()
if block.pos != i {
t.Errorf("Wanted block %d but got %d", i, block.pos)
}
}
assertEmpty(true)
})

t.Run("Empty", func(t *testing.T) {
defer func() {
r, ok := recover().(string)
if !ok || !strings.Contains(r, "list empty") {
t.Errorf("Panic string doesn't contain list empty: %q", r)
}
}()
assertEmpty(true)
l.pop()
t.Errorf("Expected exception to be thrown")
})

t.Run("Add", func(t *testing.T) {
assertEmpty(true)
for i := 1; i <= 2*maxBlocks; i++ {
block := &lruBlock{
pos: int64(i),
}
l.add(block)
wantItems := i
if i >= maxBlocks {
wantItems = maxBlocks
}
gotItems := len(l.cache)
if wantItems != gotItems {
t.Errorf("Expected %d items but got %d", wantItems, gotItems)
}
}
assertEmpty(false)
// Check the blocks are correct in the cache
for i := maxBlocks + 1; i <= 2*maxBlocks; i++ {
block, found := l.cache[int64(i)]
if !found {
t.Errorf("Didn't find block at %d", i)
} else if block.pos != int64(i) {
t.Errorf("Expected block.pos=%d but got %d", i, block.pos)
}
}
// Check the blocks are correct in the list
block := l.root.prev
for i := maxBlocks + 1; i <= 2*maxBlocks; i++ {
if block.pos != int64(i) {
t.Errorf("Expected block.pos=%d but got %d", i, block.pos)
}
block = block.prev
}

t.Run("Trim", func(t *testing.T) {
assertCacheBlocks(maxBlocks)
l.trim(maxBlocks - 1)
assertCacheBlocks(maxBlocks - 1)
l.trim(maxBlocks - 1)
assertCacheBlocks(maxBlocks - 1)

t.Run("SetMaxBlocks", func(t *testing.T) {
assertCacheBlocks(maxBlocks - 1)
l.setMaxBlocks(maxBlocks - 2)
assertCacheBlocks(maxBlocks - 2)
if l.maxBlocks != maxBlocks-2 {
t.Errorf("Expected maxBlocks %d but got %d", maxBlocks-2, l.maxBlocks)
}
l.setMaxBlocks(maxBlocks)
assertCacheBlocks(maxBlocks - 2)
if l.maxBlocks != maxBlocks {
t.Errorf("Expected maxBlocks %d but got %d", maxBlocks, l.maxBlocks)
}
})
})
})

// Check blocks are as expected in the cache and LRU list
checkCache := func(expectedPos ...int64) {
t.Helper()
// Check the blocks are correct in the cache
for _, pos := range expectedPos {
block, found := l.cache[pos]
if !found {
t.Errorf("Didn't find block at %d", pos)
} else if block.pos != pos {
t.Errorf("Expected block.pos=%d but got %d", pos, block.pos)
}
}
// Check the blocks are correct in the list
block := l.root.next
for _, pos := range expectedPos {
if block.pos != pos {
t.Errorf("Expected block.pos=%d but got %d", pos, block.pos)
}
block = block.next
}
}

l = newLRU(10)
t.Run("Get", func(t *testing.T) {
// Fill the cache
for i := 1; i <= 2*maxBlocks; i++ {
pos := int64(i)
_, _, err := l.get(pos, func() (data []byte, size uint16, err error) {
buf := []byte{byte(pos)}
return buf, uint16(i), nil
})
assertNoError(err)
}
checkCache(20, 19, 18, 17, 16, 15, 14, 13, 12, 11)

// Test cache HIT
data, size, err := l.get(int64(14), func() (data []byte, size uint16, err error) {
return nil, 0, errors.New("cached block not found")
})
assertNoError(err)
if data[0] != 14 {
t.Errorf("Expected magic %d but got %d", 14, data[0])
}
if size != 14 {
t.Errorf("Expected size %d but got %d", 14, size)
}
checkCache(14, 20, 19, 18, 17, 16, 15, 13, 12, 11)

// Test cache MISS
data, size, err = l.get(int64(1), func() (data []byte, size uint16, err error) {
buf := []byte{1}
return buf, uint16(1), nil
})
assertNoError(err)
if data[0] != byte(1) {
t.Errorf("Expected magic %d but got %d", byte(1), data[0])
}
if size != uint16(1) {
t.Errorf("Expected size %d but got %d", 1, size)
}
checkCache(1, 14, 20, 19, 18, 17, 16, 15, 13, 12)

// Test cache fetch ERROR
testErr := errors.New("test error")
_, _, err = l.get(int64(2), func() (data []byte, size uint16, err error) {
return nil, 0, testErr
})
if err != testErr {
t.Errorf("Want error %q but got %q", testErr, err)
}
checkCache(2, 1, 14, 20, 19, 18, 17, 16, 15, 13)
})
}

0 comments on commit ce98d27

Please sign in to comment.