diff --git a/filesystem/fat32/common_test.go b/filesystem/fat32/common_test.go index 8d5698f7..a89923a8 100644 --- a/filesystem/fat32/common_test.go +++ b/filesystem/fat32/common_test.go @@ -467,7 +467,7 @@ func testReadFilesystemData() (info *testFSInfo, err error) { rootDirCluster: 2, // root is at cluster 2 size: sizeInBytes, maxCluster: numClusters, - clusters: make(map[uint32]uint32), + clusters: make([]uint32, numClusters+1), } case inClusters && len(clusterLineMatch) > 4: start, err := strconv.Atoi(clusterLineMatch[1]) diff --git a/filesystem/fat32/fat32.go b/filesystem/fat32/fat32.go index 7c6cfb1f..e990190c 100644 --- a/filesystem/fat32/fat32.go +++ b/filesystem/fat32/fat32.go @@ -228,17 +228,16 @@ func Create(f util.File, size, start, blocksize int64, volumeLabel string) (*Fil fatSecondaryStart := uint64(fatPrimaryStart) + uint64(fatSize) maxCluster := fatSize / 4 rootDirCluster := uint32(2) + clusters := make([]uint32, maxCluster+1) + clusters[rootDirCluster] = eocMarker fat := table{ fatID: fatID, eocMarker: eocMarker, unusedMarker: unusedMarker, size: fatSize, rootDirCluster: rootDirCluster, - clusters: map[uint32]uint32{ - // when we start, there is just one directory with a single cluster - rootDirCluster: eocMarker, - }, - maxCluster: maxCluster, + clusters: clusters, + maxCluster: maxCluster, } // where does our data start? @@ -694,10 +693,9 @@ func (fs *FileSystem) getClusterList(firstCluster uint32) ([]uint32, error) { // first, get the chain of clusters complete := false cluster := firstCluster - clusters := fs.table.clusters // do we even have a valid cluster? - if _, ok := clusters[cluster]; !ok { + if cluster > fs.table.maxCluster || fs.table.clusters[cluster] == 0 { return nil, fmt.Errorf("invalid start cluster: %d", cluster) } @@ -706,11 +704,13 @@ func (fs *FileSystem) getClusterList(firstCluster uint32) ([]uint32, error) { // save the current cluster clusterList = append(clusterList, cluster) // get the next cluster - newCluster := clusters[cluster] + newCluster := fs.table.clusters[cluster] // if it is EOC, we are done switch { case fs.table.isEoc(newCluster): complete = true + case newCluster > fs.table.maxCluster: + return nil, fmt.Errorf("invalid cluster chain at %d", newCluster) case cluster < 2: return nil, fmt.Errorf("invalid cluster chain at %d", cluster) } @@ -924,6 +924,10 @@ func (fs *FileSystem) readDirWithMkdir(p string, doMake bool) (*Directory, []*di // returns the indexes of clusters to be used in order. If the new size is smaller than // the original size, will shrink the chain. func (fs *FileSystem) allocateSpace(size uint64, previous uint32) ([]uint32, error) { + if previous > fs.table.maxCluster { + return nil, fmt.Errorf("invalid cluster chain at %d", previous) + } + var ( clusters []uint32 err error @@ -961,12 +965,11 @@ func (fs *FileSystem) allocateSpace(size uint64, previous uint32) ([]uint32, err } // get a list of allocated clusters, so we can know which ones are unallocated and therefore allocatable - allClusters := fs.table.clusters maxCluster := fs.table.maxCluster if extraClusterCount > 0 { for i := uint32(2); i < maxCluster && len(allocated) < extraClusterCount; i++ { - if _, ok := allClusters[i]; !ok { + if fs.table.clusters[i] == 0 { // these become the same at this point allocated = append(allocated, i) } @@ -982,12 +985,12 @@ func (fs *FileSystem) allocateSpace(size uint64, previous uint32) ([]uint32, err // extend the chain and fill them in if previous > 0 { - allClusters[previous] = allocated[0] + fs.table.clusters[previous] = allocated[0] } for i := 0; i < lastAlloc; i++ { - allClusters[allocated[i]] = allocated[i+1] + fs.table.clusters[allocated[i]] = allocated[i+1] } - allClusters[allocated[lastAlloc]] = fs.table.eocMarker + fs.table.clusters[allocated[lastAlloc]] = fs.table.eocMarker // update the FSIS lastAllocatedCluster = allocated[len(allocated)-1] @@ -1003,13 +1006,21 @@ func (fs *FileSystem) allocateSpace(size uint64, previous uint32) ([]uint32, err } deallocated = clusters[lastAlloc+1:] + if uint32(lastAlloc) > fs.table.maxCluster || clusters[lastAlloc] > fs.table.maxCluster { + return nil, fmt.Errorf("invalid cluster chain at %d", lastAlloc) + } + // mark last allocated one as EOC - allClusters[clusters[lastAlloc]] = fs.table.eocMarker + fs.table.clusters[clusters[lastAlloc]] = fs.table.eocMarker // unmark all of the unused ones lastAllocatedCluster = fs.fsis.lastAllocatedCluster for _, cl := range deallocated { - allClusters[cl] = fs.table.unusedMarker + if cl > fs.table.maxCluster { + return nil, fmt.Errorf("invalid cluster chain at %d", cl) + } + + fs.table.clusters[cl] = fs.table.unusedMarker if cl == lastAllocatedCluster { lastAllocatedCluster-- } diff --git a/filesystem/fat32/fat32_internal_test.go b/filesystem/fat32/fat32_internal_test.go index 919141e8..d2b8242b 100644 --- a/filesystem/fat32/fat32_internal_test.go +++ b/filesystem/fat32/fat32_internal_test.go @@ -17,6 +17,15 @@ import ( in that case, the dataStart is relative to partition, not to disk, so need to read the offset correctly */ +func clustersFromMap(m map[uint32]uint32, maxCluster uint32) []uint32 { + clusters := make([]uint32, maxCluster+1) + for k, v := range m { + clusters[k] = v + } + + return clusters +} + func getValidFat32FSFull() *FileSystem { fs := getValidFat32FSSmall() fs.table = *getValidFat32Table() @@ -25,11 +34,12 @@ func getValidFat32FSFull() *FileSystem { func getValidFat32FSSmall() *FileSystem { eoc := uint32(0xffffffff) + maxCluster := uint32(128) fs := &FileSystem{ table: table{ rootDirCluster: 2, size: 512, - maxCluster: 128, + maxCluster: maxCluster, eocMarker: eoc, /* map: @@ -40,8 +50,9 @@ func getValidFat32FSSmall() *FileSystem { 11 15 16-broken + 17-broken */ - clusters: map[uint32]uint32{ + clusters: clustersFromMap(map[uint32]uint32{ 2: eoc, 3: 4, 4: 5, @@ -53,8 +64,9 @@ func getValidFat32FSSmall() *FileSystem { 9: 11, 11: eoc, 15: eoc, - 16: 0, - }, + 16: 1, + 17: 999, + }, maxCluster), }, bytesPerCluster: 512, dataStart: 178176, @@ -80,6 +92,7 @@ func getValidFat32FSSmall() *FileSystem { } return fs } + func TestFat32GetClusterList(t *testing.T) { fs := getValidFat32FSSmall() @@ -178,6 +191,7 @@ func TestFat32AllocateSpace(t *testing.T) { 11 15 16-broken + 17-broken // recall that 512 bytes per cluster here */ tests := []struct { @@ -189,9 +203,11 @@ func TestFat32AllocateSpace(t *testing.T) { {500, 2, []uint32{2}, nil}, {600, 2, []uint32{2, 12}, nil}, {2000, 2, []uint32{2, 12, 13, 14}, nil}, - {2000, 0, []uint32{12, 13, 14, 17}, nil}, + {2000, 0, []uint32{12, 13, 14, 18}, nil}, {200000000000, 0, nil, fmt.Errorf("no space left on device")}, {200000000000, 2, nil, fmt.Errorf("no space left on device")}, + {2000, 17, nil, fmt.Errorf("unable to get cluster list: invalid cluster chain at 999")}, + {2000, 999, nil, fmt.Errorf("invalid cluster chain at 999")}, } for _, tt := range tests { // reset for each test diff --git a/filesystem/fat32/table.go b/filesystem/fat32/table.go index 4210e9f0..600c4711 100644 --- a/filesystem/fat32/table.go +++ b/filesystem/fat32/table.go @@ -2,7 +2,7 @@ package fat32 import ( "encoding/binary" - "reflect" + "slices" ) // table a FAT32 table @@ -10,7 +10,7 @@ type table struct { fatID uint32 eocMarker uint32 unusedMarker uint32 - clusters map[uint32]uint32 + clusters []uint32 rootDirCluster uint32 size uint32 maxCluster uint32 @@ -28,7 +28,7 @@ func (t *table) equal(a *table) bool { t.rootDirCluster == a.rootDirCluster && t.size == a.size && t.maxCluster == a.maxCluster && - reflect.DeepEqual(t.clusters, a.clusters) + slices.Equal(a.clusters, t.clusters) } /* @@ -37,12 +37,14 @@ func (t *table) equal(a *table) bool { */ func tableFromBytes(b []byte) *table { + maxCluster := uint32(len(b) / 4) + t := table{ fatID: binary.LittleEndian.Uint32(b[0:4]), eocMarker: binary.LittleEndian.Uint32(b[4:8]), size: uint32(len(b)), - clusters: map[uint32]uint32{}, - maxCluster: uint32(len(b) / 4), + clusters: make([]uint32, maxCluster+1), + maxCluster: maxCluster, rootDirCluster: 2, // always 2 for FAT32 } // just need to map the clusters in @@ -71,10 +73,7 @@ func (t *table) bytes() []byte { for i := uint32(2); i < numClusters; i++ { bStart := i * 4 bEnd := bStart + 4 - val := uint32(0) - if cluster, ok := t.clusters[i]; ok { - val = cluster - } + val := t.clusters[i] binary.LittleEndian.PutUint32(b[bStart:bEnd], val) } diff --git a/filesystem/fat32/table_internal_test.go b/filesystem/fat32/table_internal_test.go index 119f05ff..aac38eda 100644 --- a/filesystem/fat32/table_internal_test.go +++ b/filesystem/fat32/table_internal_test.go @@ -3,6 +3,7 @@ package fat32 import ( "bytes" "os" + "slices" "testing" "github.com/diskfs/go-diskfs/util" @@ -16,13 +17,10 @@ const ( func getValidFat32Table() *table { // make a duplicate, in case someone modifies what we return - var t = &table{} + t := &table{} *t = *fsInfo.table // and because the clusters are copied by reference - t.clusters = make(map[uint32]uint32) - for k, v := range fsInfo.table.clusters { - t.clusters[k] = v - } + t.clusters = slices.Clone(t.clusters) return t }