diff --git a/common/model/api_key_info.go b/common/model/api_key_info.go index a22daf68..daebcb32 100644 --- a/common/model/api_key_info.go +++ b/common/model/api_key_info.go @@ -1,17 +1,20 @@ package model -import "golang.org/x/time/rate" +import ( + mapset "github.com/deckarep/golang-set/v2" + "golang.org/x/time/rate" +) type ApiKeyModel struct { - Disable bool `json:"disable"` - ApiKey string `json:"api_key"` - RateLimit rate.Limit `json:"rate_limit"` - UserId int64 `json:"user_id"` - NetWorkLimitEnable bool `json:"network_limit_enable"` - DomainWhitelist []string `json:"domain_whitelist"` - IPWhiteList []string `json:"ip_white_list"` - PaymasterEnable bool `json:"paymaster_enable"` - Erc20PaymasterEnable bool `json:"erc20_paymaster_enable"` - ProjectSponsorPaymasterEnable bool `json:"project_sponsor_paymaster_enable"` - UserPayPaymasterEnable bool `json:"user_pay_paymaster_enable"` + Disable bool `json:"disable"` + ApiKey string `json:"api_key"` + RateLimit rate.Limit `json:"rate_limit"` + UserId int64 `json:"user_id"` + NetWorkLimitEnable bool `json:"network_limit_enable"` + DomainWhitelist mapset.Set[string] `json:"domain_whitelist"` + IPWhiteList mapset.Set[string] `json:"ip_white_list"` + PaymasterEnable bool `json:"paymaster_enable"` + Erc20PaymasterEnable bool `json:"erc20_paymaster_enable"` + ProjectSponsorPaymasterEnable bool `json:"project_sponsor_paymaster_enable"` + UserPayPaymasterEnable bool `json:"user_pay_paymaster_enable"` } diff --git a/common/price_compoent/price_util.go b/common/price_compoent/price_util.go index 59ab5e5e..8fb3e99c 100644 --- a/common/price_compoent/price_util.go +++ b/common/price_compoent/price_util.go @@ -4,7 +4,6 @@ import ( "AAStarCommunity/EthPaymaster_BackService/common/global_const" "AAStarCommunity/EthPaymaster_BackService/config" "fmt" - "github.com/sirupsen/logrus" "golang.org/x/xerrors" "io" "io/ioutil" @@ -57,22 +56,12 @@ func GetPriceUsd(tokenType global_const.TokenType) (float64, error) { req.Header.Add("x-cg-demo-api-key", config.GetPriceOracleApiKey()) res, _ := http.DefaultClient.Do(req) - logrus.Debugf("get price req: %v", req) - logrus.Debugf("get price response: %v", res) - if res == nil { - return 0, xerrors.Errorf("get price error: %w", "response is nil") - } - if res.StatusCode != 200 { - return 0, xerrors.Errorf("get price error: %w", res.Status) - - } + defer res.Body.Close() body, _ := io.ReadAll(res.Body) bodystr := string(body) strarr := strings.Split(bodystr, ":") usdstr := strings.TrimRight(strarr[2], "}}") - defer res.Body.Close() - return strconv.ParseFloat(usdstr, 64) } diff --git a/rpc_server/middlewares/auth.go b/rpc_server/middlewares/auth.go index d8eabd66..8523f4ac 100644 --- a/rpc_server/middlewares/auth.go +++ b/rpc_server/middlewares/auth.go @@ -3,8 +3,8 @@ package middlewares import ( "AAStarCommunity/EthPaymaster_BackService/common/global_const" "AAStarCommunity/EthPaymaster_BackService/service/dashboard_service" + "errors" "github.com/gin-gonic/gin" - "github.com/sirupsen/logrus" "net/http" ) @@ -12,27 +12,45 @@ type ApiKey struct { Key string `form:"apiKey" json:"apiKey" binding:"required"` } -func AuthHandler() gin.HandlerFunc { +func ApiVerificationHandler() gin.HandlerFunc { return func(c *gin.Context) { apiKey := c.Query("apiKey") if apiKey == "" { - c.JSON(http.StatusForbidden, gin.H{"error": "ApiKey is mandatory, visit to https://dashboard.aastar.io for more detail."}) - c.Abort() + _ = c.AbortWithError(http.StatusForbidden, errors.New("ApiKey is mandatory, visit to https://dashboard.aastar.io for more detail")) return } apiModel, err := dashboard_service.GetAPiInfoByApiKey(apiKey) if err != nil { - logrus.Errorf("GetAPiInfoByApiKey err: %v", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Can Not Find Your Api Key"}) - c.Abort() + _ = c.AbortWithError(http.StatusBadRequest, errors.New("can Not Find Your Api Key")) return } if apiModel.Disable { - c.JSON(http.StatusForbidden, gin.H{"error": "Api Key Is Disabled"}) - c.Abort() + _ = c.AbortWithError(http.StatusForbidden, errors.New("api Key Is Disabled")) + return + } + if !apiModel.PaymasterEnable { + _ = c.AbortWithError(http.StatusForbidden, errors.New("api Key Is Disabled Paymaster")) + return + } + if !VerifyRateLimit(*apiModel) { + _ = c.AbortWithError(http.StatusTooManyRequests, errors.New("too many requests")) return } + if apiModel.IPWhiteList != nil && apiModel.IPWhiteList.Cardinality() > 0 { + clientIp := c.ClientIP() + if !apiModel.IPWhiteList.Contains(clientIp) { + _ = c.AbortWithError(http.StatusForbidden, errors.New("ip not in whitelist")) + return + } + } + if apiModel.DomainWhitelist != nil && apiModel.DomainWhitelist.Cardinality() > 0 { + domain := c.Request.Host + if !apiModel.DomainWhitelist.Contains(domain) { + _ = c.AbortWithError(http.StatusForbidden, errors.New("domain not in whitelist")) + return + } + } c.Set(global_const.ContextKeyApiMoDel, apiModel) } } diff --git a/rpc_server/middlewares/pv_mertics.go b/rpc_server/middlewares/pv_mertics.go index ad60b112..ddbdc905 100644 --- a/rpc_server/middlewares/pv_mertics.go +++ b/rpc_server/middlewares/pv_mertics.go @@ -69,7 +69,14 @@ func PvMetrics() gin.HandlerFunc { metricsParam.ApiKey = apiKeyModel.ApiKey metricsParam.ApiUserId = apiKeyModel.UserId } - metricsPaymaster(c, metricsParam) + + extraMap := make(map[string]any) + clientIp := c.ClientIP() + extraMap["client_ip"] = clientIp + + domain := c.Request.Host + extraMap["client_domain"] = domain + metricsPaymaster(c, metricsParam, &extraMap) } else { return } @@ -86,7 +93,7 @@ func (w *CustomResponseWriter) Write(b []byte) (int, error) { w.body.Write(b) return w.ResponseWriter.Write(b) } -func metricsPaymaster(c *gin.Context, metricsParam PayMasterParam) { +func metricsPaymaster(c *gin.Context, metricsParam PayMasterParam, extraMap *map[string]any) { recallModel := dashboard_service.PaymasterRecallLogDbModel{ ProjectApikey: metricsParam.ApiKey, @@ -99,6 +106,13 @@ func metricsPaymaster(c *gin.Context, metricsParam PayMasterParam) { Status: metricsParam.Status, NetWork: metricsParam.NetWork, } + if extraMap != nil { + executeRestrictionJson, err := json.Marshal(extraMap) + if err != nil { + logrus.Error("executeRestrictionJson error:", err) + } + recallModel.Extra = executeRestrictionJson + } err := dashboard_service.CreatePaymasterCall(&recallModel) if err != nil { logrus.Error("CreatePaymasterCall error:", err) diff --git a/rpc_server/middlewares/rate_limit.go b/rpc_server/middlewares/rate_limit.go index fc1b6079..20b9d1d7 100644 --- a/rpc_server/middlewares/rate_limit.go +++ b/rpc_server/middlewares/rate_limit.go @@ -1,12 +1,8 @@ package middlewares import ( - "AAStarCommunity/EthPaymaster_BackService/common/global_const" "AAStarCommunity/EthPaymaster_BackService/common/model" - "errors" - "github.com/gin-gonic/gin" "golang.org/x/time/rate" - "net/http" ) const ( @@ -16,21 +12,10 @@ const ( var limiter map[string]*rate.Limiter -// RateLimiterByApiKeyHandler represents the rate limit by each ApiKey for each api calling -func RateLimiterByApiKeyHandler() gin.HandlerFunc { - return func(ctx *gin.Context) { - apiKeyModelInterface := ctx.MustGet(global_const.ContextKeyApiMoDel) - defaultLimit := DefaultLimit - apiKeyModel := apiKeyModelInterface.(*model.ApiKeyModel) - defaultLimit = apiKeyModel.RateLimit - - if limiting(&apiKeyModel.ApiKey, defaultLimit) { - ctx.Next() - } else { - _ = ctx.AbortWithError(http.StatusTooManyRequests, errors.New("too many requests")) - } - } +func VerifyRateLimit(keyModel model.ApiKeyModel) bool { + return limiting(&keyModel.ApiKey, keyModel.RateLimit) } + func clearLimiter(apiKey *string) { delete(limiter, *apiKey) } diff --git a/rpc_server/routers/boot.go b/rpc_server/routers/boot.go index a08235e8..db7c7d53 100644 --- a/rpc_server/routers/boot.go +++ b/rpc_server/routers/boot.go @@ -52,8 +52,7 @@ func buildRoute(routers *gin.Engine) { //build the routers not need api access like auth or Traffic limit buildRouters(routers, PublicRouterMaps) - routers.Use(middlewares.AuthHandler()) - routers.Use(middlewares.RateLimiterByApiKeyHandler()) + routers.Use(middlewares.ApiVerificationHandler()) buildRouters(routers, PrivateRouterMaps) } diff --git a/service/dashboard_service/dashboard_service.go b/service/dashboard_service/dashboard_service.go index 6b3fcbda..c0120232 100644 --- a/service/dashboard_service/dashboard_service.go +++ b/service/dashboard_service/dashboard_service.go @@ -234,13 +234,52 @@ func (*ApiKeyDbModel) TableName() string { func (m *ApiKeyDbModel) GetRateLimit() rate.Limit { return 10 } + +type APiModelExtra struct { + NetWorkLimitEnable bool `json:"network_limit_enable"` + DomainWhitelist []string `json:"domain_whitelist"` + IPWhiteList []string `json:"ip_white_list"` + PaymasterEnable bool `json:"paymaster_enable"` + Erc20PaymasterEnable bool `json:"erc20_paymaster_enable"` + ProjectSponsorPaymasterEnable bool `json:"project_sponsor_paymaster_enable"` + UserPayPaymasterEnable bool `json:"user_pay_paymaster_enable"` +} + func convertApiKeyDbModelToApiKeyModel(apiKeyDbModel *ApiKeyDbModel) *model.ApiKeyModel { - return &model.ApiKeyModel{ + apiKeyModel := &model.ApiKeyModel{ Disable: apiKeyDbModel.Disable, ApiKey: apiKeyDbModel.ApiKey, RateLimit: 10, UserId: apiKeyDbModel.UserId, } + if apiKeyDbModel.Extra != nil { + // convert To map + eJson, _ := apiKeyDbModel.Extra.MarshalJSON() + apiKeyExtra := &APiModelExtra{} + err := json.Unmarshal(eJson, apiKeyExtra) + if err != nil { + return nil + } + apiKeyModel.NetWorkLimitEnable = apiKeyExtra.NetWorkLimitEnable + if apiKeyExtra.IPWhiteList != nil { + apiKeyModel.IPWhiteList = mapset.NewSetWithSize[string](len(apiKeyExtra.IPWhiteList)) + for _, v := range apiKeyExtra.IPWhiteList { + apiKeyModel.IPWhiteList.Add(v) + } + } + if apiKeyExtra.DomainWhitelist != nil { + apiKeyModel.DomainWhitelist = mapset.NewSetWithSize[string](len(apiKeyExtra.DomainWhitelist)) + for _, v := range apiKeyExtra.DomainWhitelist { + apiKeyModel.DomainWhitelist.Add(v) + } + } + apiKeyModel.PaymasterEnable = apiKeyExtra.PaymasterEnable + apiKeyModel.Erc20PaymasterEnable = apiKeyExtra.Erc20PaymasterEnable + apiKeyModel.ProjectSponsorPaymasterEnable = apiKeyExtra.ProjectSponsorPaymasterEnable + apiKeyModel.UserPayPaymasterEnable = apiKeyExtra.UserPayPaymasterEnable + + } + return apiKeyModel } func GetAPiInfoByApiKey(apiKey string) (*model.ApiKeyModel, error) { apikeyModel := &ApiKeyDbModel{} @@ -262,8 +301,8 @@ type PaymasterRecallLogDbModel struct { PaymasterMethod string `gorm:"column:paymaster_method;type:varchar(25)" json:"paymaster_method"` SendTime string `gorm:"column:send_time;type:varchar(50)" json:"send_time"` Latency int64 `gorm:"column:latency;type:integer" json:"latency"` - RequestBody string `gorm:"column:request_body;type:varchar(500)" json:"request_body"` - ResponseBody string `gorm:"column:response_body;type:varchar(1000)" json:"response_body"` + RequestBody string `gorm:"column:request_body;type:varchar(2000)" json:"request_body"` + ResponseBody string `gorm:"column:response_body;type:varchar(2000)" json:"response_body"` NetWork string `gorm:"column:network;type:varchar(25)" json:"network"` Status int `gorm:"column:status;type:integer" json:"status"` Extra datatypes.JSON `gorm:"column:extra" json:"extra"` diff --git a/service/operator/try_pay_user_op_execute.go b/service/operator/try_pay_user_op_execute.go index 18fc79e6..252d1f3f 100644 --- a/service/operator/try_pay_user_op_execute.go +++ b/service/operator/try_pay_user_op_execute.go @@ -20,7 +20,7 @@ import ( ) func TryPayUserOpExecute(apiKeyModel *model.ApiKeyModel, request *model.UserOpRequest) (*model.TryPayUserOpResponse, error) { - userOp, strategy, paymasterDataInput, err := prepareExecute(request) + userOp, strategy, paymasterDataInput, err := prepareExecute(request, apiKeyModel) if err != nil { return nil, err } @@ -46,14 +46,14 @@ func TryPayUserOpExecute(apiKeyModel *model.ApiKeyModel, request *model.UserOpRe return result, nil } -func prepareExecute(request *model.UserOpRequest) (*user_op.UserOpInput, *model.Strategy, *paymaster_data.PaymasterDataInput, error) { +func prepareExecute(request *model.UserOpRequest, apiKeyModel *model.ApiKeyModel) (*user_op.UserOpInput, *model.Strategy, *paymaster_data.PaymasterDataInput, error) { var strategy *model.Strategy strategy, generateErr := StrategyGenerate(request) if generateErr != nil { return nil, nil, nil, generateErr } - if err := validator_service.ValidateStrategy(strategy, request); err != nil { + if err := validator_service.ValidateStrategy(strategy, request, apiKeyModel); err != nil { return nil, nil, nil, err } diff --git a/service/validator_service/basic_validator.go b/service/validator_service/basic_validator.go index db5c10ef..c9dbe65f 100644 --- a/service/validator_service/basic_validator.go +++ b/service/validator_service/basic_validator.go @@ -11,7 +11,7 @@ import ( "time" ) -func ValidateStrategy(strategy *model.Strategy, request *model.UserOpRequest) error { +func ValidateStrategy(strategy *model.Strategy, request *model.UserOpRequest, keyModel *model.ApiKeyModel) error { if strategy == nil { return xerrors.Errorf("empty strategy") } @@ -31,7 +31,7 @@ func ValidateStrategy(strategy *model.Strategy, request *model.UserOpRequest) er } if strategy.ExecuteRestriction == nil { - return xerrors.Errorf("ExecuteRestriction is Empty") + return nil } if strategy.ExecuteRestriction.Status != global_const.StrategyStatusAchive { return xerrors.Errorf("strategy status is not active") @@ -77,6 +77,21 @@ func ValidateStrategy(strategy *model.Strategy, request *model.UserOpRequest) er return xerrors.Errorf("strategy not support chainId [%s]", netWorkStr) } } + payType := strategy.GetPayType() + switch payType { + case global_const.PayTypeERC20: + if !keyModel.Erc20PaymasterEnable { + return xerrors.Errorf("strategy pay type is erc20 but not enable") + } + case global_const.PayTypeUserSponsor: + if !keyModel.UserPayPaymasterEnable { + return xerrors.Errorf("strategy pay type is user sponsor but not enable") + } + case global_const.PayTypeProjectSponsor: + if !keyModel.ProjectSponsorPaymasterEnable { + return xerrors.Errorf("strategy pay type is project sponsor but not enable") + } + } return nil diff --git a/service/validator_service/validator_test.go b/service/validator_service/validator_test.go index b619c8d2..448517e5 100644 --- a/service/validator_service/validator_test.go +++ b/service/validator_service/validator_test.go @@ -38,7 +38,9 @@ func TestValidatorService(t *testing.T) { } func testValidateStrategy(t *testing.T, strategy *model.Strategy, request *model.UserOpRequest) { - if err := ValidateStrategy(strategy, request); err != nil { + if err := ValidateStrategy(strategy, request, &model.ApiKeyModel{ + UserId: 5, + }); err != nil { t.Fatalf("ValidateStrategy error: %v", err) } }