diff --git a/core/services/functions/connector_handler.go b/core/services/functions/connector_handler.go index 8a8710e6ea6..c5dbff6f10f 100644 --- a/core/services/functions/connector_handler.go +++ b/core/services/functions/connector_handler.go @@ -76,24 +76,21 @@ func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, ga h.lggr.Errorw("request rate-limited", "id", gatewayId, "address", fromAddr) return } - if balance, err := h.subscriptions.GetMaxUserBalance(fromAddr); err != nil || balance.Cmp(h.minimumBalance.ToInt()) < 0 { - h.lggr.Errorw("user subscription has insufficient balance", "id", gatewayId, "address", fromAddr, "balance", balance, "minBalance", h.minimumBalance) - response := functions.SecretsResponseBase{ - Success: false, - ErrorMessage: "user subscription has insufficient balance", - } - if err := h.sendResponse(ctx, gatewayId, body, response); err != nil { - h.lggr.Errorw("failed to send response to gateway", "id", gatewayId, "error", err) - } - return - } - h.lggr.Debugw("handling gateway request", "id", gatewayId, "method", body.Method) switch body.Method { case functions.MethodSecretsList: h.handleSecretsList(ctx, gatewayId, body, fromAddr) case functions.MethodSecretsSet: + if balance, err := h.subscriptions.GetMaxUserBalance(fromAddr); err != nil || balance.Cmp(h.minimumBalance.ToInt()) < 0 { + h.lggr.Errorw("user subscription has insufficient balance", "id", gatewayId, "address", fromAddr, "balance", balance, "minBalance", h.minimumBalance) + response := functions.SecretsResponseBase{ + Success: false, + ErrorMessage: "user subscription has insufficient balance", + } + h.sendResponseAndLog(ctx, gatewayId, body, response) + return + } h.handleSecretsSet(ctx, gatewayId, body, fromAddr) default: h.lggr.Errorw("unsupported method", "id", gatewayId, "method", body.Method) @@ -133,10 +130,7 @@ func (h *functionsConnectorHandler) handleSecretsList(ctx context.Context, gatew } else { response.ErrorMessage = fmt.Sprintf("Failed to list secrets: %v", err) } - - if err := h.sendResponse(ctx, gatewayId, body, response); err != nil { - h.lggr.Errorw("failed to send response to gateway", "id", gatewayId, "error", err) - } + h.sendResponseAndLog(ctx, gatewayId, body, response) } func (h *functionsConnectorHandler) handleSecretsSet(ctx context.Context, gatewayId string, body *api.MessageBody, fromAddr ethCommon.Address) { @@ -163,9 +157,15 @@ func (h *functionsConnectorHandler) handleSecretsSet(ctx context.Context, gatewa } else { response.ErrorMessage = fmt.Sprintf("Bad request to set secret: %v", err) } + h.sendResponseAndLog(ctx, gatewayId, body, response) +} - if err := h.sendResponse(ctx, gatewayId, body, response); err != nil { +func (h *functionsConnectorHandler) sendResponseAndLog(ctx context.Context, gatewayId string, requestBody *api.MessageBody, payload any) { + err := h.sendResponse(ctx, gatewayId, requestBody, payload) + if err != nil { h.lggr.Errorw("failed to send response to gateway", "id", gatewayId, "error", err) + } else { + h.lggr.Debugw("sent to gateway", "id", gatewayId, "messageId", requestBody.MessageId, "donId", requestBody.DonId, "method", requestBody.Method) } } @@ -187,10 +187,5 @@ func (h *functionsConnectorHandler) sendResponse(ctx context.Context, gatewayId if err = msg.Sign(h.signerKey); err != nil { return err } - - err = h.connector.SendToGateway(ctx, gatewayId, msg) - if err == nil { - h.lggr.Debugw("sent to gateway", "id", gatewayId, "messageId", requestBody.MessageId, "donId", requestBody.DonId, "method", requestBody.Method) - } - return err + return h.connector.SendToGateway(ctx, gatewayId, msg) } diff --git a/core/services/functions/connector_handler_test.go b/core/services/functions/connector_handler_test.go index 7bf98d7501d..fa9f74712be 100644 --- a/core/services/functions/connector_handler_test.go +++ b/core/services/functions/connector_handler_test.go @@ -78,7 +78,6 @@ func TestFunctionsConnectorHandler(t *testing.T) { } storage.On("List", ctx, addr).Return(snapshot, nil).Once() allowlist.On("Allow", addr).Return(true).Once() - subscriptions.On("GetMaxUserBalance", mock.Anything).Return(big.NewInt(100), nil).Once() connector.On("SendToGateway", ctx, "gw1", mock.Anything).Run(func(args mock.Arguments) { msg, ok := args[2].(*api.Message) require.True(t, ok) @@ -91,7 +90,6 @@ func TestFunctionsConnectorHandler(t *testing.T) { t.Run("orm error", func(t *testing.T) { storage.On("List", ctx, addr).Return(nil, errors.New("boom")).Once() allowlist.On("Allow", addr).Return(true).Once() - subscriptions.On("GetMaxUserBalance", mock.Anything).Return(big.NewInt(100), nil).Once() connector.On("SendToGateway", ctx, "gw1", mock.Anything).Run(func(args mock.Arguments) { msg, ok := args[2].(*api.Message) require.True(t, ok) @@ -218,7 +216,6 @@ func TestFunctionsConnectorHandler(t *testing.T) { require.NoError(t, msg.Sign(privateKey)) allowlist.On("Allow", addr).Return(true).Once() - subscriptions.On("GetMaxUserBalance", mock.Anything).Return(big.NewInt(100), nil).Once() handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg) }) }) diff --git a/core/services/gateway/handlers/functions/handler.functions.go b/core/services/gateway/handlers/functions/handler.functions.go index d0011145d40..bb6812c1f9b 100644 --- a/core/services/gateway/handlers/functions/handler.functions.go +++ b/core/services/gateway/handlers/functions/handler.functions.go @@ -178,7 +178,7 @@ func (h *functionsHandler) HandleUserMessage(ctx context.Context, msg *api.Messa promHandlerError.WithLabelValues(h.donConfig.DonId, ErrRateLimited.Error()).Inc() return ErrRateLimited } - if h.subscriptions != nil && h.minimumBalance != nil { + if msg.Body.Method == MethodSecretsSet && h.subscriptions != nil && h.minimumBalance != nil { balance, err := h.subscriptions.GetMaxUserBalance(sender) if err != nil { h.lggr.Debugw("error getting max user balance", "sender", msg.Body.Sender, "err", err) diff --git a/core/services/gateway/handlers/functions/handler.functions_test.go b/core/services/gateway/handlers/functions/handler.functions_test.go index 1446bc84571..49fdae2bb24 100644 --- a/core/services/gateway/handlers/functions/handler.functions_test.go +++ b/core/services/gateway/handlers/functions/handler.functions_test.go @@ -148,11 +148,10 @@ func TestFunctionsHandler_HandleUserMessage_InvalidMethod(t *testing.T) { t.Parallel() nodes, user := gc.NewTestNodes(t, 4), gc.NewTestNodes(t, 1)[0] - handler, _, allowlist, subscriptions := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24) + handler, _, allowlist, _ := newFunctionsHandlerForATestDON(t, nodes, time.Hour*24) userRequestMsg := newSignedMessage(t, "1234", "secrets_reveal_all_please", "don_id", user.PrivateKey) allowlist.On("Allow", common.HexToAddress(user.Address)).Return(true, nil) - subscriptions.On("GetMaxUserBalance", common.HexToAddress(user.Address)).Return(big.NewInt(1000), nil) err := handler.HandleUserMessage(testutils.Context(t), &userRequestMsg, make(chan handlers.UserCallbackPayload)) require.Error(t, err) }