Skip to content

Commit

Permalink
fix offset when reading and writing fat32 files to avoid overruns (#260)
Browse files Browse the repository at this point in the history
Signed-off-by: Avi Deitcher <[email protected]>
  • Loading branch information
deitch authored Sep 26, 2024
1 parent 142b4fd commit 6ad9db3
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 13 deletions.
111 changes: 102 additions & 9 deletions filesystem/fat32/fat32_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ package fat32_test

import (
"bytes"
"crypto/rand"
"fmt"
"io"
"math/rand/v2"
"os"
"path"
"path/filepath"
"strings"
"testing"

"github.com/diskfs/go-diskfs"
"github.com/diskfs/go-diskfs/disk"
"github.com/diskfs/go-diskfs/filesystem"
"github.com/diskfs/go-diskfs/filesystem/fat32"
"github.com/diskfs/go-diskfs/testhelper"
Expand Down Expand Up @@ -245,6 +249,8 @@ func TestFat32Read(t *testing.T) {
}
//nolint:thelper // this is not a helper function
runTest := func(t *testing.T, pre, post int64) {
seed := [32]byte{}
chacha := rand.NewChaCha8(seed)
for _, t2 := range tests {
tt := t2
t.Run(fmt.Sprintf("blocksize %d filesize %d bytechange %d", tt.filesize, tt.blocksize, tt.bytechange), func(t *testing.T) {
Expand All @@ -258,7 +264,7 @@ func TestFat32Read(t *testing.T) {
corrupted := ""
if tt.bytechange >= 0 {
b := make([]byte, 1)
_, _ = rand.Read(b)
_, _ = chacha.Read(b)
_, _ = f.WriteAt(b, tt.bytechange+pre)
corrupted = fmt.Sprintf("corrupted %d", tt.bytechange+pre)
}
Expand Down Expand Up @@ -629,14 +635,16 @@ func TestFat32OpenFile(t *testing.T) {
bWrite := make([]byte, size)
header := fmt.Sprintf("OpenFile(%s, %s)", path, getOpenMode(mode))
readWriter, err := fs.OpenFile(path, mode)
seed := [32]byte{}
chacha := rand.NewChaCha8(seed)
switch {
case err != nil:
t.Errorf("%s: unexpected error: %v", header, err)
case readWriter == nil:
t.Errorf("%s: Unexpected nil output", header)
default:
// write and then read
_, _ = rand.Read(bWrite)
_, _ = chacha.Read(bWrite)
written, writeErr := readWriter.Write(bWrite)
_, _ = readWriter.Seek(0, 0)
bRead, readErr := io.ReadAll(readWriter)
Expand Down Expand Up @@ -682,21 +690,23 @@ func TestFat32OpenFile(t *testing.T) {
if err != nil {
t.Fatalf("error reading fat32 filesystem from %s: %v", f.Name(), err)
}
path := "/abcdefghi"
p := "/abcdefghi"
mode := os.O_RDWR | os.O_CREATE
// each cluster is 512 bytes, so use 10 clusters and a bit of another
size := 10*512 + 22
bWrite := make([]byte, size)
header := fmt.Sprintf("OpenFile(%s, %s)", path, getOpenMode(mode))
readWriter, err := fs.OpenFile(path, mode)
header := fmt.Sprintf("OpenFile(%s, %s)", p, getOpenMode(mode))
readWriter, err := fs.OpenFile(p, mode)
seed := [32]byte{}
chacha := rand.NewChaCha8(seed)
switch {
case err != nil:
t.Fatalf("%s: unexpected error: %v", header, err)
case readWriter == nil:
t.Fatalf("%s: Unexpected nil output", header)
default:
// write and then read
_, _ = rand.Read(bWrite)
_, _ = chacha.Read(bWrite)
written, writeErr := readWriter.Write(bWrite)
_, _ = readWriter.Seek(0, 0)

Expand All @@ -713,7 +723,7 @@ func TestFat32OpenFile(t *testing.T) {
}
// and open to truncate
mode = os.O_RDWR | os.O_TRUNC
readWriter, err = fs.OpenFile(path, mode)
readWriter, err = fs.OpenFile(p, mode)
if err != nil {
t.Fatalf("could not reopen file: %v", err)
}
Expand Down Expand Up @@ -765,7 +775,9 @@ func TestFat32OpenFile(t *testing.T) {
// success
}

_, _ = rand.Read(bWrite)
seed := [32]byte{}
chacha := rand.NewChaCha8(seed)
_, _ = chacha.Read(bWrite)
writeSizes := []int{512, 1024, 256}
low := 0
for i := 0; low < len(bWrite); i++ {
Expand Down Expand Up @@ -1058,3 +1070,84 @@ func TestOpenFileCaseInsensitive(t *testing.T) {
}
}
}

func testMkFile(fs filesystem.FileSystem, p string, size int) error {
rw, err := fs.OpenFile(p, os.O_CREATE|os.O_RDWR)
if err != nil {
return err
}
smallFile := make([]byte, size)
_, err = rw.Write(smallFile)
if err != nil {
return err
}
return nil
}

func TestCreateFileTree(t *testing.T) {
filename := "fat32_test"
tmpDir := t.TempDir()
tmpImgPath := filepath.Join(tmpDir, filename)

// 6GB to test large disk
size := int64(6 * 1024 * 1024 * 1024)
d, err := diskfs.Create(tmpImgPath, size, diskfs.Raw, diskfs.SectorSizeDefault)
if err != nil {
t.Fatalf("error creating disk: %v", err)
}

spec := disk.FilesystemSpec{
Partition: 0,
FSType: filesystem.TypeFat32,
}
fs, err := d.CreateFilesystem(spec)
if err != nil {
t.Fatalf("error creating filesystem: %v", err)
}

if err := fs.Mkdir("/A"); err != nil {
t.Errorf("Error making dir /A in root: %v", err)
}
if err := fs.Mkdir("/b"); err != nil {
t.Errorf("Error making dir /b in root: %v", err)
}
if err := testMkFile(fs, "/rootfile", 11); err != nil {
t.Errorf("Error making microfile in root: %v", err)
}
for i := 0; i < 100; i++ {
dir := fmt.Sprintf("/b/sub%d", i)
if err := fs.Mkdir(dir); err != nil {
t.Errorf("Error making directory %s: %v", dir, err)
}
blobdir := path.Join(dir, "blob")
if err := fs.Mkdir(blobdir); err != nil {
t.Errorf("Error making directory %s: %v", blobdir, err)
}
file := path.Join(blobdir, "microfile")
if err := testMkFile(fs, file, 11); err != nil {
t.Errorf("Error making microfile %s: %v", file, err)
}
file = path.Join(blobdir, "randfile")
size := rand.IntN(73) // #nosec G404
if err := testMkFile(fs, file, size); err != nil {
t.Errorf("Error making random file %s: %v", file, err)
}
file = path.Join(blobdir, "smallfile")
if err := testMkFile(fs, file, 5*1024*1024); err != nil {
t.Errorf("Error making small file %s: %v", file, err)
}
}
file := "/b/sub49/blob/gigfile1"
gb := 1024 * 1024 * 1024
if err := testMkFile(fs, file, gb); err != nil {
t.Errorf("Error making gigfile1 %s: %v", file, err)
}
file = "/b/sub50/blob/gigfile1"
if err := testMkFile(fs, file, gb); err != nil {
t.Errorf("Error making gigfile1 %s: %v", file, err)
}
file = "/b/sub51/blob/gigfile1"
if err := testMkFile(fs, file, gb); err != nil {
t.Errorf("Error making gigfile1 %s: %v", file, err)
}
}
8 changes: 4 additions & 4 deletions filesystem/fat32/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ func (fl *File) Read(b []byte) (int, error) {
if toRead > left {
toRead = left
}
offset := uint32(start) + (clusters[i]-2)*uint32(bytesPerCluster)
_, _ = file.ReadAt(b[totalRead:totalRead+toRead], int64(offset)+fs.start)
offset := int64(start) + int64(clusters[i]-2)*int64(bytesPerCluster)
_, _ = file.ReadAt(b[totalRead:totalRead+toRead], offset+fs.start)
totalRead += toRead
if totalRead >= maxRead {
break
Expand Down Expand Up @@ -161,8 +161,8 @@ func (fl *File) Write(p []byte) (int, error) {
if toWrite > left {
toWrite = left
}
offset := uint32(start) + (clusters[i]-2)*uint32(bytesPerCluster)
_, err := file.WriteAt(p[totalWritten:totalWritten+toWrite], int64(offset)+fs.start)
offset := int64(start) + int64(clusters[i]-2)*int64(bytesPerCluster)
_, err := file.WriteAt(p[totalWritten:totalWritten+toWrite], offset+fs.start)
if err != nil {
return totalWritten, fmt.Errorf("unable to write to file: %v", err)
}
Expand Down

0 comments on commit 6ad9db3

Please sign in to comment.