-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
squashfs: add LRU cache for blocks #205
- Loading branch information
Showing
2 changed files
with
374 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
package squashfs | ||
|
||
import ( | ||
"sync" | ||
) | ||
|
||
// A simple least recently used cache | ||
type lru struct { | ||
mu sync.Mutex | ||
cache map[int64]*lruBlock // cache keyed on block position in file | ||
maxBlocks int // max number of blocks in cache | ||
root lruBlock // root block in LRU circular list | ||
} | ||
|
||
// A data block to store in the lru cache | ||
type lruBlock struct { | ||
mu sync.Mutex // lock while fetching | ||
data []byte // data block - nil while being fetched | ||
prev *lruBlock // prev block in LRU list | ||
next *lruBlock // next block in LRU list | ||
pos int64 // position it was read off disk | ||
size uint16 // compressed size on disk | ||
} | ||
|
||
// Create a new LRU cache of a maximum of maxBlocks blocks of size | ||
func newLRU(maxBlocks int) *lru { | ||
l := &lru{ | ||
cache: make(map[int64]*lruBlock), | ||
maxBlocks: maxBlocks, | ||
root: lruBlock{ | ||
pos: -1, | ||
}, | ||
} | ||
l.root.prev = &l.root // circularly link the root node | ||
l.root.next = &l.root | ||
return l | ||
} | ||
|
||
// Unlink the block from the list | ||
func (l *lru) unlink(block *lruBlock) { | ||
block.prev.next = block.next | ||
block.next.prev = block.prev | ||
block.prev = nil | ||
block.next = nil | ||
} | ||
|
||
// Pop a block from the end of the list | ||
func (l *lru) pop() *lruBlock { | ||
block := l.root.prev | ||
if block == &l.root { | ||
panic("internal error: list empty") | ||
} | ||
l.unlink(block) | ||
return block | ||
} | ||
|
||
// Add a block to the start of the list | ||
func (l *lru) push(block *lruBlock) { | ||
oldHead := l.root.next | ||
l.root.next = block | ||
block.prev = &l.root | ||
block.next = oldHead | ||
oldHead.prev = block | ||
} | ||
|
||
// ensure there are no more than n blocks in the cache | ||
func (l *lru) trim(maxBlocks int) { | ||
for len(l.cache) > maxBlocks && len(l.cache) > 0 { | ||
// Remove a block from the cache | ||
block := l.pop() | ||
delete(l.cache, block.pos) | ||
} | ||
} | ||
|
||
// add block to the cache, pruning the cache as appropriate | ||
func (l *lru) add(block *lruBlock) { | ||
l.trim(l.maxBlocks - 1) | ||
l.cache[block.pos] = block | ||
l.push(block) | ||
} | ||
|
||
// Fetch data returning size used from input and error | ||
// | ||
// data should be a subslice of buf | ||
type fetchFn func() (data []byte, size uint16, err error) | ||
|
||
// Get the block at pos from the cache. | ||
// | ||
// If it isn't found in the cache then fetch() is called to get it. | ||
// | ||
// This does read through caching and takes care not to block parallel | ||
// calls to the fetch() function. | ||
func (l *lru) get(pos int64, fetch fetchFn) (data []byte, size uint16, err error) { | ||
if l == nil { | ||
return fetch() | ||
} | ||
l.mu.Lock() | ||
block, found := l.cache[pos] | ||
if !found { | ||
// Add an empty block with data == nil | ||
block = &lruBlock{ | ||
pos: pos, | ||
} | ||
// Add it to the cache and the tail of the list | ||
l.add(block) | ||
} else { | ||
// Remove the block from the list | ||
l.unlink(block) | ||
// Add it back to the start | ||
l.push(block) | ||
} | ||
block.mu.Lock() // transfer the lock to the block | ||
l.mu.Unlock() | ||
defer block.mu.Unlock() | ||
|
||
if block.data != nil { | ||
return block.data, block.size, nil | ||
} | ||
|
||
// Fetch the block | ||
data, size, err = fetch() | ||
if err != nil { | ||
return nil, 0, err | ||
} | ||
block.data = data | ||
block.size = size | ||
return data, size, nil | ||
} | ||
|
||
// Sets the number of blocks to be used in the cache | ||
// | ||
// It makes sure that there are no more than maxBlocks in the cache. | ||
func (l *lru) setMaxBlocks(maxBlocks int) { | ||
l.mu.Lock() | ||
defer l.mu.Unlock() | ||
l.maxBlocks = maxBlocks | ||
l.trim(l.maxBlocks) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
package squashfs | ||
|
||
import ( | ||
"errors" | ||
"strings" | ||
"testing" | ||
) | ||
|
||
//nolint:gocyclo // we really do not care about the cyclomatic complexity of a test function. Maybe someday we will improve it. | ||
func TestLRU(t *testing.T) { | ||
const maxBlocks = 10 | ||
l := newLRU(maxBlocks) | ||
|
||
assertEmpty := func(want bool) { | ||
t.Helper() | ||
got := l.root.prev == &l.root && l.root.next == &l.root | ||
if want != got { | ||
t.Errorf("Wanted empty %v but got %v", want, got) | ||
} | ||
} | ||
|
||
assertClear := func(block *lruBlock, want bool) { | ||
t.Helper() | ||
got := block.next == nil && block.prev == nil | ||
if want != got { | ||
t.Errorf("Wanted block clear %v but block clear %v", want, got) | ||
} | ||
} | ||
|
||
assertNoError := func(err error) { | ||
t.Helper() | ||
if err != nil { | ||
t.Errorf("Expected no error but got: %v", err) | ||
} | ||
} | ||
|
||
assertCacheBlocks := func(want int) { | ||
if len(l.cache) != want { | ||
t.Errorf("Expected len(l.cache)=%d but got %d", want, len(l.cache)) | ||
} | ||
} | ||
|
||
t.Run("Simple", func(t *testing.T) { | ||
assertEmpty(true) | ||
block := &lruBlock{ | ||
pos: 1, | ||
} | ||
assertClear(block, true) | ||
l.push(block) | ||
assertClear(block, false) | ||
assertEmpty(false) | ||
block2 := l.pop() | ||
if block.pos != block2.pos { | ||
t.Errorf("Wanted block %d but got %d", block.pos, block2.pos) | ||
} | ||
assertClear(block, true) | ||
assertClear(block2, true) | ||
assertEmpty(true) | ||
}) | ||
|
||
t.Run("Unlink", func(t *testing.T) { | ||
assertEmpty(true) | ||
block := &lruBlock{ | ||
pos: 1, | ||
} | ||
assertClear(block, true) | ||
l.push(block) | ||
assertClear(block, false) | ||
assertEmpty(false) | ||
l.unlink(block) | ||
assertEmpty(true) | ||
assertClear(block, true) | ||
}) | ||
|
||
// Check that we push blocks on and off in FIFO order | ||
t.Run("FIFO", func(t *testing.T) { | ||
assertEmpty(true) | ||
for i := int64(1); i <= 10; i++ { | ||
block := &lruBlock{ | ||
pos: i, | ||
} | ||
l.push(block) | ||
} | ||
assertEmpty(false) | ||
for i := int64(1); i <= 10; i++ { | ||
block := l.pop() | ||
if block.pos != i { | ||
t.Errorf("Wanted block %d but got %d", i, block.pos) | ||
} | ||
} | ||
assertEmpty(true) | ||
}) | ||
|
||
t.Run("Empty", func(t *testing.T) { | ||
defer func() { | ||
r, ok := recover().(string) | ||
if !ok || !strings.Contains(r, "list empty") { | ||
t.Errorf("Panic string doesn't contain list empty: %q", r) | ||
} | ||
}() | ||
assertEmpty(true) | ||
l.pop() | ||
t.Errorf("Expected exception to be thrown") | ||
}) | ||
|
||
t.Run("Add", func(t *testing.T) { | ||
assertEmpty(true) | ||
for i := 1; i <= 2*maxBlocks; i++ { | ||
block := &lruBlock{ | ||
pos: int64(i), | ||
} | ||
l.add(block) | ||
wantItems := i | ||
if i >= maxBlocks { | ||
wantItems = maxBlocks | ||
} | ||
gotItems := len(l.cache) | ||
if wantItems != gotItems { | ||
t.Errorf("Expected %d items but got %d", wantItems, gotItems) | ||
} | ||
} | ||
assertEmpty(false) | ||
// Check the blocks are correct in the cache | ||
for i := maxBlocks + 1; i <= 2*maxBlocks; i++ { | ||
block, found := l.cache[int64(i)] | ||
if !found { | ||
t.Errorf("Didn't find block at %d", i) | ||
} else if block.pos != int64(i) { | ||
t.Errorf("Expected block.pos=%d but got %d", i, block.pos) | ||
} | ||
} | ||
// Check the blocks are correct in the list | ||
block := l.root.prev | ||
for i := maxBlocks + 1; i <= 2*maxBlocks; i++ { | ||
if block.pos != int64(i) { | ||
t.Errorf("Expected block.pos=%d but got %d", i, block.pos) | ||
} | ||
block = block.prev | ||
} | ||
|
||
t.Run("Trim", func(t *testing.T) { | ||
assertCacheBlocks(maxBlocks) | ||
l.trim(maxBlocks - 1) | ||
assertCacheBlocks(maxBlocks - 1) | ||
l.trim(maxBlocks - 1) | ||
assertCacheBlocks(maxBlocks - 1) | ||
|
||
t.Run("SetMaxBlocks", func(t *testing.T) { | ||
assertCacheBlocks(maxBlocks - 1) | ||
l.setMaxBlocks(maxBlocks - 2) | ||
assertCacheBlocks(maxBlocks - 2) | ||
if l.maxBlocks != maxBlocks-2 { | ||
t.Errorf("Expected maxBlocks %d but got %d", maxBlocks-2, l.maxBlocks) | ||
} | ||
l.setMaxBlocks(maxBlocks) | ||
assertCacheBlocks(maxBlocks - 2) | ||
if l.maxBlocks != maxBlocks { | ||
t.Errorf("Expected maxBlocks %d but got %d", maxBlocks, l.maxBlocks) | ||
} | ||
}) | ||
}) | ||
}) | ||
|
||
// Check blocks are as expected in the cache and LRU list | ||
checkCache := func(expectedPos ...int64) { | ||
t.Helper() | ||
// Check the blocks are correct in the cache | ||
for _, pos := range expectedPos { | ||
block, found := l.cache[pos] | ||
if !found { | ||
t.Errorf("Didn't find block at %d", pos) | ||
} else if block.pos != pos { | ||
t.Errorf("Expected block.pos=%d but got %d", pos, block.pos) | ||
} | ||
} | ||
// Check the blocks are correct in the list | ||
block := l.root.next | ||
for _, pos := range expectedPos { | ||
if block.pos != pos { | ||
t.Errorf("Expected block.pos=%d but got %d", pos, block.pos) | ||
} | ||
block = block.next | ||
} | ||
} | ||
|
||
l = newLRU(10) | ||
t.Run("Get", func(t *testing.T) { | ||
// Fill the cache | ||
for i := 1; i <= 2*maxBlocks; i++ { | ||
pos := int64(i) | ||
_, _, err := l.get(pos, func() (data []byte, size uint16, err error) { | ||
buf := []byte{byte(pos)} | ||
return buf, uint16(i), nil | ||
}) | ||
assertNoError(err) | ||
} | ||
checkCache(20, 19, 18, 17, 16, 15, 14, 13, 12, 11) | ||
|
||
// Test cache HIT | ||
data, size, err := l.get(int64(14), func() (data []byte, size uint16, err error) { | ||
return nil, 0, errors.New("cached block not found") | ||
}) | ||
assertNoError(err) | ||
if data[0] != 14 { | ||
t.Errorf("Expected magic %d but got %d", 14, data[0]) | ||
} | ||
if size != 14 { | ||
t.Errorf("Expected size %d but got %d", 14, size) | ||
} | ||
checkCache(14, 20, 19, 18, 17, 16, 15, 13, 12, 11) | ||
|
||
// Test cache MISS | ||
data, size, err = l.get(int64(1), func() (data []byte, size uint16, err error) { | ||
buf := []byte{1} | ||
return buf, uint16(1), nil | ||
}) | ||
assertNoError(err) | ||
if data[0] != byte(1) { | ||
t.Errorf("Expected magic %d but got %d", byte(1), data[0]) | ||
} | ||
if size != uint16(1) { | ||
t.Errorf("Expected size %d but got %d", 1, size) | ||
} | ||
checkCache(1, 14, 20, 19, 18, 17, 16, 15, 13, 12) | ||
|
||
// Test cache fetch ERROR | ||
testErr := errors.New("test error") | ||
_, _, err = l.get(int64(2), func() (data []byte, size uint16, err error) { | ||
return nil, 0, testErr | ||
}) | ||
if err != testErr { | ||
t.Errorf("Want error %q but got %q", testErr, err) | ||
} | ||
checkCache(2, 1, 14, 20, 19, 18, 17, 16, 15, 13) | ||
}) | ||
} |