From 7757ddcfc23f9083ee50b75ce55fbe03f3a3cc5e Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 22 Oct 2024 11:57:20 -0700 Subject: [PATCH] all: support rhp4 --- api/api.go | 27 +--- api/options.go | 14 -- api/prometheus.go | 21 --- api/rhpsessions.go | 45 ------- api/types.go | 4 - api/volumes.go | 4 +- cmd/hostd/main.go | 5 + cmd/hostd/run.go | 161 ++++++++++++++++++----- config/config.go | 33 ++++- go.mod | 3 +- go.sum | 2 - host/accounts/accounts.go | 3 + host/contracts/accounts.go | 36 +++++ host/contracts/contracts.go | 12 +- host/contracts/contracts_test.go | 17 +-- host/contracts/integrity.go | 2 +- host/contracts/lock.go | 141 ++++++++++++++++++++ host/contracts/manager.go | 133 +++++++------------ host/contracts/manager_test.go | 81 ++++++------ host/contracts/persist.go | 17 ++- host/contracts/update.go | 17 +-- host/settings/announce.go | 32 +++-- host/settings/announce_test.go | 42 +++--- host/settings/options.go | 9 ++ host/settings/pin/pin_test.go | 17 ++- host/settings/settings.go | 88 +++++++++---- host/settings/settings_test.go | 8 +- host/settings/update.go | 10 +- host/storage/persist.go | 14 +- host/storage/storage.go | 85 +++++++++++- host/storage/storage_test.go | 23 ++-- internal/testutil/testutil.go | 2 +- persist/sqlite/accounts.go | 121 +++++++++++++++++ persist/sqlite/contracts.go | 219 +++++++++++++++++-------------- persist/sqlite/contracts_test.go | 6 +- persist/sqlite/init.sql | 9 ++ persist/sqlite/sectors.go | 16 +++ persist/sqlite/sectors_test.go | 54 ++++++++ persist/sqlite/volumes.go | 106 +++++++++++++++ persist/sqlite/volumes_test.go | 158 ++++------------------ rhp/conn.go | 82 ------------ rhp/listener.go | 116 ++++++++++++++++ rhp/reporter.go | 204 ---------------------------- rhp/siamux.go | 47 +++++++ rhp/v2/options.go | 41 ------ rhp/v2/rhp.go | 50 ++----- rhp/v2/rpc.go | 10 +- rhp/v2/rpc_test.go | 8 +- rhp/v2/session.go | 11 +- rhp/v3/execute.go | 22 ++-- rhp/v3/options.go | 41 ------ rhp/v3/rhp.go | 56 +++----- rhp/v3/rpc_test.go | 4 +- rhp/v3/websockets.go | 63 --------- 54 files changed, 1382 insertions(+), 1170 deletions(-) delete mode 100644 api/rhpsessions.go create mode 100644 host/contracts/accounts.go create mode 100644 host/contracts/lock.go create mode 100644 persist/sqlite/sectors_test.go delete mode 100644 rhp/conn.go create mode 100644 rhp/listener.go delete mode 100644 rhp/reporter.go create mode 100644 rhp/siamux.go delete mode 100644 rhp/v2/options.go delete mode 100644 rhp/v3/options.go delete mode 100644 rhp/v3/websockets.go diff --git a/api/api.go b/api/api.go index bd5125c4..9f8f98b4 100644 --- a/api/api.go +++ b/api/api.go @@ -20,7 +20,6 @@ import ( "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/settings/pin" "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/rhp" "go.sia.tech/hostd/webhooks" "go.sia.tech/jape" "go.uber.org/zap" @@ -86,7 +85,7 @@ type ( SetReadOnly(id int64, readOnly bool) error RemoveSector(root types.Hash256) error ResizeCache(size uint32) - Read(types.Hash256) (*[rhp2.SectorSize]byte, error) + ReadSector(types.Hash256) (*[rhp2.SectorSize]byte, error) // SectorReferences returns the references to a sector SectorReferences(root types.Hash256) (storage.SectorReference, error) @@ -152,14 +151,6 @@ type ( BroadcastToWebhook(id int64, event, scope string, data interface{}) error } - // A RHPSessionReporter reports on RHP session lifecycle events - RHPSessionReporter interface { - Subscribe(rhp.SessionSubscriber) - Unsubscribe(rhp.SessionSubscriber) - - Active() []rhp.Session - } - // An api provides an HTTP API for the host api struct { hostKey types.PublicKey @@ -168,7 +159,6 @@ type ( log *zap.Logger alerts Alerts webhooks Webhooks - sessions RHPSessionReporter sqlite3Store SQLite3Store @@ -201,22 +191,12 @@ func (a *api) requiresExplorer(h jape.Handler) jape.Handler { } } -// NewServer initializes the API -// syncer -// chain -// accounts -// contracts -// volumes -// wallet -// metrics -// settings -// index +// NewServer initializes the API server with the given options func NewServer(name string, hostKey types.PublicKey, cm ChainManager, s Syncer, am AccountManager, c ContractManager, vm VolumeManager, wm Wallet, mm MetricManager, sm Settings, im Index, opts ...ServerOption) http.Handler { a := &api{ hostKey: hostKey, name: name, - sessions: noopSessionReporter{}, alerts: noopAlerts{}, webhooks: noopWebhooks{}, log: zap.NewNop(), @@ -291,9 +271,6 @@ func NewServer(name string, hostKey types.PublicKey, cm ChainManager, s Syncer, "DELETE /volumes/:id": a.handleDeleteVolume, "DELETE /volumes/:id/cancel": a.handleDELETEVolumeCancelOp, "PUT /volumes/:id/resize": a.handlePUTVolumeResize, - // session endpoints - "GET /sessions": a.handleGETSessions, - "GET /sessions/subscribe": a.handleGETSessionsSubscribe, // tpool endpoints "GET /tpool/fee": a.handleGETTPoolFee, // wallet endpoints diff --git a/api/options.go b/api/options.go index c1b4b8cd..160849d6 100644 --- a/api/options.go +++ b/api/options.go @@ -4,7 +4,6 @@ import ( "go.sia.tech/core/types" "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/explorer" - "go.sia.tech/hostd/rhp" "go.sia.tech/hostd/webhooks" "go.uber.org/zap" ) @@ -51,13 +50,6 @@ func WithExplorer(explorer *explorer.Explorer) ServerOption { } } -// WithRHPSessionReporter sets the RHP session reporter for the API server. -func WithRHPSessionReporter(rsr RHPSessionReporter) ServerOption { - return func(a *api) { - a.sessions = rsr - } -} - // WithLogger sets the logger for the API server. func WithLogger(log *zap.Logger) ServerOption { return func(a *api) { @@ -81,9 +73,3 @@ type noopAlerts struct{} func (noopAlerts) Active() []alerts.Alert { return nil } func (noopAlerts) Dismiss(...types.Hash256) {} - -type noopSessionReporter struct{} - -func (noopSessionReporter) Subscribe(rhp.SessionSubscriber) {} -func (noopSessionReporter) Unsubscribe(rhp.SessionSubscriber) {} -func (noopSessionReporter) Active() []rhp.Session { return nil } diff --git a/api/prometheus.go b/api/prometheus.go index 065ee119..afd88568 100644 --- a/api/prometheus.go +++ b/api/prometheus.go @@ -462,24 +462,3 @@ func (w WalletPendingResp) PrometheusMetric() (metrics []prometheus.Metric) { } return } - -// PrometheusMetric returns Prometheus samples for the hosts sessions -func (s SessionResp) PrometheusMetric() (metrics []prometheus.Metric) { - for _, session := range s { - metrics = append(metrics, prometheus.Metric{ - Name: "hostd_session_ingress", - Labels: map[string]any{ - "peer": session.PeerAddress, - }, - Value: float64(session.Ingress), - }) - metrics = append(metrics, prometheus.Metric{ - Name: "hostd_session_egress", - Labels: map[string]any{ - "peer": session.PeerAddress, - }, - Value: float64(session.Egress), - }) - } - return -} diff --git a/api/rhpsessions.go b/api/rhpsessions.go deleted file mode 100644 index b1a47d75..00000000 --- a/api/rhpsessions.go +++ /dev/null @@ -1,45 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - - "go.sia.tech/hostd/rhp" - "go.sia.tech/jape" - "go.uber.org/zap" - "nhooyr.io/websocket" -) - -type rhpSessionSubscriber struct { - conn *websocket.Conn -} - -func (rs *rhpSessionSubscriber) ReceiveSessionEvent(event rhp.SessionEvent) { - buf, err := json.Marshal(event) - if err != nil { - return - } - rs.conn.Write(context.Background(), websocket.MessageText, buf) -} - -func (a *api) handleGETSessions(c jape.Context) { - a.writeResponse(c, SessionResp(a.sessions.Active())) -} - -func (a *api) handleGETSessionsSubscribe(c jape.Context) { - wsc, err := websocket.Accept(c.ResponseWriter, c.Request, &websocket.AcceptOptions{ - OriginPatterns: []string{"*"}, - }) - if err != nil { - a.log.Warn("failed to accept websocket connection", zap.Error(err)) - return - } - defer wsc.Close(websocket.StatusNormalClosure, "") - - // subscribe the websocket conn - sub := &rhpSessionSubscriber{ - conn: wsc, - } - a.sessions.Subscribe(sub) - a.sessions.Unsubscribe(sub) -} diff --git a/api/types.go b/api/types.go index 1da9f07a..11c8f9d6 100644 --- a/api/types.go +++ b/api/types.go @@ -14,7 +14,6 @@ import ( "go.sia.tech/hostd/host/metrics" "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/rhp" ) // JSON keys for host setting fields @@ -193,9 +192,6 @@ type ( // WalletPendingResp is the response body for the [GET] /wallet/pending endpoint WalletPendingResp []wallet.Event - - // SessionResp is the response body for the [GET] /sessions endpoint - SessionResp []rhp.Session ) // MarshalJSON implements json.Marshaler diff --git a/api/volumes.go b/api/volumes.go index 9f77f805..491843de 100644 --- a/api/volumes.go +++ b/api/volumes.go @@ -224,10 +224,10 @@ func (a *api) handleGETVerifySector(jc jape.Context) { } // try to read the sector data and verify the root - data, err := a.volumes.Read(root) + sector, err := a.volumes.ReadSector(root) if err != nil { resp.Error = err.Error() - } else if calc := rhp2.SectorRoot(data); calc != root { + } else if calc := rhp2.SectorRoot(sector); calc != root { resp.Error = fmt.Sprintf("sector is corrupt: expected root %q, got %q", root, calc) } jc.Encode(resp) diff --git a/cmd/hostd/main.go b/cmd/hostd/main.go index 50ba60fe..4379a70e 100644 --- a/cmd/hostd/main.go +++ b/cmd/hostd/main.go @@ -57,6 +57,11 @@ var ( RHP3: config.RHP3{ TCPAddress: ":9983", }, + RHP4: config.RHP4{ + ListenAddresses: []config.RHP4ListenAddress{ + {Protocol: "tcp", Address: ":9984"}, + }, + }, Log: config.Log{ Path: os.Getenv(logFileEnvVar), // deprecated. included for compatibility. Level: "info", diff --git a/cmd/hostd/run.go b/cmd/hostd/run.go index 543d1f27..23e069c7 100644 --- a/cmd/hostd/run.go +++ b/cmd/hostd/run.go @@ -16,6 +16,7 @@ import ( "go.sia.tech/core/types" "go.sia.tech/coreutils" "go.sia.tech/coreutils/chain" + rhp4 "go.sia.tech/coreutils/rhp/v4" "go.sia.tech/coreutils/syncer" "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/alerts" @@ -96,6 +97,71 @@ func deleteSiadData(dir string) error { return nil } +func parseAnnounceAddresses(listen []config.RHP4ListenAddress, announce []config.RHP4AnnounceAddress) ([]chain.NetAddress, error) { + // build a map of the ports that each supported protocol is listening on + protocolPorts := make(map[chain.Protocol]map[uint16]bool) + for _, addr := range listen { + switch addr.Protocol { + case "tcp", "tcp4", "tcp6": + hostname, port, err := net.SplitHostPort(addr.Address) + if err != nil { + return nil, fmt.Errorf("failed to parse listen address %q: %w", addr.Address, err) + } else if ip := net.ParseIP(hostname); hostname != "" && ip == nil { + return nil, fmt.Errorf("rhp4 listen address %q should be an IP address", addr.Address) + } + ports, ok := protocolPorts[rhp4.ProtocolTCPSiaMux] + if !ok { + ports = make(map[uint16]bool) + } + n, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse port %q: %w", port, err) + } + ports[uint16(n)] = true + protocolPorts[rhp4.ProtocolTCPSiaMux] = ports + default: + return nil, fmt.Errorf("unsupported protocol %q: %s", addr.Address, addr.Protocol) + } + } + + var addrs []chain.NetAddress + for _, addr := range announce { + switch addr.Protocol { + case rhp4.ProtocolTCPSiaMux: + default: + return nil, fmt.Errorf("unsupported protocol %q", addr.Protocol) + } + + hostname, port, err := net.SplitHostPort(addr.Address) + if err != nil { + return nil, fmt.Errorf("failed to parse announce address %q: %w", addr.Address, err) + } + ip := net.ParseIP(hostname) + if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsGlobalUnicast()) { + return nil, fmt.Errorf("invalid announce address %q: must be a public IP address", addr.Address) + } + + n, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse port %q: %w", port, err) + } + + // ensure the port is being listened on + ports, ok := protocolPorts[addr.Protocol] + if !ok { + return nil, fmt.Errorf("no listen address for protocol %q", addr.Protocol) + } else if !ports[uint16(n)] { + return nil, fmt.Errorf("no listen address for protocol %q port %d", addr.Protocol, n) + } + + addrs = append(addrs, chain.NetAddress{ + Protocol: addr.Protocol, + Address: net.JoinHostPort(addr.Address, port), + }) + } + return addrs, nil +} + // startLocalhostListener https://github.com/SiaFoundation/hostd/issues/202 func startLocalhostListener(listenAddr string, log *zap.Logger) (l net.Listener, err error) { addr, port, err := net.SplitHostPort(listenAddr) @@ -186,18 +252,6 @@ func runRootCmd(ctx context.Context, cfg config.Config, walletKey types.PrivateK } defer syncerListener.Close() - rhp2Listener, err := net.Listen("tcp", cfg.RHP2.Address) - if err != nil { - return fmt.Errorf("failed to listen on rhp2 addr: %w", err) - } - defer rhp2Listener.Close() - - rhp3Listener, err := net.Listen("tcp", cfg.RHP3.TCPAddress) - if err != nil { - return fmt.Errorf("failed to listen on rhp3 addr: %w", err) - } - defer rhp3Listener.Close() - syncerAddr := syncerListener.Addr().String() if cfg.Syncer.EnableUPnP { _, portStr, _ := net.SplitHostPort(cfg.Syncer.Address) @@ -234,30 +288,22 @@ func runRootCmd(ctx context.Context, cfg config.Config, walletKey types.PrivateK GenesisID: genesisBlock.ID(), UniqueID: gateway.GenerateUniqueID(), NetAddress: syncerAddr, - }, syncer.WithLogger(log.Named("syncer"))) + }, syncer.WithLogger(log.Named("syncer")), syncer.WithMaxInboundPeers(64), syncer.WithMaxOutboundPeers(16)) go s.Run(ctx) defer s.Close() - wm, err := wallet.NewSingleAddressWallet(walletKey, cm, store, wallet.WithLogger(log.Named("wallet")), wallet.WithReservationDuration(3*time.Hour)) - if err != nil { - return fmt.Errorf("failed to create wallet: %w", err) - } - defer wm.Close() - wr, err := webhooks.NewManager(store, log.Named("webhooks")) if err != nil { return fmt.Errorf("failed to create webhook reporter: %w", err) } defer wr.Close() - sr := rhp.NewSessionReporter() - am := alerts.NewManager(alerts.WithEventReporter(wr), alerts.WithLog(log.Named("alerts"))) - cfm, err := settings.NewConfigManager(hostKey, store, cm, s, wm, settings.WithAlertManager(am), settings.WithLog(log.Named("settings"))) + wm, err := wallet.NewSingleAddressWallet(walletKey, cm, store, wallet.WithLogger(log.Named("wallet")), wallet.WithReservationDuration(3*time.Hour)) if err != nil { - return fmt.Errorf("failed to create settings manager: %w", err) + return fmt.Errorf("failed to create wallet: %w", err) } - defer cfm.Close() + defer wm.Close() vm, err := storage.NewVolumeManager(store, storage.WithLogger(log.Named("volumes")), storage.WithAlerter(am)) if err != nil { @@ -265,21 +311,47 @@ func runRootCmd(ctx context.Context, cfg config.Config, walletKey types.PrivateK } defer vm.Close() - contractManager, err := contracts.NewManager(store, vm, cm, s, wm, contracts.WithLog(log.Named("contracts")), contracts.WithAlerter(am)) + announceAddresses, err := parseAnnounceAddresses(cfg.RHP4.ListenAddresses, cfg.RHP4.AnnounceAddresses) + if err != nil { + return fmt.Errorf("failed to parse announce addresses: %w", err) + } + + sm, err := settings.NewConfigManager(hostKey, store, cm, s, wm, vm, + settings.WithRHP4AnnounceAddresses(announceAddresses), + settings.WithAlertManager(am), + settings.WithLog(log.Named("settings"))) + if err != nil { + return fmt.Errorf("failed to create settings manager: %w", err) + } + defer sm.Close() + + contracts, err := contracts.NewManager(store, vm, cm, s, wm, contracts.WithLog(log.Named("contracts")), contracts.WithAlerter(am)) if err != nil { return fmt.Errorf("failed to create contracts manager: %w", err) } - defer contractManager.Close() + defer contracts.Close() - index, err := index.NewManager(store, cm, contractManager, wm, cfm, vm, index.WithLog(log.Named("index")), index.WithBatchSize(cfg.Consensus.IndexBatchSize)) + index, err := index.NewManager(store, cm, contracts, wm, sm, vm, index.WithLog(log.Named("index")), index.WithBatchSize(cfg.Consensus.IndexBatchSize)) if err != nil { return fmt.Errorf("failed to create index manager: %w", err) } defer index.Close() dr := rhp.NewDataRecorder(store, log.Named("data")) + rl, wl := sm.RHPBandwidthLimiters() + rhp2Listener, err := rhp.Listen("tcp", cfg.RHP2.Address, rhp.WithDataMonitor(dr), rhp.WithReadLimit(rl), rhp.WithWriteLimit(wl)) + if err != nil { + return fmt.Errorf("failed to listen on rhp2 addr: %w", err) + } + defer rhp2Listener.Close() - rhp2, err := rhp2.NewSessionHandler(rhp2Listener, hostKey, rhp3Listener.Addr().String(), cm, s, wm, contractManager, cfm, vm, rhp2.WithDataMonitor(dr), rhp2.WithLog(log.Named("rhp2"))) + rhp3Listener, err := rhp.Listen("tcp", cfg.RHP3.TCPAddress, rhp.WithDataMonitor(dr), rhp.WithReadLimit(rl), rhp.WithWriteLimit(wl)) + if err != nil { + return fmt.Errorf("failed to listen on rhp3 addr: %w", err) + } + defer rhp3Listener.Close() + + rhp2, err := rhp2.NewSessionHandler(rhp2Listener, hostKey, rhp3Listener.Addr().String(), cm, s, wm, contracts, sm, vm, log.Named("rhp2")) if err != nil { return fmt.Errorf("failed to create rhp2 session handler: %w", err) } @@ -287,24 +359,47 @@ func runRootCmd(ctx context.Context, cfg config.Config, walletKey types.PrivateK defer rhp2.Close() registry := registry.NewManager(hostKey, store, log.Named("registry")) - accounts := accounts.NewManager(store, cfm) - rhp3, err := rhp3.NewSessionHandler(rhp3Listener, hostKey, cm, s, wm, accounts, contractManager, registry, vm, cfm, rhp3.WithDataMonitor(dr), rhp3.WithSessionReporter(sr), rhp3.WithLog(log.Named("rhp3"))) + accounts := accounts.NewManager(store, sm) + rhp3, err := rhp3.NewSessionHandler(rhp3Listener, hostKey, cm, s, wm, accounts, contracts, registry, vm, sm, log.Named("rhp3")) if err != nil { return fmt.Errorf("failed to create rhp3 session handler: %w", err) } go rhp3.Serve() defer rhp3.Close() + rhp4 := rhp4.NewServer(hostKey, cm, s, contracts, wm, sm, vm, rhp4.WithPriceTableValidity(30*time.Minute), rhp4.WithContractProofWindowBuffer(72)) + + var stopListenerFuncs []func() error + defer func() { + for _, f := range stopListenerFuncs { + if err := f(); err != nil { + log.Error("failed to stop listener", zap.Error(err)) + } + } + }() + for _, addr := range cfg.RHP4.ListenAddresses { + switch addr.Protocol { + case "tcp", "tcp4", "tcp6": + l, err := rhp.Listen(addr.Protocol, addr.Address, rhp.WithDataMonitor(dr), rhp.WithReadLimit(rl), rhp.WithWriteLimit(wl)) + if err != nil { + return fmt.Errorf("failed to listen on rhp4 addr: %w", err) + } + stopListenerFuncs = append(stopListenerFuncs, l.Close) + go rhp.ServeRHP4SiaMux(l, rhp4, log.Named("rhp4")) + default: + return fmt.Errorf("unsupported protocol: %s", addr.Protocol) + } + } + apiOpts := []api.ServerOption{ api.WithAlerts(am), api.WithLogger(log.Named("api")), - api.WithRHPSessionReporter(sr), api.WithWebhooks(wr), api.WithSQLite3Store(store), } if !cfg.Explorer.Disable { ex := explorer.New(cfg.Explorer.URL) - pm, err := pin.NewManager(store, cfm, ex, pin.WithLogger(log.Named("pin"))) + pm, err := pin.NewManager(store, sm, ex, pin.WithLogger(log.Named("pin"))) if err != nil { return fmt.Errorf("failed to create pin manager: %w", err) } @@ -314,7 +409,7 @@ func runRootCmd(ctx context.Context, cfg config.Config, walletKey types.PrivateK web := http.Server{ Handler: webRouter{ - api: jape.BasicAuth(cfg.HTTP.Password)(api.NewServer(cfg.Name, hostKey.PublicKey(), cm, s, accounts, contractManager, vm, wm, store, cfm, index, apiOpts...)), + api: jape.BasicAuth(cfg.HTTP.Password)(api.NewServer(cfg.Name, hostKey.PublicKey(), cm, s, accounts, contracts, vm, wm, store, sm, index, apiOpts...)), ui: hostd.Handler(), }, ReadTimeout: 30 * time.Second, diff --git a/config/config.go b/config/config.go index 38576cc8..a48b0bca 100644 --- a/config/config.go +++ b/config/config.go @@ -1,5 +1,9 @@ package config +import ( + "go.sia.tech/coreutils/chain" +) + type ( // HTTP contains the configuration for the HTTP server. HTTP struct { @@ -21,22 +25,40 @@ type ( IndexBatchSize int `yaml:"indexBatchSize,omitempty"` } - // RHP2 contains the configuration for the RHP2 server. - RHP2 struct { - Address string `yaml:"address,omitempty"` - } - // ExplorerData contains the configuration for using an external explorer. ExplorerData struct { Disable bool `yaml:"disable,omitempty"` URL string `yaml:"url,omitempty"` } + // RHP2 contains the configuration for the RHP2 server. + RHP2 struct { + Address string `yaml:"address,omitempty"` + } + // RHP3 contains the configuration for the RHP3 server. RHP3 struct { TCPAddress string `yaml:"tcp,omitempty"` } + // RHP4ListenAddress contains the configuration for an RHP4 listen address. + RHP4ListenAddress struct { + Protocol string `yaml:"protocol,omitempty"` + Address string `yaml:"address,omitempty"` + } + + // RHP4AnnounceAddress contains the configuration for an RHP4 announce address. + RHP4AnnounceAddress struct { + Protocol chain.Protocol `yaml:"protocol,omitempty"` + Address string `yaml:"address,omitempty"` + } + + // RHP4 contains the configuration for the RHP4 server. + RHP4 struct { + ListenAddresses []RHP4ListenAddress `yaml:"listenAddresses,omitempty"` + AnnounceAddresses []RHP4AnnounceAddress `yaml:"announceAddresses,omitempty"` + } + // LogFile configures the file output of the logger. LogFile struct { Enabled bool `yaml:"enabled,omitempty"` @@ -77,6 +99,7 @@ type ( Explorer ExplorerData `yaml:"explorer,omitempty"` RHP2 RHP2 `yaml:"rhp2,omitempty"` RHP3 RHP3 `yaml:"rhp3,omitempty"` + RHP4 RHP4 `yaml:"rhp4,omitempty"` Log Log `yaml:"log,omitempty"` } ) diff --git a/go.mod b/go.mod index c50df739..a639dafb 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( go.sia.tech/core v0.6.1 go.sia.tech/coreutils v0.6.0 go.sia.tech/jape v0.12.1 + go.sia.tech/mux v1.3.0 go.sia.tech/web/hostd v0.49.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 @@ -23,7 +24,6 @@ require ( lukechampine.com/flagg v1.1.1 lukechampine.com/frand v1.5.1 lukechampine.com/upnp v0.3.0 - nhooyr.io/websocket v1.8.17 ) require ( @@ -34,7 +34,6 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect go.etcd.io/bbolt v1.3.11 // indirect - go.sia.tech/mux v1.3.0 // indirect go.sia.tech/web v0.0.0-20240610131903-5611d44a533e // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.28.0 // indirect diff --git a/go.sum b/go.sum index f35bb48e..c9cae0c3 100644 --- a/go.sum +++ b/go.sum @@ -91,5 +91,3 @@ lukechampine.com/frand v1.5.1 h1:fg0eRtdmGFIxhP5zQJzM1lFDbD6CUfu/f+7WgAZd5/w= lukechampine.com/frand v1.5.1/go.mod h1:4VstaWc2plN4Mjr10chUD46RAVGWhpkZ5Nja8+Azp0Q= lukechampine.com/upnp v0.3.0 h1:UVCD6eD6fmJmwak6DVE3vGN+L46Fk8edTcC6XYCb6C4= lukechampine.com/upnp v0.3.0/go.mod h1:sOuF+fGSDKjpUm6QI0mfb82ScRrhj8bsqsD78O5nK1k= -nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= -nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/host/accounts/accounts.go b/host/accounts/accounts.go index d44688b6..213d05ae 100644 --- a/host/accounts/accounts.go +++ b/host/accounts/accounts.go @@ -1,3 +1,6 @@ +// Package accounts implements RHP3 ephemeral account management for the host. +// +// Deprecated: will be merged with the contractor and removed after the hardfork package accounts import ( diff --git a/host/contracts/accounts.go b/host/contracts/accounts.go new file mode 100644 index 00000000..514e051c --- /dev/null +++ b/host/contracts/accounts.go @@ -0,0 +1,36 @@ +package contracts + +import ( + proto4 "go.sia.tech/core/rhp/v4" + "go.sia.tech/core/types" +) + +// AccountBalance returns the balance of an account. +func (cm *Manager) AccountBalance(account proto4.Account) (types.Currency, error) { + return cm.store.RHP4AccountBalance(account) +} + +// CreditAccountsWithContract atomically revises a contract and credits the accounts +// returning the new balance of each account. +func (cm *Manager) CreditAccountsWithContract(deposits []proto4.AccountDeposit, contractID types.FileContractID, revision types.V2FileContract, usage proto4.Usage) ([]types.Currency, error) { + return cm.store.RHP4CreditAccounts(deposits, contractID, revision, V2Usage{ + RPCRevenue: usage.RPC, + StorageRevenue: usage.Storage, + IngressRevenue: usage.Ingress, + EgressRevenue: usage.Egress, + AccountFunding: usage.AccountFunding, + RiskedCollateral: usage.RiskedCollateral, + }) +} + +// DebitAccount debits an account. +func (cm *Manager) DebitAccount(account proto4.Account, usage proto4.Usage) error { + return cm.store.RHP4DebitAccount(account, V2Usage{ + RPCRevenue: usage.RPC, + StorageRevenue: usage.Storage, + IngressRevenue: usage.Ingress, + EgressRevenue: usage.Egress, + AccountFunding: usage.AccountFunding, + RiskedCollateral: usage.RiskedCollateral, + }) +} diff --git a/host/contracts/contracts.go b/host/contracts/contracts.go index 1c0da26b..e5f2af31 100644 --- a/host/contracts/contracts.go +++ b/host/contracts/contracts.go @@ -144,13 +144,6 @@ type ( RenewedFrom types.FileContractID `json:"renewedFrom"` } - // A V2FormationTransactionSet contains the formation transaction set for a - // v2 contract. - V2FormationTransactionSet struct { - TransactionSet []types.V2Transaction - Basis types.ChainIndex - } - // A Contract contains metadata on the current state of a file contract. Contract struct { SignedRevision @@ -260,6 +253,11 @@ var ( ErrContractExists = errors.New("contract already exists") ) +// RenterCost returns the total cost of the usage to the renter. +func (u V2Usage) RenterCost() types.Currency { + return u.RPCRevenue.Add(u.StorageRevenue).Add(u.EgressRevenue).Add(u.IngressRevenue).Add(u.AccountFunding) +} + // Add returns u + b func (a Usage) Add(b Usage) (c Usage) { return Usage{ diff --git a/host/contracts/contracts_test.go b/host/contracts/contracts_test.go index 7a60a939..bec22bd4 100644 --- a/host/contracts/contracts_test.go +++ b/host/contracts/contracts_test.go @@ -96,22 +96,12 @@ func TestContractUpdater(t *testing.T) { } defer updater.Close() - var releaseFuncs []func() error - defer func() { - for _, release := range releaseFuncs { - if err := release(); err != nil { - t.Fatal(err) - } - } - }() - for i := 0; i < test.append; i++ { root := frand.Entropy256() - release, err := node.Store.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) + err := node.Store.StoreTempSector(root, 10, func(loc storage.SectorLocation) error { return nil }) if err != nil { t.Fatal(err) } - releaseFuncs = append(releaseFuncs, release) updater.AppendSector(root) roots = append(roots, root) } @@ -135,11 +125,6 @@ func TestContractUpdater(t *testing.T) { } else if err := updater.Close(); err != nil { t.Fatal(err) } - for _, release := range releaseFuncs { - if err := release(); err != nil { - t.Fatal(err) - } - } // check that the sector roots are correct in the database allRoots, err := node.Store.SectorRoots() diff --git a/host/contracts/integrity.go b/host/contracts/integrity.go index aee4e633..76f7d0d5 100644 --- a/host/contracts/integrity.go +++ b/host/contracts/integrity.go @@ -116,7 +116,7 @@ func (cm *Manager) CheckIntegrity(ctx context.Context, contractID types.FileCont default: } // read each sector from disk and verify its Merkle root - sector, err := cm.storage.Read(root) + sector, err := cm.storage.ReadSector(root) if err != nil { // sector read failed log.Error("missing sector", zap.String("root", root.String()), zap.Error(err)) missing++ diff --git a/host/contracts/lock.go b/host/contracts/lock.go new file mode 100644 index 00000000..c6594f6a --- /dev/null +++ b/host/contracts/lock.go @@ -0,0 +1,141 @@ +package contracts + +import ( + "context" + "fmt" + "sync" + + "go.sia.tech/core/types" + rhp4 "go.sia.tech/coreutils/rhp/v4" +) + +type ( + lock struct { + ch chan struct{} + n int + } + + locker struct { + mu sync.Mutex + locks map[types.FileContractID]*lock + } +) + +func newLocker() *locker { + l := &locker{ + locks: make(map[types.FileContractID]*lock), + } + return l +} + +// Unlock releases a lock on the given contract ID. If the lock is not held, the +// function will panic. +func (lr *locker) Unlock(id types.FileContractID) { + lr.mu.Lock() + defer lr.mu.Unlock() + l, ok := lr.locks[id] + if !ok { + panic("unlocking unheld lock") // developer error + } + l.n-- + if l.n == 0 { + delete(lr.locks, id) + } else { + l.ch <- struct{}{} + } +} + +// Lock acquires a lock on the given contract ID. If the lock is already held, the +// function will block until the lock is released or the context is canceled. +// If the context is canceled, the function will return an error. +func (lr *locker) Lock(ctx context.Context, id types.FileContractID) error { + lr.mu.Lock() + l, ok := lr.locks[id] + if !ok { + // immediately acquire the lock + defer lr.mu.Unlock() + l = &lock{ + ch: make(chan struct{}, 1), + n: 1, + } + lr.locks[id] = l + return nil + } + l.n++ + lr.mu.Unlock() // unlock before waiting to avoid deadlock + select { + case <-ctx.Done(): + lr.mu.Lock() + l.n-- + if l.n == 0 { + delete(lr.locks, id) + } + lr.mu.Unlock() + return ctx.Err() + case <-l.ch: + return nil + } +} + +// Lock locks a contract for modification. +// +// Deprecated: Use LockV2Contract instead. +func (cm *Manager) Lock(ctx context.Context, id types.FileContractID) (SignedRevision, error) { + ctx, cancel, err := cm.tg.AddContext(ctx) + if err != nil { + return SignedRevision{}, err + } + defer cancel() + + if err := cm.locks.Lock(ctx, id); err != nil { + return SignedRevision{}, err + } + + contract, err := cm.store.Contract(id) + if err != nil { + cm.locks.Unlock(id) + return SignedRevision{}, fmt.Errorf("failed to get contract: %w", err) + } else if err := cm.isGoodForModification(contract); err != nil { + cm.locks.Unlock(id) + return SignedRevision{}, fmt.Errorf("contract is not good for modification: %w", err) + } + return contract.SignedRevision, nil +} + +// Unlock unlocks a locked contract. +// +// Deprecated: Use LockV2Contract instead. +func (cm *Manager) Unlock(id types.FileContractID) { + cm.locks.Unlock(id) +} + +// LockV2Contract locks a contract for modification. The returned unlock function +// must be called to release the lock. +func (cm *Manager) LockV2Contract(id types.FileContractID) (rev rhp4.RevisionState, unlock func(), _ error) { + done, err := cm.tg.Add() + if err != nil { + return rhp4.RevisionState{}, nil, err + } + defer done() + + // blocking is fine because locks are held for a short time + if err := cm.locks.Lock(context.Background(), id); err != nil { + return rhp4.RevisionState{}, nil, err + } + + contract, err := cm.store.V2Contract(id) + if err != nil { + cm.locks.Unlock(id) + return rhp4.RevisionState{}, nil, fmt.Errorf("failed to get contract: %w", err) + } + + var once sync.Once + return rhp4.RevisionState{ + Revision: contract.V2FileContract, + Roots: cm.getSectorRoots(id), + }, func() { + once.Do(func() { + cm.locks.Unlock(id) + }) + }, nil +} diff --git a/host/contracts/manager.go b/host/contracts/manager.go index 1053ff90..090180f1 100644 --- a/host/contracts/manager.go +++ b/host/contracts/manager.go @@ -1,7 +1,6 @@ package contracts import ( - "context" "errors" "fmt" "math" @@ -10,7 +9,9 @@ import ( "go.sia.tech/core/consensus" rhp2 "go.sia.tech/core/rhp/v2" + proto4 "go.sia.tech/core/rhp/v4" "go.sia.tech/core/types" + rhp4 "go.sia.tech/coreutils/rhp/v4" "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/internal/threadgroup" "go.uber.org/zap" @@ -26,6 +27,7 @@ type ( UnconfirmedParents(txn types.Transaction) []types.Transaction AddPoolTransactions([]types.Transaction) (known bool, err error) AddV2PoolTransactions(types.ChainIndex, []types.V2Transaction) (known bool, err error) + UpdateV2TransactionSet(txns []types.V2Transaction, from, to types.ChainIndex) ([]types.V2Transaction, error) RecommendedFee() types.Currency } @@ -49,8 +51,8 @@ type ( // A StorageManager stores and retrieves sectors. StorageManager interface { - // Read reads a sector from the store - Read(root types.Hash256) (*[rhp2.SectorSize]byte, error) + // ReadSector reads a sector from the store + ReadSector(root types.Hash256) (*[rhp2.SectorSize]byte, error) } // Alerts registers and dismisses global alerts. @@ -59,11 +61,6 @@ type ( Dismiss(...types.Hash256) } - locker struct { - c chan struct{} - waiters int - } - // A Manager manages contracts' lifecycle Manager struct { rejectBuffer uint64 @@ -78,12 +75,10 @@ type ( chain ChainManager syncer Syncer wallet Wallet + locks *locker // contracts must be locked while they are being modified - mu sync.Mutex // guards the following fields - // caches the sector roots of all contracts to avoid long reads from - // the store - sectorRoots map[types.FileContractID][]types.Hash256 - locks map[types.FileContractID]*locker // contracts must be locked while they are being modified + mu sync.Mutex + sectorRoots map[types.FileContractID][]types.Hash256 // caches the sector roots of all contracts to avoid long reads from } ) @@ -106,68 +101,6 @@ func (cm *Manager) setSectorRoots(id types.FileContractID, roots []types.Hash256 cm.sectorRoots[id] = append([]types.Hash256(nil), roots...) } -// Lock locks a contract for modification. -func (cm *Manager) Lock(ctx context.Context, id types.FileContractID) (SignedRevision, error) { - ctx, cancel, err := cm.tg.AddContext(ctx) - if err != nil { - return SignedRevision{}, err - } - defer cancel() - - cm.mu.Lock() - contract, err := cm.store.Contract(id) - if err != nil { - cm.mu.Unlock() - return SignedRevision{}, fmt.Errorf("failed to get contract: %w", err) - } else if err := cm.isGoodForModification(contract); err != nil { - cm.mu.Unlock() - return SignedRevision{}, fmt.Errorf("contract is not good for modification: %w", err) - } - - // if the contract isn't already locked, create a new lock - if _, exists := cm.locks[id]; !exists { - cm.locks[id] = &locker{ - c: make(chan struct{}, 1), - waiters: 0, - } - cm.mu.Unlock() - return contract.SignedRevision, nil - } - cm.locks[id].waiters++ - c := cm.locks[id].c - // mutex must be unlocked before waiting on the channel to prevent deadlock. - cm.mu.Unlock() - select { - case <-c: - cm.mu.Lock() - defer cm.mu.Unlock() - contract, err := cm.store.Contract(id) - if err != nil { - return SignedRevision{}, fmt.Errorf("failed to get contract: %w", err) - } else if err := cm.isGoodForModification(contract); err != nil { - return SignedRevision{}, fmt.Errorf("contract is not good for modification: %w", err) - } - return contract.SignedRevision, nil - case <-ctx.Done(): - return SignedRevision{}, ctx.Err() - } -} - -// Unlock unlocks a locked contract. -func (cm *Manager) Unlock(id types.FileContractID) { - cm.mu.Lock() - defer cm.mu.Unlock() - lock, exists := cm.locks[id] - if !exists { - return - } else if lock.waiters <= 0 { - delete(cm.locks, id) - return - } - lock.waiters-- - lock.c <- struct{}{} -} - // Contracts returns a paginated list of contracts matching the filter and the // total number of contracts matching the filter. func (cm *Manager) Contracts(filter ContractFilter) ([]Contract, int, error) { @@ -184,8 +117,9 @@ func (cm *Manager) V2Contract(id types.FileContractID) (V2Contract, error) { return cm.store.V2Contract(id) } -// V2ContractElement returns the latest v2 state element with the given ID. -func (cm *Manager) V2ContractElement(id types.FileContractID) (types.V2FileContractElement, error) { +// V2FileContractElement returns the chain index and file contract element for the +// given contract ID. +func (cm *Manager) V2FileContractElement(id types.FileContractID) (types.ChainIndex, types.V2FileContractElement, error) { return cm.store.V2ContractElement(id) } @@ -236,7 +170,7 @@ func (cm *Manager) RenewContract(renewal SignedRevision, existing SignedRevision } // ReviseV2Contract atomically updates a contract and its associated sector roots. -func (cm *Manager) ReviseV2Contract(contractID types.FileContractID, revision types.V2FileContract, roots []types.Hash256, usage Usage) error { +func (cm *Manager) ReviseV2Contract(contractID types.FileContractID, revision types.V2FileContract, roots []types.Hash256, usage proto4.Usage) error { done, err := cm.tg.Add() if err != nil { return err @@ -279,7 +213,15 @@ func (cm *Manager) ReviseV2Contract(contractID types.FileContractID, revision ty } // revise the contract in the store - if err := cm.store.ReviseV2Contract(contractID, revision, roots, usage); err != nil { + err = cm.store.ReviseV2Contract(contractID, revision, roots, V2Usage{ + RPCRevenue: usage.RPC, + StorageRevenue: usage.Storage, + IngressRevenue: usage.Ingress, + EgressRevenue: usage.Egress, + AccountFunding: usage.AccountFunding, + RiskedCollateral: usage.RiskedCollateral, + }) + if err != nil { return err } // update the sector roots cache @@ -290,14 +232,14 @@ func (cm *Manager) ReviseV2Contract(contractID types.FileContractID, revision ty // AddV2Contract stores the provided contract, should error if the contract // already exists. -func (cm *Manager) AddV2Contract(formation V2FormationTransactionSet, usage V2Usage) error { +func (cm *Manager) AddV2Contract(formation rhp4.TransactionSet, usage proto4.Usage) error { done, err := cm.tg.Add() if err != nil { return err } defer done() - formationSet := formation.TransactionSet + formationSet := formation.Transactions if len(formationSet) == 0 { return errors.New("no formation transactions provided") } else if len(formationSet[len(formationSet)-1].FileContracts) != 1 { @@ -314,7 +256,14 @@ func (cm *Manager) AddV2Contract(formation V2FormationTransactionSet, usage V2Us ID: contractID, Status: V2ContractStatusPending, NegotiationHeight: cm.chain.Tip().Height, - Usage: usage, + Usage: V2Usage{ + RPCRevenue: usage.RPC, + StorageRevenue: usage.Storage, + IngressRevenue: usage.Ingress, + EgressRevenue: usage.Egress, + AccountFunding: usage.AccountFunding, + RiskedCollateral: usage.RiskedCollateral, + }, } if err := cm.store.AddV2Contract(contract, formation); err != nil { @@ -326,14 +275,14 @@ func (cm *Manager) AddV2Contract(formation V2FormationTransactionSet, usage V2Us // RenewV2Contract renews a contract. It is expected that the existing // contract will be cleared. -func (cm *Manager) RenewV2Contract(renewal V2FormationTransactionSet, usage V2Usage) error { +func (cm *Manager) RenewV2Contract(renewal rhp4.TransactionSet, usage proto4.Usage) error { done, err := cm.tg.Add() if err != nil { return err } defer done() - renewalSet := renewal.TransactionSet + renewalSet := renewal.Transactions if len(renewalSet) == 0 { return errors.New("no renewal transactions provided") } else if len(renewalSet[len(renewalSet)-1].FileContractResolutions) != 1 { @@ -356,7 +305,7 @@ func (cm *Manager) RenewV2Contract(renewal V2FormationTransactionSet, usage V2Us // sanity checks if finalRevision.RevisionNumber != types.MaxRevisionNumber { - return errors.New("existing contract must be cleared") + return errors.New("final revision must have max revision number") } else if fc.Filesize != existing.Filesize { return errors.New("renewal contract must have same file size as existing contract") } else if fc.Capacity != existing.Capacity { @@ -378,10 +327,17 @@ func (cm *Manager) RenewV2Contract(renewal V2FormationTransactionSet, usage V2Us Status: V2ContractStatusPending, NegotiationHeight: cm.chain.Tip().Height, RenewedFrom: existingID, - Usage: usage, + Usage: V2Usage{ + RPCRevenue: usage.RPC, + StorageRevenue: usage.Storage, + IngressRevenue: usage.Ingress, + EgressRevenue: usage.Egress, + AccountFunding: usage.AccountFunding, + RiskedCollateral: usage.RiskedCollateral, + }, } - if err := cm.store.RenewV2Contract(contract, renewal, existingID, finalRevision); err != nil { + if err := cm.store.RenewV2Contract(contract, renewal, existingID, finalRevision, existingRoots); err != nil { return err } cm.setSectorRoots(contract.ID, existingRoots) @@ -447,8 +403,7 @@ func NewManager(store ContractStore, storage StorageManager, chain ChainManager, alerts: alerts.NewNop(), tg: threadgroup.New(), log: zap.NewNop(), - - locks: make(map[types.FileContractID]*locker), + locks: newLocker(), } for _, opt := range opts { diff --git a/host/contracts/manager_test.go b/host/contracts/manager_test.go index 9f91ebc3..0c3ac045 100644 --- a/host/contracts/manager_test.go +++ b/host/contracts/manager_test.go @@ -10,9 +10,10 @@ import ( "time" rhp2 "go.sia.tech/core/rhp/v2" - rhp4 "go.sia.tech/core/rhp/v4" + proto4 "go.sia.tech/core/rhp/v4" "go.sia.tech/core/types" "go.sia.tech/coreutils/chain" + rhp4 "go.sia.tech/coreutils/rhp/v4" "go.sia.tech/coreutils/syncer" "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/host/contracts" @@ -67,19 +68,19 @@ func formV2Contract(t *testing.T, cm *chain.Manager, c *contracts.Manager, w *wa t.Fatal("failed to fund transaction:", err) } w.SignV2Inputs(&txn, toSign) - formationSet := contracts.V2FormationTransactionSet{ - TransactionSet: []types.V2Transaction{txn}, - Basis: basis, + formationSet := rhp4.TransactionSet{ + Transactions: []types.V2Transaction{txn}, + Basis: basis, } if broadcast { - if _, err := cm.AddV2PoolTransactions(formationSet.Basis, formationSet.TransactionSet); err != nil { + if _, err := cm.AddV2PoolTransactions(formationSet.Basis, formationSet.Transactions); err != nil { t.Fatal("failed to add formation set to pool:", err) } - s.BroadcastV2TransactionSet(formationSet.Basis, formationSet.TransactionSet) + s.BroadcastV2TransactionSet(formationSet.Basis, formationSet.Transactions) } - if err := c.AddV2Contract(formationSet, contracts.V2Usage{}); err != nil { + if err := c.AddV2Contract(formationSet, proto4.Usage{}); err != nil { t.Fatal("failed to add contract:", err) } return txn.V2FileContractID(txn.ID(), 0), fc @@ -853,9 +854,9 @@ func TestV2ContractLifecycle(t *testing.T) { } defer release() - fc.Filesize = rhp4.SectorSize - fc.Capacity = rhp4.SectorSize - fc.FileMerkleRoot = rhp4.MetaRoot(roots) + fc.Filesize = proto4.SectorSize + fc.Capacity = proto4.SectorSize + fc.FileMerkleRoot = proto4.MetaRoot(roots) fc.RevisionNumber++ // transfer some funds from the renter to the host cost, collateral := types.Siacoins(1), types.Siacoins(2) @@ -866,8 +867,8 @@ func TestV2ContractLifecycle(t *testing.T) { fc.HostSignature = hostKey.SignHash(sigHash) fc.RenterSignature = renterKey.SignHash(sigHash) - err = node.Contracts.ReviseV2Contract(contractID, fc, roots, contracts.Usage{ - StorageRevenue: cost, + err = node.Contracts.ReviseV2Contract(contractID, fc, roots, proto4.Usage{ + Storage: cost, RiskedCollateral: collateral, }) if err != nil { @@ -902,20 +903,19 @@ func TestV2ContractLifecycle(t *testing.T) { assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) // add a root to the contract - var sector [rhp2.SectorSize]byte + var sector [proto4.SectorSize]byte frand.Read(sector[:256]) root := frand.Entropy256() // random root roots := []types.Hash256{root} - release, err := node.Volumes.Write(root, §or) + err := node.Volumes.StoreSector(root, §or, 10) if err != nil { t.Fatal(err) } - defer release() - fc.Filesize = rhp4.SectorSize - fc.Capacity = rhp4.SectorSize - fc.FileMerkleRoot = rhp4.MetaRoot(roots) + fc.Filesize = proto4.SectorSize + fc.Capacity = proto4.SectorSize + fc.FileMerkleRoot = proto4.MetaRoot(roots) fc.RevisionNumber++ // transfer some funds from the renter to the host cost, collateral := types.Siacoins(1), types.Siacoins(2) @@ -926,14 +926,12 @@ func TestV2ContractLifecycle(t *testing.T) { fc.HostSignature = hostKey.SignHash(sigHash) fc.RenterSignature = renterKey.SignHash(sigHash) - err = node.Contracts.ReviseV2Contract(contractID, fc, roots, contracts.Usage{ - StorageRevenue: cost, + err = node.Contracts.ReviseV2Contract(contractID, fc, roots, proto4.Usage{ + Storage: cost, RiskedCollateral: collateral, }) if err != nil { t.Fatal(err) - } else if err := release(); err != nil { - t.Fatal(err) } // metrics should not have been updated, contract is still pending assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) @@ -962,7 +960,7 @@ func TestV2ContractLifecycle(t *testing.T) { assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) // add a root to the contract - var sector [rhp2.SectorSize]byte + var sector [proto4.SectorSize]byte frand.Read(sector[:]) root := rhp2.SectorRoot(§or) roots := []types.Hash256{root} @@ -973,9 +971,9 @@ func TestV2ContractLifecycle(t *testing.T) { } defer release() - fc.Filesize = rhp4.SectorSize - fc.Capacity = rhp4.SectorSize - fc.FileMerkleRoot = rhp4.MetaRoot(roots) + fc.Filesize = proto4.SectorSize + fc.Capacity = proto4.SectorSize + fc.FileMerkleRoot = proto4.MetaRoot(roots) fc.RevisionNumber++ // transfer some funds from the renter to the host cost, collateral := types.Siacoins(1), types.Siacoins(2) @@ -986,8 +984,8 @@ func TestV2ContractLifecycle(t *testing.T) { fc.HostSignature = hostKey.SignHash(sigHash) fc.RenterSignature = renterKey.SignHash(sigHash) - err = node.Contracts.ReviseV2Contract(contractID, fc, roots, contracts.Usage{ - StorageRevenue: cost, + err = node.Contracts.ReviseV2Contract(contractID, fc, roots, proto4.Usage{ + Storage: cost, RiskedCollateral: collateral, }) if err != nil { @@ -1013,7 +1011,6 @@ func TestV2ContractLifecycle(t *testing.T) { final.RevisionNumber = types.MaxRevisionNumber final.HostSignature = types.Signature{} final.RenterSignature = types.Signature{} - final.RevisionNumber = types.MaxRevisionNumber additionalCollateral := types.Siacoins(2) renewal := types.V2FileContractRenewal{ @@ -1042,7 +1039,7 @@ func TestV2ContractLifecycle(t *testing.T) { renewal.HostSignature = hostKey.SignHash(renewalSigHash) renewal.RenterSignature = renterKey.SignHash(renewalSigHash) - fce, err := com.V2ContractElement(contractID) + _, fce, err := com.V2FileContractElement(contractID) if err != nil { t.Fatal(err) } @@ -1073,16 +1070,16 @@ func TestV2ContractLifecycle(t *testing.T) { }, } node.Wallet.SignV2Inputs(&renewalTxn, []int{0}) - renewalTxnSet := contracts.V2FormationTransactionSet{ - Basis: basis, - TransactionSet: []types.V2Transaction{setupTxn, renewalTxn}, + renewalTxnSet := rhp4.TransactionSet{ + Basis: basis, + Transactions: []types.V2Transaction{setupTxn, renewalTxn}, } - if _, err := cm.AddV2PoolTransactions(renewalTxnSet.Basis, renewalTxnSet.TransactionSet); err != nil { + if _, err := cm.AddV2PoolTransactions(renewalTxnSet.Basis, renewalTxnSet.Transactions); err != nil { t.Fatal("failed to add renewal to pool:", err) } - node.Syncer.BroadcastV2TransactionSet(renewalTxnSet.Basis, renewalTxnSet.TransactionSet) + node.Syncer.BroadcastV2TransactionSet(renewalTxnSet.Basis, renewalTxnSet.Transactions) - err = com.RenewV2Contract(renewalTxnSet, contracts.V2Usage{ + err = com.RenewV2Contract(renewalTxnSet, proto4.Usage{ RiskedCollateral: renewal.NewContract.TotalCollateral.Sub(renewal.NewContract.MissedHostValue), }) if err != nil { @@ -1159,14 +1156,14 @@ func TestV2ContractLifecycle(t *testing.T) { t.Fatal("failed to fund transaction:", err) } w.SignV2Inputs(&txn, toSign) - formationSet := contracts.V2FormationTransactionSet{ - TransactionSet: []types.V2Transaction{txn}, - Basis: basis, + formationSet := rhp4.TransactionSet{ + Transactions: []types.V2Transaction{txn}, + Basis: basis, } contractID := txn.V2FileContractID(txn.ID(), 0) // corrupt the formation set to trigger a rejection - formationSet.TransactionSet[len(formationSet.TransactionSet)-1].SiacoinInputs[0].SatisfiedPolicy.Signatures[0] = types.Signature{} - if err := c.AddV2Contract(formationSet, contracts.V2Usage{}); err != nil { + formationSet.Transactions[len(formationSet.Transactions)-1].SiacoinInputs[0].SatisfiedPolicy.Signatures[0] = types.Signature{} + if err := c.AddV2Contract(formationSet, proto4.Usage{}); err != nil { t.Fatal("failed to add contract:", err) } @@ -1242,7 +1239,7 @@ func TestSectorRoots(t *testing.T) { for i := 0; i < sectors; i++ { root, err := func() (types.Hash256, error) { root := frand.Entropy256() - release, err := node.Store.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) + release, err := node.Store.StoreSector(root, func(storage.SectorLocation, bool) error { return nil }) if err != nil { return types.Hash256{}, fmt.Errorf("failed to store sector: %w", err) } diff --git a/host/contracts/persist.go b/host/contracts/persist.go index ba73ae18..5e48fcad 100644 --- a/host/contracts/persist.go +++ b/host/contracts/persist.go @@ -1,7 +1,9 @@ package contracts import ( + proto4 "go.sia.tech/core/rhp/v4" "go.sia.tech/core/types" + rhp4 "go.sia.tech/coreutils/rhp/v4" ) type ( @@ -38,21 +40,28 @@ type ( ExpireContractSectors(height uint64) error // V2ContractElement returns the latest v2 state element with the given ID. - V2ContractElement(types.FileContractID) (types.V2FileContractElement, error) + V2ContractElement(types.FileContractID) (types.ChainIndex, types.V2FileContractElement, error) // V2Contract returns the v2 contract with the given ID. V2Contract(types.FileContractID) (V2Contract, error) // AddV2Contract stores the provided contract, should error if the contract // already exists in the store. - AddV2Contract(V2Contract, V2FormationTransactionSet) error + AddV2Contract(V2Contract, rhp4.TransactionSet) error // RenewV2Contract renews a contract. It is expected that the existing // contract will be cleared. - RenewV2Contract(renewal V2Contract, renewalSet V2FormationTransactionSet, renewedID types.FileContractID, finalRevision types.V2FileContract) error + RenewV2Contract(renewal V2Contract, renewalSet rhp4.TransactionSet, renewedID types.FileContractID, clearing types.V2FileContract, roots []types.Hash256) error // ReviseV2Contract atomically updates a contract and its associated // sector roots. - ReviseV2Contract(id types.FileContractID, revision types.V2FileContract, roots []types.Hash256, usage Usage) error + ReviseV2Contract(types.FileContractID, types.V2FileContract, []types.Hash256, V2Usage) error // ExpireV2ContractSectors removes sector roots for any v2 contracts that are // rejected or past their proof window. ExpireV2ContractSectors(height uint64) error + + // RHP4AccountBalance returns the balance of an account. + RHP4AccountBalance(proto4.Account) (types.Currency, error) + // RHP4CreditAccounts atomically revises a contract and credits the accounts + RHP4CreditAccounts([]proto4.AccountDeposit, types.FileContractID, types.V2FileContract, V2Usage) (balances []types.Currency, err error) + // RHP4DebitAccount debits an account. + RHP4DebitAccount(proto4.Account, V2Usage) error } ) diff --git a/host/contracts/update.go b/host/contracts/update.go index 53246eb0..e8295220 100644 --- a/host/contracts/update.go +++ b/host/contracts/update.go @@ -7,6 +7,7 @@ import ( rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/coreutils/chain" + rhp4 "go.sia.tech/coreutils/rhp/v4" "go.sia.tech/coreutils/wallet" "go.uber.org/zap" ) @@ -33,7 +34,7 @@ type ( BroadcastProof []SignedRevision // V2 actions - RebroadcastV2Formation []V2FormationTransactionSet + RebroadcastV2Formation []rhp4.TransactionSet BroadcastV2Revision []types.V2FileContractRevision BroadcastV2Proof []types.V2FileContractElement BroadcastV2Expiration []types.V2FileContractElement @@ -115,7 +116,7 @@ func (cm *Manager) buildStorageProof(revision types.FileContractRevision, index } sectorRoot := roots[sectorIndex] - sector, err := cm.storage.Read(sectorRoot) + sector, err := cm.storage.ReadSector(sectorRoot) if err != nil { log.Error("failed to read sector data", zap.Error(err), zap.Stringer("sectorRoot", sectorRoot)) return types.StorageProof{}, fmt.Errorf("failed to read sector data") @@ -158,7 +159,7 @@ func (cm *Manager) buildV2StorageProof(cs consensus.State, fce types.V2FileContr } sectorRoot := roots[sectorIndex] - sector, err := cm.storage.Read(sectorRoot) + sector, err := cm.storage.ReadSector(sectorRoot) if err != nil { log.Error("failed to read sector data", zap.Error(err), zap.Stringer("sectorRoot", sectorRoot)) return types.V2StorageProof{}, fmt.Errorf("failed to read sector data") @@ -295,10 +296,10 @@ func (cm *Manager) ProcessActions(index types.ChainIndex) error { } for _, formationSet := range actions.RebroadcastV2Formation { - if len(formationSet.TransactionSet) == 0 { + if len(formationSet.Transactions) == 0 { continue } - formationTxn := formationSet.TransactionSet[len(formationSet.TransactionSet)-1] + formationTxn := formationSet.Transactions[len(formationSet.Transactions)-1] if len(formationTxn.FileContracts) == 0 { continue } @@ -306,12 +307,12 @@ func (cm *Manager) ProcessActions(index types.ChainIndex) error { contractID := formationTxn.V2FileContractID(formationTxn.ID(), 0) log := log.Named("v2 formation").With(zap.Stringer("basis", formationSet.Basis), zap.Stringer("contractID", contractID)) - if _, err := cm.chain.AddV2PoolTransactions(formationSet.Basis, formationSet.TransactionSet); err != nil { + if _, err := cm.chain.AddV2PoolTransactions(formationSet.Basis, formationSet.Transactions); err != nil { log.Error("failed to add formation transaction to pool", zap.Error(err)) continue } - cm.syncer.BroadcastV2TransactionSet(formationSet.Basis, formationSet.TransactionSet) - log.Debug("broadcast transaction", zap.String("transactionID", formationSet.TransactionSet[len(formationSet.TransactionSet)-1].ID().String())) + cm.syncer.BroadcastV2TransactionSet(formationSet.Basis, formationSet.Transactions) + log.Debug("broadcast transaction", zap.String("transactionID", formationSet.Transactions[len(formationSet.Transactions)-1].ID().String())) } for _, fcr := range actions.BroadcastV2Revision { diff --git a/host/settings/announce.go b/host/settings/announce.go index 5d1a7cfb..a9187979 100644 --- a/host/settings/announce.go +++ b/host/settings/announce.go @@ -8,7 +8,6 @@ import ( "go.sia.tech/core/types" "go.sia.tech/coreutils/chain" - rhp4 "go.sia.tech/coreutils/rhp/v4" "go.uber.org/zap" ) @@ -22,19 +21,19 @@ type ( // Announce announces the host to the network func (m *ConfigManager) Announce() error { - // get the current settings - settings := m.Settings() - - if m.validateNetAddress { - if err := validateNetAddress(settings.NetAddress); err != nil { - return fmt.Errorf("failed to validate net address %q: %w", settings.NetAddress, err) - } - } - minerFee := m.chain.RecommendedFee().Mul64(announcementTxnSize) cs := m.chain.TipState() if cs.Index.Height < cs.Network.HardforkV2.AllowHeight { + // get the current settings + settings := m.Settings() + + if m.validateNetAddress { + if err := validateNetAddress(settings.NetAddress); err != nil { + return fmt.Errorf("failed to validate net address %q: %w", settings.NetAddress, err) + } + } + // create a transaction with an announcement txn := types.Transaction{ ArbitraryData: [][]byte{ @@ -63,9 +62,7 @@ func (m *ConfigManager) Announce() error { // create a v2 transaction with an announcement txn := types.V2Transaction{ Attestations: []types.Attestation{ - chain.V2HostAnnouncement{ - {Protocol: rhp4.ProtocolTCPSiaMux, Address: settings.NetAddress}, // TODO: this isn't correct - }.ToAttestation(cs, m.hostKey), + chain.V2HostAnnouncement(m.rhp4AnnounceAddresses).ToAttestation(cs, m.hostKey), }, MinerFee: minerFee, } @@ -83,11 +80,18 @@ func (m *ConfigManager) Announce() error { return fmt.Errorf("failed to add transaction to pool: %w", err) } m.syncer.BroadcastV2TransactionSet(cs.Index, txnset) - m.log.Debug("broadcast v2 announcement", zap.String("transactionID", txn.ID().String()), zap.String("netaddress", settings.NetAddress), zap.String("cost", minerFee.ExactString())) + addresses := make([]string, 0, len(m.rhp4AnnounceAddresses)) + for _, addr := range m.rhp4AnnounceAddresses { + addresses = append(addresses, fmt.Sprintf("%s/%s", addr.Protocol, addr.Address)) // TODO: implement Stringer? + } + m.log.Debug("broadcast v2 announcement", zap.String("transactionID", txn.ID().String()), zap.Strings("addresses", addresses), zap.String("cost", minerFee.ExactString())) } return nil } +// validateNetAddress validates a net address. +// +// Deprecated: remove after hardfork func validateNetAddress(netaddress string) error { host, port, err := net.SplitHostPort(netaddress) if err != nil { diff --git a/host/settings/announce_test.go b/host/settings/announce_test.go index e6545997..103fa577 100644 --- a/host/settings/announce_test.go +++ b/host/settings/announce_test.go @@ -48,7 +48,11 @@ func TestAutoAnnounce(t *testing.T) { } defer storage.Close() - sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50)) + v2AnnounceAddresses := []chain.NetAddress{ + {Protocol: rhp4.ProtocolTCPSiaMux, Address: "foo.bar:1234"}, + {Protocol: rhp4.ProtocolTCPSiaMux, Address: "foo.bar:1236"}, + } + sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, storage, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50), settings.WithRHP4AnnounceAddresses(v2AnnounceAddresses)) if err != nil { t.Fatal(err) } @@ -82,7 +86,7 @@ func TestAutoAnnounce(t *testing.T) { } } - assertV2Announcement := func(t *testing.T, expectedAddr string, height uint64) { + assertV2Announcement := func(t *testing.T, addresses []chain.NetAddress, height uint64) { t.Helper() index, ok := node.Chain.BestIndex(height) @@ -96,7 +100,7 @@ func TestAutoAnnounce(t *testing.T) { } h := types.NewHasher() - types.EncodeSlice(h.E, chain.V2HostAnnouncement{{Protocol: rhp4.ProtocolTCPSiaMux, Address: expectedAddr}}) + types.EncodeSlice(h.E, addresses) if err := h.E.Flush(); err != nil { t.Fatal(err) } @@ -144,11 +148,11 @@ func TestAutoAnnounce(t *testing.T) { // v2 attestation. n := node.Chain.TipState().Network mineAndSync(t, n.HardforkV2.AllowHeight-node.Chain.Tip().Height+1) - assertV2Announcement(t, "baz.qux:5678", n.HardforkV2.AllowHeight+1) + assertV2Announcement(t, v2AnnounceAddresses, n.HardforkV2.AllowHeight+1) // mine a few more blocks to ensure the host doesn't re-announce mineAndSync(t, 10) - assertV2Announcement(t, "baz.qux:5678", n.HardforkV2.AllowHeight+1) + assertV2Announcement(t, v2AnnounceAddresses, n.HardforkV2.AllowHeight+1) } func TestAutoAnnounceV2(t *testing.T) { @@ -186,7 +190,11 @@ func TestAutoAnnounceV2(t *testing.T) { } defer storage.Close() - sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50)) + v2AnnounceAddresses := []chain.NetAddress{ + {Protocol: rhp4.ProtocolTCPSiaMux, Address: "foo.bar:1234"}, + {Protocol: rhp4.ProtocolTCPSiaMux, Address: "foo.bar:1236"}, + } + sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, storage, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50), settings.WithRHP4AnnounceAddresses(v2AnnounceAddresses)) if err != nil { t.Fatal(err) } @@ -212,7 +220,7 @@ func TestAutoAnnounceV2(t *testing.T) { } } - assertV2Announcement := func(t *testing.T, expectedAddr string, height uint64) { + assertV2Announcement := func(t *testing.T, addresses []chain.NetAddress, height uint64) { t.Helper() index, ok := node.Chain.BestIndex(height) @@ -226,7 +234,7 @@ func TestAutoAnnounceV2(t *testing.T) { } h := types.NewHasher() - types.EncodeSlice(h.E, chain.V2HostAnnouncement{{Protocol: rhp4.ProtocolTCPSiaMux, Address: expectedAddr}}) + types.EncodeSlice(h.E, chain.V2HostAnnouncement(addresses)) if err := h.E.Flush(); err != nil { t.Fatal(err) } @@ -239,24 +247,24 @@ func TestAutoAnnounceV2(t *testing.T) { } } - settings := settings.DefaultSettings - settings.NetAddress = "foo.bar:1234" - sm.UpdateSettings(settings) - // fund the wallet and trigger the first auto-announce mineAndSync(t, network.MaturityDelay+1+1) - assertV2Announcement(t, "foo.bar:1234", network.MaturityDelay+1+1) // first maturity height + funds available + confirmation + assertV2Announcement(t, v2AnnounceAddresses, network.MaturityDelay+1+1) // first maturity height + funds available + confirmation // mine until the next announcement and confirm it lastHeight := node.Chain.Tip().Height mineAndSync(t, 51) - assertV2Announcement(t, "foo.bar:1234", lastHeight+50+1) // first confirm + interval + confirmation + assertV2Announcement(t, v2AnnounceAddresses, lastHeight+50+1) // first confirm + interval + confirmation // change the address - settings.NetAddress = "baz.qux:5678" - sm.UpdateSettings(settings) + v2AnnounceAddresses[1].Address = "baz.qux:5678" + sm, err = settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, storage, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50), settings.WithRHP4AnnounceAddresses(v2AnnounceAddresses)) + if err != nil { + t.Fatal(err) + } + defer sm.Close() // trigger and confirm the new announcement lastHeight = node.Chain.Tip().Height mineAndSync(t, 2) - assertV2Announcement(t, "baz.qux:5678", lastHeight+2) + assertV2Announcement(t, v2AnnounceAddresses, lastHeight+2) } diff --git a/host/settings/options.go b/host/settings/options.go index d6199cde..da82b3fb 100644 --- a/host/settings/options.go +++ b/host/settings/options.go @@ -1,6 +1,7 @@ package settings import ( + "go.sia.tech/coreutils/chain" "go.uber.org/zap" ) @@ -46,3 +47,11 @@ func WithInitialSettings(settings Settings) Option { c.initialSettings = settings } } + +// WithRHP4AnnounceAddresses sets the addresses to announce on the blockchain +// for RHP4. +func WithRHP4AnnounceAddresses(addresses []chain.NetAddress) Option { + return func(c *ConfigManager) { + c.rhp4AnnounceAddresses = addresses + } +} diff --git a/host/settings/pin/pin_test.go b/host/settings/pin/pin_test.go index 554217cc..131ba0e3 100644 --- a/host/settings/pin/pin_test.go +++ b/host/settings/pin/pin_test.go @@ -13,6 +13,7 @@ import ( "go.sia.tech/core/types" "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/settings/pin" + "go.sia.tech/hostd/host/storage" "go.sia.tech/hostd/internal/testutil" "go.uber.org/zap/zaptest" ) @@ -119,7 +120,13 @@ func TestPinnedFields(t *testing.T) { currency: "usd", } - sm, err := settings.NewConfigManager(types.GeneratePrivateKey(), node.Store, node.Chain, node.Syncer, nil) + storage, err := storage.NewVolumeManager(node.Store) + if err != nil { + t.Fatal("failed to create storage manager:", err) + } + defer storage.Close() + + sm, err := settings.NewConfigManager(types.GeneratePrivateKey(), node.Store, node.Chain, node.Syncer, nil, storage) if err != nil { t.Fatal(err) } @@ -218,7 +225,13 @@ func TestAutomaticUpdate(t *testing.T) { currency: "usd", } - sm, err := settings.NewConfigManager(types.GeneratePrivateKey(), node.Store, node.Chain, node.Syncer, nil) + storage, err := storage.NewVolumeManager(node.Store) + if err != nil { + t.Fatal("failed to create storage manager:", err) + } + defer storage.Close() + + sm, err := settings.NewConfigManager(types.GeneratePrivateKey(), node.Store, node.Chain, node.Syncer, nil, storage) if err != nil { t.Fatal(err) } diff --git a/host/settings/settings.go b/host/settings/settings.go index 0cfc4cfe..776161c9 100644 --- a/host/settings/settings.go +++ b/host/settings/settings.go @@ -2,17 +2,20 @@ package settings import ( "crypto/ed25519" - "crypto/tls" "errors" "fmt" + "math" "net" "strings" "sync" "time" "go.sia.tech/core/consensus" + proto4 "go.sia.tech/core/rhp/v4" "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" "go.sia.tech/hostd/alerts" + "go.sia.tech/hostd/build" "go.sia.tech/hostd/internal/threadgroup" "go.uber.org/zap" "golang.org/x/time/rate" @@ -38,6 +41,8 @@ type ( // UpdateSettings updates the host's settings. UpdateSettings(s Settings) error + // LastAnnouncement returns the last announcement that was made by the + // host LastAnnouncement() (Announcement, error) // LastV2AnnouncementHash returns the hash of the last v2 announcement. LastV2AnnouncementHash() (types.Hash256, types.ChainIndex, error) @@ -67,6 +72,11 @@ type ( BroadcastV2TransactionSet(types.ChainIndex, []types.V2Transaction) } + // Storage provides information about the host's storage capacity + Storage interface { + Usage() (used, total uint64, _ error) + } + // A Wallet manages Siacoins and funds transactions Wallet interface { Address() types.Address @@ -87,7 +97,7 @@ type ( Settings struct { // Host settings AcceptingContracts bool `json:"acceptingContracts"` - NetAddress string `json:"netAddress"` + NetAddress string `json:"netAddress"` // TODO: remove after hardfork MaxContractDuration uint64 `json:"maxContractDuration"` WindowSize uint64 `json:"windowSize"` @@ -126,18 +136,20 @@ type ( // A ConfigManager manages the host's current configuration ConfigManager struct { - hostKey types.PrivateKey - announceInterval uint64 - validateNetAddress bool - initialSettings Settings + hostKey types.PrivateKey + announceInterval uint64 + validateNetAddress bool // TODO: remove after hardfork + rhp4AnnounceAddresses []chain.NetAddress + initialSettings Settings store Store a Alerts log *zap.Logger - chain ChainManager - syncer Syncer - wallet Wallet + chain ChainManager + syncer Syncer + wallet Wallet + storage Storage mu sync.Mutex // guards the following fields settings Settings // in-memory cache of the host's settings @@ -150,8 +162,6 @@ type ( lastIPv4 net.IP lastIPv6 net.IP - rhp3WSTLS *tls.Config - tg *threadgroup.ThreadGroup } ) @@ -253,34 +263,66 @@ func (m *ConfigManager) Settings() Settings { return m.settings } -// BandwidthLimiters returns the rate limiters for all traffic -func (m *ConfigManager) BandwidthLimiters() (ingress, egress *rate.Limiter) { +// RHPBandwidthLimiters returns the rate limiters for all traffic +func (m *ConfigManager) RHPBandwidthLimiters() (ingress, egress *rate.Limiter) { return m.ingressLimit, m.egressLimit } +// RHP4Settings returns the host's settings in the RHP4 format. The settings +// are not signed. +func (m *ConfigManager) RHP4Settings() proto4.HostSettings { + m.mu.Lock() + settings := m.settings + m.mu.Unlock() + + used, total, err := m.storage.Usage() + if err != nil { + m.log.Error("failed to get storage usage", zap.Error(err)) + } + + hs := proto4.HostSettings{ + Release: "hostd " + build.Version(), + WalletAddress: m.wallet.Address(), + AcceptingContracts: settings.AcceptingContracts, + MaxCollateral: settings.MaxCollateral, + MaxContractDuration: settings.MaxContractDuration, + MaxSectorDuration: 3 * 144, + MaxSectorBatchSize: 25600, // 100 GiB + RemainingStorage: total - used, + TotalStorage: total, + Prices: proto4.HostPrices{ + ContractPrice: settings.ContractPrice, + StoragePrice: settings.StoragePrice, + Collateral: settings.StoragePrice.Mul64(uint64(settings.CollateralMultiplier * 1000)).Div64(1000), + IngressPrice: settings.IngressPrice, + EgressPrice: settings.EgressPrice, + FreeSectorPrice: types.Siacoins(1).Div64((1 << 40) / proto4.SectorSize), // 1 SC / TB + }, + } + return hs +} + // NewConfigManager initializes a new config manager -func NewConfigManager(hostKey types.PrivateKey, store Store, cm ChainManager, s Syncer, wm Wallet, opts ...Option) (*ConfigManager, error) { +func NewConfigManager(hostKey types.PrivateKey, store Store, cm ChainManager, s Syncer, wm Wallet, sm Storage, opts ...Option) (*ConfigManager, error) { m := &ConfigManager{ announceInterval: 144 * 90, // 90 days validateNetAddress: true, hostKey: hostKey, initialSettings: DefaultSettings, - store: store, - chain: cm, - syncer: s, - wallet: wm, + store: store, + chain: cm, + syncer: s, + wallet: wm, + storage: sm, log: zap.NewNop(), a: alerts.NewNop(), tg: threadgroup.New(), // initialize the rate limiters - ingressLimit: rate.NewLimiter(rate.Inf, defaultBurstSize), - egressLimit: rate.NewLimiter(rate.Inf, defaultBurstSize), - - // rhp3 WebSocket TLS - rhp3WSTLS: &tls.Config{}, + ingressLimit: rate.NewLimiter(rate.Inf, math.MaxInt), + egressLimit: rate.NewLimiter(rate.Inf, math.MaxInt), } for _, opt := range opts { diff --git a/host/settings/settings_test.go b/host/settings/settings_test.go index 5faa789c..fb54b52a 100644 --- a/host/settings/settings_test.go +++ b/host/settings/settings_test.go @@ -41,7 +41,13 @@ func TestSettings(t *testing.T) { } defer contracts.Close() - sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50), settings.WithValidateNetAddress(false)) + storage, err := storage.NewVolumeManager(node.Store) + if err != nil { + t.Fatal("failed to create storage manager:", err) + } + defer storage.Close() + + sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, storage, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50), settings.WithValidateNetAddress(false)) if err != nil { t.Fatal(err) } diff --git a/host/settings/update.go b/host/settings/update.go index 923ce7d7..33e304fe 100644 --- a/host/settings/update.go +++ b/host/settings/update.go @@ -5,15 +5,17 @@ import ( "go.sia.tech/core/types" "go.sia.tech/coreutils/chain" - rhp4 "go.sia.tech/coreutils/rhp/v4" "go.uber.org/zap" ) // An UpdateStateTx is a transaction that can update the host's announcement // state. type UpdateStateTx interface { + // LastAnnouncement returns the last v1 announcement. LastAnnouncement() (Announcement, error) + // RevertLastAnnouncement reverts the last v1 announcement. RevertLastAnnouncement() error + // SetLastAnnouncement sets the last v1 announcement. SetLastAnnouncement(Announcement) error // LastV2AnnouncementHash returns the hash of the last v2 announcement. @@ -136,11 +138,11 @@ func (m *ConfigManager) ProcessActions(index types.ChainIndex) error { nextHeight := announcement.Index.Height + m.announceInterval netaddress := m.Settings().NetAddress if err := validateNetAddress(netaddress); err != nil && m.validateNetAddress { - m.log.Debug("failed to validate net address", zap.Error(err)) + m.log.Warn("invalid net address", zap.String("address", netaddress), zap.Error(err)) return nil } shouldAnnounce = index.Height >= nextHeight || announcement.Address != netaddress - } else { + } else if len(m.rhp4AnnounceAddresses) > 0 { announceHash, announceIndex, err := m.store.LastV2AnnouncementHash() if err != nil { return fmt.Errorf("failed to get last v2 announcement: %w", err) @@ -148,7 +150,7 @@ func (m *ConfigManager) ProcessActions(index types.ChainIndex) error { nextHeight := announceIndex.Height + m.announceInterval h := types.NewHasher() - types.EncodeSlice(h.E, chain.V2HostAnnouncement{{Protocol: rhp4.ProtocolTCPSiaMux, Address: m.Settings().NetAddress}}) + types.EncodeSlice(h.E, m.rhp4AnnounceAddresses) if err := h.E.Flush(); err != nil { return fmt.Errorf("failed to hash v2 announcement: %w", err) } diff --git a/host/storage/persist.go b/host/storage/persist.go index b31342a5..07153a83 100644 --- a/host/storage/persist.go +++ b/host/storage/persist.go @@ -12,7 +12,6 @@ type ( // needs to be migrated If the function returns an error, the sector should // be skipped and migration should continue. MigrateFunc func(location SectorLocation) error - // A VolumeStore stores and retrieves information about storage volumes. VolumeStore interface { // StorageUsage returns the number of used and total bytes in all volumes @@ -47,6 +46,15 @@ type ( // location and synced to disk during migrateFn. If migrateFn returns an // error, migration will continue, but that sector is not migrated. MigrateSectors(ctx context.Context, volumeID int64, min uint64, migrateFn MigrateFunc) (migrated, failed int, err error) + + // StoreTempSector calls fn with an empty location in a writable volume. + // + // The sector must be written to disk within fn. If fn returns an error, + // the metadata is rolled back. If no space is available, ErrNotEnoughStorage + // is returned. If the sector is already stored, fn is skipped and nil + // is returned. + StoreTempSector(root types.Hash256, expiration uint64, fn func(loc SectorLocation) error) error + // StoreSector calls fn with an empty location in a writable volume. If // the sector root already exists, fn is called with the existing // location and exists is true. Unless exists is true, The sector must @@ -56,10 +64,14 @@ type ( // // The sector should be referenced by either a contract or temp store // before release is called to prevent Prune() from removing it. + // + // Deprecated: use StoreTempSector instead StoreSector(root types.Hash256, fn func(loc SectorLocation, exists bool) error) (release func() error, err error) // RemoveSector removes the metadata of a sector and returns its // location in the volume. RemoveSector(root types.Hash256) error + // HasSector returns true if the sector is stored in the volume store. + HasSector(root types.Hash256) (bool, error) // SectorLocation returns the location of a sector or an error if the // sector is not found. The location is locked until release is // called. diff --git a/host/storage/storage.go b/host/storage/storage.go index 5f939fbd..2d777294 100644 --- a/host/storage/storage.go +++ b/host/storage/storage.go @@ -11,6 +11,7 @@ import ( lru "github.com/hashicorp/golang-lru/v2" rhp2 "go.sia.tech/core/rhp/v2" + rhp4 "go.sia.tech/core/rhp/v4" "go.sia.tech/core/types" "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/internal/threadgroup" @@ -158,7 +159,7 @@ func (vm *VolumeManager) loadVolumes() error { // immediately synced after the sector is written. func (vm *VolumeManager) migrateSector(loc SectorLocation) error { // read the sector from the old location - sector, err := vm.Read(loc.Root) + sector, err := vm.ReadSector(loc.Root) if err != nil { return fmt.Errorf("failed to read sector: %w", err) } @@ -359,6 +360,11 @@ func (vm *VolumeManager) SectorReferences(root types.Hash256) (SectorReference, return vm.vs.SectorReferences(root) } +// HasSector returns true if the host is storing the sector. +func (vm *VolumeManager) HasSector(root types.Hash256) (bool, error) { + return vm.vs.HasSector(root) +} + // Usage returns the total and used storage space, in sectors, from the storage manager. func (vm *VolumeManager) Usage() (usedSectors uint64, totalSectors uint64, err error) { done, err := vm.tg.Add() @@ -782,8 +788,8 @@ func (vm *VolumeManager) CacheStats() (hits, misses uint64) { return atomic.LoadUint64(&vm.cacheHits), atomic.LoadUint64(&vm.cacheMisses) } -// Read reads the sector with the given root -func (vm *VolumeManager) Read(root types.Hash256) (*[rhp2.SectorSize]byte, error) { +// ReadSector reads the sector with the given root +func (vm *VolumeManager) ReadSector(root types.Hash256) (*[rhp2.SectorSize]byte, error) { done, err := vm.tg.Add() if err != nil { return nil, err @@ -831,12 +837,72 @@ func (vm *VolumeManager) Read(root types.Hash256) (*[rhp2.SectorSize]byte, error } // Add sector to cache - vm.cache.Add(root, sector) + go vm.cache.Add(root, sector) vm.recorder.AddCacheMiss() atomic.AddUint64(&vm.cacheMisses, 1) return sector, nil } +// StoreSector stores a sector in the volume manager. The sector is written to +// the first available volume. If no volumes are available, an error is +// returned. +// +// The sector will be stored until the expiration height is reached. If the +// sector is not referenced by a contract before the expiration height, it will +// be pruned. +func (vm *VolumeManager) StoreSector(root types.Hash256, sector *[rhp4.SectorSize]byte, expiration uint64) error { + done, err := vm.tg.Add() + if err != nil { + return err + } + defer done() + + err = vm.vs.StoreTempSector(root, expiration, func(loc SectorLocation) error { + start := time.Now() + + vm.mu.Lock() + vol, ok := vm.volumes[loc.Volume] + vm.mu.Unlock() + if !ok { + return fmt.Errorf("volume %v not found", loc.Volume) + } + + // write the sector to the volume + if err := vol.WriteSector(sector, loc.Index); err != nil { + stats := vol.Stats() + vm.alerts.Register(alerts.Alert{ + ID: vol.alertID("write"), + Severity: alerts.SeverityError, + Message: "Failed to write sector", + Data: map[string]interface{}{ + "volume": vol.Location(), + "failedReads": stats.FailedReads, + "failedWrites": stats.FailedWrites, + "sector": root, + "error": err.Error(), + }, + Timestamp: time.Now(), + }) + return err + } + vm.log.Debug("wrote sector", zap.String("root", root.String()), zap.Int64("volume", loc.Volume), zap.Uint64("index", loc.Index), zap.Duration("elapsed", time.Since(start))) + + // mark the volume as changed + vm.mu.Lock() + vm.changedVolumes[loc.Volume] = true + vm.mu.Unlock() + // Add newly written sector to cache + vm.cache.Add(root, sector) + return nil + }) + if err != nil { + return err + } + // Add newly written sector to cache + vm.cache.Add(root, sector) + return nil +} + // Sync syncs the data files of changed volumes. func (vm *VolumeManager) Sync() error { done, err := vm.tg.Add() @@ -871,12 +937,15 @@ func (vm *VolumeManager) Sync() error { // Write writes a sector to a volume. release should only be called after the // contract roots have been committed to prevent the sector from being deleted. +// +// Deprecated: use StoreSector func (vm *VolumeManager) Write(root types.Hash256, data *[rhp2.SectorSize]byte) (func() error, error) { done, err := vm.tg.Add() if err != nil { return nil, err } defer done() + release, err := vm.vs.StoreSector(root, func(loc SectorLocation, exists bool) error { if exists { return nil @@ -889,7 +958,6 @@ func (vm *VolumeManager) Write(root types.Hash256, data *[rhp2.SectorSize]byte) if !ok { return fmt.Errorf("volume %v not found", loc.Volume) } - // write the sector to the volume if err := vol.WriteSector(data, loc.Index); err != nil { stats := vol.Stats() @@ -919,11 +987,18 @@ func (vm *VolumeManager) Write(root types.Hash256, data *[rhp2.SectorSize]byte) vm.mu.Unlock() return nil }) + if err != nil { + return nil, err + } + // Add newly written sector to cache + vm.cache.Add(root, data) return release, err } // AddTemporarySectors adds sectors to the temporary store. The sectors are not // referenced by a contract and will be removed at the expiration height. +// +// Deprecated: use StoreSector func (vm *VolumeManager) AddTemporarySectors(sectors []TempSector) error { if len(sectors) == 0 { return nil diff --git a/host/storage/storage_test.go b/host/storage/storage_test.go index d4f69b71..1523bc9a 100644 --- a/host/storage/storage_test.go +++ b/host/storage/storage_test.go @@ -1,6 +1,7 @@ package storage_test import ( + "bytes" "context" "errors" "fmt" @@ -93,10 +94,10 @@ func TestVolumeLoad(t *testing.T) { } // check that the sector is still there - sector2, err := vm.Read(root) + sector2, err := vm.ReadSector(root) if err != nil { t.Fatal(err) - } else if *sector2 != sector { + } else if !bytes.Equal(sector[:], sector2[:]) { t.Fatal("sector was corrupted") } @@ -205,7 +206,7 @@ func TestRemoveVolume(t *testing.T) { checkRoots := func(roots []types.Hash256) error { for _, root := range roots { - sector, err := vm.Read(root) + sector, err := vm.ReadSector(root) if err != nil { return fmt.Errorf("failed to read sector: %w", err) } else if rhp2.SectorRoot(sector) != root { @@ -670,7 +671,7 @@ func TestVolumeConcurrency(t *testing.T) { // read the sectors back for _, root := range roots { - sector, err := vm.Read(root) + sector, err := vm.ReadSector(root) if err != nil { t.Fatal(err) } else if rhp2.SectorRoot(sector) != root { @@ -685,7 +686,7 @@ func TestVolumeConcurrency(t *testing.T) { // read the sectors back for _, root := range roots { - sector, err := vm.Read(root) + sector, err := vm.ReadSector(root) if err != nil { t.Fatal(err) } else if rhp2.SectorRoot(sector) != root { @@ -1044,7 +1045,7 @@ func TestVolumeManagerReadWrite(t *testing.T) { // read the sectors back frand.Shuffle(len(roots), func(i, j int) { roots[i], roots[j] = roots[j], roots[i] }) for _, root := range roots { - sector, err := vm.Read(root) + sector, err := vm.ReadSector(root) if err != nil { t.Fatal(err) } @@ -1133,9 +1134,11 @@ func TestSectorCache(t *testing.T) { } } + time.Sleep(10 * time.Second) // sectors are added to the cache in a goroutine + // read the last 5 sectors all sectors should be cached for i, root := range roots[5:] { - _, err := vm.Read(root) + _, err := vm.ReadSector(root) if err != nil { t.Fatal(err) } @@ -1150,7 +1153,7 @@ func TestSectorCache(t *testing.T) { // read the first 5 sectors all sectors should be missed for i, root := range roots[:5] { - _, err := vm.Read(root) + _, err := vm.ReadSector(root) if err != nil { t.Fatal(err) } @@ -1165,7 +1168,7 @@ func TestSectorCache(t *testing.T) { // read the first 5 sectors again all sectors should be cached for i, root := range roots[:5] { - _, err := vm.Read(root) + _, err := vm.ReadSector(root) if err != nil { t.Fatal(err) } @@ -1312,7 +1315,7 @@ func BenchmarkVolumeManagerRead(b *testing.B) { b.SetBytes(rhp2.SectorSize) // read the sectors back for _, root := range written { - if _, err := vm.Read(root); err != nil { + if _, err := vm.ReadSector(root); err != nil { b.Fatal(err) } } diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 83e7f747..f894d0de 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -188,7 +188,7 @@ func NewHostNode(t *testing.T, pk types.PrivateKey, network *consensus.Network, initialSettings.AcceptingContracts = true initialSettings.NetAddress = "127.0.0.1:9981" initialSettings.WindowSize = 10 - sm, err := settings.NewConfigManager(pk, cn.Store, cn.Chain, cn.Syncer, wm, settings.WithAnnounceInterval(10), settings.WithValidateNetAddress(false), settings.WithInitialSettings(initialSettings)) + sm, err := settings.NewConfigManager(pk, cn.Store, cn.Chain, cn.Syncer, wm, vm, settings.WithAnnounceInterval(10), settings.WithValidateNetAddress(false), settings.WithInitialSettings(initialSettings)) if err != nil { t.Fatal(err) } diff --git a/persist/sqlite/accounts.go b/persist/sqlite/accounts.go index 0a3e354c..652d9c1a 100644 --- a/persist/sqlite/accounts.go +++ b/persist/sqlite/accounts.go @@ -7,6 +7,7 @@ import ( "time" rhp3 "go.sia.tech/core/rhp/v3" + proto4 "go.sia.tech/core/rhp/v4" "go.sia.tech/core/types" "go.sia.tech/hostd/host/accounts" "go.sia.tech/hostd/host/contracts" @@ -40,7 +41,127 @@ func incrementContractAccountFunding(tx *txn, accountID, contractID int64, amoun return nil } +// RHP4AccountBalance returns the balance of the account with the given ID. +func (s *Store) RHP4AccountBalance(account proto4.Account) (balance types.Currency, err error) { + err = s.transaction(func(tx *txn) error { + return tx.QueryRow(`SELECT balance FROM accounts WHERE account_id=$1`, encode(account)).Scan(decode(&balance)) + }) + return +} + +// RHP4DebitAccount debits the account with the given ID. +func (s *Store) RHP4DebitAccount(account proto4.Account, usage contracts.V2Usage) error { + return s.transaction(func(tx *txn) error { + var dbID int64 + var balance types.Currency + err := tx.QueryRow(`SELECT id, balance FROM accounts WHERE account_id=$1`, encode(account)).Scan(&dbID, decode(&balance)) + if err != nil { + return fmt.Errorf("failed to query balance: %w", err) + } + + total := usage.RenterCost() + balance, underflow := balance.SubWithUnderflow(total) + if underflow { + return fmt.Errorf("insufficient balance") + } + + _, err = tx.Exec(`UPDATE accounts SET balance=$1 WHERE id=$2`, encode(balance), dbID) + if err != nil { + return fmt.Errorf("failed to update balance: %w", err) + } + + if err := incrementCurrencyStat(tx, metricAccountBalance, total, true, time.Now()); err != nil { + return fmt.Errorf("failed to increment balance metric: %w", err) + } + // TODO: allocate usage to funding contracts + return nil + }) +} + +// RHP4CreditAccounts credits the accounts with the given deposits and revises +// the contract. +func (s *Store) RHP4CreditAccounts(deposits []proto4.AccountDeposit, contractID types.FileContractID, revision types.V2FileContract, usage contracts.V2Usage) (balances []types.Currency, err error) { + err = s.transaction(func(tx *txn) error { + getBalanceStmt, err := tx.Prepare(`SELECT balance FROM accounts WHERE account_id=$1`) + if err != nil { + return fmt.Errorf("failed to prepare get balance statement: %w", err) + } + defer getBalanceStmt.Close() + + updateBalanceStmt, err := tx.Prepare(`INSERT INTO accounts (account_id, balance, expiration_timestamp) VALUES ($1, $2, $3) ON CONFLICT (account_id) DO UPDATE SET balance=EXCLUDED.balance, expiration_timestamp=EXCLUDED.expiration_timestamp RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare update balance statement: %w", err) + } + defer updateBalanceStmt.Close() + + getFundingAmountStmt, err := tx.Prepare(`SELECT amount FROM contract_v2_account_funding WHERE contract_id=$1 AND account_id=$2`) + if err != nil { + return fmt.Errorf("failed to prepare get funding amount statement: %w", err) + } + defer getFundingAmountStmt.Close() + + updateFundingAmountStmt, err := tx.Prepare(`INSERT INTO contract_v2_account_funding (contract_id, account_id, amount) VALUES ($1, $2, $3) ON CONFLICT (contract_id, account_id) DO UPDATE SET amount=EXCLUDED.amount`) + if err != nil { + return fmt.Errorf("failed to prepare update funding amount statement: %w", err) + } + defer updateFundingAmountStmt.Close() + + var contractDBID int64 + err = tx.QueryRow(`SELECT id FROM contracts_v2 WHERE contract_id=$1`, encode(contractID)).Scan(&contractDBID) + if err != nil { + return fmt.Errorf("failed to get contract ID: %w", err) + } + + var totalDeposits types.Currency + var createdAccounts int + for _, deposit := range deposits { + var balance types.Currency + err := getBalanceStmt.QueryRow(encode(deposit.Account)).Scan(decode(&balance)) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to get balance: %w", err) + } else if errors.Is(err, sql.ErrNoRows) { + createdAccounts++ + } + + balance = balance.Add(deposit.Amount) + + var accountDBID int64 + err = updateBalanceStmt.QueryRow(encode(deposit.Account), encode(balance), encode(time.Now().Add(90*24*time.Hour))).Scan(&accountDBID) + if err != nil { + return fmt.Errorf("failed to update balance: %w", err) + } + balances = append(balances, balance) + + var fundAmount types.Currency + if err := getFundingAmountStmt.QueryRow(contractDBID, accountDBID).Scan(decode(&fundAmount)); err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to get funding amount: %w", err) + } + fundAmount = fundAmount.Add(deposit.Amount) + if _, err := updateFundingAmountStmt.Exec(contractDBID, accountDBID, encode(fundAmount)); err != nil { + return fmt.Errorf("failed to update funding amount: %w", err) + } + totalDeposits = totalDeposits.Add(deposit.Amount) + } + + _, err = reviseV2Contract(tx, contractID, revision, usage) + if err != nil { + return fmt.Errorf("failed to revise contract: %w", err) + } + + if err := incrementCurrencyStat(tx, metricAccountBalance, totalDeposits, false, time.Now()); err != nil { + return fmt.Errorf("failed to increment balance metric: %w", err) + } else if err := incrementNumericStat(tx, metricActiveAccounts, createdAccounts, time.Now()); err != nil { + return fmt.Errorf("failed to increment active accounts metric: %w", err) + } + + return nil + }) + return +} + // CreditAccountWithContract adds the specified amount to the account with the given ID. +// +// Deprecated: use CreditAccountsWithV2Contract instead. func (s *Store) CreditAccountWithContract(fund accounts.FundAccountWithContract) error { return s.transaction(func(tx *txn) error { // get current balance diff --git a/persist/sqlite/contracts.go b/persist/sqlite/contracts.go index 90d3b785..4f9944c1 100644 --- a/persist/sqlite/contracts.go +++ b/persist/sqlite/contracts.go @@ -9,6 +9,7 @@ import ( "time" "go.sia.tech/core/types" + rhp4 "go.sia.tech/coreutils/rhp/v4" "go.sia.tech/hostd/host/contracts" "go.uber.org/zap" ) @@ -126,14 +127,15 @@ func (s *Store) Contract(id types.FileContractID) (contract contracts.Contract, } // V2ContractElement returns the latest v2 state element with the given ID. -func (s *Store) V2ContractElement(contractID types.FileContractID) (ele types.V2FileContractElement, err error) { +func (s *Store) V2ContractElement(contractID types.FileContractID) (basis types.ChainIndex, ele types.V2FileContractElement, err error) { err = s.transaction(func(tx *txn) error { - const query = `SELECT cs.raw_contract, cs.leaf_index, cs.merkle_proof + const query = `SELECT cs.raw_contract, cs.leaf_index, cs.merkle_proof, g.last_scanned_index AS basis FROM contracts_v2 c INNER JOIN contract_v2_state_elements cs ON (c.id = cs.contract_id) +CROSS JOIN global_settings g WHERE c.contract_id=?` - err := tx.QueryRow(query, encode(contractID)).Scan(decode(&ele.V2FileContract), decode(&ele.StateElement.LeafIndex), decode(&ele.StateElement.MerkleProof)) + err := tx.QueryRow(query, encode(contractID)).Scan(decode(&ele.V2FileContract), decode(&ele.StateElement.LeafIndex), decode(&ele.StateElement.MerkleProof), decode(&basis)) if errors.Is(err, sql.ErrNoRows) { return contracts.ErrNotFound } @@ -161,7 +163,7 @@ WHERE c.contract_id=$1;` } // AddV2Contract adds a new contract to the database. -func (s *Store) AddV2Contract(contract contracts.V2Contract, formationSet contracts.V2FormationTransactionSet) error { +func (s *Store) AddV2Contract(contract contracts.V2Contract, formationSet rhp4.TransactionSet) error { return s.transaction(func(tx *txn) error { _, err := insertV2Contract(tx, contract, formationSet) return err @@ -172,7 +174,7 @@ func (s *Store) AddV2Contract(contract contracts.V2Contract, formationSet contra // contract's renewed_from field. The old contract's sector roots are // copied to the new contract. The status of the old contract should continue // to be active until the renewal is confirmed -func (s *Store) RenewV2Contract(renewal contracts.V2Contract, renewalSet contracts.V2FormationTransactionSet, renewedID types.FileContractID, clearing types.V2FileContract) error { +func (s *Store) RenewV2Contract(renewal contracts.V2Contract, renewalSet rhp4.TransactionSet, renewedID types.FileContractID, clearing types.V2FileContract, roots []types.Hash256) error { return s.transaction(func(tx *txn) error { // add the new contract renewedDBID, err := insertV2Contract(tx, renewal, renewalSet) @@ -238,9 +240,9 @@ func (s *Store) RenewContract(renewal contracts.SignedRevision, clearing contrac }) } -func incrementV2ContractUsage(tx *txn, dbID int64, usage contracts.Usage) error { +func incrementV2ContractUsage(tx *txn, dbID int64, usage contracts.V2Usage) error { const query = `SELECT rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral FROM contracts_v2 WHERE id=$1;` - var existing contracts.Usage + var existing contracts.V2Usage err := tx.QueryRow(query, dbID).Scan( decode(&existing.RPCRevenue), decode(&existing.StorageRevenue), @@ -295,111 +297,136 @@ func cleanupDanglingRoots(tx *txn, contractID int64, length int64) (deleted []in return deleted, nil } -// ReviseV2Contract atomically updates a contract's revision and sectors -func (s *Store) ReviseV2Contract(id types.FileContractID, revision types.V2FileContract, roots []types.Hash256, usage contracts.Usage) error { - return s.transaction(func(tx *txn) error { +func updateV2ContractUsage(tx *txn, contractDBID int64, usage contracts.V2Usage) error { + if err := incrementV2ContractUsage(tx, contractDBID, usage); err != nil { + return fmt.Errorf("failed to update contract usage: %w", err) + } + + var status contracts.V2ContractStatus + err := tx.QueryRow(`SELECT contract_status FROM contracts_v2 WHERE id=$1`, contractDBID).Scan(&status) + if err != nil { + return fmt.Errorf("failed to get contract status: %w", err) + } + + // only increment metrics if the contract is active. + // If the contract is pending or some variant of successful, the metrics + // will already be handled. + if status == contracts.V2ContractStatusActive { incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) if err != nil { return fmt.Errorf("failed to prepare increment currency stat statement: %w", err) } defer done() - const updateQuery = `UPDATE contracts_v2 SET raw_revision=?, revision_number=? WHERE contract_id=? RETURNING id, contract_status` - - var contractDBID int64 - var status contracts.V2ContractStatus - err = tx.QueryRow(updateQuery, encode(revision), encode(revision.RevisionNumber), encode(id)).Scan(&contractDBID, &status) - if err != nil { - return fmt.Errorf("failed to update contract: %w", err) - } else if err := incrementV2ContractUsage(tx, contractDBID, usage); err != nil { - return fmt.Errorf("failed to update contract usage: %w", err) + if err := updateV2PotentialRevenueMetrics(usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue: %w", err) + } else if err := updateCollateralMetrics(types.ZeroCurrency, usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) } + } + return nil +} - // only increment metrics if the contract is active. - // If the contract is pending or some variant of successful, the metrics - // will already be handled. - if status == contracts.V2ContractStatusActive { - if err := updatePotentialRevenueMetrics(usage, false, incrementCurrencyStat); err != nil { - return fmt.Errorf("failed to update potential revenue: %w", err) - } else if err := updateCollateralMetrics(types.ZeroCurrency, usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { - return fmt.Errorf("failed to update collateral metrics: %w", err) - } +func reviseV2Contract(tx *txn, id types.FileContractID, revision types.V2FileContract, usage contracts.V2Usage) (int64, error) { + const updateQuery = `UPDATE contracts_v2 SET raw_revision=?, revision_number=? WHERE contract_id=? RETURNING id` + + var contractDBID int64 + err := tx.QueryRow(updateQuery, encode(revision), encode(revision.RevisionNumber), encode(id)).Scan(&contractDBID) + if err != nil { + return 0, fmt.Errorf("failed to update contract: %w", err) + } else if err := updateV2ContractUsage(tx, contractDBID, usage); err != nil { + return 0, fmt.Errorf("failed to update contract usage: %w", err) + } + return contractDBID, nil +} + +func updateV2ContractSectors(tx *txn, contractDBID int64, roots []types.Hash256, log *zap.Logger) error { + selectOldSectorStmt, err := tx.Prepare(`SELECT sector_id FROM contract_v2_sector_roots WHERE contract_id=? AND root_index=?`) + if err != nil { + return fmt.Errorf("failed to prepare select old sector statement: %w", err) + } + defer selectOldSectorStmt.Close() + + selectRootIDStmt, err := tx.Prepare(`SELECT id FROM stored_sectors WHERE sector_root=?`) + if err != nil { + return fmt.Errorf("failed to prepare select root ID statement: %w", err) + } + defer selectRootIDStmt.Close() + + updateRootStmt, err := tx.Prepare(`INSERT INTO contract_v2_sector_roots (contract_id, sector_id, root_index) VALUES (?, ?, ?) ON CONFLICT (contract_id, root_index) DO UPDATE SET sector_id=excluded.sector_id`) + if err != nil { + return fmt.Errorf("failed to prepare update root statement: %w", err) + } + defer updateRootStmt.Close() + + var appended int + var deleted []int64 + seen := make(map[int64]bool) + for i, root := range roots { + // TODO: benchmark this against an exceptionally large contract. + // This is less efficient than the v1 implementation, but it leaves + // less room for update edge-cases now that all sectors are loaded + // into memory. + var newSectorID int64 + if err := selectRootIDStmt.QueryRow(encode(root)).Scan(&newSectorID); err != nil { + return fmt.Errorf("failed to get sector ID: %w", err) } - selectOldSectorStmt, err := tx.Prepare(`SELECT sector_id FROM contract_v2_sector_roots WHERE contract_id=? AND root_index=?`) - if err != nil { - return fmt.Errorf("failed to prepare select old sector statement: %w", err) + var oldSectorID int64 + err := selectOldSectorStmt.QueryRow(contractDBID, i).Scan(&oldSectorID) + if errors.Is(err, sql.ErrNoRows) { + // new sector + appended++ + } else if err != nil { + // db error + return fmt.Errorf("failed to get sector ID: %w", err) + } else if newSectorID == oldSectorID { + // no change + continue + } else if !seen[oldSectorID] { + // updated root + deleted = append(deleted, oldSectorID) // mark for pruning + seen[oldSectorID] = true } - defer selectOldSectorStmt.Close() - selectRootIDStmt, err := tx.Prepare(`SELECT id FROM stored_sectors WHERE sector_root=?`) - if err != nil { - return fmt.Errorf("failed to prepare select root ID statement: %w", err) + if _, err := updateRootStmt.Exec(contractDBID, newSectorID, i); err != nil { + return fmt.Errorf("failed to update sector root: %w", err) } - defer selectRootIDStmt.Close() + } - updateRootStmt, err := tx.Prepare(`INSERT INTO contract_v2_sector_roots (contract_id, sector_id, root_index) VALUES (?, ?, ?) ON CONFLICT (contract_id, root_index) DO UPDATE SET sector_id=excluded.sector_id`) - if err != nil { - return fmt.Errorf("failed to prepare update root statement: %w", err) + cleaned, err := cleanupDanglingRoots(tx, contractDBID, int64(len(roots))) + if err != nil { + return fmt.Errorf("failed to cleanup dangling roots: %w", err) + } + for _, sectorID := range cleaned { + if seen[sectorID] { + continue } - defer updateRootStmt.Close() - - var appended int - var deleted []int64 - seen := make(map[int64]bool) - for i, root := range roots { - // TODO: benchmark this against an exceptionally large contract. - // This is less efficient than the v1 implementation, but it leaves - // less room for update edge-cases now that all sectors are loaded - // into memory. - var newSectorID int64 - if err := selectRootIDStmt.QueryRow(encode(root)).Scan(&newSectorID); err != nil { - return fmt.Errorf("failed to get sector ID: %w", err) - } + deleted = append(deleted, sectorID) + } - var oldSectorID int64 - err := selectOldSectorStmt.QueryRow(contractDBID, i).Scan(&oldSectorID) - if errors.Is(err, sql.ErrNoRows) { - // new sector - appended++ - } else if err != nil { - // db error - return fmt.Errorf("failed to get sector ID: %w", err) - } else if newSectorID == oldSectorID { - // no change - continue - } else if !seen[oldSectorID] { - // updated root - deleted = append(deleted, oldSectorID) // mark for pruning - seen[oldSectorID] = true - } + delta := appended - len(deleted) + if err := incrementNumericStat(tx, metricContractSectors, delta, time.Now()); err != nil { + return fmt.Errorf("failed to update contract sectors: %w", err) + } - if _, err := updateRootStmt.Exec(contractDBID, newSectorID, i); err != nil { - return fmt.Errorf("failed to update sector root: %w", err) - } - } + if pruned, err := pruneSectors(tx, deleted); err != nil { + return fmt.Errorf("failed to prune sectors: %w", err) + } else if len(pruned) > 0 { + log.Debug("pruned sectors", zap.Int("count", len(pruned)), zap.Stringers("sectors", pruned)) + } + return nil +} - cleaned, err := cleanupDanglingRoots(tx, contractDBID, int64(len(roots))) +// ReviseV2Contract atomically updates a contract's revision and sectors +func (s *Store) ReviseV2Contract(id types.FileContractID, revision types.V2FileContract, roots []types.Hash256, usage contracts.V2Usage) error { + return s.transaction(func(tx *txn) error { + contractDBID, err := reviseV2Contract(tx, id, revision, usage) if err != nil { - return fmt.Errorf("failed to cleanup dangling roots: %w", err) - } - for _, sectorID := range cleaned { - if seen[sectorID] { - continue - } - deleted = append(deleted, sectorID) - } - - delta := appended - len(deleted) - if err := incrementNumericStat(tx, metricContractSectors, delta, time.Now()); err != nil { + return fmt.Errorf("failed to revise contract: %w", err) + } else if err := updateV2ContractSectors(tx, contractDBID, roots, s.log.Named("ReviseV2Contract").With(zap.Stringer("contract", id))); err != nil { return fmt.Errorf("failed to update contract sectors: %w", err) } - - if pruned, err := pruneSectors(tx, deleted); err != nil { - return fmt.Errorf("failed to prune sectors: %w", err) - } else if len(pruned) > 0 { - s.log.Debug("pruned sectors", zap.Int("count", len(pruned)), zap.Stringers("sectors", pruned)) - } return nil }) } @@ -943,7 +970,7 @@ func proofContracts(tx *txn, index types.ChainIndex) (revisions []contracts.Sign return } -func rebroadcastV2Contracts(tx *txn) (rebroadcast []contracts.V2FormationTransactionSet, err error) { +func rebroadcastV2Contracts(tx *txn) (rebroadcast []rhp4.TransactionSet, err error) { rows, err := tx.Query(`SELECT formation_txn_set, formation_txn_set_basis FROM contracts_v2 WHERE confirmation_index IS NULL AND contract_status <> ?`, contracts.ContractStatusRejected) if err != nil { return nil, err @@ -951,13 +978,13 @@ func rebroadcastV2Contracts(tx *txn) (rebroadcast []contracts.V2FormationTransac defer rows.Close() for rows.Next() { - var formationSet contracts.V2FormationTransactionSet + var formationSet rhp4.TransactionSet var buf []byte if err := rows.Scan(&buf, decode(&formationSet.Basis)); err != nil { return nil, fmt.Errorf("failed to scan contract id: %w", err) } dec := types.NewBufDecoder(buf) - types.DecodeSlice(dec, &formationSet.TransactionSet) + types.DecodeSlice(dec, &formationSet.Transactions) if err := dec.Err(); err != nil { return nil, fmt.Errorf("failed to decode formation txn set: %w", err) } @@ -1129,7 +1156,7 @@ raw_revision, host_sig, renter_sig, confirmed_revision_number, contract_status, return } -func insertV2Contract(tx *txn, contract contracts.V2Contract, formationSet contracts.V2FormationTransactionSet) (dbID int64, err error) { +func insertV2Contract(tx *txn, contract contracts.V2Contract, formationSet rhp4.TransactionSet) (dbID int64, err error) { const query = `INSERT INTO contracts_v2 (contract_id, renter_id, locked_collateral, rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral, revision_number, negotiation_height, window_start, window_end, formation_txn_set, formation_txn_set_basis, raw_revision, contract_status) VALUES @@ -1153,7 +1180,7 @@ formation_txn_set_basis, raw_revision, contract_status) VALUES contract.NegotiationHeight, // stored as int64 for queries, should never overflow contract.V2FileContract.ProofHeight, // stored as int64 for queries, should never overflow contract.ExpirationHeight, // stored as int64 for queries, should never overflow - encodeSlice(formationSet.TransactionSet), + encodeSlice(formationSet.Transactions), encode(formationSet.Basis), encode(contract.V2FileContract), contracts.V2ContractStatusPending, diff --git a/persist/sqlite/contracts_test.go b/persist/sqlite/contracts_test.go index 24a677c2..b6670ac3 100644 --- a/persist/sqlite/contracts_test.go +++ b/persist/sqlite/contracts_test.go @@ -291,7 +291,7 @@ func TestReviseContract(t *testing.T) { case contracts.SectorActionAppend: // add a random sector root root := frand.Entropy256() - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) + release, err := db.StoreSector(root, func(_ storage.SectorLocation, _ bool) error { return nil }) if err != nil { t.Fatal(err) } @@ -301,7 +301,7 @@ func TestReviseContract(t *testing.T) { case contracts.SectorActionUpdate: // replace with a random sector root root := frand.Entropy256() - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) + release, err := db.StoreSector(root, func(_ storage.SectorLocation, _ bool) error { return nil }) if err != nil { t.Fatal(err) } @@ -480,7 +480,7 @@ func BenchmarkTrimSectors(b *testing.B) { roots = append(roots, root) appendActions = append(appendActions, contracts.SectorChange{Action: contracts.SectorActionAppend, Root: root}) - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) + release, err := db.StoreSector(root, func(_ storage.SectorLocation, _ bool) error { return nil }) if err != nil { b.Fatal(err) } diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index fdaec9c6..542ef6ce 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -30,6 +30,7 @@ CREATE TABLE stored_sectors ( CREATE INDEX stored_sectors_sector_root ON stored_sectors(sector_root); CREATE INDEX stored_sectors_last_access ON stored_sectors(last_access_timestamp); +-- TODO: remove after hardfork CREATE TABLE locked_sectors ( -- should be cleared at startup. currently persisted for simplicity, but may be moved to memory id INTEGER PRIMARY KEY, sector_id INTEGER NOT NULL REFERENCES stored_sectors(id) @@ -220,6 +221,14 @@ CREATE TABLE contract_account_funding ( UNIQUE (contract_id, account_id) ); +CREATE TABLE contract_v2_account_funding ( + id INTEGER PRIMARY KEY, + contract_id INTEGER NOT NULL REFERENCES contracts_v2(id), + account_id INTEGER NOT NULL REFERENCES accounts(id), + amount BLOB NOT NULL, + UNIQUE (contract_id, account_id) +); + CREATE TABLE host_stats ( date_created INTEGER NOT NULL, stat TEXT NOT NULL, diff --git a/persist/sqlite/sectors.go b/persist/sqlite/sectors.go index 50c6b6e0..8212b22a 100644 --- a/persist/sqlite/sectors.go +++ b/persist/sqlite/sectors.go @@ -80,6 +80,22 @@ func (s *Store) RemoveSector(root types.Hash256) (err error) { }) } +// HasSector returns true if the sector is stored on the host or false +// otherwise. +func (s *Store) HasSector(root types.Hash256) (exists bool, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT 1 FROM stored_sectors ss + INNER JOIN volume_sectors vs ON ss.id=vs.sector_id + WHERE ss.sector_root=$1;` + err = tx.QueryRow(query, encode(root)).Scan(&exists) + if errors.Is(err, sql.ErrNoRows) { + return nil + } + return err + }) + return +} + // SectorLocation returns the location of a sector or an error if the // sector is not found. The sector is locked until release is // called. diff --git a/persist/sqlite/sectors_test.go b/persist/sqlite/sectors_test.go new file mode 100644 index 00000000..fddeee00 --- /dev/null +++ b/persist/sqlite/sectors_test.go @@ -0,0 +1,54 @@ +package sqlite + +import ( + "path/filepath" + "testing" + + "go.sia.tech/core/types" + "go.sia.tech/hostd/host/storage" + "go.uber.org/zap/zaptest" + "lukechampine.com/frand" +) + +func BenchmarkSectorLocation(b *testing.B) { + log := zaptest.NewLogger(b) + db, err := OpenDatabase(filepath.Join(b.TempDir(), "test.db"), log) + if err != nil { + b.Fatal(err) + } + defer db.Close() + + volumeID, err := db.AddVolume("test", false) + if err != nil { + b.Fatal(err) + } + + // grow the volume to b.N sectors + if err := db.GrowVolume(volumeID, uint64(b.N)); err != nil { + b.Fatal(err) + } else if err := db.SetReadOnly(volumeID, false); err != nil { + b.Fatal(err) + } else if err := db.SetAvailable(volumeID, true); err != nil { + b.Fatal(err) + } + + roots := make([]types.Hash256, b.N) + for i := 0; i < b.N; i++ { + roots[i] = frand.Entropy256() + _, err := db.StoreSector(roots[i], func(storage.SectorLocation, bool) error { return nil }) + if err != nil { + b.Fatal(err) + } + } + + b.ResetTimer() + b.ReportAllocs() + b.ReportMetric(float64(b.N), "sectors") + + for i := 0; i < b.N; i++ { + _, _, err := db.SectorLocation(roots[i]) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/persist/sqlite/volumes.go b/persist/sqlite/volumes.go index e36597fa..b276107c 100644 --- a/persist/sqlite/volumes.go +++ b/persist/sqlite/volumes.go @@ -273,6 +273,8 @@ WHERE v.id=$1` // // The sector should be referenced by either a contract or temp store // before release is called to prevent it from being pruned +// +// Deprecated: use StoreTempSector instead func (s *Store) StoreSector(root types.Hash256, fn func(loc storage.SectorLocation, exists bool) error) (func() error, error) { var sectorLockID int64 var locationLocks []int64 @@ -353,6 +355,110 @@ func (s *Store) StoreSector(root types.Hash256, fn func(loc storage.SectorLocati return unlock, nil } +// StoreTempSector calls fn with an empty location in a writable volume. +// +// The sector must be written to disk within fn. If fn returns an error, +// the metadata is rolled back. If no space is available, ErrNotEnoughStorage +// is returned. If the sector is already stored, fn is skipped and nil +// is returned. +func (s *Store) StoreTempSector(root types.Hash256, expiration uint64, fn func(loc storage.SectorLocation) error) error { + var locationLocks []int64 + var location storage.SectorLocation + var exists bool + + // this weird manual locking and two-stage transaction is required to ensure + // atomicity with the disk without locking the whole database while + // waiting on IO that may be slow. In a database with saner locking, this + // could be a single transaction. + log := s.log.Named("StoreSector").With(zap.Stringer("root", root)) + err := s.transaction(func(tx *txn) error { + var err error + sectorID, err := insertSectorDBID(tx, root) + if err != nil { + return fmt.Errorf("failed to get sector id: %w", err) + } + + // check if the sector is already stored on disk + location, err = sectorLocation(tx, sectorID, root) + exists = err == nil + if exists { + // skip if the sector is already stored + return nil + } else if err != nil && !errors.Is(err, storage.ErrSectorNotFound) { + return fmt.Errorf("failed to check existing sector location: %w", err) + } + + location, err = emptyLocation(tx) + if err != nil { + return fmt.Errorf("failed to get empty location: %w", err) + } + + // lock the location + locationLocks, err = lockLocations(tx, []storage.SectorLocation{location}) + if err != nil { + return fmt.Errorf("failed to lock sector location: %w", err) + } + return nil + }) + if err != nil { + return err + } + defer func() { + err := s.transaction(func(tx *txn) error { + return unlockLocations(tx, locationLocks) + }) + if err != nil { + log.Warn("failed to unlock sector location", zap.Error(err)) + } + }() + + if !exists { + // only call fn if the sector is not already stored + if err := fn(location); err != nil { + return fmt.Errorf("failed to store sector: %w", err) + } + } + + // commit the sector + err = s.transaction(func(tx *txn) error { + sectorID, err := insertSectorDBID(tx, root) + if err != nil { + return fmt.Errorf("failed to get sector id: %w", err) + } + + _, err = tx.Exec(`INSERT INTO temp_storage_sector_roots (sector_id, expiration_height) VALUES ($1, $2)`, sectorID, expiration) + if err != nil { + return fmt.Errorf("failed to commit temp sector: %w", err) + } + + if err := incrementNumericStat(tx, metricTempSectors, 1, time.Now()); err != nil { + return fmt.Errorf("failed to update temp sector metric: %w", err) + } + + if !exists { + // skip volume updates if the sector already exists + if err := incrementVolumeUsage(tx, location.Volume, 1); err != nil { + return fmt.Errorf("failed to update volume metadata: %w", err) + } + + res, err := tx.Exec(`UPDATE volume_sectors SET sector_id=$1 WHERE id=$2`, sectorID, location.ID) + if err != nil { + return fmt.Errorf("failed to commit sector location: %w", err) + } else if rows, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to check rows affected: %w", err) + } else if rows == 0 { + return storage.ErrSectorNotFound + } + } + + return nil + }) + if err != nil { + return fmt.Errorf("failed to commit sector: %w", err) + } + return nil +} + // MigrateSectors migrates each occupied sector of a volume starting at // startIndex. migrateFn will be called for each sector that needs to be migrated. // The sector data should be copied to the new location and synced diff --git a/persist/sqlite/volumes_test.go b/persist/sqlite/volumes_test.go index dc308075..4a8e7fc4 100644 --- a/persist/sqlite/volumes_test.go +++ b/persist/sqlite/volumes_test.go @@ -43,12 +43,9 @@ func TestVolumeSetReadOnly(t *testing.T) { t.Fatal(err) } - // try to add a sector to the volume - release, err := db.StoreSector(frand.Entropy256(), func(loc storage.SectorLocation, exists bool) error { return nil }) + err = db.StoreTempSector(frand.Entropy256(), 1, func(loc storage.SectorLocation) error { return nil }) if err != nil { t.Fatal(err) - } else if err := release(); err != nil { // immediately release the sector so it can be used again - t.Fatal(err) } // set the volume to read-only @@ -56,9 +53,9 @@ func TestVolumeSetReadOnly(t *testing.T) { t.Fatal(err) } - // try to add another sector to the volume, should fail with + // add another sector to the volume, should fail with // ErrNotEnoughStorage - _, err = db.StoreSector(frand.Entropy256(), func(loc storage.SectorLocation, exists bool) error { return nil }) + err = db.StoreTempSector(frand.Entropy256(), 1, func(loc storage.SectorLocation) error { return nil }) if !errors.Is(err, storage.ErrNotEnoughStorage) { t.Fatalf("expected ErrNotEnoughStorage, got %v", err) } @@ -82,7 +79,7 @@ func TestAddSector(t *testing.T) { root := frand.Entropy256() // try to store a sector in the empty volume, should return // ErrNotEnoughStorage - _, err = db.StoreSector(root, func(storage.SectorLocation, bool) error { return nil }) + err = db.StoreTempSector(root, 1, func(storage.SectorLocation) error { return nil }) if !errors.Is(err, storage.ErrNotEnoughStorage) { t.Fatalf("expected ErrNotEnoughStorage, got %v", err) } @@ -92,7 +89,7 @@ func TestAddSector(t *testing.T) { t.Fatal(err) } // store the sector - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { + err = db.StoreTempSector(root, 1, func(loc storage.SectorLocation) error { // check that the sector was stored in the expected location if loc.Volume != volumeID { t.Fatalf("expected volume ID %v, got %v", volumeID, loc.Volume) @@ -104,11 +101,6 @@ func TestAddSector(t *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - if err := release(); err != nil { - t.Fatal("failed to release sector1:", err) - } - }() // check the location was added loc, release, err := db.SectorLocation(root) @@ -134,22 +126,13 @@ func TestAddSector(t *testing.T) { t.Fatalf("expected 1 used sector, got %v", volumes[0].UsedSectors) } - // store the sector again, exists should be true - release, err = db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { - switch { - case !exists: - t.Fatal("sector does not exist") - case loc.Volume != volumeID: - t.Fatalf("expected volume ID %v, got %v", volumeID, loc.Volume) - case loc.Index != 0: - t.Fatalf("expected sector index 0, got %v", loc.Index) - } + // store the sector again, should succeed + err = db.StoreTempSector(root, 1, func(loc storage.SectorLocation) error { + t.Fatal("write func called for existing sector") return nil }) if err != nil { t.Fatal(err) - } else if err := release(); err != nil { - t.Fatal(err) } volumes, err = db.Volumes() @@ -164,7 +147,7 @@ func TestAddSector(t *testing.T) { // try to store another sector in the volume, should return // ErrNotEnoughStorage - _, err = db.StoreSector(frand.Entropy256(), func(storage.SectorLocation, bool) error { return nil }) + err = db.StoreTempSector(frand.Entropy256(), 1, func(storage.SectorLocation) error { return nil }) if !errors.Is(err, storage.ErrNotEnoughStorage) { t.Fatalf("expected ErrNotEnoughStorage, got %v", err) } @@ -336,23 +319,18 @@ func TestShrinkVolume(t *testing.T) { t.Fatalf("expected %v total sectors, got %v", initialSectors/2, m.Storage.TotalSectors) } - // add a few sectors - var releaseFns []func() error for i := 0; i < 5; i++ { - release, err := db.StoreSector(frand.Entropy256(), func(loc storage.SectorLocation, exists bool) error { + err := db.StoreTempSector(frand.Entropy256(), 1, func(loc storage.SectorLocation) error { if loc.Volume != volume.ID { t.Fatalf("expected volume ID %v, got %v", volume.ID, loc.Volume) } else if loc.Index != uint64(i) { t.Fatalf("expected sector index %v, got %v", i, loc.Index) - } else if exists { - t.Fatal("sector exists") } return nil }) if err != nil { t.Fatal(err) } - releaseFns = append(releaseFns, release) } // check that the volume cannot be shrunk below the used sectors @@ -368,12 +346,6 @@ func TestShrinkVolume(t *testing.T) { } else if m.Storage.TotalSectors != 5 { t.Fatalf("expected %v total sectors, got %v", 5, m.Storage.TotalSectors) } - - for _, fn := range releaseFns { - if err := fn(); err != nil { - t.Fatal(err) - } - } } func TestRemoveVolume(t *testing.T) { @@ -404,26 +376,17 @@ func TestRemoveVolume(t *testing.T) { // add a few sectors for i := 0; i < 5; i++ { sectorRoot := frand.Entropy256() - release, err := db.StoreSector(sectorRoot, func(loc storage.SectorLocation, exists bool) error { + err = db.StoreTempSector(sectorRoot, uint64(i+1), func(loc storage.SectorLocation) error { if loc.Volume != volume.ID { t.Fatalf("expected volume ID %v, got %v", volume.ID, loc.Volume) } else if loc.Index != uint64(i) { t.Fatalf("expected sector index 0, got %v", loc.Index) - } else if exists { - t.Fatal("sector exists") } return nil }) if err != nil { t.Fatal(err) } - - err = db.AddTemporarySectors([]storage.TempSector{{Root: sectorRoot, Expiration: uint64(i)}}) - if err != nil { - t.Fatal(err) - } else if err := release(); err != nil { - t.Fatal(err) - } } // check that the metrics were updated correctly @@ -479,25 +442,17 @@ func TestMigrateSectors(t *testing.T) { for i := range roots { root := frand.Entropy256() roots[i] = root - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { + err = db.StoreTempSector(root, uint64(i+1), func(loc storage.SectorLocation) error { if loc.Volume != volume.ID { t.Fatalf("expected volume ID %v, got %v", volume.ID, loc.Volume) } else if loc.Index != uint64(i) { t.Fatalf("expected sector index %v, got %v", i, loc.Index) - } else if exists { - t.Fatal("sector already exists") } return nil }) if err != nil { t.Fatal(err) } - - if err := db.AddTemporarySectors([]storage.TempSector{{Root: root, Expiration: uint64(i)}}); err != nil { - t.Fatal(err) - } else if err := release(); err != nil { - t.Fatal(err) - } } // remove the first half of the sectors @@ -603,16 +558,13 @@ func TestPrune(t *testing.T) { // store enough sectors to fill the volume roots := make([]types.Hash256, 0, sectors) - releaseFns := make([]func() error, 0, sectors) for i := 0; i < sectors; i++ { root := frand.Entropy256() - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { + err := db.StoreTempSector(root, 1, func(loc storage.SectorLocation) error { if loc.Volume != volume.ID { t.Fatalf("expected volume ID %v, got %v", volume.ID, loc.Volume) } else if loc.Index != uint64(i) { t.Fatalf("expected sector index %v, got %v", i, loc.Index) - } else if exists { - t.Fatal("sector already exists") } return nil }) @@ -620,7 +572,6 @@ func TestPrune(t *testing.T) { t.Fatal(err) } roots = append(roots, root) - releaseFns = append(releaseFns, release) } renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) @@ -648,7 +599,7 @@ func TestPrune(t *testing.T) { if err := db.AddContract(c, []types.Transaction{}, types.MaxCurrency, contracts.Usage{}, 100); err != nil { t.Fatal(err) } - contractSectors, tempSectors, lockedSectors, deletedSectors := roots[:20], roots[20:40], roots[40:60], roots[60:] + contractSectors := roots[:50] // append the contract sectors to the contract var changes []contracts.SectorChange for _, root := range contractSectors { @@ -662,46 +613,7 @@ func TestPrune(t *testing.T) { t.Fatal(err) } - // add the temporary sectors - var tempRoots []storage.TempSector - for _, root := range tempSectors { - tempRoots = append(tempRoots, storage.TempSector{ - Root: root, - Expiration: 100, - }) - } - if err := db.AddTemporarySectors(tempRoots); err != nil { - t.Fatal(err) - } - - // lock the remaining sectors - var locks []int64 - for _, root := range lockedSectors { - err := db.transaction(func(tx *txn) error { - sectorID, err := sectorDBID(tx, root) - if err != nil { - return err - } - lockID, err := lockSector(tx, sectorID) - if err != nil { - return err - } - locks = append(locks, lockID) - return nil - }) - if err != nil { - t.Fatal(err) - } - } - - // remove the initial locks - for _, fn := range releaseFns { - if err := fn(); err != nil { - t.Fatal(err) - } - } - - checkConsistency := func(contract, temp, locked, deleted []types.Hash256) error { + checkConsistency := func(contract, temp, deleted []types.Hash256, expectedSectors uint64) error { for _, root := range contract { if _, release, err := db.SectorLocation(root); err != nil { return fmt.Errorf("sector %v not found: %w", root, err) @@ -718,14 +630,6 @@ func TestPrune(t *testing.T) { } } - for _, root := range locked { - if _, release, err := db.SectorLocation(root); err != nil { - return fmt.Errorf("sector %v not found: %w", root, err) - } else if err := release(); err != nil { - return fmt.Errorf("failed to release sector %v: %w", root, err) - } - } - for i, root := range deleted { if _, _, err := db.SectorLocation(root); !errors.Is(err, storage.ErrSectorNotFound) { return fmt.Errorf("expected ErrSectorNotFound for sector %d %q, got %v", i, root, err) @@ -733,7 +637,6 @@ func TestPrune(t *testing.T) { } // check the volume usage - expectedSectors := uint64(len(contract) + len(temp) + len(locked)) used, _, err := db.StorageUsage() if err != nil { return fmt.Errorf("failed to get storage usage: %w", err) @@ -755,28 +658,17 @@ func TestPrune(t *testing.T) { return nil } - if err := checkConsistency(contractSectors, tempSectors, lockedSectors, deletedSectors); err != nil { - t.Fatal(err) - } - - // unlock locked sectors - err = db.transaction(func(tx *txn) error { - return unlockSector(tx, log.Named("unlockSector"), locks...) - }) - if err != nil { - t.Fatal(err) - } - - if err := checkConsistency(contractSectors, tempSectors, nil, roots[60:]); err != nil { + // all sectors start out as temporary + if err := checkConsistency(contractSectors, roots, []types.Hash256{}, 100); err != nil { t.Fatal(err) } // expire the temp sectors - if err := db.ExpireTempSectors(100); err != nil { + if err := db.ExpireTempSectors(1); err != nil { t.Fatal(err) } - if err := checkConsistency(contractSectors, nil, nil, roots[40:]); err != nil { + if err := checkConsistency(contractSectors, nil, roots[50:], 50); err != nil { t.Fatal(err) } @@ -789,7 +681,7 @@ func TestPrune(t *testing.T) { } contractSectors = contractSectors[:len(contractSectors)/2] - if err := checkConsistency(contractSectors, nil, nil, roots[50:]); err != nil { + if err := checkConsistency(contractSectors[:25], nil, roots[50:], 25); err != nil { t.Fatal(err) } @@ -798,7 +690,7 @@ func TestPrune(t *testing.T) { t.Fatal(err) } - if err := checkConsistency(nil, nil, nil, roots); err != nil { + if err := checkConsistency(nil, nil, roots, 0); err != nil { t.Fatal(err) } } @@ -870,11 +762,9 @@ func BenchmarkVolumeMigrate(b *testing.B) { roots := make([]types.Hash256, b.N) for i := range roots { roots[i] = frand.Entropy256() - release, err := db.StoreSector(roots[i], func(loc storage.SectorLocation, exists bool) error { return nil }) + err := db.StoreTempSector(roots[i], 10, func(loc storage.SectorLocation) error { return nil }) if err != nil { b.Fatalf("failed to store sector %v: %v", i, err) - } else if err := release(); err != nil { - b.Fatal(err) } } @@ -888,7 +778,7 @@ func BenchmarkVolumeMigrate(b *testing.B) { b.ReportMetric(float64(b.N), "sectors") // migrate all sectors from the first volume to the second - migrated, failed, err := db.MigrateSectors(context.Background(), volume1.ID, 0, func(loc storage.SectorLocation) error { + migrated, failed, err := db.MigrateSectors(context.Background(), volume1.ID, 0, func(_ storage.SectorLocation) error { return nil }) if err != nil { @@ -918,7 +808,7 @@ func BenchmarkStoreSector(b *testing.B) { b.ReportMetric(float64(b.N), "sectors") for i := 0; i < b.N; i++ { - _, err := db.StoreSector(frand.Entropy256(), func(loc storage.SectorLocation, exists bool) error { return nil }) + err := db.StoreTempSector(frand.Entropy256(), 1, func(loc storage.SectorLocation) error { return nil }) if err != nil { b.Fatal(err) } diff --git a/rhp/conn.go b/rhp/conn.go deleted file mode 100644 index a4385c3a..00000000 --- a/rhp/conn.go +++ /dev/null @@ -1,82 +0,0 @@ -package rhp - -import ( - "context" - "net" - "sync/atomic" - - "golang.org/x/time/rate" -) - -type ( - // A DataMonitor records the amount of data read and written across all connections. - DataMonitor interface { - ReadBytes(n int) - WriteBytes(n int) - } - - // A Conn wraps a net.Conn to track the amount of data read and written and - // limit bandwidth usage. - Conn struct { - net.Conn - r, w uint64 - monitor DataMonitor - rl, wl *rate.Limiter - } - - // A noOpMonitor is a DataMonitor that does nothing. - noOpMonitor struct{} -) - -// ReadBytes implements DataMonitor -func (noOpMonitor) ReadBytes(n int) {} - -// WriteBytes implements DataMonitor -func (noOpMonitor) WriteBytes(n int) {} - -// Usage returns the amount of data read and written by the connection. -func (c *Conn) Usage() (read, written uint64) { - read = atomic.LoadUint64(&c.r) - written = atomic.LoadUint64(&c.w) - return -} - -// Read implements io.Reader -func (c *Conn) Read(b []byte) (int, error) { - n, err := c.Conn.Read(b) - atomic.AddUint64(&c.r, uint64(n)) - c.monitor.ReadBytes(n) - if err := c.rl.WaitN(context.Background(), n); err != nil { - return n, err - } - return n, err -} - -// Write implements io.Writer -func (c *Conn) Write(b []byte) (int, error) { - n, err := c.Conn.Write(b) - atomic.AddUint64(&c.w, uint64(n)) - c.monitor.WriteBytes(n) - if err := c.wl.WaitN(context.Background(), n); err != nil { - return n, err - } - return n, err -} - -// NewConn initializes a new RPC conn wrapper. -func NewConn(c net.Conn, m DataMonitor, rl, wl *rate.Limiter) *Conn { - if c, ok := c.(*Conn); ok { - return c - } - return &Conn{ - Conn: c, - monitor: m, - rl: rl, - wl: wl, - } -} - -// NewNoOpMonitor initializes a new NoOpMonitor. -func NewNoOpMonitor() DataMonitor { - return noOpMonitor{} -} diff --git a/rhp/listener.go b/rhp/listener.go new file mode 100644 index 00000000..b2784a02 --- /dev/null +++ b/rhp/listener.go @@ -0,0 +1,116 @@ +package rhp + +import ( + "net" + + "golang.org/x/time/rate" +) + +type noOpMonitor struct{} + +func (noOpMonitor) ReadBytes(n int) {} +func (noOpMonitor) WriteBytes(n int) {} + +type ( + // An Option configures a listener. + Option func(*rhpListener) + + // DataMonitor records the amount of data read and written across + // all connections. + DataMonitor interface { + ReadBytes(int) + WriteBytes(int) + } + + rhpConn struct { + net.Conn + + monitor DataMonitor + } + + rhpListener struct { + l net.Listener + + readLimiter *rate.Limiter + writeLimiter *rate.Limiter + monitor DataMonitor + } +) + +var _ net.Listener = &rhpListener{} +var _ net.Conn = &rhpConn{} + +// WithReadLimit sets the read rate limit for the listener. +func WithReadLimit(r *rate.Limiter) Option { + return func(l *rhpListener) { + l.readLimiter = r + } +} + +// WithWriteLimit sets the write rate limit for the listener. +func WithWriteLimit(w *rate.Limiter) Option { + return func(l *rhpListener) { + l.writeLimiter = w + } +} + +// WithDataMonitor sets the data monitor for the listener. +func WithDataMonitor(m DataMonitor) Option { + return func(l *rhpListener) { + l.monitor = m + } +} + +// Read reads data from the connection. Read can be made to time out and return +// an error after a fixed time limit; see SetDeadline and SetReadDeadline. +func (c *rhpConn) Read(b []byte) (int, error) { + n, err := c.Conn.Read(b) + c.monitor.ReadBytes(n) + return n, err +} + +// Write writes data to the connection. Write can be made to time out and return +// an error after a fixed time limit; see SetDeadline and SetWriteDeadline. +func (c *rhpConn) Write(b []byte) (int, error) { + n, err := c.Conn.Write(b) + c.monitor.WriteBytes(n) + return n, err +} + +func (l *rhpListener) Accept() (net.Conn, error) { + c, err := l.l.Accept() + if err != nil { + return nil, err + } + return &rhpConn{ + Conn: c, + monitor: l.monitor, + }, nil +} + +func (l *rhpListener) Close() error { + return l.l.Close() +} + +func (l *rhpListener) Addr() net.Addr { + return l.l.Addr() +} + +// Listen returns a new listener with optional rate limiting and monitoring. +func Listen(network, address string, opts ...Option) (net.Listener, error) { + l, err := net.Listen(network, address) + if err != nil { + return nil, err + } + + rhp := &rhpListener{ + l: l, + readLimiter: rate.NewLimiter(rate.Inf, 0), + writeLimiter: rate.NewLimiter(rate.Inf, 0), + monitor: noOpMonitor{}, + } + for _, opt := range opts { + opt(rhp) + } + return rhp, nil +} diff --git a/rhp/reporter.go b/rhp/reporter.go deleted file mode 100644 index f9b84788..00000000 --- a/rhp/reporter.go +++ /dev/null @@ -1,204 +0,0 @@ -package rhp - -import ( - "encoding/hex" - "sync" - "time" - - "go.sia.tech/core/types" - "go.sia.tech/hostd/host/contracts" - "lukechampine.com/frand" -) - -// SessionEventType is the type of a session event. -const ( - SessionEventTypeStart = "sessionStart" - SessionEventTypeEnd = "sessionEnd" - SessionEventTypeRPCStart = "rpcStart" - SessionEventTypeRPCEnd = "rpcEnd" -) - -// SessionProtocol is the protocol used by a session. -const ( - SessionProtocolTCP = "tcp" - SessionProtocolWS = "websocket" -) - -type ( - // UID is a unique identifier for a session or RPC. - UID [8]byte - - // A Session is an open connection between a host and a renter. - Session struct { - conn *Conn - - ID UID `json:"id"` - Protocol string `json:"protocol"` - RHPVersion int `json:"rhpVersion"` - PeerAddress string `json:"peerAddress"` - Ingress uint64 `json:"ingress"` - Egress uint64 `json:"egress"` - Usage contracts.Usage `json:"usage"` - - Timestamp time.Time `json:"timestamp"` - } - - // An RPC is an RPC call made by a renter to a host. - RPC struct { - ID UID `json:"id"` - SessionID UID `json:"sessionID"` - RPC types.Specifier `json:"rpc"` - Usage contracts.Usage `json:"usage"` - Error error `json:"error,omitempty"` - Elapsed time.Duration `json:"timestamp"` - } - - // A SessionSubscriber receives session events. - SessionSubscriber interface { - ReceiveSessionEvent(SessionEvent) - } - - // A SessionReporter manages open sessions and reports session events to - // subscribers. - SessionReporter struct { - mu sync.Mutex - sessions map[UID]Session - subscribers map[SessionSubscriber]struct{} - } - - // A SessionEvent is an event that occurs during a session. - SessionEvent struct { - Type string `json:"type"` - Session Session `json:"session"` - RPC any `json:"rpc,omitempty"` - } -) - -// String returns the hex-encoded string representation of the UID. -func (u UID) String() string { - return hex.EncodeToString(u[:]) -} - -func (sr *SessionReporter) updateSubscribers(sessionID UID, eventType string, rpc any) { - sess, ok := sr.sessions[sessionID] - if !ok { - return - } - - sess.Ingress, sess.Egress = sess.conn.Usage() - sr.sessions[sessionID] = sess - - for sub := range sr.subscribers { - sub.ReceiveSessionEvent(SessionEvent{ - Type: eventType, - Session: sess, - RPC: rpc, - }) - } -} - -// Subscribe subscribes to session events. -func (sr *SessionReporter) Subscribe(sub SessionSubscriber) { - sr.mu.Lock() - defer sr.mu.Unlock() - - sr.subscribers[sub] = struct{}{} -} - -// Unsubscribe unsubscribes from session events. -func (sr *SessionReporter) Unsubscribe(sub SessionSubscriber) { - sr.mu.Lock() - defer sr.mu.Unlock() - - delete(sr.subscribers, sub) -} - -// StartSession starts a new session and returns a function that should be -// called when the session ends. -func (sr *SessionReporter) StartSession(conn *Conn, proto string, version int) (sessionID UID, end func()) { - sr.mu.Lock() - defer sr.mu.Unlock() - - copy(sessionID[:], frand.Bytes(8)) - sr.sessions[sessionID] = Session{ - conn: conn, - - ID: sessionID, - RHPVersion: version, - Protocol: proto, - PeerAddress: conn.RemoteAddr().String(), - Timestamp: time.Now(), - } - sr.updateSubscribers(sessionID, SessionEventTypeStart, nil) - return sessionID, func() { - sr.mu.Lock() - defer sr.mu.Unlock() - - sr.updateSubscribers(sessionID, SessionEventTypeEnd, nil) - delete(sr.sessions, sessionID) - } -} - -// StartRPC starts a new RPC and returns a function that should be called when -// the RPC ends. -func (sr *SessionReporter) StartRPC(sessionID UID, rpc types.Specifier) (rpcID UID, end func(contracts.Usage, error)) { - sr.mu.Lock() - defer sr.mu.Unlock() - - copy(rpcID[:], frand.Bytes(8)) - _, ok := sr.sessions[sessionID] - if !ok { - return rpcID, func(contracts.Usage, error) {} - } - - event := RPC{ - ID: rpcID, - SessionID: sessionID, - RPC: rpc, - } - rpcStart := time.Now() - sr.updateSubscribers(sessionID, SessionEventTypeRPCStart, event) - return rpcID, func(usage contracts.Usage, err error) { - // update event - event.Error = err - event.Elapsed = time.Since(rpcStart) - event.Usage = usage - - sr.mu.Lock() - defer sr.mu.Unlock() - - sess, ok := sr.sessions[sessionID] - if !ok { - return - } - - // update session - sess.Usage = sess.Usage.Add(usage) - sr.sessions[sessionID] = sess - // update subscribers - sr.updateSubscribers(sessionID, SessionEventTypeRPCEnd, event) - } -} - -// Active returns a snapshot of the currently active sessions. -func (sr *SessionReporter) Active() []Session { - sr.mu.Lock() - defer sr.mu.Unlock() - - sessions := make([]Session, 0, len(sr.sessions)) - for _, sess := range sr.sessions { - // update session usage - sess.Ingress, sess.Egress = sess.conn.Usage() - sr.sessions[sess.ID] = sess - // append to slice - sessions = append(sessions, sess) - } - return sessions -} - -// NewSessionReporter returns a new SessionReporter. -func NewSessionReporter() *SessionReporter { - return &SessionReporter{ - sessions: make(map[UID]Session), - } -} diff --git a/rhp/siamux.go b/rhp/siamux.go new file mode 100644 index 00000000..0ef15314 --- /dev/null +++ b/rhp/siamux.go @@ -0,0 +1,47 @@ +package rhp + +import ( + "crypto/ed25519" + "net" + + rhp4 "go.sia.tech/coreutils/rhp/v4" + "go.sia.tech/mux/v2" + "go.uber.org/zap" +) + +// A muxTransport is a rhp4.Transport that wraps a mux.Mux. +type muxTransport struct { + m *mux.Mux +} + +// Close implements the rhp4.Transport interface. +func (mt *muxTransport) Close() error { + return mt.m.Close() +} + +// AcceptStream implements the rhp4.Transport interface. +func (mt *muxTransport) AcceptStream() (net.Conn, error) { + return mt.m.AcceptStream() +} + +// ServeRHP4SiaMux serves RHP4 connections on l using the provided server and logger. +func ServeRHP4SiaMux(l net.Listener, s *rhp4.Server, log *zap.Logger) { + for { + conn, err := l.Accept() + if err != nil { + log.Error("failed to accept connection", zap.Error(err)) + return + } + log := log.With(zap.String("peerAddress", conn.RemoteAddr().String())) + go func() { + defer conn.Close() + + m, err := mux.Accept(conn, ed25519.PrivateKey(s.HostKey())) + if err != nil { + log.Debug("failed to accept mux connection", zap.Error(err)) + } else if err := s.Serve(&muxTransport{m}, log); err != nil { + log.Debug("failed to serve connection", zap.Error(err)) + } + }() + } +} diff --git a/rhp/v2/options.go b/rhp/v2/options.go deleted file mode 100644 index c7f9a6de..00000000 --- a/rhp/v2/options.go +++ /dev/null @@ -1,41 +0,0 @@ -package rhp - -import ( - "go.sia.tech/core/types" - "go.sia.tech/hostd/host/contracts" - "go.sia.tech/hostd/rhp" - "go.uber.org/zap" -) - -// A SessionHandlerOption is a functional option for session handlers. -type SessionHandlerOption func(*SessionHandler) - -// WithLog sets the logger for the session handler. -func WithLog(l *zap.Logger) SessionHandlerOption { - return func(s *SessionHandler) { - s.log = l - } -} - -// WithSessionReporter sets the session reporter for the session handler. -func WithSessionReporter(r SessionReporter) SessionHandlerOption { - return func(s *SessionHandler) { - s.sessions = r - } -} - -// WithDataMonitor sets the data monitor for the session handler. -func WithDataMonitor(m rhp.DataMonitor) SessionHandlerOption { - return func(s *SessionHandler) { - s.monitor = m - } -} - -type noopSessionReporter struct{} - -func (noopSessionReporter) StartSession(conn *rhp.Conn, proto string, version int) (sessionID rhp.UID, end func()) { - return rhp.UID{}, func() {} -} -func (noopSessionReporter) StartRPC(sessionID rhp.UID, rpc types.Specifier) (rpcID rhp.UID, end func(contracts.Usage, error)) { - return rhp.UID{}, func(contracts.Usage, error) {} -} diff --git a/rhp/v2/rhp.go b/rhp/v2/rhp.go index 3db917e3..8ec0c122 100644 --- a/rhp/v2/rhp.go +++ b/rhp/v2/rhp.go @@ -2,6 +2,7 @@ package rhp import ( "context" + "encoding/hex" "errors" "fmt" "io" @@ -17,7 +18,7 @@ import ( "go.sia.tech/hostd/internal/threadgroup" "go.sia.tech/hostd/rhp" "go.uber.org/zap" - "golang.org/x/time/rate" + "lukechampine.com/frand" ) const ( @@ -57,8 +58,8 @@ type ( // called after the contract roots have been committed to prevent the // sector from being deleted. Write(root types.Hash256, data *[rhp2.SectorSize]byte) (release func() error, _ error) - // Read reads the sector with the given root from the manager. - Read(root types.Hash256) (*[rhp2.SectorSize]byte, error) + // ReadSector reads the sector with the given root from the manager. + ReadSector(root types.Hash256) (*[rhp2.SectorSize]byte, error) // Sync syncs the data files of changed volumes. Sync() error } @@ -89,13 +90,6 @@ type ( // A SettingsReporter reports the host's current configuration. SettingsReporter interface { Settings() settings.Settings - BandwidthLimiters() (ingress, egress *rate.Limiter) - } - - // SessionReporter reports session metrics - SessionReporter interface { - StartSession(conn *rhp.Conn, proto string, version int) (sessionID rhp.UID, end func()) - StartRPC(sessionID rhp.UID, rpc types.Specifier) (rpcID rhp.UID, end func(contracts.Usage, error)) } // A SessionHandler handles the host side of the renter-host protocol and @@ -113,7 +107,6 @@ type ( wallet Wallet contracts ContractManager - sessions SessionReporter settings SettingsReporter storage StorageManager log *zap.Logger @@ -155,12 +148,10 @@ func (sh *SessionHandler) rpcLoop(sess *session, log *zap.Logger) error { return err } start := time.Now() - rpcID, end := sh.sessions.StartRPC(sess.id, id) - log = log.Named(id.String()).With(zap.Stringer("rpcID", rpcID)) + + log = log.Named(id.String()).With(zap.String("rpcID", hex.EncodeToString(frand.Bytes(4)))) log.Debug("RPC start") - usage, err := rpcFn(sess, log) - end(usage, err) - if err != nil { + if _, err := rpcFn(sess, log); err != nil { log.Warn("RPC error", zap.Error(err), zap.Duration("elapsed", time.Since(start))) return fmt.Errorf("RPC %q error: %w", id, err) } @@ -170,22 +161,13 @@ func (sh *SessionHandler) rpcLoop(sess *session, log *zap.Logger) error { // upgrade performs the RHP2 handshake and begins handling RPCs func (sh *SessionHandler) upgrade(conn net.Conn) error { - // wrap the conn with the bandwidth limiters - ingressLimiter, egressLimiter := sh.settings.BandwidthLimiters() - rhpConn := rhp.NewConn(conn, sh.monitor, ingressLimiter, egressLimiter) - - t, err := rhp2.NewHostTransport(rhpConn, sh.privateKey) + t, err := rhp2.NewHostTransport(conn, sh.privateKey) if err != nil { return err } - sessionID, end := sh.sessions.StartSession(rhpConn, rhp.SessionProtocolTCP, 2) - defer end() - sess := &session{ - id: sessionID, - conn: rhpConn, - t: t, + t: t, } defer t.Close() @@ -195,7 +177,7 @@ func (sh *SessionHandler) upgrade(conn net.Conn) error { } }() - log := sh.log.With(zap.Stringer("sessionID", sessionID), zap.Stringer("peerAddr", conn.RemoteAddr())) + log := sh.log.With(zap.String("sessionID", hex.EncodeToString(frand.Bytes(4))), zap.Stringer("peerAddr", conn.RemoteAddr())) for { if err := sh.rpcLoop(sess, log); err != nil { @@ -287,7 +269,7 @@ func (sh *SessionHandler) LocalAddr() string { } // NewSessionHandler creates a new RHP2 SessionHandler -func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, rhp3Addr string, cm ChainManager, s Syncer, wallet Wallet, contracts ContractManager, settings SettingsReporter, storage StorageManager, opts ...SessionHandlerOption) (*SessionHandler, error) { +func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, rhp3Addr string, cm ChainManager, s Syncer, wallet Wallet, contracts ContractManager, settings SettingsReporter, storage StorageManager, log *zap.Logger) (*SessionHandler, error) { _, rhp3Port, err := net.SplitHostPort(rhp3Addr) if err != nil { return nil, fmt.Errorf("failed to parse rhp3 addr: %w", err) @@ -307,14 +289,8 @@ func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, rhp3Addr string settings: settings, storage: storage, - log: zap.NewNop(), - monitor: rhp.NewNoOpMonitor(), - sessions: noopSessionReporter{}, - - tg: threadgroup.New(), - } - for _, opt := range opts { - opt(sh) + log: log, + tg: threadgroup.New(), } return sh, nil } diff --git a/rhp/v2/rpc.go b/rhp/v2/rpc.go index bb057acd..a356bfd6 100644 --- a/rhp/v2/rpc.go +++ b/rhp/v2/rpc.go @@ -661,11 +661,13 @@ func (sh *SessionHandler) rpcWrite(s *session, log *zap.Logger) (contracts.Usage return contracts.Usage{}, err } - sector, err := sh.storage.Read(root) + sector, err := sh.storage.ReadSector(root) if err != nil { s.t.WriteResponseErr(ErrHostInternalError) return contracts.Usage{}, fmt.Errorf("failed to read sector %v: %w", root, err) } + var updated [rhp2.SectorSize]byte + copy(updated[:], sector[:]) i, offset := action.A, action.B if offset > rhp2.SectorSize { @@ -678,7 +680,7 @@ func (sh *SessionHandler) rpcWrite(s *session, log *zap.Logger) (contracts.Usage return contracts.Usage{}, err } - copy(sector[offset:], action.Data) + copy(updated[offset:], action.Data) newRoot := rhp2.SectorRoot(sector) if err := contractUpdater.UpdateSector(newRoot, i); err != nil { @@ -686,7 +688,7 @@ func (sh *SessionHandler) rpcWrite(s *session, log *zap.Logger) (contracts.Usage s.t.WriteResponseErr(err) return contracts.Usage{}, err } - release, err := sh.storage.Write(root, sector) + release, err := sh.storage.Write(root, &updated) if err != nil { err := fmt.Errorf("append action: failed to write sector: %w", err) s.t.WriteResponseErr(err) @@ -859,7 +861,7 @@ func (sh *SessionHandler) rpcRead(s *session, log *zap.Logger) (contracts.Usage, // enter response loop for i, sec := range req.Sections { - sector, err := sh.storage.Read(sec.MerkleRoot) + sector, err := sh.storage.ReadSector(sec.MerkleRoot) if err != nil { err := fmt.Errorf("failed to get sector: %w", err) s.t.WriteResponseErr(err) diff --git a/rhp/v2/rpc_test.go b/rhp/v2/rpc_test.go index 10f01932..f4df5d09 100644 --- a/rhp/v2/rpc_test.go +++ b/rhp/v2/rpc_test.go @@ -59,7 +59,7 @@ func TestSettings(t *testing.T) { } defer l.Close() - sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log)) + sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, log) if err != nil { t.Fatal(err) } @@ -115,7 +115,7 @@ func TestUploadDownload(t *testing.T) { } defer l.Close() - sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log)) + sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, log) if err != nil { t.Fatal(err) } @@ -221,7 +221,7 @@ func TestRenew(t *testing.T) { } defer l.Close() - sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log)) + sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, log) if err != nil { t.Fatal(err) } @@ -429,7 +429,7 @@ func TestRPCV2(t *testing.T) { } defer l.Close() - sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log)) + sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, log) if err != nil { t.Fatal(err) } diff --git a/rhp/v2/session.go b/rhp/v2/session.go index 9ab785cd..e1d8d854 100644 --- a/rhp/v2/session.go +++ b/rhp/v2/session.go @@ -6,7 +6,6 @@ import ( rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/hostd/host/contracts" - "go.sia.tech/hostd/rhp" ) // minMessageSize is the minimum size of an RPC message. If an encoded message @@ -16,25 +15,23 @@ const minMessageSize = 4096 // A session is an ongoing exchange of RPCs via the renter-host protocol. type session struct { - id rhp.UID - conn *rhp.Conn - t *rhp2.Transport + t *rhp2.Transport contract contracts.SignedRevision } func (s *session) readRequest(req rhp2.ProtocolObject, maxSize uint64, timeout time.Duration) error { - s.conn.SetReadDeadline(time.Now().Add(timeout)) + s.t.SetReadDeadline(time.Now().Add(timeout)) return s.t.ReadRequest(req, maxSize) } func (s *session) readResponse(req rhp2.ProtocolObject, maxSize uint64, timeout time.Duration) error { - s.conn.SetReadDeadline(time.Now().Add(timeout)) + s.t.SetReadDeadline(time.Now().Add(timeout)) return s.t.ReadResponse(req, maxSize) } func (s *session) writeResponse(resp rhp2.ProtocolObject, timeout time.Duration) error { - s.conn.SetWriteDeadline(time.Now().Add(timeout)) + s.t.SetWriteDeadline(time.Now().Add(timeout)) return s.t.WriteResponse(resp) } diff --git a/rhp/v3/execute.go b/rhp/v3/execute.go index 45185c80..a89fbf1a 100644 --- a/rhp/v3/execute.go +++ b/rhp/v3/execute.go @@ -94,15 +94,12 @@ func (pe *programExecutor) executeAppendSector(instr *rhp3.InstrAppendSector, lo if err != nil { return nil, nil, fmt.Errorf("failed to read sector: %w", err) } - rootCalcStart := time.Now() root := rhp2.SectorRoot(sector) - log.Debug("calculated sector root", zap.Duration("duration", time.Since(rootCalcStart))) // pay for execution cost := pe.priceTable.AppendSectorCost(pe.remainingDuration) if err := pe.payForExecution(cost, costToAccountUsage(cost)); err != nil { return nil, nil, fmt.Errorf("failed to pay for instruction: %w", err) } - release, err := pe.storage.Write(root, sector) if errors.Is(err, storage.ErrNotEnoughStorage) { return nil, nil, err @@ -111,15 +108,11 @@ func (pe *programExecutor) executeAppendSector(instr *rhp3.InstrAppendSector, lo } pe.releaseFuncs = append(pe.releaseFuncs, release) pe.updater.AppendSector(root) - if !instr.ProofRequired { return nil, nil, nil } - - proofStart := time.Now() roots := pe.updater.SectorRoots() proof, _ := rhp2.BuildDiffProof([]rhp2.RPCWriteAction{{Type: rhp2.RPCWriteActionAppend}}, roots[:len(roots)-1]) // TODO: add rhp3 proof methods - log.Debug("built proof", zap.Duration("duration", time.Since(proofStart))) return nil, proof, nil } @@ -227,7 +220,7 @@ func (pe *programExecutor) executeReadOffset(instr *rhp3.InstrReadOffset, log *z return nil, nil, fmt.Errorf("failed to get root: %w", err) } - sector, err := pe.storage.Read(root) + sector, err := pe.storage.ReadSector(root) if err != nil { return nil, nil, fmt.Errorf("failed to read sector: %w", err) } @@ -276,7 +269,7 @@ func (pe *programExecutor) executeReadSector(instr *rhp3.InstrReadSector, log *z } // read the sector - sector, err := pe.storage.Read(root) + sector, err := pe.storage.ReadSector(root) if errors.Is(err, storage.ErrSectorNotFound) { log.Debug("failed to read sector", zap.String("root", root.String()), zap.Error(err)) return nil, nil, storage.ErrSectorNotFound @@ -365,20 +358,23 @@ func (pe *programExecutor) executeUpdateSector(instr *rhp3.InstrUpdateSector, _ return nil, nil, fmt.Errorf("failed to get root: %w", err) } - sector, err := pe.storage.Read(oldRoot) + sector, err := pe.storage.ReadSector(oldRoot) if err != nil { return nil, nil, fmt.Errorf("failed to read sector: %w", err) } + var updated [rhp2.SectorSize]byte + copy(updated[:], sector[:]) + // validate and apply the patch if relOffset+length > rhp2.SectorSize { return nil, nil, fmt.Errorf("update offset %v length %v is out of bounds", relOffset, length) } - copy(sector[relOffset:], patch) + copy(updated[relOffset:], patch) // store the new sector - newRoot := rhp2.SectorRoot((*[rhp2.SectorSize]byte)(sector)) - release, err := pe.storage.Write(newRoot, sector) + newRoot := rhp2.SectorRoot(&updated) + release, err := pe.storage.Write(newRoot, &updated) if err != nil { return nil, nil, fmt.Errorf("failed to write sector: %w", err) } diff --git a/rhp/v3/options.go b/rhp/v3/options.go deleted file mode 100644 index 4e2b5f0b..00000000 --- a/rhp/v3/options.go +++ /dev/null @@ -1,41 +0,0 @@ -package rhp - -import ( - "go.sia.tech/core/types" - "go.sia.tech/hostd/host/contracts" - "go.sia.tech/hostd/rhp" - "go.uber.org/zap" -) - -// SessionHandlerOption is a functional option for session handlers. -type SessionHandlerOption func(*SessionHandler) - -// WithLog sets the logger for the session handler. -func WithLog(l *zap.Logger) SessionHandlerOption { - return func(s *SessionHandler) { - s.log = l - } -} - -// WithSessionReporter sets the session reporter for the session handler. -func WithSessionReporter(r SessionReporter) SessionHandlerOption { - return func(s *SessionHandler) { - s.sessions = r - } -} - -// WithDataMonitor sets the data monitor for the session handler. -func WithDataMonitor(m rhp.DataMonitor) SessionHandlerOption { - return func(s *SessionHandler) { - s.monitor = m - } -} - -type noopSessionReporter struct{} - -func (noopSessionReporter) StartSession(conn *rhp.Conn, proto string, version int) (sessionID rhp.UID, end func()) { - return rhp.UID{}, func() {} -} -func (noopSessionReporter) StartRPC(sessionID rhp.UID, rpc types.Specifier) (rpcID rhp.UID, end func(contracts.Usage, error)) { - return rhp.UID{}, func(contracts.Usage, error) {} -} diff --git a/rhp/v3/rhp.go b/rhp/v3/rhp.go index bb775ac6..eff38807 100644 --- a/rhp/v3/rhp.go +++ b/rhp/v3/rhp.go @@ -2,6 +2,7 @@ package rhp import ( "context" + "encoding/hex" "errors" "fmt" "net" @@ -16,9 +17,8 @@ import ( "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/storage" "go.sia.tech/hostd/internal/threadgroup" - "go.sia.tech/hostd/rhp" "go.uber.org/zap" - "golang.org/x/time/rate" + "lukechampine.com/frand" ) type ( @@ -62,8 +62,8 @@ type ( // called after the contract roots have been committed to prevent the // sector from being deleted. Write(root types.Hash256, data *[rhp2.SectorSize]byte) (release func() error, _ error) - // Read reads the sector with the given root from the manager. - Read(root types.Hash256) (*[rhp2.SectorSize]byte, error) + // ReadSector reads the sector with the given root from the manager. + ReadSector(root types.Hash256) (*[rhp2.SectorSize]byte, error) // Sync syncs the data files of changed volumes. Sync() error @@ -107,13 +107,6 @@ type ( // A SettingsReporter reports the host's current configuration. SettingsReporter interface { Settings() settings.Settings - BandwidthLimiters() (ingress, egress *rate.Limiter) - } - - // SessionReporter reports session metrics - SessionReporter interface { - StartSession(conn *rhp.Conn, proto string, version int) (sessionID rhp.UID, end func()) - StartRPC(sessionID rhp.UID, rpc types.Specifier) (rpcID rhp.UID, end func(contracts.Usage, error)) } // A SessionHandler handles the host side of the renter-host protocol and @@ -133,10 +126,8 @@ type ( syncer Syncer wallet Wallet - log *zap.Logger - sessions SessionReporter - monitor rhp.DataMonitor - tg *threadgroup.ThreadGroup + log *zap.Logger + tg *threadgroup.ThreadGroup priceTables *priceTableManager } @@ -175,7 +166,7 @@ var ( ) // handleHostStream handles streams routed to the "host" subscriber -func (sh *SessionHandler) handleHostStream(s *rhp3.Stream, sessionID rhp.UID, log *zap.Logger) { +func (sh *SessionHandler) handleHostStream(s *rhp3.Stream, log *zap.Logger) { defer s.Close() // close the stream when the RPC has completed done, err := sh.tg.Add() // add the RPC to the threadgroup @@ -207,11 +198,8 @@ func (sh *SessionHandler) handleHostStream(s *rhp3.Stream, sessionID rhp.UID, lo rpcStart := time.Now() s.SetDeadline(time.Now().Add(time.Minute)) // set the initial deadline, may be overwritten by the handler - rpcID, end := sh.sessions.StartRPC(sessionID, rpc) - log = log.Named(rpc.String()).With(zap.Stringer("rpcID", rpcID)) - usage, err := rpcFn(s, log) - end(usage, err) - if err != nil { + log = log.Named(rpc.String()).With(zap.String("rpcID", hex.EncodeToString(frand.Bytes(4)))) + if _, err := rpcFn(s, log); err != nil { log.Warn("RPC failed", zap.Error(err), zap.Duration("elapsed", time.Since(rpcStart))) return } @@ -242,19 +230,10 @@ func (sh *SessionHandler) Serve() error { go func() { defer conn.Close() - // wrap the conn with the bandwidth limiters - ingress, egress := sh.settings.BandwidthLimiters() - rhpConn := rhp.NewConn(conn, sh.monitor, ingress, egress) - defer rhpConn.Close() - - // initiate the session - sessionID, end := sh.sessions.StartSession(rhpConn, rhp.SessionProtocolTCP, 3) - defer end() - - log := sh.log.With(zap.Stringer("sessionID", sessionID), zap.String("peerAddress", conn.RemoteAddr().String())) + log := sh.log.With(zap.String("sessionID", hex.EncodeToString(frand.Bytes(4))), zap.String("peerAddress", conn.RemoteAddr().String())) // upgrade the connection to RHP3 - t, err := rhp3.NewHostTransport(rhpConn, sh.privateKey) + t, err := rhp3.NewHostTransport(conn, sh.privateKey) if err != nil { log.Debug("failed to upgrade conn", zap.Error(err)) return @@ -277,7 +256,7 @@ func (sh *SessionHandler) Serve() error { return } - go sh.handleHostStream(stream, sessionID, log) + go sh.handleHostStream(stream, log) } }() } @@ -289,7 +268,7 @@ func (sh *SessionHandler) LocalAddr() string { } // NewSessionHandler creates a new SessionHandler -func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, chain ChainManager, syncer Syncer, wallet Wallet, accounts AccountManager, contracts ContractManager, registry RegistryManager, storage StorageManager, settings SettingsReporter, opts ...SessionHandlerOption) (*SessionHandler, error) { +func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, chain ChainManager, syncer Syncer, wallet Wallet, accounts AccountManager, contracts ContractManager, registry RegistryManager, storage StorageManager, settings SettingsReporter, log *zap.Logger) (*SessionHandler, error) { sh := &SessionHandler{ privateKey: hostKey, @@ -305,15 +284,10 @@ func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, chain ChainMana settings: settings, storage: storage, - log: zap.NewNop(), - monitor: rhp.NewNoOpMonitor(), - sessions: noopSessionReporter{}, - tg: threadgroup.New(), + log: log, + tg: threadgroup.New(), priceTables: newPriceTableManager(), } - for _, opt := range opts { - opt(sh) - } return sh, nil } diff --git a/rhp/v3/rpc_test.go b/rhp/v3/rpc_test.go index ae4ef18e..0d170bd0 100644 --- a/rhp/v3/rpc_test.go +++ b/rhp/v3/rpc_test.go @@ -105,14 +105,14 @@ func setupRHP3Host(t *testing.T, node *testutil.HostNode, hostKey types.PrivateK t.Fatal(err) } - sh2, err := rhp2.NewSessionHandler(rhp2Listener, hostKey, rhp3Listener.Addr().String(), node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log.Named("rhp2"))) + sh2, err := rhp2.NewSessionHandler(rhp2Listener, hostKey, rhp3Listener.Addr().String(), node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, log.Named("rhp2")) if err != nil { t.Fatal(err) } t.Cleanup(func() { sh2.Close() }) go sh2.Serve() - sh3, err := rhp3.NewSessionHandler(rhp3Listener, hostKey, node.Chain, node.Syncer, node.Wallet, node.Accounts, node.Contracts, node.Registry, node.Volumes, node.Settings, rhp3.WithLog(log.Named("rhp3"))) + sh3, err := rhp3.NewSessionHandler(rhp3Listener, hostKey, node.Chain, node.Syncer, node.Wallet, node.Accounts, node.Contracts, node.Registry, node.Volumes, node.Settings, log.Named("rhp3")) if err != nil { t.Fatal(err) } diff --git a/rhp/v3/websockets.go b/rhp/v3/websockets.go deleted file mode 100644 index 9ea4dbe8..00000000 --- a/rhp/v3/websockets.go +++ /dev/null @@ -1,63 +0,0 @@ -package rhp - -import ( - "context" - "net/http" - - rhp3 "go.sia.tech/core/rhp/v3" - "go.sia.tech/hostd/rhp" - "go.uber.org/zap" - "nhooyr.io/websocket" -) - -// handleWebSockets handles websocket connections to the host. -func (sh *SessionHandler) handleWebSockets(w http.ResponseWriter, r *http.Request) { - log := sh.log.Named("websockets").With(zap.String("peerAddr", r.RemoteAddr)) - wsConn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - OriginPatterns: []string{"*"}, - }) - if err != nil { - log.Warn("failed to accept websocket connection", zap.Error(err)) - return - } - defer wsConn.Close(websocket.StatusNormalClosure, "") - - // wrap the websocket connection - conn := websocket.NetConn(context.Background(), wsConn, websocket.MessageBinary) - defer conn.Close() - - // wrap the connection with a rate limiter - ingress, egress := sh.settings.BandwidthLimiters() - rhpConn := rhp.NewConn(conn, sh.monitor, ingress, egress) - defer rhpConn.Close() - - // initiate the session - sessionID, end := sh.sessions.StartSession(rhpConn, rhp.SessionProtocolWS, 3) - defer end() - - log = log.With(zap.String("sessionID", sessionID.String())) - - // upgrade the connection - t, err := rhp3.NewHostTransport(rhpConn, sh.privateKey) - if err != nil { - sh.log.Debug("failed to upgrade conn", zap.Error(err), zap.String("remoteAddress", conn.RemoteAddr().String())) - return - } - defer t.Close() - - for { - stream, err := t.AcceptStream() - if err != nil { - log.Debug("failed to accept stream", zap.Error(err)) - return - } - - go sh.handleHostStream(stream, sessionID, log) - } -} - -// WebSocketHandler returns an http.Handler that upgrades the connection to a -// WebSocket and then passes the connection to the RHP3 host transport. -func (sh *SessionHandler) WebSocketHandler() http.HandlerFunc { - return sh.handleWebSockets -}