Skip to content

Commit

Permalink
Functions: fixed subscriptions tracking logic (#11167)
Browse files Browse the repository at this point in the history
* Functions: fixed subscriptions tracking logic

* Addressed PR feedback

* Addressed PR feedback

* Addressed PR feedback
  • Loading branch information
Andrei Smirnov authored Nov 6, 2023
1 parent 28e9596 commit b24af13
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 19 deletions.
12 changes: 10 additions & 2 deletions core/services/gateway/handlers/functions/handler.functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math/big"
"time"

"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -178,8 +179,15 @@ func (h *functionsHandler) HandleUserMessage(ctx context.Context, msg *api.Messa
return ErrRateLimited
}
if h.subscriptions != nil && h.minimumBalance != nil {
if balance, err := h.subscriptions.GetMaxUserBalance(sender); err != nil || balance.Cmp(h.minimumBalance.ToInt()) < 0 {
h.lggr.Debug("received a message from a user having insufficient balance", "sender", msg.Body.Sender, "balance", balance.String())
balance, err := h.subscriptions.GetMaxUserBalance(sender)
if err != nil {
h.lggr.Debugw("error getting max user balance", "sender", msg.Body.Sender, "err", err)
}
if balance == nil {
balance = big.NewInt(0)
}
if err != nil || balance.Cmp(h.minimumBalance.ToInt()) < 0 {
h.lggr.Debugw("received a message from a user having insufficient balance", "sender", msg.Body.Sender, "balance", balance.String())
return fmt.Errorf("sender has insufficient balance: %v juels", balance.String())
}
}
Expand Down
22 changes: 9 additions & 13 deletions core/services/gateway/handlers/functions/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,16 @@ func (s *onchainSubscriptions) queryLoop() {

blockNumber := big.NewInt(0).Sub(latestBlockHeight, s.blockConfirmations)

updateLastKnownCount := func() {
if lastKnownCount == 0 || start > lastKnownCount {
count, err := s.getSubscriptionsCount(ctx, blockNumber)
if err != nil {
s.lggr.Errorw("Error getting subscriptions count", "err", err)
return
s.lggr.Errorw("Error getting new subscriptions count", "err", err)
} else {
s.lggr.Infow("Updated subscriptions count", "count", count, "blockNumber", blockNumber.Int64())
lastKnownCount = count
}
s.lggr.Infow("Updated subscriptions count", "err", err, "count", count, "blockNumber", blockNumber.Int64())
lastKnownCount = count
}

if lastKnownCount == 0 {
updateLastKnownCount()
}
if lastKnownCount == 0 {
s.lggr.Info("Router has no subscriptions yet")
return
Expand All @@ -152,12 +149,9 @@ func (s *onchainSubscriptions) queryLoop() {
start = 1
}

end := start + uint64(s.config.UpdateRangeSize)
end := start + uint64(s.config.UpdateRangeSize) - 1
if end > lastKnownCount {
updateLastKnownCount()
if end > lastKnownCount {
end = lastKnownCount
}
end = lastKnownCount
}
if err := s.querySubscriptionsRange(ctx, blockNumber, start, end); err != nil {
s.lggr.Errorw("Error querying subscriptions", "err", err, "start", start, "end", end)
Expand All @@ -180,6 +174,8 @@ func (s *onchainSubscriptions) queryLoop() {
}

func (s *onchainSubscriptions) querySubscriptionsRange(ctx context.Context, blockNumber *big.Int, start, end uint64) error {
s.lggr.Debugw("Querying subscriptions", "blockNumber", blockNumber, "start", start, "end", end)

subscriptions, err := s.router.GetSubscriptionsInRange(&bind.CallOpts{
Pending: false,
BlockNumber: blockNumber,
Expand Down
52 changes: 48 additions & 4 deletions core/services/gateway/handlers/functions/subscriptions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package functions_test

import (
"math/big"
"sync/atomic"
"testing"
"time"

Expand All @@ -24,9 +25,7 @@ const (
invalidUser = "0x6E2dc0F9DB014aE19888F539E59285D2Ea04244C"
)

func TestSubscriptions(t *testing.T) {
t.Parallel()

func TestSubscriptions_OnePass(t *testing.T) {
getSubscriptionCount := hexutil.MustDecode("0x0000000000000000000000000000000000000000000000000000000000000003")
getSubscriptionsInRange := hexutil.MustDecode("0x00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000240000000000000000000000000000000000000000000000000de0b6b3a76400000000000000000000000000000109e6e1b12098cc8f3a1e9719a817ec53ab9b35c000000000000000000000000000000000000000000000000000034e23f515cb0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000f5340f0968ee8b7dfd97e3327a6139273cc2c4fa000000000000000000000000000000000000000000000001158e460913d000000000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001bc14b92364c75e20000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000005439e5881a529f3ccbffc0e82d49f9db3950aefe")

Expand All @@ -46,7 +45,7 @@ func TestSubscriptions(t *testing.T) {
BlockConfirmations: 1,
UpdateFrequencySec: 1,
UpdateTimeoutSec: 1,
UpdateRangeSize: 10,
UpdateRangeSize: 3,
}
subscriptions, err := functions.NewOnchainSubscriptions(client, config, logger.TestLogger(t))
require.NoError(t, err)
Expand All @@ -57,10 +56,55 @@ func TestSubscriptions(t *testing.T) {
assert.NoError(t, subscriptions.Close())
})

// initially we have 3 subs and range is 3, which needs one pass
gomega.NewGomegaWithT(t).Eventually(func() bool {
expectedBalance := big.NewInt(0).SetBytes(hexutil.MustDecode("0x01158e460913d00000"))
balance, err1 := subscriptions.GetMaxUserBalance(common.HexToAddress(validUser))
_, err2 := subscriptions.GetMaxUserBalance(common.HexToAddress(invalidUser))
return err1 == nil && err2 != nil && balance.Cmp(expectedBalance) == 0
}, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue())
}

func TestSubscriptions_MultiPass(t *testing.T) {
const ncycles int32 = 5
var currentCycle atomic.Int32
getSubscriptionCount := hexutil.MustDecode("0x0000000000000000000000000000000000000000000000000000000000000006")
getSubscriptionsInRange := hexutil.MustDecode("0x00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000240000000000000000000000000000000000000000000000000de0b6b3a76400000000000000000000000000000109e6e1b12098cc8f3a1e9719a817ec53ab9b35c000000000000000000000000000000000000000000000000000034e23f515cb0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000f5340f0968ee8b7dfd97e3327a6139273cc2c4fa000000000000000000000000000000000000000000000001158e460913d000000000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001bc14b92364c75e20000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000005439e5881a529f3ccbffc0e82d49f9db3950aefe")

ctx := testutils.Context(t)
client := mocks.NewClient(t)
client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
client.On("CallContract", mock.Anything, ethereum.CallMsg{ // getSubscriptionCount
To: &common.Address{},
Data: hexutil.MustDecode("0x66419970"),
}, mock.Anything).Run(func(args mock.Arguments) {
currentCycle.Add(1)
}).Return(getSubscriptionCount, nil)
client.On("CallContract", mock.Anything, ethereum.CallMsg{ // GetSubscriptionsInRange(1,3)
To: &common.Address{},
Data: hexutil.MustDecode("0xec2454e500000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000003"),
}, mock.Anything).Return(getSubscriptionsInRange, nil)
client.On("CallContract", mock.Anything, ethereum.CallMsg{ // GetSubscriptionsInRange(4,6)
To: &common.Address{},
Data: hexutil.MustDecode("0xec2454e500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000006"),
}, mock.Anything).Return(getSubscriptionsInRange, nil)
config := functions.OnchainSubscriptionsConfig{
ContractAddress: common.Address{},
BlockConfirmations: 1,
UpdateFrequencySec: 1,
UpdateTimeoutSec: 1,
UpdateRangeSize: 3,
}
subscriptions, err := functions.NewOnchainSubscriptions(client, config, logger.TestLogger(t))
require.NoError(t, err)

err = subscriptions.Start(ctx)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, subscriptions.Close())
})

gomega.NewGomegaWithT(t).Eventually(func() bool {
return currentCycle.Load() == ncycles
}, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue())
}

0 comments on commit b24af13

Please sign in to comment.