From db8248cdb5e533232c861ff6bd38d6c35acbe818 Mon Sep 17 00:00:00 2001 From: Elad Gildnur <6321801+shleikes@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:08:03 +0300 Subject: [PATCH 1/6] feat: PRT - Add cache to protocol integration tests (#1708) * Add golang/mock as a direct requirement in go.mod * Move all createRpcConsumer function arguments to a struct * Small fix to one of the integration tests * Add option to add cache for consumer in integration tests * Add consumer with cache integration test * Small code clean for the test * Move all createRpcProvider function arguments to a struct * Small fixes to TestConsumerProviderWithConsumerSideCache * Add TestConsumerProviderWithProviderSideCache * Small fix to cache server * Disable metrics server for cache in tests * Add TestConsumerProviderWithConsumerAndProviderSideCache * Attempt to fix test * Small fix to a different test --- ecosystem/cache/server.go | 10 +- go.mod | 2 +- protocol/integration/protocol_test.go | 704 ++++++++++++++++++++++++-- 3 files changed, 656 insertions(+), 60 deletions(-) diff --git a/ecosystem/cache/server.go b/ecosystem/cache/server.go index 9f3f8df1dc..fd867b303c 100644 --- a/ecosystem/cache/server.go +++ b/ecosystem/cache/server.go @@ -107,7 +107,7 @@ func (cs *CacheServer) Serve(ctx context.Context, if strings.HasPrefix(listenAddr, unixPrefix) { // Unix socket host, port, err := net.SplitHostPort(listenAddr) if err != nil { - utils.LavaFormatFatal("Failed to parse unix socket, provide address in this format unix:/tmp/example.sock: %v\n", err) + utils.LavaFormatFatal("Failed to parse unix socket, provide address in this format unix:/tmp/example.sock", err) return } @@ -115,26 +115,26 @@ func (cs *CacheServer) Serve(ctx context.Context, addr, err := net.ResolveUnixAddr(host, port) if err != nil { - utils.LavaFormatFatal("Failed to resolve unix socket address: %v\n", err) + utils.LavaFormatFatal("Failed to resolve unix socket address", err) return } lis, err = net.ListenUnix(host, addr) if err != nil { - utils.LavaFormatFatal("Faild to listen to unix socket listener: %v\n", err) + utils.LavaFormatFatal("Failed to listen to unix socket listener", err) return } // Set permissions for the Unix socket err = os.Chmod(port, 0o600) if err != nil { - utils.LavaFormatFatal("Failed to set permissions for Unix socket: %v\n", err) + utils.LavaFormatFatal("Failed to set permissions for Unix socket", err) return } } else { lis, err = net.Listen("tcp", listenAddr) if err != nil { - utils.LavaFormatFatal("Cache server failure setting up TCP listener: %v\n", err) + utils.LavaFormatFatal("Cache server failure setting up TCP listener", err) return } } diff --git a/go.mod b/go.mod index 67e086b80a..82fc25f95b 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/fullstorydev/grpcurl v1.8.5 github.com/goccy/go-json v0.10.2 github.com/gogo/status v1.1.0 + github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.4 github.com/itchyny/gojq v0.12.16 github.com/jhump/protoreflect v1.15.1 @@ -85,7 +86,6 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/googleapis v1.4.1 // indirect github.com/golang/glog v1.2.0 // indirect - github.com/golang/mock v1.6.0 // indirect github.com/google/flatbuffers v1.12.1 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/s2a-go v0.1.7 // indirect diff --git a/protocol/integration/protocol_test.go b/protocol/integration/protocol_test.go index 048bafedda..33a69599f1 100644 --- a/protocol/integration/protocol_test.go +++ b/protocol/integration/protocol_test.go @@ -10,11 +10,13 @@ import ( "net/url" "os" "strconv" + "strings" "sync" "testing" "time" "github.com/gorilla/websocket" + "github.com/lavanet/lava/v3/ecosystem/cache" "github.com/lavanet/lava/v3/protocol/chainlib" "github.com/lavanet/lava/v3/protocol/chainlib/chainproxy/rpcInterfaceMessages" "github.com/lavanet/lava/v3/protocol/chaintracker" @@ -22,6 +24,7 @@ import ( "github.com/lavanet/lava/v3/protocol/lavaprotocol/finalizationconsensus" "github.com/lavanet/lava/v3/protocol/lavasession" "github.com/lavanet/lava/v3/protocol/metrics" + "github.com/lavanet/lava/v3/protocol/performance" "github.com/lavanet/lava/v3/protocol/provideroptimizer" "github.com/lavanet/lava/v3/protocol/rpcconsumer" "github.com/lavanet/lava/v3/protocol/rpcprovider" @@ -162,21 +165,33 @@ func createInMemoryRewardDb(specs []string) (*rewardserver.RewardDB, error) { return rewardDB, nil } -func createRpcConsumer(t *testing.T, ctx context.Context, specId string, apiInterface string, account sigs.Account, consumerListenAddress string, epoch uint64, pairingList map[uint64]*lavasession.ConsumerSessionsWithProvider, requiredResponses int, lavaChainID string) (*rpcconsumer.RPCConsumerServer, *mockConsumerStateTracker) { +type rpcConsumerOptions struct { + specId string + apiInterface string + account sigs.Account + consumerListenAddress string + epoch uint64 + pairingList map[uint64]*lavasession.ConsumerSessionsWithProvider + requiredResponses int + lavaChainID string + cacheListenAddress string +} + +func createRpcConsumer(t *testing.T, ctx context.Context, rpcConsumerOptions rpcConsumerOptions) (*rpcconsumer.RPCConsumerServer, *mockConsumerStateTracker) { serverHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Handle the incoming request and provide the desired response w.WriteHeader(http.StatusOK) }) - chainParser, _, chainFetcher, _, _, err := chainlib.CreateChainLibMocks(ctx, specId, apiInterface, serverHandler, nil, "../../", nil) + chainParser, _, chainFetcher, _, _, err := chainlib.CreateChainLibMocks(ctx, rpcConsumerOptions.specId, rpcConsumerOptions.apiInterface, serverHandler, nil, "../../", nil) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainFetcher) rpcConsumerServer := &rpcconsumer.RPCConsumerServer{} rpcEndpoint := &lavasession.RPCEndpoint{ - NetworkAddress: consumerListenAddress, - ChainID: specId, - ApiInterface: apiInterface, + NetworkAddress: rpcConsumerOptions.consumerListenAddress, + ChainID: rpcConsumerOptions.specId, + ApiInterface: rpcConsumerOptions.apiInterface, TLSEnabled: false, HealthCheckPath: "", Geolocation: 1, @@ -187,13 +202,21 @@ func createRpcConsumer(t *testing.T, ctx context.Context, specId string, apiInte baseLatency := common.AverageWorldLatency / 2 optimizer := provideroptimizer.NewProviderOptimizer(provideroptimizer.STRATEGY_BALANCED, averageBlockTime, baseLatency, 2) consumerSessionManager := lavasession.NewConsumerSessionManager(rpcEndpoint, optimizer, nil, nil, "test", lavasession.NewActiveSubscriptionProvidersStorage()) - consumerSessionManager.UpdateAllProviders(epoch, pairingList) + consumerSessionManager.UpdateAllProviders(rpcConsumerOptions.epoch, rpcConsumerOptions.pairingList) + + var cache *performance.Cache = nil + if rpcConsumerOptions.cacheListenAddress != "" { + cache, err = performance.InitCache(ctx, rpcConsumerOptions.cacheListenAddress) + if err != nil { + t.Fatalf("Failed To Connect to cache at address %s: %v", rpcConsumerOptions.cacheListenAddress, err) + } + } - consumerConsistency := rpcconsumer.NewConsumerConsistency(specId) + consumerConsistency := rpcconsumer.NewConsumerConsistency(rpcConsumerOptions.specId) consumerCmdFlags := common.ConsumerCmdFlags{} rpcconsumerLogs, err := metrics.NewRPCConsumerLogs(nil, nil) require.NoError(t, err) - err = rpcConsumerServer.ServeRPCRequests(ctx, rpcEndpoint, consumerStateTracker, chainParser, finalizationConsensus, consumerSessionManager, requiredResponses, account.SK, lavaChainID, nil, rpcconsumerLogs, account.Addr, consumerConsistency, nil, consumerCmdFlags, false, nil, nil, nil) + err = rpcConsumerServer.ServeRPCRequests(ctx, rpcEndpoint, consumerStateTracker, chainParser, finalizationConsensus, consumerSessionManager, rpcConsumerOptions.requiredResponses, rpcConsumerOptions.account.SK, rpcConsumerOptions.lavaChainID, cache, rpcconsumerLogs, rpcConsumerOptions.account.Addr, consumerConsistency, nil, consumerCmdFlags, false, nil, nil, nil) require.NoError(t, err) // wait for consumer to finish initialization @@ -214,7 +237,19 @@ func createRpcConsumer(t *testing.T, ctx context.Context, specId string, apiInte return rpcConsumerServer, consumerStateTracker } -func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string, specId string, apiInterface string, listenAddress string, account sigs.Account, lavaChainID string, addons []string, providerUniqueId string) (*rpcprovider.RPCProviderServer, *lavasession.RPCProviderEndpoint, *ReplySetter, *MockChainFetcher, *MockReliabilityManager) { +type rpcProviderOptions struct { + consumerAddress string + specId string + apiInterface string + listenAddress string + account sigs.Account + lavaChainID string + addons []string + providerUniqueId string + cacheListenAddress string +} + +func createRpcProvider(t *testing.T, ctx context.Context, rpcProviderOptions rpcProviderOptions) (*rpcprovider.RPCProviderServer, *lavasession.RPCProviderEndpoint, *ReplySetter, *MockChainFetcher, *MockReliabilityManager) { replySetter := ReplySetter{ status: http.StatusOK, replyDataBuf: []byte(`{"reply": "REPLY-STUB"}`), @@ -234,16 +269,16 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string fmt.Fprint(w, string(data)) }) - chainParser, chainRouter, chainFetcher, _, endpoint, err := chainlib.CreateChainLibMocks(ctx, specId, apiInterface, serverHandler, nil, "../../", addons) + chainParser, chainRouter, chainFetcher, _, endpoint, err := chainlib.CreateChainLibMocks(ctx, rpcProviderOptions.specId, rpcProviderOptions.apiInterface, serverHandler, nil, "../../", rpcProviderOptions.addons) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainFetcher) require.NotNil(t, chainRouter) - endpoint.NetworkAddress.Address = listenAddress + endpoint.NetworkAddress.Address = rpcProviderOptions.listenAddress rpcProviderServer := &rpcprovider.RPCProviderServer{} - if providerUniqueId != "" { - rpcProviderServer.SetProviderUniqueId(providerUniqueId) + if rpcProviderOptions.providerUniqueId != "" { + rpcProviderServer.SetProviderUniqueId(rpcProviderOptions.providerUniqueId) } rpcProviderEndpoint := &lavasession.RPCProviderEndpoint{ @@ -253,8 +288,8 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string CertPem: "", DisableTLS: false, }, - ChainID: specId, - ApiInterface: apiInterface, + ChainID: rpcProviderOptions.specId, + ApiInterface: rpcProviderOptions.apiInterface, Geolocation: 1, NodeUrls: []common.NodeUrl{ { @@ -263,22 +298,22 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string AuthConfig: common.AuthConfig{}, IpForwarding: false, Timeout: 0, - Addons: addons, + Addons: rpcProviderOptions.addons, SkipVerifications: []string{}, }, }, } - rewardDB, err := createInMemoryRewardDb([]string{specId}) + rewardDB, err := createInMemoryRewardDb([]string{rpcProviderOptions.specId}) require.NoError(t, err) _, averageBlockTime, blocksToFinalization, blocksInFinalizationData := chainParser.ChainBlockStats() - mockProviderStateTracker := mockProviderStateTracker{consumerAddressForPairing: consumerAddress, averageBlockTime: averageBlockTime} + mockProviderStateTracker := mockProviderStateTracker{consumerAddressForPairing: rpcProviderOptions.consumerAddress, averageBlockTime: averageBlockTime} rws := rewardserver.NewRewardServer(&mockProviderStateTracker, nil, rewardDB, "badger_test", 1, 10, nil) blockMemorySize, err := mockProviderStateTracker.GetEpochSizeMultipliedByRecommendedEpochNumToCollectPayment(ctx) require.NoError(t, err) providerSessionManager := lavasession.NewProviderSessionManager(rpcProviderEndpoint, blockMemorySize) providerPolicy := rpcprovider.GetAllAddonsAndExtensionsFromNodeUrlSlice(rpcProviderEndpoint.NodeUrls) - chainParser.SetPolicy(providerPolicy, specId, apiInterface) + chainParser.SetPolicy(providerPolicy, rpcProviderOptions.specId, rpcProviderOptions.apiInterface) blocksToSaveChainTracker := uint64(blocksToFinalization + blocksInFinalizationData) chainTrackerConfig := chaintracker.ChainTrackerConfig{ @@ -290,13 +325,21 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string Pmetrics: nil, } + var cache *performance.Cache = nil + if rpcProviderOptions.cacheListenAddress != "" { + cache, err = performance.InitCache(ctx, rpcProviderOptions.cacheListenAddress) + if err != nil { + t.Fatalf("Failed To Connect to cache at address %s: %v", rpcProviderOptions.cacheListenAddress, err) + } + } + mockChainFetcher := NewMockChainFetcher(1000, int64(blocksToSaveChainTracker), nil) chainTracker, err := chaintracker.NewChainTracker(ctx, mockChainFetcher, chainTrackerConfig) require.NoError(t, err) chainTracker.StartAndServe(ctx) - reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser) + reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, rpcProviderOptions.account.Addr.String(), chainRouter, chainParser) mockReliabilityManager := NewMockReliabilityManager(reliabilityManager) - rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil, false) + rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, rpcProviderOptions.account.SK, cache, chainRouter, &mockProviderStateTracker, rpcProviderOptions.account.Addr, rpcProviderOptions.lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil, false) listener := rpcprovider.NewProviderListener(ctx, rpcProviderEndpoint.NetworkAddress, "/health") err = listener.RegisterReceiver(rpcProviderServer, rpcProviderEndpoint) require.NoError(t, err) @@ -307,13 +350,22 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string return rpcProviderServer, endpoint, &replySetter, mockChainFetcher, mockReliabilityManager } +func createCacheServer(t *testing.T, ctx context.Context, listenAddress string) { + go func() { + cs := cache.CacheServer{CacheMaxCost: 2 * 1024 * 1024 * 1024} // taken from max-items default value + cs.InitCache(ctx, cache.DefaultExpirationTimeFinalized, cache.DefaultExpirationForNonFinalized, cache.DefaultExpirationNodeErrors, cache.DefaultExpirationBlocksHashesToHeights, "disabled", cache.DefaultExpirationTimeFinalizedMultiplier, cache.DefaultExpirationTimeNonFinalizedMultiplier) + cs.Serve(ctx, listenAddress) + }() + cacheServerUp := checkServerStatusWithTimeout("http://"+listenAddress, time.Second*7) + require.True(t, cacheServerUp) +} + func TestConsumerProviderBasic(t *testing.T) { ctx := context.Background() // can be any spec and api interface specId := "LAV1" - apiInterface := spectypes.APIInterfaceTendermintRPC + apiInterface := spectypes.APIInterfaceRest epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := 1 @@ -339,7 +391,17 @@ func TestConsumerProviderBasic(t *testing.T) { ctx := context.Background() providerDataI := providers[i] listenAddress := addressGen.GetAddress() - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, rpcProviderOptions) } for i := 0; i < numProviders; i++ { pairingList[uint64(i)] = &lavasession.ConsumerSessionsWithProvider{ @@ -357,10 +419,21 @@ func TestConsumerProviderBasic(t *testing.T) { PairingEpoch: epoch, } } - rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) require.NotNil(t, rpcconsumerServer) client := http.Client{} - resp, err := client.Get("http://" + consumerListenAddress + "/status") + resp, err := client.Get("http://" + consumerListenAddress + "/cosmos/base/tendermint/v1beta1/blocks/latest") require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) bodyBytes, err := io.ReadAll(resp.Body) @@ -390,7 +463,6 @@ func TestConsumerProviderWithProviders(t *testing.T) { specId := "LAV1" apiInterface := spectypes.APIInterfaceTendermintRPC epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := 5 @@ -415,7 +487,17 @@ func TestConsumerProviderWithProviders(t *testing.T) { ctx := context.Background() providerDataI := providers[i] listenAddress := addressGen.GetAddress() - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, rpcProviderOptions) providers[i].replySetter.replyDataBuf = []byte(fmt.Sprintf(`{"reply": %d}`, i+1)) } for i := 0; i < numProviders; i++ { @@ -434,7 +516,18 @@ func TestConsumerProviderWithProviders(t *testing.T) { PairingEpoch: epoch, } } - rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) require.NotNil(t, rpcconsumerServer) if play.scenario != 1 { counter := map[int]int{} @@ -524,7 +617,6 @@ func TestConsumerProviderTx(t *testing.T) { specId := "LAV1" apiInterface := spectypes.APIInterfaceRest epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := 5 @@ -550,7 +642,17 @@ func TestConsumerProviderTx(t *testing.T) { ctx := context.Background() providerDataI := providers[i] listenAddress := addressGen.GetAddress() - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, rpcProviderOptions) providers[i].replySetter.replyDataBuf = []byte(fmt.Sprintf(`{"result": %d}`, i+1)) } for i := 0; i < numProviders; i++ { @@ -569,7 +671,18 @@ func TestConsumerProviderTx(t *testing.T) { PairingEpoch: epoch, } } - rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) require.NotNil(t, rpcconsumerServer) for i := 0; i < numProviders; i++ { @@ -631,7 +744,6 @@ func TestConsumerProviderJsonRpcWithNullID(t *testing.T) { specId := play.specId apiInterface := play.apiInterface epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := 5 @@ -656,7 +768,17 @@ func TestConsumerProviderJsonRpcWithNullID(t *testing.T) { ctx := context.Background() providerDataI := providers[i] listenAddress := addressGen.GetAddress() - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, rpcProviderOptions) providers[i].replySetter.replyDataBuf = []byte(fmt.Sprintf(`{"result": %d}`, i+1)) } for i := 0; i < numProviders; i++ { @@ -675,7 +797,18 @@ func TestConsumerProviderJsonRpcWithNullID(t *testing.T) { PairingEpoch: epoch, } } - rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) require.NotNil(t, rpcconsumerServer) for i := 0; i < numProviders; i++ { @@ -741,7 +874,6 @@ func TestConsumerProviderSubscriptionsHappyFlow(t *testing.T) { specId := play.specId apiInterface := play.apiInterface epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := 5 @@ -766,7 +898,17 @@ func TestConsumerProviderSubscriptionsHappyFlow(t *testing.T) { ctx := context.Background() providerDataI := providers[i] listenAddress := addressGen.GetAddress() - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, rpcProviderOptions) providers[i].replySetter.replyDataBuf = []byte(fmt.Sprintf(`{"result": %d}`, i+1)) } for i := 0; i < numProviders; i++ { @@ -785,7 +927,18 @@ func TestConsumerProviderSubscriptionsHappyFlow(t *testing.T) { PairingEpoch: epoch, } } - rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) require.NotNil(t, rpcconsumerServer) for i := 0; i < numProviders; i++ { @@ -855,7 +1008,6 @@ func TestSameProviderConflictBasicResponseCheck(t *testing.T) { specId := "LAV1" apiInterface := spectypes.APIInterfaceRest epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := play.numOfProviders @@ -883,7 +1035,17 @@ func TestSameProviderConflictBasicResponseCheck(t *testing.T) { ctx := context.Background() providerDataI := providers[i] listenAddress := addressGen.GetAddress() - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, providers[i].mockReliabilityManager = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, providers[i].mockReliabilityManager = createRpcProvider(t, ctx, rpcProviderOptions) providers[i].replySetter.replyDataBuf = []byte(fmt.Sprintf(`{"result": %d}`, i+1)) } @@ -903,7 +1065,18 @@ func TestSameProviderConflictBasicResponseCheck(t *testing.T) { PairingEpoch: epoch, } } - rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) require.NotNil(t, rpcconsumerServer) // Set first provider as a "liar", to return wrong block hashes @@ -991,7 +1164,6 @@ func TestArchiveProvidersRetry(t *testing.T) { specId := "LAV1" apiInterface := spectypes.APIInterfaceRest epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := play.numOfProviders @@ -1023,7 +1195,18 @@ func TestArchiveProvidersRetry(t *testing.T) { if i+1 <= play.archiveProviders { addons = []string{"archive"} } - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, providers[i].mockReliabilityManager = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, addons, fmt.Sprintf("provider%d", i)) + + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: addons, + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, providers[i].mockReliabilityManager = createRpcProvider(t, ctx, rpcProviderOptions) providers[i].replySetter.replyDataBuf = []byte(`{"result": "success"}`) if i+1 <= play.nodeErrorProviders { providers[i].replySetter.replyDataBuf = []byte(`{"error": "failure", "message": "test", "code": "-32132"}`) @@ -1047,7 +1230,18 @@ func TestArchiveProvidersRetry(t *testing.T) { PairingEpoch: epoch, } } - rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) require.NotNil(t, rpcconsumerServer) client := http.Client{Timeout: 1000 * time.Millisecond} @@ -1095,7 +1289,17 @@ func TestSameProviderConflictReport(t *testing.T) { ctx := context.Background() providerDataI := providers[i] listenAddress := addressGen.GetAddress() - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, providers[i].mockReliabilityManager = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, providers[i].mockReliabilityManager = createRpcProvider(t, ctx, rpcProviderOptions) providers[i].replySetter.replyDataBuf = []byte(fmt.Sprintf(`{"result": %d}`, i+1)) } } @@ -1129,7 +1333,6 @@ func TestSameProviderConflictReport(t *testing.T) { specId := "LAV1" apiInterface := spectypes.APIInterfaceRest epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := 1 @@ -1140,8 +1343,17 @@ func TestSameProviderConflictReport(t *testing.T) { initProvidersData(consumerAccount, providers, specId, apiInterface, lavaChainID) - pairingList := initPairingList(providers, epoch) - rpcconsumerServer, mockConsumerStateTracker := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: initPairingList(providers, epoch), + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, mockConsumerStateTracker := createRpcConsumer(t, ctx, rpcConsumerOptions) conflictSent := false wg := sync.WaitGroup{} @@ -1202,7 +1414,6 @@ func TestSameProviderConflictReport(t *testing.T) { specId := "LAV1" apiInterface := spectypes.APIInterfaceRest epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := 2 @@ -1212,8 +1423,17 @@ func TestSameProviderConflictReport(t *testing.T) { initProvidersData(consumerAccount, providers, specId, apiInterface, lavaChainID) - pairingList := initPairingList(providers, epoch) - rpcconsumerServer, mockConsumerStateTracker := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: initPairingList(providers, epoch), + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, mockConsumerStateTracker := createRpcConsumer(t, ctx, rpcConsumerOptions) twoProvidersConflictSent := false sameProviderConflictSent := false @@ -1285,7 +1505,6 @@ func TestConsumerProviderStatic(t *testing.T) { specId := "LAV1" apiInterface := spectypes.APIInterfaceTendermintRPC epoch := uint64(100) - requiredResponses := 1 lavaChainID := "lava" numProviders := 1 @@ -1311,7 +1530,17 @@ func TestConsumerProviderStatic(t *testing.T) { ctx := context.Background() providerDataI := providers[i] listenAddress := addressGen.GetAddress() - providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i)) + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: providerDataI.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: fmt.Sprintf("provider%d", i), + } + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, rpcProviderOptions) } // provider is static for i := 0; i < numProviders; i++ { @@ -1331,7 +1560,18 @@ func TestConsumerProviderStatic(t *testing.T) { StaticProvider: true, } } - rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) require.NotNil(t, rpcconsumerServer) client := http.Client{} // consumer sends the relay to a provider with an address BANANA+%d so the provider needs to skip validations for this to work @@ -1351,3 +1591,359 @@ func TestConsumerProviderStatic(t *testing.T) { require.Equal(t, providers[0].replySetter.replyDataBuf, bodyBytes) resp.Body.Close() } + +func jsonRpcIdToInt(t *testing.T, rawID json.RawMessage) int { + var idInterface interface{} + err := json.Unmarshal(rawID, &idInterface) + require.NoError(t, err) + + id, ok := idInterface.(float64) + require.True(t, ok, idInterface) + return int(id) +} + +func TestConsumerProviderWithConsumerSideCache(t *testing.T) { + ctx := context.Background() + // can be any spec and api interface + specId := "LAV1" + apiInterface := spectypes.APIInterfaceTendermintRPC + epoch := uint64(100) + lavaChainID := "lava" + + consumerListenAddress := addressGen.GetAddress() + cacheListenAddress := addressGen.GetAddress() + pairingList := map[uint64]*lavasession.ConsumerSessionsWithProvider{} + type providerData struct { + account sigs.Account + endpoint *lavasession.RPCProviderEndpoint + replySetter *ReplySetter + } + + consumerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + providerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + provider := providerData{account: providerAccount} + listenAddress := addressGen.GetAddress() + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: listenAddress, + account: provider.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: "provider", + } + _, provider.endpoint, provider.replySetter, _, _ = createRpcProvider(t, ctx, rpcProviderOptions) + provider.replySetter.handler = func(req []byte, header http.Header) (data []byte, status int) { + var jsonRpcMessage rpcInterfaceMessages.JsonrpcMessage + err := json.Unmarshal(req, &jsonRpcMessage) + require.NoError(t, err, req) + + response := fmt.Sprintf(`{"jsonrpc":"2.0","result": {}, "id": %v}`, string(jsonRpcMessage.ID)) + return []byte(response), http.StatusOK + } + + pairingList[0] = &lavasession.ConsumerSessionsWithProvider{ + PublicLavaAddress: provider.account.Addr.String(), + Endpoints: []*lavasession.Endpoint{ + { + NetworkAddress: provider.endpoint.NetworkAddress.Address, + Enabled: true, + Geolocation: 1, + }, + }, + Sessions: map[int64]*lavasession.SingleConsumerSession{}, + MaxComputeUnits: 10000, + UsedComputeUnits: 0, + PairingEpoch: epoch, + } + + createCacheServer(t, ctx, cacheListenAddress) + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + cacheListenAddress: cacheListenAddress, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) + require.NotNil(t, rpcconsumerServer) + + client := http.Client{} + id := 0 + sendMessage := func(method string, params []string) http.Header { + // Get latest block + body := fmt.Sprintf(`{"jsonrpc":"2.0","method":"%v","params": [%v], "id":%v}`, method, strings.Join(params, ","), id) + resp, err := client.Post("http://"+consumerListenAddress, "application/json", bytes.NewBuffer([]byte(body))) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var jsonRpcMessage rpcInterfaceMessages.JsonrpcMessage + err = json.Unmarshal(bodyBytes, &jsonRpcMessage) + require.NoError(t, err) + + respId := jsonRpcIdToInt(t, jsonRpcMessage.ID) + require.Equal(t, id, respId) + resp.Body.Close() + id++ + + return resp.Header + } + + // Get latest for sanity check + providerAddr := provider.account.Addr.String() + headers := sendMessage("status", []string{}) + require.Equal(t, providerAddr, headers.Get(common.PROVIDER_ADDRESS_HEADER_NAME)) + + // Get block, this should be cached for next time + headers = sendMessage("block", []string{"1000"}) + require.Equal(t, providerAddr, headers.Get(common.PROVIDER_ADDRESS_HEADER_NAME)) + + // Get block again, this time it should be from cache + headers = sendMessage("block", []string{"1000"}) + require.Equal(t, "Cached", headers.Get(common.PROVIDER_ADDRESS_HEADER_NAME)) +} + +func TestConsumerProviderWithProviderSideCache(t *testing.T) { + ctx := context.Background() + // can be any spec and api interface + specId := "LAV1" + apiInterface := spectypes.APIInterfaceTendermintRPC + epoch := uint64(100) + lavaChainID := "lava" + + consumerListenAddress := addressGen.GetAddress() + cacheListenAddress := addressGen.GetAddress() + pairingList := map[uint64]*lavasession.ConsumerSessionsWithProvider{} + type providerData struct { + account sigs.Account + endpoint *lavasession.RPCProviderEndpoint + replySetter *ReplySetter + } + + consumerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + providerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + provider := providerData{account: providerAccount} + providerListenAddress := addressGen.GetAddress() + + createCacheServer(t, ctx, cacheListenAddress) + testJsonRpcId := 42 + nodeRequestsCounter := 0 + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: providerListenAddress, + account: provider.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: "provider", + cacheListenAddress: cacheListenAddress, + } + _, provider.endpoint, provider.replySetter, _, _ = createRpcProvider(t, ctx, rpcProviderOptions) + provider.replySetter.handler = func(req []byte, header http.Header) (data []byte, status int) { + var jsonRpcMessage rpcInterfaceMessages.JsonrpcMessage + err := json.Unmarshal(req, &jsonRpcMessage) + require.NoError(t, err, req) + + reqId := jsonRpcIdToInt(t, jsonRpcMessage.ID) + if reqId == testJsonRpcId { + nodeRequestsCounter++ + } + + response := fmt.Sprintf(`{"jsonrpc":"2.0","result": {}, "id": %v}`, string(jsonRpcMessage.ID)) + return []byte(response), http.StatusOK + } + + pairingList[0] = &lavasession.ConsumerSessionsWithProvider{ + PublicLavaAddress: provider.account.Addr.String(), + Endpoints: []*lavasession.Endpoint{ + { + NetworkAddress: provider.endpoint.NetworkAddress.Address, + Enabled: true, + Geolocation: 1, + }, + }, + Sessions: map[int64]*lavasession.SingleConsumerSession{}, + MaxComputeUnits: 10000, + UsedComputeUnits: 0, + PairingEpoch: epoch, + } + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) + require.NotNil(t, rpcconsumerServer) + + client := http.Client{} + sendMessage := func(method string, params []string) http.Header { + body := fmt.Sprintf(`{"jsonrpc":"2.0","method":"%v","params": [%v], "id":%v}`, method, strings.Join(params, ","), testJsonRpcId) + resp, err := client.Post("http://"+consumerListenAddress, "application/json", bytes.NewBuffer([]byte(body))) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var jsonRpcMessage rpcInterfaceMessages.JsonrpcMessage + err = json.Unmarshal(bodyBytes, &jsonRpcMessage) + require.NoError(t, err) + + respId := jsonRpcIdToInt(t, jsonRpcMessage.ID) + require.Equal(t, testJsonRpcId, respId) + resp.Body.Close() + + return resp.Header + } + + // Get latest for sanity check + sendMessage("status", []string{}) + + for i := 0; i < 5; i++ { + // Get block + sendMessage("block", []string{"1000"}) + } + + // Verify node was called to only twice + require.Equal(t, 2, nodeRequestsCounter) +} + +func TestConsumerProviderWithConsumerAndProviderSideCache(t *testing.T) { + ctx := context.Background() + // can be any spec and api interface + specId := "LAV1" + apiInterface := spectypes.APIInterfaceTendermintRPC + epoch := uint64(100) + lavaChainID := "lava" + + consumerListenAddress := addressGen.GetAddress() + consumerCacheListenAddress := addressGen.GetAddress() + providerCacheListenAddress := addressGen.GetAddress() + + createCacheServer(t, ctx, consumerCacheListenAddress) + createCacheServer(t, ctx, providerCacheListenAddress) + + pairingList := map[uint64]*lavasession.ConsumerSessionsWithProvider{} + type providerData struct { + account sigs.Account + endpoint *lavasession.RPCProviderEndpoint + replySetter *ReplySetter + } + + consumerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + providerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + provider := providerData{account: providerAccount} + providerListenAddress := addressGen.GetAddress() + + testJsonRpcId := 42 + nodeRequestsCounter := 0 + rpcProviderOptions := rpcProviderOptions{ + consumerAddress: consumerAccount.Addr.String(), + specId: specId, + apiInterface: apiInterface, + listenAddress: providerListenAddress, + account: provider.account, + lavaChainID: lavaChainID, + addons: []string(nil), + providerUniqueId: "provider", + cacheListenAddress: providerCacheListenAddress, + } + _, provider.endpoint, provider.replySetter, _, _ = createRpcProvider(t, ctx, rpcProviderOptions) + provider.replySetter.handler = func(req []byte, header http.Header) (data []byte, status int) { + var jsonRpcMessage rpcInterfaceMessages.JsonrpcMessage + err := json.Unmarshal(req, &jsonRpcMessage) + require.NoError(t, err, req) + + reqId := jsonRpcIdToInt(t, jsonRpcMessage.ID) + if reqId == testJsonRpcId { + nodeRequestsCounter++ + } + + response := fmt.Sprintf(`{"jsonrpc":"2.0","result": {}, "id": %v}`, string(jsonRpcMessage.ID)) + return []byte(response), http.StatusOK + } + + pairingList[0] = &lavasession.ConsumerSessionsWithProvider{ + PublicLavaAddress: provider.account.Addr.String(), + Endpoints: []*lavasession.Endpoint{ + { + NetworkAddress: provider.endpoint.NetworkAddress.Address, + Enabled: true, + Geolocation: 1, + }, + }, + Sessions: map[int64]*lavasession.SingleConsumerSession{}, + MaxComputeUnits: 10000, + UsedComputeUnits: 0, + PairingEpoch: epoch, + } + + rpcConsumerOptions := rpcConsumerOptions{ + specId: specId, + apiInterface: apiInterface, + account: consumerAccount, + consumerListenAddress: consumerListenAddress, + epoch: epoch, + pairingList: pairingList, + requiredResponses: 1, + lavaChainID: lavaChainID, + cacheListenAddress: consumerCacheListenAddress, + } + rpcconsumerServer, _ := createRpcConsumer(t, ctx, rpcConsumerOptions) + require.NotNil(t, rpcconsumerServer) + + client := http.Client{} + sendMessage := func(method string, params []string) http.Header { + body := fmt.Sprintf(`{"jsonrpc":"2.0","method":"%v","params": [%v], "id":%v}`, method, strings.Join(params, ","), testJsonRpcId) + resp, err := client.Post("http://"+consumerListenAddress, "application/json", bytes.NewBuffer([]byte(body))) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var jsonRpcMessage rpcInterfaceMessages.JsonrpcMessage + err = json.Unmarshal(bodyBytes, &jsonRpcMessage) + require.NoError(t, err) + + respId := jsonRpcIdToInt(t, jsonRpcMessage.ID) + require.Equal(t, testJsonRpcId, respId) + resp.Body.Close() + + return resp.Header + } + + providerAddr := provider.account.Addr.String() + // Get latest for sanity check + headers := sendMessage("status", []string{}) + require.Equal(t, providerAddr, headers.Get(common.PROVIDER_ADDRESS_HEADER_NAME)) + + // Get block, this should be cached for next time + headers = sendMessage("block", []string{"1000"}) + require.Equal(t, providerAddr, headers.Get(common.PROVIDER_ADDRESS_HEADER_NAME)) + + for i := 0; i < 5; i++ { + // Get block again, this time it should be from cache + headers = sendMessage("block", []string{"1000"}) + require.Equal(t, "Cached", headers.Get(common.PROVIDER_ADDRESS_HEADER_NAME)) + } + + // Verify node was called to only twice + require.Equal(t, 2, nodeRequestsCounter) +} From ab51b7307a891467b04a9870a2eae680d76af990 Mon Sep 17 00:00:00 2001 From: Ran Mishael <106548467+ranlavanet@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:40:21 +0200 Subject: [PATCH 2/6] feat: PRT - adding stateful api header for hanging relays (#1712) --- protocol/common/endpoints.go | 1 + protocol/rpcconsumer/rpcconsumer_server.go | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/protocol/common/endpoints.go b/protocol/common/endpoints.go index 99cd6504ac..435fcde26c 100644 --- a/protocol/common/endpoints.go +++ b/protocol/common/endpoints.go @@ -28,6 +28,7 @@ const ( NODE_ERRORS_PROVIDERS_HEADER_NAME = "Lava-Node-Errors-providers" REPORTED_PROVIDERS_HEADER_NAME = "Lava-Reported-Providers" USER_REQUEST_TYPE = "lava-user-request-type" + STATEFUL_API_HEADER = "lava-stateful-api" LAVA_IDENTIFIED_NODE_ERROR_HEADER = "lava-identified-node-error" LAVAP_VERSION_HEADER_NAME = "Lavap-Version" LAVA_CONSUMER_PROCESS_GUID = "lava-consumer-process-guid" diff --git a/protocol/rpcconsumer/rpcconsumer_server.go b/protocol/rpcconsumer/rpcconsumer_server.go index d74a4fd141..af11a2d952 100644 --- a/protocol/rpcconsumer/rpcconsumer_server.go +++ b/protocol/rpcconsumer/rpcconsumer_server.go @@ -1306,6 +1306,15 @@ func (rpccs *RPCConsumerServer) appendHeadersToRelayResult(ctx context.Context, }) } + // add stateful API (hanging, transactions) + if protocolMessage.GetApi().Category.Stateful == common.CONSISTENCY_SELECT_ALL_PROVIDERS { + metadataReply = append(metadataReply, + pairingtypes.Metadata{ + Name: common.STATEFUL_API_HEADER, + Value: "true", + }) + } + // add user requested API metadataReply = append(metadataReply, pairingtypes.Metadata{ From a5206e752de5fb558943caf12f7840c4e0236e3e Mon Sep 17 00:00:00 2001 From: oren-lava <111131399+oren-lava@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:27:33 +0300 Subject: [PATCH 3/6] fix: audit fixes (#1672) * nilaway fixes * added warning comment on detection index * remove redundant conflict data checks * reverted qos req and geo req turning pointers * removed pairing query cache * fix comment --- x/conflict/keeper/conflict.go | 12 +++++++ x/conflict/keeper/msg_server_detection.go | 3 ++ x/conflict/types/conflict.go | 24 ++++++++++++++ x/pairing/keeper/keeper.go | 8 ----- x/pairing/keeper/pairing.go | 30 +++-------------- x/pairing/keeper/pairing_cache.go | 24 -------------- x/pairing/keeper/pairing_cache_test.go | 39 +---------------------- x/pairing/keeper/scores/pairing_slot.go | 2 +- x/pairing/keeper/scores/stake_req.go | 3 ++ x/projects/keeper/creation.go | 5 ++- x/projects/types/project.go | 3 ++ x/spec/keeper/spec.go | 6 ++++ x/spec/types/api_collection.go | 3 ++ x/spec/types/combinable.go | 2 +- x/spec/types/spec.go | 5 ++- x/subscription/keeper/subscription.go | 10 ++++++ 16 files changed, 78 insertions(+), 101 deletions(-) create mode 100644 x/conflict/types/conflict.go diff --git a/x/conflict/keeper/conflict.go b/x/conflict/keeper/conflict.go index a3e6e31a1a..b1f5c19ae9 100644 --- a/x/conflict/keeper/conflict.go +++ b/x/conflict/keeper/conflict.go @@ -20,6 +20,11 @@ func (k Keeper) ValidateFinalizationConflict(ctx sdk.Context, conflictData *type } func (k Keeper) ValidateResponseConflict(ctx sdk.Context, conflictData *types.ResponseConflict, clientAddr sdk.AccAddress) error { + // 0. validate conflictData is not nil + if conflictData.IsDataNil() { + return fmt.Errorf("ValidateResponseConflict: conflict data is nil") + } + // 1. validate mismatching data chainID := conflictData.ConflictRelayData0.Request.RelaySession.SpecId if chainID != conflictData.ConflictRelayData1.Request.RelaySession.SpecId { @@ -279,6 +284,10 @@ func (k Keeper) ValidateSameProviderConflict(ctx sdk.Context, conflictData *type func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization, spec *spectypes.Spec) (finalizedBlocksMarshalled map[int64]string, earliestFinalizedBlock int64, latestFinalizedBlock int64, err error) { EMPTY_MAP := map[int64]string{} + // verify spec is not nil + if spec == nil { + return EMPTY_MAP, 0, 0, fmt.Errorf("validateBlockHeights: spec is nil") + } // Unmarshall finalized blocks finalizedBlocks := map[int64]string{} @@ -312,6 +321,9 @@ func (k Keeper) validateBlockHeights(relayFinalization *types.RelayFinalization, } func (k Keeper) validateFinalizedBlock(relayFinalization *types.RelayFinalization, latestFinalizedBlock int64, spec *spectypes.Spec) error { + if spec == nil { + return fmt.Errorf("validateFinalizedBlock: spec is nil") + } latestBlock := relayFinalization.GetLatestBlock() blockDistanceToFinalization := int64(spec.BlockDistanceForFinalizedData) diff --git a/x/conflict/keeper/msg_server_detection.go b/x/conflict/keeper/msg_server_detection.go index e9d262062d..5a2dbd3727 100644 --- a/x/conflict/keeper/msg_server_detection.go +++ b/x/conflict/keeper/msg_server_detection.go @@ -14,6 +14,9 @@ import ( "golang.org/x/exp/slices" ) +// DetectionIndex creates an index for detection instances. +// WARNING: the detection index should not be used for prefixed iteration since it doesn't contain delimeters +// thus it's not sanitized for such iterations and could cause issues in the future as the codebase evolves. func DetectionIndex(creatorAddr string, conflict *types.ResponseConflict, epochStart uint64) string { return creatorAddr + conflict.ConflictRelayData0.Request.RelaySession.Provider + conflict.ConflictRelayData1.Request.RelaySession.Provider + strconv.FormatUint(epochStart, 10) } diff --git a/x/conflict/types/conflict.go b/x/conflict/types/conflict.go new file mode 100644 index 0000000000..3b9a7c72bb --- /dev/null +++ b/x/conflict/types/conflict.go @@ -0,0 +1,24 @@ +package types + +func (c *ResponseConflict) IsDataNil() bool { + if c == nil { + return true + } + if c.ConflictRelayData0 == nil || c.ConflictRelayData1 == nil { + return true + } + if c.ConflictRelayData0.Request == nil || c.ConflictRelayData1.Request == nil { + return true + } + if c.ConflictRelayData0.Request.RelayData == nil || c.ConflictRelayData1.Request.RelayData == nil { + return true + } + if c.ConflictRelayData0.Request.RelaySession == nil || c.ConflictRelayData1.Request.RelaySession == nil { + return true + } + if c.ConflictRelayData0.Reply == nil || c.ConflictRelayData1.Reply == nil { + return true + } + + return false +} diff --git a/x/pairing/keeper/keeper.go b/x/pairing/keeper/keeper.go index 5451e3da5d..88642b77c0 100644 --- a/x/pairing/keeper/keeper.go +++ b/x/pairing/keeper/keeper.go @@ -4,7 +4,6 @@ import ( "fmt" storetypes "github.com/cosmos/cosmos-sdk/store/types" - epochstoragetypes "github.com/lavanet/lava/v3/x/epochstorage/types" timerstoretypes "github.com/lavanet/lava/v3/x/timerstore/types" "github.com/cometbft/cometbft/libs/log" @@ -35,8 +34,6 @@ type ( downtimeKeeper types.DowntimeKeeper dualstakingKeeper types.DualstakingKeeper stakingKeeper types.StakingKeeper - - pairingQueryCache *map[string][]epochstoragetypes.StakeEntry } ) @@ -74,8 +71,6 @@ func NewKeeper( ps = ps.WithKeyTable(types.ParamKeyTable()) } - emptypairingQueryCache := map[string][]epochstoragetypes.StakeEntry{} - keeper := &Keeper{ cdc: cdc, storeKey: storeKey, @@ -91,7 +86,6 @@ func NewKeeper( downtimeKeeper: downtimeKeeper, dualstakingKeeper: dualstakingKeeper, stakingKeeper: stakingKeeper, - pairingQueryCache: &emptypairingQueryCache, } // note that the timer and badgeUsedCu keys are the same (so we can use only the second arg) @@ -113,8 +107,6 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger { func (k Keeper) BeginBlock(ctx sdk.Context) { if k.epochStorageKeeper.IsEpochStart(ctx) { - // reset pairing query cache every epoch - *k.pairingQueryCache = map[string][]epochstoragetypes.StakeEntry{} // remove old session payments k.RemoveOldEpochPayments(ctx) // unstake/jail unresponsive providers diff --git a/x/pairing/keeper/pairing.go b/x/pairing/keeper/pairing.go index 023eae047d..638c01648a 100644 --- a/x/pairing/keeper/pairing.go +++ b/x/pairing/keeper/pairing.go @@ -80,7 +80,7 @@ func (k Keeper) GetPairingForClient(ctx sdk.Context, chainID string, clientAddre return nil, fmt.Errorf("invalid user for pairing: %s", err.Error()) } - providers, _, _, err = k.getPairingForClient(ctx, chainID, block, strictestPolicy, cluster, project.Index, false, true) + providers, _, _, err = k.getPairingForClient(ctx, chainID, block, strictestPolicy, cluster, project.Index, false) return providers, err } @@ -90,7 +90,7 @@ func (k Keeper) CalculatePairingChance(ctx sdk.Context, provider string, chainID totalScore := cosmosmath.ZeroUint() providerScore := cosmosmath.ZeroUint() - _, _, scores, err := k.getPairingForClient(ctx, chainID, uint64(ctx.BlockHeight()), policy, cluster, "dummy", true, false) + _, _, scores, err := k.getPairingForClient(ctx, chainID, uint64(ctx.BlockHeight()), policy, cluster, "dummy", true) if err != nil { return cosmosmath.LegacyZeroDec(), err } @@ -117,22 +117,12 @@ func (k Keeper) CalculatePairingChance(ctx sdk.Context, provider string, chainID // function used to get a new pairing from provider and client // first argument has all metadata, second argument is only the addresses -// useCache is a boolean argument that is used to determine whether pairing cache should be used -// Note: useCache should only be true for queries! functions that write to the state and use this function should never put useCache=true -func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint64, policy *planstypes.Policy, cluster string, projectIndex string, calcChance bool, useCache bool) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, providerScores []*pairingscores.PairingScore, errorRet error) { +func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint64, policy *planstypes.Policy, cluster string, projectIndex string, calcChance bool) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, providerScores []*pairingscores.PairingScore, errorRet error) { epoch, providersType, err := k.VerifyPairingData(ctx, chainID, block) if err != nil { return nil, 0, nil, fmt.Errorf("invalid pairing data: %s", err) } - // to be used only in queries as this changes gas calculations, and therefore must not be part of consensus - if useCache { - providers, found := k.GetPairingQueryCache(projectIndex, chainID, epoch) - if found { - return providers, policy.EpochCuLimit, nil, nil - } - } - stakeEntries := k.epochStorageKeeper.GetAllStakeEntriesForEpochChainId(ctx, epoch, chainID) if len(stakeEntries) == 0 { return nil, 0, nil, fmt.Errorf("did not find providers for pairing: epoch:%d, chainID: %s", block, chainID) @@ -149,9 +139,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6 stakeEntriesFiltered = append(stakeEntriesFiltered, stakeEntries[i]) } } - if useCache { - k.SetPairingQueryCache(projectIndex, chainID, epoch, stakeEntriesFiltered) - } return stakeEntriesFiltered, policy.EpochCuLimit, nil, nil } @@ -171,9 +158,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6 for _, score := range providerScores { filteredEntries = append(filteredEntries, *score.Provider) } - if useCache { - k.SetPairingQueryCache(projectIndex, chainID, epoch, filteredEntries) - } return filteredEntries, policy.EpochCuLimit, nil, nil } @@ -194,10 +178,6 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, block uint6 prevGroupSlot = group } - if useCache { - k.SetPairingQueryCache(projectIndex, chainID, epoch, providers) - } - return providers, policy.EpochCuLimit, providerScores, nil } @@ -350,7 +330,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid return false, allowedCU, []epochstoragetypes.StakeEntry{}, fmt.Errorf("invalid user for pairing: %s", err.Error()) } - validAddresses, allowedCU, _, err = k.getPairingForClient(ctx, chainID, epoch, strictestPolicy, cluster, project.Index, false, false) + validAddresses, allowedCU, _, err = k.getPairingForClient(ctx, chainID, epoch, strictestPolicy, cluster, project.Index, false) if err != nil { return false, allowedCU, []epochstoragetypes.StakeEntry{}, err } @@ -363,7 +343,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, provid utils.LavaFormatPanic("critical: invalid provider address for payment", err, utils.Attribute{Key: "chainID", Value: chainID}, utils.Attribute{Key: "client", Value: project.Subscription}, - utils.Attribute{Key: "provider", Value: providerAccAddr.String()}, + utils.Attribute{Key: "provider", Value: possibleAddr.Address}, utils.Attribute{Key: "epochBlock", Value: strconv.FormatUint(epoch, 10)}, ) } diff --git a/x/pairing/keeper/pairing_cache.go b/x/pairing/keeper/pairing_cache.go index 5cad220cd6..a3652babc8 100644 --- a/x/pairing/keeper/pairing_cache.go +++ b/x/pairing/keeper/pairing_cache.go @@ -38,27 +38,3 @@ func (k Keeper) ResetPairingRelayCache(ctx sdk.Context) { store.Delete(iterator.Key()) } } - -// the cache used for the query, does not write into state -func (k Keeper) SetPairingQueryCache(project string, chainID string, epoch uint64, pairedProviders []epochstoragetypes.StakeEntry) { - if k.pairingQueryCache == nil { - // pairing cache is not initialized, will be in next epoch so simply skip - return - } - key := types.NewPairingCacheKey(project, chainID, epoch) - - (*k.pairingQueryCache)[key] = pairedProviders -} - -func (k Keeper) GetPairingQueryCache(project string, chainID string, epoch uint64) ([]epochstoragetypes.StakeEntry, bool) { - if k.pairingQueryCache == nil { - // pairing cache is not initialized, will be in next epoch so simply skip - return nil, false - } - key := types.NewPairingCacheKey(project, chainID, epoch) - if providers, ok := (*k.pairingQueryCache)[key]; ok { - return providers, true - } - - return nil, false -} diff --git a/x/pairing/keeper/pairing_cache_test.go b/x/pairing/keeper/pairing_cache_test.go index 31b6669192..fb1f0829ef 100644 --- a/x/pairing/keeper/pairing_cache_test.go +++ b/x/pairing/keeper/pairing_cache_test.go @@ -7,44 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -// TestPairingQueryCache tests the following: -// 1. The pairing query cache is reset every epoch -// 2. Getting pairing with a query using an existent cache entry consumes fewer gas than without one -func TestPairingQueryCache(t *testing.T) { - ts := newTester(t) - ts.setupForPayments(1, 1, 0) // 1 provider, 1 client, default providers-to-pair - - _, consumer := ts.GetAccount(common.CONSUMER, 0) - - getPairingGas := func(ts *tester) uint64 { - gm := ts.Ctx.GasMeter() - before := gm.GasConsumed() - _, err := ts.QueryPairingGetPairing(ts.spec.Index, consumer) - require.NoError(t, err) - return gm.GasConsumed() - before - } - - // query for pairing for the first time - empty cache - emptyCacheGas := getPairingGas(ts) - - // query for pairing for the second time - non-empty cache - filledCacheGas := getPairingGas(ts) - - // second time gas should be smaller than first time - require.Less(t, filledCacheGas, emptyCacheGas) - - // advance block to test it stays the same (should still be less than empty cache gas) - ts.AdvanceBlock() - filledAfterBlockCacheGas := getPairingGas(ts) - require.Less(t, filledAfterBlockCacheGas, emptyCacheGas) - - // advance epoch to reset the cache - ts.AdvanceEpoch() - emptyCacheAgainGas := getPairingGas(ts) - require.Equal(t, emptyCacheGas, emptyCacheAgainGas) -} - -// TestPairingQueryCache tests the following: +// TestPairingRelayCache tests the following: // 1. The pairing relay cache is reset every block // 2. Getting pairing in relay payment using an existent cache entry consumes fewer gas than without one func TestPairingRelayCache(t *testing.T) { diff --git a/x/pairing/keeper/scores/pairing_slot.go b/x/pairing/keeper/scores/pairing_slot.go index 5d6935f56e..fe13eaff63 100644 --- a/x/pairing/keeper/scores/pairing_slot.go +++ b/x/pairing/keeper/scores/pairing_slot.go @@ -36,7 +36,7 @@ func (psg PairingSlotGroup) Subtract(other *PairingSlotGroup) *PairingSlot { otherReq, found := other.Reqs[key] if !found { reqsDiff[key] = req - } else if !req.Equal(otherReq) { + } else if req != nil && !req.Equal(otherReq) { reqsDiff[key] = req } } diff --git a/x/pairing/keeper/scores/stake_req.go b/x/pairing/keeper/scores/stake_req.go index 44eb3ff5c0..74af1713fb 100644 --- a/x/pairing/keeper/scores/stake_req.go +++ b/x/pairing/keeper/scores/stake_req.go @@ -17,6 +17,9 @@ func (sr *StakeReq) Init(policy planstypes.Policy) bool { // Score calculates the the provider score as the normalized stake func (sr *StakeReq) Score(score PairingScore) math.Uint { + if sr == nil { + return math.OneUint() + } effectiveStake := score.Provider.TotalStake() if !effectiveStake.IsPositive() { return math.OneUint() diff --git a/x/projects/keeper/creation.go b/x/projects/keeper/creation.go index 423aed7227..a3f1c75741 100644 --- a/x/projects/keeper/creation.go +++ b/x/projects/keeper/creation.go @@ -191,7 +191,7 @@ func (k Keeper) registerKey(ctx sdk.Context, key types.ProjectKey, project *type // check that the developer key is valid, and that it does not already // belong to a different project. - if found && devkeyData.ProjectID != project.Index { + if found && devkeyData.ProjectID != project.GetIndex() { return utils.LavaFormatWarning("failed to register key", fmt.Errorf("key already exists"), utils.Attribute{Key: "key", Value: key.Key}, @@ -254,10 +254,9 @@ func (k Keeper) unregisterKey(ctx sdk.Context, key types.ProjectKey, project *ty // the developer key belongs to a different project if devkeyData.ProjectID != project.GetIndex() { return utils.LavaFormatWarning("failed to unregister key", legacyerrors.ErrNotFound, - utils.Attribute{Key: "projectID", Value: project.Index}, + utils.Attribute{Key: "projectID", Value: project.GetIndex()}, utils.Attribute{Key: "key", Value: key.Key}, utils.Attribute{Key: "keyTypes", Value: key.Kinds}, - utils.Attribute{Key: "projectID", Value: project.GetIndex()}, utils.Attribute{Key: "otherID", Value: devkeyData.ProjectID}, ) } diff --git a/x/projects/types/project.go b/x/projects/types/project.go index d1f7f5e3a2..9289e67257 100644 --- a/x/projects/types/project.go +++ b/x/projects/types/project.go @@ -82,6 +82,9 @@ func (project *Project) GetKey(key string) ProjectKey { } func (project *Project) AppendKey(key ProjectKey) bool { + if project == nil { + return false + } for i, projectKey := range project.ProjectKeys { if projectKey.Key == key.Key { project.ProjectKeys[i].Kinds |= key.Kinds diff --git a/x/spec/keeper/spec.go b/x/spec/keeper/spec.go index 59c0a5e6ce..b52e77fe1a 100644 --- a/x/spec/keeper/spec.go +++ b/x/spec/keeper/spec.go @@ -107,6 +107,9 @@ func (k Keeper) RefreshSpec(ctx sdk.Context, spec types.Spec, ancestors []types. } if details, err := spec.ValidateSpec(k.MaxCU(ctx)); err != nil { + if details != nil { + details = map[string]string{} + } details["invalidates"] = spec.Index attrs := utils.StringMapToAttributes(details) return nil, utils.LavaFormatWarning("spec refresh failed (invalidate)", err, attrs...) @@ -137,6 +140,9 @@ func (k Keeper) doExpandSpec( inherit *map[string]bool, details string, ) (string, error) { + if spec == nil { + return "", fmt.Errorf("doExpandSpec: spec is nil") + } parentsCollections := map[types.CollectionData][]*types.ApiCollection{} if len(spec.Imports) != 0 { diff --git a/x/spec/types/api_collection.go b/x/spec/types/api_collection.go index 06561b16bf..7df10e64ee 100644 --- a/x/spec/types/api_collection.go +++ b/x/spec/types/api_collection.go @@ -65,6 +65,9 @@ func (apic *ApiCollection) InheritAllFields(myCollections map[CollectionData]*Ap // changes in place inside the apic // nil merge maps means not to combine that field func (apic *ApiCollection) CombineWithOthers(others []*ApiCollection, combineWithDisabled, allowOverwrite bool) (err error) { + if apic == nil { + return fmt.Errorf("CombineWithOthers: API collection is nil") + } mergedApis := map[string]interface{}{} mergedHeaders := map[string]interface{}{} mergedParsers := map[string]interface{}{} diff --git a/x/spec/types/combinable.go b/x/spec/types/combinable.go index f0c9b781a5..2ee3c5a7b8 100644 --- a/x/spec/types/combinable.go +++ b/x/spec/types/combinable.go @@ -141,7 +141,7 @@ func CombineUnique[T Combinable](appendFrom, appendTo []T, currentMap map[string } else { // overwriting the inherited field might need Overwrite actions if overwritten, isOverwritten := current.currentCombinable.Overwrite(combinable); isOverwritten { - if appendTo[current.index].Differeniator() != combinable.Differeniator() { + if len(appendTo) <= current.index || appendTo[current.index].Differeniator() != combinable.Differeniator() { return nil, fmt.Errorf("differentiator mismatch in overwrite %s vs %s", combinable.Differeniator(), appendTo[current.index].Differeniator()) } overwrittenT, ok := overwritten.(T) diff --git a/x/spec/types/spec.go b/x/spec/types/spec.go index 237feb9bf4..f41ee8fda0 100644 --- a/x/spec/types/spec.go +++ b/x/spec/types/spec.go @@ -200,6 +200,9 @@ func (spec Spec) ValidateSpec(maxCU uint64) (map[string]string, error) { } func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*ApiCollection) error { + if spec == nil { + return fmt.Errorf("CombineCollections: spec is nil") + } collectionDataList := make([]CollectionData, 0) // Populate the keys slice with the map keys for key := range parentsCollections { @@ -225,7 +228,7 @@ func (spec *Spec) CombineCollections(parentsCollections map[CollectionData][]*Ap break } } - if !combined.Enabled { + if combined == nil || !combined.Enabled { // no collections enabled to combine, we skip this continue } diff --git a/x/subscription/keeper/subscription.go b/x/subscription/keeper/subscription.go index 7a8a04d3f9..7cde34e7ed 100644 --- a/x/subscription/keeper/subscription.go +++ b/x/subscription/keeper/subscription.go @@ -192,6 +192,9 @@ func (k Keeper) verifySubscriptionBuyInputAndGetPlan(ctx sdk.Context, block uint func (k Keeper) createNewSubscription(ctx sdk.Context, plan *planstypes.Plan, creator, consumer string, block uint64, autoRenewalFlag bool, ) (types.Subscription, error) { + if plan == nil { + return types.Subscription{}, utils.LavaFormatError("plan is nil", fmt.Errorf("createNewSubscription: cannot create new subscription")) + } autoRenewalNextPlan := types.AUTO_RENEWAL_PLAN_NONE if autoRenewalFlag { // On subscription creation, auto renewal is set to the subscription's plan @@ -223,6 +226,13 @@ func (k Keeper) createNewSubscription(ctx sdk.Context, plan *planstypes.Plan, cr } func (k Keeper) upgradeSubscriptionPlan(ctx sdk.Context, sub *types.Subscription, newPlan *planstypes.Plan) error { + if newPlan == nil { + return utils.LavaFormatError("new plan is nil", fmt.Errorf("upgradeSubscriptionPlan: cannot upgrade subscription plan")) + } + if sub == nil { + return utils.LavaFormatError("subscription is nil", fmt.Errorf("upgradeSubscriptionPlan: cannot upgrade subscription plan")) + } + block := uint64(ctx.BlockHeight()) nextEpoch, err := k.epochstorageKeeper.GetNextEpoch(ctx, block) From bdb5f045a83dba4f579322a272b687ebd07f5afb Mon Sep 17 00:00:00 2001 From: Leon Magma Date: Wed, 25 Sep 2024 16:02:50 +0200 Subject: [PATCH 4/6] feat: PRT: Add subscription metrics (#1695) * add subscription metrics * fix metric typo * fix pr * fix typo * fix lint * fix tests * Add "Connection refused" to allowedErrorsDuringEmergencyMode * remove consumerSessionManager from consumer ws sub * add back session t ws sub * fix pr * fix lint * fix lint * change disconnect reason map * fix pr * fix pr --------- Co-authored-by: leon mandel Co-authored-by: Elad Gildnur <6321801+shleikes@users.noreply.github.com> Co-authored-by: Elad Gildnur Co-authored-by: Ran Mishael <106548467+ranlavanet@users.noreply.github.com> --- .../chainlib/consumer_websocket_manager.go | 2 +- .../consumer_ws_subscription_manager.go | 11 +++- .../consumer_ws_subscription_manager_test.go | 35 +++++----- protocol/metrics/metrics_consumer_manager.go | 66 +++++++++++++++++++ protocol/rpcconsumer/rpcconsumer.go | 2 +- 5 files changed, 97 insertions(+), 19 deletions(-) diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go index 1d5339638e..a3bd553424 100644 --- a/protocol/chainlib/consumer_websocket_manager.go +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -160,7 +160,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { continue } - // check whether its a normal relay / unsubscribe / unsubscribe_all otherwise its a subscription flow. + // check whether it's a normal relay / unsubscribe / unsubscribe_all otherwise its a subscription flow. if !IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { if IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE) { err := cwm.consumerWsSubscriptionManager.Unsubscribe(webSocketCtx, protocolMessage, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) diff --git a/protocol/chainlib/consumer_ws_subscription_manager.go b/protocol/chainlib/consumer_ws_subscription_manager.go index 10d8972e38..102bd8240a 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager.go +++ b/protocol/chainlib/consumer_ws_subscription_manager.go @@ -56,6 +56,7 @@ type ConsumerWSSubscriptionManager struct { activeSubscriptionProvidersStorage *lavasession.ActiveSubscriptionProvidersStorage currentlyPendingSubscriptions map[string]*pendingSubscriptionsBroadcastManager lock sync.RWMutex + consumerMetricsManager *metrics.ConsumerMetricsManager } func NewConsumerWSSubscriptionManager( @@ -65,6 +66,7 @@ func NewConsumerWSSubscriptionManager( connectionType string, chainParser ChainParser, activeSubscriptionProvidersStorage *lavasession.ActiveSubscriptionProvidersStorage, + consumerMetricsManager *metrics.ConsumerMetricsManager, ) *ConsumerWSSubscriptionManager { return &ConsumerWSSubscriptionManager{ connectedDapps: make(map[string]map[string]*common.SafeChannelSender[*pairingtypes.RelayReply]), @@ -76,6 +78,7 @@ func NewConsumerWSSubscriptionManager( relaySender: relaySender, connectionType: connectionType, activeSubscriptionProvidersStorage: activeSubscriptionProvidersStorage, + consumerMetricsManager: consumerMetricsManager, } } @@ -216,6 +219,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( // called after send relay failure or parsing failure afterwards onSubscriptionFailure := func() { + go cwsm.consumerMetricsManager.SetFailedWsSubscriptionRequestMetric(metricsData.ChainID, metricsData.APIType) cwsm.failedPendingSubscription(hashedParams) closeWebsocketRepliesChannel() } @@ -255,6 +259,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( // Validated there are no active subscriptions that we can use. firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, protocolMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) if firstSubscriptionReply != nil { + go cwsm.consumerMetricsManager.SetDuplicatedWsSubscriptionRequestMetric(metricsData.ChainID, metricsData.APIType) if returnWebsocketRepliesChan { return firstSubscriptionReply, websocketRepliesChan, nil } @@ -412,7 +417,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( cwsm.successfulPendingSubscription(hashedParams) // Need to be run once for subscription go cwsm.listenForSubscriptionMessages(webSocketCtx, dappID, consumerIp, replyServer, hashedParams, providerAddr, metricsData, closeSubscriptionChan) - + go cwsm.consumerMetricsManager.SetWsSubscriptionRequestMetric(metricsData.ChainID, metricsData.APIType) return &reply, websocketRepliesChan, nil } @@ -524,12 +529,14 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), ) + go cwsm.consumerMetricsManager.SetWsSubscriptioDisconnectRequestMetric(metricsData.ChainID, metricsData.APIType, metrics.WsDisconnectionReasonUser) return case <-replyServer.Context().Done(): utils.LavaFormatTrace("reply server context canceled", utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), ) + go cwsm.consumerMetricsManager.SetWsSubscriptioDisconnectRequestMetric(metricsData.ChainID, metricsData.APIType, metrics.WsDisconnectionReasonConsumer) return default: var reply pairingtypes.RelayReply @@ -537,6 +544,7 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( if err != nil { // The connection was closed by the provider utils.LavaFormatTrace("error reading from subscription stream", utils.LogAttr("original error", err.Error())) + go cwsm.consumerMetricsManager.SetWsSubscriptioDisconnectRequestMetric(metricsData.ChainID, metricsData.APIType, metrics.WsDisconnectionReasonProvider) return } err = cwsm.handleIncomingSubscriptionNodeMessage(hashedParams, &reply, providerAddr) @@ -545,6 +553,7 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( utils.LogAttr("hashedParams", hashedParams), utils.LogAttr("reply", reply), ) + go cwsm.consumerMetricsManager.SetFailedWsSubscriptionRequestMetric(metricsData.ChainID, metricsData.APIType) return } } diff --git a/protocol/chainlib/consumer_ws_subscription_manager_test.go b/protocol/chainlib/consumer_ws_subscription_manager_test.go index c549cb6772..9aebc649a4 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager_test.go +++ b/protocol/chainlib/consumer_ws_subscription_manager_test.go @@ -27,6 +27,9 @@ import ( const ( numberOfParallelSubscriptions = 10 uniqueId = "1234" + projectHashTest = "test_projecthash" + chainIdTest = "test_chainId" + apiTypeTest = "test_apiType" ) func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *testing.T) { @@ -51,7 +54,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes subscriptionFirstReply2: []byte(`{"jsonrpc":"2.0","id":4,"result":{}}`), }, } - + metricsData := metrics.NewRelayAnalytics(projectHashTest, chainIdTest, apiTypeTest) for _, play := range playbook { t.Run(play.name, func(t *testing.T) { ts := SetupForTests(t, 1, play.specId, "../../") @@ -136,7 +139,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String()) // Create a new ConsumerWSSubscriptionManager - manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage()) + manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage(), nil) uniqueIdentifiers := make([]string, numberOfParallelSubscriptions) wg := sync.WaitGroup{} wg.Add(numberOfParallelSubscriptions) @@ -151,7 +154,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes var repliesChan <-chan *pairingtypes.RelayReply var firstReply *pairingtypes.RelayReply - firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[index], nil) + firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[index], metricsData) go func() { for subMsg := range repliesChan { // utils.LavaFormatInfo("got reply for index", utils.LogAttr("index", index)) @@ -169,7 +172,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes // now we have numberOfParallelSubscriptions subscriptions currently running require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions) // remove one - err = manager.Unsubscribe(ts.Ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[0], nil) + err = manager.Unsubscribe(ts.Ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[0], metricsData) require.NoError(t, err) // now we have numberOfParallelSubscriptions - 1 require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions-1) @@ -177,7 +180,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes require.Len(t, manager.activeSubscriptions, 1) // same flow for unsubscribe all - err = manager.UnsubscribeAll(ts.Ctx, dapp, ip, uniqueIdentifiers[1], nil) + err = manager.UnsubscribeAll(ts.Ctx, dapp, ip, uniqueIdentifiers[1], metricsData) require.NoError(t, err) // now we have numberOfParallelSubscriptions - 2 require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions-2) @@ -209,7 +212,6 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { subscriptionFirstReply2: []byte(`{"jsonrpc":"2.0","id":4,"result":{}}`), }, } - for _, play := range playbook { t.Run(play.name, func(t *testing.T) { ts := SetupForTests(t, 1, play.specId, "../../") @@ -291,9 +293,9 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { Times(1) // Should call SendParsedRelay, because it is the first time we subscribe consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String()) - + metricsData := metrics.NewRelayAnalytics(projectHashTest, chainIdTest, apiTypeTest) // Create a new ConsumerWSSubscriptionManager - manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage()) + manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage(), nil) wg := sync.WaitGroup{} wg.Add(10) @@ -305,7 +307,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { ctx := utils.WithUniqueIdentifier(ts.Ctx, utils.GenerateUniqueIdentifier()) var repliesChan <-chan *pairingtypes.RelayReply var firstReply *pairingtypes.RelayReply - firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, metricsData) go func() { for subMsg := range repliesChan { require.Equal(t, string(play.subscriptionFirstReply1), string(subMsg.Data)) @@ -379,6 +381,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { unsubscribeMessage2: []byte(`{"jsonrpc":"2.0","method":"eth_unsubscribe","params":["0x2134567890"],"id":1}`), }, } + metricsData := metrics.NewRelayAnalytics(projectHashTest, chainIdTest, apiTypeTest) for _, play := range playbook { t.Run(play.name, func(t *testing.T) { @@ -538,12 +541,12 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String()) // Create a new ConsumerWSSubscriptionManager - manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage()) + manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage(), nil) // Start a new subscription for the first time, called SendParsedRelay once ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, metricsData) assert.NoError(t, err) unsubscribeMessageWg.Add(1) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) @@ -559,7 +562,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Start a subscription again, same params, same dappKey, should not call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, metricsData) assert.NoError(t, err) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) assert.Nil(t, repliesChan2) // Same subscription, same dappKey, no need for a new channel @@ -568,7 +571,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Start a subscription again, same params, different dappKey, should not call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp2, ts.Consumer.Addr.String(), uniqueId, metricsData) assert.NoError(t, err) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) assert.NotNil(t, repliesChan3) // Same subscription, but different dappKey, so will create new channel @@ -652,7 +655,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Start a subscription again, different params, same dappKey, should call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeProtocolMessage2, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeProtocolMessage2, dapp1, ts.Consumer.Addr.String(), uniqueId, metricsData) assert.NoError(t, err) unsubscribeMessageWg.Add(1) assert.Equal(t, string(play.subscriptionFirstReply2), string(firstReply.Data)) @@ -671,7 +674,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) unsubProtocolMessage := NewProtocolMessage(unsubscribeChainMessage1, nil, relayResult1.Request.RelayData, dapp2, ts.Consumer.Addr.String()) - err = manager.Unsubscribe(ctx, unsubProtocolMessage, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) + err = manager.Unsubscribe(ctx, unsubProtocolMessage, dapp2, ts.Consumer.Addr.String(), uniqueId, metricsData) require.NoError(t, err) listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1)) @@ -697,7 +700,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { Times(2) // Should call SendParsedRelay, because it unsubscribed ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - err = manager.UnsubscribeAll(ctx, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + err = manager.UnsubscribeAll(ctx, dapp1, ts.Consumer.Addr.String(), uniqueId, metricsData) require.NoError(t, err) expectNoMoreMessages(ctx, repliesChan1) diff --git a/protocol/metrics/metrics_consumer_manager.go b/protocol/metrics/metrics_consumer_manager.go index 2f3337e432..83cd72d025 100644 --- a/protocol/metrics/metrics_consumer_manager.go +++ b/protocol/metrics/metrics_consumer_manager.go @@ -13,6 +13,12 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) +const ( + WsDisconnectionReasonConsumer = "consumer-disconnect" + WsDisconnectionReasonProvider = "provider-disconnect" + WsDisconnectionReasonUser = "user-disconnect" +) + type LatencyTracker struct { AverageLatency time.Duration // in nano seconds (time.Since result) TotalRequests int @@ -34,6 +40,10 @@ type ConsumerMetricsManager struct { totalNodeErroredRecoveryAttemptsMetric *prometheus.CounterVec totalRelaysSentToProvidersMetric *prometheus.CounterVec totalRelaysSentByNewBatchTickerMetric *prometheus.CounterVec + totalWsSubscriptionRequestsMetric *prometheus.CounterVec + totalFailedWsSubscriptionRequestsMetric *prometheus.CounterVec + totalWsSubscriptionDissconnectMetric *prometheus.CounterVec + totalDuplicatedWsSubscriptionRequestsMetric *prometheus.CounterVec blockMetric *prometheus.GaugeVec latencyMetric *prometheus.GaugeVec qosMetric *prometheus.GaugeVec @@ -88,6 +98,26 @@ func NewConsumerMetricsManager(options ConsumerMetricsManagerOptions) *ConsumerM Help: "The total number of errors encountered by the consumer over time.", }, []string{"spec", "apiInterface"}) + totalWsSubscriptionRequestsMetric := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "lava_consumer_total_ws_subscription_requests", + Help: "The total number of websocket subscription requests over time per chain id per api interface.", + }, []string{"spec", "apiInterface"}) + + totalFailedWsSubscriptionRequestsMetric := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "lava_consumer_total_failed_ws_subscription_requests", + Help: "The total number of failed websocket subscription requests over time per chain id per api interface.", + }, []string{"spec", "apiInterface"}) + + totalDuplicatedWsSubscriptionRequestsMetric := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "lava_consumer_total_duplicated_ws_subscription_requests", + Help: "The total number of duplicated webscket subscription requests over time per chain id per api interface.", + }, []string{"spec", "apiInterface"}) + + totalWsSubscriptionDissconnectMetric := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "lava_consumer_total_ws_subscription_disconnect", + Help: "The total number of websocket subscription disconnects over time per chain id per api interface per dissconnect reason.", + }, []string{"spec", "apiInterface", "dissconectReason"}) + blockMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{ Name: "lava_latest_block", Help: "The latest block measured", @@ -196,10 +226,18 @@ func NewConsumerMetricsManager(options ConsumerMetricsManagerOptions) *ConsumerM prometheus.MustRegister(totalNodeErroredRecoveryAttemptsMetric) prometheus.MustRegister(relayProcessingLatencyBeforeProvider) prometheus.MustRegister(relayProcessingLatencyAfterProvider) + prometheus.MustRegister(totalWsSubscriptionRequestsMetric) + prometheus.MustRegister(totalFailedWsSubscriptionRequestsMetric) + prometheus.MustRegister(totalDuplicatedWsSubscriptionRequestsMetric) + prometheus.MustRegister(totalWsSubscriptionDissconnectMetric) consumerMetricsManager := &ConsumerMetricsManager{ totalCURequestedMetric: totalCURequestedMetric, totalRelaysRequestedMetric: totalRelaysRequestedMetric, + totalWsSubscriptionRequestsMetric: totalWsSubscriptionRequestsMetric, + totalFailedWsSubscriptionRequestsMetric: totalFailedWsSubscriptionRequestsMetric, + totalDuplicatedWsSubscriptionRequestsMetric: totalDuplicatedWsSubscriptionRequestsMetric, + totalWsSubscriptionDissconnectMetric: totalWsSubscriptionDissconnectMetric, totalErroredMetric: totalErroredMetric, blockMetric: blockMetric, latencyMetric: latencyMetric, @@ -460,3 +498,31 @@ func SetVersionInner(protocolVersionMetric *prometheus.GaugeVec, version string) combined := major*1000000 + minor*1000 + patch protocolVersionMetric.WithLabelValues("version").Set(float64(combined)) } + +func (pme *ConsumerMetricsManager) SetWsSubscriptionRequestMetric(chainId string, apiInterface string) { + if pme == nil { + return + } + pme.totalWsSubscriptionRequestsMetric.WithLabelValues(chainId, apiInterface).Inc() +} + +func (pme *ConsumerMetricsManager) SetFailedWsSubscriptionRequestMetric(chainId string, apiInterface string) { + if pme == nil { + return + } + pme.totalFailedWsSubscriptionRequestsMetric.WithLabelValues(chainId, apiInterface).Inc() +} + +func (pme *ConsumerMetricsManager) SetDuplicatedWsSubscriptionRequestMetric(chainId string, apiInterface string) { + if pme == nil { + return + } + pme.totalDuplicatedWsSubscriptionRequestsMetric.WithLabelValues(chainId, apiInterface).Inc() +} + +func (pme *ConsumerMetricsManager) SetWsSubscriptioDisconnectRequestMetric(chainId string, apiInterface string, disconnectReason string) { + if pme == nil { + return + } + pme.totalWsSubscriptionDissconnectMetric.WithLabelValues(chainId, apiInterface, disconnectReason).Inc() +} diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index c919678735..1223ef34be 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -298,7 +298,7 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt if rpcEndpoint.ApiInterface == spectypes.APIInterfaceJsonRPC { specMethodType = http.MethodPost } - consumerWsSubscriptionManager = chainlib.NewConsumerWSSubscriptionManager(consumerSessionManager, rpcConsumerServer, options.refererData, specMethodType, chainParser, activeSubscriptionProvidersStorage) + consumerWsSubscriptionManager = chainlib.NewConsumerWSSubscriptionManager(consumerSessionManager, rpcConsumerServer, options.refererData, specMethodType, chainParser, activeSubscriptionProvidersStorage, consumerMetricsManager) utils.LavaFormatInfo("RPCConsumer Listening", utils.Attribute{Key: "endpoints", Value: rpcEndpoint.String()}) err = rpcConsumerServer.ServeRPCRequests(ctx, rpcEndpoint, rpcc.consumerStateTracker, chainParser, finalizationConsensus, consumerSessionManager, options.requiredResponses, privKey, lavaChainID, options.cache, rpcConsumerMetrics, consumerAddr, consumerConsistency, relaysMonitor, options.cmdFlags, options.stateShare, options.refererData, consumerReportsManager, consumerWsSubscriptionManager) From 7d0aefa0080889313f87f0913c8f5584b1340bbc Mon Sep 17 00:00:00 2001 From: Ran Mishael <106548467+ranlavanet@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:40:28 +0200 Subject: [PATCH 5/6] feat: PRT - add rate limit to ws (#1713) * feat: PRT - add rate limit to ws * lintush * Update protocol/chainlib/consumer_websocket_manager.go --- .../chainlib/consumer_websocket_manager.go | 58 ++++++++++++++++++- .../consumer_ws_subscription_manager_test.go | 7 +++ protocol/common/cobra_common.go | 3 + protocol/common/return_errors.go | 9 +++ protocol/rpcconsumer/rpcconsumer.go | 1 + 5 files changed, 75 insertions(+), 3 deletions(-) diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go index a3bd553424..e6edb7aaa6 100644 --- a/protocol/chainlib/consumer_websocket_manager.go +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -3,9 +3,10 @@ package chainlib import ( "context" "strconv" + "sync/atomic" "time" - gojson "github.com/goccy/go-json" + "github.com/goccy/go-json" "github.com/gofiber/websocket/v2" formatter "github.com/lavanet/lava/v3/ecosystem/cache/format" "github.com/lavanet/lava/v3/protocol/common" @@ -13,8 +14,11 @@ import ( "github.com/lavanet/lava/v3/utils" "github.com/lavanet/lava/v3/utils/rand" spectypes "github.com/lavanet/lava/v3/x/spec/types" + "github.com/tidwall/gjson" ) +var WebSocketRateLimit = -1 // rate limit requests per second on websocket connection + type ConsumerWebsocketManager struct { websocketConn *websocket.Conn rpcConsumerLogs *metrics.RPCConsumerLogs @@ -67,6 +71,27 @@ func (cwm *ConsumerWebsocketManager) GetWebSocketConnectionUniqueId(dappId, user return dappId + "__" + userIp + "__" + cwm.WebsocketConnectionUID } +func (cwm *ConsumerWebsocketManager) handleRateLimitReached(inpData []byte) ([]byte, error) { + rateLimitError := common.JsonRpcRateLimitError + id := 0 + result := gjson.GetBytes(inpData, "id") + switch result.Type { + case gjson.Number: + id = int(result.Int()) + case gjson.String: + idParsed, err := strconv.Atoi(result.Raw) + if err == nil { + id = idParsed + } + } + rateLimitError.Id = id + bytesRateLimitError, err := json.Marshal(rateLimitError) + if err != nil { + return []byte{}, utils.LavaFormatError("failed marshalling jsonrpc rate limit error", err) + } + return bytesRateLimitError, nil +} + func (cwm *ConsumerWebsocketManager) ListenToMessages() { var ( messageType int @@ -110,6 +135,24 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } }() + // rate limit routine + requestsPerSecond := &atomic.Uint64{} + go func() { + if WebSocketRateLimit <= 0 { + return + } + ticker := time.NewTicker(time.Second) // rate limit per second. + defer ticker.Stop() + for { + select { + case <-webSocketCtx.Done(): + return + case <-ticker.C: + requestsPerSecond.Store(0) + } + } + }() + for { startTime := time.Now() msgSeed := guidString + "_" + strconv.Itoa(rand.Intn(10000000000)) // use message seed with original guid and new int @@ -125,6 +168,15 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { break } + // Check rate limit is met + if WebSocketRateLimit > 0 && requestsPerSecond.Add(1) > uint64(WebSocketRateLimit) { + rateLimitResponse, err := cwm.handleRateLimitReached(msg) + if err == nil { + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: rateLimitResponse} + } + continue + } + dappID, ok := websocketConn.Locals("dapp-id").(string) if !ok { // Log and remove the analyze @@ -167,7 +219,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { if err != nil { utils.LavaFormatWarning("error unsubscribing from subscription", err, utils.LogAttr("GUID", webSocketCtx)) if err == common.SubscriptionNotFoundError { - msgData, err := gojson.Marshal(common.JsonRpcSubscriptionNotFoundError) + msgData, err := json.Marshal(common.JsonRpcSubscriptionNotFoundError) if err != nil { continue } @@ -224,7 +276,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { // Handle the case when the error is a method not found error if common.APINotSupportedError.Is(err) { - msgData, err := gojson.Marshal(common.JsonRpcMethodNotFoundError) + msgData, err := json.Marshal(common.JsonRpcMethodNotFoundError) if err != nil { continue } diff --git a/protocol/chainlib/consumer_ws_subscription_manager_test.go b/protocol/chainlib/consumer_ws_subscription_manager_test.go index 9aebc649a4..48573a3512 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager_test.go +++ b/protocol/chainlib/consumer_ws_subscription_manager_test.go @@ -2,9 +2,11 @@ package chainlib import ( "context" + "fmt" "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -324,6 +326,11 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { } } +func TestRateLimit(t *testing.T) { + numberOfRequests := &atomic.Uint64{} + fmt.Println(numberOfRequests.Load()) +} + func TestConsumerWSSubscriptionManager(t *testing.T) { // This test does the following: // 1. Create a new ConsumerWSSubscriptionManager diff --git a/protocol/common/cobra_common.go b/protocol/common/cobra_common.go index 40cbffdce2..fe75c8f31f 100644 --- a/protocol/common/cobra_common.go +++ b/protocol/common/cobra_common.go @@ -38,6 +38,9 @@ const ( SetProviderOptimizerBestTierPickChance = "set-provider-optimizer-best-tier-pick-chance" SetProviderOptimizerWorstTierPickChance = "set-provider-optimizer-worst-tier-pick-chance" SetProviderOptimizerNumberOfTiersToCreate = "set-provider-optimizer-number-of-tiers-to-create" + + // websocket flags + RateLimitWebSocketFlag = "rate-limit-websocket-requests-per-connection" ) const ( diff --git a/protocol/common/return_errors.go b/protocol/common/return_errors.go index 5394ba1f3d..9020a26f17 100644 --- a/protocol/common/return_errors.go +++ b/protocol/common/return_errors.go @@ -27,6 +27,15 @@ var JsonRpcMethodNotFoundError = JsonRPCErrorMessage{ }, } +var JsonRpcRateLimitError = JsonRPCErrorMessage{ + JsonRPC: "2.0", + Id: 1, + Error: JsonRPCError{ + Code: 429, + Message: "Too Many Requests", + }, +} + var JsonRpcSubscriptionNotFoundError = JsonRPCErrorMessage{ JsonRPC: "2.0", Id: 1, diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 1223ef34be..77b07b4bf1 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -618,6 +618,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 cmdRPCConsumer.Flags().Float64Var(&provideroptimizer.ATierChance, common.SetProviderOptimizerBestTierPickChance, 0.75, "set the chances for picking a provider from the best group, default is 75% -> 0.75") cmdRPCConsumer.Flags().Float64Var(&provideroptimizer.LastTierChance, common.SetProviderOptimizerWorstTierPickChance, 0.0, "set the chances for picking a provider from the worse group, default is 0% -> 0.0") cmdRPCConsumer.Flags().IntVar(&provideroptimizer.OptimizerNumTiers, common.SetProviderOptimizerNumberOfTiersToCreate, 4, "set the number of groups to create, default is 4") + cmdRPCConsumer.Flags().IntVar(&chainlib.WebSocketRateLimit, common.RateLimitWebSocketFlag, chainlib.WebSocketRateLimit, "rate limit (per second) websocket requests per user connection, default is unlimited") common.AddRollingLogConfig(cmdRPCConsumer) return cmdRPCConsumer } From 5e3e277718d43590f45ad6749c6583b2b38634eb Mon Sep 17 00:00:00 2001 From: Yaroms <103432884+Yaroms@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:59:29 +0300 Subject: [PATCH 6/6] add script (#1704) Co-authored-by: Yarom Swisa --- .../short_block_proposers.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 scripts/automation_scripts/short_block_proposers.py diff --git a/scripts/automation_scripts/short_block_proposers.py b/scripts/automation_scripts/short_block_proposers.py new file mode 100644 index 0000000000..578b3381b9 --- /dev/null +++ b/scripts/automation_scripts/short_block_proposers.py @@ -0,0 +1,50 @@ +import requests +from datetime import datetime + +def fetch_block_data(block_number): + # URL that returns a JSON response + url = 'https://lava.rest.lava.build/cosmos/base/tendermint/v1beta1/blocks/' + block_number + + # Fetch the data from the URL + response = requests.get(url) + + # Parse the JSON data + json_data = response.json() + block_header = json_data.get("sdk_block").get("header") + + height = block_header.get("height") + + dt = datetime.strptime(block_header.get("time")[:26] + "Z", "%Y-%m-%dT%H:%M:%S.%fZ") + time = dt.timestamp() + + proposer = block_header.get("proposer_address") + + return [int(height), float(time), proposer] + +stats_before = {} +stats_current = {} +current = fetch_block_data("latest") +for i in range(10000): + print(f"Progress: {i}", end='\r') + before = fetch_block_data(str(current[0]-1)) + + if current[2] not in stats_current: + stats_current[current[2]] = {"good": 0, "bad": 0} + + if before[2] not in stats_before: + stats_before[before[2]] = {"good": 0, "bad": 0} + + if current[1] - before[1] < 10: + stats_current[current[2]]["bad"] += 1 + stats_before[before[2]]["bad"] += 1 + else: + stats_current[current[2]]["good"] += 1 + stats_before[before[2]]["good"] += 1 + + current = before + +print(stats_before) +print("---------------------------------------------------------------------") +# print(stats_current) + +