Skip to content

Commit

Permalink
Add sanity check to UpdateSlab to prevent migrations from changing th…
Browse files Browse the repository at this point in the history
…e number or order of shards in a slab (#705)

* stores: add sanity checks to UpdateSlab to prevent changing order or length of shards for a slab

* stores: fix TestUpdateSlab
  • Loading branch information
ChrisSchinnerl authored Nov 3, 2023
1 parent 9c0bb3c commit f5021f7
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 20 deletions.
33 changes: 21 additions & 12 deletions stores/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ const (
refreshHealthBatchSize = 10000
)

var (
errInvalidNumberOfShards = errors.New("slab has invalid number of shards")
errShardRootChanged = errors.New("shard root changed")
)

type (
dbArchivedContract struct {
Model
Expand Down Expand Up @@ -1564,23 +1569,36 @@ func (ss *SQLStore) UpdateSlab(ctx context.Context, s object.Slab, contractSet s
return err
}

// make sure the number of shards doesn't change.
// NOTE: check both the slice as well as the TotalShards field to be
// safe.
if len(s.Shards) != int(slab.TotalShards) {
return fmt.Errorf("%w: expected %v shards (TotalShards) but got %v", errInvalidNumberOfShards, slab.TotalShards, len(s.Shards))
} else if len(s.Shards) != len(slab.Shards) {
return fmt.Errorf("%w: expected %v shards (Shards) but got %v", errInvalidNumberOfShards, len(slab.Shards), len(s.Shards))
}

// make sure the roots stay the same.
for i, shard := range s.Shards {
if shard.Root != types.Hash256(slab.Shards[i].Root) {
return fmt.Errorf("%w: shard %v has changed root from %v to %v", errShardRootChanged, i, slab.Shards[i].Root, shard.Root)
}
}

// update fields
if err := tx.Model(&slab).
Where(&slab).
Updates(map[string]interface{}{
"db_contract_set_id": cs.ID,
"health_valid": false,
"health": 1,
"total_shards": len(s.Shards),
}).
Error; err != nil {
return err
}

// loop updated shards
toKeep := make(map[types.Hash256]struct{})
for _, shard := range s.Shards {
toKeep[shard.Root] = struct{}{}
// ensure the sector exists
var sector dbSector
if err := tx.
Expand All @@ -1605,15 +1623,6 @@ func (ss *SQLStore) UpdateSlab(ctx context.Context, s object.Slab, contractSet s
}
}
}
for _, shard := range slab.Shards {
root := *(*types.Hash256)(shard.Root)
if _, found := toKeep[root]; found {
continue
}
if err := tx.Delete(shard).Error; err != nil {
return fmt.Errorf("failed to delete shard: %w", err)
}
}
return nil
})
}
Expand Down
81 changes: 73 additions & 8 deletions stores/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2116,8 +2116,8 @@ func TestContractSectors(t *testing.T) {
}
}

// TestPutSlab verifies the functionality of PutSlab.
func TestPutSlab(t *testing.T) {
// TestUpdateSlab verifies the functionality of UpdateSlab.
func TestUpdateSlab(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

Expand Down Expand Up @@ -2289,11 +2289,10 @@ func TestPutSlab(t *testing.T) {
t.Fatalf("unexpected slab, %v != %v", obj.Slabs[0].ID, updated.ID)
}

// update the slab to change its contract set and total shards.
// update the slab to change its contract set.
if err := ss.SetContractSet(ctx, "other", nil); err != nil {
t.Fatal(err)
}
slab.Shards = nil // remove all shards
err = ss.UpdateSlab(ctx, slab, "other", map[types.PublicKey]types.FileContractID{
hk1: fcid1,
hk3: fcid3,
Expand All @@ -2311,10 +2310,6 @@ func TestPutSlab(t *testing.T) {
t.Fatal(err)
} else if s.DBContractSet.Name != "other" {
t.Fatal("contract set was not updated")
} else if s.TotalShards != 0 {
t.Fatal("total shards was not updated")
} else if len(s.Shards) != 0 {
t.Fatal("shards were not deleted")
}
}

Expand Down Expand Up @@ -3671,3 +3666,73 @@ func TestDeleteHostSector(t *testing.T) {
t.Fatal("expected hk2 to be latest host", types.PublicKey(s.Shards[0].LatestHost))
}
}

func TestUpdateSlabSanityChecks(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)

// create hosts and contracts.
hks, err := ss.addTestHosts(5)
if err != nil {
t.Fatal(err)
}
_, contracts, err := ss.addTestContracts(hks)
if err != nil {
t.Fatal(err)
}
usedContracts := make(map[types.PublicKey]types.FileContractID)
for _, c := range contracts {
usedContracts[c.HostKey] = c.ID
}

// prepare a slab.
var shards []object.Sector
for i := 0; i < 5; i++ {
shards = append(shards, object.Sector{
Host: hks[i],
Root: types.Hash256{byte(i + 1)},
})
}
slab := object.Slab{
Key: object.GenerateEncryptionKey(),
Shards: shards,
}

// set slab.
err = ss.UpdateObject(context.Background(), api.DefaultBucketName, "foo", testContractSet, testETag, testMimeType, object.Object{
Key: object.GenerateEncryptionKey(),
Slabs: []object.SlabSlice{{Slab: slab}},
}, usedContracts)
if err != nil {
t.Fatal(err)
}

// verify slab.
rSlab, err := ss.Slab(context.Background(), slab.Key)
if err != nil {
t.Fatal(err)
} else if !reflect.DeepEqual(slab, rSlab) {
t.Fatal("unexpected slab", cmp.Diff(slab, rSlab))
}

// change the length to fail the update.
if err := ss.UpdateSlab(context.Background(), object.Slab{
Key: slab.Key,
Shards: shards[:len(shards)-1],
}, testContractSet, usedContracts); !errors.Is(err, errInvalidNumberOfShards) {
t.Fatal(err)
}

// reverse the order of the shards to fail the update.
reversedShards := append([]object.Sector{}, shards...)
for i := 0; i < len(reversedShards)/2; i++ {
j := len(reversedShards) - i - 1
reversedShards[i], reversedShards[j] = reversedShards[j], reversedShards[i]
}
reversedSlab := object.Slab{
Key: slab.Key,
Shards: reversedShards,
}
if err := ss.UpdateSlab(context.Background(), reversedSlab, testContractSet, usedContracts); !errors.Is(err, errShardRootChanged) {
t.Fatal(err)
}
}
2 changes: 2 additions & 0 deletions stores/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ func (s *SQLStore) retryTransaction(fc func(tx *gorm.DB) error, opts ...*sql.TxO
abortRetry := func(err error) bool {
if err == nil ||
errors.Is(err, gorm.ErrRecordNotFound) ||
errors.Is(err, errInvalidNumberOfShards) ||
errors.Is(err, errShardRootChanged) ||
errors.Is(err, api.ErrContractNotFound) ||
errors.Is(err, api.ErrObjectNotFound) ||
errors.Is(err, api.ErrObjectCorrupted) ||
Expand Down

0 comments on commit f5021f7

Please sign in to comment.