diff --git a/stores/metadata.go b/stores/metadata.go index ad4c7d496..56202ccf3 100644 --- a/stores/metadata.go +++ b/stores/metadata.go @@ -1469,13 +1469,16 @@ func (s *SQLStore) isKnownContract(fcid types.FileContractID) bool { return found } -func fetchUsedContracts(tx *gorm.DB, usedContracts map[types.PublicKey]map[types.FileContractID]struct{}) (map[types.FileContractID]dbContract, error) { - fcids := make([]fileContractID, 0, len(usedContracts)) - for _, hostFCIDs := range usedContracts { +func fetchUsedContracts(tx *gorm.DB, usedContractsByHost map[types.PublicKey]map[types.FileContractID]struct{}) (map[types.FileContractID]dbContract, error) { + // flatten map to get all used contract ids + fcids := make([]fileContractID, 0, len(usedContractsByHost)) + for _, hostFCIDs := range usedContractsByHost { for fcid := range hostFCIDs { fcids = append(fcids, fileContractID(fcid)) } } + + // fetch all contracts, take into account renewals var contracts []dbContract err := tx.Model(&dbContract{}). Joins("Host"). @@ -1484,17 +1487,19 @@ func fetchUsedContracts(tx *gorm.DB, usedContracts map[types.PublicKey]map[types if err != nil { return nil, err } - fetchedContracts := make(map[types.FileContractID]dbContract, len(contracts)) + + // build map of used contracts + usedContracts := make(map[types.FileContractID]dbContract, len(contracts)) for _, c := range contracts { - // If a contract has been renewed, we add the renewed contract to the - // map using the old contract's id. - if _, renewed := usedContracts[types.PublicKey(c.Host.PublicKey)][types.FileContractID(c.RenewedFrom)]; renewed { - fetchedContracts[types.FileContractID(c.RenewedFrom)] = c - } else { - fetchedContracts[types.FileContractID(c.FCID)] = c + if _, used := usedContractsByHost[types.PublicKey(c.Host.PublicKey)][types.FileContractID(c.FCID)]; used { + usedContracts[types.FileContractID(c.FCID)] = c + } + if _, used := usedContractsByHost[types.PublicKey(c.Host.PublicKey)][types.FileContractID(c.RenewedFrom)]; used { + usedContracts[types.FileContractID(c.RenewedFrom)] = c } } - return fetchedContracts, nil + + return usedContracts, nil } func (s *SQLStore) RenameObject(ctx context.Context, bucket, keyOld, keyNew string, force bool) error { diff --git a/stores/metadata_test.go b/stores/metadata_test.go index bdd955808..abda57b95 100644 --- a/stores/metadata_test.go +++ b/stores/metadata_test.go @@ -4664,3 +4664,83 @@ func TestUpdateObjectParallel(t *testing.T) { close(c) wg.Wait() } + +// TestFetchUsedContracts is a unit test that verifies the functionality of +// fetchUsedContracts +func TestFetchUsedContracts(t *testing.T) { + // create store + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + // add test host + hk1 := types.PublicKey{1} + err := ss.addTestHost(hk1) + if err != nil { + t.Fatal(err) + } + + // add test contract + fcid1 := types.FileContractID{1} + _, err = ss.addTestContract(fcid1, hk1) + if err != nil { + t.Fatal(err) + } + + // assert empty map returns no contracts + usedContracts := make(map[types.PublicKey]map[types.FileContractID]struct{}) + contracts, err := fetchUsedContracts(ss.db, usedContracts) + if err != nil { + t.Fatal(err) + } else if len(contracts) != 0 { + t.Fatal("expected 0 contracts", len(contracts)) + } + + // add an entry for fcid1 + usedContracts[hk1] = make(map[types.FileContractID]struct{}) + usedContracts[hk1][types.FileContractID{1}] = struct{}{} + + // assert we get the used contract + contracts, err = fetchUsedContracts(ss.db, usedContracts) + if err != nil { + t.Fatal(err) + } else if len(contracts) != 1 { + t.Fatal("expected 1 contract", len(contracts)) + } else if _, ok := contracts[fcid1]; !ok { + t.Fatal("contract not found") + } + + // renew the contract + fcid2 := types.FileContractID{2} + _, err = ss.addTestRenewedContract(fcid2, fcid1, hk1, 1) + if err != nil { + t.Fatal(err) + } + + // assert used contracts contains one entry and it points to the renewal + contracts, err = fetchUsedContracts(ss.db, usedContracts) + if err != nil { + t.Fatal(err) + } else if len(contracts) != 1 { + t.Fatal("expected 1 contract", len(contracts)) + } else if contract, ok := contracts[fcid1]; !ok { + t.Fatal("contract not found") + } else if contract.convert().ID != fcid2 { + t.Fatal("contract should point to the renewed contract") + } + + // add an entry for fcid2 + usedContracts[hk1][types.FileContractID{2}] = struct{}{} + + // assert used contracts now contains an entry for both contracts and both + // point to the renewed contract + contracts, err = fetchUsedContracts(ss.db, usedContracts) + if err != nil { + t.Fatal(err) + } else if len(contracts) != 2 { + t.Fatal("expected 2 contracts", len(contracts)) + } else if !reflect.DeepEqual(contracts[types.FileContractID{1}], contracts[types.FileContractID{2}]) { + t.Fatal("contracts should match") + } else if contracts[types.FileContractID{1}].convert().ID != fcid2 { + t.Fatal("contracts should point to the renewed contract") + } +}