Skip to content

Commit

Permalink
Merge pull request #2345 from OffchainLabs/gligneul/split-db-snapshot
Browse files Browse the repository at this point in the history
[NIT-2505] Download DB snapshot that is split between multiple files
  • Loading branch information
gligneul authored May 28, 2024
2 parents 879c14a + c5f7c87 commit a400236
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cmd/conf/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var InitConfigDefault = InitConfig{

func InitConfigAddOptions(prefix string, f *pflag.FlagSet) {
f.Bool(prefix+".force", InitConfigDefault.Force, "if true: in case database exists init code will be reexecuted and genesis block compared to database")
f.String(prefix+".url", InitConfigDefault.Url, "url to download initializtion data - will poll if download fails")
f.String(prefix+".url", InitConfigDefault.Url, "url to download initialization data - will poll if download fails")
f.String(prefix+".download-path", InitConfigDefault.DownloadPath, "path to save temp downloaded file")
f.Duration(prefix+".download-poll", InitConfigDefault.DownloadPoll, "how long to wait between polling attempts")
f.Bool(prefix+".dev-init", InitConfigDefault.DevInit, "init with dev data (1 account with balance) instead of file import")
Expand Down
120 changes: 117 additions & 3 deletions cmd/nitro/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ package main

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
"net/http"
"os"
"runtime"
"strings"
Expand Down Expand Up @@ -40,6 +44,8 @@ import (
"github.com/offchainlabs/nitro/util/arbmath"
)

var notFoundError = errors.New("file not found")

func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, error) {
if initConfig.Url == "" {
return "", nil
Expand All @@ -66,18 +72,30 @@ func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, err
}
return initFile, nil
}
grabclient := grab.NewClient()
log.Info("Downloading initial database", "url", initConfig.Url)
fmt.Println()
path, err := downloadFile(ctx, initConfig, initConfig.Url)
if errors.Is(err, notFoundError) {
return downloadInitInParts(ctx, initConfig)
}
return path, err
}

func downloadFile(ctx context.Context, initConfig *conf.InitConfig, url string) (string, error) {
checksum, err := fetchChecksum(ctx, url+".sha256")
if err != nil {
return "", fmt.Errorf("error fetching checksum: %w", err)
}
grabclient := grab.NewClient()
printTicker := time.NewTicker(time.Second)
defer printTicker.Stop()
attempt := 0
for {
attempt++
req, err := grab.NewRequest(initConfig.DownloadPath, initConfig.Url)
req, err := grab.NewRequest(initConfig.DownloadPath, url)
if err != nil {
panic(err)
}
req.SetChecksum(sha256.New(), checksum, false)
resp := grabclient.Do(req.WithContext(ctx))
firstPrintTime := time.Now().Add(time.Second * 2)
updateLoop:
Expand All @@ -102,6 +120,9 @@ func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, err
}
case <-resp.Done:
if err := resp.Err(); err != nil {
if resp.HTTPResponse.StatusCode == http.StatusNotFound {
return "", fmt.Errorf("file not found but checksum exists")
}
fmt.Printf("\n attempt %d failed: %v\n", attempt, err)
break updateLoop
}
Expand All @@ -121,6 +142,99 @@ func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, err
}
}

// fetchChecksum performs a GET request to the specified URL using the provided context
// and returns the checksum as a []byte
func fetchChecksum(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making GET request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, notFoundError
} else if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %v", resp.Status)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
checksumStr := strings.TrimSpace(string(body))
checksum, err := hex.DecodeString(checksumStr)
if err != nil {
return nil, fmt.Errorf("error decoding checksum: %w", err)
}
if len(checksum) != sha256.Size {
return nil, fmt.Errorf("invalid checksum length")
}
return checksum, nil
}

func downloadInitInParts(ctx context.Context, initConfig *conf.InitConfig) (string, error) {
log.Info("File not found; trying to download database in parts")
fileInfo, err := os.Stat(initConfig.DownloadPath)
if err != nil || !fileInfo.IsDir() {
return "", fmt.Errorf("download path must be a directory: %v", initConfig.DownloadPath)
}
part := 0
parts := []string{}
defer func() {
// remove all temporary files.
for _, part := range parts {
err := os.Remove(part)
if err != nil {
log.Warn("Failed to remove temporary file", "file", part)
}
}
}()
for {
url := fmt.Sprintf("%s.part%d", initConfig.Url, part)
log.Info("Downloading database part", "url", url)
partFile, err := downloadFile(ctx, initConfig, url)
if errors.Is(err, notFoundError) {
log.Info("Part not found; concatenating archive into single file", "numParts", len(parts))
break
} else if err != nil {
return "", err
}
parts = append(parts, partFile)
part++
}
return joinArchive(parts)
}

// joinArchive joins the archive parts into a single file and return its path.
func joinArchive(parts []string) (string, error) {
if len(parts) == 0 {
return "", fmt.Errorf("no database parts found")
}
archivePath := strings.TrimSuffix(parts[0], ".part0")
archive, err := os.Create(archivePath)
if err != nil {
return "", fmt.Errorf("failed to create archive: %w", err)
}
defer archive.Close()
for _, part := range parts {
partFile, err := os.Open(part)
if err != nil {
return "", fmt.Errorf("failed to open part file %s: %w", part, err)
}
defer partFile.Close()
_, err = io.Copy(archive, partFile)
if err != nil {
return "", fmt.Errorf("failed to copy part file %s: %w", part, err)
}
log.Info("Joined database part into archive", "part", part)
}
log.Info("Successfully joined parts into archive", "archive", archivePath)
return archivePath, nil
}

func validateBlockChain(blockChain *core.BlockChain, chainConfig *params.ChainConfig) error {
statedb, err := blockChain.State()
if err != nil {
Expand Down
142 changes: 142 additions & 0 deletions cmd/nitro/init_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2021-2022, Offchain Labs, Inc.
// For license information, see https://github.com/nitro/blob/master/LICENSE

package main

import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net"
"net/http"
"os"
"testing"
"time"

"github.com/offchainlabs/nitro/cmd/conf"
"github.com/offchainlabs/nitro/util/testhelpers"
)

func TestDownloadInit(t *testing.T) {
const (
archiveName = "random_data.tar.gz"
dataSize = 1024 * 1024
filePerm = 0600
)

// Create archive with random data
serverDir := t.TempDir()
data := testhelpers.RandomSlice(dataSize)
checksumBytes := sha256.Sum256(data)
checksum := hex.EncodeToString(checksumBytes[:])

// Write archive file
archiveFile := fmt.Sprintf("%s/%s", serverDir, archiveName)
err := os.WriteFile(archiveFile, data, filePerm)
Require(t, err, "failed to write archive")

// Write checksum file
checksumFile := archiveFile + ".sha256"
err = os.WriteFile(checksumFile, []byte(checksum), filePerm)
Require(t, err, "failed to write checksum")

// Start HTTP server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
addr := startFileServer(t, ctx, serverDir)

// Download file
initConfig := conf.InitConfigDefault
initConfig.Url = fmt.Sprintf("http://%s/%s", addr, archiveName)
initConfig.DownloadPath = t.TempDir()
receivedArchive, err := downloadInit(ctx, &initConfig)
Require(t, err, "failed to download")

// Check archive contents
receivedData, err := os.ReadFile(receivedArchive)
Require(t, err, "failed to read received archive")
if !bytes.Equal(receivedData, data) {
t.Error("downloaded archive is different from generated one")
}
}

func TestDownloadInitInParts(t *testing.T) {
const (
archiveName = "random_data.tar.gz"
numParts = 3
partSize = 1024 * 1024
dataSize = numParts * partSize
filePerm = 0600
)

// Create parts with random data
serverDir := t.TempDir()
data := testhelpers.RandomSlice(dataSize)
for i := 0; i < numParts; i++ {
// Create part and checksum
partData := data[partSize*i : partSize*(i+1)]
checksumBytes := sha256.Sum256(partData)
checksum := hex.EncodeToString(checksumBytes[:])
// Write part file
partFile := fmt.Sprintf("%s/%s.part%d", serverDir, archiveName, i)
err := os.WriteFile(partFile, partData, filePerm)
Require(t, err, "failed to write part")
// Write checksum file
checksumFile := partFile + ".sha256"
err = os.WriteFile(checksumFile, []byte(checksum), filePerm)
Require(t, err, "failed to write checksum")
}

// Start HTTP server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
addr := startFileServer(t, ctx, serverDir)

// Download file
initConfig := conf.InitConfigDefault
initConfig.Url = fmt.Sprintf("http://%s/%s", addr, archiveName)
initConfig.DownloadPath = t.TempDir()
receivedArchive, err := downloadInit(ctx, &initConfig)
Require(t, err, "failed to download")

// check database contents
receivedData, err := os.ReadFile(receivedArchive)
Require(t, err, "failed to read received archive")
if !bytes.Equal(receivedData, data) {
t.Error("downloaded archive is different from generated one")
}

// Check if the function deleted the temporary files
entries, err := os.ReadDir(initConfig.DownloadPath)
Require(t, err, "failed to read temp dir")
if len(entries) != 1 {
t.Error("download function did not delete temp files")
}
}

func startFileServer(t *testing.T, ctx context.Context, dir string) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
Require(t, err, "failed to listen")
addr := ln.Addr().String()
server := &http.Server{
Addr: addr,
Handler: http.FileServer(http.Dir(dir)),
ReadHeaderTimeout: time.Second,
}
go func() {
err := server.Serve(ln)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Error("failed to shutdown server")
}
}()
go func() {
<-ctx.Done()
err := server.Shutdown(ctx)
Require(t, err, "failed to shutdown server")
}()
return addr
}

0 comments on commit a400236

Please sign in to comment.