Skip to content

Commit

Permalink
[Functions] Require minimum balance only for secrets_set, not for list (
Browse files Browse the repository at this point in the history
#11309)

Additionally refactor a helper method in connector_handler.go.
  • Loading branch information
bolekk authored Nov 16, 2023
1 parent 2a947cc commit b012cc4
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 29 deletions.
41 changes: 18 additions & 23 deletions core/services/functions/connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,21 @@ func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, ga
h.lggr.Errorw("request rate-limited", "id", gatewayId, "address", fromAddr)
return
}
if balance, err := h.subscriptions.GetMaxUserBalance(fromAddr); err != nil || balance.Cmp(h.minimumBalance.ToInt()) < 0 {
h.lggr.Errorw("user subscription has insufficient balance", "id", gatewayId, "address", fromAddr, "balance", balance, "minBalance", h.minimumBalance)
response := functions.SecretsResponseBase{
Success: false,
ErrorMessage: "user subscription has insufficient balance",
}
if err := h.sendResponse(ctx, gatewayId, body, response); err != nil {
h.lggr.Errorw("failed to send response to gateway", "id", gatewayId, "error", err)
}
return
}

h.lggr.Debugw("handling gateway request", "id", gatewayId, "method", body.Method)

switch body.Method {
case functions.MethodSecretsList:
h.handleSecretsList(ctx, gatewayId, body, fromAddr)
case functions.MethodSecretsSet:
if balance, err := h.subscriptions.GetMaxUserBalance(fromAddr); err != nil || balance.Cmp(h.minimumBalance.ToInt()) < 0 {
h.lggr.Errorw("user subscription has insufficient balance", "id", gatewayId, "address", fromAddr, "balance", balance, "minBalance", h.minimumBalance)
response := functions.SecretsResponseBase{
Success: false,
ErrorMessage: "user subscription has insufficient balance",
}
h.sendResponseAndLog(ctx, gatewayId, body, response)
return
}
h.handleSecretsSet(ctx, gatewayId, body, fromAddr)
default:
h.lggr.Errorw("unsupported method", "id", gatewayId, "method", body.Method)
Expand Down Expand Up @@ -133,10 +130,7 @@ func (h *functionsConnectorHandler) handleSecretsList(ctx context.Context, gatew
} else {
response.ErrorMessage = fmt.Sprintf("Failed to list secrets: %v", err)
}

if err := h.sendResponse(ctx, gatewayId, body, response); err != nil {
h.lggr.Errorw("failed to send response to gateway", "id", gatewayId, "error", err)
}
h.sendResponseAndLog(ctx, gatewayId, body, response)
}

func (h *functionsConnectorHandler) handleSecretsSet(ctx context.Context, gatewayId string, body *api.MessageBody, fromAddr ethCommon.Address) {
Expand All @@ -163,9 +157,15 @@ func (h *functionsConnectorHandler) handleSecretsSet(ctx context.Context, gatewa
} else {
response.ErrorMessage = fmt.Sprintf("Bad request to set secret: %v", err)
}
h.sendResponseAndLog(ctx, gatewayId, body, response)
}

if err := h.sendResponse(ctx, gatewayId, body, response); err != nil {
func (h *functionsConnectorHandler) sendResponseAndLog(ctx context.Context, gatewayId string, requestBody *api.MessageBody, payload any) {
err := h.sendResponse(ctx, gatewayId, requestBody, payload)
if err != nil {
h.lggr.Errorw("failed to send response to gateway", "id", gatewayId, "error", err)
} else {
h.lggr.Debugw("sent to gateway", "id", gatewayId, "messageId", requestBody.MessageId, "donId", requestBody.DonId, "method", requestBody.Method)
}
}

Expand All @@ -187,10 +187,5 @@ func (h *functionsConnectorHandler) sendResponse(ctx context.Context, gatewayId
if err = msg.Sign(h.signerKey); err != nil {
return err
}

err = h.connector.SendToGateway(ctx, gatewayId, msg)
if err == nil {
h.lggr.Debugw("sent to gateway", "id", gatewayId, "messageId", requestBody.MessageId, "donId", requestBody.DonId, "method", requestBody.Method)
}
return err
return h.connector.SendToGateway(ctx, gatewayId, msg)
}
3 changes: 0 additions & 3 deletions core/services/functions/connector_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ func TestFunctionsConnectorHandler(t *testing.T) {
}
storage.On("List", ctx, addr).Return(snapshot, nil).Once()
allowlist.On("Allow", addr).Return(true).Once()
subscriptions.On("GetMaxUserBalance", mock.Anything).Return(big.NewInt(100), nil).Once()
connector.On("SendToGateway", ctx, "gw1", mock.Anything).Run(func(args mock.Arguments) {
msg, ok := args[2].(*api.Message)
require.True(t, ok)
Expand All @@ -91,7 +90,6 @@ func TestFunctionsConnectorHandler(t *testing.T) {
t.Run("orm error", func(t *testing.T) {
storage.On("List", ctx, addr).Return(nil, errors.New("boom")).Once()
allowlist.On("Allow", addr).Return(true).Once()
subscriptions.On("GetMaxUserBalance", mock.Anything).Return(big.NewInt(100), nil).Once()
connector.On("SendToGateway", ctx, "gw1", mock.Anything).Run(func(args mock.Arguments) {
msg, ok := args[2].(*api.Message)
require.True(t, ok)
Expand Down Expand Up @@ -218,7 +216,6 @@ func TestFunctionsConnectorHandler(t *testing.T) {
require.NoError(t, msg.Sign(privateKey))

allowlist.On("Allow", addr).Return(true).Once()
subscriptions.On("GetMaxUserBalance", mock.Anything).Return(big.NewInt(100), nil).Once()
handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg)
})
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (h *functionsHandler) HandleUserMessage(ctx context.Context, msg *api.Messa
promHandlerError.WithLabelValues(h.donConfig.DonId, ErrRateLimited.Error()).Inc()
return ErrRateLimited
}
if h.subscriptions != nil && h.minimumBalance != nil {
if msg.Body.Method == MethodSecretsSet && h.subscriptions != nil && h.minimumBalance != nil {
balance, err := h.subscriptions.GetMaxUserBalance(sender)
if err != nil {
h.lggr.Debugw("error getting max user balance", "sender", msg.Body.Sender, "err", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,10 @@ func TestFunctionsHandler_HandleUserMessage_InvalidMethod(t *testing.T) {
t.Parallel()

nodes, user := gc.NewTestNodes(t, 4), gc.NewTestNodes(t, 1)[0]
handler, _, allowlist, subscriptions := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24)
handler, _, allowlist, _ := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24)
userRequestMsg := newSignedMessage(t, "1234", "secrets_reveal_all_please", "don_id", user.PrivateKey)

allowlist.On("Allow", common.HexToAddress(user.Address)).Return(true, nil)
subscriptions.On("GetMaxUserBalance", common.HexToAddress(user.Address)).Return(big.NewInt(1000), nil)
err := handler.HandleUserMessage(testutils.Context(t), &userRequestMsg, make(chan handlers.UserCallbackPayload))
require.Error(t, err)
}
Expand Down

0 comments on commit b012cc4

Please sign in to comment.