diff --git a/pkg/solana/chainreader/account_read_binding.go b/pkg/solana/chainreader/account_read_binding.go index 128d38cd1..71ebb131b 100644 --- a/pkg/solana/chainreader/account_read_binding.go +++ b/pkg/solana/chainreader/account_read_binding.go @@ -2,7 +2,6 @@ package chainreader import ( "context" - "fmt" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" @@ -10,94 +9,37 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" ) -// BinaryDataReader provides an interface for reading bytes from a source. This is likely a wrapper -// for a solana client. -type BinaryDataReader interface { - ReadAll(context.Context, solana.PublicKey, *rpc.GetAccountInfoOpts) ([]byte, error) -} - // accountReadBinding provides decoding and reading Solana Account data using a defined codec. The // `idlAccount` refers to the account name in the IDL for which the codec has a type mapping. type accountReadBinding struct { idlAccount string codec types.RemoteCodec - reader BinaryDataReader + key solana.PublicKey opts *rpc.GetAccountInfoOpts } -func newAccountReadBinding(acct string, codec types.RemoteCodec, reader BinaryDataReader, opts *rpc.GetAccountInfoOpts) *accountReadBinding { +func newAccountReadBinding(acct string, codec types.RemoteCodec, opts *rpc.GetAccountInfoOpts) *accountReadBinding { return &accountReadBinding{ idlAccount: acct, codec: codec, - reader: reader, opts: opts, } } var _ readBinding = &accountReadBinding{} -func (b *accountReadBinding) PreLoad(ctx context.Context, address string, result *loadedResult) { - if result == nil { - return - } - - account, err := solana.PublicKeyFromBase58(address) - if err != nil { - result.err <- err - - return - } - - bts, err := b.reader.ReadAll(ctx, account, b.opts) - if err != nil { - result.err <- fmt.Errorf("%w: failed to get binary data", err) - - return - } - - select { - case <-ctx.Done(): - result.err <- ctx.Err() - default: - result.value <- bts - } +func (b *accountReadBinding) SetAddress(key solana.PublicKey) { + b.key = key } -func (b *accountReadBinding) GetLatestValue(ctx context.Context, address string, _ any, outVal any, result *loadedResult) error { - var ( - bts []byte - err error - ) - - if result != nil { - // when preloading, the process will wait for one of three conditions: - // 1. the context ends and returns an error - // 2. bytes were loaded in the bytes channel - // 3. an error was loaded in the err channel - select { - case <-ctx.Done(): - err = ctx.Err() - case bts = <-result.value: - case err = <-result.err: - } - - if err != nil { - return err - } - } else { - account, err := solana.PublicKeyFromBase58(address) - if err != nil { - return err - } - - if bts, err = b.reader.ReadAll(ctx, account, b.opts); err != nil { - return fmt.Errorf("%w: failed to get binary data", err) - } - } - - return b.codec.Decode(ctx, bts, outVal, b.idlAccount) +func (b *accountReadBinding) GetAddress() solana.PublicKey { + return b.key } func (b *accountReadBinding) CreateType(_ bool) (any, error) { return b.codec.CreateType(b.idlAccount, false) } + +func (b *accountReadBinding) Decode(ctx context.Context, bts []byte, outVal any) error { + return b.codec.Decode(ctx, bts, outVal, b.idlAccount) +} diff --git a/pkg/solana/chainreader/account_read_binding_test.go b/pkg/solana/chainreader/account_read_binding_test.go deleted file mode 100644 index 3ea899cc2..000000000 --- a/pkg/solana/chainreader/account_read_binding_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package chainreader - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/gagliardetto/solana-go" - "github.com/gagliardetto/solana-go/rpc" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - - "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" - "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary" - "github.com/smartcontractkit/chainlink-common/pkg/types" -) - -func TestPreload(t *testing.T) { - t.Parallel() - - testCodec := makeTestCodec(t) - - t.Run("get latest value waits for preload", func(t *testing.T) { - t.Parallel() - - reader := new(mockReader) - binding := newAccountReadBinding(testCodecKey, testCodec, reader, nil) - - expected := testStruct{A: true, B: 42} - bts, err := testCodec.Encode(context.Background(), expected, testCodecKey) - - require.NoError(t, err) - - reader.On("ReadAll", mock.Anything, mock.Anything, mock.Anything).Return(bts, nil).After(time.Second) - - ctx := context.Background() - start := time.Now() - loaded := &loadedResult{ - value: make(chan []byte, 1), - err: make(chan error, 1), - } - - pubKey := solana.NewWallet().PublicKey() - - binding.PreLoad(ctx, pubKey.String(), loaded) - - var result testStruct - - err = binding.GetLatestValue(ctx, pubKey.String(), nil, &result, loaded) - elapsed := time.Since(start) - - require.NoError(t, err) - assert.GreaterOrEqual(t, elapsed, time.Second) - assert.Less(t, elapsed, 1100*time.Millisecond) - assert.Equal(t, expected, result) - }) - - t.Run("cancelled context exits preload and returns error on get latest value", func(t *testing.T) { - t.Parallel() - - reader := new(mockReader) - binding := newAccountReadBinding(testCodecKey, testCodec, reader, nil) - - ctx, cancel := context.WithCancelCause(context.Background()) - - // make the readall pause until after the context is cancelled - reader.On("ReadAll", mock.Anything, mock.Anything, mock.Anything). - Return([]byte{}, nil). - After(600 * time.Millisecond) - - expectedErr := errors.New("test error") - go func() { - time.Sleep(500 * time.Millisecond) - cancel(expectedErr) - }() - - pubKey := solana.NewWallet().PublicKey() - loaded := &loadedResult{ - value: make(chan []byte, 1), - err: make(chan error, 1), - } - start := time.Now() - binding.PreLoad(ctx, pubKey.String(), loaded) - - var result testStruct - err := binding.GetLatestValue(ctx, pubKey.String(), nil, &result, loaded) - elapsed := time.Since(start) - - assert.ErrorIs(t, err, ctx.Err()) - assert.ErrorIs(t, context.Cause(ctx), expectedErr) - assert.GreaterOrEqual(t, elapsed, 600*time.Millisecond) - assert.Less(t, elapsed, 700*time.Millisecond) - }) - - t.Run("error from preload is returned in get latest value", func(t *testing.T) { - t.Parallel() - - reader := new(mockReader) - binding := newAccountReadBinding(testCodecKey, testCodec, reader, nil) - ctx := context.Background() - expectedErr := errors.New("test error") - - reader.On("ReadAll", mock.Anything, mock.Anything, mock.Anything). - Return([]byte{}, expectedErr) - - pubKey := solana.NewWallet().PublicKey() - loaded := &loadedResult{ - value: make(chan []byte, 1), - err: make(chan error, 1), - } - binding.PreLoad(ctx, pubKey.String(), loaded) - - var result testStruct - err := binding.GetLatestValue(ctx, pubKey.String(), nil, &result, loaded) - - assert.ErrorIs(t, err, expectedErr) - }) -} - -type mockReader struct { - mock.Mock -} - -func (_m *mockReader) ReadAll(ctx context.Context, pk solana.PublicKey, opts *rpc.GetAccountInfoOpts) ([]byte, error) { - ret := _m.Called(ctx, pk) - - var r0 []byte - if val, ok := ret.Get(0).([]byte); ok { - r0 = val - } - - var r1 error - if fn, ok := ret.Get(1).(func() error); ok { - r1 = fn() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type testStruct struct { - A bool - B int64 -} - -const testCodecKey = "TEST" - -func makeTestCodec(t *testing.T) types.RemoteCodec { - t.Helper() - - builder := binary.LittleEndian() - - structCodec, err := encodings.NewStructCodec([]encodings.NamedTypeCodec{ - {Name: "A", Codec: builder.Bool()}, - {Name: "B", Codec: builder.Int64()}, - }) - - require.NoError(t, err) - - return encodings.CodecFromTypeCodec(map[string]encodings.TypeCodec{testCodecKey: structCodec}) -} diff --git a/pkg/solana/chainreader/batch.go b/pkg/solana/chainreader/batch.go new file mode 100644 index 000000000..43e4971b9 --- /dev/null +++ b/pkg/solana/chainreader/batch.go @@ -0,0 +1,107 @@ +package chainreader + +import ( + "context" + "errors" + + "github.com/gagliardetto/solana-go" + + "github.com/smartcontractkit/chainlink-common/pkg/values" +) + +type call struct { + ContractName, ReadName string + Params, ReturnVal any +} + +type batchResultWithErr struct { + address string + contractName, readName string + returnVal any + err error +} + +var ( + ErrMissingAccountData = errors.New("account data not found") +) + +type MultipleAccountGetter interface { + GetMultipleAccountData(context.Context, ...solana.PublicKey) ([][]byte, error) +} + +func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindings namespaceBindings, batch []call) ([]batchResultWithErr, error) { + // Create the list of public keys to fetch + keys := make([]solana.PublicKey, len(batch)) + for idx, call := range batch { + binding, err := bindings.GetReadBinding(call.ContractName, call.ReadName) + if err != nil { + return nil, err + } + + keys[idx] = binding.GetAddress() + } + + // Fetch the account data + data, err := client.GetMultipleAccountData(ctx, keys...) + if err != nil { + return nil, err + } + + results := make([]batchResultWithErr, len(batch)) + + // decode batch call results + for idx, call := range batch { + results[idx] = batchResultWithErr{ + address: keys[idx].String(), + contractName: call.ContractName, + readName: call.ReadName, + returnVal: call.ReturnVal, + } + + if data[idx] == nil || len(data[idx]) == 0 { + results[idx].err = ErrMissingAccountData + + continue + } + + binding, err := bindings.GetReadBinding(results[idx].contractName, results[idx].readName) + if err != nil { + results[idx].err = err + + continue + } + + ptrToValue, isValue := call.ReturnVal.(*values.Value) + if !isValue { + results[idx].err = errors.Join( + results[idx].err, + binding.Decode(ctx, data[idx], results[idx].returnVal), + ) + + continue + } + + contractType, err := binding.CreateType(false) + if err != nil { + results[idx].err = err + + continue + } + + results[idx].err = errors.Join( + results[idx].err, + binding.Decode(ctx, data[idx], contractType), + ) + + value, err := values.Wrap(contractType) + if err != nil { + results[idx].err = errors.Join(results[idx].err, err) + + continue + } + + *ptrToValue = value + } + + return results, nil +} diff --git a/pkg/solana/chainreader/bindings.go b/pkg/solana/chainreader/bindings.go index 39eb07f8a..51cc8980a 100644 --- a/pkg/solana/chainreader/bindings.go +++ b/pkg/solana/chainreader/bindings.go @@ -3,7 +3,6 @@ package chainreader import ( "context" "fmt" - "reflect" "github.com/gagliardetto/solana-go" @@ -11,128 +10,63 @@ import ( ) type readBinding interface { - PreLoad(context.Context, string, *loadedResult) - GetLatestValue(ctx context.Context, address string, params, returnVal any, preload *loadedResult) error + SetAddress(solana.PublicKey) + GetAddress() solana.PublicKey CreateType(bool) (any, error) + Decode(context.Context, []byte, any) error } // key is namespace -type namespaceBindings map[string]methodBindings +type namespaceBindings map[string]readNameBindings // key is method name -type methodBindings map[string]readBindings +type readNameBindings map[string]readBinding -// read bindings is a list of bindings by index -type readBindings []readBinding - -func (b namespaceBindings) AddReadBinding(namespace, methodName string, reader readBinding) { - nbs, nbsExists := b[namespace] - if !nbsExists { - nbs = methodBindings{} - b[namespace] = nbs - } - - rbs, rbsExists := nbs[methodName] - if !rbsExists { - rbs = []readBinding{} +func (b namespaceBindings) AddReadBinding(namespace, readName string, reader readBinding) { + if _, nbsExists := b[namespace]; !nbsExists { + b[namespace] = readNameBindings{} } - b[namespace][methodName] = append(rbs, reader) + b[namespace][readName] = reader } -func (b namespaceBindings) GetReadBindings(namespace, methodName string) ([]readBinding, error) { +func (b namespaceBindings) GetReadBinding(namespace, readName string) (readBinding, error) { nbs, nbsExists := b[namespace] if !nbsExists { return nil, fmt.Errorf("%w: no read binding exists for %s", types.ErrInvalidConfig, namespace) } - rbs, rbsExists := nbs[methodName] + rbs, rbsExists := nbs[readName] if !rbsExists { - return nil, fmt.Errorf("%w: no read binding exists for %s and %s", types.ErrInvalidConfig, namespace, methodName) + return nil, fmt.Errorf("%w: no read binding exists for %s and %s", types.ErrInvalidConfig, namespace, readName) } return rbs, nil } -func (b namespaceBindings) CreateType(namespace, methodName string, forEncoding bool) (any, error) { - bindings, err := b.GetReadBindings(namespace, methodName) +func (b namespaceBindings) CreateType(namespace, readName string, forEncoding bool) (any, error) { + binding, err := b.GetReadBinding(namespace, readName) if err != nil { return nil, err } - if len(bindings) == 1 { - // get the item type from the binding codec - return bindings[0].CreateType(forEncoding) - } - - // build a merged struct from all bindings - fields := make([]reflect.StructField, 0) - var fieldIdx int - fieldNames := make(map[string]struct{}) - - for _, binding := range bindings { - bindingType, err := binding.CreateType(forEncoding) - if err != nil { - return nil, err - } - - tBinding := reflect.TypeOf(bindingType) - if tBinding.Kind() == reflect.Pointer { - tBinding = tBinding.Elem() - } - - // all bindings must be structs to allow multiple bindings - if tBinding.Kind() != reflect.Struct { - return nil, fmt.Errorf("%w: support for multiple bindings only applies to all bindings having the type struct", types.ErrInvalidType) - } - - for idx := 0; idx < tBinding.NumField(); idx++ { - value := tBinding.FieldByIndex([]int{idx}) - - _, exists := fieldNames[value.Name] - if exists { - return nil, fmt.Errorf("%w: field name overlap on %s", types.ErrInvalidConfig, value.Name) - } - - field := reflect.StructField{ - Name: value.Name, - Type: value.Type, - Index: []int{fieldIdx}, - } - - fields = append(fields, field) - - fieldIdx++ - fieldNames[value.Name] = struct{}{} - } - } - - return reflect.New(reflect.StructOf(fields)).Interface(), nil + return binding.CreateType(forEncoding) } func (b namespaceBindings) Bind(binding types.BoundContract) error { - _, nbsExist := b[binding.Name] + bnd, nbsExist := b[binding.Name] if !nbsExist { return fmt.Errorf("%w: no namespace named %s", types.ErrInvalidConfig, binding.Name) } - readAddresses, err := decodeAddressMappings(binding.Address) + key, err := solana.PublicKeyFromBase58(binding.Address) if err != nil { return err } - for readName, addresses := range readAddresses { - for idx, address := range addresses { - if _, err := solana.PublicKeyFromBase58(address); err != nil { - return fmt.Errorf("%w: invalid address binding for %s at index %d: %s", types.ErrInvalidConfig, readName, idx, err.Error()) - } - } + for _, rb := range bnd { + rb.SetAddress(key) } return nil } - -type loadedResult struct { - value chan []byte - err chan error -} diff --git a/pkg/solana/chainreader/bindings_test.go b/pkg/solana/chainreader/bindings_test.go index 9ba66aa5f..ee6287afe 100644 --- a/pkg/solana/chainreader/bindings_test.go +++ b/pkg/solana/chainreader/bindings_test.go @@ -2,9 +2,9 @@ package chainreader import ( "context" - "reflect" "testing" + "github.com/gagliardetto/solana-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -31,70 +31,11 @@ func TestBindings_CreateType(t *testing.T) { assert.Equal(t, expected, returned) }) - t.Run("multiple bindings return merged struct", func(t *testing.T) { + t.Run("returns error when binding does not exist", func(t *testing.T) { t.Parallel() - bindingA := new(mockBinding) - bindingB := new(mockBinding) bindings := namespaceBindings{} - bindings.AddReadBinding("A", "B", bindingA) - bindings.AddReadBinding("A", "B", bindingB) - - bindingA.On("CreateType", mock.Anything).Return(struct{ A string }{A: "test"}, nil) - bindingB.On("CreateType", mock.Anything).Return(struct{ B int }{B: 8}, nil) - - result, err := bindings.CreateType("A", "B", true) - - expected := reflect.New(reflect.StructOf([]reflect.StructField{ - {Name: "A", Type: reflect.TypeOf("")}, - {Name: "B", Type: reflect.TypeOf(0)}, - })) - - require.NoError(t, err) - assert.Equal(t, expected.Type(), reflect.TypeOf(result)) - }) - - t.Run("multiple bindings fails when not a struct", func(t *testing.T) { - t.Parallel() - - bindingA := new(mockBinding) - bindingB := new(mockBinding) - bindings := namespaceBindings{} - - bindings.AddReadBinding("A", "B", bindingA) - bindings.AddReadBinding("A", "B", bindingB) - - bindingA.On("CreateType", mock.Anything).Return(8, nil) - bindingB.On("CreateType", mock.Anything).Return(struct{ A string }{A: "test"}, nil) - - _, err := bindings.CreateType("A", "B", true) - - require.ErrorIs(t, err, types.ErrInvalidType) - }) - - t.Run("multiple bindings errors when fields overlap", func(t *testing.T) { - t.Parallel() - - bindingA := new(mockBinding) - bindingB := new(mockBinding) - bindings := namespaceBindings{} - - bindings.AddReadBinding("A", "B", bindingA) - bindings.AddReadBinding("A", "B", bindingB) - - type A struct { - A string - B int - } - - type B struct { - A int - } - - bindingA.On("CreateType", mock.Anything).Return(A{A: ""}, nil) - bindingB.On("CreateType", mock.Anything).Return(B{A: 8}, nil) - _, err := bindings.CreateType("A", "B", true) require.ErrorIs(t, err, types.ErrInvalidConfig) @@ -105,10 +46,10 @@ type mockBinding struct { mock.Mock } -func (_m *mockBinding) PreLoad(context.Context, string, *loadedResult) {} +func (_m *mockBinding) SetAddress(_ solana.PublicKey) {} -func (_m *mockBinding) GetLatestValue(ctx context.Context, address string, params, returnVal any, _ *loadedResult) error { - return nil +func (_m *mockBinding) GetAddress() solana.PublicKey { + return solana.PublicKey{} } func (_m *mockBinding) CreateType(b bool) (any, error) { @@ -116,3 +57,7 @@ func (_m *mockBinding) CreateType(b bool) (any, error) { return ret.Get(0), ret.Error(1) } + +func (_ *mockBinding) Decode(_ context.Context, _ []byte, _ any) error { + return nil +} diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index ba0093edc..d017eb25d 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -2,14 +2,12 @@ package chainreader import ( "context" - "encoding/base64" "encoding/json" "errors" "fmt" - "reflect" "sync" - ag_solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" codeccommon "github.com/smartcontractkit/chainlink-common/pkg/codec" @@ -18,7 +16,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query" "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" - "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" @@ -31,7 +28,7 @@ type SolanaChainReaderService struct { // provided values lggr logger.Logger - client BinaryDataReader + client MultipleAccountGetter // internal values bindings namespaceBindings @@ -48,7 +45,7 @@ var ( ) // NewChainReaderService is a constructor for a new ChainReaderService for Solana. Returns a nil service on error. -func NewChainReaderService(lggr logger.Logger, dataReader BinaryDataReader, cfg config.ChainReader) (*SolanaChainReaderService, error) { +func NewChainReaderService(lggr logger.Logger, dataReader MultipleAccountGetter, cfg config.ChainReader) (*SolanaChainReaderService, error) { svc := &SolanaChainReaderService{ lggr: logger.Named(lggr, ServiceName), client: dataReader, @@ -114,123 +111,73 @@ func (s *SolanaChainReaderService) GetLatestValue(ctx context.Context, readIdent return fmt.Errorf("%w: no contract for read identifier %s", types.ErrInvalidType, readIdentifier) } - addressMappings, err := decodeAddressMappings(vals.address) - if err != nil { - return fmt.Errorf("%w: %s", types.ErrInvalidConfig, err) - } - - addresses, ok := addressMappings[vals.readName] - if !ok { - return fmt.Errorf("%w: no addresses for readName %s", types.ErrInvalidConfig, vals.readName) + batch := []call{ + { + ContractName: vals.contract, + ReadName: vals.readName, + Params: params, + ReturnVal: returnVal, + }, } - bindings, err := s.bindings.GetReadBindings(vals.contract, vals.readName) + results, err := doMethodBatchCall(ctx, s.client, s.bindings, batch) if err != nil { return err } - if len(addresses) != len(bindings) { - return fmt.Errorf("%w: addresses and bindings lengths do not match", types.ErrInvalidConfig) + if len(results) != len(batch) { + return fmt.Errorf("%w: unexpected number of results", types.ErrInternal) } - // if the returnVal is not a *values.Value, run normally without using the ptrToValue - ptrToValue, isValue := returnVal.(*values.Value) - if !isValue { - return s.runAllBindings(ctx, bindings, addresses, params, returnVal) - } - - // if the returnVal is a *values.Value, create the type from the contract, run normally, and wrap the value - contractType, err := s.bindings.CreateType(vals.contract, vals.readName, false) - if err != nil { - return err - } - - if err = s.runAllBindings(ctx, bindings, addresses, params, contractType); err != nil { - return err - } - - value, err := values.Wrap(contractType) - if err != nil { - return err + if results[0].err != nil { + return fmt.Errorf("%w: %s", types.ErrInternal, results[0].err) } - *ptrToValue = value - return nil } -func (s *SolanaChainReaderService) runAllBindings( - ctx context.Context, - bindings []readBinding, - addresses []string, - params, returnVal any, -) error { - localCtx, localCancel := context.WithCancel(ctx) - - // the wait group ensures GetLatestValue returns only after all go-routines have completed - var wg sync.WaitGroup - - results := make(map[int]*loadedResult) - - if len(bindings) > 1 { - // might go for some guardrails when dealing with multiple bindings - // the returnVal should be compatible with multiple passes by the codec decoder - // this should only apply to types struct{} and map[any]any - tReturnVal := reflect.TypeOf(returnVal) - if tReturnVal.Kind() == reflect.Pointer { - tReturnVal = reflect.Indirect(reflect.ValueOf(returnVal)).Type() - } - - switch tReturnVal.Kind() { - case reflect.Struct, reflect.Map: - default: - localCancel() - - wg.Wait() - - return fmt.Errorf("%w: multiple bindings is only supported for struct and map", types.ErrInvalidType) +// BatchGetLatestValues implements the types.ContractReader interface. +func (s *SolanaChainReaderService) BatchGetLatestValues(ctx context.Context, request types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) { + idxLookup := make(map[types.BoundContract][]int) + batch := []call{} + + for bound, req := range request { + idxLookup[bound] = make([]int, len(req)) + + for idx, readReq := range req { + idxLookup[bound][idx] = len(batch) + batch = append(batch, call{ + ContractName: bound.Name, + ReadName: readReq.ReadName, + Params: readReq.Params, + ReturnVal: readReq.ReturnVal, + }) } + } - // for multiple bindings, preload the remote data in parallel - for idx, binding := range bindings { - results[idx] = &loadedResult{ - value: make(chan []byte, 1), - err: make(chan error, 1), - } - - wg.Add(1) - go func(ctx context.Context, rb readBinding, res *loadedResult, address string) { - defer wg.Done() + results, err := doMethodBatchCall(ctx, s.client, s.bindings, batch) + if err != nil { + return nil, err + } - rb.PreLoad(ctx, address, res) - }(localCtx, binding, results[idx], addresses[idx]) - } + if len(results) != len(batch) { + return nil, errors.New("unexpected number of results") } - // in the case of parallel preloading, GetLatestValue will still run in - // sequence because the function will block until the data is loaded. - // in the case of no preloading, GetLatestValue will load and decode in - // sequence. - for idx, binding := range bindings { - if err := binding.GetLatestValue(ctx, addresses[idx], params, returnVal, results[idx]); err != nil { - localCancel() + result := make(types.BatchGetLatestValuesResult) - wg.Wait() + for bound, idxs := range idxLookup { + result[bound] = make(types.ContractBatchResults, len(idxs)) - return err + for idx, callIdx := range idxs { + res := types.BatchReadResult{ReadName: results[callIdx].readName} + res.SetResult(results[callIdx].returnVal, results[callIdx].err) + + result[bound][idx] = res } } - localCancel() - - wg.Wait() - - return nil -} - -// BatchGetLatestValues implements the types.ContractReader interface. -func (s *SolanaChainReaderService) BatchGetLatestValues(_ context.Context, _ types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) { - return nil, errors.New("unimplemented") + return result, nil } // QueryKey implements the types.ContractReader interface. @@ -288,26 +235,25 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReader s.lookup.addReadNameForContract(namespace, methodName) - for _, procedure := range method.Procedures { - injectAddressModifier(procedure.OutputModifications) - - mod, err := procedure.OutputModifications.ToModifier(codec.DecoderHooks...) - if err != nil { - return err - } - - codecWithModifiers, err := codec.NewNamedModifierCodec(idlCodec, procedure.IDLAccount, mod) - if err != nil { - return err - } - - s.bindings.AddReadBinding(namespace, methodName, newAccountReadBinding( - procedure.IDLAccount, - codecWithModifiers, - s.client, - createRPCOpts(procedure.RPCOpts), - )) + procedure := method.Procedure + + injectAddressModifier(procedure.OutputModifications) + + mod, err := procedure.OutputModifications.ToModifier(codec.DecoderHooks...) + if err != nil { + return err + } + + codecWithModifiers, err := codec.NewNamedModifierCodec(idlCodec, procedure.IDLAccount, mod) + if err != nil { + return err } + + s.bindings.AddReadBinding(namespace, methodName, newAccountReadBinding( + procedure.IDLAccount, + codecWithModifiers, + createRPCOpts(procedure.RPCOpts), + )) } } @@ -353,7 +299,7 @@ func NewAccountDataReader(client *rpc.Client) *accountDataReader { return &accountDataReader{client: client} } -func (r *accountDataReader) ReadAll(ctx context.Context, pk ag_solana.PublicKey, opts *rpc.GetAccountInfoOpts) ([]byte, error) { +func (r *accountDataReader) ReadAll(ctx context.Context, pk solana.PublicKey, opts *rpc.GetAccountInfoOpts) ([]byte, error) { result, err := r.client.GetAccountInfoWithOpts(ctx, pk, opts) if err != nil { return nil, err @@ -363,19 +309,3 @@ func (r *accountDataReader) ReadAll(ctx context.Context, pk ag_solana.PublicKey, return bts, nil } - -func decodeAddressMappings(encoded string) (map[string][]string, error) { - decoded, err := base64.StdEncoding.DecodeString(encoded) - if err != nil { - return nil, err - } - - var readAddresses map[string][]string - - err = json.Unmarshal(decoded, &readAddresses) - if err != nil { - return nil, err - } - - return readAddresses, nil -} diff --git a/pkg/solana/chainreader/chain_reader_test.go b/pkg/solana/chainreader/chain_reader_test.go index 7a1255c07..bf8b246dd 100644 --- a/pkg/solana/chainreader/chain_reader_test.go +++ b/pkg/solana/chainreader/chain_reader_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/json" - "errors" "fmt" "math/big" "os" @@ -78,11 +77,6 @@ func TestSolanaChainReaderService_ServiceCtx(t *testing.T) { } func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { - // TODO fix Solana tests - t.Skip() - - t.Parallel() - ctx := tests.Context(t) // encode values from unmodified test struct to be read and decoded @@ -107,16 +101,20 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { require.NoError(t, svc.Close()) }) - client.SetNext(encoded, nil, 0) + pk := solana.NewWallet().PublicKey() + + client.SetForAddress(pk, encoded, nil, 0) var result modifiedStructWithNestedStruct binding := types.BoundContract{ Name: Namespace, - Address: "", + Address: pk.String(), } + require.NoError(t, svc.Bind(ctx, []types.BoundContract{binding})) require.NoError(t, svc.GetLatestValue(ctx, binding.ReadIdentifier(NamedMethod), primitives.Unconfirmed, nil, &result)) + assert.Equal(t, expected.InnerStruct, result.InnerStruct) assert.Equal(t, expected.Value, result.V) assert.Equal(t, expected.TimeVal, result.TimeVal) @@ -151,7 +149,11 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { } assert.NoError(t, svc.Bind(ctx, []types.BoundContract{binding})) - assert.ErrorIs(t, svc.GetLatestValue(ctx, binding.ReadIdentifier(NamedMethod), primitives.Unconfirmed, nil, &result), expectedErr) + + err = svc.GetLatestValue(ctx, binding.ReadIdentifier(NamedMethod), primitives.Unconfirmed, nil, &result) + + assert.Contains(t, err.Error(), chainreader.ErrMissingAccountData.Error()) + assert.ErrorIs(t, err, types.ErrInternal) }) t.Run("Method Not Found", func(t *testing.T) { @@ -196,33 +198,6 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { assert.NotNil(t, svc.GetLatestValue(ctx, types.BoundContract{Name: "Unknown"}.ReadIdentifier("Unknown"), primitives.Unconfirmed, nil, &result)) }) - t.Run("Bind Success", func(t *testing.T) { - t.Parallel() - - _, conf := newTestConfAndCodec(t) - - client := new(mockedRPCClient) - svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf) - - require.NoError(t, err) - require.NotNil(t, svc) - require.NoError(t, svc.Start(ctx)) - - t.Cleanup(func() { - require.NoError(t, svc.Close()) - }) - - pk := ag_solana.NewWallet().PublicKey() - err = svc.Bind(ctx, []types.BoundContract{ - { - Address: pk.String(), - Name: fmt.Sprintf("%s.%s.%d", Namespace, NamedMethod, 0), - }, - }) - - assert.NoError(t, err) - }) - t.Run("Bind Errors", func(t *testing.T) { t.Parallel() @@ -315,12 +290,10 @@ func newTestConfAndCodec(t *testing.T) (types.RemoteCodec, config.ChainReader) { Methods: map[string]config.ChainDataReader{ NamedMethod: { AnchorIDL: rawIDL, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: testutils.TestStructWithNestedStruct, - OutputModifications: codeccommon.ModifiersConfig{ - &codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}}, - }, + Procedure: config.ChainReaderProcedure{ + IDLAccount: testutils.TestStructWithNestedStruct, + OutputModifications: codeccommon.ModifiersConfig{ + &codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}}, }, }, }, @@ -358,36 +331,21 @@ type mockedRPCClient struct { sequence []mockedRPCCall } -func (_m *mockedRPCClient) ReadAll(_ context.Context, pk ag_solana.PublicKey, _ *rpc.GetAccountInfoOpts) ([]byte, error) { - _m.mu.Lock() - defer _m.mu.Unlock() +func (_m *mockedRPCClient) GetMultipleAccountData(_ context.Context, keys ...solana.PublicKey) ([][]byte, error) { + result := make([][]byte, len(keys)) - if _m.responseByAddress == nil { - _m.responseByAddress = make(map[string]mockedRPCCall) - } + for idx, key := range keys { + call, ok := _m.responseByAddress[key.String()] + if !ok || call.err != nil { + result[idx] = nil - if resp, ok := _m.responseByAddress[pk.String()]; ok { - if resp.delay > 0 { - time.Sleep(resp.delay) + continue } - delete(_m.responseByAddress, pk.String()) - - return resp.bts, resp.err + result[idx] = call.bts } - if len(_m.sequence) == 0 { - return nil, errors.New(" no values to return") - } - - next := _m.sequence[0] - _m.sequence = _m.sequence[1:len(_m.sequence)] - - if next.delay > 0 { - time.Sleep(next.delay) - } - - return next.bts, next.err + return result, nil } func (_m *mockedRPCClient) SetNext(bts []byte, err error, delay time.Duration) { @@ -425,9 +383,13 @@ type chainReaderInterfaceTester struct { func (r *chainReaderInterfaceTester) GetAccountBytes(i int) []byte { account := [20]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + account[i%20] += byte(i) account[(i+3)%20] += byte(i + 3) - return account[:] + + pk := solana.PublicKeyFromBytes(account[:]) + + return pk.Bytes() } func (r *chainReaderInterfaceTester) GetAccountString(i int) string { @@ -456,64 +418,48 @@ func (r *chainReaderInterfaceTester) Setup(t *testing.T) { MethodTakingLatestParamsReturningTestStruct: { AnchorIDL: fullStructIDL(t), Encoding: config.EncodingTypeBorsh, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: "TestStructB", - RPCOpts: &config.RPCOpts{ - Encoding: &encodingBase64, - Commitment: &commitment, - DataSlice: &rpc.DataSlice{ - Offset: &offset, - Length: &length, - }, + Procedure: config.ChainReaderProcedure{ + IDLAccount: "TestStruct", + RPCOpts: &config.RPCOpts{ + Encoding: &encodingBase64, + Commitment: &commitment, + DataSlice: &rpc.DataSlice{ + Offset: &offset, + Length: &length, }, }, - { - IDLAccount: "TestStructA", - }, }, }, MethodReturningUint64: { AnchorIDL: fmt.Sprintf(baseIDL, uint64BaseTypeIDL, ""), Encoding: config.EncodingTypeBorsh, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: "SimpleUint64Value", - OutputModifications: codeccommon.ModifiersConfig{ - &codeccommon.PropertyExtractorConfig{FieldName: "I"}, - }, + Procedure: config.ChainReaderProcedure{ + IDLAccount: "SimpleUint64Value", + OutputModifications: codeccommon.ModifiersConfig{ + &codeccommon.PropertyExtractorConfig{FieldName: "I"}, }, }, }, MethodReturningUint64Slice: { AnchorIDL: fmt.Sprintf(baseIDL, uint64SliceBaseTypeIDL, ""), Encoding: config.EncodingTypeBincode, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: "Uint64Slice", - OutputModifications: codeccommon.ModifiersConfig{ - &codeccommon.PropertyExtractorConfig{FieldName: "Vals"}, - }, + Procedure: config.ChainReaderProcedure{ + IDLAccount: "Uint64Slice", + OutputModifications: codeccommon.ModifiersConfig{ + &codeccommon.PropertyExtractorConfig{FieldName: "Vals"}, }, }, }, MethodReturningSeenStruct: { AnchorIDL: fullStructIDL(t), Encoding: config.EncodingTypeBorsh, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: "TestStructB", - OutputModifications: codeccommon.ModifiersConfig{ - &codeccommon.AddressBytesToStringModifierConfig{ - Fields: []string{"Accountstruct.Accountstr"}, - }, - }, - }, - { - IDLAccount: "TestStructA", - OutputModifications: codeccommon.ModifiersConfig{ - &codeccommon.HardCodeModifierConfig{OffChainValues: map[string]any{"ExtraField": AnyExtraValue}}, + Procedure: config.ChainReaderProcedure{ + IDLAccount: "TestStruct", + OutputModifications: codeccommon.ModifiersConfig{ + &codeccommon.AddressBytesToStringModifierConfig{ + Fields: []string{"Accountstruct.Accountstr"}, }, + &codeccommon.HardCodeModifierConfig{OffChainValues: map[string]any{"ExtraField": AnyExtraValue}}, }, }, }, @@ -524,12 +470,10 @@ func (r *chainReaderInterfaceTester) Setup(t *testing.T) { MethodReturningUint64: { AnchorIDL: fmt.Sprintf(baseIDL, uint64BaseTypeIDL, ""), Encoding: config.EncodingTypeBorsh, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: "SimpleUint64Value", - OutputModifications: codeccommon.ModifiersConfig{ - &codeccommon.PropertyExtractorConfig{FieldName: "I"}, - }, + Procedure: config.ChainReaderProcedure{ + IDLAccount: "SimpleUint64Value", + OutputModifications: codeccommon.ModifiersConfig{ + &codeccommon.PropertyExtractorConfig{FieldName: "I"}, }, }, }, @@ -600,9 +544,11 @@ func (r *wrappedTestChainReader) Name() string { func (r *wrappedTestChainReader) GetLatestValue(ctx context.Context, readIdentifier string, confidenceLevel primitives.ConfidenceLevel, params, returnVal any) error { var ( - a ag_solana.PublicKey - b ag_solana.PublicKey + bts []byte + acct int + err error ) + parts := strings.Split(readIdentifier, "-") if len(parts) < 3 { panic("unexpected readIdentifier length") @@ -611,6 +557,10 @@ func (r *wrappedTestChainReader) GetLatestValue(ctx context.Context, readIdentif contractName := parts[1] method := parts[2] + if contractName == AnySecondContractName { + acct = 1 + } + switch contractName + method { case AnyContractName + EventName: r.test.Skip("Events are not yet supported in Solana") @@ -622,13 +572,11 @@ func (r *wrappedTestChainReader) GetLatestValue(ctx context.Context, readIdentif I: AnyValueToReadWithoutAnArgument, } - bts, err := cdc.Encode(ctx, onChainStruct, "SimpleUint64Value") + bts, err = cdc.Encode(ctx, onChainStruct, "SimpleUint64Value") if err != nil { r.test.Log(err.Error()) r.test.FailNow() } - - r.client.SetNext(bts, nil, 0) case AnyContractName + MethodReturningUint64Slice: cdc := makeTestCodec(r.test, fmt.Sprintf(baseIDL, uint64SliceBaseTypeIDL, ""), config.EncodingTypeBincode) onChainStruct := struct { @@ -637,12 +585,10 @@ func (r *wrappedTestChainReader) GetLatestValue(ctx context.Context, readIdentif Vals: AnySliceToReadWithoutAnArgument, } - bts, err := cdc.Encode(ctx, onChainStruct, "Uint64Slice") + bts, err = cdc.Encode(ctx, onChainStruct, "Uint64Slice") if err != nil { r.test.FailNow() } - - r.client.SetNext(bts, nil, 0) case AnySecondContractName + MethodReturningUint64, AnyContractName: cdc := makeTestCodec(r.test, fmt.Sprintf(baseIDL, uint64BaseTypeIDL, ""), config.EncodingTypeBorsh) onChainStruct := struct { @@ -651,37 +597,26 @@ func (r *wrappedTestChainReader) GetLatestValue(ctx context.Context, readIdentif I: AnyDifferentValueToReadWithoutAnArgument, } - bts, err := cdc.Encode(ctx, onChainStruct, "SimpleUint64Value") + bts, err = cdc.Encode(ctx, onChainStruct, "SimpleUint64Value") if err != nil { r.test.FailNow() } - - r.client.SetNext(bts, nil, 0) case AnyContractName + MethodReturningSeenStruct: nextStruct := CreateTestStruct[*testing.T](0, r.tester) r.testStructQueue = append(r.testStructQueue, &nextStruct) - a, b = getAddresses(r.test, r.tester, AnyContractName, MethodReturningSeenStruct) - fallthrough default: - if len(r.testStructQueue) == 0 { r.test.FailNow() } - if contractName+method != AnyContractName+MethodReturningSeenStruct { - a, b = getAddresses(r.test, r.tester, AnyContractName, MethodTakingLatestParamsReturningTestStruct) - } - nextTestStruct := r.testStructQueue[0] r.testStructQueue = r.testStructQueue[1:len(r.testStructQueue)] // split into two encoded parts to test the preloading function cdc := makeTestCodec(r.test, fullStructIDL(r.test), config.EncodingTypeBorsh) - var bts []byte - var err error if strings.Contains(r.test.Name(), "wraps_config_with_modifiers_using_its_own_mapstructure_overrides") { // TODO: This is a temporary solution. We are manually retyping this struct to avoid breaking unrelated tests. // Once input modifiers are fully implemented, revisit this code and remove this manual struct conversion @@ -716,28 +651,20 @@ func (r *wrappedTestChainReader) GetLatestValue(ctx context.Context, readIdentif NestedStaticStruct: nextTestStruct.NestedStaticStruct, } - bts, err = cdc.Encode(ctx, tempStruct, "TestStructB") + bts, err = cdc.Encode(ctx, tempStruct, "TestStruct") if err != nil { r.test.FailNow() } } else { - bts, err = cdc.Encode(ctx, nextTestStruct, "TestStructB") + bts, err = cdc.Encode(ctx, nextTestStruct, "TestStruct") if err != nil { r.test.FailNow() } } - - // make part A return slower than part B - r.client.SetForAddress(a, bts, nil, 300*time.Millisecond) - - bts, err = cdc.Encode(ctx, nextTestStruct, "TestStructA") - if err != nil { - r.test.FailNow() - } - - r.client.SetForAddress(b, bts, nil, 50*time.Millisecond) } + r.client.SetForAddress(ag_solana.PublicKey(r.tester.GetAccountBytes(acct)), bts, nil, 0) + return r.service.GetLatestValue(ctx, readIdentifier, confidenceLevel, params, returnVal) } @@ -849,28 +776,9 @@ func (r *chainReaderInterfaceTester) TriggerEvent(t *testing.T, testStruct *Test } func (r *chainReaderInterfaceTester) GetBindings(t *testing.T) []types.BoundContract { - mainContractMethods := map[string][]string{ - MethodTakingLatestParamsReturningTestStruct: {r.address[0], r.address[1]}, - MethodReturningUint64: {r.address[2]}, - MethodReturningUint64Slice: {r.address[3]}, - MethodReturningSeenStruct: {r.address[4], r.address[5]}, - } - - addrBts, err := json.Marshal(mainContractMethods) - if err != nil { - t.Log(err.Error()) - t.FailNow() - } - - secondAddrBts, err := json.Marshal(map[string][]string{MethodReturningUint64: {r.address[6]}}) - if err != nil { - t.Log(err.Error()) - t.FailNow() - } - return []types.BoundContract{ - {Name: AnyContractName, Address: base64.StdEncoding.EncodeToString(addrBts)}, - {Name: AnySecondContractName, Address: base64.StdEncoding.EncodeToString(secondAddrBts)}, + {Name: AnyContractName, Address: solana.PublicKeyFromBytes(r.GetAccountBytes(0)).String()}, + {Name: AnySecondContractName, Address: solana.PublicKeyFromBytes(r.GetAccountBytes(1)).String()}, } } @@ -912,7 +820,7 @@ func fullStructIDL(t *testing.T) string { return fmt.Sprintf( baseIDL, - strings.Join([]string{testStructAIDL, testStructBIDL}, ","), + testStructIDL, strings.Join([]string{midLevelDynamicStructIDL, midLevelStaticStructIDL, innerDynamicStructIDL, innerStaticStructIDL, accountStructIDL}, ","), ) } @@ -925,8 +833,8 @@ const ( "types": [%s] }` - testStructAIDL = `{ - "name": "TestStructA", + testStructIDL = `{ + "name": "TestStruct", "type": { "kind": "struct", "fields": [ @@ -934,20 +842,12 @@ const ( {"name": "differentField","type": "string"}, {"name": "bigField","type": "i128"}, {"name": "nestedDynamicStruct","type": {"defined": "MidLevelDynamicStruct"}}, - {"name": "nestedStaticStruct","type": {"defined": "MidLevelStaticStruct"}} - ] - } - }` - - testStructBIDL = `{ - "name": "TestStructB", - "type": { - "kind": "struct", - "fields": [ + {"name": "nestedStaticStruct","type": {"defined": "MidLevelStaticStruct"}}, {"name": "oracleID","type": "u8"}, {"name": "oracleIDs","type": {"array": ["u8",32]}}, {"name": "accountstruct","type": {"defined": "accountstruct"}}, {"name": "accounts","type": {"vec": "bytes"}} + ] } }` diff --git a/pkg/solana/config/chain_reader.go b/pkg/solana/config/chain_reader.go index a1fed147d..3fe9c771f 100644 --- a/pkg/solana/config/chain_reader.go +++ b/pkg/solana/config/chain_reader.go @@ -25,8 +25,8 @@ type ChainDataReader struct { AnchorIDL string `json:"anchorIDL" toml:"anchorIDL"` // Encoding defines the type of encoding used for on-chain data. Currently supported // are 'borsh' and 'bincode'. - Encoding EncodingType `json:"encoding" toml:"encoding"` - Procedures []ChainReaderProcedure `json:"procedures" toml:"procedures"` + Encoding EncodingType `json:"encoding" toml:"encoding"` + Procedure ChainReaderProcedure `json:"procedures" toml:"procedures"` } type EncodingType int diff --git a/pkg/solana/config/chain_reader_test.go b/pkg/solana/config/chain_reader_test.go index b0ad49181..7d290b50c 100644 --- a/pkg/solana/config/chain_reader_test.go +++ b/pkg/solana/config/chain_reader_test.go @@ -90,28 +90,24 @@ var validChainReaderConfig = config.ChainReader{ "Method": { AnchorIDL: "test idl 1", Encoding: config.EncodingTypeBorsh, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: testutils.TestStructWithNestedStruct, - }, + Procedure: config.ChainReaderProcedure{ + IDLAccount: testutils.TestStructWithNestedStruct, }, }, "MethodWithOpts": { AnchorIDL: "test idl 2", Encoding: config.EncodingTypeBorsh, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: testutils.TestStructWithNestedStruct, - OutputModifications: codeccommon.ModifiersConfig{ - &codeccommon.PropertyExtractorConfig{FieldName: "DurationVal"}, - }, - RPCOpts: &config.RPCOpts{ - Encoding: &encodingBase64, - Commitment: &commitment, - DataSlice: &rpc.DataSlice{ - Offset: &offset, - Length: &length, - }, + Procedure: config.ChainReaderProcedure{ + IDLAccount: testutils.TestStructWithNestedStruct, + OutputModifications: codeccommon.ModifiersConfig{ + &codeccommon.PropertyExtractorConfig{FieldName: "DurationVal"}, + }, + RPCOpts: &config.RPCOpts{ + Encoding: &encodingBase64, + Commitment: &commitment, + DataSlice: &rpc.DataSlice{ + Offset: &offset, + Length: &length, }, }, }, @@ -123,10 +119,8 @@ var validChainReaderConfig = config.ChainReader{ "Method": { AnchorIDL: "test idl 3", Encoding: config.EncodingTypeBincode, - Procedures: []config.ChainReaderProcedure{ - { - IDLAccount: testutils.TestStructWithNestedStruct, - }, + Procedure: config.ChainReaderProcedure{ + IDLAccount: testutils.TestStructWithNestedStruct, }, }, },