diff --git a/ecosystem/cache/cache_test.go b/ecosystem/cache/cache_test.go index 0500c4344d..6c1620038f 100644 --- a/ecosystem/cache/cache_test.go +++ b/ecosystem/cache/cache_test.go @@ -444,12 +444,15 @@ func TestCacheSetGetLatestWhenAdvancingLatest(t *testing.T) { func TestCacheSetGetJsonRPCWithID(t *testing.T) { t.Parallel() tests := []struct { - name string - valid bool - delay time.Duration - finalized bool - hash []byte + name string + valid bool + delay time.Duration + finalized bool + hash []byte + nullIdInGet bool + nullIdInSet bool }{ + // No null ID {name: "Finalized No Hash", valid: true, delay: time.Millisecond, finalized: true, hash: nil}, {name: "Finalized After delay No Hash", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: nil}, {name: "NonFinalized No Hash", valid: true, delay: time.Millisecond, finalized: false, hash: nil}, @@ -458,6 +461,50 @@ func TestCacheSetGetJsonRPCWithID(t *testing.T) { {name: "Finalized After delay With Hash", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: []byte{1, 2, 3}}, {name: "NonFinalized With Hash", valid: true, delay: time.Millisecond, finalized: false, hash: []byte{1, 2, 3}}, {name: "NonFinalized After delay With Hash", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: []byte{1, 2, 3}}, + + // Null ID in get and set + {name: "Finalized No Hash, with null id in get and set", valid: true, delay: time.Millisecond, finalized: true, hash: nil, nullIdInGet: true, nullIdInSet: true}, + {name: "Finalized After delay No Hash, with null id in get and set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: nil, nullIdInGet: true, nullIdInSet: true}, + {name: "NonFinalized No Hash, with null id in get and set", valid: true, delay: time.Millisecond, finalized: false, hash: nil, nullIdInGet: true, nullIdInSet: true}, + {name: "NonFinalized After delay No Hash", valid: false, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: nil, nullIdInGet: true, nullIdInSet: true}, + {name: "Finalized With Hash, with null id in get and set", valid: true, delay: time.Millisecond, finalized: true, hash: []byte{1, 2, 3}, nullIdInGet: true, nullIdInSet: true}, + {name: "Finalized After delay With Hash, with null id in get and set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: []byte{1, 2, 3}, nullIdInGet: true, nullIdInSet: true}, + {name: "NonFinalized With Hash, with null id in get and set", valid: true, delay: time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInGet: true, nullIdInSet: true}, + {name: "NonFinalized After delay With Hash, with null id in get and set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInGet: true, nullIdInSet: true}, + + // Null ID only in get + {name: "Finalized No Hash, with null id only in get", valid: true, delay: time.Millisecond, finalized: true, hash: nil, nullIdInGet: true}, + {name: "Finalized After delay No Hash, with null id only in get", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: nil, nullIdInGet: true}, + {name: "NonFinalized No Hash, with null id only in get", valid: true, delay: time.Millisecond, finalized: false, hash: nil, nullIdInGet: true}, + {name: "NonFinalized After delay No Hash", valid: false, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: nil, nullIdInGet: true}, + {name: "Finalized With Hash, with null id only in get", valid: true, delay: time.Millisecond, finalized: true, hash: []byte{1, 2, 3}, nullIdInGet: true}, + {name: "Finalized After delay With Hash, with null id only in get", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: []byte{1, 2, 3}, nullIdInGet: true}, + {name: "NonFinalized With Hash, with null id only in get", valid: true, delay: time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInGet: true}, + {name: "NonFinalized After delay With Hash, with null id only in get", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInGet: true}, + + // Null ID only in set + {name: "Finalized No Hash, with null id only in set", valid: true, delay: time.Millisecond, finalized: true, hash: nil, nullIdInSet: true}, + {name: "Finalized After delay No Hash, with null id only in set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: nil, nullIdInSet: true}, + {name: "NonFinalized No Hash, with null id only in set", valid: true, delay: time.Millisecond, finalized: false, hash: nil, nullIdInSet: true}, + {name: "NonFinalized After delay No Hash only in set", valid: false, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: nil, nullIdInSet: true}, + {name: "Finalized With Hash, with null id only in set", valid: true, delay: time.Millisecond, finalized: true, hash: []byte{1, 2, 3}, nullIdInSet: true}, + {name: "Finalized After delay With Hash, with null id only in set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: true, hash: []byte{1, 2, 3}, nullIdInSet: true}, + {name: "NonFinalized With Hash, with null id only in set", valid: true, delay: time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInSet: true}, + {name: "NonFinalized After delay With Hash, with null id only in set", valid: true, delay: cache.DefaultExpirationForNonFinalized + time.Millisecond, finalized: false, hash: []byte{1, 2, 3}, nullIdInSet: true}, + } + + formatIDInJson := func(idNum int64, nullId bool) []byte { + if nullId { + return []byte(`{"jsonrpc":"2.0","method":"status","params":[],"id":null}`) + } + return []byte(fmt.Sprintf(`{"jsonrpc":"2.0","method":"status","params":[],"id":%d}`, idNum)) + } + + formatIDInJsonResponse := func(idNum int64, nullId bool) []byte { + if nullId { + return []byte(`{"jsonrpc":"2.0","result":0x12345,"id":null}`) + } + return []byte(fmt.Sprintf(`{"jsonrpc":"2.0","result":0x12345,"id":%d}`, idNum)) } for _, tt := range tests { @@ -465,17 +512,12 @@ func TestCacheSetGetJsonRPCWithID(t *testing.T) { ctx, cacheServer := initTest() id := rand.Int63() - formatIDInJson := func(idNum int64) []byte { - return []byte(fmt.Sprintf(`{"jsonrpc":"2.0","method":"status","params":[],"id":%d}`, idNum)) - } - formatIDInJsonResponse := func(idNum int64) []byte { - return []byte(fmt.Sprintf(`{"jsonrpc":"2.0","result":0x12345,"id":%d}`, idNum)) - } - request := getRequest(1230, formatIDInJson(id), spectypes.APIInterfaceJsonRPC) // &pairingtypes.RelayRequest{ + request := getRequest(1230, formatIDInJson(id, tt.nullIdInSet), spectypes.APIInterfaceJsonRPC) // &pairingtypes.RelayRequest{ response := &pairingtypes.RelayReply{ - Data: formatIDInJsonResponse(id), // response has the old id when cached + Data: formatIDInJsonResponse(id, tt.nullIdInSet), // response has the old id when cached } + messageSet := pairingtypes.RelayCacheSet{ RequestHash: HashRequest(t, request, StubChainID), BlockHash: tt.hash, @@ -492,9 +534,9 @@ func TestCacheSetGetJsonRPCWithID(t *testing.T) { time.Sleep(tt.delay) // now to get it - changedID := id + 1 // now we change the ID: - request.Data = formatIDInJson(changedID) + changedID := id + 1 + request.Data = formatIDInJson(changedID, tt.nullIdInGet) hash, outputFormatter := HashRequestFormatter(t, request, StubChainID) messageGet := pairingtypes.RelayCacheGet{ RequestHash: hash, @@ -503,13 +545,20 @@ func TestCacheSetGetJsonRPCWithID(t *testing.T) { Finalized: tt.finalized, RequestedBlock: request.RequestBlock, } + cacheReply, err := cacheServer.GetRelay(ctx, &messageGet) if tt.valid { cacheReply.Reply.Data = outputFormatter(cacheReply.Reply.Data) require.NoError(t, err) + result := gjson.GetBytes(cacheReply.GetReply().Data, format.IDFieldName) extractedID := result.Raw - require.Equal(t, strconv.FormatInt(changedID, 10), extractedID) + + if tt.nullIdInGet { + require.Equal(t, "null", extractedID) + } else { + require.Equal(t, strconv.FormatInt(changedID, 10), extractedID) + } } else { require.Error(t, err) } diff --git a/protocol/chainlib/chainproxy/rpcInterfaceMessages/tendermintRPCMessage.go b/protocol/chainlib/chainproxy/rpcInterfaceMessages/tendermintRPCMessage.go index 0e6f3d6b59..efa93c30bf 100644 --- a/protocol/chainlib/chainproxy/rpcInterfaceMessages/tendermintRPCMessage.go +++ b/protocol/chainlib/chainproxy/rpcInterfaceMessages/tendermintRPCMessage.go @@ -104,6 +104,8 @@ func IdFromRawMessage(rawID json.RawMessage) (jsonrpcId, error) { case float64: // json.Unmarshal uses float64 for all numbers return JSONRPCIntID(int(id)), nil + case nil: + return jsonrpcId(nil), nil default: typ := reflect.TypeOf(id) return nil, utils.LavaFormatError("failed to unmarshal id not a string or float", err, []utils.Attribute{{Key: "id", Value: string(rawID)}, {Key: "id type", Value: typ}}...) diff --git a/protocol/integration/protocol_test.go b/protocol/integration/protocol_test.go index 0b12b6f22e..02dcf6b2c0 100644 --- a/protocol/integration/protocol_test.go +++ b/protocol/integration/protocol_test.go @@ -1,6 +1,7 @@ package integration_test import ( + "bytes" "context" "encoding/json" "fmt" @@ -11,6 +12,7 @@ import ( "time" "github.com/lavanet/lava/protocol/chainlib" + "github.com/lavanet/lava/protocol/chainlib/chainproxy/rpcInterfaceMessages" "github.com/lavanet/lava/protocol/chaintracker" "github.com/lavanet/lava/protocol/common" "github.com/lavanet/lava/protocol/lavaprotocol" @@ -189,6 +191,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string w.WriteHeader(status) fmt.Fprint(w, string(data)) }) + chainParser, chainRouter, chainFetcher, _, endpoint, err := chainlib.CreateChainLibMocks(ctx, specId, apiInterface, serverHandler, "../../", addons) require.NoError(t, err) require.NotNil(t, chainParser) @@ -240,7 +243,8 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string ConsistencyCallback: nil, Pmetrics: nil, } - mockChainFetcher := NewMockChainFetcher(1000, 10, nil) + + mockChainFetcher := NewMockChainFetcher(1000, int64(blocksToSaveChainTracker), nil) chainTracker, err := chaintracker.NewChainTracker(ctx, mockChainFetcher, chainTrackerConfig) require.NoError(t, err) reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser) @@ -550,3 +554,114 @@ func TestConsumerProviderTx(t *testing.T) { }) } } + +func TestConsumerProviderJsonRpcWithNullID(t *testing.T) { + playbook := []struct { + name string + specId string + method string + expected string + apiInterface string + }{ + { + name: "jsonrpc", + specId: "ETH1", + method: "eth_blockNumber", + expected: `{"jsonrpc":"2.0","id":null,"result":{}}`, + apiInterface: spectypes.APIInterfaceJsonRPC, + }, + { + name: "tendermintrpc", + specId: "LAV1", + method: "status", + expected: `{"jsonrpc":"2.0","result":{}}`, + apiInterface: spectypes.APIInterfaceTendermintRPC, + }, + } + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + ctx := context.Background() + // can be any spec and api interface + specId := play.specId + apiInterface := play.apiInterface + epoch := uint64(100) + requiredResponses := 1 + lavaChainID := "lava" + numProviders := 5 + + consumerListenAddress := addressGen.GetAddress() + pairingList := map[uint64]*lavasession.ConsumerSessionsWithProvider{} + type providerData struct { + account sigs.Account + endpoint *lavasession.RPCProviderEndpoint + server *rpcprovider.RPCProviderServer + replySetter *ReplySetter + mockChainFetcher *MockChainFetcher + } + providers := []providerData{} + + for i := 0; i < numProviders; i++ { + // providerListenAddress := "localhost:111" + strconv.Itoa(i) + account := sigs.GenerateDeterministicFloatingKey(randomizer) + providerDataI := providerData{account: account} + providers = append(providers, providerDataI) + } + consumerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + for i := 0; i < numProviders; i++ { + ctx := context.Background() + providerDataI := providers[i] + listenAddress := addressGen.GetAddress() + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil)) + providers[i].replySetter.replyDataBuf = []byte(fmt.Sprintf(`{"result": %d}`, i+1)) + } + for i := 0; i < numProviders; i++ { + pairingList[uint64(i)] = &lavasession.ConsumerSessionsWithProvider{ + PublicLavaAddress: providers[i].account.Addr.String(), + Endpoints: []*lavasession.Endpoint{ + { + NetworkAddress: providers[i].endpoint.NetworkAddress.Address, + Enabled: true, + Geolocation: 1, + }, + }, + Sessions: map[int64]*lavasession.SingleConsumerSession{}, + MaxComputeUnits: 10000, + UsedComputeUnits: 0, + PairingEpoch: epoch, + } + } + rpcconsumerServer := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + require.NotNil(t, rpcconsumerServer) + + for i := 0; i < numProviders; i++ { + handler := func(req []byte, header http.Header) (data []byte, status int) { + var jsonRpcMessage rpcInterfaceMessages.JsonrpcMessage + err := json.Unmarshal(req, &jsonRpcMessage) + require.NoError(t, err) + + response := fmt.Sprintf(`{"jsonrpc":"2.0","result": {}, "id": %v}`, string(jsonRpcMessage.ID)) + return []byte(response), http.StatusOK + } + providers[i].replySetter.handler = handler + } + + client := http.Client{Timeout: 500 * time.Millisecond} + jsonMsg := fmt.Sprintf(`{"jsonrpc":"2.0","method":"%v","params": [], "id":null}`, play.method) + msgBuffer := bytes.NewBuffer([]byte(jsonMsg)) + req, err := http.NewRequest(http.MethodPost, "http://"+consumerListenAddress, msgBuffer) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + require.NoError(t, err) + + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, string(bodyBytes)) + + resp.Body.Close() + + require.Equal(t, play.expected, string(bodyBytes)) + }) + } +}