diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3abecd311..e40df6930 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -21,9 +21,6 @@ env: jobs: docker: runs-on: ubuntu-latest - strategy: - matrix: - network: ["mainnet" , "zen"] permissions: packages: write steps: @@ -35,22 +32,11 @@ jobs: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Prepare environment variables - run: | - if [[ "${{ matrix.network }}" == "zen" ]]; then - echo "BUILD_TAGS=testnet netgo" >> $GITHUB_ENV - echo "DOCKER_METADATA_SUFFIX=-zen" >> $GITHUB_ENV - else - echo "BUILD_TAGS=netgo" >> $GITHUB_ENV - echo "DOCKER_METADATA_SUFFIX=" >> $GITHUB_ENV - fi - uses: docker/metadata-action@v4 name: Generate tags id: meta with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - flavor: | - suffix=${{ env.DOCKER_METADATA_SUFFIX }},onlatest=true tags: | type=ref,event=branch type=sha,prefix= @@ -58,8 +44,6 @@ jobs: - uses: docker/build-push-action@v4 with: context: . - build-args: | - BUILD_TAGS=${{ env.BUILD_TAGS }} file: ./docker/Dockerfile platforms: linux/amd64,linux/arm64 push: true @@ -68,7 +52,6 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - network: ["mainnet" , "zen"] arch: [ amd64, arm64 ] steps: - uses: actions/checkout@v4 @@ -83,15 +66,6 @@ jobs: sudo apt install -y gcc-aarch64-linux-gnu echo "CC=aarch64-linux-gnu-gcc" >> $GITHUB_ENV fi - - name: Set build tag environment variable - run: | - if [[ "${{ matrix.network }}" == "zen" ]]; then - echo "BUILD_TAGS=testnet netgo" >> $GITHUB_ENV - echo "ZIP_OUTPUT_SUFFIX=_zen" >> $GITHUB_ENV - else - echo "BUILD_TAGS=netgo" >> $GITHUB_ENV - echo "ZIP_OUTPUT_SUFFIX=" >> $GITHUB_ENV - fi - name: Build ${{ matrix.arch }} env: CGO_ENABLED: 1 @@ -99,37 +73,18 @@ jobs: GOARCH: ${{ matrix.arch }} run: | mkdir -p release - ZIP_OUTPUT=release/renterd${{ env.ZIP_OUTPUT_SUFFIX }}_${GOOS}_${GOARCH}.zip - go build -tags="$BUILD_TAGS" -trimpath -o bin/ -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/renterd + ZIP_OUTPUT=release/renterd_${GOOS}_${GOARCH}.zip + go build -trimpath -o bin/ -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/renterd cp README.md LICENSE bin/ zip -qj $ZIP_OUTPUT bin/* - - name: Get Release Asset - uses: actions/github-script@v7 - id: get_release_asset - env: - ARCH: ${{ matrix.arch }} - NETWORK: ${{ matrix.network }} - with: - result-encoding: string - script: | - const arch = process.env.ARCH, - network = process.env.NETWORK; - - switch (network) { - case 'mainnet': - return `renterd_linux_${arch}`; - default: - return `renterd_${network}_linux_${arch}`; - } - uses: actions/upload-artifact@v4 with: - name: ${{ steps.get_release_asset.outputs.result }} - path: release/ + name: renterd_linux_${{ matrix.arch }} + path: release/* build-mac: runs-on: macos-latest strategy: matrix: - network: ["mainnet" , "zen"] arch: [ amd64, arm64 ] steps: - uses: actions/checkout@v4 @@ -169,15 +124,6 @@ jobs: # generate go generate ./... - - name: Set build tag environment variable - run: | - if [[ "${{ matrix.network }}" == "zen" ]]; then - echo "BUILD_TAGS=testnet netgo" >> $GITHUB_ENV - echo "ZIP_OUTPUT_SUFFIX=_zen" >> $GITHUB_ENV - else - echo "BUILD_TAGS=netgo" >> $GITHUB_ENV - echo "ZIP_OUTPUT_SUFFIX=" >> $GITHUB_ENV - fi - name: Build ${{ matrix.arch }} env: APPLE_CERT_ID: ${{ secrets.APPLE_CERT_ID }} @@ -193,53 +139,25 @@ jobs: run: | mkdir -p release ZIP_OUTPUT=release/renterd${{ env.ZIP_OUTPUT_SUFFIX }}_${GOOS}_${GOARCH}.zip - go build -tags="$BUILD_TAGS" -trimpath -o bin/ -a -ldflags '-s -w' ./cmd/renterd + go build -trimpath -o bin/ -a -ldflags '-s -w' ./cmd/renterd cp README.md LICENSE bin/ /usr/bin/codesign --deep -f -v --timestamp -o runtime,library -s $APPLE_CERT_ID bin/renterd ditto -ck bin $ZIP_OUTPUT xcrun notarytool submit -k ~/private_keys/AuthKey_$APPLE_API_KEY.p8 -d $APPLE_API_KEY -i $APPLE_API_ISSUER --wait --timeout 10m $ZIP_OUTPUT - - name: Get Release Asset - uses: actions/github-script@v7 - id: get_release_asset - env: - ARCH: ${{ matrix.arch }} - NETWORK: ${{ matrix.network }} - with: - result-encoding: string - script: | - const arch = process.env.ARCH, - network = process.env.NETWORK; - - switch (network) { - case 'mainnet': - return `renterd_darwin_${arch}`; - default: - return `renterd_${network}_darwin_${arch}`; - } - uses: actions/upload-artifact@v4 with: - name: ${{ steps.get_release_asset.outputs.result }} - path: release/ + name: renterd_darwin_${{ matrix.arch }} + path: release/* build-windows: runs-on: windows-latest strategy: matrix: - network: ["mainnet" , "zen"] arch: [ amd64 ] steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version: 'stable' - - name: Set build tag environment variable - run: | - if ( "${{ matrix.network }}" -eq "zen" ) { - "BUILD_TAGS=testnet netgo" >> $env:GITHUB_ENV - "ZIP_OUTPUT_SUFFIX=_zen" >> $env:GITHUB_ENV - } else { - "BUILD_TAGS=netgo" >> $env:GITHUB_ENV - "ZIP_OUTPUT_SUFFIX=" >> $env:GITHUB_ENV - } - name: Setup shell: bash run: | @@ -254,32 +172,14 @@ jobs: run: | mkdir -p release ZIP_OUTPUT=release/renterd${{ env.ZIP_OUTPUT_SUFFIX }}_${GOOS}_${GOARCH}.zip - go build -tags="$BUILD_TAGS" -trimpath -o bin/ -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/renterd + go build -trimpath -o bin/ -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/renterd azuresigntool sign -kvu "${{ secrets.AZURE_KEY_VAULT_URI }}" -kvi "${{ secrets.AZURE_CLIENT_ID }}" -kvt "${{ secrets.AZURE_TENANT_ID }}" -kvs "${{ secrets.AZURE_CLIENT_SECRET }}" -kvc ${{ secrets.AZURE_CERT_NAME }} -tr http://timestamp.digicert.com -v bin/renterd.exe cp README.md LICENSE bin/ 7z a $ZIP_OUTPUT bin/* - - name: Get Release Asset - uses: actions/github-script@v7 - id: get_release_asset - env: - ARCH: ${{ matrix.arch }} - NETWORK: ${{ matrix.network }} - with: - result-encoding: string - script: | - const arch = process.env.ARCH, - network = process.env.NETWORK; - - switch (network) { - case 'mainnet': - return `renterd_windows_${arch}`; - default: - return `renterd_${network}_windows_${arch}`; - } - uses: actions/upload-artifact@v4 with: - name: ${{ steps.get_release_asset.outputs.result }} - path: release/ + name: renterd_windows_${{ matrix.arch }} + path: release/* combine-release-assets: runs-on: ubuntu-latest diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fa1d690c8..0b735dab1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,7 +31,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest , macos-latest, windows-latest ] - go-version: [ '1.21', '1.22' ] + go-version: [ '1.22', '1.23' ] steps: - name: Checkout uses: actions/checkout@v4 @@ -49,7 +49,7 @@ jobs: host port: 3800 mysql version: '8' mysql root password: test - - name: Test + - name: Test Stores uses: n8maninger/action-golang-test@v1 with: args: "-race;-short" @@ -67,7 +67,7 @@ jobs: uses: n8maninger/action-golang-test@v1 with: package: "./internal/test/e2e/..." - args: "-failfast;-race;-tags=testing;-timeout=30m" + args: "-failfast;-race;-tags=testing;-timeout=60m" - name: Test Integration - MySQL if: matrix.os == 'ubuntu-latest' uses: n8maninger/action-golang-test@v1 @@ -77,6 +77,6 @@ jobs: RENTERD_DB_PASSWORD: test with: package: "./internal/test/e2e/..." - args: "-failfast;-race;-tags=testing;-timeout=30m" + args: "-failfast;-race;-tags=testing;-timeout=60m" - name: Build run: go build -o bin/ ./cmd/renterd diff --git a/.github/workflows/ui.yml b/.github/workflows/ui.yml index 78f5e2b6d..b15dd5816 100644 --- a/.github/workflows/ui.yml +++ b/.github/workflows/ui.yml @@ -15,4 +15,4 @@ jobs: with: moduleName: 'renterd' goVersion: '1.21' - token: ${{ secrets.GITHUB_TOKEN }} + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index dc51edf86..bc01922e1 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ overview of all settings configurable through the CLI. | `Directory` | Directory for storing node state | `.` | `--dir` | - | `directory` | | `Seed` | Seed for the node | - | - | `RENTERD_SEED` | `seed` | | `AutoOpenWebUI` | Automatically open the web UI on startup | `true` | `--openui` | - | `autoOpenWebUI` | +| `Network` | Network to run on (mainnet/zen/anagami) | `mainnet` | `--network` | `RENTERD_NETWORK` | `network` | | `ShutdownTimeout` | Timeout for node shutdown | `5m` | `--node.shutdownTimeout` | - | `shutdownTimeout` | | `Log.Level` | Global logger level (debug\|info\|warn\|error). Defaults to 'info' | `info` | `--log.level` | `RENTERD_LOG_LEVEL` | `log.level` | | `Log.File.Enabled` | Enables logging to disk. Defaults to 'true' | `true` | `--log.file.enabled` | `RENTERD_LOG_FILE_ENABLED` | `log.file.enabled` | @@ -408,14 +409,14 @@ ghcr.io/siafoundation/renterd. version: "3.9" services: renterd: - image: ghcr.io/siafoundation/renterd:master-zen + image: ghcr.io/siafoundation/renterd:master environment: - RENTERD_SEED=put your seed here - RENTERD_API_PASSWORD=test ports: - - 9880:9880 - - 9881:9881 - - 7070:7070 + - 9980:9980 + - 9981:9981 + - 8080:8080 volumes: - ./data:/data restart: unless-stopped @@ -427,23 +428,16 @@ services: From within the root of the repo run the following command to build an image of `renterd` tagged `renterd`. -#### Mainnet - ```sh docker build -t renterd:master -f ./docker/Dockerfile . ``` -#### Testnet - -```sh -docker build --build-arg BUILD_TAGS='netgo testnet' -t renterd:master-zen -f ./docker/Dockerfile . -``` - ### Run Container Run `renterd` in the background as a container named `renterd` that exposes its API to the host system and the gateway to the world. + #### Mainnet ```bash @@ -452,10 +446,16 @@ docker run -d --name renterd -e RENTERD_API_PASSWORD="" -e RENTERD_SEE #### Testnet +To run `renterd` on testnet use the `RENTERD_NETWORK` environment variable. + ```bash -docker run -d --name renterd-testnet -e RENTERD_API_PASSWORD="" -e RENTERD_SEED="" -p 127.0.0.1:9880:9880/tcp -p :9881:9881/tcp ghcr.io/siafoundation/renterd:master-zen +docker run -d --name renterd -e RENTERD_API_PASSWORD="" -e RENTERD_NETWORK="" -e RENTERD_SEED="" -p 127.0.0.1:9980:9980/tcp -p :9981:9981/tcp ghcr.io/siafoundation/renterd:master ``` +Currently available values for `` are: +- `zen` +- `anagami` + ## Architecture `renterd` distinguishes itself from `siad` through a unique architecture @@ -606,10 +606,10 @@ updated using the settings API: { "hostBlockHeightLeeway": 6, // 6 blocks "maxContractPrice": "15000000000000000000000000", // 15 SC per contract - "maxDownloadPrice": "3000000000000000000000000000", // 3000 SC per 1 TiB + "maxDownloadPrice": "3000000000000000000000000000", // 3000 SC per 1 TB "maxRPCPrice": "1000000000000000000000", // 1mS per RPC - "maxStoragePrice": "631593542824", // 3000 SC per TiB per month - "maxUploadPrice": "3000000000000000000000000000", // 3000 SC per 1 TiB + "maxStoragePrice": "631593542824", // 3000 SC per TB per month + "maxUploadPrice": "3000000000000000000000000000", // 3000 SC per 1 TB "migrationSurchargeMultiplier": 10, // overpay up to 10x for sectors migrations on critical slabs "minAccountExpiry": 86400000000000, // 1 day "minMaxEphemeralAccountBalance": "1000000000000000000000000", // 1 SC diff --git a/alerts/alerts_test.go b/alerts/alerts_test.go index ff927ccdc..24b299e1b 100644 --- a/alerts/alerts_test.go +++ b/alerts/alerts_test.go @@ -47,7 +47,7 @@ var _ webhooks.WebhookStore = (*testWebhookStore)(nil) func TestWebhooks(t *testing.T) { store := &testWebhookStore{} - mgr, err := webhooks.NewManager(zap.NewNop().Sugar(), store) + mgr, err := webhooks.NewManager(store, zap.NewNop()) if err != nil { t.Fatal(err) } diff --git a/api/autopilot.go b/api/autopilot.go index 9ca917f6e..e81328d88 100644 --- a/api/autopilot.go +++ b/api/autopilot.go @@ -3,9 +3,10 @@ package api import ( "errors" "fmt" + "sort" "go.sia.tech/core/types" - "go.sia.tech/siad/build" + "go.sia.tech/renterd/internal/utils" ) const ( @@ -131,8 +132,19 @@ type ( func (c AutopilotConfig) Validate() error { if c.Hosts.MaxDowntimeHours > 99*365*24 { return ErrMaxDowntimeHoursTooHigh - } else if c.Hosts.MinProtocolVersion != "" && !build.IsVersion(c.Hosts.MinProtocolVersion) { + } else if c.Hosts.MinProtocolVersion != "" && !utils.IsVersion(c.Hosts.MinProtocolVersion) { return fmt.Errorf("invalid min protocol version '%s'", c.Hosts.MinProtocolVersion) } return nil } + +func (c ContractsConfig) SortContractsForMaintenance(contracts []Contract) { + sort.SliceStable(contracts, func(i, j int) bool { + iInSet := contracts[i].InSet(c.Set) + jInSet := contracts[j].InSet(c.Set) + if iInSet != jInSet { + return iInSet + } + return contracts[i].FileSize() > contracts[j].FileSize() + }) +} diff --git a/api/autopilot_test.go b/api/autopilot_test.go new file mode 100644 index 000000000..148fa92e8 --- /dev/null +++ b/api/autopilot_test.go @@ -0,0 +1,61 @@ +package api + +import ( + "reflect" + "testing" + + "go.sia.tech/core/types" +) + +func TestSortContractsForMaintenance(t *testing.T) { + set := "testset" + cfg := ContractsConfig{ + Set: set, + } + + // empty but in set + c1 := Contract{ + ContractMetadata: ContractMetadata{ + ID: types.FileContractID{1}, + Size: 0, + ContractSets: []string{set}, + }, + } + // some data and in set + c2 := Contract{ + ContractMetadata: ContractMetadata{ + ID: types.FileContractID{2}, + Size: 10, + ContractSets: []string{set}, + }, + } + // same as c2 - sort should be stable + c3 := Contract{ + ContractMetadata: ContractMetadata{ + ID: types.FileContractID{3}, + Size: 10, + ContractSets: []string{set}, + }, + } + // more data but not in set + c4 := Contract{ + ContractMetadata: ContractMetadata{ + ID: types.FileContractID{4}, + Size: 20, + }, + } + // even more data but not in set + c5 := Contract{ + ContractMetadata: ContractMetadata{ + ID: types.FileContractID{5}, + Size: 30, + }, + } + + contracts := []Contract{c1, c2, c3, c4, c5} + cfg.SortContractsForMaintenance(contracts) + + if !reflect.DeepEqual(contracts, []Contract{c2, c3, c1, c5, c4}) { + t.Fatal("unexpected sort order") + } +} diff --git a/api/bus.go b/api/bus.go index 8652a5f49..453af61ca 100644 --- a/api/bus.go +++ b/api/bus.go @@ -48,6 +48,7 @@ type ( // BusStateResponse is the response type for the /bus/state endpoint. BusStateResponse struct { StartTime TimeRFC3339 `json:"startTime"` + Network string `json:"network"` BuildState } ) diff --git a/api/contract.go b/api/contract.go index 94f8c998a..b7d43b6a7 100644 --- a/api/contract.go +++ b/api/contract.go @@ -32,6 +32,8 @@ var ( ErrContractSetNotFound = errors.New("couldn't find contract set") ) +type ContractState string + type ( // A Contract wraps the contract metadata with the latest contract revision. Contract struct { @@ -99,18 +101,23 @@ type ( // that has been moved to the archive either due to expiring or being renewed. ArchivedContract struct { ID types.FileContractID `json:"id"` + HostIP string `json:"hostIP"` HostKey types.PublicKey `json:"hostKey"` RenewedTo types.FileContractID `json:"renewedTo"` Spending ContractSpending `json:"spending"` - ProofHeight uint64 `json:"proofHeight"` - RevisionHeight uint64 `json:"revisionHeight"` - RevisionNumber uint64 `json:"revisionNumber"` - Size uint64 `json:"size"` - StartHeight uint64 `json:"startHeight"` - State string `json:"state"` - WindowStart uint64 `json:"windowStart"` - WindowEnd uint64 `json:"windowEnd"` + ArchivalReason string `json:"archivalReason"` + ContractPrice types.Currency `json:"contractPrice"` + ProofHeight uint64 `json:"proofHeight"` + RenewedFrom types.FileContractID `json:"renewedFrom"` + RevisionHeight uint64 `json:"revisionHeight"` + RevisionNumber uint64 `json:"revisionNumber"` + Size uint64 `json:"size"` + StartHeight uint64 `json:"startHeight"` + State string `json:"state"` + TotalCost types.Currency `json:"totalCost"` + WindowStart uint64 `json:"windowStart"` + WindowEnd uint64 `json:"windowEnd"` } ) @@ -218,3 +225,13 @@ func (c Contract) RemainingCollateral() types.Currency { } return c.Revision.MissedHostPayout().Sub(c.ContractPrice) } + +// InSet returns whether the contract is in the given set. +func (cm ContractMetadata) InSet(set string) bool { + for _, s := range cm.ContractSets { + if s == set { + return true + } + } + return false +} diff --git a/api/events.go b/api/events.go index dbfa68a3f..e9600e53b 100644 --- a/api/events.go +++ b/api/events.go @@ -14,8 +14,10 @@ const ( ModuleConsensus = "consensus" ModuleContract = "contract" ModuleContractSet = "contract_set" + ModuleHost = "host" ModuleSetting = "setting" + EventAdd = "add" EventUpdate = "update" EventDelete = "delete" EventArchive = "archive" @@ -33,6 +35,11 @@ type ( Timestamp time.Time `json:"timestamp"` } + EventContractAdd struct { + Added ContractMetadata `json:"added"` + Timestamp time.Time `json:"timestamp"` + } + EventContractArchive struct { ContractID types.FileContractID `json:"contractID"` Reason string `json:"reason"` @@ -44,6 +51,12 @@ type ( Timestamp time.Time `json:"timestamp"` } + EventHostUpdate struct { + HostKey types.PublicKey `json:"hostKey"` + NetAddr string `json:"netAddr"` + Timestamp time.Time `json:"timestamp"` + } + EventContractSetUpdate struct { Name string `json:"name"` ContractIDs []types.FileContractID `json:"contractIDs"` @@ -72,6 +85,15 @@ var ( } } + WebhookContractAdd = func(url string, headers map[string]string) webhooks.Webhook { + return webhooks.Webhook{ + Event: EventAdd, + Headers: headers, + Module: ModuleContract, + URL: url, + } + } + WebhookContractArchive = func(url string, headers map[string]string) webhooks.Webhook { return webhooks.Webhook{ Event: EventArchive, @@ -99,6 +121,15 @@ var ( } } + WebhookHostUpdate = func(url string, headers map[string]string) webhooks.Webhook { + return webhooks.Webhook{ + Event: EventUpdate, + Headers: headers, + Module: ModuleHost, + URL: url, + } + } + WebhookSettingUpdate = func(url string, headers map[string]string) webhooks.Webhook { return webhooks.Webhook{ Event: EventUpdate, @@ -126,6 +157,12 @@ func ParseEventWebhook(event webhooks.Event) (interface{}, error) { switch event.Module { case ModuleContract: switch event.Event { + case EventAdd: + var e EventContractAdd + if err := json.Unmarshal(bytes, &e); err != nil { + return nil, err + } + return e, nil case EventArchive: var e EventContractArchive if err := json.Unmarshal(bytes, &e); err != nil { @@ -155,6 +192,14 @@ func ParseEventWebhook(event webhooks.Event) (interface{}, error) { } return e, nil } + case ModuleHost: + if event.Event == EventUpdate { + var e EventHostUpdate + if err := json.Unmarshal(bytes, &e); err != nil { + return nil, err + } + return e, nil + } case ModuleSetting: switch event.Event { case EventUpdate: diff --git a/api/host.go b/api/host.go index 0dfe6da81..d932229d6 100644 --- a/api/host.go +++ b/api/host.go @@ -148,18 +148,19 @@ func (opts HostsForScanningOptions) Apply(values url.Values) { type ( Host struct { - KnownSince time.Time `json:"knownSince"` - LastAnnouncement time.Time `json:"lastAnnouncement"` - PublicKey types.PublicKey `json:"publicKey"` - NetAddress string `json:"netAddress"` - PriceTable HostPriceTable `json:"priceTable"` - Settings rhpv2.HostSettings `json:"settings"` - Interactions HostInteractions `json:"interactions"` - Scanned bool `json:"scanned"` - Blocked bool `json:"blocked"` - Checks map[string]HostCheck `json:"checks"` - StoredData uint64 `json:"storedData"` - Subnets []string `json:"subnets"` + KnownSince time.Time `json:"knownSince"` + LastAnnouncement time.Time `json:"lastAnnouncement"` + PublicKey types.PublicKey `json:"publicKey"` + NetAddress string `json:"netAddress"` + PriceTable HostPriceTable `json:"priceTable"` + Settings rhpv2.HostSettings `json:"settings"` + Interactions HostInteractions `json:"interactions"` + Scanned bool `json:"scanned"` + Blocked bool `json:"blocked"` + Checks map[string]HostCheck `json:"checks"` + StoredData uint64 `json:"storedData"` + ResolvedAddresses []string `json:"resolvedAddresses"` + Subnets []string `json:"subnets"` } HostAddress struct { @@ -181,12 +182,13 @@ type ( } HostScan struct { - HostKey types.PublicKey `json:"hostKey"` - PriceTable rhpv3.HostPriceTable - Settings rhpv2.HostSettings - Subnets []string - Success bool - Timestamp time.Time + HostKey types.PublicKey `json:"hostKey"` + PriceTable rhpv3.HostPriceTable `json:"priceTable"` + Settings rhpv2.HostSettings `json:"settings"` + ResolvedAddresses []string `json:"resolvedAddresses"` + Subnets []string `json:"subnets"` + Success bool `json:"success"` + Timestamp time.Time `json:"timestamp"` } HostPriceTable struct { @@ -196,9 +198,9 @@ type ( HostPriceTableUpdate struct { HostKey types.PublicKey `json:"hostKey"` - Success bool - Timestamp time.Time - PriceTable HostPriceTable + Success bool `json:"success"` + Timestamp time.Time `json:"timestamp"` + PriceTable HostPriceTable `json:"priceTable"` } HostCheck struct { diff --git a/api/metrics.go b/api/metrics.go index 412435e34..98c6f06b0 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -110,6 +110,7 @@ type ( Confirmed types.Currency `json:"confirmed"` Spendable types.Currency `json:"spendable"` Unconfirmed types.Currency `json:"unconfirmed"` + Immature types.Currency `json:"immature"` } WalletMetricsQueryOpts struct{} diff --git a/api/setting.go b/api/setting.go index ff93550ef..5976b00b2 100644 --- a/api/setting.go +++ b/api/setting.go @@ -32,6 +32,57 @@ var ( // ErrSettingNotFound is returned if a requested setting is not present in the // database. ErrSettingNotFound = errors.New("setting not found") + + // DefaultGougingSettings define the default gouging settings the bus is + // configured with on startup. These values can be adjusted using the + // settings API. + // + DefaultGougingSettings = GougingSettings{ + MaxRPCPrice: types.Siacoins(1).Div64(1000), // 1mS per RPC + MaxContractPrice: types.Siacoins(15), // 15 SC per contract + MaxDownloadPrice: types.Siacoins(3000), // 3000 SC per 1 TB + MaxUploadPrice: types.Siacoins(3000), // 3000 SC per 1 TB + MaxStoragePrice: types.Siacoins(3000).Div64(1e12).Div64(144 * 30), // 3000 SC per TB per month + HostBlockHeightLeeway: 6, // 6 blocks + MinPriceTableValidity: 5 * time.Minute, // 5 minutes + MinAccountExpiry: 24 * time.Hour, // 1 day + MinMaxEphemeralAccountBalance: types.Siacoins(1), // 1 SC + MigrationSurchargeMultiplier: 10, // 10x + } + + // DefaultPricePinSettings define the default price pin settings the bus is + // configured with on startup. These values can be adjusted using the + // settings API. + DefaultPricePinSettings = PricePinSettings{ + Enabled: false, + Currency: "usd", + ForexEndpointURL: "https://api.siascan.com/exchange-rate/siacoin", + Threshold: 0.05, + } + + // DefaultUploadPackingSettings define the default upload packing settings + // the bus is configured with on startup. + DefaultUploadPackingSettings = UploadPackingSettings{ + Enabled: true, + SlabBufferMaxSizeSoft: 1 << 32, // 4 GiB + } + + // DefaultRedundancySettings define the default redundancy settings the bus + // is configured with on startup. These values can be adjusted using the + // settings API. + // + // NOTE: default redundancy settings for testnet are different from mainnet. + DefaultRedundancySettings = RedundancySettings{ + MinShards: 10, + TotalShards: 30, + } + + // Same as DefaultRedundancySettings but for running on testnet networks due + // to their reduced number of hosts. + DefaultRedundancySettingsTestnet = RedundancySettings{ + MinShards: 2, + TotalShards: 6, + } ) type ( @@ -49,10 +100,10 @@ type ( // MaxContractPrice is the maximum allowed price to form a contract MaxContractPrice types.Currency `json:"maxContractPrice"` - // MaxDownloadPrice is the maximum allowed price to download 1TiB of data + // MaxDownloadPrice is the maximum allowed price to download 1TB of data MaxDownloadPrice types.Currency `json:"maxDownloadPrice"` - // MaxUploadPrice is the maximum allowed price to upload 1TiB of data + // MaxUploadPrice is the maximum allowed price to upload 1TB of data MaxUploadPrice types.Currency `json:"maxUploadPrice"` // MaxStoragePrice is the maximum allowed price to store 1 byte per block @@ -102,11 +153,11 @@ type ( Threshold float64 `json:"threshold"` // Autopilots contains the pinned settings for every autopilot. - Autopilots map[string]AutopilotPins `json:"autopilots,omitempty"` + Autopilots map[string]AutopilotPins `json:"autopilots"` // GougingSettingsPins contains the pinned settings for the gouging // settings. - GougingSettingsPins GougingSettingsPins `json:"gougingSettingsPins,omitempty"` + GougingSettingsPins GougingSettingsPins `json:"gougingSettingsPins"` } // AutopilotPins contains the available autopilot settings that can be @@ -119,7 +170,6 @@ type ( // pinned. GougingSettingsPins struct { MaxDownload Pin `json:"maxDownload"` - MaxRPCPrice Pin `json:"maxRPCPrice"` MaxStorage Pin `json:"maxStorage"` MaxUpload Pin `json:"maxUpload"` } @@ -155,9 +205,6 @@ func (p Pin) IsPinned() bool { // Validate returns an error if the price pin settings are not considered valid. func (pps PricePinSettings) Validate() error { - if !pps.Enabled { - return nil - } if pps.ForexEndpointURL == "" { return fmt.Errorf("price pin settings must have a forex endpoint URL") } diff --git a/api/slab.go b/api/slab.go index 1a5d3fc79..65d19788d 100644 --- a/api/slab.go +++ b/api/slab.go @@ -74,6 +74,6 @@ type ( } ) -func (s UploadedPackedSlab) Contracts() map[types.PublicKey]map[types.FileContractID]struct{} { +func (s UploadedPackedSlab) Contracts() []types.FileContractID { return object.ContractsFromShards(s.Shards) } diff --git a/api/state.go b/api/state.go index 63faa95ea..ff292e177 100644 --- a/api/state.go +++ b/api/state.go @@ -3,7 +3,6 @@ package api type ( // BuildState contains static information about the build. BuildState struct { - Network string `json:"network"` Version string `json:"version"` Commit string `json:"commit"` OS string `json:"os"` diff --git a/api/wallet.go b/api/wallet.go index 80da310d3..510e7b95b 100644 --- a/api/wallet.go +++ b/api/wallet.go @@ -10,6 +10,26 @@ import ( "go.sia.tech/core/types" ) +type ( + // A SiacoinElement is a SiacoinOutput along with its ID. + SiacoinElement struct { + types.SiacoinOutput + ID types.Hash256 `json:"id"` + MaturityHeight uint64 `json:"maturityHeight"` + } + + // A Transaction is an on-chain transaction relevant to a particular wallet, + // paired with useful metadata. + Transaction struct { + Raw types.Transaction `json:"raw,omitempty"` + Index types.ChainIndex `json:"index"` + ID types.TransactionID `json:"id"` + Inflow types.Currency `json:"inflow"` + Outflow types.Currency `json:"outflow"` + Timestamp time.Time `json:"timestamp"` + } +) + type ( // WalletFundRequest is the request type for the /wallet/fund endpoint. WalletFundRequest struct { @@ -75,6 +95,14 @@ type ( Spendable types.Currency `json:"spendable"` Confirmed types.Currency `json:"confirmed"` Unconfirmed types.Currency `json:"unconfirmed"` + Immature types.Currency `json:"immature"` + } + + WalletSendRequest struct { + Address types.Address `json:"address"` + Amount types.Currency `json:"amount"` + SubtractMinerFee bool `json:"subtractMinerFee"` + UseUnconfirmed bool `json:"useUnconfirmed"` } // WalletSignRequest is the request type for the /wallet/sign endpoint. diff --git a/api/worker.go b/api/worker.go index ae6024b84..894fd0c60 100644 --- a/api/worker.go +++ b/api/worker.go @@ -27,10 +27,6 @@ var ( // be scanned since it is on a private network. ErrHostOnPrivateNetwork = errors.New("host is on a private network") - // ErrHostTooManyAddresses is returned by the worker API when a host has - // more than two addresses of the same type. - ErrHostTooManyAddresses = errors.New("host has more than two addresses, or two of the same type") - // ErrMultiRangeNotSupported is returned by the worker API when a request // tries to download multiple ranges at once. ErrMultiRangeNotSupported = errors.New("multipart ranges are not supported") @@ -302,3 +298,10 @@ func ParseDownloadRange(req *http.Request) (DownloadRange, error) { } return dr, nil } + +func (r RHPScanResponse) Error() error { + if r.ScanError != "" { + return errors.New(r.ScanError) + } + return nil +} diff --git a/autopilot/accounts.go b/autopilot/accounts.go index a081dd09d..a1422d69a 100644 --- a/autopilot/accounts.go +++ b/autopilot/accounts.go @@ -30,7 +30,8 @@ type accounts struct { l *zap.SugaredLogger w *workerPool - refillInterval time.Duration + refillInterval time.Duration + revisionSubmissionBuffer uint64 mu sync.Mutex inProgressRefills map[types.Hash256]struct{} @@ -45,7 +46,7 @@ type ContractStore interface { Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) } -func newAccounts(ap *Autopilot, a AccountStore, c ContractStore, w *workerPool, l *zap.SugaredLogger, refillInterval time.Duration) *accounts { +func newAccounts(ap *Autopilot, a AccountStore, c ContractStore, w *workerPool, l *zap.SugaredLogger, refillInterval time.Duration, revisionSubmissionBuffer uint64) *accounts { return &accounts{ ap: ap, a: a, @@ -53,8 +54,9 @@ func newAccounts(ap *Autopilot, a AccountStore, c ContractStore, w *workerPool, l: l.Named("accounts"), w: w, - refillInterval: refillInterval, - inProgressRefills: make(map[types.Hash256]struct{}), + refillInterval: refillInterval, + revisionSubmissionBuffer: revisionSubmissionBuffer, + inProgressRefills: make(map[types.Hash256]struct{}), } } @@ -110,6 +112,13 @@ func (a *accounts) refillWorkerAccounts(ctx context.Context, w Worker) { return } + // fetch consensus state + cs, err := a.ap.bus.ConsensusState(ctx) + if err != nil { + a.l.Errorw(fmt.Sprintf("failed to fetch consensus state for refill: %v", err)) + return + } + // fetch worker id workerID, err := w.ID(ctx) if err != nil { @@ -126,17 +135,14 @@ func (a *accounts) refillWorkerAccounts(ctx context.Context, w Worker) { return } - // fetch all contract set contracts - contractSetContracts, err := a.c.Contracts(ctx, api.ContractsOpts{ContractSet: cfg.Config.Contracts.Set}) - if err != nil { - a.l.Errorw(fmt.Sprintf("failed to fetch contract set contracts: %v", err)) - return - } - - // build a map of contract set contracts + // filter all contract set contracts + var contractSetContracts []api.ContractMetadata inContractSet := make(map[types.FileContractID]struct{}) - for _, contract := range contractSetContracts { - inContractSet[contract.ID] = struct{}{} + for _, c := range contracts { + if c.InSet(cfg.Config.Contracts.Set) { + contractSetContracts = append(contractSetContracts, c) + inContractSet[c.ID] = struct{}{} + } } // refill accounts in separate goroutines @@ -144,9 +150,11 @@ func (a *accounts) refillWorkerAccounts(ctx context.Context, w Worker) { // launch refill if not already in progress if a.markRefillInProgress(workerID, c.HostKey) { go func(contract api.ContractMetadata) { + defer a.markRefillDone(workerID, contract.HostKey) + rCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() - accountID, refilled, rerr := refillWorkerAccount(rCtx, a.a, w, contract) + accountID, refilled, rerr := refillWorkerAccount(rCtx, a.a, w, contract, cs.BlockHeight, a.revisionSubmissionBuffer) if rerr != nil { if rerr.Is(errMaxDriftExceeded) { // register the alert if error is errMaxDriftExceeded @@ -170,8 +178,6 @@ func (a *accounts) refillWorkerAccounts(ctx context.Context, w Worker) { ) } } - - a.markRefillDone(workerID, contract.HostKey) }(c) } } @@ -193,7 +199,7 @@ func (err *refillError) Is(target error) bool { return errors.Is(err.err, target) } -func refillWorkerAccount(ctx context.Context, a AccountStore, w Worker, contract api.ContractMetadata) (accountID rhpv3.Account, refilled bool, rerr *refillError) { +func refillWorkerAccount(ctx context.Context, a AccountStore, w Worker, contract api.ContractMetadata, bh, revisionSubmissionBuffer uint64) (accountID rhpv3.Account, refilled bool, rerr *refillError) { wrapErr := func(err error, keysAndValues ...interface{}) *refillError { if err == nil { return nil @@ -217,6 +223,18 @@ func refillWorkerAccount(ctx context.Context, a AccountStore, w Worker, contract return } + // check if the contract is too close to the proof window to be revised, + // trying to refill the account would result in the host not returning the + // revision and returning an obfuscated error + if (bh + revisionSubmissionBuffer) > contract.WindowStart { + rerr = wrapErr(fmt.Errorf("not refilling account since contract is too close to the proof window to be revised (%v > %v)", bh+revisionSubmissionBuffer, contract.WindowStart), + "accountID", account.ID, + "hostKey", contract.HostKey, + "blockHeight", bh, + ) + return + } + // check if a host is potentially cheating before refilling. // We only check against the max drift if the account's drift is // negative because we don't care if we have more money than diff --git a/autopilot/alerts.go b/autopilot/alerts.go index 7f496798c..1d089c39d 100644 --- a/autopilot/alerts.go +++ b/autopilot/alerts.go @@ -14,6 +14,7 @@ import ( var ( alertAccountRefillID = alerts.RandomAlertID() // constant until restarted + alertHealthRefreshID = alerts.RandomAlertID() // constant until restarted alertLowBalanceID = alerts.RandomAlertID() // constant until restarted alertMigrationID = alerts.RandomAlertID() // constant until restarted alertPruningID = alerts.RandomAlertID() // constant until restarted @@ -166,12 +167,11 @@ func newMigrationFailedAlert(slabKey object.EncryptionKey, health float64, objec func newRefreshHealthFailedAlert(err error) alerts.Alert { return alerts.Alert{ - ID: alerts.RandomAlertID(), + ID: alertHealthRefreshID, Severity: alerts.SeverityCritical, Message: "Health refresh failed", Data: map[string]interface{}{ - "migrationsInterrupted": true, - "error": err.Error(), + "error": err.Error(), }, Timestamp: time.Now(), } diff --git a/autopilot/autopilot.go b/autopilot/autopilot.go index 263744ca5..58fb0a9ec 100644 --- a/autopilot/autopilot.go +++ b/autopilot/autopilot.go @@ -17,10 +17,11 @@ import ( "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" "go.sia.tech/renterd/autopilot/contractor" + "go.sia.tech/renterd/autopilot/scanner" "go.sia.tech/renterd/build" + "go.sia.tech/renterd/config" "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" - "go.sia.tech/renterd/wallet" "go.sia.tech/renterd/webhooks" "go.uber.org/zap" ) @@ -86,7 +87,7 @@ type Bus interface { // wallet Wallet(ctx context.Context) (api.WalletResponse, error) WalletDiscard(ctx context.Context, txn types.Transaction) error - WalletOutputs(ctx context.Context) (resp []wallet.SiacoinElement, err error) + WalletOutputs(ctx context.Context) (resp []api.SiacoinElement, err error) WalletPending(ctx context.Context) (resp []types.Transaction, err error) WalletRedistribute(ctx context.Context, outputs int, amount types.Currency) (ids []types.TransactionID, err error) } @@ -102,7 +103,7 @@ type Autopilot struct { a *accounts c *contractor.Contractor m *migrator - s *scanner + s scanner.Scanner tickerDuration time.Duration wg sync.WaitGroup @@ -123,39 +124,32 @@ type Autopilot struct { } // New initializes an Autopilot. -func New(id string, bus Bus, workers []Worker, logger *zap.Logger, heartbeat time.Duration, scannerScanInterval time.Duration, scannerBatchSize, scannerNumThreads uint64, migrationHealthCutoff float64, accountsRefillInterval time.Duration, revisionSubmissionBuffer, migratorParallelSlabsPerWorker uint64, revisionBroadcastInterval time.Duration) (*Autopilot, error) { +func New(cfg config.Autopilot, bus Bus, workers []Worker, logger *zap.Logger) (_ *Autopilot, err error) { + logger = logger.Named("autopilot").Named(cfg.ID) shutdownCtx, shutdownCtxCancel := context.WithCancel(context.Background()) - ap := &Autopilot{ - alerts: alerts.WithOrigin(bus, fmt.Sprintf("autopilot.%s", id)), - id: id, + alerts: alerts.WithOrigin(bus, fmt.Sprintf("autopilot.%s", cfg.ID)), + id: cfg.ID, bus: bus, - logger: logger.Sugar().Named("autopilot").Named(id), + logger: logger.Sugar(), workers: newWorkerPool(workers), shutdownCtx: shutdownCtx, shutdownCtxCancel: shutdownCtxCancel, - tickerDuration: heartbeat, + tickerDuration: cfg.Heartbeat, pruningAlertIDs: make(map[types.FileContractID]types.Hash256), } - scanner, err := newScanner( - ap, - scannerBatchSize, - scannerNumThreads, - scannerScanInterval, - scannerTimeoutInterval, - scannerTimeoutMinTimeout, - ) + + ap.s, err = scanner.New(ap.bus, cfg.ScannerBatchSize, cfg.ScannerNumThreads, cfg.ScannerInterval, logger) if err != nil { - return nil, err + return } - ap.s = scanner - ap.c = contractor.New(bus, bus, ap.logger, revisionSubmissionBuffer, revisionBroadcastInterval) - ap.m = newMigrator(ap, migrationHealthCutoff, migratorParallelSlabsPerWorker) - ap.a = newAccounts(ap, ap.bus, ap.bus, ap.workers, ap.logger, accountsRefillInterval) + ap.c = contractor.New(bus, bus, ap.logger, cfg.RevisionSubmissionBuffer, cfg.RevisionBroadcastInterval) + ap.m = newMigrator(ap, cfg.MigrationHealthCutoff, cfg.MigratorParallelSlabsPerWorker) + ap.a = newAccounts(ap, ap.bus, ap.bus, ap.workers, ap.logger, cfg.AccountsRefillInterval, cfg.RevisionSubmissionBuffer) return ap, nil } @@ -217,11 +211,11 @@ func (ap *Autopilot) configHandlerPOST(jc jape.Context) { jc.Encode(res) } -func (ap *Autopilot) Run() error { +func (ap *Autopilot) Run() { ap.startStopMu.Lock() if ap.isRunning() { ap.startStopMu.Unlock() - return errors.New("already running") + return } ap.startTime = time.Now() ap.triggerChan = make(chan bool, 1) @@ -234,7 +228,7 @@ func (ap *Autopilot) Run() error { // block until the autopilot is online if online := ap.blockUntilOnline(); !online { ap.logger.Error("autopilot stopped before it was able to come online") - return nil + return } // schedule a trigger when the wallet receives its first deposit @@ -242,7 +236,7 @@ func (ap *Autopilot) Run() error { if !errors.Is(err, context.Canceled) { ap.logger.Error(err) } - return nil + return } var forceScan bool @@ -254,10 +248,9 @@ func (ap *Autopilot) Run() error { defer ap.logger.Info("autopilot iteration ended") // initiate a host scan - no need to be synced or configured for scanning - ap.s.tryUpdateTimeout() - ap.s.tryPerformHostScan(ap.shutdownCtx, w, forceScan) + ap.s.Scan(ap.shutdownCtx, w, forceScan) - // reset forceScan + // reset forceScans forceScan = false // block until consensus is synced @@ -270,7 +263,7 @@ func (ap *Autopilot) Run() error { return } else if blocked { if scanning, _ := ap.s.Status(); !scanning { - ap.s.tryPerformHostScan(ap.shutdownCtx, w, true) + ap.s.Scan(ap.shutdownCtx, w, true) } } @@ -291,8 +284,8 @@ func (ap *Autopilot) Run() error { return } - // prune hosts that have been offline for too long - ap.s.PruneHosts(ap.shutdownCtx, autopilot.Config.Hosts) + // update the scanner with the hosts config + ap.s.UpdateHostsConfig(autopilot.Config.Hosts) // Log worker id chosen for this maintenance iteration. workerID, err := w.ID(ap.shutdownCtx) @@ -351,7 +344,7 @@ func (ap *Autopilot) Run() error { select { case <-ap.shutdownCtx.Done(): - return nil + return case forceScan = <-ap.triggerChan: ap.logger.Info("autopilot iteration triggered") ap.ticker.Reset(ap.tickerDuration) @@ -359,11 +352,11 @@ func (ap *Autopilot) Run() error { case <-tickerFired: } } - return nil + return } // Shutdown shuts down the autopilot. -func (ap *Autopilot) Shutdown(_ context.Context) error { +func (ap *Autopilot) Shutdown(ctx context.Context) error { ap.startStopMu.Lock() defer ap.startStopMu.Unlock() @@ -372,6 +365,7 @@ func (ap *Autopilot) Shutdown(_ context.Context) error { ap.shutdownCtxCancel() close(ap.triggerChan) ap.wg.Wait() + ap.s.Shutdown(ctx) ap.startTime = time.Time{} } return nil @@ -698,8 +692,16 @@ func (ap *Autopilot) configHandlerPUT(jc jape.Context) { autopilot.Config = cfg } - // update the autopilot and interrupt migrations if necessary - if err := jc.Check("failed to update autopilot config", ap.bus.UpdateAutopilot(jc.Request.Context(), autopilot)); err == nil && contractSetChanged { + // update the autopilot + if jc.Check("failed to update autopilot config", ap.bus.UpdateAutopilot(jc.Request.Context(), autopilot)) != nil { + return + } + + // update the scanner with the hosts config + ap.s.UpdateHostsConfig(cfg.Hosts) + + // interrupt migrations if necessary + if contractSetChanged { ap.m.SignalMaintenanceFinished() } } @@ -825,7 +827,6 @@ func (ap *Autopilot) stateHandlerGET(jc jape.Context) { StartTime: api.TimeRFC3339(ap.StartTime()), BuildState: api.BuildState{ - Network: build.NetworkName(), Version: build.Version(), Commit: build.Commit(), OS: runtime.GOOS, diff --git a/autopilot/contract_pruning.go b/autopilot/contract_pruning.go index 2f491249b..7822fb326 100644 --- a/autopilot/contract_pruning.go +++ b/autopilot/contract_pruning.go @@ -10,7 +10,6 @@ import ( "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/utils" - "go.sia.tech/siad/build" "go.uber.org/zap" ) @@ -238,7 +237,7 @@ func humanReadableSize(b int) string { } func shouldSendPruneAlert(err error, version, release string) bool { - oldHost := (build.VersionCmp(version, "1.6.0") < 0 || version == "1.6.0" && release == "") + oldHost := (utils.VersionCmp(version, "1.6.0") < 0 || version == "1.6.0" && release == "") sectorRootsIssue := utils.IsErr(err, errInvalidSectorRootsRange) && oldHost merkleRootIssue := utils.IsErr(err, errInvalidMerkleProof) && oldHost return err != nil && !(sectorRootsIssue || merkleRootIssue || diff --git a/autopilot/contractor/alerts.go b/autopilot/contractor/alerts.go index 3505667ad..9185c88dd 100644 --- a/autopilot/contractor/alerts.go +++ b/autopilot/contractor/alerts.go @@ -21,9 +21,9 @@ var ( alertRenewalFailedID = alerts.RandomAlertID() // constant until restarted ) -func newContractRenewalFailedAlert(contract api.ContractMetadata, interrupted bool, err error) alerts.Alert { +func newContractRenewalFailedAlert(contract api.ContractMetadata, ourFault bool, err error) alerts.Alert { severity := alerts.SeverityWarning - if interrupted { + if ourFault { severity = alerts.SeverityCritical } @@ -32,10 +32,10 @@ func newContractRenewalFailedAlert(contract api.ContractMetadata, interrupted bo Severity: severity, Message: "Contract renewal failed", Data: map[string]interface{}{ - "error": err.Error(), - "renewalsInterrupted": interrupted, - "contractID": contract.ID.String(), - "hostKey": contract.HostKey.String(), + "error": err.Error(), + "hostError": !ourFault, + "contractID": contract.ID.String(), + "hostKey": contract.HostKey.String(), }, Timestamp: time.Now(), } diff --git a/autopilot/contractor/contract_spending.go b/autopilot/contractor/contract_spending.go index 82e08831c..54985a130 100644 --- a/autopilot/contractor/contract_spending.go +++ b/autopilot/contractor/contract_spending.go @@ -1,11 +1,13 @@ package contractor import ( + "context" + "go.sia.tech/core/types" "go.sia.tech/renterd/api" ) -func (c *Contractor) currentPeriodSpending(contracts []api.Contract, currentPeriod uint64) types.Currency { +func currentPeriodSpending(contracts []api.ContractMetadata, currentPeriod uint64) types.Currency { totalCosts := make(map[types.FileContractID]types.Currency) for _, c := range contracts { totalCosts[c.ID] = c.TotalCost @@ -15,7 +17,7 @@ func (c *Contractor) currentPeriodSpending(contracts []api.Contract, currentPeri var filtered []api.ContractMetadata for _, contract := range contracts { if contract.WindowStart <= currentPeriod { - filtered = append(filtered, contract.ContractMetadata) + filtered = append(filtered, contract) } } @@ -27,14 +29,19 @@ func (c *Contractor) currentPeriodSpending(contracts []api.Contract, currentPeri return totalAllocated } -func (c *Contractor) remainingFunds(contracts []api.Contract, state *MaintenanceState) types.Currency { +func remainingAllowance(ctx context.Context, bus Bus, state *MaintenanceState) (types.Currency, error) { + contracts, err := bus.Contracts(ctx, api.ContractsOpts{}) + if err != nil { + return types.Currency{}, err + } + // find out how much we spent in the current period - spent := c.currentPeriodSpending(contracts, state.Period()) + spent := currentPeriodSpending(contracts, state.Period()) // figure out remaining funds var remaining types.Currency if state.Allowance().Cmp(spent) > 0 { remaining = state.Allowance().Sub(spent) } - return remaining + return remaining, nil } diff --git a/autopilot/contractor/contractor.go b/autopilot/contractor/contractor.go index ca0c5c9df..e82253d43 100644 --- a/autopilot/contractor/contractor.go +++ b/autopilot/contractor/contractor.go @@ -2,6 +2,7 @@ package contractor import ( "context" + "encoding/hex" "errors" "fmt" "math" @@ -13,13 +14,13 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" - cwallet "go.sia.tech/coreutils/wallet" + "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" "go.sia.tech/renterd/internal/utils" - "go.sia.tech/renterd/wallet" - "go.sia.tech/renterd/worker" "go.uber.org/zap" + "lukechampine.com/frand" ) const ( @@ -37,10 +38,6 @@ const ( // punishing a contract for not being able to refresh failedRefreshForgivenessPeriod = 24 * time.Hour - // leewayPctCandidateHosts is the leeway we apply when fetching candidate - // hosts, we fetch ~10% more than required - leewayPctCandidateHosts = 1.1 - // leewayPctRequiredContracts is the leeway we apply on the amount of // contracts the config dictates we should have, we'll only form new // contracts if the number of contracts dips below 90% of the required @@ -89,6 +86,7 @@ type Bus interface { AncestorContracts(ctx context.Context, id types.FileContractID, minStartHeight uint64) ([]api.ArchivedContract, error) ArchiveContracts(ctx context.Context, toArchive map[types.FileContractID]string) error ConsensusState(ctx context.Context) (api.ConsensusState, error) + Contract(ctx context.Context, id types.FileContractID) (api.ContractMetadata, error) Contracts(ctx context.Context, opts api.ContractsOpts) (contracts []api.ContractMetadata, err error) FileContractTax(ctx context.Context, payout types.Currency) (types.Currency, error) Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) @@ -107,6 +105,22 @@ type Worker interface { RHPScan(ctx context.Context, hostKey types.PublicKey, hostIP string, timeout time.Duration) (api.RHPScanResponse, error) } +type contractChecker interface { + isUsableContract(cfg api.AutopilotConfig, s rhpv2.HostSettings, pt rhpv3.HostPriceTable, rs api.RedundancySettings, contract api.Contract, inSet bool, bh uint64, f *hostSet) (usable, refresh, renew bool, reasons []string) + pruneContractRefreshFailures(contracts []api.ContractMetadata) + shouldArchive(c api.Contract, bh uint64) error +} + +type contractReviser interface { + formContract(ctx *mCtx, w Worker, host api.Host, minInitialContractFunds, maxInitialContractFunds types.Currency, budget *types.Currency, logger *zap.SugaredLogger) (cm api.ContractMetadata, ourFault bool, err error) + renewContract(ctx *mCtx, w Worker, c api.Contract, h api.Host, budget *types.Currency, logger *zap.SugaredLogger) (cm api.ContractMetadata, ourFault bool, err error) + refreshContract(ctx *mCtx, w Worker, c api.Contract, h api.Host, budget *types.Currency, logger *zap.SugaredLogger) (cm api.ContractMetadata, ourFault bool, err error) +} + +type revisionBroadcaster interface { + broadcastRevisions(ctx context.Context, w Worker, contracts []api.ContractMetadata, logger *zap.SugaredLogger) +} + type ( Contractor struct { alerter alerts.Alerter @@ -126,16 +140,10 @@ type ( scoredHost struct { host api.Host + sb api.HostScoreBreakdown score float64 } - contractInfo struct { - host api.Host - contract api.Contract - usable bool - recoverable bool - } - contractSetAdditions struct { HostKey types.PublicKey `json:"hostKey"` Additions []contractSetAddition `json:"additions"` @@ -156,12 +164,6 @@ type ( Reason string `json:"reasons"` Time api.TimeRFC3339 `json:"time"` } - - renewal struct { - from api.ContractMetadata - to api.ContractMetadata - ci contractInfo - } ) func New(bus Bus, alerter alerts.Alerter, logger *zap.SugaredLogger, revisionSubmissionBuffer uint64, revisionBroadcastInterval time.Duration) *Contractor { @@ -189,292 +191,472 @@ func (c *Contractor) Close() error { return nil } -func canSkipContractMaintenance(ctx context.Context, cfg api.ContractsConfig) (string, bool) { - select { - case <-ctx.Done(): - return "", true - default: - } - - // no maintenance if no hosts are requested - // - // NOTE: this is an important check because we assume Contracts.Amount is - // not zero in several places - if cfg.Amount == 0 { - return "contracts is set to zero, skipping contract maintenance", true - } - - // no maintenance if no allowance was set - if cfg.Allowance.IsZero() { - return "allowance is set to zero, skipping contract maintenance", true - } - - // no maintenance if no period was set - if cfg.Period == 0 { - return "period is set to zero, skipping contract maintenance", true - } - return "", false -} - func (c *Contractor) PerformContractMaintenance(ctx context.Context, w Worker, state *MaintenanceState) (bool, error) { - return c.performContractMaintenance(newMaintenanceCtx(ctx, state), w) + return performContractMaintenance(newMaintenanceCtx(ctx, state), c.alerter, c.bus, c.churn, w, c, c, c, c.logger) } -func (c *Contractor) performContractMaintenance(ctx *mCtx, w Worker) (bool, error) { - mCtx := newMaintenanceCtx(ctx, ctx.state) +func (c *Contractor) formContract(ctx *mCtx, w Worker, host api.Host, minInitialContractFunds, maxInitialContractFunds types.Currency, budget *types.Currency, logger *zap.SugaredLogger) (cm api.ContractMetadata, proceed bool, err error) { + logger = logger.With("hk", host.PublicKey, "hostVersion", host.Settings.Version, "hostRelease", host.Settings.Release) - // check if we can skip maintenance - if reason, skip := canSkipContractMaintenance(ctx, ctx.ContractsConfig()); skip { - if reason != "" { - c.logger.Warn(reason) - } - if skip { - return false, nil - } + // convenience variables + hk := host.PublicKey + + // fetch host settings + scan, err := w.RHPScan(ctx, hk, host.NetAddress, 0) + if err != nil { + logger.Infow(err.Error(), "hk", hk) + return api.ContractMetadata{}, true, err } - c.logger.Info("performing contract maintenance") - // fetch current contract set - currentSet, err := c.bus.Contracts(ctx, api.ContractsOpts{ContractSet: ctx.ContractSet()}) - if err != nil && !strings.Contains(err.Error(), api.ErrContractSetNotFound.Error()) { - return false, err + // fetch consensus state + cs, err := c.bus.ConsensusState(ctx) + if err != nil { + return api.ContractMetadata{}, false, err } - hasContractInSet := make(map[types.PublicKey]types.FileContractID) - isInCurrentSet := make(map[types.FileContractID]struct{}) - for _, c := range currentSet { - hasContractInSet[c.HostKey] = c.ID - isInCurrentSet[c.ID] = struct{}{} + + // check our budget + txnFee := ctx.state.Fee.Mul64(estimatedFileContractTransactionSetSize) + renterFunds := initialContractFunding(scan.Settings, txnFee, minInitialContractFunds, maxInitialContractFunds) + if budget.Cmp(renterFunds) < 0 { + logger.Infow("insufficient budget", "budget", budget, "needed", renterFunds) + return api.ContractMetadata{}, false, errors.New("insufficient budget") } - c.logger.Infof("contract set '%s' holds %d contracts", ctx.ContractSet(), len(currentSet)) - // fetch all contracts from the worker - start := time.Now() - resp, err := w.Contracts(ctx, timeoutHostRevision) + // calculate the host collateral + endHeight := ctx.EndHeight() + expectedStorage := renterFundsToExpectedStorage(renterFunds, endHeight-cs.BlockHeight, scan.PriceTable) + hostCollateral := rhpv2.ContractFormationCollateral(ctx.Period(), expectedStorage, scan.Settings) + + // form contract + contract, _, err := w.RHPForm(ctx, endHeight, hk, host.NetAddress, ctx.state.Address, renterFunds, hostCollateral) if err != nil { - return false, err - } - if resp.Errors != nil { - for pk, err := range resp.Errors { - c.logger.With("hostKey", pk).With("error", err).Warn("failed to fetch revision") + // TODO: keep track of consecutive failures and break at some point + logger.Errorw(fmt.Sprintf("contract formation failed, err: %v", err), "hk", hk) + if utils.IsErr(err, wallet.ErrNotEnoughFunds) { + return api.ContractMetadata{}, false, err } + return api.ContractMetadata{}, true, err } - contracts := resp.Contracts - c.logger.Infof("fetched %d contracts from the worker, took %v", len(resp.Contracts), time.Since(start)) - // prune contract refresh failure map - c.pruneContractRefreshFailures(contracts) + // update the budget + *budget = budget.Sub(renterFunds) - // run revision broadcast - c.runRevisionBroadcast(ctx, w, contracts, isInCurrentSet) + // persist contract in store + contractPrice := contract.Revision.MissedHostPayout().Sub(hostCollateral) + formedContract, err := c.bus.AddContract(ctx, contract, contractPrice, renterFunds, cs.BlockHeight, api.ContractStatePending) + if err != nil { + logger.Errorw(fmt.Sprintf("contract formation failed, err: %v", err), "hk", hk) + return api.ContractMetadata{}, true, err + } - // sort contracts by their size - sort.Slice(contracts, func(i, j int) bool { - return contracts[i].FileSize() > contracts[j].FileSize() - }) + logger.Infow("formation succeeded", + "fcid", formedContract.ID, + "renterFunds", renterFunds.String(), + "collateral", hostCollateral.String(), + ) + return formedContract, true, nil +} - // get used hosts - usedHosts := make(map[types.PublicKey]struct{}) - for _, contract := range contracts { - usedHosts[contract.HostKey] = struct{}{} +func (c *Contractor) initialContractFunding(settings rhpv2.HostSettings, txnFee, minFunding, maxFunding types.Currency) types.Currency { + if !maxFunding.IsZero() && minFunding.Cmp(maxFunding) > 0 { + panic("given min is larger than max") // developer error } - // compile map of stored data per contract - contractData := make(map[types.FileContractID]uint64) - for _, c := range contracts { - contractData[c.ID] = c.FileSize() + funding := settings.ContractPrice.Add(txnFee).Mul64(10) // TODO arbitrary multiplier + if !minFunding.IsZero() && funding.Cmp(minFunding) < 0 { + return minFunding } - - // fetch all hosts - hosts, err := c.bus.SearchHosts(ctx, api.SearchHostOptions{Limit: -1, FilterMode: api.HostFilterModeAllowed}) - if err != nil { - return false, err + if !maxFunding.IsZero() && funding.Cmp(maxFunding) > 0 { + return maxFunding } + return funding +} - // resolve host IPs on the fly for hosts that have a contract in the set but - // no subnet information, this was added to minimize churn immediately after - // adding 'subnets' to the host table - for _, h := range hosts { - if fcid, ok := hasContractInSet[h.PublicKey]; ok && len(h.Subnets) == 0 { - h.Subnets, _, err = utils.ResolveHostIP(ctx, h.NetAddress) - if err != nil { - c.logger.Warnw("failed to resolve host IP for a host with a contract in the set", "hk", h.PublicKey, "fcid", fcid, "err", err) - continue - } - } +func (c *Contractor) pruneContractRefreshFailures(contracts []api.ContractMetadata) { + contractMap := make(map[types.FileContractID]struct{}) + for _, contract := range contracts { + contractMap[contract.ID] = struct{}{} } - - // check if any used hosts have lost data to warn the user - var toDismiss []types.Hash256 - for _, h := range hosts { - if registerLostSectorsAlert(h.Interactions.LostSectors*rhpv2.SectorSize, h.StoredData) { - c.alerter.RegisterAlert(ctx, newLostSectorsAlert(h.PublicKey, h.Settings.Version, h.Settings.Release, h.Interactions.LostSectors)) - } else { - toDismiss = append(toDismiss, alerts.IDForHost(alertLostSectorsID, h.PublicKey)) + for fcid := range c.firstRefreshFailure { + if _, ok := contractMap[fcid]; !ok { + delete(c.firstRefreshFailure, fcid) } } - if len(toDismiss) > 0 { - c.alerter.DismissAlerts(ctx, toDismiss...) +} + +func (c *Contractor) refreshContract(ctx *mCtx, w Worker, contract api.Contract, host api.Host, budget *types.Currency, logger *zap.SugaredLogger) (cm api.ContractMetadata, proceed bool, err error) { + if contract.Revision == nil { + return api.ContractMetadata{}, true, errors.New("can't refresh contract without a revision") } + logger = logger.With("to_renew", contract.ID, "hk", contract.HostKey, "hostVersion", host.Settings.Version, "hostRelease", host.Settings.Release) + + // convenience variables + settings := host.Settings + pt := host.PriceTable.HostPriceTable + fcid := contract.ID + hk := contract.HostKey + rev := contract.Revision - // fetch candidate hosts - candidates, unusableHosts, err := c.candidateHosts(mCtx, hosts, usedHosts, minValidScore) // avoid 0 score hosts + // fetch consensus state + cs, err := c.bus.ConsensusState(ctx) if err != nil { - return false, err + return api.ContractMetadata{}, false, err } - // min score to pass checks - var minScore float64 - if len(hosts) > 0 { - minScore = c.calculateMinScore(candidates, ctx.WantedContracts()) + // calculate the renter funds + var renterFunds types.Currency + if isOutOfFunds(ctx.AutopilotConfig(), pt, contract) { + renterFunds = c.refreshFundingEstimate(ctx.AutopilotConfig(), contract, host, ctx.state.Fee, logger) } else { - c.logger.Warn("could not calculate min score, no hosts found") + renterFunds = rev.ValidRenterPayout() // don't increase funds } - // run host checks - checks, err := c.runHostChecks(mCtx, hosts, minScore) - if err != nil { - return false, fmt.Errorf("failed to run host checks, err: %v", err) + // check our budget + if budget.Cmp(renterFunds) < 0 { + logger.Warnw("insufficient budget for refresh", "hk", hk, "fcid", fcid, "budget", budget, "needed", renterFunds) + return api.ContractMetadata{}, false, fmt.Errorf("insufficient budget: %s < %s", budget.String(), renterFunds.String()) } - // fetch consensus state - cs, err := c.bus.ConsensusState(ctx) - if err != nil { - return false, fmt.Errorf("failed to fetch consensus state, err: %v", err) - } + expectedNewStorage := renterFundsToExpectedStorage(renterFunds, contract.EndHeight()-cs.BlockHeight, pt) + unallocatedCollateral := contract.RemainingCollateral() - // run contract checks - updatedSet, toArchive, toStopUsing, toRefresh, toRenew := c.runContractChecks(mCtx, checks, contracts, isInCurrentSet, cs.BlockHeight) + // a refresh should always result in a contract that has enough collateral + minNewCollateral := minRemainingCollateral(ctx.AutopilotConfig(), ctx.state.RS, renterFunds, settings, pt).Mul64(2) - // update host checks - for hk, check := range checks { - if err := c.bus.UpdateHostCheck(ctx, ctx.ApID(), hk, *check); err != nil { - c.logger.Errorf("failed to update host check for host %v, err: %v", hk, err) - } - } + // maxFundAmount is the remaining funds of the contract to refresh plus the + // budget since the previous contract was in the same period + maxFundAmount := budget.Add(rev.ValidRenterPayout()) - // archive contracts - if len(toArchive) > 0 { - c.logger.Infof("archiving %d contracts: %+v", len(toArchive), toArchive) - if err := c.bus.ArchiveContracts(ctx, toArchive); err != nil { - c.logger.Errorf("failed to archive contracts, err: %v", err) // continue + // renew the contract + resp, err := w.RHPRenew(ctx, contract.ID, contract.EndHeight(), hk, contract.SiamuxAddr, settings.Address, ctx.state.Address, renterFunds, minNewCollateral, maxFundAmount, expectedNewStorage, settings.WindowSize) + if err != nil { + if strings.Contains(err.Error(), "new collateral is too low") { + logger.Infow("refresh failed: contract wouldn't have enough collateral after refresh", + "hk", hk, + "fcid", fcid, + "unallocatedCollateral", unallocatedCollateral.String(), + "minNewCollateral", minNewCollateral.String(), + ) + return api.ContractMetadata{}, true, err + } + logger.Errorw("refresh failed", zap.Error(err), "hk", hk, "fcid", fcid) + if utils.IsErr(err, wallet.ErrNotEnoughFunds) && !rhp3.IsErrHost(err) { + return api.ContractMetadata{}, false, err } + return api.ContractMetadata{}, true, err } - // calculate remaining funds - remaining := c.remainingFunds(contracts, mCtx.state) + // update the budget + *budget = budget.Sub(resp.FundAmount) - // calculate 'limit' amount of contracts we want to renew - var limit int - if len(toRenew) > 0 { - // when renewing, prioritise contracts that have already been in the set - // before and out of those prefer the largest ones. - sort.Slice(toRenew, func(i, j int) bool { - _, icsI := isInCurrentSet[toRenew[i].contract.ID] - _, icsJ := isInCurrentSet[toRenew[j].contract.ID] - if icsI && !icsJ { - return true - } else if !icsI && icsJ { - return false - } - return toRenew[i].contract.FileSize() > toRenew[j].contract.FileSize() - }) - for len(updatedSet)+limit < int(ctx.WantedContracts()) && limit < len(toRenew) { - // as long as we're missing contracts, increase the renewal limit - limit++ - } + // persist the contract + refreshedContract, err := c.bus.AddRenewedContract(ctx, resp.Contract, resp.ContractPrice, renterFunds, cs.BlockHeight, contract.ID, api.ContractStatePending) + if err != nil { + logger.Errorw("adding refreshed contract failed", zap.Error(err), "hk", hk, "fcid", fcid) + return api.ContractMetadata{}, false, err } - // run renewals on contracts that are not in updatedSet yet. We only renew - // up to 'limit' of those to avoid having too many contracts in the updated - // set afterwards - var renewed []renewal - if limit > 0 { - var toKeep []api.ContractMetadata - renewed, toKeep = c.runContractRenewals(ctx, w, toRenew, &remaining, limit) - for _, ri := range renewed { - if ri.ci.usable || ri.ci.recoverable { - updatedSet = append(updatedSet, ri.to) - } - contractData[ri.to.ID] = contractData[ri.from.ID] - } - updatedSet = append(updatedSet, toKeep...) + // add to renewed set + newCollateral := resp.Contract.Revision.MissedHostPayout().Sub(resp.ContractPrice) + logger.Infow("refresh succeeded", + "fcid", refreshedContract.ID, + "renewedFrom", contract.ID, + "renterFunds", renterFunds.String(), + "minNewCollateral", minNewCollateral.String(), + "newCollateral", newCollateral.String(), + ) + return refreshedContract, true, nil +} + +func (c *Contractor) renewContract(ctx *mCtx, w Worker, contract api.Contract, host api.Host, budget *types.Currency, logger *zap.SugaredLogger) (cm api.ContractMetadata, proceed bool, err error) { + if contract.Revision == nil { + return api.ContractMetadata{}, true, errors.New("can't renew contract without a revision") } + logger = logger.With("to_renew", contract.ID, "hk", contract.HostKey, "hostVersion", host.Settings.Version, "hostRelease", host.Settings.Release) + + // convenience variables + settings := host.Settings + pt := host.PriceTable.HostPriceTable + fcid := contract.ID + rev := contract.Revision + hk := contract.HostKey - // run contract refreshes - refreshed, err := c.runContractRefreshes(ctx, w, toRefresh, &remaining) + // fetch consensus state + cs, err := c.bus.ConsensusState(ctx) if err != nil { - c.logger.Errorf("failed to refresh contracts, err: %v", err) // continue - } else { - for _, ri := range refreshed { - if ri.ci.usable || ri.ci.recoverable { - updatedSet = append(updatedSet, ri.to) - } - contractData[ri.to.ID] = contractData[ri.from.ID] - } + return api.ContractMetadata{}, false, err } - // to avoid forming new contracts as soon as we dip below - // 'Contracts.Amount', we define a threshold but only if we have more - // contracts than 'Contracts.Amount' already - threshold := ctx.WantedContracts() - if uint64(len(contracts)) > ctx.WantedContracts() { - threshold = addLeeway(threshold, leewayPctRequiredContracts) - } + // calculate the renter funds for the renewal a.k.a. the funds the renter will + // be able to spend + minRenterFunds, _ := initialContractFundingMinMax(ctx.AutopilotConfig()) + renterFunds := renewFundingEstimate(minRenterFunds, contract.TotalCost, contract.RenterFunds(), logger) - // check if we need to form contracts and add them to the contract set - var formed []api.ContractMetadata - if uint64(len(updatedSet)) < threshold && !ctx.state.SkipContractFormations { - // form contracts - formed, err = c.runContractFormations(ctx, w, candidates, usedHosts, unusableHosts, ctx.WantedContracts()-uint64(len(updatedSet)), &remaining) - if err != nil { - c.logger.Errorf("failed to form contracts, err: %v", err) // continue - } else { - for _, fc := range formed { - updatedSet = append(updatedSet, fc) - contractData[fc.ID] = 0 - } - } + // check our budget + if budget.Cmp(renterFunds) < 0 { + logger.Infow("insufficient budget", "budget", budget, "needed", renterFunds) + return api.ContractMetadata{}, false, errors.New("insufficient budget") } - // cap the amount of contracts we want to keep to the configured amount - for _, contract := range updatedSet { - if _, exists := contractData[contract.ID]; !exists { - c.logger.Errorf("contract %v not found in contractData", contract.ID) - } + // sanity check the endheight is not the same on renewals + endHeight := ctx.EndHeight() + if endHeight <= rev.EndHeight() { + logger.Infow("invalid renewal endheight", "oldEndheight", rev.EndHeight(), "newEndHeight", endHeight, "period", ctx.state.Period, "bh", cs.BlockHeight) + return api.ContractMetadata{}, false, fmt.Errorf("renewal endheight should surpass the current contract endheight, %v <= %v", endHeight, rev.EndHeight()) } - if len(updatedSet) > int(ctx.WantedContracts()) { - // sort by contract size - sort.Slice(updatedSet, func(i, j int) bool { - return contractData[updatedSet[i].ID] > contractData[updatedSet[j].ID] - }) - for _, contract := range updatedSet[ctx.WantedContracts():] { - toStopUsing[contract.ID] = "truncated" + + // calculate the expected new storage + expectedNewStorage := renterFundsToExpectedStorage(renterFunds, endHeight-cs.BlockHeight, pt) + + // renew the contract + resp, err := w.RHPRenew(ctx, fcid, endHeight, hk, contract.SiamuxAddr, settings.Address, ctx.state.Address, renterFunds, types.ZeroCurrency, *budget, expectedNewStorage, settings.WindowSize) + if err != nil { + logger.Errorw( + "renewal failed", + zap.Error(err), + "endHeight", endHeight, + "renterFunds", renterFunds, + "expectedNewStorage", expectedNewStorage, + ) + if utils.IsErr(err, wallet.ErrNotEnoughFunds) && !rhp3.IsErrHost(err) { + return api.ContractMetadata{}, false, err } - updatedSet = updatedSet[:ctx.WantedContracts()] + return api.ContractMetadata{}, true, err } - // convert to set of file contract ids - var newSet []types.FileContractID - for _, contract := range updatedSet { - newSet = append(newSet, contract.ID) - } + // update the budget + *budget = budget.Sub(resp.FundAmount) - // update contract set - err = c.bus.SetContractSet(ctx, ctx.ContractSet(), newSet) + // persist the contract + renewedContract, err := c.bus.AddRenewedContract(ctx, resp.Contract, resp.ContractPrice, renterFunds, cs.BlockHeight, fcid, api.ContractStatePending) if err != nil { - return false, err + logger.Errorw(fmt.Sprintf("renewal failed to persist, err: %v", err)) + return api.ContractMetadata{}, false, err } - // return whether the maintenance changed the contract set - return c.computeContractSetChanged(mCtx, currentSet, updatedSet, formed, refreshed, renewed, toStopUsing, contractData), nil + newCollateral := resp.Contract.Revision.MissedHostPayout().Sub(resp.ContractPrice) + logger.Infow( + "renewal succeeded", + "fcid", renewedContract.ID, + "renewedFrom", fcid, + "renterFunds", renterFunds.String(), + "newCollateral", newCollateral.String(), + ) + return renewedContract, true, nil } -func (c *Contractor) computeContractSetChanged(ctx *mCtx, oldSet, newSet, formed []api.ContractMetadata, refreshed, renewed []renewal, toStopUsing map[types.FileContractID]string, contractData map[types.FileContractID]uint64) bool { - name := ctx.ContractSet() +// broadcastRevisions broadcasts contract revisions from the current set of +// contracts. Since we are migrating away from all contracts not in the set and +// are not uploading to those contracts anyway, we only worry about contracts in +// the set. +func (c *Contractor) broadcastRevisions(ctx context.Context, w Worker, contracts []api.ContractMetadata, logger *zap.SugaredLogger) { + if c.revisionBroadcastInterval == 0 { + return // not enabled + } - // build set lookups - inOldSet := make(map[types.FileContractID]struct{}) - for _, c := range oldSet { + cs, err := c.bus.ConsensusState(ctx) + if err != nil { + logger.Warnf("revision broadcast failed to fetch blockHeight: %v", err) + return + } + bh := cs.BlockHeight + + successful, failed := 0, 0 + for _, contract := range contracts { + // check whether broadcasting is necessary + timeSinceRevisionHeight := targetBlockTime * time.Duration(bh-contract.RevisionHeight) + timeSinceLastTry := time.Since(c.revisionLastBroadcast[contract.ID]) + if contract.RevisionHeight == math.MaxUint64 || timeSinceRevisionHeight < c.revisionBroadcastInterval || timeSinceLastTry < c.revisionBroadcastInterval/broadcastRevisionRetriesPerInterval { + continue // nothing to do + } + + // remember that we tried to broadcast this contract now + c.revisionLastBroadcast[contract.ID] = time.Now() + + // broadcast revision + ctx, cancel := context.WithTimeout(ctx, timeoutBroadcastRevision) + err := w.RHPBroadcast(ctx, contract.ID) + cancel() + if utils.IsErr(err, errors.New("transaction has a file contract with an outdated revision number")) { + continue // don't log - revision was already broadcasted + } else if err != nil { + logger.Warnw(fmt.Sprintf("failed to broadcast contract revision: %v", err), + "hk", contract.HostKey, + "fcid", contract.ID) + failed++ + delete(c.revisionLastBroadcast, contract.ID) // reset to try again + continue + } + successful++ + } + logger.Infow("revision broadcast completed", + "successful", successful, + "failed", failed) + + // prune revisionLastBroadcast + contractMap := make(map[types.FileContractID]struct{}) + for _, contract := range contracts { + contractMap[contract.ID] = struct{}{} + } + for contractID := range c.revisionLastBroadcast { + if _, ok := contractMap[contractID]; !ok { + delete(c.revisionLastBroadcast, contractID) + } + } +} + +func (c *Contractor) refreshFundingEstimate(cfg api.AutopilotConfig, contract api.Contract, host api.Host, fee types.Currency, logger *zap.SugaredLogger) types.Currency { + // refresh with 1.2x the funds + refreshAmount := contract.TotalCost.Mul64(6).Div64(5) + + // estimate the txn fee + txnFeeEstimate := fee.Mul64(estimatedFileContractTransactionSetSize) + + // check for a sane minimum that is equal to the initial contract funding + // but without an upper cap. + minInitialContractFunds, _ := initialContractFundingMinMax(cfg) + minimum := c.initialContractFunding(host.Settings, txnFeeEstimate, minInitialContractFunds, types.ZeroCurrency) + refreshAmountCapped := refreshAmount + if refreshAmountCapped.Cmp(minimum) < 0 { + refreshAmountCapped = minimum + } + logger.Infow("refresh estimate", + "fcid", contract.ID, + "refreshAmount", refreshAmount, + "refreshAmountCapped", refreshAmountCapped) + return refreshAmountCapped +} + +func (c *Contractor) shouldArchive(contract api.Contract, bh uint64) error { + if bh > contract.EndHeight()-c.revisionSubmissionBuffer { + return errContractExpired + } else if contract.Revision != nil && contract.Revision.RevisionNumber == math.MaxUint64 { + return errContractMaxRevisionNumber + } else if contract.RevisionNumber == math.MaxUint64 { + return errContractMaxRevisionNumber + } else if contract.State == api.ContractStatePending && bh-contract.StartHeight > ContractConfirmationDeadline { + return errContractNotConfirmed + } + return nil +} + +func (c *Contractor) shouldForgiveFailedRefresh(fcid types.FileContractID) bool { + lastFailure, exists := c.firstRefreshFailure[fcid] + if !exists { + lastFailure = time.Now() + c.firstRefreshFailure[fcid] = lastFailure + } + return time.Since(lastFailure) < failedRefreshForgivenessPeriod +} + +func addLeeway(n uint64, pct float64) uint64 { + if pct < 0 { + panic("given leeway percent has to be positive") + } + return uint64(math.Ceil(float64(n) * pct)) +} + +func calculateMinScore(candidates []scoredHost, numContracts uint64, logger *zap.SugaredLogger) float64 { + logger = logger.Named("calculateMinScore") + + // return early if there's no hosts + if len(candidates) == 0 { + logger.Warn("min host score is set to the smallest non-zero float because there are no candidate hosts") + return minValidScore + } + + // determine the number of random hosts we fetch per iteration when + // calculating the min score - it contains a constant factor in case the + // number of contracts is very low and a linear factor to make sure the + // number is relative to the number of contracts we want to form + randSetSize := 2*int(numContracts) + 50 + + // do multiple rounds to select the lowest score + var lowestScores []float64 + for r := 0; r < 5; r++ { + lowestScore := math.MaxFloat64 + for _, host := range scoredHosts(candidates).randSelectByScore(randSetSize) { + if score := host.score; score < lowestScore && score > 0 { + lowestScore = score + } + } + if lowestScore != math.MaxFloat64 { + lowestScores = append(lowestScores, lowestScore) + } + } + if len(lowestScores) == 0 { + logger.Warn("min host score is set to the smallest non-zero float because the lowest score couldn't be determined") + return minValidScore + } + + // compute the min score + var lowestScore float64 + lowestScore, err := stats.Float64Data(lowestScores).Median() + if err != nil { + panic("never fails since len(candidates) > 0 so len(lowestScores) > 0 as well") + } + minScore := lowestScore / minAllowedScoreLeeway + + // make sure the min score allows for 'numContracts' contracts to be formed + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].score > candidates[j].score + }) + if len(candidates) < int(numContracts) { + return minValidScore + } else if cutoff := candidates[numContracts-1].score; minScore > cutoff { + minScore = cutoff + } + + logger.Infow("finished computing minScore", + "candidates", len(candidates), + "minScore", minScore, + "numContracts", numContracts, + "lowestScore", lowestScore) + return minScore +} + +func canSkipContractMaintenance(ctx context.Context, cfg api.ContractsConfig) (string, bool) { + select { + case <-ctx.Done(): + return "interrupted", true + default: + } + + // no maintenance if no hosts are requested + // + // NOTE: this is an important check because we assume Contracts.Amount is + // not zero in several places + if cfg.Amount == 0 { + return "contracts is set to zero, skipping contract maintenance", true + } + + // no maintenance if no allowance was set + if cfg.Allowance.IsZero() { + return "allowance is set to zero, skipping contract maintenance", true + } + + // no maintenance if no period was set + if cfg.Period == 0 { + return "period is set to zero, skipping contract maintenance", true + } + return "", false +} + +func computeContractSetChanged(ctx *mCtx, alerter alerts.Alerter, bus Bus, churn *accumulatedChurn, logger *zap.SugaredLogger, oldSet, newSet []api.ContractMetadata, toStopUsing map[types.FileContractID]string) (bool, error) { + name := ctx.ContractSet() + + allContracts, err := bus.Contracts(ctx, api.ContractsOpts{}) + if err != nil { + return false, fmt.Errorf("failed to fetch all contracts: %w", err) + } + contractData := make(map[types.FileContractID]uint64) + for _, c := range allContracts { + contractData[c.ID] = c.Size + } + + // build set lookups + inOldSet := make(map[types.FileContractID]struct{}) + for _, c := range oldSet { inOldSet[c.ID] = struct{}{} } inNewSet := make(map[types.FileContractID]struct{}) @@ -485,9 +667,11 @@ func (c *Contractor) computeContractSetChanged(ctx *mCtx, oldSet, newSet, formed // build renewal lookups renewalsFromTo := make(map[types.FileContractID]types.FileContractID) renewalsToFrom := make(map[types.FileContractID]types.FileContractID) - for _, c := range append(refreshed, renewed...) { - renewalsFromTo[c.from.ID] = c.to.ID - renewalsToFrom[c.to.ID] = c.from.ID + for _, c := range allContracts { + if c.RenewedFrom != (types.FileContractID{}) { + renewalsFromTo[c.RenewedFrom] = c.ID + renewalsToFrom[c.ID] = c.RenewedFrom + } } // log added and removed contracts @@ -515,12 +699,12 @@ func (c *Contractor) computeContractSetChanged(ctx *mCtx, oldSet, newSet, formed Time: now, }) setRemovals[contract.ID] = removals - c.logger.Infof("contract %v was removed from the contract set, size: %v, reason: %v", contract.ID, contractData[contract.ID], reason) + logger.Infof("contract %v was removed from the contract set, size: %v, reason: %v", contract.ID, contractData[contract.ID], reason) } } for _, contract := range newSet { _, existed := inOldSet[contract.ID] - _, renewed := renewalsToFrom[contract.ID] + _, renewed := inOldSet[renewalsToFrom[contract.ID]] if !existed && !renewed { if _, exists := setAdditions[contract.ID]; !exists { setAdditions[contract.ID] = contractSetAdditions{ @@ -533,22 +717,14 @@ func (c *Contractor) computeContractSetChanged(ctx *mCtx, oldSet, newSet, formed Time: now, }) setAdditions[contract.ID] = additions - c.logger.Infof("contract %v was added to the contract set, size: %v", contract.ID, contractData[contract.ID]) - } - } - - // log renewed contracts that did not make it into the contract set - for _, fcid := range renewed { - _, exists := inNewSet[fcid.to.ID] - if !exists { - c.logger.Infof("contract %v was renewed but did not make it into the contract set, size: %v", fcid, contractData[fcid.to.ID]) + logger.Infof("contract %v was added to the contract set, size: %v", contract.ID, contractData[contract.ID]) } } // log a warning if the contract set does not contain enough contracts - logFn := c.logger.Infow + logFn := logger.Infow if len(newSet) < int(ctx.state.RS.TotalShards) { - logFn = c.logger.Warnw + logFn = logger.Warnw } // record churn metrics @@ -568,897 +744,45 @@ func (c *Contractor) computeContractSetChanged(ctx *mCtx, oldSet, newSet, formed Direction: api.ChurnDirRemoved, Reason: removal.Removals[0].Reason, Timestamp: now, - }) - } - if len(metrics) > 0 { - if err := c.bus.RecordContractSetChurnMetric(ctx, metrics...); err != nil { - c.logger.Error("failed to record contract set churn metric:", err) - } - } - - // log the contract set after maintenance - logFn( - "contractset after maintenance", - "formed", len(formed), - "renewed", len(renewed), - "refreshed", len(refreshed), - "contracts", len(newSet), - "added", len(setAdditions), - "removed", len(setRemovals), - ) - hasChanged := len(setAdditions)+len(setRemovals) > 0 - if hasChanged { - if !c.HasAlert(ctx, alertChurnID) { - c.churn.Reset() - } - c.churn.Apply(setAdditions, setRemovals) - c.alerter.RegisterAlert(ctx, c.churn.Alert(name)) - } - return hasChanged -} - -func (c *Contractor) runContractChecks(ctx *mCtx, hostChecks map[types.PublicKey]*api.HostCheck, contracts []api.Contract, inCurrentSet map[types.FileContractID]struct{}, bh uint64) (toKeep []api.ContractMetadata, toArchive, toStopUsing map[types.FileContractID]string, toRefresh, toRenew []contractInfo) { - select { - case <-ctx.Done(): - return - default: - } - c.logger.Info("running contract checks") - - // create new IP filter - ipFilter := c.newIPFilter() - - // calculate 'maxKeepLeeway' which defines the amount of contracts we'll be - // lenient towards when we fail to either fetch a valid price table or the - // contract's revision - maxKeepLeeway := addLeeway(ctx.WantedContracts(), 1-leewayPctRequiredContracts) - remainingKeepLeeway := maxKeepLeeway - - var notfound int - defer func() { - c.logger.Infow( - "contracts checks completed", - "contracts", len(contracts), - "notfound", notfound, - "usedKeepLeeway", maxKeepLeeway-remainingKeepLeeway, - "toKeep", len(toKeep), - "toArchive", len(toArchive), - "toRefresh", len(toRefresh), - "toRenew", len(toRenew), - ) - }() - - // return variables - toArchive = make(map[types.FileContractID]string) - toStopUsing = make(map[types.FileContractID]string) - - // when checking the contracts, do so from largest to smallest. That way, we - // prefer larger hosts on redundant networks. - contracts = append([]api.Contract{}, contracts...) - sort.Slice(contracts, func(i, j int) bool { - return contracts[i].FileSize() > contracts[j].FileSize() - }) - - // check all contracts -LOOP: - for _, contract := range contracts { - // break if interrupted - select { - case <-ctx.Done(): - break LOOP - default: - } - - // convenience variables - fcid := contract.ID - hk := contract.HostKey - - // check if contract is ready to be archived. - if bh > contract.EndHeight()-c.revisionSubmissionBuffer { - toArchive[fcid] = errContractExpired.Error() - } else if contract.Revision != nil && contract.Revision.RevisionNumber == math.MaxUint64 { - toArchive[fcid] = errContractMaxRevisionNumber.Error() - } else if contract.RevisionNumber == math.MaxUint64 { - toArchive[fcid] = errContractMaxRevisionNumber.Error() - } else if contract.State == api.ContractStatePending && bh-contract.StartHeight > contractConfirmationDeadline { - toArchive[fcid] = errContractNotConfirmed.Error() - } - if _, archived := toArchive[fcid]; archived { - toStopUsing[fcid] = toArchive[fcid] - continue - } - - // fetch host from hostdb - host, err := c.bus.Host(ctx, hk) - if err != nil { - c.logger.Warn(fmt.Sprintf("missing host, err: %v", err), "hk", hk) - toStopUsing[fcid] = api.ErrUsabilityHostNotFound.Error() - notfound++ - continue - } - - // fetch host checks - check, ok := hostChecks[hk] - if !ok { - // this is only possible due to developer error, if there is no - // check the host would have been missing, so we treat it the same - c.logger.Warnw("missing host check", "hk", hk) - toStopUsing[fcid] = api.ErrUsabilityHostNotFound.Error() - continue - } - - // if the host is blocked we ignore it, it might be unblocked later - if host.Blocked { - c.logger.Infow("unusable host", "hk", hk, "fcid", fcid, "reasons", api.ErrUsabilityHostBlocked.Error()) - toStopUsing[fcid] = api.ErrUsabilityHostBlocked.Error() - continue - } - - // check if the host is still usable - if !check.Usability.IsUsable() { - reasons := check.Usability.UnusableReasons() - toStopUsing[fcid] = strings.Join(reasons, ",") - c.logger.Infow("unusable host", "hk", hk, "fcid", fcid, "reasons", reasons) - continue - } - - // if we were not able to the contract's revision, we can't properly - // perform the checks that follow, however we do want to be lenient if - // this contract is in the current set and we still have leeway left - _, inSet := inCurrentSet[fcid] - if contract.Revision == nil { - if !inSet || remainingKeepLeeway == 0 { - toStopUsing[fcid] = errContractNoRevision.Error() - } else if ctx.ShouldFilterRedundantIPs() && ipFilter.HasRedundantIP(host) { - toStopUsing[fcid] = fmt.Sprintf("%v; %v", api.ErrUsabilityHostRedundantIP, errContractNoRevision) - hostChecks[contract.HostKey].Usability.RedundantIP = true - } else { - toKeep = append(toKeep, contract.ContractMetadata) - remainingKeepLeeway-- // we let it slide - } - continue // can't perform contract checks without revision - } - - // decide whether the contract is still good - ci := contractInfo{contract: contract, host: host} - usable, recoverable, refresh, renew, reasons := c.isUsableContract(ctx.AutopilotConfig(), ctx.state.RS, ci, inSet, bh, ipFilter) - ci.usable = usable - ci.recoverable = recoverable - if !usable { - c.logger.Infow( - "unusable contract", - "hk", hk, - "fcid", fcid, - "reasons", reasons, - "refresh", refresh, - "renew", renew, - "recoverable", recoverable, - ) - } - if len(reasons) > 0 { - toStopUsing[fcid] = strings.Join(reasons, ",") - } - - if renew { - toRenew = append(toRenew, ci) - } else if refresh { - toRefresh = append(toRefresh, ci) - } else if usable { - toKeep = append(toKeep, ci.contract.ContractMetadata) - } - } - - return toKeep, toArchive, toStopUsing, toRefresh, toRenew -} - -func (c *Contractor) runHostChecks(ctx *mCtx, hosts []api.Host, minScore float64) (map[types.PublicKey]*api.HostCheck, error) { - // fetch consensus state - cs, err := c.bus.ConsensusState(ctx) - if err != nil { - return nil, err - } - - // create gouging checker - gc := worker.NewGougingChecker(ctx.state.GS, cs, ctx.state.Fee, ctx.state.Period(), ctx.RenewWindow()) - - // check all hosts - checks := make(map[types.PublicKey]*api.HostCheck) - for _, h := range hosts { - h.PriceTable.HostBlockHeight = cs.BlockHeight // ignore HostBlockHeight - checks[h.PublicKey] = checkHost(ctx.AutopilotConfig(), ctx.state.RS, gc, h, minScore) - } - return checks, nil -} - -func (c *Contractor) runContractFormations(ctx *mCtx, w Worker, candidates scoredHosts, usedHosts map[types.PublicKey]struct{}, unusableHosts unusableHostsBreakdown, missing uint64, budget *types.Currency) (formed []api.ContractMetadata, _ error) { - select { - case <-c.shutdownCtx.Done(): - return nil, nil - default: - } - - c.logger.Infow( - "run contract formations", - "usedHosts", len(usedHosts), - "required", ctx.WantedContracts(), - "missing", missing, - "budget", budget, - ) - defer func() { - c.logger.Infow( - "contract formations completed", - "formed", len(formed), - "budget", budget, - ) - }() - - // build a new host filter - filter := c.newIPFilter() - for _, h := range candidates { - if _, used := usedHosts[h.host.PublicKey]; used { - _ = filter.HasRedundantIP(h.host) - } - } - - // select candidates - wanted := int(addLeeway(missing, leewayPctCandidateHosts)) - selected := candidates.randSelectByScore(wanted) - - // print warning if we couldn't find enough hosts were found - c.logger.Infof("looking for %d candidate hosts", wanted) - if len(selected) < wanted { - var msg string - if len(selected) == 0 { - msg = "no candidate hosts found" - } else { - msg = fmt.Sprintf("only found %d candidate host(s) out of the %d we wanted", len(selected), wanted) - } - if len(candidates) >= wanted { - c.logger.Warnw(msg, unusableHosts.keysAndValues()...) - } else { - c.logger.Infow(msg, unusableHosts.keysAndValues()...) - } - } - - // fetch consensus state - cs, err := c.bus.ConsensusState(ctx) - if err != nil { - return nil, err - } - lastStateUpdate := time.Now() - - // prepare a gouging checker - gc := ctx.GougingChecker(cs) - - // calculate min/max contract funds - minInitialContractFunds, maxInitialContractFunds := initialContractFundingMinMax(ctx.AutopilotConfig()) - -LOOP: - for h := 0; missing > 0 && h < len(selected); h++ { - host := selected[h].host - - // break if the autopilot is stopped - select { - case <-ctx.Done(): - break LOOP - default: - } - - // fetch a new price table if necessary - if err := refreshPriceTable(ctx, w, &host); err != nil { - c.logger.Errorf("failed to fetch price table for candidate host %v: %v", host.PublicKey, err) - continue - } - - // fetch a new consensus state if necessary, we have to do this - // frequently to ensure we're not performing gouging checks with old - // consensus state - if time.Since(lastStateUpdate) > time.Minute { - if css, err := c.bus.ConsensusState(ctx); err != nil { - c.logger.Errorf("could not fetch consensus state, err: %v", err) - } else { - cs = css - gc = ctx.GougingChecker(cs) - } - } - - // perform gouging checks on the fly to ensure the host is not gouging its prices - if breakdown := gc.Check(nil, &host.PriceTable.HostPriceTable); breakdown.Gouging() { - c.logger.Errorw("candidate host became unusable", "hk", host.PublicKey, "reasons", breakdown.String()) - continue - } - - // check if we already have a contract with a host on that subnet - if ctx.ShouldFilterRedundantIPs() && filter.HasRedundantIP(host) { - continue - } - - // form the contract - formedContract, proceed, err := c.formContract(ctx, w, host, minInitialContractFunds, maxInitialContractFunds, budget) - if err != nil { - // remove the host from the filter - filter.Remove(host) - } else { - // add contract to contract set - formed = append(formed, formedContract) - missing-- - } - if !proceed { - break - } - } - - return formed, nil -} - -// runRevisionBroadcast broadcasts contract revisions from the current set of -// contracts. Since we are migrating away from all contracts not in the set and -// are not uploading to those contracts anyway, we only worry about contracts in -// the set. -func (c *Contractor) runRevisionBroadcast(ctx context.Context, w Worker, allContracts []api.Contract, isInSet map[types.FileContractID]struct{}) { - if c.revisionBroadcastInterval == 0 { - return // not enabled - } - - cs, err := c.bus.ConsensusState(ctx) - if err != nil { - c.logger.Warnf("revision broadcast failed to fetch blockHeight: %v", err) - return - } - bh := cs.BlockHeight - - successful, failed := 0, 0 - for _, contract := range allContracts { - // check whether broadcasting is necessary - timeSinceRevisionHeight := targetBlockTime * time.Duration(bh-contract.RevisionHeight) - timeSinceLastTry := time.Since(c.revisionLastBroadcast[contract.ID]) - _, inSet := isInSet[contract.ID] - if !inSet || contract.RevisionHeight == math.MaxUint64 || timeSinceRevisionHeight < c.revisionBroadcastInterval || timeSinceLastTry < c.revisionBroadcastInterval/broadcastRevisionRetriesPerInterval { - continue // nothing to do - } - - // remember that we tried to broadcast this contract now - c.revisionLastBroadcast[contract.ID] = time.Now() - - // ignore contracts for which we weren't able to obtain a revision - if contract.Revision == nil { - c.logger.Warnw("failed to broadcast contract revision: failed to fetch revision", - "hk", contract.HostKey, - "fcid", contract.ID) - continue - } - - // broadcast revision - ctx, cancel := context.WithTimeout(ctx, timeoutBroadcastRevision) - err := w.RHPBroadcast(ctx, contract.ID) - cancel() - if utils.IsErr(err, errors.New("transaction has a file contract with an outdated revision number")) { - continue // don't log - revision was already broadcasted - } else if err != nil { - c.logger.Warnw(fmt.Sprintf("failed to broadcast contract revision: %v", err), - "hk", contract.HostKey, - "fcid", contract.ID) - failed++ - delete(c.revisionLastBroadcast, contract.ID) // reset to try again - continue - } - successful++ - } - c.logger.Infow("revision broadcast completed", - "successful", successful, - "failed", failed) - - // prune revisionLastBroadcast - contractMap := make(map[types.FileContractID]struct{}) - for _, contract := range allContracts { - contractMap[contract.ID] = struct{}{} - } - for contractID := range c.revisionLastBroadcast { - if _, ok := contractMap[contractID]; !ok { - delete(c.revisionLastBroadcast, contractID) - } - } -} - -func (c *Contractor) runContractRenewals(ctx *mCtx, w Worker, toRenew []contractInfo, budget *types.Currency, limit int) (renewals []renewal, toKeep []api.ContractMetadata) { - c.logger.Infow( - "run contracts renewals", - "torenew", len(toRenew), - "limit", limit, - "budget", budget, - ) - defer func() { - c.logger.Infow( - "contracts renewals completed", - "renewals", len(renewals), - "tokeep", len(toKeep), - "budget", budget, - ) - }() - - var i int - for i = 0; i < len(toRenew); i++ { - // check if interrupted - select { - case <-ctx.Done(): - return - default: - } - - // limit the number of contracts to renew - if len(renewals)+len(toKeep) >= limit { - break - } - - // renew and add if it succeeds or if its usable - contract := toRenew[i].contract.ContractMetadata - renewed, proceed, err := c.renewContract(ctx, w, toRenew[i], budget) - if err != nil { - // don't register an alert for hosts that are out of funds since the - // user can't do anything about it - if !(worker.IsErrHost(err) && utils.IsErr(err, cwallet.ErrNotEnoughFunds)) { - c.alerter.RegisterAlert(ctx, newContractRenewalFailedAlert(contract, !proceed, err)) - } - c.logger.With(zap.Error(err)). - With("fcid", toRenew[i].contract.ID). - With("hostKey", toRenew[i].contract.HostKey). - With("proceed", proceed). - Errorw("failed to renew contract") - if toRenew[i].usable { - toKeep = append(toKeep, toRenew[i].contract.ContractMetadata) - } - } else { - c.alerter.DismissAlerts(ctx, alerts.IDForContract(alertRenewalFailedID, contract.ID)) - renewals = append(renewals, renewal{from: contract, to: renewed, ci: toRenew[i]}) - } - - // break if we don't want to proceed - if !proceed { - break - } - } - - // loop through the remaining renewals and add them to the keep list if - // they're usable and we have 'limit' left - for j := i; j < len(toRenew); j++ { - if len(renewals)+len(toKeep) < limit && toRenew[j].usable { - toKeep = append(toKeep, toRenew[j].contract.ContractMetadata) - } - } - - return renewals, toKeep -} - -func (c *Contractor) runContractRefreshes(ctx *mCtx, w Worker, toRefresh []contractInfo, budget *types.Currency) (refreshed []renewal, _ error) { - c.logger.Infow( - "run contracts refreshes", - "torefresh", len(toRefresh), - "budget", budget, - ) - defer func() { - c.logger.Infow( - "contracts refreshes completed", - "refreshed", len(refreshed), - "budget", budget, - ) - }() - - for _, ci := range toRefresh { - // check if interrupted - select { - case <-ctx.Done(): - return - default: - } - - // refresh and add if it succeeds - renewed, proceed, err := c.refreshContract(ctx, w, ci, budget) - if err == nil { - refreshed = append(refreshed, renewal{from: ci.contract.ContractMetadata, to: renewed, ci: ci}) - } - - // break if we don't want to proceed - if !proceed { - break - } - } - - return refreshed, nil -} - -func (c *Contractor) initialContractFunding(settings rhpv2.HostSettings, txnFee, minFunding, maxFunding types.Currency) types.Currency { - if !maxFunding.IsZero() && minFunding.Cmp(maxFunding) > 0 { - panic("given min is larger than max") // developer error - } - - funding := settings.ContractPrice.Add(txnFee).Mul64(10) // TODO arbitrary multiplier - if !minFunding.IsZero() && funding.Cmp(minFunding) < 0 { - return minFunding - } - if !maxFunding.IsZero() && funding.Cmp(maxFunding) > 0 { - return maxFunding - } - return funding -} - -func (c *Contractor) refreshFundingEstimate(cfg api.AutopilotConfig, ci contractInfo, fee types.Currency) types.Currency { - // refresh with 1.2x the funds - refreshAmount := ci.contract.TotalCost.Mul64(6).Div64(5) - - // estimate the txn fee - txnFeeEstimate := fee.Mul64(estimatedFileContractTransactionSetSize) - - // check for a sane minimum that is equal to the initial contract funding - // but without an upper cap. - minInitialContractFunds, _ := initialContractFundingMinMax(cfg) - minimum := c.initialContractFunding(ci.host.Settings, txnFeeEstimate, minInitialContractFunds, types.ZeroCurrency) - refreshAmountCapped := refreshAmount - if refreshAmountCapped.Cmp(minimum) < 0 { - refreshAmountCapped = minimum - } - c.logger.Infow("refresh estimate", - "fcid", ci.contract.ID, - "refreshAmount", refreshAmount, - "refreshAmountCapped", refreshAmountCapped) - return refreshAmountCapped -} - -func (c *Contractor) calculateMinScore(candidates []scoredHost, numContracts uint64) float64 { - // return early if there's no hosts - if len(candidates) == 0 { - c.logger.Warn("min host score is set to the smallest non-zero float because there are no candidate hosts") - return minValidScore - } - - // determine the number of random hosts we fetch per iteration when - // calculating the min score - it contains a constant factor in case the - // number of contracts is very low and a linear factor to make sure the - // number is relative to the number of contracts we want to form - randSetSize := 2*int(numContracts) + 50 - - // do multiple rounds to select the lowest score - var lowestScores []float64 - for r := 0; r < 5; r++ { - lowestScore := math.MaxFloat64 - for _, host := range scoredHosts(candidates).randSelectByScore(randSetSize) { - if host.score < lowestScore { - lowestScore = host.score - } - } - lowestScores = append(lowestScores, lowestScore) - } - - // compute the min score - var lowestScore float64 - lowestScore, err := stats.Float64Data(lowestScores).Median() - if err != nil { - panic("never fails since len(candidates) > 0 so len(lowestScores) > 0 as well") - } - minScore := lowestScore / minAllowedScoreLeeway - - // make sure the min score allows for 'numContracts' contracts to be formed - sort.Slice(candidates, func(i, j int) bool { - return candidates[i].score > candidates[j].score - }) - if len(candidates) < int(numContracts) { - return minValidScore - } else if cutoff := candidates[numContracts-1].score; minScore > cutoff { - minScore = cutoff - } - - c.logger.Infow("finished computing minScore", - "candidates", len(candidates), - "minScore", minScore, - "numContracts", numContracts, - "lowestScore", lowestScore) - return minScore -} - -func (c *Contractor) candidateHosts(ctx *mCtx, hosts []api.Host, usedHosts map[types.PublicKey]struct{}, minScore float64) ([]scoredHost, unusableHostsBreakdown, error) { - start := time.Now() - - // fetch consensus state - cs, err := c.bus.ConsensusState(ctx) - if err != nil { - return nil, unusableHostsBreakdown{}, err - } - - // create a gouging checker - gc := ctx.GougingChecker(cs) - - // select unused hosts that passed a scan - var unused []api.Host - var blocked, excluded, notcompletedscan int - for _, h := range hosts { - // filter out used hosts - if _, exclude := usedHosts[h.PublicKey]; exclude { - excluded++ - continue - } - // filter out blocked hosts - if h.Blocked { - blocked++ - continue - } - // filter out unscanned hosts - if !h.Scanned { - notcompletedscan++ - continue - } - unused = append(unused, h) - } - - c.logger.Infow(fmt.Sprintf("selected %d (potentially) usable hosts for scoring out of %d", len(unused), len(hosts)), - "excluded", excluded, - "notcompletedscan", notcompletedscan, - "used", len(usedHosts)) - - // score all unused hosts - var unusableHosts unusableHostsBreakdown - var unusable, zeros int - var candidates []scoredHost - for _, h := range unused { - // NOTE: use the price table stored on the host for gouging checks when - // looking for candidate hosts, fetching the price table on the fly here - // slows contract maintenance down way too much, we re-evaluate the host - // right before forming the contract to ensure we do not form a contract - // with a host that's gouging its prices. - // - // NOTE: ignore the pricetable's HostBlockHeight by setting it to our - // own blockheight - h.PriceTable.HostBlockHeight = cs.BlockHeight - hc := checkHost(ctx.AutopilotConfig(), ctx.state.RS, gc, h, minScore) - if hc.Usability.IsUsable() { - candidates = append(candidates, scoredHost{h, hc.Score.Score()}) - continue - } - - // keep track of unusable host results - unusableHosts.track(hc.Usability) - if hc.Score.Score() == 0 { - zeros++ - } - unusable++ - } - - c.logger.Infow(fmt.Sprintf("scored %d unused hosts out of %v, took %v", len(candidates), len(unused), time.Since(start)), - "zeroscore", zeros, - "unusable", unusable, - "used", len(usedHosts)) - - return candidates, unusableHosts, nil -} - -func (c *Contractor) renewContract(ctx *mCtx, w Worker, ci contractInfo, budget *types.Currency) (cm api.ContractMetadata, proceed bool, err error) { - if ci.contract.Revision == nil { - return api.ContractMetadata{}, true, errors.New("can't renew contract without a revision") - } - log := c.logger.With("to_renew", ci.contract.ID, "hk", ci.contract.HostKey, "hostVersion", ci.host.Settings.Version, "hostRelease", ci.host.Settings.Release) - - // convenience variables - contract := ci.contract - settings := ci.host.Settings - pt := ci.host.PriceTable.HostPriceTable - fcid := contract.ID - rev := contract.Revision - hk := contract.HostKey - - // fetch consensus state - cs, err := c.bus.ConsensusState(ctx) - if err != nil { - return api.ContractMetadata{}, false, err - } - - // calculate the renter funds for the renewal a.k.a. the funds the renter will - // be able to spend - minRenterFunds, _ := initialContractFundingMinMax(ctx.AutopilotConfig()) - renterFunds := renewFundingEstimate(minRenterFunds, contract.TotalCost, contract.RenterFunds(), log) - - // check our budget - if budget.Cmp(renterFunds) < 0 { - log.Infow("insufficient budget", "budget", budget, "needed", renterFunds) - return api.ContractMetadata{}, false, errors.New("insufficient budget") - } - - // sanity check the endheight is not the same on renewals - endHeight := ctx.EndHeight() - if endHeight <= rev.EndHeight() { - log.Infow("invalid renewal endheight", "oldEndheight", rev.EndHeight(), "newEndHeight", endHeight, "period", ctx.state.Period, "bh", cs.BlockHeight) - return api.ContractMetadata{}, false, fmt.Errorf("renewal endheight should surpass the current contract endheight, %v <= %v", endHeight, rev.EndHeight()) - } - - // calculate the expected new storage - expectedNewStorage := renterFundsToExpectedStorage(renterFunds, endHeight-cs.BlockHeight, pt) - - // renew the contract - resp, err := w.RHPRenew(ctx, fcid, endHeight, hk, contract.SiamuxAddr, settings.Address, ctx.state.Address, renterFunds, types.ZeroCurrency, *budget, expectedNewStorage, settings.WindowSize) - if err != nil { - log.Errorw( - "renewal failed", - zap.Error(err), - "endHeight", endHeight, - "renterFunds", renterFunds, - "expectedNewStorage", expectedNewStorage, - ) - if utils.IsErr(err, wallet.ErrInsufficientBalance) && !worker.IsErrHost(err) { - return api.ContractMetadata{}, false, err - } - return api.ContractMetadata{}, true, err - } - - // update the budget - *budget = budget.Sub(resp.FundAmount) - - // persist the contract - renewedContract, err := c.bus.AddRenewedContract(ctx, resp.Contract, resp.ContractPrice, renterFunds, cs.BlockHeight, fcid, api.ContractStatePending) - if err != nil { - log.Errorw(fmt.Sprintf("renewal failed to persist, err: %v", err)) - return api.ContractMetadata{}, false, err - } - - newCollateral := resp.Contract.Revision.MissedHostPayout().Sub(resp.ContractPrice) - log.Infow( - "renewal succeeded", - "fcid", renewedContract.ID, - "renewedFrom", fcid, - "renterFunds", renterFunds.String(), - "newCollateral", newCollateral.String(), - ) - return renewedContract, true, nil -} - -func (c *Contractor) refreshContract(ctx *mCtx, w Worker, ci contractInfo, budget *types.Currency) (cm api.ContractMetadata, proceed bool, err error) { - if ci.contract.Revision == nil { - return api.ContractMetadata{}, true, errors.New("can't refresh contract without a revision") - } - log := c.logger.With("to_renew", ci.contract.ID, "hk", ci.contract.HostKey, "hostVersion", ci.host.Settings.Version, "hostRelease", ci.host.Settings.Release) - - // convenience variables - contract := ci.contract - settings := ci.host.Settings - pt := ci.host.PriceTable.HostPriceTable - fcid := contract.ID - rev := contract.Revision - hk := contract.HostKey - - // fetch consensus state - cs, err := c.bus.ConsensusState(ctx) - if err != nil { - return api.ContractMetadata{}, false, err - } - - // calculate the renter funds - var renterFunds types.Currency - if isOutOfFunds(ctx.AutopilotConfig(), pt, ci.contract) { - renterFunds = c.refreshFundingEstimate(ctx.AutopilotConfig(), ci, ctx.state.Fee) - } else { - renterFunds = rev.ValidRenterPayout() // don't increase funds - } - - // check our budget - if budget.Cmp(renterFunds) < 0 { - log.Warnw("insufficient budget for refresh", "hk", hk, "fcid", fcid, "budget", budget, "needed", renterFunds) - return api.ContractMetadata{}, false, fmt.Errorf("insufficient budget: %s < %s", budget.String(), renterFunds.String()) - } - - expectedNewStorage := renterFundsToExpectedStorage(renterFunds, contract.EndHeight()-cs.BlockHeight, pt) - unallocatedCollateral := contract.RemainingCollateral() - - // a refresh should always result in a contract that has enough collateral - minNewCollateral := minRemainingCollateral(ctx.AutopilotConfig(), ctx.state.RS, renterFunds, settings, pt).Mul64(2) - - // maxFundAmount is the remaining funds of the contract to refresh plus the - // budget since the previous contract was in the same period - maxFundAmount := budget.Add(rev.ValidRenterPayout()) - - // renew the contract - resp, err := w.RHPRenew(ctx, contract.ID, contract.EndHeight(), hk, contract.SiamuxAddr, settings.Address, ctx.state.Address, renterFunds, minNewCollateral, maxFundAmount, expectedNewStorage, settings.WindowSize) - if err != nil { - if strings.Contains(err.Error(), "new collateral is too low") { - log.Infow("refresh failed: contract wouldn't have enough collateral after refresh", - "hk", hk, - "fcid", fcid, - "unallocatedCollateral", unallocatedCollateral.String(), - "minNewCollateral", minNewCollateral.String(), - ) - return api.ContractMetadata{}, true, err - } - log.Errorw("refresh failed", zap.Error(err), "hk", hk, "fcid", fcid) - if utils.IsErr(err, wallet.ErrInsufficientBalance) && !worker.IsErrHost(err) { - return api.ContractMetadata{}, false, err - } - return api.ContractMetadata{}, true, err - } - - // update the budget - *budget = budget.Sub(resp.FundAmount) - - // persist the contract - refreshedContract, err := c.bus.AddRenewedContract(ctx, resp.Contract, resp.ContractPrice, renterFunds, cs.BlockHeight, contract.ID, api.ContractStatePending) - if err != nil { - log.Errorw("adding refreshed contract failed", zap.Error(err), "hk", hk, "fcid", fcid) - return api.ContractMetadata{}, false, err - } - - // add to renewed set - newCollateral := resp.Contract.Revision.MissedHostPayout().Sub(resp.ContractPrice) - log.Infow("refresh succeeded", - "fcid", refreshedContract.ID, - "renewedFrom", contract.ID, - "renterFunds", renterFunds.String(), - "minNewCollateral", minNewCollateral.String(), - "newCollateral", newCollateral.String(), - ) - return refreshedContract, true, nil -} - -func (c *Contractor) formContract(ctx *mCtx, w Worker, host api.Host, minInitialContractFunds, maxInitialContractFunds types.Currency, budget *types.Currency) (cm api.ContractMetadata, proceed bool, err error) { - log := c.logger.With("hk", host.PublicKey, "hostVersion", host.Settings.Version, "hostRelease", host.Settings.Release) - - // convenience variables - hk := host.PublicKey - - // fetch host settings - scan, err := w.RHPScan(ctx, hk, host.NetAddress, 0) - if err != nil { - log.Infow(err.Error(), "hk", hk) - return api.ContractMetadata{}, true, err - } - - // fetch consensus state - cs, err := c.bus.ConsensusState(ctx) - if err != nil { - return api.ContractMetadata{}, false, err - } - - // check our budget - txnFee := ctx.state.Fee.Mul64(estimatedFileContractTransactionSetSize) - renterFunds := initialContractFunding(scan.Settings, txnFee, minInitialContractFunds, maxInitialContractFunds) - if budget.Cmp(renterFunds) < 0 { - log.Infow("insufficient budget", "budget", budget, "needed", renterFunds) - return api.ContractMetadata{}, false, errors.New("insufficient budget") + }) } - - // calculate the host collateral - endHeight := ctx.EndHeight() - expectedStorage := renterFundsToExpectedStorage(renterFunds, endHeight-cs.BlockHeight, scan.PriceTable) - hostCollateral := rhpv2.ContractFormationCollateral(ctx.Period(), expectedStorage, scan.Settings) - - // form contract - contract, _, err := w.RHPForm(ctx, endHeight, hk, host.NetAddress, ctx.state.Address, renterFunds, hostCollateral) - if err != nil { - // TODO: keep track of consecutive failures and break at some point - log.Errorw(fmt.Sprintf("contract formation failed, err: %v", err), "hk", hk) - if strings.Contains(err.Error(), wallet.ErrInsufficientBalance.Error()) { - return api.ContractMetadata{}, false, err + if len(metrics) > 0 { + if err := bus.RecordContractSetChurnMetric(ctx, metrics...); err != nil { + logger.Error("failed to record contract set churn metric:", err) } - return api.ContractMetadata{}, true, err } - // update the budget - *budget = budget.Sub(renterFunds) + // log the contract set after maintenance + logFn( + "contractset after maintenance", + "contracts", len(newSet), + "added", len(setAdditions), + "removed", len(setRemovals), + ) - // persist contract in store - contractPrice := contract.Revision.MissedHostPayout().Sub(hostCollateral) - formedContract, err := c.bus.AddContract(ctx, contract, contractPrice, renterFunds, cs.BlockHeight, api.ContractStatePending) - if err != nil { - log.Errorw(fmt.Sprintf("contract formation failed, err: %v", err), "hk", hk) - return api.ContractMetadata{}, true, err + hasAlert := func(id types.Hash256) bool { + ar, err := alerter.Alerts(ctx, alerts.AlertsOpts{Offset: 0, Limit: -1}) + if err != nil { + logger.Errorf("failed to fetch alerts: %v", err) + return false + } + for _, alert := range ar.Alerts { + if alert.ID == id { + return true + } + } + return false } - log.Infow("formation succeeded", - "fcid", formedContract.ID, - "renterFunds", renterFunds.String(), - "collateral", hostCollateral.String(), - ) - return formedContract, true, nil -} - -func addLeeway(n uint64, pct float64) uint64 { - if pct < 0 { - panic("given leeway percent has to be positive") + hasChanged := len(setAdditions)+len(setRemovals) > 0 + if hasChanged { + if !hasAlert(alertChurnID) { + churn.Reset() + } + churn.Apply(setAdditions, setRemovals) + alerter.RegisterAlert(ctx, churn.Alert(name)) } - return uint64(math.Ceil(float64(n) * pct)) + return hasChanged, nil } func initialContractFunding(settings rhpv2.HostSettings, txnFee, minFunding, maxFunding types.Currency) types.Currency { @@ -1559,37 +883,518 @@ func renterFundsToExpectedStorage(renterFunds types.Currency, duration uint64, p return expectedStorage.Big().Uint64() } -func (c *Contractor) HasAlert(ctx context.Context, id types.Hash256) bool { - ar, err := c.alerter.Alerts(ctx, alerts.AlertsOpts{Offset: 0, Limit: -1}) +// performContractChecks performs maintenance on existing contracts, +// renewing/refreshing any that need it and filtering out contracts that should +// no longer be used. The 'ipFilter' is updated to contain all hosts that we +// keep contracts with and the 'dropOutReasons' map is updated with the reasons +// for dropping out of the set. If a contract is refreshed or renewed, the +// 'remainingFunds' are adjusted. +func performContractChecks(ctx *mCtx, alerter alerts.Alerter, bus Bus, w Worker, cc contractChecker, cr contractReviser, ipFilter *hostSet, logger *zap.SugaredLogger, remainingFunds *types.Currency) ([]api.ContractMetadata, map[types.FileContractID]string, error) { + var filteredContracts []api.ContractMetadata + keepContract := func(c api.ContractMetadata, h api.Host) { + filteredContracts = append(filteredContracts, c) + ipFilter.Add(h) + } + churnReasons := make(map[types.FileContractID]string) + + // fetch all contracts we already have + logger.Info("fetching existing contracts") + start := time.Now() + resp, err := w.Contracts(ctx, timeoutHostRevision) if err != nil { - c.logger.Errorf("failed to fetch alerts: %v", err) - return false + return nil, nil, err + } + contracts := resp.Contracts + logger.With("elapsed", time.Since(start)).Info("done fetching existing contracts") + + // print the reason for the missing revisions + for _, c := range contracts { + if c.Revision == nil { + logger.With("error", resp.Errors[c.HostKey]). + With("hostKey", c.HostKey). + With("contractID", c.ID).Debug("failed to fetch contract revision") + } } - for _, alert := range ar.Alerts { - if alert.ID == id { - return true + + // sort them by whether they are in the current set and their size + ctx.SortContractsForMaintenance(contracts) + + // allow for a leeway of 10% of the required contracts for special cases such as failing to fetch + remainingLeeway := addLeeway(ctx.WantedContracts(), 1-leewayPctRequiredContracts) + + // perform checks on contracts one-by-one renewing/refreshing + // contracts as necessary and filtering out contracts that should no + // longer be used + logger.With("contracts", len(contracts)).Info("checking existing contracts") + var renewed, refreshed int + for _, c := range contracts { + inSet := c.InSet(ctx.Set()) + + logger := logger.With("contractID", c.ID). + With("inSet", inSet). + With("hostKey", c.HostKey). + With("revisionNumber", c.RevisionNumber). + With("size", c.FileSize()). + With("state", c.State). + With("remainingLeeway", remainingLeeway). + With("revisionAvailable", c.Revision != nil). + With("filteredContracts", len(filteredContracts)). + With("wantedContracts", ctx.WantedContracts()) + + logger.Debug("checking contract") + + // abort if we have enough contracts + if uint64(len(filteredContracts)) >= ctx.WantedContracts() { + churnReasons[c.ID] = "truncated" + logger.Debug("ignoring contract since we have enough contracts") + continue + } + + // fetch recent consensus state + cs, err := bus.ConsensusState(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to fetch consensus state: %w", err) } + bh := cs.BlockHeight + logger = logger.With("blockHeight", bh) + + // check if contract is ready to be archived. + if reason := cc.shouldArchive(c, bh); reason != nil { + if err := bus.ArchiveContracts(ctx, map[types.FileContractID]string{ + c.ID: reason.Error(), + }); err != nil { + logger.With(zap.Error(err)).Error("failed to archive contract") + } else { + logger.Debug("successfully archived contract") + } + churnReasons[c.ID] = reason.Error() + continue + } + + // fetch host + host, err := bus.Host(ctx, c.HostKey) + if err != nil { + logger.With(zap.Error(err)).Warn("missing host") + churnReasons[c.ID] = api.ErrUsabilityHostNotFound.Error() + continue + } + + // extend logger + logger = logger.With("addresses", host.ResolvedAddresses). + With("blocked", host.Blocked) + + // check if host is blocked + if host.Blocked { + logger.Info("host is blocked") + churnReasons[c.ID] = api.ErrUsabilityHostBlocked.Error() + continue + } + + // check if host has a redundant ip + if ctx.ShouldFilterRedundantIPs() && ipFilter.HasRedundantIP(host) { + logger.Info("host has redundant IP") + churnReasons[c.ID] = api.ErrUsabilityHostRedundantIP.Error() + continue + } + + // get check + check, ok := host.Checks[ctx.ApID()] + if !ok { + logger.Warn("missing host check") + churnReasons[c.ID] = api.ErrUsabilityHostNotFound.Error() + continue + } + + // check usability + if !check.Usability.IsUsable() { + reasons := strings.Join(check.Usability.UnusableReasons(), ",") + logger.With("reasons", reasons).Info("unusable host") + churnReasons[c.ID] = reasons + continue + } + + // check if revision is available + if c.Revision == nil { + if inSet && remainingLeeway > 0 { + logger.Debug("keeping contract due to leeway") + keepContract(c.ContractMetadata, host) + remainingLeeway-- + } else { + logger.Debug("ignoring contract without revision") + churnReasons[c.ID] = errContractNoRevision.Error() + } + continue // no more checks without revision + } + + // check if contract is usable + usable, needsRefresh, needsRenew, reasons := cc.isUsableContract(ctx.AutopilotConfig(), host.Settings, host.PriceTable.HostPriceTable, ctx.state.RS, c, inSet, bh, ipFilter) + + // extend logger + logger = logger.With("usable", usable). + With("needsRefresh", needsRefresh). + With("needsRenew", needsRenew). + With("reasons", reasons) + + // remember reason for potential drop of contract + if len(reasons) > 0 { + churnReasons[c.ID] = strings.Join(reasons, ",") + } + + contract := c.ContractMetadata + + // renew/refresh as necessary + var ourFault bool + if needsRenew { + var renewedContract api.ContractMetadata + renewedContract, ourFault, err = cr.renewContract(ctx, w, c, host, remainingFunds, logger) + if err != nil { + logger = logger.With(zap.Error(err)).With("ourFault", ourFault) + + // don't register an alert for hosts that are out of funds since the + // user can't do anything about it + if !(rhp3.IsErrHost(err) && utils.IsErr(err, wallet.ErrNotEnoughFunds)) { + alerter.RegisterAlert(ctx, newContractRenewalFailedAlert(contract, !ourFault, err)) + } + logger.Error("failed to renew contract") + } else { + logger.Info("successfully renewed contract") + alerter.DismissAlerts(ctx, alerts.IDForContract(alertRenewalFailedID, contract.ID)) + contract = renewedContract + usable = true + renewed++ + } + } else if needsRefresh { + var refreshedContract api.ContractMetadata + refreshedContract, ourFault, err = cr.refreshContract(ctx, w, c, host, remainingFunds, logger) + if err != nil { + logger = logger.With(zap.Error(err)).With("ourFault", ourFault) + + // don't register an alert for hosts that are out of funds since the + // user can't do anything about it + if !(rhp3.IsErrHost(err) && utils.IsErr(err, wallet.ErrNotEnoughFunds)) { + alerter.RegisterAlert(ctx, newContractRenewalFailedAlert(contract, !ourFault, err)) + } + logger.Error("failed to refresh contract") + } else { + logger.Info("successfully refreshed contract") + alerter.DismissAlerts(ctx, alerts.IDForContract(alertRenewalFailedID, contract.ID)) + contract = refreshedContract + usable = true + refreshed++ + } + } + + // if the renewal/refresh failing was our fault (e.g. we ran out of + // funds), we should not drop the contract + if !usable && ourFault { + logger.Info("keeping contract even though renewal/refresh failed") + usable = true + } + + // if the contract is not usable we ignore it + if !usable { + if inSet { + logger.Info("contract is not usable, removing from set") + } else { + logger.Debug("contract is not usable, remains out of set") + } + continue + } + + // we keep the contract, add the host to the filter + logger.Debug("contract is usable and is added / stays in set") + keepContract(contract, host) } - return false + logger.With("refreshed", refreshed). + With("renewed", renewed). + With("filteredContracts", len(filteredContracts)). + Info("checking existing contracts done") + return filteredContracts, churnReasons, nil } -func (c *Contractor) pruneContractRefreshFailures(contracts []api.Contract) { - contractMap := make(map[types.FileContractID]struct{}) - for _, contract := range contracts { - contractMap[contract.ID] = struct{}{} +// performContracdtFormations forms up to 'wanted' new contracts with hosts. The +// 'ipFilter' and 'remainingFunds' are updated with every new contract. +func performContractFormations(ctx *mCtx, bus Bus, w Worker, cr contractReviser, ipFilter *hostSet, logger *zap.SugaredLogger, remainingFunds *types.Currency, wanted int) ([]api.ContractMetadata, error) { + var formedContracts []api.ContractMetadata + addContract := func(c api.ContractMetadata, h api.Host) { + formedContracts = append(formedContracts, c) + wanted-- + ipFilter.Add(h) } - for fcid := range c.firstRefreshFailure { - if _, ok := contractMap[fcid]; !ok { - delete(c.firstRefreshFailure, fcid) + + // early check to avoid fetching all candidates + if wanted <= 0 { + logger.Info("already have enough contracts, no need to form new ones") + return formedContracts, nil // nothing to do + } + logger.With("wanted", wanted).Info("trying to form more contracts to fill set") + + // get list of hosts that we already have contracts with + contracts, err := bus.Contracts(ctx, api.ContractsOpts{}) + if err != nil { + return nil, fmt.Errorf("failed to fetch contracts: %w", err) + } + usedHosts := make(map[types.PublicKey]struct{}) + for _, c := range contracts { + usedHosts[c.HostKey] = struct{}{} + } + allHosts, err := bus.SearchHosts(ctx, api.SearchHostOptions{ + Limit: -1, + FilterMode: api.HostFilterModeAllowed, + UsabilityMode: api.UsabilityFilterModeAll, + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch usable hosts: %w", err) + } + + // filter them + var candidates scoredHosts + for _, host := range allHosts { + logger := logger.With("hostKey", host.PublicKey) + hc, ok := host.Checks[ctx.ApID()] + if !ok { + logger.Warn("missing host check") + continue + } else if _, used := usedHosts[host.PublicKey]; used { + logger.Debug("host already used") + continue + } else if score := hc.Score.Score(); score == 0 { + logger.Error("host has a score of 0") + continue + } + candidates = append(candidates, newScoredHost(host, hc.Score)) + } + logger = logger.With("candidates", len(candidates)) + + // select hosts, since we already have all of them in memory we select + // len(candidates) + candidates = candidates.randSelectByScore(len(candidates)) + if len(candidates) < wanted { + logger.Warn("not enough candidates to form new contracts") + } + + // calculate min/max contract funds + minInitialContractFunds, maxInitialContractFunds := initialContractFundingMinMax(ctx.AutopilotConfig()) + + // form contracts until the new set has the desired size + for _, candidate := range candidates { + if wanted == 0 { + return formedContracts, nil // done + } + + // break if the autopilot is stopped + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + default: + } + + // fetch a new price table if necessary + if err := refreshPriceTable(ctx, w, &candidate.host); err != nil { + logger.Warnf("failed to fetch price table for candidate host %v: %v", candidate.host.PublicKey, err) + continue + } + + // prepare gouging checker + cs, err := bus.ConsensusState(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch consensus state: %w", err) + } + gc := ctx.GougingChecker(cs) + + // prepare a logger + logger := logger.With("hostKey", candidate.host.PublicKey). + With("remainingBudget", remainingFunds). + With("addresses", candidate.host.ResolvedAddresses) + + // perform gouging checks on the fly to ensure the host is not gouging its prices + if breakdown := gc.Check(nil, &candidate.host.PriceTable.HostPriceTable); breakdown.Gouging() { + logger.With("reasons", breakdown.String()).Info("candidate is price gouging") + continue + } + + // check if we already have a contract with a host on that address + if ctx.ShouldFilterRedundantIPs() && ipFilter.HasRedundantIP(candidate.host) { + logger.Info("host has redundant IP") + continue + } + + formedContract, proceed, err := cr.formContract(ctx, w, candidate.host, minInitialContractFunds, maxInitialContractFunds, remainingFunds, logger) + if err != nil { + logger.With(zap.Error(err)).Error("failed to form contract") + continue } + if !proceed { + logger.Error("not proceeding with contract formation") + break + } + + // add new contract and host + addContract(formedContract, candidate.host) } + logger.With("formedContracts", len(formedContracts)).Info("done forming contracts") + return formedContracts, nil } -func (c *Contractor) shouldForgiveFailedRefresh(fcid types.FileContractID) bool { - lastFailure, exists := c.firstRefreshFailure[fcid] - if !exists { - lastFailure = time.Now() - c.firstRefreshFailure[fcid] = lastFailure +// performHostChecks performs scoring and usability checks on all hosts, +// updating their state in the database. +func performHostChecks(ctx *mCtx, bus Bus, logger *zap.SugaredLogger) error { + var usabilityBreakdown unusableHostsBreakdown + // fetch all hosts that are not blocked + hosts, err := bus.SearchHosts(ctx, api.SearchHostOptions{Limit: -1, FilterMode: api.HostFilterModeAllowed}) + if err != nil { + return fmt.Errorf("failed to fetch all hosts: %w", err) } - return time.Since(lastFailure) < failedRefreshForgivenessPeriod + + var scoredHosts []scoredHost + for _, host := range hosts { + // score host + sb, err := ctx.HostScore(host) + if err != nil { + logger.With(zap.Error(err)).Info("failed to score host") + continue + } + scoredHosts = append(scoredHosts, newScoredHost(host, sb)) + } + + // compute minimum score for usable hosts + minScore := calculateMinScore(scoredHosts, ctx.WantedContracts(), logger) + + // run host checks using the latest consensus state + cs, err := bus.ConsensusState(ctx) + if err != nil { + return fmt.Errorf("failed to fetch consensus state: %w", err) + } + for _, h := range scoredHosts { + h.host.PriceTable.HostBlockHeight = cs.BlockHeight // ignore HostBlockHeight + hc := checkHost(ctx.GougingChecker(cs), h, minScore) + if err := bus.UpdateHostCheck(ctx, ctx.ApID(), h.host.PublicKey, *hc); err != nil { + return fmt.Errorf("failed to update host check for host %v: %w", h.host.PublicKey, err) + } + usabilityBreakdown.track(hc.Usability) + + if !hc.Usability.IsUsable() { + logger.With("hostKey", h.host.PublicKey). + With("reasons", strings.Join(hc.Usability.UnusableReasons(), ",")). + Debug("host is not usable") + } + } + + logger.Infow("host checks completed", usabilityBreakdown.keysAndValues()...) + return nil +} + +func performPostMaintenanceTasks(ctx *mCtx, bus Bus, w Worker, alerter alerts.Alerter, cc contractChecker, rb revisionBroadcaster, logger *zap.SugaredLogger) error { + // fetch some contract and host info + allContracts, err := bus.Contracts(ctx, api.ContractsOpts{}) + if err != nil { + return fmt.Errorf("failed to fetch all contracts: %w", err) + } + setContracts, err := bus.Contracts(ctx, api.ContractsOpts{ContractSet: ctx.ContractSet()}) + if err != nil { + return fmt.Errorf("failed to fetch contracts: %w", err) + } + allHosts, err := bus.SearchHosts(ctx, api.SearchHostOptions{ + Limit: -1, + FilterMode: api.HostFilterModeAllowed, + UsabilityMode: api.UsabilityFilterModeAll, + }) + if err != nil { + return fmt.Errorf("failed to fetch all hosts: %w", err) + } + usedHosts := make(map[types.PublicKey]struct{}) + for _, c := range allContracts { + usedHosts[c.HostKey] = struct{}{} + } + + // run revision broadcast on contracts in the new set + rb.broadcastRevisions(ctx, w, setContracts, logger) + + // register alerts for used hosts with lost sectors + var toDismiss []types.Hash256 + for _, h := range allHosts { + if _, used := usedHosts[h.PublicKey]; !used { + continue + } else if registerLostSectorsAlert(h.Interactions.LostSectors*rhpv2.SectorSize, h.StoredData) { + alerter.RegisterAlert(ctx, newLostSectorsAlert(h.PublicKey, h.Settings.Version, h.Settings.Release, h.Interactions.LostSectors)) + } else { + toDismiss = append(toDismiss, alerts.IDForHost(alertLostSectorsID, h.PublicKey)) + } + } + if len(toDismiss) > 0 { + alerter.DismissAlerts(ctx, toDismiss...) + } + + // prune refresh failures + cc.pruneContractRefreshFailures(allContracts) + return nil +} + +func performContractMaintenance(ctx *mCtx, alerter alerts.Alerter, bus Bus, churn *accumulatedChurn, w Worker, cc contractChecker, cr contractReviser, rb revisionBroadcaster, logger *zap.SugaredLogger) (bool, error) { + logger = logger.Named("performContractMaintenance"). + Named(hex.EncodeToString(frand.Bytes(16))). // uuid for this iteration + With("contractSet", ctx.ContractSet()) + + // check if we want to run maintenance + if reason, skip := canSkipContractMaintenance(ctx, ctx.ContractsConfig()); skip { + logger.With("reason", reason).Info("skipping contract maintenance") + return false, nil + } + + // compute the remaining budget for this period + remaining, err := remainingAllowance(ctx, bus, ctx.state) + if err != nil { + return false, fmt.Errorf("failed to compute remaining allowance: %w", err) + } + logger = logger.With("remainingAllowance", remaining) + + logger.Infow("performing contract maintenance") + + // STEP 1: perform host checks + if err := performHostChecks(ctx, bus, logger); err != nil { + return false, err + } + + // STEP 2: perform contract maintenance + ipFilter := &hostSet{ + logger: logger.Named("ipFilter"), + subnetToHostKey: make(map[string]string), + } + keptContracts, churnReasons, err := performContractChecks(ctx, alerter, bus, w, cc, cr, ipFilter, logger, &remaining) + if err != nil { + return false, err + } + + // STEP 3: perform contract formation + formedContracts, err := performContractFormations(ctx, bus, w, cr, ipFilter, logger, &remaining, int(ctx.WantedContracts())-len(keptContracts)) + if err != nil { + return false, err + } + + // fetch old set + oldSet, err := bus.Contracts(ctx, api.ContractsOpts{ContractSet: ctx.ContractSet()}) + if err != nil && !utils.IsErr(err, api.ErrContractSetNotFound) { + return false, fmt.Errorf("failed to fetch old contract set: %w", err) + } + + // STEP 4: update contract set + newSet := make([]api.ContractMetadata, 0, len(keptContracts)+len(formedContracts)) + newSet = append(newSet, keptContracts...) + newSet = append(newSet, formedContracts...) + var newSetIDs []types.FileContractID + for _, contract := range newSet { + newSetIDs = append(newSetIDs, contract.ID) + } + if err := bus.SetContractSet(ctx, ctx.ContractSet(), newSetIDs); err != nil { + return false, fmt.Errorf("failed to update contract set: %w", err) + } + + // STEP 5: perform minor maintenance such as cleanups and broadcasting + // revisions + if err := performPostMaintenanceTasks(ctx, bus, w, alerter, cc, rb, logger); err != nil { + return false, err + } + + // STEP 6: log changes and register alerts + return computeContractSetChanged(ctx, alerter, bus, churn, logger, oldSet, newSet, churnReasons) } diff --git a/autopilot/contractor/contractor_test.go b/autopilot/contractor/contractor_test.go index 245e19f7b..450433a45 100644 --- a/autopilot/contractor/contractor_test.go +++ b/autopilot/contractor/contractor_test.go @@ -12,29 +12,25 @@ import ( ) func TestCalculateMinScore(t *testing.T) { - c := &Contractor{ - logger: zap.NewNop().Sugar(), - } - var candidates []scoredHost for i := 0; i < 250; i++ { candidates = append(candidates, scoredHost{score: float64(i + 1)}) } // Test with 100 hosts which makes for a random set size of 250 - minScore := c.calculateMinScore(candidates, 100) + minScore := calculateMinScore(candidates, 100, zap.NewNop().Sugar()) if minScore != 0.002 { t.Fatalf("expected minScore to be 0.002 but was %v", minScore) } // Test with 0 hosts - minScore = c.calculateMinScore([]scoredHost{}, 100) + minScore = calculateMinScore([]scoredHost{}, 100, zap.NewNop().Sugar()) if minScore != math.SmallestNonzeroFloat64 { t.Fatalf("expected minScore to be math.SmallestNonzeroFLoat64 but was %v", minScore) } // Test with 300 hosts which is 50 more than we have - minScore = c.calculateMinScore(candidates, 300) + minScore = calculateMinScore(candidates, 300, zap.NewNop().Sugar()) if minScore != math.SmallestNonzeroFloat64 { t.Fatalf("expected minScore to be math.SmallestNonzeroFLoat64 but was %v", minScore) } @@ -115,7 +111,7 @@ func TestShouldForgiveFailedRenewal(t *testing.T) { } // prune map - c.pruneContractRefreshFailures([]api.Contract{}) + c.pruneContractRefreshFailures([]api.ContractMetadata{}) if len(c.firstRefreshFailure) != 0 { t.Fatal("expected no failures") } diff --git a/autopilot/contractor/evaluate.go b/autopilot/contractor/evaluate.go index b40cc3be6..e947009cb 100644 --- a/autopilot/contractor/evaluate.go +++ b/autopilot/contractor/evaluate.go @@ -5,15 +5,15 @@ import ( "go.sia.tech/core/types" "go.sia.tech/renterd/api" - "go.sia.tech/renterd/worker" + "go.sia.tech/renterd/internal/gouging" ) var ErrMissingRequiredFields = errors.New("missing required fields in configuration, both allowance and amount must be set") func countUsableHosts(cfg api.AutopilotConfig, cs api.ConsensusState, fee types.Currency, period uint64, rs api.RedundancySettings, gs api.GougingSettings, hosts []api.Host) (usables uint64) { - gc := worker.NewGougingChecker(gs, cs, fee, period, cfg.Contracts.RenewWindow) + gc := gouging.NewChecker(gs, cs, fee, &period, &cfg.Contracts.RenewWindow) for _, host := range hosts { - hc := checkHost(cfg, rs, gc, host, minValidScore) + hc := checkHost(gc, scoreHost(host, cfg, rs.Redundancy()), minValidScore) if hc.Usability.IsUsable() { usables++ } @@ -31,12 +31,12 @@ func EvaluateConfig(cfg api.AutopilotConfig, cs api.ConsensusState, fee types.Cu } period := cfg.Contracts.Period - gc := worker.NewGougingChecker(gs, cs, fee, period, cfg.Contracts.RenewWindow) + gc := gouging.NewChecker(gs, cs, fee, &period, &cfg.Contracts.RenewWindow) resp.Hosts = uint64(len(hosts)) for i, host := range hosts { hosts[i].PriceTable.HostBlockHeight = cs.BlockHeight // ignore block height - hc := checkHost(cfg, rs, gc, host, minValidScore) + hc := checkHost(gc, scoreHost(host, cfg, rs.Redundancy()), minValidScore) if hc.Usability.IsUsable() { resp.Usable++ continue diff --git a/autopilot/contractor/hostfilter.go b/autopilot/contractor/hostfilter.go index 9298ab009..3083976ed 100644 --- a/autopilot/contractor/hostfilter.go +++ b/autopilot/contractor/hostfilter.go @@ -10,10 +10,14 @@ import ( rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" "go.sia.tech/renterd/api" - "go.sia.tech/renterd/worker" + "go.sia.tech/renterd/internal/gouging" ) const ( + // ContractConfirmationDeadline is the number of blocks since its start + // height we wait for a contract to appear on chain. + ContractConfirmationDeadline = 18 + // minContractFundUploadThreshold is the percentage of contract funds // remaining at which the contract gets marked as not good for upload minContractFundUploadThreshold = float64(0.05) // 5% @@ -23,10 +27,6 @@ const ( // acquirable storage below which the contract is considered to be // out-of-collateral. minContractCollateralDenominator = 20 // 5% - - // contractConfirmationDeadline is the number of blocks since its start - // height we wait for a contract to appear on chain. - contractConfirmationDeadline = 18 ) var ( @@ -102,54 +102,38 @@ func (u *unusableHostsBreakdown) keysAndValues() []interface{} { // - recoverable -> can be usable in the contract set if it is refreshed/renewed // - refresh -> should be refreshed // - renew -> should be renewed -func (c *Contractor) isUsableContract(cfg api.AutopilotConfig, rs api.RedundancySettings, ci contractInfo, inSet bool, bh uint64, f *ipFilter) (usable, recoverable, refresh, renew bool, reasons []string) { - contract, s, pt := ci.contract, ci.host.Settings, ci.host.PriceTable.HostPriceTable - +func (c *Contractor) isUsableContract(cfg api.AutopilotConfig, s rhpv2.HostSettings, pt rhpv3.HostPriceTable, rs api.RedundancySettings, contract api.Contract, inSet bool, bh uint64, f *hostSet) (usable, refresh, renew bool, reasons []string) { usable = true if bh > contract.EndHeight() { reasons = append(reasons, errContractExpired.Error()) usable = false - recoverable = false refresh = false renew = false } else if contract.Revision.RevisionNumber == math.MaxUint64 { reasons = append(reasons, errContractMaxRevisionNumber.Error()) usable = false - recoverable = false refresh = false renew = false } else { if isOutOfCollateral(cfg, rs, contract, s, pt) { reasons = append(reasons, errContractOutOfCollateral.Error()) usable = usable && inSet && c.shouldForgiveFailedRefresh(contract.ID) - recoverable = !usable // only needs to be recoverable if !usable refresh = true renew = false } if isOutOfFunds(cfg, pt, contract) { reasons = append(reasons, errContractOutOfFunds.Error()) usable = usable && inSet && c.shouldForgiveFailedRefresh(contract.ID) - recoverable = !usable // only needs to be recoverable if !usable refresh = true renew = false } if shouldRenew, secondHalf := isUpForRenewal(cfg, *contract.Revision, bh); shouldRenew { reasons = append(reasons, fmt.Errorf("%w; second half: %t", errContractUpForRenewal, secondHalf).Error()) usable = usable && !secondHalf // only unusable if in second half of renew window - recoverable = true refresh = false renew = true } } - - // IP check should be last since it modifies the filter - shouldFilter := !cfg.Hosts.AllowRedundantIPs && (usable || recoverable) - if shouldFilter && f.HasRedundantIP(ci.host) { - reasons = append(reasons, api.ErrUsabilityHostRedundantIP.Error()) - usable = false - recoverable = false // do not use in the contract set, but keep it around for downloads - renew = false // do not renew, but allow refreshes so the contracts stays funded - } return } @@ -236,14 +220,11 @@ func isUpForRenewal(cfg api.AutopilotConfig, r types.FileContractRevision, block } // checkHost performs a series of checks on the host. -func checkHost(cfg api.AutopilotConfig, rs api.RedundancySettings, gc worker.GougingChecker, h api.Host, minScore float64) *api.HostCheck { - if rs.Validate() != nil { - panic("invalid redundancy settings were supplied - developer error") - } +func checkHost(gc gouging.Checker, sh scoredHost, minScore float64) *api.HostCheck { + h := sh.host // prepare host breakdown fields var gb api.HostGougingBreakdown - var sb api.HostScoreBreakdown var ub api.HostUsabilityBreakdown // blocked status does not influence what host info is calculated @@ -267,27 +248,30 @@ func checkHost(cfg api.AutopilotConfig, rs api.RedundancySettings, gc worker.Gou ub.NotAcceptingContracts = true } - // perform gouging checks + // perform gouging and score checks gb = gc.Check(&h.Settings, &h.PriceTable.HostPriceTable) if gb.Gouging() { ub.Gouging = true - } else if minScore > 0 { - // perform scoring checks - // - // NOTE: only perform these scoring checks if we know the host is - // not gouging, this because the core package does not have overflow - // checks in its cost calculations needed to calculate the period - // cost - sb = hostScore(cfg, h, rs.Redundancy()) - if sb.Score() < minScore { - ub.LowScore = true - } + } else if minScore > 0 && !(sh.score > minScore) { + ub.LowScore = true } } return &api.HostCheck{ Usability: ub, Gouging: gb, - Score: sb, + Score: sh.sb, } } + +func newScoredHost(h api.Host, sb api.HostScoreBreakdown) scoredHost { + return scoredHost{ + host: h, + sb: sb, + score: sb.Score(), + } +} + +func scoreHost(h api.Host, cfg api.AutopilotConfig, expectedRedundancy float64) scoredHost { + return newScoredHost(h, hostScore(cfg, h, expectedRedundancy)) +} diff --git a/autopilot/contractor/hostscore.go b/autopilot/contractor/hostscore.go index 68abf1b21..4535899fd 100644 --- a/autopilot/contractor/hostscore.go +++ b/autopilot/contractor/hostscore.go @@ -9,7 +9,7 @@ import ( rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" "go.sia.tech/renterd/api" - "go.sia.tech/siad/build" + "go.sia.tech/renterd/internal/utils" ) const ( @@ -22,7 +22,7 @@ const ( minValidScore = math.SmallestNonzeroFloat64 ) -func hostScore(cfg api.AutopilotConfig, h api.Host, expectedRedundancy float64) api.HostScoreBreakdown { +func hostScore(cfg api.AutopilotConfig, h api.Host, expectedRedundancy float64) (sb api.HostScoreBreakdown) { cCfg := cfg.Contracts // idealDataPerHost is the amount of data that we would have to put on each // host assuming that our storage requirements were spread evenly across @@ -273,7 +273,7 @@ func versionScore(settings rhpv2.HostSettings, minVersion string) float64 { } weight := 1.0 for _, v := range versions { - if build.VersionCmp(settings.Version, v.version) < 0 { + if utils.VersionCmp(settings.Version, v.version) < 0 { weight *= v.penalty } } diff --git a/autopilot/contractor/hostscore_test.go b/autopilot/contractor/hostscore_test.go index 41347ec72..9b2cdea47 100644 --- a/autopilot/contractor/hostscore_test.go +++ b/autopilot/contractor/hostscore_test.go @@ -19,9 +19,9 @@ var cfg = api.AutopilotConfig{ Period: 144 * 7 * 6, RenewWindow: 144 * 7 * 2, - Download: 1 << 40, // 1 TiB - Upload: 1 << 40, // 1 TiB - Storage: 1 << 42, // 4 TiB + Download: 1e12, // 1 TB + Upload: 1e12, // 1 TB + Storage: 4e12, // 4 TB Set: api.DefaultAutopilotID, }, diff --git a/autopilot/contractor/hostset.go b/autopilot/contractor/hostset.go new file mode 100644 index 000000000..0ae4ac6c2 --- /dev/null +++ b/autopilot/contractor/hostset.go @@ -0,0 +1,72 @@ +package contractor + +import ( + "context" + "errors" + "time" + + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" + "go.uber.org/zap" +) + +var ( + errHostTooManySubnets = errors.New("host has more than two subnets") +) + +type ( + hostSet struct { + subnetToHostKey map[string]string + + logger *zap.SugaredLogger + } +) + +func (hs *hostSet) HasRedundantIP(host api.Host) bool { + // compat code for hosts that have been scanned before ResolvedAddresses + // were introduced + if len(host.ResolvedAddresses) == 0 { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + host.ResolvedAddresses, _, _ = utils.ResolveHostIP(ctx, host.NetAddress) + } + + subnets, err := utils.AddressesToSubnets(host.ResolvedAddresses) + if err != nil { + hs.logger.Errorf("failed to parse host %v subnets: %v", host.PublicKey, err) + return true + } + // validate host subnets + if len(subnets) == 0 { + hs.logger.Errorf("host %v has no subnet, treating its IP %v as redundant", host.PublicKey, host.NetAddress) + return true + } else if len(subnets) > 2 { + hs.logger.Errorf("host %v has more than 2 subnets, treating its IP %v as redundant", host.PublicKey, errHostTooManySubnets) + return true + } + + // check if we know about this subnet + var knownHost string + for _, subnet := range subnets { + if knownHost = hs.subnetToHostKey[subnet]; knownHost != "" { + break + } + } + + // if we know about the subnet, the host is redundant if it's not the same + if knownHost != "" { + return host.PublicKey.String() != knownHost + } + return false +} + +func (hs *hostSet) Add(host api.Host) { + subnets, err := utils.AddressesToSubnets(host.ResolvedAddresses) + if err != nil { + hs.logger.Errorf("failed to parse host %v subnets: %v", host.PublicKey, err) + return + } + for _, subnet := range subnets { + hs.subnetToHostKey[subnet] = host.PublicKey.String() + } +} diff --git a/autopilot/contractor/hostset_test.go b/autopilot/contractor/hostset_test.go new file mode 100644 index 000000000..fbd1bb2e2 --- /dev/null +++ b/autopilot/contractor/hostset_test.go @@ -0,0 +1,77 @@ +package contractor + +import ( + "testing" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.uber.org/zap" +) + +func TestHostSet(t *testing.T) { + hs := &hostSet{ + subnetToHostKey: make(map[string]string), + logger: zap.NewNop().Sugar(), + } + + // Host with no subnets + host1 := api.Host{ + PublicKey: types.GeneratePrivateKey().PublicKey(), + ResolvedAddresses: []string{}, + } + if !hs.HasRedundantIP(host1) { + t.Fatalf("Expected host with no subnets to be considered redundant") + } + + // Host with more than 2 subnets + host2 := api.Host{ + PublicKey: types.GeneratePrivateKey().PublicKey(), + ResolvedAddresses: []string{"192.168.1.1", "10.0.0.1", "172.16.0.1"}, + } + if !hs.HasRedundantIP(host2) { + t.Fatalf("Expected host with more than 2 subnets to be considered redundant") + } + + // New host with unique subnet + host3 := api.Host{ + PublicKey: types.GeneratePrivateKey().PublicKey(), + ResolvedAddresses: []string{"192.168.2.1"}, + } + if hs.HasRedundantIP(host3) { + t.Fatal("Expected new host with unique subnet to not be considered redundant") + } + hs.Add(host3) + + // New host with same subnet but different public key + host4 := api.Host{ + PublicKey: types.GeneratePrivateKey().PublicKey(), + ResolvedAddresses: []string{"192.168.2.1"}, + } + if !hs.HasRedundantIP(host4) { + t.Fatal("Expected host with same subnet but different public key to be considered redundant") + } + + // Same host from before + if hs.HasRedundantIP(host3) { + t.Fatal("Expected same host to not be considered redundant") + } + + // Host with two valid subnets + host5 := api.Host{ + PublicKey: types.GeneratePrivateKey().PublicKey(), + ResolvedAddresses: []string{"192.168.3.1", "10.0.0.1"}, + } + if hs.HasRedundantIP(host5) { + t.Fatal("Expected host with two valid subnets to not be considered redundant") + } + hs.Add(host5) + + // New host with one overlapping subnet + host6 := api.Host{ + PublicKey: types.GeneratePrivateKey().PublicKey(), + ResolvedAddresses: []string{"10.0.0.1", "172.16.0.1"}, + } + if !hs.HasRedundantIP(host6) { + t.Fatal("Expected host with one overlapping subnet to be considered redundant") + } +} diff --git a/autopilot/contractor/ipfilter.go b/autopilot/contractor/ipfilter.go deleted file mode 100644 index b29668372..000000000 --- a/autopilot/contractor/ipfilter.go +++ /dev/null @@ -1,66 +0,0 @@ -package contractor - -import ( - "errors" - - "go.sia.tech/renterd/api" - "go.uber.org/zap" -) - -var ( - errHostTooManySubnets = errors.New("host has more than two subnets") -) - -type ( - ipFilter struct { - subnetToHostKey map[string]string - - logger *zap.SugaredLogger - } -) - -func (c *Contractor) newIPFilter() *ipFilter { - return &ipFilter{ - logger: c.logger, - subnetToHostKey: make(map[string]string), - } -} - -func (f *ipFilter) HasRedundantIP(host api.Host) bool { - // validate host subnets - if len(host.Subnets) == 0 { - f.logger.Errorf("host %v has no subnet, treating its IP %v as redundant", host.PublicKey, host.NetAddress) - return true - } else if len(host.Subnets) > 2 { - f.logger.Errorf("host %v has more than 2 subnets, treating its IP %v as redundant", host.PublicKey, errHostTooManySubnets) - return true - } - - // check if we know about this subnet - var knownHost string - for _, subnet := range host.Subnets { - if knownHost = f.subnetToHostKey[subnet]; knownHost != "" { - break - } - } - - // if we know about the subnet, the host is redundant if it's not the same - if knownHost != "" { - return host.PublicKey.String() != knownHost - } - - // otherwise register all the host'ssubnets - for _, subnet := range host.Subnets { - f.subnetToHostKey[subnet] = host.PublicKey.String() - } - - return false -} - -func (f *ipFilter) Remove(h api.Host) { - for k, v := range f.subnetToHostKey { - if v == h.PublicKey.String() { - delete(f.subnetToHostKey, k) - } - } -} diff --git a/autopilot/contractor/state.go b/autopilot/contractor/state.go index 9f06fe168..0b786adf1 100644 --- a/autopilot/contractor/state.go +++ b/autopilot/contractor/state.go @@ -2,11 +2,12 @@ package contractor import ( "context" + "errors" "time" "go.sia.tech/core/types" "go.sia.tech/renterd/api" - "go.sia.tech/renterd/worker" + "go.sia.tech/renterd/internal/gouging" ) type ( @@ -72,8 +73,19 @@ func (ctx *mCtx) Err() error { return ctx.ctx.Err() } -func (ctx *mCtx) GougingChecker(cs api.ConsensusState) worker.GougingChecker { - return worker.NewGougingChecker(ctx.state.GS, cs, ctx.state.Fee, ctx.Period(), ctx.RenewWindow()) +func (ctx *mCtx) GougingChecker(cs api.ConsensusState) gouging.Checker { + period, renewWindow := ctx.Period(), ctx.RenewWindow() + return gouging.NewChecker(ctx.state.GS, cs, ctx.state.Fee, &period, &renewWindow) +} + +func (ctx *mCtx) HostScore(h api.Host) (sb api.HostScoreBreakdown, err error) { + // host settings that cause a panic should result in a score of 0 + defer func() { + if r := recover(); r != nil { + err = errors.New("panic while scoring host") + } + }() + return hostScore(ctx.state.AP.Config, h, ctx.state.RS.Redundancy()), nil } func (ctx *mCtx) Period() uint64 { @@ -96,6 +108,14 @@ func (ctx *mCtx) WantedContracts() uint64 { return ctx.state.AP.Config.Contracts.Amount } +func (ctx *mCtx) Set() string { + return ctx.state.ContractsConfig().Set +} + +func (ctx *mCtx) SortContractsForMaintenance(contracts []api.Contract) { + ctx.state.ContractsConfig().SortContractsForMaintenance(contracts) +} + func (state *MaintenanceState) Allowance() types.Currency { return state.AP.Config.Contracts.Allowance } diff --git a/autopilot/migrator.go b/autopilot/migrator.go index 40599dea2..fd935cabb 100644 --- a/autopilot/migrator.go +++ b/autopilot/migrator.go @@ -13,12 +13,15 @@ import ( "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" - "go.sia.tech/renterd/stats" "go.uber.org/zap" ) const ( migratorBatchSize = math.MaxInt // TODO: change once we have a fix for the infinite loop + + // migrationAlertRegisterInterval is the interval at which we update the + // ongoing migrations alert to indicate progress + migrationAlertRegisterInterval = 30 * time.Second ) type ( @@ -27,8 +30,9 @@ type ( logger *zap.SugaredLogger healthCutoff float64 parallelSlabsPerWorker uint64 + signalConsensusNotSynced chan struct{} signalMaintenanceFinished chan struct{} - statsSlabMigrationSpeedMS *stats.DataPoints + statsSlabMigrationSpeedMS *utils.DataPoints mu sync.Mutex migrating bool @@ -67,8 +71,9 @@ func newMigrator(ap *Autopilot, healthCutoff float64, parallelSlabsPerWorker uin logger: ap.logger.Named("migrator"), healthCutoff: healthCutoff, parallelSlabsPerWorker: parallelSlabsPerWorker, + signalConsensusNotSynced: make(chan struct{}, 1), signalMaintenanceFinished: make(chan struct{}, 1), - statsSlabMigrationSpeedMS: stats.New(time.Hour), + statsSlabMigrationSpeedMS: utils.NewDataPoints(time.Hour), } } @@ -146,7 +151,7 @@ func (m *migrator) performMigrations(p *workerPool) { // fetch worker id once id, err := w.ID(ctx) if err != nil { - m.logger.Errorf("failed to fetch worker id: %v", err) + m.logger.Errorf("failed to reach worker, err: %v", err) return } @@ -157,8 +162,14 @@ func (m *migrator) performMigrations(p *workerPool) { m.statsSlabMigrationSpeedMS.Track(float64(time.Since(start).Milliseconds())) if err != nil { m.logger.Errorf("%v: migration %d/%d failed, key: %v, health: %v, overpaid: %v, err: %v", id, j.slabIdx+1, j.batchSize, j.Key, j.Health, res.SurchargeApplied, err) - skipAlert := utils.IsErr(err, api.ErrSlabNotFound) - if !skipAlert { + if utils.IsErr(err, api.ErrConsensusNotSynced) { + // interrupt migrations if consensus is not synced + select { + case m.signalConsensusNotSynced <- struct{}{}: + default: + } + return + } else if !utils.IsErr(err, api.ErrSlabNotFound) { // fetch all object IDs for the slab we failed to migrate var objectIds map[string][]string if res, err := m.objectIDsForSlabKey(ctx, j.Key); err != nil { @@ -198,29 +209,20 @@ func (m *migrator) performMigrations(p *workerPool) { default: } -OUTER: - for { - // fetch currently configured set - autopilot, err := m.ap.Config(m.ap.shutdownCtx) - if err != nil { - m.logger.Errorf("failed to fetch autopilot config: %w", err) - return - } - set := autopilot.Config.Contracts.Set - if set == "" { - m.logger.Error("could not perform migrations, no contract set configured") - return - } - - // recompute health. - start := time.Now() - if err := b.RefreshHealth(m.ap.shutdownCtx); err != nil { - m.ap.RegisterAlert(m.ap.shutdownCtx, newRefreshHealthFailedAlert(err)) - m.logger.Errorf("failed to recompute cached health before migration: %v", err) - return - } - m.logger.Infof("recomputed slab health in %v", time.Since(start)) + // fetch currently configured set + autopilot, err := m.ap.Config(m.ap.shutdownCtx) + if err != nil { + m.logger.Errorf("failed to fetch autopilot config: %w", err) + return + } + set := autopilot.Config.Contracts.Set + if set == "" { + m.logger.Error("could not perform migrations, no contract set configured") + return + } + // helper to update 'toMigrate' + updateToMigrate := func() { // fetch slabs for migration toMigrateNew, err := b.SlabsForMigration(m.ap.shutdownCtx, m.healthCutoff, set, migratorBatchSize) if err != nil { @@ -259,25 +261,46 @@ OUTER: sort.Slice(newSlabs, func(i, j int) bool { return newSlabs[i].Health < newSlabs[j].Health }) - migrateNewMap = nil // free map + } - // log the updated list of slabs to migrate - m.logger.Infof("%d slabs to migrate", len(toMigrate)) + // unregister the migration alert when we're done + defer m.ap.alerts.DismissAlerts(m.ap.shutdownCtx, alertMigrationID) - // register an alert to notify users about ongoing migrations. - if len(toMigrate) > 0 { - m.ap.RegisterAlert(m.ap.shutdownCtx, newOngoingMigrationsAlert(len(toMigrate), m.slabMigrationEstimate(len(toMigrate)))) +OUTER: + for { + // recompute health. + start := time.Now() + if err := b.RefreshHealth(m.ap.shutdownCtx); err != nil { + m.ap.RegisterAlert(m.ap.shutdownCtx, newRefreshHealthFailedAlert(err)) + m.logger.Errorf("failed to recompute cached health before migration: %v", err) + } else { + m.ap.DismissAlert(m.ap.shutdownCtx, alertHealthRefreshID) + m.logger.Infof("recomputed slab health in %v", time.Since(start)) + updateToMigrate() } + // log the updated list of slabs to migrate + m.logger.Infof("%d slabs to migrate", len(toMigrate)) + // return if there are no slabs to migrate if len(toMigrate) == 0 { return } + var lastRegister time.Time for i, slab := range toMigrate { + if time.Since(lastRegister) > migrationAlertRegisterInterval { + // register an alert to notify users about ongoing migrations + remaining := len(toMigrate) - i + m.ap.RegisterAlert(m.ap.shutdownCtx, newOngoingMigrationsAlert(remaining, m.slabMigrationEstimate(remaining))) + lastRegister = time.Now() + } select { case <-m.ap.shutdownCtx.Done(): return + case <-m.signalConsensusNotSynced: + m.logger.Info("migrations interrupted - consensus is not synced") + return case <-m.signalMaintenanceFinished: m.logger.Info("migrations interrupted - updating slabs for migration") continue OUTER @@ -285,6 +308,7 @@ OUTER: } } + // all slabs migrated return } } diff --git a/autopilot/percentile.go b/autopilot/percentile.go deleted file mode 100644 index eb3526b9a..000000000 --- a/autopilot/percentile.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2014-2020 Montana Flynn (https://montanaflynn.com) -package autopilot - -import ( - "errors" - "math" - "sort" -) - -var ( - errEmptyInput = errors.New("input must not be empty") - errOutOfBounds = errors.New("input is outside of range") -) - -func percentile(input []float64, percent float64) (float64, error) { - // validate input - if len(input) == 0 { - return math.NaN(), errEmptyInput - } - if percent <= 0 || percent > 100 { - return math.NaN(), errOutOfBounds - } - - // return early if we only have one - if len(input) == 1 { - return input[0], nil - } - - // deep copy the input and sort - input = append([]float64{}, input...) - sort.Float64s(input) - - // multiply percent by length of input - index := (percent / 100) * float64(len(input)) - - // check if the index is a whole number, if so return that input - if index == float64(int64(index)) { - i := int(index) - return input[i-1], nil - } - - // if the index is greater than one, return the average of the index and the value prior - if index > 1 { - i := int(index) - avg := (input[i-1] + input[i]) / 2 - return avg, nil - } - - return math.NaN(), errOutOfBounds -} diff --git a/autopilot/scanner.go b/autopilot/scanner.go deleted file mode 100644 index fa317fafa..000000000 --- a/autopilot/scanner.go +++ /dev/null @@ -1,339 +0,0 @@ -package autopilot - -import ( - "context" - "errors" - "strings" - "sync" - "sync/atomic" - "time" - - rhpv2 "go.sia.tech/core/rhp/v2" - "go.sia.tech/core/types" - "go.sia.tech/renterd/api" - "go.sia.tech/renterd/internal/utils" - "go.uber.org/zap" -) - -const ( - scannerTimeoutInterval = 10 * time.Minute - scannerTimeoutMinTimeout = 10 * time.Second - - trackerMinDataPoints = 25 - trackerNumDataPoints = 1000 - trackerTimeoutPercentile = 99 -) - -type ( - scanner struct { - // TODO: use the actual bus and worker interfaces when they've consolidated - // a bit, we currently use inline interfaces to avoid having to update the - // scanner tests with every interface change - bus interface { - SearchHosts(ctx context.Context, opts api.SearchHostOptions) ([]api.Host, error) - HostsForScanning(ctx context.Context, opts api.HostsForScanningOptions) ([]api.HostAddress, error) - RemoveOfflineHosts(ctx context.Context, minRecentScanFailures uint64, maxDowntime time.Duration) (uint64, error) - } - - tracker *tracker - logger *zap.SugaredLogger - ap *Autopilot - wg sync.WaitGroup - - scanBatchSize uint64 - scanThreads uint64 - scanMinInterval time.Duration - - timeoutMinInterval time.Duration - timeoutMinTimeout time.Duration - - mu sync.Mutex - scanning bool - scanningLastStart time.Time - timeout time.Duration - timeoutLastUpdate time.Time - interruptScanChan chan struct{} - } - scanWorker interface { - RHPScan(ctx context.Context, hostKey types.PublicKey, hostIP string, timeout time.Duration) (api.RHPScanResponse, error) - } - - scanReq struct { - hostKey types.PublicKey - hostIP string - } - - scanResp struct { - hostKey types.PublicKey - settings rhpv2.HostSettings - err error - } - - tracker struct { - threshold uint64 - percentile float64 - - mu sync.Mutex - count uint64 - timings []float64 - } -) - -func newTracker(threshold, total uint64, percentile float64) *tracker { - return &tracker{ - threshold: threshold, - percentile: percentile, - timings: make([]float64, total), - } -} - -func (t *tracker) addDataPoint(duration time.Duration) { - if duration == 0 { - return - } - - t.mu.Lock() - defer t.mu.Unlock() - - t.timings[t.count%uint64(len(t.timings))] = float64(duration.Milliseconds()) - - // NOTE: we silently overflow and disregard the threshold being reapplied - // when we overflow entirely, since we only ever increment the count with 1 - // it will never happen - t.count += 1 -} - -func (t *tracker) timeout() time.Duration { - t.mu.Lock() - defer t.mu.Unlock() - if t.count < uint64(t.threshold) { - return 0 - } - - percentile, err := percentile(t.timings, t.percentile) - if err != nil { - return 0 - } - - return time.Duration(percentile) * time.Millisecond -} - -func newScanner(ap *Autopilot, scanBatchSize, scanThreads uint64, scanMinInterval, timeoutMinInterval, timeoutMinTimeout time.Duration) (*scanner, error) { - if scanBatchSize == 0 { - return nil, errors.New("scanner batch size has to be greater than zero") - } - if scanThreads == 0 { - return nil, errors.New("scanner threads has to be greater than zero") - } - - return &scanner{ - bus: ap.bus, - tracker: newTracker( - trackerMinDataPoints, - trackerNumDataPoints, - trackerTimeoutPercentile, - ), - logger: ap.logger.Named("scanner"), - ap: ap, - - interruptScanChan: make(chan struct{}), - - scanBatchSize: scanBatchSize, - scanThreads: scanThreads, - scanMinInterval: scanMinInterval, - - timeoutMinInterval: timeoutMinInterval, - timeoutMinTimeout: timeoutMinTimeout, - }, nil -} - -func (s *scanner) Status() (bool, time.Time) { - s.mu.Lock() - defer s.mu.Unlock() - return s.scanning, s.scanningLastStart -} - -func (s *scanner) isInterrupted() bool { - select { - case <-s.interruptScanChan: - return true - default: - return false - } -} - -func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bool) { - if s.ap.isStopped() { - return - } - - scanType := "host scan" - if force { - scanType = "forced scan" - } - - s.mu.Lock() - if force { - close(s.interruptScanChan) - s.mu.Unlock() - - s.logger.Infof("waiting for ongoing scan to complete") - s.wg.Wait() - - s.mu.Lock() - s.interruptScanChan = make(chan struct{}) - } else if s.scanning || !s.isScanRequired() { - s.mu.Unlock() - return - } - s.scanningLastStart = time.Now() - s.scanning = true - s.mu.Unlock() - - s.logger.Infof("%s started", scanType) - - s.wg.Add(1) - go func(st string) { - defer s.wg.Done() - - for resp := range s.launchScanWorkers(ctx, w, s.launchHostScans()) { - if s.isInterrupted() || s.ap.isStopped() { - break - } - if resp.err != nil && !strings.Contains(resp.err.Error(), "connection refused") { - s.logger.Error(resp.err) - } - } - s.mu.Lock() - s.scanning = false - s.logger.Infof("%s finished after %v", st, time.Since(s.scanningLastStart)) - s.mu.Unlock() - }(scanType) -} - -func (s *scanner) PruneHosts(ctx context.Context, cfg api.HostsConfig) { - maxDowntime := time.Duration(cfg.MaxDowntimeHours) * time.Hour - minRecentScanFailures := cfg.MinRecentScanFailures - if maxDowntime > 0 { - s.logger.Debugf("removing hosts that have been offline for more than %v and have failed at least %d scans", maxDowntime, minRecentScanFailures) - removed, err := s.bus.RemoveOfflineHosts(ctx, minRecentScanFailures, maxDowntime) - if err != nil { - s.logger.Errorf("error occurred while removing offline hosts, err: %v", err) - } else if removed > 0 { - s.logger.Infof("removed %v offline hosts", removed) - } - } -} - -func (s *scanner) tryUpdateTimeout() { - s.mu.Lock() - defer s.mu.Unlock() - if !s.isTimeoutUpdateRequired() { - return - } - - updated := s.tracker.timeout() - if updated < s.timeoutMinTimeout { - s.logger.Infof("updated timeout is lower than min timeout, %v<%v", updated, s.timeoutMinTimeout) - updated = s.timeoutMinTimeout - } - - if s.timeout != updated { - s.logger.Infof("updated timeout %v->%v", s.timeout, updated) - s.timeout = updated - } - s.timeoutLastUpdate = time.Now() -} - -func (s *scanner) launchHostScans() chan scanReq { - reqChan := make(chan scanReq, s.scanBatchSize) - - s.ap.wg.Add(1) - go func() { - defer s.ap.wg.Done() - defer close(reqChan) - - var offset int - var exhausted bool - cutoff := time.Now().Add(-s.scanMinInterval) - for !s.ap.isStopped() && !exhausted { - // fetch next batch - hosts, err := s.bus.HostsForScanning(s.ap.shutdownCtx, api.HostsForScanningOptions{ - MaxLastScan: api.TimeRFC3339(cutoff), - Offset: offset, - Limit: int(s.scanBatchSize), - }) - if err != nil { - s.logger.Errorf("could not get hosts for scanning, err: %v", err) - break - } - if len(hosts) == 0 { - break - } - if len(hosts) < int(s.scanBatchSize) { - exhausted = true - } - - s.logger.Infof("scanning %d hosts in range %d-%d", len(hosts), offset, offset+int(s.scanBatchSize)) - offset += int(s.scanBatchSize) - - // add batch to scan queue - for _, h := range hosts { - select { - case <-s.ap.shutdownCtx.Done(): - return - case reqChan <- scanReq{ - hostKey: h.PublicKey, - hostIP: h.NetAddress, - }: - } - } - } - }() - - return reqChan -} - -func (s *scanner) launchScanWorkers(ctx context.Context, w scanWorker, reqs chan scanReq) chan scanResp { - respChan := make(chan scanResp, s.scanThreads) - liveThreads := s.scanThreads - - for i := uint64(0); i < s.scanThreads; i++ { - go func() { - for req := range reqs { - if s.ap.isStopped() { - break // shutdown - } - - scan, err := w.RHPScan(ctx, req.hostKey, req.hostIP, s.currentTimeout()) - if err != nil { - break // abort - } else if !utils.IsErr(errors.New(scan.ScanError), utils.ErrIOTimeout) && scan.Ping > 0 { - s.tracker.addDataPoint(time.Duration(scan.Ping)) - } - - respChan <- scanResp{req.hostKey, scan.Settings, err} - } - - if atomic.AddUint64(&liveThreads, ^uint64(0)) == 0 { - close(respChan) - } - }() - } - - return respChan -} - -func (s *scanner) isScanRequired() bool { - return s.scanningLastStart.IsZero() || time.Since(s.scanningLastStart) > s.scanMinInterval/20 // check 20 times per minInterval, so every 30 minutes -} - -func (s *scanner) isTimeoutUpdateRequired() bool { - return s.timeoutLastUpdate.IsZero() || time.Since(s.timeoutLastUpdate) > s.timeoutMinInterval -} - -func (s *scanner) currentTimeout() time.Duration { - s.mu.Lock() - defer s.mu.Unlock() - return s.timeout -} diff --git a/autopilot/scanner/scanner.go b/autopilot/scanner/scanner.go new file mode 100644 index 000000000..6c34274ad --- /dev/null +++ b/autopilot/scanner/scanner.go @@ -0,0 +1,307 @@ +package scanner + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" + "go.uber.org/zap" +) + +const ( + DefaultScanTimeout = 10 * time.Second +) + +type ( + HostStore interface { + HostsForScanning(ctx context.Context, opts api.HostsForScanningOptions) ([]api.HostAddress, error) + RemoveOfflineHosts(ctx context.Context, minRecentScanFailures uint64, maxDowntime time.Duration) (uint64, error) + } + + Scanner interface { + Scan(ctx context.Context, w WorkerRHPScan, force bool) + Shutdown(ctx context.Context) error + Status() (bool, time.Time) + UpdateHostsConfig(cfg api.HostsConfig) + } + + WorkerRHPScan interface { + RHPScan(ctx context.Context, hostKey types.PublicKey, hostIP string, timeout time.Duration) (api.RHPScanResponse, error) + } +) + +type ( + scanner struct { + hs HostStore + + scanBatchSize int + scanThreads int + scanInterval time.Duration + + statsHostPingMS *utils.DataPoints + + shutdownChan chan struct{} + wg sync.WaitGroup + + logger *zap.SugaredLogger + + mu sync.Mutex + hostsCfg *api.HostsConfig + + scanning bool + scanningLastStart time.Time + + interruptChan chan struct{} + } + + scanJob struct { + hostKey types.PublicKey + hostIP string + } +) + +func New(hs HostStore, scanBatchSize, scanThreads uint64, scanMinInterval time.Duration, logger *zap.Logger) (Scanner, error) { + logger = logger.Named("scanner") + if scanBatchSize == 0 { + return nil, errors.New("scanner batch size has to be greater than zero") + } + if scanThreads == 0 { + return nil, errors.New("scanner threads has to be greater than zero") + } + return &scanner{ + hs: hs, + + scanBatchSize: int(scanBatchSize), + scanThreads: int(scanThreads), + scanInterval: scanMinInterval, + + statsHostPingMS: utils.NewDataPoints(0), + logger: logger.Sugar(), + + interruptChan: make(chan struct{}), + shutdownChan: make(chan struct{}), + }, nil +} + +func (s *scanner) Scan(ctx context.Context, w WorkerRHPScan, force bool) { + if s.canSkipScan(force) { + s.logger.Debug("host scan skipped") + return + } + + cutoff := time.Now() + if !force { + cutoff = cutoff.Add(-s.scanInterval) + } + + s.logger.Infow("scan started", + "batch", s.scanBatchSize, + "force", force, + "threads", s.scanThreads, + "cutoff", cutoff, + ) + + s.wg.Add(1) + go func() { + defer s.wg.Done() + + hosts := s.fetchHosts(ctx, cutoff) + scanned := s.scanHosts(ctx, w, hosts) + removed := s.removeOfflineHosts(ctx) + + s.mu.Lock() + defer s.mu.Unlock() + s.scanning = false + s.logger.Infow("scan finished", + "force", force, + "duration", time.Since(s.scanningLastStart), + "pingMSAvg", s.statsHostPingMS.Average(), + "pingMSP90", s.statsHostPingMS.P90(), + "removed", removed, + "scanned", scanned) + }() +} + +func (s *scanner) Shutdown(ctx context.Context) error { + defer close(s.shutdownChan) + + waitChan := make(chan struct{}) + go func() { + s.wg.Wait() + close(waitChan) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-waitChan: + } + + return nil +} + +func (s *scanner) Status() (bool, time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + return s.scanning, s.scanningLastStart +} + +func (s *scanner) UpdateHostsConfig(cfg api.HostsConfig) { + s.mu.Lock() + defer s.mu.Unlock() + s.hostsCfg = &cfg +} + +func (s *scanner) fetchHosts(ctx context.Context, cutoff time.Time) chan scanJob { + jobsChan := make(chan scanJob, s.scanBatchSize) + go func() { + defer close(jobsChan) + + var exhausted bool + for offset := 0; !exhausted; offset += s.scanBatchSize { + hosts, err := s.hs.HostsForScanning(ctx, api.HostsForScanningOptions{ + MaxLastScan: api.TimeRFC3339(cutoff), + Offset: offset, + Limit: s.scanBatchSize, + }) + if err != nil { + s.logger.Errorf("could not get hosts for scanning, err: %v", err) + return + } else if len(hosts) < s.scanBatchSize { + exhausted = true + } + + s.logger.Debugf("fetched %d hosts for scanning", len(hosts)) + for _, h := range hosts { + select { + case <-s.interruptChan: + return + case <-s.shutdownChan: + return + case jobsChan <- scanJob{ + hostKey: h.PublicKey, + hostIP: h.NetAddress, + }: + } + } + } + }() + + return jobsChan +} + +func (s *scanner) scanHosts(ctx context.Context, w WorkerRHPScan, hosts chan scanJob) (scanned uint64) { + // define worker + worker := func() { + for h := range hosts { + if s.isShutdown() || s.isInterrupted() { + break // shutdown + } + + scan, err := w.RHPScan(ctx, h.hostKey, h.hostIP, DefaultScanTimeout) + if err != nil { + s.logger.Errorw("worker stopped", zap.Error(err), "hk", h.hostKey) + break // abort + } else if err := scan.Error(); err != nil { + s.logger.Debugw("host scan failed", zap.Error(err), "hk", h.hostKey, "ip", h.hostIP) + } else { + s.statsHostPingMS.Track(float64(time.Duration(scan.Ping).Milliseconds())) + atomic.AddUint64(&scanned, 1) + } + } + } + + // launch all workers + var wg sync.WaitGroup + for t := 0; t < s.scanThreads; t++ { + wg.Add(1) + go func() { + worker() + wg.Done() + }() + } + + // wait until they're done + wg.Wait() + + s.statsHostPingMS.Recompute() + return +} + +func (s *scanner) isInterrupted() bool { + select { + case <-s.interruptChan: + return true + default: + } + return false +} + +func (s *scanner) isShutdown() bool { + select { + case <-s.shutdownChan: + return true + default: + } + return false +} + +func (s *scanner) removeOfflineHosts(ctx context.Context) (removed uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.hostsCfg == nil { + s.logger.Info("no hosts config set, skipping removal of offline hosts") + return + } + + maxDowntime := time.Duration(s.hostsCfg.MaxDowntimeHours) * time.Hour + if maxDowntime == 0 { + s.logger.Info("hosts config has no max downtime set, skipping removal of offline hosts") + return + } + + s.logger.Infow("removing offline hosts", + "maxDowntime", maxDowntime, + "minRecentScanFailures", s.hostsCfg.MinRecentScanFailures) + + var err error + removed, err = s.hs.RemoveOfflineHosts(ctx, s.hostsCfg.MinRecentScanFailures, maxDowntime) + if err != nil { + s.logger.Errorw("removing offline hosts failed", zap.Error(err), "maxDowntime", maxDowntime, "minRecentScanFailures", s.hostsCfg.MinRecentScanFailures) + return + } + + return +} + +func (s *scanner) canSkipScan(force bool) bool { + if s.isShutdown() { + return true + } + + s.mu.Lock() + if force { + close(s.interruptChan) + s.mu.Unlock() + + s.logger.Infof("host scan interrupted, waiting for ongoing scan to complete") + s.wg.Wait() + + s.mu.Lock() + s.interruptChan = make(chan struct{}) + } else if s.scanning || time.Since(s.scanningLastStart) < s.scanInterval { + s.mu.Unlock() + return true + } + s.scanningLastStart = time.Now() + s.scanning = true + s.mu.Unlock() + + return false +} diff --git a/autopilot/scanner/scanner_test.go b/autopilot/scanner/scanner_test.go new file mode 100644 index 000000000..ee847395b --- /dev/null +++ b/autopilot/scanner/scanner_test.go @@ -0,0 +1,162 @@ +package scanner + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/test" + "go.uber.org/zap" +) + +const ( + testBatchSize = 40 + testNumThreads = 3 +) + +type mockHostStore struct { + hosts []api.Host + + mu sync.Mutex + scans []string + removals []string +} + +func (hs *mockHostStore) HostsForScanning(ctx context.Context, opts api.HostsForScanningOptions) ([]api.HostAddress, error) { + hs.mu.Lock() + defer hs.mu.Unlock() + hs.scans = append(hs.scans, fmt.Sprintf("%d-%d", opts.Offset, opts.Offset+opts.Limit)) + + start := opts.Offset + if start > len(hs.hosts) { + return nil, nil + } + + end := opts.Offset + opts.Limit + if end > len(hs.hosts) { + end = len(hs.hosts) + } + + var hostAddresses []api.HostAddress + for _, h := range hs.hosts[start:end] { + hostAddresses = append(hostAddresses, api.HostAddress{ + NetAddress: h.NetAddress, + PublicKey: h.PublicKey, + }) + } + return hostAddresses, nil +} + +func (hs *mockHostStore) RemoveOfflineHosts(ctx context.Context, minRecentScanFailures uint64, maxDowntime time.Duration) (uint64, error) { + hs.mu.Lock() + defer hs.mu.Unlock() + hs.removals = append(hs.removals, fmt.Sprintf("%d-%d", minRecentScanFailures, maxDowntime)) + return 0, nil +} + +func (hs *mockHostStore) state() ([]string, []string) { + hs.mu.Lock() + defer hs.mu.Unlock() + return hs.scans, hs.removals +} + +type mockWorker struct { + blockChan chan struct{} + + mu sync.Mutex + scanCount int +} + +func (w *mockWorker) RHPScan(ctx context.Context, hostKey types.PublicKey, hostIP string, _ time.Duration) (api.RHPScanResponse, error) { + if w.blockChan != nil { + <-w.blockChan + } + + w.mu.Lock() + defer w.mu.Unlock() + w.scanCount++ + + return api.RHPScanResponse{}, nil +} + +func TestScanner(t *testing.T) { + // create mock store + hs := &mockHostStore{hosts: test.NewHosts(100)} + + // create test scanner + s, err := New(hs, testBatchSize, testNumThreads, time.Minute, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer s.Shutdown(context.Background()) + + // assert it's not scanning + scanning, _ := s.Status() + if scanning { + t.Fatal("unexpected") + } + + // initiate a host scan using a worker that blocks + w := &mockWorker{blockChan: make(chan struct{})} + s.Scan(context.Background(), w, false) + + // assert it's scanning + scanning, _ = s.Status() + if !scanning { + t.Fatal("unexpected") + } + + // unblock the worker and sleep + close(w.blockChan) + time.Sleep(time.Second) + + // assert the scan is done + scanning, _ = s.Status() + if scanning { + t.Fatal("unexpected") + } + + // assert we did not remove offline hosts + if _, removals := hs.state(); len(removals) != 0 { + t.Fatalf("unexpected removals, %v != 0", len(removals)) + } + + // assert the scanner made 3 batch reqs + if scans, _ := hs.state(); len(scans) != 3 { + t.Fatalf("unexpected number of requests, %v != 3", len(scans)) + } else if scans[0] != "0-40" || scans[1] != "40-80" || scans[2] != "80-120" { + t.Fatalf("unexpected requests, %v", scans) + } + + // assert we scanned 100 hosts + if w.scanCount != 100 { + t.Fatalf("unexpected number of scans, %v != 100", w.scanCount) + } + + // assert we prevent starting a host scan immediately after a scan was done + s.Scan(context.Background(), w, false) + scanning, _ = s.Status() + if scanning { + t.Fatal("unexpected") + } + + // update the hosts config + s.UpdateHostsConfig(api.HostsConfig{ + MinRecentScanFailures: 10, + MaxDowntimeHours: 1, + }) + + s.Scan(context.Background(), w, true) + time.Sleep(time.Second) + + // assert we removed offline hosts + if _, removals := hs.state(); len(removals) != 1 { + t.Fatalf("unexpected removals, %v != 1", len(removals)) + } else if removals[0] != "10-3600000000000" { + t.Fatalf("unexpected removals, %v", removals) + } +} diff --git a/autopilot/scanner_test.go b/autopilot/scanner_test.go deleted file mode 100644 index 2dc113df2..000000000 --- a/autopilot/scanner_test.go +++ /dev/null @@ -1,161 +0,0 @@ -package autopilot - -import ( - "context" - "fmt" - "sync" - "testing" - "time" - - "go.sia.tech/core/types" - "go.sia.tech/renterd/api" - "go.sia.tech/renterd/internal/test" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -type mockBus struct { - hosts []api.Host - reqs []string -} - -func (b *mockBus) SearchHosts(ctx context.Context, opts api.SearchHostOptions) ([]api.Host, error) { - b.reqs = append(b.reqs, fmt.Sprintf("%d-%d", opts.Offset, opts.Offset+opts.Limit)) - - start := opts.Offset - if start > len(b.hosts) { - return nil, nil - } - - end := opts.Offset + opts.Limit - if end > len(b.hosts) { - end = len(b.hosts) - } - - return b.hosts[start:end], nil -} - -func (b *mockBus) HostsForScanning(ctx context.Context, opts api.HostsForScanningOptions) ([]api.HostAddress, error) { - hosts, err := b.SearchHosts(ctx, api.SearchHostOptions{ - Offset: opts.Offset, - Limit: opts.Limit, - }) - if err != nil { - return nil, err - } - var hostAddresses []api.HostAddress - for _, h := range hosts { - hostAddresses = append(hostAddresses, api.HostAddress{ - NetAddress: h.NetAddress, - PublicKey: h.PublicKey, - }) - } - return hostAddresses, nil -} - -func (b *mockBus) RemoveOfflineHosts(ctx context.Context, minRecentScanFailures uint64, maxDowntime time.Duration) (uint64, error) { - return 0, nil -} - -type mockWorker struct { - blockChan chan struct{} - - mu sync.Mutex - scanCount int -} - -func (w *mockWorker) RHPScan(ctx context.Context, hostKey types.PublicKey, hostIP string, _ time.Duration) (api.RHPScanResponse, error) { - if w.blockChan != nil { - <-w.blockChan - } - - w.mu.Lock() - defer w.mu.Unlock() - w.scanCount++ - - return api.RHPScanResponse{}, nil -} - -func (w *mockWorker) RHPPriceTable(ctx context.Context, hostKey types.PublicKey, siamuxAddr string) (api.HostPriceTable, error) { - return api.HostPriceTable{}, nil -} - -func TestScanner(t *testing.T) { - // prepare 100 hosts - hosts := test.NewHosts(100) - - // init new scanner - b := &mockBus{hosts: hosts} - w := &mockWorker{blockChan: make(chan struct{})} - s := newTestScanner(b) - - // assert it started a host scan - s.tryPerformHostScan(context.Background(), w, false) - if !s.isScanning() { - t.Fatal("unexpected") - } - - // unblock the worker and sleep - close(w.blockChan) - time.Sleep(time.Second) - - // assert the scan is done - if s.isScanning() { - t.Fatal("unexpected") - } - - // assert the scanner made 3 batch reqs - if len(b.reqs) != 3 { - t.Fatalf("unexpected number of requests, %v != 3", len(b.reqs)) - } - if b.reqs[0] != "0-40" || b.reqs[1] != "40-80" || b.reqs[2] != "80-120" { - t.Fatalf("unexpected requests, %v", b.reqs) - } - - // assert we scanned 100 hosts - if w.scanCount != 100 { - t.Fatalf("unexpected number of scans, %v != 100", w.scanCount) - } - - // assert we prevent starting a host scan immediately after a scan was done - s.tryPerformHostScan(context.Background(), w, false) - if s.isScanning() { - t.Fatal("unexpected") - } - - // reset the scanner - s.scanningLastStart = time.Time{} - - // assert it started a host scan - s.tryPerformHostScan(context.Background(), w, false) - if !s.isScanning() { - t.Fatal("unexpected") - } -} - -func (s *scanner) isScanning() bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.scanning -} - -func newTestScanner(b *mockBus) *scanner { - ap := &Autopilot{} - ap.shutdownCtx, ap.shutdownCtxCancel = context.WithCancel(context.Background()) - return &scanner{ - ap: ap, - bus: b, - logger: zap.New(zapcore.NewNopCore()).Sugar(), - tracker: newTracker( - trackerMinDataPoints, - trackerNumDataPoints, - trackerTimeoutPercentile, - ), - - interruptScanChan: make(chan struct{}), - - scanBatchSize: 40, - scanThreads: 3, - scanMinInterval: time.Minute, - } -} diff --git a/build/env_default.go b/build/env_default.go deleted file mode 100644 index 3730fd5b2..000000000 --- a/build/env_default.go +++ /dev/null @@ -1,58 +0,0 @@ -//go:build !testnet - -package build - -import ( - "time" - - "go.sia.tech/core/types" - "go.sia.tech/renterd/api" -) - -const ( - network = "mainnet" - - DefaultAPIAddress = "localhost:9980" - DefaultGatewayAddress = ":9981" - DefaultS3Address = "localhost:8080" -) - -var ( - // DefaultGougingSettings define the default gouging settings the bus is - // configured with on startup. These values can be adjusted using the - // settings API. - DefaultGougingSettings = api.GougingSettings{ - MaxRPCPrice: types.Siacoins(1).Div64(1000), // 1mS per RPC - MaxContractPrice: types.Siacoins(1), // 1 SC per contract - MaxDownloadPrice: types.Siacoins(3000), // 3000 SC per 1 TiB - MaxUploadPrice: types.Siacoins(3000), // 3000 SC per 1 TiB - MaxStoragePrice: types.Siacoins(3000).Div64(1 << 40).Div64(144 * 30), // 3000 SC per TiB per month - HostBlockHeightLeeway: 6, // 6 blocks - MinPriceTableValidity: 5 * time.Minute, // 5 minutes - MinAccountExpiry: 24 * time.Hour, // 1 day - MinMaxEphemeralAccountBalance: types.Siacoins(1), // 1 SC - MigrationSurchargeMultiplier: 10, // 10x - } - - // DefaultPricePinSettings define the default price pin settings the bus is - // configured with on startup. These values can be adjusted using the - // settings API. - DefaultPricePinSettings = api.PricePinSettings{ - Enabled: false, - } - - // DefaultUploadPackingSettings define the default upload packing settings - // the bus is configured with on startup. - DefaultUploadPackingSettings = api.UploadPackingSettings{ - Enabled: true, - SlabBufferMaxSizeSoft: 1 << 32, // 4 GiB - } - - // DefaultRedundancySettings define the default redundancy settings the bus - // is configured with on startup. These values can be adjusted using the - // settings API. - DefaultRedundancySettings = api.RedundancySettings{ - MinShards: 10, - TotalShards: 30, - } -) diff --git a/build/env_testnet.go b/build/env_testnet.go deleted file mode 100644 index 5ccf6f24f..000000000 --- a/build/env_testnet.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build testnet - -package build - -import ( - "time" - - "go.sia.tech/core/types" - "go.sia.tech/renterd/api" -) - -const ( - network = "zen" - - DefaultAPIAddress = "localhost:9880" - DefaultGatewayAddress = ":9881" - DefaultS3Address = "localhost:7070" -) - -var ( - // DefaultGougingSettings define the default gouging settings the bus is - // configured with on startup. These values can be adjusted using the - // settings API. - // - // NOTE: default gouging settings for testnet are identical to mainnet. - DefaultGougingSettings = api.GougingSettings{ - MaxRPCPrice: types.Siacoins(1).Div64(1000), // 1mS per RPC - MaxContractPrice: types.Siacoins(15), // 15 SC per contract - MaxDownloadPrice: types.Siacoins(3000), // 3000 SC per 1 TiB - MaxUploadPrice: types.Siacoins(3000), // 3000 SC per 1 TiB - MaxStoragePrice: types.Siacoins(3000).Div64(1 << 40).Div64(144 * 30), // 3000 SC per TiB per month - HostBlockHeightLeeway: 6, // 6 blocks - MinPriceTableValidity: 5 * time.Minute, // 5 minutes - MinAccountExpiry: 24 * time.Hour, // 1 day - MinMaxEphemeralAccountBalance: types.Siacoins(1), // 1 SC - MigrationSurchargeMultiplier: 10, // 10x - } - - // DefaultPricePinSettings define the default price pin settings the bus is - // configured with on startup. These values can be adjusted using the - // settings API. - DefaultPricePinSettings = api.PricePinSettings{ - Enabled: false, - } - - // DefaultUploadPackingSettings define the default upload packing settings - // the bus is configured with on startup. - DefaultUploadPackingSettings = api.UploadPackingSettings{ - Enabled: true, - SlabBufferMaxSizeSoft: 1 << 32, // 4 GiB - } - - // DefaultRedundancySettings define the default redundancy settings the bus - // is configured with on startup. These values can be adjusted using the - // settings API. - // - // NOTE: default redundancy settings for testnet are different from mainnet. - DefaultRedundancySettings = api.RedundancySettings{ - MinShards: 2, - TotalShards: 6, - } -) diff --git a/build/gen.go b/build/gen.go index efeeb5c53..26ad6f705 100644 --- a/build/gen.go +++ b/build/gen.go @@ -31,6 +31,8 @@ var buildTemplate = template.Must(template.New("").Parse(`// Code generated by g // This file was generated by go generate at {{ .RunTime }}. package build +//go:generate go run gen.go + import ( "time" ) diff --git a/build/meta.go b/build/meta.go index 8e932ba6e..49aeb39eb 100644 --- a/build/meta.go +++ b/build/meta.go @@ -1,7 +1,9 @@ // Code generated by go generate; DO NOT EDIT. -// This file was generated by go generate at 2023-08-11T11:44:04+02:00. +// This file was generated by go generate at 2024-08-12T08:54:43-07:00. package build +//go:generate go run gen.go + import ( "time" ) diff --git a/build/network.go b/build/network.go deleted file mode 100644 index a0a452189..000000000 --- a/build/network.go +++ /dev/null @@ -1,33 +0,0 @@ -package build - -//go:generate go run gen.go - -import ( - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/coreutils/chain" -) - -// Network returns the Sia network consts and genesis block for the current build. -func Network() (*consensus.Network, types.Block) { - switch network { - case "mainnet": - return chain.Mainnet() - case "zen": - return chain.TestnetZen() - default: - panic("unknown network: " + network) - } -} - -func NetworkName() string { - n, _ := Network() - switch n.Name { - case "mainnet": - return "Mainnet" - case "zen": - return "Zen Testnet" - default: - return n.Name - } -} diff --git a/bus/bus.go b/bus/bus.go index 312089c9e..431d5abd5 100644 --- a/bus/bus.go +++ b/bus/bus.go @@ -1,39 +1,42 @@ package bus +// TODOs: +// - add UPNP support + import ( "context" "encoding/json" "errors" "fmt" - "io" - "math" + "math/big" "net/http" - "runtime" - "sort" "strings" "time" "go.sia.tech/core/consensus" + "go.sia.tech/core/gateway" rhpv2 "go.sia.tech/core/rhp/v2" rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" - "go.sia.tech/gofakes3" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/jape" "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" - "go.sia.tech/renterd/build" "go.sia.tech/renterd/bus/client" ibus "go.sia.tech/renterd/internal/bus" "go.sia.tech/renterd/object" - "go.sia.tech/renterd/wallet" + "go.sia.tech/renterd/stores/sql" "go.sia.tech/renterd/webhooks" - "go.sia.tech/siad/modules" "go.uber.org/zap" ) const ( - defaultPinUpdateInterval = 5 * time.Minute - defaultPinRateWindow = 6 * time.Hour + defaultWalletRecordMetricInterval = 5 * time.Minute + defaultPinUpdateInterval = 5 * time.Minute + defaultPinRateWindow = 6 * time.Hour + stdTxnSize = 1200 // bytes ) // Client re-exports the client from the client package. @@ -52,23 +55,49 @@ func NewClient(addr, password string) *Client { } type ( - // A ChainManager manages blockchain state. + AccountManager interface { + Account(id rhpv3.Account, hostKey types.PublicKey) (api.Account, error) + Accounts() []api.Account + AddAmount(id rhpv3.Account, hk types.PublicKey, amt *big.Int) + LockAccount(ctx context.Context, id rhpv3.Account, hostKey types.PublicKey, exclusive bool, duration time.Duration) (api.Account, uint64) + ResetDrift(id rhpv3.Account) error + SetBalance(id rhpv3.Account, hk types.PublicKey, balance *big.Int) + ScheduleSync(id rhpv3.Account, hk types.PublicKey) error + Shutdown(context.Context) error + UnlockAccount(id rhpv3.Account, lockID uint64) error + } + + AlertManager interface { + alerts.Alerter + RegisterWebhookBroadcaster(b webhooks.Broadcaster) + } + ChainManager interface { - AcceptBlock(types.Block) error - BlockAtHeight(height uint64) (types.Block, bool) - IndexAtHeight(height uint64) (types.ChainIndex, error) - LastBlockTime() time.Time - Subscribe(s modules.ConsensusSetSubscriber, ccID modules.ConsensusChangeID, cancel <-chan struct{}) error - Synced() bool + AddBlocks(blocks []types.Block) error + AddPoolTransactions(txns []types.Transaction) (bool, error) + AddV2PoolTransactions(basis types.ChainIndex, txns []types.V2Transaction) (known bool, err error) + Block(id types.BlockID) (types.Block, bool) + OnReorg(fn func(types.ChainIndex)) (cancel func()) + PoolTransaction(txid types.TransactionID) (types.Transaction, bool) + PoolTransactions() []types.Transaction + V2PoolTransactions() []types.V2Transaction + RecommendedFee() types.Currency + Tip() types.ChainIndex TipState() consensus.State + UnconfirmedParents(txn types.Transaction) []types.Transaction + UpdatesSince(index types.ChainIndex, max int) (rus []chain.RevertUpdate, aus []chain.ApplyUpdate, err error) + V2UnconfirmedParents(txn types.V2Transaction) []types.V2Transaction } - // A Syncer can connect to other peers and synchronize the blockchain. - Syncer interface { - BroadcastTransaction(txn types.Transaction, dependsOn []types.Transaction) - Connect(addr string) error - Peers() []string - SyncerAddress(ctx context.Context) (string, error) + ContractLocker interface { + Acquire(ctx context.Context, priority int, id types.FileContractID, d time.Duration) (uint64, error) + KeepAlive(id types.FileContractID, lockID uint64, d time.Duration) error + Release(id types.FileContractID, lockID uint64) error + } + + ChainSubscriber interface { + ChainIndex(context.Context) (types.ChainIndex, error) + Shutdown(context.Context) error } // A TransactionPool can validate and relay unconfirmed transactions. @@ -76,26 +105,95 @@ type ( AcceptTransactionSet(txns []types.Transaction) error Close() error RecommendedFee() types.Currency - Subscribe(subscriber modules.TransactionPoolSubscriber) Transactions() []types.Transaction UnconfirmedParents(txn types.Transaction) ([]types.Transaction, error) } - // A Wallet can spend and receive siacoins. + UploadingSectorsCache interface { + AddSector(uID api.UploadID, fcid types.FileContractID, root types.Hash256) error + FinishUpload(uID api.UploadID) + HandleRenewal(fcid, renewedFrom types.FileContractID) + Pending(fcid types.FileContractID) (size uint64) + Sectors(fcid types.FileContractID) (roots []types.Hash256) + StartUpload(uID api.UploadID) error + } + + PinManager interface { + Shutdown(context.Context) error + TriggerUpdate() + } + + Syncer interface { + Addr() string + BroadcastHeader(h gateway.BlockHeader) + BroadcastV2BlockOutline(bo gateway.V2BlockOutline) + BroadcastTransactionSet([]types.Transaction) + BroadcastV2TransactionSet(index types.ChainIndex, txns []types.V2Transaction) + Connect(ctx context.Context, addr string) (*syncer.Peer, error) + Peers() []*syncer.Peer + } + Wallet interface { Address() types.Address - Balance() (spendable, confirmed, unconfirmed types.Currency, _ error) - FundTransaction(cs consensus.State, txn *types.Transaction, amount types.Currency, useUnconfirmedTxns bool) ([]types.Hash256, error) - Height() uint64 - Redistribute(cs consensus.State, outputs int, amount, feePerByte types.Currency, pool []types.Transaction) ([]types.Transaction, []types.Hash256, error) - ReleaseInputs(txn ...types.Transaction) - SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error - Transactions(before, since time.Time, offset, limit int) ([]wallet.Transaction, error) - UnspentOutputs() ([]wallet.SiacoinElement, error) + Balance() (wallet.Balance, error) + Close() error + FundTransaction(txn *types.Transaction, amount types.Currency, useUnconfirmed bool) ([]types.Hash256, error) + FundV2Transaction(txn *types.V2Transaction, amount types.Currency, useUnconfirmed bool) (consensus.State, []int, error) + Redistribute(outputs int, amount, feePerByte types.Currency) (txns []types.Transaction, toSign []types.Hash256, err error) + RedistributeV2(outputs int, amount, feePerByte types.Currency) (txns []types.V2Transaction, toSign [][]int, err error) + ReleaseInputs(txns []types.Transaction, v2txns []types.V2Transaction) + SignTransaction(txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) + SignV2Inputs(state consensus.State, txn *types.V2Transaction, toSign []int) + SpendableOutputs() ([]types.SiacoinElement, error) + Tip() (types.ChainIndex, error) + UnconfirmedEvents() ([]wallet.Event, error) + UpdateChainState(tx wallet.UpdateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate) error + Events(offset, limit int) ([]wallet.Event, error) + } + + WebhooksManager interface { + webhooks.Broadcaster + Delete(context.Context, webhooks.Webhook) error + Info() ([]webhooks.Webhook, []webhooks.WebhookQueueInfo) + Register(context.Context, webhooks.Webhook) error + Shutdown(context.Context) error + } + + // Store is a collection of stores used by the bus. + Store interface { + AccountStore + AutopilotStore + ChainStore + HostStore + MetadataStore + MetricsStore + SettingStore + } + + // AccountStore persists information about accounts. Since accounts + // are rapidly updated and can be recovered, they are only loaded upon + // startup and persisted upon shutdown. + AccountStore interface { + Accounts(context.Context) ([]api.Account, error) + SaveAccounts(context.Context, []api.Account) error + SetUncleanShutdown(context.Context) error + } + + // An AutopilotStore stores autopilots. + AutopilotStore interface { + Autopilot(ctx context.Context, id string) (api.Autopilot, error) + Autopilots(ctx context.Context) ([]api.Autopilot, error) + UpdateAutopilot(ctx context.Context, ap api.Autopilot) error + } + + // A ChainStore stores information about the chain. + ChainStore interface { + ChainIndex(ctx context.Context) (types.ChainIndex, error) + ProcessChainUpdate(ctx context.Context, applyFn func(sql.ChainUpdateTx) error) error } - // A HostDB stores information about hosts. - HostDB interface { + // A HostStore stores information about hosts. + HostStore interface { Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) HostAllowlist(ctx context.Context) ([]types.PublicKey, error) HostBlocklist(ctx context.Context) ([]string, error) @@ -172,30 +270,7 @@ type ( UpdateSlab(ctx context.Context, s object.Slab, contractSet string) error } - // An AutopilotStore stores autopilots. - AutopilotStore interface { - Autopilot(ctx context.Context, id string) (api.Autopilot, error) - Autopilots(ctx context.Context) ([]api.Autopilot, error) - UpdateAutopilot(ctx context.Context, ap api.Autopilot) error - } - - // A SettingStore stores settings. - SettingStore interface { - DeleteSetting(ctx context.Context, key string) error - Setting(ctx context.Context, key string) (string, error) - Settings(ctx context.Context) ([]string, error) - UpdateSetting(ctx context.Context, key, value string) error - } - - // EphemeralAccountStore persists information about accounts. Since accounts - // are rapidly updated and can be recovered, they are only loaded upon - // startup and persisted upon shutdown. - EphemeralAccountStore interface { - Accounts(context.Context) ([]api.Account, error) - SaveAccounts(context.Context, []api.Account) error - SetUncleanShutdown(context.Context) error - } - + // A MetricsStore stores metrics. MetricsStore interface { ContractSetMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractSetMetricsQueryOpts) ([]api.ContractSetMetric, error) @@ -210,37 +285,101 @@ type ( RecordContractSetChurnMetric(ctx context.Context, metrics ...api.ContractSetChurnMetric) error WalletMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.WalletMetricsQueryOpts) ([]api.WalletMetric, error) + RecordWalletMetric(ctx context.Context, metrics ...api.WalletMetric) error + } + + // A SettingStore stores settings. + SettingStore interface { + DeleteSetting(ctx context.Context, key string) error + Setting(ctx context.Context, key string) (string, error) + Settings(ctx context.Context) ([]string, error) + UpdateSetting(ctx context.Context, key, value string) error + } + + WalletMetricsRecorder interface { + Shutdown(context.Context) error } ) -type bus struct { +type Bus struct { startTime time.Time - cm ChainManager - s Syncer - tp TransactionPool + accountsMgr AccountManager + alerts alerts.Alerter + alertMgr AlertManager + pinMgr PinManager + webhooksMgr WebhooksManager + cm ChainManager + cs ChainSubscriber + s Syncer + w Wallet as AutopilotStore - eas EphemeralAccountStore - hdb HostDB + hs HostStore ms MetadataStore - ss SettingStore mtrcs MetricsStore - w Wallet + ss SettingStore - accounts *accounts - contractLocks *contractLocks - uploadingSectors *uploadingSectorsCache + contractLocker ContractLocker + sectors UploadingSectorsCache + walletMetricsRecorder WalletMetricsRecorder - alerts alerts.Alerter - alertMgr *alerts.Manager - pinMgr ibus.PinManager - webhooksMgr *webhooks.Manager - logger *zap.SugaredLogger + logger *zap.SugaredLogger +} + +// New returns a new Bus +func New(ctx context.Context, am AlertManager, wm WebhooksManager, cm ChainManager, s Syncer, w Wallet, store Store, announcementMaxAge time.Duration, l *zap.Logger) (_ *Bus, err error) { + l = l.Named("bus") + + b := &Bus{ + s: s, + cm: cm, + w: w, + hs: store, + as: store, + ms: store, + mtrcs: store, + ss: store, + + alerts: alerts.WithOrigin(am, "bus"), + alertMgr: am, + webhooksMgr: wm, + logger: l.Sugar(), + + startTime: time.Now(), + } + + // init settings + if err := b.initSettings(ctx); err != nil { + return nil, err + } + + // create account manager + b.accountsMgr, err = ibus.NewAccountManager(ctx, store, l) + if err != nil { + return nil, err + } + + // create contract locker + b.contractLocker = ibus.NewContractLocker() + + // create sectors cache + b.sectors = ibus.NewSectorsCache() + + // create pin manager + b.pinMgr = ibus.NewPinManager(b.alerts, wm, store, defaultPinUpdateInterval, defaultPinRateWindow, l) + + // create chain subscriber + b.cs = ibus.NewChainSubscriber(wm, cm, store, w, announcementMaxAge, l) + + // create wallet metrics recorder + b.walletMetricsRecorder = ibus.NewWalletMetricRecorder(store, w, defaultWalletRecordMetricInterval, l) + + return b, nil } // Handler returns an HTTP handler that serves the bus API. -func (b *bus) Handler() http.Handler { +func (b *Bus) Handler() http.Handler { return jape.Mux(map[string]jape.Handler{ "GET /accounts": b.accountsHandlerGET, "POST /account/:id": b.accountHandlerGET, @@ -371,6 +510,7 @@ func (b *bus) Handler() http.Handler { "POST /wallet/prepare/form": b.walletPrepareFormHandler, "POST /wallet/prepare/renew": b.walletPrepareRenewHandler, "POST /wallet/redistribute": b.walletRedistributeHandler, + "POST /wallet/send": b.walletSendSiacoinsHandler, "POST /wallet/sign": b.walletSignHandler, "GET /wallet/transactions": b.walletTransactionsHandler, @@ -381,2205 +521,127 @@ func (b *bus) Handler() http.Handler { }) } -// Setup starts the pin manager. -func (b *bus) Setup(ctx context.Context) error { - return b.pinMgr.Run(ctx) -} - // Shutdown shuts down the bus. -func (b *bus) Shutdown(ctx context.Context) error { - b.webhooksMgr.Close() - accounts := b.accounts.ToPersist() - err := b.eas.SaveAccounts(ctx, accounts) - if err != nil { - b.logger.Errorf("failed to save %v accounts: %v", len(accounts), err) - } else { - b.logger.Infof("successfully saved %v accounts", len(accounts)) - } - +func (b *Bus) Shutdown(ctx context.Context) error { return errors.Join( - err, - b.pinMgr.Close(ctx), + b.walletMetricsRecorder.Shutdown(ctx), + b.accountsMgr.Shutdown(ctx), + b.webhooksMgr.Shutdown(ctx), + b.pinMgr.Shutdown(ctx), + b.cs.Shutdown(ctx), ) } -func (b *bus) fetchSetting(ctx context.Context, key string, value interface{}) error { - if val, err := b.ss.Setting(ctx, key); err != nil { - return fmt.Errorf("could not get contract set settings: %w", err) - } else if err := json.Unmarshal([]byte(val), &value); err != nil { - b.logger.Panicf("failed to unmarshal %v settings '%s': %v", key, val, err) - } - return nil -} - -func (b *bus) consensusAcceptBlock(jc jape.Context) { - var block types.Block - if jc.Decode(&block) != nil { - return - } - if jc.Check("failed to accept block", b.cm.AcceptBlock(block)) != nil { - return - } -} - -func (b *bus) syncerAddrHandler(jc jape.Context) { - addr, err := b.s.SyncerAddress(jc.Request.Context()) - if jc.Check("failed to fetch syncer's address", err) != nil { - return - } - jc.Encode(addr) -} - -func (b *bus) syncerPeersHandler(jc jape.Context) { - jc.Encode(b.s.Peers()) -} - -func (b *bus) syncerConnectHandler(jc jape.Context) { - var addr string - if jc.Decode(&addr) == nil { - jc.Check("couldn't connect to peer", b.s.Connect(addr)) - } -} - -func (b *bus) consensusStateHandler(jc jape.Context) { - jc.Encode(b.consensusState()) -} - -func (b *bus) consensusNetworkHandler(jc jape.Context) { - jc.Encode(api.ConsensusNetwork{ - Name: b.cm.TipState().Network.Name, - }) -} - -func (b *bus) txpoolFeeHandler(jc jape.Context) { - fee := b.tp.RecommendedFee() - jc.Encode(fee) -} - -func (b *bus) txpoolTransactionsHandler(jc jape.Context) { - jc.Encode(b.tp.Transactions()) -} - -func (b *bus) txpoolBroadcastHandler(jc jape.Context) { - var txnSet []types.Transaction - if jc.Decode(&txnSet) == nil { - jc.Check("couldn't broadcast transaction set", b.tp.AcceptTransactionSet(txnSet)) - } -} - -func (b *bus) bucketsHandlerGET(jc jape.Context) { - resp, err := b.ms.ListBuckets(jc.Request.Context()) - if jc.Check("couldn't list buckets", err) != nil { - return - } - jc.Encode(resp) -} - -func (b *bus) bucketsHandlerPOST(jc jape.Context) { - var bucket api.BucketCreateRequest - if jc.Decode(&bucket) != nil { - return - } else if bucket.Name == "" { - jc.Error(errors.New("no name provided"), http.StatusBadRequest) - return - } else if jc.Check("failed to create bucket", b.ms.CreateBucket(jc.Request.Context(), bucket.Name, bucket.Policy)) != nil { - return - } -} - -func (b *bus) bucketsHandlerPolicyPUT(jc jape.Context) { - var req api.BucketUpdatePolicyRequest - if jc.Decode(&req) != nil { - return - } else if bucket := jc.PathParam("name"); bucket == "" { - jc.Error(errors.New("no bucket name provided"), http.StatusBadRequest) - return - } else if jc.Check("failed to create bucket", b.ms.UpdateBucketPolicy(jc.Request.Context(), bucket, req.Policy)) != nil { - return - } -} - -func (b *bus) bucketHandlerDELETE(jc jape.Context) { - var name string - if jc.DecodeParam("name", &name) != nil { - return - } else if name == "" { - jc.Error(errors.New("no name provided"), http.StatusBadRequest) - return - } else if jc.Check("failed to delete bucket", b.ms.DeleteBucket(jc.Request.Context(), name)) != nil { - return - } -} - -func (b *bus) bucketHandlerGET(jc jape.Context) { - var name string - if jc.DecodeParam("name", &name) != nil { - return - } else if name == "" { - jc.Error(errors.New("parameter 'name' is required"), http.StatusBadRequest) - return - } - bucket, err := b.ms.Bucket(jc.Request.Context(), name) - if errors.Is(err, api.ErrBucketNotFound) { - jc.Error(err, http.StatusNotFound) - return - } else if jc.Check("failed to fetch bucket", err) != nil { - return - } - jc.Encode(bucket) -} - -func (b *bus) walletHandler(jc jape.Context) { - address := b.w.Address() - spendable, confirmed, unconfirmed, err := b.w.Balance() - if jc.Check("couldn't fetch wallet balance", err) != nil { - return - } - jc.Encode(api.WalletResponse{ - ScanHeight: b.w.Height(), - Address: address, - Confirmed: confirmed, - Spendable: spendable, - Unconfirmed: unconfirmed, - }) -} - -func (b *bus) walletTransactionsHandler(jc jape.Context) { - var before, since time.Time - offset := 0 - limit := -1 - if jc.DecodeForm("before", (*api.TimeRFC3339)(&before)) != nil || - jc.DecodeForm("since", (*api.TimeRFC3339)(&since)) != nil || - jc.DecodeForm("offset", &offset) != nil || - jc.DecodeForm("limit", &limit) != nil { - return - } - txns, err := b.w.Transactions(before, since, offset, limit) - if jc.Check("couldn't load transactions", err) == nil { - jc.Encode(txns) - } -} - -func (b *bus) walletOutputsHandler(jc jape.Context) { - utxos, err := b.w.UnspentOutputs() - if jc.Check("couldn't load outputs", err) == nil { - jc.Encode(utxos) - } -} - -func (b *bus) walletFundHandler(jc jape.Context) { - var wfr api.WalletFundRequest - if jc.Decode(&wfr) != nil { - return - } - txn := wfr.Transaction - if len(txn.MinerFees) == 0 { - // if no fees are specified, we add some - fee := b.tp.RecommendedFee().Mul64(b.cm.TipState().TransactionWeight(txn)) - txn.MinerFees = []types.Currency{fee} - } - toSign, err := b.w.FundTransaction(b.cm.TipState(), &txn, wfr.Amount.Add(txn.MinerFees[0]), wfr.UseUnconfirmedTxns) - if jc.Check("couldn't fund transaction", err) != nil { - return - } - parents, err := b.tp.UnconfirmedParents(txn) - if jc.Check("couldn't load transaction dependencies", err) != nil { - b.w.ReleaseInputs(txn) - return - } - jc.Encode(api.WalletFundResponse{ - Transaction: txn, - ToSign: toSign, - DependsOn: parents, - }) -} - -func (b *bus) walletSignHandler(jc jape.Context) { - var wsr api.WalletSignRequest - if jc.Decode(&wsr) != nil { - return - } - err := b.w.SignTransaction(b.cm.TipState(), &wsr.Transaction, wsr.ToSign, wsr.CoveredFields) - if jc.Check("couldn't sign transaction", err) == nil { - jc.Encode(wsr.Transaction) - } -} - -func (b *bus) walletRedistributeHandler(jc jape.Context) { - var wfr api.WalletRedistributeRequest - if jc.Decode(&wfr) != nil { - return - } - if wfr.Outputs == 0 { - jc.Error(errors.New("'outputs' has to be greater than zero"), http.StatusBadRequest) - return - } - - cs := b.cm.TipState() - txns, toSign, err := b.w.Redistribute(cs, wfr.Outputs, wfr.Amount, b.tp.RecommendedFee(), b.tp.Transactions()) - if jc.Check("couldn't redistribute money in the wallet into the desired outputs", err) != nil { - return +// initSettings loads the default settings if the setting is not already set and +// ensures the settings are valid +func (b *Bus) initSettings(ctx context.Context) error { + // testnets have different redundancy settings + defaultRedundancySettings := api.DefaultRedundancySettings + if mn, _ := chain.Mainnet(); mn.Name != b.cm.TipState().Network.Name { + defaultRedundancySettings = api.DefaultRedundancySettingsTestnet } - var ids []types.TransactionID - if len(txns) == 0 { - jc.Encode(ids) - return - } - - for i := 0; i < len(txns); i++ { - err = b.w.SignTransaction(cs, &txns[i], toSign, types.CoveredFields{WholeTransaction: true}) - if jc.Check("couldn't sign the transaction", err) != nil { - b.w.ReleaseInputs(txns...) - return + // load default settings if the setting is not already set + for key, value := range map[string]interface{}{ + api.SettingGouging: api.DefaultGougingSettings, + api.SettingPricePinning: api.DefaultPricePinSettings, + api.SettingRedundancy: defaultRedundancySettings, + api.SettingUploadPacking: api.DefaultUploadPackingSettings, + } { + if _, err := b.ss.Setting(ctx, key); errors.Is(err, api.ErrSettingNotFound) { + if bytes, err := json.Marshal(value); err != nil { + panic("failed to marshal default settings") // should never happen + } else if err := b.ss.UpdateSetting(ctx, key, string(bytes)); err != nil { + return err + } } - ids = append(ids, txns[i].ID()) - } - - if jc.Check("couldn't broadcast the transaction", b.tp.AcceptTransactionSet(txns)) != nil { - b.w.ReleaseInputs(txns...) - return - } - - jc.Encode(ids) -} - -func (b *bus) walletDiscardHandler(jc jape.Context) { - var txn types.Transaction - if jc.Decode(&txn) == nil { - b.w.ReleaseInputs(txn) - } -} - -func (b *bus) walletPrepareFormHandler(jc jape.Context) { - var wpfr api.WalletPrepareFormRequest - if jc.Decode(&wpfr) != nil { - return - } - if wpfr.HostKey == (types.PublicKey{}) { - jc.Error(errors.New("no host key provided"), http.StatusBadRequest) - return - } - if wpfr.RenterKey == (types.PublicKey{}) { - jc.Error(errors.New("no renter key provided"), http.StatusBadRequest) - return - } - cs := b.cm.TipState() - - fc := rhpv2.PrepareContractFormation(wpfr.RenterKey, wpfr.HostKey, wpfr.RenterFunds, wpfr.HostCollateral, wpfr.EndHeight, wpfr.HostSettings, wpfr.RenterAddress) - cost := rhpv2.ContractFormationCost(cs, fc, wpfr.HostSettings.ContractPrice) - txn := types.Transaction{ - FileContracts: []types.FileContract{fc}, - } - txn.MinerFees = []types.Currency{b.tp.RecommendedFee().Mul64(cs.TransactionWeight(txn))} - toSign, err := b.w.FundTransaction(cs, &txn, cost.Add(txn.MinerFees[0]), true) - if jc.Check("couldn't fund transaction", err) != nil { - return - } - cf := wallet.ExplicitCoveredFields(txn) - err = b.w.SignTransaction(cs, &txn, toSign, cf) - if jc.Check("couldn't sign transaction", err) != nil { - b.w.ReleaseInputs(txn) - return - } - parents, err := b.tp.UnconfirmedParents(txn) - if jc.Check("couldn't load transaction dependencies", err) != nil { - b.w.ReleaseInputs(txn) - return - } - jc.Encode(append(parents, txn)) -} - -func (b *bus) walletPrepareRenewHandler(jc jape.Context) { - var wprr api.WalletPrepareRenewRequest - if jc.Decode(&wprr) != nil { - return - } - if wprr.RenterKey == nil { - jc.Error(errors.New("no renter key provided"), http.StatusBadRequest) - return - } - cs := b.cm.TipState() - - // Create the final revision from the provided revision. - finalRevision := wprr.Revision - finalRevision.MissedProofOutputs = finalRevision.ValidProofOutputs - finalRevision.Filesize = 0 - finalRevision.FileMerkleRoot = types.Hash256{} - finalRevision.RevisionNumber = math.MaxUint64 - - // Prepare the new contract. - fc, basePrice, err := rhpv3.PrepareContractRenewal(wprr.Revision, wprr.HostAddress, wprr.RenterAddress, wprr.RenterFunds, wprr.MinNewCollateral, wprr.PriceTable, wprr.ExpectedNewStorage, wprr.EndHeight) - if jc.Check("couldn't prepare contract renewal", err) != nil { - return - } - - // Create the transaction containing both the final revision and new - // contract. - txn := types.Transaction{ - FileContracts: []types.FileContract{fc}, - FileContractRevisions: []types.FileContractRevision{finalRevision}, - MinerFees: []types.Currency{wprr.PriceTable.TxnFeeMaxRecommended.Mul64(4096)}, - } - - // Compute how much renter funds to put into the new contract. - cost := rhpv3.ContractRenewalCost(cs, wprr.PriceTable, fc, txn.MinerFees[0], basePrice) - - // Make sure we don't exceed the max fund amount. - // TODO: remove the IsZero check for the v2 change - if /*!wprr.MaxFundAmount.IsZero() &&*/ wprr.MaxFundAmount.Cmp(cost) < 0 { - jc.Error(fmt.Errorf("%w: %v > %v", api.ErrMaxFundAmountExceeded, cost, wprr.MaxFundAmount), http.StatusBadRequest) - return } - // Fund the txn. We are not signing it yet since it's not complete. The host - // still needs to complete it and the revision + contract are signed with - // the renter key by the worker. - toSign, err := b.w.FundTransaction(cs, &txn, cost, true) - if jc.Check("couldn't fund transaction", err) != nil { - return - } - - // Add any required parents. - parents, err := b.tp.UnconfirmedParents(txn) - if jc.Check("couldn't load transaction dependencies", err) != nil { - b.w.ReleaseInputs(txn) - return + // check redundancy settings for validity + var rs api.RedundancySettings + if rss, err := b.ss.Setting(ctx, api.SettingRedundancy); err != nil { + return err + } else if err := json.Unmarshal([]byte(rss), &rs); err != nil { + return err + } else if err := rs.Validate(); err != nil { + b.logger.Warn(fmt.Sprintf("invalid redundancy setting found '%v', overwriting the redundancy settings with the default settings", rss)) + bytes, _ := json.Marshal(defaultRedundancySettings) + if err := b.ss.UpdateSetting(ctx, api.SettingRedundancy, string(bytes)); err != nil { + return err + } } - jc.Encode(api.WalletPrepareRenewResponse{ - FundAmount: cost, - ToSign: toSign, - TransactionSet: append(parents, txn), - }) -} -func (b *bus) walletPendingHandler(jc jape.Context) { - isRelevant := func(txn types.Transaction) bool { - addr := b.w.Address() - for _, sci := range txn.SiacoinInputs { - if sci.UnlockConditions.UnlockHash() == addr { - return true + // check gouging settings for validity + var gs api.GougingSettings + if gss, err := b.ss.Setting(ctx, api.SettingGouging); err != nil { + return err + } else if err := json.Unmarshal([]byte(gss), &gs); err != nil { + return err + } else if err := gs.Validate(); err != nil { + // compat: apply default EA gouging settings + gs.MinMaxEphemeralAccountBalance = api.DefaultGougingSettings.MinMaxEphemeralAccountBalance + gs.MinPriceTableValidity = api.DefaultGougingSettings.MinPriceTableValidity + gs.MinAccountExpiry = api.DefaultGougingSettings.MinAccountExpiry + if err := gs.Validate(); err == nil { + b.logger.Info(fmt.Sprintf("updating gouging settings with default EA settings: %+v", gs)) + bytes, _ := json.Marshal(gs) + if err := b.ss.UpdateSetting(ctx, api.SettingGouging, string(bytes)); err != nil { + return err } - } - for _, sco := range txn.SiacoinOutputs { - if sco.Address == addr { - return true + } else { + // compat: apply default host block leeway settings + gs.HostBlockHeightLeeway = api.DefaultGougingSettings.HostBlockHeightLeeway + if err := gs.Validate(); err == nil { + b.logger.Info(fmt.Sprintf("updating gouging settings with default HostBlockHeightLeeway settings: %v", gs)) + bytes, _ := json.Marshal(gs) + if err := b.ss.UpdateSetting(ctx, api.SettingGouging, string(bytes)); err != nil { + return err + } + } else { + b.logger.Warn(fmt.Sprintf("invalid gouging setting found '%v', overwriting the gouging settings with the default settings", gss)) + bytes, _ := json.Marshal(api.DefaultGougingSettings) + if err := b.ss.UpdateSetting(ctx, api.SettingGouging, string(bytes)); err != nil { + return err + } } } - return false } - txns := b.tp.Transactions() - relevant := txns[:0] - for _, txn := range txns { - if isRelevant(txn) { - relevant = append(relevant, txn) + // compat: default price pin settings + var pps api.PricePinSettings + if pss, err := b.ss.Setting(ctx, api.SettingPricePinning); err != nil { + return err + } else if err := json.Unmarshal([]byte(pss), &pps); err != nil { + return err + } else if err := pps.Validate(); err != nil { + // overwrite values with defaults + var updates []string + if pps.ForexEndpointURL == "" { + pps.ForexEndpointURL = api.DefaultPricePinSettings.ForexEndpointURL + updates = append(updates, fmt.Sprintf("set PricePinSettings.ForexEndpointURL to %v", pps.ForexEndpointURL)) + } + if pps.Currency == "" { + pps.Currency = api.DefaultPricePinSettings.Currency + updates = append(updates, fmt.Sprintf("set PricePinSettings.Currency to %v", pps.Currency)) + } + if pps.Threshold == 0 { + pps.Threshold = api.DefaultPricePinSettings.Threshold + updates = append(updates, fmt.Sprintf("set PricePinSettings.Threshold to %v", pps.Threshold)) } - } - jc.Encode(relevant) -} - -func (b *bus) hostsHandlerGETDeprecated(jc jape.Context) { - offset := 0 - limit := -1 - if jc.DecodeForm("offset", &offset) != nil || jc.DecodeForm("limit", &limit) != nil { - return - } - - // fetch hosts - hosts, err := b.hdb.SearchHosts(jc.Request.Context(), "", api.HostFilterModeAllowed, api.UsabilityFilterModeAll, "", nil, offset, limit) - if jc.Check(fmt.Sprintf("couldn't fetch hosts %d-%d", offset, offset+limit), err) != nil { - return - } - jc.Encode(hosts) -} - -func (b *bus) searchHostsHandlerPOST(jc jape.Context) { - var req api.SearchHostsRequest - if jc.Decode(&req) != nil { - return - } - - // TODO: on the next major release: - // - properly default search params (currently no defaults are set) - // - properly validate and return 400 (currently validation is done in autopilot and the store) - - hosts, err := b.hdb.SearchHosts(jc.Request.Context(), req.AutopilotID, req.FilterMode, req.UsabilityMode, req.AddressContains, req.KeyIn, req.Offset, req.Limit) - if jc.Check(fmt.Sprintf("couldn't fetch hosts %d-%d", req.Offset, req.Offset+req.Limit), err) != nil { - return - } - jc.Encode(hosts) -} -func (b *bus) hostsRemoveHandlerPOST(jc jape.Context) { - var hrr api.HostsRemoveRequest - if jc.Decode(&hrr) != nil { - return - } - if hrr.MaxDowntimeHours == 0 { - jc.Error(errors.New("maxDowntime must be non-zero"), http.StatusBadRequest) - return - } - if hrr.MinRecentScanFailures == 0 { - jc.Error(errors.New("minRecentScanFailures must be non-zero"), http.StatusBadRequest) - return - } - removed, err := b.hdb.RemoveOfflineHosts(jc.Request.Context(), hrr.MinRecentScanFailures, time.Duration(hrr.MaxDowntimeHours)) - if jc.Check("couldn't remove offline hosts", err) != nil { - return - } - jc.Encode(removed) -} + var updated []byte + if err := pps.Validate(); err == nil { + b.logger.Info(fmt.Sprintf("updating price pinning settings with default values: %v", strings.Join(updates, ", "))) + updated, _ = json.Marshal(pps) + } else { + b.logger.Warn(fmt.Sprintf("updated price pinning settings are invalid (%v), they have been overwritten with the default settings", err)) + updated, _ = json.Marshal(api.DefaultPricePinSettings) + } -func (b *bus) hostsScanningHandlerGET(jc jape.Context) { - offset := 0 - limit := -1 - maxLastScan := time.Now() - if jc.DecodeForm("offset", &offset) != nil || jc.DecodeForm("limit", &limit) != nil || jc.DecodeForm("lastScan", (*api.TimeRFC3339)(&maxLastScan)) != nil { - return - } - hosts, err := b.hdb.HostsForScanning(jc.Request.Context(), maxLastScan, offset, limit) - if jc.Check(fmt.Sprintf("couldn't fetch hosts %d-%d", offset, offset+limit), err) != nil { - return + if err := b.ss.UpdateSetting(ctx, api.SettingPricePinning, string(updated)); err != nil { + return err + } } - jc.Encode(hosts) -} -func (b *bus) hostsPubkeyHandlerGET(jc jape.Context) { - var hostKey types.PublicKey - if jc.DecodeParam("hostkey", &hostKey) != nil { - return - } - host, err := b.hdb.Host(jc.Request.Context(), hostKey) - if jc.Check("couldn't load host", err) == nil { - jc.Encode(host) - } -} - -func (b *bus) hostsResetLostSectorsPOST(jc jape.Context) { - var hostKey types.PublicKey - if jc.DecodeParam("hostkey", &hostKey) != nil { - return - } - err := b.hdb.ResetLostSectors(jc.Request.Context(), hostKey) - if jc.Check("couldn't reset lost sectors", err) != nil { - return - } -} - -func (b *bus) hostsScanHandlerPOST(jc jape.Context) { - var req api.HostsScanRequest - if jc.Decode(&req) != nil { - return - } - if jc.Check("failed to record scans", b.hdb.RecordHostScans(jc.Request.Context(), req.Scans)) != nil { - return - } -} - -func (b *bus) hostsPricetableHandlerPOST(jc jape.Context) { - var req api.HostsPriceTablesRequest - if jc.Decode(&req) != nil { - return - } - if jc.Check("failed to record interactions", b.hdb.RecordPriceTables(jc.Request.Context(), req.PriceTableUpdates)) != nil { - return - } -} - -func (b *bus) contractsSpendingHandlerPOST(jc jape.Context) { - var records []api.ContractSpendingRecord - if jc.Decode(&records) != nil { - return - } - if jc.Check("failed to record spending metrics for contract", b.ms.RecordContractSpending(jc.Request.Context(), records)) != nil { - return - } -} - -func (b *bus) hostsAllowlistHandlerGET(jc jape.Context) { - allowlist, err := b.hdb.HostAllowlist(jc.Request.Context()) - if jc.Check("couldn't load allowlist", err) == nil { - jc.Encode(allowlist) - } -} - -func (b *bus) hostsAllowlistHandlerPUT(jc jape.Context) { - ctx := jc.Request.Context() - var req api.UpdateAllowlistRequest - if jc.Decode(&req) == nil { - if len(req.Add)+len(req.Remove) > 0 && req.Clear { - jc.Error(errors.New("cannot add or remove entries while clearing the allowlist"), http.StatusBadRequest) - return - } else if jc.Check("couldn't update allowlist entries", b.hdb.UpdateHostAllowlistEntries(ctx, req.Add, req.Remove, req.Clear)) != nil { - return - } - } -} - -func (b *bus) hostsBlocklistHandlerGET(jc jape.Context) { - blocklist, err := b.hdb.HostBlocklist(jc.Request.Context()) - if jc.Check("couldn't load blocklist", err) == nil { - jc.Encode(blocklist) - } -} - -func (b *bus) hostsBlocklistHandlerPUT(jc jape.Context) { - ctx := jc.Request.Context() - var req api.UpdateBlocklistRequest - if jc.Decode(&req) == nil { - if len(req.Add)+len(req.Remove) > 0 && req.Clear { - jc.Error(errors.New("cannot add or remove entries while clearing the blocklist"), http.StatusBadRequest) - return - } else if jc.Check("couldn't update blocklist entries", b.hdb.UpdateHostBlocklistEntries(ctx, req.Add, req.Remove, req.Clear)) != nil { - return - } - } -} - -func (b *bus) contractsHandlerGET(jc jape.Context) { - var cs string - if jc.DecodeForm("contractset", &cs) != nil { - return - } - contracts, err := b.ms.Contracts(jc.Request.Context(), api.ContractsOpts{ - ContractSet: cs, - }) - if jc.Check("couldn't load contracts", err) == nil { - jc.Encode(contracts) - } -} - -func (b *bus) contractsRenewedIDHandlerGET(jc jape.Context) { - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - - md, err := b.ms.RenewedContract(jc.Request.Context(), id) - if jc.Check("faild to fetch renewed contract", err) == nil { - jc.Encode(md) - } -} - -func (b *bus) contractsArchiveHandlerPOST(jc jape.Context) { - var toArchive api.ContractsArchiveRequest - if jc.Decode(&toArchive) != nil { - return - } - - if jc.Check("failed to archive contracts", b.ms.ArchiveContracts(jc.Request.Context(), toArchive)) == nil { - for fcid, reason := range toArchive { - b.broadcastAction(webhooks.Event{ - Module: api.ModuleContract, - Event: api.EventArchive, - Payload: api.EventContractArchive{ - ContractID: fcid, - Reason: reason, - Timestamp: time.Now().UTC(), - }, - }) - } - } -} - -func (b *bus) contractsSetsHandlerGET(jc jape.Context) { - sets, err := b.ms.ContractSets(jc.Request.Context()) - if jc.Check("couldn't fetch contract sets", err) == nil { - jc.Encode(sets) - } -} - -func (b *bus) contractsSetHandlerPUT(jc jape.Context) { - var contractIds []types.FileContractID - if set := jc.PathParam("set"); set == "" { - jc.Error(errors.New("path parameter 'set' can not be empty"), http.StatusBadRequest) - return - } else if jc.Decode(&contractIds) != nil { - return - } else if jc.Check("could not add contracts to set", b.ms.SetContractSet(jc.Request.Context(), set, contractIds)) != nil { - return - } else { - b.broadcastAction(webhooks.Event{ - Module: api.ModuleContractSet, - Event: api.EventUpdate, - Payload: api.EventContractSetUpdate{ - Name: set, - ContractIDs: contractIds, - Timestamp: time.Now().UTC(), - }, - }) - } -} - -func (b *bus) contractsSetHandlerDELETE(jc jape.Context) { - if set := jc.PathParam("set"); set != "" { - jc.Check("could not remove contract set", b.ms.RemoveContractSet(jc.Request.Context(), set)) - } -} - -func (b *bus) contractAcquireHandlerPOST(jc jape.Context) { - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.ContractAcquireRequest - if jc.Decode(&req) != nil { - return - } - - lockID, err := b.contractLocks.Acquire(jc.Request.Context(), req.Priority, id, time.Duration(req.Duration)) - if jc.Check("failed to acquire contract", err) != nil { - return - } - jc.Encode(api.ContractAcquireResponse{ - LockID: lockID, - }) -} - -func (b *bus) contractKeepaliveHandlerPOST(jc jape.Context) { - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.ContractKeepaliveRequest - if jc.Decode(&req) != nil { - return - } - - err := b.contractLocks.KeepAlive(id, req.LockID, time.Duration(req.Duration)) - if jc.Check("failed to extend lock duration", err) != nil { - return - } -} - -func (b *bus) contractsPrunableDataHandlerGET(jc jape.Context) { - sizes, err := b.ms.ContractSizes(jc.Request.Context()) - if jc.Check("failed to fetch contract sizes", err) != nil { - return - } - - // prepare the response - var contracts []api.ContractPrunableData - var totalPrunable, totalSize uint64 - - // build the response - for fcid, size := range sizes { - // adjust the amount of prunable data with the pending uploads, due to - // how we record contract spending a contract's size might already - // include pending sectors - pending := b.uploadingSectors.Pending(fcid) - if pending > size.Prunable { - size.Prunable = 0 - } else { - size.Prunable -= pending - } - - contracts = append(contracts, api.ContractPrunableData{ - ID: fcid, - ContractSize: size, - }) - totalPrunable += size.Prunable - totalSize += size.Size - } - - // sort contracts by the amount of prunable data - sort.Slice(contracts, func(i, j int) bool { - if contracts[i].Prunable == contracts[j].Prunable { - return contracts[i].Size > contracts[j].Size - } - return contracts[i].Prunable > contracts[j].Prunable - }) - - jc.Encode(api.ContractsPrunableDataResponse{ - Contracts: contracts, - TotalPrunable: totalPrunable, - TotalSize: totalSize, - }) -} - -func (b *bus) contractSizeHandlerGET(jc jape.Context) { - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - - size, err := b.ms.ContractSize(jc.Request.Context(), id) - if errors.Is(err, api.ErrContractNotFound) { - jc.Error(err, http.StatusNotFound) - return - } else if jc.Check("failed to fetch contract size", err) != nil { - return - } - - // adjust the amount of prunable data with the pending uploads, due to how - // we record contract spending a contract's size might already include - // pending sectors - pending := b.uploadingSectors.Pending(id) - if pending > size.Prunable { - size.Prunable = 0 - } else { - size.Prunable -= pending - } - - jc.Encode(size) -} - -func (b *bus) contractReleaseHandlerPOST(jc jape.Context) { - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.ContractReleaseRequest - if jc.Decode(&req) != nil { - return - } - if jc.Check("failed to release contract", b.contractLocks.Release(id, req.LockID)) != nil { - return - } -} - -func (b *bus) contractIDHandlerGET(jc jape.Context) { - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - c, err := b.ms.Contract(jc.Request.Context(), id) - if jc.Check("couldn't load contract", err) == nil { - jc.Encode(c) - } -} - -func (b *bus) contractIDHandlerPOST(jc jape.Context) { - var id types.FileContractID - var req api.ContractAddRequest - if jc.DecodeParam("id", &id) != nil || jc.Decode(&req) != nil { - return - } - if req.Contract.ID() != id { - http.Error(jc.ResponseWriter, "contract ID mismatch", http.StatusBadRequest) - return - } - if req.TotalCost.IsZero() { - http.Error(jc.ResponseWriter, "TotalCost can not be zero", http.StatusBadRequest) - return - } - - a, err := b.ms.AddContract(jc.Request.Context(), req.Contract, req.ContractPrice, req.TotalCost, req.StartHeight, req.State) - if jc.Check("couldn't store contract", err) == nil { - jc.Encode(a) - } -} - -func (b *bus) contractIDRenewedHandlerPOST(jc jape.Context) { - var id types.FileContractID - var req api.ContractRenewedRequest - if jc.DecodeParam("id", &id) != nil || jc.Decode(&req) != nil { - return - } - if req.Contract.ID() != id { - http.Error(jc.ResponseWriter, "contract ID mismatch", http.StatusBadRequest) - return - } - if req.TotalCost.IsZero() { - http.Error(jc.ResponseWriter, "TotalCost can not be zero", http.StatusBadRequest) - return - } - if req.State == "" { - req.State = api.ContractStatePending - } - r, err := b.ms.AddRenewedContract(jc.Request.Context(), req.Contract, req.ContractPrice, req.TotalCost, req.StartHeight, req.RenewedFrom, req.State) - if jc.Check("couldn't store contract", err) != nil { - return - } - - b.uploadingSectors.HandleRenewal(req.Contract.ID(), req.RenewedFrom) - b.broadcastAction(webhooks.Event{ - Module: api.ModuleContract, - Event: api.EventRenew, - Payload: api.EventContractRenew{ - Renewal: r, - Timestamp: time.Now().UTC(), - }, - }) - - jc.Encode(r) -} - -func (b *bus) contractIDRootsHandlerGET(jc jape.Context) { - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - - roots, err := b.ms.ContractRoots(jc.Request.Context(), id) - if jc.Check("couldn't fetch contract sectors", err) == nil { - jc.Encode(api.ContractRootsResponse{ - Roots: roots, - Uploading: b.uploadingSectors.Sectors(id), - }) - } -} - -func (b *bus) contractIDHandlerDELETE(jc jape.Context) { - var id types.FileContractID - if jc.DecodeParam("id", &id) != nil { - return - } - jc.Check("couldn't remove contract", b.ms.ArchiveContract(jc.Request.Context(), id, api.ContractArchivalReasonRemoved)) -} - -func (b *bus) contractsAllHandlerDELETE(jc jape.Context) { - jc.Check("couldn't remove contracts", b.ms.ArchiveAllContracts(jc.Request.Context(), api.ContractArchivalReasonRemoved)) -} - -func (b *bus) searchObjectsHandlerGET(jc jape.Context) { - offset := 0 - limit := -1 - var key string - if jc.DecodeForm("offset", &offset) != nil || jc.DecodeForm("limit", &limit) != nil || jc.DecodeForm("key", &key) != nil { - return - } - bucket := api.DefaultBucketName - if jc.DecodeForm("bucket", &bucket) != nil { - return - } - keys, err := b.ms.SearchObjects(jc.Request.Context(), bucket, key, offset, limit) - if jc.Check("couldn't list objects", err) != nil { - return - } - jc.Encode(keys) -} - -func (b *bus) objectsHandlerGET(jc jape.Context) { - var ignoreDelim bool - if jc.DecodeForm("ignoreDelim", &ignoreDelim) != nil { - return - } - path := jc.PathParam("path") - if strings.HasSuffix(path, "/") && !ignoreDelim { - b.objectEntriesHandlerGET(jc, path) - return - } - bucket := api.DefaultBucketName - if jc.DecodeForm("bucket", &bucket) != nil { - return - } - var onlymetadata bool - if jc.DecodeForm("onlymetadata", &onlymetadata) != nil { - return - } - - var o api.Object - var err error - if onlymetadata { - o, err = b.ms.ObjectMetadata(jc.Request.Context(), bucket, path) - } else { - o, err = b.ms.Object(jc.Request.Context(), bucket, path) - } - if errors.Is(err, api.ErrObjectNotFound) { - jc.Error(err, http.StatusNotFound) - return - } else if jc.Check("couldn't load object", err) != nil { - return - } - jc.Encode(api.ObjectsResponse{Object: &o}) -} - -func (b *bus) objectEntriesHandlerGET(jc jape.Context, path string) { - bucket := api.DefaultBucketName - if jc.DecodeForm("bucket", &bucket) != nil { - return - } - - var prefix string - if jc.DecodeForm("prefix", &prefix) != nil { - return - } - - var sortBy string - if jc.DecodeForm("sortBy", &sortBy) != nil { - return - } - - var sortDir string - if jc.DecodeForm("sortDir", &sortDir) != nil { - return - } - - var marker string - if jc.DecodeForm("marker", &marker) != nil { - return - } - - var offset int - if jc.DecodeForm("offset", &offset) != nil { - return - } - limit := -1 - if jc.DecodeForm("limit", &limit) != nil { - return - } - - // look for object entries - entries, hasMore, err := b.ms.ObjectEntries(jc.Request.Context(), bucket, path, prefix, sortBy, sortDir, marker, offset, limit) - if jc.Check("couldn't list object entries", err) != nil { - return - } - - jc.Encode(api.ObjectsResponse{Entries: entries, HasMore: hasMore}) -} - -func (b *bus) objectsHandlerPUT(jc jape.Context) { - var aor api.AddObjectRequest - if jc.Decode(&aor) != nil { - return - } else if aor.Bucket == "" { - aor.Bucket = api.DefaultBucketName - } - jc.Check("couldn't store object", b.ms.UpdateObject(jc.Request.Context(), aor.Bucket, jc.PathParam("path"), aor.ContractSet, aor.ETag, aor.MimeType, aor.Metadata, aor.Object)) -} - -func (b *bus) objectsCopyHandlerPOST(jc jape.Context) { - var orr api.CopyObjectsRequest - if jc.Decode(&orr) != nil { - return - } - om, err := b.ms.CopyObject(jc.Request.Context(), orr.SourceBucket, orr.DestinationBucket, orr.SourcePath, orr.DestinationPath, orr.MimeType, orr.Metadata) - if jc.Check("couldn't copy object", err) != nil { - return - } - - jc.ResponseWriter.Header().Set("Last-Modified", om.ModTime.Std().Format(http.TimeFormat)) - jc.ResponseWriter.Header().Set("ETag", api.FormatETag(om.ETag)) - jc.Encode(om) -} - -func (b *bus) objectsListHandlerPOST(jc jape.Context) { - var req api.ObjectsListRequest - if jc.Decode(&req) != nil { - return - } - if req.Bucket == "" { - req.Bucket = api.DefaultBucketName - } - resp, err := b.ms.ListObjects(jc.Request.Context(), req.Bucket, req.Prefix, req.SortBy, req.SortDir, req.Marker, req.Limit) - if errors.Is(err, api.ErrMarkerNotFound) { - jc.Error(err, http.StatusBadRequest) - return - } else if jc.Check("couldn't list objects", err) != nil { - return - } - jc.Encode(resp) -} - -func (b *bus) objectsRenameHandlerPOST(jc jape.Context) { - var orr api.ObjectsRenameRequest - if jc.Decode(&orr) != nil { - return - } else if orr.Bucket == "" { - orr.Bucket = api.DefaultBucketName - } - if orr.Mode == api.ObjectsRenameModeSingle { - // Single object rename. - if strings.HasSuffix(orr.From, "/") || strings.HasSuffix(orr.To, "/") { - jc.Error(fmt.Errorf("can't rename dirs with mode %v", orr.Mode), http.StatusBadRequest) - return - } - jc.Check("couldn't rename object", b.ms.RenameObject(jc.Request.Context(), orr.Bucket, orr.From, orr.To, orr.Force)) - return - } else if orr.Mode == api.ObjectsRenameModeMulti { - // Multi object rename. - if !strings.HasSuffix(orr.From, "/") || !strings.HasSuffix(orr.To, "/") { - jc.Error(fmt.Errorf("can't rename file with mode %v", orr.Mode), http.StatusBadRequest) - return - } - jc.Check("couldn't rename objects", b.ms.RenameObjects(jc.Request.Context(), orr.Bucket, orr.From, orr.To, orr.Force)) - return - } else { - // Invalid mode. - jc.Error(fmt.Errorf("invalid mode: %v", orr.Mode), http.StatusBadRequest) - return - } -} - -func (b *bus) objectsHandlerDELETE(jc jape.Context) { - var batch bool - if jc.DecodeForm("batch", &batch) != nil { - return - } - bucket := api.DefaultBucketName - if jc.DecodeForm("bucket", &bucket) != nil { - return - } - var err error - if batch { - err = b.ms.RemoveObjects(jc.Request.Context(), bucket, jc.PathParam("path")) - } else { - err = b.ms.RemoveObject(jc.Request.Context(), bucket, jc.PathParam("path")) - } - if errors.Is(err, api.ErrObjectNotFound) { - jc.Error(err, http.StatusNotFound) - return - } - jc.Check("couldn't delete object", err) -} - -func (b *bus) slabbuffersHandlerGET(jc jape.Context) { - buffers, err := b.ms.SlabBuffers(jc.Request.Context()) - if jc.Check("couldn't get slab buffers info", err) != nil { - return - } - jc.Encode(buffers) -} - -func (b *bus) objectsStatshandlerGET(jc jape.Context) { - opts := api.ObjectsStatsOpts{} - if jc.DecodeForm("bucket", &opts.Bucket) != nil { - return - } - info, err := b.ms.ObjectsStats(jc.Request.Context(), opts) - if jc.Check("couldn't get objects stats", err) != nil { - return - } - jc.Encode(info) -} - -func (b *bus) packedSlabsHandlerFetchPOST(jc jape.Context) { - var psrg api.PackedSlabsRequestGET - if jc.Decode(&psrg) != nil { - return - } - if psrg.MinShards == 0 || psrg.TotalShards == 0 { - jc.Error(fmt.Errorf("min_shards and total_shards must be non-zero"), http.StatusBadRequest) - return - } - if psrg.LockingDuration == 0 { - jc.Error(fmt.Errorf("locking_duration must be non-zero"), http.StatusBadRequest) - return - } - if psrg.ContractSet == "" { - jc.Error(fmt.Errorf("contract_set must be non-empty"), http.StatusBadRequest) - return - } - slabs, err := b.ms.PackedSlabsForUpload(jc.Request.Context(), time.Duration(psrg.LockingDuration), psrg.MinShards, psrg.TotalShards, psrg.ContractSet, psrg.Limit) - if jc.Check("couldn't get packed slabs", err) != nil { - return - } - jc.Encode(slabs) -} - -func (b *bus) packedSlabsHandlerDonePOST(jc jape.Context) { - var psrp api.PackedSlabsRequestPOST - if jc.Decode(&psrp) != nil { - return - } - jc.Check("failed to mark packed slab(s) as uploaded", b.ms.MarkPackedSlabsUploaded(jc.Request.Context(), psrp.Slabs)) -} - -func (b *bus) sectorsHostRootHandlerDELETE(jc jape.Context) { - var hk types.PublicKey - var root types.Hash256 - if jc.DecodeParam("hk", &hk) != nil { - return - } else if jc.DecodeParam("root", &root) != nil { - return - } - n, err := b.ms.DeleteHostSector(jc.Request.Context(), hk, root) - if jc.Check("failed to mark sector as lost", err) != nil { - return - } else if n > 0 { - b.logger.Infow("successfully marked sector as lost", "hk", hk, "root", root) - } -} - -func (b *bus) slabObjectsHandlerGET(jc jape.Context) { - var key object.EncryptionKey - if jc.DecodeParam("key", &key) != nil { - return - } - bucket := api.DefaultBucketName - if jc.DecodeForm("bucket", &bucket) != nil { - return - } - objects, err := b.ms.ObjectsBySlabKey(jc.Request.Context(), bucket, key) - if jc.Check("failed to retrieve objects by slab", err) != nil { - return - } - jc.Encode(objects) -} - -func (b *bus) slabHandlerGET(jc jape.Context) { - var key object.EncryptionKey - if jc.DecodeParam("key", &key) != nil { - return - } - slab, err := b.ms.Slab(jc.Request.Context(), key) - if errors.Is(err, api.ErrSlabNotFound) { - jc.Error(err, http.StatusNotFound) - return - } else if err != nil { - jc.Error(err, http.StatusInternalServerError) - return - } - jc.Encode(slab) -} - -func (b *bus) slabHandlerPUT(jc jape.Context) { - var usr api.UpdateSlabRequest - if jc.Decode(&usr) == nil { - jc.Check("couldn't update slab", b.ms.UpdateSlab(jc.Request.Context(), usr.Slab, usr.ContractSet)) - } -} - -func (b *bus) slabsRefreshHealthHandlerPOST(jc jape.Context) { - jc.Check("failed to recompute health", b.ms.RefreshHealth(jc.Request.Context())) -} - -func (b *bus) slabsMigrationHandlerPOST(jc jape.Context) { - var msr api.MigrationSlabsRequest - if jc.Decode(&msr) == nil { - if slabs, err := b.ms.UnhealthySlabs(jc.Request.Context(), msr.HealthCutoff, msr.ContractSet, msr.Limit); jc.Check("couldn't fetch slabs for migration", err) == nil { - jc.Encode(api.UnhealthySlabsResponse{ - Slabs: slabs, - }) - } - } -} - -func (b *bus) slabsPartialHandlerGET(jc jape.Context) { - jc.Custom(nil, []byte{}) - - var key object.EncryptionKey - if jc.DecodeParam("key", &key) != nil { - return - } - var offset int - if jc.DecodeForm("offset", &offset) != nil { - return - } - var length int - if jc.DecodeForm("length", &length) != nil { - return - } - if length <= 0 || offset < 0 { - jc.Error(fmt.Errorf("length must be positive and offset must be non-negative"), http.StatusBadRequest) - return - } - data, err := b.ms.FetchPartialSlab(jc.Request.Context(), key, uint32(offset), uint32(length)) - if errors.Is(err, api.ErrObjectNotFound) { - jc.Error(err, http.StatusNotFound) - return - } else if err != nil { - jc.Error(err, http.StatusInternalServerError) - return - } - jc.ResponseWriter.Write(data) -} - -func (b *bus) slabsPartialHandlerPOST(jc jape.Context) { - var minShards int - if jc.DecodeForm("minShards", &minShards) != nil { - return - } - var totalShards int - if jc.DecodeForm("totalShards", &totalShards) != nil { - return - } - var contractSet string - if jc.DecodeForm("contractSet", &contractSet) != nil { - return - } - if minShards <= 0 || totalShards <= minShards { - jc.Error(errors.New("minShards must be positive and totalShards must be greater than minShards"), http.StatusBadRequest) - return - } - if totalShards > math.MaxUint8 { - jc.Error(fmt.Errorf("totalShards must be less than or equal to %d", math.MaxUint8), http.StatusBadRequest) - return - } - if contractSet == "" { - jc.Error(errors.New("parameter 'contractSet' is required"), http.StatusBadRequest) - return - } - data, err := io.ReadAll(jc.Request.Body) - if jc.Check("failed to read request body", err) != nil { - return - } - slabs, bufferSize, err := b.ms.AddPartialSlab(jc.Request.Context(), data, uint8(minShards), uint8(totalShards), contractSet) - if jc.Check("failed to add partial slab", err) != nil { - return - } - var pus api.UploadPackingSettings - if err := b.fetchSetting(jc.Request.Context(), api.SettingUploadPacking, &pus); err != nil && !errors.Is(err, api.ErrSettingNotFound) { - jc.Error(fmt.Errorf("could not get upload packing settings: %w", err), http.StatusInternalServerError) - return - } - jc.Encode(api.AddPartialSlabResponse{ - Slabs: slabs, - SlabBufferMaxSizeSoftReached: bufferSize >= pus.SlabBufferMaxSizeSoft, - }) -} - -func (b *bus) settingsHandlerGET(jc jape.Context) { - if settings, err := b.ss.Settings(jc.Request.Context()); jc.Check("couldn't load settings", err) == nil { - jc.Encode(settings) - } -} - -func (b *bus) settingKeyHandlerGET(jc jape.Context) { - key := jc.PathParam("key") - if key == "" { - jc.Error(errors.New("path parameter 'key' can not be empty"), http.StatusBadRequest) - return - } - - setting, err := b.ss.Setting(jc.Request.Context(), jc.PathParam("key")) - if errors.Is(err, api.ErrSettingNotFound) { - jc.Error(err, http.StatusNotFound) - return - } - if err != nil { - jc.Error(err, http.StatusInternalServerError) - return - } - - var resp interface{} - err = json.Unmarshal([]byte(setting), &resp) - if err != nil { - jc.Error(fmt.Errorf("couldn't unmarshal the setting, error: %v", err), http.StatusInternalServerError) - return - } - - jc.Encode(resp) -} - -func (b *bus) settingKeyHandlerPUT(jc jape.Context) { - key := jc.PathParam("key") - if key == "" { - jc.Error(errors.New("path parameter 'key' can not be empty"), http.StatusBadRequest) - return - } - - var value interface{} - if jc.Decode(&value) != nil { - return - } - - data, err := json.Marshal(value) - if err != nil { - jc.Error(fmt.Errorf("couldn't marshal the given value, error: %v", err), http.StatusBadRequest) - return - } - - switch key { - case api.SettingGouging: - var gs api.GougingSettings - if err := json.Unmarshal(data, &gs); err != nil { - jc.Error(fmt.Errorf("couldn't update gouging settings, invalid request body, %t", value), http.StatusBadRequest) - return - } else if err := gs.Validate(); err != nil { - jc.Error(fmt.Errorf("couldn't update gouging settings, error: %v", err), http.StatusBadRequest) - return - } - b.pinMgr.TriggerUpdate() - case api.SettingRedundancy: - var rs api.RedundancySettings - if err := json.Unmarshal(data, &rs); err != nil { - jc.Error(fmt.Errorf("couldn't update redundancy settings, invalid request body"), http.StatusBadRequest) - return - } else if err := rs.Validate(); err != nil { - jc.Error(fmt.Errorf("couldn't update redundancy settings, error: %v", err), http.StatusBadRequest) - return - } - case api.SettingS3Authentication: - var s3as api.S3AuthenticationSettings - if err := json.Unmarshal(data, &s3as); err != nil { - jc.Error(fmt.Errorf("couldn't update s3 authentication settings, invalid request body"), http.StatusBadRequest) - return - } else if err := s3as.Validate(); err != nil { - jc.Error(fmt.Errorf("couldn't update s3 authentication settings, error: %v", err), http.StatusBadRequest) - return - } - case api.SettingPricePinning: - var pps api.PricePinSettings - if err := json.Unmarshal(data, &pps); err != nil { - jc.Error(fmt.Errorf("couldn't update price pinning settings, invalid request body"), http.StatusBadRequest) - return - } else if err := pps.Validate(); err != nil { - jc.Error(fmt.Errorf("couldn't update price pinning settings, invalid settings, error: %v", err), http.StatusBadRequest) - return - } else if pps.Enabled { - if _, err := ibus.NewForexClient(pps.ForexEndpointURL).SiacoinExchangeRate(jc.Request.Context(), pps.Currency); err != nil { - jc.Error(fmt.Errorf("couldn't update price pinning settings, forex API unreachable,error: %v", err), http.StatusBadRequest) - return - } - } - b.pinMgr.TriggerUpdate() - } - - if jc.Check("could not update setting", b.ss.UpdateSetting(jc.Request.Context(), key, string(data))) == nil { - b.broadcastAction(webhooks.Event{ - Module: api.ModuleSetting, - Event: api.EventUpdate, - Payload: api.EventSettingUpdate{ - Key: key, - Update: value, - Timestamp: time.Now().UTC(), - }, - }) - } -} - -func (b *bus) settingKeyHandlerDELETE(jc jape.Context) { - key := jc.PathParam("key") - if key == "" { - jc.Error(errors.New("path parameter 'key' can not be empty"), http.StatusBadRequest) - return - } - - if jc.Check("could not delete setting", b.ss.DeleteSetting(jc.Request.Context(), key)) == nil { - b.broadcastAction(webhooks.Event{ - Module: api.ModuleSetting, - Event: api.EventDelete, - Payload: api.EventSettingDelete{ - Key: key, - Timestamp: time.Now().UTC(), - }, - }) - } -} - -func (b *bus) contractIDAncestorsHandler(jc jape.Context) { - var fcid types.FileContractID - if jc.DecodeParam("id", &fcid) != nil { - return - } - var minStartHeight uint64 - if jc.DecodeForm("minStartHeight", &minStartHeight) != nil { - return - } - ancestors, err := b.ms.AncestorContracts(jc.Request.Context(), fcid, uint64(minStartHeight)) - if jc.Check("failed to fetch ancestor contracts", err) != nil { - return - } - jc.Encode(ancestors) -} - -func (b *bus) paramsHandlerUploadGET(jc jape.Context) { - gp, err := b.gougingParams(jc.Request.Context()) - if jc.Check("could not get gouging parameters", err) != nil { - return - } - - var contractSet string - var css api.ContractSetSetting - if err := b.fetchSetting(jc.Request.Context(), api.SettingContractSet, &css); err != nil && !errors.Is(err, api.ErrSettingNotFound) { - jc.Error(fmt.Errorf("could not get contract set settings: %w", err), http.StatusInternalServerError) - return - } else if err == nil { - contractSet = css.Default - } - - var uploadPacking bool - var pus api.UploadPackingSettings - if err := b.fetchSetting(jc.Request.Context(), api.SettingUploadPacking, &pus); err != nil && !errors.Is(err, api.ErrSettingNotFound) { - jc.Error(fmt.Errorf("could not get upload packing settings: %w", err), http.StatusInternalServerError) - return - } else if err == nil { - uploadPacking = pus.Enabled - } - - jc.Encode(api.UploadParams{ - ContractSet: contractSet, - CurrentHeight: b.cm.TipState().Index.Height, - GougingParams: gp, - UploadPacking: uploadPacking, - }) -} - -func (b *bus) consensusState() api.ConsensusState { - return api.ConsensusState{ - BlockHeight: b.cm.TipState().Index.Height, - LastBlockTime: api.TimeRFC3339(b.cm.LastBlockTime()), - Synced: b.cm.Synced(), - } -} - -func (b *bus) paramsHandlerGougingGET(jc jape.Context) { - gp, err := b.gougingParams(jc.Request.Context()) - if jc.Check("could not get gouging parameters", err) != nil { - return - } - jc.Encode(gp) -} - -func (b *bus) gougingParams(ctx context.Context) (api.GougingParams, error) { - var gs api.GougingSettings - if gss, err := b.ss.Setting(ctx, api.SettingGouging); err != nil { - return api.GougingParams{}, err - } else if err := json.Unmarshal([]byte(gss), &gs); err != nil { - b.logger.Panicf("failed to unmarshal gouging settings '%s': %v", gss, err) - } - - var rs api.RedundancySettings - if rss, err := b.ss.Setting(ctx, api.SettingRedundancy); err != nil { - return api.GougingParams{}, err - } else if err := json.Unmarshal([]byte(rss), &rs); err != nil { - b.logger.Panicf("failed to unmarshal redundancy settings '%s': %v", rss, err) - } - - cs := b.consensusState() - - return api.GougingParams{ - ConsensusState: cs, - GougingSettings: gs, - RedundancySettings: rs, - TransactionFee: b.tp.RecommendedFee(), - }, nil -} - -func (b *bus) handleGETAlertsDeprecated(jc jape.Context) { - ar, err := b.alertMgr.Alerts(jc.Request.Context(), alerts.AlertsOpts{Offset: 0, Limit: -1}) - if jc.Check("failed to fetch alerts", err) != nil { - return - } - jc.Encode(ar.Alerts) -} - -func (b *bus) handleGETAlerts(jc jape.Context) { - if jc.Request.FormValue("offset") == "" && jc.Request.FormValue("limit") == "" { - b.handleGETAlertsDeprecated(jc) - return - } - offset, limit := 0, -1 - var severity alerts.Severity - if jc.DecodeForm("offset", &offset) != nil { - return - } else if jc.DecodeForm("limit", &limit) != nil { - return - } else if offset < 0 { - jc.Error(errors.New("offset must be non-negative"), http.StatusBadRequest) - return - } else if jc.DecodeForm("severity", &severity) != nil { - return - } - ar, err := b.alertMgr.Alerts(jc.Request.Context(), alerts.AlertsOpts{ - Offset: offset, - Limit: limit, - Severity: severity, - }) - if jc.Check("failed to fetch alerts", err) != nil { - return - } - jc.Encode(ar) -} - -func (b *bus) handlePOSTAlertsDismiss(jc jape.Context) { - var ids []types.Hash256 - if jc.Decode(&ids) != nil { - return - } - jc.Check("failed to dismiss alerts", b.alertMgr.DismissAlerts(jc.Request.Context(), ids...)) -} - -func (b *bus) handlePOSTAlertsRegister(jc jape.Context) { - var alert alerts.Alert - if jc.Decode(&alert) != nil { - return - } - jc.Check("failed to register alert", b.alertMgr.RegisterAlert(jc.Request.Context(), alert)) -} - -func (b *bus) accountsHandlerGET(jc jape.Context) { - jc.Encode(b.accounts.Accounts()) -} - -func (b *bus) accountHandlerGET(jc jape.Context) { - var id rhpv3.Account - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.AccountHandlerPOST - if jc.Decode(&req) != nil { - return - } - acc, err := b.accounts.Account(id, req.HostKey) - if jc.Check("failed to fetch account", err) != nil { - return - } - jc.Encode(acc) -} - -func (b *bus) accountsAddHandlerPOST(jc jape.Context) { - var id rhpv3.Account - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.AccountsAddBalanceRequest - if jc.Decode(&req) != nil { - return - } - if id == (rhpv3.Account{}) { - jc.Error(errors.New("account id needs to be set"), http.StatusBadRequest) - return - } - if req.HostKey == (types.PublicKey{}) { - jc.Error(errors.New("host needs to be set"), http.StatusBadRequest) - return - } - b.accounts.AddAmount(id, req.HostKey, req.Amount) -} - -func (b *bus) accountsResetDriftHandlerPOST(jc jape.Context) { - var id rhpv3.Account - if jc.DecodeParam("id", &id) != nil { - return - } - err := b.accounts.ResetDrift(id) - if errors.Is(err, errAccountsNotFound) { - jc.Error(err, http.StatusNotFound) - return - } - if jc.Check("failed to reset drift", err) != nil { - return - } -} - -func (b *bus) accountsUpdateHandlerPOST(jc jape.Context) { - var id rhpv3.Account - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.AccountsUpdateBalanceRequest - if jc.Decode(&req) != nil { - return - } - if id == (rhpv3.Account{}) { - jc.Error(errors.New("account id needs to be set"), http.StatusBadRequest) - return - } - if req.HostKey == (types.PublicKey{}) { - jc.Error(errors.New("host needs to be set"), http.StatusBadRequest) - return - } - b.accounts.SetBalance(id, req.HostKey, req.Amount) -} - -func (b *bus) accountsRequiresSyncHandlerPOST(jc jape.Context) { - var id rhpv3.Account - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.AccountsRequiresSyncRequest - if jc.Decode(&req) != nil { - return - } - if id == (rhpv3.Account{}) { - jc.Error(errors.New("account id needs to be set"), http.StatusBadRequest) - return - } - if req.HostKey == (types.PublicKey{}) { - jc.Error(errors.New("host needs to be set"), http.StatusBadRequest) - return - } - err := b.accounts.ScheduleSync(id, req.HostKey) - if errors.Is(err, errAccountsNotFound) { - jc.Error(err, http.StatusNotFound) - return - } - if jc.Check("failed to set requiresSync flag on account", err) != nil { - return - } -} - -func (b *bus) accountsLockHandlerPOST(jc jape.Context) { - var id rhpv3.Account - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.AccountsLockHandlerRequest - if jc.Decode(&req) != nil { - return - } - - acc, lockID := b.accounts.LockAccount(jc.Request.Context(), id, req.HostKey, req.Exclusive, time.Duration(req.Duration)) - jc.Encode(api.AccountsLockHandlerResponse{ - Account: acc, - LockID: lockID, - }) -} - -func (b *bus) accountsUnlockHandlerPOST(jc jape.Context) { - var id rhpv3.Account - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.AccountsUnlockHandlerRequest - if jc.Decode(&req) != nil { - return - } - - err := b.accounts.UnlockAccount(id, req.LockID) - if jc.Check("failed to unlock account", err) != nil { - return - } -} - -func (b *bus) autopilotsListHandlerGET(jc jape.Context) { - if autopilots, err := b.as.Autopilots(jc.Request.Context()); jc.Check("failed to fetch autopilots", err) == nil { - jc.Encode(autopilots) - } -} - -func (b *bus) autopilotsHandlerGET(jc jape.Context) { - var id string - if jc.DecodeParam("id", &id) != nil { - return - } - ap, err := b.as.Autopilot(jc.Request.Context(), id) - if errors.Is(err, api.ErrAutopilotNotFound) { - jc.Error(err, http.StatusNotFound) - return - } - if jc.Check("couldn't load object", err) != nil { - return - } - - jc.Encode(ap) -} - -func (b *bus) autopilotsHandlerPUT(jc jape.Context) { - var id string - if jc.DecodeParam("id", &id) != nil { - return - } - - var ap api.Autopilot - if jc.Decode(&ap) != nil { - return - } - - if ap.ID != id { - jc.Error(errors.New("id in path and body don't match"), http.StatusBadRequest) - return - } - - if jc.Check("failed to update autopilot", b.as.UpdateAutopilot(jc.Request.Context(), ap)) == nil { - b.pinMgr.TriggerUpdate() - } -} - -func (b *bus) autopilotHostCheckHandlerPUT(jc jape.Context) { - var id string - if jc.DecodeParam("id", &id) != nil { - return - } - var hk types.PublicKey - if jc.DecodeParam("hostkey", &hk) != nil { - return - } - var hc api.HostCheck - if jc.Check("failed to decode host check", jc.Decode(&hc)) != nil { - return - } - - err := b.hdb.UpdateHostCheck(jc.Request.Context(), id, hk, hc) - if errors.Is(err, api.ErrAutopilotNotFound) { - jc.Error(err, http.StatusNotFound) - return - } else if jc.Check("failed to update host", err) != nil { - return - } -} - -func (b *bus) broadcastAction(e webhooks.Event) { - log := b.logger.With("event", e.Event).With("module", e.Module) - err := b.webhooksMgr.BroadcastAction(context.Background(), e) - if err != nil { - log.With(zap.Error(err)).Error("failed to broadcast action") - } else { - log.Debug("successfully broadcast action") - } -} - -func (b *bus) contractTaxHandlerGET(jc jape.Context) { - var payout types.Currency - if jc.DecodeParam("payout", (*api.ParamCurrency)(&payout)) != nil { - return - } - cs := b.cm.TipState() - jc.Encode(cs.FileContractTax(types.FileContract{Payout: payout})) -} - -func (b *bus) stateHandlerGET(jc jape.Context) { - jc.Encode(api.BusStateResponse{ - StartTime: api.TimeRFC3339(b.startTime), - BuildState: api.BuildState{ - Network: build.NetworkName(), - Version: build.Version(), - Commit: build.Commit(), - OS: runtime.GOOS, - BuildTime: api.TimeRFC3339(build.BuildTime()), - }, - }) -} - -func (b *bus) uploadTrackHandlerPOST(jc jape.Context) { - var id api.UploadID - if jc.DecodeParam("id", &id) == nil { - jc.Check("failed to track upload", b.uploadingSectors.StartUpload(id)) - } -} - -func (b *bus) uploadAddSectorHandlerPOST(jc jape.Context) { - var id api.UploadID - if jc.DecodeParam("id", &id) != nil { - return - } - var req api.UploadSectorRequest - if jc.Decode(&req) != nil { - return - } - jc.Check("failed to add sector", b.uploadingSectors.AddSector(id, req.ContractID, req.Root)) -} - -func (b *bus) uploadFinishedHandlerDELETE(jc jape.Context) { - var id api.UploadID - if jc.DecodeParam("id", &id) == nil { - b.uploadingSectors.FinishUpload(id) - } -} - -func (b *bus) webhookActionHandlerPost(jc jape.Context) { - var action webhooks.Event - if jc.Check("failed to decode action", jc.Decode(&action)) != nil { - return - } - b.broadcastAction(action) -} - -func (b *bus) webhookHandlerDelete(jc jape.Context) { - var wh webhooks.Webhook - if jc.Decode(&wh) != nil { - return - } - err := b.webhooksMgr.Delete(jc.Request.Context(), wh) - if errors.Is(err, webhooks.ErrWebhookNotFound) { - jc.Error(fmt.Errorf("webhook for URL %v and event %v.%v not found", wh.URL, wh.Module, wh.Event), http.StatusNotFound) - return - } else if jc.Check("failed to delete webhook", err) != nil { - return - } -} - -func (b *bus) webhookHandlerGet(jc jape.Context) { - webhooks, queueInfos := b.webhooksMgr.Info() - jc.Encode(api.WebhookResponse{ - Queues: queueInfos, - Webhooks: webhooks, - }) -} - -func (b *bus) webhookHandlerPost(jc jape.Context) { - var req webhooks.Webhook - if jc.Decode(&req) != nil { - return - } - - err := b.webhooksMgr.Register(jc.Request.Context(), webhooks.Webhook{ - Event: req.Event, - Module: req.Module, - URL: req.URL, - Headers: req.Headers, - }) - if err != nil { - jc.Error(fmt.Errorf("failed to add Webhook: %w", err), http.StatusInternalServerError) - return - } -} - -func (b *bus) metricsHandlerDELETE(jc jape.Context) { - metric := jc.PathParam("key") - if metric == "" { - jc.Error(errors.New("parameter 'metric' is required"), http.StatusBadRequest) - return - } - - var cutoff time.Time - if jc.DecodeForm("cutoff", (*api.TimeRFC3339)(&cutoff)) != nil { - return - } else if cutoff.IsZero() { - jc.Error(errors.New("parameter 'cutoff' is required"), http.StatusBadRequest) - return - } - - err := b.mtrcs.PruneMetrics(jc.Request.Context(), metric, cutoff) - if jc.Check("failed to prune metrics", err) != nil { - return - } -} - -func (b *bus) metricsHandlerPUT(jc jape.Context) { - jc.Custom((*interface{})(nil), nil) - - key := jc.PathParam("key") - switch key { - case api.MetricContractPrune: - // TODO: jape hack - remove once jape can handle decoding multiple different request types - var req api.ContractPruneMetricRequestPUT - if err := json.NewDecoder(jc.Request.Body).Decode(&req); err != nil { - jc.Error(fmt.Errorf("couldn't decode request type (%T): %w", req, err), http.StatusBadRequest) - return - } else if jc.Check("failed to record contract prune metric", b.mtrcs.RecordContractPruneMetric(jc.Request.Context(), req.Metrics...)) != nil { - return - } - case api.MetricContractSetChurn: - // TODO: jape hack - remove once jape can handle decoding multiple different request types - var req api.ContractSetChurnMetricRequestPUT - if err := json.NewDecoder(jc.Request.Body).Decode(&req); err != nil { - jc.Error(fmt.Errorf("couldn't decode request type (%T): %w", req, err), http.StatusBadRequest) - return - } else if jc.Check("failed to record contract churn metric", b.mtrcs.RecordContractSetChurnMetric(jc.Request.Context(), req.Metrics...)) != nil { - return - } - default: - jc.Error(fmt.Errorf("unknown metric key '%s'", key), http.StatusBadRequest) - return - } -} - -func (b *bus) metricsHandlerGET(jc jape.Context) { - // parse mandatory query parameters - var start time.Time - if jc.DecodeForm("start", (*api.TimeRFC3339)(&start)) != nil { - return - } else if start.IsZero() { - jc.Error(errors.New("parameter 'start' is required"), http.StatusBadRequest) - return - } - - var n uint64 - if jc.DecodeForm("n", &n) != nil { - return - } else if n == 0 { - if jc.Request.FormValue("n") == "" { - jc.Error(errors.New("parameter 'n' is required"), http.StatusBadRequest) - } else { - jc.Error(errors.New("'n' has to be greater than zero"), http.StatusBadRequest) - } - return - } - - var interval time.Duration - if jc.DecodeForm("interval", (*api.DurationMS)(&interval)) != nil { - return - } else if interval == 0 { - jc.Error(errors.New("parameter 'interval' is required"), http.StatusBadRequest) - return - } - - // parse optional query parameters - var metrics interface{} - var err error - key := jc.PathParam("key") - switch key { - case api.MetricContract: - var opts api.ContractMetricsQueryOpts - if jc.DecodeForm("contractID", &opts.ContractID) != nil { - return - } else if jc.DecodeForm("hostKey", &opts.HostKey) != nil { - return - } - metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) - case api.MetricContractPrune: - var opts api.ContractPruneMetricsQueryOpts - if jc.DecodeForm("contractID", &opts.ContractID) != nil { - return - } else if jc.DecodeForm("hostKey", &opts.HostKey) != nil { - return - } else if jc.DecodeForm("hostVersion", &opts.HostVersion) != nil { - return - } - metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) - case api.MetricContractSet: - var opts api.ContractSetMetricsQueryOpts - if jc.DecodeForm("name", &opts.Name) != nil { - return - } - metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) - case api.MetricContractSetChurn: - var opts api.ContractSetChurnMetricsQueryOpts - if jc.DecodeForm("name", &opts.Name) != nil { - return - } else if jc.DecodeForm("direction", &opts.Direction) != nil { - return - } else if jc.DecodeForm("reason", &opts.Reason) != nil { - return - } - metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) - case api.MetricWallet: - var opts api.WalletMetricsQueryOpts - metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) - default: - jc.Error(fmt.Errorf("unknown metric '%s'", key), http.StatusBadRequest) - return - } - if errors.Is(err, api.ErrMaxIntervalsExceeded) { - jc.Error(err, http.StatusBadRequest) - return - } else if jc.Check(fmt.Sprintf("failed to fetch '%s' metrics", key), err) != nil { - return - } - jc.Encode(metrics) -} - -func (b *bus) metrics(ctx context.Context, key string, start time.Time, n uint64, interval time.Duration, opts interface{}) (interface{}, error) { - switch key { - case api.MetricContract: - return b.mtrcs.ContractMetrics(ctx, start, n, interval, opts.(api.ContractMetricsQueryOpts)) - case api.MetricContractPrune: - return b.mtrcs.ContractPruneMetrics(ctx, start, n, interval, opts.(api.ContractPruneMetricsQueryOpts)) - case api.MetricContractSet: - return b.mtrcs.ContractSetMetrics(ctx, start, n, interval, opts.(api.ContractSetMetricsQueryOpts)) - case api.MetricContractSetChurn: - return b.mtrcs.ContractSetChurnMetrics(ctx, start, n, interval, opts.(api.ContractSetChurnMetricsQueryOpts)) - case api.MetricWallet: - return b.mtrcs.WalletMetrics(ctx, start, n, interval, opts.(api.WalletMetricsQueryOpts)) - } - return nil, fmt.Errorf("unknown metric '%s'", key) -} - -func (b *bus) multipartHandlerCreatePOST(jc jape.Context) { - var req api.MultipartCreateRequest - if jc.Decode(&req) != nil { - return - } - - var key object.EncryptionKey - if req.GenerateKey { - key = object.GenerateEncryptionKey() - } else if req.Key == nil { - key = object.NoOpKey - } else { - key = *req.Key - } - - resp, err := b.ms.CreateMultipartUpload(jc.Request.Context(), req.Bucket, req.Path, key, req.MimeType, req.Metadata) - if jc.Check("failed to create multipart upload", err) != nil { - return - } - jc.Encode(resp) -} - -func (b *bus) multipartHandlerAbortPOST(jc jape.Context) { - var req api.MultipartAbortRequest - if jc.Decode(&req) != nil { - return - } - err := b.ms.AbortMultipartUpload(jc.Request.Context(), req.Bucket, req.Path, req.UploadID) - if jc.Check("failed to abort multipart upload", err) != nil { - return - } -} - -func (b *bus) multipartHandlerCompletePOST(jc jape.Context) { - var req api.MultipartCompleteRequest - if jc.Decode(&req) != nil { - return - } - resp, err := b.ms.CompleteMultipartUpload(jc.Request.Context(), req.Bucket, req.Path, req.UploadID, req.Parts, api.CompleteMultipartOptions{ - Metadata: req.Metadata, - }) - if jc.Check("failed to complete multipart upload", err) != nil { - return - } - jc.Encode(resp) -} - -func (b *bus) multipartHandlerUploadPartPUT(jc jape.Context) { - var req api.MultipartAddPartRequest - if jc.Decode(&req) != nil { - return - } - if req.Bucket == "" { - req.Bucket = api.DefaultBucketName - } else if req.ContractSet == "" { - jc.Error(errors.New("contract_set must be non-empty"), http.StatusBadRequest) - return - } else if req.ETag == "" { - jc.Error(errors.New("etag must be non-empty"), http.StatusBadRequest) - return - } else if req.PartNumber <= 0 || req.PartNumber > gofakes3.MaxUploadPartNumber { - jc.Error(fmt.Errorf("part_number must be between 1 and %d", gofakes3.MaxUploadPartNumber), http.StatusBadRequest) - return - } else if req.UploadID == "" { - jc.Error(errors.New("upload_id must be non-empty"), http.StatusBadRequest) - return - } - err := b.ms.AddMultipartPart(jc.Request.Context(), req.Bucket, req.Path, req.ContractSet, req.ETag, req.UploadID, req.PartNumber, req.Slices) - if jc.Check("failed to upload part", err) != nil { - return - } -} - -func (b *bus) multipartHandlerUploadGET(jc jape.Context) { - resp, err := b.ms.MultipartUpload(jc.Request.Context(), jc.PathParam("id")) - if jc.Check("failed to get multipart upload", err) != nil { - return - } - jc.Encode(resp) -} - -func (b *bus) multipartHandlerListUploadsPOST(jc jape.Context) { - var req api.MultipartListUploadsRequest - if jc.Decode(&req) != nil { - return - } - resp, err := b.ms.MultipartUploads(jc.Request.Context(), req.Bucket, req.Prefix, req.PathMarker, req.UploadIDMarker, req.Limit) - if jc.Check("failed to list multipart uploads", err) != nil { - return - } - jc.Encode(resp) -} - -func (b *bus) multipartHandlerListPartsPOST(jc jape.Context) { - var req api.MultipartListPartsRequest - if jc.Decode(&req) != nil { - return - } - resp, err := b.ms.MultipartUploadParts(jc.Request.Context(), req.Bucket, req.Path, req.UploadID, req.PartNumberMarker, int64(req.Limit)) - if jc.Check("failed to list multipart upload parts", err) != nil { - return - } - jc.Encode(resp) -} - -func (b *bus) ProcessConsensusChange(cc modules.ConsensusChange) { - if cc.Synced { - b.broadcastAction(webhooks.Event{ - Module: api.ModuleConsensus, - Event: api.EventUpdate, - Payload: api.EventConsensusUpdate{ - ConsensusState: b.consensusState(), - TransactionFee: b.tp.RecommendedFee(), - Timestamp: time.Now().UTC(), - }, - }) - } -} - -// New returns a new Bus. -func New(s Syncer, am *alerts.Manager, whm *webhooks.Manager, cm ChainManager, tp TransactionPool, w Wallet, hdb HostDB, as AutopilotStore, ms MetadataStore, ss SettingStore, eas EphemeralAccountStore, mtrcs MetricsStore, l *zap.Logger) (*bus, error) { - b := &bus{ - s: s, - cm: cm, - tp: tp, - w: w, - hdb: hdb, - as: as, - ms: ms, - mtrcs: mtrcs, - ss: ss, - eas: eas, - contractLocks: newContractLocks(), - uploadingSectors: newUploadingSectorsCache(), - - alerts: alerts.WithOrigin(am, "bus"), - alertMgr: am, - webhooksMgr: whm, - logger: l.Sugar().Named("bus"), - - startTime: time.Now(), - } - - b.pinMgr = ibus.NewPinManager(whm, as, ss, defaultPinUpdateInterval, defaultPinRateWindow, b.logger.Desugar()) - - // ensure we don't hang indefinitely - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - // load default settings if the setting is not already set - for key, value := range map[string]interface{}{ - api.SettingGouging: build.DefaultGougingSettings, - api.SettingPricePinning: build.DefaultPricePinSettings, - api.SettingRedundancy: build.DefaultRedundancySettings, - api.SettingUploadPacking: build.DefaultUploadPackingSettings, - } { - if _, err := b.ss.Setting(ctx, key); errors.Is(err, api.ErrSettingNotFound) { - if bytes, err := json.Marshal(value); err != nil { - panic("failed to marshal default settings") // should never happen - } else if err := b.ss.UpdateSetting(ctx, key, string(bytes)); err != nil { - return nil, err - } - } - } - - // check redundancy settings for validity - var rs api.RedundancySettings - if rss, err := b.ss.Setting(ctx, api.SettingRedundancy); err != nil { - return nil, err - } else if err := json.Unmarshal([]byte(rss), &rs); err != nil { - return nil, err - } else if err := rs.Validate(); err != nil { - l.Warn(fmt.Sprintf("invalid redundancy setting found '%v', overwriting the redundancy settings with the default settings", rss)) - bytes, _ := json.Marshal(build.DefaultRedundancySettings) - if err := b.ss.UpdateSetting(ctx, api.SettingRedundancy, string(bytes)); err != nil { - return nil, err - } - } - - // check gouging settings for validity - var gs api.GougingSettings - if gss, err := b.ss.Setting(ctx, api.SettingGouging); err != nil { - return nil, err - } else if err := json.Unmarshal([]byte(gss), &gs); err != nil { - return nil, err - } else if err := gs.Validate(); err != nil { - // compat: apply default EA gouging settings - gs.MinMaxEphemeralAccountBalance = build.DefaultGougingSettings.MinMaxEphemeralAccountBalance - gs.MinPriceTableValidity = build.DefaultGougingSettings.MinPriceTableValidity - gs.MinAccountExpiry = build.DefaultGougingSettings.MinAccountExpiry - if err := gs.Validate(); err == nil { - l.Info(fmt.Sprintf("updating gouging settings with default EA settings: %+v", gs)) - bytes, _ := json.Marshal(gs) - if err := b.ss.UpdateSetting(ctx, api.SettingGouging, string(bytes)); err != nil { - return nil, err - } - } else { - // compat: apply default host block leeway settings - gs.HostBlockHeightLeeway = build.DefaultGougingSettings.HostBlockHeightLeeway - if err := gs.Validate(); err == nil { - l.Info(fmt.Sprintf("updating gouging settings with default HostBlockHeightLeeway settings: %v", gs)) - bytes, _ := json.Marshal(gs) - if err := b.ss.UpdateSetting(ctx, api.SettingGouging, string(bytes)); err != nil { - return nil, err - } - } else { - l.Warn(fmt.Sprintf("invalid gouging setting found '%v', overwriting the gouging settings with the default settings", gss)) - bytes, _ := json.Marshal(build.DefaultGougingSettings) - if err := b.ss.UpdateSetting(ctx, api.SettingGouging, string(bytes)); err != nil { - return nil, err - } - } - } - } - - // load the accounts into memory, they're saved when the bus is stopped - accounts, err := eas.Accounts(ctx) - if err != nil { - return nil, err - } - b.accounts = newAccounts(accounts, b.logger) - - // mark the shutdown as unclean, this will be overwritten when/if the - // accounts are saved on shutdown - if err := eas.SetUncleanShutdown(ctx); err != nil { - return nil, fmt.Errorf("failed to mark account shutdown as unclean: %w", err) - } - - if err := cm.Subscribe(b, modules.ConsensusChangeRecent, nil); err != nil { - return nil, fmt.Errorf("failed to subscribe to consensus changes: %w", err) - } - return b, nil + return nil } diff --git a/bus/client/client_test.go b/bus/client/client_test.go deleted file mode 100644 index 92a7fdcd2..000000000 --- a/bus/client/client_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package client_test - -import ( - "context" - "net" - "net/http" - "path/filepath" - "strings" - "testing" - "time" - - "go.sia.tech/core/types" - "go.sia.tech/jape" - "go.sia.tech/renterd/api" - "go.sia.tech/renterd/build" - "go.sia.tech/renterd/bus/client" - "go.sia.tech/renterd/config" - "go.sia.tech/renterd/internal/node" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -func TestClient(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - c, serveFn, shutdownFn, err := newTestClient(t.TempDir()) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := shutdownFn(ctx); err != nil { - t.Error(err) - } - }() - go serveFn() - - // assert setting 'foo' is not found - if err := c.Setting(ctx, "foo", nil); err == nil || !strings.Contains(err.Error(), api.ErrSettingNotFound.Error()) { - t.Fatal("unexpected err", err) - } - - // update setting 'foo' - if err := c.UpdateSetting(ctx, "foo", "bar"); err != nil { - t.Fatal(err) - } - - // fetch setting 'foo' and assert it matches - var value string - if err := c.Setting(ctx, "foo", &value); err != nil { - t.Fatal("unexpected err", err) - } else if value != "bar" { - t.Fatal("unexpected result", value) - } - - // fetch redundancy settings and assert they're configured to the default values - if rs, err := c.RedundancySettings(ctx); err != nil { - t.Fatal(err) - } else if rs.MinShards != build.DefaultRedundancySettings.MinShards || rs.TotalShards != build.DefaultRedundancySettings.TotalShards { - t.Fatal("unexpected redundancy settings", rs) - } -} - -func newTestClient(dir string) (*client.Client, func() error, func(context.Context) error, error) { - // create listener - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, nil, nil, err - } - - // create client - client := client.New("http://"+l.Addr().String(), "test") - b, _, cleanup, err := node.NewBus(node.BusConfig{ - Bus: config.Bus{ - AnnouncementMaxAgeHours: 24 * 7 * 52, // 1 year - Bootstrap: false, - GatewayAddr: "127.0.0.1:0", - UsedUTXOExpiry: time.Minute, - SlabBufferCompletionThreshold: 0, - }, - DatabaseLog: config.DatabaseLog{ - SlowThreshold: 100 * time.Millisecond, - }, - Miner: node.NewMiner(client), - Logger: zap.NewNop(), - }, filepath.Join(dir, "bus"), types.GeneratePrivateKey(), zap.New(zapcore.NewNopCore())) - if err != nil { - return nil, nil, nil, err - } - - // create server - server := http.Server{Handler: jape.BasicAuth("test")(b)} - - serveFn := func() error { - err := server.Serve(l) - if err != nil && !strings.Contains(err.Error(), "Server closed") { - return err - } - return nil - } - - shutdownFn := func(ctx context.Context) error { - server.Shutdown(ctx) - return cleanup(ctx) - } - return client, serveFn, shutdownFn, nil -} diff --git a/bus/client/settings.go b/bus/client/settings.go index 0e21ca96a..22714cf8b 100644 --- a/bus/client/settings.go +++ b/bus/client/settings.go @@ -24,6 +24,12 @@ func (c *Client) GougingSettings(ctx context.Context) (gs api.GougingSettings, e return } +// PricePinningSettings returns the contract set settings. +func (c *Client) PricePinningSettings(ctx context.Context) (pps api.PricePinSettings, err error) { + err = c.Setting(ctx, api.SettingPricePinning, &pps) + return +} + // RedundancySettings returns the redundancy settings. func (c *Client) RedundancySettings(ctx context.Context) (rs api.RedundancySettings, err error) { err = c.Setting(ctx, api.SettingRedundancy, &rs) diff --git a/bus/client/wallet.go b/bus/client/wallet.go index dd419c4ea..9733ed335 100644 --- a/bus/client/wallet.go +++ b/bus/client/wallet.go @@ -10,32 +10,17 @@ import ( rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" "go.sia.tech/renterd/api" - "go.sia.tech/renterd/wallet" ) -// SendSiacoins is a helper method that sends siacoins to the given outputs. -func (c *Client) SendSiacoins(ctx context.Context, scos []types.SiacoinOutput, useUnconfirmedTxns bool) (err error) { - var value types.Currency - for _, sco := range scos { - value = value.Add(sco.Value) - } - txn := types.Transaction{ - SiacoinOutputs: scos, - } - toSign, parents, err := c.WalletFund(ctx, &txn, value, useUnconfirmedTxns) - if err != nil { - return err - } - defer func() { - if err != nil { - _ = c.WalletDiscard(ctx, txn) - } - }() - err = c.WalletSign(ctx, &txn, toSign, types.CoveredFields{WholeTransaction: true}) - if err != nil { - return err - } - return c.BroadcastTransaction(ctx, append(parents, txn)) +// SendSiacoins is a helper method that sends siacoins to the given address. +func (c *Client) SendSiacoins(ctx context.Context, addr types.Address, amt types.Currency, useUnconfirmedTxns bool) (txnID types.TransactionID, err error) { + err = c.c.WithContext(ctx).POST("/wallet/send", api.WalletSendRequest{ + Address: addr, + Amount: amt, + SubtractMinerFee: false, + UseUnconfirmed: useUnconfirmedTxns, + }, &txnID) + return } // Wallet calls the /wallet endpoint on the bus. @@ -67,7 +52,7 @@ func (c *Client) WalletFund(ctx context.Context, txn *types.Transaction, amount } // WalletOutputs returns the set of unspent outputs controlled by the wallet. -func (c *Client) WalletOutputs(ctx context.Context) (resp []wallet.SiacoinElement, err error) { +func (c *Client) WalletOutputs(ctx context.Context) (resp []api.SiacoinElement, err error) { err = c.c.WithContext(ctx).GET("/wallet/outputs", &resp) return } @@ -138,7 +123,7 @@ func (c *Client) WalletSign(ctx context.Context, txn *types.Transaction, toSign } // WalletTransactions returns all transactions relevant to the wallet. -func (c *Client) WalletTransactions(ctx context.Context, opts ...api.WalletTransactionsOption) (resp []wallet.Transaction, err error) { +func (c *Client) WalletTransactions(ctx context.Context, opts ...api.WalletTransactionsOption) (resp []api.Transaction, err error) { c.c.Custom("GET", "/wallet/transactions", nil, &resp) values := url.Values{} diff --git a/bus/client/webhooks.go b/bus/client/webhooks.go index 769d1cf57..833c7c7e1 100644 --- a/bus/client/webhooks.go +++ b/bus/client/webhooks.go @@ -13,16 +13,12 @@ func (c *Client) BroadcastAction(ctx context.Context, action webhooks.Event) err return err } -// DeleteWebhook deletes the webhook with the given ID. -func (c *Client) DeleteWebhook(ctx context.Context, url, module, event string) error { - return c.c.POST("/webhook/delete", webhooks.Webhook{ - URL: url, - Module: module, - Event: event, - }, nil) +// UnregisterWebhook unregisters the given webhook. +func (c *Client) UnregisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { + return c.c.POST("/webhook/delete", webhook, nil) } -// RegisterWebhook registers a new webhook for the given URL. +// RegisterWebhook registers the given webhook. func (c *Client) RegisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { err := c.c.WithContext(ctx).POST("/webhooks", webhook, nil) return err diff --git a/bus/routes.go b/bus/routes.go new file mode 100644 index 000000000..9feb747e4 --- /dev/null +++ b/bus/routes.go @@ -0,0 +1,2317 @@ +package bus + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "net/http" + "runtime" + "sort" + "strings" + "time" + + rhpv2 "go.sia.tech/core/rhp/v2" + rhpv3 "go.sia.tech/core/rhp/v3" + + ibus "go.sia.tech/renterd/internal/bus" + + "go.sia.tech/core/gateway" + "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/gofakes3" + "go.sia.tech/jape" + "go.sia.tech/renterd/alerts" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/build" + "go.sia.tech/renterd/internal/utils" + "go.sia.tech/renterd/object" + "go.sia.tech/renterd/webhooks" + "go.uber.org/zap" +) + +func (b *Bus) fetchSetting(ctx context.Context, key string, value interface{}) error { + if val, err := b.ss.Setting(ctx, key); err != nil { + return fmt.Errorf("could not get contract set settings: %w", err) + } else if err := json.Unmarshal([]byte(val), &value); err != nil { + b.logger.Panicf("failed to unmarshal %v settings '%s': %v", key, val, err) + } + return nil +} + +func (b *Bus) consensusAcceptBlock(jc jape.Context) { + var block types.Block + if jc.Decode(&block) != nil { + return + } + + if jc.Check("failed to accept block", b.cm.AddBlocks([]types.Block{block})) != nil { + return + } + + if block.V2 == nil { + b.s.BroadcastHeader(gateway.BlockHeader{ + ParentID: block.ParentID, + Nonce: block.Nonce, + Timestamp: block.Timestamp, + MerkleRoot: block.MerkleRoot(), + }) + } else { + b.s.BroadcastV2BlockOutline(gateway.OutlineBlock(block, b.cm.PoolTransactions(), b.cm.V2PoolTransactions())) + } +} + +func (b *Bus) syncerAddrHandler(jc jape.Context) { + jc.Encode(b.s.Addr()) +} + +func (b *Bus) syncerPeersHandler(jc jape.Context) { + var peers []string + for _, p := range b.s.Peers() { + peers = append(peers, p.String()) + } + jc.Encode(peers) +} + +func (b *Bus) syncerConnectHandler(jc jape.Context) { + var addr string + if jc.Decode(&addr) == nil { + _, err := b.s.Connect(jc.Request.Context(), addr) + jc.Check("couldn't connect to peer", err) + } +} + +func (b *Bus) consensusStateHandler(jc jape.Context) { + cs, err := b.consensusState(jc.Request.Context()) + if jc.Check("couldn't fetch consensus state", err) != nil { + return + } + jc.Encode(cs) +} + +func (b *Bus) consensusNetworkHandler(jc jape.Context) { + jc.Encode(api.ConsensusNetwork{ + Name: b.cm.TipState().Network.Name, + }) +} + +func (b *Bus) txpoolFeeHandler(jc jape.Context) { + jc.Encode(b.cm.RecommendedFee()) +} + +func (b *Bus) txpoolTransactionsHandler(jc jape.Context) { + jc.Encode(b.cm.PoolTransactions()) +} + +func (b *Bus) txpoolBroadcastHandler(jc jape.Context) { + var txnSet []types.Transaction + if jc.Decode(&txnSet) != nil { + return + } + + _, err := b.cm.AddPoolTransactions(txnSet) + if jc.Check("couldn't broadcast transaction set", err) != nil { + return + } + + b.s.BroadcastTransactionSet(txnSet) +} + +func (b *Bus) bucketsHandlerGET(jc jape.Context) { + resp, err := b.ms.ListBuckets(jc.Request.Context()) + if jc.Check("couldn't list buckets", err) != nil { + return + } + jc.Encode(resp) +} + +func (b *Bus) bucketsHandlerPOST(jc jape.Context) { + var bucket api.BucketCreateRequest + if jc.Decode(&bucket) != nil { + return + } else if bucket.Name == "" { + jc.Error(errors.New("no name provided"), http.StatusBadRequest) + return + } else if jc.Check("failed to create bucket", b.ms.CreateBucket(jc.Request.Context(), bucket.Name, bucket.Policy)) != nil { + return + } +} + +func (b *Bus) bucketsHandlerPolicyPUT(jc jape.Context) { + var req api.BucketUpdatePolicyRequest + if jc.Decode(&req) != nil { + return + } else if bucket := jc.PathParam("name"); bucket == "" { + jc.Error(errors.New("no bucket name provided"), http.StatusBadRequest) + return + } else if jc.Check("failed to create bucket", b.ms.UpdateBucketPolicy(jc.Request.Context(), bucket, req.Policy)) != nil { + return + } +} + +func (b *Bus) bucketHandlerDELETE(jc jape.Context) { + var name string + if jc.DecodeParam("name", &name) != nil { + return + } else if name == "" { + jc.Error(errors.New("no name provided"), http.StatusBadRequest) + return + } else if jc.Check("failed to delete bucket", b.ms.DeleteBucket(jc.Request.Context(), name)) != nil { + return + } +} + +func (b *Bus) bucketHandlerGET(jc jape.Context) { + var name string + if jc.DecodeParam("name", &name) != nil { + return + } else if name == "" { + jc.Error(errors.New("parameter 'name' is required"), http.StatusBadRequest) + return + } + bucket, err := b.ms.Bucket(jc.Request.Context(), name) + if errors.Is(err, api.ErrBucketNotFound) { + jc.Error(err, http.StatusNotFound) + return + } else if jc.Check("failed to fetch bucket", err) != nil { + return + } + jc.Encode(bucket) +} + +func (b *Bus) walletHandler(jc jape.Context) { + address := b.w.Address() + balance, err := b.w.Balance() + if jc.Check("couldn't fetch wallet balance", err) != nil { + return + } + + tip, err := b.w.Tip() + if jc.Check("couldn't fetch wallet scan height", err) != nil { + return + } + + jc.Encode(api.WalletResponse{ + ScanHeight: tip.Height, + Address: address, + Confirmed: balance.Confirmed, + Spendable: balance.Spendable, + Unconfirmed: balance.Unconfirmed, + Immature: balance.Immature, + }) +} + +func (b *Bus) walletTransactionsHandler(jc jape.Context) { + offset := 0 + limit := -1 + if jc.DecodeForm("offset", &offset) != nil || + jc.DecodeForm("limit", &limit) != nil { + return + } + + // TODO: deprecate these parameters when moving to v2.0.0 + var before, since time.Time + if jc.DecodeForm("before", (*api.TimeRFC3339)(&before)) != nil || + jc.DecodeForm("since", (*api.TimeRFC3339)(&since)) != nil { + return + } + + // convertToTransaction converts wallet event data to a Transaction. + convertToTransaction := func(kind string, data wallet.EventData) (txn types.Transaction, ok bool) { + ok = true + switch kind { + case wallet.EventTypeMinerPayout, + wallet.EventTypeFoundationSubsidy, + wallet.EventTypeSiafundClaim: + payout, _ := data.(wallet.EventPayout) + txn = types.Transaction{SiacoinOutputs: []types.SiacoinOutput{payout.SiacoinElement.SiacoinOutput}} + case wallet.EventTypeV1Transaction: + v1Txn, _ := data.(wallet.EventV1Transaction) + txn = types.Transaction(v1Txn.Transaction) + case wallet.EventTypeV1ContractResolution: + fce, _ := data.(wallet.EventV1ContractResolution) + txn = types.Transaction{ + FileContracts: []types.FileContract{fce.Parent.FileContract}, + SiacoinOutputs: []types.SiacoinOutput{fce.SiacoinElement.SiacoinOutput}, + } + default: + ok = false + } + return + } + + // convertToTransactions converts wallet events to API transactions. + convertToTransactions := func(events []wallet.Event) []api.Transaction { + var transactions []api.Transaction + for _, e := range events { + if txn, ok := convertToTransaction(e.Type, e.Data); ok { + transactions = append(transactions, api.Transaction{ + Raw: txn, + Index: e.Index, + ID: types.TransactionID(e.ID), + Inflow: e.SiacoinInflow(), + Outflow: e.SiacoinOutflow(), + Timestamp: e.Timestamp, + }) + } + } + return transactions + } + + if before.IsZero() && since.IsZero() { + events, err := b.w.Events(offset, limit) + if jc.Check("couldn't load transactions", err) == nil { + jc.Encode(convertToTransactions(events)) + } + return + } + + // TODO: remove this when 'before' and 'since' are deprecated, until then we + // fetch all transactions and paginate manually if either is specified + events, err := b.w.Events(0, -1) + if jc.Check("couldn't load transactions", err) != nil { + return + } + filtered := events[:0] + for _, txn := range events { + if (before.IsZero() || txn.Timestamp.Before(before)) && + (since.IsZero() || txn.Timestamp.After(since)) { + filtered = append(filtered, txn) + } + } + events = filtered + if limit == 0 || limit == -1 { + jc.Encode(convertToTransactions(events[offset:])) + } else { + jc.Encode(convertToTransactions(events[offset : offset+limit])) + } +} + +func (b *Bus) walletOutputsHandler(jc jape.Context) { + utxos, err := b.w.SpendableOutputs() + if jc.Check("couldn't load outputs", err) == nil { + // convert to siacoin elements + elements := make([]api.SiacoinElement, len(utxos)) + for i, sce := range utxos { + elements[i] = api.SiacoinElement{ + ID: sce.StateElement.ID, + SiacoinOutput: types.SiacoinOutput{ + Value: sce.SiacoinOutput.Value, + Address: sce.SiacoinOutput.Address, + }, + MaturityHeight: sce.MaturityHeight, + } + } + jc.Encode(elements) + } +} + +func (b *Bus) walletFundHandler(jc jape.Context) { + var wfr api.WalletFundRequest + if jc.Decode(&wfr) != nil { + return + } + txn := wfr.Transaction + + if len(txn.MinerFees) == 0 { + // if no fees are specified, we add some + fee := b.cm.RecommendedFee().Mul64(b.cm.TipState().TransactionWeight(txn)) + txn.MinerFees = []types.Currency{fee} + } + + toSign, err := b.w.FundTransaction(&txn, wfr.Amount.Add(txn.MinerFees[0]), wfr.UseUnconfirmedTxns) + if jc.Check("couldn't fund transaction", err) != nil { + return + } + + jc.Encode(api.WalletFundResponse{ + Transaction: txn, + ToSign: toSign, + DependsOn: b.cm.UnconfirmedParents(txn), + }) +} + +func (b *Bus) walletSendSiacoinsHandler(jc jape.Context) { + var req api.WalletSendRequest + if jc.Decode(&req) != nil { + return + } else if req.Address == types.VoidAddress { + jc.Error(errors.New("cannot send to void address"), http.StatusBadRequest) + return + } + + // estimate miner fee + feePerByte := b.cm.RecommendedFee() + minerFee := feePerByte.Mul64(stdTxnSize) + if req.SubtractMinerFee { + var underflow bool + req.Amount, underflow = req.Amount.SubWithUnderflow(minerFee) + if underflow { + jc.Error(fmt.Errorf("amount must be greater than miner fee: %s", minerFee), http.StatusBadRequest) + return + } + } + + state := b.cm.TipState() + // if the current height is below the v2 hardfork height, send a v1 + // transaction + if state.Index.Height < state.Network.HardforkV2.AllowHeight { + // build transaction + txn := types.Transaction{ + MinerFees: []types.Currency{minerFee}, + SiacoinOutputs: []types.SiacoinOutput{ + {Address: req.Address, Value: req.Amount}, + }, + } + toSign, err := b.w.FundTransaction(&txn, req.Amount.Add(minerFee), req.UseUnconfirmed) + if jc.Check("failed to fund transaction", err) != nil { + return + } + b.w.SignTransaction(&txn, toSign, types.CoveredFields{WholeTransaction: true}) + // shouldn't be necessary to get parents since the transaction is + // not using unconfirmed outputs, but good practice + txnset := append(b.cm.UnconfirmedParents(txn), txn) + // verify the transaction and add it to the transaction pool + if _, err := b.cm.AddPoolTransactions(txnset); jc.Check("failed to add transaction set", err) != nil { + b.w.ReleaseInputs([]types.Transaction{txn}, nil) + return + } + // broadcast the transaction + b.s.BroadcastTransactionSet(txnset) + jc.Encode(txn.ID()) + } else { + txn := types.V2Transaction{ + MinerFee: minerFee, + SiacoinOutputs: []types.SiacoinOutput{ + {Address: req.Address, Value: req.Amount}, + }, + } + // fund and sign transaction + state, toSign, err := b.w.FundV2Transaction(&txn, req.Amount.Add(minerFee), req.UseUnconfirmed) + if jc.Check("failed to fund transaction", err) != nil { + return + } + b.w.SignV2Inputs(state, &txn, toSign) + txnset := append(b.cm.V2UnconfirmedParents(txn), txn) + // verify the transaction and add it to the transaction pool + if _, err := b.cm.AddV2PoolTransactions(state.Index, txnset); jc.Check("failed to add v2 transaction set", err) != nil { + b.w.ReleaseInputs(nil, []types.V2Transaction{txn}) + return + } + // broadcast the transaction + b.s.BroadcastV2TransactionSet(state.Index, txnset) + jc.Encode(txn.ID()) + } +} + +func (b *Bus) walletSignHandler(jc jape.Context) { + var wsr api.WalletSignRequest + if jc.Decode(&wsr) != nil { + return + } + b.w.SignTransaction(&wsr.Transaction, wsr.ToSign, wsr.CoveredFields) + jc.Encode(wsr.Transaction) +} + +func (b *Bus) walletRedistributeHandler(jc jape.Context) { + var wfr api.WalletRedistributeRequest + if jc.Decode(&wfr) != nil { + return + } + if wfr.Outputs == 0 { + jc.Error(errors.New("'outputs' has to be greater than zero"), http.StatusBadRequest) + return + } + + var ids []types.TransactionID + if state := b.cm.TipState(); state.Index.Height < state.Network.HardforkV2.AllowHeight { + // v1 redistribution + txns, toSign, err := b.w.Redistribute(wfr.Outputs, wfr.Amount, b.cm.RecommendedFee()) + if jc.Check("couldn't redistribute money in the wallet into the desired outputs", err) != nil { + return + } + + if len(txns) == 0 { + jc.Encode(ids) + return + } + + for i := 0; i < len(txns); i++ { + b.w.SignTransaction(&txns[i], toSign, types.CoveredFields{WholeTransaction: true}) + ids = append(ids, txns[i].ID()) + } + + _, err = b.cm.AddPoolTransactions(txns) + if jc.Check("couldn't broadcast the transaction", err) != nil { + b.w.ReleaseInputs(txns, nil) + return + } + } else { + // v2 redistribution + txns, toSign, err := b.w.RedistributeV2(wfr.Outputs, wfr.Amount, b.cm.RecommendedFee()) + if jc.Check("couldn't redistribute money in the wallet into the desired outputs", err) != nil { + return + } + + if len(txns) == 0 { + jc.Encode(ids) + return + } + + for i := 0; i < len(txns); i++ { + b.w.SignV2Inputs(state, &txns[i], toSign[i]) + ids = append(ids, txns[i].ID()) + } + + _, err = b.cm.AddV2PoolTransactions(state.Index, txns) + if jc.Check("couldn't broadcast the transaction", err) != nil { + b.w.ReleaseInputs(nil, txns) + return + } + } + + jc.Encode(ids) +} + +func (b *Bus) walletDiscardHandler(jc jape.Context) { + var txn types.Transaction + if jc.Decode(&txn) == nil { + b.w.ReleaseInputs([]types.Transaction{txn}, nil) + } +} + +func (b *Bus) walletPrepareFormHandler(jc jape.Context) { + var wpfr api.WalletPrepareFormRequest + if jc.Decode(&wpfr) != nil { + return + } + if wpfr.HostKey == (types.PublicKey{}) { + jc.Error(errors.New("no host key provided"), http.StatusBadRequest) + return + } + if wpfr.RenterKey == (types.PublicKey{}) { + jc.Error(errors.New("no renter key provided"), http.StatusBadRequest) + return + } + cs := b.cm.TipState() + + fc := rhpv2.PrepareContractFormation(wpfr.RenterKey, wpfr.HostKey, wpfr.RenterFunds, wpfr.HostCollateral, wpfr.EndHeight, wpfr.HostSettings, wpfr.RenterAddress) + cost := rhpv2.ContractFormationCost(cs, fc, wpfr.HostSettings.ContractPrice) + txn := types.Transaction{ + FileContracts: []types.FileContract{fc}, + } + txn.MinerFees = []types.Currency{b.cm.RecommendedFee().Mul64(cs.TransactionWeight(txn))} + toSign, err := b.w.FundTransaction(&txn, cost.Add(txn.MinerFees[0]), true) + if jc.Check("couldn't fund transaction", err) != nil { + return + } + + b.w.SignTransaction(&txn, toSign, wallet.ExplicitCoveredFields(txn)) + + jc.Encode(append(b.cm.UnconfirmedParents(txn), txn)) +} + +func (b *Bus) walletPrepareRenewHandler(jc jape.Context) { + var wprr api.WalletPrepareRenewRequest + if jc.Decode(&wprr) != nil { + return + } + if wprr.RenterKey == nil { + jc.Error(errors.New("no renter key provided"), http.StatusBadRequest) + return + } + cs := b.cm.TipState() + + // Create the final revision from the provided revision. + finalRevision := wprr.Revision + finalRevision.MissedProofOutputs = finalRevision.ValidProofOutputs + finalRevision.Filesize = 0 + finalRevision.FileMerkleRoot = types.Hash256{} + finalRevision.RevisionNumber = math.MaxUint64 + + // Prepare the new contract. + fc, basePrice, err := rhpv3.PrepareContractRenewal(wprr.Revision, wprr.HostAddress, wprr.RenterAddress, wprr.RenterFunds, wprr.MinNewCollateral, wprr.PriceTable, wprr.ExpectedNewStorage, wprr.EndHeight) + if jc.Check("couldn't prepare contract renewal", err) != nil { + return + } + + // Create the transaction containing both the final revision and new + // contract. + txn := types.Transaction{ + FileContracts: []types.FileContract{fc}, + FileContractRevisions: []types.FileContractRevision{finalRevision}, + MinerFees: []types.Currency{wprr.PriceTable.TxnFeeMaxRecommended.Mul64(4096)}, + } + + // Compute how much renter funds to put into the new contract. + cost := rhpv3.ContractRenewalCost(cs, wprr.PriceTable, fc, txn.MinerFees[0], basePrice) + + // Make sure we don't exceed the max fund amount. + // TODO: remove the IsZero check for the v2 change + if /*!wprr.MaxFundAmount.IsZero() &&*/ wprr.MaxFundAmount.Cmp(cost) < 0 { + jc.Error(fmt.Errorf("%w: %v > %v", api.ErrMaxFundAmountExceeded, cost, wprr.MaxFundAmount), http.StatusBadRequest) + return + } + + // Fund the txn. We are not signing it yet since it's not complete. The host + // still needs to complete it and the revision + contract are signed with + // the renter key by the worker. + toSign, err := b.w.FundTransaction(&txn, cost, true) + if jc.Check("couldn't fund transaction", err) != nil { + return + } + + jc.Encode(api.WalletPrepareRenewResponse{ + FundAmount: cost, + ToSign: toSign, + TransactionSet: append(b.cm.UnconfirmedParents(txn), txn), + }) +} + +func (b *Bus) walletPendingHandler(jc jape.Context) { + isRelevant := func(txn types.Transaction) bool { + addr := b.w.Address() + for _, sci := range txn.SiacoinInputs { + if sci.UnlockConditions.UnlockHash() == addr { + return true + } + } + for _, sco := range txn.SiacoinOutputs { + if sco.Address == addr { + return true + } + } + return false + } + + txns := b.cm.PoolTransactions() + relevant := txns[:0] + for _, txn := range txns { + if isRelevant(txn) { + relevant = append(relevant, txn) + } + } + jc.Encode(relevant) +} + +func (b *Bus) hostsHandlerGETDeprecated(jc jape.Context) { + offset := 0 + limit := -1 + if jc.DecodeForm("offset", &offset) != nil || jc.DecodeForm("limit", &limit) != nil { + return + } + + // fetch hosts + hosts, err := b.hs.SearchHosts(jc.Request.Context(), "", api.HostFilterModeAllowed, api.UsabilityFilterModeAll, "", nil, offset, limit) + if jc.Check(fmt.Sprintf("couldn't fetch hosts %d-%d", offset, offset+limit), err) != nil { + return + } + jc.Encode(hosts) +} + +func (b *Bus) searchHostsHandlerPOST(jc jape.Context) { + var req api.SearchHostsRequest + if jc.Decode(&req) != nil { + return + } + + // TODO: on the next major release: + // - properly default search params (currently no defaults are set) + // - properly validate and return 400 (currently validation is done in autopilot and the store) + + hosts, err := b.hs.SearchHosts(jc.Request.Context(), req.AutopilotID, req.FilterMode, req.UsabilityMode, req.AddressContains, req.KeyIn, req.Offset, req.Limit) + if jc.Check(fmt.Sprintf("couldn't fetch hosts %d-%d", req.Offset, req.Offset+req.Limit), err) != nil { + return + } + jc.Encode(hosts) +} + +func (b *Bus) hostsRemoveHandlerPOST(jc jape.Context) { + var hrr api.HostsRemoveRequest + if jc.Decode(&hrr) != nil { + return + } + if hrr.MaxDowntimeHours == 0 { + jc.Error(errors.New("maxDowntime must be non-zero"), http.StatusBadRequest) + return + } + if hrr.MinRecentScanFailures == 0 { + jc.Error(errors.New("minRecentScanFailures must be non-zero"), http.StatusBadRequest) + return + } + removed, err := b.hs.RemoveOfflineHosts(jc.Request.Context(), hrr.MinRecentScanFailures, time.Duration(hrr.MaxDowntimeHours)) + if jc.Check("couldn't remove offline hosts", err) != nil { + return + } + jc.Encode(removed) +} + +func (b *Bus) hostsScanningHandlerGET(jc jape.Context) { + offset := 0 + limit := -1 + maxLastScan := time.Now() + if jc.DecodeForm("offset", &offset) != nil || jc.DecodeForm("limit", &limit) != nil || jc.DecodeForm("lastScan", (*api.TimeRFC3339)(&maxLastScan)) != nil { + return + } + hosts, err := b.hs.HostsForScanning(jc.Request.Context(), maxLastScan, offset, limit) + if jc.Check(fmt.Sprintf("couldn't fetch hosts %d-%d", offset, offset+limit), err) != nil { + return + } + jc.Encode(hosts) +} + +func (b *Bus) hostsPubkeyHandlerGET(jc jape.Context) { + var hostKey types.PublicKey + if jc.DecodeParam("hostkey", &hostKey) != nil { + return + } + host, err := b.hs.Host(jc.Request.Context(), hostKey) + if jc.Check("couldn't load host", err) == nil { + jc.Encode(host) + } +} + +func (b *Bus) hostsResetLostSectorsPOST(jc jape.Context) { + var hostKey types.PublicKey + if jc.DecodeParam("hostkey", &hostKey) != nil { + return + } + err := b.hs.ResetLostSectors(jc.Request.Context(), hostKey) + if jc.Check("couldn't reset lost sectors", err) != nil { + return + } +} + +func (b *Bus) hostsScanHandlerPOST(jc jape.Context) { + var req api.HostsScanRequest + if jc.Decode(&req) != nil { + return + } + if jc.Check("failed to record scans", b.hs.RecordHostScans(jc.Request.Context(), req.Scans)) != nil { + return + } +} + +func (b *Bus) hostsPricetableHandlerPOST(jc jape.Context) { + var req api.HostsPriceTablesRequest + if jc.Decode(&req) != nil { + return + } + if jc.Check("failed to record interactions", b.hs.RecordPriceTables(jc.Request.Context(), req.PriceTableUpdates)) != nil { + return + } +} + +func (b *Bus) contractsSpendingHandlerPOST(jc jape.Context) { + var records []api.ContractSpendingRecord + if jc.Decode(&records) != nil { + return + } + if jc.Check("failed to record spending metrics for contract", b.ms.RecordContractSpending(jc.Request.Context(), records)) != nil { + return + } +} + +func (b *Bus) hostsAllowlistHandlerGET(jc jape.Context) { + allowlist, err := b.hs.HostAllowlist(jc.Request.Context()) + if jc.Check("couldn't load allowlist", err) == nil { + jc.Encode(allowlist) + } +} + +func (b *Bus) hostsAllowlistHandlerPUT(jc jape.Context) { + ctx := jc.Request.Context() + var req api.UpdateAllowlistRequest + if jc.Decode(&req) == nil { + if len(req.Add)+len(req.Remove) > 0 && req.Clear { + jc.Error(errors.New("cannot add or remove entries while clearing the allowlist"), http.StatusBadRequest) + return + } else if jc.Check("couldn't update allowlist entries", b.hs.UpdateHostAllowlistEntries(ctx, req.Add, req.Remove, req.Clear)) != nil { + return + } + } +} + +func (b *Bus) hostsBlocklistHandlerGET(jc jape.Context) { + blocklist, err := b.hs.HostBlocklist(jc.Request.Context()) + if jc.Check("couldn't load blocklist", err) == nil { + jc.Encode(blocklist) + } +} + +func (b *Bus) hostsBlocklistHandlerPUT(jc jape.Context) { + ctx := jc.Request.Context() + var req api.UpdateBlocklistRequest + if jc.Decode(&req) == nil { + if len(req.Add)+len(req.Remove) > 0 && req.Clear { + jc.Error(errors.New("cannot add or remove entries while clearing the blocklist"), http.StatusBadRequest) + return + } else if jc.Check("couldn't update blocklist entries", b.hs.UpdateHostBlocklistEntries(ctx, req.Add, req.Remove, req.Clear)) != nil { + return + } + } +} + +func (b *Bus) contractsHandlerGET(jc jape.Context) { + var cs string + if jc.DecodeForm("contractset", &cs) != nil { + return + } + contracts, err := b.ms.Contracts(jc.Request.Context(), api.ContractsOpts{ + ContractSet: cs, + }) + if jc.Check("couldn't load contracts", err) == nil { + jc.Encode(contracts) + } +} + +func (b *Bus) contractsRenewedIDHandlerGET(jc jape.Context) { + var id types.FileContractID + if jc.DecodeParam("id", &id) != nil { + return + } + + md, err := b.ms.RenewedContract(jc.Request.Context(), id) + if jc.Check("faild to fetch renewed contract", err) == nil { + jc.Encode(md) + } +} + +func (b *Bus) contractsArchiveHandlerPOST(jc jape.Context) { + var toArchive api.ContractsArchiveRequest + if jc.Decode(&toArchive) != nil { + return + } + + if jc.Check("failed to archive contracts", b.ms.ArchiveContracts(jc.Request.Context(), toArchive)) == nil { + for fcid, reason := range toArchive { + b.broadcastAction(webhooks.Event{ + Module: api.ModuleContract, + Event: api.EventArchive, + Payload: api.EventContractArchive{ + ContractID: fcid, + Reason: reason, + Timestamp: time.Now().UTC(), + }, + }) + } + } +} + +func (b *Bus) contractsSetsHandlerGET(jc jape.Context) { + sets, err := b.ms.ContractSets(jc.Request.Context()) + if jc.Check("couldn't fetch contract sets", err) == nil { + jc.Encode(sets) + } +} + +func (b *Bus) contractsSetHandlerPUT(jc jape.Context) { + var contractIds []types.FileContractID + if set := jc.PathParam("set"); set == "" { + jc.Error(errors.New("path parameter 'set' can not be empty"), http.StatusBadRequest) + return + } else if jc.Decode(&contractIds) != nil { + return + } else if jc.Check("could not add contracts to set", b.ms.SetContractSet(jc.Request.Context(), set, contractIds)) != nil { + return + } else { + b.broadcastAction(webhooks.Event{ + Module: api.ModuleContractSet, + Event: api.EventUpdate, + Payload: api.EventContractSetUpdate{ + Name: set, + ContractIDs: contractIds, + Timestamp: time.Now().UTC(), + }, + }) + } +} + +func (b *Bus) contractsSetHandlerDELETE(jc jape.Context) { + if set := jc.PathParam("set"); set != "" { + jc.Check("could not remove contract set", b.ms.RemoveContractSet(jc.Request.Context(), set)) + } +} + +func (b *Bus) contractAcquireHandlerPOST(jc jape.Context) { + var id types.FileContractID + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.ContractAcquireRequest + if jc.Decode(&req) != nil { + return + } + + lockID, err := b.contractLocker.Acquire(jc.Request.Context(), req.Priority, id, time.Duration(req.Duration)) + if jc.Check("failed to acquire contract", err) != nil { + return + } + jc.Encode(api.ContractAcquireResponse{ + LockID: lockID, + }) +} + +func (b *Bus) contractKeepaliveHandlerPOST(jc jape.Context) { + var id types.FileContractID + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.ContractKeepaliveRequest + if jc.Decode(&req) != nil { + return + } + + err := b.contractLocker.KeepAlive(id, req.LockID, time.Duration(req.Duration)) + if jc.Check("failed to extend lock duration", err) != nil { + return + } +} + +func (b *Bus) contractsPrunableDataHandlerGET(jc jape.Context) { + sizes, err := b.ms.ContractSizes(jc.Request.Context()) + if jc.Check("failed to fetch contract sizes", err) != nil { + return + } + + // prepare the response + var contracts []api.ContractPrunableData + var totalPrunable, totalSize uint64 + + // build the response + for fcid, size := range sizes { + // adjust the amount of prunable data with the pending uploads, due to + // how we record contract spending a contract's size might already + // include pending sectors + pending := b.sectors.Pending(fcid) + if pending > size.Prunable { + size.Prunable = 0 + } else { + size.Prunable -= pending + } + + contracts = append(contracts, api.ContractPrunableData{ + ID: fcid, + ContractSize: size, + }) + totalPrunable += size.Prunable + totalSize += size.Size + } + + // sort contracts by the amount of prunable data + sort.Slice(contracts, func(i, j int) bool { + if contracts[i].Prunable == contracts[j].Prunable { + return contracts[i].Size > contracts[j].Size + } + return contracts[i].Prunable > contracts[j].Prunable + }) + + jc.Encode(api.ContractsPrunableDataResponse{ + Contracts: contracts, + TotalPrunable: totalPrunable, + TotalSize: totalSize, + }) +} + +func (b *Bus) contractSizeHandlerGET(jc jape.Context) { + var id types.FileContractID + if jc.DecodeParam("id", &id) != nil { + return + } + + size, err := b.ms.ContractSize(jc.Request.Context(), id) + if errors.Is(err, api.ErrContractNotFound) { + jc.Error(err, http.StatusNotFound) + return + } else if jc.Check("failed to fetch contract size", err) != nil { + return + } + + // adjust the amount of prunable data with the pending uploads, due to how + // we record contract spending a contract's size might already include + // pending sectors + pending := b.sectors.Pending(id) + if pending > size.Prunable { + size.Prunable = 0 + } else { + size.Prunable -= pending + } + + jc.Encode(size) +} + +func (b *Bus) contractReleaseHandlerPOST(jc jape.Context) { + var id types.FileContractID + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.ContractReleaseRequest + if jc.Decode(&req) != nil { + return + } + if jc.Check("failed to release contract", b.contractLocker.Release(id, req.LockID)) != nil { + return + } +} + +func (b *Bus) contractIDHandlerGET(jc jape.Context) { + var id types.FileContractID + if jc.DecodeParam("id", &id) != nil { + return + } + c, err := b.ms.Contract(jc.Request.Context(), id) + if jc.Check("couldn't load contract", err) == nil { + jc.Encode(c) + } +} + +func (b *Bus) contractIDHandlerPOST(jc jape.Context) { + var id types.FileContractID + var req api.ContractAddRequest + if jc.DecodeParam("id", &id) != nil || jc.Decode(&req) != nil { + return + } else if req.Contract.ID() != id { + http.Error(jc.ResponseWriter, "contract ID mismatch", http.StatusBadRequest) + return + } else if req.TotalCost.IsZero() { + http.Error(jc.ResponseWriter, "TotalCost can not be zero", http.StatusBadRequest) + return + } + + a, err := b.ms.AddContract(jc.Request.Context(), req.Contract, req.ContractPrice, req.TotalCost, req.StartHeight, req.State) + if jc.Check("couldn't store contract", err) != nil { + return + } + + b.broadcastAction(webhooks.Event{ + Module: api.ModuleContract, + Event: api.EventAdd, + Payload: api.EventContractAdd{ + Added: a, + Timestamp: time.Now().UTC(), + }, + }) + + jc.Encode(a) +} + +func (b *Bus) contractIDRenewedHandlerPOST(jc jape.Context) { + var id types.FileContractID + var req api.ContractRenewedRequest + if jc.DecodeParam("id", &id) != nil || jc.Decode(&req) != nil { + return + } + if req.Contract.ID() != id { + http.Error(jc.ResponseWriter, "contract ID mismatch", http.StatusBadRequest) + return + } + if req.TotalCost.IsZero() { + http.Error(jc.ResponseWriter, "TotalCost can not be zero", http.StatusBadRequest) + return + } + if req.State == "" { + req.State = api.ContractStatePending + } + r, err := b.ms.AddRenewedContract(jc.Request.Context(), req.Contract, req.ContractPrice, req.TotalCost, req.StartHeight, req.RenewedFrom, req.State) + if jc.Check("couldn't store contract", err) != nil { + return + } + + b.sectors.HandleRenewal(req.Contract.ID(), req.RenewedFrom) + b.broadcastAction(webhooks.Event{ + Module: api.ModuleContract, + Event: api.EventRenew, + Payload: api.EventContractRenew{ + Renewal: r, + Timestamp: time.Now().UTC(), + }, + }) + + jc.Encode(r) +} + +func (b *Bus) contractIDRootsHandlerGET(jc jape.Context) { + var id types.FileContractID + if jc.DecodeParam("id", &id) != nil { + return + } + + roots, err := b.ms.ContractRoots(jc.Request.Context(), id) + if jc.Check("couldn't fetch contract sectors", err) == nil { + jc.Encode(api.ContractRootsResponse{ + Roots: roots, + Uploading: b.sectors.Sectors(id), + }) + } +} + +func (b *Bus) contractIDHandlerDELETE(jc jape.Context) { + var id types.FileContractID + if jc.DecodeParam("id", &id) != nil { + return + } + jc.Check("couldn't remove contract", b.ms.ArchiveContract(jc.Request.Context(), id, api.ContractArchivalReasonRemoved)) +} + +func (b *Bus) contractsAllHandlerDELETE(jc jape.Context) { + jc.Check("couldn't remove contracts", b.ms.ArchiveAllContracts(jc.Request.Context(), api.ContractArchivalReasonRemoved)) +} + +func (b *Bus) searchObjectsHandlerGET(jc jape.Context) { + offset := 0 + limit := -1 + var key string + if jc.DecodeForm("offset", &offset) != nil || jc.DecodeForm("limit", &limit) != nil || jc.DecodeForm("key", &key) != nil { + return + } + bucket := api.DefaultBucketName + if jc.DecodeForm("bucket", &bucket) != nil { + return + } + keys, err := b.ms.SearchObjects(jc.Request.Context(), bucket, key, offset, limit) + if jc.Check("couldn't list objects", err) != nil { + return + } + jc.Encode(keys) +} + +func (b *Bus) objectsHandlerGET(jc jape.Context) { + var ignoreDelim bool + if jc.DecodeForm("ignoreDelim", &ignoreDelim) != nil { + return + } + path := jc.PathParam("path") + if strings.HasSuffix(path, "/") && !ignoreDelim { + b.objectEntriesHandlerGET(jc, path) + return + } + bucket := api.DefaultBucketName + if jc.DecodeForm("bucket", &bucket) != nil { + return + } + var onlymetadata bool + if jc.DecodeForm("onlymetadata", &onlymetadata) != nil { + return + } + + var o api.Object + var err error + if onlymetadata { + o, err = b.ms.ObjectMetadata(jc.Request.Context(), bucket, path) + } else { + o, err = b.ms.Object(jc.Request.Context(), bucket, path) + } + if errors.Is(err, api.ErrObjectNotFound) { + jc.Error(err, http.StatusNotFound) + return + } else if jc.Check("couldn't load object", err) != nil { + return + } + jc.Encode(api.ObjectsResponse{Object: &o}) +} + +func (b *Bus) objectEntriesHandlerGET(jc jape.Context, path string) { + bucket := api.DefaultBucketName + if jc.DecodeForm("bucket", &bucket) != nil { + return + } + + var prefix string + if jc.DecodeForm("prefix", &prefix) != nil { + return + } + + var sortBy string + if jc.DecodeForm("sortBy", &sortBy) != nil { + return + } + + var sortDir string + if jc.DecodeForm("sortDir", &sortDir) != nil { + return + } + + var marker string + if jc.DecodeForm("marker", &marker) != nil { + return + } + + var offset int + if jc.DecodeForm("offset", &offset) != nil { + return + } + limit := -1 + if jc.DecodeForm("limit", &limit) != nil { + return + } + + // look for object entries + entries, hasMore, err := b.ms.ObjectEntries(jc.Request.Context(), bucket, path, prefix, sortBy, sortDir, marker, offset, limit) + if jc.Check("couldn't list object entries", err) != nil { + return + } + + jc.Encode(api.ObjectsResponse{Entries: entries, HasMore: hasMore}) +} + +func (b *Bus) objectsHandlerPUT(jc jape.Context) { + var aor api.AddObjectRequest + if jc.Decode(&aor) != nil { + return + } else if aor.Bucket == "" { + aor.Bucket = api.DefaultBucketName + } + jc.Check("couldn't store object", b.ms.UpdateObject(jc.Request.Context(), aor.Bucket, jc.PathParam("path"), aor.ContractSet, aor.ETag, aor.MimeType, aor.Metadata, aor.Object)) +} + +func (b *Bus) objectsCopyHandlerPOST(jc jape.Context) { + var orr api.CopyObjectsRequest + if jc.Decode(&orr) != nil { + return + } + om, err := b.ms.CopyObject(jc.Request.Context(), orr.SourceBucket, orr.DestinationBucket, orr.SourcePath, orr.DestinationPath, orr.MimeType, orr.Metadata) + if jc.Check("couldn't copy object", err) != nil { + return + } + + jc.ResponseWriter.Header().Set("Last-Modified", om.ModTime.Std().Format(http.TimeFormat)) + jc.ResponseWriter.Header().Set("ETag", api.FormatETag(om.ETag)) + jc.Encode(om) +} + +func (b *Bus) objectsListHandlerPOST(jc jape.Context) { + var req api.ObjectsListRequest + if jc.Decode(&req) != nil { + return + } + if req.Bucket == "" { + req.Bucket = api.DefaultBucketName + } + resp, err := b.ms.ListObjects(jc.Request.Context(), req.Bucket, req.Prefix, req.SortBy, req.SortDir, req.Marker, req.Limit) + if errors.Is(err, api.ErrMarkerNotFound) { + jc.Error(err, http.StatusBadRequest) + return + } else if jc.Check("couldn't list objects", err) != nil { + return + } + jc.Encode(resp) +} + +func (b *Bus) objectsRenameHandlerPOST(jc jape.Context) { + var orr api.ObjectsRenameRequest + if jc.Decode(&orr) != nil { + return + } else if orr.Bucket == "" { + orr.Bucket = api.DefaultBucketName + } + if orr.Mode == api.ObjectsRenameModeSingle { + // Single object rename. + if strings.HasSuffix(orr.From, "/") || strings.HasSuffix(orr.To, "/") { + jc.Error(fmt.Errorf("can't rename dirs with mode %v", orr.Mode), http.StatusBadRequest) + return + } + jc.Check("couldn't rename object", b.ms.RenameObject(jc.Request.Context(), orr.Bucket, orr.From, orr.To, orr.Force)) + return + } else if orr.Mode == api.ObjectsRenameModeMulti { + // Multi object rename. + if !strings.HasSuffix(orr.From, "/") || !strings.HasSuffix(orr.To, "/") { + jc.Error(fmt.Errorf("can't rename file with mode %v", orr.Mode), http.StatusBadRequest) + return + } + jc.Check("couldn't rename objects", b.ms.RenameObjects(jc.Request.Context(), orr.Bucket, orr.From, orr.To, orr.Force)) + return + } else { + // Invalid mode. + jc.Error(fmt.Errorf("invalid mode: %v", orr.Mode), http.StatusBadRequest) + return + } +} + +func (b *Bus) objectsHandlerDELETE(jc jape.Context) { + var batch bool + if jc.DecodeForm("batch", &batch) != nil { + return + } + bucket := api.DefaultBucketName + if jc.DecodeForm("bucket", &bucket) != nil { + return + } + var err error + if batch { + err = b.ms.RemoveObjects(jc.Request.Context(), bucket, jc.PathParam("path")) + } else { + err = b.ms.RemoveObject(jc.Request.Context(), bucket, jc.PathParam("path")) + } + if errors.Is(err, api.ErrObjectNotFound) { + jc.Error(err, http.StatusNotFound) + return + } + jc.Check("couldn't delete object", err) +} + +func (b *Bus) slabbuffersHandlerGET(jc jape.Context) { + buffers, err := b.ms.SlabBuffers(jc.Request.Context()) + if jc.Check("couldn't get slab buffers info", err) != nil { + return + } + jc.Encode(buffers) +} + +func (b *Bus) objectsStatshandlerGET(jc jape.Context) { + opts := api.ObjectsStatsOpts{} + if jc.DecodeForm("bucket", &opts.Bucket) != nil { + return + } + info, err := b.ms.ObjectsStats(jc.Request.Context(), opts) + if jc.Check("couldn't get objects stats", err) != nil { + return + } + jc.Encode(info) +} + +func (b *Bus) packedSlabsHandlerFetchPOST(jc jape.Context) { + var psrg api.PackedSlabsRequestGET + if jc.Decode(&psrg) != nil { + return + } + if psrg.MinShards == 0 || psrg.TotalShards == 0 { + jc.Error(fmt.Errorf("min_shards and total_shards must be non-zero"), http.StatusBadRequest) + return + } + if psrg.LockingDuration == 0 { + jc.Error(fmt.Errorf("locking_duration must be non-zero"), http.StatusBadRequest) + return + } + if psrg.ContractSet == "" { + jc.Error(fmt.Errorf("contract_set must be non-empty"), http.StatusBadRequest) + return + } + slabs, err := b.ms.PackedSlabsForUpload(jc.Request.Context(), time.Duration(psrg.LockingDuration), psrg.MinShards, psrg.TotalShards, psrg.ContractSet, psrg.Limit) + if jc.Check("couldn't get packed slabs", err) != nil { + return + } + jc.Encode(slabs) +} + +func (b *Bus) packedSlabsHandlerDonePOST(jc jape.Context) { + var psrp api.PackedSlabsRequestPOST + if jc.Decode(&psrp) != nil { + return + } + jc.Check("failed to mark packed slab(s) as uploaded", b.ms.MarkPackedSlabsUploaded(jc.Request.Context(), psrp.Slabs)) +} + +func (b *Bus) sectorsHostRootHandlerDELETE(jc jape.Context) { + var hk types.PublicKey + var root types.Hash256 + if jc.DecodeParam("hk", &hk) != nil { + return + } else if jc.DecodeParam("root", &root) != nil { + return + } + n, err := b.ms.DeleteHostSector(jc.Request.Context(), hk, root) + if jc.Check("failed to mark sector as lost", err) != nil { + return + } else if n > 0 { + b.logger.Infow("successfully marked sector as lost", "hk", hk, "root", root) + } +} + +func (b *Bus) slabObjectsHandlerGET(jc jape.Context) { + var key object.EncryptionKey + if jc.DecodeParam("key", &key) != nil { + return + } + bucket := api.DefaultBucketName + if jc.DecodeForm("bucket", &bucket) != nil { + return + } + objects, err := b.ms.ObjectsBySlabKey(jc.Request.Context(), bucket, key) + if jc.Check("failed to retrieve objects by slab", err) != nil { + return + } + jc.Encode(objects) +} + +func (b *Bus) slabHandlerGET(jc jape.Context) { + var key object.EncryptionKey + if jc.DecodeParam("key", &key) != nil { + return + } + slab, err := b.ms.Slab(jc.Request.Context(), key) + if errors.Is(err, api.ErrSlabNotFound) { + jc.Error(err, http.StatusNotFound) + return + } else if err != nil { + jc.Error(err, http.StatusInternalServerError) + return + } + jc.Encode(slab) +} + +func (b *Bus) slabHandlerPUT(jc jape.Context) { + var usr api.UpdateSlabRequest + if jc.Decode(&usr) == nil { + jc.Check("couldn't update slab", b.ms.UpdateSlab(jc.Request.Context(), usr.Slab, usr.ContractSet)) + } +} + +func (b *Bus) slabsRefreshHealthHandlerPOST(jc jape.Context) { + jc.Check("failed to recompute health", b.ms.RefreshHealth(jc.Request.Context())) +} + +func (b *Bus) slabsMigrationHandlerPOST(jc jape.Context) { + var msr api.MigrationSlabsRequest + if jc.Decode(&msr) == nil { + if slabs, err := b.ms.UnhealthySlabs(jc.Request.Context(), msr.HealthCutoff, msr.ContractSet, msr.Limit); jc.Check("couldn't fetch slabs for migration", err) == nil { + jc.Encode(api.UnhealthySlabsResponse{ + Slabs: slabs, + }) + } + } +} + +func (b *Bus) slabsPartialHandlerGET(jc jape.Context) { + jc.Custom(nil, []byte{}) + + var key object.EncryptionKey + if jc.DecodeParam("key", &key) != nil { + return + } + var offset int + if jc.DecodeForm("offset", &offset) != nil { + return + } + var length int + if jc.DecodeForm("length", &length) != nil { + return + } + if length <= 0 || offset < 0 { + jc.Error(fmt.Errorf("length must be positive and offset must be non-negative"), http.StatusBadRequest) + return + } + data, err := b.ms.FetchPartialSlab(jc.Request.Context(), key, uint32(offset), uint32(length)) + if errors.Is(err, api.ErrObjectNotFound) { + jc.Error(err, http.StatusNotFound) + return + } else if err != nil { + jc.Error(err, http.StatusInternalServerError) + return + } + jc.ResponseWriter.Write(data) +} + +func (b *Bus) slabsPartialHandlerPOST(jc jape.Context) { + var minShards int + if jc.DecodeForm("minShards", &minShards) != nil { + return + } + var totalShards int + if jc.DecodeForm("totalShards", &totalShards) != nil { + return + } + var contractSet string + if jc.DecodeForm("contractSet", &contractSet) != nil { + return + } + if minShards <= 0 || totalShards <= minShards { + jc.Error(errors.New("minShards must be positive and totalShards must be greater than minShards"), http.StatusBadRequest) + return + } + if totalShards > math.MaxUint8 { + jc.Error(fmt.Errorf("totalShards must be less than or equal to %d", math.MaxUint8), http.StatusBadRequest) + return + } + if contractSet == "" { + jc.Error(errors.New("parameter 'contractSet' is required"), http.StatusBadRequest) + return + } + data, err := io.ReadAll(jc.Request.Body) + if jc.Check("failed to read request body", err) != nil { + return + } + slabs, bufferSize, err := b.ms.AddPartialSlab(jc.Request.Context(), data, uint8(minShards), uint8(totalShards), contractSet) + if jc.Check("failed to add partial slab", err) != nil { + return + } + var pus api.UploadPackingSettings + if err := b.fetchSetting(jc.Request.Context(), api.SettingUploadPacking, &pus); err != nil && !errors.Is(err, api.ErrSettingNotFound) { + jc.Error(fmt.Errorf("could not get upload packing settings: %w", err), http.StatusInternalServerError) + return + } + jc.Encode(api.AddPartialSlabResponse{ + Slabs: slabs, + SlabBufferMaxSizeSoftReached: bufferSize >= pus.SlabBufferMaxSizeSoft, + }) +} + +func (b *Bus) settingsHandlerGET(jc jape.Context) { + if settings, err := b.ss.Settings(jc.Request.Context()); jc.Check("couldn't load settings", err) == nil { + jc.Encode(settings) + } +} + +func (b *Bus) settingKeyHandlerGET(jc jape.Context) { + jc.Custom(nil, (any)(nil)) + + key := jc.PathParam("key") + if key == "" { + jc.Error(errors.New("path parameter 'key' can not be empty"), http.StatusBadRequest) + return + } + + setting, err := b.ss.Setting(jc.Request.Context(), jc.PathParam("key")) + if errors.Is(err, api.ErrSettingNotFound) { + jc.Error(err, http.StatusNotFound) + return + } else if err != nil { + jc.Error(err, http.StatusInternalServerError) + return + } + resp := []byte(setting) + + // populate autopilots of price pinning settings with defaults for better DX + if key == api.SettingPricePinning { + var pps api.PricePinSettings + err = json.Unmarshal([]byte(setting), &pps) + if jc.Check("failed to unmarshal price pinning settings", err) != nil { + return + } else if pps.Autopilots == nil { + pps.Autopilots = make(map[string]api.AutopilotPins) + } + // populate the Autopilots map with the current autopilots + aps, err := b.as.Autopilots(jc.Request.Context()) + if jc.Check("failed to fetch autopilots", err) != nil { + return + } + for _, ap := range aps { + if _, exists := pps.Autopilots[ap.ID]; !exists { + pps.Autopilots[ap.ID] = api.AutopilotPins{} + } + } + // encode the settings back + resp, err = json.Marshal(pps) + if jc.Check("failed to marshal price pinning settings", err) != nil { + return + } + } + jc.ResponseWriter.Header().Set("Content-Type", "application/json") + jc.ResponseWriter.Write(resp) +} + +func (b *Bus) settingKeyHandlerPUT(jc jape.Context) { + key := jc.PathParam("key") + if key == "" { + jc.Error(errors.New("path parameter 'key' can not be empty"), http.StatusBadRequest) + return + } + + var value interface{} + if jc.Decode(&value) != nil { + return + } + + data, err := json.Marshal(value) + if err != nil { + jc.Error(fmt.Errorf("couldn't marshal the given value, error: %v", err), http.StatusBadRequest) + return + } + + switch key { + case api.SettingGouging: + var gs api.GougingSettings + if err := json.Unmarshal(data, &gs); err != nil { + jc.Error(fmt.Errorf("couldn't update gouging settings, invalid request body, %t", value), http.StatusBadRequest) + return + } else if err := gs.Validate(); err != nil { + jc.Error(fmt.Errorf("couldn't update gouging settings, error: %v", err), http.StatusBadRequest) + return + } + b.pinMgr.TriggerUpdate() + case api.SettingRedundancy: + var rs api.RedundancySettings + if err := json.Unmarshal(data, &rs); err != nil { + jc.Error(fmt.Errorf("couldn't update redundancy settings, invalid request body"), http.StatusBadRequest) + return + } else if err := rs.Validate(); err != nil { + jc.Error(fmt.Errorf("couldn't update redundancy settings, error: %v", err), http.StatusBadRequest) + return + } + case api.SettingS3Authentication: + var s3as api.S3AuthenticationSettings + if err := json.Unmarshal(data, &s3as); err != nil { + jc.Error(fmt.Errorf("couldn't update s3 authentication settings, invalid request body"), http.StatusBadRequest) + return + } else if err := s3as.Validate(); err != nil { + jc.Error(fmt.Errorf("couldn't update s3 authentication settings, error: %v", err), http.StatusBadRequest) + return + } + case api.SettingPricePinning: + var pps api.PricePinSettings + if err := json.Unmarshal(data, &pps); err != nil { + jc.Error(fmt.Errorf("couldn't update price pinning settings, invalid request body"), http.StatusBadRequest) + return + } else if err := pps.Validate(); err != nil { + jc.Error(fmt.Errorf("couldn't update price pinning settings, invalid settings, error: %v", err), http.StatusBadRequest) + return + } else if pps.Enabled { + if _, err := ibus.NewForexClient(pps.ForexEndpointURL).SiacoinExchangeRate(jc.Request.Context(), pps.Currency); err != nil { + jc.Error(fmt.Errorf("couldn't update price pinning settings, forex API unreachable,error: %v", err), http.StatusBadRequest) + return + } + } + b.pinMgr.TriggerUpdate() + } + + if jc.Check("could not update setting", b.ss.UpdateSetting(jc.Request.Context(), key, string(data))) == nil { + b.broadcastAction(webhooks.Event{ + Module: api.ModuleSetting, + Event: api.EventUpdate, + Payload: api.EventSettingUpdate{ + Key: key, + Update: value, + Timestamp: time.Now().UTC(), + }, + }) + } +} + +func (b *Bus) settingKeyHandlerDELETE(jc jape.Context) { + key := jc.PathParam("key") + if key == "" { + jc.Error(errors.New("path parameter 'key' can not be empty"), http.StatusBadRequest) + return + } + + if jc.Check("could not delete setting", b.ss.DeleteSetting(jc.Request.Context(), key)) == nil { + b.broadcastAction(webhooks.Event{ + Module: api.ModuleSetting, + Event: api.EventDelete, + Payload: api.EventSettingDelete{ + Key: key, + Timestamp: time.Now().UTC(), + }, + }) + } +} + +func (b *Bus) contractIDAncestorsHandler(jc jape.Context) { + var fcid types.FileContractID + if jc.DecodeParam("id", &fcid) != nil { + return + } + var minStartHeight uint64 + if jc.DecodeForm("minStartHeight", &minStartHeight) != nil { + return + } + ancestors, err := b.ms.AncestorContracts(jc.Request.Context(), fcid, uint64(minStartHeight)) + if jc.Check("failed to fetch ancestor contracts", err) != nil { + return + } + jc.Encode(ancestors) +} + +func (b *Bus) paramsHandlerUploadGET(jc jape.Context) { + gp, err := b.gougingParams(jc.Request.Context()) + if jc.Check("could not get gouging parameters", err) != nil { + return + } + + var contractSet string + var css api.ContractSetSetting + if err := b.fetchSetting(jc.Request.Context(), api.SettingContractSet, &css); err != nil && !errors.Is(err, api.ErrSettingNotFound) { + jc.Error(fmt.Errorf("could not get contract set settings: %w", err), http.StatusInternalServerError) + return + } else if err == nil { + contractSet = css.Default + } + + var uploadPacking bool + var pus api.UploadPackingSettings + if err := b.fetchSetting(jc.Request.Context(), api.SettingUploadPacking, &pus); err != nil && !errors.Is(err, api.ErrSettingNotFound) { + jc.Error(fmt.Errorf("could not get upload packing settings: %w", err), http.StatusInternalServerError) + return + } else if err == nil { + uploadPacking = pus.Enabled + } + + jc.Encode(api.UploadParams{ + ContractSet: contractSet, + CurrentHeight: b.cm.TipState().Index.Height, + GougingParams: gp, + UploadPacking: uploadPacking, + }) +} + +func (b *Bus) consensusState(ctx context.Context) (api.ConsensusState, error) { + index, err := b.cs.ChainIndex(ctx) + if err != nil { + return api.ConsensusState{}, err + } + + var synced bool + block, found := b.cm.Block(index.ID) + if found { + synced = utils.IsSynced(block) + } + + return api.ConsensusState{ + BlockHeight: index.Height, + LastBlockTime: api.TimeRFC3339(block.Timestamp), + Synced: synced, + }, nil +} + +func (b *Bus) paramsHandlerGougingGET(jc jape.Context) { + gp, err := b.gougingParams(jc.Request.Context()) + if jc.Check("could not get gouging parameters", err) != nil { + return + } + jc.Encode(gp) +} + +func (b *Bus) gougingParams(ctx context.Context) (api.GougingParams, error) { + var gs api.GougingSettings + if gss, err := b.ss.Setting(ctx, api.SettingGouging); err != nil { + return api.GougingParams{}, err + } else if err := json.Unmarshal([]byte(gss), &gs); err != nil { + b.logger.Panicf("failed to unmarshal gouging settings '%s': %v", gss, err) + } + + var rs api.RedundancySettings + if rss, err := b.ss.Setting(ctx, api.SettingRedundancy); err != nil { + return api.GougingParams{}, err + } else if err := json.Unmarshal([]byte(rss), &rs); err != nil { + b.logger.Panicf("failed to unmarshal redundancy settings '%s': %v", rss, err) + } + + cs, err := b.consensusState(ctx) + if err != nil { + return api.GougingParams{}, err + } + + return api.GougingParams{ + ConsensusState: cs, + GougingSettings: gs, + RedundancySettings: rs, + TransactionFee: b.cm.RecommendedFee(), + }, nil +} + +func (b *Bus) handleGETAlertsDeprecated(jc jape.Context) { + ar, err := b.alertMgr.Alerts(jc.Request.Context(), alerts.AlertsOpts{Offset: 0, Limit: -1}) + if jc.Check("failed to fetch alerts", err) != nil { + return + } + jc.Encode(ar.Alerts) +} + +func (b *Bus) handleGETAlerts(jc jape.Context) { + if jc.Request.FormValue("offset") == "" && jc.Request.FormValue("limit") == "" { + b.handleGETAlertsDeprecated(jc) + return + } + offset, limit := 0, -1 + var severity alerts.Severity + if jc.DecodeForm("offset", &offset) != nil { + return + } else if jc.DecodeForm("limit", &limit) != nil { + return + } else if offset < 0 { + jc.Error(errors.New("offset must be non-negative"), http.StatusBadRequest) + return + } else if jc.DecodeForm("severity", &severity) != nil { + return + } + ar, err := b.alertMgr.Alerts(jc.Request.Context(), alerts.AlertsOpts{ + Offset: offset, + Limit: limit, + Severity: severity, + }) + if jc.Check("failed to fetch alerts", err) != nil { + return + } + jc.Encode(ar) +} + +func (b *Bus) handlePOSTAlertsDismiss(jc jape.Context) { + var ids []types.Hash256 + if jc.Decode(&ids) != nil { + return + } + jc.Check("failed to dismiss alerts", b.alertMgr.DismissAlerts(jc.Request.Context(), ids...)) +} + +func (b *Bus) handlePOSTAlertsRegister(jc jape.Context) { + var alert alerts.Alert + if jc.Decode(&alert) != nil { + return + } + jc.Check("failed to register alert", b.alertMgr.RegisterAlert(jc.Request.Context(), alert)) +} + +func (b *Bus) accountsHandlerGET(jc jape.Context) { + jc.Encode(b.accountsMgr.Accounts()) +} + +func (b *Bus) accountHandlerGET(jc jape.Context) { + var id rhpv3.Account + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.AccountHandlerPOST + if jc.Decode(&req) != nil { + return + } + acc, err := b.accountsMgr.Account(id, req.HostKey) + if jc.Check("failed to fetch account", err) != nil { + return + } + jc.Encode(acc) +} + +func (b *Bus) accountsAddHandlerPOST(jc jape.Context) { + var id rhpv3.Account + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.AccountsAddBalanceRequest + if jc.Decode(&req) != nil { + return + } + if id == (rhpv3.Account{}) { + jc.Error(errors.New("account id needs to be set"), http.StatusBadRequest) + return + } + if req.HostKey == (types.PublicKey{}) { + jc.Error(errors.New("host needs to be set"), http.StatusBadRequest) + return + } + b.accountsMgr.AddAmount(id, req.HostKey, req.Amount) +} + +func (b *Bus) accountsResetDriftHandlerPOST(jc jape.Context) { + var id rhpv3.Account + if jc.DecodeParam("id", &id) != nil { + return + } + err := b.accountsMgr.ResetDrift(id) + if errors.Is(err, ibus.ErrAccountNotFound) { + jc.Error(err, http.StatusNotFound) + return + } + if jc.Check("failed to reset drift", err) != nil { + return + } +} + +func (b *Bus) accountsUpdateHandlerPOST(jc jape.Context) { + var id rhpv3.Account + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.AccountsUpdateBalanceRequest + if jc.Decode(&req) != nil { + return + } + if id == (rhpv3.Account{}) { + jc.Error(errors.New("account id needs to be set"), http.StatusBadRequest) + return + } + if req.HostKey == (types.PublicKey{}) { + jc.Error(errors.New("host needs to be set"), http.StatusBadRequest) + return + } + b.accountsMgr.SetBalance(id, req.HostKey, req.Amount) +} + +func (b *Bus) accountsRequiresSyncHandlerPOST(jc jape.Context) { + var id rhpv3.Account + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.AccountsRequiresSyncRequest + if jc.Decode(&req) != nil { + return + } + if id == (rhpv3.Account{}) { + jc.Error(errors.New("account id needs to be set"), http.StatusBadRequest) + return + } + if req.HostKey == (types.PublicKey{}) { + jc.Error(errors.New("host needs to be set"), http.StatusBadRequest) + return + } + err := b.accountsMgr.ScheduleSync(id, req.HostKey) + if errors.Is(err, ibus.ErrAccountNotFound) { + jc.Error(err, http.StatusNotFound) + return + } + if jc.Check("failed to set requiresSync flag on account", err) != nil { + return + } +} + +func (b *Bus) accountsLockHandlerPOST(jc jape.Context) { + var id rhpv3.Account + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.AccountsLockHandlerRequest + if jc.Decode(&req) != nil { + return + } + + acc, lockID := b.accountsMgr.LockAccount(jc.Request.Context(), id, req.HostKey, req.Exclusive, time.Duration(req.Duration)) + jc.Encode(api.AccountsLockHandlerResponse{ + Account: acc, + LockID: lockID, + }) +} + +func (b *Bus) accountsUnlockHandlerPOST(jc jape.Context) { + var id rhpv3.Account + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.AccountsUnlockHandlerRequest + if jc.Decode(&req) != nil { + return + } + + err := b.accountsMgr.UnlockAccount(id, req.LockID) + if jc.Check("failed to unlock account", err) != nil { + return + } +} + +func (b *Bus) autopilotsListHandlerGET(jc jape.Context) { + if autopilots, err := b.as.Autopilots(jc.Request.Context()); jc.Check("failed to fetch autopilots", err) == nil { + jc.Encode(autopilots) + } +} + +func (b *Bus) autopilotsHandlerGET(jc jape.Context) { + var id string + if jc.DecodeParam("id", &id) != nil { + return + } + ap, err := b.as.Autopilot(jc.Request.Context(), id) + if errors.Is(err, api.ErrAutopilotNotFound) { + jc.Error(err, http.StatusNotFound) + return + } + if jc.Check("couldn't load object", err) != nil { + return + } + + jc.Encode(ap) +} + +func (b *Bus) autopilotsHandlerPUT(jc jape.Context) { + var id string + if jc.DecodeParam("id", &id) != nil { + return + } + + var ap api.Autopilot + if jc.Decode(&ap) != nil { + return + } + + if ap.ID != id { + jc.Error(errors.New("id in path and body don't match"), http.StatusBadRequest) + return + } + + if jc.Check("failed to update autopilot", b.as.UpdateAutopilot(jc.Request.Context(), ap)) == nil { + b.pinMgr.TriggerUpdate() + } +} + +func (b *Bus) autopilotHostCheckHandlerPUT(jc jape.Context) { + var id string + if jc.DecodeParam("id", &id) != nil { + return + } + var hk types.PublicKey + if jc.DecodeParam("hostkey", &hk) != nil { + return + } + var hc api.HostCheck + if jc.Check("failed to decode host check", jc.Decode(&hc)) != nil { + return + } + + err := b.hs.UpdateHostCheck(jc.Request.Context(), id, hk, hc) + if errors.Is(err, api.ErrAutopilotNotFound) { + jc.Error(err, http.StatusNotFound) + return + } else if jc.Check("failed to update host", err) != nil { + return + } +} + +func (b *Bus) broadcastAction(e webhooks.Event) { + log := b.logger.With("event", e.Event).With("module", e.Module) + err := b.webhooksMgr.BroadcastAction(context.Background(), e) + if err != nil { + log.With(zap.Error(err)).Error("failed to broadcast action") + } else { + log.Debug("successfully broadcast action") + } +} + +func (b *Bus) contractTaxHandlerGET(jc jape.Context) { + var payout types.Currency + if jc.DecodeParam("payout", (*api.ParamCurrency)(&payout)) != nil { + return + } + cs := b.cm.TipState() + jc.Encode(cs.FileContractTax(types.FileContract{Payout: payout})) +} + +func (b *Bus) stateHandlerGET(jc jape.Context) { + jc.Encode(api.BusStateResponse{ + StartTime: api.TimeRFC3339(b.startTime), + BuildState: api.BuildState{ + Version: build.Version(), + Commit: build.Commit(), + OS: runtime.GOOS, + BuildTime: api.TimeRFC3339(build.BuildTime()), + }, + Network: b.cm.TipState().Network.Name, + }) +} + +func (b *Bus) uploadTrackHandlerPOST(jc jape.Context) { + var id api.UploadID + if jc.DecodeParam("id", &id) == nil { + jc.Check("failed to track upload", b.sectors.StartUpload(id)) + } +} + +func (b *Bus) uploadAddSectorHandlerPOST(jc jape.Context) { + var id api.UploadID + if jc.DecodeParam("id", &id) != nil { + return + } + var req api.UploadSectorRequest + if jc.Decode(&req) != nil { + return + } + jc.Check("failed to add sector", b.sectors.AddSector(id, req.ContractID, req.Root)) +} + +func (b *Bus) uploadFinishedHandlerDELETE(jc jape.Context) { + var id api.UploadID + if jc.DecodeParam("id", &id) == nil { + b.sectors.FinishUpload(id) + } +} + +func (b *Bus) webhookActionHandlerPost(jc jape.Context) { + var action webhooks.Event + if jc.Check("failed to decode action", jc.Decode(&action)) != nil { + return + } + b.broadcastAction(action) +} + +func (b *Bus) webhookHandlerDelete(jc jape.Context) { + var wh webhooks.Webhook + if jc.Decode(&wh) != nil { + return + } + err := b.webhooksMgr.Delete(jc.Request.Context(), wh) + if errors.Is(err, webhooks.ErrWebhookNotFound) { + jc.Error(fmt.Errorf("webhook for URL %v and event %v.%v not found", wh.URL, wh.Module, wh.Event), http.StatusNotFound) + return + } else if jc.Check("failed to delete webhook", err) != nil { + return + } +} + +func (b *Bus) webhookHandlerGet(jc jape.Context) { + webhooks, queueInfos := b.webhooksMgr.Info() + jc.Encode(api.WebhookResponse{ + Queues: queueInfos, + Webhooks: webhooks, + }) +} + +func (b *Bus) webhookHandlerPost(jc jape.Context) { + var req webhooks.Webhook + if jc.Decode(&req) != nil { + return + } + + err := b.webhooksMgr.Register(jc.Request.Context(), webhooks.Webhook{ + Event: req.Event, + Module: req.Module, + URL: req.URL, + Headers: req.Headers, + }) + if err != nil { + jc.Error(fmt.Errorf("failed to add Webhook: %w", err), http.StatusInternalServerError) + return + } +} + +func (b *Bus) metricsHandlerDELETE(jc jape.Context) { + metric := jc.PathParam("key") + if metric == "" { + jc.Error(errors.New("parameter 'metric' is required"), http.StatusBadRequest) + return + } + + var cutoff time.Time + if jc.DecodeForm("cutoff", (*api.TimeRFC3339)(&cutoff)) != nil { + return + } else if cutoff.IsZero() { + jc.Error(errors.New("parameter 'cutoff' is required"), http.StatusBadRequest) + return + } + + err := b.mtrcs.PruneMetrics(jc.Request.Context(), metric, cutoff) + if jc.Check("failed to prune metrics", err) != nil { + return + } +} + +func (b *Bus) metricsHandlerPUT(jc jape.Context) { + jc.Custom((*interface{})(nil), nil) + + key := jc.PathParam("key") + switch key { + case api.MetricContractPrune: + // TODO: jape hack - remove once jape can handle decoding multiple different request types + var req api.ContractPruneMetricRequestPUT + if err := json.NewDecoder(jc.Request.Body).Decode(&req); err != nil { + jc.Error(fmt.Errorf("couldn't decode request type (%T): %w", req, err), http.StatusBadRequest) + return + } else if jc.Check("failed to record contract prune metric", b.mtrcs.RecordContractPruneMetric(jc.Request.Context(), req.Metrics...)) != nil { + return + } + case api.MetricContractSetChurn: + // TODO: jape hack - remove once jape can handle decoding multiple different request types + var req api.ContractSetChurnMetricRequestPUT + if err := json.NewDecoder(jc.Request.Body).Decode(&req); err != nil { + jc.Error(fmt.Errorf("couldn't decode request type (%T): %w", req, err), http.StatusBadRequest) + return + } else if jc.Check("failed to record contract churn metric", b.mtrcs.RecordContractSetChurnMetric(jc.Request.Context(), req.Metrics...)) != nil { + return + } + default: + jc.Error(fmt.Errorf("unknown metric key '%s'", key), http.StatusBadRequest) + return + } +} + +func (b *Bus) metricsHandlerGET(jc jape.Context) { + // parse mandatory query parameters + var start time.Time + if jc.DecodeForm("start", (*api.TimeRFC3339)(&start)) != nil { + return + } else if start.IsZero() { + jc.Error(errors.New("parameter 'start' is required"), http.StatusBadRequest) + return + } + + var n uint64 + if jc.DecodeForm("n", &n) != nil { + return + } else if n == 0 { + if jc.Request.FormValue("n") == "" { + jc.Error(errors.New("parameter 'n' is required"), http.StatusBadRequest) + } else { + jc.Error(errors.New("'n' has to be greater than zero"), http.StatusBadRequest) + } + return + } + + var interval time.Duration + if jc.DecodeForm("interval", (*api.DurationMS)(&interval)) != nil { + return + } else if interval == 0 { + jc.Error(errors.New("parameter 'interval' is required"), http.StatusBadRequest) + return + } + + // parse optional query parameters + var metrics interface{} + var err error + key := jc.PathParam("key") + switch key { + case api.MetricContract: + var opts api.ContractMetricsQueryOpts + if jc.DecodeForm("contractID", &opts.ContractID) != nil { + return + } else if jc.DecodeForm("hostKey", &opts.HostKey) != nil { + return + } + metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) + case api.MetricContractPrune: + var opts api.ContractPruneMetricsQueryOpts + if jc.DecodeForm("contractID", &opts.ContractID) != nil { + return + } else if jc.DecodeForm("hostKey", &opts.HostKey) != nil { + return + } else if jc.DecodeForm("hostVersion", &opts.HostVersion) != nil { + return + } + metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) + case api.MetricContractSet: + var opts api.ContractSetMetricsQueryOpts + if jc.DecodeForm("name", &opts.Name) != nil { + return + } + metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) + case api.MetricContractSetChurn: + var opts api.ContractSetChurnMetricsQueryOpts + if jc.DecodeForm("name", &opts.Name) != nil { + return + } else if jc.DecodeForm("direction", &opts.Direction) != nil { + return + } else if jc.DecodeForm("reason", &opts.Reason) != nil { + return + } + metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) + case api.MetricWallet: + var opts api.WalletMetricsQueryOpts + metrics, err = b.metrics(jc.Request.Context(), key, start, n, interval, opts) + default: + jc.Error(fmt.Errorf("unknown metric '%s'", key), http.StatusBadRequest) + return + } + if errors.Is(err, api.ErrMaxIntervalsExceeded) { + jc.Error(err, http.StatusBadRequest) + return + } else if jc.Check(fmt.Sprintf("failed to fetch '%s' metrics", key), err) != nil { + return + } + jc.Encode(metrics) +} + +func (b *Bus) metrics(ctx context.Context, key string, start time.Time, n uint64, interval time.Duration, opts interface{}) (interface{}, error) { + switch key { + case api.MetricContract: + return b.mtrcs.ContractMetrics(ctx, start, n, interval, opts.(api.ContractMetricsQueryOpts)) + case api.MetricContractPrune: + return b.mtrcs.ContractPruneMetrics(ctx, start, n, interval, opts.(api.ContractPruneMetricsQueryOpts)) + case api.MetricContractSet: + return b.mtrcs.ContractSetMetrics(ctx, start, n, interval, opts.(api.ContractSetMetricsQueryOpts)) + case api.MetricContractSetChurn: + return b.mtrcs.ContractSetChurnMetrics(ctx, start, n, interval, opts.(api.ContractSetChurnMetricsQueryOpts)) + case api.MetricWallet: + return b.mtrcs.WalletMetrics(ctx, start, n, interval, opts.(api.WalletMetricsQueryOpts)) + } + return nil, fmt.Errorf("unknown metric '%s'", key) +} + +func (b *Bus) multipartHandlerCreatePOST(jc jape.Context) { + var req api.MultipartCreateRequest + if jc.Decode(&req) != nil { + return + } + + var key object.EncryptionKey + if req.GenerateKey { + key = object.GenerateEncryptionKey() + } else if req.Key == nil { + key = object.NoOpKey + } else { + key = *req.Key + } + + resp, err := b.ms.CreateMultipartUpload(jc.Request.Context(), req.Bucket, req.Path, key, req.MimeType, req.Metadata) + if jc.Check("failed to create multipart upload", err) != nil { + return + } + jc.Encode(resp) +} + +func (b *Bus) multipartHandlerAbortPOST(jc jape.Context) { + var req api.MultipartAbortRequest + if jc.Decode(&req) != nil { + return + } + err := b.ms.AbortMultipartUpload(jc.Request.Context(), req.Bucket, req.Path, req.UploadID) + if jc.Check("failed to abort multipart upload", err) != nil { + return + } +} + +func (b *Bus) multipartHandlerCompletePOST(jc jape.Context) { + var req api.MultipartCompleteRequest + if jc.Decode(&req) != nil { + return + } + resp, err := b.ms.CompleteMultipartUpload(jc.Request.Context(), req.Bucket, req.Path, req.UploadID, req.Parts, api.CompleteMultipartOptions{ + Metadata: req.Metadata, + }) + if jc.Check("failed to complete multipart upload", err) != nil { + return + } + jc.Encode(resp) +} + +func (b *Bus) multipartHandlerUploadPartPUT(jc jape.Context) { + var req api.MultipartAddPartRequest + if jc.Decode(&req) != nil { + return + } + if req.Bucket == "" { + req.Bucket = api.DefaultBucketName + } else if req.ContractSet == "" { + jc.Error(errors.New("contract_set must be non-empty"), http.StatusBadRequest) + return + } else if req.ETag == "" { + jc.Error(errors.New("etag must be non-empty"), http.StatusBadRequest) + return + } else if req.PartNumber <= 0 || req.PartNumber > gofakes3.MaxUploadPartNumber { + jc.Error(fmt.Errorf("part_number must be between 1 and %d", gofakes3.MaxUploadPartNumber), http.StatusBadRequest) + return + } else if req.UploadID == "" { + jc.Error(errors.New("upload_id must be non-empty"), http.StatusBadRequest) + return + } + err := b.ms.AddMultipartPart(jc.Request.Context(), req.Bucket, req.Path, req.ContractSet, req.ETag, req.UploadID, req.PartNumber, req.Slices) + if jc.Check("failed to upload part", err) != nil { + return + } +} + +func (b *Bus) multipartHandlerUploadGET(jc jape.Context) { + resp, err := b.ms.MultipartUpload(jc.Request.Context(), jc.PathParam("id")) + if jc.Check("failed to get multipart upload", err) != nil { + return + } + jc.Encode(resp) +} + +func (b *Bus) multipartHandlerListUploadsPOST(jc jape.Context) { + var req api.MultipartListUploadsRequest + if jc.Decode(&req) != nil { + return + } + resp, err := b.ms.MultipartUploads(jc.Request.Context(), req.Bucket, req.Prefix, req.PathMarker, req.UploadIDMarker, req.Limit) + if jc.Check("failed to list multipart uploads", err) != nil { + return + } + jc.Encode(resp) +} + +func (b *Bus) multipartHandlerListPartsPOST(jc jape.Context) { + var req api.MultipartListPartsRequest + if jc.Decode(&req) != nil { + return + } + resp, err := b.ms.MultipartUploadParts(jc.Request.Context(), req.Bucket, req.Path, req.UploadID, req.PartNumberMarker, int64(req.Limit)) + if jc.Check("failed to list multipart upload parts", err) != nil { + return + } + jc.Encode(resp) +} diff --git a/bus/uploadingsectors_test.go b/bus/uploadingsectors_test.go deleted file mode 100644 index b1c9b725a..000000000 --- a/bus/uploadingsectors_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package bus - -import ( - "errors" - "testing" - - rhpv2 "go.sia.tech/core/rhp/v2" - "go.sia.tech/core/types" - "go.sia.tech/renterd/api" - "lukechampine.com/frand" -) - -func TestUploadingSectorsCache(t *testing.T) { - c := newUploadingSectorsCache() - - uID1 := newTestUploadID() - uID2 := newTestUploadID() - - fcid1 := types.FileContractID{1} - fcid2 := types.FileContractID{2} - fcid3 := types.FileContractID{3} - - c.StartUpload(uID1) - c.StartUpload(uID2) - - _ = c.AddSector(uID1, fcid1, types.Hash256{1}) - _ = c.AddSector(uID1, fcid2, types.Hash256{2}) - _ = c.AddSector(uID2, fcid2, types.Hash256{3}) - - if roots1 := c.Sectors(fcid1); len(roots1) != 1 || roots1[0] != (types.Hash256{1}) { - t.Fatal("unexpected cached sectors") - } - if roots2 := c.Sectors(fcid2); len(roots2) != 2 { - t.Fatal("unexpected cached sectors", roots2) - } - if roots3 := c.Sectors(fcid3); len(roots3) != 0 { - t.Fatal("unexpected cached sectors") - } - - if o1, exists := c.uploads[uID1]; !exists || o1.started.IsZero() { - t.Fatal("unexpected") - } - if o2, exists := c.uploads[uID2]; !exists || o2.started.IsZero() { - t.Fatal("unexpected") - } - - c.FinishUpload(uID1) - if roots1 := c.Sectors(fcid1); len(roots1) != 0 { - t.Fatal("unexpected cached sectors") - } - if roots2 := c.Sectors(fcid2); len(roots2) != 1 || roots2[0] != (types.Hash256{3}) { - t.Fatal("unexpected cached sectors") - } - - c.FinishUpload(uID2) - if roots2 := c.Sectors(fcid1); len(roots2) != 0 { - t.Fatal("unexpected cached sectors") - } - - if err := c.AddSector(uID1, fcid1, types.Hash256{1}); !errors.Is(err, api.ErrUnknownUpload) { - t.Fatal("unexpected error", err) - } - if err := c.StartUpload(uID1); err != nil { - t.Fatal("unexpected error", err) - } - if err := c.StartUpload(uID1); !errors.Is(err, api.ErrUploadAlreadyExists) { - t.Fatal("unexpected error", err) - } - - // reset cache - c = newUploadingSectorsCache() - - // track upload that uploads across two contracts - c.StartUpload(uID1) - c.AddSector(uID1, fcid1, types.Hash256{1}) - c.AddSector(uID1, fcid1, types.Hash256{2}) - c.HandleRenewal(fcid2, fcid1) - c.AddSector(uID1, fcid2, types.Hash256{3}) - c.AddSector(uID1, fcid2, types.Hash256{4}) - - // assert pending sizes for both contracts should be 4 sectors - p1 := c.Pending(fcid1) - p2 := c.Pending(fcid2) - if p1 != p2 || p1 != 4*rhpv2.SectorSize { - t.Fatal("unexpected pending size", p1/rhpv2.SectorSize, p2/rhpv2.SectorSize) - } - - // assert sectors for both contracts contain 4 sectors - s1 := c.Sectors(fcid1) - s2 := c.Sectors(fcid2) - if len(s1) != 4 || len(s2) != 4 { - t.Fatal("unexpected sectors", len(s1), len(s2)) - } - - // finish upload - c.FinishUpload(uID1) - s1 = c.Sectors(fcid1) - s2 = c.Sectors(fcid2) - if len(s1) != 0 || len(s2) != 0 { - t.Fatal("unexpected sectors", len(s1), len(s2)) - } - - // renew the contract - c.HandleRenewal(fcid3, fcid2) - - // trigger pruning - c.StartUpload(uID2) - c.FinishUpload(uID2) - - // assert renewedTo gets pruned - if len(c.renewedTo) != 1 { - t.Fatal("unexpected", len(c.renewedTo)) - } -} - -func newTestUploadID() api.UploadID { - var uID api.UploadID - frand.Read(uID[:]) - return uID -} diff --git a/cmd/renterd/commands.go b/cmd/renterd/commands.go index 96c798abf..819eb9de5 100644 --- a/cmd/renterd/commands.go +++ b/cmd/renterd/commands.go @@ -7,10 +7,11 @@ import ( "go.sia.tech/core/types" "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/build" + "go.sia.tech/renterd/config" "gopkg.in/yaml.v3" ) -func cmdBuildConfig() { +func cmdBuildConfig(cfg *config.Config) { if _, err := os.Stat("renterd.yml"); err == nil { if !promptYesNo("renterd.yml already exists. Would you like to overwrite it?") { return @@ -23,10 +24,10 @@ func cmdBuildConfig() { fmt.Println("If you change your wallet seed phrase, your renter will not be able to access Siacoin associated with this wallet.") fmt.Println("Ensure that you have backed up your wallet seed phrase before continuing.") if promptYesNo("Would you like to change your wallet seed phrase?") { - setSeedPhrase() + setSeedPhrase(cfg) } } else { - setSeedPhrase() + setSeedPhrase(cfg) } fmt.Println("") @@ -34,17 +35,17 @@ func cmdBuildConfig() { fmt.Println(wrapANSI("\033[33m", "An admin password is already set.", "\033[0m")) fmt.Println("If you change your admin password, you will need to update any scripts or applications that use the admin API.") if promptYesNo("Would you like to change your admin password?") { - setAPIPassword() + setAPIPassword(cfg) } } else { - setAPIPassword() + setAPIPassword(cfg) } fmt.Println("") - setS3Config() + setS3Config(cfg) fmt.Println("") - setAdvancedConfig() + setAdvancedConfig(cfg) // write the config file configPath := "renterd.yml" @@ -78,9 +79,9 @@ func cmdSeed() { fmt.Println("Address", types.StandardUnlockHash(key.PublicKey())) } -func cmdVersion() { +func cmdVersion(network string) { fmt.Println("renterd", build.Version()) - fmt.Println("Network", build.NetworkName()) + fmt.Println("Network", network) fmt.Println("Commit:", build.Commit()) fmt.Println("Build Date:", build.BuildTime()) } diff --git a/cmd/renterd/config.go b/cmd/renterd/config.go index f4b728c7e..38231458d 100644 --- a/cmd/renterd/config.go +++ b/cmd/renterd/config.go @@ -3,21 +3,402 @@ package main import ( "bufio" "encoding/hex" + "errors" + "flag" "fmt" + "log" "net" "os" "runtime" "strconv" "strings" + "time" + "go.sia.tech/core/consensus" "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" "go.sia.tech/coreutils/wallet" + "go.sia.tech/renterd/api" "go.sia.tech/renterd/config" + "go.sia.tech/renterd/worker/s3" "golang.org/x/term" + "gopkg.in/yaml.v3" "lukechampine.com/frand" ) -var enableANSI = runtime.GOOS != "windows" +// TODO: handle RENTERD_S3_HOST_BUCKET_BASES correctly + +const ( + // accountRefillInterval is the amount of time between refills of ephemeral + // accounts. If we conservatively assume that a good host charges 500 SC / + // TiB, we can pay for about 2.2 GiB with 1 SC. Since we want to refill + // ahead of time at 0.5 SC, that makes 1.1 GiB. Considering a 1 Gbps uplink + // that is shared across 30 uploads, we upload at around 33 Mbps to each + // host. That means uploading 1.1 GiB to drain 0.5 SC takes around 5 + // minutes. That's why we assume 10 seconds to be more than frequent enough + // to refill an account when it's due for another refill. + defaultAccountRefillInterval = 10 * time.Second +) + +var ( + disableStdin bool + enableANSI = runtime.GOOS != "windows" + + hostBasesStr string + keyPairsV4 string + workerRemotePassStr string + workerRemoteAddrsStr string +) + +func defaultConfig() config.Config { + return config.Config{ + Directory: ".", + Seed: os.Getenv("RENTERD_SEED"), + AutoOpenWebUI: true, + Network: "mainnet", + HTTP: config.HTTP{ + Address: "localhost:9980", + Password: os.Getenv("RENTERD_API_PASSWORD"), + }, + ShutdownTimeout: 5 * time.Minute, + Database: config.Database{ + MySQL: config.MySQL{ + User: "renterd", + Database: "renterd", + MetricsDatabase: "renterd_metrics", + }, + }, + Log: config.Log{ + Path: "", // deprecated. included for compatibility. + Level: "", + File: config.LogFile{ + Enabled: true, + Format: "json", + Path: os.Getenv("RENTERD_LOG_FILE"), + }, + StdOut: config.StdOut{ + Enabled: true, + Format: "human", + EnableANSI: runtime.GOOS != "windows", + }, + Database: config.DatabaseLog{ + Enabled: true, + IgnoreRecordNotFoundError: true, + SlowThreshold: 100 * time.Millisecond, + }, + }, + Bus: config.Bus{ + AnnouncementMaxAgeHours: 24 * 7 * 52, // 1 year + Bootstrap: true, + GatewayAddr: ":9981", + UsedUTXOExpiry: 24 * time.Hour, + SlabBufferCompletionThreshold: 1 << 12, + }, + Worker: config.Worker{ + Enabled: true, + + ID: "worker", + ContractLockTimeout: 30 * time.Second, + BusFlushInterval: 5 * time.Second, + + DownloadMaxOverdrive: 5, + DownloadOverdriveTimeout: 3 * time.Second, + + DownloadMaxMemory: 1 << 30, // 1 GiB + UploadMaxMemory: 1 << 30, // 1 GiB + UploadMaxOverdrive: 5, + UploadOverdriveTimeout: 3 * time.Second, + }, + Autopilot: config.Autopilot{ + Enabled: true, + + ID: api.DefaultAutopilotID, + RevisionSubmissionBuffer: 150, // 144 + 6 blocks leeway + AccountsRefillInterval: defaultAccountRefillInterval, + Heartbeat: 30 * time.Minute, + MigrationHealthCutoff: 0.75, + RevisionBroadcastInterval: 7 * 24 * time.Hour, + ScannerBatchSize: 100, + ScannerInterval: 4 * time.Hour, + ScannerNumThreads: 10, + MigratorParallelSlabsPerWorker: 1, + }, + S3: config.S3{ + Address: "localhost:8080", + Enabled: true, + DisableAuth: false, + KeypairsV4: nil, + }, + } +} + +// loadConfig creates a default config and overrides it with the contents of the +// YAML file (specified by the RENTERD_CONFIG_FILE), CLI flags, and environment +// variables, in that order. +func loadConfig() (cfg config.Config, network *consensus.Network, genesis types.Block, err error) { + cfg = defaultConfig() + parseYamlConfig(&cfg) + parseCLIFlags(&cfg) + parseEnvironmentVariables(&cfg) + + // check network + switch cfg.Network { + case "anagami": + network, genesis = chain.TestnetAnagami() + case "mainnet": + network, genesis = chain.Mainnet() + case "zen": + network, genesis = chain.TestnetZen() + default: + err = fmt.Errorf("unknown network '%s'", cfg.Network) + return + } + + return +} + +func sanitizeConfig(cfg *config.Config) error { + // parse remotes + if workerRemoteAddrsStr != "" && workerRemotePassStr != "" { + cfg.Worker.Remotes = cfg.Worker.Remotes[:0] + for _, addr := range strings.Split(workerRemoteAddrsStr, ";") { + cfg.Worker.Remotes = append(cfg.Worker.Remotes, config.RemoteWorker{ + Address: addr, + Password: workerRemotePassStr, + }) + } + } + + // disable worker if remotes are set + if len(cfg.Worker.Remotes) > 0 { + cfg.Worker.Enabled = false + } + + // combine host bucket bases + for _, base := range strings.Split(hostBasesStr, ",") { + if trimmed := strings.TrimSpace(base); trimmed != "" { + cfg.S3.HostBucketBases = append(cfg.S3.HostBucketBases, base) + } + } + + // check that the API password is set + if cfg.HTTP.Password == "" { + if disableStdin { + return errors.New("API password must be set via environment variable or config file when --env flag is set") + } + } + setAPIPassword(cfg) + + // check that the seed is set + if cfg.Seed == "" && (cfg.Worker.Enabled || cfg.Bus.RemoteAddr == "") { // only worker & bus require a seed + if disableStdin { + return errors.New("Seed must be set via environment variable or config file when --env flag is set") + } + setSeedPhrase(cfg) + } + + // validate the seed is valid + if cfg.Seed != "" { + var rawSeed [32]byte + if err := wallet.SeedFromPhrase(&rawSeed, cfg.Seed); err != nil { + return fmt.Errorf("failed to load wallet: %v", err) + } + } + + // parse S3 auth keys + if cfg.S3.Enabled { + if !cfg.S3.DisableAuth && keyPairsV4 != "" { + var err error + cfg.S3.KeypairsV4, err = s3.Parsev4AuthKeys(strings.Split(keyPairsV4, ";")) + if err != nil { + return fmt.Errorf("failed to parse keypairs: %v", err) + } + } + } + + // default log levels + if cfg.Log.Level == "" { + cfg.Log.Level = "info" + } + if cfg.Log.Database.Level == "" { + cfg.Log.Database.Level = cfg.Log.Level + } + + return nil +} + +func parseYamlConfig(cfg *config.Config) { + configPath := "renterd.yml" + if str := os.Getenv("RENTERD_CONFIG_FILE"); str != "" { + configPath = str + } + + // If the config file doesn't exist, don't try to load it. + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return + } + + f, err := os.Open(configPath) + if err != nil { + log.Fatal("failed to open config file:", err) + } + defer f.Close() + + dec := yaml.NewDecoder(f) + dec.KnownFields(true) + + if err := dec.Decode(&cfg); err != nil { + log.Fatal("failed to decode config file:", err) + } +} + +func parseCLIFlags(cfg *config.Config) { + // deprecated - these go first so that they can be overwritten by the non-deprecated flags + flag.StringVar(&cfg.Log.Database.Level, "db.logger.logLevel", cfg.Log.Database.Level, "(deprecated) Logger level (overrides with RENTERD_DB_LOGGER_LOG_LEVEL)") + flag.BoolVar(&cfg.Database.Log.IgnoreRecordNotFoundError, "db.logger.ignoreNotFoundError", cfg.Database.Log.IgnoreRecordNotFoundError, "(deprecated) Ignores 'not found' errors in logger (overrides with RENTERD_DB_LOGGER_IGNORE_NOT_FOUND_ERROR)") + flag.DurationVar(&cfg.Database.Log.SlowThreshold, "db.logger.slowThreshold", cfg.Database.Log.SlowThreshold, "(deprecated) Threshold for slow queries in logger (overrides with RENTERD_DB_LOGGER_SLOW_THRESHOLD)") + flag.StringVar(&cfg.Log.Path, "log-path", cfg.Log.Path, "(deprecated) Path to directory for logs (overrides with RENTERD_LOG_PATH)") + + // node + flag.StringVar(&cfg.HTTP.Address, "http", cfg.HTTP.Address, "Address for serving the API") + flag.StringVar(&cfg.Directory, "dir", cfg.Directory, "Directory for storing node state") + flag.BoolVar(&disableStdin, "env", false, "disable stdin prompts for environment variables (default false)") + flag.BoolVar(&cfg.AutoOpenWebUI, "openui", cfg.AutoOpenWebUI, "automatically open the web UI on startup") + flag.StringVar(&cfg.Network, "network", cfg.Network, "Network to connect to (mainnet|zen|anagami). Defaults to 'mainnet' (overrides with RENTERD_NETWORK)") + + // logger + flag.StringVar(&cfg.Log.Level, "log.level", cfg.Log.Level, "Global logger level (debug|info|warn|error). Defaults to 'info' (overrides with RENTERD_LOG_LEVEL)") + flag.BoolVar(&cfg.Log.File.Enabled, "log.file.enabled", cfg.Log.File.Enabled, "Enables logging to disk. Defaults to 'true'. (overrides with RENTERD_LOG_FILE_ENABLED)") + flag.StringVar(&cfg.Log.File.Format, "log.file.format", cfg.Log.File.Format, "Format of log file (json|human). Defaults to 'json' (overrides with RENTERD_LOG_FILE_FORMAT)") + flag.StringVar(&cfg.Log.File.Path, "log.file.path", cfg.Log.File.Path, "Path of log file. Defaults to 'renterd.log' within the renterd directory. (overrides with RENTERD_LOG_FILE_PATH)") + flag.BoolVar(&cfg.Log.StdOut.Enabled, "log.stdout.enabled", cfg.Log.StdOut.Enabled, "Enables logging to stdout. Defaults to 'true'. (overrides with RENTERD_LOG_STDOUT_ENABLED)") + flag.StringVar(&cfg.Log.StdOut.Format, "log.stdout.format", cfg.Log.StdOut.Format, "Format of log output (json|human). Defaults to 'human' (overrides with RENTERD_LOG_STDOUT_FORMAT)") + flag.BoolVar(&cfg.Log.StdOut.EnableANSI, "log.stdout.enableANSI", cfg.Log.StdOut.EnableANSI, "Enables ANSI color codes in log output. Defaults to 'true' on non-Windows systems. (overrides with RENTERD_LOG_STDOUT_ENABLE_ANSI)") + flag.BoolVar(&cfg.Log.Database.Enabled, "log.database.enabled", cfg.Log.Database.Enabled, "Enable logging database queries. Defaults to 'true' (overrides with RENTERD_LOG_DATABASE_ENABLED)") + flag.StringVar(&cfg.Log.Database.Level, "log.database.level", cfg.Log.Database.Level, "Logger level for database queries (info|warn|error). Defaults to 'warn' (overrides with RENTERD_LOG_LEVEL and RENTERD_LOG_DATABASE_LEVEL)") + flag.BoolVar(&cfg.Log.Database.IgnoreRecordNotFoundError, "log.database.ignoreRecordNotFoundError", cfg.Log.Database.IgnoreRecordNotFoundError, "Enable ignoring 'not found' errors resulting from database queries. Defaults to 'true' (overrides with RENTERD_LOG_DATABASE_IGNORE_RECORD_NOT_FOUND_ERROR)") + flag.DurationVar(&cfg.Log.Database.SlowThreshold, "log.database.slowThreshold", cfg.Log.Database.SlowThreshold, "Threshold for slow queries in logger. Defaults to 100ms (overrides with RENTERD_LOG_DATABASE_SLOW_THRESHOLD)") + + // db + flag.StringVar(&cfg.Database.MySQL.URI, "db.uri", cfg.Database.MySQL.URI, "Database URI for the bus (overrides with RENTERD_DB_URI)") + flag.StringVar(&cfg.Database.MySQL.User, "db.user", cfg.Database.MySQL.User, "Database username for the bus (overrides with RENTERD_DB_USER)") + flag.StringVar(&cfg.Database.MySQL.Database, "db.name", cfg.Database.MySQL.Database, "Database name for the bus (overrides with RENTERD_DB_NAME)") + flag.StringVar(&cfg.Database.MySQL.MetricsDatabase, "db.metricsName", cfg.Database.MySQL.MetricsDatabase, "Database for metrics (overrides with RENTERD_DB_METRICS_NAME)") + + // bus + flag.Uint64Var(&cfg.Bus.AnnouncementMaxAgeHours, "bus.announcementMaxAgeHours", cfg.Bus.AnnouncementMaxAgeHours, "Max age for announcements") + flag.BoolVar(&cfg.Bus.Bootstrap, "bus.bootstrap", cfg.Bus.Bootstrap, "Bootstraps gateway and consensus modules") + flag.StringVar(&cfg.Bus.GatewayAddr, "bus.gatewayAddr", cfg.Bus.GatewayAddr, "Address for Sia peer connections (overrides with RENTERD_BUS_GATEWAY_ADDR)") + flag.DurationVar(&cfg.Bus.PersistInterval, "bus.persistInterval", cfg.Bus.PersistInterval, "(deprecated) Interval for persisting consensus updates") + flag.DurationVar(&cfg.Bus.UsedUTXOExpiry, "bus.usedUTXOExpiry", cfg.Bus.UsedUTXOExpiry, "Expiry for used UTXOs in transactions") + flag.Int64Var(&cfg.Bus.SlabBufferCompletionThreshold, "bus.slabBufferCompletionThreshold", cfg.Bus.SlabBufferCompletionThreshold, "Threshold for slab buffer upload (overrides with RENTERD_BUS_SLAB_BUFFER_COMPLETION_THRESHOLD)") + + // worker + flag.BoolVar(&cfg.Worker.AllowPrivateIPs, "worker.allowPrivateIPs", cfg.Worker.AllowPrivateIPs, "Allows hosts with private IPs") + flag.DurationVar(&cfg.Worker.BusFlushInterval, "worker.busFlushInterval", cfg.Worker.BusFlushInterval, "Interval for flushing data to bus") + flag.Uint64Var(&cfg.Worker.DownloadMaxMemory, "worker.downloadMaxMemory", cfg.Worker.DownloadMaxMemory, "Max amount of RAM the worker allocates for slabs when downloading (overrides with RENTERD_WORKER_DOWNLOAD_MAX_MEMORY)") + flag.Uint64Var(&cfg.Worker.DownloadMaxOverdrive, "worker.downloadMaxOverdrive", cfg.Worker.DownloadMaxOverdrive, "Max overdrive workers for downloads") + flag.StringVar(&cfg.Worker.ID, "worker.id", cfg.Worker.ID, "Unique ID for worker (overrides with RENTERD_WORKER_ID)") + flag.DurationVar(&cfg.Worker.DownloadOverdriveTimeout, "worker.downloadOverdriveTimeout", cfg.Worker.DownloadOverdriveTimeout, "Timeout for overdriving slab downloads") + flag.Uint64Var(&cfg.Worker.UploadMaxMemory, "worker.uploadMaxMemory", cfg.Worker.UploadMaxMemory, "Max amount of RAM the worker allocates for slabs when uploading (overrides with RENTERD_WORKER_UPLOAD_MAX_MEMORY)") + flag.Uint64Var(&cfg.Worker.UploadMaxOverdrive, "worker.uploadMaxOverdrive", cfg.Worker.UploadMaxOverdrive, "Max overdrive workers for uploads") + flag.DurationVar(&cfg.Worker.UploadOverdriveTimeout, "worker.uploadOverdriveTimeout", cfg.Worker.UploadOverdriveTimeout, "Timeout for overdriving slab uploads") + flag.BoolVar(&cfg.Worker.Enabled, "worker.enabled", cfg.Worker.Enabled, "Enables/disables worker (overrides with RENTERD_WORKER_ENABLED)") + flag.BoolVar(&cfg.Worker.AllowUnauthenticatedDownloads, "worker.unauthenticatedDownloads", cfg.Worker.AllowUnauthenticatedDownloads, "Allows unauthenticated downloads (overrides with RENTERD_WORKER_UNAUTHENTICATED_DOWNLOADS)") + flag.StringVar(&cfg.Worker.ExternalAddress, "worker.externalAddress", cfg.Worker.ExternalAddress, "Address of the worker on the network, only necessary when the bus is remote (overrides with RENTERD_WORKER_EXTERNAL_ADDR)") + + // autopilot + flag.DurationVar(&cfg.Autopilot.AccountsRefillInterval, "autopilot.accountRefillInterval", cfg.Autopilot.AccountsRefillInterval, "Interval for refilling workers' account balances") + flag.DurationVar(&cfg.Autopilot.Heartbeat, "autopilot.heartbeat", cfg.Autopilot.Heartbeat, "Interval for autopilot loop execution") + flag.Float64Var(&cfg.Autopilot.MigrationHealthCutoff, "autopilot.migrationHealthCutoff", cfg.Autopilot.MigrationHealthCutoff, "Threshold for migrating slabs based on health") + flag.DurationVar(&cfg.Autopilot.RevisionBroadcastInterval, "autopilot.revisionBroadcastInterval", cfg.Autopilot.RevisionBroadcastInterval, "Interval for broadcasting contract revisions (overrides with RENTERD_AUTOPILOT_REVISION_BROADCAST_INTERVAL)") + flag.Uint64Var(&cfg.Autopilot.ScannerBatchSize, "autopilot.scannerBatchSize", cfg.Autopilot.ScannerBatchSize, "Batch size for host scanning") + flag.DurationVar(&cfg.Autopilot.ScannerInterval, "autopilot.scannerInterval", cfg.Autopilot.ScannerInterval, "Interval for scanning hosts") + flag.Uint64Var(&cfg.Autopilot.ScannerNumThreads, "autopilot.scannerNumThreads", cfg.Autopilot.ScannerNumThreads, "Number of threads for scanning hosts") + flag.Uint64Var(&cfg.Autopilot.MigratorParallelSlabsPerWorker, "autopilot.migratorParallelSlabsPerWorker", cfg.Autopilot.MigratorParallelSlabsPerWorker, "Parallel slab migrations per worker (overrides with RENTERD_MIGRATOR_PARALLEL_SLABS_PER_WORKER)") + flag.BoolVar(&cfg.Autopilot.Enabled, "autopilot.enabled", cfg.Autopilot.Enabled, "Enables/disables autopilot (overrides with RENTERD_AUTOPILOT_ENABLED)") + flag.DurationVar(&cfg.ShutdownTimeout, "node.shutdownTimeout", cfg.ShutdownTimeout, "Timeout for node shutdown") + + // s3 + flag.StringVar(&cfg.S3.Address, "s3.address", cfg.S3.Address, "Address for serving S3 API (overrides with RENTERD_S3_ADDRESS)") + flag.BoolVar(&cfg.S3.DisableAuth, "s3.disableAuth", cfg.S3.DisableAuth, "Disables authentication for S3 API (overrides with RENTERD_S3_DISABLE_AUTH)") + flag.BoolVar(&cfg.S3.Enabled, "s3.enabled", cfg.S3.Enabled, "Enables/disables S3 API (requires worker.enabled to be 'true', overrides with RENTERD_S3_ENABLED)") + flag.StringVar(&hostBasesStr, "s3.hostBases", "", "Enables bucket rewriting in the router for specific hosts provided via comma-separated list (overrides with RENTERD_S3_HOST_BUCKET_BASES)") + flag.BoolVar(&cfg.S3.HostBucketEnabled, "s3.hostBucketEnabled", cfg.S3.HostBucketEnabled, "Enables bucket rewriting in the router for all hosts (overrides with RENTERD_S3_HOST_BUCKET_ENABLED)") + + // custom usage + flag.Usage = func() { + log.Print(usageHeader) + flag.PrintDefaults() + log.Print(usageFooter) + } + + flag.Parse() +} + +func parseEnvironmentVariables(cfg *config.Config) { + // define helper function to parse environment variables + parseEnvVar := func(s string, v interface{}) { + if env, ok := os.LookupEnv(s); ok { + if _, err := fmt.Sscan(env, v); err != nil { + log.Fatalf("failed to parse %s: %v", s, err) + } + fmt.Printf("Using %s environment variable\n", s) + } + } + + parseEnvVar("RENTERD_NETWORK", &cfg.Network) + + parseEnvVar("RENTERD_BUS_REMOTE_ADDR", &cfg.Bus.RemoteAddr) + parseEnvVar("RENTERD_BUS_API_PASSWORD", &cfg.Bus.RemotePassword) + parseEnvVar("RENTERD_BUS_GATEWAY_ADDR", &cfg.Bus.GatewayAddr) + parseEnvVar("RENTERD_BUS_SLAB_BUFFER_COMPLETION_THRESHOLD", &cfg.Bus.SlabBufferCompletionThreshold) + + parseEnvVar("RENTERD_DB_URI", &cfg.Database.MySQL.URI) + parseEnvVar("RENTERD_DB_USER", &cfg.Database.MySQL.User) + parseEnvVar("RENTERD_DB_PASSWORD", &cfg.Database.MySQL.Password) + parseEnvVar("RENTERD_DB_NAME", &cfg.Database.MySQL.Database) + parseEnvVar("RENTERD_DB_METRICS_NAME", &cfg.Database.MySQL.MetricsDatabase) + + parseEnvVar("RENTERD_DB_LOGGER_IGNORE_NOT_FOUND_ERROR", &cfg.Database.Log.IgnoreRecordNotFoundError) + parseEnvVar("RENTERD_DB_LOGGER_LOG_LEVEL", &cfg.Log.Level) + parseEnvVar("RENTERD_DB_LOGGER_SLOW_THRESHOLD", &cfg.Database.Log.SlowThreshold) + + parseEnvVar("RENTERD_WORKER_ENABLED", &cfg.Worker.Enabled) + parseEnvVar("RENTERD_WORKER_ID", &cfg.Worker.ID) + parseEnvVar("RENTERD_WORKER_UNAUTHENTICATED_DOWNLOADS", &cfg.Worker.AllowUnauthenticatedDownloads) + parseEnvVar("RENTERD_WORKER_DOWNLOAD_MAX_MEMORY", &cfg.Worker.DownloadMaxMemory) + parseEnvVar("RENTERD_WORKER_UPLOAD_MAX_MEMORY", &cfg.Worker.UploadMaxMemory) + parseEnvVar("RENTERD_WORKER_EXTERNAL_ADDR", &cfg.Worker.ExternalAddress) + + parseEnvVar("RENTERD_AUTOPILOT_ENABLED", &cfg.Autopilot.Enabled) + parseEnvVar("RENTERD_AUTOPILOT_REVISION_BROADCAST_INTERVAL", &cfg.Autopilot.RevisionBroadcastInterval) + parseEnvVar("RENTERD_MIGRATOR_PARALLEL_SLABS_PER_WORKER", &cfg.Autopilot.MigratorParallelSlabsPerWorker) + + parseEnvVar("RENTERD_S3_ADDRESS", &cfg.S3.Address) + parseEnvVar("RENTERD_S3_ENABLED", &cfg.S3.Enabled) + parseEnvVar("RENTERD_S3_DISABLE_AUTH", &cfg.S3.DisableAuth) + parseEnvVar("RENTERD_S3_HOST_BUCKET_ENABLED", &cfg.S3.HostBucketEnabled) + parseEnvVar("RENTERD_S3_HOST_BUCKET_BASES", &cfg.S3.HostBucketBases) + + parseEnvVar("RENTERD_LOG_PATH", &cfg.Log.Path) + parseEnvVar("RENTERD_LOG_LEVEL", &cfg.Log.Level) + parseEnvVar("RENTERD_LOG_FILE_ENABLED", &cfg.Log.File.Enabled) + parseEnvVar("RENTERD_LOG_FILE_FORMAT", &cfg.Log.File.Format) + parseEnvVar("RENTERD_LOG_FILE_PATH", &cfg.Log.File.Path) + parseEnvVar("RENTERD_LOG_STDOUT_ENABLED", &cfg.Log.StdOut.Enabled) + parseEnvVar("RENTERD_LOG_STDOUT_FORMAT", &cfg.Log.StdOut.Format) + parseEnvVar("RENTERD_LOG_STDOUT_ENABLE_ANSI", &cfg.Log.StdOut.EnableANSI) + parseEnvVar("RENTERD_LOG_DATABASE_ENABLED", &cfg.Log.Database.Enabled) + parseEnvVar("RENTERD_LOG_DATABASE_LEVEL", &cfg.Log.Database.Level) + parseEnvVar("RENTERD_LOG_DATABASE_IGNORE_RECORD_NOT_FOUND_ERROR", &cfg.Log.Database.IgnoreRecordNotFoundError) + parseEnvVar("RENTERD_LOG_DATABASE_SLOW_THRESHOLD", &cfg.Log.Database.SlowThreshold) + + parseEnvVar("RENTERD_WORKER_REMOTE_ADDRS", &workerRemoteAddrsStr) + parseEnvVar("RENTERD_WORKER_API_PASSWORD", &workerRemotePassStr) + + parseEnvVar("RENTERD_S3_KEYPAIRS_V4", &keyPairsV4) +} // readPasswordInput reads a password from stdin. func readPasswordInput(context string) string { @@ -147,7 +528,7 @@ func setListenAddress(context string, value *string, allowEmpty bool) { // setSeedPhrase prompts the user to enter a seed phrase if one is not already // set via environment variable or config file. -func setSeedPhrase() { +func setSeedPhrase(cfg *config.Config) { // retry until a valid seed phrase is entered for { fmt.Println("") @@ -203,7 +584,7 @@ func setSeedPhrase() { // setAPIPassword prompts the user to enter an API password if one is not // already set via environment variable or config file. -func setAPIPassword() { +func setAPIPassword(cfg *config.Config) { // return early if the password is already set if len(cfg.HTTP.Password) >= 4 { return @@ -224,7 +605,7 @@ func setAPIPassword() { } } -func setAdvancedConfig() { +func setAdvancedConfig(cfg *config.Config) { if !promptYesNo("Would you like to configure advanced settings?") { return } @@ -251,10 +632,10 @@ func setAdvancedConfig() { fmt.Println("The database is used to store the renter's metadata.") fmt.Println("The embedded SQLite database requires no additional configuration and is ideal for testing or demo purposes.") fmt.Println("For production usage, we recommend MySQL, which requires a separate MySQL server.") - setStoreConfig() + setStoreConfig(cfg) } -func setStoreConfig() { +func setStoreConfig(cfg *config.Config) { store := promptQuestion("Which data store would you like to use?", []string{"mysql", "sqlite"}) switch store { case "mysql": @@ -278,7 +659,7 @@ func setStoreConfig() { } } -func setS3Config() { +func setS3Config(cfg *config.Config) { if !promptYesNo("Would you like to configure S3 settings?") { return } diff --git a/cmd/renterd/logger.go b/cmd/renterd/logger.go index 4b21a1925..d107cc4a0 100644 --- a/cmd/renterd/logger.go +++ b/cmd/renterd/logger.go @@ -11,11 +11,11 @@ import ( "go.uber.org/zap/zapcore" ) -func NewLogger(dir string, cfg config.Log) (*zap.Logger, func(context.Context) error, error) { +func NewLogger(dir, filename string, cfg config.Log) (*zap.Logger, func(context.Context) error, error) { // path - path := filepath.Join(dir, "renterd.log") + path := filepath.Join(dir, filename) if cfg.Path != "" { - path = filepath.Join(cfg.Path, "renterd.log") + path = filepath.Join(cfg.Path, filename) } if cfg.File.Path != "" { diff --git a/cmd/renterd/main.go b/cmd/renterd/main.go index d0e75d680..a32ecceed 100644 --- a/cmd/renterd/main.go +++ b/cmd/renterd/main.go @@ -1,53 +1,14 @@ package main import ( - "context" - "encoding/json" - "errors" "flag" - "fmt" "log" - "net" - "net/http" "os" - "os/exec" "os/signal" - "path/filepath" - "runtime" - "strings" "syscall" - "time" - - "go.sia.tech/core/types" - "go.sia.tech/coreutils/wallet" - "go.sia.tech/jape" - "go.sia.tech/renterd/api" - "go.sia.tech/renterd/autopilot" - "go.sia.tech/renterd/build" - "go.sia.tech/renterd/bus" - "go.sia.tech/renterd/config" - "go.sia.tech/renterd/internal/node" - "go.sia.tech/renterd/internal/utils" - iworker "go.sia.tech/renterd/internal/worker" - "go.sia.tech/renterd/worker" - "go.sia.tech/renterd/worker/s3" - "go.sia.tech/web/renterd" - "go.uber.org/zap" - "golang.org/x/sys/cpu" - "gopkg.in/yaml.v3" ) const ( - // accountRefillInterval is the amount of time between refills of ephemeral - // accounts. If we conservatively assume that a good host charges 500 SC / - // TiB, we can pay for about 2.2 GiB with 1 SC. Since we want to refill - // ahead of time at 0.5 SC, that makes 1.1 GiB. Considering a 1 Gbps uplink - // that is shared across 30 uploads, we upload at around 33 Mbps to each - // host. That means uploading 1.1 GiB to drain 0.5 SC takes around 5 - // minutes. That's why we assume 10 seconds to be more than frequent enough - // to refill an account when it's due for another refill. - defaultAccountRefillInterval = 10 * time.Second - // usageHeader is the header for the CLI usage text. usageHeader = ` Renterd is the official Sia renter daemon. It provides a REST API for forming @@ -72,711 +33,58 @@ on how to configure and use renterd. ` ) -var ( - cfg = config.Config{ - Directory: ".", - Seed: os.Getenv("RENTERD_SEED"), - AutoOpenWebUI: true, - HTTP: config.HTTP{ - Address: build.DefaultAPIAddress, - Password: os.Getenv("RENTERD_API_PASSWORD"), - }, - ShutdownTimeout: 5 * time.Minute, - Database: config.Database{ - MySQL: config.MySQL{ - User: "renterd", - Database: "renterd", - MetricsDatabase: "renterd_metrics", - }, - }, - Log: config.Log{ - Path: "", // deprecated. included for compatibility. - Level: "", - File: config.LogFile{ - Enabled: true, - Format: "json", - Path: os.Getenv("RENTERD_LOG_FILE"), - }, - StdOut: config.StdOut{ - Enabled: true, - Format: "human", - EnableANSI: runtime.GOOS != "windows", - }, - Database: config.DatabaseLog{ - Enabled: true, - IgnoreRecordNotFoundError: true, - SlowThreshold: 100 * time.Millisecond, - }, - }, - Bus: config.Bus{ - AnnouncementMaxAgeHours: 24 * 7 * 52, // 1 year - Bootstrap: true, - GatewayAddr: build.DefaultGatewayAddress, - PersistInterval: time.Minute, - UsedUTXOExpiry: 24 * time.Hour, - SlabBufferCompletionThreshold: 1 << 12, - }, - Worker: config.Worker{ - Enabled: true, - - ID: "worker", - ContractLockTimeout: 30 * time.Second, - BusFlushInterval: 5 * time.Second, - - DownloadMaxOverdrive: 5, - DownloadOverdriveTimeout: 3 * time.Second, - - DownloadMaxMemory: 1 << 30, // 1 GiB - UploadMaxMemory: 1 << 30, // 1 GiB - UploadMaxOverdrive: 5, - UploadOverdriveTimeout: 3 * time.Second, - }, - Autopilot: config.Autopilot{ - Enabled: true, - RevisionSubmissionBuffer: 144, - AccountsRefillInterval: defaultAccountRefillInterval, - Heartbeat: 30 * time.Minute, - MigrationHealthCutoff: 0.75, - RevisionBroadcastInterval: 7 * 24 * time.Hour, - ScannerBatchSize: 1000, - ScannerInterval: 24 * time.Hour, - ScannerNumThreads: 100, - MigratorParallelSlabsPerWorker: 1, - }, - S3: config.S3{ - Address: build.DefaultS3Address, - Enabled: true, - DisableAuth: false, - KeypairsV4: nil, - }, - } - disableStdin bool -) - -func mustParseWorkers(workers, password string) { - if workers == "" { - return - } - // if the CLI flag/environment variable is set, overwrite the config file - cfg.Worker.Remotes = cfg.Worker.Remotes[:0] - for _, addr := range strings.Split(workers, ";") { - // note: duplicates the old behavior of all workers sharing the same - // password - cfg.Worker.Remotes = append(cfg.Worker.Remotes, config.RemoteWorker{ - Address: addr, - Password: password, - }) - } -} - -// tryLoadConfig loads the config file specified by the RENTERD_CONFIG_FILE -// environment variable. If the config file does not exist, it will not be -// loaded. -func tryLoadConfig() { - configPath := "renterd.yml" - if str := os.Getenv("RENTERD_CONFIG_FILE"); str != "" { - configPath = str - } - - // If the config file doesn't exist, don't try to load it. - if _, err := os.Stat(configPath); os.IsNotExist(err) { - return - } - - f, err := os.Open(configPath) - if err != nil { - log.Fatal("failed to open config file:", err) - } - defer f.Close() - - dec := yaml.NewDecoder(f) - dec.KnownFields(true) - - if err := dec.Decode(&cfg); err != nil { - log.Fatal("failed to decode config file:", err) - } -} - -func parseEnvVar(s string, v interface{}) { - if env, ok := os.LookupEnv(s); ok { - if _, err := fmt.Sscan(env, v); err != nil { - log.Fatalf("failed to parse %s: %v", s, err) - } - fmt.Printf("Using %s environment variable\n", s) - } -} - -func listenTCP(logger *zap.Logger, addr string) (net.Listener, error) { - l, err := net.Listen("tcp", addr) - if utils.IsErr(err, errors.New("no such host")) && strings.Contains(addr, "localhost") { - // fall back to 127.0.0.1 if 'localhost' doesn't work - _, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - fallbackAddr := fmt.Sprintf("127.0.0.1:%s", port) - logger.Sugar().Warnf("failed to listen on %s, falling back to %s", addr, fallbackAddr) - return net.Listen("tcp", fallbackAddr) - } else if err != nil { - return nil, err - } - return l, nil -} - func main() { log.SetFlags(0) - // load the YAML config first. CLI flags and environment variables will - // overwrite anything set in the config file. - tryLoadConfig() - - // deprecated - these go first so that they can be overwritten by the non-deprecated flags - flag.StringVar(&cfg.Log.Database.Level, "db.logger.logLevel", cfg.Log.Database.Level, "(deprecated) Logger level (overrides with RENTERD_DB_LOGGER_LOG_LEVEL)") - flag.BoolVar(&cfg.Database.Log.IgnoreRecordNotFoundError, "db.logger.ignoreNotFoundError", cfg.Database.Log.IgnoreRecordNotFoundError, "(deprecated) Ignores 'not found' errors in logger (overrides with RENTERD_DB_LOGGER_IGNORE_NOT_FOUND_ERROR)") - flag.DurationVar(&cfg.Database.Log.SlowThreshold, "db.logger.slowThreshold", cfg.Database.Log.SlowThreshold, "(deprecated) Threshold for slow queries in logger (overrides with RENTERD_DB_LOGGER_SLOW_THRESHOLD)") - flag.StringVar(&cfg.Log.Path, "log-path", cfg.Log.Path, "(deprecated) Path to directory for logs (overrides with RENTERD_LOG_PATH)") - - // node - flag.StringVar(&cfg.HTTP.Address, "http", cfg.HTTP.Address, "Address for serving the API") - flag.StringVar(&cfg.Directory, "dir", cfg.Directory, "Directory for storing node state") - flag.BoolVar(&disableStdin, "env", false, "disable stdin prompts for environment variables (default false)") - flag.BoolVar(&cfg.AutoOpenWebUI, "openui", cfg.AutoOpenWebUI, "automatically open the web UI on startup") - - // logger - flag.StringVar(&cfg.Log.Level, "log.level", cfg.Log.Level, "Global logger level (debug|info|warn|error). Defaults to 'info' (overrides with RENTERD_LOG_LEVEL)") - flag.BoolVar(&cfg.Log.File.Enabled, "log.file.enabled", cfg.Log.File.Enabled, "Enables logging to disk. Defaults to 'true'. (overrides with RENTERD_LOG_FILE_ENABLED)") - flag.StringVar(&cfg.Log.File.Format, "log.file.format", cfg.Log.File.Format, "Format of log file (json|human). Defaults to 'json' (overrides with RENTERD_LOG_FILE_FORMAT)") - flag.StringVar(&cfg.Log.File.Path, "log.file.path", cfg.Log.File.Path, "Path of log file. Defaults to 'renterd.log' within the renterd directory. (overrides with RENTERD_LOG_FILE_PATH)") - flag.BoolVar(&cfg.Log.StdOut.Enabled, "log.stdout.enabled", cfg.Log.StdOut.Enabled, "Enables logging to stdout. Defaults to 'true'. (overrides with RENTERD_LOG_STDOUT_ENABLED)") - flag.StringVar(&cfg.Log.StdOut.Format, "log.stdout.format", cfg.Log.StdOut.Format, "Format of log output (json|human). Defaults to 'human' (overrides with RENTERD_LOG_STDOUT_FORMAT)") - flag.BoolVar(&cfg.Log.StdOut.EnableANSI, "log.stdout.enableANSI", cfg.Log.StdOut.EnableANSI, "Enables ANSI color codes in log output. Defaults to 'true' on non-Windows systems. (overrides with RENTERD_LOG_STDOUT_ENABLE_ANSI)") - flag.BoolVar(&cfg.Log.Database.Enabled, "log.database.enabled", cfg.Log.Database.Enabled, "Enable logging database queries. Defaults to 'true' (overrides with RENTERD_LOG_DATABASE_ENABLED)") - flag.StringVar(&cfg.Log.Database.Level, "log.database.level", cfg.Log.Database.Level, "Logger level for database queries (info|warn|error). Defaults to 'warn' (overrides with RENTERD_LOG_LEVEL and RENTERD_LOG_DATABASE_LEVEL)") - flag.BoolVar(&cfg.Log.Database.IgnoreRecordNotFoundError, "log.database.ignoreRecordNotFoundError", cfg.Log.Database.IgnoreRecordNotFoundError, "Enable ignoring 'not found' errors resulting from database queries. Defaults to 'true' (overrides with RENTERD_LOG_DATABASE_IGNORE_RECORD_NOT_FOUND_ERROR)") - flag.DurationVar(&cfg.Log.Database.SlowThreshold, "log.database.slowThreshold", cfg.Log.Database.SlowThreshold, "Threshold for slow queries in logger. Defaults to 100ms (overrides with RENTERD_LOG_DATABASE_SLOW_THRESHOLD)") - - // db - flag.StringVar(&cfg.Database.MySQL.URI, "db.uri", cfg.Database.MySQL.URI, "Database URI for the bus (overrides with RENTERD_DB_URI)") - flag.StringVar(&cfg.Database.MySQL.User, "db.user", cfg.Database.MySQL.User, "Database username for the bus (overrides with RENTERD_DB_USER)") - flag.StringVar(&cfg.Database.MySQL.Database, "db.name", cfg.Database.MySQL.Database, "Database name for the bus (overrides with RENTERD_DB_NAME)") - flag.StringVar(&cfg.Database.MySQL.MetricsDatabase, "db.metricsName", cfg.Database.MySQL.MetricsDatabase, "Database for metrics (overrides with RENTERD_DB_METRICS_NAME)") - - // bus - flag.Uint64Var(&cfg.Bus.AnnouncementMaxAgeHours, "bus.announcementMaxAgeHours", cfg.Bus.AnnouncementMaxAgeHours, "Max age for announcements") - flag.BoolVar(&cfg.Bus.Bootstrap, "bus.bootstrap", cfg.Bus.Bootstrap, "Bootstraps gateway and consensus modules") - flag.StringVar(&cfg.Bus.GatewayAddr, "bus.gatewayAddr", cfg.Bus.GatewayAddr, "Address for Sia peer connections (overrides with RENTERD_BUS_GATEWAY_ADDR)") - flag.DurationVar(&cfg.Bus.PersistInterval, "bus.persistInterval", cfg.Bus.PersistInterval, "Interval for persisting consensus updates") - flag.DurationVar(&cfg.Bus.UsedUTXOExpiry, "bus.usedUTXOExpiry", cfg.Bus.UsedUTXOExpiry, "Expiry for used UTXOs in transactions") - flag.Int64Var(&cfg.Bus.SlabBufferCompletionThreshold, "bus.slabBufferCompletionThreshold", cfg.Bus.SlabBufferCompletionThreshold, "Threshold for slab buffer upload (overrides with RENTERD_BUS_SLAB_BUFFER_COMPLETION_THRESHOLD)") - - // worker - flag.BoolVar(&cfg.Worker.AllowPrivateIPs, "worker.allowPrivateIPs", cfg.Worker.AllowPrivateIPs, "Allows hosts with private IPs") - flag.DurationVar(&cfg.Worker.BusFlushInterval, "worker.busFlushInterval", cfg.Worker.BusFlushInterval, "Interval for flushing data to bus") - flag.Uint64Var(&cfg.Worker.DownloadMaxMemory, "worker.downloadMaxMemory", cfg.Worker.DownloadMaxMemory, "Max amount of RAM the worker allocates for slabs when downloading (overrides with RENTERD_WORKER_DOWNLOAD_MAX_MEMORY)") - flag.Uint64Var(&cfg.Worker.DownloadMaxOverdrive, "worker.downloadMaxOverdrive", cfg.Worker.DownloadMaxOverdrive, "Max overdrive workers for downloads") - flag.StringVar(&cfg.Worker.ID, "worker.id", cfg.Worker.ID, "Unique ID for worker (overrides with RENTERD_WORKER_ID)") - flag.DurationVar(&cfg.Worker.DownloadOverdriveTimeout, "worker.downloadOverdriveTimeout", cfg.Worker.DownloadOverdriveTimeout, "Timeout for overdriving slab downloads") - flag.Uint64Var(&cfg.Worker.UploadMaxMemory, "worker.uploadMaxMemory", cfg.Worker.UploadMaxMemory, "Max amount of RAM the worker allocates for slabs when uploading (overrides with RENTERD_WORKER_UPLOAD_MAX_MEMORY)") - flag.Uint64Var(&cfg.Worker.UploadMaxOverdrive, "worker.uploadMaxOverdrive", cfg.Worker.UploadMaxOverdrive, "Max overdrive workers for uploads") - flag.DurationVar(&cfg.Worker.UploadOverdriveTimeout, "worker.uploadOverdriveTimeout", cfg.Worker.UploadOverdriveTimeout, "Timeout for overdriving slab uploads") - flag.BoolVar(&cfg.Worker.Enabled, "worker.enabled", cfg.Worker.Enabled, "Enables/disables worker (overrides with RENTERD_WORKER_ENABLED)") - flag.BoolVar(&cfg.Worker.AllowUnauthenticatedDownloads, "worker.unauthenticatedDownloads", cfg.Worker.AllowUnauthenticatedDownloads, "Allows unauthenticated downloads (overrides with RENTERD_WORKER_UNAUTHENTICATED_DOWNLOADS)") - flag.StringVar(&cfg.Worker.ExternalAddress, "worker.externalAddress", cfg.Worker.ExternalAddress, "Address of the worker on the network, only necessary when the bus is remote (overrides with RENTERD_WORKER_EXTERNAL_ADDR)") - - // autopilot - flag.DurationVar(&cfg.Autopilot.AccountsRefillInterval, "autopilot.accountRefillInterval", cfg.Autopilot.AccountsRefillInterval, "Interval for refilling workers' account balances") - flag.DurationVar(&cfg.Autopilot.Heartbeat, "autopilot.heartbeat", cfg.Autopilot.Heartbeat, "Interval for autopilot loop execution") - flag.Float64Var(&cfg.Autopilot.MigrationHealthCutoff, "autopilot.migrationHealthCutoff", cfg.Autopilot.MigrationHealthCutoff, "Threshold for migrating slabs based on health") - flag.DurationVar(&cfg.Autopilot.RevisionBroadcastInterval, "autopilot.revisionBroadcastInterval", cfg.Autopilot.RevisionBroadcastInterval, "Interval for broadcasting contract revisions (overrides with RENTERD_AUTOPILOT_REVISION_BROADCAST_INTERVAL)") - flag.Uint64Var(&cfg.Autopilot.ScannerBatchSize, "autopilot.scannerBatchSize", cfg.Autopilot.ScannerBatchSize, "Batch size for host scanning") - flag.DurationVar(&cfg.Autopilot.ScannerInterval, "autopilot.scannerInterval", cfg.Autopilot.ScannerInterval, "Interval for scanning hosts") - flag.Uint64Var(&cfg.Autopilot.ScannerNumThreads, "autopilot.scannerNumThreads", cfg.Autopilot.ScannerNumThreads, "Number of threads for scanning hosts") - flag.Uint64Var(&cfg.Autopilot.MigratorParallelSlabsPerWorker, "autopilot.migratorParallelSlabsPerWorker", cfg.Autopilot.MigratorParallelSlabsPerWorker, "Parallel slab migrations per worker (overrides with RENTERD_MIGRATOR_PARALLEL_SLABS_PER_WORKER)") - flag.BoolVar(&cfg.Autopilot.Enabled, "autopilot.enabled", cfg.Autopilot.Enabled, "Enables/disables autopilot (overrides with RENTERD_AUTOPILOT_ENABLED)") - flag.DurationVar(&cfg.ShutdownTimeout, "node.shutdownTimeout", cfg.ShutdownTimeout, "Timeout for node shutdown") - - // s3 - var hostBasesStr string - flag.StringVar(&cfg.S3.Address, "s3.address", cfg.S3.Address, "Address for serving S3 API (overrides with RENTERD_S3_ADDRESS)") - flag.BoolVar(&cfg.S3.DisableAuth, "s3.disableAuth", cfg.S3.DisableAuth, "Disables authentication for S3 API (overrides with RENTERD_S3_DISABLE_AUTH)") - flag.BoolVar(&cfg.S3.Enabled, "s3.enabled", cfg.S3.Enabled, "Enables/disables S3 API (requires worker.enabled to be 'true', overrides with RENTERD_S3_ENABLED)") - flag.StringVar(&hostBasesStr, "s3.hostBases", "", "Enables bucket rewriting in the router for specific hosts provided via comma-separated list (overrides with RENTERD_S3_HOST_BUCKET_BASES)") - flag.BoolVar(&cfg.S3.HostBucketEnabled, "s3.hostBucketEnabled", cfg.S3.HostBucketEnabled, "Enables bucket rewriting in the router for all hosts (overrides with RENTERD_S3_HOST_BUCKET_ENABLED)") - - // custom usage - flag.Usage = func() { - log.Print(usageHeader) - flag.PrintDefaults() - log.Print(usageFooter) + // load the config + cfg, network, genesis, err := loadConfig() + if err != nil { + stdoutFatalError("failed to load config: " + err.Error()) } - flag.Parse() - // NOTE: update the usage header when adding new commands if flag.Arg(0) == "version" { - cmdVersion() + cmdVersion(network.Name) return } else if flag.Arg(0) == "seed" { cmdSeed() return } else if flag.Arg(0) == "config" { - cmdBuildConfig() + cmdBuildConfig(&cfg) return } else if flag.Arg(0) != "" { flag.Usage() return } - // Overwrite flags from environment if set. - parseEnvVar("RENTERD_BUS_REMOTE_ADDR", &cfg.Bus.RemoteAddr) - parseEnvVar("RENTERD_BUS_API_PASSWORD", &cfg.Bus.RemotePassword) - parseEnvVar("RENTERD_BUS_GATEWAY_ADDR", &cfg.Bus.GatewayAddr) - parseEnvVar("RENTERD_BUS_SLAB_BUFFER_COMPLETION_THRESHOLD", &cfg.Bus.SlabBufferCompletionThreshold) - - parseEnvVar("RENTERD_DB_URI", &cfg.Database.MySQL.URI) - parseEnvVar("RENTERD_DB_USER", &cfg.Database.MySQL.User) - parseEnvVar("RENTERD_DB_PASSWORD", &cfg.Database.MySQL.Password) - parseEnvVar("RENTERD_DB_NAME", &cfg.Database.MySQL.Database) - parseEnvVar("RENTERD_DB_METRICS_NAME", &cfg.Database.MySQL.MetricsDatabase) - - parseEnvVar("RENTERD_DB_LOGGER_IGNORE_NOT_FOUND_ERROR", &cfg.Database.Log.IgnoreRecordNotFoundError) - parseEnvVar("RENTERD_DB_LOGGER_LOG_LEVEL", &cfg.Log.Level) - parseEnvVar("RENTERD_DB_LOGGER_SLOW_THRESHOLD", &cfg.Database.Log.SlowThreshold) - - parseEnvVar("RENTERD_WORKER_ENABLED", &cfg.Worker.Enabled) - parseEnvVar("RENTERD_WORKER_ID", &cfg.Worker.ID) - parseEnvVar("RENTERD_WORKER_UNAUTHENTICATED_DOWNLOADS", &cfg.Worker.AllowUnauthenticatedDownloads) - parseEnvVar("RENTERD_WORKER_DOWNLOAD_MAX_MEMORY", &cfg.Worker.DownloadMaxMemory) - parseEnvVar("RENTERD_WORKER_UPLOAD_MAX_MEMORY", &cfg.Worker.UploadMaxMemory) - parseEnvVar("RENTERD_WORKER_EXTERNAL_ADDR", &cfg.Worker.ExternalAddress) - - parseEnvVar("RENTERD_AUTOPILOT_ENABLED", &cfg.Autopilot.Enabled) - parseEnvVar("RENTERD_AUTOPILOT_REVISION_BROADCAST_INTERVAL", &cfg.Autopilot.RevisionBroadcastInterval) - parseEnvVar("RENTERD_MIGRATOR_PARALLEL_SLABS_PER_WORKER", &cfg.Autopilot.MigratorParallelSlabsPerWorker) - - parseEnvVar("RENTERD_S3_ADDRESS", &cfg.S3.Address) - parseEnvVar("RENTERD_S3_ENABLED", &cfg.S3.Enabled) - parseEnvVar("RENTERD_S3_DISABLE_AUTH", &cfg.S3.DisableAuth) - parseEnvVar("RENTERD_S3_HOST_BUCKET_ENABLED", &cfg.S3.HostBucketEnabled) - parseEnvVar("RENTERD_S3_HOST_BUCKET_BASES", &cfg.S3.HostBucketBases) - - parseEnvVar("RENTERD_LOG_PATH", &cfg.Log.Path) - parseEnvVar("RENTERD_LOG_LEVEL", &cfg.Log.Level) - parseEnvVar("RENTERD_LOG_FILE_ENABLED", &cfg.Log.File.Enabled) - parseEnvVar("RENTERD_LOG_FILE_FORMAT", &cfg.Log.File.Format) - parseEnvVar("RENTERD_LOG_FILE_PATH", &cfg.Log.File.Path) - parseEnvVar("RENTERD_LOG_STDOUT_ENABLED", &cfg.Log.StdOut.Enabled) - parseEnvVar("RENTERD_LOG_STDOUT_FORMAT", &cfg.Log.StdOut.Format) - parseEnvVar("RENTERD_LOG_STDOUT_ENABLE_ANSI", &cfg.Log.StdOut.EnableANSI) - parseEnvVar("RENTERD_LOG_DATABASE_ENABLED", &cfg.Log.Database.Enabled) - parseEnvVar("RENTERD_LOG_DATABASE_LEVEL", &cfg.Log.Database.Level) - parseEnvVar("RENTERD_LOG_DATABASE_IGNORE_RECORD_NOT_FOUND_ERROR", &cfg.Log.Database.IgnoreRecordNotFoundError) - parseEnvVar("RENTERD_LOG_DATABASE_SLOW_THRESHOLD", &cfg.Log.Database.SlowThreshold) - - // parse remotes - var workerRemotePassStr string - var workerRemoteAddrsStr string - parseEnvVar("RENTERD_WORKER_REMOTE_ADDRS", &workerRemoteAddrsStr) - parseEnvVar("RENTERD_WORKER_API_PASSWORD", &workerRemotePassStr) - if workerRemoteAddrsStr != "" && workerRemotePassStr != "" { - mustParseWorkers(workerRemoteAddrsStr, workerRemotePassStr) - } - - // disable worker if remotes are set - if len(cfg.Worker.Remotes) > 0 { - cfg.Worker.Enabled = false + // sanitize the config + if err := sanitizeConfig(&cfg); err != nil { + stdoutFatalError("failed to sanitize config: " + err.Error()) } - // combine host bucket bases - for _, base := range strings.Split(hostBasesStr, ",") { - if trimmed := strings.TrimSpace(base); trimmed != "" { - cfg.S3.HostBucketBases = append(cfg.S3.HostBucketBases, base) - } - } - - // check that the API password is set - if cfg.HTTP.Password == "" { - if disableStdin { - stdoutFatalError("API password must be set via environment variable or config file when --env flag is set") - return - } - } - setAPIPassword() - - // check that the seed is set - if cfg.Seed == "" && (cfg.Worker.Enabled || cfg.Bus.RemoteAddr == "") { // only worker & bus require a seed - if disableStdin { - stdoutFatalError("Seed must be set via environment variable or config file when --env flag is set") - return - } - setSeedPhrase() - } - - // generate private key from seed - var pk types.PrivateKey - if cfg.Seed != "" { - var rawSeed [32]byte - if err := wallet.SeedFromPhrase(&rawSeed, cfg.Seed); err != nil { - log.Fatal("failed to load wallet", zap.Error(err)) - } - pk = wallet.KeyFromSeed(&rawSeed, 0) - } - - // parse S3 auth keys - if cfg.S3.Enabled { - var keyPairsV4 string - parseEnvVar("RENTERD_S3_KEYPAIRS_V4", &keyPairsV4) - if !cfg.S3.DisableAuth && keyPairsV4 != "" { - var err error - cfg.S3.KeypairsV4, err = s3.Parsev4AuthKeys(strings.Split(keyPairsV4, ";")) - if err != nil { - log.Fatalf("failed to parse keypairs: %v", err) - } - } - } - - // create logger - if cfg.Log.Level == "" { - cfg.Log.Level = "info" // default to 'info' if not set - } - logger, closeFn, err := NewLogger(cfg.Directory, cfg.Log) + // create node + node, err := newNode(cfg, network, genesis) if err != nil { - log.Fatalln("failed to create logger:", err) - } - defer closeFn(context.Background()) - - logger.Info("renterd", zap.String("version", build.Version()), zap.String("network", build.NetworkName()), zap.String("commit", build.Commit()), zap.Time("buildDate", build.BuildTime())) - if runtime.GOARCH == "amd64" && !cpu.X86.HasAVX2 { - logger.Warn("renterd is running on a system without AVX2 support, performance may be degraded") - } - - if cfg.Log.Database.Level == "" { - cfg.Log.Database.Level = cfg.Log.Level + stdoutFatalError("failed to create node: " + err.Error()) } - network, _ := build.Network() - busCfg := node.BusConfig{ - Bus: cfg.Bus, - Database: cfg.Database, - DatabaseLog: cfg.Log.Database, - Logger: logger, - Network: network, - } - - type shutdownFnEntry struct { - name string - fn func(context.Context) error - } - var shutdownFns []shutdownFnEntry - - if cfg.Bus.RemoteAddr != "" && !cfg.Worker.Enabled && !cfg.Autopilot.Enabled { - logger.Fatal("remote bus, remote worker, and no autopilot -- nothing to do!") - } - if cfg.Worker.Enabled && cfg.Bus.RemoteAddr != "" && cfg.Worker.ExternalAddress == "" { - logger.Fatal("can't enable the worker using a remote bus, without configuring the worker's external address") - } - if cfg.Autopilot.Enabled && !cfg.Worker.Enabled && len(cfg.Worker.Remotes) == 0 { - logger.Fatal("can't enable autopilot without providing either workers to connect to or creating a worker") - } - - // create listener first, so that we know the actual apiAddr if the user - // specifies port :0 - l, err := listenTCP(logger, cfg.HTTP.Address) - if err != nil { - logger.Fatal("failed to create listener: " + err.Error()) - } - - // override the address with the actual one - cfg.HTTP.Address = "http://" + l.Addr().String() - - auth := jape.BasicAuth(cfg.HTTP.Password) - mux := &utils.TreeMux{ - Sub: make(map[string]utils.TreeMux), - } - - // Create the webserver. - srv := &http.Server{Handler: mux} - shutdownFns = append(shutdownFns, shutdownFnEntry{ - name: "HTTP Server", - fn: srv.Shutdown, - }) - - if err := os.MkdirAll(cfg.Directory, 0700); err != nil { - logger.Fatal("failed to create directory: " + err.Error()) - } - - busAddr, busPassword := cfg.Bus.RemoteAddr, cfg.Bus.RemotePassword - setupBusFn := node.NoopFn - if cfg.Bus.RemoteAddr == "" { - b, setupFn, shutdownFn, err := node.NewBus(busCfg, cfg.Directory, pk, logger) - if err != nil { - logger.Fatal("failed to create bus, err: " + err.Error()) - } - setupBusFn = setupFn - shutdownFns = append(shutdownFns, shutdownFnEntry{ - name: "Bus", - fn: shutdownFn, - }) - - mux.Sub["/api/bus"] = utils.TreeMux{Handler: auth(b)} - busAddr = cfg.HTTP.Address + "/api/bus" - busPassword = cfg.HTTP.Password - - // only serve the UI if a bus is created - mux.Handler = renterd.Handler() - } else { - logger.Info("connecting to remote bus at " + busAddr) - } - bc := bus.NewClient(busAddr, busPassword) - - var s3Srv *http.Server - var s3Listener net.Listener - var workers []autopilot.Worker - setupWorkerFn := node.NoopFn - if len(cfg.Worker.Remotes) == 0 { - if cfg.Worker.Enabled { - workerAddr := cfg.HTTP.Address + "/api/worker" - var shutdownFn node.ShutdownFn - w, s3Handler, setupFn, shutdownFn, err := node.NewWorker(cfg.Worker, s3.Opts{ - AuthDisabled: cfg.S3.DisableAuth, - HostBucketBases: cfg.S3.HostBucketBases, - HostBucketEnabled: cfg.S3.HostBucketEnabled, - }, bc, pk, logger) - if err != nil { - logger.Fatal("failed to create worker: " + err.Error()) - } - var workerExternAddr string - if cfg.Bus.RemoteAddr != "" { - workerExternAddr = cfg.Worker.ExternalAddress - } else { - workerExternAddr = workerAddr - } - setupWorkerFn = func(ctx context.Context) error { - return setupFn(ctx, workerExternAddr, cfg.HTTP.Password) - } - shutdownFns = append(shutdownFns, shutdownFnEntry{ - name: "Worker", - fn: shutdownFn, - }) - - mux.Sub["/api/worker"] = utils.TreeMux{Handler: iworker.Auth(cfg.HTTP.Password, cfg.Worker.AllowUnauthenticatedDownloads)(w)} - wc := worker.NewClient(workerAddr, cfg.HTTP.Password) - workers = append(workers, wc) - - if cfg.S3.Enabled { - s3Srv = &http.Server{ - Addr: cfg.S3.Address, - Handler: s3Handler, - } - s3Listener, err = listenTCP(logger, cfg.S3.Address) - if err != nil { - logger.Fatal("failed to create listener: " + err.Error()) - } - shutdownFns = append(shutdownFns, shutdownFnEntry{ - name: "S3", - fn: s3Srv.Shutdown, - }) - } - } - } else { - for _, remote := range cfg.Worker.Remotes { - workers = append(workers, worker.NewClient(remote.Address, remote.Password)) - logger.Info("connecting to remote worker at " + remote.Address) - } - } - - autopilotErr := make(chan error, 1) - autopilotDir := filepath.Join(cfg.Directory, api.DefaultAutopilotID) - if cfg.Autopilot.Enabled { - apCfg := node.AutopilotConfig{ - ID: api.DefaultAutopilotID, - Autopilot: cfg.Autopilot, - } - ap, runFn, fn, err := node.NewAutopilot(apCfg, bc, workers, logger) - if err != nil { - logger.Fatal("failed to create autopilot: " + err.Error()) - } - - // NOTE: the autopilot shutdown function needs to be called first. - shutdownFns = append(shutdownFns, shutdownFnEntry{ - name: "Autopilot", - fn: fn, - }) - - go func() { autopilotErr <- runFn() }() - mux.Sub["/api/autopilot"] = utils.TreeMux{Handler: auth(ap)} - } - - // Start server. - go srv.Serve(l) - - // Finish bus setup. - if err := setupBusFn(context.Background()); err != nil { - logger.Fatal("failed to setup bus: " + err.Error()) - } - - // Finish worker setup. - if err := setupWorkerFn(context.Background()); err != nil { - logger.Fatal("failed to setup worker: " + err.Error()) - } - - // Set initial S3 keys. - if cfg.S3.Enabled && !cfg.S3.DisableAuth { - as, err := bc.S3AuthenticationSettings(context.Background()) - if err != nil && !strings.Contains(err.Error(), api.ErrSettingNotFound.Error()) { - logger.Fatal("failed to fetch S3 authentication settings: " + err.Error()) - } else if as.V4Keypairs == nil { - as.V4Keypairs = make(map[string]string) - } - - // S3 key pair validation was broken at one point, we need to remove the - // invalid key pairs here to ensure we don't fail when we update the - // setting below. - for k, v := range as.V4Keypairs { - if err := (api.S3AuthenticationSettings{V4Keypairs: map[string]string{k: v}}).Validate(); err != nil { - logger.Sugar().Infof("removing invalid S3 keypair for AccessKeyID %s, reason: %v", k, err) - delete(as.V4Keypairs, k) - } - } - - // merge keys - for k, v := range cfg.S3.KeypairsV4 { - as.V4Keypairs[k] = v - } - // update settings - if err := bc.UpdateSetting(context.Background(), api.SettingS3Authentication, as); err != nil { - logger.Fatal("failed to update S3 authentication settings: " + err.Error()) - } - } - - logger.Info("api: Listening on " + l.Addr().String()) - - if s3Srv != nil { - go s3Srv.Serve(s3Listener) - logger.Info("s3: Listening on " + s3Listener.Addr().String()) - } - - syncerAddress, err := bc.SyncerAddress(context.Background()) + // start node + err = node.Run() if err != nil { - logger.Fatal("failed to fetch syncer address: " + err.Error()) - } - logger.Info("bus: Listening on " + syncerAddress) - - if cfg.Autopilot.Enabled { - if err := runCompatMigrateAutopilotJSONToStore(bc, "autopilot", autopilotDir); err != nil { - logger.Fatal("failed to migrate autopilot JSON: " + err.Error()) - } - } - - if cfg.AutoOpenWebUI { - time.Sleep(time.Millisecond) // give the web server a chance to start - _, port, err := net.SplitHostPort(l.Addr().String()) - if err != nil { - logger.Debug("failed to parse API address", zap.Error(err)) - } else if err := openBrowser(fmt.Sprintf("http://127.0.0.1:%s", port)); err != nil { - logger.Debug("failed to open browser", zap.Error(err)) - } + stdoutFatalError("failed to run node: " + err.Error()) } + // wait for interrupt signal signalCh := make(chan os.Signal, 1) signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - select { - case <-signalCh: - logger.Info("Shutting down...") - case err := <-autopilotErr: - logger.Fatal("Fatal autopilot error: " + err.Error()) - } - - // Give each service a fraction of the total shutdown timeout. One service - // timing out shouldn't prevent the others from attempting a shutdown. - timeout := cfg.ShutdownTimeout / time.Duration(len(shutdownFns)) - shutdown := func(fn func(ctx context.Context) error) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return fn(ctx) - } - - // Shut down the autopilot first, then the rest of the services in reverse order and then - exitCode := 0 - for i := len(shutdownFns) - 1; i >= 0; i-- { - if err := shutdown(shutdownFns[i].fn); err != nil { - logger.Sugar().Errorf("Failed to shut down %v: %v", shutdownFns[i].name, err) - exitCode = 1 - } else { - logger.Sugar().Infof("%v shut down successfully", shutdownFns[i].name) - } - } - logger.Info("Shutdown complete") - os.Exit(exitCode) -} + <-signalCh -func openBrowser(url string) error { - switch runtime.GOOS { - case "linux": - return exec.Command("xdg-open", url).Start() - case "windows": - return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - case "darwin": - return exec.Command("open", url).Start() - default: - return fmt.Errorf("unsupported platform %q", runtime.GOOS) - } -} - -func runCompatMigrateAutopilotJSONToStore(bc *bus.Client, id, dir string) (err error) { - // check if the file exists - path := filepath.Join(dir, "autopilot.json") - if _, err := os.Stat(path); os.IsNotExist(err) { - return nil - } - - // defer autopilot dir cleanup - defer func() { - if err == nil { - log.Println("migration: removing autopilot directory") - if err = os.RemoveAll(dir); err == nil { - log.Println("migration: done") - } - } - }() - - // read the json config - log.Println("migration: reading autopilot.json") - //nolint:tagliatelle - var cfg struct { - Config api.AutopilotConfig `json:"Config"` - } - if data, err := os.ReadFile(path); err != nil { - return err - } else if err := json.Unmarshal(data, &cfg); err != nil { - return err - } - - // make sure we don't hang - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // check if the autopilot already exists, if so we don't need to migrate - _, err = bc.Autopilot(ctx, api.DefaultAutopilotID) - if err == nil { - log.Printf("migration: autopilot already exists in the bus, the autopilot.json won't be migrated\n old config: %+v\n", cfg.Config) - return nil - } - - // create an autopilot entry - log.Println("migration: persisting autopilot to the bus") - if err := bc.UpdateAutopilot(ctx, api.Autopilot{ - ID: id, - Config: cfg.Config, - }); err != nil { - return err - } - - // remove autopilot folder and config - log.Println("migration: cleaning up autopilot directory") - if err = os.RemoveAll(dir); err == nil { - log.Println("migration: done") + // shut down the node + err = node.Shutdown() + if err != nil { + os.Exit(1) + return } - return nil + os.Exit(0) } diff --git a/cmd/renterd/node.go b/cmd/renterd/node.go new file mode 100644 index 000000000..89dd75ab0 --- /dev/null +++ b/cmd/renterd/node.go @@ -0,0 +1,600 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "go.sia.tech/core/consensus" + "go.sia.tech/core/gateway" + "go.sia.tech/core/types" + "go.sia.tech/coreutils" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/jape" + "go.sia.tech/renterd/alerts" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/autopilot" + "go.sia.tech/renterd/build" + "go.sia.tech/renterd/bus" + "go.sia.tech/renterd/config" + "go.sia.tech/renterd/internal/utils" + "go.sia.tech/renterd/stores" + "go.sia.tech/renterd/stores/sql" + "go.sia.tech/renterd/stores/sql/mysql" + "go.sia.tech/renterd/stores/sql/sqlite" + "go.sia.tech/renterd/webhooks" + "go.sia.tech/renterd/worker" + "go.sia.tech/renterd/worker/s3" + "go.sia.tech/web/renterd" + "go.uber.org/zap" + "golang.org/x/crypto/blake2b" + "golang.org/x/sys/cpu" +) + +type ( + node struct { + cfg config.Config + + apiSrv *http.Server + apiListener net.Listener + + s3Srv *http.Server + s3Listener net.Listener + + setupFns []fn + shutdownFns []fn + + bus *bus.Client + logger *zap.SugaredLogger + } + + fn struct { + name string + fn func(context.Context) error + } +) + +func newNode(cfg config.Config, network *consensus.Network, genesis types.Block) (*node, error) { + var setupFns, shutdownFns []fn + + // validate config + if cfg.Bus.RemoteAddr != "" && !cfg.Worker.Enabled && !cfg.Autopilot.Enabled { + return nil, errors.New("remote bus, remote worker, and no autopilot -- nothing to do!") + } + if cfg.Worker.Enabled && cfg.Bus.RemoteAddr != "" && cfg.Worker.ExternalAddress == "" { + return nil, errors.New("can't enable the worker using a remote bus, without configuring the worker's external address") + } + if cfg.Autopilot.Enabled && !cfg.Worker.Enabled && len(cfg.Worker.Remotes) == 0 { + return nil, errors.New("can't enable autopilot without providing either workers to connect to or creating a worker") + } + + // initialise directory + err := os.MkdirAll(cfg.Directory, 0700) + if err != nil { + return nil, fmt.Errorf("failed to create directory: %w", err) + } + + // initialise logger + logger, closeFn, err := NewLogger(cfg.Directory, "renterd.log", cfg.Log) + if err != nil { + return nil, fmt.Errorf("failed to create logger: %w", err) + } + shutdownFns = append(shutdownFns, fn{ + name: "Logger", + fn: closeFn, + }) + + // print network and version + logger.Info("renterd", zap.String("version", build.Version()), zap.String("network", network.Name), zap.String("commit", build.Commit()), zap.Time("buildDate", build.BuildTime())) + if runtime.GOARCH == "amd64" && !cpu.X86.HasAVX2 { + logger.Warn("renterd is running on a system without AVX2 support, performance may be degraded") + } + + // initialise a listener and override the HTTP address, we have to do this + // first so we know the actual api address if the user specifies port :0 + l, err := utils.ListenTCP(cfg.HTTP.Address, logger) + if err != nil { + return nil, fmt.Errorf("failed to create listener: %w", err) + } + cfg.HTTP.Address = "http://" + l.Addr().String() + + // initialise a web server + mux := &utils.TreeMux{Sub: make(map[string]utils.TreeMux)} + srv := &http.Server{Handler: mux} + shutdownFns = append(shutdownFns, fn{ + name: "HTTP Server", + fn: srv.Shutdown, + }) + + // initialise auth handler + auth := jape.BasicAuth(cfg.HTTP.Password) + + // generate private key from seed + var pk types.PrivateKey + if cfg.Seed != "" { + var rawSeed [32]byte + err := wallet.SeedFromPhrase(&rawSeed, cfg.Seed) + if err != nil { + return nil, fmt.Errorf("failed to load wallet: %w", err) + } + pk = wallet.KeyFromSeed(&rawSeed, 0) + } + + // initialise bus + busAddr, busPassword := cfg.Bus.RemoteAddr, cfg.Bus.RemotePassword + if cfg.Bus.RemoteAddr == "" { + // ensure we don't hang indefinitely + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + // create bus + b, shutdownFn, err := newBus(ctx, cfg, pk, network, genesis, logger) + if err != nil { + return nil, err + } + + shutdownFns = append(shutdownFns, fn{ + name: "Bus", + fn: shutdownFn, + }) + + mux.Sub["/api/bus"] = utils.TreeMux{Handler: auth(b.Handler())} + busAddr = cfg.HTTP.Address + "/api/bus" + busPassword = cfg.HTTP.Password + + // only serve the UI if a bus is created + mux.Handler = renterd.Handler() + } else { + logger.Info("connecting to remote bus at " + busAddr) + } + bc := bus.NewClient(busAddr, busPassword) + + // initialise workers + var s3Srv *http.Server + var s3Listener net.Listener + var workers []autopilot.Worker + if len(cfg.Worker.Remotes) == 0 { + if cfg.Worker.Enabled { + workerAddr := cfg.HTTP.Address + "/api/worker" + var workerExternAddr string + if cfg.Bus.RemoteAddr != "" { + workerExternAddr = cfg.Worker.ExternalAddress + } else { + workerExternAddr = workerAddr + } + + workerKey := blake2b.Sum256(append([]byte("worker"), pk...)) + w, err := worker.New(cfg.Worker, workerKey, bc, logger) + if err != nil { + logger.Fatal("failed to create worker: " + err.Error()) + } + setupFns = append(setupFns, fn{ + name: "Worker", + fn: func(ctx context.Context) error { + return w.Setup(ctx, workerExternAddr, cfg.HTTP.Password) + }, + }) + shutdownFns = append(shutdownFns, fn{ + name: "Worker", + fn: w.Shutdown, + }) + + mux.Sub["/api/worker"] = utils.TreeMux{Handler: utils.Auth(cfg.HTTP.Password, cfg.Worker.AllowUnauthenticatedDownloads)(w.Handler())} + wc := worker.NewClient(workerAddr, cfg.HTTP.Password) + workers = append(workers, wc) + + if cfg.S3.Enabled { + s3Handler, err := s3.New(bc, w, logger, s3.Opts{ + AuthDisabled: cfg.S3.DisableAuth, + HostBucketBases: cfg.S3.HostBucketBases, + HostBucketEnabled: cfg.S3.HostBucketEnabled, + }) + if err != nil { + err = errors.Join(err, w.Shutdown(context.Background())) + logger.Fatal("failed to create s3 handler: " + err.Error()) + } + + s3Srv = &http.Server{ + Addr: cfg.S3.Address, + Handler: s3Handler, + } + s3Listener, err = utils.ListenTCP(cfg.S3.Address, logger) + if err != nil { + logger.Fatal("failed to create listener: " + err.Error()) + } + shutdownFns = append(shutdownFns, fn{ + name: "S3", + fn: s3Srv.Shutdown, + }) + } + } + } else { + for _, remote := range cfg.Worker.Remotes { + workers = append(workers, worker.NewClient(remote.Address, remote.Password)) + logger.Info("connecting to remote worker at " + remote.Address) + } + } + + // initialise autopilot + if cfg.Autopilot.Enabled { + ap, err := autopilot.New(cfg.Autopilot, bc, workers, logger) + if err != nil { + logger.Fatal("failed to create autopilot: " + err.Error()) + } + setupFns = append(setupFns, fn{ + name: "Autopilot", + fn: func(_ context.Context) error { go ap.Run(); return nil }, + }) + shutdownFns = append(shutdownFns, fn{ + name: "Autopilot", + fn: ap.Shutdown, + }) + + mux.Sub["/api/autopilot"] = utils.TreeMux{Handler: auth(ap.Handler())} + } + + return &node{ + apiSrv: srv, + apiListener: l, + + s3Srv: s3Srv, + s3Listener: s3Listener, + + setupFns: setupFns, + shutdownFns: shutdownFns, + + bus: bc, + cfg: cfg, + + logger: logger.Sugar(), + }, nil +} + +func newBus(ctx context.Context, cfg config.Config, pk types.PrivateKey, network *consensus.Network, genesis types.Block, logger *zap.Logger) (*bus.Bus, func(ctx context.Context) error, error) { + // create store + alertsMgr := alerts.NewManager() + storeCfg, err := buildStoreConfig(alertsMgr, cfg, pk, logger) + if err != nil { + return nil, nil, err + } + sqlStore, err := stores.NewSQLStore(storeCfg) + if err != nil { + return nil, nil, err + } + + // create webhooks manager + wh, err := webhooks.NewManager(sqlStore, logger) + if err != nil { + return nil, nil, err + } + + // hookup webhooks <-> alerts + alertsMgr.RegisterWebhookBroadcaster(wh) + + // create consensus directory + consensusDir := filepath.Join(cfg.Directory, "consensus") + if err := os.MkdirAll(consensusDir, 0700); err != nil { + return nil, nil, err + } + + // migrate consensus database if necessary + migrateConsensusDatabase(ctx, sqlStore, consensusDir, logger) + + // reset chain state if blockchain.db does not exist to make sure deleting + // it forces a resync + chainPath := filepath.Join(consensusDir, "blockchain.db") + if _, err := os.Stat(chainPath); os.IsNotExist(err) { + if err := sqlStore.ResetChainState(context.Background()); err != nil { + return nil, nil, err + } + } + + // create chain database + bdb, err := coreutils.OpenBoltChainDB(chainPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to open chain database: %w", err) + } + + // create chain manager + store, state, err := chain.NewDBStore(bdb, network, genesis) + if err != nil { + return nil, nil, err + } + cm := chain.NewManager(store, state) + + // create wallet + w, err := wallet.NewSingleAddressWallet(pk, cm, sqlStore, wallet.WithReservationDuration(cfg.Bus.UsedUTXOExpiry)) + if err != nil { + return nil, nil, err + } + + // bootstrap the syncer + if cfg.Bus.Bootstrap { + var peers []string + switch network.Name { + case "mainnet": + peers = syncer.MainnetBootstrapPeers + case "zen": + peers = syncer.ZenBootstrapPeers + case "anagami": + peers = syncer.AnagamiBootstrapPeers + default: + return nil, nil, fmt.Errorf("no available bootstrap peers for unknown network '%s'", network.Name) + } + for _, addr := range peers { + if err := sqlStore.AddPeer(addr); err != nil { + return nil, nil, fmt.Errorf("%w: failed to add bootstrap peer '%s'", err, addr) + } + } + } + + // create syncer, peers will reject us if our hostname is empty or + // unspecified, so use loopback + l, err := net.Listen("tcp", cfg.Bus.GatewayAddr) + if err != nil { + return nil, nil, err + } + syncerAddr := l.Addr().String() + host, port, _ := net.SplitHostPort(syncerAddr) + if ip := net.ParseIP(host); ip == nil || ip.IsUnspecified() { + syncerAddr = net.JoinHostPort("127.0.0.1", port) + } + + // create header + header := gateway.Header{ + GenesisID: genesis.ID(), + UniqueID: gateway.GenerateUniqueID(), + NetAddress: syncerAddr, + } + + // create the syncer + s := syncer.New(l, cm, sqlStore, header, syncer.WithLogger(logger.Named("syncer")), syncer.WithSendBlocksTimeout(time.Minute)) + + // start syncer + errChan := make(chan error, 1) + go func() { + errChan <- s.Run(context.Background()) + close(errChan) + }() + + // create a helper function to wait for syncer to wind down on shutdown + syncerShutdown := func(ctx context.Context) error { + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return context.Cause(ctx) + } + } + + // create bus + announcementMaxAgeHours := time.Duration(cfg.Bus.AnnouncementMaxAgeHours) * time.Hour + b, err := bus.New(ctx, alertsMgr, wh, cm, s, w, sqlStore, announcementMaxAgeHours, logger) + if err != nil { + return nil, nil, fmt.Errorf("failed to create bus: %w", err) + } + + return b, func(ctx context.Context) error { + return errors.Join( + s.Close(), + w.Close(), + b.Shutdown(ctx), + sqlStore.Close(), + bdb.Close(), + syncerShutdown(ctx), + ) + }, nil +} + +func (n *node) Run() error { + // start server + go n.apiSrv.Serve(n.apiListener) + n.logger.Info("api: Listening on " + n.apiListener.Addr().String()) + + // execute run functions + for _, fn := range n.setupFns { + if err := fn.fn(context.Background()); err != nil { + return fmt.Errorf("failed to run %v: %w", fn.name, err) + } + } + + // set initial S3 keys + if n.cfg.S3.Enabled && !n.cfg.S3.DisableAuth { + as, err := n.bus.S3AuthenticationSettings(context.Background()) + if err != nil && !strings.Contains(err.Error(), api.ErrSettingNotFound.Error()) { + return fmt.Errorf("failed to fetch S3 authentication settings: %w", err) + } else if as.V4Keypairs == nil { + as.V4Keypairs = make(map[string]string) + } + + // S3 key pair validation was broken at one point, we need to remove the + // invalid key pairs here to ensure we don't fail when we update the + // setting below. + for k, v := range as.V4Keypairs { + if err := (api.S3AuthenticationSettings{V4Keypairs: map[string]string{k: v}}).Validate(); err != nil { + n.logger.Infof("removing invalid S3 keypair for AccessKeyID %s, reason: %v", k, err) + delete(as.V4Keypairs, k) + } + } + + // merge keys + for k, v := range n.cfg.S3.KeypairsV4 { + as.V4Keypairs[k] = v + } + // update settings + if err := n.bus.UpdateSetting(context.Background(), api.SettingS3Authentication, as); err != nil { + return fmt.Errorf("failed to update S3 authentication settings: %w", err) + } + } + + // start S3 server + if n.s3Srv != nil { + go n.s3Srv.Serve(n.s3Listener) + n.logger.Info("s3: Listening on " + n.s3Listener.Addr().String()) + } + + // fetch the syncer address + syncerAddress, err := n.bus.SyncerAddress(context.Background()) + if err != nil { + return fmt.Errorf("failed to fetch syncer address: %w", err) + } + n.logger.Info("bus: Listening on " + syncerAddress) + + // open the web UI if enabled + if n.cfg.AutoOpenWebUI { + time.Sleep(time.Millisecond) // give the web server a chance to start + _, port, err := net.SplitHostPort(n.apiListener.Addr().String()) + if err != nil { + n.logger.Debug("failed to parse API address", zap.Error(err)) + } else if err := utils.OpenBrowser(fmt.Sprintf("http://127.0.0.1:%s", port)); err != nil { + n.logger.Debug("failed to open browser", zap.Error(err)) + } + } + return nil +} + +func (n *node) Shutdown() error { + n.logger.Info("Shutting down...") + + // give each service a fraction of the total shutdown timeout. One service + // timing out shouldn't prevent the others from attempting a shutdown. + timeout := n.cfg.ShutdownTimeout / time.Duration(len(n.shutdownFns)) + shutdown := func(fn func(ctx context.Context) error) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return fn(ctx) + } + + // shut down the services in reverse order + var errs []error + for i := len(n.shutdownFns) - 1; i >= 0; i-- { + if err := shutdown(n.shutdownFns[i].fn); err != nil { + n.logger.Errorf("failed to shut down %v: %v", n.shutdownFns[i].name, err) + errs = append(errs, err) + } else { + n.logger.Infof("%v shut down successfully", n.shutdownFns[i].name) + } + } + + return errors.Join(errs...) +} + +// TODO: needs a better spot +func buildStoreConfig(am alerts.Alerter, cfg config.Config, pk types.PrivateKey, logger *zap.Logger) (stores.Config, error) { + // create database connections + var dbMain sql.Database + var dbMetrics sql.MetricsDatabase + if cfg.Database.MySQL.URI != "" { + // create MySQL connections + connMain, err := mysql.Open( + cfg.Database.MySQL.User, + cfg.Database.MySQL.Password, + cfg.Database.MySQL.URI, + cfg.Database.MySQL.Database, + ) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to open MySQL main database: %w", err) + } + connMetrics, err := mysql.Open( + cfg.Database.MySQL.User, + cfg.Database.MySQL.Password, + cfg.Database.MySQL.URI, + cfg.Database.MySQL.MetricsDatabase, + ) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to open MySQL metrics database: %w", err) + } + dbMain, err = mysql.NewMainDatabase(connMain, logger, cfg.Log.Database.SlowThreshold, cfg.Log.Database.SlowThreshold) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to create MySQL main database: %w", err) + } + dbMetrics, err = mysql.NewMetricsDatabase(connMetrics, logger, cfg.Log.Database.SlowThreshold, cfg.Log.Database.SlowThreshold) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to create MySQL metrics database: %w", err) + } + } else { + // create database directory + dbDir := filepath.Join(cfg.Directory, "db") + if err := os.MkdirAll(dbDir, 0700); err != nil { + return stores.Config{}, err + } + + // create SQLite connections + db, err := sqlite.Open(filepath.Join(dbDir, "db.sqlite")) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to open SQLite main database: %w", err) + } + dbMain, err = sqlite.NewMainDatabase(db, logger, cfg.Log.Database.SlowThreshold, cfg.Log.Database.SlowThreshold) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to create SQLite main database: %w", err) + } + + dbm, err := sqlite.Open(filepath.Join(dbDir, "metrics.sqlite")) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to open SQLite metrics database: %w", err) + } + dbMetrics, err = sqlite.NewMetricsDatabase(dbm, logger, cfg.Log.Database.SlowThreshold, cfg.Log.Database.SlowThreshold) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to create SQLite metrics database: %w", err) + } + } + + return stores.Config{ + Alerts: alerts.WithOrigin(am, "bus"), + DB: dbMain, + DBMetrics: dbMetrics, + PartialSlabDir: filepath.Join(cfg.Directory, "partial_slabs"), + Migrate: true, + SlabBufferCompletionThreshold: cfg.Bus.SlabBufferCompletionThreshold, + Logger: logger, + RetryTransactionIntervals: []time.Duration{ + 200 * time.Millisecond, + 500 * time.Millisecond, + time.Second, + 3 * time.Second, + 10 * time.Second, + 10 * time.Second, + }, + WalletAddress: types.StandardUnlockHash(pk.PublicKey()), + LongQueryDuration: cfg.Log.Database.SlowThreshold, + LongTxDuration: cfg.Log.Database.SlowThreshold, + }, nil +} + +func migrateConsensusDatabase(ctx context.Context, store *stores.SQLStore, consensusDir string, logger *zap.Logger) error { + oldConsensus, err := os.Stat(filepath.Join(consensusDir, "consensus.db")) + if os.IsNotExist(err) { + return nil + } else if err != nil { + return err + } + + logger.Warn("found old consensus.db, indicating a migration is necessary") + + // reset chain state + logger.Warn("Resetting chain state...") + if err := store.ResetChainState(ctx); err != nil { + return err + } + logger.Warn("Chain state was successfully reset.") + + // remove consensus.db and consensus.log file + logger.Warn("Removing consensus database...") + _ = os.RemoveAll(filepath.Join(consensusDir, "consensus.log")) // ignore error + if err := os.Remove(filepath.Join(consensusDir, "consensus.db")); err != nil { + return err + } + + logger.Warn(fmt.Sprintf("Old 'consensus.db' was successfully removed, reclaimed %v of disk space.", utils.HumanReadableSize(int(oldConsensus.Size())))) + logger.Warn("ATTENTION: consensus will now resync from scratch, this process may take several hours to complete") + return nil +} diff --git a/config/config.go b/config/config.go index 5c0f9dc87..99382240b 100644 --- a/config/config.go +++ b/config/config.go @@ -11,16 +11,18 @@ type ( Seed string `yaml:"seed,omitempty"` Directory string `yaml:"directory,omitempty"` AutoOpenWebUI bool `yaml:"autoOpenWebUI,omitempty"` + Network string `yaml:"network,omitempty"` ShutdownTimeout time.Duration `yaml:"shutdownTimeout,omitempty"` Log Log `yaml:"log,omitempty"` - HTTP HTTP `yaml:"http,omitempty"` + HTTP HTTP `yaml:"http,omitempty"` + + Autopilot Autopilot `yaml:"autopilot,omitempty"` Bus Bus `yaml:"bus,omitempty"` Worker Worker `yaml:"worker,omitempty"` S3 S3 `yaml:"s3,omitempty"` - Autopilot Autopilot `yaml:"autopilot,omitempty"` Database Database `yaml:"database,omitempty"` } @@ -51,9 +53,9 @@ type ( GatewayAddr string `yaml:"gatewayAddr,omitempty"` RemoteAddr string `yaml:"remoteAddr,omitempty"` RemotePassword string `yaml:"remotePassword,omitempty"` - PersistInterval time.Duration `yaml:"persistInterval,omitempty"` UsedUTXOExpiry time.Duration `yaml:"usedUtxoExpiry,omitempty"` SlabBufferCompletionThreshold int64 `yaml:"slabBufferCompleionThreshold,omitempty"` + PersistInterval time.Duration `yaml:"persistInterval,omitempty"` // deprecated } // LogFile configures the file output of the logger. @@ -131,6 +133,7 @@ type ( // Autopilot contains the configuration for an autopilot. Autopilot struct { Enabled bool `yaml:"enabled,omitempty"` + ID string `yaml:"id,omitempty"` AccountsRefillInterval time.Duration `yaml:"accountsRefillInterval,omitempty"` Heartbeat time.Duration `yaml:"heartbeat,omitempty"` MigrationHealthCutoff float64 `yaml:"migrationHealthCutoff,omitempty"` diff --git a/docker/Dockerfile b/docker/Dockerfile index 55d27ac99..350c6c6b6 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,9 +1,8 @@ # Helper image to build renterd. -FROM golang:1.21 AS builder +FROM golang:1.23 AS builder # Define arguments for build tags and to skip running go generate. -ARG BUILD_TAGS='netgo' \ - BUILD_RUN_GO_GENERATE='true' +ARG BUILD_RUN_GO_GENERATE='true' # Set the working directory. WORKDIR /renterd @@ -23,36 +22,32 @@ RUN if [ "$BUILD_RUN_GO_GENERATE" = "true" ] ; then go generate ./... ; fi # Build renterd. RUN --mount=type=cache,target=/root/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - CGO_ENABLED=1 go build -ldflags='-s -w -linkmode external -extldflags "-static"' -tags="${BUILD_TAGS}" ./cmd/renterd + CGO_ENABLED=1 go build -ldflags='-s -w -linkmode external -extldflags "-static"' ./cmd/renterd # Build image that will be used to run renterd. -FROM alpine:3 +FROM scratch LABEL maintainer="The Sia Foundation " \ - org.opencontainers.image.description.vendor="The Sia Foundation" \ - org.opencontainers.image.description="A renterd container - next-generation Sia renter" \ - org.opencontainers.image.source="https://github.com/SiaFoundation/renterd" \ - org.opencontainers.image.licenses=MIT + org.opencontainers.image.description.vendor="The Sia Foundation" \ + org.opencontainers.image.description="A renterd container - next-generation Sia renter" \ + org.opencontainers.image.source="https://github.com/SiaFoundation/renterd" \ + org.opencontainers.image.licenses=MIT # User to run renterd as. Defaults to root. ENV PUID=0 ENV PGID=0 -# Entrypoint env args -ARG BUILD_TAGS -ENV BUILD_TAGS=$BUILD_TAGS - # Renterd env args ENV RENTERD_API_PASSWORD= ENV RENTERD_SEED= ENV RENTERD_CONFIG_FILE=/data/renterd.yml +ENV RENTERD_NETWORK='mainnet' # Copy binary and prepare data dir. COPY --from=builder /renterd/renterd /usr/bin/renterd +COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ VOLUME [ "/data" ] USER ${PUID}:${PGID} # Copy the script and set it as the entrypoint. -COPY docker/entrypoint.sh /entrypoint.sh -RUN chmod +x /entrypoint.sh -ENTRYPOINT ["/entrypoint.sh", "-dir", "./data"] +ENTRYPOINT ["renterd", "-env", "-http", ":9980", "-s3.address", ":8080", "-dir", "./data"] diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh deleted file mode 100644 index da8b4d8ce..000000000 --- a/docker/entrypoint.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/sh - -if [[ "$BUILD_TAGS" == *'testnet'* ]]; then - exec renterd -env -http=':9880' -s3.address=':7070' "$@" -else - exec renterd -env -http=':9980' -s3.address=':8080' "$@" -fi diff --git a/go.mod b/go.mod index d5c2d8aed..f8466efe0 100644 --- a/go.mod +++ b/go.mod @@ -1,83 +1,57 @@ module go.sia.tech/renterd -go 1.21.8 - -toolchain go1.22.3 +go 1.22.5 require ( - github.com/gabriel-vasile/mimetype v1.4.4 + github.com/gabriel-vasile/mimetype v1.4.5 + github.com/go-sql-driver/mysql v1.8.1 github.com/google/go-cmp v0.6.0 github.com/gotd/contrib v0.20.0 - github.com/klauspost/reedsolomon v1.12.1 - github.com/minio/minio-go/v7 v7.0.72 + github.com/klauspost/reedsolomon v1.12.3 + github.com/mattn/go-sqlite3 v1.14.22 + github.com/minio/minio-go/v7 v7.0.75 github.com/montanaflynn/stats v0.7.1 github.com/shopspring/decimal v1.4.0 - gitlab.com/NebulousLabs/encoding v0.0.0-20200604091946-456c3dc907fe - go.sia.tech/core v0.3.0 - go.sia.tech/coreutils v0.1.0 + go.sia.tech/core v0.4.3 + go.sia.tech/coreutils v0.2.5 go.sia.tech/gofakes3 v0.0.4 - go.sia.tech/hostd v1.1.1-beta.1.0.20240618072747-b3f430b4d272 - go.sia.tech/jape v0.11.2-0.20240306154058-9832414a5385 + go.sia.tech/hostd v1.1.3-0.20240807214810-c2d8ed84dc45 + go.sia.tech/jape v0.12.0 go.sia.tech/mux v1.2.0 - go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca - go.sia.tech/web/renterd v0.55.0 + go.sia.tech/web/renterd v0.60.1 go.uber.org/zap v1.27.0 - golang.org/x/crypto v0.24.0 - golang.org/x/sys v0.21.0 - golang.org/x/term v0.21.0 + golang.org/x/crypto v0.26.0 + golang.org/x/sys v0.24.0 + golang.org/x/term v0.23.0 gopkg.in/yaml.v3 v3.0.1 - gorm.io/driver/mysql v1.5.7 - gorm.io/driver/sqlite v1.5.6 - gorm.io/gorm v1.25.10 lukechampine.com/frand v1.4.2 - moul.io/zapgorm2 v1.3.0 ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect - github.com/aws/aws-sdk-go v1.54.6 // indirect - github.com/cloudflare/cloudflare-go v0.97.0 // indirect - github.com/dchest/threefish v0.0.0-20120919164726-3ecf4c494abf // indirect + github.com/aws/aws-sdk-go v1.55.5 // indirect + github.com/cloudflare/cloudflare-go v0.101.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/go-ini/ini v1.67.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/websocket v1.5.2 // indirect - github.com/hashicorp/go-cleanhttp v0.5.2 // indirect - github.com/hashicorp/go-retryablehttp v0.7.7 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/julienschmidt/httprouter v1.3.0 // indirect - github.com/klauspost/compress v1.17.7 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/minio/md5-simd v1.1.2 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/rs/xid v1.5.0 // indirect github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 // indirect github.com/shabbyrobe/gocovmerge v0.0.0-20230507112040-c3350d9342df // indirect - gitlab.com/NebulousLabs/bolt v1.4.4 // indirect - gitlab.com/NebulousLabs/demotemutex v0.0.0-20151003192217-235395f71c40 // indirect - gitlab.com/NebulousLabs/entropy-mnemonics v0.0.0-20181018051301-7532f67e3500 // indirect - gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975 // indirect - gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40 // indirect - gitlab.com/NebulousLabs/go-upnp v0.0.0-20211002182029-11da932010b6 // indirect - gitlab.com/NebulousLabs/log v0.0.0-20210609172545-77f6775350e2 // indirect - gitlab.com/NebulousLabs/merkletree v0.0.0-20200118113624-07fbf710afc4 // indirect - gitlab.com/NebulousLabs/monitor v0.0.0-20191205095550-2b0fd3e1012a // indirect - gitlab.com/NebulousLabs/persist v0.0.0-20200605115618-007e5e23d877 // indirect - gitlab.com/NebulousLabs/ratelimit v0.0.0-20200811080431-99b8f0768b2e // indirect - gitlab.com/NebulousLabs/siamux v0.0.2-0.20220630142132-142a1443a259 // indirect - gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213 // indirect + go.etcd.io/bbolt v1.3.10 // indirect go.sia.tech/web v0.0.0-20240610131903-5611d44a533e // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/net v0.26.0 // indirect - golang.org/x/text v0.16.0 // indirect - golang.org/x/time v0.5.0 // indirect + golang.org/x/net v0.27.0 // indirect + golang.org/x/text v0.17.0 // indirect + golang.org/x/time v0.6.0 // indirect golang.org/x/tools v0.22.0 // indirect - gopkg.in/ini.v1 v1.67.0 // indirect nhooyr.io/websocket v1.8.11 // indirect ) diff --git a/go.sum b/go.sum index 6c9e2f318..42cf991d3 100644 --- a/go.sum +++ b/go.sum @@ -1,61 +1,24 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= -github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= -github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudflare/cloudflare-go v0.97.0 h1:feZRGiRF1EbljnNIYdt8014FnOLtC3CCvgkLXu915ks= -github.com/cloudflare/cloudflare-go v0.97.0/go.mod h1:JXRwuTfHpe5xFg8xytc2w0XC6LcrFsBVMS4WlVaiGg8= -github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= +github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/cloudflare/cloudflare-go v0.101.0 h1:SXWNSEDkbdY84iFIZGyTdWQwDfd98ljv0/4UubpleBQ= +github.com/cloudflare/cloudflare-go v0.101.0/go.mod h1:xXQHnoXKR48JlWbFS42i2al3nVqimVhcYvKnIdXLw9g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dchest/threefish v0.0.0-20120919164726-3ecf4c494abf h1:K5VXW9LjmJv/xhjvQcNWTdk4WOSyreil6YaubuCPeRY= -github.com/dchest/threefish v0.0.0-20120919164726-3ecf4c494abf/go.mod h1:bXVurdTuvOiJu7NHALemFe0JMvC2UmwYHW+7fcZaZ2M= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= -github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/gabriel-vasile/mimetype v1.4.4 h1:QjV6pZ7/XZ7ryI2KuyeEDE8wnh7fHP9YnQy+R0LnH8I= -github.com/gabriel-vasile/mimetype v1.4.4/go.mod h1:JwLei5XPtWdGiMFB5Pjle1oEeoSeEuJfJE+TtfvdB/s= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= -github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gabriel-vasile/mimetype v1.4.5 h1:J7wGKdGu33ocBOhGy0z653k/lFKLFDPJMG8Gql0kxn4= +github.com/gabriel-vasile/mimetype v1.4.5/go.mod h1:ibHel+/kbxn9x2407k1izTA1S81ku1z/DlgOW2QE0M4= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -63,322 +26,111 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/gorilla/websocket v1.5.2 h1:qoW6V1GT3aZxybsbC6oLnailWnB+qTMVwMreOso9XUw= -github.com/gorilla/websocket v1.5.2/go.mod h1:0n9H61RBAcf5/38py2MCYbxzPIY9rOkpvvMT24Rqs30= github.com/gotd/contrib v0.20.0 h1:1Wc4+HMQiIKYQuGHVwVksIx152HFTP6B5n88dDe0ZYw= github.com/gotd/contrib v0.20.0/go.mod h1:P6o8W4niqhDPHLA0U+SA/L7l3BQHYLULpeHfRSePn9o= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= -github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok= -github.com/hanwen/go-fuse/v2 v2.1.0/go.mod h1:oRyA5eK+pvJyv5otpO/DgccS8y/RvYMaO00GgRLGryc= -github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= -github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= -github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= -github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= -github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= -github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf/go.mod h1:hyb9oH7vZsitZCiBt0ZvifOrB+qc8PS5IiilCIb87rg= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= -github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= -github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/klauspost/cpuid v1.2.2/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/klauspost/reedsolomon v1.9.3/go.mod h1:CwCi+NUr9pqSVktrkN+Ondf06rkhYZ/pcNv7fu+8Un4= -github.com/klauspost/reedsolomon v1.12.1 h1:NhWgum1efX1x58daOBGCFWcxtEhOhXKKl1HAPQUp03Q= -github.com/klauspost/reedsolomon v1.12.1/go.mod h1:nEi5Kjb6QqtbofI6s+cbG/j1da11c96IBYBSnVGtuBs= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/klauspost/reedsolomon v1.12.3 h1:tzUznbfc3OFwJaTebv/QdhnFf2Xvb7gZ24XaHLBPmdc= +github.com/klauspost/reedsolomon v1.12.3/go.mod h1:3K5rXwABAvzGeR01r6pWZieUALXO/Tq7bFKGIb4m4WI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= -github.com/minio/minio-go/v7 v7.0.72 h1:ZSbxs2BfJensLyHdVOgHv+pfmvxYraaUy07ER04dWnA= -github.com/minio/minio-go/v7 v7.0.72/go.mod h1:4yBA8v80xGA30cfM3fz0DKYMXunWl/AV/6tWEs9ryzo= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/minio/minio-go/v7 v7.0.75 h1:0uLrB6u6teY2Jt+cJUVi9cTvDRuBKWSRzSAcznRkwlE= +github.com/minio/minio-go/v7 v7.0.75/go.mod h1:qydcVzV8Hqtj1VtEocfxbmVFa2siu6HGa+LDEPogjD8= github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 h1:GHRpF1pTW19a8tTFrMLUcfWwyC0pnifVo2ClaLq+hP8= github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8= github.com/shabbyrobe/gocovmerge v0.0.0-20230507112040-c3350d9342df h1:S77Pf5fIGMa7oSwp8SQPp7Hb4ZiI38K3RNBKD2LLeEM= github.com/shabbyrobe/gocovmerge v0.0.0-20230507112040-c3350d9342df/go.mod h1:dcuzJZ83w/SqN9k4eQqwKYMgmKWzg/KzJAURBhRL1tc= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= -github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v1.0.0/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= -github.com/vbauerster/mpb/v5 v5.0.3/go.mod h1:h3YxU5CSr8rZP4Q3xZPVB3jJLhWPou63lHEdr9ytH4Y= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -gitlab.com/NebulousLabs/bolt v1.4.4 h1:3UhpR2qtHs87dJBE3CIzhw48GYSoUUNByJmic0cbu1w= -gitlab.com/NebulousLabs/bolt v1.4.4/go.mod h1:ZL02cwhpLNif6aruxvUMqu/Bdy0/lFY21jMFfNAA+O8= -gitlab.com/NebulousLabs/demotemutex v0.0.0-20151003192217-235395f71c40 h1:IbucNi8u1a1ErgVFVgg8pERhSyzYe5l+o8krDMnNjWA= -gitlab.com/NebulousLabs/demotemutex v0.0.0-20151003192217-235395f71c40/go.mod h1:HfnnxM8isYA7FUlqS5h34XTeiBhPtcuCquVujKsn9aw= -gitlab.com/NebulousLabs/encoding v0.0.0-20200604091946-456c3dc907fe h1:vylvMCgxVPYojpQ2p536xDooW/B3znEnw58mCxrlZow= -gitlab.com/NebulousLabs/encoding v0.0.0-20200604091946-456c3dc907fe/go.mod h1:Gi3CPCauIWmGp7YrnV/mKZ8qkD/N/LrunGNc8QmsVkU= -gitlab.com/NebulousLabs/entropy-mnemonics v0.0.0-20181018051301-7532f67e3500 h1:BUDZfLl/9IRseYl7/GW1DF+11SYCMJ6P4whCBJhtEhQ= -gitlab.com/NebulousLabs/entropy-mnemonics v0.0.0-20181018051301-7532f67e3500/go.mod h1:4koft3fRXTETovKPTeX/Aggj+ajCGWCcuuBBc598Pcs= -gitlab.com/NebulousLabs/errors v0.0.0-20171229012116-7ead97ef90b8/go.mod h1:ZkMZ0dpQyWwlENaeZVBiQRjhMEZvk6VTXquzl3FOFP8= -gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975 h1:L/ENs/Ar1bFzUeKx6m3XjlmBgIUlykX9dzvp5k9NGxc= -gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975/go.mod h1:ZkMZ0dpQyWwlENaeZVBiQRjhMEZvk6VTXquzl3FOFP8= -gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40 h1:dizWJqTWjwyD8KGcMOwgrkqu1JIkofYgKkmDeNE7oAs= -gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40/go.mod h1:rOnSnoRyxMI3fe/7KIbVcsHRGxe30OONv8dEgo+vCfA= -gitlab.com/NebulousLabs/go-upnp v0.0.0-20181011194642-3a71999ed0d3/go.mod h1:sleOmkovWsDEQVYXmOJhx69qheoMTmCuPYyiCFCihlg= -gitlab.com/NebulousLabs/go-upnp v0.0.0-20211002182029-11da932010b6 h1:WKij6HF8ECp9E7K0E44dew9NrRDGiNR5u4EFsXnJUx4= -gitlab.com/NebulousLabs/go-upnp v0.0.0-20211002182029-11da932010b6/go.mod h1:vhrHTGDh4YR7wK8Z+kRJ+x8SF/6RUM3Vb64Si5FD0L8= -gitlab.com/NebulousLabs/log v0.0.0-20200529173103-40b250c2d92c/go.mod h1:qOhJbQ7Vzw+F+RCVmpPZ7WAwBIM9PZv4tWKp6Kgd9CY= -gitlab.com/NebulousLabs/log v0.0.0-20200604091839-0ba4a941cdc2/go.mod h1:qOhJbQ7Vzw+F+RCVmpPZ7WAwBIM9PZv4tWKp6Kgd9CY= -gitlab.com/NebulousLabs/log v0.0.0-20210609172545-77f6775350e2 h1:ovh05+n1jw7R9KT3qa5kdK4T26fIKyVogws06goZ5+Y= -gitlab.com/NebulousLabs/log v0.0.0-20210609172545-77f6775350e2/go.mod h1:qOhJbQ7Vzw+F+RCVmpPZ7WAwBIM9PZv4tWKp6Kgd9CY= -gitlab.com/NebulousLabs/merkletree v0.0.0-20200118113624-07fbf710afc4 h1:iuNdBfBg0umjOvrEf9MxGzK+NwAyE2oCZjDqUx9zVFs= -gitlab.com/NebulousLabs/merkletree v0.0.0-20200118113624-07fbf710afc4/go.mod h1:0cjDwhA+Pv9ZQXHED7HUSS3sCvo2zgsoaMgE7MeGBWo= -gitlab.com/NebulousLabs/monitor v0.0.0-20191205095550-2b0fd3e1012a h1:fs891phmYZrVdaCVPXfHGDMpV5LWPKvnOMjx70EpJkw= -gitlab.com/NebulousLabs/monitor v0.0.0-20191205095550-2b0fd3e1012a/go.mod h1:QxXtb5hIp2xQkfb+lzBDIqQIGEj22U7AkYCXO3hkhqc= -gitlab.com/NebulousLabs/persist v0.0.0-20200605115618-007e5e23d877 h1:BGJ+na/hpeAV6WR8Pys9bJM2ynEwKmT6+qgF8pn01fM= -gitlab.com/NebulousLabs/persist v0.0.0-20200605115618-007e5e23d877/go.mod h1:KT2SgNX75xjMIQdDi3Rf3tcDWsX/D289R65Ss/7lKBg= -gitlab.com/NebulousLabs/ratelimit v0.0.0-20200811080431-99b8f0768b2e h1:sMZdmPFduUilFk8Ed1Ya/DP0gVfUbGhLlNtLG2tONYk= -gitlab.com/NebulousLabs/ratelimit v0.0.0-20200811080431-99b8f0768b2e/go.mod h1:HVrehlTxX2hYjsrL1k0WK43OZ0NGZfGvqzPL+n0/zrM= -gitlab.com/NebulousLabs/siamux v0.0.0-20200723083235-f2c35a421446/go.mod h1:B0RyynPElUG2Y2CAVIIRriIqR9qht2I+nDisi3gfKn0= -gitlab.com/NebulousLabs/siamux v0.0.2-0.20220630142132-142a1443a259 h1:94CUlkiIN8Mu+hTYgT7n36SbJ7WR6ZMM91ReaSDxUlQ= -gitlab.com/NebulousLabs/siamux v0.0.2-0.20220630142132-142a1443a259/go.mod h1:owMSVLlMMCSK6tfhfSshZhrsIFCUNvQEsiGZoWhaXcc= -gitlab.com/NebulousLabs/threadgroup v0.0.0-20200527092543-afa01960408c/go.mod h1:av52iTyGuPtGU+GMcqfGtZu2vxhIjPgrxvIwVYelEvs= -gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213 h1:owERlKtUEFTPQ897iiqWPOuWBdq7BYqPxDOCgEZnbN4= -gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213/go.mod h1:vIutAvl7lmJqLVYTCBY5WDdJomP+V74At8LCeEYoH8w= -gitlab.com/NebulousLabs/writeaheadlog v0.0.0-20200618142844-c59a90f49130/go.mod h1:SxigdS5Q1ui+OMgGAXt1E/Fg3RB6PvKXMov2O3gvIzs= -go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.10 h1:+BqfJTcCzTItrop8mq/lbzL8wSGtj94UO/3U31shqG0= go.etcd.io/bbolt v1.3.10/go.mod h1:bK3UQLPJZly7IlNmV7uVHJDxfe5aK9Ll93e/74Y9oEQ= -go.sia.tech/core v0.3.0 h1:PDfAQh9z8PYD+oeVS7rS9SEnTMOZzwwFfAH45yktmko= -go.sia.tech/core v0.3.0/go.mod h1:BMgT/reXtgv6XbDgUYTCPY7wSMbspDRDs7KMi1vL6Iw= -go.sia.tech/coreutils v0.1.0 h1:WQL7iT+jK1BiMx87bASXrZJZf4N2fbQkIOW8rS7wkh4= -go.sia.tech/coreutils v0.1.0/go.mod h1:ybaFgewKXrlxFW71LqsyQlxjG6yWL6BSePrbZYnrprU= +go.sia.tech/core v0.4.3 h1:XEX7v6X8eJh4zyOkSHYi6FsyD+N/OEKw/NIigaaWPAU= +go.sia.tech/core v0.4.3/go.mod h1:cGfGNcyAq1k4oIOsrNpJV/Z/p+20/IMS6vIaofE8nr8= +go.sia.tech/coreutils v0.2.5 h1:oMnBGMBRfxhLzTH1ZDBg0Ep0QLE2GE1lND9yfzOzenA= +go.sia.tech/coreutils v0.2.5/go.mod h1:Pg9eE3xL25couNL/vYrtCWP5uXkVvC+SUcMVh1/E7+I= go.sia.tech/gofakes3 v0.0.4 h1:Kvo8j5cVdJRBXvV1KBJ69bocY23twG8ao/HCdwuPMeI= go.sia.tech/gofakes3 v0.0.4/go.mod h1:6hh4lETCMbyFFNWp3FRE838geY6vh1Aeas7LtYDpQdc= -go.sia.tech/hostd v1.1.1-beta.1.0.20240618072747-b3f430b4d272 h1:RJmZ1Y9PoqpHjYHT5nr6Vmo6tTUpB2AIyd8zFge2JAs= -go.sia.tech/hostd v1.1.1-beta.1.0.20240618072747-b3f430b4d272/go.mod h1:bM0ldLiCPAQenZcczN5I6Iw43iNcCTQqK3aLZlAQ/rc= -go.sia.tech/jape v0.11.2-0.20240306154058-9832414a5385 h1:Gho1g6pkv56o6Ut9cez/Yu5o4xlA8WNkDbPn6RWXL7g= -go.sia.tech/jape v0.11.2-0.20240306154058-9832414a5385/go.mod h1:wU+h6Wh5olDjkPXjF0tbZ1GDgoZ6VTi4naFw91yyWC4= +go.sia.tech/hostd v1.1.3-0.20240807214810-c2d8ed84dc45 h1:yq8n3leZWAeEwbAa3sbqe5mS5LgG5IH23aM8tefSuUo= +go.sia.tech/hostd v1.1.3-0.20240807214810-c2d8ed84dc45/go.mod h1:MSP0m1OPZGE5hyXEx35HM6MJWsrL0MLKwaKMzW4b8JU= +go.sia.tech/jape v0.12.0 h1:13fBi7c5X8zxTQ05Cd9ZsIfRJgdvGoZqbEzH861z7BU= +go.sia.tech/jape v0.12.0/go.mod h1:wU+h6Wh5olDjkPXjF0tbZ1GDgoZ6VTi4naFw91yyWC4= go.sia.tech/mux v1.2.0 h1:ofa1Us9mdymBbGMY2XH/lSpY8itFsKIo/Aq8zwe+GHU= go.sia.tech/mux v1.2.0/go.mod h1:Yyo6wZelOYTyvrHmJZ6aQfRoer3o4xyKQ4NmQLJrBSo= -go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca h1:aZMg2AKevn7jKx+wlusWQfwSM5pNU9aGtRZme29q3O4= -go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca/go.mod h1:h/1afFwpxzff6/gG5i1XdAgPK7dEY6FaibhK7N5F86Y= go.sia.tech/web v0.0.0-20240610131903-5611d44a533e h1:oKDz6rUExM4a4o6n/EXDppsEka2y/+/PgFOZmHWQRSI= go.sia.tech/web v0.0.0-20240610131903-5611d44a533e/go.mod h1:4nyDlycPKxTlCqvOeRO0wUfXxyzWCEE7+2BRrdNqvWk= -go.sia.tech/web/renterd v0.55.0 h1:xjHF0TudolsrQbguNR6+J/OPeXVf+ekodVtLB3y/dyU= -go.sia.tech/web/renterd v0.55.0/go.mod h1:SWwKoAJvLxiHjTXsNPKX3RLiQzJb/vxwcpku3F78MO8= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.sia.tech/web/renterd v0.60.1 h1:KJ/DgYKES29HoRd4/XY/G9CzTrHpMANCRCffIYc6Sxg= +go.sia.tech/web/renterd v0.60.1/go.mod h1:SWwKoAJvLxiHjTXsNPKX3RLiQzJb/vxwcpku3F78MO8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= -go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191105034135-c7e5f84aec59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200109152110-61a87790db17/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200117160349-530e935923ad/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210421210424-b80969c67360/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20190829051458-42f498d34c4d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= -gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= -gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= -gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= -gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= -gorm.io/gorm v1.23.6/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= -gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= -gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= lukechampine.com/frand v1.4.2 h1:RzFIpOvkMXuPMBb9maa4ND4wjBn71E1Jpf8BzJHMaVw= lukechampine.com/frand v1.4.2/go.mod h1:4S/TM2ZgrKejMcKMbeLjISpJMO+/eZ1zu3vYX9dtj3s= -moul.io/zapgorm2 v1.3.0 h1:+CzUTMIcnafd0d/BvBce8T4uPn6DQnpIrz64cyixlkk= -moul.io/zapgorm2 v1.3.0/go.mod h1:nPVy6U9goFKHR4s+zfSo1xVFaoU7Qgd5DoCdOfzoCqs= nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0= nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/hostdb/hostdb.go b/hostdb/hostdb.go deleted file mode 100644 index 1a957e327..000000000 --- a/hostdb/hostdb.go +++ /dev/null @@ -1,59 +0,0 @@ -package hostdb - -import ( - "time" - - "gitlab.com/NebulousLabs/encoding" - "go.sia.tech/core/types" - "go.sia.tech/siad/crypto" - "go.sia.tech/siad/modules" -) - -// Announcement represents a host announcement in a given block. -type Announcement struct { - Index types.ChainIndex - Timestamp time.Time - NetAddress string -} - -type hostAnnouncement struct { - modules.HostAnnouncement - Signature types.Signature -} - -// ForEachAnnouncement calls fn on each host announcement in a block. -func ForEachAnnouncement(b types.Block, height uint64, fn func(types.PublicKey, Announcement)) { - for _, txn := range b.Transactions { - for _, arb := range txn.ArbitraryData { - // decode announcement - var ha hostAnnouncement - if err := encoding.Unmarshal(arb, &ha); err != nil { - continue - } else if ha.Specifier != modules.PrefixHostAnnouncement { - continue - } - - // verify signature - var hostKey types.PublicKey - copy(hostKey[:], ha.PublicKey.Key) - annHash := types.Hash256(crypto.HashObject(ha.HostAnnouncement)) - if !hostKey.VerifyHash(annHash, ha.Signature) { - continue - } - - // verify net address - if ha.NetAddress == "" { - continue - } - - fn(hostKey, Announcement{ - Index: types.ChainIndex{ - Height: height, - ID: b.ID(), - }, - Timestamp: b.Timestamp, - NetAddress: string(ha.NetAddress), - }) - } - } -} diff --git a/bus/accounts.go b/internal/bus/accounts.go similarity index 55% rename from bus/accounts.go rename to internal/bus/accounts.go index d072de1c7..666398e57 100644 --- a/bus/accounts.go +++ b/internal/bus/accounts.go @@ -16,45 +16,125 @@ import ( "lukechampine.com/frand" ) -var errAccountsNotFound = errors.New("account doesn't exist") +var ( + ErrAccountNotFound = errors.New("account doesn't exist") +) -type accounts struct { - mu sync.Mutex - byID map[rhpv3.Account]*account - logger *zap.SugaredLogger -} +type ( + AccountStore interface { + Accounts(context.Context) ([]api.Account, error) + SaveAccounts(context.Context, []api.Account) error + SetUncleanShutdown(context.Context) error + } +) -type account struct { - mu sync.Mutex - locks map[uint64]*accountLock - requiresSyncTime time.Time - api.Account +type ( + AccountMgr struct { + s AccountStore + logger *zap.SugaredLogger - rwmu sync.RWMutex -} + mu sync.Mutex + byID map[rhpv3.Account]*account + } -type accountLock struct { - heldByID uint64 - unlock func() - timer *time.Timer -} + account struct { + mu sync.Mutex + locks map[uint64]*accountLock + requiresSyncTime time.Time + api.Account + + rwmu sync.RWMutex + } -func newAccounts(accs []api.Account, logger *zap.SugaredLogger) *accounts { - a := &accounts{ - byID: make(map[rhpv3.Account]*account), - logger: logger.Named("accounts"), + accountLock struct { + heldByID uint64 + unlock func() + timer *time.Timer } - for _, acc := range accs { +) + +// NewAccountManager creates a new account manager. It will load all accounts +// from the given store and mark the shutdown as unclean. When Shutdown is +// called it will save all accounts. +func NewAccountManager(ctx context.Context, s AccountStore, logger *zap.Logger) (*AccountMgr, error) { + logger = logger.Named("accounts") + + // load saved accounts + saved, err := s.Accounts(ctx) + if err != nil { + return nil, err + } + + // wrap with a lock + accounts := make(map[rhpv3.Account]*account, len(saved)) + for _, acc := range saved { account := &account{ Account: acc, locks: map[uint64]*accountLock{}, } - a.byID[account.ID] = account + accounts[account.ID] = account + } + + // mark the shutdown as unclean, this will be overwritten on shutdown + err = s.SetUncleanShutdown(ctx) + if err != nil { + return nil, fmt.Errorf("failed to mark account shutdown as unclean: %w", err) } - return a + + return &AccountMgr{ + s: s, + logger: logger.Sugar(), + + byID: accounts, + }, nil } -func (a *accounts) LockAccount(ctx context.Context, id rhpv3.Account, hostKey types.PublicKey, exclusive bool, duration time.Duration) (api.Account, uint64) { +// Account returns the account with the given id. +func (a *AccountMgr) Account(id rhpv3.Account, hostKey types.PublicKey) (api.Account, error) { + acc := a.account(id, hostKey) + acc.mu.Lock() + defer acc.mu.Unlock() + return acc.convert(), nil +} + +// Accounts returns all accounts. +func (a *AccountMgr) Accounts() []api.Account { + a.mu.Lock() + defer a.mu.Unlock() + accounts := make([]api.Account, 0, len(a.byID)) + for _, acc := range a.byID { + acc.mu.Lock() + accounts = append(accounts, acc.convert()) + acc.mu.Unlock() + } + return accounts +} + +// AddAmount applies the provided amount to an account through addition. So the +// input can be both a positive or negative number depending on whether a +// withdrawal or deposit is recorded. If the account doesn't exist, it is +// created. +func (a *AccountMgr) AddAmount(id rhpv3.Account, hk types.PublicKey, amt *big.Int) { + acc := a.account(id, hk) + + // Update balance. + acc.mu.Lock() + balanceBefore := acc.Balance.String() + acc.Balance.Add(acc.Balance, amt) + + // Log deposits. + if amt.Cmp(big.NewInt(0)) > 0 { + a.logger.Infow("account balance was increased", + "account", acc.ID, + "host", acc.HostKey.String(), + "amt", amt.String(), + "balanceBefore", balanceBefore, + "balanceAfter", acc.Balance.String()) + } + acc.mu.Unlock() +} + +func (a *AccountMgr) LockAccount(ctx context.Context, id rhpv3.Account, hostKey types.PublicKey, exclusive bool, duration time.Duration) (api.Account, uint64) { acc := a.account(id, hostKey) // Try to lock the account. @@ -93,93 +173,72 @@ func (a *accounts) LockAccount(ctx context.Context, id rhpv3.Account, hostKey ty return account, lock.heldByID } -func (a *accounts) UnlockAccount(id rhpv3.Account, lockID uint64) error { +// ResetDrift resets the drift on an account. +func (a *AccountMgr) ResetDrift(id rhpv3.Account) error { a.mu.Lock() - acc, exists := a.byID[id] + account, exists := a.byID[id] if !exists { a.mu.Unlock() - return errAccountsNotFound + return ErrAccountNotFound } a.mu.Unlock() - // Get lock. - acc.mu.Lock() - lock, exists := acc.locks[lockID] - acc.mu.Unlock() - if !exists { - return fmt.Errorf("account lock with id %v not found", lockID) - } - - // Stop timer. - lock.timer.Stop() - select { - case <-lock.timer.C: - default: - } - - // Unlock - lock.unlock() - return nil -} + account.mu.Lock() + driftBefore := account.Drift.String() + account.mu.Unlock() -// AddAmount applies the provided amount to an account through addition. So the -// input can be both a positive or negative number depending on whether a -// withdrawal or deposit is recorded. If the account doesn't exist, it is -// created. -func (a *accounts) AddAmount(id rhpv3.Account, hk types.PublicKey, amt *big.Int) { - acc := a.account(id, hk) + account.resetDrift() - // Update balance. - acc.mu.Lock() - balanceBefore := acc.Balance.String() - acc.Balance.Add(acc.Balance, amt) + a.logger.Infow("account drift was reset", + zap.Stringer("account", account.ID), + zap.Stringer("host", account.HostKey), + zap.String("driftBefore", driftBefore)) - // Log deposits. - if amt.Cmp(big.NewInt(0)) > 0 { - a.logger.Infow("account balance was increased", - "account", acc.ID, - "host", acc.HostKey.String(), - "amt", amt.String(), - "balanceBefore", balanceBefore, - "balanceAfter", acc.Balance.String()) - } - acc.mu.Unlock() + return nil } // SetBalance sets the balance of a given account to the provided amount. If the // account doesn't exist, it is created. // If an account hasn't been saved successfully upon the last shutdown, no drift // will be added upon the first call to SetBalance. -func (a *accounts) SetBalance(id rhpv3.Account, hk types.PublicKey, balance *big.Int) { +func (a *AccountMgr) SetBalance(id rhpv3.Account, hk types.PublicKey, balance *big.Int) { acc := a.account(id, hk) - // Update balance and drift. acc.mu.Lock() - delta := new(big.Int).Sub(balance, acc.Balance) - balanceBefore := acc.Balance.String() - driftBefore := acc.Drift.String() + defer acc.mu.Unlock() + + // save previous values + prevBalance := new(big.Int).Set(acc.Balance) + prevDrift := new(big.Int).Set(acc.Drift) + + // update balance + acc.Balance.Set(balance) + + // update drift + drift := new(big.Int).Sub(balance, prevBalance) if acc.CleanShutdown { - acc.Drift = acc.Drift.Add(acc.Drift, delta) + acc.Drift = acc.Drift.Add(acc.Drift, drift) } - acc.Balance.Set(balance) + + // reset fields acc.CleanShutdown = true - acc.RequiresSync = false // resetting the balance resets the sync field - balanceAfter := acc.Balance.String() - acc.mu.Unlock() + acc.RequiresSync = false - // Log resets. + // log account changes a.logger.Infow("account balance was reset", - "account", acc.ID, - "host", acc.HostKey.String(), - "balanceBefore", balanceBefore, - "balanceAfter", balanceAfter, - "driftBefore", driftBefore, - "driftAfter", acc.Drift.String(), - "delta", delta.String()) + zap.Stringer("account", acc.ID), + zap.Stringer("host", acc.HostKey), + zap.Stringer("balanceBefore", prevBalance), + zap.Stringer("balanceAfter", balance), + zap.Stringer("driftBefore", prevDrift), + zap.Stringer("driftAfter", acc.Drift), + zap.Bool("firstDrift", acc.Drift.Cmp(big.NewInt(0)) != 0 && prevDrift.Cmp(big.NewInt(0)) == 0), + zap.Bool("cleanshutdown", acc.CleanShutdown), + zap.Stringer("drift", drift)) } // ScheduleSync sets the requiresSync flag of an account. -func (a *accounts) ScheduleSync(id rhpv3.Account, hk types.PublicKey) error { +func (a *AccountMgr) ScheduleSync(id rhpv3.Account, hk types.PublicKey) error { acc := a.account(id, hk) acc.mu.Lock() // Only update the sync flag to 'true' if some time has passed since the @@ -204,83 +263,58 @@ func (a *accounts) ScheduleSync(id rhpv3.Account, hk types.PublicKey) error { account, exists := a.byID[id] defer a.mu.Unlock() if !exists { - return errAccountsNotFound + return ErrAccountNotFound } account.resetDrift() return nil } -func (a *account) convert() api.Account { - return api.Account{ - ID: a.ID, - Balance: new(big.Int).Set(a.Balance), - CleanShutdown: a.CleanShutdown, - Drift: new(big.Int).Set(a.Drift), - HostKey: a.HostKey, - RequiresSync: a.RequiresSync, +func (a *AccountMgr) Shutdown(ctx context.Context) error { + accounts := a.Accounts() + err := a.s.SaveAccounts(ctx, accounts) + if err != nil { + a.logger.Errorf("failed to save %v accounts: %v", len(accounts), err) + return err } -} -// Account returns the account with the given id. -func (a *accounts) Account(id rhpv3.Account, hostKey types.PublicKey) (api.Account, error) { - acc := a.account(id, hostKey) - acc.mu.Lock() - defer acc.mu.Unlock() - return acc.convert(), nil -} - -// Accounts returns all accounts. -func (a *accounts) Accounts() []api.Account { - a.mu.Lock() - defer a.mu.Unlock() - accounts := make([]api.Account, 0, len(a.byID)) - for _, acc := range a.byID { - acc.mu.Lock() - accounts = append(accounts, acc.convert()) - acc.mu.Unlock() - } - return accounts + a.logger.Infof("successfully saved %v accounts", len(accounts)) + return nil } -// ResetDrift resets the drift on an account. -func (a *accounts) ResetDrift(id rhpv3.Account) error { +// UnlockAccount unlocks an account with the given lock id. +func (a *AccountMgr) UnlockAccount(id rhpv3.Account, lockID uint64) error { a.mu.Lock() - account, exists := a.byID[id] + acc, exists := a.byID[id] if !exists { a.mu.Unlock() - return errAccountsNotFound + return ErrAccountNotFound } a.mu.Unlock() - account.resetDrift() - return nil -} -// ToPersist returns all known accounts to be persisted by the storage backend. -// Called once on shutdown. -func (a *accounts) ToPersist() []api.Account { - a.mu.Lock() - defer a.mu.Unlock() - accounts := make([]api.Account, 0, len(a.byID)) - for _, acc := range a.byID { - acc.mu.Lock() - accounts = append(accounts, api.Account{ - ID: acc.ID, - Balance: new(big.Int).Set(acc.Balance), - CleanShutdown: acc.CleanShutdown, - Drift: new(big.Int).Set(acc.Drift), - HostKey: acc.HostKey, - RequiresSync: acc.RequiresSync, - }) - acc.mu.Unlock() + // Get lock. + acc.mu.Lock() + lock, exists := acc.locks[lockID] + acc.mu.Unlock() + if !exists { + return fmt.Errorf("account lock with id %v not found", lockID) } - return accounts + + // Stop timer. + lock.timer.Stop() + select { + case <-lock.timer.C: + default: + } + + // Unlock + lock.unlock() + return nil } -func (a *accounts) account(id rhpv3.Account, hk types.PublicKey) *account { +func (a *AccountMgr) account(id rhpv3.Account, hk types.PublicKey) *account { a.mu.Lock() defer a.mu.Unlock() - // Create account if it doesn't exist. acc, exists := a.byID[id] if !exists { acc = &account{ @@ -290,7 +324,7 @@ func (a *accounts) account(id rhpv3.Account, hk types.PublicKey) *account { HostKey: hk, Balance: big.NewInt(0), Drift: big.NewInt(0), - RequiresSync: false, + RequiresSync: true, // initial sync }, locks: map[uint64]*accountLock{}, } @@ -299,6 +333,17 @@ func (a *accounts) account(id rhpv3.Account, hk types.PublicKey) *account { return acc } +func (a *account) convert() api.Account { + return api.Account{ + ID: a.ID, + Balance: new(big.Int).Set(a.Balance), + CleanShutdown: a.CleanShutdown, + Drift: new(big.Int).Set(a.Drift), + HostKey: a.HostKey, + RequiresSync: a.RequiresSync, + } +} + func (a *account) resetDrift() { a.mu.Lock() defer a.mu.Unlock() diff --git a/bus/accounts_test.go b/internal/bus/accounts_test.go similarity index 81% rename from bus/accounts_test.go rename to internal/bus/accounts_test.go index 70c813f68..38d062e75 100644 --- a/bus/accounts_test.go +++ b/internal/bus/accounts_test.go @@ -7,12 +7,23 @@ import ( rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" + "go.sia.tech/renterd/api" "go.uber.org/zap" "lukechampine.com/frand" ) +type mockAccStore struct{} + +func (m *mockAccStore) Accounts(context.Context) ([]api.Account, error) { return nil, nil } +func (m *mockAccStore) SaveAccounts(context.Context, []api.Account) error { return nil } +func (m *mockAccStore) SetUncleanShutdown(context.Context) error { return nil } + func TestAccountLocking(t *testing.T) { - accounts := newAccounts(nil, zap.NewNop().Sugar()) + eas := &mockAccStore{} + accounts, err := NewAccountManager(context.Background(), eas, zap.NewNop()) + if err != nil { + t.Fatal(err) + } var accountID rhpv3.Account frand.Read(accountID[:]) diff --git a/internal/bus/alerts.go b/internal/bus/alerts.go new file mode 100644 index 000000000..ef06d230c --- /dev/null +++ b/internal/bus/alerts.go @@ -0,0 +1,25 @@ +package bus + +import ( + "fmt" + "time" + + "go.sia.tech/renterd/alerts" +) + +var ( + alertPricePinningID = alerts.RandomAlertID() // constant until restarted +) + +func newPricePinningFailedAlert(err error) alerts.Alert { + return alerts.Alert{ + ID: alertPricePinningID, + Severity: alerts.SeverityWarning, + Message: "Price pinning failed", + Data: map[string]any{ + "error": err.Error(), + "hint": fmt.Sprintf("This might happen when the forex API is temporarily unreachable. This alert will disappear the next time prices were updated successfully"), + }, + Timestamp: time.Now(), + } +} diff --git a/internal/bus/chainsubscriber.go b/internal/bus/chainsubscriber.go new file mode 100644 index 000000000..e1200c24b --- /dev/null +++ b/internal/bus/chainsubscriber.go @@ -0,0 +1,566 @@ +package bus + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" + "go.sia.tech/renterd/stores/sql" + "go.sia.tech/renterd/webhooks" + "go.uber.org/zap" +) + +const ( + // updatesBatchSize is the maximum number of updates to fetch in a single + // call to the chain manager when we request updates since a given index. + updatesBatchSize = 100 + + // syncUpdateFrequency is the frequency with which we log sync progress. + syncUpdateFrequency = 1e3 * updatesBatchSize +) + +var ( + errClosed = errors.New("subscriber closed") +) + +type ( + ChainManager interface { + OnReorg(fn func(types.ChainIndex)) (cancel func()) + RecommendedFee() types.Currency + Tip() types.ChainIndex + UpdatesSince(index types.ChainIndex, max int) (rus []chain.RevertUpdate, aus []chain.ApplyUpdate, err error) + } + + ChainStore interface { + ChainIndex(ctx context.Context) (types.ChainIndex, error) + ProcessChainUpdate(ctx context.Context, applyFn func(sql.ChainUpdateTx) error) error + } + + WebhookManager interface { + webhooks.Broadcaster + Delete(context.Context, webhooks.Webhook) error + Info() ([]webhooks.Webhook, []webhooks.WebhookQueueInfo) + Register(context.Context, webhooks.Webhook) error + Shutdown(context.Context) error + } + + Wallet interface { + UpdateChainState(tx wallet.UpdateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate) error + } + + chainSubscriber struct { + cm ChainManager + cs ChainStore + wm WebhookManager + logger *zap.SugaredLogger + + announcementMaxAge time.Duration + wallet Wallet + + shutdownCtx context.Context + shutdownCtxCancel context.CancelCauseFunc + syncSig chan struct{} + wg sync.WaitGroup + + mu sync.Mutex + knownContracts map[types.FileContractID]bool + unsubscribeFn func() + } +) + +type ( + revision struct { + revisionNumber uint64 + fileSize uint64 + } + + contractUpdate struct { + fcid types.FileContractID + prev *revision + curr *revision + resolved bool + valid bool + } +) + +// NewChainSubscriber creates a new chain subscriber that will sync with the +// given chain manager and chain store. The returned subscriber is already +// running and can be stopped by calling Shutdown. +func NewChainSubscriber(whm WebhookManager, cm ChainManager, cs ChainStore, w Wallet, announcementMaxAge time.Duration, logger *zap.Logger) *chainSubscriber { + logger = logger.Named("chainsubscriber") + ctx, cancel := context.WithCancelCause(context.Background()) + subscriber := &chainSubscriber{ + cm: cm, + cs: cs, + wm: whm, + logger: logger.Sugar(), + + announcementMaxAge: announcementMaxAge, + wallet: w, + + shutdownCtx: ctx, + shutdownCtxCancel: cancel, + syncSig: make(chan struct{}, 1), + + knownContracts: make(map[types.FileContractID]bool), + } + + // start the subscriber + subscriber.run() + + // trigger a sync on reorgs + subscriber.unsubscribeFn = cm.OnReorg(func(ci types.ChainIndex) { + select { + case subscriber.syncSig <- struct{}{}: + subscriber.logger.Debugw("reorg triggered", "height", ci.Height, "block_id", ci.ID) + default: + } + }) + + return subscriber +} + +func (s *chainSubscriber) ChainIndex(ctx context.Context) (types.ChainIndex, error) { + return s.cs.ChainIndex(ctx) +} + +func (s *chainSubscriber) Shutdown(ctx context.Context) error { + // cancel shutdown context + s.shutdownCtxCancel(errClosed) + + // unsubscribe from the chain manager + if s.unsubscribeFn != nil { + s.unsubscribeFn() + } + + // wait for sync loop to finish + waitChan := make(chan struct{}) + go func() { + s.wg.Wait() + close(waitChan) + }() + select { + case <-ctx.Done(): + return ctx.Err() + case <-waitChan: + } + return nil +} + +func (s *chainSubscriber) applyChainUpdate(tx sql.ChainUpdateTx, cau chain.ApplyUpdate) error { + // apply host updates + b := cau.Block + if time.Since(b.Timestamp) <= s.announcementMaxAge { + hus := make(map[types.PublicKey]chain.HostAnnouncement) + chain.ForEachHostAnnouncement(b, func(hk types.PublicKey, ha chain.HostAnnouncement) { + if ha.NetAddress != "" { + hus[hk] = ha + } + }) + for hk, ha := range hus { + if err := tx.UpdateHost(hk, ha, cau.State.Index.Height, b.ID(), b.Timestamp); err != nil { + return fmt.Errorf("failed to update host: %w", err) + } else if utils.IsSynced(b) { + // broadcast host update + s.wm.BroadcastAction(s.shutdownCtx, webhooks.Event{ + Module: api.ModuleHost, + Event: api.EventUpdate, + Payload: api.EventHostUpdate{ + HostKey: hk, + NetAddr: ha.NetAddress, + Timestamp: time.Now().UTC(), + }, + }) + } + } + } + + // v1 contracts + cus := make(map[types.FileContractID]contractUpdate) + cau.ForEachFileContractElement(func(fce types.FileContractElement, _ bool, rev *types.FileContractElement, resolved, valid bool) { + cu, ok := cus[types.FileContractID(fce.ID)] + if !ok { + cus[types.FileContractID(fce.ID)] = v1ContractUpdate(fce, rev, resolved, valid) + } else if fce.FileContract.RevisionNumber > cu.curr.revisionNumber { + cus[types.FileContractID(fce.ID)] = v1ContractUpdate(fce, rev, resolved, valid) + } + }) + for _, cu := range cus { + if err := s.updateContract(tx, cau.State.Index, cu.fcid, cu.prev, cu.curr, cu.resolved, cu.valid); err != nil { + return fmt.Errorf("failed to apply v1 contract update: %w", err) + } + } + + // v2 contracts + cus = make(map[types.FileContractID]contractUpdate) + cau.ForEachV2FileContractElement(func(fce types.V2FileContractElement, _ bool, rev *types.V2FileContractElement, res types.V2FileContractResolutionType) { + cu, ok := cus[types.FileContractID(fce.ID)] + if !ok { + cus[types.FileContractID(fce.ID)] = v2ContractUpdate(fce, rev, res) + } else if fce.V2FileContract.RevisionNumber > cu.curr.revisionNumber { + cus[types.FileContractID(fce.ID)] = v2ContractUpdate(fce, rev, res) + } + }) + for _, cu := range cus { + if err := s.updateContract(tx, cau.State.Index, cu.fcid, cu.prev, cu.curr, cu.resolved, cu.valid); err != nil { + return fmt.Errorf("failed to apply v2 contract update: %w", err) + } + } + return nil +} + +func (s *chainSubscriber) revertChainUpdate(tx sql.ChainUpdateTx, cru chain.RevertUpdate) error { + // NOTE: host updates are not reverted + + // v1 contracts + var cus []contractUpdate + cru.ForEachFileContractElement(func(fce types.FileContractElement, _ bool, rev *types.FileContractElement, resolved, valid bool) { + cus = append(cus, v1ContractUpdate(fce, rev, resolved, valid)) + }) + for _, cu := range cus { + if err := s.updateContract(tx, cru.State.Index, cu.fcid, cu.prev, cu.curr, cu.resolved, cu.valid); err != nil { + return fmt.Errorf("failed to revert v1 contract update: %w", err) + } + } + + // v2 contracts + cus = cus[:0] + cru.ForEachV2FileContractElement(func(fce types.V2FileContractElement, _ bool, rev *types.V2FileContractElement, res types.V2FileContractResolutionType) { + cus = append(cus, v2ContractUpdate(fce, rev, res)) + }) + for _, cu := range cus { + if err := s.updateContract(tx, cru.State.Index, cu.fcid, cu.prev, cu.curr, cu.resolved, cu.valid); err != nil { + return fmt.Errorf("failed to revert v2 contract update: %w", err) + } + } + + return nil +} + +func (s *chainSubscriber) run() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + + for { + select { + case <-s.shutdownCtx.Done(): + return + case <-s.syncSig: + } + + if err := s.sync(); errors.Is(err, errClosed) || errors.Is(err, context.Canceled) { + return + } else if err != nil { + s.logger.Panicf("failed to sync: %v", err) + } + } + }() +} + +func (s *chainSubscriber) sync() error { + start := time.Now() + + // fetch current chain index + index, err := s.cs.ChainIndex(s.shutdownCtx) + if err != nil { + return fmt.Errorf("failed to get chain index: %w", err) + } + s.logger.Debugw("sync started", "height", index.Height, "block_id", index.ID) + sheight := index.Height / syncUpdateFrequency + + // fetch updates until we're caught up + var cnt uint64 + for index != s.cm.Tip() && !s.isClosed() { + // fetch updates + istart := time.Now() + crus, caus, err := s.cm.UpdatesSince(index, updatesBatchSize) + if err != nil { + return fmt.Errorf("failed to fetch updates: %w", err) + } + s.logger.Debugw("fetched updates since", "caus", len(caus), "crus", len(crus), "since_height", index.Height, "since_block_id", index.ID, "ms", time.Since(istart).Milliseconds(), "batch_size", updatesBatchSize) + + // process updates + var block types.Block + istart = time.Now() + index, block, err = s.processUpdates(s.shutdownCtx, crus, caus) + if err != nil { + return fmt.Errorf("failed to process updates: %w", err) + } + s.logger.Debugw("processed updates successfully", "new_height", index.Height, "new_block_id", index.ID, "ms", time.Since(istart).Milliseconds()) + cnt++ + + // broadcast consensus update + if utils.IsSynced(block) { + s.wm.BroadcastAction(s.shutdownCtx, webhooks.Event{ + Module: api.ModuleConsensus, + Event: api.EventUpdate, + Payload: api.EventConsensusUpdate{ + ConsensusState: api.ConsensusState{ + BlockHeight: index.Height, + LastBlockTime: api.TimeRFC3339(block.Timestamp), + Synced: true, + }, + TransactionFee: s.cm.RecommendedFee(), + Timestamp: time.Now().UTC(), + }}) + } + } + + s.logger.Debugw("sync completed", "height", index.Height, "block_id", index.ID, "ms", time.Since(start).Milliseconds(), "iterations", cnt) + + // info log sync progress + if index.Height/syncUpdateFrequency != sheight { + s.logger.Infow("sync progress", "height", index.Height, "block_id", index.ID) + } + return nil +} + +func (s *chainSubscriber) processUpdates(ctx context.Context, crus []chain.RevertUpdate, caus []chain.ApplyUpdate) (index types.ChainIndex, tip types.Block, _ error) { + if err := s.cs.ProcessChainUpdate(ctx, func(tx sql.ChainUpdateTx) error { + // process wallet updates + if err := s.wallet.UpdateChainState(tx, crus, caus); err != nil { + return fmt.Errorf("failed to process wallet updates: %w", err) + } + + // process revert updates + for _, cru := range crus { + if err := s.revertChainUpdate(tx, cru); err != nil { + return fmt.Errorf("failed to revert chain update: %w", err) + } + } + + // process apply updates + for _, cau := range caus { + if err := s.applyChainUpdate(tx, cau); err != nil { + return fmt.Errorf("failed to apply chain updates: %w", err) + } + } + + // update chain index + index = caus[len(caus)-1].State.Index + if err := tx.UpdateChainIndex(index); err != nil { + return fmt.Errorf("failed to update chain index: %w", err) + } + + // update failed contracts + if err := tx.UpdateFailedContracts(index.Height); err != nil { + return fmt.Errorf("failed to update failed contracts: %w", err) + } + + tip = caus[len(caus)-1].Block + return nil + }); err != nil { + return types.ChainIndex{}, types.Block{}, err + } + return +} + +func (s *chainSubscriber) updateContract(tx sql.ChainUpdateTx, index types.ChainIndex, fcid types.FileContractID, prev, curr *revision, resolved, valid bool) error { + // sanity check at least one is not nil + if prev == nil && curr == nil { + return errors.New("both prev and curr revisions are nil") // developer error + } + + // ignore unknown contracts + if !s.isKnownContract(fcid) { + return nil + } + + // fetch contract state + state, err := tx.ContractState(fcid) + if err != nil && utils.IsErr(err, api.ErrContractNotFound) { + s.updateKnownContracts(fcid, false) // ignore unknown contracts + return nil + } else if err != nil { + return fmt.Errorf("failed to get contract state: %w", err) + } else { + s.updateKnownContracts(fcid, true) // update known contracts + } + + // define a helper function to update the contract state + updateState := func(update api.ContractState) (err error) { + if state != update { + err = tx.UpdateContractState(fcid, update) + if err == nil { + state = update + } + } + return + } + + // handle reverts + if prev != nil { + // update state from 'active' -> 'pending' + if curr == nil { + if err := updateState(api.ContractStatePending); err != nil { + return fmt.Errorf("failed to update contract state: %w", err) + } + } + + // reverted renewal: 'complete' -> 'active' + if curr != nil { + if err := tx.UpdateContract(fcid, index.Height, prev.revisionNumber, prev.fileSize); err != nil { + return fmt.Errorf("failed to revert contract: %w", err) + } + if state == api.ContractStateComplete { + if err := updateState(api.ContractStateActive); err != nil { + return fmt.Errorf("failed to update contract state: %w", err) + } + s.logger.Infow("contract state changed: complete -> active", + "fcid", fcid, + "reason", "final revision reverted") + } + } + + // reverted storage proof: 'complete/failed' -> 'active' + if resolved { + if err := updateState(api.ContractStateActive); err != nil { + return fmt.Errorf("failed to update contract state: %w", err) + } + if valid { + s.logger.Infow("contract state changed: complete -> active", + "fcid", fcid, + "reason", "storage proof reverted") + } else { + s.logger.Infow("contract state changed: failed -> active", + "fcid", fcid, + "reason", "storage proof reverted") + } + } + + return nil + } + + // handle apply + if err := tx.UpdateContract(fcid, index.Height, curr.revisionNumber, curr.fileSize); err != nil { + return fmt.Errorf("failed to update contract %v: %w", fcid, err) + } + + // update state from 'pending' -> 'active' + if state == api.ContractStatePending || state == api.ContractStateUnknown { + if err := updateState(api.ContractStateActive); err != nil { + return fmt.Errorf("failed to update contract state: %w", err) + } + s.logger.Infow("contract state changed: pending -> active", + "fcid", fcid, + "reason", "contract confirmed") + } + + // renewed: 'active' -> 'complete' + if curr.revisionNumber == types.MaxRevisionNumber && curr.fileSize == 0 { + if err := updateState(api.ContractStateComplete); err != nil { + return fmt.Errorf("failed to update contract state: %w", err) + } + s.logger.Infow("contract state changed: active -> complete", + "fcid", fcid, + "reason", "final revision confirmed") + } + + // storage proof: 'active' -> 'complete/failed' + if resolved { + if err := tx.UpdateContractProofHeight(fcid, index.Height); err != nil { + return fmt.Errorf("failed to update contract proof height: %w", err) + } + if valid { + if err := updateState(api.ContractStateComplete); err != nil { + return fmt.Errorf("failed to update contract state: %w", err) + } + s.logger.Infow("contract state changed: active -> complete", + "fcid", fcid, + "reason", "storage proof valid") + } else { + if err := updateState(api.ContractStateFailed); err != nil { + return fmt.Errorf("failed to update contract state: %w", err) + } + s.logger.Infow("contract state changed: active -> failed", + "fcid", fcid, + "reason", "storage proof missed") + } + } + return nil +} + +func (s *chainSubscriber) isClosed() bool { + select { + case <-s.shutdownCtx.Done(): + return true + default: + } + return false +} + +func (s *chainSubscriber) isKnownContract(fcid types.FileContractID) bool { + s.mu.Lock() + defer s.mu.Unlock() + known, ok := s.knownContracts[fcid] + if !ok { + return true // assume known + } + return known +} + +func (s *chainSubscriber) updateKnownContracts(fcid types.FileContractID, known bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.knownContracts[fcid] = known +} + +func v1ContractUpdate(fce types.FileContractElement, rev *types.FileContractElement, resolved, valid bool) contractUpdate { + curr := &revision{ + revisionNumber: fce.FileContract.RevisionNumber, + fileSize: fce.FileContract.Filesize, + } + if rev != nil { + curr.revisionNumber = rev.FileContract.RevisionNumber + curr.fileSize = rev.FileContract.Filesize + } + return contractUpdate{ + fcid: types.FileContractID(fce.ID), + prev: nil, + curr: curr, + resolved: resolved, + valid: valid, + } +} + +func v2ContractUpdate(fce types.V2FileContractElement, rev *types.V2FileContractElement, res types.V2FileContractResolutionType) contractUpdate { + curr := &revision{ + revisionNumber: fce.V2FileContract.RevisionNumber, + fileSize: fce.V2FileContract.Filesize, + } + if rev != nil { + curr.revisionNumber = rev.V2FileContract.RevisionNumber + curr.fileSize = rev.V2FileContract.Filesize + } + + var resolved, valid bool + if res != nil { + resolved = true + switch res.(type) { + case *types.V2FileContractFinalization: + valid = true + case *types.V2FileContractRenewal: + valid = true + case *types.V2StorageProof: + valid = true + case *types.V2FileContractExpiration: + valid = fce.V2FileContract.Filesize == 0 + } + } + + return contractUpdate{ + fcid: types.FileContractID(fce.ID), + prev: nil, + curr: curr, + resolved: resolved, + valid: valid, + } +} diff --git a/bus/contractlocking.go b/internal/bus/contractlocker.go similarity index 89% rename from bus/contractlocking.go rename to internal/bus/contractlocker.go index c94372aa8..dee1d98b0 100644 --- a/bus/contractlocking.go +++ b/internal/bus/contractlocker.go @@ -46,7 +46,7 @@ func (h *lockCandidatePriorityHeap) Pop() interface{} { return x } -type contractLocks struct { +type ContractLocker struct { mu sync.Mutex locks map[types.FileContractID]*contractLock } @@ -65,13 +65,13 @@ type lockCandidate struct { timedOut <-chan struct{} } -func newContractLocks() *contractLocks { - return &contractLocks{ +func NewContractLocker() *ContractLocker { + return &ContractLocker{ locks: make(map[types.FileContractID]*contractLock), } } -func (l *contractLocks) lockForContractID(id types.FileContractID, create bool) *contractLock { +func (l *ContractLocker) lockForContractID(id types.FileContractID, create bool) *contractLock { l.mu.Lock() defer l.mu.Unlock() lock, exists := l.locks[id] @@ -86,7 +86,7 @@ func (l *contractLocks) lockForContractID(id types.FileContractID, create bool) return lock } -func (lock *contractLock) setTimer(l *contractLocks, lockID uint64, id types.FileContractID, d time.Duration) { +func (lock *contractLock) setTimer(l *ContractLocker, lockID uint64, id types.FileContractID, d time.Duration) { lock.wakeupTimer = time.AfterFunc(d, func() { l.Release(id, lockID) }) @@ -112,7 +112,7 @@ func (l *contractLock) stopTimer() { // TODO: Extend this with some sort of priority. e.g. migrations would acquire a // lock with a low priority but contract maintenance would have a very high one // to avoid being starved by low prio tasks. -func (l *contractLocks) Acquire(ctx context.Context, priority int, id types.FileContractID, d time.Duration) (uint64, error) { +func (l *ContractLocker) Acquire(ctx context.Context, priority int, id types.FileContractID, d time.Duration) (uint64, error) { lock := l.lockForContractID(id, true) // Prepare a random lockID for ourselves. @@ -156,7 +156,7 @@ func (l *contractLocks) Acquire(ctx context.Context, priority int, id types.File // KeepAlive refreshes the timer on a contract lock for a given contract if the // lockID matches the one on the lock. -func (l *contractLocks) KeepAlive(id types.FileContractID, lockID uint64, d time.Duration) error { +func (l *ContractLocker) KeepAlive(id types.FileContractID, lockID uint64, d time.Duration) error { lock := l.lockForContractID(id, false) if lock == nil { return errors.New("lock not found") @@ -174,7 +174,7 @@ func (l *contractLocks) KeepAlive(id types.FileContractID, lockID uint64, d time } // Release releases the contract lock for a given contract and lock id. -func (l *contractLocks) Release(id types.FileContractID, lockID uint64) error { +func (l *ContractLocker) Release(id types.FileContractID, lockID uint64) error { if lockID == 0 { return errors.New("can't release lock with id 0") } diff --git a/bus/contractlocking_test.go b/internal/bus/contractlocker_test.go similarity index 98% rename from bus/contractlocking_test.go rename to internal/bus/contractlocker_test.go index 120ca9ca2..1fe3b4f64 100644 --- a/bus/contractlocking_test.go +++ b/internal/bus/contractlocker_test.go @@ -13,7 +13,7 @@ import ( // TestContractAcquire is a unit test for contractLocks.Acquire. func TestContractAcquire(t *testing.T) { - locks := newContractLocks() + locks := NewContractLocker() verify := func(fcid types.FileContractID, lockID uint64) { t.Helper() @@ -117,7 +117,7 @@ func TestContractKeepalive(t *testing.T) { t.Parallel() // Create a contractLocks object. - locks := newContractLocks() + locks := NewContractLocker() // Acquire a contract. fcid := types.FileContractID{1} @@ -152,7 +152,7 @@ func TestContractKeepalive(t *testing.T) { // TestContractRelease is a unit test for contractLocks.Release. func TestContractRelease(t *testing.T) { - locks := newContractLocks() + locks := NewContractLocker() verify := func(fcid types.FileContractID, lockID uint64) { t.Helper() diff --git a/internal/bus/pinmanager.go b/internal/bus/pinmanager.go index 21591b21c..c128a8392 100644 --- a/internal/bus/pinmanager.go +++ b/internal/bus/pinmanager.go @@ -11,36 +11,25 @@ import ( "github.com/montanaflynn/stats" "github.com/shopspring/decimal" "go.sia.tech/core/types" + "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" "go.sia.tech/renterd/webhooks" "go.uber.org/zap" ) type ( - // An AutopilotStore stores autopilots. - AutopilotStore interface { + Store interface { Autopilot(ctx context.Context, id string) (api.Autopilot, error) - UpdateAutopilot(ctx context.Context, ap api.Autopilot) error - } - - // PinManager is a service that manages price pinning. - PinManager interface { - Close(context.Context) error - Run(context.Context) error - TriggerUpdate() - } - - // A SettingStore stores settings. - SettingStore interface { Setting(ctx context.Context, key string) (string, error) + UpdateAutopilot(ctx context.Context, ap api.Autopilot) error UpdateSetting(ctx context.Context, key, value string) error } ) type ( pinManager struct { - as AutopilotStore - ss SettingStore + a alerts.Alerter + s Store broadcaster webhooks.Broadcaster updateInterval time.Duration @@ -58,13 +47,16 @@ type ( } ) -func NewPinManager(broadcaster webhooks.Broadcaster, as AutopilotStore, ss SettingStore, updateInterval, rateWindow time.Duration, l *zap.Logger) *pinManager { - return &pinManager{ - as: as, - ss: ss, +// NewPinManager returns a new PinManager, responsible for pinning prices to a +// fixed value in an underlying currency. The returned pin manager is already +// running and can be stopped by calling Shutdown. +func NewPinManager(alerts alerts.Alerter, broadcaster webhooks.Broadcaster, s Store, updateInterval, rateWindow time.Duration, l *zap.Logger) *pinManager { + pm := &pinManager{ + a: alerts, + s: s, broadcaster: broadcaster, - logger: l.Sugar().Named("pricemanager"), + logger: l.Named("pricemanager").Sugar(), updateInterval: updateInterval, rateWindow: rateWindow, @@ -72,9 +64,14 @@ func NewPinManager(broadcaster webhooks.Broadcaster, as AutopilotStore, ss Setti triggerChan: make(chan struct{}, 1), closedChan: make(chan struct{}), } + + // start the pin manager + pm.run() + + return pm } -func (pm *pinManager) Close(ctx context.Context) error { +func (pm *pinManager) Shutdown(ctx context.Context) error { close(pm.closedChan) doneChan := make(chan struct{}) @@ -91,43 +88,6 @@ func (pm *pinManager) Close(ctx context.Context) error { } } -func (pm *pinManager) Run(ctx context.Context) error { - // try to update prices - if err := pm.updatePrices(ctx, true); err != nil { - return err - } - - // start the update loop - pm.wg.Add(1) - go func() { - defer pm.wg.Done() - - t := time.NewTicker(pm.updateInterval) - defer t.Stop() - - var forced bool - for { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - err := pm.updatePrices(ctx, forced) - if err != nil { - pm.logger.Warn("failed to update prices", zap.Error(err)) - } - cancel() - - forced = false - select { - case <-pm.closedChan: - return - case <-pm.triggerChan: - forced = true - case <-t.C: - } - } - }() - - return nil -} - func (pm *pinManager) TriggerUpdate() { select { case pm.triggerChan <- struct{}{}: @@ -145,7 +105,7 @@ func (pm *pinManager) averageRate() decimal.Decimal { func (pm *pinManager) pinnedSettings(ctx context.Context) (api.PricePinSettings, error) { var ps api.PricePinSettings - if pss, err := pm.ss.Setting(ctx, api.SettingPricePinning); err != nil { + if pss, err := pm.s.Setting(ctx, api.SettingPricePinning); err != nil { return api.PricePinSettings{}, err } else if err := json.Unmarshal([]byte(pss), &ps); err != nil { pm.logger.Panicf("failed to unmarshal pinned settings '%s': %v", pss, err) @@ -174,7 +134,7 @@ func (pm *pinManager) rateExceedsThreshold(threshold float64) bool { exceeded := delta.GreaterThan(cur.Mul(pct)) // log the result - pm.logger.Debugw("rate exceeds threshold", + pm.logger.Debugw("checking if rate exceeds threshold", "last", cur, "average", avg, "percentage", threshold, @@ -185,10 +145,42 @@ func (pm *pinManager) rateExceedsThreshold(threshold float64) bool { return exceeded } +func (pm *pinManager) run() { + pm.wg.Add(1) + go func() { + defer pm.wg.Done() + + t := time.NewTicker(pm.updateInterval) + defer t.Stop() + + var forced bool + for { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + err := pm.updatePrices(ctx, forced) + if err != nil { + pm.logger.Warn("failed to update prices", zap.Error(err)) + pm.a.RegisterAlert(ctx, newPricePinningFailedAlert(err)) + } else { + pm.a.DismissAlerts(ctx, alertPricePinningID) + } + cancel() + + forced = false + select { + case <-pm.closedChan: + return + case <-pm.triggerChan: + forced = true + case <-t.C: + } + } + }() +} + func (pm *pinManager) updateAutopilotSettings(ctx context.Context, autopilotID string, pins api.AutopilotPins, rate decimal.Decimal) error { var updated bool - ap, err := pm.as.Autopilot(ctx, autopilotID) + ap, err := pm.s.Autopilot(ctx, autopilotID) if err != nil { return err } @@ -225,7 +217,7 @@ func (pm *pinManager) updateAutopilotSettings(ctx context.Context, autopilotID s } // update autopilto - return pm.as.UpdateAutopilot(ctx, ap) + return pm.s.UpdateAutopilot(ctx, ap) } func (pm *pinManager) updateExchangeRates(currency string, rate float64) error { @@ -252,7 +244,7 @@ func (pm *pinManager) updateGougingSettings(ctx context.Context, pins api.Gougin // fetch gouging settings var gs api.GougingSettings - if gss, err := pm.ss.Setting(ctx, api.SettingGouging); err != nil { + if gss, err := pm.s.Setting(ctx, api.SettingGouging); err != nil { return err } else if err := json.Unmarshal([]byte(gss), &gs); err != nil { pm.logger.Panicf("failed to unmarshal gouging settings '%s': %v", gss, err) @@ -271,18 +263,6 @@ func (pm *pinManager) updateGougingSettings(ctx context.Context, pins api.Gougin } } - // update max RPC price - if pins.MaxRPCPrice.IsPinned() { - update, err := convertCurrencyToSC(decimal.NewFromFloat(pins.MaxRPCPrice.Value), rate) - if err != nil { - pm.logger.Warnw("failed to convert max RPC price to currency", zap.Error(err)) - } else if !gs.MaxRPCPrice.Equals(update) { - pm.logger.Infow("updating max RPC price", "old", gs.MaxRPCPrice, "new", update, "rate", rate) - gs.MaxRPCPrice = update - updated = true - } - } - // update max storage price if pins.MaxStorage.IsPinned() { maxStorageCurr, err := convertCurrencyToSC(decimal.NewFromFloat(pins.MaxStorage.Value), rate) @@ -322,7 +302,7 @@ func (pm *pinManager) updateGougingSettings(ctx context.Context, pins api.Gougin // update settings bytes, _ := json.Marshal(gs) - err = pm.ss.UpdateSetting(ctx, api.SettingGouging, string(bytes)) + err = pm.s.UpdateSetting(ctx, api.SettingGouging, string(bytes)) // broadcast event if err == nil { diff --git a/internal/bus/pinmanager_test.go b/internal/bus/pinmanager_test.go index a2af6e137..e5158836d 100644 --- a/internal/bus/pinmanager_test.go +++ b/internal/bus/pinmanager_test.go @@ -14,8 +14,8 @@ import ( "github.com/shopspring/decimal" "go.sia.tech/core/types" "go.sia.tech/hostd/host/settings/pin" + "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" - "go.sia.tech/renterd/build" "go.sia.tech/renterd/webhooks" "go.uber.org/zap" ) @@ -25,6 +25,43 @@ const ( testUpdateInterval = 100 * time.Millisecond ) +type mockAlerter struct { + mu sync.Mutex + alerts []alerts.Alert +} + +func (ma *mockAlerter) Alerts(ctx context.Context, opts alerts.AlertsOpts) (resp alerts.AlertsResponse, err error) { + ma.mu.Lock() + defer ma.mu.Unlock() + return alerts.AlertsResponse{Alerts: ma.alerts}, nil +} + +func (ma *mockAlerter) RegisterAlert(_ context.Context, a alerts.Alert) error { + ma.mu.Lock() + defer ma.mu.Unlock() + for _, alert := range ma.alerts { + if alert.ID == a.ID { + return nil + } + } + ma.alerts = append(ma.alerts, a) + return nil +} + +func (ma *mockAlerter) DismissAlerts(_ context.Context, ids ...types.Hash256) error { + ma.mu.Lock() + defer ma.mu.Unlock() + for _, id := range ids { + for i, a := range ma.alerts { + if a.ID == id { + ma.alerts = append(ma.alerts[:i], ma.alerts[i+1:]...) + break + } + } + } + return nil +} + type mockBroadcaster struct { events []webhooks.Event } @@ -37,8 +74,9 @@ func (meb *mockBroadcaster) BroadcastAction(ctx context.Context, e webhooks.Even type mockForexAPI struct { s *httptest.Server - mu sync.Mutex - rate float64 + mu sync.Mutex + rate float64 + unreachable bool } func newTestForexAPI() *mockForexAPI { @@ -46,6 +84,10 @@ func newTestForexAPI() *mockForexAPI { api.s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { api.mu.Lock() defer api.mu.Unlock() + if api.unreachable { + w.WriteHeader(http.StatusInternalServerError) + return + } json.NewEncoder(w).Encode(api.rate) })) return api @@ -55,28 +97,34 @@ func (api *mockForexAPI) Close() { api.s.Close() } -func (api *mockForexAPI) updateRate(rate float64) { +func (api *mockForexAPI) setRate(rate float64) { api.mu.Lock() defer api.mu.Unlock() api.rate = rate } -type mockStore struct { +func (api *mockForexAPI) setUnreachable(unreachable bool) { + api.mu.Lock() + defer api.mu.Unlock() + api.unreachable = unreachable +} + +type mockPinStore struct { mu sync.Mutex settings map[string]string autopilots map[string]api.Autopilot } -func newTestStore() *mockStore { - s := &mockStore{ +func newTestStore() *mockPinStore { + s := &mockPinStore{ autopilots: make(map[string]api.Autopilot), settings: make(map[string]string), } // add default price pin - and gouging settings - b, _ := json.Marshal(build.DefaultPricePinSettings) + b, _ := json.Marshal(api.DefaultPricePinSettings) s.settings[api.SettingPricePinning] = string(b) - b, _ = json.Marshal(build.DefaultGougingSettings) + b, _ = json.Marshal(api.DefaultGougingSettings) s.settings[api.SettingGouging] = string(b) // add default autopilot @@ -92,7 +140,7 @@ func newTestStore() *mockStore { return s } -func (ms *mockStore) gougingSettings() api.GougingSettings { +func (ms *mockPinStore) gougingSettings() api.GougingSettings { val, err := ms.Setting(context.Background(), api.SettingGouging) if err != nil { panic(err) @@ -104,32 +152,32 @@ func (ms *mockStore) gougingSettings() api.GougingSettings { return gs } -func (ms *mockStore) updatPinnedSettings(pps api.PricePinSettings) { +func (ms *mockPinStore) updatPinnedSettings(pps api.PricePinSettings) { b, _ := json.Marshal(pps) ms.UpdateSetting(context.Background(), api.SettingPricePinning, string(b)) time.Sleep(2 * testUpdateInterval) } -func (ms *mockStore) Setting(ctx context.Context, key string) (string, error) { +func (ms *mockPinStore) Setting(ctx context.Context, key string) (string, error) { ms.mu.Lock() defer ms.mu.Unlock() return ms.settings[key], nil } -func (ms *mockStore) UpdateSetting(ctx context.Context, key, value string) error { +func (ms *mockPinStore) UpdateSetting(ctx context.Context, key, value string) error { ms.mu.Lock() defer ms.mu.Unlock() ms.settings[key] = value return nil } -func (ms *mockStore) Autopilot(ctx context.Context, id string) (api.Autopilot, error) { +func (ms *mockPinStore) Autopilot(ctx context.Context, id string) (api.Autopilot, error) { ms.mu.Lock() defer ms.mu.Unlock() return ms.autopilots[id], nil } -func (ms *mockStore) UpdateAutopilot(ctx context.Context, autopilot api.Autopilot) error { +func (ms *mockPinStore) UpdateAutopilot(ctx context.Context, autopilot api.Autopilot) error { ms.mu.Lock() defer ms.mu.Unlock() ms.autopilots[autopilot.ID] = autopilot @@ -140,18 +188,16 @@ func TestPinManager(t *testing.T) { // mock dependencies ms := newTestStore() eb := &mockBroadcaster{} + a := &mockAlerter{} // mock forex api forex := newTestForexAPI() defer forex.Close() - // start a pinmanager - pm := NewPinManager(eb, ms, ms, testUpdateInterval, time.Minute, zap.NewNop()) - if err := pm.Run(context.Background()); err != nil { - t.Fatal(err) - } + // create a pinmanager + pm := NewPinManager(a, eb, ms, testUpdateInterval, time.Minute, zap.NewNop()) defer func() { - if err := pm.Close(context.Background()); err != nil { + if err := pm.Shutdown(context.Background()); err != nil { t.Fatal(err) } }() @@ -170,7 +216,7 @@ func TestPinManager(t *testing.T) { } // enable price pinning - pps := build.DefaultPricePinSettings + pps := api.DefaultPricePinSettings pps.Enabled = true pps.Currency = "usd" pps.Threshold = 0.5 @@ -183,12 +229,11 @@ func TestPinManager(t *testing.T) { } // update exchange rate and fetch current gouging settings - forex.updateRate(2.5) + forex.setRate(2.5) gs := ms.gougingSettings() // configure all pins but disable them for now pps.GougingSettingsPins.MaxDownload = api.Pin{Value: 3, Pinned: false} - pps.GougingSettingsPins.MaxRPCPrice = api.Pin{Value: 3, Pinned: false} pps.GougingSettingsPins.MaxStorage = api.Pin{Value: 3, Pinned: false} pps.GougingSettingsPins.MaxUpload = api.Pin{Value: 3, Pinned: false} ms.updatPinnedSettings(pps) @@ -214,21 +259,19 @@ func TestPinManager(t *testing.T) { // enable the rest of the pins pps.GougingSettingsPins.MaxDownload.Pinned = true - pps.GougingSettingsPins.MaxRPCPrice.Pinned = true pps.GougingSettingsPins.MaxStorage.Pinned = true pps.GougingSettingsPins.MaxUpload.Pinned = true ms.updatPinnedSettings(pps) // assert they're all updated if gss := ms.gougingSettings(); gss.MaxDownloadPrice.Equals(gs.MaxDownloadPrice) || - gss.MaxRPCPrice.Equals(gs.MaxRPCPrice) || gss.MaxStoragePrice.Equals(gs.MaxStoragePrice) || gss.MaxUploadPrice.Equals(gs.MaxUploadPrice) { t.Fatalf("expected gouging settings to be updated, got %v = %v", gss, gs) } // increase rate so average isn't catching up to us - forex.updateRate(3) + forex.setRate(3) // fetch autopilot ap, _ := ms.Autopilot(context.Background(), testAutopilotID) @@ -257,6 +300,26 @@ func TestPinManager(t *testing.T) { if app, _ := ms.Autopilot(context.Background(), testAutopilotID); app.Config.Contracts.Allowance.Equals(ap.Config.Contracts.Allowance) { t.Fatalf("expected autopilot to be updated, got %v = %v", app.Config.Contracts.Allowance, ap.Config.Contracts.Allowance) } + + // make forex API return an error + forex.setUnreachable(true) + + // assert alert was registered + ms.updatPinnedSettings(pps) + res, _ := a.Alerts(context.Background(), alerts.AlertsOpts{}) + if len(res.Alerts) == 0 { + t.Fatalf("expected 1 alert, got %d", len(a.alerts)) + } + + // make forex API return a valid response + forex.setUnreachable(false) + + // assert alert was dismissed + ms.updatPinnedSettings(pps) + res, _ = a.Alerts(context.Background(), alerts.AlertsOpts{}) + if len(res.Alerts) != 0 { + t.Fatalf("expected 0 alerts, got %d", len(a.alerts)) + } } // TestConvertConvertCurrencyToSC tests the conversion of a currency to Siacoins. diff --git a/bus/uploadingsectors.go b/internal/bus/sectorscache.go similarity index 55% rename from bus/uploadingsectors.go rename to internal/bus/sectorscache.go index 18c64a7c5..40930ceca 100644 --- a/bus/uploadingsectors.go +++ b/internal/bus/sectorscache.go @@ -18,7 +18,7 @@ const ( ) type ( - uploadingSectorsCache struct { + SectorsCache struct { mu sync.Mutex uploads map[api.UploadID]*ongoingUpload renewedTo map[types.FileContractID]types.FileContractID @@ -30,13 +30,6 @@ type ( } ) -func newUploadingSectorsCache() *uploadingSectorsCache { - return &uploadingSectorsCache{ - uploads: make(map[api.UploadID]*ongoingUpload), - renewedTo: make(map[types.FileContractID]types.FileContractID), - } -} - func (ou *ongoingUpload) addSector(fcid types.FileContractID, root types.Hash256) { ou.contractSectors[fcid] = append(ou.contractSectors[fcid], root) } @@ -48,93 +41,100 @@ func (ou *ongoingUpload) sectors(fcid types.FileContractID) (roots []types.Hash2 return } -func (usc *uploadingSectorsCache) AddSector(uID api.UploadID, fcid types.FileContractID, root types.Hash256) error { - usc.mu.Lock() - defer usc.mu.Unlock() +func NewSectorsCache() *SectorsCache { + return &SectorsCache{ + uploads: make(map[api.UploadID]*ongoingUpload), + renewedTo: make(map[types.FileContractID]types.FileContractID), + } +} + +func (sc *SectorsCache) AddSector(uID api.UploadID, fcid types.FileContractID, root types.Hash256) error { + sc.mu.Lock() + defer sc.mu.Unlock() - ongoing, ok := usc.uploads[uID] + ongoing, ok := sc.uploads[uID] if !ok { return fmt.Errorf("%w; id '%v'", api.ErrUnknownUpload, uID) } - fcid = usc.latestFCID(fcid) + fcid = sc.latestFCID(fcid) ongoing.addSector(fcid, root) return nil } -func (usc *uploadingSectorsCache) FinishUpload(uID api.UploadID) { - usc.mu.Lock() - defer usc.mu.Unlock() - delete(usc.uploads, uID) +func (sc *SectorsCache) FinishUpload(uID api.UploadID) { + sc.mu.Lock() + defer sc.mu.Unlock() + delete(sc.uploads, uID) // prune expired uploads - for uID, ongoing := range usc.uploads { + for uID, ongoing := range sc.uploads { if time.Since(ongoing.started) > cacheExpiry { - delete(usc.uploads, uID) + delete(sc.uploads, uID) } } // prune renewed to map - for old, new := range usc.renewedTo { - if _, exists := usc.renewedTo[new]; exists { - delete(usc.renewedTo, old) + for old, new := range sc.renewedTo { + if _, exists := sc.renewedTo[new]; exists { + delete(sc.renewedTo, old) } } } -func (usc *uploadingSectorsCache) HandleRenewal(fcid, renewedFrom types.FileContractID) { - usc.mu.Lock() - defer usc.mu.Unlock() +func (sc *SectorsCache) HandleRenewal(fcid, renewedFrom types.FileContractID) { + sc.mu.Lock() + defer sc.mu.Unlock() - for _, upload := range usc.uploads { + for _, upload := range sc.uploads { if _, exists := upload.contractSectors[renewedFrom]; exists { upload.contractSectors[fcid] = upload.contractSectors[renewedFrom] upload.contractSectors[renewedFrom] = nil } } - usc.renewedTo[renewedFrom] = fcid + sc.renewedTo[renewedFrom] = fcid } -func (usc *uploadingSectorsCache) Pending(fcid types.FileContractID) (size uint64) { - usc.mu.Lock() - defer usc.mu.Unlock() +func (sc *SectorsCache) Pending(fcid types.FileContractID) (size uint64) { + sc.mu.Lock() + defer sc.mu.Unlock() - fcid = usc.latestFCID(fcid) - for _, ongoing := range usc.uploads { + fcid = sc.latestFCID(fcid) + for _, ongoing := range sc.uploads { size += uint64(len(ongoing.sectors(fcid))) * rhp.SectorSize } return } -func (usc *uploadingSectorsCache) Sectors(fcid types.FileContractID) (roots []types.Hash256) { - usc.mu.Lock() - defer usc.mu.Unlock() +func (sc *SectorsCache) Sectors(fcid types.FileContractID) (roots []types.Hash256) { + sc.mu.Lock() + defer sc.mu.Unlock() - fcid = usc.latestFCID(fcid) - for _, ongoing := range usc.uploads { + fcid = sc.latestFCID(fcid) + for _, ongoing := range sc.uploads { roots = append(roots, ongoing.sectors(fcid)...) } return } -func (usc *uploadingSectorsCache) StartUpload(uID api.UploadID) error { - usc.mu.Lock() - defer usc.mu.Unlock() +func (sc *SectorsCache) StartUpload(uID api.UploadID) error { + sc.mu.Lock() + defer sc.mu.Unlock() // check if upload already exists - if _, exists := usc.uploads[uID]; exists { + if _, exists := sc.uploads[uID]; exists { return fmt.Errorf("%w; id '%v'", api.ErrUploadAlreadyExists, uID) } - usc.uploads[uID] = &ongoingUpload{ + sc.uploads[uID] = &ongoingUpload{ started: time.Now(), contractSectors: make(map[types.FileContractID][]types.Hash256), } return nil } -func (usc *uploadingSectorsCache) latestFCID(fcid types.FileContractID) types.FileContractID { - if latest, ok := usc.renewedTo[fcid]; ok { +func (um *SectorsCache) latestFCID(fcid types.FileContractID) types.FileContractID { + if latest, ok := um.renewedTo[fcid]; ok { return latest } return fcid diff --git a/internal/bus/sectorscache_test.go b/internal/bus/sectorscache_test.go new file mode 100644 index 000000000..36c2e4231 --- /dev/null +++ b/internal/bus/sectorscache_test.go @@ -0,0 +1,120 @@ +package bus + +import ( + "errors" + "testing" + + rhpv2 "go.sia.tech/core/rhp/v2" + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "lukechampine.com/frand" +) + +func TestUploadingSectorsCache(t *testing.T) { + sc := NewSectorsCache() + + uID1 := newTestUploadID() + uID2 := newTestUploadID() + + fcid1 := types.FileContractID{1} + fcid2 := types.FileContractID{2} + fcid3 := types.FileContractID{3} + + sc.StartUpload(uID1) + sc.StartUpload(uID2) + + _ = sc.AddSector(uID1, fcid1, types.Hash256{1}) + _ = sc.AddSector(uID1, fcid2, types.Hash256{2}) + _ = sc.AddSector(uID2, fcid2, types.Hash256{3}) + + if roots1 := sc.Sectors(fcid1); len(roots1) != 1 || roots1[0] != (types.Hash256{1}) { + t.Fatal("unexpected cached sectors") + } + if roots2 := sc.Sectors(fcid2); len(roots2) != 2 { + t.Fatal("unexpected cached sectors", roots2) + } + if roots3 := sc.Sectors(fcid3); len(roots3) != 0 { + t.Fatal("unexpected cached sectors") + } + + if o1, exists := sc.uploads[uID1]; !exists || o1.started.IsZero() { + t.Fatal("unexpected") + } + if o2, exists := sc.uploads[uID2]; !exists || o2.started.IsZero() { + t.Fatal("unexpected") + } + + sc.FinishUpload(uID1) + if roots1 := sc.Sectors(fcid1); len(roots1) != 0 { + t.Fatal("unexpected cached sectors") + } + if roots2 := sc.Sectors(fcid2); len(roots2) != 1 || roots2[0] != (types.Hash256{3}) { + t.Fatal("unexpected cached sectors") + } + + sc.FinishUpload(uID2) + if roots2 := sc.Sectors(fcid1); len(roots2) != 0 { + t.Fatal("unexpected cached sectors") + } + + if err := sc.AddSector(uID1, fcid1, types.Hash256{1}); !errors.Is(err, api.ErrUnknownUpload) { + t.Fatal("unexpected error", err) + } + if err := sc.StartUpload(uID1); err != nil { + t.Fatal("unexpected error", err) + } + if err := sc.StartUpload(uID1); !errors.Is(err, api.ErrUploadAlreadyExists) { + t.Fatal("unexpected error", err) + } + + // reset cache + sc = NewSectorsCache() + + // track upload that uploads across two contracts + sc.StartUpload(uID1) + sc.AddSector(uID1, fcid1, types.Hash256{1}) + sc.AddSector(uID1, fcid1, types.Hash256{2}) + sc.HandleRenewal(fcid2, fcid1) + sc.AddSector(uID1, fcid2, types.Hash256{3}) + sc.AddSector(uID1, fcid2, types.Hash256{4}) + + // assert pending sizes for both contracts should be 4 sectors + p1 := sc.Pending(fcid1) + p2 := sc.Pending(fcid2) + if p1 != p2 || p1 != 4*rhpv2.SectorSize { + t.Fatal("unexpected pending size", p1/rhpv2.SectorSize, p2/rhpv2.SectorSize) + } + + // assert sectors for both contracts contain 4 sectors + s1 := sc.Sectors(fcid1) + s2 := sc.Sectors(fcid2) + if len(s1) != 4 || len(s2) != 4 { + t.Fatal("unexpected sectors", len(s1), len(s2)) + } + + // finish upload + sc.FinishUpload(uID1) + s1 = sc.Sectors(fcid1) + s2 = sc.Sectors(fcid2) + if len(s1) != 0 || len(s2) != 0 { + t.Fatal("unexpected sectors", len(s1), len(s2)) + } + + // renew the contract + sc.HandleRenewal(fcid3, fcid2) + + // trigger pruning + sc.StartUpload(uID2) + sc.FinishUpload(uID2) + + // assert renewedTo gets pruned + if len(sc.renewedTo) != 1 { + t.Fatal("unexpected", len(sc.renewedTo)) + } +} + +func newTestUploadID() api.UploadID { + var uID api.UploadID + frand.Read(uID[:]) + return uID +} diff --git a/internal/bus/walletmetricsrecorder.go b/internal/bus/walletmetricsrecorder.go new file mode 100644 index 000000000..4d3205043 --- /dev/null +++ b/internal/bus/walletmetricsrecorder.go @@ -0,0 +1,104 @@ +package bus + +import ( + "context" + "sync" + "time" + + "go.sia.tech/coreutils/wallet" + "go.sia.tech/renterd/api" + "go.uber.org/zap" +) + +type ( + WalletMetricsRecorder struct { + store MetricsStore + wallet WalletBalance + + shutdownChan chan struct{} + wg sync.WaitGroup + + logger *zap.SugaredLogger + } + + MetricsStore interface { + RecordWalletMetric(ctx context.Context, metrics ...api.WalletMetric) error + } + + WalletBalance interface { + Balance() (wallet.Balance, error) + } +) + +// NewWalletMetricRecorder returns a recorder that periodically records wallet +// metrics. The recorder is already running and can be stopped by calling +// Shutdown. +func NewWalletMetricRecorder(store MetricsStore, wallet WalletBalance, interval time.Duration, logger *zap.Logger) *WalletMetricsRecorder { + logger = logger.Named("walletmetricsrecorder") + recorder := &WalletMetricsRecorder{ + store: store, + wallet: wallet, + shutdownChan: make(chan struct{}), + logger: logger.Sugar(), + } + recorder.run(interval) + return recorder +} + +func (wmr *WalletMetricsRecorder) run(interval time.Duration) { + wmr.wg.Add(1) + go func() { + defer wmr.wg.Done() + + t := time.NewTicker(interval) + defer t.Stop() + + for { + balance, err := wmr.wallet.Balance() + if err != nil { + wmr.logger.Error("failed to get wallet balance", zap.Error(err)) + } else { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + if err = wmr.store.RecordWalletMetric(ctx, api.WalletMetric{ + Timestamp: api.TimeRFC3339(time.Now().UTC()), + Spendable: balance.Spendable, + Confirmed: balance.Confirmed, + Unconfirmed: balance.Unconfirmed, + Immature: balance.Immature, + }); err != nil { + wmr.logger.Error("failed to record wallet metric", zap.Error(err)) + } else { + wmr.logger.Debugw("successfully recorded wallet metrics", + zap.Stringer("spendable", balance.Spendable), + zap.Stringer("confirmed", balance.Confirmed), + zap.Stringer("unconfirmed", balance.Unconfirmed), + zap.Stringer("immature", balance.Immature)) + } + cancel() + } + + select { + case <-wmr.shutdownChan: + return + case <-t.C: + } + } + }() +} + +func (wmr *WalletMetricsRecorder) Shutdown(ctx context.Context) error { + close(wmr.shutdownChan) + + waitChan := make(chan struct{}) + go func() { + wmr.wg.Wait() + close(waitChan) + }() + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-waitChan: + return nil + } +} diff --git a/internal/gouging/gouging.go b/internal/gouging/gouging.go new file mode 100644 index 000000000..8e729247d --- /dev/null +++ b/internal/gouging/gouging.go @@ -0,0 +1,492 @@ +package gouging + +import ( + "context" + "errors" + "fmt" + "time" + + rhpv2 "go.sia.tech/core/rhp/v2" + rhpv3 "go.sia.tech/core/rhp/v3" + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" +) + +const ( + bytesPerTB = 1e12 + + // maxBaseRPCPriceVsBandwidth is the max ratio for sane pricing between the + // MinBaseRPCPrice and the MinDownloadBandwidthPrice. This ensures that 1 + // million base RPC charges are at most 1% of the cost to download 4TB. This + // ratio should be used by checking that the MinBaseRPCPrice is less than or + // equal to the MinDownloadBandwidthPrice multiplied by this constant + maxBaseRPCPriceVsBandwidth = uint64(40e3) + + // maxSectorAccessPriceVsBandwidth is the max ratio for sane pricing between + // the MinSectorAccessPrice and the MinDownloadBandwidthPrice. This ensures + // that 1 million base accesses are at most 10% of the cost to download 4TB. + // This ratio should be used by checking that the MinSectorAccessPrice is + // less than or equal to the MinDownloadBandwidthPrice multiplied by this + // constant + maxSectorAccessPriceVsBandwidth = uint64(400e3) +) + +var ( + errHostSettingsGouging = errors.New("host settings gouging detected") + ErrPriceTableGouging = errors.New("price table gouging detected") +) + +type ( + ConsensusState interface { + ConsensusState(ctx context.Context) (api.ConsensusState, error) + } + + Checker interface { + Check(_ *rhpv2.HostSettings, _ *rhpv3.HostPriceTable) api.HostGougingBreakdown + CheckSettings(rhpv2.HostSettings) api.HostGougingBreakdown + CheckUnusedDefaults(rhpv3.HostPriceTable) error + BlocksUntilBlockHeightGouging(hostHeight uint64) int64 + } + + checker struct { + consensusState api.ConsensusState + settings api.GougingSettings + txFee types.Currency + + period *uint64 + renewWindow *uint64 + } +) + +var _ Checker = checker{} + +func NewChecker(gs api.GougingSettings, cs api.ConsensusState, txnFee types.Currency, period, renewWindow *uint64) Checker { + return checker{ + consensusState: cs, + settings: gs, + txFee: txnFee, + + period: period, + renewWindow: renewWindow, + } +} + +func (gc checker) BlocksUntilBlockHeightGouging(hostHeight uint64) int64 { + blockHeight := gc.consensusState.BlockHeight + leeway := gc.settings.HostBlockHeightLeeway + var minHeight uint64 + if blockHeight >= uint64(leeway) { + minHeight = blockHeight - uint64(leeway) + } + return int64(hostHeight) - int64(minHeight) +} + +func (gc checker) Check(hs *rhpv2.HostSettings, pt *rhpv3.HostPriceTable) api.HostGougingBreakdown { + if hs == nil && pt == nil { + panic("gouging checker needs to be provided with at least host settings or a price table") // developer error + } + + return api.HostGougingBreakdown{ + ContractErr: errsToStr( + checkContractGougingRHPv2(gc.period, gc.renewWindow, hs), + checkContractGougingRHPv3(gc.period, gc.renewWindow, pt), + ), + DownloadErr: errsToStr(checkDownloadGougingRHPv3(gc.settings, pt)), + GougingErr: errsToStr( + checkPriceGougingPT(gc.settings, gc.consensusState, gc.txFee, pt), + checkPriceGougingHS(gc.settings, hs), + ), + PruneErr: errsToStr(checkPruneGougingRHPv2(gc.settings, hs)), + UploadErr: errsToStr(checkUploadGougingRHPv3(gc.settings, pt)), + } +} + +func (gc checker) CheckSettings(hs rhpv2.HostSettings) api.HostGougingBreakdown { + return gc.Check(&hs, nil) +} + +func (gc checker) CheckUnusedDefaults(pt rhpv3.HostPriceTable) error { + return checkUnusedDefaults(pt) +} + +func checkPriceGougingHS(gs api.GougingSettings, hs *rhpv2.HostSettings) error { + // check if we have settings + if hs == nil { + return nil + } + // check base rpc price + if !gs.MaxRPCPrice.IsZero() && hs.BaseRPCPrice.Cmp(gs.MaxRPCPrice) > 0 { + return fmt.Errorf("rpc price exceeds max: %v > %v", hs.BaseRPCPrice, gs.MaxRPCPrice) + } + maxBaseRPCPrice := hs.DownloadBandwidthPrice.Mul64(maxBaseRPCPriceVsBandwidth) + if hs.BaseRPCPrice.Cmp(maxBaseRPCPrice) > 0 { + return fmt.Errorf("rpc price too high, %v > %v", hs.BaseRPCPrice, maxBaseRPCPrice) + } + + // check sector access price + if hs.DownloadBandwidthPrice.IsZero() { + hs.DownloadBandwidthPrice = types.NewCurrency64(1) + } + maxSectorAccessPrice := hs.DownloadBandwidthPrice.Mul64(maxSectorAccessPriceVsBandwidth) + if hs.SectorAccessPrice.Cmp(maxSectorAccessPrice) > 0 { + return fmt.Errorf("sector access price too high, %v > %v", hs.SectorAccessPrice, maxSectorAccessPrice) + } + + // check max storage price + if !gs.MaxStoragePrice.IsZero() && hs.StoragePrice.Cmp(gs.MaxStoragePrice) > 0 { + return fmt.Errorf("storage price exceeds max: %v > %v", hs.StoragePrice, gs.MaxStoragePrice) + } + + // check contract price + if !gs.MaxContractPrice.IsZero() && hs.ContractPrice.Cmp(gs.MaxContractPrice) > 0 { + return fmt.Errorf("contract price exceeds max: %v > %v", hs.ContractPrice, gs.MaxContractPrice) + } + + // check max EA balance + if hs.MaxEphemeralAccountBalance.Cmp(gs.MinMaxEphemeralAccountBalance) < 0 { + return fmt.Errorf("'MaxEphemeralAccountBalance' is less than the allowed minimum value, %v < %v", hs.MaxEphemeralAccountBalance, gs.MinMaxEphemeralAccountBalance) + } + + // check EA expiry + if hs.EphemeralAccountExpiry < gs.MinAccountExpiry { + return fmt.Errorf("'EphemeralAccountExpiry' is less than the allowed minimum value, %v < %v", hs.EphemeralAccountExpiry, gs.MinAccountExpiry) + } + + return nil +} + +// TODO: if we ever stop assuming that certain prices in the pricetable are +// always set to 1H we should account for those fields in +// `hostPeriodCostForScore` as well. +func checkPriceGougingPT(gs api.GougingSettings, cs api.ConsensusState, txnFee types.Currency, pt *rhpv3.HostPriceTable) error { + // check if we have a price table + if pt == nil { + return nil + } + + // check unused defaults + if err := checkUnusedDefaults(*pt); err != nil { + return err + } + + // check base rpc price + if !gs.MaxRPCPrice.IsZero() && gs.MaxRPCPrice.Cmp(pt.InitBaseCost) < 0 { + return fmt.Errorf("init base cost exceeds max: %v > %v", pt.InitBaseCost, gs.MaxRPCPrice) + } + + // check contract price + if !gs.MaxContractPrice.IsZero() && pt.ContractPrice.Cmp(gs.MaxContractPrice) > 0 { + return fmt.Errorf("contract price exceeds max: %v > %v", pt.ContractPrice, gs.MaxContractPrice) + } + + // check max storage + if !gs.MaxStoragePrice.IsZero() && pt.WriteStoreCost.Cmp(gs.MaxStoragePrice) > 0 { + return fmt.Errorf("storage price exceeds max: %v > %v", pt.WriteStoreCost, gs.MaxStoragePrice) + } + + // check max collateral + if pt.MaxCollateral.IsZero() { + return errors.New("MaxCollateral of host is 0") + } + + // check LatestRevisionCost - expect sane value + maxRevisionCost, overflow := gs.MaxRPCPrice.AddWithOverflow(gs.MaxDownloadPrice.Div64(bytesPerTB).Mul64(2048)) + if overflow { + maxRevisionCost = types.MaxCurrency + } + if pt.LatestRevisionCost.Cmp(maxRevisionCost) > 0 { + return fmt.Errorf("LatestRevisionCost of %v exceeds maximum cost of %v", pt.LatestRevisionCost, maxRevisionCost) + } + + // check block height - if too much time has passed since the last block + // there is a chance we are not up-to-date anymore. So we only check whether + // the host's height is at least equal to ours. + if !cs.Synced || time.Since(cs.LastBlockTime.Std()) > time.Hour { + if pt.HostBlockHeight < cs.BlockHeight { + return fmt.Errorf("consensus not synced and host block height is lower, %v < %v", pt.HostBlockHeight, cs.BlockHeight) + } + } else { + var minHeight uint64 + if cs.BlockHeight >= uint64(gs.HostBlockHeightLeeway) { + minHeight = cs.BlockHeight - uint64(gs.HostBlockHeightLeeway) + } + maxHeight := cs.BlockHeight + uint64(gs.HostBlockHeightLeeway) + if !(minHeight <= pt.HostBlockHeight && pt.HostBlockHeight <= maxHeight) { + return fmt.Errorf("consensus is synced and host block height is not within range, %v-%v %v", minHeight, maxHeight, pt.HostBlockHeight) + } + } + + // check TxnFeeMaxRecommended - expect at most a multiple of our fee + if !txnFee.IsZero() && pt.TxnFeeMaxRecommended.Cmp(txnFee.Mul64(5)) > 0 { + return fmt.Errorf("TxnFeeMaxRecommended %v exceeds %v", pt.TxnFeeMaxRecommended, txnFee.Mul64(5)) + } + + // check TxnFeeMinRecommended - expect it to be lower or equal than the max + if pt.TxnFeeMinRecommended.Cmp(pt.TxnFeeMaxRecommended) > 0 { + return fmt.Errorf("TxnFeeMinRecommended is greater than TxnFeeMaxRecommended, %v > %v", pt.TxnFeeMinRecommended, pt.TxnFeeMaxRecommended) + } + + // check Validity + if pt.Validity < gs.MinPriceTableValidity { + return fmt.Errorf("'Validity' is less than the allowed minimum value, %v < %v", pt.Validity, gs.MinPriceTableValidity) + } + + return nil +} + +func checkContractGougingRHPv2(period, renewWindow *uint64, hs *rhpv2.HostSettings) (err error) { + // period and renew window might be nil since we don't always have access to + // these settings when performing gouging checks + if hs == nil || period == nil || renewWindow == nil { + return nil + } + + err = checkContractGouging(*period, *renewWindow, hs.MaxDuration, hs.WindowSize) + if err != nil { + err = fmt.Errorf("%w: %v", errHostSettingsGouging, err) + } + return +} + +func checkContractGougingRHPv3(period, renewWindow *uint64, pt *rhpv3.HostPriceTable) (err error) { + // period and renew window might be nil since we don't always have access to + // these settings when performing gouging checks + if pt == nil || period == nil || renewWindow == nil { + return nil + } + err = checkContractGouging(*period, *renewWindow, pt.MaxDuration, pt.WindowSize) + if err != nil { + err = fmt.Errorf("%w: %v", ErrPriceTableGouging, err) + } + return +} + +func checkContractGouging(period, renewWindow, maxDuration, windowSize uint64) error { + // check MaxDuration + if period != 0 && period > maxDuration { + return fmt.Errorf("MaxDuration %v is lower than the period %v", maxDuration, period) + } + + // check WindowSize + if renewWindow != 0 && renewWindow < windowSize { + return fmt.Errorf("minimum WindowSize %v is greater than the renew window %v", windowSize, renewWindow) + } + + return nil +} + +func checkPruneGougingRHPv2(gs api.GougingSettings, hs *rhpv2.HostSettings) error { + if hs == nil { + return nil + } + // pruning costs are similar to sector read costs in a way because they + // include base costs and download bandwidth costs, to avoid re-adding all + // RHPv2 cost calculations we reuse download gouging checks to cover pruning + sectorDownloadPrice, overflow := sectorReadCost( + types.NewCurrency64(1), // 1H + hs.SectorAccessPrice, + hs.BaseRPCPrice, + hs.DownloadBandwidthPrice, + hs.UploadBandwidthPrice, + ) + if overflow { + return fmt.Errorf("%w: overflow detected when computing sector download price", errHostSettingsGouging) + } + dpptb, overflow := sectorDownloadPrice.Mul64WithOverflow(uint64(bytesPerTB) / rhpv2.SectorSize) // sectors per TB + if overflow { + return fmt.Errorf("%w: overflow detected when computing download price per TiB", errHostSettingsGouging) + } + if !gs.MaxDownloadPrice.IsZero() && dpptb.Cmp(gs.MaxDownloadPrice) > 0 { + return fmt.Errorf("%w: cost per TiB exceeds max dl price: %v > %v", errHostSettingsGouging, dpptb, gs.MaxDownloadPrice) + } + return nil +} + +func checkDownloadGougingRHPv3(gs api.GougingSettings, pt *rhpv3.HostPriceTable) error { + if pt == nil { + return nil + } + sectorDownloadPrice, overflow := sectorReadCostRHPv3(*pt) + if overflow { + return fmt.Errorf("%w: overflow detected when computing sector download price", ErrPriceTableGouging) + } + dpptb, overflow := sectorDownloadPrice.Mul64WithOverflow(uint64(bytesPerTB) / rhpv2.SectorSize) // sectors per TiB + if overflow { + return fmt.Errorf("%w: overflow detected when computing download price per TiB", ErrPriceTableGouging) + } + if !gs.MaxDownloadPrice.IsZero() && dpptb.Cmp(gs.MaxDownloadPrice) > 0 { + return fmt.Errorf("%w: cost per TiB exceeds max dl price: %v > %v", ErrPriceTableGouging, dpptb, gs.MaxDownloadPrice) + } + return nil +} + +func checkUploadGougingRHPv3(gs api.GougingSettings, pt *rhpv3.HostPriceTable) error { + if pt == nil { + return nil + } + sectorUploadPricePerMonth, overflow := sectorUploadCostRHPv3(*pt) + if overflow { + return fmt.Errorf("%w: overflow detected when computing sector price", ErrPriceTableGouging) + } + uploadPrice, overflow := sectorUploadPricePerMonth.Mul64WithOverflow(uint64(bytesPerTB) / rhpv2.SectorSize) // sectors per TiB + if overflow { + return fmt.Errorf("%w: overflow detected when computing upload price per TiB", ErrPriceTableGouging) + } + if !gs.MaxUploadPrice.IsZero() && uploadPrice.Cmp(gs.MaxUploadPrice) > 0 { + return fmt.Errorf("%w: cost per TiB exceeds max ul price: %v > %v", ErrPriceTableGouging, uploadPrice, gs.MaxUploadPrice) + } + return nil +} + +func checkUnusedDefaults(pt rhpv3.HostPriceTable) error { + // check ReadLengthCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.ReadLengthCost) < 0 { + return fmt.Errorf("ReadLengthCost of host is %v but should be %v", pt.ReadLengthCost, types.NewCurrency64(1)) + } + + // check WriteLengthCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.WriteLengthCost) < 0 { + return fmt.Errorf("WriteLengthCost of %v exceeds 1H", pt.WriteLengthCost) + } + + // check AccountBalanceCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.AccountBalanceCost) < 0 { + return fmt.Errorf("AccountBalanceCost of %v exceeds 1H", pt.AccountBalanceCost) + } + + // check FundAccountCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.FundAccountCost) < 0 { + return fmt.Errorf("FundAccountCost of %v exceeds 1H", pt.FundAccountCost) + } + + // check UpdatePriceTableCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.UpdatePriceTableCost) < 0 { + return fmt.Errorf("UpdatePriceTableCost of %v exceeds 1H", pt.UpdatePriceTableCost) + } + + // check HasSectorBaseCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.HasSectorBaseCost) < 0 { + return fmt.Errorf("HasSectorBaseCost of %v exceeds 1H", pt.HasSectorBaseCost) + } + + // check MemoryTimeCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.MemoryTimeCost) < 0 { + return fmt.Errorf("MemoryTimeCost of %v exceeds 1H", pt.MemoryTimeCost) + } + + // check DropSectorsBaseCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.DropSectorsBaseCost) < 0 { + return fmt.Errorf("DropSectorsBaseCost of %v exceeds 1H", pt.DropSectorsBaseCost) + } + + // check DropSectorsUnitCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.DropSectorsUnitCost) < 0 { + return fmt.Errorf("DropSectorsUnitCost of %v exceeds 1H", pt.DropSectorsUnitCost) + } + + // check SwapSectorBaseCost - should be 1H as it's unused by hosts + if types.NewCurrency64(1).Cmp(pt.SwapSectorBaseCost) < 0 { + return fmt.Errorf("SwapSectorBaseCost of %v exceeds 1H", pt.SwapSectorBaseCost) + } + + // check SubscriptionMemoryCost - expect 1H default + if types.NewCurrency64(1).Cmp(pt.SubscriptionMemoryCost) < 0 { + return fmt.Errorf("SubscriptionMemoryCost of %v exceeds 1H", pt.SubscriptionMemoryCost) + } + + // check SubscriptionNotificationCost - expect 1H default + if types.NewCurrency64(1).Cmp(pt.SubscriptionNotificationCost) < 0 { + return fmt.Errorf("SubscriptionNotificationCost of %v exceeds 1H", pt.SubscriptionNotificationCost) + } + + // check RenewContractCost - expect 100nS default + if types.Siacoins(1).Mul64(100).Div64(1e9).Cmp(pt.RenewContractCost) < 0 { + return fmt.Errorf("RenewContractCost of %v exceeds 100nS", pt.RenewContractCost) + } + + // check RevisionBaseCost - expect 0H default + if types.ZeroCurrency.Cmp(pt.RevisionBaseCost) < 0 { + return fmt.Errorf("RevisionBaseCost of %v exceeds 0H", pt.RevisionBaseCost) + } + + return nil +} + +func sectorReadCostRHPv3(pt rhpv3.HostPriceTable) (types.Currency, bool) { + return sectorReadCost( + pt.ReadLengthCost, + pt.ReadBaseCost, + pt.InitBaseCost, + pt.UploadBandwidthCost, + pt.DownloadBandwidthCost, + ) +} + +func sectorReadCost(readLengthCost, readBaseCost, initBaseCost, ulBWCost, dlBWCost types.Currency) (types.Currency, bool) { + // base + base, overflow := readLengthCost.Mul64WithOverflow(rhpv2.SectorSize) + if overflow { + return types.ZeroCurrency, true + } + base, overflow = base.AddWithOverflow(readBaseCost) + if overflow { + return types.ZeroCurrency, true + } + base, overflow = base.AddWithOverflow(initBaseCost) + if overflow { + return types.ZeroCurrency, true + } + // bandwidth + ingress, overflow := ulBWCost.Mul64WithOverflow(32) + if overflow { + return types.ZeroCurrency, true + } + egress, overflow := dlBWCost.Mul64WithOverflow(rhpv2.SectorSize) + if overflow { + return types.ZeroCurrency, true + } + // total + total, overflow := base.AddWithOverflow(ingress) + if overflow { + return types.ZeroCurrency, true + } + total, overflow = total.AddWithOverflow(egress) + if overflow { + return types.ZeroCurrency, true + } + return total, false +} + +func sectorUploadCostRHPv3(pt rhpv3.HostPriceTable) (types.Currency, bool) { + // write + writeCost, overflow := pt.WriteLengthCost.Mul64WithOverflow(rhpv2.SectorSize) + if overflow { + return types.ZeroCurrency, true + } + writeCost, overflow = writeCost.AddWithOverflow(pt.WriteBaseCost) + if overflow { + return types.ZeroCurrency, true + } + writeCost, overflow = writeCost.AddWithOverflow(pt.InitBaseCost) + if overflow { + return types.ZeroCurrency, true + } + // bandwidth + ingress, overflow := pt.UploadBandwidthCost.Mul64WithOverflow(rhpv2.SectorSize) + if overflow { + return types.ZeroCurrency, true + } + // total + total, overflow := writeCost.AddWithOverflow(ingress) + if overflow { + return types.ZeroCurrency, true + } + return total, false +} + +func errsToStr(errs ...error) string { + if err := errors.Join(errs...); err != nil { + return err.Error() + } + return "" +} diff --git a/internal/node/chainmanager.go b/internal/node/chainmanager.go deleted file mode 100644 index 6eaf91a53..000000000 --- a/internal/node/chainmanager.go +++ /dev/null @@ -1,173 +0,0 @@ -package node - -import ( - "errors" - "fmt" - "strings" - "sync" - "time" - - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/renterd/bus" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" -) - -const ( - maxSyncTime = time.Hour -) - -var ( - ErrBlockNotFound = errors.New("block not found") - ErrInvalidChangeID = errors.New("invalid change id") -) - -type chainManager struct { - cs modules.ConsensusSet - tp bus.TransactionPool - network *consensus.Network - - close chan struct{} - mu sync.Mutex - lastBlockTime time.Time - tip consensus.State - synced bool -} - -// ProcessConsensusChange implements the modules.ConsensusSetSubscriber interface. -func (m *chainManager) ProcessConsensusChange(cc modules.ConsensusChange) { - m.mu.Lock() - defer m.mu.Unlock() - - b := cc.AppliedBlocks[len(cc.AppliedBlocks)-1] - m.tip = consensus.State{ - Network: m.network, - Index: types.ChainIndex{ - ID: types.BlockID(b.ID()), - Height: uint64(cc.BlockHeight), - }, - } - m.synced = synced(b.Timestamp) - m.lastBlockTime = time.Unix(int64(b.Timestamp), 0) -} - -// Network returns the network name. -func (m *chainManager) Network() string { - switch m.network.Name { - case "zen": - return "Zen Testnet" - case "mainnet": - return "Mainnet" - default: - return m.network.Name - } -} - -// Close closes the chain manager. -func (m *chainManager) Close() error { - select { - case <-m.close: - return nil - default: - } - close(m.close) - return m.cs.Close() -} - -// Synced returns true if the chain manager is synced with the consensus set. -func (m *chainManager) Synced() bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.synced -} - -// BlockAtHeight returns the block at the given height. -func (m *chainManager) BlockAtHeight(height uint64) (types.Block, bool) { - sb, ok := m.cs.BlockAtHeight(stypes.BlockHeight(height)) - var c types.Block - convertToCore(sb, (*types.V1Block)(&c)) - return types.Block(c), ok -} - -func (m *chainManager) LastBlockTime() time.Time { - m.mu.Lock() - defer m.mu.Unlock() - return m.lastBlockTime -} - -// IndexAtHeight return the chain index at the given height. -func (m *chainManager) IndexAtHeight(height uint64) (types.ChainIndex, error) { - block, ok := m.cs.BlockAtHeight(stypes.BlockHeight(height)) - if !ok { - return types.ChainIndex{}, ErrBlockNotFound - } - return types.ChainIndex{ - ID: types.BlockID(block.ID()), - Height: height, - }, nil -} - -// TipState returns the current chain state. -func (m *chainManager) TipState() consensus.State { - m.mu.Lock() - defer m.mu.Unlock() - return m.tip -} - -// AcceptBlock adds b to the consensus set. -func (m *chainManager) AcceptBlock(b types.Block) error { - var sb stypes.Block - convertToSiad(types.V1Block(b), &sb) - return m.cs.AcceptBlock(sb) -} - -// Subscribe subscribes to the consensus set. -func (m *chainManager) Subscribe(s modules.ConsensusSetSubscriber, ccID modules.ConsensusChangeID, cancel <-chan struct{}) error { - if err := m.cs.ConsensusSetSubscribe(s, ccID, cancel); err != nil { - if strings.Contains(err.Error(), "consensus subscription has invalid id") { - return ErrInvalidChangeID - } - return err - } - return nil -} - -// PoolTransactions returns all transactions in the transaction pool -func (m *chainManager) PoolTransactions() []types.Transaction { - return m.tp.Transactions() -} - -func synced(timestamp stypes.Timestamp) bool { - return time.Since(time.Unix(int64(timestamp), 0)) <= maxSyncTime -} - -// NewManager creates a new chain manager. -func NewChainManager(cs modules.ConsensusSet, tp bus.TransactionPool, network *consensus.Network) (*chainManager, error) { - height := cs.Height() - block, ok := cs.BlockAtHeight(height) - if !ok { - return nil, fmt.Errorf("failed to get block at height %d", height) - } - - m := &chainManager{ - cs: cs, - tp: tp, - network: network, - tip: consensus.State{ - Network: network, - Index: types.ChainIndex{ - ID: types.BlockID(block.ID()), - Height: uint64(height), - }, - }, - synced: synced(block.Timestamp), - lastBlockTime: time.Unix(int64(block.Timestamp), 0), - close: make(chan struct{}), - } - - if err := cs.ConsensusSetSubscribe(m, modules.ConsensusChangeRecent, m.close); err != nil { - return nil, fmt.Errorf("failed to subscribe to consensus set: %w", err) - } - return m, nil -} diff --git a/internal/node/convert.go b/internal/node/convert.go deleted file mode 100644 index 8fcc01eed..000000000 --- a/internal/node/convert.go +++ /dev/null @@ -1,28 +0,0 @@ -package node - -import ( - "bytes" - - "gitlab.com/NebulousLabs/encoding" - "go.sia.tech/core/types" -) - -func convertToSiad(core types.EncoderTo, siad encoding.SiaUnmarshaler) { - var buf bytes.Buffer - e := types.NewEncoder(&buf) - core.EncodeTo(e) - e.Flush() - if err := siad.UnmarshalSia(&buf); err != nil { - panic(err) - } -} - -func convertToCore(siad encoding.SiaMarshaler, core types.DecoderFrom) { - var buf bytes.Buffer - siad.MarshalSia(&buf) - d := types.NewBufDecoder(buf.Bytes()) - core.DecodeFrom(d) - if d.Err() != nil { - panic(d.Err()) - } -} diff --git a/internal/node/miner.go b/internal/node/miner.go deleted file mode 100644 index 9043196b4..000000000 --- a/internal/node/miner.go +++ /dev/null @@ -1,150 +0,0 @@ -// TODO: remove this file when we can import it from hostd -package node - -import ( - "bytes" - "context" - "encoding/binary" - "errors" - "fmt" - "sync" - - "go.sia.tech/core/types" - "go.sia.tech/siad/crypto" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" - "lukechampine.com/frand" -) - -const solveAttempts = 1e4 - -type ( - // Consensus defines a minimal interface needed by the miner to interact - // with the consensus set - Consensus interface { - AcceptBlock(context.Context, types.Block) error - } - - // A Miner is a CPU miner that can mine blocks, sending the reward to a - // specified address. - Miner struct { - consensus Consensus - - mu sync.Mutex - height stypes.BlockHeight - target stypes.Target - currentBlockID stypes.BlockID - txnsets map[modules.TransactionSetID][]stypes.TransactionID - transactions []stypes.Transaction - } -) - -var errFailedToSolve = errors.New("failed to solve block") - -// ProcessConsensusChange implements modules.ConsensusSetSubscriber. -func (m *Miner) ProcessConsensusChange(cc modules.ConsensusChange) { - m.mu.Lock() - defer m.mu.Unlock() - m.target = cc.ChildTarget - m.currentBlockID = cc.AppliedBlocks[len(cc.AppliedBlocks)-1].ID() - m.height = cc.BlockHeight -} - -// ReceiveUpdatedUnconfirmedTransactions implements modules.TransactionPoolSubscriber -func (m *Miner) ReceiveUpdatedUnconfirmedTransactions(diff *modules.TransactionPoolDiff) { - m.mu.Lock() - defer m.mu.Unlock() - - reverted := make(map[stypes.TransactionID]bool) - for _, setID := range diff.RevertedTransactions { - for _, txnID := range m.txnsets[setID] { - reverted[txnID] = true - } - } - - filtered := m.transactions[:0] - for _, txn := range m.transactions { - if reverted[txn.ID()] { - continue - } - filtered = append(filtered, txn) - } - - for _, txnset := range diff.AppliedTransactions { - m.txnsets[txnset.ID] = txnset.IDs - filtered = append(filtered, txnset.Transactions...) - } - m.transactions = filtered -} - -// mineBlock attempts to mine a block and add it to the consensus set. -func (m *Miner) mineBlock(addr stypes.UnlockHash) error { - m.mu.Lock() - block := stypes.Block{ - ParentID: m.currentBlockID, - Timestamp: stypes.CurrentTimestamp(), - } - - randBytes := frand.Bytes(stypes.SpecifierLen) - randTxn := stypes.Transaction{ - ArbitraryData: [][]byte{append(modules.PrefixNonSia[:], randBytes...)}, - } - block.Transactions = append([]stypes.Transaction{randTxn}, m.transactions...) - block.MinerPayouts = append(block.MinerPayouts, stypes.SiacoinOutput{ - Value: block.CalculateSubsidy(m.height + 1), - UnlockHash: addr, - }) - target := m.target - m.mu.Unlock() - - merkleRoot := block.MerkleRoot() - header := make([]byte, 80) - copy(header, block.ParentID[:]) - binary.LittleEndian.PutUint64(header[40:48], uint64(block.Timestamp)) - copy(header[48:], merkleRoot[:]) - - var nonce uint64 - var solved bool - for i := 0; i < solveAttempts; i++ { - id := crypto.HashBytes(header) - if bytes.Compare(target[:], id[:]) >= 0 { - block.Nonce = *(*stypes.BlockNonce)(header[32:40]) - solved = true - break - } - binary.LittleEndian.PutUint64(header[32:], nonce) - nonce += stypes.ASICHardforkFactor - } - if !solved { - return errFailedToSolve - } - - var b types.Block - convertToCore(&block, (*types.V1Block)(&b)) - if err := m.consensus.AcceptBlock(context.Background(), types.Block(b)); err != nil { - return fmt.Errorf("failed to get block accepted: %w", err) - } - return nil -} - -// Mine mines n blocks, sending the reward to addr -func (m *Miner) Mine(addr types.Address, n int) error { - var err error - for mined := 1; mined <= n; { - // return the error only if the miner failed to solve the block, - // ignore any consensus related errors - if err = m.mineBlock(stypes.UnlockHash(addr)); errors.Is(err, errFailedToSolve) { - return fmt.Errorf("failed to mine block %v: %w", mined, errFailedToSolve) - } - mined++ - } - return nil -} - -// NewMiner initializes a new CPU miner -func NewMiner(consensus Consensus) *Miner { - return &Miner{ - consensus: consensus, - txnsets: make(map[modules.TransactionSetID][]stypes.TransactionID), - } -} diff --git a/internal/node/node.go b/internal/node/node.go deleted file mode 100644 index f305441cf..000000000 --- a/internal/node/node.go +++ /dev/null @@ -1,286 +0,0 @@ -package node - -import ( - "context" - "errors" - "fmt" - "log" - "net/http" - "os" - "path/filepath" - "strings" - "time" - - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/renterd/alerts" - "go.sia.tech/renterd/autopilot" - "go.sia.tech/renterd/bus" - "go.sia.tech/renterd/config" - "go.sia.tech/renterd/stores" - "go.sia.tech/renterd/stores/sql" - "go.sia.tech/renterd/stores/sql/mysql" - "go.sia.tech/renterd/stores/sql/sqlite" - "go.sia.tech/renterd/wallet" - "go.sia.tech/renterd/webhooks" - "go.sia.tech/renterd/worker" - "go.sia.tech/renterd/worker/s3" - "go.sia.tech/siad/modules" - mconsensus "go.sia.tech/siad/modules/consensus" - "go.sia.tech/siad/modules/gateway" - "go.sia.tech/siad/modules/transactionpool" - "go.sia.tech/siad/sync" - "go.uber.org/zap" - "golang.org/x/crypto/blake2b" - "gorm.io/gorm" - "gorm.io/gorm/logger" - "moul.io/zapgorm2" -) - -type Bus interface { - worker.Bus - s3.Bus -} - -type BusConfig struct { - config.Bus - Database config.Database - DatabaseLog config.DatabaseLog - Network *consensus.Network - Logger *zap.Logger - Miner *Miner -} - -type AutopilotConfig struct { - config.Autopilot - ID string -} - -type ( - RunFn = func() error - BusSetupFn = func(context.Context) error - WorkerSetupFn = func(context.Context, string, string) error - ShutdownFn = func(context.Context) error -) - -var NoopFn = func(context.Context) error { return nil } - -func NewBus(cfg BusConfig, dir string, seed types.PrivateKey, l *zap.Logger) (http.Handler, BusSetupFn, ShutdownFn, error) { - gatewayDir := filepath.Join(dir, "gateway") - if err := os.MkdirAll(gatewayDir, 0700); err != nil { - return nil, nil, nil, err - } - g, err := gateway.New(cfg.GatewayAddr, cfg.Bootstrap, gatewayDir) - if err != nil { - return nil, nil, nil, err - } - consensusDir := filepath.Join(dir, "consensus") - if err := os.MkdirAll(consensusDir, 0700); err != nil { - return nil, nil, nil, err - } - cs, errCh := mconsensus.New(g, cfg.Bootstrap, consensusDir) - select { - case err := <-errCh: - if err != nil { - return nil, nil, nil, err - } - default: - go func() { - if err := <-errCh; err != nil { - log.Println("WARNING: consensus initialization returned an error:", err) - } - }() - } - tpoolDir := filepath.Join(dir, "transactionpool") - if err := os.MkdirAll(tpoolDir, 0700); err != nil { - return nil, nil, nil, err - } - tp, err := transactionpool.New(cs, g, tpoolDir) - if err != nil { - return nil, nil, nil, err - } - - // create database connections - var dbConn gorm.Dialector - var dbMetrics sql.MetricsDatabase - if cfg.Database.MySQL.URI != "" { - // create MySQL connections - dbConn = stores.NewMySQLConnection( - cfg.Database.MySQL.User, - cfg.Database.MySQL.Password, - cfg.Database.MySQL.URI, - cfg.Database.MySQL.Database, - ) - dbm, err := mysql.Open( - cfg.Database.MySQL.User, - cfg.Database.MySQL.Password, - cfg.Database.MySQL.URI, - cfg.Database.MySQL.MetricsDatabase, - ) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to open MySQL metrics database: %w", err) - } - dbMetrics, err = mysql.NewMetricsDatabase(dbm, l.Named("metrics").Sugar(), cfg.DatabaseLog.SlowThreshold, cfg.DatabaseLog.SlowThreshold) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to create MySQL metrics database: %w", err) - } - } else { - // create database directory - dbDir := filepath.Join(dir, "db") - if err := os.MkdirAll(dbDir, 0700); err != nil { - return nil, nil, nil, err - } - - // create SQLite connections - dbConn = stores.NewSQLiteConnection(filepath.Join(dbDir, "db.sqlite")) - - dbm, err := sqlite.Open(filepath.Join(dbDir, "metrics.sqlite")) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to open SQLite metrics database: %w", err) - } - dbMetrics, err = sqlite.NewMetricsDatabase(dbm, l.Named("metrics").Sugar(), cfg.DatabaseLog.SlowThreshold, cfg.DatabaseLog.SlowThreshold) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to create SQLite metrics database: %w", err) - } - } - - // create database logger - dbLogger := zapgorm2.Logger{ - ZapLogger: cfg.Logger.Named("SQL"), - LogLevel: gormLogLevel(cfg.DatabaseLog), - SlowThreshold: cfg.DatabaseLog.SlowThreshold, - SkipCallerLookup: false, - IgnoreRecordNotFoundError: cfg.DatabaseLog.IgnoreRecordNotFoundError, - Context: nil, - } - - alertsMgr := alerts.NewManager() - walletAddr := wallet.StandardAddress(seed.PublicKey()) - sqlStoreDir := filepath.Join(dir, "partial_slabs") - announcementMaxAge := time.Duration(cfg.AnnouncementMaxAgeHours) * time.Hour - sqlStore, ccid, err := stores.NewSQLStore(stores.Config{ - Conn: dbConn, - Alerts: alerts.WithOrigin(alertsMgr, "bus"), - DBMetrics: dbMetrics, - PartialSlabDir: sqlStoreDir, - Migrate: true, - AnnouncementMaxAge: announcementMaxAge, - PersistInterval: cfg.PersistInterval, - WalletAddress: walletAddr, - SlabBufferCompletionThreshold: cfg.SlabBufferCompletionThreshold, - Logger: l.Sugar(), - GormLogger: dbLogger, - RetryTransactionIntervals: []time.Duration{200 * time.Millisecond, 500 * time.Millisecond, time.Second, 3 * time.Second, 10 * time.Second, 10 * time.Second}, - LongQueryDuration: cfg.DatabaseLog.SlowThreshold, - LongTxDuration: cfg.DatabaseLog.SlowThreshold, - }) - if err != nil { - return nil, nil, nil, err - } - hooksMgr, err := webhooks.NewManager(l.Named("webhooks").Sugar(), sqlStore) - if err != nil { - return nil, nil, nil, err - } - - // Hook up webhooks to alerts. - alertsMgr.RegisterWebhookBroadcaster(hooksMgr) - - cancelSubscribe := make(chan struct{}) - go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - subscribeErr := cs.ConsensusSetSubscribe(sqlStore, ccid, cancelSubscribe) - if errors.Is(subscribeErr, modules.ErrInvalidConsensusChangeID) { - l.Warn("Invalid consensus change ID detected - resyncing consensus") - // Reset the consensus state within the database and rescan. - if err := sqlStore.ResetConsensusSubscription(ctx); err != nil { - l.Fatal(fmt.Sprintf("Failed to reset consensus subscription of SQLStore: %v", err)) - return - } - // Subscribe from the beginning. - subscribeErr = cs.ConsensusSetSubscribe(sqlStore, modules.ConsensusChangeBeginning, cancelSubscribe) - } - if subscribeErr != nil && !errors.Is(subscribeErr, sync.ErrStopped) { - l.Fatal(fmt.Sprintf("ConsensusSetSubscribe returned an error: %v", err)) - } - }() - - w := wallet.NewSingleAddressWallet(seed, sqlStore, cfg.UsedUTXOExpiry, zap.NewNop().Sugar()) - tp.TransactionPoolSubscribe(w) - if err := cs.ConsensusSetSubscribe(w, modules.ConsensusChangeRecent, nil); err != nil { - return nil, nil, nil, err - } - - if m := cfg.Miner; m != nil { - if err := cs.ConsensusSetSubscribe(m, ccid, nil); err != nil { - return nil, nil, nil, err - } - tp.TransactionPoolSubscribe(m) - } - - cm, err := NewChainManager(cs, NewTransactionPool(tp), cfg.Network) - if err != nil { - return nil, nil, nil, err - } - - b, err := bus.New(syncer{g, tp}, alertsMgr, hooksMgr, cm, NewTransactionPool(tp), w, sqlStore, sqlStore, sqlStore, sqlStore, sqlStore, sqlStore, l) - if err != nil { - return nil, nil, nil, err - } - - shutdownFn := func(ctx context.Context) error { - close(cancelSubscribe) - return errors.Join( - g.Close(), - cs.Close(), - tp.Close(), - b.Shutdown(ctx), - sqlStore.Close(), - ) - } - return b.Handler(), b.Setup, shutdownFn, nil -} - -func NewWorker(cfg config.Worker, s3Opts s3.Opts, b Bus, seed types.PrivateKey, l *zap.Logger) (http.Handler, http.Handler, WorkerSetupFn, ShutdownFn, error) { - workerKey := blake2b.Sum256(append([]byte("worker"), seed...)) - w, err := worker.New(workerKey, cfg.ID, b, cfg.ContractLockTimeout, cfg.BusFlushInterval, cfg.DownloadOverdriveTimeout, cfg.UploadOverdriveTimeout, cfg.DownloadMaxOverdrive, cfg.UploadMaxOverdrive, cfg.DownloadMaxMemory, cfg.UploadMaxMemory, cfg.AllowPrivateIPs, l) - if err != nil { - return nil, nil, nil, nil, err - } - s3Handler, err := s3.New(b, w, l.Named("s3").Sugar(), s3Opts) - if err != nil { - err = errors.Join(err, w.Shutdown(context.Background())) - return nil, nil, nil, nil, fmt.Errorf("failed to create s3 handler: %w", err) - } - return w.Handler(), s3Handler, w.Setup, w.Shutdown, nil -} - -func NewAutopilot(cfg AutopilotConfig, b autopilot.Bus, workers []autopilot.Worker, l *zap.Logger) (http.Handler, RunFn, ShutdownFn, error) { - ap, err := autopilot.New(cfg.ID, b, workers, l, cfg.Heartbeat, cfg.ScannerInterval, cfg.ScannerBatchSize, cfg.ScannerNumThreads, cfg.MigrationHealthCutoff, cfg.AccountsRefillInterval, cfg.RevisionSubmissionBuffer, cfg.MigratorParallelSlabsPerWorker, cfg.RevisionBroadcastInterval) - if err != nil { - return nil, nil, nil, err - } - return ap.Handler(), ap.Run, ap.Shutdown, nil -} - -func gormLogLevel(cfg config.DatabaseLog) logger.LogLevel { - level := logger.Silent - if cfg.Enabled { - switch strings.ToLower(cfg.Level) { - case "": - level = logger.Warn // default to 'warn' if not set - case "error": - level = logger.Error - case "warn": - level = logger.Warn - case "info": - level = logger.Info - case "debug": - level = logger.Info - default: - log.Fatalf("invalid log level %q, options are: silent, error, warn, info", cfg.Level) - } - } - return level -} diff --git a/internal/node/syncer.go b/internal/node/syncer.go deleted file mode 100644 index 6a4e80c98..000000000 --- a/internal/node/syncer.go +++ /dev/null @@ -1,43 +0,0 @@ -package node - -import ( - "context" - - "go.sia.tech/core/types" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" -) - -type syncer struct { - g modules.Gateway - tp modules.TransactionPool -} - -func (s syncer) Addr() string { - return string(s.g.Address()) -} - -func (s syncer) Peers() []string { - var peers []string - for _, p := range s.g.Peers() { - peers = append(peers, string(p.NetAddress)) - } - return peers -} - -func (s syncer) Connect(addr string) error { - return s.g.Connect(modules.NetAddress(addr)) -} - -func (s syncer) BroadcastTransaction(txn types.Transaction, dependsOn []types.Transaction) { - txnSet := make([]stypes.Transaction, len(dependsOn)+1) - for i, txn := range dependsOn { - convertToSiad(txn, &txnSet[i]) - } - convertToSiad(txn, &txnSet[len(txnSet)-1]) - s.tp.Broadcast(txnSet) -} - -func (s syncer) SyncerAddress(ctx context.Context) (string, error) { - return string(s.g.Address()), nil -} diff --git a/internal/node/transactionpool.go b/internal/node/transactionpool.go deleted file mode 100644 index 54bfb3142..000000000 --- a/internal/node/transactionpool.go +++ /dev/null @@ -1,85 +0,0 @@ -package node - -import ( - "errors" - "slices" - - "go.sia.tech/core/types" - "go.sia.tech/renterd/bus" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" -) - -type txpool struct { - tp modules.TransactionPool -} - -func (tp txpool) RecommendedFee() (fee types.Currency) { - _, maxFee := tp.tp.FeeEstimation() - convertToCore(&maxFee, (*types.V1Currency)(&fee)) - return -} - -func (tp txpool) Transactions() []types.Transaction { - stxns := tp.tp.Transactions() - txns := make([]types.Transaction, len(stxns)) - for i := range txns { - convertToCore(&stxns[i], &txns[i]) - } - return txns -} - -func (tp txpool) AcceptTransactionSet(txns []types.Transaction) error { - stxns := make([]stypes.Transaction, len(txns)) - for i := range stxns { - convertToSiad(&txns[i], &stxns[i]) - } - err := tp.tp.AcceptTransactionSet(stxns) - if errors.Is(err, modules.ErrDuplicateTransactionSet) { - err = nil - } - return err -} - -func (tp txpool) UnconfirmedParents(txn types.Transaction) ([]types.Transaction, error) { - return unconfirmedParents(txn, tp.Transactions()), nil -} - -func (tp txpool) Subscribe(subscriber modules.TransactionPoolSubscriber) { - tp.tp.TransactionPoolSubscribe(subscriber) -} - -func (tp txpool) Close() error { - return tp.tp.Close() -} - -func unconfirmedParents(txn types.Transaction, pool []types.Transaction) []types.Transaction { - outputToParent := make(map[types.SiacoinOutputID]*types.Transaction) - for i, txn := range pool { - for j := range txn.SiacoinOutputs { - outputToParent[txn.SiacoinOutputID(j)] = &pool[i] - } - } - var parents []types.Transaction - txnsToCheck := []*types.Transaction{&txn} - seen := make(map[types.TransactionID]bool) - for len(txnsToCheck) > 0 { - nextTxn := txnsToCheck[0] - txnsToCheck = txnsToCheck[1:] - for _, sci := range nextTxn.SiacoinInputs { - if parent, ok := outputToParent[sci.ParentID]; ok { - if txid := parent.ID(); !seen[txid] { - seen[txid] = true - parents = append(parents, *parent) - txnsToCheck = append(txnsToCheck, parent) - } - } - } - } - slices.Reverse(parents) - return parents -} - -func NewTransactionPool(tp modules.TransactionPool) bus.TransactionPool { - return &txpool{tp: tp} -} diff --git a/internal/node/transactionpool_test.go b/internal/node/transactionpool_test.go deleted file mode 100644 index c24e2c190..000000000 --- a/internal/node/transactionpool_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package node - -import ( - "reflect" - "testing" - - "go.sia.tech/core/types" -) - -func TestUnconfirmedParents(t *testing.T) { - grandparent := types.Transaction{ - SiacoinOutputs: []types.SiacoinOutput{{}}, - } - parent := types.Transaction{ - SiacoinInputs: []types.SiacoinInput{ - { - ParentID: grandparent.SiacoinOutputID(0), - }, - }, - SiacoinOutputs: []types.SiacoinOutput{{}}, - } - txn := types.Transaction{ - SiacoinInputs: []types.SiacoinInput{ - { - ParentID: parent.SiacoinOutputID(0), - }, - }, - SiacoinOutputs: []types.SiacoinOutput{{}}, - } - pool := []types.Transaction{grandparent, parent} - - parents := unconfirmedParents(txn, pool) - if len(parents) != 2 { - t.Fatalf("expected 2 parents, got %v", len(parents)) - } else if !reflect.DeepEqual(parents[0], grandparent) { - t.Fatalf("expected grandparent") - } else if !reflect.DeepEqual(parents[1], parent) { - t.Fatalf("expected parent") - } -} diff --git a/worker/rhpv2.go b/internal/rhp/v2/rhp.go similarity index 52% rename from worker/rhpv2.go rename to internal/rhp/v2/rhp.go index c0e01b1b4..9786bf4c8 100644 --- a/worker/rhpv2.go +++ b/internal/rhp/v2/rhp.go @@ -1,4 +1,4 @@ -package worker +package rhp import ( "context" @@ -7,20 +7,25 @@ import ( "errors" "fmt" "math" + "net" "sort" - "strings" "time" rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" - "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/gouging" "go.sia.tech/renterd/internal/utils" - "go.sia.tech/siad/build" - "go.sia.tech/siad/crypto" + "go.uber.org/zap" "lukechampine.com/frand" ) const ( + batchSizeDeleteSectors = uint64(1000) // 4GiB of contract data + batchSizeFetchSectors = uint64(25600) // 100GiB of contract data + + // default lock timeout + defaultLockTimeout = time.Minute + // minMessageSize is the minimum size of an RPC message minMessageSize = 4096 @@ -31,14 +36,14 @@ const ( ) var ( - // ErrInsufficientFunds is returned by various RPCs when the renter is - // unable to provide sufficient payment to the host. - ErrInsufficientFunds = errors.New("insufficient funds") - // ErrInsufficientCollateral is returned by various RPCs when the host is // unable to provide sufficient collateral. ErrInsufficientCollateral = errors.New("insufficient collateral") + // ErrInsufficientFunds is returned by various RPCs when the renter is + // unable to provide sufficient payment to the host. + ErrInsufficientFunds = errors.New("insufficient funds") + // ErrInvalidMerkleProof is returned by various RPCs when the host supplies // an invalid Merkle proof. ErrInvalidMerkleProof = errors.New("host supplied invalid Merkle proof") @@ -62,194 +67,41 @@ var ( ErrNoSectorsToPrune = errors.New("no sectors to prune") ) -// A HostErrorSet is a collection of errors from various hosts. -type HostErrorSet map[types.PublicKey]error - -// NumGouging returns numbers of host that errored out due to price gouging. -func (hes HostErrorSet) NumGouging() (n int) { - for _, he := range hes { - if errors.Is(he, errPriceTableGouging) { - n++ - } +type ( + Dialer interface { + Dial(ctx context.Context, hk types.PublicKey, address string) (net.Conn, error) } - return -} -// Error implements error. -func (hes HostErrorSet) Error() string { - if len(hes) == 0 { - return "" - } - - var strs []string - for hk, he := range hes { - strs = append(strs, fmt.Sprintf("%x: %v", hk[:4], he.Error())) - } + PrepareFormFn func(ctx context.Context, renterAddress types.Address, renterKey types.PublicKey, renterFunds, hostCollateral types.Currency, hostKey types.PublicKey, hostSettings rhpv2.HostSettings, endHeight uint64) (txns []types.Transaction, discard func(types.Transaction), err error) +) - // include a leading newline so that the first error isn't printed on the - // same line as the error context - return "\n" + strings.Join(strs, "\n") +type Client struct { + dialer Dialer + logger *zap.SugaredLogger } -func wrapErr(ctx context.Context, fnName string, err *error) { - if *err != nil { - *err = fmt.Errorf("%s: %w", fnName, *err) - if cause := context.Cause(ctx); cause != nil && !utils.IsErr(*err, cause) { - *err = fmt.Errorf("%w; %w", cause, *err) - } +func New(dialer Dialer, logger *zap.Logger) *Client { + return &Client{ + dialer: dialer, + logger: logger.Sugar().Named("rhp2"), } } -func hashRevision(rev types.FileContractRevision) types.Hash256 { - h := types.NewHasher() - rev.EncodeTo(h.E) - return h.Sum() -} - -func updateRevisionOutputs(rev *types.FileContractRevision, cost, collateral types.Currency) (valid, missed []types.Currency, err error) { - // allocate new slices; don't want to risk accidentally sharing memory - rev.ValidProofOutputs = append([]types.SiacoinOutput(nil), rev.ValidProofOutputs...) - rev.MissedProofOutputs = append([]types.SiacoinOutput(nil), rev.MissedProofOutputs...) - - // move valid payout from renter to host - var underflow, overflow bool - rev.ValidProofOutputs[0].Value, underflow = rev.ValidProofOutputs[0].Value.SubWithUnderflow(cost) - rev.ValidProofOutputs[1].Value, overflow = rev.ValidProofOutputs[1].Value.AddWithOverflow(cost) - if underflow || overflow { - err = errors.New("insufficient funds to pay host") - return - } - - // move missed payout from renter to void - rev.MissedProofOutputs[0].Value, underflow = rev.MissedProofOutputs[0].Value.SubWithUnderflow(cost) - rev.MissedProofOutputs[2].Value, overflow = rev.MissedProofOutputs[2].Value.AddWithOverflow(cost) - if underflow || overflow { - err = errors.New("insufficient funds to move missed payout to void") - return - } - - // move collateral from host to void - rev.MissedProofOutputs[1].Value, underflow = rev.MissedProofOutputs[1].Value.SubWithUnderflow(collateral) - rev.MissedProofOutputs[2].Value, overflow = rev.MissedProofOutputs[2].Value.AddWithOverflow(collateral) - if underflow || overflow { - err = errors.New("insufficient collateral") - return - } - - return []types.Currency{rev.ValidProofOutputs[0].Value, rev.ValidProofOutputs[1].Value}, - []types.Currency{rev.MissedProofOutputs[0].Value, rev.MissedProofOutputs[1].Value, rev.MissedProofOutputs[2].Value}, nil -} - -// RPCSettings calls the Settings RPC, returning the host's reported settings. -func RPCSettings(ctx context.Context, t *rhpv2.Transport) (settings rhpv2.HostSettings, err error) { - defer wrapErr(ctx, "Settings", &err) - - var resp rhpv2.RPCSettingsResponse - if err := t.Call(rhpv2.RPCSettingsID, nil, &resp); err != nil { - return rhpv2.HostSettings{}, err - } else if err := json.Unmarshal(resp.Settings, &settings); err != nil { - return rhpv2.HostSettings{}, fmt.Errorf("couldn't unmarshal json: %w", err) - } - - return settings, nil -} - -// RPCFormContract forms a contract with a host. -func RPCFormContract(ctx context.Context, t *rhpv2.Transport, renterKey types.PrivateKey, txnSet []types.Transaction) (_ rhpv2.ContractRevision, _ []types.Transaction, err error) { - defer wrapErr(ctx, "FormContract", &err) - - // strip our signatures before sending - parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] - renterContractSignatures := txn.Signatures - txnSet[len(txnSet)-1].Signatures = nil - - // create request - renterPubkey := renterKey.PublicKey() - req := &rhpv2.RPCFormContractRequest{ - Transactions: txnSet, - RenterKey: renterPubkey.UnlockKey(), - } - if err := t.WriteRequest(rhpv2.RPCFormContractID, req); err != nil { - return rhpv2.ContractRevision{}, nil, err - } - - // execute form contract RPC - var resp rhpv2.RPCFormContractAdditions - if err := t.ReadResponse(&resp, 65536); err != nil { - return rhpv2.ContractRevision{}, nil, err - } - - // merge host additions with txn - txn.SiacoinInputs = append(txn.SiacoinInputs, resp.Inputs...) - txn.SiacoinOutputs = append(txn.SiacoinOutputs, resp.Outputs...) - - // create initial (no-op) revision, transaction, and signature - fc := txn.FileContracts[0] - initRevision := types.FileContractRevision{ - ParentID: txn.FileContractID(0), - UnlockConditions: types.UnlockConditions{ - PublicKeys: []types.UnlockKey{ - renterPubkey.UnlockKey(), - t.HostKey().UnlockKey(), - }, - SignaturesRequired: 2, - }, - FileContract: types.FileContract{ - RevisionNumber: 1, - Filesize: fc.Filesize, - FileMerkleRoot: fc.FileMerkleRoot, - WindowStart: fc.WindowStart, - WindowEnd: fc.WindowEnd, - ValidProofOutputs: fc.ValidProofOutputs, - MissedProofOutputs: fc.MissedProofOutputs, - UnlockHash: fc.UnlockHash, - }, - } - revSig := renterKey.SignHash(hashRevision(initRevision)) - renterRevisionSig := types.TransactionSignature{ - ParentID: types.Hash256(initRevision.ParentID), - CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, - PublicKeyIndex: 0, - Signature: revSig[:], - } - - // write our signatures - renterSigs := &rhpv2.RPCFormContractSignatures{ - ContractSignatures: renterContractSignatures, - RevisionSignature: renterRevisionSig, - } - if err := t.WriteResponse(renterSigs); err != nil { - return rhpv2.ContractRevision{}, nil, err - } - - // read the host's signatures and merge them with our own - var hostSigs rhpv2.RPCFormContractSignatures - if err := t.ReadResponse(&hostSigs, minMessageSize); err != nil { - return rhpv2.ContractRevision{}, nil, err - } - - txn.Signatures = make([]types.TransactionSignature, 0, len(renterContractSignatures)+len(hostSigs.ContractSignatures)) - txn.Signatures = append(txn.Signatures, renterContractSignatures...) - txn.Signatures = append(txn.Signatures, hostSigs.ContractSignatures...) - - signedTxnSet := make([]types.Transaction, 0, len(resp.Parents)+len(parents)+1) - signedTxnSet = append(signedTxnSet, resp.Parents...) - signedTxnSet = append(signedTxnSet, parents...) - signedTxnSet = append(signedTxnSet, txn) - return rhpv2.ContractRevision{ - Revision: initRevision, - Signatures: [2]types.TransactionSignature{ - renterRevisionSig, - hostSigs.RevisionSignature, - }, - }, signedTxnSet, nil +func (w *Client) ContractRoots(ctx context.Context, renterKey types.PrivateKey, gougingChecker gouging.Checker, hostIP string, hostKey types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64) (roots []types.Hash256, revision *types.FileContractRevision, cost types.Currency, err error) { + err = w.withTransport(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { + return w.withRevisionV2(renterKey, gougingChecker, t, fcid, lastKnownRevisionNumber, func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) (err error) { + roots, cost, err = w.fetchContractRoots(t, renterKey, &rev, settings) + revision = &rev.Revision + return + }) + }) + return } -// FetchSignedRevision fetches the latest signed revision for a contract from a host. -// TODO: stop using rhpv2 and upgrade to newer protocol when possible. -func (w *worker) FetchSignedRevision(ctx context.Context, hostIP string, hostKey types.PublicKey, renterKey types.PrivateKey, contractID types.FileContractID, timeout time.Duration) (rhpv2.ContractRevision, error) { +// SignedRevision fetches the latest signed revision for a contract from a host. +func (w *Client) SignedRevision(ctx context.Context, hostIP string, hostKey types.PublicKey, renterKey types.PrivateKey, contractID types.FileContractID, timeout time.Duration) (rhpv2.ContractRevision, error) { var rev rhpv2.ContractRevision - err := w.withTransportV2(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { + err := w.withTransport(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { req := &rhpv2.RPCLockRequest{ ContractID: contractID, Signature: t.SignChallenge(renterKey), @@ -291,78 +143,113 @@ func (w *worker) FetchSignedRevision(ctx context.Context, hostIP string, hostKey return rev, err } -func (w *worker) PruneContract(ctx context.Context, hostIP string, hostKey types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64) (deleted, remaining uint64, err error) { - err = w.withContractLock(ctx, fcid, lockingPriorityPruning, func() error { - return w.withTransportV2(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { - return w.withRevisionV2(defaultLockTimeout, t, hostKey, fcid, lastKnownRevisionNumber, func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) (err error) { - // perform gouging checks - gc, err := GougingCheckerFromContext(ctx, false) - if err != nil { - return err - } - if breakdown := gc.Check(&settings, nil); breakdown.Gouging() { - return fmt.Errorf("failed to prune contract: %v", breakdown) - } +func (c *Client) Settings(ctx context.Context, hostKey types.PublicKey, hostIP string) (settings rhpv2.HostSettings, err error) { + err = c.withTransport(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { + var err error + if settings, err = rpcSettings(ctx, t); err != nil { + return err + } + // NOTE: we overwrite the NetAddress with the host address here + // since we just used it to dial the host we know it's valid + settings.NetAddress = hostIP + return nil + }) + return +} - // delete roots - got, err := w.fetchContractRoots(t, &rev, settings) - if err != nil { - return err - } +func (c *Client) FormContract(ctx context.Context, renterAddress types.Address, renterKey types.PrivateKey, hostKey types.PublicKey, hostIP string, renterFunds, hostCollateral types.Currency, endHeight uint64, gougingChecker gouging.Checker, prepareForm PrepareFormFn) (contract rhpv2.ContractRevision, txnSet []types.Transaction, err error) { + err = c.withTransport(ctx, hostKey, hostIP, func(t *rhpv2.Transport) (err error) { + settings, err := rpcSettings(ctx, t) + if err != nil { + return err + } - // fetch the roots from the bus - want, pending, err := w.bus.ContractRoots(ctx, fcid) - if err != nil { - return err - } - keep := make(map[types.Hash256]struct{}) - for _, root := range append(want, pending...) { - keep[root] = struct{}{} - } + if breakdown := gougingChecker.CheckSettings(settings); breakdown.Gouging() { + return fmt.Errorf("failed to form contract, gouging check failed: %v", breakdown) + } - // collect indices for roots we want to prune - var indices []uint64 - for i, root := range got { - if _, wanted := keep[root]; wanted { - delete(keep, root) // prevent duplicates - continue - } - indices = append(indices, uint64(i)) - } - if len(indices) == 0 { - return fmt.Errorf("%w: database holds %d (%d pending), contract contains %d", ErrNoSectorsToPrune, len(want)+len(pending), len(pending), len(got)) - } + renterTxnSet, discardTxn, err := prepareForm(ctx, renterAddress, renterKey.PublicKey(), renterFunds, hostCollateral, hostKey, settings, endHeight) + if err != nil { + return err + } + + contract, txnSet, err = rpcFormContract(ctx, t, renterKey, renterTxnSet) + if err != nil { + discardTxn(renterTxnSet[len(renterTxnSet)-1]) + return err + } + return + }) + return +} + +func (c *Client) PruneContract(ctx context.Context, renterKey types.PrivateKey, gougingChecker gouging.Checker, hostIP string, hostKey types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64, toKeep []types.Hash256) (revision *types.FileContractRevision, deleted, remaining uint64, cost types.Currency, err error) { + err = c.withTransport(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { + return c.withRevisionV2(renterKey, gougingChecker, t, fcid, lastKnownRevisionNumber, func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) (err error) { + // fetch roots + got, fetchCost, err := c.fetchContractRoots(t, renterKey, &rev, settings) + if err != nil { + return err + } + + // update cost and revision + cost = cost.Add(fetchCost) + revision = &rev.Revision + + keep := make(map[types.Hash256]struct{}) + for _, root := range toKeep { + keep[root] = struct{}{} + } - // delete the roots from the contract - deleted, err = w.deleteContractRoots(t, &rev, settings, indices) - if deleted < uint64(len(indices)) { - remaining = uint64(len(indices)) - deleted + // collect indices for roots we want to prune + var indices []uint64 + for i, root := range got { + if _, wanted := keep[root]; wanted { + delete(keep, root) // prevent duplicates + continue } + indices = append(indices, uint64(i)) + } + if len(indices) == 0 { + return fmt.Errorf("%w: database holds %d, contract contains %d", ErrNoSectorsToPrune, len(toKeep), len(got)) + } - // return sizes instead of number of roots - deleted *= rhpv2.SectorSize - remaining *= rhpv2.SectorSize - return - }) + // delete the roots from the contract + var deleteCost types.Currency + deleted, deleteCost, err = c.deleteContractRoots(t, renterKey, &rev, settings, indices) + if deleted < uint64(len(indices)) { + remaining = uint64(len(indices)) - deleted + } + + // update cost and revision + if deleted > 0 { + cost = cost.Add(deleteCost) + revision = &rev.Revision + } + + // return sizes instead of number of roots + deleted *= rhpv2.SectorSize + remaining *= rhpv2.SectorSize + return }) }) return } -func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevision, settings rhpv2.HostSettings, indices []uint64) (deleted uint64, err error) { +func (c *Client) deleteContractRoots(t *rhpv2.Transport, renterKey types.PrivateKey, rev *rhpv2.ContractRevision, settings rhpv2.HostSettings, indices []uint64) (deleted uint64, cost types.Currency, err error) { id := frand.Entropy128() - logger := w.logger. + logger := c.logger. With("id", hex.EncodeToString(id[:])). With("hostKey", rev.HostKey()). With("hostVersion", settings.Version). With("fcid", rev.ID()). With("revisionNumber", rev.Revision.RevisionNumber). Named("deleteContractRoots") - logger.Infow(fmt.Sprintf("deleting %d contract roots (%v)", len(indices), humanReadableSize(len(indices)*rhpv2.SectorSize)), "hk", rev.HostKey(), "fcid", rev.ID()) + logger.Infow(fmt.Sprintf("deleting %d contract roots (%v)", len(indices), utils.HumanReadableSize(len(indices)*rhpv2.SectorSize)), "hk", rev.HostKey(), "fcid", rev.ID()) // return early if len(indices) == 0 { - return 0, nil + return 0, types.ZeroCurrency, nil } // sort in descending order so that we can use 'range' @@ -374,7 +261,7 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi // hosts we use a much smaller batch size to ensure we nibble away at the // problem rather than outright failing or timing out batchSize := int(batchSizeDeleteSectors) - if build.VersionCmp(settings.Version, "1.6.0") < 0 { + if utils.VersionCmp(settings.Version, "1.6.0") < 0 { batchSize = 100 } @@ -391,17 +278,14 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi } } - // derive the renter key - renterKey := w.deriveRenterKey(rev.HostKey()) - // range over the batches and delete the sectors batch per batch for i, batch := range batches { if err = func() error { - var cost types.Currency + var batchCost types.Currency start := time.Now() logger.Infow(fmt.Sprintf("starting batch %d/%d of size %d", i+1, len(batches), len(batch))) defer func() { - logger.Infow(fmt.Sprintf("processing batch %d/%d of size %d took %v", i+1, len(batches), len(batch), time.Since(start)), "cost", cost) + logger.Infow(fmt.Sprintf("processing batch %d/%d of size %d took %v", i+1, len(batches), len(batch), time.Since(start)), "cost", batchCost) }() numSectors := rev.NumSectors() @@ -431,20 +315,20 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi if err != nil { return err } - cost, _ = rpcCost.Total() + batchCost, _ = rpcCost.Total() // NOTE: we currently overpay hosts by quite a large margin (~10x) // to ensure we cover both 1.5.9 and pre v0.2.1 hosts. // // TODO: remove once host network is updated, or once we include the // host release in the scoring and stop using old hosts - proofSize := (128 + uint64(len(actions))) * crypto.HashSize + proofSize := (128 + uint64(len(actions))) * rhpv2.LeafSize compatCost := settings.BaseRPCPrice.Add(settings.DownloadBandwidthPrice.Mul64(proofSize)) - if cost.Cmp(compatCost) < 0 { - cost = compatCost + if batchCost.Cmp(compatCost) < 0 { + batchCost = compatCost } - if rev.RenterFunds().Cmp(cost) < 0 { + if rev.RenterFunds().Cmp(batchCost) < 0 { return ErrInsufficientFunds } @@ -458,7 +342,7 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi rev.Revision.Filesize -= rhpv2.SectorSize * actions[len(actions)-1].A // update the revision outputs - newValid, newMissed, err := updateRevisionOutputs(&rev.Revision, cost, types.ZeroCurrency) + newRevision, err := updatedRevision(rev.Revision, batchCost, types.ZeroCurrency) if err != nil { return err } @@ -468,9 +352,16 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi Actions: actions, MerkleProof: true, - RevisionNumber: rev.Revision.RevisionNumber, - ValidProofValues: newValid, - MissedProofValues: newMissed, + RevisionNumber: rev.Revision.RevisionNumber, + ValidProofValues: []types.Currency{ + newRevision.ValidProofOutputs[0].Value, + newRevision.ValidProofOutputs[1].Value, + }, + MissedProofValues: []types.Currency{ + newRevision.MissedProofOutputs[0].Value, + newRevision.MissedProofOutputs[1].Value, + newRevision.MissedProofOutputs[2].Value, + }, } // send request and read merkle proof @@ -486,8 +377,8 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi // verify proof proofHashes := merkleResp.OldSubtreeHashes leafHashes := merkleResp.OldLeafHashes - oldRoot, newRoot := types.Hash256(rev.Revision.FileMerkleRoot), merkleResp.NewMerkleRoot - if rev.Revision.Filesize > 0 && !rhpv2.VerifyDiffProof(actions, numSectors, proofHashes, leafHashes, oldRoot, newRoot, nil) { + oldRoot, newRoot := types.Hash256(newRevision.FileMerkleRoot), merkleResp.NewMerkleRoot + if newRevision.Filesize > 0 && !rhpv2.VerifyDiffProof(actions, numSectors, proofHashes, leafHashes, oldRoot, newRoot, nil) { err := fmt.Errorf("couldn't verify delete proof, host %v, version %v; %w", rev.HostKey(), settings.Version, ErrInvalidMerkleProof) logger.Infow(fmt.Sprintf("processing batch %d/%d failed, err %v", i+1, len(batches), err)) t.WriteResponseErr(err) @@ -495,10 +386,10 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi } // update merkle root - copy(rev.Revision.FileMerkleRoot[:], newRoot[:]) + copy(newRevision.FileMerkleRoot[:], newRoot[:]) // build the write response - revisionHash := hashRevision(rev.Revision) + revisionHash := hashRevision(newRevision) renterSig := &rhpv2.RPCWriteResponse{ Signature: renterKey.SignHash(revisionHash), } @@ -521,8 +412,9 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi // update deleted count deleted += uint64(len(batch)) - // record spending - w.contractSpendingRecorder.Record(rev.Revision, api.ContractSpending{Deletions: cost}) + // update revision + rev.Revision = newRevision + cost = cost.Add(batchCost) return nil }(); err != nil { return @@ -531,27 +423,7 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi return } -func (w *worker) FetchContractRoots(ctx context.Context, hostIP string, hostKey types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64) (roots []types.Hash256, err error) { - err = w.withTransportV2(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { - return w.withRevisionV2(defaultLockTimeout, t, hostKey, fcid, lastKnownRevisionNumber, func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) (err error) { - gc, err := GougingCheckerFromContext(ctx, false) - if err != nil { - return err - } - if breakdown := gc.Check(&settings, nil); breakdown.Gouging() { - return fmt.Errorf("failed to list contract roots: %v", breakdown) - } - roots, err = w.fetchContractRoots(t, &rev, settings) - return - }) - }) - return -} - -func (w *worker) fetchContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevision, settings rhpv2.HostSettings) (roots []types.Hash256, _ error) { - // derive the renter key - renterKey := w.deriveRenterKey(rev.HostKey()) - +func (c *Client) fetchContractRoots(t *rhpv2.Transport, renterKey types.PrivateKey, rev *rhpv2.ContractRevision, settings rhpv2.HostSettings) (roots []types.Hash256, cost types.Currency, _ error) { // download the full set of SectorRoots numsectors := rev.NumSectors() for offset := uint64(0); offset < numsectors; { @@ -561,123 +433,96 @@ func (w *worker) fetchContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevis } // calculate the cost - cost, _ := settings.RPCSectorRootsCost(offset, n).Total() + batchCost, _ := settings.RPCSectorRootsCost(offset, n).Total() // TODO: remove once host network is updated - if build.VersionCmp(settings.Version, "1.6.0") < 0 { + if utils.VersionCmp(settings.Version, "1.6.0") < 0 { // calculate the response size proofSize := rhpv2.RangeProofSize(numsectors, offset, offset+n) - responseSize := (proofSize + n) * crypto.HashSize + responseSize := (proofSize + n) * 32 if responseSize < minMessageSize { responseSize = minMessageSize } - cost = settings.BaseRPCPrice.Add(settings.DownloadBandwidthPrice.Mul64(responseSize)) - cost = cost.Mul64(2) // generous leeway + batchCost = settings.BaseRPCPrice.Add(settings.DownloadBandwidthPrice.Mul64(responseSize)) + batchCost = batchCost.Mul64(2) // generous leeway } // check funds - if rev.RenterFunds().Cmp(cost) < 0 { - return nil, ErrInsufficientFunds + if rev.RenterFunds().Cmp(batchCost) < 0 { + return nil, types.ZeroCurrency, ErrInsufficientFunds } // update the revision number if rev.Revision.RevisionNumber == math.MaxUint64 { - return nil, ErrContractFinalized + return nil, types.ZeroCurrency, ErrContractFinalized } rev.Revision.RevisionNumber++ // update the revision outputs - newValid, newMissed, err := updateRevisionOutputs(&rev.Revision, cost, types.ZeroCurrency) + newRevision, err := updatedRevision(rev.Revision, batchCost, types.ZeroCurrency) if err != nil { - return nil, err + return nil, types.ZeroCurrency, err } // build the sector roots request - revisionHash := hashRevision(rev.Revision) + revisionHash := hashRevision(newRevision) req := &rhpv2.RPCSectorRootsRequest{ RootOffset: uint64(offset), NumRoots: uint64(n), - RevisionNumber: rev.Revision.RevisionNumber, - ValidProofValues: newValid, - MissedProofValues: newMissed, - Signature: renterKey.SignHash(revisionHash), + RevisionNumber: rev.Revision.RevisionNumber, + ValidProofValues: []types.Currency{ + newRevision.ValidProofOutputs[0].Value, + newRevision.ValidProofOutputs[1].Value, + }, + MissedProofValues: []types.Currency{ + newRevision.MissedProofOutputs[0].Value, + newRevision.MissedProofOutputs[1].Value, + newRevision.MissedProofOutputs[2].Value, + }, + Signature: renterKey.SignHash(revisionHash), } // execute the sector roots RPC var rootsResp rhpv2.RPCSectorRootsResponse if err := t.WriteRequest(rhpv2.RPCSectorRootsID, req); err != nil { - return nil, err + return nil, types.ZeroCurrency, err } else if err := t.ReadResponse(&rootsResp, maxMerkleProofResponseSize); err != nil { - return nil, fmt.Errorf("couldn't read sector roots response: %w", err) + return nil, types.ZeroCurrency, fmt.Errorf("couldn't read sector roots response: %w", err) } // verify the host signature if !rev.HostKey().VerifyHash(revisionHash, rootsResp.Signature) { - return nil, errors.New("host's signature is invalid") + return nil, types.ZeroCurrency, errors.New("host's signature is invalid") } rev.Signatures[0].Signature = req.Signature[:] rev.Signatures[1].Signature = rootsResp.Signature[:] // verify the proof if uint64(len(rootsResp.SectorRoots)) != n { - return nil, fmt.Errorf("couldn't verify contract roots proof, host %v, version %v, err: number of roots does not match range %d != %d (num sectors: %d rev size: %d offset: %d)", rev.HostKey(), settings.Version, len(rootsResp.SectorRoots), n, numsectors, rev.Revision.Filesize, offset) + return nil, types.ZeroCurrency, fmt.Errorf("couldn't verify contract roots proof, host %v, version %v, err: number of roots does not match range %d != %d (num sectors: %d rev size: %d offset: %d)", rev.HostKey(), settings.Version, len(rootsResp.SectorRoots), n, numsectors, rev.Revision.Filesize, offset) } else if !rhpv2.VerifySectorRangeProof(rootsResp.MerkleProof, rootsResp.SectorRoots, offset, offset+n, numsectors, rev.Revision.FileMerkleRoot) { - return nil, fmt.Errorf("couldn't verify contract roots proof, host %v, version %v; %w", rev.HostKey(), settings.Version, ErrInvalidMerkleProof) + return nil, types.ZeroCurrency, fmt.Errorf("couldn't verify contract roots proof, host %v, version %v; %w", rev.HostKey(), settings.Version, ErrInvalidMerkleProof) } // append roots roots = append(roots, rootsResp.SectorRoots...) offset += n - // record spending - w.contractSpendingRecorder.Record(rev.Revision, api.ContractSpending{SectorRoots: cost}) + // update revision + rev.Revision = newRevision + cost = cost.Add(batchCost) } return } -func (w *worker) withTransportV2(ctx context.Context, hostKey types.PublicKey, hostIP string, fn func(*rhpv2.Transport) error) (err error) { - conn, err := dial(ctx, hostIP) - if err != nil { - return err - } - done := make(chan struct{}) - go func() { - select { - case <-done: - case <-ctx.Done(): - conn.Close() - } - }() - defer func() { - close(done) - if context.Cause(ctx) != nil { - err = context.Cause(ctx) - } - }() - t, err := rhpv2.NewRenterTransport(conn, hostKey) - if err != nil { - return err - } - defer t.Close() - - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic (withTransportV2): %v", r) - } - }() - return fn(t) -} - -func (w *worker) withRevisionV2(lockTimeout time.Duration, t *rhpv2.Transport, hk types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64, fn func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) error) error { - renterKey := w.deriveRenterKey(hk) - +func (w *Client) withRevisionV2(renterKey types.PrivateKey, gougingChecker gouging.Checker, t *rhpv2.Transport, fcid types.FileContractID, lastKnownRevisionNumber uint64, fn func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) error) error { // execute lock RPC var lockResp rhpv2.RPCLockResponse err := t.Call(rhpv2.RPCLockID, &rhpv2.RPCLockRequest{ ContractID: fcid, Signature: t.SignChallenge(renterKey), - Timeout: uint64(lockTimeout.Milliseconds()), + Timeout: uint64(defaultLockTimeout.Milliseconds()), }, &lockResp) if err != nil { return err @@ -722,19 +567,78 @@ func (w *worker) withRevisionV2(lockTimeout time.Duration, t *rhpv2.Transport, h return fmt.Errorf("couldn't unmarshal json: %w", err) } + // perform gouging checks on settings + if breakdown := gougingChecker.CheckSettings(settings); breakdown.Gouging() { + return fmt.Errorf("failed to prune contract: %v", breakdown) + } + return fn(t, rev, settings) } -func humanReadableSize(b int) string { - const unit = 1024 - if b < unit { - return fmt.Sprintf("%d B", b) +func (c *Client) withTransport(ctx context.Context, hostKey types.PublicKey, hostIP string, fn func(*rhpv2.Transport) error) (err error) { + conn, err := c.dialer.Dial(ctx, hostKey, hostIP) + if err != nil { + return err + } + done := make(chan struct{}) + go func() { + select { + case <-done: + case <-ctx.Done(): + conn.Close() + } + }() + defer func() { + close(done) + if context.Cause(ctx) != nil { + err = context.Cause(ctx) + } + }() + t, err := rhpv2.NewRenterTransport(conn, hostKey) + if err != nil { + return err + } + defer t.Close() + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic (withTransportV2): %v", r) + } + }() + return fn(t) +} + +func hashRevision(rev types.FileContractRevision) types.Hash256 { + h := types.NewHasher() + rev.EncodeTo(h.E) + return h.Sum() +} + +func updatedRevision(rev types.FileContractRevision, cost, collateral types.Currency) (types.FileContractRevision, error) { + // allocate new slices; don't want to risk accidentally sharing memory + rev.ValidProofOutputs = append([]types.SiacoinOutput(nil), rev.ValidProofOutputs...) + rev.MissedProofOutputs = append([]types.SiacoinOutput(nil), rev.MissedProofOutputs...) + + // move valid payout from renter to host + var underflow, overflow bool + rev.ValidProofOutputs[0].Value, underflow = rev.ValidProofOutputs[0].Value.SubWithUnderflow(cost) + rev.ValidProofOutputs[1].Value, overflow = rev.ValidProofOutputs[1].Value.AddWithOverflow(cost) + if underflow || overflow { + return types.FileContractRevision{}, errors.New("insufficient funds to pay host") + } + + // move missed payout from renter to void + rev.MissedProofOutputs[0].Value, underflow = rev.MissedProofOutputs[0].Value.SubWithUnderflow(cost) + rev.MissedProofOutputs[2].Value, overflow = rev.MissedProofOutputs[2].Value.AddWithOverflow(cost) + if underflow || overflow { + return types.FileContractRevision{}, errors.New("insufficient funds to move missed payout to void") } - div, exp := int64(unit), 0 - for n := b / unit; n >= unit; n /= unit { - div *= unit - exp++ + + // move collateral from host to void + rev.MissedProofOutputs[1].Value, underflow = rev.MissedProofOutputs[1].Value.SubWithUnderflow(collateral) + rev.MissedProofOutputs[2].Value, overflow = rev.MissedProofOutputs[2].Value.AddWithOverflow(collateral) + if underflow || overflow { + return types.FileContractRevision{}, errors.New("insufficient collateral") } - return fmt.Sprintf("%.1f %ciB", - float64(b)/float64(div), "KMGTPE"[exp]) + return rev, nil } diff --git a/internal/rhp/v2/rpc.go b/internal/rhp/v2/rpc.go new file mode 100644 index 000000000..5bd458742 --- /dev/null +++ b/internal/rhp/v2/rpc.go @@ -0,0 +1,115 @@ +package rhp + +import ( + "context" + "encoding/json" + "fmt" + + rhpv2 "go.sia.tech/core/rhp/v2" + "go.sia.tech/core/types" + "go.sia.tech/renterd/internal/utils" +) + +// rpcFormContract forms a contract with a host. +func rpcFormContract(ctx context.Context, t *rhpv2.Transport, renterKey types.PrivateKey, txnSet []types.Transaction) (_ rhpv2.ContractRevision, _ []types.Transaction, err error) { + defer utils.WrapErr(ctx, "FormContract", &err) + + // strip our signatures before sending + parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] + renterContractSignatures := txn.Signatures + txnSet[len(txnSet)-1].Signatures = nil + + // create request + renterPubkey := renterKey.PublicKey() + req := &rhpv2.RPCFormContractRequest{ + Transactions: txnSet, + RenterKey: renterPubkey.UnlockKey(), + } + if err := t.WriteRequest(rhpv2.RPCFormContractID, req); err != nil { + return rhpv2.ContractRevision{}, nil, err + } + + // execute form contract RPC + var resp rhpv2.RPCFormContractAdditions + if err := t.ReadResponse(&resp, 65536); err != nil { + return rhpv2.ContractRevision{}, nil, err + } + + // merge host additions with txn + txn.SiacoinInputs = append(txn.SiacoinInputs, resp.Inputs...) + txn.SiacoinOutputs = append(txn.SiacoinOutputs, resp.Outputs...) + + // create initial (no-op) revision, transaction, and signature + fc := txn.FileContracts[0] + initRevision := types.FileContractRevision{ + ParentID: txn.FileContractID(0), + UnlockConditions: types.UnlockConditions{ + PublicKeys: []types.UnlockKey{ + renterPubkey.UnlockKey(), + t.HostKey().UnlockKey(), + }, + SignaturesRequired: 2, + }, + FileContract: types.FileContract{ + RevisionNumber: 1, + Filesize: fc.Filesize, + FileMerkleRoot: fc.FileMerkleRoot, + WindowStart: fc.WindowStart, + WindowEnd: fc.WindowEnd, + ValidProofOutputs: fc.ValidProofOutputs, + MissedProofOutputs: fc.MissedProofOutputs, + UnlockHash: fc.UnlockHash, + }, + } + revSig := renterKey.SignHash(hashRevision(initRevision)) + renterRevisionSig := types.TransactionSignature{ + ParentID: types.Hash256(initRevision.ParentID), + CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, + PublicKeyIndex: 0, + Signature: revSig[:], + } + + // write our signatures + renterSigs := &rhpv2.RPCFormContractSignatures{ + ContractSignatures: renterContractSignatures, + RevisionSignature: renterRevisionSig, + } + if err := t.WriteResponse(renterSigs); err != nil { + return rhpv2.ContractRevision{}, nil, err + } + + // read the host's signatures and merge them with our own + var hostSigs rhpv2.RPCFormContractSignatures + if err := t.ReadResponse(&hostSigs, minMessageSize); err != nil { + return rhpv2.ContractRevision{}, nil, err + } + + txn.Signatures = make([]types.TransactionSignature, 0, len(renterContractSignatures)+len(hostSigs.ContractSignatures)) + txn.Signatures = append(txn.Signatures, renterContractSignatures...) + txn.Signatures = append(txn.Signatures, hostSigs.ContractSignatures...) + + signedTxnSet := make([]types.Transaction, 0, len(resp.Parents)+len(parents)+1) + signedTxnSet = append(signedTxnSet, resp.Parents...) + signedTxnSet = append(signedTxnSet, parents...) + signedTxnSet = append(signedTxnSet, txn) + return rhpv2.ContractRevision{ + Revision: initRevision, + Signatures: [2]types.TransactionSignature{ + renterRevisionSig, + hostSigs.RevisionSignature, + }, + }, signedTxnSet, nil +} + +// rpcSettings calls the Settings RPC, returning the host's reported settings. +func rpcSettings(ctx context.Context, t *rhpv2.Transport) (settings rhpv2.HostSettings, err error) { + defer utils.WrapErr(ctx, "Settings", &err) + + var resp rhpv2.RPCSettingsResponse + if err := t.Call(rhpv2.RPCSettingsID, nil, &resp); err != nil { + return rhpv2.HostSettings{}, err + } else if err := json.Unmarshal(resp.Settings, &settings); err != nil { + return rhpv2.HostSettings{}, fmt.Errorf("couldn't unmarshal json: %w", err) + } + return settings, nil +} diff --git a/internal/rhp/v3/rhp.go b/internal/rhp/v3/rhp.go new file mode 100644 index 000000000..5ae5d9972 --- /dev/null +++ b/internal/rhp/v3/rhp.go @@ -0,0 +1,403 @@ +package rhp + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "net" + + rhpv2 "go.sia.tech/core/rhp/v2" + rhpv3 "go.sia.tech/core/rhp/v3" + "go.sia.tech/core/types" + "go.sia.tech/mux/v1" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/gouging" + "go.sia.tech/renterd/internal/utils" + "go.uber.org/zap" +) + +const ( + + // defaultRPCResponseMaxSize is the default maxSize we use whenever we read + // an RPC response. + defaultRPCResponseMaxSize = 100 * 1024 // 100 KiB + + // defaultWithdrawalExpiryBlocks is the number of blocks we add to the + // current blockheight when we define an expiry block height for withdrawal + // messages. + defaultWithdrawalExpiryBlocks = 12 + + // maxPriceTableSize defines the maximum size of a price table + maxPriceTableSize = 16 * 1024 + + // responseLeeway is the amount of leeway given to the maxLen when we read + // the response in the ReadSector RPC + responseLeeway = 1 << 12 // 4 KiB +) + +var ( + // ErrFailedToCreatePayment is returned when the client failed to pay using a contract. + ErrFailedToCreatePayment = errors.New("failed to create contract payment") + + // errDialTransport is returned when the worker could not dial the host. + ErrDialTransport = errors.New("could not dial transport") + + // errBalanceInsufficient occurs when a withdrawal failed because the + // account balance was insufficient. + ErrBalanceInsufficient = errors.New("ephemeral account balance was insufficient") + + // ErrMaxRevisionReached occurs when trying to revise a contract that has + // already reached the highest possible revision number. Usually happens + // when trying to use a renewed contract. + ErrMaxRevisionReached = errors.New("contract has reached the maximum number of revisions") + + // ErrSectorNotFound is returned by a host when it can't find the requested + // sector. + ErrSectorNotFound = errors.New("sector not found") + + // errHost is used to wrap rpc errors returned by the host. + errHost = errors.New("host responded with error") + + // errTransport is used to wrap rpc errors caused by the transport. + errTransport = errors.New("transport error") + + // errBalanceMaxExceeded occurs when a deposit would push the account's + // balance over the maximum allowed ephemeral account balance. + errBalanceMaxExceeded = errors.New("ephemeral account maximum balance exceeded") + + // errInsufficientFunds is returned by various RPCs when the renter is + // unable to provide sufficient payment to the host. + errInsufficientFunds = errors.New("insufficient funds") + + // errPriceTableExpired is returned by the host when the price table that + // corresponds to the id it was given is already expired and thus no longer + // valid. + errPriceTableExpired = errors.New("price table requested is expired") + + // errPriceTableNotFound is returned by the host when it can not find a + // price table that corresponds with the id we sent it. + errPriceTableNotFound = errors.New("price table not found") + + // errSectorNotFound is returned by the host when it can not find the + // requested sector. + errSectorNotFoundOld = errors.New("could not find the desired sector") + + // errWithdrawalsInactive occurs when the host is (perhaps temporarily) + // unsynced and has disabled its account manager. + errWithdrawalsInactive = errors.New("ephemeral account withdrawals are inactive because the host is not synced") + + // errWithdrawalExpired is returned by the host when the withdrawal request + // has an expiry block height that is in the past. + errWithdrawalExpired = errors.New("withdrawal request expired") +) + +// IsErrHost indicates whether an error was returned by a host as part of an RPC. +func IsErrHost(err error) bool { + return utils.IsErr(err, errHost) +} + +func IsBalanceInsufficient(err error) bool { return utils.IsErr(err, ErrBalanceInsufficient) } +func IsBalanceMaxExceeded(err error) bool { return utils.IsErr(err, errBalanceMaxExceeded) } +func IsClosedStream(err error) bool { + return utils.IsErr(err, mux.ErrClosedStream) || utils.IsErr(err, net.ErrClosed) +} +func IsInsufficientFunds(err error) bool { return utils.IsErr(err, errInsufficientFunds) } +func IsPriceTableExpired(err error) bool { return utils.IsErr(err, errPriceTableExpired) } +func IsPriceTableGouging(err error) bool { return utils.IsErr(err, gouging.ErrPriceTableGouging) } +func IsPriceTableNotFound(err error) bool { return utils.IsErr(err, errPriceTableNotFound) } +func IsSectorNotFound(err error) bool { + return utils.IsErr(err, ErrSectorNotFound) || utils.IsErr(err, errSectorNotFoundOld) +} +func IsWithdrawalsInactive(err error) bool { return utils.IsErr(err, errWithdrawalsInactive) } +func IsWithdrawalExpired(err error) bool { return utils.IsErr(err, errWithdrawalExpired) } + +type ( + Dialer interface { + Dial(ctx context.Context, hk types.PublicKey, address string) (net.Conn, error) + } +) + +type Client struct { + logger *zap.SugaredLogger + tpool *transportPoolV3 +} + +func New(dialer Dialer, logger *zap.Logger) *Client { + return &Client{ + logger: logger.Sugar().Named("rhp3"), + tpool: newTransportPoolV3(dialer), + } +} + +func (c *Client) AppendSector(ctx context.Context, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte, rev *types.FileContractRevision, hk types.PublicKey, siamuxAddr string, accID rhpv3.Account, pt rhpv3.HostPriceTable, rk types.PrivateKey) (types.Currency, error) { + expectedCost, _, _, err := uploadSectorCost(pt, rev.WindowEnd) + if err != nil { + return types.ZeroCurrency, err + } + payment, err := payByContract(rev, expectedCost, accID, rk) + if err != nil { + return types.ZeroCurrency, ErrFailedToCreatePayment + } + + var cost types.Currency + err = c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { + cost, err = rpcAppendSector(ctx, t, rk, pt, rev, &payment, sectorRoot, sector) + return err + }) + return cost, err +} + +func (c *Client) FundAccount(ctx context.Context, rev *types.FileContractRevision, hk types.PublicKey, siamuxAddr string, amount types.Currency, accID rhpv3.Account, pt rhpv3.HostPriceTable, renterKey types.PrivateKey) error { + ppcr, err := payByContract(rev, amount.Add(types.NewCurrency64(1)), rhpv3.ZeroAccount, renterKey) + if err != nil { + return fmt.Errorf("failed to create payment: %w", err) + } + return c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { + return rpcFundAccount(ctx, t, &ppcr, accID, pt.UID) + }) +} + +func (c *Client) Renew(ctx context.Context, rrr api.RHPRenewRequest, gougingChecker gouging.Checker, renewer PrepareRenewFunc, signer SignFunc, rev types.FileContractRevision, renterKey types.PrivateKey) (newRev rhpv2.ContractRevision, txnSet []types.Transaction, contractPrice, fundAmount types.Currency, err error) { + err = c.tpool.withTransport(ctx, rrr.HostKey, rrr.SiamuxAddr, func(ctx context.Context, t *transportV3) error { + newRev, txnSet, contractPrice, fundAmount, err = rpcRenew(ctx, rrr, gougingChecker, renewer, signer, t, rev, renterKey) + return err + }) + return +} + +func (c *Client) SyncAccount(ctx context.Context, rev *types.FileContractRevision, hk types.PublicKey, siamuxAddr string, accID rhpv3.Account, pt rhpv3.SettingsID, rk types.PrivateKey) (balance types.Currency, _ error) { + return balance, c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { + payment, err := payByContract(rev, types.NewCurrency64(1), accID, rk) + if err != nil { + return err + } + balance, err = rpcAccountBalance(ctx, t, &payment, accID, pt) + return err + }) +} + +func (c *Client) PriceTable(ctx context.Context, hk types.PublicKey, siamuxAddr string, paymentFn PriceTablePaymentFunc) (pt api.HostPriceTable, err error) { + err = c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { + pt, err = rpcPriceTable(ctx, t, paymentFn) + return err + }) + return +} + +func (c *Client) PriceTableUnpaid(ctx context.Context, hk types.PublicKey, siamuxAddr string) (pt api.HostPriceTable, err error) { + err = c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { + pt, err = rpcPriceTable(ctx, t, func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) { return nil, nil }) + if err != nil { + return fmt.Errorf("failed to fetch host price table: %w", err) + } + return err + }) + return +} + +func (c *Client) ReadSector(ctx context.Context, offset, length uint32, root types.Hash256, w io.Writer, hk types.PublicKey, siamuxAddr string, accID rhpv3.Account, accKey types.PrivateKey, pt rhpv3.HostPriceTable) (types.Currency, error) { + var amount types.Currency + err := c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { + cost, err := readSectorCost(pt, uint64(length)) + if err != nil { + return err + } + + amount = cost // pessimistic cost estimate in case rpc fails + payment := rhpv3.PayByEphemeralAccount(accID, cost, pt.HostBlockHeight+defaultWithdrawalExpiryBlocks, accKey) + cost, refund, err := rpcReadSector(ctx, t, w, pt, &payment, offset, length, root) + if err != nil { + return err + } + + amount = cost.Sub(refund) + return nil + }) + return amount, err +} + +func (c *Client) Revision(ctx context.Context, fcid types.FileContractID, hk types.PublicKey, siamuxAddr string) (rev types.FileContractRevision, err error) { + return rev, c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { + rev, err = rpcLatestRevision(ctx, t, fcid) + return err + }) +} + +// PreparePriceTableAccountPayment prepare a payment function to pay for a price +// table from the given host using the provided revision. +// +// NOTE: This is the preferred way of paying for a price table since it is +// faster and doesn't require locking a contract. +func PreparePriceTableAccountPayment(accKey types.PrivateKey) PriceTablePaymentFunc { + return func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) { + accID := rhpv3.Account(accKey.PublicKey()) + payment := rhpv3.PayByEphemeralAccount(accID, pt.UpdatePriceTableCost, pt.HostBlockHeight+defaultWithdrawalExpiryBlocks, accKey) + return &payment, nil + } +} + +// PreparePriceTableContractPayment prepare a payment function to pay for a +// price table from the given host using the provided revision. +// +// NOTE: This way of paying for a price table should only be used if payment by +// EA is not possible or if we already need a contract revision anyway. e.g. +// funding an EA. +func PreparePriceTableContractPayment(rev *types.FileContractRevision, refundAccID rhpv3.Account, renterKey types.PrivateKey) PriceTablePaymentFunc { + return func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) { + payment, err := payByContract(rev, pt.UpdatePriceTableCost, refundAccID, renterKey) + if err != nil { + return nil, err + } + return &payment, nil + } +} + +// padBandwitdh pads the bandwidth to the next multiple of 1460 bytes. 1460 +// bytes is the maximum size of a TCP packet when using IPv4. +// TODO: once hostd becomes the only host implementation we can simplify this. +func padBandwidth(pt rhpv3.HostPriceTable, rc rhpv3.ResourceCost) rhpv3.ResourceCost { + padCost := func(cost, paddingSize types.Currency) types.Currency { + if paddingSize.IsZero() { + return cost // might happen if bandwidth is free + } + return cost.Add(paddingSize).Sub(types.NewCurrency64(1)).Div(paddingSize).Mul(paddingSize) + } + minPacketSize := uint64(1460) + minIngress := pt.UploadBandwidthCost.Mul64(minPacketSize) + minEgress := pt.DownloadBandwidthCost.Mul64(3*minPacketSize + responseLeeway) + rc.Ingress = padCost(rc.Ingress, minIngress) + rc.Egress = padCost(rc.Egress, minEgress) + return rc +} + +// readSectorCost returns an overestimate for the cost of reading a sector from a host +func readSectorCost(pt rhpv3.HostPriceTable, length uint64) (types.Currency, error) { + rc := pt.BaseCost() + rc = rc.Add(pt.ReadSectorCost(length)) + rc = padBandwidth(pt, rc) + cost, _ := rc.Total() + + // overestimate the cost by 10% + cost, overflow := cost.Mul64WithOverflow(11) + if overflow { + return types.ZeroCurrency, errors.New("overflow occurred while adding leeway to read sector cost") + } + return cost.Div64(10), nil +} + +// uploadSectorCost returns an overestimate for the cost of uploading a sector +// to a host +func uploadSectorCost(pt rhpv3.HostPriceTable, windowEnd uint64) (cost, collateral, storage types.Currency, _ error) { + rc := pt.BaseCost() + rc = rc.Add(pt.AppendSectorCost(windowEnd - pt.HostBlockHeight)) + rc = padBandwidth(pt, rc) + cost, collateral = rc.Total() + + // overestimate the cost by 10% + cost, overflow := cost.Mul64WithOverflow(11) + if overflow { + return types.ZeroCurrency, types.ZeroCurrency, types.ZeroCurrency, errors.New("overflow occurred while adding leeway to read sector cost") + } + return cost.Div64(10), collateral, rc.Storage, nil +} + +func processPayment(s *streamV3, payment rhpv3.PaymentMethod) error { + var paymentType types.Specifier + switch payment.(type) { + case *rhpv3.PayByContractRequest: + paymentType = rhpv3.PaymentTypeContract + case *rhpv3.PayByEphemeralAccountRequest: + paymentType = rhpv3.PaymentTypeEphemeralAccount + default: + panic("unhandled payment method") + } + if err := s.WriteResponse(&paymentType); err != nil { + return err + } else if err := s.WriteResponse(payment); err != nil { + return err + } + if _, ok := payment.(*rhpv3.PayByContractRequest); ok { + var pr rhpv3.PaymentResponse + if err := s.ReadResponse(&pr, defaultRPCResponseMaxSize); err != nil { + return err + } + // TODO: return host signature + } + return nil +} + +func hashRevision(rev types.FileContractRevision) types.Hash256 { + h := types.NewHasher() + rev.EncodeTo(h.E) + return h.Sum() +} + +// initialRevision returns the first revision of a file contract formation +// transaction. +func initialRevision(formationTxn types.Transaction, hostPubKey, renterPubKey types.UnlockKey) types.FileContractRevision { + fc := formationTxn.FileContracts[0] + return types.FileContractRevision{ + ParentID: formationTxn.FileContractID(0), + UnlockConditions: types.UnlockConditions{ + PublicKeys: []types.UnlockKey{renterPubKey, hostPubKey}, + SignaturesRequired: 2, + }, + FileContract: types.FileContract{ + Filesize: fc.Filesize, + FileMerkleRoot: fc.FileMerkleRoot, + WindowStart: fc.WindowStart, + WindowEnd: fc.WindowEnd, + ValidProofOutputs: fc.ValidProofOutputs, + MissedProofOutputs: fc.MissedProofOutputs, + UnlockHash: fc.UnlockHash, + RevisionNumber: 1, + }, + } +} + +func payByContract(rev *types.FileContractRevision, amount types.Currency, refundAcct rhpv3.Account, sk types.PrivateKey) (rhpv3.PayByContractRequest, error) { + if rev.RevisionNumber == math.MaxUint64 { + return rhpv3.PayByContractRequest{}, ErrMaxRevisionReached + } + payment, ok := rhpv3.PayByContract(rev, amount, refundAcct, sk) + if !ok { + return rhpv3.PayByContractRequest{}, errInsufficientFunds + } + return payment, nil +} + +func updateRevisionOutputs(rev *types.FileContractRevision, cost, collateral types.Currency) (valid, missed []types.Currency, err error) { + // allocate new slices; don't want to risk accidentally sharing memory + rev.ValidProofOutputs = append([]types.SiacoinOutput(nil), rev.ValidProofOutputs...) + rev.MissedProofOutputs = append([]types.SiacoinOutput(nil), rev.MissedProofOutputs...) + + // move valid payout from renter to host + var underflow, overflow bool + rev.ValidProofOutputs[0].Value, underflow = rev.ValidProofOutputs[0].Value.SubWithUnderflow(cost) + rev.ValidProofOutputs[1].Value, overflow = rev.ValidProofOutputs[1].Value.AddWithOverflow(cost) + if underflow || overflow { + err = errors.New("insufficient funds to pay host") + return + } + + // move missed payout from renter to void + rev.MissedProofOutputs[0].Value, underflow = rev.MissedProofOutputs[0].Value.SubWithUnderflow(cost) + rev.MissedProofOutputs[2].Value, overflow = rev.MissedProofOutputs[2].Value.AddWithOverflow(cost) + if underflow || overflow { + err = errors.New("insufficient funds to move missed payout to void") + return + } + + // move collateral from host to void + rev.MissedProofOutputs[1].Value, underflow = rev.MissedProofOutputs[1].Value.SubWithUnderflow(collateral) + rev.MissedProofOutputs[2].Value, overflow = rev.MissedProofOutputs[2].Value.AddWithOverflow(collateral) + if underflow || overflow { + err = errors.New("insufficient collateral") + return + } + + return []types.Currency{rev.ValidProofOutputs[0].Value, rev.ValidProofOutputs[1].Value}, + []types.Currency{rev.MissedProofOutputs[0].Value, rev.MissedProofOutputs[1].Value, rev.MissedProofOutputs[2].Value}, nil +} diff --git a/internal/rhp/v3/rpc.go b/internal/rhp/v3/rpc.go new file mode 100644 index 000000000..746cac5fe --- /dev/null +++ b/internal/rhp/v3/rpc.go @@ -0,0 +1,500 @@ +package rhp + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "time" + + rhpv2 "go.sia.tech/core/rhp/v2" + rhpv3 "go.sia.tech/core/rhp/v3" + "go.sia.tech/core/types" + "go.sia.tech/mux/v1" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/gouging" + "go.sia.tech/renterd/internal/utils" +) + +type ( + // PriceTablePaymentFunc is a function that can be passed in to RPCPriceTable. + // It is called after the price table is received from the host and supposed to + // create a payment for that table and return it. It can also be used to perform + // gouging checks before paying for the table. + PriceTablePaymentFunc func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) + + PrepareRenewFunc func(ctx context.Context, revision types.FileContractRevision, hostAddress, renterAddress types.Address, renterKey types.PrivateKey, renterFunds, minNewCollateral, maxFundAmount types.Currency, pt rhpv3.HostPriceTable, endHeight, windowSize, expectedStorage uint64) (resp api.WalletPrepareRenewResponse, discard func(context.Context, types.Transaction, *error), err error) + + SignFunc func(ctx context.Context, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error +) + +// rpcPriceTable calls the UpdatePriceTable RPC. +func rpcPriceTable(ctx context.Context, t *transportV3, paymentFunc PriceTablePaymentFunc) (_ api.HostPriceTable, err error) { + defer utils.WrapErr(ctx, "PriceTable", &err) + + s, err := t.DialStream(ctx) + if err != nil { + return api.HostPriceTable{}, err + } + defer s.Close() + + var pt rhpv3.HostPriceTable + var ptr rhpv3.RPCUpdatePriceTableResponse + if err := s.WriteRequest(rhpv3.RPCUpdatePriceTableID, nil); err != nil { + return api.HostPriceTable{}, fmt.Errorf("couldn't send RPCUpdatePriceTableID: %w", err) + } else if err := s.ReadResponse(&ptr, maxPriceTableSize); err != nil { + return api.HostPriceTable{}, fmt.Errorf("couldn't read RPCUpdatePriceTableResponse: %w", err) + } else if err := json.Unmarshal(ptr.PriceTableJSON, &pt); err != nil { + return api.HostPriceTable{}, fmt.Errorf("couldn't unmarshal price table: %w", err) + } else if payment, err := paymentFunc(pt); err != nil { + return api.HostPriceTable{}, fmt.Errorf("couldn't create payment: %w", err) + } else if payment == nil { + return api.HostPriceTable{ + HostPriceTable: pt, + Expiry: time.Now(), + }, nil // intended not to pay + } else if err := processPayment(s, payment); err != nil { + return api.HostPriceTable{}, fmt.Errorf("couldn't process payment: %w", err) + } else if err := s.ReadResponse(&rhpv3.RPCPriceTableResponse{}, 0); err != nil { + return api.HostPriceTable{}, fmt.Errorf("couldn't read RPCPriceTableResponse: %w", err) + } else { + return api.HostPriceTable{ + HostPriceTable: pt, + Expiry: time.Now().Add(pt.Validity), + }, nil + } +} + +// rpcAccountBalance calls the AccountBalance RPC. +func rpcAccountBalance(ctx context.Context, t *transportV3, payment rhpv3.PaymentMethod, account rhpv3.Account, settingsID rhpv3.SettingsID) (bal types.Currency, err error) { + defer utils.WrapErr(ctx, "AccountBalance", &err) + s, err := t.DialStream(ctx) + if err != nil { + return types.ZeroCurrency, err + } + defer s.Close() + + req := rhpv3.RPCAccountBalanceRequest{ + Account: account, + } + var resp rhpv3.RPCAccountBalanceResponse + if err := s.WriteRequest(rhpv3.RPCAccountBalanceID, &settingsID); err != nil { + return types.ZeroCurrency, err + } else if err := processPayment(s, payment); err != nil { + return types.ZeroCurrency, err + } else if err := s.WriteResponse(&req); err != nil { + return types.ZeroCurrency, err + } else if err := s.ReadResponse(&resp, 128); err != nil { + return types.ZeroCurrency, err + } + return resp.Balance, nil +} + +// rpcFundAccount calls the FundAccount RPC. +func rpcFundAccount(ctx context.Context, t *transportV3, payment rhpv3.PaymentMethod, account rhpv3.Account, settingsID rhpv3.SettingsID) (err error) { + defer utils.WrapErr(ctx, "FundAccount", &err) + s, err := t.DialStream(ctx) + if err != nil { + return err + } + defer s.Close() + + req := rhpv3.RPCFundAccountRequest{ + Account: account, + } + var resp rhpv3.RPCFundAccountResponse + if err := s.WriteRequest(rhpv3.RPCFundAccountID, &settingsID); err != nil { + return err + } else if err := s.WriteResponse(&req); err != nil { + return err + } else if err := processPayment(s, payment); err != nil { + return err + } else if err := s.ReadResponse(&resp, defaultRPCResponseMaxSize); err != nil { + return err + } + return nil +} + +// rpcLatestRevision calls the LatestRevision RPC. The paymentFunc allows for +// fetching a pricetable using the fetched revision to pay for it. If +// paymentFunc returns 'nil' as payment, the host is not paid. +func rpcLatestRevision(ctx context.Context, t *transportV3, contractID types.FileContractID) (_ types.FileContractRevision, err error) { + defer utils.WrapErr(ctx, "LatestRevision", &err) + s, err := t.DialStream(ctx) + if err != nil { + return types.FileContractRevision{}, err + } + defer s.Close() + req := rhpv3.RPCLatestRevisionRequest{ + ContractID: contractID, + } + var resp rhpv3.RPCLatestRevisionResponse + if err := s.WriteRequest(rhpv3.RPCLatestRevisionID, &req); err != nil { + return types.FileContractRevision{}, err + } else if err := s.ReadResponse(&resp, defaultRPCResponseMaxSize); err != nil { + return types.FileContractRevision{}, err + } + return resp.Revision, nil +} + +// rpcReadSector calls the ExecuteProgram RPC with a ReadSector instruction. +func rpcReadSector(ctx context.Context, t *transportV3, w io.Writer, pt rhpv3.HostPriceTable, payment rhpv3.PaymentMethod, offset, length uint32, merkleRoot types.Hash256) (cost, refund types.Currency, err error) { + defer utils.WrapErr(ctx, "ReadSector", &err) + s, err := t.DialStream(ctx) + if err != nil { + return types.ZeroCurrency, types.ZeroCurrency, err + } + defer s.Close() + + var buf bytes.Buffer + e := types.NewEncoder(&buf) + e.WriteUint64(uint64(length)) + e.WriteUint64(uint64(offset)) + merkleRoot.EncodeTo(e) + e.Flush() + + req := rhpv3.RPCExecuteProgramRequest{ + FileContractID: types.FileContractID{}, + Program: []rhpv3.Instruction{&rhpv3.InstrReadSector{ + LengthOffset: 0, + OffsetOffset: 8, + MerkleRootOffset: 16, + ProofRequired: true, + }}, + ProgramData: buf.Bytes(), + } + + var cancellationToken types.Specifier + var resp rhpv3.RPCExecuteProgramResponse + if err = s.WriteRequest(rhpv3.RPCExecuteProgramID, &pt.UID); err != nil { + return + } else if err = processPayment(s, payment); err != nil { + return + } else if err = s.WriteResponse(&req); err != nil { + return + } else if err = s.ReadResponse(&cancellationToken, 16); err != nil { + return + } else if err = s.ReadResponse(&resp, rhpv2.SectorSize+responseLeeway); err != nil { + return + } + + // check response error + if err = resp.Error; err != nil { + refund = resp.FailureRefund + return + } + cost = resp.TotalCost + + // verify proof + proofStart := uint64(offset) / rhpv2.LeafSize + proofEnd := uint64(offset+length) / rhpv2.LeafSize + verifier := rhpv2.NewRangeProofVerifier(proofStart, proofEnd) + _, err = verifier.ReadFrom(bytes.NewReader(resp.Output)) + if err != nil { + err = fmt.Errorf("failed to read proof: %w", err) + return + } else if !verifier.Verify(resp.Proof, merkleRoot) { + err = errors.New("proof verification failed") + return + } + + _, err = w.Write(resp.Output) + return +} + +func rpcAppendSector(ctx context.Context, t *transportV3, renterKey types.PrivateKey, pt rhpv3.HostPriceTable, rev *types.FileContractRevision, payment rhpv3.PaymentMethod, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte) (cost types.Currency, err error) { + defer utils.WrapErr(ctx, "AppendSector", &err) + + // sanity check revision first + if rev.RevisionNumber == math.MaxUint64 { + return types.ZeroCurrency, ErrMaxRevisionReached + } + + s, err := t.DialStream(ctx) + if err != nil { + return types.ZeroCurrency, err + } + defer s.Close() + + req := rhpv3.RPCExecuteProgramRequest{ + FileContractID: rev.ParentID, + Program: []rhpv3.Instruction{&rhpv3.InstrAppendSector{ + SectorDataOffset: 0, + ProofRequired: true, + }}, + ProgramData: (*sector)[:], + } + + var cancellationToken types.Specifier + var executeResp rhpv3.RPCExecuteProgramResponse + if err = s.WriteRequest(rhpv3.RPCExecuteProgramID, &pt.UID); err != nil { + return + } else if err = processPayment(s, payment); err != nil { + return + } else if err = s.WriteResponse(&req); err != nil { + return + } else if err = s.ReadResponse(&cancellationToken, 16); err != nil { + return + } else if err = s.ReadResponse(&executeResp, defaultRPCResponseMaxSize); err != nil { + return + } + + // compute expected collateral and refund + expectedCost, expectedCollateral, expectedRefund, err := uploadSectorCost(pt, rev.WindowEnd) + if err != nil { + return types.ZeroCurrency, err + } + + // apply leeways. + // TODO: remove once most hosts use hostd. Then we can check for exact values. + expectedCollateral = expectedCollateral.Mul64(9).Div64(10) + expectedCost = expectedCost.Mul64(11).Div64(10) + expectedRefund = expectedRefund.Mul64(9).Div64(10) + + // check if the cost, collateral and refund match our expectation. + if executeResp.TotalCost.Cmp(expectedCost) > 0 { + return types.ZeroCurrency, fmt.Errorf("cost exceeds expectation: %v > %v", executeResp.TotalCost.String(), expectedCost.String()) + } + if executeResp.FailureRefund.Cmp(expectedRefund) < 0 { + return types.ZeroCurrency, fmt.Errorf("insufficient refund: %v < %v", executeResp.FailureRefund.String(), expectedRefund.String()) + } + if executeResp.AdditionalCollateral.Cmp(expectedCollateral) < 0 { + return types.ZeroCurrency, fmt.Errorf("insufficient collateral: %v < %v", executeResp.AdditionalCollateral.String(), expectedCollateral.String()) + } + + // set the cost and refund + cost = executeResp.TotalCost + defer func() { + if err != nil { + cost = types.ZeroCurrency + if executeResp.FailureRefund.Cmp(cost) < 0 { + cost = cost.Sub(executeResp.FailureRefund) + } + } + }() + + // check response error + if err = executeResp.Error; err != nil { + return + } + cost = executeResp.TotalCost + + // include the refund in the collateral + collateral := executeResp.AdditionalCollateral.Add(executeResp.FailureRefund) + + // check proof + if rev.Filesize == 0 { + // For the first upload to a contract we don't get a proof. So we just + // assert that the new contract root matches the root of the sector. + if rev.Filesize == 0 && executeResp.NewMerkleRoot != sectorRoot { + return types.ZeroCurrency, fmt.Errorf("merkle root doesn't match the sector root upon first upload to contract: %v != %v", executeResp.NewMerkleRoot, sectorRoot) + } + } else { + // Otherwise we make sure the proof was transmitted and verify it. + actions := []rhpv2.RPCWriteAction{{Type: rhpv2.RPCWriteActionAppend}} // TODO: change once rhpv3 support is available + if !rhpv2.VerifyDiffProof(actions, rev.Filesize/rhpv2.SectorSize, executeResp.Proof, []types.Hash256{}, rev.FileMerkleRoot, executeResp.NewMerkleRoot, []types.Hash256{sectorRoot}) { + return types.ZeroCurrency, errors.New("proof verification failed") + } + } + + // finalize the program with a new revision. + newRevision := *rev + newValid, newMissed, err := updateRevisionOutputs(&newRevision, types.ZeroCurrency, collateral) + if err != nil { + return types.ZeroCurrency, err + } + newRevision.Filesize += rhpv2.SectorSize + newRevision.RevisionNumber++ + newRevision.FileMerkleRoot = executeResp.NewMerkleRoot + + finalizeReq := rhpv3.RPCFinalizeProgramRequest{ + Signature: renterKey.SignHash(hashRevision(newRevision)), + ValidProofValues: newValid, + MissedProofValues: newMissed, + RevisionNumber: newRevision.RevisionNumber, + } + + var finalizeResp rhpv3.RPCFinalizeProgramResponse + if err = s.WriteResponse(&finalizeReq); err != nil { + return + } else if err = s.ReadResponse(&finalizeResp, 64); err != nil { + return + } + + // read one more time to receive a potential error in case finalising the + // contract fails after receiving the RPCFinalizeProgramResponse. This also + // guarantees that the program is finalised before we return. + // TODO: remove once most hosts use hostd. + errFinalise := s.ReadResponse(&finalizeResp, 64) + if errFinalise != nil && + !errors.Is(errFinalise, io.EOF) && + !errors.Is(errFinalise, mux.ErrClosedConn) && + !errors.Is(errFinalise, mux.ErrClosedStream) && + !errors.Is(errFinalise, mux.ErrPeerClosedStream) && + !errors.Is(errFinalise, mux.ErrPeerClosedConn) { + err = errFinalise + return + } + + *rev = newRevision + return +} + +func rpcRenew(ctx context.Context, rrr api.RHPRenewRequest, gougingChecker gouging.Checker, prepareRenew PrepareRenewFunc, signTxn SignFunc, t *transportV3, rev types.FileContractRevision, renterKey types.PrivateKey) (_ rhpv2.ContractRevision, _ []types.Transaction, _, _ types.Currency, err error) { + defer utils.WrapErr(ctx, "RPCRenew", &err) + + s, err := t.DialStream(ctx) + if err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to dial stream: %w", err) + } + defer s.Close() + + // Send the ptUID. + if err = s.WriteRequest(rhpv3.RPCRenewContractID, &rhpv3.SettingsID{}); err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to send ptUID: %w", err) + } + + // Read the temporary one from the host. + var ptResp rhpv3.RPCUpdatePriceTableResponse + if err = s.ReadResponse(&ptResp, defaultRPCResponseMaxSize); err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to read RPCUpdatePriceTableResponse: %w", err) + } + var pt rhpv3.HostPriceTable + if err = json.Unmarshal(ptResp.PriceTableJSON, &pt); err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to unmarshal price table: %w", err) + } + + // Perform gouging checks. + if breakdown := gougingChecker.Check(nil, &pt); breakdown.Gouging() { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("host gouging during renew: %v", breakdown) + } + + // Prepare the signed transaction that contains the final revision as well + // as the new contract + wprr, discard, err := prepareRenew(ctx, rev, rrr.HostAddress, rrr.RenterAddress, renterKey, rrr.RenterFunds, rrr.MinNewCollateral, rrr.MaxFundAmount, pt, rrr.EndHeight, rrr.WindowSize, rrr.ExpectedNewStorage) + if err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to prepare renew: %w", err) + } + + // Starting from here, we need to make sure to release the txn on error. + defer discard(ctx, wprr.TransactionSet[len(wprr.TransactionSet)-1], &err) + + txnSet := wprr.TransactionSet + parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] + + // Sign only the revision and contract. We can't sign everything because + // then the host can't add its own outputs. + h := types.NewHasher() + txn.FileContracts[0].EncodeTo(h.E) + txn.FileContractRevisions[0].EncodeTo(h.E) + finalRevisionSignature := renterKey.SignHash(h.Sum()) + + // Send the request. + req := rhpv3.RPCRenewContractRequest{ + TransactionSet: txnSet, + RenterKey: rev.UnlockConditions.PublicKeys[0], + FinalRevisionSignature: finalRevisionSignature, + } + if err = s.WriteResponse(&req); err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to send RPCRenewContractRequest: %w", err) + } + + // Incorporate the host's additions. + var hostAdditions rhpv3.RPCRenewContractHostAdditions + if err = s.ReadResponse(&hostAdditions, defaultRPCResponseMaxSize); err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to read RPCRenewContractHostAdditions: %w", err) + } + parents = append(parents, hostAdditions.Parents...) + txn.SiacoinInputs = append(txn.SiacoinInputs, hostAdditions.SiacoinInputs...) + txn.SiacoinOutputs = append(txn.SiacoinOutputs, hostAdditions.SiacoinOutputs...) + finalRevRenterSig := types.TransactionSignature{ + ParentID: types.Hash256(rev.ParentID), + PublicKeyIndex: 0, // renter key is first + CoveredFields: types.CoveredFields{ + FileContracts: []uint64{0}, + FileContractRevisions: []uint64{0}, + }, + Signature: finalRevisionSignature[:], + } + finalRevHostSig := types.TransactionSignature{ + ParentID: types.Hash256(rev.ParentID), + PublicKeyIndex: 1, + CoveredFields: types.CoveredFields{ + FileContracts: []uint64{0}, + FileContractRevisions: []uint64{0}, + }, + Signature: hostAdditions.FinalRevisionSignature[:], + } + txn.Signatures = []types.TransactionSignature{finalRevRenterSig, finalRevHostSig} + + // Sign the inputs we funded the txn with and cover the whole txn including + // the existing signatures. + cf := types.CoveredFields{ + WholeTransaction: true, + Signatures: []uint64{0, 1}, + } + if err := signTxn(ctx, &txn, wprr.ToSign, cf); err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to sign transaction: %w", err) + } + + // Create a new no-op revision and sign it. + noOpRevision := initialRevision(txn, rev.UnlockConditions.PublicKeys[1], renterKey.PublicKey().UnlockKey()) + h = types.NewHasher() + noOpRevision.EncodeTo(h.E) + renterNoOpSig := renterKey.SignHash(h.Sum()) + renterNoOpRevisionSignature := types.TransactionSignature{ + ParentID: types.Hash256(noOpRevision.ParentID), + PublicKeyIndex: 0, // renter key is first + CoveredFields: types.CoveredFields{ + FileContractRevisions: []uint64{0}, + }, + Signature: renterNoOpSig[:], + } + + // Send the newly added signatures to the host and the signature for the + // initial no-op revision. + rs := rhpv3.RPCRenewSignatures{ + TransactionSignatures: txn.Signatures[2:], + RevisionSignature: renterNoOpRevisionSignature, + } + if err = s.WriteResponse(&rs); err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to send RPCRenewSignatures: %w", err) + } + + // Receive the host's signatures. + var hostSigs rhpv3.RPCRenewSignatures + if err = s.ReadResponse(&hostSigs, defaultRPCResponseMaxSize); err != nil { + return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to read RPCRenewSignatures: %w", err) + } + txn.Signatures = append(txn.Signatures, hostSigs.TransactionSignatures...) + + // Add the parents to get the full txnSet. + txnSet = parents + txnSet = append(txnSet, txn) + + return rhpv2.ContractRevision{ + Revision: noOpRevision, + Signatures: [2]types.TransactionSignature{renterNoOpRevisionSignature, hostSigs.RevisionSignature}, + }, txnSet, pt.ContractPrice, wprr.FundAmount, nil +} + +// wrapRPCErr extracts the innermost error, wraps it in either a errHost or +// errTransport and finally wraps it using the provided fnName. +func wrapRPCErr(err *error, fnName string) { + if *err == nil { + return + } + innerErr := *err + for errors.Unwrap(innerErr) != nil { + innerErr = errors.Unwrap(innerErr) + } + if errors.As(*err, new(*rhpv3.RPCError)) { + *err = fmt.Errorf("%w: '%w'", errHost, innerErr) + } else { + *err = fmt.Errorf("%w: '%w'", errTransport, innerErr) + } + *err = fmt.Errorf("%s: %w", fnName, *err) +} diff --git a/worker/rhpv3_test.go b/internal/rhp/v3/rpc_test.go similarity index 98% rename from worker/rhpv3_test.go rename to internal/rhp/v3/rpc_test.go index 83f605807..2a4545793 100644 --- a/worker/rhpv3_test.go +++ b/internal/rhp/v3/rpc_test.go @@ -1,4 +1,4 @@ -package worker +package rhp import ( "errors" diff --git a/internal/rhp/v3/transport.go b/internal/rhp/v3/transport.go new file mode 100644 index 000000000..e08d88d45 --- /dev/null +++ b/internal/rhp/v3/transport.go @@ -0,0 +1,172 @@ +package rhp + +import ( + "context" + "fmt" + "sync" + "time" + + rhpv3 "go.sia.tech/core/rhp/v3" + "go.sia.tech/core/types" +) + +type transportPoolV3 struct { + dialer Dialer + + mu sync.Mutex + pool map[string]*transportV3 +} + +func newTransportPoolV3(dialer Dialer) *transportPoolV3 { + return &transportPoolV3{ + dialer: dialer, + pool: make(map[string]*transportV3), + } +} + +func (p *transportPoolV3) withTransport(ctx context.Context, hostKey types.PublicKey, siamuxAddr string, fn func(context.Context, *transportV3) error) (err error) { + // Create or fetch transport. + p.mu.Lock() + t, found := p.pool[siamuxAddr] + if !found { + t = &transportV3{ + dialer: p.dialer, + hostKey: hostKey, + siamuxAddr: siamuxAddr, + } + p.pool[siamuxAddr] = t + } + t.refCount++ + p.mu.Unlock() + + // Execute function. + err = func() (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic (withTransportV3): %v", r) + } + }() + return fn(ctx, t) + }() + + // Decrement refcounter again and clean up pool. + p.mu.Lock() + t.refCount-- + if t.refCount == 0 { + // Cleanup + if t.t != nil { + _ = t.t.Close() + t.t = nil + } + delete(p.pool, siamuxAddr) + } + p.mu.Unlock() + return err +} + +// transportPoolV3 is a pool of rhpv3.Transports which allows for reusing them. +func dialTransport(ctx context.Context, dialer Dialer, siamuxAddr string, hostKey types.PublicKey) (*rhpv3.Transport, error) { + // Dial host. + conn, err := dialer.Dial(ctx, hostKey, siamuxAddr) + if err != nil { + return nil, err + } + + // Upgrade to rhpv3.Transport. + var t *rhpv3.Transport + done := make(chan struct{}) + go func() { + t, err = rhpv3.NewRenterTransport(conn, hostKey) + close(done) + }() + select { + case <-ctx.Done(): + conn.Close() + <-done + return nil, context.Cause(ctx) + case <-done: + return t, err + } +} + +// transportV3 is a reference-counted wrapper for rhpv3.Transport. +type transportV3 struct { + dialer Dialer + refCount uint64 // locked by pool + + mu sync.Mutex + hostKey types.PublicKey + siamuxAddr string + t *rhpv3.Transport +} + +// DialStream dials a new stream on the transport. +func (t *transportV3) DialStream(ctx context.Context) (*streamV3, error) { + t.mu.Lock() + if t.t == nil { + start := time.Now() + newTransport, err := dialTransport(ctx, t.dialer, t.siamuxAddr, t.hostKey) + if err != nil { + t.mu.Unlock() + return nil, fmt.Errorf("DialStream: %w: %w (%v)", ErrDialTransport, err, time.Since(start)) + } + t.t = newTransport + } + transport := t.t + t.mu.Unlock() + + // Close the stream when the context is closed to unblock any reads or + // writes. + stream := transport.DialStream() + + // Apply a sane timeout to the stream. + if err := stream.SetDeadline(time.Now().Add(5 * time.Minute)); err != nil { + _ = stream.Close() + return nil, err + } + + // Make sure the stream is closed when the context is closed. + doneCtx, doneFn := context.WithCancel(ctx) + go func() { + select { + case <-doneCtx.Done(): + case <-ctx.Done(): + _ = stream.Close() + } + }() + return &streamV3{ + Stream: stream, + cancel: doneFn, + }, nil +} + +type streamV3 struct { + cancel context.CancelFunc + *rhpv3.Stream +} + +func (s *streamV3) ReadResponse(resp rhpv3.ProtocolObject, maxLen uint64) (err error) { + defer wrapRPCErr(&err, "ReadResponse") + return s.Stream.ReadResponse(resp, maxLen) +} + +func (s *streamV3) WriteResponse(resp rhpv3.ProtocolObject) (err error) { + defer wrapRPCErr(&err, "WriteResponse") + return s.Stream.WriteResponse(resp) +} + +func (s *streamV3) ReadRequest(req rhpv3.ProtocolObject, maxLen uint64) (err error) { + defer wrapRPCErr(&err, "ReadRequest") + return s.Stream.ReadRequest(req, maxLen) +} + +func (s *streamV3) WriteRequest(rpcID types.Specifier, req rhpv3.ProtocolObject) (err error) { + defer wrapRPCErr(&err, "WriteRequest") + return s.Stream.WriteRequest(rpcID, req) +} + +// Close closes the stream and cancels the goroutine launched by DialStream. +func (s *streamV3) Close() error { + s.cancel() + return s.Stream.Close() +} diff --git a/internal/sql/migrations.go b/internal/sql/migrations.go index d29e88fe3..377bf6fc5 100644 --- a/internal/sql/migrations.go +++ b/internal/sql/migrations.go @@ -181,6 +181,30 @@ var ( return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00011_host_subnets", log) }, }, + { + ID: "00012_peer_store", + Migrate: func(tx Tx) error { + return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00012_peer_store", log) + }, + }, + { + ID: "00013_coreutils_wallet", + Migrate: func(tx Tx) error { + return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00013_coreutils_wallet", log) + }, + }, + { + ID: "00014_hosts_resolvedaddresses", + Migrate: func(tx Tx) error { + return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00014_hosts_resolvedaddresses", log) + }, + }, + { + ID: "00015_reset_drift", + Migrate: func(tx Tx) error { + return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00015_reset_drift", log) + }, + }, } } MetricsMigrations = func(ctx context.Context, migrationsFs embed.FS, log *zap.SugaredLogger) []Migration { @@ -196,6 +220,12 @@ var ( return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00001_idx_contracts_fcid_timestamp", log) }, }, + { + ID: "00002_idx_wallet_metrics_immature", + Migrate: func(tx Tx) error { + return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00002_idx_wallet_metrics_immature", log) + }, + }, } } ) diff --git a/internal/sql/sql.go b/internal/sql/sql.go index 23b499213..b677e97fd 100644 --- a/internal/sql/sql.go +++ b/internal/sql/sql.go @@ -20,6 +20,7 @@ const ( factor = 1.8 // factor ^ retryAttempts = backoff time in milliseconds maxBackoff = 15 * time.Second + ConsensusInfoID = 1 DirectoriesRootID = 1 ) diff --git a/internal/test/config.go b/internal/test/config.go index abf6caaac..1b5d926a0 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -52,9 +52,7 @@ var ( MinMaxEphemeralAccountBalance: types.Siacoins(1), // 1SC } - PricePinSettings = api.PricePinSettings{ - Enabled: false, - } + PricePinSettings = api.DefaultPricePinSettings RedundancySettings = api.RedundancySettings{ MinShards: 2, diff --git a/internal/test/e2e/blocklist_test.go b/internal/test/e2e/blocklist_test.go index 64acc2fba..94659b277 100644 --- a/internal/test/e2e/blocklist_test.go +++ b/internal/test/e2e/blocklist_test.go @@ -12,9 +12,11 @@ import ( ) func TestBlocklist(t *testing.T) { - if testing.Short() { - t.SkipNow() - } + t.SkipNow() // TODO: re-enable this test + + // if testing.Short() { + // t.SkipNow() + // } ctx := context.Background() @@ -27,7 +29,8 @@ func TestBlocklist(t *testing.T) { tt := cluster.tt // fetch contracts - contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set}) + opts := api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set} + contracts, err := b.Contracts(ctx, opts) tt.OK(err) if len(contracts) != 3 { t.Fatalf("unexpected number of contracts, %v != 3", len(contracts)) @@ -37,14 +40,15 @@ func TestBlocklist(t *testing.T) { hk1 := contracts[0].HostKey hk2 := contracts[1].HostKey hk3 := contracts[2].HostKey - b.UpdateHostAllowlist(ctx, []types.PublicKey{hk1, hk2}, nil, false) + err = b.UpdateHostAllowlist(ctx, []types.PublicKey{hk1, hk2}, nil, false) + tt.OK(err) // assert h3 is no longer in the contract set - tt.Retry(5, time.Second, func() error { - contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set}) + tt.Retry(100, 100*time.Millisecond, func() error { + contracts, err := b.Contracts(ctx, opts) tt.OK(err) if len(contracts) != 2 { - return fmt.Errorf("unexpected number of contracts, %v != 2", len(contracts)) + return fmt.Errorf("unexpected number of contracts in set '%v', %v != 2", opts.ContractSet, len(contracts)) } for _, c := range contracts { if c.HostKey == hk3 { @@ -60,11 +64,11 @@ func TestBlocklist(t *testing.T) { tt.OK(b.UpdateHostBlocklist(ctx, []string{h1.NetAddress}, nil, false)) // assert h1 is no longer in the contract set - tt.Retry(5, time.Second, func() error { + tt.Retry(100, 100*time.Millisecond, func() error { contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set}) tt.OK(err) if len(contracts) != 1 { - return fmt.Errorf("unexpected number of contracts, %v != 1", len(contracts)) + return fmt.Errorf("unexpected number of contracts in set '%v', %v != 1", opts.ContractSet, len(contracts)) } for _, c := range contracts { if c.HostKey == hk1 { @@ -77,11 +81,11 @@ func TestBlocklist(t *testing.T) { // clear the allowlist and blocklist and assert we have 3 contracts again tt.OK(b.UpdateHostAllowlist(ctx, nil, []types.PublicKey{hk1, hk2}, false)) tt.OK(b.UpdateHostBlocklist(ctx, nil, []string{h1.NetAddress}, false)) - tt.Retry(5, time.Second, func() error { - contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set}) + tt.Retry(100, 100*time.Millisecond, func() error { + contracts, err := b.Contracts(ctx, opts) tt.OK(err) if len(contracts) != 3 { - return fmt.Errorf("unexpected number of contracts, %v != 3", len(contracts)) + return fmt.Errorf("unexpected number of contracts in set '%v', %v != 3", opts.ContractSet, len(contracts)) } return nil }) diff --git a/internal/test/e2e/cluster.go b/internal/test/e2e/cluster.go index 6fd9f5673..3e01e8ae7 100644 --- a/internal/test/e2e/cluster.go +++ b/internal/test/e2e/cluster.go @@ -15,32 +15,37 @@ import ( "github.com/minio/minio-go/v7" "go.sia.tech/core/consensus" + "go.sia.tech/core/gateway" "go.sia.tech/core/types" + "go.sia.tech/coreutils" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/jape" + "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" "go.sia.tech/renterd/autopilot" - "go.sia.tech/renterd/build" "go.sia.tech/renterd/bus" "go.sia.tech/renterd/config" - "go.sia.tech/renterd/internal/node" "go.sia.tech/renterd/internal/test" "go.sia.tech/renterd/internal/utils" - iworker "go.sia.tech/renterd/internal/worker" "go.sia.tech/renterd/stores" + "go.sia.tech/renterd/stores/sql" + "go.sia.tech/renterd/stores/sql/mysql" + "go.sia.tech/renterd/stores/sql/sqlite" + "go.sia.tech/renterd/webhooks" "go.sia.tech/renterd/worker/s3" "go.sia.tech/web/renterd" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "gorm.io/gorm" + "golang.org/x/crypto/blake2b" "lukechampine.com/frand" "go.sia.tech/renterd/worker" ) const ( - testBusFlushInterval = 100 * time.Millisecond - testBusPersistInterval = 2 * time.Second - latestHardforkHeight = 50 // foundation hardfork height in testing + testBusFlushInterval = 100 * time.Millisecond ) var ( @@ -64,15 +69,22 @@ type TestCluster struct { autopilotShutdownFns []func(context.Context) error s3ShutdownFns []func(context.Context) error - network *consensus.Network - miner *node.Miner - apID string - dbName string - dir string - logger *zap.Logger - tt test.TT - wk types.PrivateKey - wg sync.WaitGroup + network *consensus.Network + genesisBlock types.Block + cm *chain.Manager + apID string + dbName string + dir string + logger *zap.Logger + tt test.TT + wk types.PrivateKey + wg sync.WaitGroup +} + +type dbConfig struct { + Database config.Database + DatabaseLog config.DatabaseLog + RetryTxIntervals []time.Duration } func (tc *TestCluster) ShutdownAutopilot(ctx context.Context) { @@ -161,15 +173,15 @@ type testClusterOptions struct { skipRunningAutopilot bool walletKey *types.PrivateKey - autopilotCfg *node.AutopilotConfig + autopilotCfg *config.Autopilot autopilotSettings *api.AutopilotConfig - busCfg *node.BusConfig + busCfg *config.Bus workerCfg *config.Worker } // newTestLogger creates a console logger used for testing. func newTestLogger() *zap.Logger { - return newTestLoggerCustom(zapcore.DebugLevel) + return newTestLoggerCustom(zapcore.WarnLevel) } // newTestLoggerCustom creates a console logger used for testing and allows @@ -190,8 +202,6 @@ func newTestLoggerCustom(level zapcore.Level) *zap.Logger { // newTestCluster creates a new cluster without hosts with a funded bus. func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { - t.Helper() - // Skip any test that requires a cluster when running short tests. if testing.Short() { t.SkipNow() @@ -215,7 +225,7 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { if opts.walletKey != nil { wk = *opts.walletKey } - busCfg, workerCfg, apCfg := testBusCfg(), testWorkerCfg(), testApCfg() + busCfg, workerCfg, apCfg, dbCfg := testBusCfg(), testWorkerCfg(), testApCfg(), testDBCfg() if opts.busCfg != nil { busCfg = *opts.busCfg } @@ -241,36 +251,28 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { if opts.autopilotSettings != nil { apSettings = *opts.autopilotSettings } - if busCfg.Logger == nil { - busCfg.Logger = logger - } if opts.dbName != "" { - busCfg.Database.MySQL.Database = opts.dbName + dbCfg.Database.MySQL.Database = opts.dbName } // Check if we are testing against an external database. If so, we create a // database with a random name first. - if mysql := config.MySQLConfigFromEnv(); mysql.URI != "" { + if mysqlCfg := config.MySQLConfigFromEnv(); mysqlCfg.URI != "" { // generate a random database name if none are set - if busCfg.Database.MySQL.Database == "" { - busCfg.Database.MySQL.Database = "db" + hex.EncodeToString(frand.Bytes(16)) + if dbCfg.Database.MySQL.Database == "" { + dbCfg.Database.MySQL.Database = "db" + hex.EncodeToString(frand.Bytes(16)) } - if busCfg.Database.MySQL.MetricsDatabase == "" { - busCfg.Database.MySQL.MetricsDatabase = "db" + hex.EncodeToString(frand.Bytes(16)) + if dbCfg.Database.MySQL.MetricsDatabase == "" { + dbCfg.Database.MySQL.MetricsDatabase = "db" + hex.EncodeToString(frand.Bytes(16)) } - tmpDB, err := gorm.Open(stores.NewMySQLConnection(mysql.User, mysql.Password, mysql.URI, "")) - tt.OK(err) - tt.OK(tmpDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", busCfg.Database.MySQL.Database)).Error) - tt.OK(tmpDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", busCfg.Database.MySQL.MetricsDatabase)).Error) - tmpDBB, err := tmpDB.DB() + tmpDB, err := mysql.Open(mysqlCfg.User, mysqlCfg.Password, mysqlCfg.URI, "") tt.OK(err) - tt.OK(tmpDBB.Close()) + tt.OKAll(tmpDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", dbCfg.Database.MySQL.Database))) + tt.OKAll(tmpDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", dbCfg.Database.MySQL.MetricsDatabase))) + tt.OK(tmpDB.Close()) } - // Prepare individual dirs. - busDir := filepath.Join(dir, "bus") - // Generate API passwords. busPassword := randomPassword() workerPassword := randomPassword() @@ -309,11 +311,9 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { }) tt.OK(err) - // Create miner. - busCfg.Miner = node.NewMiner(busClient) - // Create bus. - b, bSetupFn, bShutdownFn, err := node.NewBus(busCfg, busDir, wk, logger) + busDir := filepath.Join(dir, "bus") + b, bShutdownFn, cm, err := newTestBus(ctx, busDir, busCfg, dbCfg, wk, logger) tt.OK(err) busAuth := jape.BasicAuth(busPassword) @@ -322,7 +322,7 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { Handler: renterd.Handler(), // ui Sub: map[string]utils.TreeMux{ "/bus": { - Handler: busAuth(b), + Handler: busAuth(b.Handler()), }, }, }, @@ -333,46 +333,47 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { busShutdownFns = append(busShutdownFns, bShutdownFn) // Create worker. - w, s3Handler, wSetupFn, wShutdownFn, err := node.NewWorker(workerCfg, s3.Opts{}, busClient, wk, logger) + workerKey := blake2b.Sum256(append([]byte("worker"), wk...)) + w, err := worker.New(workerCfg, workerKey, busClient, logger) tt.OK(err) - workerServer := http.Server{ - Handler: iworker.Auth(workerPassword, false)(w), - } + workerServer := http.Server{Handler: utils.Auth(workerPassword, false)(w.Handler())} var workerShutdownFns []func(context.Context) error workerShutdownFns = append(workerShutdownFns, workerServer.Shutdown) - workerShutdownFns = append(workerShutdownFns, wShutdownFn) + workerShutdownFns = append(workerShutdownFns, w.Shutdown) // Create S3 API. - s3Server := http.Server{ - Handler: s3Handler, - } + s3Handler, err := s3.New(busClient, w, logger, s3.Opts{}) + tt.OK(err) + s3Server := http.Server{Handler: s3Handler} var s3ShutdownFns []func(context.Context) error s3ShutdownFns = append(s3ShutdownFns, s3Server.Shutdown) // Create autopilot. - ap, aStartFn, aStopFn, err := node.NewAutopilot(apCfg, busClient, []autopilot.Worker{workerClient}, logger) + ap, err := autopilot.New(apCfg, busClient, []autopilot.Worker{workerClient}, logger) tt.OK(err) autopilotAuth := jape.BasicAuth(autopilotPassword) autopilotServer := http.Server{ - Handler: autopilotAuth(ap), + Handler: autopilotAuth(ap.Handler()), } var autopilotShutdownFns []func(context.Context) error autopilotShutdownFns = append(autopilotShutdownFns, autopilotServer.Shutdown) - autopilotShutdownFns = append(autopilotShutdownFns, aStopFn) + autopilotShutdownFns = append(autopilotShutdownFns, ap.Shutdown) + network, genesis := testNetwork() cluster := &TestCluster{ - apID: apCfg.ID, - dir: dir, - dbName: busCfg.Database.MySQL.Database, - logger: logger, - network: busCfg.Network, - miner: busCfg.Miner, - tt: tt, - wk: wk, + apID: apCfg.ID, + dir: dir, + dbName: dbCfg.Database.MySQL.Database, + logger: logger, + network: network, + genesisBlock: genesis, + cm: cm, + tt: tt, + wk: wk, Autopilot: autopilotClient, Bus: busClient, @@ -410,18 +411,13 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { if !opts.skipRunningAutopilot { cluster.wg.Add(1) go func() { - _ = aStartFn() + ap.Run() cluster.wg.Done() }() } - // Finish bus setup. - if err := bSetupFn(ctx); err != nil { - tt.Fatalf("failed to setup bus, err: %v", err) - } - // Finish worker setup. - if err := wSetupFn(ctx, workerAddr, workerPassword); err != nil { + if err := w.Setup(ctx, workerAddr, workerPassword); err != nil { tt.Fatalf("failed to setup worker, err: %v", err) } @@ -447,35 +443,32 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { })) tt.OK(busClient.UpdateSetting(ctx, api.SettingUploadPacking, api.UploadPackingSettings{ Enabled: enableUploadPacking, - SlabBufferMaxSizeSoft: build.DefaultUploadPackingSettings.SlabBufferMaxSizeSoft, + SlabBufferMaxSizeSoft: api.DefaultUploadPackingSettings.SlabBufferMaxSizeSoft, })) // Fund the bus. if funding { - cluster.MineBlocks(latestHardforkHeight) - tt.Retry(1000, 100*time.Millisecond, func() error { - resp, err := busClient.ConsensusState(ctx) - if err != nil { + cluster.MineBlocks(network.HardforkFoundation.Height + blocksPerDay) // mine until the first block reward matures + tt.Retry(100, 100*time.Millisecond, func() error { + if cs, err := busClient.ConsensusState(ctx); err != nil { return err + } else if !cs.Synced { + return fmt.Errorf("chain not synced: %v", cs.Synced) } - if !resp.Synced || resp.BlockHeight < latestHardforkHeight { - return fmt.Errorf("chain not synced: %v %v", resp.Synced, resp.BlockHeight < latestHardforkHeight) - } - res, err := cluster.Bus.Wallet(ctx) - if err != nil { + if res, err := cluster.Bus.Wallet(ctx); err != nil { return err + } else if res.Confirmed.IsZero() { + return fmt.Errorf("wallet not funded: %+v", res) + } else { + return nil } - - if res.Confirmed.IsZero() { - tt.Fatal("wallet not funded") - } - return nil }) } if nHosts > 0 { cluster.AddHostsBlocking(nHosts) + cluster.WaitForPeers() cluster.WaitForContracts() cluster.WaitForContractSet(test.ContractSet, nHosts) cluster.WaitForAccounts() @@ -491,6 +484,114 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { return cluster } +func newTestBus(ctx context.Context, dir string, cfg config.Bus, cfgDb dbConfig, pk types.PrivateKey, logger *zap.Logger) (*bus.Bus, func(ctx context.Context) error, *chain.Manager, error) { + // create store + alertsMgr := alerts.NewManager() + storeCfg, err := buildStoreConfig(alertsMgr, dir, cfg.SlabBufferCompletionThreshold, cfgDb, pk, logger) + if err != nil { + return nil, nil, nil, err + } + + sqlStore, err := stores.NewSQLStore(storeCfg) + if err != nil { + return nil, nil, nil, err + } + + // create webhooks manager + wh, err := webhooks.NewManager(sqlStore, logger) + if err != nil { + return nil, nil, nil, err + } + + // hookup webhooks <-> alerts + alertsMgr.RegisterWebhookBroadcaster(wh) + + // create consensus directory + consensusDir := filepath.Join(dir, "consensus") + if err := os.MkdirAll(consensusDir, 0700); err != nil { + return nil, nil, nil, err + } + + // create chain database + chainPath := filepath.Join(consensusDir, "blockchain.db") + bdb, err := coreutils.OpenBoltChainDB(chainPath) + if err != nil { + return nil, nil, nil, err + } + + // create chain manager + network, genesis := testNetwork() + store, state, err := chain.NewDBStore(bdb, network, genesis) + if err != nil { + return nil, nil, nil, err + } + cm := chain.NewManager(store, state) + + // create wallet + w, err := wallet.NewSingleAddressWallet(pk, cm, sqlStore, wallet.WithReservationDuration(cfg.UsedUTXOExpiry)) + if err != nil { + return nil, nil, nil, err + } + + // create syncer, peers will reject us if our hostname is empty or + // unspecified, so use loopback + l, err := net.Listen("tcp", cfg.GatewayAddr) + if err != nil { + return nil, nil, nil, err + } + syncerAddr := l.Addr().String() + host, port, _ := net.SplitHostPort(syncerAddr) + if ip := net.ParseIP(host); ip == nil || ip.IsUnspecified() { + syncerAddr = net.JoinHostPort("127.0.0.1", port) + } + + // create header + header := gateway.Header{ + GenesisID: genesis.ID(), + UniqueID: gateway.GenerateUniqueID(), + NetAddress: syncerAddr, + } + + // create the syncer + s := syncer.New(l, cm, sqlStore, header, syncer.WithLogger(logger.Named("syncer")), syncer.WithSendBlocksTimeout(time.Minute)) + + // start syncer + errChan := make(chan error, 1) + go func() { + errChan <- s.Run(context.Background()) + close(errChan) + }() + + // create a helper function to wait for syncer to wind down on shutdown + syncerShutdown := func(ctx context.Context) error { + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return context.Cause(ctx) + } + } + + // create bus + announcementMaxAgeHours := time.Duration(cfg.AnnouncementMaxAgeHours) * time.Hour + b, err := bus.New(ctx, alertsMgr, wh, cm, s, w, sqlStore, announcementMaxAgeHours, logger) + if err != nil { + return nil, nil, nil, err + } + + shutdownFn := func(ctx context.Context) error { + return errors.Join( + s.Close(), + w.Close(), + b.Shutdown(ctx), + sqlStore.Close(), + bdb.Close(), + syncerShutdown(ctx), + ) + } + return b, shutdownFn, cm, nil +} + // addStorageFolderToHosts adds a single storage folder to each host. func addStorageFolderToHost(ctx context.Context, hosts []*Host) error { for _, host := range hosts { @@ -536,68 +637,55 @@ func (c *TestCluster) MineToRenewWindow() { if cs.BlockHeight >= renewWindowStart { c.tt.Fatalf("already in renew window: bh: %v, currentPeriod: %v, periodLength: %v, renewWindow: %v", cs.BlockHeight, ap.CurrentPeriod, ap.Config.Contracts.Period, renewWindowStart) } - c.MineBlocks(int(renewWindowStart - cs.BlockHeight)) - c.Sync() -} - -// sync blocks until the cluster is synced. -func (c *TestCluster) sync(hosts []*Host) { - c.tt.Helper() - c.tt.Retry(100, 100*time.Millisecond, func() error { - synced, err := c.synced(hosts) - if err != nil { - return err - } - if !synced { - return errors.New("cluster was unable to sync in time") - } - return nil - }) + c.MineBlocks(renewWindowStart - cs.BlockHeight) } -// synced returns true if bus and hosts are at the same blockheight. -func (c *TestCluster) synced(hosts []*Host) (bool, error) { - c.tt.Helper() - cs, err := c.Bus.ConsensusState(context.Background()) - if err != nil { - return false, err - } - if !cs.Synced { - return false, nil // can't be synced if bus itself isn't synced - } - for _, h := range hosts { - bh := h.cs.Height() - if cs.BlockHeight != uint64(bh) { - return false, nil - } - } - return true, nil -} - -// MineBlocks uses the bus' miner to mine n blocks. -func (c *TestCluster) MineBlocks(n int) { +// MineBlocks mines n blocks +func (c *TestCluster) MineBlocks(n uint64) { c.tt.Helper() wallet, err := c.Bus.Wallet(context.Background()) c.tt.OK(err) // If we don't have any hosts in the cluster mine all blocks right away. if len(c.hosts) == 0 { - c.tt.OK(c.miner.Mine(wallet.Address, n)) - c.Sync() + c.tt.OK(c.mineBlocks(wallet.Address, n)) + c.sync() return } - // Otherwise mine blocks in batches of 3 to avoid going out of sync with - // hosts by too many blocks. - for mined := 0; mined < n; { + // Otherwise mine blocks in batches of 10 blocks to avoid going out of sync + // with hosts by too many blocks. + for mined := uint64(0); mined < n; { toMine := n - mined if toMine > 10 { toMine = 10 } - c.tt.OK(c.miner.Mine(wallet.Address, toMine)) - c.Sync() + c.tt.OK(c.mineBlocks(wallet.Address, toMine)) mined += toMine + c.sync() } + c.sync() +} + +func (c *TestCluster) sync() { + tip := c.cm.Tip() + c.tt.Retry(300, 100*time.Millisecond, func() error { + cs, err := c.Bus.ConsensusState(context.Background()) + if err != nil { + return err + } else if !cs.Synced { + return errors.New("bus is not synced") + } else if cs.BlockHeight < tip.Height { + return fmt.Errorf("subscriber hasn't caught up, %d < %d", cs.BlockHeight, tip.Height) + } + + for _, h := range c.hosts { + if hh := h.cm.Tip().Height; hh < tip.Height { + return fmt.Errorf("host %v is not synced, %v < %v", h.PublicKey(), hh, cs.BlockHeight) + } + } + return nil + }) } func (c *TestCluster) WaitForAccounts() []api.Account { @@ -620,6 +708,7 @@ func (c *TestCluster) WaitForAccounts() []api.Account { func (c *TestCluster) WaitForContracts() []api.Contract { c.tt.Helper() + // build hosts map hostsMap := make(map[types.PublicKey]struct{}) for _, host := range c.hosts { @@ -680,6 +769,19 @@ func (c *TestCluster) WaitForContractSetContracts(set string, n int) { }) } +func (c *TestCluster) WaitForPeers() { + c.tt.Helper() + c.tt.Retry(300, 100*time.Millisecond, func() error { + peers, err := c.Bus.SyncerPeers(context.Background()) + if err != nil { + return err + } else if len(peers) == 0 { + return errors.New("no peers found") + } + return nil + }) +} + func (c *TestCluster) RemoveHost(host *Host) { c.tt.Helper() c.tt.OK(host.Close()) @@ -696,11 +798,11 @@ func (c *TestCluster) NewHost() *Host { c.tt.Helper() // Create host. hostDir := filepath.Join(c.dir, "hosts", fmt.Sprint(len(c.hosts)+1)) - h, err := NewHost(types.GeneratePrivateKey(), hostDir, c.network, false) + h, err := NewHost(types.GeneratePrivateKey(), hostDir, c.network, c.genesisBlock) c.tt.OK(err) // Connect gateways. - c.tt.OK(c.Bus.SyncerConnect(context.Background(), h.GatewayAddr())) + c.tt.OK(c.Bus.SyncerConnect(context.Background(), h.SyncerAddr())) return h } @@ -710,43 +812,41 @@ func (c *TestCluster) AddHost(h *Host) { c.hosts = append(c.hosts, h) // Fund host from bus. - fundAmt := types.Siacoins(100e3) - var scos []types.SiacoinOutput - for i := 0; i < 10; i++ { - scos = append(scos, types.SiacoinOutput{ - Value: fundAmt, - Address: h.WalletAddress(), - }) + fundAmt := types.Siacoins(5e3) + for i := 0; i < 5; i++ { + c.tt.OKAll(c.Bus.SendSiacoins(context.Background(), h.WalletAddress(), fundAmt, true)) } - c.tt.OK(c.Bus.SendSiacoins(context.Background(), scos, false)) // Mine transaction. c.MineBlocks(1) - // Wait for hosts to sync up with consensus. - hosts := []*Host{h} - c.sync(hosts) + // Wait for host's wallet to be funded + c.tt.Retry(100, 100*time.Millisecond, func() error { + balance, err := h.wallet.Balance() + c.tt.OK(err) + if balance.Confirmed.IsZero() { + return errors.New("host wallet not funded") + } + return nil + }) // Announce hosts. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c.tt.OK(addStorageFolderToHost(ctx, hosts)) - c.tt.OK(announceHosts(hosts)) + c.tt.OK(addStorageFolderToHost(ctx, []*Host{h})) + c.tt.OK(announceHosts([]*Host{h})) - // Mine a few blocks. The host should show up eventually. - c.tt.Retry(10, time.Second, func() error { + // Mine until the host shows up. + c.tt.Retry(100, 100*time.Millisecond, func() error { c.tt.Helper() - c.MineBlocks(1) _, err := c.Bus.Host(context.Background(), h.PublicKey()) if err != nil { + c.MineBlocks(1) return err } return nil }) - - // Wait for host to be synced. - c.Sync() } // AddHosts adds n hosts to the cluster. These hosts will be funded and announce @@ -793,12 +893,6 @@ func (c *TestCluster) Shutdown() { c.wg.Wait() } -// Sync blocks until the whole cluster has reached the same block height. -func (c *TestCluster) Sync() { - c.tt.Helper() - c.sync(c.hosts) -} - // waitForHostAccounts will fetch the accounts from the worker and wait until // they have money in them func (c *TestCluster) waitForHostAccounts(hosts map[types.PublicKey]struct{}) { @@ -866,52 +960,54 @@ func (c *TestCluster) waitForHostContracts(hosts map[types.PublicKey]struct{}) { }) } -// testNetwork returns a custom network for testing which matches the -// configuration of siad consensus in testing. -func testNetwork() *consensus.Network { - n := &consensus.Network{ - InitialCoinbase: types.Siacoins(300000), - MinimumCoinbase: types.Siacoins(299990), - InitialTarget: types.BlockID{4: 32}, +func (c *TestCluster) mineBlocks(addr types.Address, n uint64) error { + for i := uint64(0); i < n; i++ { + if block, found := coreutils.MineBlock(c.cm, addr, 5*time.Second); !found { + c.tt.Fatal("failed to mine block") + } else if err := c.Bus.AcceptBlock(context.Background(), block); err != nil { + return err + } } + return nil +} - n.HardforkDevAddr.Height = 3 - n.HardforkDevAddr.OldAddress = types.Address{} - n.HardforkDevAddr.NewAddress = types.Address{} - - n.HardforkTax.Height = 10 - - n.HardforkStorageProof.Height = 10 - - n.HardforkOak.Height = 20 - n.HardforkOak.FixHeight = 23 - n.HardforkOak.GenesisTimestamp = time.Now().Add(-1e6 * time.Second) - - n.HardforkASIC.Height = 5 - n.HardforkASIC.OakTime = 10000 * time.Second - n.HardforkASIC.OakTarget = types.BlockID{255, 255} - - n.HardforkFoundation.Height = 50 - n.HardforkFoundation.PrimaryAddress = types.StandardUnlockHash(types.GeneratePrivateKey().PublicKey()) - n.HardforkFoundation.FailsafeAddress = types.StandardUnlockHash(types.GeneratePrivateKey().PublicKey()) - - // make it difficult to reach v2 in most tests +// testNetwork returns a modified version of Zen used for testing +func testNetwork() (*consensus.Network, types.Block) { + // use a modified version of Zen + n, genesis := chain.TestnetZen() + + // we have to set the initial target to 128 to ensure blocks we mine match + // the PoW testnet in siad testnet consensu + n.InitialTarget = types.BlockID{0x80} + + // we have to make minimum coinbase get hit after 10 blocks to ensure we + // match the siad test network settings, otherwise the blocksubsidy is + // considered invalid after 10 blocks + n.MinimumCoinbase = types.Siacoins(299990) + n.HardforkDevAddr.Height = 1 + n.HardforkTax.Height = 1 + n.HardforkStorageProof.Height = 1 + n.HardforkOak.Height = 1 + n.HardforkASIC.Height = 1 + n.HardforkFoundation.Height = 1 n.HardforkV2.AllowHeight = 1000 n.HardforkV2.RequireHeight = 1020 - return n + return n, genesis } -func testBusCfg() node.BusConfig { - return node.BusConfig{ - Bus: config.Bus{ - AnnouncementMaxAgeHours: 24 * 7 * 52, // 1 year - Bootstrap: false, - GatewayAddr: "127.0.0.1:0", - PersistInterval: testBusPersistInterval, - UsedUTXOExpiry: time.Minute, - SlabBufferCompletionThreshold: 0, - }, +func testBusCfg() config.Bus { + return config.Bus{ + AnnouncementMaxAgeHours: 24 * 7 * 52, // 1 year + Bootstrap: false, + GatewayAddr: "127.0.0.1:0", + UsedUTXOExpiry: time.Minute, + SlabBufferCompletionThreshold: 0, + } +} + +func testDBCfg() dbConfig { + return dbConfig{ Database: config.Database{ MySQL: config.MySQLConfigFromEnv(), }, @@ -920,7 +1016,14 @@ func testBusCfg() node.BusConfig { IgnoreRecordNotFoundError: true, SlowThreshold: 100 * time.Millisecond, }, - Network: testNetwork(), + RetryTxIntervals: []time.Duration{ + 50 * time.Millisecond, + 100 * time.Millisecond, + 200 * time.Millisecond, + 500 * time.Millisecond, + time.Second, + 5 * time.Second, + }, } } @@ -938,18 +1041,91 @@ func testWorkerCfg() config.Worker { } } -func testApCfg() node.AutopilotConfig { - return node.AutopilotConfig{ - ID: api.DefaultAutopilotID, - Autopilot: config.Autopilot{ - AccountsRefillInterval: time.Second, - Heartbeat: time.Second, - MigrationHealthCutoff: 0.99, - MigratorParallelSlabsPerWorker: 1, - RevisionSubmissionBuffer: 0, - ScannerInterval: time.Second, - ScannerBatchSize: 10, - ScannerNumThreads: 1, - }, +func testApCfg() config.Autopilot { + return config.Autopilot{ + AccountsRefillInterval: time.Second, + Heartbeat: time.Second, + ID: api.DefaultAutopilotID, + MigrationHealthCutoff: 0.99, + MigratorParallelSlabsPerWorker: 1, + RevisionSubmissionBuffer: 0, + ScannerInterval: time.Second, + ScannerBatchSize: 10, + ScannerNumThreads: 1, + } +} + +func buildStoreConfig(am alerts.Alerter, dir string, slabBufferCompletionThreshold int64, cfg dbConfig, pk types.PrivateKey, logger *zap.Logger) (stores.Config, error) { + // create database connections + var dbMain sql.Database + var dbMetrics sql.MetricsDatabase + if cfg.Database.MySQL.URI != "" { + // create MySQL connections + connMain, err := mysql.Open( + cfg.Database.MySQL.User, + cfg.Database.MySQL.Password, + cfg.Database.MySQL.URI, + cfg.Database.MySQL.Database, + ) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to open MySQL main database: %w", err) + } + connMetrics, err := mysql.Open( + cfg.Database.MySQL.User, + cfg.Database.MySQL.Password, + cfg.Database.MySQL.URI, + cfg.Database.MySQL.MetricsDatabase, + ) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to open MySQL metrics database: %w", err) + } + dbMain, err = mysql.NewMainDatabase(connMain, logger, cfg.DatabaseLog.SlowThreshold, cfg.DatabaseLog.SlowThreshold) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to create MySQL main database: %w", err) + } + dbMetrics, err = mysql.NewMetricsDatabase(connMetrics, logger, cfg.DatabaseLog.SlowThreshold, cfg.DatabaseLog.SlowThreshold) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to create MySQL metrics database: %w", err) + } + } else { + // create database directory + dbDir := filepath.Join(dir, "db") + if err := os.MkdirAll(dbDir, 0700); err != nil { + return stores.Config{}, err + } + + // create SQLite connections + db, err := sqlite.Open(filepath.Join(dbDir, "db.sqlite")) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to open SQLite main database: %w", err) + } + dbMain, err = sqlite.NewMainDatabase(db, logger, cfg.DatabaseLog.SlowThreshold, cfg.DatabaseLog.SlowThreshold) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to create SQLite main database: %w", err) + } + + dbm, err := sqlite.Open(filepath.Join(dbDir, "metrics.sqlite")) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to open SQLite metrics database: %w", err) + } + dbMetrics, err = sqlite.NewMetricsDatabase(dbm, logger, cfg.DatabaseLog.SlowThreshold, cfg.DatabaseLog.SlowThreshold) + if err != nil { + return stores.Config{}, fmt.Errorf("failed to create SQLite metrics database: %w", err) + } } + + return stores.Config{ + Alerts: alerts.WithOrigin(am, "bus"), + DB: dbMain, + DBMetrics: dbMetrics, + PartialSlabDir: filepath.Join(dir, "partial_slabs"), + Migrate: true, + SlabBufferCompletionThreshold: slabBufferCompletionThreshold, + Logger: logger, + WalletAddress: types.StandardUnlockHash(pk.PublicKey()), + + RetryTransactionIntervals: cfg.RetryTxIntervals, + LongQueryDuration: cfg.DatabaseLog.SlowThreshold, + LongTxDuration: cfg.DatabaseLog.SlowThreshold, + }, nil } diff --git a/internal/test/e2e/cluster_test.go b/internal/test/e2e/cluster_test.go index 049e54d7e..f9ba9e018 100644 --- a/internal/test/e2e/cluster_test.go +++ b/internal/test/e2e/cluster_test.go @@ -21,12 +21,13 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/autopilot/contractor" "go.sia.tech/renterd/internal/test" "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" - "go.sia.tech/renterd/wallet" "go.uber.org/zap" "lukechampine.com/frand" ) @@ -183,6 +184,26 @@ func TestNewTestCluster(t *testing.T) { t.Fatalf("expected upload packing to be disabled by default, got %v", ups.Enabled) } + // PricePinningSettings should have default values + pps, err := b.PricePinningSettings(context.Background()) + tt.OK(err) + if pps.ForexEndpointURL == "" { + t.Fatal("expected default value for ForexEndpointURL") + } else if pps.Currency == "" { + t.Fatal("expected default value for Currency") + } else if pps.Threshold == 0 { + t.Fatal("expected default value for Threshold") + } + + // Autopilot shouldn't have its prices pinned + if len(pps.Autopilots) != 1 { + t.Fatalf("expected 1 autopilot, got %v", len(pps.Autopilots)) + } else if pin, exists := pps.Autopilots[api.DefaultAutopilotID]; !exists { + t.Fatalf("expected autopilot %v to exist", api.DefaultAutopilotID) + } else if pin.Allowance != (api.Pin{}) { + t.Fatalf("expected autopilot %v to have no pinned allowance, got %v", api.DefaultAutopilotID, pin.Allowance) + } + // See if autopilot is running by triggering the loop. _, err = cluster.Autopilot.Trigger(false) tt.OK(err) @@ -211,13 +232,14 @@ func TestNewTestCluster(t *testing.T) { cluster.MineToRenewWindow() // Wait for the contract to be renewed. + var renewalID types.FileContractID tt.Retry(100, 100*time.Millisecond, func() error { contracts, err := cluster.Bus.Contracts(context.Background(), api.ContractsOpts{}) if err != nil { return err } if len(contracts) != 1 { - return errors.New("no renewed contract") + return fmt.Errorf("unexpected number of contracts %d != 1", len(contracts)) } if contracts[0].RenewedFrom != contract.ID { return fmt.Errorf("contract wasn't renewed %v != %v", contracts[0].RenewedFrom, contract.ID) @@ -231,6 +253,7 @@ func TestNewTestCluster(t *testing.T) { if contracts[0].State != api.ContractStatePending { return fmt.Errorf("contract should be pending but was %v", contracts[0].State) } + renewalID = contracts[0].ID return nil }) @@ -238,8 +261,7 @@ func TestNewTestCluster(t *testing.T) { // revision first. cs, err := cluster.Bus.ConsensusState(context.Background()) tt.OK(err) - cluster.MineBlocks(int(contract.WindowStart - cs.BlockHeight - 4)) - cluster.Sync() + cluster.MineBlocks(contract.WindowStart - cs.BlockHeight - 4) if cs.LastBlockTime.IsZero() { t.Fatal("last block time not set") } @@ -247,14 +269,7 @@ func TestNewTestCluster(t *testing.T) { // Now wait for the revision and proof to be caught by the hostdb. var ac api.ArchivedContract tt.Retry(20, time.Second, func() error { - cluster.MineBlocks(1) - - // Fetch renewed contract and make sure we caught the proof and revision. - contracts, err := cluster.Bus.Contracts(context.Background(), api.ContractsOpts{}) - if err != nil { - t.Fatal(err) - } - archivedContracts, err := cluster.Bus.AncestorContracts(context.Background(), contracts[0].ID, 0) + archivedContracts, err := cluster.Bus.AncestorContracts(context.Background(), renewalID, 0) if err != nil { t.Fatal(err) } @@ -263,7 +278,7 @@ func TestNewTestCluster(t *testing.T) { } ac = archivedContracts[0] if ac.RevisionHeight == 0 || ac.RevisionNumber != math.MaxUint64 { - return fmt.Errorf("revision information is wrong: %v %v", ac.RevisionHeight, ac.RevisionNumber) + return fmt.Errorf("revision information is wrong: %v %v %v", ac.RevisionHeight, ac.RevisionNumber, ac.ID) } if ac.ProofHeight != 0 { t.Fatal("proof height should be 0 since the contract was renewed and therefore doesn't require a proof") @@ -271,13 +286,6 @@ func TestNewTestCluster(t *testing.T) { if ac.State != api.ContractStateComplete { return fmt.Errorf("contract should be complete but was %v", ac.State) } - archivedContracts, err = cluster.Bus.AncestorContracts(context.Background(), contracts[0].ID, math.MaxUint32) - if err != nil { - t.Fatal(err) - } - if len(archivedContracts) != 0 { - return fmt.Errorf("should have 0 archived contracts but got %v", len(archivedContracts)) - } return nil }) @@ -722,7 +730,7 @@ func TestUploadDownloadBasic(t *testing.T) { t.Fatal("unexpected", len(data), buffer.Len()) } - // download again, 32 bytes at a time. + // download again, 32 bytes at a time for i := int64(0); i < 4; i++ { offset := i * 32 var buffer bytes.Buffer @@ -734,48 +742,17 @@ func TestUploadDownloadBasic(t *testing.T) { } } - // fetch the contracts. - contracts, err := cluster.Bus.Contracts(context.Background(), api.ContractsOpts{}) - tt.OK(err) - - // broadcast the revision for each contract and assert the revision height - // is 0. - for _, c := range contracts { - if c.RevisionHeight != 0 { - t.Fatal("revision height should be 0") - } - tt.OK(w.RHPBroadcast(context.Background(), c.ID)) - } - - // mine a block to get the revisions mined. - cluster.MineBlocks(1) - - // check the revision height and size were updated. + // check that stored data on hosts was updated tt.Retry(100, 100*time.Millisecond, func() error { - // fetch the contracts. - contracts, err := cluster.Bus.Contracts(context.Background(), api.ContractsOpts{}) - if err != nil { - return err - } - // assert the revision height was updated. - for _, c := range contracts { - if c.RevisionHeight == 0 { - return errors.New("revision height should be > 0") - } else if c.Size != rhpv2.SectorSize { - return fmt.Errorf("size should be %v, got %v", rhpv2.SectorSize, c.Size) + hosts, err := cluster.Bus.Hosts(context.Background(), api.GetHostsOptions{}) + tt.OK(err) + for _, host := range hosts { + if host.StoredData != rhpv2.SectorSize { + return fmt.Errorf("stored data should be %v, got %v", rhpv2.SectorSize, host.StoredData) } } return nil }) - - // Check that stored data on hosts was updated - hosts, err := cluster.Bus.Hosts(context.Background(), api.GetHostsOptions{}) - tt.OK(err) - for _, host := range hosts { - if host.StoredData != rhpv2.SectorSize { - t.Fatalf("stored data should be %v, got %v", rhpv2.SectorSize, host.StoredData) - } - } } // TestUploadDownloadExtended is an integration test that verifies objects can @@ -934,7 +911,8 @@ func TestUploadDownloadSpending(t *testing.T) { // create a test cluster cluster := newTestCluster(t, testClusterOptions{ - hosts: test.RedundancySettings.TotalShards, + hosts: test.RedundancySettings.TotalShards, + logger: zap.NewNop(), }) defer cluster.Shutdown() @@ -946,8 +924,8 @@ func TestUploadDownloadSpending(t *testing.T) { tt.Retry(100, testBusFlushInterval, func() error { cms, err := cluster.Bus.Contracts(context.Background(), api.ContractsOpts{}) tt.OK(err) - if len(cms) == 0 { - t.Fatal("no contracts found") + if len(cms) != test.RedundancySettings.TotalShards { + t.Fatalf("unexpected number of contracts %v", len(cms)) } nFunded := 0 @@ -1099,77 +1077,118 @@ func TestUploadDownloadSpending(t *testing.T) { tt.OK(err) } -// TestEphemeralAccounts tests the use of ephemeral accounts. -func TestEphemeralAccounts(t *testing.T) { +func TestContractApplyChainUpdates(t *testing.T) { if testing.Short() { t.SkipNow() } - dir := t.TempDir() - cluster := newTestCluster(t, testClusterOptions{ - dir: dir, - logger: zap.NewNop(), - }) + // create a test cluster without autopilot + cluster := newTestCluster(t, testClusterOptions{skipRunningAutopilot: true}) defer cluster.Shutdown() - tt := cluster.tt - // add host - nodes := cluster.AddHosts(1) - host := nodes[0] + // convenience variables + w := cluster.Worker + b := cluster.Bus + tt := cluster.tt - // make the cost of fetching a revision 0. That allows us to check for exact - // balances when funding the account and avoid NDFs. - settings := host.settings.Settings() - settings.BaseRPCPrice = types.ZeroCurrency - settings.EgressPrice = types.ZeroCurrency - if err := host.settings.UpdateSettings(settings); err != nil { - t.Fatal(err) - } + // add a host + hosts := cluster.AddHosts(1) + h, err := b.Host(context.Background(), hosts[0].PublicKey()) + tt.OK(err) - // Wait for contracts to form. - var contract api.Contract - contracts := cluster.WaitForContracts() - contract = contracts[0] + // manually form a contract with the host + cs, _ := b.ConsensusState(context.Background()) + wallet, _ := b.Wallet(context.Background()) + rev, _, err := w.RHPForm(context.Background(), cs.BlockHeight+test.AutopilotConfig.Contracts.Period+test.AutopilotConfig.Contracts.RenewWindow, h.PublicKey, h.NetAddress, wallet.Address, types.Siacoins(1), types.Siacoins(1)) + tt.OK(err) + contract, err := b.AddContract(context.Background(), rev, rev.Revision.MissedHostPayout().Sub(types.Siacoins(1)), types.Siacoins(1), cs.BlockHeight, api.ContractStatePending) + tt.OK(err) - // Wait for account to appear. - accounts := cluster.WaitForAccounts() + // assert revision height is 0 + if contract.RevisionHeight != 0 { + t.Fatalf("expected revision height to be 0, got %v", contract.RevisionHeight) + } - // Shut down the autopilot to prevent it from interfering with the test. - cluster.ShutdownAutopilot(context.Background()) + // broadcast the revision for each contract + fcid := contract.ID + tt.OK(w.RHPBroadcast(context.Background(), fcid)) + cluster.MineBlocks(1) - // Newly created accounts are !cleanShutdown. Simulate a sync to change - // that. - for _, acc := range accounts { - if acc.CleanShutdown { - t.Fatal("new account should indicate an unclean shutdown") - } else if acc.RequiresSync { - t.Fatal("new account should not require a sync") - } - if err := cluster.Bus.SetBalance(context.Background(), acc.ID, acc.HostKey, acc.Balance); err != nil { - t.Fatal(err) + // check the revision height was updated. + tt.Retry(100, 100*time.Millisecond, func() error { + c, err := cluster.Bus.Contract(context.Background(), fcid) + tt.OK(err) + if c.RevisionHeight == 0 { + return fmt.Errorf("contract %v should have been revised", c.ID) } + return nil + }) +} + +// TestEphemeralAccounts tests the use of ephemeral accounts. +func TestEphemeralAccounts(t *testing.T) { + if testing.Short() { + t.SkipNow() } - // Fetch accounts again. + // run without autopilot + opts := clusterOptsDefault + opts.skipRunningAutopilot = true + + // create cluster + cluster := newTestCluster(t, opts) + defer cluster.Shutdown() + + // convenience variables + b := cluster.Bus + w := cluster.Worker + tt := cluster.tt + + tt.OK(b.UpdateSetting(context.Background(), api.SettingRedundancy, api.RedundancySettings{ + MinShards: 1, + TotalShards: 1, + })) + // add a host + hosts := cluster.AddHosts(1) + h, err := b.Host(context.Background(), hosts[0].PublicKey()) + tt.OK(err) + + // scan the host + tt.OKAll(w.RHPScan(context.Background(), h.PublicKey, h.NetAddress, 10*time.Second)) + + // manually form a contract with the host + cs, _ := b.ConsensusState(context.Background()) + wallet, _ := b.Wallet(context.Background()) + rev, _, err := w.RHPForm(context.Background(), cs.BlockHeight+test.AutopilotConfig.Contracts.Period+test.AutopilotConfig.Contracts.RenewWindow, h.PublicKey, h.NetAddress, wallet.Address, types.Siacoins(10), types.Siacoins(1)) + tt.OK(err) + c, err := b.AddContract(context.Background(), rev, rev.Revision.MissedHostPayout().Sub(types.Siacoins(1)), types.Siacoins(1), cs.BlockHeight, api.ContractStatePending) + tt.OK(err) + + tt.OK(b.SetContractSet(context.Background(), test.ContractSet, []types.FileContractID{c.ID})) + + // fund the account + fundAmt := types.Siacoins(1) + tt.OK(w.RHPFund(context.Background(), c.ID, c.HostKey, c.HostIP, c.SiamuxAddr, fundAmt)) + + // fetch accounts accounts, err := cluster.Bus.Accounts(context.Background()) tt.OK(err) + // assert account state acc := accounts[0] - minExpectedBalance := types.Siacoins(1).Sub(types.NewCurrency64(1)) - if acc.Balance.Cmp(minExpectedBalance.Big()) < 0 { - t.Fatalf("wrong balance %v", acc.Balance) - } if acc.ID == (rhpv3.Account{}) { t.Fatal("account id not set") - } - if acc.HostKey != types.PublicKey(host.PublicKey()) { + } else if acc.CleanShutdown { + t.Fatal("account should indicate an unclean shutdown") + } else if !acc.RequiresSync { + t.Fatal("account should require a sync") + } else if acc.HostKey != h.PublicKey { t.Fatal("wrong host") - } - if !acc.CleanShutdown { - t.Fatal("account should indicate a clean shutdown") + } else if acc.Balance.Cmp(types.Siacoins(1).Big()) != 0 { + t.Fatalf("wrong balance %v", acc.Balance) } - // Fetch account from bus directly. + // fetch account from bus directly busAccounts, err := cluster.Bus.Accounts(context.Background()) tt.OK(err) if len(busAccounts) != 1 { @@ -1180,12 +1199,11 @@ func TestEphemeralAccounts(t *testing.T) { t.Fatal("bus account doesn't match worker account") } - // Check that the spending was recorded for the contract. The recorded + // check that the spending was recorded for the contract. The recorded // spending should be > the fundAmt since it consists of the fundAmt plus // fee. - fundAmt := types.Siacoins(1) tt.Retry(10, testBusFlushInterval, func() error { - cm, err := cluster.Bus.Contract(context.Background(), contract.ID) + cm, err := cluster.Bus.Contract(context.Background(), c.ID) tt.OK(err) if cm.Spending.FundAccount.Cmp(fundAmt) <= 0 { @@ -1194,7 +1212,24 @@ func TestEphemeralAccounts(t *testing.T) { return nil }) - // Update the balance to create some drift. + // sync the account + tt.OK(w.RHPSync(context.Background(), c.ID, acc.HostKey, c.HostIP, c.SiamuxAddr)) + + // assert account state + accounts, err = cluster.Bus.Accounts(context.Background()) + tt.OK(err) + + // assert account state + acc = accounts[0] + if !acc.CleanShutdown { + t.Fatal("account should indicate a clean shutdown") + } else if acc.RequiresSync { + t.Fatal("account should not require a sync") + } else if acc.Drift.Cmp(new(big.Int)) != 0 { + t.Fatalf("account shoult not have drift %v", acc.Drift) + } + + // update the balance to create some drift newBalance := fundAmt.Div64(2) newDrift := new(big.Int).Sub(newBalance.Big(), fundAmt.Big()) if err := cluster.Bus.SetBalance(context.Background(), busAcc.ID, acc.HostKey, newBalance.Big()); err != nil { @@ -1208,11 +1243,11 @@ func TestEphemeralAccounts(t *testing.T) { t.Fatalf("drift was %v but should be %v", busAcc.Drift, maxNewDrift) } - // Reboot cluster. + // reboot cluster cluster2 := cluster.Reboot(t) defer cluster2.Shutdown() - // Check that accounts were loaded from the bus. + // check that accounts were loaded from the bus accounts2, err := cluster2.Bus.Accounts(context.Background()) tt.OK(err) for _, acc := range accounts2 { @@ -1225,7 +1260,7 @@ func TestEphemeralAccounts(t *testing.T) { } } - // Reset drift again. + // reset drift again if err := cluster2.Bus.ResetDrift(context.Background(), acc.ID); err != nil { t.Fatal(err) } @@ -1491,8 +1526,7 @@ func TestContractArchival(t *testing.T) { // create a test cluster cluster := newTestCluster(t, testClusterOptions{ - hosts: 1, - logger: zap.NewNop(), + hosts: 1, }) defer cluster.Shutdown() tt := cluster.tt @@ -1511,7 +1545,7 @@ func TestContractArchival(t *testing.T) { endHeight := contracts[0].WindowEnd cs, err := cluster.Bus.ConsensusState(context.Background()) tt.OK(err) - cluster.MineBlocks(int(endHeight - cs.BlockHeight + 1)) + cluster.MineBlocks(endHeight - cs.BlockHeight + 1) // check that we have 0 contracts tt.Retry(100, 100*time.Millisecond, func() error { @@ -1520,7 +1554,14 @@ func TestContractArchival(t *testing.T) { return err } if len(contracts) != 0 { - return fmt.Errorf("expected 0 contracts, got %v", len(contracts)) + // trigger contract maintenance again, there's an NDF where we use + // the keep leeway because we can't fetch the revision preventing + // the contract from being archived + _, err := cluster.Autopilot.Trigger(false) + tt.OK(err) + + cs, _ := cluster.Bus.ConsensusState(context.Background()) + return fmt.Errorf("expected 0 contracts, got %v (bh: %v we: %v)", len(contracts), cs.BlockHeight, contracts[0].WindowEnd) } return nil }) @@ -1532,10 +1573,7 @@ func TestUnconfirmedContractArchival(t *testing.T) { } // create a test cluster - cluster := newTestCluster(t, testClusterOptions{ - logger: zap.NewNop(), - hosts: 1, - }) + cluster := newTestCluster(t, testClusterOptions{hosts: 1}) defer cluster.Shutdown() tt := cluster.tt @@ -1580,9 +1618,8 @@ func TestUnconfirmedContractArchival(t *testing.T) { t.Fatalf("expected 2 contracts, got %v", len(contracts)) } - // mine for 20 blocks to make sure we are beyond the 18 block deadline for - // contract confirmation - cluster.MineBlocks(20) + // mine enough blocks to ensure we're passed the confirmation deadline + cluster.MineBlocks(contractor.ContractConfirmationDeadline + 1) tt.Retry(100, 100*time.Millisecond, func() error { contracts, err := cluster.Bus.Contracts(context.Background(), api.ContractsOpts{}) @@ -1629,7 +1666,7 @@ func TestWalletTransactions(t *testing.T) { txns, err := b.WalletTransactions(context.Background(), api.WalletTransactionsWithOffset(2)) tt.OK(err) if !reflect.DeepEqual(txns, allTxns[2:]) { - t.Fatal("transactions don't match") + t.Fatal("transactions don't match", cmp.Diff(txns, allTxns[2:])) } // Find the first index that has a different timestamp than the first. @@ -2328,13 +2365,8 @@ func TestWalletSendUnconfirmed(t *testing.T) { } // send the full balance back to the weallet - toSend := wr.Confirmed.Sub(types.Siacoins(1).Div64(100)) // leave some for the fee - tt.OK(b.SendSiacoins(context.Background(), []types.SiacoinOutput{ - { - Address: wr.Address, - Value: toSend, - }, - }, false)) + toSend := wr.Confirmed.Sub(types.Siacoins(1)) // leave some for the fee + tt.OKAll(b.SendSiacoins(context.Background(), wr.Address, toSend, false)) // the unconfirmed balance should have changed to slightly more than toSend // since we paid a fee @@ -2347,21 +2379,11 @@ func TestWalletSendUnconfirmed(t *testing.T) { fmt.Println(wr.Confirmed, wr.Unconfirmed) // try again - this should fail - err = b.SendSiacoins(context.Background(), []types.SiacoinOutput{ - { - Address: wr.Address, - Value: toSend, - }, - }, false) - tt.AssertIs(err, wallet.ErrInsufficientBalance) + _, err = b.SendSiacoins(context.Background(), wr.Address, toSend, false) + tt.AssertIs(err, wallet.ErrNotEnoughFunds) // try again - this time using unconfirmed transactions - tt.OK(b.SendSiacoins(context.Background(), []types.SiacoinOutput{ - { - Address: wr.Address, - Value: toSend, - }, - }, true)) + tt.OKAll(b.SendSiacoins(context.Background(), wr.Address, toSend, true)) // the unconfirmed balance should be almost the same wr, err = b.Wallet(context.Background()) @@ -2388,46 +2410,43 @@ func TestWalletSendUnconfirmed(t *testing.T) { } func TestWalletFormUnconfirmed(t *testing.T) { - // New cluster with autopilot disabled + // create cluster without autopilot cfg := clusterOptsDefault cfg.skipSettingAutopilot = true cluster := newTestCluster(t, cfg) defer cluster.Shutdown() + + // convenience variables b := cluster.Bus tt := cluster.tt - // Add a host. + // add a host (non-blocking) cluster.AddHosts(1) - // Send the full balance back to the wallet to make sure it's all - // unconfirmed. + // send all money to ourselves, making sure it's unconfirmed + feeReserve := types.Siacoins(1) wr, err := b.Wallet(context.Background()) tt.OK(err) - tt.OK(b.SendSiacoins(context.Background(), []types.SiacoinOutput{ - { - Address: wr.Address, - Value: wr.Confirmed.Sub(types.Siacoins(1).Div64(100)), // leave some for the fee - }, - }, false)) + tt.OKAll(b.SendSiacoins(context.Background(), wr.Address, wr.Confirmed.Sub(feeReserve), false)) // leave some for the fee - // There should be hardly any money in the wallet. + // check wallet only has the reserve in the confirmed balance wr, err = b.Wallet(context.Background()) tt.OK(err) - if wr.Confirmed.Sub(wr.Unconfirmed).Cmp(types.Siacoins(1).Div64(100)) > 0 { + if wr.Confirmed.Sub(wr.Unconfirmed).Cmp(feeReserve) > 0 { t.Fatal("wallet should have hardly any confirmed balance") } - // There shouldn't be any contracts at this point. + // there shouldn't be any contracts yet contracts, err := b.Contracts(context.Background(), api.ContractsOpts{}) tt.OK(err) if len(contracts) != 0 { t.Fatal("expected 0 contracts", len(contracts)) } - // Enable autopilot by setting it. + // enable the autopilot by configuring it cluster.UpdateAutopilotConfig(context.Background(), test.AutopilotConfig) - // Wait for a contract to form. + // wait for a contract to form contractsFormed := cluster.WaitForContracts() if len(contractsFormed) != 1 { t.Fatal("expected 1 contract", len(contracts)) @@ -2442,34 +2461,27 @@ func TestBusRecordedMetrics(t *testing.T) { }) defer cluster.Shutdown() - // Get contract set metrics. - csMetrics, err := cluster.Bus.ContractSetMetrics(context.Background(), startTime, api.MetricMaxIntervals, time.Second, api.ContractSetMetricsQueryOpts{}) - cluster.tt.OK(err) + // fetch contract set metrics + cluster.tt.Retry(100, 100*time.Millisecond, func() error { + csMetrics, err := cluster.Bus.ContractSetMetrics(context.Background(), startTime, api.MetricMaxIntervals, time.Second, api.ContractSetMetricsQueryOpts{}) + cluster.tt.OK(err) - for i := 0; i < len(csMetrics); i++ { - // Remove metrics from before contract was formed. - if csMetrics[i].Contracts > 0 { - csMetrics = csMetrics[i:] - break - } - } - if len(csMetrics) == 0 { - t.Fatal("expected at least 1 metric with contracts") - } - for _, m := range csMetrics { - if m.Contracts != 1 { - t.Fatalf("expected 1 contract, got %v", m.Contracts) + // expect at least 1 metric with contracts + if len(csMetrics) < 1 { + return fmt.Errorf("expected at least 1 metric, got %v", len(csMetrics)) + } else if m := csMetrics[len(csMetrics)-1]; m.Contracts != 1 { + return fmt.Errorf("expected 1 contract, got %v", m.Contracts) } else if m.Name != test.ContractSet { - t.Fatalf("expected contract set %v, got %v", test.ContractSet, m.Name) + return fmt.Errorf("expected contract set %v, got %v", test.ContractSet, m.Name) } else if m.Timestamp.Std().Before(startTime) { - t.Fatalf("expected time to be after start time %v, got %v", startTime, m.Timestamp.Std()) + return fmt.Errorf("expected time to be after start time %v, got %v", startTime, m.Timestamp.Std()) } - } + return nil + }) - // Get churn metrics. Should have 1 for the new contract. + // get churn metrics, should have 1 for the new contract cscMetrics, err := cluster.Bus.ContractSetChurnMetrics(context.Background(), startTime, api.MetricMaxIntervals, time.Second, api.ContractSetChurnMetricsQueryOpts{}) cluster.tt.OK(err) - if len(cscMetrics) != 1 { t.Fatalf("expected 1 metric, got %v", len(cscMetrics)) } else if m := cscMetrics[0]; m.Direction != api.ChurnDirAdded { @@ -2482,7 +2494,7 @@ func TestBusRecordedMetrics(t *testing.T) { t.Fatalf("expected time to be after start time %v, got %v", startTime, m.Timestamp.Std()) } - // Get contract metrics. + // get contract metrics var cMetrics []api.ContractMetric cluster.tt.Retry(100, 100*time.Millisecond, func() error { // Retry fetching metrics since they are buffered. @@ -2520,7 +2532,7 @@ func TestBusRecordedMetrics(t *testing.T) { t.Fatal("expected zero ListSpending") } - // Prune one of the metrics + // prune one of the metrics if err := cluster.Bus.PruneMetrics(context.Background(), api.MetricContract, time.Now()); err != nil { t.Fatal(err) } else if cMetrics, err = cluster.Bus.ContractMetrics(context.Background(), startTime, api.MetricMaxIntervals, time.Second, api.ContractMetricsQueryOpts{}); err != nil { diff --git a/internal/test/e2e/events_test.go b/internal/test/e2e/events_test.go index befa3194a..4972adf1b 100644 --- a/internal/test/e2e/events_test.go +++ b/internal/test/e2e/events_test.go @@ -25,6 +25,7 @@ func TestEvents(t *testing.T) { api.WebhookContractArchive, api.WebhookContractRenew, api.WebhookContractSetUpdate, + api.WebhookHostUpdate, api.WebhookSettingDelete, api.WebhookSettingUpdate, } @@ -49,6 +50,15 @@ func TestEvents(t *testing.T) { return nil } + // ignore host updates with net address diff. from the update we assert to receive + if event.Module == api.ModuleHost && event.Event == api.EventUpdate { + if parsed, err := api.ParseEventWebhook(event); err != nil { + t.Fatal(err) + } else if parsed.(api.EventHostUpdate).NetAddr != "127.0.0.1:0" { + return nil + } + } + // check if the event is expected if !isKnownEvent(event) { return fmt.Errorf("unexpected event %+v", event) @@ -120,6 +130,12 @@ func TestEvents(t *testing.T) { // delete setting tt.OK(b.DeleteSetting(context.Background(), api.SettingRedundancy)) + // update host setting + h := cluster.hosts[0] + settings := h.settings.Settings() + settings.NetAddress = "127.0.0.1:0" + tt.OK(h.UpdateSettings(settings)) + // wait until we received the events tt.Retry(10, time.Second, func() error { mu.Lock() @@ -152,6 +168,10 @@ func TestEvents(t *testing.T) { if e.TransactionFee.IsZero() || e.BlockHeight == 0 || e.Timestamp.IsZero() || !e.Synced { t.Fatalf("unexpected event %+v", e) } + case api.EventHostUpdate: + if e.HostKey != h.PublicKey() || e.NetAddr != "127.0.0.1:0" || e.Timestamp.IsZero() { + t.Fatalf("unexpected event %+v", e) + } case api.EventSettingUpdate: if e.Key != api.SettingGouging || e.Timestamp.IsZero() { t.Fatalf("unexpected event %+v", e) diff --git a/internal/test/e2e/gouging_test.go b/internal/test/e2e/gouging_test.go index 8915a2e11..a40fe0024 100644 --- a/internal/test/e2e/gouging_test.go +++ b/internal/test/e2e/gouging_test.go @@ -13,6 +13,7 @@ import ( "go.sia.tech/core/types" "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/test" + "go.uber.org/zap/zapcore" "lukechampine.com/frand" ) @@ -31,7 +32,7 @@ func TestGouging(t *testing.T) { tt := cluster.tt // mine enough blocks for the current period to become > period - cluster.MineBlocks(int(cfg.Period) + 1) + cluster.MineBlocks(cfg.Period + 1) // add hosts tt.OKAll(cluster.AddHostsBlocking(int(test.AutopilotConfig.Contracts.Amount))) @@ -72,6 +73,7 @@ func TestGouging(t *testing.T) { if err := b.UpdateSetting(context.Background(), api.SettingGouging, gs); err != nil { t.Fatal(err) } + // fetch current contract set contracts, err := b.Contracts(context.Background(), api.ContractsOpts{ContractSet: cfg.Set}) tt.OK(err) @@ -134,6 +136,60 @@ func TestGouging(t *testing.T) { }) } +// TestAccountFunding is a regression tests that verify we can fund an account +// even if the host is considered gouging, this protects us from not being able +// to download from certain critical hosts when we migrate away from them. +func TestAccountFunding(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + + // run without autopilot + opts := clusterOptsDefault + opts.skipRunningAutopilot = true + opts.logger = newTestLoggerCustom(zapcore.ErrorLevel) + + // create a new test cluster + cluster := newTestCluster(t, opts) + defer cluster.Shutdown() + + // convenience variables + b := cluster.Bus + w := cluster.Worker + tt := cluster.tt + + // add a host + hosts := cluster.AddHosts(1) + h, err := b.Host(context.Background(), hosts[0].PublicKey()) + tt.OK(err) + + // scan the host + _, err = w.RHPScan(context.Background(), h.PublicKey, h.NetAddress, 10*time.Second) + tt.OK(err) + + // manually form a contract with the host + cs, _ := b.ConsensusState(context.Background()) + wallet, _ := b.Wallet(context.Background()) + rev, _, err := w.RHPForm(context.Background(), cs.BlockHeight+test.AutopilotConfig.Contracts.Period+test.AutopilotConfig.Contracts.RenewWindow, h.PublicKey, h.NetAddress, wallet.Address, types.Siacoins(1), types.Siacoins(1)) + tt.OK(err) + c, err := b.AddContract(context.Background(), rev, rev.Revision.MissedHostPayout().Sub(types.Siacoins(1)), types.Siacoins(1), cs.BlockHeight, api.ContractStatePending) + tt.OK(err) + + // fund the account + tt.OK(w.RHPFund(context.Background(), c.ID, c.HostKey, c.HostIP, c.SiamuxAddr, types.Siacoins(1).Div64(2))) + + // update host so it's gouging + settings := hosts[0].settings.Settings() + settings.StoragePrice = types.Siacoins(1) + tt.OK(hosts[0].UpdateSettings(settings)) + + // ensure the price table expires so the worker is forced to fetch it + time.Sleep(defaultHostSettings.PriceTableValidity) + + // fund the account again + tt.OK(w.RHPFund(context.Background(), c.ID, c.HostKey, c.HostIP, c.SiamuxAddr, types.Siacoins(1))) +} + func TestHostMinVersion(t *testing.T) { if testing.Short() { t.SkipNow() diff --git a/internal/test/e2e/host.go b/internal/test/e2e/host.go index 284ab65ae..bd10b4af1 100644 --- a/internal/test/e2e/host.go +++ b/internal/test/e2e/host.go @@ -4,66 +4,122 @@ import ( "context" "fmt" "net" + "os" "path/filepath" + "sync" "time" "go.sia.tech/core/consensus" + "go.sia.tech/core/gateway" crhpv2 "go.sia.tech/core/rhp/v2" crhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" + "go.sia.tech/coreutils" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/host/accounts" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/registry" "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/storage" + "go.sia.tech/hostd/index" "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/rhp" rhpv2 "go.sia.tech/hostd/rhp/v2" rhpv3 "go.sia.tech/hostd/rhp/v3" - "go.sia.tech/hostd/wallet" - "go.sia.tech/hostd/webhooks" - "go.sia.tech/renterd/bus" - "go.sia.tech/renterd/internal/node" - "go.sia.tech/siad/modules" - mconsensus "go.sia.tech/siad/modules/consensus" - "go.sia.tech/siad/modules/gateway" - "go.sia.tech/siad/modules/transactionpool" "go.uber.org/zap" ) -const blocksPerMonth = 144 * 30 +const ( + blocksPerDay = 144 + blocksPerMonth = blocksPerDay * 30 +) -type stubMetricReporter struct{} +type ephemeralPeerStore struct { + peers map[string]syncer.PeerInfo + bans map[string]time.Time + mu sync.Mutex +} -func (stubMetricReporter) StartSession(conn *rhp.Conn, proto string, version int) (rhp.UID, func()) { - return rhp.UID{}, func() {} +func (eps *ephemeralPeerStore) AddPeer(addr string) error { + eps.mu.Lock() + defer eps.mu.Unlock() + eps.peers[addr] = syncer.PeerInfo{Address: addr} + return nil } -func (stubMetricReporter) StartRPC(rhp.UID, types.Specifier) (rhp.UID, func(contracts.Usage, error)) { - return rhp.UID{}, func(contracts.Usage, error) {} + +func (eps *ephemeralPeerStore) Peers() ([]syncer.PeerInfo, error) { + eps.mu.Lock() + defer eps.mu.Unlock() + var peers []syncer.PeerInfo + for _, peer := range eps.peers { + peers = append(peers, peer) + } + return peers, nil } -type stubDataMonitor struct{} +func (eps *ephemeralPeerStore) PeerInfo(addr string) (syncer.PeerInfo, error) { + eps.mu.Lock() + defer eps.mu.Unlock() + peer, ok := eps.peers[addr] + if !ok { + return syncer.PeerInfo{}, syncer.ErrPeerNotFound + } + return peer, nil +} -func (stubDataMonitor) ReadBytes(n int) {} -func (stubDataMonitor) WriteBytes(n int) {} +func (eps *ephemeralPeerStore) UpdatePeerInfo(addr string, fn func(*syncer.PeerInfo)) error { + eps.mu.Lock() + defer eps.mu.Unlock() + peer, ok := eps.peers[addr] + if !ok { + return syncer.ErrPeerNotFound + } + fn(&peer) + eps.peers[addr] = peer + return nil +} + +func (eps *ephemeralPeerStore) Ban(addr string, duration time.Duration, reason string) error { + eps.mu.Lock() + defer eps.mu.Unlock() + eps.bans[addr] = time.Now().Add(duration) + return nil +} + +// Banned returns true, nil if the peer is banned. +func (eps *ephemeralPeerStore) Banned(addr string) (bool, error) { + eps.mu.Lock() + defer eps.mu.Unlock() + t, ok := eps.bans[addr] + return ok && time.Now().Before(t), nil +} + +func newEphemeralPeerStore() syncer.PeerStore { + return &ephemeralPeerStore{ + peers: make(map[string]syncer.PeerInfo), + bans: make(map[string]time.Time), + } +} // A Host is an ephemeral host that can be used for testing. type Host struct { dir string privKey types.PrivateKey - g modules.Gateway - cs modules.ConsensusSet - tp bus.TransactionPool + s *syncer.Syncer + syncerCancel context.CancelFunc + cm *chain.Manager + chainDB *coreutils.BoltChainDB store *sqlite.Store wallet *wallet.SingleAddressWallet settings *settings.ConfigManager storage *storage.VolumeManager + index *index.Manager registry *registry.Manager accounts *accounts.AccountManager - contracts *contracts.ContractManager + contracts *contracts.Manager rhpv2 *rhpv2.SessionHandler rhpv3 *rhpv3.SessionHandler @@ -97,13 +153,14 @@ func (h *Host) Close() error { h.rhpv2.Close() h.rhpv3.Close() h.settings.Close() + h.index.Close() h.wallet.Close() h.contracts.Close() h.storage.Close() h.store.Close() - h.tp.Close() - h.cs.Close() - h.g.Close() + h.syncerCancel() + h.s.Close() + h.chainDB.Close() return nil } @@ -119,7 +176,7 @@ func (h *Host) RHPv3Addr() string { // AddVolume adds a new volume to the host func (h *Host) AddVolume(ctx context.Context, path string, size uint64) error { - result := make(chan error) + result := make(chan error, 1) _, err := h.storage.AddVolume(ctx, path, size, result) if err != nil { return err @@ -148,7 +205,7 @@ func (h *Host) WalletAddress() types.Address { } // Contracts returns the host's contract manager -func (h *Host) Contracts() *contracts.ContractManager { +func (h *Host) Contracts() *contracts.Manager { return h.contracts } @@ -157,30 +214,43 @@ func (h *Host) PublicKey() types.PublicKey { return h.privKey.PublicKey() } -// GatewayAddr returns the address of the host's gateway. -func (h *Host) GatewayAddr() string { - return string(h.g.Address()) +// SyncerAddr returns the address of the host's syncer. +func (h *Host) SyncerAddr() string { + return string(h.s.Addr()) } -// NewHost initializes a new test host -func NewHost(privKey types.PrivateKey, dir string, network *consensus.Network, debugLogging bool) (*Host, error) { - g, err := gateway.New("localhost:0", false, filepath.Join(dir, "gateway")) - if err != nil { - return nil, fmt.Errorf("failed to create gateway: %w", err) +// NewHost initializes a new test host. +func NewHost(privKey types.PrivateKey, dir string, network *consensus.Network, genesisBlock types.Block) (*Host, error) { + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, fmt.Errorf("failed to create dir: %w", err) } - cs, errCh := mconsensus.New(g, false, filepath.Join(dir, "consensus")) - if err := <-errCh; err != nil { - return nil, fmt.Errorf("failed to create consensus set: %w", err) - } - tpool, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) + chainDB, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "chain.db")) if err != nil { - return nil, fmt.Errorf("failed to create transaction pool: %w", err) + return nil, fmt.Errorf("failed to create chaindb: %w", err) } - tp := node.NewTransactionPool(tpool) - cm, err := node.NewChainManager(cs, tp, network) + dbStore, tipState, err := chain.NewDBStore(chainDB, network, genesisBlock) if err != nil { return nil, err } + cm := chain.NewManager(dbStore, tipState) + + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, fmt.Errorf("failed to create syncer listener: %w", err) + } + s := syncer.New(l, cm, newEphemeralPeerStore(), gateway.Header{ + GenesisID: genesisBlock.ID(), + UniqueID: gateway.GenerateUniqueID(), + NetAddress: l.Addr().String(), + }, syncer.WithPeerDiscoveryInterval(100*time.Millisecond), syncer.WithSyncInterval(100*time.Millisecond)) + syncErrChan := make(chan error, 1) + syncerCtx, syncerCancel := context.WithCancel(context.Background()) + defer func() { + if err != nil { + syncerCancel() + } + }() + go func() { syncErrChan <- s.Run(syncerCtx) }() log := zap.NewNop() db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("sqlite")) @@ -188,23 +258,17 @@ func NewHost(privKey types.PrivateKey, dir string, network *consensus.Network, d return nil, fmt.Errorf("failed to create sql store: %w", err) } - wallet, err := wallet.NewSingleAddressWallet(privKey, cm, db, log.Named("wallet")) + wallet, err := wallet.NewSingleAddressWallet(privKey, cm, db) if err != nil { return nil, fmt.Errorf("failed to create wallet: %w", err) } - wr, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - return nil, fmt.Errorf("failed to create webhook reporter: %w", err) - } - - am := alerts.NewManager(wr, log.Named("alerts")) - storage, err := storage.NewVolumeManager(db, am, cm, log.Named("storage"), 0) + storage, err := storage.NewVolumeManager(db) if err != nil { return nil, fmt.Errorf("failed to create storage manager: %w", err) } - contracts, err := contracts.NewManager(db, am, storage, cm, tp, wallet, log.Named("contracts")) + contracts, err := contracts.NewManager(db, storage, cm, s, wallet, contracts.WithRejectAfter(10), contracts.WithRevisionSubmissionBuffer(5)) if err != nil { return nil, fmt.Errorf("failed to create contract manager: %w", err) } @@ -219,29 +283,26 @@ func NewHost(privKey types.PrivateKey, dir string, network *consensus.Network, d return nil, fmt.Errorf("failed to create rhp3 listener: %w", err) } - settings, err := settings.NewConfigManager( - settings.WithHostKey(privKey), - settings.WithRHP2Addr(rhp2Listener.Addr().String()), - settings.WithStore(db), - settings.WithChainManager(cm), - settings.WithTransactionPool(tp), - settings.WithWallet(wallet), - settings.WithAlertManager(am), - settings.WithLog(log.Named("settings"))) + settings, err := settings.NewConfigManager(privKey, db, cm, s, wallet) if err != nil { return nil, fmt.Errorf("failed to create settings manager: %w", err) } + idx, err := index.NewManager(db, cm, contracts, wallet, settings, storage, index.WithLog(log.Named("index")), index.WithBatchSize(0)) // off-by-one + if err != nil { + return nil, fmt.Errorf("failed to create index manager: %w", err) + } + registry := registry.NewManager(privKey, db, zap.NewNop()) accounts := accounts.NewManager(db, settings) - rhpv2, err := rhpv2.NewSessionHandler(rhp2Listener, privKey, rhp3Listener.Addr().String(), cm, tp, wallet, contracts, settings, storage, stubDataMonitor{}, stubMetricReporter{}, log.Named("rhpv2")) + rhpv2, err := rhpv2.NewSessionHandler(rhp2Listener, privKey, rhp3Listener.Addr().String(), cm, s, wallet, contracts, settings, storage) if err != nil { return nil, fmt.Errorf("failed to create rhpv2 session handler: %w", err) } go rhpv2.Serve() - rhpv3, err := rhpv3.NewSessionHandler(rhp3Listener, privKey, cm, tp, wallet, accounts, contracts, registry, storage, settings, stubDataMonitor{}, stubMetricReporter{}, log.Named("rhpv3")) + rhpv3, err := rhpv3.NewSessionHandler(rhp3Listener, privKey, cm, s, wallet, accounts, contracts, registry, storage, settings) if err != nil { return nil, fmt.Errorf("failed to create rhpv3 session handler: %w", err) } @@ -251,13 +312,15 @@ func NewHost(privKey types.PrivateKey, dir string, network *consensus.Network, d dir: dir, privKey: privKey, - g: g, - cs: cs, - tp: tp, + s: s, + syncerCancel: syncerCancel, + cm: cm, + chainDB: chainDB, store: db, wallet: wallet, settings: settings, + index: idx, storage: storage, registry: registry, accounts: accounts, diff --git a/internal/test/e2e/interactions_test.go b/internal/test/e2e/interactions_test.go index 021d75cb6..86a634891 100644 --- a/internal/test/e2e/interactions_test.go +++ b/internal/test/e2e/interactions_test.go @@ -51,8 +51,8 @@ func TestInteractions(t *testing.T) { // assert price table gets updated var ptUpdates int tt.Retry(100, 100*time.Millisecond, func() error { - // fetch contracts (this registers host interactions) - tt.OKAll(w.Contracts(context.Background(), time.Minute)) + // fetch pricetable (this registers host interactions) + tt.OKAll(w.RHPPriceTable(context.Background(), h1.PublicKey(), h.Settings.SiamuxAddr(), 0)) // fetch the host h, err := b.Host(context.Background(), h1.PublicKey()) diff --git a/internal/test/e2e/metadata_test.go b/internal/test/e2e/metadata_test.go index 4bb1ea2dd..4dd6c1229 100644 --- a/internal/test/e2e/metadata_test.go +++ b/internal/test/e2e/metadata_test.go @@ -10,7 +10,6 @@ import ( "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/test" - "go.uber.org/zap" ) func TestObjectMetadata(t *testing.T) { @@ -20,8 +19,7 @@ func TestObjectMetadata(t *testing.T) { // create cluster cluster := newTestCluster(t, testClusterOptions{ - hosts: test.RedundancySettings.TotalShards, - logger: zap.NewNop(), + hosts: test.RedundancySettings.TotalShards, }) defer cluster.Shutdown() diff --git a/internal/test/e2e/metrics_test.go b/internal/test/e2e/metrics_test.go index aaa139102..fcec2d6c1 100644 --- a/internal/test/e2e/metrics_test.go +++ b/internal/test/e2e/metrics_test.go @@ -85,7 +85,6 @@ func TestMetrics(t *testing.T) { if len(wm) == 0 { return errors.New("no wallet metrics") } - return nil }) } diff --git a/internal/test/e2e/pruning_test.go b/internal/test/e2e/pruning_test.go index 7c1a856f1..8492bf9f1 100644 --- a/internal/test/e2e/pruning_test.go +++ b/internal/test/e2e/pruning_test.go @@ -20,10 +20,11 @@ func TestHostPruning(t *testing.T) { } // create a new test cluster - cluster := newTestCluster(t, clusterOptsDefault) + cluster := newTestCluster(t, testClusterOptions{hosts: 1}) defer cluster.Shutdown() + + // convenience variables b := cluster.Bus - w := cluster.Worker a := cluster.Autopilot tt := cluster.tt @@ -43,44 +44,19 @@ func TestHostPruning(t *testing.T) { tt.OK(b.RecordHostScans(context.Background(), his)) } - // add a host - hosts := cluster.AddHosts(1) - h1 := hosts[0] - - // fetch the host - h, err := b.Host(context.Background(), h1.PublicKey()) - tt.OK(err) - - // scan the host (lastScan needs to be > 0 for downtime to start counting) - tt.OKAll(w.RHPScan(context.Background(), h1.PublicKey(), h.NetAddress, 0)) - - // block the host - tt.OK(b.UpdateHostBlocklist(context.Background(), []string{h1.PublicKey().String()}, nil, false)) + // shut down the worker manually, this will flush any interactions + cluster.ShutdownWorker(context.Background()) // remove it from the cluster manually + h1 := cluster.hosts[0] cluster.RemoveHost(h1) - // shut down the worker manually, this will flush any interactions - cluster.ShutdownWorker(context.Background()) - // record 9 failed interactions, right before the pruning threshold, and // wait for the autopilot loop to finish at least once recordFailedInteractions(9, h1.PublicKey()) - // trigger the autopilot loop twice, failing to trigger it twice shouldn't - // fail the test, this avoids an NDF on windows - remaining := 2 - for i := 1; i < 100; i++ { - triggered, err := a.Trigger(false) - tt.OK(err) - if triggered { - remaining-- - if remaining == 0 { - break - } - } - time.Sleep(50 * time.Millisecond) - } + // trigger the autopilot + tt.OKAll(a.Trigger(true)) // assert the host was not pruned hostss, err := b.Hosts(context.Background(), api.GetHostsOptions{}) @@ -98,6 +74,7 @@ func TestHostPruning(t *testing.T) { hostss, err = b.Hosts(context.Background(), api.GetHostsOptions{}) tt.OK(err) if len(hostss) != 0 { + a.Trigger(true) // trigger autopilot return fmt.Errorf("host was not pruned, %+v", hostss[0].Interactions) } return nil diff --git a/internal/test/e2e/s3_test.go b/internal/test/e2e/s3_test.go index daaefed5e..3f20e22ad 100644 --- a/internal/test/e2e/s3_test.go +++ b/internal/test/e2e/s3_test.go @@ -18,7 +18,6 @@ import ( "go.sia.tech/gofakes3" "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/test" - "go.uber.org/zap" "lukechampine.com/frand" ) @@ -195,8 +194,7 @@ func TestS3ObjectMetadata(t *testing.T) { // create cluster opts := testClusterOptions{ - hosts: test.RedundancySettings.TotalShards, - logger: zap.NewNop(), + hosts: test.RedundancySettings.TotalShards, } cluster := newTestCluster(t, opts) defer cluster.Shutdown() diff --git a/internal/test/host.go b/internal/test/host.go index e95d1d3d1..f5128b8fc 100644 --- a/internal/test/host.go +++ b/internal/test/host.go @@ -35,11 +35,12 @@ func NewHost(hk types.PublicKey, pt rhpv3.HostPriceTable, settings rhpv2.HostSet SuccessfulInteractions: 2, FailedInteractions: 0, }, - PublicKey: hk, - PriceTable: api.HostPriceTable{HostPriceTable: pt, Expiry: time.Now().Add(time.Minute)}, - Settings: settings, - Scanned: true, - Subnets: []string{"38.135.51.0/24"}, + PublicKey: hk, + PriceTable: api.HostPriceTable{HostPriceTable: pt, Expiry: time.Now().Add(time.Minute)}, + Settings: settings, + Scanned: true, + ResolvedAddresses: []string{"38.135.51.1"}, + Subnets: []string{"38.135.51.0/24"}, } } diff --git a/autopilot/host_test.go b/internal/test/host_test.go similarity index 78% rename from autopilot/host_test.go rename to internal/test/host_test.go index 0a3de2a71..cf6e3fda5 100644 --- a/autopilot/host_test.go +++ b/internal/test/host_test.go @@ -1,14 +1,12 @@ -package autopilot +package test import ( "testing" - - "go.sia.tech/renterd/internal/test" ) func TestHost(t *testing.T) { - hk := test.RandomHostKey() - h := test.NewHost(hk, test.NewHostPriceTable(), test.NewHostSettings()) + hk := RandomHostKey() + h := NewHost(hk, NewHostPriceTable(), NewHostSettings()) // assert host is online if !h.IsOnline() { diff --git a/internal/utils/chain.go b/internal/utils/chain.go new file mode 100644 index 000000000..613c41be3 --- /dev/null +++ b/internal/utils/chain.go @@ -0,0 +1,11 @@ +package utils + +import ( + "time" + + "go.sia.tech/core/types" +) + +func IsSynced(b types.Block) bool { + return time.Since(b.Timestamp) <= 3*time.Hour +} diff --git a/internal/utils/errors.go b/internal/utils/errors.go index 67696b984..22ff0e660 100644 --- a/internal/utils/errors.go +++ b/internal/utils/errors.go @@ -1,7 +1,9 @@ package utils import ( + "context" "errors" + "fmt" "strings" ) @@ -28,3 +30,15 @@ func IsErr(err error, target error) bool { // renterd/hostd use the same error messages return strings.Contains(strings.ToLower(err.Error()), strings.ToLower(target.Error())) } + +// WrapErr can be used to defer wrapping an error which is then decorated with +// the provided function name. If the context contains a cause error, it will +// also be included in the wrapping. +func WrapErr(ctx context.Context, fnName string, err *error) { + if *err != nil { + *err = fmt.Errorf("%s: %w", fnName, *err) + if cause := context.Cause(ctx); cause != nil && !IsErr(*err, cause) { + *err = fmt.Errorf("%w; %w", cause, *err) + } + } +} diff --git a/internal/utils/fmt.go b/internal/utils/fmt.go new file mode 100644 index 000000000..782b7a566 --- /dev/null +++ b/internal/utils/fmt.go @@ -0,0 +1,17 @@ +package utils + +import "fmt" + +func HumanReadableSize(b int) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", + float64(b)/float64(div), "KMGTPE"[exp]) +} diff --git a/internal/utils/net.go b/internal/utils/net.go index a4aabd252..999896c48 100644 --- a/internal/utils/net.go +++ b/internal/utils/net.go @@ -2,11 +2,10 @@ package utils import ( "context" + "errors" "fmt" "net" "sort" - - "go.sia.tech/renterd/api" ) const ( @@ -16,6 +15,10 @@ const ( var ( privateSubnets []*net.IPNet + + // ErrHostTooManyAddresses is returned by the worker API when a host has + // more than two addresses of the same type. + ErrHostTooManyAddresses = errors.New("host has more than two addresses, or two of the same type") ) func init() { @@ -33,7 +36,34 @@ func init() { } } -func ResolveHostIP(ctx context.Context, hostIP string) (subnets []string, private bool, _ error) { +func AddressesToSubnets(resolvedAddresses []string) ([]string, error) { + var subnets []string + for _, addr := range resolvedAddresses { + parsed := net.ParseIP(addr) + if parsed == nil { + return nil, fmt.Errorf("failed to parse address: %s", addr) + } + + // figure out the IP range + ipRange := ipv6FilterRange + if parsed.To4() != nil { + ipRange = ipv4FilterRange + } + + // parse the subnet + cidr := fmt.Sprintf("%s/%d", parsed.String(), ipRange) + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return nil, fmt.Errorf("failed to parse cidr: %w", err) + } + + subnets = append(subnets, ipnet.String()) + } + + return subnets, nil +} + +func ResolveHostIP(ctx context.Context, hostIP string) (ips []string, private bool, _ error) { // resolve host address host, _, err := net.SplitHostPort(hostIP) if err != nil { @@ -46,33 +76,20 @@ func ResolveHostIP(ctx context.Context, hostIP string) (subnets []string, privat // filter out hosts associated with more than two addresses or two of the same type if len(addrs) > 2 || (len(addrs) == 2) && (len(addrs[0].IP) == len(addrs[1].IP)) { - return nil, false, api.ErrHostTooManyAddresses + return nil, false, fmt.Errorf("%w: %+v", ErrHostTooManyAddresses, addrs) } - // parse out subnets + // get ips for _, address := range addrs { private = private || isPrivateIP(address.IP) - // figure out the IP range - ipRange := ipv6FilterRange - if address.IP.To4() != nil { - ipRange = ipv4FilterRange - } - - // parse the subnet - cidr := fmt.Sprintf("%s/%d", address.String(), ipRange) - _, ipnet, err := net.ParseCIDR(cidr) - if err != nil { - continue - } - // add it - subnets = append(subnets, ipnet.String()) + ips = append(ips, address.IP.String()) } - // sort the subnets - sort.Slice(subnets, func(i, j int) bool { - return subnets[i] < subnets[j] + // sort the ips + sort.Slice(ips, func(i, j int) bool { + return ips[i] < ips[j] }) return } diff --git a/stats/stats.go b/internal/utils/stats.go similarity index 91% rename from stats/stats.go rename to internal/utils/stats.go index d64bff352..9af14fe7d 100644 --- a/stats/stats.go +++ b/internal/utils/stats.go @@ -1,4 +1,4 @@ -package stats +package utils import ( "math" @@ -29,15 +29,7 @@ type ( Float64Data = stats.Float64Data ) -func Default() *DataPoints { - return New(statsDecayHalfTime) -} - -func NoDecay() *DataPoints { - return New(0) -} - -func New(halfLife time.Duration) *DataPoints { +func NewDataPoints(halfLife time.Duration) *DataPoints { return &DataPoints{ size: 1000, Float64Data: make([]float64, 0), diff --git a/internal/utils/version.go b/internal/utils/version.go new file mode 100644 index 000000000..3e46792f7 --- /dev/null +++ b/internal/utils/version.go @@ -0,0 +1,75 @@ +package utils + +import ( + "strconv" + "strings" +) + +// IsVersion returns whether str is a valid release version with no -rc component. +func IsVersion(str string) bool { + for _, n := range strings.Split(str, ".") { + if _, err := strconv.Atoi(n); err != nil { + return false + } + } + return true +} + +// VersionCmp returns an int indicating the difference between a and b. It +// follows the convention of bytes.Compare and big.Cmp: +// +// -1 if a < b +// 0 if a == b +// +1 if a > b +// +// One important quirk is that "1.1.0" is considered newer than "1.1", despite +// being numerically equal. +func VersionCmp(a, b string) int { + va, rca := splitVersion(a) + vb, rcb := splitVersion(b) + + for i := 0; i < min(len(va), len(vb)); i++ { + if va[i] < vb[i] { + return -1 + } else if va[i] > vb[i] { + return 1 + } + } + + switch { + case len(va) < len(vb): // a has fewer digits than b + return -1 + case len(va) > len(vb): // a has more digits than b + return 1 + case rca == rcb: // length is equal and rcs are equal + return 0 + case rca == 0: // a is a full release + return 1 + case rcb == 0: // b is a full release + return -1 + case rca > rcb: + return 1 + case rca < rcb: + return -1 + } + + return 0 +} + +// splitVersion splits a version string into it's version and optional rc component. +// full releases are considered rc 0. +func splitVersion(v string) (version []int, rc int) { + parts := strings.Split(v, "-rc") + for _, s := range strings.Split(parts[0], ".") { + n, _ := strconv.Atoi(s) + version = append(version, n) + } + if len(parts) == 1 { // if we don't have an rc part, we're done + return + } else if parts[1] == "" { // -rc is equivalent to -rc1 since rc0 is a full release + return version, 1 + } + + rc, _ = strconv.Atoi(parts[1]) + return +} diff --git a/internal/utils/web.go b/internal/utils/web.go index 0a490acd6..6f0caa571 100644 --- a/internal/utils/web.go +++ b/internal/utils/web.go @@ -1,9 +1,17 @@ package utils import ( + "errors" + "fmt" + "net" "net/http" _ "net/http/pprof" + "os/exec" + "runtime" "strings" + + "go.sia.tech/jape" + "go.uber.org/zap" ) type TreeMux struct { @@ -30,3 +38,45 @@ func (t TreeMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { } http.NotFound(w, req) } + +func Auth(password string, unauthenticatedDownloads bool) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if unauthenticatedDownloads && req.Method == http.MethodGet && strings.HasPrefix(req.URL.Path, "/objects/") { + h.ServeHTTP(w, req) + } else { + jape.BasicAuth(password)(h).ServeHTTP(w, req) + } + }) + } +} + +func ListenTCP(addr string, logger *zap.Logger) (net.Listener, error) { + l, err := net.Listen("tcp", addr) + if IsErr(err, errors.New("no such host")) && strings.Contains(addr, "localhost") { + // fall back to 127.0.0.1 if 'localhost' doesn't work + _, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + fallbackAddr := fmt.Sprintf("127.0.0.1:%s", port) + logger.Sugar().Warnf("failed to listen on %s, falling back to %s", addr, fallbackAddr) + return net.Listen("tcp", fallbackAddr) + } else if err != nil { + return nil, err + } + return l, nil +} + +func OpenBrowser(url string) error { + switch runtime.GOOS { + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + return exec.Command("open", url).Start() + default: + return fmt.Errorf("unsupported platform %q", runtime.GOOS) + } +} diff --git a/internal/worker/auth.go b/internal/worker/auth.go deleted file mode 100644 index 032d2536c..000000000 --- a/internal/worker/auth.go +++ /dev/null @@ -1,20 +0,0 @@ -package worker - -import ( - "net/http" - "strings" - - "go.sia.tech/jape" -) - -func Auth(password string, unauthenticatedDownloads bool) func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if unauthenticatedDownloads && req.Method == http.MethodGet && strings.HasPrefix(req.URL.Path, "/objects/") { - h.ServeHTTP(w, req) - } else { - jape.BasicAuth(password)(h).ServeHTTP(w, req) - } - }) - } -} diff --git a/internal/worker/cache.go b/internal/worker/cache.go index e223c82fe..dfc749d2a 100644 --- a/internal/worker/cache.go +++ b/internal/worker/cache.go @@ -75,14 +75,13 @@ type ( Bus interface { Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) GougingParams(ctx context.Context) (api.GougingParams, error) - RegisterWebhook(ctx context.Context, wh webhooks.Webhook) error } WorkerCache interface { DownloadContracts(ctx context.Context) ([]api.ContractMetadata, error) GougingParams(ctx context.Context) (api.GougingParams, error) HandleEvent(event webhooks.Event) error - Initialize(ctx context.Context, workerAPI string, opts ...webhooks.HeaderOption) error + Subscribe(e EventSubscriber) error } ) @@ -92,16 +91,17 @@ type cache struct { cache *memoryCache logger *zap.SugaredLogger - mu sync.Mutex - ready bool + mu sync.Mutex + readyChan chan struct{} } func NewCache(b Bus, logger *zap.Logger) WorkerCache { + logger = logger.Named("workercache") return &cache{ b: b, cache: newMemoryCache(), - logger: logger.Sugar().Named("workercache"), + logger: logger.Sugar(), } } @@ -168,12 +168,18 @@ func (c *cache) HandleEvent(event webhooks.Event) (err error) { case api.EventConsensusUpdate: log = log.With("bh", e.BlockHeight, "ts", e.Timestamp) c.handleConsensusUpdate(e) + case api.EventContractAdd: + log = log.With("fcid", e.Added.ID, "ts", e.Timestamp) + c.handleContractAdd(e) case api.EventContractArchive: log = log.With("fcid", e.ContractID, "ts", e.Timestamp) c.handleContractArchive(e) case api.EventContractRenew: log = log.With("fcid", e.Renewal.ID, "renewedFrom", e.Renewal.RenewedFrom, "ts", e.Timestamp) c.handleContractRenew(e) + case api.EventHostUpdate: + log = log.With("hk", e.HostKey, "ts", e.Timestamp) + c.handleHostUpdate(e) case api.EventSettingUpdate: log = log.With("key", e.Key, "ts", e.Timestamp) err = c.handleSettingUpdate(e) @@ -194,32 +200,27 @@ func (c *cache) HandleEvent(event webhooks.Event) (err error) { return } -func (c *cache) Initialize(ctx context.Context, workerAPI string, webhookOpts ...webhooks.HeaderOption) error { - eventsURL := fmt.Sprintf("%s/events", workerAPI) - headers := make(map[string]string) - for _, opt := range webhookOpts { - opt(headers) +func (c *cache) Subscribe(e EventSubscriber) (err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.readyChan != nil { + return fmt.Errorf("already subscribed") } - for _, wh := range []webhooks.Webhook{ - api.WebhookConsensusUpdate(eventsURL, headers), - api.WebhookContractArchive(eventsURL, headers), - api.WebhookContractRenew(eventsURL, headers), - api.WebhookSettingUpdate(eventsURL, headers), - } { - if err := c.b.RegisterWebhook(ctx, wh); err != nil { - return fmt.Errorf("failed to register webhook '%s', err: %v", wh, err) - } + + c.readyChan, err = e.AddEventHandler(c.logger.Desugar().Name(), c) + if err != nil { + return fmt.Errorf("failed to subscribe the worker cache, error: %v", err) } - c.mu.Lock() - c.ready = true - c.mu.Unlock() return nil } func (c *cache) isReady() bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.ready + select { + case <-c.readyChan: + return true + default: + } + return false } func (c *cache) handleConsensusUpdate(event api.EventConsensusUpdate) { @@ -236,6 +237,24 @@ func (c *cache) handleConsensusUpdate(event api.EventConsensusUpdate) { c.cache.Set(cacheKeyGougingParams, gp) } +func (c *cache) handleContractAdd(event api.EventContractAdd) { + // return early if the cache doesn't have contracts + value, found, _ := c.cache.Get(cacheKeyDownloadContracts) + if !found { + return + } + contracts := value.([]api.ContractMetadata) + + // add the contract to the cache + for _, contract := range contracts { + if contract.ID == event.Added.ID { + return + } + } + contracts = append(contracts, event.Added) + c.cache.Set(cacheKeyDownloadContracts, contracts) +} + func (c *cache) handleContractArchive(event api.EventContractArchive) { // return early if the cache doesn't have contracts value, found, _ := c.cache.Get(cacheKeyDownloadContracts) @@ -273,6 +292,24 @@ func (c *cache) handleContractRenew(event api.EventContractRenew) { c.cache.Set(cacheKeyDownloadContracts, contracts) } +func (c *cache) handleHostUpdate(e api.EventHostUpdate) { + // return early if the cache doesn't have contracts + value, found, _ := c.cache.Get(cacheKeyDownloadContracts) + if !found { + return + } + contracts := value.([]api.ContractMetadata) + + // update the host's IP in the cache + for i, contract := range contracts { + if contract.HostKey == e.HostKey { + contracts[i].HostIP = e.NetAddr + } + } + + c.cache.Set(cacheKeyDownloadContracts, contracts) +} + func (c *cache) handleSettingDelete(e api.EventSettingDelete) { if e.Key == api.SettingGouging || e.Key == api.SettingRedundancy { c.cache.Invalidate(cacheKeyGougingParams) diff --git a/internal/worker/cache_test.go b/internal/worker/cache_test.go index e696ed02c..9bc8d682d 100644 --- a/internal/worker/cache_test.go +++ b/internal/worker/cache_test.go @@ -26,7 +26,22 @@ func (m *mockBus) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api. func (m *mockBus) GougingParams(ctx context.Context) (api.GougingParams, error) { return m.gougingParams, nil } -func (m *mockBus) RegisterWebhook(ctx context.Context, wh webhooks.Webhook) error { + +type mockEventSubscriber struct { + readyChan chan struct{} +} + +func (m *mockEventSubscriber) AddEventHandler(id string, h EventHandler) (chan struct{}, error) { + return m.readyChan, nil +} + +func (m *mockEventSubscriber) ProcessEvent(event webhooks.Event) {} + +func (m *mockEventSubscriber) Register(ctx context.Context, eventURL string, opts ...webhooks.HeaderOption) error { + return nil +} + +func (m *mockEventSubscriber) Shutdown(ctx context.Context) error { return nil } @@ -57,7 +72,13 @@ func TestWorkerCache(t *testing.T) { // create mock bus and cache c, b, mc := newTestCache(zap.New(observedZapCore)) - // assert using cache before it's initialized prints a warning + // create mock event subscriber + m := &mockEventSubscriber{readyChan: make(chan struct{})} + + // subscribe cache to event subscriber + c.Subscribe(m) + + // assert using cache before it's ready prints a warning contracts, err := c.DownloadContracts(context.Background()) if err != nil { t.Fatal(err) @@ -84,10 +105,8 @@ func TestWorkerCache(t *testing.T) { t.Fatal("expected error message to contain 'cache is not ready yet', got", lines[0].Message) } - // initialize the cache - if err := c.Initialize(context.Background(), ""); err != nil { - t.Fatal(err) - } + // close the ready channel + close(m.readyChan) // fetch contracts & gouging params so they're cached _, err = c.DownloadContracts(context.Background()) @@ -149,6 +168,7 @@ func TestWorkerCache(t *testing.T) { {Module: api.ModuleConsensus, Event: api.EventUpdate, Payload: nil}, {Module: api.ModuleContract, Event: api.EventArchive, Payload: nil}, {Module: api.ModuleContract, Event: api.EventRenew, Payload: nil}, + {Module: api.ModuleHost, Event: api.EventUpdate, Payload: nil}, {Module: api.ModuleSetting, Event: api.EventUpdate, Payload: nil}, {Module: api.ModuleSetting, Event: api.EventDelete, Payload: nil}, } { diff --git a/internal/worker/dialer.go b/internal/worker/dialer.go new file mode 100644 index 000000000..56e51ce42 --- /dev/null +++ b/internal/worker/dialer.go @@ -0,0 +1,110 @@ +package worker + +import ( + "context" + "fmt" + "net" + "sync" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.uber.org/zap" +) + +// Cache to store resolved IPs +type hostCache struct { + mu sync.RWMutex + cache map[string]string // hostname -> IP address +} + +func newHostCache() *hostCache { + return &hostCache{ + cache: make(map[string]string), + } +} + +func (hc *hostCache) Get(hostname string) (string, bool) { + hc.mu.RLock() + defer hc.mu.RUnlock() + ip, ok := hc.cache[hostname] + return ip, ok +} + +func (hc *hostCache) Set(hostname, ip string) { + hc.mu.Lock() + defer hc.mu.Unlock() + hc.cache[hostname] = ip +} + +func (hc *hostCache) Delete(hostname string) { + hc.mu.Lock() + defer hc.mu.Unlock() + delete(hc.cache, hostname) +} + +type DialerBus interface { + Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) +} + +// FallbackDialer implements a custom net.Dialer with a fallback mechanism +type FallbackDialer struct { + cache *hostCache + + bus DialerBus + logger *zap.SugaredLogger + dialer net.Dialer +} + +func NewFallbackDialer(bus DialerBus, dialer net.Dialer, logger *zap.Logger) *FallbackDialer { + return &FallbackDialer{ + cache: newHostCache(), + + bus: bus, + logger: logger.Sugar().Named("fallbackdialer"), + dialer: dialer, + } +} + +func (d *FallbackDialer) Dial(ctx context.Context, hk types.PublicKey, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("failed to split host and port of host address '%v': %w", address, err) + } + logger := d.logger.With(zap.String("hostKey", hk.String()), zap.String("host", host)) + + // Dial and cache the resolved IP if dial successful + conn, err := d.dialer.DialContext(ctx, "tcp", address) + if err == nil { + d.cache.Set(host, conn.RemoteAddr().String()) + return conn, nil + } + + // If resolution fails, check the cache + if cachedIP, ok := d.cache.Get(host); ok { + logger.Debug("Failed to resolve host, using cached IP", zap.Error(err)) + conn, err := d.dialer.DialContext(ctx, "tcp", net.JoinHostPort(cachedIP, port)) + if err == nil { + return conn, nil + } + // Delete the cache if the cached IP doesn't work + d.cache.Delete(host) + } + + // Attempt to resolve using the bus + logger.Debug("Cache not available or cached IP stale, retrieving host resolved addresses from bus") + hostInfo, err := d.bus.Host(ctx, hk) + if err != nil { + return nil, err + } + + for _, addr := range hostInfo.ResolvedAddresses { + conn, err := d.dialer.DialContext(ctx, "tcp", net.JoinHostPort(addr, port)) + if err == nil { + // Update cache on successful dial + d.cache.Set(host, addr) + return conn, nil + } + } + + return nil, fmt.Errorf("failed to dial %s with all methods", address) +} diff --git a/internal/worker/events.go b/internal/worker/events.go new file mode 100644 index 000000000..e0960fd5c --- /dev/null +++ b/internal/worker/events.go @@ -0,0 +1,195 @@ +package worker + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "go.sia.tech/renterd/alerts" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/webhooks" + "go.uber.org/zap" +) + +var ( + alertWebhookRegistrationFailedID = alerts.RandomAlertID() // constant until restarted +) + +type ( + EventSubscriber interface { + AddEventHandler(id string, h EventHandler) (chan struct{}, error) + ProcessEvent(event webhooks.Event) + Register(ctx context.Context, eventURL string, opts ...webhooks.HeaderOption) error + Shutdown(context.Context) error + } + + EventHandler interface { + HandleEvent(event webhooks.Event) error + Subscribe(e EventSubscriber) error + } + + WebhookManager interface { + RegisterWebhook(ctx context.Context, wh webhooks.Webhook) error + UnregisterWebhook(ctx context.Context, wh webhooks.Webhook) error + } +) + +type ( + eventSubscriber struct { + alerts alerts.Alerter + webhooks WebhookManager + logger *zap.SugaredLogger + + registerInterval time.Duration + + mu sync.Mutex + handlers map[string]EventHandler + registered []webhooks.Webhook + registeredChan chan struct{} + } +) + +func NewEventSubscriber(a alerts.Alerter, w WebhookManager, l *zap.Logger, registerInterval time.Duration) EventSubscriber { + return &eventSubscriber{ + alerts: a, + webhooks: w, + logger: l.Sugar().Named("events"), + + registeredChan: make(chan struct{}), + + handlers: make(map[string]EventHandler), + registerInterval: registerInterval, + } +} + +func (e *eventSubscriber) AddEventHandler(id string, h EventHandler) (chan struct{}, error) { + e.mu.Lock() + defer e.mu.Unlock() + _, ok := e.handlers[id] + if ok { + return nil, fmt.Errorf("subscriber with id %v already exists", id) + } + e.handlers[id] = h + + return e.registeredChan, nil +} + +func (e *eventSubscriber) ProcessEvent(event webhooks.Event) { + log := e.logger.With( + zap.String("module", event.Module), + zap.String("event", event.Event), + ) + + for id, s := range e.handlers { + if err := s.HandleEvent(event); err != nil { + log.Errorw("failed to handle event", + zap.Error(err), + zap.String("subscriber", id), + ) + } else { + log.Debugw("handled event", + zap.String("subscriber", id), + ) + } + } +} + +func (e *eventSubscriber) Register(ctx context.Context, eventsURL string, opts ...webhooks.HeaderOption) error { + select { + case <-e.registeredChan: + return fmt.Errorf("already registered") // developer error + default: + } + + // prepare headers + headers := make(map[string]string) + for _, opt := range opts { + opt(headers) + } + + // prepare webhooks + webhooks := []webhooks.Webhook{ + api.WebhookConsensusUpdate(eventsURL, headers), + api.WebhookContractAdd(eventsURL, headers), + api.WebhookContractArchive(eventsURL, headers), + api.WebhookContractRenew(eventsURL, headers), + api.WebhookHostUpdate(eventsURL, headers), + api.WebhookSettingUpdate(eventsURL, headers), + } + + // try and register the webhooks in a loop + for { + err := e.registerWebhooks(ctx, webhooks) + if err == nil { + e.alerts.DismissAlerts(ctx, alertWebhookRegistrationFailedID) + break + } + + // alert on failure + e.alerts.RegisterAlert(ctx, newWebhookRegistrationFailedAlert(err)) + e.logger.Warnf("failed to register webhooks, retrying in %v", e.registerInterval) + + // sleep for a bit before trying again + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(e.registerInterval): + } + } + + return nil +} + +func (e *eventSubscriber) Shutdown(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + // unregister webhooks + var errs []error + for _, wh := range e.registered { + if err := e.webhooks.UnregisterWebhook(ctx, wh); err != nil { + e.logger.Errorw("failed to unregister webhook", + zap.Error(err), + zap.Stringer("webhook", wh), + ) + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +func (e *eventSubscriber) registerWebhooks(ctx context.Context, webhooks []webhooks.Webhook) error { + for _, wh := range webhooks { + if err := e.webhooks.RegisterWebhook(ctx, wh); err != nil { + e.logger.Errorw("failed to register webhook", + zap.Error(err), + zap.Stringer("webhook", wh), + ) + return err + } + } + + // save webhooks so we can unregister them on shutdown + e.mu.Lock() + e.registered = webhooks + e.mu.Unlock() + + // signal that we're registered + close(e.registeredChan) + return nil +} + +func newWebhookRegistrationFailedAlert(err error) alerts.Alert { + return alerts.Alert{ + ID: alertWebhookRegistrationFailedID, + Severity: alerts.SeverityCritical, + Message: "Worker failed to register webhooks", + Data: map[string]any{ + "error": err.Error(), + }, + Timestamp: time.Now(), + } +} diff --git a/internal/worker/events_test.go b/internal/worker/events_test.go new file mode 100644 index 000000000..76c89bbfb --- /dev/null +++ b/internal/worker/events_test.go @@ -0,0 +1,219 @@ +package worker + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/jape" + "go.sia.tech/renterd/alerts" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/webhooks" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +const testRegisterInterval = 100 * time.Millisecond + +type mockAlerter struct{} + +func (a *mockAlerter) Alerts(ctx context.Context, opts alerts.AlertsOpts) (alerts.AlertsResponse, error) { + return alerts.AlertsResponse{}, nil +} +func (a *mockAlerter) RegisterAlert(ctx context.Context, alert alerts.Alert) error { return nil } +func (a *mockAlerter) DismissAlerts(ctx context.Context, ids ...types.Hash256) error { return nil } + +type mockEventHandler struct { + id string + readyChan chan struct{} + + mu sync.Mutex + events []webhooks.Event +} + +func (s *mockEventHandler) Events() []webhooks.Event { + s.mu.Lock() + defer s.mu.Unlock() + return s.events +} + +func (s *mockEventHandler) HandleEvent(event webhooks.Event) error { + s.mu.Lock() + defer s.mu.Unlock() + + select { + case <-s.readyChan: + default: + return fmt.Errorf("subscriber not ready") + } + + s.events = append(s.events, event) + return nil +} + +func (s *mockEventHandler) Subscribe(e EventSubscriber) error { + s.readyChan, _ = e.AddEventHandler(s.id, s) + return nil +} + +type mockWebhookManager struct { + blockChan chan struct{} + + mu sync.Mutex + registered []webhooks.Webhook +} + +func (m *mockWebhookManager) RegisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { + <-m.blockChan + + m.mu.Lock() + defer m.mu.Unlock() + m.registered = append(m.registered, webhook) + return nil +} + +func (m *mockWebhookManager) UnregisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { + m.mu.Lock() + defer m.mu.Unlock() + + for i, wh := range m.registered { + if wh.String() == webhook.String() { + m.registered = append(m.registered[:i], m.registered[i+1:]...) + return nil + } + } + return nil +} + +func (m *mockWebhookManager) Webhooks() []webhooks.Webhook { + m.mu.Lock() + defer m.mu.Unlock() + return m.registered +} + +func TestEventSubscriber(t *testing.T) { + // observe logs + observedZapCore, observedLogs := observer.New(zap.DebugLevel) + + // create mocks + a := &mockAlerter{} + w := &mockWebhookManager{blockChan: make(chan struct{})} + h := &mockEventHandler{id: t.Name()} + + // create event subscriber + s := NewEventSubscriber(a, w, zap.New(observedZapCore), testRegisterInterval) + + // subscribe the event handler + if err := h.Subscribe(s); err != nil { + t.Fatal(err) + } + + // setup a server + mux := jape.Mux(map[string]jape.Handler{"POST /events": func(jc jape.Context) { + var event webhooks.Event + if jc.Decode(&event) != nil { + return + } else if event.Event == webhooks.WebhookEventPing { + jc.ResponseWriter.WriteHeader(http.StatusOK) + return + } else { + s.ProcessEvent(event) + } + }}) + srv := httptest.NewServer(mux) + defer srv.Close() + + // register the subscriber + eventsURL := fmt.Sprintf("http://%v/events", srv.Listener.Addr().String()) + go func() { + if err := s.Register(context.Background(), eventsURL); err != nil { + t.Error(err) + } + }() + + // send an event before unblocking webhooks registration + err := sendEvent(eventsURL, webhooks.Event{Module: api.ModuleConsensus, Event: api.EventUpdate}) + if err != nil { + t.Fatal(err) + } + logs := observedLogs.TakeAll() + if len(logs) != 1 { + t.Fatal("expected 1 log, got", len(logs)) + } else if entry := logs[0]; entry.Message != "failed to handle event" || entry.ContextMap()["error"] != "subscriber not ready" { + t.Fatal("expected different log entry, got", entry) + } + + // unblock the webhooks registration + close(w.blockChan) + time.Sleep(testRegisterInterval) + + // assert webhook was registered + if webhooks := w.Webhooks(); len(webhooks) != 6 { + t.Fatal("expected 6 webhooks, got", len(webhooks)) + } + + // send the same event again + err = sendEvent(eventsURL, webhooks.Event{Module: api.ModuleConsensus, Event: api.EventUpdate}) + if err != nil { + t.Fatal(err) + } + logs = observedLogs.TakeAll() + if len(logs) != 1 { + t.Fatal("expected 1 log, got", len(logs)) + } else if entry := logs[0]; entry.Message != "handled event" || entry.ContextMap()["subscriber"] != t.Name() { + t.Fatal("expected different log entry, got", entry) + } + + // assert the subscriber handled the event + if events := h.Events(); len(events) != 1 { + t.Fatal("expected 1 event, got", len(events)) + } + + // shutdown event subscriber + err = s.Shutdown(context.Background()) + if err != nil { + t.Fatal(err) + } + + // assert webhook was unregistered + if webhooks := w.Webhooks(); len(webhooks) != 0 { + t.Fatal("expected 0 webhooks, got", len(webhooks)) + } +} + +func sendEvent(url string, event webhooks.Event) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + body, err := json.Marshal(event) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return err + } + defer io.ReadAll(req.Body) // always drain body + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + errStr, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + return fmt.Errorf("Webhook returned unexpected status %v: %v", resp.StatusCode, string(errStr)) + } + return nil +} diff --git a/object/slab.go b/object/slab.go index 770df9ef6..e52e7bd7b 100644 --- a/object/slab.go +++ b/object/slab.go @@ -53,15 +53,16 @@ func NewPartialSlab(ec EncryptionKey, minShards uint8) Slab { // ContractsFromShards is a helper to extract all contracts used by a set of // shards. -func ContractsFromShards(shards []Sector) map[types.PublicKey]map[types.FileContractID]struct{} { - usedContracts := make(map[types.PublicKey]map[types.FileContractID]struct{}) +func ContractsFromShards(shards []Sector) []types.FileContractID { + var usedContracts []types.FileContractID + usedMap := make(map[types.FileContractID]struct{}) for _, shard := range shards { - for h, fcids := range shard.Contracts { + for _, fcids := range shard.Contracts { for _, fcid := range fcids { - if _, exists := usedContracts[h]; !exists { - usedContracts[h] = make(map[types.FileContractID]struct{}) + if _, exists := usedMap[fcid]; !exists { + usedContracts = append(usedContracts, fcid) } - usedContracts[h][fcid] = struct{}{} + usedMap[fcid] = struct{}{} } } } diff --git a/stores/accounts.go b/stores/accounts.go index 523ba2697..183582b8b 100644 --- a/stores/accounts.go +++ b/stores/accounts.go @@ -9,7 +9,7 @@ import ( // Accounts returns all accounts from the db. func (s *SQLStore) Accounts(ctx context.Context) (accounts []api.Account, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { accounts, err = tx.Accounts(ctx) return err }) @@ -21,7 +21,7 @@ func (s *SQLStore) Accounts(ctx context.Context) (accounts []api.Account, err er // sync all accounts after an unclean shutdown and the bus will know not to // apply drift. func (s *SQLStore) SetUncleanShutdown(ctx context.Context) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.SetUncleanShutdown(ctx) }) } @@ -29,7 +29,7 @@ func (s *SQLStore) SetUncleanShutdown(ctx context.Context) error { // SaveAccounts saves the given accounts in the db, overwriting any existing // ones. func (s *SQLStore) SaveAccounts(ctx context.Context, accounts []api.Account) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.SaveAccounts(ctx, accounts) }) } diff --git a/stores/autopilot.go b/stores/autopilot.go index 45b899576..9c557c3d7 100644 --- a/stores/autopilot.go +++ b/stores/autopilot.go @@ -9,7 +9,7 @@ import ( ) func (s *SQLStore) Autopilots(ctx context.Context) (aps []api.Autopilot, _ error) { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { aps, err = tx.Autopilots(ctx) return }) @@ -17,7 +17,7 @@ func (s *SQLStore) Autopilots(ctx context.Context) (aps []api.Autopilot, _ error } func (s *SQLStore) Autopilot(ctx context.Context, id string) (ap api.Autopilot, _ error) { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { ap, err = tx.Autopilot(ctx, id) return }) @@ -32,7 +32,7 @@ func (s *SQLStore) UpdateAutopilot(ctx context.Context, ap api.Autopilot) error if err := ap.Config.Validate(); err != nil { return err } - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateAutopilot(ctx, ap) }) } diff --git a/stores/autopilot_test.go b/stores/autopilot_test.go index ef94d7d8f..1ed78370c 100644 --- a/stores/autopilot_test.go +++ b/stores/autopilot_test.go @@ -85,4 +85,10 @@ func TestAutopilotStore(t *testing.T) { if updated.Config.Contracts.Amount != 99 { t.Fatal("expected amount to be 99") } + + // update the autopilot with the same config and assert it does not fail + err = ss.UpdateAutopilot(context.Background(), updated) + if err != nil { + t.Fatal(err) + } } diff --git a/stores/chain.go b/stores/chain.go new file mode 100644 index 000000000..90b7de3e0 --- /dev/null +++ b/stores/chain.go @@ -0,0 +1,32 @@ +package stores + +import ( + "context" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/stores/sql" +) + +// ChainIndex returns the last stored chain index. +func (s *SQLStore) ChainIndex(ctx context.Context) (ci types.ChainIndex, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + ci, err = tx.Tip(ctx) + return err + }) + return +} + +// ProcessChainUpdate returns a callback function that process a chain update +// inside a transaction. +func (s *SQLStore) ProcessChainUpdate(ctx context.Context, applyFn func(sql.ChainUpdateTx) error) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + return tx.ProcessChainUpdate(ctx, applyFn) + }) +} + +// ResetChainState deletes all chain data in the database. +func (s *SQLStore) ResetChainState(ctx context.Context) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + return tx.ResetChainState(ctx) + }) +} diff --git a/stores/chain_test.go b/stores/chain_test.go new file mode 100644 index 000000000..e0bcc8480 --- /dev/null +++ b/stores/chain_test.go @@ -0,0 +1,211 @@ +package stores + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/stores/sql" +) + +// TestProcessChainUpdate tests the ProcessChainUpdate method on the SQL store. +func TestProcessChainUpdate(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + + // add test host and contract + hks, err := ss.addTestHosts(1) + if err != nil { + t.Fatal(err) + } + fcids, _, err := ss.addTestContracts(hks) + if err != nil { + t.Fatal(err) + } else if len(fcids) != 1 { + t.Fatal("expected one contract", len(fcids)) + } + fcid := fcids[0] + + // assert contract state returns the correct state + var state api.ContractState + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) (err error) { + state, err = tx.ContractState(fcid) + return + }); err != nil { + t.Fatal("unexpected error", err) + } else if state != api.ContractStatePending { + t.Fatalf("unexpected state '%v'", state) + } + + // check current index + if curr, err := ss.ChainIndex(context.Background()); err != nil { + t.Fatal(err) + } else if curr.Height != 0 { + t.Fatalf("unexpected height %v", curr.Height) + } + + // assert update chain index is successful + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.UpdateChainIndex(types.ChainIndex{Height: 1}) + }); err != nil { + t.Fatal("unexpected error", err) + } + + // check updated index + if curr, err := ss.ChainIndex(context.Background()); err != nil { + t.Fatal(err) + } else if curr.Height != 1 { + t.Fatalf("unexpected height %v", curr.Height) + } + + // assert update contract is successful + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + if err := tx.UpdateContract(fcid, 1, 2, 3); err != nil { + return err + } else if err := tx.UpdateContractState(fcid, api.ContractStateActive); err != nil { + return err + } else if err := tx.UpdateContractProofHeight(fcid, 4); err != nil { + return err + } else { + return nil + } + }); err != nil { + t.Fatal("unexpected error", err) + } + + // assert contract was updated successfully + var we uint64 + if c, err := ss.Contract(context.Background(), fcid); err != nil { + t.Fatal("unexpected error", err) + } else if c.RevisionHeight != 1 { + t.Fatal("unexpected revision height", c.RevisionHeight) + } else if c.RevisionNumber != 2 { + t.Fatal("unexpected revision number", c.RevisionNumber) + } else if c.Size != 3 { + t.Fatal("unexpected size", c.Size) + } else if c.State != api.ContractStateActive { + t.Fatal("unexpected state", c.State) + } else { + we = c.WindowEnd + } + + // assert we only update revision height if the rev number doesn't increase + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.UpdateContract(fcid, 2, 2, 4) + }); err != nil { + t.Fatal("unexpected error", err) + } + if c, err := ss.Contract(context.Background(), fcid); err != nil { + t.Fatal("unexpected error", err) + } else if c.RevisionHeight != 2 { + t.Fatal("unexpected revision height", c.RevisionHeight) + } else if c.RevisionNumber != 2 { + t.Fatal("unexpected revision number", c.RevisionNumber) + } else if c.Size != 3 { + t.Fatal("unexpected size", c.Size) + } + + // assert update failed contracts is successful + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.UpdateFailedContracts(we + 1) + }); err != nil { + t.Fatal("unexpected error", err) + } + if c, err := ss.Contract(context.Background(), fcid); err != nil { + t.Fatal("unexpected error", err) + } else if c.State != api.ContractStateFailed { + t.Fatal("unexpected state", c.State) + } + + // renew the contract + _, err = ss.addTestRenewedContract(types.FileContractID{2}, fcid, hks[0], 1) + if err != nil { + t.Fatal(err) + } + + // assert we can fetch the state of the archived contract + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) (err error) { + state, err = tx.ContractState(fcid) + return + }); err != nil { + t.Fatal("unexpected error", err) + } else if state != api.ContractStateFailed { + t.Fatalf("unexpected state '%v'", state) + } + + // assert update host is successful + ts := time.Now().Truncate(time.Second).Add(-time.Minute).UTC() + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.UpdateHost(hks[0], chain.HostAnnouncement{NetAddress: "foo"}, 1, types.BlockID{}, ts) + }); err != nil { + t.Fatal("unexpected error", err) + } + if h, err := ss.Host(context.Background(), hks[0]); err != nil { + t.Fatal("unexpected error", err) + } else if h.NetAddress != "foo" { + t.Fatal("unexpected net address", h.NetAddress) + } else if !h.LastAnnouncement.Truncate(time.Second).Equal(ts) { + t.Fatalf("unexpected last announcement %v != %v", h.LastAnnouncement, ts) + } + + // record 2 scans for the host to give it some uptime + err = ss.RecordHostScans(context.Background(), []api.HostScan{ + {HostKey: hks[0], Success: true, Timestamp: time.Now()}, + {HostKey: hks[0], Success: true, Timestamp: time.Now().Add(time.Minute)}, + }) + if err != nil { + t.Fatal(err) + } else if h, err := ss.Host(context.Background(), hks[0]); err != nil { + t.Fatal(err) + } else if h.Interactions.Uptime < time.Minute || h.Interactions.Uptime > time.Minute+time.Second { + t.Fatalf("unexpected uptime %v", h.Interactions.Uptime) + } + + // reannounce the host and make sure the uptime is the same + ts = ts.Add(time.Minute) + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.UpdateHost(hks[0], chain.HostAnnouncement{NetAddress: "fooNew"}, 1, types.BlockID{}, ts) + }); err != nil { + t.Fatal("unexpected error", err) + } + if h, err := ss.Host(context.Background(), hks[0]); err != nil { + t.Fatal("unexpected error", err) + } else if h.Interactions.Uptime < time.Minute || h.Interactions.Uptime > time.Minute+time.Second { + t.Fatalf("unexpected uptime %v", h.Interactions.Uptime) + } else if h.NetAddress != "fooNew" { + t.Fatal("unexpected net address", h.NetAddress) + } else if !h.LastAnnouncement.Equal(ts) { + t.Fatalf("unexpected last announcement %v != %v", h.LastAnnouncement, ts) + } + + // assert passing empty function is successful + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { return nil }); err != nil { + t.Fatal("unexpected error", err) + } + + // assert we rollback on error + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + if err := tx.UpdateChainIndex(types.ChainIndex{Height: 2}); err != nil { + return err + } + return errors.New("some error") + }); err == nil || !strings.Contains(err.Error(), "some error") { + t.Fatal("unexpected error", err) + } + + // check chain index was rolled back + if curr, err := ss.ChainIndex(context.Background()); err != nil { + t.Fatal(err) + } else if curr.Height != 1 { + t.Fatalf("unexpected height %v", curr.Height) + } + + // assert we recover from panic + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { return nil }); err != nil { + panic("oh no") + } +} diff --git a/stores/hostdb.go b/stores/hostdb.go index 9831db7a4..5111682d1 100644 --- a/stores/hostdb.go +++ b/stores/hostdb.go @@ -2,308 +2,55 @@ package stores import ( "context" - dsql "database/sql" "errors" "fmt" - "net" - "strings" "time" "go.sia.tech/core/types" "go.sia.tech/renterd/api" - "go.sia.tech/renterd/hostdb" sql "go.sia.tech/renterd/stores/sql" - "go.sia.tech/siad/modules" - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -const ( - // announcementBatchSoftLimit is the limit above which - // threadedProcessAnnouncements will stop merging batches of - // announcements and apply them to the db. - announcementBatchSoftLimit = 1000 - - // consensusInfoID defines the primary key of the entry in the consensusInfo - // table. - consensusInfoID = 1 ) var ( ErrNegativeMaxDowntime = errors.New("max downtime can not be negative") ) -type ( - // dbHost defines a api.Interaction as persisted in the DB. Deleting a - // host from the db will cascade the deletion and also delete the - // corresponding announcements and interactions with that host. - // - // NOTE: updating the host entity requires an update to the field map passed - // to 'Update' when recording host interactions - dbHost struct { - Model - - PublicKey publicKey `gorm:"unique;index;NOT NULL;size:32"` - Settings hostSettings - PriceTable hostPriceTable - PriceTableExpiry dsql.NullTime - - TotalScans uint64 - LastScan int64 `gorm:"index"` // unix nano - LastScanSuccess bool - SecondToLastScanSuccess bool - Scanned bool `gorm:"index"` - Uptime time.Duration - Downtime time.Duration - - // RecentDowntime and RecentScanFailures are used to determine whether a - // host is eligible for pruning. - RecentDowntime time.Duration `gorm:"index"` - RecentScanFailures uint64 `gorm:"index"` - - SuccessfulInteractions float64 - FailedInteractions float64 - - LostSectors uint64 - - LastAnnouncement time.Time - NetAddress string `gorm:"index"` - Subnets string - - Allowlist []dbAllowlistEntry `gorm:"many2many:host_allowlist_entry_hosts;constraint:OnDelete:CASCADE"` - Blocklist []dbBlocklistEntry `gorm:"many2many:host_blocklist_entry_hosts;constraint:OnDelete:CASCADE"` - Checks []dbHostCheck `gorm:"foreignKey:DBHostID;constraint:OnDelete:CASCADE"` - } - - // dbHostCheck contains information about a host that is collected and used - // by the autopilot. - dbHostCheck struct { - Model - - DBAutopilotID uint - - DBHostID uint - DBHost dbHost - - // usability - UsabilityBlocked bool - UsabilityOffline bool - UsabilityLowScore bool - UsabilityRedundantIP bool - UsabilityGouging bool - UsabilityNotAcceptingContracts bool - UsabilityNotAnnounced bool - UsabilityNotCompletingScan bool - - // score - ScoreAge float64 - ScoreCollateral float64 - ScoreInteractions float64 - ScoreStorageRemaining float64 - ScoreUptime float64 - ScoreVersion float64 - ScorePrices float64 - - // gouging - GougingContractErr string - GougingDownloadErr string - GougingGougingErr string - GougingPruneErr string - GougingUploadErr string - } - - // dbAllowlistEntry defines a table that stores the host blocklist. - dbAllowlistEntry struct { - Model - Entry publicKey `gorm:"unique;index;NOT NULL;size:32"` - Hosts []dbHost `gorm:"many2many:host_allowlist_entry_hosts;constraint:OnDelete:CASCADE"` - } - - // dbBlocklistEntry defines a table that stores the host blocklist. - dbBlocklistEntry struct { - Model - Entry string `gorm:"unique;index;NOT NULL"` - Hosts []dbHost `gorm:"many2many:host_blocklist_entry_hosts;constraint:OnDelete:CASCADE"` - } - - dbConsensusInfo struct { - Model - CCID []byte - Height uint64 - BlockID hash256 - } - - // dbAnnouncement is a table used for storing all announcements. It - // doesn't have any relations to dbHost which means it won't - // automatically prune when a host is deleted. - dbAnnouncement struct { - Model - HostKey publicKey `gorm:"NOT NULL"` - - BlockHeight uint64 - BlockID string - NetAddress string - } - - // announcement describes an announcement for a single host. - announcement struct { - hostKey publicKey - announcement hostdb.Announcement - } -) - -// TableName implements the gorm.Tabler interface. -func (dbAnnouncement) TableName() string { return "host_announcements" } - -// TableName implements the gorm.Tabler interface. -func (dbConsensusInfo) TableName() string { return "consensus_infos" } - -// TableName implements the gorm.Tabler interface. -func (dbHost) TableName() string { return "hosts" } - -// TableName implements the gorm.Tabler interface. -func (dbHostCheck) TableName() string { return "host_checks" } - -// TableName implements the gorm.Tabler interface. -func (dbAllowlistEntry) TableName() string { return "host_allowlist_entries" } - -// TableName implements the gorm.Tabler interface. -func (dbBlocklistEntry) TableName() string { return "host_blocklist_entries" } - -func (h *dbHost) BeforeCreate(tx *gorm.DB) (err error) { - tx.Statement.AddClause(clause.OnConflict{ - Columns: []clause.Column{{Name: "public_key"}}, - DoUpdates: clause.AssignmentColumns([]string{"last_announcement", "net_address"}), - }) - return nil -} - -func (e *dbAllowlistEntry) AfterCreate(tx *gorm.DB) error { - // NOTE: the ID is zero here if we ignore a conflict on create - if e.ID == 0 { - return nil - } - - params := map[string]interface{}{ - "entry_id": e.ID, - "exact_entry": publicKey(e.Entry), - } - - // insert entries into the allowlist - if isSQLite(tx) { - return tx.Exec(`INSERT OR IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) -SELECT @entry_id, id FROM ( -SELECT id -FROM hosts -WHERE public_key = @exact_entry -)`, params).Error - } - - return tx.Exec(`INSERT IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) -SELECT @entry_id, id FROM ( - SELECT id - FROM hosts - WHERE public_key=@exact_entry -) AS _`, params).Error -} - -func (e *dbAllowlistEntry) BeforeCreate(tx *gorm.DB) (err error) { - tx.Statement.AddClause(clause.OnConflict{ - Columns: []clause.Column{{Name: "entry"}}, - DoNothing: true, - }) - return nil -} - -func (e *dbBlocklistEntry) AfterCreate(tx *gorm.DB) error { - // NOTE: the ID is zero here if we ignore a conflict on create - if e.ID == 0 { - return nil - } - - params := map[string]interface{}{ - "entry_id": e.ID, - "exact_entry": e.Entry, - "like_entry": fmt.Sprintf("%%.%s", e.Entry), - } - - // insert entries into the blocklist - if isSQLite(tx) { - return tx.Exec(` -INSERT OR IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) -SELECT @entry_id, id FROM ( - SELECT id - FROM hosts - WHERE net_address == @exact_entry OR - rtrim(rtrim(net_address, replace(net_address, ':', '')),':') == @exact_entry OR - rtrim(rtrim(net_address, replace(net_address, ':', '')),':') LIKE @like_entry -)`, params).Error - } - - return tx.Exec(` -INSERT IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) -SELECT @entry_id, id FROM ( - SELECT id - FROM hosts - WHERE net_address=@exact_entry OR - SUBSTRING_INDEX(net_address,':',1)=@exact_entry OR - SUBSTRING_INDEX(net_address,':',1) LIKE @like_entry -) AS _`, params).Error -} - -func (e *dbBlocklistEntry) BeforeCreate(tx *gorm.DB) (err error) { - tx.Statement.AddClause(clause.OnConflict{ - Columns: []clause.Column{{Name: "entry"}}, - DoNothing: true, - }) - return nil -} - -func (e *dbBlocklistEntry) blocks(h dbHost) bool { - values := []string{h.NetAddress} - host, _, err := net.SplitHostPort(h.NetAddress) - if err == nil { - values = append(values, host) - } - - for _, value := range values { - if value == e.Entry || strings.HasSuffix(value, "."+e.Entry) { - return true - } - } - return false -} - // Host returns information about a host. -func (ss *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) { - hosts, err := ss.SearchHosts(ctx, "", api.HostFilterModeAll, api.UsabilityFilterModeAll, "", []types.PublicKey{hostKey}, 0, 1) +func (s *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) { + hosts, err := s.SearchHosts(ctx, "", api.HostFilterModeAll, api.UsabilityFilterModeAll, "", []types.PublicKey{hostKey}, 0, 1) if err != nil { return api.Host{}, err } else if len(hosts) == 0 { - return api.Host{}, api.ErrHostNotFound + return api.Host{}, fmt.Errorf("%w %v", api.ErrHostNotFound, hostKey) } else { return hosts[0], nil } } -func (ss *SQLStore) UpdateHostCheck(ctx context.Context, autopilotID string, hk types.PublicKey, hc api.HostCheck) (err error) { - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) UpdateHostCheck(ctx context.Context, autopilotID string, hk types.PublicKey, hc api.HostCheck) (err error) { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateHostCheck(ctx, autopilotID, hk, hc) }) } // HostsForScanning returns the address of hosts for scanning. -func (ss *SQLStore) HostsForScanning(ctx context.Context, maxLastScan time.Time, offset, limit int) (hosts []api.HostAddress, err error) { - err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) HostsForScanning(ctx context.Context, maxLastScan time.Time, offset, limit int) (hosts []api.HostAddress, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { hosts, err = tx.HostsForScanning(ctx, maxLastScan, offset, limit) return err }) return } -func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { +func (s *SQLStore) ResetLostSectors(ctx context.Context, hk types.PublicKey) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + return tx.ResetLostSectors(ctx, hk) + }) +} + +func (s *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { var hosts []api.Host - err := ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { hosts, err = tx.SearchHosts(ctx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit) return }) @@ -311,16 +58,16 @@ func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, us } // Hosts returns non-blocked hosts at given offset and limit. -func (ss *SQLStore) Hosts(ctx context.Context, offset, limit int) ([]api.Host, error) { - return ss.SearchHosts(ctx, "", api.HostFilterModeAllowed, api.UsabilityFilterModeAll, "", nil, offset, limit) +func (s *SQLStore) Hosts(ctx context.Context, offset, limit int) ([]api.Host, error) { + return s.SearchHosts(ctx, "", api.HostFilterModeAllowed, api.UsabilityFilterModeAll, "", nil, offset, limit) } -func (ss *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures uint64, maxDowntime time.Duration) (removed uint64, err error) { +func (s *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures uint64, maxDowntime time.Duration) (removed uint64, err error) { // sanity check 'maxDowntime' if maxDowntime < 0 { return 0, ErrNegativeMaxDowntime } - err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { n, err := tx.RemoveOfflineHosts(ctx, minRecentFailures, maxDowntime) removed = uint64(n) return err @@ -328,191 +75,50 @@ func (ss *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures ui return } -func (ss *SQLStore) UpdateHostAllowlistEntries(ctx context.Context, add, remove []types.PublicKey, clear bool) (err error) { +func (s *SQLStore) UpdateHostAllowlistEntries(ctx context.Context, add, remove []types.PublicKey, clear bool) (err error) { // nothing to do if len(add)+len(remove) == 0 && !clear { return nil } - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateHostAllowlistEntries(ctx, add, remove, clear) }) } -func (ss *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove []string, clear bool) (err error) { +func (s *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove []string, clear bool) (err error) { // nothing to do if len(add)+len(remove) == 0 && !clear { return nil } - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateHostBlocklistEntries(ctx, add, remove, clear) }) } -func (ss *SQLStore) HostAllowlist(ctx context.Context) (allowlist []types.PublicKey, err error) { - err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) HostAllowlist(ctx context.Context) (allowlist []types.PublicKey, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { allowlist, err = tx.HostAllowlist(ctx) return err }) return } -func (ss *SQLStore) HostBlocklist(ctx context.Context) (blocklist []string, err error) { - err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) HostBlocklist(ctx context.Context) (blocklist []string, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { blocklist, err = tx.HostBlocklist(ctx) return err }) return } -func (ss *SQLStore) RecordHostScans(ctx context.Context, scans []api.HostScan) error { - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) RecordHostScans(ctx context.Context, scans []api.HostScan) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.RecordHostScans(ctx, scans) }) } -func (ss *SQLStore) RecordPriceTables(ctx context.Context, priceTableUpdate []api.HostPriceTableUpdate) error { - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) RecordPriceTables(ctx context.Context, priceTableUpdate []api.HostPriceTableUpdate) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.RecordPriceTables(ctx, priceTableUpdate) }) } - -func (ss *SQLStore) processConsensusChangeHostDB(cc modules.ConsensusChange) { - height := uint64(cc.InitialHeight()) - for range cc.RevertedBlocks { - height-- - } - - var newAnnouncements []announcement - for _, sb := range cc.AppliedBlocks { - var b types.Block - convertToCore(sb, (*types.V1Block)(&b)) - - // Process announcements, but only if they are not too old. - if b.Timestamp.After(time.Now().Add(-ss.announcementMaxAge)) { - hostdb.ForEachAnnouncement(types.Block(b), height, func(hostKey types.PublicKey, ha hostdb.Announcement) { - newAnnouncements = append(newAnnouncements, announcement{ - hostKey: publicKey(hostKey), - announcement: ha, - }) - ss.unappliedHostKeys[hostKey] = struct{}{} - }) - } - height++ - } - - ss.unappliedAnnouncements = append(ss.unappliedAnnouncements, newAnnouncements...) -} - -func updateCCID(tx *gorm.DB, newCCID modules.ConsensusChangeID, newTip types.ChainIndex) error { - return tx.Model(&dbConsensusInfo{}).Where(&dbConsensusInfo{ - Model: Model{ - ID: consensusInfoID, - }, - }).Updates(map[string]interface{}{ - "CCID": newCCID[:], - "height": newTip.Height, - "block_id": hash256(newTip.ID), - }).Error -} - -func insertAnnouncements(tx *gorm.DB, as []announcement) error { - var hosts []dbHost - var announcements []dbAnnouncement - for _, a := range as { - hosts = append(hosts, dbHost{ - PublicKey: a.hostKey, - LastAnnouncement: a.announcement.Timestamp.UTC(), - NetAddress: a.announcement.NetAddress, - }) - announcements = append(announcements, dbAnnouncement{ - HostKey: a.hostKey, - BlockHeight: a.announcement.Index.Height, - BlockID: a.announcement.Index.ID.String(), - NetAddress: a.announcement.NetAddress, - }) - } - if err := tx.Create(&announcements).Error; err != nil { - return err - } - return tx.Create(&hosts).Error -} - -func applyRevisionUpdate(db *gorm.DB, fcid types.FileContractID, rev revisionUpdate) error { - return updateActiveAndArchivedContract(db, fcid, map[string]interface{}{ - "revision_height": rev.height, - "revision_number": fmt.Sprint(rev.number), - "size": rev.size, - }) -} - -func updateContractState(db *gorm.DB, fcid types.FileContractID, cs contractState) error { - return updateActiveAndArchivedContract(db, fcid, map[string]interface{}{ - "state": cs, - }) -} - -func markFailedContracts(db *gorm.DB, height uint64) error { - if err := db.Model(&dbContract{}). - Where("state = ? AND ? > window_end", contractStateActive, height). - Update("state", contractStateFailed).Error; err != nil { - return fmt.Errorf("failed to mark failed contracts: %w", err) - } - return nil -} - -func updateProofHeight(db *gorm.DB, fcid types.FileContractID, blockHeight uint64) error { - return updateActiveAndArchivedContract(db, fcid, map[string]interface{}{ - "proof_height": blockHeight, - }) -} - -func updateActiveAndArchivedContract(tx *gorm.DB, fcid types.FileContractID, updates map[string]interface{}) error { - err1 := tx.Model(&dbContract{}). - Where("fcid = ?", fileContractID(fcid)). - Updates(updates).Error - err2 := tx.Model(&dbArchivedContract{}). - Where("fcid = ?", fileContractID(fcid)). - Updates(updates).Error - if err1 != nil || err2 != nil { - return fmt.Errorf("%s; %s", err1, err2) - } - return nil -} - -func updateBlocklist(tx *gorm.DB, hk types.PublicKey, allowlist []dbAllowlistEntry, blocklist []dbBlocklistEntry) error { - // fetch the host - var host dbHost - if err := tx. - Model(&dbHost{}). - Where("public_key = ?", publicKey(hk)). - First(&host). - Error; err != nil { - return err - } - - // update host allowlist - var dbAllowlist []dbAllowlistEntry - for _, entry := range allowlist { - if entry.Entry == host.PublicKey { - dbAllowlist = append(dbAllowlist, entry) - } - } - if err := tx.Model(&host).Association("Allowlist").Replace(&dbAllowlist); err != nil { - return err - } - - // update host blocklist - var dbBlocklist []dbBlocklistEntry - for _, entry := range blocklist { - if entry.blocks(host) { - dbBlocklist = append(dbBlocklist, entry) - } - } - return tx.Model(&host).Association("Blocklist").Replace(&dbBlocklist) -} - -func (s *SQLStore) ResetLostSectors(ctx context.Context, hk types.PublicKey) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { - return tx.ResetLostSectors(ctx, hk) - }) -} diff --git a/stores/hostdb_test.go b/stores/hostdb_test.go index d6195d9c9..b4bba0dc1 100644 --- a/stores/hostdb_test.go +++ b/stores/hostdb_test.go @@ -1,7 +1,6 @@ package stores import ( - "bytes" "context" "errors" "fmt" @@ -10,36 +9,19 @@ import ( "time" "github.com/google/go-cmp/cmp" - "gitlab.com/NebulousLabs/encoding" rhpv2 "go.sia.tech/core/rhp/v2" rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" "go.sia.tech/renterd/api" - "go.sia.tech/renterd/hostdb" sql "go.sia.tech/renterd/stores/sql" - "go.sia.tech/siad/crypto" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" - "gorm.io/gorm" ) -func (s *SQLStore) insertTestAnnouncement(hk types.PublicKey, a hostdb.Announcement) error { - return insertAnnouncements(s.db, []announcement{ - { - hostKey: publicKey(hk), - announcement: a, - }, - }) -} - // TestSQLHostDB tests the basic functionality of SQLHostDB using an in-memory // SQLite DB. func TestSQLHostDB(t *testing.T) { ss := newTestSQLStore(t, defaultTestSQLStoreConfig) defer ss.Close() - if ss.ccid != modules.ConsensusChangeBeginning { - t.Fatal("wrong ccid", ss.ccid, modules.ConsensusChangeBeginning) - } // Try to fetch a random host. Should fail. ctx := context.Background() @@ -66,25 +48,22 @@ func TestSQLHostDB(t *testing.T) { // Insert an announcement for the host and another one for an unknown // host. - ann := newTestHostDBAnnouncement("address") - err = ss.insertTestAnnouncement(hk, ann) - if err != nil { + if err := ss.announceHost(hk, "address"); err != nil { t.Fatal(err) } - // Read the host and verify that the announcement related fields were - // set. - var h dbHost - tx := ss.db.Where("last_announcement = ? AND net_address = ?", ann.Timestamp, ann.NetAddress).Find(&h) - if tx.Error != nil { - t.Fatal(tx.Error) - } - if types.PublicKey(h.PublicKey) != hk { - t.Fatal("wrong host returned") + // Fetch the host + h, err := ss.Host(ctx, hk) + if err != nil { + t.Fatal(err) + } else if h.NetAddress != "address" { + t.Fatalf("unexpected address: %v", h.NetAddress) + } else if h.LastAnnouncement.IsZero() { + t.Fatal("last announcement not set") } // Same thing again but with hosts. - hosts, err := ss.hosts() + hosts, err := ss.Hosts(ctx, 0, -1) if err != nil { t.Fatal(err) } @@ -111,42 +90,20 @@ func TestSQLHostDB(t *testing.T) { } // Insert another announcement for an unknown host. - unknownKey := types.PublicKey{1, 4, 7} - err = ss.insertTestAnnouncement(unknownKey, ann) - if err != nil { + randomHK := types.PublicKey{1, 4, 7} + if err := ss.announceHost(types.PublicKey{1, 4, 7}, "na"); err != nil { t.Fatal(err) } - h3, err := ss.Host(ctx, unknownKey) + h3, err := ss.Host(ctx, randomHK) if err != nil { t.Fatal(err) } - if h3.NetAddress != ann.NetAddress { + if h3.NetAddress != "na" { t.Fatal("wrong net address") } if h3.KnownSince.IsZero() { t.Fatal("known since not set") } - - // Apply a consensus change. - ccid2 := modules.ConsensusChangeID{1, 2, 3} - ss.ProcessConsensusChange(modules.ConsensusChange{ - ID: ccid2, - AppliedBlocks: []stypes.Block{{}}, - AppliedDiffs: []modules.ConsensusChangeDiffs{{}}, - }) - if err := ss.applyUpdates(true); err != nil { - t.Fatal(err) - } - - // Connect to the same DB again. - hdb2 := ss.Reopen() - if hdb2.ccid != ccid2 { - t.Fatal("ccid wasn't updated", hdb2.ccid, ccid2) - } - _, err = hdb2.Host(ctx, hk) - if err != nil { - t.Fatal(err) - } } func (s *SQLStore) addTestScan(hk types.PublicKey, t time.Time, err error, settings rhpv2.HostSettings) error { @@ -359,11 +316,11 @@ func TestSearchHosts(t *testing.T) { } // assert there are currently 3 checks - var cnt int64 - err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error - if err != nil { - t.Fatal(err) - } else if cnt != 3 { + checkCount := func() int64 { + t.Helper() + return ss.Count("host_checks") + } + if cnt := checkCount(); cnt != 3 { t.Fatal("unexpected", cnt) } @@ -436,26 +393,20 @@ func TestSearchHosts(t *testing.T) { } // assert cascade delete on host - err = ss.db.Exec("DELETE FROM hosts WHERE public_key = ?", publicKey(types.PublicKey{1})).Error + _, err = ss.DB().Exec(context.Background(), "DELETE FROM hosts WHERE public_key = ?", sql.PublicKey(types.PublicKey{1})) if err != nil { t.Fatal(err) } - err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error - if err != nil { - t.Fatal(err) - } else if cnt != 2 { + if cnt := checkCount(); cnt != 2 { t.Fatal("unexpected", cnt) } // assert cascade delete on autopilot - err = ss.db.Exec("DELETE FROM autopilots WHERE identifier IN (?,?)", ap1, ap2).Error + _, err = ss.DB().Exec(context.Background(), "DELETE FROM autopilots WHERE identifier IN (?,?)", ap1, ap2) if err != nil { t.Fatal(err) } - err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error - if err != nil { - t.Fatal(err) - } else if cnt != 0 { + if cnt := checkCount(); cnt != 0 { t.Fatal("unexpected", cnt) } } @@ -485,28 +436,30 @@ func TestRecordScan(t *testing.T) { t.Fatal("mismatch") } - // The host shouldn't have any subnets. - if len(host.Subnets) != 0 { + // The host shouldn't have any addresses. + if len(host.ResolvedAddresses) != 0 { + t.Fatal("unexpected", host.ResolvedAddresses, len(host.ResolvedAddresses)) + } else if len(host.Subnets) != 0 { t.Fatal("unexpected", host.Subnets, len(host.Subnets)) } // Fetch the host directly to get the creation time. - h, err := hostByPubKey(ss.db, hk) + h, err := ss.Host(ctx, hk) if err != nil { t.Fatal(err) - } - if h.CreatedAt.IsZero() { - t.Fatal("creation time not set") + } else if h.KnownSince.IsZero() { + t.Fatal("known since not set") } // Record a scan. firstScanTime := time.Now().UTC() + resolvedAddresses := []string{"212.1.96.0", "38.135.51.0"} subnets := []string{"212.1.96.0/24", "38.135.51.0/24"} settings := rhpv2.HostSettings{NetAddress: "host.com"} pt := rhpv3.HostPriceTable{ HostBlockHeight: 123, } - if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, firstScanTime, settings, pt, true, subnets)}); err != nil { + if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, firstScanTime, settings, pt, true, resolvedAddresses, subnets)}); err != nil { t.Fatal(err) } host, err = ss.Host(ctx, hk) @@ -524,9 +477,11 @@ func TestRecordScan(t *testing.T) { t.Fatal(err) } - // The host should have the subnets. - if !reflect.DeepEqual(host.Subnets, subnets) { - t.Fatal("mismatch") + // The host should have the addresses. + if !reflect.DeepEqual(host.ResolvedAddresses, resolvedAddresses) { + t.Fatal("resolved addresses mismatch") + } else if !reflect.DeepEqual(host.Subnets, subnets) { + t.Fatal("subnets mismatch") } // We expect no uptime or downtime from only a single scan. @@ -556,7 +511,7 @@ func TestRecordScan(t *testing.T) { // subnets this time. secondScanTime := firstScanTime.Add(time.Hour) pt.HostBlockHeight = 456 - if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, secondScanTime, settings, pt, true, nil)}); err != nil { + if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, secondScanTime, settings, pt, true, nil, nil)}); err != nil { t.Fatal(err) } host, err = ss.Host(ctx, hk) @@ -585,13 +540,15 @@ func TestRecordScan(t *testing.T) { } // The host should still have the subnets. - if !reflect.DeepEqual(host.Subnets, subnets) { - t.Fatal("mismatch") + if !reflect.DeepEqual(host.ResolvedAddresses, resolvedAddresses) { + t.Fatal("resolved addresses mismatch") + } else if !reflect.DeepEqual(host.Subnets, subnets) { + t.Fatal("subnets mismatch") } // Record another scan 2 hours after the second one. This time it fails. thirdScanTime := secondScanTime.Add(2 * time.Hour) - if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, thirdScanTime, settings, pt, false, nil)}); err != nil { + if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, thirdScanTime, settings, pt, false, nil, nil)}); err != nil { t.Fatal(err) } host, err = ss.Host(ctx, hk) @@ -628,12 +585,11 @@ func TestRemoveHosts(t *testing.T) { t.Fatal(err) } - // fetch the host and assert the recent downtime is zero - h, err := hostByPubKey(ss.db, hk) + // fetch the host and assert the downtime is zero + h, err := ss.Host(context.Background(), hk) if err != nil { t.Fatal(err) - } - if h.RecentDowntime != 0 { + } else if h.Interactions.Downtime != 0 { t.Fatal("downtime is not zero") } @@ -650,24 +606,24 @@ func TestRemoveHosts(t *testing.T) { pt := rhpv3.HostPriceTable{} t1 := now.Add(-time.Minute * 120) // 2 hours ago t2 := now.Add(-time.Minute * 90) // 1.5 hours ago (30min downtime) - hi1 := newTestScan(hk, t1, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false, nil) - hi2 := newTestScan(hk, t2, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false, nil) + hi1 := newTestScan(hk, t1, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false, nil, nil) + hi2 := newTestScan(hk, t2, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false, nil, nil) // record interactions if err := ss.RecordHostScans(context.Background(), []api.HostScan{hi1, hi2}); err != nil { t.Fatal(err) } - // fetch the host and assert the recent downtime is 30 minutes and he has 2 recent scan failures - h, err = hostByPubKey(ss.db, hk) + // fetch the host and assert the downtime is 30 minutes and he has 2 recent scan failures + h, err = ss.Host(context.Background(), hk) if err != nil { t.Fatal(err) } - if h.RecentDowntime.Minutes() != 30 { - t.Fatal("downtime is not 30 minutes", h.RecentDowntime.Minutes()) + if h.Interactions.Downtime.Minutes() != 30 { + t.Fatal("downtime is not 30 minutes", h.Interactions.Downtime.Minutes()) } - if h.RecentScanFailures != 2 { - t.Fatal("recent scan failures is not 2", h.RecentScanFailures) + if h.Interactions.FailedInteractions != 2 { + t.Fatal("recent scan failures is not 2", h.Interactions.FailedInteractions) } // assert no hosts are removed @@ -681,7 +637,7 @@ func TestRemoveHosts(t *testing.T) { // record interactions t3 := now.Add(-time.Minute * 60) // 1 hour ago (60min downtime) - hi3 := newTestScan(hk, t3, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false, nil) + hi3 := newTestScan(hk, t3, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false, nil, nil) if err := ss.RecordHostScans(context.Background(), []api.HostScan{hi3}); err != nil { t.Fatal(err) } @@ -714,91 +670,11 @@ func TestRemoveHosts(t *testing.T) { } // assert host is removed from the database - if _, err = hostByPubKey(ss.db, hk); err != gorm.ErrRecordNotFound { + if _, err = ss.Host(context.Background(), hk); !errors.Is(err, api.ErrHostNotFound) { t.Fatal("expected record not found error") } } -// TestInsertAnnouncements is a test for insertAnnouncements. -func TestInsertAnnouncements(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - // Create announcements for 3 hosts. - ann1 := announcement{ - hostKey: publicKey(types.GeneratePrivateKey().PublicKey()), - announcement: newTestHostDBAnnouncement("foo.bar:1000"), - } - ann2 := announcement{ - hostKey: publicKey(types.GeneratePrivateKey().PublicKey()), - announcement: newTestHostDBAnnouncement("bar.baz:1000"), - } - ann3 := announcement{ - hostKey: publicKey(types.GeneratePrivateKey().PublicKey()), - announcement: newTestHostDBAnnouncement("quz.qux:1000"), - } - - // Insert the first one and check that all fields are set. - if err := insertAnnouncements(ss.db, []announcement{ann1}); err != nil { - t.Fatal(err) - } - var ann dbAnnouncement - if err := ss.db.Find(&ann).Error; err != nil { - t.Fatal(err) - } - ann.Model = Model{} // ignore - expectedAnn := dbAnnouncement{ - HostKey: ann1.hostKey, - BlockHeight: 1, - BlockID: types.BlockID{1}.String(), - NetAddress: "foo.bar:1000", - } - if ann != expectedAnn { - t.Fatal("mismatch") - } - // Insert the first and second one. - if err := insertAnnouncements(ss.db, []announcement{ann1, ann2}); err != nil { - t.Fatal(err) - } - - // Insert the first one twice. The second one again and the third one. - if err := insertAnnouncements(ss.db, []announcement{ann1, ann2, ann1, ann3}); err != nil { - t.Fatal(err) - } - - // There should be 3 hosts in the db. - hosts, err := ss.hosts() - if err != nil { - t.Fatal(err) - } - if len(hosts) != 3 { - t.Fatal("invalid number of hosts") - } - - // There should be 7 announcements total. - var announcements []dbAnnouncement - if err := ss.db.Find(&announcements).Error; err != nil { - t.Fatal(err) - } - if len(announcements) != 7 { - t.Fatal("invalid number of announcements") - } - - // Add an entry to the blocklist to block host 1 - entry1 := "foo.bar" - err = ss.UpdateHostBlocklistEntries(context.Background(), []string{entry1}, nil, false) - if err != nil { - t.Fatal(err) - } - - // Insert multiple announcements for host 1 - this asserts that the UNIQUE - // constraint on the blocklist table isn't triggered when inserting multiple - // announcements for a host that's on the blocklist - if err := insertAnnouncements(ss.db, []announcement{ann1, ann1}); err != nil { - t.Fatal(err) - } -} - func TestSQLHostAllowlist(t *testing.T) { ss := newTestSQLStore(t, defaultTestSQLStoreConfig) defer ss.Close() @@ -825,11 +701,7 @@ func TestSQLHostAllowlist(t *testing.T) { numRelations := func() (cnt int64) { t.Helper() - err := ss.db.Table("host_allowlist_entry_hosts").Count(&cnt).Error - if err != nil { - t.Fatal(err) - } - return + return ss.Count("host_allowlist_entry_hosts") } isAllowed := func(hk types.PublicKey) bool { @@ -931,7 +803,7 @@ func TestSQLHostAllowlist(t *testing.T) { } // remove host 1 - if err = ss.db.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk1)}).Delete(&dbHost{}).Error; err != nil { + if err := ss.DeleteHost(hk1); err != nil { t.Fatal(err) } if numHosts() != 0 { @@ -997,20 +869,12 @@ func TestSQLHostBlocklist(t *testing.T) { numAllowlistRelations := func() (cnt int64) { t.Helper() - err := ss.db.Table("host_allowlist_entry_hosts").Count(&cnt).Error - if err != nil { - t.Fatal(err) - } - return + return ss.Count("host_allowlist_entry_hosts") } numBlocklistRelations := func() (cnt int64) { t.Helper() - err := ss.db.Table("host_blocklist_entry_hosts").Count(&cnt).Error - if err != nil { - t.Fatal(err) - } - return + return ss.Count("host_blocklist_entry_hosts") } isBlocked := func(hk types.PublicKey) bool { @@ -1115,7 +979,7 @@ func TestSQLHostBlocklist(t *testing.T) { } // delete host 2 and assert the delete cascaded properly - if err = ss.db.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk2)}).Delete(&dbHost{}).Error; err != nil { + if err = ss.DeleteHost(hk2); err != nil { t.Fatal(err) } if numHosts() != 2 { @@ -1252,134 +1116,19 @@ func TestSQLHostBlocklistBasic(t *testing.T) { } } -// TestAnnouncementMaxAge verifies old announcements are ignored. -func TestAnnouncementMaxAge(t *testing.T) { - db := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer db.Close() - - if len(db.unappliedAnnouncements) != 0 { - t.Fatal("expected 0 announcements") - } - - db.processConsensusChangeHostDB( - modules.ConsensusChange{ - ID: modules.ConsensusChangeID{1}, - BlockHeight: 1, - AppliedBlocks: []stypes.Block{ - { - Timestamp: stypes.Timestamp(time.Now().Add(-time.Hour).Add(-time.Minute).Unix()), - Transactions: []stypes.Transaction{newTestTransaction(newTestHostAnnouncement("foo.com:1000"))}, - }, - { - Timestamp: stypes.Timestamp(time.Now().Add(-time.Hour).Add(time.Minute).Unix()), - Transactions: []stypes.Transaction{newTestTransaction(newTestHostAnnouncement("foo.com:1001"))}, - }, - }, - }, - ) - - if len(db.unappliedAnnouncements) != 1 { - t.Fatal("expected 1 announcement") - } else if db.unappliedAnnouncements[0].announcement.NetAddress != "foo.com:1001" { - t.Fatal("unexpected announcement") - } -} - -// addTestHosts adds 'n' hosts to the db and returns their keys. -func (s *SQLStore) addTestHosts(n int) (keys []types.PublicKey, err error) { - cnt, err := s.contractsCount() - if err != nil { - return nil, err - } - - for i := 0; i < n; i++ { - keys = append(keys, types.PublicKey{byte(int(cnt) + i + 1)}) - if err := s.addTestHost(keys[len(keys)-1]); err != nil { - return nil, err - } - } - return -} - -// addTestHost ensures a host with given hostkey exists. -func (s *SQLStore) addTestHost(hk types.PublicKey) error { - return s.addCustomTestHost(hk, "") -} - -// addCustomTestHost ensures a host with given hostkey and net address exists. -func (s *SQLStore) addCustomTestHost(hk types.PublicKey, na string) error { - s.unappliedHostKeys[hk] = struct{}{} - s.unappliedAnnouncements = append(s.unappliedAnnouncements, []announcement{{ - hostKey: publicKey(hk), - announcement: newTestHostDBAnnouncement(na), - }}...) - s.lastSave = time.Now().Add(s.persistInterval * -2) - return s.applyUpdates(false) -} - -// hosts returns all hosts in the db. Only used in testing since preloading all -// interactions for all hosts is expensive in production. -func (db *SQLStore) hosts() ([]dbHost, error) { - var hosts []dbHost - tx := db.db.Find(&hosts) - if tx.Error != nil { - return nil, tx.Error - } - return hosts, nil -} - -func hostByPubKey(tx *gorm.DB, hostKey types.PublicKey) (dbHost, error) { - var h dbHost - err := tx.Where("public_key", publicKey(hostKey)). - Take(&h).Error - return h, err -} - // newTestScan returns a host interaction with given parameters. -func newTestScan(hk types.PublicKey, scanTime time.Time, settings rhpv2.HostSettings, pt rhpv3.HostPriceTable, success bool, subnets []string) api.HostScan { +func newTestScan(hk types.PublicKey, scanTime time.Time, settings rhpv2.HostSettings, pt rhpv3.HostPriceTable, success bool, resolvedAddresses, subnets []string) api.HostScan { return api.HostScan{ - HostKey: hk, - PriceTable: pt, - Settings: settings, - Subnets: subnets, - Success: success, - Timestamp: scanTime, - } -} - -func newTestPK() (stypes.SiaPublicKey, types.PrivateKey) { - sk := types.GeneratePrivateKey() - pk := sk.PublicKey() - return stypes.SiaPublicKey{ - Algorithm: stypes.SignatureEd25519, - Key: pk[:], - }, sk -} - -func newTestHostAnnouncement(na modules.NetAddress) (modules.HostAnnouncement, types.PrivateKey) { - spk, sk := newTestPK() - return modules.HostAnnouncement{ - Specifier: modules.PrefixHostAnnouncement, - NetAddress: na, - PublicKey: spk, - }, sk -} - -func newTestHostDBAnnouncement(addr string) hostdb.Announcement { - return hostdb.Announcement{ - Index: types.ChainIndex{Height: 1, ID: types.BlockID{1}}, - Timestamp: time.Now().UTC().Round(time.Second), - NetAddress: addr, + HostKey: hk, + PriceTable: pt, + Settings: settings, + ResolvedAddresses: resolvedAddresses, + Subnets: subnets, + Success: success, + Timestamp: scanTime, } } -func newTestTransaction(ha modules.HostAnnouncement, sk types.PrivateKey) stypes.Transaction { - var buf bytes.Buffer - buf.Write(encoding.Marshal(ha)) - buf.Write(encoding.Marshal(sk.SignHash(types.Hash256(crypto.HashObject(ha))))) - return stypes.Transaction{ArbitraryData: [][]byte{buf.Bytes()}} -} - func newTestHostCheck() api.HostCheck { return api.HostCheck{ @@ -1411,3 +1160,45 @@ func newTestHostCheck() api.HostCheck { }, } } + +// addCustomTestHost ensures a host with given hostkey and net address exists. +func (s *testSQLStore) addCustomTestHost(hk types.PublicKey, na string) error { + if err := s.announceHost(hk, na); err != nil { + return err + } + return nil +} + +// addTestHost ensures a host with given hostkey exists. +func (s *testSQLStore) addTestHost(hk types.PublicKey) error { + return s.addCustomTestHost(hk, "") +} + +// addTestHosts adds 'n' hosts to the db and returns their keys. +func (s *testSQLStore) addTestHosts(n int) (keys []types.PublicKey, err error) { + cnt := s.Count("contracts") + + for i := 0; i < n; i++ { + keys = append(keys, types.PublicKey{byte(int(cnt) + i + 1)}) + if err := s.addTestHost(keys[len(keys)-1]); err != nil { + return nil, err + } + } + return +} + +// announceHost adds a host announcement to the database. +func (s *testSQLStore) announceHost(hk types.PublicKey, na string) error { + return s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + return tx.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.UpdateHost(hk, chain.HostAnnouncement{ + NetAddress: na, + }, 42, types.BlockID{1, 2, 3}, time.Now().UTC().Round(time.Second)) + }) + }) +} + +func (s *testSQLStore) DeleteHost(hk types.PublicKey) error { + _, err := s.DB().Exec(context.Background(), "DELETE FROM hosts WHERE public_key = ?", sql.PublicKey(hk)) + return err +} diff --git a/stores/metadata.go b/stores/metadata.go index 56b92d1c4..fe26bc29e 100644 --- a/stores/metadata.go +++ b/stores/metadata.go @@ -7,7 +7,6 @@ import ( "math" "strings" "time" - "unicode/utf8" rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" @@ -15,10 +14,7 @@ import ( "go.sia.tech/renterd/api" "go.sia.tech/renterd/object" sql "go.sia.tech/renterd/stores/sql" - "go.sia.tech/siad/modules" "go.uber.org/zap" - "gorm.io/gorm" - "gorm.io/gorm/clause" "lukechampine.com/frand" ) @@ -33,10 +29,6 @@ const ( // 10/30 erasure coding and takes <1s to execute on an SSD in SQLite. refreshHealthBatchSize = 10000 - // sectorInsertionBatchSize is the number of sectors per batch when we - // upsert sectors. - sectorInsertionBatchSize = 500 - // slabPruningBatchSize is the number of slabs per batch when we prune // slabs. We limit this to 100 slabs which is 3000 sectors at default // redundancy. @@ -46,421 +38,15 @@ const ( refreshHealthMaxHealthValidity = 72 * time.Hour ) -const ( - contractStateInvalid contractState = iota - contractStatePending - contractStateActive - contractStateComplete - contractStateFailed -) - var ( pruneSlabsAlertID = frand.Entropy256() pruneDirsAlertID = frand.Entropy256() ) -var ( - objectDeleteBatchSizes = []int64{10, 50, 100, 200, 500, 1000, 5000, 10000, 50000, 100000} -) - -type ( - contractState uint8 - - dbArchivedContract struct { - Model - - ContractCommon - RenewedTo fileContractID `gorm:"index;size:32"` - - Host publicKey `gorm:"index;NOT NULL;size:32"` - Reason string - } - - dbContract struct { - Model - - ContractCommon - - HostID uint `gorm:"index"` - Host dbHost - - ContractSets []dbContractSet `gorm:"many2many:contract_set_contracts;constraint:OnDelete:CASCADE"` - } - - ContractCommon struct { - FCID fileContractID `gorm:"unique;index;NOT NULL;column:fcid;size:32"` - RenewedFrom fileContractID `gorm:"index;size:32"` - - ContractPrice currency - State contractState `gorm:"index;NOT NULL;default:0"` - TotalCost currency - ProofHeight uint64 `gorm:"index;default:0"` - RevisionHeight uint64 `gorm:"index;default:0"` - RevisionNumber string `gorm:"NOT NULL;default:'0'"` // string since db can't store math.MaxUint64 - Size uint64 - StartHeight uint64 `gorm:"index;NOT NULL"` - WindowStart uint64 `gorm:"index;NOT NULL;default:0"` - WindowEnd uint64 `gorm:"index;NOT NULL;default:0"` - - // spending fields - UploadSpending currency - DownloadSpending currency - FundAccountSpending currency - DeleteSpending currency - ListSpending currency - } - - dbContractSet struct { - Model - - Name string `gorm:"unique;index;"` - Contracts []dbContract `gorm:"many2many:contract_set_contracts;constraint:OnDelete:CASCADE"` - } - - dbDirectory struct { - Model - - Name string - DBParentID uint - } - - dbObject struct { - Model - - DBDirectoryID uint - - DBBucketID uint `gorm:"index;uniqueIndex:idx_object_bucket;NOT NULL"` - DBBucket dbBucket - ObjectID string `gorm:"index;uniqueIndex:idx_object_bucket"` - - Key secretKey - Slabs []dbSlice // no CASCADE, slices are deleted via trigger - Metadata []dbObjectUserMetadata `gorm:"constraint:OnDelete:CASCADE"` // CASCADE to delete metadata too - Health float64 `gorm:"index;default:1.0; NOT NULL"` - Size int64 - - MimeType string `json:"index"` - Etag string `gorm:"index"` - } - - dbObjectUserMetadata struct { - Model - - DBObjectID *uint `gorm:"index:uniqueIndex:idx_object_user_metadata_key"` - DBMultipartUploadID *uint `gorm:"index:uniqueIndex:idx_object_user_metadata_key"` - Key string `gorm:"index:uniqueIndex:idx_object_user_metadata_key"` - Value string - } - - dbBucket struct { - Model - - Policy api.BucketPolicy `gorm:"serializer:json"` - Name string `gorm:"unique;index;NOT NULL"` - } - - dbSlice struct { - Model - DBObjectID *uint `gorm:"index"` - ObjectIndex uint `gorm:"index:idx_slices_object_index"` - DBMultipartPartID *uint `gorm:"index"` - - // Slice related fields. - DBSlabID uint `gorm:"index"` - Offset uint32 - Length uint32 - } - - dbSlab struct { - Model - DBContractSetID uint `gorm:"index"` - DBContractSet dbContractSet - DBBufferedSlabID uint `gorm:"index;default: NULL"` - - Health float64 `gorm:"index;default:1.0; NOT NULL"` - HealthValidUntil int64 `gorm:"index;default:0; NOT NULL"` // unix timestamp - Key secretKey `gorm:"unique;NOT NULL;size:32"` // json string - MinShards uint8 `gorm:"index"` - TotalShards uint8 `gorm:"index"` - - Slices []dbSlice - Shards []dbSector `gorm:"constraint:OnDelete:CASCADE"` // CASCADE to delete shards too - } - - dbSector struct { - Model - - DBSlabID uint `gorm:"index:idx_sectors_db_slab_id;uniqueIndex:idx_sectors_slab_id_slab_index;NOT NULL"` - SlabIndex int `gorm:"index:idx_sectors_slab_index;uniqueIndex:idx_sectors_slab_id_slab_index;NOT NULL"` - - LatestHost publicKey `gorm:"NOT NULL"` - Root []byte `gorm:"index;unique;NOT NULL;size:32"` - - Contracts []dbContract `gorm:"many2many:contract_sectors;constraint:OnDelete:CASCADE"` - } - - // dbContractSector is a join table between dbContract and dbSector. - dbContractSector struct { - DBSectorID uint `gorm:"primaryKey;index"` - DBContractID uint `gorm:"primaryKey;index"` - } - - // rawObject is used for hydration and is made up of one or many raw sectors. - rawObject []rawObjectSector - - // rawObjectRow contains all necessary information to reconstruct the object. - rawObjectSector struct { - // object - ObjectID uint - ObjectIndex uint64 - ObjectKey []byte - ObjectName string - ObjectSize int64 - ObjectModTime time.Time - ObjectMimeType string - ObjectHealth float64 - ObjectETag string - - // slice - SliceOffset uint32 - SliceLength uint32 - - // slab - SlabBuffered bool - SlabID uint - SlabHealth float64 - SlabKey []byte - SlabMinShards uint8 - - // sector - SectorIndex uint - SectorRoot []byte - LatestHost publicKey - - // contract - FCID fileContractID - HostKey publicKey - } - - // rawObjectMetadata is used for hydrating object metadata. - rawObjectMetadata struct { - ETag string - Health float64 - MimeType string - ModTime datetime - ObjectName string - Size int64 - } -) - -func (s *contractState) LoadString(state string) error { - switch strings.ToLower(state) { - case api.ContractStateInvalid: - *s = contractStateInvalid - case api.ContractStatePending: - *s = contractStatePending - case api.ContractStateActive: - *s = contractStateActive - case api.ContractStateComplete: - *s = contractStateComplete - case api.ContractStateFailed: - *s = contractStateFailed - default: - *s = contractStateInvalid - } - return nil -} - -func (s contractState) String() string { - switch s { - case contractStateInvalid: - return api.ContractStateInvalid - case contractStatePending: - return api.ContractStatePending - case contractStateActive: - return api.ContractStateActive - case contractStateComplete: - return api.ContractStateComplete - case contractStateFailed: - return api.ContractStateFailed - default: - return api.ContractStateUnknown - } -} - -func (s dbSlab) HealthValid() bool { - return time.Now().Before(time.Unix(s.HealthValidUntil, 0)) -} - -// TableName implements the gorm.Tabler interface. -func (dbArchivedContract) TableName() string { return "archived_contracts" } - -// TableName implements the gorm.Tabler interface. -func (dbBucket) TableName() string { return "buckets" } - -// TableName implements the gorm.Tabler interface. -func (dbContract) TableName() string { return "contracts" } - -// TableName implements the gorm.Tabler interface. -func (dbContractSector) TableName() string { return "contract_sectors" } - -// TableName implements the gorm.Tabler interface. -func (dbContractSet) TableName() string { return "contract_sets" } - -// TableName implements the gorm.Tabler interface. -func (dbDirectory) TableName() string { return "directories" } - -// TableName implements the gorm.Tabler interface. -func (dbObject) TableName() string { return "objects" } - -// TableName implements the gorm.Tabler interface. -func (dbObjectUserMetadata) TableName() string { return "object_user_metadata" } - -// TableName implements the gorm.Tabler interface. -func (dbSector) TableName() string { return "sectors" } - -// TableName implements the gorm.Tabler interface. -func (dbSlab) TableName() string { return "slabs" } - -// TableName implements the gorm.Tabler interface. -func (dbSlice) TableName() string { return "slices" } - -// convert converts a dbContract to a ContractMetadata. -func (c dbContract) convert() api.ContractMetadata { - var revisionNumber uint64 - _, _ = fmt.Sscan(c.RevisionNumber, &revisionNumber) - var contractSets []string - for _, cs := range c.ContractSets { - contractSets = append(contractSets, cs.Name) - } - return api.ContractMetadata{ - ContractPrice: types.Currency(c.ContractPrice), - ID: types.FileContractID(c.FCID), - HostIP: c.Host.NetAddress, - HostKey: types.PublicKey(c.Host.PublicKey), - SiamuxAddr: rhpv2.HostSettings(c.Host.Settings).SiamuxAddr(), - - RenewedFrom: types.FileContractID(c.RenewedFrom), - TotalCost: types.Currency(c.TotalCost), - Spending: api.ContractSpending{ - Uploads: types.Currency(c.UploadSpending), - Downloads: types.Currency(c.DownloadSpending), - FundAccount: types.Currency(c.FundAccountSpending), - Deletions: types.Currency(c.DeleteSpending), - SectorRoots: types.Currency(c.ListSpending), - }, - ProofHeight: c.ProofHeight, - RevisionHeight: c.RevisionHeight, - RevisionNumber: revisionNumber, - ContractSets: contractSets, - Size: c.Size, - StartHeight: c.StartHeight, - State: c.State.String(), - WindowStart: c.WindowStart, - WindowEnd: c.WindowEnd, - } -} - -// convert turns a dbObject into a object.Slab. -func (s dbSlab) convert() (slab object.Slab, err error) { - // unmarshal key - err = slab.Key.UnmarshalBinary(s.Key) - if err != nil { - return - } - - // set health - slab.Health = s.Health - - // set shards - slab.MinShards = s.MinShards - slab.Shards = make([]object.Sector, len(s.Shards)) - - // hydrate shards - for i, shard := range s.Shards { - slab.Shards[i].LatestHost = types.PublicKey(shard.LatestHost) - slab.Shards[i].Root = *(*types.Hash256)(shard.Root) - for _, c := range shard.Contracts { - if slab.Shards[i].Contracts == nil { - slab.Shards[i].Contracts = make(map[types.PublicKey][]types.FileContractID) - } - slab.Shards[i].Contracts[types.PublicKey(c.Host.PublicKey)] = append(slab.Shards[i].Contracts[types.PublicKey(c.Host.PublicKey)], types.FileContractID(c.FCID)) - } - } - - return -} - -func (raw rawObjectMetadata) convert() api.ObjectMetadata { - return newObjectMetadata( - raw.ObjectName, - raw.ETag, - raw.MimeType, - raw.Health, - time.Time(raw.ModTime), - raw.Size, - ) -} - -func (raw rawObject) toSlabSlice() (slice object.SlabSlice, _ error) { - if len(raw) == 0 { - return object.SlabSlice{}, errors.New("no sectors found") - } else if raw[0].SlabBuffered && len(raw) != 1 { - return object.SlabSlice{}, errors.New("buffered slab with multiple sectors") - } - - // unmarshal key - if err := slice.Slab.Key.UnmarshalBinary(raw[0].SlabKey); err != nil { - return object.SlabSlice{}, err - } - - // handle partial slab - if raw[0].SlabBuffered { - slice.Offset = raw[0].SliceOffset - slice.Length = raw[0].SliceLength - slice.Slab.MinShards = raw[0].SlabMinShards - slice.Slab.Health = raw[0].SlabHealth - return - } - - // hydrate all sectors - slabID := raw[0].SlabID - sectors := make([]object.Sector, 0, len(raw)) - secIdx := uint(0) - for _, sector := range raw { - if sector.SlabID != slabID { - return object.SlabSlice{}, errors.New("sectors from different slabs") // developer error - } - latestHost := types.PublicKey(sector.LatestHost) - fcid := types.FileContractID(sector.FCID) - - // next sector - if sector.SectorIndex != secIdx { - sectors = append(sectors, object.Sector{ - Contracts: make(map[types.PublicKey][]types.FileContractID), - LatestHost: latestHost, - Root: *(*types.Hash256)(sector.SectorRoot), - }) - secIdx = sector.SectorIndex - } - - // add host+contract to sector - if fcid != (types.FileContractID{}) { - sectors[len(sectors)-1].Contracts[types.PublicKey(sector.HostKey)] = append(sectors[len(sectors)-1].Contracts[types.PublicKey(sector.HostKey)], fcid) - } - } - - // hydrate all fields - slice.Slab.Health = raw[0].SlabHealth - slice.Slab.Shards = sectors - slice.Slab.MinShards = raw[0].SlabMinShards - slice.Offset = raw[0].SliceOffset - slice.Length = raw[0].SliceLength - return slice, nil -} +var objectDeleteBatchSizes = []int64{10, 50, 100, 200, 500, 1000, 5000, 10000, 50000, 100000} func (s *SQLStore) Bucket(ctx context.Context, bucket string) (b api.Bucket, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { b, err = tx.Bucket(ctx, bucket) return }) @@ -468,25 +54,25 @@ func (s *SQLStore) Bucket(ctx context.Context, bucket string) (b api.Bucket, err } func (s *SQLStore) CreateBucket(ctx context.Context, bucket string, policy api.BucketPolicy) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.CreateBucket(ctx, bucket, policy) }) } func (s *SQLStore) UpdateBucketPolicy(ctx context.Context, bucket string, policy api.BucketPolicy) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateBucketPolicy(ctx, bucket, policy) }) } func (s *SQLStore) DeleteBucket(ctx context.Context, bucket string) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.DeleteBucket(ctx, bucket) }) } func (s *SQLStore) ListBuckets(ctx context.Context) (buckets []api.Bucket, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { buckets, err = tx.ListBuckets(ctx) return }) @@ -497,7 +83,7 @@ func (s *SQLStore) ListBuckets(ctx context.Context) (buckets []api.Bucket, err e // reduce locking and make sure all results are consistent, everything is done // within a single transaction. func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) (resp api.ObjectsStatsResponse, _ error) { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { resp, err = tx.ObjectsStats(ctx, opts) return }) @@ -507,7 +93,7 @@ func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) func (s *SQLStore) SlabBuffers(ctx context.Context) ([]api.SlabBuffer, error) { var err error var fileNameToContractSet map[string]string - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { fileNameToContractSet, err = tx.SlabBuffers(ctx) return err }) @@ -524,25 +110,21 @@ func (s *SQLStore) SlabBuffers(ctx context.Context) ([]api.SlabBuffer, error) { } func (s *SQLStore) AddContract(ctx context.Context, c rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, state string) (_ api.ContractMetadata, err error) { - var cs contractState - if err := cs.LoadString(state); err != nil { - return api.ContractMetadata{}, err - } - var added dbContract - if err = s.retryTransaction(ctx, func(tx *gorm.DB) error { - added, err = addContract(tx, c, contractPrice, totalCost, startHeight, types.FileContractID{}, cs) + var contract api.ContractMetadata + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + contract, err = tx.InsertContract(ctx, c, contractPrice, totalCost, startHeight, types.FileContractID{}, state) return err - }); err != nil { - return + }) + if err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to add contract: %w", err) } - s.addKnownContract(types.FileContractID(added.FCID)) - return added.convert(), nil + return contract, nil } func (s *SQLStore) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) { var contracts []api.ContractMetadata - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { contracts, err = tx.Contracts(ctx, opts) return }) @@ -553,55 +135,19 @@ func (s *SQLStore) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api // The old contract specified as 'renewedFrom' will be deleted from the active // contracts and moved to the archive. Both new and old contract will be linked // to each other through the RenewedFrom and RenewedTo fields respectively. -func (s *SQLStore) AddRenewedContract(ctx context.Context, c rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) { - var cs contractState - if err := cs.LoadString(state); err != nil { - return api.ContractMetadata{}, err - } - var renewed dbContract - if err := s.retryTransaction(ctx, func(tx *gorm.DB) error { - // Fetch contract we renew from. - oldContract, err := contract(tx, fileContractID(renewedFrom)) - if err != nil { - return err - } - - // Create copy in archive. - err = tx.Create(&dbArchivedContract{ - Host: publicKey(oldContract.Host.PublicKey), - Reason: api.ContractArchivalReasonRenewed, - RenewedTo: fileContractID(c.ID()), - - ContractCommon: oldContract.ContractCommon, - }).Error - if err != nil { - return err - } - - // Overwrite the old contract with the new one. - newContract := newContract(oldContract.HostID, c.ID(), renewedFrom, contractPrice, totalCost, startHeight, c.Revision.WindowStart, c.Revision.WindowEnd, oldContract.Size, cs) - newContract.Model = oldContract.Model - newContract.CreatedAt = time.Now() - err = tx.Save(&newContract).Error - if err != nil { - return err - } - - // Populate host. - newContract.Host = oldContract.Host - - s.addKnownContract(c.ID()) - renewed = newContract - return nil - }); err != nil { - return api.ContractMetadata{}, err +func (s *SQLStore) AddRenewedContract(ctx context.Context, c rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (renewed api.ContractMetadata, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + renewed, err = tx.RenewContract(ctx, c, contractPrice, totalCost, startHeight, renewedFrom, state) + return err + }) + if err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to add renewed contract: %w", err) } - - return renewed.convert(), nil + return } func (s *SQLStore) AncestorContracts(ctx context.Context, id types.FileContractID, startHeight uint64) (ancestors []api.ArchivedContract, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { ancestors, err = tx.AncestorContracts(ctx, id, startHeight) return err }) @@ -626,7 +172,7 @@ func (s *SQLStore) ArchiveContracts(ctx context.Context, toArchive map[types.Fil // archive the contract but don't interrupt the process if one contract // fails - if err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + if err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.ArchiveContract(ctx, fcid, reason) }); err != nil { errs = append(errs, fmt.Sprintf("%v: %v", fcid, err)) @@ -651,20 +197,16 @@ func (s *SQLStore) ArchiveAllContracts(ctx context.Context, reason string) error return s.ArchiveContracts(ctx, toArchive) } -func (s *SQLStore) Contract(ctx context.Context, id types.FileContractID) (api.ContractMetadata, error) { - contract, err := s.contract(ctx, fileContractID(id)) - if err != nil { - return api.ContractMetadata{}, err - } - return contract.convert(), nil +func (s *SQLStore) Contract(ctx context.Context, id types.FileContractID) (cm api.ContractMetadata, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + cm, err = tx.Contract(ctx, id) + return err + }) + return } func (s *SQLStore) ContractRoots(ctx context.Context, id types.FileContractID) (roots []types.Hash256, err error) { - if !s.isKnownContract(id) { - return nil, api.ErrContractNotFound - } - - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { roots, err = tx.ContractRoots(ctx, id) return err }) @@ -672,7 +214,7 @@ func (s *SQLStore) ContractRoots(ctx context.Context, id types.FileContractID) ( } func (s *SQLStore) ContractSets(ctx context.Context) (sets []string, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { sets, err = tx.ContractSets(ctx) return err }) @@ -680,7 +222,7 @@ func (s *SQLStore) ContractSets(ctx context.Context) (sets []string, err error) } func (s *SQLStore) ContractSizes(ctx context.Context) (sizes map[types.FileContractID]api.ContractSize, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { sizes, err = tx.ContractSizes(ctx) return err }) @@ -688,10 +230,7 @@ func (s *SQLStore) ContractSizes(ctx context.Context) (sizes map[types.FileContr } func (s *SQLStore) ContractSize(ctx context.Context, id types.FileContractID) (cs api.ContractSize, err error) { - if !s.isKnownContract(id) { - return api.ContractSize{}, api.ErrContractNotFound - } - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { cs, err = tx.ContractSize(ctx, id) return }) @@ -699,57 +238,40 @@ func (s *SQLStore) ContractSize(ctx context.Context, id types.FileContractID) (c } func (s *SQLStore) SetContractSet(ctx context.Context, name string, contractIds []types.FileContractID) error { - var wantedIds []fileContractID - wanted := make(map[fileContractID]struct{}) + wanted := make(map[types.FileContractID]struct{}) for _, fcid := range contractIds { - wantedIds = append(wantedIds, fileContractID(fcid)) - wanted[fileContractID(fcid)] = struct{}{} + wanted[types.FileContractID(fcid)] = struct{}{} } var diff []types.FileContractID var nContractsAfter int - err := s.retryTransaction(ctx, func(tx *gorm.DB) error { - // fetch contract set - var cs dbContractSet - err := tx. - Where(dbContractSet{Name: name}). - Preload("Contracts"). - FirstOrCreate(&cs). - Error - if err != nil { - return err - } - - // fetch contracts - var dbContracts []dbContract - err = tx. - Model(&dbContract{}). - Where("fcid IN (?)", wantedIds). - Find(&dbContracts). - Error - if err != nil { - return err - } - nContractsAfter = len(dbContracts) - - // add removals to the diff - for _, contract := range cs.Contracts { - if _, ok := wanted[contract.FCID]; !ok { - diff = append(diff, types.FileContractID(contract.FCID)) + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + // build diff + prevContracts, err := tx.Contracts(ctx, api.ContractsOpts{ContractSet: name}) + if err != nil && !errors.Is(err, api.ErrContractSetNotFound) { + return fmt.Errorf("failed to fetch contracts: %w", err) + } + diff = nil // reset + for _, c := range prevContracts { + if _, exists := wanted[c.ID]; !exists { + diff = append(diff, c.ID) // removal + } else { + delete(wanted, c.ID) } - delete(wanted, contract.FCID) } - - // add additions to the diff for fcid := range wanted { - diff = append(diff, types.FileContractID(fcid)) + diff = append(diff, fcid) // addition } - - // update the association - if err := tx.Model(&cs).Association("Contracts").Replace(&dbContracts); err != nil { - return err + // update contract set + if err := tx.SetContractSet(ctx, name, contractIds); err != nil { + return fmt.Errorf("failed to set contract set: %w", err) } - + // fetch contracts after update + afterContracts, err := tx.Contracts(ctx, api.ContractsOpts{ContractSet: name}) + if err != nil { + return fmt.Errorf("failed to fetch contracts after update: %w", err) + } + nContractsAfter = len(afterContracts) return nil }) if err != nil { @@ -775,257 +297,38 @@ func (s *SQLStore) SetContractSet(ctx context.Context, name string, contractIds } func (s *SQLStore) RemoveContractSet(ctx context.Context, name string) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.RemoveContractSet(ctx, name) }) } func (s *SQLStore) RenewedContract(ctx context.Context, renewedFrom types.FileContractID) (cm api.ContractMetadata, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { cm, err = tx.RenewedContract(ctx, renewedFrom) return err }) return } -func (s *SQLStore) SearchObjects(ctx context.Context, bucket, substring string, offset, limit int) ([]api.ObjectMetadata, error) { - // fetch one more to see if there are more entries - if limit <= -1 { - limit = math.MaxInt - } - - var objects []api.ObjectMetadata - err := s.db. - WithContext(ctx). - Select("o.object_id as Name, o.size as Size, o.health as Health, o.mime_type as MimeType, o.etag as ETag, o.created_at as ModTime"). - Model(&dbObject{}). - Table("objects o"). - Joins("INNER JOIN buckets b ON o.db_bucket_id = b.id"). - Where("INSTR(o.object_id, ?) > 0 AND b.name = ?", substring, bucket). - Order("o.object_id ASC"). - Offset(offset). - Limit(limit). - Scan(&objects).Error - if err != nil { - return nil, err - } - - return objects, nil +func (s *SQLStore) SearchObjects(ctx context.Context, bucket, substring string, offset, limit int) (objects []api.ObjectMetadata, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + objects, err = tx.SearchObjects(ctx, bucket, substring, offset, limit) + return err + }) + return } func (s *SQLStore) ObjectEntries(ctx context.Context, bucket, path, prefix, sortBy, sortDir, marker string, offset, limit int) (metadata []api.ObjectMetadata, hasMore bool, err error) { - // sanity check we are passing a directory - if !strings.HasSuffix(path, "/") { - panic("path must end in /") - } - - // convenience variables - usingMarker := marker != "" - usingOffset := offset > 0 - - // sanity check we are passing sane paging parameters - if usingMarker && usingOffset { - return nil, false, errors.New("fetching entries using a marker and an offset is not supported at the same time") - } - - // sanity check we are passing sane sorting parameters - if err := validateSort(sortBy, sortDir); err != nil { - return nil, false, err - } - - // sanitize sorting parameters - if sortBy == "" { - sortBy = api.ObjectSortByName - } - if sortDir == "" { - sortDir = api.ObjectSortDirAsc - } else { - sortDir = strings.ToLower(sortDir) - } - - // ensure marker is '/' prefixed - if usingMarker && !strings.HasPrefix(marker, "/") { - marker = fmt.Sprintf("/%s", marker) - } - - // ensure limit is out of play - if limit <= -1 { - limit = math.MaxInt - } - - // fetch one more to see if there are more entries - if limit != math.MaxInt { - limit += 1 - } - - // ensure offset is out of play - if usingMarker { - offset = 0 - } - - // fetch id of directory to query - dirID, err := s.dirID(s.db, path) - if errors.Is(err, gorm.ErrRecordNotFound) { - return []api.ObjectMetadata{}, false, nil - } else if err != nil { - return nil, false, err - } - - // fetch bucket id - var dBucket dbBucket - if err := s.db.Select("id"). - Where("name", bucket). - Take(&dBucket).Error; err != nil { - return nil, false, fmt.Errorf("failed to fetch bucket id: %w", err) - } - - // build prefix expression - prefixExpr := "TRUE" - if prefix != "" { - prefixExpr = "SUBSTR(o.object_id, 1, ?) = ?" - } - - lengthFn := "CHAR_LENGTH" - if isSQLite(s.db) { - lengthFn = "LENGTH" - } - - // objectsQuery consists of 2 parts - // 1. fetch all objects in requested directory - // 2. fetch all sub-directories - objectsQuery := fmt.Sprintf(` -SELECT o.etag as ETag, o.created_at as ModTime, o.object_id as ObjectName, o.size as Size, o.health as Health, o.mime_type as MimeType -FROM objects o -WHERE o.object_id != ? AND o.db_directory_id = ? AND o.db_bucket_id = (SELECT id FROM buckets b WHERE b.name = ?) AND %s -UNION ALL -SELECT '' as ETag, MAX(o.created_at) as ModTime, d.name as ObjectName, SUM(o.size) as Size, MIN(o.health) as Health, '' as MimeType -FROM objects o -INNER JOIN directories d ON SUBSTR(o.object_id, 1, %s(d.name)) = d.name AND %s -WHERE o.db_bucket_id = (SELECT id FROM buckets b WHERE b.name = ?) -AND o.object_id LIKE ? -AND SUBSTR(o.object_id, 1, ?) = ? -AND d.db_parent_id = ? -GROUP BY d.id -`, prefixExpr, - lengthFn, - prefixExpr) - - // build query params - var objectsQueryParams []interface{} - if prefix != "" { - objectsQueryParams = []interface{}{ - path, // o.object_id != ? - dirID, bucket, // o.db_directory_id = ? AND b.name = ? - utf8.RuneCountInString(path + prefix), path + prefix, - utf8.RuneCountInString(path + prefix), path + prefix, - bucket, // b.name = ? - path + "%", // o.object_id LIKE ? - utf8.RuneCountInString(path), path, // SUBSTR(o.object_id, 1, ?) = ? - dirID, // d.db_parent_id = ? - } - } else { - objectsQueryParams = []interface{}{ - path, // o.object_id != ? - dirID, bucket, // o.db_directory_id = ? AND b.name = ? - bucket, - path + "%", // o.object_id LIKE ? - utf8.RuneCountInString(path), path, // SUBSTR(o.object_id, 1, ?) = ? - dirID, // d.db_parent_id = ? - } - } - - // build marker expr - markerExpr := "1 = 1" - var markerParams []interface{} - if usingMarker { - switch sortBy { - case api.ObjectSortByHealth: - var markerHealth float64 - if err = s.db. - WithContext(ctx). - Raw(fmt.Sprintf(`SELECT Health FROM (SELECT * FROM (%s) h WHERE ObjectName >= ? ORDER BY ObjectName LIMIT 1) as n`, objectsQuery), append(objectsQueryParams, marker)...). - Scan(&markerHealth). - Error; err != nil { - return - } - - if sortDir == api.ObjectSortDirAsc { - markerExpr = "(Health > ? OR (Health = ? AND ObjectName > ?))" - markerParams = []interface{}{markerHealth, markerHealth, marker} - } else { - markerExpr = "(Health = ? AND ObjectName > ?) OR Health < ?" - markerParams = []interface{}{markerHealth, marker, markerHealth} - } - case api.ObjectSortBySize: - var markerSize float64 - if err = s.db. - WithContext(ctx). - Raw(fmt.Sprintf(`SELECT Size FROM (SELECT * FROM (%s) s WHERE ObjectName >= ? ORDER BY ObjectName LIMIT 1) as n`, objectsQuery), append(objectsQueryParams, marker)...). - Scan(&markerSize). - Error; err != nil { - return - } - - if sortDir == api.ObjectSortDirAsc { - markerExpr = "(Size > ? OR (Size = ? AND ObjectName > ?))" - markerParams = []interface{}{markerSize, markerSize, marker} - } else { - markerExpr = "(Size = ? AND ObjectName > ?) OR Size < ?" - markerParams = []interface{}{markerSize, marker, markerSize} - } - case api.ObjectSortByName: - if sortDir == api.ObjectSortDirAsc { - markerExpr = "ObjectName > ?" - } else { - markerExpr = "ObjectName < ?" - } - markerParams = []interface{}{marker} - default: - panic("unhandled sortBy") // developer error - } - } - - // build order clause - if sortBy == api.ObjectSortByName { - sortBy = "ObjectName" - } - orderByClause := fmt.Sprintf("%s %s", sortBy, sortDir) - if sortBy != "ObjectName" { - orderByClause += ", ObjectName" - } - - var rows []rawObjectMetadata - query := fmt.Sprintf(`SELECT * FROM (%s ORDER BY %s) AS n WHERE %s LIMIT ? OFFSET ?`, - objectsQuery, - orderByClause, - markerExpr, - ) - parameters := append(append(objectsQueryParams, markerParams...), limit, offset) - - if err = s.db. - WithContext(ctx). - Raw(query, parameters...). - Scan(&rows). - Error; err != nil { - return - } - - // trim last element if we have more - if len(rows) == limit { - hasMore = true - rows = rows[:len(rows)-1] - } - - // convert rows into metadata - for _, row := range rows { - metadata = append(metadata, row.convert()) - } + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + metadata, hasMore, err = tx.ObjectEntries(ctx, bucket, path, prefix, sortBy, sortDir, marker, offset, limit) + return err + }) return } func (s *SQLStore) Object(ctx context.Context, bucket, path string) (obj api.Object, err error) { - err = s.retryTransaction(ctx, func(tx *gorm.DB) error { - obj, err = s.object(tx, bucket, path) + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + obj, err = tx.Object(ctx, bucket, path) return err }) return @@ -1056,16 +359,11 @@ func (s *SQLStore) RecordContractSpending(ctx context.Context, records []api.Con } metrics := make([]api.ContractMetric, 0, len(squashedRecords)) for fcid, newSpending := range squashedRecords { - err := s.retryTransaction(ctx, func(tx *gorm.DB) error { - var contract dbContract - err := tx.Model(&dbContract{}). - Where("fcid = ?", fileContractID(fcid)). - Joins("Host"). - Take(&contract).Error - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil // contract not found, continue with next one + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + contract, err := tx.Contract(ctx, fcid) + if errors.Is(err, api.ErrContractNotFound) { } else if err != nil { - return err + return fmt.Errorf("failed to fetch contract: %w", err) } remainingCollateral := types.ZeroCurrency @@ -1075,95 +373,50 @@ func (s *SQLStore) RecordContractSpending(ctx context.Context, records []api.Con m := api.ContractMetric{ Timestamp: api.TimeNow(), ContractID: fcid, - HostKey: types.PublicKey(contract.Host.PublicKey), + HostKey: contract.HostKey, RemainingCollateral: remainingCollateral, RemainingFunds: latestValues[fcid].validRenterPayout, RevisionNumber: latestValues[fcid].revision, - UploadSpending: types.Currency(contract.UploadSpending).Add(newSpending.Uploads), - DownloadSpending: types.Currency(contract.DownloadSpending).Add(newSpending.Downloads), - FundAccountSpending: types.Currency(contract.FundAccountSpending).Add(newSpending.FundAccount), - DeleteSpending: types.Currency(contract.DeleteSpending).Add(newSpending.Deletions), - ListSpending: types.Currency(contract.ListSpending).Add(newSpending.SectorRoots), + UploadSpending: contract.Spending.Uploads.Add(newSpending.Uploads), + DownloadSpending: contract.Spending.Downloads.Add(newSpending.Downloads), + FundAccountSpending: contract.Spending.FundAccount.Add(newSpending.FundAccount), + DeleteSpending: contract.Spending.Deletions.Add(newSpending.Deletions), + ListSpending: contract.Spending.SectorRoots.Add(newSpending.SectorRoots), } metrics = append(metrics, m) - updates := make(map[string]interface{}) + var updates api.ContractSpending if !newSpending.Uploads.IsZero() { - updates["upload_spending"] = currency(m.UploadSpending) + updates.Uploads = m.UploadSpending } if !newSpending.Downloads.IsZero() { - updates["download_spending"] = currency(m.DownloadSpending) + updates.Downloads = m.DownloadSpending } if !newSpending.FundAccount.IsZero() { - updates["fund_account_spending"] = currency(m.FundAccountSpending) + updates.FundAccount = m.FundAccountSpending } if !newSpending.Deletions.IsZero() { - updates["delete_spending"] = currency(m.DeleteSpending) + updates.Deletions = m.DeleteSpending } if !newSpending.SectorRoots.IsZero() { - updates["list_spending"] = currency(m.ListSpending) + updates.SectorRoots = m.ListSpending } - updates["revision_number"] = latestValues[fcid].revision - updates["size"] = latestValues[fcid].size - return tx.Model(&contract).Updates(updates).Error + return tx.RecordContractSpending(ctx, fcid, latestValues[fcid].revision, latestValues[fcid].size, updates) }) if err != nil { return err } } - if err := s.RecordContractMetric(ctx, metrics...); err != nil { - s.logger.Errorw("failed to record contract metrics", zap.Error(err)) - } - return nil -} - -func (s *SQLStore) addKnownContract(fcid types.FileContractID) { - s.mu.Lock() - defer s.mu.Unlock() - s.knownContracts[fcid] = struct{}{} -} - -func (s *SQLStore) isKnownContract(fcid types.FileContractID) bool { - s.mu.Lock() - defer s.mu.Unlock() - _, found := s.knownContracts[fcid] - return found -} - -func fetchUsedContracts(tx *gorm.DB, usedContractsByHost map[types.PublicKey]map[types.FileContractID]struct{}) (map[types.FileContractID]dbContract, error) { - // flatten map to get all used contract ids - fcids := make([]fileContractID, 0, len(usedContractsByHost)) - for _, hostFCIDs := range usedContractsByHost { - for fcid := range hostFCIDs { - fcids = append(fcids, fileContractID(fcid)) + if len(metrics) > 0 { + if err := s.RecordContractMetric(ctx, metrics...); err != nil { + s.logger.Errorw("failed to record contract metrics", zap.Error(err)) } } - - // fetch all contracts, take into account renewals - var contracts []dbContract - err := tx.Model(&dbContract{}). - Joins("Host"). - Where("fcid IN (?) OR renewed_from IN (?)", fcids, fcids). - Find(&contracts).Error - if err != nil { - return nil, err - } - - // build map of used contracts - usedContracts := make(map[types.FileContractID]dbContract, len(contracts)) - for _, c := range contracts { - if _, used := usedContractsByHost[types.PublicKey(c.Host.PublicKey)][types.FileContractID(c.FCID)]; used { - usedContracts[types.FileContractID(c.FCID)] = c - } - if _, used := usedContractsByHost[types.PublicKey(c.Host.PublicKey)][types.FileContractID(c.RenewedFrom)]; used { - usedContracts[types.FileContractID(c.RenewedFrom)] = c - } - } - return usedContracts, nil + return nil } func (s *SQLStore) RenameObject(ctx context.Context, bucket, keyOld, keyNew string, force bool) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { // create new dir dirID, err := tx.MakeDirsForPath(ctx, keyNew) if err != nil { @@ -1181,7 +434,7 @@ func (s *SQLStore) RenameObject(ctx context.Context, bucket, keyOld, keyNew stri } func (s *SQLStore) RenameObjects(ctx context.Context, bucket, prefixOld, prefixNew string, force bool) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { // create new dir dirID, err := tx.MakeDirsForPath(ctx, prefixNew) if err != nil { @@ -1200,15 +453,11 @@ func (s *SQLStore) FetchPartialSlab(ctx context.Context, ec object.EncryptionKey } func (s *SQLStore) AddPartialSlab(ctx context.Context, data []byte, minShards, totalShards uint8, contractSet string) ([]object.SlabSlice, int64, error) { - var contractSetID uint - if err := s.db.Raw("SELECT id FROM contract_sets WHERE name = ?", contractSet).Scan(&contractSetID).Error; err != nil { - return nil, 0, err - } - return s.slabBufferMgr.AddPartialSlab(ctx, data, minShards, totalShards, contractSetID) + return s.slabBufferMgr.AddPartialSlab(ctx, data, minShards, totalShards, contractSet) } func (s *SQLStore) CopyObject(ctx context.Context, srcBucket, dstBucket, srcPath, dstPath, mimeType string, metadata api.ObjectUserMetadata) (om api.ObjectMetadata, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { if srcBucket != dstBucket || srcPath != dstPath { _, err = tx.DeleteObject(ctx, dstBucket, dstPath) if err != nil { @@ -1222,34 +471,13 @@ func (s *SQLStore) CopyObject(ctx context.Context, srcBucket, dstBucket, srcPath } func (s *SQLStore) DeleteHostSector(ctx context.Context, hk types.PublicKey, root types.Hash256) (deletedSectors int, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { deletedSectors, err = tx.DeleteHostSector(ctx, hk, root) return err }) return } -func (s *SQLStore) dirID(tx *gorm.DB, dirPath string) (uint, error) { - if !strings.HasPrefix(dirPath, "/") { - return 0, fmt.Errorf("path must start with /") - } else if !strings.HasSuffix(dirPath, "/") { - return 0, fmt.Errorf("path must end with /") - } - - if dirPath == "/" { - return 1, nil // root dir returned - } - - var dir dbDirectory - if err := tx.Where("name", dirPath). - Select("id"). - Take(&dir). - Error; err != nil { - return 0, fmt.Errorf("failed to fetch directory: %w", err) - } - return dir.ID, nil -} - func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, eTag, mimeType string, metadata api.ObjectUserMetadata, o object.Object) error { // Sanity check input. for _, s := range o.Slabs { @@ -1263,7 +491,7 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, // UpdateObject is ACID. var prune bool - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { // Try to delete. We want to get rid of the object and its slices if it // exists. // @@ -1304,7 +532,7 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, func (s *SQLStore) RemoveObject(ctx context.Context, bucket, path string) error { var prune bool - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { prune, err = tx.DeleteObject(ctx, bucket, path) return }) @@ -1324,7 +552,7 @@ func (s *SQLStore) RemoveObjects(ctx context.Context, bucket, prefix string) err start := time.Now() var done bool var duration time.Duration - if err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + if err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { deleted, err := tx.DeleteObjects(ctx, bucket, prefix, objectDeleteBatchSizes[batchSizeIdx]) if err != nil { return err @@ -1351,30 +579,23 @@ func (s *SQLStore) RemoveObjects(ctx context.Context, bucket, prefix string) err return nil } -func (s *SQLStore) Slab(ctx context.Context, key object.EncryptionKey) (object.Slab, error) { - k, err := key.MarshalBinary() - if err != nil { - return object.Slab{}, err - } - var slab dbSlab - tx := s.db.Where(&dbSlab{Key: k}). - Preload("Shards.Contracts.Host"). - Take(&slab) - if errors.Is(tx.Error, gorm.ErrRecordNotFound) { - return object.Slab{}, api.ErrSlabNotFound - } - return slab.convert() +func (s *SQLStore) Slab(ctx context.Context, key object.EncryptionKey) (slab object.Slab, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + slab, err = tx.Slab(ctx, key) + return err + }) + return } -func (ss *SQLStore) UpdateSlab(ctx context.Context, s object.Slab, contractSet string) error { +func (s *SQLStore) UpdateSlab(ctx context.Context, slab object.Slab, contractSet string) error { // sanity check the shards don't contain an empty root - for _, s := range s.Shards { - if s.Root == (types.Hash256{}) { + for _, shard := range slab.Shards { + if shard.Root == (types.Hash256{}) { return errors.New("shard root can never be the empty root") } } // Sanity check input. - for i, shard := range s.Shards { + for i, shard := range slab.Shards { // Verify that all hosts have a contract. if len(shard.Contracts) == 0 { return fmt.Errorf("missing hosts for slab %d", i) @@ -1382,24 +603,16 @@ func (ss *SQLStore) UpdateSlab(ctx context.Context, s object.Slab, contractSet s } // Update slab. - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { - return tx.UpdateSlab(ctx, s, contractSet, s.Contracts()) + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + return tx.UpdateSlab(ctx, slab, contractSet, slab.Contracts()) }) } func (s *SQLStore) RefreshHealth(ctx context.Context) error { - var nSlabs int64 - if err := s.db.Model(&dbSlab{}).Count(&nSlabs).Error; err != nil { - return err - } - if nSlabs == 0 { - return nil // nothing to do - } - for { // update slabs var rowsAffected int64 - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { rowsAffected, err = tx.UpdateSlabHealth(ctx, refreshHealthBatchSize, refreshHealthMinHealthValidity, refreshHealthMaxHealthValidity) return }) @@ -1421,272 +634,38 @@ func (s *SQLStore) RefreshHealth(ctx context.Context) error { // UnhealthySlabs returns up to 'limit' slabs that do not reach full redundancy // in the given contract set. These slabs need to be migrated to good contracts // so they are restored to full health. -func (s *SQLStore) UnhealthySlabs(ctx context.Context, healthCutoff float64, set string, limit int) ([]api.UnhealthySlab, error) { +func (s *SQLStore) UnhealthySlabs(ctx context.Context, healthCutoff float64, set string, limit int) (slabs []api.UnhealthySlab, err error) { if limit <= -1 { limit = math.MaxInt } - - var rows []struct { - Key []byte - Health float64 - } - - if err := s.retryTransaction(ctx, func(tx *gorm.DB) error { - return tx.Select("slabs.key, slabs.health"). - Joins("INNER JOIN contract_sets cs ON slabs.db_contract_set_id = cs.id"). - Model(&dbSlab{}). - Where("health <= ? AND cs.name = ?", healthCutoff, set). - Order("health ASC"). - Limit(limit). - Find(&rows). - Error - }); err != nil { - return nil, err - } - - slabs := make([]api.UnhealthySlab, len(rows)) - for i, row := range rows { - var key object.EncryptionKey - if err := key.UnmarshalBinary(row.Key); err != nil { - return nil, err - } - slabs[i] = api.UnhealthySlab{ - Key: key, - Health: row.Health, - } - } - return slabs, nil -} - -// object retrieves an object from the store. -func (s *SQLStore) object(tx *gorm.DB, bucket, path string) (api.Object, error) { - // fetch raw object data - raw, err := s.objectRaw(tx, bucket, path) - if errors.Is(err, gorm.ErrRecordNotFound) || (err == nil && len(raw) == 0) { - return api.Object{}, api.ErrObjectNotFound - } else if err != nil { - return api.Object{}, err - } - - // hydrate raw object data - return s.objectHydrate(tx, bucket, path, raw) -} - -// objectHydrate hydrates a raw object and returns an api.Object. -func (s *SQLStore) objectHydrate(tx *gorm.DB, bucket, path string, obj rawObject) (api.Object, error) { - // parse object key - var key object.EncryptionKey - if err := key.UnmarshalBinary(obj[0].ObjectKey); err != nil { - return api.Object{}, err - } - - // filter out slabs without slab ID and buffered slabs - this is expected - // for an empty object or objects that end with a partial slab. - var filtered rawObject - minHealth := math.MaxFloat64 - for _, sector := range obj { - if sector.SlabID != 0 { - filtered = append(filtered, sector) - if sector.SlabHealth < minHealth { - minHealth = sector.SlabHealth - } - } - } - - // hydrate all slabs - slabs := make([]object.SlabSlice, 0, len(filtered)) - if len(filtered) > 0 { - var start int - // create a helper function to add a slab and update the state - addSlab := func(end int) error { - if slab, err := filtered[start:end].toSlabSlice(); err != nil { - return err - } else { - slabs = append(slabs, slab) - start = end - } - return nil - } - - curr := filtered[0] - for j, sector := range filtered { - if sector.ObjectIndex == 0 { - return api.Object{}, api.ErrObjectCorrupted - } else if sector.SectorIndex == 0 && !sector.SlabBuffered { - return api.Object{}, api.ErrObjectCorrupted - } - if sector.ObjectIndex != curr.ObjectIndex { - if err := addSlab(j); err != nil { - return api.Object{}, err - } - curr = sector - } - } - if err := addSlab(len(filtered)); err != nil { - return api.Object{}, err - } - } - - // fetch object metadata - metadata, err := s.objectMetadata(tx, bucket, path) - if err != nil { - return api.Object{}, err - } - - // return object - return api.Object{ - Metadata: metadata, - ObjectMetadata: newObjectMetadata( - obj[0].ObjectName, - obj[0].ObjectETag, - obj[0].ObjectMimeType, - obj[0].ObjectHealth, - obj[0].ObjectModTime, - obj[0].ObjectSize, - ), - Object: &object.Object{ - Key: key, - Slabs: slabs, - }, - }, nil + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + slabs, err = tx.UnhealthySlabs(ctx, healthCutoff, set, limit) + return err + }) + return } // ObjectMetadata returns an object's metadata -func (s *SQLStore) ObjectMetadata(ctx context.Context, bucket, path string) (api.Object, error) { - var resp api.Object - err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - var obj dbObject - err := tx.Model(&dbObject{}). - Joins("INNER JOIN buckets b ON objects.db_bucket_id = b.id"). - Where("b.name", bucket). - Where("object_id", path). - Take(&obj). - Error - if errors.Is(err, gorm.ErrRecordNotFound) { - return api.ErrObjectNotFound - } else if err != nil { - return err - } - oum, err := s.objectMetadata(tx, bucket, path) - if err != nil { - return err - } - resp = api.Object{ - ObjectMetadata: newObjectMetadata( - obj.ObjectID, - obj.Etag, - obj.MimeType, - obj.Health, - obj.CreatedAt, - obj.Size, - ), - Metadata: oum, - } - return nil +func (s *SQLStore) ObjectMetadata(ctx context.Context, bucket, path string) (obj api.Object, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + obj, err = tx.ObjectMetadata(ctx, bucket, path) + return err }) - return resp, err -} - -func (s *SQLStore) objectMetadata(tx *gorm.DB, bucket, path string) (api.ObjectUserMetadata, error) { - var rows []dbObjectUserMetadata - err := tx. - Model(&dbObjectUserMetadata{}). - Table("object_user_metadata oum"). - Joins("INNER JOIN objects o ON oum.db_object_id = o.id"). - Joins("INNER JOIN buckets b ON o.db_bucket_id = b.id"). - Where("o.object_id = ? AND b.name = ?", path, bucket). - Find(&rows). - Error - if err != nil { - return nil, err - } - metadata := make(api.ObjectUserMetadata) - for _, row := range rows { - metadata[row.Key] = row.Value - } - return metadata, nil -} - -func newObjectMetadata(name, etag, mimeType string, health float64, modTime time.Time, size int64) api.ObjectMetadata { - return api.ObjectMetadata{ - ETag: etag, - Health: health, - ModTime: api.TimeRFC3339(modTime.UTC()), - Name: name, - Size: size, - MimeType: mimeType, - } -} - -func (s *SQLStore) objectRaw(txn *gorm.DB, bucket string, path string) (rows rawObject, err error) { - // NOTE: we LEFT JOIN here because empty objects are valid and need to be - // included in the result set, when we convert the rawObject before - // returning it we'll check for SlabID and/or SectorID being 0 and act - // accordingly - err = txn. - Select("o.id as ObjectID, o.health as ObjectHealth, sli.object_index as ObjectIndex, o.key as ObjectKey, o.object_id as ObjectName, o.size as ObjectSize, o.mime_type as ObjectMimeType, o.created_at as ObjectModTime, o.etag as ObjectETag, sli.object_index, sli.offset as SliceOffset, sli.length as SliceLength, sla.id as SlabID, sla.health as SlabHealth, sla.key as SlabKey, sla.min_shards as SlabMinShards, bs.id IS NOT NULL AS SlabBuffered, sec.slab_index as SectorIndex, sec.root as SectorRoot, sec.latest_host as LatestHost, c.fcid as FCID, h.public_key as HostKey"). - Model(&dbObject{}). - Table("objects o"). - Joins("INNER JOIN buckets b ON o.db_bucket_id = b.id"). - Joins("LEFT JOIN slices sli ON o.id = sli.`db_object_id`"). - Joins("LEFT JOIN slabs sla ON sli.db_slab_id = sla.`id`"). - Joins("LEFT JOIN sectors sec ON sla.id = sec.`db_slab_id`"). - Joins("LEFT JOIN contract_sectors cs ON sec.id = cs.`db_sector_id`"). - Joins("LEFT JOIN contracts c ON cs.`db_contract_id` = c.`id`"). - Joins("LEFT JOIN hosts h ON c.host_id = h.id"). - Joins("LEFT JOIN buffered_slabs bs ON sla.db_buffered_slab_id = bs.`id`"). - Where("o.object_id = ? AND b.name = ?", path, bucket). - Order("sli.object_index ASC"). - Order("sec.slab_index ASC"). - Scan(&rows). - Error return } -// contract retrieves a contract from the store. -func (s *SQLStore) contract(ctx context.Context, id fileContractID) (dbContract, error) { - return contract(s.db.WithContext(ctx), id) -} - // PackedSlabsForUpload returns up to 'limit' packed slabs that are ready for // uploading. They are locked for 'lockingDuration' time before being handed out // again. func (s *SQLStore) PackedSlabsForUpload(ctx context.Context, lockingDuration time.Duration, minShards, totalShards uint8, set string, limit int) ([]api.PackedSlab, error) { - var contractSetID uint - if err := s.db.WithContext(ctx).Raw("SELECT id FROM contract_sets WHERE name = ?", set). - Scan(&contractSetID).Error; err != nil { - return nil, err - } - return s.slabBufferMgr.SlabsForUpload(ctx, lockingDuration, minShards, totalShards, contractSetID, limit) + return s.slabBufferMgr.SlabsForUpload(ctx, lockingDuration, minShards, totalShards, set, limit) } func (s *SQLStore) ObjectsBySlabKey(ctx context.Context, bucket string, slabKey object.EncryptionKey) (metadata []api.ObjectMetadata, err error) { - var rows []rawObjectMetadata - key, err := slabKey.MarshalBinary() - if err != nil { - return nil, err - } - - err = s.retryTransaction(ctx, func(tx *gorm.DB) error { - return tx.Raw(` -SELECT DISTINCT obj.object_id as ObjectName, obj.size as Size, obj.mime_type as MimeType, sla.health as Health -FROM slabs sla -INNER JOIN slices sli ON sli.db_slab_id = sla.id -INNER JOIN objects obj ON sli.db_object_id = obj.id -INNER JOIN buckets b ON obj.db_bucket_id = b.id AND b.name = ? -WHERE sla.key = ? - `, bucket, key). - Scan(&rows). - Error + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + metadata, err = tx.ObjectsBySlabKey(ctx, bucket, slabKey) + return err }) - if err != nil { - return nil, err - } - - // convert rows - for _, row := range rows { - metadata = append(metadata, row.convert()) - } return } @@ -1702,14 +681,15 @@ func (s *SQLStore) MarkPackedSlabsUploaded(ctx context.Context, slabs []api.Uplo } } } - var fileName string - err := s.retryTransaction(ctx, func(tx *gorm.DB) error { - for _, slab := range slabs { - var err error - fileName, err = s.markPackedSlabUploaded(tx, slab) + var fileNames []string + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + fileNames = make([]string, len(slabs)) + for i, slab := range slabs { + fileName, err := tx.MarkPackedSlabUploaded(ctx, slab) if err != nil { return err } + fileNames[i] = fileName } return nil }) @@ -1718,144 +698,10 @@ func (s *SQLStore) MarkPackedSlabsUploaded(ctx context.Context, slabs []api.Uplo } // Delete buffer from disk. - s.slabBufferMgr.RemoveBuffers(fileName) + s.slabBufferMgr.RemoveBuffers(fileNames...) return nil } -func (s *SQLStore) markPackedSlabUploaded(tx *gorm.DB, slab api.UploadedPackedSlab) (string, error) { - // collect all used contracts - usedContracts := slab.Contracts() - contracts, err := fetchUsedContracts(tx, usedContracts) - if err != nil { - return "", err - } - - // find the slab - var sla dbSlab - if err := tx.Where("db_buffered_slab_id", slab.BufferID). - Take(&sla).Error; err != nil { - return "", err - } - - // update the slab - if err := tx.Model(&dbSlab{}). - Where("id", sla.ID). - Updates(map[string]interface{}{ - "db_buffered_slab_id": nil, - }).Error; err != nil { - return "", fmt.Errorf("failed to set buffered slab NULL: %w", err) - } - - // delete buffer - var fileName string - if err := tx.Raw("SELECT filename FROM buffered_slabs WHERE id = ?", slab.BufferID). - Scan(&fileName).Error; err != nil { - return "", err - } - if err := tx.Exec("DELETE FROM buffered_slabs WHERE id = ?", slab.BufferID).Error; err != nil { - return "", err - } - - // add the shards to the slab - var shards []dbSector - for i := range slab.Shards { - sector := dbSector{ - DBSlabID: sla.ID, - SlabIndex: i + 1, - LatestHost: publicKey(slab.Shards[i].LatestHost), - Root: slab.Shards[i].Root[:], - } - for _, fcids := range slab.Shards[i].Contracts { - for _, fcid := range fcids { - if c, ok := contracts[fcid]; ok { - sector.Contracts = append(sector.Contracts, c) - } - } - } - shards = append(shards, sector) - } - if err := tx.Create(&shards).Error; err != nil { - return "", fmt.Errorf("failed to create shards: %w", err) - } - return fileName, nil -} - -// contract retrieves a contract from the store. -func contract(tx *gorm.DB, id fileContractID) (contract dbContract, err error) { - err = tx. - Where(&dbContract{ContractCommon: ContractCommon{FCID: id}}). - Joins("Host"). - Take(&contract). - Error - - if errors.Is(err, gorm.ErrRecordNotFound) { - err = api.ErrContractNotFound - } - return -} - -// contractsForHost retrieves all contracts for the given host -func contractsForHost(tx *gorm.DB, host dbHost) (contracts []dbContract, err error) { - err = tx. - Where(&dbContract{HostID: host.ID}). - Joins("Host"). - Find(&contracts). - Error - return -} - -func newContract(hostID uint, fcid, renewedFrom types.FileContractID, contractPrice, totalCost types.Currency, startHeight, windowStart, windowEnd, size uint64, state contractState) dbContract { - return dbContract{ - HostID: hostID, - ContractSets: nil, // new contract isn't in a set yet - - ContractCommon: ContractCommon{ - FCID: fileContractID(fcid), - RenewedFrom: fileContractID(renewedFrom), - - ContractPrice: currency(contractPrice), - State: state, - TotalCost: currency(totalCost), - RevisionNumber: "0", - Size: size, - StartHeight: startHeight, - WindowStart: windowStart, - WindowEnd: windowEnd, - - UploadSpending: zeroCurrency, - DownloadSpending: zeroCurrency, - FundAccountSpending: zeroCurrency, - DeleteSpending: zeroCurrency, - ListSpending: zeroCurrency, - }, - } -} - -// addContract adds a contract to the store. -func addContract(tx *gorm.DB, c rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state contractState) (dbContract, error) { - fcid := c.ID() - - // Find host. - var host dbHost - err := tx.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(c.HostKey())}). - Find(&host).Error - if err != nil { - return dbContract{}, err - } - - // Create contract. - contract := newContract(host.ID, fcid, renewedFrom, contractPrice, totalCost, startHeight, c.Revision.WindowStart, c.Revision.WindowEnd, c.Revision.Filesize, state) - - // Insert contract. - err = tx.Create(&contract).Error - if err != nil { - return dbContract{}, err - } - // Populate host. - contract.Host = host - return contract, nil -} - func (s *SQLStore) pruneSlabsLoop() { for { select { @@ -1868,7 +714,7 @@ func (s *SQLStore) pruneSlabsLoop() { pruneSuccess := true for { var deleted int64 - err := s.bMain.Transaction(s.shutdownCtx, func(dt sql.DatabaseTx) error { + err := s.db.Transaction(s.shutdownCtx, func(dt sql.DatabaseTx) error { var err error deleted, err = dt.PruneSlabs(s.shutdownCtx, slabPruningBatchSize) return err @@ -1896,7 +742,7 @@ func (s *SQLStore) pruneSlabsLoop() { } // prune dirs - err := s.bMain.Transaction(s.shutdownCtx, func(dt sql.DatabaseTx) error { + err := s.db.Transaction(s.shutdownCtx, func(dt sql.DatabaseTx) error { return dt.PruneEmptydirs(s.shutdownCtx) }) if err != nil { @@ -1935,7 +781,7 @@ func (s *SQLStore) triggerSlabPruning() { func (s *SQLStore) invalidateSlabHealthByFCID(ctx context.Context, fcids []types.FileContractID) error { for { var affected int64 - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { affected, err = tx.InvalidateSlabHealthByFCID(ctx, fcids, refreshHealthBatchSize) return }) @@ -1952,148 +798,9 @@ func (s *SQLStore) invalidateSlabHealthByFCID(ctx context.Context, fcids []types // a delimiter for now (see backend.go) but it would be interesting to have // arbitrary 'delim' support in ListObjects. func (s *SQLStore) ListObjects(ctx context.Context, bucket, prefix, sortBy, sortDir, marker string, limit int) (resp api.ObjectsListResponse, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { resp, err = tx.ListObjects(ctx, bucket, prefix, sortBy, sortDir, marker, limit) return err }) return } - -func (ss *SQLStore) processConsensusChangeContracts(cc modules.ConsensusChange) { - height := uint64(cc.InitialHeight()) - for _, sb := range cc.RevertedBlocks { - var b types.Block - convertToCore(sb, (*types.V1Block)(&b)) - - // revert contracts that got reorged to "pending". - for _, txn := range b.Transactions { - // handle contracts - for i := range txn.FileContracts { - fcid := txn.FileContractID(i) - if ss.isKnownContract(fcid) { - ss.unappliedContractState[fcid] = contractStatePending // revert from 'active' to 'pending' - ss.logger.Infow("contract state changed: active -> pending", - "fcid", fcid, - "reason", "contract reverted") - } - } - // handle contract revision - for _, rev := range txn.FileContractRevisions { - if ss.isKnownContract(rev.ParentID) { - if rev.RevisionNumber == math.MaxUint64 && rev.Filesize == 0 { - ss.unappliedContractState[rev.ParentID] = contractStateActive // revert from 'complete' to 'active' - ss.logger.Infow("contract state changed: complete -> active", - "fcid", rev.ParentID, - "reason", "final revision reverted") - } - } - } - // handle storage proof - for _, sp := range txn.StorageProofs { - if ss.isKnownContract(sp.ParentID) { - ss.unappliedContractState[sp.ParentID] = contractStateActive // revert from 'complete' to 'active' - ss.logger.Infow("contract state changed: complete -> active", - "fcid", sp.ParentID, - "reason", "storage proof reverted") - } - } - } - height-- - } - - for _, sb := range cc.AppliedBlocks { - var b types.Block - convertToCore(sb, (*types.V1Block)(&b)) - - // Update RevisionHeight and RevisionNumber for our contracts. - for _, txn := range b.Transactions { - // handle contracts - for i := range txn.FileContracts { - fcid := txn.FileContractID(i) - if ss.isKnownContract(fcid) { - ss.unappliedContractState[fcid] = contractStateActive // 'pending' -> 'active' - ss.logger.Infow("contract state changed: pending -> active", - "fcid", fcid, - "reason", "contract confirmed") - } - } - // handle contract revision - for _, rev := range txn.FileContractRevisions { - if ss.isKnownContract(rev.ParentID) { - ss.unappliedRevisions[types.FileContractID(rev.ParentID)] = revisionUpdate{ - height: height, - number: rev.RevisionNumber, - size: rev.Filesize, - } - if rev.RevisionNumber == math.MaxUint64 && rev.Filesize == 0 { - ss.unappliedContractState[rev.ParentID] = contractStateComplete // renewed: 'active' -> 'complete' - ss.logger.Infow("contract state changed: active -> complete", - "fcid", rev.ParentID, - "reason", "final revision confirmed") - } - } - } - // handle storage proof - for _, sp := range txn.StorageProofs { - if ss.isKnownContract(sp.ParentID) { - ss.unappliedProofs[sp.ParentID] = height - ss.unappliedContractState[sp.ParentID] = contractStateComplete // storage proof: 'active' -> 'complete' - ss.logger.Infow("contract state changed: active -> complete", - "fcid", sp.ParentID, - "reason", "storage proof confirmed") - } - } - } - height++ - } -} - -func validateSort(sortBy, sortDir string) error { - allowed := func(s string, allowed ...string) bool { - for _, a := range allowed { - if strings.EqualFold(s, a) { - return true - } - } - return false - } - - if !allowed(sortDir, "", api.ObjectSortDirAsc, api.ObjectSortDirDesc) { - return fmt.Errorf("invalid dir '%v', allowed values are '%v' and '%v'; %w", sortDir, api.ObjectSortDirAsc, api.ObjectSortDirDesc, api.ErrInvalidObjectSortParameters) - } - - if !allowed(sortBy, "", api.ObjectSortByHealth, api.ObjectSortByName, api.ObjectSortBySize) { - return fmt.Errorf("invalid sort by '%v', allowed values are '%v', '%v' and '%v'; %w", sortBy, api.ObjectSortByHealth, api.ObjectSortByName, api.ObjectSortBySize, api.ErrInvalidObjectSortParameters) - } - return nil -} - -// upsertSectors creates a sector or updates it if it exists already. The -// resulting ID is set on the input sector. -func upsertSectors(tx *gorm.DB, sectors []dbSector) ([]uint, error) { - if len(sectors) == 0 { - return nil, nil // nothing to do - } - err := tx. - Clauses(clause.OnConflict{ - UpdateAll: true, - Columns: []clause.Column{{Name: "root"}}, - }). - CreateInBatches(§ors, sectorInsertionBatchSize). - Error - if err != nil { - return nil, err - } - - sectorIDs := make([]uint, len(sectors)) - for i := range sectors { - var id uint - if err := tx.Model(dbSector{}). - Where("root", sectors[i].Root). - Select("id").Take(&id).Error; err != nil { - return nil, err - } - sectorIDs[i] = id - } - return sectorIDs, nil -} diff --git a/stores/metadata_test.go b/stores/metadata_test.go index 7462f6187..6d0639e5d 100644 --- a/stores/metadata_test.go +++ b/stores/metadata_test.go @@ -3,12 +3,12 @@ package stores import ( "bytes" "context" + dsql "database/sql" "encoding/hex" "errors" "fmt" "os" "reflect" - "sort" "strings" "sync" "testing" @@ -23,13 +23,28 @@ import ( "go.sia.tech/renterd/internal/test" "go.sia.tech/renterd/object" sql "go.sia.tech/renterd/stores/sql" - "gorm.io/gorm" - "gorm.io/gorm/schema" "lukechampine.com/frand" ) +func (s *testSQLStore) InsertSlab(slab object.Slab) { + s.t.Helper() + obj := object.Object{ + Key: object.GenerateEncryptionKey(), + Slabs: object.SlabSlices{ + object.SlabSlice{ + Slab: slab, + }, + }, + } + err := s.UpdateObject(context.Background(), api.DefaultBucketName, hex.EncodeToString(frand.Bytes(16)), testContractSet, "", "", api.ObjectUserMetadata{}, obj) + if err != nil { + s.t.Fatal(err) + } +} + func (s *SQLStore) RemoveObjectBlocking(ctx context.Context, bucket, path string) error { ts := time.Now() + time.Sleep(time.Millisecond) if err := s.RemoveObject(ctx, bucket, path); err != nil { return err } @@ -38,6 +53,7 @@ func (s *SQLStore) RemoveObjectBlocking(ctx context.Context, bucket, path string func (s *SQLStore) RemoveObjectsBlocking(ctx context.Context, bucket, prefix string) error { ts := time.Now() + time.Sleep(time.Millisecond) if err := s.RemoveObjects(ctx, bucket, prefix); err != nil { return err } @@ -46,6 +62,7 @@ func (s *SQLStore) RemoveObjectsBlocking(ctx context.Context, bucket, prefix str func (s *SQLStore) RenameObjectBlocking(ctx context.Context, bucket, keyOld, keyNew string, force bool) error { ts := time.Now() + time.Sleep(time.Millisecond) if err := s.RenameObject(ctx, bucket, keyOld, keyNew, force); err != nil { return err } @@ -54,6 +71,7 @@ func (s *SQLStore) RenameObjectBlocking(ctx context.Context, bucket, keyOld, key func (s *SQLStore) RenameObjectsBlocking(ctx context.Context, bucket, prefixOld, prefixNew string, force bool) error { ts := time.Now() + time.Sleep(time.Millisecond) if err := s.RenameObjects(ctx, bucket, prefixOld, prefixNew, force); err != nil { return err } @@ -65,6 +83,7 @@ func (s *SQLStore) UpdateObjectBlocking(ctx context.Context, bucket, path, contr _, err := s.Object(ctx, bucket, path) if err == nil { ts = time.Now() + time.Sleep(time.Millisecond) } if err := s.UpdateObject(ctx, bucket, path, contractSet, eTag, mimeType, metadata, o); err != nil { return err @@ -76,7 +95,7 @@ func (s *SQLStore) waitForPruneLoop(ts time.Time) error { return test.Retry(100, 100*time.Millisecond, func() error { s.mu.Lock() defer s.mu.Unlock() - if s.lastPrunedAt.Before(ts) { + if !s.lastPrunedAt.After(ts) { return errors.New("slabs have not been pruned yet") } return nil @@ -95,15 +114,16 @@ func randomMultisigUC() types.UnlockConditions { return uc } -func updateAllObjectsHealth(tx *gorm.DB) error { - return tx.Exec(` +func updateAllObjectsHealth(db *isql.DB) error { + _, err := db.Exec(context.Background(), ` UPDATE objects SET health = ( SELECT COALESCE(MIN(slabs.health), 1) FROM slabs INNER JOIN slices sli ON sli.db_slab_id = slabs.id WHERE sli.db_object_id = objects.id) -`).Error +`) + return err } // TestObjectBasic tests the hydration of raw objects works when we fetch @@ -159,17 +179,13 @@ func TestObjectBasic(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(*got.Object, want) { - t.Fatal("object mismatch", cmp.Diff(got.Object, want)) + t.Fatal("object mismatch", got.Object, want) } - // delete a sector - var sectors []dbSector - if err := ss.db.Find(§ors).Error; err != nil { - t.Fatal(err) - } else if len(sectors) != 2 { - t.Fatal("unexpected number of sectors") - } else if tx := ss.db.Delete(sectors[0]); tx.Error != nil || tx.RowsAffected != 1 { - t.Fatal("unexpected number of sectors deleted", tx.Error, tx.RowsAffected) + // update the sector to have a non-consecutive slab index + _, err = ss.DB().Exec(context.Background(), "UPDATE sectors SET slab_index = 100 WHERE slab_index = 1") + if err != nil { + t.Fatalf("failed to update sector: %v", err) } // fetch the object again and assert we receive an indication it was corrupted @@ -256,10 +272,7 @@ func TestObjectMetadata(t *testing.T) { } // assert metadata CASCADE on object delete - var cnt int64 - if err := ss.db.Model(&dbObjectUserMetadata{}).Count(&cnt).Error; err != nil { - t.Fatal(err) - } else if cnt != 2 { + if cnt := ss.Count("object_user_metadata"); cnt != 2 { t.Fatal("unexpected number of metadata entries", cnt) } @@ -269,9 +282,7 @@ func TestObjectMetadata(t *testing.T) { } // assert records are gone - if err := ss.db.Model(&dbObjectUserMetadata{}).Count(&cnt).Error; err != nil { - t.Fatal(err) - } else if cnt != 0 { + if cnt := ss.Count("object_user_metadata"); cnt != 0 { t.Fatal("unexpected number of metadata entries", cnt) } } @@ -289,8 +300,7 @@ func TestSQLContractStore(t *testing.T) { } // Add an announcement. - err = ss.insertTestAnnouncement(hk, newTestHostDBAnnouncement("address")) - if err != nil { + if err := ss.announceHost(hk, "address"); err != nil { t.Fatal(err) } @@ -385,7 +395,7 @@ func TestSQLContractStore(t *testing.T) { Size: c.Revision.Filesize, } if !reflect.DeepEqual(returned, expected) { - t.Fatal("contract mismatch") + t.Fatal("contract mismatch", cmp.Diff(returned, expected)) } // Look it up again. @@ -456,70 +466,16 @@ func TestSQLContractStore(t *testing.T) { } // Make sure the db was cleaned up properly through the CASCADE delete. - tableCountCheck := func(table interface{}, tblCount int64) error { - var count int64 - if err := ss.db.Model(table).Count(&count).Error; err != nil { - return err - } - if count != tblCount { - return fmt.Errorf("expected %v objects in table %v but got %v", tblCount, table.(schema.Tabler).TableName(), count) - } - return nil - } - if err := tableCountCheck(&dbContract{}, 0); err != nil { - t.Fatal(err) + if count := ss.Count("contracts"); count != 0 { + t.Fatalf("expected %v rows in contracts but got %v", 0, count) } // Check join table count as well. - var count int64 - if err := ss.db.Table("contract_sectors").Count(&count).Error; err != nil { - t.Fatal(err) - } - if count != 0 { + if count := ss.Count("contract_sectors"); count != 0 { t.Fatalf("expected %v objects in contract_sectors but got %v", 0, count) } } -func TestContractsForHost(t *testing.T) { - // create a SQL store - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - // add 2 hosts - hks, err := ss.addTestHosts(2) - if err != nil { - t.Fatal(err) - } - - // add 2 contracts - _, _, err = ss.addTestContracts(hks) - if err != nil { - t.Fatal(err) - } - - // fetch raw hosts - var hosts []dbHost - if err := ss.db. - Model(&dbHost{}). - Find(&hosts). - Error; err != nil { - t.Fatal(err) - } - if len(hosts) != 2 { - t.Fatal("unexpected number of hosts") - } - - contracts, _ := contractsForHost(ss.db, hosts[0]) - if len(contracts) != 1 || types.PublicKey(contracts[0].Host.PublicKey).String() != types.PublicKey(hosts[0].PublicKey).String() { - t.Fatal("unexpected", len(contracts), contracts) - } - - contracts, _ = contractsForHost(ss.db, hosts[1]) - if len(contracts) != 1 || types.PublicKey(contracts[0].Host.PublicKey).String() != types.PublicKey(hosts[1].PublicKey).String() { - t.Fatalf("unexpected contracts, %+v", contracts) - } -} - // TestContractRoots tests the ContractRoots function on the store. func TestContractRoots(t *testing.T) { // create a SQL store @@ -580,12 +536,10 @@ func TestRenewedContract(t *testing.T) { hk, hk2 := hks[0], hks[1] // Add announcements. - err = ss.insertTestAnnouncement(hk, newTestHostDBAnnouncement("address")) - if err != nil { + if err := ss.announceHost(hk, "address"); err != nil { t.Fatal(err) } - err = ss.insertTestAnnouncement(hk2, newTestHostDBAnnouncement("address2")) - if err != nil { + if err := ss.announceHost(hk2, "address2"); err != nil { t.Fatal(err) } @@ -700,6 +654,7 @@ func TestRenewedContract(t *testing.T) { ParentID: fcid1Renewed, UnlockConditions: uc, FileContract: types.FileContract{ + Filesize: 2 * rhpv2.SectorSize, MissedProofOutputs: []types.SiacoinOutput{}, ValidProofOutputs: []types.SiacoinOutput{}, }, @@ -759,7 +714,7 @@ func TestRenewedContract(t *testing.T) { HostKey: hk, StartHeight: newContractStartHeight, RenewedFrom: fcid1, - Size: rhpv2.SectorSize, + Size: 2 * rhpv2.SectorSize, State: api.ContractStatePending, Spending: api.ContractSpending{ Uploads: types.ZeroCurrency, @@ -767,6 +722,7 @@ func TestRenewedContract(t *testing.T) { FundAccount: types.ZeroCurrency, }, ContractPrice: types.NewCurrency64(2), + ContractSets: []string{"test"}, TotalCost: newContractTotal, } if !reflect.DeepEqual(newContract, expected) { @@ -774,42 +730,41 @@ func TestRenewedContract(t *testing.T) { } // Archived contract should exist. - var ac dbArchivedContract - err = ss.db.Model(&dbArchivedContract{}). - Where("fcid", fileContractID(fcid1)). - Take(&ac). - Error - if err != nil { - t.Fatal(err) - } - - ac.Model = Model{} - expectedContract := dbArchivedContract{ - Host: publicKey(c.HostKey()), - RenewedTo: fileContractID(fcid1Renewed), - Reason: api.ContractArchivalReasonRenewed, - - ContractCommon: ContractCommon{ - FCID: fileContractID(fcid1), - - ContractPrice: currency(oldContractPrice), - TotalCost: currency(oldContractTotal), - ProofHeight: 0, - RevisionHeight: 0, - RevisionNumber: "1", - StartHeight: 100, - WindowStart: 2, - WindowEnd: 3, - Size: rhpv2.SectorSize, - State: contractStatePending, - - UploadSpending: currency(types.Siacoins(1)), - DownloadSpending: currency(types.Siacoins(2)), - FundAccountSpending: currency(types.Siacoins(3)), - DeleteSpending: currency(types.Siacoins(4)), - ListSpending: currency(types.Siacoins(5)), + ancestors, err := ss.AncestorContracts(context.Background(), fcid1Renewed, 0) + if err != nil { + t.Fatal(err) + } else if len(ancestors) != 1 { + t.Fatalf("expected 1 ancestor but got %v", len(ancestors)) + } + ac := ancestors[0] + + expectedContract := api.ArchivedContract{ + ID: fcid1, + HostIP: "address", + HostKey: c.HostKey(), + RenewedTo: fcid1Renewed, + Spending: api.ContractSpending{ + Uploads: types.Siacoins(1), + Downloads: types.Siacoins(2), + FundAccount: types.Siacoins(3), + Deletions: types.Siacoins(4), + SectorRoots: types.ZeroCurrency, // currently not persisted }, + + ArchivalReason: api.ContractArchivalReasonRenewed, + ContractPrice: oldContractPrice, + ProofHeight: 0, + RenewedFrom: types.FileContractID{}, + RevisionHeight: 0, + RevisionNumber: 1, + Size: rhpv2.SectorSize, + StartHeight: 100, + State: api.ContractStatePending, + TotalCost: oldContractTotal, + WindowStart: 2, + WindowEnd: 3, } + if !reflect.DeepEqual(ac, expectedContract) { t.Fatal("mismatch", cmp.Diff(ac, expectedContract)) } @@ -851,9 +806,9 @@ func TestAncestorsContracts(t *testing.T) { t.Fatal(err) } - // Create a chain of 4 contracts. - // Their start heights are 0, 1, 2, 3. - fcids := []types.FileContractID{{1}, {2}, {3}, {4}} + // Create a chain of 6 contracts. + // Their start heights are 0, 1, 2, 3, 4, 5, 6. + fcids := []types.FileContractID{{1}, {2}, {3}, {4}, {5}, {6}} if _, err := ss.addTestContract(fcids[0], hk); err != nil { t.Fatal(err) } @@ -873,22 +828,40 @@ func TestAncestorsContracts(t *testing.T) { if len(contracts) != len(fcids)-2 { t.Fatal("wrong number of contracts returned", len(contracts)) } - for i := 0; i < len(contracts)-1; i++ { + for i := 0; i < len(contracts); i++ { + var renewedFrom, renewedTo types.FileContractID + if j := len(fcids) - 3 - i; j >= 0 { + renewedFrom = fcids[j] + } + if j := len(fcids) - 1 - i; j >= 0 { + renewedTo = fcids[j] + } expected := api.ArchivedContract{ - ID: fcids[len(fcids)-2-i], - HostKey: hk, - RenewedTo: fcids[len(fcids)-1-i], - StartHeight: 2, - Size: 4096, - State: api.ContractStatePending, - WindowStart: 400, - WindowEnd: 500, + ArchivalReason: api.ContractArchivalReasonRenewed, + ID: fcids[len(fcids)-2-i], + HostKey: hk, + RenewedFrom: renewedFrom, + RenewedTo: renewedTo, + RevisionNumber: 200, + StartHeight: uint64(len(fcids) - 2 - i), + Size: 4096, + State: api.ContractStatePending, + WindowStart: 400, + WindowEnd: 500, } if !reflect.DeepEqual(contracts[i], expected) { t.Log(cmp.Diff(contracts[i], expected)) t.Fatal("wrong contract", i, contracts[i]) } } + + // Fetch the ancestors with startHeight >= 5. That should return 0 contracts. + contracts, err = ss.AncestorContracts(context.Background(), fcids[len(fcids)-1], 5) + if err != nil { + t.Fatal(err) + } else if len(contracts) != 0 { + t.Fatalf("should have 0 contracts but got %v", len(contracts)) + } } func TestArchiveContracts(t *testing.T) { @@ -926,22 +899,30 @@ func TestArchiveContracts(t *testing.T) { } // assert the two others were archived - ffcids := make([]fileContractID, 2) - ffcids[0] = fileContractID(fcids[1]) - ffcids[1] = fileContractID(fcids[2]) - var acs []dbArchivedContract - err = ss.db.Model(&dbArchivedContract{}). - Where("fcid IN (?)", ffcids). - Find(&acs). - Error + ffcids := make([]sql.FileContractID, 2) + ffcids[0] = sql.FileContractID(fcids[1]) + ffcids[1] = sql.FileContractID(fcids[2]) + rows, err := ss.DB().Query(context.Background(), "SELECT reason FROM archived_contracts WHERE fcid IN (?, ?)", + sql.FileContractID(ffcids[0]), sql.FileContractID(ffcids[1])) if err != nil { t.Fatal(err) } - if len(acs) != 2 { - t.Fatal("wrong number of archived contracts", len(acs)) + defer rows.Close() + + var cnt int + for rows.Next() { + var reason string + if err := rows.Scan(&reason); err != nil { + t.Fatal(err) + } else if cnt == 0 && reason != "foo" { + t.Fatal("unexpected reason", reason) + } else if cnt == 1 && reason != "bar" { + t.Fatal("unexpected reason", reason) + } + cnt++ } - if acs[0].Reason != "foo" || acs[1].Reason != "bar" { - t.Fatal("unexpected reason", acs[0].Reason, acs[1].Reason) + if cnt != 2 { + t.Fatal("wrong number of archived contracts", cnt) } } @@ -1051,61 +1032,74 @@ func TestSQLMetadataStore(t *testing.T) { } // Fetch it using get and verify every field. - obj, err := ss.dbObject(objID) + obj, err := ss.Object(context.Background(), api.DefaultBucketName, objID) if err != nil { t.Fatal(err) } - obj1Key, err := obj1.Key.MarshalBinary() - if err != nil { - t.Fatal(err) - } - obj1Slab0Key, err := obj1.Slabs[0].Key.MarshalBinary() - if err != nil { - t.Fatal(err) - } - obj1Slab1Key, err := obj1.Slabs[1].Key.MarshalBinary() - if err != nil { - t.Fatal(err) + // compare timestamp separately + if obj.ModTime.IsZero() { + t.Fatal("unexpected", obj.ModTime) } + obj.ModTime = api.TimeRFC3339{} - // Set the Model fields to zero before comparing. These are set by gorm - // itself and contain a few timestamps which would make the following - // code a lot more verbose. - obj.Model = Model{} - for i := range obj.Slabs { - obj.Slabs[i].Model = Model{} - } - - one := uint(1) - expectedObj := dbObject{ - DBDirectoryID: 1, - DBBucketID: ss.DefaultBucketID(), - Health: 1, - ObjectID: objID, - Key: obj1Key, - Size: obj1.TotalSize(), - Slabs: []dbSlice{ - { - DBObjectID: &one, - DBSlabID: 1, - ObjectIndex: 1, - Offset: 10, - Length: 100, - }, - { - DBObjectID: &one, - DBSlabID: 2, - ObjectIndex: 2, - Offset: 20, - Length: 200, + obj1Slab0Key := obj1.Slabs[0].Key + obj1Slab1Key := obj1.Slabs[1].Key + + expectedObj := api.Object{ + ObjectMetadata: api.ObjectMetadata{ + ETag: testETag, + Health: 1, + ModTime: api.TimeRFC3339{}, + Name: objID, + Size: obj1.TotalSize(), + MimeType: testMimeType, + }, + Metadata: testMetadata, + Object: &object.Object{ + Key: obj1.Key, + Slabs: []object.SlabSlice{ + { + Offset: 10, + Length: 100, + Slab: object.Slab{ + Health: 1, + Key: obj1Slab0Key, + MinShards: 1, + Shards: []object.Sector{ + { + LatestHost: hk1, + Root: types.Hash256{1}, + Contracts: map[types.PublicKey][]types.FileContractID{ + hk1: {fcid1}, + }, + }, + }, + }, + }, + { + Offset: 20, + Length: 200, + Slab: object.Slab{ + Health: 1, + Key: obj1Slab1Key, + MinShards: 2, + Shards: []object.Sector{ + { + LatestHost: hk2, + Root: types.Hash256{2}, + Contracts: map[types.PublicKey][]types.FileContractID{ + hk2: {fcid2}, + }, + }, + }, + }, + }, }, }, - MimeType: testMimeType, - Etag: testETag, } if !reflect.DeepEqual(obj, expectedObj) { - t.Fatal("object mismatch", cmp.Diff(obj, expectedObj)) + t.Fatal("object mismatch", cmp.Diff(obj, expectedObj, cmp.AllowUnexported(object.EncryptionKey{}), cmp.Comparer(api.CompareTimeRFC3339))) } // Try to store it again. Should work. @@ -1114,28 +1108,20 @@ func TestSQLMetadataStore(t *testing.T) { } // Fetch it again and verify. - obj, err = ss.dbObject(objID) + obj, err = ss.Object(context.Background(), api.DefaultBucketName, objID) if err != nil { t.Fatal(err) } - // Set the Model fields to zero before comparing. These are set by gorm - // itself and contain a few timestamps which would make the following - // code a lot more verbose. - obj.Model = Model{} - for i := range obj.Slabs { - obj.Slabs[i].Model = Model{} + // compare timestamp separately + if obj.ModTime.IsZero() { + t.Fatal("unexpected", obj.ModTime) } + obj.ModTime = api.TimeRFC3339{} - // The expected object is the same except for some ids which were - // incremented due to the object and slab being overwritten. - two := uint(2) - expectedObj.Slabs[0].DBObjectID = &two - expectedObj.Slabs[0].DBSlabID = 1 - expectedObj.Slabs[1].DBObjectID = &two - expectedObj.Slabs[1].DBSlabID = 2 + // The expected object is the same. if !reflect.DeepEqual(obj, expectedObj) { - t.Fatal("object mismatch", cmp.Diff(obj, expectedObj)) + t.Fatal("object mismatch", cmp.Diff(obj, expectedObj, cmp.AllowUnexported(object.EncryptionKey{}), cmp.Comparer(api.CompareTimeRFC3339))) } // Fetch it and verify again. @@ -1147,102 +1133,100 @@ func TestSQLMetadataStore(t *testing.T) { t.Fatal("object mismatch", cmp.Diff(fullObj, obj1)) } - expectedObjSlab1 := dbSlab{ - DBContractSetID: 1, - Health: 1, - Key: obj1Slab0Key, - MinShards: 1, - TotalShards: 1, - Shards: []dbSector{ + expectedObjSlab1 := object.Slab{ + Health: 1, + Key: obj1Slab0Key, + MinShards: 1, + Shards: []object.Sector{ { - DBSlabID: 1, - SlabIndex: 1, - Root: obj1.Slabs[0].Shards[0].Root[:], - LatestHost: publicKey(obj1.Slabs[0].Shards[0].LatestHost), - Contracts: []dbContract{ - { - HostID: 1, - Host: dbHost{ - PublicKey: publicKey(hk1), - }, - - ContractCommon: ContractCommon{ - FCID: fileContractID(fcid1), - - TotalCost: currency(totalCost1), - RevisionNumber: "0", - StartHeight: startHeight1, - WindowStart: 400, - WindowEnd: 500, - Size: 4096, - State: contractStatePending, - - UploadSpending: zeroCurrency, - DownloadSpending: zeroCurrency, - FundAccountSpending: zeroCurrency, - }, - }, + Contracts: map[types.PublicKey][]types.FileContractID{ + hk1: {fcid1}, }, + LatestHost: hk1, + Root: types.Hash256{1}, }, }, } - expectedObjSlab2 := dbSlab{ - DBContractSetID: 1, - Health: 1, - Key: obj1Slab1Key, - MinShards: 2, - TotalShards: 1, - Shards: []dbSector{ + expectedContract1 := api.ContractMetadata{ + ID: fcid1, + HostIP: "", + HostKey: hk1, + SiamuxAddr: "", + ProofHeight: 0, + RevisionHeight: 0, + RevisionNumber: 0, + Size: 4096, + StartHeight: startHeight1, + State: api.ContractStatePending, + WindowStart: 400, + WindowEnd: 500, + ContractPrice: types.ZeroCurrency, + RenewedFrom: types.FileContractID{}, + Spending: api.ContractSpending{ + Uploads: types.ZeroCurrency, + Downloads: types.ZeroCurrency, + FundAccount: types.ZeroCurrency, + }, + TotalCost: totalCost1, + ContractSets: nil, + } + + expectedObjSlab2 := object.Slab{ + Health: 1, + Key: obj1Slab1Key, + MinShards: 2, + Shards: []object.Sector{ { - DBSlabID: 2, - SlabIndex: 1, - Root: obj1.Slabs[1].Shards[0].Root[:], - LatestHost: publicKey(obj1.Slabs[1].Shards[0].LatestHost), - Contracts: []dbContract{ - { - HostID: 2, - Host: dbHost{ - PublicKey: publicKey(hk2), - }, - ContractCommon: ContractCommon{ - FCID: fileContractID(fcid2), - - TotalCost: currency(totalCost2), - RevisionNumber: "0", - StartHeight: startHeight2, - WindowStart: 400, - WindowEnd: 500, - Size: 4096, - State: contractStatePending, - - UploadSpending: zeroCurrency, - DownloadSpending: zeroCurrency, - FundAccountSpending: zeroCurrency, - }, - }, + Contracts: map[types.PublicKey][]types.FileContractID{ + hk2: {fcid2}, }, + LatestHost: hk2, + Root: types.Hash256{2}, }, }, } + expectedContract2 := api.ContractMetadata{ + ID: fcid2, + HostIP: "", + HostKey: hk2, + SiamuxAddr: "", + ProofHeight: 0, + RevisionHeight: 0, + RevisionNumber: 0, + Size: 4096, + StartHeight: startHeight2, + State: api.ContractStatePending, + WindowStart: 400, + WindowEnd: 500, + ContractPrice: types.ZeroCurrency, + RenewedFrom: types.FileContractID{}, + Spending: api.ContractSpending{ + Uploads: types.ZeroCurrency, + Downloads: types.ZeroCurrency, + FundAccount: types.ZeroCurrency, + }, + TotalCost: totalCost2, + ContractSets: nil, + } + // Compare slabs. - slab1, err := ss.dbSlab(obj1Slab0Key) + slab1, err := ss.Slab(context.Background(), obj1Slab0Key) + if err != nil { + t.Fatal(err) + } + contract1, err := ss.Contract(context.Background(), fcid1) if err != nil { t.Fatal(err) } - slab2, err := ss.dbSlab(obj1Slab1Key) + slab2, err := ss.Slab(context.Background(), obj1Slab1Key) if err != nil { t.Fatal(err) } - slabs := []*dbSlab{&slab1, &slab2} - for i := range slabs { - slabs[i].Model = Model{} - slabs[i].Shards[0].Model = Model{} - slabs[i].Shards[0].Contracts[0].Model = Model{} - slabs[i].Shards[0].Contracts[0].Host.Model = Model{} - slabs[i].Shards[0].Contracts[0].Host.LastAnnouncement = time.Time{} - slabs[i].HealthValidUntil = 0 + contract2, err := ss.Contract(context.Background(), fcid2) + if err != nil { + t.Fatal(err) } if !reflect.DeepEqual(slab1, expectedObjSlab1) { t.Fatal("mismatch", cmp.Diff(slab1, expectedObjSlab1)) @@ -1250,6 +1234,12 @@ func TestSQLMetadataStore(t *testing.T) { if !reflect.DeepEqual(slab2, expectedObjSlab2) { t.Fatal("mismatch", cmp.Diff(slab2, expectedObjSlab2)) } + if !reflect.DeepEqual(contract1, expectedContract1) { + t.Fatal("mismatch", cmp.Diff(contract1, expectedContract1)) + } + if !reflect.DeepEqual(contract2, expectedContract2) { + t.Fatal("mismatch", cmp.Diff(contract2, expectedContract2)) + } // Remove the first slab of the object. obj1.Slabs = obj1.Slabs[1:] @@ -1266,27 +1256,23 @@ func TestSQLMetadataStore(t *testing.T) { // - 1 element in the slices table for the same reason // - 1 element in the sectors table for the same reason countCheck := func(objCount, sliceCount, slabCount, sectorCount int64) error { - tableCountCheck := func(table interface{}, tblCount int64) error { - var count int64 - if err := ss.db.Model(table).Count(&count).Error; err != nil { - return err - } - if count != tblCount { - return fmt.Errorf("expected %v objects in table %v but got %v", tblCount, table.(schema.Tabler).TableName(), count) + tableCountCheck := func(table string, tblCount int64) error { + if count := ss.Count(table); count != tblCount { + return fmt.Errorf("expected %v objects in table %v but got %v", tblCount, table, count) } return nil } // Check all tables. - if err := tableCountCheck(&dbObject{}, objCount); err != nil { + if err := tableCountCheck("objects", objCount); err != nil { return err } - if err := tableCountCheck(&dbSlice{}, sliceCount); err != nil { + if err := tableCountCheck("slices", sliceCount); err != nil { return err } - if err := tableCountCheck(&dbSlab{}, slabCount); err != nil { + if err := tableCountCheck("slabs", slabCount); err != nil { return err } - if err := tableCountCheck(&dbSector{}, sectorCount); err != nil { + if err := tableCountCheck("sectors", sectorCount); err != nil { return err } return nil @@ -1511,7 +1497,7 @@ func TestObjectEntries(t *testing.T) { } // update health of objects to match the overridden health of the slabs - if err := updateAllObjectsHealth(ss.db); err != nil { + if err := updateAllObjectsHealth(ss.DB()); err != nil { t.Fatal() } @@ -1593,6 +1579,67 @@ func TestObjectEntries(t *testing.T) { } } +func TestObjectEntriesExplicitDir(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + objects := []struct { + path string + size int64 + }{ + {"/dir/", 0}, // empty dir - created first + {"/dir/file", 1}, // file uploaded to dir + {"/dir2/", 2}, // empty dir - remains empty + } + + ctx := context.Background() + for _, o := range objects { + obj := newTestObject(frand.Intn(9) + 1) + obj.Slabs = obj.Slabs[:1] + obj.Slabs[0].Length = uint32(o.size) + _, err := ss.addTestObject(o.path, obj) + if err != nil { + t.Fatal(err) + } + } + + // set file health to 0.5 + if err := ss.overrideSlabHealth("/dir/file", 0.5); err != nil { + t.Fatal(err) + } + + // update health of objects to match the overridden health of the slabs + if err := updateAllObjectsHealth(ss.DB()); err != nil { + t.Fatal() + } + + tests := []struct { + path string + prefix string + sortBy string + sortDir string + want []api.ObjectMetadata + }{ + {"/", "", "", "", []api.ObjectMetadata{ + {Name: "/dir/", Size: 1, Health: 0.5}, + {ETag: "d34db33f", Name: "/dir2/", Size: 2, Health: 1, MimeType: testMimeType}, // has MimeType and ETag since it's a file + }}, + {"/dir/", "", "", "", []api.ObjectMetadata{{ETag: "d34db33f", Name: "/dir/file", Size: 1, Health: 0.5, MimeType: testMimeType}}}, + } + for _, test := range tests { + got, _, err := ss.ObjectEntries(ctx, api.DefaultBucketName, test.path, test.prefix, test.sortBy, test.sortDir, "", 0, -1) + if err != nil { + t.Fatal(err) + } + for i := range got { + got[i].ModTime = api.TimeRFC3339{} // ignore time for comparison + } + if !reflect.DeepEqual(got, test.want) { + t.Fatalf("\nlist: %v\nprefix: %v\ngot: %v\nwant: %v", test.path, test.prefix, got, test.want) + } + } +} + // TestSearchObjects is a test for the SearchObjects method. func TestSearchObjects(t *testing.T) { ss := newTestSQLStore(t, defaultTestSQLStoreConfig) @@ -1773,6 +1820,12 @@ func TestUnhealthySlabs(t *testing.T) { t.Fatal(err) } + // add a partial slab + _, _, err = ss.AddPartialSlab(context.Background(), []byte{1, 2, 3}, 1, 3, testContractSet) + if err != nil { + t.Fatal(err) + } + if err := ss.RefreshHealth(context.Background()); err != nil { t.Fatal(err) } @@ -1946,7 +1999,7 @@ func TestUnhealthySlabsNoContracts(t *testing.T) { if err != nil { t.Fatal(err) } - if err := ss.db.Table("contract_sectors").Where("TRUE").Delete(&dbContractSector{}).Error; err != nil { + if _, err := ss.DB().Exec(context.Background(), "DELETE FROM contract_sectors"); err != nil { t.Fatal(err) } @@ -2084,12 +2137,8 @@ func TestContractSectors(t *testing.T) { } // Check the join table. Should be empty. - var css []dbContractSector - if err := ss.db.Find(&css).Error; err != nil { - t.Fatal(err) - } - if len(css) != 0 { - t.Fatal("table should be empty", len(css)) + if n := ss.Count("contract_sectors"); n != 0 { + t.Fatal("table should be empty", n) } // Add the contract back. @@ -2109,14 +2158,10 @@ func TestContractSectors(t *testing.T) { } // Delete the sector. - if err := ss.db.Delete(&dbSector{Model: Model{ID: 1}}).Error; err != nil { - t.Fatal(err) - } - if err := ss.db.Find(&css).Error; err != nil { + if _, err := ss.DB().Exec(context.Background(), "DELETE FROM sectors WHERE id = ?", 1); err != nil { t.Fatal(err) - } - if len(css) != 0 { - t.Fatal("table should be empty") + } else if n := ss.Count("contract_sectors"); n != 0 { + t.Fatal("table should be empty", n) } } @@ -2167,22 +2212,34 @@ func TestUpdateSlab(t *testing.T) { } // helper to fetch a slab from the database - fetchSlab := func() (slab dbSlab) { + fetchSlab := func() (slab object.Slab) { t.Helper() - if err = ss.db. - Where(&dbSlab{Key: key}). - Preload("Shards.Contracts"). - Take(&slab). - Error; err != nil { + if slab, err = ss.Slab(ctx, obj.Slabs[0].Key); err != nil { t.Fatal(err) } return } - // helper to extract the FCID from a list of contracts - contractIds := func(contracts []dbContract) (ids []fileContractID) { - for _, c := range contracts { - ids = append(ids, fileContractID(c.FCID)) + // helper to fetch contract ids for a sector + contractIds := func(root types.Hash256) (fcids []types.FileContractID) { + t.Helper() + rows, err := ss.DB().Query(context.Background(), ` + SELECT fcid + FROM contracts c + INNER JOIN contract_sectors cs ON c.id = cs.db_contract_id + INNER JOIN sectors s ON s.id = cs.db_sector_id + WHERE s.root = ? + `, sql.Hash256(root)) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var fcid types.FileContractID + if err := rows.Scan((*sql.FileContractID)(&fcid)); err != nil { + t.Fatal(err) + } + fcids = append(fcids, fcid) } return } @@ -2192,9 +2249,9 @@ func TestUpdateSlab(t *testing.T) { // assert both sectors were upload to one contract/host for i := 0; i < 2; i++ { - if cids := contractIds(inserted.Shards[i].Contracts); len(cids) != 1 { + if cids := contractIds(types.Hash256(inserted.Shards[i].Root)); len(cids) != 1 { t.Fatalf("sector %d was uploaded to unexpected amount of contracts, %v!=1", i+1, len(cids)) - } else if inserted.Shards[i].LatestHost != publicKey(hks[i]) { + } else if inserted.Shards[i].LatestHost != hks[i] { t.Fatalf("sector %d was uploaded to unexpected amount of hosts, %v!=1", i+1, len(hks)) } } @@ -2231,30 +2288,27 @@ func TestUpdateSlab(t *testing.T) { updated := fetchSlab() // assert the first sector is still only on one host, also assert it's h1 - if cids := contractIds(updated.Shards[0].Contracts); len(cids) != 1 { + if cids := contractIds(types.Hash256(updated.Shards[0].Root)); len(cids) != 1 { t.Fatalf("sector 1 was uploaded to unexpected amount of contracts, %v!=1", len(cids)) } else if types.FileContractID(cids[0]) != fcid1 { t.Fatal("sector 1 was uploaded to unexpected contract", cids[0]) - } else if updated.Shards[0].LatestHost != publicKey(hks[0]) { - t.Fatal("host key was invalid", updated.Shards[0].LatestHost, publicKey(hks[0])) + } else if updated.Shards[0].LatestHost != hks[0] { + t.Fatal("host key was invalid", updated.Shards[0].LatestHost, sql.PublicKey(hks[0])) } else if hks[0] != hk1 { t.Fatal("sector 1 was uploaded to unexpected host", hks[0]) } // assert the second sector however is uploaded to two hosts, assert it's h2 and h3 - if cids := contractIds(updated.Shards[1].Contracts); len(cids) != 2 { + if cids := contractIds(types.Hash256(updated.Shards[1].Root)); len(cids) != 2 { t.Fatalf("sector 1 was uploaded to unexpected amount of contracts, %v!=2", len(cids)) } else if types.FileContractID(cids[0]) != fcid2 || types.FileContractID(cids[1]) != fcid3 { t.Fatal("sector 1 was uploaded to unexpected contracts", cids[0], cids[1]) - } else if updated.Shards[0].LatestHost != publicKey(hks[0]) { - t.Fatal("host key was invalid", updated.Shards[0].LatestHost, publicKey(hks[0])) + } else if updated.Shards[0].LatestHost != hks[0] { + t.Fatal("host key was invalid", updated.Shards[0].LatestHost, sql.PublicKey(hks[0])) } // assert there's still only one entry in the dbslab table - var cnt int64 - if err := ss.db.Model(&dbSlab{}).Count(&cnt).Error; err != nil { - t.Fatal(err) - } else if cnt != 1 { + if cnt := ss.Count("slabs"); cnt != 1 { t.Fatalf("unexpected number of entries in dbslab, %v != 1", cnt) } @@ -2270,12 +2324,12 @@ func TestUpdateSlab(t *testing.T) { t.Fatal("unexpected number of slabs to migrate", len(toMigrate)) } - if obj, err := ss.dbObject(t.Name()); err != nil { + if obj, err := ss.Object(context.Background(), api.DefaultBucketName, t.Name()); err != nil { t.Fatal(err) } else if len(obj.Slabs) != 1 { t.Fatalf("unexpected number of slabs, %v != 1", len(obj.Slabs)) - } else if obj.Slabs[0].ID != updated.ID { - t.Fatalf("unexpected slab, %v != %v", obj.Slabs[0].ID, updated.ID) + } else if obj.Slabs[0].Key.String() != updated.Key.String() { + t.Fatalf("unexpected slab, %v != %v", obj.Slabs[0].Key, updated.Key) } // update the slab to change its contract set. @@ -2286,14 +2340,11 @@ func TestUpdateSlab(t *testing.T) { if err != nil { t.Fatal(err) } - var s dbSlab - if err := ss.db.Where(&dbSlab{Key: key}). - Joins("DBContractSet"). - Preload("Shards"). - Take(&s). - Error; err != nil { + var csID int64 + if err := ss.DB().QueryRow(context.Background(), "SELECT db_contract_set_id FROM slabs WHERE `key` = ?", key). + Scan(&csID); err != nil { t.Fatal(err) - } else if s.DBContractSet.Name != "other" { + } else if csID != ss.ContractSetID("other") { t.Fatal("contract set was not updated") } } @@ -2338,8 +2389,7 @@ func TestRecordContractSpending(t *testing.T) { } // Add an announcement. - err = ss.insertTestAnnouncement(hk, newTestHostDBAnnouncement("address")) - if err != nil { + if err := ss.announceHost(hk, "address"); err != nil { t.Fatal(err) } @@ -2347,9 +2397,10 @@ func TestRecordContractSpending(t *testing.T) { cm, err := ss.addTestContract(fcid, hk) if err != nil { t.Fatal(err) - } - if cm.Spending != (api.ContractSpending{}) { + } else if cm.Spending != (api.ContractSpending{}) { t.Fatal("spending should be all 0") + } else if cm.Size != 0 && cm.RevisionNumber != 0 { + t.Fatalf("unexpected size or revision number, %v %v", cm.Size, cm.RevisionNumber) } // Record some spending. @@ -2369,6 +2420,8 @@ func TestRecordContractSpending(t *testing.T) { { ContractID: fcid, ContractSpending: expectedSpending, + RevisionNumber: 100, + Size: 200, }, }) if err != nil { @@ -2377,16 +2430,20 @@ func TestRecordContractSpending(t *testing.T) { cm2, err := ss.Contract(context.Background(), fcid) if err != nil { t.Fatal(err) - } - if cm2.Spending != expectedSpending { + } else if cm2.Spending != expectedSpending { t.Fatal("invalid spending", cm2.Spending, expectedSpending) + } else if cm2.Size != 200 && cm2.RevisionNumber != 100 { + t.Fatalf("unexpected size or revision number, %v %v", cm2.Size, cm2.RevisionNumber) } - // Record the same spending again. + // Record the same spending again but with a lower revision number. This + // shouldn't update the size. err = ss.RecordContractSpending(context.Background(), []api.ContractSpendingRecord{ { ContractID: fcid, ContractSpending: expectedSpending, + RevisionNumber: 100, + Size: 200, }, }) if err != nil { @@ -2396,9 +2453,10 @@ func TestRecordContractSpending(t *testing.T) { cm3, err := ss.Contract(context.Background(), fcid) if err != nil { t.Fatal(err) - } - if cm3.Spending != expectedSpending { + } else if cm3.Spending != expectedSpending { t.Fatal("invalid spending") + } else if cm2.Size != 200 && cm2.RevisionNumber != 100 { + t.Fatalf("unexpected size or revision number, %v %v", cm2.Size, cm2.RevisionNumber) } } @@ -2431,40 +2489,37 @@ func TestRenameObjects(t *testing.T) { } // Try renaming objects that don't exist. - if err := ss.RenameObject(ctx, api.DefaultBucketName, "/fileÅ›", "/fileÅ›2", false); !errors.Is(err, api.ErrObjectNotFound) { + if err := ss.RenameObjectBlocking(ctx, api.DefaultBucketName, "/fileÅ›", "/fileÅ›2", false); !errors.Is(err, api.ErrObjectNotFound) { t.Fatal(err) } - if err := ss.RenameObjects(ctx, api.DefaultBucketName, "/fileÅ›1", "/fileÅ›2", false); !errors.Is(err, api.ErrObjectNotFound) { + if err := ss.RenameObjectsBlocking(ctx, api.DefaultBucketName, "/fileÅ›1", "/fileÅ›2", false); !errors.Is(err, api.ErrObjectNotFound) { t.Fatal(err) } // Perform some renames. - if err := ss.RenameObjects(ctx, api.DefaultBucketName, "/fileÅ›/dir/", "/fileÅ›/", false); err != nil { + if err := ss.RenameObjectsBlocking(ctx, api.DefaultBucketName, "/fileÅ›/dir/", "/fileÅ›/", false); err != nil { t.Fatal(err) } - if err := ss.RenameObject(ctx, api.DefaultBucketName, "/foo", "/fileÅ›/foo", false); err != nil { + if err := ss.RenameObjectBlocking(ctx, api.DefaultBucketName, "/foo", "/fileÅ›/foo", false); err != nil { t.Fatal(err) } - if err := ss.RenameObject(ctx, api.DefaultBucketName, "/bar", "/fileÅ›/bar", false); err != nil { + if err := ss.RenameObjectBlocking(ctx, api.DefaultBucketName, "/bar", "/fileÅ›/bar", false); err != nil { t.Fatal(err) } - if err := ss.RenameObject(ctx, api.DefaultBucketName, "/baz", "/fileÅ›/baz", false); err != nil { + if err := ss.RenameObjectBlocking(ctx, api.DefaultBucketName, "/baz", "/fileÅ›/baz", false); err != nil { t.Fatal(err) } - if err := ss.RenameObjects(ctx, api.DefaultBucketName, "/fileÅ›/case", "/fileÅ›/case1", false); err != nil { + if err := ss.RenameObjectsBlocking(ctx, api.DefaultBucketName, "/fileÅ›/case", "/fileÅ›/case1", false); err != nil { t.Fatal(err) } - if err := ss.RenameObjects(ctx, api.DefaultBucketName, "/fileÅ›/CASE", "/fileÅ›/case2", false); err != nil { + if err := ss.RenameObjectsBlocking(ctx, api.DefaultBucketName, "/fileÅ›/CASE", "/fileÅ›/case2", false); err != nil { t.Fatal(err) } - if err := ss.RenameObjects(ctx, api.DefaultBucketName, "/baz2", "/fileÅ›/baz", false); !errors.Is(err, api.ErrObjectExists) { - t.Fatal(err) - } else if err := ss.RenameObjects(ctx, api.DefaultBucketName, "/baz2", "/fileÅ›/baz", true); err != nil { + if err := ss.RenameObjectsBlocking(ctx, api.DefaultBucketName, "/baz2", "/fileÅ›/baz", false); !errors.Is(err, api.ErrObjectExists) { t.Fatal(err) - } - if err := ss.RenameObject(ctx, api.DefaultBucketName, "/baz3", "/fileÅ›/baz", false); !errors.Is(err, api.ErrObjectExists) { + } else if err := ss.RenameObjectsBlocking(ctx, api.DefaultBucketName, "/baz2", "/fileÅ›/baz", true); err != nil { t.Fatal(err) - } else if err := ss.RenameObject(ctx, api.DefaultBucketName, "/baz3", "/fileÅ›/baz", true); err != nil { + } else if err := ss.RenameObjectBlocking(ctx, api.DefaultBucketName, "/baz3", "/fileÅ›/baz", true); err != nil { t.Fatal(err) } @@ -2505,8 +2560,8 @@ func TestRenameObjects(t *testing.T) { // Assert directories are correct expectedDirs := []struct { - id uint - parentID uint + id int64 + parentID int64 name string }{ { @@ -2520,24 +2575,40 @@ func TestRenameObjects(t *testing.T) { name: "/fileÅ›/", }, } - var directories []dbDirectory - test.Retry(100, 100*time.Millisecond, func() error { - if err := ss.db.Find(&directories).Error; err != nil { - return err - } else if len(directories) != len(expectedDirs) { - return fmt.Errorf("unexpected number of directories, %v != %v", len(directories), len(expectedDirs)) - } - return nil - }) - for i, dir := range directories { - if dir.ID != expectedDirs[i].id { + var n int64 + if err := ss.DB().QueryRow(ctx, "SELECT COUNT(*) FROM directories").Scan(&n); err != nil { + t.Fatal(err) + } else if n != int64(len(expectedDirs)) { + t.Fatalf("unexpected number of directories, %v != %v", n, len(expectedDirs)) + } + + type row struct { + ID int64 + ParentID int64 + Name string + } + rows, err := ss.DB().Query(context.Background(), "SELECT id, COALESCE(db_parent_id, 0), name FROM directories ORDER BY id ASC") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + var i int + for rows.Next() { + var dir row + if err := rows.Scan(&dir.ID, &dir.ParentID, &dir.Name); err != nil { + t.Fatal(err) + } else if dir.ID != expectedDirs[i].id { t.Fatalf("unexpected directory id, %v != %v", dir.ID, expectedDirs[i].id) - } else if dir.DBParentID != expectedDirs[i].parentID { - t.Fatalf("unexpected directory parent id, %v != %v", dir.DBParentID, expectedDirs[i].parentID) + } else if dir.ParentID != expectedDirs[i].parentID { + t.Fatalf("unexpected directory parent id, %v != %v", dir.ParentID, expectedDirs[i].parentID) } else if dir.Name != expectedDirs[i].name { t.Fatalf("unexpected directory name, %v != %v", dir.Name, expectedDirs[i].name) } + i++ + } + if len(expectedDirs) != i { + t.Fatalf("expected %v dirs, got %v", len(expectedDirs), i) } } @@ -2589,28 +2660,26 @@ func TestObjectsStats(t *testing.T) { // Get all entries in contract_sectors and store them again with a different // contract id. This should cause the uploaded size to double. - var contractSectors []dbContractSector - err = ss.db.Find(&contractSectors).Error - if err != nil { - t.Fatal(err) - } var newContractID types.FileContractID frand.Read(newContractID[:]) - c, err := ss.addTestContract(newContractID, types.PublicKey{}) + hks, err := ss.addTestHosts(1) if err != nil { t.Fatal(err) } - totalUploadedSize += c.Size - newContract, err := ss.contract(context.Background(), fileContractID(newContractID)) + hk := hks[0] + c, err := ss.addTestContract(newContractID, hk) if err != nil { t.Fatal(err) } - for _, contractSector := range contractSectors { - contractSector.DBContractID = newContract.ID - err = ss.db.Create(&contractSector).Error - if err != nil { - t.Fatal(err) - } + totalUploadedSize += c.Size + if _, err := ss.DB().Exec(context.Background(), ` + INSERT INTO contract_sectors (db_contract_id, db_sector_id) + SELECT ( + SELECT id FROM contracts WHERE fcid = ? + ), db_sector_id + FROM contract_sectors + `, sql.FileContractID(newContractID)); err != nil { + t.Fatal(err) } // Check sizes. @@ -2737,15 +2806,23 @@ func TestPartialSlab(t *testing.T) { type bufferedSlab struct { ID uint - DBSlab dbSlab `gorm:"foreignKey:DBBufferedSlabID"` Filename string } - - var buffer bufferedSlab - sk, _ := slabs[0].Key.MarshalBinary() - if err := ss.db.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { - t.Fatal(err) + fetchBuffer := func(ec object.EncryptionKey) (b bufferedSlab) { + t.Helper() + if err := ss.DB().QueryRow(context.Background(), ` + SELECT bs.id, bs.filename + FROM buffered_slabs bs + INNER JOIN slabs sla ON sla.db_buffered_slab_id = bs.id + WHERE sla.key = ? + `, sql.EncryptionKey(ec)). + Scan(&b.ID, &b.Filename); err != nil && !errors.Is(err, dsql.ErrNoRows) { + t.Fatal(err) + } + return } + + buffer := fetchBuffer(slabs[0].Key) if buffer.Filename == "" { t.Fatal("empty filename") } @@ -2803,11 +2880,6 @@ func TestPartialSlab(t *testing.T) { } else if !bytes.Equal(data, slab2Data) { t.Fatal("wrong data") } - buffer = bufferedSlab{} - sk, _ = slabs[0].Key.MarshalBinary() - if err := ss.db.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { - t.Fatal(err) - } assertBuffer(buffer1Name, 4194303, false, false) // Create an object again. @@ -2844,17 +2916,9 @@ func TestPartialSlab(t *testing.T) { } else if !bytes.Equal(slab3Data, append(data1, data2...)) { t.Fatal("wrong data") } - buffer = bufferedSlab{} - sk, _ = slabs[0].Key.MarshalBinary() - if err := ss.db.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { - t.Fatal(err) - } assertBuffer(buffer1Name, rhpv2.SectorSize, true, false) - buffer = bufferedSlab{} - sk, _ = slabs[1].Key.MarshalBinary() - if err := ss.db.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { - t.Fatal(err) - } + + buffer = fetchBuffer(slabs[1].Key) buffer2Name := buffer.Filename assertBuffer(buffer2Name, 1, false, false) @@ -2879,13 +2943,9 @@ func TestPartialSlab(t *testing.T) { assertBuffer(buffer1Name, rhpv2.SectorSize, true, true) assertBuffer(buffer2Name, 1, false, false) - var foo []bufferedSlab - if err := ss.db.Find(&foo).Error; err != nil { - t.Fatal(err) - } - buffer = bufferedSlab{} - if err := ss.db.Take(&buffer, "id = ?", packedSlabs[0].BufferID).Error; err != nil { - t.Fatal(err) + buffer = fetchBuffer(packedSlabs[0].Key) + if buffer.ID != packedSlabs[0].BufferID { + t.Fatalf("wrong buffer id, %v != %v", buffer.ID, packedSlabs[0].BufferID) } // Mark slab as uploaded. @@ -2902,8 +2962,8 @@ func TestPartialSlab(t *testing.T) { t.Fatal(err) } - buffer = bufferedSlab{} - if err := ss.db.Take(&buffer, "id = ?", packedSlabs[0].BufferID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + buffer = fetchBuffer(packedSlabs[0].Key) + if buffer != (bufferedSlab{}) { t.Fatal("shouldn't be able to find buffer", err) } assertBuffer(buffer2Name, 1, false, false) @@ -3105,36 +3165,6 @@ func TestContractSizes(t *testing.T) { if n := prunableData(nil); n != 0 { t.Fatal("expected no prunable data", n) } - - // assert passing a non-existent fcid returns an error - _, err = ss.ContractSize(context.Background(), types.FileContractID{9}) - if err != api.ErrContractNotFound { - t.Fatal(err) - } -} - -// dbObject retrieves a dbObject from the store. -func (s *SQLStore) dbObject(key string) (dbObject, error) { - var obj dbObject - tx := s.db.Where(&dbObject{ObjectID: key}). - Preload("Slabs"). - Take(&obj) - if errors.Is(tx.Error, gorm.ErrRecordNotFound) { - return dbObject{}, api.ErrObjectNotFound - } - return obj, nil -} - -// dbSlab retrieves a dbSlab from the store. -func (s *SQLStore) dbSlab(key []byte) (dbSlab, error) { - var slab dbSlab - tx := s.db.Where(&dbSlab{Key: key}). - Preload("Shards.Contracts.Host"). - Take(&slab) - if errors.Is(tx.Error, gorm.ErrRecordNotFound) { - return dbSlab{}, api.ErrObjectNotFound - } - return slab, nil } func TestObjectsBySlabKey(t *testing.T) { @@ -3395,16 +3425,13 @@ func TestBucketObjects(t *testing.T) { } // See if we can fetch the object by slab. - var ec object.EncryptionKey - if obj, err := ss.objectRaw(ss.db, b1, "/bar"); err != nil { - t.Fatal(err) - } else if err := ec.UnmarshalBinary(obj[0].SlabKey); err != nil { + if obj, err := ss.Object(context.Background(), b1, "/bar"); err != nil { t.Fatal(err) - } else if objects, err := ss.ObjectsBySlabKey(context.Background(), b1, ec); err != nil { + } else if objects, err := ss.ObjectsBySlabKey(context.Background(), b1, obj.Slabs[0].Key); err != nil { t.Fatal(err) } else if len(objects) != 1 { t.Fatal("expected 1 object", len(objects)) - } else if objects, err := ss.ObjectsBySlabKey(context.Background(), b2, ec); err != nil { + } else if objects, err := ss.ObjectsBySlabKey(context.Background(), b2, obj.Slabs[0].Key); err != nil { t.Fatal(err) } else if len(objects) != 0 { t.Fatal("expected 0 objects", len(objects)) @@ -3527,12 +3554,7 @@ func TestMarkSlabUploadedAfterRenew(t *testing.T) { }) if err != nil { t.Fatal(err) - } - - var count int64 - if err := ss.db.Model(&dbContractSector{}).Count(&count).Error; err != nil { - t.Fatal(err) - } else if count != 1 { + } else if count := ss.Count("contract_sectors"); count != 1 { t.Fatal("expected 1 sector", count) } } @@ -3581,7 +3603,7 @@ func TestListObjects(t *testing.T) { } // update health of objects to match the overridden health of the slabs - if err := updateAllObjectsHealth(ss.db); err != nil { + if err := updateAllObjectsHealth(ss.DB()); err != nil { t.Fatal() } @@ -3658,45 +3680,29 @@ func TestDeleteHostSector(t *testing.T) { hk1, hk2 := hks[0], hks[1] // create 2 contracts with each - _, _, err = ss.addTestContracts([]types.PublicKey{hk1, hk1, hk2, hk2}) + fcids, _, err := ss.addTestContracts([]types.PublicKey{hk1, hk1, hk2, hk2}) if err != nil { t.Fatal(err) } - // get all contracts - var dbContracts []dbContract - if err := ss.db.Model(&dbContract{}).Preload("Host").Find(&dbContracts).Error; err != nil { - t.Fatal(err) - } - // create a healthy slab with one sector that is uploaded to all contracts. - key, _ := object.GenerateEncryptionKey().MarshalBinary() root := types.Hash256{1, 2, 3} - slab := dbSlab{ - DBContractSetID: 1, - Key: key, - Health: 1.0, - HealthValidUntil: time.Now().Add(time.Hour).Unix(), - TotalShards: 1, - Shards: []dbSector{ + ss.InsertSlab(object.Slab{ + Key: object.GenerateEncryptionKey(), + MinShards: 1, + Shards: []object.Sector{ { - Contracts: dbContracts, - Root: root[:], - LatestHost: publicKey(hk1), // hk1 is latest host + Contracts: map[types.PublicKey][]types.FileContractID{ + hk1: fcids, + }, + Root: root, + LatestHost: hk1, }, }, - } - if err := ss.db.Create(&slab).Error; err != nil { - t.Fatal(err) - } + }) // Make sure 4 contractSector entries exist. - var n int64 - if err := ss.db.Model(&dbContractSector{}). - Count(&n). - Error; err != nil { - t.Fatal(err) - } else if n != 4 { + if n := ss.Count("contract_sectors"); n != 4 { t.Fatal("expected 4 contract-sector links", n) } @@ -3708,32 +3714,65 @@ func TestDeleteHostSector(t *testing.T) { } // Make sure 2 contractSector entries exist. - if err := ss.db.Model(&dbContractSector{}). - Count(&n). - Error; err != nil { - t.Fatal(err) - } else if n != 2 { + if n := ss.Count("contract_sectors"); n != 2 { t.Fatal("expected 2 contract-sector links", n) } // Find the slab. It should have an invalid health. - var s dbSlab - if err := ss.db.Preload("Shards").Take(&s).Error; err != nil { + var slabID int64 + var validUntil int64 + if err := ss.DB().QueryRow(context.Background(), "SELECT id, health_valid_until FROM slabs").Scan(&slabID, &validUntil); err != nil { t.Fatal(err) - } else if s.HealthValid() { + } else if time.Now().Before(time.Unix(validUntil, 0)) { t.Fatal("expected health to be invalid") - } else if s.Shards[0].LatestHost != publicKey(hk2) { - t.Fatal("expected hk2 to be latest host", types.PublicKey(s.Shards[0].LatestHost)) + } + + sectorContractCnt := func(root types.Hash256) (n int) { + t.Helper() + err := ss.DB().QueryRow(context.Background(), ` + SELECT COUNT(*) + FROM contract_sectors cs + INNER JOIN sectors s ON s.id = cs.db_sector_id + WHERE s.root = ? + `, (*sql.Hash256)(&root)).Scan(&n) + if err != nil { + t.Fatal(err) + } + return + } + + // helper to fetch sectors + type sector struct { + LatestHost types.PublicKey + Root types.Hash256 + SlabID int64 + } + fetchSectors := func() (sectors []sector) { + t.Helper() + rows, err := ss.DB().Query(context.Background(), "SELECT root, latest_host, db_slab_id FROM sectors") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var s sector + if err := rows.Scan((*sql.PublicKey)(&s.Root), (*sql.Hash256)(&s.LatestHost), &s.SlabID); err != nil { + t.Fatal(err) + } + sectors = append(sectors, s) + } + return } // Fetch the sector and assert the contracts association. - var sectors []dbSector - if err := ss.db.Model(&dbSector{}).Preload("Contracts").Find(§ors).Preload("Contracts").Error; err != nil { - t.Fatal(err) - } else if len(sectors) != 1 { + if sectors := fetchSectors(); len(sectors) != 1 { t.Fatal("expected 1 sector", len(sectors)) - } else if sector := sectors[0]; len(sector.Contracts) != 2 { - t.Fatal("expected 2 contracts", len(sector.Contracts)) + } else if cnt := sectorContractCnt(types.Hash256(sectors[0].Root)); cnt != 2 { + t.Fatal("expected 2 contracts", cnt) + } else if sectors[0].LatestHost != hk2 { + t.Fatalf("expected latest host to be hk2, got %v", sectors[0].LatestHost) + } else if sectors[0].SlabID != slabID { + t.Fatalf("expected slab id to be %v, got %v", slabID, sectors[0].SlabID) } hi, err := ss.Host(context.Background(), hk1) @@ -3770,12 +3809,14 @@ func TestDeleteHostSector(t *testing.T) { } // Fetch the sector and check the public key has the default value - if err := ss.db.Model(&dbSector{}).Find(§ors).Error; err != nil { - t.Fatal(err) - } else if len(sectors) != 1 { + if sectors := fetchSectors(); len(sectors) != 1 { t.Fatal("expected 1 sector", len(sectors)) + } else if cnt := sectorContractCnt(types.Hash256(sectors[0].Root)); cnt != 0 { + t.Fatal("expected 0 contracts", cnt) } else if sector := sectors[0]; sector.LatestHost != [32]byte{} { t.Fatal("expected latest host to be empty", sector.LatestHost) + } else if sectors[0].SlabID != slabID { + t.Fatalf("expected slab id to be %v, got %v", slabID, sectors[0].SlabID) } } func newTestShards(hk types.PublicKey, fcid types.FileContractID, root types.Hash256) []object.Sector { @@ -3868,13 +3909,11 @@ func TestSlabHealthInvalidation(t *testing.T) { assertHealthValid := func(slabKey object.EncryptionKey, expected bool) { t.Helper() - var slab dbSlab - if key, err := slabKey.MarshalBinary(); err != nil { + var validUntil int64 + if err := ss.DB().QueryRow(context.Background(), "SELECT health_valid_until FROM slabs WHERE `key` = ?", sql.EncryptionKey(slabKey)).Scan(&validUntil); err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlab{}).Where(&dbSlab{Key: key}).Take(&slab).Error; err != nil { - t.Fatal(err) - } else if slab.HealthValid() != expected { - t.Fatal("unexpected health valid", slab.HealthValid(), slab.HealthValidUntil, time.Now(), time.Unix(slab.HealthValidUntil, 0)) + } else if valid := time.Now().Before(time.Unix(validUntil, 0)); valid != expected { + t.Fatal("unexpected health valid", valid) } } @@ -3993,7 +4032,7 @@ func TestSlabHealthInvalidation(t *testing.T) { // assert the health validity is always updated to a random time in the future that matches the boundaries for i := 0; i < 1e3; i++ { // reset health validity - if tx := ss.db.Exec("UPDATE slabs SET health_valid_until = 0;"); tx.Error != nil { + if _, err := ss.DB().Exec(context.Background(), "UPDATE slabs SET health_valid_until = 0;"); err != nil { t.Fatal(err) } @@ -4003,19 +4042,17 @@ func TestSlabHealthInvalidation(t *testing.T) { t.Fatal(err) } - // fetch slab - var slab dbSlab - if key, err := s1.MarshalBinary(); err != nil { - t.Fatal(err) - } else if err := ss.db.Model(&dbSlab{}).Where(&dbSlab{Key: key}).Take(&slab).Error; err != nil { + // fetch health_valid_until + var validUntil int64 + if err := ss.DB().QueryRow(context.Background(), "SELECT health_valid_until FROM slabs").Scan(&validUntil); err != nil { t.Fatal(err) } // assert it's validity is within expected bounds minValidity := now.Add(refreshHealthMinHealthValidity).Add(-time.Second) // avoid NDF maxValidity := now.Add(refreshHealthMaxHealthValidity).Add(time.Second) // avoid NDF - validUntil := time.Unix(slab.HealthValidUntil, 0) - if !(minValidity.Before(validUntil) && maxValidity.After(validUntil)) { + validUntilUnix := time.Unix(validUntil, 0) + if !(minValidity.Before(validUntilUnix) && maxValidity.After(validUntilUnix)) { t.Fatal("valid until not in boundaries", minValidity, maxValidity, validUntil, now) } } @@ -4144,19 +4181,22 @@ func TestSlabCleanup(t *testing.T) { defer ss.Close() // create contract set - cs := dbContractSet{} - if err := ss.db.Create(&cs).Error; err != nil { + err := ss.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + return tx.SetContractSet(context.Background(), testContractSet, nil) + }) + if err != nil { t.Fatal(err) } + csID := ss.ContractSetID(testContractSet) // create buffered slab bsID := uint(1) - if err := ss.db.Exec("INSERT INTO buffered_slabs (filename) VALUES ('foo');").Error; err != nil { + if _, err := ss.DB().Exec(context.Background(), "INSERT INTO buffered_slabs (filename) VALUES ('foo');"); err != nil { t.Fatal(err) } var dirID int64 - err := ss.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + err = ss.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { var err error dirID, err = tx.MakeDirsForPath(context.Background(), "1") return err @@ -4166,172 +4206,90 @@ func TestSlabCleanup(t *testing.T) { } // create objects - obj1 := dbObject{ - DBDirectoryID: uint(dirID), - ObjectID: "1", - DBBucketID: ss.DefaultBucketID(), - Health: 1, - } - if err := ss.db.Create(&obj1).Error; err != nil { + insertObjStmt, err := ss.DB().Prepare(context.Background(), "INSERT INTO objects (db_directory_id, object_id, db_bucket_id, health) VALUES (?, ?, ?, ?);") + if err != nil { t.Fatal(err) } - obj2 := dbObject{ - DBDirectoryID: uint(dirID), - ObjectID: "2", - DBBucketID: ss.DefaultBucketID(), - Health: 1, - } - if err := ss.db.Create(&obj2).Error; err != nil { + defer insertObjStmt.Close() + + var obj1ID, obj2ID int64 + if res, err := insertObjStmt.Exec(context.Background(), dirID, "1", ss.DefaultBucketID(), 1); err != nil { + t.Fatal(err) + } else if obj1ID, err = res.LastInsertId(); err != nil { + t.Fatal(err) + } else if res, err := insertObjStmt.Exec(context.Background(), dirID, "2", ss.DefaultBucketID(), 1); err != nil { + t.Fatal(err) + } else if obj2ID, err = res.LastInsertId(); err != nil { t.Fatal(err) } // create a slab - ek, _ := object.GenerateEncryptionKey().MarshalBinary() - slab := dbSlab{ - DBContractSet: cs, - Health: 1, - Key: secretKey(ek), - HealthValidUntil: 100, + var slabID int64 + if res, err := ss.DB().Exec(context.Background(), "INSERT INTO slabs (db_contract_set_id, `key`, health_valid_until) VALUES (?, ?, ?);", csID, sql.EncryptionKey(object.GenerateEncryptionKey()), 100); err != nil { + t.Fatal(err) + } else if slabID, err = res.LastInsertId(); err != nil { + t.Fatal(err) } - if err := ss.db.Create(&slab).Error; err != nil { + + // statement to reference slabs by inserting a slice for an object + insertSlabRefStmt, err := ss.DB().Prepare(context.Background(), "INSERT INTO slices (db_object_id, db_slab_id) VALUES (?, ?);") + if err != nil { t.Fatal(err) } + defer insertSlabRefStmt.Close() // reference the slab - slice1 := dbSlice{ - DBObjectID: &obj1.ID, - DBSlabID: slab.ID, - } - if err := ss.db.Create(&slice1).Error; err != nil { + if _, err := insertSlabRefStmt.Exec(context.Background(), obj1ID, slabID); err != nil { t.Fatal(err) - } - slice2 := dbSlice{ - DBObjectID: &obj2.ID, - DBSlabID: slab.ID, - } - if err := ss.db.Create(&slice2).Error; err != nil { + } else if _, err := insertSlabRefStmt.Exec(context.Background(), obj2ID, slabID); err != nil { t.Fatal(err) } // delete the object - err = ss.RemoveObjectBlocking(context.Background(), api.DefaultBucketName, obj1.ObjectID) + err = ss.RemoveObjectBlocking(context.Background(), api.DefaultBucketName, "1") if err != nil { t.Fatal(err) } - // check slice count - var slabCntr int64 - if err := ss.db.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { - t.Fatal(err) - } else if slabCntr != 1 { + // check slab count + if slabCntr := ss.Count("slabs"); slabCntr != 1 { t.Fatalf("expected 1 slabs, got %v", slabCntr) } // delete second object - err = ss.RemoveObjectBlocking(context.Background(), api.DefaultBucketName, obj2.ObjectID) + err = ss.RemoveObjectBlocking(context.Background(), api.DefaultBucketName, "2") if err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { - t.Fatal(err) - } else if slabCntr != 0 { + } else if slabCntr := ss.Count("slabs"); slabCntr != 0 { t.Fatalf("expected 0 slabs, got %v", slabCntr) } - // create another object that references a slab with buffer - ek, _ = object.GenerateEncryptionKey().MarshalBinary() - bufferedSlab := dbSlab{ - DBBufferedSlabID: bsID, - DBContractSet: cs, - Health: 1, - Key: ek, - HealthValidUntil: 100, - } - if err := ss.db.Create(&bufferedSlab).Error; err != nil { - t.Fatal(err) - } - obj3 := dbObject{ - DBDirectoryID: uint(dirID), - ObjectID: "3", - DBBucketID: ss.DefaultBucketID(), - Health: 1, - } - if err := ss.db.Create(&obj3).Error; err != nil { + // create another slab referencing the buffered slab + var bufferedSlabID int64 + if res, err := ss.DB().Exec(context.Background(), "INSERT INTO slabs (db_buffered_slab_id, db_contract_set_id, `key`, health_valid_until) VALUES (?, ?, ?, ?);", bsID, csID, sql.EncryptionKey(object.GenerateEncryptionKey()), 100); err != nil { t.Fatal(err) - } - slice := dbSlice{ - DBObjectID: &obj3.ID, - DBSlabID: bufferedSlab.ID, - } - if err := ss.db.Create(&slice).Error; err != nil { - t.Fatal(err) - } - if err := ss.db.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { + } else if bufferedSlabID, err = res.LastInsertId(); err != nil { t.Fatal(err) - } else if slabCntr != 1 { - t.Fatalf("expected 1 slabs, got %v", slabCntr) } - // delete third object - err = ss.RemoveObjectBlocking(context.Background(), api.DefaultBucketName, obj3.ObjectID) - if err != nil { + var obj3ID int64 + if res, err := insertObjStmt.Exec(context.Background(), dirID, "3", ss.DefaultBucketID(), 1); err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { + } else if obj3ID, err = res.LastInsertId(); err != nil { t.Fatal(err) - } else if slabCntr != 1 { - t.Fatalf("expected 1 slabs, got %v", slabCntr) - } -} - -func TestUpsertSectors(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - err := ss.db.Create(&dbSlab{ - DBContractSetID: 1, - Key: []byte{1}, - }).Error - if err != nil { + } else if _, err := insertSlabRefStmt.Exec(context.Background(), obj3ID, bufferedSlabID); err != nil { t.Fatal(err) } - - err = ss.db.Create(&dbSector{ - DBSlabID: 1, - SlabIndex: 2, - Root: []byte{2}, - }).Error - if err != nil { - t.Fatal(err) + if slabCntr := ss.Count("slabs"); slabCntr != 1 { + t.Fatalf("expected 1 slabs, got %v", slabCntr) } - sectors := []dbSector{ - { - DBSlabID: 1, - SlabIndex: 1, - Root: []byte{1}, - }, - { - DBSlabID: 1, - SlabIndex: 2, - Root: []byte{2}, - }, - { - DBSlabID: 1, - SlabIndex: 3, - Root: []byte{3}, - }, - } - sectorIDs, err := upsertSectors(ss.db, sectors) + // delete third object + err = ss.RemoveObjectBlocking(context.Background(), api.DefaultBucketName, "3") if err != nil { t.Fatal(err) - } - - for i, id := range sectorIDs { - var sector dbSector - if err := ss.db.Where("id", id).Take(§or).Error; err != nil { - t.Fatal(err) - } else if sector.SlabIndex != i+1 { - t.Fatal("unexpected slab index", sector.SlabIndex) - } + } else if slabCntr := ss.Count("slabs"); slabCntr != 1 { + t.Fatalf("expected 1 slabs, got %v", slabCntr) } } @@ -4389,74 +4347,149 @@ func TestUpdateObjectReuseSlab(t *testing.T) { t.Fatal(err) } + // helper to fetch relevant fields from an object + fetchObj := func(bid int64, oid string) (id, bucketID int64, objectID string, health float64, size int64) { + t.Helper() + err := ss.DB().QueryRow(context.Background(), ` + SELECT id, db_bucket_id, object_id, health, size + FROM objects + WHERE db_bucket_id = ? AND object_id = ? + `, bid, oid).Scan(&id, &bucketID, &objectID, &health, &size) + if err != nil { + t.Fatal(err) + } + return + } + // fetch the object - var dbObj dbObject - if err := ss.db.Where("db_bucket_id", ss.DefaultBucketID()).Take(&dbObj).Error; err != nil { - t.Fatal(err) - } else if dbObj.ID != 1 { - t.Fatal("unexpected id", dbObj.ID) - } else if dbObj.DBBucketID != ss.DefaultBucketID() { - t.Fatal("bucket id mismatch", dbObj.DBBucketID) - } else if dbObj.ObjectID != "1" { - t.Fatal("object id mismatch", dbObj.ObjectID) - } else if dbObj.Health != 1 { - t.Fatal("health mismatch", dbObj.Health) - } else if dbObj.Size != obj.TotalSize() { - t.Fatal("size mismatch", dbObj.Size) + id, bid, oid, health, size := fetchObj(ss.DefaultBucketID(), "1") + if id != 1 { + t.Fatal("unexpected id", id) + } else if bid != ss.DefaultBucketID() { + t.Fatal("bucket id mismatch", bid) + } else if oid != "1" { + t.Fatal("object id mismatch", oid) + } else if health != 1 { + t.Fatal("health mismatch", health) + } else if size != obj.TotalSize() { + t.Fatal("size mismatch", size) + } + + // helper to fetch object's slices + type slice struct { + ID int64 + ObjectIndex int64 + Offset int64 + Length int64 + SlabID int64 + } + fetchSlicesByObjectID := func(oid int64) (slices []slice) { + t.Helper() + rows, err := ss.DB().Query(context.Background(), "SELECT id, object_index, offset, length, db_slab_id FROM slices WHERE db_object_id = ?", oid) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var s slice + if err := rows.Scan(&s.ID, &s.ObjectIndex, &s.Offset, &s.Length, &s.SlabID); err != nil { + t.Fatal(err) + } + slices = append(slices, s) + } + return } // fetch its slices - var dbSlices []dbSlice - if err := ss.db.Where("db_object_id", dbObj.ID).Find(&dbSlices).Error; err != nil { - t.Fatal(err) - } else if len(dbSlices) != 2 { - t.Fatal("invalid number of slices", len(dbSlices)) - } - for i, dbSlice := range dbSlices { - if dbSlice.ID != uint(i+1) { - t.Fatal("unexpected id", dbSlice.ID) - } else if dbSlice.ObjectIndex != uint(i+1) { - t.Fatal("unexpected object index", dbSlice.ObjectIndex) - } else if dbSlice.Offset != 0 || dbSlice.Length != uint32(minShards)*rhpv2.SectorSize { - t.Fatal("invalid offset/length", dbSlice.Offset, dbSlice.Length) + slices := fetchSlicesByObjectID(id) + if len(slices) != 2 { + t.Fatal("invalid number of slices", len(slices)) + } + + // helper to fetch sectors + type sector struct { + ID int64 + SlabID int64 + LatestHost types.PublicKey + Root types.Hash256 + } + fetchSectorsBySlabID := func(slabID int64) (sectors []sector) { + t.Helper() + rows, err := ss.DB().Query(context.Background(), "SELECT id, db_slab_id, root, latest_host FROM sectors WHERE db_slab_id = ?", slabID) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var s sector + if err := rows.Scan(&s.ID, &s.SlabID, (*sql.PublicKey)(&s.Root), (*sql.Hash256)(&s.LatestHost)); err != nil { + t.Fatal(err) + } + sectors = append(sectors, s) + } + return + } + + // helper type to fetch a slab + type slab struct { + ID int64 + ContractSetID int64 + Health float64 + HealthValidUntil int64 + MinShards uint8 + TotalShards uint8 + Key object.EncryptionKey + } + fetchSlabStmt, err := ss.DB().Prepare(context.Background(), "SELECT id, db_contract_set_id, health, health_valid_until, min_shards, total_shards, `key` FROM slabs WHERE id = ?") + if err != nil { + t.Fatal(err) + } + defer fetchSlabStmt.Close() + + for i, slice := range slices { + if slice.ID != int64(i+1) { + t.Fatal("unexpected id", slice.ID) + } else if slice.ObjectIndex != int64(i+1) { + t.Fatal("unexpected object index", slice.ObjectIndex) + } else if slice.Offset != 0 || slice.Length != int64(minShards)*rhpv2.SectorSize { + t.Fatal("invalid offset/length", slice.Offset, slice.Length) } // fetch the slab - var dbSlab dbSlab - key, _ := obj.Slabs[i].Key.MarshalBinary() - if err := ss.db.Where("id", dbSlice.DBSlabID).Take(&dbSlab).Error; err != nil { + var slab slab + err = fetchSlabStmt.QueryRow(context.Background(), slice.SlabID). + Scan(&slab.ID, &slab.ContractSetID, &slab.Health, &slab.HealthValidUntil, &slab.MinShards, &slab.TotalShards, (*sql.EncryptionKey)(&slab.Key)) + if err != nil { t.Fatal(err) - } else if dbSlab.ID != uint(i+1) { - t.Fatal("unexpected id", dbSlab.ID) - } else if dbSlab.DBContractSetID != 1 { - t.Fatal("invalid contract set id", dbSlab.DBContractSetID) - } else if dbSlab.Health != 1 { - t.Fatal("invalid health", dbSlab.Health) - } else if dbSlab.HealthValidUntil != 0 { - t.Fatal("invalid health validity", dbSlab.HealthValidUntil) - } else if dbSlab.MinShards != uint8(minShards) { - t.Fatal("invalid minShards", dbSlab.MinShards) - } else if dbSlab.TotalShards != uint8(totalShards) { - t.Fatal("invalid totalShards", dbSlab.TotalShards) - } else if !bytes.Equal(dbSlab.Key, key) { + } else if slab.ID != int64(i+1) { + t.Fatal("unexpected id", slab.ID) + } else if slab.ContractSetID != 1 { + t.Fatal("invalid contract set id", slab.ContractSetID) + } else if slab.Health != 1 { + t.Fatal("invalid health", slab.Health) + } else if slab.HealthValidUntil != 0 { + t.Fatal("invalid health validity", slab.HealthValidUntil) + } else if slab.MinShards != uint8(minShards) { + t.Fatal("invalid minShards", slab.MinShards) + } else if slab.TotalShards != uint8(totalShards) { + t.Fatal("invalid totalShards", slab.TotalShards) + } else if slab.Key.String() != obj.Slabs[i].Key.String() { t.Fatal("wrong key") } // fetch the sectors - var dbSectors []dbSector - if err := ss.db.Where("db_slab_id", dbSlab.ID).Find(&dbSectors).Error; err != nil { - t.Fatal(err) - } else if len(dbSectors) != totalShards { - t.Fatal("invalid number of sectors", len(dbSectors)) + sectors := fetchSectorsBySlabID(int64(slab.ID)) + if len(sectors) != totalShards { + t.Fatal("invalid number of sectors", len(sectors)) } - for j, dbSector := range dbSectors { - if dbSector.ID != uint(i*totalShards+j+1) { - t.Fatal("invalid id", dbSector.ID) - } else if dbSector.DBSlabID != dbSlab.ID { - t.Fatal("invalid slab id", dbSector.DBSlabID) - } else if dbSector.LatestHost != publicKey(hks[i*totalShards+j]) { + for j, sector := range sectors { + if sector.ID != int64(i*totalShards+j+1) { + t.Fatal("invalid id", sector.ID) + } else if sector.SlabID != int64(slab.ID) { + t.Fatal("invalid slab id", sector.SlabID) + } else if sector.LatestHost != hks[i*totalShards+j] { t.Fatal("invalid host") - } else if !bytes.Equal(dbSector.Root, obj.Slabs[i].Shards[j].Root[:]) { + } else if sector.Root != obj.Slabs[i].Shards[j].Root { t.Fatal("invalid root") } } @@ -4496,179 +4529,104 @@ func TestUpdateObjectReuseSlab(t *testing.T) { } // fetch the object - var dbObj2 dbObject - if err := ss.db.Where("db_bucket_id", ss.DefaultBucketID()). - Where("object_id", "2"). - Take(&dbObj2).Error; err != nil { - t.Fatal(err) - } else if dbObj2.ID != 2 { - t.Fatal("unexpected id", dbObj2.ID) - } else if dbObj.Size != obj2.TotalSize() { - t.Fatal("size mismatch", dbObj2.Size) + id2, bid2, oid2, health2, size2 := fetchObj(ss.DefaultBucketID(), "2") + if id2 != 2 { + t.Fatal("unexpected id", id) + } else if bid2 != ss.DefaultBucketID() { + t.Fatal("bucket id mismatch", bid) + } else if oid2 != "2" { + t.Fatal("object id mismatch", oid) + } else if health2 != 1 { + t.Fatal("health mismatch", health) + } else if size2 != obj.TotalSize() { + t.Fatal("size mismatch", size) } // fetch its slices - var dbSlices2 []dbSlice - if err := ss.db.Where("db_object_id", dbObj2.ID).Find(&dbSlices2).Error; err != nil { - t.Fatal(err) - } else if len(dbSlices2) != 2 { - t.Fatal("invalid number of slices", len(dbSlices)) + slices2 := fetchSlicesByObjectID(id2) + if len(slices2) != 2 { + t.Fatal("invalid number of slices", len(slices2)) } // check the first one - dbSlice2 := dbSlices2[0] - if dbSlice2.ID != uint(len(dbSlices)+1) { - t.Fatal("unexpected id", dbSlice2.ID) - } else if dbSlice2.ObjectIndex != uint(1) { - t.Fatal("unexpected object index", dbSlice2.ObjectIndex) - } else if dbSlice2.Offset != 0 || dbSlice2.Length != uint32(minShards)*rhpv2.SectorSize { - t.Fatal("invalid offset/length", dbSlice2.Offset, dbSlice2.Length) + slice2 := slices2[0] + if slice2.ID != int64(len(slices)+1) { + t.Fatal("unexpected id", slice2.ID) + } else if slice2.ObjectIndex != 1 { + t.Fatal("unexpected object index", slice2.ObjectIndex) + } else if slice2.Offset != 0 || slice2.Length != int64(minShards)*rhpv2.SectorSize { + t.Fatal("invalid offset/length", slice2.Offset, slice2.Length) } // fetch the slab - var dbSlab2 dbSlab - key, _ := obj2.Slabs[0].Key.MarshalBinary() - if err := ss.db.Where("id", dbSlice2.DBSlabID).Take(&dbSlab2).Error; err != nil { - t.Fatal(err) - } else if dbSlab2.ID != uint(len(dbSlices)+1) { - t.Fatal("unexpected id", dbSlab2.ID) - } else if dbSlab2.DBContractSetID != 1 { - t.Fatal("invalid contract set id", dbSlab2.DBContractSetID) - } else if !bytes.Equal(dbSlab2.Key, key) { + var slab2 slab + err = fetchSlabStmt.QueryRow(context.Background(), slice2.SlabID). + Scan(&slab2.ID, &slab2.ContractSetID, &slab2.Health, &slab2.HealthValidUntil, &slab2.MinShards, &slab2.TotalShards, (*sql.EncryptionKey)(&slab2.Key)) + if err != nil { + t.Fatal(err) + } else if slab2.ID != int64(len(slices)+1) { + t.Fatal("unexpected id", slab2.ID) + } else if slab2.ContractSetID != 1 { + t.Fatal("invalid contract set id", slab2.ContractSetID) + } else if slab2.Health != 1 { + t.Fatal("invalid health", slab2.Health) + } else if slab2.HealthValidUntil != 0 { + t.Fatal("invalid health validity", slab2.HealthValidUntil) + } else if slab2.MinShards != uint8(minShards) { + t.Fatal("invalid minShards", slab2.MinShards) + } else if slab2.TotalShards != uint8(totalShards) { + t.Fatal("invalid totalShards", slab2.TotalShards) + } else if slab2.Key.String() != obj2.Slabs[0].Key.String() { t.Fatal("wrong key") } // fetch the sectors - var dbSectors2 []dbSector - if err := ss.db.Where("db_slab_id", dbSlab2.ID).Find(&dbSectors2).Error; err != nil { - t.Fatal(err) - } else if len(dbSectors2) != totalShards { - t.Fatal("invalid number of sectors", len(dbSectors2)) - } - for j, dbSector := range dbSectors2 { - if dbSector.ID != uint((len(obj.Slabs))*totalShards+j+1) { - t.Fatal("invalid id", dbSector.ID) - } else if dbSector.DBSlabID != dbSlab2.ID { - t.Fatal("invalid slab id", dbSector.DBSlabID) - } else if dbSector.LatestHost != publicKey(hks[(len(obj.Slabs))*totalShards+j]) { + sectors2 := fetchSectorsBySlabID(int64(slab2.ID)) + if len(sectors2) != totalShards { + t.Fatal("invalid number of sectors", len(sectors2)) + } + for j, sector := range sectors2 { + if sector.ID != int64((len(obj.Slabs))*totalShards+j+1) { + t.Fatal("invalid id", sector.ID) + } else if sector.SlabID != int64(slab2.ID) { + t.Fatal("invalid slab id", sector.SlabID) + } else if sector.LatestHost != hks[(len(obj.Slabs))*totalShards+j] { t.Fatal("invalid host") - } else if !bytes.Equal(dbSector.Root, obj2.Slabs[0].Shards[j].Root[:]) { + } else if sector.Root != obj2.Slabs[0].Shards[j].Root { t.Fatal("invalid root") } } // the second slab of obj2 should be the same as the first in obj - if dbSlices2[1].DBSlabID != 2 { + if slices2[1].SlabID != 2 { t.Fatal("wrong slab") } - var contractSectors []dbContractSector - if err := ss.db.Find(&contractSectors).Error; err != nil { - t.Fatal(err) - } else if len(contractSectors) != 3*totalShards { - t.Fatal("invalid number of contract sectors", len(contractSectors)) + type contractSector struct { + ContractID int64 + SectorID int64 } - for i, cs := range contractSectors { - if cs.DBContractID != uint(i+1) { - t.Fatal("invalid contract id") - } else if cs.DBSectorID != uint(i+1) { - t.Fatal("invalid sector id") - } + var contractSectors []contractSector + rows, err := ss.DB().Query(context.Background(), "SELECT db_contract_id, db_sector_id FROM contract_sectors") + if err != nil { + t.Fatal(err) } -} - -func TestTypeCurrency(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - // prepare the table - if isSQLite(ss.db) { - if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil { + defer rows.Close() + for rows.Next() { + var cs contractSector + if err := rows.Scan(&cs.ContractID, &cs.SectorID); err != nil { t.Fatal(err) } - } else { - if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil { - t.Fatal(err) - } - } - - // insert currencies in random order - if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil { - t.Fatal(err) - } - - // fetch currencies and assert they're sorted - var currencies []bCurrency - if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(¤cies).Error; err != nil { - t.Fatal(err) - } else if !sort.SliceIsSorted(currencies, func(i, j int) bool { - return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0 - }) { - t.Fatal("currencies not sorted", currencies) + contractSectors = append(contractSectors, cs) } - - // convenience variables - c0 := currencies[0] - c1 := currencies[1] - cM := currencies[2] - - tests := []struct { - a bCurrency - b bCurrency - cmp string - }{ - { - a: c0, - b: c1, - cmp: "<", - }, - { - a: c1, - b: c0, - cmp: ">", - }, - { - a: c0, - b: c1, - cmp: "!=", - }, - { - a: c1, - b: c1, - cmp: "=", - }, - { - a: c0, - b: cM, - cmp: "<", - }, - { - a: cM, - b: c0, - cmp: ">", - }, - { - a: cM, - b: cM, - cmp: "=", - }, + if len(contractSectors) != 3*totalShards { + t.Fatal("invalid number of contract sectors", len(contractSectors)) } - for i, test := range tests { - var result bool - query := fmt.Sprintf("SELECT ? %s ?", test.cmp) - if !isSQLite(ss.db) { - query = strings.ReplaceAll(query, "?", "HEX(?)") - } - if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil { - t.Fatal(err) - } else if !result { - t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String()) - } else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 { - t.Fatal("invalid result") - } else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 { - t.Fatal("invalid result") - } else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 { - t.Fatal("invalid result") + for i, cs := range contractSectors { + if cs.ContractID != int64(i+1) { + t.Fatal("invalid contract id") + } else if cs.SectorID != int64(i+1) { + t.Fatal("invalid sector id") } } } @@ -4771,86 +4729,6 @@ func TestUpdateObjectParallel(t *testing.T) { wg.Wait() } -// TestFetchUsedContracts is a unit test that verifies the functionality of -// fetchUsedContracts -func TestFetchUsedContracts(t *testing.T) { - // create store - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - // add test host - hk1 := types.PublicKey{1} - err := ss.addTestHost(hk1) - if err != nil { - t.Fatal(err) - } - - // add test contract - fcid1 := types.FileContractID{1} - _, err = ss.addTestContract(fcid1, hk1) - if err != nil { - t.Fatal(err) - } - - // assert empty map returns no contracts - usedContracts := make(map[types.PublicKey]map[types.FileContractID]struct{}) - contracts, err := fetchUsedContracts(ss.db, usedContracts) - if err != nil { - t.Fatal(err) - } else if len(contracts) != 0 { - t.Fatal("expected 0 contracts", len(contracts)) - } - - // add an entry for fcid1 - usedContracts[hk1] = make(map[types.FileContractID]struct{}) - usedContracts[hk1][types.FileContractID{1}] = struct{}{} - - // assert we get the used contract - contracts, err = fetchUsedContracts(ss.db, usedContracts) - if err != nil { - t.Fatal(err) - } else if len(contracts) != 1 { - t.Fatal("expected 1 contract", len(contracts)) - } else if _, ok := contracts[fcid1]; !ok { - t.Fatal("contract not found") - } - - // renew the contract - fcid2 := types.FileContractID{2} - _, err = ss.addTestRenewedContract(fcid2, fcid1, hk1, 1) - if err != nil { - t.Fatal(err) - } - - // assert used contracts contains one entry and it points to the renewal - contracts, err = fetchUsedContracts(ss.db, usedContracts) - if err != nil { - t.Fatal(err) - } else if len(contracts) != 1 { - t.Fatal("expected 1 contract", len(contracts)) - } else if contract, ok := contracts[fcid1]; !ok { - t.Fatal("contract not found") - } else if contract.convert().ID != fcid2 { - t.Fatal("contract should point to the renewed contract") - } - - // add an entry for fcid2 - usedContracts[hk1][types.FileContractID{2}] = struct{}{} - - // assert used contracts now contains an entry for both contracts and both - // point to the renewed contract - contracts, err = fetchUsedContracts(ss.db, usedContracts) - if err != nil { - t.Fatal(err) - } else if len(contracts) != 2 { - t.Fatal("expected 2 contracts", len(contracts)) - } else if !reflect.DeepEqual(contracts[types.FileContractID{1}], contracts[types.FileContractID{2}]) { - t.Fatal("contracts should match") - } else if contracts[types.FileContractID{1}].convert().ID != fcid2 { - t.Fatal("contracts should point to the renewed contract") - } -} - func TestDirectories(t *testing.T) { ss := newTestSQLStore(t, defaultTestSQLStoreConfig) defer ss.Close() @@ -4866,7 +4744,7 @@ func TestDirectories(t *testing.T) { for _, o := range objects { var dirID int64 - err := ss.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + err := ss.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { var err error dirID, err = tx.MakeDirsForPath(context.Background(), o) return err @@ -4880,8 +4758,8 @@ func TestDirectories(t *testing.T) { expectedDirs := []struct { name string - id uint - parentID uint + id int64 + parentID int64 }{ { name: "/", @@ -4905,24 +4783,37 @@ func TestDirectories(t *testing.T) { }, { name: "/dir/", - id: 2, + id: 5, parentID: 1, }, } - var dbDirs []dbDirectory - if err := ss.db.Find(&dbDirs).Error; err != nil { + type row struct { + ID int64 + ParentID int64 + Name string + } + rows, err := ss.DB().Query(context.Background(), "SELECT id, COALESCE(db_parent_id, 0), name FROM directories ORDER BY id ASC") + if err != nil { t.Fatal(err) - } else if len(dbDirs) != len(expectedDirs) { - t.Fatalf("expected %v dirs, got %v", len(expectedDirs), len(dbDirs)) } - - for i, dbDir := range dbDirs { - if dbDir.ID != uint(i+1) { - t.Fatalf("unexpected id %v", dbDir.ID) - } else if dbDir.Name != expectedDirs[i].name { - t.Fatalf("unexpected name '%v' != '%v'", dbDir.Name, expectedDirs[i].name) + defer rows.Close() + var nDirs int + for i := 0; rows.Next(); i++ { + var dir row + if err := rows.Scan(&dir.ID, &dir.ParentID, &dir.Name); err != nil { + t.Fatal(err) + } else if dir.ID != expectedDirs[i].id { + t.Fatalf("unexpected id %v", dir.ID) + } else if dir.ParentID != expectedDirs[i].parentID { + t.Fatalf("unexpected parent id %v", dir.ParentID) + } else if dir.Name != expectedDirs[i].name { + t.Fatalf("unexpected name '%v' != '%v'", dir.Name, expectedDirs[i].name) } + nDirs++ + } + if len(expectedDirs) != nDirs { + t.Fatalf("expected %v dirs, got %v", len(expectedDirs), nDirs) } now := time.Now() @@ -4931,10 +4822,7 @@ func TestDirectories(t *testing.T) { return ss.waitForPruneLoop(now) }) - var n int64 - if err := ss.db.Model(&dbDirectory{}).Count(&n).Error; err != nil { - t.Fatal(err) - } else if n != 1 { + if n := ss.Count("directories"); n != 1 { t.Fatal("expected 1 dir, got", n) } } diff --git a/stores/metrics.go b/stores/metrics.go index 45003dfb9..62dbde8ce 100644 --- a/stores/metrics.go +++ b/stores/metrics.go @@ -9,7 +9,7 @@ import ( ) func (s *SQLStore) ContractMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractMetricsQueryOpts) (metrics []api.ContractMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.ContractMetrics(ctx, start, n, interval, opts) return }) @@ -17,7 +17,7 @@ func (s *SQLStore) ContractMetrics(ctx context.Context, start time.Time, n uint6 } func (s *SQLStore) ContractPruneMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractPruneMetricsQueryOpts) (metrics []api.ContractPruneMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.ContractPruneMetrics(ctx, start, n, interval, opts) return }) @@ -25,7 +25,7 @@ func (s *SQLStore) ContractPruneMetrics(ctx context.Context, start time.Time, n } func (s *SQLStore) ContractSetChurnMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractSetChurnMetricsQueryOpts) (metrics []api.ContractSetChurnMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.ContractSetChurnMetrics(ctx, start, n, interval, opts) return }) @@ -33,7 +33,7 @@ func (s *SQLStore) ContractSetChurnMetrics(ctx context.Context, start time.Time, } func (s *SQLStore) ContractSetMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractSetMetricsQueryOpts) (metrics []api.ContractSetMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.ContractSetMetrics(ctx, start, n, interval, opts) return }) @@ -41,7 +41,7 @@ func (s *SQLStore) ContractSetMetrics(ctx context.Context, start time.Time, n ui } func (s *SQLStore) PerformanceMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.PerformanceMetricsQueryOpts) (metrics []api.PerformanceMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.PerformanceMetrics(ctx, start, n, interval, opts) return }) @@ -49,43 +49,43 @@ func (s *SQLStore) PerformanceMetrics(ctx context.Context, start time.Time, n ui } func (s *SQLStore) RecordContractMetric(ctx context.Context, metrics ...api.ContractMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordContractMetric(ctx, metrics...) }) } func (s *SQLStore) RecordContractPruneMetric(ctx context.Context, metrics ...api.ContractPruneMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordContractPruneMetric(ctx, metrics...) }) } func (s *SQLStore) RecordContractSetChurnMetric(ctx context.Context, metrics ...api.ContractSetChurnMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordContractSetChurnMetric(ctx, metrics...) }) } func (s *SQLStore) RecordContractSetMetric(ctx context.Context, metrics ...api.ContractSetMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordContractSetMetric(ctx, metrics...) }) } func (s *SQLStore) RecordPerformanceMetric(ctx context.Context, metrics ...api.PerformanceMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordPerformanceMetric(ctx, metrics...) }) } func (s *SQLStore) RecordWalletMetric(ctx context.Context, metrics ...api.WalletMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordWalletMetric(ctx, metrics...) }) } func (s *SQLStore) WalletMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.WalletMetricsQueryOpts) (metrics []api.WalletMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.WalletMetrics(ctx, start, n, interval, opts) return }) @@ -93,7 +93,7 @@ func (s *SQLStore) WalletMetrics(ctx context.Context, start time.Time, n uint64, } func (s *SQLStore) PruneMetrics(ctx context.Context, metric string, cutoff time.Time) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.PruneMetrics(ctx, metric, cutoff) }) } diff --git a/stores/metrics_test.go b/stores/metrics_test.go index 5725f7a65..9a9f7b71b 100644 --- a/stores/metrics_test.go +++ b/stores/metrics_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/stores/sql" "lukechampine.com/frand" ) @@ -149,7 +150,7 @@ func TestContractMetrics(t *testing.T) { } for _, m := range metrics { expectedMetric := fcid2Metric[m.ContractID] - expectedMetric.Timestamp = api.TimeRFC3339(normaliseTimestamp(start, interval, unixTimeMS(expectedMetric.Timestamp))) + expectedMetric.Timestamp = api.TimeRFC3339(normaliseTimestamp(start, interval, sql.UnixTimeMS(expectedMetric.Timestamp))) if !cmp.Equal(m, expectedMetric, cmp.Comparer(api.CompareTimeRFC3339)) { t.Fatal("unexpected metric", cmp.Diff(m, expectedMetric, cmp.Comparer(api.CompareTimeRFC3339))) } @@ -181,7 +182,7 @@ func TestContractMetrics(t *testing.T) { } for i, m := range metrics { var expectedMetric api.ContractMetric - expectedMetric.Timestamp = api.TimeRFC3339(normaliseTimestamp(start, time.Millisecond, unixTimeMS(metricsTimeAsc[2*i].Timestamp))) + expectedMetric.Timestamp = api.TimeRFC3339(normaliseTimestamp(start, time.Millisecond, sql.UnixTimeMS(metricsTimeAsc[2*i].Timestamp))) expectedMetric.ContractID = types.FileContractID{} expectedMetric.HostKey = types.PublicKey{} expectedMetric.RemainingCollateral, _ = metricsTimeAsc[2*i].RemainingCollateral.AddWithOverflow(metricsTimeAsc[2*i+1].RemainingCollateral) @@ -424,7 +425,7 @@ func TestNormaliseTimestamp(t *testing.T) { } for _, test := range tests { - if result := time.Time(normaliseTimestamp(test.start, test.interval, unixTimeMS(test.ti))); !result.Equal(test.result) { + if result := time.Time(normaliseTimestamp(test.start, test.interval, sql.UnixTimeMS(test.ti))); !result.Equal(test.result) { t.Fatalf("expected %v, got %v", test.result, result) } } @@ -527,6 +528,7 @@ func TestWalletMetrics(t *testing.T) { Confirmed: types.NewCurrency(frand.Uint64n(math.MaxUint64), frand.Uint64n(math.MaxUint64)), Unconfirmed: types.NewCurrency(frand.Uint64n(math.MaxUint64), frand.Uint64n(math.MaxUint64)), Spendable: types.NewCurrency(frand.Uint64n(math.MaxUint64), frand.Uint64n(math.MaxUint64)), + Immature: types.NewCurrency(frand.Uint64n(math.MaxUint64), frand.Uint64n(math.MaxUint64)), } if err := ss.RecordWalletMetric(context.Background(), metric); err != nil { t.Fatal(err) @@ -555,13 +557,13 @@ func TestWalletMetrics(t *testing.T) { } } -func normaliseTimestamp(start time.Time, interval time.Duration, t unixTimeMS) unixTimeMS { +func normaliseTimestamp(start time.Time, interval time.Duration, t sql.UnixTimeMS) sql.UnixTimeMS { startMS := start.UnixMilli() toNormaliseMS := time.Time(t).UnixMilli() intervalMS := interval.Milliseconds() if startMS > toNormaliseMS { - return unixTimeMS(start) + return sql.UnixTimeMS(start) } normalizedMS := (toNormaliseMS-startMS)/intervalMS*intervalMS + start.UnixMilli() - return unixTimeMS(time.UnixMilli(normalizedMS)) + return sql.UnixTimeMS(time.UnixMilli(normalizedMS)) } diff --git a/stores/multipart.go b/stores/multipart.go index 95aad7104..ec987619f 100644 --- a/stores/multipart.go +++ b/stores/multipart.go @@ -12,7 +12,7 @@ import ( func (s *SQLStore) CreateMultipartUpload(ctx context.Context, bucket, path string, ec object.EncryptionKey, mimeType string, metadata api.ObjectUserMetadata) (api.MultipartCreateResponse, error) { var uploadID string - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { uploadID, err = tx.InsertMultipartUpload(ctx, bucket, path, ec, mimeType, metadata) return }) @@ -25,13 +25,13 @@ func (s *SQLStore) CreateMultipartUpload(ctx context.Context, bucket, path strin } func (s *SQLStore) AddMultipartPart(ctx context.Context, bucket, path, contractSet, eTag, uploadID string, partNumber int, slices []object.SlabSlice) (err error) { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.AddMultipartPart(ctx, bucket, path, contractSet, eTag, uploadID, partNumber, slices) }) } func (s *SQLStore) MultipartUpload(ctx context.Context, uploadID string) (resp api.MultipartUpload, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { resp, err = tx.MultipartUpload(ctx, uploadID) return }) @@ -39,7 +39,7 @@ func (s *SQLStore) MultipartUpload(ctx context.Context, uploadID string) (resp a } func (s *SQLStore) MultipartUploads(ctx context.Context, bucket, prefix, keyMarker, uploadIDMarker string, limit int) (resp api.MultipartListUploadsResponse, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { resp, err = tx.MultipartUploads(ctx, bucket, prefix, keyMarker, uploadIDMarker, limit) return }) @@ -47,7 +47,7 @@ func (s *SQLStore) MultipartUploads(ctx context.Context, bucket, prefix, keyMark } func (s *SQLStore) MultipartUploadParts(ctx context.Context, bucket, object string, uploadID string, marker int, limit int64) (resp api.MultipartListPartsResponse, _ error) { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { resp, err = tx.MultipartUploadParts(ctx, bucket, object, uploadID, marker, limit) return }) @@ -55,7 +55,7 @@ func (s *SQLStore) MultipartUploadParts(ctx context.Context, bucket, object stri } func (s *SQLStore) AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) error { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.AbortMultipartUpload(ctx, bucket, path, uploadID) }) if err != nil { @@ -80,7 +80,7 @@ func (s *SQLStore) CompleteMultipartUpload(ctx context.Context, bucket, path str var eTag string var prune bool - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { // Delete potentially existing object. prune, err = tx.DeleteObject(ctx, bucket, path) if err != nil { diff --git a/stores/multipart_test.go b/stores/multipart_test.go index 762ea45a9..52dddb050 100644 --- a/stores/multipart_test.go +++ b/stores/multipart_test.go @@ -71,31 +71,43 @@ func TestMultipartUploadWithUploadPackingRegression(t *testing.T) { }) } + type oum struct { + MultipartUploadID *int64 + ObjectID *int64 + } + fetchUserMD := func() (metadatas []oum) { + rows, err := ss.DB().Query(context.Background(), "SELECT db_multipart_upload_id, db_object_id FROM object_user_metadata") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var md oum + if err := rows.Scan(&md.MultipartUploadID, &md.ObjectID); err != nil { + t.Fatal(err) + } + metadatas = append(metadatas, md) + } + return + } + // Assert metadata was persisted and is linked to the multipart upload - var metadatas []dbObjectUserMetadata - if err := ss.db.Model(&dbObjectUserMetadata{}).Find(&metadatas).Error; err != nil { - t.Fatal(err) - } else if len(metadatas) != len(testMetadata) { + metadatas := fetchUserMD() + if len(metadatas) != len(testMetadata) { t.Fatal("expected metadata to be persisted") } for _, m := range metadatas { - if m.DBMultipartUploadID == nil || m.DBObjectID != nil { + if m.MultipartUploadID == nil || m.ObjectID != nil { t.Fatal("unexpected") } } // Complete the upload. Check that the number of slices stays the same. - var nSlicesBefore int64 - var nSlicesAfter int64 - if err := ss.db.Model(&dbSlice{}).Count(&nSlicesBefore).Error; err != nil { - t.Fatal(err) - } else if nSlicesBefore == 0 { + if nSlicesBefore := ss.Count("slices"); nSlicesBefore == 0 { t.Fatal("expected some slices") } else if _, err = ss.CompleteMultipartUpload(ctx, api.DefaultBucketName, objName, resp.UploadID, parts, api.CompleteMultipartOptions{}); err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlice{}).Count(&nSlicesAfter).Error; err != nil { - t.Fatal(err) - } else if nSlicesBefore != nSlicesAfter { + } else if nSlicesAfter := ss.Count("slices"); nSlicesAfter != nSlicesBefore { t.Fatalf("expected number of slices to stay the same, but got %v before and %v after", nSlicesBefore, nSlicesAfter) } @@ -115,13 +127,12 @@ func TestMultipartUploadWithUploadPackingRegression(t *testing.T) { } // Assert metadata was converted and the multipart upload id was nullified - if err := ss.db.Model(&dbObjectUserMetadata{}).Find(&metadatas).Error; err != nil { - t.Fatal(err) - } else if len(metadatas) != len(testMetadata) { + metadatas = fetchUserMD() + if len(metadatas) != len(testMetadata) { t.Fatal("expected metadata to be persisted") } for _, m := range metadatas { - if m.DBMultipartUploadID != nil || m.DBObjectID == nil { + if m.MultipartUploadID != nil || m.ObjectID == nil { t.Fatal("unexpected") } } diff --git a/stores/peers.go b/stores/peers.go new file mode 100644 index 000000000..5937a9d68 --- /dev/null +++ b/stores/peers.go @@ -0,0 +1,65 @@ +package stores + +import ( + "context" + "time" + + "go.sia.tech/coreutils/syncer" + "go.sia.tech/renterd/stores/sql" +) + +var ( + _ syncer.PeerStore = (*SQLStore)(nil) +) + +// AddPeer adds a peer to the store. If the peer already exists, nil should be +// returned. +func (s *SQLStore) AddPeer(addr string) error { + return s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + return tx.AddPeer(context.Background(), addr) + }) +} + +// Peers returns the set of known peers. +func (s *SQLStore) Peers() (peers []syncer.PeerInfo, err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { + peers, txErr = tx.Peers(context.Background()) + return + }) + return +} + +// PeerInfo returns the metadata for the specified peer or ErrPeerNotFound +// if the peer wasn't found in the store. +func (s *SQLStore) PeerInfo(addr string) (info syncer.PeerInfo, err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { + info, txErr = tx.PeerInfo(context.Background(), addr) + return + }) + return +} + +// UpdatePeerInfo updates the metadata for the specified peer. If the peer +// is not found, the error should be ErrPeerNotFound. +func (s *SQLStore) UpdatePeerInfo(addr string, fn func(*syncer.PeerInfo)) error { + return s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + return tx.UpdatePeerInfo(context.Background(), addr, fn) + }) +} + +// Ban temporarily bans one or more IPs. The addr should either be a single +// IP with port (e.g. 1.2.3.4:5678) or a CIDR subnet (e.g. 1.2.3.4/16). +func (s *SQLStore) Ban(addr string, duration time.Duration, reason string) error { + return s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + return tx.BanPeer(context.Background(), addr, duration, reason) + }) +} + +// Banned returns true, nil if the peer is banned. +func (s *SQLStore) Banned(addr string) (banned bool, err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { + banned, txErr = tx.PeerBanned(context.Background(), addr) + return + }) + return +} diff --git a/stores/peers_test.go b/stores/peers_test.go new file mode 100644 index 000000000..7442962a1 --- /dev/null +++ b/stores/peers_test.go @@ -0,0 +1,165 @@ +package stores + +import ( + "errors" + "testing" + "time" + + "go.sia.tech/coreutils/syncer" +) + +const ( + testPeer = "1.2.3.4:9981" +) + +func TestPeers(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + // assert ErrPeerNotFound before we add it + err := ss.UpdatePeerInfo(testPeer, func(info *syncer.PeerInfo) {}) + if !errors.Is(err, syncer.ErrPeerNotFound) { + t.Fatal("expected peer not found") + } + + // add peer + err = ss.AddPeer(testPeer) + if err != nil { + t.Fatal(err) + } + + // fetch peers + var peer syncer.PeerInfo + peers, err := ss.Peers() + if err != nil { + t.Fatal(err) + } else if len(peers) != 1 { + t.Fatal("expected 1 peer") + } else { + peer = peers[0] + } + + // assert peer info + if peer.Address != testPeer { + t.Fatal("unexpected address") + } else if peer.FirstSeen.IsZero() { + t.Fatal("unexpected first seen") + } else if !peer.LastConnect.IsZero() { + t.Fatal("unexpected last connect") + } else if peer.SyncedBlocks != 0 { + t.Fatal("unexpected synced blocks") + } else if peer.SyncDuration != 0 { + t.Fatal("unexpected sync duration") + } + + // prepare peer update + lastConnect := time.Now().Truncate(time.Millisecond) + syncedBlocks := uint64(15) + syncDuration := 5 * time.Second + + // update peer + err = ss.UpdatePeerInfo(testPeer, func(info *syncer.PeerInfo) { + info.LastConnect = lastConnect + info.SyncedBlocks = syncedBlocks + info.SyncDuration = syncDuration + }) + if err != nil { + t.Fatal(err) + } + + // refetch peer + peers, err = ss.Peers() + if err != nil { + t.Fatal(err) + } else if len(peers) != 1 { + t.Fatal("expected 1 peer") + } else { + peer = peers[0] + } + + // assert peer info + if peer.Address != testPeer { + t.Fatal("unexpected address") + } else if peer.FirstSeen.IsZero() { + t.Fatal("unexpected first seen") + } else if !peer.LastConnect.Equal(lastConnect) { + t.Fatal("unexpected last connect") + } else if peer.SyncedBlocks != syncedBlocks { + t.Fatal("unexpected synced blocks") + } else if peer.SyncDuration != syncDuration { + t.Fatal("unexpected sync duration") + } + + // ban peer + err = ss.Ban(testPeer, time.Hour, "too many hits") + if err != nil { + t.Fatal(err) + } + + // assert the peer was banned + banned, err := ss.Banned(testPeer) + if err != nil { + t.Fatal(err) + } else if !banned { + t.Fatal("expected banned") + } + + // add another banned peer + bannedPeer := "1.2.3.4:9982" + err = ss.AddPeer(bannedPeer) + if err != nil { + t.Fatal(err) + } + + // add another unbanned peer + unbannedPeer := "1.2.3.5:9981" + err = ss.AddPeer(unbannedPeer) + if err != nil { + t.Fatal(err) + } + + // assert we have three peers + peers, err = ss.Peers() + if err != nil { + t.Fatal(err) + } else if len(peers) != 3 { + t.Fatalf("expected 3 peers, got %d", len(peers)) + } + + // assert the peers are properly banned + banned, err = ss.Banned(bannedPeer) + if err != nil { + t.Fatal(err) + } else if !banned { + t.Fatal("expected banned") + } + + banned, err = ss.Banned(unbannedPeer) + if err != nil { + t.Fatal(err) + } else if banned { + t.Fatal("expected unbanned") + } + + // ban by cidr + err = ss.Ban("192.168.1.0/30", time.Hour, "too many hits") + if err != nil { + t.Fatal(err) + } + + // assert address within subnet is banned + banned, err = ss.Banned("192.168.1.1") + if err != nil { + t.Fatal(err) + } else if !banned { + t.Fatal("expected banned") + } + + // assert address outside subnet is not banned + banned, err = ss.Banned("192.168.1.4") + if err != nil { + t.Fatal(err) + } else if banned { + t.Fatal("expected unbanned") + } +} diff --git a/stores/settingsdb.go b/stores/settingsdb.go index 08d0d3faf..7a895108c 100644 --- a/stores/settingsdb.go +++ b/stores/settingsdb.go @@ -13,7 +13,7 @@ func (s *SQLStore) DeleteSetting(ctx context.Context, key string) error { defer s.settingsMu.Unlock() // delete from database first - if err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + if err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.DeleteSettings(ctx, key) }); err != nil { return err @@ -36,7 +36,7 @@ func (s *SQLStore) Setting(ctx context.Context, key string) (string, error) { // Check database. var err error - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { value, err = tx.Setting(ctx, key) return err }) @@ -49,7 +49,7 @@ func (s *SQLStore) Setting(ctx context.Context, key string) (string, error) { // Settings implements the bus.SettingStore interface. func (s *SQLStore) Settings(ctx context.Context) (settings []string, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { settings, err = tx.Settings(ctx) return err }) @@ -62,7 +62,7 @@ func (s *SQLStore) UpdateSetting(ctx context.Context, key, value string) error { s.settingsMu.Lock() defer s.settingsMu.Unlock() - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateSetting(ctx, key, value) }) if err != nil { diff --git a/stores/settingsdb_test.go b/stores/settingsdb_test.go index b6708d6cb..cf2582579 100644 --- a/stores/settingsdb_test.go +++ b/stores/settingsdb_test.go @@ -55,7 +55,7 @@ func TestSQLSettingStore(t *testing.T) { if err := ss.DeleteSetting(ctx, "foo"); err != nil { t.Fatal(err) } else if _, err := ss.Setting(ctx, "foo"); !errors.Is(err, api.ErrSettingNotFound) { - t.Fatal("should fail with gorm.ErrRecordNotFound", err) + t.Fatal("should fail with api.ErrSettingNotFound", err) } else if keys, err := ss.Settings(ctx); err != nil { t.Fatal(err) } else if len(keys) != 0 { diff --git a/stores/slabbuffer.go b/stores/slabbuffer.go index 77e2574a3..5e8a542b8 100644 --- a/stores/slabbuffer.go +++ b/stores/slabbuffer.go @@ -53,7 +53,8 @@ type SlabBufferManager struct { buffersByKey map[string]*SlabBuffer } -func newSlabBufferManager(ctx context.Context, a alerts.Alerter, db sql.Database, logger *zap.SugaredLogger, slabBufferCompletionThreshold int64, partialSlabDir string) (*SlabBufferManager, error) { +func newSlabBufferManager(ctx context.Context, a alerts.Alerter, db sql.Database, logger *zap.Logger, slabBufferCompletionThreshold int64, partialSlabDir string) (*SlabBufferManager, error) { + logger = logger.Named("slabbuffers") if slabBufferCompletionThreshold < 0 || slabBufferCompletionThreshold > 1<<22 { return nil, fmt.Errorf("invalid slabBufferCompletionThreshold %v", slabBufferCompletionThreshold) } @@ -69,7 +70,7 @@ func newSlabBufferManager(ctx context.Context, a alerts.Alerter, db sql.Database bufferedSlabCompletionThreshold: slabBufferCompletionThreshold, db: db, dir: partialSlabDir, - logger: logger, + logger: logger.Sugar(), completeBuffers: make(map[bufferGroupID][]*SlabBuffer), incompleteBuffers: make(map[bufferGroupID][]*SlabBuffer), @@ -98,7 +99,7 @@ func newSlabBufferManager(ctx context.Context, a alerts.Alerter, db sql.Database }, Timestamp: time.Now(), }) - logger.Errorf("failed to open buffer file %v for slab %v: %v", buffer.Filename, buffer.Key, err) + logger.Sugar().Errorf("failed to open buffer file %v for slab %v: %v", buffer.Filename, buffer.Key, err) continue } @@ -146,15 +147,21 @@ func (mgr *SlabBufferManager) Close() error { return errors.Join(errs...) } -func (mgr *SlabBufferManager) AddPartialSlab(ctx context.Context, data []byte, minShards, totalShards uint8, contractSet uint) ([]object.SlabSlice, int64, error) { - gid := bufferGID(minShards, totalShards, uint32(contractSet)) +func (mgr *SlabBufferManager) AddPartialSlab(ctx context.Context, data []byte, minShards, totalShards uint8, contractSet string) ([]object.SlabSlice, int64, error) { + var set int64 + err := mgr.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + set, err = tx.ContractSetID(ctx, contractSet) + return err + }) + if err != nil { + return nil, 0, err + } + gid := bufferGID(minShards, totalShards, uint32(set)) // Sanity check input. slabSize := bufferedSlabSize(minShards) if minShards == 0 || totalShards == 0 || minShards > totalShards { return nil, 0, fmt.Errorf("invalid shard configuration: minShards=%v, totalShards=%v", minShards, totalShards) - } else if contractSet == 0 { - return nil, 0, fmt.Errorf("contract set must be set") } else if len(data) > slabSize { return nil, 0, fmt.Errorf("data size %v exceeds size of a slab %v", len(data), slabSize) } @@ -170,7 +177,6 @@ func (mgr *SlabBufferManager) AddPartialSlab(ctx context.Context, data []byte, m // the data over too many slabs. var slab object.SlabSlice var slabs []object.SlabSlice - var err error var usedBuffers []*SlabBuffer for _, buffer := range buffers { var used bool @@ -191,7 +197,7 @@ func (mgr *SlabBufferManager) AddPartialSlab(ctx context.Context, data []byte, m if len(data) > 0 { var sb *SlabBuffer err = mgr.db.Transaction(ctx, func(tx sql.DatabaseTx) error { - sb, err = createSlabBuffer(ctx, tx, contractSet, mgr.dir, minShards, totalShards) + sb, err = createSlabBuffer(ctx, tx, uint(set), mgr.dir, minShards, totalShards) return err }) if err != nil { @@ -291,7 +297,16 @@ func (mgr *SlabBufferManager) SlabBuffers() (sbs []api.SlabBuffer) { return sbs } -func (mgr *SlabBufferManager) SlabsForUpload(ctx context.Context, lockingDuration time.Duration, minShards, totalShards uint8, set uint, limit int) (slabs []api.PackedSlab, _ error) { +func (mgr *SlabBufferManager) SlabsForUpload(ctx context.Context, lockingDuration time.Duration, minShards, totalShards uint8, contractSet string, limit int) (slabs []api.PackedSlab, _ error) { + var set int64 + err := mgr.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + set, err = tx.ContractSetID(ctx, contractSet) + return err + }) + if err != nil { + return nil, err + } + // Deep copy complete buffers. We don't want to block the manager while we // perform disk I/O. mgr.mu.Lock() diff --git a/stores/slabbuffer_test.go b/stores/slabbuffer_test.go index 4425fcff1..0a0c03192 100644 --- a/stores/slabbuffer_test.go +++ b/stores/slabbuffer_test.go @@ -8,31 +8,36 @@ import ( "lukechampine.com/frand" ) +func (s *testSQLStore) ContractSetID(name string) (id int64) { + if err := s.DB().QueryRow(context.Background(), "SELECT id FROM contract_sets WHERE name = ?", name). + Scan(&id); err != nil { + s.t.Fatal(err) + } + return +} + func TestRecordAppendToCompletedBuffer(t *testing.T) { ss := newTestSQLStore(t, defaultTestSQLStoreConfig) defer ss.Close() completionThreshold := int64(1000) - mgr, err := newSlabBufferManager(context.Background(), ss.alerts, ss.bMain, ss.logger, completionThreshold, t.TempDir()) + mgr, err := newSlabBufferManager(context.Background(), ss.alerts, ss.db, ss.logger.Desugar(), completionThreshold, t.TempDir()) if err != nil { t.Fatal(err) } defer mgr.Close() // get contract set for its id - var set dbContractSet - if err := ss.db.Where("name", testContractSet).Take(&set).Error; err != nil { - t.Fatal(err) - } + csID := ss.ContractSetID(testContractSet) // compute gid - gid := bufferGID(1, 2, uint32(set.ID)) + gid := bufferGID(1, 2, uint32(csID)) // add a slab that immediately fills a buffer but has 100 bytes left minShards := uint8(1) totalShards := uint8(2) maxSize := bufferedSlabSize(minShards) - _, _, err = mgr.AddPartialSlab(context.Background(), frand.Bytes(maxSize-100), minShards, totalShards, set.ID) + _, _, err = mgr.AddPartialSlab(context.Background(), frand.Bytes(maxSize-100), minShards, totalShards, testContractSet) if err != nil { t.Fatal(err) } else if len(mgr.completeBuffers[gid]) != 1 { @@ -52,7 +57,7 @@ func TestRecordAppendToCompletedBuffer(t *testing.T) { // add a slab that should fit in the buffer but since the first buffer is // complete we ignore it - _, _, err = mgr.AddPartialSlab(context.Background(), frand.Bytes(1), minShards, totalShards, set.ID) + _, _, err = mgr.AddPartialSlab(context.Background(), frand.Bytes(1), minShards, totalShards, testContractSet) if err != nil { t.Fatal(err) } else if len(mgr.completeBuffers[gid]) != 1 { @@ -66,23 +71,20 @@ func TestMarkBufferCompleteTwice(t *testing.T) { ss := newTestSQLStore(t, defaultTestSQLStoreConfig) defer ss.Close() - mgr, err := newSlabBufferManager(context.Background(), ss.alerts, ss.bMain, ss.logger, 0, t.TempDir()) + mgr, err := newSlabBufferManager(context.Background(), ss.alerts, ss.db, ss.logger.Desugar(), 0, t.TempDir()) if err != nil { t.Fatal(err) } defer mgr.Close() // get contract set for its id - var set dbContractSet - if err := ss.db.Where("name", testContractSet).Take(&set).Error; err != nil { - t.Fatal(err) - } + csID := ss.ContractSetID(testContractSet) // compute gid - gid := bufferGID(1, 2, uint32(set.ID)) + gid := bufferGID(1, 2, uint32(csID)) // create an incomplete buffer - _, _, err = mgr.AddPartialSlab(context.Background(), frand.Bytes(1), 1, 2, set.ID) + _, _, err = mgr.AddPartialSlab(context.Background(), frand.Bytes(1), 1, 2, testContractSet) if err != nil { t.Fatal(err) } diff --git a/stores/sql.go b/stores/sql.go index 8cc86be62..50533768d 100644 --- a/stores/sql.go +++ b/stores/sql.go @@ -2,49 +2,30 @@ package stores import ( "context" - "errors" "fmt" "math" "os" - "strings" "sync" "time" "go.sia.tech/core/types" "go.sia.tech/renterd/alerts" - "go.sia.tech/renterd/api" "go.sia.tech/renterd/stores/sql" - "go.sia.tech/renterd/stores/sql/mysql" - "go.sia.tech/renterd/stores/sql/sqlite" - "go.sia.tech/siad/modules" "go.uber.org/zap" - gmysql "gorm.io/driver/mysql" - gsqlite "gorm.io/driver/sqlite" - "gorm.io/gorm" - glogger "gorm.io/gorm/logger" ) type ( - // Model defines the common fields of every table. Same as Model - // but excludes soft deletion since it breaks cascading deletes. - Model struct { - ID uint `gorm:"primarykey"` - CreatedAt time.Time - } - // Config contains all params for creating a SQLStore Config struct { - Conn gorm.Dialector + DB sql.Database DBMetrics sql.MetricsDatabase Alerts alerts.Alerter PartialSlabDir string Migrate bool AnnouncementMaxAge time.Duration - PersistInterval time.Duration WalletAddress types.Address SlabBufferCompletionThreshold int64 - Logger *zap.SugaredLogger - GormLogger glogger.Interface + Logger *zap.Logger RetryTransactionIntervals []time.Duration LongQueryDuration time.Duration LongTxDuration time.Duration @@ -52,206 +33,72 @@ type ( // SQLStore is a helper type for interacting with a SQL-based backend. SQLStore struct { - alerts alerts.Alerter - db *gorm.DB - bMain sql.Database - bMetrics sql.MetricsDatabase - logger *zap.SugaredLogger - - slabBufferMgr *SlabBufferManager + alerts alerts.Alerter + db sql.Database + dbMetrics sql.MetricsDatabase + logger *zap.SugaredLogger - retryTransactionIntervals []time.Duration - - // Persistence buffer - related fields. - lastSave time.Time - persistInterval time.Duration - persistMu sync.Mutex - persistTimer *time.Timer - unappliedAnnouncements []announcement - unappliedContractState map[types.FileContractID]contractState - unappliedHostKeys map[types.PublicKey]struct{} - unappliedRevisions map[types.FileContractID]revisionUpdate - unappliedProofs map[types.FileContractID]uint64 - unappliedOutputChanges []outputChange - unappliedTxnChanges []txnChange + walletAddress types.Address - // HostDB related fields - announcementMaxAge time.Duration + // ObjectDB related fields + slabBufferMgr *SlabBufferManager - // SettingsDB related fields. + // SettingsDB related fields settingsMu sync.Mutex settings map[string]string - // WalletDB related fields. - walletAddress types.Address - - // Consensus related fields. - ccid modules.ConsensusChangeID - chainIndex types.ChainIndex + retryTransactionIntervals []time.Duration shutdownCtx context.Context shutdownCtxCancel context.CancelFunc slabPruneSigChan chan struct{} + wg sync.WaitGroup - wg sync.WaitGroup mu sync.Mutex lastPrunedAt time.Time closed bool - - knownContracts map[types.FileContractID]struct{} - } - - revisionUpdate struct { - height uint64 - number uint64 - size uint64 } ) -// NewEphemeralSQLiteConnection creates a connection to an in-memory SQLite DB. -// NOTE: Use simple names such as a random hex identifier or the filepath.Base -// of a test's name. Certain symbols will break the cfg string and cause a file -// to be created on disk. -// -// mode: set to memory for in-memory database -// cache: set to shared which is required for in-memory databases -// _foreign_keys: enforce foreign_key relations -func NewEphemeralSQLiteConnection(name string) gorm.Dialector { - return gsqlite.Open(fmt.Sprintf("file:%s?mode=memory&cache=shared&_foreign_keys=1", name)) -} - -// NewSQLiteConnection opens a sqlite db at the given path. -// -// _busy_timeout: set to prevent concurrent transactions from failing and -// instead have them block -// _foreign_keys: enforce foreign_key relations -// _journal_mode: set to WAL instead of delete since it's usually the fastest. -// Only downside is that the db won't work on network drives. In that case this -// should be made configurable and set to TRUNCATE or any of the other options. -// For reference see https://github.com/mattn/go-sqlite3#connection-string. -func NewSQLiteConnection(path string) gorm.Dialector { - return gsqlite.Open(fmt.Sprintf("file:%s?_busy_timeout=30000&_foreign_keys=1&_journal_mode=WAL&_secure_delete=false&_cache_size=65536", path)) -} - -// NewMetricsSQLiteConnection opens a sqlite db at the given path similarly to -// NewSQLiteConnection but with weaker consistency guarantees since it's -// optimised for recording metrics. -func NewMetricsSQLiteConnection(path string) gorm.Dialector { - return gsqlite.Open(fmt.Sprintf("file:%s?_busy_timeout=30000&_foreign_keys=1&_journal_mode=WAL&_synchronous=NORMAL", path)) -} - -// NewMySQLConnection creates a connection to a MySQL database. -func NewMySQLConnection(user, password, addr, dbName string) gorm.Dialector { - return gmysql.Open(fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&multiStatements=true", user, password, addr, dbName)) -} - // NewSQLStore uses a given Dialector to connect to a SQL database. NOTE: Only // pass migrate=true for the first instance of SQLHostDB if you connect via the // same Dialector multiple times. -func NewSQLStore(cfg Config) (*SQLStore, modules.ConsensusChangeID, error) { - // Sanity check announcement max age. - if cfg.AnnouncementMaxAge == 0 { - return nil, modules.ConsensusChangeID{}, errors.New("announcementMaxAge must be non-zero") - } - +func NewSQLStore(cfg Config) (*SQLStore, error) { if err := os.MkdirAll(cfg.PartialSlabDir, 0700); err != nil { - return nil, modules.ConsensusChangeID{}, fmt.Errorf("failed to create partial slab dir: %v", err) - } - db, err := gorm.Open(cfg.Conn, &gorm.Config{ - Logger: cfg.GormLogger, // custom logger - SkipDefaultTransaction: true, - DisableNestedTransaction: true, - }) - if err != nil { - return nil, modules.ConsensusChangeID{}, fmt.Errorf("failed to open SQL db") + return nil, fmt.Errorf("failed to create partial slab dir '%s': %v", cfg.PartialSlabDir, err) } l := cfg.Logger.Named("sql") - - sqlDB, err := db.DB() - if err != nil { - return nil, modules.ConsensusChangeID{}, fmt.Errorf("failed to fetch db: %v", err) - } - - // Print DB version - var dbMain sql.Database + dbMain := cfg.DB dbMetrics := cfg.DBMetrics - var mainErr error - if cfg.Conn.Name() == "sqlite" { - dbMain, mainErr = sqlite.NewMainDatabase(sqlDB, l, cfg.LongQueryDuration, cfg.LongTxDuration) - } else { - dbMain, mainErr = mysql.NewMainDatabase(sqlDB, l, cfg.LongQueryDuration, cfg.LongTxDuration) - } - if mainErr != nil { - return nil, modules.ConsensusChangeID{}, fmt.Errorf("failed to create main database: %v", mainErr) - } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - dbName, dbVersion, err := dbMain.Version(ctx) + // Print DB version + dbName, dbVersion, err := dbMain.Version(context.Background()) if err != nil { - return nil, modules.ConsensusChangeID{}, fmt.Errorf("failed to fetch db version: %v", err) + return nil, fmt.Errorf("failed to fetch db version: %v", err) } - l.Infof("Using %s version %s", dbName, dbVersion) + l.Sugar().Infof("Using %s version %s", dbName, dbVersion) // Perform migrations. if cfg.Migrate { if err := dbMain.Migrate(context.Background()); err != nil { - return nil, modules.ConsensusChangeID{}, fmt.Errorf("failed to perform migrations: %v", err) + return nil, fmt.Errorf("failed to perform migrations: %v", err) } else if err := dbMetrics.Migrate(context.Background()); err != nil { - return nil, modules.ConsensusChangeID{}, fmt.Errorf("failed to perform migrations for metrics db: %v", err) + return nil, fmt.Errorf("failed to perform migrations for metrics db: %v", err) } } - // Get latest consensus change ID or init db. - ci, ccid, err := initConsensusInfo(ctx, dbMain) - if err != nil { - return nil, modules.ConsensusChangeID{}, err - } - - // Fetch contract ids. - var activeFCIDs, archivedFCIDs []fileContractID - if err := db.Model(&dbContract{}). - Select("fcid"). - Find(&activeFCIDs).Error; err != nil { - return nil, modules.ConsensusChangeID{}, err - } - if err := db.Model(&dbArchivedContract{}). - Select("fcid"). - Find(&archivedFCIDs).Error; err != nil { - return nil, modules.ConsensusChangeID{}, err - } - isOurContract := make(map[types.FileContractID]struct{}) - for _, fcid := range append(activeFCIDs, archivedFCIDs...) { - isOurContract[types.FileContractID(fcid)] = struct{}{} - } - shutdownCtx, shutdownCtxCancel := context.WithCancel(context.Background()) ss := &SQLStore{ - alerts: cfg.Alerts, - ccid: ccid, - db: db, - bMain: dbMain, - bMetrics: dbMetrics, - logger: l, - knownContracts: isOurContract, - lastSave: time.Now(), - persistInterval: cfg.PersistInterval, - settings: make(map[string]string), - slabPruneSigChan: make(chan struct{}, 1), - unappliedContractState: make(map[types.FileContractID]contractState), - unappliedHostKeys: make(map[types.PublicKey]struct{}), - unappliedRevisions: make(map[types.FileContractID]revisionUpdate), - unappliedProofs: make(map[types.FileContractID]uint64), - - announcementMaxAge: cfg.AnnouncementMaxAge, + alerts: cfg.Alerts, + db: dbMain, + dbMetrics: dbMetrics, + logger: l.Sugar(), + settings: make(map[string]string), walletAddress: cfg.WalletAddress, - chainIndex: types.ChainIndex{ - Height: ci.Height, - ID: types.BlockID(ci.ID), - }, + slabPruneSigChan: make(chan struct{}, 1), lastPrunedAt: time.Now(), retryTransactionIntervals: cfg.RetryTransactionIntervals, @@ -259,25 +106,14 @@ func NewSQLStore(cfg Config) (*SQLStore, modules.ConsensusChangeID, error) { shutdownCtxCancel: shutdownCtxCancel, } - ss.slabBufferMgr, err = newSlabBufferManager(shutdownCtx, cfg.Alerts, dbMain, l.Named("slabbuffers"), cfg.SlabBufferCompletionThreshold, cfg.PartialSlabDir) + ss.slabBufferMgr, err = newSlabBufferManager(shutdownCtx, cfg.Alerts, dbMain, l, cfg.SlabBufferCompletionThreshold, cfg.PartialSlabDir) if err != nil { - return nil, modules.ConsensusChangeID{}, err + return nil, err } if err := ss.initSlabPruning(); err != nil { - return nil, modules.ConsensusChangeID{}, err - } - return ss, ccid, nil -} - -func isSQLite(db *gorm.DB) bool { - switch db.Dialector.(type) { - case *gsqlite.Dialector: - return true - case *gmysql.Dialector: - return false - default: - panic(fmt.Sprintf("unknown dialector: %t", db.Dialector)) + return nil, err } + return ss, nil } func (s *SQLStore) initSlabPruning() error { @@ -289,7 +125,7 @@ func (s *SQLStore) initSlabPruning() error { }() // prune once to guarantee consistency on startup - return s.bMain.Transaction(s.shutdownCtx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(s.shutdownCtx, func(tx sql.DatabaseTx) error { _, err := tx.PruneSlabs(s.shutdownCtx, math.MaxInt64) return err }) @@ -298,18 +134,17 @@ func (s *SQLStore) initSlabPruning() error { // Close closes the underlying database connection of the store. func (s *SQLStore) Close() error { s.shutdownCtxCancel() - s.wg.Wait() - err := s.bMain.Close() + err := s.slabBufferMgr.Close() if err != nil { return err } - err = s.bMetrics.Close() + + err = s.db.Close() if err != nil { return err } - - err = s.slabBufferMgr.Close() + err = s.dbMetrics.Close() if err != nil { return err } @@ -319,219 +154,3 @@ func (s *SQLStore) Close() error { s.mu.Unlock() return nil } - -// ProcessConsensusChange implements consensus.Subscriber. -func (ss *SQLStore) ProcessConsensusChange(cc modules.ConsensusChange) { - ss.persistMu.Lock() - defer ss.persistMu.Unlock() - - ss.processConsensusChangeHostDB(cc) - ss.processConsensusChangeContracts(cc) - ss.processConsensusChangeWallet(cc) - - // Update consensus fields. - ss.ccid = cc.ID - ss.chainIndex = types.ChainIndex{ - Height: uint64(cc.BlockHeight), - ID: types.BlockID(cc.AppliedBlocks[len(cc.AppliedBlocks)-1].ID()), - } - - // Try to apply the updates. - if err := ss.applyUpdates(false); err != nil { - ss.logger.Error(fmt.Sprintf("failed to apply updates, err: %v", err)) - } - - // Force a persist if no block has been received for some time. - if ss.persistTimer != nil { - ss.persistTimer.Stop() - select { - case <-ss.persistTimer.C: - default: - } - } - ss.persistTimer = time.AfterFunc(10*time.Second, func() { - ss.mu.Lock() - if ss.closed { - ss.mu.Unlock() - return - } - ss.mu.Unlock() - - ss.persistMu.Lock() - defer ss.persistMu.Unlock() - if err := ss.applyUpdates(true); err != nil { - ss.logger.Error(fmt.Sprintf("failed to apply updates, err: %v", err)) - } - }) -} - -// applyUpdates applies all unapplied updates to the database. -func (ss *SQLStore) applyUpdates(force bool) error { - // Check if we need to apply changes - persistIntervalPassed := time.Since(ss.lastSave) > ss.persistInterval // enough time has passed since last persist - softLimitReached := len(ss.unappliedAnnouncements) >= announcementBatchSoftLimit // enough announcements have accumulated - unappliedRevisionsOrProofs := len(ss.unappliedRevisions) > 0 || len(ss.unappliedProofs) > 0 // enough revisions/proofs have accumulated - unappliedOutputsOrTxns := len(ss.unappliedOutputChanges) > 0 || len(ss.unappliedTxnChanges) > 0 // enough outputs/txns have accumualted - unappliedContractState := len(ss.unappliedContractState) > 0 // the chain state of a contract changed - if !force && !persistIntervalPassed && !softLimitReached && !unappliedRevisionsOrProofs && !unappliedOutputsOrTxns && !unappliedContractState { - return nil - } - - // Fetch allowlist - var allowlist []dbAllowlistEntry - if err := ss.db. - Model(&dbAllowlistEntry{}). - Find(&allowlist). - Error; err != nil { - ss.logger.Error(fmt.Sprintf("failed to fetch allowlist, err: %v", err)) - } - - // Fetch blocklist - var blocklist []dbBlocklistEntry - if err := ss.db. - Model(&dbBlocklistEntry{}). - Find(&blocklist). - Error; err != nil { - ss.logger.Error(fmt.Sprintf("failed to fetch blocklist, err: %v", err)) - } - - err := ss.retryTransaction(context.Background(), func(tx *gorm.DB) (err error) { - if len(ss.unappliedAnnouncements) > 0 { - if err = insertAnnouncements(tx, ss.unappliedAnnouncements); err != nil { - return fmt.Errorf("%w; failed to insert %d announcements", err, len(ss.unappliedAnnouncements)) - } - } - if len(ss.unappliedHostKeys) > 0 && (len(allowlist)+len(blocklist)) > 0 { - for host := range ss.unappliedHostKeys { - if err := updateBlocklist(tx, host, allowlist, blocklist); err != nil { - ss.logger.Error(fmt.Sprintf("failed to update blocklist, err: %v", err)) - } - } - } - for fcid, rev := range ss.unappliedRevisions { - if err := applyRevisionUpdate(tx, types.FileContractID(fcid), rev); err != nil { - return fmt.Errorf("%w; failed to update revision number and height", err) - } - } - for fcid, proofHeight := range ss.unappliedProofs { - if err := updateProofHeight(tx, types.FileContractID(fcid), proofHeight); err != nil { - return fmt.Errorf("%w; failed to update proof height", err) - } - } - for _, oc := range ss.unappliedOutputChanges { - if oc.addition { - err = applyUnappliedOutputAdditions(tx, oc.sco) - } else { - err = applyUnappliedOutputRemovals(tx, oc.oid) - } - if err != nil { - return fmt.Errorf("%w; failed to apply unapplied output change", err) - } - } - for _, tc := range ss.unappliedTxnChanges { - if tc.addition { - err = applyUnappliedTxnAdditions(tx, tc.txn) - } else { - err = applyUnappliedTxnRemovals(tx, tc.txnID) - } - if err != nil { - return fmt.Errorf("%w; failed to apply unapplied txn change", err) - } - } - for fcid, cs := range ss.unappliedContractState { - if err := updateContractState(tx, fcid, cs); err != nil { - return fmt.Errorf("%w; failed to update chain state", err) - } - } - if err := markFailedContracts(tx, ss.chainIndex.Height); err != nil { - return err - } - return updateCCID(tx, ss.ccid, ss.chainIndex) - }) - if err != nil { - return fmt.Errorf("%w; failed to apply updates", err) - } - - ss.unappliedContractState = make(map[types.FileContractID]contractState) - ss.unappliedProofs = make(map[types.FileContractID]uint64) - ss.unappliedRevisions = make(map[types.FileContractID]revisionUpdate) - ss.unappliedHostKeys = make(map[types.PublicKey]struct{}) - ss.unappliedAnnouncements = ss.unappliedAnnouncements[:0] - ss.lastSave = time.Now() - ss.unappliedOutputChanges = nil - ss.unappliedTxnChanges = nil - return nil -} - -func (s *SQLStore) retryTransaction(ctx context.Context, fc func(tx *gorm.DB) error) error { - abortRetry := func(err error) bool { - if err == nil || - errors.Is(err, context.Canceled) || - errors.Is(err, context.DeadlineExceeded) || - errors.Is(err, gorm.ErrRecordNotFound) || - errors.Is(err, api.ErrContractNotFound) || - errors.Is(err, api.ErrObjectNotFound) || - errors.Is(err, api.ErrObjectCorrupted) || - errors.Is(err, api.ErrBucketExists) || - errors.Is(err, api.ErrBucketNotFound) || - errors.Is(err, api.ErrBucketNotEmpty) || - errors.Is(err, api.ErrContractNotFound) || - errors.Is(err, api.ErrMultipartUploadNotFound) || - errors.Is(err, api.ErrObjectExists) || - strings.Contains(err.Error(), "no such table") || - strings.Contains(err.Error(), "Duplicate entry") || - errors.Is(err, api.ErrPartNotFound) || - errors.Is(err, api.ErrSlabNotFound) { - return true - } - return false - } - - var err error - attempts := len(s.retryTransactionIntervals) + 1 - for i := 0; i < attempts; i++ { - // execute the transaction - err = s.db.WithContext(ctx).Transaction(fc) - if abortRetry(err) { - return err - } - - // if this was the last attempt, return the error - if i == len(s.retryTransactionIntervals) { - s.logger.Warn(fmt.Sprintf("transaction attempt %d/%d failed, err: %v", i+1, attempts, err)) - return err - } - - // log the failed attempt and sleep before retrying - interval := s.retryTransactionIntervals[i] - s.logger.Warn(fmt.Sprintf("transaction attempt %d/%d failed, retry in %v, err: %v", i+1, attempts, interval, err)) - time.Sleep(interval) - } - return fmt.Errorf("retryTransaction failed: %w", err) -} - -func initConsensusInfo(ctx context.Context, db sql.Database) (ci types.ChainIndex, ccid modules.ConsensusChangeID, err error) { - err = db.Transaction(ctx, func(tx sql.DatabaseTx) error { - ci, ccid, err = tx.InitConsensusInfo(ctx) - return err - }) - return -} - -func (s *SQLStore) ResetConsensusSubscription(ctx context.Context) error { - // reset db - var ci types.ChainIndex - var err error - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { - ci, err = tx.ResetConsensusSubscription(ctx) - return err - }) - if err != nil { - return err - } - // reset in-memory state. - s.persistMu.Lock() - s.chainIndex = ci - s.persistMu.Unlock() - return nil -} diff --git a/stores/sql/chain.go b/stores/sql/chain.go new file mode 100644 index 000000000..c9c991d4e --- /dev/null +++ b/stores/sql/chain.go @@ -0,0 +1,226 @@ +package sql + +import ( + "context" + dsql "database/sql" + "errors" + "fmt" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/sql" + "go.uber.org/zap" +) + +var contractTables = []string{ + "contracts", + "archived_contracts", +} + +func GetContractState(ctx context.Context, tx sql.Tx, fcid types.FileContractID) (api.ContractState, error) { + var cse ContractState + err := tx. + QueryRow(ctx, + fmt.Sprintf("SELECT state FROM (SELECT state, fcid FROM %s UNION SELECT state, fcid FROM %s) as combined WHERE fcid = ?", + contractTables[0], + contractTables[1]), + FileContractID(fcid), + ). + Scan(&cse) + if errors.Is(err, dsql.ErrNoRows) { + return "", contractNotFoundErr(fcid) + } else if err != nil { + return "", fmt.Errorf("failed to fetch contract state: %w", err) + } + + return api.ContractState(cse.String()), nil +} + +func UpdateChainIndex(ctx context.Context, tx sql.Tx, index types.ChainIndex, l *zap.SugaredLogger) error { + l.Debugw("update chain index", "height", index.Height, "block_id", index.ID) + + if res, err := tx.Exec(ctx, + fmt.Sprintf("UPDATE consensus_infos SET height = ?, block_id = ? WHERE id = %d", sql.ConsensusInfoID), + index.Height, + Hash256(index.ID), + ); err != nil { + return fmt.Errorf("failed to update chain index: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to update chain index: no rows affected") + } + + return nil +} + +func UpdateContract(ctx context.Context, tx sql.Tx, fcid types.FileContractID, revisionHeight, revisionNumber, size uint64, l *zap.SugaredLogger) error { + for _, table := range contractTables { + // fetch current contract, in SQLite we could use a single query to + // perform the conditional update, however we have to compare the + // revision number which are stored as strings so we need to fetch the + // current contract info separately + var currRevisionHeight, currSize uint64 + var currRevisionNumber Uint64Str + err := tx. + QueryRow(ctx, fmt.Sprintf("SELECT revision_height, revision_number, COALESCE(size, 0) FROM %s WHERE fcid = ?", table), FileContractID(fcid)). + Scan(&currRevisionHeight, &currRevisionNumber, &currSize) + if errors.Is(err, dsql.ErrNoRows) { + continue + } else if err != nil { + return fmt.Errorf("failed to fetch '%s' info for %v: %w", table[:len(table)-1], fcid, err) + } + + // update contract + err = updateContract(ctx, tx, table, fcid, currRevisionHeight, uint64(currRevisionNumber), revisionHeight, revisionNumber, size) + if err != nil { + return fmt.Errorf("failed to update '%s' %v: %w", table[:len(table)-1], fcid, err) + } + + l.Debugw(fmt.Sprintf("update %s, revision number %d -> %d, revision height %d -> %d, size %d -> %d", table[:len(table)-1], currRevisionNumber, revisionNumber, currRevisionHeight, revisionHeight, currSize, size), "fcid", fcid) + return nil + } + + return contractNotFoundErr(fcid) +} + +func UpdateContractProofHeight(ctx context.Context, tx sql.Tx, fcid types.FileContractID, proofHeight uint64, l *zap.SugaredLogger) error { + l.Debugw("update contract proof height", "fcid", fcid, "proof_height", proofHeight) + + for _, table := range contractTables { + ok, err := updateContractProofHeight(ctx, tx, table, fcid, proofHeight) + if err != nil { + return fmt.Errorf("failed to update '%s' %v proof height: %w", table[:len(table)-1], fcid, err) + } else if ok { + break + } + } + + return nil +} + +func UpdateContractState(ctx context.Context, tx sql.Tx, fcid types.FileContractID, state api.ContractState, l *zap.SugaredLogger) error { + l.Debugw("update contract state", "fcid", fcid, "state", state) + + var cs ContractState + if err := cs.LoadString(string(state)); err != nil { + return err + } + + for _, table := range contractTables { + ok, err := updateContractState(ctx, tx, table, fcid, cs) + if err != nil { + return fmt.Errorf("failed to update %s state: %w", table[:len(table)-1], err) + } else if ok { + break + } + } + + return nil +} + +func UpdateFailedContracts(ctx context.Context, tx sql.Tx, blockHeight uint64, l *zap.SugaredLogger) error { + l.Debugw("update failed contracts", "block_height", blockHeight) + + if res, err := tx.Exec(ctx, + "UPDATE contracts SET state = ? WHERE window_end <= ? AND state = ?", + ContractStateFromString(api.ContractStateFailed), + blockHeight, + ContractStateFromString(api.ContractStateActive), + ); err != nil { + return fmt.Errorf("failed to update failed contracts: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n > 0 { + l.Debugw(fmt.Sprintf("marked %d active contracts as failed", n), "window_end", blockHeight) + } + + return nil +} + +func UpdateWalletStateElements(ctx context.Context, tx sql.Tx, elements []types.StateElement) error { + if len(elements) == 0 { + return nil + } + + updateStmt, err := tx.Prepare(ctx, "UPDATE wallet_outputs SET leaf_index = ?, merkle_proof= ? WHERE output_id = ?") + if err != nil { + return fmt.Errorf("failed to prepare statement to update state elements: %w", err) + } + defer updateStmt.Close() + + for _, el := range elements { + if _, err := updateStmt.Exec(ctx, el.LeafIndex, MerkleProof{Hashes: el.MerkleProof}, Hash256(el.ID)); err != nil { + return fmt.Errorf("failed to update state element '%v': %w", el.ID, err) + } + } + + return nil +} + +func WalletStateElements(ctx context.Context, tx sql.Tx) ([]types.StateElement, error) { + rows, err := tx.Query(ctx, "SELECT output_id, leaf_index, merkle_proof FROM wallet_outputs") + if err != nil { + return nil, fmt.Errorf("failed to fetch state elements: %w", err) + } + defer rows.Close() + + var elements []types.StateElement + for rows.Next() { + if el, err := scanStateElement(rows); err != nil { + return nil, fmt.Errorf("failed to scan state element: %w", err) + } else { + elements = append(elements, el) + } + } + return elements, nil +} + +func contractNotFoundErr(fcid types.FileContractID) error { + return fmt.Errorf("%w: %v", api.ErrContractNotFound, fcid) +} + +func updateContract(ctx context.Context, tx sql.Tx, table string, fcid types.FileContractID, currRevisionHeight, currRevisionNumber, revisionHeight, revisionNumber, size uint64) (err error) { + if revisionNumber > currRevisionNumber { + _, err = tx.Exec( + ctx, + fmt.Sprintf("UPDATE %s SET revision_height = ?, revision_number = ?, size = ? WHERE fcid = ?", table), + revisionHeight, + fmt.Sprint(revisionNumber), + size, + FileContractID(fcid), + ) + } else if revisionHeight > currRevisionHeight { + _, err = tx.Exec( + ctx, + fmt.Sprintf("UPDATE %s SET revision_height = ? WHERE fcid = ?", table), + revisionHeight, + FileContractID(fcid), + ) + } + return +} + +func updateContractProofHeight(ctx context.Context, tx sql.Tx, table string, fcid types.FileContractID, proofHeight uint64) (bool, error) { + res, err := tx.Exec(ctx, fmt.Sprintf("UPDATE %s SET proof_height = ? WHERE fcid = ?", table), proofHeight, FileContractID(fcid)) + if err != nil { + return false, err + } + n, err := res.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to get rows affected: %w", err) + } + return n == 1, nil +} + +func updateContractState(ctx context.Context, tx sql.Tx, table string, fcid types.FileContractID, cs ContractState) (bool, error) { + res, err := tx.Exec(ctx, fmt.Sprintf("UPDATE %s SET state = ? WHERE fcid = ?", table), cs, FileContractID(fcid)) + if err != nil { + return false, err + } + n, err := res.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to get rows affected: %w", err) + } + return n == 1, nil +} diff --git a/stores/sql/consts.go b/stores/sql/consts.go index 340935623..64558343d 100644 --- a/stores/sql/consts.go +++ b/stores/sql/consts.go @@ -16,6 +16,23 @@ const ( contractStateFailed ) +func ContractStateFromString(state string) ContractState { + switch strings.ToLower(state) { + case api.ContractStateInvalid: + return contractStateInvalid + case api.ContractStatePending: + return contractStatePending + case api.ContractStateActive: + return contractStateActive + case api.ContractStateComplete: + return contractStateComplete + case api.ContractStateFailed: + return contractStateFailed + default: + return contractStateInvalid + } +} + func (s *ContractState) LoadString(state string) error { switch strings.ToLower(state) { case api.ContractStateInvalid: diff --git a/stores/sql/database.go b/stores/sql/database.go index 57f94efa4..cc2aab0df 100644 --- a/stores/sql/database.go +++ b/stores/sql/database.go @@ -5,16 +5,31 @@ import ( "io" "time" + rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/api" "go.sia.tech/renterd/object" "go.sia.tech/renterd/webhooks" - "go.sia.tech/siad/modules" ) // The database interfaces define all methods that a SQL database must implement // to be used by the SQLStore. type ( + ChainUpdateTx interface { + ContractState(fcid types.FileContractID) (api.ContractState, error) + UpdateChainIndex(index types.ChainIndex) error + UpdateContract(fcid types.FileContractID, revisionHeight, revisionNumber, size uint64) error + UpdateContractState(fcid types.FileContractID, state api.ContractState) error + UpdateContractProofHeight(fcid types.FileContractID, proofHeight uint64) error + UpdateFailedContracts(blockHeight uint64) error + UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, bh uint64, blockID types.BlockID, ts time.Time) error + + wallet.UpdateTx + } + Database interface { io.Closer @@ -34,13 +49,16 @@ type ( DatabaseTx interface { // AbortMultipartUpload aborts a multipart upload and deletes it from // the database. - AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) error + AbortMultipartUpload(ctx context.Context, bucket, key string, uploadID string) error // Accounts returns all accounts from the db. Accounts(ctx context.Context) ([]api.Account, error) // AddMultipartPart adds a part to an unfinished multipart upload. - AddMultipartPart(ctx context.Context, bucket, path, contractSet, eTag, uploadID string, partNumber int, slices object.SlabSlices) error + AddMultipartPart(ctx context.Context, bucket, key, contractSet, eTag, uploadID string, partNumber int, slices object.SlabSlices) error + + // AddPeer adds a peer to the store. + AddPeer(ctx context.Context, addr string) error // AddWebhook adds a new webhook to the database. If the webhook already // exists, it is updated. @@ -61,6 +79,11 @@ type ( // Autopilots returns all autopilots. Autopilots(ctx context.Context) ([]api.Autopilot, error) + // BanPeer temporarily bans one or more IPs. The addr should either be a + // single IP with port (e.g. 1.2.3.4:5678) or a CIDR subnet (e.g. + // 1.2.3.4/16). + BanPeer(ctx context.Context, addr string, duration time.Duration, reason string) error + // Bucket returns the bucket with the given name. If the bucket doesn't // exist, it returns api.ErrBucketNotFound. Bucket(ctx context.Context, bucket string) (api.Bucket, error) @@ -71,6 +94,10 @@ type ( // duplicates but can contain gaps. CompleteMultipartUpload(ctx context.Context, bucket, key, uploadID string, parts []api.MultipartCompletedPart, opts api.CompleteMultipartOptions) (string, error) + // Contract returns the metadata of the contract with the given ID or + // ErrContractNotFound. + Contract(ctx context.Context, id types.FileContractID) (cm api.ContractMetadata, err error) + // ContractRoots returns the roots of the contract with the given ID. ContractRoots(ctx context.Context, fcid types.FileContractID) ([]types.Hash256, error) @@ -78,6 +105,12 @@ type ( // opts argument can be used to filter the result. Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) + // ContractSetID returns the ID of the contract set with the given name. + // NOTE: Our locking strategy requires that the contract set ID is + // unique. So even after a contract set was deleted, the ID must not be + // reused. + ContractSetID(ctx context.Context, contractSet string) (int64, error) + // ContractSets returns the names of all contract sets. ContractSets(ctx context.Context) ([]string, error) @@ -98,12 +131,25 @@ type ( // the bucket already exists, api.ErrBucketExists is returned. CreateBucket(ctx context.Context, bucket string, policy api.BucketPolicy) error + // DeleteBucket deletes a bucket. If the bucket isn't empty, it returns + // api.ErrBucketNotEmpty. If the bucket doesn't exist, it returns + // api.ErrBucketNotFound. + DeleteBucket(ctx context.Context, bucket string) error + // DeleteHostSector deletes all contract sector links that a host has // with the given root incrementing the lost sector count in the // process. If another contract with a different host exists that // contains the root, latest_host is updated to that host. DeleteHostSector(ctx context.Context, hk types.PublicKey, root types.Hash256) (int, error) + // DeleteObject deletes an object from the database and returns true if + // the requested object was actually deleted. + DeleteObject(ctx context.Context, bucket, key string) (bool, error) + + // DeleteObjects deletes a batch of objects starting with the given + // prefix and returns 'true' if any object was deleted. + DeleteObjects(ctx context.Context, bucket, prefix string, limit int64) (bool, error) + // DeleteSettings deletes the settings with the given key. DeleteSettings(ctx context.Context, key string) error @@ -118,6 +164,9 @@ type ( // that was created. InsertBufferedSlab(ctx context.Context, fileName string, contractSetID int64, ec object.EncryptionKey, minShards, totalShards uint8) (int64, error) + // InsertContract inserts a new contract into the database. + InsertContract(ctx context.Context, rev rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) + // InsertMultipartUpload creates a new multipart upload and returns a // unique upload ID. InsertMultipartUpload(ctx context.Context, bucket, path string, ec object.EncryptionKey, mimeType string, metadata api.ObjectUserMetadata) (string, error) @@ -126,19 +175,6 @@ type ( // are associated with any of the provided contracts. InvalidateSlabHealthByFCID(ctx context.Context, fcids []types.FileContractID, limit int64) (int64, error) - // DeleteBucket deletes a bucket. If the bucket isn't empty, it returns - // api.ErrBucketNotEmpty. If the bucket doesn't exist, it returns - // api.ErrBucketNotFound. - DeleteBucket(ctx context.Context, bucket string) error - - // DeleteObject deletes an object from the database and returns true if - // the requested object was actually deleted. - DeleteObject(ctx context.Context, bucket, key string) (bool, error) - - // DeleteObjects deletes a batch of objects starting with the given - // prefix and returns 'true' if any object was deleted. - DeleteObjects(ctx context.Context, bucket, prefix string, limit int64) (bool, error) - // HostAllowlist returns the list of public keys of hosts on the // allowlist. HostAllowlist(ctx context.Context) ([]types.PublicKey, error) @@ -146,10 +182,6 @@ type ( // HostBlocklist returns the list of host addresses on the blocklist. HostBlocklist(ctx context.Context) ([]string, error) - // InitConsensusInfo initializes the consensus info in the database or - // returns the latest one. - InitConsensusInfo(ctx context.Context) (types.ChainIndex, modules.ConsensusChangeID, error) - // InsertObject inserts a new object into the database. InsertObject(ctx context.Context, bucket, key, contractSet string, dirID int64, o object.Object, mimeType, eTag string, md api.ObjectUserMetadata) error @@ -166,6 +198,11 @@ type ( // MakeDirsForPath creates all directories for a given object's path. MakeDirsForPath(ctx context.Context, path string) (int64, error) + // MarkPackedSlabUploaded marks the packed slab as uploaded in the + // database, causing the provided shards to be associated with the slab. + // The returned string contains the filename of the slab buffer on disk. + MarkPackedSlabUploaded(ctx context.Context, slab api.UploadedPackedSlab) (string, error) + // MultipartUpload returns the multipart upload with the given ID or // api.ErrMultipartUploadNotFound if the upload doesn't exist. MultipartUpload(ctx context.Context, uploadID string) (api.MultipartUpload, error) @@ -177,9 +214,35 @@ type ( // MultipartUploads returns a list of all multipart uploads. MultipartUploads(ctx context.Context, bucket, prefix, keyMarker, uploadIDMarker string, limit int) (api.MultipartListUploadsResponse, error) + // Object returns an object from the database. + Object(ctx context.Context, bucket, key string) (api.Object, error) + + // ObjectEntries queries the database for objects in a given dir. + ObjectEntries(ctx context.Context, bucket, key, prefix, sortBy, sortDir, marker string, offset, limit int) ([]api.ObjectMetadata, bool, error) + + // ObjectMetadata returns an object's metadata. + ObjectMetadata(ctx context.Context, bucket, key string) (api.Object, error) + + // ObjectsBySlabKey returns all objects that contain a reference to the + // slab with the given slabKey. + ObjectsBySlabKey(ctx context.Context, bucket string, slabKey object.EncryptionKey) (metadata []api.ObjectMetadata, err error) + // ObjectsStats returns overall stats about stored objects ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) (api.ObjectsStatsResponse, error) + // PeerBanned returns true if the peer is banned. + PeerBanned(ctx context.Context, addr string) (bool, error) + + // PeerInfo returns the metadata for the specified peer or + // ErrPeerNotFound if the peer wasn't found in the store. + PeerInfo(ctx context.Context, addr string) (syncer.PeerInfo, error) + + // Peers returns the set of known peers. + Peers(ctx context.Context) ([]syncer.PeerInfo, error) + + // ProcessChainUpdate applies the given chain update to the database. + ProcessChainUpdate(ctx context.Context, applyFn func(ChainUpdateTx) error) error + // PruneEmptydirs prunes any directories that are empty. PruneEmptydirs(ctx context.Context) error @@ -187,6 +250,9 @@ type ( // or slab buffer. PruneSlabs(ctx context.Context, limit int64) (int64, error) + // RecordContractSpending records new spending for a contract + RecordContractSpending(ctx context.Context, fcid types.FileContractID, revisionNumber, size uint64, newSpending api.ContractSpending) error + // RecordHostScans records the results of host scans in the database // such as recording the settings and price table of a host in case of // success and updating the uptime and downtime of a host. @@ -224,13 +290,18 @@ type ( // returned. RenameObjects(ctx context.Context, bucket, prefixOld, prefixNew string, dirID int64, force bool) error + // RenewContract renews the contract in the database. That means the + // contract with the ID of 'renewedFrom' will be moved to the archived + // contracts and the new contract will overwrite the existing one, + // inheriting its sectors. + RenewContract(ctx context.Context, rev rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) + // RenewedContract returns the metadata of the contract that was renewed - // fro mthe specified contract or ErrContractNotFound otherwise. + // from the specified contract or ErrContractNotFound otherwise. RenewedContract(ctx context.Context, renewedFrom types.FileContractID) (api.ContractMetadata, error) - // ResetConsenusSubscription resets the consensus subscription in the - // database. - ResetConsensusSubscription(ctx context.Context) (types.ChainIndex, error) + // ResetChainState deletes all chain data in the database. + ResetChainState(ctx context.Context) error // ResetLostSectors resets the lost sector count for the given host. ResetLostSectors(ctx context.Context, hk types.PublicKey) error @@ -242,20 +313,41 @@ type ( // SearchHosts returns a list of hosts that match the provided filters SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) + // SearchObjects returns a list of objects that contain the provided + // substring. + SearchObjects(ctx context.Context, bucket, substring string, offset, limit int) ([]api.ObjectMetadata, error) + // SetUncleanShutdown sets the clean shutdown flag on the accounts to // 'false' and also marks them as requiring a resync. SetUncleanShutdown(ctx context.Context) error + // SetContractSet creates the contract set with the given name and + // associates it with the provided contract IDs. + SetContractSet(ctx context.Context, name string, contractIds []types.FileContractID) error + // Setting returns the setting with the given key from the database. Setting(ctx context.Context, key string) (string, error) // Settings returns all available settings from the database. Settings(ctx context.Context) ([]string, error) + // Slab returns the slab with the given ID or api.ErrSlabNotFound. + Slab(ctx context.Context, key object.EncryptionKey) (object.Slab, error) + // SlabBuffers returns the filenames and associated contract sets of all // slab buffers. SlabBuffers(ctx context.Context) (map[string]string, error) + // Tip returns the sync height. + Tip(ctx context.Context) (types.ChainIndex, error) + + // UnhealthySlabs returns up to 'limit' slabs belonging to the contract + // set 'set' with a health smaller than or equal to 'healthCutoff' + UnhealthySlabs(ctx context.Context, healthCutoff float64, set string, limit int) ([]api.UnhealthySlab, error) + + // UnspentSiacoinElements returns all wallet outputs in the database. + UnspentSiacoinElements(ctx context.Context) ([]types.SiacoinElement, error) + // UpdateAutopilot updates the autopilot with the provided one or // creates a new one if it doesn't exist yet. UpdateAutopilot(ctx context.Context, ap api.Autopilot) error @@ -273,6 +365,9 @@ type ( // UpdateHostCheck updates the host check for the given host. UpdateHostCheck(ctx context.Context, autopilot string, hk types.PublicKey, hc api.HostCheck) error + // UpdatePeerInfo updates the metadata for the specified peer. + UpdatePeerInfo(ctx context.Context, addr string, fn func(*syncer.PeerInfo)) error + // UpdateSetting updates the setting with the given key to the given // value. UpdateSetting(ctx context.Context, key, value string) error @@ -291,6 +386,12 @@ type ( // the health of the updated slabs becomes invalid UpdateSlabHealth(ctx context.Context, limit int64, minValidity, maxValidity time.Duration) (int64, error) + // WalletEvents returns all wallet events in the database. + WalletEvents(ctx context.Context, offset, limit int) ([]wallet.Event, error) + + // WalletEventCount returns the total number of events in the database. + WalletEventCount(ctx context.Context) (uint64, error) + // Webhooks returns all registered webhooks. Webhooks(ctx context.Context) ([]webhooks.Webhook, error) } diff --git a/stores/sql/main.go b/stores/sql/main.go index b5ee9670d..bb03bd86d 100644 --- a/stores/sql/main.go +++ b/stores/sql/main.go @@ -8,6 +8,8 @@ import ( "fmt" "math" "math/big" + "net" + "strconv" "strings" "time" "unicode/utf8" @@ -16,18 +18,52 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/sql" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" "go.sia.tech/renterd/webhooks" - "go.sia.tech/siad/modules" "lukechampine.com/frand" ) -const consensuInfoID = 1 - var ErrNegativeOffset = errors.New("offset can not be negative") +// helper types +type ( + multipartUpload struct { + ID int64 + Key string + Bucket string + BucketID int64 + EC object.EncryptionKey + MimeType string + } + + multipartUploadPart struct { + ID int64 + PartNumber int64 + Etag string + Size int64 + } + + // Tx is an interface that allows for injecting custom methods into helpers + // to avoid duplicating code. + Tx interface { + sql.Tx + + CharLengthExpr() string + + // ScanObjectMetadata scans the object metadata from the given scanner. + // The columns required to scan the metadata are returned by the + // SelectObjectMetadataExpr helper method. Additional fields can be + // selected and scanned by passing them to the method as 'others'. + ScanObjectMetadata(s Scanner, others ...any) (md api.ObjectMetadata, err error) + SelectObjectMetadataExpr() string + } +) + func AbortMultipartUpload(ctx context.Context, tx sql.Tx, bucket, key string, uploadID string) error { res, err := tx.Exec(ctx, ` DELETE @@ -92,8 +128,10 @@ func AncestorContracts(ctx context.Context, tx sql.Tx, fcid types.FileContractID WHERE archived_contracts.renewed_to = ancestors.fcid ) SELECT fcid, host, renewed_to, upload_spending, download_spending, fund_account_spending, delete_spending, - proof_height, revision_height, revision_number, size, start_height, state, window_start, window_end + proof_height, revision_height, revision_number, size, start_height, state, window_start, window_end, + COALESCE(h.net_address, ''), contract_price, renewed_from, total_cost, reason FROM ancestors + LEFT JOIN hosts h ON h.public_key = ancestors.host WHERE start_height >= ? `, FileContractID(fcid), startHeight) if err != nil { @@ -109,7 +147,8 @@ func AncestorContracts(ctx context.Context, tx sql.Tx, fcid types.FileContractID (*Currency)(&c.Spending.Uploads), (*Currency)(&c.Spending.Downloads), (*Currency)(&c.Spending.FundAccount), (*Currency)(&c.Spending.Deletions), &c.ProofHeight, &c.RevisionHeight, &c.RevisionNumber, &c.Size, &c.StartHeight, &state, &c.WindowStart, - &c.WindowEnd) + &c.WindowEnd, &c.HostIP, (*Currency)(&c.ContractPrice), (*FileContractID)(&c.RenewedFrom), + (*Currency)(&c.TotalCost), &c.ArchivalReason) if err != nil { return nil, fmt.Errorf("failed to scan contract: %w", err) } @@ -120,19 +159,7 @@ func AncestorContracts(ctx context.Context, tx sql.Tx, fcid types.FileContractID } func ArchiveContract(ctx context.Context, tx sql.Tx, fcid types.FileContractID, reason string) error { - _, err := tx.Exec(ctx, ` - INSERT INTO archived_contracts (created_at, fcid, renewed_from, contract_price, state, total_cost, - proof_height, revision_height, revision_number, size, start_height, window_start, window_end, - upload_spending, download_spending, fund_account_spending, delete_spending, list_spending, renewed_to, - host, reason) - SELECT ?, fcid, renewed_from, contract_price, state, total_cost, proof_height, revision_height, revision_number, - size, start_height, window_start, window_end, upload_spending, download_spending, fund_account_spending, - delete_spending, list_spending, NULL, h.public_key, ? - FROM contracts c - INNER JOIN hosts h ON h.id = c.host_id - WHERE fcid = ? - `, time.Now(), reason, FileContractID(fcid)) - if err != nil { + if err := copyContractToArchive(ctx, tx, fcid, nil, reason); err != nil { return fmt.Errorf("failed to copy contract to archived_contracts: %w", err) } res, err := tx.Exec(ctx, "DELETE FROM contracts WHERE fcid = ?", FileContractID(fcid)) @@ -183,6 +210,16 @@ func Bucket(ctx context.Context, tx sql.Tx, bucket string) (api.Bucket, error) { return b, nil } +func Contract(ctx context.Context, tx sql.Tx, fcid types.FileContractID) (api.ContractMetadata, error) { + contracts, err := QueryContracts(ctx, tx, []string{"c.fcid = ?"}, []any{FileContractID(fcid)}) + if err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to fetch contract: %w", err) + } else if len(contracts) == 0 { + return api.ContractMetadata{}, api.ErrContractNotFound + } + return contracts[0], nil +} + func ContractRoots(ctx context.Context, tx sql.Tx, fcid types.FileContractID) ([]types.Hash256, error) { rows, err := tx.Query(ctx, ` SELECT s.root @@ -223,6 +260,17 @@ func Contracts(ctx context.Context, tx sql.Tx, opts api.ContractsOpts) ([]api.Co return QueryContracts(ctx, tx, whereExprs, whereArgs) } +func ContractSetID(ctx context.Context, tx sql.Tx, contractSet string) (int64, error) { + var id int64 + err := tx.QueryRow(ctx, "SELECT id FROM contract_sets WHERE name = ?", contractSet).Scan(&id) + if errors.Is(err, dsql.ErrNoRows) { + return 0, api.ErrContractSetNotFound + } else if err != nil { + return 0, fmt.Errorf("failed to fetch contract set id: %w", err) + } + return id, nil +} + func ContractSets(ctx context.Context, tx sql.Tx) ([]string, error) { rows, err := tx.Query(ctx, "SELECT name FROM contract_sets") if err != nil { @@ -241,6 +289,32 @@ func ContractSets(ctx context.Context, tx sql.Tx) ([]string, error) { return sets, nil } +func ContractSize(ctx context.Context, tx sql.Tx, id types.FileContractID) (api.ContractSize, error) { + var contractID, size uint64 + if err := tx.QueryRow(ctx, "SELECT id, size FROM contracts WHERE fcid = ?", FileContractID(id)). + Scan(&contractID, &size); errors.Is(err, dsql.ErrNoRows) { + return api.ContractSize{}, api.ErrContractNotFound + } else if err != nil { + return api.ContractSize{}, err + } + + var nSectors uint64 + if err := tx.QueryRow(ctx, "SELECT COUNT(*) FROM contract_sectors WHERE db_contract_id = ?", contractID). + Scan(&nSectors); err != nil { + return api.ContractSize{}, err + } + sectorsSize := nSectors * rhpv2.SectorSize + + var prunable uint64 + if size > sectorsSize { + prunable = size - sectorsSize + } + return api.ContractSize{ + Size: size, + Prunable: prunable, + }, nil +} + func ContractSizes(ctx context.Context, tx sql.Tx) (map[types.FileContractID]api.ContractSize, error) { // the following query consists of two parts: // 1. fetch all contracts that have no sectors and consider their size as @@ -267,6 +341,7 @@ func ContractSizes(ctx context.Context, tx sql.Tx) (map[types.FileContractID]api if err != nil { return nil, fmt.Errorf("failed to fetch contract sizes: %w", err) } + defer rows.Close() sizes := make(map[types.FileContractID]api.ContractSize) for rows.Next() { @@ -368,6 +443,28 @@ func CopyObject(ctx context.Context, tx sql.Tx, srcBucket, dstBucket, srcKey, ds return fetchMetadata(dstObjID) } +func DeleteBucket(ctx context.Context, tx sql.Tx, bucket string) error { + var id int64 + err := tx.QueryRow(ctx, "SELECT id FROM buckets WHERE name = ?", bucket).Scan(&id) + if errors.Is(err, dsql.ErrNoRows) { + return api.ErrBucketNotFound + } else if err != nil { + return fmt.Errorf("failed to fetch bucket id: %w", err) + } + var empty bool + err = tx.QueryRow(ctx, "SELECT NOT EXISTS(SELECT 1 FROM objects WHERE db_bucket_id = ?)", id).Scan(&empty) + if err != nil { + return fmt.Errorf("failed to check if bucket is empty: %w", err) + } else if !empty { + return api.ErrBucketNotEmpty + } + _, err = tx.Exec(ctx, "DELETE FROM buckets WHERE id = ?", id) + if err != nil { + return fmt.Errorf("failed to delete bucket: %w", err) + } + return nil +} + func DeleteHostSector(ctx context.Context, tx sql.Tx, hk types.PublicKey, root types.Hash256) (int, error) { // update the latest_host field of the sector _, err := tx.Exec(ctx, ` @@ -439,6 +536,11 @@ func DeleteHostSector(ctx context.Context, tx sql.Tx, hk types.PublicKey, root t return int(deletedSectors), nil } +func DeleteMetadata(ctx context.Context, tx sql.Tx, objID int64) error { + _, err := tx.Exec(ctx, "DELETE FROM object_user_metadata WHERE db_object_id = ?", objID) + return err +} + func DeleteSettings(ctx context.Context, tx sql.Tx, key string) error { if _, err := tx.Exec(ctx, "DELETE FROM settings WHERE `key` = ?", key); err != nil { return fmt.Errorf("failed to delete setting '%s': %w", key, err) @@ -458,6 +560,65 @@ func DeleteWebhook(ctx context.Context, tx sql.Tx, wh webhooks.Webhook) error { return nil } +func FetchUsedContracts(ctx context.Context, tx sql.Tx, fcids []types.FileContractID) (map[types.FileContractID]UsedContract, error) { + if len(fcids) == 0 { + return make(map[types.FileContractID]UsedContract), nil + } + + // flatten map to get all used contract ids + usedFCIDs := make([]FileContractID, 0, len(fcids)) + for _, fcid := range fcids { + usedFCIDs = append(usedFCIDs, FileContractID(fcid)) + } + + placeholders := make([]string, len(usedFCIDs)) + for i := range usedFCIDs { + placeholders[i] = "?" + } + placeholdersStr := strings.Join(placeholders, ", ") + + args := make([]interface{}, len(usedFCIDs)*2) + for i := range args { + args[i] = usedFCIDs[i%len(usedFCIDs)] + } + + // fetch all contracts, take into account renewals + rows, err := tx.Query(ctx, fmt.Sprintf(`SELECT id, fcid, renewed_from + FROM contracts + WHERE contracts.fcid IN (%s) OR renewed_from IN (%s) + `, placeholdersStr, placeholdersStr), args...) + if err != nil { + return nil, fmt.Errorf("failed to fetch used contracts: %w", err) + } + defer rows.Close() + + var contracts []UsedContract + for rows.Next() { + var c UsedContract + if err := rows.Scan(&c.ID, &c.FCID, &c.RenewedFrom); err != nil { + return nil, fmt.Errorf("failed to scan used contract: %w", err) + } + contracts = append(contracts, c) + } + + fcidMap := make(map[types.FileContractID]struct{}, len(fcids)) + for _, fcid := range fcids { + fcidMap[fcid] = struct{}{} + } + + // build map of used contracts + usedContracts := make(map[types.FileContractID]UsedContract, len(contracts)) + for _, c := range contracts { + if _, used := fcidMap[types.FileContractID(c.FCID)]; used { + usedContracts[types.FileContractID(c.FCID)] = c + } + if _, used := fcidMap[types.FileContractID(c.RenewedFrom)]; used { + usedContracts[types.FileContractID(c.RenewedFrom)] = c + } + } + return usedContracts, nil +} + func HostAllowlist(ctx context.Context, tx sql.Tx) ([]types.PublicKey, error) { rows, err := tx.Query(ctx, "SELECT entry FROM host_allowlist_entries") if err != nil { @@ -531,20 +692,72 @@ func InsertBufferedSlab(ctx context.Context, tx sql.Tx, fileName string, contrac return 0, fmt.Errorf("failed to fetch buffered slab id: %w", err) } - key, err := ec.MarshalBinary() - if err != nil { - return 0, err - } _, err = tx.Exec(ctx, ` INSERT INTO slabs (created_at, db_contract_set_id, db_buffered_slab_id, `+"`key`"+`, min_shards, total_shards) VALUES (?, ?, ?, ?, ?, ?)`, - time.Now(), contractSetID, bufferedSlabID, SecretKey(key), minShards, totalShards) + time.Now(), contractSetID, bufferedSlabID, EncryptionKey(ec), minShards, totalShards) if err != nil { return 0, fmt.Errorf("failed to insert slab: %w", err) } return bufferedSlabID, nil } +func InsertContract(ctx context.Context, tx sql.Tx, rev rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) { + var contractState ContractState + if err := contractState.LoadString(state); err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to load contract state: %w", err) + } + var hostID int64 + if err := tx.QueryRow(ctx, "SELECT id FROM hosts WHERE public_key = ?", + PublicKey(rev.HostKey())).Scan(&hostID); err != nil { + return api.ContractMetadata{}, api.ErrHostNotFound + } + + res, err := tx.Exec(ctx, ` + INSERT INTO contracts (created_at, host_id, fcid, renewed_from, contract_price, state, total_cost, proof_height, + revision_height, revision_number, size, start_height, window_start, window_end, upload_spending, download_spending, + fund_account_spending, delete_spending, list_spending) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, time.Now(), hostID, FileContractID(rev.ID()), FileContractID(renewedFrom), Currency(contractPrice), + contractState, Currency(totalCost), 0, 0, "0", rev.Revision.Filesize, startHeight, rev.Revision.WindowStart, rev.Revision.WindowEnd, + ZeroCurrency, ZeroCurrency, ZeroCurrency, ZeroCurrency, ZeroCurrency) + if err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to insert contract: %w", err) + } + cid, err := res.LastInsertId() + if err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to fetch contract id: %w", err) + } + + contracts, err := QueryContracts(ctx, tx, []string{"c.id = ?"}, []any{cid}) + if err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to fetch contract: %w", err) + } else if len(contracts) == 0 { + return api.ContractMetadata{}, api.ErrContractNotFound + } + return contracts[0], nil +} + +func InsertMetadata(ctx context.Context, tx sql.Tx, objID, muID *int64, md api.ObjectUserMetadata) error { + if len(md) == 0 { + return nil + } else if (objID == nil) == (muID == nil) { + return errors.New("either objID or muID must be set") + } + insertMetadataStmt, err := tx.Prepare(ctx, "INSERT INTO object_user_metadata (created_at, db_object_id, db_multipart_upload_id, `key`, value) VALUES (?, ?, ?, ?, ?)") + if err != nil { + return fmt.Errorf("failed to prepare statement to insert object metadata: %w", err) + } + defer insertMetadataStmt.Close() + + for k, v := range md { + if _, err := insertMetadataStmt.Exec(ctx, time.Now(), objID, muID, k, v); err != nil { + return fmt.Errorf("failed to insert object metadata: %w", err) + } + } + return nil +} + func InsertMultipartUpload(ctx context.Context, tx sql.Tx, bucket, key string, ec object.EncryptionKey, mimeType string, metadata api.ObjectUserMetadata) (string, error) { // fetch bucket id var bucketID int64 @@ -556,12 +769,6 @@ func InsertMultipartUpload(ctx context.Context, tx sql.Tx, bucket, key string, e return "", fmt.Errorf("failed to fetch bucket id: %w", err) } - // marshal key - ecBytes, err := ec.MarshalBinary() - if err != nil { - return "", err - } - // insert multipart upload uploadIDEntropy := frand.Entropy256() uploadID := hex.EncodeToString(uploadIDEntropy[:]) @@ -569,7 +776,7 @@ func InsertMultipartUpload(ctx context.Context, tx sql.Tx, bucket, key string, e res, err := tx.Exec(ctx, ` INSERT INTO multipart_uploads (created_at, `+"`key`"+`, upload_id, object_id, db_bucket_id, mime_type) VALUES (?, ?, ?, ?, ?, ?) - `, time.Now(), SecretKey(ecBytes), uploadID, key, bucketID, mimeType) + `, time.Now(), EncryptionKey(ec), uploadID, key, bucketID, mimeType) if err != nil { return "", fmt.Errorf("failed to create multipart upload: %w", err) } else if muID, err = res.LastInsertId(); err != nil { @@ -583,14 +790,14 @@ func InsertMultipartUpload(ctx context.Context, tx sql.Tx, bucket, key string, e return uploadID, nil } -func InsertObject(ctx context.Context, tx sql.Tx, key string, dirID, bucketID, size int64, ec []byte, mimeType, eTag string) (int64, error) { +func InsertObject(ctx context.Context, tx sql.Tx, key string, dirID, bucketID, size int64, ec object.EncryptionKey, mimeType, eTag string) (int64, error) { res, err := tx.Exec(ctx, `INSERT INTO objects (created_at, object_id, db_directory_id, db_bucket_id, `+"`key`"+`, size, mime_type, etag) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, time.Now(), key, dirID, bucketID, - SecretKey(ec), + EncryptionKey(ec), size, mimeType, eTag) @@ -615,11 +822,8 @@ func LoadSlabBuffers(ctx context.Context, db *sql.DB) (bufferedSlabs []LoadedSla for rows.Next() { var bs LoadedSlabBuffer - var sk SecretKey - if err := rows.Scan(&bs.ID, &bs.Filename, &bs.ContractSetID, &sk, &bs.MinShards, &bs.TotalShards); err != nil { + if err := rows.Scan(&bs.ID, &bs.Filename, &bs.ContractSetID, (*EncryptionKey)(&bs.Key), &bs.MinShards, &bs.TotalShards); err != nil { return fmt.Errorf("failed to scan buffered slab: %w", err) - } else if err := bs.Key.UnmarshalBinary(sk[:]); err != nil { - return fmt.Errorf("failed to unmarshal secret key: %w", err) } bufferedSlabs = append(bufferedSlabs, bs) } @@ -676,188 +880,228 @@ func UpdateMetadata(ctx context.Context, tx sql.Tx, objID int64, md api.ObjectUs return nil } -func DeleteMetadata(ctx context.Context, tx sql.Tx, objID int64) error { - _, err := tx.Exec(ctx, "DELETE FROM object_user_metadata WHERE db_object_id = ?", objID) +func PrepareSlabHealth(ctx context.Context, tx sql.Tx, limit int64, now time.Time) error { + _, err := tx.Exec(ctx, "DROP TABLE IF EXISTS slabs_health") + if err != nil { + return fmt.Errorf("failed to drop temporary table: %w", err) + } + _, err = tx.Exec(ctx, ` + CREATE TEMPORARY TABLE slabs_health AS + SELECT slabs.id as id, CASE WHEN (slabs.min_shards = slabs.total_shards) + THEN + CASE WHEN (COUNT(DISTINCT(CASE WHEN cs.name IS NULL THEN NULL ELSE c.host_id END)) < slabs.min_shards) + THEN -1 + ELSE 1 + END + ELSE (CAST(COUNT(DISTINCT(CASE WHEN cs.name IS NULL THEN NULL ELSE c.host_id END)) AS FLOAT) - CAST(slabs.min_shards AS FLOAT)) / Cast(slabs.total_shards - slabs.min_shards AS FLOAT) + END as health + FROM slabs + INNER JOIN sectors s ON s.db_slab_id = slabs.id + LEFT JOIN contract_sectors se ON s.id = se.db_sector_id + LEFT JOIN contracts c ON se.db_contract_id = c.id + LEFT JOIN contract_set_contracts csc ON csc.db_contract_id = c.id AND csc.db_contract_set_id = slabs.db_contract_set_id + LEFT JOIN contract_sets cs ON cs.id = csc.db_contract_set_id + WHERE slabs.health_valid_until <= ? + GROUP BY slabs.id + LIMIT ? + `, now.Unix(), limit) + if err != nil { + return fmt.Errorf("failed to create temporary table: %w", err) + } + if _, err := tx.Exec(ctx, "CREATE INDEX slabs_health_id ON slabs_health (id)"); err != nil { + return fmt.Errorf("failed to create index on temporary table: %w", err) + } return err } -func InsertMetadata(ctx context.Context, tx sql.Tx, objID, muID *int64, md api.ObjectUserMetadata) error { - if len(md) == 0 { - return nil - } else if (objID == nil) == (muID == nil) { - return errors.New("either objID or muID must be set") - } - insertMetadataStmt, err := tx.Prepare(ctx, "INSERT INTO object_user_metadata (created_at, db_object_id, db_multipart_upload_id, `key`, value) VALUES (?, ?, ?, ?, ?)") +func ListBuckets(ctx context.Context, tx sql.Tx) ([]api.Bucket, error) { + rows, err := tx.Query(ctx, "SELECT created_at, name, COALESCE(policy, '{}') FROM buckets") if err != nil { - return fmt.Errorf("failed to prepare statement to insert object metadata: %w", err) + return nil, fmt.Errorf("failed to fetch buckets: %w", err) } - defer insertMetadataStmt.Close() + defer rows.Close() - for k, v := range md { - if _, err := insertMetadataStmt.Exec(ctx, time.Now(), objID, muID, k, v); err != nil { - return fmt.Errorf("failed to insert object metadata: %w", err) + var buckets []api.Bucket + for rows.Next() { + bucket, err := scanBucket(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan bucket: %w", err) } + buckets = append(buckets, bucket) } - return nil + return buckets, nil } -func ContractSize(ctx context.Context, tx sql.Tx, id types.FileContractID) (api.ContractSize, error) { - var contractID, size uint64 - if err := tx.QueryRow(ctx, "SELECT id, size FROM contracts WHERE fcid = ?", FileContractID(id)). - Scan(&contractID, &size); errors.Is(err, dsql.ErrNoRows) { - return api.ContractSize{}, api.ErrContractNotFound - } else if err != nil { - return api.ContractSize{}, err +func whereObjectMarker(marker, sortBy, sortDir string, queryMarker func(dst any, marker, col string) error) (whereExprs []string, whereArgs []any, _ error) { + if marker == "" { + return nil, nil, nil + } else if sortBy == "" || sortDir == "" { + return nil, nil, fmt.Errorf("sortBy and sortDir must be set") } - var nSectors uint64 - if err := tx.QueryRow(ctx, "SELECT COUNT(*) FROM contract_sectors WHERE db_contract_id = ?", contractID). - Scan(&nSectors); err != nil { - return api.ContractSize{}, err + desc := strings.ToLower(sortDir) == api.ObjectSortDirDesc + switch strings.ToLower(sortBy) { + case api.ObjectSortByName: + if desc { + whereExprs = append(whereExprs, "o.object_id < ?") + } else { + whereExprs = append(whereExprs, "o.object_id > ?") + } + whereArgs = append(whereArgs, marker) + case api.ObjectSortByHealth: + var markerHealth float64 + if err := queryMarker(&markerHealth, marker, "health"); err != nil { + return nil, nil, fmt.Errorf("failed to fetch health marker: %w", err) + } else if desc { + whereExprs = append(whereExprs, "((o.health <= ? AND o.object_id >?) OR o.health < ?)") + whereArgs = append(whereArgs, markerHealth, marker, markerHealth) + } else { + whereExprs = append(whereExprs, "(o.health > ? OR (o.health >= ? AND object_id > ?))") + whereArgs = append(whereArgs, markerHealth, markerHealth, marker) + } + case api.ObjectSortBySize: + var markerSize int64 + if err := queryMarker(&markerSize, marker, "size"); err != nil { + return nil, nil, fmt.Errorf("failed to fetch health marker: %w", err) + } else if desc { + whereExprs = append(whereExprs, "((o.size <= ? AND o.object_id >?) OR o.size < ?)") + whereArgs = append(whereArgs, markerSize, marker, markerSize) + } else { + whereExprs = append(whereExprs, "(o.size > ? OR (o.size >= ? AND object_id > ?))") + whereArgs = append(whereArgs, markerSize, markerSize, marker) + } + default: + return nil, nil, fmt.Errorf("invalid marker: %v", marker) } - sectorsSize := nSectors * rhpv2.SectorSize + return whereExprs, whereArgs, nil +} - var prunable uint64 - if size > sectorsSize { - prunable = size - sectorsSize +func orderByObject(sortBy, sortDir string) (orderByExprs []string, _ error) { + if sortBy == "" || sortDir == "" { + return nil, fmt.Errorf("sortBy and sortDir must be set") } - return api.ContractSize{ - Size: size, - Prunable: prunable, - }, nil -} -func DeleteBucket(ctx context.Context, tx sql.Tx, bucket string) error { - var id int64 - err := tx.QueryRow(ctx, "SELECT id FROM buckets WHERE name = ?", bucket).Scan(&id) - if errors.Is(err, dsql.ErrNoRows) { - return api.ErrBucketNotFound - } else if err != nil { - return fmt.Errorf("failed to fetch bucket id: %w", err) + dir2SQL := map[string]string{ + api.ObjectSortDirAsc: "ASC", + api.ObjectSortDirDesc: "DESC", } - var empty bool - err = tx.QueryRow(ctx, "SELECT NOT EXISTS(SELECT 1 FROM objects WHERE db_bucket_id = ?)", id).Scan(&empty) - if err != nil { - return fmt.Errorf("failed to check if bucket is empty: %w", err) - } else if !empty { - return api.ErrBucketNotEmpty + if _, ok := dir2SQL[strings.ToLower(sortDir)]; !ok { + return nil, fmt.Errorf("invalid sortDir: %v", sortDir) } - _, err = tx.Exec(ctx, "DELETE FROM buckets WHERE id = ?", id) - if err != nil { - return fmt.Errorf("failed to delete bucket: %w", err) + switch strings.ToLower(sortBy) { + case "", api.ObjectSortByName: + orderByExprs = append(orderByExprs, "o.object_id "+dir2SQL[strings.ToLower(sortDir)]) + case api.ObjectSortByHealth: + orderByExprs = append(orderByExprs, "o.health "+dir2SQL[strings.ToLower(sortDir)]) + case api.ObjectSortBySize: + orderByExprs = append(orderByExprs, "o.size "+dir2SQL[strings.ToLower(sortDir)]) + default: + return nil, fmt.Errorf("invalid sortBy: %v", sortBy) } - return nil -} -func FetchUsedContracts(ctx context.Context, tx sql.Tx, fcids []types.FileContractID) (map[types.FileContractID]UsedContract, error) { - if len(fcids) == 0 { - return make(map[types.FileContractID]UsedContract), nil + // always sort by object_id as well if we aren't explicitly + if sortBy != api.ObjectSortByName { + orderByExprs = append(orderByExprs, "o.object_id ASC") } + return orderByExprs, nil +} - // flatten map to get all used contract ids - usedFCIDs := make([]FileContractID, 0, len(fcids)) - for _, fcid := range fcids { - usedFCIDs = append(usedFCIDs, FileContractID(fcid)) +func ListObjects(ctx context.Context, tx Tx, bucket, prefix, sortBy, sortDir, marker string, limit int) (api.ObjectsListResponse, error) { + // fetch one more to see if there are more entries + if limit <= -1 { + limit = math.MaxInt + } else if limit != math.MaxInt { + limit++ } - placeholders := make([]string, len(usedFCIDs)) - for i := range usedFCIDs { - placeholders[i] = "?" + // establish sane defaults for sorting + if sortBy == "" { + sortBy = api.ObjectSortByName } - placeholdersStr := strings.Join(placeholders, ", ") - - args := make([]interface{}, len(usedFCIDs)*2) - for i := range args { - args[i] = usedFCIDs[i%len(usedFCIDs)] + if sortDir == "" { + sortDir = api.ObjectSortDirAsc } - // fetch all contracts, take into account renewals - rows, err := tx.Query(ctx, fmt.Sprintf(`SELECT id, fcid, renewed_from - FROM contracts - WHERE contracts.fcid IN (%s) OR renewed_from IN (%s) - `, placeholdersStr, placeholdersStr), args...) - if err != nil { - return nil, fmt.Errorf("failed to fetch used contracts: %w", err) - } - defer rows.Close() + // filter by bucket + whereExprs := []string{"o.db_bucket_id = (SELECT id FROM buckets b WHERE b.name = ?)"} + whereArgs := []any{bucket} - var contracts []UsedContract - for rows.Next() { - var c UsedContract - if err := rows.Scan(&c.ID, &c.FCID, &c.RenewedFrom); err != nil { - return nil, fmt.Errorf("failed to scan used contract: %w", err) - } - contracts = append(contracts, c) + // apply prefix + if prefix != "" { + whereExprs = append(whereExprs, "o.object_id LIKE ? AND SUBSTR(o.object_id, 1, ?) = ?") + whereArgs = append(whereArgs, prefix+"%", utf8.RuneCountInString(prefix), prefix) } - fcidMap := make(map[types.FileContractID]struct{}, len(fcids)) - for _, fcid := range fcids { - fcidMap[fcid] = struct{}{} + // apply sorting + orderByExprs, err := orderByObject(sortBy, sortDir) + if err != nil { + return api.ObjectsListResponse{}, fmt.Errorf("failed to apply sorting: %w", err) } - // build map of used contracts - usedContracts := make(map[types.FileContractID]UsedContract, len(contracts)) - for _, c := range contracts { - if _, used := fcidMap[types.FileContractID(c.FCID)]; used { - usedContracts[types.FileContractID(c.FCID)] = c - } - if _, used := fcidMap[types.FileContractID(c.RenewedFrom)]; used { - usedContracts[types.FileContractID(c.RenewedFrom)] = c + // apply marker + markerExprs, markerArgs, err := whereObjectMarker(marker, sortBy, sortDir, func(dst any, marker, col string) error { + err := tx.QueryRow(ctx, fmt.Sprintf(` + SELECT o.%s + FROM objects o + INNER JOIN buckets b ON o.db_bucket_id = b.id + WHERE b.name = ? AND o.object_id = ? + `, col), bucket, marker).Scan(dst) + if errors.Is(err, dsql.ErrNoRows) { + return api.ErrMarkerNotFound + } else { + return err } - } - return usedContracts, nil -} - -func PrepareSlabHealth(ctx context.Context, tx sql.Tx, limit int64, now time.Time) error { - _, err := tx.Exec(ctx, "DROP TABLE IF EXISTS slabs_health") - if err != nil { - return fmt.Errorf("failed to drop temporary table: %w", err) - } - _, err = tx.Exec(ctx, ` - CREATE TEMPORARY TABLE slabs_health AS - SELECT slabs.id as id, CASE WHEN (slabs.min_shards = slabs.total_shards) - THEN - CASE WHEN (COUNT(DISTINCT(CASE WHEN cs.name IS NULL THEN NULL ELSE c.host_id END)) < slabs.min_shards) - THEN -1 - ELSE 1 - END - ELSE (CAST(COUNT(DISTINCT(CASE WHEN cs.name IS NULL THEN NULL ELSE c.host_id END)) AS FLOAT) - CAST(slabs.min_shards AS FLOAT)) / Cast(slabs.total_shards - slabs.min_shards AS FLOAT) - END as health - FROM slabs - INNER JOIN sectors s ON s.db_slab_id = slabs.id - LEFT JOIN contract_sectors se ON s.id = se.db_sector_id - LEFT JOIN contracts c ON se.db_contract_id = c.id - LEFT JOIN contract_set_contracts csc ON csc.db_contract_id = c.id AND csc.db_contract_set_id = slabs.db_contract_set_id - LEFT JOIN contract_sets cs ON cs.id = csc.db_contract_set_id - WHERE slabs.health_valid_until <= ? - GROUP BY slabs.id - LIMIT ? - `, now.Unix(), limit) + }) if err != nil { - return fmt.Errorf("failed to create temporary table: %w", err) - } - if _, err := tx.Exec(ctx, "CREATE INDEX slabs_health_id ON slabs_health (id)"); err != nil { - return fmt.Errorf("failed to create index on temporary table: %w", err) + return api.ObjectsListResponse{}, fmt.Errorf("failed to get marker exprs: %w", err) } - return err -} + whereExprs = append(whereExprs, markerExprs...) + whereArgs = append(whereArgs, markerArgs...) -func ListBuckets(ctx context.Context, tx sql.Tx) ([]api.Bucket, error) { - rows, err := tx.Query(ctx, "SELECT created_at, name, COALESCE(policy, '{}') FROM buckets") + // apply limit + whereArgs = append(whereArgs, limit) + + // run query + rows, err := tx.Query(ctx, fmt.Sprintf(` + SELECT %s + FROM objects o + WHERE %s + ORDER BY %s + LIMIT ? + `, + tx.SelectObjectMetadataExpr(), + strings.Join(whereExprs, " AND "), + strings.Join(orderByExprs, ", ")), + whereArgs...) if err != nil { - return nil, fmt.Errorf("failed to fetch buckets: %w", err) + return api.ObjectsListResponse{}, fmt.Errorf("failed to fetch objects: %w", err) } defer rows.Close() - var buckets []api.Bucket + var objects []api.ObjectMetadata for rows.Next() { - bucket, err := scanBucket(rows) + om, err := tx.ScanObjectMetadata(rows) if err != nil { - return nil, fmt.Errorf("failed to scan bucket: %w", err) + return api.ObjectsListResponse{}, fmt.Errorf("failed to scan object metadata: %w", err) } - buckets = append(buckets, bucket) + objects = append(objects, om) } - return buckets, nil + + var hasMore bool + var nextMarker string + if len(objects) == limit { + objects = objects[:len(objects)-1] + if len(objects) > 0 { + hasMore = true + nextMarker = objects[len(objects)-1].Name + } + } + + return api.ObjectsListResponse{ + HasMore: hasMore, + NextMarker: nextMarker, + Objects: objects, + }, nil } func MultipartUpload(ctx context.Context, tx sql.Tx, uploadID string) (api.MultipartUpload, error) { @@ -878,7 +1122,7 @@ func MultipartUploadParts(ctx context.Context, tx sql.Tx, bucket, key, uploadID rows, err := tx.Query(ctx, fmt.Sprintf(` SELECT mp.part_number, mp.created_at, mp.etag, mp.size FROM multipart_parts mp - INNER JOIN multipart_uploads mus ON mus.id = mp.db_multipart_upload_id + INNER JOIN multipart_uploads mus ON mus.id = mp.db_multipart_upload_id INNER JOIN buckets b ON b.id = mus.db_bucket_id WHERE mus.object_id = ? AND b.name = ? AND mus.upload_id = ? AND part_number > ? ORDER BY part_number ASC @@ -981,36 +1225,23 @@ func MultipartUploads(ctx context.Context, tx sql.Tx, bucket, prefix, keyMarker, }, nil } -type multipartUpload struct { - ID int64 - Key string - Bucket string - BucketID int64 - EC []byte - MimeType string -} - -type multipartUploadPart struct { - ID int64 - PartNumber int64 - Etag string - Size int64 -} - func MultipartUploadForCompletion(ctx context.Context, tx sql.Tx, bucket, key, uploadID string, parts []api.MultipartCompletedPart) (multipartUpload, []multipartUploadPart, int64, string, error) { // fetch upload + var ec []byte var mpu multipartUpload err := tx.QueryRow(ctx, ` SELECT mu.id, mu.object_id, mu.mime_type, mu.key, b.name, b.id FROM multipart_uploads mu INNER JOIN buckets b ON b.id = mu.db_bucket_id WHERE mu.upload_id = ?`, uploadID). - Scan(&mpu.ID, &mpu.Key, &mpu.MimeType, &mpu.EC, &mpu.Bucket, &mpu.BucketID) + Scan(&mpu.ID, &mpu.Key, &mpu.MimeType, &ec, &mpu.Bucket, &mpu.BucketID) if err != nil { return multipartUpload{}, nil, 0, "", fmt.Errorf("failed to fetch upload: %w", err) } else if mpu.Key != key { return multipartUpload{}, nil, 0, "", fmt.Errorf("object id mismatch: %v != %v: %w", mpu.Key, key, api.ErrObjectNotFound) } else if mpu.Bucket != bucket { return multipartUpload{}, nil, 0, "", fmt.Errorf("bucket name mismatch: %v != %v: %w", mpu.Bucket, bucket, api.ErrBucketNotFound) + } else if err := mpu.EC.UnmarshalBinary(ec); err != nil { + return multipartUpload{}, nil, 0, "", fmt.Errorf("failed to unmarshal encryption key: %w", err) } // find relevant parts @@ -1065,61 +1296,325 @@ func MultipartUploadForCompletion(ctx context.Context, tx sql.Tx, bucket, key, u return mpu, neededParts, size, eTag, nil } -func ObjectsStats(ctx context.Context, tx sql.Tx, opts api.ObjectsStatsOpts) (api.ObjectsStatsResponse, error) { - var args []any - var bucketExpr string - var bucketID int64 - if opts.Bucket != "" { - err := tx.QueryRow(ctx, "SELECT id FROM buckets WHERE name = ?", opts.Bucket). - Scan(&bucketID) - if errors.Is(err, dsql.ErrNoRows) { - return api.ObjectsStatsResponse{}, api.ErrBucketNotFound - } else if err != nil { - return api.ObjectsStatsResponse{}, fmt.Errorf("failed to fetch bucket id: %w", err) +func NormalizePeer(peer string) (string, error) { + host, _, err := net.SplitHostPort(peer) + if err != nil { + host = peer + } + if strings.IndexByte(host, '/') != -1 { + _, subnet, err := net.ParseCIDR(host) + if err != nil { + return "", fmt.Errorf("failed to parse CIDR: %w", err) } - bucketExpr = "WHERE db_bucket_id = ?" - args = append(args, bucketID) + return subnet.String(), nil } - // objects stats - var numObjects, totalObjectsSize uint64 - var minHealth float64 - err := tx.QueryRow(ctx, "SELECT COUNT(*), COALESCE(MIN(health), 1), COALESCE(SUM(size), 0) FROM objects "+bucketExpr, args...). - Scan(&numObjects, &minHealth, &totalObjectsSize) - if err != nil { - return api.ObjectsStatsResponse{}, fmt.Errorf("failed to fetch objects stats: %w", err) + ip := net.ParseIP(host) + if ip == nil { + return "", errors.New("invalid IP address") } - // multipart upload stats - var unfinishedObjects uint64 - err = tx.QueryRow(ctx, "SELECT COUNT(*) FROM multipart_uploads "+bucketExpr, args...). - Scan(&unfinishedObjects) - if err != nil { - return api.ObjectsStatsResponse{}, fmt.Errorf("failed to fetch multipart upload stats: %w", err) + var maskLen int + if ip.To4() != nil { + maskLen = 32 + } else { + maskLen = 128 } - // multipart upload part stats - var totalUnfinishedObjectsSize uint64 - err = tx.QueryRow(ctx, "SELECT COALESCE(SUM(size), 0) FROM multipart_parts mp INNER JOIN multipart_uploads mu ON mp.db_multipart_upload_id = mu.id "+bucketExpr, args...). - Scan(&totalUnfinishedObjectsSize) + _, normalized, err := net.ParseCIDR(fmt.Sprintf("%s/%d", ip.String(), maskLen)) if err != nil { - return api.ObjectsStatsResponse{}, fmt.Errorf("failed to fetch multipart upload part stats: %w", err) + panic("failed to parse CIDR") } + return normalized.String(), nil +} - // total sectors - var whereExpr string - var whereArgs []any - if opts.Bucket != "" { - whereExpr = ` - AND EXISTS ( - SELECT 1 FROM slices sli - INNER JOIN objects o ON o.id = sli.db_object_id AND o.db_bucket_id = ? - WHERE sli.db_slab_id = sla.id - ) - ` - whereArgs = append(whereArgs, bucketID) +func dirID(ctx context.Context, tx sql.Tx, dirPath string) (int64, error) { + if !strings.HasPrefix(dirPath, "/") { + return 0, fmt.Errorf("path must start with /") + } else if !strings.HasSuffix(dirPath, "/") { + return 0, fmt.Errorf("path must end with /") } - var totalSectors uint64 + + if dirPath == "/" { + return 1, nil // root dir returned + } + + var id int64 + if err := tx.QueryRow(ctx, "SELECT id FROM directories WHERE name = ?", dirPath).Scan(&id); err != nil { + return 0, fmt.Errorf("failed to fetch directory: %w", err) + } + return id, nil +} + +func ObjectEntries(ctx context.Context, tx Tx, bucket, path, prefix, sortBy, sortDir, marker string, offset, limit int) ([]api.ObjectMetadata, bool, error) { + // sanity check we are passing a directory + if !strings.HasSuffix(path, "/") { + panic("path must end in /") + } + + // sanity check we are passing sane paging parameters + usingMarker := marker != "" + usingOffset := offset > 0 + if usingMarker && usingOffset { + return nil, false, errors.New("fetching entries using a marker and an offset is not supported at the same time") + } + + // fetch one more to see if there are more entries + if limit <= -1 { + limit = math.MaxInt + } else if limit != math.MaxInt { + limit++ + } + + // establish sane defaults for sorting + if sortBy == "" { + sortBy = api.ObjectSortByName + } + if sortDir == "" { + sortDir = api.ObjectSortDirAsc + } + + // fetch directory id + dirID, err := dirID(ctx, tx, path) + if errors.Is(err, dsql.ErrNoRows) { + return []api.ObjectMetadata{}, false, nil + } else if err != nil { + return nil, false, fmt.Errorf("failed to fetch directory id: %w", err) + } + + args := []any{ + path, + dirID, bucket, + } + + // apply prefix + var prefixExpr string + if prefix != "" { + prefixExpr = "AND SUBSTR(o.object_id, 1, ?) = ?" + args = append(args, + utf8.RuneCountInString(path+prefix), path+prefix, + utf8.RuneCountInString(path+prefix), path+prefix, + ) + } + + args = append(args, + bucket, + path+"%", + utf8.RuneCountInString(path), path, + dirID, + ) + + // apply marker + var whereExpr string + markerExprs, markerArgs, err := whereObjectMarker(marker, sortBy, sortDir, func(dst any, marker, col string) error { + var groupFn string + switch col { + case "size": + groupFn = "SUM" + case "health": + groupFn = "MIN" + default: + return fmt.Errorf("unknown column: %v", col) + } + err := tx.QueryRow(ctx, fmt.Sprintf(` + SELECT o.%s + FROM objects o + INNER JOIN buckets b ON o.db_bucket_id = b.id + WHERE b.name = ? AND o.object_id = ? + UNION ALL + SELECT %s(o.%s) + FROM objects o + INNER JOIN buckets b ON o.db_bucket_id = b.id + INNER JOIN directories d ON SUBSTR(o.object_id, 1, %s(d.name)) = d.name + WHERE b.name = ? AND d.name = ? + GROUP BY d.id + `, col, groupFn, col, tx.CharLengthExpr()), bucket, marker, bucket, marker).Scan(dst) + if errors.Is(err, dsql.ErrNoRows) { + return api.ErrMarkerNotFound + } else { + return err + } + }) + if err != nil { + return nil, false, fmt.Errorf("failed to query marker: %w", err) + } else if len(markerExprs) > 0 { + whereExpr = "WHERE " + strings.Join(markerExprs, " AND ") + } + args = append(args, markerArgs...) + + // apply sorting + orderByExprs, err := orderByObject(sortBy, sortDir) + if err != nil { + return nil, false, fmt.Errorf("failed to apply sorting: %w", err) + } + + // apply offset and limit + args = append(args, limit, offset) + + // objectsQuery consists of 2 parts + // 1. fetch all objects in requested directory + // 2. fetch all sub-directories + rows, err := tx.Query(ctx, fmt.Sprintf(` + SELECT %s + FROM ( + SELECT o.object_id, o.size, o.health, o.mime_type, o.created_at, o.etag + FROM objects o + LEFT JOIN directories d ON d.name = o.object_id + WHERE o.object_id != ? AND o.db_directory_id = ? AND o.db_bucket_id = (SELECT id FROM buckets b WHERE b.name = ?) %s + AND d.id IS NULL + UNION ALL + SELECT d.name as object_id, SUM(o.size), MIN(o.health), '' as mime_type, MAX(o.created_at) as created_at, '' as etag + FROM objects o + INNER JOIN directories d ON SUBSTR(o.object_id, 1, %s(d.name)) = d.name %s + WHERE o.db_bucket_id = (SELECT id FROM buckets b WHERE b.name = ?) + AND o.object_id LIKE ? + AND SUBSTR(o.object_id, 1, ?) = ? + AND d.db_parent_id = ? + GROUP BY d.id + ) AS o + %s + ORDER BY %s + LIMIT ? OFFSET ? + `, + tx.SelectObjectMetadataExpr(), + prefixExpr, + tx.CharLengthExpr(), + prefixExpr, + whereExpr, + strings.Join(orderByExprs, ", "), + ), args...) + if err != nil { + return nil, false, fmt.Errorf("failed to fetch objects: %w", err) + } + defer rows.Close() + + var objects []api.ObjectMetadata + for rows.Next() { + om, err := tx.ScanObjectMetadata(rows) + if err != nil { + return nil, false, fmt.Errorf("failed to scan object metadata: %w", err) + } + objects = append(objects, om) + } + + // trim last element if we have more + var hasMore bool + if len(objects) == limit { + hasMore = true + objects = objects[:len(objects)-1] + } + + return objects, hasMore, nil +} + +func ObjectMetadata(ctx context.Context, tx Tx, bucket, key string) (api.Object, error) { + // fetch object id + var objID int64 + if err := tx.QueryRow(ctx, ` + SELECT o.id + FROM objects o + INNER JOIN buckets b ON b.id = o.db_bucket_id + WHERE o.object_id = ? AND b.name = ? + `, key, bucket).Scan(&objID); errors.Is(err, dsql.ErrNoRows) { + return api.Object{}, api.ErrObjectNotFound + } else if err != nil { + return api.Object{}, fmt.Errorf("failed to fetch object id: %w", err) + } + + // fetch metadata + om, err := tx.ScanObjectMetadata(tx.QueryRow(ctx, fmt.Sprintf(` + SELECT %s + FROM objects o + WHERE o.id = ? + `, tx.SelectObjectMetadataExpr()), objID)) + if err != nil { + return api.Object{}, fmt.Errorf("failed to fetch object metadata: %w", err) + } + + // fetch user metadata + rows, err := tx.Query(ctx, ` + SELECT oum.key, oum.value + FROM object_user_metadata oum + WHERE oum.db_object_id = ? + ORDER BY oum.id ASC + `, objID) + if err != nil { + return api.Object{}, fmt.Errorf("failed to fetch user metadata: %w", err) + } + defer rows.Close() + + // build object + metadata := make(api.ObjectUserMetadata) + for rows.Next() { + var key, value string + if err := rows.Scan(&key, &value); err != nil { + return api.Object{}, fmt.Errorf("failed to scan user metadata: %w", err) + } + metadata[key] = value + } + + return api.Object{ + Metadata: metadata, + ObjectMetadata: om, + Object: nil, // only return metadata + }, nil +} + +func ObjectsStats(ctx context.Context, tx sql.Tx, opts api.ObjectsStatsOpts) (api.ObjectsStatsResponse, error) { + var args []any + var bucketExpr string + var bucketID int64 + if opts.Bucket != "" { + err := tx.QueryRow(ctx, "SELECT id FROM buckets WHERE name = ?", opts.Bucket). + Scan(&bucketID) + if errors.Is(err, dsql.ErrNoRows) { + return api.ObjectsStatsResponse{}, api.ErrBucketNotFound + } else if err != nil { + return api.ObjectsStatsResponse{}, fmt.Errorf("failed to fetch bucket id: %w", err) + } + bucketExpr = "WHERE db_bucket_id = ?" + args = append(args, bucketID) + } + + // objects stats + var numObjects, totalObjectsSize uint64 + var minHealth float64 + err := tx.QueryRow(ctx, "SELECT COUNT(*), COALESCE(MIN(health), 1), COALESCE(SUM(size), 0) FROM objects "+bucketExpr, args...). + Scan(&numObjects, &minHealth, &totalObjectsSize) + if err != nil { + return api.ObjectsStatsResponse{}, fmt.Errorf("failed to fetch objects stats: %w", err) + } + + // multipart upload stats + var unfinishedObjects uint64 + err = tx.QueryRow(ctx, "SELECT COUNT(*) FROM multipart_uploads "+bucketExpr, args...). + Scan(&unfinishedObjects) + if err != nil { + return api.ObjectsStatsResponse{}, fmt.Errorf("failed to fetch multipart upload stats: %w", err) + } + + // multipart upload part stats + var totalUnfinishedObjectsSize uint64 + err = tx.QueryRow(ctx, "SELECT COALESCE(SUM(size), 0) FROM multipart_parts mp INNER JOIN multipart_uploads mu ON mp.db_multipart_upload_id = mu.id "+bucketExpr, args...). + Scan(&totalUnfinishedObjectsSize) + if err != nil { + return api.ObjectsStatsResponse{}, fmt.Errorf("failed to fetch multipart upload part stats: %w", err) + } + + // total sectors + var whereExpr string + var whereArgs []any + if opts.Bucket != "" { + whereExpr = ` + AND EXISTS ( + SELECT 1 FROM slices sli + INNER JOIN objects o ON o.id = sli.db_object_id AND o.db_bucket_id = ? + WHERE sli.db_slab_id = sla.id + ) + ` + whereArgs = append(whereArgs, bucketID) + } + var totalSectors uint64 err = tx.QueryRow(ctx, "SELECT COALESCE(SUM(total_shards), 0) FROM slabs sla WHERE db_buffered_slab_id IS NULL "+whereExpr, whereArgs...). Scan(&totalSectors) if err != nil { @@ -1144,6 +1639,78 @@ func ObjectsStats(ctx context.Context, tx sql.Tx, opts api.ObjectsStatsOpts) (ap }, nil } +func PeerBanned(ctx context.Context, tx sql.Tx, addr string) (bool, error) { + // normalize the address to a CIDR + netCIDR, err := NormalizePeer(addr) + if err != nil { + return false, err + } + + // parse the subnet + _, subnet, err := net.ParseCIDR(netCIDR) + if err != nil { + return false, err + } + + // check all subnets from the given subnet to the max subnet length + var maxMaskLen int + if subnet.IP.To4() != nil { + maxMaskLen = 32 + } else { + maxMaskLen = 128 + } + + checkSubnets := make([]any, 0, maxMaskLen) + for i := maxMaskLen; i > 0; i-- { + _, subnet, err := net.ParseCIDR(subnet.IP.String() + "/" + strconv.Itoa(i)) + if err != nil { + return false, err + } + checkSubnets = append(checkSubnets, subnet.String()) + } + + var expiration time.Time + err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT expiration FROM syncer_bans WHERE net_cidr IN (%s) ORDER BY expiration DESC LIMIT 1`, strings.Repeat("?, ", len(checkSubnets)-1)+"?"), checkSubnets...). + Scan((*UnixTimeMS)(&expiration)) + if errors.Is(err, dsql.ErrNoRows) { + return false, nil + } else if err != nil { + return false, err + } + + return time.Now().Before(expiration), nil +} + +func PeerInfo(ctx context.Context, tx sql.Tx, addr string) (syncer.PeerInfo, error) { + var peer syncer.PeerInfo + err := tx.QueryRow(ctx, "SELECT address, first_seen, last_connect, synced_blocks, sync_duration FROM syncer_peers WHERE address = ?", addr). + Scan(&peer.Address, (*UnixTimeMS)(&peer.FirstSeen), (*UnixTimeMS)(&peer.LastConnect), (*Unsigned64)(&peer.SyncedBlocks), &peer.SyncDuration) + if errors.Is(err, dsql.ErrNoRows) { + return syncer.PeerInfo{}, syncer.ErrPeerNotFound + } else if err != nil { + return syncer.PeerInfo{}, fmt.Errorf("failed to fetch peer: %w", err) + } + return peer, nil +} + +func Peers(ctx context.Context, tx sql.Tx) ([]syncer.PeerInfo, error) { + rows, err := tx.Query(ctx, "SELECT address, first_seen, last_connect, synced_blocks, sync_duration FROM syncer_peers") + if err != nil { + return nil, fmt.Errorf("failed to fetch peers: %w", err) + } + defer rows.Close() + + var peers []syncer.PeerInfo + for rows.Next() { + var peer syncer.PeerInfo + if err := rows.Scan(&peer.Address, (*UnixTimeMS)(&peer.FirstSeen), (*UnixTimeMS)(&peer.LastConnect), (*Unsigned64)(&peer.SyncedBlocks), &peer.SyncDuration); err != nil { + return nil, fmt.Errorf("failed to scan peer: %w", err) + } + peers = append(peers, peer) + } + return peers, nil +} + func RecordHostScans(ctx context.Context, tx sql.Tx, scans []api.HostScan) error { if len(scans) == 0 { return nil @@ -1167,7 +1734,7 @@ func RecordHostScans(ctx context.Context, tx sql.Tx, scans []api.HostScan) error price_table_expiry = CASE WHEN ? AND (price_table_expiry IS NULL OR ? > price_table_expiry) THEN ? ELSE price_table_expiry END, successful_interactions = CASE WHEN ? THEN successful_interactions + 1 ELSE successful_interactions END, failed_interactions = CASE WHEN ? THEN failed_interactions + 1 ELSE failed_interactions END, - subnets = CASE WHEN ? THEN ? ELSE subnets END + resolved_addresses = CASE WHEN ? THEN ? ELSE resolved_addresses END WHERE public_key = ? `) if err != nil { @@ -1191,7 +1758,7 @@ func RecordHostScans(ctx context.Context, tx sql.Tx, scans []api.HostScan) error scan.Success, now, now, // price_table_expiry scan.Success, // successful_interactions !scan.Success, // failed_interactions - len(scan.Subnets) > 0, strings.Join(scan.Subnets, ","), + len(scan.ResolvedAddresses) > 0, strings.Join(scan.ResolvedAddresses, ","), PublicKey(scan.HostKey), ) if err != nil { @@ -1284,24 +1851,46 @@ func RemoveOfflineHosts(ctx context.Context, tx sql.Tx, minRecentFailures uint64 return res.RowsAffected() } -func InitConsensusInfo(ctx context.Context, tx sql.Tx) (types.ChainIndex, modules.ConsensusChangeID, error) { - // try fetch existing - var ccid modules.ConsensusChangeID - var ci types.ChainIndex - err := tx.QueryRow(ctx, "SELECT cc_id, height, block_id FROM consensus_infos WHERE id = ?", consensuInfoID). - Scan((*CCID)(&ccid), &ci.Height, (*Hash256)(&ci.ID)) - if err != nil && !errors.Is(err, dsql.ErrNoRows) { - return types.ChainIndex{}, modules.ConsensusChangeID{}, fmt.Errorf("failed to fetch consensus info: %w", err) - } else if err == nil { - return ci, ccid, nil +func RenewContract(ctx context.Context, tx sql.Tx, rev rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) { + var contractState ContractState + if err := contractState.LoadString(state); err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to load contract state: %w", err) } - // otherwise init - ci = types.ChainIndex{} - if _, err := tx.Exec(ctx, "INSERT INTO consensus_infos (id, created_at, cc_id, height, block_id) VALUES (?, ?, ?, ?, ?)", - consensuInfoID, time.Now(), (CCID)(modules.ConsensusChangeBeginning), ci.Height, (Hash256)(ci.ID)); err != nil { - return types.ChainIndex{}, modules.ConsensusChangeID{}, fmt.Errorf("failed to init consensus infos: %w", err) + // create copy of contract in archived_contracts + if err := copyContractToArchive(ctx, tx, renewedFrom, &rev.Revision.ParentID, api.ContractArchivalReasonRenewed); err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to copy contract to archived_contracts: %w", err) } - return types.ChainIndex{}, modules.ConsensusChangeBeginning, nil + // update existing contract + _, err := tx.Exec(ctx, ` + UPDATE contracts SET + created_at = ?, + fcid = ?, + renewed_from = ?, + contract_price = ?, + state = ?, + total_cost = ?, + proof_height = ?, + revision_height = ?, + revision_number = ?, + size = ?, + start_height = ?, + window_start = ?, + window_end = ?, + upload_spending = ?, + download_spending = ?, + fund_account_spending = ?, + delete_spending = ?, + list_spending = ? + WHERE fcid = ? + `, + time.Now(), FileContractID(rev.ID()), FileContractID(renewedFrom), Currency(contractPrice), contractState, + Currency(totalCost), 0, 0, fmt.Sprint(rev.Revision.RevisionNumber), rev.Revision.Filesize, startHeight, + rev.Revision.WindowStart, rev.Revision.WindowEnd, ZeroCurrency, ZeroCurrency, ZeroCurrency, ZeroCurrency, + ZeroCurrency, FileContractID(renewedFrom)) + if err != nil { + return api.ContractMetadata{}, fmt.Errorf("failed to update contract: %w", err) + } + return Contract(ctx, tx, rev.ID()) } func QueryContracts(ctx context.Context, tx sql.Tx, whereExprs []string, whereArgs []any) ([]api.ContractMetadata, error) { @@ -1367,17 +1956,15 @@ func RenewedContract(ctx context.Context, tx sql.Tx, renewedFrom types.FileContr return contracts[0], nil } -func ResetConsensusSubscription(ctx context.Context, tx sql.Tx) (ci types.ChainIndex, err error) { +func ResetChainState(ctx context.Context, tx sql.Tx) error { if _, err := tx.Exec(ctx, "DELETE FROM consensus_infos"); err != nil { - return types.ChainIndex{}, fmt.Errorf("failed to delete consensus infos: %w", err) - } else if _, err := tx.Exec(ctx, "DELETE FROM siacoin_elements"); err != nil { - return types.ChainIndex{}, fmt.Errorf("failed to delete siacoin elements: %w", err) - } else if _, err := tx.Exec(ctx, "DELETE FROM transactions"); err != nil { - return types.ChainIndex{}, fmt.Errorf("failed to delete transactions: %w", err) - } else if ci, _, err = InitConsensusInfo(ctx, tx); err != nil { - return types.ChainIndex{}, fmt.Errorf("failed to initialize consensus info: %w", err) + return err + } else if _, err := tx.Exec(ctx, "DELETE FROM wallet_events"); err != nil { + return err + } else if _, err := tx.Exec(ctx, "DELETE FROM wallet_outputs"); err != nil { + return err } - return ci, nil + return nil } func ResetLostSectors(ctx context.Context, tx sql.Tx, hk types.PublicKey) error { @@ -1522,7 +2109,7 @@ func SearchHosts(ctx context.Context, tx sql.Tx, autopilot, filterMode, usabilit SELECT h.id, h.created_at, h.last_announcement, h.public_key, h.net_address, h.price_table, h.price_table_expiry, h.settings, h.total_scans, h.last_scan, h.last_scan_success, h.second_to_last_scan_success, h.uptime, h.downtime, h.successful_interactions, h.failed_interactions, COALESCE(h.lost_sectors, 0), - h.scanned, h.subnets, %s + h.scanned, h.resolved_addresses, %s FROM hosts h %s %s @@ -1537,20 +2124,24 @@ func SearchHosts(ctx context.Context, tx sql.Tx, autopilot, filterMode, usabilit var h api.Host var hostID int64 var pte dsql.NullTime - var subnets string + var resolvedAddresses string err := rows.Scan(&hostID, &h.KnownSince, &h.LastAnnouncement, (*PublicKey)(&h.PublicKey), &h.NetAddress, (*PriceTable)(&h.PriceTable.HostPriceTable), &pte, (*HostSettings)(&h.Settings), &h.Interactions.TotalScans, (*UnixTimeNS)(&h.Interactions.LastScan), &h.Interactions.LastScanSuccess, &h.Interactions.SecondToLastScanSuccess, &h.Interactions.Uptime, &h.Interactions.Downtime, &h.Interactions.SuccessfulInteractions, &h.Interactions.FailedInteractions, &h.Interactions.LostSectors, - &h.Scanned, &subnets, &h.Blocked, + &h.Scanned, &resolvedAddresses, &h.Blocked, ) if err != nil { return nil, fmt.Errorf("failed to scan host: %w", err) } - if subnets != "" { - h.Subnets = strings.Split(subnets, ",") + if resolvedAddresses != "" { + h.ResolvedAddresses = strings.Split(resolvedAddresses, ",") + h.Subnets, err = utils.AddressesToSubnets(h.ResolvedAddresses) + if err != nil { + return nil, fmt.Errorf("failed to convert addresses to subnets: %w", err) + } } h.PriceTable.Expiry = pte.Time h.StoredData = storedDataMap[hostID] @@ -1626,6 +2217,7 @@ func Settings(ctx context.Context, tx sql.Tx) ([]string, error) { if err != nil { return nil, fmt.Errorf("failed to query settings: %w", err) } + defer rows.Close() var settings []string for rows.Next() { var setting string @@ -1645,6 +2237,83 @@ func SetUncleanShutdown(ctx context.Context, tx sql.Tx) error { return err } +func Slab(ctx context.Context, tx sql.Tx, key object.EncryptionKey) (object.Slab, error) { + // fetch slab + var slabID int64 + slab := object.Slab{Key: key} + err := tx.QueryRow(ctx, ` + SELECT id, health, min_shards + FROM slabs sla + WHERE sla.key = ? + `, EncryptionKey(key)).Scan(&slabID, &slab.Health, &slab.MinShards) + if errors.Is(err, dsql.ErrNoRows) { + return object.Slab{}, api.ErrSlabNotFound + } else if err != nil { + return object.Slab{}, fmt.Errorf("failed to fetch slab: %w", err) + } + + // fetch sectors + rows, err := tx.Query(ctx, ` + SELECT id, latest_host, root + FROM sectors s + WHERE s.db_slab_id = ? + ORDER BY s.slab_index + `, slabID) + if err != nil { + return object.Slab{}, fmt.Errorf("failed to fetch sectors: %w", err) + } + defer rows.Close() + + var sectorIDs []int64 + for rows.Next() { + var sectorID int64 + var sector object.Sector + if err := rows.Scan(§orID, (*PublicKey)(§or.LatestHost), (*Hash256)(§or.Root)); err != nil { + return object.Slab{}, fmt.Errorf("failed to scan sector: %w", err) + } + slab.Shards = append(slab.Shards, sector) + sectorIDs = append(sectorIDs, sectorID) + } + + // fetch contracts for each sector + stmt, err := tx.Prepare(ctx, ` + SELECT h.public_key, c.fcid + FROM contract_sectors cs + INNER JOIN contracts c ON c.id = cs.db_contract_id + INNER JOIN hosts h ON h.id = c.host_id + WHERE cs.db_sector_id = ? + ORDER BY c.id + `) + if err != nil { + return object.Slab{}, fmt.Errorf("failed to prepare statement to fetch contracts: %w", err) + } + defer stmt.Close() + + for i, sectorID := range sectorIDs { + rows, err := stmt.Query(ctx, sectorID) + if err != nil { + return object.Slab{}, fmt.Errorf("failed to fetch contracts: %w", err) + } + if err := func() error { + defer rows.Close() + + slab.Shards[i].Contracts = make(map[types.PublicKey][]types.FileContractID) + for rows.Next() { + var pk types.PublicKey + var fcid types.FileContractID + if err := rows.Scan((*PublicKey)(&pk), (*FileContractID)(&fcid)); err != nil { + return fmt.Errorf("failed to scan contract: %w", err) + } + slab.Shards[i].Contracts[pk] = append(slab.Shards[i].Contracts[pk], fcid) + } + return nil + }(); err != nil { + return object.Slab{}, err + } + } + return slab, nil +} + func SlabBuffers(ctx context.Context, tx sql.Tx) (map[string]string, error) { rows, err := tx.Query(ctx, ` SELECT buffered_slabs.filename, cs.name @@ -1669,6 +2338,48 @@ func SlabBuffers(ctx context.Context, tx sql.Tx) (map[string]string, error) { return fileNameToContractSet, nil } +func Tip(ctx context.Context, tx sql.Tx) (types.ChainIndex, error) { + var id Hash256 + var height uint64 + if err := tx.QueryRow(ctx, "SELECT height, block_id FROM consensus_infos WHERE id = ?", sql.ConsensusInfoID). + Scan(&height, &id); errors.Is(err, dsql.ErrNoRows) { + // init + _, err = tx.Exec(ctx, "INSERT INTO consensus_infos (id, height, block_id) VALUES (?, ?, ?)", sql.ConsensusInfoID, 0, Hash256{}) + return types.ChainIndex{}, err + } else if err != nil { + return types.ChainIndex{}, err + } + return types.ChainIndex{ + ID: types.BlockID(id), + Height: height, + }, nil +} + +func UnhealthySlabs(ctx context.Context, tx sql.Tx, healthCutoff float64, set string, limit int) ([]api.UnhealthySlab, error) { + rows, err := tx.Query(ctx, ` + SELECT sla.key, sla.health + FROM slabs sla + INNER JOIN contract_sets cs ON sla.db_contract_set_id = cs.id + WHERE sla.health <= ? AND cs.name = ? AND sla.health_valid_until > ? AND sla.db_buffered_slab_id IS NULL + ORDER BY sla.health ASC + LIMIT ? + `, healthCutoff, set, time.Now().Unix(), limit) + if err != nil { + return nil, fmt.Errorf("failed to fetch unhealthy slabs: %w", err) + } + defer rows.Close() + + var slabs []api.UnhealthySlab + for rows.Next() { + var slab api.UnhealthySlab + if err := rows.Scan((*EncryptionKey)(&slab.Key), &slab.Health); err != nil { + return nil, fmt.Errorf("failed to scan unhealthy slab: %w", err) + } + slabs = append(slabs, slab) + } + return slabs, nil +} + func UpdateBucketPolicy(ctx context.Context, tx sql.Tx, bucket string, bp api.BucketPolicy) error { policy, err := json.Marshal(bp) if err != nil { @@ -1685,6 +2396,30 @@ func UpdateBucketPolicy(ctx context.Context, tx sql.Tx, bucket string, bp api.Bu return nil } +func UpdatePeerInfo(ctx context.Context, tx sql.Tx, addr string, fn func(*syncer.PeerInfo)) error { + info, err := PeerInfo(ctx, tx, addr) + if err != nil { + return err + } + fn(&info) + + res, err := tx.Exec(ctx, "UPDATE syncer_peers SET last_connect = ?, synced_blocks = ?, sync_duration = ? WHERE address = ?", + UnixTimeMS(info.LastConnect), + Unsigned64(info.SyncedBlocks), + info.SyncDuration, + addr, + ) + if err != nil { + return fmt.Errorf("failed to update peer info: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to check rows affected: %w", err) + } else if n == 0 { + return syncer.ErrPeerNotFound + } + + return nil +} + func Webhooks(ctx context.Context, tx sql.Tx) ([]webhooks.Webhook, error) { rows, err := tx.Query(ctx, "SELECT module, event, url, headers FROM webhooks") if err != nil { @@ -1706,7 +2441,70 @@ func Webhooks(ctx context.Context, tx sql.Tx) ([]webhooks.Webhook, error) { return whs, nil } -func scanAutopilot(s scanner) (api.Autopilot, error) { +func UnspentSiacoinElements(ctx context.Context, tx sql.Tx) (elements []types.SiacoinElement, err error) { + rows, err := tx.Query(ctx, "SELECT output_id, leaf_index, merkle_proof, address, value, maturity_height FROM wallet_outputs") + if err != nil { + return nil, fmt.Errorf("failed to fetch wallet events: %w", err) + } + defer rows.Close() + + for rows.Next() { + element, err := scanSiacoinElement(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan wallet event: %w", err) + } + elements = append(elements, element) + } + return +} + +func WalletEvents(ctx context.Context, tx sql.Tx, offset, limit int) (events []wallet.Event, _ error) { + if limit == 0 || limit == -1 { + limit = math.MaxInt64 + } + + rows, err := tx.Query(ctx, "SELECT event_id, block_id, height, inflow, outflow, type, data, maturity_height, timestamp FROM wallet_events ORDER BY timestamp DESC LIMIT ? OFFSET ?", limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to fetch wallet events: %w", err) + } + defer rows.Close() + + for rows.Next() { + event, err := scanWalletEvent(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan wallet event: %w", err) + } + events = append(events, event) + } + return +} + +func WalletEventCount(ctx context.Context, tx sql.Tx) (count uint64, err error) { + var n int64 + err = tx.QueryRow(ctx, "SELECT COUNT(*) FROM wallet_events").Scan(&n) + if err != nil { + return 0, fmt.Errorf("failed to count wallet events: %w", err) + } + return uint64(n), nil +} + +func copyContractToArchive(ctx context.Context, tx sql.Tx, fcid types.FileContractID, renewedTo *types.FileContractID, reason string) error { + _, err := tx.Exec(ctx, ` + INSERT INTO archived_contracts (created_at, fcid, renewed_from, contract_price, state, total_cost, + proof_height, revision_height, revision_number, size, start_height, window_start, window_end, + upload_spending, download_spending, fund_account_spending, delete_spending, list_spending, renewed_to, + host, reason) + SELECT ?, fcid, renewed_from, contract_price, state, total_cost, proof_height, revision_height, revision_number, + size, start_height, window_start, window_end, upload_spending, download_spending, fund_account_spending, + delete_spending, list_spending, ?, h.public_key, ? + FROM contracts c + INNER JOIN hosts h ON h.id = c.host_id + WHERE fcid = ? + `, time.Now(), (*FileContractID)(renewedTo), reason, FileContractID(fcid)) + return err +} + +func scanAutopilot(s Scanner) (api.Autopilot, error) { var a api.Autopilot if err := s.Scan(&a.ID, (*AutopilotConfig)(&a.Config), &a.CurrentPeriod); err != nil { return api.Autopilot{}, err @@ -1714,7 +2512,7 @@ func scanAutopilot(s scanner) (api.Autopilot, error) { return a, nil } -func scanBucket(s scanner) (api.Bucket, error) { +func scanBucket(s Scanner) (api.Bucket, error) { var createdAt time.Time var name, policy string err := s.Scan(&createdAt, &name, &policy) @@ -1734,170 +2532,386 @@ func scanBucket(s scanner) (api.Bucket, error) { }, nil } -func scanMultipartUpload(s scanner) (resp api.MultipartUpload, _ error) { - var key SecretKey - err := s.Scan(&resp.Bucket, &key, &resp.Path, &resp.UploadID, &resp.CreatedAt) +func scanMultipartUpload(s Scanner) (resp api.MultipartUpload, _ error) { + err := s.Scan(&resp.Bucket, (*EncryptionKey)(&resp.Key), &resp.Path, &resp.UploadID, &resp.CreatedAt) if errors.Is(err, dsql.ErrNoRows) { return api.MultipartUpload{}, api.ErrMultipartUploadNotFound } else if err != nil { return api.MultipartUpload{}, fmt.Errorf("failed to fetch multipart upload: %w", err) - } else if err := resp.Key.UnmarshalBinary(key); err != nil { - return api.MultipartUpload{}, fmt.Errorf("failed to unmarshal encryption key: %w", err) } return } -func scanObjectMetadata(s scanner) (api.ObjectMetadata, error) { - var md api.ObjectMetadata - if err := s.Scan(&md.Name, &md.Size, &md.Health, &md.MimeType, &md.ModTime, &md.ETag); err != nil { - return api.ObjectMetadata{}, fmt.Errorf("failed to scan object metadata: %w", err) - } - return md, nil +func scanWalletEvent(s Scanner) (wallet.Event, error) { + var blockID, eventID Hash256 + var height, maturityHeight uint64 + var inflow, outflow Currency + var edata []byte + var etype string + var ts UnixTimeNS + if err := s.Scan( + &eventID, + &blockID, + &height, + &inflow, + &outflow, + &etype, + &edata, + &maturityHeight, + &ts, + ); err != nil { + return wallet.Event{}, err + } + + data, err := UnmarshalEventData(edata, etype) + if err != nil { + return wallet.Event{}, err + } + return wallet.Event{ + ID: types.Hash256(eventID), + Index: types.ChainIndex{ + ID: types.BlockID(blockID), + Height: height, + }, + Type: etype, + Data: data, + MaturityHeight: maturityHeight, + Timestamp: time.Time(ts), + }, nil } -func ListObjects(ctx context.Context, tx sql.Tx, bucket, prefix, sortBy, sortDir, marker string, limit int) (api.ObjectsListResponse, error) { - // fetch one more to see if there are more entries +func scanSiacoinElement(s Scanner) (el types.SiacoinElement, err error) { + var id Hash256 + var leafIndex, maturityHeight uint64 + var merkleProof MerkleProof + var address Hash256 + var value Currency + err = s.Scan(&id, &leafIndex, &merkleProof, &address, &value, &maturityHeight) + if err != nil { + return types.SiacoinElement{}, err + } + return types.SiacoinElement{ + StateElement: types.StateElement{ + ID: types.Hash256(id), + LeafIndex: leafIndex, + MerkleProof: merkleProof.Hashes, + }, + SiacoinOutput: types.SiacoinOutput{ + Address: types.Address(address), + Value: types.Currency(value), + }, + MaturityHeight: maturityHeight, + }, nil +} + +func scanStateElement(s Scanner) (types.StateElement, error) { + var id Hash256 + var leafIndex uint64 + var merkleProof MerkleProof + if err := s.Scan(&id, &leafIndex, &merkleProof); err != nil { + return types.StateElement{}, err + } + return types.StateElement{ + ID: types.Hash256(id), + LeafIndex: leafIndex, + MerkleProof: merkleProof.Hashes, + }, nil +} + +func SearchObjects(ctx context.Context, tx Tx, bucket, substring string, offset, limit int) ([]api.ObjectMetadata, error) { if limit <= -1 { limit = math.MaxInt - } else { - limit++ } - // establish sane defaults for sorting - if sortBy == "" { - sortBy = api.ObjectSortByName + rows, err := tx.Query(ctx, fmt.Sprintf(` + SELECT %s + FROM objects o + INNER JOIN buckets b ON o.db_bucket_id = b.id + WHERE INSTR(o.object_id, ?) > 0 AND b.name = ? + ORDER BY o.object_id ASC + LIMIT ? OFFSET ? + `, tx.SelectObjectMetadataExpr()), substring, bucket, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to search objects: %w", err) } - if sortDir == "" { - sortDir = api.ObjectSortDirAsc + defer rows.Close() + + var objects []api.ObjectMetadata + for rows.Next() { + om, err := tx.ScanObjectMetadata(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan object metadata: %w", err) + } + objects = append(objects, om) } + return objects, nil +} - // filter by bucket - whereExprs := []string{"o.db_bucket_id = (SELECT id FROM buckets b WHERE b.name = ?)"} - whereArgs := []any{bucket} +func ObjectsBySlabKey(ctx context.Context, tx Tx, bucket string, slabKey object.EncryptionKey) ([]api.ObjectMetadata, error) { + rows, err := tx.Query(ctx, fmt.Sprintf(` + SELECT %s + FROM objects o + INNER JOIN buckets b ON o.db_bucket_id = b.id + WHERE b.name = ? AND EXISTS ( + SELECT 1 + FROM objects o2 + INNER JOIN slices sli ON sli.db_object_id = o2.id + INNER JOIN slabs sla ON sla.id = sli.db_slab_id + WHERE o2.id = o.id AND sla.key = ? + ) + `, tx.SelectObjectMetadataExpr()), bucket, EncryptionKey(slabKey)) + if err != nil { + return nil, fmt.Errorf("failed to query objects: %w", err) + } + defer rows.Close() - // apply prefix - if prefix != "" { - whereExprs = append(whereExprs, "o.object_id LIKE ? AND SUBSTR(o.object_id, 1, ?) = ?") - whereArgs = append(whereArgs, prefix+"%", utf8.RuneCountInString(prefix), prefix) + var objects []api.ObjectMetadata + for rows.Next() { + om, err := tx.ScanObjectMetadata(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan object metadata: %w", err) + } + objects = append(objects, om) } + return objects, nil +} - // apply sorting - dir2SQL := map[string]string{ - api.ObjectSortDirAsc: "ASC", - api.ObjectSortDirDesc: "DESC", +func MarkPackedSlabUploaded(ctx context.Context, tx Tx, slab api.UploadedPackedSlab) (string, error) { + // fetch relevant slab info + var slabID, bufferedSlabID int64 + var bufferFileName string + if err := tx.QueryRow(ctx, ` + SELECT sla.id, bs.id, bs.filename + FROM slabs sla + INNER JOIN buffered_slabs bs ON bs.id = sla.db_buffered_slab_id + WHERE sla.db_buffered_slab_id = ? + `, slab.BufferID). + Scan(&slabID, &bufferedSlabID, &bufferFileName); err != nil { + return "", fmt.Errorf("failed to fetch slab id: %w", err) } - if _, ok := dir2SQL[strings.ToLower(sortDir)]; !ok { - return api.ObjectsListResponse{}, fmt.Errorf("invalid sortDir: %v", sortDir) + + // set 'db_buffered_slab_id' to NULL + if _, err := tx.Exec(ctx, "UPDATE slabs SET db_buffered_slab_id = NULL WHERE id = ?", slabID); err != nil { + return "", fmt.Errorf("failed to update slab: %w", err) } - var orderByExprs []string - switch strings.ToLower(sortBy) { - case "", api.ObjectSortByName: - orderByExprs = append(orderByExprs, "o.object_id "+dir2SQL[strings.ToLower(sortDir)]) - case api.ObjectSortByHealth: - orderByExprs = append(orderByExprs, "o.health "+dir2SQL[strings.ToLower(sortDir)]) - case api.ObjectSortBySize: - orderByExprs = append(orderByExprs, "o.size "+dir2SQL[strings.ToLower(sortDir)]) - default: - return api.ObjectsListResponse{}, fmt.Errorf("invalid sortBy: %v", sortBy) + + // delete buffer slab + if _, err := tx.Exec(ctx, "DELETE FROM buffered_slabs WHERE id = ?", bufferedSlabID); err != nil { + return "", fmt.Errorf("failed to delete buffered slab: %w", err) } - // always sort by object_id as well if we aren't explicitly - if sortBy != api.ObjectSortByName { - orderByExprs = append(orderByExprs, "o.object_id ASC") + // fetch used contracts + usedContracts, err := FetchUsedContracts(ctx, tx, slab.Contracts()) + if err != nil { + return "", fmt.Errorf("failed to fetch used contracts: %w", err) } - // apply marker - queryMarker := func(dst any, marker, col string) error { - err := tx.QueryRow(ctx, fmt.Sprintf(` - SELECT o.%s - FROM objects o - INNER JOIN buckets b ON o.db_bucket_id = b.id - WHERE b.name = ? AND o.object_id = ? - `, col), bucket, marker).Scan(dst) - if errors.Is(err, dsql.ErrNoRows) { - return api.ErrMarkerNotFound - } else { - return err - } + // stmt to add sector + sectorStmt, err := tx.Prepare(ctx, "INSERT INTO sectors (db_slab_id, slab_index, latest_host, root) VALUES (?, ?, ?, ?)") + if err != nil { + return "", fmt.Errorf("failed to prepare statement to insert sectors: %w", err) } - desc := strings.ToLower(sortDir) == api.ObjectSortDirDesc - if marker != "" { - switch strings.ToLower(sortBy) { - case api.ObjectSortByName: - if desc { - whereExprs = append(whereExprs, "o.object_id < ?") - } else { - whereExprs = append(whereExprs, "o.object_id > ?") - } - whereArgs = append(whereArgs, marker) - case api.ObjectSortByHealth: - var markerHealth float64 - if err := queryMarker(&markerHealth, marker, "health"); err != nil { - return api.ObjectsListResponse{}, fmt.Errorf("failed to fetch health marker: %w", err) - } else if desc { - whereExprs = append(whereExprs, "((o.health <= ? AND o.object_id >?) OR o.health < ?)") - whereArgs = append(whereArgs, markerHealth, marker, markerHealth) - } else { - whereExprs = append(whereExprs, "(o.health > ? OR (o.health >= ? AND object_id > ?))") - whereArgs = append(whereArgs, markerHealth, markerHealth, marker) - } - case api.ObjectSortBySize: - var markerSize int64 - if err := queryMarker(&markerSize, marker, "size"); err != nil { - return api.ObjectsListResponse{}, fmt.Errorf("failed to fetch health marker: %w", err) - } else if desc { - whereExprs = append(whereExprs, "((o.size <= ? AND o.object_id >?) OR o.size < ?)") - whereArgs = append(whereArgs, markerSize, marker, markerSize) - } else { - whereExprs = append(whereExprs, "(o.size > ? OR (o.size >= ? AND object_id > ?))") - whereArgs = append(whereArgs, markerSize, markerSize, marker) + defer sectorStmt.Close() + + // stmt to insert contract_sector + contractSectorStmt, err := tx.Prepare(ctx, "INSERT INTO contract_sectors (db_contract_id, db_sector_id) VALUES (?, ?)") + if err != nil { + return "", fmt.Errorf("failed to prepare statement to insert contract sectors: %w", err) + } + defer contractSectorStmt.Close() + + // insert shards + for i := range slab.Shards { + // insert shard + res, err := sectorStmt.Exec(ctx, slabID, i+1, PublicKey(slab.Shards[i].LatestHost), slab.Shards[i].Root[:]) + if err != nil { + return "", fmt.Errorf("failed to insert sector: %w", err) + } + sectorID, err := res.LastInsertId() + if err != nil { + return "", fmt.Errorf("failed to get sector id: %w", err) + } + + // insert contracts for shard + for _, fcids := range slab.Shards[i].Contracts { + for _, fcid := range fcids { + uc, ok := usedContracts[fcid] + if !ok { + continue + } + // insert contract sector + if _, err := contractSectorStmt.Exec(ctx, uc.ID, sectorID); err != nil { + return "", fmt.Errorf("failed to insert contract sector: %w", err) + } } - default: - return api.ObjectsListResponse{}, fmt.Errorf("invalid marker: %v", marker) } } + return bufferFileName, nil +} - // apply limit - whereArgs = append(whereArgs, limit) +func RecordContractSpending(ctx context.Context, tx Tx, fcid types.FileContractID, revisionNumber, size uint64, newSpending api.ContractSpending) error { + var updateKeys []string + var updateValues []interface{} - // run query - rows, err := tx.Query(ctx, fmt.Sprintf(` - SELECT o.object_id, o.size, o.health, o.mime_type, o.created_at, o.etag + if !newSpending.Uploads.IsZero() { + updateKeys = append(updateKeys, "upload_spending = ?") + updateValues = append(updateValues, Currency(newSpending.Uploads)) + } + if !newSpending.Downloads.IsZero() { + updateKeys = append(updateKeys, "download_spending = ?") + updateValues = append(updateValues, Currency(newSpending.Downloads)) + } + if !newSpending.FundAccount.IsZero() { + updateKeys = append(updateKeys, "fund_account_spending = ?") + updateValues = append(updateValues, Currency(newSpending.FundAccount)) + } + if !newSpending.Deletions.IsZero() { + updateKeys = append(updateKeys, "delete_spending = ?") + updateValues = append(updateValues, Currency(newSpending.Deletions)) + } + if !newSpending.SectorRoots.IsZero() { + updateKeys = append(updateKeys, "list_spending = ?") + updateValues = append(updateValues, Currency(newSpending.SectorRoots)) + } + updateKeys = append(updateKeys, "revision_number = ?", "size = ?") + updateValues = append(updateValues, revisionNumber, size) + + updateValues = append(updateValues, FileContractID(fcid)) + _, err := tx.Exec(ctx, fmt.Sprintf(` + UPDATE contracts + SET %s + WHERE fcid = ? + `, strings.Join(updateKeys, ",")), updateValues...) + if err != nil { + return fmt.Errorf("failed to record contract spending: %w", err) + } + return nil +} + +func Object(ctx context.Context, tx Tx, bucket, key string) (api.Object, error) { + /// fetch object metadata + row := tx.QueryRow(ctx, fmt.Sprintf(` + SELECT %s, o.id, o.key FROM objects o - WHERE %s - ORDER BY %s - LIMIT ? + INNER JOIN buckets b ON o.db_bucket_id = b.id + WHERE o.object_id = ? AND b.name = ? `, - strings.Join(whereExprs, " AND "), - strings.Join(orderByExprs, ", ")), - whereArgs...) + tx.SelectObjectMetadataExpr()), key, bucket) + var objID int64 + var ec object.EncryptionKey + om, err := tx.ScanObjectMetadata(row, &objID, (*EncryptionKey)(&ec)) + if errors.Is(err, dsql.ErrNoRows) { + return api.Object{}, api.ErrObjectNotFound + } else if err != nil { + return api.Object{}, err + } + + // fetch user metadata + rows, err := tx.Query(ctx, ` + SELECT oum.key, oum.value + FROM object_user_metadata oum + WHERE oum.db_object_id = ? + `, objID) if err != nil { - return api.ObjectsListResponse{}, fmt.Errorf("failed to fetch objects: %w", err) + return api.Object{}, fmt.Errorf("failed to fetch user metadata: %w", err) } defer rows.Close() - var objects []api.ObjectMetadata + oum := make(api.ObjectUserMetadata) for rows.Next() { - om, err := scanObjectMetadata(rows) - if err != nil { - return api.ObjectsListResponse{}, fmt.Errorf("failed to scan object metadata: %w", err) + var key, value string + if err := rows.Scan(&key, &value); err != nil { + return api.Object{}, fmt.Errorf("failed to scan user metadata: %w", err) } - objects = append(objects, om) + oum[key] = value } - var hasMore bool - var nextMarker string - if len(objects) == limit { - objects = objects[:len(objects)-1] - if len(objects) > 0 { - hasMore = true - nextMarker = objects[len(objects)-1].Name + // fetch slab slices + rows, err = tx.Query(ctx, ` + SELECT sla.db_buffered_slab_id IS NOT NULL, sli.object_index, sli.offset, sli.length, sla.health, sla.key, sla.min_shards, COALESCE(sec.slab_index, 0), COALESCE(sec.root, ?), COALESCE(sec.latest_host, ?), COALESCE(c.fcid, ?), COALESCE(h.public_key, ?) + FROM slices sli + INNER JOIN slabs sla ON sli.db_slab_id = sla.id + LEFT JOIN sectors sec ON sec.db_slab_id = sla.id + LEFT JOIN contract_sectors csec ON csec.db_sector_id = sec.id + LEFT JOIN contracts c ON c.id = csec.db_contract_id + LEFT JOIN hosts h ON h.id = c.host_id + WHERE sli.db_object_id = ? + ORDER BY sli.object_index ASC, sec.slab_index ASC + `, Hash256{}, PublicKey{}, FileContractID{}, PublicKey{}, objID) + if err != nil { + return api.Object{}, fmt.Errorf("failed to fetch slabs: %w", err) + } + defer rows.Close() + + slabSlices := object.SlabSlices{} + var current *object.SlabSlice + var currObjIdx, currSlaIdx int64 + for rows.Next() { + var bufferedSlab bool + var objectIndex int64 + var slabIndex int64 + var ss object.SlabSlice + var sector object.Sector + var fcid types.FileContractID + var hk types.PublicKey + if err := rows.Scan(&bufferedSlab, // whether the slab is buffered + &objectIndex, &ss.Offset, &ss.Length, // slice info + &ss.Health, (*EncryptionKey)(&ss.Key), &ss.MinShards, // slab info + &slabIndex, (*Hash256)(§or.Root), (*PublicKey)(§or.LatestHost), // sector info + (*PublicKey)(&fcid), // contract info + (*PublicKey)(&hk), // host info + ); err != nil { + return api.Object{}, fmt.Errorf("failed to scan slab slice: %w", err) + } + + // sanity check object for corruption + isFirst := current == nil && objectIndex == 1 && slabIndex == 1 + isBuffered := bufferedSlab && objectIndex == currObjIdx+1 && slabIndex == 0 + isNewSlab := isFirst || isBuffered || (current != nil && objectIndex == currObjIdx+1 && slabIndex == 1) + isNewShard := isNewSlab || (objectIndex == currObjIdx && slabIndex == currSlaIdx+1) + isNewContract := isNewShard || (objectIndex == currObjIdx && slabIndex == currSlaIdx) + if !isFirst && !isBuffered && !isNewSlab && !isNewShard && !isNewContract { + return api.Object{}, fmt.Errorf("%w: object index %d, slab index %d, current object index %d, current slab index %d", api.ErrObjectCorrupted, objectIndex, slabIndex, currObjIdx, currSlaIdx) + } + + // update indices + currObjIdx = objectIndex + currSlaIdx = slabIndex + + if isNewSlab { + if current != nil { + slabSlices = append(slabSlices, *current) + } + current = &ss + } + + // if the slab is buffered there are no sectors/contracts to add + if bufferedSlab { + continue + } + + if isNewShard { + current.Shards = append(current.Shards, sector) + } + if isNewContract { + if current.Shards[len(current.Shards)-1].Contracts == nil { + current.Shards[len(current.Shards)-1].Contracts = make(map[types.PublicKey][]types.FileContractID) + } + current.Shards[len(current.Shards)-1].Contracts[hk] = append(current.Shards[len(current.Shards)-1].Contracts[hk], fcid) } } - return api.ObjectsListResponse{ - HasMore: hasMore, - NextMarker: nextMarker, - Objects: objects, + // add last slab slice + if current != nil { + slabSlices = append(slabSlices, *current) + } + + return api.Object{ + Metadata: oum, + ObjectMetadata: om, + Object: &object.Object{ + Key: ec, + Slabs: slabSlices, + }, }, nil } diff --git a/stores/sql/metrics.go b/stores/sql/metrics.go index 689f98843..6f6e5420f 100644 --- a/stores/sql/metrics.go +++ b/stores/sql/metrics.go @@ -206,6 +206,7 @@ func RecordContractMetric(ctx context.Context, tx sql.Tx, metrics ...api.Contrac if err != nil { return fmt.Errorf("failed to prepare statement to delete contract metric: %w", err) } + defer deleteStmt.Close() for _, metric := range metrics { // delete any existing metric for the same contract that has happened @@ -365,7 +366,7 @@ func RecordPerformanceMetric(ctx context.Context, tx sql.Tx, metrics ...api.Perf } func RecordWalletMetric(ctx context.Context, tx sql.Tx, metrics ...api.WalletMetric) error { - insertStmt, err := tx.Prepare(ctx, "INSERT INTO wallets (created_at, timestamp, confirmed_lo, confirmed_hi, spendable_lo, spendable_hi, unconfirmed_lo, unconfirmed_hi) VALUES (?, ?, ?, ?, ?, ?, ?, ?)") + insertStmt, err := tx.Prepare(ctx, "INSERT INTO wallets (created_at, timestamp, confirmed_lo, confirmed_hi, spendable_lo, spendable_hi, unconfirmed_lo, unconfirmed_hi, immature_hi, immature_lo) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") if err != nil { return fmt.Errorf("failed to prepare statement to insert wallet metric: %w", err) } @@ -381,6 +382,8 @@ func RecordWalletMetric(ctx context.Context, tx sql.Tx, metrics ...api.WalletMet Unsigned64(metric.Spendable.Hi), Unsigned64(metric.Unconfirmed.Lo), Unsigned64(metric.Unconfirmed.Hi), + Unsigned64(metric.Immature.Lo), + Unsigned64(metric.Immature.Hi), ) if err != nil { return fmt.Errorf("failed to insert wallet metric: %w", err) @@ -406,6 +409,7 @@ func WalletMetrics(ctx context.Context, tx sql.Tx, start time.Time, n uint64, in (*Unsigned64)(&m.Confirmed.Lo), (*Unsigned64)(&m.Confirmed.Hi), (*Unsigned64)(&m.Spendable.Lo), (*Unsigned64)(&m.Spendable.Hi), (*Unsigned64)(&m.Unconfirmed.Lo), (*Unsigned64)(&m.Unconfirmed.Hi), + (*Unsigned64)(&m.Immature.Lo), (*Unsigned64)(&m.Immature.Hi), ) if err != nil { err = fmt.Errorf("failed to scan contract set metric: %w", err) diff --git a/stores/sql/mysql/chain.go b/stores/sql/mysql/chain.go new file mode 100644 index 000000000..56d5ba340 --- /dev/null +++ b/stores/sql/mysql/chain.go @@ -0,0 +1,333 @@ +package mysql + +import ( + "context" + "encoding/json" + "fmt" + "net" + "strings" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/renterd/api" + isql "go.sia.tech/renterd/internal/sql" + ssql "go.sia.tech/renterd/stores/sql" + "go.uber.org/zap" +) + +var ( + _ ssql.ChainUpdateTx = (*chainUpdateTx)(nil) +) + +type chainUpdateTx struct { + ctx context.Context + tx isql.Tx + l *zap.SugaredLogger +} + +func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent []types.SiacoinElement, events []wallet.Event, timestamp time.Time) error { + c.l.Debugw("applying index", "height", index.Height, "block_id", index.ID) + + if len(spent) > 0 { + // prepare statement to delete spent outputs + deleteSpentStmt, err := c.tx.Prepare(c.ctx, "DELETE FROM wallet_outputs WHERE output_id = ?") + if err != nil { + return fmt.Errorf("failed to prepare statement to delete spent outputs: %w", err) + } + defer deleteSpentStmt.Close() + + // delete spent outputs + for _, e := range spent { + c.l.Debugw(fmt.Sprintf("remove output %v", e.ID), "height", index.Height, "block_id", index.ID) + if res, err := deleteSpentStmt.Exec(c.ctx, ssql.Hash256(e.ID)); err != nil { + return fmt.Errorf("failed to delete spent output: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to delete spent output: no rows affected") + } + } + } + + if len(created) > 0 { + // prepare statement to insert new outputs + insertOutputStmt, err := c.tx.Prepare(c.ctx, "INSERT IGNORE INTO wallet_outputs (created_at, output_id, leaf_index, merkle_proof, value, address, maturity_height) VALUES (?, ?, ?, ?, ?, ?, ?)") + if err != nil { + return fmt.Errorf("failed to prepare statement to insert new outputs: %w", err) + } + defer insertOutputStmt.Close() + + // insert new outputs + for _, e := range created { + c.l.Debugw(fmt.Sprintf("create output %v", e.ID), "height", index.Height, "block_id", index.ID) + if _, err := insertOutputStmt.Exec(c.ctx, + time.Now().UTC(), + ssql.Hash256(e.ID), + e.StateElement.LeafIndex, + ssql.MerkleProof{Hashes: e.StateElement.MerkleProof}, + ssql.Currency(e.SiacoinOutput.Value), + ssql.Hash256(e.SiacoinOutput.Address), + e.MaturityHeight, + ); err != nil { + return fmt.Errorf("failed to insert new output: %w", err) + } + } + } + + if len(events) > 0 { + // prepare statement to insert new events + insertEventStmt, err := c.tx.Prepare(c.ctx, "INSERT IGNORE INTO wallet_events (created_at, event_id, height, block_id, inflow, outflow, type, data, maturity_height, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + if err != nil { + return fmt.Errorf("failed to prepare statement to insert new events: %w", err) + } + defer insertEventStmt.Close() + + // insert new events + for _, e := range events { + c.l.Debugw(fmt.Sprintf("create event %v", e.ID), "height", index.Height, "block_id", index.ID) + data, err := json.Marshal(e.Data) + if err != nil { + c.l.Error(err) + return err + } + if _, err := insertEventStmt.Exec(c.ctx, + time.Now().UTC(), + ssql.Hash256(e.ID), + e.Index.Height, + ssql.Hash256(e.Index.ID), + ssql.Currency(e.SiacoinInflow()), + ssql.Currency(e.SiacoinOutflow()), + e.Type, + data, + e.MaturityHeight, + ssql.UnixTimeNS(e.Timestamp), + ); err != nil { + return fmt.Errorf("failed to insert new event: %w", err) + } + } + } + return nil +} + +func (c chainUpdateTx) ContractState(fcid types.FileContractID) (api.ContractState, error) { + return ssql.GetContractState(c.ctx, c.tx, fcid) +} + +func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspent []types.SiacoinElement, timestamp time.Time) error { + c.l.Debugw("reverting index", "height", index.Height, "block_id", index.ID) + + if len(removed) > 0 { + // prepare statement to delete removed outputs + deleteRemovedStmt, err := c.tx.Prepare(c.ctx, "DELETE FROM wallet_outputs WHERE output_id = ?") + if err != nil { + return fmt.Errorf("failed to prepare statement to delete removed outputs: %w", err) + } + defer deleteRemovedStmt.Close() + + // delete removed outputs + for _, e := range removed { + c.l.Debugw(fmt.Sprintf("remove output %v", e.ID), "height", index.Height, "block_id", index.ID) + if res, err := deleteRemovedStmt.Exec(c.ctx, ssql.Hash256(e.ID)); err != nil { + return fmt.Errorf("failed to delete removed output: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to delete removed output: no rows affected") + } + } + } + + if len(unspent) > 0 { + // prepare statement to insert unspent outputs + insertOutputStmt, err := c.tx.Prepare(c.ctx, "INSERT IGNORE INTO wallet_outputs (created_at, output_id, leaf_index, merkle_proof, value, address, maturity_height) VALUES (?, ?, ?, ?, ?, ?, ?)") + if err != nil { + return fmt.Errorf("failed to prepare statement to insert unspent outputs: %w", err) + } + defer insertOutputStmt.Close() + + // insert unspent outputs + for _, e := range unspent { + c.l.Debugw(fmt.Sprintf("recreate unspent output %v", e.ID), "height", index.Height, "block_id", index.ID) + if _, err := insertOutputStmt.Exec(c.ctx, + time.Now().UTC(), + ssql.Hash256(e.ID), + e.StateElement.LeafIndex, + ssql.MerkleProof{Hashes: e.StateElement.MerkleProof}, + ssql.Currency(e.SiacoinOutput.Value), + ssql.Hash256(e.SiacoinOutput.Address), + e.MaturityHeight, + ); err != nil { + return fmt.Errorf("failed to insert unspent output: %w", err) + } + } + } + + // remove events created at the reverted index + res, err := c.tx.Exec(c.ctx, "DELETE FROM wallet_events WHERE height = ? AND block_id = ?", index.Height, ssql.Hash256(index.ID)) + if err != nil { + return fmt.Errorf("failed to delete events: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n > 0 { + c.l.Debugw(fmt.Sprintf("removed %d events", n), "height", index.Height, "block_id", index.ID) + } + return nil +} + +func (c chainUpdateTx) UpdateChainIndex(index types.ChainIndex) error { + return ssql.UpdateChainIndex(c.ctx, c.tx, index, c.l) +} + +func (c chainUpdateTx) UpdateContract(fcid types.FileContractID, revisionHeight, revisionNumber, size uint64) error { + return ssql.UpdateContract(c.ctx, c.tx, fcid, revisionHeight, revisionNumber, size, c.l) +} + +func (c chainUpdateTx) UpdateContractProofHeight(fcid types.FileContractID, proofHeight uint64) error { + return ssql.UpdateContractProofHeight(c.ctx, c.tx, fcid, proofHeight, c.l) +} + +func (c chainUpdateTx) UpdateContractState(fcid types.FileContractID, state api.ContractState) error { + return ssql.UpdateContractState(c.ctx, c.tx, fcid, state, c.l) +} + +func (c chainUpdateTx) UpdateFailedContracts(blockHeight uint64) error { + return ssql.UpdateFailedContracts(c.ctx, c.tx, blockHeight, c.l) +} + +func (c chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, bh uint64, blockID types.BlockID, ts time.Time) error { // + c.l.Debugw("update host", "hk", hk, "netaddress", ha.NetAddress) + + // create the announcement + if _, err := c.tx.Exec(c.ctx, + "INSERT IGNORE INTO host_announcements (created_at, host_key, block_height, block_id, net_address) VALUES (?, ?, ?, ?, ?)", + time.Now().UTC(), + ssql.PublicKey(hk), + bh, + blockID.String(), + ha.NetAddress, + ); err != nil { + return fmt.Errorf("failed to insert host announcement: %w", err) + } + + // create the host + var hostID int64 + if res, err := c.tx.Exec(c.ctx, ` + INSERT INTO hosts (created_at, public_key, settings, price_table, total_scans, last_scan, last_scan_success, second_to_last_scan_success, scanned, uptime, downtime, recent_downtime, recent_scan_failures, successful_interactions, failed_interactions, lost_sectors, last_announcement, net_address) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + last_announcement = VALUES(last_announcement), + net_address = VALUES(net_address), + id = last_insert_id(id) + `, + time.Now().UTC(), + ssql.PublicKey(hk), + ssql.HostSettings{}, + ssql.PriceTable{}, + 0, + 0, + false, + false, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ts.UTC(), + ha.NetAddress, + ); err != nil { + return fmt.Errorf("failed to insert host: %w", err) + } else if hostID, err = res.LastInsertId(); err != nil { + return fmt.Errorf("failed to fetch host id: %w", err) + } + + // update allow list + rows, err := c.tx.Query(c.ctx, "SELECT id, entry FROM host_allowlist_entries") + if err != nil { + return fmt.Errorf("failed to fetch allow list: %w", err) + } + defer rows.Close() + for rows.Next() { + var id int64 + var pk ssql.PublicKey + if err := rows.Scan(&id, &pk); err != nil { + return fmt.Errorf("failed to scan row: %w", err) + } + if hk == types.PublicKey(pk) { + if _, err := c.tx.Exec(c.ctx, + "INSERT IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) VALUES (?,?)", + id, + hostID, + ); err != nil { + return fmt.Errorf("failed to insert host into allowlist: %w", err) + } + } + } + + // update blocklist + values := []string{ha.NetAddress} + host, _, err := net.SplitHostPort(ha.NetAddress) + if err == nil { + values = append(values, host) + } + + rows, err = c.tx.Query(c.ctx, "SELECT id, entry FROM host_blocklist_entries") + if err != nil { + return fmt.Errorf("failed to fetch block list: %w", err) + } + defer rows.Close() + + type row struct { + id int64 + entry string + } + var entries []row + for rows.Next() { + var r row + if err := rows.Scan(&r.id, &r.entry); err != nil { + return fmt.Errorf("failed to scan row: %w", err) + } + entries = append(entries, r) + } + + for _, row := range entries { + var blocked bool + for _, value := range values { + if value == row.entry || strings.HasSuffix(value, "."+row.entry) { + blocked = true + break + } + } + if blocked { + if _, err := c.tx.Exec(c.ctx, + "INSERT IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) VALUES (?,?)", + row.id, + hostID, + ); err != nil { + return fmt.Errorf("failed to insert host into blocklist: %w", err) + } + } else { + if _, err := c.tx.Exec(c.ctx, + "DELETE FROM host_blocklist_entry_hosts WHERE db_blocklist_entry_id = ? AND db_host_id = ?", + row.id, + hostID, + ); err != nil { + return fmt.Errorf("failed to remove host from blocklist: %w", err) + } + } + } + + return nil +} + +func (c chainUpdateTx) UpdateWalletStateElements(elements []types.StateElement) error { + return ssql.UpdateWalletStateElements(c.ctx, c.tx, elements) +} + +func (c chainUpdateTx) WalletStateElements() ([]types.StateElement, error) { + return ssql.WalletStateElements(c.ctx, c.tx) +} diff --git a/stores/sql/mysql/common.go b/stores/sql/mysql/common.go index fca2749e7..73f6c9dc3 100644 --- a/stores/sql/mysql/common.go +++ b/stores/sql/mysql/common.go @@ -6,6 +6,7 @@ import ( "embed" "fmt" + _ "github.com/go-sql-driver/mysql" "go.sia.tech/renterd/internal/sql" ) diff --git a/stores/sql/mysql/main.go b/stores/sql/mysql/main.go index 88ce3cb83..08ff0010e 100644 --- a/stores/sql/mysql/main.go +++ b/stores/sql/mysql/main.go @@ -11,12 +11,14 @@ import ( "time" "unicode/utf8" + rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/api" "go.sia.tech/renterd/object" ssql "go.sia.tech/renterd/stores/sql" "go.sia.tech/renterd/webhooks" - "go.sia.tech/siad/modules" "lukechampine.com/frand" "go.sia.tech/renterd/internal/sql" @@ -37,11 +39,12 @@ type ( ) // NewMainDatabase creates a new MySQL backend. -func NewMainDatabase(db *dsql.DB, log *zap.SugaredLogger, lqd, ltd time.Duration) (*MainDatabase, error) { - store, err := sql.NewDB(db, log.Desugar(), deadlockMsgs, lqd, ltd) +func NewMainDatabase(db *dsql.DB, log *zap.Logger, lqd, ltd time.Duration) (*MainDatabase, error) { + log = log.Named("main") + store, err := sql.NewDB(db, log, deadlockMsgs, lqd, ltd) return &MainDatabase{ db: store, - log: log, + log: log.Sugar(), }, err } @@ -88,6 +91,10 @@ func (b *MainDatabase) wrapTxn(tx sql.Tx) *MainDatabaseTx { return &MainDatabaseTx{tx, b.log.Named(hex.EncodeToString(frand.Bytes(16)))} } +func (tx *MainDatabaseTx) AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) error { + return ssql.AbortMultipartUpload(ctx, tx, bucket, path, uploadID) +} + func (tx *MainDatabaseTx) Accounts(ctx context.Context) ([]api.Account, error) { return ssql.Accounts(ctx, tx) } @@ -136,8 +143,16 @@ func (tx *MainDatabaseTx) AddMultipartPart(ctx context.Context, bucket, path, co return tx.insertSlabs(ctx, nil, &partID, contractSet, slices) } -func (tx *MainDatabaseTx) AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) error { - return ssql.AbortMultipartUpload(ctx, tx, bucket, path, uploadID) +func (tx *MainDatabaseTx) AddPeer(ctx context.Context, addr string) error { + _, err := tx.Exec(ctx, + "INSERT IGNORE INTO syncer_peers (address, first_seen, last_connect, synced_blocks, sync_duration) VALUES (?, ?, ?, ?, ?)", + addr, + ssql.UnixTimeMS(time.Now()), + ssql.UnixTimeMS(time.Time{}), + 0, + 0, + ) + return err } func (tx *MainDatabaseTx) AddWebhook(ctx context.Context, wh webhooks.Webhook) error { @@ -173,10 +188,30 @@ func (tx *MainDatabaseTx) Autopilots(ctx context.Context) ([]api.Autopilot, erro return ssql.Autopilots(ctx, tx) } +func (tx *MainDatabaseTx) BanPeer(ctx context.Context, addr string, duration time.Duration, reason string) error { + cidr, err := ssql.NormalizePeer(addr) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, + "INSERT INTO syncer_bans (created_at, net_cidr, expiration, reason) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE expiration = VALUES(expiration), reason = VALUES(reason)", + time.Now(), + cidr, + ssql.UnixTimeMS(time.Now().Add(duration)), + reason, + ) + return err +} + func (tx *MainDatabaseTx) Bucket(ctx context.Context, bucket string) (api.Bucket, error) { return ssql.Bucket(ctx, tx, bucket) } +func (tx *MainDatabaseTx) CharLengthExpr() string { + return "CHAR_LENGTH" +} + func (tx *MainDatabaseTx) CompleteMultipartUpload(ctx context.Context, bucket, key, uploadID string, parts []api.MultipartCompletedPart, opts api.CompleteMultipartOptions) (string, error) { mpu, neededParts, size, eTag, err := ssql.MultipartUploadForCompletion(ctx, tx, bucket, key, uploadID, parts) if err != nil { @@ -240,6 +275,10 @@ func (tx *MainDatabaseTx) CompleteMultipartUpload(ctx context.Context, bucket, k return eTag, nil } +func (tx *MainDatabaseTx) Contract(ctx context.Context, fcid types.FileContractID) (api.ContractMetadata, error) { + return ssql.Contract(ctx, tx, fcid) +} + func (tx *MainDatabaseTx) ContractRoots(ctx context.Context, fcid types.FileContractID) ([]types.Hash256, error) { return ssql.ContractRoots(ctx, tx, fcid) } @@ -248,6 +287,10 @@ func (tx *MainDatabaseTx) Contracts(ctx context.Context, opts api.ContractsOpts) return ssql.Contracts(ctx, tx, opts) } +func (tx *MainDatabaseTx) ContractSetID(ctx context.Context, contractSet string) (int64, error) { + return ssql.ContractSetID(ctx, tx, contractSet) +} + func (tx *MainDatabaseTx) ContractSets(ctx context.Context) ([]string, error) { return ssql.ContractSets(ctx, tx) } @@ -289,6 +332,10 @@ func (tx *MainDatabaseTx) InsertBufferedSlab(ctx context.Context, fileName strin return ssql.InsertBufferedSlab(ctx, tx, fileName, contractSetID, ec, minShards, totalShards) } +func (tx *MainDatabaseTx) InsertContract(ctx context.Context, rev rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) { + return ssql.InsertContract(ctx, tx, rev, contractPrice, totalCost, startHeight, renewedFrom, state) +} + func (tx *MainDatabaseTx) InsertMultipartUpload(ctx context.Context, bucket, key string, ec object.EncryptionKey, mimeType string, metadata api.ObjectUserMetadata) (string, error) { return ssql.InsertMultipartUpload(ctx, tx, bucket, key, ec, mimeType, metadata) } @@ -360,10 +407,6 @@ func (tx *MainDatabaseTx) HostsForScanning(ctx context.Context, maxLastScan time return ssql.HostsForScanning(ctx, tx, maxLastScan, offset, limit) } -func (tx *MainDatabaseTx) InitConsensusInfo(ctx context.Context) (types.ChainIndex, modules.ConsensusChangeID, error) { - return ssql.InitConsensusInfo(ctx, tx) -} - func (tx *MainDatabaseTx) InsertObject(ctx context.Context, bucket, key, contractSet string, dirID int64, o object.Object, mimeType, eTag string, md api.ObjectUserMetadata) error { // get bucket id var bucketID int64 @@ -375,11 +418,7 @@ func (tx *MainDatabaseTx) InsertObject(ctx context.Context, bucket, key, contrac } // insert object - objKey, err := o.Key.MarshalBinary() - if err != nil { - return fmt.Errorf("failed to marshal object key: %w", err) - } - objID, err := ssql.InsertObject(ctx, tx, key, dirID, bucketID, o.TotalSize(), objKey, mimeType, eTag) + objID, err := ssql.InsertObject(ctx, tx, key, dirID, bucketID, o.TotalSize(), o.Key, mimeType, eTag) if err != nil { return fmt.Errorf("failed to insert object: %w", err) } @@ -471,6 +510,10 @@ func (tx *MainDatabaseTx) MakeDirsForPath(ctx context.Context, path string) (int return dirID, nil } +func (tx *MainDatabaseTx) MarkPackedSlabUploaded(ctx context.Context, slab api.UploadedPackedSlab) (string, error) { + return ssql.MarkPackedSlabUploaded(ctx, tx, slab) +} + func (tx *MainDatabaseTx) MultipartUpload(ctx context.Context, uploadID string) (api.MultipartUpload, error) { return ssql.MultipartUpload(ctx, tx, uploadID) } @@ -483,10 +526,46 @@ func (tx *MainDatabaseTx) MultipartUploads(ctx context.Context, bucket, prefix, return ssql.MultipartUploads(ctx, tx, bucket, prefix, keyMarker, uploadIDMarker, limit) } +func (tx *MainDatabaseTx) Object(ctx context.Context, bucket, key string) (api.Object, error) { + return ssql.Object(ctx, tx, bucket, key) +} + +func (tx *MainDatabaseTx) ObjectEntries(ctx context.Context, bucket, path, prefix, sortBy, sortDir, marker string, offset, limit int) ([]api.ObjectMetadata, bool, error) { + return ssql.ObjectEntries(ctx, tx, bucket, path, prefix, sortBy, sortDir, marker, offset, limit) +} + +func (tx *MainDatabaseTx) ObjectMetadata(ctx context.Context, bucket, path string) (api.Object, error) { + return ssql.ObjectMetadata(ctx, tx, bucket, path) +} + +func (tx *MainDatabaseTx) ObjectsBySlabKey(ctx context.Context, bucket string, slabKey object.EncryptionKey) (metadata []api.ObjectMetadata, err error) { + return ssql.ObjectsBySlabKey(ctx, tx, bucket, slabKey) +} + func (tx *MainDatabaseTx) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) (api.ObjectsStatsResponse, error) { return ssql.ObjectsStats(ctx, tx, opts) } +func (tx *MainDatabaseTx) PeerBanned(ctx context.Context, addr string) (bool, error) { + return ssql.PeerBanned(ctx, tx, addr) +} + +func (tx *MainDatabaseTx) PeerInfo(ctx context.Context, addr string) (syncer.PeerInfo, error) { + return ssql.PeerInfo(ctx, tx, addr) +} + +func (tx *MainDatabaseTx) Peers(ctx context.Context) ([]syncer.PeerInfo, error) { + return ssql.Peers(ctx, tx) +} + +func (tx *MainDatabaseTx) ProcessChainUpdate(ctx context.Context, fn func(ssql.ChainUpdateTx) error) error { + return fn(&chainUpdateTx{ + ctx: ctx, + tx: tx, + l: tx.log.Named("ProcessChainUpdate"), + }) +} + func (tx *MainDatabaseTx) PruneEmptydirs(ctx context.Context) error { stmt, err := tx.Prepare(ctx, ` DELETE @@ -532,6 +611,10 @@ func (tx *MainDatabaseTx) PruneSlabs(ctx context.Context, limit int64) (int64, e return res.RowsAffected() } +func (tx *MainDatabaseTx) RecordContractSpending(ctx context.Context, fcid types.FileContractID, revisionNumber, size uint64, newSpending api.ContractSpending) error { + return ssql.RecordContractSpending(ctx, tx, fcid, revisionNumber, size, newSpending) +} + func (tx *MainDatabaseTx) RecordHostScans(ctx context.Context, scans []api.HostScan) error { return ssql.RecordHostScans(ctx, tx, scans) } @@ -617,12 +700,16 @@ func (tx *MainDatabaseTx) RenameObjects(ctx context.Context, bucket, prefixOld, return nil } +func (tx *MainDatabaseTx) RenewContract(ctx context.Context, rev rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) { + return ssql.RenewContract(ctx, tx, rev, contractPrice, totalCost, startHeight, renewedFrom, state) +} + func (tx *MainDatabaseTx) RenewedContract(ctx context.Context, renwedFrom types.FileContractID) (api.ContractMetadata, error) { return ssql.RenewedContract(ctx, tx, renwedFrom) } -func (tx *MainDatabaseTx) ResetConsensusSubscription(ctx context.Context) (types.ChainIndex, error) { - return ssql.ResetConsensusSubscription(ctx, tx) +func (tx *MainDatabaseTx) ResetChainState(ctx context.Context) error { + return ssql.ResetChainState(ctx, tx.Tx) } func (tx *MainDatabaseTx) ResetLostSectors(ctx context.Context, hk types.PublicKey) error { @@ -660,10 +747,78 @@ func (tx MainDatabaseTx) SaveAccounts(ctx context.Context, accounts []api.Accoun return nil } +func (tx *MainDatabaseTx) ScanObjectMetadata(s ssql.Scanner, others ...any) (md api.ObjectMetadata, err error) { + dst := []any{&md.Name, &md.Size, &md.Health, &md.MimeType, &md.ModTime, &md.ETag} + dst = append(dst, others...) + if err := s.Scan(dst...); err != nil { + return api.ObjectMetadata{}, fmt.Errorf("failed to scan object metadata: %w", err) + } + return md, nil +} + func (tx *MainDatabaseTx) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { return ssql.SearchHosts(ctx, tx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit) } +func (tx *MainDatabaseTx) SearchObjects(ctx context.Context, bucket, substring string, offset, limit int) ([]api.ObjectMetadata, error) { + return ssql.SearchObjects(ctx, tx, bucket, substring, offset, limit) +} + +func (tx *MainDatabaseTx) SelectObjectMetadataExpr() string { + return "o.object_id, o.size, o.health, o.mime_type, o.created_at, o.etag" +} + +func (tx *MainDatabaseTx) SetContractSet(ctx context.Context, name string, contractIds []types.FileContractID) error { + res, err := tx.Exec(ctx, "INSERT INTO contract_sets (name) VALUES (?) ON DUPLICATE KEY UPDATE id = last_insert_id(id)", name) + if err != nil { + return fmt.Errorf("failed to insert contract set: %w", err) + } + + csID, err := res.LastInsertId() + if err != nil { + return fmt.Errorf("failed to fetch contract set id: %w", err) + } + + // handle empty set + if len(contractIds) == 0 { + _, err := tx.Exec(ctx, "DELETE FROM contract_set_contracts WHERE db_contract_set_id = ?", csID) + return err + } + + // prepare fcid args and query + fcidQuery := strings.Repeat("?, ", len(contractIds)-1) + "?" + fcidArgs := make([]interface{}, len(contractIds)) + for i, fcid := range contractIds { + fcidArgs[i] = ssql.FileContractID(fcid) + } + + // remove unwanted contracts + _, err = tx.Exec(ctx, fmt.Sprintf(` + DELETE csc + FROM contract_set_contracts csc + INNER JOIN contracts c ON c.id = csc.db_contract_id + WHERE c.fcid NOT IN (%s) + `, fcidQuery), fcidArgs...) + if err != nil { + return fmt.Errorf("failed to delete contract set contracts: %w", err) + } + + // add missing contracts + args := []interface{}{csID} + args = append(args, fcidArgs...) + _, err = tx.Exec(ctx, fmt.Sprintf(` + INSERT INTO contract_set_contracts (db_contract_set_id, db_contract_id) + SELECT ?, c.id + FROM contracts c + WHERE c.fcid IN (%s) + ON DUPLICATE KEY UPDATE db_contract_set_id = VALUES(db_contract_set_id) + `, fcidQuery), args...) + if err != nil { + return fmt.Errorf("failed to add contract set contracts: %w", err) + } + return nil +} + func (tx *MainDatabaseTx) Setting(ctx context.Context, key string) (string, error) { return ssql.Setting(ctx, tx, key) } @@ -676,26 +831,35 @@ func (tx *MainDatabaseTx) SetUncleanShutdown(ctx context.Context) error { return ssql.SetUncleanShutdown(ctx, tx) } +func (tx *MainDatabaseTx) Slab(ctx context.Context, key object.EncryptionKey) (object.Slab, error) { + return ssql.Slab(ctx, tx, key) +} + func (tx *MainDatabaseTx) SlabBuffers(ctx context.Context) (map[string]string, error) { return ssql.SlabBuffers(ctx, tx) } +func (tx *MainDatabaseTx) Tip(ctx context.Context) (types.ChainIndex, error) { + return ssql.Tip(ctx, tx.Tx) +} + +func (tx *MainDatabaseTx) UnhealthySlabs(ctx context.Context, healthCutoff float64, set string, limit int) ([]api.UnhealthySlab, error) { + return ssql.UnhealthySlabs(ctx, tx, healthCutoff, set, limit) +} + +func (tx *MainDatabaseTx) UnspentSiacoinElements(ctx context.Context) (elements []types.SiacoinElement, err error) { + return ssql.UnspentSiacoinElements(ctx, tx.Tx) +} + func (tx *MainDatabaseTx) UpdateAutopilot(ctx context.Context, ap api.Autopilot) error { - res, err := tx.Exec(ctx, ` + _, err := tx.Exec(ctx, ` INSERT INTO autopilots (created_at, identifier, config, current_period) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE config = VALUES(config), current_period = VALUES(current_period) `, time.Now(), ap.ID, (*ssql.AutopilotConfig)(&ap.Config), ap.CurrentPeriod) - if err != nil { - return err - } else if n, err := res.RowsAffected(); err != nil { - return err - } else if n != 1 && n != 2 { // 1 if inserted, 2 if updated - return fmt.Errorf("expected 1 row affected, got %v", n) - } - return nil + return err } func (tx *MainDatabaseTx) UpdateBucketPolicy(ctx context.Context, bucket string, bp api.BucketPolicy) error { @@ -839,6 +1003,10 @@ func (tx *MainDatabaseTx) UpdateHostCheck(ctx context.Context, autopilot string, return nil } +func (tx *MainDatabaseTx) UpdatePeerInfo(ctx context.Context, addr string, fn func(*syncer.PeerInfo)) error { + return ssql.UpdatePeerInfo(ctx, tx, addr, fn) +} + func (tx *MainDatabaseTx) UpdateSetting(ctx context.Context, key, value string) error { _, err := tx.Exec(ctx, "INSERT INTO settings (created_at, `key`, value) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE value = VALUES(value)", time.Now(), key, value) @@ -855,12 +1023,6 @@ func (tx *MainDatabaseTx) UpdateSlab(ctx context.Context, s object.Slab, contrac return fmt.Errorf("failed to fetch used contracts: %w", err) } - // extract the slab key - key, err := s.Key.MarshalBinary() - if err != nil { - return fmt.Errorf("failed to marshal slab key: %w", err) - } - // update slab res, err := tx.Exec(ctx, ` UPDATE slabs @@ -868,18 +1030,18 @@ func (tx *MainDatabaseTx) UpdateSlab(ctx context.Context, s object.Slab, contrac health_valid_until = ?, health = ? WHERE `+"`key`"+` = ? - `, contractSet, time.Now().Unix(), 1, ssql.SecretKey(key)) + `, contractSet, time.Now().Unix(), 1, ssql.EncryptionKey(s.Key)) if err != nil { return err } else if n, err := res.RowsAffected(); err != nil { return err } else if n == 0 { - return fmt.Errorf("%w: slab with key '%s' not found: %w", api.ErrSlabNotFound, string(key), err) + return api.ErrSlabNotFound } // fetch slab id and total shards var slabID, totalShards int64 - err = tx.QueryRow(ctx, "SELECT id, total_shards FROM slabs WHERE `key` = ?", ssql.SecretKey(key)). + err = tx.QueryRow(ctx, "SELECT id, total_shards FROM slabs WHERE `key` = ?", ssql.EncryptionKey(s.Key)). Scan(&slabID, &totalShards) if err != nil { return err @@ -997,6 +1159,14 @@ func (tx *MainDatabaseTx) UpdateSlabHealth(ctx context.Context, limit int64, min return res.RowsAffected() } +func (tx *MainDatabaseTx) WalletEvents(ctx context.Context, offset, limit int) ([]wallet.Event, error) { + return ssql.WalletEvents(ctx, tx.Tx, offset, limit) +} + +func (tx *MainDatabaseTx) WalletEventCount(ctx context.Context) (count uint64, err error) { + return ssql.WalletEventCount(ctx, tx.Tx) +} + func (tx *MainDatabaseTx) Webhooks(ctx context.Context) ([]webhooks.Webhook, error) { return ssql.Webhooks(ctx, tx) } @@ -1037,14 +1207,10 @@ func (tx *MainDatabaseTx) insertSlabs(ctx context.Context, objID, partID *int64, slabIDs := make([]int64, len(slices)) for i := range slices { - slabKey, err := slices[i].Key.MarshalBinary() - if err != nil { - return fmt.Errorf("failed to marshal slab key: %w", err) - } res, err := insertSlabStmt.Exec(ctx, time.Now(), contractSetID, - ssql.SecretKey(slabKey), + ssql.EncryptionKey(slices[i].Key), slices[i].MinShards, uint8(len(slices[i].Shards)), ) diff --git a/stores/sql/mysql/metrics.go b/stores/sql/mysql/metrics.go index dd51f228e..e7ef23813 100644 --- a/stores/sql/mysql/metrics.go +++ b/stores/sql/mysql/metrics.go @@ -30,11 +30,12 @@ type ( var _ ssql.MetricsDatabaseTx = (*MetricsDatabaseTx)(nil) // NewMetricsDatabase creates a new MySQL backend. -func NewMetricsDatabase(db *dsql.DB, log *zap.SugaredLogger, lqd, ltd time.Duration) (*MetricsDatabase, error) { - store, err := sql.NewDB(db, log.Desugar(), deadlockMsgs, lqd, ltd) +func NewMetricsDatabase(db *dsql.DB, log *zap.Logger, lqd, ltd time.Duration) (*MetricsDatabase, error) { + log = log.Named("metrics") + store, err := sql.NewDB(db, log, deadlockMsgs, lqd, ltd) return &MetricsDatabase{ db: store, - log: log, + log: log.Sugar(), }, err } diff --git a/stores/sql/mysql/migrations/main/migration_00012_peer_store.sql b/stores/sql/mysql/migrations/main/migration_00012_peer_store.sql new file mode 100644 index 000000000..02223995a --- /dev/null +++ b/stores/sql/mysql/migrations/main/migration_00012_peer_store.sql @@ -0,0 +1,24 @@ +-- dbSyncerPeer +CREATE TABLE `syncer_peers` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `created_at` datetime(3) DEFAULT NULL, + `address` varchar(191) NOT NULL, + `first_seen` bigint NOT NULL, + `last_connect` bigint, + `synced_blocks` bigint, + `sync_duration` bigint, + PRIMARY KEY (`id`), + UNIQUE KEY `idx_syncer_peers_address` (`address`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; + +-- dbSyncerBan +CREATE TABLE `syncer_bans` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `created_at` datetime(3) DEFAULT NULL, + `net_cidr` varchar(191) NOT NULL, + `reason` longtext, + `expiration` bigint NOT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `idx_syncer_bans_net_cidr` (`net_cidr`), + KEY `idx_syncer_bans_expiration` (`expiration`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; diff --git a/stores/sql/mysql/migrations/main/migration_00013_coreutils_wallet.sql b/stores/sql/mysql/migrations/main/migration_00013_coreutils_wallet.sql new file mode 100644 index 000000000..8adf1e717 --- /dev/null +++ b/stores/sql/mysql/migrations/main/migration_00013_coreutils_wallet.sql @@ -0,0 +1,42 @@ +-- drop tables +DROP TABLE IF EXISTS `siacoin_elements`; +DROP TABLE IF EXISTS `transactions`; + +-- drop column +ALTER TABLE `consensus_infos` DROP COLUMN `cc_id`; + +-- dbWalletEvent +CREATE TABLE `wallet_events` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `created_at` datetime(3) DEFAULT NULL, + `event_id` varbinary(32) NOT NULL, + `height` bigint unsigned DEFAULT NULL, + `block_id` varbinary(32) NOT NULL, + `inflow` longtext, + `outflow` longtext, + `type` varchar(191) NOT NULL, + `data` longblob NOT NULL, + `maturity_height` bigint unsigned DEFAULT NULL, + `timestamp` bigint DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `event_id` (`event_id`), + KEY `idx_wallet_events_maturity_height` (`maturity_height`), + KEY `idx_wallet_events_type` (`type`), + KEY `idx_wallet_events_timestamp` (`timestamp`), + KEY `idx_wallet_events_block_id_height` (`block_id`, `height`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; + +-- dbWalletOutput +CREATE TABLE `wallet_outputs` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `created_at` datetime(3) DEFAULT NULL, + `output_id` varbinary(32) NOT NULL, + `leaf_index` bigint, + `merkle_proof` longblob NOT NULL, + `value` longtext, + `address` varbinary(32) DEFAULT NULL, + `maturity_height` bigint unsigned DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `output_id` (`output_id`), + KEY `idx_wallet_outputs_maturity_height` (`maturity_height`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; diff --git a/stores/sql/mysql/migrations/main/migration_00014_hosts_resolvedaddresses.sql b/stores/sql/mysql/migrations/main/migration_00014_hosts_resolvedaddresses.sql new file mode 100644 index 000000000..74a5cebe6 --- /dev/null +++ b/stores/sql/mysql/migrations/main/migration_00014_hosts_resolvedaddresses.sql @@ -0,0 +1,2 @@ +ALTER TABLE hosts DROP COLUMN subnets; +ALTER TABLE hosts ADD resolved_addresses varchar(255) NOT NULL DEFAULT ''; diff --git a/stores/sql/mysql/migrations/main/migration_00015_reset_drift.sql b/stores/sql/mysql/migrations/main/migration_00015_reset_drift.sql new file mode 100644 index 000000000..c151d90a3 --- /dev/null +++ b/stores/sql/mysql/migrations/main/migration_00015_reset_drift.sql @@ -0,0 +1 @@ +UPDATE ephemeral_accounts SET drift = "0", clean_shutdown = 0, requires_sync = 1; \ No newline at end of file diff --git a/stores/sql/mysql/migrations/main/schema.sql b/stores/sql/mysql/migrations/main/schema.sql index 145fa9452..51a5c5629 100644 --- a/stores/sql/mysql/migrations/main/schema.sql +++ b/stores/sql/mysql/migrations/main/schema.sql @@ -70,7 +70,6 @@ CREATE TABLE `buffered_slabs` ( CREATE TABLE `consensus_infos` ( `id` bigint unsigned NOT NULL AUTO_INCREMENT, `created_at` datetime(3) DEFAULT NULL, - `cc_id` longblob, `height` bigint unsigned DEFAULT NULL, `block_id` longblob, PRIMARY KEY (`id`) @@ -98,7 +97,7 @@ CREATE TABLE `hosts` ( `lost_sectors` bigint unsigned DEFAULT NULL, `last_announcement` datetime(3) DEFAULT NULL, `net_address` varchar(191) DEFAULT NULL, - `subnets` varchar(255) NOT NULL DEFAULT '', + `resolved_addresses` varchar(255) NOT NULL DEFAULT '', PRIMARY KEY (`id`), UNIQUE KEY `public_key` (`public_key`), KEY `idx_hosts_public_key` (`public_key`), @@ -362,20 +361,6 @@ CREATE TABLE `settings` ( KEY `idx_settings_key` (`key`) ) ENGINE=InnoDB AUTO_INCREMENT=5 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; --- dbSiacoinElement -CREATE TABLE `siacoin_elements` ( - `id` bigint unsigned NOT NULL AUTO_INCREMENT, - `created_at` datetime(3) DEFAULT NULL, - `value` longtext, - `address` varbinary(32) DEFAULT NULL, - `output_id` varbinary(32) NOT NULL, - `maturity_height` bigint unsigned DEFAULT NULL, - PRIMARY KEY (`id`), - UNIQUE KEY `output_id` (`output_id`), - KEY `idx_siacoin_elements_output_id` (`output_id`), - KEY `idx_siacoin_elements_maturity_height` (`maturity_height`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; - -- dbSlice CREATE TABLE `slices` ( `id` bigint unsigned NOT NULL AUTO_INCREMENT, @@ -396,23 +381,6 @@ CREATE TABLE `slices` ( CONSTRAINT `fk_slabs_slices` FOREIGN KEY (`db_slab_id`) REFERENCES `slabs` (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; --- dbTransaction -CREATE TABLE `transactions` ( - `id` bigint unsigned NOT NULL AUTO_INCREMENT, - `created_at` datetime(3) DEFAULT NULL, - `raw` longtext, - `height` bigint unsigned DEFAULT NULL, - `block_id` varbinary(32) DEFAULT NULL, - `transaction_id` varbinary(32) NOT NULL, - `inflow` longtext, - `outflow` longtext, - `timestamp` bigint DEFAULT NULL, - PRIMARY KEY (`id`), - UNIQUE KEY `transaction_id` (`transaction_id`), - KEY `idx_transactions_transaction_id` (`transaction_id`), - KEY `idx_transactions_timestamp` (`timestamp`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; - -- dbWebhook CREATE TABLE `webhooks` ( `id` bigint unsigned NOT NULL AUTO_INCREMENT, @@ -492,5 +460,100 @@ CREATE TABLE `host_checks` ( CONSTRAINT `fk_host_checks_host` FOREIGN KEY (`db_host_id`) REFERENCES `hosts` (`id`) ON DELETE CASCADE ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; +-- dbObject trigger to delete from slices +CREATE TRIGGER before_delete_on_objects_delete_slices +BEFORE DELETE +ON objects FOR EACH ROW +DELETE FROM slices +WHERE slices.db_object_id = OLD.id; + +-- dbMultipartUpload trigger to delete from dbMultipartPart +CREATE TRIGGER before_delete_on_multipart_uploads_delete_multipart_parts +BEFORE DELETE +ON multipart_uploads FOR EACH ROW +DELETE FROM multipart_parts +WHERE multipart_parts.db_multipart_upload_id = OLD.id; + +-- dbMultipartPart trigger to delete from slices +CREATE TRIGGER before_delete_on_multipart_parts_delete_slices +BEFORE DELETE +ON multipart_parts FOR EACH ROW +DELETE FROM slices +WHERE slices.db_multipart_part_id = OLD.id; + +-- dbSlices trigger to prune slabs +CREATE TRIGGER after_delete_on_slices_delete_slabs +AFTER DELETE +ON slices FOR EACH ROW +DELETE FROM slabs +WHERE slabs.id = OLD.db_slab_id +AND slabs.db_buffered_slab_id IS NULL +AND NOT EXISTS ( + SELECT 1 + FROM slices + WHERE slices.db_slab_id = OLD.db_slab_id +); + +-- dbSyncerPeer +CREATE TABLE `syncer_peers` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `created_at` datetime(3) DEFAULT NULL, + `address` varchar(191) NOT NULL, + `first_seen` bigint NOT NULL, + `last_connect` bigint, + `synced_blocks` bigint, + `sync_duration` bigint, + PRIMARY KEY (`id`), + UNIQUE KEY `idx_syncer_peers_address` (`address`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; + +-- dbSyncerBan +CREATE TABLE `syncer_bans` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `created_at` datetime(3) DEFAULT NULL, + `net_cidr` varchar(191) NOT NULL, + `reason` longtext, + `expiration` bigint NOT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `idx_syncer_bans_net_cidr` (`net_cidr`), + KEY `idx_syncer_bans_expiration` (`expiration`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; + +-- dbWalletEvent +CREATE TABLE `wallet_events` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `created_at` datetime(3) DEFAULT NULL, + `event_id` varbinary(32) NOT NULL, + `height` bigint unsigned DEFAULT NULL, + `block_id` varbinary(32) NOT NULL, + `inflow` longtext, + `outflow` longtext, + `type` varchar(191) NOT NULL, + `data` longblob NOT NULL, + `maturity_height` bigint unsigned DEFAULT NULL, + `timestamp` bigint DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `event_id` (`event_id`), + KEY `idx_wallet_events_maturity_height` (`maturity_height`), + KEY `idx_wallet_events_type` (`type`), + KEY `idx_wallet_events_timestamp` (`timestamp`), + KEY `idx_wallet_events_block_id_height` (`block_id`, `height`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; + +-- dbWalletOutput +CREATE TABLE `wallet_outputs` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `created_at` datetime(3) DEFAULT NULL, + `output_id` varbinary(32) NOT NULL, + `leaf_index` bigint, + `merkle_proof` longblob NOT NULL, + `value` longtext, + `address` varbinary(32) DEFAULT NULL, + `maturity_height` bigint unsigned DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `output_id` (`output_id`), + KEY `idx_wallet_outputs_maturity_height` (`maturity_height`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; + -- create default bucket -INSERT INTO buckets (created_at, name) VALUES (CURRENT_TIMESTAMP, 'default'); \ No newline at end of file +INSERT INTO buckets (created_at, name) VALUES (CURRENT_TIMESTAMP, 'default'); diff --git a/stores/sql/mysql/migrations/metrics/migration_00002_idx_wallet_metrics_immature.sql b/stores/sql/mysql/migrations/metrics/migration_00002_idx_wallet_metrics_immature.sql new file mode 100644 index 000000000..edc14f373 --- /dev/null +++ b/stores/sql/mysql/migrations/metrics/migration_00002_idx_wallet_metrics_immature.sql @@ -0,0 +1,2 @@ +ALTER TABLE `wallets` ADD COLUMN `immature_lo` bigint NOT NULL, ADD COLUMN `immature_hi` bigint NOT NULL; +CREATE INDEX `idx_wallets_immature` ON `wallets`(`immature_lo`,`immature_hi`); diff --git a/stores/sql/mysql/migrations/metrics/schema.sql b/stores/sql/mysql/migrations/metrics/schema.sql index da4db5a6e..7c4c27d6c 100644 --- a/stores/sql/mysql/migrations/metrics/schema.sql +++ b/stores/sql/mysql/migrations/metrics/schema.sql @@ -114,9 +114,12 @@ CREATE TABLE `wallets` ( `spendable_hi` bigint NOT NULL, `unconfirmed_lo` bigint NOT NULL, `unconfirmed_hi` bigint NOT NULL, + `immature_lo` bigint NOT NULL, + `immature_hi` bigint NOT NULL, PRIMARY KEY (`id`), KEY `idx_wallets_timestamp` (`timestamp`), KEY `idx_confirmed` (`confirmed_lo`,`confirmed_hi`), KEY `idx_spendable` (`spendable_lo`,`spendable_hi`), - KEY `idx_unconfirmed` (`unconfirmed_lo`,`unconfirmed_hi`) + KEY `idx_unconfirmed` (`unconfirmed_lo`,`unconfirmed_hi`), + KEY `idx_wallets_immature` (`immature_lo`,`immature_hi`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; \ No newline at end of file diff --git a/stores/sql/rows.go b/stores/sql/rows.go index 44260b2ad..20830ee93 100644 --- a/stores/sql/rows.go +++ b/stores/sql/rows.go @@ -6,7 +6,7 @@ import ( "go.sia.tech/renterd/api" ) -type scanner interface { +type Scanner interface { Scan(dest ...any) error } @@ -38,7 +38,7 @@ type ContractRow struct { SiamuxPort string } -func (r *ContractRow) Scan(s scanner) error { +func (r *ContractRow) Scan(s Scanner) error { return s.Scan(&r.FCID, &r.RenewedFrom, &r.ContractPrice, &r.State, &r.TotalCost, &r.ProofHeight, &r.RevisionHeight, &r.RevisionNumber, &r.Size, &r.StartHeight, &r.WindowStart, &r.WindowEnd, &r.UploadSpending, &r.DownloadSpending, &r.FundAccountSpending, &r.DeleteSpending, &r.ListSpending, diff --git a/stores/sql/sqlite/chain.go b/stores/sql/sqlite/chain.go new file mode 100644 index 000000000..a98b777a4 --- /dev/null +++ b/stores/sql/sqlite/chain.go @@ -0,0 +1,345 @@ +package sqlite + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "strings" + "time" + + dsql "database/sql" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/renterd/api" + isql "go.sia.tech/renterd/internal/sql" + ssql "go.sia.tech/renterd/stores/sql" + "go.uber.org/zap" +) + +var ( + _ ssql.ChainUpdateTx = (*chainUpdateTx)(nil) +) + +type chainUpdateTx struct { + ctx context.Context + tx isql.Tx + l *zap.SugaredLogger +} + +func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent []types.SiacoinElement, events []wallet.Event, timestamp time.Time) error { + c.l.Debugw("applying index", "height", index.Height, "block_id", index.ID) + + if len(spent) > 0 { + // prepare statement to delete spent outputs + deleteSpentStmt, err := c.tx.Prepare(c.ctx, "DELETE FROM wallet_outputs WHERE output_id = ?") + if err != nil { + return fmt.Errorf("failed to prepare statement to delete spent outputs: %w", err) + } + defer deleteSpentStmt.Close() + + // delete spent outputs + for _, e := range spent { + c.l.Debugw(fmt.Sprintf("remove output %v", e.ID), "height", index.Height, "block_id", index.ID) + if res, err := deleteSpentStmt.Exec(c.ctx, ssql.Hash256(e.ID)); err != nil { + return fmt.Errorf("failed to delete spent output: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to delete spent output: no rows affected") + } + } + } + + if len(created) > 0 { + // prepare statement to insert new outputs + insertOutputStmt, err := c.tx.Prepare(c.ctx, "INSERT OR IGNORE INTO wallet_outputs (created_at, output_id, leaf_index, merkle_proof, value, address, maturity_height) VALUES (?, ?, ?, ?, ?, ?, ?)") + if err != nil { + return fmt.Errorf("failed to prepare statement to insert new outputs: %w", err) + } + defer insertOutputStmt.Close() + + // insert new outputs + for _, e := range created { + c.l.Debugw(fmt.Sprintf("create output %v", e.ID), "height", index.Height, "block_id", index.ID) + if _, err := insertOutputStmt.Exec(c.ctx, + time.Now().UTC(), + ssql.Hash256(e.ID), + e.StateElement.LeafIndex, + ssql.MerkleProof{Hashes: e.StateElement.MerkleProof}, + ssql.Currency(e.SiacoinOutput.Value), + ssql.Hash256(e.SiacoinOutput.Address), + e.MaturityHeight, + ); err != nil { + return fmt.Errorf("failed to insert new output: %w", err) + } + } + } + + if len(events) > 0 { + // prepare statement to insert new events + insertEventStmt, err := c.tx.Prepare(c.ctx, `INSERT OR IGNORE INTO wallet_events (created_at, height, block_id, event_id, inflow, outflow, type, data, maturity_height, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + if err != nil { + return fmt.Errorf("failed to prepare statement to insert new events: %w", err) + } + defer insertEventStmt.Close() + + // insert new events + for _, e := range events { + c.l.Debugw(fmt.Sprintf("create event %v", e.ID), "height", index.Height, "block_id", index.ID) + data, err := json.Marshal(e.Data) + if err != nil { + c.l.Error(err) + return err + } + if _, err := insertEventStmt.Exec(c.ctx, + time.Now().UTC(), + e.Index.Height, + ssql.Hash256(e.Index.ID), + ssql.Hash256(e.ID), + ssql.Currency(e.SiacoinInflow()), + ssql.Currency(e.SiacoinOutflow()), + e.Type, + data, + e.MaturityHeight, + ssql.UnixTimeNS(e.Timestamp), + ); err != nil { + return fmt.Errorf("failed to insert new event: %w", err) + } + } + } + return nil +} + +func (c chainUpdateTx) ContractState(fcid types.FileContractID) (api.ContractState, error) { + return ssql.GetContractState(c.ctx, c.tx, fcid) +} + +func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspent []types.SiacoinElement, timestamp time.Time) error { + c.l.Debugw("reverting index", "height", index.Height, "block_id", index.ID) + + if len(removed) > 0 { + // prepare statement to delete removed outputs + deleteRemovedStmt, err := c.tx.Prepare(c.ctx, "DELETE FROM wallet_outputs WHERE output_id = ?") + if err != nil { + return fmt.Errorf("failed to prepare statement to delete removed outputs: %w", err) + } + defer deleteRemovedStmt.Close() + + // delete removed outputs + for _, e := range removed { + c.l.Debugw(fmt.Sprintf("remove output %v", e.ID), "height", index.Height, "block_id", index.ID) + if res, err := deleteRemovedStmt.Exec(c.ctx, e.ID); err != nil { + return fmt.Errorf("failed to delete removed output: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to delete removed output: no rows affected") + } + } + } + + if len(unspent) > 0 { + // prepare statement to insert unspent outputs + insertOutputStmt, err := c.tx.Prepare(c.ctx, "INSERT OR IGNORE INTO wallet_outputs (created_at, output_id, leaf_index, merkle_proof, value, address, maturity_height) VALUES (?, ?, ?, ?, ?, ?, ?)") + if err != nil { + return fmt.Errorf("failed to prepare statement to insert unspent outputs: %w", err) + } + defer insertOutputStmt.Close() + + // insert unspent outputs + for _, e := range unspent { + c.l.Debugw(fmt.Sprintf("recreate unspent output %v", e.ID), "height", index.Height, "block_id", index.ID) + if _, err := insertOutputStmt.Exec(c.ctx, + time.Now().UTC(), + e.ID, + e.StateElement.LeafIndex, + ssql.MerkleProof{Hashes: e.StateElement.MerkleProof}, + ssql.Currency(e.SiacoinOutput.Value), + ssql.Hash256(e.SiacoinOutput.Address), + e.MaturityHeight, + ); err != nil { + return fmt.Errorf("failed to insert unspent output: %w", err) + } + } + } + + // remove events created at the reverted index + res, err := c.tx.Exec(c.ctx, "DELETE FROM wallet_events WHERE height = ? AND block_id = ?", index.Height, ssql.Hash256(index.ID)) + if err != nil { + return fmt.Errorf("failed to delete events: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n > 0 { + c.l.Debugw(fmt.Sprintf("removed %d events", n), "height", index.Height, "block_id", index.ID) + } + return nil +} + +func (c chainUpdateTx) UpdateChainIndex(index types.ChainIndex) error { + return ssql.UpdateChainIndex(c.ctx, c.tx, index, c.l) +} + +func (c chainUpdateTx) UpdateContract(fcid types.FileContractID, revisionHeight, revisionNumber, size uint64) error { + return ssql.UpdateContract(c.ctx, c.tx, fcid, revisionHeight, revisionNumber, size, c.l) +} + +func (c chainUpdateTx) UpdateContractProofHeight(fcid types.FileContractID, proofHeight uint64) error { + return ssql.UpdateContractProofHeight(c.ctx, c.tx, fcid, proofHeight, c.l) +} + +func (c chainUpdateTx) UpdateContractState(fcid types.FileContractID, state api.ContractState) error { + return ssql.UpdateContractState(c.ctx, c.tx, fcid, state, c.l) +} + +func (c chainUpdateTx) UpdateFailedContracts(blockHeight uint64) error { + return ssql.UpdateFailedContracts(c.ctx, c.tx, blockHeight, c.l) +} + +func (c chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, bh uint64, blockID types.BlockID, ts time.Time) error { // + c.l.Debugw("update host", "hk", hk, "netaddress", ha.NetAddress) + + // create the announcement + if _, err := c.tx.Exec(c.ctx, + "INSERT OR IGNORE INTO host_announcements (created_at,host_key, block_height, block_id, net_address) VALUES (?, ?, ?, ?, ?)", + time.Now().UTC(), + ssql.PublicKey(hk), + bh, + blockID.String(), + ha.NetAddress, + ); err != nil { + return fmt.Errorf("failed to insert host announcement: %w", err) + } + + // create the host + var hostID int64 + if err := c.tx.QueryRow(c.ctx, ` + INSERT INTO hosts (created_at, public_key, settings, price_table, total_scans, last_scan, last_scan_success, second_to_last_scan_success, scanned, uptime, downtime, recent_downtime, recent_scan_failures, successful_interactions, failed_interactions, lost_sectors, last_announcement, net_address) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(public_key) DO UPDATE SET + last_announcement = EXCLUDED.last_announcement, + net_address = EXCLUDED.net_address + RETURNING id`, + time.Now().UTC(), + ssql.PublicKey(hk), + ssql.HostSettings{}, + ssql.PriceTable{}, + 0, + 0, + false, + false, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ts.UTC(), + ha.NetAddress, + ).Scan(&hostID); err != nil { + if errors.Is(err, dsql.ErrNoRows) { + err = c.tx.QueryRow(c.ctx, + "UPDATE hosts SET last_announcement = ?, net_address = ? WHERE public_key = ? RETURNING id", + ts.UTC(), + ha.NetAddress, + ssql.PublicKey(hk), + ).Scan(&hostID) + if err != nil { + return fmt.Errorf("failed to fetch host id after conflict: %w", err) + } + } else { + return fmt.Errorf("failed to insert host: %w", err) + } + } + + // update allow list + rows, err := c.tx.Query(c.ctx, "SELECT id, entry FROM host_allowlist_entries") + if err != nil { + return fmt.Errorf("failed to fetch allow list: %w", err) + } + defer rows.Close() + for rows.Next() { + var id int64 + var pk ssql.PublicKey + if err := rows.Scan(&id, &pk); err != nil { + return fmt.Errorf("failed to scan row: %w", err) + } + if hk == types.PublicKey(pk) { + if _, err := c.tx.Exec(c.ctx, + "INSERT OR IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) VALUES (?,?)", + id, + hostID, + ); err != nil { + return fmt.Errorf("failed to insert host into allowlist: %w", err) + } + } + } + + // update blocklist + values := []string{ha.NetAddress} + host, _, err := net.SplitHostPort(ha.NetAddress) + if err == nil { + values = append(values, host) + } + + rows, err = c.tx.Query(c.ctx, "SELECT id, entry FROM host_blocklist_entries") + if err != nil { + return fmt.Errorf("failed to fetch block list: %w", err) + } + defer rows.Close() + + type row struct { + id int64 + entry string + } + var entries []row + for rows.Next() { + var r row + if err := rows.Scan(&r.id, &r.entry); err != nil { + return fmt.Errorf("failed to scan row: %w", err) + } + entries = append(entries, r) + } + + for _, row := range entries { + var blocked bool + for _, value := range values { + if value == row.entry || strings.HasSuffix(value, "."+row.entry) { + blocked = true + break + } + } + if blocked { + if _, err := c.tx.Exec(c.ctx, + "INSERT OR IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) VALUES (?,?)", + row.id, + hostID, + ); err != nil { + return fmt.Errorf("failed to insert host into blocklist: %w", err) + } + } else { + if _, err := c.tx.Exec(c.ctx, + "DELETE FROM host_blocklist_entry_hosts WHERE db_blocklist_entry_id = ? AND db_host_id = ?", + row.id, + hostID, + ); err != nil { + return fmt.Errorf("failed to remove host from blocklist: %w", err) + } + } + } + + return nil +} + +func (c chainUpdateTx) UpdateWalletStateElements(elements []types.StateElement) error { + return ssql.UpdateWalletStateElements(c.ctx, c.tx, elements) +} + +func (c chainUpdateTx) WalletStateElements() ([]types.StateElement, error) { + return ssql.WalletStateElements(c.ctx, c.tx) +} diff --git a/stores/sql/sqlite/common.go b/stores/sql/sqlite/common.go index c45c0eab7..fd46688b8 100644 --- a/stores/sql/sqlite/common.go +++ b/stores/sql/sqlite/common.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + _ "github.com/mattn/go-sqlite3" "go.sia.tech/renterd/internal/sql" "go.uber.org/zap" ) diff --git a/stores/sql/sqlite/main.go b/stores/sql/sqlite/main.go index cf6c415e6..b72ec5e8c 100644 --- a/stores/sql/sqlite/main.go +++ b/stores/sql/sqlite/main.go @@ -11,13 +11,15 @@ import ( "time" "unicode/utf8" + rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/sql" "go.sia.tech/renterd/object" ssql "go.sia.tech/renterd/stores/sql" "go.sia.tech/renterd/webhooks" - "go.sia.tech/siad/modules" "lukechampine.com/frand" "go.uber.org/zap" @@ -36,11 +38,12 @@ type ( ) // NewMainDatabase creates a new SQLite backend. -func NewMainDatabase(db *dsql.DB, log *zap.SugaredLogger, lqd, ltd time.Duration) (*MainDatabase, error) { - store, err := sql.NewDB(db, log.Desugar(), deadlockMsgs, lqd, ltd) +func NewMainDatabase(db *dsql.DB, log *zap.Logger, lqd, ltd time.Duration) (*MainDatabase, error) { + log = log.Named("main") + store, err := sql.NewDB(db, log, deadlockMsgs, lqd, ltd) return &MainDatabase{ db: store, - log: log, + log: log.Sugar(), }, err } @@ -91,6 +94,10 @@ func (tx *MainDatabaseTx) Accounts(ctx context.Context) ([]api.Account, error) { return ssql.Accounts(ctx, tx) } +func (tx *MainDatabaseTx) AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) error { + return ssql.AbortMultipartUpload(ctx, tx, bucket, path, uploadID) +} + func (tx *MainDatabaseTx) AddMultipartPart(ctx context.Context, bucket, path, contractSet, eTag, uploadID string, partNumber int, slices object.SlabSlices) error { // fetch contract set var csID int64 @@ -135,8 +142,16 @@ func (tx *MainDatabaseTx) AddMultipartPart(ctx context.Context, bucket, path, co return tx.insertSlabs(ctx, nil, &partID, contractSet, slices) } -func (tx *MainDatabaseTx) AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) error { - return ssql.AbortMultipartUpload(ctx, tx, bucket, path, uploadID) +func (tx *MainDatabaseTx) AddPeer(ctx context.Context, addr string) error { + _, err := tx.Exec(ctx, + "INSERT OR IGNORE INTO syncer_peers (address, first_seen, last_connect, synced_blocks, sync_duration) VALUES (?, ?, ?, ?, ?)", + addr, + ssql.UnixTimeMS(time.Now()), + ssql.UnixTimeMS(time.Time{}), + 0, + 0, + ) + return err } func (tx *MainDatabaseTx) AddWebhook(ctx context.Context, wh webhooks.Webhook) error { @@ -172,10 +187,30 @@ func (tx *MainDatabaseTx) Autopilots(ctx context.Context) ([]api.Autopilot, erro return ssql.Autopilots(ctx, tx) } +func (tx *MainDatabaseTx) BanPeer(ctx context.Context, addr string, duration time.Duration, reason string) error { + cidr, err := ssql.NormalizePeer(addr) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, + "INSERT INTO syncer_bans (created_at, net_cidr, expiration, reason) VALUES (?, ?, ?, ?) ON CONFLICT DO UPDATE SET expiration = EXCLUDED.expiration, reason = EXCLUDED.reason", + time.Now(), + cidr, + ssql.UnixTimeMS(time.Now().Add(duration)), + reason, + ) + return err +} + func (tx *MainDatabaseTx) Bucket(ctx context.Context, bucket string) (api.Bucket, error) { return ssql.Bucket(ctx, tx, bucket) } +func (tx *MainDatabaseTx) CharLengthExpr() string { + return "LENGTH" +} + func (tx *MainDatabaseTx) CompleteMultipartUpload(ctx context.Context, bucket, key, uploadID string, parts []api.MultipartCompletedPart, opts api.CompleteMultipartOptions) (string, error) { mpu, neededParts, size, eTag, err := ssql.MultipartUploadForCompletion(ctx, tx, bucket, key, uploadID, parts) if err != nil { @@ -244,6 +279,10 @@ func (tx *MainDatabaseTx) CompleteMultipartUpload(ctx context.Context, bucket, k return eTag, nil } +func (tx *MainDatabaseTx) Contract(ctx context.Context, fcid types.FileContractID) (api.ContractMetadata, error) { + return ssql.Contract(ctx, tx, fcid) +} + func (tx *MainDatabaseTx) ContractRoots(ctx context.Context, fcid types.FileContractID) ([]types.Hash256, error) { return ssql.ContractRoots(ctx, tx, fcid) } @@ -252,6 +291,10 @@ func (tx *MainDatabaseTx) Contracts(ctx context.Context, opts api.ContractsOpts) return ssql.Contracts(ctx, tx, opts) } +func (tx *MainDatabaseTx) ContractSetID(ctx context.Context, contractSet string) (int64, error) { + return ssql.ContractSetID(ctx, tx, contractSet) +} + func (tx *MainDatabaseTx) ContractSets(ctx context.Context) ([]string, error) { return ssql.ContractSets(ctx, tx) } @@ -301,6 +344,10 @@ func (tx *MainDatabaseTx) InsertBufferedSlab(ctx context.Context, fileName strin return ssql.InsertBufferedSlab(ctx, tx, fileName, contractSetID, ec, minShards, totalShards) } +func (tx *MainDatabaseTx) InsertContract(ctx context.Context, rev rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) { + return ssql.InsertContract(ctx, tx, rev, contractPrice, totalCost, startHeight, renewedFrom, state) +} + func (tx *MainDatabaseTx) InsertMultipartUpload(ctx context.Context, bucket, key string, ec object.EncryptionKey, mimeType string, metadata api.ObjectUserMetadata) (string, error) { return ssql.InsertMultipartUpload(ctx, tx, bucket, key, ec, mimeType, metadata) } @@ -349,10 +396,6 @@ func (tx *MainDatabaseTx) HostsForScanning(ctx context.Context, maxLastScan time return ssql.HostsForScanning(ctx, tx, maxLastScan, offset, limit) } -func (tx *MainDatabaseTx) InitConsensusInfo(ctx context.Context) (types.ChainIndex, modules.ConsensusChangeID, error) { - return ssql.InitConsensusInfo(ctx, tx) -} - func (tx *MainDatabaseTx) InsertObject(ctx context.Context, bucket, key, contractSet string, dirID int64, o object.Object, mimeType, eTag string, md api.ObjectUserMetadata) error { // get bucket id var bucketID int64 @@ -364,11 +407,7 @@ func (tx *MainDatabaseTx) InsertObject(ctx context.Context, bucket, key, contrac } // insert object - objKey, err := o.Key.MarshalBinary() - if err != nil { - return fmt.Errorf("failed to marshal object key: %w", err) - } - objID, err := ssql.InsertObject(ctx, tx, key, dirID, bucketID, o.TotalSize(), objKey, mimeType, eTag) + objID, err := ssql.InsertObject(ctx, tx, key, dirID, bucketID, o.TotalSize(), o.Key, mimeType, eTag) if err != nil { return fmt.Errorf("failed to insert object: %w", err) } @@ -468,6 +507,10 @@ func (tx *MainDatabaseTx) MakeDirsForPath(ctx context.Context, path string) (int return dirID, nil } +func (tx *MainDatabaseTx) MarkPackedSlabUploaded(ctx context.Context, slab api.UploadedPackedSlab) (string, error) { + return ssql.MarkPackedSlabUploaded(ctx, tx, slab) +} + func (tx *MainDatabaseTx) MultipartUpload(ctx context.Context, uploadID string) (api.MultipartUpload, error) { return ssql.MultipartUpload(ctx, tx, uploadID) } @@ -480,10 +523,46 @@ func (tx *MainDatabaseTx) MultipartUploads(ctx context.Context, bucket, prefix, return ssql.MultipartUploads(ctx, tx, bucket, prefix, keyMarker, uploadIDMarker, limit) } +func (tx *MainDatabaseTx) Object(ctx context.Context, bucket, key string) (api.Object, error) { + return ssql.Object(ctx, tx, bucket, key) +} + +func (tx *MainDatabaseTx) ObjectEntries(ctx context.Context, bucket, path, prefix, sortBy, sortDir, marker string, offset, limit int) ([]api.ObjectMetadata, bool, error) { + return ssql.ObjectEntries(ctx, tx, bucket, path, prefix, sortBy, sortDir, marker, offset, limit) +} + +func (tx *MainDatabaseTx) ObjectMetadata(ctx context.Context, bucket, path string) (api.Object, error) { + return ssql.ObjectMetadata(ctx, tx, bucket, path) +} + +func (tx *MainDatabaseTx) ObjectsBySlabKey(ctx context.Context, bucket string, slabKey object.EncryptionKey) (metadata []api.ObjectMetadata, err error) { + return ssql.ObjectsBySlabKey(ctx, tx, bucket, slabKey) +} + func (tx *MainDatabaseTx) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) (api.ObjectsStatsResponse, error) { return ssql.ObjectsStats(ctx, tx, opts) } +func (tx *MainDatabaseTx) PeerBanned(ctx context.Context, addr string) (bool, error) { + return ssql.PeerBanned(ctx, tx, addr) +} + +func (tx *MainDatabaseTx) PeerInfo(ctx context.Context, addr string) (syncer.PeerInfo, error) { + return ssql.PeerInfo(ctx, tx, addr) +} + +func (tx *MainDatabaseTx) Peers(ctx context.Context) ([]syncer.PeerInfo, error) { + return ssql.Peers(ctx, tx) +} + +func (tx *MainDatabaseTx) ProcessChainUpdate(ctx context.Context, fn func(ssql.ChainUpdateTx) error) (err error) { + return fn(&chainUpdateTx{ + ctx: ctx, + tx: tx, + l: tx.log.Named("ProcessChainUpdate"), + }) +} + func (tx *MainDatabaseTx) PruneEmptydirs(ctx context.Context) error { stmt, err := tx.Prepare(ctx, ` DELETE @@ -529,6 +608,10 @@ func (tx *MainDatabaseTx) PruneSlabs(ctx context.Context, limit int64) (int64, e return res.RowsAffected() } +func (tx *MainDatabaseTx) RecordContractSpending(ctx context.Context, fcid types.FileContractID, revisionNumber, size uint64, newSpending api.ContractSpending) error { + return ssql.RecordContractSpending(ctx, tx, fcid, revisionNumber, size, newSpending) +} + func (tx *MainDatabaseTx) RecordHostScans(ctx context.Context, scans []api.HostScan) error { return ssql.RecordHostScans(ctx, tx, scans) } @@ -615,12 +698,16 @@ func (tx *MainDatabaseTx) RenameObjects(ctx context.Context, bucket, prefixOld, return nil } +func (tx *MainDatabaseTx) RenewContract(ctx context.Context, rev rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (api.ContractMetadata, error) { + return ssql.RenewContract(ctx, tx, rev, contractPrice, totalCost, startHeight, renewedFrom, state) +} + func (tx *MainDatabaseTx) RenewedContract(ctx context.Context, renwedFrom types.FileContractID) (api.ContractMetadata, error) { return ssql.RenewedContract(ctx, tx, renwedFrom) } -func (tx *MainDatabaseTx) ResetConsensusSubscription(ctx context.Context) (types.ChainIndex, error) { - return ssql.ResetConsensusSubscription(ctx, tx) +func (tx *MainDatabaseTx) ResetChainState(ctx context.Context) error { + return ssql.ResetChainState(ctx, tx.Tx) } func (tx *MainDatabaseTx) ResetLostSectors(ctx context.Context, hk types.PublicKey) error { @@ -658,10 +745,79 @@ func (tx *MainDatabaseTx) SaveAccounts(ctx context.Context, accounts []api.Accou return nil } +func (tx *MainDatabaseTx) ScanObjectMetadata(s ssql.Scanner, others ...any) (md api.ObjectMetadata, err error) { + var createdAt string + dst := []any{&md.Name, &md.Size, &md.Health, &md.MimeType, &createdAt, &md.ETag} + dst = append(dst, others...) + if err := s.Scan(dst...); err != nil { + return api.ObjectMetadata{}, fmt.Errorf("failed to scan object metadata: %w", err) + } else if *(*time.Time)(&md.ModTime), err = time.Parse(time.DateTime, createdAt); err != nil { + return api.ObjectMetadata{}, fmt.Errorf("failed to parse created at time: %w", err) + } + return md, nil +} + func (tx *MainDatabaseTx) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { return ssql.SearchHosts(ctx, tx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit) } +func (tx *MainDatabaseTx) SearchObjects(ctx context.Context, bucket, substring string, offset, limit int) ([]api.ObjectMetadata, error) { + return ssql.SearchObjects(ctx, tx, bucket, substring, offset, limit) +} + +func (tx *MainDatabaseTx) SelectObjectMetadataExpr() string { + return "o.object_id, o.size, o.health, o.mime_type, DATETIME(o.created_at), o.etag" +} + +func (tx *MainDatabaseTx) SetContractSet(ctx context.Context, name string, contractIds []types.FileContractID) error { + var csID int64 + err := tx.QueryRow(ctx, "INSERT INTO contract_sets (name) VALUES (?) ON CONFLICT(name) DO UPDATE SET id = id RETURNING id", name).Scan(&csID) + if err != nil { + return fmt.Errorf("failed to fetch contract set id: %w", err) + } + + // handle empty set + if len(contractIds) == 0 { + _, err := tx.Exec(ctx, "DELETE FROM contract_set_contracts WHERE db_contract_set_id = ?", csID) + return err + } + + // prepare fcid args and query + fcidQuery := strings.Repeat("?, ", len(contractIds)-1) + "?" + fcidArgs := make([]interface{}, len(contractIds)) + for i, fcid := range contractIds { + fcidArgs[i] = ssql.FileContractID(fcid) + } + + // remove unwanted contracts + args := []interface{}{csID} + args = append(args, fcidArgs...) + _, err = tx.Exec(ctx, fmt.Sprintf(` + DELETE FROM contract_set_contracts + WHERE db_contract_set_id = ? AND db_contract_id NOT IN ( + SELECT id + FROM contracts + WHERE contracts.fcid IN (%s) + ) + `, fcidQuery), args...) + if err != nil { + return fmt.Errorf("failed to delete contract set contracts: %w", err) + } + + // add missing contracts + _, err = tx.Exec(ctx, fmt.Sprintf(` + INSERT INTO contract_set_contracts (db_contract_set_id, db_contract_id) + SELECT ?, c.id + FROM contracts c + WHERE c.fcid IN (%s) + ON CONFLICT(db_contract_set_id, db_contract_id) DO NOTHING + `, fcidQuery), args...) + if err != nil { + return fmt.Errorf("failed to add contract set contracts: %w", err) + } + return nil +} + func (tx *MainDatabaseTx) Setting(ctx context.Context, key string) (string, error) { return ssql.Setting(ctx, tx, key) } @@ -674,26 +830,35 @@ func (tx *MainDatabaseTx) SetUncleanShutdown(ctx context.Context) error { return ssql.SetUncleanShutdown(ctx, tx) } +func (tx *MainDatabaseTx) Slab(ctx context.Context, key object.EncryptionKey) (object.Slab, error) { + return ssql.Slab(ctx, tx, key) +} + func (tx *MainDatabaseTx) SlabBuffers(ctx context.Context) (map[string]string, error) { return ssql.SlabBuffers(ctx, tx) } +func (tx *MainDatabaseTx) Tip(ctx context.Context) (types.ChainIndex, error) { + return ssql.Tip(ctx, tx.Tx) +} + +func (tx *MainDatabaseTx) UnhealthySlabs(ctx context.Context, healthCutoff float64, set string, limit int) ([]api.UnhealthySlab, error) { + return ssql.UnhealthySlabs(ctx, tx, healthCutoff, set, limit) +} + +func (tx *MainDatabaseTx) UnspentSiacoinElements(ctx context.Context) (elements []types.SiacoinElement, err error) { + return ssql.UnspentSiacoinElements(ctx, tx.Tx) +} + func (tx *MainDatabaseTx) UpdateAutopilot(ctx context.Context, ap api.Autopilot) error { - res, err := tx.Exec(ctx, ` + _, err := tx.Exec(ctx, ` INSERT INTO autopilots (created_at, identifier, config, current_period) VALUES (?, ?, ?, ?) ON CONFLICT(identifier) DO UPDATE SET config = EXCLUDED.config, current_period = EXCLUDED.current_period `, time.Now(), ap.ID, (*ssql.AutopilotConfig)(&ap.Config), ap.CurrentPeriod) - if err != nil { - return err - } else if n, err := res.RowsAffected(); err != nil { - return err - } else if n != 1 { - return fmt.Errorf("expected 1 row affected, got %v", n) - } - return nil + return err } func (tx *MainDatabaseTx) UpdateBucketPolicy(ctx context.Context, bucket string, policy api.BucketPolicy) error { @@ -836,6 +1001,10 @@ func (tx *MainDatabaseTx) UpdateHostCheck(ctx context.Context, autopilot string, return nil } +func (tx *MainDatabaseTx) UpdatePeerInfo(ctx context.Context, addr string, fn func(*syncer.PeerInfo)) error { + return ssql.UpdatePeerInfo(ctx, tx, addr, fn) +} + func (tx *MainDatabaseTx) UpdateSetting(ctx context.Context, key, value string) error { _, err := tx.Exec(ctx, "INSERT INTO settings (created_at, `key`, value) VALUES (?, ?, ?) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value", time.Now(), key, value) @@ -852,12 +1021,6 @@ func (tx *MainDatabaseTx) UpdateSlab(ctx context.Context, s object.Slab, contrac return fmt.Errorf("failed to fetch used contracts: %w", err) } - // extract the slab key - key, err := s.Key.MarshalBinary() - if err != nil { - return fmt.Errorf("failed to marshal slab key: %w", err) - } - // update slab var slabID, totalShards int64 err = tx.QueryRow(ctx, ` @@ -867,10 +1030,10 @@ func (tx *MainDatabaseTx) UpdateSlab(ctx context.Context, s object.Slab, contrac health = ? WHERE key = ? RETURNING id, total_shards - `, contractSet, time.Now().Unix(), 1, ssql.SecretKey(key)). + `, contractSet, time.Now().Unix(), 1, ssql.EncryptionKey(s.Key)). Scan(&slabID, &totalShards) if errors.Is(err, dsql.ErrNoRows) { - return fmt.Errorf("%w: slab with key '%s' not found: %w", api.ErrSlabNotFound, string(key), err) + return api.ErrSlabNotFound } else if err != nil { return err } @@ -983,6 +1146,14 @@ func (tx *MainDatabaseTx) UpdateSlabHealth(ctx context.Context, limit int64, min return res.RowsAffected() } +func (tx *MainDatabaseTx) WalletEvents(ctx context.Context, offset, limit int) ([]wallet.Event, error) { + return ssql.WalletEvents(ctx, tx.Tx, offset, limit) +} + +func (tx *MainDatabaseTx) WalletEventCount(ctx context.Context) (count uint64, err error) { + return ssql.WalletEventCount(ctx, tx.Tx) +} + func (tx *MainDatabaseTx) Webhooks(ctx context.Context) ([]webhooks.Webhook, error) { return ssql.Webhooks(ctx, tx) } @@ -1023,19 +1194,15 @@ func (tx *MainDatabaseTx) insertSlabs(ctx context.Context, objID, partID *int64, slabIDs := make([]int64, len(slices)) for i := range slices { - slabKey, err := slices[i].Key.MarshalBinary() - if err != nil { - return fmt.Errorf("failed to marshal slab key: %w", err) - } err = insertSlabStmt.QueryRow(ctx, time.Now(), contractSetID, - ssql.SecretKey(slabKey), + ssql.EncryptionKey(slices[i].Key), slices[i].MinShards, uint8(len(slices[i].Shards)), ).Scan(&slabIDs[i]) if errors.Is(err, dsql.ErrNoRows) { - if err := querySlabIDStmt.QueryRow(ctx, ssql.SecretKey(slabKey)).Scan(&slabIDs[i]); err != nil { + if err := querySlabIDStmt.QueryRow(ctx, ssql.EncryptionKey(slices[i].Key)).Scan(&slabIDs[i]); err != nil { return fmt.Errorf("failed to fetch slab id: %w", err) } } else if err != nil { diff --git a/stores/sql/sqlite/metrics.go b/stores/sql/sqlite/metrics.go index c52417fee..df912d7c7 100644 --- a/stores/sql/sqlite/metrics.go +++ b/stores/sql/sqlite/metrics.go @@ -29,11 +29,12 @@ type ( var _ ssql.MetricsDatabaseTx = (*MetricsDatabaseTx)(nil) // NewSQLiteDatabase creates a new SQLite backend. -func NewMetricsDatabase(db *dsql.DB, log *zap.SugaredLogger, lqd, ltd time.Duration) (*MetricsDatabase, error) { - store, err := sql.NewDB(db, log.Desugar(), deadlockMsgs, lqd, ltd) +func NewMetricsDatabase(db *dsql.DB, log *zap.Logger, lqd, ltd time.Duration) (*MetricsDatabase, error) { + log = log.Named("metrics") + store, err := sql.NewDB(db, log, deadlockMsgs, lqd, ltd) return &MetricsDatabase{ db: store, - log: log, + log: log.Sugar(), }, err } diff --git a/stores/sql/sqlite/migrations/main/migration_00012_peer_store.sql b/stores/sql/sqlite/migrations/main/migration_00012_peer_store.sql new file mode 100644 index 000000000..a1e168de8 --- /dev/null +++ b/stores/sql/sqlite/migrations/main/migration_00012_peer_store.sql @@ -0,0 +1,7 @@ +-- dbSyncerPeer +CREATE TABLE `syncer_peers` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`address` text NOT NULL,`first_seen` BIGINT NOT NULL,`last_connect` BIGINT,`synced_blocks` BIGINT,`sync_duration` BIGINT); +CREATE UNIQUE INDEX `idx_syncer_peers_address` ON `syncer_peers`(`address`); + +-- dbSyncerBan +CREATE TABLE `syncer_bans` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`net_cidr` text NOT NULL,`reason` text,`expiration` BIGINT NOT NULL); +CREATE UNIQUE INDEX `idx_syncer_bans_net_cidr` ON `syncer_bans`(`net_cidr`); \ No newline at end of file diff --git a/stores/sql/sqlite/migrations/main/migration_00013_coreutils_wallet.sql b/stores/sql/sqlite/migrations/main/migration_00013_coreutils_wallet.sql new file mode 100644 index 000000000..cbf14dcc0 --- /dev/null +++ b/stores/sql/sqlite/migrations/main/migration_00013_coreutils_wallet.sql @@ -0,0 +1,19 @@ +-- drop tables +DROP TABLE IF EXISTS `siacoin_elements`; +DROP TABLE IF EXISTS `transactions`; + +-- drop column +ALTER TABLE `consensus_infos` DROP COLUMN `cc_id`; + +-- dbWalletEvent +CREATE TABLE `wallet_events` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`event_id` blob NOT NULL,`height` integer, `block_id` blob,`inflow` text,`outflow` text,`type` text NOT NULL,`data` longblob NOT NULL,`maturity_height` integer,`timestamp` integer); +CREATE UNIQUE INDEX `idx_wallet_events_event_id` ON `wallet_events`(`event_id`); +CREATE INDEX `idx_wallet_events_maturity_height` ON `wallet_events`(`maturity_height`); +CREATE INDEX `idx_wallet_events_type` ON `wallet_events`(`type`); +CREATE INDEX `idx_wallet_events_timestamp` ON `wallet_events`(`timestamp`); +CREATE INDEX `idx_wallet_events_block_id_height` ON `wallet_events`(`block_id`,`height`); + +-- dbWalletOutput +CREATE TABLE `wallet_outputs` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`output_id` blob NOT NULL,`leaf_index` integer,`merkle_proof` longblob NOT NULL,`value` text,`address` blob,`maturity_height` integer); +CREATE UNIQUE INDEX `idx_wallet_outputs_output_id` ON `wallet_outputs`(`output_id`); +CREATE INDEX `idx_wallet_outputs_maturity_height` ON `wallet_outputs`(`maturity_height`); \ No newline at end of file diff --git a/stores/sql/sqlite/migrations/main/migration_00014_hosts_resolvedaddresses.sql b/stores/sql/sqlite/migrations/main/migration_00014_hosts_resolvedaddresses.sql new file mode 100644 index 000000000..9800c2b7b --- /dev/null +++ b/stores/sql/sqlite/migrations/main/migration_00014_hosts_resolvedaddresses.sql @@ -0,0 +1,2 @@ +ALTER TABLE hosts DROP COLUMN subnets; +ALTER TABLE hosts ADD resolved_addresses TEXT NOT NULL DEFAULT ''; diff --git a/stores/sql/sqlite/migrations/main/migration_00015_reset_drift.sql b/stores/sql/sqlite/migrations/main/migration_00015_reset_drift.sql new file mode 100644 index 000000000..c151d90a3 --- /dev/null +++ b/stores/sql/sqlite/migrations/main/migration_00015_reset_drift.sql @@ -0,0 +1 @@ +UPDATE ephemeral_accounts SET drift = "0", clean_shutdown = 0, requires_sync = 1; \ No newline at end of file diff --git a/stores/sql/sqlite/migrations/main/schema.sql b/stores/sql/sqlite/migrations/main/schema.sql index eadbc425c..647e6cfdd 100644 --- a/stores/sql/sqlite/migrations/main/schema.sql +++ b/stores/sql/sqlite/migrations/main/schema.sql @@ -12,7 +12,7 @@ CREATE INDEX `idx_archived_contracts_state` ON `archived_contracts`(`state`); CREATE INDEX `idx_archived_contracts_renewed_from` ON `archived_contracts`(`renewed_from`); -- dbHost -CREATE TABLE `hosts` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`public_key` blob NOT NULL UNIQUE,`settings` text,`price_table` text,`price_table_expiry` datetime,`total_scans` integer,`last_scan` integer,`last_scan_success` numeric,`second_to_last_scan_success` numeric,`scanned` numeric,`uptime` integer,`downtime` integer,`recent_downtime` integer,`recent_scan_failures` integer,`successful_interactions` real,`failed_interactions` real,`lost_sectors` integer,`last_announcement` datetime,`net_address` text,`subnets` text NOT NULL DEFAULT ''); +CREATE TABLE `hosts` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`public_key` blob NOT NULL UNIQUE,`settings` text,`price_table` text,`price_table_expiry` datetime,`total_scans` integer,`last_scan` integer,`last_scan_success` numeric,`second_to_last_scan_success` numeric,`scanned` numeric,`uptime` integer,`downtime` integer,`recent_downtime` integer,`recent_scan_failures` integer,`successful_interactions` real,`failed_interactions` real,`lost_sectors` integer,`last_announcement` datetime,`net_address` text,`resolved_addresses` text NOT NULL DEFAULT ''); CREATE INDEX `idx_hosts_recent_scan_failures` ON `hosts`(`recent_scan_failures`); CREATE INDEX `idx_hosts_recent_downtime` ON `hosts`(`recent_downtime`); CREATE INDEX `idx_hosts_scanned` ON `hosts`(`scanned`); @@ -107,7 +107,7 @@ CREATE INDEX `idx_slices_db_multipart_part_id` ON `slices`(`db_multipart_part_id CREATE TABLE `host_announcements` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`host_key` blob NOT NULL,`block_height` integer,`block_id` text,`net_address` text); -- dbConsensusInfo -CREATE TABLE `consensus_infos` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`cc_id` blob,`height` integer,`block_id` blob); +CREATE TABLE `consensus_infos` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`height` integer,`block_id` blob); -- dbBlocklistEntry CREATE TABLE `host_blocklist_entries` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`entry` text NOT NULL UNIQUE); @@ -125,16 +125,6 @@ CREATE INDEX `idx_host_allowlist_entries_entry` ON `host_allowlist_entries`(`ent CREATE TABLE `host_allowlist_entry_hosts` (`db_allowlist_entry_id` integer,`db_host_id` integer,PRIMARY KEY (`db_allowlist_entry_id`,`db_host_id`),CONSTRAINT `fk_host_allowlist_entry_hosts_db_allowlist_entry` FOREIGN KEY (`db_allowlist_entry_id`) REFERENCES `host_allowlist_entries`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_host_allowlist_entry_hosts_db_host` FOREIGN KEY (`db_host_id`) REFERENCES `hosts`(`id`) ON DELETE CASCADE); CREATE INDEX `idx_host_allowlist_entry_hosts_db_host_id` ON `host_allowlist_entry_hosts`(`db_host_id`); --- dbSiacoinElement -CREATE TABLE `siacoin_elements` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`value` text,`address` blob,`output_id` blob NOT NULL UNIQUE,`maturity_height` integer); -CREATE INDEX `idx_siacoin_elements_maturity_height` ON `siacoin_elements`(`maturity_height`); -CREATE INDEX `idx_siacoin_elements_output_id` ON `siacoin_elements`(`output_id`); - --- dbTransaction -CREATE TABLE `transactions` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`raw` text,`height` integer,`block_id` blob,`transaction_id` blob NOT NULL UNIQUE,`inflow` text,`outflow` text,`timestamp` integer); -CREATE INDEX `idx_transactions_timestamp` ON `transactions`(`timestamp`); -CREATE INDEX `idx_transactions_transaction_id` ON `transactions`(`transaction_id`); - -- dbSetting CREATE TABLE `settings` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`key` text NOT NULL UNIQUE,`value` text NOT NULL); CREATE INDEX `idx_settings_key` ON `settings`(`key`); @@ -173,5 +163,65 @@ CREATE INDEX `idx_host_checks_score_uptime` ON `host_checks` (`score_uptime`); CREATE INDEX `idx_host_checks_score_version` ON `host_checks` (`score_version`); CREATE INDEX `idx_host_checks_score_prices` ON `host_checks` (`score_prices`); +-- dbObject trigger to delete from slices +CREATE TRIGGER before_delete_on_objects_delete_slices +BEFORE DELETE ON objects +BEGIN + DELETE FROM slices + WHERE slices.db_object_id = OLD.id; +END; + +-- dbMultipartUpload trigger to delete from dbMultipartPart +CREATE TRIGGER before_delete_on_multipart_uploads_delete_multipart_parts +BEFORE DELETE ON multipart_uploads +BEGIN + DELETE FROM multipart_parts + WHERE multipart_parts.db_multipart_upload_id = OLD.id; +END; + +-- dbMultipartPart trigger to delete from slices +CREATE TRIGGER before_delete_on_multipart_parts_delete_slices +BEFORE DELETE ON multipart_parts +BEGIN + DELETE FROM slices + WHERE slices.db_multipart_part_id = OLD.id; +END; + +-- dbSlices trigger to prune slabs +CREATE TRIGGER after_delete_on_slices_delete_slabs +AFTER DELETE ON slices +BEGIN + DELETE FROM slabs + WHERE slabs.id = OLD.db_slab_id + AND slabs.db_buffered_slab_id IS NULL + AND NOT EXISTS ( + SELECT 1 + FROM slices + WHERE slices.db_slab_id = OLD.db_slab_id + ); +END; + +-- dbSyncerPeer +CREATE TABLE `syncer_peers` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`address` text NOT NULL,`first_seen` BIGINT NOT NULL,`last_connect` BIGINT,`synced_blocks` BIGINT,`sync_duration` BIGINT); +CREATE UNIQUE INDEX `idx_syncer_peers_address` ON `syncer_peers`(`address`); + +-- dbSyncerBan +CREATE TABLE `syncer_bans` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`net_cidr` text NOT NULL,`reason` text,`expiration` BIGINT NOT NULL); +CREATE UNIQUE INDEX `idx_syncer_bans_net_cidr` ON `syncer_bans`(`net_cidr`); +CREATE INDEX `idx_syncer_bans_expiration` ON `syncer_bans`(`expiration`); + +-- dbWalletEvent +CREATE TABLE `wallet_events` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`event_id` blob NOT NULL,`height` integer, `block_id` blob,`inflow` text,`outflow` text,`type` text NOT NULL,`data` longblob NOT NULL,`maturity_height` integer,`timestamp` integer); +CREATE UNIQUE INDEX `idx_wallet_events_event_id` ON `wallet_events`(`event_id`); +CREATE INDEX `idx_wallet_events_maturity_height` ON `wallet_events`(`maturity_height`); +CREATE INDEX `idx_wallet_events_type` ON `wallet_events`(`type`); +CREATE INDEX `idx_wallet_events_timestamp` ON `wallet_events`(`timestamp`); +CREATE INDEX `idx_wallet_events_block_id_height` ON `wallet_events`(`block_id`,`height`); + +-- dbWalletOutput +CREATE TABLE `wallet_outputs` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`output_id` blob NOT NULL,`leaf_index` integer,`merkle_proof` longblob NOT NULL,`value` text,`address` blob,`maturity_height` integer); +CREATE UNIQUE INDEX `idx_wallet_outputs_output_id` ON `wallet_outputs`(`output_id`); +CREATE INDEX `idx_wallet_outputs_maturity_height` ON `wallet_outputs`(`maturity_height`); + -- create default bucket INSERT INTO buckets (created_at, name) VALUES (CURRENT_TIMESTAMP, 'default'); diff --git a/stores/sql/sqlite/migrations/metrics/migration_00002_idx_wallet_metrics_immature.sql b/stores/sql/sqlite/migrations/metrics/migration_00002_idx_wallet_metrics_immature.sql new file mode 100644 index 000000000..ee2b3e109 --- /dev/null +++ b/stores/sql/sqlite/migrations/metrics/migration_00002_idx_wallet_metrics_immature.sql @@ -0,0 +1,3 @@ +ALTER TABLE `wallets` ADD COLUMN `immature_lo` BIGINT DEFAULT 0; +ALTER TABLE `wallets` ADD COLUMN `immature_hi` BIGINT DEFAULT 0; +CREATE INDEX `idx_wallets_immature` ON `wallets`(`immature_lo`,`immature_hi`); \ No newline at end of file diff --git a/stores/sql/sqlite/migrations/metrics/schema.sql b/stores/sql/sqlite/migrations/metrics/schema.sql index 63dae7d65..dfb8e3cf1 100644 --- a/stores/sql/sqlite/migrations/metrics/schema.sql +++ b/stores/sql/sqlite/migrations/metrics/schema.sql @@ -46,8 +46,9 @@ CREATE INDEX `idx_performance_action` ON `performance`(`action`); CREATE INDEX `idx_performance_timestamp` ON `performance`(`timestamp`); -- dbWalletMetric -CREATE TABLE `wallets` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`timestamp` BIGINT NOT NULL,`confirmed_lo` BIGINT NOT NULL,`confirmed_hi` BIGINT NOT NULL,`spendable_lo` BIGINT NOT NULL,`spendable_hi` BIGINT NOT NULL,`unconfirmed_lo` BIGINT NOT NULL,`unconfirmed_hi` BIGINT NOT NULL); +CREATE TABLE `wallets` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`timestamp` BIGINT NOT NULL,`confirmed_lo` BIGINT NOT NULL,`confirmed_hi` BIGINT NOT NULL,`spendable_lo` BIGINT NOT NULL,`spendable_hi` BIGINT NOT NULL,`unconfirmed_lo` BIGINT NOT NULL,`unconfirmed_hi` BIGINT NOT NULL,`immature_lo` BIGINT NOT NULL,`immature_hi` BIGINT NOT NULL); CREATE INDEX `idx_unconfirmed` ON `wallets`(`unconfirmed_lo`,`unconfirmed_hi`); CREATE INDEX `idx_spendable` ON `wallets`(`spendable_lo`,`spendable_hi`); CREATE INDEX `idx_confirmed` ON `wallets`(`confirmed_lo`,`confirmed_hi`); +CREATE INDEX `idx_wallets_immature` ON `wallets`(`immature_lo`,`immature_hi`); CREATE INDEX `idx_wallets_timestamp` ON `wallets`(`timestamp`); diff --git a/stores/sql/types.go b/stores/sql/types.go index 00242be2f..10cf76e42 100644 --- a/stores/sql/types.go +++ b/stores/sql/types.go @@ -3,6 +3,7 @@ package sql import ( "database/sql" "database/sql/driver" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -14,26 +15,33 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/api" - "go.sia.tech/siad/modules" + "go.sia.tech/renterd/object" ) const ( - secretKeySize = 32 + proofHashSize = 32 +) + +var ( + ZeroCurrency = Currency(types.ZeroCurrency) ) type ( AutopilotConfig api.AutopilotConfig + BCurrency types.Currency BigInt big.Int - CCID modules.ConsensusChangeID + BusSetting string Currency types.Currency FileContractID types.FileContractID Hash256 types.Hash256 - BusSetting string + MerkleProof struct{ Hashes []types.Hash256 } HostSettings rhpv2.HostSettings PriceTable rhpv3.HostPriceTable PublicKey types.PublicKey - SecretKey []byte + EncryptionKey object.EncryptionKey + Uint64Str uint64 UnixTimeMS time.Time UnixTimeNS time.Time Unsigned64 uint64 @@ -46,16 +54,17 @@ type scannerValuer interface { var ( _ scannerValuer = (*AutopilotConfig)(nil) + _ scannerValuer = (*BCurrency)(nil) _ scannerValuer = (*BigInt)(nil) _ scannerValuer = (*BusSetting)(nil) - _ scannerValuer = (*CCID)(nil) _ scannerValuer = (*Currency)(nil) _ scannerValuer = (*FileContractID)(nil) _ scannerValuer = (*Hash256)(nil) + _ scannerValuer = (*MerkleProof)(nil) _ scannerValuer = (*HostSettings)(nil) _ scannerValuer = (*PriceTable)(nil) _ scannerValuer = (*PublicKey)(nil) - _ scannerValuer = (*SecretKey)(nil) + _ scannerValuer = (*EncryptionKey)(nil) _ scannerValuer = (*UnixTimeMS)(nil) _ scannerValuer = (*UnixTimeNS)(nil) _ scannerValuer = (*Unsigned64)(nil) @@ -80,6 +89,28 @@ func (cfg AutopilotConfig) Value() (driver.Value, error) { return json.Marshal(cfg) } +// Scan implements the sql.Scanner interface. +func (sc *BCurrency) Scan(src any) error { + buf, ok := src.([]byte) + if !ok { + return fmt.Errorf("cannot scan %T to Currency", src) + } else if len(buf) != 16 { + return fmt.Errorf("cannot scan %d bytes to Currency", len(buf)) + } + + sc.Hi = binary.BigEndian.Uint64(buf[:8]) + sc.Lo = binary.BigEndian.Uint64(buf[8:]) + return nil +} + +// Value implements the driver.Valuer interface. +func (sc BCurrency) Value() (driver.Value, error) { + buf := make([]byte, 16) + binary.BigEndian.PutUint64(buf[:8], sc.Hi) + binary.BigEndian.PutUint64(buf[8:], sc.Lo) + return buf, nil +} + // Scan scan value into BigInt, implements sql.Scanner interface. func (b *BigInt) Scan(value interface{}) error { var s string @@ -102,22 +133,6 @@ func (b BigInt) Value() (driver.Value, error) { return (*big.Int)(&b).String(), nil } -// Scan scan value into CCID, implements sql.Scanner interface. -func (c *CCID) Scan(value interface{}) error { - switch value := value.(type) { - case []byte: - copy(c[:], value) - default: - return fmt.Errorf("failed to unmarshal CCID value: %v %t", value, value) - } - return nil -} - -// Value returns a publicKey value, implements driver.Valuer interface. -func (c CCID) Value() (driver.Value, error) { - return c[:], nil -} - // Scan scan value into Currency, implements sql.Scanner interface. func (c *Currency) Scan(value interface{}) error { var s string @@ -224,27 +239,54 @@ func (pk PublicKey) Value() (driver.Value, error) { return pk[:], nil } +// Scan scans value into a MerkleProof, implements sql.Scanner interface. +func (mp *MerkleProof) Scan(value interface{}) error { + b, ok := value.([]byte) + if !ok { + return errors.New(fmt.Sprint("failed to unmarshal MerkleProof value:", value)) + } else if len(b)%proofHashSize != 0 { + return fmt.Errorf("failed to unmarshal MerkleProof value due to invalid number of bytes %v: %v", len(b), value) + } + + mp.Hashes = make([]types.Hash256, len(b)/proofHashSize) + for i := range mp.Hashes { + copy(mp.Hashes[i][:], b[i*proofHashSize:]) + } + return nil +} + +// Value returns a MerkleProof value, implements driver.Valuer interface. +func (mp MerkleProof) Value() (driver.Value, error) { + b := make([]byte, len(mp.Hashes)*proofHashSize) + for i, h := range mp.Hashes { + copy(b[i*proofHashSize:], h[:]) + } + return b, nil +} + // String implements fmt.Stringer to prevent the key from getting leaked in // logs. -func (k SecretKey) String() string { +func (k EncryptionKey) String() string { return "*****" } // Scan scans value into key, implements sql.Scanner interface. -func (k *SecretKey) Scan(value interface{}) error { +func (k *EncryptionKey) Scan(value interface{}) error { bytes, ok := value.([]byte) if !ok { - return errors.New(fmt.Sprint("failed to unmarshal secretKey value:", value)) - } else if len(bytes) != secretKeySize { - return fmt.Errorf("failed to unmarshal secretKey value due to invalid number of bytes %v != %v: %v", len(bytes), secretKeySize, value) + return errors.New(fmt.Sprint("failed to unmarshal EncryptionKey value:", value)) + } + var ec object.EncryptionKey + if err := ec.UnmarshalBinary(bytes); err != nil { + return fmt.Errorf("failed to unmarshal EncryptionKey value): %w", err) } - *k = append(SecretKey{}, SecretKey(bytes)...) + *k = EncryptionKey(ec) return nil } // Value returns an key value, implements driver.Valuer interface. -func (k SecretKey) Value() (driver.Value, error) { - return []byte(k), nil +func (k EncryptionKey) Value() (driver.Value, error) { + return object.EncryptionKey(k).MarshalBinary() } // String implements fmt.Stringer to prevent "s3authentication" settings from @@ -330,6 +372,61 @@ func (u UnixTimeNS) Value() (driver.Value, error) { return time.Time(u).UnixNano(), nil } +// Scan scan value into Uint64, implements sql.Scanner interface. +func (u *Uint64Str) Scan(value interface{}) error { + var s string + switch value := value.(type) { + case string: + s = value + case []byte: + s = string(value) + default: + return fmt.Errorf("failed to unmarshal Uint64 value: %v %t", value, value) + } + var val uint64 + _, err := fmt.Sscan(s, &val) + if err != nil { + return fmt.Errorf("failed to scan Uint64 value: %v", err) + } + *u = Uint64Str(val) + return nil +} + +// Value returns a Uint64 value, implements driver.Valuer interface. +func (u Uint64Str) Value() (driver.Value, error) { + return fmt.Sprint(u), nil +} + +func UnmarshalEventData(b []byte, t string) (dst wallet.EventData, err error) { + switch t { + case wallet.EventTypeMinerPayout, + wallet.EventTypeSiafundClaim, + wallet.EventTypeFoundationSubsidy: + var e wallet.EventPayout + err = json.Unmarshal(b, &e) + dst = e + case wallet.EventTypeV1ContractResolution: + var e wallet.EventV1ContractResolution + err = json.Unmarshal(b, &e) + dst = e + case wallet.EventTypeV2ContractResolution: + var e wallet.EventV2ContractResolution + err = json.Unmarshal(b, &e) + dst = e + case wallet.EventTypeV1Transaction: + var e wallet.EventV1Transaction + err = json.Unmarshal(b, &e) + dst = e + case wallet.EventTypeV2Transaction: + var e wallet.EventV2Transaction + err = json.Unmarshal(b, &e) + dst = e + default: + return nil, fmt.Errorf("unknown event type %v", t) + } + return +} + // Scan scan value into Unsigned64, implements sql.Scanner interface. func (u *Unsigned64) Scan(value interface{}) error { var n int64 diff --git a/stores/sql_test.go b/stores/sql_test.go index 228165370..0846254cb 100644 --- a/stores/sql_test.go +++ b/stores/sql_test.go @@ -1,20 +1,15 @@ package stores import ( - "bytes" "context" dsql "database/sql" "encoding/hex" "errors" "fmt" - "os" "path/filepath" - "reflect" - "strings" "testing" "time" - "github.com/google/go-cmp/cmp" "go.sia.tech/core/types" "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" @@ -24,21 +19,14 @@ import ( sql "go.sia.tech/renterd/stores/sql" "go.sia.tech/renterd/stores/sql/mysql" "go.sia.tech/renterd/stores/sql/sqlite" - "go.sia.tech/siad/modules" "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "go.uber.org/zap/zaptest/observer" - "gorm.io/gorm" - "gorm.io/gorm/logger" "lukechampine.com/frand" - "moul.io/zapgorm2" ) const ( - testPersistInterval = time.Second - testContractSet = "test" - testMimeType = "application/octet-stream" - testETag = "d34db33f" + testContractSet = "test" + testMimeType = "application/octet-stream" + testETag = "d34db33f" ) var ( @@ -69,11 +57,9 @@ func randomDBName() string { return "db" + hex.EncodeToString(frand.Bytes(16)) } -func (cfg *testSQLStoreConfig) dbConnections() (gorm.Dialector, sql.MetricsDatabase, error) { - var connMain gorm.Dialector - var dbm *dsql.DB +func (cfg *testSQLStoreConfig) dbConnections() (sql.Database, sql.MetricsDatabase, error) { + var dbMain sql.Database var dbMetrics sql.MetricsDatabase - var err error if mysqlCfg := config.MySQLConfigFromEnv(); mysqlCfg.URI != "" { // create MySQL connections if URI is set @@ -90,42 +76,72 @@ func (cfg *testSQLStoreConfig) dbConnections() (gorm.Dialector, sql.MetricsDatab mysqlCfg.MetricsDatabase = cfg.dbMetricsName } - // use a tmp connection to precreate the two databases - if tmpDB, err := gorm.Open(NewMySQLConnection(mysqlCfg.User, mysqlCfg.Password, mysqlCfg.URI, "")); err != nil { + // precreate the two databases + if tmpDB, err := mysql.Open(mysqlCfg.User, mysqlCfg.Password, mysqlCfg.URI, ""); err != nil { return nil, nil, err - } else if err := tmpDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", mysqlCfg.Database)).Error; err != nil { + } else if _, err := tmpDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", mysqlCfg.Database)); err != nil { return nil, nil, err - } else if err := tmpDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", mysqlCfg.MetricsDatabase)).Error; err != nil { + } else if _, err := tmpDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", mysqlCfg.MetricsDatabase)); err != nil { + return nil, nil, err + } else if err := tmpDB.Close(); err != nil { return nil, nil, err } - connMain = NewMySQLConnection(mysqlCfg.User, mysqlCfg.Password, mysqlCfg.URI, mysqlCfg.Database) - dbm, err = mysql.Open(mysqlCfg.User, mysqlCfg.Password, mysqlCfg.URI, mysqlCfg.MetricsDatabase) + // create MySQL conns + connMain, err := mysql.Open(mysqlCfg.User, mysqlCfg.Password, mysqlCfg.URI, mysqlCfg.Database) + if err != nil { + return nil, nil, fmt.Errorf("failed to open MySQL main database: %w", err) + } + connMetrics, err := mysql.Open(mysqlCfg.User, mysqlCfg.Password, mysqlCfg.URI, mysqlCfg.MetricsDatabase) if err != nil { return nil, nil, fmt.Errorf("failed to open MySQL metrics database: %w", err) } - dbMetrics, err = mysql.NewMetricsDatabase(dbm, zap.NewNop().Sugar(), 100*time.Millisecond, 100*time.Millisecond) + dbMain, err = mysql.NewMainDatabase(connMain, zap.NewNop(), 100*time.Millisecond, 100*time.Millisecond) + if err != nil { + return nil, nil, fmt.Errorf("failed to create MySQL main database: %w", err) + } + dbMetrics, err = mysql.NewMetricsDatabase(connMetrics, zap.NewNop(), 100*time.Millisecond, 100*time.Millisecond) + if err != nil { + return nil, nil, fmt.Errorf("failed to create MySQL metrics database: %w", err) + } } else if cfg.persistent { // create SQL connections if we want a persistent store - connMain = NewSQLiteConnection(filepath.Join(cfg.dir, "db.sqlite")) - dbm, err = sqlite.Open(filepath.Join(cfg.dir, "metrics.sqlite")) + connMain, err := sqlite.Open(filepath.Join(cfg.dir, "db.sqlite")) + if err != nil { + return nil, nil, fmt.Errorf("failed to open SQLite main database: %w", err) + } + connMetrics, err := sqlite.Open(filepath.Join(cfg.dir, "metrics.sqlite")) if err != nil { return nil, nil, fmt.Errorf("failed to open SQLite metrics database: %w", err) } - dbMetrics, err = sqlite.NewMetricsDatabase(dbm, zap.NewNop().Sugar(), 100*time.Millisecond, 100*time.Millisecond) + dbMain, err = sqlite.NewMainDatabase(connMain, zap.NewNop(), 100*time.Millisecond, 100*time.Millisecond) + if err != nil { + return nil, nil, fmt.Errorf("failed to create SQLite main database: %w", err) + } + dbMetrics, err = sqlite.NewMetricsDatabase(connMetrics, zap.NewNop(), 100*time.Millisecond, 100*time.Millisecond) + if err != nil { + return nil, nil, fmt.Errorf("failed to create SQLite metrics database: %w", err) + } } else { // otherwise return ephemeral connections - connMain = NewEphemeralSQLiteConnection(cfg.dbName) - dbm, err = sqlite.OpenEphemeral(cfg.dbMetricsName) + connMain, err := sqlite.OpenEphemeral(cfg.dbName) if err != nil { return nil, nil, fmt.Errorf("failed to open ephemeral SQLite metrics database: %w", err) } - dbMetrics, err = sqlite.NewMetricsDatabase(dbm, zap.NewNop().Sugar(), 100*time.Millisecond, 100*time.Millisecond) - } - if err != nil { - return nil, nil, fmt.Errorf("failed to create metrics database: %w", err) + connMetrics, err := sqlite.OpenEphemeral(cfg.dbMetricsName) + if err != nil { + return nil, nil, fmt.Errorf("failed to open ephemeral SQLite metrics database: %w", err) + } + dbMain, err = sqlite.NewMainDatabase(connMain, zap.NewNop(), 100*time.Millisecond, 100*time.Millisecond) + if err != nil { + return nil, nil, fmt.Errorf("failed to create ephemeral SQLite main database: %w", err) + } + dbMetrics, err = sqlite.NewMetricsDatabase(connMetrics, zap.NewNop(), 100*time.Millisecond, 100*time.Millisecond) + if err != nil { + return nil, nil, fmt.Errorf("failed to create ephemeral SQLite metrics database: %w", err) + } } - return connMain, dbMetrics, nil + return dbMain, dbMetrics, nil } // newTestSQLStore creates a new SQLStore for testing. @@ -146,27 +162,22 @@ func newTestSQLStore(t *testing.T, cfg testSQLStoreConfig) *testSQLStore { } // create db connections - conn, dbMetrics, err := cfg.dbConnections() + dbMain, dbMetrics, err := cfg.dbConnections() if err != nil { t.Fatal("failed to create db connections", err) } - walletAddrs := types.Address(frand.Entropy256()) alerts := alerts.WithOrigin(alerts.NewManager(), "test") - sqlStore, _, err := NewSQLStore(Config{ - Conn: conn, + sqlStore, err := NewSQLStore(Config{ Alerts: alerts, + DB: dbMain, DBMetrics: dbMetrics, PartialSlabDir: cfg.dir, Migrate: !cfg.skipMigrate, - AnnouncementMaxAge: time.Hour, - PersistInterval: time.Second, - WalletAddress: walletAddrs, SlabBufferCompletionThreshold: 0, - Logger: zap.NewNop().Sugar(), + Logger: zap.NewNop(), LongQueryDuration: 100 * time.Millisecond, LongTxDuration: 100 * time.Millisecond, - GormLogger: newTestLogger(), RetryTransactionIntervals: []time.Duration{50 * time.Millisecond, 100 * time.Millisecond, 200 * time.Millisecond}, }) if err != nil { @@ -187,7 +198,7 @@ func newTestSQLStore(t *testing.T, cfg testSQLStoreConfig) *testSQLStore { } func (s *testSQLStore) DB() *isql.DB { - switch db := s.bMain.(type) { + switch db := s.db.(type) { case *sqlite.MainDatabase: return db.DB() case *mysql.MainDatabase: @@ -198,8 +209,32 @@ func (s *testSQLStore) DB() *isql.DB { panic("unreachable") } +func (s *testSQLStore) ExecDBSpecific(sqliteQuery, mysqlQuery string) (dsql.Result, error) { + switch db := s.db.(type) { + case *sqlite.MainDatabase: + return db.DB().Exec(context.Background(), sqliteQuery) + case *mysql.MainDatabase: + return db.DB().Exec(context.Background(), mysqlQuery) + default: + s.t.Fatal("unknown db type", db) + } + panic("unreachable") +} + +func (s *testSQLStore) QueryRowDBSpecific(sqliteQuery, mysqlQuery string, sqliteArgs, mysqlArgs []any) *isql.LoggedRow { + switch db := s.db.(type) { + case *sqlite.MainDatabase: + return db.DB().QueryRow(context.Background(), sqliteQuery, sqliteArgs...) + case *mysql.MainDatabase: + return db.DB().QueryRow(context.Background(), mysqlQuery, mysqlArgs...) + default: + s.t.Fatal("unknown db type", db) + } + panic("unreachable") +} + func (s *testSQLStore) DBMetrics() *isql.DB { - switch db := s.bMetrics.(type) { + switch db := s.dbMetrics.(type) { case *sqlite.MetricsDatabase: return db.DB() case *mysql.MetricsDatabase: @@ -217,16 +252,12 @@ func (s *testSQLStore) Close() error { return nil } -func (s *testSQLStore) DefaultBucketID() uint { - var b dbBucket - if err := s.db. - Model(&dbBucket{}). - Where("name = ?", api.DefaultBucketName). - Take(&b). - Error; err != nil { +func (s *testSQLStore) DefaultBucketID() (id int64) { + if err := s.DB().QueryRow(context.Background(), "SELECT id FROM buckets WHERE name = ?", api.DefaultBucketName). + Scan(&id); err != nil { s.t.Fatal(err) } - return b.ID + return } func (s *testSQLStore) Reopen() *testSQLStore { @@ -251,22 +282,6 @@ func (s *testSQLStore) Retry(tries int, durationBetweenAttempts time.Duration, f } } -// newTestLogger creates a console logger used for testing. -func newTestLogger() logger.Interface { - config := zap.NewProductionEncoderConfig() - config.EncodeTime = zapcore.RFC3339TimeEncoder - config.EncodeLevel = zapcore.CapitalColorLevelEncoder - config.StacktraceKey = "" - consoleEncoder := zapcore.NewConsoleEncoder(config) - - l := zap.New( - zapcore.NewCore(consoleEncoder, zapcore.AddSync(os.Stdout), zapcore.DebugLevel), - zap.AddCaller(), - zap.AddStacktrace(zapcore.ErrorLevel), - ) - return zapgorm2.New(l) -} - func (s *testSQLStore) addTestObject(path string, o object.Object) (api.Object, error) { if err := s.UpdateObjectBlocking(context.Background(), api.DefaultBucketName, path, testContractSet, testETag, testMimeType, testMetadata, o); err != nil { return api.Object{}, err @@ -277,11 +292,16 @@ func (s *testSQLStore) addTestObject(path string, o object.Object) (api.Object, } } -func (s *SQLStore) addTestContracts(keys []types.PublicKey) (fcids []types.FileContractID, contracts []api.ContractMetadata, err error) { - cnt, err := s.contractsCount() - if err != nil { - return nil, nil, err +func (s *testSQLStore) Count(table string) (n int64) { + if err := s.DB().QueryRow(context.Background(), fmt.Sprintf("SELECT COUNT(*) FROM %s", table)). + Scan(&n); err != nil { + s.t.Fatal(err) } + return +} + +func (s *testSQLStore) addTestContracts(keys []types.PublicKey) (fcids []types.FileContractID, contracts []api.ContractMetadata, err error) { + cnt := s.Count("contracts") for i, key := range keys { fcids = append(fcids, types.FileContractID{byte(int(cnt) + i + 1)}) contract, err := s.addTestContract(fcids[len(fcids)-1], key) @@ -303,16 +323,8 @@ func (s *SQLStore) addTestRenewedContract(fcid, renewedFrom types.FileContractID return s.AddRenewedContract(context.Background(), rev, types.ZeroCurrency, types.ZeroCurrency, startHeight, renewedFrom, api.ContractStatePending) } -func (s *SQLStore) contractsCount() (cnt int64, err error) { - err = s.db. - Model(&dbContract{}). - Count(&cnt). - Error - return -} - -func (s *SQLStore) overrideSlabHealth(objectID string, health float64) (err error) { - err = s.db.Exec(fmt.Sprintf(` +func (s *testSQLStore) overrideSlabHealth(objectID string, health float64) (err error) { + _, err = s.DB().Exec(context.Background(), fmt.Sprintf(` UPDATE slabs SET health = %v WHERE id IN ( SELECT * FROM ( SELECT sla.id @@ -321,221 +333,6 @@ func (s *SQLStore) overrideSlabHealth(objectID string, health float64) (err erro INNER JOIN slabs sla ON sli.db_slab_id = sla.id WHERE o.object_id = "%s" ) AS sub - )`, health, objectID)).Error + )`, health, objectID)) return } - -// TestConsensusReset is a unit test for ResetConsensusSubscription. -func TestConsensusReset(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - if ss.ccid != modules.ConsensusChangeBeginning { - t.Fatal("wrong ccid", ss.ccid, modules.ConsensusChangeBeginning) - } - - // Manually insert into the consenus_infos, the transactions and siacoin_elements tables. - ccid2 := modules.ConsensusChangeID{1} - ss.db.Create(&dbConsensusInfo{ - CCID: ccid2[:], - }) - ss.db.Create(&dbSiacoinElement{ - OutputID: hash256{2}, - }) - ss.db.Create(&dbTransaction{ - TransactionID: hash256{3}, - }) - - // Reset the consensus. - if err := ss.ResetConsensusSubscription(context.Background()); err != nil { - t.Fatal(err) - } - - // Reopen the SQLStore. - ss = ss.Reopen() - defer ss.Close() - - // Check tables. - var count int64 - if err := ss.db.Model(&dbConsensusInfo{}).Count(&count).Error; err != nil || count != 1 { - t.Fatal("table should have 1 entry", err, count) - } else if err = ss.db.Model(&dbTransaction{}).Count(&count).Error; err != nil || count > 0 { - t.Fatal("table not empty", err) - } else if err = ss.db.Model(&dbSiacoinElement{}).Count(&count).Error; err != nil || count > 0 { - t.Fatal("table not empty", err) - } - - // Check consensus info. - var ci dbConsensusInfo - if err := ss.db.Take(&ci).Error; err != nil { - t.Fatal(err) - } else if !bytes.Equal(ci.CCID, modules.ConsensusChangeBeginning[:]) { - t.Fatal("wrong ccid", ci.CCID, modules.ConsensusChangeBeginning) - } else if ci.Height != 0 { - t.Fatal("wrong height", ci.Height, 0) - } - - // Check SQLStore. - if ss.chainIndex.Height != 0 { - t.Fatal("wrong height", ss.chainIndex.Height, 0) - } else if ss.chainIndex.ID != (types.BlockID{}) { - t.Fatal("wrong id", ss.chainIndex.ID, types.BlockID{}) - } -} - -type sqliteQueryPlan struct { - Detail string `json:"detail"` -} - -func (p sqliteQueryPlan) usesIndex() bool { - d := strings.ToLower(p.Detail) - return strings.Contains(d, "using index") || strings.Contains(d, "using covering index") -} - -//nolint:tagliatelle -type mysqlQueryPlan struct { - Extra string `json:"Extra"` - PossibleKeys string `json:"possible_keys"` -} - -func (p mysqlQueryPlan) usesIndex() bool { - d := strings.ToLower(p.Extra) - return strings.Contains(d, "using index") || strings.Contains(p.PossibleKeys, "idx_") -} - -func TestQueryPlan(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - queries := []string{ - // allow_list - "SELECT * FROM host_allowlist_entry_hosts WHERE db_host_id = 1", - "SELECT * FROM host_allowlist_entry_hosts WHERE db_allowlist_entry_id = 1", - - // block_list - "SELECT * FROM host_blocklist_entry_hosts WHERE db_host_id = 1", - "SELECT * FROM host_blocklist_entry_hosts WHERE db_blocklist_entry_id = 1", - - // contract_sectors - "SELECT * FROM contract_sectors WHERE db_contract_id = 1", - "SELECT * FROM contract_sectors WHERE db_sector_id = 1", - "SELECT COUNT(DISTINCT db_sector_id) FROM contract_sectors", - - // contract_set_contracts - "SELECT * FROM contract_set_contracts WHERE db_contract_id = 1", - "SELECT * FROM contract_set_contracts WHERE db_contract_set_id = 1", - - // slabs - "SELECT * FROM slabs WHERE health_valid_until > 0", - "SELECT * FROM slabs WHERE health > 0", - "SELECT * FROM slabs WHERE db_buffered_slab_id = 1", - - // objects - "SELECT * FROM objects WHERE db_bucket_id = 1", - "SELECT * FROM objects WHERE etag = ''", - } - - for _, query := range queries { - if isSQLite(ss.db) { - var explain sqliteQueryPlan - if err := ss.db.Raw(fmt.Sprintf("EXPLAIN QUERY PLAN %s;", query)).Scan(&explain).Error; err != nil { - t.Fatal(err) - } else if !explain.usesIndex() { - t.Fatalf("query '%s' should use an index, instead the plan was %+v", query, explain) - } - } else { - var explain mysqlQueryPlan - if err := ss.db.Raw(fmt.Sprintf("EXPLAIN %s;", query)).Scan(&explain).Error; err != nil { - t.Fatal(err) - } else if !explain.usesIndex() { - t.Fatalf("query '%s' should use an index, instead the plan was %+v", query, explain) - } - } - } -} - -func TestApplyUpdatesErr(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - before := ss.lastSave - - // drop consensus_infos table to cause update to fail - if err := ss.db.Exec("DROP TABLE consensus_infos").Error; err != nil { - t.Fatal(err) - } - - // call applyUpdates with 'force' set to true - if err := ss.applyUpdates(true); err == nil { - t.Fatal("expected error") - } - - // save shouldn't have happened - if ss.lastSave != before { - t.Fatal("lastSave should not have changed") - } -} - -func TestRetryTransaction(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - // create custom logger to capture logs - observedZapCore, observedLogs := observer.New(zap.InfoLevel) - ss.logger = zap.New(observedZapCore).Sugar() - - // collectLogs returns all logs - collectLogs := func() (logs []string) { - t.Helper() - for _, entry := range observedLogs.All() { - logs = append(logs, entry.Message) - } - return - } - - // disable retries and retry a transaction that fails - ss.retryTransactionIntervals = nil - ss.retryTransaction(context.Background(), func(tx *gorm.DB) error { return errors.New("database locked") }) - - // assert transaction is attempted once and not retried - got := collectLogs() - want := []string{"transaction attempt 1/1 failed, err: database locked"} - if !reflect.DeepEqual(got, want) { - t.Fatal("unexpected logs", cmp.Diff(got, want)) - } - - // enable retries and retry the same transaction - ss.retryTransactionIntervals = []time.Duration{ - 5 * time.Millisecond, - 10 * time.Millisecond, - 15 * time.Millisecond, - } - ss.retryTransaction(context.Background(), func(tx *gorm.DB) error { return errors.New("database locked") }) - - // assert transaction is retried 4 times in total - got = collectLogs() - want = append(want, - "transaction attempt 1/4 failed, retry in 5ms, err: database locked", - "transaction attempt 2/4 failed, retry in 10ms, err: database locked", - "transaction attempt 3/4 failed, retry in 15ms, err: database locked", - "transaction attempt 4/4 failed, err: database locked", - ) - if !reflect.DeepEqual(got, want) { - t.Fatal("unexpected logs", cmp.Diff(got, want)) - } - - // retry transaction with cancelled context - ctx, cancel := context.WithCancel(context.Background()) - cancel() - ss.retryTransaction(ctx, func(tx *gorm.DB) error { return nil }) - if len(observedLogs.All()) != len(want) { - t.Fatal("expected no logs") - } - - ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) - defer cancel() - time.Sleep(time.Millisecond) - ss.retryTransaction(ctx, func(tx *gorm.DB) error { return nil }) - if len(observedLogs.All()) != len(want) { - t.Fatal("expected no logs") - } -} diff --git a/stores/types.go b/stores/types.go deleted file mode 100644 index 7020f49e6..000000000 --- a/stores/types.go +++ /dev/null @@ -1,342 +0,0 @@ -package stores - -import ( - "database/sql/driver" - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "strconv" - "strings" - "time" - - rhpv2 "go.sia.tech/core/rhp/v2" - rhpv3 "go.sia.tech/core/rhp/v3" - "go.sia.tech/core/types" -) - -const ( - secretKeySize = 32 -) - -var zeroCurrency = currency(types.ZeroCurrency) - -type ( - unixTimeMS time.Time - datetime time.Time - currency types.Currency - bCurrency types.Currency - fileContractID types.FileContractID - hash256 types.Hash256 - publicKey types.PublicKey - hostSettings rhpv2.HostSettings - hostPriceTable rhpv3.HostPriceTable - secretKey []byte - setting string -) - -// GormDataType implements gorm.GormDataTypeInterface. -func (setting) GormDataType() string { - return "string" -} - -// String implements fmt.Stringer to prevent "s3authentication" settings from -// getting leaked. -func (s setting) String() string { - if strings.Contains(string(s), "v4Keypairs") { - return "*****" - } - return string(s) -} - -// Scan scans value into the setting -func (s *setting) Scan(value interface{}) error { - switch value := value.(type) { - case string: - *s = setting(value) - case []byte: - *s = setting(value) - default: - return fmt.Errorf("failed to unmarshal setting value from type %t", value) - } - return nil -} - -// Value returns a setting value, implements driver.Valuer interface. -func (s setting) Value() (driver.Value, error) { - return string(s), nil -} - -// GormDataType implements gorm.GormDataTypeInterface. -func (secretKey) GormDataType() string { - return "bytes" -} - -// String implements fmt.Stringer to prevent the key from getting leaked in -// logs. -func (k secretKey) String() string { - return "*****" -} - -// Scan scans value into key, implements sql.Scanner interface. -func (k *secretKey) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("failed to unmarshal secretKey value:", value)) - } else if len(bytes) != secretKeySize { - return fmt.Errorf("failed to unmarshal secretKey value due to invalid number of bytes %v != %v: %v", len(bytes), secretKeySize, value) - } - *k = append(secretKey{}, secretKey(bytes)...) - return nil -} - -// Value returns an key value, implements driver.Valuer interface. -func (k secretKey) Value() (driver.Value, error) { - return []byte(k), nil -} - -// GormDataType implements gorm.GormDataTypeInterface. -func (hash256) GormDataType() string { - return "bytes" -} - -// Scan scan value into address, implements sql.Scanner interface. -func (h *hash256) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("failed to unmarshal hash256 value:", value)) - } - if len(bytes) != len(hash256{}) { - return fmt.Errorf("failed to unmarshal hash256 value due to invalid number of bytes %v != %v: %v", len(bytes), len(fileContractID{}), value) - } - *h = *(*hash256)(bytes) - return nil -} - -// Value returns an addr value, implements driver.Valuer interface. -func (h hash256) Value() (driver.Value, error) { - return h[:], nil -} - -// GormDataType implements gorm.GormDataTypeInterface. -func (fileContractID) GormDataType() string { - return "bytes" -} - -// Scan scan value into fileContractID, implements sql.Scanner interface. -func (fcid *fileContractID) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("failed to unmarshal fcid value:", value)) - } - if len(bytes) != len(fileContractID{}) { - return fmt.Errorf("failed to unmarshal fcid value due to invalid number of bytes %v != %v: %v", len(bytes), len(fileContractID{}), value) - } - *fcid = *(*fileContractID)(bytes) - return nil -} - -// Value returns a fileContractID value, implements driver.Valuer interface. -func (fcid fileContractID) Value() (driver.Value, error) { - return fcid[:], nil -} - -func (currency) GormDataType() string { - return "string" -} - -// Scan scan value into currency, implements sql.Scanner interface. -func (c *currency) Scan(value interface{}) error { - var s string - switch value := value.(type) { - case string: - s = value - case []byte: - s = string(value) - default: - return fmt.Errorf("failed to unmarshal currency value: %v %t", value, value) - } - curr, err := types.ParseCurrency(s) - if err != nil { - return err - } - *c = currency(curr) - return nil -} - -// Value returns a publicKey value, implements driver.Valuer interface. -func (c currency) Value() (driver.Value, error) { - return types.Currency(c).ExactString(), nil -} - -func (publicKey) GormDataType() string { - return "bytes" -} - -// Scan scan value into publicKey, implements sql.Scanner interface. -func (pk *publicKey) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("failed to unmarshal publicKey value:", value)) - } - if len(bytes) != len(types.PublicKey{}) { - return fmt.Errorf("failed to unmarshal publicKey value due invalid number of bytes %v != %v: %v", len(bytes), len(publicKey{}), value) - } - *pk = *(*publicKey)(bytes) - return nil -} - -// Value returns a publicKey value, implements driver.Valuer interface. -func (pk publicKey) Value() (driver.Value, error) { - return pk[:], nil -} - -func (hostSettings) GormDataType() string { - return "string" -} - -// Scan scan value into hostSettings, implements sql.Scanner interface. -func (hs *hostSettings) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("failed to unmarshal hostSettings value:", value)) - } - return json.Unmarshal(bytes, hs) -} - -// Value returns a hostSettings value, implements driver.Valuer interface. -func (hs hostSettings) Value() (driver.Value, error) { - return json.Marshal(hs) -} - -func (hs hostPriceTable) GormDataType() string { - return "string" -} - -// Scan scan value into hostPriceTable, implements sql.Scanner interface. -func (hpt *hostPriceTable) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("failed to unmarshal hostPriceTable value:", value)) - } - return json.Unmarshal(bytes, hpt) -} - -// Value returns a hostPriceTable value, implements driver.Valuer interface. -func (hs hostPriceTable) Value() (driver.Value, error) { - return json.Marshal(hs) -} - -// SQLiteTimestampFormats were taken from github.com/mattn/go-sqlite3 and are -// used when parsing a string to a date -var SQLiteTimestampFormats = []string{ - "2006-01-02 15:04:05.999999999-07:00", - "2006-01-02T15:04:05.999999999-07:00", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999", - "2006-01-02 15:04:05", - "2006-01-02T15:04:05", - "2006-01-02 15:04", - "2006-01-02T15:04", - "2006-01-02", -} - -// GormDataType implements gorm.GormDataTypeInterface. -func (datetime) GormDataType() string { - return "string" -} - -// Scan scan value into datetime, implements sql.Scanner interface. -func (dt *datetime) Scan(value interface{}) error { - var s string - switch value := value.(type) { - case string: - s = value - case []byte: - s = string(value) - case time.Time: - *dt = datetime(value) - return nil - default: - return fmt.Errorf("failed to unmarshal time.Time value: %v %T", value, value) - } - - var ok bool - var t time.Time - s = strings.TrimSuffix(s, "Z") - for _, format := range SQLiteTimestampFormats { - if timeVal, err := time.ParseInLocation(format, s, time.UTC); err == nil { - ok = true - t = timeVal - break - } - } - if !ok { - return fmt.Errorf("failed to parse datetime value: %v", s) - } - - *dt = datetime(t) - return nil -} - -// Value returns a datetime value, implements driver.Valuer interface. -func (dt datetime) Value() (driver.Value, error) { - return (time.Time)(dt).Format(SQLiteTimestampFormats[0]), nil -} - -// GormDataType implements gorm.GormDataTypeInterface. -func (unixTimeMS) GormDataType() string { - return "BIGINT" -} - -// Scan scan value into unixTimeMS, implements sql.Scanner interface. -func (u *unixTimeMS) Scan(value interface{}) error { - var msec int64 - var err error - switch value := value.(type) { - case int64: - msec = value - case []uint8: - msec, err = strconv.ParseInt(string(value), 10, 64) - if err != nil { - return fmt.Errorf("failed to unmarshal unixTimeMS value: %v %T", value, value) - } - default: - return fmt.Errorf("failed to unmarshal unixTimeMS value: %v %T", value, value) - } - - *u = unixTimeMS(time.UnixMilli(msec)) - return nil -} - -// Value returns a int64 value representing a unix timestamp in milliseconds, -// implements driver.Valuer interface. -func (u unixTimeMS) Value() (driver.Value, error) { - return time.Time(u).UnixMilli(), nil -} - -func (bCurrency) GormDataType() string { - return "bytes" -} - -// Scan implements the sql.Scanner interface. -func (sc *bCurrency) Scan(src any) error { - buf, ok := src.([]byte) - if !ok { - return fmt.Errorf("cannot scan %T to Currency", src) - } else if len(buf) != 16 { - return fmt.Errorf("cannot scan %d bytes to Currency", len(buf)) - } - - sc.Hi = binary.BigEndian.Uint64(buf[:8]) - sc.Lo = binary.BigEndian.Uint64(buf[8:]) - return nil -} - -// Value implements the driver.Valuer interface. -func (sc bCurrency) Value() (driver.Value, error) { - buf := make([]byte, 16) - binary.BigEndian.PutUint64(buf[:8], sc.Hi) - binary.BigEndian.PutUint64(buf[8:], sc.Lo) - return buf, nil -} diff --git a/stores/types_test.go b/stores/types_test.go index c985dd012..7518a17d0 100644 --- a/stores/types_test.go +++ b/stores/types_test.go @@ -1,14 +1,161 @@ package stores -import "testing" +import ( + "context" + "fmt" + "sort" + "testing" -func TestTypeSetting(t *testing.T) { - s1 := setting("some setting") - s2 := setting("v4Keypairs") + "go.sia.tech/core/types" + "go.sia.tech/renterd/stores/sql" +) - if s1.String() != "some setting" { - t.Fatal("unexpected string") - } else if s2.String() != "*****" { - t.Fatal("unexpected string") +func TestTypeCurrency(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + // prepare the table + if _, err := ss.ExecDBSpecific( + "CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);", // sqlite + "CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);", // mysql + ); err != nil { + t.Fatal(err) + } + + // insert currencies in random order + if _, err := ss.DB().Exec(context.Background(), "INSERT INTO currencies (c) VALUES (?),(?),(?);", sql.BCurrency(types.MaxCurrency), sql.BCurrency(types.NewCurrency64(1)), sql.BCurrency(types.ZeroCurrency)); err != nil { + t.Fatal(err) + } + + // fetch currencies and assert they're sorted + var currencies []sql.BCurrency + rows, err := ss.DB().Query(context.Background(), "SELECT c FROM currencies ORDER BY c ASC;") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var c sql.BCurrency + if err := rows.Scan(&c); err != nil { + t.Fatal(err) + } + currencies = append(currencies, c) + } + if !sort.SliceIsSorted(currencies, func(i, j int) bool { + return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0 + }) { + t.Fatal("currencies not sorted", currencies) + } + + // convenience variables + c0 := currencies[0] + c1 := currencies[1] + cM := currencies[2] + + tests := []struct { + a sql.BCurrency + b sql.BCurrency + cmp string + }{ + { + a: c0, + b: c1, + cmp: "<", + }, + { + a: c1, + b: c0, + cmp: ">", + }, + { + a: c0, + b: c1, + cmp: "!=", + }, + { + a: c1, + b: c1, + cmp: "=", + }, + { + a: c0, + b: cM, + cmp: "<", + }, + { + a: cM, + b: c0, + cmp: ">", + }, + { + a: cM, + b: cM, + cmp: "=", + }, + } + for i, test := range tests { + var result bool + if err := ss.QueryRowDBSpecific( + fmt.Sprintf("SELECT ? %s ?", test.cmp), + fmt.Sprintf("SELECT HEX(?) %s HEX(?)", test.cmp), + []any{test.a, test.b}, + []any{test.a, test.b}, + ).Scan(&result); err != nil { + t.Fatal(err) + } else if !result { + t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String()) + } else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 { + t.Fatal("invalid result") + } else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 { + t.Fatal("invalid result") + } else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 { + t.Fatal("invalid result") + } + } +} + +func TestTypeMerkleProof(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + // prepare the table + ss.ExecDBSpecific( + "CREATE TABLE merkle_proofs (id INTEGER PRIMARY KEY AUTOINCREMENT,merkle_proof BLOB);", + "CREATE TABLE merkle_proofs (id INT AUTO_INCREMENT PRIMARY KEY, merkle_proof BLOB);", + ) + + // insert merkle proof + mp1 := sql.MerkleProof{Hashes: []types.Hash256{{3}, {1}, {2}}} + mp2 := sql.MerkleProof{Hashes: []types.Hash256{{4}}} + if _, err := ss.DB().Exec(context.Background(), "INSERT INTO merkle_proofs (merkle_proof) VALUES (?), (?);", mp1, mp2); err != nil { + t.Fatal(err) + } + + // fetch first proof + var first sql.MerkleProof + if err := ss.DB().QueryRow(context.Background(), "SELECT merkle_proof FROM merkle_proofs").Scan(&first); err != nil { + t.Fatal(err) + } else if first.Hashes[0] != (types.Hash256{3}) || first.Hashes[1] != (types.Hash256{1}) || first.Hashes[2] != (types.Hash256{2}) { + t.Fatalf("unexpected proof %+v", first) + } + + // fetch both proofs + var both []sql.MerkleProof + rows, err := ss.DB().Query(context.Background(), "SELECT merkle_proof FROM merkle_proofs") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var mp sql.MerkleProof + if err := rows.Scan(&mp); err != nil { + t.Fatal(err) + } + both = append(both, mp) + } + if len(both) != 2 { + t.Fatalf("unexpected number of proofs: %d", len(both)) + } else if both[1].Hashes[0] != (types.Hash256{4}) { + t.Fatalf("unexpected proof %+v", both) } } diff --git a/stores/wallet.go b/stores/wallet.go index d9bf51c39..ff8053358 100644 --- a/stores/wallet.go +++ b/stores/wallet.go @@ -1,339 +1,51 @@ package stores import ( - "bytes" - "math" - "time" + "context" - "gitlab.com/NebulousLabs/encoding" "go.sia.tech/core/types" - "go.sia.tech/renterd/wallet" - "go.sia.tech/siad/modules" - "gorm.io/gorm" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/renterd/stores/sql" ) -type ( - dbSiacoinElement struct { - Model - Value currency - Address hash256 `gorm:"size:32"` - OutputID hash256 `gorm:"unique;index;NOT NULL;size:32"` - MaturityHeight uint64 `gorm:"index"` - } - - dbTransaction struct { - Model - Raw types.Transaction `gorm:"serializer:json"` - Height uint64 - BlockID hash256 `gorm:"size:32"` - TransactionID hash256 `gorm:"unique;index;NOT NULL;size:32"` - Inflow currency - Outflow currency - Timestamp int64 `gorm:"index:idx_transactions_timestamp"` - } - - outputChange struct { - addition bool - oid hash256 - sco dbSiacoinElement - } - - txnChange struct { - addition bool - txnID hash256 - txn dbTransaction - } +var ( + _ wallet.SingleAddressStore = (*SQLStore)(nil) ) -// TableName implements the gorm.Tabler interface. -func (dbSiacoinElement) TableName() string { return "siacoin_elements" } - -// TableName implements the gorm.Tabler interface. -func (dbTransaction) TableName() string { return "transactions" } - -func (s *SQLStore) Height() uint64 { - s.persistMu.Lock() - height := s.chainIndex.Height - s.persistMu.Unlock() - return height -} - -// UnspentSiacoinElements implements wallet.SingleAddressStore. -func (s *SQLStore) UnspentSiacoinElements(matured bool) ([]wallet.SiacoinElement, error) { - s.persistMu.Lock() - height := s.chainIndex.Height - s.persistMu.Unlock() - - tx := s.db - var elems []dbSiacoinElement - if matured { - tx = tx.Where("maturity_height <= ?", height) - } - if err := tx.Find(&elems).Error; err != nil { - return nil, err - } - utxo := make([]wallet.SiacoinElement, len(elems)) - for i := range elems { - utxo[i] = wallet.SiacoinElement{ - ID: types.Hash256(elems[i].OutputID), - MaturityHeight: elems[i].MaturityHeight, - SiacoinOutput: types.SiacoinOutput{ - Address: types.Address(elems[i].Address), - Value: types.Currency(elems[i].Value), - }, - } - } - return utxo, nil -} - -// Transactions implements wallet.SingleAddressStore. -func (s *SQLStore) Transactions(before, since time.Time, offset, limit int) ([]wallet.Transaction, error) { - beforeX := int64(math.MaxInt64) - sinceX := int64(0) - if !before.IsZero() { - beforeX = before.Unix() - } - if !since.IsZero() { - sinceX = since.Unix() - } - if limit == 0 || limit == -1 { - limit = math.MaxInt64 - } - - var dbTxns []dbTransaction - err := s.db.Raw("SELECT * FROM transactions WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp DESC LIMIT ? OFFSET ?", - sinceX, beforeX, limit, offset).Scan(&dbTxns). - Error - if err != nil { - return nil, err - } - - txns := make([]wallet.Transaction, len(dbTxns)) - for i := range dbTxns { - txns[i] = wallet.Transaction{ - Raw: dbTxns[i].Raw, - Index: types.ChainIndex{ - Height: dbTxns[i].Height, - ID: types.BlockID(dbTxns[i].BlockID), - }, - ID: types.TransactionID(dbTxns[i].TransactionID), - Inflow: types.Currency(dbTxns[i].Inflow), - Outflow: types.Currency(dbTxns[i].Outflow), - Timestamp: time.Unix(dbTxns[i].Timestamp, 0), - } - } - return txns, nil -} - -// ProcessConsensusChange implements chain.Subscriber. -func (s *SQLStore) processConsensusChangeWallet(cc modules.ConsensusChange) { - // Add/Remove siacoin outputs. - for _, diff := range cc.SiacoinOutputDiffs { - var sco types.SiacoinOutput - convertToCore(diff.SiacoinOutput, (*types.V1SiacoinOutput)(&sco)) - if sco.Address != s.walletAddress { - continue - } - if diff.Direction == modules.DiffApply { - // add new outputs - s.unappliedOutputChanges = append(s.unappliedOutputChanges, outputChange{ - addition: true, - oid: hash256(diff.ID), - sco: dbSiacoinElement{ - Address: hash256(sco.Address), - Value: currency(sco.Value), - OutputID: hash256(diff.ID), - MaturityHeight: uint64(cc.BlockHeight), // immediately spendable - }, - }) - } else { - // remove reverted outputs - s.unappliedOutputChanges = append(s.unappliedOutputChanges, outputChange{ - addition: false, - oid: hash256(diff.ID), - }) - } - } - - // Create a 'fake' transaction for every matured siacoin output. - for _, diff := range cc.AppliedDiffs { - for _, dsco := range diff.DelayedSiacoinOutputDiffs { - // if a delayed output is reverted in an applied diff, the - // output has matured -- add a payout transaction. - if dsco.Direction != modules.DiffRevert { - continue - } else if types.Address(dsco.SiacoinOutput.UnlockHash) != s.walletAddress { - continue - } - var sco types.SiacoinOutput - convertToCore(dsco.SiacoinOutput, (*types.V1SiacoinOutput)(&sco)) - s.unappliedTxnChanges = append(s.unappliedTxnChanges, txnChange{ - addition: true, - txnID: hash256(dsco.ID), // use output id as txn id - txn: dbTransaction{ - Height: uint64(dsco.MaturityHeight), - Inflow: currency(sco.Value), // transaction inflow is value of matured output - TransactionID: hash256(dsco.ID), // use output as txn id - Timestamp: int64(cc.AppliedBlocks[dsco.MaturityHeight-cc.InitialHeight()-1].Timestamp), // use timestamp of block that caused output to mature - }, - }) - } - } - - // Revert transactions from reverted blocks. - for _, block := range cc.RevertedBlocks { - for _, stxn := range block.Transactions { - var txn types.Transaction - convertToCore(stxn, &txn) - if transactionIsRelevant(txn, s.walletAddress) { - // remove reverted txns - s.unappliedTxnChanges = append(s.unappliedTxnChanges, txnChange{ - addition: false, - txnID: hash256(txn.ID()), - }) - } - } - } - - // Revert 'fake' transactions. - for _, diff := range cc.RevertedDiffs { - for _, dsco := range diff.DelayedSiacoinOutputDiffs { - if dsco.Direction == modules.DiffApply { - s.unappliedTxnChanges = append(s.unappliedTxnChanges, txnChange{ - addition: false, - txnID: hash256(dsco.ID), - }) - } - } - } - - spentOutputs := make(map[types.SiacoinOutputID]types.SiacoinOutput) - for i, block := range cc.AppliedBlocks { - appliedDiff := cc.AppliedDiffs[i] - for _, diff := range appliedDiff.SiacoinOutputDiffs { - if diff.Direction == modules.DiffRevert { - var so types.SiacoinOutput - convertToCore(diff.SiacoinOutput, (*types.V1SiacoinOutput)(&so)) - spentOutputs[types.SiacoinOutputID(diff.ID)] = so - } - } - - for _, stxn := range block.Transactions { - var txn types.Transaction - convertToCore(stxn, &txn) - if transactionIsRelevant(txn, s.walletAddress) { - var inflow, outflow types.Currency - for _, out := range txn.SiacoinOutputs { - if out.Address == s.walletAddress { - inflow = inflow.Add(out.Value) - } - } - for _, in := range txn.SiacoinInputs { - if in.UnlockConditions.UnlockHash() == s.walletAddress { - so, ok := spentOutputs[in.ParentID] - if !ok { - panic("spent output not found") - } - outflow = outflow.Add(so.Value) - } - } - - // add confirmed txns - s.unappliedTxnChanges = append(s.unappliedTxnChanges, txnChange{ - addition: true, - txnID: hash256(txn.ID()), - txn: dbTransaction{ - Raw: txn, - Height: uint64(cc.InitialHeight()) + uint64(i) + 1, - BlockID: hash256(block.ID()), - Inflow: currency(inflow), - Outflow: currency(outflow), - TransactionID: hash256(txn.ID()), - Timestamp: int64(block.Timestamp), - }, - }) - } - } - } -} - -func transactionIsRelevant(txn types.Transaction, addr types.Address) bool { - for i := range txn.SiacoinInputs { - if txn.SiacoinInputs[i].UnlockConditions.UnlockHash() == addr { - return true - } - } - for i := range txn.SiacoinOutputs { - if txn.SiacoinOutputs[i].Address == addr { - return true - } - } - for i := range txn.SiafundInputs { - if txn.SiafundInputs[i].UnlockConditions.UnlockHash() == addr { - return true - } - if txn.SiafundInputs[i].ClaimAddress == addr { - return true - } - } - for i := range txn.SiafundOutputs { - if txn.SiafundOutputs[i].Address == addr { - return true - } - } - for i := range txn.FileContracts { - for _, sco := range txn.FileContracts[i].ValidProofOutputs { - if sco.Address == addr { - return true - } - } - for _, sco := range txn.FileContracts[i].MissedProofOutputs { - if sco.Address == addr { - return true - } - } - } - for i := range txn.FileContractRevisions { - for _, sco := range txn.FileContractRevisions[i].ValidProofOutputs { - if sco.Address == addr { - return true - } - } - for _, sco := range txn.FileContractRevisions[i].MissedProofOutputs { - if sco.Address == addr { - return true - } - } - } - return false -} - -func convertToCore(siad encoding.SiaMarshaler, core types.DecoderFrom) { - var buf bytes.Buffer - siad.MarshalSia(&buf) - d := types.NewBufDecoder(buf.Bytes()) - core.DecodeFrom(d) - if d.Err() != nil { - panic(d.Err()) - } -} - -func applyUnappliedOutputAdditions(tx *gorm.DB, sco dbSiacoinElement) error { - return tx.Create(&sco).Error -} - -func applyUnappliedOutputRemovals(tx *gorm.DB, oid hash256) error { - return tx.Where("output_id", oid). - Delete(&dbSiacoinElement{}). - Error -} - -func applyUnappliedTxnAdditions(tx *gorm.DB, txn dbTransaction) error { - return tx.Create(&txn).Error -} - -func applyUnappliedTxnRemovals(tx *gorm.DB, txnID hash256) error { - return tx.Where("transaction_id", txnID). - Delete(&dbTransaction{}). - Error +// Tip returns the consensus change ID and block height of the last wallet +// change. +func (s *SQLStore) Tip() (ci types.ChainIndex, err error) { + err = s.db.Transaction(s.shutdownCtx, func(tx sql.DatabaseTx) error { + ci, err = tx.Tip(s.shutdownCtx) + return err + }) + return +} + +// UnspentSiacoinElements returns a list of all unspent siacoin outputs +func (s *SQLStore) UnspentSiacoinElements() (elements []types.SiacoinElement, err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { + elements, err = tx.UnspentSiacoinElements(context.Background()) + return + }) + return +} + +// WalletEvents returns a paginated list of events, ordered by maturity height, +// descending. If no more events are available, (nil, nil) is returned. +func (s *SQLStore) WalletEvents(offset, limit int) (events []wallet.Event, err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { + events, err = tx.WalletEvents(context.Background(), offset, limit) + return + }) + return +} + +// WalletEventCount returns the number of events relevant to the wallet. +func (s *SQLStore) WalletEventCount() (count uint64, err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { + count, err = tx.WalletEventCount(context.Background()) + return + }) + return } diff --git a/stores/webhooks.go b/stores/webhooks.go index 02516c419..e7e3782ea 100644 --- a/stores/webhooks.go +++ b/stores/webhooks.go @@ -8,19 +8,19 @@ import ( ) func (s *SQLStore) AddWebhook(ctx context.Context, wh webhooks.Webhook) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.AddWebhook(ctx, wh) }) } func (s *SQLStore) DeleteWebhook(ctx context.Context, wh webhooks.Webhook) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.DeleteWebhook(ctx, wh) }) } func (s *SQLStore) Webhooks(ctx context.Context) (whs []webhooks.Webhook, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { whs, err = tx.Webhooks(ctx) return err }) diff --git a/wallet/wallet.go b/wallet/wallet.go deleted file mode 100644 index 6c641ed42..000000000 --- a/wallet/wallet.go +++ /dev/null @@ -1,650 +0,0 @@ -package wallet - -import ( - "bytes" - "context" - "errors" - "fmt" - "sort" - "sync" - "time" - - "gitlab.com/NebulousLabs/encoding" - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/renterd/api" - "go.sia.tech/siad/modules" - "go.uber.org/zap" -) - -const ( - // BytesPerInput is the encoded size of a SiacoinInput and corresponding - // TransactionSignature, assuming standard UnlockConditions. - BytesPerInput = 241 - - // redistributeBatchSize is the number of outputs to redistribute per txn to - // avoid creating a txn that is too large. - redistributeBatchSize = 10 - - // transactionDefragThreshold is the number of utxos at which the wallet - // will attempt to defrag itself by including small utxos in transactions. - transactionDefragThreshold = 30 - // maxInputsForDefrag is the maximum number of inputs a transaction can - // have before the wallet will stop adding inputs - maxInputsForDefrag = 30 - // maxDefragUTXOs is the maximum number of utxos that will be added to a - // transaction when defragging - maxDefragUTXOs = 10 -) - -// ErrInsufficientBalance is returned when there aren't enough unused outputs to -// cover the requested amount. -var ErrInsufficientBalance = errors.New("insufficient balance") - -// StandardUnlockConditions returns the standard unlock conditions for a single -// Ed25519 key. -func StandardUnlockConditions(pk types.PublicKey) types.UnlockConditions { - return types.UnlockConditions{ - PublicKeys: []types.UnlockKey{{ - Algorithm: types.SpecifierEd25519, - Key: pk[:], - }}, - SignaturesRequired: 1, - } -} - -// StandardAddress returns the standard address for an Ed25519 key. -func StandardAddress(pk types.PublicKey) types.Address { - return StandardUnlockConditions(pk).UnlockHash() -} - -// StandardTransactionSignature returns the standard signature object for a -// siacoin or siafund input. -func StandardTransactionSignature(id types.Hash256) types.TransactionSignature { - return types.TransactionSignature{ - ParentID: id, - CoveredFields: types.CoveredFields{WholeTransaction: true}, - PublicKeyIndex: 0, - } -} - -// ExplicitCoveredFields returns a CoveredFields that covers all elements -// present in txn. -func ExplicitCoveredFields(txn types.Transaction) (cf types.CoveredFields) { - for i := range txn.SiacoinInputs { - cf.SiacoinInputs = append(cf.SiacoinInputs, uint64(i)) - } - for i := range txn.SiacoinOutputs { - cf.SiacoinOutputs = append(cf.SiacoinOutputs, uint64(i)) - } - for i := range txn.FileContracts { - cf.FileContracts = append(cf.FileContracts, uint64(i)) - } - for i := range txn.FileContractRevisions { - cf.FileContractRevisions = append(cf.FileContractRevisions, uint64(i)) - } - for i := range txn.StorageProofs { - cf.StorageProofs = append(cf.StorageProofs, uint64(i)) - } - for i := range txn.SiafundInputs { - cf.SiafundInputs = append(cf.SiafundInputs, uint64(i)) - } - for i := range txn.SiafundOutputs { - cf.SiafundOutputs = append(cf.SiafundOutputs, uint64(i)) - } - for i := range txn.MinerFees { - cf.MinerFees = append(cf.MinerFees, uint64(i)) - } - for i := range txn.ArbitraryData { - cf.ArbitraryData = append(cf.ArbitraryData, uint64(i)) - } - for i := range txn.Signatures { - cf.Signatures = append(cf.Signatures, uint64(i)) - } - return -} - -// A SiacoinElement is a SiacoinOutput along with its ID. -type SiacoinElement struct { - types.SiacoinOutput - ID types.Hash256 `json:"id"` - MaturityHeight uint64 `json:"maturityHeight"` -} - -// A Transaction is an on-chain transaction relevant to a particular wallet, -// paired with useful metadata. -type Transaction struct { - Raw types.Transaction `json:"raw,omitempty"` - Index types.ChainIndex `json:"index"` - ID types.TransactionID `json:"id"` - Inflow types.Currency `json:"inflow"` - Outflow types.Currency `json:"outflow"` - Timestamp time.Time `json:"timestamp"` -} - -// A SingleAddressStore stores the state of a single-address wallet. -// Implementations are assumed to be thread safe. -type SingleAddressStore interface { - Height() uint64 - UnspentSiacoinElements(matured bool) ([]SiacoinElement, error) - Transactions(before, since time.Time, offset, limit int) ([]Transaction, error) - RecordWalletMetric(ctx context.Context, metrics ...api.WalletMetric) error -} - -// A TransactionPool contains transactions that have not yet been included in a -// block. -type TransactionPool interface { - ContainsElement(id types.Hash256) bool -} - -// A SingleAddressWallet is a hot wallet that manages the outputs controlled by -// a single address. -type SingleAddressWallet struct { - log *zap.SugaredLogger - priv types.PrivateKey - addr types.Address - store SingleAddressStore - usedUTXOExpiry time.Duration - - // for building transactions - mu sync.Mutex - lastUsed map[types.Hash256]time.Time - // tpoolTxns maps a transaction set ID to the transactions in that set - tpoolTxns map[types.Hash256][]Transaction - // tpoolUtxos maps a siacoin output ID to its corresponding siacoin - // element. It is used to track siacoin outputs that are currently in - // the transaction pool. - tpoolUtxos map[types.SiacoinOutputID]SiacoinElement - // tpoolSpent is a set of siacoin output IDs that are currently in the - // transaction pool. - tpoolSpent map[types.SiacoinOutputID]bool -} - -// PrivateKey returns the private key of the wallet. -func (w *SingleAddressWallet) PrivateKey() types.PrivateKey { - return w.priv -} - -// Address returns the address of the wallet. -func (w *SingleAddressWallet) Address() types.Address { - return w.addr -} - -// Balance returns the balance of the wallet. -func (w *SingleAddressWallet) Balance() (spendable, confirmed, unconfirmed types.Currency, _ error) { - sces, err := w.store.UnspentSiacoinElements(true) - if err != nil { - return types.Currency{}, types.Currency{}, types.Currency{}, err - } - w.mu.Lock() - defer w.mu.Unlock() - for _, sce := range sces { - if !w.isOutputUsed(sce.ID) { - spendable = spendable.Add(sce.Value) - } - confirmed = confirmed.Add(sce.Value) - } - for _, sco := range w.tpoolUtxos { - if !w.isOutputUsed(sco.ID) { - unconfirmed = unconfirmed.Add(sco.Value) - } - } - return -} - -func (w *SingleAddressWallet) Height() uint64 { - return w.store.Height() -} - -// UnspentOutputs returns the set of unspent Siacoin outputs controlled by the -// wallet. -func (w *SingleAddressWallet) UnspentOutputs() ([]SiacoinElement, error) { - sces, err := w.store.UnspentSiacoinElements(false) - if err != nil { - return nil, err - } - w.mu.Lock() - defer w.mu.Unlock() - filtered := sces[:0] - for _, sce := range sces { - if !w.isOutputUsed(sce.ID) { - filtered = append(filtered, sce) - } - } - return filtered, nil -} - -// Transactions returns up to max transactions relevant to the wallet that have -// a timestamp later than since. -func (w *SingleAddressWallet) Transactions(before, since time.Time, offset, limit int) ([]Transaction, error) { - return w.store.Transactions(before, since, offset, limit) -} - -// FundTransaction adds siacoin inputs worth at least the requested amount to -// the provided transaction. A change output is also added, if necessary. The -// inputs will not be available to future calls to FundTransaction unless -// ReleaseInputs is called or enough time has passed. -func (w *SingleAddressWallet) FundTransaction(cs consensus.State, txn *types.Transaction, amount types.Currency, useUnconfirmedTxns bool) ([]types.Hash256, error) { - if amount.IsZero() { - return nil, nil - } - w.mu.Lock() - defer w.mu.Unlock() - - // fetch all unspent siacoin elements - utxos, err := w.store.UnspentSiacoinElements(false) - if err != nil { - return nil, err - } - - // desc sort - sort.Slice(utxos, func(i, j int) bool { - return utxos[i].Value.Cmp(utxos[j].Value) > 0 - }) - - // add all unconfirmed outputs to the end of the slice as a last resort - if useUnconfirmedTxns { - var tpoolUtxos []SiacoinElement - for _, sco := range w.tpoolUtxos { - tpoolUtxos = append(tpoolUtxos, sco) - } - // desc sort - sort.Slice(tpoolUtxos, func(i, j int) bool { - return tpoolUtxos[i].Value.Cmp(tpoolUtxos[j].Value) > 0 - }) - utxos = append(utxos, tpoolUtxos...) - } - - // remove locked and spent outputs - usableUTXOs := utxos[:0] - for _, sce := range utxos { - if w.isOutputUsed(sce.ID) { - continue - } - usableUTXOs = append(usableUTXOs, sce) - } - - // fund the transaction using the largest utxos first - var selected []SiacoinElement - var inputSum types.Currency - for i, sce := range usableUTXOs { - if inputSum.Cmp(amount) >= 0 { - usableUTXOs = usableUTXOs[i:] - break - } - selected = append(selected, sce) - inputSum = inputSum.Add(sce.Value) - } - - // if the transaction can't be funded, return an error - if inputSum.Cmp(amount) < 0 { - return nil, fmt.Errorf("%w: inputSum: %v, amount: %v", ErrInsufficientBalance, inputSum.String(), amount.String()) - } - - // check if remaining utxos should be defragged - txnInputs := len(txn.SiacoinInputs) + len(selected) - if len(usableUTXOs) > transactionDefragThreshold && txnInputs < maxInputsForDefrag { - // add the smallest utxos to the transaction - defraggable := usableUTXOs - if len(defraggable) > maxDefragUTXOs { - defraggable = defraggable[len(defraggable)-maxDefragUTXOs:] - } - for i := len(defraggable) - 1; i >= 0; i-- { - if txnInputs >= maxInputsForDefrag { - break - } - - sce := defraggable[i] - selected = append(selected, sce) - inputSum = inputSum.Add(sce.Value) - txnInputs++ - } - } - - // add a change output if necessary - if inputSum.Cmp(amount) > 0 { - txn.SiacoinOutputs = append(txn.SiacoinOutputs, types.SiacoinOutput{ - Value: inputSum.Sub(amount), - Address: w.addr, - }) - } - - toSign := make([]types.Hash256, len(selected)) - for i, sce := range selected { - txn.SiacoinInputs = append(txn.SiacoinInputs, types.SiacoinInput{ - ParentID: types.SiacoinOutputID(sce.ID), - UnlockConditions: types.StandardUnlockConditions(w.priv.PublicKey()), - }) - toSign[i] = types.Hash256(sce.ID) - w.lastUsed[sce.ID] = time.Now() - } - - return toSign, nil -} - -// ReleaseInputs is a helper function that releases the inputs of txn for use in -// other transactions. It should only be called on transactions that are invalid -// or will never be broadcast. -func (w *SingleAddressWallet) ReleaseInputs(txns ...types.Transaction) { - w.mu.Lock() - defer w.mu.Unlock() - w.releaseInputs(txns...) -} - -func (w *SingleAddressWallet) releaseInputs(txns ...types.Transaction) { - for _, txn := range txns { - for _, in := range txn.SiacoinInputs { - delete(w.lastUsed, types.Hash256(in.ParentID)) - } - } -} - -// SignTransaction adds a signature to each of the specified inputs. -func (w *SingleAddressWallet) SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error { - for _, id := range toSign { - ts := types.TransactionSignature{ - ParentID: id, - CoveredFields: cf, - PublicKeyIndex: 0, - } - var h types.Hash256 - if cf.WholeTransaction { - h = cs.WholeSigHash(*txn, ts.ParentID, ts.PublicKeyIndex, ts.Timelock, cf.Signatures) - } else { - h = cs.PartialSigHash(*txn, cf) - } - sig := w.priv.SignHash(h) - ts.Signature = sig[:] - txn.Signatures = append(txn.Signatures, ts) - } - return nil -} - -// Redistribute returns a transaction that redistributes money in the wallet by -// selecting a minimal set of inputs to cover the creation of the requested -// outputs. It also returns a list of output IDs that need to be signed. -func (w *SingleAddressWallet) Redistribute(cs consensus.State, outputs int, amount, feePerByte types.Currency, pool []types.Transaction) ([]types.Transaction, []types.Hash256, error) { - w.mu.Lock() - defer w.mu.Unlock() - - // build map of inputs currently in the tx pool - inPool := make(map[types.Hash256]bool) - for _, ptxn := range pool { - for _, in := range ptxn.SiacoinInputs { - inPool[types.Hash256(in.ParentID)] = true - } - } - - // fetch unspent transaction outputs - utxos, err := w.store.UnspentSiacoinElements(false) - if err != nil { - return nil, nil, err - } - - // check whether a redistribution is necessary, adjust number of desired - // outputs accordingly - for _, sce := range utxos { - inUse := w.isOutputUsed(sce.ID) || inPool[sce.ID] - matured := cs.Index.Height >= sce.MaturityHeight - sameValue := sce.Value.Equals(amount) - if !inUse && matured && sameValue { - outputs-- - } - } - if outputs <= 0 { - return nil, nil, nil - } - - // desc sort - sort.Slice(utxos, func(i, j int) bool { - return utxos[i].Value.Cmp(utxos[j].Value) > 0 - }) - - // prepare all outputs - var txns []types.Transaction - var toSign []types.Hash256 - - for outputs > 0 { - var txn types.Transaction - for i := 0; i < outputs && i < redistributeBatchSize; i++ { - txn.SiacoinOutputs = append(txn.SiacoinOutputs, types.SiacoinOutput{ - Value: amount, - Address: w.Address(), - }) - } - outputs -= len(txn.SiacoinOutputs) - - // estimate the fees - outputFees := feePerByte.Mul64(uint64(len(encoding.Marshal(txn.SiacoinOutputs)))) - feePerInput := feePerByte.Mul64(BytesPerInput) - - // collect outputs that cover the total amount - var inputs []SiacoinElement - want := amount.Mul64(uint64(len(txn.SiacoinOutputs))) - var amtInUse, amtSameValue, amtNotMatured types.Currency - for _, sce := range utxos { - inUse := w.isOutputUsed(sce.ID) || inPool[sce.ID] - matured := cs.Index.Height >= sce.MaturityHeight - sameValue := sce.Value.Equals(amount) - if inUse { - amtInUse = amtInUse.Add(sce.Value) - continue - } else if sameValue { - amtSameValue = amtSameValue.Add(sce.Value) - continue - } else if !matured { - amtNotMatured = amtNotMatured.Add(sce.Value) - continue - } - - inputs = append(inputs, sce) - fee := feePerInput.Mul64(uint64(len(inputs))).Add(outputFees) - if SumOutputs(inputs).Cmp(want.Add(fee)) > 0 { - break - } - } - - // not enough outputs found - fee := feePerInput.Mul64(uint64(len(inputs))).Add(outputFees) - if sumOut := SumOutputs(inputs); sumOut.Cmp(want.Add(fee)) < 0 { - // in case of an error we need to free all inputs - w.releaseInputs(txns...) - return nil, nil, fmt.Errorf("%w: inputs %v < needed %v + txnFee %v (usable: %v, inUse: %v, sameValue: %v, notMatured: %v)", - ErrInsufficientBalance, sumOut.String(), want.String(), fee.String(), sumOut.String(), amtInUse.String(), amtSameValue.String(), amtNotMatured.String()) - } - - // set the miner fee - txn.MinerFees = []types.Currency{fee} - - // add the change output - change := SumOutputs(inputs).Sub(want.Add(fee)) - if !change.IsZero() { - txn.SiacoinOutputs = append(txn.SiacoinOutputs, types.SiacoinOutput{ - Value: change, - Address: w.addr, - }) - } - - // add the inputs - for _, sce := range inputs { - txn.SiacoinInputs = append(txn.SiacoinInputs, types.SiacoinInput{ - ParentID: types.SiacoinOutputID(sce.ID), - UnlockConditions: StandardUnlockConditions(w.priv.PublicKey()), - }) - toSign = append(toSign, sce.ID) - w.lastUsed[sce.ID] = time.Now() - } - - txns = append(txns, txn) - } - - return txns, toSign, nil -} - -func (w *SingleAddressWallet) isOutputUsed(id types.Hash256) bool { - inPool := w.tpoolSpent[types.SiacoinOutputID(id)] - lastUsed := w.lastUsed[id] - if w.usedUTXOExpiry == 0 { - return !lastUsed.IsZero() || inPool - } - return time.Since(lastUsed) <= w.usedUTXOExpiry || inPool -} - -// ProcessConsensusChange implements modules.ConsensusSetSubscriber. -func (w *SingleAddressWallet) ProcessConsensusChange(cc modules.ConsensusChange) { - // only record when we are synced - if !cc.Synced { - return - } - - // apply sane timeout - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - // fetch balance - spendable, confirmed, unconfirmed, err := w.Balance() - if err != nil { - w.log.Errorf("failed to fetch wallet balance, err: %v", err) - return - } - - // record wallet metric - if err := w.store.RecordWalletMetric(ctx, api.WalletMetric{ - Timestamp: api.TimeNow(), - Confirmed: confirmed, - Unconfirmed: unconfirmed, - Spendable: spendable, - }); err != nil { - w.log.Errorf("failed to record wallet metric, err: %v", err) - return - } -} - -// ReceiveUpdatedUnconfirmedTransactions implements modules.TransactionPoolSubscriber. -func (w *SingleAddressWallet) ReceiveUpdatedUnconfirmedTransactions(diff *modules.TransactionPoolDiff) { - siacoinOutputs := make(map[types.SiacoinOutputID]SiacoinElement) - utxos, err := w.store.UnspentSiacoinElements(false) - if err != nil { - return - } - for _, output := range utxos { - siacoinOutputs[types.SiacoinOutputID(output.ID)] = output - } - - w.mu.Lock() - defer w.mu.Unlock() - - for id, output := range w.tpoolUtxos { - siacoinOutputs[id] = output - } - - for _, txnsetID := range diff.RevertedTransactions { - txns, ok := w.tpoolTxns[types.Hash256(txnsetID)] - if !ok { - continue - } - for _, txn := range txns { - for _, sci := range txn.Raw.SiacoinInputs { - delete(w.tpoolSpent, sci.ParentID) - } - for i := range txn.Raw.SiacoinOutputs { - delete(w.tpoolUtxos, txn.Raw.SiacoinOutputID(i)) - } - } - delete(w.tpoolTxns, types.Hash256(txnsetID)) - } - - currentHeight := w.store.Height() - - for _, txnset := range diff.AppliedTransactions { - var relevantTxns []Transaction - - txnLoop: - for _, stxn := range txnset.Transactions { - var relevant bool - var txn types.Transaction - convertToCore(stxn, &txn) - processed := Transaction{ - ID: txn.ID(), - Index: types.ChainIndex{ - Height: currentHeight + 1, - }, - Raw: txn, - Timestamp: time.Now(), - } - for _, sci := range txn.SiacoinInputs { - if sci.UnlockConditions.UnlockHash() != w.addr { - continue - } - relevant = true - w.tpoolSpent[sci.ParentID] = true - - output, ok := siacoinOutputs[sci.ParentID] - if !ok { - // note: happens during deep reorgs. Possibly a race - // condition in siad. Log and skip. - w.log.Info("tpool transaction unknown utxo", zap.Stringer("outputID", sci.ParentID), zap.Stringer("txnID", txn.ID())) - continue txnLoop - } - processed.Outflow = processed.Outflow.Add(output.Value) - } - - for i, sco := range txn.SiacoinOutputs { - if sco.Address != w.addr { - continue - } - relevant = true - outputID := txn.SiacoinOutputID(i) - processed.Inflow = processed.Inflow.Add(sco.Value) - sce := SiacoinElement{ - ID: types.Hash256(outputID), - SiacoinOutput: sco, - } - siacoinOutputs[outputID] = sce - w.tpoolUtxos[outputID] = sce - } - - if relevant { - relevantTxns = append(relevantTxns, processed) - } - } - - if len(relevantTxns) != 0 { - w.tpoolTxns[types.Hash256(txnset.ID)] = relevantTxns - } - } -} - -// SumOutputs returns the total value of the supplied outputs. -func SumOutputs(outputs []SiacoinElement) (sum types.Currency) { - for _, o := range outputs { - sum = sum.Add(o.Value) - } - return -} - -// NewSingleAddressWallet returns a new SingleAddressWallet using the provided private key and store. -func NewSingleAddressWallet(priv types.PrivateKey, store SingleAddressStore, usedUTXOExpiry time.Duration, log *zap.SugaredLogger) *SingleAddressWallet { - return &SingleAddressWallet{ - priv: priv, - addr: StandardAddress(priv.PublicKey()), - store: store, - lastUsed: make(map[types.Hash256]time.Time), - usedUTXOExpiry: usedUTXOExpiry, - tpoolTxns: make(map[types.Hash256][]Transaction), - tpoolUtxos: make(map[types.SiacoinOutputID]SiacoinElement), - tpoolSpent: make(map[types.SiacoinOutputID]bool), - log: log.Named("wallet"), - } -} - -// convertToCore converts a siad type to an equivalent core type. -func convertToCore(siad encoding.SiaMarshaler, core types.DecoderFrom) { - var buf bytes.Buffer - siad.MarshalSia(&buf) - d := types.NewBufDecoder(buf.Bytes()) - core.DecodeFrom(d) - if d.Err() != nil { - panic(d.Err()) - } -} diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go deleted file mode 100644 index 0538d50af..000000000 --- a/wallet/wallet_test.go +++ /dev/null @@ -1,195 +0,0 @@ -package wallet - -import ( - "context" - "strings" - "testing" - "time" - - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/renterd/api" - "go.uber.org/zap" - "lukechampine.com/frand" -) - -// mockStore implements wallet.SingleAddressStore and allows to manipulate the -// wallet's utxos -type mockStore struct { - utxos []SiacoinElement -} - -func (s *mockStore) Balance() (types.Currency, error) { return types.ZeroCurrency, nil } -func (s *mockStore) Height() uint64 { return 0 } -func (s *mockStore) UnspentSiacoinElements(bool) ([]SiacoinElement, error) { - return s.utxos, nil -} -func (s *mockStore) Transactions(before, since time.Time, offset, limit int) ([]Transaction, error) { - return nil, nil -} -func (s *mockStore) RecordWalletMetric(ctx context.Context, metrics ...api.WalletMetric) error { - return nil -} - -var cs = consensus.State{ - Index: types.ChainIndex{ - Height: 1, - ID: types.BlockID{}, - }, -} - -// TestWalletRedistribute is a small unit test that covers the functionality of -// the 'Redistribute' method on the wallet. -func TestWalletRedistribute(t *testing.T) { - oneSC := types.Siacoins(1) - - // create a wallet with one output - priv := types.GeneratePrivateKey() - pub := priv.PublicKey() - utxo := SiacoinElement{ - types.SiacoinOutput{ - Value: oneSC.Mul64(20), - Address: StandardAddress(pub), - }, - randomOutputID(), - 0, - } - s := &mockStore{utxos: []SiacoinElement{utxo}} - w := NewSingleAddressWallet(priv, s, 0, zap.NewNop().Sugar()) - - numOutputsWithValue := func(v types.Currency) (c uint64) { - utxos, _ := w.UnspentOutputs() - for _, utxo := range utxos { - if utxo.Value.Equals(v) { - c++ - } - } - return - } - - applyTxn := func(txn types.Transaction) { - for _, input := range txn.SiacoinInputs { - for i, utxo := range s.utxos { - if input.ParentID == types.SiacoinOutputID(utxo.ID) { - s.utxos[i] = s.utxos[len(s.utxos)-1] - s.utxos = s.utxos[:len(s.utxos)-1] - } - } - } - for _, output := range txn.SiacoinOutputs { - s.utxos = append(s.utxos, SiacoinElement{output, randomOutputID(), 0}) - } - } - - // assert number of outputs - if utxos, err := w.UnspentOutputs(); err != nil { - t.Fatal(err) - } else if len(utxos) != 1 { - t.Fatalf("unexpected number of outputs, %v != 1", len(utxos)) - } - - // split into 3 outputs of 6SC each - amount := oneSC.Mul64(6) - if txns, _, err := w.Redistribute(cs, 3, amount, types.NewCurrency64(1), nil); err != nil { - t.Fatal(err) - } else if len(txns) != 1 { - t.Fatalf("unexpected number of txns, %v != 1", len(txns)) - } else { - applyTxn(txns[0]) - } - - // assert number of outputs - if utxos, err := w.UnspentOutputs(); err != nil { - t.Fatal(err) - } else if len(s.utxos) != 4 { - t.Fatalf("unexpected number of outputs, %v != 4", len(utxos)) - } - - // assert number of outputs that hold 6SC - if cnt := numOutputsWithValue(amount); cnt != 3 { - t.Fatalf("unexpected number of 6SC outputs, %v != 3", cnt) - } - - // split into 3 outputs of 7SC each, expect this to fail - _, _, err := w.Redistribute(cs, 3, oneSC.Mul64(7), types.NewCurrency64(1), nil) - if err == nil || !strings.Contains(err.Error(), "insufficient balance") { - t.Fatalf("unexpected err: '%v'", err) - } - - // split into 2 outputs of 9SC - amount = oneSC.Mul64(9) - if txns, _, err := w.Redistribute(cs, 2, amount, types.NewCurrency64(1), nil); err != nil { - t.Fatal(err) - } else if len(txns) != 1 { - t.Fatalf("unexpected number of txns, %v != 1", len(txns)) - } else { - applyTxn(txns[0]) - } - - // assert number of outputs - if utxos, err := w.UnspentOutputs(); err != nil { - t.Fatal(err) - } else if len(s.utxos) != 3 { - t.Fatalf("unexpected number of outputs, %v != 3", len(utxos)) - } - - // assert number of outputs that hold 9SC - if cnt := numOutputsWithValue(amount); cnt != 2 { - t.Fatalf("unexpected number of 9SC outputs, %v != 2", cnt) - } - - // split into 5 outputs of 3SC - amount = oneSC.Mul64(3) - if txns, _, err := w.Redistribute(cs, 5, amount, types.NewCurrency64(1), nil); err != nil { - t.Fatal(err) - } else if len(txns) != 1 { - t.Fatalf("unexpected number of txns, %v != 1", len(txns)) - } else { - applyTxn(txns[0]) - } - - // assert number of outputs that hold 3SC - if cnt := numOutputsWithValue(amount); cnt != 5 { - t.Fatalf("unexpected number of 3SC outputs, %v != 5", cnt) - } - - // split into 4 outputs of 3SC - this should be a no-op - if _, _, err := w.Redistribute(cs, 4, amount, types.NewCurrency64(1), nil); err != nil { - t.Fatal(err) - } - - // split into 6 outputs of 3SC - if txns, _, err := w.Redistribute(cs, 6, amount, types.NewCurrency64(1), nil); err != nil { - t.Fatal(err) - } else if len(txns) != 1 { - t.Fatalf("unexpected number of txns, %v != 1", len(txns)) - } else { - applyTxn(txns[0]) - } - - // assert number of outputs that hold 3SC - if cnt := numOutputsWithValue(amount); cnt != 6 { - t.Fatalf("unexpected number of 3SC outputs, %v != 6", cnt) - } - - // split into 2 times the redistributeBatchSize - amount = oneSC.Div64(10) - if txns, _, err := w.Redistribute(cs, 2*redistributeBatchSize, amount, types.NewCurrency64(1), nil); err != nil { - t.Fatal(err) - } else if len(txns) != 2 { - t.Fatalf("unexpected number of txns, %v != 2", len(txns)) - } else { - applyTxn(txns[0]) - applyTxn(txns[1]) - } - - // assert number of outputs that hold 0.1SC - if cnt := numOutputsWithValue(amount); cnt != 2*redistributeBatchSize { - t.Fatalf("unexpected number of 0.1SC outputs, %v != 20", cnt) - } -} - -func randomOutputID() (t types.Hash256) { - frand.Read(t[:]) - return -} diff --git a/webhooks/webhooks.go b/webhooks/webhooks.go index 20bf94381..0f1eb636f 100644 --- a/webhooks/webhooks.go +++ b/webhooks/webhooks.go @@ -13,7 +13,6 @@ import ( "time" "go.uber.org/zap" - "gorm.io/gorm" ) var ErrWebhookNotFound = errors.New("Webhook not found") @@ -128,18 +127,10 @@ func (m *Manager) BroadcastAction(_ context.Context, event Event) error { return nil } -func (m *Manager) Close() error { - m.shutdownCtxCancel() - m.wg.Wait() - return nil -} - func (m *Manager) Delete(ctx context.Context, wh Webhook) error { m.mu.Lock() defer m.mu.Unlock() - if err := m.store.DeleteWebhook(ctx, wh); errors.Is(err, gorm.ErrRecordNotFound) { - return ErrWebhookNotFound - } else if err != nil { + if err := m.store.DeleteWebhook(ctx, wh); err != nil { return err } delete(m.webhooks, wh.String()) @@ -191,6 +182,23 @@ func (m *Manager) Register(ctx context.Context, wh Webhook) error { return nil } +func (m *Manager) Shutdown(ctx context.Context) error { + m.shutdownCtxCancel() + + waitChan := make(chan struct{}) + go func() { + m.wg.Wait() + close(waitChan) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-waitChan: + } + return nil +} + func (a Event) String() string { return a.Module + "." + a.Event } @@ -225,10 +233,10 @@ func (w Webhook) String() string { return fmt.Sprintf("%v.%v.%v", w.URL, w.Module, w.Event) } -func NewManager(logger *zap.SugaredLogger, store WebhookStore) (*Manager, error) { +func NewManager(store WebhookStore, logger *zap.Logger) (*Manager, error) { shutdownCtx, shutdownCtxCancel := context.WithCancel(context.Background()) m := &Manager{ - logger: logger.Named("webhooks"), + logger: logger.Named("webhooks").Sugar(), store: store, shutdownCtx: shutdownCtx, diff --git a/worker/accounts.go b/worker/accounts.go new file mode 100644 index 000000000..76a18d37e --- /dev/null +++ b/worker/accounts.go @@ -0,0 +1,167 @@ +package worker + +import ( + "context" + "errors" + "fmt" + "math/big" + "time" + + rhpv3 "go.sia.tech/core/rhp/v3" + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" +) + +const ( + // accountLockingDuration is the time for which an account lock remains + // reserved on the bus after locking it. + accountLockingDuration = 30 * time.Second +) + +type ( + // accounts stores the balance and other metrics of accounts that the + // worker maintains with a host. + accounts struct { + as AccountStore + key types.PrivateKey + } + + // account contains information regarding a specific account of the + // worker. + account struct { + as AccountStore + id rhpv3.Account + key types.PrivateKey + host types.PublicKey + } +) + +// ForHost returns an account to use for a given host. If the account +// doesn't exist, a new one is created. +func (a *accounts) ForHost(hk types.PublicKey) *account { + accountID := rhpv3.Account(a.deriveAccountKey(hk).PublicKey()) + return &account{ + as: a.as, + id: accountID, + key: a.key, + host: hk, + } +} + +// deriveAccountKey derives an account plus key for a given host and worker. +// Each worker has its own account for a given host. That makes concurrency +// around keeping track of an accounts balance and refilling it a lot easier in +// a multi-worker setup. +func (a *accounts) deriveAccountKey(hostKey types.PublicKey) types.PrivateKey { + index := byte(0) // not used yet but can be used to derive more than 1 account per host + + // Append the host for which to create it and the index to the + // corresponding sub-key. + subKey := a.key + data := make([]byte, 0, len(subKey)+len(hostKey)+1) + data = append(data, subKey[:]...) + data = append(data, hostKey[:]...) + data = append(data, index) + + seed := types.HashBytes(data) + pk := types.NewPrivateKeyFromSeed(seed[:]) + for i := range seed { + seed[i] = 0 + } + return pk +} + +// Balance returns the account balance. +func (a *account) Balance(ctx context.Context) (balance types.Currency, err error) { + err = withAccountLock(ctx, a.as, a.id, a.host, false, func(account api.Account) error { + balance = types.NewCurrency(account.Balance.Uint64(), new(big.Int).Rsh(account.Balance, 64).Uint64()) + return nil + }) + return +} + +// WithDeposit increases the balance of an account by the amount returned by +// amtFn if amtFn doesn't return an error. +func (a *account) WithDeposit(ctx context.Context, amtFn func() (types.Currency, error)) error { + return withAccountLock(ctx, a.as, a.id, a.host, false, func(_ api.Account) error { + amt, err := amtFn() + if err != nil { + return err + } + return a.as.AddBalance(ctx, a.id, a.host, amt.Big()) + }) +} + +// WithSync syncs an accounts balance with the bus. To do so, the account is +// locked while the balance is fetched through balanceFn. +func (a *account) WithSync(ctx context.Context, balanceFn func() (types.Currency, error)) error { + return withAccountLock(ctx, a.as, a.id, a.host, true, func(_ api.Account) error { + balance, err := balanceFn() + if err != nil { + return err + } + return a.as.SetBalance(ctx, a.id, a.host, balance.Big()) + }) +} + +// WithWithdrawal decreases the balance of an account by the amount returned by +// amtFn. The amount is still withdrawn if amtFn returns an error since some +// costs are non-refundable. +func (a *account) WithWithdrawal(ctx context.Context, amtFn func() (types.Currency, error)) error { + return withAccountLock(ctx, a.as, a.id, a.host, false, func(account api.Account) error { + // return early if the account needs to sync + if account.RequiresSync { + return fmt.Errorf("%w; account requires resync", rhp3.ErrBalanceInsufficient) + } + + // return early if our account is not funded + if account.Balance.Cmp(big.NewInt(0)) <= 0 { + return rhp3.ErrBalanceInsufficient + } + + // execute amtFn + amt, err := amtFn() + + // in case of an insufficient balance, we schedule a sync + if rhp3.IsBalanceInsufficient(err) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + err = errors.Join(err, a.as.ScheduleSync(ctx, a.id, a.host)) + cancel() + } + + // if an amount was returned, we withdraw it + if !amt.IsZero() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + err = errors.Join(err, a.as.AddBalance(ctx, a.id, a.host, new(big.Int).Neg(amt.Big()))) + cancel() + } + + return err + }) +} + +func (w *Worker) initAccounts(as AccountStore) { + if w.accounts != nil { + panic("accounts already initialized") // developer error + } + w.accounts = &accounts{ + as: as, + key: w.deriveSubKey("accountkey"), + } +} + +func withAccountLock(ctx context.Context, as AccountStore, id rhpv3.Account, hk types.PublicKey, exclusive bool, fn func(a api.Account) error) error { + acc, lockID, err := as.LockAccount(ctx, id, hk, exclusive, accountLockingDuration) + if err != nil { + return err + } + err = fn(acc) + + // unlock account + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + _ = as.UnlockAccount(ctx, acc.ID, lockID) // ignore error + cancel() + + return err +} diff --git a/worker/client/client.go b/worker/client/client.go index c1ab8a70e..9abac4d0e 100644 --- a/worker/client/client.go +++ b/worker/client/client.go @@ -269,8 +269,8 @@ func (c *Client) UploadStats() (resp api.UploadStatsResponse, err error) { return } -// RegisterEvent register an event. -func (c *Client) RegisterEvent(ctx context.Context, e webhooks.Event) (err error) { +// NotifyEvent notifies the worker of an event. +func (c *Client) NotifyEvent(ctx context.Context, e webhooks.Event) (err error) { err = c.c.WithContext(ctx).POST("/events", e, nil) return } diff --git a/worker/contract_lock.go b/worker/contract_lock.go index f5115d37f..4f9d147ba 100644 --- a/worker/contract_lock.go +++ b/worker/contract_lock.go @@ -51,7 +51,7 @@ func newContractLock(ctx context.Context, fcid types.FileContractID, lockID uint return cl } -func (w *worker) acquireContractLock(ctx context.Context, fcid types.FileContractID, priority int) (_ *contractLock, err error) { +func (w *Worker) acquireContractLock(ctx context.Context, fcid types.FileContractID, priority int) (_ *contractLock, err error) { lockID, err := w.bus.AcquireContract(ctx, fcid, priority, w.contractLockingDuration) if err != nil { return nil, err @@ -59,7 +59,7 @@ func (w *worker) acquireContractLock(ctx context.Context, fcid types.FileContrac return newContractLock(w.shutdownCtx, fcid, lockID, w.contractLockingDuration, w.bus, w.logger), nil } -func (w *worker) withContractLock(ctx context.Context, fcid types.FileContractID, priority int, fn func() error) error { +func (w *Worker) withContractLock(ctx context.Context, fcid types.FileContractID, priority int, fn func() error) error { contractLock, err := w.acquireContractLock(ctx, fcid, priority) if err != nil { return err diff --git a/worker/download.go b/worker/download.go index 2e69f7375..e1c54771d 100644 --- a/worker/download.go +++ b/worker/download.go @@ -13,9 +13,9 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" - "go.sia.tech/renterd/stats" "go.uber.org/zap" ) @@ -39,8 +39,8 @@ type ( maxOverdrive uint64 overdriveTimeout time.Duration - statsOverdrivePct *stats.DataPoints - statsSlabDownloadSpeedBytesPerMS *stats.DataPoints + statsOverdrivePct *utils.DataPoints + statsSlabDownloadSpeedBytesPerMS *utils.DataPoints shutdownCtx context.Context @@ -127,27 +127,26 @@ type ( } ) -func (w *worker) initDownloadManager(maxMemory, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.SugaredLogger) { +func (w *Worker) initDownloadManager(maxMemory, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.Logger) { if w.downloadManager != nil { panic("download manager already initialized") // developer error } - - mm := newMemoryManager(logger.Named("memorymanager"), maxMemory) - w.downloadManager = newDownloadManager(w.shutdownCtx, w, mm, w.bus, maxOverdrive, overdriveTimeout, logger) + w.downloadManager = newDownloadManager(w.shutdownCtx, w, w.bus, maxMemory, maxOverdrive, overdriveTimeout, logger) } -func newDownloadManager(ctx context.Context, hm HostManager, mm MemoryManager, os ObjectStore, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.SugaredLogger) *downloadManager { +func newDownloadManager(ctx context.Context, hm HostManager, os ObjectStore, maxMemory, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.Logger) *downloadManager { + logger = logger.Named("downloadmanager") return &downloadManager{ hm: hm, - mm: mm, + mm: newMemoryManager(maxMemory, logger), os: os, - logger: logger, + logger: logger.Sugar(), maxOverdrive: maxOverdrive, overdriveTimeout: overdriveTimeout, - statsOverdrivePct: stats.NoDecay(), - statsSlabDownloadSpeedBytesPerMS: stats.NoDecay(), + statsOverdrivePct: utils.NewDataPoints(0), + statsSlabDownloadSpeedBytesPerMS: utils.NewDataPoints(0), shutdownCtx: ctx, @@ -759,11 +758,11 @@ loop: } // handle lost sectors - if isSectorNotFound(resp.err) { + if rhp3.IsSectorNotFound(resp.err) { if err := s.mgr.os.DeleteHostSector(ctx, resp.req.host.PublicKey(), resp.req.root); err != nil { s.mgr.logger.Errorw("failed to mark sector as lost", "hk", resp.req.host.PublicKey(), "root", resp.req.root, zap.Error(err)) } - } else if isPriceTableGouging(resp.err) && s.overpay && !resp.req.overpay { + } else if rhp3.IsPriceTableGouging(resp.err) && s.overpay && !resp.req.overpay { resp.req.overpay = true // ensures we don't retry the same request over and over again gouging = append(gouging, resp.req) } diff --git a/worker/downloader.go b/worker/downloader.go index 46dac61e3..9720237a1 100644 --- a/worker/downloader.go +++ b/worker/downloader.go @@ -10,7 +10,8 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" - "go.sia.tech/renterd/stats" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" + "go.sia.tech/renterd/internal/utils" ) const ( @@ -26,8 +27,8 @@ type ( downloader struct { host Host - statsDownloadSpeedBytesPerMS *stats.DataPoints // keep track of this separately for stats (no decay is applied) - statsSectorDownloadEstimateInMS *stats.DataPoints + statsDownloadSpeedBytesPerMS *utils.DataPoints // keep track of this separately for stats (no decay is applied) + statsSectorDownloadEstimateInMS *utils.DataPoints signalWorkChan chan struct{} shutdownCtx context.Context @@ -44,8 +45,8 @@ func newDownloader(ctx context.Context, host Host) *downloader { return &downloader{ host: host, - statsSectorDownloadEstimateInMS: stats.Default(), - statsDownloadSpeedBytesPerMS: stats.NoDecay(), + statsSectorDownloadEstimateInMS: utils.NewDataPoints(10 * time.Minute), + statsDownloadSpeedBytesPerMS: utils.NewDataPoints(0), signalWorkChan: make(chan struct{}, 1), shutdownCtx: ctx, @@ -295,10 +296,10 @@ func (d *downloader) trackFailure(err error) { return } - if isBalanceInsufficient(err) || - isPriceTableExpired(err) || - isPriceTableNotFound(err) || - isSectorNotFound(err) { + if rhp3.IsBalanceInsufficient(err) || + rhp3.IsPriceTableExpired(err) || + rhp3.IsPriceTableNotFound(err) || + rhp3.IsSectorNotFound(err) { return // host is not to blame for these errors } diff --git a/worker/gouging.go b/worker/gouging.go index 9566e73cc..eccc5785e 100644 --- a/worker/gouging.go +++ b/worker/gouging.go @@ -2,509 +2,44 @@ package worker import ( "context" - "errors" "fmt" - "time" - rhpv2 "go.sia.tech/core/rhp/v2" - rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/gouging" ) const ( keyGougingChecker contextKey = "GougingChecker" - - // maxBaseRPCPriceVsBandwidth is the max ratio for sane pricing between the - // MinBaseRPCPrice and the MinDownloadBandwidthPrice. This ensures that 1 - // million base RPC charges are at most 1% of the cost to download 4TB. This - // ratio should be used by checking that the MinBaseRPCPrice is less than or - // equal to the MinDownloadBandwidthPrice multiplied by this constant - maxBaseRPCPriceVsBandwidth = uint64(40e3) - - // maxSectorAccessPriceVsBandwidth is the max ratio for sane pricing between - // the MinSectorAccessPrice and the MinDownloadBandwidthPrice. This ensures - // that 1 million base accesses are at most 10% of the cost to download 4TB. - // This ratio should be used by checking that the MinSectorAccessPrice is - // less than or equal to the MinDownloadBandwidthPrice multiplied by this - // constant - maxSectorAccessPriceVsBandwidth = uint64(400e3) ) -var ( - errHostSettingsGouging = errors.New("host settings gouging detected") - errPriceTableGouging = errors.New("price table gouging detected") -) - -type ( - GougingChecker interface { - Check(_ *rhpv2.HostSettings, _ *rhpv3.HostPriceTable) api.HostGougingBreakdown - BlocksUntilBlockHeightGouging(hostHeight uint64) int64 - } - - gougingChecker struct { - consensusState api.ConsensusState - settings api.GougingSettings - txFee types.Currency - - period *uint64 - renewWindow *uint64 - } - - contextKey string -) +type contextKey string -var _ GougingChecker = gougingChecker{} - -func GougingCheckerFromContext(ctx context.Context, criticalMigration bool) (GougingChecker, error) { - gc, ok := ctx.Value(keyGougingChecker).(func(bool) (GougingChecker, error)) +func GougingCheckerFromContext(ctx context.Context, criticalMigration bool) (gouging.Checker, error) { + gc, ok := ctx.Value(keyGougingChecker).(func(bool) (gouging.Checker, error)) if !ok { panic("no gouging checker attached to the context") // developer error } return gc(criticalMigration) } -func WithGougingChecker(ctx context.Context, cs ConsensusState, gp api.GougingParams) context.Context { - return context.WithValue(ctx, keyGougingChecker, func(criticalMigration bool) (GougingChecker, error) { - consensusState, err := cs.ConsensusState(ctx) +func WithGougingChecker(ctx context.Context, cs gouging.ConsensusState, gp api.GougingParams) context.Context { + return context.WithValue(ctx, keyGougingChecker, func(criticalMigration bool) (gouging.Checker, error) { + cs, err := cs.ConsensusState(ctx) if err != nil { - return gougingChecker{}, fmt.Errorf("failed to get consensus state: %w", err) - } - - // adjust the max download price if we are dealing with a critical - // migration that might be failing due to gouging checks - settings := gp.GougingSettings - if criticalMigration && gp.GougingSettings.MigrationSurchargeMultiplier > 0 { - if adjustedMaxDownloadPrice, overflow := gp.GougingSettings.MaxDownloadPrice.Mul64WithOverflow(gp.GougingSettings.MigrationSurchargeMultiplier); !overflow { - settings.MaxDownloadPrice = adjustedMaxDownloadPrice - } + return nil, fmt.Errorf("failed to get consensus state: %w", err) } - - return gougingChecker{ - consensusState: consensusState, - settings: settings, - txFee: gp.TransactionFee, - - // NOTE: - // - // period and renew window are nil here and that's fine, gouging - // checkers in the workers don't have easy access to these settings and - // thus ignore them when perform gouging checks, the autopilot however - // does have those and will pass them when performing gouging checks - period: nil, - renewWindow: nil, - }, nil + return newGougingChecker(gp.GougingSettings, cs, gp.TransactionFee, criticalMigration), nil }) } -func NewGougingChecker(gs api.GougingSettings, cs api.ConsensusState, txnFee types.Currency, period, renewWindow uint64) GougingChecker { - return gougingChecker{ - consensusState: cs, - settings: gs, - txFee: txnFee, - - period: &period, - renewWindow: &renewWindow, - } -} - -func (gc gougingChecker) BlocksUntilBlockHeightGouging(hostHeight uint64) int64 { - blockHeight := gc.consensusState.BlockHeight - leeway := gc.settings.HostBlockHeightLeeway - var minHeight uint64 - if blockHeight >= uint64(leeway) { - minHeight = blockHeight - uint64(leeway) - } - return int64(hostHeight) - int64(minHeight) -} - -func (gc gougingChecker) Check(hs *rhpv2.HostSettings, pt *rhpv3.HostPriceTable) api.HostGougingBreakdown { - if hs == nil && pt == nil { - panic("gouging checker needs to be provided with at least host settings or a price table") // developer error - } - - return api.HostGougingBreakdown{ - ContractErr: errsToStr( - checkContractGougingRHPv2(gc.period, gc.renewWindow, hs), - checkContractGougingRHPv3(gc.period, gc.renewWindow, pt), - ), - DownloadErr: errsToStr(checkDownloadGougingRHPv3(gc.settings, pt)), - GougingErr: errsToStr( - checkPriceGougingPT(gc.settings, gc.consensusState, gc.txFee, pt), - checkPriceGougingHS(gc.settings, hs), - ), - PruneErr: errsToStr(checkPruneGougingRHPv2(gc.settings, hs)), - UploadErr: errsToStr(checkUploadGougingRHPv3(gc.settings, pt)), - } -} - -func checkPriceGougingHS(gs api.GougingSettings, hs *rhpv2.HostSettings) error { - // check if we have settings - if hs == nil { - return nil - } - // check base rpc price - if !gs.MaxRPCPrice.IsZero() && hs.BaseRPCPrice.Cmp(gs.MaxRPCPrice) > 0 { - return fmt.Errorf("rpc price exceeds max: %v > %v", hs.BaseRPCPrice, gs.MaxRPCPrice) - } - maxBaseRPCPrice := hs.DownloadBandwidthPrice.Mul64(maxBaseRPCPriceVsBandwidth) - if hs.BaseRPCPrice.Cmp(maxBaseRPCPrice) > 0 { - return fmt.Errorf("rpc price too high, %v > %v", hs.BaseRPCPrice, maxBaseRPCPrice) - } - - // check sector access price - if hs.DownloadBandwidthPrice.IsZero() { - hs.DownloadBandwidthPrice = types.NewCurrency64(1) - } - maxSectorAccessPrice := hs.DownloadBandwidthPrice.Mul64(maxSectorAccessPriceVsBandwidth) - if hs.SectorAccessPrice.Cmp(maxSectorAccessPrice) > 0 { - return fmt.Errorf("sector access price too high, %v > %v", hs.SectorAccessPrice, maxSectorAccessPrice) - } - - // check max storage price - if !gs.MaxStoragePrice.IsZero() && hs.StoragePrice.Cmp(gs.MaxStoragePrice) > 0 { - return fmt.Errorf("storage price exceeds max: %v > %v", hs.StoragePrice, gs.MaxStoragePrice) - } - - // check contract price - if !gs.MaxContractPrice.IsZero() && hs.ContractPrice.Cmp(gs.MaxContractPrice) > 0 { - return fmt.Errorf("contract price exceeds max: %v > %v", hs.ContractPrice, gs.MaxContractPrice) - } - - // check max EA balance - if hs.MaxEphemeralAccountBalance.Cmp(gs.MinMaxEphemeralAccountBalance) < 0 { - return fmt.Errorf("'MaxEphemeralAccountBalance' is less than the allowed minimum value, %v < %v", hs.MaxEphemeralAccountBalance, gs.MinMaxEphemeralAccountBalance) - } - - // check EA expiry - if hs.EphemeralAccountExpiry < gs.MinAccountExpiry { - return fmt.Errorf("'EphemeralAccountExpiry' is less than the allowed minimum value, %v < %v", hs.EphemeralAccountExpiry, gs.MinAccountExpiry) - } - - return nil -} - -// TODO: if we ever stop assuming that certain prices in the pricetable are -// always set to 1H we should account for those fields in -// `hostPeriodCostForScore` as well. -func checkPriceGougingPT(gs api.GougingSettings, cs api.ConsensusState, txnFee types.Currency, pt *rhpv3.HostPriceTable) error { - // check if we have a price table - if pt == nil { - return nil - } - // check base rpc price - if !gs.MaxRPCPrice.IsZero() && gs.MaxRPCPrice.Cmp(pt.InitBaseCost) < 0 { - return fmt.Errorf("init base cost exceeds max: %v > %v", pt.InitBaseCost, gs.MaxRPCPrice) - } - - // check contract price - if !gs.MaxContractPrice.IsZero() && pt.ContractPrice.Cmp(gs.MaxContractPrice) > 0 { - return fmt.Errorf("contract price exceeds max: %v > %v", pt.ContractPrice, gs.MaxContractPrice) - } - - // check max storage - if !gs.MaxStoragePrice.IsZero() && pt.WriteStoreCost.Cmp(gs.MaxStoragePrice) > 0 { - return fmt.Errorf("storage price exceeds max: %v > %v", pt.WriteStoreCost, gs.MaxStoragePrice) - } - - // check max collateral - if pt.MaxCollateral.IsZero() { - return errors.New("MaxCollateral of host is 0") - } - // check ReadLengthCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.ReadLengthCost) < 0 { - return fmt.Errorf("ReadLengthCost of host is %v but should be %v", pt.ReadLengthCost, types.NewCurrency64(1)) - } - - // check WriteLengthCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.WriteLengthCost) < 0 { - return fmt.Errorf("WriteLengthCost of %v exceeds 1H", pt.WriteLengthCost) - } - - // check AccountBalanceCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.AccountBalanceCost) < 0 { - return fmt.Errorf("AccountBalanceCost of %v exceeds 1H", pt.AccountBalanceCost) - } - - // check FundAccountCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.FundAccountCost) < 0 { - return fmt.Errorf("FundAccountCost of %v exceeds 1H", pt.FundAccountCost) - } - - // check UpdatePriceTableCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.UpdatePriceTableCost) < 0 { - return fmt.Errorf("UpdatePriceTableCost of %v exceeds 1H", pt.UpdatePriceTableCost) - } - - // check HasSectorBaseCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.HasSectorBaseCost) < 0 { - return fmt.Errorf("HasSectorBaseCost of %v exceeds 1H", pt.HasSectorBaseCost) - } - - // check MemoryTimeCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.MemoryTimeCost) < 0 { - return fmt.Errorf("MemoryTimeCost of %v exceeds 1H", pt.MemoryTimeCost) - } - - // check DropSectorsBaseCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.DropSectorsBaseCost) < 0 { - return fmt.Errorf("DropSectorsBaseCost of %v exceeds 1H", pt.DropSectorsBaseCost) - } - - // check DropSectorsUnitCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.DropSectorsUnitCost) < 0 { - return fmt.Errorf("DropSectorsUnitCost of %v exceeds 1H", pt.DropSectorsUnitCost) - } - - // check SwapSectorBaseCost - should be 1H as it's unused by hosts - if types.NewCurrency64(1).Cmp(pt.SwapSectorBaseCost) < 0 { - return fmt.Errorf("SwapSectorBaseCost of %v exceeds 1H", pt.SwapSectorBaseCost) - } - - // check SubscriptionMemoryCost - expect 1H default - if types.NewCurrency64(1).Cmp(pt.SubscriptionMemoryCost) < 0 { - return fmt.Errorf("SubscriptionMemoryCost of %v exceeds 1H", pt.SubscriptionMemoryCost) - } - - // check SubscriptionNotificationCost - expect 1H default - if types.NewCurrency64(1).Cmp(pt.SubscriptionNotificationCost) < 0 { - return fmt.Errorf("SubscriptionNotificationCost of %v exceeds 1H", pt.SubscriptionNotificationCost) - } - - // check LatestRevisionCost - expect sane value - maxRevisionCost, overflow := gs.MaxRPCPrice.AddWithOverflow(gs.MaxDownloadPrice.Div64(1 << 40).Mul64(2048)) - if overflow { - maxRevisionCost = types.MaxCurrency - } - if pt.LatestRevisionCost.Cmp(maxRevisionCost) > 0 { - return fmt.Errorf("LatestRevisionCost of %v exceeds maximum cost of %v", pt.LatestRevisionCost, maxRevisionCost) - } - - // check RenewContractCost - expect 100nS default - if types.Siacoins(1).Mul64(100).Div64(1e9).Cmp(pt.RenewContractCost) < 0 { - return fmt.Errorf("RenewContractCost of %v exceeds 100nS", pt.RenewContractCost) - } - - // check RevisionBaseCost - expect 0H default - if types.ZeroCurrency.Cmp(pt.RevisionBaseCost) < 0 { - return fmt.Errorf("RevisionBaseCost of %v exceeds 0H", pt.RevisionBaseCost) - } - - // check block height - if too much time has passed since the last block - // there is a chance we are not up-to-date anymore. So we only check whether - // the host's height is at least equal to ours. - if !cs.Synced || time.Since(cs.LastBlockTime.Std()) > time.Hour { - if pt.HostBlockHeight < cs.BlockHeight { - return fmt.Errorf("consensus not synced and host block height is lower, %v < %v", pt.HostBlockHeight, cs.BlockHeight) +func newGougingChecker(settings api.GougingSettings, cs api.ConsensusState, txnFee types.Currency, criticalMigration bool) gouging.Checker { + // adjust the max download price if we are dealing with a critical + // migration that might be failing due to gouging checks + if criticalMigration && settings.MigrationSurchargeMultiplier > 0 { + if adjustedMaxDownloadPrice, overflow := settings.MaxDownloadPrice.Mul64WithOverflow(settings.MigrationSurchargeMultiplier); !overflow { + settings.MaxDownloadPrice = adjustedMaxDownloadPrice } - } else { - var minHeight uint64 - if cs.BlockHeight >= uint64(gs.HostBlockHeightLeeway) { - minHeight = cs.BlockHeight - uint64(gs.HostBlockHeightLeeway) - } - maxHeight := cs.BlockHeight + uint64(gs.HostBlockHeightLeeway) - if !(minHeight <= pt.HostBlockHeight && pt.HostBlockHeight <= maxHeight) { - return fmt.Errorf("consensus is synced and host block height is not within range, %v-%v %v", minHeight, maxHeight, pt.HostBlockHeight) - } - } - - // check TxnFeeMaxRecommended - expect at most a multiple of our fee - if !txnFee.IsZero() && pt.TxnFeeMaxRecommended.Cmp(txnFee.Mul64(5)) > 0 { - return fmt.Errorf("TxnFeeMaxRecommended %v exceeds %v", pt.TxnFeeMaxRecommended, txnFee.Mul64(5)) - } - - // check TxnFeeMinRecommended - expect it to be lower or equal than the max - if pt.TxnFeeMinRecommended.Cmp(pt.TxnFeeMaxRecommended) > 0 { - return fmt.Errorf("TxnFeeMinRecommended is greater than TxnFeeMaxRecommended, %v > %v", pt.TxnFeeMinRecommended, pt.TxnFeeMaxRecommended) - } - - // check Validity - if pt.Validity < gs.MinPriceTableValidity { - return fmt.Errorf("'Validity' is less than the allowed minimum value, %v < %v", pt.Validity, gs.MinPriceTableValidity) - } - - return nil -} - -func checkContractGougingRHPv2(period, renewWindow *uint64, hs *rhpv2.HostSettings) (err error) { - // period and renew window might be nil since we don't always have access to - // these settings when performing gouging checks - if hs == nil || period == nil || renewWindow == nil { - return nil - } - - err = checkContractGouging(*period, *renewWindow, hs.MaxDuration, hs.WindowSize) - if err != nil { - err = fmt.Errorf("%w: %v", errHostSettingsGouging, err) - } - return -} - -func checkContractGougingRHPv3(period, renewWindow *uint64, pt *rhpv3.HostPriceTable) (err error) { - // period and renew window might be nil since we don't always have access to - // these settings when performing gouging checks - if pt == nil || period == nil || renewWindow == nil { - return nil - } - err = checkContractGouging(*period, *renewWindow, pt.MaxDuration, pt.WindowSize) - if err != nil { - err = fmt.Errorf("%w: %v", errPriceTableGouging, err) - } - return -} - -func checkContractGouging(period, renewWindow, maxDuration, windowSize uint64) error { - // check MaxDuration - if period != 0 && period > maxDuration { - return fmt.Errorf("MaxDuration %v is lower than the period %v", maxDuration, period) - } - - // check WindowSize - if renewWindow != 0 && renewWindow < windowSize { - return fmt.Errorf("minimum WindowSize %v is greater than the renew window %v", windowSize, renewWindow) - } - - return nil -} - -func checkPruneGougingRHPv2(gs api.GougingSettings, hs *rhpv2.HostSettings) error { - if hs == nil { - return nil - } - // pruning costs are similar to sector read costs in a way because they - // include base costs and download bandwidth costs, to avoid re-adding all - // RHPv2 cost calculations we reuse download gouging checks to cover pruning - sectorDownloadPrice, overflow := sectorReadCost( - types.NewCurrency64(1), // 1H - hs.SectorAccessPrice, - hs.BaseRPCPrice, - hs.DownloadBandwidthPrice, - hs.UploadBandwidthPrice, - ) - if overflow { - return fmt.Errorf("%w: overflow detected when computing sector download price", errHostSettingsGouging) - } - dpptb, overflow := sectorDownloadPrice.Mul64WithOverflow(1 << 40 / rhpv2.SectorSize) // sectors per TiB - if overflow { - return fmt.Errorf("%w: overflow detected when computing download price per TiB", errHostSettingsGouging) - } - if !gs.MaxDownloadPrice.IsZero() && dpptb.Cmp(gs.MaxDownloadPrice) > 0 { - return fmt.Errorf("%w: cost per TiB exceeds max dl price: %v > %v", errHostSettingsGouging, dpptb, gs.MaxDownloadPrice) - } - return nil -} - -func checkDownloadGougingRHPv3(gs api.GougingSettings, pt *rhpv3.HostPriceTable) error { - if pt == nil { - return nil - } - sectorDownloadPrice, overflow := sectorReadCostRHPv3(*pt) - if overflow { - return fmt.Errorf("%w: overflow detected when computing sector download price", errPriceTableGouging) - } - dpptb, overflow := sectorDownloadPrice.Mul64WithOverflow(1 << 40 / rhpv2.SectorSize) // sectors per TiB - if overflow { - return fmt.Errorf("%w: overflow detected when computing download price per TiB", errPriceTableGouging) - } - if !gs.MaxDownloadPrice.IsZero() && dpptb.Cmp(gs.MaxDownloadPrice) > 0 { - return fmt.Errorf("%w: cost per TiB exceeds max dl price: %v > %v", errPriceTableGouging, dpptb, gs.MaxDownloadPrice) - } - return nil -} - -func checkUploadGougingRHPv3(gs api.GougingSettings, pt *rhpv3.HostPriceTable) error { - if pt == nil { - return nil - } - sectorUploadPricePerMonth, overflow := sectorUploadCostRHPv3(*pt) - if overflow { - return fmt.Errorf("%w: overflow detected when computing sector price", errPriceTableGouging) - } - uploadPrice, overflow := sectorUploadPricePerMonth.Mul64WithOverflow(1 << 40 / rhpv2.SectorSize) // sectors per TiB - if overflow { - return fmt.Errorf("%w: overflow detected when computing upload price per TiB", errPriceTableGouging) - } - if !gs.MaxUploadPrice.IsZero() && uploadPrice.Cmp(gs.MaxUploadPrice) > 0 { - return fmt.Errorf("%w: cost per TiB exceeds max ul price: %v > %v", errPriceTableGouging, uploadPrice, gs.MaxUploadPrice) - } - return nil -} - -func sectorReadCostRHPv3(pt rhpv3.HostPriceTable) (types.Currency, bool) { - return sectorReadCost( - pt.ReadLengthCost, - pt.ReadBaseCost, - pt.InitBaseCost, - pt.UploadBandwidthCost, - pt.DownloadBandwidthCost, - ) -} - -func sectorReadCost(readLengthCost, readBaseCost, initBaseCost, ulBWCost, dlBWCost types.Currency) (types.Currency, bool) { - // base - base, overflow := readLengthCost.Mul64WithOverflow(rhpv2.SectorSize) - if overflow { - return types.ZeroCurrency, true - } - base, overflow = base.AddWithOverflow(readBaseCost) - if overflow { - return types.ZeroCurrency, true - } - base, overflow = base.AddWithOverflow(initBaseCost) - if overflow { - return types.ZeroCurrency, true - } - // bandwidth - ingress, overflow := ulBWCost.Mul64WithOverflow(32) - if overflow { - return types.ZeroCurrency, true - } - egress, overflow := dlBWCost.Mul64WithOverflow(rhpv2.SectorSize) - if overflow { - return types.ZeroCurrency, true - } - // total - total, overflow := base.AddWithOverflow(ingress) - if overflow { - return types.ZeroCurrency, true - } - total, overflow = total.AddWithOverflow(egress) - if overflow { - return types.ZeroCurrency, true - } - return total, false -} - -func sectorUploadCostRHPv3(pt rhpv3.HostPriceTable) (types.Currency, bool) { - // write - writeCost, overflow := pt.WriteLengthCost.Mul64WithOverflow(rhpv2.SectorSize) - if overflow { - return types.ZeroCurrency, true - } - writeCost, overflow = writeCost.AddWithOverflow(pt.WriteBaseCost) - if overflow { - return types.ZeroCurrency, true - } - writeCost, overflow = writeCost.AddWithOverflow(pt.InitBaseCost) - if overflow { - return types.ZeroCurrency, true - } - // bandwidth - ingress, overflow := pt.UploadBandwidthCost.Mul64WithOverflow(rhpv2.SectorSize) - if overflow { - return types.ZeroCurrency, true - } - // total - total, overflow := writeCost.AddWithOverflow(ingress) - if overflow { - return types.ZeroCurrency, true - } - return total, false -} - -func errsToStr(errs ...error) string { - if err := errors.Join(errs...); err != nil { - return err.Error() } - return "" + return gouging.NewChecker(settings, cs, txnFee, nil, nil) } diff --git a/worker/host.go b/worker/host.go index c4c84be21..9c65bd0c2 100644 --- a/worker/host.go +++ b/worker/host.go @@ -5,20 +5,17 @@ import ( "errors" "fmt" "io" - "math" "time" rhpv2 "go.sia.tech/core/rhp/v2" rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/gouging" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" "go.uber.org/zap" ) -var ( - errFailedToCreatePayment = errors.New("failed to create payment") -) - type ( Host interface { PublicKey() types.PublicKey @@ -26,7 +23,8 @@ type ( DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) error UploadSector(ctx context.Context, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) error - FetchPriceTable(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, err error) + PriceTable(ctx context.Context, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) + PriceTableUnpaid(ctx context.Context) (hpt api.HostPriceTable, err error) FetchRevision(ctx context.Context, fetchTimeout time.Duration) (types.FileContractRevision, error) FundAccount(ctx context.Context, balance types.Currency, rev *types.FileContractRevision) error @@ -49,21 +47,22 @@ type ( siamuxAddr string acc *account + client *rhp3.Client bus Bus contractSpendingRecorder ContractSpendingRecorder logger *zap.SugaredLogger - transportPool *transportPoolV3 priceTables *priceTables } ) var ( _ Host = (*host)(nil) - _ HostManager = (*worker)(nil) + _ HostManager = (*Worker)(nil) ) -func (w *worker) Host(hk types.PublicKey, fcid types.FileContractID, siamuxAddr string) Host { +func (w *Worker) Host(hk types.PublicKey, fcid types.FileContractID, siamuxAddr string) Host { return &host{ + client: w.rhp3Client, hk: hk, acc: w.accounts.ForHost(hk), bus: w.bus, @@ -73,7 +72,6 @@ func (w *worker) Host(hk types.PublicKey, fcid types.FileContractID, siamuxAddr siamuxAddr: siamuxAddr, renterKey: w.deriveRenterKey(hk), accountKey: w.accounts.deriveAccountKey(hk), - transportPool: w.transportPoolV3, priceTables: w.priceTables, } } @@ -81,240 +79,232 @@ func (w *worker) Host(hk types.PublicKey, fcid types.FileContractID, siamuxAddr func (h *host) PublicKey() types.PublicKey { return h.hk } func (h *host) DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) (err error) { - pt, err := h.priceTables.fetch(ctx, h.hk, nil) - if err != nil { - return err - } - hpt := pt.HostPriceTable - - // check for download gouging specifically - gc, err := GougingCheckerFromContext(ctx, overpay) - if err != nil { - return err - } - if breakdown := gc.Check(nil, &hpt); breakdown.DownloadErr != "" { - return fmt.Errorf("%w: %v", errPriceTableGouging, breakdown.DownloadErr) - } - - // return errBalanceInsufficient if balance insufficient - defer func() { - if isBalanceInsufficient(err) { - err = fmt.Errorf("%w %v, err: %v", errBalanceInsufficient, h.hk, err) + var amount types.Currency + return h.acc.WithWithdrawal(ctx, func() (types.Currency, error) { + pt, uptc, err := h.priceTables.fetch(ctx, h.hk, nil) + if err != nil { + return types.ZeroCurrency, err } - }() - - return h.acc.WithWithdrawal(ctx, func() (amount types.Currency, err error) { - err = h.transportPool.withTransportV3(ctx, h.hk, h.siamuxAddr, func(ctx context.Context, t *transportV3) error { - cost, err := readSectorCost(hpt, uint64(length)) - if err != nil { - return err - } + hpt := pt.HostPriceTable + amount = uptc - payment := rhpv3.PayByEphemeralAccount(h.acc.id, cost, pt.HostBlockHeight+defaultWithdrawalExpiryBlocks, h.accountKey) - cost, refund, err := RPCReadSector(ctx, t, w, hpt, &payment, offset, length, root) - if err != nil { - return err - } + // check for download gouging specifically + gc, err := GougingCheckerFromContext(ctx, overpay) + if err != nil { + return amount, err + } + if breakdown := gc.Check(nil, &hpt); breakdown.DownloadErr != "" { + return amount, fmt.Errorf("%w: %v", gouging.ErrPriceTableGouging, breakdown.DownloadErr) + } - amount = cost.Sub(refund) - return nil - }) - return + cost, err := h.client.ReadSector(ctx, offset, length, root, w, h.hk, h.siamuxAddr, h.acc.id, h.accountKey, hpt) + if err != nil { + return amount, err + } + return amount.Add(cost), nil }) } -func (h *host) UploadSector(ctx context.Context, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) (err error) { +func (h *host) UploadSector(ctx context.Context, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) error { // fetch price table - pt, err := h.priceTable(ctx, nil) - if err != nil { - return err - } - - // prepare payment - // - // TODO: change to account payments once we have the means to check for an - // insufficient balance error - expectedCost, _, _, err := uploadSectorCost(pt, rev.WindowEnd) - if err != nil { + var pt rhpv3.HostPriceTable + if err := h.acc.WithWithdrawal(ctx, func() (amount types.Currency, err error) { + pt, amount, err = h.priceTable(ctx, nil) + return + }); err != nil { return err } - if rev.RevisionNumber == math.MaxUint64 { - return fmt.Errorf("revision number has reached max, fcid %v", rev.ParentID) - } - payment, ok := rhpv3.PayByContract(&rev, expectedCost, h.acc.id, h.renterKey) - if !ok { - return errFailedToCreatePayment - } - - var cost types.Currency - err = h.transportPool.withTransportV3(ctx, h.hk, h.siamuxAddr, func(ctx context.Context, t *transportV3) error { - cost, err = RPCAppendSector(ctx, t, h.renterKey, pt, &rev, &payment, sectorRoot, sector) - return err - }) + // upload + cost, err := h.client.AppendSector(ctx, sectorRoot, sector, &rev, h.hk, h.siamuxAddr, h.acc.id, pt, h.renterKey) if err != nil { - return err + return fmt.Errorf("failed to upload sector: %w", err) } - // record spending h.contractSpendingRecorder.Record(rev, api.ContractSpending{Uploads: cost}) return nil } func (h *host) RenewContract(ctx context.Context, rrr api.RHPRenewRequest) (_ rhpv2.ContractRevision, _ []types.Transaction, _, _ types.Currency, err error) { - // Try to get a valid pricetable. - ptCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - var pt *rhpv3.HostPriceTable - hpt, err := h.priceTables.fetch(ptCtx, h.hk, nil) - if err == nil { - pt = &hpt.HostPriceTable - } else { - h.logger.Infof("unable to fetch price table for renew: %v", err) + gc, err := h.gougingChecker(ctx, false) + if err != nil { + return rhpv2.ContractRevision{}, nil, types.ZeroCurrency, types.ZeroCurrency, err + } + revision, err := h.client.Revision(ctx, h.fcid, h.hk, h.siamuxAddr) + if err != nil { + return rhpv2.ContractRevision{}, nil, types.ZeroCurrency, types.ZeroCurrency, err } - var contractPrice types.Currency - var rev rhpv2.ContractRevision - var txnSet []types.Transaction - var renewErr error - var fundAmount types.Currency - err = h.transportPool.withTransportV3(ctx, h.hk, h.siamuxAddr, func(ctx context.Context, t *transportV3) (err error) { - // NOTE: to avoid an edge case where the contract is drained and can - // therefore not be used to pay for the revision, we simply don't pay - // for it. - _, err = RPCLatestRevision(ctx, t, h.fcid, func(revision *types.FileContractRevision) (rhpv3.HostPriceTable, rhpv3.PaymentMethod, error) { - // Renew contract. - rev, txnSet, contractPrice, fundAmount, renewErr = RPCRenew(ctx, rrr, h.bus, t, pt, *revision, h.renterKey, h.logger) - return rhpv3.HostPriceTable{}, nil, nil - }) - return err - }) + // helper to discard txn on error + discardTxn := func(ctx context.Context, txn types.Transaction, err *error) { + if *err == nil { + return + } + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + if dErr := h.bus.WalletDiscard(ctx, txn); dErr != nil { + h.logger.Errorf("%v: %s, failed to discard txn: %v", *err, dErr) + } + cancel() + } + + // helper to sign txn + signTxn := func(ctx context.Context, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error { + // sign txn + return h.bus.WalletSign(ctx, txn, toSign, cf) + } + + // helper to prepare contract renewal + prepareRenew := func(ctx context.Context, revision types.FileContractRevision, hostAddress, renterAddress types.Address, renterKey types.PrivateKey, renterFunds, minNewCollateral, maxFundAmount types.Currency, pt rhpv3.HostPriceTable, endHeight, windowSize, expectedStorage uint64) (api.WalletPrepareRenewResponse, func(context.Context, types.Transaction, *error), error) { + resp, err := h.bus.WalletPrepareRenew(ctx, revision, hostAddress, renterAddress, renterKey, renterFunds, minNewCollateral, maxFundAmount, pt, endHeight, windowSize, expectedStorage) + if err != nil { + return api.WalletPrepareRenewResponse{}, nil, err + } + return resp, discardTxn, nil + } + + // renew contract + rev, txnSet, contractPrice, fundAmount, err := h.client.Renew(ctx, rrr, gc, prepareRenew, signTxn, revision, h.renterKey) if err != nil { return rhpv2.ContractRevision{}, nil, contractPrice, fundAmount, err } - return rev, txnSet, contractPrice, fundAmount, renewErr + return rev, txnSet, contractPrice, fundAmount, err } -func (h *host) FetchPriceTable(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, err error) { +func (h *host) PriceTableUnpaid(ctx context.Context) (api.HostPriceTable, error) { + return h.client.PriceTableUnpaid(ctx, h.hk, h.siamuxAddr) +} + +func (h *host) PriceTable(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, cost types.Currency, err error) { // fetchPT is a helper function that performs the RPC given a payment function - fetchPT := func(paymentFn PriceTablePaymentFunc) (hpt api.HostPriceTable, err error) { - err = h.transportPool.withTransportV3(ctx, h.hk, h.siamuxAddr, func(ctx context.Context, t *transportV3) (err error) { - hpt, err = RPCPriceTable(ctx, t, paymentFn) - return - }) - return + fetchPT := func(paymentFn rhp3.PriceTablePaymentFunc) (api.HostPriceTable, error) { + return h.client.PriceTable(ctx, h.hk, h.siamuxAddr, paymentFn) } - // pay by contract if a revision is given + // fetch the price table if rev != nil { - return fetchPT(h.preparePriceTableContractPayment(rev)) + hpt, err = fetchPT(rhp3.PreparePriceTableContractPayment(rev, h.acc.id, h.renterKey)) + } else { + hpt, err = fetchPT(rhp3.PreparePriceTableAccountPayment(h.accountKey)) } - // pay by account - return fetchPT(h.preparePriceTableAccountPayment()) + // set the cost + if err == nil { + cost = hpt.UpdatePriceTableCost + } + return } -func (h *host) FundAccount(ctx context.Context, balance types.Currency, rev *types.FileContractRevision) error { +// FetchRevision tries to fetch a contract revision from the host. +func (h *host) FetchRevision(ctx context.Context, fetchTimeout time.Duration) (types.FileContractRevision, error) { + if fetchTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, fetchTimeout) + defer cancel() + } + // Try to fetch the revision with an account first. + return h.client.Revision(ctx, h.fcid, h.hk, h.siamuxAddr) +} + +func (h *host) FundAccount(ctx context.Context, desired types.Currency, rev *types.FileContractRevision) error { + log := h.logger.With( + zap.Stringer("host", h.hk), + zap.Stringer("account", h.acc.id), + ) + + // ensure we have at least 2H in the contract to cover the costs + if types.NewCurrency64(2).Cmp(rev.ValidRenterPayout()) >= 0 { + return fmt.Errorf("insufficient funds to fund account: %v <= %v", rev.ValidRenterPayout(), types.NewCurrency64(2)) + } + // fetch current balance - curr, err := h.acc.Balance(ctx) + balance, err := h.acc.Balance(ctx) if err != nil { return err } // return early if we have the desired balance - if curr.Cmp(balance) >= 0 { + if balance.Cmp(desired) >= 0 { return nil } - deposit := balance.Sub(curr) + // calculate the deposit amount + deposit := desired.Sub(balance) return h.acc.WithDeposit(ctx, func() (types.Currency, error) { - if err := h.transportPool.withTransportV3(ctx, h.hk, h.siamuxAddr, func(ctx context.Context, t *transportV3) error { - // fetch pricetable - pt, err := h.priceTable(ctx, rev) - if err != nil { - return err - } - - // check whether we have money left in the contract - cost := types.NewCurrency64(1) - if cost.Cmp(rev.ValidRenterPayout()) >= 0 { - return fmt.Errorf("insufficient funds to fund account: %v <= %v", rev.ValidRenterPayout(), cost) - } - availableFunds := rev.ValidRenterPayout().Sub(cost) + // fetch pricetable directly to bypass the gouging check + pt, _, err := h.priceTables.fetch(ctx, h.hk, rev) + if err != nil { + return types.ZeroCurrency, err + } - // cap the deposit amount by the money that's left in the contract - if deposit.Cmp(availableFunds) > 0 { - deposit = availableFunds - } + // cap the deposit by what's left in the contract + cost := types.NewCurrency64(1) + availableFunds := rev.ValidRenterPayout().Sub(cost) + if deposit.Cmp(availableFunds) > 0 { + deposit = availableFunds + } - // create the payment - amount := deposit.Add(cost) - payment, err := payByContract(rev, amount, rhpv3.Account{}, h.renterKey) // no account needed for funding - if err != nil { - return err + // fund the account + if err := h.client.FundAccount(ctx, rev, h.hk, h.siamuxAddr, deposit, h.acc.id, pt.HostPriceTable, h.renterKey); err != nil { + if rhp3.IsBalanceMaxExceeded(err) { + err = errors.Join(err, h.acc.as.ScheduleSync(ctx, h.acc.id, h.hk)) } + return types.ZeroCurrency, fmt.Errorf("failed to fund account with %v; %w", deposit, err) + } - // fund the account - if err := RPCFundAccount(ctx, t, &payment, h.acc.id, pt.UID); err != nil { - return fmt.Errorf("failed to fund account with %v (excluding cost %v);%w", deposit, cost, err) - } + // record the spend + h.contractSpendingRecorder.Record(*rev, api.ContractSpending{FundAccount: deposit.Add(cost)}) - // record the spend - h.contractSpendingRecorder.Record(*rev, api.ContractSpending{FundAccount: amount}) - return nil - }); err != nil { - return types.ZeroCurrency, err - } + // log the account balance after funding + log.Debugw("fund account succeeded", + "balance", balance.ExactString(), + "deposit", deposit.ExactString(), + ) return deposit, nil }) } func (h *host) SyncAccount(ctx context.Context, rev *types.FileContractRevision) error { - // fetch pricetable - pt, err := h.priceTable(ctx, rev) + // fetch pricetable directly to bypass the gouging check + pt, _, err := h.priceTables.fetch(ctx, h.hk, rev) if err != nil { return err } + // check only the unused defaults + gc, err := GougingCheckerFromContext(ctx, false) + if err != nil { + return err + } else if err := gc.CheckUnusedDefaults(pt.HostPriceTable); err != nil { + return fmt.Errorf("%w: %v", gouging.ErrPriceTableGouging, err) + } + return h.acc.WithSync(ctx, func() (types.Currency, error) { - var balance types.Currency - err := h.transportPool.withTransportV3(ctx, h.hk, h.siamuxAddr, func(ctx context.Context, t *transportV3) error { - payment, err := payByContract(rev, types.NewCurrency64(1), h.acc.id, h.renterKey) - if err != nil { - return err - } - balance, err = RPCAccountBalance(ctx, t, &payment, h.acc.id, pt.UID) - return err - }) - return balance, err + return h.client.SyncAccount(ctx, rev, h.hk, h.siamuxAddr, h.acc.id, pt.UID, h.renterKey) }) } -// preparePriceTableAccountPayment prepare a payment function to pay for a price -// table from the given host using the provided revision. -// -// NOTE: This is the preferred way of paying for a price table since it is -// faster and doesn't require locking a contract. -func (h *host) preparePriceTableAccountPayment() PriceTablePaymentFunc { - return func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) { - account := rhpv3.Account(h.accountKey.PublicKey()) - payment := rhpv3.PayByEphemeralAccount(account, pt.UpdatePriceTableCost, pt.HostBlockHeight+defaultWithdrawalExpiryBlocks, h.accountKey) - return &payment, nil +func (h *host) gougingChecker(ctx context.Context, criticalMigration bool) (gouging.Checker, error) { + gp, err := h.bus.GougingParams(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get gouging params: %w", err) } + return newGougingChecker(gp.GougingSettings, gp.ConsensusState, gp.TransactionFee, criticalMigration), nil } -// preparePriceTableContractPayment prepare a payment function to pay for a -// price table from the given host using the provided revision. -// -// NOTE: This way of paying for a price table should only be used if payment by -// EA is not possible or if we already need a contract revision anyway. e.g. -// funding an EA. -func (h *host) preparePriceTableContractPayment(rev *types.FileContractRevision) PriceTablePaymentFunc { - return func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) { - refundAccount := rhpv3.Account(h.accountKey.PublicKey()) - payment, err := payByContract(rev, pt.UpdatePriceTableCost, refundAccount, h.renterKey) - if err != nil { - return nil, err - } - return &payment, nil +// priceTable fetches a price table from the host. If a revision is provided, it +// will be used to pay for the price table. The returned price table is +// guaranteed to be safe to use. +func (h *host) priceTable(ctx context.Context, rev *types.FileContractRevision) (rhpv3.HostPriceTable, types.Currency, error) { + pt, cost, err := h.priceTables.fetch(ctx, h.hk, rev) + if err != nil { + return rhpv3.HostPriceTable{}, types.ZeroCurrency, err + } + gc, err := GougingCheckerFromContext(ctx, false) + if err != nil { + return rhpv3.HostPriceTable{}, cost, err + } + if breakdown := gc.Check(nil, &pt.HostPriceTable); breakdown.Gouging() { + return rhpv3.HostPriceTable{}, cost, fmt.Errorf("%w: %v", gouging.ErrPriceTableGouging, breakdown) } + return pt.HostPriceTable, cost, nil } diff --git a/worker/host_test.go b/worker/host_test.go index e329e4b90..8bbecaeff 100644 --- a/worker/host_test.go +++ b/worker/host_test.go @@ -13,6 +13,7 @@ import ( rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" "go.sia.tech/renterd/internal/test" "lukechampine.com/frand" ) @@ -82,7 +83,7 @@ func (h *testHost) PublicKey() types.PublicKey { func (h *testHost) DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) error { sector, exist := h.Sector(root) if !exist { - return errSectorNotFound + return rhp3.ErrSectorNotFound } if offset+length > rhpv2.SectorSize { return errSectorOutOfBounds @@ -110,7 +111,11 @@ func (h *testHost) FetchRevision(ctx context.Context, fetchTimeout time.Duration return rev, nil } -func (h *testHost) FetchPriceTable(ctx context.Context, rev *types.FileContractRevision) (api.HostPriceTable, error) { +func (h *testHost) PriceTable(ctx context.Context, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) { + return h.hptFn(), types.ZeroCurrency, nil +} + +func (h *testHost) PriceTableUnpaid(ctx context.Context) (api.HostPriceTable, error) { return h.hptFn(), nil } diff --git a/worker/interactions.go b/worker/interactions.go deleted file mode 100644 index 34e47953a..000000000 --- a/worker/interactions.go +++ /dev/null @@ -1,27 +0,0 @@ -package worker - -import ( - "go.sia.tech/renterd/api" -) - -type ( - HostInteractionRecorder interface { - RecordHostScan(...api.HostScan) - RecordPriceTableUpdate(...api.HostPriceTableUpdate) - } -) - -func isSuccessfulInteraction(err error) bool { - // No error always means success. - if err == nil { - return true - } - // List of errors that are considered successful interactions. - if isInsufficientFunds(err) { - return true - } - if isBalanceInsufficient(err) { - return true - } - return false -} diff --git a/worker/memory.go b/worker/memory.go index 1dbd680ec..d6f459bc8 100644 --- a/worker/memory.go +++ b/worker/memory.go @@ -41,7 +41,14 @@ type ( var _ MemoryManager = (*memoryManager)(nil) -func newMemoryManager(logger *zap.SugaredLogger, maxMemory uint64) MemoryManager { +func newMemoryManager(maxMemory uint64, logger *zap.Logger) MemoryManager { + return newMemoryManagerCustom(maxMemory, logger.Named("memorymanager").Sugar()) +} + +// newMemoryManagerCustom is an internal constructor that doesn't name the +// logger being passed in, this avoids that we chain the logger name for every +// limit memory manager being created. +func newMemoryManagerCustom(maxMemory uint64, logger *zap.SugaredLogger) MemoryManager { mm := &memoryManager{ logger: logger, totalAvailable: maxMemory, @@ -57,7 +64,7 @@ func (mm *memoryManager) Limit(amt uint64) (MemoryManager, error) { } return &limitMemoryManager{ parent: mm, - child: newMemoryManager(mm.logger, amt), + child: newMemoryManagerCustom(amt, mm.logger), }, nil } diff --git a/worker/migrations.go b/worker/migrations.go index 075642dd5..d2d1c6474 100644 --- a/worker/migrations.go +++ b/worker/migrations.go @@ -10,7 +10,7 @@ import ( "go.sia.tech/renterd/object" ) -func (w *worker) migrate(ctx context.Context, s object.Slab, contractSet string, dlContracts, ulContracts []api.ContractMetadata, bh uint64) (int, bool, error) { +func (w *Worker) migrate(ctx context.Context, s object.Slab, contractSet string, dlContracts, ulContracts []api.ContractMetadata, bh uint64) (int, bool, error) { // make a map of good hosts goodHosts := make(map[types.PublicKey]map[types.FileContractID]bool) for _, c := range ulContracts { diff --git a/worker/mocks_test.go b/worker/mocks_test.go index 192f4c169..f982437a7 100644 --- a/worker/mocks_test.go +++ b/worker/mocks_test.go @@ -15,6 +15,7 @@ import ( "go.sia.tech/core/types" "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/gouging" "go.sia.tech/renterd/object" "go.sia.tech/renterd/webhooks" ) @@ -61,7 +62,7 @@ func (*alerterMock) Alerts(_ context.Context, opts alerts.AlertsOpts) (resp aler func (*alerterMock) RegisterAlert(context.Context, alerts.Alert) error { return nil } func (*alerterMock) DismissAlerts(context.Context, ...types.Hash256) error { return nil } -var _ ConsensusState = (*chainMock)(nil) +var _ gouging.ConsensusState = (*chainMock)(nil) type chainMock struct { cs api.ConsensusState @@ -83,6 +84,7 @@ type busMock struct { *objectStoreMock *settingStoreMock *syncerMock + *s3Mock *walletMock *webhookBroadcasterMock *webhookStoreMock @@ -634,6 +636,56 @@ func (os *objectStoreMock) forEachObject(fn func(bucket, path string, o object.O } } +type s3Mock struct{} + +func (*s3Mock) CreateBucket(context.Context, string, api.CreateBucketOptions) error { + return nil +} + +func (*s3Mock) DeleteBucket(context.Context, string) error { + return nil +} + +func (*s3Mock) ListBuckets(context.Context) (buckets []api.Bucket, err error) { + return nil, nil +} + +func (*s3Mock) CopyObject(context.Context, string, string, string, string, api.CopyObjectOptions) (om api.ObjectMetadata, err error) { + return api.ObjectMetadata{}, nil +} + +func (*s3Mock) ListObjects(context.Context, string, api.ListObjectOptions) (resp api.ObjectsListResponse, err error) { + return api.ObjectsListResponse{}, nil +} + +func (*s3Mock) AbortMultipartUpload(context.Context, string, string, string) (err error) { + return nil +} + +func (*s3Mock) CompleteMultipartUpload(context.Context, string, string, string, []api.MultipartCompletedPart, api.CompleteMultipartOptions) (_ api.MultipartCompleteResponse, err error) { + return api.MultipartCompleteResponse{}, nil +} + +func (*s3Mock) CreateMultipartUpload(context.Context, string, string, api.CreateMultipartOptions) (api.MultipartCreateResponse, error) { + return api.MultipartCreateResponse{}, nil +} + +func (*s3Mock) MultipartUploads(ctx context.Context, bucket, prefix, keyMarker, uploadIDMarker string, maxUploads int) (resp api.MultipartListUploadsResponse, _ error) { + return api.MultipartListUploadsResponse{}, nil +} + +func (*s3Mock) MultipartUploadParts(ctx context.Context, bucket, object string, uploadID string, marker int, limit int64) (resp api.MultipartListPartsResponse, _ error) { + return api.MultipartListPartsResponse{}, nil +} + +func (*s3Mock) S3AuthenticationSettings(context.Context) (as api.S3AuthenticationSettings, err error) { + return api.S3AuthenticationSettings{}, nil +} + +func (*s3Mock) UpdateSetting(context.Context, string, interface{}) error { + return nil +} + var _ SettingStore = (*settingStoreMock)(nil) type settingStoreMock struct{} @@ -697,3 +749,7 @@ type webhookStoreMock struct{} func (*webhookStoreMock) RegisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { return nil } + +func (*webhookStoreMock) UnregisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { + return nil +} diff --git a/worker/net.go b/worker/net.go deleted file mode 100644 index 8401062ab..000000000 --- a/worker/net.go +++ /dev/null @@ -1,11 +0,0 @@ -package worker - -import ( - "context" - "net" -) - -func dial(ctx context.Context, hostIP string) (net.Conn, error) { - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", hostIP) - return conn, err -} diff --git a/worker/pricetables.go b/worker/pricetables.go index 592037146..fc3901f67 100644 --- a/worker/pricetables.go +++ b/worker/pricetables.go @@ -11,6 +11,7 @@ import ( rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" "lukechampine.com/frand" ) @@ -58,7 +59,7 @@ type ( } ) -func (w *worker) initPriceTables() { +func (w *Worker) initPriceTables() { if w.priceTables != nil { panic("priceTables already initialized") // developer error } @@ -75,7 +76,7 @@ func newPriceTables(hm HostManager, hs HostStore) *priceTables { } // fetch returns a price table for the given host -func (pts *priceTables) fetch(ctx context.Context, hk types.PublicKey, rev *types.FileContractRevision) (api.HostPriceTable, error) { +func (pts *priceTables) fetch(ctx context.Context, hk types.PublicKey, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) { pts.mu.Lock() pt, exists := pts.priceTables[hk] if !exists { @@ -105,7 +106,7 @@ func (pt *priceTable) ongoingUpdate() (bool, *priceTableUpdate) { return ongoing, pt.update } -func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, err error) { +func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, cost types.Currency, err error) { // grab the current price table p.mu.Lock() hpt = p.hpt @@ -115,7 +116,7 @@ func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) // current price table is considered to gouge on the block height gc, err := GougingCheckerFromContext(ctx, false) if err != nil { - return api.HostPriceTable{}, err + return api.HostPriceTable{}, types.ZeroCurrency, err } // figure out whether we should update the price table, if not we can return @@ -137,10 +138,10 @@ func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) } else if ongoing { select { case <-ctx.Done(): - return api.HostPriceTable{}, fmt.Errorf("%w; %w", errPriceTableUpdateTimedOut, context.Cause(ctx)) + return api.HostPriceTable{}, types.ZeroCurrency, fmt.Errorf("%w; %w", errPriceTableUpdateTimedOut, context.Cause(ctx)) case <-update.done: } - return update.hpt, update.err + return update.hpt, types.ZeroCurrency, update.err } // this thread is updating the price table @@ -166,28 +167,43 @@ func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) // sanity check the host has been scanned before fetching the price table if !host.Scanned { - return api.HostPriceTable{}, fmt.Errorf("host %v was not scanned", p.hk) + return api.HostPriceTable{}, types.ZeroCurrency, fmt.Errorf("host %v was not scanned", p.hk) } // otherwise fetch it h := p.hm.Host(p.hk, types.FileContractID{}, host.Settings.SiamuxAddr()) - hpt, err = h.FetchPriceTable(ctx, rev) + hpt, cost, err = h.PriceTable(ctx, rev) // record it in the background - go func(hpt api.HostPriceTable, success bool) { - p.hs.RecordPriceTables(context.Background(), []api.HostPriceTableUpdate{ - { - HostKey: p.hk, - Success: success, - Timestamp: time.Now(), - PriceTable: hpt, - }, - }) - }(hpt, isSuccessfulInteraction(err)) + if shouldRecordPriceTable(err) { + go func(hpt api.HostPriceTable, success bool) { + p.hs.RecordPriceTables(context.Background(), []api.HostPriceTableUpdate{ + { + HostKey: p.hk, + Success: success, + Timestamp: time.Now(), + PriceTable: hpt, + }, + }) + }(hpt, err == nil) + } // handle error after recording if err != nil { - return api.HostPriceTable{}, fmt.Errorf("failed to update pricetable, err %v", err) + return api.HostPriceTable{}, types.ZeroCurrency, fmt.Errorf("failed to update pricetable, err %v", err) } return } + +func shouldRecordPriceTable(err error) bool { + // List of errors that are considered 'successful' failures. Meaning that + // the host was reachable but we were unable to obtain a price table due to + // reasons out of the host's control. + if rhp3.IsInsufficientFunds(err) { + return false + } + if rhp3.IsBalanceInsufficient(err) { + return false + } + return true +} diff --git a/worker/pricetables_test.go b/worker/pricetables_test.go index 22c021ccb..ed918fa1a 100644 --- a/worker/pricetables_test.go +++ b/worker/pricetables_test.go @@ -58,14 +58,14 @@ func TestPriceTables(t *testing.T) { // update ctx, cancel := context.WithCancel(gCtx) cancel() - _, err := pts.fetch(ctx, h.hk, nil) + _, _, err := pts.fetch(ctx, h.hk, nil) if !errors.Is(err, errPriceTableUpdateTimedOut) { t.Fatal("expected errPriceTableUpdateTimedOut, got", err) } - // unblock and assert we receive a valid price table + // unblock and assert we paid for the price table close(fetchPTBlockChan) - update, err := pts.fetch(gCtx, h.hk, nil) + update, _, err := pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != validPT.UID { @@ -75,7 +75,7 @@ func TestPriceTables(t *testing.T) { // refresh the price table on the host, update again, assert we receive the // same price table as it hasn't expired yet h.hi.PriceTable = newTestHostPriceTable() - update, err = pts.fetch(gCtx, h.hk, nil) + update, _, err = pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != validPT.UID { @@ -86,7 +86,7 @@ func TestPriceTables(t *testing.T) { pts.priceTables[h.hk].hpt.Expiry = time.Now() // fetch it again and assert we updated the price table - update, err = pts.fetch(gCtx, h.hk, nil) + update, _, err = pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != h.hi.PriceTable.UID { @@ -97,7 +97,7 @@ func TestPriceTables(t *testing.T) { // the price table since it's not expired validPT = h.hi.PriceTable h.hi.PriceTable = newTestHostPriceTable() - update, err = pts.fetch(gCtx, h.hk, nil) + update, _, err = pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != validPT.UID { @@ -110,7 +110,7 @@ func TestPriceTables(t *testing.T) { cm.cs.BlockHeight = validPT.HostBlockHeight + uint64(blockHeightLeeway) - priceTableBlockHeightLeeway // fetch it again and assert we updated the price table - update, err = pts.fetch(gCtx, h.hk, nil) + update, _, err = pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != h.hi.PriceTable.UID { diff --git a/worker/rhpv3.go b/worker/rhpv3.go deleted file mode 100644 index 4d7518d2e..000000000 --- a/worker/rhpv3.go +++ /dev/null @@ -1,1134 +0,0 @@ -package worker - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "math" - "math/big" - "net" - "sync" - "time" - - rhpv2 "go.sia.tech/core/rhp/v2" - rhpv3 "go.sia.tech/core/rhp/v3" - "go.sia.tech/core/types" - "go.sia.tech/mux/v1" - "go.sia.tech/renterd/api" - "go.sia.tech/renterd/internal/utils" - "go.sia.tech/siad/crypto" - "go.uber.org/zap" -) - -const ( - // accountLockingDuration is the time for which an account lock remains - // reserved on the bus after locking it. - accountLockingDuration = 30 * time.Second - - // defaultRPCResponseMaxSize is the default maxSize we use whenever we read - // an RPC response. - defaultRPCResponseMaxSize = 100 * 1024 // 100 KiB - - // defaultWithdrawalExpiryBlocks is the number of blocks we add to the - // current blockheight when we define an expiry block height for withdrawal - // messages. - defaultWithdrawalExpiryBlocks = 12 - - // maxPriceTableSize defines the maximum size of a price table - maxPriceTableSize = 16 * 1024 - - // responseLeeway is the amount of leeway given to the maxLen when we read - // the response in the ReadSector RPC - responseLeeway = 1 << 12 // 4 KiB -) - -var ( - // errHost is used to wrap rpc errors returned by the host. - errHost = errors.New("host responded with error") - - // errTransport is used to wrap rpc errors caused by the transport. - errTransport = errors.New("transport error") - - // errDialTransport is returned when the worker could not dial the host. - errDialTransport = errors.New("could not dial transport") - - // errBalanceInsufficient occurs when a withdrawal failed because the - // account balance was insufficient. - errBalanceInsufficient = errors.New("ephemeral account balance was insufficient") - - // errBalanceMaxExceeded occurs when a deposit would push the account's - // balance over the maximum allowed ephemeral account balance. - errBalanceMaxExceeded = errors.New("ephemeral account maximum balance exceeded") - - // errMaxRevisionReached occurs when trying to revise a contract that has - // already reached the highest possible revision number. Usually happens - // when trying to use a renewed contract. - errMaxRevisionReached = errors.New("contract has reached the maximum number of revisions") - - // errPriceTableExpired is returned by the host when the price table that - // corresponds to the id it was given is already expired and thus no longer - // valid. - errPriceTableExpired = errors.New("price table requested is expired") - - // errPriceTableNotFound is returned by the host when it can not find a - // price table that corresponds with the id we sent it. - errPriceTableNotFound = errors.New("price table not found") - - // errSectorNotFound is returned by the host when it can not find the - // requested sector. - errSectorNotFoundOld = errors.New("could not find the desired sector") - errSectorNotFound = errors.New("sector not found") - - // errWithdrawalsInactive occurs when the host is (perhaps temporarily) - // unsynced and has disabled its account manager. - errWithdrawalsInactive = errors.New("ephemeral account withdrawals are inactive because the host is not synced") - - // errWithdrawalExpired is returned by the host when the withdrawal request - // has an expiry block height that is in the past. - errWithdrawalExpired = errors.New("withdrawal request expired") -) - -// IsErrHost indicates whether an error was returned by a host as part of an RPC. -func IsErrHost(err error) bool { - return utils.IsErr(err, errHost) -} - -func isBalanceInsufficient(err error) bool { return utils.IsErr(err, errBalanceInsufficient) } -func isBalanceMaxExceeded(err error) bool { return utils.IsErr(err, errBalanceMaxExceeded) } -func isClosedStream(err error) bool { - return utils.IsErr(err, mux.ErrClosedStream) || utils.IsErr(err, net.ErrClosed) -} -func isInsufficientFunds(err error) bool { return utils.IsErr(err, ErrInsufficientFunds) } -func isPriceTableExpired(err error) bool { return utils.IsErr(err, errPriceTableExpired) } -func isPriceTableGouging(err error) bool { return utils.IsErr(err, errPriceTableGouging) } -func isPriceTableNotFound(err error) bool { return utils.IsErr(err, errPriceTableNotFound) } -func isSectorNotFound(err error) bool { - return utils.IsErr(err, errSectorNotFound) || utils.IsErr(err, errSectorNotFoundOld) -} -func isWithdrawalsInactive(err error) bool { return utils.IsErr(err, errWithdrawalsInactive) } -func isWithdrawalExpired(err error) bool { return utils.IsErr(err, errWithdrawalExpired) } - -// wrapRPCErr extracts the innermost error, wraps it in either a errHost or -// errTransport and finally wraps it using the provided fnName. -func wrapRPCErr(err *error, fnName string) { - if *err == nil { - return - } - innerErr := *err - for errors.Unwrap(innerErr) != nil { - innerErr = errors.Unwrap(innerErr) - } - if errors.As(*err, new(*rhpv3.RPCError)) { - *err = fmt.Errorf("%w: '%w'", errHost, innerErr) - } else { - *err = fmt.Errorf("%w: '%w'", errTransport, innerErr) - } - *err = fmt.Errorf("%s: %w", fnName, *err) -} - -// transportV3 is a reference-counted wrapper for rhpv3.Transport. -type transportV3 struct { - refCount uint64 // locked by pool - - mu sync.Mutex - hostKey types.PublicKey - siamuxAddr string - t *rhpv3.Transport -} - -type streamV3 struct { - cancel context.CancelFunc - *rhpv3.Stream -} - -func (s *streamV3) ReadResponse(resp rhpv3.ProtocolObject, maxLen uint64) (err error) { - defer wrapRPCErr(&err, "ReadResponse") - return s.Stream.ReadResponse(resp, maxLen) -} - -func (s *streamV3) WriteResponse(resp rhpv3.ProtocolObject) (err error) { - defer wrapRPCErr(&err, "WriteResponse") - return s.Stream.WriteResponse(resp) -} - -func (s *streamV3) ReadRequest(req rhpv3.ProtocolObject, maxLen uint64) (err error) { - defer wrapRPCErr(&err, "ReadRequest") - return s.Stream.ReadRequest(req, maxLen) -} - -func (s *streamV3) WriteRequest(rpcID types.Specifier, req rhpv3.ProtocolObject) (err error) { - defer wrapRPCErr(&err, "WriteRequest") - return s.Stream.WriteRequest(rpcID, req) -} - -// Close closes the stream and cancels the goroutine launched by DialStream. -func (s *streamV3) Close() error { - s.cancel() - return s.Stream.Close() -} - -// DialStream dials a new stream on the transport. -func (t *transportV3) DialStream(ctx context.Context) (*streamV3, error) { - t.mu.Lock() - if t.t == nil { - start := time.Now() - newTransport, err := dialTransport(ctx, t.siamuxAddr, t.hostKey) - if err != nil { - t.mu.Unlock() - return nil, fmt.Errorf("DialStream: %w: %w (%v)", errDialTransport, err, time.Since(start)) - } - t.t = newTransport - } - transport := t.t - t.mu.Unlock() - - // Close the stream when the context is closed to unblock any reads or - // writes. - stream := transport.DialStream() - - // Apply a sane timeout to the stream. - if err := stream.SetDeadline(time.Now().Add(5 * time.Minute)); err != nil { - _ = stream.Close() - return nil, err - } - - // Make sure the stream is closed when the context is closed. - doneCtx, doneFn := context.WithCancel(ctx) - go func() { - select { - case <-doneCtx.Done(): - case <-ctx.Done(): - _ = stream.Close() - } - }() - return &streamV3{ - Stream: stream, - cancel: doneFn, - }, nil -} - -// transportPoolV3 is a pool of rhpv3.Transports which allows for reusing them. -type transportPoolV3 struct { - mu sync.Mutex - pool map[string]*transportV3 -} - -func newTransportPoolV3() *transportPoolV3 { - return &transportPoolV3{ - pool: make(map[string]*transportV3), - } -} - -func dialTransport(ctx context.Context, siamuxAddr string, hostKey types.PublicKey) (*rhpv3.Transport, error) { - // Dial host. - conn, err := dial(ctx, siamuxAddr) - if err != nil { - return nil, err - } - - // Upgrade to rhpv3.Transport. - var t *rhpv3.Transport - done := make(chan struct{}) - go func() { - t, err = rhpv3.NewRenterTransport(conn, hostKey) - close(done) - }() - select { - case <-ctx.Done(): - conn.Close() - <-done - return nil, context.Cause(ctx) - case <-done: - return t, err - } -} - -func (p *transportPoolV3) withTransportV3(ctx context.Context, hostKey types.PublicKey, siamuxAddr string, fn func(context.Context, *transportV3) error) (err error) { - // Create or fetch transport. - p.mu.Lock() - t, found := p.pool[siamuxAddr] - if !found { - t = &transportV3{ - hostKey: hostKey, - siamuxAddr: siamuxAddr, - } - p.pool[siamuxAddr] = t - } - t.refCount++ - p.mu.Unlock() - - // Execute function. - err = func() (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic (withTransportV3): %v", r) - } - }() - return fn(ctx, t) - }() - - // Decrement refcounter again and clean up pool. - p.mu.Lock() - t.refCount-- - if t.refCount == 0 { - // Cleanup - if t.t != nil { - _ = t.t.Close() - t.t = nil - } - delete(p.pool, siamuxAddr) - } - p.mu.Unlock() - return err -} - -// FetchRevision tries to fetch a contract revision from the host. -func (h *host) FetchRevision(ctx context.Context, fetchTimeout time.Duration) (types.FileContractRevision, error) { - timeoutCtx := func() (context.Context, context.CancelFunc) { - if fetchTimeout > 0 { - return context.WithTimeout(ctx, fetchTimeout) - } - return ctx, func() {} - } - - // Try to fetch the revision with an account first. - ctx, cancel := timeoutCtx() - defer cancel() - rev, err := h.fetchRevisionWithAccount(ctx, h.hk, h.siamuxAddr, h.fcid) - if err != nil && !(isBalanceInsufficient(err) || isWithdrawalsInactive(err) || isWithdrawalExpired(err) || isClosedStream(err)) { // TODO: checking for a closed stream here can be removed once the withdrawal timeout on the host side is removed - return types.FileContractRevision{}, fmt.Errorf("unable to fetch revision with account: %v", err) - } else if err == nil { - return rev, nil - } - - // Fall back to using the contract to pay for the revision. - ctx, cancel = timeoutCtx() - defer cancel() - rev, err = h.fetchRevisionWithContract(ctx, h.hk, h.siamuxAddr, h.fcid) - if err != nil && !isInsufficientFunds(err) { - return types.FileContractRevision{}, fmt.Errorf("unable to fetch revision with contract: %v", err) - } else if err == nil { - return rev, nil - } - - // If we don't have enough money in the contract, try again without paying. - ctx, cancel = timeoutCtx() - defer cancel() - rev, err = h.fetchRevisionNoPayment(ctx, h.hk, h.siamuxAddr, h.fcid) - if err != nil { - return types.FileContractRevision{}, fmt.Errorf("unable to fetch revision without payment: %v", err) - } - return rev, nil -} - -func (h *host) fetchRevisionWithAccount(ctx context.Context, hostKey types.PublicKey, siamuxAddr string, fcid types.FileContractID) (rev types.FileContractRevision, err error) { - err = h.acc.WithWithdrawal(ctx, func() (types.Currency, error) { - var cost types.Currency - return cost, h.transportPool.withTransportV3(ctx, hostKey, siamuxAddr, func(ctx context.Context, t *transportV3) (err error) { - rev, err = RPCLatestRevision(ctx, t, fcid, func(rev *types.FileContractRevision) (rhpv3.HostPriceTable, rhpv3.PaymentMethod, error) { - pt, err := h.priceTable(ctx, nil) - if err != nil { - return rhpv3.HostPriceTable{}, nil, fmt.Errorf("failed to fetch pricetable, err: %w", err) - } - cost = pt.LatestRevisionCost.Add(pt.UpdatePriceTableCost) // add cost of fetching the pricetable since we might need a new one and it's better to stay pessimistic - payment := rhpv3.PayByEphemeralAccount(h.acc.id, cost, pt.HostBlockHeight+defaultWithdrawalExpiryBlocks, h.accountKey) - return pt, &payment, nil - }) - if err != nil { - return err - } - return nil - }) - }) - return rev, err -} - -// FetchRevisionWithContract fetches the latest revision of a contract and uses -// a contract to pay for it. -func (h *host) fetchRevisionWithContract(ctx context.Context, hostKey types.PublicKey, siamuxAddr string, contractID types.FileContractID) (rev types.FileContractRevision, err error) { - err = h.transportPool.withTransportV3(ctx, hostKey, siamuxAddr, func(ctx context.Context, t *transportV3) (err error) { - rev, err = RPCLatestRevision(ctx, t, contractID, func(rev *types.FileContractRevision) (rhpv3.HostPriceTable, rhpv3.PaymentMethod, error) { - // Fetch pt. - pt, err := h.priceTable(ctx, rev) - if err != nil { - return rhpv3.HostPriceTable{}, nil, fmt.Errorf("failed to fetch pricetable, err: %v", err) - } - // Pay for the revision. - payment, err := payByContract(rev, pt.LatestRevisionCost, h.acc.id, h.renterKey) - if err != nil { - return rhpv3.HostPriceTable{}, nil, err - } - return pt, &payment, nil - }) - return err - }) - return rev, err -} - -func (h *host) fetchRevisionNoPayment(ctx context.Context, hostKey types.PublicKey, siamuxAddr string, contractID types.FileContractID) (rev types.FileContractRevision, err error) { - err = h.transportPool.withTransportV3(ctx, hostKey, siamuxAddr, func(ctx context.Context, t *transportV3) (err error) { - _, err = RPCLatestRevision(ctx, t, contractID, func(r *types.FileContractRevision) (rhpv3.HostPriceTable, rhpv3.PaymentMethod, error) { - rev = *r - return rhpv3.HostPriceTable{}, nil, nil - }) - return err - }) - return rev, err -} - -type ( - // accounts stores the balance and other metrics of accounts that the - // worker maintains with a host. - accounts struct { - as AccountStore - key types.PrivateKey - } - - // account contains information regarding a specific account of the - // worker. - account struct { - as AccountStore - id rhpv3.Account - key types.PrivateKey - host types.PublicKey - } -) - -func (w *worker) initAccounts(as AccountStore) { - if w.accounts != nil { - panic("accounts already initialized") // developer error - } - w.accounts = &accounts{ - as: as, - key: w.deriveSubKey("accountkey"), - } -} - -func (w *worker) initTransportPool() { - if w.transportPoolV3 != nil { - panic("transport pool already initialized") // developer error - } - w.transportPoolV3 = newTransportPoolV3() -} - -// ForHost returns an account to use for a given host. If the account -// doesn't exist, a new one is created. -func (a *accounts) ForHost(hk types.PublicKey) *account { - accountID := rhpv3.Account(a.deriveAccountKey(hk).PublicKey()) - return &account{ - as: a.as, - id: accountID, - key: a.key, - host: hk, - } -} - -func withAccountLock(ctx context.Context, as AccountStore, id rhpv3.Account, hk types.PublicKey, exclusive bool, fn func(a api.Account) error) error { - acc, lockID, err := as.LockAccount(ctx, id, hk, exclusive, accountLockingDuration) - if err != nil { - return err - } - err = fn(acc) - - // unlock account - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - _ = as.UnlockAccount(ctx, acc.ID, lockID) // ignore error - cancel() - - return err -} - -// Balance returns the account balance. -func (a *account) Balance(ctx context.Context) (balance types.Currency, err error) { - err = withAccountLock(ctx, a.as, a.id, a.host, false, func(account api.Account) error { - balance = types.NewCurrency(account.Balance.Uint64(), new(big.Int).Rsh(account.Balance, 64).Uint64()) - return nil - }) - return -} - -// WithDeposit increases the balance of an account by the amount returned by -// amtFn if amtFn doesn't return an error. -func (a *account) WithDeposit(ctx context.Context, amtFn func() (types.Currency, error)) error { - return withAccountLock(ctx, a.as, a.id, a.host, false, func(_ api.Account) error { - amt, err := amtFn() - if err != nil { - return err - } - return a.as.AddBalance(ctx, a.id, a.host, amt.Big()) - }) -} - -// WithSync syncs an accounts balance with the bus. To do so, the account is -// locked while the balance is fetched through balanceFn. -func (a *account) WithSync(ctx context.Context, balanceFn func() (types.Currency, error)) error { - return withAccountLock(ctx, a.as, a.id, a.host, true, func(_ api.Account) error { - balance, err := balanceFn() - if err != nil { - return err - } - return a.as.SetBalance(ctx, a.id, a.host, balance.Big()) - }) -} - -// WithWithdrawal decreases the balance of an account by the amount returned by -// amtFn. The amount is still withdrawn if amtFn returns an error since some -// costs are non-refundable. -func (a *account) WithWithdrawal(ctx context.Context, amtFn func() (types.Currency, error)) error { - return withAccountLock(ctx, a.as, a.id, a.host, false, func(account api.Account) error { - // return early if the account needs to sync - if account.RequiresSync { - return fmt.Errorf("%w; account requires resync", errBalanceInsufficient) - } - - // return early if our account is not funded - if account.Balance.Cmp(big.NewInt(0)) <= 0 { - return errBalanceInsufficient - } - - // execute amtFn - amt, err := amtFn() - - // in case of an insufficient balance, we schedule a sync - if isBalanceInsufficient(err) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - err = errors.Join(err, a.as.ScheduleSync(ctx, a.id, a.host)) - cancel() - } - - // if an amount was returned, we withdraw it - if !amt.IsZero() { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - err = errors.Join(err, a.as.AddBalance(ctx, a.id, a.host, new(big.Int).Neg(amt.Big()))) - cancel() - } - - return err - }) -} - -// deriveAccountKey derives an account plus key for a given host and worker. -// Each worker has its own account for a given host. That makes concurrency -// around keeping track of an accounts balance and refilling it a lot easier in -// a multi-worker setup. -func (a *accounts) deriveAccountKey(hostKey types.PublicKey) types.PrivateKey { - index := byte(0) // not used yet but can be used to derive more than 1 account per host - - // Append the host for which to create it and the index to the - // corresponding sub-key. - subKey := a.key - data := make([]byte, 0, len(subKey)+len(hostKey)+1) - data = append(data, subKey[:]...) - data = append(data, hostKey[:]...) - data = append(data, index) - - seed := types.HashBytes(data) - pk := types.NewPrivateKeyFromSeed(seed[:]) - for i := range seed { - seed[i] = 0 - } - return pk -} - -// priceTable fetches a price table from the host. If a revision is provided, it -// will be used to pay for the price table. The returned price table is -// guaranteed to be safe to use. -func (h *host) priceTable(ctx context.Context, rev *types.FileContractRevision) (rhpv3.HostPriceTable, error) { - pt, err := h.priceTables.fetch(ctx, h.hk, rev) - if err != nil { - return rhpv3.HostPriceTable{}, err - } - gc, err := GougingCheckerFromContext(ctx, false) - if err != nil { - return rhpv3.HostPriceTable{}, err - } - if breakdown := gc.Check(nil, &pt.HostPriceTable); breakdown.Gouging() { - return rhpv3.HostPriceTable{}, fmt.Errorf("%w: %v", errPriceTableGouging, breakdown) - } - return pt.HostPriceTable, nil -} - -// padBandwitdh pads the bandwidth to the next multiple of 1460 bytes. 1460 -// bytes is the maximum size of a TCP packet when using IPv4. -// TODO: once hostd becomes the only host implementation we can simplify this. -func padBandwidth(pt rhpv3.HostPriceTable, rc rhpv3.ResourceCost) rhpv3.ResourceCost { - padCost := func(cost, paddingSize types.Currency) types.Currency { - if paddingSize.IsZero() { - return cost // might happen if bandwidth is free - } - return cost.Add(paddingSize).Sub(types.NewCurrency64(1)).Div(paddingSize).Mul(paddingSize) - } - minPacketSize := uint64(1460) - minIngress := pt.UploadBandwidthCost.Mul64(minPacketSize) - minEgress := pt.DownloadBandwidthCost.Mul64(3*minPacketSize + responseLeeway) - rc.Ingress = padCost(rc.Ingress, minIngress) - rc.Egress = padCost(rc.Egress, minEgress) - return rc -} - -// readSectorCost returns an overestimate for the cost of reading a sector from a host -func readSectorCost(pt rhpv3.HostPriceTable, length uint64) (types.Currency, error) { - rc := pt.BaseCost() - rc = rc.Add(pt.ReadSectorCost(length)) - rc = padBandwidth(pt, rc) - cost, _ := rc.Total() - - // overestimate the cost by 10% - cost, overflow := cost.Mul64WithOverflow(11) - if overflow { - return types.ZeroCurrency, errors.New("overflow occurred while adding leeway to read sector cost") - } - return cost.Div64(10), nil -} - -// uploadSectorCost returns an overestimate for the cost of uploading a sector -// to a host -func uploadSectorCost(pt rhpv3.HostPriceTable, windowEnd uint64) (cost, collateral, storage types.Currency, _ error) { - rc := pt.BaseCost() - rc = rc.Add(pt.AppendSectorCost(windowEnd - pt.HostBlockHeight)) - rc = padBandwidth(pt, rc) - cost, collateral = rc.Total() - - // overestimate the cost by 10% - cost, overflow := cost.Mul64WithOverflow(11) - if overflow { - return types.ZeroCurrency, types.ZeroCurrency, types.ZeroCurrency, errors.New("overflow occurred while adding leeway to read sector cost") - } - return cost.Div64(10), collateral, rc.Storage, nil -} - -func processPayment(s *streamV3, payment rhpv3.PaymentMethod) error { - var paymentType types.Specifier - switch payment.(type) { - case *rhpv3.PayByContractRequest: - paymentType = rhpv3.PaymentTypeContract - case *rhpv3.PayByEphemeralAccountRequest: - paymentType = rhpv3.PaymentTypeEphemeralAccount - default: - panic("unhandled payment method") - } - if err := s.WriteResponse(&paymentType); err != nil { - return err - } else if err := s.WriteResponse(payment); err != nil { - return err - } - if _, ok := payment.(*rhpv3.PayByContractRequest); ok { - var pr rhpv3.PaymentResponse - if err := s.ReadResponse(&pr, defaultRPCResponseMaxSize); err != nil { - return err - } - // TODO: return host signature - } - return nil -} - -// PriceTablePaymentFunc is a function that can be passed in to RPCPriceTable. -// It is called after the price table is received from the host and supposed to -// create a payment for that table and return it. It can also be used to perform -// gouging checks before paying for the table. -type PriceTablePaymentFunc func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) - -// RPCPriceTable calls the UpdatePriceTable RPC. -func RPCPriceTable(ctx context.Context, t *transportV3, paymentFunc PriceTablePaymentFunc) (_ api.HostPriceTable, err error) { - defer wrapErr(ctx, "PriceTable", &err) - - s, err := t.DialStream(ctx) - if err != nil { - return api.HostPriceTable{}, err - } - defer s.Close() - - var pt rhpv3.HostPriceTable - var ptr rhpv3.RPCUpdatePriceTableResponse - if err := s.WriteRequest(rhpv3.RPCUpdatePriceTableID, nil); err != nil { - return api.HostPriceTable{}, fmt.Errorf("couldn't send RPCUpdatePriceTableID: %w", err) - } else if err := s.ReadResponse(&ptr, maxPriceTableSize); err != nil { - return api.HostPriceTable{}, fmt.Errorf("couldn't read RPCUpdatePriceTableResponse: %w", err) - } else if err := json.Unmarshal(ptr.PriceTableJSON, &pt); err != nil { - return api.HostPriceTable{}, fmt.Errorf("couldn't unmarshal price table: %w", err) - } else if payment, err := paymentFunc(pt); err != nil { - return api.HostPriceTable{}, fmt.Errorf("couldn't create payment: %w", err) - } else if payment == nil { - return api.HostPriceTable{ - HostPriceTable: pt, - Expiry: time.Now(), - }, nil // intended not to pay - } else if err := processPayment(s, payment); err != nil { - return api.HostPriceTable{}, fmt.Errorf("couldn't process payment: %w", err) - } else if err := s.ReadResponse(&rhpv3.RPCPriceTableResponse{}, 0); err != nil { - return api.HostPriceTable{}, fmt.Errorf("couldn't read RPCPriceTableResponse: %w", err) - } else { - return api.HostPriceTable{ - HostPriceTable: pt, - Expiry: time.Now().Add(pt.Validity), - }, nil - } -} - -// RPCAccountBalance calls the AccountBalance RPC. -func RPCAccountBalance(ctx context.Context, t *transportV3, payment rhpv3.PaymentMethod, account rhpv3.Account, settingsID rhpv3.SettingsID) (bal types.Currency, err error) { - defer wrapErr(ctx, "AccountBalance", &err) - s, err := t.DialStream(ctx) - if err != nil { - return types.ZeroCurrency, err - } - defer s.Close() - - req := rhpv3.RPCAccountBalanceRequest{ - Account: account, - } - var resp rhpv3.RPCAccountBalanceResponse - if err := s.WriteRequest(rhpv3.RPCAccountBalanceID, &settingsID); err != nil { - return types.ZeroCurrency, err - } else if err := processPayment(s, payment); err != nil { - return types.ZeroCurrency, err - } else if err := s.WriteResponse(&req); err != nil { - return types.ZeroCurrency, err - } else if err := s.ReadResponse(&resp, 128); err != nil { - return types.ZeroCurrency, err - } - return resp.Balance, nil -} - -// RPCFundAccount calls the FundAccount RPC. -func RPCFundAccount(ctx context.Context, t *transportV3, payment rhpv3.PaymentMethod, account rhpv3.Account, settingsID rhpv3.SettingsID) (err error) { - defer wrapErr(ctx, "FundAccount", &err) - s, err := t.DialStream(ctx) - if err != nil { - return err - } - defer s.Close() - - req := rhpv3.RPCFundAccountRequest{ - Account: account, - } - var resp rhpv3.RPCFundAccountResponse - if err := s.WriteRequest(rhpv3.RPCFundAccountID, &settingsID); err != nil { - return err - } else if err := s.WriteResponse(&req); err != nil { - return err - } else if err := processPayment(s, payment); err != nil { - return err - } else if err := s.ReadResponse(&resp, defaultRPCResponseMaxSize); err != nil { - return err - } - return nil -} - -// RPCLatestRevision calls the LatestRevision RPC. The paymentFunc allows for -// fetching a pricetable using the fetched revision to pay for it. If -// paymentFunc returns 'nil' as payment, the host is not paid. -func RPCLatestRevision(ctx context.Context, t *transportV3, contractID types.FileContractID, paymentFunc func(rev *types.FileContractRevision) (rhpv3.HostPriceTable, rhpv3.PaymentMethod, error)) (_ types.FileContractRevision, err error) { - defer wrapErr(ctx, "LatestRevision", &err) - s, err := t.DialStream(ctx) - if err != nil { - return types.FileContractRevision{}, err - } - defer s.Close() - req := rhpv3.RPCLatestRevisionRequest{ - ContractID: contractID, - } - var resp rhpv3.RPCLatestRevisionResponse - if err := s.WriteRequest(rhpv3.RPCLatestRevisionID, &req); err != nil { - return types.FileContractRevision{}, err - } else if err := s.ReadResponse(&resp, defaultRPCResponseMaxSize); err != nil { - return types.FileContractRevision{}, err - } else if pt, payment, err := paymentFunc(&resp.Revision); err != nil || payment == nil { - return types.FileContractRevision{}, err - } else if err := s.WriteResponse(&pt.UID); err != nil { - return types.FileContractRevision{}, err - } else if err := processPayment(s, payment); err != nil { - return types.FileContractRevision{}, err - } - return resp.Revision, nil -} - -// RPCReadSector calls the ExecuteProgram RPC with a ReadSector instruction. -func RPCReadSector(ctx context.Context, t *transportV3, w io.Writer, pt rhpv3.HostPriceTable, payment rhpv3.PaymentMethod, offset, length uint32, merkleRoot types.Hash256) (cost, refund types.Currency, err error) { - defer wrapErr(ctx, "ReadSector", &err) - s, err := t.DialStream(ctx) - if err != nil { - return types.ZeroCurrency, types.ZeroCurrency, err - } - defer s.Close() - - var buf bytes.Buffer - e := types.NewEncoder(&buf) - e.WriteUint64(uint64(length)) - e.WriteUint64(uint64(offset)) - merkleRoot.EncodeTo(e) - e.Flush() - - req := rhpv3.RPCExecuteProgramRequest{ - FileContractID: types.FileContractID{}, - Program: []rhpv3.Instruction{&rhpv3.InstrReadSector{ - LengthOffset: 0, - OffsetOffset: 8, - MerkleRootOffset: 16, - ProofRequired: true, - }}, - ProgramData: buf.Bytes(), - } - - var cancellationToken types.Specifier - var resp rhpv3.RPCExecuteProgramResponse - if err = s.WriteRequest(rhpv3.RPCExecuteProgramID, &pt.UID); err != nil { - return - } else if err = processPayment(s, payment); err != nil { - return - } else if err = s.WriteResponse(&req); err != nil { - return - } else if err = s.ReadResponse(&cancellationToken, 16); err != nil { - return - } else if err = s.ReadResponse(&resp, rhpv2.SectorSize+responseLeeway); err != nil { - return - } - - // check response error - if err = resp.Error; err != nil { - refund = resp.FailureRefund - return - } - cost = resp.TotalCost - - // build proof - proof := make([]crypto.Hash, len(resp.Proof)) - for i, h := range resp.Proof { - proof[i] = crypto.Hash(h) - } - - // verify proof - proofStart := int(offset) / crypto.SegmentSize - proofEnd := int(offset+length) / crypto.SegmentSize - if !crypto.VerifyRangeProof(resp.Output, proof, proofStart, proofEnd, crypto.Hash(merkleRoot)) { - err = errors.New("proof verification failed") - return - } - - _, err = w.Write(resp.Output) - return -} - -func RPCAppendSector(ctx context.Context, t *transportV3, renterKey types.PrivateKey, pt rhpv3.HostPriceTable, rev *types.FileContractRevision, payment rhpv3.PaymentMethod, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte) (cost types.Currency, err error) { - defer wrapErr(ctx, "AppendSector", &err) - - // sanity check revision first - if rev.RevisionNumber == math.MaxUint64 { - return types.ZeroCurrency, errMaxRevisionReached - } - - s, err := t.DialStream(ctx) - if err != nil { - return types.ZeroCurrency, err - } - defer s.Close() - - req := rhpv3.RPCExecuteProgramRequest{ - FileContractID: rev.ParentID, - Program: []rhpv3.Instruction{&rhpv3.InstrAppendSector{ - SectorDataOffset: 0, - ProofRequired: true, - }}, - ProgramData: (*sector)[:], - } - - var cancellationToken types.Specifier - var executeResp rhpv3.RPCExecuteProgramResponse - if err = s.WriteRequest(rhpv3.RPCExecuteProgramID, &pt.UID); err != nil { - return - } else if err = processPayment(s, payment); err != nil { - return - } else if err = s.WriteResponse(&req); err != nil { - return - } else if err = s.ReadResponse(&cancellationToken, 16); err != nil { - return - } else if err = s.ReadResponse(&executeResp, defaultRPCResponseMaxSize); err != nil { - return - } - - // compute expected collateral and refund - expectedCost, expectedCollateral, expectedRefund, err := uploadSectorCost(pt, rev.WindowEnd) - if err != nil { - return types.ZeroCurrency, err - } - - // apply leeways. - // TODO: remove once most hosts use hostd. Then we can check for exact values. - expectedCollateral = expectedCollateral.Mul64(9).Div64(10) - expectedCost = expectedCost.Mul64(11).Div64(10) - expectedRefund = expectedRefund.Mul64(9).Div64(10) - - // check if the cost, collateral and refund match our expectation. - if executeResp.TotalCost.Cmp(expectedCost) > 0 { - return types.ZeroCurrency, fmt.Errorf("cost exceeds expectation: %v > %v", executeResp.TotalCost.String(), expectedCost.String()) - } - if executeResp.FailureRefund.Cmp(expectedRefund) < 0 { - return types.ZeroCurrency, fmt.Errorf("insufficient refund: %v < %v", executeResp.FailureRefund.String(), expectedRefund.String()) - } - if executeResp.AdditionalCollateral.Cmp(expectedCollateral) < 0 { - return types.ZeroCurrency, fmt.Errorf("insufficient collateral: %v < %v", executeResp.AdditionalCollateral.String(), expectedCollateral.String()) - } - - // set the cost and refund - cost = executeResp.TotalCost - defer func() { - if err != nil { - cost = types.ZeroCurrency - if executeResp.FailureRefund.Cmp(cost) < 0 { - cost = cost.Sub(executeResp.FailureRefund) - } - } - }() - - // check response error - if err = executeResp.Error; err != nil { - return - } - cost = executeResp.TotalCost - - // include the refund in the collateral - collateral := executeResp.AdditionalCollateral.Add(executeResp.FailureRefund) - - // check proof - if rev.Filesize == 0 { - // For the first upload to a contract we don't get a proof. So we just - // assert that the new contract root matches the root of the sector. - if rev.Filesize == 0 && executeResp.NewMerkleRoot != sectorRoot { - return types.ZeroCurrency, fmt.Errorf("merkle root doesn't match the sector root upon first upload to contract: %v != %v", executeResp.NewMerkleRoot, sectorRoot) - } - } else { - // Otherwise we make sure the proof was transmitted and verify it. - actions := []rhpv2.RPCWriteAction{{Type: rhpv2.RPCWriteActionAppend}} // TODO: change once rhpv3 support is available - if !rhpv2.VerifyDiffProof(actions, rev.Filesize/rhpv2.SectorSize, executeResp.Proof, []types.Hash256{}, rev.FileMerkleRoot, executeResp.NewMerkleRoot, []types.Hash256{sectorRoot}) { - return types.ZeroCurrency, errors.New("proof verification failed") - } - } - - // finalize the program with a new revision. - newRevision := *rev - newValid, newMissed, err := updateRevisionOutputs(&newRevision, types.ZeroCurrency, collateral) - if err != nil { - return types.ZeroCurrency, err - } - newRevision.Filesize += rhpv2.SectorSize - newRevision.RevisionNumber++ - newRevision.FileMerkleRoot = executeResp.NewMerkleRoot - - finalizeReq := rhpv3.RPCFinalizeProgramRequest{ - Signature: renterKey.SignHash(hashRevision(newRevision)), - ValidProofValues: newValid, - MissedProofValues: newMissed, - RevisionNumber: newRevision.RevisionNumber, - } - - var finalizeResp rhpv3.RPCFinalizeProgramResponse - if err = s.WriteResponse(&finalizeReq); err != nil { - return - } else if err = s.ReadResponse(&finalizeResp, 64); err != nil { - return - } - - // read one more time to receive a potential error in case finalising the - // contract fails after receiving the RPCFinalizeProgramResponse. This also - // guarantees that the program is finalised before we return. - // TODO: remove once most hosts use hostd. - errFinalise := s.ReadResponse(&finalizeResp, 64) - if errFinalise != nil && - !errors.Is(errFinalise, io.EOF) && - !errors.Is(errFinalise, mux.ErrClosedConn) && - !errors.Is(errFinalise, mux.ErrClosedStream) && - !errors.Is(errFinalise, mux.ErrPeerClosedStream) && - !errors.Is(errFinalise, mux.ErrPeerClosedConn) { - err = errFinalise - return - } - - *rev = newRevision - return -} - -func RPCRenew(ctx context.Context, rrr api.RHPRenewRequest, bus Bus, t *transportV3, pt *rhpv3.HostPriceTable, rev types.FileContractRevision, renterKey types.PrivateKey, l *zap.SugaredLogger) (_ rhpv2.ContractRevision, _ []types.Transaction, _, _ types.Currency, err error) { - defer wrapErr(ctx, "RPCRenew", &err) - - s, err := t.DialStream(ctx) - if err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to dial stream: %w", err) - } - defer s.Close() - - // Send the ptUID. - var ptUID rhpv3.SettingsID - if pt != nil { - ptUID = pt.UID - } - if err = s.WriteRequest(rhpv3.RPCRenewContractID, &ptUID); err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to send ptUID: %w", err) - } - - // If we didn't have a valid pricetable, read the temporary one from the - // host. - if ptUID == (rhpv3.SettingsID{}) { - var ptResp rhpv3.RPCUpdatePriceTableResponse - if err = s.ReadResponse(&ptResp, defaultRPCResponseMaxSize); err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to read RPCUpdatePriceTableResponse: %w", err) - } - pt = new(rhpv3.HostPriceTable) - if err = json.Unmarshal(ptResp.PriceTableJSON, pt); err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to unmarshal price table: %w", err) - } - } - - // Perform gouging checks. - gc, err := GougingCheckerFromContext(ctx, false) - if err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to get gouging checker: %w", err) - } - if breakdown := gc.Check(nil, pt); breakdown.Gouging() { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("host gouging during renew: %v", breakdown) - } - - // Prepare the signed transaction that contains the final revision as well - // as the new contract - wprr, err := bus.WalletPrepareRenew(ctx, rev, rrr.HostAddress, rrr.RenterAddress, renterKey, rrr.RenterFunds, rrr.MinNewCollateral, rrr.MaxFundAmount, *pt, rrr.EndHeight, rrr.WindowSize, rrr.ExpectedNewStorage) - if err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to prepare renew: %w", err) - } - - // Starting from here, we need to make sure to release the txn on error. - defer discardTxnOnErr(ctx, bus, l, wprr.TransactionSet[len(wprr.TransactionSet)-1], "RPCRenew", &err) - - txnSet := wprr.TransactionSet - parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] - - // Sign only the revision and contract. We can't sign everything because - // then the host can't add its own outputs. - h := types.NewHasher() - txn.FileContracts[0].EncodeTo(h.E) - txn.FileContractRevisions[0].EncodeTo(h.E) - finalRevisionSignature := renterKey.SignHash(h.Sum()) - - // Send the request. - req := rhpv3.RPCRenewContractRequest{ - TransactionSet: txnSet, - RenterKey: rev.UnlockConditions.PublicKeys[0], - FinalRevisionSignature: finalRevisionSignature, - } - if err = s.WriteResponse(&req); err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to send RPCRenewContractRequest: %w", err) - } - - // Incorporate the host's additions. - var hostAdditions rhpv3.RPCRenewContractHostAdditions - if err = s.ReadResponse(&hostAdditions, defaultRPCResponseMaxSize); err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to read RPCRenewContractHostAdditions: %w", err) - } - parents = append(parents, hostAdditions.Parents...) - txn.SiacoinInputs = append(txn.SiacoinInputs, hostAdditions.SiacoinInputs...) - txn.SiacoinOutputs = append(txn.SiacoinOutputs, hostAdditions.SiacoinOutputs...) - finalRevRenterSig := types.TransactionSignature{ - ParentID: types.Hash256(rev.ParentID), - PublicKeyIndex: 0, // renter key is first - CoveredFields: types.CoveredFields{ - FileContracts: []uint64{0}, - FileContractRevisions: []uint64{0}, - }, - Signature: finalRevisionSignature[:], - } - finalRevHostSig := types.TransactionSignature{ - ParentID: types.Hash256(rev.ParentID), - PublicKeyIndex: 1, - CoveredFields: types.CoveredFields{ - FileContracts: []uint64{0}, - FileContractRevisions: []uint64{0}, - }, - Signature: hostAdditions.FinalRevisionSignature[:], - } - txn.Signatures = []types.TransactionSignature{finalRevRenterSig, finalRevHostSig} - - // Sign the inputs we funded the txn with and cover the whole txn including - // the existing signatures. - cf := types.CoveredFields{ - WholeTransaction: true, - Signatures: []uint64{0, 1}, - } - if err := bus.WalletSign(ctx, &txn, wprr.ToSign, cf); err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to sign transaction: %w", err) - } - - // Create a new no-op revision and sign it. - noOpRevision := initialRevision(txn, rev.UnlockConditions.PublicKeys[1], renterKey.PublicKey().UnlockKey()) - h = types.NewHasher() - noOpRevision.EncodeTo(h.E) - renterNoOpSig := renterKey.SignHash(h.Sum()) - renterNoOpRevisionSignature := types.TransactionSignature{ - ParentID: types.Hash256(noOpRevision.ParentID), - PublicKeyIndex: 0, // renter key is first - CoveredFields: types.CoveredFields{ - FileContractRevisions: []uint64{0}, - }, - Signature: renterNoOpSig[:], - } - - // Send the newly added signatures to the host and the signature for the - // initial no-op revision. - rs := rhpv3.RPCRenewSignatures{ - TransactionSignatures: txn.Signatures[2:], - RevisionSignature: renterNoOpRevisionSignature, - } - if err = s.WriteResponse(&rs); err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to send RPCRenewSignatures: %w", err) - } - - // Receive the host's signatures. - var hostSigs rhpv3.RPCRenewSignatures - if err = s.ReadResponse(&hostSigs, defaultRPCResponseMaxSize); err != nil { - return rhpv2.ContractRevision{}, nil, types.Currency{}, types.Currency{}, fmt.Errorf("failed to read RPCRenewSignatures: %w", err) - } - txn.Signatures = append(txn.Signatures, hostSigs.TransactionSignatures...) - - // Add the parents to get the full txnSet. - txnSet = parents - txnSet = append(txnSet, txn) - - return rhpv2.ContractRevision{ - Revision: noOpRevision, - Signatures: [2]types.TransactionSignature{renterNoOpRevisionSignature, hostSigs.RevisionSignature}, - }, txnSet, pt.ContractPrice, wprr.FundAmount, nil -} - -// initialRevision returns the first revision of a file contract formation -// transaction. -func initialRevision(formationTxn types.Transaction, hostPubKey, renterPubKey types.UnlockKey) types.FileContractRevision { - fc := formationTxn.FileContracts[0] - return types.FileContractRevision{ - ParentID: formationTxn.FileContractID(0), - UnlockConditions: types.UnlockConditions{ - PublicKeys: []types.UnlockKey{renterPubKey, hostPubKey}, - SignaturesRequired: 2, - }, - FileContract: types.FileContract{ - Filesize: fc.Filesize, - FileMerkleRoot: fc.FileMerkleRoot, - WindowStart: fc.WindowStart, - WindowEnd: fc.WindowEnd, - ValidProofOutputs: fc.ValidProofOutputs, - MissedProofOutputs: fc.MissedProofOutputs, - UnlockHash: fc.UnlockHash, - RevisionNumber: 1, - }, - } -} - -func payByContract(rev *types.FileContractRevision, amount types.Currency, refundAcct rhpv3.Account, sk types.PrivateKey) (rhpv3.PayByContractRequest, error) { - if rev.RevisionNumber == math.MaxUint64 { - return rhpv3.PayByContractRequest{}, errMaxRevisionReached - } - payment, ok := rhpv3.PayByContract(rev, amount, refundAcct, sk) - if !ok { - return rhpv3.PayByContractRequest{}, ErrInsufficientFunds - } - return payment, nil -} diff --git a/worker/s3/s3.go b/worker/s3/s3.go index cc6021652..d5cbb71a3 100644 --- a/worker/s3/s3.go +++ b/worker/s3/s3.go @@ -69,12 +69,12 @@ func (l *gofakes3Logger) Print(level gofakes3.LogLevel, v ...interface{}) { } } -func New(b Bus, w Worker, logger *zap.SugaredLogger, opts Opts) (http.Handler, error) { - namedLogger := logger.Named("s3") +func New(b Bus, w Worker, logger *zap.Logger, opts Opts) (http.Handler, error) { + logger = logger.Named("s3") s3Backend := &s3{ b: b, w: w, - logger: namedLogger, + logger: logger.Sugar(), } backend := gofakes3.Backend(s3Backend) if !opts.AuthDisabled { @@ -84,9 +84,7 @@ func New(b Bus, w Worker, logger *zap.SugaredLogger, opts Opts) (http.Handler, e backend, gofakes3.WithHostBucket(opts.HostBucketEnabled), gofakes3.WithHostBucketBase(opts.HostBucketBases...), - gofakes3.WithLogger(&gofakes3Logger{ - l: namedLogger, - }), + gofakes3.WithLogger(&gofakes3Logger{l: logger.Sugar()}), gofakes3.WithRequestID(rand.Uint64()), gofakes3.WithoutVersioning(), ) diff --git a/worker/spending.go b/worker/spending.go index 87d2ec17d..6b9914c58 100644 --- a/worker/spending.go +++ b/worker/spending.go @@ -35,7 +35,7 @@ var ( _ ContractSpendingRecorder = (*contractSpendingRecorder)(nil) ) -func (w *worker) initContractSpendingRecorder(flushInterval time.Duration) { +func (w *Worker) initContractSpendingRecorder(flushInterval time.Duration) { if w.contractSpendingRecorder != nil { panic("ContractSpendingRecorder already initialized") // developer error } diff --git a/worker/upload.go b/worker/upload.go index 4b7595bc0..6ced77f9a 100644 --- a/worker/upload.go +++ b/worker/upload.go @@ -17,8 +17,8 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" - "go.sia.tech/renterd/stats" "go.uber.org/zap" ) @@ -51,8 +51,8 @@ type ( maxOverdrive uint64 overdriveTimeout time.Duration - statsOverdrivePct *stats.DataPoints - statsSlabUploadSpeedBytesPerMS *stats.DataPoints + statsOverdrivePct *utils.DataPoints + statsSlabUploadSpeedBytesPerMS *utils.DataPoints shutdownCtx context.Context @@ -146,18 +146,17 @@ type ( } ) -func (w *worker) initUploadManager(maxMemory, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.SugaredLogger) { +func (w *Worker) initUploadManager(maxMemory, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.Logger) { if w.uploadManager != nil { panic("upload manager already initialized") // developer error } - mm := newMemoryManager(logger.Named("memorymanager"), maxMemory) - w.uploadManager = newUploadManager(w.shutdownCtx, w, mm, w.bus, w.bus, w.bus, maxOverdrive, overdriveTimeout, w.contractLockingDuration, logger) + w.uploadManager = newUploadManager(w.shutdownCtx, w, w.bus, w.bus, w.bus, maxMemory, maxOverdrive, overdriveTimeout, w.contractLockingDuration, logger) } -func (w *worker) upload(ctx context.Context, bucket, path string, r io.Reader, contracts []api.ContractMetadata, opts ...UploadOption) (_ string, err error) { +func (w *Worker) upload(ctx context.Context, bucket, path string, rs api.RedundancySettings, r io.Reader, contracts []api.ContractMetadata, opts ...UploadOption) (_ string, err error) { // apply the options - up := defaultParameters(bucket, path) + up := defaultParameters(bucket, path, rs) for _, opt := range opts { opt(&up) } @@ -212,7 +211,7 @@ func (w *worker) upload(ctx context.Context, bucket, path string, r io.Reader, c return eTag, nil } -func (w *worker) threadedUploadPackedSlabs(rs api.RedundancySettings, contractSet string, lockPriority int) { +func (w *Worker) threadedUploadPackedSlabs(rs api.RedundancySettings, contractSet string, lockPriority int) { key := fmt.Sprintf("%d-%d_%s", rs.MinShards, rs.TotalShards, contractSet) w.uploadsMu.Lock() if _, ok := w.uploadingPackedSlabs[key]; ok { @@ -278,7 +277,7 @@ func (w *worker) threadedUploadPackedSlabs(rs api.RedundancySettings, contractSe wg.Wait() } -func (w *worker) tryUploadPackedSlab(ctx context.Context, mem Memory, ps api.PackedSlab, rs api.RedundancySettings, contractSet string, lockPriority int) error { +func (w *Worker) tryUploadPackedSlab(ctx context.Context, mem Memory, ps api.PackedSlab, rs api.RedundancySettings, contractSet string, lockPriority int) error { // fetch contracts contracts, err := w.bus.Contracts(ctx, api.ContractsOpts{ContractSet: contractSet}) if err != nil { @@ -303,22 +302,23 @@ func (w *worker) tryUploadPackedSlab(ctx context.Context, mem Memory, ps api.Pac return nil } -func newUploadManager(ctx context.Context, hm HostManager, mm MemoryManager, os ObjectStore, cl ContractLocker, cs ContractStore, maxOverdrive uint64, overdriveTimeout time.Duration, contractLockDuration time.Duration, logger *zap.SugaredLogger) *uploadManager { +func newUploadManager(ctx context.Context, hm HostManager, os ObjectStore, cl ContractLocker, cs ContractStore, maxMemory, maxOverdrive uint64, overdriveTimeout time.Duration, contractLockDuration time.Duration, logger *zap.Logger) *uploadManager { + logger = logger.Named("uploadmanager") return &uploadManager{ hm: hm, - mm: mm, + mm: newMemoryManager(maxMemory, logger), os: os, cl: cl, cs: cs, - logger: logger, + logger: logger.Sugar(), contractLockDuration: contractLockDuration, maxOverdrive: maxOverdrive, overdriveTimeout: overdriveTimeout, - statsOverdrivePct: stats.NoDecay(), - statsSlabUploadSpeedBytesPerMS: stats.NoDecay(), + statsOverdrivePct: utils.NewDataPoints(0), + statsSlabUploadSpeedBytesPerMS: utils.NewDataPoints(0), shutdownCtx: ctx, @@ -341,8 +341,8 @@ func (mgr *uploadManager) newUploader(os ObjectStore, cl ContractLocker, cs Cont signalNewUpload: make(chan struct{}, 1), // stats - statsSectorUploadEstimateInMS: stats.Default(), - statsSectorUploadSpeedBytesPerMS: stats.NoDecay(), + statsSectorUploadEstimateInMS: utils.NewDataPoints(10 * time.Minute), + statsSectorUploadSpeedBytesPerMS: utils.NewDataPoints(0), // covered by mutex host: hm.Host(c.HostKey, c.ID, c.SiamuxAddr), diff --git a/worker/upload_params.go b/worker/upload_params.go index ae8baa8d0..109488bb9 100644 --- a/worker/upload_params.go +++ b/worker/upload_params.go @@ -2,7 +2,6 @@ package worker import ( "go.sia.tech/renterd/api" - "go.sia.tech/renterd/build" "go.sia.tech/renterd/object" ) @@ -26,7 +25,7 @@ type uploadParameters struct { metadata api.ObjectUserMetadata } -func defaultParameters(bucket, path string) uploadParameters { +func defaultParameters(bucket, path string, rs api.RedundancySettings) uploadParameters { return uploadParameters{ bucket: bucket, path: path, @@ -34,7 +33,7 @@ func defaultParameters(bucket, path string) uploadParameters { ec: object.GenerateEncryptionKey(), // random key encryptionOffset: 0, // from the beginning - rs: build.DefaultRedundancySettings, + rs: rs, } } @@ -82,12 +81,6 @@ func WithPartNumber(partNumber int) UploadOption { } } -func WithRedundancySettings(rs api.RedundancySettings) UploadOption { - return func(up *uploadParameters) { - up.rs = rs - } -} - func WithUploadID(uploadID string) UploadOption { return func(up *uploadParameters) { up.uploadID = uploadID diff --git a/worker/upload_test.go b/worker/upload_test.go index 320c71736..5bad0941a 100644 --- a/worker/upload_test.go +++ b/worker/upload_test.go @@ -222,7 +222,7 @@ func TestUploadPackedSlab(t *testing.T) { uploadBytes := func(n int) { t.Helper() params.path = fmt.Sprintf("%s_%d", t.Name(), c) - _, err := w.upload(context.Background(), params.bucket, params.path, bytes.NewReader(frand.Bytes(n)), w.Contracts(), opts...) + _, err := w.upload(context.Background(), params.bucket, params.path, testRedundancySettings, bytes.NewReader(frand.Bytes(n)), w.Contracts(), opts...) if err != nil { t.Fatal(err) } @@ -599,7 +599,7 @@ func TestUploadRegression(t *testing.T) { // upload data ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _, err := w.upload(ctx, params.bucket, params.path, bytes.NewReader(data), w.Contracts(), testOpts()...) + _, err := w.upload(ctx, params.bucket, params.path, testRedundancySettings, bytes.NewReader(data), w.Contracts(), testOpts()...) if !errors.Is(err, errUploadInterrupted) { t.Fatal(err) } @@ -608,7 +608,7 @@ func TestUploadRegression(t *testing.T) { unblock() // upload data - _, err = w.upload(context.Background(), params.bucket, params.path, bytes.NewReader(data), w.Contracts(), testOpts()...) + _, err = w.upload(context.Background(), params.bucket, params.path, testRedundancySettings, bytes.NewReader(data), w.Contracts(), testOpts()...) if err != nil { t.Fatal(err) } @@ -676,6 +676,5 @@ func testParameters(path string) uploadParameters { func testOpts() []UploadOption { return []UploadOption{ WithContractSet(testContractSet), - WithRedundancySettings(testRedundancySettings), } } diff --git a/worker/uploader.go b/worker/uploader.go index 4a9d9aa9a..410760b1a 100644 --- a/worker/uploader.go +++ b/worker/uploader.go @@ -11,8 +11,8 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" "go.sia.tech/renterd/internal/utils" - "go.sia.tech/renterd/stats" "go.uber.org/zap" ) @@ -50,8 +50,8 @@ type ( consecutiveFailures uint64 lastRecompute time.Time - statsSectorUploadEstimateInMS *stats.DataPoints - statsSectorUploadSpeedBytesPerMS *stats.DataPoints + statsSectorUploadEstimateInMS *utils.DataPoints + statsSectorUploadSpeedBytesPerMS *utils.DataPoints } ) @@ -121,7 +121,7 @@ outer: start := time.Now() duration, err := u.execute(req) elapsed := time.Since(start) - if errors.Is(err, errMaxRevisionReached) { + if errors.Is(err, rhp3.ErrMaxRevisionReached) { if u.tryRefresh(req.sector.ctx) { u.enqueue(req) continue outer @@ -154,7 +154,7 @@ outer: func handleSectorUpload(uploadErr error, uploadDuration, totalDuration time.Duration, overdrive bool) (success bool, failure bool, uploadEstimateMS float64, uploadSpeedBytesPerMS float64) { // no-op cases - if utils.IsErr(uploadErr, errMaxRevisionReached) { + if utils.IsErr(uploadErr, rhp3.ErrMaxRevisionReached) { return false, false, 0, 0 } else if utils.IsErr(uploadErr, context.Canceled) { return false, false, 0, 0 @@ -172,7 +172,7 @@ func handleSectorUpload(uploadErr error, uploadDuration, totalDuration time.Dura // upload failed because we weren't able to create a payment, in this case // we want to punish the host but only to ensure we stop using it, meaning // we don't increment consecutive failures - if utils.IsErr(uploadErr, errFailedToCreatePayment) { + if utils.IsErr(uploadErr, rhp3.ErrFailedToCreatePayment) { return false, false, float64(time.Hour.Milliseconds()), 0 } @@ -180,7 +180,7 @@ func handleSectorUpload(uploadErr error, uploadDuration, totalDuration time.Dura // this case we want to punish the host for being too slow but only when we // weren't overdriving or when it took too long to dial if utils.IsErr(uploadErr, errSectorUploadFinished) { - slowDial := utils.IsErr(uploadErr, errDialTransport) && totalDuration > time.Second + slowDial := utils.IsErr(uploadErr, rhp3.ErrDialTransport) && totalDuration > time.Second slowLock := utils.IsErr(uploadErr, errAcquireContractFailed) && totalDuration > time.Second slowFetchRev := utils.IsErr(uploadErr, errFetchRevisionFailed) && totalDuration > time.Second if !overdrive || slowDial || slowLock || slowFetchRev { @@ -289,7 +289,7 @@ func (u *uploader) execute(req *sectorUploadReq) (_ time.Duration, err error) { if err != nil { return 0, fmt.Errorf("%w; %w", errFetchRevisionFailed, err) } else if rev.RevisionNumber == math.MaxUint64 { - return 0, errMaxRevisionReached + return 0, rhp3.ErrMaxRevisionReached } // update the bus diff --git a/worker/uploader_test.go b/worker/uploader_test.go index 46df5f584..172d4250b 100644 --- a/worker/uploader_test.go +++ b/worker/uploader_test.go @@ -8,6 +8,7 @@ import ( "time" rhpv2 "go.sia.tech/core/rhp/v2" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" ) func TestUploaderStopped(t *testing.T) { @@ -43,7 +44,7 @@ func TestHandleSectorUpload(t *testing.T) { regular := false errHostError := errors.New("some host error") - errSectorUploadFinishedAndDial := fmt.Errorf("%w;%w", errDialTransport, errSectorUploadFinished) + errSectorUploadFinishedAndDial := fmt.Errorf("%w;%w", rhp3.ErrDialTransport, errSectorUploadFinished) cases := []struct { // input @@ -63,8 +64,8 @@ func TestHandleSectorUpload(t *testing.T) { {nil, ms, ms, overdrive, true, false, 1, ss}, // renewed contract case - {errMaxRevisionReached, 0, ms, regular, false, false, 0, 0}, - {errMaxRevisionReached, 0, ms, overdrive, false, false, 0, 0}, + {rhp3.ErrMaxRevisionReached, 0, ms, regular, false, false, 0, 0}, + {rhp3.ErrMaxRevisionReached, 0, ms, overdrive, false, false, 0, 0}, // context canceled case {context.Canceled, 0, ms, regular, false, false, 0, 0}, @@ -77,8 +78,8 @@ func TestHandleSectorUpload(t *testing.T) { {errSectorUploadFinishedAndDial, ms, 1001 * ms, overdrive, false, true, 10010, 0}, // payment failure case - {errFailedToCreatePayment, 0, ms, regular, false, false, 3600000, 0}, - {errFailedToCreatePayment, 0, ms, overdrive, false, false, 3600000, 0}, + {rhp3.ErrFailedToCreatePayment, 0, ms, regular, false, false, 3600000, 0}, + {rhp3.ErrFailedToCreatePayment, 0, ms, overdrive, false, false, 3600000, 0}, // host failure {errHostError, ms, ms, regular, false, true, 3600000, 0}, diff --git a/worker/worker.go b/worker/worker.go index 03ce9c874..7073e0c63 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -8,6 +8,7 @@ import ( "io" "math" "math/big" + "net" "net/http" "os" "runtime" @@ -24,21 +25,21 @@ import ( "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" "go.sia.tech/renterd/build" + "go.sia.tech/renterd/config" + "go.sia.tech/renterd/internal/gouging" + rhp2 "go.sia.tech/renterd/internal/rhp/v2" + rhp3 "go.sia.tech/renterd/internal/rhp/v3" "go.sia.tech/renterd/internal/utils" iworker "go.sia.tech/renterd/internal/worker" "go.sia.tech/renterd/object" "go.sia.tech/renterd/webhooks" "go.sia.tech/renterd/worker/client" - "go.sia.tech/siad/modules" + "go.sia.tech/renterd/worker/s3" "go.uber.org/zap" "golang.org/x/crypto/blake2b" ) const ( - batchSizeDeleteSectors = uint64(1000) // 4GiB of contract data - batchSizeFetchSectors = uint64(25600) // 100GiB of contract data - - defaultLockTimeout = time.Minute defaultRevisionFetchTimeout = 30 * time.Second lockingPriorityActiveContractRevision = 100 @@ -69,8 +70,10 @@ func NewClient(address, password string) *Client { type ( Bus interface { + s3.Bus + alerts.Alerter - ConsensusState + gouging.ConsensusState webhooks.Broadcaster AccountStore @@ -93,7 +96,6 @@ type ( LockAccount(ctx context.Context, id rhpv3.Account, hostKey types.PublicKey, exclusive bool, duration time.Duration) (api.Account, uint64, error) UnlockAccount(ctx context.Context, id rhpv3.Account, lockID uint64) error - ResetDrift(ctx context.Context, id rhpv3.Account) error SetBalance(ctx context.Context, id rhpv3.Account, hk types.PublicKey, amt *big.Int) error ScheduleSync(ctx context.Context, id rhpv3.Account, hk types.PublicKey) error } @@ -158,17 +160,14 @@ type ( WebhookStore interface { RegisterWebhook(ctx context.Context, webhook webhooks.Webhook) error - } - - ConsensusState interface { - ConsensusState(ctx context.Context) (api.ConsensusState, error) + UnregisterWebhook(ctx context.Context, webhook webhooks.Webhook) error } ) // deriveSubKey can be used to derive a sub-masterkey from the worker's // masterkey to use for a specific purpose. Such as deriving more keys for // ephemeral accounts. -func (w *worker) deriveSubKey(purpose string) types.PrivateKey { +func (w *Worker) deriveSubKey(purpose string) types.PrivateKey { seed := blake2b.Sum256(append(w.masterKey[:], []byte(purpose)...)) pk := types.NewPrivateKeyFromSeed(seed[:]) for i := range seed { @@ -189,7 +188,7 @@ func (w *worker) deriveSubKey(purpose string) types.PrivateKey { // // TODO: instead of deriving a renter key use a randomly generated salt so we're // not limited to one key per host -func (w *worker) deriveRenterKey(hostKey types.PublicKey) types.PrivateKey { +func (w *Worker) deriveRenterKey(hostKey types.PublicKey) types.PrivateKey { seed := blake2b.Sum256(append(w.deriveSubKey("renterkey"), hostKey[:]...)) pk := types.NewPrivateKeyFromSeed(seed[:]) for i := range seed { @@ -200,22 +199,26 @@ func (w *worker) deriveRenterKey(hostKey types.PublicKey) types.PrivateKey { // A worker talks to Sia hosts to perform contract and storage operations within // a renterd system. -type worker struct { +type Worker struct { alerts alerts.Alerter + rhp2Client *rhp2.Client + rhp3Client *rhp3.Client + allowPrivateIPs bool id string bus Bus masterKey [32]byte startTime time.Time + eventSubscriber iworker.EventSubscriber downloadManager *downloadManager uploadManager *uploadManager - accounts *accounts - cache iworker.WorkerCache - priceTables *priceTables - transportPoolV3 *transportPoolV3 + accounts *accounts + dialer *iworker.FallbackDialer + cache iworker.WorkerCache + priceTables *priceTables uploadsMu sync.Mutex uploadingPackedSlabs map[string]struct{} @@ -229,7 +232,7 @@ type worker struct { logger *zap.SugaredLogger } -func (w *worker) isStopped() bool { +func (w *Worker) isStopped() bool { select { case <-w.shutdownCtx.Done(): return true @@ -238,7 +241,7 @@ func (w *worker) isStopped() bool { return false } -func (w *worker) withRevision(ctx context.Context, fetchTimeout time.Duration, fcid types.FileContractID, hk types.PublicKey, siamuxAddr string, lockPriority int, fn func(rev types.FileContractRevision) error) error { +func (w *Worker) withRevision(ctx context.Context, fetchTimeout time.Duration, fcid types.FileContractID, hk types.PublicKey, siamuxAddr string, lockPriority int, fn func(rev types.FileContractRevision) error) error { return w.withContractLock(ctx, fcid, lockPriority, func() error { h := w.Host(hk, fcid, siamuxAddr) rev, err := h.FetchRevision(ctx, fetchTimeout) @@ -249,7 +252,7 @@ func (w *worker) withRevision(ctx context.Context, fetchTimeout time.Duration, f }) } -func (w *worker) registerAlert(a alerts.Alert) { +func (w *Worker) registerAlert(a alerts.Alert) { ctx, cancel := context.WithTimeout(w.shutdownCtx, time.Minute) if err := w.alerts.RegisterAlert(ctx, a); err != nil { w.logger.Errorf("failed to register alert, err: %v", err) @@ -257,7 +260,7 @@ func (w *worker) registerAlert(a alerts.Alert) { cancel() } -func (w *worker) rhpScanHandler(jc jape.Context) { +func (w *Worker) rhpScanHandler(jc jape.Context) { ctx := jc.Request.Context() // decode the request @@ -291,7 +294,7 @@ func (w *worker) rhpScanHandler(jc jape.Context) { }) } -func (w *worker) fetchContracts(ctx context.Context, metadatas []api.ContractMetadata, timeout time.Duration) (contracts []api.Contract, errs HostErrorSet) { +func (w *Worker) fetchContracts(ctx context.Context, metadatas []api.ContractMetadata, timeout time.Duration) (contracts []api.Contract, errs HostErrorSet) { errs = make(HostErrorSet) // create requests channel @@ -343,7 +346,7 @@ func (w *worker) fetchContracts(ctx context.Context, metadatas []api.ContractMet return } -func (w *worker) rhpPriceTableHandler(jc jape.Context) { +func (w *Worker) rhpPriceTableHandler(jc jape.Context) { // decode the request var rptr api.RHPPriceTableRequest if jc.Decode(&rptr) != nil { @@ -355,14 +358,16 @@ func (w *worker) rhpPriceTableHandler(jc jape.Context) { var err error var hpt api.HostPriceTable defer func() { - w.bus.RecordPriceTables(jc.Request.Context(), []api.HostPriceTableUpdate{ - { - HostKey: rptr.HostKey, - Success: isSuccessfulInteraction(err), - Timestamp: time.Now(), - PriceTable: hpt, - }, - }) + if shouldRecordPriceTable(err) { + w.bus.RecordPriceTables(jc.Request.Context(), []api.HostPriceTableUpdate{ + { + HostKey: rptr.HostKey, + Success: err == nil, + Timestamp: time.Now(), + PriceTable: hpt, + }, + }) + } }() // apply timeout @@ -373,22 +378,14 @@ func (w *worker) rhpPriceTableHandler(jc jape.Context) { defer cancel() } - err = w.transportPoolV3.withTransportV3(ctx, rptr.HostKey, rptr.SiamuxAddr, func(ctx context.Context, t *transportV3) error { - hpt, err = RPCPriceTable(ctx, t, func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) { return nil, nil }) - return err - }) - + hpt, err = w.Host(rptr.HostKey, types.FileContractID{}, rptr.SiamuxAddr).PriceTableUnpaid(ctx) if jc.Check("could not get price table", err) != nil { return } jc.Encode(hpt) } -func (w *worker) discardTxnOnErr(txn types.Transaction, errContext string, err *error) { - discardTxnOnErr(w.shutdownCtx, w.bus, w.logger, txn, errContext, err) -} - -func (w *worker) rhpFormHandler(jc jape.Context) { +func (w *Worker) rhpFormHandler(jc jape.Context) { ctx := jc.Request.Context() // decode the request @@ -411,42 +408,20 @@ func (w *worker) rhpFormHandler(jc jape.Context) { if jc.Check("could not get gouging parameters", err) != nil { return } + gc := newGougingChecker(gp.GougingSettings, gp.ConsensusState, gp.TransactionFee, false) hostIP, hostKey, renterFunds := rfr.HostIP, rfr.HostKey, rfr.RenterFunds renterAddress, endHeight, hostCollateral := rfr.RenterAddress, rfr.EndHeight, rfr.HostCollateral renterKey := w.deriveRenterKey(hostKey) - var contract rhpv2.ContractRevision - var txnSet []types.Transaction - ctx = WithGougingChecker(ctx, w.bus, gp) - err = w.withTransportV2(ctx, rfr.HostKey, hostIP, func(t *rhpv2.Transport) (err error) { - hostSettings, err := RPCSettings(ctx, t) - if err != nil { - return err - } - // NOTE: we overwrite the NetAddress with the host address here since we - // just used it to dial the host we know it's valid - hostSettings.NetAddress = hostIP - - gc, err := GougingCheckerFromContext(ctx, false) + contract, txnSet, err := w.rhp2Client.FormContract(ctx, renterAddress, renterKey, hostKey, hostIP, renterFunds, hostCollateral, endHeight, gc, func(ctx context.Context, renterAddress types.Address, renterKey types.PublicKey, renterFunds, hostCollateral types.Currency, hostKey types.PublicKey, hostSettings rhpv2.HostSettings, endHeight uint64) (txns []types.Transaction, discard func(types.Transaction), err error) { + txns, err = w.bus.WalletPrepareForm(ctx, renterAddress, renterKey, renterFunds, hostCollateral, hostKey, hostSettings, endHeight) if err != nil { - return err - } - if breakdown := gc.Check(&hostSettings, nil); breakdown.Gouging() { - return fmt.Errorf("failed to form contract, gouging check failed: %v", breakdown) + return nil, nil, err } - - renterTxnSet, err := w.bus.WalletPrepareForm(ctx, renterAddress, renterKey.PublicKey(), renterFunds, hostCollateral, hostKey, hostSettings, endHeight) - if err != nil { - return err - } - defer w.discardTxnOnErr(renterTxnSet[len(renterTxnSet)-1], "rhpFormHandler", &err) - - contract, txnSet, err = RPCFormContract(ctx, t, renterKey, renterTxnSet) - if err != nil { - return err - } - return + return txns, func(txn types.Transaction) { + _ = w.bus.WalletDiscard(ctx, txn) + }, nil }) if jc.Check("couldn't form contract", err) != nil { return @@ -454,7 +429,7 @@ func (w *worker) rhpFormHandler(jc jape.Context) { // broadcast the transaction set err = w.bus.BroadcastTransaction(ctx, txnSet) - if err != nil && !isErrDuplicateTransactionSet(err) { + if err != nil { w.logger.Errorf("failed to broadcast formation txn set: %v", err) } @@ -465,7 +440,7 @@ func (w *worker) rhpFormHandler(jc jape.Context) { }) } -func (w *worker) rhpBroadcastHandler(jc jape.Context) { +func (w *Worker) rhpBroadcastHandler(jc jape.Context) { ctx := jc.Request.Context() // decode the fcid @@ -488,10 +463,11 @@ func (w *worker) rhpBroadcastHandler(jc jape.Context) { } rk := w.deriveRenterKey(c.HostKey) - rev, err := w.FetchSignedRevision(ctx, c.HostIP, c.HostKey, rk, fcid, time.Minute) + rev, err := w.rhp2Client.SignedRevision(ctx, c.HostIP, c.HostKey, rk, fcid, time.Minute) if jc.Check("could not fetch revision", err) != nil { return } + // Create txn with revision. txn := types.Transaction{ FileContractRevisions: []types.FileContractRevision{rev.Revision}, @@ -521,7 +497,7 @@ func (w *worker) rhpBroadcastHandler(jc jape.Context) { } } -func (w *worker) rhpPruneContractHandlerPOST(jc jape.Context) { +func (w *Worker) rhpPruneContractHandlerPOST(jc jape.Context) { ctx := jc.Request.Context() // decode fcid @@ -566,13 +542,24 @@ func (w *worker) rhpPruneContractHandlerPOST(jc jape.Context) { if jc.Check("could not fetch gouging parameters", err) != nil { return } - - // attach gouging checker - ctx = WithGougingChecker(ctx, w.bus, gp) + gc := newGougingChecker(gp.GougingSettings, gp.ConsensusState, gp.TransactionFee, false) // prune the contract - pruned, remaining, err := w.PruneContract(ctx, contract.HostIP, contract.HostKey, fcid, contract.RevisionNumber) - if err != nil && !errors.Is(err, ErrNoSectorsToPrune) && pruned == 0 { + var pruned, remaining uint64 + var rev *types.FileContractRevision + var cost types.Currency + err = w.withContractLock(ctx, contract.ID, lockingPriorityPruning, func() error { + stored, pending, err := w.bus.ContractRoots(ctx, contract.ID) + if err != nil { + return fmt.Errorf("failed to fetch contract roots; %w", err) + } + rev, pruned, remaining, cost, err = w.rhp2Client.PruneContract(ctx, w.deriveRenterKey(contract.HostKey), gc, contract.HostIP, contract.HostKey, fcid, contract.RevisionNumber, append(stored, pending...)) + return err + }) + if rev != nil { + w.contractSpendingRecorder.Record(*rev, api.ContractSpending{Deletions: cost}) + } + if err != nil && !errors.Is(err, rhp2.ErrNoSectorsToPrune) && pruned == 0 { err = fmt.Errorf("failed to prune contract %v; %w", fcid, err) jc.Error(err, http.StatusInternalServerError) return @@ -588,7 +575,7 @@ func (w *worker) rhpPruneContractHandlerPOST(jc jape.Context) { jc.Encode(res) } -func (w *worker) rhpContractRootsHandlerGET(jc jape.Context) { +func (w *Worker) rhpContractRootsHandlerGET(jc jape.Context) { ctx := jc.Request.Context() // decode fcid @@ -611,18 +598,19 @@ func (w *worker) rhpContractRootsHandlerGET(jc jape.Context) { if jc.Check("couldn't fetch gouging parameters from bus", err) != nil { return } - - // attach gouging checker to the context - ctx = WithGougingChecker(ctx, w.bus, gp) + gc := newGougingChecker(gp.GougingSettings, gp.ConsensusState, gp.TransactionFee, false) // fetch the roots from the host - roots, err := w.FetchContractRoots(ctx, c.HostIP, c.HostKey, id, c.RevisionNumber) - if jc.Check("couldn't fetch contract roots from host", err) == nil { - jc.Encode(roots) + roots, rev, cost, err := w.rhp2Client.ContractRoots(ctx, w.deriveRenterKey(c.HostKey), gc, c.HostIP, c.HostKey, id, c.RevisionNumber) + if jc.Check("couldn't fetch contract roots from host", err) != nil { + return + } else if rev != nil { + w.contractSpendingRecorder.Record(*rev, api.ContractSpending{SectorRoots: cost}) } + jc.Encode(roots) } -func (w *worker) rhpRenewHandler(jc jape.Context) { +func (w *Worker) rhpRenewHandler(jc jape.Context) { ctx := jc.Request.Context() // decode request @@ -648,7 +636,7 @@ func (w *worker) rhpRenewHandler(jc jape.Context) { var renewed rhpv2.ContractRevision var txnSet []types.Transaction var contractPrice, fundAmount types.Currency - if jc.Check("couldn't renew contract", w.withRevision(ctx, defaultRevisionFetchTimeout, rrr.ContractID, rrr.HostKey, rrr.SiamuxAddr, lockingPriorityRenew, func(_ types.FileContractRevision) (err error) { + if jc.Check("couldn't renew contract", w.withContractLock(ctx, rrr.ContractID, lockingPriorityRenew, func() (err error) { h := w.Host(rrr.HostKey, rrr.ContractID, rrr.SiamuxAddr) renewed, txnSet, contractPrice, fundAmount, err = h.RenewContract(ctx, rrr) return err @@ -658,7 +646,7 @@ func (w *worker) rhpRenewHandler(jc jape.Context) { // broadcast the transaction set err = w.bus.BroadcastTransaction(ctx, txnSet) - if err != nil && !isErrDuplicateTransactionSet(err) { + if err != nil { w.logger.Errorf("failed to broadcast renewal txn set: %v", err) } @@ -672,7 +660,7 @@ func (w *worker) rhpRenewHandler(jc jape.Context) { }) } -func (w *worker) rhpFundHandler(jc jape.Context) { +func (w *Worker) rhpFundHandler(jc jape.Context) { ctx := jc.Request.Context() // decode request @@ -689,28 +677,12 @@ func (w *worker) rhpFundHandler(jc jape.Context) { ctx = WithGougingChecker(ctx, w.bus, gp) // fund the account - jc.Check("couldn't fund account", w.withRevision(ctx, defaultRevisionFetchTimeout, rfr.ContractID, rfr.HostKey, rfr.SiamuxAddr, lockingPriorityFunding, func(rev types.FileContractRevision) (err error) { - h := w.Host(rfr.HostKey, rev.ParentID, rfr.SiamuxAddr) - err = h.FundAccount(ctx, rfr.Balance, &rev) - if isBalanceMaxExceeded(err) { - // sync the account - err = h.SyncAccount(ctx, &rev) - if err != nil { - w.logger.Infof(fmt.Sprintf("failed to sync account: %v", err), "host", rfr.HostKey) - return - } - - // try funding the account again - err = h.FundAccount(ctx, rfr.Balance, &rev) - if err != nil { - w.logger.Errorw(fmt.Sprintf("failed to fund account after syncing: %v", err), "host", rfr.HostKey, "balance", rfr.Balance) - } - } - return + jc.Check("couldn't fund account", w.withRevision(ctx, defaultRevisionFetchTimeout, rfr.ContractID, rfr.HostKey, rfr.SiamuxAddr, lockingPriorityFunding, func(rev types.FileContractRevision) error { + return w.Host(rfr.HostKey, rev.ParentID, rfr.SiamuxAddr).FundAccount(ctx, rfr.Balance, &rev) })) } -func (w *worker) rhpSyncHandler(jc jape.Context) { +func (w *Worker) rhpSyncHandler(jc jape.Context) { ctx := jc.Request.Context() // decode the request @@ -733,7 +705,7 @@ func (w *worker) rhpSyncHandler(jc jape.Context) { })) } -func (w *worker) slabMigrateHandler(jc jape.Context) { +func (w *Worker) slabMigrateHandler(jc jape.Context) { ctx := jc.Request.Context() // decode the slab @@ -784,10 +756,12 @@ func (w *worker) slabMigrateHandler(jc jape.Context) { return } - // fetch upload contracts - ulContracts, err := w.bus.Contracts(ctx, api.ContractsOpts{ContractSet: up.ContractSet}) - if jc.Check("couldn't fetch contracts from bus", err) != nil { - return + // filter upload contracts + var ulContracts []api.ContractMetadata + for _, c := range dlContracts { + if c.InSet(up.ContractSet) { + ulContracts = append(ulContracts, c) + } } // migrate the slab @@ -807,7 +781,7 @@ func (w *worker) slabMigrateHandler(jc jape.Context) { }) } -func (w *worker) downloadsStatsHandlerGET(jc jape.Context) { +func (w *Worker) downloadsStatsHandlerGET(jc jape.Context) { stats := w.downloadManager.Stats() // prepare downloaders stats @@ -837,7 +811,7 @@ func (w *worker) downloadsStatsHandlerGET(jc jape.Context) { }) } -func (w *worker) uploadsStatsHandlerGET(jc jape.Context) { +func (w *Worker) uploadsStatsHandlerGET(jc jape.Context) { stats := w.uploadManager.Stats() // prepare upload stats @@ -862,7 +836,7 @@ func (w *worker) uploadsStatsHandlerGET(jc jape.Context) { }) } -func (w *worker) objectsHandlerHEAD(jc jape.Context) { +func (w *Worker) objectsHandlerHEAD(jc jape.Context) { // parse bucket bucket := api.DefaultBucketName if jc.DecodeForm("bucket", &bucket) != nil { @@ -921,7 +895,7 @@ func (w *worker) objectsHandlerHEAD(jc jape.Context) { serveContent(jc.ResponseWriter, jc.Request, path, bytes.NewReader(nil), *hor) } -func (w *worker) objectsHandlerGET(jc jape.Context) { +func (w *Worker) objectsHandlerGET(jc jape.Context) { jc.Custom(nil, []api.ObjectMetadata{}) ctx := jc.Request.Context() @@ -1014,7 +988,7 @@ func (w *worker) objectsHandlerGET(jc jape.Context) { serveContent(jc.ResponseWriter, jc.Request, path, gor.Content, gor.HeadObjectResponse) } -func (w *worker) objectsHandlerPUT(jc jape.Context) { +func (w *Worker) objectsHandlerPUT(jc jape.Context) { jc.Custom((*[]byte)(nil), nil) ctx := jc.Request.Context() @@ -1085,7 +1059,7 @@ func (w *worker) objectsHandlerPUT(jc jape.Context) { jc.ResponseWriter.Header().Set("ETag", api.FormatETag(resp.ETag)) } -func (w *worker) multipartUploadHandlerPUT(jc jape.Context) { +func (w *Worker) multipartUploadHandlerPUT(jc jape.Context) { jc.Custom((*[]byte)(nil), nil) ctx := jc.Request.Context() @@ -1173,7 +1147,7 @@ func (w *worker) multipartUploadHandlerPUT(jc jape.Context) { jc.ResponseWriter.Header().Set("ETag", api.FormatETag(resp.ETag)) } -func (w *worker) objectsHandlerDELETE(jc jape.Context) { +func (w *Worker) objectsHandlerDELETE(jc jape.Context) { var batch bool if jc.DecodeForm("batch", &batch) != nil { return @@ -1190,7 +1164,7 @@ func (w *worker) objectsHandlerDELETE(jc jape.Context) { jc.Check("couldn't delete object", err) } -func (w *worker) rhpContractsHandlerGET(jc jape.Context) { +func (w *Worker) rhpContractsHandlerGET(jc jape.Context) { ctx := jc.Request.Context() // fetch contracts @@ -1226,37 +1200,18 @@ func (w *worker) rhpContractsHandlerGET(jc jape.Context) { jc.Encode(resp) } -func (w *worker) idHandlerGET(jc jape.Context) { +func (w *Worker) idHandlerGET(jc jape.Context) { jc.Encode(w.id) } -func (w *worker) eventsHandler(jc jape.Context) { - var event webhooks.Event - if jc.Decode(&event) != nil { - return - } else if event.Event == webhooks.WebhookEventPing { - jc.ResponseWriter.WriteHeader(http.StatusOK) - return - } - - err := w.cache.HandleEvent(event) - if errors.Is(err, api.ErrUnknownEvent) { - jc.ResponseWriter.WriteHeader(http.StatusAccepted) - return - } else if err != nil { - jc.Error(err, http.StatusBadRequest) - return - } -} - -func (w *worker) memoryGET(jc jape.Context) { +func (w *Worker) memoryGET(jc jape.Context) { jc.Encode(api.MemoryResponse{ Download: w.downloadManager.mm.Status(), Upload: w.uploadManager.mm.Status(), }) } -func (w *worker) accountHandlerGET(jc jape.Context) { +func (w *Worker) accountHandlerGET(jc jape.Context) { var hostKey types.PublicKey if jc.DecodeParam("hostkey", &hostKey) != nil { return @@ -1265,12 +1220,22 @@ func (w *worker) accountHandlerGET(jc jape.Context) { jc.Encode(account) } -func (w *worker) stateHandlerGET(jc jape.Context) { +func (w *Worker) eventsHandlerPOST(jc jape.Context) { + var event webhooks.Event + if jc.Decode(&event) != nil { + return + } else if event.Event == webhooks.WebhookEventPing { + jc.ResponseWriter.WriteHeader(http.StatusOK) + } else { + w.eventSubscriber.ProcessEvent(event) + } +} + +func (w *Worker) stateHandlerGET(jc jape.Context) { jc.Encode(api.WorkerStateResponse{ ID: w.id, StartTime: api.TimeRFC3339(w.startTime), BuildState: api.BuildState{ - Network: build.NetworkName(), Version: build.Version(), Commit: build.Commit(), OS: runtime.GOOS, @@ -1280,38 +1245,45 @@ func (w *worker) stateHandlerGET(jc jape.Context) { } // New returns an HTTP handler that serves the worker API. -func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlushInterval, downloadOverdriveTimeout, uploadOverdriveTimeout time.Duration, downloadMaxOverdrive, uploadMaxOverdrive, downloadMaxMemory, uploadMaxMemory uint64, allowPrivateIPs bool, l *zap.Logger) (*worker, error) { - if contractLockingDuration == 0 { +func New(cfg config.Worker, masterKey [32]byte, b Bus, l *zap.Logger) (*Worker, error) { + l = l.Named("worker").Named(cfg.ID) + + if cfg.ContractLockTimeout == 0 { return nil, errors.New("contract lock duration must be positive") } - if busFlushInterval == 0 { + if cfg.BusFlushInterval == 0 { return nil, errors.New("bus flush interval must be positive") } - if downloadOverdriveTimeout == 0 { + if cfg.DownloadOverdriveTimeout == 0 { return nil, errors.New("download overdrive timeout must be positive") } - if uploadOverdriveTimeout == 0 { + if cfg.UploadOverdriveTimeout == 0 { return nil, errors.New("upload overdrive timeout must be positive") } - if downloadMaxMemory == 0 { + if cfg.DownloadMaxMemory == 0 { return nil, errors.New("downloadMaxMemory cannot be 0") } - if uploadMaxMemory == 0 { + if cfg.UploadMaxMemory == 0 { return nil, errors.New("uploadMaxMemory cannot be 0") } - l = l.Named("worker").Named(id) - cache := iworker.NewCache(b, l) + a := alerts.WithOrigin(b, fmt.Sprintf("worker.%s", cfg.ID)) shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) - w := &worker{ - alerts: alerts.WithOrigin(b, fmt.Sprintf("worker.%s", id)), - allowPrivateIPs: allowPrivateIPs, - contractLockingDuration: contractLockingDuration, - cache: cache, - id: id, + + dialer := iworker.NewFallbackDialer(b, net.Dialer{}, l) + w := &Worker{ + alerts: a, + allowPrivateIPs: cfg.AllowPrivateIPs, + contractLockingDuration: cfg.ContractLockTimeout, + cache: iworker.NewCache(b, l), + dialer: dialer, + eventSubscriber: iworker.NewEventSubscriber(a, b, l, 10*time.Second), + id: cfg.ID, bus: b, masterKey: masterKey, logger: l.Sugar(), + rhp2Client: rhp2.New(dialer, l), + rhp3Client: rhp3.New(dialer, l), startTime: time.Now(), uploadingPackedSlabs: make(map[string]struct{}), shutdownCtx: shutdownCtx, @@ -1320,22 +1292,21 @@ func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlush w.initAccounts(b) w.initPriceTables() - w.initTransportPool() - w.initDownloadManager(downloadMaxMemory, downloadMaxOverdrive, downloadOverdriveTimeout, l.Named("downloadmanager").Sugar()) - w.initUploadManager(uploadMaxMemory, uploadMaxOverdrive, uploadOverdriveTimeout, l.Named("uploadmanager").Sugar()) + w.initDownloadManager(cfg.DownloadMaxMemory, cfg.DownloadMaxOverdrive, cfg.DownloadOverdriveTimeout, l) + w.initUploadManager(cfg.UploadMaxMemory, cfg.UploadMaxOverdrive, cfg.UploadOverdriveTimeout, l) - w.initContractSpendingRecorder(busFlushInterval) + w.initContractSpendingRecorder(cfg.BusFlushInterval) return w, nil } // Handler returns an HTTP handler that serves the worker API. -func (w *worker) Handler() http.Handler { +func (w *Worker) Handler() http.Handler { return jape.Mux(map[string]jape.Handler{ "GET /account/:hostkey": w.accountHandlerGET, "GET /id": w.idHandlerGET, - "POST /events": w.eventsHandler, + "POST /events": w.eventsHandlerPOST, "GET /memory": w.memoryGET, @@ -1365,14 +1336,21 @@ func (w *worker) Handler() http.Handler { }) } -// Setup initializes the worker cache. -func (w *worker) Setup(ctx context.Context, apiURL, apiPassword string) error { - webhookOpts := []webhooks.HeaderOption{webhooks.WithBasicAuth("", apiPassword)} - return w.cache.Initialize(ctx, apiURL, webhookOpts...) +// Setup register event webhooks that enable the worker cache. +func (w *Worker) Setup(ctx context.Context, apiURL, apiPassword string) error { + go func() { + eventsURL := fmt.Sprintf("%s/events", apiURL) + webhookOpts := []webhooks.HeaderOption{webhooks.WithBasicAuth("", apiPassword)} + if err := w.eventSubscriber.Register(w.shutdownCtx, eventsURL, webhookOpts...); err != nil { + w.logger.Errorw("failed to register webhooks", zap.Error(err)) + } + }() + + return w.cache.Subscribe(w.eventSubscriber) } // Shutdown shuts down the worker. -func (w *worker) Shutdown(ctx context.Context) error { +func (w *Worker) Shutdown(ctx context.Context) error { // cancel shutdown context w.shutdownCtxCancel() @@ -1382,10 +1360,12 @@ func (w *worker) Shutdown(ctx context.Context) error { // stop recorders w.contractSpendingRecorder.Stop(ctx) - return nil + + // shutdown the subscriber + return w.eventSubscriber.Shutdown(ctx) } -func (w *worker) scanHost(ctx context.Context, timeout time.Duration, hostKey types.PublicKey, hostIP string) (rhpv2.HostSettings, rhpv3.HostPriceTable, time.Duration, error) { +func (w *Worker) scanHost(ctx context.Context, timeout time.Duration, hostKey types.PublicKey, hostIP string) (rhpv2.HostSettings, rhpv3.HostPriceTable, time.Duration, error) { logger := w.logger.With("host", hostKey).With("hostIP", hostIP).With("timeout", timeout) // prepare a helper to create a context for scanning @@ -1400,50 +1380,28 @@ func (w *worker) scanHost(ctx context.Context, timeout time.Duration, hostKey ty scan := func() (rhpv2.HostSettings, rhpv3.HostPriceTable, time.Duration, error) { // fetch the host settings start := time.Now() - var settings rhpv2.HostSettings - { - scanCtx, cancel := timeoutCtx() - defer cancel() - err := w.withTransportV2(scanCtx, hostKey, hostIP, func(t *rhpv2.Transport) error { - var err error - if settings, err = RPCSettings(scanCtx, t); err != nil { - return fmt.Errorf("failed to fetch host settings: %w", err) - } - // NOTE: we overwrite the NetAddress with the host address here - // since we just used it to dial the host we know it's valid - settings.NetAddress = hostIP - return nil - }) - if err != nil { - return settings, rhpv3.HostPriceTable{}, time.Since(start), err - } + scanCtx, cancel := timeoutCtx() + settings, err := w.rhp2Client.Settings(scanCtx, hostKey, hostIP) + cancel() + if err != nil { + return settings, rhpv3.HostPriceTable{}, time.Since(start), err } // fetch the host pricetable - var pt rhpv3.HostPriceTable - { - scanCtx, cancel := timeoutCtx() - defer cancel() - err := w.transportPoolV3.withTransportV3(scanCtx, hostKey, settings.SiamuxAddr(), func(ctx context.Context, t *transportV3) error { - if hpt, err := RPCPriceTable(ctx, t, func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, error) { return nil, nil }); err != nil { - return fmt.Errorf("failed to fetch host price table: %w", err) - } else { - pt = hpt.HostPriceTable - return nil - } - }) - if err != nil { - return settings, rhpv3.HostPriceTable{}, time.Since(start), err - } + scanCtx, cancel = timeoutCtx() + pt, err := w.rhp3Client.PriceTableUnpaid(ctx, hostKey, settings.SiamuxAddr()) + cancel() + if err != nil { + return settings, rhpv3.HostPriceTable{}, time.Since(start), err } - return settings, pt, time.Since(start), nil + return settings, pt.HostPriceTable, time.Since(start), nil } // resolve host ip, don't scan if the host is on a private network or if it // resolves to more than two addresses of the same type, if it fails for // another reason the host scan won't have subnets - subnets, private, err := utils.ResolveHostIP(ctx, hostIP) - if errors.Is(err, api.ErrHostTooManyAddresses) { + resolvedAddresses, private, err := utils.ResolveHostIP(ctx, hostIP) + if errors.Is(err, utils.ErrHostTooManyAddresses) { return rhpv2.HostSettings{}, rhpv3.HostPriceTable{}, 0, err } else if private && !w.allowPrivateIPs { return rhpv2.HostSettings{}, rhpv3.HostPriceTable{}, 0, api.ErrHostOnPrivateNetwork @@ -1484,32 +1442,26 @@ func (w *worker) scanHost(ctx context.Context, timeout time.Duration, hostKey ty // Otherwise scans that time out won't be recorded. scanErr := w.bus.RecordHostScans(ctx, []api.HostScan{ { - HostKey: hostKey, - PriceTable: pt, - Subnets: subnets, - Success: isSuccessfulInteraction(err), - Settings: settings, - Timestamp: time.Now(), + HostKey: hostKey, + PriceTable: pt, + ResolvedAddresses: resolvedAddresses, + + // NOTE: A scan is considered successful if both fetching the price + // table and the settings succeeded. Right now scanning can't fail + // due to a reason that is our fault unless we are offline. If that + // changes, we should adjust this code to account for that. + Success: err == nil, + Settings: settings, + Timestamp: time.Now(), }, }) if scanErr != nil { logger.Errorw("failed to record host scan", zap.Error(scanErr)) } + logger.With(zap.Error(err)).Debugw("scanned host", "success", err == nil) return settings, pt, duration, err } -func discardTxnOnErr(ctx context.Context, bus Bus, l *zap.SugaredLogger, txn types.Transaction, errContext string, err *error) { - if *err == nil { - return - } - - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - if dErr := bus.WalletDiscard(ctx, txn); dErr != nil { - l.Errorf("%v: %s, failed to discard txn: %v", *err, errContext, dErr) - } - cancel() -} - func isErrHostUnreachable(err error) bool { return utils.IsErr(err, os.ErrDeadlineExceeded) || utils.IsErr(err, context.DeadlineExceeded) || @@ -1521,11 +1473,7 @@ func isErrHostUnreachable(err error) bool { utils.IsErr(err, errors.New("cannot assign requested address")) } -func isErrDuplicateTransactionSet(err error) bool { - return utils.IsErr(err, modules.ErrDuplicateTransactionSet) -} - -func (w *worker) headObject(ctx context.Context, bucket, path string, onlyMetadata bool, opts api.HeadObjectOptions) (*api.HeadObjectResponse, api.ObjectsResponse, error) { +func (w *Worker) headObject(ctx context.Context, bucket, path string, onlyMetadata bool, opts api.HeadObjectOptions) (*api.HeadObjectResponse, api.ObjectsResponse, error) { // fetch object res, err := w.bus.Object(ctx, bucket, path, api.GetObjectOptions{ IgnoreDelim: opts.IgnoreDelim, @@ -1560,7 +1508,7 @@ func (w *worker) headObject(ctx context.Context, bucket, path string, onlyMetada }, res, nil } -func (w *worker) GetObject(ctx context.Context, bucket, path string, opts api.DownloadObjectOptions) (*api.GetObjectResponse, error) { +func (w *Worker) GetObject(ctx context.Context, bucket, path string, opts api.DownloadObjectOptions) (*api.GetObjectResponse, error) { // head object hor, res, err := w.headObject(ctx, bucket, path, false, api.HeadObjectOptions{ IgnoreDelim: opts.IgnoreDelim, @@ -1626,12 +1574,12 @@ func (w *worker) GetObject(ctx context.Context, bucket, path string, opts api.Do }, nil } -func (w *worker) HeadObject(ctx context.Context, bucket, path string, opts api.HeadObjectOptions) (*api.HeadObjectResponse, error) { +func (w *Worker) HeadObject(ctx context.Context, bucket, path string, opts api.HeadObjectOptions) (*api.HeadObjectResponse, error) { res, _, err := w.headObject(ctx, bucket, path, true, opts) return res, err } -func (w *worker) UploadObject(ctx context.Context, r io.Reader, bucket, path string, opts api.UploadObjectOptions) (*api.UploadObjectResponse, error) { +func (w *Worker) UploadObject(ctx context.Context, r io.Reader, bucket, path string, opts api.UploadObjectOptions) (*api.UploadObjectResponse, error) { // prepare upload params up, err := w.prepareUploadParams(ctx, bucket, opts.ContractSet, opts.MinShards, opts.TotalShards) if err != nil { @@ -1648,12 +1596,11 @@ func (w *worker) UploadObject(ctx context.Context, r io.Reader, bucket, path str } // upload - eTag, err := w.upload(ctx, bucket, path, r, contracts, + eTag, err := w.upload(ctx, bucket, path, up.RedundancySettings, r, contracts, WithBlockHeight(up.CurrentHeight), WithContractSet(up.ContractSet), WithMimeType(opts.MimeType), WithPacking(up.UploadPacking), - WithRedundancySettings(up.RedundancySettings), WithObjectUserMetadata(opts.Metadata), ) if err != nil { @@ -1668,7 +1615,7 @@ func (w *worker) UploadObject(ctx context.Context, r io.Reader, bucket, path str }, nil } -func (w *worker) UploadMultipartUploadPart(ctx context.Context, r io.Reader, bucket, path, uploadID string, partNumber int, opts api.UploadMultipartUploadPartOptions) (*api.UploadMultipartUploadPartResponse, error) { +func (w *Worker) UploadMultipartUploadPart(ctx context.Context, r io.Reader, bucket, path, uploadID string, partNumber int, opts api.UploadMultipartUploadPartOptions) (*api.UploadMultipartUploadPartResponse, error) { // prepare upload params up, err := w.prepareUploadParams(ctx, bucket, opts.ContractSet, opts.MinShards, opts.TotalShards) if err != nil { @@ -1689,7 +1636,6 @@ func (w *worker) UploadMultipartUploadPart(ctx context.Context, r io.Reader, buc WithBlockHeight(up.CurrentHeight), WithContractSet(up.ContractSet), WithPacking(up.UploadPacking), - WithRedundancySettings(up.RedundancySettings), WithCustomKey(upload.Key), WithPartNumber(partNumber), WithUploadID(uploadID), @@ -1711,7 +1657,7 @@ func (w *worker) UploadMultipartUploadPart(ctx context.Context, r io.Reader, buc } // upload - eTag, err := w.upload(ctx, bucket, path, r, contracts, uploadOpts...) + eTag, err := w.upload(ctx, bucket, path, up.RedundancySettings, r, contracts, uploadOpts...) if err != nil { w.logger.With(zap.Error(err)).With("path", path).With("bucket", bucket).Error("failed to upload object") if !errors.Is(err, ErrShuttingDown) && !errors.Is(err, errUploadInterrupted) && !errors.Is(err, context.Canceled) { @@ -1724,7 +1670,7 @@ func (w *worker) UploadMultipartUploadPart(ctx context.Context, r io.Reader, buc }, nil } -func (w *worker) prepareUploadParams(ctx context.Context, bucket string, contractSet string, minShards, totalShards int) (api.UploadParams, error) { +func (w *Worker) prepareUploadParams(ctx context.Context, bucket string, contractSet string, minShards, totalShards int) (api.UploadParams, error) { // return early if the bucket does not exist _, err := w.bus.Bucket(ctx, bucket) if err != nil { @@ -1759,3 +1705,32 @@ func (w *worker) prepareUploadParams(ctx context.Context, bucket string, contrac } return up, nil } + +// A HostErrorSet is a collection of errors from various hosts. +type HostErrorSet map[types.PublicKey]error + +// NumGouging returns numbers of host that errored out due to price gouging. +func (hes HostErrorSet) NumGouging() (n int) { + for _, he := range hes { + if errors.Is(he, gouging.ErrPriceTableGouging) { + n++ + } + } + return +} + +// Error implements error. +func (hes HostErrorSet) Error() string { + if len(hes) == 0 { + return "" + } + + var strs []string + for hk, he := range hes { + strs = append(strs, fmt.Sprintf("%x: %v", hk[:4], he.Error())) + } + + // include a leading newline so that the first error isn't printed on the + // same line as the error context + return "\n" + strings.Join(strs, "\n") +} diff --git a/worker/worker_test.go b/worker/worker_test.go index 706fae14e..f0822f03f 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -8,6 +8,7 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/config" "go.sia.tech/renterd/internal/test" "go.uber.org/zap" "golang.org/x/crypto/blake2b" @@ -17,7 +18,7 @@ import ( type ( testWorker struct { tt test.TT - *worker + *Worker cs *contractStoreMock os *objectStoreMock @@ -42,7 +43,7 @@ func newTestWorker(t test.TestingCommon) *testWorker { ulmm := newMemoryManagerMock() // create worker - w, err := New(blake2b.Sum256([]byte("testwork")), "test", b, time.Second, time.Second, time.Second, time.Second, 0, 0, 1, 1, false, zap.NewNop()) + w, err := New(newTestWorkerCfg(), blake2b.Sum256([]byte("testwork")), b, zap.NewNop()) if err != nil { t.Fatal(err) } @@ -129,6 +130,18 @@ func (w *testWorker) RenewContract(hk types.PublicKey) *contractMock { return renewal } +func newTestWorkerCfg() config.Worker { + return config.Worker{ + ID: "test", + ContractLockTimeout: time.Second, + BusFlushInterval: time.Second, + DownloadOverdriveTimeout: time.Second, + UploadOverdriveTimeout: time.Second, + DownloadMaxMemory: 1 << 12, // 4 KiB + UploadMaxMemory: 1 << 12, // 4 KiB + } +} + func newTestSector() (*[rhpv2.SectorSize]byte, types.Hash256) { var sector [rhpv2.SectorSize]byte frand.Read(sector[:])