diff --git a/network/acp118/aggregator.go b/network/acp118/aggregator.go new file mode 100644 index 000000000000..506f39eac7a9 --- /dev/null +++ b/network/acp118/aggregator.go @@ -0,0 +1,231 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package acp118 + +import ( + "context" + "errors" + "fmt" + "sync" + + "go.uber.org/zap" + "golang.org/x/sync/semaphore" + "google.golang.org/protobuf/proto" + + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" + "github.com/ava-labs/avalanchego/proto/pb/sdk" + "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/vms/platformvm/warp" +) + +var ( + ErrDuplicateValidator = errors.New("duplicate validator") + ErrInsufficientSignatures = errors.New("failed to aggregate sufficient stake weight of signatures") +) + +type result struct { + message *warp.Message + err error +} + +type Validator struct { + NodeID ids.NodeID + PublicKey *bls.PublicKey + Weight uint64 +} + +type indexedValidator struct { + Validator + I int +} + +// NewSignatureAggregator returns an instance of SignatureAggregator +func NewSignatureAggregator( + log logging.Logger, + client *p2p.Client, + maxPending int, +) *SignatureAggregator { + return &SignatureAggregator{ + log: log, + client: client, + maxPending: int64(maxPending), + } +} + +// SignatureAggregator aggregates validator signatures for warp messages +type SignatureAggregator struct { + log logging.Logger + client *p2p.Client + codec codec.Codec + maxPending int64 +} + +// AggregateSignatures blocks until stakeWeightThreshold of validators signs the +// provided message. Validators are issued requests in the caller-specified +// order. +func (s *SignatureAggregator) AggregateSignatures( + parentCtx context.Context, + message *warp.UnsignedMessage, + justification []byte, + validators []Validator, + stakeWeightThreshold uint64, +) (*warp.Message, error) { + ctx, cancel := context.WithCancel(parentCtx) + defer cancel() + + request := &sdk.SignatureRequest{ + Message: message.Bytes(), + Justification: justification, + } + + requestBytes, err := proto.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal signature request: %w", err) + } + + done := make(chan result) + pendingRequests := semaphore.NewWeighted(s.maxPending) + lock := &sync.Mutex{} + aggregatedStakeWeight := uint64(0) + attemptedStakeWeight := uint64(0) + totalStakeWeight := uint64(0) + signatures := make([]*bls.Signature, 0) + signerBitSet := set.NewBits() + + nodeIDsToValidator := make(map[ids.NodeID]indexedValidator) + for i, v := range validators { + totalStakeWeight += v.Weight + + // Sanity check the validator set provided by the caller + if _, ok := nodeIDsToValidator[v.NodeID]; ok { + return nil, fmt.Errorf("%w: %s", ErrDuplicateValidator, v.NodeID) + } + + nodeIDsToValidator[v.NodeID] = indexedValidator{ + I: i, + Validator: v, + } + } + + onResponse := func( + ctx context.Context, + nodeID ids.NodeID, + responseBytes []byte, + err error, + ) { + // We are guaranteed a response from a node in the validator set + validator := nodeIDsToValidator[nodeID] + + defer func() { + lock.Lock() + attemptedStakeWeight += validator.Weight + remainingStakeWeight := totalStakeWeight - attemptedStakeWeight + failed := remainingStakeWeight < stakeWeightThreshold + lock.Unlock() + + if failed { + done <- result{err: ErrInsufficientSignatures} + } + + pendingRequests.Release(1) + }() + + if err != nil { + s.log.Debug( + "dropping response", + zap.Stringer("nodeID", nodeID), + zap.Error(err), + ) + return + } + + response := &sdk.SignatureResponse{} + if err := proto.Unmarshal(responseBytes, response); err != nil { + s.log.Debug( + "dropping response", + zap.Stringer("nodeID", nodeID), + zap.Error(err), + ) + return + } + + signature, err := bls.SignatureFromBytes(response.Signature) + if err != nil { + s.log.Debug( + "dropping response", + zap.Stringer("nodeID", nodeID), + zap.String("reason", "invalid signature"), + zap.Error(err), + ) + return + } + + if !bls.Verify(validator.PublicKey, signature, message.Bytes()) { + s.log.Debug( + "dropping response", + zap.Stringer("nodeID", nodeID), + zap.String("reason", "public key failed verification"), + ) + return + } + + lock.Lock() + signerBitSet.Add(validator.I) + signatures = append(signatures, signature) + aggregatedStakeWeight += validator.Weight + + if aggregatedStakeWeight >= stakeWeightThreshold { + aggregateSignature, err := bls.AggregateSignatures(signatures) + if err != nil { + done <- result{err: err} + lock.Unlock() + return + } + + bitSetSignature := &warp.BitSetSignature{ + Signers: signerBitSet.Bytes(), + Signature: [bls.SignatureLen]byte{}, + } + + copy(bitSetSignature.Signature[:], bls.SignatureToBytes(aggregateSignature)) + signedMessage, err := warp.NewMessage(message, bitSetSignature) + done <- result{message: signedMessage, err: err} + lock.Unlock() + return + } + + lock.Unlock() + } + + for _, validator := range validators { + if err := pendingRequests.Acquire(ctx, 1); err != nil { + return nil, err + } + + // Avoid loop shadowing in goroutine + validatorCopy := validator + go func() { + if err := s.client.AppRequest( + ctx, + set.Of(validatorCopy.NodeID), + requestBytes, + onResponse, + ); err != nil { + done <- result{err: err} + return + } + }() + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case r := <-done: + return r.message, r.err + } +} diff --git a/network/acp118/aggregator_test.go b/network/acp118/aggregator_test.go new file mode 100644 index 000000000000..38185d57af34 --- /dev/null +++ b/network/acp118/aggregator_test.go @@ -0,0 +1,123 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package acp118 + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" + "github.com/ava-labs/avalanchego/network/p2p/p2ptest" + "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/vms/platformvm/warp" +) + +func TestVerifier_Verify(t *testing.T) { + nodeID0 := ids.GenerateTestNodeID() + sk0, err := bls.NewSecretKey() + require.NoError(t, err) + pk0 := bls.PublicFromSecretKey(sk0) + networkID := uint32(123) + chainID := ids.GenerateTestID() + signer := warp.NewSigner(sk0, networkID, chainID) + + tests := []struct { + name string + + ctx context.Context + validators []Validator + threshold uint64 + + handler p2p.Handler + + wantErr error + }{ + { + name: "passes verification", + ctx: context.Background(), + validators: []Validator{ + { + NodeID: nodeID0, + PublicKey: pk0, + Weight: 1, + }, + }, + threshold: 1, + handler: NewHandler(&testAttestor{}, signer, networkID, chainID), + }, + { + name: "fails verification", + ctx: context.Background(), + validators: []Validator{ + { + NodeID: nodeID0, + PublicKey: pk0, + Weight: 1, + }, + }, + threshold: 1, + handler: NewHandler( + &testAttestor{Err: errors.New("foobar")}, + signer, + networkID, + chainID, + ), + wantErr: ErrInsufficientSignatures, + }, + { + name: "invalid validator set", + ctx: context.Background(), + validators: []Validator{ + { + NodeID: nodeID0, + PublicKey: pk0, + Weight: 1, + }, + { + NodeID: nodeID0, + PublicKey: pk0, + Weight: 1, + }, + }, + wantErr: ErrDuplicateValidator, + }, + { + name: "context canceled", + ctx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + return ctx + }(), + threshold: 1, + wantErr: context.Canceled, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := context.Background() + message, err := warp.NewUnsignedMessage(networkID, chainID, []byte("payload")) + require.NoError(err) + client := p2ptest.NewClient(t, ctx, tt.handler, ids.GenerateTestNodeID(), nodeID0) + verifier := NewSignatureAggregator(logging.NoLog{}, client, 1) + + _, err = verifier.AggregateSignatures( + tt.ctx, + message, + []byte("justification"), + tt.validators, + tt.threshold, + ) + require.ErrorIs(err, tt.wantErr) + }) + } +} diff --git a/network/acp118/handler.go b/network/acp118/handler.go new file mode 100644 index 000000000000..e4aae28e3ce6 --- /dev/null +++ b/network/acp118/handler.go @@ -0,0 +1,107 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package acp118 + +import ( + "context" + "fmt" + "time" + + "google.golang.org/protobuf/proto" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" + "github.com/ava-labs/avalanchego/proto/pb/sdk" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/vms/platformvm/warp" +) + +var _ p2p.Handler = (*Handler)(nil) + +// Attestor defines whether to a warp message payload should be attested to +type Attestor interface { + Attest(message *warp.UnsignedMessage, justification []byte) (bool, error) +} + +// NewHandler returns an instance of Handler +func NewHandler( + attestor Attestor, + signer warp.Signer, + networkID uint32, + chainID ids.ID, +) *Handler { + return &Handler{ + attestor: attestor, + signer: signer, + networkID: networkID, + chainID: chainID, + } +} + +// Handler signs warp messages +type Handler struct { + p2p.NoOpHandler + + attestor Attestor + signer warp.Signer + networkID uint32 + chainID ids.ID +} + +func (h *Handler) AppRequest( + _ context.Context, + _ ids.NodeID, + _ time.Time, + requestBytes []byte, +) ([]byte, *common.AppError) { + request := &sdk.SignatureRequest{} + if err := proto.Unmarshal(requestBytes, request); err != nil { + return nil, &common.AppError{ + Code: p2p.ErrUnexpected.Code, + Message: fmt.Sprintf("failed to unmarshal request: %s", err), + } + } + + msg, err := warp.ParseUnsignedMessage(request.Message) + if err != nil { + return nil, &common.AppError{ + Code: p2p.ErrUnexpected.Code, + Message: fmt.Sprintf("failed to initialize warp unsigned message: %s", err), + } + } + + ok, err := h.attestor.Attest(msg, request.Justification) + if err != nil { + return nil, &common.AppError{ + Code: p2p.ErrUnexpected.Code, + Message: fmt.Sprintf("failed to attest request: %s", err), + } + } + + if !ok { + return nil, p2p.ErrAttestFailed + } + + signature, err := h.signer.Sign(msg) + if err != nil { + return nil, &common.AppError{ + Code: p2p.ErrUnexpected.Code, + Message: fmt.Sprintf("failed to sign message: %s", err), + } + } + + response := &sdk.SignatureResponse{ + Signature: signature, + } + + responseBytes, err := proto.Marshal(response) + if err != nil { + return nil, &common.AppError{ + Code: p2p.ErrUnexpected.Code, + Message: fmt.Sprintf("failed to marshal response: %s", err), + } + } + + return responseBytes, nil +} diff --git a/network/acp118/handler_test.go b/network/acp118/handler_test.go new file mode 100644 index 000000000000..26a2f1002a2f --- /dev/null +++ b/network/acp118/handler_test.go @@ -0,0 +1,118 @@ +package acp118 + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" + "github.com/ava-labs/avalanchego/network/p2p/p2ptest" + "github.com/ava-labs/avalanchego/proto/pb/sdk" + "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/vms/platformvm/warp" +) + +var _ Attestor = (*testAttestor)(nil) + +func TestHandler(t *testing.T) { + tests := []struct { + name string + attestor Attestor + expectedErr error + expectedVerify bool + }{ + { + name: "signature fails attestation", + attestor: &testAttestor{Err: errors.New("foo")}, + expectedErr: p2p.ErrUnexpected, + }, + { + name: "signature not attested", + attestor: &testAttestor{CantAttest: true}, + expectedErr: p2p.ErrAttestFailed, + }, + { + name: "signature attested", + attestor: &testAttestor{}, + expectedVerify: true, + }, + } + + for _, test := range tests { + // Prevent loop shadowing in goroutines + tt := test + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + ctx := context.Background() + sk, err := bls.NewSecretKey() + require.NoError(err) + pk := bls.PublicFromSecretKey(sk) + networkID := uint32(123) + chainID := ids.GenerateTestID() + signer := warp.NewSigner(sk, networkID, chainID) + h := NewHandler(tt.attestor, signer, networkID, chainID) + clientNodeID := ids.GenerateTestNodeID() + serverNodeID := ids.GenerateTestNodeID() + c := p2ptest.NewClient( + t, + ctx, + h, + clientNodeID, + serverNodeID, + ) + + unsignedMessage, err := warp.NewUnsignedMessage( + networkID, + chainID, + []byte("payload"), + ) + require.NoError(err) + + request := &sdk.SignatureRequest{ + Message: unsignedMessage.Bytes(), + Justification: []byte("justification"), + } + + requestBytes, err := proto.Marshal(request) + require.NoError(err) + + done := make(chan struct{}) + onResponse := func(ctx context.Context, _ ids.NodeID, responseBytes []byte, appErr error) { + defer close(done) + + if appErr != nil { + require.ErrorIs(tt.expectedErr, appErr) + return + } + + response := &sdk.SignatureResponse{} + require.NoError(proto.Unmarshal(responseBytes, response)) + + signature, err := bls.SignatureFromBytes(response.Signature) + require.NoError(err) + + require.Equal(tt.expectedVerify, bls.Verify(pk, signature, request.Message)) + } + + require.NoError(c.AppRequest(ctx, set.Of(serverNodeID), requestBytes, onResponse)) + <-done + }) + } + +} + +// The zero value of testAttestor attests +type testAttestor struct { + CantAttest bool + Err error +} + +func (t testAttestor) Attest(*warp.UnsignedMessage, []byte) (bool, error) { + return !t.CantAttest, t.Err +} diff --git a/network/p2p/error.go b/network/p2p/error.go index 07207319a041..120cfeb75bc4 100644 --- a/network/p2p/error.go +++ b/network/p2p/error.go @@ -30,4 +30,8 @@ var ( Code: -4, Message: "throttled", } + ErrAttestFailed = &common.AppError{ + Code: -5, + Message: "failed attestation", + } )