diff --git a/arbnode/batch_poster.go b/arbnode/batch_poster.go index ee00cdc618..f1dcc91884 100644 --- a/arbnode/batch_poster.go +++ b/arbnode/batch_poster.go @@ -216,8 +216,9 @@ var DefaultBatchPosterConfig = BatchPosterConfig{ Enable: false, DisableDapFallbackStoreDataOnChain: false, // This default is overridden for L3 chains in applyChainParameters in cmd/nitro/nitro.go - MaxSize: 100000, - Max4844BatchSize: blobs.BlobEncodableData*(params.MaxBlobGasPerBlock/params.BlobTxBlobGasPerBlob) - 2000, + MaxSize: 100000, + // Try to fill 3 blobs per batch + Max4844BatchSize: blobs.BlobEncodableData*(params.MaxBlobGasPerBlock/params.BlobTxBlobGasPerBlob)/2 - 2000, PollInterval: time.Second * 10, ErrorDelay: time.Second * 10, MaxDelay: time.Hour, diff --git a/cmd/conf/init.go b/cmd/conf/init.go index 8a6c5096fb..2697a11111 100644 --- a/cmd/conf/init.go +++ b/cmd/conf/init.go @@ -45,7 +45,7 @@ var InitConfigDefault = InitConfig{ func InitConfigAddOptions(prefix string, f *pflag.FlagSet) { f.Bool(prefix+".force", InitConfigDefault.Force, "if true: in case database exists init code will be reexecuted and genesis block compared to database") - f.String(prefix+".url", InitConfigDefault.Url, "url to download initializtion data - will poll if download fails") + f.String(prefix+".url", InitConfigDefault.Url, "url to download initialization data - will poll if download fails") f.String(prefix+".download-path", InitConfigDefault.DownloadPath, "path to save temp downloaded file") f.Duration(prefix+".download-poll", InitConfigDefault.DownloadPoll, "how long to wait between polling attempts") f.Bool(prefix+".dev-init", InitConfigDefault.DevInit, "init with dev data (1 account with balance) instead of file import") diff --git a/cmd/daserver/daserver.go b/cmd/daserver/daserver.go index 8036487d26..1a3fd435b8 100644 --- a/cmd/daserver/daserver.go +++ b/cmd/daserver/daserver.go @@ -31,10 +31,11 @@ import ( ) type DAServerConfig struct { - EnableRPC bool `koanf:"enable-rpc"` - RPCAddr string `koanf:"rpc-addr"` - RPCPort uint64 `koanf:"rpc-port"` - RPCServerTimeouts genericconf.HTTPServerTimeoutConfig `koanf:"rpc-server-timeouts"` + EnableRPC bool `koanf:"enable-rpc"` + RPCAddr string `koanf:"rpc-addr"` + RPCPort uint64 `koanf:"rpc-port"` + RPCServerTimeouts genericconf.HTTPServerTimeoutConfig `koanf:"rpc-server-timeouts"` + RPCServerBodyLimit int `koanf:"rpc-server-body-limit"` EnableREST bool `koanf:"enable-rest"` RESTAddr string `koanf:"rest-addr"` @@ -58,6 +59,7 @@ var DefaultDAServerConfig = DAServerConfig{ RPCAddr: "localhost", RPCPort: 9876, RPCServerTimeouts: genericconf.HTTPServerTimeoutConfigDefault, + RPCServerBodyLimit: genericconf.HTTPServerBodyLimitDefault, EnableREST: false, RESTAddr: "localhost", RESTPort: 9877, @@ -88,6 +90,7 @@ func parseDAServer(args []string) (*DAServerConfig, error) { f.Bool("enable-rpc", DefaultDAServerConfig.EnableRPC, "enable the HTTP-RPC server listening on rpc-addr and rpc-port") f.String("rpc-addr", DefaultDAServerConfig.RPCAddr, "HTTP-RPC server listening interface") f.Uint64("rpc-port", DefaultDAServerConfig.RPCPort, "HTTP-RPC server listening port") + f.Int("rpc-server-body-limit", DefaultDAServerConfig.RPCServerBodyLimit, "HTTP-RPC server maximum request body size in bytes; the default (0) uses geth's 5MB limit") genericconf.HTTPServerTimeoutConfigAddOptions("rpc-server-timeouts", f) f.Bool("enable-rest", DefaultDAServerConfig.EnableREST, "enable the REST server listening on rest-addr and rest-port") @@ -250,7 +253,7 @@ func startup() error { if serverConfig.EnableRPC { log.Info("Starting HTTP-RPC server", "addr", serverConfig.RPCAddr, "port", serverConfig.RPCPort, "revision", vcsRevision, "vcs.time", vcsTime) - rpcServer, err = das.StartDASRPCServer(ctx, serverConfig.RPCAddr, serverConfig.RPCPort, serverConfig.RPCServerTimeouts, daReader, daWriter, daHealthChecker) + rpcServer, err = das.StartDASRPCServer(ctx, serverConfig.RPCAddr, serverConfig.RPCPort, serverConfig.RPCServerTimeouts, serverConfig.RPCServerBodyLimit, daReader, daWriter, daHealthChecker) if err != nil { return err } diff --git a/cmd/genericconf/server.go b/cmd/genericconf/server.go index 7550791d6d..9b8acd5f71 100644 --- a/cmd/genericconf/server.go +++ b/cmd/genericconf/server.go @@ -48,6 +48,8 @@ var HTTPServerTimeoutConfigDefault = HTTPServerTimeoutConfig{ IdleTimeout: 120 * time.Second, } +const HTTPServerBodyLimitDefault = 0 // Use default from go-ethereum + func (c HTTPConfig) Apply(stackConf *node.Config) { stackConf.HTTPHost = c.Addr stackConf.HTTPPort = c.Port diff --git a/cmd/nitro/init.go b/cmd/nitro/init.go index 6305c41155..8581a919c8 100644 --- a/cmd/nitro/init.go +++ b/cmd/nitro/init.go @@ -5,10 +5,14 @@ package main import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" + "io" "math/big" + "net/http" "os" "runtime" "strings" @@ -40,6 +44,8 @@ import ( "github.com/offchainlabs/nitro/util/arbmath" ) +var notFoundError = errors.New("file not found") + func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, error) { if initConfig.Url == "" { return "", nil @@ -66,18 +72,30 @@ func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, err } return initFile, nil } - grabclient := grab.NewClient() log.Info("Downloading initial database", "url", initConfig.Url) - fmt.Println() + path, err := downloadFile(ctx, initConfig, initConfig.Url) + if errors.Is(err, notFoundError) { + return downloadInitInParts(ctx, initConfig) + } + return path, err +} + +func downloadFile(ctx context.Context, initConfig *conf.InitConfig, url string) (string, error) { + checksum, err := fetchChecksum(ctx, url+".sha256") + if err != nil { + return "", fmt.Errorf("error fetching checksum: %w", err) + } + grabclient := grab.NewClient() printTicker := time.NewTicker(time.Second) defer printTicker.Stop() attempt := 0 for { attempt++ - req, err := grab.NewRequest(initConfig.DownloadPath, initConfig.Url) + req, err := grab.NewRequest(initConfig.DownloadPath, url) if err != nil { panic(err) } + req.SetChecksum(sha256.New(), checksum, false) resp := grabclient.Do(req.WithContext(ctx)) firstPrintTime := time.Now().Add(time.Second * 2) updateLoop: @@ -102,6 +120,9 @@ func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, err } case <-resp.Done: if err := resp.Err(); err != nil { + if resp.HTTPResponse.StatusCode == http.StatusNotFound { + return "", fmt.Errorf("file not found but checksum exists") + } fmt.Printf("\n attempt %d failed: %v\n", attempt, err) break updateLoop } @@ -121,6 +142,99 @@ func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, err } } +// fetchChecksum performs a GET request to the specified URL using the provided context +// and returns the checksum as a []byte +func fetchChecksum(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making GET request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return nil, notFoundError + } else if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %v", resp.Status) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + checksumStr := strings.TrimSpace(string(body)) + checksum, err := hex.DecodeString(checksumStr) + if err != nil { + return nil, fmt.Errorf("error decoding checksum: %w", err) + } + if len(checksum) != sha256.Size { + return nil, fmt.Errorf("invalid checksum length") + } + return checksum, nil +} + +func downloadInitInParts(ctx context.Context, initConfig *conf.InitConfig) (string, error) { + log.Info("File not found; trying to download database in parts") + fileInfo, err := os.Stat(initConfig.DownloadPath) + if err != nil || !fileInfo.IsDir() { + return "", fmt.Errorf("download path must be a directory: %v", initConfig.DownloadPath) + } + part := 0 + parts := []string{} + defer func() { + // remove all temporary files. + for _, part := range parts { + err := os.Remove(part) + if err != nil { + log.Warn("Failed to remove temporary file", "file", part) + } + } + }() + for { + url := fmt.Sprintf("%s.part%d", initConfig.Url, part) + log.Info("Downloading database part", "url", url) + partFile, err := downloadFile(ctx, initConfig, url) + if errors.Is(err, notFoundError) { + log.Info("Part not found; concatenating archive into single file", "numParts", len(parts)) + break + } else if err != nil { + return "", err + } + parts = append(parts, partFile) + part++ + } + return joinArchive(parts) +} + +// joinArchive joins the archive parts into a single file and return its path. +func joinArchive(parts []string) (string, error) { + if len(parts) == 0 { + return "", fmt.Errorf("no database parts found") + } + archivePath := strings.TrimSuffix(parts[0], ".part0") + archive, err := os.Create(archivePath) + if err != nil { + return "", fmt.Errorf("failed to create archive: %w", err) + } + defer archive.Close() + for _, part := range parts { + partFile, err := os.Open(part) + if err != nil { + return "", fmt.Errorf("failed to open part file %s: %w", part, err) + } + defer partFile.Close() + _, err = io.Copy(archive, partFile) + if err != nil { + return "", fmt.Errorf("failed to copy part file %s: %w", part, err) + } + log.Info("Joined database part into archive", "part", part) + } + log.Info("Successfully joined parts into archive", "archive", archivePath) + return archivePath, nil +} + func validateBlockChain(blockChain *core.BlockChain, chainConfig *params.ChainConfig) error { statedb, err := blockChain.State() if err != nil { diff --git a/cmd/nitro/init_test.go b/cmd/nitro/init_test.go new file mode 100644 index 0000000000..a621a6669d --- /dev/null +++ b/cmd/nitro/init_test.go @@ -0,0 +1,142 @@ +// Copyright 2021-2022, Offchain Labs, Inc. +// For license information, see https://github.com/nitro/blob/master/LICENSE + +package main + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "net" + "net/http" + "os" + "testing" + "time" + + "github.com/offchainlabs/nitro/cmd/conf" + "github.com/offchainlabs/nitro/util/testhelpers" +) + +func TestDownloadInit(t *testing.T) { + const ( + archiveName = "random_data.tar.gz" + dataSize = 1024 * 1024 + filePerm = 0600 + ) + + // Create archive with random data + serverDir := t.TempDir() + data := testhelpers.RandomSlice(dataSize) + checksumBytes := sha256.Sum256(data) + checksum := hex.EncodeToString(checksumBytes[:]) + + // Write archive file + archiveFile := fmt.Sprintf("%s/%s", serverDir, archiveName) + err := os.WriteFile(archiveFile, data, filePerm) + Require(t, err, "failed to write archive") + + // Write checksum file + checksumFile := archiveFile + ".sha256" + err = os.WriteFile(checksumFile, []byte(checksum), filePerm) + Require(t, err, "failed to write checksum") + + // Start HTTP server + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + addr := startFileServer(t, ctx, serverDir) + + // Download file + initConfig := conf.InitConfigDefault + initConfig.Url = fmt.Sprintf("http://%s/%s", addr, archiveName) + initConfig.DownloadPath = t.TempDir() + receivedArchive, err := downloadInit(ctx, &initConfig) + Require(t, err, "failed to download") + + // Check archive contents + receivedData, err := os.ReadFile(receivedArchive) + Require(t, err, "failed to read received archive") + if !bytes.Equal(receivedData, data) { + t.Error("downloaded archive is different from generated one") + } +} + +func TestDownloadInitInParts(t *testing.T) { + const ( + archiveName = "random_data.tar.gz" + numParts = 3 + partSize = 1024 * 1024 + dataSize = numParts * partSize + filePerm = 0600 + ) + + // Create parts with random data + serverDir := t.TempDir() + data := testhelpers.RandomSlice(dataSize) + for i := 0; i < numParts; i++ { + // Create part and checksum + partData := data[partSize*i : partSize*(i+1)] + checksumBytes := sha256.Sum256(partData) + checksum := hex.EncodeToString(checksumBytes[:]) + // Write part file + partFile := fmt.Sprintf("%s/%s.part%d", serverDir, archiveName, i) + err := os.WriteFile(partFile, partData, filePerm) + Require(t, err, "failed to write part") + // Write checksum file + checksumFile := partFile + ".sha256" + err = os.WriteFile(checksumFile, []byte(checksum), filePerm) + Require(t, err, "failed to write checksum") + } + + // Start HTTP server + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + addr := startFileServer(t, ctx, serverDir) + + // Download file + initConfig := conf.InitConfigDefault + initConfig.Url = fmt.Sprintf("http://%s/%s", addr, archiveName) + initConfig.DownloadPath = t.TempDir() + receivedArchive, err := downloadInit(ctx, &initConfig) + Require(t, err, "failed to download") + + // check database contents + receivedData, err := os.ReadFile(receivedArchive) + Require(t, err, "failed to read received archive") + if !bytes.Equal(receivedData, data) { + t.Error("downloaded archive is different from generated one") + } + + // Check if the function deleted the temporary files + entries, err := os.ReadDir(initConfig.DownloadPath) + Require(t, err, "failed to read temp dir") + if len(entries) != 1 { + t.Error("download function did not delete temp files") + } +} + +func startFileServer(t *testing.T, ctx context.Context, dir string) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + Require(t, err, "failed to listen") + addr := ln.Addr().String() + server := &http.Server{ + Addr: addr, + Handler: http.FileServer(http.Dir(dir)), + ReadHeaderTimeout: time.Second, + } + go func() { + err := server.Serve(ln) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Error("failed to shutdown server") + } + }() + go func() { + <-ctx.Done() + err := server.Shutdown(ctx) + Require(t, err, "failed to shutdown server") + }() + return addr +} diff --git a/das/dasRpcServer.go b/das/dasRpcServer.go index 2f1fc1fd42..03f755b90e 100644 --- a/das/dasRpcServer.go +++ b/das/dasRpcServer.go @@ -36,19 +36,22 @@ type DASRPCServer struct { daHealthChecker DataAvailabilityServiceHealthChecker } -func StartDASRPCServer(ctx context.Context, addr string, portNum uint64, rpcServerTimeouts genericconf.HTTPServerTimeoutConfig, daReader DataAvailabilityServiceReader, daWriter DataAvailabilityServiceWriter, daHealthChecker DataAvailabilityServiceHealthChecker) (*http.Server, error) { +func StartDASRPCServer(ctx context.Context, addr string, portNum uint64, rpcServerTimeouts genericconf.HTTPServerTimeoutConfig, rpcServerBodyLimit int, daReader DataAvailabilityServiceReader, daWriter DataAvailabilityServiceWriter, daHealthChecker DataAvailabilityServiceHealthChecker) (*http.Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addr, portNum)) if err != nil { return nil, err } - return StartDASRPCServerOnListener(ctx, listener, rpcServerTimeouts, daReader, daWriter, daHealthChecker) + return StartDASRPCServerOnListener(ctx, listener, rpcServerTimeouts, rpcServerBodyLimit, daReader, daWriter, daHealthChecker) } -func StartDASRPCServerOnListener(ctx context.Context, listener net.Listener, rpcServerTimeouts genericconf.HTTPServerTimeoutConfig, daReader DataAvailabilityServiceReader, daWriter DataAvailabilityServiceWriter, daHealthChecker DataAvailabilityServiceHealthChecker) (*http.Server, error) { +func StartDASRPCServerOnListener(ctx context.Context, listener net.Listener, rpcServerTimeouts genericconf.HTTPServerTimeoutConfig, rpcServerBodyLimit int, daReader DataAvailabilityServiceReader, daWriter DataAvailabilityServiceWriter, daHealthChecker DataAvailabilityServiceHealthChecker) (*http.Server, error) { if daWriter == nil { return nil, errors.New("No writer backend was configured for DAS RPC server. Has the BLS signing key been set up (--data-availability.key.key-dir or --data-availability.key.priv-key options)?") } rpcServer := rpc.NewServer() + if rpcServerBodyLimit > 0 { + rpcServer.SetHTTPBodyLimit(rpcServerBodyLimit) + } err := rpcServer.RegisterName("das", &DASRPCServer{ daReader: daReader, daWriter: daWriter, diff --git a/das/rpc_test.go b/das/rpc_test.go index 044ba597be..658592cc0b 100644 --- a/das/rpc_test.go +++ b/das/rpc_test.go @@ -55,7 +55,7 @@ func TestRPC(t *testing.T) { testhelpers.RequireImpl(t, err) localDas, err := NewSignAfterStoreDASWriterWithSeqInboxCaller(privKey, nil, storageService, "") testhelpers.RequireImpl(t, err) - dasServer, err := StartDASRPCServerOnListener(ctx, lis, genericconf.HTTPServerTimeoutConfigDefault, storageService, localDas, storageService) + dasServer, err := StartDASRPCServerOnListener(ctx, lis, genericconf.HTTPServerTimeoutConfigDefault, genericconf.HTTPServerBodyLimitDefault, storageService, localDas, storageService) defer func() { if err := dasServer.Shutdown(ctx); err != nil { panic(err) diff --git a/pubsub/common.go b/pubsub/common.go new file mode 100644 index 0000000000..9f05304e46 --- /dev/null +++ b/pubsub/common.go @@ -0,0 +1,29 @@ +package pubsub + +import ( + "context" + + "github.com/ethereum/go-ethereum/log" + "github.com/go-redis/redis/v8" +) + +// CreateStream tries to create stream with given name, if it already exists +// does not return an error. +func CreateStream(ctx context.Context, streamName string, client redis.UniversalClient) error { + _, err := client.XGroupCreateMkStream(ctx, streamName, streamName, "$").Result() + if err != nil && !StreamExists(ctx, streamName, client) { + return err + } + return nil +} + +// StreamExists returns whether there are any consumer group for specified +// redis stream. +func StreamExists(ctx context.Context, streamName string, client redis.UniversalClient) bool { + got, err := client.Do(ctx, "XINFO", "STREAM", streamName).Result() + if err != nil { + log.Error("Reading redis streams", "error", err) + return false + } + return got != nil +} diff --git a/pubsub/consumer.go b/pubsub/consumer.go index c9590de8e6..df3695606d 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -86,6 +86,14 @@ func heartBeatKey(id string) string { return fmt.Sprintf("consumer:%s:heartbeat", id) } +func (c *Consumer[Request, Response]) RedisClient() redis.UniversalClient { + return c.client +} + +func (c *Consumer[Request, Response]) StreamName() string { + return c.redisStream +} + func (c *Consumer[Request, Response]) heartBeatKey() string { return heartBeatKey(c.id) } diff --git a/staker/block_validator.go b/staker/block_validator.go index 5a511920f2..0fea05469f 100644 --- a/staker/block_validator.go +++ b/staker/block_validator.go @@ -1089,7 +1089,7 @@ func (v *BlockValidator) Initialize(ctx context.Context) error { } // First spawner is always RedisValidationClient if RedisStreams are enabled. if v.redisValidator != nil { - err := v.redisValidator.Initialize(moduleRoots) + err := v.redisValidator.Initialize(ctx, moduleRoots) if err != nil { return err } diff --git a/system_tests/block_validator_test.go b/system_tests/block_validator_test.go index debd6d4c7c..54046edf15 100644 --- a/system_tests/block_validator_test.go +++ b/system_tests/block_validator_test.go @@ -74,6 +74,8 @@ func testBlockValidatorSimple(t *testing.T, dasModeString string, workloadLoops redisURL = redisutil.CreateTestRedis(ctx, t) validatorConfig.BlockValidator.RedisValidationClientConfig = redis.TestValidationClientConfig validatorConfig.BlockValidator.RedisValidationClientConfig.RedisURL = redisURL + } else { + validatorConfig.BlockValidator.RedisValidationClientConfig = redis.ValidationClientConfig{} } AddDefaultValNode(t, ctx, validatorConfig, !arbitrator, redisURL) diff --git a/system_tests/common_test.go b/system_tests/common_test.go index 4e82f9f870..5e1aca3ec2 100644 --- a/system_tests/common_test.go +++ b/system_tests/common_test.go @@ -1128,7 +1128,7 @@ func setupConfigWithDAS( Require(t, err) restLis, err := net.Listen("tcp", "localhost:0") Require(t, err) - _, err = das.StartDASRPCServerOnListener(ctx, rpcLis, genericconf.HTTPServerTimeoutConfigDefault, daReader, daWriter, daHealthChecker) + _, err = das.StartDASRPCServerOnListener(ctx, rpcLis, genericconf.HTTPServerTimeoutConfigDefault, genericconf.HTTPServerBodyLimitDefault, daReader, daWriter, daHealthChecker) Require(t, err) _, err = das.NewRestfulDasServerOnListener(restLis, genericconf.HTTPServerTimeoutConfigDefault, daReader, daHealthChecker) Require(t, err) diff --git a/system_tests/das_test.go b/system_tests/das_test.go index 11d887315a..d15c5e6ba6 100644 --- a/system_tests/das_test.go +++ b/system_tests/das_test.go @@ -75,7 +75,7 @@ func startLocalDASServer( Require(t, err) rpcLis, err := net.Listen("tcp", "localhost:0") Require(t, err) - rpcServer, err := das.StartDASRPCServerOnListener(ctx, rpcLis, genericconf.HTTPServerTimeoutConfigDefault, storageService, daWriter, storageService) + rpcServer, err := das.StartDASRPCServerOnListener(ctx, rpcLis, genericconf.HTTPServerTimeoutConfigDefault, genericconf.HTTPServerBodyLimitDefault, storageService, daWriter, storageService) Require(t, err) restLis, err := net.Listen("tcp", "localhost:0") Require(t, err) @@ -284,7 +284,7 @@ func TestDASComplexConfigAndRestMirror(t *testing.T) { defer lifecycleManager.StopAndWaitUntil(time.Second) rpcLis, err := net.Listen("tcp", "localhost:0") Require(t, err) - _, err = das.StartDASRPCServerOnListener(ctx, rpcLis, genericconf.HTTPServerTimeoutConfigDefault, daReader, daWriter, daHealthChecker) + _, err = das.StartDASRPCServerOnListener(ctx, rpcLis, genericconf.HTTPServerTimeoutConfigDefault, genericconf.HTTPServerBodyLimitDefault, daReader, daWriter, daHealthChecker) Require(t, err) restLis, err := net.Listen("tcp", "localhost:0") Require(t, err) diff --git a/validator/client/redis/producer.go b/validator/client/redis/producer.go index 1055d93968..41ae100954 100644 --- a/validator/client/redis/producer.go +++ b/validator/client/redis/producer.go @@ -23,6 +23,7 @@ type ValidationClientConfig struct { Room int32 `koanf:"room"` RedisURL string `koanf:"redis-url"` ProducerConfig pubsub.ProducerConfig `koanf:"producer-config"` + CreateStreams bool `koanf:"create-streams"` } func (c ValidationClientConfig) Enabled() bool { @@ -34,6 +35,7 @@ var DefaultValidationClientConfig = ValidationClientConfig{ Room: 2, RedisURL: "", ProducerConfig: pubsub.DefaultProducerConfig, + CreateStreams: true, } var TestValidationClientConfig = ValidationClientConfig{ @@ -41,12 +43,14 @@ var TestValidationClientConfig = ValidationClientConfig{ Room: 2, RedisURL: "", ProducerConfig: pubsub.TestProducerConfig, + CreateStreams: false, } func ValidationClientConfigAddOptions(prefix string, f *pflag.FlagSet) { f.String(prefix+".name", DefaultValidationClientConfig.Name, "validation client name") f.Int32(prefix+".room", DefaultValidationClientConfig.Room, "validation client room") pubsub.ProducerAddConfigAddOptions(prefix+".producer-config", f) + f.Bool(prefix+".create-streams", DefaultValidationClientConfig.CreateStreams, "create redis streams if it does not exist") } // ValidationClient implements validation client through redis streams. @@ -59,6 +63,7 @@ type ValidationClient struct { producerConfig pubsub.ProducerConfig redisClient redis.UniversalClient moduleRoots []common.Hash + createStreams bool } func NewValidationClient(cfg *ValidationClientConfig) (*ValidationClient, error) { @@ -75,11 +80,17 @@ func NewValidationClient(cfg *ValidationClientConfig) (*ValidationClient, error) producers: make(map[common.Hash]*pubsub.Producer[*validator.ValidationInput, validator.GoGlobalState]), producerConfig: cfg.ProducerConfig, redisClient: redisClient, + createStreams: cfg.CreateStreams, }, nil } -func (c *ValidationClient) Initialize(moduleRoots []common.Hash) error { +func (c *ValidationClient) Initialize(ctx context.Context, moduleRoots []common.Hash) error { for _, mr := range moduleRoots { + if c.createStreams { + if err := pubsub.CreateStream(ctx, server_api.RedisStreamForRoot(mr), c.redisClient); err != nil { + return fmt.Errorf("creating redis stream: %w", err) + } + } if _, exists := c.producers[mr]; exists { log.Warn("Producer already existsw for module root", "hash", mr) continue diff --git a/validator/server_common/machine_locator.go b/validator/server_common/machine_locator.go index 28093c30f0..71f6af60b6 100644 --- a/validator/server_common/machine_locator.go +++ b/validator/server_common/machine_locator.go @@ -58,7 +58,7 @@ func NewMachineLocator(rootPath string) (*MachineLocator, error) { for _, dir := range dirs { fInfo, err := os.Stat(dir) if err != nil { - log.Warn("Getting file info", "error", err) + log.Warn("Getting file info", "dir", dir, "error", err) continue } if !fInfo.IsDir() { diff --git a/validator/valnode/redis/consumer.go b/validator/valnode/redis/consumer.go index 1cadaf7c9a..3569e78b5c 100644 --- a/validator/valnode/redis/consumer.go +++ b/validator/valnode/redis/consumer.go @@ -22,7 +22,8 @@ type ValidationServer struct { spawner validator.ValidationSpawner // consumers stores moduleRoot to consumer mapping. - consumers map[common.Hash]*pubsub.Consumer[*validator.ValidationInput, validator.GoGlobalState] + consumers map[common.Hash]*pubsub.Consumer[*validator.ValidationInput, validator.GoGlobalState] + streamTimeout time.Duration } func NewValidationServer(cfg *ValidationServerConfig, spawner validator.ValidationSpawner) (*ValidationServer, error) { @@ -43,39 +44,83 @@ func NewValidationServer(cfg *ValidationServerConfig, spawner validator.Validati consumers[mr] = c } return &ValidationServer{ - consumers: consumers, - spawner: spawner, + consumers: consumers, + spawner: spawner, + streamTimeout: cfg.StreamTimeout, }, nil } func (s *ValidationServer) Start(ctx_in context.Context) { s.StopWaiter.Start(ctx_in, s) + // Channel that all consumers use to indicate their readiness. + readyStreams := make(chan struct{}, len(s.consumers)) for moduleRoot, c := range s.consumers { c := c + moduleRoot := moduleRoot c.Start(ctx_in) - s.StopWaiter.CallIteratively(func(ctx context.Context) time.Duration { - req, err := c.Consume(ctx) - if err != nil { - log.Error("Consuming request", "error", err) - return 0 + // Channel for single consumer, once readiness is indicated in this, + // consumer will start consuming iteratively. + ready := make(chan struct{}, 1) + s.StopWaiter.LaunchThread(func(ctx context.Context) { + for { + if pubsub.StreamExists(ctx, c.StreamName(), c.RedisClient()) { + ready <- struct{}{} + readyStreams <- struct{}{} + return + } + select { + case <-ctx.Done(): + log.Info("Context done", "error", ctx.Err().Error()) + return + case <-time.After(time.Millisecond * 100): + } } - if req == nil { - // There's nothing in the queue. - return time.Second - } - valRun := s.spawner.Launch(req.Value, moduleRoot) - res, err := valRun.Await(ctx) - if err != nil { - log.Error("Error validating", "request value", req.Value, "error", err) - return 0 - } - if err := c.SetResult(ctx, req.ID, res); err != nil { - log.Error("Error setting result for request", "id", req.ID, "result", res, "error", err) - return 0 + }) + s.StopWaiter.LaunchThread(func(ctx context.Context) { + select { + case <-ctx.Done(): + log.Info("Context done", "error", ctx.Err().Error()) + return + case <-ready: // Wait until the stream exists and start consuming iteratively. } - return time.Second + s.StopWaiter.CallIteratively(func(ctx context.Context) time.Duration { + req, err := c.Consume(ctx) + if err != nil { + log.Error("Consuming request", "error", err) + return 0 + } + if req == nil { + // There's nothing in the queue. + return time.Second + } + valRun := s.spawner.Launch(req.Value, moduleRoot) + res, err := valRun.Await(ctx) + if err != nil { + log.Error("Error validating", "request value", req.Value, "error", err) + return 0 + } + if err := c.SetResult(ctx, req.ID, res); err != nil { + log.Error("Error setting result for request", "id", req.ID, "result", res, "error", err) + return 0 + } + return time.Second + }) }) } + s.StopWaiter.LaunchThread(func(ctx context.Context) { + for { + select { + case <-readyStreams: + log.Trace("At least one stream is ready") + return // Don't block Start if at least one of the stream is ready. + case <-time.After(s.streamTimeout): + log.Error("Waiting for redis streams timed out") + case <-ctx.Done(): + log.Info(("Context expired, failed to start")) + return + } + } + }) } type ValidationServerConfig struct { @@ -83,23 +128,28 @@ type ValidationServerConfig struct { ConsumerConfig pubsub.ConsumerConfig `koanf:"consumer-config"` // Supported wasm module roots. ModuleRoots []string `koanf:"module-roots"` + // Timeout on polling for existence of each redis stream. + StreamTimeout time.Duration `koanf:"stream-timeout"` } var DefaultValidationServerConfig = ValidationServerConfig{ RedisURL: "", ConsumerConfig: pubsub.DefaultConsumerConfig, ModuleRoots: []string{}, + StreamTimeout: 10 * time.Minute, } var TestValidationServerConfig = ValidationServerConfig{ RedisURL: "", ConsumerConfig: pubsub.TestConsumerConfig, ModuleRoots: []string{}, + StreamTimeout: time.Minute, } func ValidationServerConfigAddOptions(prefix string, f *pflag.FlagSet) { pubsub.ConsumerConfigAddOptions(prefix+".consumer-config", f) f.StringSlice(prefix+".module-roots", nil, "Supported module root hashes") + f.Duration(prefix+".stream-timeout", DefaultValidationServerConfig.StreamTimeout, "Timeout on polling for existence of redis streams") } func (cfg *ValidationServerConfig) Enabled() bool {