diff --git a/.ci/build_and_test.sh b/.ci/build_and_test.sh index 6578544c7c8b..8ff0f0e179ad 100755 --- a/.ci/build_and_test.sh +++ b/.ci/build_and_test.sh @@ -4,4 +4,10 @@ set -ev cd $TRAVIS_BUILD_DIR ./scripts/build.sh +# Check to see if the build script creates any unstaged changes to prevent +# regression where builds go.mod/go.sum files get out of date. +if [[ -z $(git status -s) ]]; then + echo "Build script created unstaged changes in the repository" + exit 1 +fi ./scripts/build_test.sh diff --git a/.ci/run_e2e_tests.sh b/.ci/run_e2e_tests.sh index cc4042d77b5b..c0b2106abe36 100755 --- a/.ci/run_e2e_tests.sh +++ b/.ci/run_e2e_tests.sh @@ -18,7 +18,7 @@ echo "Using Avalanche Image: $AVALANCHE_IMAGE" DOCKER_REPO="avaplatform" BYZANTINE_IMAGE="$DOCKER_REPO/avalanche-byzantine:v0.2.0-rc.1" -TEST_SUITE_IMAGE="$DOCKER_REPO/avalanche-testing:v0.11.0-rc.1" +TEST_SUITE_IMAGE="$DOCKER_REPO/avalanche-testing:v0.11.0-rc.2" # Kurtosis Environment Parameters KURTOSIS_CORE_CHANNEL="1.0.3" diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000000..3e23d7e4be83 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +.ci +.git +.github +.gitignore +.golangci.yml +.idea +.vscode + +LICENSE +*.md diff --git a/.travis.yml b/.travis.yml index 0133ca2da742..901ae1bc5caf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,13 +5,10 @@ go: env: global: - - CODECOV_TOKEN="8c18c993-fc6e-4706-998b-01ddc7987804" - AVALANCHE_HOME=$GOPATH/src/github.com/$TRAVIS_REPO_SLUG/ - E2E_TEST_HOME=$GOPATH/src/github.com/ava-labs/avalanche-testing/ - COMMIT=${TRAVIS_COMMIT::8} - DOCKERHUB_REPO=avaplatform/avalanchego - - secure: Ozsv2nUqUVSdiaEovsffCBaGIaJdSGEq6zSNr1af74+zhYu1G5Dw3s0u5Uq42NTFygGVAg2ODh1/PSviAK3P7Dzgi3yMtUBD8kAAISJW3lKr/JavBOIsUnekhZYniAS77vUHwpOi6vQxgjhy/ymYxXTSRyHPys7DwZhZcCMiR6Bk/O7w8JbYo2m31mCaJZWpt9m4SCXVr+lK1prYuCOAME5SwKq3eVHfUGKn7w8f7kLUDv6XPLlAjzOQHKq4AD9DwDQX2wubAOc87a4BCti9suaXNyzRtS3AUQXjZkHy8BHyWPnMOIwOT6sVAYEm65fcOPxlawnkbs8ny7xnJMqj9R0tyq7XmnGoaALeXxOcV1B55TFjyo0P48NTugFrdqtQ+LSOvvVQJV/QBoe+sZwUTus+LP5lWl86EnQPGxjll+vXO2Mces+F48eoj9dfPBbBLRLEaxk54l6+H6JHvAG2QRtRG3beh7XbdFYPnt+LEuYdW3kyCRD24JhCrglJlebCnqAKR7GfAICf7ca5+WJj4Fiyyh/tUt4Ss0E53Mvz6NXFNpJVTEcFc3RFjIp/louK1Y2Uxbyr+LT9hw2bvo6Obz9sl/YZs10rCWZnC0zF2WE01tADD8YShrA70349hmtE2FSJLCg2LTVgmTFbS47Rn4QFcmo5AGHDNl31iwj9gKk= - - secure: Si2xR7IOINZRHtH6DbgxhpOH/oX1XL+DvlutmQVS5ZbjXFmvtqB8CT9uDUAKj38R+sgEmAwOENLJcl7SWZfXDE4EYqlcafh/pcBMjm8O6atWrrwmpZSEiHNba2t+yXBN7z5M/KV777FY43SbOEA+/5Bytcluk3Mxjjl5iFXEWai0RH9jmk7lsHmIyPYsAG9/SRwYgr0uZ3T5872HBD7td7+umiTyWshdT3dOXilYWWflc4eniv/ifp/H6A5k9uXrE50KdtaeDcAl5eAp9mItNd8nLenmaNzDkq6IBTUwy+gmEHctq8YbjmtQhdWtdIXDxPFknpBqKsg7oXgstJt11UVqhDcsnX6Trj1GO8InUykMLRDxWMwlCZfZdAuUvhbrHmbFOWF1ANL5Rl7RzUXSov02WAPvrze/8ZFq2O2f28CVkcWCZy/Vei9EhAwQUyJOug6R/1cSOkcqpovc5yf65dLnpUMb+f4fbMk5Z/YPijJ8VZSFU9ul73re6xcWz9PbxZWqsN2Ubqm6EKRB02gxsLbdl56lHcMl6uawrTUFDoie+alkiqlxP8Ey7Phw43os/lxHq9lQN0bZ2Mkq4LMVc4noa5TQNTNxS/hrAIly7IzZVV0VfiLeHJVT/2BWCUhNqw/jL2lpGTPbH0NqG/wIRAAW9BYDlQKc17vBM9Xvs8g= jobs: include: diff --git a/Dockerfile b/Dockerfile index 2caa08a205bb..43f361f2fbb8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,17 @@ # syntax=docker/dockerfile:experimental +ARG AVALANCHEGO_COMMIT FROM golang:1.15.5-buster +ARG AVALANCHEGO_COMMIT + RUN mkdir -p /go/src/github.com/ava-labs WORKDIR $GOPATH/src/github.com/ava-labs/ COPY . avalanchego WORKDIR $GOPATH/src/github.com/ava-labs/avalanchego +RUN export AVALANCHEGO_COMMIT=$AVALANCHEGO_COMMIT RUN ./scripts/build.sh RUN ln -sv $GOPATH/src/github.com/ava-labs/avalanchego/ /avalanchego diff --git a/api/admin/service.go b/api/admin/service.go index 4645f74f8bbb..66a89c8ecdf6 100644 --- a/api/admin/service.go +++ b/api/admin/service.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/rpc/v2" "github.com/ava-labs/avalanchego/api" + "github.com/ava-labs/avalanchego/api/server" "github.com/ava-labs/avalanchego/chains" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/engine/common" @@ -35,11 +36,11 @@ type Admin struct { log logging.Logger performance *Performance chainManager chains.Manager - httpServer *api.Server + httpServer *server.Server } // NewService returns a new admin API service -func NewService(log logging.Logger, chainManager chains.Manager, httpServer *api.Server) (*common.HTTPHandler, error) { +func NewService(log logging.Logger, chainManager chains.Manager, httpServer *server.Server) (*common.HTTPHandler, error) { newServer := rpc.NewServer() codec := cjson.NewCodec() newServer.RegisterCodec(codec, "application/json") diff --git a/api/auth/auth.go b/api/auth/auth.go index 2913bc5bf918..1e3fd94c7375 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -3,7 +3,6 @@ package auth import ( "crypto/rand" "encoding/base64" - "encoding/json" "errors" "fmt" "net/http" @@ -13,93 +12,126 @@ import ( "time" jwt "github.com/dgrijalva/jwt-go" - rpc "github.com/gorilla/rpc/v2/json2" + "github.com/gorilla/rpc/v2" + + "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/password" "github.com/ava-labs/avalanchego/utils/timer" + + cjson "github.com/ava-labs/avalanchego/utils/json" ) const ( - // Endpoint is the base of the auth URL - Endpoint = "auth" - headerKey = "Authorization" headerValStart = "Bearer " // number of bytes to use when generating a new random token ID tokenIDByteLen = 20 + + // defaultTokenLifespan is how long a token lives before it expires + defaultTokenLifespan = time.Hour * 12 + + maxEndpoints = 128 ) var ( - // TokenLifespan is how long a token lives before it expires - TokenLifespan = time.Hour * 12 - - // ErrNoToken is returned by GetToken if no token is provided - ErrNoToken = errors.New("auth token not provided") - ErrAuthHeaderNotParsable = fmt.Errorf( + errNoToken = errors.New("auth token not provided") + errAuthHeaderNotParsable = fmt.Errorf( "couldn't parse auth token. Header \"%s\" should be \"%sTOKEN.GOES.HERE\"", headerKey, headerValStart, ) - ErrInvalidSigningMethod = fmt.Errorf("auth token didn't specify the HS256 signing method correctly") - ErrTokenRevoked = errors.New("the provided auth token was revoked") - ErrTokenInsufficientPermission = errors.New("the provided auth token does not allow access to this endpoint") + errInvalidSigningMethod = fmt.Errorf("auth token didn't specify the HS256 signing method correctly") + errTokenRevoked = errors.New("the provided auth token was revoked") + errTokenInsufficientPermission = errors.New("the provided auth token does not allow access to this endpoint") errWrongPassword = errors.New("incorrect password") errSamePassword = errors.New("new password can't be same as old password") + errNoPassword = errors.New("no password") + + errNoEndpoints = errors.New("must name at least one endpoint") + errTooManyEndpoints = fmt.Errorf("can only name at most %d endpoints", maxEndpoints) ) -// Auth handles HTTP API authorization for this node -type Auth struct { - lock sync.RWMutex // Prevent race condition when accessing password - enabled bool // True iff API calls need auth token - password password.Hash // Hash of the password. Can be changed via API call. - clock timer.Clock // Tells the time. Can be faked for testing - revoked map[string]struct{} // Set of token IDs that have been revoked +type Auth interface { + // Create and return a new token that allows access to each API endpoint for + // [duration] such that the API's path ends with an element of [endpoints]. + // If one of the elements of [endpoints] is "*", all APIs are accessible. + NewToken(pw string, duration time.Duration, endpoints []string) (string, error) + + // Revokes [token]; it will not be accepted as authorization for future API + // calls. If the token is invalid, this is a no-op. If a token is revoked + // and then the password is changed, and then changed back to the current + // password, the token will be un-revoked. Therefore, passwords shouldn't be + // re-used before previously revoked tokens have expired. + RevokeToken(pw, token string) error + + // Authenticates [token] for access to [url]. + AuthenticateToken(token, url string) error + + // Change the password required to create and revoke tokens. + // [oldPW] is the current password. + // [newPW] is the new password. It can't be the empty string and it can't be + // unreasonably long. + // Changing the password makes tokens issued under a previous password + // invalid. + ChangePassword(oldPW, newPW string) error + + // Create the API endpoint for this auth handler. + CreateHandler() (http.Handler, error) + + // WrapHandler wraps an http.Handler. Before passing a request to the + // provided handler, the auth token is authenticated. + WrapHandler(h http.Handler) http.Handler } -func New(enabled bool, password string) (*Auth, error) { - auth := &Auth{ - enabled: enabled, - revoked: make(map[string]struct{}), - } - return auth, auth.password.Set(password) +type auth struct { + // Used to mock time. + clock timer.Clock + + log logging.Logger + endpoint string + + lock sync.RWMutex + // Can be changed via API call. + password password.Hash + // Set of token IDs that have been revoked + revoked map[string]struct{} } -func NewFromHash(enabled bool, password password.Hash) *Auth { - return &Auth{ - enabled: enabled, - password: password, +func New(log logging.Logger, endpoint, pw string) (Auth, error) { + a := &auth{ + log: log, + endpoint: endpoint, revoked: make(map[string]struct{}), } + return a, a.password.Set(pw) } -// Custom claim type used for API access token -type endpointClaims struct { - jwt.StandardClaims - - // Each element is an endpoint that the token allows access to - // If endpoints has an element "*", allows access to all API endpoints - // In this case, "*" should be the only element of [endpoints] - Endpoints []string `json:"endpoints,omitempty"` +func NewFromHash(log logging.Logger, endpoint string, pw password.Hash) Auth { + return &auth{ + log: log, + endpoint: endpoint, + password: pw, + revoked: make(map[string]struct{}), + } } -// getTokenKey returns the key to use when making and parsing tokens -func (auth *Auth) getTokenKey(t *jwt.Token) (interface{}, error) { - if t.Method != jwt.SigningMethodHS256 { - return nil, ErrInvalidSigningMethod +func (a *auth) NewToken(pw string, duration time.Duration, endpoints []string) (string, error) { + if pw == "" { + return "", errNoPassword + } + if l := len(endpoints); l == 0 { + return "", errNoEndpoints + } else if l > maxEndpoints { + return "", errTooManyEndpoints } - return auth.password.Password[:], nil -} -// Create and return a new token that allows access to each API endpoint such -// that the API's path ends with an element of [endpoints] -// If one of the elements of [endpoints] is "*", allows access to all APIs -func (auth *Auth) newToken(password string, endpoints []string) (string, error) { - auth.lock.RLock() - defer auth.lock.RUnlock() + a.lock.RLock() + defer a.lock.RUnlock() - if !auth.password.Check(password) { + if !a.password.Check(pw) { return "", errWrongPassword } @@ -119,7 +151,7 @@ func (auth *Auth) newToken(password string, endpoints []string) (string, error) claims := endpointClaims{ StandardClaims: jwt.StandardClaims{ - ExpiresAt: auth.clock.Time().Add(TokenLifespan).Unix(), + ExpiresAt: a.clock.Time().Add(duration).Unix(), Id: id, }, } @@ -129,25 +161,26 @@ func (auth *Auth) newToken(password string, endpoints []string) (string, error) claims.Endpoints = endpoints } token := jwt.NewWithClaims(jwt.SigningMethodHS256, &claims) - return token.SignedString(auth.password.Password[:]) // Sign the token and return its string repr. + return token.SignedString(a.password.Password[:]) // Sign the token and return its string repr. } -// Revokes the token whose string repr. is [tokenStr]; it will not be accepted as authorization for future API calls. -// If the token is invalid, this is a no-op. -// Only currently valid tokens can be revoked -// If a token is revoked and then the password is changed, and then changed back to the current password, -// the token will be un-revoked. Don't re-use passwords before at least TokenLifespan has elapsed. -// Returns an error if the wrong password is given -func (auth *Auth) revokeToken(tokenStr string, password string) error { - auth.lock.Lock() - defer auth.lock.Unlock() - - if !auth.password.Check(password) { +func (a *auth) RevokeToken(tokenStr, pw string) error { + if tokenStr == "" { + return errNoToken + } + if pw == "" { + return errNoPassword + } + + a.lock.Lock() + defer a.lock.Unlock() + + if !a.password.Check(pw) { return errWrongPassword } // See if token is well-formed and signature is right - token, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, auth.getTokenKey) + token, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, a.getTokenKey) if err != nil { return err } @@ -161,16 +194,15 @@ func (auth *Auth) revokeToken(tokenStr string, password string) error { if !ok { return fmt.Errorf("expected auth token's claims to be type endpointClaims but is %T", token.Claims) } - auth.revoked[claims.Id] = struct{}{} + a.revoked[claims.Id] = struct{}{} return nil } -// Authenticates [tokenStr] for access to [url]. -func (auth *Auth) authenticateToken(tokenStr, url string) error { - auth.lock.RLock() - defer auth.lock.RUnlock() +func (a *auth) AuthenticateToken(tokenStr, url string) error { + a.lock.RLock() + defer a.lock.RUnlock() - token, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, auth.getTokenKey) + token, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, a.getTokenKey) if err != nil { // Probably because signature wrong return err } @@ -183,9 +215,9 @@ func (auth *Auth) authenticateToken(tokenStr, url string) error { return fmt.Errorf("expected auth token's claims to be type endpointClaims but is %T", token.Claims) } - _, revoked := auth.revoked[claims.Id] + _, revoked := a.revoked[claims.Id] if revoked { - return ErrTokenRevoked + return errTokenRevoked } for _, endpoint := range claims.Endpoints { @@ -193,47 +225,48 @@ func (auth *Auth) authenticateToken(tokenStr, url string) error { return nil } } - return ErrTokenInsufficientPermission + return errTokenInsufficientPermission } -// Change the password required to create and revoke tokens. -// [oldPassword] is the current password. -// [newPassword] is the new password. It can't be the empty string and it can't -// be unreasonably long. -// Changing the password makes tokens issued under a previous password invalid. -func (auth *Auth) changePassword(oldPassword, newPassword string) error { - if oldPassword == newPassword { +func (a *auth) ChangePassword(oldPW, newPW string) error { + if oldPW == newPW { return errSamePassword } - auth.lock.Lock() - defer auth.lock.Unlock() + a.lock.Lock() + defer a.lock.Unlock() - if !auth.password.Check(oldPassword) { + if !a.password.Check(oldPW) { return errWrongPassword } - if err := password.IsValid(newPassword, password.OK); err != nil { + if err := password.IsValid(newPW, password.OK); err != nil { return err } - if err := auth.password.Set(newPassword); err != nil { + if err := a.password.Set(newPW); err != nil { return err } // All the revoked tokens are now invalid; no need to mark specifically as // revoked. - auth.revoked = make(map[string]struct{}) + a.revoked = make(map[string]struct{}) return nil } -// WrapHandler wraps a handler. Before passing a request to the handler, check that -// an auth token was provided (if necessary) and that it is valid/unexpired. -func (auth *Auth) WrapHandler(h http.Handler) http.Handler { - if !auth.enabled { // Auth tokens aren't in use. Do nothing. - return h - } +func (a *auth) CreateHandler() (http.Handler, error) { + server := rpc.NewServer() + codec := cjson.NewCodec() + server.RegisterCodec(codec, "application/json") + server.RegisterCodec(codec, "application/json;charset=UTF-8") + return server, server.RegisterService( + &service{auth: a}, + "auth", + ) +} + +func (a *auth) WrapHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Don't require auth token to hit auth endpoint - if path.Base(r.URL.Path) == Endpoint { + if path.Base(r.URL.Path) == a.endpoint { h.ServeHTTP(w, r) return } @@ -241,19 +274,19 @@ func (auth *Auth) WrapHandler(h http.Handler) http.Handler { // Should be "Bearer AUTH.TOKEN.HERE" rawHeader := r.Header.Get(headerKey) if rawHeader == "" { - writeUnauthorizedResponse(w, ErrNoToken) + writeUnauthorizedResponse(w, errNoToken) return } if !strings.HasPrefix(rawHeader, headerValStart) { // Error is intentionally dropped here as there is nothing left to // do with it. - writeUnauthorizedResponse(w, ErrAuthHeaderNotParsable) + writeUnauthorizedResponse(w, errAuthHeaderNotParsable) return } // Returns actual auth token. Slice guaranteed to not go OOB tokenStr := rawHeader[len(headerValStart):] - if err := auth.authenticateToken(tokenStr, r.URL.Path); err != nil { + if err := a.AuthenticateToken(tokenStr, r.URL.Path); err != nil { writeUnauthorizedResponse(w, err) return } @@ -262,27 +295,10 @@ func (auth *Auth) WrapHandler(h http.Handler) http.Handler { }) } -// Write a JSON-RPC formatted response saying that the API call is unauthorized. -// The response has header http.StatusUnauthorized. -// Errors while marshalling or writing are ignored. -func writeUnauthorizedResponse(w http.ResponseWriter, err error) { - body := struct { - Version string `json:"jsonrpc"` - Err struct { - Code rpc.ErrorCode `json:"code"` - Message string `json:"message"` - } `json:"error"` - ID uint8 `json:"id"` - }{} - - body.Version = rpc.Version - body.Err.Code = rpc.E_INVALID_REQ - body.Err.Message = err.Error() - body.ID = 1 - - encoded, _ := json.Marshal(body) - - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write(encoded) +// getTokenKey returns the key to use when making and parsing tokens +func (a *auth) getTokenKey(t *jwt.Token) (interface{}, error) { + if t.Method != jwt.SigningMethodHS256 { + return nil, errInvalidSigningMethod + } + return a.password.Password[:], nil } diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index 363b196c5986..5bd81145e9fb 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -15,13 +15,14 @@ import ( jwt "github.com/dgrijalva/jwt-go" + "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/password" ) var ( testPassword = "password!@#$%$#@!" hashedPassword = password.Hash{} - unAuthorizedResponseRegex = "^{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"(.*)\"},\"id\":1}$" + unAuthorizedResponseRegex = "^{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"(.*)\"},\"id\":1}" ) func init() { @@ -36,24 +37,24 @@ var ( ) func TestNewTokenWrongPassword(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword) - _, err := auth.newToken("", []string{"endpoint1, endpoint2"}) + _, err := auth.NewToken("", defaultTokenLifespan, []string{"endpoint1, endpoint2"}) assert.Error(t, err, "should have failed because password is wrong") - _, err = auth.newToken("notThePassword", []string{"endpoint1, endpoint2"}) + _, err = auth.NewToken("notThePassword", defaultTokenLifespan, []string{"endpoint1, endpoint2"}) assert.Error(t, err, "should have failed because password is wrong") } func TestNewTokenHappyPath(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword).(*auth) now := time.Now() auth.clock.Set(now) // Make a token endpoints := []string{"endpoint1", "endpoint2", "endpoint3"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) // Parse the token @@ -68,16 +69,16 @@ func TestNewTokenHappyPath(t *testing.T) { assert.True(t, ok, "expected auth token's claims to be type endpointClaims but is different type") assert.ElementsMatch(t, endpoints, claims.Endpoints, "token has wrong endpoint claims") - shouldExpireAt := now.Add(TokenLifespan).Unix() + shouldExpireAt := now.Add(defaultTokenLifespan).Unix() assert.Equal(t, shouldExpireAt, claims.ExpiresAt, "token expiration time is wrong") } func TestTokenHasWrongSig(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword).(*auth) // Make a token endpoints := []string{"endpoint1", "endpoint2", "endpoint3"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) // Try to parse the token using the wrong password @@ -98,52 +99,52 @@ func TestTokenHasWrongSig(t *testing.T) { } func TestChangePassword(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword).(*auth) password2 := "fejhkefjhefjhefhje" // #nosec G101 var err error - err = auth.changePassword("", password2) + err = auth.ChangePassword("", password2) assert.Error(t, err, "should have failed because old password is wrong") - err = auth.changePassword("notThePassword", password2) + err = auth.ChangePassword("notThePassword", password2) assert.Error(t, err, "should have failed because old password is wrong") - err = auth.changePassword(testPassword, "") + err = auth.ChangePassword(testPassword, "") assert.Error(t, err, "should have failed because new password is empty") - err = auth.changePassword(testPassword, password2) + err = auth.ChangePassword(testPassword, password2) assert.NoError(t, err, "should have succeeded") assert.True(t, auth.password.Check(password2), "password should have been changed") password3 := "ufwhwohwfohawfhwdwd" // #nosec G101 - err = auth.changePassword(testPassword, password3) + err = auth.ChangePassword(testPassword, password3) assert.Error(t, err, "should have failed because old password is wrong") - err = auth.changePassword(password2, password3) + err = auth.ChangePassword(password2, password3) assert.NoError(t, err, "should have succeeded") } func TestRevokeToken(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword).(*auth) // Make a token endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) - err = auth.revokeToken(tokenStr, testPassword) + err = auth.RevokeToken(tokenStr, testPassword) assert.NoError(t, err, "should have succeeded") assert.Len(t, auth.revoked, 1, "revoked token list is incorrect") } func TestWrapHandlerHappyPath(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword) // Make a token endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) wrappedHandler := auth.WrapHandler(dummyHandler) @@ -158,14 +159,14 @@ func TestWrapHandlerHappyPath(t *testing.T) { } func TestWrapHandlerRevokedToken(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword) // Make a token endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) - err = auth.revokeToken(tokenStr, testPassword) + err = auth.RevokeToken(tokenStr, testPassword) assert.NoError(t, err) wrappedHandler := auth.WrapHandler(dummyHandler) @@ -176,19 +177,19 @@ func TestWrapHandlerRevokedToken(t *testing.T) { rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Body.String(), ErrTokenRevoked.Error()) + assert.Contains(t, rr.Body.String(), errTokenRevoked.Error()) assert.Regexp(t, unAuthorizedResponseRegex, rr.Body.String()) } } func TestWrapHandlerExpiredToken(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword).(*auth) - auth.clock.Set(time.Now().Add(-2 * TokenLifespan)) + auth.clock.Set(time.Now().Add(-2 * defaultTokenLifespan)) // Make a token that expired well in the past endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) wrappedHandler := auth.WrapHandler(dummyHandler) @@ -205,7 +206,7 @@ func TestWrapHandlerExpiredToken(t *testing.T) { } func TestWrapHandlerNoAuthToken(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword) endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} wrappedHandler := auth.WrapHandler(dummyHandler) @@ -214,17 +215,17 @@ func TestWrapHandlerNoAuthToken(t *testing.T) { rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Body.String(), ErrNoToken.Error()) + assert.Contains(t, rr.Body.String(), errNoToken.Error()) assert.Regexp(t, unAuthorizedResponseRegex, rr.Body.String()) } } func TestWrapHandlerUnauthorizedEndpoint(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword) // Make a token endpoints := []string{"/ext/info"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) unauthorizedEndpoints := []string{"/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"} @@ -236,21 +237,21 @@ func TestWrapHandlerUnauthorizedEndpoint(t *testing.T) { rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Body.String(), ErrTokenInsufficientPermission.Error()) + assert.Contains(t, rr.Body.String(), errTokenInsufficientPermission.Error()) assert.Regexp(t, unAuthorizedResponseRegex, rr.Body.String()) } } func TestWrapHandlerAuthEndpoint(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword) // Make a token endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) wrappedHandler := auth.WrapHandler(dummyHandler) - req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", fmt.Sprintf("/ext/%s", Endpoint)), strings.NewReader("")) + req := httptest.NewRequest(http.MethodPost, "http://127.0.0.1:9650/ext/auth", strings.NewReader("")) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr)) rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) @@ -258,11 +259,11 @@ func TestWrapHandlerAuthEndpoint(t *testing.T) { } func TestWrapHandlerAccessAll(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword) // Make a token that allows access to all endpoints endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info"} - tokenStr, err := auth.newToken(testPassword, []string{"*"}) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, []string{"*"}) assert.NoError(t, err) wrappedHandler := auth.WrapHandler(dummyHandler) @@ -275,36 +276,22 @@ func TestWrapHandlerAccessAll(t *testing.T) { } } -func TestWrapHandlerAuthDisabled(t *testing.T) { - auth := NewFromHash(false, hashedPassword) - - endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info", "/ext/auth"} - - wrappedHandler := auth.WrapHandler(dummyHandler) - for _, endpoint := range endpoints { - req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader("")) - rr := httptest.NewRecorder() - wrappedHandler.ServeHTTP(rr, req) - assert.Equal(t, http.StatusOK, rr.Code) - } -} - func TestWriteUnauthorizedResponse(t *testing.T) { rr := httptest.NewRecorder() writeUnauthorizedResponse(rr, errors.New("example err")) assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Equal(t, "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"example err\"},\"id\":1}", rr.Body.String()) + assert.Equal(t, "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"example err\"},\"id\":1}\n", rr.Body.String()) } func TestWrapHandlerMutatedRevokedToken(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword) // Make a token endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} - tokenStr, err := auth.newToken(testPassword, endpoints) + tokenStr, err := auth.NewToken(testPassword, defaultTokenLifespan, endpoints) assert.NoError(t, err) - err = auth.revokeToken(tokenStr, testPassword) + err = auth.RevokeToken(tokenStr, testPassword) assert.NoError(t, err) wrappedHandler := auth.WrapHandler(dummyHandler) @@ -315,13 +302,13 @@ func TestWrapHandlerMutatedRevokedToken(t *testing.T) { rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Body.String(), ErrTokenRevoked.Error()) + assert.Contains(t, rr.Body.String(), errTokenRevoked.Error()) assert.Regexp(t, unAuthorizedResponseRegex, rr.Body.String()) } } func TestWrapHandlerInvalidSigningMethod(t *testing.T) { - auth := NewFromHash(true, hashedPassword) + auth := NewFromHash(logging.NoLog{}, "auth", hashedPassword).(*auth) // Make a token endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} @@ -333,7 +320,7 @@ func TestWrapHandlerInvalidSigningMethod(t *testing.T) { claims := endpointClaims{ StandardClaims: jwt.StandardClaims{ - ExpiresAt: auth.clock.Time().Add(TokenLifespan).Unix(), + ExpiresAt: auth.clock.Time().Add(defaultTokenLifespan).Unix(), Id: id, }, Endpoints: endpoints, @@ -352,7 +339,7 @@ func TestWrapHandlerInvalidSigningMethod(t *testing.T) { rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Body.String(), ErrInvalidSigningMethod.Error()) + assert.Contains(t, rr.Body.String(), errInvalidSigningMethod.Error()) assert.Regexp(t, unAuthorizedResponseRegex, rr.Body.String()) } } diff --git a/api/auth/claims.go b/api/auth/claims.go new file mode 100644 index 000000000000..b580b8337d8a --- /dev/null +++ b/api/auth/claims.go @@ -0,0 +1,15 @@ +package auth + +import ( + jwt "github.com/dgrijalva/jwt-go" +) + +// Custom claim type used for API access token +type endpointClaims struct { + jwt.StandardClaims + + // Each element is an endpoint that the token allows access to + // If endpoints has an element "*", allows access to all API endpoints + // In this case, "*" should be the only element of [endpoints] + Endpoints []string `json:"endpoints,omitempty"` +} diff --git a/api/auth/response.go b/api/auth/response.go new file mode 100644 index 000000000000..af33cd62c818 --- /dev/null +++ b/api/auth/response.go @@ -0,0 +1,37 @@ +package auth + +import ( + "encoding/json" + "net/http" + + rpc "github.com/gorilla/rpc/v2/json2" +) + +type responseErr struct { + Code rpc.ErrorCode `json:"code"` + Message string `json:"message"` +} + +type responseBody struct { + Version string `json:"jsonrpc"` + Err responseErr `json:"error"` + ID uint8 `json:"id"` +} + +// Write a JSON-RPC formatted response saying that the API call is unauthorized. +// The response has header http.StatusUnauthorized. +// Errors while writing are ignored. +func writeUnauthorizedResponse(w http.ResponseWriter, err error) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + // There isn't anything to do with the returned error, so it is dropped. + _ = json.NewEncoder(w).Encode(responseBody{ + Version: rpc.Version, + Err: responseErr{ + Code: rpc.E_INVALID_REQ, + Message: err.Error(), + }, + ID: 1, + }) +} diff --git a/api/auth/service.go b/api/auth/service.go index a56bada83adc..68079751c1bd 100644 --- a/api/auth/service.go +++ b/api/auth/service.go @@ -1,114 +1,62 @@ package auth import ( - "errors" - "fmt" "net/http" - "github.com/gorilla/rpc/v2" - - "github.com/ava-labs/avalanchego/snow/engine/common" - "github.com/ava-labs/avalanchego/utils/logging" - - cjson "github.com/ava-labs/avalanchego/utils/json" -) - -const ( - maxEndpoints = 128 -) - -var ( - errNoPassword = errors.New("argument 'password' not given") - errNoToken = errors.New("argument 'token' not given") + "github.com/ava-labs/avalanchego/api" ) -// Service ... -type Service struct { - *Auth // has to be a reference to the same Auth inside the API server - log logging.Logger +// service that serves the Auth API functionality. +type service struct { + auth *auth } -// NewService returns a new auth API service -func NewService(log logging.Logger, auth *Auth) *common.HTTPHandler { - newServer := rpc.NewServer() - codec := cjson.NewCodec() - newServer.RegisterCodec(codec, "application/json") - newServer.RegisterCodec(codec, "application/json;charset=UTF-8") - log.AssertNoError(newServer.RegisterService(&Service{Auth: auth, log: log}, "auth")) - return &common.HTTPHandler{Handler: newServer} -} - -// Success ... -type Success struct { - Success bool `json:"success"` -} - -// Password ... type Password struct { Password string `json:"password"` // The authorization password } -// NewTokenArgs ... type NewTokenArgs struct { Password - // Endpoints that may be accessed with this token - // e.g. if endpoints is ["/ext/bc/X", "/ext/admin"] then the token holder - // can hit the X-Chain API and the admin API - // If [Endpoints] contains an element "*" then the token - // allows access to all API endpoints - // [Endpoints] must have between 1 and [maxEndpoints] elements + // Endpoints that may be accessed with this token e.g. if endpoints is + // ["/ext/bc/X", "/ext/admin"] then the token holder can hit the X-Chain API + // and the admin API. If [Endpoints] contains an element "*" then the token + // allows access to all API endpoints. [Endpoints] must have between 1 and + // [maxEndpoints] elements Endpoints []string `json:"endpoints"` } -// Token ... type Token struct { Token string `json:"token"` // The new token. Expires in [TokenLifespan]. } -// NewToken returns a new token -func (s *Service) NewToken(_ *http.Request, args *NewTokenArgs, reply *Token) error { - s.log.Info("Auth: NewToken called") - if args.Password.Password == "" { - return errNoPassword - } - if l := len(args.Endpoints); l < 1 || l > maxEndpoints { - return fmt.Errorf("argument 'endpoints' must have between %d and %d elements, but has %d", - 1, maxEndpoints, l) - } - token, err := s.newToken(args.Password.Password, args.Endpoints) - reply.Token = token +func (s *service) NewToken(_ *http.Request, args *NewTokenArgs, reply *Token) error { + s.auth.log.Info("Auth: NewToken called") + + var err error + reply.Token, err = s.auth.NewToken(args.Password.Password, defaultTokenLifespan, args.Endpoints) return err } -// RevokeTokenArgs ... type RevokeTokenArgs struct { Password Token } -// RevokeToken revokes a token -func (s *Service) RevokeToken(_ *http.Request, args *RevokeTokenArgs, reply *Success) error { - s.log.Info("Auth: RevokeToken called") - if args.Password.Password == "" { - return errNoPassword - } else if args.Token.Token == "" { - return errNoToken - } +func (s *service) RevokeToken(_ *http.Request, args *RevokeTokenArgs, reply *api.SuccessResponse) error { + s.auth.log.Info("Auth: RevokeToken called") + reply.Success = true - return s.revokeToken(args.Token.Token, args.Password.Password) + return s.auth.RevokeToken(args.Token.Token, args.Password.Password) } -// ChangePasswordArgs ... type ChangePasswordArgs struct { OldPassword string `json:"oldPassword"` // Current authorization password NewPassword string `json:"newPassword"` // New authorization password } -// ChangePassword changes the password required to create and revoke tokens -// Changing the password makes tokens issued under a previous password invalid -func (s *Service) ChangePassword(_ *http.Request, args *ChangePasswordArgs, reply *Success) error { - s.log.Info("Auth: ChangePassword called") +func (s *service) ChangePassword(_ *http.Request, args *ChangePasswordArgs, reply *api.SuccessResponse) error { + s.auth.log.Info("Auth: ChangePassword called") reply.Success = true - return s.changePassword(args.OldPassword, args.NewPassword) + return s.auth.ChangePassword(args.OldPassword, args.NewPassword) } diff --git a/api/ipcs/service.go b/api/ipcs/service.go index 933fc3a8156c..bf5c8a1356b2 100644 --- a/api/ipcs/service.go +++ b/api/ipcs/service.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/rpc/v2" "github.com/ava-labs/avalanchego/api" + "github.com/ava-labs/avalanchego/api/server" "github.com/ava-labs/avalanchego/chains" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/ipcs" @@ -20,14 +21,14 @@ import ( // IPCServer maintains the IPCs type IPCServer struct { - httpServer *api.Server + httpServer *server.Server chainManager chains.Manager log logging.Logger ipcs *ipcs.ChainIPCs } // NewService returns a new IPCs API service -func NewService(log logging.Logger, chainManager chains.Manager, httpServer *api.Server, ipcs *ipcs.ChainIPCs) (*common.HTTPHandler, error) { +func NewService(log logging.Logger, chainManager chains.Manager, httpServer *server.Server, ipcs *ipcs.ChainIPCs) (*common.HTTPHandler, error) { ipcServer := &IPCServer{ log: log, chainManager: chainManager, diff --git a/api/keystore/blockchain_keystore.go b/api/keystore/blockchain_keystore.go index 1f1b263f28c4..6cbc16c7dccf 100644 --- a/api/keystore/blockchain_keystore.go +++ b/api/keystore/blockchain_keystore.go @@ -5,16 +5,36 @@ package keystore import ( "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/encdb" "github.com/ava-labs/avalanchego/ids" ) -// BlockchainKeystore ... -type BlockchainKeystore struct { +var ( + _ BlockchainKeystore = &blockchainKeystore{} +) + +type BlockchainKeystore interface { + // Get a database that is able to read and write unencrypted values from the + // underlying database. + GetDatabase(username, password string) (*encdb.Database, error) + + // Get the underlying database that is able to read and write encrypted + // values. This Database will not perform any encrypting or decrypting of + // values and is not recommended to be used when implementing a VM. + GetRawDatabase(username, password string) (database.Database, error) +} + +type blockchainKeystore struct { blockchainID ids.ID - ks *Keystore + ks *keystore } -// GetDatabase ... -func (bks *BlockchainKeystore) GetDatabase(username, password string) (database.Database, error) { +func (bks *blockchainKeystore) GetDatabase(username, password string) (*encdb.Database, error) { + bks.ks.log.Info("Keystore: GetDatabase called with %s from %s", username, bks.blockchainID) return bks.ks.GetDatabase(bks.blockchainID, username, password) } + +func (bks *blockchainKeystore) GetRawDatabase(username, password string) (database.Database, error) { + bks.ks.log.Info("Keystore: GetRawDatabase called with %s from %s", username, bks.blockchainID) + return bks.ks.GetRawDatabase(bks.blockchainID, username, password) +} diff --git a/api/keystore/codec.go b/api/keystore/codec.go new file mode 100644 index 000000000000..027dc548bf57 --- /dev/null +++ b/api/keystore/codec.go @@ -0,0 +1,29 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package keystore + +import ( + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/codec/linearcodec" + "github.com/ava-labs/avalanchego/codec/reflectcodec" +) + +const ( + maxPackerSize = 1 << 30 // max size, in bytes, of something being marshalled by Marshal() + maxSliceLength = 1 << 18 + + codecVersion = 0 +) + +var ( + c codec.Manager +) + +func init() { + lc := linearcodec.New(reflectcodec.DefaultTagName, maxSliceLength) + c = codec.NewManager(maxPackerSize) + if err := c.RegisterCodec(codecVersion, lc); err != nil { + panic(err) + } +} diff --git a/vms/rpcchainvm/gkeystore/gkeystoreproto/gkeystore.pb.go b/api/keystore/gkeystore/gkeystoreproto/gkeystore.pb.go similarity index 100% rename from vms/rpcchainvm/gkeystore/gkeystoreproto/gkeystore.pb.go rename to api/keystore/gkeystore/gkeystoreproto/gkeystore.pb.go diff --git a/vms/rpcchainvm/gkeystore/gkeystoreproto/gkeystore.proto b/api/keystore/gkeystore/gkeystoreproto/gkeystore.proto similarity index 100% rename from vms/rpcchainvm/gkeystore/gkeystoreproto/gkeystore.proto rename to api/keystore/gkeystore/gkeystoreproto/gkeystore.proto diff --git a/vms/rpcchainvm/gkeystore/keystore_client.go b/api/keystore/gkeystore/keystore_client.go similarity index 67% rename from vms/rpcchainvm/gkeystore/keystore_client.go rename to api/keystore/gkeystore/keystore_client.go index 1f0c028cc061..f0d50fe8206e 100644 --- a/vms/rpcchainvm/gkeystore/keystore_client.go +++ b/api/keystore/gkeystore/keystore_client.go @@ -8,15 +8,16 @@ import ( "github.com/hashicorp/go-plugin" + "github.com/ava-labs/avalanchego/api/keystore" + "github.com/ava-labs/avalanchego/api/keystore/gkeystore/gkeystoreproto" "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/encdb" "github.com/ava-labs/avalanchego/database/rpcdb" "github.com/ava-labs/avalanchego/database/rpcdb/rpcdbproto" - "github.com/ava-labs/avalanchego/snow" - "github.com/ava-labs/avalanchego/vms/rpcchainvm/gkeystore/gkeystoreproto" ) var ( - _ snow.Keystore = &Client{} + _ keystore.BlockchainKeystore = &Client{} ) // Client is a snow.Keystore that talks over RPC. @@ -33,7 +34,15 @@ func NewClient(client gkeystoreproto.KeystoreClient, broker *plugin.GRPCBroker) } } -func (c *Client) GetDatabase(username, password string) (database.Database, error) { +func (c *Client) GetDatabase(username, password string) (*encdb.Database, error) { + bcDB, err := c.GetRawDatabase(username, password) + if err != nil { + return nil, err + } + return encdb.New([]byte(password), bcDB) +} + +func (c *Client) GetRawDatabase(username, password string) (database.Database, error) { resp, err := c.client.GetDatabase(context.Background(), &gkeystoreproto.GetDatabaseRequest{ Username: username, Password: password, diff --git a/vms/rpcchainvm/gkeystore/keystore_server.go b/api/keystore/gkeystore/keystore_server.go similarity index 82% rename from vms/rpcchainvm/gkeystore/keystore_server.go rename to api/keystore/gkeystore/keystore_server.go index 239e7a7b2c1e..789fdf98cfb8 100644 --- a/vms/rpcchainvm/gkeystore/keystore_server.go +++ b/api/keystore/gkeystore/keystore_server.go @@ -10,11 +10,11 @@ import ( "github.com/hashicorp/go-plugin" + "github.com/ava-labs/avalanchego/api/keystore" + "github.com/ava-labs/avalanchego/api/keystore/gkeystore/gkeystoreproto" "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/rpcdb" "github.com/ava-labs/avalanchego/database/rpcdb/rpcdbproto" - "github.com/ava-labs/avalanchego/snow" - "github.com/ava-labs/avalanchego/vms/rpcchainvm/gkeystore/gkeystoreproto" "github.com/ava-labs/avalanchego/vms/rpcchainvm/grpcutils" ) @@ -24,12 +24,12 @@ var ( // Server is a snow.Keystore that is managed over RPC. type Server struct { - ks snow.Keystore + ks keystore.BlockchainKeystore broker *plugin.GRPCBroker } // NewServer returns a keystore connected to a remote keystore -func NewServer(ks snow.Keystore, broker *plugin.GRPCBroker) *Server { +func NewServer(ks keystore.BlockchainKeystore, broker *plugin.GRPCBroker) *Server { return &Server{ ks: ks, broker: broker, @@ -40,7 +40,7 @@ func (s *Server) GetDatabase( _ context.Context, req *gkeystoreproto.GetDatabaseRequest, ) (*gkeystoreproto.GetDatabaseResponse, error) { - db, err := s.ks.GetDatabase(req.Username, req.Password) + db, err := s.ks.GetRawDatabase(req.Username, req.Password) if err != nil { return nil, err } diff --git a/api/keystore/keystore.go b/api/keystore/keystore.go new file mode 100644 index 000000000000..e211a138195c --- /dev/null +++ b/api/keystore/keystore.go @@ -0,0 +1,385 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package keystore + +import ( + "errors" + "fmt" + "net/http" + "sync" + + "github.com/gorilla/rpc/v2" + + "github.com/ava-labs/avalanchego/chains/atomic" + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/encdb" + "github.com/ava-labs/avalanchego/database/prefixdb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/password" + + jsoncodec "github.com/ava-labs/avalanchego/utils/json" +) + +const ( + // maxUserLen is the maximum allowed length of a username + maxUserLen = 1024 +) + +var ( + errEmptyUsername = errors.New("empty username") + errUserMaxLength = fmt.Errorf("username exceeds maximum length of %d chars", maxUserLen) + + _ Keystore = &keystore{} +) + +type Keystore interface { + // Create the API endpoint for this keystore. + CreateHandler() (http.Handler, error) + + // NewBlockchainKeyStore returns this keystore limiting the functionality to + // a single blockchain database. + NewBlockchainKeyStore(blockchainID ids.ID) BlockchainKeystore + + // Get a database that is able to read and write unencrypted values from the + // underlying database. + GetDatabase(bID ids.ID, username, password string) (*encdb.Database, error) + + // Get the underlying database that is able to read and write encrypted + // values. This Database will not perform any encrypting or decrypting of + // values and is not recommended to be used when implementing a VM. + GetRawDatabase(bID ids.ID, username, password string) (database.Database, error) + + // CreateUser attempts to register this username and password as a new user + // of the keystore. + CreateUser(username, pw string) error + + // DeleteUser attempts to remove the provided username and all of its data + // from the keystore. + DeleteUser(username, pw string) error + + // ListUsers returns all the users that currently exist in this keystore. + ListUsers() ([]string, error) + + // ImportUser imports a serialized encoding of a user's information complete + // with encrypted database values. The password is integrity checked. + ImportUser(username, pw string, user []byte) error + + // ExportUser exports a serialized encoding of a user's information complete + // with encrypted database values. + ExportUser(username, pw string) ([]byte, error) + + // Get the password that is used by [username]. If [username] doesn't exist, + // no error is returned and a nil password hash is returned. + getPassword(username string) (*password.Hash, error) +} + +type kvPair struct { + Key []byte `serialize:"true"` + Value []byte `serialize:"true"` +} + +// user describes the full content of a user +type user struct { + password.Hash `serialize:"true"` + Data []kvPair `serialize:"true"` +} + +// keystore implements keystore management logic +type keystore struct { + lock sync.Mutex + log logging.Logger + + // Key: username + // Value: The hash of that user's password + usernameToPassword map[string]*password.Hash + + // Used to persist users and their data + userDB database.Database + bcDB database.Database + // BaseDB + // / \ + // UserDB BlockchainDB + // / | \ + // Usr Usr Usr + // / | \ + // BID BID BID +} + +func New(log logging.Logger, db database.Database) Keystore { + return &keystore{ + log: log, + usernameToPassword: make(map[string]*password.Hash), + userDB: prefixdb.New([]byte("users"), db), + bcDB: prefixdb.New([]byte("bcs"), db), + } +} + +func (ks *keystore) CreateHandler() (http.Handler, error) { + newServer := rpc.NewServer() + codec := jsoncodec.NewCodec() + newServer.RegisterCodec(codec, "application/json") + newServer.RegisterCodec(codec, "application/json;charset=UTF-8") + if err := newServer.RegisterService(&service{ks: ks}, "keystore"); err != nil { + return nil, err + } + return newServer, nil +} + +func (ks *keystore) NewBlockchainKeyStore(blockchainID ids.ID) BlockchainKeystore { + return &blockchainKeystore{ + blockchainID: blockchainID, + ks: ks, + } +} + +func (ks *keystore) GetDatabase(bID ids.ID, username, password string) (*encdb.Database, error) { + bcDB, err := ks.GetRawDatabase(bID, username, password) + if err != nil { + return nil, err + } + return encdb.New([]byte(password), bcDB) +} + +func (ks *keystore) GetRawDatabase(bID ids.ID, username, pw string) (database.Database, error) { + if username == "" { + return nil, errEmptyUsername + } + + ks.lock.Lock() + defer ks.lock.Unlock() + + passwordHash, err := ks.getPassword(username) + if err != nil { + return nil, err + } + if passwordHash == nil || !passwordHash.Check(pw) { + return nil, fmt.Errorf("incorrect password for user %q", username) + } + + userDB := prefixdb.New([]byte(username), ks.bcDB) + bcDB := prefixdb.NewNested(bID[:], userDB) + return bcDB, nil +} + +func (ks *keystore) CreateUser(username, pw string) error { + if username == "" { + return errEmptyUsername + } + if len(username) > maxUserLen { + return errUserMaxLength + } + + ks.lock.Lock() + defer ks.lock.Unlock() + + passwordHash, err := ks.getPassword(username) + if err != nil { + return err + } + if passwordHash != nil { + return fmt.Errorf("user already exists: %s", username) + } + + if err := password.IsValid(pw, password.OK); err != nil { + return err + } + + passwordHash = &password.Hash{} + if err := passwordHash.Set(pw); err != nil { + return err + } + + passwordBytes, err := c.Marshal(codecVersion, passwordHash) + if err != nil { + return err + } + + if err := ks.userDB.Put([]byte(username), passwordBytes); err != nil { + return err + } + ks.usernameToPassword[username] = passwordHash + + return nil +} + +func (ks *keystore) DeleteUser(username, pw string) error { + if username == "" { + return errEmptyUsername + } + if len(username) > maxUserLen { + return errUserMaxLength + } + + ks.lock.Lock() + defer ks.lock.Unlock() + + // check if user exists and valid user. + passwordHash, err := ks.getPassword(username) + switch { + case err != nil: + return err + case passwordHash == nil: + return fmt.Errorf("user doesn't exist: %s", username) + case !passwordHash.Check(pw): + return fmt.Errorf("incorrect password for user %q", username) + } + + userNameBytes := []byte(username) + userBatch := ks.userDB.NewBatch() + if err := userBatch.Delete(userNameBytes); err != nil { + return err + } + + userDataDB := prefixdb.New(userNameBytes, ks.bcDB) + dataBatch := userDataDB.NewBatch() + + it := userDataDB.NewIterator() + defer it.Release() + + for it.Next() { + if err = dataBatch.Delete(it.Key()); err != nil { + return err + } + } + + if err = it.Error(); err != nil { + return err + } + + if err := atomic.WriteAll(dataBatch, userBatch); err != nil { + return err + } + + // delete from users map. + delete(ks.usernameToPassword, username) + return nil +} + +func (ks *keystore) ListUsers() ([]string, error) { + users := []string{} + + ks.lock.Lock() + defer ks.lock.Unlock() + + it := ks.userDB.NewIterator() + defer it.Release() + for it.Next() { + users = append(users, string(it.Key())) + } + return users, it.Error() +} + +func (ks *keystore) ImportUser(username, pw string, userBytes []byte) error { + if username == "" { + return errEmptyUsername + } + if len(username) > maxUserLen { + return errUserMaxLength + } + + ks.lock.Lock() + defer ks.lock.Unlock() + + passwordHash, err := ks.getPassword(username) + if err != nil { + return err + } + if passwordHash != nil { + return fmt.Errorf("user already exists: %s", username) + } + + userData := user{} + if _, err := c.Unmarshal(userBytes, &userData); err != nil { + return err + } + if !userData.Hash.Check(pw) { + return fmt.Errorf("incorrect password for user %q", username) + } + + usrBytes, err := c.Marshal(codecVersion, &userData.Hash) + if err != nil { + return err + } + + userBatch := ks.userDB.NewBatch() + if err := userBatch.Put([]byte(username), usrBytes); err != nil { + return err + } + + userDataDB := prefixdb.New([]byte(username), ks.bcDB) + dataBatch := userDataDB.NewBatch() + for _, kvp := range userData.Data { + if err := dataBatch.Put(kvp.Key, kvp.Value); err != nil { + return fmt.Errorf("error on database put: %w", err) + } + } + + if err := atomic.WriteAll(dataBatch, userBatch); err != nil { + return err + } + ks.usernameToPassword[username] = &userData.Hash + return nil +} + +func (ks *keystore) ExportUser(username, pw string) ([]byte, error) { + if username == "" { + return nil, errEmptyUsername + } + if len(username) > maxUserLen { + return nil, errUserMaxLength + } + + ks.lock.Lock() + defer ks.lock.Unlock() + + passwordHash, err := ks.getPassword(username) + if err != nil { + return nil, err + } + if passwordHash == nil || !passwordHash.Check(pw) { + return nil, fmt.Errorf("incorrect password for user %q", username) + } + + userDB := prefixdb.New([]byte(username), ks.bcDB) + + userData := user{Hash: *passwordHash} + it := userDB.NewIterator() + defer it.Release() + for it.Next() { + userData.Data = append(userData.Data, kvPair{ + Key: it.Key(), + Value: it.Value(), + }) + } + if err := it.Error(); err != nil { + return nil, err + } + + // Return the byte representation of the user + return c.Marshal(codecVersion, &userData) +} + +func (ks *keystore) getPassword(username string) (*password.Hash, error) { + // If the user is already in memory, return it + passwordHash, exists := ks.usernameToPassword[username] + if exists { + return passwordHash, nil + } + + // The user is not in memory; try the database + userBytes, err := ks.userDB.Get([]byte(username)) + if err == database.ErrNotFound { + // The user doesn't exist + return nil, nil + } + if err != nil { + // An unexpected database error occurred + return nil, err + } + + passwordHash = &password.Hash{} + _, err = c.Unmarshal(userBytes, passwordHash) + return passwordHash, err +} diff --git a/api/keystore/service.go b/api/keystore/service.go index cfe0737676c6..69148c97b009 100644 --- a/api/keystore/service.go +++ b/api/keystore/service.go @@ -4,226 +4,43 @@ package keystore import ( - "errors" "fmt" "net/http" - "sync" - - "github.com/gorilla/rpc/v2" "github.com/ava-labs/avalanchego/api" - "github.com/ava-labs/avalanchego/chains/atomic" - "github.com/ava-labs/avalanchego/codec" - "github.com/ava-labs/avalanchego/codec/linearcodec" - "github.com/ava-labs/avalanchego/codec/reflectcodec" - "github.com/ava-labs/avalanchego/database" - "github.com/ava-labs/avalanchego/database/encdb" - "github.com/ava-labs/avalanchego/database/memdb" - "github.com/ava-labs/avalanchego/database/prefixdb" - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/utils/formatting" - "github.com/ava-labs/avalanchego/utils/logging" - "github.com/ava-labs/avalanchego/utils/password" - - jsoncodec "github.com/ava-labs/avalanchego/utils/json" -) - -const ( - // maxUserLen is the maximum allowed length of a username - maxUserLen = 1024 - - maxPackerSize = 1 << 30 // max size, in bytes, of something being marshalled by Marshal() - maxSliceLength = 1 << 18 - - codecVersion = 0 -) - -var ( - errEmptyUsername = errors.New("empty username") - errUserMaxLength = fmt.Errorf("username exceeds maximum length of %d chars", maxUserLen) ) -// KeyValuePair ... -type KeyValuePair struct { - Key []byte `serialize:"true"` - Value []byte `serialize:"true"` -} - -// UserDB describes the full content of a user -type UserDB struct { - password.Hash `serialize:"true"` - Data []KeyValuePair `serialize:"true"` +type service struct { + ks *keystore } -// Keystore is the RPC interface for keystore management -type Keystore struct { - lock sync.Mutex - log logging.Logger - codec codec.Manager +func (s *service) CreateUser(_ *http.Request, args *api.UserPass, reply *api.SuccessResponse) error { + s.ks.log.Info("Keystore: CreateUser called with %.*s", maxUserLen, args.Username) - // Key: username - // Value: The user with that name - users map[string]*password.Hash - - // Used to persist users and their data - userDB database.Database - bcDB database.Database - // BaseDB - // / \ - // UserDB BlockchainDB - // / | \ - // Usr Usr Usr - // / | \ - // BID BID BID -} - -// Initialize the keystore -func (ks *Keystore) Initialize(log logging.Logger, db database.Database) error { - c := linearcodec.New(reflectcodec.DefaultTagName, maxSliceLength) - manager := codec.NewManager(maxPackerSize) - if err := manager.RegisterCodec(codecVersion, c); err != nil { - return err - } - - ks.log = log - ks.codec = manager - ks.users = make(map[string]*password.Hash) - ks.userDB = prefixdb.New([]byte("users"), db) - ks.bcDB = prefixdb.New([]byte("bcs"), db) - return nil -} - -// CreateHandler returns a new service object that can send requests to thisAPI. -func (ks *Keystore) CreateHandler() (*common.HTTPHandler, error) { - newServer := rpc.NewServer() - codec := jsoncodec.NewCodec() - newServer.RegisterCodec(codec, "application/json") - newServer.RegisterCodec(codec, "application/json;charset=UTF-8") - if err := newServer.RegisterService(ks, "keystore"); err != nil { - return nil, err - } - return &common.HTTPHandler{LockOptions: common.NoLock, Handler: newServer}, nil -} - -// Get the user whose name is [username] -func (ks *Keystore) getUser(username string) (*password.Hash, error) { - // If the user is already in memory, return it - user, exists := ks.users[username] - if exists { - return user, nil - } - // The user is not in memory; try the database - userBytes, err := ks.userDB.Get([]byte(username)) - if err != nil { // Most likely bc user doesn't exist in database - return nil, err - } - - user = &password.Hash{} - _, err = ks.codec.Unmarshal(userBytes, user) - return user, err + reply.Success = true + return s.ks.CreateUser(args.Username, args.Password) } -// CreateUser creates an empty user with the provided username and password -func (ks *Keystore) CreateUser(_ *http.Request, args *api.UserPass, reply *api.SuccessResponse) error { - ks.log.Info("Keystore: CreateUser called with %.*s", maxUserLen, args.Username) - - ks.lock.Lock() - defer ks.lock.Unlock() - - if err := ks.AddUser(args.Username, args.Password); err != nil { - return err - } +func (s *service) DeleteUser(_ *http.Request, args *api.UserPass, reply *api.SuccessResponse) error { + s.ks.log.Info("Keystore: DeleteUser called with %s", args.Username) reply.Success = true - return nil + return s.ks.DeleteUser(args.Username, args.Password) } -// ListUsersReply is the reply from ListUsers type ListUsersReply struct { Users []string `json:"users"` } -// ListUsers lists all the registered usernames -func (ks *Keystore) ListUsers(_ *http.Request, args *struct{}, reply *ListUsersReply) error { - ks.log.Info("Keystore: ListUsers called") - - reply.Users = []string{} - - ks.lock.Lock() - defer ks.lock.Unlock() +func (s *service) ListUsers(_ *http.Request, args *struct{}, reply *ListUsersReply) error { + s.ks.log.Info("Keystore: ListUsers called") - it := ks.userDB.NewIterator() - defer it.Release() - for it.Next() { - reply.Users = append(reply.Users, string(it.Key())) - } - return it.Error() + var err error + reply.Users, err = s.ks.ListUsers() + return err } -// ExportUserArgs ... -type ExportUserArgs struct { - // The username and password - api.UserPass - // The encoding for the exported user ("hex" or "cb58") - Encoding formatting.Encoding `json:"encoding"` -} - -// ExportUserReply is the reply from ExportUser -type ExportUserReply struct { - // String representation of the user - User string `json:"user"` - // The encoding for the exported user ("hex" or "cb58") - Encoding formatting.Encoding `json:"encoding"` -} - -// ExportUser exports a serialized encoding of a user's information complete with encrypted database values -func (ks *Keystore) ExportUser(_ *http.Request, args *ExportUserArgs, reply *ExportUserReply) error { - ks.log.Info("Keystore: ExportUser called for %s", args.Username) - - ks.lock.Lock() - defer ks.lock.Unlock() - - user, err := ks.getUser(args.Username) - if err != nil { - return err - } - if !user.Check(args.Password) { - return fmt.Errorf("incorrect password for user %q", args.Username) - } - - userDB := prefixdb.New([]byte(args.Username), ks.bcDB) - - userData := UserDB{Hash: *user} - - it := userDB.NewIterator() - defer it.Release() - for it.Next() { - userData.Data = append(userData.Data, KeyValuePair{ - Key: it.Key(), - Value: it.Value(), - }) - } - if err := it.Error(); err != nil { - return err - } - - // Get byte representation of user - b, err := ks.codec.Marshal(codecVersion, &userData) - if err != nil { - return err - } - - // Encode the user from bytes to string - reply.User, err = formatting.Encode(args.Encoding, b) - if err != nil { - return fmt.Errorf("couldn't encode user to string: %w", err) - } - reply.Encoding = args.Encoding - return nil -} - -// ImportUserArgs are arguments for ImportUser type ImportUserArgs struct { // The username and password of the user being imported api.UserPass @@ -233,183 +50,46 @@ type ImportUserArgs struct { Encoding formatting.Encoding `json:"encoding"` } -// ImportUser imports a serialized encoding of a user's information complete with encrypted database values, -// integrity checks the password, and adds it to the database -func (ks *Keystore) ImportUser(r *http.Request, args *ImportUserArgs, reply *api.SuccessResponse) error { - ks.log.Info("Keystore: ImportUser called for %s", args.Username) - - if args.Username == "" { - return errEmptyUsername - } - - ks.lock.Lock() - defer ks.lock.Unlock() +func (s *service) ImportUser(r *http.Request, args *ImportUserArgs, reply *api.SuccessResponse) error { + s.ks.log.Info("Keystore: ImportUser called for %s", args.Username) // Decode the user from string to bytes - userBytes, err := formatting.Decode(args.Encoding, args.User) + user, err := formatting.Decode(args.Encoding, args.User) if err != nil { return fmt.Errorf("couldn't decode 'user' to bytes: %w", err) } - if usr, err := ks.getUser(args.Username); err == nil || usr != nil { - return fmt.Errorf("user already exists: %s", args.Username) - } - - userData := UserDB{} - if _, err := ks.codec.Unmarshal(userBytes, &userData); err != nil { - return err - } - if !userData.Hash.Check(args.Password) { - return fmt.Errorf("incorrect password for user %q", args.Username) - } - - usrBytes, err := ks.codec.Marshal(codecVersion, &userData.Hash) - if err != nil { - return err - } - - userBatch := ks.userDB.NewBatch() - if err := userBatch.Put([]byte(args.Username), usrBytes); err != nil { - return err - } - - userDataDB := prefixdb.New([]byte(args.Username), ks.bcDB) - dataBatch := userDataDB.NewBatch() - for _, kvp := range userData.Data { - if err := dataBatch.Put(kvp.Key, kvp.Value); err != nil { - return fmt.Errorf("error on database put: %w", err) - } - } - - if err := atomic.WriteAll(dataBatch, userBatch); err != nil { - return err - } - - ks.users[args.Username] = &userData.Hash - reply.Success = true - return nil + return s.ks.ImportUser(args.Username, args.Password, user) } -// DeleteUser deletes user with the provided username and password. -func (ks *Keystore) DeleteUser(_ *http.Request, args *api.UserPass, reply *api.SuccessResponse) error { - ks.log.Info("Keystore: DeleteUser called with %s", args.Username) - - if args.Username == "" { - return errEmptyUsername - } - - ks.lock.Lock() - defer ks.lock.Unlock() - - // check if user exists and valid user. - usr, err := ks.getUser(args.Username) - switch { - case err != nil || usr == nil: - return fmt.Errorf("user doesn't exist: %s", args.Username) - case !usr.Check(args.Password): - return fmt.Errorf("incorrect password for user %q", args.Username) - } - - userNameBytes := []byte(args.Username) - userBatch := ks.userDB.NewBatch() - if err := userBatch.Delete(userNameBytes); err != nil { - return err - } - - userDataDB := prefixdb.New(userNameBytes, ks.bcDB) - dataBatch := userDataDB.NewBatch() - - it := userDataDB.NewIterator() - defer it.Release() - - for it.Next() { - if err = dataBatch.Delete(it.Key()); err != nil { - return err - } - } - - if err = it.Error(); err != nil { - return err - } - - if err := atomic.WriteAll(dataBatch, userBatch); err != nil { - return err - } - - // delete from users map. - delete(ks.users, args.Username) - - reply.Success = true - return nil +type ExportUserArgs struct { + // The username and password + api.UserPass + // The encoding for the exported user ("hex" or "cb58") + Encoding formatting.Encoding `json:"encoding"` } -// NewBlockchainKeyStore ... -func (ks *Keystore) NewBlockchainKeyStore(blockchainID ids.ID) *BlockchainKeystore { - return &BlockchainKeystore{ - blockchainID: blockchainID, - ks: ks, - } +type ExportUserReply struct { + // String representation of the user + User string `json:"user"` + // The encoding for the exported user ("hex" or "cb58") + Encoding formatting.Encoding `json:"encoding"` } -// GetDatabase ... -func (ks *Keystore) GetDatabase(bID ids.ID, username, password string) (database.Database, error) { - ks.log.Info("Keystore: GetDatabase called with %s from %s", username, bID) - - ks.lock.Lock() - defer ks.lock.Unlock() +func (s *service) ExportUser(_ *http.Request, args *ExportUserArgs, reply *ExportUserReply) error { + s.ks.log.Info("Keystore: ExportUser called for %s", args.Username) - usr, err := ks.getUser(username) + userBytes, err := s.ks.ExportUser(args.Username, args.Password) if err != nil { - return nil, err - } - if !usr.Check(password) { - return nil, fmt.Errorf("incorrect password for user %q", username) - } - - userDB := prefixdb.New([]byte(username), ks.bcDB) - bcDB := prefixdb.NewNested(bID[:], userDB) - return encdb.New([]byte(password), bcDB) -} - -// AddUser attempts to register this username and password as a new user of the -// keystore. -func (ks *Keystore) AddUser(username, pword string) error { - if username == "" { - return errEmptyUsername - } - if len(username) > maxUserLen { - return errUserMaxLength - } - - if user, err := ks.getUser(username); err == nil || user != nil { - return fmt.Errorf("user already exists: %s", username) - } - - if err := password.IsValid(pword, password.OK); err != nil { return err } - user := &password.Hash{} - if err := user.Set(pword); err != nil { - return err - } - - userBytes, err := ks.codec.Marshal(codecVersion, user) + // Encode the user from bytes to string + reply.User, err = formatting.Encode(args.Encoding, userBytes) if err != nil { - return err - } - - if err := ks.userDB.Put([]byte(username), userBytes); err != nil { - return err + return fmt.Errorf("couldn't encode user to string: %w", err) } - ks.users[username] = user - + reply.Encoding = args.Encoding return nil } - -// CreateTestKeystore returns a new keystore that can be utilized for testing -func CreateTestKeystore() (*Keystore, error) { - ks := &Keystore{} - return ks, ks.Initialize(logging.NoLog{}, memdb.New()) -} diff --git a/api/keystore/service_test.go b/api/keystore/service_test.go index 69bcfde2ac52..93bba2c6e549 100644 --- a/api/keystore/service_test.go +++ b/api/keystore/service_test.go @@ -11,8 +11,10 @@ import ( "testing" "github.com/ava-labs/avalanchego/api" + "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/formatting" + "github.com/ava-labs/avalanchego/utils/logging" ) var ( @@ -22,13 +24,11 @@ var ( ) func TestServiceListNoUsers(t *testing.T) { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} reply := ListUsersReply{} - if err := ks.ListUsers(nil, nil, &reply); err != nil { + if err := s.ListUsers(nil, nil, &reply); err != nil { t.Fatal(err) } if len(reply.Users) != 0 { @@ -37,14 +37,12 @@ func TestServiceListNoUsers(t *testing.T) { } func TestServiceCreateUser(t *testing.T) { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} { reply := api.SuccessResponse{} - if err := ks.CreateUser(nil, &api.UserPass{ + if err := s.CreateUser(nil, &api.UserPass{ Username: "bob", Password: strongPassword, }, &reply); err != nil { @@ -57,7 +55,7 @@ func TestServiceCreateUser(t *testing.T) { { reply := ListUsersReply{} - if err := ks.ListUsers(nil, nil, &reply); err != nil { + if err := s.ListUsers(nil, nil, &reply); err != nil { t.Fatal(err) } if len(reply.Users) != 1 { @@ -79,38 +77,36 @@ func genStr(n int) string { // TestServiceCreateUserArgsCheck generates excessively long usernames or // passwords to assure the sanity checks on string length are not exceeded func TestServiceCreateUserArgsCheck(t *testing.T) { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} { reply := api.SuccessResponse{} - err := ks.CreateUser(nil, &api.UserPass{ + err := s.CreateUser(nil, &api.UserPass{ Username: genStr(maxUserLen + 1), Password: strongPassword, }, &reply) - if reply.Success || err != errUserMaxLength { + if err != errUserMaxLength { t.Fatal("User was created when it should have been rejected due to too long a Username, err =", err) } } { reply := api.SuccessResponse{} - err := ks.CreateUser(nil, &api.UserPass{ + err := s.CreateUser(nil, &api.UserPass{ Username: "shortuser", Password: genStr(maxUserLen + 1), }, &reply) - if reply.Success || err == nil { + if err == nil { t.Fatal("User was created when it should have been rejected due to too long a Password, err =", err) } } { reply := ListUsersReply{} - if err := ks.ListUsers(nil, nil, &reply); err != nil { + if err := s.ListUsers(nil, nil, &reply); err != nil { t.Fatal(err) } @@ -123,14 +119,12 @@ func TestServiceCreateUserArgsCheck(t *testing.T) { // TestServiceCreateUserWeakPassword tests creating a new user with a weak // password to ensure the password strength check is working func TestServiceCreateUserWeakPassword(t *testing.T) { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} { reply := api.SuccessResponse{} - err := ks.CreateUser(nil, &api.UserPass{ + err := s.CreateUser(nil, &api.UserPass{ Username: "bob", Password: "weak", }, &reply) @@ -138,22 +132,16 @@ func TestServiceCreateUserWeakPassword(t *testing.T) { if err == nil { t.Error("Expected error when testing weak password") } - - if reply.Success { - t.Fatal("User was created when it should have been rejected due to weak password") - } } } func TestServiceCreateDuplicate(t *testing.T) { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} { reply := api.SuccessResponse{} - if err := ks.CreateUser(nil, &api.UserPass{ + if err := s.CreateUser(nil, &api.UserPass{ Username: "bob", Password: strongPassword, }, &reply); err != nil { @@ -166,7 +154,7 @@ func TestServiceCreateDuplicate(t *testing.T) { { reply := api.SuccessResponse{} - if err := ks.CreateUser(nil, &api.UserPass{ + if err := s.CreateUser(nil, &api.UserPass{ Username: "bob", Password: strongPassword, }, &reply); err == nil { @@ -176,13 +164,11 @@ func TestServiceCreateDuplicate(t *testing.T) { } func TestServiceCreateUserNoName(t *testing.T) { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} reply := api.SuccessResponse{} - if err := ks.CreateUser(nil, &api.UserPass{ + if err := s.CreateUser(nil, &api.UserPass{ Password: strongPassword, }, &reply); err == nil { t.Fatalf("Shouldn't have allowed empty username") @@ -190,14 +176,12 @@ func TestServiceCreateUserNoName(t *testing.T) { } func TestServiceUseBlockchainDB(t *testing.T) { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} { reply := api.SuccessResponse{} - if err := ks.CreateUser(nil, &api.UserPass{ + if err := s.CreateUser(nil, &api.UserPass{ Username: "bob", Password: strongPassword, }, &reply); err != nil { @@ -234,14 +218,12 @@ func TestServiceUseBlockchainDB(t *testing.T) { func TestServiceExportImport(t *testing.T) { encodings := []formatting.Encoding{formatting.Hex, formatting.CB58} for _, encoding := range encodings { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} { reply := api.SuccessResponse{} - if err := ks.CreateUser(nil, &api.UserPass{ + if err := s.CreateUser(nil, &api.UserPass{ Username: "bob", Password: strongPassword, }, &reply); err != nil { @@ -270,18 +252,16 @@ func TestServiceExportImport(t *testing.T) { Encoding: encoding, } exportReply := ExportUserReply{} - if err := ks.ExportUser(nil, &exportArgs, &exportReply); err != nil { + if err := s.ExportUser(nil, &exportArgs, &exportReply); err != nil { t.Fatal(err) } - newKS, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + newKS := New(logging.NoLog{}, memdb.New()).(*keystore) + newS := service{ks: newKS} { reply := api.SuccessResponse{} - if err := newKS.ImportUser(nil, &ImportUserArgs{ + if err := newS.ImportUser(nil, &ImportUserArgs{ UserPass: api.UserPass{ Username: "bob", Password: "", @@ -294,7 +274,7 @@ func TestServiceExportImport(t *testing.T) { { reply := api.SuccessResponse{} - if err := newKS.ImportUser(nil, &ImportUserArgs{ + if err := newS.ImportUser(nil, &ImportUserArgs{ UserPass: api.UserPass{ Username: "", Password: "strongPassword", @@ -307,7 +287,7 @@ func TestServiceExportImport(t *testing.T) { { reply := api.SuccessResponse{} - if err := newKS.ImportUser(nil, &ImportUserArgs{ + if err := newS.ImportUser(nil, &ImportUserArgs{ UserPass: api.UserPass{ Username: "bob", Password: strongPassword, @@ -341,7 +321,7 @@ func TestServiceDeleteUser(t *testing.T) { password := "passwTest@fake01ord" tests := []struct { desc string - setup func(ks *Keystore) error + setup func(ks *keystore) error request *api.UserPass want *api.SuccessResponse wantError bool @@ -359,17 +339,19 @@ func TestServiceDeleteUser(t *testing.T) { wantError: true, }, { desc: "user exists and valid password case", - setup: func(ks *Keystore) error { - return ks.CreateUser(nil, &api.UserPass{Username: testUser, Password: password}, &api.SuccessResponse{}) + setup: func(ks *keystore) error { + s := service{ks: ks} + return s.CreateUser(nil, &api.UserPass{Username: testUser, Password: password}, &api.SuccessResponse{}) }, request: &api.UserPass{Username: testUser, Password: password}, want: &api.SuccessResponse{Success: true}, }, { desc: "delete a user, imported from import api case", - setup: func(ks *Keystore) error { + setup: func(ks *keystore) error { + s := service{ks: ks} reply := api.SuccessResponse{} - if err := ks.CreateUser(nil, &api.UserPass{Username: testUser, Password: password}, &reply); err != nil { + if err := s.CreateUser(nil, &api.UserPass{Username: testUser, Password: password}, &reply); err != nil { return err } @@ -390,10 +372,8 @@ func TestServiceDeleteUser(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - ks, err := CreateTestKeystore() - if err != nil { - t.Fatal(err) - } + ks := New(logging.NoLog{}, memdb.New()).(*keystore) + s := service{ks: ks} if tt.setup != nil { if err := tt.setup(ks); err != nil { @@ -401,7 +381,7 @@ func TestServiceDeleteUser(t *testing.T) { } } got := &api.SuccessResponse{} - err = ks.DeleteUser(nil, tt.request, got) + err := s.DeleteUser(nil, tt.request, got) if (err != nil) != tt.wantError { t.Fatalf("DeleteUser() failed: error %v, wantError %v", err, tt.wantError) } @@ -411,12 +391,12 @@ func TestServiceDeleteUser(t *testing.T) { } if err == nil && got.Success { // delete is successful - if _, ok := ks.users[testUser]; ok { + if _, ok := ks.usernameToPassword[testUser]; ok { t.Fatalf("DeleteUser() failed: expected the user %s should be delete from users map", testUser) } // deleted user details should be available to create user again. - if err = ks.CreateUser(nil, &api.UserPass{Username: testUser, Password: password}, &api.SuccessResponse{}); err != nil { + if err = s.CreateUser(nil, &api.UserPass{Username: testUser, Password: password}, &api.SuccessResponse{}); err != nil { t.Fatalf("failed to create user: %v", err) } } diff --git a/utils/json/pubsub_server.go b/api/pubsub/pubsub.go similarity index 89% rename from utils/json/pubsub_server.go rename to api/pubsub/pubsub.go index 22431861bf11..51405fa6149e 100644 --- a/utils/json/pubsub_server.go +++ b/api/pubsub/pubsub.go @@ -1,7 +1,7 @@ // (c) 2019-2020, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package json +package pubsub import ( "errors" @@ -47,8 +47,8 @@ var ( errDuplicateChannel = errors.New("duplicate channel") ) -// PubSubServer maintains the set of active clients and sends messages to the clients. -type PubSubServer struct { +// Server maintains the set of active clients and sends messages to the clients. +type Server struct { ctx *snow.Context lock sync.Mutex @@ -56,16 +56,16 @@ type PubSubServer struct { channels map[string]map[*Connection]struct{} } -// NewPubSubServer ... -func NewPubSubServer(ctx *snow.Context) *PubSubServer { - return &PubSubServer{ +// NewServer ... +func NewServer(ctx *snow.Context) *Server { + return &Server{ ctx: ctx, conns: make(map[*Connection]map[string]struct{}), channels: make(map[string]map[*Connection]struct{}), } } -func (s *PubSubServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { wsConn, err := upgrader.Upgrade(w, r, nil) if err != nil { s.ctx.Log.Debug("Failed to upgrade %s", err) @@ -76,7 +76,7 @@ func (s *PubSubServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Publish ... -func (s *PubSubServer) Publish(channel string, msg interface{}) { +func (s *Server) Publish(channel string, msg interface{}) { s.lock.Lock() defer s.lock.Unlock() @@ -101,7 +101,7 @@ func (s *PubSubServer) Publish(channel string, msg interface{}) { } // Register ... -func (s *PubSubServer) Register(channel string) error { +func (s *Server) Register(channel string) error { s.lock.Lock() defer s.lock.Unlock() @@ -113,7 +113,7 @@ func (s *PubSubServer) Register(channel string) error { return nil } -func (s *PubSubServer) addConnection(conn *Connection) { +func (s *Server) addConnection(conn *Connection) { s.lock.Lock() defer s.lock.Unlock() s.conns[conn] = make(map[string]struct{}) @@ -122,7 +122,7 @@ func (s *PubSubServer) addConnection(conn *Connection) { go conn.readPump() } -func (s *PubSubServer) removeConnection(conn *Connection) { +func (s *Server) removeConnection(conn *Connection) { s.lock.Lock() defer s.lock.Unlock() @@ -137,7 +137,7 @@ func (s *PubSubServer) removeConnection(conn *Connection) { } } -func (s *PubSubServer) addChannel(conn *Connection, channel string) { +func (s *Server) addChannel(conn *Connection, channel string) { s.lock.Lock() defer s.lock.Unlock() @@ -155,7 +155,7 @@ func (s *PubSubServer) addChannel(conn *Connection, channel string) { conns[conn] = struct{}{} } -func (s *PubSubServer) removeChannel(conn *Connection, channel string) { +func (s *Server) removeChannel(conn *Connection, channel string) { s.lock.Lock() defer s.lock.Unlock() @@ -185,7 +185,7 @@ type subscribe struct { // Connection is a representation of the websocket connection. type Connection struct { - s *PubSubServer + s *Server // The websocket connection. conn *websocket.Conn diff --git a/api/middleware_handler.go b/api/server/middleware_handler.go similarity index 96% rename from api/middleware_handler.go rename to api/server/middleware_handler.go index fba05635a0c2..5d1ece740d5c 100644 --- a/api/middleware_handler.go +++ b/api/server/middleware_handler.go @@ -1,7 +1,7 @@ // (c) 2019-2020, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package api +package server import ( "net/http" diff --git a/api/router.go b/api/server/router.go similarity index 99% rename from api/router.go rename to api/server/router.go index 5fda13b946f6..5bab97347f0c 100644 --- a/api/router.go +++ b/api/server/router.go @@ -1,7 +1,7 @@ // (c) 2019-2020, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package api +package server import ( "errors" diff --git a/api/router_test.go b/api/server/router_test.go similarity index 99% rename from api/router_test.go rename to api/server/router_test.go index e8decd66a5ea..a62703e837a2 100644 --- a/api/router_test.go +++ b/api/server/router_test.go @@ -1,7 +1,7 @@ // (c) 2019-2020, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package api +package server import ( "net/http" diff --git a/api/server.go b/api/server/server.go similarity index 83% rename from api/server.go rename to api/server/server.go index 7dccf9c2730c..9c3729076776 100644 --- a/api/server.go +++ b/api/server/server.go @@ -1,7 +1,7 @@ // (c) 2019-2020, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package api +package server import ( "context" @@ -18,7 +18,6 @@ import ( "github.com/rs/cors" - "github.com/ava-labs/avalanchego/api/auth" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/utils/logging" @@ -33,6 +32,10 @@ var ( errUnknownLockOption = errors.New("invalid lock options") ) +type RouteAdder interface { + AddRoute(handler *common.HTTPHandler, lock *sync.RWMutex, base, endpoint string, loggingWriter io.Writer) error +} + // Server maintains the HTTP router type Server struct { // log this server writes to @@ -45,9 +48,6 @@ type Server struct { handler http.Handler // Listens for HTTP traffic on this address listenAddress string - // Handles authorization. Must be non-nil after initialization, even if - // token authorization is off. - auth *auth.Auth // http server srv *http.Server @@ -59,37 +59,24 @@ func (s *Server) Initialize( factory logging.Factory, host string, port uint16, - authEnabled bool, - authPassword string, allowedOrigins []string, -) error { + wrappers ...Wrapper, +) { s.log = log s.factory = factory s.listenAddress = fmt.Sprintf("%s:%d", host, port) s.router = newRouter() - a, err := auth.New(authEnabled, authPassword) - if err != nil { - return err - } - s.auth = a - s.log.Info("API created with allowed origins: %v", allowedOrigins) corsWrapper := cors.New(cors.Options{ - AllowedOrigins: allowedOrigins, + AllowedOrigins: allowedOrigins, + AllowCredentials: true, }) - corsHandler := corsWrapper.Handler(s.router) - s.handler = s.auth.WrapHandler(corsHandler) + s.handler = corsWrapper.Handler(s.router) - if !authEnabled { - return nil + for _, wrapper := range wrappers { + s.handler = wrapper.WrapHandler(s.handler) } - - // only create auth service if token authorization is required - s.log.Info("API authorization is enabled. Auth tokens must be passed in the header of API requests, except requests to the auth service.") - authService := auth.NewService(s.log, s.auth) - return s.AddRoute(authService, &sync.RWMutex{}, auth.Endpoint, "", s.log) - } // Dispatch starts the API server @@ -113,25 +100,34 @@ func (s *Server) DispatchTLS(certFile, keyFile string) error { return http.ServeTLS(listener, s.handler, certFile, keyFile) } -// RegisterChain registers the API endpoints associated with this chain That is, -// add pairs to server so that http calls can be made to the vm -func (s *Server) RegisterChain(chainName string, ctx *snow.Context, vmIntf interface{}) { - vm, ok := vmIntf.(common.VM) - if !ok { - return - } +// RegisterChain registers the API endpoints associated with this chain. That is, +// add pairs to server so that API calls can be made to the VM. +// This method runs in a goroutine to avoid a deadlock in the event that the caller +// holds the engine's context lock. Namely, this could happen when the P-Chain is +// creating a new chain and holds the P-Chain's lock when this function is held, +// and at the same time the server's lock is held due to an API call and is trying +// to grab the P-Chain's lock. +func (s *Server) RegisterChain(chainName string, ctx *snow.Context, engine common.Engine) { + go s.registerChain(chainName, ctx, engine) +} + +func (s *Server) registerChain(chainName string, ctx *snow.Context, engine common.Engine) { + var ( + handlers map[string]*common.HTTPHandler + err error + ) ctx.Lock.Lock() - handlers, err := vm.CreateHandlers() + handlers, err = engine.GetVM().CreateHandlers() ctx.Lock.Unlock() if err != nil { - s.log.Error("Failed to create %s handlers: %s", chainName, err) + s.log.Error("failed to create %s handlers: %s", chainName, err) return } httpLogger, err := s.factory.MakeChain(chainName, "http") if err != nil { - s.log.Error("Failed to create new http logger: %s", err) + s.log.Error("failed to create new http logger: %s", err) return } @@ -140,7 +136,7 @@ func (s *Server) RegisterChain(chainName string, ctx *snow.Context, vmIntf inter defaultEndpoint := "bc/" + ctx.ChainID.String() // Register each endpoint - for extension, service := range handlers { + for extension, handler := range handlers { // Validate that the route being added is valid // e.g. "/foo" and "" are ok but "\n" is not _, err := url.ParseRequestURI(extension) @@ -148,7 +144,7 @@ func (s *Server) RegisterChain(chainName string, ctx *snow.Context, vmIntf inter s.log.Error("could not add route to chain's API handler because route is malformed: %s", err) continue } - if err := s.AddChainRoute(service, ctx, defaultEndpoint, extension, httpLogger); err != nil { + if err := s.AddChainRoute(handler, ctx, defaultEndpoint, extension, httpLogger); err != nil { s.log.Error("error adding route: %s", err) } } diff --git a/api/server_test.go b/api/server/server_test.go similarity index 93% rename from api/server_test.go rename to api/server/server_test.go index 0f0f70a062ea..5582dca47aab 100644 --- a/api/server_test.go +++ b/api/server/server_test.go @@ -1,7 +1,7 @@ // (c) 2019-2020, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package api +package server import ( "bytes" @@ -30,18 +30,13 @@ func (s *Service) Call(_ *http.Request, args *Args, reply *Reply) error { func TestCall(t *testing.T) { s := Server{} - err := s.Initialize( + s.Initialize( logging.NoLog{}, logging.NoFactory{}, "localhost", 8080, - false, - "", []string{"*"}, ) - if err != nil { - t.Fatal(err) - } serv := &Service{} newServer := rpc.NewServer() @@ -51,7 +46,7 @@ func TestCall(t *testing.T) { t.Fatal(err) } - err = s.AddRoute( + err := s.AddRoute( &common.HTTPHandler{Handler: newServer}, new(sync.RWMutex), "vm/lol", diff --git a/api/server/wrapper.go b/api/server/wrapper.go new file mode 100644 index 000000000000..bef6b91f07e1 --- /dev/null +++ b/api/server/wrapper.go @@ -0,0 +1,13 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package server + +import ( + "net/http" +) + +type Wrapper interface { + // WrapHandler wraps an http.Handler. + WrapHandler(h http.Handler) http.Handler +} diff --git a/chains/atomic/shared_memory.go b/chains/atomic/shared_memory.go index 699ef0ec9584..d80bcc7b928e 100644 --- a/chains/atomic/shared_memory.go +++ b/chains/atomic/shared_memory.go @@ -317,29 +317,37 @@ func (s *state) getKeys(traits [][]byte, startTrait, startKey []byte, limit int) } lastTrait = trait - lastKey = startKey + var err error + lastKey, err = s.appendTraitKeys(&keys, &tracked, &limit, trait, startKey) + if err != nil { + return nil, nil, nil, err + } - traitDB := prefixdb.New(trait, s.indexDB) - iter := traitDB.NewIteratorWithStart(startKey) - for iter.Next() { - if limit == 0 { - iter.Release() - return keys, lastTrait, lastKey, nil - } - - key := iter.Key() - lastKey = key - - id := hashing.ComputeHash256Array(key) - if tracked.Contains(id) { - continue - } - - tracked.Add(id) - keys = append(keys, key) - limit-- + if limit == 0 { + break } - iter.Release() } return keys, lastTrait, lastKey, nil } + +func (s *state) appendTraitKeys(keys *[][]byte, tracked *ids.Set, limit *int, trait, startKey []byte) ([]byte, error) { + lastKey := startKey + + traitDB := prefixdb.New(trait, s.indexDB) + iter := traitDB.NewIteratorWithStart(startKey) + defer iter.Release() + for iter.Next() && *limit > 0 { + key := iter.Key() + lastKey = key + + id := hashing.ComputeHash256Array(key) + if tracked.Contains(id) { + continue + } + + tracked.Add(id) + *keys = append(*keys, key) + *limit-- + } + return lastKey, iter.Error() +} diff --git a/chains/manager.go b/chains/manager.go index 670434e2f260..3364b4d22a54 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -9,9 +9,9 @@ import ( "sync" "time" - "github.com/ava-labs/avalanchego/api" "github.com/ava-labs/avalanchego/api/health" "github.com/ava-labs/avalanchego/api/keystore" + "github.com/ava-labs/avalanchego/api/server" "github.com/ava-labs/avalanchego/chains/atomic" "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/meterdb" @@ -65,7 +65,7 @@ type Manager interface { ForceCreateChain(ChainParameters) // Add a registrant [r]. Every time a chain is - // created, [r].RegisterChain([new chain]) is called + // created, [r].RegisterChain([new chain]) is called. AddRegistrant(Registrant) // Given an alias, return the ID of the chain associated with that alias @@ -130,8 +130,8 @@ type ManagerConfig struct { Validators validators.Manager // Validators validating on this chain NodeID ids.ShortID // The ID of this node NetworkID uint32 // ID of the network this node is connected to - Server *api.Server // Handles HTTP API calls - Keystore *keystore.Keystore + Server *server.Server // Handles HTTP API calls + Keystore keystore.Keystore AtomicMemory *atomic.Memory AVAXAssetID ids.ID XChainID ids.ID @@ -149,7 +149,8 @@ type manager struct { ids.Aliaser ManagerConfig - registrants []Registrant // Those notified when a chain is created + // Those notified when a chain is created + registrants []Registrant unblocked bool blockedChains []ChainParameters @@ -240,7 +241,20 @@ func (m *manager) ForceCreateChain(chainParams ChainParameters) { m.Log.AssertNoError(m.Alias(chainParams.ID, chainParams.ID.String())) // Notify those that registered to be notified when a new chain is created - m.notifyRegistrants(chain.Name, chain.Ctx, chain.VM) + m.notifyRegistrants(chain.Name, chain.Ctx, chain.Engine) + + // Tell the chain to start processing messages. + // If the X or P Chain panics, do not attempt to recover + if m.CriticalChains.Contains(chainParams.ID) { + go chain.Ctx.Log.RecoverAndPanic(chain.Handler.Dispatch) + } else { + go chain.Ctx.Log.RecoverAndExit(chain.Handler.Dispatch, func() { + chain.Ctx.Log.Error("Chain with ID: %s was shutdown due to a panic", chainParams.ID) + }) + } + + // Allows messages to be routed to the new chain + m.ManagerConfig.Router.AddChain(chain.Handler) } // Create a chain @@ -382,17 +396,6 @@ func (m *manager) buildChain(chainParams ChainParameters, sb Subnet) (*chain, er return nil, err } - // Allows messages to be routed to the new chain - m.ManagerConfig.Router.AddChain(chain.Handler) - - // If the X or P Chain panics, do not attempt to recover - if m.CriticalChains.Contains(chainParams.ID) { - go ctx.Log.RecoverAndPanic(chain.Handler.Dispatch) - } else { - go ctx.Log.RecoverAndExit(chain.Handler.Dispatch, func() { - ctx.Log.Error("Chain with ID: %s was shutdown due to a panic", chainParams.ID) - }) - } return chain, nil } @@ -514,7 +517,7 @@ func (m *manager) createAvalancheChain( // Asynchronously passes messages from the network to the consensus engine handler := &router.Handler{} - handler.Initialize( + err = handler.Initialize( engine, validators, msgChan, @@ -533,7 +536,7 @@ func (m *manager) createAvalancheChain( Handler: handler, VM: vm, Ctx: ctx, - }, nil + }, err } // Create a linear chain using the Snowman consensus engine @@ -623,7 +626,7 @@ func (m *manager) createSnowmanChain( // Asynchronously passes messages from the network to the consensus engine handler := &router.Handler{} - handler.Initialize( + err = handler.Initialize( engine, validators, msgChan, @@ -635,6 +638,9 @@ func (m *manager) createSnowmanChain( consensusParams.Metrics, delay, ) + if err != nil { + return nil, fmt.Errorf("couldn't initialize message handler: %s", err) + } // Register health checks chainAlias, err := m.PrimaryAlias(ctx.ChainID) @@ -693,9 +699,9 @@ func (m *manager) LookupVM(alias string) (ids.ID, error) { return m.VMManager.Lo // Notify registrants [those who want to know about the creation of chains] // that the specified chain has been created -func (m *manager) notifyRegistrants(name string, ctx *snow.Context, vm interface{}) { +func (m *manager) notifyRegistrants(name string, ctx *snow.Context, engine common.Engine) { for _, registrant := range m.registrants { - go registrant.RegisterChain(name, ctx, vm) + registrant.RegisterChain(name, ctx, engine) } } diff --git a/chains/registrant.go b/chains/registrant.go index 7ede8a4b770e..9ccfc3fbded5 100644 --- a/chains/registrant.go +++ b/chains/registrant.go @@ -5,9 +5,13 @@ package chains import ( "github.com/ava-labs/avalanchego/snow" + "github.com/ava-labs/avalanchego/snow/engine/common" ) // Registrant can register the existence of a chain type Registrant interface { - RegisterChain(name string, ctx *snow.Context, vm interface{}) + // Called when the chain described by [ctx] and [engine] is created + // This function is called before the chain starts processing messages + // [engine] should be an avalanche.Engine or snowman.Engine + RegisterChain(name string, ctx *snow.Context, engine common.Engine) } diff --git a/database/rpcdb/db_client.go b/database/rpcdb/db_client.go index f2a38e0f852a..b33c8cd1755b 100644 --- a/database/rpcdb/db_client.go +++ b/database/rpcdb/db_client.go @@ -261,8 +261,12 @@ func (it *iterator) Value() []byte { return it.value } // Release frees any resources held by the iterator func (it *iterator) Release() { - _, err := it.db.client.IteratorRelease(context.Background(), &rpcdbproto.IteratorReleaseRequest{ + resp, err := it.db.client.IteratorRelease(context.Background(), &rpcdbproto.IteratorReleaseRequest{ Id: it.id, }) - it.errs.Add(err) + if err != nil { + it.errs.Add(err) + } else { + it.errs.Add(errCodeToError[resp.Err]) + } } diff --git a/database/rpcdb/db_server.go b/database/rpcdb/db_server.go index 4a30eea5a7d6..2cbea8bd5670 100644 --- a/database/rpcdb/db_server.go +++ b/database/rpcdb/db_server.go @@ -162,9 +162,12 @@ func (db *DatabaseServer) IteratorRelease(_ context.Context, req *rpcdbproto.Ite defer db.lock.Unlock() it, exists := db.iterators[req.Id] - if exists { - delete(db.iterators, req.Id) - it.Release() + if !exists { + return &rpcdbproto.IteratorReleaseResponse{Err: 0}, nil } - return &rpcdbproto.IteratorReleaseResponse{}, nil + + delete(db.iterators, req.Id) + err := it.Error() + it.Release() + return &rpcdbproto.IteratorReleaseResponse{Err: errorToErrCode[err]}, errorToRPCError(err) } diff --git a/database/rpcdb/rpcdbproto/rpcdb.pb.go b/database/rpcdb/rpcdbproto/rpcdb.pb.go index 785765bd7694..73ff7b7f05a7 100644 --- a/database/rpcdb/rpcdbproto/rpcdb.pb.go +++ b/database/rpcdb/rpcdbproto/rpcdb.pb.go @@ -1017,6 +1017,7 @@ func (m *IteratorReleaseRequest) GetId() uint64 { } type IteratorReleaseResponse struct { + Err uint32 `protobuf:"varint,1,opt,name=err,proto3" json:"err,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -1047,6 +1048,13 @@ func (m *IteratorReleaseResponse) XXX_DiscardUnknown() { var xxx_messageInfo_IteratorReleaseResponse proto.InternalMessageInfo +func (m *IteratorReleaseResponse) GetErr() uint32 { + if m != nil { + return m.Err + } + return 0 +} + func init() { proto.RegisterType((*HasRequest)(nil), "rpcdbproto.HasRequest") proto.RegisterType((*HasResponse)(nil), "rpcdbproto.HasResponse") @@ -1079,49 +1087,49 @@ func init() { proto.RegisterFile("rpcdb.proto", fileDescriptor_af52f4b90339c3f4) var fileDescriptor_af52f4b90339c3f4 = []byte{ // 677 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x54, 0xdd, 0x4f, 0x13, 0x4f, - 0x14, 0x4d, 0x3f, 0xf8, 0x3a, 0x2d, 0xe5, 0xf7, 0x1b, 0x2b, 0x94, 0x55, 0xbe, 0x16, 0x21, 0xc5, - 0x07, 0x22, 0x1f, 0x62, 0x4c, 0x48, 0x8c, 0x80, 0x01, 0x63, 0x42, 0x70, 0x21, 0x21, 0x31, 0xbe, - 0x0c, 0x74, 0x08, 0x1b, 0x0b, 0xbb, 0xce, 0x4e, 0x15, 0xdf, 0x7d, 0xf1, 0xbf, 0x36, 0x33, 0x9d, - 0xdd, 0x99, 0x69, 0x77, 0xab, 0xbe, 0xcd, 0x9d, 0x7b, 0xce, 0x99, 0xbb, 0x77, 0xef, 0xb9, 0xa8, - 0xf1, 0xf8, 0xba, 0x73, 0xb5, 0x19, 0xf3, 0x48, 0x44, 0x04, 0x2a, 0x50, 0x67, 0x7f, 0x11, 0x38, - 0xa1, 0x49, 0xc0, 0xbe, 0xf6, 0x58, 0x22, 0xc8, 0x7f, 0xa8, 0x7c, 0x61, 0x3f, 0x5a, 0xa5, 0xe5, - 0x52, 0xbb, 0x1e, 0xc8, 0xa3, 0xbf, 0x85, 0x9a, 0xca, 0x27, 0x71, 0x74, 0x9f, 0x30, 0x09, 0xb8, - 0xa5, 0x89, 0x02, 0x4c, 0x06, 0xf2, 0x28, 0x6f, 0x18, 0xe7, 0xad, 0xf2, 0x72, 0xa9, 0x3d, 0x1d, - 0xc8, 0xa3, 0x94, 0x3c, 0x66, 0xa2, 0x58, 0xf2, 0x25, 0x6a, 0x2a, 0xaf, 0x25, 0x9b, 0x18, 0xfb, - 0x46, 0xbb, 0x3d, 0xa6, 0x21, 0xfd, 0x20, 0x47, 0x76, 0x17, 0x38, 0xeb, 0x15, 0xcb, 0x1a, 0x9d, - 0xb2, 0xa5, 0xe3, 0x2f, 0xa1, 0xa6, 0x58, 0xa6, 0x7e, 0x29, 0x5b, 0x32, 0xb2, 0x2b, 0x98, 0x3e, - 0x62, 0x5d, 0x26, 0x58, 0x71, 0xc1, 0x3e, 0x1a, 0x29, 0xa4, 0x50, 0x66, 0x03, 0xb5, 0x73, 0x41, - 0xb3, 0xf2, 0x3c, 0x4c, 0xc6, 0x3c, 0x8a, 0x19, 0x17, 0x7d, 0xa5, 0xa9, 0x20, 0x8b, 0xfd, 0x5d, - 0xd4, 0xfb, 0x50, 0x2d, 0x46, 0x50, 0x4d, 0x04, 0x15, 0x1a, 0xa7, 0xce, 0x39, 0x9f, 0xbf, 0x8f, - 0xc6, 0x61, 0x74, 0x17, 0xd3, 0xeb, 0xec, 0x8d, 0x26, 0xc6, 0x12, 0x41, 0xb9, 0x48, 0x1b, 0xa7, - 0x02, 0x79, 0xdb, 0x0d, 0xef, 0x42, 0x91, 0xb6, 0x41, 0x05, 0xfe, 0x2a, 0x66, 0x32, 0x76, 0xe1, - 0x37, 0x34, 0x50, 0x3f, 0xec, 0x46, 0x49, 0xda, 0x09, 0xd9, 0x1a, 0x1d, 0x17, 0x52, 0x04, 0xfe, - 0xbf, 0xe4, 0xa1, 0x60, 0x07, 0x54, 0x5c, 0xdf, 0xa6, 0x85, 0x3d, 0x47, 0x35, 0xee, 0x09, 0x39, - 0x25, 0x95, 0x76, 0x6d, 0x7b, 0x76, 0xd3, 0x8c, 0xdb, 0xa6, 0xf9, 0x83, 0x81, 0xc2, 0x90, 0x1d, - 0x4c, 0x74, 0x54, 0x6f, 0x93, 0x56, 0x59, 0xc1, 0xe7, 0x6d, 0xb8, 0xf3, 0x67, 0x82, 0x14, 0xe9, - 0xaf, 0x83, 0xd8, 0xaf, 0x16, 0x56, 0xd7, 0x04, 0x39, 0x65, 0xdf, 0xdf, 0x0b, 0xc6, 0xa9, 0x88, - 0x78, 0xfa, 0x59, 0x17, 0x78, 0x66, 0xdd, 0x5e, 0x86, 0xe2, 0xf6, 0x5c, 0x76, 0xee, 0xed, 0x7d, - 0xe7, 0x8c, 0xb3, 0x9b, 0xf0, 0x61, 0x74, 0x7f, 0x67, 0x31, 0x1e, 0x2b, 0x98, 0x6e, 0xb0, 0x8e, - 0xfc, 0x57, 0x58, 0xfb, 0x83, 0xaa, 0x2e, 0xb3, 0x81, 0x72, 0xd8, 0x51, 0x9a, 0xd5, 0xa0, 0x1c, - 0x76, 0xfc, 0x35, 0x3c, 0x4a, 0x59, 0xa7, 0xec, 0x21, 0xfb, 0xbb, 0x83, 0xb0, 0xcf, 0x68, 0xba, - 0x30, 0x2d, 0xf7, 0x14, 0x53, 0x37, 0x51, 0xef, 0xbe, 0x23, 0x2f, 0xb5, 0x2f, 0xcd, 0x45, 0x3a, - 0xcc, 0xe5, 0x1c, 0x9b, 0x54, 0x6c, 0x9b, 0xac, 0x1b, 0xf5, 0x77, 0x9c, 0x67, 0xbd, 0x1a, 0xaa, - 0x62, 0x03, 0x8f, 0x07, 0x70, 0x85, 0xcd, 0x6f, 0x63, 0xd6, 0x74, 0xbe, 0xcb, 0x68, 0x36, 0x57, - 0x43, 0xa2, 0xf3, 0x98, 0x1b, 0x42, 0xf6, 0x65, 0xb7, 0x7f, 0x4d, 0x60, 0xf2, 0x88, 0x0a, 0x7a, - 0x45, 0x13, 0x46, 0xf6, 0x50, 0x39, 0xa1, 0x09, 0x71, 0x06, 0xca, 0x2c, 0x2f, 0x6f, 0x6e, 0xe8, - 0x5e, 0xd7, 0xb6, 0x87, 0xca, 0x31, 0x13, 0x2e, 0xcf, 0x6c, 0x28, 0x97, 0x67, 0x6f, 0xa6, 0x3d, - 0x54, 0xce, 0x7a, 0x03, 0x3c, 0x33, 0xc0, 0x2e, 0xcf, 0x5e, 0x32, 0x6f, 0x30, 0xde, 0x1f, 0x5c, - 0x52, 0x3c, 0xcc, 0x9e, 0x97, 0x97, 0xd2, 0x02, 0xaf, 0x51, 0x95, 0x1b, 0x82, 0x38, 0x2f, 0x58, - 0xeb, 0xc5, 0x6b, 0x0d, 0x27, 0x34, 0xf5, 0x00, 0x13, 0xda, 0xe8, 0xc4, 0x79, 0xc1, 0xdd, 0x1d, - 0xde, 0x93, 0xdc, 0x9c, 0xd6, 0xd8, 0xc7, 0x98, 0xf2, 0x3d, 0x71, 0x9e, 0xb1, 0x57, 0x83, 0x37, - 0x9f, 0x93, 0xd1, 0xec, 0x0f, 0x80, 0x31, 0x27, 0x59, 0xb0, 0x81, 0x43, 0xab, 0xc2, 0x5b, 0x2c, - 0x4a, 0x6b, 0xb1, 0x9f, 0x25, 0x2c, 0x8c, 0xb4, 0x15, 0x79, 0x61, 0x2b, 0xfc, 0x8d, 0xaf, 0xbd, - 0xad, 0x7f, 0x60, 0xe8, 0x32, 0x3e, 0xa2, 0x6e, 0x9b, 0x8f, 0x2c, 0xd9, 0x12, 0x39, 0xee, 0xf5, - 0x96, 0x8b, 0x01, 0x5a, 0xf2, 0x02, 0xd3, 0x8e, 0x93, 0x48, 0x2e, 0xc5, 0x36, 0xa3, 0xb7, 0x32, - 0x02, 0xa1, 0x55, 0x3f, 0x61, 0x66, 0xc0, 0x4a, 0xc4, 0xcf, 0x63, 0xb9, 0x8e, 0xf4, 0x56, 0x47, - 0x62, 0xfa, 0xda, 0x57, 0xe3, 0x2a, 0xbd, 0xf3, 0x3b, 0x00, 0x00, 0xff, 0xff, 0x78, 0xfe, 0x0b, - 0x6b, 0x4c, 0x08, 0x00, 0x00, + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x94, 0xed, 0x4f, 0xd4, 0x40, + 0x10, 0xc6, 0x73, 0x2f, 0xbc, 0x3d, 0x77, 0x1c, 0xba, 0x9e, 0x70, 0x54, 0x79, 0x2b, 0x42, 0x0e, + 0x4d, 0x88, 0xbc, 0x88, 0x31, 0x21, 0x31, 0x02, 0x06, 0x8c, 0x09, 0xc1, 0x42, 0x42, 0x62, 0xfc, + 0xb2, 0x70, 0x4b, 0x68, 0x3c, 0xae, 0x75, 0xbb, 0xa7, 0xf8, 0xdd, 0x2f, 0xfe, 0xd7, 0x66, 0xf7, + 0xb6, 0xed, 0xf6, 0xda, 0x3d, 0xf5, 0xdb, 0xce, 0xce, 0x33, 0xbf, 0x9d, 0x4e, 0x67, 0x06, 0x35, + 0x1e, 0x5e, 0x77, 0xae, 0x36, 0x43, 0x1e, 0x88, 0x80, 0x40, 0x19, 0xea, 0xec, 0x2e, 0x02, 0x27, + 0x34, 0xf2, 0xd8, 0xb7, 0x3e, 0x8b, 0x04, 0x79, 0x80, 0xca, 0x57, 0xf6, 0xb3, 0x55, 0x5a, 0x2e, + 0xb5, 0xeb, 0x9e, 0x3c, 0xba, 0x5b, 0xa8, 0x29, 0x7f, 0x14, 0x06, 0xbd, 0x88, 0x49, 0xc1, 0x2d, + 0x8d, 0x94, 0x60, 0xd2, 0x93, 0x47, 0x79, 0xc3, 0x38, 0x6f, 0x95, 0x97, 0x4b, 0xed, 0x69, 0x4f, + 0x1e, 0x25, 0xf2, 0x98, 0x09, 0x3b, 0xf2, 0x15, 0x6a, 0xca, 0xaf, 0x91, 0x4d, 0x8c, 0x7d, 0xa7, + 0xdd, 0x3e, 0xd3, 0x92, 0x81, 0x51, 0x80, 0xdd, 0x05, 0xce, 0xfa, 0x76, 0x6c, 0xca, 0x29, 0x1b, + 0x1c, 0x77, 0x09, 0x35, 0x15, 0x95, 0xe6, 0x2f, 0xb1, 0xa5, 0x14, 0xbb, 0x82, 0xe9, 0x23, 0xd6, + 0x65, 0x82, 0xd9, 0x13, 0x76, 0xd1, 0x88, 0x25, 0x56, 0xcc, 0x06, 0x6a, 0xe7, 0x82, 0x26, 0xe9, + 0x39, 0x98, 0x0c, 0x79, 0x10, 0x32, 0x2e, 0x06, 0xa4, 0x29, 0x2f, 0xb1, 0xdd, 0x5d, 0xd4, 0x07, + 0x52, 0x0d, 0x23, 0xa8, 0x46, 0x82, 0x0a, 0xad, 0x53, 0xe7, 0x82, 0xcf, 0xdf, 0x47, 0xe3, 0x30, + 0xb8, 0x0b, 0xe9, 0x75, 0xf2, 0x46, 0x13, 0x63, 0x91, 0xa0, 0x5c, 0xc4, 0x85, 0x53, 0x86, 0xbc, + 0xed, 0xfa, 0x77, 0xbe, 0x88, 0xcb, 0xa0, 0x0c, 0x77, 0x15, 0x33, 0x49, 0xb4, 0xf5, 0x1b, 0x1a, + 0xa8, 0x1f, 0x76, 0x83, 0x28, 0xae, 0x84, 0x2c, 0x8d, 0xb6, 0xad, 0x21, 0x02, 0x0f, 0x2f, 0xb9, + 0x2f, 0xd8, 0x01, 0x15, 0xd7, 0xb7, 0x71, 0x62, 0xcf, 0x51, 0x0d, 0xfb, 0x42, 0x76, 0x49, 0xa5, + 0x5d, 0xdb, 0x9e, 0xdd, 0x4c, 0xdb, 0x6d, 0x33, 0xfd, 0x83, 0x9e, 0xd2, 0x90, 0x1d, 0x4c, 0x74, + 0x54, 0x6d, 0xa3, 0x56, 0x59, 0xc9, 0xe7, 0x4d, 0x79, 0xe6, 0xcf, 0x78, 0xb1, 0xd2, 0x5d, 0x07, + 0x31, 0x5f, 0xb5, 0x66, 0xd7, 0x04, 0x39, 0x65, 0x3f, 0x3e, 0x08, 0xc6, 0xa9, 0x08, 0x78, 0xfc, + 0x59, 0x17, 0x78, 0x66, 0xdc, 0x5e, 0xfa, 0xe2, 0xf6, 0x5c, 0x56, 0xee, 0x5d, 0xaf, 0x73, 0xc6, + 0xd9, 0x8d, 0x7f, 0x3f, 0xba, 0xbe, 0xb3, 0x18, 0x0f, 0x95, 0x4c, 0x17, 0x58, 0x5b, 0xee, 0x6b, + 0xac, 0xfd, 0x85, 0xaa, 0xd3, 0x6c, 0xa0, 0xec, 0x77, 0x14, 0xb3, 0xea, 0x95, 0xfd, 0x8e, 0xbb, + 0x86, 0x47, 0x71, 0xd4, 0x29, 0xbb, 0x4f, 0xfe, 0xee, 0xb0, 0xec, 0x0b, 0x9a, 0x59, 0x99, 0xc6, + 0x3d, 0xc5, 0xd4, 0x4d, 0xd0, 0xef, 0x75, 0xe4, 0xa5, 0x9e, 0xcb, 0xf4, 0x22, 0x6e, 0xe6, 0x72, + 0xc1, 0x98, 0x54, 0xcc, 0x31, 0x59, 0x4f, 0xe9, 0xef, 0x39, 0x4f, 0x6a, 0x95, 0xcb, 0x62, 0x03, + 0x8f, 0x87, 0x74, 0xd6, 0xe2, 0xb7, 0x31, 0x9b, 0x56, 0xbe, 0xcb, 0x68, 0xd2, 0x57, 0x39, 0xe8, + 0x0b, 0xcc, 0xe5, 0x94, 0x36, 0xec, 0xf6, 0xef, 0x09, 0x4c, 0x1e, 0x51, 0x41, 0xaf, 0x68, 0xc4, + 0xc8, 0x1e, 0x2a, 0x27, 0x34, 0x22, 0x99, 0x16, 0x4b, 0xd7, 0x99, 0x33, 0x97, 0xbb, 0xd7, 0xd8, + 0x3d, 0x54, 0x8e, 0x99, 0xc8, 0xc6, 0xa5, 0x3b, 0x2b, 0x1b, 0x67, 0xee, 0xaa, 0x3d, 0x54, 0xce, + 0xfa, 0x43, 0x71, 0x69, 0x4b, 0x67, 0xe3, 0xcc, 0xb5, 0xf3, 0x16, 0xe3, 0x83, 0x56, 0x26, 0xf6, + 0xf6, 0x76, 0x9c, 0x22, 0x97, 0x06, 0xbc, 0x41, 0x55, 0xee, 0x0c, 0x92, 0x79, 0xc1, 0x58, 0x38, + 0x4e, 0x2b, 0xef, 0xd0, 0xa1, 0x07, 0x98, 0xd0, 0xa3, 0x4f, 0x32, 0x2f, 0x64, 0xb7, 0x89, 0xf3, + 0xa4, 0xd0, 0xa7, 0x19, 0xfb, 0x18, 0x53, 0x9b, 0x80, 0x64, 0x9e, 0x31, 0x97, 0x85, 0x33, 0x5f, + 0xe0, 0xd1, 0xd1, 0x1f, 0x81, 0x74, 0x5c, 0xc9, 0x82, 0x29, 0xcc, 0x2d, 0x0f, 0x67, 0xd1, 0xe6, + 0xd6, 0xb0, 0x5f, 0x25, 0x2c, 0x8c, 0x1c, 0x34, 0xf2, 0xd2, 0x24, 0xfc, 0xcb, 0xa4, 0x3b, 0x5b, + 0xff, 0x11, 0xa1, 0xd3, 0xf8, 0x84, 0xba, 0x39, 0x8e, 0x64, 0xc9, 0x44, 0x14, 0xcc, 0xb3, 0xb3, + 0x6c, 0x17, 0x68, 0xe4, 0x05, 0xa6, 0x33, 0xb3, 0x45, 0x0a, 0x43, 0xcc, 0xf1, 0x74, 0x56, 0x46, + 0x28, 0x34, 0xf5, 0x33, 0x66, 0x86, 0x86, 0x8b, 0xb8, 0x45, 0x51, 0xd9, 0x19, 0x75, 0x56, 0x47, + 0x6a, 0x06, 0xec, 0xab, 0x71, 0xe5, 0xde, 0xf9, 0x13, 0x00, 0x00, 0xff, 0xff, 0xc1, 0x4c, 0xf4, + 0x67, 0x5e, 0x08, 0x00, 0x00, } // Reference imports to suppress errors if they are not otherwise used. diff --git a/database/rpcdb/rpcdbproto/rpcdb.proto b/database/rpcdb/rpcdbproto/rpcdb.proto index 656b0abdaa29..0a4fd499f852 100644 --- a/database/rpcdb/rpcdbproto/rpcdb.proto +++ b/database/rpcdb/rpcdbproto/rpcdb.proto @@ -102,7 +102,9 @@ message IteratorReleaseRequest { uint64 id = 1; } -message IteratorReleaseResponse {} +message IteratorReleaseResponse { + uint32 err = 1; +} service Database { rpc Has(HasRequest) returns (HasResponse); diff --git a/database/test_database.go b/database/test_database.go index ba953481b634..daf00e44f162 100644 --- a/database/test_database.go +++ b/database/test_database.go @@ -26,6 +26,8 @@ var ( TestIteratorStartPrefix, TestIteratorMemorySafety, TestIteratorClosed, + TestIteratorError, + TestIteratorErrorAfterRelease, TestStatNoPanic, TestCompactNoPanic, TestMemorySafetyDatabase, @@ -805,6 +807,75 @@ func TestIteratorClosed(t *testing.T, db Database) { } } +// TestIteratorError tests to make sure that an iterator still works after the +// database is closed. +func TestIteratorError(t *testing.T, db Database) { + key := []byte("hello1") + value := []byte("world1") + + if err := db.Put(key, value); err != nil { + t.Fatalf("Unexpected error on batch.Put: %s", err) + } + + iterator := db.NewIterator() + if iterator == nil { + t.Fatalf("db.NewIterator returned nil") + } + defer iterator.Release() + + if err := db.Close(); err != nil { + t.Fatalf("Unexpected error on db.Close: %s", err) + } + + if !iterator.Next() { + t.Fatalf("iterator.Next Returned: %v ; Expected: %v", false, true) + } + if itKey := iterator.Key(); !bytes.Equal(itKey, key) { + t.Fatalf("iterator.Key Returned: 0x%x ; Expected: 0x%x", itKey, key) + } + if itValue := iterator.Value(); !bytes.Equal(itValue, value) { + t.Fatalf("iterator.Value Returned: 0x%x ; Expected: 0x%x", itValue, value) + } + if err := iterator.Error(); err != nil { + t.Fatalf("Expected no error on iterator.Error but got %s", err) + } +} + +// TestIteratorErrorAfterRelease tests to make sure that an iterator that was +// released still reports the error correctly. +func TestIteratorErrorAfterRelease(t *testing.T, db Database) { + key := []byte("hello1") + value := []byte("world1") + + if err := db.Put(key, value); err != nil { + t.Fatalf("Unexpected error on batch.Put: %s", err) + } + + if err := db.Close(); err != nil { + t.Fatalf("Unexpected error on db.Close: %s", err) + } + + iterator := db.NewIterator() + if iterator == nil { + t.Fatalf("db.NewIterator returned nil") + } + + iterator.Release() + + if iterator.Next() { + t.Fatalf("iterator.Next Returned: %v ; Expected: %v", false, true) + } + if key := iterator.Key(); key != nil { + t.Fatalf("iterator.Key Returned: 0x%x ; Expected: nil", key) + } + if value := iterator.Value(); value != nil { + t.Fatalf("iterator.Value Returned: 0x%x ; Expected: nil", value) + } + if err := iterator.Error(); err != ErrClosed { + t.Fatalf("Expected %s on iterator.Error", ErrClosed) + } +} + // TestStatNoPanic tests to make sure that Stat never panics. func TestStatNoPanic(t *testing.T, db Database) { key1 := []byte("hello1") diff --git a/genesis/genesis_fuji.go b/genesis/genesis_fuji.go index 4179bec14c7f..15fb40d8afaa 100644 --- a/genesis/genesis_fuji.go +++ b/genesis/genesis_fuji.go @@ -210,6 +210,5 @@ var ( StakeMintingPeriod: 365 * 24 * time.Hour, EpochFirstTransition: time.Unix(1607626800, 0), EpochDuration: 6 * time.Hour, - ApricotPhase0Time: time.Date(2020, 12, 5, 5, 00, 0, 0, time.UTC), } ) diff --git a/genesis/genesis_local.go b/genesis/genesis_local.go index 0e56da46698c..6c6a27ab6cc2 100644 --- a/genesis/genesis_local.go +++ b/genesis/genesis_local.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/avalanchego/utils/units" ) -// PrivateKey-vmRQiZeXEXYMyJhEiqdC2z5JhuDbxL8ix9UVvjgMu2Er1NepE => X-local1g65uqn6t77p656w64023nh8nd9updzmxyymev2 +// PrivateKey-vmRQiZeXEXYMyJhEiqdC2z5JhuDbxL8ix9UVvjgMu2Er1NepE => P-local1g65uqn6t77p656w64023nh8nd9updzmxyymev2 // PrivateKey-ewoqjP7PxY4yr3iLTpLisriqt94hdyDFNgchSxGGztUrTXtNN => X-local18jma8ppw3nhx5r4ap8clazz0dps7rv5u00z96u var ( @@ -104,6 +104,5 @@ var ( StakeMintingPeriod: 365 * 24 * time.Hour, EpochFirstTransition: time.Unix(1607626800, 0), EpochDuration: 5 * time.Minute, - ApricotPhase0Time: time.Date(2020, 12, 5, 5, 00, 0, 0, time.UTC), } ) diff --git a/genesis/genesis_mainnet.go b/genesis/genesis_mainnet.go index cdeaced3a96e..7cb88659a7f2 100644 --- a/genesis/genesis_mainnet.go +++ b/genesis/genesis_mainnet.go @@ -175783,6 +175783,5 @@ var ( StakeMintingPeriod: 365 * 24 * time.Hour, EpochFirstTransition: time.Unix(1607626800, 0), EpochDuration: 6 * time.Hour, - ApricotPhase0Time: time.Date(2020, 12, 8, 3, 00, 0, 0, time.UTC), } ) diff --git a/genesis/params.go b/genesis/params.go index d60d9f3f121e..d443ebcbec18 100644 --- a/genesis/params.go +++ b/genesis/params.go @@ -40,8 +40,6 @@ type Params struct { EpochFirstTransition time.Time // EpochDuration is the amount of time that an epoch runs for. EpochDuration time.Duration - // Time that Apricot phase 0 rules go into effect - ApricotPhase0Time time.Time } // GetParams ... diff --git a/indexer/client.go b/indexer/client.go new file mode 100644 index 000000000000..eb55483b41c9 --- /dev/null +++ b/indexer/client.go @@ -0,0 +1,50 @@ +package indexer + +import ( + "time" + + "github.com/ava-labs/avalanchego/utils/rpc" +) + +type Client struct { + rpc.EndpointRequester +} + +// NewClient creates a client that can interact with an index via HTTP API calls. +// [host] is the host to make API calls to (e.g. http://1.2.3.4:9650). +// [endpoint] is the path to the index endpoint (e.g. /ext/index/C/block or /ext/index/X/tx). +func NewClient(host, endpoint string, requestTimeout time.Duration) *Client { + return &Client{ + EndpointRequester: rpc.NewEndpointRequester(host, endpoint, "index", requestTimeout), + } +} + +func (c *Client) GetContainerRange(args *GetContainerRangeArgs) ([]FormattedContainer, error) { + var response GetContainerRangeResponse + err := c.SendRequest("getContainerRange", args, &response) + return response.Containers, err +} + +func (c *Client) GetContainerByIndex(args *GetContainer) (FormattedContainer, error) { + var response FormattedContainer + err := c.SendRequest("getContainerByIndex", args, &response) + return response, err +} + +func (c *Client) GetLastAccepted(args *GetLastAcceptedArgs) (FormattedContainer, error) { + var response FormattedContainer + err := c.SendRequest("getLastAccepted", args, &response) + return response, err +} + +func (c *Client) GetIndex(args *GetIndexArgs) (GetIndexResponse, error) { + var response GetIndexResponse + err := c.SendRequest("getIndex", args, &response) + return response, err +} + +func (c *Client) IsAccepted(args *GetIndexArgs) (bool, error) { + var response bool + err := c.SendRequest("isAccepted", args, &response) + return response, err +} diff --git a/indexer/client_test.go b/indexer/client_test.go new file mode 100644 index 000000000000..3f1c2d1ea309 --- /dev/null +++ b/indexer/client_test.go @@ -0,0 +1,70 @@ +package indexer + +import ( + "testing" + "time" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/formatting" + "github.com/stretchr/testify/assert" +) + +type mockClient struct { + f func(reply interface{}) error +} + +func (mc *mockClient) SendRequest(_ string, _ interface{}, reply interface{}) error { + return mc.f(reply) +} + +func TestIndexClient(t *testing.T) { + assert := assert.New(t) + client := NewClient("http://localhost:9650", "/ext/index/C/block", time.Minute) + + // Test GetIndex + client.EndpointRequester = &mockClient{ + f: func(reply interface{}) error { + *(reply.(*GetIndexResponse)) = GetIndexResponse{Index: 5} + return nil + }, + } + index, err := client.GetIndex(&GetIndexArgs{ContainerID: ids.Empty, Encoding: formatting.Hex}) + assert.NoError(err) + assert.EqualValues(5, index.Index) + + // Test GetLastAccepted + id := ids.GenerateTestID() + client.EndpointRequester = &mockClient{ + f: func(reply interface{}) error { + *(reply.(*FormattedContainer)) = FormattedContainer{ID: id} + return nil + }, + } + container, err := client.GetLastAccepted(&GetLastAcceptedArgs{Encoding: formatting.Hex}) + assert.NoError(err) + assert.EqualValues(id, container.ID) + + // Test GetContainerRange + id = ids.GenerateTestID() + client.EndpointRequester = &mockClient{ + f: func(reply interface{}) error { + *(reply.(*GetContainerRangeResponse)) = GetContainerRangeResponse{Containers: []FormattedContainer{{ID: id}}} + return nil + }, + } + containers, err := client.GetContainerRange(&GetContainerRangeArgs{StartIndex: 1, NumToFetch: 10, Encoding: formatting.Hex}) + assert.NoError(err) + assert.Len(containers, 1) + assert.EqualValues(id, containers[0].ID) + + // Test IsAccepted + client.EndpointRequester = &mockClient{ + f: func(reply interface{}) error { + *(reply.(*bool)) = true + return nil + }, + } + isAccepted, err := client.IsAccepted(&GetIndexArgs{ContainerID: ids.Empty, Encoding: formatting.Hex}) + assert.NoError(err) + assert.True(isAccepted) +} diff --git a/indexer/container.go b/indexer/container.go new file mode 100644 index 000000000000..f6fe3a7c0a83 --- /dev/null +++ b/indexer/container.go @@ -0,0 +1,14 @@ +package indexer + +import "github.com/ava-labs/avalanchego/ids" + +// Container is something that gets accepted +// (a block, transaction or vertex) +type Container struct { + // ID of this container + ID ids.ID `serialize:"true"` + // Byte representation of this container + Bytes []byte `serialize:"true"` + // Unix time, in nanoseconds, at which this container was accepted by this node + Timestamp int64 `serialize:"true"` +} diff --git a/indexer/index.go b/indexer/index.go new file mode 100644 index 000000000000..ccac58223799 --- /dev/null +++ b/indexer/index.go @@ -0,0 +1,300 @@ +package indexer + +import ( + "errors" + "fmt" + "io" + "sync" + + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/prefixdb" + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/math" + "github.com/ava-labs/avalanchego/utils/timer" + "github.com/ava-labs/avalanchego/utils/wrappers" +) + +const ( + // Maximum number of containers IDs that can be fetched at a time + // in a call to GetContainerRange + MaxFetchedByRange = 1024 +) + +var ( + // Maps to the byte representation of the next accepted index + nextAcceptedIndexKey []byte = []byte{0x00} + indexToContainerPrefix []byte = []byte{0x01} + containerToIDPrefix []byte = []byte{0x02} + errNoneAccepted = errors.New("no containers have been accepted") + errNumToFetchZero = fmt.Errorf("numToFetch must be in [1,%d]", MaxFetchedByRange) + + _ Index = &index{} +) + +// Index indexes containers in their order of acceptance +// Index implements triggers.Acceptor +// Index is thread-safe. +// Index assumes that Accept is called before the container is committed to the +// database of the VM that the container exists in. +type Index interface { + Accept(ctx *snow.Context, containerID ids.ID, container []byte) error + GetContainerByIndex(index uint64) (Container, error) + GetContainerRange(startIndex uint64, numToFetch uint64) ([]Container, error) + GetLastAccepted() (Container, error) + GetIndex(containerID ids.ID) (uint64, error) + GetContainerByID(containerID ids.ID) (Container, error) + io.Closer +} + +// indexer indexes all accepted transactions by the order in which they were accepted +type index struct { + codec codec.Manager + clock timer.Clock + lock sync.RWMutex + // The index of the next accepted transaction + nextAcceptedIndex uint64 + // When [baseDB] is committed, writes to [baseDB] + vDB *versiondb.Database + baseDB database.Database + // Both [indexToContainer] and [containerToIndex] have [vDB] underneath + // Index --> Container + indexToContainer database.Database + // Container ID --> Index + containerToIndex database.Database + log logging.Logger +} + +// Returns a new, thread-safe Index. +// Closes [baseDB] on close. +func newIndex( + baseDB database.Database, + log logging.Logger, + codec codec.Manager, + clock timer.Clock, +) (Index, error) { + vDB := versiondb.New(baseDB) + indexToContainer := prefixdb.New(indexToContainerPrefix, vDB) + containerToIndex := prefixdb.New(containerToIDPrefix, vDB) + + i := &index{ + clock: clock, + codec: codec, + baseDB: baseDB, + vDB: vDB, + indexToContainer: indexToContainer, + containerToIndex: containerToIndex, + log: log, + } + + // Get next accepted index from db + nextAcceptedIndexBytes, err := i.vDB.Get(nextAcceptedIndexKey) + if err == database.ErrNotFound { + // Couldn't find it in the database. Must not have accepted any containers in previous runs. + i.log.Info("next accepted index %d", i.nextAcceptedIndex) + return i, nil + } + if err != nil { + return nil, fmt.Errorf("couldn't get next accepted index from database: %w", err) + } + i.nextAcceptedIndex, err = wrappers.UnpackLong(nextAcceptedIndexBytes) + if err != nil { + return nil, fmt.Errorf("couldn't parse next accepted index from bytes: %w", err) + } + i.log.Info("next accepted index %d", i.nextAcceptedIndex) + return i, nil +} + +// Close this index +func (i *index) Close() error { + errs := wrappers.Errs{} + errs.Add( + i.indexToContainer.Close(), + i.containerToIndex.Close(), + i.vDB.Close(), + i.baseDB.Close(), + ) + return errs.Err +} + +// Index that the given transaction is accepted +// Returned error should be treated as fatal; the VM should not commit [containerID] +// or any new containers as accepted. +func (i *index) Accept(ctx *snow.Context, containerID ids.ID, containerBytes []byte) error { + i.lock.Lock() + defer i.lock.Unlock() + + // It may be the case that in a previous run of this node, this index committed [containerID] + // as accepted and then the node shut down before the VM committed [containerID] as accepted. + // In that case, when the node restarts Accept will be called with the same container. + // Make sure we don't index the same container twice in that event. + _, err := i.containerToIndex.Get(containerID[:]) + if err == nil { + ctx.Log.Debug("not indexing already accepted container %s", containerID) + return nil + } + if err != database.ErrNotFound { + return fmt.Errorf("couldn't get whether %s is accepted: %w", containerID, err) + } + + ctx.Log.Debug("indexing %d --> container %s", i.nextAcceptedIndex, containerID) + // Persist index --> Container + nextAcceptedIndexBytes := wrappers.PackLong(i.nextAcceptedIndex) + bytes, err := i.codec.Marshal(codecVersion, Container{ + ID: containerID, + Bytes: containerBytes, + Timestamp: i.clock.Time().UnixNano(), + }) + if err != nil { + return fmt.Errorf("couldn't serialize container %s: %w", containerID, err) + } + if err := i.indexToContainer.Put(nextAcceptedIndexBytes, bytes); err != nil { + return fmt.Errorf("couldn't put accepted container %s into index: %w", containerID, err) + } + + // Persist container ID --> index + if err := i.containerToIndex.Put(containerID[:], nextAcceptedIndexBytes); err != nil { + return fmt.Errorf("couldn't map container %s to index: %w", containerID, err) + } + + // Persist next accepted index + i.nextAcceptedIndex++ + nextAcceptedIndexBytes = wrappers.PackLong(i.nextAcceptedIndex) + if err := i.vDB.Put(nextAcceptedIndexKey, nextAcceptedIndexBytes); err != nil { + return fmt.Errorf("couldn't put accepted container %s into index: %w", containerID, err) + } + + // Atomically commit [i.vDB], [i.indexToContainer], [i.containerToIndex] to [i.baseDB] + return i.vDB.Commit() +} + +// Returns the ID of the [index]th accepted container and the container itself. +// For example, if [index] == 0, returns the first accepted container. +// If [index] == 1, returns the second accepted container, etc. +// Returns an error if there is no container at the given index. +func (i *index) GetContainerByIndex(index uint64) (Container, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + return i.getContainerByIndex(index) +} + +// Assumes [i.lock] is held +func (i *index) getContainerByIndex(index uint64) (Container, error) { + lastAcceptedIndex, ok := i.lastAcceptedIndex() + if !ok || index > lastAcceptedIndex { + return Container{}, fmt.Errorf("no container at index %d", index) + } + indexBytes := wrappers.PackLong(index) + return i.getContainerByIndexBytes(indexBytes) +} + +// [indexBytes] is the byte representation of the index to fetch. +// Assumes [i.lock] is held +func (i *index) getContainerByIndexBytes(indexBytes []byte) (Container, error) { + containerBytes, err := i.indexToContainer.Get(indexBytes) + if err != nil { + i.log.Error("couldn't read container from database: %w", err) + return Container{}, fmt.Errorf("couldn't read from database: %w", err) + } + var container Container + if _, err = i.codec.Unmarshal(containerBytes, &container); err != nil { + return Container{}, fmt.Errorf("couldn't unmarshal container: %w", err) + } + return container, nil +} + +// GetContainerRange returns the IDs of containers at indices +// [startIndex], [startIndex+1], ..., [startIndex+numToFetch-1]. +// [startIndex] should be <= i.lastAcceptedIndex(). +// [numToFetch] should be in [0, MaxFetchedByRange] +func (i *index) GetContainerRange(startIndex, numToFetch uint64) ([]Container, error) { + // Check arguments for validity + if numToFetch == 0 { + return nil, errNumToFetchZero + } else if numToFetch > MaxFetchedByRange { + return nil, fmt.Errorf("requested %d but maximum page size is %d", numToFetch, MaxFetchedByRange) + } + + i.lock.RLock() + defer i.lock.RUnlock() + + lastAcceptedIndex, ok := i.lastAcceptedIndex() + if !ok { + return nil, errNoneAccepted + } else if startIndex > lastAcceptedIndex { + return nil, fmt.Errorf("start index (%d) > last accepted index (%d)", startIndex, lastAcceptedIndex) + } + + // Calculate the last index we will fetch + lastIndex := math.Min64(startIndex+numToFetch-1, lastAcceptedIndex) + // [lastIndex] is always >= [startIndex] so this is safe. + // [numToFetch] is limited to [MaxFetchedByRange] so [containers] is bounded in size. + containers := make([]Container, int(lastIndex)-int(startIndex)+1) + + n := 0 + var err error + for j := startIndex; j <= lastIndex; j++ { + containers[n], err = i.getContainerByIndex(j) + if err != nil { + return nil, fmt.Errorf("couldn't get container at index %d: %w", j, err) + } + n++ + } + return containers, nil +} + +// Returns database.ErrNotFound if the container is not indexed as accepted +func (i *index) GetIndex(containerID ids.ID) (uint64, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + indexBytes, err := i.containerToIndex.Get(containerID[:]) + if err != nil { + return 0, err + } + index, err := wrappers.UnpackLong(indexBytes) + if err != nil { + // Should never happen + i.log.Error("couldn't unpack index: %w", err) + return 0, err + } + return index, nil +} + +func (i *index) GetContainerByID(containerID ids.ID) (Container, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + // Read index from database + indexBytes, err := i.containerToIndex.Get(containerID[:]) + if err != nil { + return Container{}, err + } + return i.getContainerByIndexBytes(indexBytes) +} + +// GetLastAccepted returns the last accepted container. +// Returns an error if no containers have been accepted. +func (i *index) GetLastAccepted() (Container, error) { + i.lock.RLock() + defer i.lock.RUnlock() + + lastAcceptedIndex, exists := i.lastAcceptedIndex() + if !exists { + return Container{}, errNoneAccepted + } + return i.getContainerByIndex(lastAcceptedIndex) +} + +// Assumes i.lock is held +// Returns: +// 1) The index of the most recently accepted transaction, +// or 0 if no transactions have been accepted +// 2) Whether at least 1 transaction has been accepted +func (i *index) lastAcceptedIndex() (uint64, bool) { + return i.nextAcceptedIndex - 1, i.nextAcceptedIndex != 0 +} diff --git a/indexer/index_test.go b/indexer/index_test.go new file mode 100644 index 000000000000..338507274c75 --- /dev/null +++ b/indexer/index_test.go @@ -0,0 +1,173 @@ +package indexer + +import ( + "testing" + + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/codec/linearcodec" + "github.com/ava-labs/avalanchego/database/memdb" + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow" + "github.com/ava-labs/avalanchego/utils" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/timer" + "github.com/stretchr/testify/assert" +) + +func TestIndex(t *testing.T) { + // Setup + pageSize := uint64(64) + assert := assert.New(t) + codec := codec.NewDefaultManager() + err := codec.RegisterCodec(codecVersion, linearcodec.NewDefault()) + assert.NoError(err) + baseDB := memdb.New() + db := versiondb.New(baseDB) + ctx := snow.DefaultContextTest() + + indexIntf, err := newIndex(db, logging.NoLog{}, codec, timer.Clock{}) + assert.NoError(err) + idx := indexIntf.(*index) + + // Populate "containers" with random IDs/bytes + containers := map[ids.ID][]byte{} + for i := uint64(0); i < 2*pageSize; i++ { + containers[ids.GenerateTestID()] = utils.RandomBytes(32) + } + + // Accept each container and after each, make assertions + i := uint64(0) + for containerID, containerBytes := range containers { + err = idx.Accept(ctx, containerID, containerBytes) + assert.NoError(err) + + lastAcceptedIndex, ok := idx.lastAcceptedIndex() + assert.True(ok) + assert.EqualValues(i, lastAcceptedIndex) + assert.EqualValues(i+1, idx.nextAcceptedIndex) + + gotContainer, err := idx.GetContainerByID(containerID) + assert.NoError(err) + assert.Equal(containerBytes, gotContainer.Bytes) + + gotIndex, err := idx.GetIndex(containerID) + assert.NoError(err) + assert.EqualValues(i, gotIndex) + + gotContainer, err = idx.GetContainerByIndex(i) + assert.NoError(err) + assert.Equal(containerBytes, gotContainer.Bytes) + + gotContainer, err = idx.GetLastAccepted() + assert.NoError(err) + assert.Equal(containerBytes, gotContainer.Bytes) + + containers, err := idx.GetContainerRange(i, 1) + assert.NoError(err) + assert.Len(containers, 1) + assert.Equal(containerBytes, containers[0].Bytes) + + containers, err = idx.GetContainerRange(i, 2) + assert.NoError(err) + assert.Len(containers, 1) + assert.Equal(containerBytes, containers[0].Bytes) + + i++ + } + + // Create a new index with the same database and ensure contents still there + assert.NoError(db.Commit()) + assert.NoError(idx.Close()) + db = versiondb.New(baseDB) + indexIntf, err = newIndex(db, logging.NoLog{}, codec, timer.Clock{}) + assert.NoError(err) + idx = indexIntf.(*index) + + // Get all of the containers + containersList, err := idx.GetContainerRange(0, pageSize) + assert.NoError(err) + assert.Len(containersList, int(pageSize)) + containersList2, err := idx.GetContainerRange(pageSize, pageSize) + assert.NoError(err) + assert.Len(containersList2, int(pageSize)) + containersList = append(containersList, containersList2...) + + // Ensure that the data is correct + lastTimestamp := int64(0) + sawContainers := ids.Set{} + for _, container := range containersList { + assert.False(sawContainers.Contains(container.ID)) // Should only see this container once + assert.Contains(containers, container.ID) + assert.EqualValues(containers[container.ID], container.Bytes) + // Timestamps should be non-decreasing + assert.True(container.Timestamp >= lastTimestamp) + lastTimestamp = container.Timestamp + sawContainers.Add(container.ID) + } +} + +func TestIndexGetContainerByRangeMaxPageSize(t *testing.T) { + // Setup + assert := assert.New(t) + codec := codec.NewDefaultManager() + err := codec.RegisterCodec(codecVersion, linearcodec.NewDefault()) + assert.NoError(err) + db := memdb.New() + ctx := snow.DefaultContextTest() + indexIntf, err := newIndex(db, logging.NoLog{}, codec, timer.Clock{}) + assert.NoError(err) + idx := indexIntf.(*index) + + // Insert [MaxFetchedByRange] + 1 containers + for i := uint64(0); i < MaxFetchedByRange+1; i++ { + err = idx.Accept(ctx, ids.GenerateTestID(), utils.RandomBytes(32)) + assert.NoError(err) + } + + // Page size too large + _, err = idx.GetContainerRange(0, MaxFetchedByRange+1) + assert.Error(err) + + // Make sure data is right + containers, err := idx.GetContainerRange(0, MaxFetchedByRange) + assert.NoError(err) + assert.Len(containers, MaxFetchedByRange) + + containers2, err := idx.GetContainerRange(1, MaxFetchedByRange) + assert.NoError(err) + assert.Len(containers2, MaxFetchedByRange) + + assert.Equal(containers[1], containers2[0]) + assert.Equal(containers[MaxFetchedByRange-1], containers2[MaxFetchedByRange-2]) + + // Should have last 2 elements + containers, err = idx.GetContainerRange(MaxFetchedByRange-1, MaxFetchedByRange) + assert.NoError(err) + assert.Len(containers, 2) + assert.EqualValues(containers[1], containers2[MaxFetchedByRange-1]) + assert.EqualValues(containers[0], containers2[MaxFetchedByRange-2]) +} + +func TestDontIndexSameContainerTwice(t *testing.T) { + // Setup + assert := assert.New(t) + codec := codec.NewDefaultManager() + err := codec.RegisterCodec(codecVersion, linearcodec.NewDefault()) + assert.NoError(err) + db := memdb.New() + ctx := snow.DefaultContextTest() + idx, err := newIndex(db, logging.NoLog{}, codec, timer.Clock{}) + assert.NoError(err) + + // Accept the same container twice + containerID := ids.GenerateTestID() + assert.NoError(idx.Accept(ctx, containerID, []byte{1, 2, 3})) + assert.NoError(idx.Accept(ctx, containerID, []byte{4, 5, 6})) + _, err = idx.GetContainerByIndex(1) + assert.Error(err, "should not have accepted same container twice") + gotContainer, err := idx.GetContainerByID(containerID) + assert.NoError(err) + assert.EqualValues(gotContainer.Bytes, []byte{1, 2, 3}, "should not have accepted same container twice") + +} diff --git a/indexer/indexer.go b/indexer/indexer.go new file mode 100644 index 000000000000..d389a6de18c0 --- /dev/null +++ b/indexer/indexer.go @@ -0,0 +1,385 @@ +package indexer + +import ( + "fmt" + "io" + "math" + "sync" + + "github.com/ava-labs/avalanchego/api/server" + "github.com/ava-labs/avalanchego/network" + "github.com/ava-labs/avalanchego/snow" + "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/hashing" + "github.com/ava-labs/avalanchego/utils/json" + "github.com/ava-labs/avalanchego/utils/timer" + "github.com/ava-labs/avalanchego/utils/wrappers" + + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/codec/linearcodec" + "github.com/ava-labs/avalanchego/codec/reflectcodec" + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/prefixdb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/avalanche" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/snow/engine/snowman" + "github.com/ava-labs/avalanchego/snow/triggers" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/gorilla/rpc/v2" +) + +const ( + indexNamePrefix = "index-" + codecVersion = uint16(0) + // Max size, in bytes, of something serialized by this indexer + // Assumes no containers are larger than math.MaxUint32 + // wrappers.IntLen accounts for the size of the container bytes + // wrappers.LongLen accounts for the timestamp of the container + // hashing.HashLen accounts for the container ID + // wrappers.ShortLen accounts for the codec version + codecMaxSize = int(network.DefaultMaxMessageSize) + wrappers.IntLen + wrappers.LongLen + hashing.HashLen + wrappers.ShortLen +) + +var ( + txPrefix = byte(0x01) + vtxPrefix = byte(0x02) + blockPrefix = byte(0x03) + isIncompletePrefix = byte(0x04) + previouslyIndexedPrefix = byte(0x05) + hasRunKey = []byte{0x07} +) + +var ( + _ Indexer = &indexer{} +) + +// Config for an indexer +type Config struct { + DB database.Database + Log logging.Logger + IndexingEnabled bool + AllowIncompleteIndex bool + DecisionDispatcher, ConsensusDispatcher *triggers.EventDispatcher + APIServer server.RouteAdder + ShutdownF func() +} + +// Indexer causes accepted containers for a given chain +// to be indexed by their ID and by the order in which +// they were accepted by this node. +// Indexer is threadsafe. +type Indexer interface { + RegisterChain(name string, ctx *snow.Context, engine common.Engine) + // Close will do nothing and return nil after the first call + io.Closer +} + +// NewIndexer returns a new Indexer and registers a new endpoint on the given API server. +func NewIndexer(config Config) (Indexer, error) { + indexer := &indexer{ + codec: codec.NewManager(codecMaxSize), + log: config.Log, + db: config.DB, + allowIncompleteIndex: config.AllowIncompleteIndex, + indexingEnabled: config.IndexingEnabled, + consensusDispatcher: config.ConsensusDispatcher, + decisionDispatcher: config.DecisionDispatcher, + txIndices: map[ids.ID]Index{}, + vtxIndices: map[ids.ID]Index{}, + blockIndices: map[ids.ID]Index{}, + routeAdder: config.APIServer, + shutdownF: config.ShutdownF, + } + if err := indexer.codec.RegisterCodec( + codecVersion, + linearcodec.New(reflectcodec.DefaultTagName, math.MaxUint32), + ); err != nil { + return nil, fmt.Errorf("couldn't register codec: %s", err) + } + hasRun, err := indexer.hasRun() + if err != nil { + return nil, err + } + indexer.hasRunBefore = hasRun + return indexer, indexer.markHasRun() +} + +// indexer implements Indexer +type indexer struct { + codec codec.Manager + clock timer.Clock + lock sync.RWMutex + log logging.Logger + db database.Database + closed bool + + // Called in a goroutine on shutdown + shutdownF func() + + // true if this is not the first run using this database + hasRunBefore bool + + // Used to add API endpoint for new indices + routeAdder server.RouteAdder + + // If true, allow running in such a way that could allow the creation + // of an index which could be missing accepted containers. + allowIncompleteIndex bool + + // If false, don't create index for a chain when RegisterChain is called + indexingEnabled bool + + // Chain ID --> index of blocks of that chain (if applicable) + blockIndices map[ids.ID]Index + // Chain ID --> index of vertices of that chain (if applicable) + vtxIndices map[ids.ID]Index + // Chain ID --> index of txs of that chain (if applicable) + txIndices map[ids.ID]Index + + // Notifies of newly accepted blocks and vertices + consensusDispatcher *triggers.EventDispatcher + // Notifies of newly accepted transactions + decisionDispatcher *triggers.EventDispatcher +} + +// Assumes [engine]'s context lock is not held +func (i *indexer) RegisterChain(name string, ctx *snow.Context, engine common.Engine) { + i.lock.Lock() + defer i.lock.Unlock() + + if i.closed { + i.log.Debug("not registering chain %s because indexer is closed", name) + return + } else if ctx.SubnetID != constants.PrimaryNetworkID { + i.log.Debug("not registering chain %s because it's not in primary network", name) + return + } + + chainID := ctx.ChainID + if i.blockIndices[chainID] != nil || i.txIndices[chainID] != nil || i.vtxIndices[chainID] != nil { + i.log.Warn("chain %s is already being indexed", chainID) + return + } + + // If the index is incomplete, make sure that's OK. Otherwise, cause node to die. + isIncomplete, err := i.isIncomplete(chainID) + if err != nil { + i.log.Error("couldn't get whether chain %s is incomplete: %s", name, err) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + + // See if this chain was indexed in a previous run + previouslyIndexed, err := i.previouslyIndexed(chainID) + if err != nil { + i.log.Error("couldn't get whether chain %s was previously indexed: %s", name, err) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + + if !i.indexingEnabled { // Indexing is disabled + if previouslyIndexed && !i.allowIncompleteIndex { + // We indexed this chain in a previous run but not in this run. + // This would create an incomplete index, which is not allowed, so exit. + i.log.Fatal("running would cause index %s would become incomplete but incomplete indices are disabled", name) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + + // Creating an incomplete index is allowed. Mark index as incomplete. + err := i.markIncomplete(chainID) + if err == nil { + return + } + i.log.Fatal("couldn't mark chain %s as incomplete: %s", name, err) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + + if !i.allowIncompleteIndex && isIncomplete && (previouslyIndexed || i.hasRunBefore) { + i.log.Fatal("index %s is incomplete but incomplete indices are disabled. Shutting down", name) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + + // Mark that in this run, this chain was indexed + if err := i.markPreviouslyIndexed(chainID); err != nil { + i.log.Error("couldn't mark chain %s as indexed: %s", name, err) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + + switch engine.(type) { + case snowman.Engine: + index, err := i.registerChainHelper(chainID, blockPrefix, name, "block", i.consensusDispatcher) + if err != nil { + i.log.Fatal("couldn't create block index for %s: %s", name, err) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + i.blockIndices[chainID] = index + case avalanche.Engine: + vtxIndex, err := i.registerChainHelper(chainID, vtxPrefix, name, "vtx", i.consensusDispatcher) + if err != nil { + i.log.Fatal("couldn't create vertex index for %s: %s", name, err) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + i.vtxIndices[chainID] = vtxIndex + + txIndex, err := i.registerChainHelper(chainID, txPrefix, name, "tx", i.decisionDispatcher) + if err != nil { + i.log.Fatal("couldn't create tx index for %s: %s", name, err) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + i.txIndices[chainID] = txIndex + default: + i.log.Error("got unexpected engine type %T", engine) + if err := i.close(); err != nil { + i.log.Error("error while closing indexer: %s", err) + } + return + } + +} + +func (i *indexer) registerChainHelper( + chainID ids.ID, + prefixEnd byte, + name, endpoint string, + dispatcher *triggers.EventDispatcher, +) (Index, error) { + prefix := make([]byte, hashing.HashLen+wrappers.ByteLen) + copy(prefix, chainID[:]) + prefix[hashing.HashLen] = prefixEnd + indexDB := prefixdb.New(prefix, i.db) + index, err := newIndex(indexDB, i.log, i.codec, i.clock) + if err != nil { + _ = indexDB.Close() + return nil, err + } + + // Register index to learn about new accepted vertices + if err := dispatcher.RegisterChain(chainID, fmt.Sprintf("%s%s", indexNamePrefix, chainID), index, true); err != nil { + _ = index.Close() + return nil, err + } + + // Create an API endpoint for this index + apiServer := rpc.NewServer() + codec := json.NewCodec() + apiServer.RegisterCodec(codec, "application/json") + apiServer.RegisterCodec(codec, "application/json;charset=UTF-8") + if err := apiServer.RegisterService(&service{Index: index}, "index"); err != nil { + _ = index.Close() + return nil, err + } + handler := &common.HTTPHandler{LockOptions: common.NoLock, Handler: apiServer} + if err := i.routeAdder.AddRoute(handler, &sync.RWMutex{}, "index/"+name, "/"+endpoint, i.log); err != nil { + _ = index.Close() + return nil, err + } + return index, nil +} + +// Close this indexer. Stops indexing all chains. +// Closes [i.db]. Assumes Close is only called after +// the node is done making decisions. +// Calling Close after it has been called does nothing. +func (i *indexer) Close() error { + i.lock.Lock() + defer i.lock.Unlock() + + return i.close() +} + +func (i *indexer) close() error { + if i.closed { + return nil + } + i.closed = true + + errs := &wrappers.Errs{} + for chainID, txIndex := range i.txIndices { + errs.Add( + txIndex.Close(), + i.decisionDispatcher.DeregisterChain(chainID, fmt.Sprintf("%s%s", indexNamePrefix, chainID)), + ) + } + for chainID, vtxIndex := range i.vtxIndices { + errs.Add( + vtxIndex.Close(), + i.consensusDispatcher.DeregisterChain(chainID, fmt.Sprintf("%s%s", indexNamePrefix, chainID)), + ) + } + for chainID, blockIndex := range i.blockIndices { + errs.Add( + blockIndex.Close(), + i.consensusDispatcher.DeregisterChain(chainID, fmt.Sprintf("%s%s", indexNamePrefix, chainID)), + ) + } + errs.Add(i.db.Close()) + + go i.shutdownF() + return errs.Err +} + +func (i *indexer) markIncomplete(chainID ids.ID) error { + key := make([]byte, hashing.HashLen+wrappers.ByteLen) + copy(key, chainID[:]) + key[hashing.HashLen] = isIncompletePrefix + return i.db.Put(key, nil) +} + +// Returns true if this chain is incomplete +func (i *indexer) isIncomplete(chainID ids.ID) (bool, error) { + key := make([]byte, hashing.HashLen+wrappers.ByteLen) + copy(key, chainID[:]) + key[hashing.HashLen] = isIncompletePrefix + return i.db.Has(key) +} + +func (i *indexer) markPreviouslyIndexed(chainID ids.ID) error { + key := make([]byte, hashing.HashLen+wrappers.ByteLen) + copy(key, chainID[:]) + key[hashing.HashLen] = previouslyIndexedPrefix + return i.db.Put(key, nil) +} + +// Returns true if this chain is incomplete +func (i *indexer) previouslyIndexed(chainID ids.ID) (bool, error) { + key := make([]byte, hashing.HashLen+wrappers.ByteLen) + copy(key, chainID[:]) + key[hashing.HashLen] = previouslyIndexedPrefix + return i.db.Has(key) +} + +// Mark that the node has run at least once +func (i *indexer) markHasRun() error { + return i.db.Put(hasRunKey, nil) +} + +// Returns true if the node has run before +func (i *indexer) hasRun() (bool, error) { + return i.db.Has(hasRunKey) +} diff --git a/indexer/indexer_test.go b/indexer/indexer_test.go new file mode 100644 index 000000000000..c19ca3c633e7 --- /dev/null +++ b/indexer/indexer_test.go @@ -0,0 +1,544 @@ +package indexer + +import ( + "io" + "sync" + "testing" + "time" + + "github.com/ava-labs/avalanchego/database/memdb" + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow" + "github.com/ava-labs/avalanchego/snow/choices" + "github.com/ava-labs/avalanchego/snow/consensus/avalanche" + "github.com/ava-labs/avalanchego/snow/consensus/snowman" + "github.com/ava-labs/avalanchego/snow/consensus/snowstorm" + "github.com/ava-labs/avalanchego/snow/engine/avalanche/mocks" + avvtxmocks "github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex/mocks" + "github.com/ava-labs/avalanchego/snow/engine/common" + smblockmocks "github.com/ava-labs/avalanchego/snow/engine/snowman/block/mocks" + smengmocks "github.com/ava-labs/avalanchego/snow/engine/snowman/mocks" + "github.com/ava-labs/avalanchego/snow/triggers" + "github.com/ava-labs/avalanchego/utils" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/stretchr/testify/assert" +) + +type apiServerMock struct { + timesCalled int + bases []string + endpoints []string +} + +func (a *apiServerMock) AddRoute(_ *common.HTTPHandler, _ *sync.RWMutex, base, endpoint string, _ io.Writer) error { + a.timesCalled++ + a.bases = append(a.bases, base) + a.endpoints = append(a.endpoints, endpoint) + return nil +} + +// Test that newIndexer sets fields correctly +func TestNewIndexer(t *testing.T) { + assert := assert.New(t) + ed := &triggers.EventDispatcher{} + ed.Initialize(logging.NoLog{}) + config := Config{ + IndexingEnabled: true, + AllowIncompleteIndex: true, + Log: logging.NoLog{}, + DB: memdb.New(), + ConsensusDispatcher: ed, + DecisionDispatcher: ed, + APIServer: &apiServerMock{}, + ShutdownF: func() {}, + } + + idxrIntf, err := NewIndexer(config) + assert.NoError(err) + idxr, ok := idxrIntf.(*indexer) + assert.True(ok) + assert.NotNil(idxr.codec) + assert.NotNil(idxr.log) + assert.NotNil(idxr.db) + assert.False(idxr.closed) + assert.NotNil(idxr.routeAdder) + assert.True(idxr.indexingEnabled) + assert.True(idxr.allowIncompleteIndex) + assert.NotNil(idxr.blockIndices) + assert.Len(idxr.blockIndices, 0) + assert.NotNil(idxr.txIndices) + assert.Len(idxr.txIndices, 0) + assert.NotNil(idxr.vtxIndices) + assert.Len(idxr.vtxIndices, 0) + assert.NotNil(idxr.consensusDispatcher) + assert.NotNil(idxr.decisionDispatcher) + assert.NotNil(idxr.shutdownF) + assert.False(idxr.hasRunBefore) +} + +// Test that [hasRunBefore] is set correctly and that Shutdown is called on close +func TestMarkHasRunAndShutdown(t *testing.T) { + assert := assert.New(t) + cd := &triggers.EventDispatcher{} + cd.Initialize(logging.NoLog{}) + dd := &triggers.EventDispatcher{} + dd.Initialize(logging.NoLog{}) + baseDB := memdb.New() + db := versiondb.New(baseDB) + shutdown := &sync.WaitGroup{} + shutdown.Add(1) + config := Config{ + IndexingEnabled: true, + Log: logging.NoLog{}, + DB: db, + ConsensusDispatcher: cd, + DecisionDispatcher: dd, + APIServer: &apiServerMock{}, + ShutdownF: func() { shutdown.Done() }, + } + + idxrIntf, err := NewIndexer(config) + assert.NoError(err) + assert.False(idxrIntf.(*indexer).hasRunBefore) + assert.NoError(db.Commit()) + assert.NoError(idxrIntf.Close()) + shutdown.Wait() + shutdown.Add(1) + + config.DB = versiondb.New(baseDB) + idxrIntf, err = NewIndexer(config) + assert.NoError(err) + idxr, ok := idxrIntf.(*indexer) + assert.True(ok) + assert.True(idxr.hasRunBefore) + assert.NoError(idxr.Close()) + shutdown.Wait() +} + +// Test registering a linear chain and a DAG chain and accepting +// some vertices +func TestIndexer(t *testing.T) { + assert := assert.New(t) + cd := &triggers.EventDispatcher{} + cd.Initialize(logging.NoLog{}) + dd := &triggers.EventDispatcher{} + dd.Initialize(logging.NoLog{}) + baseDB := memdb.New() + db := versiondb.New(baseDB) + config := Config{ + IndexingEnabled: true, + AllowIncompleteIndex: false, + Log: logging.NoLog{}, + DB: db, + ConsensusDispatcher: cd, + DecisionDispatcher: dd, + APIServer: &apiServerMock{}, + ShutdownF: func() {}, + } + + // Create indexer + idxrIntf, err := NewIndexer(config) + assert.NoError(err) + idxr, ok := idxrIntf.(*indexer) + assert.True(ok) + now := time.Now() + idxr.clock.Set(now) + + // Assert state is right + chain1Ctx := snow.DefaultContextTest() + chain1Ctx.ChainID = ids.GenerateTestID() + isIncomplete, err := idxr.isIncomplete(chain1Ctx.ChainID) + assert.NoError(err) + assert.False(isIncomplete) + previouslyIndexed, err := idxr.previouslyIndexed(chain1Ctx.ChainID) + assert.NoError(err) + assert.False(previouslyIndexed) + + // Register this chain, creating a new index + chainVM := &smblockmocks.ChainVM{} + chainEngine := &smengmocks.Engine{} + chainEngine.On("GetVM").Return(chainVM) + + idxr.RegisterChain("chain1", chain1Ctx, chainEngine) + isIncomplete, err = idxr.isIncomplete(chain1Ctx.ChainID) + assert.NoError(err) + assert.False(isIncomplete) + previouslyIndexed, err = idxr.previouslyIndexed(chain1Ctx.ChainID) + assert.NoError(err) + assert.True(previouslyIndexed) + server := config.APIServer.(*apiServerMock) + assert.EqualValues(1, server.timesCalled) + assert.EqualValues("index/chain1", server.bases[0]) + assert.EqualValues("/block", server.endpoints[0]) + assert.Len(idxr.blockIndices, 1) + assert.Len(idxr.txIndices, 0) + assert.Len(idxr.vtxIndices, 0) + + // Accept a container + blkID, blkBytes := ids.GenerateTestID(), utils.RandomBytes(32) + expectedContainer := Container{ + ID: blkID, + Bytes: blkBytes, + Timestamp: now.UnixNano(), + } + // Mocked VM knows about this block now + chainVM.On("GetBlock", blkID).Return( + &snowman.TestBlock{ + TestDecidable: choices.TestDecidable{ + StatusV: choices.Accepted, + IDV: blkID, + }, + BytesV: blkBytes, + }, nil, + ).Twice() + + assert.NoError(cd.Accept(chain1Ctx, blkID, blkBytes)) + + blkIdx := idxr.blockIndices[chain1Ctx.ChainID] + assert.NotNil(blkIdx) + + // Verify GetLastAccepted is right + gotLastAccepted, err := blkIdx.GetLastAccepted() + assert.NoError(err) + assert.Equal(expectedContainer, gotLastAccepted) + + // Verify GetContainerByID is right + container, err := blkIdx.GetContainerByID(blkID) + assert.NoError(err) + assert.Equal(expectedContainer, container) + + // Verify GetIndex is right + index, err := blkIdx.GetIndex(blkID) + assert.NoError(err) + assert.EqualValues(0, index) + + // Verify GetContainerByIndex is right + container, err = blkIdx.GetContainerByIndex(0) + assert.NoError(err) + assert.Equal(expectedContainer, container) + + // Verify GetContainerRange is right + containers, err := blkIdx.GetContainerRange(0, 1) + assert.NoError(err) + assert.Len(containers, 1) + assert.Equal(expectedContainer, containers[0]) + + // Close the indexer + assert.NoError(db.Commit()) + assert.NoError(idxr.Close()) + assert.True(idxr.closed) + // Calling Close again should be fine + assert.NoError(idxr.Close()) + server.timesCalled = 0 + + // Re-open the indexer + config.DB = versiondb.New(baseDB) + idxrIntf, err = NewIndexer(config) + assert.NoError(err) + idxr, ok = idxrIntf.(*indexer) + now = time.Now() + idxr.clock.Set(now) + assert.True(ok) + assert.Len(idxr.blockIndices, 0) + assert.Len(idxr.txIndices, 0) + assert.Len(idxr.vtxIndices, 0) + assert.True(idxr.hasRunBefore) + previouslyIndexed, err = idxr.previouslyIndexed(chain1Ctx.ChainID) + assert.NoError(err) + assert.True(previouslyIndexed) + hasRun, err := idxr.hasRun() + assert.NoError(err) + assert.True(hasRun) + isIncomplete, err = idxr.isIncomplete(chain1Ctx.ChainID) + assert.NoError(err) + assert.False(isIncomplete) + + // Register the same chain as before + idxr.RegisterChain("chain1", chain1Ctx, chainEngine) + blkIdx = idxr.blockIndices[chain1Ctx.ChainID] + assert.NotNil(blkIdx) + container, err = blkIdx.GetLastAccepted() + assert.NoError(err) + assert.Equal(blkID, container.ID) + + // Register a DAG chain + chain2Ctx := snow.DefaultContextTest() + chain2Ctx.ChainID = ids.GenerateTestID() + isIncomplete, err = idxr.isIncomplete(chain2Ctx.ChainID) + assert.NoError(err) + assert.False(isIncomplete) + previouslyIndexed, err = idxr.previouslyIndexed(chain2Ctx.ChainID) + assert.NoError(err) + assert.False(previouslyIndexed) + dagVM := &avvtxmocks.DAGVM{} + dagEngine := &mocks.Engine{} + dagEngine.On("GetVM").Return(dagVM).Once() + idxr.RegisterChain("chain2", chain2Ctx, dagEngine) + assert.NoError(err) + server = config.APIServer.(*apiServerMock) + assert.EqualValues(3, server.timesCalled) // block index, vtx index, tx index + assert.Contains(server.bases, "index/chain2") + assert.Contains(server.endpoints, "/vtx") + assert.Contains(server.endpoints, "/tx") + assert.Len(idxr.blockIndices, 1) + assert.Len(idxr.txIndices, 1) + assert.Len(idxr.vtxIndices, 1) + + // Accept a vertex + vtxID, vtxBytes := ids.GenerateTestID(), utils.RandomBytes(32) + expectedVtx := Container{ + ID: vtxID, + Bytes: blkBytes, + Timestamp: now.UnixNano(), + } + // Mocked VM knows about this block now + dagEngine.On("GetVtx", vtxID).Return( + &avalanche.TestVertex{ + TestDecidable: choices.TestDecidable{ + StatusV: choices.Accepted, + IDV: vtxID, + }, + BytesV: vtxBytes, + }, nil, + ).Once() + + assert.NoError(cd.Accept(chain2Ctx, vtxID, blkBytes)) + + vtxIdx := idxr.vtxIndices[chain2Ctx.ChainID] + assert.NotNil(vtxIdx) + + // Verify GetLastAccepted is right + gotLastAccepted, err = vtxIdx.GetLastAccepted() + assert.NoError(err) + assert.Equal(expectedVtx, gotLastAccepted) + + // Verify GetContainerByID is right + vtx, err := vtxIdx.GetContainerByID(vtxID) + assert.NoError(err) + assert.Equal(expectedVtx, vtx) + + // Verify GetIndex is right + index, err = vtxIdx.GetIndex(vtxID) + assert.NoError(err) + assert.EqualValues(0, index) + + // Verify GetContainerByIndex is right + vtx, err = vtxIdx.GetContainerByIndex(0) + assert.NoError(err) + assert.Equal(expectedVtx, vtx) + + // Verify GetContainerRange is right + vtxs, err := vtxIdx.GetContainerRange(0, 1) + assert.NoError(err) + assert.Len(vtxs, 1) + assert.Equal(expectedVtx, vtxs[0]) + + // Accept a tx + txID, txBytes := ids.GenerateTestID(), utils.RandomBytes(32) + expectedTx := Container{ + ID: txID, + Bytes: blkBytes, + Timestamp: now.UnixNano(), + } + // Mocked VM knows about this tx now + dagVM.On("GetTx", txID).Return( + &snowstorm.TestTx{ + TestDecidable: choices.TestDecidable{ + IDV: txID, + StatusV: choices.Accepted, + }, + BytesV: txBytes, + }, nil, + ).Once() + + assert.NoError(dd.Accept(chain2Ctx, txID, blkBytes)) + + txIdx := idxr.txIndices[chain2Ctx.ChainID] + assert.NotNil(txIdx) + + // Verify GetLastAccepted is right + gotLastAccepted, err = txIdx.GetLastAccepted() + assert.NoError(err) + assert.Equal(expectedTx, gotLastAccepted) + + // Verify GetContainerByID is right + tx, err := txIdx.GetContainerByID(txID) + assert.NoError(err) + assert.Equal(expectedTx, tx) + + // Verify GetIndex is right + index, err = txIdx.GetIndex(txID) + assert.NoError(err) + assert.EqualValues(0, index) + + // Verify GetContainerByIndex is right + tx, err = txIdx.GetContainerByIndex(0) + assert.NoError(err) + assert.Equal(expectedTx, tx) + + // Verify GetContainerRange is right + txs, err := txIdx.GetContainerRange(0, 1) + assert.NoError(err) + assert.Len(txs, 1) + assert.Equal(expectedTx, txs[0]) + + // Accepting a vertex shouldn't have caused anything to + // happen on the block/tx index. Similar for tx. + lastAcceptedTx, err := txIdx.GetLastAccepted() + assert.NoError(err) + assert.EqualValues(txID, lastAcceptedTx.ID) + lastAcceptedVtx, err := vtxIdx.GetLastAccepted() + assert.NoError(err) + assert.EqualValues(vtxID, lastAcceptedVtx.ID) + lastAcceptedBlk, err := blkIdx.GetLastAccepted() + assert.NoError(err) + assert.EqualValues(blkID, lastAcceptedBlk.ID) + + // Close the indexer again + assert.NoError(config.DB.(*versiondb.Database).Commit()) + assert.NoError(idxr.Close()) + + // Re-open one more time and re-register chains + config.DB = versiondb.New(baseDB) + idxrIntf, err = NewIndexer(config) + assert.NoError(err) + idxr, ok = idxrIntf.(*indexer) + assert.True(ok) + idxr.RegisterChain("chain1", chain1Ctx, chainEngine) + idxr.RegisterChain("chain2", chain2Ctx, dagEngine) + + // Verify state + lastAcceptedTx, err = idxr.txIndices[chain2Ctx.ChainID].GetLastAccepted() + assert.NoError(err) + assert.EqualValues(txID, lastAcceptedTx.ID) + lastAcceptedVtx, err = idxr.vtxIndices[chain2Ctx.ChainID].GetLastAccepted() + assert.NoError(err) + assert.EqualValues(vtxID, lastAcceptedVtx.ID) + lastAcceptedBlk, err = idxr.blockIndices[chain1Ctx.ChainID].GetLastAccepted() + assert.NoError(err) + assert.EqualValues(blkID, lastAcceptedBlk.ID) +} + +// Make sure the indexer doesn't allow incomplete indices unless explicitly allowed +func TestIncompleteIndex(t *testing.T) { + // Create an indexer with indexing disabled + assert := assert.New(t) + cd := &triggers.EventDispatcher{} + cd.Initialize(logging.NoLog{}) + dd := &triggers.EventDispatcher{} + dd.Initialize(logging.NoLog{}) + baseDB := memdb.New() + config := Config{ + IndexingEnabled: false, + AllowIncompleteIndex: false, + Log: logging.NoLog{}, + DB: versiondb.New(baseDB), + ConsensusDispatcher: cd, + DecisionDispatcher: dd, + APIServer: &apiServerMock{}, + ShutdownF: func() {}, + } + idxrIntf, err := NewIndexer(config) + assert.NoError(err) + idxr, ok := idxrIntf.(*indexer) + assert.True(ok) + assert.False(idxr.indexingEnabled) + + // Register a chain + chain1Ctx := snow.DefaultContextTest() + chain1Ctx.ChainID = ids.GenerateTestID() + isIncomplete, err := idxr.isIncomplete(chain1Ctx.ChainID) + assert.NoError(err) + assert.False(isIncomplete) + previouslyIndexed, err := idxr.previouslyIndexed(chain1Ctx.ChainID) + assert.NoError(err) + assert.False(previouslyIndexed) + chainEngine := &smengmocks.Engine{} + idxr.RegisterChain("chain1", chain1Ctx, chainEngine) + isIncomplete, err = idxr.isIncomplete(chain1Ctx.ChainID) + assert.NoError(err) + assert.True(isIncomplete) + assert.Len(idxr.blockIndices, 0) + + // Close and re-open the indexer, this time with indexing enabled + assert.NoError(config.DB.(*versiondb.Database).Commit()) + assert.NoError(idxr.Close()) + config.IndexingEnabled = true + config.DB = versiondb.New(baseDB) + idxrIntf, err = NewIndexer(config) + assert.NoError(err) + idxr, ok = idxrIntf.(*indexer) + assert.True(ok) + assert.True(idxr.indexingEnabled) + + // Register the chain again. Should die due to incomplete index. + assert.NoError(config.DB.(*versiondb.Database).Commit()) + idxr.RegisterChain("chain1", chain1Ctx, chainEngine) + assert.True(idxr.closed) + + // Close and re-open the indexer, this time with indexing enabled + // and incomplete index allowed. + assert.NoError(idxr.Close()) + config.AllowIncompleteIndex = true + config.DB = versiondb.New(baseDB) + idxrIntf, err = NewIndexer(config) + assert.NoError(err) + idxr, ok = idxrIntf.(*indexer) + assert.True(ok) + assert.True(idxr.allowIncompleteIndex) + + // Register the chain again. Should be OK + idxr.RegisterChain("chain1", chain1Ctx, chainEngine) + assert.False(idxr.closed) + + // Close the indexer and re-open with indexing disabled and + // incomplete index not allowed. + assert.NoError(idxr.Close()) + config.AllowIncompleteIndex = false + config.IndexingEnabled = false + config.DB = versiondb.New(baseDB) + idxrIntf, err = NewIndexer(config) + assert.NoError(err) + idxr, ok = idxrIntf.(*indexer) + assert.True(ok) +} + +// Ensure we only index chains in the primary network +func TestIgnoreNonDefaultChains(t *testing.T) { + assert := assert.New(t) + cd := &triggers.EventDispatcher{} + cd.Initialize(logging.NoLog{}) + dd := &triggers.EventDispatcher{} + dd.Initialize(logging.NoLog{}) + baseDB := memdb.New() + db := versiondb.New(baseDB) + config := Config{ + IndexingEnabled: true, + AllowIncompleteIndex: false, + Log: logging.NoLog{}, + DB: db, + ConsensusDispatcher: cd, + DecisionDispatcher: dd, + APIServer: &apiServerMock{}, + ShutdownF: func() {}, + } + + // Create indexer + idxrIntf, err := NewIndexer(config) + assert.NoError(err) + idxr, ok := idxrIntf.(*indexer) + assert.True(ok) + + // Assert state is right + chain1Ctx := snow.DefaultContextTest() + chain1Ctx.ChainID = ids.GenerateTestID() + chain1Ctx.SubnetID = ids.GenerateTestID() + + // RegisterChain should return without adding an index for this chain + chainVM := &smblockmocks.ChainVM{} + chainEngine := &smengmocks.Engine{} + chainEngine.On("GetVM").Return(chainVM) + idxr.RegisterChain("chain1", chain1Ctx, chainEngine) + assert.Len(idxr.blockIndices, 0) +} diff --git a/indexer/service.go b/indexer/service.go new file mode 100644 index 000000000000..19f1d0e90417 --- /dev/null +++ b/indexer/service.go @@ -0,0 +1,154 @@ +package indexer + +import ( + "fmt" + "net/http" + "time" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/formatting" + "github.com/ava-labs/avalanchego/utils/json" +) + +type service struct { + Index +} + +type FormattedContainer struct { + ID ids.ID `json:"id"` + Bytes string `json:"bytes"` + Timestamp time.Time `json:"timestamp"` + Encoding formatting.Encoding `json:"encoding"` + Index json.Uint64 `json:"index"` +} + +func newFormattedContainer(c Container, index uint64, enc formatting.Encoding) (FormattedContainer, error) { + fc := FormattedContainer{ + Encoding: enc, + ID: c.ID, + Index: json.Uint64(index), + } + bytesStr, err := formatting.Encode(enc, c.Bytes) + if err != nil { + return fc, err + } + fc.Bytes = bytesStr + fc.Timestamp = time.Unix(0, c.Timestamp) + return fc, nil +} + +type GetLastAcceptedArgs struct { + Encoding formatting.Encoding `json:"encoding"` +} + +func (s *service) GetLastAccepted(_ *http.Request, args *GetLastAcceptedArgs, reply *FormattedContainer) error { + container, err := s.Index.GetLastAccepted() + if err != nil { + return err + } + index, err := s.Index.GetIndex(container.ID) + if err != nil { + return fmt.Errorf("couldn't get index: %s", err) + } + *reply, err = newFormattedContainer(container, index, args.Encoding) + return err +} + +type GetContainer struct { + Index json.Uint64 `json:"index"` + Encoding formatting.Encoding `json:"encoding"` +} + +func (s *service) GetContainerByIndex(_ *http.Request, args *GetContainer, reply *FormattedContainer) error { + container, err := s.Index.GetContainerByIndex(uint64(args.Index)) + if err != nil { + return err + } + index, err := s.Index.GetIndex(container.ID) + if err != nil { + return fmt.Errorf("couldn't get index: %s", err) + } + *reply, err = newFormattedContainer(container, index, args.Encoding) + return err +} + +type GetContainerRangeArgs struct { + StartIndex json.Uint64 `json:"startIndex"` + NumToFetch json.Uint64 `json:"numToFetch"` + Encoding formatting.Encoding `json:"encoding"` +} + +type GetContainerRangeResponse struct { + Containers []FormattedContainer `json:"containers"` +} + +// GetContainerRange returns the transactions at index [startIndex], [startIndex+1], ... , [startIndex+n-1] +// If [n] == 0, returns an empty response (i.e. null). +// If [startIndex] > the last accepted index, returns an error (unless the above apply.) +// If [n] > [MaxFetchedByRange], returns an error. +// If we run out of transactions, returns the ones fetched before running out. +func (s *service) GetContainerRange(r *http.Request, args *GetContainerRangeArgs, reply *GetContainerRangeResponse) error { + containers, err := s.Index.GetContainerRange(uint64(args.StartIndex), uint64(args.NumToFetch)) + if err != nil { + return err + } + + reply.Containers = make([]FormattedContainer, len(containers)) + for i, container := range containers { + index, err := s.Index.GetIndex(container.ID) + if err != nil { + return fmt.Errorf("couldn't get index: %s", err) + } + reply.Containers[i], err = newFormattedContainer(container, index, args.Encoding) + if err != nil { + return err + } + } + return nil +} + +type GetIndexArgs struct { + ContainerID ids.ID `json:"containerID"` + Encoding formatting.Encoding `json:"encoding"` +} + +type GetIndexResponse struct { + Index json.Uint64 `json:"index"` +} + +func (s *service) GetIndex(r *http.Request, args *GetIndexArgs, reply *GetIndexResponse) error { + index, err := s.Index.GetIndex(args.ContainerID) + reply.Index = json.Uint64(index) + return err +} + +type IsAcceptedResponse struct { + IsAccepted bool `json:"isAccepted"` +} + +func (s *service) IsAccepted(r *http.Request, args *GetIndexArgs, reply *IsAcceptedResponse) error { + _, err := s.Index.GetIndex(args.ContainerID) + if err == nil { + reply.IsAccepted = true + return nil + } + if err == database.ErrNotFound { + reply.IsAccepted = false + return nil + } + return err +} + +func (s *service) GetContainerByID(r *http.Request, args *GetIndexArgs, reply *FormattedContainer) error { + container, err := s.Index.GetContainerByID(args.ContainerID) + if err != nil { + return err + } + index, err := s.Index.GetIndex(container.ID) + if err != nil { + return fmt.Errorf("couldn't get index: %s", err) + } + *reply, err = newFormattedContainer(container, index, args.Encoding) + return err +} diff --git a/ipcs/eventsocket.go b/ipcs/eventsocket.go index 5545cbe3067a..66dc62fe7398 100644 --- a/ipcs/eventsocket.go +++ b/ipcs/eventsocket.go @@ -119,7 +119,7 @@ func newEventIPCSocket(ctx context, chainID ids.ID, name string, events *trigger return nil, err } - if err := events.RegisterChain(chainID, ipcName, eis); err != nil { + if err := events.RegisterChain(chainID, ipcName, eis, false); err != nil { if err := eis.stop(); err != nil { return nil, err } diff --git a/main/keys.go b/main/keys.go index 8d44f6390bd7..d87243269408 100644 --- a/main/keys.go +++ b/main/keys.go @@ -58,7 +58,7 @@ const ( networkHealthMaxTimeSinceMsgSentKey = "network-health-max-time-since-msg-sent" networkHealthMaxPortionSendQueueFillKey = "network-health-max-portion-send-queue-full" networkHealthMaxSendFailRateKey = "network-health-max-send-fail-rate" - networkHealthMaxTimeSinceNoReqsKey = "network-health-max-time-since-no-requests" + networkHealthMaxOutstandingDurationKey = "network-health-max-outstanding-request-duration" sendQueueSizeKey = "send-queue-size" benchlistFailThresholdKey = "benchlist-fail-threshold" benchlistPeerSummaryEnabledKey = "benchlist-peer-summary-enabled" @@ -99,6 +99,8 @@ const ( disconnectedCheckFreqKey = "disconnected-check-frequency" disconnectedRestartTimeoutKey = "disconnected-restart-timeout" restartOnDisconnectedKey = "restart-on-disconnected" + indexEnabledKey = "index-enabled" + indexAllowIncompleteKey = "index-allow-incomplete" routerHealthMaxDropRateKey = "router-health-max-drop-rate" routerHealthMaxOutstandingRequestsKey = "router-health-max-outstanding-requests" healthCheckFreqKey = "health-check-frequency" diff --git a/main/params.go b/main/params.go index 8fed27e2bc9d..472f05458fc7 100644 --- a/main/params.go +++ b/main/params.go @@ -164,7 +164,7 @@ func avalancheFlagSet() *flag.FlagSet { fs.String(httpsCertFileKey, "", "TLS certificate file for the HTTPs server") fs.String(httpAllowedOrigins, "*", "Origins to allow on the HTTP port. Defaults to * which allows all origins. Example: https://*.avax.network https://*.avax-test.network") fs.Bool(apiAuthRequiredKey, false, "Require authorization token to call HTTP APIs") - fs.String(apiAuthPasswordFileKey, "", "Password file used to initially create/validate API authorization tokens. Can be changed via API call.") + fs.String(apiAuthPasswordFileKey, "", "Password file used to initially create/validate API authorization tokens. Leading and trailing whitespace is removed from the password. Can be changed via API call.") // Enable/Disable APIs fs.Bool(adminAPIEnabledKey, false, "If true, this node exposes the Admin API") fs.Bool(infoAPIEnabledKey, true, "If true, this node exposes the Info API") @@ -187,7 +187,7 @@ func avalancheFlagSet() *flag.FlagSet { // Router Health fs.Float64(routerHealthMaxDropRateKey, 1, "Node reports unhealthy if the router drops more than this portion of messages.") fs.Uint(routerHealthMaxOutstandingRequestsKey, 1024, "Node reports unhealthy if there are more than this many outstanding consensus requests (Get, PullQuery, etc.) over all chains") - fs.Duration(networkHealthMaxTimeSinceNoReqsKey, 5*time.Minute, "Node reports unhealthy if there is at least 1 outstanding request continuously for this duration") + fs.Duration(networkHealthMaxOutstandingDurationKey, 5*time.Minute, "Node reports unhealthy if there has been a request outstanding for this duration") // Staking fs.Uint(stakingPortKey, 9651, "Port of the consensus server") @@ -237,6 +237,11 @@ func avalancheFlagSet() *flag.FlagSet { fs.String(ipcsChainIDsKey, "", "Comma separated list of chain ids to add to the IPC engine. Example: 11111111111111111111111111111111LpoYY,4R5p2RXDGLqaifZE4hHWH9owe34pfoBULn1DrQTWivjg8o4aH") fs.String(ipcsPathKey, defaultString, "The directory (Unix) or named pipe name prefix (Windows) for IPC sockets") + // Indexer + // TODO handle the below line better + fs.Bool(indexEnabledKey, false, "If true, index all accepted containers and transactions and expose them via an API") + fs.Bool(indexAllowIncompleteKey, false, "If true, allow running the node in such a way that could cause an index to miss transactions. Ignored if index is disabled.") + return fs } @@ -513,7 +518,7 @@ func setNodeConfig(v *viper.Viper) error { if err != nil { return fmt.Errorf("api-auth-password-file %q failed to be read with: %w", passwordFile, err) } - Config.APIAuthPassword = string(pwBytes) + Config.APIAuthPassword = strings.TrimSpace(string(pwBytes)) if !password.SufficientlyStrong(Config.APIAuthPassword, password.OK) { return errors.New("api-auth-password is not strong enough") } @@ -526,6 +531,7 @@ func setNodeConfig(v *viper.Viper) error { Config.MetricsAPIEnabled = v.GetBool(metricsAPIEnabledKey) Config.HealthAPIEnabled = v.GetBool(healthAPIEnabledKey) Config.IPCAPIEnabled = v.GetBool(ipcAPIEnabledKey) + Config.IndexAPIEnabled = v.GetBool(indexEnabledKey) // Throughput: Config.ThroughputServerEnabled = v.GetBool(xputServerEnabledKey) @@ -541,14 +547,14 @@ func setNodeConfig(v *viper.Viper) error { Config.ConsensusRouter = &router.ChainRouter{} Config.RouterHealthConfig.MaxDropRate = v.GetFloat64(routerHealthMaxDropRateKey) Config.RouterHealthConfig.MaxOutstandingRequests = int(v.GetUint(routerHealthMaxOutstandingRequestsKey)) - Config.RouterHealthConfig.MaxTimeSinceNoOutstandingRequests = v.GetDuration(networkHealthMaxTimeSinceNoReqsKey) + Config.RouterHealthConfig.MaxOutstandingDuration = v.GetDuration(networkHealthMaxOutstandingDurationKey) Config.RouterHealthConfig.MaxRunTimeRequests = v.GetDuration(networkMaximumTimeoutKey) Config.RouterHealthConfig.MaxDropRateHalflife = healthCheckAveragerHalflife switch { case Config.RouterHealthConfig.MaxDropRate < 0 || Config.RouterHealthConfig.MaxDropRate > 1: return fmt.Errorf("%s must be in [0,1]", routerHealthMaxDropRateKey) - case Config.RouterHealthConfig.MaxTimeSinceNoOutstandingRequests <= 0: - return fmt.Errorf("%s must be positive", networkHealthMaxTimeSinceNoReqsKey) + case Config.RouterHealthConfig.MaxOutstandingDuration <= 0: + return fmt.Errorf("%s must be positive", networkHealthMaxOutstandingDurationKey) } // IPCs @@ -714,6 +720,9 @@ func setNodeConfig(v *viper.Viper) error { } Config.CorethConfig = corethConfigString + // Indexer + Config.IndexAllowIncomplete = v.GetBool(indexAllowIncompleteKey) + // Bootstrap Configs Config.RetryBootstrap = v.GetBool(retryBootstrap) Config.RetryBootstrapMaxAttempts = v.GetInt(retryBootstrapMaxAttempts) diff --git a/network/network.go b/network/network.go index bfc3ff449bad..62c029ef4e59 100644 --- a/network/network.go +++ b/network/network.go @@ -59,9 +59,6 @@ var ( errNetworkLayerUnhealthy = errors.New("network layer is unhealthy") ) -// Network Upgrade -var minimumUnmaskedVersion = version.NewDefaultVersion(constants.PlatformName, 1, 1, 0) - func init() { rand.Seed(time.Now().UnixNano()) } // Network defines the functionality of the networking library. @@ -118,7 +115,7 @@ type network struct { id ids.ShortID ip utils.DynamicIPDesc networkID uint32 - msgVersion version.Version + versionCompatibility version.Compatibility parser version.Parser listener net.Listener dialer Dialer @@ -149,7 +146,6 @@ type network struct { connMeterMaxConns int connMeter ConnMeter b Builder - apricotPhase0Time time.Time // stateLock should never be held when grabbing a peer senderLock stateLock sync.RWMutex @@ -215,7 +211,7 @@ func NewDefaultNetwork( id ids.ShortID, ip utils.DynamicIPDesc, networkID uint32, - version version.Version, + versionCompatibility version.Compatibility, parser version.Parser, listener net.Listener, dialer Dialer, @@ -230,7 +226,6 @@ func NewDefaultNetwork( restartOnDisconnected bool, disconnectedCheckFreq time.Duration, disconnectedRestartTimeout time.Duration, - apricotPhase0Time time.Time, sendQueueSize uint32, healthConfig HealthConfig, benchlistManager benchlist.Manager, @@ -242,7 +237,7 @@ func NewDefaultNetwork( id, ip, networkID, - version, + versionCompatibility, parser, listener, dialer, @@ -275,7 +270,6 @@ func NewDefaultNetwork( restartOnDisconnected, disconnectedCheckFreq, disconnectedRestartTimeout, - apricotPhase0Time, healthConfig, benchlistManager, peerAliasTimeout, @@ -289,7 +283,7 @@ func NewNetwork( id ids.ShortID, ip utils.DynamicIPDesc, networkID uint32, - version version.Version, + versionCompatibility version.Compatibility, parser version.Parser, listener net.Listener, dialer Dialer, @@ -322,26 +316,25 @@ func NewNetwork( restartOnDisconnected bool, disconnectedCheckFreq time.Duration, disconnectedRestartTimeout time.Duration, - apricotPhase0Time time.Time, healthConfig HealthConfig, benchlistManager benchlist.Manager, peerAliasTimeout time.Duration, ) Network { // #nosec G404 netw := &network{ - log: log, - id: id, - ip: ip, - networkID: networkID, - msgVersion: version, - parser: parser, - listener: listener, - dialer: dialer, - serverUpgrader: serverUpgrader, - clientUpgrader: clientUpgrader, - vdrs: vdrs, - beacons: beacons, - router: router, + log: log, + id: id, + ip: ip, + networkID: networkID, + versionCompatibility: versionCompatibility, + parser: parser, + listener: listener, + dialer: dialer, + serverUpgrader: serverUpgrader, + clientUpgrader: clientUpgrader, + vdrs: vdrs, + beacons: beacons, + router: router, // This field just makes sure we don't connect to ourselves when TLS is // disabled. So, cryptographically secure random number generation isn't // used here. @@ -377,7 +370,6 @@ func NewNetwork( disconnectedCheckFreq: disconnectedCheckFreq, connectedMeter: timer.TimedMeter{Duration: disconnectedRestartTimeout}, restarter: restarter, - apricotPhase0Time: apricotPhase0Time, healthConfig: healthConfig, benchlistManager: benchlistManager, } @@ -406,7 +398,7 @@ func (n *network) GetAcceptedFrontier(validatorIDs ids.ShortSet, chainID ids.ID, for _, peerElement := range n.getPeers(validatorIDs) { peer := peerElement.peer vID := peerElement.id - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send GetAcceptedFrontier(%s, %s, %d)", vID, chainID, @@ -440,7 +432,7 @@ func (n *network) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, requ } peer := n.getPeer(validatorID) - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send AcceptedFrontier(%s, %s, %d, %s)", validatorID, chainID, @@ -475,7 +467,7 @@ func (n *network) GetAccepted(validatorIDs ids.ShortSet, chainID ids.ID, request for _, peerElement := range n.getPeers(validatorIDs) { peer := peerElement.peer vID := peerElement.id - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send GetAccepted(%s, %s, %d, %s)", vID, chainID, @@ -510,7 +502,7 @@ func (n *network) Accepted(validatorID ids.ShortID, chainID ids.ID, requestID ui } peer := n.getPeer(validatorID) - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send Accepted(%s, %s, %d, %s)", validatorID, chainID, @@ -538,7 +530,7 @@ func (n *network) GetAncestors(validatorID ids.ShortID, chainID ids.ID, requestI } peer := n.getPeer(validatorID) - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send GetAncestors(%s, %s, %d, %s)", validatorID, chainID, @@ -567,7 +559,7 @@ func (n *network) MultiPut(validatorID ids.ShortID, chainID ids.ID, requestID ui } peer := n.getPeer(validatorID) - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send MultiPut(%s, %s, %d, %d)", validatorID, chainID, @@ -591,7 +583,7 @@ func (n *network) Get(validatorID ids.ShortID, chainID ids.ID, requestID uint32, n.log.AssertNoError(err) peer := n.getPeer(validatorID) - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send Get(%s, %s, %d, %s)", validatorID, chainID, @@ -625,7 +617,7 @@ func (n *network) Put(validatorID ids.ShortID, chainID ids.ID, requestID uint32, } peer := n.getPeer(validatorID) - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send Put(%s, %s, %d, %s)", validatorID, chainID, @@ -663,7 +655,7 @@ func (n *network) PushQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID for _, peerElement := range n.getPeers(validatorIDs) { peer := peerElement.peer vID := peerElement.id - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send PushQuery(%s, %s, %d, %s)", vID, chainID, @@ -694,7 +686,7 @@ func (n *network) PullQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID for _, peerElement := range n.getPeers(validatorIDs) { peer := peerElement.peer vID := peerElement.id - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send PullQuery(%s, %s, %d, %s)", vID, chainID, @@ -729,7 +721,7 @@ func (n *network) Chits(validatorID ids.ShortID, chainID ids.ID, requestID uint3 } peer := n.getPeer(validatorID) - if peer == nil || !peer.connected.GetValue() || !peer.Send(msg) { + if peer == nil || !peer.connected.GetValue() || !peer.compatible.GetValue() || !peer.Send(msg) { n.log.Debug("failed to send Chits(%s, %s, %d, %s)", validatorID, chainID, @@ -799,7 +791,7 @@ func (n *network) upgradeIncoming(remoteAddr string) (bool, error) { func (n *network) Dispatch() error { go n.gossip() // Periodically gossip peers go func() { - duration := time.Until(n.apricotPhase0Time) + duration := time.Until(n.versionCompatibility.MaskTime()) time.Sleep(duration) n.stateLock.Lock() @@ -894,6 +886,7 @@ func (n *network) Peers(nodeIDs []ids.ShortID) []PeerID { PublicIP: peer.getIP().String(), ID: peer.id.PrefixedString(constants.NodeIDPrefix), Version: peer.versionStr.GetValue().(string), + Up: peer.compatible.GetValue(), LastSent: time.Unix(atomic.LoadInt64(&peer.lastSent), 0), LastReceived: time.Unix(atomic.LoadInt64(&peer.lastReceived), 0), Benched: n.benchlistManager.GetBenched(peer.id), @@ -910,6 +903,7 @@ func (n *network) Peers(nodeIDs []ids.ShortID) []PeerID { PublicIP: peer.getIP().String(), ID: peer.id.PrefixedString(constants.NodeIDPrefix), Version: peer.versionStr.GetValue().(string), + Up: peer.compatible.GetValue(), LastSent: time.Unix(atomic.LoadInt64(&peer.lastSent), 0), LastReceived: time.Unix(atomic.LoadInt64(&peer.lastReceived), 0), Benched: n.benchlistManager.GetBenched(peer.id), @@ -1057,7 +1051,7 @@ func (n *network) gossip() { !ip.IsZero() && n.vdrs.Contains(peer.id) { peerVersion := peer.versionStruct.GetValue().(version.Version) - if !peerVersion.Before(minimumUnmaskedVersion) || time.Since(n.apricotPhase0Time) < 0 { + if n.versionCompatibility.Unmaskable(peerVersion) == nil { ips = append(ips, ip) } } @@ -1300,7 +1294,7 @@ func (n *network) validatorIPs() []utils.IPDesc { ip := peer.getIP() if peer.connected.GetValue() && !ip.IsZero() && n.vdrs.Contains(peer.id) { peerVersion := peer.versionStruct.GetValue().(version.Version) - if !peerVersion.Before(minimumUnmaskedVersion) || time.Since(n.apricotPhase0Time) < 0 { + if n.versionCompatibility.Unmaskable(peerVersion) == nil { ips = append(ips, ip) } } @@ -1320,7 +1314,7 @@ func (n *network) connected(p *peer) { peerVersion := p.versionStruct.GetValue().(version.Version) if n.hasMasked { - if peerVersion.Before(minimumUnmaskedVersion) { + if n.versionCompatibility.Unmaskable(peerVersion) != nil { if err := n.vdrs.MaskValidator(p.id); err != nil { n.log.Error("failed to mask validator %s due to %s", p.id, err) } @@ -1331,7 +1325,7 @@ func (n *network) connected(p *peer) { } n.log.Verbo("The new staking set is:\n%s", n.vdrs) } else { - if peerVersion.Before(minimumUnmaskedVersion) { + if n.versionCompatibility.WontMask(peerVersion) != nil { n.maskedValidators.Add(p.id) } else { n.maskedValidators.Remove(p.id) @@ -1349,7 +1343,12 @@ func (n *network) connected(p *peer) { n.connectedIPs[str] = struct{}{} } - n.router.Connected(p.id) + compatible := n.versionCompatibility.Compatible(peerVersion) == nil + p.compatible.SetValue(compatible) + + if compatible { + n.router.Connected(p.id) + } } // should only be called after the peer is marked as connected. @@ -1376,7 +1375,8 @@ func (n *network) disconnected(p *peer) { n.track(ip) } - if p.connected.GetValue() { + if p.compatible.GetValue() { + p.compatible.SetValue(false) n.router.Disconnected(p.id) } } diff --git a/network/network_test.go b/network/network_test.go index 74d908329dcc..3615834f4c97 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -254,13 +254,23 @@ func TestNewDefaultNetwork(t *testing.T) { vdrs := validators.NewSet() handler := &testHandler{} + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + net := NewDefaultNetwork( prometheus.NewRegistry(), log, id, ip, networkID, - appVersion, + versionManager, versionParser, listener, caller, @@ -275,7 +285,6 @@ func TestNewDefaultNetwork(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -371,13 +380,23 @@ func TestEstablishConnection(t *testing.T) { }, } + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + net0 := NewDefaultNetwork( prometheus.NewRegistry(), log, id0, ip0, networkID, - appVersion, + versionManager, versionParser, listener0, caller0, @@ -392,7 +411,6 @@ func TestEstablishConnection(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -406,7 +424,7 @@ func TestEstablishConnection(t *testing.T) { id1, ip1, networkID, - appVersion, + versionManager, versionParser, listener1, caller1, @@ -421,7 +439,6 @@ func TestEstablishConnection(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -529,13 +546,23 @@ func TestDoubleTrack(t *testing.T) { }, } + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + net0 := NewDefaultNetwork( prometheus.NewRegistry(), log, id0, ip0, networkID, - appVersion, + versionManager, versionParser, listener0, caller0, @@ -550,7 +577,6 @@ func TestDoubleTrack(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -564,7 +590,7 @@ func TestDoubleTrack(t *testing.T) { id1, ip1, networkID, - appVersion, + versionManager, versionParser, listener1, caller1, @@ -579,7 +605,6 @@ func TestDoubleTrack(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -688,13 +713,23 @@ func TestDoubleClose(t *testing.T) { }, } + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + net0 := NewDefaultNetwork( prometheus.NewRegistry(), log, id0, ip0, networkID, - appVersion, + versionManager, versionParser, listener0, caller0, @@ -709,7 +744,6 @@ func TestDoubleClose(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -723,7 +757,7 @@ func TestDoubleClose(t *testing.T) { id1, ip1, networkID, - appVersion, + versionManager, versionParser, listener1, caller1, @@ -738,7 +772,6 @@ func TestDoubleClose(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -852,13 +885,23 @@ func TestTrackConnected(t *testing.T) { }, } + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + net0 := NewDefaultNetwork( prometheus.NewRegistry(), log, id0, ip0, networkID, - appVersion, + versionManager, versionParser, listener0, caller0, @@ -873,7 +916,6 @@ func TestTrackConnected(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -887,7 +929,7 @@ func TestTrackConnected(t *testing.T) { id1, ip1, networkID, - appVersion, + versionManager, versionParser, listener1, caller1, @@ -902,7 +944,6 @@ func TestTrackConnected(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -990,13 +1031,23 @@ func TestTrackConnectedRace(t *testing.T) { vdrs := validators.NewSet() handler := &testHandler{} + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + net0 := NewDefaultNetwork( prometheus.NewRegistry(), log, id0, ip0, networkID, - appVersion, + versionManager, versionParser, listener0, caller0, @@ -1011,7 +1062,6 @@ func TestTrackConnectedRace(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1025,7 +1075,7 @@ func TestTrackConnectedRace(t *testing.T) { id1, ip1, networkID, - appVersion, + versionManager, versionParser, listener1, caller1, @@ -1040,7 +1090,6 @@ func TestTrackConnectedRace(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1256,13 +1305,23 @@ func TestPeerAliasesTicker(t *testing.T) { assert.Fail(t, "caller 0 unauthorized close", local.String(), remote.String()) } + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + net0 := NewDefaultNetwork( prometheus.NewRegistry(), log, id0, ip0, networkID, - appVersion, + versionManager, versionParser, listener0, caller0, @@ -1277,7 +1336,6 @@ func TestPeerAliasesTicker(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1291,7 +1349,7 @@ func TestPeerAliasesTicker(t *testing.T) { id1, ip1, networkID, - appVersion, + versionManager, versionParser, listener1, caller1, @@ -1306,7 +1364,6 @@ func TestPeerAliasesTicker(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1320,7 +1377,7 @@ func TestPeerAliasesTicker(t *testing.T) { id1, ip2, networkID, - appVersion, + versionManager, versionParser, listener2, caller2, @@ -1335,7 +1392,6 @@ func TestPeerAliasesTicker(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1349,7 +1405,7 @@ func TestPeerAliasesTicker(t *testing.T) { id2, ip2, networkID, - appVersion, + versionManager, versionParser, listener3, caller3, @@ -1364,7 +1420,6 @@ func TestPeerAliasesTicker(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1662,13 +1717,23 @@ func TestPeerAliasesDisconnect(t *testing.T) { assert.Fail(t, "caller 0 unauthorized close", local.String(), remote.String()) } + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + net0 := NewDefaultNetwork( prometheus.NewRegistry(), log, id0, ip0, networkID, - appVersion, + versionManager, versionParser, listener0, caller0, @@ -1683,7 +1748,6 @@ func TestPeerAliasesDisconnect(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1697,7 +1761,7 @@ func TestPeerAliasesDisconnect(t *testing.T) { id1, ip1, networkID, - appVersion, + versionManager, versionParser, listener1, caller1, @@ -1712,7 +1776,6 @@ func TestPeerAliasesDisconnect(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1726,7 +1789,7 @@ func TestPeerAliasesDisconnect(t *testing.T) { id1, ip2, networkID, - appVersion, + versionManager, versionParser, listener2, caller2, @@ -1741,7 +1804,6 @@ func TestPeerAliasesDisconnect(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), @@ -1755,7 +1817,7 @@ func TestPeerAliasesDisconnect(t *testing.T) { id2, ip2, networkID, - appVersion, + versionManager, versionParser, listener3, caller3, @@ -1770,7 +1832,6 @@ func TestPeerAliasesDisconnect(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), diff --git a/network/peer.go b/network/peer.go index e2f3c82b8de6..3c290cde85d4 100644 --- a/network/peer.go +++ b/network/peer.go @@ -44,6 +44,10 @@ type peer struct { // has been returned. is only modified on the connection's reader routine. connected utils.AtomicBool + // if the peer is connected and the peer is able to follow the protocol. is + // only modified on the connection's reader routine. + compatible utils.AtomicBool + // only close the peer once once sync.Once @@ -389,8 +393,18 @@ func (p *peer) handle(msg Msg) { } peerVersion := p.versionStruct.GetValue().(version.Version) - if peerVersion.Before(minimumUnmaskedVersion) && time.Until(p.net.apricotPhase0Time) < 0 { + if p.net.versionCompatibility.Compatible(peerVersion) != nil { p.net.log.Verbo("dropping message from un-upgraded validator %s", p.id) + + if p.compatible.GetValue() { + p.net.stateLock.Lock() + defer p.net.stateLock.Unlock() + + if p.compatible.GetValue() { + p.net.router.Disconnected(p.id) + p.compatible.SetValue(false) + } + } return } @@ -482,16 +496,16 @@ func (p *peer) Version() { p.net.nodeID, p.net.clock.Unix(), p.net.ip.IP(), - p.net.msgVersion.String(), + p.net.versionCompatibility.Version().String(), ) p.net.stateLock.RUnlock() p.net.log.AssertNoError(err) if p.Send(msg) { - p.net.version.numSent.Inc() - p.net.version.sentBytes.Add(float64(len(msg.Bytes()))) + p.net.metrics.version.numSent.Inc() + p.net.metrics.version.sentBytes.Add(float64(len(msg.Bytes()))) p.net.sendFailRateCalculator.Observe(0, p.net.clock.Time()) } else { - p.net.version.numFailed.Inc() + p.net.metrics.version.numFailed.Inc() p.net.sendFailRateCalculator.Observe(1, p.net.clock.Time()) } } @@ -616,7 +630,7 @@ func (p *peer) version(msg Msg) { return } - if p.net.msgVersion.Before(peerVersion) { + if p.net.versionCompatibility.Version().Before(peerVersion) { if p.net.beacons.Contains(p.id) { p.net.log.Info("beacon %s attempting to connect with newer version %s. You may want to update your client", p.id, @@ -628,7 +642,7 @@ func (p *peer) version(msg Msg) { } } - if err := p.net.msgVersion.Compatible(peerVersion); err != nil { + if err := p.net.versionCompatibility.Connectable(peerVersion); err != nil { p.net.log.Debug("peer version not compatible due to %s", err) if !p.net.beacons.Contains(p.id) { diff --git a/network/peer_id.go b/network/peer_id.go index 1d2174b896d3..96afa041fa88 100644 --- a/network/peer_id.go +++ b/network/peer_id.go @@ -15,6 +15,7 @@ type PeerID struct { PublicIP string `json:"publicIP"` ID string `json:"nodeID"` Version string `json:"version"` + Up bool `json:"up"` LastSent time.Time `json:"lastSent"` LastReceived time.Time `json:"lastReceived"` Benched []ids.ID `json:"benched"` diff --git a/network/peer_test.go b/network/peer_test.go index 04892d846bb3..741a0e6a20bd 100644 --- a/network/peer_test.go +++ b/network/peer_test.go @@ -67,13 +67,23 @@ func TestPeer_Close(t *testing.T) { vdrs := validators.NewSet() handler := &testHandler{} + versionManager := version.NewCompatibility( + appVersion, + appVersion, + time.Now(), + appVersion, + appVersion, + time.Now(), + appVersion, + ) + netwrk := NewDefaultNetwork( prometheus.NewRegistry(), log, id, ip, networkID, - appVersion, + versionManager, versionParser, listener, caller, @@ -88,7 +98,6 @@ func TestPeer_Close(t *testing.T) { false, 0, 0, - time.Now(), defaultSendQueueSize, HealthConfig{}, benchlist.NewManager(&benchlist.Config{}), diff --git a/node/config.go b/node/config.go index a03c6cfd778c..35a33d6d5da0 100644 --- a/node/config.go +++ b/node/config.go @@ -93,6 +93,7 @@ type Config struct { KeystoreAPIEnabled bool MetricsAPIEnabled bool HealthAPIEnabled bool + IndexAPIEnabled bool // Logging configuration LoggingConfig logging.Config @@ -138,6 +139,8 @@ type Config struct { // Coreth CorethConfig string + IndexAllowIncomplete bool + // Should Bootstrap be retried RetryBootstrap bool diff --git a/node/node.go b/node/node.go index 1f0a7490d307..4233115b11df 100644 --- a/node/node.go +++ b/node/node.go @@ -17,12 +17,13 @@ import ( "syscall" "time" - "github.com/ava-labs/avalanchego/api" "github.com/ava-labs/avalanchego/api/admin" + "github.com/ava-labs/avalanchego/api/auth" "github.com/ava-labs/avalanchego/api/health" "github.com/ava-labs/avalanchego/api/info" "github.com/ava-labs/avalanchego/api/keystore" "github.com/ava-labs/avalanchego/api/metrics" + "github.com/ava-labs/avalanchego/api/server" "github.com/ava-labs/avalanchego/chains" "github.com/ava-labs/avalanchego/chains/atomic" "github.com/ava-labs/avalanchego/database" @@ -30,8 +31,10 @@ import ( "github.com/ava-labs/avalanchego/database/prefixdb" "github.com/ava-labs/avalanchego/genesis" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/indexer" "github.com/ava-labs/avalanchego/ipcs" "github.com/ava-labs/avalanchego/network" + "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/networking/benchlist" "github.com/ava-labs/avalanchego/snow/networking/router" "github.com/ava-labs/avalanchego/snow/networking/timeout" @@ -64,15 +67,10 @@ const ( ) var ( + genesisHashKey = []byte("genesisID") + indexerDBPrefix = []byte{0x00} errPrimarySubnetNotBootstrapped = errors.New("primary subnet has not finished bootstrapping") -) - -var ( - genesisHashKey = []byte("genesisID") - // Version is the version of this code - Version = version.NewDefaultVersion(constants.PlatformName, 1, 3, 0) - versionParser = version.NewDefaultParser() beaconConnectionTimeout = 1 * time.Minute ) @@ -89,8 +87,11 @@ type Node struct { // Storage for this node DB database.Database + // Indexes blocks, transactions and blocks + indexer indexer.Indexer + // Handles calls to Keystore API - keystoreServer keystore.Keystore + keystore keystore.Keystore // Manages shared memory sharedMemory atomic.Memory @@ -123,7 +124,7 @@ type Node struct { vdrs validators.Manager // Handles HTTP API calls - APIServer api.Server + APIServer server.Server // This node's configuration Config *Config @@ -234,14 +235,24 @@ func (n *Node) initNetworking() error { } } + versionManager := version.NewCompatibility( + Version, + MinimumCompatibleVersion, + GetApricotPhase1Time(n.Config.NetworkID), + PrevMinimumCompatibleVersion, + MinimumUnmaskedVersion, + GetApricotPhase0Time(n.Config.NetworkID), + PrevMinimumUnmaskedVersion, + ) + n.Net = network.NewDefaultNetwork( n.Config.ConsensusParams.Metrics, n.Log, n.ID, n.Config.StakingIP, n.Config.NetworkID, - Version, - versionParser, + versionManager, + VersionParser, listener, dialer, serverUpgrader, @@ -255,7 +266,6 @@ func (n *Node) initNetworking() error { n.Config.RestartOnDisconnected, n.Config.DisconnectedCheckFreq, n.Config.DisconnectedRestartTimeout, - n.Config.ApricotPhase0Time, n.Config.SendQueueSize, n.Config.NetworkHealthConfig, n.benchlistManager, @@ -475,10 +485,35 @@ func (n *Node) initIPCs() error { return err } +// Initialize [n.indexer]. +// Should only be called after [n.DB], [n.DecisionDispatcher], [n.ConsensusDispatcher], +// [n.Log], [n.APIServer], [n.chainManager] are initialized +func (n *Node) initIndexer() error { + txIndexerDB := prefixdb.New(indexerDBPrefix, n.DB) + var err error + n.indexer, err = indexer.NewIndexer(indexer.Config{ + IndexingEnabled: n.Config.IndexAPIEnabled, + AllowIncompleteIndex: n.Config.IndexAllowIncomplete, + DB: txIndexerDB, + Log: n.Log, + DecisionDispatcher: n.DecisionDispatcher, + ConsensusDispatcher: n.ConsensusDispatcher, + APIServer: &n.APIServer, + ShutdownF: n.Shutdown, + }) + if err != nil { + return fmt.Errorf("couldn't create index for txs: %w", err) + } + + // Chain manager will notify indexer when a chain is created + n.chainManager.AddRegistrant(n.indexer) + + return nil +} + // Initializes the Platform chain. -// Its genesis data specifies the other chains that should -// be created. -func (n *Node) initChains(genesisBytes []byte, avaxAssetID ids.ID) error { +// Its genesis data specifies the other chains that should be created. +func (n *Node) initChains(genesisBytes []byte) { n.Log.Info("initializing chains") // Create the Platform Chain @@ -489,23 +524,48 @@ func (n *Node) initChains(genesisBytes []byte, avaxAssetID ids.ID) error { VMAlias: platformvm.ID.String(), CustomBeacons: n.beacons, }) - - return nil } // initAPIServer initializes the server that handles HTTP calls func (n *Node) initAPIServer() error { n.Log.Info("initializing API server") - return n.APIServer.Initialize( + if !n.Config.APIRequireAuthToken { + n.APIServer.Initialize( + n.Log, + n.LogFactory, + n.Config.HTTPHost, + n.Config.HTTPPort, + n.Config.APIAllowedOrigins, + ) + return nil + } + + a, err := auth.New(n.Log, "auth", n.Config.APIAuthPassword) + if err != nil { + return err + } + + n.APIServer.Initialize( n.Log, n.LogFactory, n.Config.HTTPHost, n.Config.HTTPPort, - n.Config.APIRequireAuthToken, - n.Config.APIAuthPassword, n.Config.APIAllowedOrigins, + a, ) + + // only create auth service if token authorization is required + n.Log.Info("API authorization is enabled. Auth tokens must be passed in the header of API requests, except requests to the auth service.") + authService, err := a.CreateHandler() + if err != nil { + return err + } + handler := &common.HTTPHandler{ + LockOptions: common.NoLock, + Handler: authService, + } + return n.APIServer.AddRoute(handler, &sync.RWMutex{}, "auth", "", n.Log) } // Create the vmManager, chainManager and register the following VMs: @@ -579,7 +639,7 @@ func (n *Node) initChainManager(avaxAssetID ids.ID) error { NodeID: n.ID, NetworkID: n.Config.NetworkID, Server: &n.APIServer, - Keystore: &n.keystoreServer, + Keystore: n.keystore, AtomicMemory: &n.sharedMemory, AVAXAssetID: avaxAssetID, XChainID: xChainID, @@ -617,7 +677,7 @@ func (n *Node) initChainManager(avaxAssetID ids.ID) error { MinStakeDuration: n.Config.MinStakeDuration, MaxStakeDuration: n.Config.MaxStakeDuration, StakeMintingPeriod: n.Config.StakeMintingPeriod, - ApricotPhase0Time: n.Config.ApricotPhase0Time, + ApricotPhase0Time: GetApricotPhase0Time(n.Config.NetworkID), }), n.vmManager.RegisterVMFactory(avm.ID, &avm.Factory{ CreationFee: n.Config.CreationTxFee, @@ -653,10 +713,8 @@ func (n *Node) initSharedMemory() error { func (n *Node) initKeystoreAPI() error { n.Log.Info("initializing keystore") keystoreDB := prefixdb.New([]byte("keystore"), n.DB) - if err := n.keystoreServer.Initialize(n.Log, keystoreDB); err != nil { - return err - } - keystoreHandler, err := n.keystoreServer.CreateHandler() + n.keystore = keystore.New(n.Log, keystoreDB) + keystoreHandler, err := n.keystore.CreateHandler() if err != nil { return err } @@ -665,7 +723,11 @@ func (n *Node) initKeystoreAPI() error { return nil } n.Log.Info("initializing keystore API") - return n.APIServer.AddRoute(keystoreHandler, &sync.RWMutex{}, "keystore", "", n.HTTPLog) + handler := &common.HTTPHandler{ + LockOptions: common.NoLock, + Handler: keystoreHandler, + } + return n.APIServer.AddRoute(handler, &sync.RWMutex{}, "keystore", "", n.HTTPLog) } // initMetricsAPI initializes the Metrics API @@ -908,9 +970,11 @@ func (n *Node) Initialize( if err := n.initAliases(n.Config.GenesisBytes); err != nil { // Set up aliases return fmt.Errorf("couldn't initialize aliases: %w", err) } - if err := n.initChains(n.Config.GenesisBytes, n.Config.AvaxAssetID); err != nil { // Start the Platform chain - return fmt.Errorf("couldn't initialize chains: %w", err) + if err := n.initIndexer(); err != nil { + return fmt.Errorf("couldn't initialize indexer: %w", err) } + // Start the Platform chain + n.initChains(n.Config.GenesisBytes) return nil } @@ -938,6 +1002,9 @@ func (n *Node) shutdown() { if err := n.APIServer.Shutdown(); err != nil { n.Log.Debug("error during API shutdown: %s", err) } + if err := n.indexer.Close(); err != nil { + n.Log.Debug("error closing tx indexer: %w", err) + } utils.ClearSignals(n.nodeCloser) n.doneShuttingDown.Done() n.Log.Info("finished node shutdown") diff --git a/node/version.go b/node/version.go new file mode 100644 index 000000000000..80912fcbd532 --- /dev/null +++ b/node/version.go @@ -0,0 +1,46 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package node + +import ( + "time" + + "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/version" +) + +var ( + Version = version.NewDefaultVersion(constants.PlatformName, 1, 3, 2) + MinimumCompatibleVersion = version.NewDefaultVersion(constants.PlatformName, 1, 3, 0) + PrevMinimumCompatibleVersion = version.NewDefaultVersion(constants.PlatformName, 1, 2, 0) + MinimumUnmaskedVersion = version.NewDefaultVersion(constants.PlatformName, 1, 1, 0) + PrevMinimumUnmaskedVersion = version.NewDefaultVersion(constants.PlatformName, 1, 0, 0) + VersionParser = version.NewDefaultParser() + + ApricotPhase0Times = map[uint32]time.Time{ + constants.MainnetID: time.Date(2020, time.December, 8, 3, 0, 0, 0, time.UTC), + constants.FujiID: time.Date(2020, time.December, 5, 5, 0, 0, 0, time.UTC), + } + ApricotPhase0DefaultTime = time.Date(2020, time.December, 5, 5, 0, 0, 0, time.UTC) + + ApricotPhase1Times = map[uint32]time.Time{ + constants.MainnetID: time.Date(2021, time.March, 31, 14, 0, 0, 0, time.UTC), + constants.FujiID: time.Date(2021, time.March, 26, 14, 0, 0, 0, time.UTC), + } + ApricotPhase1DefaultTime = time.Date(2020, time.December, 5, 5, 0, 0, 0, time.UTC) +) + +func GetApricotPhase0Time(networkID uint32) time.Time { + if upgradeTime, exists := ApricotPhase0Times[networkID]; exists { + return upgradeTime + } + return ApricotPhase0DefaultTime +} + +func GetApricotPhase1Time(networkID uint32) time.Time { + if upgradeTime, exists := ApricotPhase1Times[networkID]; exists { + return upgradeTime + } + return ApricotPhase1DefaultTime +} diff --git a/scripts/ansible/.ansible-lint b/scripts/ansible/.ansible-lint index a9514eae0f8e..b576e34d7454 100644 --- a/scripts/ansible/.ansible-lint +++ b/scripts/ansible/.ansible-lint @@ -1,7 +1,3 @@ parsable: true -skip_list: - - '301' # Commands should not change things if nothing needs doing - - '204' # Lines should be no longer than 160 chars - - '502' # All tasks should be named # vim: filetype=yaml diff --git a/scripts/ansible/README.md b/scripts/ansible/README.md new file mode 100644 index 000000000000..8ac1d03ddd86 --- /dev/null +++ b/scripts/ansible/README.md @@ -0,0 +1,202 @@ +# Ansible for AvalancheGo + +[Ansible](https://ansible.com) playbooks, roles, & inventories to install +[AvalancheGo](https://github.com/ava-labs/avalanchego) as a systemd service. +Target(s) can be + +- localhost +- Cloud VMs (e.g. Amazon, Azure, Digital Ocean) +- Raspberry Pi +- any host running a supported operating system +- any combination of the above + + +## Using + +To create an AvalancheGo service on localhost + +1. Check you have Ansible 2.9+ (see [Installing](#installing)) +2. Clone the AvalancheGo git repository + ``` + $ git clone https://github.com/ava-labs/avalanchego + ``` + +3. Change to this directory + ``` + $ cd avalanchego/scripts/ansible + ``` + +4. Run the service playbook + ``` + $ ansible-playbook \ + -i inventories/localhost.yml \ + --ask-become-pass \ + service_playbook.yml + ``` + + You don't need `--ask-become-pass` if your account doesn't require a sudo + password. To install on remote hosts you will need to create an inventory, + see [customising]. + + Output should look similar to the [example run](#example-run). + +5. Check the service is running + ``` + $ systemctl status avalanchego + ``` + + The output should look similar to + ``` + ● avalanchego.service - AvalancheGo node for Avalanche consensus network + Loaded: loaded (/etc/systemd/system/avalanchego.service; enabled; vendor preset: enabled) + Active: active (running) since Wed 2020-10-21 10:00:00 UTC; 32s ago + ... + ``` + + +## Installing + +Ansible 2.9 (or higher) is required. To check, run + +``` +ansible --version +``` + +the first line of output should be like `ansible 2.9.x`, or `ansible 2.10.x` +(`x` can be any number). If the output includes `ansible: command not found`, +or an earlier version (e.g. `ansible 2.8.16`), then you need to install a +supported version. + +To install a supported version + +1. Create a Python Virtualenv + ``` + $ python3 -m venv venv/ + ``` + +2. Activate the Virtualenv + ``` + $ source venv/bin/activate + ``` + +4. Install Ansible inside the virtualenv + ``` + $ pip install "ansible>=2.9" + ``` + + +## Customising + +To run against remote targets you'll need an [Ansible inventory](https://docs.ansible.com/ansible/latest/user_guide/intro_inventory.html#inventory-basics-formats-hosts-and-groups). +Here are some examples to use as a starting point. + +### Amazon + +```yaml +avalanche_nodes: + hosts: + ec2-203-0-113-42.us-east-1.compute.amazonaws.com: + ec2-203-0-113-9.ap-southeast-1.compute.amazonaws.com: + vars: + ansible_ssh_private_key_file: "~/.ssh/aws_identity.pem" + ansible_user: "ubuntu" +``` + +### Raspberry Pi + +```yaml +avalanche_nodes: + hosts: + raspberrypi.local: + vars: + ansible_user: "pi" +``` + +## Requirements + +Target operating systems supported by these roles & playbooks are + +- CentOS 7 +- CentOS 8 +- Debian 10 +- Raspberry Pi OS +- Ubuntu 18.04 LTS +- Ubuntu 20.04 LTS + + +## Example run + +``` +PLAY [Configure Avalanche service] **************************************************** + +TASK [Gathering Facts] **************************************************************** +ok: [localhost] + +TASK [golang_base : Dispatch tasks] *************************************************** +included: …/roles/golang_base/tasks/ubuntu-20.04.yml for localhost + +TASK [golang_base : Install Go] ******************************************************* +ok: [localhost] + +TASK [avalanche_base : Dispatch tasks] ************************************************ +included: …/roles/avalanche_base/tasks/ubuntu.yml for localhost + +TASK [avalanche_base : Install Avalanche dependencies] ******************************** +ok: [localhost] + +TASK [avalanche_build : Update git clone] ********************************************* +changed: [localhost] + +TASK [avalanche_build : Build project] ************************************************ +changed: [localhost] + +TASK [avalanche_user : Create Avalanche daemon group] ********************************* +changed: [localhost] + +TASK [avalanche_user : Create Avalanche daemon user] ********************************** +[WARNING]: The value False (type bool) in a string field was converted to 'False' +(type string). If this does not look like what you expect, quote the entire value to +ensure it does not change. +changed: [localhost] + +TASK [avalanche_install : Create shared directories] ********************************** +ok: [localhost] => (item={'path': '/var/local/lib'}) +changed: [localhost] => (item={'path': '/var/local/log'}) + +TASK [avalanche_install : Create Avalanche directories] ******************************* +ok: [localhost] => (item=/var/local/lib/avalanchego) +changed: [localhost] => (item=/var/local/lib/avalanchego/db) +changed: [localhost] => (item=/var/local/lib/avalanchego/staking) +changed: [localhost] => (item=/var/local/log/avalanchego) +changed: [localhost] => (item=/usr/local/lib/avalanchego) + +TASK [avalanche_install : Install Avalanche binary] *********************************** +changed: [localhost] + +TASK [avalanche_install : Install Avalanche plugins] ********************************** +changed: [localhost] => (item={'path': '~auser/go/src/github.com/ava-labs/avalanchego/build/plugins/evm'}) + +TASK [avalanche_staker : Create staking key] ****************************************** +changed: [localhost] + +TASK [avalanche_staker : Create staking certificate signing request] ****************** +changed: [localhost] + +TASK [avalanche_staker : Create staking certificate] ********************************** +changed: [localhost] + +TASK [avalanche_service : Configure Avalanche service] ******************************** +changed: [localhost] + +TASK [avalanche_service : Enable Avalanche service] *********************************** +changed: [localhost] + +RUNNING HANDLER [avalanche_service : Reload systemd] ********************************** +ok: [localhost] + +RUNNING HANDLER [avalanche_service : Restart Avalanche service] *********************** +changed: [localhost] + +PLAY RECAP **************************************************************************** +localhost : ok=20 changed=14 unreachable=0 failed=0 skipped=0 rescued=0 ignored=0 +``` diff --git a/scripts/ansible/ansible.cfg b/scripts/ansible/ansible.cfg old mode 100755 new mode 100644 diff --git a/scripts/ansible/inventories/examples/amazon.yml b/scripts/ansible/inventories/examples/amazon.yml new file mode 100644 index 000000000000..d018bf495028 --- /dev/null +++ b/scripts/ansible/inventories/examples/amazon.yml @@ -0,0 +1,7 @@ +avalanche_nodes: + hosts: + ec2-203-0-113-42.us-east-1.compute.amazonaws.com: + ec2-203-0-113-9.ap-southeast-1.compute.amazonaws.com: + vars: + ansible_ssh_private_key_file: "~/.ssh/aws_identity.pem" + ansible_user: "ubuntu" diff --git a/scripts/ansible/inventories/localhost.yml b/scripts/ansible/inventories/localhost.yml new file mode 100644 index 000000000000..2227abd082f9 --- /dev/null +++ b/scripts/ansible/inventories/localhost.yml @@ -0,0 +1,4 @@ +avalanche_nodes: + hosts: + localhost: + ansible_connection: local diff --git a/scripts/ansible/inventories/raspberrypi.yml b/scripts/ansible/inventories/raspberrypi.yml new file mode 100644 index 000000000000..e5eaa4efc6de --- /dev/null +++ b/scripts/ansible/inventories/raspberrypi.yml @@ -0,0 +1,5 @@ +avalanche_nodes: + hosts: + raspberrypi.local: + vars: + ansible_user: "pi" diff --git a/scripts/ansible/restart_playbook.yml b/scripts/ansible/restart_playbook.yml index aa357eaeb0d3..da1d641660ad 100755 --- a/scripts/ansible/restart_playbook.yml +++ b/scripts/ansible/restart_playbook.yml @@ -2,7 +2,6 @@ --- - name: Update the network connection: ssh - gather_facts: false hosts: all roles: - name: avalanche_stop diff --git a/scripts/ansible/restart_with_coreth_playbook.yml b/scripts/ansible/restart_with_coreth_playbook.yml new file mode 100755 index 000000000000..39e712fcaf02 --- /dev/null +++ b/scripts/ansible/restart_with_coreth_playbook.yml @@ -0,0 +1,11 @@ +#!/usr/bin/env ansible-playbook +--- +- name: Update the network + connection: ssh + hosts: all + roles: + - name: avalanche_stop + - name: avalanche_build + - name: coreth_build + - name: avalanche_reset + - name: avalanche_start diff --git a/scripts/ansible/roles/avalanche_base/tasks/main.yml b/scripts/ansible/roles/avalanche_base/tasks/main.yml index f619e990ec10..9370eb906ba8 100644 --- a/scripts/ansible/roles/avalanche_base/tasks/main.yml +++ b/scripts/ansible/roles/avalanche_base/tasks/main.yml @@ -3,5 +3,6 @@ with_first_found: - "{{ ansible_facts.distribution | lower }}-{{ ansible_facts.distribution_major_version }}.yml" - "{{ ansible_facts.distribution | lower }}.yml" + - "not_supported.yml" tags: - - golang_base + - avalanche_base diff --git a/scripts/ansible/roles/avalanche_base/tasks/not_supported.yml b/scripts/ansible/roles/avalanche_base/tasks/not_supported.yml new file mode 100644 index 000000000000..4ccf53846dc0 --- /dev/null +++ b/scripts/ansible/roles/avalanche_base/tasks/not_supported.yml @@ -0,0 +1,10 @@ +- name: Not supported + fail: + msg: + This operating system is not supported by this role + (system={{ ansible_facts.system }}, + os_family={{ ansible_facts.os_family }}, + distribution={{ ansible_facts.distribution }}, + distribution_version={{ ansible_facts.distribution_version }}). + tags: + - avalanche_base diff --git a/scripts/ansible/roles/avalanche_build/defaults/main.yml b/scripts/ansible/roles/avalanche_build/defaults/main.yml index 7ebabe2ba269..717e992865de 100644 --- a/scripts/ansible/roles/avalanche_build/defaults/main.yml +++ b/scripts/ansible/roles/avalanche_build/defaults/main.yml @@ -1,5 +1,5 @@ -avalanche_binary: ~/go/src/github.com/ava-labs/avalanchego/build/avalanchego -repo_folder: ~/go/src/github.com/ava-labs/avalanchego +avalanche_binary: "{{ repo_folder }}/build/avalanchego" +repo_folder: "~{{ ansible_facts.user_id }}/go/src/github.com/ava-labs/avalanchego" repo_name: ava-labs/avalanchego -repo_url: https://github.com/{{ repo_name }} +repo_url: git@github.com:{{ repo_name }} repo_branch: dev diff --git a/scripts/ansible/roles/avalanche_build/tasks/main.yml b/scripts/ansible/roles/avalanche_build/tasks/main.yml index 48900626ee61..10b44d890c80 100644 --- a/scripts/ansible/roles/avalanche_build/tasks/main.yml +++ b/scripts/ansible/roles/avalanche_build/tasks/main.yml @@ -6,8 +6,9 @@ update: yes - name: Build project + # noqa 301 command: ./scripts/build.sh args: chdir: "{{ repo_folder }}" environment: - PATH: /sbin:/usr/sbin:/bin:/usr/bin:/usr/local/bin:/snap/bin + PATH: /usr/lib/go-{{ golang_version_min_major }}.{{ golang_version_min_minor }}/bin:/sbin:/usr/sbin:/bin:/usr/bin:/usr/local/bin:/snap/bin diff --git a/scripts/ansible/roles/avalanche_install/defaults/main.yml b/scripts/ansible/roles/avalanche_install/defaults/main.yml index 0fb4f1db9dbe..b83f7da86f82 100644 --- a/scripts/ansible/roles/avalanche_install/defaults/main.yml +++ b/scripts/ansible/roles/avalanche_install/defaults/main.yml @@ -13,10 +13,11 @@ logdir: "{{ localstatedir }}/log" # These names are specific to Avalanche. Default values are based loosely on *nix # conventions. -avalanche_daemon_home_dir: "{{ sharedstatedir }}/avalanche" +avalanche_daemon_home_dir: "{{ sharedstatedir }}/avalanchego" avalanche_daemon_db_dir: "{{ avalanche_daemon_home_dir }}/db" -avalanche_daemon_log_dir: "{{ logdir }}/avalanche" -avalanche_daemon_plugin_dir: "{{ libdir }}/avalanche/plugins" +avalanche_daemon_log_dir: "{{ logdir }}/avalanchego" +avalanche_daemon_plugin_dir: "{{ libdir }}/avalanchego" avalanche_daemon_staking_dir: "{{ avalanche_daemon_home_dir }}/staking" avalanche_daemon_staking_tls_cert: "{{ avalanche_daemon_staking_dir }}/staker.crt" +avalanche_daemon_staking_tls_csr: "{{ avalanche_daemon_staking_dir }}/staker.csr" avalanche_daemon_staking_tls_key: "{{ avalanche_daemon_staking_dir }}/staker.key" diff --git a/scripts/ansible/roles/avalanche_install/tasks/main.yml b/scripts/ansible/roles/avalanche_install/tasks/main.yml index 3c08eb6c5f9c..d1b8853a3b3c 100644 --- a/scripts/ansible/roles/avalanche_install/tasks/main.yml +++ b/scripts/ansible/roles/avalanche_install/tasks/main.yml @@ -3,7 +3,13 @@ file: path: "{{ item.path }}" state: directory - mode: preserve + # This task intentionally leaves the owner, group, & mode unspecified. + # These directories (e.g. /var/local/log) are shared resources. + # We lack the authority to unilaterlally decide their permissions. + # The goals are + # - Don't modify existing permissions, if the directories already exist. + # - Follow system policy (e.g. umask ), if the directories are created. + # noqa 208 loop: - path: "{{ sharedstatedir }}" - path: "{{ logdir }}" @@ -40,7 +46,7 @@ become: true copy: src: "{{ avalanche_binary }}" - dest: "{{ bindir }}/avalanche" + dest: "{{ bindir }}/{{ avalanche_binary | basename }}" remote_src: true owner: root group: root diff --git a/scripts/ansible/roles/avalanche_service/defaults/main.yml b/scripts/ansible/roles/avalanche_service/defaults/main.yml index 59f44fc6c3ba..22ad478c149d 100644 --- a/scripts/ansible/roles/avalanche_service/defaults/main.yml +++ b/scripts/ansible/roles/avalanche_service/defaults/main.yml @@ -1,2 +1,6 @@ +avalanche_daemon_dynamic_public_ip: ifconfigme avalanche_daemon_http_host: localhost +# This variable is ignored, if avalanche_daemon_dynamic_public_ip is not false +avalanche_daemon_public_ip: "{{ ansible_facts.default_ipv4.address }}" +avalanche_daemon_service_name: avalanchego log_level: info diff --git a/scripts/ansible/roles/avalanche_service/handlers/main.yml b/scripts/ansible/roles/avalanche_service/handlers/main.yml index fee6950b3e7c..e6eba9c2c874 100644 --- a/scripts/ansible/roles/avalanche_service/handlers/main.yml +++ b/scripts/ansible/roles/avalanche_service/handlers/main.yml @@ -6,5 +6,5 @@ - name: Restart Avalanche service become: true service: - name: avalanche + name: "{{ avalanche_daemon_service_name }}" state: restarted diff --git a/scripts/ansible/roles/avalanche_service/tasks/main.yml b/scripts/ansible/roles/avalanche_service/tasks/main.yml index 626eb1529841..5b491525cee4 100644 --- a/scripts/ansible/roles/avalanche_service/tasks/main.yml +++ b/scripts/ansible/roles/avalanche_service/tasks/main.yml @@ -2,7 +2,7 @@ become: true template: src: avalanche.service - dest: /etc/systemd/system + dest: "/etc/systemd/system/{{ avalanche_daemon_service_name }}.service" owner: root group: root mode: u=rw,go=r @@ -13,7 +13,7 @@ - name: Enable Avalanche service become: true systemd: - name: avalanche + name: "{{ avalanche_daemon_service_name }}" state: started enabled: true daemon_reload: true diff --git a/scripts/ansible/roles/avalanche_service/templates/avalanche.service b/scripts/ansible/roles/avalanche_service/templates/avalanche.service index 0a7a5e71e022..0e21f6e14863 100644 --- a/scripts/ansible/roles/avalanche_service/templates/avalanche.service +++ b/scripts/ansible/roles/avalanche_service/templates/avalanche.service @@ -1,7 +1,7 @@ # {{ ansible_managed }} [Unit] -Description=Avalanche test node +Description=AvalancheGo node Documentation=https://docs.avax.network/ After=network.target StartLimitIntervalSec=0 @@ -12,8 +12,12 @@ WorkingDirectory={{ avalanche_daemon_home_dir }} Restart=always RestartSec=1 User={{ avalanche_daemon_user }} -ExecStart={{ bindir }}/avalanche \ - --public-ip="{{ ansible_facts.default_ipv4.address }}" \ +ExecStart={{ bindir }}/{{ avalanche_binary | basename }} \ +{% if avalanche_daemon_dynamic_public_ip %} + --dynamic-public-ip="{{ avalanche_daemon_dynamic_public_ip }}" \ +{% else %} + --public-ip="{{ avalanche_daemon_public_ip }}" \ +{% endif %} --http-host="{{ avalanche_daemon_http_host }}" \ --db-dir="{{ avalanche_daemon_db_dir }}" \ --plugin-dir="{{ avalanche_daemon_plugin_dir }}" \ diff --git a/scripts/ansible/roles/avalanche_staker/tasks/main.yml b/scripts/ansible/roles/avalanche_staker/tasks/main.yml new file mode 100644 index 000000000000..ed93a5910be3 --- /dev/null +++ b/scripts/ansible/roles/avalanche_staker/tasks/main.yml @@ -0,0 +1,46 @@ +- name: Create staking key + become: true + openssl_privatekey: + path: "{{ avalanche_daemon_staking_tls_key }}" + size: 4096 + backup: true + owner: "{{ avalanche_daemon_user }}" + group: "{{ avalanche_daemon_group }}" + mode: u=rw,go= + +# This CSR isn't used at all by avlanchego. It's a required step when creating +# a certificate with Ansible, even a self signed one. Leaving the CSR in place +# (rather than deleting it) keeps the role idempotent. +- name: Create staking certificate signing request + become: true + openssl_csr: + path: "{{ avalanche_daemon_staking_tls_csr }}" + basic_constraints: + - "CA:FALSE" + basic_constraints_critical: true + key_usage: + - digitalSignature + - dataEncipherment + key_usage_critical: true + privatekey_path: "{{ avalanche_daemon_staking_tls_key }}" + use_common_name_for_san: false + owner: "{{ avalanche_daemon_user }}" + group: "{{ avalanche_daemon_group }}" + mode: u=rw,go=r + +# These parameters in this role were arrived at by inspecting an certificate +# that was generated by avlanchego. There is one intentional difference. +# The role generates certificates valid for the Ansible default of 10 years. +# avlanchego generates them valid for 100 years. By using Ansible's default +# (rather than overriding it) the role can be kept idempotent. +- name: Create staking certificate + become: true + openssl_certificate: + path: "{{ avalanche_daemon_staking_tls_cert }}" + backup: true + csr_path: "{{ avalanche_daemon_staking_tls_csr }}" + privatekey_path: "{{ avalanche_daemon_staking_tls_key }}" + provider: selfsigned + owner: "{{ avalanche_daemon_user }}" + group: "{{ avalanche_daemon_group }}" + mode: u=rw,go=r diff --git a/scripts/ansible/roles/avalanche_start/tasks/main.yml b/scripts/ansible/roles/avalanche_start/tasks/main.yml index dd14640195b5..e918c70871bf 100644 --- a/scripts/ansible/roles/avalanche_start/tasks/main.yml +++ b/scripts/ansible/roles/avalanche_start/tasks/main.yml @@ -1,4 +1,5 @@ - name: Start node + # noqa 301 shell: nohup {{ avalanche_binary }} --network-id="{{ network_id }}" diff --git a/scripts/ansible/roles/avalanche_stop/tasks/main.yml b/scripts/ansible/roles/avalanche_stop/tasks/main.yml index 22da3e5ba248..1e23fa33d054 100644 --- a/scripts/ansible/roles/avalanche_stop/tasks/main.yml +++ b/scripts/ansible/roles/avalanche_stop/tasks/main.yml @@ -1,7 +1,17 @@ - name: Kill Node command: killall -SIGTERM avalanchego - ignore_errors: true + register: killall_avalanchego + changed_when: + - "killall_avalanchego.rc in [0]" + failed_when: + - "killall_avalanchego.rc not in [0]" + - "killall_avalanchego.stderr not in ['avalanchego: no process found']" - name: Kill EVM command: killall -SIGTERM evm - ignore_errors: true + register: killall_evm + changed_when: + - "killall_evm.rc in [0]" + failed_when: + - "killall_evm.rc not in [0]" + - "killall_evm.stderr not in ['evm: no process found']" diff --git a/scripts/ansible/roles/avalanche_upgrade/tasks/10-staking-migrate.yml b/scripts/ansible/roles/avalanche_upgrade/tasks/10-staking-migrate.yml deleted file mode 100644 index 0f31ebfee01f..000000000000 --- a/scripts/ansible/roles/avalanche_upgrade/tasks/10-staking-migrate.yml +++ /dev/null @@ -1,51 +0,0 @@ -- name: Migrate staking key - become: true - vars: - old_key: "{{ avalanche_daemon_home_dir }}/keys/staker.key" - new_key: "{{ avalanche_daemon_home_dir }}/staking/staker.key" - block: - - name: Check for Gecko 0.2.0 staking key - stat: - path: "{{ old_key }}" - register: gecko_0_2_0_staking_key - - - name: Check for Gecko newer staking key - stat: - path: "{{ new_key }}" - register: gecko_newer_staking_key - - - name: Move staking key - command: - cmd: mv "{{ old_key }}" "{{ new_key }}" - creates: "{{ new_key }}" - when: - - gecko_0_2_0_staking_key.stat.exists - - not gecko_newer_staking_key.stat.exists - notify: - - Restart Avalanche service - -- name: Migrate staking certificate - become: true - vars: - old_cert: "{{ avalanche_daemon_home_dir }}/keys/staker.crt" - new_cert: "{{ avalanche_daemon_home_dir }}/staking/staker.crt" - block: - - name: Check for Gecko 0.2.0 staking certificate - stat: - path: "{{ old_cert }}" - register: gecko_0_2_0_staking_cert - - - name: Check for Gecko newer staking certificate - stat: - path: "{{ new_cert }}" - register: gecko_newer_staking_cert - - - name: Migrate staking certificate - command: - cmd: mv "{{ old_cert }}" "{{ new_cert }}" - creates: "{{ new_cert }}" - when: - - gecko_0_2_0_staking_cert.stat.exists - - not gecko_newer_staking_cert.stat.exists - notify: - - Restart Avalanche service diff --git a/scripts/ansible/roles/avalanche_upgrade/tasks/main.yml b/scripts/ansible/roles/avalanche_upgrade/tasks/main.yml deleted file mode 100644 index b63665c293ba..000000000000 --- a/scripts/ansible/roles/avalanche_upgrade/tasks/main.yml +++ /dev/null @@ -1 +0,0 @@ -- import_tasks: 10-staking-migrate.yml diff --git a/scripts/ansible/roles/avalanche_user/tasks/main.yml b/scripts/ansible/roles/avalanche_user/tasks/main.yml index 248275568890..f18a804f32cf 100644 --- a/scripts/ansible/roles/avalanche_user/tasks/main.yml +++ b/scripts/ansible/roles/avalanche_user/tasks/main.yml @@ -11,5 +11,4 @@ group: "{{ avalanche_daemon_group }}" home: "{{ avalanche_daemon_home_dir }}" shell: /bin/false - skeleton: false system: true diff --git a/scripts/ansible/roles/coreth_build/defaults/main.yml b/scripts/ansible/roles/coreth_build/defaults/main.yml new file mode 100644 index 000000000000..af2a41892f12 --- /dev/null +++ b/scripts/ansible/roles/coreth_build/defaults/main.yml @@ -0,0 +1,5 @@ +evm_binary: "{{ repo_folder }}/build/plugins/evm" +repo_folder: "~{{ ansible_facts.user_id }}/go/src/github.com/ava-labs/coreth" +repo_name: ava-labs/coreth +repo_url: git@github.com:{{ repo_name }} +repo_branch: dev diff --git a/scripts/ansible/roles/coreth_build/tasks/main.yml b/scripts/ansible/roles/coreth_build/tasks/main.yml new file mode 100644 index 000000000000..10b44d890c80 --- /dev/null +++ b/scripts/ansible/roles/coreth_build/tasks/main.yml @@ -0,0 +1,14 @@ +- name: Update git clone + git: + repo: "{{ repo_url }}" + dest: "{{ repo_folder }}" + version: "{{ repo_branch }}" + update: yes + +- name: Build project + # noqa 301 + command: ./scripts/build.sh + args: + chdir: "{{ repo_folder }}" + environment: + PATH: /usr/lib/go-{{ golang_version_min_major }}.{{ golang_version_min_minor }}/bin:/sbin:/usr/sbin:/bin:/usr/bin:/usr/local/bin:/snap/bin diff --git a/scripts/ansible/roles/golang_base/defaults/main.yml b/scripts/ansible/roles/golang_base/defaults/main.yml new file mode 100644 index 000000000000..bedf4cda74d3 --- /dev/null +++ b/scripts/ansible/roles/golang_base/defaults/main.yml @@ -0,0 +1,7 @@ +--- +# Changes to this minimum version should be coordinated with changes the +# documented minimum in README.md +golang_version_min: 1.15.5 +golang_version_min_info: "{{ golang_version_min.split('.') | map('int') | list }}" +golang_version_min_major: "{{ golang_version_min_info[0] }}" +golang_version_min_minor: "{{ golang_version_min_info[1] }}" diff --git a/scripts/ansible/roles/golang_base/tasks/centos-7.yml b/scripts/ansible/roles/golang_base/tasks/centos-7.yml index 2d890c443654..e2bf1b4b3443 100644 --- a/scripts/ansible/roles/golang_base/tasks/centos-7.yml +++ b/scripts/ansible/roles/golang_base/tasks/centos-7.yml @@ -1,12 +1,17 @@ +# CentOS 7.x does not include golang. +# The EPEL repository includes it, and tracks golang packages from Fedora. + - name: Install Go repo yum: name: - epel-release + tags: + - golang_base - name: Install Go become: true yum: name: - - golang # 1.13.11 in July 2020 + - "golang >= {{ golang_version_min }}" tags: - golang_base diff --git a/scripts/ansible/roles/golang_base/tasks/centos-8.yml b/scripts/ansible/roles/golang_base/tasks/centos-8.yml index be4ec20d916b..ca9d4a277f77 100644 --- a/scripts/ansible/roles/golang_base/tasks/centos-8.yml +++ b/scripts/ansible/roles/golang_base/tasks/centos-8.yml @@ -1,7 +1,28 @@ +# CentOS Linux 8.x has periodic minor releases, each one has a Go release. +# CentOS Stream 8 is a rolling release, Go is updated on a continuous basis. + +- name: Add Go repository + # Typically the version of Go in CentOS Linux will be too old. So use the + # CentOS Stream AppStream repository, but but only for Go packages. + yum_repository: + name: centos-stream-appstream + description: CentOS Stream $releasever - AppStream + mirrorlist: http://mirrorlist.centos.org/?release=8-stream&arch=$basearch&repo=AppStream&infra=$infra + gpgcheck: true + gpgkey: file:///etc/pki/rpm-gpg/RPM-GPG-KEY-centosofficial + file: CentOS-Stream-AppStream # .repo extension is added by Ansible + includepkgs: golang* + # Only run on CentOS Linux. CentOS Stream already has this repo configured. + # On CentOS Linux distribution_version is ".". + # On CentOS Stream distribution_version is just "". + when: (ansible_facts.distribution_version | length) >= 3 + tags: + - golang_base + - name: Install Go become: true yum: name: - - golang # 1.13.4 in July 2020 + - "golang >= {{ golang_version_min }}" tags: - golang_base diff --git a/scripts/ansible/roles/golang_base/tasks/debian-10.yml b/scripts/ansible/roles/golang_base/tasks/debian-10.yml index d571f07a70fe..3f5fdaa3ae5a 100644 --- a/scripts/ansible/roles/golang_base/tasks/debian-10.yml +++ b/scripts/ansible/roles/golang_base/tasks/debian-10.yml @@ -1,15 +1,58 @@ +# Debian 10.x (buster) includes golang 1.11.x. +# The testing repository tracks https://golang.org releases more closely. + +- name: Configure Apt + become: true + copy: + content: "{{ golang_base_apt_conf.content }}" + dest: "{{ golang_base_apt_conf.dest }}" + owner: root + group: root + mode: u=rw,go=r + loop: + - content: | + APT::Default-Release "stable"; + dest: /etc/apt/apt.conf.d/99defaultrelease + - content: | + Package: * + Pin: release o=Debian,a=testing,n=bullseye + Pin-Priority: 400 + dest: /etc/apt/preferences.d/99pin-bullseye + loop_control: + label: "{{ golang_base_apt_conf.dest }}" + loop_var: golang_base_apt_conf + tags: + - golang_base + - name: Add Go repository become: true apt_repository: - repo: deb http://deb.debian.org/debian buster-backports main + repo: deb http://deb.debian.org/debian testing main tags: - golang_base - name: Install Go become: true + # Apt doesn't support specifying a minimum version (e.g. foo >= 1.0) + # https://github.com/ansible/ansible/issues/69034 apt: name: - - golang-go - default_release: buster-backports + - "golang-{{ golang_version_min_major }}.{{ golang_version_min_minor }}-go" + default_release: testing + tags: + - golang_base + +- name: Query installed packages + package_facts: + tags: + - golang_base + +- name: Check minimum Go version + vars: + pkg_name: golang-{{ golang_version_min_major }}.{{ golang_version_min_minor }}-go + assert: + that: + - pkg_name in ansible_facts.packages + - ansible_facts.packages[pkg_name][0].version is version(golang_version_min, '>=') tags: - golang_base diff --git a/scripts/ansible/roles/golang_base/tasks/main.yml b/scripts/ansible/roles/golang_base/tasks/main.yml index 9911784e0ae4..3537aa28559f 100644 --- a/scripts/ansible/roles/golang_base/tasks/main.yml +++ b/scripts/ansible/roles/golang_base/tasks/main.yml @@ -3,5 +3,6 @@ with_first_found: - "{{ ansible_facts.distribution | lower }}-{{ ansible_facts.distribution_version }}.yml" - "{{ ansible_facts.distribution | lower }}-{{ ansible_facts.distribution_major_version }}.yml" + - "not_supported.yml" tags: - golang_base diff --git a/scripts/ansible/roles/golang_base/tasks/not_supported.yml b/scripts/ansible/roles/golang_base/tasks/not_supported.yml new file mode 100644 index 000000000000..2021acc4f9e8 --- /dev/null +++ b/scripts/ansible/roles/golang_base/tasks/not_supported.yml @@ -0,0 +1,10 @@ +- name: Not supported + fail: + msg: + This operating system is not supported by this role + (system={{ ansible_facts.system }}, + os_family={{ ansible_facts.os_family }}, + distribution={{ ansible_facts.distribution }}, + distribution_version={{ ansible_facts.distribution_version }}). + tags: + - golang_base diff --git a/scripts/ansible/roles/golang_base/tasks/ubuntu-18.04.yml b/scripts/ansible/roles/golang_base/tasks/ubuntu-18.04.yml index bf237d13e660..84f4b6c9085e 100644 --- a/scripts/ansible/roles/golang_base/tasks/ubuntu-18.04.yml +++ b/scripts/ansible/roles/golang_base/tasks/ubuntu-18.04.yml @@ -1,5 +1,5 @@ # As mentioned by https://github.com/golang/go/wiki/Ubuntu -- task: Add Go repository +- name: Add Go repository become: true apt_repository: repo: ppa:longsleep/golang-backports @@ -8,8 +8,25 @@ - name: Install Go become: true + # Apt doesn't support specifying a minimum version (e.g. foo >= 1.0) + # https://github.com/ansible/ansible/issues/69034 apt: name: - - golang-go + - "golang-{{ golang_version_min_major }}.{{ golang_version_min_minor }}-go" + tags: + - golang_base + +- name: Query installed packages + package_facts: + tags: + - golang_base + +- name: Check minimum Go version + vars: + pkg_name: golang-{{ golang_version_min_major }}.{{ golang_version_min_minor }}-go + assert: + that: + - pkg_name in ansible_facts.packages + - ansible_facts.packages[pkg_name][0].version is version(golang_version_min, '>=') tags: - golang_base diff --git a/scripts/ansible/roles/golang_base/tasks/ubuntu-20.04.yml b/scripts/ansible/roles/golang_base/tasks/ubuntu-20.04.yml deleted file mode 100644 index 9f574ed62185..000000000000 --- a/scripts/ansible/roles/golang_base/tasks/ubuntu-20.04.yml +++ /dev/null @@ -1,7 +0,0 @@ -- name: Install Go - become: true - apt: - name: - - golang-go - tags: - - golang_base diff --git a/scripts/ansible/roles/golang_base/tasks/ubuntu-20.04.yml b/scripts/ansible/roles/golang_base/tasks/ubuntu-20.04.yml new file mode 120000 index 000000000000..8710b1685eb8 --- /dev/null +++ b/scripts/ansible/roles/golang_base/tasks/ubuntu-20.04.yml @@ -0,0 +1 @@ +ubuntu-18.04.yml \ No newline at end of file diff --git a/scripts/ansible/roles/gopath/tasks/main.yml b/scripts/ansible/roles/gopath/tasks/main.yml deleted file mode 100644 index c7f165c6ae86..000000000000 --- a/scripts/ansible/roles/gopath/tasks/main.yml +++ /dev/null @@ -1,5 +0,0 @@ -- name: Set GOPATH - lineinfile: - path: ~/.bashrc - line: GOPATH=$HOME/go - mode: preserve diff --git a/scripts/ansible/service_playbook.yml b/scripts/ansible/service_playbook.yml index 2c0b38406626..3a529e4e660d 100755 --- a/scripts/ansible/service_playbook.yml +++ b/scripts/ansible/service_playbook.yml @@ -4,10 +4,9 @@ hosts: avalanche_nodes roles: - name: golang_base - - name: gopath - name: avalanche_base - name: avalanche_build - name: avalanche_user - name: avalanche_install - - name: avalanche_upgrade + - name: avalanche_staker - name: avalanche_service diff --git a/scripts/ansible/update_playbook.yml b/scripts/ansible/update_playbook.yml index d3ed46bf6551..99ab04143bc6 100755 --- a/scripts/ansible/update_playbook.yml +++ b/scripts/ansible/update_playbook.yml @@ -3,7 +3,6 @@ --- - name: Update the network connection: ssh - gather_facts: false hosts: all roles: - name: avalanche_stop diff --git a/scripts/ansible/update_with_coreth_playbook.yml b/scripts/ansible/update_with_coreth_playbook.yml new file mode 100755 index 000000000000..4707cf6b3000 --- /dev/null +++ b/scripts/ansible/update_with_coreth_playbook.yml @@ -0,0 +1,11 @@ + +#!/usr/bin/env ansible-playbook +--- +- name: Update the network + connection: ssh + hosts: all + roles: + - name: avalanche_stop + - name: avalanche_build + - name: coreth_build + - name: avalanche_start diff --git a/scripts/build_avalanche.sh b/scripts/build_avalanche.sh index 3a2dad5b76cc..78c14c63ac67 100755 --- a/scripts/build_avalanche.sh +++ b/scripts/build_avalanche.sh @@ -10,7 +10,7 @@ GOPATH="$(go env GOPATH)" AVALANCHE_PATH=$( cd "$( dirname "${BASH_SOURCE[0]}" )"; cd .. && pwd ) # Directory above this script BUILD_DIR=$AVALANCHE_PATH/build # Where binaries go -GIT_COMMIT=$( git rev-list -1 HEAD ) +GIT_COMMIT=${AVALANCHEGO_COMMIT:-$( git rev-list -1 HEAD )} # Build aVALANCHE echo "Building Avalanche..." diff --git a/scripts/build_coreth.sh b/scripts/build_coreth.sh index 741f206e21fa..c4fcf2cb71fd 100755 --- a/scripts/build_coreth.sh +++ b/scripts/build_coreth.sh @@ -13,7 +13,7 @@ BUILD_DIR="$AVALANCHE_PATH/build" # Where binaries go PLUGIN_DIR="$BUILD_DIR/plugins" # Where plugin binaries (namely coreth) go BINARY_PATH="$PLUGIN_DIR/evm" -CORETH_VER="v0.4.0-rc.7" +CORETH_VER="v0.4.2-rc.4" CORETH_PATH="$GOPATH/pkg/mod/github.com/ava-labs/coreth@$CORETH_VER" diff --git a/scripts/build_local_image.sh b/scripts/build_local_image.sh index b808e5be6f1a..5189e7e64583 100755 --- a/scripts/build_local_image.sh +++ b/scripts/build_local_image.sh @@ -14,4 +14,4 @@ FULL_COMMIT_HASH="$(git --git-dir="$AVALANCHE_PATH/.git" rev-parse HEAD)" COMMIT_HASH="${FULL_COMMIT_HASH::8}" echo "Building Docker Image: $DOCKERHUB_REPO:$COMMIT_HASH" -docker build -t "$DOCKERHUB_REPO:$COMMIT_HASH" "$AVALANCHE_PATH" -f "$AVALANCHE_PATH/Dockerfile" +docker build -t "$DOCKERHUB_REPO:$COMMIT_HASH" "$AVALANCHE_PATH" -f "$AVALANCHE_PATH/Dockerfile" --build-arg AVALANCHEGO_COMMIT="$FULL_COMMIT_HASH" diff --git a/scripts/build_test.sh b/scripts/build_test.sh index 06f2ceaaeb63..4e5b495f9937 100755 --- a/scripts/build_test.sh +++ b/scripts/build_test.sh @@ -6,4 +6,4 @@ set -o pipefail # Ted: contact me when you make any changes -go test -race -timeout="90s" -coverprofile="coverage.out" -covermode="atomic" ./... +go test -race -timeout="90s" -coverprofile="coverage.out" -covermode="atomic" $(go list ./... | grep -v /mocks | grep -v proto) \ No newline at end of file diff --git a/snow/consensus/avalanche/topological.go b/snow/consensus/avalanche/topological.go index 096d204abd00..36d33e98e694 100644 --- a/snow/consensus/avalanche/topological.go +++ b/snow/consensus/avalanche/topological.go @@ -114,7 +114,9 @@ func (ta *Topological) Add(vtx Vertex) error { return nil // Already inserted this vertex } - ta.ctx.ConsensusDispatcher.Issue(ta.ctx, vtxID, vtx.Bytes()) + if err := ta.ctx.ConsensusDispatcher.Issue(ta.ctx, vtxID, vtx.Bytes()); err != nil { + return err + } txs, err := vtx.Txs() if err != nil { @@ -457,7 +459,9 @@ func (ta *Topological) update(vtx Vertex) error { if err := vtx.Reject(); err != nil { return err } - ta.ctx.ConsensusDispatcher.Reject(ta.ctx, vtxID, vtx.Bytes()) + if err := ta.ctx.ConsensusDispatcher.Reject(ta.ctx, vtxID, vtx.Bytes()); err != nil { + return err + } delete(ta.nodes, vtxID) ta.Metrics.Rejected(vtxID) @@ -508,18 +512,24 @@ func (ta *Topological) update(vtx Vertex) error { switch { case acceptable: // I'm acceptable, why not accept? + // Note that ConsensusDispatcher.Accept must be called before vtx.Accept to honor + // EventDispatcher.Accept's invariant. + if err := ta.ctx.ConsensusDispatcher.Accept(ta.ctx, vtxID, vtx.Bytes()); err != nil { + return err + } if err := vtx.Accept(); err != nil { return err } - ta.ctx.ConsensusDispatcher.Accept(ta.ctx, vtxID, vtx.Bytes()) delete(ta.nodes, vtxID) ta.Metrics.Accepted(vtxID) case rejectable: // I'm rejectable, why not reject? + if err := ta.ctx.ConsensusDispatcher.Reject(ta.ctx, vtxID, vtx.Bytes()); err != nil { + return err + } if err := vtx.Reject(); err != nil { return err } - ta.ctx.ConsensusDispatcher.Reject(ta.ctx, vtxID, vtx.Bytes()) delete(ta.nodes, vtxID) ta.Metrics.Rejected(vtxID) } diff --git a/snow/consensus/snowman/topological.go b/snow/consensus/snowman/topological.go index 55ebab9e101b..a5d30d9d5235 100644 --- a/snow/consensus/snowman/topological.go +++ b/snow/consensus/snowman/topological.go @@ -106,8 +106,12 @@ func (ts *Topological) Add(blk Block) error { blkBytes := blk.Bytes() // Notify anyone listening that this block was issued. - ts.ctx.DecisionDispatcher.Issue(ts.ctx, blkID, blkBytes) - ts.ctx.ConsensusDispatcher.Issue(ts.ctx, blkID, blkBytes) + if err := ts.ctx.DecisionDispatcher.Issue(ts.ctx, blkID, blkBytes); err != nil { + return err + } + if err := ts.ctx.ConsensusDispatcher.Issue(ts.ctx, blkID, blkBytes); err != nil { + return err + } ts.Metrics.Issued(blkID) parentNode, ok := ts.blocks[parentID] @@ -120,8 +124,12 @@ func (ts *Topological) Add(blk Block) error { } // Notify anyone listening that this block was rejected. - ts.ctx.DecisionDispatcher.Reject(ts.ctx, blkID, blkBytes) - ts.ctx.ConsensusDispatcher.Reject(ts.ctx, blkID, blkBytes) + if err := ts.ctx.DecisionDispatcher.Reject(ts.ctx, blkID, blkBytes); err != nil { + return err + } + if err := ts.ctx.ConsensusDispatcher.Reject(ts.ctx, blkID, blkBytes); err != nil { + return err + } ts.Metrics.Rejected(blkID) return nil } @@ -511,20 +519,25 @@ func (ts *Topological) accept(n *snowmanBlock) error { // Get the child and accept it child := n.children[pref] + // Notify anyone listening that this block was accepted. + bytes := child.Bytes() + // Note that DecisionDispatcher.Accept / DecisionDispatcher.Accept must be called before + // child.Accept to honor EventDispatcher.Accept's invariant. + if err := ts.ctx.DecisionDispatcher.Accept(ts.ctx, pref, bytes); err != nil { + return err + } + if err := ts.ctx.ConsensusDispatcher.Accept(ts.ctx, pref, bytes); err != nil { + return err + } if err := child.Accept(); err != nil { return err } - // Notify anyone listening that this block was accepted. - bytes := child.Bytes() - ts.ctx.DecisionDispatcher.Accept(ts.ctx, pref, bytes) - ts.ctx.ConsensusDispatcher.Accept(ts.ctx, pref, bytes) ts.Metrics.Accepted(pref) // Because this is the newest accepted block, this is the new head. ts.head = pref ts.height = child.Height() - // Remove the decided block from the set of processing IDs, as its status // now implies its preferredness. ts.preferredIDs.Remove(pref) @@ -545,8 +558,12 @@ func (ts *Topological) accept(n *snowmanBlock) error { // Notify anyone listening that this block was rejected. bytes := child.Bytes() - ts.ctx.DecisionDispatcher.Reject(ts.ctx, childID, bytes) - ts.ctx.ConsensusDispatcher.Reject(ts.ctx, childID, bytes) + if err := ts.ctx.DecisionDispatcher.Reject(ts.ctx, childID, bytes); err != nil { + return err + } + if err := ts.ctx.ConsensusDispatcher.Reject(ts.ctx, childID, bytes); err != nil { + return err + } ts.Metrics.Rejected(childID) // Track which blocks have been directly rejected @@ -578,8 +595,12 @@ func (ts *Topological) rejectTransitively(rejected []ids.ID) error { // Notify anyone listening that this block was rejected. bytes := child.Bytes() - ts.ctx.DecisionDispatcher.Reject(ts.ctx, childID, bytes) - ts.ctx.ConsensusDispatcher.Reject(ts.ctx, childID, bytes) + if err := ts.ctx.DecisionDispatcher.Reject(ts.ctx, childID, bytes); err != nil { + return err + } + if err := ts.ctx.ConsensusDispatcher.Reject(ts.ctx, childID, bytes); err != nil { + return err + } ts.Metrics.Rejected(childID) // add the newly rejected block to the end of the queue diff --git a/snow/consensus/snowstorm/common.go b/snow/consensus/snowstorm/common.go index 4fa89f053c7d..4ef4c01ed016 100644 --- a/snow/consensus/snowstorm/common.go +++ b/snow/consensus/snowstorm/common.go @@ -123,7 +123,9 @@ func (c *common) shouldVote(con Consensus, tx Tx) (bool, error) { bytes := tx.Bytes() // Notify the IPC socket that this tx has been issued. - c.ctx.DecisionDispatcher.Issue(c.ctx, txID, bytes) + if err := c.ctx.DecisionDispatcher.Issue(c.ctx, txID, bytes); err != nil { + return false, err + } // Notify the metrics that this transaction is being issued. c.Metrics.Issued(txID) @@ -137,42 +139,44 @@ func (c *common) shouldVote(con Consensus, tx Tx) (bool, error) { // any conflicting transactions. Therefore, this transaction is treated as // vacuously accepted and doesn't need to be voted on. - // Accept is called before notifying the IPC so that acceptances that - // cause fatal errors aren't sent to an IPC peer. - if err := tx.Accept(); err != nil { + // Notify those listening for accepted txs + // Note that DecisionDispatcher.Accept must be called before + // tx.Accept to honor EventDispatcher.Accept's invariant. + if err := c.ctx.DecisionDispatcher.Accept(c.ctx, txID, bytes); err != nil { return false, err } - // Notify the IPC socket that this tx has been accepted. - c.ctx.DecisionDispatcher.Accept(c.ctx, txID, bytes) + if err := tx.Accept(); err != nil { + return false, err + } - // Notify the metrics that this transaction was just accepted. + // Notify the metrics that this transaction was accepted. c.Metrics.Accepted(txID) return false, nil } // accept the provided tx. func (c *common) acceptTx(tx Tx) error { - // Accept is called before notifying the IPC so that acceptances that cause - // fatal errors aren't sent to an IPC peer. + txID := tx.ID() + // Notify those listening that this tx has been accepted. + // Note that DecisionDispatcher.Accept must be called before + // tx.Accept to honor EventDispatcher.Accept's invariant. + if err := c.ctx.DecisionDispatcher.Accept(c.ctx, txID, tx.Bytes()); err != nil { + return err + } if err := tx.Accept(); err != nil { return err } - txID := tx.ID() - - // Notify the IPC socket that this tx has been accepted. - c.ctx.DecisionDispatcher.Accept(c.ctx, txID, tx.Bytes()) - // Update the metrics to account for this transaction's acceptance c.Metrics.Accepted(txID) - // If there is a tx that was accepted pending on this tx, the ancestor // should be notified that it doesn't need to block on this tx anymore. c.pendingAccept.Fulfill(txID) // If there is a tx that was issued pending on this tx, the ancestor tx // doesn't need to be rejected because of this tx. c.pendingReject.Abandon(txID) + return nil } @@ -187,7 +191,9 @@ func (c *common) rejectTx(tx Tx) error { txID := tx.ID() // Notify the IPC that the tx was rejected - c.ctx.DecisionDispatcher.Reject(c.ctx, txID, tx.Bytes()) + if err := c.ctx.DecisionDispatcher.Reject(c.ctx, txID, tx.Bytes()); err != nil { + return err + } // Update the metrics to account for this transaction's rejection c.Metrics.Rejected(txID) diff --git a/snow/context.go b/snow/context.go index 20b24a77241d..206aa6065a4b 100644 --- a/snow/context.go +++ b/snow/context.go @@ -11,8 +11,8 @@ import ( "github.com/prometheus/client_golang/prometheus" + "github.com/ava-labs/avalanchego/api/keystore" "github.com/ava-labs/avalanchego/chains/atomic" - "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/timer" @@ -20,14 +20,12 @@ import ( // EventDispatcher ... type EventDispatcher interface { - Issue(ctx *Context, containerID ids.ID, container []byte) - Accept(ctx *Context, containerID ids.ID, container []byte) - Reject(ctx *Context, containerID ids.ID, container []byte) -} - -// Keystore ... -type Keystore interface { - GetDatabase(username, password string) (database.Database, error) + Issue(ctx *Context, containerID ids.ID, container []byte) error + // If the returned error is non-nil, the chain associated with [ctx] should shut + // down and not commit [container] or any other container to its database as accepted. + // Accept must be called before [containerID] is committed to the VM as accepted. + Accept(ctx *Context, containerID ids.ID, container []byte) error + Reject(ctx *Context, containerID ids.ID, container []byte) error } // AliasLookup ... @@ -58,7 +56,7 @@ type Context struct { DecisionDispatcher EventDispatcher ConsensusDispatcher EventDispatcher Lock sync.RWMutex - Keystore Keystore + Keystore keystore.BlockchainKeystore SharedMemory atomic.SharedMemory BCLookup AliasLookup SNLookup SubnetLookup @@ -116,6 +114,6 @@ func DefaultContextTest() *Context { type emptyEventDispatcher struct{} -func (emptyEventDispatcher) Issue(*Context, ids.ID, []byte) {} -func (emptyEventDispatcher) Accept(*Context, ids.ID, []byte) {} -func (emptyEventDispatcher) Reject(*Context, ids.ID, []byte) {} +func (emptyEventDispatcher) Issue(*Context, ids.ID, []byte) error { return nil } +func (emptyEventDispatcher) Accept(*Context, ids.ID, []byte) error { return nil } +func (emptyEventDispatcher) Reject(*Context, ids.ID, []byte) error { return nil } diff --git a/snow/engine/avalanche/bootstrap/bootstrapper.go b/snow/engine/avalanche/bootstrap/bootstrapper.go index 4cea3bd66b20..8f0766a0d78d 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper.go @@ -120,7 +120,7 @@ func (b *Bootstrapper) CurrentAcceptedFrontier() ([]ids.ID, error) { func (b *Bootstrapper) FilterAccepted(containerIDs []ids.ID) []ids.ID { acceptedVtxIDs := make([]ids.ID, 0, len(containerIDs)) for _, vtxID := range containerIDs { - if vtx, err := b.Manager.Get(vtxID); err == nil && vtx.Status() == choices.Accepted { + if vtx, err := b.Manager.GetVtx(vtxID); err == nil && vtx.Status() == choices.Accepted { acceptedVtxIDs = append(acceptedVtxIDs, vtxID) } } @@ -142,7 +142,7 @@ func (b *Bootstrapper) fetch(vtxIDs ...ids.ID) error { } // Make sure we don't already have this vertex - if _, err := b.Manager.Get(vtxID); err == nil { + if _, err := b.Manager.GetVtx(vtxID); err == nil { continue } @@ -195,7 +195,11 @@ func (b *Bootstrapper) process(vtxs ...avalanche.Vertex) error { b.numFetchedVts.Inc() b.NumFetched++ // Progress tracker if b.NumFetched%common.StatusUpdateFrequency == 0 { - b.Ctx.Log.Info("fetched %d vertices", b.NumFetched) + if !b.Restarted { + b.Ctx.Log.Info("fetched %d vertices", b.NumFetched) + } else { + b.Ctx.Log.Debug("fetched %d vertices", b.NumFetched) + } } } else { b.Ctx.Log.Verbo("couldn't push to vtxBlocked: %s", err) @@ -268,7 +272,7 @@ func (b *Bootstrapper) MultiPut(vdr ids.ShortID, requestID uint32, vtxs [][]byte } requestedVtxID, requested := b.OutstandingRequests.Remove(vdr, requestID) - vtx, err := b.Manager.Parse(vtxs[0]) // first vertex should be the one we requested in GetAncestors request + vtx, err := b.Manager.ParseVtx(vtxs[0]) // first vertex should be the one we requested in GetAncestors request if err != nil { if !requested { b.Ctx.Log.Debug("failed to parse unrequested vertex from %s with requestID %d: %s", vdr, requestID, err) @@ -309,7 +313,7 @@ func (b *Bootstrapper) MultiPut(vdr ids.ShortID, requestID uint32, vtxs [][]byte } for _, vtxBytes := range vtxs[1:] { // Parse/persist all the vertices - vtx, err := b.Manager.Parse(vtxBytes) // Persists the vtx + vtx, err := b.Manager.ParseVtx(vtxBytes) // Persists the vtx if err != nil { b.Ctx.Log.Debug("failed to parse vertex: %s", err) b.Ctx.Log.Verbo("vertex: %s", formatting.DumpBytes{Bytes: vtxBytes}) @@ -356,7 +360,7 @@ func (b *Bootstrapper) ForceAccepted(acceptedContainerIDs []ids.ID) error { b.NumFetched = 0 toProcess := make([]avalanche.Vertex, 0, len(acceptedContainerIDs)) for _, vtxID := range acceptedContainerIDs { - if vtx, err := b.Manager.Get(vtxID); err == nil { + if vtx, err := b.Manager.GetVtx(vtxID); err == nil { toProcess = append(toProcess, vtx) // Process this vertex. } else { b.needToFetch.Add(vtxID) // We don't have this vertex. Mark that we have to fetch it. @@ -373,15 +377,22 @@ func (b *Bootstrapper) checkFinish() error { return nil } - b.Ctx.Log.Info("bootstrapping fetched %d vertices. executing transaction state transitions...", - b.NumFetched) + if !b.Restarted { + b.Ctx.Log.Info("bootstrapping fetched %d vertices. Executing transaction state transitions...", b.NumFetched) + } else { + b.Ctx.Log.Debug("bootstrapping fetched %d vertices. Executing transaction state transitions...", b.NumFetched) + } _, err := b.executeAll(b.TxBlocked, b.Ctx.DecisionDispatcher) if err != nil { return err } - b.Ctx.Log.Info("executing vertex state transitions...") + if !b.Restarted { + b.Ctx.Log.Info("executing vertex state transitions...") + } else { + b.Ctx.Log.Debug("executing vertex state transitions...") + } executedVts, err := b.executeAll(b.VtxBlocked, b.Ctx.ConsensusDispatcher) if err != nil { return err @@ -394,19 +405,21 @@ func (b *Bootstrapper) checkFinish() error { // bootstrapping process will terminate even as new vertices are being // issued. if executedVts > 0 && executedVts < previouslyExecuted/2 && b.RetryBootstrap { - b.Ctx.Log.Info("bootstrapping is checking for more vertices before finishing the bootstrap process...") + b.Ctx.Log.Debug("checking for more vertices before finishing bootstrapping") return b.RestartBootstrap(true) } - b.Ctx.Log.Info("bootstrapping fetched enough vertices to finish the bootstrap process...") - // Notify the subnet that this chain is synced b.Subnet.Bootstrapped(b.Ctx.ChainID) // If the subnet hasn't finished bootstrapping, this chain should remain // syncing. if !b.Subnet.IsBootstrapped() { - b.Ctx.Log.Info("bootstrapping is waiting for the remaining chains in this subnet to finish syncing...") + if !b.Restarted { + b.Ctx.Log.Info("waiting for the remaining chains in this subnet to finish syncing") + } else { + b.Ctx.Log.Debug("waiting for the remaining chains in this subnet to finish syncing") + } // Delay new incoming messages to avoid consuming unnecessary resources // while keeping up to date on the latest tip. b.Config.Delay.Delay(b.delayAmount) @@ -442,6 +455,11 @@ func (b *Bootstrapper) executeAll(jobs *queue.Jobs, events snow.EventDispatcher) for job, err := jobs.Pop(); err == nil; job, err = jobs.Pop() { b.Ctx.Log.Debug("Executing: %s", job.ID()) + // Note that events.Accept must be called before + // job.Execute to honor EventDispatcher.Accept's invariant. + if err := events.Accept(b.Ctx, job.ID(), job.Bytes()); err != nil { + return numExecuted, err + } if err := jobs.Execute(job); err != nil { b.Ctx.Log.Error("Error executing: %s", err) return numExecuted, err @@ -451,12 +469,19 @@ func (b *Bootstrapper) executeAll(jobs *queue.Jobs, events snow.EventDispatcher) } numExecuted++ if numExecuted%common.StatusUpdateFrequency == 0 { // Periodically print progress - b.Ctx.Log.Info("executed %d operations", numExecuted) + if !b.Restarted { + b.Ctx.Log.Info("executed %d operations", numExecuted) + } else { + b.Ctx.Log.Debug("executed %d operations", numExecuted) + } } - events.Accept(b.Ctx, job.ID(), job.Bytes()) } - b.Ctx.Log.Info("executed %d operations", numExecuted) + if !b.Restarted { + b.Ctx.Log.Info("executed %d operations", numExecuted) + } else { + b.Ctx.Log.Debug("executed %d operations", numExecuted) + } return numExecuted, nil } diff --git a/snow/engine/avalanche/bootstrap/bootstrapper_test.go b/snow/engine/avalanche/bootstrap/bootstrapper_test.go index 9855f9e17348..267337d7da1d 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper_test.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper_test.go @@ -131,7 +131,7 @@ func TestBootstrapperSingleFrontier(t *testing.T) { acceptedIDs := []ids.ID{vtxID0, vtxID1, vtxID2} - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtxID0: return vtx0, nil @@ -145,7 +145,7 @@ func TestBootstrapperSingleFrontier(t *testing.T) { } } - manager.ParseF = func(vtxBytes []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(vtxBytes []byte) (avalanche.Vertex, error) { switch { case bytes.Equal(vtxBytes, vtxBytes0): return vtx0, nil @@ -232,7 +232,7 @@ func TestBootstrapperByzantineResponses(t *testing.T) { acceptedIDs := []ids.ID{vtxID1} - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtxID1: return vtx1, nil @@ -258,7 +258,7 @@ func TestBootstrapperByzantineResponses(t *testing.T) { reqVtxID = vtxID } - manager.ParseF = func(vtxBytes []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(vtxBytes []byte) (avalanche.Vertex, error) { switch { case bytes.Equal(vtxBytes, vtxBytes0): vtx0.StatusV = choices.Processing @@ -293,7 +293,7 @@ func TestBootstrapperByzantineResponses(t *testing.T) { } oldReqID = *requestID - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtxID1: return vtx1, nil @@ -363,7 +363,7 @@ func TestBootstrapperTxDependencies(t *testing.T) { vtxBytes0 := []byte{2} vtxBytes1 := []byte{3} - vm.ParseF = func(b []byte) (snowstorm.Tx, error) { + vm.ParseTxF = func(b []byte) (snowstorm.Tx, error) { switch { case bytes.Equal(b, txBytes0): return tx0, nil @@ -408,7 +408,7 @@ func TestBootstrapperTxDependencies(t *testing.T) { acceptedIDs := []ids.ID{vtxID1} - manager.ParseF = func(vtxBytes []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(vtxBytes []byte) (avalanche.Vertex, error) { switch { case bytes.Equal(vtxBytes, vtxBytes1): return vtx1, nil @@ -418,7 +418,7 @@ func TestBootstrapperTxDependencies(t *testing.T) { t.Fatal(errParsedUnknownVertex) return nil, errParsedUnknownVertex } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtxID1: return vtx1, nil @@ -450,7 +450,7 @@ func TestBootstrapperTxDependencies(t *testing.T) { t.Fatal(err) } - manager.ParseF = func(vtxBytes []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(vtxBytes []byte) (avalanche.Vertex, error) { switch { case bytes.Equal(vtxBytes, vtxBytes1): return vtx1, nil @@ -551,7 +551,7 @@ func TestBootstrapperMissingTxDependency(t *testing.T) { acceptedIDs := []ids.ID{vtxID1} - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtxID1: return vtx1, nil @@ -562,7 +562,7 @@ func TestBootstrapperMissingTxDependency(t *testing.T) { panic(errUnknownVertex) } } - manager.ParseF = func(vtxBytes []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(vtxBytes []byte) (avalanche.Vertex, error) { switch { case bytes.Equal(vtxBytes, vtxBytes1): return vtx1, nil @@ -693,7 +693,7 @@ func TestBootstrapperFilterAccepted(t *testing.T) { vtxIDs := []ids.ID{vtxID0, vtxID1, vtxID2} - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtxID0: return vtx0, nil @@ -710,7 +710,7 @@ func TestBootstrapperFilterAccepted(t *testing.T) { acceptedSet := ids.Set{} acceptedSet.Add(accepted...) - manager.GetF = nil + manager.GetVtxF = nil if !acceptedSet.Contains(vtxID0) { t.Fatalf("Vtx should be accepted") @@ -776,7 +776,7 @@ func TestBootstrapperIncompleteMultiPut(t *testing.T) { acceptedIDs := []ids.ID{vtxID2} - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch { case vtxID == vtxID0: return nil, errUnknownVertex @@ -789,7 +789,7 @@ func TestBootstrapperIncompleteMultiPut(t *testing.T) { panic(errUnknownVertex) } } - manager.ParseF = func(vtxBytes []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(vtxBytes []byte) (avalanche.Vertex, error) { switch { case bytes.Equal(vtxBytes, vtxBytes0): vtx0.StatusV = choices.Processing @@ -897,7 +897,7 @@ func TestBootstrapperFinalized(t *testing.T) { parsedVtx0 := false parsedVtx1 := false - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtxID0: if parsedVtx0 { @@ -914,7 +914,7 @@ func TestBootstrapperFinalized(t *testing.T) { panic(errUnknownVertex) } } - manager.ParseF = func(vtxBytes []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(vtxBytes []byte) (avalanche.Vertex, error) { switch { case bytes.Equal(vtxBytes, vtxBytes0): vtx0.StatusV = choices.Processing @@ -1028,7 +1028,7 @@ func TestBootstrapperAcceptsMultiPutParents(t *testing.T) { parsedVtx0 := false parsedVtx1 := false parsedVtx2 := false - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtxID0: if parsedVtx0 { @@ -1050,7 +1050,7 @@ func TestBootstrapperAcceptsMultiPutParents(t *testing.T) { } return nil, errUnknownVertex } - manager.ParseF = func(vtxBytes []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(vtxBytes []byte) (avalanche.Vertex, error) { switch { case bytes.Equal(vtxBytes, vtxBytes0): vtx0.StatusV = choices.Processing diff --git a/snow/engine/avalanche/bootstrap/tx_job.go b/snow/engine/avalanche/bootstrap/tx_job.go index fd042c34657e..283c83fdd178 100644 --- a/snow/engine/avalanche/bootstrap/tx_job.go +++ b/snow/engine/avalanche/bootstrap/tx_job.go @@ -24,7 +24,7 @@ type txParser struct { } func (p *txParser) Parse(txBytes []byte) (queue.Job, error) { - tx, err := p.vm.Parse(txBytes) + tx, err := p.vm.ParseTx(txBytes) if err != nil { return nil, err } diff --git a/snow/engine/avalanche/bootstrap/vertex_job.go b/snow/engine/avalanche/bootstrap/vertex_job.go index 259156926a12..693d2994f347 100644 --- a/snow/engine/avalanche/bootstrap/vertex_job.go +++ b/snow/engine/avalanche/bootstrap/vertex_job.go @@ -24,7 +24,7 @@ type vtxParser struct { } func (p *vtxParser) Parse(vtxBytes []byte) (queue.Job, error) { - vtx, err := p.manager.Parse(vtxBytes) + vtx, err := p.manager.ParseVtx(vtxBytes) if err != nil { return nil, err } diff --git a/snow/engine/avalanche/engine.go b/snow/engine/avalanche/engine.go index 0afe8ec088d5..da5583816120 100644 --- a/snow/engine/avalanche/engine.go +++ b/snow/engine/avalanche/engine.go @@ -4,6 +4,8 @@ package avalanche import ( + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/consensus/avalanche" "github.com/ava-labs/avalanchego/snow/engine/common" ) @@ -11,12 +13,10 @@ import ( type Engine interface { common.Engine - /* - *************************************************************************** - ***************************** Setup/Teardown ****************************** - *************************************************************************** - */ - // Initialize this engine. - Initialize(Config) + Initialize(Config) error + + // GetVtx returns a vertex by its ID. + // Returns an error if unknown. + GetVtx(vtxID ids.ID) (avalanche.Vertex, error) } diff --git a/snow/engine/avalanche/issuer.go b/snow/engine/avalanche/issuer.go index 16e49a60fdf7..aab47f7a2c02 100644 --- a/snow/engine/avalanche/issuer.go +++ b/snow/engine/avalanche/issuer.go @@ -110,7 +110,7 @@ func (i *issuer) Update() { } // Issue a repoll - i.t.errs.Add(i.t.repoll()) + i.t.repoll() } type vtxIssuer struct{ i *issuer } diff --git a/snow/engine/avalanche/mocks/engine.go b/snow/engine/avalanche/mocks/engine.go new file mode 100644 index 000000000000..d1e048eafd45 --- /dev/null +++ b/snow/engine/avalanche/mocks/engine.go @@ -0,0 +1,435 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import ( + avalanche "github.com/ava-labs/avalanchego/snow/engine/avalanche" + common "github.com/ava-labs/avalanchego/snow/engine/common" + + consensusavalanche "github.com/ava-labs/avalanchego/snow/consensus/avalanche" + + ids "github.com/ava-labs/avalanchego/ids" + + mock "github.com/stretchr/testify/mock" + + snow "github.com/ava-labs/avalanchego/snow" +) + +// Engine is an autogenerated mock type for the Engine type +type Engine struct { + mock.Mock +} + +// Accepted provides a mock function with given fields: validatorID, requestID, containerIDs +func (_m *Engine) Accepted(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { + ret := _m.Called(validatorID, requestID, containerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, []ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AcceptedFrontier provides a mock function with given fields: validatorID, requestID, containerIDs +func (_m *Engine) AcceptedFrontier(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { + ret := _m.Called(validatorID, requestID, containerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, []ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Chits provides a mock function with given fields: validatorID, requestID, containerIDs +func (_m *Engine) Chits(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { + ret := _m.Called(validatorID, requestID, containerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, []ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Connected provides a mock function with given fields: validatorID +func (_m *Engine) Connected(validatorID ids.ShortID) error { + ret := _m.Called(validatorID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID) error); ok { + r0 = rf(validatorID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Context provides a mock function with given fields: +func (_m *Engine) Context() *snow.Context { + ret := _m.Called() + + var r0 *snow.Context + if rf, ok := ret.Get(0).(func() *snow.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*snow.Context) + } + } + + return r0 +} + +// Disconnected provides a mock function with given fields: validatorID +func (_m *Engine) Disconnected(validatorID ids.ShortID) error { + ret := _m.Called(validatorID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID) error); ok { + r0 = rf(validatorID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Get provides a mock function with given fields: validatorID, requestID, containerID +func (_m *Engine) Get(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { + ret := _m.Called(validatorID, requestID, containerID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAccepted provides a mock function with given fields: validatorID, requestID, containerIDs +func (_m *Engine) GetAccepted(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { + ret := _m.Called(validatorID, requestID, containerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, []ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAcceptedFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetAcceptedFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAcceptedFrontier provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetAcceptedFrontier(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAcceptedFrontierFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetAcceptedFrontierFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAncestors provides a mock function with given fields: validatorID, requestID, containerID +func (_m *Engine) GetAncestors(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { + ret := _m.Called(validatorID, requestID, containerID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAncestorsFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetAncestorsFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetVM provides a mock function with given fields: +func (_m *Engine) GetVM() common.VM { + ret := _m.Called() + + var r0 common.VM + if rf, ok := ret.Get(0).(func() common.VM); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.VM) + } + } + + return r0 +} + +// GetVtx provides a mock function with given fields: vtxID +func (_m *Engine) GetVtx(vtxID ids.ID) (consensusavalanche.Vertex, error) { + ret := _m.Called(vtxID) + + var r0 consensusavalanche.Vertex + if rf, ok := ret.Get(0).(func(ids.ID) consensusavalanche.Vertex); ok { + r0 = rf(vtxID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(consensusavalanche.Vertex) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(ids.ID) error); ok { + r1 = rf(vtxID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Gossip provides a mock function with given fields: +func (_m *Engine) Gossip() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// HealthCheck provides a mock function with given fields: +func (_m *Engine) HealthCheck() (interface{}, error) { + ret := _m.Called() + + var r0 interface{} + if rf, ok := ret.Get(0).(func() interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Initialize provides a mock function with given fields: _a0 +func (_m *Engine) Initialize(_a0 avalanche.Config) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(avalanche.Config) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// IsBootstrapped provides a mock function with given fields: +func (_m *Engine) IsBootstrapped() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MultiPut provides a mock function with given fields: validatorID, requestID, containers +func (_m *Engine) MultiPut(validatorID ids.ShortID, requestID uint32, containers [][]byte) error { + ret := _m.Called(validatorID, requestID, containers) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, [][]byte) error); ok { + r0 = rf(validatorID, requestID, containers) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Notify provides a mock function with given fields: _a0 +func (_m *Engine) Notify(_a0 common.Message) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(common.Message) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PullQuery provides a mock function with given fields: validatorID, requestID, containerID +func (_m *Engine) PullQuery(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { + ret := _m.Called(validatorID, requestID, containerID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PushQuery provides a mock function with given fields: validatorID, requestID, containerID, container +func (_m *Engine) PushQuery(validatorID ids.ShortID, requestID uint32, containerID ids.ID, container []byte) error { + ret := _m.Called(validatorID, requestID, containerID, container) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID, []byte) error); ok { + r0 = rf(validatorID, requestID, containerID, container) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Put provides a mock function with given fields: validatorID, requestID, containerID, container +func (_m *Engine) Put(validatorID ids.ShortID, requestID uint32, containerID ids.ID, container []byte) error { + ret := _m.Called(validatorID, requestID, containerID, container) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID, []byte) error); ok { + r0 = rf(validatorID, requestID, containerID, container) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// QueryFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) QueryFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Shutdown provides a mock function with given fields: +func (_m *Engine) Shutdown() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Startup provides a mock function with given fields: +func (_m *Engine) Startup() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/snow/engine/avalanche/state/serializer.go b/snow/engine/avalanche/state/serializer.go index 1f20337f4131..7b513409e5ed 100644 --- a/snow/engine/avalanche/state/serializer.go +++ b/snow/engine/avalanche/state/serializer.go @@ -30,6 +30,8 @@ var ( errWrongChainID = errors.New("wrong ChainID in vertex") ) +var _ vertex.Manager = &Serializer{} + // Serializer manages the state of multiple vertices type Serializer struct { ctx *snow.Context @@ -58,12 +60,12 @@ func (s *Serializer) Initialize(ctx *snow.Context, vm vertex.DAGVM, db database. } // Parse implements the avalanche.State interface -func (s *Serializer) Parse(b []byte) (avalanche.Vertex, error) { +func (s *Serializer) ParseVtx(b []byte) (avalanche.Vertex, error) { return newUniqueVertex(s, b) } // Build implements the avalanche.State interface -func (s *Serializer) Build( +func (s *Serializer) BuildVtx( epoch uint32, parentIDs []ids.ID, txs []snowstorm.Tx, @@ -105,7 +107,7 @@ func (s *Serializer) Build( } // Get implements the avalanche.State interface -func (s *Serializer) Get(vtxID ids.ID) (avalanche.Vertex, error) { return s.getVertex(vtxID) } +func (s *Serializer) GetVtx(vtxID ids.ID) (avalanche.Vertex, error) { return s.getVertex(vtxID) } // Edge implements the avalanche.State interface func (s *Serializer) Edge() []ids.ID { return s.edge.List() } diff --git a/snow/engine/avalanche/state/unique_vertex.go b/snow/engine/avalanche/state/unique_vertex.go index d9188bea2e52..af084e4d9896 100644 --- a/snow/engine/avalanche/state/unique_vertex.go +++ b/snow/engine/avalanche/state/unique_vertex.go @@ -236,7 +236,7 @@ func (vtx *uniqueVertex) Txs() ([]snowstorm.Tx, error) { if len(txs) != len(vtx.v.txs) { vtx.v.txs = make([]snowstorm.Tx, len(txs)) for i, txBytes := range txs { - tx, err := vtx.serializer.vm.Parse(txBytes) + tx, err := vtx.serializer.vm.ParseTx(txBytes) if err != nil { return nil, err } diff --git a/snow/engine/avalanche/state/unique_vertex_test.go b/snow/engine/avalanche/state/unique_vertex_test.go index 5ccbd07c9317..2dbb775eaf36 100644 --- a/snow/engine/avalanche/state/unique_vertex_test.go +++ b/snow/engine/avalanche/state/unique_vertex_test.go @@ -19,7 +19,7 @@ func newSerializer(t *testing.T, parse func([]byte) (snowstorm.Tx, error)) *Seri vm := vertex.TestVM{} vm.T = t vm.Default(true) - vm.ParseF = parse + vm.ParseTxF = parse baseDB := memdb.New() ctx := snow.DefaultContextTest() diff --git a/snow/engine/avalanche/transitive.go b/snow/engine/avalanche/transitive.go index 4be7ac3f2f03..e21101631394 100644 --- a/snow/engine/avalanche/transitive.go +++ b/snow/engine/avalanche/transitive.go @@ -29,6 +29,8 @@ const ( maxContainersLen = int(4 * network.DefaultMaxMessageSize / 5) ) +var _ Engine = &Transitive{} + // Transitive implements the Engine interface by attempting to fetch all // transitive dependencies. type Transitive struct { @@ -93,7 +95,7 @@ func (t *Transitive) finishBootstrapping() error { edge := t.Manager.Edge() frontier := make([]avalanche.Vertex, 0, len(edge)) for _, vtxID := range edge { - if vtx, err := t.Manager.Get(vtxID); err == nil { + if vtx, err := t.Manager.GetVtx(vtxID); err == nil { frontier = append(frontier, vtx) } else { t.Ctx.Log.Error("vertex %s failed to be loaded from the frontier with %s", vtxID, err) @@ -121,7 +123,7 @@ func (t *Transitive) Gossip() error { return err // Also should never really happen because the edge has positive length } vtxID := edge[int(indices[0])] - vtx, err := t.Manager.Get(vtxID) + vtx, err := t.Manager.GetVtx(vtxID) if err != nil { t.Ctx.Log.Warn("dropping gossip request as %s couldn't be loaded due to: %s", vtxID, err) return nil @@ -141,7 +143,7 @@ func (t *Transitive) Shutdown() error { // Get implements the Engine interface func (t *Transitive) Get(vdr ids.ShortID, requestID uint32, vtxID ids.ID) error { // If this engine has access to the requested vertex, provide it - if vtx, err := t.Manager.Get(vtxID); err == nil { + if vtx, err := t.Manager.GetVtx(vtxID); err == nil { t.Sender.Put(vdr, requestID, vtxID, vtx.Bytes()) } return nil @@ -151,7 +153,7 @@ func (t *Transitive) Get(vdr ids.ShortID, requestID uint32, vtxID ids.ID) error func (t *Transitive) GetAncestors(vdr ids.ShortID, requestID uint32, vtxID ids.ID) error { startTime := time.Now() t.Ctx.Log.Verbo("GetAncestors(%s, %d, %s) called", vdr, requestID, vtxID) - vertex, err := t.Manager.Get(vtxID) + vertex, err := t.Manager.GetVtx(vtxID) if err != nil || vertex.Status() == choices.Unknown { t.Ctx.Log.Verbo("dropping getAncestors") return nil // Don't have the requested vertex. Drop message. @@ -209,7 +211,7 @@ func (t *Transitive) Put(vdr ids.ShortID, requestID uint32, vtxID ids.ID, vtxByt return nil } - vtx, err := t.Manager.Parse(vtxBytes) + vtx, err := t.Manager.ParseVtx(vtxBytes) if err != nil { t.Ctx.Log.Debug("failed to parse vertex %s due to: %s", vtxID, err) t.Ctx.Log.Verbo("vertex:\n%s", formatting.DumpBytes{Bytes: vtxBytes}) @@ -291,7 +293,7 @@ func (t *Transitive) PushQuery(vdr ids.ShortID, requestID uint32, vtxID ids.ID, return nil } - vtx, err := t.Manager.Parse(vtxBytes) + vtx, err := t.Manager.ParseVtx(vtxBytes) if err != nil { t.Ctx.Log.Debug("failed to parse vertex %s due to: %s", vtxID, err) t.Ctx.Log.Verbo("vertex:\n%s", formatting.DumpBytes{Bytes: vtxBytes}) @@ -344,7 +346,7 @@ func (t *Transitive) Notify(msg common.Message) error { switch msg { case common.PendingTxs: - t.pendingTxs = append(t.pendingTxs, t.VM.Pending()...) + t.pendingTxs = append(t.pendingTxs, t.VM.PendingTxs()...) return t.attemptToIssueTxs() default: t.Ctx.Log.Warn("unexpected message from the VM: %s", msg) @@ -365,22 +367,17 @@ func (t *Transitive) attemptToIssueTxs() error { // If there are pending transactions from the VM, issue them. // If we're not already at the limit for number of concurrent polls, issue a new // query. -func (t *Transitive) repoll() error { - if t.polls.Len() >= t.Params.ConcurrentRepolls || t.errs.Errored() { - return nil - } - - for i := t.polls.Len(); i < t.Params.ConcurrentRepolls; i++ { +func (t *Transitive) repoll() { + for i := t.polls.Len(); i < t.Params.ConcurrentRepolls && !t.errs.Errored(); i++ { t.issueRepoll() } - return nil } // issueFromByID issues the branch ending with vertex [vtxID] to consensus. // Fetches [vtxID] if we don't have it locally. // Returns true if [vtx] has been added to consensus (now or previously) func (t *Transitive) issueFromByID(vdr ids.ShortID, vtxID ids.ID) (bool, error) { - vtx, err := t.Manager.Get(vtxID) + vtx, err := t.Manager.GetVtx(vtxID) if err != nil { // We don't have [vtxID]. Request it. t.sendRequest(vdr, vtxID) @@ -611,7 +608,7 @@ func (t *Transitive) issueBatch(txs []snowstorm.Tx) error { parentIDs[i] = virtuousIDs[int(index)] } - vtx, err := t.Manager.Build(0, parentIDs, txs, nil) + vtx, err := t.Manager.BuildVtx(0, parentIDs, txs, nil) if err != nil { t.Ctx.Log.Warn("error building new vertex with %d parents and %d transactions", len(parentIDs), len(txs)) @@ -654,3 +651,13 @@ func (t *Transitive) HealthCheck() (interface{}, error) { } return intf, fmt.Errorf("vm: %s ; consensus: %s", vmErr, consensusErr) } + +// GetVtx returns a vertex by its ID. +// Returns database.ErrNotFound if unknown. +func (t *Transitive) GetVtx(vtxID ids.ID) (avalanche.Vertex, error) { + return t.Manager.GetVtx(vtxID) +} + +func (t *Transitive) GetVM() common.VM { + return t.VM +} diff --git a/snow/engine/avalanche/transitive_test.go b/snow/engine/avalanche/transitive_test.go index 78c57748be30..0f2b2cf6ef96 100644 --- a/snow/engine/avalanche/transitive_test.go +++ b/snow/engine/avalanche/transitive_test.go @@ -114,7 +114,7 @@ func TestEngineAdd(t *testing.T) { } } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if !bytes.Equal(b, vtx.Bytes()) { t.Fatalf("Wrong bytes") } @@ -125,7 +125,7 @@ func TestEngineAdd(t *testing.T) { t.Fatal(err) } - manager.ParseF = nil + manager.ParseVtxF = nil if !*asked { t.Fatalf("Didn't ask for a missing vertex") @@ -135,13 +135,13 @@ func TestEngineAdd(t *testing.T) { t.Fatalf("Should have been blocking on request") } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { return nil, errFailedParsing } + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { return nil, errFailedParsing } if err := te.Put(vdr, *reqID, vtx.ParentsV[0].ID(), nil); err != nil { t.Fatal(err) } - manager.ParseF = nil + manager.ParseVtxF = nil if len(te.vtxBlocked) != 0 { t.Fatalf("Should have finished blocking issue") @@ -201,7 +201,7 @@ func TestEngineQuery(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -219,7 +219,7 @@ func TestEngineQuery(t *testing.T) { } vertexed := new(bool) - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if *vertexed { t.Fatalf("Sent multiple requests") } @@ -284,7 +284,7 @@ func TestEngineQuery(t *testing.T) { } } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if !bytes.Equal(b, vtx0.Bytes()) { t.Fatalf("Wrong bytes") } @@ -293,7 +293,7 @@ func TestEngineQuery(t *testing.T) { if err := te.Put(vdr, 0, vtx0.ID(), vtx0.Bytes()); err != nil { t.Fatal(err) } - manager.ParseF = nil + manager.ParseVtxF = nil if !*queried { t.Fatalf("Didn't ask for preferences") @@ -313,7 +313,7 @@ func TestEngineQuery(t *testing.T) { BytesV: []byte{5, 4, 3, 2, 1, 9}, } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtx0.ID() { return &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ @@ -363,12 +363,12 @@ func TestEngineQuery(t *testing.T) { } } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if !bytes.Equal(b, vtx1.Bytes()) { t.Fatalf("Wrong bytes") } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtx0.ID() { return &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ @@ -388,7 +388,7 @@ func TestEngineQuery(t *testing.T) { if err := te.Put(vdr, 0, vtx1.ID(), vtx1.Bytes()); err != nil { t.Fatal(err) } - manager.ParseF = nil + manager.ParseVtxF = nil if vtx0.Status() != choices.Accepted { t.Fatalf("Should have executed vertex") @@ -466,7 +466,7 @@ func TestEngineMultipleQuery(t *testing.T) { utxos := []ids.ID{ids.GenerateTestID()} manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -530,7 +530,7 @@ func TestEngineMultipleQuery(t *testing.T) { TxsV: []snowstorm.Tx{tx0}, } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -716,7 +716,7 @@ func TestEngineAbandonResponse(t *testing.T) { TxsV: []snowstorm.Tx{tx0}, } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { return nil, errUnknownVertex } + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { return nil, errUnknownVertex } te := &Transitive{} if err := te.Initialize(config); err != nil { @@ -892,7 +892,7 @@ func TestEngineRejectDoubleSpendTx(t *testing.T) { tx1.InputIDsV = append(tx1.InputIDsV, utxos[0]) manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID(), mVtx.ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -902,7 +902,7 @@ func TestEngineRejectDoubleSpendTx(t *testing.T) { t.Fatalf("Unknown vertex") panic("Should have errored") } - manager.BuildF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { + manager.BuildVtxF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { return &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ IDV: ids.GenerateTestID(), @@ -928,7 +928,7 @@ func TestEngineRejectDoubleSpendTx(t *testing.T) { sender.CantPushQuery = false - vm.PendingF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0, tx1} } + vm.PendingTxsF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0, tx1} } if err := te.Notify(common.PendingTxs); err != nil { t.Fatal(err) } @@ -1000,7 +1000,7 @@ func TestEngineRejectDoubleSpendIssuedTx(t *testing.T) { tx1.InputIDsV = append(tx1.InputIDsV, utxos[0]) manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID(), mVtx.ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1021,7 +1021,7 @@ func TestEngineRejectDoubleSpendIssuedTx(t *testing.T) { vm.CantBootstrapping = true vm.CantBootstrapped = true - manager.BuildF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { + manager.BuildVtxF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { return &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ IDV: ids.GenerateTestID(), @@ -1036,12 +1036,12 @@ func TestEngineRejectDoubleSpendIssuedTx(t *testing.T) { sender.CantPushQuery = false - vm.PendingF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0} } + vm.PendingTxsF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0} } if err := te.Notify(common.PendingTxs); err != nil { t.Fatal(err) } - vm.PendingF = func() []snowstorm.Tx { return []snowstorm.Tx{tx1} } + vm.PendingTxsF = func() []snowstorm.Tx { return []snowstorm.Tx{tx1} } if err := te.Notify(common.PendingTxs); err != nil { t.Fatal(err) } @@ -1082,7 +1082,7 @@ func TestEngineIssueRepoll(t *testing.T) { }} manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID(), mVtx.ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1109,7 +1109,8 @@ func TestEngineIssueRepoll(t *testing.T) { } } - if err := te.repoll(); err != nil { + te.repoll() + if err := te.errs.Err; err != nil { t.Fatal(err) } } @@ -1210,7 +1211,7 @@ func TestEngineReissue(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID(), mVtx.ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1235,7 +1236,7 @@ func TestEngineReissue(t *testing.T) { vm.CantBootstrapped = true lastVtx := new(avalanche.TestVertex) - manager.BuildF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { + manager.BuildVtxF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { lastVtx = &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ IDV: ids.GenerateTestID(), @@ -1249,7 +1250,7 @@ func TestEngineReissue(t *testing.T) { return lastVtx, nil } - vm.GetF = func(id ids.ID) (snowstorm.Tx, error) { + vm.GetTxF = func(id ids.ID) (snowstorm.Tx, error) { if id != tx0.ID() { t.Fatalf("Wrong tx") } @@ -1261,12 +1262,12 @@ func TestEngineReissue(t *testing.T) { *queryRequestID = requestID } - vm.PendingF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0, tx1} } + vm.PendingTxsF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0, tx1} } if err := te.Notify(common.PendingTxs); err != nil { t.Fatal(err) } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if !bytes.Equal(b, vtx.Bytes()) { t.Fatalf("Wrong bytes") } @@ -1275,9 +1276,9 @@ func TestEngineReissue(t *testing.T) { if err := te.Put(vdr, 0, vtx.ID(), vtx.Bytes()); err != nil { t.Fatal(err) } - manager.ParseF = nil + manager.ParseVtxF = nil - vm.PendingF = func() []snowstorm.Tx { return []snowstorm.Tx{tx3} } + vm.PendingTxsF = func() []snowstorm.Tx { return []snowstorm.Tx{tx3} } if err := te.Notify(common.PendingTxs); err != nil { t.Fatal(err) } @@ -1358,7 +1359,7 @@ func TestEngineLargeIssue(t *testing.T) { tx1.InputIDsV = append(tx1.InputIDsV, utxos[1]) manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID(), mVtx.ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1381,7 +1382,7 @@ func TestEngineLargeIssue(t *testing.T) { vm.CantBootstrapped = true lastVtx := new(avalanche.TestVertex) - manager.BuildF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { + manager.BuildVtxF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { lastVtx = &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ IDV: ids.GenerateTestID(), @@ -1397,7 +1398,7 @@ func TestEngineLargeIssue(t *testing.T) { sender.CantPushQuery = false - vm.PendingF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0, tx1} } + vm.PendingTxsF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0, tx1} } if err := te.Notify(common.PendingTxs); err != nil { t.Fatal(err) } @@ -1434,7 +1435,7 @@ func TestEngineGetVertex(t *testing.T) { }} manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID(), mVtx.ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1504,7 +1505,7 @@ func TestEngineInsufficientValidators(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1579,7 +1580,7 @@ func TestEnginePushGossip(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1602,7 +1603,7 @@ func TestEnginePushGossip(t *testing.T) { *requested = true } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtx.BytesV) { return vtx, nil } @@ -1666,7 +1667,7 @@ func TestEngineSingleQuery(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1757,7 +1758,7 @@ func TestEngineParentBlockingInsert(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1861,7 +1862,7 @@ func TestEngineBlockingChitRequest(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -1881,14 +1882,14 @@ func TestEngineBlockingChitRequest(t *testing.T) { t.Fatal(err) } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == blockingVtx.ID() { return blockingVtx, nil } t.Fatalf("Unknown vertex") panic("Should have errored") } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, blockingVtx.Bytes()) { return blockingVtx, nil } @@ -1982,7 +1983,7 @@ func TestEngineBlockingChitResponse(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -2019,7 +2020,7 @@ func TestEngineBlockingChitResponse(t *testing.T) { t.Fatal(err) } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { if id == blockingVtx.ID() { return blockingVtx, nil } @@ -2114,7 +2115,7 @@ func TestEngineMissingTx(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -2151,7 +2152,7 @@ func TestEngineMissingTx(t *testing.T) { t.Fatal(err) } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { if id == blockingVtx.ID() { return blockingVtx, nil } @@ -2301,7 +2302,7 @@ func TestEngineReissueAbortedVertex(t *testing.T) { return []ids.ID{gVtx.ID()} } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == gVtx.ID() { return gVtx, nil } @@ -2315,20 +2316,20 @@ func TestEngineReissueAbortedVertex(t *testing.T) { } manager.EdgeF = nil - manager.GetF = nil + manager.GetVtxF = nil requestID := new(uint32) sender.GetF = func(vID ids.ShortID, reqID uint32, vtxID ids.ID) { *requestID = reqID } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtxBytes1) { return vtx1, nil } t.Fatalf("Unknown bytes provided") panic("Unknown bytes provided") } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtxID1 { return vtx1, nil } @@ -2341,7 +2342,7 @@ func TestEngineReissueAbortedVertex(t *testing.T) { } sender.GetF = nil - manager.ParseF = nil + manager.ParseVtxF = nil sender.CantChits = false if err := te.GetFailed(vdr, *requestID); err != nil { @@ -2354,7 +2355,7 @@ func TestEngineReissueAbortedVertex(t *testing.T) { *requested = true } } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtxID1 { return vtx1, nil } @@ -2509,7 +2510,7 @@ func TestEngineBootstrappingIntoConsensus(t *testing.T) { t.Fatalf("Should have requested from the validators during AcceptedFrontier") } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtxID0 { return nil, errMissing } @@ -2531,17 +2532,17 @@ func TestEngineBootstrappingIntoConsensus(t *testing.T) { t.Fatal(err) } - manager.GetF = nil + manager.GetVtxF = nil sender.GetF = nil - vm.ParseF = func(b []byte) (snowstorm.Tx, error) { + vm.ParseTxF = func(b []byte) (snowstorm.Tx, error) { if bytes.Equal(b, txBytes0) { return tx0, nil } t.Fatalf("Unknown bytes provided") panic("Unknown bytes provided") } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtxBytes0) { return vtx0, nil } @@ -2551,7 +2552,7 @@ func TestEngineBootstrappingIntoConsensus(t *testing.T) { manager.EdgeF = func() []ids.ID { return []ids.ID{vtxID0} } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtxID0 { return vtx0, nil } @@ -2563,10 +2564,10 @@ func TestEngineBootstrappingIntoConsensus(t *testing.T) { t.Fatal(err) } - vm.ParseF = nil - manager.ParseF = nil + vm.ParseTxF = nil + manager.ParseVtxF = nil manager.EdgeF = nil - manager.GetF = nil + manager.GetVtxF = nil if tx0.Status() != choices.Accepted { t.Fatalf("Should have accepted %s", txID0) @@ -2575,7 +2576,7 @@ func TestEngineBootstrappingIntoConsensus(t *testing.T) { t.Fatalf("Should have accepted %s", vtxID0) } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtxBytes1) { return vtx1, nil } @@ -2608,7 +2609,7 @@ func TestEngineBootstrappingIntoConsensus(t *testing.T) { t.Fatalf("Sent wrong query bytes") } } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtxID1 { return vtx1, nil } @@ -2620,10 +2621,10 @@ func TestEngineBootstrappingIntoConsensus(t *testing.T) { t.Fatal(err) } - manager.ParseF = nil + manager.ParseVtxF = nil sender.ChitsF = nil sender.PushQueryF = nil - manager.GetF = nil + manager.GetVtxF = nil } func TestEngineReBootstrapFails(t *testing.T) { @@ -2923,7 +2924,7 @@ func TestEngineReBootstrappingIntoConsensus(t *testing.T) { t.Fatalf("Should have requested from the validators during AcceptedFrontier") } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtxID0 { return nil, errMissing } @@ -2945,16 +2946,16 @@ func TestEngineReBootstrappingIntoConsensus(t *testing.T) { t.Fatal(err) } - manager.GetF = nil + manager.GetVtxF = nil - vm.ParseF = func(b []byte) (snowstorm.Tx, error) { + vm.ParseTxF = func(b []byte) (snowstorm.Tx, error) { if bytes.Equal(b, txBytes0) { return tx0, nil } t.Fatalf("Unknown bytes provided") panic("Unknown bytes provided") } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtxBytes0) { return vtx0, nil } @@ -2964,7 +2965,7 @@ func TestEngineReBootstrappingIntoConsensus(t *testing.T) { manager.EdgeF = func() []ids.ID { return []ids.ID{vtxID0} } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtxID0 { return vtx0, nil } @@ -2978,10 +2979,10 @@ func TestEngineReBootstrappingIntoConsensus(t *testing.T) { sender.GetAcceptedFrontierF = nil sender.GetF = nil - vm.ParseF = nil - manager.ParseF = nil + vm.ParseTxF = nil + manager.ParseVtxF = nil manager.EdgeF = nil - manager.GetF = nil + manager.GetVtxF = nil if tx0.Status() != choices.Accepted { t.Fatalf("Should have accepted %s", txID0) @@ -2990,7 +2991,7 @@ func TestEngineReBootstrappingIntoConsensus(t *testing.T) { t.Fatalf("Should have accepted %s", vtxID0) } - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtxBytes1) { return vtx1, nil } @@ -3023,7 +3024,7 @@ func TestEngineReBootstrappingIntoConsensus(t *testing.T) { t.Fatalf("Sent wrong query bytes") } } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == vtxID1 { return vtx1, nil } @@ -3035,10 +3036,10 @@ func TestEngineReBootstrappingIntoConsensus(t *testing.T) { t.Fatal(err) } - manager.ParseF = nil + manager.ParseVtxF = nil sender.ChitsF = nil sender.PushQueryF = nil - manager.GetF = nil + manager.GetVtxF = nil } func TestEngineUndeclaredDependencyDeadlock(t *testing.T) { @@ -3123,7 +3124,7 @@ func TestEngineUndeclaredDependencyDeadlock(t *testing.T) { t.Fatal(err) } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { switch vtxID { case vtx0.ID(): return vtx0, nil @@ -3195,7 +3196,7 @@ func TestEnginePartiallyValidVertex(t *testing.T) { } expectedVtxID := ids.GenerateTestID() - manager.BuildF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { + manager.BuildVtxF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { return &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ IDV: expectedVtxID, @@ -3246,7 +3247,7 @@ func TestEngineGossip(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID()} } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if vtxID == gVtx.ID() { return gVtx, nil } @@ -3347,7 +3348,7 @@ func TestEngineInvalidVertexIgnoredFromUnexpectedPeer(t *testing.T) { } parsed := new(bool) - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtx1.Bytes()) { *parsed = true return vtx1, nil @@ -3355,7 +3356,7 @@ func TestEngineInvalidVertexIgnoredFromUnexpectedPeer(t *testing.T) { return nil, errUnknownVertex } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if !*parsed { return nil, errUnknownVertex } @@ -3386,7 +3387,7 @@ func TestEngineInvalidVertexIgnoredFromUnexpectedPeer(t *testing.T) { } *parsed = false - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtx0.Bytes()) { *parsed = true return vtx0, nil @@ -3394,7 +3395,7 @@ func TestEngineInvalidVertexIgnoredFromUnexpectedPeer(t *testing.T) { return nil, errUnknownVertex } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if !*parsed { return nil, errUnknownVertex } @@ -3490,7 +3491,7 @@ func TestEnginePushQueryRequestIDConflict(t *testing.T) { } parsed := new(bool) - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtx1.Bytes()) { *parsed = true return vtx1, nil @@ -3498,7 +3499,7 @@ func TestEnginePushQueryRequestIDConflict(t *testing.T) { return nil, errUnknownVertex } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if !*parsed { return nil, errUnknownVertex } @@ -3532,7 +3533,7 @@ func TestEnginePushQueryRequestIDConflict(t *testing.T) { } *parsed = false - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtx0.Bytes()) { *parsed = true return vtx0, nil @@ -3540,7 +3541,7 @@ func TestEnginePushQueryRequestIDConflict(t *testing.T) { return nil, errUnknownVertex } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if !*parsed { return nil, errUnknownVertex } @@ -3638,7 +3639,7 @@ func TestEngineAggressivePolling(t *testing.T) { vm.CantBootstrapped = true parsed := new(bool) - manager.ParseF = func(b []byte) (avalanche.Vertex, error) { + manager.ParseVtxF = func(b []byte) (avalanche.Vertex, error) { if bytes.Equal(b, vtx.Bytes()) { *parsed = true return vtx, nil @@ -3646,7 +3647,7 @@ func TestEngineAggressivePolling(t *testing.T) { return nil, errUnknownVertex } - manager.GetF = func(vtxID ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(vtxID ids.ID) (avalanche.Vertex, error) { if !*parsed { return nil, errUnknownVertex } @@ -3663,7 +3664,7 @@ func TestEngineAggressivePolling(t *testing.T) { numPullQueries := new(int) sender.PullQueryF = func(ids.ShortSet, uint32, ids.ID) { *numPullQueries++ } - vm.CantPending = false + vm.CantPendingTxs = false if err := te.Put(vdr, 0, vtx.ID(), vtx.Bytes()); err != nil { t.Fatal(err) @@ -3735,7 +3736,7 @@ func TestEngineDuplicatedIssuance(t *testing.T) { tx.InputIDsV = append(tx.InputIDsV, utxos[0]) manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID(), mVtx.ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -3758,7 +3759,7 @@ func TestEngineDuplicatedIssuance(t *testing.T) { vm.CantBootstrapped = true lastVtx := new(avalanche.TestVertex) - manager.BuildF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { + manager.BuildVtxF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { lastVtx = &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ IDV: ids.GenerateTestID(), @@ -3774,7 +3775,7 @@ func TestEngineDuplicatedIssuance(t *testing.T) { sender.CantPushQuery = false - vm.PendingF = func() []snowstorm.Tx { return []snowstorm.Tx{tx} } + vm.PendingTxsF = func() []snowstorm.Tx { return []snowstorm.Tx{tx} } if err := te.Notify(common.PendingTxs); err != nil { t.Fatal(err) } @@ -3783,7 +3784,7 @@ func TestEngineDuplicatedIssuance(t *testing.T) { t.Fatalf("Should have issued txs differently") } - manager.BuildF = func(uint32, []ids.ID, []snowstorm.Tx, []ids.ID) (avalanche.Vertex, error) { + manager.BuildVtxF = func(uint32, []ids.ID, []snowstorm.Tx, []ids.ID) (avalanche.Vertex, error) { t.Fatalf("shouldn't have attempted to issue a duplicated tx") return nil, nil } @@ -3854,7 +3855,7 @@ func TestEngineDoubleChit(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -3880,7 +3881,7 @@ func TestEngineDoubleChit(t *testing.T) { t.Fatalf("Wrong vertex requested") } } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { if id == vtx.ID() { return vtx, nil } @@ -4011,7 +4012,7 @@ func TestEngineBubbleVotes(t *testing.T) { } manager.EdgeF = func() []ids.ID { return nil } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case vtx.ID(): return vtx, nil @@ -4135,7 +4136,7 @@ func TestEngineIssue(t *testing.T) { } manager.EdgeF = func() []ids.ID { return []ids.ID{gVtx.ID(), mVtx.ID()} } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -4158,7 +4159,7 @@ func TestEngineIssue(t *testing.T) { vm.CantBootstrapped = true numBuilt := 0 - manager.BuildF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { + manager.BuildVtxF = func(_ uint32, _ []ids.ID, txs []snowstorm.Tx, _ []ids.ID) (avalanche.Vertex, error) { numBuilt++ vtx := &avalanche.TestVertex{ TestDecidable: choices.TestDecidable{ @@ -4171,7 +4172,7 @@ func TestEngineIssue(t *testing.T) { BytesV: []byte{1}, } - manager.GetF = func(id ids.ID) (avalanche.Vertex, error) { + manager.GetVtxF = func(id ids.ID) (avalanche.Vertex, error) { switch id { case gVtx.ID(): return gVtx, nil @@ -4196,7 +4197,7 @@ func TestEngineIssue(t *testing.T) { queryRequestID = requestID } - vm.PendingF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0, tx1} } + vm.PendingTxsF = func() []snowstorm.Tx { return []snowstorm.Tx{tx0, tx1} } if err := te.Notify(common.PendingTxs); err != nil { t.Fatal(err) } diff --git a/snow/engine/avalanche/vertex/builder.go b/snow/engine/avalanche/vertex/builder.go index dd99d9b23e3f..77e52bbada25 100644 --- a/snow/engine/avalanche/vertex/builder.go +++ b/snow/engine/avalanche/vertex/builder.go @@ -13,7 +13,7 @@ import ( // Builder builds a vertex given a set of parentIDs and transactions. type Builder interface { // Build a new vertex from the contents of a vertex - Build( + BuildVtx( epoch uint32, parentIDs []ids.ID, txs []snowstorm.Tx, @@ -47,7 +47,7 @@ func Build( return nil, err } - vtxBytes, err := Codec.Marshal(innerVtx.Version, innerVtx) + vtxBytes, err := c.Marshal(innerVtx.Version, innerVtx) vtx := statelessVertex{ innerStatelessVertex: innerVtx, id: hashing.ComputeHash256Array(vtxBytes), diff --git a/snow/engine/avalanche/vertex/codec.go b/snow/engine/avalanche/vertex/codec.go index 977b16459e36..0d5a6623d1dc 100644 --- a/snow/engine/avalanche/vertex/codec.go +++ b/snow/engine/avalanche/vertex/codec.go @@ -23,18 +23,18 @@ const ( ) var ( - Codec codec.Manager + c codec.Manager ) func init() { codecV0 := linearcodec.New("serializeV0", maxSize) codecV1 := linearcodec.New("serializeV1", maxSize) - Codec = codec.NewManager(maxSize) + c = codec.NewManager(maxSize) errs := wrappers.Errs{} errs.Add( - Codec.RegisterCodec(noEpochTransitionsCodecVersion, codecV0), - Codec.RegisterCodec(apricotCodecVersion, codecV1), + c.RegisterCodec(noEpochTransitionsCodecVersion, codecV0), + c.RegisterCodec(apricotCodecVersion, codecV1), ) if errs.Errored() { panic(errs.Err) diff --git a/snow/engine/avalanche/vertex/mocks/dag_vm.go b/snow/engine/avalanche/vertex/mocks/dag_vm.go new file mode 100644 index 000000000000..afa1da750cd1 --- /dev/null +++ b/snow/engine/avalanche/vertex/mocks/dag_vm.go @@ -0,0 +1,187 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import ( + database "github.com/ava-labs/avalanchego/database" + common "github.com/ava-labs/avalanchego/snow/engine/common" + + ids "github.com/ava-labs/avalanchego/ids" + + mock "github.com/stretchr/testify/mock" + + snow "github.com/ava-labs/avalanchego/snow" + snowstorm "github.com/ava-labs/avalanchego/snow/consensus/snowstorm" + vertex "github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex" +) + +var _ vertex.DAGVM = &DAGVM{} + +// DAGVM is an autogenerated mock type for the DAGVM type +type DAGVM struct { + mock.Mock +} + +// Bootstrapped provides a mock function with given fields: +func (_m *DAGVM) Bootstrapped() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Bootstrapping provides a mock function with given fields: +func (_m *DAGVM) Bootstrapping() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CreateHandlers provides a mock function with given fields: +func (_m *DAGVM) CreateHandlers() (map[string]*common.HTTPHandler, error) { + ret := _m.Called() + + var r0 map[string]*common.HTTPHandler + if rf, ok := ret.Get(0).(func() map[string]*common.HTTPHandler); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]*common.HTTPHandler) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTx provides a mock function with given fields: _a0 +func (_m *DAGVM) GetTx(_a0 ids.ID) (snowstorm.Tx, error) { + ret := _m.Called(_a0) + + var r0 snowstorm.Tx + if rf, ok := ret.Get(0).(func(ids.ID) snowstorm.Tx); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(snowstorm.Tx) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(ids.ID) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HealthCheck provides a mock function with given fields: +func (_m *DAGVM) HealthCheck() (interface{}, error) { + ret := _m.Called() + + var r0 interface{} + if rf, ok := ret.Get(0).(func() interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Initialize provides a mock function with given fields: ctx, db, genesisBytes, toEngine, fxs +func (_m *DAGVM) Initialize(ctx *snow.Context, db database.Database, genesisBytes []byte, toEngine chan<- common.Message, fxs []*common.Fx) error { + ret := _m.Called(ctx, db, genesisBytes, toEngine, fxs) + + var r0 error + if rf, ok := ret.Get(0).(func(*snow.Context, database.Database, []byte, chan<- common.Message, []*common.Fx) error); ok { + r0 = rf(ctx, db, genesisBytes, toEngine, fxs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ParseTx provides a mock function with given fields: tx +func (_m *DAGVM) ParseTx(tx []byte) (snowstorm.Tx, error) { + ret := _m.Called(tx) + + var r0 snowstorm.Tx + if rf, ok := ret.Get(0).(func([]byte) snowstorm.Tx); ok { + r0 = rf(tx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(snowstorm.Tx) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(tx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PendingTxs provides a mock function with given fields: +func (_m *DAGVM) PendingTxs() []snowstorm.Tx { + ret := _m.Called() + + var r0 []snowstorm.Tx + if rf, ok := ret.Get(0).(func() []snowstorm.Tx); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]snowstorm.Tx) + } + } + + return r0 +} + +// Shutdown provides a mock function with given fields: +func (_m *DAGVM) Shutdown() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/snow/engine/avalanche/vertex/parser.go b/snow/engine/avalanche/vertex/parser.go index 6218570b55c6..e37f7575c480 100644 --- a/snow/engine/avalanche/vertex/parser.go +++ b/snow/engine/avalanche/vertex/parser.go @@ -11,13 +11,13 @@ import ( // Parser parses bytes into a vertex. type Parser interface { // Parse a vertex from a slice of bytes - Parse(vertex []byte) (avalanche.Vertex, error) + ParseVtx(vertex []byte) (avalanche.Vertex, error) } // Parse the provided vertex bytes into a stateless vertex func Parse(vertex []byte) (StatelessVertex, error) { vtx := innerStatelessVertex{} - version, err := Codec.Unmarshal(vertex, &vtx) + version, err := c.Unmarshal(vertex, &vtx) vtx.Version = version return statelessVertex{ innerStatelessVertex: vtx, diff --git a/snow/engine/avalanche/vertex/storage.go b/snow/engine/avalanche/vertex/storage.go index 74a58549ba9d..7463e5ab675f 100644 --- a/snow/engine/avalanche/vertex/storage.go +++ b/snow/engine/avalanche/vertex/storage.go @@ -12,7 +12,7 @@ import ( // engine. type Storage interface { // Get a vertex by its hash from storage. - Get(vtxID ids.ID) (avalanche.Vertex, error) + GetVtx(vtxID ids.ID) (avalanche.Vertex, error) // Edge returns a list of accepted vertex IDs with no accepted children. Edge() (vtxIDs []ids.ID) diff --git a/snow/engine/avalanche/vertex/test_builder.go b/snow/engine/avalanche/vertex/test_builder.go index 581ffbe5414a..a28cff4f439b 100644 --- a/snow/engine/avalanche/vertex/test_builder.go +++ b/snow/engine/avalanche/vertex/test_builder.go @@ -19,9 +19,9 @@ var ( ) type TestBuilder struct { - T *testing.T - CantBuild bool - BuildF func( + T *testing.T + CantBuildVtx bool + BuildVtxF func( epoch uint32, parentIDs []ids.ID, txs []snowstorm.Tx, @@ -29,18 +29,18 @@ type TestBuilder struct { ) (avalanche.Vertex, error) } -func (b *TestBuilder) Default(cant bool) { b.CantBuild = cant } +func (b *TestBuilder) Default(cant bool) { b.CantBuildVtx = cant } -func (b *TestBuilder) Build( +func (b *TestBuilder) BuildVtx( epoch uint32, parentIDs []ids.ID, txs []snowstorm.Tx, restrictions []ids.ID, ) (avalanche.Vertex, error) { - if b.BuildF != nil { - return b.BuildF(epoch, parentIDs, txs, restrictions) + if b.BuildVtxF != nil { + return b.BuildVtxF(epoch, parentIDs, txs, restrictions) } - if b.CantBuild && b.T != nil { + if b.CantBuildVtx && b.T != nil { b.T.Fatal(errBuild) } return nil, errBuild diff --git a/snow/engine/avalanche/vertex/test_parser.go b/snow/engine/avalanche/vertex/test_parser.go index 4e79b78ebd09..1e37edb44fc3 100644 --- a/snow/engine/avalanche/vertex/test_parser.go +++ b/snow/engine/avalanche/vertex/test_parser.go @@ -17,18 +17,18 @@ var ( ) type TestParser struct { - T *testing.T - CantParse bool - ParseF func([]byte) (avalanche.Vertex, error) + T *testing.T + CantParseVtx bool + ParseVtxF func([]byte) (avalanche.Vertex, error) } -func (p *TestParser) Default(cant bool) { p.CantParse = cant } +func (p *TestParser) Default(cant bool) { p.CantParseVtx = cant } -func (p *TestParser) Parse(b []byte) (avalanche.Vertex, error) { - if p.ParseF != nil { - return p.ParseF(b) +func (p *TestParser) ParseVtx(b []byte) (avalanche.Vertex, error) { + if p.ParseVtxF != nil { + return p.ParseVtxF(b) } - if p.CantParse && p.T != nil { + if p.CantParseVtx && p.T != nil { p.T.Fatal(errParse) } return nil, errParse diff --git a/snow/engine/avalanche/vertex/test_storage.go b/snow/engine/avalanche/vertex/test_storage.go index 1b81262b3df0..68335f0c818a 100644 --- a/snow/engine/avalanche/vertex/test_storage.go +++ b/snow/engine/avalanche/vertex/test_storage.go @@ -19,22 +19,22 @@ var ( ) type TestStorage struct { - T *testing.T - CantGet, CantEdge bool - GetF func(ids.ID) (avalanche.Vertex, error) - EdgeF func() []ids.ID + T *testing.T + CantGetVtx, CantEdge bool + GetVtxF func(ids.ID) (avalanche.Vertex, error) + EdgeF func() []ids.ID } func (s *TestStorage) Default(cant bool) { - s.CantGet = cant + s.CantGetVtx = cant s.CantEdge = cant } -func (s *TestStorage) Get(id ids.ID) (avalanche.Vertex, error) { - if s.GetF != nil { - return s.GetF(id) +func (s *TestStorage) GetVtx(id ids.ID) (avalanche.Vertex, error) { + if s.GetVtxF != nil { + return s.GetVtxF(id) } - if s.CantGet && s.T != nil { + if s.CantGetVtx && s.T != nil { s.T.Fatal(errGet) } return nil, errGet diff --git a/snow/engine/avalanche/vertex/test_vm.go b/snow/engine/avalanche/vertex/test_vm.go index 8b5daf8ef946..9cd39a56eff4 100644 --- a/snow/engine/avalanche/vertex/test_vm.go +++ b/snow/engine/avalanche/vertex/test_vm.go @@ -20,34 +20,34 @@ var ( type TestVM struct { common.TestVM - CantPending, CantParse, CantGet bool + CantPendingTxs, CantParse, CantGet bool - PendingF func() []snowstorm.Tx - ParseF func([]byte) (snowstorm.Tx, error) - GetF func(ids.ID) (snowstorm.Tx, error) + PendingTxsF func() []snowstorm.Tx + ParseTxF func([]byte) (snowstorm.Tx, error) + GetTxF func(ids.ID) (snowstorm.Tx, error) } func (vm *TestVM) Default(cant bool) { vm.TestVM.Default(cant) - vm.CantPending = cant + vm.CantPendingTxs = cant vm.CantParse = cant vm.CantGet = cant } -func (vm *TestVM) Pending() []snowstorm.Tx { - if vm.PendingF != nil { - return vm.PendingF() +func (vm *TestVM) PendingTxs() []snowstorm.Tx { + if vm.PendingTxsF != nil { + return vm.PendingTxsF() } - if vm.CantPending && vm.T != nil { + if vm.CantPendingTxs && vm.T != nil { vm.T.Fatal(errPending) } return nil } -func (vm *TestVM) Parse(b []byte) (snowstorm.Tx, error) { - if vm.ParseF != nil { - return vm.ParseF(b) +func (vm *TestVM) ParseTx(b []byte) (snowstorm.Tx, error) { + if vm.ParseTxF != nil { + return vm.ParseTxF(b) } if vm.CantParse && vm.T != nil { vm.T.Fatal(errParse) @@ -55,9 +55,9 @@ func (vm *TestVM) Parse(b []byte) (snowstorm.Tx, error) { return nil, errParse } -func (vm *TestVM) Get(txID ids.ID) (snowstorm.Tx, error) { - if vm.GetF != nil { - return vm.GetF(txID) +func (vm *TestVM) GetTx(txID ids.ID) (snowstorm.Tx, error) { + if vm.GetTxF != nil { + return vm.GetTxF(txID) } if vm.CantGet && vm.T != nil { vm.T.Fatal(errGet) diff --git a/snow/engine/avalanche/vertex/vm.go b/snow/engine/avalanche/vertex/vm.go index 554daab11b8b..4e125b445811 100644 --- a/snow/engine/avalanche/vertex/vm.go +++ b/snow/engine/avalanche/vertex/vm.go @@ -15,11 +15,11 @@ type DAGVM interface { common.VM // Return any transactions that have not been sent to consensus yet - Pending() []snowstorm.Tx + PendingTxs() []snowstorm.Tx // Convert a stream of bytes to a transaction or return an error - Parse(tx []byte) (snowstorm.Tx, error) + ParseTx(tx []byte) (snowstorm.Tx, error) // Retrieve a transaction that was submitted previously - Get(ids.ID) (snowstorm.Tx, error) + GetTx(ids.ID) (snowstorm.Tx, error) } diff --git a/snow/engine/avalanche/voter.go b/snow/engine/avalanche/voter.go index 478d06187e0d..9c06354c1a3d 100644 --- a/snow/engine/avalanche/voter.go +++ b/snow/engine/avalanche/voter.go @@ -53,7 +53,7 @@ func (v *voter) Update() { orphans := v.t.Consensus.Orphans() txs := make([]snowstorm.Tx, 0, orphans.Len()) for orphanID := range orphans { - if tx, err := v.t.VM.Get(orphanID); err == nil { + if tx, err := v.t.VM.GetTx(orphanID); err == nil { txs = append(txs, tx) } else { v.t.Ctx.Log.Warn("Failed to fetch %s during attempted re-issuance", orphanID) @@ -73,13 +73,13 @@ func (v *voter) Update() { } v.t.Ctx.Log.Debug("Avalanche engine can't quiesce") - v.t.errs.Add(v.t.repoll()) + v.t.repoll() } func (v *voter) bubbleVotes(votes ids.UniqueBag) (ids.UniqueBag, error) { vertexHeap := vertex.NewHeap() for vote := range votes { - vtx, err := v.t.Manager.Get(vote) + vtx, err := v.t.Manager.GetVtx(vote) if err != nil { continue } diff --git a/snow/engine/common/bootstrapper.go b/snow/engine/common/bootstrapper.go index dd714d10a381..48e9a7eef645 100644 --- a/snow/engine/common/bootstrapper.go +++ b/snow/engine/common/bootstrapper.go @@ -21,11 +21,11 @@ const ( // StatusUpdateFrequency is how many containers should be processed between // logs - StatusUpdateFrequency = 2500 + StatusUpdateFrequency = 5000 // MaxOutstandingRequests is the maximum number of GetAncestors sent but not // responded to/failed - MaxOutstandingRequests = 8 + MaxOutstandingRequests = 10 // MaxTimeFetchingAncestors is the maximum amount of time to spend fetching // vertices during a call to GetAncestors @@ -42,6 +42,8 @@ type Bootstrapper struct { // received a reply from pendingAcceptedFrontier ids.ShortSet acceptedFrontier ids.Set + // True if RestartBootstrap has been called at least once + Restarted bool // holds the beacons that were sampled for the accepted frontier sampledBeacons validators.Set @@ -293,7 +295,11 @@ func (b *Bootstrapper) Accepted(validatorID ids.ShortID, requestID uint32, conta } } - b.Ctx.Log.Info("Bootstrapping started syncing with %d vertices in the accepted frontier", size) + if !b.Restarted { + b.Ctx.Log.Info("Bootstrapping started syncing with %d vertices in the accepted frontier", size) + } else { + b.Ctx.Log.Debug("Bootstrapping started syncing with %d vertices in the accepted frontier", size) + } return b.Bootstrapable.ForceAccepted(accepted) } @@ -333,10 +339,12 @@ func (b *Bootstrapper) Disconnected(validatorID ids.ShortID) error { } func (b *Bootstrapper) RestartBootstrap(reset bool) error { + b.Restarted = true + // resets the attempts when we're pulling blocks/vertices // we don't want to fail the bootstrap at that stage if reset { - b.Ctx.Log.Info("Checking for new frontiers...") + b.Ctx.Log.Debug("Checking for new frontiers") b.bootstrapAttempts = 0 } diff --git a/snow/engine/common/engine.go b/snow/engine/common/engine.go index 9eba2ab72370..845e6bc3df9b 100644 --- a/snow/engine/common/engine.go +++ b/snow/engine/common/engine.go @@ -22,6 +22,9 @@ type Engine interface { // Returns nil if the engine is healthy. // Periodically called and reported through the health API health.Checkable + + // GetVM returns this engine's VM + GetVM() VM } // Handler defines the functions that are acted on the node diff --git a/snow/engine/common/test_engine.go b/snow/engine/common/test_engine.go index 328945265526..c6af6e65d0da 100644 --- a/snow/engine/common/test_engine.go +++ b/snow/engine/common/test_engine.go @@ -9,6 +9,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" + "github.com/ava-labs/avalanchego/snow/consensus/avalanche" ) // EngineTest is a test engine @@ -47,7 +48,9 @@ type EngineTest struct { CantConnected, CantDisconnected, - CantHealth bool + CantHealth, + + CantGetVtx, CantGetVM bool IsBootstrappedF func() bool ContextF func() *snow.Context @@ -61,6 +64,8 @@ type EngineTest struct { QueryFailedF, GetAcceptedFrontierFailedF, GetAcceptedFailedF func(validatorID ids.ShortID, requestID uint32) error ConnectedF, DisconnectedF func(validatorID ids.ShortID) error HealthF func() (interface{}, error) + GetVtxF func() (avalanche.Vertex, error) + GetVMF func() VM } var _ Engine = &EngineTest{} @@ -101,9 +106,11 @@ func (e *EngineTest) Default(cant bool) { e.CantDisconnected = cant e.CantHealth = cant + + e.CantGetVtx = cant + e.CantGetVM = cant } -// Context ... func (e *EngineTest) Context() *snow.Context { if e.ContextF != nil { return e.ContextF() @@ -114,7 +121,6 @@ func (e *EngineTest) Context() *snow.Context { return nil } -// Startup ... func (e *EngineTest) Startup() error { if e.StartupF != nil { return e.StartupF() @@ -128,7 +134,6 @@ func (e *EngineTest) Startup() error { return errors.New("unexpectedly called Startup") } -// Gossip ... func (e *EngineTest) Gossip() error { if e.GossipF != nil { return e.GossipF() @@ -142,7 +147,6 @@ func (e *EngineTest) Gossip() error { return errors.New("unexpectedly called Gossip") } -// Shutdown ... func (e *EngineTest) Shutdown() error { if e.ShutdownF != nil { return e.ShutdownF() @@ -156,7 +160,6 @@ func (e *EngineTest) Shutdown() error { return errors.New("unexpectedly called Shutdown") } -// Notify ... func (e *EngineTest) Notify(msg Message) error { if e.NotifyF != nil { return e.NotifyF(msg) @@ -170,7 +173,6 @@ func (e *EngineTest) Notify(msg Message) error { return errors.New("unexpectedly called Notify") } -// GetAcceptedFrontier ... func (e *EngineTest) GetAcceptedFrontier(validatorID ids.ShortID, requestID uint32) error { if e.GetAcceptedFrontierF != nil { return e.GetAcceptedFrontierF(validatorID, requestID) @@ -184,7 +186,6 @@ func (e *EngineTest) GetAcceptedFrontier(validatorID ids.ShortID, requestID uint return errors.New("unexpectedly called GetAcceptedFrontier") } -// GetAcceptedFrontierFailed ... func (e *EngineTest) GetAcceptedFrontierFailed(validatorID ids.ShortID, requestID uint32) error { if e.GetAcceptedFrontierFailedF != nil { return e.GetAcceptedFrontierFailedF(validatorID, requestID) @@ -198,7 +199,6 @@ func (e *EngineTest) GetAcceptedFrontierFailed(validatorID ids.ShortID, requestI return errors.New("unexpectedly called GetAcceptedFrontierFailed") } -// AcceptedFrontier ... func (e *EngineTest) AcceptedFrontier(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { if e.AcceptedFrontierF != nil { return e.AcceptedFrontierF(validatorID, requestID, containerIDs) @@ -212,7 +212,6 @@ func (e *EngineTest) AcceptedFrontier(validatorID ids.ShortID, requestID uint32, return errors.New("unexpectedly called AcceptedFrontierF") } -// GetAccepted ... func (e *EngineTest) GetAccepted(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { if e.GetAcceptedF != nil { return e.GetAcceptedF(validatorID, requestID, containerIDs) @@ -226,7 +225,6 @@ func (e *EngineTest) GetAccepted(validatorID ids.ShortID, requestID uint32, cont return errors.New("unexpectedly called GetAccepted") } -// GetAcceptedFailed ... func (e *EngineTest) GetAcceptedFailed(validatorID ids.ShortID, requestID uint32) error { if e.GetAcceptedFailedF != nil { return e.GetAcceptedFailedF(validatorID, requestID) @@ -240,7 +238,6 @@ func (e *EngineTest) GetAcceptedFailed(validatorID ids.ShortID, requestID uint32 return errors.New("unexpectedly called GetAcceptedFailed") } -// Accepted ... func (e *EngineTest) Accepted(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { if e.AcceptedF != nil { return e.AcceptedF(validatorID, requestID, containerIDs) @@ -254,7 +251,6 @@ func (e *EngineTest) Accepted(validatorID ids.ShortID, requestID uint32, contain return errors.New("unexpectedly called Accepted") } -// Get ... func (e *EngineTest) Get(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { if e.GetF != nil { return e.GetF(validatorID, requestID, containerID) @@ -268,7 +264,6 @@ func (e *EngineTest) Get(validatorID ids.ShortID, requestID uint32, containerID return errors.New("unexpectedly called Get") } -// GetAncestors ... func (e *EngineTest) GetAncestors(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { if e.GetAncestorsF != nil { return e.GetAncestorsF(validatorID, requestID, containerID) @@ -283,7 +278,6 @@ func (e *EngineTest) GetAncestors(validatorID ids.ShortID, requestID uint32, con } -// GetFailed ... func (e *EngineTest) GetFailed(validatorID ids.ShortID, requestID uint32) error { if e.GetFailedF != nil { return e.GetFailedF(validatorID, requestID) @@ -297,7 +291,6 @@ func (e *EngineTest) GetFailed(validatorID ids.ShortID, requestID uint32) error return errors.New("unexpectedly called GetFailed") } -// GetAncestorsFailed ... func (e *EngineTest) GetAncestorsFailed(validatorID ids.ShortID, requestID uint32) error { if e.GetAncestorsFailedF != nil { return e.GetAncestorsFailedF(validatorID, requestID) @@ -311,7 +304,6 @@ func (e *EngineTest) GetAncestorsFailed(validatorID ids.ShortID, requestID uint3 return errors.New("unexpectedly called GetAncestorsFailed") } -// Put ... func (e *EngineTest) Put(validatorID ids.ShortID, requestID uint32, containerID ids.ID, container []byte) error { if e.PutF != nil { return e.PutF(validatorID, requestID, containerID, container) @@ -325,7 +317,6 @@ func (e *EngineTest) Put(validatorID ids.ShortID, requestID uint32, containerID return errors.New("unexpectedly called Put") } -// MultiPut ... func (e *EngineTest) MultiPut(validatorID ids.ShortID, requestID uint32, containers [][]byte) error { if e.MultiPutF != nil { return e.MultiPutF(validatorID, requestID, containers) @@ -339,7 +330,6 @@ func (e *EngineTest) MultiPut(validatorID ids.ShortID, requestID uint32, contain return errors.New("unexpectedly called MultiPut") } -// PushQuery ... func (e *EngineTest) PushQuery(validatorID ids.ShortID, requestID uint32, containerID ids.ID, container []byte) error { if e.PushQueryF != nil { return e.PushQueryF(validatorID, requestID, containerID, container) @@ -353,7 +343,6 @@ func (e *EngineTest) PushQuery(validatorID ids.ShortID, requestID uint32, contai return errors.New("unexpectedly called PushQuery") } -// PullQuery ... func (e *EngineTest) PullQuery(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { if e.PullQueryF != nil { return e.PullQueryF(validatorID, requestID, containerID) @@ -367,7 +356,6 @@ func (e *EngineTest) PullQuery(validatorID ids.ShortID, requestID uint32, contai return errors.New("unexpectedly called PullQuery") } -// QueryFailed ... func (e *EngineTest) QueryFailed(validatorID ids.ShortID, requestID uint32) error { if e.QueryFailedF != nil { return e.QueryFailedF(validatorID, requestID) @@ -381,7 +369,6 @@ func (e *EngineTest) QueryFailed(validatorID ids.ShortID, requestID uint32) erro return errors.New("unexpectedly called QueryFailed") } -// Chits ... func (e *EngineTest) Chits(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { if e.ChitsF != nil { return e.ChitsF(validatorID, requestID, containerIDs) @@ -395,7 +382,6 @@ func (e *EngineTest) Chits(validatorID ids.ShortID, requestID uint32, containerI return errors.New("unexpectedly called Chits") } -// Connected ... func (e *EngineTest) Connected(validatorID ids.ShortID) error { if e.ConnectedF != nil { return e.ConnectedF(validatorID) @@ -409,7 +395,6 @@ func (e *EngineTest) Connected(validatorID ids.ShortID) error { return errors.New("unexpectedly called Connected") } -// Disconnected ... func (e *EngineTest) Disconnected(validatorID ids.ShortID) error { if e.DisconnectedF != nil { return e.DisconnectedF(validatorID) @@ -423,7 +408,6 @@ func (e *EngineTest) Disconnected(validatorID ids.ShortID) error { return errors.New("unexpectedly called Disconnected") } -// IsBootstrapped ... func (e *EngineTest) IsBootstrapped() bool { if e.IsBootstrappedF != nil { return e.IsBootstrappedF() @@ -434,7 +418,6 @@ func (e *EngineTest) IsBootstrapped() bool { return false } -// Health ... func (e *EngineTest) HealthCheck() (interface{}, error) { if e.HealthF != nil { return e.HealthF() @@ -444,3 +427,23 @@ func (e *EngineTest) HealthCheck() (interface{}, error) { } return nil, errors.New("unexpectedly called Health") } + +func (e *EngineTest) GetVtx() (avalanche.Vertex, error) { + if e.GetVtxF != nil { + return e.GetVtxF() + } + if e.CantGetVtx && e.T != nil { + e.T.Fatalf("Unexpectedly called GetVtx") + } + return nil, errors.New("unexpectedly called GetVtx") +} + +func (e *EngineTest) GetVM() VM { + if e.GetVMF != nil { + return e.GetVMF() + } + if e.CantGetVM && e.T != nil { + e.T.Fatalf("Unexpectedly called GetVM") + } + return nil +} diff --git a/snow/engine/snowman/block/mocks/chain_vm.go b/snow/engine/snowman/block/mocks/chain_vm.go new file mode 100644 index 000000000000..e160b1488472 --- /dev/null +++ b/snow/engine/snowman/block/mocks/chain_vm.go @@ -0,0 +1,232 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import ( + database "github.com/ava-labs/avalanchego/database" + common "github.com/ava-labs/avalanchego/snow/engine/common" + + ids "github.com/ava-labs/avalanchego/ids" + + mock "github.com/stretchr/testify/mock" + + snow "github.com/ava-labs/avalanchego/snow" + + snowman "github.com/ava-labs/avalanchego/snow/consensus/snowman" + block "github.com/ava-labs/avalanchego/snow/engine/snowman/block" +) + +var _ block.ChainVM = &ChainVM{} + +// ChainVM is an autogenerated mock type for the ChainVM type +type ChainVM struct { + mock.Mock +} + +// Bootstrapped provides a mock function with given fields: +func (_m *ChainVM) Bootstrapped() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Bootstrapping provides a mock function with given fields: +func (_m *ChainVM) Bootstrapping() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// BuildBlock provides a mock function with given fields: +func (_m *ChainVM) BuildBlock() (snowman.Block, error) { + ret := _m.Called() + + var r0 snowman.Block + if rf, ok := ret.Get(0).(func() snowman.Block); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(snowman.Block) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateHandlers provides a mock function with given fields: +func (_m *ChainVM) CreateHandlers() (map[string]*common.HTTPHandler, error) { + ret := _m.Called() + + var r0 map[string]*common.HTTPHandler + if rf, ok := ret.Get(0).(func() map[string]*common.HTTPHandler); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]*common.HTTPHandler) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetBlock provides a mock function with given fields: _a0 +func (_m *ChainVM) GetBlock(_a0 ids.ID) (snowman.Block, error) { + ret := _m.Called(_a0) + + var r0 snowman.Block + if rf, ok := ret.Get(0).(func(ids.ID) snowman.Block); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(snowman.Block) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(ids.ID) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HealthCheck provides a mock function with given fields: +func (_m *ChainVM) HealthCheck() (interface{}, error) { + ret := _m.Called() + + var r0 interface{} + if rf, ok := ret.Get(0).(func() interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Initialize provides a mock function with given fields: ctx, db, genesisBytes, toEngine, fxs +func (_m *ChainVM) Initialize(ctx *snow.Context, db database.Database, genesisBytes []byte, toEngine chan<- common.Message, fxs []*common.Fx) error { + ret := _m.Called(ctx, db, genesisBytes, toEngine, fxs) + + var r0 error + if rf, ok := ret.Get(0).(func(*snow.Context, database.Database, []byte, chan<- common.Message, []*common.Fx) error); ok { + r0 = rf(ctx, db, genesisBytes, toEngine, fxs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LastAccepted provides a mock function with given fields: +func (_m *ChainVM) LastAccepted() (ids.ID, error) { + ret := _m.Called() + + var r0 ids.ID + if rf, ok := ret.Get(0).(func() ids.ID); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(ids.ID) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ParseBlock provides a mock function with given fields: _a0 +func (_m *ChainVM) ParseBlock(_a0 []byte) (snowman.Block, error) { + ret := _m.Called(_a0) + + var r0 snowman.Block + if rf, ok := ret.Get(0).(func([]byte) snowman.Block); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(snowman.Block) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SetPreference provides a mock function with given fields: _a0 +func (_m *ChainVM) SetPreference(_a0 ids.ID) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ID) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Shutdown provides a mock function with given fields: +func (_m *ChainVM) Shutdown() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/snow/engine/snowman/bootstrap/bootstrapper.go b/snow/engine/snowman/bootstrap/bootstrapper.go index ce4efe611414..96d79b992f0e 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper.go +++ b/snow/engine/snowman/bootstrap/bootstrapper.go @@ -44,6 +44,13 @@ type Bootstrapper struct { common.Fetcher metrics + // Greatest height of the blocks passed in ForceAccepted + tipHeight uint64 + // Height of the last accepted block when bootstrapping starts + startingHeight uint64 + // Blocks passed into ForceAccepted + startingAcceptedFrontier ids.Set + // Blocked tracks operations that are blocked on blocks Blocked *queue.Jobs @@ -72,6 +79,16 @@ func (b *Bootstrapper) Initialize( b.OnFinished = onFinished b.executedStateTransitions = math.MaxInt32 b.delayAmount = initialBootstrappingDelay + b.startingAcceptedFrontier = ids.Set{} + lastAcceptedID, err := b.VM.LastAccepted() + if err != nil { + return fmt.Errorf("couldn't get last accepted ID: %s", err) + } + lastAccepted, err := b.VM.GetBlock(lastAcceptedID) + if err != nil { + return fmt.Errorf("couldn't get last accepted block: %s", err) + } + b.startingHeight = lastAccepted.Height() if err := b.metrics.Initialize(namespace, registerer); err != nil { return err @@ -114,7 +131,11 @@ func (b *Bootstrapper) ForceAccepted(acceptedContainerIDs []ids.ID) error { b.NumFetched = 0 for _, blkID := range acceptedContainerIDs { + b.startingAcceptedFrontier.Add(blkID) if blk, err := b.VM.GetBlock(blkID); err == nil { + if height := blk.Height(); height > b.tipHeight { + b.tipHeight = height + } if err := b.process(blk); err != nil { return err } @@ -213,6 +234,10 @@ func (b *Bootstrapper) GetAncestorsFailed(vdr ids.ShortID, requestID uint32) err func (b *Bootstrapper) process(blk snowman.Block) error { status := blk.Status() blkID := blk.ID() + blkHeight := blk.Height() + if blkHeight > b.tipHeight && b.startingAcceptedFrontier.Contains(blkID) { + b.tipHeight = blkHeight + } for status == choices.Processing { if err := b.Blocked.Push(&blockJob{ numAccepted: b.numAccepted, @@ -222,7 +247,11 @@ func (b *Bootstrapper) process(blk snowman.Block) error { b.numFetched.Inc() b.NumFetched++ // Progress tracker if b.NumFetched%common.StatusUpdateFrequency == 0 { // Periodically print progress - b.Ctx.Log.Info("fetched %d blocks", b.NumFetched) + if !b.Restarted { + b.Ctx.Log.Info("fetched %d of %d blocks", b.NumFetched, b.tipHeight-b.startingHeight) + } else { + b.Ctx.Log.Debug("fetched %d of %d blocks", b.NumFetched, b.tipHeight-b.startingHeight) + } } } @@ -258,8 +287,11 @@ func (b *Bootstrapper) checkFinish() error { return nil } - b.Ctx.Log.Info("bootstrapping fetched %d blocks. executing state transitions...", - b.NumFetched) + if !b.Restarted { + b.Ctx.Log.Info("bootstrapping fetched %d blocks. Executing state transitions...", b.NumFetched) + } else { + b.Ctx.Log.Debug("bootstrapping fetched %d blocks. Executing state transitions...", b.NumFetched) + } executedBlocks, err := b.executeAll(b.Blocked) if err != nil { @@ -272,14 +304,10 @@ func (b *Bootstrapper) checkFinish() error { // Note that executedVts < c*previouslyExecuted is enforced so that the // bootstrapping process will terminate even as new blocks are being issued. if executedBlocks > 0 && executedBlocks < previouslyExecuted/2 && b.RetryBootstrap { - b.Ctx.Log.Info("bootstrapping is checking for more blocks before finishing the bootstrap process...") - b.processedStartingAcceptedFrontier = false return b.RestartBootstrap(true) } - b.Ctx.Log.Info("bootstrapping fetched enough blocks to finish the bootstrap process...") - // Notify the subnet that this chain is synced b.Subnet.Bootstrapped(b.Ctx.ChainID) @@ -292,7 +320,11 @@ func (b *Bootstrapper) checkFinish() error { // If the subnet hasn't finished bootstrapping, this chain should remain // syncing. if !b.Subnet.IsBootstrapped() { - b.Ctx.Log.Info("bootstrapping is waiting for the remaining chains in this subnet to finish syncing...") + if !b.Restarted { + b.Ctx.Log.Info("waiting for the remaining chains in this subnet to finish syncing") + } else { + b.Ctx.Log.Debug("waiting for the remaining chains in this subnet to finish syncing") + } // Delay new incoming messages to avoid consuming unnecessary resources // while keeping up to date on the latest tip. b.Config.Delay.Delay(b.delayAmount) @@ -325,6 +357,17 @@ func (b *Bootstrapper) finish() error { func (b *Bootstrapper) executeAll(jobs *queue.Jobs) (int, error) { numExecuted := 0 for job, err := jobs.Pop(); err == nil; job, err = jobs.Pop() { + jobID := job.ID() + jobBytes := job.Bytes() + // Note that ConsensusDispatcher.Accept / DecisionDispatcher.Accept must be + // called before job.Execute to honor EventDispatcher.Accept's invariant. + if err := b.Ctx.ConsensusDispatcher.Accept(b.Ctx, jobID, jobBytes); err != nil { + return numExecuted, err + } + if err := b.Ctx.DecisionDispatcher.Accept(b.Ctx, jobID, jobBytes); err != nil { + return numExecuted, err + } + if err := jobs.Execute(job); err != nil { return numExecuted, err } @@ -333,13 +376,19 @@ func (b *Bootstrapper) executeAll(jobs *queue.Jobs) (int, error) { } numExecuted++ if numExecuted%common.StatusUpdateFrequency == 0 { // Periodically print progress - b.Ctx.Log.Info("executed %d blocks", numExecuted) + if !b.Restarted { + b.Ctx.Log.Info("executed %d of %d blocks", numExecuted, b.tipHeight-b.startingHeight) + } else { + b.Ctx.Log.Debug("executed %d of %d blocks", numExecuted, b.tipHeight-b.startingHeight) + } } - b.Ctx.ConsensusDispatcher.Accept(b.Ctx, job.ID(), job.Bytes()) - b.Ctx.DecisionDispatcher.Accept(b.Ctx, job.ID(), job.Bytes()) } - b.Ctx.Log.Info("executed %d blocks", numExecuted) + if !b.Restarted { + b.Ctx.Log.Info("executed %d blocks", numExecuted) + } else { + b.Ctx.Log.Debug("executed %d blocks", numExecuted) + } return numExecuted, nil } diff --git a/snow/engine/snowman/bootstrap/bootstrapper_test.go b/snow/engine/snowman/bootstrap/bootstrapper_test.go index fb8692e44945..3686c8e251bc 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper_test.go +++ b/snow/engine/snowman/bootstrap/bootstrapper_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/prometheus/client_golang/prometheus" + "gotest.tools/assert" "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" @@ -102,6 +103,13 @@ func TestBootstrapperSingleFrontier(t *testing.T) { BytesV: blkBytes1, } + vm.CantLastAccepted = false + vm.LastAcceptedF = func() (ids.ID, error) { return blk0.ID(), nil } + vm.GetBlockF = func(blkID ids.ID) (snowman.Block, error) { + assert.Equal(t, blk0.ID(), blkID) + return blk0, nil + } + finished := new(bool) bs := Bootstrapper{} err := bs.Initialize( @@ -194,6 +202,13 @@ func TestBootstrapperUnknownByzantineResponse(t *testing.T) { BytesV: blkBytes2, } + vm.CantLastAccepted = false + vm.LastAcceptedF = func() (ids.ID, error) { return blk0.ID(), nil } + vm.GetBlockF = func(blkID ids.ID) (snowman.Block, error) { + assert.Equal(t, blk0.ID(), blkID) + return blk0, nil + } + finished := new(bool) bs := Bootstrapper{} err := bs.Initialize( @@ -344,6 +359,13 @@ func TestBootstrapperPartialFetch(t *testing.T) { BytesV: blkBytes3, } + vm.CantLastAccepted = false + vm.LastAcceptedF = func() (ids.ID, error) { return blk0.ID(), nil } + vm.GetBlockF = func(blkID ids.ID) (snowman.Block, error) { + assert.Equal(t, blk0.ID(), blkID) + return blk0, nil + } + finished := new(bool) bs := Bootstrapper{} err := bs.Initialize( @@ -498,7 +520,12 @@ func TestBootstrapperMultiPut(t *testing.T) { } vm.CantBootstrapping = false - + vm.CantLastAccepted = false + vm.LastAcceptedF = func() (ids.ID, error) { return blk0.ID(), nil } + vm.GetBlockF = func(blkID ids.ID) (snowman.Block, error) { + assert.Equal(t, blk0.ID(), blkID) + return blk0, nil + } finished := new(bool) bs := Bootstrapper{} err := bs.Initialize( @@ -599,6 +626,20 @@ func TestBootstrapperAcceptedFrontier(t *testing.T) { blkID := ids.GenerateTestID() + dummyBlk := &snowman.TestBlock{ + TestDecidable: choices.TestDecidable{ + IDV: blkID, + StatusV: choices.Accepted, + }, + HeightV: 0, + BytesV: []byte{1, 2, 3}, + } + vm.CantLastAccepted = false + vm.LastAcceptedF = func() (ids.ID, error) { return blkID, nil } + vm.GetBlockF = func(bID ids.ID) (snowman.Block, error) { + assert.Equal(t, blkID, bID) + return dummyBlk, nil + } bs := Bootstrapper{} err := bs.Initialize( config, @@ -610,8 +651,6 @@ func TestBootstrapperAcceptedFrontier(t *testing.T) { t.Fatal(err) } - vm.LastAcceptedF = func() (ids.ID, error) { return blkID, nil } - accepted, err := bs.CurrentAcceptedFrontier() if err != nil { t.Fatal(err) @@ -641,6 +680,13 @@ func TestBootstrapperFilterAccepted(t *testing.T) { StatusV: choices.Accepted, }} + vm.CantLastAccepted = false + vm.LastAcceptedF = func() (ids.ID, error) { return blk1.ID(), nil } + vm.GetBlockF = func(blkID ids.ID) (snowman.Block, error) { + assert.Equal(t, blk1.ID(), blkID) + return blk1, nil + } + bs := Bootstrapper{} err := bs.Initialize( config, @@ -653,7 +699,6 @@ func TestBootstrapperFilterAccepted(t *testing.T) { } blkIDs := []ids.ID{blkID0, blkID1, blkID2} - vm.GetBlockF = func(blkID ids.ID) (snowman.Block, error) { switch blkID { case blkID0: @@ -726,6 +771,12 @@ func TestBootstrapperFinalized(t *testing.T) { finished := new(bool) bs := Bootstrapper{} + vm.CantLastAccepted = false + vm.LastAcceptedF = func() (ids.ID, error) { return blk0.ID(), nil } + vm.GetBlockF = func(blkID ids.ID) (snowman.Block, error) { + assert.Equal(t, blk0.ID(), blkID) + return blk0, nil + } err := bs.Initialize( config, func() error { *finished = true; return nil }, diff --git a/snow/engine/snowman/engine.go b/snow/engine/snowman/engine.go index c12f7e6effe5..bcc684b42324 100644 --- a/snow/engine/snowman/engine.go +++ b/snow/engine/snowman/engine.go @@ -4,6 +4,8 @@ package snowman import ( + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/consensus/snowman" "github.com/ava-labs/avalanchego/snow/engine/common" ) @@ -19,5 +21,9 @@ type Engine interface { common.Engine // Initialize this engine. - Initialize(Config) + Initialize(Config) error + + // GetBlock returns a block by its ID. + // Returns an error if unknown. + GetBlock(blkID ids.ID) (snowman.Block, error) } diff --git a/snow/engine/snowman/mocks/engine.go b/snow/engine/snowman/mocks/engine.go new file mode 100644 index 000000000000..48121f441f00 --- /dev/null +++ b/snow/engine/snowman/mocks/engine.go @@ -0,0 +1,435 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import ( + consensussnowman "github.com/ava-labs/avalanchego/snow/consensus/snowman" + common "github.com/ava-labs/avalanchego/snow/engine/common" + + ids "github.com/ava-labs/avalanchego/ids" + + mock "github.com/stretchr/testify/mock" + + snow "github.com/ava-labs/avalanchego/snow" + + snowman "github.com/ava-labs/avalanchego/snow/engine/snowman" +) + +// Engine is an autogenerated mock type for the Engine type +type Engine struct { + mock.Mock +} + +// Accepted provides a mock function with given fields: validatorID, requestID, containerIDs +func (_m *Engine) Accepted(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { + ret := _m.Called(validatorID, requestID, containerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, []ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AcceptedFrontier provides a mock function with given fields: validatorID, requestID, containerIDs +func (_m *Engine) AcceptedFrontier(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { + ret := _m.Called(validatorID, requestID, containerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, []ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Chits provides a mock function with given fields: validatorID, requestID, containerIDs +func (_m *Engine) Chits(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { + ret := _m.Called(validatorID, requestID, containerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, []ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Connected provides a mock function with given fields: validatorID +func (_m *Engine) Connected(validatorID ids.ShortID) error { + ret := _m.Called(validatorID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID) error); ok { + r0 = rf(validatorID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Context provides a mock function with given fields: +func (_m *Engine) Context() *snow.Context { + ret := _m.Called() + + var r0 *snow.Context + if rf, ok := ret.Get(0).(func() *snow.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*snow.Context) + } + } + + return r0 +} + +// Disconnected provides a mock function with given fields: validatorID +func (_m *Engine) Disconnected(validatorID ids.ShortID) error { + ret := _m.Called(validatorID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID) error); ok { + r0 = rf(validatorID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Get provides a mock function with given fields: validatorID, requestID, containerID +func (_m *Engine) Get(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { + ret := _m.Called(validatorID, requestID, containerID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAccepted provides a mock function with given fields: validatorID, requestID, containerIDs +func (_m *Engine) GetAccepted(validatorID ids.ShortID, requestID uint32, containerIDs []ids.ID) error { + ret := _m.Called(validatorID, requestID, containerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, []ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAcceptedFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetAcceptedFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAcceptedFrontier provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetAcceptedFrontier(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAcceptedFrontierFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetAcceptedFrontierFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAncestors provides a mock function with given fields: validatorID, requestID, containerID +func (_m *Engine) GetAncestors(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { + ret := _m.Called(validatorID, requestID, containerID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAncestorsFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetAncestorsFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetBlock provides a mock function with given fields: blkID +func (_m *Engine) GetBlock(blkID ids.ID) (consensussnowman.Block, error) { + ret := _m.Called(blkID) + + var r0 consensussnowman.Block + if rf, ok := ret.Get(0).(func(ids.ID) consensussnowman.Block); ok { + r0 = rf(blkID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(consensussnowman.Block) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(ids.ID) error); ok { + r1 = rf(blkID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) GetFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetVM provides a mock function with given fields: +func (_m *Engine) GetVM() common.VM { + ret := _m.Called() + + var r0 common.VM + if rf, ok := ret.Get(0).(func() common.VM); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.VM) + } + } + + return r0 +} + +// Gossip provides a mock function with given fields: +func (_m *Engine) Gossip() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// HealthCheck provides a mock function with given fields: +func (_m *Engine) HealthCheck() (interface{}, error) { + ret := _m.Called() + + var r0 interface{} + if rf, ok := ret.Get(0).(func() interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Initialize provides a mock function with given fields: _a0 +func (_m *Engine) Initialize(_a0 snowman.Config) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(snowman.Config) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// IsBootstrapped provides a mock function with given fields: +func (_m *Engine) IsBootstrapped() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MultiPut provides a mock function with given fields: validatorID, requestID, containers +func (_m *Engine) MultiPut(validatorID ids.ShortID, requestID uint32, containers [][]byte) error { + ret := _m.Called(validatorID, requestID, containers) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, [][]byte) error); ok { + r0 = rf(validatorID, requestID, containers) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Notify provides a mock function with given fields: _a0 +func (_m *Engine) Notify(_a0 common.Message) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(common.Message) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PullQuery provides a mock function with given fields: validatorID, requestID, containerID +func (_m *Engine) PullQuery(validatorID ids.ShortID, requestID uint32, containerID ids.ID) error { + ret := _m.Called(validatorID, requestID, containerID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID) error); ok { + r0 = rf(validatorID, requestID, containerID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PushQuery provides a mock function with given fields: validatorID, requestID, containerID, container +func (_m *Engine) PushQuery(validatorID ids.ShortID, requestID uint32, containerID ids.ID, container []byte) error { + ret := _m.Called(validatorID, requestID, containerID, container) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID, []byte) error); ok { + r0 = rf(validatorID, requestID, containerID, container) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Put provides a mock function with given fields: validatorID, requestID, containerID, container +func (_m *Engine) Put(validatorID ids.ShortID, requestID uint32, containerID ids.ID, container []byte) error { + ret := _m.Called(validatorID, requestID, containerID, container) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32, ids.ID, []byte) error); ok { + r0 = rf(validatorID, requestID, containerID, container) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// QueryFailed provides a mock function with given fields: validatorID, requestID +func (_m *Engine) QueryFailed(validatorID ids.ShortID, requestID uint32) error { + ret := _m.Called(validatorID, requestID) + + var r0 error + if rf, ok := ret.Get(0).(func(ids.ShortID, uint32) error); ok { + r0 = rf(validatorID, requestID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Shutdown provides a mock function with given fields: +func (_m *Engine) Shutdown() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Startup provides a mock function with given fields: +func (_m *Engine) Startup() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/snow/engine/snowman/transitive.go b/snow/engine/snowman/transitive.go index 1525d58665a5..430da3768983 100644 --- a/snow/engine/snowman/transitive.go +++ b/snow/engine/snowman/transitive.go @@ -27,6 +27,8 @@ const ( maxContainersLen = int(4 * network.DefaultMaxMessageSize / 5) ) +var _ Engine = &Transitive{} + // Transitive implements the Engine interface by attempting to fetch all // transitive dependencies. type Transitive struct { @@ -745,3 +747,13 @@ func (t *Transitive) HealthCheck() (interface{}, error) { } return intf, fmt.Errorf("vm: %s ; consensus: %s", vmErr, consensusErr) } + +// GetBlock implements the snowman.Engine interface +func (t *Transitive) GetBlock(blkID ids.ID) (snowman.Block, error) { + return t.VM.GetBlock(blkID) +} + +// GetVM implements the snowman.Engine interface +func (t *Transitive) GetVM() common.VM { + return t.VM +} diff --git a/snow/networking/router/chain_router.go b/snow/networking/router/chain_router.go index 7fc874f680bc..24711e5496ed 100644 --- a/snow/networking/router/chain_router.go +++ b/snow/networking/router/chain_router.go @@ -101,7 +101,7 @@ func (cr *ChainRouter) Initialize( cr.healthConfig = healthConfig // Register metrics - rMetrics, err := newRouterMetrics(cr.log, metricsNamespace, metricsRegisterer) + rMetrics, err := newRouterMetrics(metricsNamespace, metricsRegisterer) if err != nil { return err } @@ -787,22 +787,14 @@ func (cr *ChainRouter) HealthCheck() (interface{}, error) { healthy = healthy && numOutstandingReqs <= cr.healthConfig.MaxOutstandingRequests details["outstandingRequests"] = numOutstandingReqs - now := cr.clock.Time() - if numOutstandingReqs == 0 { - cr.lastTimeNoOutstanding = now - } - timeSinceNoOutstandingRequests := now.Sub(cr.lastTimeNoOutstanding) - healthy = healthy && timeSinceNoOutstandingRequests <= cr.healthConfig.MaxTimeSinceNoOutstandingRequests - details["timeSinceNoOutstandingRequests"] = timeSinceNoOutstandingRequests.String() - cr.metrics.timeSinceNoOutstandingRequests.Set(float64(timeSinceNoOutstandingRequests.Milliseconds())) - // check for long running requests + now := cr.clock.Time() processingRequest := now if longestRunning, exists := cr.timedRequests.Oldest(); exists { processingRequest = longestRunning.(requestEntry).time } timeReqRunning := now.Sub(processingRequest) - healthy = healthy && timeReqRunning <= cr.healthConfig.MaxTimeSinceNoOutstandingRequests + healthy = healthy && timeReqRunning <= cr.healthConfig.MaxOutstandingDuration details["longestRunningRequest"] = timeReqRunning.String() cr.metrics.longestRunningRequest.Set(float64(timeReqRunning.Milliseconds())) diff --git a/snow/networking/router/chain_router_metrics.go b/snow/networking/router/chain_router_metrics.go index aa14e23f6150..6ed5b5509699 100644 --- a/snow/networking/router/chain_router_metrics.go +++ b/snow/networking/router/chain_router_metrics.go @@ -4,20 +4,18 @@ package router import ( - "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/prometheus/client_golang/prometheus" ) // routerMetrics about router messages type routerMetrics struct { - outstandingRequests prometheus.Gauge - msgDropRate prometheus.Gauge - timeSinceNoOutstandingRequests prometheus.Gauge - longestRunningRequest prometheus.Gauge + outstandingRequests prometheus.Gauge + msgDropRate prometheus.Gauge + longestRunningRequest prometheus.Gauge } -func newRouterMetrics(log logging.Logger, namespace string, registerer prometheus.Registerer) (*routerMetrics, error) { +func newRouterMetrics(namespace string, registerer prometheus.Registerer) (*routerMetrics, error) { rMetrics := &routerMetrics{} rMetrics.outstandingRequests = prometheus.NewGauge( prometheus.GaugeOpts{ @@ -33,13 +31,6 @@ func newRouterMetrics(log logging.Logger, namespace string, registerer prometheu Help: "Rate of messages dropped", }, ) - rMetrics.timeSinceNoOutstandingRequests = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: namespace, - Name: "time_since_no_outstanding_requests", - Help: "Time with no requests being processed in milliseconds", - }, - ) rMetrics.longestRunningRequest = prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: namespace, @@ -52,7 +43,6 @@ func newRouterMetrics(log logging.Logger, namespace string, registerer prometheu errs.Add( registerer.Register(rMetrics.outstandingRequests), registerer.Register(rMetrics.msgDropRate), - registerer.Register(rMetrics.timeSinceNoOutstandingRequests), registerer.Register(rMetrics.longestRunningRequest), ) return rMetrics, errs.Err diff --git a/snow/networking/router/chain_router_test.go b/snow/networking/router/chain_router_test.go index 81abe5c10c2a..7c3ca76d08b3 100644 --- a/snow/networking/router/chain_router_test.go +++ b/snow/networking/router/chain_router_test.go @@ -53,7 +53,7 @@ func TestShutdown(t *testing.T) { engine.ShutdownF = func() error { shutdownCalled <- struct{}{}; return nil } handler := &Handler{} - handler.Initialize( + err = handler.Initialize( &engine, vdrs, nil, @@ -65,6 +65,7 @@ func TestShutdown(t *testing.T) { prometheus.NewRegistry(), &Delay{}, ) + assert.NoError(t, err) go handler.Dispatch() @@ -127,7 +128,7 @@ func TestShutdownTimesOut(t *testing.T) { engine.ShutdownF = func() error { *closed++; return nil } handler := &Handler{} - handler.Initialize( + err = handler.Initialize( &engine, vdrs, nil, @@ -139,6 +140,7 @@ func TestShutdownTimesOut(t *testing.T) { prometheus.NewRegistry(), &Delay{}, ) + assert.NoError(t, err) chainRouter.AddChain(handler) @@ -226,7 +228,7 @@ func TestRouterTimeout(t *testing.T) { engine.ContextF = snow.DefaultContextTest handler := &Handler{} - handler.Initialize( + err = handler.Initialize( &engine, validators.NewSet(), nil, @@ -238,6 +240,7 @@ func TestRouterTimeout(t *testing.T) { prometheus.NewRegistry(), &Delay{}, ) + assert.NoError(t, err) chainRouter.AddChain(handler) go handler.Dispatch() @@ -293,7 +296,7 @@ func TestRouterClearTimeouts(t *testing.T) { engine.ContextF = snow.DefaultContextTest handler := &Handler{} - handler.Initialize( + err = handler.Initialize( &engine, validators.NewSet(), nil, @@ -305,6 +308,7 @@ func TestRouterClearTimeouts(t *testing.T) { prometheus.NewRegistry(), &Delay{}, ) + assert.NoError(t, err) chainRouter.AddChain(handler) go handler.Dispatch() diff --git a/snow/networking/router/handler.go b/snow/networking/router/handler.go index 584e30958b26..9830bd36e603 100644 --- a/snow/networking/router/handler.go +++ b/snow/networking/router/handler.go @@ -92,7 +92,7 @@ const ( // Handler passes incoming messages from the network to the consensus engine // (Actually, it receives the incoming messages from a ChainRouter, but same difference) type Handler struct { - metrics + metrics handlerMetrics validators validators.Set @@ -132,7 +132,7 @@ func (h *Handler) Initialize( namespace string, metrics prometheus.Registerer, delay *Delay, -) { +) error { h.ctx = engine.Context() if err := h.metrics.Initialize(namespace, metrics); err != nil { h.ctx.Log.Warn("initializing handler metrics errored with: %s", err) @@ -162,7 +162,7 @@ func (h *Handler) Initialize( h.cpuTracker = tracker.NewCPUTracker(uptime.IntervalFactory{}, cpuInterval) msgTracker := tracker.NewMessageTracker() - msgManager := NewMsgManager( + msgManager, err := NewMsgManager( validators, h.ctx.Log, msgTracker, @@ -171,7 +171,12 @@ func (h *Handler) Initialize( maxNonStakerPendingMsgs, stakerMsgPortion, stakerCPUPortion, + namespace, + metrics, ) + if err != nil { + return err + } h.serviceQueue, h.msgSema = newMultiLevelQueue( msgManager, @@ -184,6 +189,7 @@ func (h *Handler) Initialize( h.engine = engine h.validators = validators h.delay = delay + return nil } // Context of this Handler @@ -282,10 +288,10 @@ func (h *Handler) dispatchMsg(msg message) { switch msg.messageType { case constants.NotifyMsg: err = h.engine.Notify(msg.notification) - h.notify.Observe(float64(h.clock.Time().Sub(startTime))) + h.metrics.notify.Observe(float64(h.clock.Time().Sub(startTime))) case constants.GossipMsg: err = h.engine.Gossip() - h.gossip.Observe(float64(h.clock.Time().Sub(startTime))) + h.metrics.gossip.Observe(float64(h.clock.Time().Sub(startTime))) default: err = h.handleValidatorMsg(msg, startTime) } @@ -536,7 +542,7 @@ func (h *Handler) shutdownDispatch() { go h.toClose() } h.closing.SetValue(true) - h.shutdown.Observe(float64(time.Since(startTime))) + h.metrics.shutdown.Observe(float64(time.Since(startTime))) close(h.closed) } @@ -585,7 +591,7 @@ func (h *Handler) handleValidatorMsg(msg message, startTime time.Time) error { endTime := h.clock.Time() timeConsumed := endTime.Sub(startTime) - histogram := h.getMSGHistogram(msg.messageType) + histogram := h.metrics.getMSGHistogram(msg.messageType) histogram.Observe(float64(timeConsumed)) h.cpuTracker.UtilizeTime(msg.validatorID, startTime, endTime) diff --git a/snow/networking/router/metrics.go b/snow/networking/router/handler_metrics.go similarity index 89% rename from snow/networking/router/metrics.go rename to snow/networking/router/handler_metrics.go index bdac6956c1e5..33aa249ab005 100644 --- a/snow/networking/router/metrics.go +++ b/snow/networking/router/handler_metrics.go @@ -32,11 +32,11 @@ func initHistogram(namespace, name string, registerer prometheus.Registerer, err return histogram } -type metrics struct { - namespace string - registerer prometheus.Registerer - pending prometheus.Gauge - dropped, expired, throttled prometheus.Counter +type handlerMetrics struct { + namespace string + registerer prometheus.Registerer + pending prometheus.Gauge + dropped, expired prometheus.Counter getAcceptedFrontier, acceptedFrontier, getAcceptedFrontierFailed, getAccepted, accepted, getAcceptedFailed, getAncestors, multiPut, getAncestorsFailed, @@ -50,7 +50,7 @@ type metrics struct { } // Initialize implements the Engine interface -func (m *metrics) Initialize(namespace string, registerer prometheus.Registerer) error { +func (m *handlerMetrics) Initialize(namespace string, registerer prometheus.Registerer) error { m.namespace = namespace m.registerer = registerer errs := wrappers.Errs{} @@ -83,15 +83,6 @@ func (m *metrics) Initialize(namespace string, registerer prometheus.Registerer) errs.Add(fmt.Errorf("failed to register expired statistics due to %w", err)) } - m.throttled = prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: namespace, - Name: "throttled", - Help: "Number of throttled events", - }) - if err := registerer.Register(m.throttled); err != nil { - errs.Add(fmt.Errorf("failed to register throttled statistics due to %w", err)) - } - m.getAcceptedFrontier = initHistogram(namespace, "get_accepted_frontier", registerer, &errs) m.acceptedFrontier = initHistogram(namespace, "accepted_frontier", registerer, &errs) m.getAcceptedFrontierFailed = initHistogram(namespace, "get_accepted_frontier_failed", registerer, &errs) @@ -135,7 +126,7 @@ func (m *metrics) Initialize(namespace string, registerer prometheus.Registerer) return errs.Err } -func (m *metrics) registerTierStatistics(tier int) (prometheus.Gauge, prometheus.Histogram, error) { +func (m *handlerMetrics) registerTierStatistics(tier int) (prometheus.Gauge, prometheus.Histogram, error) { errs := wrappers.Errs{} gauge := prometheus.NewGauge(prometheus.GaugeOpts{ @@ -159,7 +150,7 @@ func (m *metrics) registerTierStatistics(tier int) (prometheus.Gauge, prometheus return gauge, histogram, errs.Err } -func (m *metrics) getMSGHistogram(msg constants.MsgType) prometheus.Histogram { +func (m *handlerMetrics) getMSGHistogram(msg constants.MsgType) prometheus.Histogram { switch msg { case constants.GetAcceptedFrontierMsg: return m.getAcceptedFrontier diff --git a/snow/networking/router/handler_test.go b/snow/networking/router/handler_test.go index d73f05a381d7..35a7ed2f5cce 100644 --- a/snow/networking/router/handler_test.go +++ b/snow/networking/router/handler_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" @@ -38,7 +39,7 @@ func TestHandlerDropsTimedOutMessages(t *testing.T) { if err := vdrs.AddWeight(vdr0, 1); err != nil { t.Fatal(err) } - handler.Initialize( + err := handler.Initialize( &engine, vdrs, nil, @@ -50,6 +51,7 @@ func TestHandlerDropsTimedOutMessages(t *testing.T) { prometheus.NewRegistry(), &Delay{}, ) + assert.NoError(t, err) currentTime := time.Now() handler.clock.Set(currentTime) @@ -82,7 +84,7 @@ func TestHandlerDoesntDrop(t *testing.T) { handler := &Handler{} validators := validators.NewSet() - handler.Initialize( + err := handler.Initialize( &engine, validators, nil, @@ -94,6 +96,7 @@ func TestHandlerDoesntDrop(t *testing.T) { prometheus.NewRegistry(), &Delay{}, ) + assert.NoError(t, err) handler.GetAcceptedFrontier(ids.ShortID{}, 1, time.Time{}) go handler.Dispatch() @@ -119,7 +122,7 @@ func TestHandlerClosesOnError(t *testing.T) { } handler := &Handler{} - handler.Initialize( + err := handler.Initialize( &engine, validators.NewSet(), nil, @@ -131,6 +134,8 @@ func TestHandlerClosesOnError(t *testing.T) { prometheus.NewRegistry(), &Delay{}, ) + assert.NoError(t, err) + handler.clock.Set(time.Now()) handler.toClose = func() { @@ -163,7 +168,7 @@ func TestHandlerDropsGossipDuringBootstrapping(t *testing.T) { } handler := &Handler{} - handler.Initialize( + err := handler.Initialize( &engine, validators.NewSet(), nil, @@ -175,6 +180,8 @@ func TestHandlerDropsGossipDuringBootstrapping(t *testing.T) { prometheus.NewRegistry(), &Delay{}, ) + assert.NoError(t, err) + handler.clock.Set(time.Now()) go handler.Dispatch() diff --git a/snow/networking/router/health.go b/snow/networking/router/health.go index 59a54fbe83a2..3aa93477cd53 100644 --- a/snow/networking/router/health.go +++ b/snow/networking/router/health.go @@ -16,9 +16,8 @@ type HealthConfig struct { // Must be > 0 MaxOutstandingRequests int - // Reports unhealthy if there is at least 1 outstanding request continuously - // for longer than this - MaxTimeSinceNoOutstandingRequests time.Duration + // Reports unhealthy if there is a request outstanding for longer than this + MaxOutstandingDuration time.Duration // Reports unhealthy if there is at least 1 outstanding not processed // before this mark diff --git a/snow/networking/router/msg_manager.go b/snow/networking/router/msg_manager.go index 259e25b3a8b3..8bada7b6bf6d 100644 --- a/snow/networking/router/msg_manager.go +++ b/snow/networking/router/msg_manager.go @@ -4,16 +4,20 @@ package router import ( + "fmt" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/networking/tracker" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/timer" + "github.com/ava-labs/avalanchego/utils/wrappers" + "github.com/prometheus/client_golang/prometheus" ) const ( // DefaultMaxNonStakerPendingMsgs is the default number of messages that can be taken from - // the shared message pool by a single validator + // the shared message pool by a single node DefaultMaxNonStakerPendingMsgs uint32 = 20 // DefaultStakerPortion is the default portion of resources to reserve for stakers DefaultStakerPortion float64 = 0.375 @@ -37,11 +41,11 @@ type msgManager struct { vdrs validators.Set maxNonStakerPendingMsgs uint32 poolMessages, reservedMessages uint32 - stakerMsgPortion float64 msgTracker tracker.CountingTracker stakerCPUPortion float64 cpuTracker tracker.TimeTracker clock timer.Clock + metrics msgManagerMetrics } // NewMsgManager returns a new MsgManager @@ -60,10 +64,16 @@ func NewMsgManager( maxNonStakerPendingMsgs uint32, stakerMsgPortion, stakerCPUPortion float64, -) MsgManager { + metricsNamespace string, + metricsRegisterer prometheus.Registerer, +) (MsgManager, error) { // Number of messages reserved for stakers vs. non-stakers reservedMessages := uint32(stakerMsgPortion * float64(maxPendingMsgs)) poolMessages := maxPendingMsgs - reservedMessages + metrics := msgManagerMetrics{} + if err := metrics.initialize(metricsNamespace, metricsRegisterer); err != nil { + return nil, err + } return &msgManager{ vdrs: vdrs, @@ -74,7 +84,8 @@ func NewMsgManager( poolMessages: poolMessages, maxNonStakerPendingMsgs: maxNonStakerPendingMsgs, stakerCPUPortion: stakerCPUPortion, - } + metrics: metrics, + }, nil } // AddPending marks that there is a message from [vdr] ready to be processed. @@ -83,7 +94,14 @@ func (rm *msgManager) AddPending(vdr ids.ShortID) bool { // Attempt to take the message from the pool outstandingPoolMessages := rm.msgTracker.PoolCount() totalPeerMessages, peerPoolMessages := rm.msgTracker.OutstandingCount(vdr) - if outstandingPoolMessages < rm.poolMessages && peerPoolMessages < rm.maxNonStakerPendingMsgs { + + rm.metrics.poolMsgsAvailable.Set(float64(rm.poolMessages - outstandingPoolMessages)) + // True if the all the messages in the at-large message pool have been used + poolEmpty := outstandingPoolMessages >= rm.poolMessages + // True if this node has used the maximum number of messages from the at-large message pool + poolAllocUsed := peerPoolMessages >= rm.maxNonStakerPendingMsgs + if !poolEmpty && !poolAllocUsed { + // This node can use a message from the at-large message pool rm.msgTracker.AddPool(vdr) return true } @@ -91,6 +109,11 @@ func (rm *msgManager) AddPending(vdr ids.ShortID) bool { // Attempt to take the message from the individual allotment weight, isStaker := rm.vdrs.GetWeight(vdr) if !isStaker { + if poolEmpty { + rm.metrics.throttledPoolEmpty.Inc() + } else if poolAllocUsed { + rm.metrics.throttledPoolAllocExhausted.Inc() + } rm.log.Verbo("Throttling message from non-staker %s. %d/%d.", vdr, peerPoolMessages, rm.poolMessages) return false } @@ -104,6 +127,7 @@ func (rm *msgManager) AddPending(vdr ids.ShortID) bool { rm.msgTracker.Add(vdr) return true } + rm.metrics.throttledVdrAllocExhausted.Inc() rm.log.Debug("Throttling message from staker %s. %d/%d. %d/%d.", vdr, messageCount, messageAllotment, peerPoolMessages, rm.poolMessages) return false @@ -133,3 +157,50 @@ func (rm *msgManager) Utilization(vdr ids.ShortID) float64 { return vdrUtilization / stakerAllotment } + +type msgManagerMetrics struct { + poolMsgsAvailable prometheus.Gauge + throttledPoolEmpty, + throttledPoolAllocExhausted, + throttledVdrAllocExhausted prometheus.Counter +} + +func (m *msgManagerMetrics) initialize(namespace string, registerer prometheus.Registerer) error { + errs := wrappers.Errs{} + m.poolMsgsAvailable = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: namespace, + Name: "pool_msgs_available", + Help: "Number of available messages in the at-large pending message pool", + }) + if err := registerer.Register(m.poolMsgsAvailable); err != nil { + errs.Add(fmt.Errorf("failed to register throttled statistics due to %w", err)) + } + + m.throttledPoolEmpty = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: namespace, + Name: "throttled_pool_empty", + Help: "Number of incoming messages dropped because at-large pending message pool is empty", + }) + if err := registerer.Register(m.throttledPoolEmpty); err != nil { + errs.Add(fmt.Errorf("failed to register throttled statistics due to %w", err)) + } + + m.throttledPoolAllocExhausted = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: namespace, + Name: "throttled_pool_alloc_exhausted", + Help: "Number of incoming messages dropped because a non-validator used the max number of messages from the at-large pool", + }) + if err := registerer.Register(m.throttledPoolAllocExhausted); err != nil { + errs.Add(fmt.Errorf("failed to register throttled statistics due to %w", err)) + } + + m.throttledVdrAllocExhausted = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: namespace, + Name: "throttled_validator_alloc_exhausted", + Help: "Number of incoming messages dropped because a validator used the max number of pending messages allocated to them", + }) + if err := registerer.Register(m.throttledVdrAllocExhausted); err != nil { + errs.Add(fmt.Errorf("failed to register throttled statistics due to %w", err)) + } + return errs.Err +} diff --git a/snow/networking/router/msg_manager_test.go b/snow/networking/router/msg_manager_test.go index 9b23f7165b2e..d4d5ddd01c86 100644 --- a/snow/networking/router/msg_manager_test.go +++ b/snow/networking/router/msg_manager_test.go @@ -12,6 +12,8 @@ import ( "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/uptime" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" ) func TestAddPending(t *testing.T) { @@ -29,7 +31,7 @@ func TestAddPending(t *testing.T) { if err := vdrs.Set(vdrList); err != nil { t.Fatal(err) } - resourceManager := NewMsgManager( + resourceManager, err := NewMsgManager( vdrs, logging.NoLog{}, msgTracker, @@ -38,7 +40,10 @@ func TestAddPending(t *testing.T) { 1, // Allow each peer to take at most one message from pool 0.5, // Allot half of message queue to stakers 0.5, // Allot half of CPU time to stakers + "", + prometheus.NewRegistry(), ) + assert.NoError(t, err) for i, vdr := range vdrList { if success := resourceManager.AddPending(vdr.ID()); !success { @@ -75,7 +80,7 @@ func TestStakerGetsThrottled(t *testing.T) { if err := vdrs.Set(vdrList); err != nil { t.Fatal(err) } - resourceManager := NewMsgManager( + resourceManager, err := NewMsgManager( vdrs, logging.NoLog{}, msgTracker, @@ -84,7 +89,10 @@ func TestStakerGetsThrottled(t *testing.T) { 1, // Allow each peer to take at most one message from pool 0.5, // Allot half of message queue to stakers 0.5, // Allot half of CPU time to stakers + "", + prometheus.NewRegistry(), ) + assert.NoError(t, err) // Ensure that a staker with only part of the stake // cannot take up the entire message queue diff --git a/snow/networking/router/service_queue.go b/snow/networking/router/service_queue.go index 2f9123c76739..3c81bdd8b360 100644 --- a/snow/networking/router/service_queue.go +++ b/snow/networking/router/service_queue.go @@ -41,7 +41,6 @@ type multiLevelQueue struct { queues []singleLevelQueue cpuRanges []float64 // CPU Utilization ranges that should be attributed to a corresponding queue cpuAllotments []time.Duration // Allotments of CPU time per cycle that should be spent on each level of queue - cpuPortion float64 // Message throttling maxPendingMsgs uint32 @@ -50,7 +49,7 @@ type multiLevelQueue struct { semaChan chan struct{} log logging.Logger - metrics *metrics + metrics *handlerMetrics } // newMultiLevelQueue creates a new MultilevelQueue and counting semaphore for signaling when messages are available @@ -63,7 +62,7 @@ func newMultiLevelQueue( consumptionAllotments []time.Duration, maxPendingMsgs uint32, log logging.Logger, - metrics *metrics, + metrics *handlerMetrics, ) (messageQueue, chan struct{}) { semaChan := make(chan struct{}, maxPendingMsgs) singleLevelSize := int(maxPendingMsgs) / len(consumptionRanges) @@ -209,7 +208,6 @@ func (ml *multiLevelQueue) pushMessage(msg message) bool { processing := ml.msgManager.AddPending(msg.validatorID) if !processing { ml.metrics.dropped.Inc() - ml.metrics.throttled.Inc() return false } diff --git a/snow/networking/router/service_queue_test.go b/snow/networking/router/service_queue_test.go index bde23104ba39..1926d0048270 100644 --- a/snow/networking/router/service_queue_test.go +++ b/snow/networking/router/service_queue_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/networking/tracker" @@ -19,7 +20,7 @@ import ( // returns a new multi-level queue that will never throttle or prioritize func setupMultiLevelQueue(t *testing.T, bufferSize uint32) (messageQueue, chan struct{}) { - metrics := &metrics{} + metrics := &handlerMetrics{} if err := metrics.Initialize("", prometheus.NewRegistry()); err != nil { t.Fatal(err) } @@ -134,7 +135,7 @@ func TestMultiLevelQueuePrioritizes(t *testing.T) { t.Fatal(err) } - metrics := &metrics{} + metrics := &handlerMetrics{} if err := metrics.Initialize("", prometheus.NewRegistry()); err != nil { t.Fatal(err) } @@ -159,7 +160,7 @@ func TestMultiLevelQueuePrioritizes(t *testing.T) { cpuTracker := tracker.NewCPUTracker(uptime.IntervalFactory{}, time.Second) msgTracker := tracker.NewMessageTracker() - resourceManager := NewMsgManager( + resourceManager, err := NewMsgManager( vdrs, logging.NoLog{}, msgTracker, @@ -168,7 +169,11 @@ func TestMultiLevelQueuePrioritizes(t *testing.T) { DefaultMaxNonStakerPendingMsgs, DefaultStakerPortion, DefaultStakerPortion, + "", + prometheus.NewRegistry(), ) + assert.NoError(t, err) + queue, semaChan := newMultiLevelQueue( resourceManager, consumptionRanges, @@ -247,7 +252,7 @@ func TestMultiLevelQueuePushesDownOldMessages(t *testing.T) { t.Fatal(err) } - metrics := &metrics{} + metrics := &handlerMetrics{} if err := metrics.Initialize("", prometheus.NewRegistry()); err != nil { t.Fatal(err) } @@ -272,7 +277,7 @@ func TestMultiLevelQueuePushesDownOldMessages(t *testing.T) { cpuTracker := tracker.NewCPUTracker(uptime.IntervalFactory{}, time.Second) msgTracker := tracker.NewMessageTracker() - resourceManager := NewMsgManager( + resourceManager, err := NewMsgManager( vdrs, logging.NoLog{}, msgTracker, @@ -281,7 +286,11 @@ func TestMultiLevelQueuePushesDownOldMessages(t *testing.T) { DefaultMaxNonStakerPendingMsgs, DefaultStakerPortion, DefaultStakerPortion, + "", + prometheus.NewRegistry(), ) + assert.NoError(t, err) + queue, semaChan := newMultiLevelQueue( resourceManager, consumptionRanges, @@ -347,7 +356,7 @@ func TestMultiLevelQueueFreesSpace(t *testing.T) { t.Fatal(err) } - metrics := &metrics{} + metrics := &handlerMetrics{} if err := metrics.Initialize("", prometheus.NewRegistry()); err != nil { t.Fatal(err) } @@ -375,7 +384,7 @@ func TestMultiLevelQueueFreesSpace(t *testing.T) { cpuTracker := tracker.NewCPUTracker(uptime.IntervalFactory{}, time.Second) msgTracker := tracker.NewMessageTracker() - resourceManager := NewMsgManager( + resourceManager, err := NewMsgManager( vdrs, logging.NoLog{}, msgTracker, @@ -384,7 +393,11 @@ func TestMultiLevelQueueFreesSpace(t *testing.T) { DefaultMaxNonStakerPendingMsgs, DefaultStakerPortion, DefaultStakerPortion, + "", + prometheus.NewRegistry(), ) + assert.NoError(t, err) + queue, semaChan := newMultiLevelQueue( resourceManager, consumptionRanges, @@ -445,7 +458,7 @@ func TestMultiLevelQueueThrottles(t *testing.T) { t.Fatal(err) } - metrics := &metrics{} + metrics := &handlerMetrics{} if err := metrics.Initialize("", prometheus.NewRegistry()); err != nil { t.Fatal(err) } diff --git a/snow/networking/sender/sender_test.go b/snow/networking/sender/sender_test.go index e562274d1187..633e5dcdb7ad 100644 --- a/snow/networking/sender/sender_test.go +++ b/snow/networking/sender/sender_test.go @@ -85,7 +85,7 @@ func TestTimeout(t *testing.T) { } handler := router.Handler{} - handler.Initialize( + err = handler.Initialize( &engine, vdrs, nil, @@ -97,6 +97,8 @@ func TestTimeout(t *testing.T) { prometheus.NewRegistry(), &router.Delay{}, ) + assert.NoError(t, err) + go handler.Dispatch() chainRouter.AddChain(&handler) @@ -159,7 +161,7 @@ func TestReliableMessages(t *testing.T) { } handler := router.Handler{} - handler.Initialize( + err = handler.Initialize( &engine, vdrs, nil, @@ -171,6 +173,8 @@ func TestReliableMessages(t *testing.T) { prometheus.NewRegistry(), &router.Delay{}, ) + assert.NoError(t, err) + go handler.Dispatch() chainRouter.AddChain(&handler) @@ -242,7 +246,7 @@ func TestReliableMessagesToMyself(t *testing.T) { } handler := router.Handler{} - handler.Initialize( + err = handler.Initialize( &engine, vdrs, nil, @@ -254,6 +258,8 @@ func TestReliableMessagesToMyself(t *testing.T) { prometheus.NewRegistry(), &router.Delay{}, ) + assert.NoError(t, err) + go handler.Dispatch() chainRouter.AddChain(&handler) diff --git a/snow/triggers/dispatcher.go b/snow/triggers/dispatcher.go index f24cb4e102f7..7ace1028fe29 100644 --- a/snow/triggers/dispatcher.go +++ b/snow/triggers/dispatcher.go @@ -12,23 +12,36 @@ import ( "github.com/ava-labs/avalanchego/utils/logging" ) +var _ snow.EventDispatcher = &EventDispatcher{} + +type handler struct { + // Must implement at least one of Acceptor, Rejector, Issuer + handlerFunc interface{} + // If true and [handlerFunc] returns an error during a call to Accept, + // the chain this handler corresponds to will stop. + dieOnError bool +} + // EventDispatcher receives events from consensus and dispatches the events to triggers type EventDispatcher struct { - lock sync.Mutex - log logging.Logger - chainHandlers map[ids.ID]map[string]interface{} + lock sync.Mutex + log logging.Logger + // Chain ID --> Identifier --> handler + chainHandlers map[ids.ID]map[string]handler handlers map[string]interface{} } // Initialize creates the EventDispatcher's initial values func (ed *EventDispatcher) Initialize(log logging.Logger) { ed.log = log - ed.chainHandlers = make(map[ids.ID]map[string]interface{}) + ed.chainHandlers = make(map[ids.ID]map[string]handler) ed.handlers = make(map[string]interface{}) } -// Accept is called when a transaction or block is accepted -func (ed *EventDispatcher) Accept(ctx *snow.Context, containerID ids.ID, container []byte) { +// Accept is called when a transaction or block is accepted. +// If the returned error is non-nil, the chain associated with [ctx] should shut +// down and not commit [container] or any other container to its database as accepted. +func (ed *EventDispatcher) Accept(ctx *snow.Context, containerID ids.ID, container []byte) error { ed.lock.Lock() defer ed.lock.Unlock() @@ -39,28 +52,32 @@ func (ed *EventDispatcher) Accept(ctx *snow.Context, containerID ids.ID, contain } if err := handler.Accept(ctx, containerID, container); err != nil { - ed.log.Error("unable to Accept on %s for chainID %s: %s", id, ctx.ChainID, err) + ed.log.Error("handler %s on chain %s errored while accepting %s: %s", id, ctx.ChainID, containerID, err) } } events, exist := ed.chainHandlers[ctx.ChainID] if !exist { - return + return nil } for id, handler := range events { - handler, ok := handler.(Acceptor) + handlerFunc, ok := handler.handlerFunc.(Acceptor) if !ok { continue } - if err := handler.Accept(ctx, containerID, container); err != nil { - ed.log.Error("unable to Accept on %s for chainID %s: %s", id, ctx.ChainID, err) + if err := handlerFunc.Accept(ctx, containerID, container); err != nil { + ed.log.Error("handler %s on chain %s errored while accepting %s: %s", id, ctx.ChainID, containerID, err) + if handler.dieOnError { + return fmt.Errorf("handler %s on chain %s errored while accepting %s: %w", id, ctx.ChainID, containerID, err) + } } } + return nil } // Reject is called when a transaction or block is rejected -func (ed *EventDispatcher) Reject(ctx *snow.Context, containerID ids.ID, container []byte) { +func (ed *EventDispatcher) Reject(ctx *snow.Context, containerID ids.ID, container []byte) error { ed.lock.Lock() defer ed.lock.Unlock() @@ -77,10 +94,10 @@ func (ed *EventDispatcher) Reject(ctx *snow.Context, containerID ids.ID, contain events, exist := ed.chainHandlers[ctx.ChainID] if !exist { - return + return nil } for id, handler := range events { - handler, ok := handler.(Rejector) + handler, ok := handler.handlerFunc.(Rejector) if !ok { continue } @@ -89,10 +106,11 @@ func (ed *EventDispatcher) Reject(ctx *snow.Context, containerID ids.ID, contain ed.log.Error("unable to Reject on %s for chainID %s: %s", id, ctx.ChainID, err) } } + return nil } // Issue is called when a transaction or block is issued -func (ed *EventDispatcher) Issue(ctx *snow.Context, containerID ids.ID, container []byte) { +func (ed *EventDispatcher) Issue(ctx *snow.Context, containerID ids.ID, container []byte) error { ed.lock.Lock() defer ed.lock.Unlock() @@ -109,10 +127,10 @@ func (ed *EventDispatcher) Issue(ctx *snow.Context, containerID ids.ID, containe events, exist := ed.chainHandlers[ctx.ChainID] if !exist { - return + return nil } for id, handler := range events { - handler, ok := handler.(Issuer) + handler, ok := handler.handlerFunc.(Issuer) if !ok { continue } @@ -121,16 +139,19 @@ func (ed *EventDispatcher) Issue(ctx *snow.Context, containerID ids.ID, containe ed.log.Error("unable to Issue on %s for chainID %s: %s", id, ctx.ChainID, err) } } + return nil } -// RegisterChain places a new chain handler into the system -func (ed *EventDispatcher) RegisterChain(chainID ids.ID, identifier string, handler interface{}) error { +// RegisterChain causes [handlerFunc] to be invoked every time a container is issued, accepted or rejected on chain [chainID]. +// [handlerFunc] should implement at least one of Acceptor, Rejector, Issuer. +// If [dieOnError], chain [chainID] stops if [handler].Accept is invoked and returns a non-nil error. +func (ed *EventDispatcher) RegisterChain(chainID ids.ID, identifier string, handlerFunc interface{}, dieOnError bool) error { ed.lock.Lock() defer ed.lock.Unlock() events, exist := ed.chainHandlers[chainID] if !exist { - events = make(map[string]interface{}) + events = make(map[string]handler) ed.chainHandlers[chainID] = events } @@ -138,7 +159,10 @@ func (ed *EventDispatcher) RegisterChain(chainID ids.ID, identifier string, hand return fmt.Errorf("handler %s already exists on chain %s", identifier, chainID) } - events[identifier] = handler + events[identifier] = handler{ + handlerFunc: handlerFunc, + dieOnError: dieOnError, + } return nil } diff --git a/utils/bytes.go b/utils/bytes.go index 789b881df2e9..b39c1a3635e4 100644 --- a/utils/bytes.go +++ b/utils/bytes.go @@ -3,6 +3,8 @@ package utils +import "crypto/rand" + // CopyBytes returns a copy of the provided byte slice. If nil is provided, nil // will be returned. func CopyBytes(b []byte) []byte { @@ -14,3 +16,11 @@ func CopyBytes(b []byte) []byte { copy(cb, b) return cb } + +// RandomBytes returns a slice of n random bytes +// Intended for use in testing +func RandomBytes(n int) []byte { + b := make([]byte, n) + _, _ = rand.Read(b) + return b +} diff --git a/utils/logging/log.go b/utils/logging/log.go index 0212d8eff0f1..9e7bfc6fb3ad 100644 --- a/utils/logging/log.go +++ b/utils/logging/log.go @@ -363,7 +363,7 @@ func (fw *fileWriter) Close() error { func (fw *fileWriter) Rotate() error { fw.fileIndex = (fw.fileIndex + 1) % fw.config.RotationSize - writer, file, err := fw.create(fw.fileIndex) + writer, file, err := fw.create() if err != nil { return err } @@ -372,7 +372,7 @@ func (fw *fileWriter) Rotate() error { return nil } -func (fw *fileWriter) create(fileIndex int) (*bufio.Writer, *os.File, error) { +func (fw *fileWriter) create() (*bufio.Writer, *os.File, error) { filename := filepath.Join(fw.config.Directory, fmt.Sprintf("%d.log", fw.fileIndex)) file, err := perms.Create(filename, perms.ReadWrite) if err != nil { @@ -384,7 +384,7 @@ func (fw *fileWriter) create(fileIndex int) (*bufio.Writer, *os.File, error) { func (fw *fileWriter) Initialize(config Config) error { fw.config = config - writer, file, err := fw.create(fw.fileIndex) + writer, file, err := fw.create() if err != nil { return err } diff --git a/utils/perms/perms.go b/utils/perms/perms.go index 882cdf3dcbee..f5a9ed45e44a 100644 --- a/utils/perms/perms.go +++ b/utils/perms/perms.go @@ -5,6 +5,6 @@ package perms const ( ReadOnly = 0400 - ReadWrite = 0600 - ReadWriteExecute = 0700 + ReadWrite = 0640 + ReadWriteExecute = 0750 ) diff --git a/utils/wrappers/packing.go b/utils/wrappers/packing.go index 98a14ea389da..0a7afedf0014 100644 --- a/utils/wrappers/packing.go +++ b/utils/wrappers/packing.go @@ -6,6 +6,7 @@ package wrappers import ( "encoding/binary" "errors" + "fmt" "math" "github.com/ava-labs/avalanchego/utils" @@ -502,3 +503,19 @@ func TryPackIPList(packer *Packer, valIntf interface{}) { func TryUnpackIPList(packer *Packer) interface{} { return packer.UnpackIPs() } + +// PackLong returns the byte representation of a uint64 +func PackLong(val uint64) []byte { + bytes := make([]byte, 8) + binary.BigEndian.PutUint64(bytes, val) + return bytes +} + +// PackLong returns the byte representation of a uint64. +// Returns an error if len(bytes) != 8 +func UnpackLong(bytes []byte) (uint64, error) { + if len(bytes) != 8 { + return 0, fmt.Errorf("expected len(bytes) to be 8 but is %d", len(bytes)) + } + return binary.BigEndian.Uint64(bytes), nil +} diff --git a/utils/wrappers/packing_test.go b/utils/wrappers/packing_test.go index faaa8d136b20..e598f2b66664 100644 --- a/utils/wrappers/packing_test.go +++ b/utils/wrappers/packing_test.go @@ -5,8 +5,11 @@ package wrappers import ( "bytes" + "math" "reflect" "testing" + + "github.com/stretchr/testify/assert" ) const ( @@ -581,3 +584,30 @@ func TestPacker2DByteSlice(t *testing.T) { t.Fatal("should match") } } + +func TestPackLong(t *testing.T) { + for _, n := range []uint64{0, 10000, math.MaxUint64} { + bytes := PackLong(n) + got, err := UnpackLong(bytes) + assert.NoError(t, err) + assert.Equal(t, n, got) + } +} + +func TestUnpackLong(t *testing.T) { + // Too few bytes + bytes := make([]byte, 7) + _, err := UnpackLong(bytes) + assert.Error(t, err) + + // Too many bytes + bytes = make([]byte, 9) + _, err = UnpackLong(bytes) + assert.Error(t, err) + + // Right number of bytes + bytes = make([]byte, 8) + n, err := UnpackLong(bytes) + assert.NoError(t, err) + assert.EqualValues(t, 0, n) +} diff --git a/version/compatibility.go b/version/compatibility.go new file mode 100644 index 000000000000..445a3d43eaf9 --- /dev/null +++ b/version/compatibility.go @@ -0,0 +1,148 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package version + +import ( + "errors" + "time" + + "github.com/ava-labs/avalanchego/utils/timer" +) + +var ( + errIncompatible = errors.New("peers version is incompatible") + errMaskable = errors.New("peers version is maskable") +) + +// Compatibility a utility for checking the compatibility of peer versions +type Compatibility interface { + // Returns the local version + Version() Version + + // Returns nil if the provided version is able to connect with the local + // version. This means that the node will keep connections open with the + // peer. + Connectable(Version) error + + // Returns nil if the provided version is compatible with the local version. + // This means that the version is connectable and that consensus messages + // can be made to them. + Compatible(Version) error + + // Returns nil if the provided version shouldn't be masked. This means that + // the version is connectable but not compatible. The version is so old that + // it should just be masked. + Unmaskable(Version) error + + // Returns nil if the provided version will not be masked by this version. + WontMask(Version) error + + // Returns when additional masking will occur. + MaskTime() time.Time +} + +type compatibility struct { + version Version + + minCompatable Version + minCompatableTime time.Time + prevMinCompatable Version + + minUnmaskable Version + minUnmaskableTime time.Time + prevMinUnmaskable Version + + clock timer.Clock +} + +// NewCompatibility returns a compatibility checker with the provided options +func NewCompatibility( + version Version, + minCompatable Version, + minCompatableTime time.Time, + prevMinCompatable Version, + minUnmaskable Version, + minUnmaskableTime time.Time, + prevMinUnmaskable Version, +) Compatibility { + return &compatibility{ + version: version, + minCompatable: minCompatable, + minCompatableTime: minCompatableTime, + prevMinCompatable: prevMinCompatable, + minUnmaskable: minUnmaskable, + minUnmaskableTime: minUnmaskableTime, + prevMinUnmaskable: prevMinUnmaskable, + } +} + +func (c *compatibility) Version() Version { return c.version } + +func (c *compatibility) Connectable(peer Version) error { + return c.version.Compatible(peer) +} + +func (c *compatibility) Compatible(peer Version) error { + if err := c.Connectable(peer); err != nil { + return err + } + + if !peer.Before(c.minCompatable) { + // The peer is at least the minimum compatible version. + return nil + } + + // The peer is going to be marked as incompatible at [c.minCompatableTime]. + now := c.clock.Time() + if !now.Before(c.minCompatableTime) { + return errIncompatible + } + + // The minCompatable check isn't being enforced yet. + if !peer.Before(c.prevMinCompatable) { + // The peer is at least the previous minimum compatible version. + return nil + } + return errIncompatible +} + +func (c *compatibility) Unmaskable(peer Version) error { + if err := c.Connectable(peer); err != nil { + return err + } + + if !peer.Before(c.minUnmaskable) { + // The peer is at least the minimum unmaskable version. + return nil + } + + // The peer is going to be marked as maskable at [c.minUnmaskableTime]. + now := c.clock.Time() + if !now.Before(c.minUnmaskableTime) { + return errMaskable + } + + // The minCompatable check isn't being enforced yet. + if !peer.Before(c.prevMinUnmaskable) { + // The peer is at least the previous minimum unmaskable version. + return nil + } + return errMaskable +} + +func (c *compatibility) WontMask(peer Version) error { + if err := c.Connectable(peer); err != nil { + return err + } + + if !peer.Before(c.minUnmaskable) { + // The peer is at least the minimum unmaskable version. + return nil + } + return errMaskable +} + +func (c *compatibility) MaskTime() time.Time { + return c.minUnmaskableTime +} diff --git a/version/compatibility_test.go b/version/compatibility_test.go new file mode 100644 index 000000000000..b21e94fde691 --- /dev/null +++ b/version/compatibility_test.go @@ -0,0 +1,142 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package version + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCompatibility(t *testing.T) { + v := NewDefaultVersion("avalanche", 1, 4, 3) + minCompatable := NewDefaultVersion("avalanche", 1, 4, 0) + minCompatableTime := time.Unix(9000, 0) + prevMinCompatable := NewDefaultVersion("avalanche", 1, 3, 0) + minUnmaskable := NewDefaultVersion("avalanche", 1, 2, 0) + minUnmaskableTime := time.Unix(7000, 0) + prevMinUnmaskable := NewDefaultVersion("avalanche", 1, 1, 0) + + compatibility := NewCompatibility( + v, + minCompatable, + minCompatableTime, + prevMinCompatable, + minUnmaskable, + minUnmaskableTime, + prevMinUnmaskable, + ).(*compatibility) + assert.Equal(t, v, compatibility.Version()) + assert.Equal(t, minUnmaskableTime, compatibility.MaskTime()) + + tests := []struct { + peer Version + time time.Time + connectable bool + compatible bool + unmaskable bool + wontMask bool + }{ + { + peer: NewDefaultVersion("avalanche", 1, 5, 0), + time: minCompatableTime, + connectable: true, + compatible: true, + unmaskable: true, + wontMask: true, + }, + { + peer: NewDefaultVersion("avalanche", 1, 3, 5), + time: time.Unix(8500, 0), + connectable: true, + compatible: true, + unmaskable: true, + wontMask: true, + }, + { + peer: NewDefaultVersion("ava", 1, 5, 0), + time: minCompatableTime, + connectable: false, + compatible: false, + unmaskable: false, + wontMask: false, + }, + { + peer: NewDefaultVersion("avalanche", 0, 1, 0), + time: minCompatableTime, + connectable: false, + compatible: false, + unmaskable: false, + wontMask: false, + }, + { + peer: NewDefaultVersion("avalanche", 1, 3, 5), + time: minCompatableTime, + connectable: true, + compatible: false, + unmaskable: true, + wontMask: true, + }, + { + peer: NewDefaultVersion("avalanche", 1, 2, 5), + time: time.Unix(8500, 0), + connectable: true, + compatible: false, + unmaskable: true, + wontMask: true, + }, + { + peer: NewDefaultVersion("avalanche", 1, 1, 5), + time: time.Unix(7500, 0), + connectable: true, + compatible: false, + unmaskable: false, + wontMask: false, + }, + { + peer: NewDefaultVersion("avalanche", 1, 1, 5), + time: time.Unix(6500, 0), + connectable: true, + compatible: false, + unmaskable: true, + wontMask: false, + }, + { + peer: NewDefaultVersion("avalanche", 1, 0, 5), + time: time.Unix(6500, 0), + connectable: true, + compatible: false, + unmaskable: false, + wontMask: false, + }, + } + for _, test := range tests { + peer := test.peer + compatibility.clock.Set(test.time) + t.Run(fmt.Sprintf("%s-%s", peer, test.time), func(t *testing.T) { + if err := compatibility.Connectable(peer); test.connectable && err != nil { + t.Fatalf("incorrectly marked %s as un-connectable with %s", peer, err) + } else if !test.connectable && err == nil { + t.Fatalf("incorrectly marked %s as connectable", peer) + } + if err := compatibility.Compatible(peer); test.compatible && err != nil { + t.Fatalf("incorrectly marked %s as incompatible with %s", peer, err) + } else if !test.compatible && err == nil { + t.Fatalf("incorrectly marked %s as compatible", peer) + } + if err := compatibility.Unmaskable(peer); test.unmaskable && err != nil { + t.Fatalf("incorrectly marked %s as un-maskable with %s", peer, err) + } else if !test.unmaskable && err == nil { + t.Fatalf("incorrectly marked %s as maskable", peer) + } + if err := compatibility.WontMask(peer); test.wontMask && err != nil { + t.Fatalf("incorrectly marked %s as unmaskable with %s", peer, err) + } else if !test.wontMask && err == nil { + t.Fatalf("incorrectly marked %s as maskable", peer) + } + }) + } +} diff --git a/vms/avm/base_tx_test.go b/vms/avm/base_tx_test.go index 419ff8b9096f..b4a222198655 100644 --- a/vms/avm/base_tx_test.go +++ b/vms/avm/base_tx_test.go @@ -1133,7 +1133,7 @@ func TestBaseTxSemanticVerifyPendingInvalidUTXO(t *testing.T) { ctx.Lock.Unlock() }() - vm.Pending() + vm.PendingTxs() tx := &Tx{UnsignedTx: &BaseTx{BaseTx: avax.BaseTx{ NetworkID: networkID, @@ -1228,7 +1228,7 @@ func TestBaseTxSemanticVerifyPendingWrongAssetID(t *testing.T) { ctx.Lock.Unlock() }() - vm.Pending() + vm.PendingTxs() tx := &Tx{UnsignedTx: &BaseTx{BaseTx: avax.BaseTx{ NetworkID: networkID, @@ -1366,7 +1366,7 @@ func TestBaseTxSemanticVerifyPendingUnauthorizedFx(t *testing.T) { ctx.Lock.Unlock() }() - vm.Pending() + vm.PendingTxs() tx := &Tx{ UnsignedTx: &BaseTx{BaseTx: avax.BaseTx{ @@ -1508,7 +1508,7 @@ func TestBaseTxSemanticVerifyPendingInvalidSignature(t *testing.T) { ctx.Lock.Unlock() }() - vm.Pending() + vm.PendingTxs() tx := &Tx{ UnsignedTx: &BaseTx{BaseTx: avax.BaseTx{ diff --git a/vms/avm/export_tx_test.go b/vms/avm/export_tx_test.go index cdbe28bce14e..45ac32117d43 100644 --- a/vms/avm/export_tx_test.go +++ b/vms/avm/export_tx_test.go @@ -774,7 +774,7 @@ func TestExportTxSemanticVerify(t *testing.T) { t.Fatal(err) } - tx, err := vm.Parse(rawTx.Bytes()) + tx, err := vm.ParseTx(rawTx.Bytes()) if err != nil { t.Fatal(err) } @@ -834,7 +834,7 @@ func TestExportTxSemanticVerifyUnknownCredFx(t *testing.T) { t.Fatal(err) } - tx, err := vm.Parse(rawTx.Bytes()) + tx, err := vm.ParseTx(rawTx.Bytes()) if err != nil { t.Fatal(err) } @@ -894,7 +894,7 @@ func TestExportTxSemanticVerifyMissingUTXO(t *testing.T) { t.Fatal(err) } - tx, err := vm.Parse(rawTx.Bytes()) + tx, err := vm.ParseTx(rawTx.Bytes()) if err != nil { t.Fatal(err) } @@ -978,7 +978,7 @@ func TestExportTxSemanticVerifyInvalidAssetID(t *testing.T) { t.Fatal(err) } - tx, err := vm.Parse(rawTx.Bytes()) + tx, err := vm.ParseTx(rawTx.Bytes()) if err != nil { t.Fatal(err) } @@ -1008,11 +1008,8 @@ func TestExportTxSemanticVerifyInvalidFx(t *testing.T) { ctx.Lock.Lock() - userKeystore, err := keystore.CreateTestKeystore() - if err != nil { - t.Fatal(err) - } - if err := userKeystore.AddUser(username, password); err != nil { + userKeystore := keystore.New(logging.NoLog{}, memdb.New()) + if err := userKeystore.CreateUser(username, password); err != nil { t.Fatal(err) } ctx.Keystore = userKeystore.NewBlockchainKeyStore(ctx.ChainID) @@ -1097,7 +1094,7 @@ func TestExportTxSemanticVerifyInvalidFx(t *testing.T) { t.Fatal(err) } - tx, err := vm.Parse(rawTx.Bytes()) + tx, err := vm.ParseTx(rawTx.Bytes()) if err != nil { t.Fatal(err) } @@ -1157,7 +1154,7 @@ func TestExportTxSemanticVerifyInvalidTransfer(t *testing.T) { t.Fatal(err) } - tx, err := vm.Parse(rawTx.Bytes()) + tx, err := vm.ParseTx(rawTx.Bytes()) if err != nil { t.Fatal(err) } @@ -1269,7 +1266,7 @@ func TestIssueExportTx(t *testing.T) { ctx.Lock.Unlock() }() - txs := vm.Pending() + txs := vm.PendingTxs() if len(txs) != 1 { t.Fatalf("Should have returned %d tx(s)", 1) } @@ -1400,7 +1397,7 @@ func TestClearForceAcceptedExportTx(t *testing.T) { ctx.Lock.Unlock() }() - txs := vm.Pending() + txs := vm.PendingTxs() if len(txs) != 1 { t.Fatalf("Should have returned %d tx(s)", 1) } diff --git a/vms/avm/import_tx_test.go b/vms/avm/import_tx_test.go index 1a2a9f554e4c..0df8416ebd3c 100644 --- a/vms/avm/import_tx_test.go +++ b/vms/avm/import_tx_test.go @@ -343,7 +343,7 @@ func TestIssueImportTx(t *testing.T) { ctx.Lock.Unlock() }() - txs := vm.Pending() + txs := vm.PendingTxs() if len(txs) != 1 { t.Fatalf("Should have returned %d tx(s)", 1) } @@ -442,7 +442,7 @@ func TestForceAcceptImportTx(t *testing.T) { t.Fatal(err) } - parsedTx, err := vm.Parse(tx.Bytes()) + parsedTx, err := vm.ParseTx(tx.Bytes()) if err != nil { t.Fatal(err) } diff --git a/vms/avm/service_test.go b/vms/avm/service_test.go index cf7e347f5cca..074a036f5e78 100644 --- a/vms/avm/service_test.go +++ b/vms/avm/service_test.go @@ -14,12 +14,14 @@ import ( "github.com/ava-labs/avalanchego/api" "github.com/ava-labs/avalanchego/api/keystore" "github.com/ava-labs/avalanchego/chains/atomic" + "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto" "github.com/ava-labs/avalanchego/utils/formatting" "github.com/ava-labs/avalanchego/utils/json" + "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/sampler" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/secp256k1fx" @@ -37,11 +39,8 @@ var ( // 4) atomic memory to use in tests func setup(t *testing.T) ([]byte, *VM, *Service, *atomic.Memory) { genesisBytes, _, vm, m := GenesisVM(t) - keystore, err := keystore.CreateTestKeystore() - if err != nil { - t.Fatal(err) - } - if err := keystore.AddUser(username, password); err != nil { + keystore := keystore.New(logging.NoLog{}, memdb.New()) + if err := keystore.CreateUser(username, password); err != nil { t.Fatalf("couldn't add user: %s", err) } vm.ctx.Keystore = keystore.NewBlockchainKeyStore(chainID) @@ -1486,7 +1485,7 @@ func TestSendMultiple(t *testing.T) { t.Fatal("Transaction ID returned by SendMultiple does not match the transaction found in vm's pending transactions") } - if _, err = vm.Get(reply.TxID); err != nil { + if _, err = vm.GetTx(reply.TxID); err != nil { t.Fatalf("Failed to retrieve created transaction: %s", err) } } diff --git a/vms/avm/unique_tx.go b/vms/avm/unique_tx.go index f73c88c6fdd1..3b6174dd7ba4 100644 --- a/vms/avm/unique_tx.go +++ b/vms/avm/unique_tx.go @@ -145,6 +145,7 @@ func (tx *UniqueTx) Accept() error { } txID := tx.ID() + commitBatch, err := tx.vm.db.CommitBatch() if err != nil { tx.vm.ctx.Log.Error("Failed to calculate CommitBatch for %s due to %s", txID, err) @@ -159,6 +160,7 @@ func (tx *UniqueTx) Accept() error { tx.vm.ctx.Log.Verbo("Accepted Tx: %s", txID) tx.vm.pubsub.Publish("accepted", txID) + tx.vm.walletService.decided(txID) tx.deps = nil // Needed to prevent a memory leak diff --git a/vms/avm/user_state.go b/vms/avm/user_state.go index 228844e8fe14..b1fb432161b0 100644 --- a/vms/avm/user_state.go +++ b/vms/avm/user_state.go @@ -6,7 +6,7 @@ package avm import ( "fmt" - "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/encdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto" "github.com/ava-labs/avalanchego/vms/secp256k1fx" @@ -17,7 +17,7 @@ var addresses = ids.Empty type userState struct{ vm *VM } // SetAddresses ... -func (s *userState) SetAddresses(db database.Database, addrs []ids.ShortID) error { +func (s *userState) SetAddresses(db *encdb.Database, addrs []ids.ShortID) error { bytes, err := s.vm.codec.Marshal(codecVersion, addrs) if err != nil { return err @@ -26,7 +26,7 @@ func (s *userState) SetAddresses(db database.Database, addrs []ids.ShortID) erro } // Addresses ... -func (s *userState) Addresses(db database.Database) ([]ids.ShortID, error) { +func (s *userState) Addresses(db *encdb.Database) ([]ids.ShortID, error) { bytes, err := db.Get(addresses[:]) if err != nil { return nil, err @@ -43,7 +43,7 @@ func (s *userState) Addresses(db database.Database) ([]ids.ShortID, error) { // in addresses. If any key is missing, an error is returned. // If [addresses] is empty, then it will create a keychain using // every address in [db]. -func (s *userState) Keychain(db database.Database, addresses ids.ShortSet) (*secp256k1fx.Keychain, error) { +func (s *userState) Keychain(db *encdb.Database, addresses ids.ShortSet) (*secp256k1fx.Keychain, error) { kc := secp256k1fx.NewKeychain() addrsList := addresses.List() @@ -63,12 +63,12 @@ func (s *userState) Keychain(db database.Database, addresses ids.ShortSet) (*sec } // SetKey ... -func (s *userState) SetKey(db database.Database, sk *crypto.PrivateKeySECP256K1R) error { +func (s *userState) SetKey(db *encdb.Database, sk *crypto.PrivateKeySECP256K1R) error { return db.Put(sk.PublicKey().Address().Bytes(), sk.Bytes()) } // Key ... -func (s *userState) Key(db database.Database, address ids.ShortID) (*crypto.PrivateKeySECP256K1R, error) { +func (s *userState) Key(db *encdb.Database, address ids.ShortID) (*crypto.PrivateKeySECP256K1R, error) { factory := crypto.FactorySECP256K1R{} bytes, err := db.Get(address.Bytes()) diff --git a/vms/avm/vm.go b/vms/avm/vm.go index d99150b64684..fee0284be7d9 100644 --- a/vms/avm/vm.go +++ b/vms/avm/vm.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/rpc/v2" + "github.com/ava-labs/avalanchego/api/pubsub" "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/codec/linearcodec" @@ -61,7 +62,8 @@ var ( errBootstrapping = errors.New("chain is currently bootstrapping") errInsufficientFunds = errors.New("insufficient funds") - _ vertex.DAGVM = &VM{} + _ vertex.DAGVM = &VM{} + _ common.StaticVM = &VM{} ) // VM implements the avalanche.DAGVM interface @@ -79,7 +81,7 @@ type VM struct { codec codec.Manager codecRegistry codec.Registry - pubsub *cjson.PubSubServer + pubsub *pubsub.Server // State management state *prefixedState @@ -132,7 +134,7 @@ func (vm *VM) Initialize( vm.Aliaser.Initialize() vm.assetToFxCache = &cache.LRU{Size: assetToFxCacheSize} - vm.pubsub = cjson.NewPubSubServer(ctx) + vm.pubsub = pubsub.NewServer(ctx) genesisCodec := linearcodec.New(reflectcodec.DefaultTagName, 1<<20) c := linearcodec.NewDefault() @@ -299,22 +301,21 @@ func (vm *VM) CreateHandlers() (map[string]*common.HTTPHandler, error) { }, err } -// CreateStaticHandlers implements the avalanche.DAGVM interface -func (vm *VM) CreateStaticHandlers() map[string]*common.HTTPHandler { +// CreateStaticHandlers implements the common.StaticVM interface +func (vm *VM) CreateStaticHandlers() (map[string]*common.HTTPHandler, error) { newServer := rpc.NewServer() codec := cjson.NewCodec() newServer.RegisterCodec(codec, "application/json") newServer.RegisterCodec(codec, "application/json;charset=UTF-8") // name this service "avm" staticService := CreateStaticService() - _ = newServer.RegisterService(staticService, "avm") return map[string]*common.HTTPHandler{ "": {LockOptions: common.WriteLock, Handler: newServer}, - } + }, newServer.RegisterService(staticService, "avm") } // Pending implements the avalanche.DAGVM interface -func (vm *VM) Pending() []snowstorm.Tx { +func (vm *VM) PendingTxs() []snowstorm.Tx { vm.metrics.numPendingCalls.Inc() vm.timer.Cancel() @@ -325,14 +326,14 @@ func (vm *VM) Pending() []snowstorm.Tx { } // Parse implements the avalanche.DAGVM interface -func (vm *VM) Parse(b []byte) (snowstorm.Tx, error) { +func (vm *VM) ParseTx(b []byte) (snowstorm.Tx, error) { vm.metrics.numParseCalls.Inc() return vm.parseTx(b) } // Get implements the avalanche.DAGVM interface -func (vm *VM) Get(txID ids.ID) (snowstorm.Tx, error) { +func (vm *VM) GetTx(txID ids.ID) (snowstorm.Tx, error) { vm.metrics.numGetCalls.Inc() tx := &UniqueTx{ diff --git a/vms/avm/vm_test.go b/vms/avm/vm_test.go index 3859a4a7717c..fb11f24836f9 100644 --- a/vms/avm/vm_test.go +++ b/vms/avm/vm_test.go @@ -252,11 +252,8 @@ func GenesisVMWithArgs(tb testing.TB, args *BuildGenesisArgs) ([]byte, chan comm // The caller of this function is responsible for unlocking. ctx.Lock.Lock() - userKeystore, err := keystore.CreateTestKeystore() - if err != nil { - tb.Fatal(err) - } - if err := userKeystore.AddUser(username, password); err != nil { + userKeystore := keystore.New(logging.NoLog{}, memdb.New()) + if err := userKeystore.CreateUser(username, password); err != nil { tb.Fatal(err) } ctx.Keystore = userKeystore.NewBlockchainKeyStore(ctx.ChainID) @@ -600,7 +597,7 @@ func TestIssueTx(t *testing.T) { } ctx.Lock.Lock() - if txs := vm.Pending(); len(txs) != 1 { + if txs := vm.PendingTxs(); len(txs) != 1 { t.Fatalf("Should have returned %d tx(s)", 1) } } @@ -799,7 +796,7 @@ func TestIssueDependentTx(t *testing.T) { } ctx.Lock.Lock() - if txs := vm.Pending(); len(txs) != 2 { + if txs := vm.PendingTxs(); len(txs) != 2 { t.Fatalf("Should have returned %d tx(s)", 2) } } @@ -1154,7 +1151,7 @@ func TestTxCached(t *testing.T) { newTx := NewTx(t, genesisBytes, vm) txBytes := newTx.Bytes() - _, err := vm.Parse(txBytes) + _, err := vm.ParseTx(txBytes) assert.NoError(t, err) db := mockdb.New() @@ -1166,7 +1163,7 @@ func TestTxCached(t *testing.T) { vm.state.state.DB = db vm.state.state.Cache.Flush() - _, err = vm.Parse(txBytes) + _, err = vm.ParseTx(txBytes) assert.NoError(t, err) assert.False(t, *called, "shouldn't have called the DB") } @@ -1184,7 +1181,7 @@ func TestTxNotCached(t *testing.T) { newTx := NewTx(t, genesisBytes, vm) txBytes := newTx.Bytes() - _, err := vm.Parse(txBytes) + _, err := vm.ParseTx(txBytes) assert.NoError(t, err) db := mockdb.New() @@ -1198,7 +1195,7 @@ func TestTxNotCached(t *testing.T) { vm.state.uniqueTx.Flush() vm.state.state.Cache.Flush() - _, err = vm.Parse(txBytes) + _, err = vm.ParseTx(txBytes) assert.NoError(t, err) assert.True(t, *called, "should have called the DB") } @@ -1281,7 +1278,7 @@ func TestTxVerifyAfterIssueTx(t *testing.T) { t.Fatal(err) } - parsedSecondTx, err := vm.Parse(secondTx.Bytes()) + parsedSecondTx, err := vm.ParseTx(secondTx.Bytes()) if err != nil { t.Fatal(err) } @@ -1302,7 +1299,7 @@ func TestTxVerifyAfterIssueTx(t *testing.T) { } ctx.Lock.Lock() - txs := vm.Pending() + txs := vm.PendingTxs() if len(txs) != 1 { t.Fatalf("Should have returned %d tx(s)", 1) } @@ -1391,7 +1388,7 @@ func TestTxVerifyAfterGet(t *testing.T) { t.Fatal(err) } - parsedSecondTx, err := vm.Parse(secondTx.Bytes()) + parsedSecondTx, err := vm.ParseTx(secondTx.Bytes()) if err != nil { t.Fatal(err) } @@ -1401,7 +1398,7 @@ func TestTxVerifyAfterGet(t *testing.T) { if _, err := vm.IssueTx(firstTx.Bytes()); err != nil { t.Fatal(err) } - parsedFirstTx, err := vm.Get(firstTx.ID()) + parsedFirstTx, err := vm.GetTx(firstTx.ID()) if err != nil { t.Fatal(err) } @@ -1524,7 +1521,7 @@ func TestTxVerifyAfterVerifyAncestorTx(t *testing.T) { t.Fatal(err) } - parsedSecondTx, err := vm.Parse(secondTx.Bytes()) + parsedSecondTx, err := vm.ParseTx(secondTx.Bytes()) if err != nil { t.Fatal(err) } @@ -1537,7 +1534,7 @@ func TestTxVerifyAfterVerifyAncestorTx(t *testing.T) { if _, err := vm.IssueTx(firstTxDescendant.Bytes()); err != nil { t.Fatal(err) } - parsedFirstTx, err := vm.Get(firstTx.ID()) + parsedFirstTx, err := vm.GetTx(firstTx.ID()) if err != nil { t.Fatal(err) } diff --git a/vms/manager.go b/vms/manager.go index 59e1a7047c30..d82f9f4fd128 100644 --- a/vms/manager.go +++ b/vms/manager.go @@ -7,7 +7,7 @@ import ( "fmt" "sync" - "github.com/ava-labs/avalanchego/api" + "github.com/ava-labs/avalanchego/api/server" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/engine/common" @@ -59,13 +59,13 @@ type manager struct { // The node's API server. // [manager] adds routes to this server to expose new API endpoints/services - apiServer *api.Server + apiServer *server.Server log logging.Logger } // NewManager returns an instance of a VM manager -func NewManager(apiServer *api.Server, log logging.Logger) Manager { +func NewManager(apiServer *server.Server, log logging.Logger) Manager { m := &manager{ vmFactories: make(map[ids.ID]VMFactory), apiServer: apiServer, diff --git a/vms/platformvm/add_delegator_tx_test.go b/vms/platformvm/add_delegator_tx_test.go index c8885c4d37c5..01e923a5ca5f 100644 --- a/vms/platformvm/add_delegator_tx_test.go +++ b/vms/platformvm/add_delegator_tx_test.go @@ -362,3 +362,276 @@ func TestAddDelegatorTxSemanticVerify(t *testing.T) { }) } } + +func TestAddDelegatorTxOverDelegatedRegression(t *testing.T) { + vm, _ := defaultVM() + vm.Ctx.Lock.Lock() + defer func() { + if err := vm.Shutdown(); err != nil { + t.Fatal(err) + } + vm.Ctx.Lock.Unlock() + }() + + validatorStartTime := defaultGenesisTime.Add(syncBound).Add(1 * time.Second) + validatorEndTime := validatorStartTime.Add(360 * 24 * time.Hour) + key, err := vm.factory.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + id := key.PublicKey().Address() + + // create valid tx + addValidatorTx, err := vm.newAddValidatorTx( + vm.minValidatorStake, + uint64(validatorStartTime.Unix()), + uint64(validatorEndTime.Unix()), + id, + id, + PercentDenominator, + []*crypto.PrivateKeySECP256K1R{keys[0]}, + ids.ShortEmpty, // change addr + ) + if err != nil { + t.Fatal(err) + } + + // trigger block creation + if err := vm.mempool.IssueTx(addValidatorTx); err != nil { + t.Fatal(err) + } + addValidatorBlockIntf, err := vm.BuildBlock() + if err != nil { + t.Fatal(err) + } + + // Verify the proposed block + if err := addValidatorBlockIntf.Verify(); err != nil { + t.Fatal(err) + } + + // Assert preferences are correct + addValidatorBlock := addValidatorBlockIntf.(*ProposalBlock) + options, err := addValidatorBlock.Options() + if err != nil { + t.Fatal(err) + } + + // verify the commit block + commit := options[0].(*Commit) + if err := commit.Verify(); err != nil { + t.Fatal(err) + } + + // Accept the proposal block and the commit block + if err := addValidatorBlock.Accept(); err != nil { + t.Fatal(err) + } + if err := commit.Accept(); err != nil { + t.Fatal(err) + } + + vm.clock.Set(validatorStartTime) + + firstAdvanceTimeBlockIntf, err := vm.BuildBlock() + if err != nil { + t.Fatal(err) + } + + // Verify the proposed block + if err := firstAdvanceTimeBlockIntf.Verify(); err != nil { + t.Fatal(err) + } + + // Assert preferences are correct + firstAdvanceTimeBlock := firstAdvanceTimeBlockIntf.(*ProposalBlock) + options, err = firstAdvanceTimeBlock.Options() + if err != nil { + t.Fatal(err) + } + + // verify the commit block + commit = options[0].(*Commit) + if err := commit.Verify(); err != nil { + t.Fatal(err) + } + + // Accept the proposal block and the commit block + if err := firstAdvanceTimeBlock.Accept(); err != nil { + t.Fatal(err) + } + if err := commit.Accept(); err != nil { + t.Fatal(err) + } + + firstDelegatorStartTime := validatorStartTime.Add(syncBound).Add(1 * time.Second) + firstDelegatorEndTime := firstDelegatorStartTime.Add(vm.minStakeDuration) + + // create valid tx + addFirstDelegatorTx, err := vm.newAddDelegatorTx( + 4*vm.minValidatorStake, // maximum amount of stake this delegator can provide + uint64(firstDelegatorStartTime.Unix()), + uint64(firstDelegatorEndTime.Unix()), + id, + keys[0].PublicKey().Address(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, + ids.ShortEmpty, // change addr + ) + if err != nil { + t.Fatal(err) + } + + // trigger block creation + if err := vm.mempool.IssueTx(addFirstDelegatorTx); err != nil { + t.Fatal(err) + } + addFirstDelegatorBlockIntf, err := vm.BuildBlock() + if err != nil { + t.Fatal(err) + } + + // Verify the proposed block + if err := addFirstDelegatorBlockIntf.Verify(); err != nil { + t.Fatal(err) + } + + // Assert preferences are correct + addFirstDelegatorBlock := addFirstDelegatorBlockIntf.(*ProposalBlock) + options, err = addFirstDelegatorBlock.Options() + if err != nil { + t.Fatal(err) + } + + // verify the commit block + commit = options[0].(*Commit) + if err := commit.Verify(); err != nil { + t.Fatal(err) + } + + // Accept the proposal block and the commit block + if err := addFirstDelegatorBlock.Accept(); err != nil { + t.Fatal(err) + } + if err := commit.Accept(); err != nil { + t.Fatal(err) + } + + vm.clock.Set(firstDelegatorStartTime) + + secondAdvanceTimeBlockIntf, err := vm.BuildBlock() + if err != nil { + t.Fatal(err) + } + + // Verify the proposed block + if err := secondAdvanceTimeBlockIntf.Verify(); err != nil { + t.Fatal(err) + } + + // Assert preferences are correct + secondAdvanceTimeBlock := secondAdvanceTimeBlockIntf.(*ProposalBlock) + options, err = secondAdvanceTimeBlock.Options() + if err != nil { + t.Fatal(err) + } + + // verify the commit block + commit = options[0].(*Commit) + if err := commit.Verify(); err != nil { + t.Fatal(err) + } + + // Accept the proposal block and the commit block + if err := secondAdvanceTimeBlock.Accept(); err != nil { + t.Fatal(err) + } + if err := commit.Accept(); err != nil { + t.Fatal(err) + } + + secondDelegatorStartTime := firstDelegatorEndTime.Add(2 * time.Second) + secondDelegatorEndTime := secondDelegatorStartTime.Add(vm.minStakeDuration) + + vm.clock.Set(secondDelegatorStartTime.Add(-10 * syncBound)) + + // create valid tx + addSecondDelegatorTx, err := vm.newAddDelegatorTx( + vm.minDelegatorStake, + uint64(secondDelegatorStartTime.Unix()), + uint64(secondDelegatorEndTime.Unix()), + id, + keys[0].PublicKey().Address(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1], keys[3]}, + ids.ShortEmpty, // change addr + ) + if err != nil { + t.Fatal(err) + } + + // trigger block creation + if err := vm.mempool.IssueTx(addSecondDelegatorTx); err != nil { + t.Fatal(err) + } + addSecondDelegatorBlockIntf, err := vm.BuildBlock() + if err != nil { + t.Fatal(err) + } + + // Verify the proposed block + if err := addSecondDelegatorBlockIntf.Verify(); err != nil { + t.Fatal(err) + } + + // Assert preferences are correct + addSecondDelegatorBlock := addSecondDelegatorBlockIntf.(*ProposalBlock) + options, err = addSecondDelegatorBlock.Options() + if err != nil { + t.Fatal(err) + } + + // verify the commit block + commit = options[0].(*Commit) + if err := commit.Verify(); err != nil { + t.Fatal(err) + } + + // Accept the proposal block and the commit block + if err := addSecondDelegatorBlock.Accept(); err != nil { + t.Fatal(err) + } + if err := commit.Accept(); err != nil { + t.Fatal(err) + } + + thirdDelegatorStartTime := firstDelegatorEndTime.Add(-time.Second) + thirdDelegatorEndTime := thirdDelegatorStartTime.Add(vm.minStakeDuration) + + // create valid tx + addThirdDelegatorTx, err := vm.newAddDelegatorTx( + vm.minDelegatorStake, + uint64(thirdDelegatorStartTime.Unix()), + uint64(thirdDelegatorEndTime.Unix()), + id, + keys[0].PublicKey().Address(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1], keys[4]}, + ids.ShortEmpty, // change addr + ) + if err != nil { + t.Fatal(err) + } + + // trigger block creation + if err := vm.mempool.IssueTx(addThirdDelegatorTx); err != nil { + t.Fatal(err) + } + + addThirdDelegatorBlockIntf, err := vm.BuildBlock() + if err != nil { + t.Fatal(err) + } + + // Verify the proposed block is invalid + if err := addThirdDelegatorBlockIntf.Verify(); err == nil { + t.Fatalf("should have marked the delegator as being over delegated") + } +} diff --git a/vms/platformvm/service.go b/vms/platformvm/service.go index 8f7ab07d9d24..e57253038ff5 100644 --- a/vms/platformvm/service.go +++ b/vms/platformvm/service.go @@ -2035,7 +2035,7 @@ func (service *Service) getStakeHelper(tx *Tx, addrs ids.ShortSet) (uint64, []av var ( totalAmountStaked uint64 err error - stakedOuts []avax.TransferableOutput + stakedOuts = make([]avax.TransferableOutput, 0, len(outs)) ) // Go through all of the staked outputs for _, stake := range outs { @@ -2064,15 +2064,6 @@ func (service *Service) getStakeHelper(tx *Tx, addrs ids.ShortSet) (uint64, []av // This output isn't owned by one of the given addresses. Ignore. continue } - // Parse the owners of this output to their formatted string representations - ownersStrs := []string{} - for _, addr := range secpOut.Addrs { - addrStr, err := service.vm.FormatLocalAddress(addr) - if err != nil { - return 0, nil, fmt.Errorf("couldn't format address %s: %w", addr, err) - } - ownersStrs = append(ownersStrs, addrStr) - } totalAmountStaked, err = math.Add64(totalAmountStaked, stake.Out.Amount()) if err != nil { return 0, stakedOuts, err diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index 6fe8367e9b51..1ebad04b8d5d 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -14,10 +14,12 @@ import ( "github.com/ava-labs/avalanchego/api" "github.com/ava-labs/avalanchego/api/keystore" + "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto" "github.com/ava-labs/avalanchego/utils/formatting" + "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/vms/avm" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/secp256k1fx" @@ -49,11 +51,8 @@ func defaultService(t *testing.T) *Service { vm, _ := defaultVM() vm.Ctx.Lock.Lock() defer vm.Ctx.Lock.Unlock() - ks, err := keystore.CreateTestKeystore() - if err != nil { - t.Fatal(err) - } - if err := ks.AddUser(testUsername, testPassword); err != nil { + ks := keystore.New(logging.NoLog{}, memdb.New()) + if err := ks.CreateUser(testUsername, testPassword); err != nil { t.Fatal(err) } vm.SnowmanVM.Ctx.Keystore = ks.NewBlockchainKeyStore(vm.SnowmanVM.Ctx.ChainID) diff --git a/vms/platformvm/state.go b/vms/platformvm/state.go index 253e66753bd0..0a23e2b805c5 100644 --- a/vms/platformvm/state.go +++ b/vms/platformvm/state.go @@ -521,7 +521,7 @@ func (vm *VM) getPaginatedUTXOs( start = startUTXOID } - utxoIDs, err := vm.getReferencingUTXOs(vm.DB, addr.Bytes(), start, searchSize) // Get UTXOs associated with [addr] + utxoIDs, err := vm.getReferencingUTXOs(db, addr.Bytes(), start, searchSize) // Get UTXOs associated with [addr] if err != nil { return nil, ids.ShortID{}, ids.ID{}, fmt.Errorf("couldn't get UTXOs for address %s: %w", addr, err) } @@ -533,7 +533,7 @@ func (vm *VM) getPaginatedUTXOs( continue } - utxo, err := vm.getUTXO(vm.DB, utxoID) + utxo, err := vm.getUTXO(db, utxoID) if err != nil { return nil, ids.ShortID{}, ids.ID{}, fmt.Errorf("couldn't get UTXO %s: %w", utxoID, err) } diff --git a/vms/platformvm/static_service.go b/vms/platformvm/static_service.go index 60fc2543a3e3..e0c7dedd632a 100644 --- a/vms/platformvm/static_service.go +++ b/vms/platformvm/static_service.go @@ -233,7 +233,6 @@ func (ss *StaticService) BuildGenesis(_ *http.Request, args *BuildGenesisArgs, r weight := uint64(0) stake := make([]*avax.TransferableOutput, len(validator.Staked)) sortAPIUTXOs(validator.Staked) - memo := []byte(nil) for i, apiUTXO := range validator.Staked { addrID, err := bech32ToID(apiUTXO.Address) if err != nil { @@ -264,11 +263,6 @@ func (ss *StaticService) BuildGenesis(_ *http.Request, args *BuildGenesisArgs, r return errStakeOverflow } weight = newWeight - messageBytes, err := formatting.Decode(args.Encoding, apiUTXO.Message) - if err != nil { - return fmt.Errorf("problem decoding validator UTXO message bytes: %w", err) - } - memo = append(memo, messageBytes...) } if weight == 0 { diff --git a/vms/platformvm/user.go b/vms/platformvm/user.go index 289f9628ca72..6729a8cd5e5f 100644 --- a/vms/platformvm/user.go +++ b/vms/platformvm/user.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" - "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/encdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto" ) @@ -23,7 +23,7 @@ var ( type user struct { // This user's database, acquired from the keystore - db database.Database + db *encdb.Database } // Get the addresses controlled by this user diff --git a/vms/platformvm/user_test.go b/vms/platformvm/user_test.go index ff9dd9561a0c..54ea7064deb0 100644 --- a/vms/platformvm/user_test.go +++ b/vms/platformvm/user_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/ava-labs/avalanchego/database/encdb" "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto" @@ -37,8 +38,10 @@ func TestUserNilDB(t *testing.T) { } func TestUserClosedDB(t *testing.T) { - db := memdb.New() - err := db.Close() + db, err := encdb.New([]byte(testPassword), memdb.New()) + assert.NoError(t, err) + + err = db.Close() assert.NoError(t, err) u := user{db} @@ -64,14 +67,20 @@ func TestUserClosedDB(t *testing.T) { } func TestUserNilSK(t *testing.T) { - u := user{db: memdb.New()} + db, err := encdb.New([]byte(testPassword), memdb.New()) + assert.NoError(t, err) - err := u.putAddress(nil) + u := user{db: db} + + err = u.putAddress(nil) assert.Error(t, err, "nil key should have caused an error") } func TestUser(t *testing.T) { - u := user{db: memdb.New()} + db, err := encdb.New([]byte(testPassword), memdb.New()) + assert.NoError(t, err) + + u := user{db: db} addresses, err := u.getAddresses() assert.NoError(t, err) diff --git a/vms/platformvm/vm.go b/vms/platformvm/vm.go index c11eeaf0e73e..97a737e27c4e 100644 --- a/vms/platformvm/vm.go +++ b/vms/platformvm/vm.go @@ -90,6 +90,7 @@ var ( _ block.ChainVM = &VM{} _ validators.Connector = &VM{} + _ common.StaticVM = &VM{} ) // VM implements the snowman.ChainVM interface @@ -1214,43 +1215,6 @@ func (vm *VM) getStakers() ([]validators.Validator, error) { return stakers, errs.Err } -// Returns the pending staker set of the Primary Network. -// Each element corresponds to a staking transaction. -// There may be multiple elements with the same node ID. -// TODO implement this more efficiently -func (vm *VM) getPendingStakers() ([]validators.Validator, error) { - startDBPrefix := []byte(fmt.Sprintf("%s%s", constants.PrimaryNetworkID, startDBPrefix)) - startDB := prefixdb.NewNested(startDBPrefix, vm.DB) - defer startDB.Close() - startIter := startDB.NewIterator() - defer startIter.Release() - - stakers := []validators.Validator{} - for startIter.Next() { // Iterates in order of increasing start time - txBytes := startIter.Value() - tx := rewardTx{} - if _, err := vm.codec.Unmarshal(txBytes, &tx); err != nil { - return nil, fmt.Errorf("couldn't unmarshal validator tx: %w", err) - } else if err := tx.Tx.Sign(vm.codec, nil); err != nil { - return nil, err - } - - switch staker := tx.Tx.UnsignedTx.(type) { - case *UnsignedAddDelegatorTx: - stakers = append(stakers, &staker.Validator) - case *UnsignedAddValidatorTx: - stakers = append(stakers, &staker.Validator) - } - } - - errs := wrappers.Errs{} - errs.Add( - startIter.Error(), - startDB.Close(), - ) - return stakers, errs.Err -} - // Returns the percentage of the total stake on the Primary Network // of nodes connected to this node. func (vm *VM) getPercentConnected() (float64, error) { @@ -1379,6 +1343,10 @@ func (vm *VM) maxStakeAmount(db database.Database, subnetID ids.ID, nodeID ids.S toRemove := toRemoveHeap[0] toRemoveHeap = toRemoveHeap[1:] + if currentWeight > maxWeight && !startTime.After(toRemove.EndTime()) { + maxWeight = currentWeight + } + newWeight, err := safemath.Sub64(currentWeight, toRemove.Wght) if err != nil { return 0, err diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index f67b926ed309..d5feaf68ac0b 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -350,10 +350,6 @@ func defaultVM() (*VM, database.Database) { return vm, baseDB } -func GenesisVM(t *testing.T) ([]byte, chan common.Message, *VM, *atomic.Memory) { - return GenesisVMWithArgs(t, nil) -} - func GenesisVMWithArgs(t *testing.T, args *BuildGenesisArgs) ([]byte, chan common.Message, *VM, *atomic.Memory) { var genesisBytes []byte @@ -2088,7 +2084,7 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { // Asynchronously passes messages from the network to the consensus engine handler := &router.Handler{} - handler.Initialize( + err = handler.Initialize( &engine, vdrs, msgChan, @@ -2100,6 +2096,7 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { prometheus.NewRegistry(), &router.Delay{}, ) + assert.NoError(t, err) // Allow incoming messages to be routed to the new chain chainRouter.AddChain(handler) diff --git a/vms/rpcchainvm/vm_client.go b/vms/rpcchainvm/vm_client.go index 5c31839420c6..520f06c08ba5 100644 --- a/vms/rpcchainvm/vm_client.go +++ b/vms/rpcchainvm/vm_client.go @@ -13,6 +13,8 @@ import ( "github.com/hashicorp/go-plugin" "github.com/prometheus/client_golang/prometheus" + "github.com/ava-labs/avalanchego/api/keystore/gkeystore" + "github.com/ava-labs/avalanchego/api/keystore/gkeystore/gkeystoreproto" "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/cache/metercacher" "github.com/ava-labs/avalanchego/database" @@ -30,8 +32,6 @@ import ( "github.com/ava-labs/avalanchego/vms/rpcchainvm/galiaslookup/galiaslookupproto" "github.com/ava-labs/avalanchego/vms/rpcchainvm/ghttp" "github.com/ava-labs/avalanchego/vms/rpcchainvm/ghttp/ghttpproto" - "github.com/ava-labs/avalanchego/vms/rpcchainvm/gkeystore" - "github.com/ava-labs/avalanchego/vms/rpcchainvm/gkeystore/gkeystoreproto" "github.com/ava-labs/avalanchego/vms/rpcchainvm/grpcutils" "github.com/ava-labs/avalanchego/vms/rpcchainvm/gsharedmemory" "github.com/ava-labs/avalanchego/vms/rpcchainvm/gsharedmemory/gsharedmemoryproto" diff --git a/vms/rpcchainvm/vm_server.go b/vms/rpcchainvm/vm_server.go index 3e9357f689bf..3ab891133feb 100644 --- a/vms/rpcchainvm/vm_server.go +++ b/vms/rpcchainvm/vm_server.go @@ -12,6 +12,8 @@ import ( "github.com/hashicorp/go-plugin" + "github.com/ava-labs/avalanchego/api/keystore/gkeystore" + "github.com/ava-labs/avalanchego/api/keystore/gkeystore/gkeystoreproto" "github.com/ava-labs/avalanchego/database/rpcdb" "github.com/ava-labs/avalanchego/database/rpcdb/rpcdbproto" "github.com/ava-labs/avalanchego/ids" @@ -24,8 +26,6 @@ import ( "github.com/ava-labs/avalanchego/vms/rpcchainvm/galiaslookup/galiaslookupproto" "github.com/ava-labs/avalanchego/vms/rpcchainvm/ghttp" "github.com/ava-labs/avalanchego/vms/rpcchainvm/ghttp/ghttpproto" - "github.com/ava-labs/avalanchego/vms/rpcchainvm/gkeystore" - "github.com/ava-labs/avalanchego/vms/rpcchainvm/gkeystore/gkeystoreproto" "github.com/ava-labs/avalanchego/vms/rpcchainvm/grpcutils" "github.com/ava-labs/avalanchego/vms/rpcchainvm/gsharedmemory" "github.com/ava-labs/avalanchego/vms/rpcchainvm/gsharedmemory/gsharedmemoryproto" diff --git a/vms/timestampvm/vm.go b/vms/timestampvm/vm.go index bc6c1bd3783e..e26dbec9e54a 100644 --- a/vms/timestampvm/vm.go +++ b/vms/timestampvm/vm.go @@ -28,7 +28,8 @@ var ( errNoPendingBlocks = errors.New("there is no block to propose") errBadGenesisBytes = errors.New("genesis data should be bytes (max length 32)") - _ block.ChainVM = &VM{} + _ block.ChainVM = &VM{} + _ common.StaticVM = &VM{} ) // VM implements the snowman.VM interface