From ce98d276dcc2cc0a34c0b0d464699a6866a0638b Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Thu, 28 Dec 2023 18:12:59 +0000 Subject: [PATCH] squashfs: add LRU cache for blocks #205 --- filesystem/squashfs/lru.go | 138 +++++++++++++++++++ filesystem/squashfs/lru_test.go | 236 ++++++++++++++++++++++++++++++++ 2 files changed, 374 insertions(+) create mode 100644 filesystem/squashfs/lru.go create mode 100644 filesystem/squashfs/lru_test.go diff --git a/filesystem/squashfs/lru.go b/filesystem/squashfs/lru.go new file mode 100644 index 00000000..9185d44c --- /dev/null +++ b/filesystem/squashfs/lru.go @@ -0,0 +1,138 @@ +package squashfs + +import ( + "sync" +) + +// A simple least recently used cache +type lru struct { + mu sync.Mutex + cache map[int64]*lruBlock // cache keyed on block position in file + maxBlocks int // max number of blocks in cache + root lruBlock // root block in LRU circular list +} + +// A data block to store in the lru cache +type lruBlock struct { + mu sync.Mutex // lock while fetching + data []byte // data block - nil while being fetched + prev *lruBlock // prev block in LRU list + next *lruBlock // next block in LRU list + pos int64 // position it was read off disk + size uint16 // compressed size on disk +} + +// Create a new LRU cache of a maximum of maxBlocks blocks of size +func newLRU(maxBlocks int) *lru { + l := &lru{ + cache: make(map[int64]*lruBlock), + maxBlocks: maxBlocks, + root: lruBlock{ + pos: -1, + }, + } + l.root.prev = &l.root // circularly link the root node + l.root.next = &l.root + return l +} + +// Unlink the block from the list +func (l *lru) unlink(block *lruBlock) { + block.prev.next = block.next + block.next.prev = block.prev + block.prev = nil + block.next = nil +} + +// Pop a block from the end of the list +func (l *lru) pop() *lruBlock { + block := l.root.prev + if block == &l.root { + panic("internal error: list empty") + } + l.unlink(block) + return block +} + +// Add a block to the start of the list +func (l *lru) push(block *lruBlock) { + oldHead := l.root.next + l.root.next = block + block.prev = &l.root + block.next = oldHead + oldHead.prev = block +} + +// ensure there are no more than n blocks in the cache +func (l *lru) trim(maxBlocks int) { + for len(l.cache) > maxBlocks && len(l.cache) > 0 { + // Remove a block from the cache + block := l.pop() + delete(l.cache, block.pos) + } +} + +// add block to the cache, pruning the cache as appropriate +func (l *lru) add(block *lruBlock) { + l.trim(l.maxBlocks - 1) + l.cache[block.pos] = block + l.push(block) +} + +// Fetch data returning size used from input and error +// +// data should be a subslice of buf +type fetchFn func() (data []byte, size uint16, err error) + +// Get the block at pos from the cache. +// +// If it isn't found in the cache then fetch() is called to get it. +// +// This does read through caching and takes care not to block parallel +// calls to the fetch() function. +func (l *lru) get(pos int64, fetch fetchFn) (data []byte, size uint16, err error) { + if l == nil { + return fetch() + } + l.mu.Lock() + block, found := l.cache[pos] + if !found { + // Add an empty block with data == nil + block = &lruBlock{ + pos: pos, + } + // Add it to the cache and the tail of the list + l.add(block) + } else { + // Remove the block from the list + l.unlink(block) + // Add it back to the start + l.push(block) + } + block.mu.Lock() // transfer the lock to the block + l.mu.Unlock() + defer block.mu.Unlock() + + if block.data != nil { + return block.data, block.size, nil + } + + // Fetch the block + data, size, err = fetch() + if err != nil { + return nil, 0, err + } + block.data = data + block.size = size + return data, size, nil +} + +// Sets the number of blocks to be used in the cache +// +// It makes sure that there are no more than maxBlocks in the cache. +func (l *lru) setMaxBlocks(maxBlocks int) { + l.mu.Lock() + defer l.mu.Unlock() + l.maxBlocks = maxBlocks + l.trim(l.maxBlocks) +} diff --git a/filesystem/squashfs/lru_test.go b/filesystem/squashfs/lru_test.go new file mode 100644 index 00000000..1b62e021 --- /dev/null +++ b/filesystem/squashfs/lru_test.go @@ -0,0 +1,236 @@ +package squashfs + +import ( + "errors" + "strings" + "testing" +) + +//nolint:gocyclo // we really do not care about the cyclomatic complexity of a test function. Maybe someday we will improve it. +func TestLRU(t *testing.T) { + const maxBlocks = 10 + l := newLRU(maxBlocks) + + assertEmpty := func(want bool) { + t.Helper() + got := l.root.prev == &l.root && l.root.next == &l.root + if want != got { + t.Errorf("Wanted empty %v but got %v", want, got) + } + } + + assertClear := func(block *lruBlock, want bool) { + t.Helper() + got := block.next == nil && block.prev == nil + if want != got { + t.Errorf("Wanted block clear %v but block clear %v", want, got) + } + } + + assertNoError := func(err error) { + t.Helper() + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + } + + assertCacheBlocks := func(want int) { + if len(l.cache) != want { + t.Errorf("Expected len(l.cache)=%d but got %d", want, len(l.cache)) + } + } + + t.Run("Simple", func(t *testing.T) { + assertEmpty(true) + block := &lruBlock{ + pos: 1, + } + assertClear(block, true) + l.push(block) + assertClear(block, false) + assertEmpty(false) + block2 := l.pop() + if block.pos != block2.pos { + t.Errorf("Wanted block %d but got %d", block.pos, block2.pos) + } + assertClear(block, true) + assertClear(block2, true) + assertEmpty(true) + }) + + t.Run("Unlink", func(t *testing.T) { + assertEmpty(true) + block := &lruBlock{ + pos: 1, + } + assertClear(block, true) + l.push(block) + assertClear(block, false) + assertEmpty(false) + l.unlink(block) + assertEmpty(true) + assertClear(block, true) + }) + + // Check that we push blocks on and off in FIFO order + t.Run("FIFO", func(t *testing.T) { + assertEmpty(true) + for i := int64(1); i <= 10; i++ { + block := &lruBlock{ + pos: i, + } + l.push(block) + } + assertEmpty(false) + for i := int64(1); i <= 10; i++ { + block := l.pop() + if block.pos != i { + t.Errorf("Wanted block %d but got %d", i, block.pos) + } + } + assertEmpty(true) + }) + + t.Run("Empty", func(t *testing.T) { + defer func() { + r, ok := recover().(string) + if !ok || !strings.Contains(r, "list empty") { + t.Errorf("Panic string doesn't contain list empty: %q", r) + } + }() + assertEmpty(true) + l.pop() + t.Errorf("Expected exception to be thrown") + }) + + t.Run("Add", func(t *testing.T) { + assertEmpty(true) + for i := 1; i <= 2*maxBlocks; i++ { + block := &lruBlock{ + pos: int64(i), + } + l.add(block) + wantItems := i + if i >= maxBlocks { + wantItems = maxBlocks + } + gotItems := len(l.cache) + if wantItems != gotItems { + t.Errorf("Expected %d items but got %d", wantItems, gotItems) + } + } + assertEmpty(false) + // Check the blocks are correct in the cache + for i := maxBlocks + 1; i <= 2*maxBlocks; i++ { + block, found := l.cache[int64(i)] + if !found { + t.Errorf("Didn't find block at %d", i) + } else if block.pos != int64(i) { + t.Errorf("Expected block.pos=%d but got %d", i, block.pos) + } + } + // Check the blocks are correct in the list + block := l.root.prev + for i := maxBlocks + 1; i <= 2*maxBlocks; i++ { + if block.pos != int64(i) { + t.Errorf("Expected block.pos=%d but got %d", i, block.pos) + } + block = block.prev + } + + t.Run("Trim", func(t *testing.T) { + assertCacheBlocks(maxBlocks) + l.trim(maxBlocks - 1) + assertCacheBlocks(maxBlocks - 1) + l.trim(maxBlocks - 1) + assertCacheBlocks(maxBlocks - 1) + + t.Run("SetMaxBlocks", func(t *testing.T) { + assertCacheBlocks(maxBlocks - 1) + l.setMaxBlocks(maxBlocks - 2) + assertCacheBlocks(maxBlocks - 2) + if l.maxBlocks != maxBlocks-2 { + t.Errorf("Expected maxBlocks %d but got %d", maxBlocks-2, l.maxBlocks) + } + l.setMaxBlocks(maxBlocks) + assertCacheBlocks(maxBlocks - 2) + if l.maxBlocks != maxBlocks { + t.Errorf("Expected maxBlocks %d but got %d", maxBlocks, l.maxBlocks) + } + }) + }) + }) + + // Check blocks are as expected in the cache and LRU list + checkCache := func(expectedPos ...int64) { + t.Helper() + // Check the blocks are correct in the cache + for _, pos := range expectedPos { + block, found := l.cache[pos] + if !found { + t.Errorf("Didn't find block at %d", pos) + } else if block.pos != pos { + t.Errorf("Expected block.pos=%d but got %d", pos, block.pos) + } + } + // Check the blocks are correct in the list + block := l.root.next + for _, pos := range expectedPos { + if block.pos != pos { + t.Errorf("Expected block.pos=%d but got %d", pos, block.pos) + } + block = block.next + } + } + + l = newLRU(10) + t.Run("Get", func(t *testing.T) { + // Fill the cache + for i := 1; i <= 2*maxBlocks; i++ { + pos := int64(i) + _, _, err := l.get(pos, func() (data []byte, size uint16, err error) { + buf := []byte{byte(pos)} + return buf, uint16(i), nil + }) + assertNoError(err) + } + checkCache(20, 19, 18, 17, 16, 15, 14, 13, 12, 11) + + // Test cache HIT + data, size, err := l.get(int64(14), func() (data []byte, size uint16, err error) { + return nil, 0, errors.New("cached block not found") + }) + assertNoError(err) + if data[0] != 14 { + t.Errorf("Expected magic %d but got %d", 14, data[0]) + } + if size != 14 { + t.Errorf("Expected size %d but got %d", 14, size) + } + checkCache(14, 20, 19, 18, 17, 16, 15, 13, 12, 11) + + // Test cache MISS + data, size, err = l.get(int64(1), func() (data []byte, size uint16, err error) { + buf := []byte{1} + return buf, uint16(1), nil + }) + assertNoError(err) + if data[0] != byte(1) { + t.Errorf("Expected magic %d but got %d", byte(1), data[0]) + } + if size != uint16(1) { + t.Errorf("Expected size %d but got %d", 1, size) + } + checkCache(1, 14, 20, 19, 18, 17, 16, 15, 13, 12) + + // Test cache fetch ERROR + testErr := errors.New("test error") + _, _, err = l.get(int64(2), func() (data []byte, size uint16, err error) { + return nil, 0, testErr + }) + if err != testErr { + t.Errorf("Want error %q but got %q", testErr, err) + } + checkCache(2, 1, 14, 20, 19, 18, 17, 16, 15, 13) + }) +}